Merge pull request #421 from sosedoff/restrict-keywords
Keyword restriction in read-only mode
This commit is contained in:
commit
15c21b6379
@ -1,6 +1,7 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
neturl "net/url"
|
||||
@ -344,6 +345,9 @@ func (client *Client) query(query string, args ...interface{}) (*Result, error)
|
||||
if err := client.SetReadOnlyMode(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if containsRestrictedKeywords(query) {
|
||||
return nil, errors.New("query contains keywords not allowed in read-only mode")
|
||||
}
|
||||
}
|
||||
|
||||
action := strings.ToLower(strings.Split(query, " ")[0])
|
||||
|
@ -429,15 +429,33 @@ func testHistoryUniqueness(t *testing.T) {
|
||||
}
|
||||
|
||||
func testReadOnlyMode(t *testing.T) {
|
||||
command.Opts.ReadOnly = true
|
||||
defer func() {
|
||||
command.Opts.ReadOnly = false
|
||||
}()
|
||||
|
||||
url := fmt.Sprintf("postgres://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, serverDatabase)
|
||||
client, _ := NewFromUrl(url, nil)
|
||||
|
||||
err := client.SetReadOnlyMode()
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = client.Query("CREATE TABLE foobar(id integer);")
|
||||
_, err = client.Query("\nCREATE TABLE foobar(id integer);\n")
|
||||
assert.NotNil(t, err)
|
||||
assert.Contains(t, err.Error(), "in a read-only transaction")
|
||||
assert.Error(t, err, "query contains keywords not allowed in read-only mode")
|
||||
|
||||
// Turn off guard
|
||||
client.db.Exec("SET default_transaction_read_only=off;")
|
||||
|
||||
_, err = client.Query("\nCREATE TABLE foobar(id integer);\n")
|
||||
assert.NotNil(t, err)
|
||||
assert.Contains(t, err.Error(), "query contains keywords not allowed in read-only mode")
|
||||
|
||||
_, err = client.Query("-- CREATE TABLE foobar(id integer);\nSELECT 'foo';")
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = client.Query("/* CREATE TABLE foobar(id integer); */ SELECT 'foo';")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestAll(t *testing.T) {
|
||||
|
@ -1,9 +1,19 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
// List of keywords that are not allowed in read-only mode
|
||||
reRestrictedKeywords = regexp.MustCompile(`(?mi)\s?(CREATE|INSERT|DROP|DELETE|TRUNCATE|GRANT|OPEN|IMPORT|COPY)\s`)
|
||||
|
||||
// Comment regular expressions
|
||||
reSlashComment = regexp.MustCompile(`(?m)/\*.+\*/`)
|
||||
reDashComment = regexp.MustCompile(`(?m)--.+`)
|
||||
)
|
||||
|
||||
// Get short version from the string
|
||||
// Example: 10.2.3.1 -> 10.2
|
||||
func getMajorMinorVersion(str string) string {
|
||||
@ -13,3 +23,11 @@ func getMajorMinorVersion(str string) string {
|
||||
}
|
||||
return strings.Join(chunks[0:2], ".")
|
||||
}
|
||||
|
||||
// containsRestrictedKeywords returns true if given keyword is not allowed in read-only mode
|
||||
func containsRestrictedKeywords(str string) bool {
|
||||
str = reSlashComment.ReplaceAllString(str, "")
|
||||
str = reDashComment.ReplaceAllString(str, "")
|
||||
|
||||
return reRestrictedKeywords.MatchString(str)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user