diff --git a/pubsub/pt.go b/pubsub/pt.go index 6c82cfd..74f225b 100644 --- a/pubsub/pt.go +++ b/pubsub/pt.go @@ -15,17 +15,6 @@ const ( Done ) -type subChan struct { - c chan<- string - lastUpdateIndex int -} - -type subscriberState struct { - step Step - m sync.Mutex - sc []subChan -} - type InvalidState struct{ Step Step } func (i InvalidState) Error() string { @@ -37,46 +26,42 @@ func (i InvalidState) Error() string { type ProgressTracker interface { // Only one publisher sends update. Should close when done // Error if there is/was existing publisher - Publish(id string) (chan<- string, error) + Publish() (chan<- string, error) // Can subscribe even if there no publisher yet // If already done, nil channel is returned // channel will be closed when done - Subscribe(id string) <-chan string + Subscribe() <-chan string +} + +type subChan struct { + c chan<- string + lastUpdateIndex int } type progressTracker struct { - subscribers map[string]*subscriberState - m sync.Mutex + step Step + m sync.Mutex + sc []subChan } func NewProgressTracker() ProgressTracker { return &progressTracker{ - subscribers: map[string]*subscriberState{}, + step: NotStarted, } } -func (pt *progressTracker) Publish(id string) (chan<- string, error) { - var state *subscriberState - func() { +func (pt *progressTracker) Publish() (chan<- string, error) { + err := func() error { pt.m.Lock() defer pt.m.Unlock() - state = pt.subscribers[id] - if state == nil { - state = &subscriberState{step: NotStarted} - pt.subscribers[id] = state + if pt.step != NotStarted { + return InvalidState{pt.step} } - }() - - err := func() error { - state.m.Lock() - defer state.m.Unlock() - if state.step != NotStarted { - return InvalidState{state.step} - } - state.step = Publishing + pt.step = Publishing return nil }() + if err != nil { return nil, err } @@ -100,17 +85,17 @@ func (pt *progressTracker) Publish(id string) (chan<- string, error) { } var scs []subChan func() { - state.m.Lock() - defer state.m.Unlock() - scs = state.sc + pt.m.Lock() + defer pt.m.Unlock() + scs = pt.sc if !prodChanOpen { - for _, subChan := range scs { - close(subChan.c) - } - state.step = Done + pt.step = Done } }() if !prodChanOpen { + for _, subChan := range scs { + close(subChan.c) + } return } for _, subChan := range scs { @@ -127,29 +112,14 @@ func (pt *progressTracker) Publish(id string) (chan<- string, error) { return prodChan, nil } -func (pt *progressTracker) Subscribe(id string) <-chan string { +func (pt *progressTracker) Subscribe() <-chan string { c := make(chan string, 1) sc := subChan{c: c} - var state *subscriberState - func() { - pt.m.Lock() - defer pt.m.Unlock() - state = pt.subscribers[id] - if state == nil { - pt.subscribers[id] = &subscriberState{ - step: NotStarted, - sc: []subChan{sc}, - } - } - }() - if state == nil { - return c - } - state.m.Lock() - defer state.m.Unlock() - if state.step == Done { + pt.m.Lock() + defer pt.m.Unlock() + if pt.step == Done { return nil } - state.sc = append(state.sc, sc) + pt.sc = append(pt.sc, sc) return c } diff --git a/pubsub/pt_test.go b/pubsub/pt_test.go index fdd20bc..01ab8fe 100644 --- a/pubsub/pt_test.go +++ b/pubsub/pt_test.go @@ -9,47 +9,43 @@ import ( func TestDupePublisher(t *testing.T) { pt := NewProgressTracker() - if _, err := pt.Publish("foo"); err != nil { + if _, err := pt.Publish(); err != nil { t.Fatalf("First publisher should not give error, err:%v", err) } - if _, err := pt.Publish("foo"); err == nil { + if _, err := pt.Publish(); err == nil { t.Fatal("Dupe publisher should give error but got nil") } else { t.Logf("Got err: %v", err) } - - if _, err := pt.Publish("bar"); err != nil { - t.Fatalf("Different publisher should not give error, err:%v", err) - } } func TestSubSub(t *testing.T) { pt := NewProgressTracker() - c1 := pt.Subscribe("foo") + c1 := pt.Subscribe() select { case <-c1: default: } if c1 == nil { - t.Fatal("Subscriber should not get a closed channel") + t.Fatal("Subscriber should not get a nil channel") } - c2 := pt.Subscribe("foo") + c2 := pt.Subscribe() if c2 == nil { - t.Fatal("Subscriber should not get a closed channel") + t.Fatal("Subscriber should not get a nil channel") } } func TestPubSub(t *testing.T) { pt := NewProgressTracker() - pc, err := pt.Publish("foo") + pc, err := pt.Publish() if err != nil { t.Fatalf("Unexpected err: %v", err) } if pc == nil { t.Fatal("Should not get nil channel") } - sc := pt.Subscribe("foo") + sc := pt.Subscribe() if sc == nil { t.Fatal("Should not get nil channel") } @@ -70,8 +66,6 @@ func TestPubSub(t *testing.T) { pc <- "blah" time.Sleep(166 * time.Millisecond) if i == 5 { - // time.Sleep(100 * time.Millisecond) - //time.Sleep(1 * time.Second) <-testc } } @@ -81,7 +75,7 @@ func TestPubSub(t *testing.T) { t.Fatal("There should be atleast one update") } t.Logf("c is :%d", c) - sc2 := pt.Subscribe("foo") + sc2 := pt.Subscribe() if sc2 != nil { t.Fatal("Subscriber after publisher done should return nil") }