remove map

This commit is contained in:
Balakrishnan Balasubramanian 2022-06-27 13:05:00 -04:00
parent 1fab1e0095
commit 882f8aea2f
2 changed files with 38 additions and 74 deletions

View File

@ -15,17 +15,6 @@ const (
Done Done
) )
type subChan struct {
c chan<- string
lastUpdateIndex int
}
type subscriberState struct {
step Step
m sync.Mutex
sc []subChan
}
type InvalidState struct{ Step Step } type InvalidState struct{ Step Step }
func (i InvalidState) Error() string { func (i InvalidState) Error() string {
@ -37,46 +26,42 @@ func (i InvalidState) Error() string {
type ProgressTracker interface { type ProgressTracker interface {
// Only one publisher sends update. Should close when done // Only one publisher sends update. Should close when done
// Error if there is/was existing publisher // Error if there is/was existing publisher
Publish(id string) (chan<- string, error) Publish() (chan<- string, error)
// Can subscribe even if there no publisher yet // Can subscribe even if there no publisher yet
// If already done, nil channel is returned // If already done, nil channel is returned
// channel will be closed when done // channel will be closed when done
Subscribe(id string) <-chan string Subscribe() <-chan string
}
type subChan struct {
c chan<- string
lastUpdateIndex int
} }
type progressTracker struct { type progressTracker struct {
subscribers map[string]*subscriberState step Step
m sync.Mutex m sync.Mutex
sc []subChan
} }
func NewProgressTracker() ProgressTracker { func NewProgressTracker() ProgressTracker {
return &progressTracker{ return &progressTracker{
subscribers: map[string]*subscriberState{}, step: NotStarted,
} }
} }
func (pt *progressTracker) Publish(id string) (chan<- string, error) { func (pt *progressTracker) Publish() (chan<- string, error) {
var state *subscriberState err := func() error {
func() {
pt.m.Lock() pt.m.Lock()
defer pt.m.Unlock() defer pt.m.Unlock()
state = pt.subscribers[id] if pt.step != NotStarted {
if state == nil { return InvalidState{pt.step}
state = &subscriberState{step: NotStarted}
pt.subscribers[id] = state
} }
}() pt.step = Publishing
err := func() error {
state.m.Lock()
defer state.m.Unlock()
if state.step != NotStarted {
return InvalidState{state.step}
}
state.step = Publishing
return nil return nil
}() }()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -100,17 +85,17 @@ func (pt *progressTracker) Publish(id string) (chan<- string, error) {
} }
var scs []subChan var scs []subChan
func() { func() {
state.m.Lock() pt.m.Lock()
defer state.m.Unlock() defer pt.m.Unlock()
scs = state.sc scs = pt.sc
if !prodChanOpen { if !prodChanOpen {
for _, subChan := range scs { pt.step = Done
close(subChan.c)
}
state.step = Done
} }
}() }()
if !prodChanOpen { if !prodChanOpen {
for _, subChan := range scs {
close(subChan.c)
}
return return
} }
for _, subChan := range scs { for _, subChan := range scs {
@ -127,29 +112,14 @@ func (pt *progressTracker) Publish(id string) (chan<- string, error) {
return prodChan, nil return prodChan, nil
} }
func (pt *progressTracker) Subscribe(id string) <-chan string { func (pt *progressTracker) Subscribe() <-chan string {
c := make(chan string, 1) c := make(chan string, 1)
sc := subChan{c: c} sc := subChan{c: c}
var state *subscriberState pt.m.Lock()
func() { defer pt.m.Unlock()
pt.m.Lock() if pt.step == Done {
defer pt.m.Unlock()
state = pt.subscribers[id]
if state == nil {
pt.subscribers[id] = &subscriberState{
step: NotStarted,
sc: []subChan{sc},
}
}
}()
if state == nil {
return c
}
state.m.Lock()
defer state.m.Unlock()
if state.step == Done {
return nil return nil
} }
state.sc = append(state.sc, sc) pt.sc = append(pt.sc, sc)
return c return c
} }

View File

@ -9,47 +9,43 @@ import (
func TestDupePublisher(t *testing.T) { func TestDupePublisher(t *testing.T) {
pt := NewProgressTracker() pt := NewProgressTracker()
if _, err := pt.Publish("foo"); err != nil { if _, err := pt.Publish(); err != nil {
t.Fatalf("First publisher should not give error, err:%v", err) t.Fatalf("First publisher should not give error, err:%v", err)
} }
if _, err := pt.Publish("foo"); err == nil { if _, err := pt.Publish(); err == nil {
t.Fatal("Dupe publisher should give error but got nil") t.Fatal("Dupe publisher should give error but got nil")
} else { } else {
t.Logf("Got err: %v", err) t.Logf("Got err: %v", err)
} }
if _, err := pt.Publish("bar"); err != nil {
t.Fatalf("Different publisher should not give error, err:%v", err)
}
} }
func TestSubSub(t *testing.T) { func TestSubSub(t *testing.T) {
pt := NewProgressTracker() pt := NewProgressTracker()
c1 := pt.Subscribe("foo") c1 := pt.Subscribe()
select { select {
case <-c1: case <-c1:
default: default:
} }
if c1 == nil { if c1 == nil {
t.Fatal("Subscriber should not get a closed channel") t.Fatal("Subscriber should not get a nil channel")
} }
c2 := pt.Subscribe("foo") c2 := pt.Subscribe()
if c2 == nil { if c2 == nil {
t.Fatal("Subscriber should not get a closed channel") t.Fatal("Subscriber should not get a nil channel")
} }
} }
func TestPubSub(t *testing.T) { func TestPubSub(t *testing.T) {
pt := NewProgressTracker() pt := NewProgressTracker()
pc, err := pt.Publish("foo") pc, err := pt.Publish()
if err != nil { if err != nil {
t.Fatalf("Unexpected err: %v", err) t.Fatalf("Unexpected err: %v", err)
} }
if pc == nil { if pc == nil {
t.Fatal("Should not get nil channel") t.Fatal("Should not get nil channel")
} }
sc := pt.Subscribe("foo") sc := pt.Subscribe()
if sc == nil { if sc == nil {
t.Fatal("Should not get nil channel") t.Fatal("Should not get nil channel")
} }
@ -70,8 +66,6 @@ func TestPubSub(t *testing.T) {
pc <- "blah" pc <- "blah"
time.Sleep(166 * time.Millisecond) time.Sleep(166 * time.Millisecond)
if i == 5 { if i == 5 {
// time.Sleep(100 * time.Millisecond)
//time.Sleep(1 * time.Second)
<-testc <-testc
} }
} }
@ -81,7 +75,7 @@ func TestPubSub(t *testing.T) {
t.Fatal("There should be atleast one update") t.Fatal("There should be atleast one update")
} }
t.Logf("c is :%d", c) t.Logf("c is :%d", c)
sc2 := pt.Subscribe("foo") sc2 := pt.Subscribe()
if sc2 != nil { if sc2 != nil {
t.Fatal("Subscriber after publisher done should return nil") t.Fatal("Subscriber after publisher done should return nil")
} }