Compare commits

..

No commits in common. "main" and "globalIdler" have entirely different histories.

4 changed files with 61 additions and 440 deletions

View File

@ -1,4 +1,4 @@
Create http server listening on unix sockets and systemd socket activated fds
Create http server listening on unix sockets or systemd socket activated fds
## Quick Usage
@ -17,52 +17,36 @@ Just replace `http.ListenAndServe` with `anyhttp.ListenAndServe`.
Syntax
unix?path=<socket_path>&mode=<socket file mode>&remove_existing=<true|false>
unix/<path to socket>
Examples
unix?path=relative/path.sock
unix?path=/var/run/app/absolutepath.sock
unix?path=/run/app.sock&mode=600&remove_existing=false
| option | description | default |
|-----------------|------------------------------------------------|----------|
| path | path to unix socket | Required |
| mode | socket file mode | 666 |
| remove_existing | Whether to remove existing socket file or fail | true |
unix/relative/path.sock
unix//var/run/app/absolutepath.sock
### Systemd Socket activated fd:
Syntax
sysd?idx=<fd index>&name=<fd name>&check_pid=<true|false>&unset_env=<true|false>&idle_timeout=<duration>
Only one of `idx` or `name` has to be set
sysd/fdidx/<fd index starting at 0>
sysd/fdname/<fd name set using FileDescriptorName socket setting >
Examples:
# First (or only) socket fd passed to app
sysd?idx=0
sysd/fdidx/0
# Socket with FileDescriptorName
sysd?name=myapp
sysd/fdname/myapp
# Using default name and auto shutdown if no requests received in last 30 minutes
sysd?name=myapp.socket&idle_timeout=30m
# Using default name
sysd/fdname/myapp.socket
| option | description | default |
|--------------|--------------------------------------------------------------------------------------------|------------------|
| name | Name configured via FileDescriptorName or socket file name | Required |
| idx | FD Index. Actual fd num will be 3 + idx | Required |
| idle_timeout | time to wait before shutdown. [syntax][0] | no auto shutdown |
| check_pid | Check process PID matches LISTEN_PID | true |
| unset_env | Unsets the LISTEN\* environment variables, so they don't get passed to any child processes | true |
### TCP port
### TCP
If the address is a number less than 65536, it is assumed as a port and passed as `http.ListenAndServe(":<port>",...)`
If the address is not one of above, it is assumed to be tcp and passed to `http.ListenAndServe`.
Examples:
Anything else is directly passed to `http.ListenAndServe` as well. Below examples should work
:http
:8888
@ -75,6 +59,4 @@ https://pkg.go.dev/go.balki.me/anyhttp
### Related links
* https://gist.github.com/teknoraver/5ffacb8757330715bcbcc90e6d46ac74#file-unixhttpd-go
* https://github.com/coreos/go-systemd/tree/main/activation
[0]: https://pkg.go.dev/time#ParseDuration
* https://github.com/coreos/go-systemd/tree/main/activation

View File

@ -2,30 +2,25 @@
package anyhttp
import (
"context"
"errors"
"fmt"
"io/fs"
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"sync"
"syscall"
"time"
"go.balki.me/anyhttp/idle"
)
// AddressType of the address passed
type AddressType string
var (
// UnixSocket - address is a unix socket, e.g. unix?path=/run/foo.sock
// UnixSocket - address is a unix socket, e.g. unix//run/foo.sock
UnixSocket AddressType = "UnixSocket"
// SystemdFD - address is a systemd fd, e.g. sysd?name=myapp.socket
// SystemdFD - address is a systemd fd, e.g. sysd/fdname/myapp.socket
SystemdFD AddressType = "SystemdFD"
// TCP - address is a TCP address, e.g. :1234
TCP AddressType = "TCP"
@ -102,8 +97,6 @@ type SysdConfig struct {
CheckPID bool
// Unsets the LISTEN* environment variables, so they don't get passed to any child processes
UnsetEnv bool
// Shutdown http server if no requests received for below timeout
IdleTimeout *time.Duration
}
// DefaultSysdConfig has the default values for SysdConfig
@ -203,86 +196,69 @@ func (s *SysdConfig) GetListener() (net.Listener, error) {
return nil, errors.New("neither FDIndex nor FDName set")
}
// GetListener is low level function for use with non-http servers. e.g. tcp, smtp
// Caller should handle idle timeout if needed
func GetListener(addr string) (net.Listener, AddressType, any /* cfg */, error) {
// GetListener gets a unix or systemd socket listener
func GetListener(addr string) (AddressType, net.Listener, error) {
if strings.HasPrefix(addr, "unix/") {
usc := NewUnixSocketConfig(strings.TrimPrefix(addr, "unix/"))
l, err := usc.GetListener()
return UnixSocket, l, err
}
addrType, unixSocketConfig, sysdConfig, perr := parseAddress(addr)
if perr != nil {
return nil, Unknown, nil, perr
}
if unixSocketConfig != nil {
listener, err := unixSocketConfig.GetListener()
if strings.HasPrefix(addr, "sysd/fdidx/") {
idx, err := strconv.Atoi(strings.TrimPrefix(addr, "sysd/fdidx/"))
if err != nil {
return nil, Unknown, nil, err
return Unknown, nil, fmt.Errorf("invalid fdidx, addr:%q err: %w", addr, err)
}
return listener, addrType, unixSocketConfig, nil
} else if sysdConfig != nil {
listener, err := sysdConfig.GetListener()
if err != nil {
return nil, Unknown, nil, err
}
return listener, addrType, sysdConfig, nil
sysdc := NewSysDConfigWithFDIdx(idx)
l, err := sysdc.GetListener()
return SystemdFD, l, err
}
if strings.HasPrefix(addr, "sysd/fdname/") {
sysdc := NewSysDConfigWithFDName(strings.TrimPrefix(addr, "sysd/fdname/"))
l, err := sysdc.GetListener()
return SystemdFD, l, err
}
if port, err := strconv.Atoi(addr); err == nil {
if port > 0 && port < 65536 {
addr = fmt.Sprintf(":%v", port)
} else {
return Unknown, nil, fmt.Errorf("invalid port: %v", port)
}
}
if addr == "" {
addr = ":http"
}
listener, err := net.Listen("tcp", addr)
return listener, TCP, nil, err
l, err := net.Listen("tcp", addr)
return TCP, l, err
}
type ServerCtx struct {
AddressType AddressType
Listener net.Listener
Server *http.Server
Idler idle.Idler
Done <-chan error
UnixSocketConfig *UnixSocketConfig
SysdConfig *SysdConfig
}
func (s *ServerCtx) Wait() error {
return <-s.Done
}
func (s *ServerCtx) Addr() net.Addr {
return s.Listener.Addr()
}
func (s *ServerCtx) Shutdown(ctx context.Context) error {
err := s.Server.Shutdown(ctx)
// Serve creates and serve a http server.
func Serve(addr string, h http.Handler) (AddressType, *http.Server, <-chan error, error) {
addrType, listener, err := GetListener(addr)
if err != nil {
return err
return addrType, nil, nil, err
}
return <-s.Done
}
// ServeTLS creates and serves a HTTPS server.
func ServeTLS(addr string, h http.Handler, certFile string, keyFile string) (*ServerCtx, error) {
return serve(addr, h, certFile, keyFile)
}
// Serve creates and serves a HTTP server.
func Serve(addr string, h http.Handler) (*ServerCtx, error) {
return serve(addr, h, "", "")
srv := &http.Server{Handler: h}
done := make(chan error)
go func() {
done <- srv.Serve(listener)
close(done)
}()
return addrType, srv, done, nil
}
// ListenAndServe is the drop-in replacement for `http.ListenAndServe`.
// Supports unix and systemd sockets in addition
func ListenAndServe(addr string, h http.Handler) error {
ctx, err := Serve(addr, h)
_, _, done, err := Serve(addr, h)
if err != nil {
return err
}
return ctx.Wait()
}
func ListenAndServeTLS(addr string, certFile string, keyFile string, h http.Handler) error {
ctx, err := ServeTLS(addr, h, certFile, keyFile)
if err != nil {
return err
}
return ctx.Wait()
return <-done
}
// UnsetSystemdListenVars unsets the LISTEN* environment variables so they are not passed to any child processes
@ -291,150 +267,3 @@ func UnsetSystemdListenVars() {
_ = os.Unsetenv("LISTEN_FDS")
_ = os.Unsetenv("LISTEN_FDNAMES")
}
func parseAddress(addr string) (addrType AddressType, usc *UnixSocketConfig, sysc *SysdConfig, err error) {
usc = nil
sysc = nil
err = nil
u, err := url.Parse(addr)
if err != nil {
return TCP, nil, nil, nil
}
if u.Path == "unix" {
duc := DefaultUnixSocketConfig
usc = &duc
addrType = UnixSocket
for key, val := range u.Query() {
if len(val) != 1 {
err = fmt.Errorf("unix socket address error. Multiple %v found: %v", key, val)
return
}
if key == "path" {
usc.SocketPath = val[0]
} else if key == "mode" {
if _, serr := fmt.Sscanf(val[0], "%o", &usc.SocketMode); serr != nil {
err = fmt.Errorf("unix socket address error. Bad mode: %v, err: %w", val, serr)
return
}
} else if key == "remove_existing" {
if removeExisting, berr := strconv.ParseBool(val[0]); berr == nil {
usc.RemoveExisting = removeExisting
} else {
err = fmt.Errorf("unix socket address error. Bad remove_existing: %v, err: %w", val, berr)
return
}
} else {
err = fmt.Errorf("unix socket address error. Bad option; key: %v, val: %v", key, val)
return
}
}
if usc.SocketPath == "" {
err = fmt.Errorf("unix socket address error. Missing path; addr: %v", addr)
return
}
} else if u.Path == "sysd" {
dsc := DefaultSysdConfig
sysc = &dsc
addrType = SystemdFD
for key, val := range u.Query() {
if len(val) != 1 {
err = fmt.Errorf("systemd socket fd address error. Multiple %v found: %v", key, val)
return
}
if key == "name" {
sysc.FDName = &val[0]
} else if key == "idx" {
if idx, ierr := strconv.Atoi(val[0]); ierr == nil {
sysc.FDIndex = &idx
} else {
err = fmt.Errorf("systemd socket fd address error. Bad idx: %v, err: %w", val, ierr)
return
}
} else if key == "check_pid" {
if checkPID, berr := strconv.ParseBool(val[0]); berr == nil {
sysc.CheckPID = checkPID
} else {
err = fmt.Errorf("systemd socket fd address error. Bad check_pid: %v, err: %w", val, berr)
return
}
} else if key == "unset_env" {
if unsetEnv, berr := strconv.ParseBool(val[0]); berr == nil {
sysc.UnsetEnv = unsetEnv
} else {
err = fmt.Errorf("systemd socket fd address error. Bad unset_env: %v, err: %w", val, berr)
return
}
} else if key == "idle_timeout" {
if timeout, terr := time.ParseDuration(val[0]); terr == nil {
sysc.IdleTimeout = &timeout
} else {
err = fmt.Errorf("systemd socket fd address error. Bad idle_timeout: %v, err: %w", val, terr)
return
}
} else {
err = fmt.Errorf("systemd socket fd address error. Bad option; key: %v, val: %v", key, val)
return
}
}
if (sysc.FDIndex == nil) == (sysc.FDName == nil) {
err = fmt.Errorf("systemd socket fd address error. Exactly only one of name and idx has to be set. name: %v, idx: %v", sysc.FDName, sysc.FDIndex)
return
}
} else {
// Just assume as TCP address
return TCP, nil, nil, nil
}
return
}
func serve(addr string, h http.Handler, certFile string, keyFile string) (*ServerCtx, error) {
serveFn := func() func(ctx *ServerCtx) error {
if certFile != "" {
return func(ctx *ServerCtx) error {
return ctx.Server.ServeTLS(ctx.Listener, certFile, keyFile)
}
}
return func(ctx *ServerCtx) error {
return ctx.Server.Serve(ctx.Listener)
}
}()
var ctx ServerCtx
var err error
var cfg any
ctx.Listener, ctx.AddressType, cfg, err = GetListener(addr)
if err != nil {
return nil, err
}
switch ctx.AddressType {
case UnixSocket:
ctx.UnixSocketConfig = cfg.(*UnixSocketConfig)
case SystemdFD:
ctx.SysdConfig = cfg.(*SysdConfig)
}
errChan := make(chan error)
ctx.Done = errChan
if ctx.AddressType == SystemdFD && ctx.SysdConfig.IdleTimeout != nil {
ctx.Idler = idle.CreateIdler(*ctx.SysdConfig.IdleTimeout)
ctx.Server = &http.Server{Handler: idle.WrapIdlerHandler(ctx.Idler, h)}
waitErrChan := make(chan error)
go func() {
waitErrChan <- serveFn(&ctx)
}()
go func() {
select {
case err := <-waitErrChan:
errChan <- err
case <-ctx.Idler.Chan():
errChan <- ctx.Server.Shutdown(context.TODO())
}
}()
} else {
ctx.Server = &http.Server{Handler: h}
go func() {
errChan <- serveFn(&ctx)
}()
}
return &ctx, nil
}

View File

@ -1,158 +0,0 @@
package anyhttp
import (
"context"
"encoding/json"
"testing"
"time"
)
func Test_parseAddress(t *testing.T) {
tests := []struct {
name string // description of this test case
// Named input parameters for target function.
addr string
wantAddrType AddressType
wantUsc *UnixSocketConfig
wantSysc *SysdConfig
wantErr bool
}{
{
name: "tcp port",
addr: ":8080",
wantAddrType: TCP,
wantUsc: nil,
wantSysc: nil,
wantErr: false,
},
{
name: "unix address",
addr: "unix?path=/run/foo.sock&mode=660",
wantAddrType: UnixSocket,
wantUsc: &UnixSocketConfig{
SocketPath: "/run/foo.sock",
SocketMode: 0660,
RemoveExisting: true,
},
wantSysc: nil,
wantErr: false,
},
{
name: "systemd address",
addr: "sysd?name=foo.socket",
wantAddrType: SystemdFD,
wantUsc: nil,
wantSysc: &SysdConfig{
FDIndex: nil,
FDName: ptr("foo.socket"),
CheckPID: true,
UnsetEnv: true,
IdleTimeout: nil,
},
wantErr: false,
},
{
name: "systemd address with index",
addr: "sysd?idx=0&idle_timeout=30m",
wantAddrType: SystemdFD,
wantUsc: nil,
wantSysc: &SysdConfig{
FDIndex: ptr(0),
FDName: nil,
CheckPID: true,
UnsetEnv: true,
IdleTimeout: ptr(30 * time.Minute),
},
wantErr: false,
},
{
name: "systemd address. Bad example",
addr: "sysd?idx=0&idle_timeout=30m&name=foo",
wantAddrType: SystemdFD,
wantUsc: nil,
wantSysc: nil,
wantErr: true,
},
{
name: "systemd address with check_pid and unset_env",
addr: "sysd?idx=0&check_pid=false&unset_env=f",
wantAddrType: SystemdFD,
wantUsc: nil,
wantSysc: &SysdConfig{
FDIndex: ptr(0),
FDName: nil,
CheckPID: false,
UnsetEnv: false,
IdleTimeout: nil,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotAddrType, gotUsc, gotSysc, gotErr := parseAddress(tt.addr)
if gotErr != nil {
if !tt.wantErr {
t.Errorf("parseAddress() failed: %v", gotErr)
}
return
}
if tt.wantErr {
t.Fatal("parseAddress() succeeded unexpectedly")
}
if gotAddrType != tt.wantAddrType {
t.Errorf("parseAddress() addrType = %v, want %v", gotAddrType, tt.wantAddrType)
}
if !check(gotUsc, tt.wantUsc) {
t.Errorf("parseAddress() Usc = %v, want %v", gotUsc, tt.wantUsc)
}
if !check(gotSysc, tt.wantSysc) {
if (gotSysc == nil || tt.wantSysc == nil) ||
!(check(gotSysc.FDIndex, tt.wantSysc.FDIndex) &&
check(gotSysc.FDName, tt.wantSysc.FDName) &&
check(gotSysc.IdleTimeout, tt.wantSysc.IdleTimeout)) {
t.Errorf("parseAddress() Sysc = %v, want %v", asJSON(gotSysc), asJSON(tt.wantSysc))
}
}
})
}
}
func TestServe(t *testing.T) {
ctx, err := Serve("unix?path=/tmp/foo.sock", nil)
if err != nil {
t.Fatal()
}
if ctx.AddressType != UnixSocket {
t.Errorf("Serve() ServerCtx = %v, want %v", ctx.AddressType, UnixSocket)
}
ctx.Shutdown(context.TODO())
}
// Helpers
// print value instead of pointer
func asJSON[T any](val T) string {
op, err := json.Marshal(val)
if err != nil {
return err.Error()
}
return string(op)
}
func ptr[T any](val T) *T {
return &val
}
// nil safe equal check
func check[T comparable](got, want *T) bool {
if (got == nil) != (want == nil) {
return false
}
if got == nil {
return true
}
return *got == *want
}

View File

@ -1,32 +0,0 @@
package main
import (
"context"
"log"
"net/http"
"os"
"time"
"go.balki.me/anyhttp"
)
func main() {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hello\n"))
})
//log.Println("Got error: ", anyhttp.ListenAndServe(os.Args[1], nil))
ctx, err := anyhttp.Serve(os.Args[1], nil)
log.Printf("Got ctx: %v\n, err: %v", ctx, err)
log.Println(ctx.Addr())
if err != nil {
return
}
select {
case doneErr := <-ctx.Done:
log.Println(doneErr)
case <-time.After(1 * time.Minute):
log.Println("Awake")
ctx.Shutdown(context.TODO())
}
}