Browse Source

Implement UpdateRow func

Volodymyr Tkach 2 years ago
parent
commit
2a74e03e42

+ 56 - 0
gosql/common/common.go

@@ -43,6 +43,7 @@ type Engine interface {
 	SetMaxIdleConns(n int)
 	SetMaxOpenConns(n int)
 	Transaction(ctx context.Context, queries func(ctx context.Context, tx *Tx) error) error
+	UpdateRow(ctx context.Context, row any) error
 }
 
 var rSqlParam = regexp.MustCompile(`\$\d+`)
@@ -202,6 +203,61 @@ func scans(row any) []any {
 	return res
 }
 
+func updateRowString(row any) (string, []any) {
+	v := reflect.ValueOf(row).Elem()
+	t := v.Type()
+	var id int64
+	var table string
+	fields := []string{}
+	values := []string{}
+	args := []any{}
+	position := 1
+	updated_at := currentUnixTimestamp()
+	for i := 0; i < t.NumField(); i++ {
+		if table == "" {
+			if tag := t.Field(i).Tag.Get("table"); tag != "" {
+				table = tag
+			}
+		}
+		tag := t.Field(i).Tag.Get("field")
+		if tag != "" {
+			if id == 0 && tag == "id" {
+				id = v.Field(i).Int()
+			}
+			if tag != "id" && tag != "created_at" {
+				fields = append(fields, tag)
+				values = append(values, "$"+strconv.Itoa(position))
+				if tag == "updated_at" {
+					args = append(args, updated_at)
+				} else {
+					switch t.Field(i).Type.Kind() {
+					case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+						args = append(args, v.Field(i).Int())
+					case reflect.Float32, reflect.Float64:
+						args = append(args, v.Field(i).Float())
+					case reflect.String:
+						args = append(args, v.Field(i).String())
+					}
+				}
+				position++
+			}
+		}
+	}
+	sql := ""
+	args = append(args, id)
+	sql += "UPDATE " + table + " SET "
+	for i, v := range fields {
+		sql += v + " = " + values[i]
+		if i < len(fields)-1 {
+			sql += ", "
+		} else {
+			sql += " "
+		}
+	}
+	sql += "WHERE id = " + "$" + strconv.Itoa(position)
+	return sql, args
+}
+
 func ParseUrl(dbURL string) (*url.URL, error) {
 	databaseURL, err := url.Parse(dbURL)
 	if err != nil {

+ 1 - 0
gosql/common/common_export_test.go

@@ -7,3 +7,4 @@ var Log = log
 var QueryRowByIDString = queryRowByIDString
 var RowExistsString = rowExistsString
 var Scans = scans
+var UpdateRowString = updateRowString

+ 47 - 0
gosql/common/common_test.go

@@ -159,6 +159,53 @@ var _ = Describe("common", func() {
 		})
 	})
 
+	Context("updateRowString", func() {
+		It("convert struct to SQL query", func() {
+			var row struct {
+				ID       int64  `field:"id" table:"users"`
+				Name     string `field:"name"`
+				Value    string `field:"value"`
+				Position int64  `field:"position"`
+			}
+
+			row.ID = 10
+			row.Name = "Name"
+			row.Value = "Value"
+			row.Position = 59
+
+			sql, args := common.UpdateRowString(&row)
+
+			Expect(sql).To(Equal(`UPDATE users SET name = $1, value = $2, position = $3 WHERE id = $4`))
+
+			Expect(len(args)).To(Equal(4))
+			Expect(args[0]).To(Equal("Name"))
+			Expect(args[1]).To(Equal("Value"))
+			Expect(args[2]).To(Equal(int64(59)))
+			Expect(args[3]).To(Equal(int64(10)))
+		})
+
+		It("convert struct to SQL query and populate updated_at", func() {
+			var row struct {
+				ID        int64  `field:"id" table:"users"`
+				CreatedAt int64  `field:"created_at"`
+				UpdatedAt int64  `field:"updated_at"`
+				Name      string `field:"name"`
+			}
+
+			row.ID = 10
+			row.Name = "Name"
+
+			sql, args := common.UpdateRowString(&row)
+
+			Expect(sql).To(Equal(`UPDATE users SET updated_at = $1, name = $2 WHERE id = $3`))
+
+			Expect(len(args)).To(Equal(3))
+			Expect(args[0].(int64) > 0).To(BeTrue())
+			Expect(args[1]).To(Equal("Name"))
+			Expect(args[2]).To(Equal(int64(10)))
+		})
+	})
+
 	Context("ParseUrl", func() {
 		Context("Success", func() {
 			It("for MySQL", func() {

+ 6 - 0
gosql/common/dbmethods.go

@@ -191,3 +191,9 @@ func (d *DBMethods) Transaction(ctx context.Context, callback func(ctx context.C
 	}
 	return tx.Commit()
 }
+
+func (d *DBMethods) UpdateRow(ctx context.Context, row any) error {
+	query, args := updateRowString(row)
+	_, err := d.Exec(ctx, query, args...)
+	return err
+}

+ 6 - 0
gosql/common/tx.go

@@ -137,3 +137,9 @@ func (t *Tx) Rollback() error {
 	t.log("Rollback", t.start, err, true, "")
 	return err
 }
+
+func (t *Tx) UpdateRow(ctx context.Context, row any) error {
+	query, args := updateRowString(row)
+	_, err := t.Exec(ctx, query, args...)
+	return err
+}