10 Commits

5 changed files with 311 additions and 45 deletions

View File

@ -2,6 +2,8 @@ Create http server listening on unix sockets or systemd socket activated fds
## Quick Usage
go get go.balki.me/anyhttp
Just replace `http.ListenAndServe` with `anyhttp.ListenAndServe`.
```diff
@ -49,3 +51,12 @@ Anything else is directly passed to `http.ListenAndServe` as well. Below example
:http
:8888
127.0.0.1:8080
## Documentation
https://pkg.go.dev/go.balki.me/anyhttp
### Related links
* https://gist.github.com/teknoraver/5ffacb8757330715bcbcc90e6d46ac74#file-unixhttpd-go
* https://github.com/coreos/go-systemd/tree/main/activation

View File

@ -10,9 +10,24 @@ import (
"os"
"strconv"
"strings"
"sync"
"syscall"
)
// AddressType of the address passed
type AddressType string
var (
// UnixSocket - address is a unix socket, e.g. unix//run/foo.sock
UnixSocket AddressType = "UnixSocket"
// SystemdFD - address is a systemd fd, e.g. sysd/fdname/myapp.socket
SystemdFD AddressType = "SystemdFD"
// TCP - address is a TCP address, e.g. :1234
TCP AddressType = "TCP"
// Unknown - address is not recognized
Unknown AddressType = "Unknown"
)
// UnixSocketConfig has the configuration for Unix socket
type UnixSocketConfig struct {
@ -39,6 +54,39 @@ func NewUnixSocketConfig(socketPath string) UnixSocketConfig {
return usc
}
type sysdEnvData struct {
pid int
fdNames []string
fdNamesStr string
numFds int
}
var sysdEnvParser = struct {
sysdOnce sync.Once
data sysdEnvData
err error
}{}
func parse() (sysdEnvData, error) {
p := &sysdEnvParser
p.sysdOnce.Do(func() {
p.data.pid, p.err = strconv.Atoi(os.Getenv("LISTEN_PID"))
if p.err != nil {
p.err = fmt.Errorf("invalid LISTEN_PID, err: %w", p.err)
return
}
p.data.numFds, p.err = strconv.Atoi(os.Getenv("LISTEN_FDS"))
if p.err != nil {
p.err = fmt.Errorf("invalid LISTEN_FDS, err: %w", p.err)
return
}
p.data.fdNamesStr = os.Getenv("LISTEN_FDNAMES")
p.data.fdNames = strings.Split(p.data.fdNamesStr, ":")
})
return p.data, p.err
}
// SysdConfig has the configuration for the socket activated fd
type SysdConfig struct {
// Integer value starting at 0. Either index or name is required
@ -54,7 +102,7 @@ type SysdConfig struct {
// DefaultSysdConfig has the default values for SysdConfig
var DefaultSysdConfig = SysdConfig{
CheckPID: true,
UnsetEnv: false,
UnsetEnv: true,
}
// NewSysDConfigWithFDIdx creates SysdConfig with defaults and fdIdx
@ -112,99 +160,110 @@ func (s *SysdConfig) GetListener() (net.Listener, error) {
defer UnsetSystemdListenVars()
}
if s.CheckPID {
pid, err := strconv.Atoi(os.Getenv("LISTEN_PID"))
if err != nil {
return nil, err
}
if pid != os.Getpid() {
return nil, fmt.Errorf("fd not for you")
}
}
numFds, err := strconv.Atoi(os.Getenv("LISTEN_FDS"))
envData, err := parse()
if err != nil {
return nil, err
}
fdNames := strings.Split(os.Getenv("LISTEN_FDNAMES"), ":")
if s.CheckPID {
if envData.pid != os.Getpid() {
return nil, fmt.Errorf("unexpected PID, current:%v, LISTEN_PID: %v", os.Getpid(), envData.pid)
}
}
if s.FDIndex != nil {
idx := *s.FDIndex
if idx < 0 || idx >= numFds {
return nil, fmt.Errorf("invalid fd")
if idx < 0 || idx >= envData.numFds {
return nil, fmt.Errorf("invalid fd index, expected between 0 and %v, got: %v", envData.numFds, idx)
}
fd := StartFD + idx
if idx < len(fdNames) {
return makeFdListener(fd, fdNames[idx])
if idx < len(envData.fdNames) {
return makeFdListener(fd, envData.fdNames[idx])
}
return makeFdListener(fd, fmt.Sprintf("sysdfd_%d", fd))
}
if s.FDName != nil {
for idx, name := range fdNames {
for idx, name := range envData.fdNames {
if name == *s.FDName {
fd := StartFD + idx
return makeFdListener(fd, name)
}
}
return nil, fmt.Errorf("fdName not found: %q", *s.FDName)
return nil, fmt.Errorf("fdName not found: %q, LISTEN_FDNAMES:%q", *s.FDName, envData.fdNamesStr)
}
return nil, fmt.Errorf("neither FDIndex nor FDName set")
return nil, errors.New("neither FDIndex nor FDName set")
}
// GetListener gets a unix or systemd socket listener
func GetListener(addr string) (net.Listener, error) {
func GetListener(addr string) (AddressType, net.Listener, error) {
if strings.HasPrefix(addr, "unix/") {
usc := NewUnixSocketConfig(strings.TrimPrefix(addr, "unix/"))
return usc.GetListener()
l, err := usc.GetListener()
return UnixSocket, l, err
}
if strings.HasPrefix(addr, "sysd/fdidx/") {
idx, err := strconv.Atoi(strings.TrimPrefix(addr, "sysd/fdidx/"))
if err != nil {
return nil, err
return Unknown, nil, fmt.Errorf("invalid fdidx, addr:%q err: %w", addr, err)
}
sysdc := NewSysDConfigWithFDIdx(idx)
return sysdc.GetListener()
l, err := sysdc.GetListener()
return SystemdFD, l, err
}
if strings.HasPrefix(addr, "sysd/fdname/") {
sysdc := NewSysDConfigWithFDName(strings.TrimPrefix(addr, "sysd/fdname/"))
return sysdc.GetListener()
l, err := sysdc.GetListener()
return SystemdFD, l, err
}
return nil, nil
if port, err := strconv.Atoi(addr); err == nil {
if port > 0 && port < 65536 {
addr = fmt.Sprintf(":%v", port)
} else {
return Unknown, nil, fmt.Errorf("invalid port: %v", port)
}
}
if addr == "" {
addr = ":http"
}
l, err := net.Listen("tcp", addr)
return TCP, l, err
}
// Serve creates and serve a http server.
func Serve(addr string, h http.Handler) (AddressType, *http.Server, <-chan error, error) {
addrType, listener, err := GetListener(addr)
if err != nil {
return addrType, nil, nil, err
}
srv := &http.Server{Handler: h}
done := make(chan error)
go func() {
done <- srv.Serve(listener)
close(done)
}()
return addrType, srv, done, nil
}
// ListenAndServe is the drop-in replacement for `http.ListenAndServe`.
// Supports unix and systemd sockets in addition
func ListenAndServe(addr string, h http.Handler) error {
listener, err := GetListener(addr)
_, _, done, err := Serve(addr, h)
if err != nil {
return err
}
if listener != nil {
return http.Serve(listener, h)
}
if port, err := strconv.Atoi(addr); err == nil {
if port > 0 && port < 65536 {
return http.ListenAndServe(fmt.Sprintf(":%v", port), h)
}
return fmt.Errorf("invalid port: %v", port)
}
return http.ListenAndServe(addr, h)
return <-done
}
// UnsetSystemdListenVars unsets the LISTEN* environment variables so they are not passed to any child processes
func UnsetSystemdListenVars() {
os.Unsetenv("LISTEN_PID")
os.Unsetenv("LISTEN_FDS")
os.Unsetenv("LISTEN_FDNAMES")
_ = os.Unsetenv("LISTEN_PID")
_ = os.Unsetenv("LISTEN_FDS")
_ = os.Unsetenv("LISTEN_FDNAMES")
}

View File

@ -0,0 +1,35 @@
package main
import (
"context"
"net/http"
"time"
"go.balki.me/anyhttp/idle"
)
func main() {
idler := idle.CreateIdler(10 * time.Second)
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
idler.Tick()
w.Write([]byte("Hey there!\n"))
})
http.HandleFunc("/job", func(w http.ResponseWriter, r *http.Request) {
go func() {
idler.Enter()
defer idler.Exit()
time.Sleep(15 * time.Second)
}()
w.Write([]byte("Job scheduled\n"))
})
server := http.Server{
Addr: ":8888",
}
go server.ListenAndServe()
idler.Wait()
server.Shutdown(context.TODO())
}

127
idle/idle.go Normal file
View File

@ -0,0 +1,127 @@
// Package idle helps to gracefully shutdown idle (typically http) servers
package idle
import (
"fmt"
"net/http"
"sync/atomic"
"time"
)
var (
// For simple servers without backgroud jobs, global singleton for simpler API
// Enter/Exit worn't work for global idler as Enter may be called before Wait, use CreateIdler in those cases
gIdler atomic.Pointer[idler]
)
// Wait waits till the server is idle and returns. i.e. no Ticks in last <timeout> duration
func Wait(timeout time.Duration) error {
i := CreateIdler(timeout).(*idler)
ok := gIdler.CompareAndSwap(nil, i)
if !ok {
return fmt.Errorf("idler already waiting")
}
i.Wait()
return nil
}
// Tick records the current time. This will make the server not idle until next Tick or timeout
func Tick() {
i := gIdler.Load()
if i != nil {
i.Tick()
}
}
// WrapHandler calls Tick() before processing passing request to http.Handler
func WrapHandler(h http.Handler) http.Handler {
if h == nil {
h = http.DefaultServeMux
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Tick()
h.ServeHTTP(w, r)
})
}
// WrapIdlerHandler calls idler.Tick() before processing passing request to http.Handler
func WrapIdlerHandler(i Idler, h http.Handler) http.Handler {
if h == nil {
h = http.DefaultServeMux
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
i.Tick()
h.ServeHTTP(w, r)
})
}
// Idler helps manage idle servers
type Idler interface {
// Tick records the current time. This will make the server not idle until next Tick or timeout
Tick()
// Wait waits till the server is idle and returns. i.e. no Ticks in last <timeout> duration
Wait()
// For long running background jobs, use Enter to record start time. Wait will not return while there are active jobs running
Enter()
// Exit records end of a background job
Exit()
// Get the channel to wait yourself
Chan() <-chan struct{}
}
type idler struct {
lastTick atomic.Pointer[time.Time]
c chan struct{}
active atomic.Int64
}
func (i *idler) Enter() {
i.active.Add(1)
}
func (i *idler) Exit() {
i.Tick()
i.active.Add(-1)
}
// CreateIdler creates an Idler with given timeout
func CreateIdler(timeout time.Duration) Idler {
i := &idler{}
i.c = make(chan struct{})
i.Tick()
go func() {
for {
if i.active.Load() != 0 {
time.Sleep(timeout)
continue
}
t := *i.lastTick.Load()
now := time.Now()
dur := t.Add(timeout).Sub(now)
if dur == dur.Abs() {
time.Sleep(dur)
continue
}
break
}
close(i.c)
}()
return i
}
func (i *idler) Tick() {
now := time.Now()
i.lastTick.Store(&now)
}
func (i *idler) Wait() {
<-i.c
}
func (i *idler) Chan() <-chan struct{} {
return i.c
}

34
idle/idle_test.go Normal file
View File

@ -0,0 +1,34 @@
package idle
import (
"testing"
"time"
)
func TestIdlerChan(_ *testing.T) {
i := CreateIdler(10 * time.Millisecond)
<-i.Chan()
}
func TestGlobalIdler(t *testing.T) {
err := Wait(10 * time.Millisecond)
if err != nil {
t.Fatalf("idle.Wait failed, %v", err)
}
err = Wait(10 * time.Millisecond)
if err == nil {
t.Fatal("idle.Wait should fail when called second time")
}
}
func TestIdlerEnterExit(t *testing.T) {
i := CreateIdler(10 * time.Millisecond).(*idler)
i.Enter()
if i.active.Load() != 1 {
t.FailNow()
}
i.Exit()
if i.active.Load() != 0 {
t.FailNow()
}
}