diff --git a/core/config.go b/core/config.go index 8eb1619..c4cffa0 100644 --- a/core/config.go +++ b/core/config.go @@ -1,13 +1,26 @@ package core +import ( + "os" + "path/filepath" +) + var ( - BOT_TOKEN string - ADMIN_ID int64 + BOT_TOKEN string + ADMIN_ID int64 + DB_FILE string + DEBUG_MODE bool ) func InitGlobalVariables(botToken string, adminID int64) { BOT_TOKEN = botToken ADMIN_ID = adminID + + // 设置数据库文件路径 + DB_FILE = filepath.Join("/app/data", "q58.db") + + // 从环境变量中读取调试模式设置 + DEBUG_MODE = os.Getenv("DEBUG_MODE") == "true" } func IsAdmin(userID int64) bool { diff --git a/core/database.go b/core/database.go index 7f83fdd..83d5249 100644 --- a/core/database.go +++ b/core/database.go @@ -4,6 +4,7 @@ import ( "database/sql" "os" "path/filepath" + "strings" "sync" "time" @@ -11,24 +12,25 @@ import ( ) type Database struct { - db *sql.DB - dbFile string - keywordsCache []string - whitelistCache []string - cacheTime time.Time - mu sync.Mutex + db *sql.DB + keywordsCache []string + whitelistCache []string + promptRepliesCache map[string]string + keywordsCacheTime time.Time + whitelistCacheTime time.Time + promptRepliesCacheTime time.Time + mu sync.Mutex } -func NewDatabase(dbFile string) (*Database, error) { - os.MkdirAll(filepath.Dir(dbFile), os.ModePerm) - db, err := sql.Open("sqlite", dbFile) +func NewDatabase() (*Database, error) { + os.MkdirAll(filepath.Dir(DB_FILE), os.ModePerm) + db, err := sql.Open("sqlite", DB_FILE) if err != nil { return nil, err } database := &Database{ - db: db, - dbFile: dbFile, + db: db, } if err := database.createTables(); err != nil { @@ -46,6 +48,8 @@ func (d *Database) createTables() error { `CREATE TABLE IF NOT EXISTS whitelist (id INTEGER PRIMARY KEY, domain TEXT UNIQUE)`, `CREATE INDEX IF NOT EXISTS idx_domain ON whitelist(domain)`, + `CREATE TABLE IF NOT EXISTS prompt_replies + (prompt TEXT PRIMARY KEY, reply TEXT NOT NULL)`, } for _, query := range queries { @@ -82,7 +86,7 @@ func (d *Database) AddKeyword(keyword string) error { if err != nil { return err } - d.invalidateCache() + d.invalidateCache("keywords") return nil } @@ -91,7 +95,7 @@ func (d *Database) RemoveKeyword(keyword string) error { if err != nil { return err } - d.invalidateCache() + d.invalidateCache("keywords") return nil } @@ -99,13 +103,13 @@ func (d *Database) GetAllKeywords() ([]string, error) { d.mu.Lock() defer d.mu.Unlock() - if d.keywordsCache == nil || time.Since(d.cacheTime) > 5*time.Minute { + if d.keywordsCache == nil || time.Since(d.keywordsCacheTime) > 5*time.Minute { keywords, err := d.executeQuery("SELECT keyword FROM keywords") if err != nil { return nil, err } d.keywordsCache = keywords - d.cacheTime = time.Now() + d.keywordsCacheTime = time.Now() } return d.keywordsCache, nil @@ -134,7 +138,7 @@ func (d *Database) RemoveKeywordsContaining(substring string) ([]string, error) return nil, err } - d.invalidateCache() + d.invalidateCache("keywords") return removedKeywords, nil } @@ -143,7 +147,7 @@ func (d *Database) AddWhitelist(domain string) error { if err != nil { return err } - d.invalidateCache() + d.invalidateCache("whitelist") return nil } @@ -152,7 +156,7 @@ func (d *Database) RemoveWhitelist(domain string) error { if err != nil { return err } - d.invalidateCache() + d.invalidateCache("whitelist") return nil } @@ -160,13 +164,13 @@ func (d *Database) GetAllWhitelist() ([]string, error) { d.mu.Lock() defer d.mu.Unlock() - if d.whitelistCache == nil || time.Since(d.cacheTime) > 5*time.Minute { + if d.whitelistCache == nil || time.Since(d.whitelistCacheTime) > 5*time.Minute { whitelist, err := d.executeQuery("SELECT domain FROM whitelist") if err != nil { return nil, err } d.whitelistCache = whitelist - d.cacheTime = time.Now() + d.whitelistCacheTime = time.Now() } return d.whitelistCache, nil @@ -194,12 +198,85 @@ func (d *Database) WhitelistExists(domain string) (bool, error) { return count > 0, nil } -func (d *Database) invalidateCache() { +func (d *Database) AddPromptReply(prompt, reply string) error { + _, err := d.db.Exec("INSERT OR REPLACE INTO prompt_replies (prompt, reply) VALUES (?, ?)", strings.ToLower(prompt), reply) + if err != nil { + return err + } + d.invalidateCache("promptReplies") + return nil +} + +func (d *Database) DeletePromptReply(prompt string) error { + _, err := d.db.Exec("DELETE FROM prompt_replies WHERE prompt = ?", strings.ToLower(prompt)) + if err != nil { + return err + } + d.invalidateCache("promptReplies") + return nil +} + +func (d *Database) fetchAllPromptReplies() (map[string]string, error) { + rows, err := d.db.Query("SELECT prompt, reply FROM prompt_replies") + if err != nil { + return nil, err + } + defer rows.Close() + + promptReplies := make(map[string]string) + for rows.Next() { + var prompt, reply string + if err := rows.Scan(&prompt, &reply); err != nil { + return nil, err + } + promptReplies[prompt] = reply + } + return promptReplies, nil +} + +func (d *Database) GetAllPromptReplies() (map[string]string, error) { d.mu.Lock() defer d.mu.Unlock() - d.keywordsCache = nil - d.whitelistCache = nil - d.cacheTime = time.Time{} + + if d.promptRepliesCache == nil || time.Since(d.promptRepliesCacheTime) > 5*time.Minute { + promptReplies, err := d.fetchAllPromptReplies() + if err != nil { + return nil, err + } + d.promptRepliesCache = promptReplies + d.promptRepliesCacheTime = time.Now() + } + + // 返回一个副本以防止外部修改缓存 + result := make(map[string]string, len(d.promptRepliesCache)) + for k, v := range d.promptRepliesCache { + result[k] = v + } + return result, nil +} + +func (d *Database) invalidateCache(cacheType string) { + d.mu.Lock() + defer d.mu.Unlock() + switch cacheType { + case "keywords": + d.keywordsCache = nil + d.keywordsCacheTime = time.Time{} + case "whitelist": + d.whitelistCache = nil + d.whitelistCacheTime = time.Time{} + case "promptReplies": + d.promptRepliesCache = nil + d.promptRepliesCacheTime = time.Time{} + default: + // 清除所有缓存 + d.keywordsCache = nil + d.whitelistCache = nil + d.promptRepliesCache = nil + d.keywordsCacheTime = time.Time{} + d.whitelistCacheTime = time.Time{} + d.promptRepliesCacheTime = time.Time{} + } } func (d *Database) Close() error { diff --git a/service/guard.go b/service/guard.go index 5b68689..af50a68 100644 --- a/service/guard.go +++ b/service/guard.go @@ -82,7 +82,7 @@ func startBot() error { return fmt.Errorf("failed to create bot: %w", err) } - bot.Debug = debugMode + bot.Debug = core.DEBUG_MODE log.Printf("Authorized on account %s", bot.Self.UserName) @@ -91,7 +91,7 @@ func startBot() error { return fmt.Errorf("error registering commands: %w", err) } - linkFilter, err := NewLinkFilter(dbFile) + linkFilter, err := NewLinkFilter(core.DB_FILE) if err != nil { return fmt.Errorf("failed to create LinkFilter: %v", err) } diff --git a/service/init.go b/service/init.go index 372d3ac..b6eee12 100644 --- a/service/init.go +++ b/service/init.go @@ -1,21 +1,13 @@ package service import ( - "os" "time" "github.com/woodchen-ink/Q58Bot/core" ) -var ( - dbFile string - debugMode bool -) - func Init(botToken string, adminID int64) { core.InitGlobalVariables(botToken, adminID) - dbFile = "/app/data/q58.db" - debugMode = os.Getenv("DEBUG_MODE") == "true" // 设置时区 loc := time.FixedZone("Asia/Singapore", 8*60*60) diff --git a/service/prompt_reply.go b/service/prompt_reply.go index 5a0f6f3..e5220a8 100644 --- a/service/prompt_reply.go +++ b/service/prompt_reply.go @@ -2,6 +2,7 @@ package service import ( "fmt" + "log" "strings" "sync" @@ -13,23 +14,35 @@ import ( var ( promptReplies = make(map[string]string) promptMutex sync.RWMutex + db *core.Database ) -func SetPromptReply(prompt, reply string) { - promptMutex.Lock() - defer promptMutex.Unlock() - promptReplies[strings.ToLower(prompt)] = reply +func InitPromptService(database *core.Database) error { + db = database + return loadPromptRepliesFromDB() } -func DeletePromptReply(prompt string) { - promptMutex.Lock() - defer promptMutex.Unlock() - delete(promptReplies, strings.ToLower(prompt)) +func loadPromptRepliesFromDB() error { + var err error + promptReplies, err = db.GetAllPromptReplies() + return err +} + +func SetPromptReply(prompt, reply string) error { + return db.AddPromptReply(prompt, reply) +} + +func DeletePromptReply(prompt string) error { + return db.DeletePromptReply(prompt) } func GetPromptReply(message string) (string, bool) { - promptMutex.RLock() - defer promptMutex.RUnlock() + promptReplies, err := db.GetAllPromptReplies() + if err != nil { + log.Printf("Error getting prompt replies: %v", err) + return "", false + } + for prompt, reply := range promptReplies { if strings.Contains(strings.ToLower(message), prompt) { return reply, true @@ -77,14 +90,22 @@ func HandlePromptCommand(bot *tgbotapi.BotAPI, message *tgbotapi.Message) { bot.Send(tgbotapi.NewMessage(message.Chat.ID, "请同时提供提示词和回复。")) return } - SetPromptReply(promptAndReply[0], promptAndReply[1]) + err := SetPromptReply(promptAndReply[0], promptAndReply[1]) + if err != nil { + bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("设置提示词失败:%v", err))) + return + } bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("已设置提示词 '%s' 的回复。", promptAndReply[0]))) case "delete": if len(args) < 3 { bot.Send(tgbotapi.NewMessage(message.Chat.ID, "使用方法: /prompt delete <提示词>")) return } - DeletePromptReply(args[2]) + err := DeletePromptReply(args[2]) + if err != nil { + bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("删除提示词失败:%v", err))) + return + } bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("已删除提示词 '%s' 的回复。", args[2]))) case "list": bot.Send(tgbotapi.NewMessage(message.Chat.ID, ListPromptReplies()))