Browse Source

Fix, cover by tests

Volodymyr Tkach 2 years ago
parent
commit
e09c226966
2 changed files with 5 additions and 4 deletions
  1. 3 3
      utils/http/servauth/servauth.go
  2. 2 1
      utils/http/servauth/servauth_test.go

+ 3 - 3
utils/http/servauth/servauth.go

@@ -13,8 +13,8 @@ func BasicAuth(handler http.Handler, username, password, realm string) http.Hand
 
 		u, p, ok := r.BasicAuth()
 		if !ok {
-			w.WriteHeader(401)
 			w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`)
+			w.WriteHeader(401)
 			if _, err := w.Write([]byte("Unauthorised\n")); err != nil {
 				log.Printf("%s\n", err.Error())
 			}
@@ -22,8 +22,8 @@ func BasicAuth(handler http.Handler, username, password, realm string) http.Hand
 		}
 
 		if u != username {
-			w.WriteHeader(401)
 			w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`)
+			w.WriteHeader(401)
 			if _, err := w.Write([]byte("Unauthorised\n")); err != nil {
 				log.Printf("%s\n", err.Error())
 			}
@@ -31,8 +31,8 @@ func BasicAuth(handler http.Handler, username, password, realm string) http.Hand
 		}
 
 		if p != password {
-			w.WriteHeader(401)
 			w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`)
+			w.WriteHeader(401)
 			if _, err := w.Write([]byte("Unauthorised\n")); err != nil {
 				log.Printf("%s\n", err.Error())
 			}

+ 2 - 1
utils/http/servauth/servauth_test.go

@@ -26,7 +26,7 @@ var _ = Describe("servauth", func() {
 		}
 
 		BeforeEach(func() {
-			srv = httptest.NewServer(servauth.BasicAuth(getTestHandler(), "user", "pass", ""))
+			srv = httptest.NewServer(servauth.BasicAuth(getTestHandler(), "user", "pass", "msg"))
 			client = srv.Client()
 		})
 
@@ -40,6 +40,7 @@ var _ = Describe("servauth", func() {
 			defer resp.Body.Close()
 
 			Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
+			Expect(resp.Header["Www-Authenticate"]).To(Equal([]string{`Basic realm="msg"`}))
 
 			body, err := io.ReadAll(resp.Body)
 			Expect(err).To(Succeed())