diff --git a/anyhttp.go b/anyhttp.go index a05d983..dd1a74d 100644 --- a/anyhttp.go +++ b/anyhttp.go @@ -2,6 +2,7 @@ package anyhttp import ( + "context" "errors" "fmt" "io/fs" @@ -14,6 +15,8 @@ import ( "sync" "syscall" "time" + + "go.balki.me/anyhttp/idle" ) // AddressType of the address passed @@ -200,65 +203,59 @@ func (s *SysdConfig) GetListener() (net.Listener, error) { return nil, errors.New("neither FDIndex nor FDName set") } -// GetListener gets a unix or systemd socket listener -func GetListener(addr string) (AddressType, net.Listener, error) { - if strings.HasPrefix(addr, "unix/") { - usc := NewUnixSocketConfig(strings.TrimPrefix(addr, "unix/")) - l, err := usc.GetListener() - return UnixSocket, l, err - } - - if strings.HasPrefix(addr, "sysd/fdidx/") { - idx, err := strconv.Atoi(strings.TrimPrefix(addr, "sysd/fdidx/")) - if err != nil { - return Unknown, nil, fmt.Errorf("invalid fdidx, addr:%q err: %w", addr, err) - } - sysdc := NewSysDConfigWithFDIdx(idx) - l, err := sysdc.GetListener() - return SystemdFD, l, err - } - - if strings.HasPrefix(addr, "sysd/fdname/") { - sysdc := NewSysDConfigWithFDName(strings.TrimPrefix(addr, "sysd/fdname/")) - l, err := sysdc.GetListener() - return SystemdFD, l, err - } - - if port, err := strconv.Atoi(addr); err == nil { - if port > 0 && port < 65536 { - addr = fmt.Sprintf(":%v", port) - } else { - return Unknown, nil, fmt.Errorf("invalid port: %v", port) - } - } - - if addr == "" { - addr = ":http" - } - - l, err := net.Listen("tcp", addr) - return TCP, l, err -} - // Serve creates and serve a http server. -func Serve(addr string, h http.Handler) (AddressType, *http.Server, <-chan error, error) { - addrType, listener, err := GetListener(addr) +func Serve(addr string, h http.Handler) (addrType AddressType, srv *http.Server, idler idle.Idler, done <-chan error, err error) { + addrType, usc, sysc, err := ParseAddress(addr) if err != nil { - return addrType, nil, nil, err + return } - srv := &http.Server{Handler: h} - done := make(chan error) - go func() { - done <- srv.Serve(listener) - close(done) + + listener, err := func() (net.Listener, error) { + if usc != nil { + return usc.GetListener() + } else if sysc != nil { + return sysc.GetListener() + } else { + if addr == "" { + addr = ":http" + } + return net.Listen("tcp", addr) + } }() - return addrType, srv, done, nil + if err != nil { + return + } + errChan := make(chan error) + done = errChan + if addrType == SystemdFD && sysc.IdleTimeout != nil { + idler = idle.CreateIdler(*sysc.IdleTimeout) + srv = &http.Server{Handler: idle.WrapIdlerHandler(idler, h)} + waitErrChan := make(chan error) + go func() { + waitErrChan <- srv.Serve(listener) + close(waitErrChan) + }() + go func() { + select { + case err := <-waitErrChan: + errChan <- err + case <-idler.Chan(): + errChan <- srv.Shutdown(context.TODO()) + } + }() + } else { + srv = &http.Server{Handler: h} + go func() { + errChan <- srv.Serve(listener) + }() + } + return } // ListenAndServe is the drop-in replacement for `http.ListenAndServe`. // Supports unix and systemd sockets in addition func ListenAndServe(addr string, h http.Handler) error { - _, _, done, err := Serve(addr, h) + _, _, _, done, err := Serve(addr, h) if err != nil { return err }