Add global idler and other helpers #2
							
								
								
									
										84
									
								
								anyhttp.go
									
									
									
									
									
								
							
							
						
						
									
										84
									
								
								anyhttp.go
									
									
									
									
									
								
							@@ -14,6 +14,20 @@ import (
 | 
			
		||||
	"syscall"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// AddressType of the address passed
 | 
			
		||||
type AddressType string
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	// UnixSocket - address is a unix socket, e.g. unix//run/foo.sock
 | 
			
		||||
	UnixSocket AddressType = "UnixSocket"
 | 
			
		||||
	// SystemdFD - address is a systemd fd, e.g. sysd/fdname/myapp.socket
 | 
			
		||||
	SystemdFD AddressType = "SystemdFD"
 | 
			
		||||
	// TCP - address is a TCP address, e.g. :1234
 | 
			
		||||
	TCP AddressType = "TCP"
 | 
			
		||||
	// Unknown - address is not recognized
 | 
			
		||||
	Unknown AddressType = "Unknown"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// UnixSocketConfig has the configuration for Unix socket
 | 
			
		||||
type UnixSocketConfig struct {
 | 
			
		||||
 | 
			
		||||
@@ -182,59 +196,69 @@ func (s *SysdConfig) GetListener() (net.Listener, error) {
 | 
			
		||||
	return nil, errors.New("neither FDIndex nor FDName set")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UnknownAddress Error is returned when address does not match any known syntax
 | 
			
		||||
type UnknownAddress struct{}
 | 
			
		||||
 | 
			
		||||
func (u UnknownAddress) Error() string {
 | 
			
		||||
	return "unknown address"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetListener gets a unix or systemd socket listener
 | 
			
		||||
func GetListener(addr string) (net.Listener, error) {
 | 
			
		||||
func GetListener(addr string) (AddressType, net.Listener, error) {
 | 
			
		||||
	if strings.HasPrefix(addr, "unix/") {
 | 
			
		||||
		usc := NewUnixSocketConfig(strings.TrimPrefix(addr, "unix/"))
 | 
			
		||||
		return usc.GetListener()
 | 
			
		||||
		l, err := usc.GetListener()
 | 
			
		||||
		return UnixSocket, l, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if strings.HasPrefix(addr, "sysd/fdidx/") {
 | 
			
		||||
		idx, err := strconv.Atoi(strings.TrimPrefix(addr, "sysd/fdidx/"))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, fmt.Errorf("invalid fdidx, addr:%q err: %w", addr, err)
 | 
			
		||||
			return Unknown, nil, fmt.Errorf("invalid fdidx, addr:%q err: %w", addr, err)
 | 
			
		||||
		}
 | 
			
		||||
		sysdc := NewSysDConfigWithFDIdx(idx)
 | 
			
		||||
		return sysdc.GetListener()
 | 
			
		||||
		l, err := sysdc.GetListener()
 | 
			
		||||
		return SystemdFD, l, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if strings.HasPrefix(addr, "sysd/fdname/") {
 | 
			
		||||
		sysdc := NewSysDConfigWithFDName(strings.TrimPrefix(addr, "sysd/fdname/"))
 | 
			
		||||
		return sysdc.GetListener()
 | 
			
		||||
		l, err := sysdc.GetListener()
 | 
			
		||||
		return SystemdFD, l, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, UnknownAddress{}
 | 
			
		||||
	if port, err := strconv.Atoi(addr); err == nil {
 | 
			
		||||
		if port > 0 && port < 65536 {
 | 
			
		||||
			addr = fmt.Sprintf(":%v", port)
 | 
			
		||||
		} else {
 | 
			
		||||
			return Unknown, nil, fmt.Errorf("invalid port: %v", port)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if addr == "" {
 | 
			
		||||
		addr = ":http"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	l, err := net.Listen("tcp", addr)
 | 
			
		||||
	return TCP, l, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Serve creates and serve a http server.
 | 
			
		||||
func Serve(addr string, h http.Handler) (AddressType, *http.Server, <-chan error, error) {
 | 
			
		||||
	addrType, listener, err := GetListener(addr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return addrType, nil, nil, err
 | 
			
		||||
	}
 | 
			
		||||
	srv := &http.Server{Handler: h}
 | 
			
		||||
	done := make(chan error)
 | 
			
		||||
	go func() {
 | 
			
		||||
		done <- srv.Serve(listener)
 | 
			
		||||
		close(done)
 | 
			
		||||
	}()
 | 
			
		||||
	return addrType, srv, done, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ListenAndServe is the drop-in replacement for `http.ListenAndServe`.
 | 
			
		||||
// Supports unix and systemd sockets in addition
 | 
			
		||||
func ListenAndServe(addr string, h http.Handler) error {
 | 
			
		||||
 | 
			
		||||
	listener, err := GetListener(addr)
 | 
			
		||||
	if _, isUnknown := err.(UnknownAddress); err != nil && !isUnknown {
 | 
			
		||||
	_, _, done, err := Serve(addr, h)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if listener != nil {
 | 
			
		||||
		return http.Serve(listener, h)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if port, err := strconv.Atoi(addr); err == nil {
 | 
			
		||||
		if port > 0 && port < 65536 {
 | 
			
		||||
 | 
			
		||||
			return http.ListenAndServe(fmt.Sprintf(":%v", port), h)
 | 
			
		||||
		}
 | 
			
		||||
		return fmt.Errorf("invalid port: %v", port)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return http.ListenAndServe(addr, h)
 | 
			
		||||
	return <-done
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UnsetSystemdListenVars unsets the LISTEN* environment variables so they are not passed to any child processes
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										70
									
								
								idle/idle.go
									
									
									
									
									
								
							
							
						
						
									
										70
									
								
								idle/idle.go
									
									
									
									
									
								
							@@ -1,16 +1,75 @@
 | 
			
		||||
// Package idle helps to gracefully shutdown idle servers
 | 
			
		||||
// Package idle helps to gracefully shutdown idle (typically http) servers
 | 
			
		||||
package idle
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Idler interface {
 | 
			
		||||
	Enter()
 | 
			
		||||
	Exit()
 | 
			
		||||
	Wait()
 | 
			
		||||
var (
 | 
			
		||||
	// For simple servers without backgroud jobs, global singleton for simpler API
 | 
			
		||||
	// Enter/Exit worn't work for global idler as Enter may be called before Wait, use CreateIdler in those cases
 | 
			
		||||
	gIdler atomic.Pointer[idler]
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Wait waits till the server is idle and returns. i.e. no Ticks in last <timeout> duration
 | 
			
		||||
func Wait(timeout time.Duration) error {
 | 
			
		||||
	i := CreateIdler(timeout).(*idler)
 | 
			
		||||
	ok := gIdler.CompareAndSwap(nil, i)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return fmt.Errorf("idler already waiting")
 | 
			
		||||
	}
 | 
			
		||||
	i.Wait()
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Tick records the current time. This will make the server not idle until next Tick or timeout
 | 
			
		||||
func Tick() {
 | 
			
		||||
	i := gIdler.Load()
 | 
			
		||||
	if i != nil {
 | 
			
		||||
		i.Tick()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WrapHandler calls Tick() before processing passing request to http.Handler
 | 
			
		||||
func WrapHandler(h http.Handler) http.Handler {
 | 
			
		||||
	if h == nil {
 | 
			
		||||
		h = http.DefaultServeMux
 | 
			
		||||
	}
 | 
			
		||||
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		Tick()
 | 
			
		||||
		h.ServeHTTP(w, r)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WrapIdlerHandler calls idler.Tick() before processing passing request to http.Handler
 | 
			
		||||
func WrapIdlerHandler(i Idler, h http.Handler) http.Handler {
 | 
			
		||||
	if h == nil {
 | 
			
		||||
		h = http.DefaultServeMux
 | 
			
		||||
	}
 | 
			
		||||
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		i.Tick()
 | 
			
		||||
		h.ServeHTTP(w, r)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Idler helps manage idle servers
 | 
			
		||||
type Idler interface {
 | 
			
		||||
	// Tick records the current time. This will make the server not idle until next Tick or timeout
 | 
			
		||||
	Tick()
 | 
			
		||||
 | 
			
		||||
	// Wait waits till the server is idle and returns. i.e. no Ticks in last <timeout> duration
 | 
			
		||||
	Wait()
 | 
			
		||||
 | 
			
		||||
	// For long running background jobs, use Enter to record start time. Wait will not return while there are active jobs running
 | 
			
		||||
	Enter()
 | 
			
		||||
 | 
			
		||||
	// Exit records end of a background job
 | 
			
		||||
	Exit()
 | 
			
		||||
 | 
			
		||||
	// Get the channel to wait yourself
 | 
			
		||||
	Chan() <-chan struct{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -29,6 +88,7 @@ func (i *idler) Exit() {
 | 
			
		||||
	i.active.Add(-1)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CreateIdler creates an Idler with given timeout
 | 
			
		||||
func CreateIdler(timeout time.Duration) Idler {
 | 
			
		||||
	i := &idler{}
 | 
			
		||||
	i.c = make(chan struct{})
 | 
			
		||||
 
 | 
			
		||||
@@ -5,7 +5,30 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestIdlerChan(t *testing.T) {
 | 
			
		||||
	i := CreateIdler(1 * time.Second)
 | 
			
		||||
func TestIdlerChan(_ *testing.T) {
 | 
			
		||||
	i := CreateIdler(10 * time.Millisecond)
 | 
			
		||||
	<-i.Chan()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestGlobalIdler(t *testing.T) {
 | 
			
		||||
	err := Wait(10 * time.Millisecond)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("idle.Wait failed, %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	err = Wait(10 * time.Millisecond)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		t.Fatal("idle.Wait should fail when called second time")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestIdlerEnterExit(t *testing.T) {
 | 
			
		||||
	i := CreateIdler(10 * time.Millisecond).(*idler)
 | 
			
		||||
	i.Enter()
 | 
			
		||||
	if i.active.Load() != 1 {
 | 
			
		||||
		t.FailNow()
 | 
			
		||||
	}
 | 
			
		||||
	i.Exit()
 | 
			
		||||
	if i.active.Load() != 0 {
 | 
			
		||||
		t.FailNow()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user