156 lines
3.0 KiB
Go
156 lines
3.0 KiB
Go
package pubsub
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
//go:generate stringer -type=Step
|
|
type Step int
|
|
|
|
const (
|
|
NotStarted Step = iota
|
|
Publishing
|
|
Done
|
|
)
|
|
|
|
type subChan struct {
|
|
c chan<- string
|
|
lastUpdateIndex int
|
|
}
|
|
|
|
type subscriberState struct {
|
|
step Step
|
|
m sync.Mutex
|
|
sc []subChan
|
|
}
|
|
|
|
type InvalidState struct{ Step Step }
|
|
|
|
func (i InvalidState) Error() string {
|
|
return fmt.Sprintf("Invalid state: %v", i.Step)
|
|
}
|
|
|
|
// One thread publishes progress, one or more threads subscribes to watch the progress
|
|
// Subscribers may not get all updates. They will get the latest status when waiting on the channel
|
|
type ProgressTracker interface {
|
|
// Only one publisher sends update. Should close when done
|
|
// Error if there is/was existing publisher
|
|
Publish(id string) (chan<- string, error)
|
|
|
|
// Can subscribe even if there no publisher yet
|
|
// If already done, nil channel is returned
|
|
// channel will be closed when done
|
|
Subscribe(id string) <-chan string
|
|
}
|
|
|
|
type progressTracker struct {
|
|
subscribers map[string]*subscriberState
|
|
m sync.Mutex
|
|
}
|
|
|
|
func NewProgressTracker() ProgressTracker {
|
|
return &progressTracker{
|
|
subscribers: map[string]*subscriberState{},
|
|
}
|
|
}
|
|
|
|
func (pt *progressTracker) Publish(id string) (chan<- string, error) {
|
|
var state *subscriberState
|
|
func() {
|
|
pt.m.Lock()
|
|
defer pt.m.Unlock()
|
|
state = pt.subscribers[id]
|
|
if state == nil {
|
|
state = &subscriberState{step: NotStarted}
|
|
pt.subscribers[id] = state
|
|
}
|
|
}()
|
|
|
|
err := func() error {
|
|
state.m.Lock()
|
|
defer state.m.Unlock()
|
|
if state.step != NotStarted {
|
|
return InvalidState{state.step}
|
|
}
|
|
state.step = Publishing
|
|
return nil
|
|
}()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
prodChan := make(chan string, 100)
|
|
go func() {
|
|
var update string
|
|
prodChanOpen := true
|
|
ticker := time.NewTicker(100 * time.Millisecond)
|
|
defer ticker.Stop()
|
|
lastUpdateIndex := 0
|
|
for range ticker.C {
|
|
LoopReader:
|
|
for prodChanOpen {
|
|
select {
|
|
case update, prodChanOpen = <-prodChan:
|
|
lastUpdateIndex++
|
|
default:
|
|
break LoopReader
|
|
}
|
|
}
|
|
var scs []subChan
|
|
func() {
|
|
state.m.Lock()
|
|
defer state.m.Unlock()
|
|
scs = state.sc
|
|
if !prodChanOpen {
|
|
for _, subChan := range scs {
|
|
close(subChan.c)
|
|
}
|
|
state.step = Done
|
|
}
|
|
}()
|
|
if !prodChanOpen {
|
|
return
|
|
}
|
|
for _, subChan := range scs {
|
|
if subChan.lastUpdateIndex != lastUpdateIndex {
|
|
select {
|
|
case subChan.c <- update:
|
|
subChan.lastUpdateIndex = lastUpdateIndex
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
return prodChan, nil
|
|
}
|
|
|
|
func (pt *progressTracker) Subscribe(id string) <-chan string {
|
|
c := make(chan string, 1)
|
|
sc := subChan{c: c}
|
|
var state *subscriberState
|
|
func() {
|
|
pt.m.Lock()
|
|
defer pt.m.Unlock()
|
|
state = pt.subscribers[id]
|
|
if state == nil {
|
|
pt.subscribers[id] = &subscriberState{
|
|
step: NotStarted,
|
|
sc: []subChan{sc},
|
|
}
|
|
}
|
|
}()
|
|
if state == nil {
|
|
return c
|
|
}
|
|
state.m.Lock()
|
|
defer state.m.Unlock()
|
|
if state.step == Done {
|
|
return nil
|
|
}
|
|
state.sc = append(state.sc, sc)
|
|
return c
|
|
}
|