diff --git a/models/session.go b/models/session.go index f298bc8..94100f2 100644 --- a/models/session.go +++ b/models/session.go @@ -64,6 +64,18 @@ func CreateSession(app *app.App, w http.ResponseWriter, userId int64, remember b return session, nil } +func GetSessionByAuthToken(app *app.App, authToken string) (Session, error) { + session := Session{} + + err := app.Db.QueryRow(selectSessionByAuthToken, authToken).Scan(&session.Id, &session.UserId, &session.AuthToken, &session.RememberMe, &session.CreatedAt) + if err != nil { + log.Println("Error getting session by auth token") + return Session{}, err + } + + return session, nil +} + // Generates a random 64-byte string func generateAuthToken(app *app.App) string { // Generate random bytes diff --git a/models/user.go b/models/user.go index a732c0d..c7ca5cd 100644 --- a/models/user.go +++ b/models/user.go @@ -37,16 +37,13 @@ func GetCurrentUser(app *app.App, r *http.Request) (User, error) { return User{}, err } - var userId int64 - - // Query row by AuthToken - err = app.Db.QueryRow(selectSessionIdByAuthToken, cookie.Value).Scan(&userId) + session, err := GetSessionByAuthToken(app, cookie.Value) if err != nil { - log.Println("Error querying session row with session: " + cookie.Value) + log.Println("Error getting session by auth token") return User{}, err } - return GetUserById(app, userId) + return GetUserById(app, session.UserId) } // GetUserById finds a User table row in the database by id and returns a struct representing this row