Browse Source

Log also transactions

Volodymyr Tkach 2 years ago
parent
commit
8ea616bfba
3 changed files with 126 additions and 49 deletions
  1. 39 1
      gosql/common/common.go
  2. 16 48
      gosql/common/dbmethods.go
  3. 71 0
      gosql/common/tx.go

+ 39 - 1
gosql/common/common.go

@@ -6,6 +6,9 @@ import (
 	"fmt"
 	"io"
 	"net/url"
+	"os"
+	"strings"
+	"time"
 
 	"github.com/amacneil/dbmate/pkg/dbmate"
 	_ "github.com/amacneil/dbmate/pkg/driver/mysql"
@@ -15,7 +18,7 @@ import (
 )
 
 type Engine interface {
-	Begin(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
+	Begin(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
 	Close() error
 	Exec(ctx context.Context, query string, args ...any) (sql.Result, error)
 	Ping(context.Context) error
@@ -25,6 +28,41 @@ type Engine interface {
 	Transaction(ctx context.Context, queries queryFunc) error
 }
 
+func log(m string, s time.Time, e error, tx bool, query string, args ...any) {
+	var tmsg string
+
+	if tx {
+		tmsg = " [TX]"
+	}
+	if m != "" {
+		tmsg = tmsg + " " + m
+	}
+
+	qmsg := query
+	if qmsg != "" {
+		qmsg = strings.Trim(rLogSpacesAll.ReplaceAllString(qmsg, " "), " ")
+		qmsg = rLogSpacesEnd.ReplaceAllString(qmsg, ";")
+		qmsg = " " + qmsg
+	}
+
+	astr := " (empty)"
+	if len(args) > 0 {
+		astr = fmt.Sprintf(" (%v)", args)
+	}
+
+	estr := " (nil)"
+	if e != nil {
+		estr = " \033[0m\033[0;31m(" + e.Error() + ")"
+	}
+
+	color := "0;33"
+	if tx {
+		color = "1;33"
+	}
+
+	fmt.Fprintln(os.Stdout, "\033["+color+"m[SQL]"+tmsg+qmsg+astr+estr+fmt.Sprintf(" %.3f ms", time.Since(s).Seconds())+"\033[0m")
+}
+
 func ParseUrl(dbURL string) (*url.URL, error) {
 	databaseURL, err := url.Parse(dbURL)
 	if err != nil {

+ 16 - 48
gosql/common/dbmethods.go

@@ -4,9 +4,7 @@ import (
 	"context"
 	"database/sql"
 	"fmt"
-	"os"
 	"regexp"
-	"strings"
 	"time"
 
 	"github.com/pkg/errors"
@@ -23,42 +21,7 @@ var rLogSpacesAll = regexp.MustCompile(`[\s\t]+`)
 var rLogSpacesEnd = regexp.MustCompile(`[\s\t]+;$`)
 var rSqlParam = regexp.MustCompile(`\$\d+`)
 
-type queryFunc func(ctx context.Context, tx *sql.Tx) error
-
-func (db *DBMethods) log(m string, s time.Time, e error, tx bool, query string, args ...any) {
-	var tmsg string
-
-	if tx {
-		tmsg = " [TX]"
-	}
-	if m != "" {
-		tmsg = tmsg + " " + m
-	}
-
-	qmsg := query
-	if qmsg != "" {
-		qmsg = strings.Trim(rLogSpacesAll.ReplaceAllString(qmsg, " "), " ")
-		qmsg = rLogSpacesEnd.ReplaceAllString(qmsg, ";")
-		qmsg = " " + qmsg
-	}
-
-	astr := " (empty)"
-	if len(args) > 0 {
-		astr = fmt.Sprintf(" (%v)", args)
-	}
-
-	estr := " (nil)"
-	if e != nil {
-		estr = " \033[0m\033[0;31m(" + e.Error() + ")"
-	}
-
-	color := "0;33"
-	if tx {
-		color = "1;33"
-	}
-
-	fmt.Fprintln(os.Stdout, "\033["+color+"m[SQL]"+tmsg+qmsg+astr+estr+fmt.Sprintf(" %.3f ms", time.Since(s).Seconds())+"\033[0m")
-}
+type queryFunc func(ctx context.Context, tx *Tx) error
 
 func (db *DBMethods) fixQuery(query string) string {
 	if db.Driver == "mysql" {
@@ -67,21 +30,26 @@ func (db *DBMethods) fixQuery(query string) string {
 	return query
 }
 
-func (db *DBMethods) Begin(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
+func (db *DBMethods) Begin(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
 	if db.Debug {
 		t := time.Now()
 		tx, err := db.DB.BeginTx(ctx, opts)
-		db.log("[func Begin]", t, err, true, "")
-		return tx, err
+		log("[func Begin]", t, err, true, "")
+		return &Tx{tx, db.Debug, db.Driver, t}, err
+	}
+
+	tx, err := db.DB.BeginTx(ctx, opts)
+	if err != nil {
+		return nil, err
 	}
-	return db.DB.BeginTx(ctx, opts)
+	return &Tx{tx, db.Debug, db.Driver, time.Now()}, err
 }
 
 func (db *DBMethods) Close() error {
 	if db.Debug {
 		t := time.Now()
 		err := db.DB.Close()
-		db.log("[func Close]", t, err, false, "")
+		log("[func Close]", t, err, false, "")
 		return err
 	}
 	return db.DB.Close()
@@ -91,7 +59,7 @@ func (db *DBMethods) Exec(ctx context.Context, query string, args ...any) (sql.R
 	if db.Debug {
 		t := time.Now()
 		res, err := db.DB.ExecContext(ctx, db.fixQuery(query), args...)
-		db.log("[func Exec]", t, err, false, db.fixQuery(query), args...)
+		log("[func Exec]", t, err, false, db.fixQuery(query), args...)
 		return res, err
 	}
 	return db.DB.ExecContext(ctx, db.fixQuery(query), args...)
@@ -101,7 +69,7 @@ func (db *DBMethods) Ping(ctx context.Context) error {
 	if db.Debug {
 		t := time.Now()
 		err := db.DB.PingContext(ctx)
-		db.log("[func Ping]", t, err, false, "")
+		log("[func Ping]", t, err, false, "")
 		return err
 	}
 	return db.DB.PingContext(ctx)
@@ -111,7 +79,7 @@ func (db *DBMethods) Prepare(ctx context.Context, query string) (*sql.Stmt, erro
 	if db.Debug {
 		t := time.Now()
 		stm, err := db.DB.PrepareContext(ctx, db.fixQuery(query))
-		db.log("[func Prepare]", t, err, false, db.fixQuery(query))
+		log("[func Prepare]", t, err, false, db.fixQuery(query))
 		return stm, err
 	}
 	return db.DB.PrepareContext(ctx, db.fixQuery(query))
@@ -121,7 +89,7 @@ func (db *DBMethods) Query(ctx context.Context, query string, args ...any) (*sql
 	if db.Debug {
 		t := time.Now()
 		rows, err := db.DB.QueryContext(ctx, db.fixQuery(query), args...)
-		db.log("[func Query]", t, err, false, db.fixQuery(query), args...)
+		log("[func Query]", t, err, false, db.fixQuery(query), args...)
 		return rows, err
 	}
 	return db.DB.QueryContext(ctx, db.fixQuery(query), args...)
@@ -131,7 +99,7 @@ func (db *DBMethods) QueryRow(ctx context.Context, query string, args ...any) *s
 	if db.Debug {
 		t := time.Now()
 		row := db.DB.QueryRowContext(ctx, db.fixQuery(query), args...)
-		db.log("[func QueryRow]", t, nil, false, db.fixQuery(query), args...)
+		log("[func QueryRow]", t, nil, false, db.fixQuery(query), args...)
 		return row
 	}
 	return db.DB.QueryRowContext(ctx, db.fixQuery(query), args...)

+ 71 - 0
gosql/common/tx.go

@@ -0,0 +1,71 @@
+package common
+
+import (
+	"context"
+	"database/sql"
+
+	"time"
+)
+
+type Tx struct {
+	tx *sql.Tx
+
+	Debug  bool
+	Driver string
+	t      time.Time
+}
+
+func (db *Tx) fixQuery(query string) string {
+	if db.Driver == "mysql" {
+		return rSqlParam.ReplaceAllString(query, "?")
+	}
+	return query
+}
+
+func (db *Tx) Commit() error {
+	if db.Debug {
+		err := db.tx.Commit()
+		log("[func Commit]", db.t, err, true, "")
+		return err
+	}
+	return db.tx.Commit()
+}
+
+func (db *Tx) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) {
+	if db.Debug {
+		t := time.Now()
+		res, err := db.tx.ExecContext(ctx, db.fixQuery(query), args...)
+		log("[func Exec]", t, err, true, db.fixQuery(query), args...)
+		return res, err
+	}
+	return db.tx.ExecContext(ctx, db.fixQuery(query), args...)
+}
+
+func (db *Tx) Query(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
+	if db.Debug {
+		t := time.Now()
+		rows, err := db.tx.QueryContext(ctx, db.fixQuery(query), args...)
+		log("[func Query]", t, err, true, db.fixQuery(query), args...)
+		return rows, err
+	}
+	return db.tx.QueryContext(ctx, db.fixQuery(query), args...)
+}
+
+func (db *Tx) QueryRow(ctx context.Context, query string, args ...any) *sql.Row {
+	if db.Debug {
+		t := time.Now()
+		row := db.tx.QueryRowContext(ctx, db.fixQuery(query), args...)
+		log("[func QueryRow]", t, nil, true, db.fixQuery(query), args...)
+		return row
+	}
+	return db.tx.QueryRowContext(ctx, db.fixQuery(query), args...)
+}
+
+func (db *Tx) Rollback() error {
+	if db.Debug {
+		err := db.tx.Rollback()
+		log("[func Rollback]", db.t, err, true, "")
+		return err
+	}
+	return db.tx.Rollback()
+}