diff --git a/controller/user.go b/controller/user.go index a06e198..3257938 100644 --- a/controller/user.go +++ b/controller/user.go @@ -11,9 +11,18 @@ import ( "github.com/gin-gonic/gin" ) -type _UserController struct{} +type _UserController struct { + Router *gin.Engine +} -var UserController _UserController +var UserController *_UserController + +func InitializeUserController(router *gin.Engine) { + UserController = &_UserController{ + Router: router, + } + UserController.Router.POST("/login", UserController.Login) +} type LoginFormData struct { Username string `form:"uname"` @@ -48,5 +57,5 @@ func (_UserController) Login(c *gin.Context) { return } } - result.LoginSuccess(session, false) + result.LoginSuccess(session) } diff --git a/response/base-response.go b/response/base_response.go similarity index 100% rename from response/base-response.go rename to response/base_response.go diff --git a/response/user_response.go b/response/user_response.go index bafc299..e796fe7 100644 --- a/response/user_response.go +++ b/response/user_response.go @@ -11,11 +11,11 @@ type LoginResponse struct { Session *model.Session `json:"session,omitempty"` } -func (r *Result) LoginSuccess(session *model.Session, needReset bool) { +func (r *Result) LoginSuccess(session *model.Session) { res := &LoginResponse{} res.Code = http.StatusOK res.Message = "用户已成功登录。" - res.NeedReset = needReset + res.NeedReset = false res.Session = session r.Ctx.JSON(http.StatusOK, res) } diff --git a/router/router.go b/router/router.go index 92cef94..df011b6 100644 --- a/router/router.go +++ b/router/router.go @@ -1,6 +1,7 @@ package router import ( + "electricity_bill_calc/controller" "electricity_bill_calc/response" "log" "runtime/debug" @@ -11,6 +12,9 @@ import ( func Router() *gin.Engine { router := gin.Default() router.Use(Recover) + router.Use(SessionRecovery) + + controller.InitializeUserController(router) return router } diff --git a/router/security.go b/router/security.go index fe13b92..e3c350d 100644 --- a/router/security.go +++ b/router/security.go @@ -2,22 +2,45 @@ package router import ( "electricity_bill_calc/cache" + "electricity_bill_calc/model" "net/http" "strings" "github.com/gin-gonic/gin" ) -func AuthenticatedSession(c *gin.Context) { +// 用于解析Authorization头,并从缓存中获取用户会话信息注入上下文的中间件。 +// 如果没有获取到用户会话信息,将直接跳过会话信息注入。 +// ! 仅通过该中间件是不能保证上下文中一定保存有用户会话信息的。 +func SessionRecovery(c *gin.Context) { auth := c.Request.Header.Get("Authorization") if len(auth) > 0 { token := strings.Fields(auth)[1] session, err := cache.RetreiveSession(token) - if err != nil { - c.AbortWithStatus(http.StatusForbidden) + if err == nil { + c.Set("session", session) } - c.Set("session", session) + } + c.Next() +} + +// 用于对用户会话进行是否企业用户的判断 +// ! 通过该中间件以后,是可以保证上下文中一定具有用户会话信息的。 +func EnterpriseAuthorize(c *gin.Context) { + session, exists := c.Get("session") + if !exists || session.(*model.Session).Type != 0 { + c.AbortWithStatus(http.StatusForbidden) + } + c.Next() +} + +// 用于对用户会话进行是否监管用户或运维用户的判断 +// ! 通过该中间件以后,是可以保证上下文中一定具有用户会话信息的。 +func ManagementAuthorize(c *gin.Context) { + session, exists := c.Get("session") + if !exists || (session.(*model.Session).Type != 1 && session.(*model.Session).Type != 2) { + c.AbortWithStatus(http.StatusForbidden) } c.Next() }