common.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. package common
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "io"
  7. "net/url"
  8. "os"
  9. "reflect"
  10. "regexp"
  11. "strings"
  12. "time"
  13. "github.com/amacneil/dbmate/pkg/dbmate"
  14. _ "github.com/amacneil/dbmate/pkg/driver/mysql"
  15. _ "github.com/amacneil/dbmate/pkg/driver/postgres"
  16. _ "github.com/amacneil/dbmate/pkg/driver/sqlite"
  17. "golang.org/x/exp/slices"
  18. )
  19. type Engine interface {
  20. Begin(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
  21. Close() error
  22. CurrentUnixTimestamp() int64
  23. DeleteRowByID(ctx context.Context, id int64, row any) error
  24. Each(ctx context.Context, query string, logic func(ctx context.Context, rows *Rows) error, args ...any) error
  25. Exec(ctx context.Context, query string, args ...any) (sql.Result, error)
  26. Ping(context.Context) error
  27. Prepare(ctx context.Context, query string) (*sql.Stmt, error)
  28. Query(ctx context.Context, query string, args ...any) (*Rows, error)
  29. QueryRow(ctx context.Context, query string, args ...any) *Row
  30. QueryRowByID(ctx context.Context, id int64, row any) error
  31. RowExists(ctx context.Context, id int64, row any) bool
  32. SetConnMaxLifetime(d time.Duration)
  33. SetMaxIdleConns(n int)
  34. SetMaxOpenConns(n int)
  35. Transaction(ctx context.Context, queries func(ctx context.Context, tx *Tx) error) error
  36. }
  37. var rSqlParam = regexp.MustCompile(`\$\d+`)
  38. var rLogSpacesAll = regexp.MustCompile(`[\s\t]+`)
  39. var rLogSpacesEnd = regexp.MustCompile(`[\s\t]+;$`)
  40. func currentUnixTimestamp() int64 {
  41. return time.Now().UTC().Unix()
  42. }
  43. func deleteRowByIDString(row any) string {
  44. v := reflect.ValueOf(row).Elem()
  45. t := v.Type()
  46. var table string
  47. for i := 0; i < t.NumField(); i++ {
  48. if table == "" {
  49. if tag := t.Field(i).Tag.Get("table"); tag != "" {
  50. table = tag
  51. }
  52. }
  53. }
  54. return `DELETE FROM ` + table + ` WHERE id = $1`
  55. }
  56. func fixQuery(query string) string {
  57. return rSqlParam.ReplaceAllString(query, "?")
  58. }
  59. func log(w io.Writer, fname string, start time.Time, err error, tx bool, query string, args ...any) string {
  60. var values []string
  61. bold := "0"
  62. color := "33"
  63. // Transaction or not
  64. if tx {
  65. bold = "1"
  66. values = append(values, "[TX]")
  67. }
  68. // Function name
  69. if fname != "" {
  70. values = append(values, "[func "+fname+"]")
  71. }
  72. // SQL query
  73. if query != "" {
  74. values = append(values, rLogSpacesEnd.ReplaceAllString(
  75. strings.Trim(rLogSpacesAll.ReplaceAllString(query, " "), " "), ";",
  76. ))
  77. }
  78. // Params
  79. if len(args) > 0 {
  80. values = append(values, fmt.Sprintf("(%v)", args))
  81. } else {
  82. values = append(values, "(empty)")
  83. }
  84. // Error
  85. if err != nil {
  86. color = "31"
  87. values = append(values, "("+err.Error()+")")
  88. } else {
  89. values = append(values, "(nil)")
  90. }
  91. // Execute time with close color symbols
  92. values = append(values, fmt.Sprintf("%.3f ms\033[0m", time.Since(start).Seconds()))
  93. // Prepend start caption with colors
  94. values = append([]string{"\033[" + bold + ";" + color + "m[SQL]"}, values...)
  95. res := fmt.Sprintln(strings.Join(values, " "))
  96. fmt.Fprint(w, res)
  97. return res
  98. }
  99. func queryRowByIDString(row any) string {
  100. v := reflect.ValueOf(row).Elem()
  101. t := v.Type()
  102. var table string
  103. fields := []string{}
  104. for i := 0; i < t.NumField(); i++ {
  105. if table == "" {
  106. if tag := t.Field(i).Tag.Get("table"); tag != "" {
  107. table = tag
  108. }
  109. }
  110. if tag := t.Field(i).Tag.Get("field"); tag != "" {
  111. fields = append(fields, tag)
  112. }
  113. }
  114. return `SELECT ` + strings.Join(fields, ", ") + ` FROM ` + table + ` WHERE id = $1 LIMIT 1`
  115. }
  116. func rowExistsString(row any) string {
  117. v := reflect.ValueOf(row).Elem()
  118. t := v.Type()
  119. var table string
  120. for i := 0; i < t.NumField(); i++ {
  121. if table == "" {
  122. if tag := t.Field(i).Tag.Get("table"); tag != "" {
  123. table = tag
  124. }
  125. }
  126. }
  127. return `SELECT 1 FROM ` + table + ` WHERE id = $1 LIMIT 1`
  128. }
  129. func scans(row any) []any {
  130. v := reflect.ValueOf(row).Elem()
  131. res := make([]interface{}, v.NumField())
  132. for i := 0; i < v.NumField(); i++ {
  133. res[i] = v.Field(i).Addr().Interface()
  134. }
  135. return res
  136. }
  137. func ParseUrl(dbURL string) (*url.URL, error) {
  138. databaseURL, err := url.Parse(dbURL)
  139. if err != nil {
  140. return nil, fmt.Errorf("unable to parse URL: %w", err)
  141. }
  142. if databaseURL.Scheme == "" {
  143. return nil, fmt.Errorf("protocol scheme is not defined")
  144. }
  145. protocols := []string{"mysql", "postgres", "postgresql", "sqlite", "sqlite3"}
  146. if !slices.Contains(protocols, databaseURL.Scheme) {
  147. return nil, fmt.Errorf("unsupported protocol scheme: %s", databaseURL.Scheme)
  148. }
  149. return databaseURL, nil
  150. }
  151. func OpenDB(databaseURL *url.URL, migrationsDir string, skipMigration bool, debug bool) (*sql.DB, error) {
  152. mate := dbmate.New(databaseURL)
  153. mate.AutoDumpSchema = false
  154. mate.Log = io.Discard
  155. if migrationsDir != "" {
  156. mate.MigrationsDir = migrationsDir
  157. }
  158. driver, err := mate.GetDriver()
  159. if err != nil {
  160. return nil, fmt.Errorf("DB get driver error: %w", err)
  161. }
  162. if !skipMigration {
  163. if err := mate.CreateAndMigrate(); err != nil {
  164. return nil, fmt.Errorf("DB migration error: %w", err)
  165. }
  166. }
  167. var db *sql.DB
  168. start := time.Now()
  169. db, err = driver.Open()
  170. if debug {
  171. log(os.Stdout, "Open", start, err, false, "")
  172. }
  173. if err != nil {
  174. return nil, fmt.Errorf("DB open error: %w", err)
  175. }
  176. return db, nil
  177. }