fix: session authentication middleware

This commit is contained in:
2025-06-09 12:56:59 +01:00
parent ea9c5f0902
commit 14ce860c37
12 changed files with 283 additions and 97 deletions
+7 -4
View File
@@ -4,10 +4,10 @@ import (
"fmt"
"github.com/go-playground/validator/v10"
"github.com/gorilla/sessions"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/labstack/echo-contrib/session"
"github.com/gorilla/sessions"
"github.com/hazemKrimi/crimson-vault/internal/lib"
"github.com/hazemKrimi/crimson-vault/internal/models"
@@ -34,12 +34,15 @@ func (api *API) Initialize() {
api.instance = ech
api.db = db
// TODO: Change and store the secret separately when finilizing v1.
api.instance.Use(session.Middleware(sessions.NewCookieStore([]byte("SECRET"))))
api.instance.Use(middleware.CORSWithConfig(middleware.CORSConfig{
AllowOrigins: []string{"*"},
}))
// TODO: Change and store the secret separately when finilizing v1.
api.instance.Use(session.Middleware(sessions.NewCookieStore([]byte("SECRET"))))
api.ClientRoutes()
api.UserRoutes()
api.AuthRoutes()
api.instance.Logger.Fatal(api.instance.Start(fmt.Sprintf(":%d", lib.DEFAULT_PORT)))
}
+84
View File
@@ -0,0 +1,84 @@
package api
import (
"fmt"
"log"
"net/http"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
"github.com/hazemKrimi/crimson-vault/internal/lib"
"github.com/hazemKrimi/crimson-vault/internal/types"
)
func (api *API) LoginHandler(context echo.Context) error {
var body types.LoginRequestBody
if err := context.Bind(&body); err != nil {
log.Println(fmt.Sprintf("Error logging User in: %v.", err))
return context.String(http.StatusBadRequest, "Invalid JSON!")
}
if err := context.Validate(body); err != nil {
return err
}
var user types.User
if err := api.db.GetUserByUsername(body.Username, &user); err != nil {
return context.String(http.StatusNotFound, "User not found!")
}
if match := lib.CheckPasswordHash(body.Password, user.Password); !match {
return context.String(http.StatusBadRequest, "Invalid credentials!")
}
sess, err := session.Get("session", context)
if err != nil {
log.Println(fmt.Sprintf("Error creating User session: %v.", err))
return context.String(http.StatusInternalServerError, "Unexpected error creating User session!")
}
if err := api.db.UpdateUserSessionID(&user); err != nil {
log.Println(fmt.Sprintf("Error creating User session: %v.", err))
return context.String(http.StatusInternalServerError, "Unexpected error creating User session!")
}
if err := lib.CreateSession(sess, context, &user); err != nil {
log.Println(fmt.Sprintf("Error creating User session: %v.", err))
return context.String(http.StatusInternalServerError, "Unexpected error creating User session!")
}
log.Println(fmt.Sprintf("User with ID %s logged in.", user.ID))
return context.JSON(http.StatusOK, user)
}
func (api *API) LogoutHandler(context echo.Context) error {
sessionId, ok := context.Get("sessionId").(string)
if !ok {
return context.String(http.StatusInternalServerError, "Unexpected error deleting User session!")
}
if err := api.db.DeleteUserSessionID(sessionId); err != nil {
log.Println(fmt.Sprintf("Error deleting User session: %v.", err))
return context.String(http.StatusInternalServerError, "Unexpected error deleting User session!")
}
sess, err := session.Get("session", context)
if err != nil {
log.Println(fmt.Sprintf("Error deleting User session: %v.", err))
return context.String(http.StatusInternalServerError, "Unexpected error deleting User session!")
}
if err := lib.DeleteSession(sess, context); err != nil {
log.Println(fmt.Sprintf("Error deleting User session: %v.", err))
return context.String(http.StatusInternalServerError, "Unexpected error deleting User session!")
}
log.Println(fmt.Sprintf("User with SessionID %s logged out.", sessionId))
return context.String(http.StatusOK, "Logged out successfully!")
}
+18 -5
View File
@@ -3,29 +3,42 @@ package api
import (
"net/http"
"github.com/google/uuid"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
"github.com/hazemKrimi/crimson-vault/internal/types"
)
func SessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
func (api *API) AuthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
return func(context echo.Context) error {
sess, err := session.Get("session", context)
if err != nil {
if err != nil || sess.IsNew {
return context.String(http.StatusUnauthorized, "User not authenticated!")
}
cookie, err := context.Cookie("session")
id, ok := sess.Values["sessionId"].(string)
if !ok || id == "" {
return context.String(http.StatusUnauthorized, "User not authenticated!")
}
sessionId, err := uuid.Parse(id)
if err != nil {
return context.String(http.StatusUnauthorized, "User not authenticated!")
}
if sess.IsNew || cookie.Value == "" || sess.Values["id"] == "" {
var user types.User
if err := api.db.GetUserBySessionId(sessionId, &user); err != nil {
return context.String(http.StatusUnauthorized, "User not authenticated!")
}
context.Set("id", sess.Values["id"])
context.Set("sessionId", sess.Values["sessionId"])
context.Set("username", sess.Values["username"])
return next(context)
}
}
+10 -3
View File
@@ -3,7 +3,7 @@ package api
import "github.com/labstack/echo/v4/middleware"
func (api *API) ClientRoutes() {
clients := api.instance.Group("/clients")
clients := api.instance.Group("/api/clients")
clients.GET("/", api.GetAllClientsHandler)
clients.POST("/", api.CreateClientHandler)
@@ -13,14 +13,21 @@ func (api *API) ClientRoutes() {
}
func (api *API) UserRoutes() {
users := api.instance.Group("/users")
users := api.instance.Group("/api/users")
users.GET("/", api.GetAllUsersHandler)
users.POST("/", api.CreateUserHandler)
users.GET("/:id", api.GetUserHandler)
users.PUT("/:id", api.UpdateUserHandler, SessionMiddleware)
users.PUT("/:id", api.UpdateUserHandler, api.AuthSessionMiddleware)
users.PUT("/:id/security", api.UpdateUserSecurityDetailsHandler)
users.PUT("/:id/logo", api.UpdateUserLogoHandler, middleware.BodyLimit("2M"))
users.DELETE("/:id", api.DeleteUserHandler)
users.DELETE("/:id/logo", api.DeleteUserLogoHandler)
}
func (api *API) AuthRoutes() {
auth := api.instance.Group("/api/auth")
auth.POST("/login", api.LoginHandler)
auth.DELETE("/logout", api.LogoutHandler, api.AuthSessionMiddleware)
}
+36 -68
View File
@@ -7,9 +7,9 @@ import (
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"github.com/google/uuid"
"github.com/hazemKrimi/crimson-vault/internal/lib"
"github.com/hazemKrimi/crimson-vault/internal/types"
"github.com/labstack/echo-contrib/session"
@@ -28,21 +28,26 @@ func (api *API) CreateUserHandler(context echo.Context) error {
return err
}
user := api.db.CreateUser(body)
user, err := api.db.CreateUser(body)
if err != nil {
log.Println(fmt.Sprintf("Error creating User: %v.", err))
return context.String(http.StatusInternalServerError, "Unexpected error creating User!")
}
sess, err := session.Get("session", context)
if err != nil {
api.db.DeleteUser(user.ID)
return context.String(http.StatusInternalServerError, "Unexpected error saving User session!")
log.Println(fmt.Sprintf("Error creating User session: %v.", err))
return context.String(http.StatusInternalServerError, "Unexpected error creating User session!")
}
lib.ConstructSession(sess, user)
if err := sess.Save(context.Request(), context.Response()); err != nil {
return context.String(http.StatusInternalServerError, "Unexpected error saving User session!")
if err := lib.CreateSession(sess, context, &user); err != nil {
log.Println(fmt.Sprintf("Error creating User session: %v.", err))
return context.String(http.StatusInternalServerError, "Unexpected error creating User session!")
}
log.Println(fmt.Sprintf("User created with ID %d.", user.ID))
log.Println(fmt.Sprintf("User created with ID %s.", user.ID))
return context.JSON(http.StatusOK, user)
}
@@ -58,39 +63,27 @@ func (api *API) GetAllUsersHandler(context echo.Context) error {
}
func (api *API) GetUserHandler(context echo.Context) error {
idString := context.Param("id")
if idString == "" {
return context.String(http.StatusBadRequest, "ID is required to get a User!")
}
id, err := strconv.ParseUint(idString, 10, 32)
id, err := uuid.Parse(context.Param("id"))
if err != nil {
return context.String(http.StatusInternalServerError, "Unexpected error getting User!")
return context.String(http.StatusBadRequest, "ID is required to get a User!")
}
var user types.User
if err := api.db.GetUser(uint32(id), &user); err != nil {
if err := api.db.GetUserById(id, &user); err != nil {
return context.String(http.StatusNotFound, "User not found!")
}
log.Println(fmt.Sprintf("Got User with ID %d.", user.ID))
log.Println(fmt.Sprintf("Got User with ID %s.", user.ID))
return context.JSON(http.StatusOK, user)
}
func (api *API) UpdateUserHandler(context echo.Context) error {
idString := context.Param("id")
if idString == "" {
return context.String(http.StatusBadRequest, "ID is required to update a User!")
}
id, err := strconv.ParseUint(idString, 10, 32)
id, err := uuid.Parse(context.Param("id"))
if err != nil {
return context.String(http.StatusInternalServerError, "Unexpected error updating User!")
return context.String(http.StatusBadRequest, "ID is required to update a User!")
}
var body types.UpdateUserRequestBody
@@ -106,25 +99,19 @@ func (api *API) UpdateUserHandler(context echo.Context) error {
var user types.User
if err := api.db.UpdateUser(uint32(id), body, &user); err != nil {
if err := api.db.UpdateUser(id, body, &user); err != nil {
return context.String(http.StatusNotFound, "User not found!")
}
log.Println(fmt.Sprintf("Updated user with ID %d.", user.ID))
log.Println(fmt.Sprintf("Updated user with ID %s.", user.ID))
return context.JSON(http.StatusOK, user)
}
func (api *API) UpdateUserSecurityDetailsHandler(context echo.Context) error {
idString := context.Param("id")
if idString == "" {
return context.String(http.StatusBadRequest, "ID is required to create security details for a User!")
}
id, err := strconv.ParseUint(idString, 10, 32)
id, err := uuid.Parse(context.Param("id"))
if err != nil {
return context.String(http.StatusInternalServerError, "Unexpected error while creating security details for User!")
return context.String(http.StatusBadRequest, "ID is required to create security details for a User!")
}
var body types.UpdateUserSecurityDetailsBody
@@ -140,30 +127,24 @@ func (api *API) UpdateUserSecurityDetailsHandler(context echo.Context) error {
var user types.User
if err := api.db.UpdateUserSecurityDetails(uint32(id), body, &user); err != nil {
if err := api.db.UpdateUserSecurityDetails(id, body, &user); err != nil {
return context.String(http.StatusNotFound, "User not found!")
}
log.Println(fmt.Sprintf("Updated security details of user with ID %d.", user.ID))
log.Println(fmt.Sprintf("Updated security details of user with ID %s.", user.ID))
return context.JSON(http.StatusOK, user)
}
func (api *API) UpdateUserLogoHandler(context echo.Context) error {
idString := context.Param("id")
if idString == "" {
return context.String(http.StatusBadRequest, "ID is required to update logo for User!")
}
id, err := strconv.ParseUint(idString, 10, 32)
id, err := uuid.Parse(context.Param("id"))
if err != nil {
return context.String(http.StatusInternalServerError, "Unexpected error updating logo for User!")
return context.String(http.StatusBadRequest, "ID is required to update logo for User!")
}
var user types.User
if err := api.db.GetUser(uint32(id), &user); err != nil {
if err := api.db.GetUserById(id, &user); err != nil {
return context.String(http.StatusNotFound, "User not found!")
}
@@ -239,19 +220,13 @@ func (api *API) UpdateUserLogoHandler(context echo.Context) error {
}
func (api *API) DeleteUserHandler(context echo.Context) error {
idString := context.Param("id")
id, err := uuid.Parse(context.Param("id"))
if idString == "" {
if err != nil {
return context.String(http.StatusBadRequest, "ID is required to delete a User!")
}
id, err := strconv.ParseUint(idString, 10, 32)
if err != nil {
return context.String(http.StatusInternalServerError, "Unexpected error deleting User!")
}
if err := api.db.DeleteUser(uint32(id)); err != nil {
if err := api.db.DeleteUser(id); err != nil {
return context.String(http.StatusNotFound, "User not found!")
}
@@ -260,22 +235,15 @@ func (api *API) DeleteUserHandler(context echo.Context) error {
}
func (api *API) DeleteUserLogoHandler(context echo.Context) error {
idString := context.Param("id")
if idString == "" {
return context.String(http.StatusBadRequest, "ID is required to delete logo of User!")
}
id, err := strconv.ParseUint(idString, 10, 32)
id, err := uuid.Parse(context.Param("id"))
if err != nil {
log.Println(fmt.Sprintf("Error deleting logo of User: %v.", err))
return context.String(http.StatusInternalServerError, "Unexpected error deleting logo of User!")
return context.String(http.StatusBadRequest, "ID is required to delete logo of User!")
}
var user types.User
if err := api.db.GetUser(uint32(id), &user); err != nil {
if err := api.db.GetUserById(id, &user); err != nil {
return context.String(http.StatusNotFound, "User not found!")
}
@@ -286,6 +254,6 @@ func (api *API) DeleteUserLogoHandler(context echo.Context) error {
return context.String(http.StatusInternalServerError, "Unexpected error deleting logo of User!")
}
log.Println(fmt.Sprintf("Deleted logo of User with ID %d.", user.ID))
log.Println(fmt.Sprintf("Deleted logo of User with ID %s.", user.ID))
return context.String(http.StatusOK, "User logo deleted successfully!")
}