package service import ( "context" "database/sql" "electricity_bill_calc/cache" "electricity_bill_calc/config" "electricity_bill_calc/exceptions" "electricity_bill_calc/global" "electricity_bill_calc/logger" "electricity_bill_calc/model" "fmt" "strconv" "time" "github.com/fufuok/utils" "github.com/samber/lo" "github.com/uptrace/bun" "go.uber.org/zap" ) type _ChargeService struct { l *zap.Logger } var ChargeService = _ChargeService{ l: logger.Named("Service", "Charge"), } func (c _ChargeService) CreateChargeRecord(charge *model.UserCharge, extendWithIgnoreSettle bool) error { ctx, cancel := global.TimeoutContext() defer cancel() tx, err := global.DB.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return err } _, err = tx.NewInsert().Model(charge).Exec(ctx) if err != nil { tx.Rollback() return err } if extendWithIgnoreSettle { err := c.updateUserExpiration(&tx, ctx, charge.UserId, charge.ChargeTo) if err != nil { return err } } err = tx.Commit() if err != nil { tx.Rollback() return err } cache.AbolishRelation("charge") return nil } func (c _ChargeService) SettleCharge(seq int64, uid string) error { ctx, cancel := global.TimeoutContext() defer cancel() tx, err := global.DB.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return err } var record = new(model.UserCharge) err = tx.NewSelect().Model(&record). Where("seq = ?", seq). Where("user_id = ?", uid). Scan(ctx) if err != nil { return nil } if record == nil { return exceptions.NewNotFoundError("未找到匹配指定条件的计费记录。") } currentTime := time.Now() _, err = tx.NewUpdate().Model((*model.UserCharge)(nil)). Where("seq = ?", seq). Where("user_id = ?", uid). Set("settled = ?", true). Set("settled_at = ?", currentTime). Exec(ctx) if err != nil { tx.Rollback() return err } err = c.updateUserExpiration(&tx, ctx, uid, record.ChargeTo) if err != nil { return err } err = tx.Commit() if err != nil { tx.Rollback() return err } cache.AbolishRelation(fmt.Sprintf("charge:%s:%d", uid, seq)) return nil } func (c _ChargeService) RefundCharge(seq int64, uid string) error { ctx, cancel := global.TimeoutContext() defer cancel() tx, err := global.DB.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return err } currentTime := time.Now() res, err := tx.NewUpdate().Model((*model.UserCharge)(nil)). Where("seq = ?", seq). Where("user_id = ?", uid). Set("refunded = ?", true). Set("refunded_at = ?", currentTime). Exec(ctx) if err != nil { tx.Rollback() return err } if rows, _ := res.RowsAffected(); rows == 0 { tx.Rollback() return exceptions.NewNotFoundError("未找到匹配指定条件的计费记录。") } lastValidExpriation, err := c.lastValidChargeTo(&tx, &ctx, uid) if err != nil { tx.Rollback() return exceptions.NewNotFoundError("未找到最后合法的计费时间。") } err = c.updateUserExpiration(&tx, ctx, uid, lastValidExpriation) if err != nil { return err } err = tx.Commit() if err != nil { tx.Rollback() return err } cache.AbolishRelation(fmt.Sprintf("charge:%s:%d", uid, seq)) return nil } func (c _ChargeService) CancelCharge(seq int64, uid string) error { ctx, cancel := global.TimeoutContext() defer cancel() tx, err := global.DB.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return err } currentTime := time.Now() res, err := tx.NewUpdate().Model((*model.UserCharge)(nil)). Where("seq = ?", seq). Where("user_id = ?", uid). Set("cancelled = ?", true). Set("cancelled_at = ?", currentTime). Exec(ctx) if err != nil { tx.Rollback() return err } if rows, _ := res.RowsAffected(); rows == 0 { tx.Rollback() return exceptions.NewNotFoundError("未找到匹配指定条件的计费记录。") } err = tx.Commit() if err != nil { tx.Rollback() return err } tx, err = global.DB.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return err } lastValidExpriation, err := c.lastValidChargeTo(&tx, &ctx, uid) if err != nil { return exceptions.NewNotFoundError("未找到最后合法的计费时间。") } err = c.updateUserExpiration(&tx, ctx, uid, lastValidExpriation) if err != nil { return err } err = tx.Commit() if err != nil { tx.Rollback() return err } cache.AbolishRelation("user") cache.AbolishRelation(fmt.Sprintf("user:%s", uid)) cache.AbolishRelation("charge") cache.AbolishRelation(fmt.Sprintf("charge:%s:%d", uid, seq)) return nil } func (ch _ChargeService) updateUserExpiration(tx *bun.Tx, ctx context.Context, uid string, expiration model.Date) error { _, err := tx.NewUpdate().Model((*model.UserDetail)(nil)). Set("service_expiration = ?", expiration). Where("id = ?", uid). Exec(ctx) if err != nil { tx.Rollback() } cache.AbolishRelation(fmt.Sprintf("user:%s", uid)) return err } func (_ChargeService) ListPagedChargeRecord(keyword, beginDate, endDate string, page int) ([]model.ChargeWithName, int64, error) { var ( cond = global.DB.NewSelect() condition = make([]string, 0) charges = make([]model.UserCharge, 0) ) cond = cond.Model(&charges).Relation("Detail") condition = append(condition, strconv.Itoa(page)) if len(keyword) != 0 { keywordCond := "%" + keyword + "%" cond = cond.WhereGroup(" and ", func(q *bun.SelectQuery) *bun.SelectQuery { return q.Where("d.mame like ?", keywordCond). WhereOr("d.abbr like ?", keywordCond) }) condition = append(condition, keyword) } if len(beginDate) != 0 { beginTime, err := time.ParseInLocation("2006-01-02", beginDate, time.Local) beginTime = utils.BeginOfDay(beginTime) if err != nil { return make([]model.ChargeWithName, 0), 0, err } cond = cond.Where("c.created_at >= ?", beginTime) condition = append(condition, strconv.FormatInt(beginTime.Unix(), 10)) } if len(endDate) != 0 { endTime, err := time.ParseInLocation("2006-01-02", endDate, time.Local) endTime = utils.EndOfDay(endTime) if err != nil { return make([]model.ChargeWithName, 0), 0, err } cond = cond.Where("c.created_at <= ?", endTime) condition = append(condition, strconv.FormatInt(endTime.Unix(), 10)) } if cachedTotal, err := cache.RetreiveCount("charge_with_name", condition...); cachedTotal != -1 && err == nil { if cachedCharges, _ := cache.RetreiveSearch[[]model.ChargeWithName]("charge_with_name", condition...); cachedCharges != nil { return *cachedCharges, cachedTotal, nil } } startItem := (page - 1) * config.ServiceSettings.ItemsPageSize var ( total int err error ) ctx, cancel := global.TimeoutContext() defer cancel() total, err = cond.Limit(config.ServiceSettings.ItemsPageSize).Offset(startItem).ScanAndCount(ctx) relations := []string{"charge"} chargesWithName := make([]model.ChargeWithName, 0) for _, c := range charges { chargesWithName = append(chargesWithName, model.ChargeWithName{ UserCharge: c, UserDetail: *c.Detail, }) relations = append(relations, fmt.Sprintf("charge:%s:%d", c.UserId, c.Seq)) } cache.CacheCount(relations, "charge_with_name", int64(total), condition...) cache.CacheSearch(chargesWithName, relations, "charge_with_name", condition...) return chargesWithName, int64(total), err } func (_ChargeService) lastValidChargeTo(tx *bun.Tx, ctx *context.Context, uid string) (model.Date, error) { var records []model.Date err := tx.NewSelect().Table("user_charge"). Where("settled = ? and cancelled = ? and refunded = ? and user_id = ?", true, false, false, uid). Column("charge_to"). Scan(*ctx, &records) if err != nil { return model.NewEmptyDate(), nil } lastValid := lo.Reduce(records, func(acc, elem model.Date, index int) model.Date { if elem.Time.After(acc.Time) { return elem } else { return acc } }, model.NewEmptyDate()) return lastValid, nil }