diff --git a/pkg/api/api.go b/pkg/api/api.go index dd4017e..a6b30c4 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -19,10 +19,14 @@ import ( ) var ( - DbClient *client.Client + // DbClient represents the active database connection in a single-session mode + DbClient *client.Client + + // DbSessions represents the mapping for client connections DbSessions = map[string]*client.Client{} ) +// DB returns a database connection from the client context func DB(c *gin.Context) *client.Client { if command.Opts.Sessions { return DbSessions[getSessionId(c.Request)] @@ -63,11 +67,10 @@ func GetSessions(c *gin.Context) { // In debug mode endpoint will return a lot of sensitive information // like full database connection string and all query history. if command.Opts.Debug { - c.JSON(200, DbSessions) + successResponse(c, DbSessions) return } - - c.JSON(200, map[string]int{"sessions": len(DbSessions)}) + successResponse(c, gin.H{"sessions": len(DbSessions)}) } func ConnectWithBackend(c *gin.Context) { @@ -81,22 +84,22 @@ func ConnectWithBackend(c *gin.Context) { // Fetch connection credentials cred, err := backend.FetchCredential(c.Param("resource"), c) if err != nil { - c.JSON(400, Error{err.Error()}) + badRequest(c, err) return } // Make the new session sessionId, err := securerandom.Uuid() if err != nil { - c.JSON(400, Error{err.Error()}) + badRequest(c, err) return } c.Request.Header.Add("x-session-id", sessionId) // Connect to the database - cl, err := client.NewFromUrl(cred.DatabaseUrl, nil) + cl, err := client.NewFromUrl(cred.DatabaseURL, nil) if err != nil { - c.JSON(400, Error{err.Error()}) + badRequest(c, err) return } cl.External = true @@ -108,7 +111,7 @@ func ConnectWithBackend(c *gin.Context) { } if err != nil { cl.Close() - c.JSON(400, Error{err.Error()}) + badRequest(c, err) return } @@ -117,7 +120,7 @@ func ConnectWithBackend(c *gin.Context) { func Connect(c *gin.Context) { if command.Opts.LockSession { - c.JSON(400, Error{"Session is locked"}) + badRequest(c, "Session is locked") return } @@ -125,7 +128,7 @@ func Connect(c *gin.Context) { url := c.Request.FormValue("url") if url == "" { - c.JSON(400, Error{"Url parameter is required"}) + badRequest(c, "Url parameter is required") return } @@ -133,7 +136,7 @@ func Connect(c *gin.Context) { url, err := connection.FormatUrl(opts) if err != nil { - c.JSON(400, Error{err.Error()}) + badRequest(c, err) return } @@ -143,13 +146,13 @@ func Connect(c *gin.Context) { cl, err := client.NewFromUrl(url, sshInfo) if err != nil { - c.JSON(400, Error{err.Error()}) + badRequest(c, err) return } err = cl.Test() if err != nil { - c.JSON(400, Error{err.Error()}) + badRequest(c, err) return } @@ -159,16 +162,16 @@ func Connect(c *gin.Context) { } if err != nil { cl.Close() - c.JSON(400, Error{err.Error()}) + badRequest(c, err) return } - c.JSON(200, info.Format()[0]) + successResponse(c, info.Format()[0]) } func SwitchDb(c *gin.Context) { if command.Opts.LockSession { - c.JSON(400, Error{"Session is locked"}) + badRequest(c, "Session is locked") return } @@ -177,25 +180,25 @@ func SwitchDb(c *gin.Context) { name = c.Request.FormValue("db") } if name == "" { - c.JSON(400, Error{"Database name is not provided"}) + badRequest(c, "Database name is not provided") return } conn := DB(c) if conn == nil { - c.JSON(400, Error{"Not connected"}) + badRequest(c, "Not connected") return } // Do not allow switching databases for connections from third-party backends if conn.External { - c.JSON(400, Error{"Session is locked"}) + badRequest(c, "Session is locked") return } currentUrl, err := neturl.Parse(conn.ConnectionString) if err != nil { - c.JSON(400, Error{"Unable to parse current connection string"}) + badRequest(c, "Unable to parse current connection string") return } @@ -203,13 +206,13 @@ func SwitchDb(c *gin.Context) { cl, err := client.NewFromUrl(currentUrl.String(), nil) if err != nil { - c.JSON(400, Error{err.Error()}) + badRequest(c, err) return } err = cl.Test() if err != nil { - c.JSON(400, Error{err.Error()}) + badRequest(c, err) return } @@ -219,64 +222,62 @@ func SwitchDb(c *gin.Context) { } if err != nil { cl.Close() - c.JSON(400, Error{err.Error()}) + badRequest(c, err) return } conn.Close() - c.JSON(200, info.Format()[0]) + successResponse(c, info.Format()[0]) } func Disconnect(c *gin.Context) { if command.Opts.LockSession { - c.JSON(400, Error{"Session is locked"}) + badRequest(c, "Session is locked") return } conn := DB(c) if conn == nil { - c.JSON(400, Error{"Not connected"}) + badRequest(c, "Not connected") return } err := conn.Close() if err != nil { - c.JSON(400, Error{err.Error()}) + badRequest(c, err) return } - c.JSON(200, map[string]bool{"success": true}) + successResponse(c, gin.H{"success": true}) } func GetDatabases(c *gin.Context) { conn := DB(c) if conn.External { - c.JSON(403, Error{"Not permitted"}) + errorResponse(c, 403, "Not permitted") return } names, err := DB(c).Databases() - serveResult(names, err, c) + serveResult(c, names, err) } func GetObjects(c *gin.Context) { result, err := DB(c).Objects() if err != nil { - c.JSON(400, NewError(err)) + badRequest(c, err) return } - - objects := client.ObjectsFromResult(result) - c.JSON(200, objects) + successResponse(c, client.ObjectsFromResult(result)) } func RunQuery(c *gin.Context) { query := cleanQuery(c.Request.FormValue("query")) if query == "" { - c.JSON(400, NewError(errors.New("Query parameter is missing"))) + badRequest(c, "Query parameter is missing") return } @@ -287,7 +288,7 @@ func ExplainQuery(c *gin.Context) { query := cleanQuery(c.Request.FormValue("query")) if query == "" { - c.JSON(400, NewError(errors.New("Query parameter is missing"))) + badRequest(c, "Query parameter is missing") return } @@ -296,7 +297,7 @@ func ExplainQuery(c *gin.Context) { func GetSchemas(c *gin.Context) { res, err := DB(c).Schemas() - serveResult(res, err, c) + serveResult(c, res, err) } func GetTable(c *gin.Context) { @@ -309,19 +310,19 @@ func GetTable(c *gin.Context) { res, err = DB(c).Table(c.Params.ByName("table")) } - serveResult(res, err, c) + serveResult(c, res, err) } func GetTableRows(c *gin.Context) { offset, err := parseIntFormValue(c, "offset", 0) if err != nil { - c.JSON(400, NewError(err)) + badRequest(c, err) return } limit, err := parseIntFormValue(c, "limit", 100) if err != nil { - c.JSON(400, NewError(err)) + badRequest(c, err) return } @@ -335,13 +336,13 @@ func GetTableRows(c *gin.Context) { res, err := DB(c).TableRows(c.Params.ByName("table"), opts) if err != nil { - c.JSON(400, NewError(err)) + badRequest(c, err) return } countRes, err := DB(c).TableRowsCount(c.Params.ByName("table"), opts) if err != nil { - c.JSON(400, NewError(err)) + badRequest(c, err) return } @@ -361,51 +362,49 @@ func GetTableRows(c *gin.Context) { PerPage: numFetch, } - serveResult(res, err, c) + serveResult(c, res, err) } func GetTableInfo(c *gin.Context) { res, err := DB(c).TableInfo(c.Params.ByName("table")) - - if err != nil { - c.JSON(400, NewError(err)) - return + if err == nil { + successResponse(c, res.Format()[0]) + } else { + badRequest(c, err) } - - c.JSON(200, res.Format()[0]) } func GetHistory(c *gin.Context) { - c.JSON(200, DB(c).History) + successResponse(c, DB(c).History) } func GetConnectionInfo(c *gin.Context) { res, err := DB(c).Info() if err != nil { - c.JSON(400, NewError(err)) + badRequest(c, err) return } info := res.Format()[0] info["session_lock"] = command.Opts.LockSession - c.JSON(200, info) + successResponse(c, info) } func GetActivity(c *gin.Context) { res, err := DB(c).Activity() - serveResult(res, err, c) + serveResult(c, res, err) } func GetTableIndexes(c *gin.Context) { res, err := DB(c).TableIndexes(c.Params.ByName("table")) - serveResult(res, err, c) + serveResult(c, res, err) } func GetTableConstraints(c *gin.Context) { res, err := DB(c).TableConstraints(c.Params.ByName("table")) - serveResult(res, err, c) + serveResult(c, res, err) } func HandleQuery(query string, c *gin.Context) { @@ -416,7 +415,7 @@ func HandleQuery(query string, c *gin.Context) { result, err := DB(c).Query(query) if err != nil { - c.JSON(400, NewError(err)) + badRequest(c, err) return } @@ -445,17 +444,15 @@ func HandleQuery(query string, c *gin.Context) { func GetBookmarks(c *gin.Context) { bookmarks, err := bookmarks.ReadAll(bookmarks.Path(command.Opts.BookmarksDir)) - serveResult(bookmarks, err, c) + serveResult(c, bookmarks, err) } func GetInfo(c *gin.Context) { - info := map[string]string{ + successResponse(c, gin.H{ "version": command.VERSION, "git_sha": command.GitCommit, "build_time": command.BuildTime, - } - - c.JSON(200, info) + }) } // Export database or table data @@ -464,7 +461,7 @@ func DataExport(c *gin.Context) { info, err := db.Info() if err != nil { - c.JSON(400, Error{err.Error()}) + badRequest(c, err) return } @@ -475,7 +472,7 @@ func DataExport(c *gin.Context) { // If pg_dump is not available the following code will not show an error in browser. // This is due to the headers being written first. if !dump.CanExport() { - c.JSON(400, Error{"pg_dump is not found"}) + badRequest(c, "pg_dump is not found") return } @@ -485,11 +482,13 @@ func DataExport(c *gin.Context) { filename = filename + "_" + dump.Table } - attachment := fmt.Sprintf(`attachment; filename="%s.sql.gz"`, filename) - c.Header("Content-Disposition", attachment) + c.Header( + "Content-Disposition", + fmt.Sprintf(`attachment; filename="%s.sql.gz"`, filename), + ) err = dump.Export(db.ConnectionString, c.Writer) if err != nil { - c.JSON(400, Error{err.Error()}) + badRequest(c, err) } } diff --git a/pkg/api/backend.go b/pkg/api/backend.go index 90baf79..cfad64a 100644 --- a/pkg/api/backend.go +++ b/pkg/api/backend.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/json" "fmt" - "io/ioutil" "log" "net/http" "strings" @@ -12,22 +11,26 @@ import ( "github.com/gin-gonic/gin" ) +// Backend represents a third party configuration source type Backend struct { Endpoint string Token string PassHeaders string } +// BackendRequest represents a payload sent to the third-party source type BackendRequest struct { Resource string `json:"resource"` Token string `json:"token"` Headers map[string]string `json:"headers"` } +// BackendCredential represents the third-party response type BackendCredential struct { - DatabaseUrl string `json:"database_url"` + DatabaseURL string `json:"database_url"` } +// FetchCredential sends an authentication request to a third-party service func (be Backend) FetchCredential(resource string, c *gin.Context) (*BackendCredential, error) { request := BackendRequest{ Resource: resource, @@ -35,6 +38,7 @@ func (be Backend) FetchCredential(resource string, c *gin.Context) (*BackendCred Headers: map[string]string{}, } + // Pass white-listed client headers to the backend request for _, name := range strings.Split(be.PassHeaders, ",") { request.Headers[strings.ToLower(name)] = c.Request.Header.Get(name) } @@ -58,17 +62,12 @@ func (be Backend) FetchCredential(resource string, c *gin.Context) (*BackendCred return nil, fmt.Errorf("Got HTTP error %v from backend", resp.StatusCode) } - respBody, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, err - } - cred := &BackendCredential{} - if err := json.Unmarshal(respBody, cred); err != nil { + if err := json.NewDecoder(resp.Body).Decode(cred); err != nil { return nil, err } - if cred.DatabaseUrl == "" { - return nil, fmt.Errorf("Database url was not provided") + if cred.DatabaseURL == "" { + return nil, fmt.Errorf("Database URL was not provided") } return cred, nil diff --git a/pkg/api/helpers.go b/pkg/api/helpers.go index a83be47..03980fd 100644 --- a/pkg/api/helpers.go +++ b/pkg/api/helpers.go @@ -14,30 +14,33 @@ import ( "github.com/sosedoff/pgweb/pkg/shared" ) -var extraMimeTypes = map[string]string{ - ".icon": "image-x-icon", - ".ttf": "application/x-font-ttf", - ".woff": "application/x-font-woff", - ".eot": "application/vnd.ms-fontobject", - ".svg": "image/svg+xml", - ".html": "text/html; charset-utf-8", -} +var ( + // Mime types definitions + extraMimeTypes = map[string]string{ + ".icon": "image-x-icon", + ".ttf": "application/x-font-ttf", + ".woff": "application/x-font-woff", + ".eot": "application/vnd.ms-fontobject", + ".svg": "image/svg+xml", + ".html": "text/html; charset-utf-8", + } -// Paths that dont require database connection -var allowedPaths = map[string]bool{ - "/api/sessions": true, - "/api/info": true, - "/api/connect": true, - "/api/bookmarks": true, - "/api/history": true, -} + // Paths that dont require database connection + allowedPaths = map[string]bool{ + "/api/sessions": true, + "/api/info": true, + "/api/connect": true, + "/api/bookmarks": true, + "/api/history": true, + } -// List of characters replaced by javascript code to make queries url-safe. -var base64subs = map[string]string{ - "-": "+", - "_": "/", - ".": "=", -} + // List of characters replaced by javascript code to make queries url-safe. + base64subs = map[string]string{ + "-": "+", + "_": "/", + ".": "=", + } +) type Error struct { Message string `json:"error"` @@ -151,11 +154,37 @@ func serveStaticAsset(path string, c *gin.Context) { c.Data(200, assetContentType(path), data) } -func serveResult(result interface{}, err error, c *gin.Context) { - if err != nil { - c.JSON(400, NewError(err)) - return +// Send a query result to client +func serveResult(c *gin.Context, result interface{}, err interface{}) { + if err == nil { + successResponse(c, result) + } else { + badRequest(c, err) + } +} + +// Send successful response back to client +func successResponse(c *gin.Context, data interface{}) { + c.JSON(200, data) +} + +// Send an error response back to client +func errorResponse(c *gin.Context, status int, err interface{}) { + var message interface{} + + switch v := err.(type) { + case error: + message = v.Error() + case string: + message = v + default: + message = v } - c.JSON(200, result) + c.AbortWithStatusJSON(status, gin.H{"status": status, "error": message}) +} + +// Send a bad request (http 400) back to client +func badRequest(c *gin.Context, err interface{}) { + errorResponse(c, 400, err) } diff --git a/pkg/api/helpers_test.go b/pkg/api/helpers_test.go index be954ae..7738d10 100644 --- a/pkg/api/helpers_test.go +++ b/pkg/api/helpers_test.go @@ -43,13 +43,13 @@ func Test_getSessionId(t *testing.T) { func Test_serveResult(t *testing.T) { server := gin.Default() server.GET("/good", func(c *gin.Context) { - serveResult(gin.H{"foo": "bar"}, nil, c) + serveResult(c, gin.H{"foo": "bar"}, nil) }) server.GET("/bad", func(c *gin.Context) { - serveResult(nil, errors.New("message"), c) + serveResult(c, nil, errors.New("message")) }) server.GET("/nodata", func(c *gin.Context) { - serveResult(nil, nil, c) + serveResult(c, nil, nil) }) w := httptest.NewRecorder() @@ -62,7 +62,7 @@ func Test_serveResult(t *testing.T) { req, _ = http.NewRequest("GET", "/bad", nil) server.ServeHTTP(w, req) assert.Equal(t, 400, w.Code) - assert.Equal(t, `{"error":"message"}`, w.Body.String()) + assert.Equal(t, `{"error":"message","status":400}`, w.Body.String()) w = httptest.NewRecorder() req, _ = http.NewRequest("GET", "/nodata", nil) diff --git a/pkg/api/middleware.go b/pkg/api/middleware.go index ef7781f..2610b27 100644 --- a/pkg/api/middleware.go +++ b/pkg/api/middleware.go @@ -9,21 +9,21 @@ import ( "github.com/sosedoff/pgweb/pkg/command" ) -// Middleware function to check database connection status before running queries +// Middleware to check database connection status before running queries func dbCheckMiddleware() gin.HandlerFunc { return func(c *gin.Context) { path := strings.Replace(c.Request.URL.Path, command.Opts.Prefix, "", -1) - if allowedPaths[path] == true { + // Allow whitelisted paths + if allowedPaths[path] { c.Next() return } - // We dont care about sessions unless they're enabled + // Check if session exists in single-session mode if !command.Opts.Sessions { if DbClient == nil { - c.JSON(400, Error{"Not connected"}) - c.Abort() + badRequest(c, "Not connected") return } @@ -31,17 +31,17 @@ func dbCheckMiddleware() gin.HandlerFunc { return } + // Determine session ID from the client request sessionId := getSessionId(c.Request) if sessionId == "" { - c.JSON(400, Error{"Session ID is required"}) - c.Abort() + badRequest(c, "Session ID is required") return } + // Determine the database connection handle for the session conn := DbSessions[sessionId] if conn == nil { - c.JSON(400, Error{"Not connected"}) - c.Abort() + badRequest(c, "Not connected") return } @@ -49,7 +49,7 @@ func dbCheckMiddleware() gin.HandlerFunc { } } -// Middleware function to print out request parameters and body for debugging +// Middleware to print out request parameters and body for debugging func requestInspectMiddleware() gin.HandlerFunc { return func(c *gin.Context) { err := c.Request.ParseForm() @@ -57,6 +57,7 @@ func requestInspectMiddleware() gin.HandlerFunc { } } +// Middleware to inject CORS headers func corsMiddleware() gin.HandlerFunc { return func(c *gin.Context) { c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS") diff --git a/pkg/api/session_cleanup.go b/pkg/api/session_cleanup.go index 5951521..862d121 100644 --- a/pkg/api/session_cleanup.go +++ b/pkg/api/session_cleanup.go @@ -7,6 +7,16 @@ import ( "github.com/sosedoff/pgweb/pkg/command" ) +// StartSessionCleanup starts a goroutine to cleanup idle database sessions +func StartSessionCleanup() { + for range time.Tick(time.Minute) { + if command.Opts.Debug { + log.Println("Triggering idle session deletion") + } + cleanupIdleSessions() + } +} + func cleanupIdleSessions() { ids := []string{} @@ -16,11 +26,11 @@ func cleanupIdleSessions() { ids = append(ids, id) } } - - // Close and delete idle sessions if len(ids) == 0 { return } + + // Close and delete idle sessions log.Println("Closing", len(ids), "idle sessions") for _, id := range ids { // TODO: concurrent map edit will trigger panic @@ -29,17 +39,3 @@ func cleanupIdleSessions() { } } } - -func StartSessionCleanup() { - ticker := time.NewTicker(time.Minute) - - for { - <-ticker.C - - if command.Opts.Debug { - log.Println("Triggering idle session deletion") - } - - cleanupIdleSessions() - } -} diff --git a/pkg/client/tunnel.go b/pkg/client/tunnel.go index 7b57388..db167e5 100644 --- a/pkg/client/tunnel.go +++ b/pkg/client/tunnel.go @@ -19,10 +19,11 @@ import ( ) const ( - PORT_START = 29168 - PORT_LIMIT = 500 + portStart = 29168 + portLimit = 500 ) +// Tunnel represents the connection between local and remote server type Tunnel struct { TargetHost string TargetPort string @@ -121,6 +122,7 @@ func (tunnel *Tunnel) handleConnection(local net.Conn) { local.Close() } +// Close closes the tunnel connection func (tunnel *Tunnel) Close() { if tunnel.Client != nil { tunnel.Client.Close() @@ -131,6 +133,7 @@ func (tunnel *Tunnel) Close() { } } +// Configure establishes the tunnel between localhost and remote machine func (tunnel *Tunnel) Configure() error { config, err := makeConfig(tunnel.SSHInfo) if err != nil { @@ -153,6 +156,7 @@ func (tunnel *Tunnel) Configure() error { return nil } +// Start starts the connection handler loop func (tunnel *Tunnel) Start() { defer tunnel.Close() @@ -166,13 +170,14 @@ func (tunnel *Tunnel) Start() { } } +// NewTunnel instantiates a new tunnel struct from given ssh info func NewTunnel(sshInfo *shared.SSHInfo, dbUrl string) (*Tunnel, error) { uri, err := url.Parse(dbUrl) if err != nil { return nil, err } - listenPort, err := connection.AvailablePort(PORT_START, PORT_LIMIT) + listenPort, err := connection.FindAvailablePort(portStart, portLimit) if err != nil { return nil, err } diff --git a/pkg/connection/port.go b/pkg/connection/port.go index 4c87010..a596730 100644 --- a/pkg/connection/port.go +++ b/pkg/connection/port.go @@ -7,10 +7,9 @@ import ( "strings" ) -// Check if the TCP port available on localhost -func portAvailable(port int) bool { +// IsPortAvailable returns true if there's no listeners on a given port +func IsPortAvailable(port int) bool { conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%v", port)) - if err != nil { if strings.Index(err.Error(), "connection refused") > 0 { return true @@ -22,10 +21,10 @@ func portAvailable(port int) bool { return false } -// Get available TCP port on localhost by trying available ports in a range -func AvailablePort(start int, limit int) (int, error) { +// FindAvailablePort returns the first available TCP port in the range +func FindAvailablePort(start int, limit int) (int, error) { for i := start; i <= (start + limit); i++ { - if portAvailable(i) { + if IsPortAvailable(i) { return i, nil } } diff --git a/pkg/connection/port_test.go b/pkg/connection/port_test.go index a2b3114..4ee3eae 100644 --- a/pkg/connection/port_test.go +++ b/pkg/connection/port_test.go @@ -10,12 +10,12 @@ import ( "github.com/stretchr/testify/assert" ) -func Test_portAvailable(t *testing.T) { +func TestIsPortAvailable(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("FIXME") } - assert.Equal(t, true, portAvailable(30000)) + assert.Equal(t, true, IsPortAvailable(30000)) serv, err := net.Listen("tcp", "127.0.0.1:30000") if err != nil { @@ -35,16 +35,16 @@ func Test_portAvailable(t *testing.T) { } }() - assert.Equal(t, false, portAvailable(30000)) - assert.Equal(t, true, portAvailable(30001)) + assert.Equal(t, false, IsPortAvailable(30000)) + assert.Equal(t, true, IsPortAvailable(30001)) } -func Test_getAvailablePort(t *testing.T) { +func TestFindAvailablePort(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("FIXME") } - port, err := AvailablePort(30000, 1) + port, err := FindAvailablePort(30000, 1) assert.Equal(t, nil, err) assert.Equal(t, 30000, port) @@ -65,11 +65,11 @@ func Test_getAvailablePort(t *testing.T) { } }() - port, err = AvailablePort(30000, 0) + port, err = FindAvailablePort(30000, 0) assert.EqualError(t, err, "No available port") assert.Equal(t, -1, port) - port, err = AvailablePort(30000, 1) + port, err = FindAvailablePort(30000, 1) assert.Equal(t, nil, err) assert.Equal(t, 30001, port) }