ytui/pubsub/pt.go

154 lines
3.0 KiB
Go
Raw Normal View History

2022-06-26 19:30:26 -04:00
package pubsub
import (
"fmt"
"sync"
)
type Step int
const (
NotStarted Step = iota
Publishing
Done
)
type state struct {
step Step
m sync.Mutex
scs []chan<- string
}
type InvalidState struct{ Step Step }
func (i InvalidState) Error() string {
return fmt.Sprintf("Invalid state: %v", i.Step)
}
// One thread publishes progress, one or more threads subscribes to watch the progress
// Subscribers may not get all updates. They will get the latest status when waiting on the channel
type ProgressTracker interface {
// Only one publisher sends update. Should close when done
// Error if there is/was existing publisher
Publish(id string) (chan<- string, error)
// Can subscribe even if there no publisher yet
// If already done, nil channel is returned
// channel will be closed when done
Subscribe(id string) <-chan string
}
type progressTracker struct {
subscribers map[string]*state
m sync.Mutex
}
func NewProgressTracker() ProgressTracker {
return &progressTracker{
subscribers: map[string]*state{},
}
}
func (pt *progressTracker) Publish(id string) (chan<- string, error) {
var ste *state
func() {
pt.m.Lock()
defer pt.m.Unlock()
ste = pt.subscribers[id]
if ste == nil {
fmt.Println("new pub")
ste = &state{step: NotStarted}
pt.subscribers[id] = ste
}
}()
err := func() error {
ste.m.Lock()
defer ste.m.Unlock()
if ste.step != NotStarted {
return InvalidState{ste.step}
}
ste.step = Publishing
return nil
}()
if err != nil {
return nil, err
}
fmt.Println("About to start loop")
pc := make(chan string, 100)
go func() {
var upd string
ok := true
for {
fmt.Println("In producer loop")
done := false
for !done && ok {
// upd, ok = <-pc
// fmt.Println("got msg", upd, ok)
select {
case upd, ok = <-pc:
fmt.Println("got msg in select", upd, ok)
default:
fmt.Println("in default")
done = true
}
fmt.Println(done, ok)
}
fmt.Println("About to send to subscribers")
var scs []chan<- string
func() {
ste.m.Lock()
defer ste.m.Unlock()
fmt.Println("Inside lock")
scs = ste.scs
if !ok {
for _, sc := range scs {
fmt.Println("Closing subs channel")
fmt.Printf("From code: chan is %+v and %#v\n", sc, sc)
close(sc)
}
ste.step = Done
}
}()
fmt.Println("Len of subs", len(scs))
if !ok {
fmt.Println("channel closed, good bye")
return
}
for _, sc := range scs {
select {
case sc <- upd:
default:
}
}
}
}()
return pc, nil
}
func (pt *progressTracker) Subscribe(id string) <-chan string {
c := make(chan string)
pt.m.Lock()
ste := pt.subscribers[id]
if ste == nil {
fmt.Println("new sub")
pt.subscribers[id] = &state{
step: NotStarted,
scs: []chan<- string{c},
}
}
pt.m.Unlock()
if ste == nil {
return c
}
ste.m.Lock()
defer ste.m.Unlock()
if ste.step == Done {
return nil
}
fmt.Println("appending to scs")
ste.scs = append(ste.scs, c)
return c
}