package global import ( "context" "fmt" "electricity_bill_calc/config" "electricity_bill_calc/logger" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/samber/lo" "go.uber.org/zap" ) var ( DB *pgxpool.Pool ) func SetupDatabaseConnection() error { connString := fmt.Sprintf( "postgres://%s:%s@%s:%d/%s?sslmode=disable&connect_timeout=%d&application_name=%s&pool_max_conns=%d&pool_min_conns=%d&pool_max_conn_lifetime=%s&pool_max_conn_idle_time=%s&pool_health_check_period=%s", config.DatabaseSettings.User, config.DatabaseSettings.Pass, config.DatabaseSettings.Host, config.DatabaseSettings.Port, config.DatabaseSettings.DB, 0, "elec_service_go", config.DatabaseSettings.MaxOpenConns, config.DatabaseSettings.MaxIdleConns, "60m", "10m", "10s", ) poolConfig, err := pgxpool.ParseConfig(connString) if err != nil { logger.Named("DB INIT").Error("数据库连接初始化失败。", zap.Error(err)) return err } poolConfig.ConnConfig.Tracer = QueryLogger{logger: logger.Named("PG")} DB, _ = pgxpool.NewWithConfig(context.Background(), poolConfig) return nil } type QueryLogger struct { logger *zap.Logger } func (ql QueryLogger) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { ql.logger.Info(fmt.Sprintf("将要执行查询: %s", data.SQL)) ql.logger.Info("查询参数", lo.Map(data.Args, func(elem any, index int) zap.Field { return zap.Any(fmt.Sprintf("[Arg %d]: ", index), elem) })...) // for index, arg := range data.Args { // ql.logger.Info(fmt.Sprintf("[Arg %d]: %v", index, arg)) // } return ctx } func (ql QueryLogger) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { var logFunc func(string, ...zap.Field) var templateString string if data.Err != nil { logFunc = ql.logger.Error templateString = "命令 [%s] 执行失败。" } else { logFunc = ql.logger.Info templateString = "命令 [%s] 执行成功。" } switch { case data.CommandTag.Update(): fallthrough case data.CommandTag.Delete(): fallthrough case data.CommandTag.Insert(): logFunc( fmt.Sprintf(templateString, data.CommandTag.String()), zap.Error(data.Err), zap.Any("affected", data.CommandTag.RowsAffected())) case data.CommandTag.Select(): logFunc( fmt.Sprintf(templateString, data.CommandTag.String()), zap.Error(data.Err)) default: logFunc( fmt.Sprintf(templateString, data.CommandTag.String()), zap.Error(data.Err)) } }