common.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. package common
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "io"
  7. "net/url"
  8. "os"
  9. "reflect"
  10. "regexp"
  11. "strings"
  12. "time"
  13. "github.com/amacneil/dbmate/pkg/dbmate"
  14. _ "github.com/amacneil/dbmate/pkg/driver/mysql"
  15. _ "github.com/amacneil/dbmate/pkg/driver/postgres"
  16. _ "github.com/amacneil/dbmate/pkg/driver/sqlite"
  17. "golang.org/x/exp/slices"
  18. )
  19. type Engine interface {
  20. Begin(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
  21. Close() error
  22. Each(ctx context.Context, query string, logic func(ctx context.Context, rows *Rows) error, args ...any) error
  23. Exec(ctx context.Context, query string, args ...any) (sql.Result, error)
  24. Ping(context.Context) error
  25. Prepare(ctx context.Context, query string) (*sql.Stmt, error)
  26. Query(ctx context.Context, query string, args ...any) (*Rows, error)
  27. QueryRow(ctx context.Context, query string, args ...any) *Row
  28. QueryRowByID(ctx context.Context, id int64, row any) error
  29. RowExists(ctx context.Context, id int64, row any) bool
  30. SetConnMaxLifetime(d time.Duration)
  31. SetMaxIdleConns(n int)
  32. SetMaxOpenConns(n int)
  33. Transaction(ctx context.Context, queries func(ctx context.Context, tx *Tx) error) error
  34. }
  35. var rSqlParam = regexp.MustCompile(`\$\d+`)
  36. var rLogSpacesAll = regexp.MustCompile(`[\s\t]+`)
  37. var rLogSpacesEnd = regexp.MustCompile(`[\s\t]+;$`)
  38. func fixQuery(query string) string {
  39. return rSqlParam.ReplaceAllString(query, "?")
  40. }
  41. func log(w io.Writer, fname string, start time.Time, err error, tx bool, query string, args ...any) string {
  42. var values []string
  43. bold := "0"
  44. color := "33"
  45. // Transaction or not
  46. if tx {
  47. bold = "1"
  48. values = append(values, "[TX]")
  49. }
  50. // Function name
  51. if fname != "" {
  52. values = append(values, "[func "+fname+"]")
  53. }
  54. // SQL query
  55. if query != "" {
  56. values = append(values, rLogSpacesEnd.ReplaceAllString(
  57. strings.Trim(rLogSpacesAll.ReplaceAllString(query, " "), " "), ";",
  58. ))
  59. }
  60. // Params
  61. if len(args) > 0 {
  62. values = append(values, fmt.Sprintf("(%v)", args))
  63. } else {
  64. values = append(values, "(empty)")
  65. }
  66. // Error
  67. if err != nil {
  68. color = "31"
  69. values = append(values, "("+err.Error()+")")
  70. } else {
  71. values = append(values, "(nil)")
  72. }
  73. // Execute time with close color symbols
  74. values = append(values, fmt.Sprintf("%.3f ms\033[0m", time.Since(start).Seconds()))
  75. // Prepend start caption with colors
  76. values = append([]string{"\033[" + bold + ";" + color + "m[SQL]"}, values...)
  77. res := fmt.Sprintln(strings.Join(values, " "))
  78. fmt.Fprint(w, res)
  79. return res
  80. }
  81. func scans(row any) []any {
  82. v := reflect.ValueOf(row).Elem()
  83. res := make([]interface{}, v.NumField())
  84. for i := 0; i < v.NumField(); i++ {
  85. res[i] = v.Field(i).Addr().Interface()
  86. }
  87. return res
  88. }
  89. func queryRowByIDString(row any) string {
  90. v := reflect.ValueOf(row).Elem()
  91. t := v.Type()
  92. var table string
  93. fields := []string{}
  94. for i := 0; i < t.NumField(); i++ {
  95. if table == "" {
  96. if tag := t.Field(i).Tag.Get("table"); tag != "" {
  97. table = tag
  98. }
  99. }
  100. if tag := t.Field(i).Tag.Get("field"); tag != "" {
  101. fields = append(fields, tag)
  102. }
  103. }
  104. return `SELECT ` + strings.Join(fields, ", ") + ` FROM ` + table + ` WHERE id = $1 LIMIT 1`
  105. }
  106. func rowExistsString(row any) string {
  107. v := reflect.ValueOf(row).Elem()
  108. t := v.Type()
  109. var table string
  110. for i := 0; i < t.NumField(); i++ {
  111. if table == "" {
  112. if tag := t.Field(i).Tag.Get("table"); tag != "" {
  113. table = tag
  114. }
  115. }
  116. }
  117. return `SELECT 1 FROM ` + table + ` WHERE id = $1 LIMIT 1`
  118. }
  119. func deleteRowByIDString(row any) string {
  120. v := reflect.ValueOf(row).Elem()
  121. t := v.Type()
  122. var table string
  123. for i := 0; i < t.NumField(); i++ {
  124. if table == "" {
  125. if tag := t.Field(i).Tag.Get("table"); tag != "" {
  126. table = tag
  127. }
  128. }
  129. }
  130. return `DELETE FROM ` + table + ` WHERE id = $1 LIMIT 1`
  131. }
  132. func ParseUrl(dbURL string) (*url.URL, error) {
  133. databaseURL, err := url.Parse(dbURL)
  134. if err != nil {
  135. return nil, fmt.Errorf("unable to parse URL: %w", err)
  136. }
  137. if databaseURL.Scheme == "" {
  138. return nil, fmt.Errorf("protocol scheme is not defined")
  139. }
  140. protocols := []string{"mysql", "postgres", "postgresql", "sqlite", "sqlite3"}
  141. if !slices.Contains(protocols, databaseURL.Scheme) {
  142. return nil, fmt.Errorf("unsupported protocol scheme: %s", databaseURL.Scheme)
  143. }
  144. return databaseURL, nil
  145. }
  146. func OpenDB(databaseURL *url.URL, migrationsDir string, skipMigration bool, debug bool) (*sql.DB, error) {
  147. mate := dbmate.New(databaseURL)
  148. mate.AutoDumpSchema = false
  149. mate.Log = io.Discard
  150. if migrationsDir != "" {
  151. mate.MigrationsDir = migrationsDir
  152. }
  153. driver, err := mate.GetDriver()
  154. if err != nil {
  155. return nil, fmt.Errorf("DB get driver error: %w", err)
  156. }
  157. if !skipMigration {
  158. if err := mate.CreateAndMigrate(); err != nil {
  159. return nil, fmt.Errorf("DB migration error: %w", err)
  160. }
  161. }
  162. var db *sql.DB
  163. start := time.Now()
  164. db, err = driver.Open()
  165. if debug {
  166. log(os.Stdout, "Open", start, err, false, "")
  167. }
  168. if err != nil {
  169. return nil, fmt.Errorf("DB open error: %w", err)
  170. }
  171. return db, nil
  172. }