From 7d08017c7f66d8874720fadfc82ac8be26be4b61 Mon Sep 17 00:00:00 2001 From: Dan Sosedoff Date: Sat, 5 Nov 2016 22:23:26 -0500 Subject: [PATCH] Add endpoint to switch active database --- pkg/api/api.go | 53 +++++++++++++++++++++++++++++++++++++++++++++++ pkg/api/routes.go | 1 + 2 files changed, 54 insertions(+) diff --git a/pkg/api/api.go b/pkg/api/api.go index 6928291..93dbe9a 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -4,6 +4,8 @@ import ( "encoding/base64" "errors" "fmt" + neturl "net/url" + "strings" "time" "github.com/gin-gonic/gin" @@ -118,6 +120,57 @@ func Connect(c *gin.Context) { c.JSON(200, info.Format()[0]) } +func SwitchDb(c *gin.Context) { + if command.Opts.LockSession { + c.JSON(400, Error{"Session is locked"}) + return + } + + name := c.Request.URL.Query().Get("db") + if name == "" { + c.JSON(400, Error{"Database name is not provided"}) + return + } + + conn := DB(c) + if conn == nil { + c.JSON(400, Error{"Not connected"}) + return + } + + currentUrl, err := neturl.Parse(conn.ConnectionString) + if err != nil { + c.JSON(400, Error{"Unable to parse current connection string"}) + return + } + + newStr := strings.Replace(conn.ConnectionString, currentUrl.Path, "/"+name, 1) + + cl, err := client.NewFromUrl(newStr, nil) + if err != nil { + c.JSON(400, Error{err.Error()}) + return + } + + err = cl.Test() + if err != nil { + c.JSON(400, Error{err.Error()}) + return + } + + info, err := cl.Info() + if err == nil { + err = setClient(c, cl) + if err != nil { + cl.Close() + c.JSON(400, Error{err.Error()}) + return + } + } + + c.JSON(200, info.Format()[0]) +} + func Disconnect(c *gin.Context) { if command.Opts.LockSession { c.JSON(400, Error{"Session is locked"}) diff --git a/pkg/api/routes.go b/pkg/api/routes.go index 81729db..d800f87 100644 --- a/pkg/api/routes.go +++ b/pkg/api/routes.go @@ -30,6 +30,7 @@ func SetupRoutes(router *gin.Engine) { api.GET("/info", GetInfo) api.POST("/connect", Connect) api.POST("/disconnect", Disconnect) + api.POST("/switchdb", SwitchDb) api.GET("/databases", GetDatabases) api.GET("/connection", GetConnectionInfo) api.GET("/activity", GetActivity)