Perform client version validation before executing pg_dump command (#614)

* Add func to parse out pg_dump version
* Perform client vs server version checking before dump exports
* Fix dump tests
* Add extra test to validate against empty server version
* Fix attachment filenames cleanup function
* Add extra test
* Fix small typos in comments
* Drop third-party package to deal with versions
* Tweak the pg dump incompatibility error message
* Run CI on pull requests
This commit is contained in:
Dan Sosedoff
2022-12-12 15:09:12 -06:00
committed by GitHub
parent 7557ac854e
commit 4c40eef99a
15 changed files with 185 additions and 43 deletions

View File

@@ -291,7 +291,7 @@ func (client *Client) Activity() (*Result, error) {
return client.query("SHOW QUERIES")
}
version := getMajorMinorVersion(client.serverVersion)
version := getMajorMinorVersionString(client.serverVersion)
query := statements.Activity[version]
if query == "" {
query = statements.Activity["default"]
@@ -325,10 +325,14 @@ func (client *Client) SetReadOnlyMode() error {
return nil
}
func (client *Client) ServerVersion() string {
func (client *Client) ServerVersionInfo() string {
return fmt.Sprintf("%s %s", client.serverType, client.serverVersion)
}
func (client *Client) ServerVersion() string {
return client.serverVersion
}
func (client *Client) context() (context.Context, context.CancelFunc) {
if client.queryTimeout > 0 {
return context.WithTimeout(context.Background(), client.queryTimeout)
@@ -405,7 +409,7 @@ func (client *Client) query(query string, args ...interface{}) (*Result, error)
return nil, err
}
// Make sure to never return null colums
// Make sure to never return null columns
if cols == nil {
cols = []string{}
}

View File

@@ -2,6 +2,7 @@ package client
import (
"bytes"
"context"
"fmt"
"io"
"net/url"
@@ -20,21 +21,37 @@ type Dump struct {
Table string
}
// CanExport returns true if database dump tool could be used without an error
func (d *Dump) CanExport() bool {
return exec.Command("pg_dump", "--version").Run() == nil
// Validate checks availability and version of pg_dump CLI
func (d *Dump) Validate(serverVersion string) error {
out := bytes.NewBuffer(nil)
cmd := exec.Command("pg_dump", "--version")
cmd.Stdout = out
cmd.Stderr = out
if err := cmd.Run(); err != nil {
return fmt.Errorf("pg_dump command failed: %s", out.Bytes())
}
detected, dumpVersion := detectDumpVersion(out.String())
if detected && serverVersion != "" {
satisfied := checkVersionRequirement(dumpVersion, serverVersion)
if !satisfied {
return fmt.Errorf("pg_dump version %v not compatible with server version %v", dumpVersion, serverVersion)
}
}
return nil
}
// Export streams the database dump to the specified writer
func (d *Dump) Export(connstr string, writer io.Writer) error {
func (d *Dump) Export(ctx context.Context, connstr string, writer io.Writer) error {
if str, err := removeUnsupportedOptions(connstr); err != nil {
return err
} else {
connstr = str
}
errOutput := bytes.NewBuffer(nil)
opts := []string{
"--no-owner", // skip restoration of object ownership in plain-text format
"--clean", // clean (drop) database objects before recreating
@@ -46,8 +63,9 @@ func (d *Dump) Export(connstr string, writer io.Writer) error {
}
opts = append(opts, connstr)
errOutput := bytes.NewBuffer(nil)
cmd := exec.Command("pg_dump", opts...)
cmd := exec.CommandContext(ctx, "pg_dump", opts...)
cmd.Stdout = writer
cmd.Stderr = errOutput

View File

@@ -1,6 +1,7 @@
package client
import (
"context"
"fmt"
"os"
"testing"
@@ -27,26 +28,28 @@ func testDumpExport(t *testing.T) {
dump := Dump{}
// Test for pg_dump presence
assert.True(t, dump.CanExport())
assert.NoError(t, dump.Validate("10.0"))
assert.NoError(t, dump.Validate(""))
assert.Contains(t, dump.Validate("20").Error(), "not compatible with server version 20")
// Test full db dump
err = dump.Export(url, saveFile)
err = dump.Export(context.Background(), url, saveFile)
assert.NoError(t, err)
// Test nonexistent database
invalidURL := fmt.Sprintf("postgres://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, "foobar")
err = dump.Export(invalidURL, saveFile)
err = dump.Export(context.Background(), invalidURL, saveFile)
assert.Contains(t, err.Error(), `database "foobar" does not exist`)
// Test dump of non existent db
dump = Dump{Table: "foobar"}
err = dump.Export(url, saveFile)
err = dump.Export(context.Background(), url, saveFile)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "no matching tables were found")
// Should drop "search_path" param from URI
dump = Dump{}
searchPathURL := fmt.Sprintf("postgres://%s@%s:%s/%s?sslmode=disable&search_path=private", serverUser, serverHost, serverPort, serverDatabase)
err = dump.Export(searchPathURL, saveFile)
err = dump.Export(context.Background(), searchPathURL, saveFile)
assert.NoError(t, err)
}

View File

@@ -85,7 +85,7 @@ func makeConfig(info *shared.SSHInfo) (*ssh.ClientConfig, error) {
return nil, errors.New("ssh public key not found at " + keyPath)
}
// Appen public key authentication method
// Append public key authentication method
key, err := parsePrivateKey(keyPath, info.KeyPassword)
if err != nil {
return nil, err

View File

@@ -1,6 +1,7 @@
package client
import (
"fmt"
"regexp"
"strings"
)
@@ -14,22 +15,31 @@ var (
reDashComment = regexp.MustCompile(`(?m)--.+`)
// Postgres version signature
postgresSignature = regexp.MustCompile(`(?i)postgresql ([\d\.]+)\s?`)
postgresType = "PostgreSQL"
postgresSignature = regexp.MustCompile(`(?i)postgresql ([\d\.]+)\s?`)
postgresDumpSignature = regexp.MustCompile(`\s([\d\.]+)\s?`)
postgresType = "PostgreSQL"
// Cockroach version signature
cockroachSignature = regexp.MustCompile(`(?i)cockroachdb ccl v([\d\.]+)\s?`)
cockroachType = "CockroachDB"
)
// Get major and minor version components
// Example: 10.2.3.1 -> 10.2
func getMajorMinorVersion(str string) (major int, minor int) {
chunks := strings.Split(str, ".")
fmt.Sscanf(chunks[0], "%d", &major)
if len(chunks) > 1 {
fmt.Sscanf(chunks[1], "%d", &minor)
}
return
}
// Get short version from the string
// Example: 10.2.3.1 -> 10.2
func getMajorMinorVersion(str string) string {
chunks := strings.Split(str, ".")
if len(chunks) == 0 {
return str
}
return strings.Join(chunks[0:2], ".")
func getMajorMinorVersionString(str string) string {
major, minor := getMajorMinorVersion(str)
return fmt.Sprintf("%d.%d", major, minor)
}
func detectServerTypeAndVersion(version string) (bool, string, string) {
@@ -50,6 +60,26 @@ func detectServerTypeAndVersion(version string) (bool, string, string) {
return false, "", ""
}
// detectDumpVersion parses out version from `pg_dump -V` command.
func detectDumpVersion(version string) (bool, string) {
matches := postgresDumpSignature.FindAllStringSubmatch(version, 1)
if len(matches) > 0 {
return true, matches[0][1]
}
return false, ""
}
func checkVersionRequirement(client, server string) bool {
clientMajor, clientMinor := getMajorMinorVersion(client)
serverMajor, serverMinor := getMajorMinorVersion(server)
if serverMajor < 10 {
return clientMajor >= serverMajor && clientMinor >= serverMinor
}
return clientMajor >= serverMajor
}
// containsRestrictedKeywords returns true if given keyword is not allowed in read-only mode
func containsRestrictedKeywords(str string) bool {
str = reSlashComment.ReplaceAllString(str, "")

View File

@@ -48,3 +48,69 @@ func TestDetectServerType(t *testing.T) {
})
}
}
func TestDetectDumpVersion(t *testing.T) {
examples := []struct {
input string
match bool
version string
}{
{"", false, ""},
{"pg_dump (PostgreSQL) 9.6", true, "9.6"},
{"pg_dump 10", true, "10"},
{"pg_dump (PostgreSQL) 14.5 (Homebrew)", true, "14.5"},
}
for _, ex := range examples {
t.Run("input:"+ex.input, func(t *testing.T) {
match, version := detectDumpVersion(ex.input)
assert.Equal(t, ex.match, match)
assert.Equal(t, ex.version, version)
})
}
}
func TestGetMajorMinorVersion(t *testing.T) {
examples := []struct {
input string
major int
minor int
}{
{"", 0, 0},
{" ", 0, 0},
{"0", 0, 0},
{"9.6", 9, 6},
{"9.6.1.1", 9, 6},
{"10", 10, 0},
{"10.1 ", 10, 1},
}
for _, ex := range examples {
t.Run(ex.input, func(t *testing.T) {
major, minor := getMajorMinorVersion(ex.input)
assert.Equal(t, ex.major, major)
assert.Equal(t, ex.minor, minor)
})
}
}
func TestCheckVersionRequirement(t *testing.T) {
examples := []struct {
client string
server string
result bool
}{
{"", "", true},
{"0", "0", true},
{"9.6", "9.7", false},
{"9.6.10", "9.6.25", true},
{"10.0", "10.1", true},
{"10.5", "10.1", true},
{"14.5", "15.1", false},
}
for _, ex := range examples {
assert.Equal(t, ex.result, checkVersionRequirement(ex.client, ex.server))
}
}