electricity_bill_calc_service/service/charge.go

290 lines
7.8 KiB
Go

package service
import (
"context"
"database/sql"
"electricity_bill_calc/cache"
"electricity_bill_calc/config"
"electricity_bill_calc/exceptions"
"electricity_bill_calc/global"
"electricity_bill_calc/logger"
"electricity_bill_calc/model"
"fmt"
"strconv"
"time"
"github.com/fufuok/utils"
"github.com/samber/lo"
"github.com/uptrace/bun"
"go.uber.org/zap"
)
type _ChargeService struct {
l *zap.Logger
}
var ChargeService = _ChargeService{
l: logger.Named("Service", "Charge"),
}
func (c _ChargeService) CreateChargeRecord(charge *model.UserCharge, extendWithIgnoreSettle bool) error {
ctx, cancel := global.TimeoutContext()
defer cancel()
tx, err := global.DB.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
return err
}
_, err = tx.NewInsert().Model(charge).Exec(ctx)
if err != nil {
tx.Rollback()
return err
}
if extendWithIgnoreSettle {
err := c.updateUserExpiration(&tx, ctx, charge.UserId, charge.ChargeTo)
if err != nil {
return err
}
}
err = tx.Commit()
if err != nil {
tx.Rollback()
return err
}
cache.AbolishRelation("charge")
return nil
}
func (c _ChargeService) SettleCharge(seq int64, uid string) error {
ctx, cancel := global.TimeoutContext()
defer cancel()
tx, err := global.DB.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
return err
}
var record = new(model.UserCharge)
err = tx.NewSelect().Model(&record).
Where("seq = ?", seq).
Where("user_id = ?", uid).
Scan(ctx)
if err != nil {
return nil
}
if record == nil {
return exceptions.NewNotFoundError("未找到匹配指定条件的计费记录。")
}
currentTime := time.Now()
_, err = tx.NewUpdate().Model((*model.UserCharge)(nil)).
Where("seq = ?", seq).
Where("user_id = ?", uid).
Set("settled = ?", true).
Set("settled_at = ?", currentTime).
Exec(ctx)
if err != nil {
tx.Rollback()
return err
}
err = c.updateUserExpiration(&tx, ctx, uid, record.ChargeTo)
if err != nil {
return err
}
err = tx.Commit()
if err != nil {
tx.Rollback()
return err
}
cache.AbolishRelation(fmt.Sprintf("charge:%s:%d", uid, seq))
return nil
}
func (c _ChargeService) RefundCharge(seq int64, uid string) error {
ctx, cancel := global.TimeoutContext()
defer cancel()
tx, err := global.DB.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
return err
}
currentTime := time.Now()
res, err := tx.NewUpdate().Model((*model.UserCharge)(nil)).
Where("seq = ?", seq).
Where("user_id = ?", uid).
Set("refunded = ?", true).
Set("refunded_at = ?", currentTime).
Exec(ctx)
if err != nil {
tx.Rollback()
return err
}
if rows, _ := res.RowsAffected(); rows == 0 {
tx.Rollback()
return exceptions.NewNotFoundError("未找到匹配指定条件的计费记录。")
}
lastValidExpriation, err := c.lastValidChargeTo(&tx, &ctx, uid)
if err != nil {
tx.Rollback()
return exceptions.NewNotFoundError("未找到最后合法的计费时间。")
}
err = c.updateUserExpiration(&tx, ctx, uid, lastValidExpriation)
if err != nil {
return err
}
err = tx.Commit()
if err != nil {
tx.Rollback()
return err
}
cache.AbolishRelation(fmt.Sprintf("charge:%s:%d", uid, seq))
return nil
}
func (c _ChargeService) CancelCharge(seq int64, uid string) error {
ctx, cancel := global.TimeoutContext()
defer cancel()
tx, err := global.DB.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
return err
}
currentTime := time.Now()
res, err := tx.NewUpdate().Model((*model.UserCharge)(nil)).
Where("seq = ?", seq).
Where("user_id = ?", uid).
Set("cancelled = ?", true).
Set("cancelled_at = ?", currentTime).
Exec(ctx)
if err != nil {
tx.Rollback()
return err
}
if rows, _ := res.RowsAffected(); rows == 0 {
tx.Rollback()
return exceptions.NewNotFoundError("未找到匹配指定条件的计费记录。")
}
err = tx.Commit()
if err != nil {
tx.Rollback()
return err
}
tx, err = global.DB.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
return err
}
lastValidExpriation, err := c.lastValidChargeTo(&tx, &ctx, uid)
if err != nil {
return exceptions.NewNotFoundError("未找到最后合法的计费时间。")
}
err = c.updateUserExpiration(&tx, ctx, uid, lastValidExpriation)
if err != nil {
return err
}
err = tx.Commit()
if err != nil {
tx.Rollback()
return err
}
cache.AbolishRelation("user")
cache.AbolishRelation(fmt.Sprintf("user:%s", uid))
cache.AbolishRelation("charge")
cache.AbolishRelation(fmt.Sprintf("charge:%s:%d", uid, seq))
return nil
}
func (_ChargeService) updateUserExpiration(tx *bun.Tx, ctx context.Context, uid string, expiration time.Time) error {
_, err := tx.NewUpdate().Model((*model.UserDetail)(nil)).
Set("service_expiration = ?", expiration).
Where("id = ?", uid).
Exec(ctx)
if err != nil {
tx.Rollback()
}
cache.AbolishRelation(fmt.Sprintf("user:%s", uid))
return err
}
func (_ChargeService) ListPagedChargeRecord(keyword, beginDate, endDate string, page int) ([]model.ChargeWithName, int64, error) {
var (
cond = global.DB.NewSelect()
condition = make([]string, 0)
charges = make([]model.UserCharge, 0)
)
cond = cond.Model(&charges).Relation("Detail")
condition = append(condition, strconv.Itoa(page))
if len(keyword) != 0 {
keywordCond := "%" + keyword + "%"
cond = cond.WhereGroup(" and ", func(q *bun.SelectQuery) *bun.SelectQuery {
return q.Where("d.mame like ?", keywordCond).
WhereOr("d.abbr like ?", keywordCond)
})
condition = append(condition, keyword)
}
if len(beginDate) != 0 {
beginTime, err := time.ParseInLocation("2006-01-02", beginDate, time.Local)
beginTime = utils.BeginOfDay(beginTime)
if err != nil {
return make([]model.ChargeWithName, 0), 0, err
}
cond = cond.Where("c.created_at >= ?", beginTime)
condition = append(condition, strconv.FormatInt(beginTime.Unix(), 10))
}
if len(endDate) != 0 {
endTime, err := time.ParseInLocation("2006-01-02", endDate, time.Local)
endTime = utils.EndOfDay(endTime)
if err != nil {
return make([]model.ChargeWithName, 0), 0, err
}
cond = cond.Where("c.created_at <= ?", endTime)
condition = append(condition, strconv.FormatInt(endTime.Unix(), 10))
}
if cachedTotal, err := cache.RetreiveCount("charge_with_name", condition...); cachedTotal != -1 && err == nil {
if cachedCharges, _ := cache.RetreiveSearch[[]model.ChargeWithName]("charge_with_name", condition...); cachedCharges != nil {
return *cachedCharges, cachedTotal, nil
}
}
startItem := (page - 1) * config.ServiceSettings.ItemsPageSize
var (
total int
err error
)
ctx, cancel := global.TimeoutContext()
defer cancel()
total, err = cond.Limit(config.ServiceSettings.ItemsPageSize).Offset(startItem).ScanAndCount(ctx)
relations := []string{"charge"}
chargesWithName := make([]model.ChargeWithName, 0)
for _, c := range charges {
chargesWithName = append(chargesWithName, model.ChargeWithName{
UserCharge: c,
UserDetail: *c.Detail,
})
relations = append(relations, fmt.Sprintf("charge:%s:%d", c.UserId, c.Seq))
}
cache.CacheCount(relations, "charge_with_name", int64(total), condition...)
cache.CacheSearch(charges, relations, "charge_with_name", condition...)
return chargesWithName, int64(total), err
}
func (_ChargeService) lastValidChargeTo(tx *bun.Tx, ctx *context.Context, uid string) (time.Time, error) {
veryBlankTime, _ := time.Parse("2006-01-02 15:04:05", "0001-01-01 00:00:00")
var records []string
err := tx.NewSelect().Table("user_charge").
Where("settled = ? and cancelled = ? and refunded = ? and user_id = ?", true, false, false, uid).
Column("charge_to").
Scan(*ctx, &records)
if err != nil {
return veryBlankTime, nil
}
mappedRecords := lo.Map(records, func(elem string, index int) time.Time {
t, _ := time.Parse(time.RFC3339, elem)
return utils.BeginOfDay(t)
})
lastValid := lo.Reduce(mappedRecords, func(acc, elem time.Time, index int) time.Time {
if elem.After(acc) {
return elem
} else {
return acc
}
}, veryBlankTime)
return lastValid, nil
}