Reject queries that contain restricted keywords in read-only mode

This commit is contained in:
Dan Sosedoff 2019-02-20 18:20:27 -06:00
parent 9f29b10098
commit 40eb74529e
3 changed files with 24 additions and 2 deletions

View File

@ -1,6 +1,7 @@
package client package client
import ( import (
"errors"
"fmt" "fmt"
"log" "log"
neturl "net/url" neturl "net/url"
@ -344,6 +345,9 @@ func (client *Client) query(query string, args ...interface{}) (*Result, error)
if err := client.SetReadOnlyMode(); err != nil { if err := client.SetReadOnlyMode(); err != nil {
return nil, err return nil, err
} }
if containsRestrictedKeywords(query) {
return nil, errors.New("query contains keywords not allowed in read-only mode")
}
} }
action := strings.ToLower(strings.Split(query, " ")[0]) action := strings.ToLower(strings.Split(query, " ")[0])

View File

@ -429,15 +429,24 @@ func testHistoryUniqueness(t *testing.T) {
} }
func testReadOnlyMode(t *testing.T) { func testReadOnlyMode(t *testing.T) {
command.Opts.ReadOnly = true
url := fmt.Sprintf("postgres://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, serverDatabase) url := fmt.Sprintf("postgres://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, serverDatabase)
client, _ := NewFromUrl(url, nil) client, _ := NewFromUrl(url, nil)
err := client.SetReadOnlyMode() err := client.SetReadOnlyMode()
assert.Equal(t, nil, err) assert.NoError(t, err)
_, err = client.Query("CREATE TABLE foobar(id integer);") _, err = client.Query("CREATE TABLE foobar(id integer);")
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Contains(t, err.Error(), "in a read-only transaction") assert.Error(t, err, "query contains keywords not allowed in read-only mode")
// Turn off guard
client.db.Exec("SET default_transaction_read_only=off;")
_, err = client.Query("CREATE TABLE foobar(id integer);")
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "query contains keywords not allowed in read-only mode")
} }
func TestAll(t *testing.T) { func TestAll(t *testing.T) {

View File

@ -1,9 +1,13 @@
package client package client
import ( import (
"regexp"
"strings" "strings"
) )
// List of keywords that are not allowed in read-only mode
var restrictedKeywords = regexp.MustCompile(`(?mi)\s?(CREATE|INSERT|DROP|DELETE|TRUNCATE|GRANT|OPEN|IMPORT|COPY|LOCK|SET)\s`)
// Get short version from the string // Get short version from the string
// Example: 10.2.3.1 -> 10.2 // Example: 10.2.3.1 -> 10.2
func getMajorMinorVersion(str string) string { func getMajorMinorVersion(str string) string {
@ -13,3 +17,8 @@ func getMajorMinorVersion(str string) string {
} }
return strings.Join(chunks[0:2], ".") return strings.Join(chunks[0:2], ".")
} }
// containsRestrictedKeywords returns true if given keyword is not allowed in read-only mode
func containsRestrictedKeywords(str string) bool {
return restrictedKeywords.MatchString(str)
}