Browse Source

BasicAuth handler

Volodymyr Tkach 2 years ago
parent
commit
828ba2548e
2 changed files with 114 additions and 0 deletions
  1. 44 0
      utils/http/servauth/servauth.go
  2. 70 0
      utils/http/servauth/servauth_test.go

+ 44 - 0
utils/http/servauth/servauth.go

@@ -0,0 +1,44 @@
+package servauth
+
+import (
+	"log"
+	"net/http"
+)
+
+func BasicAuth(handler http.Handler, username, password, realm string) http.Handler {
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		if realm == "" {
+			realm = "Please enter username and password"
+		}
+
+		u, p, ok := r.BasicAuth()
+		if !ok {
+			w.WriteHeader(401)
+			w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`)
+			if _, err := w.Write([]byte("Unauthorised\n")); err != nil {
+				log.Printf("%s\n", err.Error())
+			}
+			return
+		}
+
+		if u != username {
+			w.WriteHeader(401)
+			w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`)
+			if _, err := w.Write([]byte("Unauthorised\n")); err != nil {
+				log.Printf("%s\n", err.Error())
+			}
+			return
+		}
+
+		if p != password {
+			w.WriteHeader(401)
+			w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`)
+			if _, err := w.Write([]byte("Unauthorised\n")); err != nil {
+				log.Printf("%s\n", err.Error())
+			}
+			return
+		}
+
+		handler.ServeHTTP(w, r)
+	})
+}

+ 70 - 0
utils/http/servauth/servauth_test.go

@@ -0,0 +1,70 @@
+package servauth_test
+
+import (
+	"fmt"
+	"io"
+	"net/http"
+	"net/http/httptest"
+	"testing"
+
+	. "github.com/onsi/ginkgo"
+	. "github.com/onsi/gomega"
+	"github.com/vladimirok5959/golang-utils/utils/http/servauth"
+)
+
+var _ = Describe("servauth", func() {
+	Context("BasicAuth", func() {
+		var srv *httptest.Server
+		var client *http.Client
+
+		var getTestHandler = func() http.HandlerFunc {
+			return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+				if _, err := w.Write([]byte("Index")); err != nil {
+					fmt.Printf("%s\n", err.Error())
+				}
+			})
+		}
+
+		BeforeEach(func() {
+			srv = httptest.NewServer(servauth.BasicAuth(getTestHandler(), "user", "pass", ""))
+			client = srv.Client()
+		})
+
+		AfterEach(func() {
+			srv.Close()
+		})
+
+		It("request credentials", func() {
+			resp, err := client.Get(srv.URL + "/")
+			Expect(err).To(Succeed())
+			defer resp.Body.Close()
+
+			Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
+
+			body, err := io.ReadAll(resp.Body)
+			Expect(err).To(Succeed())
+			Expect(string(body)).To(Equal("Unauthorised\n"))
+		})
+
+		It("show with correct credentials", func() {
+			req, err := http.NewRequest("GET", srv.URL+"/", nil)
+			Expect(err).To(Succeed())
+			req.SetBasicAuth("user", "pass")
+
+			resp, err := client.Do(req)
+			Expect(err).To(Succeed())
+			defer resp.Body.Close()
+
+			Expect(resp.StatusCode).To(Equal(http.StatusOK))
+
+			body, err := io.ReadAll(resp.Body)
+			Expect(err).To(Succeed())
+			Expect(string(body)).To(Equal("Index"))
+		})
+	})
+})
+
+func TestSuite(t *testing.T) {
+	RegisterFailHandler(Fail)
+	RunSpecs(t, "servauth")
+}