|
@@ -3,7 +3,10 @@ package common
|
|
|
import (
|
|
|
"context"
|
|
|
"database/sql"
|
|
|
+ "fmt"
|
|
|
"regexp"
|
|
|
+
|
|
|
+ "github.com/pkg/errors"
|
|
|
)
|
|
|
|
|
|
type DBMethods struct {
|
|
@@ -12,6 +15,8 @@ type DBMethods struct {
|
|
|
Driver string
|
|
|
}
|
|
|
|
|
|
+type qFunc func(ctx context.Context, tx *sql.Tx) error
|
|
|
+
|
|
|
var r = regexp.MustCompile(`\$\d+`)
|
|
|
|
|
|
func (db *DBMethods) fixQuery(query string) string {
|
|
@@ -48,3 +53,17 @@ func (db *DBMethods) Query(ctx context.Context, query string, args ...any) (*sql
|
|
|
func (db *DBMethods) QueryRow(ctx context.Context, query string, args ...any) *sql.Row {
|
|
|
return db.DB.QueryRowContext(ctx, db.fixQuery(query), args...)
|
|
|
}
|
|
|
+
|
|
|
+func (db *DBMethods) Transaction(ctx context.Context, queries qFunc) error {
|
|
|
+ if queries == nil {
|
|
|
+ return fmt.Errorf("queries is not set for transaction")
|
|
|
+ }
|
|
|
+ tx, err := db.Begin(ctx, nil)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if err := queries(ctx, tx); err != nil {
|
|
|
+ return errors.Wrap(err, tx.Rollback().Error())
|
|
|
+ }
|
|
|
+ return tx.Commit()
|
|
|
+}
|