dbmethods.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. package common
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "regexp"
  7. "time"
  8. "github.com/pkg/errors"
  9. )
  10. type DBMethods struct {
  11. DB *sql.DB
  12. Debug bool
  13. Driver string
  14. }
  15. var rLogSpacesAll = regexp.MustCompile(`[\s\t]+`)
  16. var rLogSpacesEnd = regexp.MustCompile(`[\s\t]+;$`)
  17. var rSqlParam = regexp.MustCompile(`\$\d+`)
  18. type queryFunc func(ctx context.Context, tx *Tx) error
  19. func (db *DBMethods) fixQuery(query string) string {
  20. if db.Driver == "mysql" {
  21. return rSqlParam.ReplaceAllString(query, "?")
  22. }
  23. return query
  24. }
  25. func (db *DBMethods) Begin(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
  26. if db.Debug {
  27. t := time.Now()
  28. tx, err := db.DB.BeginTx(ctx, opts)
  29. log("[func Begin]", t, err, true, "")
  30. return &Tx{tx, db.Debug, db.Driver, t}, err
  31. }
  32. tx, err := db.DB.BeginTx(ctx, opts)
  33. if err != nil {
  34. return nil, err
  35. }
  36. return &Tx{tx, db.Debug, db.Driver, time.Now()}, err
  37. }
  38. func (db *DBMethods) Close() error {
  39. if db.Debug {
  40. t := time.Now()
  41. err := db.DB.Close()
  42. log("[func Close]", t, err, false, "")
  43. return err
  44. }
  45. return db.DB.Close()
  46. }
  47. func (db *DBMethods) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) {
  48. if db.Debug {
  49. t := time.Now()
  50. res, err := db.DB.ExecContext(ctx, db.fixQuery(query), args...)
  51. log("[func Exec]", t, err, false, db.fixQuery(query), args...)
  52. return res, err
  53. }
  54. return db.DB.ExecContext(ctx, db.fixQuery(query), args...)
  55. }
  56. func (db *DBMethods) Ping(ctx context.Context) error {
  57. if db.Debug {
  58. t := time.Now()
  59. err := db.DB.PingContext(ctx)
  60. log("[func Ping]", t, err, false, "")
  61. return err
  62. }
  63. return db.DB.PingContext(ctx)
  64. }
  65. func (db *DBMethods) Prepare(ctx context.Context, query string) (*sql.Stmt, error) {
  66. if db.Debug {
  67. t := time.Now()
  68. stm, err := db.DB.PrepareContext(ctx, db.fixQuery(query))
  69. log("[func Prepare]", t, err, false, db.fixQuery(query))
  70. return stm, err
  71. }
  72. return db.DB.PrepareContext(ctx, db.fixQuery(query))
  73. }
  74. func (db *DBMethods) Query(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
  75. if db.Debug {
  76. t := time.Now()
  77. rows, err := db.DB.QueryContext(ctx, db.fixQuery(query), args...)
  78. log("[func Query]", t, err, false, db.fixQuery(query), args...)
  79. return rows, err
  80. }
  81. return db.DB.QueryContext(ctx, db.fixQuery(query), args...)
  82. }
  83. func (db *DBMethods) QueryRow(ctx context.Context, query string, args ...any) *sql.Row {
  84. if db.Debug {
  85. t := time.Now()
  86. row := db.DB.QueryRowContext(ctx, db.fixQuery(query), args...)
  87. log("[func QueryRow]", t, nil, false, db.fixQuery(query), args...)
  88. return row
  89. }
  90. return db.DB.QueryRowContext(ctx, db.fixQuery(query), args...)
  91. }
  92. func (db *DBMethods) Transaction(ctx context.Context, queries queryFunc) error {
  93. if queries == nil {
  94. return fmt.Errorf("queries is not set for transaction")
  95. }
  96. tx, err := db.Begin(ctx, nil)
  97. if err != nil {
  98. return err
  99. }
  100. if err := queries(ctx, tx); err != nil {
  101. return errors.Wrap(err, tx.Rollback().Error())
  102. }
  103. return tx.Commit()
  104. }