remove map

This commit is contained in:
Balakrishnan Balasubramanian 2022-06-27 13:05:00 -04:00
parent 1fab1e0095
commit 882f8aea2f
2 changed files with 38 additions and 74 deletions

View File

@ -15,17 +15,6 @@ const (
Done
)
type subChan struct {
c chan<- string
lastUpdateIndex int
}
type subscriberState struct {
step Step
m sync.Mutex
sc []subChan
}
type InvalidState struct{ Step Step }
func (i InvalidState) Error() string {
@ -37,46 +26,42 @@ func (i InvalidState) Error() string {
type ProgressTracker interface {
// Only one publisher sends update. Should close when done
// Error if there is/was existing publisher
Publish(id string) (chan<- string, error)
Publish() (chan<- string, error)
// Can subscribe even if there no publisher yet
// If already done, nil channel is returned
// channel will be closed when done
Subscribe(id string) <-chan string
Subscribe() <-chan string
}
type subChan struct {
c chan<- string
lastUpdateIndex int
}
type progressTracker struct {
subscribers map[string]*subscriberState
m sync.Mutex
step Step
m sync.Mutex
sc []subChan
}
func NewProgressTracker() ProgressTracker {
return &progressTracker{
subscribers: map[string]*subscriberState{},
step: NotStarted,
}
}
func (pt *progressTracker) Publish(id string) (chan<- string, error) {
var state *subscriberState
func() {
func (pt *progressTracker) Publish() (chan<- string, error) {
err := func() error {
pt.m.Lock()
defer pt.m.Unlock()
state = pt.subscribers[id]
if state == nil {
state = &subscriberState{step: NotStarted}
pt.subscribers[id] = state
if pt.step != NotStarted {
return InvalidState{pt.step}
}
}()
err := func() error {
state.m.Lock()
defer state.m.Unlock()
if state.step != NotStarted {
return InvalidState{state.step}
}
state.step = Publishing
pt.step = Publishing
return nil
}()
if err != nil {
return nil, err
}
@ -100,17 +85,17 @@ func (pt *progressTracker) Publish(id string) (chan<- string, error) {
}
var scs []subChan
func() {
state.m.Lock()
defer state.m.Unlock()
scs = state.sc
pt.m.Lock()
defer pt.m.Unlock()
scs = pt.sc
if !prodChanOpen {
for _, subChan := range scs {
close(subChan.c)
}
state.step = Done
pt.step = Done
}
}()
if !prodChanOpen {
for _, subChan := range scs {
close(subChan.c)
}
return
}
for _, subChan := range scs {
@ -127,29 +112,14 @@ func (pt *progressTracker) Publish(id string) (chan<- string, error) {
return prodChan, nil
}
func (pt *progressTracker) Subscribe(id string) <-chan string {
func (pt *progressTracker) Subscribe() <-chan string {
c := make(chan string, 1)
sc := subChan{c: c}
var state *subscriberState
func() {
pt.m.Lock()
defer pt.m.Unlock()
state = pt.subscribers[id]
if state == nil {
pt.subscribers[id] = &subscriberState{
step: NotStarted,
sc: []subChan{sc},
}
}
}()
if state == nil {
return c
}
state.m.Lock()
defer state.m.Unlock()
if state.step == Done {
pt.m.Lock()
defer pt.m.Unlock()
if pt.step == Done {
return nil
}
state.sc = append(state.sc, sc)
pt.sc = append(pt.sc, sc)
return c
}

View File

@ -9,47 +9,43 @@ import (
func TestDupePublisher(t *testing.T) {
pt := NewProgressTracker()
if _, err := pt.Publish("foo"); err != nil {
if _, err := pt.Publish(); err != nil {
t.Fatalf("First publisher should not give error, err:%v", err)
}
if _, err := pt.Publish("foo"); err == nil {
if _, err := pt.Publish(); err == nil {
t.Fatal("Dupe publisher should give error but got nil")
} else {
t.Logf("Got err: %v", err)
}
if _, err := pt.Publish("bar"); err != nil {
t.Fatalf("Different publisher should not give error, err:%v", err)
}
}
func TestSubSub(t *testing.T) {
pt := NewProgressTracker()
c1 := pt.Subscribe("foo")
c1 := pt.Subscribe()
select {
case <-c1:
default:
}
if c1 == nil {
t.Fatal("Subscriber should not get a closed channel")
t.Fatal("Subscriber should not get a nil channel")
}
c2 := pt.Subscribe("foo")
c2 := pt.Subscribe()
if c2 == nil {
t.Fatal("Subscriber should not get a closed channel")
t.Fatal("Subscriber should not get a nil channel")
}
}
func TestPubSub(t *testing.T) {
pt := NewProgressTracker()
pc, err := pt.Publish("foo")
pc, err := pt.Publish()
if err != nil {
t.Fatalf("Unexpected err: %v", err)
}
if pc == nil {
t.Fatal("Should not get nil channel")
}
sc := pt.Subscribe("foo")
sc := pt.Subscribe()
if sc == nil {
t.Fatal("Should not get nil channel")
}
@ -70,8 +66,6 @@ func TestPubSub(t *testing.T) {
pc <- "blah"
time.Sleep(166 * time.Millisecond)
if i == 5 {
// time.Sleep(100 * time.Millisecond)
//time.Sleep(1 * time.Second)
<-testc
}
}
@ -81,7 +75,7 @@ func TestPubSub(t *testing.T) {
t.Fatal("There should be atleast one update")
}
t.Logf("c is :%d", c)
sc2 := pt.Subscribe("foo")
sc2 := pt.Subscribe()
if sc2 != nil {
t.Fatal("Subscriber after publisher done should return nil")
}