Browse Source

Add Each method, refactoring, code optimization

Volodymyr Tkach 2 years ago
parent
commit
1b414f58fb

+ 9 - 14
gosql/common/common.go

@@ -21,11 +21,12 @@ import (
 type Engine interface {
 type Engine interface {
 	Begin(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
 	Begin(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
 	Close() error
 	Close() error
+	Each(ctx context.Context, query string, logic func(ctx context.Context, rows *Rows) error) error
 	Exec(ctx context.Context, query string, args ...any) (sql.Result, error)
 	Exec(ctx context.Context, query string, args ...any) (sql.Result, error)
 	Ping(context.Context) error
 	Ping(context.Context) error
 	Prepare(ctx context.Context, query string) (*sql.Stmt, error)
 	Prepare(ctx context.Context, query string) (*sql.Stmt, error)
-	Query(ctx context.Context, query string, args ...any) (*sql.Rows, error)
-	QueryRow(ctx context.Context, query string, args ...any) *sql.Row
+	Query(ctx context.Context, query string, args ...any) (*Rows, error)
+	QueryRow(ctx context.Context, query string, args ...any) *Row
 	SetConnMaxLifetime(d time.Duration)
 	SetConnMaxLifetime(d time.Duration)
 	SetMaxIdleConns(n int)
 	SetMaxIdleConns(n int)
 	SetMaxOpenConns(n int)
 	SetMaxOpenConns(n int)
@@ -127,19 +128,13 @@ func OpenDB(databaseURL *url.URL, migrationsDir string, debug bool) (*sql.DB, er
 	}
 	}
 
 
 	var db *sql.DB
 	var db *sql.DB
-
+	start := time.Now()
+	db, err = driver.Open()
 	if debug {
 	if debug {
-		t := time.Now()
-		db, err = driver.Open()
-		log(os.Stdout, "Open", t, err, false, "")
-		if err != nil {
-			return nil, fmt.Errorf("DB open error: %w", err)
-		}
-	} else {
-		db, err = driver.Open()
-		if err != nil {
-			return nil, fmt.Errorf("DB open error: %w", err)
-		}
+		log(os.Stdout, "Open", start, err, false, "")
+	}
+	if err != nil {
+		return nil, fmt.Errorf("DB open error: %w", err)
 	}
 	}
 
 
 	return db, nil
 	return db, nil

+ 70 - 70
gosql/common/dbmethods.go

@@ -15,109 +15,109 @@ type DBMethods struct {
 	Driver string
 	Driver string
 }
 }
 
 
-func (db *DBMethods) fixQuery(query string) string {
-	if db.Driver == "mysql" {
+func (d *DBMethods) fixQuery(query string) string {
+	if d.Driver == "mysql" {
 		return fixQuery(query)
 		return fixQuery(query)
 	}
 	}
 	return query
 	return query
 }
 }
 
 
-func (db *DBMethods) Begin(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
-	if db.Debug {
-		t := time.Now()
-		tx, err := db.DB.BeginTx(ctx, opts)
-		log(os.Stdout, "Begin", t, err, true, "")
-		return &Tx{tx, db.Debug, db.Driver, t}, err
+func (d *DBMethods) log(fname string, start time.Time, err error, tx bool, query string, args ...any) {
+	if d.Debug {
+		log(os.Stdout, fname, start, err, tx, query, args...)
 	}
 	}
+}
 
 
-	tx, err := db.DB.BeginTx(ctx, opts)
-	if err != nil {
-		return nil, err
-	}
-	return &Tx{tx, db.Debug, db.Driver, time.Now()}, err
+func (d *DBMethods) Begin(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
+	start := time.Now()
+	tx, err := d.DB.BeginTx(ctx, opts)
+	d.log("Begin", start, err, true, "")
+	return &Tx{tx, d.Debug, d.Driver, start}, err
+}
+
+func (d *DBMethods) Close() error {
+	start := time.Now()
+	err := d.DB.Close()
+	d.log("Close", start, err, false, "")
+	return err
 }
 }
 
 
-func (db *DBMethods) Close() error {
-	if db.Debug {
-		t := time.Now()
-		err := db.DB.Close()
-		log(os.Stdout, "Close", t, err, false, "")
+func (d *DBMethods) Each(ctx context.Context, query string, callback func(ctx context.Context, rows *Rows) error) error {
+	if callback == nil {
+		return fmt.Errorf("callback is not set")
+	}
+	rows, err := d.Query(ctx, query)
+	if err != nil {
 		return err
 		return err
 	}
 	}
-	return db.DB.Close()
+	defer rows.Close()
+	for rows.Next() {
+		if err := callback(ctx, rows); err != nil {
+			return err
+		}
+	}
+	if err := rows.Err(); err != nil {
+		return err
+	}
+	return nil
 }
 }
 
 
-func (db *DBMethods) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) {
-	if db.Debug {
-		t := time.Now()
-		res, err := db.DB.ExecContext(ctx, db.fixQuery(query), args...)
-		log(os.Stdout, "Exec", t, err, false, db.fixQuery(query), args...)
-		return res, err
-	}
-	return db.DB.ExecContext(ctx, db.fixQuery(query), args...)
+func (d *DBMethods) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) {
+	start := time.Now()
+	res, err := d.DB.ExecContext(ctx, d.fixQuery(query), args...)
+	d.log("Exec", start, err, false, d.fixQuery(query), args...)
+	return res, err
 }
 }
 
 
-func (db *DBMethods) Ping(ctx context.Context) error {
-	if db.Debug {
-		t := time.Now()
-		err := db.DB.PingContext(ctx)
-		log(os.Stdout, "Ping", t, err, false, "")
-		return err
-	}
-	return db.DB.PingContext(ctx)
+func (d *DBMethods) Ping(ctx context.Context) error {
+	start := time.Now()
+	err := d.DB.PingContext(ctx)
+	d.log("Ping", start, err, false, "")
+	return err
 }
 }
 
 
-func (db *DBMethods) Prepare(ctx context.Context, query string) (*sql.Stmt, error) {
-	if db.Debug {
-		t := time.Now()
-		stm, err := db.DB.PrepareContext(ctx, db.fixQuery(query))
-		log(os.Stdout, "Prepare", t, err, false, db.fixQuery(query))
-		return stm, err
-	}
-	return db.DB.PrepareContext(ctx, db.fixQuery(query))
+func (d *DBMethods) Prepare(ctx context.Context, query string) (*sql.Stmt, error) {
+	start := time.Now()
+	stm, err := d.DB.PrepareContext(ctx, d.fixQuery(query))
+	d.log("Prepare", start, err, false, d.fixQuery(query))
+	return stm, err
 }
 }
 
 
-func (db *DBMethods) Query(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
-	if db.Debug {
-		t := time.Now()
-		rows, err := db.DB.QueryContext(ctx, db.fixQuery(query), args...)
-		log(os.Stdout, "Query", t, err, false, db.fixQuery(query), args...)
-		return rows, err
-	}
-	return db.DB.QueryContext(ctx, db.fixQuery(query), args...)
+func (d *DBMethods) Query(ctx context.Context, query string, args ...any) (*Rows, error) {
+	start := time.Now()
+	rows, err := d.DB.QueryContext(ctx, d.fixQuery(query), args...)
+	d.log("Query", start, err, false, d.fixQuery(query), args...)
+	return &Rows{Rows: rows}, err
 }
 }
 
 
-func (db *DBMethods) QueryRow(ctx context.Context, query string, args ...any) *sql.Row {
-	if db.Debug {
-		t := time.Now()
-		row := db.DB.QueryRowContext(ctx, db.fixQuery(query), args...)
-		log(os.Stdout, "QueryRow", t, nil, false, db.fixQuery(query), args...)
-		return row
-	}
-	return db.DB.QueryRowContext(ctx, db.fixQuery(query), args...)
+func (d *DBMethods) QueryRow(ctx context.Context, query string, args ...any) *Row {
+	start := time.Now()
+	row := d.DB.QueryRowContext(ctx, d.fixQuery(query), args...)
+	d.log("QueryRow", start, nil, false, d.fixQuery(query), args...)
+	return &Row{Row: row}
 }
 }
 
 
-func (db *DBMethods) SetConnMaxLifetime(d time.Duration) {
-	db.DB.SetConnMaxLifetime(d)
+func (d *DBMethods) SetConnMaxLifetime(t time.Duration) {
+	d.DB.SetConnMaxLifetime(t)
 }
 }
 
 
-func (db *DBMethods) SetMaxIdleConns(n int) {
-	db.DB.SetMaxIdleConns(n)
+func (d *DBMethods) SetMaxIdleConns(n int) {
+	d.DB.SetMaxIdleConns(n)
 }
 }
 
 
-func (db *DBMethods) SetMaxOpenConns(n int) {
-	db.DB.SetMaxOpenConns(n)
+func (d *DBMethods) SetMaxOpenConns(n int) {
+	d.DB.SetMaxOpenConns(n)
 }
 }
 
 
-func (db *DBMethods) Transaction(ctx context.Context, queries func(ctx context.Context, tx *Tx) error) error {
-	if queries == nil {
-		return fmt.Errorf("queries is not set for transaction")
+func (d *DBMethods) Transaction(ctx context.Context, callback func(ctx context.Context, tx *Tx) error) error {
+	if callback == nil {
+		return fmt.Errorf("callback is not set")
 	}
 	}
-	tx, err := db.Begin(ctx, nil)
+	tx, err := d.Begin(ctx, nil)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	if err := queries(ctx, tx); err != nil {
+	if err := callback(ctx, tx); err != nil {
 		rerr := tx.Rollback()
 		rerr := tx.Rollback()
 		if rerr != nil {
 		if rerr != nil {
 			return fmt.Errorf(
 			return fmt.Errorf(

+ 9 - 0
gosql/common/row.go

@@ -0,0 +1,9 @@
+package common
+
+import (
+	"database/sql"
+)
+
+type Row struct {
+	*sql.Row
+}

+ 23 - 0
gosql/common/rows.go

@@ -0,0 +1,23 @@
+package common
+
+import (
+	"database/sql"
+	"reflect"
+)
+
+type Rows struct {
+	*sql.Rows
+}
+
+func scans(row any) []any {
+	v := reflect.ValueOf(row).Elem()
+	res := make([]interface{}, v.NumField())
+	for i := 0; i < v.NumField(); i++ {
+		res[i] = v.Field(i).Addr().Interface()
+	}
+	return res
+}
+
+func (r *Rows) Scans(row any) error {
+	return r.Rows.Scan(scans(row)...)
+}

+ 3 - 0
gosql/common/rows_export_test.go

@@ -0,0 +1,3 @@
+package common
+
+var Scans = scans

+ 25 - 0
gosql/common/rows_test.go

@@ -0,0 +1,25 @@
+package common_test
+
+import (
+	. "github.com/onsi/ginkgo"
+	. "github.com/onsi/gomega"
+	"github.com/vladimirok5959/golang-sql/gosql/common"
+)
+
+var _ = Describe("common", func() {
+	Context("scans", func() {
+		It("convert struct to array of pointers to this struct fields", func() {
+			var row struct {
+				ID    int64
+				Name  string
+				Value string
+			}
+
+			Expect(common.Scans(&row)).To(Equal([]any{
+				&row.ID,
+				&row.Name,
+				&row.Value,
+			}))
+		})
+	})
+})

+ 31 - 40
gosql/common/tx.go

@@ -13,60 +13,51 @@ type Tx struct {
 
 
 	Debug  bool
 	Debug  bool
 	Driver string
 	Driver string
-	t      time.Time
+	start  time.Time
 }
 }
 
 
-func (db *Tx) fixQuery(query string) string {
-	if db.Driver == "mysql" {
+func (t *Tx) fixQuery(query string) string {
+	if t.Driver == "mysql" {
 		return fixQuery(query)
 		return fixQuery(query)
 	}
 	}
 	return query
 	return query
 }
 }
 
 
-func (db *Tx) Commit() error {
-	if db.Debug {
-		err := db.tx.Commit()
-		log(os.Stdout, "Commit", db.t, err, true, "")
-		return err
+func (t *Tx) log(fname string, start time.Time, err error, tx bool, query string, args ...any) {
+	if t.Debug {
+		log(os.Stdout, fname, start, err, tx, query, args...)
 	}
 	}
-	return db.tx.Commit()
 }
 }
 
 
-func (db *Tx) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) {
-	if db.Debug {
-		t := time.Now()
-		res, err := db.tx.ExecContext(ctx, db.fixQuery(query), args...)
-		log(os.Stdout, "Exec", t, err, true, db.fixQuery(query), args...)
-		return res, err
-	}
-	return db.tx.ExecContext(ctx, db.fixQuery(query), args...)
+func (t *Tx) Commit() error {
+	err := t.tx.Commit()
+	t.log("Commit", t.start, err, true, "")
+	return err
 }
 }
 
 
-func (db *Tx) Query(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
-	if db.Debug {
-		t := time.Now()
-		rows, err := db.tx.QueryContext(ctx, db.fixQuery(query), args...)
-		log(os.Stdout, "Query", t, err, true, db.fixQuery(query), args...)
-		return rows, err
-	}
-	return db.tx.QueryContext(ctx, db.fixQuery(query), args...)
+func (t *Tx) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) {
+	start := time.Now()
+	res, err := t.tx.ExecContext(ctx, t.fixQuery(query), args...)
+	t.log("Exec", start, err, true, t.fixQuery(query), args...)
+	return res, err
 }
 }
 
 
-func (db *Tx) QueryRow(ctx context.Context, query string, args ...any) *sql.Row {
-	if db.Debug {
-		t := time.Now()
-		row := db.tx.QueryRowContext(ctx, db.fixQuery(query), args...)
-		log(os.Stdout, "QueryRow", t, nil, true, db.fixQuery(query), args...)
-		return row
-	}
-	return db.tx.QueryRowContext(ctx, db.fixQuery(query), args...)
+func (t *Tx) Query(ctx context.Context, query string, args ...any) (*Rows, error) {
+	start := time.Now()
+	rows, err := t.tx.QueryContext(ctx, t.fixQuery(query), args...)
+	t.log("Query", start, err, true, t.fixQuery(query), args...)
+	return &Rows{Rows: rows}, err
 }
 }
 
 
-func (db *Tx) Rollback() error {
-	if db.Debug {
-		err := db.tx.Rollback()
-		log(os.Stdout, "Rollback", db.t, err, true, "")
-		return err
-	}
-	return db.tx.Rollback()
+func (t *Tx) QueryRow(ctx context.Context, query string, args ...any) *Row {
+	start := time.Now()
+	row := t.tx.QueryRowContext(ctx, t.fixQuery(query), args...)
+	t.log("QueryRow", start, nil, true, t.fixQuery(query), args...)
+	return &Row{Row: row}
+}
+
+func (t *Tx) Rollback() error {
+	err := t.tx.Rollback()
+	t.log("Rollback", t.start, err, true, "")
+	return err
 }
 }

+ 4 - 0
gosql/gosql.go

@@ -7,6 +7,10 @@ import (
 	"github.com/vladimirok5959/golang-sql/gosql/engine"
 	"github.com/vladimirok5959/golang-sql/gosql/engine"
 )
 )
 
 
+type Row = common.Row
+
+type Rows = common.Rows
+
 type Tx = common.Tx
 type Tx = common.Tx
 
 
 func Open(dbURL, migrationsDir string, debug bool) (common.Engine, error) {
 func Open(dbURL, migrationsDir string, debug bool) (common.Engine, error) {

+ 11 - 16
main.go

@@ -83,26 +83,21 @@ func main() {
 	}
 	}
 
 
 	fmt.Println("Select all rows from users again")
 	fmt.Println("Select all rows from users again")
-	if rows, err := db.Query(
+	if err := db.Each(
 		context.Background(),
 		context.Background(),
 		"SELECT id, name FROM users ORDER BY id ASC",
 		"SELECT id, name FROM users ORDER BY id ASC",
-	); err == nil {
-		type rowStruct struct {
-			ID   int64
-			Name string
-		}
-		defer rows.Close()
-		for rows.Next() {
-			var row rowStruct
-			if err := rows.Scan(&row.ID, &row.Name); err != nil {
-				panic(fmt.Sprintf("%s", err))
+		func(ctx context.Context, rows *gosql.Rows) error {
+			var row struct {
+				ID   int64
+				Name string
+			}
+			if err := rows.Scans(&row); err != nil {
+				return err
 			}
 			}
 			fmt.Printf("ID: %d, Name: %s\n", row.ID, row.Name)
 			fmt.Printf("ID: %d, Name: %s\n", row.ID, row.Name)
-		}
-		if err := rows.Err(); err != nil {
-			panic(fmt.Sprintf("%s", err))
-		}
-	} else {
+			return nil
+		},
+	); err != nil {
 		panic(fmt.Sprintf("%s", err))
 		panic(fmt.Sprintf("%s", err))
 	}
 	}