diff --git a/anyhttp.go b/anyhttp.go index d76f7e9..8d316a7 100644 --- a/anyhttp.go +++ b/anyhttp.go @@ -23,9 +23,9 @@ import ( type AddressType string var ( - // UnixSocket - address is a unix socket, e.g. unix//run/foo.sock + // UnixSocket - address is a unix socket, e.g. unix?path=/run/foo.sock UnixSocket AddressType = "UnixSocket" - // SystemdFD - address is a systemd fd, e.g. sysd/fdname/myapp.socket + // SystemdFD - address is a systemd fd, e.g. sysd?name=myapp.socket SystemdFD AddressType = "SystemdFD" // TCP - address is a TCP address, e.g. :1234 TCP AddressType = "TCP" @@ -203,6 +203,34 @@ func (s *SysdConfig) GetListener() (net.Listener, error) { return nil, errors.New("neither FDIndex nor FDName set") } +// GetListener is low level function for use with non-http servers. e.g. tcp, smtp +// Caller should handle idle timeout if needed +func GetListener(addr string) (net.Listener, AddressType, any /* cfg */, error) { + + addrType, unixSocketConfig, sysdConfig, perr := parseAddress(addr) + if perr != nil { + return nil, Unknown, nil, perr + } + if unixSocketConfig != nil { + listener, err := unixSocketConfig.GetListener() + if err != nil { + return nil, Unknown, nil, err + } + return listener, addrType, unixSocketConfig, nil + } else if sysdConfig != nil { + listener, err := sysdConfig.GetListener() + if err != nil { + return nil, Unknown, nil, err + } + return listener, addrType, sysdConfig, nil + } + if addr == "" { + addr = ":http" + } + listener, err := net.Listen("tcp", addr) + return listener, TCP, nil, err +} + type ServerCtx struct { AddressType AddressType Listener net.Listener @@ -229,53 +257,14 @@ func (s *ServerCtx) Shutdown(ctx context.Context) error { return <-s.Done } -// Serve creates and serve a http server. -func Serve(addr string, h http.Handler) (*ServerCtx, error) { - var ctx ServerCtx - var err error - ctx.AddressType, ctx.UnixSocketConfig, ctx.SysdConfig, err = parseAddress(addr) - if err != nil { - return nil, err - } +// ServeTLS creates and serves a HTTPS server. +func ServeTLS(addr string, h http.Handler, certFile string, keyFile string) (*ServerCtx, error) { + return serve(addr, h, certFile, keyFile) +} - ctx.Listener, err = func() (net.Listener, error) { - if ctx.UnixSocketConfig != nil { - return ctx.UnixSocketConfig.GetListener() - } else if ctx.SysdConfig != nil { - return ctx.SysdConfig.GetListener() - } - if addr == "" { - addr = ":http" - } - return net.Listen("tcp", addr) - }() - if err != nil { - return nil, err - } - errChan := make(chan error) - ctx.Done = errChan - if ctx.AddressType == SystemdFD && ctx.SysdConfig.IdleTimeout != nil { - ctx.Idler = idle.CreateIdler(*ctx.SysdConfig.IdleTimeout) - ctx.Server = &http.Server{Handler: idle.WrapIdlerHandler(ctx.Idler, h)} - waitErrChan := make(chan error) - go func() { - waitErrChan <- ctx.Server.Serve(ctx.Listener) - }() - go func() { - select { - case err := <-waitErrChan: - errChan <- err - case <-ctx.Idler.Chan(): - errChan <- ctx.Server.Shutdown(context.TODO()) - } - }() - } else { - ctx.Server = &http.Server{Handler: h} - go func() { - errChan <- ctx.Server.Serve(ctx.Listener) - }() - } - return &ctx, nil +// Serve creates and serves a HTTP server. +func Serve(addr string, h http.Handler) (*ServerCtx, error) { + return serve(addr, h, "", "") } // ListenAndServe is the drop-in replacement for `http.ListenAndServe`. @@ -288,6 +277,14 @@ func ListenAndServe(addr string, h http.Handler) error { return ctx.Wait() } +func ListenAndServeTLS(addr string, certFile string, keyFile string, h http.Handler) error { + ctx, err := ServeTLS(addr, h, certFile, keyFile) + if err != nil { + return err + } + return ctx.Wait() +} + // UnsetSystemdListenVars unsets the LISTEN* environment variables so they are not passed to any child processes func UnsetSystemdListenVars() { _ = os.Unsetenv("LISTEN_PID") @@ -389,3 +386,55 @@ func parseAddress(addr string) (addrType AddressType, usc *UnixSocketConfig, sys } return } + +func serve(addr string, h http.Handler, certFile string, keyFile string) (*ServerCtx, error) { + + serveFn := func() func(ctx *ServerCtx) error { + if certFile != "" { + return func(ctx *ServerCtx) error { + return ctx.Server.ServeTLS(ctx.Listener, certFile, keyFile) + } + } + return func(ctx *ServerCtx) error { + return ctx.Server.Serve(ctx.Listener) + } + }() + var ctx ServerCtx + var err error + var cfg any + + ctx.Listener, ctx.AddressType, cfg, err = GetListener(addr) + if err != nil { + return nil, err + } + switch ctx.AddressType { + case UnixSocket: + ctx.UnixSocketConfig = cfg.(*UnixSocketConfig) + case SystemdFD: + ctx.SysdConfig = cfg.(*SysdConfig) + } + errChan := make(chan error) + ctx.Done = errChan + if ctx.AddressType == SystemdFD && ctx.SysdConfig.IdleTimeout != nil { + ctx.Idler = idle.CreateIdler(*ctx.SysdConfig.IdleTimeout) + ctx.Server = &http.Server{Handler: idle.WrapIdlerHandler(ctx.Idler, h)} + waitErrChan := make(chan error) + go func() { + waitErrChan <- serveFn(&ctx) + }() + go func() { + select { + case err := <-waitErrChan: + errChan <- err + case <-ctx.Idler.Chan(): + errChan <- ctx.Server.Shutdown(context.TODO()) + } + }() + } else { + ctx.Server = &http.Server{Handler: h} + go func() { + errChan <- serveFn(&ctx) + }() + } + return &ctx, nil +} diff --git a/anyhttp_test.go b/anyhttp_test.go index b6decaa..5024c11 100644 --- a/anyhttp_test.go +++ b/anyhttp_test.go @@ -1,7 +1,9 @@ package anyhttp import ( + "context" "encoding/json" + "log" "testing" "time" ) @@ -119,6 +121,12 @@ func Test_parseAddress(t *testing.T) { } } +func TestServe(t *testing.T) { + ctx, err := Serve("unix?path=/tmp/foo.sock", nil) + log.Printf("Got ctx: %v\n, err: %v", ctx, err) + ctx.Shutdown(context.TODO()) +} + // Helpers // print value instead of pointer diff --git a/examples/simple/main.go b/examples/simple/main.go new file mode 100644 index 0000000..19ea57b --- /dev/null +++ b/examples/simple/main.go @@ -0,0 +1,32 @@ +package main + +import ( + "context" + "log" + "net/http" + "os" + "time" + + "go.balki.me/anyhttp" +) + +func main() { + + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hello\n")) + }) + //log.Println("Got error: ", anyhttp.ListenAndServe(os.Args[1], nil)) + ctx, err := anyhttp.Serve(os.Args[1], nil) + log.Printf("Got ctx: %v\n, err: %v", ctx, err) + log.Println(ctx.Addr()) + if err != nil { + return + } + select { + case doneErr := <-ctx.Done: + log.Println(doneErr) + case <-time.After(1 * time.Minute): + log.Println("Awake") + ctx.Shutdown(context.TODO()) + } +} diff --git a/examples/simple/simple b/examples/simple/simple new file mode 100755 index 0000000..f002bf6 Binary files /dev/null and b/examples/simple/simple differ