Perform client version validation before executing pg_dump command (#614)
* Add func to parse out pg_dump version * Perform client vs server version checking before dump exports * Fix dump tests * Add extra test to validate against empty server version * Fix attachment filenames cleanup function * Add extra test * Fix small typos in comments * Drop third-party package to deal with versions * Tweak the pg dump incompatibility error message * Run CI on pull requests
This commit is contained in:
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
neturl "net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -58,7 +57,7 @@ func setClient(c *gin.Context, newClient *client.Client) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetHome renderes the home page
|
||||
// GetHome renders the home page
|
||||
func GetHome(prefix string) http.Handler {
|
||||
if prefix != "" {
|
||||
prefix = "/" + prefix
|
||||
@@ -530,10 +529,10 @@ func DataExport(c *gin.Context) {
|
||||
Table: strings.TrimSpace(c.Request.FormValue("table")),
|
||||
}
|
||||
|
||||
// If pg_dump is not available the following code will not show an error in browser.
|
||||
// This is due to the headers being written first.
|
||||
if !dump.CanExport() {
|
||||
badRequest(c, errPgDumpNotFound)
|
||||
// Perform validation of pg_dump command availability and compatibility.
|
||||
// Must be done before the actual command is executed to display errors.
|
||||
if err := dump.Validate(db.ServerVersion()); err != nil {
|
||||
badRequest(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -542,16 +541,18 @@ func DataExport(c *gin.Context) {
|
||||
if dump.Table != "" {
|
||||
filename = filename + "_" + dump.Table
|
||||
}
|
||||
reg := regexp.MustCompile(`[^._\\w]+`)
|
||||
cleanFilename := reg.ReplaceAllString(filename, "")
|
||||
|
||||
filename = sanitizeFilename(filename)
|
||||
filename = fmt.Sprintf("%s_%s", filename, time.Now().Format("20060102_150405"))
|
||||
|
||||
c.Header(
|
||||
"Content-Disposition",
|
||||
fmt.Sprintf(`attachment; filename="%s.sql.gz"`, cleanFilename),
|
||||
fmt.Sprintf(`attachment; filename="%s.sql.gz"`, filename),
|
||||
)
|
||||
|
||||
err = dump.Export(db.ConnectionString, c.Writer)
|
||||
err = dump.Export(c.Request.Context(), db.ConnectionString, c.Writer)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("pg_dump command failed")
|
||||
badRequest(c, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,5 @@ var (
|
||||
errURLRequired = errors.New("URL parameter is required")
|
||||
errQueryRequired = errors.New("Query parameter is required")
|
||||
errDatabaseNameRequired = errors.New("Database name is required")
|
||||
errPgDumpNotFound = errors.New("pg_dump utility is not found")
|
||||
errBackendConnectError = errors.New("Unable to connect to the auth backend")
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"mime"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -39,6 +40,9 @@ var (
|
||||
"_": "/",
|
||||
".": "=",
|
||||
}
|
||||
|
||||
// Regular expression to remove unwanted characters in filenames
|
||||
regexCleanFilename = regexp.MustCompile(`[^\w]+`)
|
||||
)
|
||||
|
||||
type Error struct {
|
||||
@@ -74,6 +78,11 @@ func desanitize64(query string) string {
|
||||
return query
|
||||
}
|
||||
|
||||
func sanitizeFilename(str string) string {
|
||||
str = strings.ReplaceAll(str, ".", "_")
|
||||
return regexCleanFilename.ReplaceAllString(str, "")
|
||||
}
|
||||
|
||||
func getSessionId(req *http.Request) string {
|
||||
id := req.Header.Get("x-session-id")
|
||||
if id == "" {
|
||||
|
||||
@@ -30,6 +30,22 @@ func Test_cleanQuery(t *testing.T) {
|
||||
assert.Equal(t, "test", cleanQuery("--test\ntest\n -- test\n"))
|
||||
}
|
||||
|
||||
func Test_sanitizeFilename(t *testing.T) {
|
||||
examples := map[string]string{
|
||||
"foo": "foo",
|
||||
"fooBar": "fooBar",
|
||||
"foo.bar": "foo_bar",
|
||||
`"foo"."bar"`: "foo_bar",
|
||||
"!@#$foo.&&*(&bar": "foo_bar",
|
||||
}
|
||||
|
||||
for given, expected := range examples {
|
||||
t.Run(given, func(t *testing.T) {
|
||||
assert.Equal(t, expected, sanitizeFilename(given))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_getSessionId(t *testing.T) {
|
||||
req := &http.Request{Header: http.Header{}}
|
||||
req.Header.Add("x-session-id", "token")
|
||||
|
||||
Reference in New Issue
Block a user