random-api-go/service/s3_fetcher.go

331 lines
8.2 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"fmt"
"log"
"net/url"
"random-api-go/model"
"regexp"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
)
// S3Fetcher S3获取器
type S3Fetcher struct {
timeout time.Duration
}
// NewS3Fetcher 创建S3获取器
func NewS3Fetcher() *S3Fetcher {
return &S3Fetcher{
timeout: 30 * time.Second,
}
}
// FetchURLs 从S3存储桶获取文件URL列表
func (sf *S3Fetcher) FetchURLs(s3Config *model.S3Config) ([]string, error) {
if s3Config == nil {
return nil, fmt.Errorf("S3配置不能为空")
}
// 验证必需的配置
if s3Config.Endpoint == "" {
return nil, fmt.Errorf("S3端点地址不能为空")
}
if s3Config.BucketName == "" {
return nil, fmt.Errorf("存储桶名称不能为空")
}
if s3Config.AccessKeyID == "" {
return nil, fmt.Errorf("访问密钥ID不能为空")
}
if s3Config.SecretAccessKey == "" {
return nil, fmt.Errorf("访问密钥不能为空")
}
// 创建S3客户端
client, err := sf.createS3Client(s3Config)
if err != nil {
return nil, fmt.Errorf("创建S3客户端失败: %w", err)
}
// 获取对象列表
objects, err := sf.listObjects(client, s3Config)
if err != nil {
return nil, fmt.Errorf("获取对象列表失败: %w", err)
}
// 过滤和转换为URL
urls := sf.convertObjectsToURLs(objects, s3Config)
log.Printf("从S3存储桶 %s 获取到 %d 个文件URL", s3Config.BucketName, len(urls))
return urls, nil
}
// createS3Client 创建S3客户端
func (sf *S3Fetcher) createS3Client(s3Config *model.S3Config) (*s3.Client, error) {
// 设置默认地区
region := s3Config.Region
if region == "" {
region = "us-east-1"
}
// 创建凭证
creds := credentials.NewStaticCredentialsProvider(
s3Config.AccessKeyID,
s3Config.SecretAccessKey,
"",
)
// 创建配置
cfg, err := config.LoadDefaultConfig(context.TODO(),
config.WithRegion(region),
config.WithCredentialsProvider(creds),
)
if err != nil {
return nil, fmt.Errorf("加载AWS配置失败: %w", err)
}
// 创建S3客户端选项
options := func(o *s3.Options) {
if s3Config.Endpoint != "" {
o.BaseEndpoint = aws.String(s3Config.Endpoint)
}
o.UsePathStyle = s3Config.UsePathStyle
}
client := s3.NewFromConfig(cfg, options)
return client, nil
}
// listObjects 列出存储桶中的对象
func (sf *S3Fetcher) listObjects(client *s3.Client, s3Config *model.S3Config) ([]types.Object, error) {
ctx, cancel := context.WithTimeout(context.Background(), sf.timeout)
defer cancel()
var allObjects []types.Object
var continuationToken *string
// 设置前缀(文件夹路径)
prefix := strings.TrimPrefix(s3Config.FolderPath, "/")
if prefix != "" && !strings.HasSuffix(prefix, "/") {
prefix += "/"
}
// 设置分隔符(如果不包含子文件夹)
var delimiter *string
if !s3Config.IncludeSubfolders {
delimiter = aws.String("/")
}
// 确定使用的ListObjects版本
listVersion := s3Config.ListObjectsVersion
if listVersion == "" {
listVersion = "v2" // 默认使用v2
}
for {
if listVersion == "v1" {
// 使用ListObjects (v1)
input := &s3.ListObjectsInput{
Bucket: aws.String(s3Config.BucketName),
Prefix: aws.String(prefix),
Delimiter: delimiter,
MaxKeys: aws.Int32(1000),
}
if continuationToken != nil {
input.Marker = continuationToken
}
result, err := client.ListObjects(ctx, input)
if err != nil {
return nil, fmt.Errorf("ListObjects失败: %w", err)
}
allObjects = append(allObjects, result.Contents...)
if !aws.ToBool(result.IsTruncated) {
break
}
if len(result.Contents) > 0 {
continuationToken = result.Contents[len(result.Contents)-1].Key
}
} else {
// 使用ListObjectsV2 (v2)
input := &s3.ListObjectsV2Input{
Bucket: aws.String(s3Config.BucketName),
Prefix: aws.String(prefix),
Delimiter: delimiter,
MaxKeys: aws.Int32(1000),
ContinuationToken: continuationToken,
}
result, err := client.ListObjectsV2(ctx, input)
if err != nil {
return nil, fmt.Errorf("ListObjectsV2失败: %w", err)
}
allObjects = append(allObjects, result.Contents...)
if !aws.ToBool(result.IsTruncated) {
break
}
continuationToken = result.NextContinuationToken
}
}
return allObjects, nil
}
// convertObjectsToURLs 将S3对象转换为URL列表
func (sf *S3Fetcher) convertObjectsToURLs(objects []types.Object, s3Config *model.S3Config) []string {
var urls []string
// 编译文件扩展名正则表达式
var extensionRegexes []*regexp.Regexp
for _, ext := range s3Config.FileExtensions {
if ext != "" {
// 确保扩展名以点开头
if !strings.HasPrefix(ext, ".") {
ext = "." + ext
}
// 转义特殊字符并创建正则表达式
pattern := regexp.QuoteMeta(ext) + "$"
if regex, err := regexp.Compile("(?i)" + pattern); err == nil {
extensionRegexes = append(extensionRegexes, regex)
} else {
log.Printf("警告: 无效的文件扩展名正则表达式 '%s': %v", ext, err)
}
}
}
for _, obj := range objects {
if obj.Key == nil {
continue
}
key := aws.ToString(obj.Key)
// 跳过以/结尾的对象(文件夹)
if strings.HasSuffix(key, "/") {
continue
}
// 如果设置了文件扩展名过滤,检查是否匹配
if len(extensionRegexes) > 0 {
matched := false
for _, regex := range extensionRegexes {
if regex.MatchString(key) {
matched = true
break
}
}
if !matched {
continue
}
}
// 生成URL
fileURL := sf.generateURL(key, s3Config)
if fileURL != "" {
urls = append(urls, fileURL)
}
}
return urls
}
// generateURL 生成文件的访问URL
func (sf *S3Fetcher) generateURL(key string, s3Config *model.S3Config) string {
// 如果设置了自定义域名
if s3Config.CustomDomain != "" {
return sf.generateCustomDomainURL(key, s3Config)
}
// 使用S3标准URL格式
return sf.generateS3URL(key, s3Config)
}
// generateCustomDomainURL 生成自定义域名URL
func (sf *S3Fetcher) generateCustomDomainURL(key string, s3Config *model.S3Config) string {
baseURL := strings.TrimSuffix(s3Config.CustomDomain, "/")
// 处理key路径
path := key
if s3Config.RemoveBucket {
// 如果需要移除bucket名称并且key以bucket名称开头
bucketPrefix := s3Config.BucketName + "/"
if strings.HasPrefix(path, bucketPrefix) {
path = strings.TrimPrefix(path, bucketPrefix)
}
}
// 对路径进行适当的URL编码但保留路径分隔符
path = sf.encodeURLPath(path)
// 确保路径以/开头
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
return baseURL + path
}
// generateS3URL 生成S3标准URL
func (sf *S3Fetcher) generateS3URL(key string, s3Config *model.S3Config) string {
// 对key进行适当的URL编码但保留路径分隔符
encodedKey := sf.encodeURLPath(key)
if s3Config.UsePathStyle {
// Path-style URL: http://endpoint/bucket/key
endpoint := strings.TrimSuffix(s3Config.Endpoint, "/")
return fmt.Sprintf("%s/%s/%s", endpoint, s3Config.BucketName, encodedKey)
} else {
// Virtual-hosted-style URL: http://bucket.endpoint/key
endpoint := strings.TrimSuffix(s3Config.Endpoint, "/")
// 解析endpoint以获取主机名
if parsedURL, err := url.Parse(endpoint); err == nil {
scheme := parsedURL.Scheme
if scheme == "" {
scheme = "https"
}
host := parsedURL.Host
if host == "" {
host = parsedURL.Path
}
return fmt.Sprintf("%s://%s.%s/%s", scheme, s3Config.BucketName, host, encodedKey)
}
// 如果解析失败回退到path-style
return fmt.Sprintf("%s/%s/%s", endpoint, s3Config.BucketName, encodedKey)
}
}
// encodeURLPath 对URL路径进行编码保留路径分隔符但编码其他特殊字符
func (sf *S3Fetcher) encodeURLPath(path string) string {
// 分割路径为各个部分
parts := strings.Split(path, "/")
// 对每个部分进行URL编码
for i, part := range parts {
if part != "" {
// 使用PathEscape对每个路径段进行编码
// 这会将空格编码为%20这是URL路径中的标准做法
parts[i] = url.PathEscape(part)
}
}
// 重新组合路径
return strings.Join(parts, "/")
}