123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152 |
- package common
- import (
- "context"
- "database/sql"
- "fmt"
- "os"
- "regexp"
- "strings"
- "time"
- "github.com/pkg/errors"
- )
- type DBMethods struct {
- DB *sql.DB
- Debug bool
- Driver string
- }
- var rLogSpacesAll = regexp.MustCompile(`[\s\t]+`)
- var rLogSpacesEnd = regexp.MustCompile(`[\s\t]+;$`)
- var rSqlParam = regexp.MustCompile(`\$\d+`)
- type queryFunc func(ctx context.Context, tx *sql.Tx) error
- func (db *DBMethods) log(m string, s time.Time, e error, tx bool, query string, args ...any) {
- var tmsg string
- if tx {
- tmsg = " [TX]"
- }
- if m != "" {
- tmsg = tmsg + " " + m
- }
- qmsg := query
- if qmsg != "" {
- qmsg = strings.Trim(rLogSpacesAll.ReplaceAllString(qmsg, " "), " ")
- qmsg = rLogSpacesEnd.ReplaceAllString(qmsg, ";")
- qmsg = " " + qmsg
- }
- astr := " (empty)"
- if len(args) > 0 {
- astr = fmt.Sprintf(" (%v)", args)
- }
- estr := " (nil)"
- if e != nil {
- estr = " \033[0m\033[0;31m(" + e.Error() + ")"
- }
- color := "0;33"
- if tx {
- color = "1;33"
- }
- fmt.Fprintln(os.Stdout, "\033["+color+"m[SQL]"+tmsg+qmsg+astr+estr+fmt.Sprintf(" %.3f ms", time.Since(s).Seconds())+"\033[0m")
- }
- func (db *DBMethods) fixQuery(query string) string {
- if db.Driver == "mysql" {
- return rSqlParam.ReplaceAllString(query, "?")
- }
- return query
- }
- func (db *DBMethods) Begin(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
- if db.Debug {
- t := time.Now()
- tx, err := db.DB.BeginTx(ctx, opts)
- db.log("[func Begin]", t, err, true, "")
- return tx, err
- }
- return db.DB.BeginTx(ctx, opts)
- }
- func (db *DBMethods) Close() error {
- if db.Debug {
- t := time.Now()
- err := db.DB.Close()
- db.log("[func Close]", t, err, false, "")
- return err
- }
- return db.DB.Close()
- }
- func (db *DBMethods) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) {
- if db.Debug {
- t := time.Now()
- res, err := db.DB.ExecContext(ctx, db.fixQuery(query), args...)
- db.log("[func Exec]", t, err, false, db.fixQuery(query), args...)
- return res, err
- }
- return db.DB.ExecContext(ctx, db.fixQuery(query), args...)
- }
- func (db *DBMethods) Ping(ctx context.Context) error {
- if db.Debug {
- t := time.Now()
- err := db.DB.PingContext(ctx)
- db.log("[func Ping]", t, err, false, "")
- return err
- }
- return db.DB.PingContext(ctx)
- }
- func (db *DBMethods) Prepare(ctx context.Context, query string) (*sql.Stmt, error) {
- if db.Debug {
- t := time.Now()
- stm, err := db.DB.PrepareContext(ctx, db.fixQuery(query))
- db.log("[func Prepare]", t, err, false, db.fixQuery(query))
- return stm, err
- }
- return db.DB.PrepareContext(ctx, db.fixQuery(query))
- }
- func (db *DBMethods) Query(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
- if db.Debug {
- t := time.Now()
- rows, err := db.DB.QueryContext(ctx, db.fixQuery(query), args...)
- db.log("[func Query]", t, err, false, db.fixQuery(query), args...)
- return rows, err
- }
- return db.DB.QueryContext(ctx, db.fixQuery(query), args...)
- }
- func (db *DBMethods) QueryRow(ctx context.Context, query string, args ...any) *sql.Row {
- if db.Debug {
- t := time.Now()
- row := db.DB.QueryRowContext(ctx, db.fixQuery(query), args...)
- db.log("[func QueryRow]", t, nil, false, db.fixQuery(query), args...)
- return row
- }
- return db.DB.QueryRowContext(ctx, db.fixQuery(query), args...)
- }
- func (db *DBMethods) Transaction(ctx context.Context, queries queryFunc) error {
- if queries == nil {
- return fmt.Errorf("queries is not set for transaction")
- }
- tx, err := db.Begin(ctx, nil)
- if err != nil {
- return err
- }
- if err := queries(ctx, tx); err != nil {
- return errors.Wrap(err, tx.Rollback().Error())
- }
- return tx.Commit()
- }
|