fix: session authentication middleware

This commit is contained in:
2025-06-09 12:56:59 +01:00
parent ea9c5f0902
commit 14ce860c37
12 changed files with 283 additions and 97 deletions
+2 -1
View File
@@ -8,6 +8,7 @@ require (
github.com/labstack/echo-contrib v0.17.4
github.com/labstack/echo/v4 v4.13.4
github.com/spf13/cobra v1.9.1
golang.org/x/crypto v0.38.0
gorm.io/driver/sqlite v1.5.7
gorm.io/gorm v1.30.0
)
@@ -16,6 +17,7 @@ require (
github.com/gabriel-vasile/mimetype v1.4.8 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/gorilla/context v1.1.2 // indirect
github.com/gorilla/securecookie v1.1.2 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
@@ -29,7 +31,6 @@ require (
github.com/spf13/pflag v1.0.6 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasttemplate v1.2.2 // indirect
golang.org/x/crypto v0.38.0 // indirect
golang.org/x/net v0.40.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.25.0 // indirect
+2
View File
@@ -13,6 +13,8 @@ github.com/go-playground/validator/v10 v10.26.0 h1:SP05Nqhjcvz81uJaRfEV0YBSSSGMc
github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo=
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/context v1.1.2 h1:WRkNAv2uoa03QNIc1A6u4O7DAGMUVoopZhkiXWA2V1o=
github.com/gorilla/context v1.1.2/go.mod h1:KDPwT9i/MeWHiLl90fuTgrt4/wPcv75vFAZLaOOcbxM=
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
+7 -4
View File
@@ -4,10 +4,10 @@ import (
"fmt"
"github.com/go-playground/validator/v10"
"github.com/gorilla/sessions"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/labstack/echo-contrib/session"
"github.com/gorilla/sessions"
"github.com/hazemKrimi/crimson-vault/internal/lib"
"github.com/hazemKrimi/crimson-vault/internal/models"
@@ -34,12 +34,15 @@ func (api *API) Initialize() {
api.instance = ech
api.db = db
// TODO: Change and store the secret separately when finilizing v1.
api.instance.Use(session.Middleware(sessions.NewCookieStore([]byte("SECRET"))))
api.instance.Use(middleware.CORSWithConfig(middleware.CORSConfig{
AllowOrigins: []string{"*"},
}))
// TODO: Change and store the secret separately when finilizing v1.
api.instance.Use(session.Middleware(sessions.NewCookieStore([]byte("SECRET"))))
api.ClientRoutes()
api.UserRoutes()
api.AuthRoutes()
api.instance.Logger.Fatal(api.instance.Start(fmt.Sprintf(":%d", lib.DEFAULT_PORT)))
}
+84
View File
@@ -0,0 +1,84 @@
package api
import (
"fmt"
"log"
"net/http"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
"github.com/hazemKrimi/crimson-vault/internal/lib"
"github.com/hazemKrimi/crimson-vault/internal/types"
)
func (api *API) LoginHandler(context echo.Context) error {
var body types.LoginRequestBody
if err := context.Bind(&body); err != nil {
log.Println(fmt.Sprintf("Error logging User in: %v.", err))
return context.String(http.StatusBadRequest, "Invalid JSON!")
}
if err := context.Validate(body); err != nil {
return err
}
var user types.User
if err := api.db.GetUserByUsername(body.Username, &user); err != nil {
return context.String(http.StatusNotFound, "User not found!")
}
if match := lib.CheckPasswordHash(body.Password, user.Password); !match {
return context.String(http.StatusBadRequest, "Invalid credentials!")
}
sess, err := session.Get("session", context)
if err != nil {
log.Println(fmt.Sprintf("Error creating User session: %v.", err))
return context.String(http.StatusInternalServerError, "Unexpected error creating User session!")
}
if err := api.db.UpdateUserSessionID(&user); err != nil {
log.Println(fmt.Sprintf("Error creating User session: %v.", err))
return context.String(http.StatusInternalServerError, "Unexpected error creating User session!")
}
if err := lib.CreateSession(sess, context, &user); err != nil {
log.Println(fmt.Sprintf("Error creating User session: %v.", err))
return context.String(http.StatusInternalServerError, "Unexpected error creating User session!")
}
log.Println(fmt.Sprintf("User with ID %s logged in.", user.ID))
return context.JSON(http.StatusOK, user)
}
func (api *API) LogoutHandler(context echo.Context) error {
sessionId, ok := context.Get("sessionId").(string)
if !ok {
return context.String(http.StatusInternalServerError, "Unexpected error deleting User session!")
}
if err := api.db.DeleteUserSessionID(sessionId); err != nil {
log.Println(fmt.Sprintf("Error deleting User session: %v.", err))
return context.String(http.StatusInternalServerError, "Unexpected error deleting User session!")
}
sess, err := session.Get("session", context)
if err != nil {
log.Println(fmt.Sprintf("Error deleting User session: %v.", err))
return context.String(http.StatusInternalServerError, "Unexpected error deleting User session!")
}
if err := lib.DeleteSession(sess, context); err != nil {
log.Println(fmt.Sprintf("Error deleting User session: %v.", err))
return context.String(http.StatusInternalServerError, "Unexpected error deleting User session!")
}
log.Println(fmt.Sprintf("User with SessionID %s logged out.", sessionId))
return context.String(http.StatusOK, "Logged out successfully!")
}
+18 -5
View File
@@ -3,29 +3,42 @@ package api
import (
"net/http"
"github.com/google/uuid"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
"github.com/hazemKrimi/crimson-vault/internal/types"
)
func SessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
func (api *API) AuthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
return func(context echo.Context) error {
sess, err := session.Get("session", context)
if err != nil {
if err != nil || sess.IsNew {
return context.String(http.StatusUnauthorized, "User not authenticated!")
}
cookie, err := context.Cookie("session")
id, ok := sess.Values["sessionId"].(string)
if !ok || id == "" {
return context.String(http.StatusUnauthorized, "User not authenticated!")
}
sessionId, err := uuid.Parse(id)
if err != nil {
return context.String(http.StatusUnauthorized, "User not authenticated!")
}
if sess.IsNew || cookie.Value == "" || sess.Values["id"] == "" {
var user types.User
if err := api.db.GetUserBySessionId(sessionId, &user); err != nil {
return context.String(http.StatusUnauthorized, "User not authenticated!")
}
context.Set("id", sess.Values["id"])
context.Set("sessionId", sess.Values["sessionId"])
context.Set("username", sess.Values["username"])
return next(context)
}
}
+10 -3
View File
@@ -3,7 +3,7 @@ package api
import "github.com/labstack/echo/v4/middleware"
func (api *API) ClientRoutes() {
clients := api.instance.Group("/clients")
clients := api.instance.Group("/api/clients")
clients.GET("/", api.GetAllClientsHandler)
clients.POST("/", api.CreateClientHandler)
@@ -13,14 +13,21 @@ func (api *API) ClientRoutes() {
}
func (api *API) UserRoutes() {
users := api.instance.Group("/users")
users := api.instance.Group("/api/users")
users.GET("/", api.GetAllUsersHandler)
users.POST("/", api.CreateUserHandler)
users.GET("/:id", api.GetUserHandler)
users.PUT("/:id", api.UpdateUserHandler, SessionMiddleware)
users.PUT("/:id", api.UpdateUserHandler, api.AuthSessionMiddleware)
users.PUT("/:id/security", api.UpdateUserSecurityDetailsHandler)
users.PUT("/:id/logo", api.UpdateUserLogoHandler, middleware.BodyLimit("2M"))
users.DELETE("/:id", api.DeleteUserHandler)
users.DELETE("/:id/logo", api.DeleteUserLogoHandler)
}
func (api *API) AuthRoutes() {
auth := api.instance.Group("/api/auth")
auth.POST("/login", api.LoginHandler)
auth.DELETE("/logout", api.LogoutHandler, api.AuthSessionMiddleware)
}
+36 -68
View File
@@ -7,9 +7,9 @@ import (
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"github.com/google/uuid"
"github.com/hazemKrimi/crimson-vault/internal/lib"
"github.com/hazemKrimi/crimson-vault/internal/types"
"github.com/labstack/echo-contrib/session"
@@ -28,21 +28,26 @@ func (api *API) CreateUserHandler(context echo.Context) error {
return err
}
user := api.db.CreateUser(body)
user, err := api.db.CreateUser(body)
if err != nil {
log.Println(fmt.Sprintf("Error creating User: %v.", err))
return context.String(http.StatusInternalServerError, "Unexpected error creating User!")
}
sess, err := session.Get("session", context)
if err != nil {
api.db.DeleteUser(user.ID)
return context.String(http.StatusInternalServerError, "Unexpected error saving User session!")
log.Println(fmt.Sprintf("Error creating User session: %v.", err))
return context.String(http.StatusInternalServerError, "Unexpected error creating User session!")
}
lib.ConstructSession(sess, user)
if err := sess.Save(context.Request(), context.Response()); err != nil {
return context.String(http.StatusInternalServerError, "Unexpected error saving User session!")
if err := lib.CreateSession(sess, context, &user); err != nil {
log.Println(fmt.Sprintf("Error creating User session: %v.", err))
return context.String(http.StatusInternalServerError, "Unexpected error creating User session!")
}
log.Println(fmt.Sprintf("User created with ID %d.", user.ID))
log.Println(fmt.Sprintf("User created with ID %s.", user.ID))
return context.JSON(http.StatusOK, user)
}
@@ -58,39 +63,27 @@ func (api *API) GetAllUsersHandler(context echo.Context) error {
}
func (api *API) GetUserHandler(context echo.Context) error {
idString := context.Param("id")
if idString == "" {
return context.String(http.StatusBadRequest, "ID is required to get a User!")
}
id, err := strconv.ParseUint(idString, 10, 32)
id, err := uuid.Parse(context.Param("id"))
if err != nil {
return context.String(http.StatusInternalServerError, "Unexpected error getting User!")
return context.String(http.StatusBadRequest, "ID is required to get a User!")
}
var user types.User
if err := api.db.GetUser(uint32(id), &user); err != nil {
if err := api.db.GetUserById(id, &user); err != nil {
return context.String(http.StatusNotFound, "User not found!")
}
log.Println(fmt.Sprintf("Got User with ID %d.", user.ID))
log.Println(fmt.Sprintf("Got User with ID %s.", user.ID))
return context.JSON(http.StatusOK, user)
}
func (api *API) UpdateUserHandler(context echo.Context) error {
idString := context.Param("id")
if idString == "" {
return context.String(http.StatusBadRequest, "ID is required to update a User!")
}
id, err := strconv.ParseUint(idString, 10, 32)
id, err := uuid.Parse(context.Param("id"))
if err != nil {
return context.String(http.StatusInternalServerError, "Unexpected error updating User!")
return context.String(http.StatusBadRequest, "ID is required to update a User!")
}
var body types.UpdateUserRequestBody
@@ -106,25 +99,19 @@ func (api *API) UpdateUserHandler(context echo.Context) error {
var user types.User
if err := api.db.UpdateUser(uint32(id), body, &user); err != nil {
if err := api.db.UpdateUser(id, body, &user); err != nil {
return context.String(http.StatusNotFound, "User not found!")
}
log.Println(fmt.Sprintf("Updated user with ID %d.", user.ID))
log.Println(fmt.Sprintf("Updated user with ID %s.", user.ID))
return context.JSON(http.StatusOK, user)
}
func (api *API) UpdateUserSecurityDetailsHandler(context echo.Context) error {
idString := context.Param("id")
if idString == "" {
return context.String(http.StatusBadRequest, "ID is required to create security details for a User!")
}
id, err := strconv.ParseUint(idString, 10, 32)
id, err := uuid.Parse(context.Param("id"))
if err != nil {
return context.String(http.StatusInternalServerError, "Unexpected error while creating security details for User!")
return context.String(http.StatusBadRequest, "ID is required to create security details for a User!")
}
var body types.UpdateUserSecurityDetailsBody
@@ -140,30 +127,24 @@ func (api *API) UpdateUserSecurityDetailsHandler(context echo.Context) error {
var user types.User
if err := api.db.UpdateUserSecurityDetails(uint32(id), body, &user); err != nil {
if err := api.db.UpdateUserSecurityDetails(id, body, &user); err != nil {
return context.String(http.StatusNotFound, "User not found!")
}
log.Println(fmt.Sprintf("Updated security details of user with ID %d.", user.ID))
log.Println(fmt.Sprintf("Updated security details of user with ID %s.", user.ID))
return context.JSON(http.StatusOK, user)
}
func (api *API) UpdateUserLogoHandler(context echo.Context) error {
idString := context.Param("id")
if idString == "" {
return context.String(http.StatusBadRequest, "ID is required to update logo for User!")
}
id, err := strconv.ParseUint(idString, 10, 32)
id, err := uuid.Parse(context.Param("id"))
if err != nil {
return context.String(http.StatusInternalServerError, "Unexpected error updating logo for User!")
return context.String(http.StatusBadRequest, "ID is required to update logo for User!")
}
var user types.User
if err := api.db.GetUser(uint32(id), &user); err != nil {
if err := api.db.GetUserById(id, &user); err != nil {
return context.String(http.StatusNotFound, "User not found!")
}
@@ -239,19 +220,13 @@ func (api *API) UpdateUserLogoHandler(context echo.Context) error {
}
func (api *API) DeleteUserHandler(context echo.Context) error {
idString := context.Param("id")
id, err := uuid.Parse(context.Param("id"))
if idString == "" {
if err != nil {
return context.String(http.StatusBadRequest, "ID is required to delete a User!")
}
id, err := strconv.ParseUint(idString, 10, 32)
if err != nil {
return context.String(http.StatusInternalServerError, "Unexpected error deleting User!")
}
if err := api.db.DeleteUser(uint32(id)); err != nil {
if err := api.db.DeleteUser(id); err != nil {
return context.String(http.StatusNotFound, "User not found!")
}
@@ -260,22 +235,15 @@ func (api *API) DeleteUserHandler(context echo.Context) error {
}
func (api *API) DeleteUserLogoHandler(context echo.Context) error {
idString := context.Param("id")
if idString == "" {
return context.String(http.StatusBadRequest, "ID is required to delete logo of User!")
}
id, err := strconv.ParseUint(idString, 10, 32)
id, err := uuid.Parse(context.Param("id"))
if err != nil {
log.Println(fmt.Sprintf("Error deleting logo of User: %v.", err))
return context.String(http.StatusInternalServerError, "Unexpected error deleting logo of User!")
return context.String(http.StatusBadRequest, "ID is required to delete logo of User!")
}
var user types.User
if err := api.db.GetUser(uint32(id), &user); err != nil {
if err := api.db.GetUserById(id, &user); err != nil {
return context.String(http.StatusNotFound, "User not found!")
}
@@ -286,6 +254,6 @@ func (api *API) DeleteUserLogoHandler(context echo.Context) error {
return context.String(http.StatusInternalServerError, "Unexpected error deleting logo of User!")
}
log.Println(fmt.Sprintf("Deleted logo of User with ID %d.", user.ID))
log.Println(fmt.Sprintf("Deleted logo of User with ID %s.", user.ID))
return context.String(http.StatusOK, "User logo deleted successfully!")
}
+46 -2
View File
@@ -1,10 +1,15 @@
package lib
import (
"net/http"
"os"
"path/filepath"
"github.com/google/uuid"
"github.com/gorilla/sessions"
"github.com/labstack/echo/v4"
"golang.org/x/crypto/bcrypt"
"github.com/hazemKrimi/crimson-vault/internal/types"
)
@@ -20,11 +25,50 @@ func GetConfigDirectory() (string, error) {
return config, nil
}
func ConstructSession(session *sessions.Session, user types.User) {
func SaveSession(session *sessions.Session, context echo.Context) error {
if err := session.Save(context.Request(), context.Response()); err != nil {
return context.String(http.StatusInternalServerError, "Unexpected error saving User session!")
}
return nil
}
func CreateSession(session *sessions.Session, context echo.Context, user *types.User) error {
if err := uuid.Validate(user.SessionID); err != nil {
return context.String(http.StatusInternalServerError, "Unexpected error saving User session!")
}
session.Options = &sessions.Options{
Path: "/",
MaxAge: 3600,
HttpOnly: true,
}
session.Values["id"] = user.ID
session.Values["sessionId"] = user.SessionID
session.Values["username"] = user.Username
if err := SaveSession(session, context); err != nil {
return context.String(http.StatusInternalServerError, "Unexpected error saving User session!")
}
return nil
}
func DeleteSession(session *sessions.Session, context echo.Context) error {
session.Options.MaxAge = -1
if err := SaveSession(session, context); err != nil {
return context.String(http.StatusInternalServerError, "Unexpected error saving User session!")
}
return nil
}
func HashPassword(password string) (string, error) {
bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
return string(bytes), err
}
func CheckPasswordHash(password, hash string) bool {
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
return err == nil
}
+67 -10
View File
@@ -3,6 +3,8 @@ package models
import (
"strings"
"github.com/google/uuid"
"github.com/hazemKrimi/crimson-vault/internal/lib"
"github.com/hazemKrimi/crimson-vault/internal/types"
)
@@ -10,8 +12,10 @@ func (db *DB) MigrateUsers() {
db.instance.AutoMigrate(&types.User{})
}
func (db *DB) CreateUser(body types.CreateUserRequestBody) types.User {
func (db *DB) CreateUser(body types.CreateUserRequestBody) (types.User, error) {
user := types.User{
ID: uuid.New().String(),
SessionID: uuid.New().String(),
Name: body.Name,
FiscalCode: body.FiscalCode,
Address: body.Address,
@@ -21,8 +25,13 @@ func (db *DB) CreateUser(body types.CreateUserRequestBody) types.User {
Email: body.Email,
}
db.instance.Create(&user)
return user
result := db.instance.Create(&user)
if result.Error != nil {
return types.User{}, result.Error
}
return user, nil
}
func (db *DB) GetUsers() ([]types.User, error) {
@@ -37,8 +46,8 @@ func (db *DB) GetUsers() ([]types.User, error) {
return users, nil
}
func (db *DB) GetUser(id uint32, user *types.User) error {
result := db.instance.Where("id = ?", id).First(user, id)
func (db *DB) GetUserById(id uuid.UUID, user *types.User) error {
result := db.instance.Where("id = ?", id).First(user)
if result.Error != nil {
return result.Error
@@ -47,7 +56,27 @@ func (db *DB) GetUser(id uint32, user *types.User) error {
return nil
}
func (db *DB) UpdateUser(id uint32, body types.UpdateUserRequestBody, user *types.User) error {
func (db *DB) GetUserBySessionId(sessionId uuid.UUID, user *types.User) error {
result := db.instance.Where("session_id = ?", sessionId).First(user)
if result.Error != nil {
return result.Error
}
return nil
}
func (db *DB) GetUserByUsername(username string, user *types.User) error {
result := db.instance.Where("username = ?", username).First(user)
if result.Error != nil {
return result.Error
}
return nil
}
func (db *DB) UpdateUser(id uuid.UUID, body types.UpdateUserRequestBody, user *types.User) error {
result := db.instance.Where("id = ?", id).First(user, id)
if result.Error != nil {
@@ -71,16 +100,22 @@ func (db *DB) UpdateUser(id uint32, body types.UpdateUserRequestBody, user *type
return nil
}
func (db *DB) UpdateUserSecurityDetails(id uint32, body types.UpdateUserSecurityDetailsBody, user *types.User) error {
func (db *DB) UpdateUserSecurityDetails(id uuid.UUID, body types.UpdateUserSecurityDetailsBody, user *types.User) error {
result := db.instance.Where("id = ?", id).First(user, id)
if result.Error != nil {
return result.Error
}
hashedPassword, err := lib.HashPassword(body.Password)
if err != nil {
return err
}
result = db.instance.Model(user).Updates(types.User{
Username: strings.ToLower(body.Username),
Password: body.Password,
Password: hashedPassword,
})
if result.Error != nil {
@@ -102,8 +137,20 @@ func (db *DB) UpdateUserLogo(path string, user *types.User) error {
return nil
}
func (db *DB) DeleteUser(id uint32) error {
result := db.instance.Delete(&types.User{}, id)
func (db *DB) UpdateUserSessionID(user *types.User) error {
result := db.instance.Model(user).Updates(types.User{
SessionID: uuid.New().String(),
})
if result.Error != nil {
return result.Error
}
return nil
}
func (db *DB) DeleteUser(id uuid.UUID) error {
result := db.instance.Unscoped().Delete(&types.User{}, id)
if result.Error != nil {
return result.Error
@@ -123,3 +170,13 @@ func (db *DB) DeleteUserLogo(user *types.User) error {
return nil
}
func (db *DB) DeleteUserSessionID(sessionId string) error {
result := db.instance.Model(&types.User{}).Where("session_id = ?", sessionId).Update("session_id", "")
if result.Error != nil {
return result.Error
}
return nil
}
+6
View File
@@ -0,0 +1,6 @@
package types
type LoginRequestBody struct {
Username string `json:"username" validate:"required"`
Password string `json:"password" validate:"password"`
}
+1 -1
View File
@@ -7,7 +7,7 @@ import (
)
type Client struct {
ID uint32 `json:"id"`
ID uint32 `json:"id" gorm:"primaryKey"`
CreatedAt time.Time `json:"createAt"`
UpdatedAt time.Time `json:"updatedAt"`
DeletedAt gorm.DeletedAt `json:"deletedAt" gorm:"index"`
+4 -3
View File
@@ -7,7 +7,8 @@ import (
)
type User struct {
ID uint32 `json:"id"`
ID string `json:"id" gorm:"primaryKey"`
SessionID string `json:"-"`
CreatedAt time.Time `json:"createAt"`
UpdatedAt time.Time `json:"updatedAt"`
DeletedAt gorm.DeletedAt `json:"deletedAt" gorm:"index"`
@@ -19,8 +20,8 @@ type User struct {
Country string `json:"country"`
Phone string `json:"phone"`
Email string `json:"email"`
Username string `json:"username"`
Password string `json:"password"`
Username string `json:"username" gorm:"unique"`
Password string `json:"-"`
}
type CreateUserRequestBody struct {