diff --git a/anyhttp.go b/anyhttp.go index 743b064..7ec1f23 100644 --- a/anyhttp.go +++ b/anyhttp.go @@ -205,7 +205,7 @@ func (s *SysdConfig) GetListener() (net.Listener, error) { // Serve creates and serve a http server. func Serve(addr string, h http.Handler) (addrType AddressType, srv *http.Server, idler idle.Idler, done <-chan error, err error) { - addrType, usc, sysc, err := ParseAddress(addr) + addrType, usc, sysc, err := parseAddress(addr) if err != nil { return } @@ -215,12 +215,11 @@ func Serve(addr string, h http.Handler) (addrType AddressType, srv *http.Server, return usc.GetListener() } else if sysc != nil { return sysc.GetListener() - } else { - if addr == "" { - addr = ":http" - } - return net.Listen("tcp", addr) } + if addr == "" { + addr = ":http" + } + return net.Listen("tcp", addr) }() if err != nil { return @@ -268,8 +267,7 @@ func UnsetSystemdListenVars() { _ = os.Unsetenv("LISTEN_FDNAMES") } -func ParseAddress(addr string) (addrType AddressType, usc *UnixSocketConfig, sysc *SysdConfig, err error) { - // addrType = Unknown +func parseAddress(addr string) (addrType AddressType, usc *UnixSocketConfig, sysc *SysdConfig, err error) { usc = nil sysc = nil err = nil @@ -302,11 +300,11 @@ func ParseAddress(addr string) (addrType AddressType, usc *UnixSocketConfig, sys err = fmt.Errorf("unix socket address error. Multiple remove_existing found: %v", val) return } - if removeExisting, berr := strconv.ParseBool(val[0]); berr != nil { + if removeExisting, berr := strconv.ParseBool(val[0]); berr == nil { + usc.RemoveExisting = removeExisting + } else { err = fmt.Errorf("unix socket address error. Bad remove_existing: %v, err: %w", val, berr) return - } else { - usc.RemoveExisting = removeExisting } } else { err = fmt.Errorf("unix socket address error. Bad option; key: %v, val: %v", key, val) @@ -333,44 +331,44 @@ func ParseAddress(addr string) (addrType AddressType, usc *UnixSocketConfig, sys err = fmt.Errorf("systemd socket fd address error. Multiple idx found: %v", val) return } - if idx, ierr := strconv.Atoi(val[0]); ierr != nil { + if idx, ierr := strconv.Atoi(val[0]); ierr == nil { + sysc.FDIndex = &idx + } else { err = fmt.Errorf("systemd socket fd address error. Bad idx: %v, err: %w", val, ierr) return - } else { - sysc.FDIndex = &idx } } else if key == "check_pid" { if len(val) != 1 { err = fmt.Errorf("systemd socket fd address error. Multiple check_pid found: %v", val) return } - if checkPID, berr := strconv.ParseBool(val[0]); berr != nil { + if checkPID, berr := strconv.ParseBool(val[0]); berr == nil { + sysc.CheckPID = checkPID + } else { err = fmt.Errorf("systemd socket fd address error. Bad check_pid: %v, err: %w", val, berr) return - } else { - sysc.CheckPID = checkPID } } else if key == "unset_env" { if len(val) != 1 { err = fmt.Errorf("systemd socket fd address error. Multiple unset_env found: %v", val) return } - if unsetEnv, berr := strconv.ParseBool(val[0]); berr != nil { + if unsetEnv, berr := strconv.ParseBool(val[0]); berr == nil { + sysc.UnsetEnv = unsetEnv + } else { err = fmt.Errorf("systemd socket fd address error. Bad unset_env: %v, err: %w", val, berr) return - } else { - sysc.UnsetEnv = unsetEnv } } else if key == "idle_timeout" { if len(val) != 1 { err = fmt.Errorf("systemd socket fd address error. Multiple idle_timeout found: %v", val) return } - if timeout, terr := time.ParseDuration(val[0]); terr != nil { + if timeout, terr := time.ParseDuration(val[0]); terr == nil { + sysc.IdleTimeout = &timeout + } else { err = fmt.Errorf("systemd socket fd address error. Bad idle_timeout: %v, err: %w", val, terr) return - } else { - sysc.IdleTimeout = &timeout } } else { err = fmt.Errorf("systemd socket fd address error. Bad option; key: %v, val: %v", key, val) diff --git a/anyhttp_test.go b/anyhttp_test.go new file mode 100644 index 0000000..74f58ad --- /dev/null +++ b/anyhttp_test.go @@ -0,0 +1,132 @@ +package anyhttp + +import ( + "encoding/json" + "testing" + "time" +) + +func Test_parseAddress(t *testing.T) { + tests := []struct { + name string // description of this test case + // Named input parameters for target function. + addr string + wantAddrType AddressType + wantUsc *UnixSocketConfig + wantSysc *SysdConfig + wantErr bool + }{ + { + name: "tcp port", + addr: ":8080", + wantAddrType: TCP, + wantUsc: nil, + wantSysc: nil, + wantErr: false, + }, + { + name: "unix address", + addr: "unix?path=/run/foo.sock&mode=660", + wantAddrType: UnixSocket, + wantUsc: &UnixSocketConfig{ + SocketPath: "/run/foo.sock", + SocketMode: 0660, + RemoveExisting: true, + }, + wantSysc: nil, + wantErr: false, + }, + { + name: "systemd address", + addr: "sysd?name=foo.socket", + wantAddrType: SystemdFD, + wantUsc: nil, + wantSysc: &SysdConfig{ + FDIndex: nil, + FDName: ptr("foo.socket"), + CheckPID: true, + UnsetEnv: true, + IdleTimeout: nil, + }, + wantErr: false, + }, + { + name: "systemd address with index", + addr: "sysd?idx=0&idle_timeout=30m", + wantAddrType: SystemdFD, + wantUsc: nil, + wantSysc: &SysdConfig{ + FDIndex: ptr(0), + FDName: nil, + CheckPID: true, + UnsetEnv: true, + IdleTimeout: ptr(30 * time.Minute), + }, + wantErr: false, + }, + { + name: "systemd address. Bad example", + addr: "sysd?idx=0&idle_timeout=30m&name=foo", + wantAddrType: SystemdFD, + wantUsc: nil, + wantSysc: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotAddrType, gotUsc, gotSysc, gotErr := parseAddress(tt.addr) + if gotErr != nil { + if !tt.wantErr { + t.Errorf("parseAddress() failed: %v", gotErr) + } + return + } + if tt.wantErr { + t.Fatal("parseAddress() succeeded unexpectedly") + } + + if gotAddrType != tt.wantAddrType { + t.Errorf("parseAddress() addrType = %v, want %v", gotAddrType, tt.wantAddrType) + } + + if !check(gotUsc, tt.wantUsc) { + t.Errorf("parseAddress() Usc = %v, want %v", gotUsc, tt.wantUsc) + } + if !check(gotSysc, tt.wantSysc) { + if (gotSysc == nil || tt.wantSysc == nil) || + !(check(gotSysc.FDIndex, tt.wantSysc.FDIndex) && + check(gotSysc.FDName, tt.wantSysc.FDName) && + check(gotSysc.IdleTimeout, tt.wantSysc.IdleTimeout)) { + t.Errorf("parseAddress() Sysc = %v, want %v", asJSON(gotSysc), asJSON(tt.wantSysc)) + } + } + }) + } +} + +// Helpers + +// print value instead of pointer +func asJSON[T any](val T) string { + op, err := json.Marshal(val) + if err != nil { + return err.Error() + } + return string(op) +} + +func ptr[T any](val T) *T { + return &val +} + +// nil safe equal check +func check[T comparable](got, want *T) bool { + if (got == nil) != (want == nil) { + return false + } + if got == nil { + return true + } + return *got == *want +}