diff --git a/controllers/postController.go b/controllers/postController.go index f0af29d..1631f98 100644 --- a/controllers/postController.go +++ b/controllers/postController.go @@ -3,7 +3,6 @@ package controllers import ( "GoWeb/app" "GoWeb/models" - "GoWeb/security" "log" "net/http" "time" @@ -15,13 +14,6 @@ type PostController struct { } func (postController *PostController) Login(w http.ResponseWriter, r *http.Request) { - // Validate csrf token - _, err := security.VerifyCsrfToken(r) - if err != nil { - log.Println("Error verifying csrf token") - return - } - username := r.FormValue("username") password := r.FormValue("password") remember := r.FormValue("remember") == "on" @@ -31,7 +23,7 @@ func (postController *PostController) Login(w http.ResponseWriter, r *http.Reque http.Redirect(w, r, "/login", http.StatusFound) } - _, err = models.AuthenticateUser(postController.App, w, username, password, remember) + _, err := models.AuthenticateUser(postController.App, w, username, password, remember) if err != nil { log.Println("Error authenticating user") log.Println(err) @@ -43,13 +35,6 @@ func (postController *PostController) Login(w http.ResponseWriter, r *http.Reque } func (postController *PostController) Register(w http.ResponseWriter, r *http.Request) { - // Validate csrf token - _, err := security.VerifyCsrfToken(r) - if err != nil { - log.Println("Error verifying csrf token") - return - } - username := r.FormValue("username") password := r.FormValue("password") createdAt := time.Now() @@ -60,7 +45,7 @@ func (postController *PostController) Register(w http.ResponseWriter, r *http.Re http.Redirect(w, r, "/register", http.StatusFound) } - _, err = models.CreateUser(postController.App, username, password, createdAt, updatedAt) + _, err := models.CreateUser(postController.App, username, password, createdAt, updatedAt) if err != nil { log.Println("Error creating user") log.Println(err) diff --git a/middleware/csrf.go b/middleware/csrf.go new file mode 100644 index 0000000..de3c04c --- /dev/null +++ b/middleware/csrf.go @@ -0,0 +1,22 @@ +package middleware + +import ( + "GoWeb/security" + "log" + "net/http" +) + +// Csrf validates the CSRF token and returns the handler function if it succeded +func Csrf(f func(w http.ResponseWriter, r *http.Request)) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + // Verify csrf token + _, err := security.VerifyCsrfToken(r) + if err != nil { + log.Println("Error verifying csrf token") + http.Error(w, "Forbidden", http.StatusForbidden) + return + } + + f(w, r) + } +} diff --git a/middleware/groups.go b/middleware/groups.go new file mode 100644 index 0000000..0bb20a4 --- /dev/null +++ b/middleware/groups.go @@ -0,0 +1,5 @@ +package middleware + +import "net/http" + +type MiddlewareFunc func(f func(w http.ResponseWriter, r *http.Request)) func(w http.ResponseWriter, r *http.Request) diff --git a/middleware/wrapper.go b/middleware/wrapper.go new file mode 100644 index 0000000..ce0d4e6 --- /dev/null +++ b/middleware/wrapper.go @@ -0,0 +1,14 @@ +package middleware + +import "net/http" + +// ProcessGroup is a wrapper function for the http.HandleFunc function +// that takes the function you want to execute (f) and the middleware you want +// to execute (m) this should be used when processing multiple groups of middleware at a time +func ProcessGroup(f func(w http.ResponseWriter, r *http.Request), m []MiddlewareFunc) func(w http.ResponseWriter, r *http.Request) { + for _, middleware := range m { + _ = middleware(f) + } + + return f +} diff --git a/routes/postRoutes.go b/routes/postRoutes.go index 58ab50e..a0d93b1 100644 --- a/routes/postRoutes.go +++ b/routes/postRoutes.go @@ -3,6 +3,7 @@ package routes import ( "GoWeb/app" "GoWeb/controllers" + "GoWeb/middleware" "net/http" ) @@ -14,6 +15,6 @@ func PostRoutes(app *app.App) { } // User authentication - http.HandleFunc("/register-handle", postController.Register) - http.HandleFunc("/login-handle", postController.Login) + http.HandleFunc("/register-handle", middleware.Csrf(postController.Register)) + http.HandleFunc("/login-handle", middleware.Csrf(postController.Login)) }