Refactor: Add GetListener and TLS versions
This commit is contained in:
parent
27d7f247c3
commit
74ade60009
145
anyhttp.go
145
anyhttp.go
@ -23,9 +23,9 @@ import (
|
||||
type AddressType string
|
||||
|
||||
var (
|
||||
// UnixSocket - address is a unix socket, e.g. unix//run/foo.sock
|
||||
// UnixSocket - address is a unix socket, e.g. unix?path=/run/foo.sock
|
||||
UnixSocket AddressType = "UnixSocket"
|
||||
// SystemdFD - address is a systemd fd, e.g. sysd/fdname/myapp.socket
|
||||
// SystemdFD - address is a systemd fd, e.g. sysd?name=myapp.socket
|
||||
SystemdFD AddressType = "SystemdFD"
|
||||
// TCP - address is a TCP address, e.g. :1234
|
||||
TCP AddressType = "TCP"
|
||||
@ -203,6 +203,34 @@ func (s *SysdConfig) GetListener() (net.Listener, error) {
|
||||
return nil, errors.New("neither FDIndex nor FDName set")
|
||||
}
|
||||
|
||||
// GetListener is low level function for use with non-http servers. e.g. tcp, smtp
|
||||
// Caller should handle idle timeout if needed
|
||||
func GetListener(addr string) (net.Listener, AddressType, any /* cfg */, error) {
|
||||
|
||||
addrType, unixSocketConfig, sysdConfig, perr := parseAddress(addr)
|
||||
if perr != nil {
|
||||
return nil, Unknown, nil, perr
|
||||
}
|
||||
if unixSocketConfig != nil {
|
||||
listener, err := unixSocketConfig.GetListener()
|
||||
if err != nil {
|
||||
return nil, Unknown, nil, err
|
||||
}
|
||||
return listener, addrType, unixSocketConfig, nil
|
||||
} else if sysdConfig != nil {
|
||||
listener, err := sysdConfig.GetListener()
|
||||
if err != nil {
|
||||
return nil, Unknown, nil, err
|
||||
}
|
||||
return listener, addrType, sysdConfig, nil
|
||||
}
|
||||
if addr == "" {
|
||||
addr = ":http"
|
||||
}
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
return listener, TCP, nil, err
|
||||
}
|
||||
|
||||
type ServerCtx struct {
|
||||
AddressType AddressType
|
||||
Listener net.Listener
|
||||
@ -229,53 +257,14 @@ func (s *ServerCtx) Shutdown(ctx context.Context) error {
|
||||
return <-s.Done
|
||||
}
|
||||
|
||||
// Serve creates and serve a http server.
|
||||
func Serve(addr string, h http.Handler) (*ServerCtx, error) {
|
||||
var ctx ServerCtx
|
||||
var err error
|
||||
ctx.AddressType, ctx.UnixSocketConfig, ctx.SysdConfig, err = parseAddress(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// ServeTLS creates and serves a HTTPS server.
|
||||
func ServeTLS(addr string, h http.Handler, certFile string, keyFile string) (*ServerCtx, error) {
|
||||
return serve(addr, h, certFile, keyFile)
|
||||
}
|
||||
|
||||
ctx.Listener, err = func() (net.Listener, error) {
|
||||
if ctx.UnixSocketConfig != nil {
|
||||
return ctx.UnixSocketConfig.GetListener()
|
||||
} else if ctx.SysdConfig != nil {
|
||||
return ctx.SysdConfig.GetListener()
|
||||
}
|
||||
if addr == "" {
|
||||
addr = ":http"
|
||||
}
|
||||
return net.Listen("tcp", addr)
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
errChan := make(chan error)
|
||||
ctx.Done = errChan
|
||||
if ctx.AddressType == SystemdFD && ctx.SysdConfig.IdleTimeout != nil {
|
||||
ctx.Idler = idle.CreateIdler(*ctx.SysdConfig.IdleTimeout)
|
||||
ctx.Server = &http.Server{Handler: idle.WrapIdlerHandler(ctx.Idler, h)}
|
||||
waitErrChan := make(chan error)
|
||||
go func() {
|
||||
waitErrChan <- ctx.Server.Serve(ctx.Listener)
|
||||
}()
|
||||
go func() {
|
||||
select {
|
||||
case err := <-waitErrChan:
|
||||
errChan <- err
|
||||
case <-ctx.Idler.Chan():
|
||||
errChan <- ctx.Server.Shutdown(context.TODO())
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
ctx.Server = &http.Server{Handler: h}
|
||||
go func() {
|
||||
errChan <- ctx.Server.Serve(ctx.Listener)
|
||||
}()
|
||||
}
|
||||
return &ctx, nil
|
||||
// Serve creates and serves a HTTP server.
|
||||
func Serve(addr string, h http.Handler) (*ServerCtx, error) {
|
||||
return serve(addr, h, "", "")
|
||||
}
|
||||
|
||||
// ListenAndServe is the drop-in replacement for `http.ListenAndServe`.
|
||||
@ -288,6 +277,14 @@ func ListenAndServe(addr string, h http.Handler) error {
|
||||
return ctx.Wait()
|
||||
}
|
||||
|
||||
func ListenAndServeTLS(addr string, certFile string, keyFile string, h http.Handler) error {
|
||||
ctx, err := ServeTLS(addr, h, certFile, keyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ctx.Wait()
|
||||
}
|
||||
|
||||
// UnsetSystemdListenVars unsets the LISTEN* environment variables so they are not passed to any child processes
|
||||
func UnsetSystemdListenVars() {
|
||||
_ = os.Unsetenv("LISTEN_PID")
|
||||
@ -389,3 +386,55 @@ func parseAddress(addr string) (addrType AddressType, usc *UnixSocketConfig, sys
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func serve(addr string, h http.Handler, certFile string, keyFile string) (*ServerCtx, error) {
|
||||
|
||||
serveFn := func() func(ctx *ServerCtx) error {
|
||||
if certFile != "" {
|
||||
return func(ctx *ServerCtx) error {
|
||||
return ctx.Server.ServeTLS(ctx.Listener, certFile, keyFile)
|
||||
}
|
||||
}
|
||||
return func(ctx *ServerCtx) error {
|
||||
return ctx.Server.Serve(ctx.Listener)
|
||||
}
|
||||
}()
|
||||
var ctx ServerCtx
|
||||
var err error
|
||||
var cfg any
|
||||
|
||||
ctx.Listener, ctx.AddressType, cfg, err = GetListener(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch ctx.AddressType {
|
||||
case UnixSocket:
|
||||
ctx.UnixSocketConfig = cfg.(*UnixSocketConfig)
|
||||
case SystemdFD:
|
||||
ctx.SysdConfig = cfg.(*SysdConfig)
|
||||
}
|
||||
errChan := make(chan error)
|
||||
ctx.Done = errChan
|
||||
if ctx.AddressType == SystemdFD && ctx.SysdConfig.IdleTimeout != nil {
|
||||
ctx.Idler = idle.CreateIdler(*ctx.SysdConfig.IdleTimeout)
|
||||
ctx.Server = &http.Server{Handler: idle.WrapIdlerHandler(ctx.Idler, h)}
|
||||
waitErrChan := make(chan error)
|
||||
go func() {
|
||||
waitErrChan <- serveFn(&ctx)
|
||||
}()
|
||||
go func() {
|
||||
select {
|
||||
case err := <-waitErrChan:
|
||||
errChan <- err
|
||||
case <-ctx.Idler.Chan():
|
||||
errChan <- ctx.Server.Shutdown(context.TODO())
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
ctx.Server = &http.Server{Handler: h}
|
||||
go func() {
|
||||
errChan <- serveFn(&ctx)
|
||||
}()
|
||||
}
|
||||
return &ctx, nil
|
||||
}
|
||||
|
@ -1,7 +1,9 @@
|
||||
package anyhttp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@ -119,6 +121,12 @@ func Test_parseAddress(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServe(t *testing.T) {
|
||||
ctx, err := Serve("unix?path=/tmp/foo.sock", nil)
|
||||
log.Printf("Got ctx: %v\n, err: %v", ctx, err)
|
||||
ctx.Shutdown(context.TODO())
|
||||
}
|
||||
|
||||
// Helpers
|
||||
|
||||
// print value instead of pointer
|
||||
|
32
examples/simple/main.go
Normal file
32
examples/simple/main.go
Normal file
@ -0,0 +1,32 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"go.balki.me/anyhttp"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("hello\n"))
|
||||
})
|
||||
//log.Println("Got error: ", anyhttp.ListenAndServe(os.Args[1], nil))
|
||||
ctx, err := anyhttp.Serve(os.Args[1], nil)
|
||||
log.Printf("Got ctx: %v\n, err: %v", ctx, err)
|
||||
log.Println(ctx.Addr())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case doneErr := <-ctx.Done:
|
||||
log.Println(doneErr)
|
||||
case <-time.After(1 * time.Minute):
|
||||
log.Println("Awake")
|
||||
ctx.Shutdown(context.TODO())
|
||||
}
|
||||
}
|
BIN
examples/simple/simple
Executable file
BIN
examples/simple/simple
Executable file
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user