Browse Source

Add Transaction method

Volodymyr Tkach 2 years ago
parent
commit
f7cb2aa407
4 changed files with 23 additions and 0 deletions
  1. 1 0
      go.mod
  2. 2 0
      go.sum
  3. 1 0
      gosql/common/common.go
  4. 19 0
      gosql/common/dbmethods.go

+ 1 - 0
go.mod

@@ -6,6 +6,7 @@ require (
 	github.com/amacneil/dbmate v1.15.0
 	github.com/onsi/ginkgo v1.16.5
 	github.com/onsi/gomega v1.19.0
+	github.com/pkg/errors v0.9.1
 )
 
 require (

+ 2 - 0
go.sum

@@ -39,6 +39,8 @@ github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7J
 github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
 github.com/onsi/gomega v1.19.0 h1:4ieX6qQjPP/BfC3mpsAtIGGlxTWPeA3Inl/7DtXw1tw=
 github.com/onsi/gomega v1.19.0/go.mod h1:LY+I3pBVzYsTBU1AnDwOSxaYi9WoWiqgwooUqq9yPro=
+github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
+github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
 github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=

+ 1 - 0
gosql/common/common.go

@@ -22,6 +22,7 @@ type Engine interface {
 	Prepare(ctx context.Context, query string) (*sql.Stmt, error)
 	Query(ctx context.Context, query string, args ...any) (*sql.Rows, error)
 	QueryRow(ctx context.Context, query string, args ...any) *sql.Row
+	Transaction(ctx context.Context, queries qFunc) error
 }
 
 func ParseUrl(dbURL string) (*url.URL, error) {

+ 19 - 0
gosql/common/dbmethods.go

@@ -3,7 +3,10 @@ package common
 import (
 	"context"
 	"database/sql"
+	"fmt"
 	"regexp"
+
+	"github.com/pkg/errors"
 )
 
 type DBMethods struct {
@@ -12,6 +15,8 @@ type DBMethods struct {
 	Driver string
 }
 
+type qFunc func(ctx context.Context, tx *sql.Tx) error
+
 var r = regexp.MustCompile(`\$\d+`)
 
 func (db *DBMethods) fixQuery(query string) string {
@@ -48,3 +53,17 @@ func (db *DBMethods) Query(ctx context.Context, query string, args ...any) (*sql
 func (db *DBMethods) QueryRow(ctx context.Context, query string, args ...any) *sql.Row {
 	return db.DB.QueryRowContext(ctx, db.fixQuery(query), args...)
 }
+
+func (db *DBMethods) Transaction(ctx context.Context, queries qFunc) error {
+	if queries == nil {
+		return fmt.Errorf("queries is not set for transaction")
+	}
+	tx, err := db.Begin(ctx, nil)
+	if err != nil {
+		return err
+	}
+	if err := queries(ctx, tx); err != nil {
+		return errors.Wrap(err, tx.Rollback().Error())
+	}
+	return tx.Commit()
+}