This commit is contained in:
Balakrishnan Balasubramanian 2022-06-26 22:34:53 -04:00
parent d5126bfb71
commit 2ff289becf
3 changed files with 70 additions and 148 deletions

View File

@ -13,10 +13,10 @@ const (
Done Done
) )
type state struct { type subscriberState struct {
step Step step Step
m sync.Mutex m sync.Mutex
scs []chan<- string subChans []chan<- string
} }
type InvalidState struct{ Step Step } type InvalidState struct{ Step Step }
@ -39,115 +39,105 @@ type ProgressTracker interface {
} }
type progressTracker struct { type progressTracker struct {
subscribers map[string]*state subscribers map[string]*subscriberState
m sync.Mutex m sync.Mutex
} }
func NewProgressTracker() ProgressTracker { func NewProgressTracker() ProgressTracker {
return &progressTracker{ return &progressTracker{
subscribers: map[string]*state{}, subscribers: map[string]*subscriberState{},
} }
} }
func (pt *progressTracker) Publish(id string) (chan<- string, error) { func (pt *progressTracker) Publish(id string) (chan<- string, error) {
var ste *state var state *subscriberState
func() { func() {
pt.m.Lock() pt.m.Lock()
defer pt.m.Unlock() defer pt.m.Unlock()
ste = pt.subscribers[id] state = pt.subscribers[id]
if ste == nil { if state == nil {
fmt.Println("new pub") state = &subscriberState{step: NotStarted}
ste = &state{step: NotStarted} pt.subscribers[id] = state
pt.subscribers[id] = ste
} }
}() }()
err := func() error { err := func() error {
ste.m.Lock() state.m.Lock()
defer ste.m.Unlock() defer state.m.Unlock()
if ste.step != NotStarted { if state.step != NotStarted {
return InvalidState{ste.step} return InvalidState{state.step}
} }
ste.step = Publishing state.step = Publishing
return nil return nil
}() }()
if err != nil { if err != nil {
return nil, err return nil, err
} }
fmt.Println("About to start loop") prodChan := make(chan string, 100)
pc := make(chan string, 100)
go func() { go func() {
var upd string var update string
ok := true prodChanOpen := true
for { for {
fmt.Println("In producer loop") for prodChanOpen {
done := false update, prodChanOpen = <-prodChan
for !done && ok { if !prodChanOpen {
// upd, ok = <-pc break
// fmt.Println("got msg", upd, ok) }
select { select {
case upd, ok = <-pc: case update, prodChanOpen = <-prodChan:
fmt.Println("got msg in select", upd, ok)
default: default:
fmt.Println("in default") break
done = true
} }
fmt.Println(done, ok)
} }
fmt.Println("About to send to subscribers")
var scs []chan<- string var scs []chan<- string
func() { func() {
ste.m.Lock() state.m.Lock()
defer ste.m.Unlock() defer state.m.Unlock()
fmt.Println("Inside lock") scs = state.subChans
scs = ste.scs if !prodChanOpen {
if !ok { for _, subChan := range scs {
for _, sc := range scs { close(subChan)
fmt.Println("Closing subs channel")
fmt.Printf("From code: chan is %+v and %#v\n", sc, sc)
close(sc)
} }
ste.step = Done state.step = Done
} }
}() }()
fmt.Println("Len of subs", len(scs)) if !prodChanOpen {
if !ok {
fmt.Println("channel closed, good bye")
return return
} }
for _, sc := range scs { for _, subChan := range scs {
select { select {
case sc <- upd: case subChan <- update:
default: default:
} }
} }
} }
}() }()
return pc, nil return prodChan, nil
} }
func (pt *progressTracker) Subscribe(id string) <-chan string { func (pt *progressTracker) Subscribe(id string) <-chan string {
c := make(chan string) c := make(chan string, 1)
var state *subscriberState
func() {
pt.m.Lock() pt.m.Lock()
ste := pt.subscribers[id] defer pt.m.Unlock()
if ste == nil { state = pt.subscribers[id]
fmt.Println("new sub") if state == nil {
pt.subscribers[id] = &state{ pt.subscribers[id] = &subscriberState{
step: NotStarted, step: NotStarted,
scs: []chan<- string{c}, subChans: []chan<- string{c},
} }
} }
pt.m.Unlock() }()
if ste == nil { if state == nil {
return c return c
} }
ste.m.Lock() state.m.Lock()
defer ste.m.Unlock() defer state.m.Unlock()
if ste.step == Done { if state.step == Done {
return nil return nil
} }
fmt.Println("appending to scs") state.subChans = append(state.subChans, c)
ste.scs = append(ste.scs, c)
return c return c
} }

View File

@ -1,9 +1,9 @@
package pubsub package pubsub
import ( import (
"fmt"
"sync" "sync"
"testing" "testing"
"time"
) )
func TestDupePublisher(t *testing.T) { func TestDupePublisher(t *testing.T) {
@ -15,6 +15,8 @@ func TestDupePublisher(t *testing.T) {
if _, err := pt.Publish("foo"); err == nil { if _, err := pt.Publish("foo"); err == nil {
t.Fatal("Dupe publisher should give error but got nil") t.Fatal("Dupe publisher should give error but got nil")
} else {
t.Logf("Got err: %v", err)
} }
if _, err := pt.Publish("bar"); err != nil { if _, err := pt.Publish("bar"); err != nil {
@ -26,8 +28,6 @@ func TestSubSub(t *testing.T) {
pt := NewProgressTracker() pt := NewProgressTracker()
c1 := pt.Subscribe("foo") c1 := pt.Subscribe("foo")
select { select {
//Not Working: expected 1 expression
//case _, ok <- c1:
case <-c1: case <-c1:
default: default:
} }
@ -42,39 +42,34 @@ func TestSubSub(t *testing.T) {
func TestPubSub(t *testing.T) { func TestPubSub(t *testing.T) {
pt := NewProgressTracker() pt := NewProgressTracker()
pc, _ := pt.Publish("foo") pc, err := pt.Publish("foo")
if err != nil {
t.Fatalf("Unexpected err: %v", err)
}
if pc == nil { if pc == nil {
t.Fatal("Should not get nil channel") t.Fatal("Should not get nil channel")
} }
sc := pt.Subscribe("foo") sc := pt.Subscribe("foo")
fmt.Printf("From test: chan is %+v and %#v\n", sc, sc)
if sc == nil { if sc == nil {
t.Fatal("Should not get nil channel") t.Fatal("Should not get nil channel")
} }
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(1) wg.Add(1)
go func() { go func() {
fmt.Println("Subscriber start")
c := 0
for range sc { for range sc {
fmt.Println("Subscriber got msg")
c++
}
fmt.Println("Subscriber received close")
if c == 0 {
fmt.Println("Should have gotten update")
} }
wg.Done() wg.Done()
fmt.Println("Subscriber Done")
}() }()
fmt.Println("Producer Start")
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
fmt.Println("Producer Sent")
pc <- "blah" pc <- "blah"
if i == 4 || i == 5 {
time.Sleep(100 * time.Millisecond)
}
} }
close(pc) close(pc)
fmt.Println("Producer close")
t.Log("Now waiting")
fmt.Println("Now waiting")
wg.Wait() wg.Wait()
sc2 := pt.Subscribe("foo")
if sc2 != nil {
t.Fatal("Subscriber after publisher done should return nil")
}
} }

View File

@ -1,63 +0,0 @@
package main
import (
"fmt"
"sync"
"time"
"gitlab.com/balki/ytui/pubsub"
)
func main() {
fmt.Println("vim-go")
foo()
// bar()
}
func foo() {
pt := pubsub.NewProgressTracker()
pc, err := pt.Publish("id1")
sc := pt.Subscribe("id1")
if err != nil {
panic(err)
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
fmt.Println("subscriber loop")
for msg := range sc {
fmt.Println("received", msg)
}
fmt.Println("subscriber loop done")
wg.Done()
}()
for i := 0; i < 10; i++ {
msg := fmt.Sprint("msg_", i)
// fmt.Println("sending: ", msg)
pc <- msg
}
close(pc)
wg.Wait()
fmt.Println("sleeping")
time.Sleep(10 * time.Second)
}
func bar() {
c := make(chan int, 100)
for i := 0; i < 10; i++ {
c <- i
}
close(c)
ok := true
var v int
for ok {
select {
case v, ok = <-c:
fmt.Println(v, ok)
default:
fmt.Println("in default")
}
}
fmt.Println("done")
}