common.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. package common
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "io"
  7. "net/url"
  8. "os"
  9. "strings"
  10. "time"
  11. "github.com/amacneil/dbmate/pkg/dbmate"
  12. _ "github.com/amacneil/dbmate/pkg/driver/mysql"
  13. _ "github.com/amacneil/dbmate/pkg/driver/postgres"
  14. _ "github.com/amacneil/dbmate/pkg/driver/sqlite"
  15. "golang.org/x/exp/slices"
  16. )
  17. type Engine interface {
  18. Begin(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
  19. Close() error
  20. Exec(ctx context.Context, query string, args ...any) (sql.Result, error)
  21. Ping(context.Context) error
  22. Prepare(ctx context.Context, query string) (*sql.Stmt, error)
  23. Query(ctx context.Context, query string, args ...any) (*sql.Rows, error)
  24. QueryRow(ctx context.Context, query string, args ...any) *sql.Row
  25. Transaction(ctx context.Context, queries queryFunc) error
  26. }
  27. func log(m string, s time.Time, e error, tx bool, query string, args ...any) {
  28. var tmsg string
  29. if tx {
  30. tmsg = " [TX]"
  31. }
  32. if m != "" {
  33. tmsg = tmsg + " " + m
  34. }
  35. qmsg := query
  36. if qmsg != "" {
  37. qmsg = strings.Trim(rLogSpacesAll.ReplaceAllString(qmsg, " "), " ")
  38. qmsg = rLogSpacesEnd.ReplaceAllString(qmsg, ";")
  39. qmsg = " " + qmsg
  40. }
  41. astr := " (empty)"
  42. if len(args) > 0 {
  43. astr = fmt.Sprintf(" (%v)", args)
  44. }
  45. estr := " (nil)"
  46. if e != nil {
  47. estr = " \033[0m\033[0;31m(" + e.Error() + ")"
  48. }
  49. color := "0;33"
  50. if tx {
  51. color = "1;33"
  52. }
  53. fmt.Fprintln(os.Stdout, "\033["+color+"m[SQL]"+tmsg+qmsg+astr+estr+fmt.Sprintf(" %.3f ms", time.Since(s).Seconds())+"\033[0m")
  54. }
  55. func ParseUrl(dbURL string) (*url.URL, error) {
  56. databaseURL, err := url.Parse(dbURL)
  57. if err != nil {
  58. return nil, fmt.Errorf("unable to parse URL: %w", err)
  59. }
  60. if databaseURL.Scheme == "" {
  61. return nil, fmt.Errorf("protocol scheme is not defined")
  62. }
  63. protocols := []string{"mysql", "postgres", "postgresql", "sqlite", "sqlite3"}
  64. if !slices.Contains(protocols, databaseURL.Scheme) {
  65. return nil, fmt.Errorf("unsupported protocol scheme: %s", databaseURL.Scheme)
  66. }
  67. return databaseURL, nil
  68. }
  69. func OpenDB(databaseURL *url.URL, migrationsDir string) (*sql.DB, error) {
  70. mate := dbmate.New(databaseURL)
  71. mate.AutoDumpSchema = false
  72. mate.Log = io.Discard
  73. if migrationsDir != "" {
  74. mate.MigrationsDir = migrationsDir
  75. }
  76. driver, err := mate.GetDriver()
  77. if err != nil {
  78. return nil, fmt.Errorf("DB get driver error: %w", err)
  79. }
  80. if err := mate.CreateAndMigrate(); err != nil {
  81. return nil, fmt.Errorf("DB migration error: %w", err)
  82. }
  83. db, err := driver.Open()
  84. if err != nil {
  85. return nil, fmt.Errorf("DB open error: %w", err)
  86. }
  87. return db, nil
  88. }