common.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. package common
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "io"
  7. "net/url"
  8. "os"
  9. "regexp"
  10. "strings"
  11. "time"
  12. "github.com/amacneil/dbmate/pkg/dbmate"
  13. _ "github.com/amacneil/dbmate/pkg/driver/mysql"
  14. _ "github.com/amacneil/dbmate/pkg/driver/postgres"
  15. _ "github.com/amacneil/dbmate/pkg/driver/sqlite"
  16. "golang.org/x/exp/slices"
  17. )
  18. type queryFunc func(ctx context.Context, tx *Tx) error
  19. type Engine interface {
  20. Begin(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
  21. Close() error
  22. Exec(ctx context.Context, query string, args ...any) (sql.Result, error)
  23. Ping(context.Context) error
  24. Prepare(ctx context.Context, query string) (*sql.Stmt, error)
  25. Query(ctx context.Context, query string, args ...any) (*sql.Rows, error)
  26. QueryRow(ctx context.Context, query string, args ...any) *sql.Row
  27. Transaction(ctx context.Context, queries queryFunc) error
  28. }
  29. var rLogSpacesAll = regexp.MustCompile(`[\s\t]+`)
  30. var rLogSpacesEnd = regexp.MustCompile(`[\s\t]+;$`)
  31. var rSqlParam = regexp.MustCompile(`\$\d+`)
  32. func log(m string, s time.Time, e error, tx bool, query string, args ...any) string {
  33. var tmsg string
  34. if tx {
  35. tmsg = " [TX]"
  36. }
  37. if m != "" {
  38. tmsg = tmsg + " " + m
  39. }
  40. qmsg := query
  41. if qmsg != "" {
  42. qmsg = strings.Trim(rLogSpacesAll.ReplaceAllString(qmsg, " "), " ")
  43. qmsg = rLogSpacesEnd.ReplaceAllString(qmsg, ";")
  44. qmsg = " " + qmsg
  45. }
  46. astr := " (empty)"
  47. if len(args) > 0 {
  48. astr = fmt.Sprintf(" (%v)", args)
  49. }
  50. estr := " (nil)"
  51. if e != nil {
  52. estr = " \033[0m\033[0;31m(" + e.Error() + ")"
  53. }
  54. color := "0;33"
  55. if tx {
  56. color = "1;33"
  57. }
  58. res := fmt.Sprintln("\033[" + color + "m[SQL]" + tmsg + qmsg + astr + estr + fmt.Sprintf(" %.3f ms", time.Since(s).Seconds()) + "\033[0m")
  59. fmt.Fprintln(os.Stdout, res)
  60. return res
  61. }
  62. func fixQuery(query string) string {
  63. return rSqlParam.ReplaceAllString(query, "?")
  64. }
  65. func ParseUrl(dbURL string) (*url.URL, error) {
  66. databaseURL, err := url.Parse(dbURL)
  67. if err != nil {
  68. return nil, fmt.Errorf("unable to parse URL: %w", err)
  69. }
  70. if databaseURL.Scheme == "" {
  71. return nil, fmt.Errorf("protocol scheme is not defined")
  72. }
  73. protocols := []string{"mysql", "postgres", "postgresql", "sqlite", "sqlite3"}
  74. if !slices.Contains(protocols, databaseURL.Scheme) {
  75. return nil, fmt.Errorf("unsupported protocol scheme: %s", databaseURL.Scheme)
  76. }
  77. return databaseURL, nil
  78. }
  79. func OpenDB(databaseURL *url.URL, migrationsDir string) (*sql.DB, error) {
  80. mate := dbmate.New(databaseURL)
  81. mate.AutoDumpSchema = false
  82. mate.Log = io.Discard
  83. if migrationsDir != "" {
  84. mate.MigrationsDir = migrationsDir
  85. }
  86. driver, err := mate.GetDriver()
  87. if err != nil {
  88. return nil, fmt.Errorf("DB get driver error: %w", err)
  89. }
  90. if err := mate.CreateAndMigrate(); err != nil {
  91. return nil, fmt.Errorf("DB migration error: %w", err)
  92. }
  93. db, err := driver.Open()
  94. if err != nil {
  95. return nil, fmt.Errorf("DB open error: %w", err)
  96. }
  97. return db, nil
  98. }