Connect backend refactor (#801)
* Move connect backend code to its own package * Move errors into the connect package * Add NewBackend func
This commit is contained in:
@@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/sosedoff/pgweb/pkg/bookmarks"
|
"github.com/sosedoff/pgweb/pkg/bookmarks"
|
||||||
"github.com/sosedoff/pgweb/pkg/client"
|
"github.com/sosedoff/pgweb/pkg/client"
|
||||||
"github.com/sosedoff/pgweb/pkg/command"
|
"github.com/sosedoff/pgweb/pkg/command"
|
||||||
|
"github.com/sosedoff/pgweb/pkg/connect"
|
||||||
"github.com/sosedoff/pgweb/pkg/connection"
|
"github.com/sosedoff/pgweb/pkg/connection"
|
||||||
"github.com/sosedoff/pgweb/pkg/metrics"
|
"github.com/sosedoff/pgweb/pkg/metrics"
|
||||||
"github.com/sosedoff/pgweb/pkg/queries"
|
"github.com/sosedoff/pgweb/pkg/queries"
|
||||||
@@ -92,18 +93,18 @@ func GetSessions(c *gin.Context) {
|
|||||||
|
|
||||||
// ConnectWithBackend creates a new connection based on backend resource
|
// ConnectWithBackend creates a new connection based on backend resource
|
||||||
func ConnectWithBackend(c *gin.Context) {
|
func ConnectWithBackend(c *gin.Context) {
|
||||||
// Setup a new backend client
|
backend := connect.NewBackend(command.Opts.ConnectBackend, command.Opts.ConnectToken)
|
||||||
backend := Backend{
|
backend.SetLogger(logger)
|
||||||
Endpoint: command.Opts.ConnectBackend,
|
|
||||||
Token: command.Opts.ConnectToken,
|
if command.Opts.ConnectHeaders != "" {
|
||||||
PassHeaders: strings.Split(command.Opts.ConnectHeaders, ","),
|
backend.SetPassHeaders(strings.Split(command.Opts.ConnectHeaders, ","))
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Fetch connection credentials
|
// Fetch connection credentials
|
||||||
cred, err := backend.FetchCredential(ctx, c.Param("resource"), c)
|
cred, err := backend.FetchCredential(ctx, c.Param("resource"), c.Request.Header)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
badRequest(c, err)
|
badRequest(c, err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,87 +0,0 @@
|
|||||||
package api
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Backend represents a third party configuration source
|
|
||||||
type Backend struct {
|
|
||||||
Endpoint string
|
|
||||||
Token string
|
|
||||||
PassHeaders []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// BackendRequest represents a payload sent to the third-party source
|
|
||||||
type BackendRequest struct {
|
|
||||||
Resource string `json:"resource"`
|
|
||||||
Token string `json:"token"`
|
|
||||||
Headers map[string]string `json:"headers"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// BackendCredential represents the third-party response
|
|
||||||
type BackendCredential struct {
|
|
||||||
DatabaseURL string `json:"database_url"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// FetchCredential sends an authentication request to a third-party service
|
|
||||||
func (be Backend) FetchCredential(ctx context.Context, resource string, c *gin.Context) (*BackendCredential, error) {
|
|
||||||
logger.WithField("resource", resource).Debug("fetching database credential")
|
|
||||||
|
|
||||||
request := BackendRequest{
|
|
||||||
Resource: resource,
|
|
||||||
Token: be.Token,
|
|
||||||
Headers: map[string]string{},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pass white-listed client headers to the backend request
|
|
||||||
for _, name := range be.PassHeaders {
|
|
||||||
request.Headers[strings.ToLower(name)] = c.Request.Header.Get(name)
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := json.Marshal(request)
|
|
||||||
if err != nil {
|
|
||||||
logger.WithField("resource", resource).Error("backend request serialization error:", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, be.Endpoint, bytes.NewReader(body))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req.Header.Set("content-type", "application/json")
|
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
logger.WithField("resource", resource).Error("backend credential fetch failed:", err)
|
|
||||||
return nil, errBackendConnectError
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
|
||||||
err = fmt.Errorf("backend credential fetch received HTTP status code %v", resp.StatusCode)
|
|
||||||
|
|
||||||
logger.
|
|
||||||
WithField("resource", resource).
|
|
||||||
WithField("status", resp.StatusCode).
|
|
||||||
Error(err)
|
|
||||||
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
cred := &BackendCredential{}
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(cred); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if cred.DatabaseURL == "" {
|
|
||||||
return nil, errConnStringRequired
|
|
||||||
}
|
|
||||||
|
|
||||||
return cred, nil
|
|
||||||
}
|
|
||||||
@@ -7,12 +7,10 @@ import (
|
|||||||
var (
|
var (
|
||||||
errNotConnected = errors.New("Not connected")
|
errNotConnected = errors.New("Not connected")
|
||||||
errNotPermitted = errors.New("Not permitted")
|
errNotPermitted = errors.New("Not permitted")
|
||||||
errConnStringRequired = errors.New("Connection string is required")
|
|
||||||
errInvalidConnString = errors.New("Invalid connection string")
|
errInvalidConnString = errors.New("Invalid connection string")
|
||||||
errSessionRequired = errors.New("Session ID is required")
|
errSessionRequired = errors.New("Session ID is required")
|
||||||
errSessionLocked = errors.New("Session is locked")
|
errSessionLocked = errors.New("Session is locked")
|
||||||
errURLRequired = errors.New("URL parameter is required")
|
errURLRequired = errors.New("URL parameter is required")
|
||||||
errQueryRequired = errors.New("Query parameter is required")
|
errQueryRequired = errors.New("Query parameter is required")
|
||||||
errDatabaseNameRequired = errors.New("Database name is required")
|
errDatabaseNameRequired = errors.New("Database name is required")
|
||||||
errBackendConnectError = errors.New("Unable to connect to the auth backend")
|
|
||||||
)
|
)
|
||||||
|
|||||||
92
pkg/connect/backend.go
Normal file
92
pkg/connect/backend.go
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
package connect
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Backend struct {
|
||||||
|
Endpoint string
|
||||||
|
Token string
|
||||||
|
PassHeaders []string
|
||||||
|
|
||||||
|
logger *logrus.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBackend(endpoint string, token string) Backend {
|
||||||
|
return Backend{
|
||||||
|
Endpoint: endpoint,
|
||||||
|
Token: token,
|
||||||
|
logger: logrus.StandardLogger(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (be *Backend) SetLogger(logger *logrus.Logger) {
|
||||||
|
be.logger = logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (be *Backend) SetPassHeaders(headers []string) {
|
||||||
|
be.PassHeaders = headers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (be *Backend) FetchCredential(ctx context.Context, resource string, headers http.Header) (*Credential, error) {
|
||||||
|
be.logger.WithField("resource", resource).Debug("fetching database credential")
|
||||||
|
|
||||||
|
request := Request{
|
||||||
|
Resource: resource,
|
||||||
|
Token: be.Token,
|
||||||
|
Headers: map[string]string{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass allow-listed client headers to the backend request
|
||||||
|
for _, name := range be.PassHeaders {
|
||||||
|
request.Headers[strings.ToLower(name)] = headers.Get(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(request)
|
||||||
|
if err != nil {
|
||||||
|
be.logger.WithField("resource", resource).Error("backend request serialization error:", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, be.Endpoint, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("content-type", "application/json")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
be.logger.WithField("resource", resource).Error("backend credential fetch failed:", err)
|
||||||
|
return nil, errBackendConnectError
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
err = fmt.Errorf("backend credential fetch received HTTP status code %v", resp.StatusCode)
|
||||||
|
|
||||||
|
be.logger.
|
||||||
|
WithField("resource", request.Resource).
|
||||||
|
WithField("status", resp.StatusCode).
|
||||||
|
Error(err)
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cred := &Credential{}
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(cred); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if cred.DatabaseURL == "" {
|
||||||
|
return nil, errConnStringRequired
|
||||||
|
}
|
||||||
|
|
||||||
|
return cred, nil
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package api
|
package connect
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -17,8 +18,8 @@ func TestBackendFetchCredential(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
backend Backend
|
backend Backend
|
||||||
resourceName string
|
resourceName string
|
||||||
cred *BackendCredential
|
cred *Credential
|
||||||
reqCtx *gin.Context
|
headers http.Header
|
||||||
ctx func() (context.Context, context.CancelFunc)
|
ctx func() (context.Context, context.CancelFunc)
|
||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
@@ -33,12 +34,12 @@ func TestBackendFetchCredential(t *testing.T) {
|
|||||||
ctx: func() (context.Context, context.CancelFunc) {
|
ctx: func() (context.Context, context.CancelFunc) {
|
||||||
return context.WithTimeout(context.Background(), time.Millisecond*100)
|
return context.WithTimeout(context.Background(), time.Millisecond*100)
|
||||||
},
|
},
|
||||||
err: errors.New("Unable to connect to the auth backend"),
|
err: errors.New("unable to connect to the auth backend"),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Empty response",
|
name: "Empty response",
|
||||||
backend: Backend{Endpoint: "http://localhost:5555/empty-response"},
|
backend: Backend{Endpoint: "http://localhost:5555/empty-response"},
|
||||||
err: errors.New("Connection string is required"),
|
err: errors.New("connection string is required"),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Missing header",
|
name: "Missing header",
|
||||||
@@ -51,19 +52,15 @@ func TestBackendFetchCredential(t *testing.T) {
|
|||||||
Endpoint: "http://localhost:5555/pass-header",
|
Endpoint: "http://localhost:5555/pass-header",
|
||||||
PassHeaders: []string{"x-foo"},
|
PassHeaders: []string{"x-foo"},
|
||||||
},
|
},
|
||||||
reqCtx: &gin.Context{
|
headers: http.Header{
|
||||||
Request: &http.Request{
|
|
||||||
Header: http.Header{
|
|
||||||
"X-Foo": []string{"bar"},
|
"X-Foo": []string{"bar"},
|
||||||
},
|
},
|
||||||
},
|
cred: &Credential{DatabaseURL: "postgres://hostname/bar"},
|
||||||
},
|
|
||||||
cred: &BackendCredential{DatabaseURL: "postgres://hostname/bar"},
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Success",
|
name: "Success",
|
||||||
backend: Backend{Endpoint: "http://localhost:5555/success"},
|
backend: Backend{Endpoint: "http://localhost:5555/success"},
|
||||||
cred: &BackendCredential{DatabaseURL: "postgres://hostname/dbname"},
|
cred: &Credential{DatabaseURL: "postgres://hostname/dbname"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -73,6 +70,8 @@ func TestBackendFetchCredential(t *testing.T) {
|
|||||||
startTestBackend(srvCtx, "localhost:5555")
|
startTestBackend(srvCtx, "localhost:5555")
|
||||||
|
|
||||||
for _, ex := range examples {
|
for _, ex := range examples {
|
||||||
|
ex.backend.logger = logrus.StandardLogger()
|
||||||
|
|
||||||
t.Run(ex.name, func(t *testing.T) {
|
t.Run(ex.name, func(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
if ex.ctx != nil {
|
if ex.ctx != nil {
|
||||||
@@ -80,14 +79,7 @@ func TestBackendFetchCredential(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
reqCtx := ex.reqCtx
|
cred, err := ex.backend.FetchCredential(ctx, ex.resourceName, ex.headers)
|
||||||
if reqCtx == nil {
|
|
||||||
reqCtx = &gin.Context{
|
|
||||||
Request: &http.Request{},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cred, err := ex.backend.FetchCredential(ctx, ex.resourceName, reqCtx)
|
|
||||||
assert.Equal(t, ex.err, err)
|
assert.Equal(t, ex.err, err)
|
||||||
assert.Equal(t, ex.cred, cred)
|
assert.Equal(t, ex.cred, cred)
|
||||||
})
|
})
|
||||||
@@ -117,7 +109,7 @@ func startTestBackend(ctx context.Context, listenAddr string) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
router.POST("/pass-header", func(c *gin.Context) {
|
router.POST("/pass-header", func(c *gin.Context) {
|
||||||
req := BackendRequest{}
|
req := Request{}
|
||||||
if err := c.BindJSON(&req); err != nil {
|
if err := c.BindJSON(&req); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
20
pkg/connect/types.go
Normal file
20
pkg/connect/types.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package connect
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
errBackendConnectError = errors.New("unable to connect to the auth backend")
|
||||||
|
errConnStringRequired = errors.New("connection string is required")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Request holds the resource request details
|
||||||
|
type Request struct {
|
||||||
|
Resource string `json:"resource"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
Headers map[string]string `json:"headers,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Credential holds the database connection string
|
||||||
|
type Credential struct {
|
||||||
|
DatabaseURL string `json:"database_url"`
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user