Browse Source

Requests rate limiter, tests

Volodymyr Tkach 2 years ago
parent
commit
064f06b635

+ 2 - 1
utils/http/helpers/helpers.go

@@ -65,7 +65,8 @@ func ClientIPs(r *http.Request) []string {
 	res := []string{}
 	ips := strings.Split(ra, ",")
 	for _, ip := range ips {
-		res = append(res, strings.Trim(ip, " "))
+		ipPort := strings.Split(ip, ":")
+		res = append(res, strings.Trim(ipPort[0], " "))
 	}
 	return res
 }

+ 51 - 0
utils/http/servlimit/requests.go

@@ -0,0 +1,51 @@
+package servlimit
+
+import (
+	"sync"
+	"time"
+)
+
+type Requests struct {
+	sync.RWMutex
+	counter   map[string]int
+	lastTime  map[string]int64
+	cleanTime int64
+}
+
+func (r *Requests) Count(ip string) int {
+	r.Lock()
+	defer r.Unlock()
+	if v, ok := r.counter[ip]; ok {
+		return v
+	}
+	return 0
+}
+
+func (r *Requests) SetCount(ip string, count int) {
+	r.Lock()
+	defer r.Unlock()
+	r.counter[ip] = count
+}
+
+func (r *Requests) Time(ip string) int64 {
+	r.Lock()
+	defer r.Unlock()
+	if v, ok := r.lastTime[ip]; ok {
+		return v
+	}
+	return 0
+}
+
+func (r *Requests) SetTime(ip string, time int64) {
+	r.Lock()
+	defer r.Unlock()
+	r.lastTime[ip] = time
+}
+
+func (r *Requests) Cleanup() {
+	r.Lock()
+	defer r.Unlock()
+	r.counter = map[string]int{}
+	r.lastTime = map[string]int64{}
+	r.cleanTime = time.Now().UTC().Unix()
+}

+ 51 - 0
utils/http/servlimit/servlimit.go

@@ -0,0 +1,51 @@
+package servlimit
+
+import (
+	"log"
+	"net/http"
+	"time"
+
+	"github.com/vladimirok5959/golang-utils/utils/http/helpers"
+)
+
+var mRequests = &Requests{
+	counter:   map[string]int{},
+	lastTime:  map[string]int64{},
+	cleanTime: time.Now().UTC().Unix(),
+}
+
+func ReqPerSecond(handler http.Handler, requests int) http.Handler {
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		// Cleanup every hour
+		if (time.Now().UTC().Unix() - mRequests.cleanTime) > 3600 {
+			mRequests.Cleanup()
+		}
+
+		ip := helpers.ClientIP(r)
+		reqs := mRequests.Count(ip)
+		ltime := mRequests.Time(ip)
+
+		// Inc counter
+		reqs = reqs + 1
+		mRequests.SetCount(ip, reqs)
+
+		// Reset counter
+		if (time.Now().UTC().Unix() - ltime) >= 1 {
+			reqs = 0
+			mRequests.SetCount(ip, 0)
+		}
+
+		// Restrict access
+		if reqs >= requests {
+			w.WriteHeader(429)
+			if _, err := w.Write([]byte("Too Many Requests\n")); err != nil {
+				log.Printf("%s\n", err.Error())
+			}
+			return
+		}
+
+		mRequests.SetTime(ip, time.Now().UTC().Unix())
+
+		handler.ServeHTTP(w, r)
+	})
+}

+ 3 - 0
utils/http/servlimit/servlimit_export_test.go

@@ -0,0 +1,3 @@
+package servlimit
+
+var MRequests = mRequests

+ 182 - 0
utils/http/servlimit/servlimit_test.go

@@ -0,0 +1,182 @@
+package servlimit_test
+
+import (
+	"fmt"
+	"io"
+	"net/http"
+	"net/http/httptest"
+	"testing"
+	"time"
+
+	. "github.com/onsi/ginkgo"
+	. "github.com/onsi/gomega"
+	"github.com/vladimirok5959/golang-utils/utils/http/servlimit"
+)
+
+var _ = Describe("servlimit", func() {
+	Context("ReqPerSecond", func() {
+		var srv *httptest.Server
+		var client *http.Client
+
+		var getTestHandler = func() http.HandlerFunc {
+			return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+				if _, err := w.Write([]byte("Index")); err != nil {
+					fmt.Printf("%s\n", err.Error())
+				}
+			})
+		}
+
+		BeforeEach(func() {
+			servlimit.MRequests.Cleanup()
+			srv = httptest.NewServer(servlimit.ReqPerSecond(getTestHandler(), 1))
+			client = srv.Client()
+		})
+
+		AfterEach(func() {
+			srv.Close()
+		})
+
+		It("process request", func() {
+			resp, err := client.Get(srv.URL + "/")
+			Expect(err).To(Succeed())
+			defer resp.Body.Close()
+
+			Expect(resp.StatusCode).To(Equal(http.StatusOK))
+
+			body, err := io.ReadAll(resp.Body)
+			Expect(err).To(Succeed())
+			Expect(string(body)).To(Equal("Index"))
+		})
+
+		It("process multiple requests", func() {
+			resp, err := client.Get(srv.URL + "/")
+			Expect(err).To(Succeed())
+			resp.Body.Close()
+
+			Expect(resp.StatusCode).To(Equal(http.StatusOK))
+
+			time.Sleep(1 * time.Second)
+
+			resp, err = client.Get(srv.URL + "/")
+			Expect(err).To(Succeed())
+			defer resp.Body.Close()
+
+			Expect(resp.StatusCode).To(Equal(http.StatusOK))
+
+			body, err := io.ReadAll(resp.Body)
+			Expect(err).To(Succeed())
+			Expect(string(body)).To(Equal("Index"))
+		})
+
+		It("block multiple requests", func() {
+			resp, err := client.Get(srv.URL + "/")
+			Expect(err).To(Succeed())
+			resp.Body.Close()
+
+			Expect(resp.StatusCode).To(Equal(http.StatusOK))
+
+			resp, err = client.Get(srv.URL + "/")
+			Expect(err).To(Succeed())
+			defer resp.Body.Close()
+
+			Expect(resp.StatusCode).To(Equal(http.StatusTooManyRequests))
+
+			body, err := io.ReadAll(resp.Body)
+			Expect(err).To(Succeed())
+			Expect(string(body)).To(Equal("Too Many Requests\n"))
+		})
+
+		It("block more multiple requests", func() {
+			resp, err := client.Get(srv.URL + "/")
+			Expect(err).To(Succeed())
+			resp.Body.Close()
+
+			Expect(resp.StatusCode).To(Equal(http.StatusOK))
+
+			resp, err = client.Get(srv.URL + "/")
+			Expect(err).To(Succeed())
+			resp.Body.Close()
+
+			Expect(resp.StatusCode).To(Equal(http.StatusTooManyRequests))
+
+			resp, err = client.Get(srv.URL + "/")
+			Expect(err).To(Succeed())
+			resp.Body.Close()
+
+			Expect(resp.StatusCode).To(Equal(http.StatusTooManyRequests))
+
+			resp, err = client.Get(srv.URL + "/")
+			Expect(err).To(Succeed())
+			resp.Body.Close()
+
+			Expect(resp.StatusCode).To(Equal(http.StatusTooManyRequests))
+
+			resp, err = client.Get(srv.URL + "/")
+			Expect(err).To(Succeed())
+			defer resp.Body.Close()
+
+			Expect(resp.StatusCode).To(Equal(http.StatusTooManyRequests))
+
+			body, err := io.ReadAll(resp.Body)
+			Expect(err).To(Succeed())
+			Expect(string(body)).To(Equal("Too Many Requests\n"))
+		})
+
+		It("clean requests data in memory", func() {
+			resp, err := client.Get(srv.URL + "/")
+			Expect(err).To(Succeed())
+			resp.Body.Close()
+
+			servlimit.MRequests.Cleanup()
+
+			resp, err = client.Get(srv.URL + "/")
+			Expect(err).To(Succeed())
+			defer resp.Body.Close()
+
+			Expect(resp.StatusCode).To(Equal(http.StatusOK))
+
+			body, err := io.ReadAll(resp.Body)
+			Expect(err).To(Succeed())
+			Expect(string(body)).To(Equal("Index"))
+		})
+
+		It("process 3 requests per second", func() {
+			srv.Close()
+			srv = httptest.NewServer(servlimit.ReqPerSecond(getTestHandler(), 3))
+			client = srv.Client()
+
+			resp, err := client.Get(srv.URL + "/")
+			Expect(err).To(Succeed())
+			resp.Body.Close()
+
+			Expect(resp.StatusCode).To(Equal(http.StatusOK))
+
+			resp, err = client.Get(srv.URL + "/")
+			Expect(err).To(Succeed())
+			resp.Body.Close()
+
+			Expect(resp.StatusCode).To(Equal(http.StatusOK))
+
+			resp, err = client.Get(srv.URL + "/")
+			Expect(err).To(Succeed())
+			resp.Body.Close()
+
+			Expect(resp.StatusCode).To(Equal(http.StatusOK))
+
+			resp, err = client.Get(srv.URL + "/")
+			Expect(err).To(Succeed())
+			defer resp.Body.Close()
+
+			Expect(resp.StatusCode).To(Equal(http.StatusTooManyRequests))
+
+			body, err := io.ReadAll(resp.Body)
+			Expect(err).To(Succeed())
+			Expect(string(body)).To(Equal("Too Many Requests\n"))
+		})
+	})
+})
+
+func TestSuite(t *testing.T) {
+	RegisterFailHandler(Fail)
+	RunSpecs(t, "servlimit")
+}