Browse Source

Add Scans func to Row struct, optimization, move tests

Volodymyr Tkach 2 years ago
parent
commit
9acb32b9ea

+ 13 - 3
gosql/common/common.go

@@ -7,6 +7,7 @@ import (
 	"io"
 	"net/url"
 	"os"
+	"reflect"
 	"regexp"
 	"strings"
 	"time"
@@ -33,9 +34,13 @@ type Engine interface {
 	Transaction(ctx context.Context, queries func(ctx context.Context, tx *Tx) error) error
 }
 
+var rSqlParam = regexp.MustCompile(`\$\d+`)
 var rLogSpacesAll = regexp.MustCompile(`[\s\t]+`)
 var rLogSpacesEnd = regexp.MustCompile(`[\s\t]+;$`)
-var rSqlParam = regexp.MustCompile(`\$\d+`)
+
+func fixQuery(query string) string {
+	return rSqlParam.ReplaceAllString(query, "?")
+}
 
 func log(w io.Writer, fname string, start time.Time, err error, tx bool, query string, args ...any) string {
 	var values []string
@@ -87,8 +92,13 @@ func log(w io.Writer, fname string, start time.Time, err error, tx bool, query s
 	return res
 }
 
-func fixQuery(query string) string {
-	return rSqlParam.ReplaceAllString(query, "?")
+func scans(row any) []any {
+	v := reflect.ValueOf(row).Elem()
+	res := make([]interface{}, v.NumField())
+	for i := 0; i < v.NumField(); i++ {
+		res[i] = v.Field(i).Addr().Interface()
+	}
+	return res
 }
 
 func ParseUrl(dbURL string) (*url.URL, error) {

+ 2 - 1
gosql/common/common_export_test.go

@@ -1,4 +1,5 @@
 package common
 
-var Log = log
 var FixQuery = fixQuery
+var Log = log
+var Scans = scans

+ 25 - 9
gosql/common/common_test.go

@@ -15,6 +15,18 @@ import (
 )
 
 var _ = Describe("common", func() {
+	Context("fixQuery", func() {
+		It("replace param for MySQL driver", func() {
+			sql := "select id, name from users where id=$1"
+			Expect(common.FixQuery(sql)).To(Equal("select id, name from users where id=?"))
+		})
+
+		It("replace all params for MySQL driver", func() {
+			sql := "insert into users set name=$1 where id=$2"
+			Expect(common.FixQuery(sql)).To(Equal("insert into users set name=? where id=?"))
+		})
+	})
+
 	Context("log", func() {
 		Context("time", func() {
 			It("calculate one second", func() {
@@ -51,15 +63,19 @@ var _ = Describe("common", func() {
 		})
 	})
 
-	Context("fixQuery", func() {
-		It("replace param for MySQL driver", func() {
-			sql := "select id, name from users where id=$1"
-			Expect(common.FixQuery(sql)).To(Equal("select id, name from users where id=?"))
-		})
-
-		It("replace all params for MySQL driver", func() {
-			sql := "insert into users set name=$1 where id=$2"
-			Expect(common.FixQuery(sql)).To(Equal("insert into users set name=? where id=?"))
+	Context("scans", func() {
+		It("convert struct to array of pointers to this struct fields", func() {
+			var row struct {
+				ID    int64
+				Name  string
+				Value string
+			}
+
+			Expect(common.Scans(&row)).To(Equal([]any{
+				&row.ID,
+				&row.Name,
+				&row.Value,
+			}))
 		})
 	})
 

+ 4 - 0
gosql/common/row.go

@@ -7,3 +7,7 @@ import (
 type Row struct {
 	*sql.Row
 }
+
+func (r *Row) Scans(row any) error {
+	return r.Row.Scan(scans(row)...)
+}

+ 0 - 10
gosql/common/rows.go

@@ -2,22 +2,12 @@ package common
 
 import (
 	"database/sql"
-	"reflect"
 )
 
 type Rows struct {
 	*sql.Rows
 }
 
-func scans(row any) []any {
-	v := reflect.ValueOf(row).Elem()
-	res := make([]interface{}, v.NumField())
-	for i := 0; i < v.NumField(); i++ {
-		res[i] = v.Field(i).Addr().Interface()
-	}
-	return res
-}
-
 func (r *Rows) Scans(row any) error {
 	return r.Rows.Scan(scans(row)...)
 }

+ 0 - 3
gosql/common/rows_export_test.go

@@ -1,3 +0,0 @@
-package common
-
-var Scans = scans

+ 0 - 25
gosql/common/rows_test.go

@@ -1,25 +0,0 @@
-package common_test
-
-import (
-	. "github.com/onsi/ginkgo"
-	. "github.com/onsi/gomega"
-	"github.com/vladimirok5959/golang-sql/gosql/common"
-)
-
-var _ = Describe("common", func() {
-	Context("scans", func() {
-		It("convert struct to array of pointers to this struct fields", func() {
-			var row struct {
-				ID    int64
-				Name  string
-				Value string
-			}
-
-			Expect(common.Scans(&row)).To(Equal([]any{
-				&row.ID,
-				&row.Name,
-				&row.Value,
-			}))
-		})
-	})
-})

+ 20 - 4
main.go

@@ -2,6 +2,7 @@ package main
 
 import (
 	"context"
+	"database/sql"
 	"fmt"
 	"io/ioutil"
 	"path/filepath"
@@ -36,7 +37,7 @@ func main() {
 	db.SetMaxOpenConns(8)
 
 	// DB struct here ./db/migrations/20220527233113_test_migration.sql
-	fmt.Println("Insert some data to users table")
+	fmt.Println("Inserting some data to users table")
 	if _, err := db.Exec(
 		context.Background(),
 		"INSERT INTO users (id, name) VALUES ($1, $2)",
@@ -45,7 +46,7 @@ func main() {
 		panic(fmt.Sprintf("%s", err))
 	}
 
-	fmt.Println("Select all rows from users table")
+	fmt.Println("Selecting all rows from users table")
 	if rows, err := db.Query(
 		context.Background(),
 		"SELECT id, name FROM users ORDER BY id ASC",
@@ -69,7 +70,7 @@ func main() {
 		panic(fmt.Sprintf("%s", err))
 	}
 
-	fmt.Println("Update inside transaction")
+	fmt.Println("Updating inside transaction")
 	if err := db.Transaction(context.Background(), func(ctx context.Context, tx *gosql.Tx) error {
 		if _, err := tx.Exec(ctx, "UPDATE users SET name=$1 WHERE id=$2", "John", 1); err != nil {
 			return err
@@ -82,7 +83,7 @@ func main() {
 		panic(fmt.Sprintf("%s", err))
 	}
 
-	fmt.Println("Select all rows from users again")
+	fmt.Println("Selecting all rows from users again")
 	if err := db.Each(
 		context.Background(),
 		"SELECT id, name FROM users ORDER BY id ASC",
@@ -101,6 +102,21 @@ func main() {
 		panic(fmt.Sprintf("%s", err))
 	}
 
+	fmt.Println("Selecting specific user with ID: 5")
+	var row struct {
+		ID   int64
+		Name string
+	}
+	err = db.QueryRow(context.Background(), "SELECT id, name FROM users WHERE id=$1", 5).Scans(&row)
+	if err != nil && err != sql.ErrNoRows {
+		panic(fmt.Sprintf("%s", err))
+	}
+	if err != sql.ErrNoRows {
+		fmt.Printf("ID: %d, Name: %s\n", row.ID, row.Name)
+	} else {
+		fmt.Printf("Record not found\n")
+	}
+
 	// Close DB connection
 	if err := db.Close(); err != nil {
 		panic(fmt.Sprintf("%s", err))