From 5f7e674d32acb7818affa43917478e3d448676bd Mon Sep 17 00:00:00 2001 From: max Date: Thu, 6 Apr 2023 08:56:48 -0500 Subject: [PATCH] Add remember me functionality, handle both types of sessions appropriately --- controllers/postController.go | 3 +- models/migrations.go | 9 +++--- models/session.go | 54 ++++++++++++++++++++++++----------- models/user.go | 4 +-- 4 files changed, 47 insertions(+), 23 deletions(-) diff --git a/controllers/postController.go b/controllers/postController.go index 742c8f9..f0af29d 100644 --- a/controllers/postController.go +++ b/controllers/postController.go @@ -24,13 +24,14 @@ func (postController *PostController) Login(w http.ResponseWriter, r *http.Reque username := r.FormValue("username") password := r.FormValue("password") + remember := r.FormValue("remember") == "on" if username == "" || password == "" { log.Println("Tried to login user with empty username or password") http.Redirect(w, r, "/login", http.StatusFound) } - _, err = models.AuthenticateUser(postController.App, w, username, password) + _, err = models.AuthenticateUser(postController.App, w, username, password, remember) if err != nil { log.Println("Error authenticating user") log.Println(err) diff --git a/models/migrations.go b/models/migrations.go index 3b59cd5..a26b595 100644 --- a/models/migrations.go +++ b/models/migrations.go @@ -22,10 +22,11 @@ func RunAllMigrations(app *app.App) error { } session := Session{ - Id: 1, - UserId: 1, - AuthToken: "migrate", - CreatedAt: time.Now(), + Id: 1, + UserId: 1, + AuthToken: "migrate", + RememberMe: false, + CreatedAt: time.Now(), } err = database.Migrate(app, session) if err != nil { diff --git a/models/session.go b/models/session.go index 4801ca8..4b397e0 100644 --- a/models/session.go +++ b/models/session.go @@ -10,10 +10,11 @@ import ( ) type Session struct { - Id int64 - UserId int64 - AuthToken string - CreatedAt time.Time + Id int64 + UserId int64 + AuthToken string + RememberMe bool + CreatedAt time.Time } const sessionColumnsNoId = "\"UserId\", \"AuthToken\", \"CreatedAt\"" @@ -26,13 +27,15 @@ const ( insertSession = "INSERT INTO " + sessionTable + " (" + sessionColumnsNoId + ") VALUES ($1, $2, $3) RETURNING \"Id\"" deleteSessionByAuthToken = "DELETE FROM " + sessionTable + " WHERE \"AuthToken\" = $1" deleteSessionsOlderThan30Days = "DELETE FROM " + sessionTable + " WHERE \"CreatedAt\" < NOW() - INTERVAL '30 days'" + deleteSessionsOlderThan6Hours = "DELETE FROM " + sessionTable + " WHERE \"CreatedAt\" < NOW() - INTERVAL '6 hours' AND \"RememberMe\" = false" ) // CreateSession creates a new session for a user -func CreateSession(app *app.App, w http.ResponseWriter, userId int64) (Session, error) { +func CreateSession(app *app.App, w http.ResponseWriter, userId int64, remember bool) (Session, error) { session := Session{} session.UserId = userId session.AuthToken = generateAuthToken(app) + session.RememberMe = remember session.CreatedAt = time.Now() // If the AuthToken column for any user matches the token, set existingAuthToken to true @@ -47,11 +50,11 @@ func CreateSession(app *app.App, w http.ResponseWriter, userId int64) (Session, // If duplicate token found, recursively call function until unique token is generated if existingAuthToken == true { log.Println("Duplicate token found in sessions table, generating new token...") - return CreateSession(app, w, userId) + return CreateSession(app, w, userId, remember) } // Insert session into database - err = app.Db.QueryRow(insertSession, session.UserId, session.AuthToken, session.CreatedAt).Scan(&session.Id) + err = app.Db.QueryRow(insertSession, session.UserId, session.AuthToken, session.RememberMe, session.CreatedAt).Scan(&session.Id) if err != nil { log.Println("Error inserting session into database") return Session{}, err @@ -76,13 +79,25 @@ func generateAuthToken(app *app.App) string { // createSessionCookie creates a new session cookie func createSessionCookie(app *app.App, w http.ResponseWriter, session Session) { - cookie := &http.Cookie{ - Name: "session", - Value: session.AuthToken, - Path: "/", - MaxAge: 86400, - HttpOnly: true, - Secure: true, + cookie := &http.Cookie{} + if session.RememberMe { + cookie = &http.Cookie{ + Name: "session", + Value: session.AuthToken, + Path: "/", + MaxAge: 2592000 * 1000, // 30 days in ms + HttpOnly: true, + Secure: true, + } + } else { + cookie = &http.Cookie{ + Name: "session", + Value: session.AuthToken, + Path: "/", + MaxAge: 21600 * 1000, // 6 hours in ms + HttpOnly: true, + Secure: true, + } } http.SetCookie(w, cookie) @@ -116,10 +131,17 @@ func DeleteSessionByAuthToken(app *app.App, w http.ResponseWriter, authToken str // ScheduledSessionCleanup deletes expired sessions from the database func ScheduledSessionCleanup(app *app.App) { - // Delete sessions older than 30 days + // Delete sessions older than 30 days (remember me sessions) _, err := app.Db.Exec(deleteSessionsOlderThan30Days) if err != nil { - log.Println("Error deleting expired sessions from database") + log.Println("Error deleting 30 day expired sessions from database") + log.Println(err) + } + + // Delete sessions older than 6 hours + _, err = app.Db.Exec(deleteSessionsOlderThan30Days) + if err != nil { + log.Println("Error deleting 6 hour expired sessions from database") log.Println(err) } diff --git a/models/user.go b/models/user.go index b4d35e2..a732c0d 100644 --- a/models/user.go +++ b/models/user.go @@ -98,7 +98,7 @@ func CreateUser(app *app.App, username string, password string, createdAt time.T } // AuthenticateUser validates the password for the specified user -func AuthenticateUser(app *app.App, w http.ResponseWriter, username string, password string) (Session, error) { +func AuthenticateUser(app *app.App, w http.ResponseWriter, username string, password string, remember bool) (Session, error) { var user User // Query row by username @@ -114,7 +114,7 @@ func AuthenticateUser(app *app.App, w http.ResponseWriter, username string, pass log.Println("Authentication error (incorrect password) for user:" + username) return Session{}, err } else { - return CreateSession(app, w, user.Id) + return CreateSession(app, w, user.Id, remember) } }