123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324 |
- package common
- import (
- "context"
- "database/sql"
- "fmt"
- "io"
- "net/url"
- "os"
- "reflect"
- "regexp"
- "strconv"
- "strings"
- "time"
- "github.com/amacneil/dbmate/pkg/dbmate"
- _ "github.com/amacneil/dbmate/pkg/driver/mysql"
- _ "github.com/amacneil/dbmate/pkg/driver/postgres"
- _ "github.com/amacneil/dbmate/pkg/driver/sqlite"
- "golang.org/x/exp/slices"
- )
- type Engine interface {
- Begin(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
- Close() error
- CurrentUnixTimestamp() int64
- DeleteRowByID(ctx context.Context, id int64, row any) error
- Each(ctx context.Context, query string, logic func(ctx context.Context, rows *Rows) error, args ...any) error
- EachPrepared(ctx context.Context, prep *Prepared, logic func(ctx context.Context, rows *Rows) error) error
- Exec(ctx context.Context, query string, args ...any) (sql.Result, error)
- ExecPrepared(ctx context.Context, prep *Prepared) (sql.Result, error)
- InsertRow(ctx context.Context, row any) error
- Ping(context.Context) error
- Prepare(ctx context.Context, query string) (*sql.Stmt, error)
- PrepareSQL(query string, args ...any) *Prepared
- Query(ctx context.Context, query string, args ...any) (*Rows, error)
- QueryPrepared(ctx context.Context, prep *Prepared) (*Rows, error)
- QueryRow(ctx context.Context, query string, args ...any) *Row
- QueryRowByID(ctx context.Context, id int64, row any) error
- QueryRowPrepared(ctx context.Context, prep *Prepared) *Row
- RowExists(ctx context.Context, id int64, row any) bool
- SetConnMaxLifetime(d time.Duration)
- SetMaxIdleConns(n int)
- SetMaxOpenConns(n int)
- Transaction(ctx context.Context, queries func(ctx context.Context, tx *Tx) error) error
- UpdateRow(ctx context.Context, row any) error
- UpdateRowOnly(ctx context.Context, row any, fields ...string) error
- }
- var rSqlParam = regexp.MustCompile(`\$\d+`)
- var rLogSpacesAll = regexp.MustCompile(`[\s\t]+`)
- var rLogSpacesEnd = regexp.MustCompile(`[\s\t]+;$`)
- func currentUnixTimestamp() int64 {
- return time.Now().UTC().Unix()
- }
- func deleteRowByIDString(row any) string {
- v := reflect.ValueOf(row).Elem()
- t := v.Type()
- var table string
- for i := 0; i < t.NumField(); i++ {
- if table == "" {
- if tag := t.Field(i).Tag.Get("table"); tag != "" {
- table = tag
- }
- }
- }
- return `DELETE FROM ` + table + ` WHERE id = $1`
- }
- func fixQuery(query string) string {
- return rSqlParam.ReplaceAllString(query, "?")
- }
- func inArray(arr []string, str string) bool {
- for _, s := range arr {
- if s == str {
- return true
- }
- }
- return false
- }
- func insertRowString(row any) (string, []any) {
- v := reflect.ValueOf(row).Elem()
- t := v.Type()
- var table string
- fields := []string{}
- values := []string{}
- args := []any{}
- position := 1
- created_at := currentUnixTimestamp()
- for i := 0; i < t.NumField(); i++ {
- if table == "" {
- if tag := t.Field(i).Tag.Get("table"); tag != "" {
- table = tag
- }
- }
- tag := t.Field(i).Tag.Get("field")
- if tag != "" {
- if tag != "id" {
- fields = append(fields, tag)
- values = append(values, "$"+strconv.Itoa(position))
- if tag == "created_at" || tag == "updated_at" {
- args = append(args, created_at)
- } else {
- switch t.Field(i).Type.Kind() {
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- args = append(args, v.Field(i).Int())
- case reflect.Float32, reflect.Float64:
- args = append(args, v.Field(i).Float())
- case reflect.String:
- args = append(args, v.Field(i).String())
- }
- }
- position++
- }
- }
- }
- return `INSERT INTO ` + table + ` (` + strings.Join(fields, ", ") + `) VALUES (` + strings.Join(values, ", ") + `)`, args
- }
- func log(w io.Writer, fname string, start time.Time, err error, tx bool, query string, args ...any) string {
- var values []string
- bold := "0"
- color := "33"
- // Transaction or not
- if tx {
- bold = "1"
- values = append(values, "[TX]")
- }
- // Function name
- if fname != "" {
- values = append(values, "[func "+fname+"]")
- }
- // SQL query
- if query != "" {
- values = append(values, rLogSpacesEnd.ReplaceAllString(
- strings.Trim(rLogSpacesAll.ReplaceAllString(query, " "), " "), ";",
- ))
- }
- // Params
- if len(args) > 0 {
- values = append(values, fmt.Sprintf("(%v)", args))
- } else {
- values = append(values, "(empty)")
- }
- // Error
- if err != nil {
- color = "31"
- values = append(values, "("+err.Error()+")")
- } else {
- values = append(values, "(nil)")
- }
- // Execute time with close color symbols
- values = append(values, fmt.Sprintf("%.3f ms\033[0m", time.Since(start).Seconds()))
- // Prepend start caption with colors
- values = append([]string{"\033[" + bold + ";" + color + "m[SQL]"}, values...)
- res := fmt.Sprintln(strings.Join(values, " "))
- fmt.Fprint(w, res)
- return res
- }
- func prepareSQL(query string, args ...any) *Prepared {
- return &Prepared{query, args}
- }
- func queryRowByIDString(row any) string {
- v := reflect.ValueOf(row).Elem()
- t := v.Type()
- var table string
- fields := []string{}
- for i := 0; i < t.NumField(); i++ {
- if table == "" {
- if tag := t.Field(i).Tag.Get("table"); tag != "" {
- table = tag
- }
- }
- tag := t.Field(i).Tag.Get("field")
- if tag != "" {
- fields = append(fields, tag)
- }
- }
- return `SELECT ` + strings.Join(fields, ", ") + ` FROM ` + table + ` WHERE id = $1 LIMIT 1`
- }
- func rowExistsString(row any) string {
- v := reflect.ValueOf(row).Elem()
- t := v.Type()
- var table string
- for i := 0; i < t.NumField(); i++ {
- if table == "" {
- if tag := t.Field(i).Tag.Get("table"); tag != "" {
- table = tag
- }
- }
- }
- return `SELECT 1 FROM ` + table + ` WHERE id = $1 LIMIT 1`
- }
- func scans(row any) []any {
- v := reflect.ValueOf(row).Elem()
- res := make([]interface{}, v.NumField())
- for i := 0; i < v.NumField(); i++ {
- res[i] = v.Field(i).Addr().Interface()
- }
- return res
- }
- func updateRowString(row any, only ...string) (string, []any) {
- v := reflect.ValueOf(row).Elem()
- t := v.Type()
- var id int64
- var table string
- fields := []string{}
- values := []string{}
- args := []any{}
- position := 1
- updated_at := currentUnixTimestamp()
- for i := 0; i < t.NumField(); i++ {
- if table == "" {
- if tag := t.Field(i).Tag.Get("table"); tag != "" {
- table = tag
- }
- }
- tag := t.Field(i).Tag.Get("field")
- if tag != "" {
- if id == 0 && tag == "id" {
- id = v.Field(i).Int()
- }
- if tag != "id" && tag != "created_at" && ((len(only) == 0) || (len(only) > 0 && inArray(only, tag))) {
- fields = append(fields, tag)
- values = append(values, "$"+strconv.Itoa(position))
- if tag == "updated_at" {
- args = append(args, updated_at)
- } else {
- switch t.Field(i).Type.Kind() {
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- args = append(args, v.Field(i).Int())
- case reflect.Float32, reflect.Float64:
- args = append(args, v.Field(i).Float())
- case reflect.String:
- args = append(args, v.Field(i).String())
- }
- }
- position++
- }
- }
- }
- sql := ""
- args = append(args, id)
- sql += "UPDATE " + table + " SET "
- for i, v := range fields {
- sql += v + " = " + values[i]
- if i < len(fields)-1 {
- sql += ", "
- } else {
- sql += " "
- }
- }
- sql += "WHERE id = " + "$" + strconv.Itoa(position)
- return sql, args
- }
- func ParseUrl(dbURL string) (*url.URL, error) {
- databaseURL, err := url.Parse(dbURL)
- if err != nil {
- return nil, fmt.Errorf("unable to parse URL: %w", err)
- }
- if databaseURL.Scheme == "" {
- return nil, fmt.Errorf("protocol scheme is not defined")
- }
- protocols := []string{"mysql", "postgres", "postgresql", "sqlite", "sqlite3"}
- if !slices.Contains(protocols, databaseURL.Scheme) {
- return nil, fmt.Errorf("unsupported protocol scheme: %s", databaseURL.Scheme)
- }
- return databaseURL, nil
- }
- func OpenDB(databaseURL *url.URL, migrationsDir string, skipMigration bool, debug bool) (*sql.DB, error) {
- mate := dbmate.New(databaseURL)
- mate.AutoDumpSchema = false
- mate.Log = io.Discard
- if migrationsDir != "" {
- mate.MigrationsDir = migrationsDir
- }
- driver, err := mate.GetDriver()
- if err != nil {
- return nil, fmt.Errorf("DB get driver error: %w", err)
- }
- if !skipMigration {
- if err := mate.CreateAndMigrate(); err != nil {
- return nil, fmt.Errorf("DB migration error: %w", err)
- }
- }
- var db *sql.DB
- start := time.Now()
- db, err = driver.Open()
- if debug {
- log(os.Stdout, "Open", start, err, false, "")
- }
- if err != nil {
- return nil, fmt.Errorf("DB open error: %w", err)
- }
- return db, nil
- }
|