Compare commits

..

No commits in common. "main" and "addr-url" have entirely different histories.

4 changed files with 49 additions and 184 deletions

View File

@ -17,13 +17,13 @@ Just replace `http.ListenAndServe` with `anyhttp.ListenAndServe`.
Syntax
unix?path=<socket_path>&mode=<socket file mode>&remove_existing=<true|false>
unix?path=<socket_path>&mode=<socket file mode>&remove_existing=<yes|no>
Examples
unix?path=relative/path.sock
unix?path=/var/run/app/absolutepath.sock
unix?path=/run/app.sock&mode=600&remove_existing=false
unix?path=/run/app.sock&mode=600&remove_existing=no
| option | description | default |
|-----------------|------------------------------------------------|----------|
@ -35,7 +35,7 @@ Examples
Syntax
sysd?idx=<fd index>&name=<fd name>&check_pid=<true|false>&unset_env=<true|false>&idle_timeout=<duration>
sysd?idx=<fd index>&name=<fd name>&check_pid=<yes|no>&unset_env=<yes|no>&idle_timeout=<duration>
Only one of `idx` or `name` has to be set

View File

@ -23,9 +23,9 @@ import (
type AddressType string
var (
// UnixSocket - address is a unix socket, e.g. unix?path=/run/foo.sock
// UnixSocket - address is a unix socket, e.g. unix//run/foo.sock
UnixSocket AddressType = "UnixSocket"
// SystemdFD - address is a systemd fd, e.g. sysd?name=myapp.socket
// SystemdFD - address is a systemd fd, e.g. sysd/fdname/myapp.socket
SystemdFD AddressType = "SystemdFD"
// TCP - address is a TCP address, e.g. :1234
TCP AddressType = "TCP"
@ -203,86 +203,61 @@ func (s *SysdConfig) GetListener() (net.Listener, error) {
return nil, errors.New("neither FDIndex nor FDName set")
}
// GetListener is low level function for use with non-http servers. e.g. tcp, smtp
// Caller should handle idle timeout if needed
func GetListener(addr string) (net.Listener, AddressType, any /* cfg */, error) {
addrType, unixSocketConfig, sysdConfig, perr := parseAddress(addr)
if perr != nil {
return nil, Unknown, nil, perr
}
if unixSocketConfig != nil {
listener, err := unixSocketConfig.GetListener()
if err != nil {
return nil, Unknown, nil, err
}
return listener, addrType, unixSocketConfig, nil
} else if sysdConfig != nil {
listener, err := sysdConfig.GetListener()
if err != nil {
return nil, Unknown, nil, err
}
return listener, addrType, sysdConfig, nil
}
if addr == "" {
addr = ":http"
}
listener, err := net.Listen("tcp", addr)
return listener, TCP, nil, err
}
type ServerCtx struct {
AddressType AddressType
Listener net.Listener
Server *http.Server
Idler idle.Idler
Done <-chan error
UnixSocketConfig *UnixSocketConfig
SysdConfig *SysdConfig
}
func (s *ServerCtx) Wait() error {
return <-s.Done
}
func (s *ServerCtx) Addr() net.Addr {
return s.Listener.Addr()
}
func (s *ServerCtx) Shutdown(ctx context.Context) error {
err := s.Server.Shutdown(ctx)
// Serve creates and serve a http server.
func Serve(addr string, h http.Handler) (addrType AddressType, srv *http.Server, idler idle.Idler, done <-chan error, err error) {
addrType, usc, sysc, err := parseAddress(addr)
if err != nil {
return err
return
}
return <-s.Done
}
// ServeTLS creates and serves a HTTPS server.
func ServeTLS(addr string, h http.Handler, certFile string, keyFile string) (*ServerCtx, error) {
return serve(addr, h, certFile, keyFile)
}
// Serve creates and serves a HTTP server.
func Serve(addr string, h http.Handler) (*ServerCtx, error) {
return serve(addr, h, "", "")
listener, err := func() (net.Listener, error) {
if usc != nil {
return usc.GetListener()
} else if sysc != nil {
return sysc.GetListener()
}
if addr == "" {
addr = ":http"
}
return net.Listen("tcp", addr)
}()
if err != nil {
return
}
errChan := make(chan error)
done = errChan
if addrType == SystemdFD && sysc.IdleTimeout != nil {
idler = idle.CreateIdler(*sysc.IdleTimeout)
srv = &http.Server{Handler: idle.WrapIdlerHandler(idler, h)}
waitErrChan := make(chan error)
go func() {
waitErrChan <- srv.Serve(listener)
}()
go func() {
select {
case err := <-waitErrChan:
errChan <- err
case <-idler.Chan():
errChan <- srv.Shutdown(context.TODO())
}
}()
} else {
srv = &http.Server{Handler: h}
go func() {
errChan <- srv.Serve(listener)
}()
}
return
}
// ListenAndServe is the drop-in replacement for `http.ListenAndServe`.
// Supports unix and systemd sockets in addition
func ListenAndServe(addr string, h http.Handler) error {
ctx, err := Serve(addr, h)
_, _, _, done, err := Serve(addr, h)
if err != nil {
return err
}
return ctx.Wait()
}
func ListenAndServeTLS(addr string, certFile string, keyFile string, h http.Handler) error {
ctx, err := ServeTLS(addr, h, certFile, keyFile)
if err != nil {
return err
}
return ctx.Wait()
return <-done
}
// UnsetSystemdListenVars unsets the LISTEN* environment variables so they are not passed to any child processes
@ -386,55 +361,3 @@ func parseAddress(addr string) (addrType AddressType, usc *UnixSocketConfig, sys
}
return
}
func serve(addr string, h http.Handler, certFile string, keyFile string) (*ServerCtx, error) {
serveFn := func() func(ctx *ServerCtx) error {
if certFile != "" {
return func(ctx *ServerCtx) error {
return ctx.Server.ServeTLS(ctx.Listener, certFile, keyFile)
}
}
return func(ctx *ServerCtx) error {
return ctx.Server.Serve(ctx.Listener)
}
}()
var ctx ServerCtx
var err error
var cfg any
ctx.Listener, ctx.AddressType, cfg, err = GetListener(addr)
if err != nil {
return nil, err
}
switch ctx.AddressType {
case UnixSocket:
ctx.UnixSocketConfig = cfg.(*UnixSocketConfig)
case SystemdFD:
ctx.SysdConfig = cfg.(*SysdConfig)
}
errChan := make(chan error)
ctx.Done = errChan
if ctx.AddressType == SystemdFD && ctx.SysdConfig.IdleTimeout != nil {
ctx.Idler = idle.CreateIdler(*ctx.SysdConfig.IdleTimeout)
ctx.Server = &http.Server{Handler: idle.WrapIdlerHandler(ctx.Idler, h)}
waitErrChan := make(chan error)
go func() {
waitErrChan <- serveFn(&ctx)
}()
go func() {
select {
case err := <-waitErrChan:
errChan <- err
case <-ctx.Idler.Chan():
errChan <- ctx.Server.Shutdown(context.TODO())
}
}()
} else {
ctx.Server = &http.Server{Handler: h}
go func() {
errChan <- serveFn(&ctx)
}()
}
return &ctx, nil
}

View File

@ -1,7 +1,6 @@
package anyhttp
import (
"context"
"encoding/json"
"testing"
"time"
@ -73,20 +72,6 @@ func Test_parseAddress(t *testing.T) {
wantSysc: nil,
wantErr: true,
},
{
name: "systemd address with check_pid and unset_env",
addr: "sysd?idx=0&check_pid=false&unset_env=f",
wantAddrType: SystemdFD,
wantUsc: nil,
wantSysc: &SysdConfig{
FDIndex: ptr(0),
FDName: nil,
CheckPID: false,
UnsetEnv: false,
IdleTimeout: nil,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -120,17 +105,6 @@ func Test_parseAddress(t *testing.T) {
}
}
func TestServe(t *testing.T) {
ctx, err := Serve("unix?path=/tmp/foo.sock", nil)
if err != nil {
t.Fatal()
}
if ctx.AddressType != UnixSocket {
t.Errorf("Serve() ServerCtx = %v, want %v", ctx.AddressType, UnixSocket)
}
ctx.Shutdown(context.TODO())
}
// Helpers
// print value instead of pointer

View File

@ -1,32 +0,0 @@
package main
import (
"context"
"log"
"net/http"
"os"
"time"
"go.balki.me/anyhttp"
)
func main() {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hello\n"))
})
//log.Println("Got error: ", anyhttp.ListenAndServe(os.Args[1], nil))
ctx, err := anyhttp.Serve(os.Args[1], nil)
log.Printf("Got ctx: %v\n, err: %v", ctx, err)
log.Println(ctx.Addr())
if err != nil {
return
}
select {
case doneErr := <-ctx.Done:
log.Println(doneErr)
case <-time.After(1 * time.Minute):
log.Println("Awake")
ctx.Shutdown(context.TODO())
}
}