Compare commits

..

4 Commits

Author SHA1 Message Date
Maximilian
37391190fb Merge branch 'master' into file_uploads 2023-05-05 12:19:55 -05:00
max
1fb8fdef81 Merge branch 'master' into file_uploads 2023-05-04 09:00:53 -05:00
max
baef0cbe78 Fix a couple deprecated calls and handle errors 2023-05-04 08:40:35 -05:00
tfasano1
d0da1a9114 File Uploading feature by tfasano1 2023-05-04 08:26:44 -05:00
28 changed files with 432 additions and 449 deletions

View File

@ -12,14 +12,12 @@ fine with getting your hands dirty, but I plan on having it ready to go for more
- Routing/controllers
- Templating
- Simple database migration system
- Built in REST client
- CSRF protection
- Middleware
- Minimal user login/registration + sessions
- Config file handling
- Scheduled tasks
- Entire website compiles into a single binary (~10mb) (excluding env.json)
- Minimal dependencies (just standard library, postgres driver, and x/crypto for bcrypt)
- Minimal dependencies (just standard library, postgres driver, and experimental package for bcrypt)
<hr>
@ -39,13 +37,10 @@ fine with getting your hands dirty, but I plan on having it ready to go for more
## How to use 🤔
1. Clone
2. Delete the git folder, so you can start tracking in your own repo
3. Run `go get` to install dependencies
4. Copy env_example.json to env.json and fill in the values
5. Run `go run main.go` to start the server
6. Rename the occurences of "GoWeb" to your app name
7. Start building your app!
8. When you see useful changes to GoWeb you'd like in your project copy them over
2. Run `go get` to install dependencies
3. Copy env_example.json to env.json and fill in the values
4. Run `go run main.go` to start the server
5. Start building your app!
## How to contribute 👨‍💻
@ -59,7 +54,7 @@ fine with getting your hands dirty, but I plan on having it ready to go for more
### License and disclaimer 😤
- You are free to use this project under the terms of the MIT license. See LICENSE for more details.
- You are responsible for the security and everything else regarding your application.
- You and you alone are responsible for the security and everything else regarding your application.
- It is not required, but I ask that when you use this project you give me credit by linking to this repository.
- I also ask that when releasing self-hosted or other end-user applications that you release it under
the [GPLv3](https://www.gnu.org/licenses/gpl-3.0.html) license. This too is not required, but I would appreciate it.

View File

@ -17,25 +17,27 @@ type Scheduled struct {
}
type Task struct {
Funcs []func(app *App)
Interval time.Duration
Funcs []func(app *App)
}
func RunScheduledTasks(app *App, poolSize int, stop <-chan struct{}) {
// Run every time the server starts
for _, f := range app.ScheduledTasks.EveryReboot {
f(app)
}
tasks := []Task{
{Funcs: app.ScheduledTasks.EverySecond, Interval: time.Second},
{Funcs: app.ScheduledTasks.EveryMinute, Interval: time.Minute},
{Funcs: app.ScheduledTasks.EveryHour, Interval: time.Hour},
{Funcs: app.ScheduledTasks.EveryDay, Interval: 24 * time.Hour},
{Funcs: app.ScheduledTasks.EveryWeek, Interval: 7 * 24 * time.Hour},
{Funcs: app.ScheduledTasks.EveryMonth, Interval: 30 * 24 * time.Hour},
{Funcs: app.ScheduledTasks.EveryYear, Interval: 365 * 24 * time.Hour},
{Interval: time.Second, Funcs: app.ScheduledTasks.EverySecond},
{Interval: time.Minute, Funcs: app.ScheduledTasks.EveryMinute},
{Interval: time.Hour, Funcs: app.ScheduledTasks.EveryHour},
{Interval: 24 * time.Hour, Funcs: app.ScheduledTasks.EveryDay},
{Interval: 7 * 24 * time.Hour, Funcs: app.ScheduledTasks.EveryWeek},
{Interval: 30 * 24 * time.Hour, Funcs: app.ScheduledTasks.EveryMonth},
{Interval: 365 * 24 * time.Hour, Funcs: app.ScheduledTasks.EveryYear},
}
// Set up task runners
var wg sync.WaitGroup
runners := make([]chan bool, len(tasks))
for i, task := range tasks {
@ -63,8 +65,10 @@ func RunScheduledTasks(app *App, poolSize int, stop <-chan struct{}) {
}(task, runner)
}
// Wait for all goroutines to finish
wg.Wait()
// Close channels
for _, runner := range runners {
close(runner)
}

View File

@ -3,7 +3,7 @@ package config
import (
"encoding/json"
"flag"
"log/slog"
"log"
"os"
)
@ -23,8 +23,12 @@ type Configuration struct {
}
Template struct {
BaseName string `json:"BaseTemplateName"`
ContentPath string `json:"ContentPath"`
BaseName string `json:"BaseTemplateName"`
}
Upload struct {
BaseName string `json:"UploadDirectoryName"`
MaxSize int64 `json:"MaxUploadSize"`
}
}
@ -34,21 +38,22 @@ func LoadConfig() Configuration {
flag.Parse()
file, err := os.Open(*c)
if err != nil {
panic("unable to open JSON config file: " + err.Error())
log.Fatal("Unable to open JSON config file: ", err)
}
defer func(file *os.File) {
err := file.Close()
if err != nil {
slog.Error("unable to close JSON config file: ", err)
log.Fatal("Unable to close JSON config file: ", err)
}
}(file)
// Decode json config file to Configuration struct named config
decoder := json.NewDecoder(file)
Config := Configuration{}
err = decoder.Decode(&Config)
if err != nil {
panic("unable to decode JSON config file: " + err.Error())
log.Fatal("Unable to decode JSON config file: ", err)
}
return Config

View File

@ -1,65 +0,0 @@
package controllers
import (
"GoWeb/app"
"GoWeb/models"
"GoWeb/security"
"GoWeb/templating"
"net/http"
)
// Get is a wrapper struct for the App struct
type Get struct {
App *app.App
}
func (g *Get) ShowHome(w http.ResponseWriter, _ *http.Request) {
type dataStruct struct {
Test string
}
data := dataStruct{
Test: "Hello World!",
}
templating.RenderTemplate(w, "templates/pages/home.html", data)
}
func (g *Get) ShowRegister(w http.ResponseWriter, r *http.Request) {
type dataStruct struct {
CsrfToken string
}
CsrfToken, err := security.GenerateCsrfToken(w, r)
if err != nil {
return
}
data := dataStruct{
CsrfToken: CsrfToken,
}
templating.RenderTemplate(w, "templates/pages/register.html", data)
}
func (g *Get) ShowLogin(w http.ResponseWriter, r *http.Request) {
type dataStruct struct {
CsrfToken string
}
CsrfToken, err := security.GenerateCsrfToken(w, r)
if err != nil {
return
}
data := dataStruct{
CsrfToken: CsrfToken,
}
templating.RenderTemplate(w, "templates/pages/login.html", data)
}
func (g *Get) Logout(w http.ResponseWriter, r *http.Request) {
models.LogoutUser(g.App, w, r)
http.Redirect(w, r, "/", http.StatusFound)
}

View File

@ -0,0 +1,74 @@
package controllers
import (
"GoWeb/app"
"GoWeb/models"
"GoWeb/security"
"GoWeb/templating"
"net/http"
)
// GetController is a wrapper struct for the App struct
type GetController struct {
App *app.App
}
func (getController *GetController) ShowHome(w http.ResponseWriter, _ *http.Request) {
type dataStruct struct {
Test string
}
data := dataStruct{
Test: "Hello World!",
}
templating.RenderTemplate(getController.App, w, "templates/pages/home.html", data)
}
func (getController *GetController) ShowRegister(w http.ResponseWriter, r *http.Request) {
type dataStruct struct {
CsrfToken string
}
// Create csrf token
CsrfToken, err := security.GenerateCsrfToken(w, r)
if err != nil {
return
}
data := dataStruct{
CsrfToken: CsrfToken,
}
templating.RenderTemplate(getController.App, w, "templates/pages/register.html", data)
}
func (getController *GetController) ShowLogin(w http.ResponseWriter, r *http.Request) {
type dataStruct struct {
CsrfToken string
}
// Create csrf token
CsrfToken, err := security.GenerateCsrfToken(w, r)
if err != nil {
return
}
data := dataStruct{
CsrfToken: CsrfToken,
}
templating.RenderTemplate(getController.App, w, "templates/pages/login.html", data)
}
func (getController *GetController) ShowFile(w http.ResponseWriter, r *http.Request) {
// GET /uploads?name=file.jpg
// will serve file.jpg
name := r.URL.Query().Get("name")
http.ServeFile(w, r, getController.App.Config.Upload.BaseName+name)
}
func (getController *GetController) Logout(w http.ResponseWriter, r *http.Request) {
models.LogoutUser(getController.App, w, r)
http.Redirect(w, r, "/", http.StatusFound)
}

View File

@ -1,52 +0,0 @@
package controllers
import (
"GoWeb/app"
"GoWeb/models"
"log/slog"
"net/http"
"time"
)
// Post is a wrapper struct for the App struct
type Post struct {
App *app.App
}
func (p *Post) Login(w http.ResponseWriter, r *http.Request) {
username := r.FormValue("username")
password := r.FormValue("password")
remember := r.FormValue("remember") == "on"
if username == "" || password == "" {
http.Redirect(w, r, "/login", http.StatusUnauthorized)
}
_, err := models.AuthenticateUser(p.App, w, username, password, remember)
if err != nil {
http.Redirect(w, r, "/login", http.StatusUnauthorized)
return
}
http.Redirect(w, r, "/", http.StatusFound)
}
func (p *Post) Register(w http.ResponseWriter, r *http.Request) {
username := r.FormValue("username")
password := r.FormValue("password")
createdAt := time.Now()
updatedAt := time.Now()
if username == "" || password == "" {
http.Redirect(w, r, "/register", http.StatusUnauthorized)
}
_, err := models.CreateUser(p.App, username, password, createdAt, updatedAt)
if err != nil {
// TODO: if err == bcrypt.ErrPasswordTooLong display error to user, this will require a flash message system with cookies
slog.Error("error creating user: " + err.Error())
http.Redirect(w, r, "/register", http.StatusInternalServerError)
}
http.Redirect(w, r, "/login", http.StatusFound)
}

View File

@ -0,0 +1,131 @@
package controllers
import (
"GoWeb/app"
"GoWeb/models"
"GoWeb/security"
"io"
"log"
"mime/multipart"
"net/http"
"os"
"time"
)
// PostController is a wrapper struct for the App struct
type PostController struct {
App *app.App
}
func (postController *PostController) Login(w http.ResponseWriter, r *http.Request) {
// Validate csrf token
_, err := security.VerifyCsrfToken(r)
if err != nil {
log.Println("Error verifying csrf token")
return
}
username := r.FormValue("username")
password := r.FormValue("password")
remember := r.FormValue("remember") == "on"
if username == "" || password == "" {
log.Println("Tried to login user with empty username or password")
http.Redirect(w, r, "/login", http.StatusFound)
}
_, err = models.AuthenticateUser(postController.App, w, username, password, remember)
if err != nil {
log.Println("Error authenticating user")
log.Println(err)
http.Redirect(w, r, "/login", http.StatusFound)
return
}
http.Redirect(w, r, "/", http.StatusFound)
}
func (postController *PostController) Register(w http.ResponseWriter, r *http.Request) {
// Validate csrf token
_, err := security.VerifyCsrfToken(r)
if err != nil {
log.Println("Error verifying csrf token")
return
}
username := r.FormValue("username")
password := r.FormValue("password")
createdAt := time.Now()
updatedAt := time.Now()
if username == "" || password == "" {
log.Println("Tried to create user with empty username or password")
http.Redirect(w, r, "/register", http.StatusFound)
}
_, err = models.CreateUser(postController.App, username, password, createdAt, updatedAt)
if err != nil {
log.Println("Error creating user")
log.Println(err)
return
}
http.Redirect(w, r, "/login", http.StatusFound)
}
func (postController *PostController) FileUpload(w http.ResponseWriter, r *http.Request) {
max := postController.App.Config.Upload.MaxSize
err := r.ParseMultipartForm(max)
if err != nil {
return
}
// FormFile returns the first file for the given key `file`
// it also returns the FileHeader, so we can get the Filename,
// the Header and the size of the file
file, handler, err := r.FormFile("file")
if err != nil {
log.Println("Error Retrieving the File")
log.Println(err)
return
}
defer func(file multipart.File) {
err := file.Close()
if err != nil {
log.Println(err)
}
}(file)
if handler.Size > max {
log.Println("User tried uploading a file which is too large.")
http.Redirect(w, r, "/", http.StatusRequestHeaderFieldsTooLarge)
return
}
// Create a temporary file within upload directory
tempFile, err := os.Create(postController.App.Config.Upload.BaseName + handler.Filename)
if err != nil {
log.Println(err)
http.Redirect(w, r, "/", http.StatusNotAcceptable)
}
defer func(tempFile *os.File) {
err := tempFile.Close()
if err != nil {
log.Println(err)
}
}(tempFile)
// read all the contents of our uploaded file into a
// byte array
fileBytes, err := io.ReadAll(file)
if err != nil {
log.Println(err)
}
_, err = tempFile.Write(fileBytes)
if err != nil {
log.Println(err)
}
http.Redirect(w, r, "/", http.StatusFound)
}

View File

@ -5,26 +5,29 @@ import (
"database/sql"
"fmt"
_ "github.com/lib/pq"
"log/slog"
"log"
)
// Connect returns a new database connection
func Connect(app *app.App) *sql.DB {
// ConnectDB returns a new database connection
func ConnectDB(app *app.App) *sql.DB {
// Set connection parameters from config
postgresConfig := fmt.Sprintf("host=%s port=%s user=%s "+
"password=%s dbname=%s sslmode=disable",
app.Config.Db.Ip, app.Config.Db.Port, app.Config.Db.User, app.Config.Db.Password, app.Config.Db.Name)
// Create connection
db, err := sql.Open("postgres", postgresConfig)
if err != nil {
panic(err)
}
// Test connection
err = db.Ping()
if err != nil {
panic(err)
}
slog.Info("connected to database successfully on " + app.Config.Db.Ip + ":" + app.Config.Db.Port + " using database " + app.Config.Db.Name)
log.Println("Connected to database successfully on " + app.Config.Db.Ip + ":" + app.Config.Db.Port + " using database " + app.Config.Db.Name)
return db
}

View File

@ -5,12 +5,11 @@ import (
"errors"
"fmt"
"github.com/lib/pq"
"log/slog"
"log"
"reflect"
)
// Migrate given a dummy object of any type, it will create a table with the same name
// as the type and create columns with the same name as the fields of the object
// Migrate given a dummy object of any type, it will create a table with the same name as the type and create columns with the same name as the fields of the object
func Migrate(app *app.App, anyStruct interface{}) error {
valueOfStruct := reflect.ValueOf(anyStruct)
typeOfStruct := valueOfStruct.Type()
@ -24,15 +23,10 @@ func Migrate(app *app.App, anyStruct interface{}) error {
for i := 0; i < valueOfStruct.NumField(); i++ {
fieldType := typeOfStruct.Field(i)
fieldName := fieldType.Name
// Create column if dummy for migration is NOT zero value
fieldValue := valueOfStruct.Field(i).Interface()
if !reflect.ValueOf(fieldValue).IsZero() {
if fieldName != "Id" && fieldName != "id" {
err := createColumn(app, tableName, fieldName, fieldType.Type.Name())
if err != nil {
return err
}
if fieldName != "Id" && fieldName != "id" {
err := createColumn(app, tableName, fieldName, fieldType.Type.Name())
if err != nil {
return err
}
}
}
@ -42,46 +36,48 @@ func Migrate(app *app.App, anyStruct interface{}) error {
// createTable creates a table with the given name if it doesn't exist, it is assumed that id will be the primary key
func createTable(app *app.App, tableName string) error {
// Check to see if the table already exists
var tableExists bool
err := app.Db.QueryRow("SELECT EXISTS (SELECT 1 FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace WHERE c.relname ~ $1 AND pg_catalog.pg_table_is_visible(c.oid))", "^"+tableName+"$").Scan(&tableExists)
if err != nil {
slog.Error("error checking if table exists: " + tableName)
log.Println("Error checking if table exists: " + tableName)
return err
}
if tableExists {
slog.Info("table already exists: " + tableName)
log.Println("Table already exists: " + tableName)
return nil
} else {
sanitizedTableQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS \"%s\" (\"Id\" serial primary key)", tableName)
_, err := app.Db.Query(sanitizedTableQuery)
if err != nil {
slog.Error("error creating table: " + tableName)
log.Println("Error creating table: " + tableName)
return err
}
slog.Info("table created successfully: " + tableName)
log.Println("Table created successfully: " + tableName)
return nil
}
}
// createColumn creates a column with the given name and type if it doesn't exist
func createColumn(app *app.App, tableName, columnName, columnType string) error {
// Check to see if the column already exists
var columnExists bool
err := app.Db.QueryRow("SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = $1 AND column_name = $2)", tableName, columnName).Scan(&columnExists)
if err != nil {
slog.Error("error checking if column exists: " + columnName + " in table: " + tableName)
log.Println("Error checking if column exists: " + columnName + " in table: " + tableName)
return err
}
if columnExists {
slog.Info("column already exists: " + columnName + " in table: " + tableName)
log.Println("Column already exists: " + columnName + " in table: " + tableName)
return nil
} else {
postgresType, err := getPostgresType(columnType)
if err != nil {
slog.Error("error creating column: " + columnName + " in table: " + tableName + " with type: " + postgresType)
log.Println("Error creating column: " + columnName + " in table: " + tableName + " with type: " + postgresType)
return err
}
@ -90,11 +86,11 @@ func createColumn(app *app.App, tableName, columnName, columnType string) error
_, err = app.Db.Query(query)
if err != nil {
slog.Error("error creating column: " + columnName + " in table: " + tableName + " with type: " + postgresType)
log.Println("Error creating column: " + columnName + " in table: " + tableName + " with type: " + postgresType)
return err
}
slog.Info("column created successfully:", columnName)
log.Println("Column created successfully:", columnName)
return nil
}

View File

@ -12,7 +12,10 @@
"HttpPort": "8090"
},
"Template": {
"BaseTemplateName": "templates/base.html",
"ContentPath": "templates"
"BaseTemplateName": "templates/base.html"
},
"Upload": {
"UploadDirectoryName": "goweb-uploads/",
"MaxUploadSize": 10485760
}
}

4
go.mod
View File

@ -1,8 +1,8 @@
module GoWeb
go 1.22
go 1.20
require (
github.com/lib/pq v1.10.9
golang.org/x/crypto v0.24.0
golang.org/x/crypto v0.8.0
)

4
go.sum
View File

@ -1,4 +1,4 @@
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ=
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=

51
main.go
View File

@ -6,11 +6,10 @@ import (
"GoWeb/database"
"GoWeb/models"
"GoWeb/routes"
"GoWeb/templating"
"context"
"embed"
"errors"
"log/slog"
"log"
"net/http"
"os"
"os/signal"
@ -35,26 +34,31 @@ func main() {
if _, err := os.Stat("logs"); os.IsNotExist(err) {
err := os.Mkdir("logs", 0755)
if err != nil {
panic("failed to create log directory: " + err.Error())
log.Println("Failed to create log directory")
log.Println(err)
return
}
}
// Create log file and set output
file, err := os.OpenFile("logs/"+time.Now().Format("2006-01-02")+".log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
panic("error creating log file: " + err.Error())
file, err := os.OpenFile("logs/"+time.Now().Format("2006-01-02")+".log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
log.SetOutput(file)
// Create upload directory if it doesn't exist
uploadPath := appLoaded.Config.Upload.BaseName
if _, err := os.Stat(uploadPath); errors.Is(err, os.ErrNotExist) {
if err := os.MkdirAll(uploadPath, os.ModePerm); err != nil {
log.Fatal(err)
}
}
logger := slog.New(slog.NewTextHandler(file, nil))
slog.SetDefault(logger) // Set structured logger globally
// Connect to database and run migrations
appLoaded.Db = database.Connect(&appLoaded)
appLoaded.Db = database.ConnectDB(&appLoaded)
if appLoaded.Config.Db.AutoMigrate {
err = models.RunAllMigrations(&appLoaded)
if err != nil {
slog.Error("error running migrations: " + err.Error())
os.Exit(1)
log.Println(err)
return
}
}
@ -65,24 +69,16 @@ func main() {
}
// Define Routes
routes.Get(&appLoaded)
routes.Post(&appLoaded)
// Prepare templates
err = templating.BuildPages(&appLoaded)
if err != nil {
slog.Error("error building templates: " + err.Error())
os.Exit(1)
}
routes.GetRoutes(&appLoaded)
routes.PostRoutes(&appLoaded)
// Start server
server := &http.Server{Addr: appLoaded.Config.Listen.Ip + ":" + appLoaded.Config.Listen.Port}
go func() {
slog.Info("starting server and listening on " + appLoaded.Config.Listen.Ip + ":" + appLoaded.Config.Listen.Port)
log.Println("Starting server and listening on " + appLoaded.Config.Listen.Ip + ":" + appLoaded.Config.Listen.Port)
err := server.ListenAndServe()
if err != nil && !errors.Is(err, http.ErrServerClosed) {
slog.Error("could not listen on %s: %v\n", appLoaded.Config.Listen.Ip+":"+appLoaded.Config.Listen.Port, err)
os.Exit(1)
if err != nil && err != http.ErrServerClosed {
log.Fatalf("Could not listen on %s: %v\n", appLoaded.Config.Listen.Ip+":"+appLoaded.Config.Listen.Port, err)
}
}()
@ -93,11 +89,10 @@ func main() {
go app.RunScheduledTasks(&appLoaded, 100, stop)
<-interrupt
slog.Info("interrupt signal received. Shutting down server...")
log.Println("Interrupt signal received. Shutting down server...")
err = server.Shutdown(context.Background())
if err != nil {
slog.Error("could not gracefully shutdown the server: %v\n", err)
os.Exit(1)
log.Fatalf("Could not gracefully shutdown the server: %v\n", err)
}
}

View File

@ -1,21 +0,0 @@
package middleware
import (
"GoWeb/security"
"log/slog"
"net/http"
)
// Csrf validates the CSRF token and returns the handler function if it succeeded
func Csrf(f func(w http.ResponseWriter, r *http.Request)) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
_, err := security.VerifyCsrfToken(r)
if err != nil {
slog.Info("error verifying csrf token")
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
f(w, r)
}
}

View File

@ -1,5 +0,0 @@
package middleware
import "net/http"
type MiddlewareFunc func(f func(w http.ResponseWriter, r *http.Request)) func(w http.ResponseWriter, r *http.Request)

View File

@ -1,14 +0,0 @@
package middleware
import "net/http"
// ProcessGroup is a wrapper function for the http.HandleFunc function
// that takes the function you want to execute (f) and the middleware you want
// to execute (m) this should be used when processing multiple groups of middleware at a time
func ProcessGroup(f func(w http.ResponseWriter, r *http.Request), m []MiddlewareFunc) func(w http.ResponseWriter, r *http.Request) {
for _, middleware := range m {
_ = middleware(f)
}
return f
}

View File

@ -4,7 +4,7 @@ import (
"GoWeb/app"
"crypto/rand"
"encoding/hex"
"log/slog"
"log"
"net/http"
"time"
)
@ -17,7 +17,7 @@ type Session struct {
CreatedAt time.Time
}
const sessionColumnsNoId = "\"UserId\", \"AuthToken\", \"RememberMe\", \"CreatedAt\""
const sessionColumnsNoId = "\"UserId\", \"AuthToken\",\"RememberMe\", \"CreatedAt\""
const sessionColumns = "\"Id\", " + sessionColumnsNoId
const sessionTable = "public.\"Session\""
@ -42,19 +42,21 @@ func CreateSession(app *app.App, w http.ResponseWriter, userId int64, remember b
var existingAuthToken bool
err := app.Db.QueryRow(selectAuthTokenIfExists, session.AuthToken).Scan(&existingAuthToken)
if err != nil {
slog.Error("error checking for existing auth token" + err.Error())
log.Println("Error checking for existing auth token")
log.Println(err)
return Session{}, err
}
// If duplicate token found, recursively call function until unique token is generated
if existingAuthToken {
slog.Warn("duplicate token found in sessions table, generating new token...")
if existingAuthToken == true {
log.Println("Duplicate token found in sessions table, generating new token...")
return CreateSession(app, w, userId, remember)
}
// Insert session into database
err = app.Db.QueryRow(insertSession, session.UserId, session.AuthToken, session.RememberMe, session.CreatedAt).Scan(&session.Id)
if err != nil {
slog.Error("error inserting session into database")
log.Println("Error inserting session into database")
return Session{}, err
}
@ -62,25 +64,28 @@ func CreateSession(app *app.App, w http.ResponseWriter, userId int64, remember b
return session, nil
}
func SessionByAuthToken(app *app.App, authToken string) (Session, error) {
func GetSessionByAuthToken(app *app.App, authToken string) (Session, error) {
session := Session{}
err := app.Db.QueryRow(selectSessionByAuthToken, authToken).Scan(&session.Id, &session.UserId, &session.AuthToken, &session.RememberMe, &session.CreatedAt)
if err != nil {
log.Println("Error getting session by auth token")
return Session{}, err
}
return session, nil
}
// generateAuthToken generates a random 64-byte string
// Generates a random 64-byte string
func generateAuthToken(app *app.App) string {
// Generate random bytes
b := make([]byte, 64)
_, err := rand.Read(b)
if err != nil {
slog.Error("error generating random bytes for auth token")
log.Println("Error generating random bytes")
}
// Convert random bytes to hex string
return hex.EncodeToString(b)
}
@ -124,9 +129,10 @@ func deleteSessionCookie(app *app.App, w http.ResponseWriter) {
// DeleteSessionByAuthToken deletes a session from the database by AuthToken
func DeleteSessionByAuthToken(app *app.App, w http.ResponseWriter, authToken string) error {
// Delete session from database
_, err := app.Db.Exec(deleteSessionByAuthToken, authToken)
if err != nil {
slog.Error("error deleting session from database")
log.Println("Error deleting session from database")
return err
}
@ -140,14 +146,16 @@ func ScheduledSessionCleanup(app *app.App) {
// Delete sessions older than 30 days (remember me sessions)
_, err := app.Db.Exec(deleteSessionsOlderThan30Days)
if err != nil {
slog.Error("error deleting 30 day expired sessions from database" + err.Error())
log.Println("Error deleting 30 day expired sessions from database")
log.Println(err)
}
// Delete sessions older than 6 hours
_, err = app.Db.Exec(deleteSessionsOlderThan6Hours)
if err != nil {
slog.Error("error deleting 6 hour expired sessions from database" + err.Error())
log.Println("Error deleting 6 hour expired sessions from database")
log.Println(err)
}
slog.Info("deleted expired sessions from database")
log.Println("Deleted expired sessions from database")
}

View File

@ -2,10 +2,9 @@ package models
import (
"GoWeb/app"
"crypto/sha256"
"encoding/hex"
"log/slog"
"log"
"net/http"
"strconv"
"time"
"golang.org/x/crypto/bcrypt"
@ -29,39 +28,45 @@ const (
insertUser = "INSERT INTO " + userTable + " (" + userColumnsNoId + ") VALUES ($1, $2, $3, $4) RETURNING \"Id\""
)
// CurrentUser finds the currently logged-in user by session cookie
func CurrentUser(app *app.App, r *http.Request) (User, error) {
// GetCurrentUser finds the currently logged-in user by session cookie
func GetCurrentUser(app *app.App, r *http.Request) (User, error) {
cookie, err := r.Cookie("session")
if err != nil {
log.Println("Error getting session cookie")
return User{}, err
}
session, err := SessionByAuthToken(app, cookie.Value)
session, err := GetSessionByAuthToken(app, cookie.Value)
if err != nil {
log.Println("Error getting session by auth token")
return User{}, err
}
return UserById(app, session.UserId)
return GetUserById(app, session.UserId)
}
// UserById finds a User table row in the database by id and returns a struct representing this row
func UserById(app *app.App, id int64) (User, error) {
// GetUserById finds a User table row in the database by id and returns a struct representing this row
func GetUserById(app *app.App, id int64) (User, error) {
user := User{}
// Query row by id
err := app.Db.QueryRow(selectUserById, id).Scan(&user.Id, &user.Username, &user.Password, &user.CreatedAt, &user.UpdatedAt)
if err != nil {
log.Println("Get user error (user not found) for user id:" + strconv.FormatInt(id, 10))
return User{}, err
}
return user, nil
}
// UserByUsername finds a User table row in the database by username and returns a struct representing this row
func UserByUsername(app *app.App, username string) (User, error) {
// GetUserByUsername finds a User table row in the database by username and returns a struct representing this row
func GetUserByUsername(app *app.App, username string) (User, error) {
user := User{}
// Query row by username
err := app.Db.QueryRow(selectUserByUsername, username).Scan(&user.Id, &user.Username, &user.Password, &user.CreatedAt, &user.UpdatedAt)
if err != nil {
log.Println("Get user error (user not found) for user:" + username)
return User{}, err
}
@ -70,14 +75,10 @@ func UserByUsername(app *app.App, username string) (User, error) {
// CreateUser creates a User table row in the database
func CreateUser(app *app.App, username string, password string, createdAt time.Time, updatedAt time.Time) (User, error) {
// Get sha256 hash of password then get bcrypt hash to store
hash256 := sha256.New()
hash256.Write([]byte(password))
hashSum := hash256.Sum(nil)
hashString := hex.EncodeToString(hashSum)
hash, err := bcrypt.GenerateFromPassword([]byte(hashString), bcrypt.DefaultCost)
// Hash password
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
slog.Error("error hashing password: " + err.Error())
log.Println("Error hashing password when creating user")
return User{}, err
}
@ -85,31 +86,28 @@ func CreateUser(app *app.App, username string, password string, createdAt time.T
err = app.Db.QueryRow(insertUser, username, string(hash), createdAt, updatedAt).Scan(&lastInsertId)
if err != nil {
slog.Error("error creating user row: " + err.Error())
log.Println("Error creating user row")
return User{}, err
}
return UserById(app, lastInsertId)
return GetUserById(app, lastInsertId)
}
// AuthenticateUser validates the password for the specified user
func AuthenticateUser(app *app.App, w http.ResponseWriter, username string, password string, remember bool) (Session, error) {
var user User
// Query row by username
err := app.Db.QueryRow(selectUserByUsername, username).Scan(&user.Id, &user.Username, &user.Password, &user.CreatedAt, &user.UpdatedAt)
if err != nil {
slog.Info("user not found: " + username)
log.Println("Authentication error (user not found) for user:" + username)
return Session{}, err
}
// Get sha256 hash of password then check bcrypt
hash256 := sha256.New()
hash256.Write([]byte(password))
hashSum := hash256.Sum(nil)
hashString := hex.EncodeToString(hashSum)
err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(hashString))
// Validate password
err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
if err != nil { // Failed to validate password, doesn't match
slog.Info("incorrect password:" + username)
log.Println("Authentication error (incorrect password) for user:" + username)
return Session{}, err
} else {
return CreateSession(app, w, user.Id, remember)
@ -118,13 +116,17 @@ func AuthenticateUser(app *app.App, w http.ResponseWriter, username string, pass
// LogoutUser deletes the session cookie and AuthToken from the database
func LogoutUser(app *app.App, w http.ResponseWriter, r *http.Request) {
// Get cookie from request
cookie, err := r.Cookie("session")
if err != nil {
log.Println("Error getting cookie from request")
return
}
// Set token to empty string
err = DeleteSessionByAuthToken(app, w, cookie.Value)
if err != nil {
log.Println("Error deleting session by AuthToken")
return
}
}

View File

@ -1,65 +0,0 @@
package restclient
import (
"bytes"
"encoding/json"
"mime/multipart"
"net/http"
)
// SendRequest sends an HTTP request to a URL and includes the specified headers and body.
// A body can be nil for GET requests, a map[string]string for multipart/form-data requests,
// or a struct for JSON requests
func SendRequest(url string, method string, headers map[string]string, body interface{}) (http.Response, error) {
var reqBody *bytes.Buffer
var contentType string
switch v := body.(type) {
case nil:
reqBody = bytes.NewBuffer([]byte(""))
case map[string]string:
reqBody = &bytes.Buffer{}
writer := multipart.NewWriter(reqBody)
for key, value := range v {
err := writer.WriteField(key, value)
if err != nil {
return http.Response{}, err
}
}
err := writer.Close()
if err != nil {
return http.Response{}, err
}
contentType = writer.FormDataContentType()
default:
jsonBody, err := json.Marshal(body)
if err != nil {
return http.Response{}, err
}
reqBody = bytes.NewBuffer(jsonBody)
contentType = "application/json"
}
req, err := http.NewRequest(method, url, reqBody)
if err != nil {
return http.Response{}, err
}
if contentType != "" {
req.Header.Set("Content-Type", contentType)
}
for key, value := range headers {
req.Header.Add(key, value)
}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return http.Response{}, err
}
return *resp, nil
}

View File

@ -4,30 +4,33 @@ import (
"GoWeb/app"
"GoWeb/controllers"
"io/fs"
"log/slog"
"log"
"net/http"
)
// Get defines all project get routes
func Get(app *app.App) {
// GetRoutes defines all project get routes
func GetRoutes(app *app.App) {
// Get controller struct initialize
getController := controllers.Get{
getController := controllers.GetController{
App: app,
}
// Serve static files
staticFS, err := fs.Sub(app.Res, "static")
if err != nil {
slog.Error(err.Error())
log.Println(err)
return
}
staticHandler := http.FileServer(http.FS(staticFS))
http.Handle("/static/", http.StripPrefix("/static/", staticHandler))
slog.Info("serving static files from embedded file system /static")
log.Println("Serving static files from embedded file system /static")
// Pages
http.HandleFunc("/", getController.ShowHome)
http.HandleFunc("/login", getController.ShowLogin)
http.HandleFunc("/register", getController.ShowRegister)
http.HandleFunc("/logout", getController.Logout)
// Files
http.HandleFunc("/uploads", getController.ShowFile)
}

View File

@ -1,20 +0,0 @@
package routes
import (
"GoWeb/app"
"GoWeb/controllers"
"GoWeb/middleware"
"net/http"
)
// Post defines all project post routes
func Post(app *app.App) {
// Post controller struct initialize
postController := controllers.Post{
App: app,
}
// User authentication
http.HandleFunc("/register-handle", middleware.Csrf(postController.Register))
http.HandleFunc("/login-handle", middleware.Csrf(postController.Login))
}

20
routes/postRoutes.go Normal file
View File

@ -0,0 +1,20 @@
package routes
import (
"GoWeb/app"
"GoWeb/controllers"
"net/http"
)
// PostRoutes defines all project post routes
func PostRoutes(app *app.App) {
// Post controller struct initialize
postController := controllers.PostController{
App: app,
}
// User authentication
http.HandleFunc("/register-handle", postController.Register)
http.HandleFunc("/login-handle", postController.Login)
http.HandleFunc("/upload-handle", postController.FileUpload)
}

View File

@ -3,17 +3,19 @@ package security
import (
"crypto/rand"
"encoding/hex"
"log/slog"
"log"
"math"
"net/http"
)
// GenerateCsrfToken generates a csrf token and assigns it to a cookie for double submit cookie csrf protection
func GenerateCsrfToken(w http.ResponseWriter, _ *http.Request) (string, error) {
// Generate random 64 character string (alpha-numeric)
buff := make([]byte, int(math.Ceil(float64(64)/2)))
_, err := rand.Read(buff)
if err != nil {
slog.Error("error creating random buffer for csrf token value" + err.Error())
log.Println("Error creating random buffer for csrf token value")
log.Println(err)
return "", err
}
str := hex.EncodeToString(buff)
@ -37,7 +39,8 @@ func GenerateCsrfToken(w http.ResponseWriter, _ *http.Request) (string, error) {
func VerifyCsrfToken(r *http.Request) (bool, error) {
cookie, err := r.Cookie("csrf_token")
if err != nil {
slog.Info("unable to get csrf_token cookie" + err.Error())
log.Println("Error getting csrf_token cookie")
log.Println(err)
return false, err
}

View File

@ -3,7 +3,7 @@
<head>
<meta charset="UTF-8">
<title>SiteName - {{ template "pageTitle" }}</title>
<link href="/static/css/style.css" rel="stylesheet">
<link rel="stylesheet" href="/static/css/style.css">
</head>
<body>
{{ template "content" . }}

View File

@ -1,5 +1,25 @@
{{ define "pageTitle" }}Home{{ end }}
{{ define "file-upload" }}
<form
enctype="multipart/form-data"
action="/upload-handle"
method="post"
>
<input type="file" accept="*/*" name="file" />
<input type="submit" value="upload" />
</form>
{{ end }}
{{ define "content" }}
{{ .Test }}
<!-- Uncomment below to demo file upload system -->
<!-- {{ template "file-upload" . }} -->
<!-- <p>Upload an image called test.jpg to test the file upload system</p> -->
<!-- <img src="/uploads?name=test.jpg" alt=""> -->
{{ end }}
{{ define "content" }}
{{ end }}

View File

@ -7,11 +7,11 @@
<input name="csrf_token" type="hidden" value="{{ .CsrfToken }}">
<label for="username">Username:</label><br>
<input id="username" name="username" placeholder="John" type="text"><br><br>
<input id="username" name="username" type="text" placeholder="John"><br><br>
<label for="password">Password:</label><br>
<input id="password" name="password" type="password"><br><br>
<label for="remember">Remember Me:</label>
<input id="remember" name="remember" type="checkbox"><br><br>
<input id="remember" type="checkbox" name="remember"><br><br>
<input type="submit" value="Submit">
</form>
</div>

View File

@ -7,7 +7,7 @@
<input name="csrf_token" type="hidden" value="{{ .CsrfToken }}">
<label for="username">Username:</label><br>
<input id="username" name="username" placeholder="John" type="text"><br><br>
<input id="username" name="username" type="text" placeholder="John"><br><br>
<label for="password">Password:</label><br>
<input id="password" name="password" type="password"><br><br>
<input type="submit" value="Submit">

View File

@ -2,83 +2,46 @@ package templating
import (
"GoWeb/app"
"fmt"
"html/template"
"io/fs"
"log/slog"
"log"
"net/http"
)
var templates = make(map[string]*template.Template) // This is only used here, does not need to be in app.App
// RenderTemplate renders and serves a template from the embedded filesystem optionally with given data
func RenderTemplate(app *app.App, w http.ResponseWriter, contentPath string, data any) {
templatePath := app.Config.Template.BaseName
func BuildPages(app *app.App) error {
basePath := app.Config.Template.BaseName
baseContent, err := app.Res.ReadFile(basePath)
templateContent, err := app.Res.ReadFile(templatePath)
if err != nil {
return fmt.Errorf("error reading base file: %w", err)
}
base, err := template.New(basePath).Parse(string(baseContent)) // Sets filepath as name and parses content
if err != nil {
return fmt.Errorf("error parsing base file: %w", err)
}
readFilesRecursively := func(fsys fs.FS, root string) ([]string, error) {
var files []string
err := fs.WalkDir(fsys, root, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return fmt.Errorf("error walking the path %q: %w", path, err)
}
if !d.IsDir() {
files = append(files, path)
}
return nil
})
return files, err
}
// Get all file paths in the directory tree
filePaths, err := readFilesRecursively(app.Res, app.Config.Template.ContentPath)
if err != nil {
return fmt.Errorf("error reading files recursively: %w", err)
}
for _, contentPath := range filePaths { // Create a new template base + content for each page
content, err := app.Res.ReadFile(contentPath)
if err != nil {
return fmt.Errorf("error reading content file %s: %w", contentPath, err)
}
t, err := base.Clone()
if err != nil {
return fmt.Errorf("error cloning base template: %w", err)
}
_, err = t.Parse(string(content))
if err != nil {
return fmt.Errorf("error parsing content: %w", err)
}
templates[contentPath] = t
}
return nil
}
func RenderTemplate(w http.ResponseWriter, contentPath string, data any) {
t, ok := templates[contentPath]
if !ok {
err := fmt.Errorf("template not found for path: %s", contentPath)
slog.Error(err.Error())
http.Error(w, "Template not found", 404)
log.Println(err)
http.Error(w, err.Error(), 500)
return
}
err := t.Execute(w, data) // Execute prebuilt template with dynamic data
t, err := template.New(templatePath).Parse(string(templateContent))
if err != nil {
err = fmt.Errorf("error executing template: %w", err)
slog.Error(err.Error())
log.Println(err)
http.Error(w, err.Error(), 500)
return
}
content, err := app.Res.ReadFile(contentPath)
if err != nil {
log.Println(err)
http.Error(w, err.Error(), 500)
return
}
t, err = t.Parse(string(content))
if err != nil {
log.Println(err)
http.Error(w, err.Error(), 500)
return
}
err = t.Execute(w, data)
if err != nil {
log.Println(err)
http.Error(w, err.Error(), 500)
return
}