Browse Source

Protect BasicAuth from bruteforce

Volodymyr Tkach 2 years ago
parent
commit
ec706623f8

+ 51 - 0
utils/http/servauth/requests.go

@@ -0,0 +1,51 @@
+package servauth
+
+import (
+	"sync"
+	"time"
+)
+
+type Requests struct {
+	sync.RWMutex
+	counter   map[string]int
+	lastTime  map[string]int64
+	cleanTime int64
+}
+
+func (r *Requests) Count(ip string) int {
+	r.Lock()
+	defer r.Unlock()
+	if v, ok := r.counter[ip]; ok {
+		return v
+	}
+	return 0
+}
+
+func (r *Requests) SetCount(ip string, count int) {
+	r.Lock()
+	defer r.Unlock()
+	r.counter[ip] = count
+}
+
+func (r *Requests) Time(ip string) int64 {
+	r.Lock()
+	defer r.Unlock()
+	if v, ok := r.lastTime[ip]; ok {
+		return v
+	}
+	return 0
+}
+
+func (r *Requests) SetTime(ip string, time int64) {
+	r.Lock()
+	defer r.Unlock()
+	r.lastTime[ip] = time
+}
+
+func (r *Requests) Cleanup() {
+	r.Lock()
+	defer r.Unlock()
+	r.counter = map[string]int{}
+	r.lastTime = map[string]int64{}
+	r.cleanTime = time.Now().UTC().Unix()
+}

+ 49 - 1
utils/http/servauth/servauth.go

@@ -3,13 +3,46 @@ package servauth
 import (
 import (
 	"log"
 	"log"
 	"net/http"
 	"net/http"
+	"time"
+
+	"github.com/vladimirok5959/golang-utils/utils/http/helpers"
 )
 )
 
 
-// TODO: protect from bruteforce
+var mRequests = &Requests{
+	counter:   map[string]int{},
+	lastTime:  map[string]int64{},
+	cleanTime: time.Now().UTC().Unix(),
+}
 
 
 func BasicAuth(handler http.Handler, username, password, realm string) http.Handler {
 func BasicAuth(handler http.Handler, username, password, realm string) http.Handler {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		if username != "" {
 		if username != "" {
+			// Cleanup every hour
+			if (time.Now().UTC().Unix() - mRequests.cleanTime) > 3600 {
+				mRequests.Cleanup()
+			}
+
+			ip := helpers.ClientIP(r)
+			reqs := mRequests.Count(ip)
+			ltime := mRequests.Time(ip)
+
+			// Reset counter
+			if (time.Now().UTC().Unix() - ltime) >= 30 {
+				reqs = 0
+				mRequests.SetCount(ip, reqs)
+				mRequests.SetTime(ip, time.Now().UTC().Unix())
+			}
+
+			// Restrict access
+			if reqs >= 5 {
+				w.Header().Set("Retry-After", "30")
+				w.WriteHeader(429)
+				if _, err := w.Write([]byte("Too Many Requests\n")); err != nil {
+					log.Printf("%s\n", err.Error())
+				}
+				return
+			}
+
 			if realm == "" {
 			if realm == "" {
 				realm = "Please enter username and password"
 				realm = "Please enter username and password"
 			}
 			}
@@ -25,6 +58,11 @@ func BasicAuth(handler http.Handler, username, password, realm string) http.Hand
 			}
 			}
 
 
 			if u != username {
 			if u != username {
+				// Inc counter
+				reqs = reqs + 1
+				mRequests.SetCount(ip, reqs)
+				mRequests.SetTime(ip, time.Now().UTC().Unix())
+
 				w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`)
 				w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`)
 				w.WriteHeader(401)
 				w.WriteHeader(401)
 				if _, err := w.Write([]byte("Unauthorised\n")); err != nil {
 				if _, err := w.Write([]byte("Unauthorised\n")); err != nil {
@@ -34,6 +72,11 @@ func BasicAuth(handler http.Handler, username, password, realm string) http.Hand
 			}
 			}
 
 
 			if p != password {
 			if p != password {
+				// Inc counter
+				reqs = reqs + 1
+				mRequests.SetCount(ip, reqs)
+				mRequests.SetTime(ip, time.Now().UTC().Unix())
+
 				w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`)
 				w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`)
 				w.WriteHeader(401)
 				w.WriteHeader(401)
 				if _, err := w.Write([]byte("Unauthorised\n")); err != nil {
 				if _, err := w.Write([]byte("Unauthorised\n")); err != nil {
@@ -41,6 +84,11 @@ func BasicAuth(handler http.Handler, username, password, realm string) http.Hand
 				}
 				}
 				return
 				return
 			}
 			}
+
+			// Reset counter
+			reqs = 0
+			mRequests.SetCount(ip, reqs)
+			mRequests.SetTime(ip, time.Now().UTC().Unix())
 		}
 		}
 
 
 		handler.ServeHTTP(w, r)
 		handler.ServeHTTP(w, r)

+ 52 - 0
utils/http/servauth/servauth_test.go

@@ -112,6 +112,58 @@ var _ = Describe("servauth", func() {
 			Expect(err).To(Succeed())
 			Expect(err).To(Succeed())
 			Expect(string(body)).To(Equal("Unauthorised\n"))
 			Expect(string(body)).To(Equal("Unauthorised\n"))
 		})
 		})
+
+		It("block requests to 30 seconds on 5 times wrong entered credentials", func() {
+			req, err := http.NewRequest("GET", srv.URL+"/", nil)
+			Expect(err).To(Succeed())
+			req.SetBasicAuth("user", "wrong")
+
+			// 1
+			resp, err := client.Do(req)
+			Expect(err).To(Succeed())
+			defer resp.Body.Close()
+
+			Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
+
+			// 2
+			resp, err = client.Do(req)
+			Expect(err).To(Succeed())
+			defer resp.Body.Close()
+
+			Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
+
+			// 3
+			resp, err = client.Do(req)
+			Expect(err).To(Succeed())
+			defer resp.Body.Close()
+
+			Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
+
+			// 4
+			resp, err = client.Do(req)
+			Expect(err).To(Succeed())
+			defer resp.Body.Close()
+
+			Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
+
+			// 5
+			resp, err = client.Do(req)
+			Expect(err).To(Succeed())
+			defer resp.Body.Close()
+
+			Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
+
+			// 6
+			resp, err = client.Do(req)
+			Expect(err).To(Succeed())
+			defer resp.Body.Close()
+
+			Expect(resp.StatusCode).To(Equal(http.StatusTooManyRequests))
+
+			body, err := io.ReadAll(resp.Body)
+			Expect(err).To(Succeed())
+			Expect(string(body)).To(Equal("Too Many Requests\n"))
+		})
 	})
 	})
 })
 })