Browse Source

Add QueryRowByID func

Volodymyr Tkach 2 years ago
parent
commit
9732c1d14f

+ 0 - 1
go.mod

@@ -6,7 +6,6 @@ require (
 	github.com/amacneil/dbmate v1.15.0
 	github.com/onsi/ginkgo v1.16.5
 	github.com/onsi/gomega v1.19.0
-	github.com/pkg/errors v0.9.1
 )
 
 require (

+ 0 - 2
go.sum

@@ -39,8 +39,6 @@ github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7J
 github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
 github.com/onsi/gomega v1.19.0 h1:4ieX6qQjPP/BfC3mpsAtIGGlxTWPeA3Inl/7DtXw1tw=
 github.com/onsi/gomega v1.19.0/go.mod h1:LY+I3pBVzYsTBU1AnDwOSxaYi9WoWiqgwooUqq9yPro=
-github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
-github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
 github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=

+ 19 - 0
gosql/common/common.go

@@ -28,6 +28,7 @@ type Engine interface {
 	Prepare(ctx context.Context, query string) (*sql.Stmt, error)
 	Query(ctx context.Context, query string, args ...any) (*Rows, error)
 	QueryRow(ctx context.Context, query string, args ...any) *Row
+	QueryRowByID(ctx context.Context, id int64, row any) error
 	SetConnMaxLifetime(d time.Duration)
 	SetMaxIdleConns(n int)
 	SetMaxOpenConns(n int)
@@ -101,6 +102,24 @@ func scans(row any) []any {
 	return res
 }
 
+func queryRowByIDString(row any) string {
+	v := reflect.ValueOf(row).Elem()
+	t := v.Type()
+	var table string
+	fields := []string{}
+	for i := 0; i < t.NumField(); i++ {
+		if table == "" {
+			if tag := t.Field(i).Tag.Get("table"); tag != "" {
+				table = tag
+			}
+		}
+		if tag := t.Field(i).Tag.Get("field"); tag != "" {
+			fields = append(fields, tag)
+		}
+	}
+	return `SELECT ` + strings.Join(fields, ", ") + ` FROM ` + table + ` WHERE id = $1 LIMIT 1`
+}
+
 func ParseUrl(dbURL string) (*url.URL, error) {
 	databaseURL, err := url.Parse(dbURL)
 	if err != nil {

+ 1 - 0
gosql/common/common_export_test.go

@@ -3,3 +3,4 @@ package common
 var FixQuery = fixQuery
 var Log = log
 var Scans = scans
+var QueryRowByIDString = queryRowByIDString

+ 12 - 0
gosql/common/common_test.go

@@ -79,6 +79,18 @@ var _ = Describe("common", func() {
 		})
 	})
 
+	Context("queryRowByIDString", func() {
+		It("convert struct to select SQL query", func() {
+			var row struct {
+				ID    int64  `field:"id" table:"users"`
+				Name  string `field:"name"`
+				Value string `field:"value"`
+			}
+
+			Expect(common.QueryRowByIDString(&row)).To(Equal(`SELECT id, name, value FROM users WHERE id = $1 LIMIT 1`))
+		})
+	})
+
 	Context("ParseUrl", func() {
 		Context("Success", func() {
 			It("for MySQL", func() {

+ 5 - 0
gosql/common/dbmethods.go

@@ -102,6 +102,11 @@ func (d *DBMethods) QueryRow(ctx context.Context, query string, args ...any) *Ro
 	return &Row{Row: row}
 }
 
+func (d *DBMethods) QueryRowByID(ctx context.Context, id int64, row any) error {
+	query := queryRowByIDString(row)
+	return d.QueryRow(ctx, query, id).Scans(row)
+}
+
 func (d *DBMethods) SetConnMaxLifetime(t time.Duration) {
 	start := time.Now()
 	d.DB.SetConnMaxLifetime(t)

+ 5 - 0
gosql/common/tx.go

@@ -56,6 +56,11 @@ func (t *Tx) QueryRow(ctx context.Context, query string, args ...any) *Row {
 	return &Row{Row: row}
 }
 
+func (t *Tx) QueryRowByID(ctx context.Context, id int64, row any) error {
+	query := queryRowByIDString(row)
+	return t.QueryRow(ctx, query, id).Scans(row)
+}
+
 func (t *Tx) Rollback() error {
 	err := t.tx.Rollback()
 	t.log("Rollback", t.start, err, true, "")

+ 55 - 0
gosql/gosql_test.go

@@ -44,6 +44,23 @@ var _ = Describe("gosql", func() {
 		// 		Expect(id).To(Equal(2))
 		// 		Expect(name).To(Equal("Bob"))
 
+		// 		Expect(db.Close()).To(Succeed())
+		// 	})
+
+		// 	It("open connection, migrate and select by ID", func() {
+		// 		db, err := gosql.Open("mysql://root:root@127.0.0.1:3306/gosql", migrationsDir, false, false)
+		// 		Expect(err).To(Succeed())
+
+		// 		var rowUser struct {
+		// 			ID   int64  `field:"id" table:"users"`
+		// 			Name string `field:"name"`
+		// 		}
+
+		// 		err = db.QueryRowByID(ctx, 1, &rowUser)
+		// 		Expect(err).To(Succeed())
+		// 		Expect(rowUser.ID).To(Equal(int64(1)))
+		// 		Expect(rowUser.Name).To(Equal("Alice"))
+
 		// 		Expect(db.Close()).To(Succeed())
 		// 	})
 		// })
@@ -64,6 +81,23 @@ var _ = Describe("gosql", func() {
 		// 		Expect(id).To(Equal(2))
 		// 		Expect(name).To(Equal("Bob"))
 
+		// 		Expect(db.Close()).To(Succeed())
+		// 	})
+
+		// 	It("open connection, migrate and select by ID", func() {
+		// 		db, err := gosql.Open("postgres://root:root@127.0.0.1:5432/gosql?sslmode=disable", migrationsDir, false, false)
+		// 		Expect(err).To(Succeed())
+
+		// 		var rowUser struct {
+		// 			ID   int64  `field:"id" table:"users"`
+		// 			Name string `field:"name"`
+		// 		}
+
+		// 		err = db.QueryRowByID(ctx, 1, &rowUser)
+		// 		Expect(err).To(Succeed())
+		// 		Expect(rowUser.ID).To(Equal(int64(1)))
+		// 		Expect(rowUser.Name).To(Equal("Alice"))
+
 		// 		Expect(db.Close()).To(Succeed())
 		// 	})
 		// })
@@ -89,6 +123,27 @@ var _ = Describe("gosql", func() {
 
 				Expect(db.Close()).To(Succeed())
 			})
+
+			It("open connection, migrate and select by ID", func() {
+				f, err := ioutil.TempFile("", "go-sqlite-test-")
+				Expect(err).To(Succeed())
+				f.Close()
+
+				db, err := gosql.Open("sqlite://"+f.Name(), migrationsDir, false, false)
+				Expect(err).To(Succeed())
+
+				var rowUser struct {
+					ID   int64  `field:"id" table:"users"`
+					Name string `field:"name"`
+				}
+
+				err = db.QueryRowByID(ctx, 1, &rowUser)
+				Expect(err).To(Succeed())
+				Expect(rowUser.ID).To(Equal(int64(1)))
+				Expect(rowUser.Name).To(Equal("Alice"))
+
+				Expect(db.Close()).To(Succeed())
+			})
 		})
 
 		It("open connection and skip migration", func() {