diff --git a/anyhttp.go b/anyhttp.go index 7582381..5c06154 100644 --- a/anyhttp.go +++ b/anyhttp.go @@ -14,6 +14,20 @@ import ( "syscall" ) +// AddressType of the address passed +type AddressType string + +var ( + // UnixSocket - address is a unix socket, e.g. unix//run/foo.sock + UnixSocket AddressType = "UnixSocket" + // SystemdFD - address is a systemd fd, e.g. sysd/fdname/myapp.socket + SystemdFD AddressType = "SystemdFD" + // TCP - address is a TCP address, e.g. :1234 + TCP AddressType = "TCP" + // Unknown - address is not recognized + Unknown AddressType = "Unknown" +) + // UnixSocketConfig has the configuration for Unix socket type UnixSocketConfig struct { @@ -182,59 +196,69 @@ func (s *SysdConfig) GetListener() (net.Listener, error) { return nil, errors.New("neither FDIndex nor FDName set") } -// UnknownAddress Error is returned when address does not match any known syntax -type UnknownAddress struct{} - -func (u UnknownAddress) Error() string { - return "unknown address" -} - // GetListener gets a unix or systemd socket listener -func GetListener(addr string) (net.Listener, error) { +func GetListener(addr string) (AddressType, net.Listener, error) { if strings.HasPrefix(addr, "unix/") { usc := NewUnixSocketConfig(strings.TrimPrefix(addr, "unix/")) - return usc.GetListener() + l, err := usc.GetListener() + return UnixSocket, l, err } if strings.HasPrefix(addr, "sysd/fdidx/") { idx, err := strconv.Atoi(strings.TrimPrefix(addr, "sysd/fdidx/")) if err != nil { - return nil, fmt.Errorf("invalid fdidx, addr:%q err: %w", addr, err) + return Unknown, nil, fmt.Errorf("invalid fdidx, addr:%q err: %w", addr, err) } sysdc := NewSysDConfigWithFDIdx(idx) - return sysdc.GetListener() + l, err := sysdc.GetListener() + return SystemdFD, l, err } if strings.HasPrefix(addr, "sysd/fdname/") { sysdc := NewSysDConfigWithFDName(strings.TrimPrefix(addr, "sysd/fdname/")) - return sysdc.GetListener() + l, err := sysdc.GetListener() + return SystemdFD, l, err } - return nil, UnknownAddress{} + if port, err := strconv.Atoi(addr); err == nil { + if port > 0 && port < 65536 { + addr = fmt.Sprintf(":%v", port) + } else { + return Unknown, nil, fmt.Errorf("invalid port: %v", port) + } + } + + if addr == "" { + addr = ":http" + } + + l, err := net.Listen("tcp", addr) + return TCP, l, err +} + +// Serve creates and serve a http server. +func Serve(addr string, h http.Handler) (AddressType, *http.Server, <-chan error, error) { + addrType, listener, err := GetListener(addr) + if err != nil { + return addrType, nil, nil, err + } + srv := &http.Server{Handler: h} + done := make(chan error) + go func() { + done <- srv.Serve(listener) + close(done) + }() + return addrType, srv, done, nil } // ListenAndServe is the drop-in replacement for `http.ListenAndServe`. // Supports unix and systemd sockets in addition func ListenAndServe(addr string, h http.Handler) error { - - listener, err := GetListener(addr) - if _, isUnknown := err.(UnknownAddress); err != nil && !isUnknown { + _, _, done, err := Serve(addr, h) + if err != nil { return err } - - if listener != nil { - return http.Serve(listener, h) - } - - if port, err := strconv.Atoi(addr); err == nil { - if port > 0 && port < 65536 { - - return http.ListenAndServe(fmt.Sprintf(":%v", port), h) - } - return fmt.Errorf("invalid port: %v", port) - } - - return http.ListenAndServe(addr, h) + return <-done } // UnsetSystemdListenVars unsets the LISTEN* environment variables so they are not passed to any child processes diff --git a/idle/idle.go b/idle/idle.go index 5f3d2d7..ee3d81f 100644 --- a/idle/idle.go +++ b/idle/idle.go @@ -1,16 +1,75 @@ -// Package idle helps to gracefully shutdown idle servers +// Package idle helps to gracefully shutdown idle (typically http) servers package idle import ( + "fmt" + "net/http" "sync/atomic" "time" ) +var ( + // For simple servers without backgroud jobs, global singleton for simpler API + // Enter/Exit worn't work for global idler as Enter may be called before Wait, use CreateIdler in those cases + gIdler atomic.Pointer[idler] +) + +// Wait waits till the server is idle and returns. i.e. no Ticks in last duration +func Wait(timeout time.Duration) error { + i := CreateIdler(timeout).(*idler) + ok := gIdler.CompareAndSwap(nil, i) + if !ok { + return fmt.Errorf("idler already waiting") + } + i.Wait() + return nil +} + +// Tick records the current time. This will make the server not idle until next Tick or timeout +func Tick() { + i := gIdler.Load() + if i != nil { + i.Tick() + } +} + +// WrapHandler calls Tick() before processing passing request to http.Handler +func WrapHandler(h http.Handler) http.Handler { + if h == nil { + h = http.DefaultServeMux + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + Tick() + h.ServeHTTP(w, r) + }) +} + +// WrapIdlerHandler calls idler.Tick() before processing passing request to http.Handler +func WrapIdlerHandler(i Idler, h http.Handler) http.Handler { + if h == nil { + h = http.DefaultServeMux + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + i.Tick() + h.ServeHTTP(w, r) + }) +} + +// Idler helps manage idle servers type Idler interface { - Enter() - Exit() - Wait() + // Tick records the current time. This will make the server not idle until next Tick or timeout Tick() + + // Wait waits till the server is idle and returns. i.e. no Ticks in last duration + Wait() + + // For long running background jobs, use Enter to record start time. Wait will not return while there are active jobs running + Enter() + + // Exit records end of a background job + Exit() + + // Get the channel to wait yourself Chan() <-chan struct{} } @@ -29,6 +88,7 @@ func (i *idler) Exit() { i.active.Add(-1) } +// CreateIdler creates an Idler with given timeout func CreateIdler(timeout time.Duration) Idler { i := &idler{} i.c = make(chan struct{}) diff --git a/idle/idle_test.go b/idle/idle_test.go index bce4c51..069ea57 100644 --- a/idle/idle_test.go +++ b/idle/idle_test.go @@ -5,7 +5,30 @@ import ( "time" ) -func TestIdlerChan(t *testing.T) { - i := CreateIdler(1 * time.Second) +func TestIdlerChan(_ *testing.T) { + i := CreateIdler(10 * time.Millisecond) <-i.Chan() } + +func TestGlobalIdler(t *testing.T) { + err := Wait(10 * time.Millisecond) + if err != nil { + t.Fatalf("idle.Wait failed, %v", err) + } + err = Wait(10 * time.Millisecond) + if err == nil { + t.Fatal("idle.Wait should fail when called second time") + } +} + +func TestIdlerEnterExit(t *testing.T) { + i := CreateIdler(10 * time.Millisecond).(*idler) + i.Enter() + if i.active.Load() != 1 { + t.FailNow() + } + i.Exit() + if i.active.Load() != 0 { + t.FailNow() + } +}