Browse Source

Disable logs to os.Stdout on tests

Volodymyr Tkach 2 years ago
parent
commit
de61ceb051
4 changed files with 23 additions and 21 deletions
  1. 2 3
      gosql/common/common.go
  2. 7 6
      gosql/common/common_test.go
  3. 8 7
      gosql/common/dbmethods.go
  4. 6 5
      gosql/common/tx.go

+ 2 - 3
gosql/common/common.go

@@ -6,7 +6,6 @@ import (
 	"fmt"
 	"io"
 	"net/url"
-	"os"
 	"regexp"
 	"strings"
 	"time"
@@ -35,7 +34,7 @@ var rLogSpacesAll = regexp.MustCompile(`[\s\t]+`)
 var rLogSpacesEnd = regexp.MustCompile(`[\s\t]+;$`)
 var rSqlParam = regexp.MustCompile(`\$\d+`)
 
-func log(m string, s time.Time, e error, tx bool, query string, args ...any) string {
+func log(w io.Writer, m string, s time.Time, e error, tx bool, query string, args ...any) string {
 	var tmsg string
 
 	if tx {
@@ -68,7 +67,7 @@ func log(m string, s time.Time, e error, tx bool, query string, args ...any) str
 	}
 
 	res := fmt.Sprintln("\033[" + color + "m[SQL]" + tmsg + qmsg + astr + estr + fmt.Sprintf(" %.3f ms", time.Since(s).Seconds()) + "\033[0m")
-	fmt.Fprintln(os.Stdout, res)
+	fmt.Fprintln(w, res)
 	return res
 }
 

+ 7 - 6
gosql/common/common_test.go

@@ -2,6 +2,7 @@ package common_test
 
 import (
 	"fmt"
+	"io"
 	"io/ioutil"
 	"net/url"
 	"path/filepath"
@@ -17,34 +18,34 @@ var _ = Describe("common", func() {
 	Context("log", func() {
 		Context("time", func() {
 			It("calculate one second", func() {
-				str := common.Log("[func Exec]", time.Now().Add(time.Second*-1), nil, false, "")
+				str := common.Log(io.Discard, "[func Exec]", time.Now().Add(time.Second*-1), nil, false, "")
 				Expect(str).To(Equal("\x1b[0;33m[SQL] [func Exec] (empty) (nil) 1.000 ms\x1b[0m\n"))
 			})
 		})
 
 		Context("format", func() {
 			It("with func name", func() {
-				str := common.Log("[func Exec]", time.Now(), nil, false, "")
+				str := common.Log(io.Discard, "[func Exec]", time.Now(), nil, false, "")
 				Expect(str).To(Equal("\x1b[0;33m[SQL] [func Exec] (empty) (nil) 0.000 ms\x1b[0m\n"))
 			})
 
 			It("with sql query", func() {
-				str := common.Log("[func Exec]", time.Now(), nil, false, "select * from users")
+				str := common.Log(io.Discard, "[func Exec]", time.Now(), nil, false, "select * from users")
 				Expect(str).To(Equal("\x1b[0;33m[SQL] [func Exec] select * from users (empty) (nil) 0.000 ms\x1b[0m\n"))
 			})
 
 			It("with error message", func() {
-				str := common.Log("[func Exec]", time.Now(), fmt.Errorf("Exec error"), false, "select * from users")
+				str := common.Log(io.Discard, "[func Exec]", time.Now(), fmt.Errorf("Exec error"), false, "select * from users")
 				Expect(str).To(Equal("\x1b[0;33m[SQL] [func Exec] select * from users (empty) \x1b[0m\x1b[0;31m(Exec error) 0.000 ms\x1b[0m\n"))
 			})
 
 			It("with transaction flag", func() {
-				str := common.Log("[func Exec]", time.Now(), fmt.Errorf("Exec error"), true, "select * from users")
+				str := common.Log(io.Discard, "[func Exec]", time.Now(), fmt.Errorf("Exec error"), true, "select * from users")
 				Expect(str).To(Equal("\x1b[1;33m[SQL] [TX] [func Exec] select * from users (empty) \x1b[0m\x1b[0;31m(Exec error) 0.000 ms\x1b[0m\n"))
 			})
 
 			It("with sql query arguments", func() {
-				str := common.Log("[func Exec]", time.Now(), fmt.Errorf("Exec error"), true, "select * from users where id=$1", 100)
+				str := common.Log(io.Discard, "[func Exec]", time.Now(), fmt.Errorf("Exec error"), true, "select * from users where id=$1", 100)
 				Expect(str).To(Equal("\x1b[1;33m[SQL] [TX] [func Exec] select * from users where id=$1 ([100]) \x1b[0m\x1b[0;31m(Exec error) 0.000 ms\x1b[0m\n"))
 			})
 		})

+ 8 - 7
gosql/common/dbmethods.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"database/sql"
 	"fmt"
+	"os"
 	"time"
 
 	"github.com/pkg/errors"
@@ -27,7 +28,7 @@ func (db *DBMethods) Begin(ctx context.Context, opts *sql.TxOptions) (*Tx, error
 	if db.Debug {
 		t := time.Now()
 		tx, err := db.DB.BeginTx(ctx, opts)
-		log("[func Begin]", t, err, true, "")
+		log(os.Stdout, "[func Begin]", t, err, true, "")
 		return &Tx{tx, db.Debug, db.Driver, t}, err
 	}
 
@@ -42,7 +43,7 @@ func (db *DBMethods) Close() error {
 	if db.Debug {
 		t := time.Now()
 		err := db.DB.Close()
-		log("[func Close]", t, err, false, "")
+		log(os.Stdout, "[func Close]", t, err, false, "")
 		return err
 	}
 	return db.DB.Close()
@@ -52,7 +53,7 @@ func (db *DBMethods) Exec(ctx context.Context, query string, args ...any) (sql.R
 	if db.Debug {
 		t := time.Now()
 		res, err := db.DB.ExecContext(ctx, db.fixQuery(query), args...)
-		log("[func Exec]", t, err, false, db.fixQuery(query), args...)
+		log(os.Stdout, "[func Exec]", t, err, false, db.fixQuery(query), args...)
 		return res, err
 	}
 	return db.DB.ExecContext(ctx, db.fixQuery(query), args...)
@@ -62,7 +63,7 @@ func (db *DBMethods) Ping(ctx context.Context) error {
 	if db.Debug {
 		t := time.Now()
 		err := db.DB.PingContext(ctx)
-		log("[func Ping]", t, err, false, "")
+		log(os.Stdout, "[func Ping]", t, err, false, "")
 		return err
 	}
 	return db.DB.PingContext(ctx)
@@ -72,7 +73,7 @@ func (db *DBMethods) Prepare(ctx context.Context, query string) (*sql.Stmt, erro
 	if db.Debug {
 		t := time.Now()
 		stm, err := db.DB.PrepareContext(ctx, db.fixQuery(query))
-		log("[func Prepare]", t, err, false, db.fixQuery(query))
+		log(os.Stdout, "[func Prepare]", t, err, false, db.fixQuery(query))
 		return stm, err
 	}
 	return db.DB.PrepareContext(ctx, db.fixQuery(query))
@@ -82,7 +83,7 @@ func (db *DBMethods) Query(ctx context.Context, query string, args ...any) (*sql
 	if db.Debug {
 		t := time.Now()
 		rows, err := db.DB.QueryContext(ctx, db.fixQuery(query), args...)
-		log("[func Query]", t, err, false, db.fixQuery(query), args...)
+		log(os.Stdout, "[func Query]", t, err, false, db.fixQuery(query), args...)
 		return rows, err
 	}
 	return db.DB.QueryContext(ctx, db.fixQuery(query), args...)
@@ -92,7 +93,7 @@ func (db *DBMethods) QueryRow(ctx context.Context, query string, args ...any) *s
 	if db.Debug {
 		t := time.Now()
 		row := db.DB.QueryRowContext(ctx, db.fixQuery(query), args...)
-		log("[func QueryRow]", t, nil, false, db.fixQuery(query), args...)
+		log(os.Stdout, "[func QueryRow]", t, nil, false, db.fixQuery(query), args...)
 		return row
 	}
 	return db.DB.QueryRowContext(ctx, db.fixQuery(query), args...)

+ 6 - 5
gosql/common/tx.go

@@ -3,6 +3,7 @@ package common
 import (
 	"context"
 	"database/sql"
+	"os"
 
 	"time"
 )
@@ -25,7 +26,7 @@ func (db *Tx) fixQuery(query string) string {
 func (db *Tx) Commit() error {
 	if db.Debug {
 		err := db.tx.Commit()
-		log("[func Commit]", db.t, err, true, "")
+		log(os.Stdout, "[func Commit]", db.t, err, true, "")
 		return err
 	}
 	return db.tx.Commit()
@@ -35,7 +36,7 @@ func (db *Tx) Exec(ctx context.Context, query string, args ...any) (sql.Result,
 	if db.Debug {
 		t := time.Now()
 		res, err := db.tx.ExecContext(ctx, db.fixQuery(query), args...)
-		log("[func Exec]", t, err, true, db.fixQuery(query), args...)
+		log(os.Stdout, "[func Exec]", t, err, true, db.fixQuery(query), args...)
 		return res, err
 	}
 	return db.tx.ExecContext(ctx, db.fixQuery(query), args...)
@@ -45,7 +46,7 @@ func (db *Tx) Query(ctx context.Context, query string, args ...any) (*sql.Rows,
 	if db.Debug {
 		t := time.Now()
 		rows, err := db.tx.QueryContext(ctx, db.fixQuery(query), args...)
-		log("[func Query]", t, err, true, db.fixQuery(query), args...)
+		log(os.Stdout, "[func Query]", t, err, true, db.fixQuery(query), args...)
 		return rows, err
 	}
 	return db.tx.QueryContext(ctx, db.fixQuery(query), args...)
@@ -55,7 +56,7 @@ func (db *Tx) QueryRow(ctx context.Context, query string, args ...any) *sql.Row
 	if db.Debug {
 		t := time.Now()
 		row := db.tx.QueryRowContext(ctx, db.fixQuery(query), args...)
-		log("[func QueryRow]", t, nil, true, db.fixQuery(query), args...)
+		log(os.Stdout, "[func QueryRow]", t, nil, true, db.fixQuery(query), args...)
 		return row
 	}
 	return db.tx.QueryRowContext(ctx, db.fixQuery(query), args...)
@@ -64,7 +65,7 @@ func (db *Tx) QueryRow(ctx context.Context, query string, args ...any) *sql.Row
 func (db *Tx) Rollback() error {
 	if db.Debug {
 		err := db.tx.Rollback()
-		log("[func Rollback]", db.t, err, true, "")
+		log(os.Stdout, "[func Rollback]", db.t, err, true, "")
 		return err
 	}
 	return db.tx.Rollback()