common.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. package common
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "io"
  7. "net/url"
  8. "os"
  9. "regexp"
  10. "strings"
  11. "time"
  12. "github.com/amacneil/dbmate/pkg/dbmate"
  13. _ "github.com/amacneil/dbmate/pkg/driver/mysql"
  14. _ "github.com/amacneil/dbmate/pkg/driver/postgres"
  15. _ "github.com/amacneil/dbmate/pkg/driver/sqlite"
  16. "golang.org/x/exp/slices"
  17. )
  18. type Engine interface {
  19. Begin(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
  20. Close() error
  21. Exec(ctx context.Context, query string, args ...any) (sql.Result, error)
  22. Ping(context.Context) error
  23. Prepare(ctx context.Context, query string) (*sql.Stmt, error)
  24. Query(ctx context.Context, query string, args ...any) (*sql.Rows, error)
  25. QueryRow(ctx context.Context, query string, args ...any) *sql.Row
  26. SetConnMaxLifetime(d time.Duration)
  27. SetMaxIdleConns(n int)
  28. SetMaxOpenConns(n int)
  29. Transaction(ctx context.Context, queries func(ctx context.Context, tx *Tx) error) error
  30. }
  31. var rLogSpacesAll = regexp.MustCompile(`[\s\t]+`)
  32. var rLogSpacesEnd = regexp.MustCompile(`[\s\t]+;$`)
  33. var rSqlParam = regexp.MustCompile(`\$\d+`)
  34. func log(w io.Writer, fname string, start time.Time, err error, tx bool, query string, args ...any) string {
  35. var values []string
  36. bold := "0"
  37. color := "33"
  38. // Transaction or not
  39. if tx {
  40. bold = "1"
  41. values = append(values, "[TX]")
  42. }
  43. // Function name
  44. if fname != "" {
  45. values = append(values, "[func "+fname+"]")
  46. }
  47. // SQL query
  48. if query != "" {
  49. values = append(values, rLogSpacesEnd.ReplaceAllString(
  50. strings.Trim(rLogSpacesAll.ReplaceAllString(query, " "), " "), ";",
  51. ))
  52. }
  53. // Params
  54. if len(args) > 0 {
  55. values = append(values, fmt.Sprintf("(%v)", args))
  56. } else {
  57. values = append(values, "(empty)")
  58. }
  59. // Error
  60. if err != nil {
  61. color = "31"
  62. values = append(values, "("+err.Error()+")")
  63. } else {
  64. values = append(values, "(nil)")
  65. }
  66. // Execute time with close color symbols
  67. values = append(values, fmt.Sprintf("%.3f ms\033[0m", time.Since(start).Seconds()))
  68. // Prepend start caption with colors
  69. values = append([]string{"\033[" + bold + ";" + color + "m[SQL]"}, values...)
  70. res := fmt.Sprintln(strings.Join(values, " "))
  71. fmt.Fprint(w, res)
  72. return res
  73. }
  74. func fixQuery(query string) string {
  75. return rSqlParam.ReplaceAllString(query, "?")
  76. }
  77. func ParseUrl(dbURL string) (*url.URL, error) {
  78. databaseURL, err := url.Parse(dbURL)
  79. if err != nil {
  80. return nil, fmt.Errorf("unable to parse URL: %w", err)
  81. }
  82. if databaseURL.Scheme == "" {
  83. return nil, fmt.Errorf("protocol scheme is not defined")
  84. }
  85. protocols := []string{"mysql", "postgres", "postgresql", "sqlite", "sqlite3"}
  86. if !slices.Contains(protocols, databaseURL.Scheme) {
  87. return nil, fmt.Errorf("unsupported protocol scheme: %s", databaseURL.Scheme)
  88. }
  89. return databaseURL, nil
  90. }
  91. func OpenDB(databaseURL *url.URL, migrationsDir string, debug bool) (*sql.DB, error) {
  92. mate := dbmate.New(databaseURL)
  93. mate.AutoDumpSchema = false
  94. mate.Log = io.Discard
  95. if migrationsDir != "" {
  96. mate.MigrationsDir = migrationsDir
  97. }
  98. driver, err := mate.GetDriver()
  99. if err != nil {
  100. return nil, fmt.Errorf("DB get driver error: %w", err)
  101. }
  102. if err := mate.CreateAndMigrate(); err != nil {
  103. return nil, fmt.Errorf("DB migration error: %w", err)
  104. }
  105. var db *sql.DB
  106. if debug {
  107. t := time.Now()
  108. db, err = driver.Open()
  109. log(os.Stdout, "Open", t, err, false, "")
  110. if err != nil {
  111. return nil, fmt.Errorf("DB open error: %w", err)
  112. }
  113. } else {
  114. db, err = driver.Open()
  115. if err != nil {
  116. return nil, fmt.Errorf("DB open error: %w", err)
  117. }
  118. }
  119. return db, nil
  120. }