From b5588989def9dc7c65e2a341ae12977bf24de62f Mon Sep 17 00:00:00 2001 From: Balakrishnan Balasubramanian Date: Fri, 8 Sep 2023 19:10:36 -0400 Subject: [PATCH] Add function to return address type --- anyhttp.go | 75 +++++++++++++++++++++++++++++++----------------------- 1 file changed, 43 insertions(+), 32 deletions(-) diff --git a/anyhttp.go b/anyhttp.go index 7582381..ba8355f 100644 --- a/anyhttp.go +++ b/anyhttp.go @@ -14,6 +14,15 @@ import ( "syscall" ) +type AddressType string + +var ( + UnixSocket AddressType = "UnixSocket" + SystemdFD AddressType = "SystemdFD" + TCP AddressType = "TCP" + Unknown AddressType = "Unknown" +) + // UnixSocketConfig has the configuration for Unix socket type UnixSocketConfig struct { @@ -182,59 +191,61 @@ func (s *SysdConfig) GetListener() (net.Listener, error) { return nil, errors.New("neither FDIndex nor FDName set") } -// UnknownAddress Error is returned when address does not match any known syntax -type UnknownAddress struct{} - -func (u UnknownAddress) Error() string { - return "unknown address" -} - // GetListener gets a unix or systemd socket listener -func GetListener(addr string) (net.Listener, error) { +func GetListener(addr string) (AddressType, net.Listener, error) { if strings.HasPrefix(addr, "unix/") { usc := NewUnixSocketConfig(strings.TrimPrefix(addr, "unix/")) - return usc.GetListener() + l, err := usc.GetListener() + return UnixSocket, l, err } if strings.HasPrefix(addr, "sysd/fdidx/") { idx, err := strconv.Atoi(strings.TrimPrefix(addr, "sysd/fdidx/")) if err != nil { - return nil, fmt.Errorf("invalid fdidx, addr:%q err: %w", addr, err) + return Unknown, nil, fmt.Errorf("invalid fdidx, addr:%q err: %w", addr, err) } sysdc := NewSysDConfigWithFDIdx(idx) - return sysdc.GetListener() + l, err := sysdc.GetListener() + return SystemdFD, l, err } if strings.HasPrefix(addr, "sysd/fdname/") { sysdc := NewSysDConfigWithFDName(strings.TrimPrefix(addr, "sysd/fdname/")) - return sysdc.GetListener() + l, err := sysdc.GetListener() + return SystemdFD, l, err } - return nil, UnknownAddress{} + if port, err := strconv.Atoi(addr); err == nil { + if port > 0 && port < 65536 { + addr = fmt.Sprintf(":%v", port) + } else { + return Unknown, nil, fmt.Errorf("invalid port: %v", port) + } + } + + if addr == "" { + addr = ":http" + } + + l, err := net.Listen("tcp", addr) + return TCP, l, err +} + +func ListenAndServeHTTP(addr string, h http.Handler) (AddressType, *http.Server, error) { + addrType, listener, err := GetListener(addr) + if err != nil { + return addrType, nil, err + } + srv := &http.Server{Handler: h} + err = srv.Serve(listener) + return addrType, srv, err } // ListenAndServe is the drop-in replacement for `http.ListenAndServe`. // Supports unix and systemd sockets in addition func ListenAndServe(addr string, h http.Handler) error { - - listener, err := GetListener(addr) - if _, isUnknown := err.(UnknownAddress); err != nil && !isUnknown { - return err - } - - if listener != nil { - return http.Serve(listener, h) - } - - if port, err := strconv.Atoi(addr); err == nil { - if port > 0 && port < 65536 { - - return http.ListenAndServe(fmt.Sprintf(":%v", port), h) - } - return fmt.Errorf("invalid port: %v", port) - } - - return http.ListenAndServe(addr, h) + _, _, err := ListenAndServeHTTP(addr, h) + return err } // UnsetSystemdListenVars unsets the LISTEN* environment variables so they are not passed to any child processes