common.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. package common
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "io"
  7. "net/url"
  8. "os"
  9. "reflect"
  10. "regexp"
  11. "strings"
  12. "time"
  13. "github.com/amacneil/dbmate/pkg/dbmate"
  14. _ "github.com/amacneil/dbmate/pkg/driver/mysql"
  15. _ "github.com/amacneil/dbmate/pkg/driver/postgres"
  16. _ "github.com/amacneil/dbmate/pkg/driver/sqlite"
  17. "golang.org/x/exp/slices"
  18. )
  19. type Engine interface {
  20. Begin(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
  21. Close() error
  22. DeleteRowByID(ctx context.Context, id int64, row any) error
  23. Each(ctx context.Context, query string, logic func(ctx context.Context, rows *Rows) error, args ...any) error
  24. Exec(ctx context.Context, query string, args ...any) (sql.Result, error)
  25. Ping(context.Context) error
  26. Prepare(ctx context.Context, query string) (*sql.Stmt, error)
  27. Query(ctx context.Context, query string, args ...any) (*Rows, error)
  28. QueryRow(ctx context.Context, query string, args ...any) *Row
  29. QueryRowByID(ctx context.Context, id int64, row any) error
  30. RowExists(ctx context.Context, id int64, row any) bool
  31. SetConnMaxLifetime(d time.Duration)
  32. SetMaxIdleConns(n int)
  33. SetMaxOpenConns(n int)
  34. Transaction(ctx context.Context, queries func(ctx context.Context, tx *Tx) error) error
  35. }
  36. var rSqlParam = regexp.MustCompile(`\$\d+`)
  37. var rLogSpacesAll = regexp.MustCompile(`[\s\t]+`)
  38. var rLogSpacesEnd = regexp.MustCompile(`[\s\t]+;$`)
  39. func fixQuery(query string) string {
  40. return rSqlParam.ReplaceAllString(query, "?")
  41. }
  42. func log(w io.Writer, fname string, start time.Time, err error, tx bool, query string, args ...any) string {
  43. var values []string
  44. bold := "0"
  45. color := "33"
  46. // Transaction or not
  47. if tx {
  48. bold = "1"
  49. values = append(values, "[TX]")
  50. }
  51. // Function name
  52. if fname != "" {
  53. values = append(values, "[func "+fname+"]")
  54. }
  55. // SQL query
  56. if query != "" {
  57. values = append(values, rLogSpacesEnd.ReplaceAllString(
  58. strings.Trim(rLogSpacesAll.ReplaceAllString(query, " "), " "), ";",
  59. ))
  60. }
  61. // Params
  62. if len(args) > 0 {
  63. values = append(values, fmt.Sprintf("(%v)", args))
  64. } else {
  65. values = append(values, "(empty)")
  66. }
  67. // Error
  68. if err != nil {
  69. color = "31"
  70. values = append(values, "("+err.Error()+")")
  71. } else {
  72. values = append(values, "(nil)")
  73. }
  74. // Execute time with close color symbols
  75. values = append(values, fmt.Sprintf("%.3f ms\033[0m", time.Since(start).Seconds()))
  76. // Prepend start caption with colors
  77. values = append([]string{"\033[" + bold + ";" + color + "m[SQL]"}, values...)
  78. res := fmt.Sprintln(strings.Join(values, " "))
  79. fmt.Fprint(w, res)
  80. return res
  81. }
  82. func scans(row any) []any {
  83. v := reflect.ValueOf(row).Elem()
  84. res := make([]interface{}, v.NumField())
  85. for i := 0; i < v.NumField(); i++ {
  86. res[i] = v.Field(i).Addr().Interface()
  87. }
  88. return res
  89. }
  90. func queryRowByIDString(row any) string {
  91. v := reflect.ValueOf(row).Elem()
  92. t := v.Type()
  93. var table string
  94. fields := []string{}
  95. for i := 0; i < t.NumField(); i++ {
  96. if table == "" {
  97. if tag := t.Field(i).Tag.Get("table"); tag != "" {
  98. table = tag
  99. }
  100. }
  101. if tag := t.Field(i).Tag.Get("field"); tag != "" {
  102. fields = append(fields, tag)
  103. }
  104. }
  105. return `SELECT ` + strings.Join(fields, ", ") + ` FROM ` + table + ` WHERE id = $1 LIMIT 1`
  106. }
  107. func rowExistsString(row any) string {
  108. v := reflect.ValueOf(row).Elem()
  109. t := v.Type()
  110. var table string
  111. for i := 0; i < t.NumField(); i++ {
  112. if table == "" {
  113. if tag := t.Field(i).Tag.Get("table"); tag != "" {
  114. table = tag
  115. }
  116. }
  117. }
  118. return `SELECT 1 FROM ` + table + ` WHERE id = $1 LIMIT 1`
  119. }
  120. func deleteRowByIDString(row any) string {
  121. v := reflect.ValueOf(row).Elem()
  122. t := v.Type()
  123. var table string
  124. for i := 0; i < t.NumField(); i++ {
  125. if table == "" {
  126. if tag := t.Field(i).Tag.Get("table"); tag != "" {
  127. table = tag
  128. }
  129. }
  130. }
  131. return `DELETE FROM ` + table + ` WHERE id = $1`
  132. }
  133. func ParseUrl(dbURL string) (*url.URL, error) {
  134. databaseURL, err := url.Parse(dbURL)
  135. if err != nil {
  136. return nil, fmt.Errorf("unable to parse URL: %w", err)
  137. }
  138. if databaseURL.Scheme == "" {
  139. return nil, fmt.Errorf("protocol scheme is not defined")
  140. }
  141. protocols := []string{"mysql", "postgres", "postgresql", "sqlite", "sqlite3"}
  142. if !slices.Contains(protocols, databaseURL.Scheme) {
  143. return nil, fmt.Errorf("unsupported protocol scheme: %s", databaseURL.Scheme)
  144. }
  145. return databaseURL, nil
  146. }
  147. func OpenDB(databaseURL *url.URL, migrationsDir string, skipMigration bool, debug bool) (*sql.DB, error) {
  148. mate := dbmate.New(databaseURL)
  149. mate.AutoDumpSchema = false
  150. mate.Log = io.Discard
  151. if migrationsDir != "" {
  152. mate.MigrationsDir = migrationsDir
  153. }
  154. driver, err := mate.GetDriver()
  155. if err != nil {
  156. return nil, fmt.Errorf("DB get driver error: %w", err)
  157. }
  158. if !skipMigration {
  159. if err := mate.CreateAndMigrate(); err != nil {
  160. return nil, fmt.Errorf("DB migration error: %w", err)
  161. }
  162. }
  163. var db *sql.DB
  164. start := time.Now()
  165. db, err = driver.Open()
  166. if debug {
  167. log(os.Stdout, "Open", start, err, false, "")
  168. }
  169. if err != nil {
  170. return nil, fmt.Errorf("DB open error: %w", err)
  171. }
  172. return db, nil
  173. }