Try to make esi use db.go
This commit is contained in:
29
db.go
29
db.go
@@ -10,6 +10,9 @@ import (
|
|||||||
type DB interface {
|
type DB interface {
|
||||||
DB() *gorm.DB
|
DB() *gorm.DB
|
||||||
Raw(sql string, args ...any) *gorm.DB
|
Raw(sql string, args ...any) *gorm.DB
|
||||||
|
GetTokenForCharacter(characterName string) (*Token, error)
|
||||||
|
SaveTokenForCharacter(token *Token) error
|
||||||
|
AutoMigrate(dst ...interface{}) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type DBWrapper struct {
|
type DBWrapper struct {
|
||||||
@@ -21,7 +24,11 @@ var db *DBWrapper
|
|||||||
func GetDB() (DB, error) {
|
func GetDB() (DB, error) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
dbFile := filepath.Join("data.sqlite")
|
if db != nil {
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dbFile := filepath.Join(options.DBPath)
|
||||||
db, err := gorm.Open(sqlite.Open(dbFile), &gorm.Config{
|
db, err := gorm.Open(sqlite.Open(dbFile), &gorm.Config{
|
||||||
// SkipDefaultTransaction: true,
|
// SkipDefaultTransaction: true,
|
||||||
PrepareStmt: true,
|
PrepareStmt: true,
|
||||||
@@ -34,7 +41,6 @@ func GetDB() (DB, error) {
|
|||||||
return &DBWrapper{db: db}, nil
|
return &DBWrapper{db: db}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Just a wrapper
|
// Just a wrapper
|
||||||
func (db *DBWrapper) Raw(sql string, args ...any) *gorm.DB {
|
func (db *DBWrapper) Raw(sql string, args ...any) *gorm.DB {
|
||||||
return db.db.Raw(sql, args...)
|
return db.db.Raw(sql, args...)
|
||||||
@@ -42,4 +48,21 @@ func (db *DBWrapper) Raw(sql string, args ...any) *gorm.DB {
|
|||||||
|
|
||||||
func (db *DBWrapper) DB() *gorm.DB {
|
func (db *DBWrapper) DB() *gorm.DB {
|
||||||
return db.db
|
return db.db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db *DBWrapper) GetTokenForCharacter(characterName string) (*Token, error) {
|
||||||
|
var token Token
|
||||||
|
err := db.db.Where("character_name = ?", characterName).First(&token).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DBWrapper) SaveTokenForCharacter(token *Token) error {
|
||||||
|
return db.db.Save(token).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DBWrapper) AutoMigrate(dst ...interface{}) error {
|
||||||
|
return db.db.AutoMigrate(dst...)
|
||||||
|
}
|
||||||
|
|||||||
42
esi_sso.go
42
esi_sso.go
@@ -12,14 +12,11 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/driver/sqlite"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -42,7 +39,7 @@ type SSO struct {
|
|||||||
clientID string
|
clientID string
|
||||||
redirectURI string
|
redirectURI string
|
||||||
scopes []string
|
scopes []string
|
||||||
db *gorm.DB
|
db DB
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
server *http.Server
|
server *http.Server
|
||||||
state string
|
state string
|
||||||
@@ -55,10 +52,16 @@ type SSO struct {
|
|||||||
|
|
||||||
// NewSSO creates a new SSO instance
|
// NewSSO creates a new SSO instance
|
||||||
func NewSSO(clientID, redirectURI string, scopes []string) (*SSO, error) {
|
func NewSSO(clientID, redirectURI string, scopes []string) (*SSO, error) {
|
||||||
|
db, err := GetDB()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
s := &SSO{
|
s := &SSO{
|
||||||
clientID: clientID,
|
clientID: clientID,
|
||||||
redirectURI: redirectURI,
|
redirectURI: redirectURI,
|
||||||
scopes: scopes,
|
scopes: scopes,
|
||||||
|
db: db,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.initDB(); err != nil {
|
if err := s.initDB(); err != nil {
|
||||||
@@ -69,20 +72,7 @@ func NewSSO(clientID, redirectURI string, scopes []string) (*SSO, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SSO) initDB() error {
|
func (s *SSO) initDB() error {
|
||||||
home, err := os.UserHomeDir()
|
return s.db.AutoMigrate(&Token{})
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
dbPath := filepath.Join(home, ".industrializer", "sqlite-latest.sqlite")
|
|
||||||
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := db.AutoMigrate(&Token{}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.db = db
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetToken returns a valid access token for the given character name
|
// GetToken returns a valid access token for the given character name
|
||||||
@@ -93,15 +83,16 @@ func (s *SSO) GetToken(ctx context.Context, characterName string) (string, error
|
|||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
// Try to get existing token from DB
|
// Try to get existing token from DB
|
||||||
var token Token
|
token, err := s.db.GetTokenForCharacter(characterName)
|
||||||
if err := s.db.Where("character_name = ?", characterName).First(&token).Error; err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
// No token exists, need to authenticate
|
// No token exists, need to authenticate
|
||||||
if err := s.startAuthFlow(ctx, characterName); err != nil {
|
if err := s.startAuthFlow(ctx, characterName); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
// After authentication, fetch the token from DB
|
// After authentication, fetch the token from DB
|
||||||
if err := s.db.Where("character_name = ?", characterName).First(&token).Error; err != nil {
|
token, err = s.db.GetTokenForCharacter(characterName)
|
||||||
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -111,13 +102,14 @@ func (s *SSO) GetToken(ctx context.Context, characterName string) (string, error
|
|||||||
|
|
||||||
// Check if token needs refresh
|
// Check if token needs refresh
|
||||||
if time.Now().After(token.ExpiresAt.Add(-60 * time.Second)) {
|
if time.Now().After(token.ExpiresAt.Add(-60 * time.Second)) {
|
||||||
if err := s.refreshToken(ctx, &token); err != nil {
|
if err := s.refreshToken(ctx, token); err != nil {
|
||||||
// Refresh failed, need to re-authenticate
|
// Refresh failed, need to re-authenticate
|
||||||
if err := s.startAuthFlow(ctx, characterName); err != nil {
|
if err := s.startAuthFlow(ctx, characterName); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
// After re-authentication, fetch the token from DB
|
// After re-authentication, fetch the token from DB
|
||||||
if err := s.db.Where("character_name = ?", characterName).First(&token).Error; err != nil {
|
token, err = s.db.GetTokenForCharacter(characterName)
|
||||||
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -170,7 +162,7 @@ func (s *SSO) startAuthFlow(ctx context.Context, characterName string) error {
|
|||||||
|
|
||||||
// Save token to DB
|
// Save token to DB
|
||||||
token.CharacterName = characterName
|
token.CharacterName = characterName
|
||||||
return s.db.Save(&token).Error
|
return s.db.SaveTokenForCharacter(token)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SSO) buildAuthURL(challenge, state string) string {
|
func (s *SSO) buildAuthURL(challenge, state string) string {
|
||||||
@@ -341,7 +333,7 @@ func (s *SSO) refreshToken(ctx context.Context, token *Token) error {
|
|||||||
token.ExpiresAt = time.Now().Add(time.Duration(tr.ExpiresIn-30) * time.Second)
|
token.ExpiresAt = time.Now().Add(time.Duration(tr.ExpiresIn-30) * time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.db.Save(token).Error
|
return s.db.SaveTokenForCharacter(token)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Utility functions
|
// Utility functions
|
||||||
|
|||||||
8
main.go
8
main.go
@@ -17,6 +17,7 @@ type Options struct {
|
|||||||
ClientID string
|
ClientID string
|
||||||
RedirectURI string
|
RedirectURI string
|
||||||
Scopes []string
|
Scopes []string
|
||||||
|
DBPath string
|
||||||
}
|
}
|
||||||
|
|
||||||
var options Options
|
var options Options
|
||||||
@@ -32,7 +33,7 @@ func main() {
|
|||||||
logger.Error("Failed to load options %v", err)
|
logger.Error("Failed to load options %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create SSO instance
|
// Create SSO instance
|
||||||
sso, err := NewSSO(
|
sso, err := NewSSO(
|
||||||
options.ClientID,
|
options.ClientID,
|
||||||
@@ -75,10 +76,15 @@ func LoadOptions() (Options, error) {
|
|||||||
return Options{}, fmt.Errorf("ESI_SCOPES is required in .env file")
|
return Options{}, fmt.Errorf("ESI_SCOPES is required in .env file")
|
||||||
}
|
}
|
||||||
scopes := strings.Fields(rawScopes)
|
scopes := strings.Fields(rawScopes)
|
||||||
|
dbPath := os.Getenv("DB_PATH")
|
||||||
|
if dbPath == "" {
|
||||||
|
return Options{}, fmt.Errorf("DB_PATH is required in .env file")
|
||||||
|
}
|
||||||
|
|
||||||
return Options{
|
return Options{
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
RedirectURI: redirectURI,
|
RedirectURI: redirectURI,
|
||||||
Scopes: scopes,
|
Scopes: scopes,
|
||||||
|
DBPath: dbPath,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user