Browse Source

Add RowExists func

Volodymyr Tkach 2 years ago
parent
commit
8126af33e2

+ 15 - 0
gosql/common/common.go

@@ -29,6 +29,7 @@ type Engine interface {
 	Query(ctx context.Context, query string, args ...any) (*Rows, error)
 	QueryRow(ctx context.Context, query string, args ...any) *Row
 	QueryRowByID(ctx context.Context, id int64, row any) error
+	RowExists(ctx context.Context, id int64, row any) bool
 	SetConnMaxLifetime(d time.Duration)
 	SetMaxIdleConns(n int)
 	SetMaxOpenConns(n int)
@@ -120,6 +121,20 @@ func queryRowByIDString(row any) string {
 	return `SELECT ` + strings.Join(fields, ", ") + ` FROM ` + table + ` WHERE id = $1 LIMIT 1`
 }
 
+func rowExistsString(row any) string {
+	v := reflect.ValueOf(row).Elem()
+	t := v.Type()
+	var table string
+	for i := 0; i < t.NumField(); i++ {
+		if table == "" {
+			if tag := t.Field(i).Tag.Get("table"); tag != "" {
+				table = tag
+			}
+		}
+	}
+	return `SELECT 1 FROM ` + table + ` WHERE id = $1 LIMIT 1`
+}
+
 func ParseUrl(dbURL string) (*url.URL, error) {
 	databaseURL, err := url.Parse(dbURL)
 	if err != nil {

+ 1 - 0
gosql/common/common_export_test.go

@@ -4,3 +4,4 @@ var FixQuery = fixQuery
 var Log = log
 var Scans = scans
 var QueryRowByIDString = queryRowByIDString
+var RowExistsString = rowExistsString

+ 13 - 1
gosql/common/common_test.go

@@ -80,7 +80,7 @@ var _ = Describe("common", func() {
 	})
 
 	Context("queryRowByIDString", func() {
-		It("convert struct to select SQL query", func() {
+		It("convert struct to SQL query", func() {
 			var row struct {
 				ID    int64  `field:"id" table:"users"`
 				Name  string `field:"name"`
@@ -91,6 +91,18 @@ var _ = Describe("common", func() {
 		})
 	})
 
+	Context("rowExistsString", func() {
+		It("convert struct to SQL query", func() {
+			var row struct {
+				ID    int64  `field:"id" table:"users"`
+				Name  string `field:"name"`
+				Value string `field:"value"`
+			}
+
+			Expect(common.RowExistsString(&row)).To(Equal(`SELECT 1 FROM users WHERE id = $1 LIMIT 1`))
+		})
+	})
+
 	Context("ParseUrl", func() {
 		Context("Success", func() {
 			It("for MySQL", func() {

+ 9 - 0
gosql/common/dbmethods.go

@@ -107,6 +107,15 @@ func (d *DBMethods) QueryRowByID(ctx context.Context, id int64, row any) error {
 	return d.QueryRow(ctx, query, id).Scans(row)
 }
 
+func (d *DBMethods) RowExists(ctx context.Context, id int64, row any) bool {
+	var exists int
+	query := rowExistsString(row)
+	if err := d.QueryRow(ctx, query, id).Scan(&exists); err == nil && exists == 1 {
+		return true
+	}
+	return false
+}
+
 func (d *DBMethods) SetConnMaxLifetime(t time.Duration) {
 	start := time.Now()
 	d.DB.SetConnMaxLifetime(t)

+ 9 - 0
gosql/common/tx.go

@@ -61,6 +61,15 @@ func (t *Tx) QueryRowByID(ctx context.Context, id int64, row any) error {
 	return t.QueryRow(ctx, query, id).Scans(row)
 }
 
+func (t *Tx) RowExists(ctx context.Context, id int64, row any) bool {
+	var exists int
+	query := rowExistsString(row)
+	if err := t.QueryRow(ctx, query, id).Scan(&exists); err == nil && exists == 1 {
+		return true
+	}
+	return false
+}
+
 func (t *Tx) Rollback() error {
 	err := t.tx.Rollback()
 	t.log("Rollback", t.start, err, true, "")

+ 58 - 0
gosql/gosql_test.go

@@ -61,6 +61,24 @@ var _ = Describe("gosql", func() {
 		// 		Expect(rowUser.ID).To(Equal(int64(1)))
 		// 		Expect(rowUser.Name).To(Equal("Alice"))
 
+		// 		Expect(db.Close()).To(Succeed())
+		// 	})
+
+		// 	It("open connection, migrate and check row", func() {
+		// 		db, err := gosql.Open("mysql://root:root@127.0.0.1:3306/gosql", migrationsDir, false, false)
+		// 		Expect(err).To(Succeed())
+
+		// 		var rowUser struct {
+		// 			ID   int64  `field:"id" table:"users"`
+		// 			Name string `field:"name"`
+		// 		}
+
+		// 		Expect(db.RowExists(ctx, 1, &rowUser)).To(BeTrue())
+		// 		Expect(db.RowExists(ctx, 2, &rowUser)).To(BeTrue())
+		// 		Expect(db.RowExists(ctx, 3, &rowUser)).To(BeFalse())
+		// 		Expect(db.RowExists(ctx, 4, &rowUser)).To(BeFalse())
+		// 		Expect(db.RowExists(ctx, 5, &rowUser)).To(BeFalse())
+
 		// 		Expect(db.Close()).To(Succeed())
 		// 	})
 		// })
@@ -98,6 +116,24 @@ var _ = Describe("gosql", func() {
 		// 		Expect(rowUser.ID).To(Equal(int64(1)))
 		// 		Expect(rowUser.Name).To(Equal("Alice"))
 
+		// 		Expect(db.Close()).To(Succeed())
+		// 	})
+
+		// 	It("open connection, migrate and check row", func() {
+		// 		db, err := gosql.Open("postgres://root:root@127.0.0.1:5432/gosql?sslmode=disable", migrationsDir, false, false)
+		// 		Expect(err).To(Succeed())
+
+		// 		var rowUser struct {
+		// 			ID   int64  `field:"id" table:"users"`
+		// 			Name string `field:"name"`
+		// 		}
+
+		// 		Expect(db.RowExists(ctx, 1, &rowUser)).To(BeTrue())
+		// 		Expect(db.RowExists(ctx, 2, &rowUser)).To(BeTrue())
+		// 		Expect(db.RowExists(ctx, 3, &rowUser)).To(BeFalse())
+		// 		Expect(db.RowExists(ctx, 4, &rowUser)).To(BeFalse())
+		// 		Expect(db.RowExists(ctx, 5, &rowUser)).To(BeFalse())
+
 		// 		Expect(db.Close()).To(Succeed())
 		// 	})
 		// })
@@ -144,6 +180,28 @@ var _ = Describe("gosql", func() {
 
 				Expect(db.Close()).To(Succeed())
 			})
+
+			It("open connection, migrate and check row", func() {
+				f, err := ioutil.TempFile("", "go-sqlite-test-")
+				Expect(err).To(Succeed())
+				f.Close()
+
+				db, err := gosql.Open("sqlite://"+f.Name(), migrationsDir, false, false)
+				Expect(err).To(Succeed())
+
+				var rowUser struct {
+					ID   int64  `field:"id" table:"users"`
+					Name string `field:"name"`
+				}
+
+				Expect(db.RowExists(ctx, 1, &rowUser)).To(BeTrue())
+				Expect(db.RowExists(ctx, 2, &rowUser)).To(BeTrue())
+				Expect(db.RowExists(ctx, 3, &rowUser)).To(BeFalse())
+				Expect(db.RowExists(ctx, 4, &rowUser)).To(BeFalse())
+				Expect(db.RowExists(ctx, 5, &rowUser)).To(BeFalse())
+
+				Expect(db.Close()).To(Succeed())
+			})
 		})
 
 		It("open connection and skip migration", func() {