Add global idler and other helpers #2
							
								
								
									
										84
									
								
								anyhttp.go
									
									
									
									
									
								
							
							
						
						
									
										84
									
								
								anyhttp.go
									
									
									
									
									
								
							@@ -14,6 +14,20 @@ import (
 | 
				
			|||||||
	"syscall"
 | 
						"syscall"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// AddressType of the address passed
 | 
				
			||||||
 | 
					type AddressType string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var (
 | 
				
			||||||
 | 
						// UnixSocket - address is a unix socket, e.g. unix//run/foo.sock
 | 
				
			||||||
 | 
						UnixSocket AddressType = "UnixSocket"
 | 
				
			||||||
 | 
						// SystemdFD - address is a systemd fd, e.g. sysd/fdname/myapp.socket
 | 
				
			||||||
 | 
						SystemdFD AddressType = "SystemdFD"
 | 
				
			||||||
 | 
						// TCP - address is a TCP address, e.g. :1234
 | 
				
			||||||
 | 
						TCP AddressType = "TCP"
 | 
				
			||||||
 | 
						// Unknown - address is not recognized
 | 
				
			||||||
 | 
						Unknown AddressType = "Unknown"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// UnixSocketConfig has the configuration for Unix socket
 | 
					// UnixSocketConfig has the configuration for Unix socket
 | 
				
			||||||
type UnixSocketConfig struct {
 | 
					type UnixSocketConfig struct {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -182,59 +196,69 @@ func (s *SysdConfig) GetListener() (net.Listener, error) {
 | 
				
			|||||||
	return nil, errors.New("neither FDIndex nor FDName set")
 | 
						return nil, errors.New("neither FDIndex nor FDName set")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// UnknownAddress Error is returned when address does not match any known syntax
 | 
					 | 
				
			||||||
type UnknownAddress struct{}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (u UnknownAddress) Error() string {
 | 
					 | 
				
			||||||
	return "unknown address"
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// GetListener gets a unix or systemd socket listener
 | 
					// GetListener gets a unix or systemd socket listener
 | 
				
			||||||
func GetListener(addr string) (net.Listener, error) {
 | 
					func GetListener(addr string) (AddressType, net.Listener, error) {
 | 
				
			||||||
	if strings.HasPrefix(addr, "unix/") {
 | 
						if strings.HasPrefix(addr, "unix/") {
 | 
				
			||||||
		usc := NewUnixSocketConfig(strings.TrimPrefix(addr, "unix/"))
 | 
							usc := NewUnixSocketConfig(strings.TrimPrefix(addr, "unix/"))
 | 
				
			||||||
		return usc.GetListener()
 | 
							l, err := usc.GetListener()
 | 
				
			||||||
 | 
							return UnixSocket, l, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if strings.HasPrefix(addr, "sysd/fdidx/") {
 | 
						if strings.HasPrefix(addr, "sysd/fdidx/") {
 | 
				
			||||||
		idx, err := strconv.Atoi(strings.TrimPrefix(addr, "sysd/fdidx/"))
 | 
							idx, err := strconv.Atoi(strings.TrimPrefix(addr, "sysd/fdidx/"))
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return nil, fmt.Errorf("invalid fdidx, addr:%q err: %w", addr, err)
 | 
								return Unknown, nil, fmt.Errorf("invalid fdidx, addr:%q err: %w", addr, err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		sysdc := NewSysDConfigWithFDIdx(idx)
 | 
							sysdc := NewSysDConfigWithFDIdx(idx)
 | 
				
			||||||
		return sysdc.GetListener()
 | 
							l, err := sysdc.GetListener()
 | 
				
			||||||
 | 
							return SystemdFD, l, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if strings.HasPrefix(addr, "sysd/fdname/") {
 | 
						if strings.HasPrefix(addr, "sysd/fdname/") {
 | 
				
			||||||
		sysdc := NewSysDConfigWithFDName(strings.TrimPrefix(addr, "sysd/fdname/"))
 | 
							sysdc := NewSysDConfigWithFDName(strings.TrimPrefix(addr, "sysd/fdname/"))
 | 
				
			||||||
		return sysdc.GetListener()
 | 
							l, err := sysdc.GetListener()
 | 
				
			||||||
 | 
							return SystemdFD, l, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return nil, UnknownAddress{}
 | 
						if port, err := strconv.Atoi(addr); err == nil {
 | 
				
			||||||
 | 
							if port > 0 && port < 65536 {
 | 
				
			||||||
 | 
								addr = fmt.Sprintf(":%v", port)
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								return Unknown, nil, fmt.Errorf("invalid port: %v", port)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if addr == "" {
 | 
				
			||||||
 | 
							addr = ":http"
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						l, err := net.Listen("tcp", addr)
 | 
				
			||||||
 | 
						return TCP, l, err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Serve creates and serve a http server.
 | 
				
			||||||
 | 
					func Serve(addr string, h http.Handler) (AddressType, *http.Server, <-chan error, error) {
 | 
				
			||||||
 | 
						addrType, listener, err := GetListener(addr)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return addrType, nil, nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						srv := &http.Server{Handler: h}
 | 
				
			||||||
 | 
						done := make(chan error)
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							done <- srv.Serve(listener)
 | 
				
			||||||
 | 
							close(done)
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
						return addrType, srv, done, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ListenAndServe is the drop-in replacement for `http.ListenAndServe`.
 | 
					// ListenAndServe is the drop-in replacement for `http.ListenAndServe`.
 | 
				
			||||||
// Supports unix and systemd sockets in addition
 | 
					// Supports unix and systemd sockets in addition
 | 
				
			||||||
func ListenAndServe(addr string, h http.Handler) error {
 | 
					func ListenAndServe(addr string, h http.Handler) error {
 | 
				
			||||||
 | 
						_, _, done, err := Serve(addr, h)
 | 
				
			||||||
	listener, err := GetListener(addr)
 | 
						if err != nil {
 | 
				
			||||||
	if _, isUnknown := err.(UnknownAddress); err != nil && !isUnknown {
 | 
					 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						return <-done
 | 
				
			||||||
	if listener != nil {
 | 
					 | 
				
			||||||
		return http.Serve(listener, h)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if port, err := strconv.Atoi(addr); err == nil {
 | 
					 | 
				
			||||||
		if port > 0 && port < 65536 {
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			return http.ListenAndServe(fmt.Sprintf(":%v", port), h)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return fmt.Errorf("invalid port: %v", port)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return http.ListenAndServe(addr, h)
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// UnsetSystemdListenVars unsets the LISTEN* environment variables so they are not passed to any child processes
 | 
					// UnsetSystemdListenVars unsets the LISTEN* environment variables so they are not passed to any child processes
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										68
									
								
								idle/idle.go
									
									
									
									
									
								
							
							
						
						
									
										68
									
								
								idle/idle.go
									
									
									
									
									
								
							@@ -1,16 +1,75 @@
 | 
				
			|||||||
// Package idle helps to gracefully shutdown idle servers
 | 
					// Package idle helps to gracefully shutdown idle (typically http) servers
 | 
				
			||||||
package idle
 | 
					package idle
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
	"sync/atomic"
 | 
						"sync/atomic"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var (
 | 
				
			||||||
 | 
						// For simple servers without backgroud jobs, global singleton for simpler API
 | 
				
			||||||
 | 
						// Enter/Exit worn't work for global idler as Enter may be called before Wait, use CreateIdler in those cases
 | 
				
			||||||
 | 
						gIdler atomic.Pointer[idler]
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Wait waits till the server is idle and returns. i.e. no Ticks in last <timeout> duration
 | 
				
			||||||
 | 
					func Wait(timeout time.Duration) error {
 | 
				
			||||||
 | 
						i := CreateIdler(timeout).(*idler)
 | 
				
			||||||
 | 
						ok := gIdler.CompareAndSwap(nil, i)
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							return fmt.Errorf("idler already waiting")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						i.Wait()
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Tick records the current time. This will make the server not idle until next Tick or timeout
 | 
				
			||||||
 | 
					func Tick() {
 | 
				
			||||||
 | 
						i := gIdler.Load()
 | 
				
			||||||
 | 
						if i != nil {
 | 
				
			||||||
 | 
							i.Tick()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// WrapHandler calls Tick() before processing passing request to http.Handler
 | 
				
			||||||
 | 
					func WrapHandler(h http.Handler) http.Handler {
 | 
				
			||||||
 | 
						if h == nil {
 | 
				
			||||||
 | 
							h = http.DefaultServeMux
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
							Tick()
 | 
				
			||||||
 | 
							h.ServeHTTP(w, r)
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// WrapIdlerHandler calls idler.Tick() before processing passing request to http.Handler
 | 
				
			||||||
 | 
					func WrapIdlerHandler(i Idler, h http.Handler) http.Handler {
 | 
				
			||||||
 | 
						if h == nil {
 | 
				
			||||||
 | 
							h = http.DefaultServeMux
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
							i.Tick()
 | 
				
			||||||
 | 
							h.ServeHTTP(w, r)
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Idler helps manage idle servers
 | 
				
			||||||
type Idler interface {
 | 
					type Idler interface {
 | 
				
			||||||
	Enter()
 | 
						// Tick records the current time. This will make the server not idle until next Tick or timeout
 | 
				
			||||||
	Exit()
 | 
					 | 
				
			||||||
	Wait()
 | 
					 | 
				
			||||||
	Tick()
 | 
						Tick()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Wait waits till the server is idle and returns. i.e. no Ticks in last <timeout> duration
 | 
				
			||||||
 | 
						Wait()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// For long running background jobs, use Enter to record start time. Wait will not return while there are active jobs running
 | 
				
			||||||
 | 
						Enter()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Exit records end of a background job
 | 
				
			||||||
 | 
						Exit()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Get the channel to wait yourself
 | 
				
			||||||
	Chan() <-chan struct{}
 | 
						Chan() <-chan struct{}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -29,6 +88,7 @@ func (i *idler) Exit() {
 | 
				
			|||||||
	i.active.Add(-1)
 | 
						i.active.Add(-1)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CreateIdler creates an Idler with given timeout
 | 
				
			||||||
func CreateIdler(timeout time.Duration) Idler {
 | 
					func CreateIdler(timeout time.Duration) Idler {
 | 
				
			||||||
	i := &idler{}
 | 
						i := &idler{}
 | 
				
			||||||
	i.c = make(chan struct{})
 | 
						i.c = make(chan struct{})
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,7 +5,30 @@ import (
 | 
				
			|||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestIdlerChan(t *testing.T) {
 | 
					func TestIdlerChan(_ *testing.T) {
 | 
				
			||||||
	i := CreateIdler(1 * time.Second)
 | 
						i := CreateIdler(10 * time.Millisecond)
 | 
				
			||||||
	<-i.Chan()
 | 
						<-i.Chan()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestGlobalIdler(t *testing.T) {
 | 
				
			||||||
 | 
						err := Wait(10 * time.Millisecond)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("idle.Wait failed, %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err = Wait(10 * time.Millisecond)
 | 
				
			||||||
 | 
						if err == nil {
 | 
				
			||||||
 | 
							t.Fatal("idle.Wait should fail when called second time")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestIdlerEnterExit(t *testing.T) {
 | 
				
			||||||
 | 
						i := CreateIdler(10 * time.Millisecond).(*idler)
 | 
				
			||||||
 | 
						i.Enter()
 | 
				
			||||||
 | 
						if i.active.Load() != 1 {
 | 
				
			||||||
 | 
							t.FailNow()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						i.Exit()
 | 
				
			||||||
 | 
						if i.active.Load() != 0 {
 | 
				
			||||||
 | 
							t.FailNow()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user