common.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. package common
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "io"
  7. "net/url"
  8. "os"
  9. "reflect"
  10. "regexp"
  11. "strconv"
  12. "strings"
  13. "time"
  14. "github.com/amacneil/dbmate/pkg/dbmate"
  15. _ "github.com/amacneil/dbmate/pkg/driver/mysql"
  16. _ "github.com/amacneil/dbmate/pkg/driver/postgres"
  17. _ "github.com/amacneil/dbmate/pkg/driver/sqlite"
  18. "golang.org/x/exp/slices"
  19. )
  20. type Engine interface {
  21. Begin(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
  22. Close() error
  23. CurrentUnixTimestamp() int64
  24. DeleteRowByID(ctx context.Context, id int64, row any) error
  25. Each(ctx context.Context, query string, logic func(ctx context.Context, rows *Rows) error, args ...any) error
  26. EachPrepared(ctx context.Context, prep *Prepared, logic func(ctx context.Context, rows *Rows) error) error
  27. Exec(ctx context.Context, query string, args ...any) (sql.Result, error)
  28. ExecPrepared(ctx context.Context, prep *Prepared) (sql.Result, error)
  29. InsertRow(ctx context.Context, row any) error
  30. Ping(context.Context) error
  31. Prepare(ctx context.Context, query string) (*sql.Stmt, error)
  32. Query(ctx context.Context, query string, args ...any) (*Rows, error)
  33. QueryPrepared(ctx context.Context, prep *Prepared) (*Rows, error)
  34. QueryRow(ctx context.Context, query string, args ...any) *Row
  35. QueryRowByID(ctx context.Context, id int64, row any) error
  36. QueryRowPrepared(ctx context.Context, prep *Prepared) *Row
  37. RowExists(ctx context.Context, id int64, row any) bool
  38. SetConnMaxLifetime(d time.Duration)
  39. SetMaxIdleConns(n int)
  40. SetMaxOpenConns(n int)
  41. Transaction(ctx context.Context, queries func(ctx context.Context, tx *Tx) error) error
  42. }
  43. var rSqlParam = regexp.MustCompile(`\$\d+`)
  44. var rLogSpacesAll = regexp.MustCompile(`[\s\t]+`)
  45. var rLogSpacesEnd = regexp.MustCompile(`[\s\t]+;$`)
  46. func currentUnixTimestamp() int64 {
  47. return time.Now().UTC().Unix()
  48. }
  49. func deleteRowByIDString(row any) string {
  50. v := reflect.ValueOf(row).Elem()
  51. t := v.Type()
  52. var table string
  53. for i := 0; i < t.NumField(); i++ {
  54. if table == "" {
  55. if tag := t.Field(i).Tag.Get("table"); tag != "" {
  56. table = tag
  57. }
  58. }
  59. }
  60. return `DELETE FROM ` + table + ` WHERE id = $1`
  61. }
  62. func fixQuery(query string) string {
  63. return rSqlParam.ReplaceAllString(query, "?")
  64. }
  65. func insertRowString(row any) (string, []any) {
  66. v := reflect.ValueOf(row).Elem()
  67. t := v.Type()
  68. var table string
  69. fields := []string{}
  70. values := []string{}
  71. args := []any{}
  72. position := 1
  73. for i := 0; i < t.NumField(); i++ {
  74. if table == "" {
  75. if tag := t.Field(i).Tag.Get("table"); tag != "" {
  76. table = tag
  77. }
  78. }
  79. if tag := t.Field(i).Tag.Get("field"); tag != "" && tag != "id" {
  80. fields = append(fields, tag)
  81. values = append(values, "$"+strconv.Itoa(position))
  82. switch t.Field(i).Type.Kind() {
  83. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  84. args = append(args, v.Field(i).Int())
  85. case reflect.Float32, reflect.Float64:
  86. args = append(args, v.Field(i).Float())
  87. case reflect.String:
  88. args = append(args, v.Field(i).String())
  89. }
  90. position++
  91. }
  92. }
  93. return `INSERT INTO ` + table + ` (` + strings.Join(fields, ", ") + `) VALUES (` + strings.Join(values, ", ") + `)`, args
  94. }
  95. func log(w io.Writer, fname string, start time.Time, err error, tx bool, query string, args ...any) string {
  96. var values []string
  97. bold := "0"
  98. color := "33"
  99. // Transaction or not
  100. if tx {
  101. bold = "1"
  102. values = append(values, "[TX]")
  103. }
  104. // Function name
  105. if fname != "" {
  106. values = append(values, "[func "+fname+"]")
  107. }
  108. // SQL query
  109. if query != "" {
  110. values = append(values, rLogSpacesEnd.ReplaceAllString(
  111. strings.Trim(rLogSpacesAll.ReplaceAllString(query, " "), " "), ";",
  112. ))
  113. }
  114. // Params
  115. if len(args) > 0 {
  116. values = append(values, fmt.Sprintf("(%v)", args))
  117. } else {
  118. values = append(values, "(empty)")
  119. }
  120. // Error
  121. if err != nil {
  122. color = "31"
  123. values = append(values, "("+err.Error()+")")
  124. } else {
  125. values = append(values, "(nil)")
  126. }
  127. // Execute time with close color symbols
  128. values = append(values, fmt.Sprintf("%.3f ms\033[0m", time.Since(start).Seconds()))
  129. // Prepend start caption with colors
  130. values = append([]string{"\033[" + bold + ";" + color + "m[SQL]"}, values...)
  131. res := fmt.Sprintln(strings.Join(values, " "))
  132. fmt.Fprint(w, res)
  133. return res
  134. }
  135. func queryRowByIDString(row any) string {
  136. v := reflect.ValueOf(row).Elem()
  137. t := v.Type()
  138. var table string
  139. fields := []string{}
  140. for i := 0; i < t.NumField(); i++ {
  141. if table == "" {
  142. if tag := t.Field(i).Tag.Get("table"); tag != "" {
  143. table = tag
  144. }
  145. }
  146. if tag := t.Field(i).Tag.Get("field"); tag != "" {
  147. fields = append(fields, tag)
  148. }
  149. }
  150. return `SELECT ` + strings.Join(fields, ", ") + ` FROM ` + table + ` WHERE id = $1 LIMIT 1`
  151. }
  152. func rowExistsString(row any) string {
  153. v := reflect.ValueOf(row).Elem()
  154. t := v.Type()
  155. var table string
  156. for i := 0; i < t.NumField(); i++ {
  157. if table == "" {
  158. if tag := t.Field(i).Tag.Get("table"); tag != "" {
  159. table = tag
  160. }
  161. }
  162. }
  163. return `SELECT 1 FROM ` + table + ` WHERE id = $1 LIMIT 1`
  164. }
  165. func scans(row any) []any {
  166. v := reflect.ValueOf(row).Elem()
  167. res := make([]interface{}, v.NumField())
  168. for i := 0; i < v.NumField(); i++ {
  169. res[i] = v.Field(i).Addr().Interface()
  170. }
  171. return res
  172. }
  173. func ParseUrl(dbURL string) (*url.URL, error) {
  174. databaseURL, err := url.Parse(dbURL)
  175. if err != nil {
  176. return nil, fmt.Errorf("unable to parse URL: %w", err)
  177. }
  178. if databaseURL.Scheme == "" {
  179. return nil, fmt.Errorf("protocol scheme is not defined")
  180. }
  181. protocols := []string{"mysql", "postgres", "postgresql", "sqlite", "sqlite3"}
  182. if !slices.Contains(protocols, databaseURL.Scheme) {
  183. return nil, fmt.Errorf("unsupported protocol scheme: %s", databaseURL.Scheme)
  184. }
  185. return databaseURL, nil
  186. }
  187. func OpenDB(databaseURL *url.URL, migrationsDir string, skipMigration bool, debug bool) (*sql.DB, error) {
  188. mate := dbmate.New(databaseURL)
  189. mate.AutoDumpSchema = false
  190. mate.Log = io.Discard
  191. if migrationsDir != "" {
  192. mate.MigrationsDir = migrationsDir
  193. }
  194. driver, err := mate.GetDriver()
  195. if err != nil {
  196. return nil, fmt.Errorf("DB get driver error: %w", err)
  197. }
  198. if !skipMigration {
  199. if err := mate.CreateAndMigrate(); err != nil {
  200. return nil, fmt.Errorf("DB migration error: %w", err)
  201. }
  202. }
  203. var db *sql.DB
  204. start := time.Now()
  205. db, err = driver.Open()
  206. if debug {
  207. log(os.Stdout, "Open", start, err, false, "")
  208. }
  209. if err != nil {
  210. return nil, fmt.Errorf("DB open error: %w", err)
  211. }
  212. return db, nil
  213. }
  214. func PrepareSQL(query string, args ...any) *Prepared {
  215. return &Prepared{query, args}
  216. }