From 2ff289becfc98c37f205cf1cceb70328530a8995 Mon Sep 17 00:00:00 2001 From: balki <3070606-balki@users.noreply.gitlab.com> Date: Sun, 26 Jun 2022 22:34:53 -0400 Subject: [PATCH] cleanup --- pubsub/pt.go | 122 ++++++++++++++++++++----------------------- pubsub/pt_test.go | 33 +++++------- pubsub/ptexp/main.go | 63 ---------------------- 3 files changed, 70 insertions(+), 148 deletions(-) delete mode 100644 pubsub/ptexp/main.go diff --git a/pubsub/pt.go b/pubsub/pt.go index 005c1e0..346ac6e 100644 --- a/pubsub/pt.go +++ b/pubsub/pt.go @@ -13,10 +13,10 @@ const ( Done ) -type state struct { - step Step - m sync.Mutex - scs []chan<- string +type subscriberState struct { + step Step + m sync.Mutex + subChans []chan<- string } type InvalidState struct{ Step Step } @@ -39,115 +39,105 @@ type ProgressTracker interface { } type progressTracker struct { - subscribers map[string]*state + subscribers map[string]*subscriberState m sync.Mutex } func NewProgressTracker() ProgressTracker { return &progressTracker{ - subscribers: map[string]*state{}, + subscribers: map[string]*subscriberState{}, } } func (pt *progressTracker) Publish(id string) (chan<- string, error) { - var ste *state + var state *subscriberState func() { pt.m.Lock() defer pt.m.Unlock() - ste = pt.subscribers[id] - if ste == nil { - fmt.Println("new pub") - ste = &state{step: NotStarted} - pt.subscribers[id] = ste + state = pt.subscribers[id] + if state == nil { + state = &subscriberState{step: NotStarted} + pt.subscribers[id] = state } }() err := func() error { - ste.m.Lock() - defer ste.m.Unlock() - if ste.step != NotStarted { - return InvalidState{ste.step} + state.m.Lock() + defer state.m.Unlock() + if state.step != NotStarted { + return InvalidState{state.step} } - ste.step = Publishing + state.step = Publishing return nil }() if err != nil { return nil, err } - fmt.Println("About to start loop") - pc := make(chan string, 100) + prodChan := make(chan string, 100) go func() { - var upd string - ok := true + var update string + prodChanOpen := true for { - fmt.Println("In producer loop") - done := false - for !done && ok { - // upd, ok = <-pc - // fmt.Println("got msg", upd, ok) - select { - case upd, ok = <-pc: - fmt.Println("got msg in select", upd, ok) - default: - fmt.Println("in default") - done = true + for prodChanOpen { + update, prodChanOpen = <-prodChan + if !prodChanOpen { + break + } + select { + case update, prodChanOpen = <-prodChan: + default: + break } - fmt.Println(done, ok) } - fmt.Println("About to send to subscribers") var scs []chan<- string func() { - ste.m.Lock() - defer ste.m.Unlock() - fmt.Println("Inside lock") - scs = ste.scs - if !ok { - for _, sc := range scs { - fmt.Println("Closing subs channel") - fmt.Printf("From code: chan is %+v and %#v\n", sc, sc) - close(sc) + state.m.Lock() + defer state.m.Unlock() + scs = state.subChans + if !prodChanOpen { + for _, subChan := range scs { + close(subChan) } - ste.step = Done + state.step = Done } }() - fmt.Println("Len of subs", len(scs)) - if !ok { - fmt.Println("channel closed, good bye") + if !prodChanOpen { return } - for _, sc := range scs { + for _, subChan := range scs { select { - case sc <- upd: + case subChan <- update: default: } } } }() - return pc, nil + return prodChan, nil } func (pt *progressTracker) Subscribe(id string) <-chan string { - c := make(chan string) - pt.m.Lock() - ste := pt.subscribers[id] - if ste == nil { - fmt.Println("new sub") - pt.subscribers[id] = &state{ - step: NotStarted, - scs: []chan<- string{c}, + c := make(chan string, 1) + var state *subscriberState + func() { + pt.m.Lock() + defer pt.m.Unlock() + state = pt.subscribers[id] + if state == nil { + pt.subscribers[id] = &subscriberState{ + step: NotStarted, + subChans: []chan<- string{c}, + } } - } - pt.m.Unlock() - if ste == nil { + }() + if state == nil { return c } - ste.m.Lock() - defer ste.m.Unlock() - if ste.step == Done { + state.m.Lock() + defer state.m.Unlock() + if state.step == Done { return nil } - fmt.Println("appending to scs") - ste.scs = append(ste.scs, c) + state.subChans = append(state.subChans, c) return c } diff --git a/pubsub/pt_test.go b/pubsub/pt_test.go index e060738..26d8623 100644 --- a/pubsub/pt_test.go +++ b/pubsub/pt_test.go @@ -1,9 +1,9 @@ package pubsub import ( - "fmt" "sync" "testing" + "time" ) func TestDupePublisher(t *testing.T) { @@ -15,6 +15,8 @@ func TestDupePublisher(t *testing.T) { if _, err := pt.Publish("foo"); err == nil { t.Fatal("Dupe publisher should give error but got nil") + } else { + t.Logf("Got err: %v", err) } if _, err := pt.Publish("bar"); err != nil { @@ -26,8 +28,6 @@ func TestSubSub(t *testing.T) { pt := NewProgressTracker() c1 := pt.Subscribe("foo") select { - //Not Working: expected 1 expression - //case _, ok <- c1: case <-c1: default: } @@ -42,39 +42,34 @@ func TestSubSub(t *testing.T) { func TestPubSub(t *testing.T) { pt := NewProgressTracker() - pc, _ := pt.Publish("foo") + pc, err := pt.Publish("foo") + if err != nil { + t.Fatalf("Unexpected err: %v", err) + } if pc == nil { t.Fatal("Should not get nil channel") } sc := pt.Subscribe("foo") - fmt.Printf("From test: chan is %+v and %#v\n", sc, sc) if sc == nil { t.Fatal("Should not get nil channel") } wg := sync.WaitGroup{} wg.Add(1) go func() { - fmt.Println("Subscriber start") - c := 0 for range sc { - fmt.Println("Subscriber got msg") - c++ - } - fmt.Println("Subscriber received close") - if c == 0 { - fmt.Println("Should have gotten update") } wg.Done() - fmt.Println("Subscriber Done") }() - fmt.Println("Producer Start") for i := 0; i < 10; i++ { - fmt.Println("Producer Sent") pc <- "blah" + if i == 4 || i == 5 { + time.Sleep(100 * time.Millisecond) + } } close(pc) - fmt.Println("Producer close") - t.Log("Now waiting") - fmt.Println("Now waiting") wg.Wait() + sc2 := pt.Subscribe("foo") + if sc2 != nil { + t.Fatal("Subscriber after publisher done should return nil") + } } diff --git a/pubsub/ptexp/main.go b/pubsub/ptexp/main.go deleted file mode 100644 index 11acfd4..0000000 --- a/pubsub/ptexp/main.go +++ /dev/null @@ -1,63 +0,0 @@ -package main - -import ( - "fmt" - "sync" - "time" - - "gitlab.com/balki/ytui/pubsub" -) - -func main() { - fmt.Println("vim-go") - foo() - // bar() -} - -func foo() { - pt := pubsub.NewProgressTracker() - pc, err := pt.Publish("id1") - sc := pt.Subscribe("id1") - if err != nil { - panic(err) - } - var wg sync.WaitGroup - wg.Add(1) - go func() { - fmt.Println("subscriber loop") - for msg := range sc { - fmt.Println("received", msg) - } - fmt.Println("subscriber loop done") - wg.Done() - }() - for i := 0; i < 10; i++ { - msg := fmt.Sprint("msg_", i) - // fmt.Println("sending: ", msg) - pc <- msg - } - close(pc) - wg.Wait() - fmt.Println("sleeping") - time.Sleep(10 * time.Second) -} - -func bar() { - c := make(chan int, 100) - for i := 0; i < 10; i++ { - c <- i - } - close(c) - - ok := true - var v int - for ok { - select { - case v, ok = <-c: - fmt.Println(v, ok) - default: - fmt.Println("in default") - } - } - fmt.Println("done") -}