Browse Source

Added context supports

Listen/Unlisten break long loops
Volodymyr Tkach 2 years ago
parent
commit
233f347067
4 changed files with 49 additions and 18 deletions
  1. 1 1
      README.md
  2. 5 2
      cmd/cli/main.go
  3. 27 2
      pubsub/pubsub.go
  4. 16 13
      pubsub/pubsub_test.go

+ 1 - 1
README.md

@@ -33,7 +33,7 @@ ps.OnPong(func(c *pubsub.Connection, start, end time.Time) {
     log.Printf("OnPong (ID: %d), start: %d, end: %d\n", c.ID, start.Unix(), end.Unix())
 })
 
-ps.Listen("community-points-channel-v1", "<UserID>")
+ps.Listen(context.Background(), "community-points-channel-v1", "<UserID>")
 
 interrupt := make(chan os.Signal, 1)
 signal.Notify(interrupt, os.Interrupt)

+ 5 - 2
cmd/cli/main.go

@@ -4,6 +4,7 @@ package main
 
 import (
 	"bufio"
+	"context"
 	"fmt"
 	"log"
 	"os"
@@ -49,6 +50,8 @@ func main() {
 	interrupt := make(chan os.Signal, 1)
 	signal.Notify(interrupt, os.Interrupt)
 
+	ctx := context.Background()
+
 	go func(interrupt chan os.Signal) {
 		reader := bufio.NewReader(os.Stdin)
 		for {
@@ -63,7 +66,7 @@ func main() {
 					if len([]rune(cmd)) > 7 {
 						param := string([]rune(cmd)[7:len([]rune(cmd))])
 						fmt.Printf("Listen: (%s)\n", param)
-						ps.Listen(param)
+						ps.Listen(ctx, param)
 					} else {
 						fmt.Printf("Parameter is not set\n")
 					}
@@ -71,7 +74,7 @@ func main() {
 					if len([]rune(cmd)) > 9 {
 						param := string([]rune(cmd)[9:len([]rune(cmd))])
 						fmt.Printf("Unlisten: (%s)\n", param)
-						ps.Unlisten(param)
+						ps.Unlisten(ctx, param)
 					} else {
 						fmt.Printf("Parameter is not set\n")
 					}

+ 27 - 2
pubsub/pubsub.go

@@ -4,6 +4,7 @@
 package pubsub
 
 import (
+	"context"
 	"fmt"
 	"net/url"
 	"strings"
@@ -76,7 +77,7 @@ func (p *PubSub) newConnection() *Connection {
 // New TCP connection will be created for every 50 topics.
 //
 // https://dev.twitch.tv/docs/pubsub/#connection-management
-func (p *PubSub) Listen(topic string, params ...interface{}) {
+func (p *PubSub) Listen(ctx context.Context, topic string, params ...interface{}) {
 	p.Lock()
 	defer p.Unlock()
 
@@ -91,6 +92,12 @@ func (p *PubSub) Listen(topic string, params ...interface{}) {
 	// Check topic in connection
 	// Don't continue if already present
 	for _, c := range p.Connections {
+		select {
+		case <-ctx.Done():
+			return
+		default:
+		}
+
 		if c.HasTopic(t) {
 			return
 		}
@@ -98,6 +105,12 @@ func (p *PubSub) Listen(topic string, params ...interface{}) {
 
 	// Add topic to first not busy connection
 	for _, c := range p.Connections {
+		select {
+		case <-ctx.Done():
+			return
+		default:
+		}
+
 		if c.TopicsCount() < TwitchApiMaxTopics {
 			c.AddTopic(t)
 			return
@@ -114,7 +127,7 @@ func (p *PubSub) Listen(topic string, params ...interface{}) {
 // Connection count will automatically decrease of needs.
 //
 // https://dev.twitch.tv/docs/pubsub/#connection-management
-func (p *PubSub) Unlisten(topic string, params ...interface{}) {
+func (p *PubSub) Unlisten(ctx context.Context, topic string, params ...interface{}) {
 	p.Lock()
 	defer p.Unlock()
 
@@ -122,6 +135,12 @@ func (p *PubSub) Unlisten(topic string, params ...interface{}) {
 
 	// Search and unlisten
 	for _, c := range p.Connections {
+		select {
+		case <-ctx.Done():
+			return
+		default:
+		}
+
 		if c.HasTopic(t) {
 			c.RemoveTopic(t)
 
@@ -133,6 +152,12 @@ func (p *PubSub) Unlisten(topic string, params ...interface{}) {
 
 	// Remove empty connections
 	for i, c := range p.Connections {
+		select {
+		case <-ctx.Done():
+			return
+		default:
+		}
+
 		if c.TopicsCount() <= 0 {
 			_ = c.Close()
 			delete(p.Connections, i)

+ 16 - 13
pubsub/pubsub_test.go

@@ -1,6 +1,7 @@
 package pubsub_test
 
 import (
+	"context"
 	"fmt"
 	"net/url"
 	"testing"
@@ -12,6 +13,8 @@ import (
 )
 
 var _ = Describe("PubSub", func() {
+	var ctx = context.Background()
+
 	Context("PubSub", func() {
 		var ps *pubsub.PubSub
 
@@ -28,22 +31,22 @@ var _ = Describe("PubSub", func() {
 				Expect(len(ps.Connections)).To(Equal(0))
 
 				for i := 1; i <= 45; i++ {
-					ps.Listen("community-points-channel-v1", 1, i)
+					ps.Listen(ctx, "community-points-channel-v1", 1, i)
 				}
 				Expect(len(ps.Connections)).To(Equal(1))
 
 				for i := 1; i <= 5; i++ {
-					ps.Listen("community-points-channel-v1", 1, i)
+					ps.Listen(ctx, "community-points-channel-v1", 1, i)
 				}
 				Expect(len(ps.Connections)).To(Equal(1))
 
 				for i := 1; i <= 50; i++ {
-					ps.Listen("community-points-channel-v1", 2, i)
+					ps.Listen(ctx, "community-points-channel-v1", 2, i)
 				}
 				Expect(len(ps.Connections)).To(Equal(2))
 
 				for i := 1; i <= 50; i++ {
-					ps.Listen("community-points-channel-v1", 3, i)
+					ps.Listen(ctx, "community-points-channel-v1", 3, i)
 				}
 				Expect(len(ps.Connections)).To(Equal(3))
 			})
@@ -54,18 +57,18 @@ var _ = Describe("PubSub", func() {
 				Expect(len(ps.Connections)).To(Equal(0))
 
 				for i := 1; i <= 50; i++ {
-					ps.Listen("community-points-channel-v1", 1, i)
+					ps.Listen(ctx, "community-points-channel-v1", 1, i)
 				}
 				Expect(len(ps.Connections)).To(Equal(1))
 
-				ps.Listen("community-points-channel-v1", 2, 1)
+				ps.Listen(ctx, "community-points-channel-v1", 2, 1)
 				Expect(len(ps.Connections)).To(Equal(2))
 
-				ps.Unlisten("community-points-channel-v1", 2, 1)
+				ps.Unlisten(ctx, "community-points-channel-v1", 2, 1)
 				Expect(len(ps.Connections)).To(Equal(1))
 
 				for i := 1; i <= 50; i++ {
-					ps.Unlisten("community-points-channel-v1", 1, i)
+					ps.Unlisten(ctx, "community-points-channel-v1", 1, i)
 				}
 				Expect(len(ps.Connections)).To(Equal(0))
 			})
@@ -76,11 +79,11 @@ var _ = Describe("PubSub", func() {
 				Expect(len(ps.Connections)).To(Equal(0))
 
 				for i := 1; i <= 50; i++ {
-					ps.Listen("community-points-channel-v1", 1, i)
+					ps.Listen(ctx, "community-points-channel-v1", 1, i)
 				}
 				Expect(len(ps.Connections)).To(Equal(1))
 
-				ps.Listen("community-points-channel-v1", 2, 1)
+				ps.Listen(ctx, "community-points-channel-v1", 2, 1)
 				Expect(len(ps.Connections)).To(Equal(2))
 
 				Expect(ps.Topics()).To(ContainElements(
@@ -93,7 +96,7 @@ var _ = Describe("PubSub", func() {
 			It("checks topics", func() {
 				Expect(len(ps.Connections)).To(Equal(0))
 
-				ps.Listen("community-points-channel-v1", 1)
+				ps.Listen(ctx, "community-points-channel-v1", 1)
 				Expect(ps.HasTopic("unknown")).To(BeFalse())
 				Expect(ps.HasTopic("community-points-channel-v1", 1)).To(BeTrue())
 			})
@@ -103,12 +106,12 @@ var _ = Describe("PubSub", func() {
 			It("return topics count", func() {
 				Expect(ps.TopicsCount()).To(Equal(0))
 				for i := 1; i <= 50; i++ {
-					ps.Listen("community-points-channel-v1", 1, i)
+					ps.Listen(ctx, "community-points-channel-v1", 1, i)
 				}
 				Expect(ps.TopicsCount()).To(Equal(50))
 
 				for i := 1; i <= 5; i++ {
-					ps.Listen("community-points-channel-v1", 2, i)
+					ps.Listen(ctx, "community-points-channel-v1", 2, i)
 				}
 				Expect(ps.TopicsCount()).To(Equal(55))
 			})