Connect backend refactor (#801)
* Move connect backend code to its own package * Move errors into the connect package * Add NewBackend func
This commit is contained in:
92
pkg/connect/backend.go
Normal file
92
pkg/connect/backend.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package connect
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type Backend struct {
|
||||
Endpoint string
|
||||
Token string
|
||||
PassHeaders []string
|
||||
|
||||
logger *logrus.Logger
|
||||
}
|
||||
|
||||
func NewBackend(endpoint string, token string) Backend {
|
||||
return Backend{
|
||||
Endpoint: endpoint,
|
||||
Token: token,
|
||||
logger: logrus.StandardLogger(),
|
||||
}
|
||||
}
|
||||
|
||||
func (be *Backend) SetLogger(logger *logrus.Logger) {
|
||||
be.logger = logger
|
||||
}
|
||||
|
||||
func (be *Backend) SetPassHeaders(headers []string) {
|
||||
be.PassHeaders = headers
|
||||
}
|
||||
|
||||
func (be *Backend) FetchCredential(ctx context.Context, resource string, headers http.Header) (*Credential, error) {
|
||||
be.logger.WithField("resource", resource).Debug("fetching database credential")
|
||||
|
||||
request := Request{
|
||||
Resource: resource,
|
||||
Token: be.Token,
|
||||
Headers: map[string]string{},
|
||||
}
|
||||
|
||||
// Pass allow-listed client headers to the backend request
|
||||
for _, name := range be.PassHeaders {
|
||||
request.Headers[strings.ToLower(name)] = headers.Get(name)
|
||||
}
|
||||
|
||||
body, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
be.logger.WithField("resource", resource).Error("backend request serialization error:", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, be.Endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("content-type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
be.logger.WithField("resource", resource).Error("backend credential fetch failed:", err)
|
||||
return nil, errBackendConnectError
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
err = fmt.Errorf("backend credential fetch received HTTP status code %v", resp.StatusCode)
|
||||
|
||||
be.logger.
|
||||
WithField("resource", request.Resource).
|
||||
WithField("status", resp.StatusCode).
|
||||
Error(err)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cred := &Credential{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(cred); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cred.DatabaseURL == "" {
|
||||
return nil, errConnStringRequired
|
||||
}
|
||||
|
||||
return cred, nil
|
||||
}
|
||||
173
pkg/connect/backend_test.go
Normal file
173
pkg/connect/backend_test.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package connect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBackendFetchCredential(t *testing.T) {
|
||||
examples := []struct {
|
||||
name string
|
||||
backend Backend
|
||||
resourceName string
|
||||
cred *Credential
|
||||
headers http.Header
|
||||
ctx func() (context.Context, context.CancelFunc)
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Bad auth token",
|
||||
backend: Backend{Endpoint: "http://localhost:5555/unauthorized"},
|
||||
err: errors.New("backend credential fetch received HTTP status code 401"),
|
||||
},
|
||||
{
|
||||
name: "Backend timeout",
|
||||
backend: Backend{Endpoint: "http://localhost:5555/timeout"},
|
||||
ctx: func() (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), time.Millisecond*100)
|
||||
},
|
||||
err: errors.New("unable to connect to the auth backend"),
|
||||
},
|
||||
{
|
||||
name: "Empty response",
|
||||
backend: Backend{Endpoint: "http://localhost:5555/empty-response"},
|
||||
err: errors.New("connection string is required"),
|
||||
},
|
||||
{
|
||||
name: "Missing header",
|
||||
backend: Backend{Endpoint: "http://localhost:5555/pass-header"},
|
||||
err: errors.New("backend credential fetch received HTTP status code 400"),
|
||||
},
|
||||
{
|
||||
name: "Require header",
|
||||
backend: Backend{
|
||||
Endpoint: "http://localhost:5555/pass-header",
|
||||
PassHeaders: []string{"x-foo"},
|
||||
},
|
||||
headers: http.Header{
|
||||
"X-Foo": []string{"bar"},
|
||||
},
|
||||
cred: &Credential{DatabaseURL: "postgres://hostname/bar"},
|
||||
},
|
||||
{
|
||||
name: "Success",
|
||||
backend: Backend{Endpoint: "http://localhost:5555/success"},
|
||||
cred: &Credential{DatabaseURL: "postgres://hostname/dbname"},
|
||||
},
|
||||
}
|
||||
|
||||
srvCtx, srvCancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer srvCancel()
|
||||
|
||||
startTestBackend(srvCtx, "localhost:5555")
|
||||
|
||||
for _, ex := range examples {
|
||||
ex.backend.logger = logrus.StandardLogger()
|
||||
|
||||
t.Run(ex.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
if ex.ctx != nil {
|
||||
ctx, cancel = ex.ctx()
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
cred, err := ex.backend.FetchCredential(ctx, ex.resourceName, ex.headers)
|
||||
assert.Equal(t, ex.err, err)
|
||||
assert.Equal(t, ex.cred, cred)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func startTestBackend(ctx context.Context, listenAddr string) {
|
||||
router := gin.New()
|
||||
|
||||
router.Use(func(c *gin.Context) {
|
||||
if c.GetHeader("content-type") != "application/json" {
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
|
||||
router.POST("/unauthorized", func(c *gin.Context) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
})
|
||||
|
||||
router.POST("/timeout", func(c *gin.Context) {
|
||||
time.Sleep(time.Second)
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
router.POST("/empty-response", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
router.POST("/pass-header", func(c *gin.Context) {
|
||||
req := Request{}
|
||||
if err := c.BindJSON(&req); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
header := req.Headers["x-foo"]
|
||||
if header == "" {
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"database_url": "postgres://hostname/" + header,
|
||||
})
|
||||
})
|
||||
|
||||
router.POST("/success", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"database_url": "postgres://hostname/dbname",
|
||||
})
|
||||
})
|
||||
|
||||
server := &http.Server{Addr: listenAddr, Handler: router}
|
||||
mustStartServer(server)
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
if err := server.Shutdown(context.Background()); err != nil && err != http.ErrServerClosed {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func mustStartServer(server *http.Server) {
|
||||
go func() {
|
||||
err := server.ListenAndServe()
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := waitForServer(server.Addr, 5); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func waitForServer(addr string, n int) error {
|
||||
var lastErr error
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
20
pkg/connect/types.go
Normal file
20
pkg/connect/types.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package connect
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
errBackendConnectError = errors.New("unable to connect to the auth backend")
|
||||
errConnStringRequired = errors.New("connection string is required")
|
||||
)
|
||||
|
||||
// Request holds the resource request details
|
||||
type Request struct {
|
||||
Resource string `json:"resource"`
|
||||
Token string `json:"token"`
|
||||
Headers map[string]string `json:"headers,omitempty"`
|
||||
}
|
||||
|
||||
// Credential holds the database connection string
|
||||
type Credential struct {
|
||||
DatabaseURL string `json:"database_url"`
|
||||
}
|
||||
Reference in New Issue
Block a user