From bbef4be4a7c8ee56d910786f98278438303ec2c8 Mon Sep 17 00:00:00 2001 From: Balakrishnan Balasubramanian Date: Sun, 30 Apr 2023 23:37:46 -0400 Subject: [PATCH] Refactor environment parsing Systemd environment variables LISTEN* are unset by default and saved for future calls --- anyhttp.go | 82 ++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 55 insertions(+), 27 deletions(-) diff --git a/anyhttp.go b/anyhttp.go index 790212b..7582381 100644 --- a/anyhttp.go +++ b/anyhttp.go @@ -10,6 +10,7 @@ import ( "os" "strconv" "strings" + "sync" "syscall" ) @@ -39,6 +40,39 @@ func NewUnixSocketConfig(socketPath string) UnixSocketConfig { return usc } +type sysdEnvData struct { + pid int + fdNames []string + fdNamesStr string + numFds int +} + +var sysdEnvParser = struct { + sysdOnce sync.Once + data sysdEnvData + err error +}{} + +func parse() (sysdEnvData, error) { + p := &sysdEnvParser + p.sysdOnce.Do(func() { + p.data.pid, p.err = strconv.Atoi(os.Getenv("LISTEN_PID")) + if p.err != nil { + p.err = fmt.Errorf("invalid LISTEN_PID, err: %w", p.err) + return + } + p.data.numFds, p.err = strconv.Atoi(os.Getenv("LISTEN_FDS")) + if p.err != nil { + p.err = fmt.Errorf("invalid LISTEN_FDS, err: %w", p.err) + return + } + p.data.fdNamesStr = os.Getenv("LISTEN_FDNAMES") + p.data.fdNames = strings.Split(p.data.fdNamesStr, ":") + + }) + return p.data, p.err +} + // SysdConfig has the configuration for the socket activated fd type SysdConfig struct { // Integer value starting at 0. Either index or name is required @@ -54,7 +88,7 @@ type SysdConfig struct { // DefaultSysdConfig has the default values for SysdConfig var DefaultSysdConfig = SysdConfig{ CheckPID: true, - UnsetEnv: false, + UnsetEnv: true, } // NewSysDConfigWithFDIdx creates SysdConfig with defaults and fdIdx @@ -112,53 +146,47 @@ func (s *SysdConfig) GetListener() (net.Listener, error) { defer UnsetSystemdListenVars() } - if s.CheckPID { - pid, err := strconv.Atoi(os.Getenv("LISTEN_PID")) - if err != nil { - return nil, fmt.Errorf("invalid LISTEN_PID, err: %w", err) - } - if pid != os.Getpid() { - return nil, fmt.Errorf("unexpected PID, current:%v, LISTEN_PID: %v", os.Getpid(), pid) - } - } - - numFds, err := strconv.Atoi(os.Getenv("LISTEN_FDS")) + envData, err := parse() if err != nil { - return nil, fmt.Errorf("invalid LISTEN_FDS, err: %w", err) + return nil, err } - fdNames := strings.Split(os.Getenv("LISTEN_FDNAMES"), ":") + if s.CheckPID { + if envData.pid != os.Getpid() { + return nil, fmt.Errorf("unexpected PID, current:%v, LISTEN_PID: %v", os.Getpid(), envData.pid) + } + } if s.FDIndex != nil { idx := *s.FDIndex - if idx < 0 || idx >= numFds { - return nil, fmt.Errorf("invalid fd index, expected between 0 and %v, got: %v", numFds, idx) + if idx < 0 || idx >= envData.numFds { + return nil, fmt.Errorf("invalid fd index, expected between 0 and %v, got: %v", envData.numFds, idx) } fd := StartFD + idx - if idx < len(fdNames) { - return makeFdListener(fd, fdNames[idx]) + if idx < len(envData.fdNames) { + return makeFdListener(fd, envData.fdNames[idx]) } return makeFdListener(fd, fmt.Sprintf("sysdfd_%d", fd)) } if s.FDName != nil { - for idx, name := range fdNames { + for idx, name := range envData.fdNames { if name == *s.FDName { fd := StartFD + idx return makeFdListener(fd, name) } } - return nil, fmt.Errorf("fdName not found: %q, LISTEN_FDNAMES:%q", *s.FDName, os.Getenv("LISTEN_FDNAMES")) + return nil, fmt.Errorf("fdName not found: %q, LISTEN_FDNAMES:%q", *s.FDName, envData.fdNamesStr) } return nil, errors.New("neither FDIndex nor FDName set") } // UnknownAddress Error is returned when address does not match any known syntax -type UnknownAddress string +type UnknownAddress struct{} func (u UnknownAddress) Error() string { - return fmt.Sprintf("unknown address: %q", string(u)) + return "unknown address" } // GetListener gets a unix or systemd socket listener @@ -182,7 +210,7 @@ func GetListener(addr string) (net.Listener, error) { return sysdc.GetListener() } - return nil, UnknownAddress(addr) + return nil, UnknownAddress{} } // ListenAndServe is the drop-in replacement for `http.ListenAndServe`. @@ -190,7 +218,7 @@ func GetListener(addr string) (net.Listener, error) { func ListenAndServe(addr string, h http.Handler) error { listener, err := GetListener(addr) - if _, ok := err.(UnknownAddress); err != nil && !ok { + if _, isUnknown := err.(UnknownAddress); err != nil && !isUnknown { return err } @@ -211,7 +239,7 @@ func ListenAndServe(addr string, h http.Handler) error { // UnsetSystemdListenVars unsets the LISTEN* environment variables so they are not passed to any child processes func UnsetSystemdListenVars() { - os.Unsetenv("LISTEN_PID") - os.Unsetenv("LISTEN_FDS") - os.Unsetenv("LISTEN_FDNAMES") + _ = os.Unsetenv("LISTEN_PID") + _ = os.Unsetenv("LISTEN_FDS") + _ = os.Unsetenv("LISTEN_FDNAMES") }