From 9aa32d99b966b7a6c26bdcef3a065c0276b3829e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E6=B6=9B?= Date: Thu, 1 Jun 2023 12:04:03 +0800 Subject: [PATCH] =?UTF-8?q?feat(user):=E9=80=9A=E8=BF=87=E5=AE=8C=E6=88=90?= =?UTF-8?q?=E7=94=A8=E6=88=B7=E6=A3=80=E7=B4=A2=E5=8A=9F=E8=83=BD=EF=BC=8C?= =?UTF-8?q?=E7=BB=A7=E7=BB=AD=E7=A1=AE=E5=AE=9A=E9=A1=B9=E7=9B=AE=E7=9A=84?= =?UTF-8?q?=E5=9F=BA=E6=9C=AC=E4=BB=A3=E7=A0=81=E7=BB=93=E6=9E=84=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cache/search.go | 9 +++ config/settings.go | 2 +- controller/user.go | 44 ++++++++++++- repository/user.go | 135 ++++++++++++++++++++++++++++++++++++-- response/base_response.go | 2 +- security/security.go | 2 +- tools/utils.go | 35 ++++++++++ 7 files changed, 219 insertions(+), 10 deletions(-) diff --git a/cache/search.go b/cache/search.go index d83bcd6..fb4b8df 100644 --- a/cache/search.go +++ b/cache/search.go @@ -1,11 +1,16 @@ package cache import ( + "electricity_bill_calc/logger" "fmt" "strings" "time" + + "go.uber.org/zap" ) +var log = logger.Named("Cache") + func assembleSearchKey(entityName string, additional ...string) string { var keys = make([]string, 0) keys = append(keys, strings.ToUpper(entityName)) @@ -70,6 +75,10 @@ func RetrievePagedSearch[T any](entityName string, conditions ...string) (*T, in if err != nil { return nil, -1, err } + if instance == nil || count == nil { + log.Warn("检索结果或者检索总数为空。", zap.String("searchKey", searchKey), zap.String("countKey", countKey)) + return nil, -1, nil + } return instance, count.Count, nil } diff --git a/config/settings.go b/config/settings.go index a46dd93..b8403ec 100644 --- a/config/settings.go +++ b/config/settings.go @@ -31,7 +31,7 @@ type RedisSetting struct { type ServiceSetting struct { MaxSessionLife time.Duration - ItemsPageSize int + ItemsPageSize uint CacheLifeTime time.Duration HostSerial int64 } diff --git a/controller/user.go b/controller/user.go index 981f6dc..f500569 100644 --- a/controller/user.go +++ b/controller/user.go @@ -4,10 +4,12 @@ import ( "electricity_bill_calc/cache" "electricity_bill_calc/exceptions" "electricity_bill_calc/model" + "electricity_bill_calc/repository" "electricity_bill_calc/response" "electricity_bill_calc/security" "electricity_bill_calc/service" "net/http" + "strconv" "github.com/gofiber/fiber/v2" ) @@ -15,6 +17,7 @@ import ( func InitializeUserHandlers(router *fiber.App) { router.Delete("/login", security.MustAuthenticated, doLogout) router.Post("/login", doLogin) + router.Get("/account", security.ManagementAuthorize, searchUsers) } type _LoginForm struct { @@ -53,13 +56,48 @@ func doLogin(c *fiber.Ctx) error { func doLogout(c *fiber.Ctx) error { result := response.NewResult(c) - session := c.Locals("session") - if session == nil { + session, err := _retreiveSession(c) + if err != nil { return result.Success("用户会话已结束。") } - _, err := cache.ClearSession(session.(*model.Session).Token) + _, err = cache.ClearSession(session.Token) if err != nil { return result.Error(http.StatusInternalServerError, err.Error()) } return result.Success("用户已成功登出系统。") } + +func searchUsers(c *fiber.Ctx) error { + result := response.NewResult(c) + requestPage, err := strconv.Atoi(c.Query("page", "1")) + if err != nil { + return result.NotAccept("查询参数[page]格式不正确。") + } + requestKeyword := c.Query("keyword") + requestUserType, err := strconv.Atoi(c.Query("type", "-1")) + if err != nil { + return result.NotAccept("查询参数[type]格式不正确。") + } + var requestUserStat *bool + state, err := strconv.ParseBool(c.Query("state")) + if err != nil { + requestUserStat = nil + } else { + requestUserStat = &state + } + users, total, err := repository.UserRepository.FindUser( + &requestKeyword, + int16(requestUserType), + requestUserStat, + uint(requestPage), + ) + if err != nil { + return result.NotFound(err.Error()) + } + return result.Json( + http.StatusOK, + "已取得符合条件的用户集合。", + response.NewPagedResponse(requestPage, total).ToMap(), + fiber.Map{"accounts": users}, + ) +} diff --git a/repository/user.go b/repository/user.go index fa42f04..dc164b1 100644 --- a/repository/user.go +++ b/repository/user.go @@ -2,6 +2,7 @@ package repository import ( "electricity_bill_calc/cache" + "electricity_bill_calc/config" "electricity_bill_calc/global" "electricity_bill_calc/logger" "electricity_bill_calc/model" @@ -28,7 +29,7 @@ var UserRepository = _UserRepository{ // 使用用户名查询指定用户的基本信息 func (ur _UserRepository) FindUserByUsername(username string) (*model.User, error) { ur.log.Info("根据用户名查询指定用户的基本信息。", zap.String("username", username)) - if cachedUser, _ := cache.RetreiveEntity[model.User]("user", username); cachedUser != nil { + if cachedUser, _ := cache.RetrieveEntity[model.User]("user", username); cachedUser != nil { ur.log.Info("已经从缓存获取到了符合指定用户名条件的用户基本信息。", zap.String("username", username)) return cachedUser, nil } @@ -48,7 +49,7 @@ func (ur _UserRepository) FindUserByUsername(username string) (*model.User, erro // 使用用户唯一编号查询指定用户的基本信息 func (ur _UserRepository) FindUserById(uid string) (*model.User, error) { ur.log.Info("根据用户唯一编号查询指定用户的基本信息。", zap.String("user id", uid)) - if cachedUser, _ := cache.RetreiveEntity[model.User]("user", uid); cachedUser != nil { + if cachedUser, _ := cache.RetrieveEntity[model.User]("user", uid); cachedUser != nil { ur.log.Info("已经从缓存获取到了符合指定用户唯一编号的用户基本信息。") return cachedUser, nil } @@ -68,7 +69,7 @@ func (ur _UserRepository) FindUserById(uid string) (*model.User, error) { // 使用用户的唯一编号获取用户的详细信息 func (ur _UserRepository) FindUserDetailById(uid string) (*model.UserDetail, error) { ur.log.Info("根据用户唯一编号查询指定用户的详细信息。", zap.String("user id", uid)) - if cachedUser, _ := cache.RetreiveEntity[model.UserDetail]("user_detail", uid); cachedUser != nil { + if cachedUser, _ := cache.RetrieveEntity[model.UserDetail]("user_detail", uid); cachedUser != nil { ur.log.Info("已经从缓存获取到了符合指定用户唯一编号的用户详细信息。") return cachedUser, nil } @@ -88,7 +89,7 @@ func (ur _UserRepository) FindUserDetailById(uid string) (*model.UserDetail, err // 使用用户唯一编号获取用户的综合详细信息 func (ur _UserRepository) FindUserInformation(uid string) (*model.UserWithDetail, error) { ur.log.Info("根据用户唯一编号查询用户的综合详细信息", zap.String("user id", uid)) - if cachedUser, _ := cache.RetreiveEntity[model.UserWithDetail]("user_information", uid); cachedUser != nil { + if cachedUser, _ := cache.RetrieveEntity[model.UserWithDetail]("user_information", uid); cachedUser != nil { ur.log.Info("已经从缓存获取到了符合指定用户唯一编号的用户综合详细信息。") return cachedUser, nil } @@ -141,8 +142,31 @@ func (ur _UserRepository) IsUserExists(uid string) (bool, error) { return userCount > 0, nil } +// 检查指定用户名在数据库中是否已经存在 +func (ur _UserRepository) IsUsernameExists(username string) (bool, error) { + ur.log.Info("检查指定用户名在数据库中是否已经存在。", zap.String("username", username)) + if exists, _ := cache.CheckExists("user", username); exists { + ur.log.Info("已经从缓存获取到了符合指定用户名的用户基本信息。") + return exists, nil + } + ctx, cancel := global.TimeoutContext() + defer cancel() + + var userCount int + sql, params, _ := ur.ds.From("user").Select(goqu.COUNT("*")).Where(goqu.Ex{"username": username}).Prepared(true).ToSQL() + if err := pgxscan.Get(ctx, global.DB, &userCount, sql, params...); err != nil { + ur.log.Error("从数据库查询指定用户名的用户基本信息失败。", zap.String("username", username), zap.Error(err)) + return false, err + } + if userCount > 0 { + cache.CacheExists([]string{"user", fmt.Sprintf("user:%s", username)}, "user", username) + } + return userCount > 0, nil +} + // 创建一个新用户 func (ur _UserRepository) CreateUser(user model.User, detail model.UserDetail, operator *string) (bool, error) { + ur.log.Info("创建一个新用户。", zap.String("username", user.Username)) ctx, cancel := global.TimeoutContext() defer cancel() tx, err := global.DB.Begin(ctx) @@ -197,3 +221,106 @@ func (ur _UserRepository) CreateUser(user model.User, detail model.UserDetail, o } return userResult.RowsAffected() > 0 && detailResult.RowsAffected() > 0, nil } + +// 根据给定的条件检索用户 +func (ur _UserRepository) FindUser(keyword *string, userType int16, state *bool, page uint) (*[]model.UserWithDetail, int64, error) { + ur.log.Info("根据给定的条件检索用户。", zap.Uint("page", page), zap.Stringp("keyword", keyword), zap.Int16("user type", userType), zap.Boolp("state", state)) + if users, total, err := cache.RetrievePagedSearch[[]model.UserWithDetail]("user_with_detail", []string{ + fmt.Sprintf("%d", page), + tools.CondFn( + func(v int16) bool { + return v != -1 + }, + userType, + fmt.Sprintf("%d", userType), + "UNDEF", + ), + tools.DefaultStrTo("%s", state, "UNDEF"), + tools.DefaultTo(keyword, ""), + }...); err == nil && users != nil && total != -1 { + return users, total, nil + } + + ctx, cancel := global.TimeoutContext() + defer cancel() + + var ( + userWithDetails []model.UserWithDetail + userCount int64 + ) + userQuery := ur.ds. + From(goqu.T("user").As("u")). + Join(goqu.T("user_detail").As("ud"), goqu.On(goqu.Ex{"ud.id": goqu.I("u.id")})). + Select( + "u.id", "u.username", "u.reset_needed", "u.type", "u.enabled", + "ud.name", "ud.abbr", "ud.region", "ud.address", "ud.contact", "ud.phone", + "ud.unit_service_fee", "ud.service_expiration", + "ud.created_at", "ud.created_by", "ud.last_modified_at", "ud.last_modified_by", + ) + countQuery := ur.ds. + From(goqu.T("user").As("u")). + Join(goqu.T("user_detail").As("ud"), goqu.On(goqu.Ex{"ud.id": goqu.I("u.id")})). + Select(goqu.COUNT("*")) + + if keyword != nil && len(*keyword) > 0 { + pattern := fmt.Sprintf("%%%s%%", *keyword) + userQuery = userQuery.Where( + goqu.Or( + goqu.Ex{"u.username": goqu.Op{"like": pattern}}, + goqu.Ex{"ud.name": goqu.Op{"like": pattern}}, + goqu.Ex{"ud.abbr": goqu.Op{"like": pattern}}, + ), + ) + countQuery = countQuery.Where( + goqu.Or( + goqu.Ex{"u.username": goqu.Op{"like": pattern}}, + goqu.Ex{"ud.name": goqu.Op{"like": pattern}}, + goqu.Ex{"ud.abbr": goqu.Op{"like": pattern}}, + ), + ) + } + + if userType != -1 { + userQuery = userQuery.Where(goqu.Ex{"u.type": userType}) + countQuery = countQuery.Where(goqu.Ex{"u.type": userType}) + } + + if state != nil { + userQuery = userQuery.Where(goqu.Ex{"u.enabled": state}) + countQuery = countQuery.Where(goqu.Ex{"u.enabled": state}) + } + + currentPosition := (page - 1) * config.ServiceSettings.ItemsPageSize + userQuery = userQuery.Offset(currentPosition).Limit(config.ServiceSettings.ItemsPageSize) + + userSql, userParams, _ := userQuery.Prepared(true).ToSQL() + countSql, countParams, _ := countQuery.Prepared(true).ToSQL() + if err := pgxscan.Select(ctx, global.DB, &userWithDetails, userSql, userParams...); err != nil { + ur.log.Error("从数据库查询用户列表失败。", zap.Error(err)) + return nil, 0, err + } + if err := pgxscan.Get(ctx, global.DB, &userCount, countSql, countParams...); err != nil { + ur.log.Error("从数据库查询用户列表总数失败。", zap.Error(err)) + return nil, 0, err + } + cache.CachePagedSearch( + userWithDetails, + userCount, + []string{"user"}, + "user_with_detail", + []string{ + fmt.Sprintf("%d", page), + tools.CondFn( + func(v int16) bool { + return v != -1 + }, + userType, + fmt.Sprintf("%d", userType), + "UNDEF", + ), + tools.DefaultStrTo("%s", state, "UNDEF"), + tools.DefaultTo(keyword, ""), + }..., + ) + return &userWithDetails, userCount, nil +} diff --git a/response/base_response.go b/response/base_response.go index c156f9f..01ce30b 100644 --- a/response/base_response.go +++ b/response/base_response.go @@ -17,7 +17,7 @@ type BaseResponse struct { type PagedResponse struct { Page int `json:"current"` - Size int `json:"pageSize"` + Size uint `json:"pageSize"` Total int64 `json:"total"` } diff --git a/security/security.go b/security/security.go index 151b097..dda0f61 100644 --- a/security/security.go +++ b/security/security.go @@ -15,7 +15,7 @@ import ( func SessionRecovery(c *fiber.Ctx) error { if auth := c.Get("Authorization", ""); len(auth) > 0 { token := strings.Fields(auth)[1] - session, err := cache.RetreiveSession(token) + session, err := cache.RetrieveSession(token) if err == nil && session != nil { c.Locals("session", session) diff --git a/tools/utils.go b/tools/utils.go index 459f1dd..a0333d9 100644 --- a/tools/utils.go +++ b/tools/utils.go @@ -2,6 +2,7 @@ package tools import ( "encoding/json" + "fmt" "strings" "github.com/mozillazg/go-pinyin" @@ -51,3 +52,37 @@ func PartitionSlice[T any](slice []T, chunkSize int) [][]T { } return divided } + +// 判断指定指针是否为空,如果为空,则返回指定默认值(指针形式) +func DefaultTo[T any](originValue *T, defaultValue T) T { + if originValue == nil { + return defaultValue + } + return *originValue +} + +// 判断指定的指针是否为空,如果为空,则返回指定的默认字符串,或者返回指针所指内容的字符串形式。 +func DefaultStrTo[T any](format string, originValue *T, defaultStr string) string { + if originValue == nil { + return defaultStr + } + return fmt.Sprintf(format, originValue) +} + +// 判断指定表达式的值,根据表达式的值返回指定的值。相当于其他语言中的三目运算符。 +func Cond[T any](expr bool, trueValue, falseValue T) T { + if expr { + return trueValue + } + return falseValue +} + +// 使用给定的函数对指定的值进行判断,根据表达式的值返回指定的值。 +func CondFn[T, R any](exprFn func(val T) bool, value T, trueValue, falseValue R) R { + return Cond(exprFn(value), trueValue, falseValue) +} + +// 使用给定的函数对指定的值进行判断,如果表达式为真,则返回指定的值,否则返回另一个值。 +func CondOr[T any](exprFn func(val T) bool, value, elseValue T) T { + return CondFn(exprFn, value, value, elseValue) +}