common.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. package common
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "io"
  7. "net/url"
  8. "os"
  9. "reflect"
  10. "regexp"
  11. "strings"
  12. "time"
  13. "github.com/amacneil/dbmate/pkg/dbmate"
  14. _ "github.com/amacneil/dbmate/pkg/driver/mysql"
  15. _ "github.com/amacneil/dbmate/pkg/driver/postgres"
  16. _ "github.com/amacneil/dbmate/pkg/driver/sqlite"
  17. "golang.org/x/exp/slices"
  18. )
  19. type Engine interface {
  20. Begin(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
  21. Close() error
  22. CurrentUnixTimestamp() int64
  23. DeleteRowByID(ctx context.Context, id int64, row any) error
  24. Each(ctx context.Context, query string, logic func(ctx context.Context, rows *Rows) error, args ...any) error
  25. Exec(ctx context.Context, query string, args ...any) (sql.Result, error)
  26. Ping(context.Context) error
  27. Prepare(ctx context.Context, query string) (*sql.Stmt, error)
  28. Query(ctx context.Context, query string, args ...any) (*Rows, error)
  29. QueryRow(ctx context.Context, query string, args ...any) *Row
  30. QueryRowByID(ctx context.Context, id int64, row any) error
  31. RowExists(ctx context.Context, id int64, row any) bool
  32. SetConnMaxLifetime(d time.Duration)
  33. SetMaxIdleConns(n int)
  34. SetMaxOpenConns(n int)
  35. Transaction(ctx context.Context, queries func(ctx context.Context, tx *Tx) error) error
  36. }
  37. var rSqlParam = regexp.MustCompile(`\$\d+`)
  38. var rLogSpacesAll = regexp.MustCompile(`[\s\t]+`)
  39. var rLogSpacesEnd = regexp.MustCompile(`[\s\t]+;$`)
  40. func fixQuery(query string) string {
  41. return rSqlParam.ReplaceAllString(query, "?")
  42. }
  43. func log(w io.Writer, fname string, start time.Time, err error, tx bool, query string, args ...any) string {
  44. var values []string
  45. bold := "0"
  46. color := "33"
  47. // Transaction or not
  48. if tx {
  49. bold = "1"
  50. values = append(values, "[TX]")
  51. }
  52. // Function name
  53. if fname != "" {
  54. values = append(values, "[func "+fname+"]")
  55. }
  56. // SQL query
  57. if query != "" {
  58. values = append(values, rLogSpacesEnd.ReplaceAllString(
  59. strings.Trim(rLogSpacesAll.ReplaceAllString(query, " "), " "), ";",
  60. ))
  61. }
  62. // Params
  63. if len(args) > 0 {
  64. values = append(values, fmt.Sprintf("(%v)", args))
  65. } else {
  66. values = append(values, "(empty)")
  67. }
  68. // Error
  69. if err != nil {
  70. color = "31"
  71. values = append(values, "("+err.Error()+")")
  72. } else {
  73. values = append(values, "(nil)")
  74. }
  75. // Execute time with close color symbols
  76. values = append(values, fmt.Sprintf("%.3f ms\033[0m", time.Since(start).Seconds()))
  77. // Prepend start caption with colors
  78. values = append([]string{"\033[" + bold + ";" + color + "m[SQL]"}, values...)
  79. res := fmt.Sprintln(strings.Join(values, " "))
  80. fmt.Fprint(w, res)
  81. return res
  82. }
  83. func scans(row any) []any {
  84. v := reflect.ValueOf(row).Elem()
  85. res := make([]interface{}, v.NumField())
  86. for i := 0; i < v.NumField(); i++ {
  87. res[i] = v.Field(i).Addr().Interface()
  88. }
  89. return res
  90. }
  91. func queryRowByIDString(row any) string {
  92. v := reflect.ValueOf(row).Elem()
  93. t := v.Type()
  94. var table string
  95. fields := []string{}
  96. for i := 0; i < t.NumField(); i++ {
  97. if table == "" {
  98. if tag := t.Field(i).Tag.Get("table"); tag != "" {
  99. table = tag
  100. }
  101. }
  102. if tag := t.Field(i).Tag.Get("field"); tag != "" {
  103. fields = append(fields, tag)
  104. }
  105. }
  106. return `SELECT ` + strings.Join(fields, ", ") + ` FROM ` + table + ` WHERE id = $1 LIMIT 1`
  107. }
  108. func rowExistsString(row any) string {
  109. v := reflect.ValueOf(row).Elem()
  110. t := v.Type()
  111. var table string
  112. for i := 0; i < t.NumField(); i++ {
  113. if table == "" {
  114. if tag := t.Field(i).Tag.Get("table"); tag != "" {
  115. table = tag
  116. }
  117. }
  118. }
  119. return `SELECT 1 FROM ` + table + ` WHERE id = $1 LIMIT 1`
  120. }
  121. func deleteRowByIDString(row any) string {
  122. v := reflect.ValueOf(row).Elem()
  123. t := v.Type()
  124. var table string
  125. for i := 0; i < t.NumField(); i++ {
  126. if table == "" {
  127. if tag := t.Field(i).Tag.Get("table"); tag != "" {
  128. table = tag
  129. }
  130. }
  131. }
  132. return `DELETE FROM ` + table + ` WHERE id = $1`
  133. }
  134. func currentUnixTimestamp() int64 {
  135. return time.Now().UTC().Unix()
  136. }
  137. func ParseUrl(dbURL string) (*url.URL, error) {
  138. databaseURL, err := url.Parse(dbURL)
  139. if err != nil {
  140. return nil, fmt.Errorf("unable to parse URL: %w", err)
  141. }
  142. if databaseURL.Scheme == "" {
  143. return nil, fmt.Errorf("protocol scheme is not defined")
  144. }
  145. protocols := []string{"mysql", "postgres", "postgresql", "sqlite", "sqlite3"}
  146. if !slices.Contains(protocols, databaseURL.Scheme) {
  147. return nil, fmt.Errorf("unsupported protocol scheme: %s", databaseURL.Scheme)
  148. }
  149. return databaseURL, nil
  150. }
  151. func OpenDB(databaseURL *url.URL, migrationsDir string, skipMigration bool, debug bool) (*sql.DB, error) {
  152. mate := dbmate.New(databaseURL)
  153. mate.AutoDumpSchema = false
  154. mate.Log = io.Discard
  155. if migrationsDir != "" {
  156. mate.MigrationsDir = migrationsDir
  157. }
  158. driver, err := mate.GetDriver()
  159. if err != nil {
  160. return nil, fmt.Errorf("DB get driver error: %w", err)
  161. }
  162. if !skipMigration {
  163. if err := mate.CreateAndMigrate(); err != nil {
  164. return nil, fmt.Errorf("DB migration error: %w", err)
  165. }
  166. }
  167. var db *sql.DB
  168. start := time.Now()
  169. db, err = driver.Open()
  170. if debug {
  171. log(os.Stdout, "Open", start, err, false, "")
  172. }
  173. if err != nil {
  174. return nil, fmt.Errorf("DB open error: %w", err)
  175. }
  176. return db, nil
  177. }