Refactor db and options to separate packages
This commit is contained in:
465
esi/sso.go
Normal file
465
esi/sso.go
Normal file
@@ -0,0 +1,465 @@
|
||||
package esi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
logger "git.site.quack-lab.dev/dave/cylogger"
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
const (
|
||||
issuerAuthorizeURL = "https://login.eveonline.com/v2/oauth/authorize"
|
||||
issuerTokenURL = "https://login.eveonline.com/v2/oauth/token"
|
||||
)
|
||||
|
||||
type Character struct {
|
||||
ID int64 `gorm:"primaryKey"` // EVE character ID from JWT token
|
||||
CharacterName string `gorm:"uniqueIndex"`
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
ExpiresAt time.Time
|
||||
UpdatedAt time.Time
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type SSO struct {
|
||||
clientID string
|
||||
redirectURI string
|
||||
scopes []string
|
||||
db DB
|
||||
mu sync.Mutex
|
||||
router *router.Router
|
||||
state string
|
||||
callbackChan chan struct {
|
||||
code string
|
||||
state string
|
||||
err error
|
||||
}
|
||||
}
|
||||
|
||||
// NewSSO creates a new SSO instance
|
||||
func NewSSO(clientID, redirectURI string, scopes []string) (*SSO, error) {
|
||||
logger.Info("Creating new SSO instance for clientID %s with redirectURI %s and scopes %v", clientID, redirectURI, scopes)
|
||||
|
||||
db, err := GetDB()
|
||||
if err != nil {
|
||||
logger.Error("Failed to get database connection %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s := &SSO{
|
||||
clientID: clientID,
|
||||
redirectURI: redirectURI,
|
||||
scopes: scopes,
|
||||
db: db,
|
||||
}
|
||||
|
||||
if err := s.initDB(); err != nil {
|
||||
logger.Error("Failed to initialize SSO database %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.Info("SSO instance created successfully")
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// SetRouter allows the SSO to use an existing fasthttp router
|
||||
func (s *SSO) SetRouter(r *router.Router) {
|
||||
s.router = r
|
||||
logger.Debug("SSO configured to use existing fasthttp router")
|
||||
s.setupCallbackHandler()
|
||||
}
|
||||
|
||||
func (s *SSO) initDB() error {
|
||||
logger.Debug("Initializing SSO database schema")
|
||||
err := s.db.AutoMigrate(&Character{})
|
||||
if err != nil {
|
||||
logger.Error("Failed to migrate Token table %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Debug("SSO database schema initialized successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCharacter returns a valid character object for the given character name
|
||||
// If no token exists, it will start the OAuth flow
|
||||
// If token is expired, it will refresh it automatically
|
||||
func (s *SSO) GetCharacter(ctx context.Context, characterName string) (Character, error) {
|
||||
logger.Debug("Getting token for character %s", characterName)
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Try to get existing token from DB
|
||||
char, err := s.db.GetCharacterByName(characterName)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
logger.Info("No existing token found for character %s, starting authentication flow", characterName)
|
||||
// No token exists, need to authenticate
|
||||
if err := s.startAuthFlow(ctx, characterName); err != nil {
|
||||
return Character{}, err
|
||||
}
|
||||
// After authentication, fetch the token from DB
|
||||
char, err = s.db.GetCharacterByName(characterName)
|
||||
if err != nil {
|
||||
return Character{}, err
|
||||
}
|
||||
logger.Info("Successfully authenticated character %s", characterName)
|
||||
} else {
|
||||
return Character{}, err
|
||||
}
|
||||
} else {
|
||||
logger.Debug("Found existing token for character %s, expires at %v", characterName, char.ExpiresAt)
|
||||
// Check if character ID is missing (old record)
|
||||
if char.ID == 0 {
|
||||
logger.Info("Character ID missing for %s, extracting from token", characterName)
|
||||
_, eveCharID := parseTokenCharacter(char.AccessToken)
|
||||
if eveCharID > 0 {
|
||||
char.ID = eveCharID
|
||||
logger.Debug("Updating character %s with ID: %d", characterName, eveCharID)
|
||||
if err := s.db.SaveCharacter(char); err != nil {
|
||||
logger.Warning("Failed to update character %s with ID: %v", characterName, err)
|
||||
}
|
||||
} else {
|
||||
logger.Warning("Failed to extract character ID from token for %s", characterName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if token needs refresh
|
||||
if time.Now().After(char.ExpiresAt.Add(-1 * time.Minute)) {
|
||||
logger.Info("Token for character %s is expired or expiring soon, refreshing", characterName)
|
||||
if err := s.refreshToken(ctx, char); err != nil {
|
||||
logger.Warning("Token refresh failed for character %s, re-authenticating: %v", characterName, err)
|
||||
// Refresh failed, need to re-authenticate
|
||||
if err := s.startAuthFlow(ctx, characterName); err != nil {
|
||||
return Character{}, err
|
||||
}
|
||||
// After re-authentication, fetch the token from DB
|
||||
char, err = s.db.GetCharacterByName(characterName)
|
||||
if err != nil {
|
||||
return Character{}, err
|
||||
}
|
||||
logger.Info("Successfully re-authenticated character %s", characterName)
|
||||
} else {
|
||||
logger.Debug("Token refreshed successfully for character %s", characterName)
|
||||
}
|
||||
} else {
|
||||
logger.Debug("Token for character %s is still valid", characterName)
|
||||
}
|
||||
|
||||
logger.Debug("Returning valid token for character %s", characterName)
|
||||
return *char, nil
|
||||
}
|
||||
|
||||
func (s *SSO) startAuthFlow(ctx context.Context, characterName string) error {
|
||||
logger.Info("Starting authentication flow for character %s", characterName)
|
||||
|
||||
// Generate PKCE
|
||||
logger.Debug("Generating PKCE parameters")
|
||||
verifier, challenge, err := generatePKCE()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.state = randString(24)
|
||||
s.callbackChan = make(chan struct {
|
||||
code string
|
||||
state string
|
||||
err error
|
||||
}, 1)
|
||||
authURL := s.buildAuthURL(challenge, s.state)
|
||||
|
||||
logger.Info("Generated authentication URL for character %s", characterName)
|
||||
logger.Info("Please visit this URL to authenticate: \n%s", authURL)
|
||||
logger.Info("Waiting for authentication...")
|
||||
|
||||
// Setup callback handling
|
||||
if s.router == nil {
|
||||
logger.Error("No router configured for callback handling")
|
||||
return errors.New("no router configured for callback handling")
|
||||
}
|
||||
logger.Debug("Using fasthttp router for callback handling")
|
||||
|
||||
// Wait for callback
|
||||
logger.Debug("Waiting for authentication callback")
|
||||
code, receivedState, err := s.waitForCallback()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if receivedState != s.state {
|
||||
logger.Error("Invalid state parameter received: %s, expected: %s", receivedState, s.state)
|
||||
return errors.New("invalid state parameter")
|
||||
}
|
||||
|
||||
logger.Debug("Received valid callback, exchanging code for token")
|
||||
// Exchange code for token
|
||||
char, err := s.exchangeCodeForToken(ctx, code, verifier)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Save token to DB
|
||||
char.CharacterName = characterName
|
||||
// Extract and set EVE character ID as primary key
|
||||
_, eveCharID := parseTokenCharacter(char.AccessToken)
|
||||
char.ID = eveCharID
|
||||
logger.Debug("Saving token to database for character %s (EVE ID: %d)", characterName, eveCharID)
|
||||
if err := s.db.SaveCharacter(char); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Info("Authentication flow completed successfully for character %s", characterName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SSO) buildAuthURL(challenge, state string) string {
|
||||
q := url.Values{}
|
||||
q.Set("response_type", "code")
|
||||
q.Set("client_id", s.clientID)
|
||||
q.Set("redirect_uri", s.redirectURI)
|
||||
if len(s.scopes) > 0 {
|
||||
q.Set("scope", strings.Join(s.scopes, " "))
|
||||
}
|
||||
q.Set("state", state)
|
||||
q.Set("code_challenge", challenge)
|
||||
q.Set("code_challenge_method", "S256")
|
||||
|
||||
return issuerAuthorizeURL + "?" + q.Encode()
|
||||
}
|
||||
|
||||
func (s *SSO) setupCallbackHandler() {
|
||||
u, err := url.Parse(s.redirectURI)
|
||||
if err != nil {
|
||||
logger.Error("Failed to parse redirect URI for callback handler: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debug("Setting up callback handler on path: %s", u.Path)
|
||||
s.router.GET(u.Path, s.handleCallback)
|
||||
}
|
||||
|
||||
func (s *SSO) handleCallback(ctx *fasthttp.RequestCtx) {
|
||||
s.processCallback(
|
||||
ctx.IsGet(),
|
||||
string(ctx.QueryArgs().Peek("code")),
|
||||
string(ctx.QueryArgs().Peek("state")),
|
||||
func(status int, body string) {
|
||||
ctx.SetStatusCode(status)
|
||||
ctx.WriteString(body)
|
||||
},
|
||||
func(contentType string) {
|
||||
ctx.SetContentType(contentType)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (s *SSO) processCallback(isGet bool, code, state string, writeResponse func(int, string), setContentType func(string)) {
|
||||
if !isGet {
|
||||
logger.Warning("Invalid callback method")
|
||||
writeResponse(http.StatusMethodNotAllowed, "Method not allowed")
|
||||
s.callbackChan <- struct {
|
||||
code string
|
||||
state string
|
||||
err error
|
||||
}{"", "", errors.New("method not allowed")}
|
||||
return
|
||||
}
|
||||
|
||||
if code == "" || state == "" || state != s.state {
|
||||
logger.Error("Invalid SSO response: code=%s, state=%s, expected_state=%s", code, state, s.state)
|
||||
writeResponse(http.StatusBadRequest, "Invalid SSO response")
|
||||
s.callbackChan <- struct {
|
||||
code string
|
||||
state string
|
||||
err error
|
||||
}{"", "", errors.New("invalid state")}
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Received valid callback, exchanging token for code: %s", code)
|
||||
setContentType("text/html")
|
||||
writeResponse(http.StatusOK, "<html><body><h1>Login successful!</h1><p>You can close this window.</p></body></html>")
|
||||
s.callbackChan <- struct {
|
||||
code string
|
||||
state string
|
||||
err error
|
||||
}{code, state, nil}
|
||||
}
|
||||
|
||||
func (s *SSO) waitForCallback() (code, state string, err error) {
|
||||
logger.Debug("Waiting for authentication callback (timeout: 30s)")
|
||||
// Wait for callback through channel
|
||||
select {
|
||||
case result := <-s.callbackChan:
|
||||
if result.err != nil {
|
||||
logger.Error("Callback received with error: %v", result.err)
|
||||
return "", "", result.err
|
||||
}
|
||||
logger.Debug("Callback received successfully")
|
||||
return result.code, result.state, result.err
|
||||
case <-time.After(30 * time.Second):
|
||||
logger.Error("Callback timeout after 30 seconds")
|
||||
return "", "", errors.New("callback timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SSO) exchangeCodeForToken(ctx context.Context, code, verifier string) (*Character, error) {
|
||||
logger.Debug("Exchanging authorization code for access token")
|
||||
form := url.Values{}
|
||||
form.Set("grant_type", "authorization_code")
|
||||
form.Set("code", code)
|
||||
form.Set("client_id", s.clientID)
|
||||
form.Set("code_verifier", verifier)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, issuerTokenURL, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
logger.Debug("Sending token exchange request to %s", issuerTokenURL)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
logger.Error("Token exchange failed with status %d: %s", resp.StatusCode, string(b))
|
||||
return nil, fmt.Errorf("token exchange failed: %s: %s", resp.Status, string(b))
|
||||
}
|
||||
|
||||
var tr tokenResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse character info from token
|
||||
name, _ := parseTokenCharacter(tr.AccessToken)
|
||||
logger.Info("Successfully exchanged code for token, character: %s", name)
|
||||
|
||||
return &Character{
|
||||
CharacterName: name,
|
||||
AccessToken: tr.AccessToken,
|
||||
RefreshToken: tr.RefreshToken,
|
||||
ExpiresAt: time.Now().Add(time.Duration(tr.ExpiresIn-30) * time.Second),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SSO) refreshToken(ctx context.Context, char *Character) error {
|
||||
logger.Debug("Refreshing token for character %s", char.CharacterName)
|
||||
form := url.Values{}
|
||||
form.Set("grant_type", "refresh_token")
|
||||
form.Set("refresh_token", char.RefreshToken)
|
||||
form.Set("client_id", s.clientID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, issuerTokenURL, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
logger.Debug("Sending token refresh request to %s", issuerTokenURL)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
logger.Error("Token refresh failed with status %d: %s", resp.StatusCode, string(b))
|
||||
return fmt.Errorf("token refresh failed: %s: %s", resp.Status, string(b))
|
||||
}
|
||||
|
||||
var tr tokenResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update token
|
||||
char.AccessToken = tr.AccessToken
|
||||
if tr.RefreshToken != "" {
|
||||
char.RefreshToken = tr.RefreshToken
|
||||
}
|
||||
if tr.ExpiresIn > 0 {
|
||||
char.ExpiresAt = time.Now().Add(time.Duration(tr.ExpiresIn-30) * time.Second)
|
||||
}
|
||||
|
||||
logger.Debug("Saving refreshed token to database for character %s", char.CharacterName)
|
||||
if err := s.db.SaveCharacter(char); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Info("Token refreshed successfully for character %s", char.CharacterName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Utility functions
|
||||
func generatePKCE() (verifier string, challenge string, err error) {
|
||||
buf := make([]byte, 32)
|
||||
if _, err = rand.Read(buf); err != nil {
|
||||
return
|
||||
}
|
||||
v := base64.RawURLEncoding.EncodeToString(buf)
|
||||
h := sha256.Sum256([]byte(v))
|
||||
c := base64.RawURLEncoding.EncodeToString(h[:])
|
||||
return v, c, nil
|
||||
}
|
||||
|
||||
func randString(n int) string {
|
||||
buf := make([]byte, n)
|
||||
_, _ = rand.Read(buf)
|
||||
return base64.RawURLEncoding.EncodeToString(buf)
|
||||
}
|
||||
|
||||
func parseTokenCharacter(jwt string) (name string, id int64) {
|
||||
parts := strings.Split(jwt, ".")
|
||||
if len(parts) != 3 {
|
||||
return "", 0
|
||||
}
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return "", 0
|
||||
}
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(payload, &m); err != nil {
|
||||
return "", 0
|
||||
}
|
||||
if v, ok := m["name"].(string); ok {
|
||||
name = v
|
||||
}
|
||||
if v, ok := m["sub"].(string); ok {
|
||||
if idx := strings.LastIndexByte(v, ':'); idx > -1 {
|
||||
if idv, err := strconv.ParseInt(v[idx+1:], 10, 64); err == nil {
|
||||
id = idv
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type tokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
Reference in New Issue
Block a user