diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 319c656..e86d573 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -2,6 +2,7 @@ name: checks on: - push + - pull_request env: GO_VERSION: 1.19 diff --git a/go.mod b/go.mod index ca1485e..6056b37 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/lib/pq v1.10.5 github.com/mitchellh/go-homedir v1.1.0 github.com/mr-tron/base58 v1.2.0 + github.com/sirupsen/logrus v1.9.0 github.com/stretchr/testify v1.7.1 github.com/tuvistavie/securerandom v0.0.0-20140719024926-15512123a948 golang.org/x/crypto v0.0.0-20220511200225-c6db032c6c88 @@ -32,7 +33,6 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/sirupsen/logrus v1.9.0 // indirect github.com/ugorji/go/codec v1.2.6 // indirect golang.org/x/sys v0.2.0 // indirect golang.org/x/text v0.3.7 // indirect diff --git a/go.sum b/go.sum index d5cc007..3b0194c 100644 --- a/go.sum +++ b/go.sum @@ -87,7 +87,6 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/tuvistavie/securerandom v0.0.0-20140719024926-15512123a948 h1:yL0l/u242MzDP6D0B5vGC+wxm5WRY+alQZy+dJk3bFI= github.com/tuvistavie/securerandom v0.0.0-20140719024926-15512123a948/go.mod h1:a06d/M1pxWi51qiSrfGMHaEydtuXT06nha8N2aNQuXk= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= -github.com/ugorji/go v1.2.6 h1:tGiWC9HENWE2tqYycIqFTNorMmFRVhNwCpDOpWqnk8E= github.com/ugorji/go v1.2.6/go.mod h1:anCg0y61KIhDlPZmnH+so+RQbysYVyDko0IMgJv0Nn0= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= github.com/ugorji/go/codec v1.2.6 h1:7kbGefxLoDBuYXOms4yD7223OpNMMPNPZxXk5TvFcyQ= @@ -100,19 +99,15 @@ golang.org/x/crypto v0.0.0-20220511200225-c6db032c6c88 h1:Tgea0cVUD0ivh5ADBX4Wwu golang.org/x/crypto v0.0.0-20220511200225-c6db032c6c88/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200219091948-cb0a6d8edb6c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210923061019-b8560ed6a9b7 h1:c20P3CcPbopVp2f7099WLOqSNKURf30Z0uq66HpijZY= -golang.org/x/sys v0.0.0-20210923061019-b8560ed6a9b7/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/pkg/api/api.go b/pkg/api/api.go index 152da8f..ee8f91a 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -6,7 +6,6 @@ import ( "fmt" "net/http" neturl "net/url" - "regexp" "strings" "time" @@ -58,7 +57,7 @@ func setClient(c *gin.Context, newClient *client.Client) error { return nil } -// GetHome renderes the home page +// GetHome renders the home page func GetHome(prefix string) http.Handler { if prefix != "" { prefix = "/" + prefix @@ -530,10 +529,10 @@ func DataExport(c *gin.Context) { Table: strings.TrimSpace(c.Request.FormValue("table")), } - // If pg_dump is not available the following code will not show an error in browser. - // This is due to the headers being written first. - if !dump.CanExport() { - badRequest(c, errPgDumpNotFound) + // Perform validation of pg_dump command availability and compatibility. + // Must be done before the actual command is executed to display errors. + if err := dump.Validate(db.ServerVersion()); err != nil { + badRequest(c, err) return } @@ -542,16 +541,18 @@ func DataExport(c *gin.Context) { if dump.Table != "" { filename = filename + "_" + dump.Table } - reg := regexp.MustCompile(`[^._\\w]+`) - cleanFilename := reg.ReplaceAllString(filename, "") + + filename = sanitizeFilename(filename) + filename = fmt.Sprintf("%s_%s", filename, time.Now().Format("20060102_150405")) c.Header( "Content-Disposition", - fmt.Sprintf(`attachment; filename="%s.sql.gz"`, cleanFilename), + fmt.Sprintf(`attachment; filename="%s.sql.gz"`, filename), ) - err = dump.Export(db.ConnectionString, c.Writer) + err = dump.Export(c.Request.Context(), db.ConnectionString, c.Writer) if err != nil { + logger.WithError(err).Error("pg_dump command failed") badRequest(c, err) } } diff --git a/pkg/api/errors.go b/pkg/api/errors.go index 54f8d80..4ef6330 100644 --- a/pkg/api/errors.go +++ b/pkg/api/errors.go @@ -14,6 +14,5 @@ var ( errURLRequired = errors.New("URL parameter is required") errQueryRequired = errors.New("Query parameter is required") errDatabaseNameRequired = errors.New("Database name is required") - errPgDumpNotFound = errors.New("pg_dump utility is not found") errBackendConnectError = errors.New("Unable to connect to the auth backend") ) diff --git a/pkg/api/helpers.go b/pkg/api/helpers.go index f56093a..ed7be10 100644 --- a/pkg/api/helpers.go +++ b/pkg/api/helpers.go @@ -5,6 +5,7 @@ import ( "mime" "net/http" "path/filepath" + "regexp" "strconv" "strings" @@ -39,6 +40,9 @@ var ( "_": "/", ".": "=", } + + // Regular expression to remove unwanted characters in filenames + regexCleanFilename = regexp.MustCompile(`[^\w]+`) ) type Error struct { @@ -74,6 +78,11 @@ func desanitize64(query string) string { return query } +func sanitizeFilename(str string) string { + str = strings.ReplaceAll(str, ".", "_") + return regexCleanFilename.ReplaceAllString(str, "") +} + func getSessionId(req *http.Request) string { id := req.Header.Get("x-session-id") if id == "" { diff --git a/pkg/api/helpers_test.go b/pkg/api/helpers_test.go index 7738d10..79f41db 100644 --- a/pkg/api/helpers_test.go +++ b/pkg/api/helpers_test.go @@ -30,6 +30,22 @@ func Test_cleanQuery(t *testing.T) { assert.Equal(t, "test", cleanQuery("--test\ntest\n -- test\n")) } +func Test_sanitizeFilename(t *testing.T) { + examples := map[string]string{ + "foo": "foo", + "fooBar": "fooBar", + "foo.bar": "foo_bar", + `"foo"."bar"`: "foo_bar", + "!@#$foo.&&*(&bar": "foo_bar", + } + + for given, expected := range examples { + t.Run(given, func(t *testing.T) { + assert.Equal(t, expected, sanitizeFilename(given)) + }) + } +} + func Test_getSessionId(t *testing.T) { req := &http.Request{Header: http.Header{}} req.Header.Add("x-session-id", "token") diff --git a/pkg/bookmarks/bookmarks.go b/pkg/bookmarks/bookmarks.go index d90cb0f..f89efb8 100644 --- a/pkg/bookmarks/bookmarks.go +++ b/pkg/bookmarks/bookmarks.go @@ -25,7 +25,7 @@ type Bookmark struct { SSH *shared.SSHInfo `json:"ssh"` // SSH tunnel config } -// SSHInfoIsEmpty returns true if ssh configration is not provided +// SSHInfoIsEmpty returns true if ssh configuration is not provided func (b Bookmark) SSHInfoIsEmpty() bool { return b.SSH == nil || b.SSH.User == "" && b.SSH.Host == "" && b.SSH.Port == "" } diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go index 5b57b99..2a939e9 100644 --- a/pkg/cli/cli.go +++ b/pkg/cli/cli.go @@ -123,7 +123,7 @@ func initClient() { } if !command.Opts.Sessions { - fmt.Printf("Connected to %s\n", cl.ServerVersion()) + fmt.Printf("Connected to %s\n", cl.ServerVersionInfo()) } fmt.Println("Checking database objects...") diff --git a/pkg/client/client.go b/pkg/client/client.go index 2c4993d..22216ca 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -291,7 +291,7 @@ func (client *Client) Activity() (*Result, error) { return client.query("SHOW QUERIES") } - version := getMajorMinorVersion(client.serverVersion) + version := getMajorMinorVersionString(client.serverVersion) query := statements.Activity[version] if query == "" { query = statements.Activity["default"] @@ -325,10 +325,14 @@ func (client *Client) SetReadOnlyMode() error { return nil } -func (client *Client) ServerVersion() string { +func (client *Client) ServerVersionInfo() string { return fmt.Sprintf("%s %s", client.serverType, client.serverVersion) } +func (client *Client) ServerVersion() string { + return client.serverVersion +} + func (client *Client) context() (context.Context, context.CancelFunc) { if client.queryTimeout > 0 { return context.WithTimeout(context.Background(), client.queryTimeout) @@ -405,7 +409,7 @@ func (client *Client) query(query string, args ...interface{}) (*Result, error) return nil, err } - // Make sure to never return null colums + // Make sure to never return null columns if cols == nil { cols = []string{} } diff --git a/pkg/client/dump.go b/pkg/client/dump.go index d9c1705..6cdbe75 100644 --- a/pkg/client/dump.go +++ b/pkg/client/dump.go @@ -2,6 +2,7 @@ package client import ( "bytes" + "context" "fmt" "io" "net/url" @@ -20,21 +21,37 @@ type Dump struct { Table string } -// CanExport returns true if database dump tool could be used without an error -func (d *Dump) CanExport() bool { - return exec.Command("pg_dump", "--version").Run() == nil +// Validate checks availability and version of pg_dump CLI +func (d *Dump) Validate(serverVersion string) error { + out := bytes.NewBuffer(nil) + + cmd := exec.Command("pg_dump", "--version") + cmd.Stdout = out + cmd.Stderr = out + + if err := cmd.Run(); err != nil { + return fmt.Errorf("pg_dump command failed: %s", out.Bytes()) + } + + detected, dumpVersion := detectDumpVersion(out.String()) + if detected && serverVersion != "" { + satisfied := checkVersionRequirement(dumpVersion, serverVersion) + if !satisfied { + return fmt.Errorf("pg_dump version %v not compatible with server version %v", dumpVersion, serverVersion) + } + } + + return nil } // Export streams the database dump to the specified writer -func (d *Dump) Export(connstr string, writer io.Writer) error { +func (d *Dump) Export(ctx context.Context, connstr string, writer io.Writer) error { if str, err := removeUnsupportedOptions(connstr); err != nil { return err } else { connstr = str } - errOutput := bytes.NewBuffer(nil) - opts := []string{ "--no-owner", // skip restoration of object ownership in plain-text format "--clean", // clean (drop) database objects before recreating @@ -46,8 +63,9 @@ func (d *Dump) Export(connstr string, writer io.Writer) error { } opts = append(opts, connstr) + errOutput := bytes.NewBuffer(nil) - cmd := exec.Command("pg_dump", opts...) + cmd := exec.CommandContext(ctx, "pg_dump", opts...) cmd.Stdout = writer cmd.Stderr = errOutput diff --git a/pkg/client/dump_test.go b/pkg/client/dump_test.go index 61c5171..f194ce8 100644 --- a/pkg/client/dump_test.go +++ b/pkg/client/dump_test.go @@ -1,6 +1,7 @@ package client import ( + "context" "fmt" "os" "testing" @@ -27,26 +28,28 @@ func testDumpExport(t *testing.T) { dump := Dump{} // Test for pg_dump presence - assert.True(t, dump.CanExport()) + assert.NoError(t, dump.Validate("10.0")) + assert.NoError(t, dump.Validate("")) + assert.Contains(t, dump.Validate("20").Error(), "not compatible with server version 20") // Test full db dump - err = dump.Export(url, saveFile) + err = dump.Export(context.Background(), url, saveFile) assert.NoError(t, err) // Test nonexistent database invalidURL := fmt.Sprintf("postgres://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, "foobar") - err = dump.Export(invalidURL, saveFile) + err = dump.Export(context.Background(), invalidURL, saveFile) assert.Contains(t, err.Error(), `database "foobar" does not exist`) // Test dump of non existent db dump = Dump{Table: "foobar"} - err = dump.Export(url, saveFile) + err = dump.Export(context.Background(), url, saveFile) assert.NotNil(t, err) assert.Contains(t, err.Error(), "no matching tables were found") // Should drop "search_path" param from URI dump = Dump{} searchPathURL := fmt.Sprintf("postgres://%s@%s:%s/%s?sslmode=disable&search_path=private", serverUser, serverHost, serverPort, serverDatabase) - err = dump.Export(searchPathURL, saveFile) + err = dump.Export(context.Background(), searchPathURL, saveFile) assert.NoError(t, err) } diff --git a/pkg/client/tunnel.go b/pkg/client/tunnel.go index 685cc0b..b64e515 100644 --- a/pkg/client/tunnel.go +++ b/pkg/client/tunnel.go @@ -85,7 +85,7 @@ func makeConfig(info *shared.SSHInfo) (*ssh.ClientConfig, error) { return nil, errors.New("ssh public key not found at " + keyPath) } - // Appen public key authentication method + // Append public key authentication method key, err := parsePrivateKey(keyPath, info.KeyPassword) if err != nil { return nil, err diff --git a/pkg/client/util.go b/pkg/client/util.go index 5dd4bb0..50d83d5 100644 --- a/pkg/client/util.go +++ b/pkg/client/util.go @@ -1,6 +1,7 @@ package client import ( + "fmt" "regexp" "strings" ) @@ -14,22 +15,31 @@ var ( reDashComment = regexp.MustCompile(`(?m)--.+`) // Postgres version signature - postgresSignature = regexp.MustCompile(`(?i)postgresql ([\d\.]+)\s?`) - postgresType = "PostgreSQL" + postgresSignature = regexp.MustCompile(`(?i)postgresql ([\d\.]+)\s?`) + postgresDumpSignature = regexp.MustCompile(`\s([\d\.]+)\s?`) + postgresType = "PostgreSQL" // Cockroach version signature cockroachSignature = regexp.MustCompile(`(?i)cockroachdb ccl v([\d\.]+)\s?`) cockroachType = "CockroachDB" ) +// Get major and minor version components +// Example: 10.2.3.1 -> 10.2 +func getMajorMinorVersion(str string) (major int, minor int) { + chunks := strings.Split(str, ".") + fmt.Sscanf(chunks[0], "%d", &major) + if len(chunks) > 1 { + fmt.Sscanf(chunks[1], "%d", &minor) + } + return +} + // Get short version from the string // Example: 10.2.3.1 -> 10.2 -func getMajorMinorVersion(str string) string { - chunks := strings.Split(str, ".") - if len(chunks) == 0 { - return str - } - return strings.Join(chunks[0:2], ".") +func getMajorMinorVersionString(str string) string { + major, minor := getMajorMinorVersion(str) + return fmt.Sprintf("%d.%d", major, minor) } func detectServerTypeAndVersion(version string) (bool, string, string) { @@ -50,6 +60,26 @@ func detectServerTypeAndVersion(version string) (bool, string, string) { return false, "", "" } +// detectDumpVersion parses out version from `pg_dump -V` command. +func detectDumpVersion(version string) (bool, string) { + matches := postgresDumpSignature.FindAllStringSubmatch(version, 1) + if len(matches) > 0 { + return true, matches[0][1] + } + return false, "" +} + +func checkVersionRequirement(client, server string) bool { + clientMajor, clientMinor := getMajorMinorVersion(client) + serverMajor, serverMinor := getMajorMinorVersion(server) + + if serverMajor < 10 { + return clientMajor >= serverMajor && clientMinor >= serverMinor + } + + return clientMajor >= serverMajor +} + // containsRestrictedKeywords returns true if given keyword is not allowed in read-only mode func containsRestrictedKeywords(str string) bool { str = reSlashComment.ReplaceAllString(str, "") diff --git a/pkg/client/util_test.go b/pkg/client/util_test.go index 5036a8a..2551d87 100644 --- a/pkg/client/util_test.go +++ b/pkg/client/util_test.go @@ -48,3 +48,69 @@ func TestDetectServerType(t *testing.T) { }) } } + +func TestDetectDumpVersion(t *testing.T) { + examples := []struct { + input string + match bool + version string + }{ + {"", false, ""}, + {"pg_dump (PostgreSQL) 9.6", true, "9.6"}, + {"pg_dump 10", true, "10"}, + {"pg_dump (PostgreSQL) 14.5 (Homebrew)", true, "14.5"}, + } + + for _, ex := range examples { + t.Run("input:"+ex.input, func(t *testing.T) { + match, version := detectDumpVersion(ex.input) + + assert.Equal(t, ex.match, match) + assert.Equal(t, ex.version, version) + }) + } +} + +func TestGetMajorMinorVersion(t *testing.T) { + examples := []struct { + input string + major int + minor int + }{ + {"", 0, 0}, + {" ", 0, 0}, + {"0", 0, 0}, + {"9.6", 9, 6}, + {"9.6.1.1", 9, 6}, + {"10", 10, 0}, + {"10.1 ", 10, 1}, + } + + for _, ex := range examples { + t.Run(ex.input, func(t *testing.T) { + major, minor := getMajorMinorVersion(ex.input) + assert.Equal(t, ex.major, major) + assert.Equal(t, ex.minor, minor) + }) + } +} + +func TestCheckVersionRequirement(t *testing.T) { + examples := []struct { + client string + server string + result bool + }{ + {"", "", true}, + {"0", "0", true}, + {"9.6", "9.7", false}, + {"9.6.10", "9.6.25", true}, + {"10.0", "10.1", true}, + {"10.5", "10.1", true}, + {"14.5", "15.1", false}, + } + + for _, ex := range examples { + assert.Equal(t, ex.result, checkVersionRequirement(ex.client, ex.server)) + } +}