diff --git a/pkg/api/backend_test.go b/pkg/api/backend_test.go index 1a71c4b..6d3cbf2 100644 --- a/pkg/api/backend_test.go +++ b/pkg/api/backend_test.go @@ -3,6 +3,7 @@ package api import ( "context" "errors" + "net" "net/http" "testing" "time" @@ -69,7 +70,7 @@ func TestBackendFetchCredential(t *testing.T) { srvCtx, srvCancel := context.WithTimeout(context.Background(), time.Minute) defer srvCancel() - go startTestBackend(srvCtx, "localhost:5555") + startTestBackend(srvCtx, "localhost:5555") for _, ex := range examples { t.Run(ex.name, func(t *testing.T) { @@ -139,17 +140,42 @@ func startTestBackend(ctx context.Context, listenAddr string) { }) server := &http.Server{Addr: listenAddr, Handler: router} - go mustStartServer(server) + mustStartServer(server) - <-ctx.Done() - if err := server.Shutdown(context.Background()); err != nil && err != http.ErrServerClosed { - panic(err) - } + go func() { + <-ctx.Done() + if err := server.Shutdown(context.Background()); err != nil && err != http.ErrServerClosed { + panic(err) + } + }() } func mustStartServer(server *http.Server) { - err := server.ListenAndServe() - if err != nil && err != http.ErrServerClosed { + go func() { + err := server.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + panic(err) + } + }() + + if err := waitForServer(server.Addr, 5); err != nil { panic(err) } } + +func waitForServer(addr string, n int) error { + var lastErr error + + for i := 0; i < n; i++ { + conn, err := net.Dial("tcp", addr) + if err == nil { + conn.Close() + return nil + } + + lastErr = err + time.Sleep(time.Millisecond * 100) + } + + return lastErr +}