dbmethods.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. package common
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "os"
  7. "regexp"
  8. "strings"
  9. "time"
  10. "github.com/pkg/errors"
  11. )
  12. type DBMethods struct {
  13. DB *sql.DB
  14. Debug bool
  15. Driver string
  16. }
  17. var rLogSpacesAll = regexp.MustCompile(`[\s\t]+`)
  18. var rLogSpacesEnd = regexp.MustCompile(`[\s\t]+;$`)
  19. var rSqlParam = regexp.MustCompile(`\$\d+`)
  20. type queryFunc func(ctx context.Context, tx *sql.Tx) error
  21. func (db *DBMethods) log(m string, s time.Time, e error, tx bool, query string, args ...any) {
  22. var tmsg string
  23. if tx {
  24. tmsg = " [TX]"
  25. }
  26. if m != "" {
  27. tmsg = tmsg + " " + m
  28. }
  29. qmsg := query
  30. if qmsg != "" {
  31. qmsg = strings.Trim(rLogSpacesAll.ReplaceAllString(qmsg, " "), " ")
  32. qmsg = rLogSpacesEnd.ReplaceAllString(qmsg, ";")
  33. qmsg = " " + qmsg
  34. }
  35. astr := " (empty)"
  36. if len(args) > 0 {
  37. astr = fmt.Sprintf(" (%v)", args)
  38. }
  39. estr := " (nil)"
  40. if e != nil {
  41. estr = " \033[0m\033[0;31m(" + e.Error() + ")"
  42. }
  43. color := "0;33"
  44. if tx {
  45. color = "1;33"
  46. }
  47. fmt.Fprintln(os.Stdout, "\033["+color+"m[SQL]"+tmsg+qmsg+astr+estr+fmt.Sprintf(" %.3f ms", time.Since(s).Seconds())+"\033[0m")
  48. }
  49. func (db *DBMethods) fixQuery(query string) string {
  50. if db.Driver == "mysql" {
  51. return rSqlParam.ReplaceAllString(query, "?")
  52. }
  53. return query
  54. }
  55. func (db *DBMethods) Begin(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
  56. if db.Debug {
  57. t := time.Now()
  58. tx, err := db.DB.BeginTx(ctx, opts)
  59. db.log("[func Begin]", t, err, true, "")
  60. return tx, err
  61. }
  62. return db.DB.BeginTx(ctx, opts)
  63. }
  64. func (db *DBMethods) Close() error {
  65. if db.Debug {
  66. t := time.Now()
  67. err := db.DB.Close()
  68. db.log("[func Close]", t, err, false, "")
  69. return err
  70. }
  71. return db.DB.Close()
  72. }
  73. func (db *DBMethods) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) {
  74. if db.Debug {
  75. t := time.Now()
  76. res, err := db.DB.ExecContext(ctx, db.fixQuery(query), args...)
  77. db.log("[func Exec]", t, err, false, db.fixQuery(query), args...)
  78. return res, err
  79. }
  80. return db.DB.ExecContext(ctx, db.fixQuery(query), args...)
  81. }
  82. func (db *DBMethods) Ping(ctx context.Context) error {
  83. if db.Debug {
  84. t := time.Now()
  85. err := db.DB.PingContext(ctx)
  86. db.log("[func Ping]", t, err, false, "")
  87. return err
  88. }
  89. return db.DB.PingContext(ctx)
  90. }
  91. func (db *DBMethods) Prepare(ctx context.Context, query string) (*sql.Stmt, error) {
  92. if db.Debug {
  93. t := time.Now()
  94. stm, err := db.DB.PrepareContext(ctx, db.fixQuery(query))
  95. db.log("[func Prepare]", t, err, false, db.fixQuery(query))
  96. return stm, err
  97. }
  98. return db.DB.PrepareContext(ctx, db.fixQuery(query))
  99. }
  100. func (db *DBMethods) Query(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
  101. if db.Debug {
  102. t := time.Now()
  103. rows, err := db.DB.QueryContext(ctx, db.fixQuery(query), args...)
  104. db.log("[func Query]", t, err, false, db.fixQuery(query), args...)
  105. return rows, err
  106. }
  107. return db.DB.QueryContext(ctx, db.fixQuery(query), args...)
  108. }
  109. func (db *DBMethods) QueryRow(ctx context.Context, query string, args ...any) *sql.Row {
  110. if db.Debug {
  111. t := time.Now()
  112. row := db.DB.QueryRowContext(ctx, db.fixQuery(query), args...)
  113. db.log("[func QueryRow]", t, nil, false, db.fixQuery(query), args...)
  114. return row
  115. }
  116. return db.DB.QueryRowContext(ctx, db.fixQuery(query), args...)
  117. }
  118. func (db *DBMethods) Transaction(ctx context.Context, queries queryFunc) error {
  119. if queries == nil {
  120. return fmt.Errorf("queries is not set for transaction")
  121. }
  122. tx, err := db.Begin(ctx, nil)
  123. if err != nil {
  124. return err
  125. }
  126. if err := queries(ctx, tx); err != nil {
  127. return errors.Wrap(err, tx.Rollback().Error())
  128. }
  129. return tx.Commit()
  130. }