Perform client version validation before executing pg_dump command (#614)

* Add func to parse out pg_dump version
* Perform client vs server version checking before dump exports
* Fix dump tests
* Add extra test to validate against empty server version
* Fix attachment filenames cleanup function
* Add extra test
* Fix small typos in comments
* Drop third-party package to deal with versions
* Tweak the pg dump incompatibility error message
* Run CI on pull requests
This commit is contained in:
Dan Sosedoff
2022-12-12 15:09:12 -06:00
committed by GitHub
parent 7557ac854e
commit 4c40eef99a
15 changed files with 185 additions and 43 deletions

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"net/http"
neturl "net/url"
"regexp"
"strings"
"time"
@@ -58,7 +57,7 @@ func setClient(c *gin.Context, newClient *client.Client) error {
return nil
}
// GetHome renderes the home page
// GetHome renders the home page
func GetHome(prefix string) http.Handler {
if prefix != "" {
prefix = "/" + prefix
@@ -530,10 +529,10 @@ func DataExport(c *gin.Context) {
Table: strings.TrimSpace(c.Request.FormValue("table")),
}
// If pg_dump is not available the following code will not show an error in browser.
// This is due to the headers being written first.
if !dump.CanExport() {
badRequest(c, errPgDumpNotFound)
// Perform validation of pg_dump command availability and compatibility.
// Must be done before the actual command is executed to display errors.
if err := dump.Validate(db.ServerVersion()); err != nil {
badRequest(c, err)
return
}
@@ -542,16 +541,18 @@ func DataExport(c *gin.Context) {
if dump.Table != "" {
filename = filename + "_" + dump.Table
}
reg := regexp.MustCompile(`[^._\\w]+`)
cleanFilename := reg.ReplaceAllString(filename, "")
filename = sanitizeFilename(filename)
filename = fmt.Sprintf("%s_%s", filename, time.Now().Format("20060102_150405"))
c.Header(
"Content-Disposition",
fmt.Sprintf(`attachment; filename="%s.sql.gz"`, cleanFilename),
fmt.Sprintf(`attachment; filename="%s.sql.gz"`, filename),
)
err = dump.Export(db.ConnectionString, c.Writer)
err = dump.Export(c.Request.Context(), db.ConnectionString, c.Writer)
if err != nil {
logger.WithError(err).Error("pg_dump command failed")
badRequest(c, err)
}
}

View File

@@ -14,6 +14,5 @@ var (
errURLRequired = errors.New("URL parameter is required")
errQueryRequired = errors.New("Query parameter is required")
errDatabaseNameRequired = errors.New("Database name is required")
errPgDumpNotFound = errors.New("pg_dump utility is not found")
errBackendConnectError = errors.New("Unable to connect to the auth backend")
)

View File

@@ -5,6 +5,7 @@ import (
"mime"
"net/http"
"path/filepath"
"regexp"
"strconv"
"strings"
@@ -39,6 +40,9 @@ var (
"_": "/",
".": "=",
}
// Regular expression to remove unwanted characters in filenames
regexCleanFilename = regexp.MustCompile(`[^\w]+`)
)
type Error struct {
@@ -74,6 +78,11 @@ func desanitize64(query string) string {
return query
}
func sanitizeFilename(str string) string {
str = strings.ReplaceAll(str, ".", "_")
return regexCleanFilename.ReplaceAllString(str, "")
}
func getSessionId(req *http.Request) string {
id := req.Header.Get("x-session-id")
if id == "" {

View File

@@ -30,6 +30,22 @@ func Test_cleanQuery(t *testing.T) {
assert.Equal(t, "test", cleanQuery("--test\ntest\n -- test\n"))
}
func Test_sanitizeFilename(t *testing.T) {
examples := map[string]string{
"foo": "foo",
"fooBar": "fooBar",
"foo.bar": "foo_bar",
`"foo"."bar"`: "foo_bar",
"!@#$foo.&&*(&bar": "foo_bar",
}
for given, expected := range examples {
t.Run(given, func(t *testing.T) {
assert.Equal(t, expected, sanitizeFilename(given))
})
}
}
func Test_getSessionId(t *testing.T) {
req := &http.Request{Header: http.Header{}}
req.Header.Add("x-session-id", "token")