common.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. package common
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "io"
  7. "net/url"
  8. "os"
  9. "reflect"
  10. "regexp"
  11. "strconv"
  12. "strings"
  13. "time"
  14. "github.com/amacneil/dbmate/pkg/dbmate"
  15. _ "github.com/amacneil/dbmate/pkg/driver/mysql"
  16. _ "github.com/amacneil/dbmate/pkg/driver/postgres"
  17. _ "github.com/amacneil/dbmate/pkg/driver/sqlite"
  18. "golang.org/x/exp/slices"
  19. )
  20. type Engine interface {
  21. Begin(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
  22. Close() error
  23. CurrentUnixTimestamp() int64
  24. DeleteRowByID(ctx context.Context, id int64, row any) error
  25. Each(ctx context.Context, query string, logic func(ctx context.Context, rows *Rows) error, args ...any) error
  26. EachPrepared(ctx context.Context, prep *Prepared, logic func(ctx context.Context, rows *Rows) error) error
  27. Exec(ctx context.Context, query string, args ...any) (sql.Result, error)
  28. ExecPrepared(ctx context.Context, prep *Prepared) (sql.Result, error)
  29. InsertRow(ctx context.Context, row any) error
  30. Ping(context.Context) error
  31. Prepare(ctx context.Context, query string) (*sql.Stmt, error)
  32. PrepareSQL(query string, args ...any) *Prepared
  33. Query(ctx context.Context, query string, args ...any) (*Rows, error)
  34. QueryPrepared(ctx context.Context, prep *Prepared) (*Rows, error)
  35. QueryRow(ctx context.Context, query string, args ...any) *Row
  36. QueryRowByID(ctx context.Context, id int64, row any) error
  37. QueryRowPrepared(ctx context.Context, prep *Prepared) *Row
  38. RowExists(ctx context.Context, id int64, row any) bool
  39. SetConnMaxLifetime(d time.Duration)
  40. SetMaxIdleConns(n int)
  41. SetMaxOpenConns(n int)
  42. Transaction(ctx context.Context, queries func(ctx context.Context, tx *Tx) error) error
  43. UpdateRow(ctx context.Context, row any) error
  44. UpdateRowOnly(ctx context.Context, row any, fields ...string) error
  45. }
  46. var rSqlParam = regexp.MustCompile(`\$\d+`)
  47. var rLogSpacesAll = regexp.MustCompile(`[\s\t]+`)
  48. var rLogSpacesEnd = regexp.MustCompile(`[\s\t]+;$`)
  49. func currentUnixTimestamp() int64 {
  50. return time.Now().UTC().Unix()
  51. }
  52. func deleteRowByIDString(row any) string {
  53. v := reflect.ValueOf(row).Elem()
  54. t := v.Type()
  55. var table string
  56. for i := 0; i < t.NumField(); i++ {
  57. if table == "" {
  58. if tag := t.Field(i).Tag.Get("table"); tag != "" {
  59. table = tag
  60. }
  61. }
  62. }
  63. return `DELETE FROM ` + table + ` WHERE id = $1`
  64. }
  65. func fixQuery(query string) string {
  66. return rSqlParam.ReplaceAllString(query, "?")
  67. }
  68. func inArray(arr []string, str string) bool {
  69. for _, s := range arr {
  70. if s == str {
  71. return true
  72. }
  73. }
  74. return false
  75. }
  76. func insertRowString(row any) (string, []any) {
  77. v := reflect.ValueOf(row).Elem()
  78. t := v.Type()
  79. var table string
  80. fields := []string{}
  81. values := []string{}
  82. args := []any{}
  83. position := 1
  84. created_at := currentUnixTimestamp()
  85. for i := 0; i < t.NumField(); i++ {
  86. if table == "" {
  87. if tag := t.Field(i).Tag.Get("table"); tag != "" {
  88. table = tag
  89. }
  90. }
  91. tag := t.Field(i).Tag.Get("field")
  92. if tag != "" {
  93. if tag != "id" {
  94. fields = append(fields, tag)
  95. values = append(values, "$"+strconv.Itoa(position))
  96. if tag == "created_at" || tag == "updated_at" {
  97. args = append(args, created_at)
  98. } else {
  99. switch t.Field(i).Type.Kind() {
  100. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  101. args = append(args, v.Field(i).Int())
  102. case reflect.Float32, reflect.Float64:
  103. args = append(args, v.Field(i).Float())
  104. case reflect.String:
  105. args = append(args, v.Field(i).String())
  106. }
  107. }
  108. position++
  109. }
  110. }
  111. }
  112. return `INSERT INTO ` + table + ` (` + strings.Join(fields, ", ") + `) VALUES (` + strings.Join(values, ", ") + `)`, args
  113. }
  114. func log(w io.Writer, fname string, start time.Time, err error, tx bool, query string, args ...any) string {
  115. var values []string
  116. bold := "0"
  117. color := "33"
  118. // Transaction or not
  119. if tx {
  120. bold = "1"
  121. values = append(values, "[TX]")
  122. }
  123. // Function name
  124. if fname != "" {
  125. values = append(values, "[func "+fname+"]")
  126. }
  127. // SQL query
  128. if query != "" {
  129. values = append(values, rLogSpacesEnd.ReplaceAllString(
  130. strings.Trim(rLogSpacesAll.ReplaceAllString(query, " "), " "), ";",
  131. ))
  132. }
  133. // Params
  134. if len(args) > 0 {
  135. values = append(values, fmt.Sprintf("(%v)", args))
  136. } else {
  137. values = append(values, "(empty)")
  138. }
  139. // Error
  140. if err != nil {
  141. color = "31"
  142. values = append(values, "("+err.Error()+")")
  143. } else {
  144. values = append(values, "(nil)")
  145. }
  146. // Execute time with close color symbols
  147. values = append(values, fmt.Sprintf("%.3f ms\033[0m", time.Since(start).Seconds()))
  148. // Prepend start caption with colors
  149. values = append([]string{"\033[" + bold + ";" + color + "m[SQL]"}, values...)
  150. res := fmt.Sprintln(strings.Join(values, " "))
  151. fmt.Fprint(w, res)
  152. return res
  153. }
  154. func prepareSQL(query string, args ...any) *Prepared {
  155. return &Prepared{query, args}
  156. }
  157. func queryRowByIDString(row any) string {
  158. v := reflect.ValueOf(row).Elem()
  159. t := v.Type()
  160. var table string
  161. fields := []string{}
  162. for i := 0; i < t.NumField(); i++ {
  163. if table == "" {
  164. if tag := t.Field(i).Tag.Get("table"); tag != "" {
  165. table = tag
  166. }
  167. }
  168. tag := t.Field(i).Tag.Get("field")
  169. if tag != "" {
  170. fields = append(fields, tag)
  171. }
  172. }
  173. return `SELECT ` + strings.Join(fields, ", ") + ` FROM ` + table + ` WHERE id = $1 LIMIT 1`
  174. }
  175. func rowExistsString(row any) string {
  176. v := reflect.ValueOf(row).Elem()
  177. t := v.Type()
  178. var table string
  179. for i := 0; i < t.NumField(); i++ {
  180. if table == "" {
  181. if tag := t.Field(i).Tag.Get("table"); tag != "" {
  182. table = tag
  183. }
  184. }
  185. }
  186. return `SELECT 1 FROM ` + table + ` WHERE id = $1 LIMIT 1`
  187. }
  188. func scans(row any) []any {
  189. v := reflect.ValueOf(row).Elem()
  190. res := make([]interface{}, v.NumField())
  191. for i := 0; i < v.NumField(); i++ {
  192. res[i] = v.Field(i).Addr().Interface()
  193. }
  194. return res
  195. }
  196. func updateRowString(row any, only ...string) (string, []any) {
  197. v := reflect.ValueOf(row).Elem()
  198. t := v.Type()
  199. var id int64
  200. var table string
  201. fields := []string{}
  202. values := []string{}
  203. args := []any{}
  204. position := 1
  205. updated_at := currentUnixTimestamp()
  206. for i := 0; i < t.NumField(); i++ {
  207. if table == "" {
  208. if tag := t.Field(i).Tag.Get("table"); tag != "" {
  209. table = tag
  210. }
  211. }
  212. tag := t.Field(i).Tag.Get("field")
  213. if tag != "" {
  214. if id == 0 && tag == "id" {
  215. id = v.Field(i).Int()
  216. }
  217. if tag != "id" && tag != "created_at" && ((len(only) == 0) || (len(only) > 0 && inArray(only, tag))) {
  218. fields = append(fields, tag)
  219. values = append(values, "$"+strconv.Itoa(position))
  220. if tag == "updated_at" {
  221. args = append(args, updated_at)
  222. } else {
  223. switch t.Field(i).Type.Kind() {
  224. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  225. args = append(args, v.Field(i).Int())
  226. case reflect.Float32, reflect.Float64:
  227. args = append(args, v.Field(i).Float())
  228. case reflect.String:
  229. args = append(args, v.Field(i).String())
  230. }
  231. }
  232. position++
  233. }
  234. }
  235. }
  236. sql := ""
  237. args = append(args, id)
  238. sql += "UPDATE " + table + " SET "
  239. for i, v := range fields {
  240. sql += v + " = " + values[i]
  241. if i < len(fields)-1 {
  242. sql += ", "
  243. } else {
  244. sql += " "
  245. }
  246. }
  247. sql += "WHERE id = " + "$" + strconv.Itoa(position)
  248. return sql, args
  249. }
  250. func ParseUrl(dbURL string) (*url.URL, error) {
  251. databaseURL, err := url.Parse(dbURL)
  252. if err != nil {
  253. return nil, fmt.Errorf("unable to parse URL: %w", err)
  254. }
  255. if databaseURL.Scheme == "" {
  256. return nil, fmt.Errorf("protocol scheme is not defined")
  257. }
  258. protocols := []string{"mysql", "postgres", "postgresql", "sqlite", "sqlite3"}
  259. if !slices.Contains(protocols, databaseURL.Scheme) {
  260. return nil, fmt.Errorf("unsupported protocol scheme: %s", databaseURL.Scheme)
  261. }
  262. return databaseURL, nil
  263. }
  264. func OpenDB(databaseURL *url.URL, migrationsDir string, skipMigration bool, debug bool) (*sql.DB, error) {
  265. mate := dbmate.New(databaseURL)
  266. mate.AutoDumpSchema = false
  267. mate.Log = io.Discard
  268. if migrationsDir != "" {
  269. mate.MigrationsDir = migrationsDir
  270. }
  271. driver, err := mate.GetDriver()
  272. if err != nil {
  273. return nil, fmt.Errorf("DB get driver error: %w", err)
  274. }
  275. if !skipMigration {
  276. if err := mate.CreateAndMigrate(); err != nil {
  277. return nil, fmt.Errorf("DB migration error: %w", err)
  278. }
  279. }
  280. var db *sql.DB
  281. start := time.Now()
  282. db, err = driver.Open()
  283. if debug {
  284. log(os.Stdout, "Open", start, err, false, "")
  285. }
  286. if err != nil {
  287. return nil, fmt.Errorf("DB open error: %w", err)
  288. }
  289. return db, nil
  290. }