Browse Source

Add transaction example, fix name params func error

Volodymyr Tkach 2 years ago
parent
commit
adb5759781
3 changed files with 43 additions and 7 deletions
  1. 1 3
      gosql/common/common.go
  2. 1 1
      gosql/common/dbmethods.go
  3. 41 3
      main.go

+ 1 - 3
gosql/common/common.go

@@ -17,8 +17,6 @@ import (
 	"golang.org/x/exp/slices"
 	"golang.org/x/exp/slices"
 )
 )
 
 
-type queryFunc func(ctx context.Context, tx *Tx) error
-
 type Engine interface {
 type Engine interface {
 	Begin(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
 	Begin(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
 	Close() error
 	Close() error
@@ -27,7 +25,7 @@ type Engine interface {
 	Prepare(ctx context.Context, query string) (*sql.Stmt, error)
 	Prepare(ctx context.Context, query string) (*sql.Stmt, error)
 	Query(ctx context.Context, query string, args ...any) (*sql.Rows, error)
 	Query(ctx context.Context, query string, args ...any) (*sql.Rows, error)
 	QueryRow(ctx context.Context, query string, args ...any) *sql.Row
 	QueryRow(ctx context.Context, query string, args ...any) *sql.Row
-	Transaction(ctx context.Context, queries queryFunc) error
+	Transaction(ctx context.Context, queries func(ctx context.Context, tx *Tx) error) error
 }
 }
 
 
 var rLogSpacesAll = regexp.MustCompile(`[\s\t]+`)
 var rLogSpacesAll = regexp.MustCompile(`[\s\t]+`)

+ 1 - 1
gosql/common/dbmethods.go

@@ -99,7 +99,7 @@ func (db *DBMethods) QueryRow(ctx context.Context, query string, args ...any) *s
 	return db.DB.QueryRowContext(ctx, db.fixQuery(query), args...)
 	return db.DB.QueryRowContext(ctx, db.fixQuery(query), args...)
 }
 }
 
 
-func (db *DBMethods) Transaction(ctx context.Context, queries queryFunc) error {
+func (db *DBMethods) Transaction(ctx context.Context, queries func(ctx context.Context, tx *Tx) error) error {
 	if queries == nil {
 	if queries == nil {
 		return fmt.Errorf("queries is not set for transaction")
 		return fmt.Errorf("queries is not set for transaction")
 	}
 	}

+ 41 - 3
main.go

@@ -7,6 +7,7 @@ import (
 	"path/filepath"
 	"path/filepath"
 
 
 	"github.com/vladimirok5959/golang-sql/gosql"
 	"github.com/vladimirok5959/golang-sql/gosql"
+	"github.com/vladimirok5959/golang-sql/gosql/common"
 )
 )
 
 
 func main() {
 func main() {
@@ -31,7 +32,7 @@ func main() {
 	}
 	}
 
 
 	// DB struct here ./db/migrations/20220527233113_test_migration.sql
 	// DB struct here ./db/migrations/20220527233113_test_migration.sql
-	// Insert some data to users table:
+	// Insert some data to users table
 	if _, err := db.Exec(
 	if _, err := db.Exec(
 		context.Background(),
 		context.Background(),
 		"INSERT INTO users (id, name) VALUES ($1, $2)",
 		"INSERT INTO users (id, name) VALUES ($1, $2)",
@@ -40,10 +41,47 @@ func main() {
 		panic(fmt.Sprintf("%s", err))
 		panic(fmt.Sprintf("%s", err))
 	}
 	}
 
 
-	// Select all rows from users table:
+	// Select all rows from users table
 	if rows, err := db.Query(
 	if rows, err := db.Query(
 		context.Background(),
 		context.Background(),
-		"SELECT id, name FROM users ORDER BY id DESC",
+		"SELECT id, name FROM users ORDER BY id ASC",
+	); err == nil {
+		type rowStruct struct {
+			ID   int64
+			Name string
+		}
+		defer rows.Close()
+		for rows.Next() {
+			var row rowStruct
+			if err := rows.Scan(&row.ID, &row.Name); err != nil {
+				panic(fmt.Sprintf("%s", err))
+			}
+			fmt.Printf("ID: %d, Name: %s\n", row.ID, row.Name)
+		}
+		if err := rows.Err(); err != nil {
+			panic(fmt.Sprintf("%s", err))
+		}
+	} else {
+		panic(fmt.Sprintf("%s", err))
+	}
+
+	// Update inside transaction
+	if err := db.Transaction(context.Background(), func(ctx context.Context, tx *common.Tx) error {
+		if _, err := tx.Exec(ctx, "UPDATE users SET name=$1 WHERE id=$2", "John", 1); err != nil {
+			return err
+		}
+		if _, err := tx.Exec(ctx, "UPDATE users SET name=$1 WHERE id=$2", "Alice", 5); err != nil {
+			return err
+		}
+		return nil
+	}); err != nil {
+		panic(fmt.Sprintf("%s", err))
+	}
+
+	// Select all rows from users again
+	if rows, err := db.Query(
+		context.Background(),
+		"SELECT id, name FROM users ORDER BY id ASC",
 	); err == nil {
 	); err == nil {
 		type rowStruct struct {
 		type rowStruct struct {
 			ID   int64
 			ID   int64