servauth.go 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. package servauth
  2. import (
  3. "log"
  4. "net/http"
  5. "time"
  6. "github.com/vladimirok5959/golang-utils/utils/http/helpers"
  7. )
  8. var mRequests = &Requests{
  9. counter: map[string]int{},
  10. lastTime: map[string]int64{},
  11. cleanTime: time.Now().UTC().Unix(),
  12. }
  13. func BasicAuth(handler http.Handler, username, password, realm string) http.Handler {
  14. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  15. if username != "" {
  16. // Cleanup every hour
  17. if (time.Now().UTC().Unix() - mRequests.cleanTime) > 3600 {
  18. mRequests.Cleanup()
  19. }
  20. ip := helpers.ClientIP(r)
  21. reqs := mRequests.Count(ip)
  22. ltime := mRequests.Time(ip)
  23. // Reset counter
  24. if (time.Now().UTC().Unix() - ltime) >= 30 {
  25. reqs = 0
  26. mRequests.SetCount(ip, reqs)
  27. mRequests.SetTime(ip, time.Now().UTC().Unix())
  28. }
  29. // Restrict access
  30. if reqs >= 5 {
  31. w.Header().Set("Retry-After", "30")
  32. w.WriteHeader(429)
  33. if _, err := w.Write([]byte("Too Many Requests\n")); err != nil {
  34. log.Printf("%s\n", err.Error())
  35. }
  36. return
  37. }
  38. if realm == "" {
  39. realm = "Please enter username and password"
  40. }
  41. u, p, ok := r.BasicAuth()
  42. if !ok {
  43. w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`)
  44. w.WriteHeader(401)
  45. if _, err := w.Write([]byte("Unauthorised\n")); err != nil {
  46. log.Printf("%s\n", err.Error())
  47. }
  48. return
  49. }
  50. if u != username {
  51. // Inc counter
  52. reqs = reqs + 1
  53. mRequests.SetCount(ip, reqs)
  54. mRequests.SetTime(ip, time.Now().UTC().Unix())
  55. w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`)
  56. w.WriteHeader(401)
  57. if _, err := w.Write([]byte("Unauthorised\n")); err != nil {
  58. log.Printf("%s\n", err.Error())
  59. }
  60. return
  61. }
  62. if p != password {
  63. // Inc counter
  64. reqs = reqs + 1
  65. mRequests.SetCount(ip, reqs)
  66. mRequests.SetTime(ip, time.Now().UTC().Unix())
  67. w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`)
  68. w.WriteHeader(401)
  69. if _, err := w.Write([]byte("Unauthorised\n")); err != nil {
  70. log.Printf("%s\n", err.Error())
  71. }
  72. return
  73. }
  74. // Reset counter
  75. if reqs > 0 {
  76. reqs = 0
  77. mRequests.SetCount(ip, reqs)
  78. mRequests.SetTime(ip, time.Now().UTC().Unix())
  79. }
  80. }
  81. handler.ServeHTTP(w, r)
  82. })
  83. }