diff --git a/pubsub/pt.go b/pubsub/pt.go new file mode 100644 index 0000000..005c1e0 --- /dev/null +++ b/pubsub/pt.go @@ -0,0 +1,153 @@ +package pubsub + +import ( + "fmt" + "sync" +) + +type Step int + +const ( + NotStarted Step = iota + Publishing + Done +) + +type state struct { + step Step + m sync.Mutex + scs []chan<- string +} + +type InvalidState struct{ Step Step } + +func (i InvalidState) Error() string { + return fmt.Sprintf("Invalid state: %v", i.Step) +} + +// One thread publishes progress, one or more threads subscribes to watch the progress +// Subscribers may not get all updates. They will get the latest status when waiting on the channel +type ProgressTracker interface { + // Only one publisher sends update. Should close when done + // Error if there is/was existing publisher + Publish(id string) (chan<- string, error) + + // Can subscribe even if there no publisher yet + // If already done, nil channel is returned + // channel will be closed when done + Subscribe(id string) <-chan string +} + +type progressTracker struct { + subscribers map[string]*state + m sync.Mutex +} + +func NewProgressTracker() ProgressTracker { + return &progressTracker{ + subscribers: map[string]*state{}, + } +} + +func (pt *progressTracker) Publish(id string) (chan<- string, error) { + var ste *state + func() { + pt.m.Lock() + defer pt.m.Unlock() + ste = pt.subscribers[id] + if ste == nil { + fmt.Println("new pub") + ste = &state{step: NotStarted} + pt.subscribers[id] = ste + } + }() + + err := func() error { + ste.m.Lock() + defer ste.m.Unlock() + if ste.step != NotStarted { + return InvalidState{ste.step} + } + ste.step = Publishing + return nil + }() + if err != nil { + return nil, err + } + + fmt.Println("About to start loop") + pc := make(chan string, 100) + go func() { + var upd string + ok := true + for { + fmt.Println("In producer loop") + done := false + for !done && ok { + // upd, ok = <-pc + // fmt.Println("got msg", upd, ok) + select { + case upd, ok = <-pc: + fmt.Println("got msg in select", upd, ok) + default: + fmt.Println("in default") + done = true + } + fmt.Println(done, ok) + } + fmt.Println("About to send to subscribers") + var scs []chan<- string + func() { + ste.m.Lock() + defer ste.m.Unlock() + fmt.Println("Inside lock") + scs = ste.scs + if !ok { + for _, sc := range scs { + fmt.Println("Closing subs channel") + fmt.Printf("From code: chan is %+v and %#v\n", sc, sc) + close(sc) + } + ste.step = Done + } + }() + fmt.Println("Len of subs", len(scs)) + if !ok { + fmt.Println("channel closed, good bye") + return + } + for _, sc := range scs { + select { + case sc <- upd: + default: + } + } + } + }() + return pc, nil +} + +func (pt *progressTracker) Subscribe(id string) <-chan string { + c := make(chan string) + pt.m.Lock() + ste := pt.subscribers[id] + if ste == nil { + fmt.Println("new sub") + pt.subscribers[id] = &state{ + step: NotStarted, + scs: []chan<- string{c}, + } + } + pt.m.Unlock() + if ste == nil { + return c + } + ste.m.Lock() + defer ste.m.Unlock() + if ste.step == Done { + return nil + } + fmt.Println("appending to scs") + ste.scs = append(ste.scs, c) + return c +} diff --git a/pubsub/pt_test.go b/pubsub/pt_test.go new file mode 100644 index 0000000..e060738 --- /dev/null +++ b/pubsub/pt_test.go @@ -0,0 +1,80 @@ +package pubsub + +import ( + "fmt" + "sync" + "testing" +) + +func TestDupePublisher(t *testing.T) { + pt := NewProgressTracker() + + if _, err := pt.Publish("foo"); err != nil { + t.Fatalf("First publisher should not give error, err:%v", err) + } + + if _, err := pt.Publish("foo"); err == nil { + t.Fatal("Dupe publisher should give error but got nil") + } + + if _, err := pt.Publish("bar"); err != nil { + t.Fatalf("Different publisher should not give error, err:%v", err) + } +} + +func TestSubSub(t *testing.T) { + pt := NewProgressTracker() + c1 := pt.Subscribe("foo") + select { + //Not Working: expected 1 expression + //case _, ok <- c1: + case <-c1: + default: + } + if c1 == nil { + t.Fatal("Subscriber should not get a closed channel") + } + c2 := pt.Subscribe("foo") + if c2 == nil { + t.Fatal("Subscriber should not get a closed channel") + } +} + +func TestPubSub(t *testing.T) { + pt := NewProgressTracker() + pc, _ := pt.Publish("foo") + if pc == nil { + t.Fatal("Should not get nil channel") + } + sc := pt.Subscribe("foo") + fmt.Printf("From test: chan is %+v and %#v\n", sc, sc) + if sc == nil { + t.Fatal("Should not get nil channel") + } + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + fmt.Println("Subscriber start") + c := 0 + for range sc { + fmt.Println("Subscriber got msg") + c++ + } + fmt.Println("Subscriber received close") + if c == 0 { + fmt.Println("Should have gotten update") + } + wg.Done() + fmt.Println("Subscriber Done") + }() + fmt.Println("Producer Start") + for i := 0; i < 10; i++ { + fmt.Println("Producer Sent") + pc <- "blah" + } + close(pc) + fmt.Println("Producer close") + t.Log("Now waiting") + fmt.Println("Now waiting") + wg.Wait() +} diff --git a/pubsub/ptexp/main.go b/pubsub/ptexp/main.go new file mode 100644 index 0000000..11acfd4 --- /dev/null +++ b/pubsub/ptexp/main.go @@ -0,0 +1,63 @@ +package main + +import ( + "fmt" + "sync" + "time" + + "gitlab.com/balki/ytui/pubsub" +) + +func main() { + fmt.Println("vim-go") + foo() + // bar() +} + +func foo() { + pt := pubsub.NewProgressTracker() + pc, err := pt.Publish("id1") + sc := pt.Subscribe("id1") + if err != nil { + panic(err) + } + var wg sync.WaitGroup + wg.Add(1) + go func() { + fmt.Println("subscriber loop") + for msg := range sc { + fmt.Println("received", msg) + } + fmt.Println("subscriber loop done") + wg.Done() + }() + for i := 0; i < 10; i++ { + msg := fmt.Sprint("msg_", i) + // fmt.Println("sending: ", msg) + pc <- msg + } + close(pc) + wg.Wait() + fmt.Println("sleeping") + time.Sleep(10 * time.Second) +} + +func bar() { + c := make(chan int, 100) + for i := 0; i < 10; i++ { + c <- i + } + close(c) + + ok := true + var v int + for ok { + select { + case v, ok = <-c: + fmt.Println(v, ok) + default: + fmt.Println("in default") + } + } + fmt.Println("done") +}