154 lines
3.0 KiB
Go
154 lines
3.0 KiB
Go
|
package pubsub
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"sync"
|
||
|
)
|
||
|
|
||
|
type Step int
|
||
|
|
||
|
const (
|
||
|
NotStarted Step = iota
|
||
|
Publishing
|
||
|
Done
|
||
|
)
|
||
|
|
||
|
type state struct {
|
||
|
step Step
|
||
|
m sync.Mutex
|
||
|
scs []chan<- string
|
||
|
}
|
||
|
|
||
|
type InvalidState struct{ Step Step }
|
||
|
|
||
|
func (i InvalidState) Error() string {
|
||
|
return fmt.Sprintf("Invalid state: %v", i.Step)
|
||
|
}
|
||
|
|
||
|
// One thread publishes progress, one or more threads subscribes to watch the progress
|
||
|
// Subscribers may not get all updates. They will get the latest status when waiting on the channel
|
||
|
type ProgressTracker interface {
|
||
|
// Only one publisher sends update. Should close when done
|
||
|
// Error if there is/was existing publisher
|
||
|
Publish(id string) (chan<- string, error)
|
||
|
|
||
|
// Can subscribe even if there no publisher yet
|
||
|
// If already done, nil channel is returned
|
||
|
// channel will be closed when done
|
||
|
Subscribe(id string) <-chan string
|
||
|
}
|
||
|
|
||
|
type progressTracker struct {
|
||
|
subscribers map[string]*state
|
||
|
m sync.Mutex
|
||
|
}
|
||
|
|
||
|
func NewProgressTracker() ProgressTracker {
|
||
|
return &progressTracker{
|
||
|
subscribers: map[string]*state{},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (pt *progressTracker) Publish(id string) (chan<- string, error) {
|
||
|
var ste *state
|
||
|
func() {
|
||
|
pt.m.Lock()
|
||
|
defer pt.m.Unlock()
|
||
|
ste = pt.subscribers[id]
|
||
|
if ste == nil {
|
||
|
fmt.Println("new pub")
|
||
|
ste = &state{step: NotStarted}
|
||
|
pt.subscribers[id] = ste
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
err := func() error {
|
||
|
ste.m.Lock()
|
||
|
defer ste.m.Unlock()
|
||
|
if ste.step != NotStarted {
|
||
|
return InvalidState{ste.step}
|
||
|
}
|
||
|
ste.step = Publishing
|
||
|
return nil
|
||
|
}()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
fmt.Println("About to start loop")
|
||
|
pc := make(chan string, 100)
|
||
|
go func() {
|
||
|
var upd string
|
||
|
ok := true
|
||
|
for {
|
||
|
fmt.Println("In producer loop")
|
||
|
done := false
|
||
|
for !done && ok {
|
||
|
// upd, ok = <-pc
|
||
|
// fmt.Println("got msg", upd, ok)
|
||
|
select {
|
||
|
case upd, ok = <-pc:
|
||
|
fmt.Println("got msg in select", upd, ok)
|
||
|
default:
|
||
|
fmt.Println("in default")
|
||
|
done = true
|
||
|
}
|
||
|
fmt.Println(done, ok)
|
||
|
}
|
||
|
fmt.Println("About to send to subscribers")
|
||
|
var scs []chan<- string
|
||
|
func() {
|
||
|
ste.m.Lock()
|
||
|
defer ste.m.Unlock()
|
||
|
fmt.Println("Inside lock")
|
||
|
scs = ste.scs
|
||
|
if !ok {
|
||
|
for _, sc := range scs {
|
||
|
fmt.Println("Closing subs channel")
|
||
|
fmt.Printf("From code: chan is %+v and %#v\n", sc, sc)
|
||
|
close(sc)
|
||
|
}
|
||
|
ste.step = Done
|
||
|
}
|
||
|
}()
|
||
|
fmt.Println("Len of subs", len(scs))
|
||
|
if !ok {
|
||
|
fmt.Println("channel closed, good bye")
|
||
|
return
|
||
|
}
|
||
|
for _, sc := range scs {
|
||
|
select {
|
||
|
case sc <- upd:
|
||
|
default:
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}()
|
||
|
return pc, nil
|
||
|
}
|
||
|
|
||
|
func (pt *progressTracker) Subscribe(id string) <-chan string {
|
||
|
c := make(chan string)
|
||
|
pt.m.Lock()
|
||
|
ste := pt.subscribers[id]
|
||
|
if ste == nil {
|
||
|
fmt.Println("new sub")
|
||
|
pt.subscribers[id] = &state{
|
||
|
step: NotStarted,
|
||
|
scs: []chan<- string{c},
|
||
|
}
|
||
|
}
|
||
|
pt.m.Unlock()
|
||
|
if ste == nil {
|
||
|
return c
|
||
|
}
|
||
|
ste.m.Lock()
|
||
|
defer ste.m.Unlock()
|
||
|
if ste.step == Done {
|
||
|
return nil
|
||
|
}
|
||
|
fmt.Println("appending to scs")
|
||
|
ste.scs = append(ste.scs, c)
|
||
|
return c
|
||
|
}
|