common.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. package common
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "io"
  7. "net/url"
  8. "regexp"
  9. "strings"
  10. "time"
  11. "github.com/amacneil/dbmate/pkg/dbmate"
  12. _ "github.com/amacneil/dbmate/pkg/driver/mysql"
  13. _ "github.com/amacneil/dbmate/pkg/driver/postgres"
  14. _ "github.com/amacneil/dbmate/pkg/driver/sqlite"
  15. "golang.org/x/exp/slices"
  16. )
  17. type Engine interface {
  18. Begin(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
  19. Close() error
  20. Exec(ctx context.Context, query string, args ...any) (sql.Result, error)
  21. Ping(context.Context) error
  22. Prepare(ctx context.Context, query string) (*sql.Stmt, error)
  23. Query(ctx context.Context, query string, args ...any) (*sql.Rows, error)
  24. QueryRow(ctx context.Context, query string, args ...any) *sql.Row
  25. Transaction(ctx context.Context, queries func(ctx context.Context, tx *Tx) error) error
  26. }
  27. var rLogSpacesAll = regexp.MustCompile(`[\s\t]+`)
  28. var rLogSpacesEnd = regexp.MustCompile(`[\s\t]+;$`)
  29. var rSqlParam = regexp.MustCompile(`\$\d+`)
  30. func log(w io.Writer, fname string, start time.Time, err error, tx bool, query string, args ...any) string {
  31. var values []string
  32. bold := "0"
  33. color := "33"
  34. // Transaction or not
  35. if tx {
  36. bold = "1"
  37. values = append(values, "[TX]")
  38. }
  39. // Function name
  40. if fname != "" {
  41. values = append(values, fname)
  42. }
  43. // SQL query
  44. if query != "" {
  45. values = append(values, rLogSpacesEnd.ReplaceAllString(
  46. strings.Trim(rLogSpacesAll.ReplaceAllString(query, " "), " "), ";",
  47. ))
  48. }
  49. // Params
  50. if len(args) > 0 {
  51. values = append(values, fmt.Sprintf("(%v)", args))
  52. } else {
  53. values = append(values, "(empty)")
  54. }
  55. // Error
  56. if err != nil {
  57. color = "31"
  58. values = append(values, "("+err.Error()+")")
  59. } else {
  60. values = append(values, "(nil)")
  61. }
  62. // Execute time with close color symbols
  63. values = append(values, fmt.Sprintf("%.3f ms\033[0m", time.Since(start).Seconds()))
  64. // Prepend start caption with colors
  65. values = append([]string{"\033[" + bold + ";" + color + "m[SQL]"}, values...)
  66. res := fmt.Sprintln(strings.Join(values, " "))
  67. fmt.Fprint(w, res)
  68. return res
  69. }
  70. func fixQuery(query string) string {
  71. return rSqlParam.ReplaceAllString(query, "?")
  72. }
  73. func ParseUrl(dbURL string) (*url.URL, error) {
  74. databaseURL, err := url.Parse(dbURL)
  75. if err != nil {
  76. return nil, fmt.Errorf("unable to parse URL: %w", err)
  77. }
  78. if databaseURL.Scheme == "" {
  79. return nil, fmt.Errorf("protocol scheme is not defined")
  80. }
  81. protocols := []string{"mysql", "postgres", "postgresql", "sqlite", "sqlite3"}
  82. if !slices.Contains(protocols, databaseURL.Scheme) {
  83. return nil, fmt.Errorf("unsupported protocol scheme: %s", databaseURL.Scheme)
  84. }
  85. return databaseURL, nil
  86. }
  87. func OpenDB(databaseURL *url.URL, migrationsDir string) (*sql.DB, error) {
  88. mate := dbmate.New(databaseURL)
  89. mate.AutoDumpSchema = false
  90. mate.Log = io.Discard
  91. if migrationsDir != "" {
  92. mate.MigrationsDir = migrationsDir
  93. }
  94. driver, err := mate.GetDriver()
  95. if err != nil {
  96. return nil, fmt.Errorf("DB get driver error: %w", err)
  97. }
  98. if err := mate.CreateAndMigrate(); err != nil {
  99. return nil, fmt.Errorf("DB migration error: %w", err)
  100. }
  101. db, err := driver.Open()
  102. if err != nil {
  103. return nil, fmt.Errorf("DB open error: %w", err)
  104. }
  105. return db, nil
  106. }