dbmethods.go 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. package common
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "regexp"
  7. "github.com/pkg/errors"
  8. )
  9. type DBMethods struct {
  10. DB *sql.DB
  11. Driver string
  12. }
  13. type qFunc func(ctx context.Context, tx *sql.Tx) error
  14. var r = regexp.MustCompile(`\$\d+`)
  15. func (db *DBMethods) fixQuery(query string) string {
  16. if db.Driver == "mysql" {
  17. return r.ReplaceAllString(query, "?")
  18. }
  19. return query
  20. }
  21. func (db *DBMethods) Begin(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
  22. return db.DB.BeginTx(ctx, opts)
  23. }
  24. func (db *DBMethods) Close() error {
  25. return db.DB.Close()
  26. }
  27. func (db *DBMethods) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) {
  28. return db.DB.ExecContext(ctx, db.fixQuery(query), args...)
  29. }
  30. func (db *DBMethods) Ping(ctx context.Context) error {
  31. return db.DB.PingContext(ctx)
  32. }
  33. func (db *DBMethods) Prepare(ctx context.Context, query string) (*sql.Stmt, error) {
  34. return db.DB.PrepareContext(ctx, db.fixQuery(query))
  35. }
  36. func (db *DBMethods) Query(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
  37. return db.DB.QueryContext(ctx, db.fixQuery(query), args...)
  38. }
  39. func (db *DBMethods) QueryRow(ctx context.Context, query string, args ...any) *sql.Row {
  40. return db.DB.QueryRowContext(ctx, db.fixQuery(query), args...)
  41. }
  42. func (db *DBMethods) Transaction(ctx context.Context, queries qFunc) error {
  43. if queries == nil {
  44. return fmt.Errorf("queries is not set for transaction")
  45. }
  46. tx, err := db.Begin(ctx, nil)
  47. if err != nil {
  48. return err
  49. }
  50. if err := queries(ctx, tx); err != nil {
  51. return errors.Wrap(err, tx.Rollback().Error())
  52. }
  53. return tx.Commit()
  54. }