package global import ( "context" "fmt" "time" "electricity_bill_calc/config" "electricity_bill_calc/logger" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" "github.com/samber/lo" "go.uber.org/zap" ) var ( DB *pgxpool.Pool ) func SetupDatabaseConnection() error { connConfig := &pgx.ConnConfig{ Config: pgconn.Config{ Host: config.DatabaseSettings.Host, Port: uint16(config.DatabaseSettings.Port), User: config.DatabaseSettings.User, Password: config.DatabaseSettings.Pass, Database: config.DatabaseSettings.DB, TLSConfig: nil, ConnectTimeout: 0 * time.Second, RuntimeParams: map[string]string{"application_name": "elec_service_go"}, }, Tracer: QueryLogger{ logger: logger.Named("PG"), }, } poolConfig := &pgxpool.Config{ ConnConfig: connConfig, MaxConnLifetime: 60 * time.Minute, MaxConnIdleTime: 10 * time.Minute, HealthCheckPeriod: 10 * time.Second, MaxConns: int32(config.DatabaseSettings.MaxOpenConns), MinConns: int32(config.DatabaseSettings.MaxIdleConns), } DB, _ = pgxpool.NewWithConfig(context.Background(), poolConfig) return nil } type QueryLogger struct { logger *zap.Logger } func (ql QueryLogger) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { ql.logger.Info(fmt.Sprintf("将要执行查询: %s", data.SQL)) ql.logger.Info("查询参数", lo.Map(data.Args, func(elem any, index int) zap.Field { return zap.Any(fmt.Sprintf("[Arg %d]: ", index), elem) })...) return ctx } func (ql QueryLogger) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { var logFunc func(string, ...zap.Field) var templateString string if data.Err != nil { logFunc = ql.logger.Error templateString = "命令 [%s] 执行失败。" } else { logFunc = ql.logger.Info templateString = "命令 [%s] 执行成功。" } switch { case data.CommandTag.Update(): fallthrough case data.CommandTag.Delete(): fallthrough case data.CommandTag.Insert(): logFunc( fmt.Sprintf(templateString, data.CommandTag.String()), zap.Error(data.Err), zap.Any("affected", data.CommandTag.RowsAffected())) case data.CommandTag.Select(): logFunc( fmt.Sprintf(templateString, data.CommandTag.String()), zap.Error(data.Err)) default: logFunc( fmt.Sprintf(templateString, data.CommandTag.String()), zap.Error(data.Err)) } }