Connect backend refactor (#801)

* Move connect backend code to its own package
* Move errors into the connect package
* Add NewBackend func
This commit is contained in:
Dan Sosedoff
2025-11-11 10:19:05 -08:00
committed by GitHub
parent 266a516076
commit 70f62feec8
6 changed files with 133 additions and 117 deletions

View File

@@ -15,6 +15,7 @@ import (
"github.com/sosedoff/pgweb/pkg/bookmarks" "github.com/sosedoff/pgweb/pkg/bookmarks"
"github.com/sosedoff/pgweb/pkg/client" "github.com/sosedoff/pgweb/pkg/client"
"github.com/sosedoff/pgweb/pkg/command" "github.com/sosedoff/pgweb/pkg/command"
"github.com/sosedoff/pgweb/pkg/connect"
"github.com/sosedoff/pgweb/pkg/connection" "github.com/sosedoff/pgweb/pkg/connection"
"github.com/sosedoff/pgweb/pkg/metrics" "github.com/sosedoff/pgweb/pkg/metrics"
"github.com/sosedoff/pgweb/pkg/queries" "github.com/sosedoff/pgweb/pkg/queries"
@@ -92,18 +93,18 @@ func GetSessions(c *gin.Context) {
// ConnectWithBackend creates a new connection based on backend resource // ConnectWithBackend creates a new connection based on backend resource
func ConnectWithBackend(c *gin.Context) { func ConnectWithBackend(c *gin.Context) {
// Setup a new backend client backend := connect.NewBackend(command.Opts.ConnectBackend, command.Opts.ConnectToken)
backend := Backend{ backend.SetLogger(logger)
Endpoint: command.Opts.ConnectBackend,
Token: command.Opts.ConnectToken, if command.Opts.ConnectHeaders != "" {
PassHeaders: strings.Split(command.Opts.ConnectHeaders, ","), backend.SetPassHeaders(strings.Split(command.Opts.ConnectHeaders, ","))
} }
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel() defer cancel()
// Fetch connection credentials // Fetch connection credentials
cred, err := backend.FetchCredential(ctx, c.Param("resource"), c) cred, err := backend.FetchCredential(ctx, c.Param("resource"), c.Request.Header)
if err != nil { if err != nil {
badRequest(c, err) badRequest(c, err)
return return

View File

@@ -1,87 +0,0 @@
package api
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
// Backend represents a third party configuration source
type Backend struct {
Endpoint string
Token string
PassHeaders []string
}
// BackendRequest represents a payload sent to the third-party source
type BackendRequest struct {
Resource string `json:"resource"`
Token string `json:"token"`
Headers map[string]string `json:"headers"`
}
// BackendCredential represents the third-party response
type BackendCredential struct {
DatabaseURL string `json:"database_url"`
}
// FetchCredential sends an authentication request to a third-party service
func (be Backend) FetchCredential(ctx context.Context, resource string, c *gin.Context) (*BackendCredential, error) {
logger.WithField("resource", resource).Debug("fetching database credential")
request := BackendRequest{
Resource: resource,
Token: be.Token,
Headers: map[string]string{},
}
// Pass white-listed client headers to the backend request
for _, name := range be.PassHeaders {
request.Headers[strings.ToLower(name)] = c.Request.Header.Get(name)
}
body, err := json.Marshal(request)
if err != nil {
logger.WithField("resource", resource).Error("backend request serialization error:", err)
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, be.Endpoint, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("content-type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
logger.WithField("resource", resource).Error("backend credential fetch failed:", err)
return nil, errBackendConnectError
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
err = fmt.Errorf("backend credential fetch received HTTP status code %v", resp.StatusCode)
logger.
WithField("resource", resource).
WithField("status", resp.StatusCode).
Error(err)
return nil, err
}
cred := &BackendCredential{}
if err := json.NewDecoder(resp.Body).Decode(cred); err != nil {
return nil, err
}
if cred.DatabaseURL == "" {
return nil, errConnStringRequired
}
return cred, nil
}

View File

@@ -7,12 +7,10 @@ import (
var ( var (
errNotConnected = errors.New("Not connected") errNotConnected = errors.New("Not connected")
errNotPermitted = errors.New("Not permitted") errNotPermitted = errors.New("Not permitted")
errConnStringRequired = errors.New("Connection string is required")
errInvalidConnString = errors.New("Invalid connection string") errInvalidConnString = errors.New("Invalid connection string")
errSessionRequired = errors.New("Session ID is required") errSessionRequired = errors.New("Session ID is required")
errSessionLocked = errors.New("Session is locked") errSessionLocked = errors.New("Session is locked")
errURLRequired = errors.New("URL parameter is required") errURLRequired = errors.New("URL parameter is required")
errQueryRequired = errors.New("Query parameter is required") errQueryRequired = errors.New("Query parameter is required")
errDatabaseNameRequired = errors.New("Database name is required") errDatabaseNameRequired = errors.New("Database name is required")
errBackendConnectError = errors.New("Unable to connect to the auth backend")
) )

92
pkg/connect/backend.go Normal file
View File

@@ -0,0 +1,92 @@
package connect
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/sirupsen/logrus"
)
type Backend struct {
Endpoint string
Token string
PassHeaders []string
logger *logrus.Logger
}
func NewBackend(endpoint string, token string) Backend {
return Backend{
Endpoint: endpoint,
Token: token,
logger: logrus.StandardLogger(),
}
}
func (be *Backend) SetLogger(logger *logrus.Logger) {
be.logger = logger
}
func (be *Backend) SetPassHeaders(headers []string) {
be.PassHeaders = headers
}
func (be *Backend) FetchCredential(ctx context.Context, resource string, headers http.Header) (*Credential, error) {
be.logger.WithField("resource", resource).Debug("fetching database credential")
request := Request{
Resource: resource,
Token: be.Token,
Headers: map[string]string{},
}
// Pass allow-listed client headers to the backend request
for _, name := range be.PassHeaders {
request.Headers[strings.ToLower(name)] = headers.Get(name)
}
body, err := json.Marshal(request)
if err != nil {
be.logger.WithField("resource", resource).Error("backend request serialization error:", err)
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, be.Endpoint, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("content-type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
be.logger.WithField("resource", resource).Error("backend credential fetch failed:", err)
return nil, errBackendConnectError
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
err = fmt.Errorf("backend credential fetch received HTTP status code %v", resp.StatusCode)
be.logger.
WithField("resource", request.Resource).
WithField("status", resp.StatusCode).
Error(err)
return nil, err
}
cred := &Credential{}
if err := json.NewDecoder(resp.Body).Decode(cred); err != nil {
return nil, err
}
if cred.DatabaseURL == "" {
return nil, errConnStringRequired
}
return cred, nil
}

View File

@@ -1,4 +1,4 @@
package api package connect
import ( import (
"context" "context"
@@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@@ -17,8 +18,8 @@ func TestBackendFetchCredential(t *testing.T) {
name string name string
backend Backend backend Backend
resourceName string resourceName string
cred *BackendCredential cred *Credential
reqCtx *gin.Context headers http.Header
ctx func() (context.Context, context.CancelFunc) ctx func() (context.Context, context.CancelFunc)
err error err error
}{ }{
@@ -33,12 +34,12 @@ func TestBackendFetchCredential(t *testing.T) {
ctx: func() (context.Context, context.CancelFunc) { ctx: func() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), time.Millisecond*100) return context.WithTimeout(context.Background(), time.Millisecond*100)
}, },
err: errors.New("Unable to connect to the auth backend"), err: errors.New("unable to connect to the auth backend"),
}, },
{ {
name: "Empty response", name: "Empty response",
backend: Backend{Endpoint: "http://localhost:5555/empty-response"}, backend: Backend{Endpoint: "http://localhost:5555/empty-response"},
err: errors.New("Connection string is required"), err: errors.New("connection string is required"),
}, },
{ {
name: "Missing header", name: "Missing header",
@@ -51,19 +52,15 @@ func TestBackendFetchCredential(t *testing.T) {
Endpoint: "http://localhost:5555/pass-header", Endpoint: "http://localhost:5555/pass-header",
PassHeaders: []string{"x-foo"}, PassHeaders: []string{"x-foo"},
}, },
reqCtx: &gin.Context{ headers: http.Header{
Request: &http.Request{
Header: http.Header{
"X-Foo": []string{"bar"}, "X-Foo": []string{"bar"},
}, },
}, cred: &Credential{DatabaseURL: "postgres://hostname/bar"},
},
cred: &BackendCredential{DatabaseURL: "postgres://hostname/bar"},
}, },
{ {
name: "Success", name: "Success",
backend: Backend{Endpoint: "http://localhost:5555/success"}, backend: Backend{Endpoint: "http://localhost:5555/success"},
cred: &BackendCredential{DatabaseURL: "postgres://hostname/dbname"}, cred: &Credential{DatabaseURL: "postgres://hostname/dbname"},
}, },
} }
@@ -73,6 +70,8 @@ func TestBackendFetchCredential(t *testing.T) {
startTestBackend(srvCtx, "localhost:5555") startTestBackend(srvCtx, "localhost:5555")
for _, ex := range examples { for _, ex := range examples {
ex.backend.logger = logrus.StandardLogger()
t.Run(ex.name, func(t *testing.T) { t.Run(ex.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
if ex.ctx != nil { if ex.ctx != nil {
@@ -80,14 +79,7 @@ func TestBackendFetchCredential(t *testing.T) {
} }
defer cancel() defer cancel()
reqCtx := ex.reqCtx cred, err := ex.backend.FetchCredential(ctx, ex.resourceName, ex.headers)
if reqCtx == nil {
reqCtx = &gin.Context{
Request: &http.Request{},
}
}
cred, err := ex.backend.FetchCredential(ctx, ex.resourceName, reqCtx)
assert.Equal(t, ex.err, err) assert.Equal(t, ex.err, err)
assert.Equal(t, ex.cred, cred) assert.Equal(t, ex.cred, cred)
}) })
@@ -117,7 +109,7 @@ func startTestBackend(ctx context.Context, listenAddr string) {
}) })
router.POST("/pass-header", func(c *gin.Context) { router.POST("/pass-header", func(c *gin.Context) {
req := BackendRequest{} req := Request{}
if err := c.BindJSON(&req); err != nil { if err := c.BindJSON(&req); err != nil {
panic(err) panic(err)
} }

20
pkg/connect/types.go Normal file
View File

@@ -0,0 +1,20 @@
package connect
import "errors"
var (
errBackendConnectError = errors.New("unable to connect to the auth backend")
errConnStringRequired = errors.New("connection string is required")
)
// Request holds the resource request details
type Request struct {
Resource string `json:"resource"`
Token string `json:"token"`
Headers map[string]string `json:"headers,omitempty"`
}
// Credential holds the database connection string
type Credential struct {
DatabaseURL string `json:"database_url"`
}