Refactor: Add GetListener and TLS versions
This commit is contained in:
		
							
								
								
									
										145
									
								
								anyhttp.go
									
									
									
									
									
								
							
							
						
						
									
										145
									
								
								anyhttp.go
									
									
									
									
									
								
							@@ -23,9 +23,9 @@ import (
 | 
			
		||||
type AddressType string
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	// UnixSocket - address is a unix socket, e.g. unix//run/foo.sock
 | 
			
		||||
	// UnixSocket - address is a unix socket, e.g. unix?path=/run/foo.sock
 | 
			
		||||
	UnixSocket AddressType = "UnixSocket"
 | 
			
		||||
	// SystemdFD - address is a systemd fd, e.g. sysd/fdname/myapp.socket
 | 
			
		||||
	// SystemdFD - address is a systemd fd, e.g. sysd?name=myapp.socket
 | 
			
		||||
	SystemdFD AddressType = "SystemdFD"
 | 
			
		||||
	// TCP - address is a TCP address, e.g. :1234
 | 
			
		||||
	TCP AddressType = "TCP"
 | 
			
		||||
@@ -203,6 +203,34 @@ func (s *SysdConfig) GetListener() (net.Listener, error) {
 | 
			
		||||
	return nil, errors.New("neither FDIndex nor FDName set")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetListener is low level function for use with non-http servers. e.g. tcp, smtp
 | 
			
		||||
// Caller should handle idle timeout if needed
 | 
			
		||||
func GetListener(addr string) (net.Listener, AddressType, any /* cfg */, error) {
 | 
			
		||||
 | 
			
		||||
	addrType, unixSocketConfig, sysdConfig, perr := parseAddress(addr)
 | 
			
		||||
	if perr != nil {
 | 
			
		||||
		return nil, Unknown, nil, perr
 | 
			
		||||
	}
 | 
			
		||||
	if unixSocketConfig != nil {
 | 
			
		||||
		listener, err := unixSocketConfig.GetListener()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, Unknown, nil, err
 | 
			
		||||
		}
 | 
			
		||||
		return listener, addrType, unixSocketConfig, nil
 | 
			
		||||
	} else if sysdConfig != nil {
 | 
			
		||||
		listener, err := sysdConfig.GetListener()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, Unknown, nil, err
 | 
			
		||||
		}
 | 
			
		||||
		return listener, addrType, sysdConfig, nil
 | 
			
		||||
	}
 | 
			
		||||
	if addr == "" {
 | 
			
		||||
		addr = ":http"
 | 
			
		||||
	}
 | 
			
		||||
	listener, err := net.Listen("tcp", addr)
 | 
			
		||||
	return listener, TCP, nil, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ServerCtx struct {
 | 
			
		||||
	AddressType      AddressType
 | 
			
		||||
	Listener         net.Listener
 | 
			
		||||
@@ -229,53 +257,14 @@ func (s *ServerCtx) Shutdown(ctx context.Context) error {
 | 
			
		||||
	return <-s.Done
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Serve creates and serve a http server.
 | 
			
		||||
func Serve(addr string, h http.Handler) (*ServerCtx, error) {
 | 
			
		||||
	var ctx ServerCtx
 | 
			
		||||
	var err error
 | 
			
		||||
	ctx.AddressType, ctx.UnixSocketConfig, ctx.SysdConfig, err = parseAddress(addr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
// ServeTLS creates and serves a HTTPS server.
 | 
			
		||||
func ServeTLS(addr string, h http.Handler, certFile string, keyFile string) (*ServerCtx, error) {
 | 
			
		||||
	return serve(addr, h, certFile, keyFile)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
	ctx.Listener, err = func() (net.Listener, error) {
 | 
			
		||||
		if ctx.UnixSocketConfig != nil {
 | 
			
		||||
			return ctx.UnixSocketConfig.GetListener()
 | 
			
		||||
		} else if ctx.SysdConfig != nil {
 | 
			
		||||
			return ctx.SysdConfig.GetListener()
 | 
			
		||||
		}
 | 
			
		||||
		if addr == "" {
 | 
			
		||||
			addr = ":http"
 | 
			
		||||
		}
 | 
			
		||||
		return net.Listen("tcp", addr)
 | 
			
		||||
	}()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	errChan := make(chan error)
 | 
			
		||||
	ctx.Done = errChan
 | 
			
		||||
	if ctx.AddressType == SystemdFD && ctx.SysdConfig.IdleTimeout != nil {
 | 
			
		||||
		ctx.Idler = idle.CreateIdler(*ctx.SysdConfig.IdleTimeout)
 | 
			
		||||
		ctx.Server = &http.Server{Handler: idle.WrapIdlerHandler(ctx.Idler, h)}
 | 
			
		||||
		waitErrChan := make(chan error)
 | 
			
		||||
		go func() {
 | 
			
		||||
			waitErrChan <- ctx.Server.Serve(ctx.Listener)
 | 
			
		||||
		}()
 | 
			
		||||
		go func() {
 | 
			
		||||
			select {
 | 
			
		||||
			case err := <-waitErrChan:
 | 
			
		||||
				errChan <- err
 | 
			
		||||
			case <-ctx.Idler.Chan():
 | 
			
		||||
				errChan <- ctx.Server.Shutdown(context.TODO())
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	} else {
 | 
			
		||||
		ctx.Server = &http.Server{Handler: h}
 | 
			
		||||
		go func() {
 | 
			
		||||
			errChan <- ctx.Server.Serve(ctx.Listener)
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
	return &ctx, nil
 | 
			
		||||
// Serve creates and serves a HTTP server.
 | 
			
		||||
func Serve(addr string, h http.Handler) (*ServerCtx, error) {
 | 
			
		||||
	return serve(addr, h, "", "")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ListenAndServe is the drop-in replacement for `http.ListenAndServe`.
 | 
			
		||||
@@ -288,6 +277,14 @@ func ListenAndServe(addr string, h http.Handler) error {
 | 
			
		||||
	return ctx.Wait()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ListenAndServeTLS(addr string, certFile string, keyFile string, h http.Handler) error {
 | 
			
		||||
	ctx, err := ServeTLS(addr, h, certFile, keyFile)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return ctx.Wait()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UnsetSystemdListenVars unsets the LISTEN* environment variables so they are not passed to any child processes
 | 
			
		||||
func UnsetSystemdListenVars() {
 | 
			
		||||
	_ = os.Unsetenv("LISTEN_PID")
 | 
			
		||||
@@ -389,3 +386,55 @@ func parseAddress(addr string) (addrType AddressType, usc *UnixSocketConfig, sys
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func serve(addr string, h http.Handler, certFile string, keyFile string) (*ServerCtx, error) {
 | 
			
		||||
 | 
			
		||||
	serveFn := func() func(ctx *ServerCtx) error {
 | 
			
		||||
		if certFile != "" {
 | 
			
		||||
			return func(ctx *ServerCtx) error {
 | 
			
		||||
				return ctx.Server.ServeTLS(ctx.Listener, certFile, keyFile)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		return func(ctx *ServerCtx) error {
 | 
			
		||||
			return ctx.Server.Serve(ctx.Listener)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
	var ctx ServerCtx
 | 
			
		||||
	var err error
 | 
			
		||||
	var cfg any
 | 
			
		||||
 | 
			
		||||
	ctx.Listener, ctx.AddressType, cfg, err = GetListener(addr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	switch ctx.AddressType {
 | 
			
		||||
	case UnixSocket:
 | 
			
		||||
		ctx.UnixSocketConfig = cfg.(*UnixSocketConfig)
 | 
			
		||||
	case SystemdFD:
 | 
			
		||||
		ctx.SysdConfig = cfg.(*SysdConfig)
 | 
			
		||||
	}
 | 
			
		||||
	errChan := make(chan error)
 | 
			
		||||
	ctx.Done = errChan
 | 
			
		||||
	if ctx.AddressType == SystemdFD && ctx.SysdConfig.IdleTimeout != nil {
 | 
			
		||||
		ctx.Idler = idle.CreateIdler(*ctx.SysdConfig.IdleTimeout)
 | 
			
		||||
		ctx.Server = &http.Server{Handler: idle.WrapIdlerHandler(ctx.Idler, h)}
 | 
			
		||||
		waitErrChan := make(chan error)
 | 
			
		||||
		go func() {
 | 
			
		||||
			waitErrChan <- serveFn(&ctx)
 | 
			
		||||
		}()
 | 
			
		||||
		go func() {
 | 
			
		||||
			select {
 | 
			
		||||
			case err := <-waitErrChan:
 | 
			
		||||
				errChan <- err
 | 
			
		||||
			case <-ctx.Idler.Chan():
 | 
			
		||||
				errChan <- ctx.Server.Shutdown(context.TODO())
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	} else {
 | 
			
		||||
		ctx.Server = &http.Server{Handler: h}
 | 
			
		||||
		go func() {
 | 
			
		||||
			errChan <- serveFn(&ctx)
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
	return &ctx, nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,9 @@
 | 
			
		||||
package anyhttp
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"log"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
@@ -119,6 +121,12 @@ func Test_parseAddress(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestServe(t *testing.T) {
 | 
			
		||||
	ctx, err := Serve("unix?path=/tmp/foo.sock", nil)
 | 
			
		||||
	log.Printf("Got ctx: %v\n,  err: %v", ctx, err)
 | 
			
		||||
	ctx.Shutdown(context.TODO())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Helpers
 | 
			
		||||
 | 
			
		||||
// print value instead of pointer
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										32
									
								
								examples/simple/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								examples/simple/main.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,32 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"go.balki.me/anyhttp"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
 | 
			
		||||
	http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		w.Write([]byte("hello\n"))
 | 
			
		||||
	})
 | 
			
		||||
	//log.Println("Got error: ", anyhttp.ListenAndServe(os.Args[1], nil))
 | 
			
		||||
	ctx, err := anyhttp.Serve(os.Args[1], nil)
 | 
			
		||||
	log.Printf("Got ctx: %v\n,  err: %v", ctx, err)
 | 
			
		||||
	log.Println(ctx.Addr())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	select {
 | 
			
		||||
	case doneErr := <-ctx.Done:
 | 
			
		||||
		log.Println(doneErr)
 | 
			
		||||
	case <-time.After(1 * time.Minute):
 | 
			
		||||
		log.Println("Awake")
 | 
			
		||||
		ctx.Shutdown(context.TODO())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								examples/simple/simple
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								examples/simple/simple
									
									
									
									
									
										Executable file
									
								
							
										
											Binary file not shown.
										
									
								
							
		Reference in New Issue
	
	Block a user