From d496c9924a4073feaa8f0e681a767e60aa877cca Mon Sep 17 00:00:00 2001 From: wood chen Date: Sat, 28 Sep 2024 15:45:35 +0800 Subject: [PATCH] fix --- core/database.go | 18 ++++++++++++++++++ core/init.go | 4 ++++ service/link_filter/link_filter.go | 6 +++--- service/message_handler.go | 1 + 4 files changed, 26 insertions(+), 3 deletions(-) diff --git a/core/database.go b/core/database.go index cdbf946..0ff0659 100644 --- a/core/database.go +++ b/core/database.go @@ -457,3 +457,21 @@ func contains(slice []string, str string) bool { } return false } + +func (d *Database) EnsureTablesExist() error { + tables := []string{"keywords", "whitelist", "prompt_replies", "config"} + for _, table := range tables { + var exists bool + err := d.db.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name=?", table).Scan(&exists) + if err != nil && err != sql.ErrNoRows { + return err + } + if !exists { + if err := d.createTables(); err != nil { + return err + } + break + } + } + return nil +} diff --git a/core/init.go b/core/init.go index 8ade2ea..120ba5b 100644 --- a/core/init.go +++ b/core/init.go @@ -57,6 +57,10 @@ func Init() error { return fmt.Errorf("迁移现有关键词失败: %v", err) } + if err := DB.EnsureTablesExist(); err != nil { + return fmt.Errorf("确保数据库表存在失败: %v", err) + } + // 从环境变量中读取调试模式设置 DEBUG_MODE = os.Getenv("DEBUG_MODE") == "true" diff --git a/service/link_filter/link_filter.go b/service/link_filter/link_filter.go index 456b810..9bf6044 100644 --- a/service/link_filter/link_filter.go +++ b/service/link_filter/link_filter.go @@ -28,7 +28,7 @@ func NewLinkFilter() (*LinkFilter, error) { } if err := lf.LoadDataFromDatabase(); err != nil { - return nil, err + return nil, fmt.Errorf("failed to load data from database: %v", err) } return lf, nil @@ -41,12 +41,12 @@ func (lf *LinkFilter) LoadDataFromDatabase() error { var err error lf.Keywords, err = core.DB.GetAllKeywords() if err != nil { - return err + return fmt.Errorf("failed to get keywords: %v", err) } lf.Whitelist, err = core.DB.GetAllWhitelist() if err != nil { - return err + return fmt.Errorf("failed to get whitelist: %v", err) } logger.Printf("Loaded %d Keywords and %d Whitelist entries from database", len(lf.Keywords), len(lf.Whitelist)) diff --git a/service/message_handler.go b/service/message_handler.go index 4e3440e..b26ac5b 100644 --- a/service/message_handler.go +++ b/service/message_handler.go @@ -116,6 +116,7 @@ func RunMessageHandler() error { linkFilter, err := link_filter.NewLinkFilter() if err != nil { + log.Printf("Failed to create LinkFilter: %v", err) return fmt.Errorf("failed to create LinkFilter: %v", err) }