Reject queries that contain restricted keywords in read-only mode
This commit is contained in:
parent
9f29b10098
commit
40eb74529e
@ -1,6 +1,7 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
neturl "net/url"
|
||||
@ -344,6 +345,9 @@ func (client *Client) query(query string, args ...interface{}) (*Result, error)
|
||||
if err := client.SetReadOnlyMode(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if containsRestrictedKeywords(query) {
|
||||
return nil, errors.New("query contains keywords not allowed in read-only mode")
|
||||
}
|
||||
}
|
||||
|
||||
action := strings.ToLower(strings.Split(query, " ")[0])
|
||||
|
@ -429,15 +429,24 @@ func testHistoryUniqueness(t *testing.T) {
|
||||
}
|
||||
|
||||
func testReadOnlyMode(t *testing.T) {
|
||||
command.Opts.ReadOnly = true
|
||||
|
||||
url := fmt.Sprintf("postgres://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, serverDatabase)
|
||||
client, _ := NewFromUrl(url, nil)
|
||||
|
||||
err := client.SetReadOnlyMode()
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = client.Query("CREATE TABLE foobar(id integer);")
|
||||
assert.NotNil(t, err)
|
||||
assert.Contains(t, err.Error(), "in a read-only transaction")
|
||||
assert.Error(t, err, "query contains keywords not allowed in read-only mode")
|
||||
|
||||
// Turn off guard
|
||||
client.db.Exec("SET default_transaction_read_only=off;")
|
||||
|
||||
_, err = client.Query("CREATE TABLE foobar(id integer);")
|
||||
assert.NotNil(t, err)
|
||||
assert.Contains(t, err.Error(), "query contains keywords not allowed in read-only mode")
|
||||
}
|
||||
|
||||
func TestAll(t *testing.T) {
|
||||
|
@ -1,9 +1,13 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// List of keywords that are not allowed in read-only mode
|
||||
var restrictedKeywords = regexp.MustCompile(`(?mi)\s?(CREATE|INSERT|DROP|DELETE|TRUNCATE|GRANT|OPEN|IMPORT|COPY|LOCK|SET)\s`)
|
||||
|
||||
// Get short version from the string
|
||||
// Example: 10.2.3.1 -> 10.2
|
||||
func getMajorMinorVersion(str string) string {
|
||||
@ -13,3 +17,8 @@ func getMajorMinorVersion(str string) string {
|
||||
}
|
||||
return strings.Join(chunks[0:2], ".")
|
||||
}
|
||||
|
||||
// containsRestrictedKeywords returns true if given keyword is not allowed in read-only mode
|
||||
func containsRestrictedKeywords(str string) bool {
|
||||
return restrictedKeywords.MatchString(str)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user