From 693c822c12366bf58ca576e9fba49eeb5b612443 Mon Sep 17 00:00:00 2001 From: PhatPhuckDave Date: Fri, 10 Oct 2025 20:00:33 +0200 Subject: [PATCH] Try to make esi use db.go --- db.go | 29 ++++++++++++++++++++++++++--- esi_sso.go | 42 +++++++++++++++++------------------------- main.go | 8 +++++++- 3 files changed, 50 insertions(+), 29 deletions(-) diff --git a/db.go b/db.go index f74a412..c125cd0 100644 --- a/db.go +++ b/db.go @@ -10,6 +10,9 @@ import ( type DB interface { DB() *gorm.DB Raw(sql string, args ...any) *gorm.DB + GetTokenForCharacter(characterName string) (*Token, error) + SaveTokenForCharacter(token *Token) error + AutoMigrate(dst ...interface{}) error } type DBWrapper struct { @@ -21,7 +24,11 @@ var db *DBWrapper func GetDB() (DB, error) { var err error - dbFile := filepath.Join("data.sqlite") + if db != nil { + return db, nil + } + + dbFile := filepath.Join(options.DBPath) db, err := gorm.Open(sqlite.Open(dbFile), &gorm.Config{ // SkipDefaultTransaction: true, PrepareStmt: true, @@ -34,7 +41,6 @@ func GetDB() (DB, error) { return &DBWrapper{db: db}, nil } - // Just a wrapper func (db *DBWrapper) Raw(sql string, args ...any) *gorm.DB { return db.db.Raw(sql, args...) @@ -42,4 +48,21 @@ func (db *DBWrapper) Raw(sql string, args ...any) *gorm.DB { func (db *DBWrapper) DB() *gorm.DB { return db.db -} \ No newline at end of file +} + +func (db *DBWrapper) GetTokenForCharacter(characterName string) (*Token, error) { + var token Token + err := db.db.Where("character_name = ?", characterName).First(&token).Error + if err != nil { + return nil, err + } + return &token, nil +} + +func (db *DBWrapper) SaveTokenForCharacter(token *Token) error { + return db.db.Save(token).Error +} + +func (db *DBWrapper) AutoMigrate(dst ...interface{}) error { + return db.db.AutoMigrate(dst...) +} diff --git a/esi_sso.go b/esi_sso.go index 659a5d5..5e12c41 100644 --- a/esi_sso.go +++ b/esi_sso.go @@ -12,14 +12,11 @@ import ( "net" "net/http" "net/url" - "os" - "path/filepath" "strconv" "strings" "sync" "time" - "gorm.io/driver/sqlite" "gorm.io/gorm" ) @@ -42,7 +39,7 @@ type SSO struct { clientID string redirectURI string scopes []string - db *gorm.DB + db DB mu sync.Mutex server *http.Server state string @@ -55,10 +52,16 @@ type SSO struct { // NewSSO creates a new SSO instance func NewSSO(clientID, redirectURI string, scopes []string) (*SSO, error) { + db, err := GetDB() + if err != nil { + return nil, err + } + s := &SSO{ clientID: clientID, redirectURI: redirectURI, scopes: scopes, + db: db, } if err := s.initDB(); err != nil { @@ -69,20 +72,7 @@ func NewSSO(clientID, redirectURI string, scopes []string) (*SSO, error) { } func (s *SSO) initDB() error { - home, err := os.UserHomeDir() - if err != nil { - return err - } - dbPath := filepath.Join(home, ".industrializer", "sqlite-latest.sqlite") - db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{}) - if err != nil { - return err - } - if err := db.AutoMigrate(&Token{}); err != nil { - return err - } - s.db = db - return nil + return s.db.AutoMigrate(&Token{}) } // GetToken returns a valid access token for the given character name @@ -93,15 +83,16 @@ func (s *SSO) GetToken(ctx context.Context, characterName string) (string, error defer s.mu.Unlock() // Try to get existing token from DB - var token Token - if err := s.db.Where("character_name = ?", characterName).First(&token).Error; err != nil { + token, err := s.db.GetTokenForCharacter(characterName) + if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { // No token exists, need to authenticate if err := s.startAuthFlow(ctx, characterName); err != nil { return "", err } // After authentication, fetch the token from DB - if err := s.db.Where("character_name = ?", characterName).First(&token).Error; err != nil { + token, err = s.db.GetTokenForCharacter(characterName) + if err != nil { return "", err } } else { @@ -111,13 +102,14 @@ func (s *SSO) GetToken(ctx context.Context, characterName string) (string, error // Check if token needs refresh if time.Now().After(token.ExpiresAt.Add(-60 * time.Second)) { - if err := s.refreshToken(ctx, &token); err != nil { + if err := s.refreshToken(ctx, token); err != nil { // Refresh failed, need to re-authenticate if err := s.startAuthFlow(ctx, characterName); err != nil { return "", err } // After re-authentication, fetch the token from DB - if err := s.db.Where("character_name = ?", characterName).First(&token).Error; err != nil { + token, err = s.db.GetTokenForCharacter(characterName) + if err != nil { return "", err } } @@ -170,7 +162,7 @@ func (s *SSO) startAuthFlow(ctx context.Context, characterName string) error { // Save token to DB token.CharacterName = characterName - return s.db.Save(&token).Error + return s.db.SaveTokenForCharacter(token) } func (s *SSO) buildAuthURL(challenge, state string) string { @@ -341,7 +333,7 @@ func (s *SSO) refreshToken(ctx context.Context, token *Token) error { token.ExpiresAt = time.Now().Add(time.Duration(tr.ExpiresIn-30) * time.Second) } - return s.db.Save(token).Error + return s.db.SaveTokenForCharacter(token) } // Utility functions diff --git a/main.go b/main.go index 3033555..fb41432 100644 --- a/main.go +++ b/main.go @@ -17,6 +17,7 @@ type Options struct { ClientID string RedirectURI string Scopes []string + DBPath string } var options Options @@ -32,7 +33,7 @@ func main() { logger.Error("Failed to load options %v", err) return } - + // Create SSO instance sso, err := NewSSO( options.ClientID, @@ -75,10 +76,15 @@ func LoadOptions() (Options, error) { return Options{}, fmt.Errorf("ESI_SCOPES is required in .env file") } scopes := strings.Fields(rawScopes) + dbPath := os.Getenv("DB_PATH") + if dbPath == "" { + return Options{}, fmt.Errorf("DB_PATH is required in .env file") + } return Options{ ClientID: clientID, RedirectURI: redirectURI, Scopes: scopes, + DBPath: dbPath, }, nil }