common.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. package common
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "io"
  7. "net/url"
  8. "os"
  9. "reflect"
  10. "regexp"
  11. "strings"
  12. "time"
  13. "github.com/amacneil/dbmate/pkg/dbmate"
  14. _ "github.com/amacneil/dbmate/pkg/driver/mysql"
  15. _ "github.com/amacneil/dbmate/pkg/driver/postgres"
  16. _ "github.com/amacneil/dbmate/pkg/driver/sqlite"
  17. "golang.org/x/exp/slices"
  18. )
  19. type Engine interface {
  20. Begin(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
  21. Close() error
  22. CurrentUnixTimestamp() int64
  23. DeleteRowByID(ctx context.Context, id int64, row any) error
  24. Each(ctx context.Context, query string, logic func(ctx context.Context, rows *Rows) error, args ...any) error
  25. EachPrepared(ctx context.Context, prep *Prepared, logic func(ctx context.Context, rows *Rows) error) error
  26. Exec(ctx context.Context, query string, args ...any) (sql.Result, error)
  27. ExecPrepared(ctx context.Context, prep *Prepared) (sql.Result, error)
  28. Ping(context.Context) error
  29. Prepare(ctx context.Context, query string) (*sql.Stmt, error)
  30. Query(ctx context.Context, query string, args ...any) (*Rows, error)
  31. QueryPrepared(ctx context.Context, prep *Prepared) (*Rows, error)
  32. QueryRow(ctx context.Context, query string, args ...any) *Row
  33. QueryRowByID(ctx context.Context, id int64, row any) error
  34. QueryRowPrepared(ctx context.Context, prep *Prepared) *Row
  35. RowExists(ctx context.Context, id int64, row any) bool
  36. SetConnMaxLifetime(d time.Duration)
  37. SetMaxIdleConns(n int)
  38. SetMaxOpenConns(n int)
  39. Transaction(ctx context.Context, queries func(ctx context.Context, tx *Tx) error) error
  40. }
  41. var rSqlParam = regexp.MustCompile(`\$\d+`)
  42. var rLogSpacesAll = regexp.MustCompile(`[\s\t]+`)
  43. var rLogSpacesEnd = regexp.MustCompile(`[\s\t]+;$`)
  44. func currentUnixTimestamp() int64 {
  45. return time.Now().UTC().Unix()
  46. }
  47. func deleteRowByIDString(row any) string {
  48. v := reflect.ValueOf(row).Elem()
  49. t := v.Type()
  50. var table string
  51. for i := 0; i < t.NumField(); i++ {
  52. if table == "" {
  53. if tag := t.Field(i).Tag.Get("table"); tag != "" {
  54. table = tag
  55. }
  56. }
  57. }
  58. return `DELETE FROM ` + table + ` WHERE id = $1`
  59. }
  60. func fixQuery(query string) string {
  61. return rSqlParam.ReplaceAllString(query, "?")
  62. }
  63. func log(w io.Writer, fname string, start time.Time, err error, tx bool, query string, args ...any) string {
  64. var values []string
  65. bold := "0"
  66. color := "33"
  67. // Transaction or not
  68. if tx {
  69. bold = "1"
  70. values = append(values, "[TX]")
  71. }
  72. // Function name
  73. if fname != "" {
  74. values = append(values, "[func "+fname+"]")
  75. }
  76. // SQL query
  77. if query != "" {
  78. values = append(values, rLogSpacesEnd.ReplaceAllString(
  79. strings.Trim(rLogSpacesAll.ReplaceAllString(query, " "), " "), ";",
  80. ))
  81. }
  82. // Params
  83. if len(args) > 0 {
  84. values = append(values, fmt.Sprintf("(%v)", args))
  85. } else {
  86. values = append(values, "(empty)")
  87. }
  88. // Error
  89. if err != nil {
  90. color = "31"
  91. values = append(values, "("+err.Error()+")")
  92. } else {
  93. values = append(values, "(nil)")
  94. }
  95. // Execute time with close color symbols
  96. values = append(values, fmt.Sprintf("%.3f ms\033[0m", time.Since(start).Seconds()))
  97. // Prepend start caption with colors
  98. values = append([]string{"\033[" + bold + ";" + color + "m[SQL]"}, values...)
  99. res := fmt.Sprintln(strings.Join(values, " "))
  100. fmt.Fprint(w, res)
  101. return res
  102. }
  103. func queryRowByIDString(row any) string {
  104. v := reflect.ValueOf(row).Elem()
  105. t := v.Type()
  106. var table string
  107. fields := []string{}
  108. for i := 0; i < t.NumField(); i++ {
  109. if table == "" {
  110. if tag := t.Field(i).Tag.Get("table"); tag != "" {
  111. table = tag
  112. }
  113. }
  114. if tag := t.Field(i).Tag.Get("field"); tag != "" {
  115. fields = append(fields, tag)
  116. }
  117. }
  118. return `SELECT ` + strings.Join(fields, ", ") + ` FROM ` + table + ` WHERE id = $1 LIMIT 1`
  119. }
  120. func rowExistsString(row any) string {
  121. v := reflect.ValueOf(row).Elem()
  122. t := v.Type()
  123. var table string
  124. for i := 0; i < t.NumField(); i++ {
  125. if table == "" {
  126. if tag := t.Field(i).Tag.Get("table"); tag != "" {
  127. table = tag
  128. }
  129. }
  130. }
  131. return `SELECT 1 FROM ` + table + ` WHERE id = $1 LIMIT 1`
  132. }
  133. func scans(row any) []any {
  134. v := reflect.ValueOf(row).Elem()
  135. res := make([]interface{}, v.NumField())
  136. for i := 0; i < v.NumField(); i++ {
  137. res[i] = v.Field(i).Addr().Interface()
  138. }
  139. return res
  140. }
  141. func ParseUrl(dbURL string) (*url.URL, error) {
  142. databaseURL, err := url.Parse(dbURL)
  143. if err != nil {
  144. return nil, fmt.Errorf("unable to parse URL: %w", err)
  145. }
  146. if databaseURL.Scheme == "" {
  147. return nil, fmt.Errorf("protocol scheme is not defined")
  148. }
  149. protocols := []string{"mysql", "postgres", "postgresql", "sqlite", "sqlite3"}
  150. if !slices.Contains(protocols, databaseURL.Scheme) {
  151. return nil, fmt.Errorf("unsupported protocol scheme: %s", databaseURL.Scheme)
  152. }
  153. return databaseURL, nil
  154. }
  155. func OpenDB(databaseURL *url.URL, migrationsDir string, skipMigration bool, debug bool) (*sql.DB, error) {
  156. mate := dbmate.New(databaseURL)
  157. mate.AutoDumpSchema = false
  158. mate.Log = io.Discard
  159. if migrationsDir != "" {
  160. mate.MigrationsDir = migrationsDir
  161. }
  162. driver, err := mate.GetDriver()
  163. if err != nil {
  164. return nil, fmt.Errorf("DB get driver error: %w", err)
  165. }
  166. if !skipMigration {
  167. if err := mate.CreateAndMigrate(); err != nil {
  168. return nil, fmt.Errorf("DB migration error: %w", err)
  169. }
  170. }
  171. var db *sql.DB
  172. start := time.Now()
  173. db, err = driver.Open()
  174. if debug {
  175. log(os.Stdout, "Open", start, err, false, "")
  176. }
  177. if err != nil {
  178. return nil, fmt.Errorf("DB open error: %w", err)
  179. }
  180. return db, nil
  181. }
  182. func PrepareSQL(query string, args ...any) *Prepared {
  183. return &Prepared{query, args}
  184. }