Local queries (#641)

* Read local queries from pgweb home directory
* Refactor local query functionality
* Allow picking local query in the query tab
* WIP
* Disable local query dropdown during execution
* Only allow local queries running in a single session mode
* Add middleware to enforce local query endpoint availability
* Fix query check
* Add query store tests
* Make query store errors portable
* Skip building specific tests on windows
This commit is contained in:
Dan Sosedoff
2023-02-02 16:13:14 -06:00
committed by GitHub
parent 1c3ab1fd1c
commit 41bf189e6b
23 changed files with 884 additions and 12 deletions

43
pkg/queries/field.go Normal file
View File

@@ -0,0 +1,43 @@
package queries
import (
"fmt"
"regexp"
"strings"
)
type field struct {
value string
re *regexp.Regexp
}
func (f field) String() string {
return f.value
}
func (f field) matches(input string) bool {
if f.re != nil {
return f.re.MatchString(input)
}
return f.value == input
}
func newField(value string) (field, error) {
f := field{value: value}
if value == "*" { // match everything
f.re = reMatchAll
} else if reExpression.MatchString(value) { // match by given expression
// Make writing expressions easier for values like "foo_*"
if strings.Count(value, "*") == 1 {
value = strings.Replace(value, "*", "(.+)", 1)
}
re, err := regexp.Compile(fmt.Sprintf("^%s$", value))
if err != nil {
return f, err
}
f.re = re
}
return f, nil
}

41
pkg/queries/field_test.go Normal file
View File

@@ -0,0 +1,41 @@
package queries
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_field(t *testing.T) {
field, err := newField("val")
assert.NoError(t, err)
assert.Equal(t, "val", field.value)
assert.Equal(t, true, field.matches("val"))
assert.Equal(t, false, field.matches("value"))
field, err = newField("*")
assert.NoError(t, err)
assert.Equal(t, "*", field.value)
assert.NotNil(t, field.re)
assert.Equal(t, true, field.matches("val"))
assert.Equal(t, true, field.matches("value"))
field, err = newField("(.+")
assert.EqualError(t, err, "error parsing regexp: missing closing ): `^(.+$`")
assert.NotNil(t, field)
field, err = newField("foo_*")
assert.NoError(t, err)
assert.Equal(t, "foo_*", field.value)
assert.NotNil(t, field.re)
assert.Equal(t, false, field.matches("foo"))
assert.Equal(t, true, field.matches("foo_bar"))
assert.Equal(t, true, field.matches("foo_bar_widget"))
}
func Test_fieldString(t *testing.T) {
field, err := newField("val")
assert.NoError(t, err)
assert.Equal(t, "val", field.String())
}

148
pkg/queries/metadata.go Normal file
View File

@@ -0,0 +1,148 @@
package queries
import (
"fmt"
"regexp"
"strconv"
"strings"
"time"
)
var (
reMetaPrefix = regexp.MustCompile(`(?m)^\s*--\s*pgweb:\s*(.+)`)
reMetaContent = regexp.MustCompile(`([\w]+)\s*=\s*"([^"]+)"`)
reMatchAll = regexp.MustCompile(`^(.+)$`)
reExpression = regexp.MustCompile(`[\[\]\(\)\+\*]+`)
allowedKeys = []string{"title", "description", "host", "user", "database", "mode", "timeout"}
allowedModes = map[string]bool{"readonly": true, "*": true}
)
type Metadata struct {
Title string
Description string
Host field
User field
Database field
Mode field
Timeout *time.Duration
}
func parseMetadata(input string) (*Metadata, error) {
fields, err := parseFields(input)
if err != nil {
return nil, err
}
if fields == nil {
return nil, nil
}
// Host must be set to limit queries availability
if fields["host"] == "" {
return nil, fmt.Errorf("host field must be set")
}
// Allow matching for any user, database and mode by default
if fields["user"] == "" {
fields["user"] = "*"
}
if fields["database"] == "" {
fields["database"] = "*"
}
if fields["mode"] == "" {
fields["mode"] = "*"
}
hostField, err := newField(fields["host"])
if err != nil {
return nil, fmt.Errorf(`error initializing "host" field: %w`, err)
}
userField, err := newField(fields["user"])
if err != nil {
return nil, fmt.Errorf(`error initializing "user" field: %w`, err)
}
dbField, err := newField(fields["database"])
if err != nil {
return nil, fmt.Errorf(`error initializing "database" field: %w`, err)
}
if !allowedModes[fields["mode"]] {
return nil, fmt.Errorf(`invalid "mode" field value: %q`, fields["mode"])
}
modeField, err := newField(fields["mode"])
if err != nil {
return nil, fmt.Errorf(`error initializing "mode" field: %w`, err)
}
var timeout *time.Duration
if fields["timeout"] != "" {
timeoutSec, err := strconv.Atoi(fields["timeout"])
if err != nil {
return nil, fmt.Errorf(`error initializing "timeout" field: %w`, err)
}
timeoutVal := time.Duration(timeoutSec) * time.Second
timeout = &timeoutVal
}
return &Metadata{
Title: fields["title"],
Description: fields["description"],
Host: hostField,
User: userField,
Database: dbField,
Mode: modeField,
Timeout: timeout,
}, nil
}
func parseFields(input string) (map[string]string, error) {
result := map[string]string{}
seenKeys := map[string]bool{}
allowed := map[string]bool{}
for _, key := range allowedKeys {
allowed[key] = true
}
matches := reMetaPrefix.FindAllStringSubmatch(input, -1)
if len(matches) == 0 {
return nil, nil
}
for _, match := range matches {
content := reMetaContent.FindAllStringSubmatch(match[1], -1)
if len(content) == 0 {
continue
}
for _, field := range content {
key := field[1]
value := field[2]
if !allowed[key] {
return result, fmt.Errorf("unknown key: %q", key)
}
if seenKeys[key] {
return result, fmt.Errorf("duplicate key: %q", key)
}
seenKeys[key] = true
result[key] = value
}
}
return result, nil
}
func sanitizeMetadata(input string) string {
lines := []string{}
for _, line := range strings.Split(input, "\n") {
line = reMetaPrefix.ReplaceAllString(line, "")
if len(line) > 0 {
lines = append(lines, line)
}
}
return strings.Join(lines, "\n")
}

View File

@@ -0,0 +1,146 @@
package queries
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_parseFields(t *testing.T) {
examples := []struct {
input string
err error
vals map[string]string
}{
{input: "", err: nil, vals: nil},
{input: "foobar", err: nil, vals: nil},
{input: "-- no pgweb meta", err: nil, vals: nil},
{
input: `--pgweb: foo=bar`,
err: nil,
vals: map[string]string{},
},
{
input: `--pgweb: host="localhost"`,
err: nil,
vals: map[string]string{"host": "localhost"},
},
{
input: `--pgweb: host="*" user="admin" database ="mydb"; mode = "readonly"`,
err: nil,
vals: map[string]string{
"host": "*",
"database": "mydb",
"user": "admin",
"mode": "readonly",
},
},
}
for _, ex := range examples {
t.Run(ex.input, func(t *testing.T) {
fields, err := parseFields(ex.input)
assert.Equal(t, ex.err, err)
assert.Equal(t, ex.vals, fields)
})
}
}
func Test_parseMetadata(t *testing.T) {
examples := []struct {
input string
err string
check func(meta *Metadata) bool
}{
{
input: `--pgweb: `,
err: `host field must be set`,
},
{
input: `--pgweb: hello="world"`,
err: `unknown key: "hello"`,
},
{
input: `--pgweb: host="localhost" user="anyuser" database="anydb" mode="foo"`,
err: `invalid "mode" field value: "foo"`,
},
{
input: "--pgweb2:",
check: func(m *Metadata) bool {
return m == nil
},
},
{
input: `--pgweb: host="localhost"`,
check: func(m *Metadata) bool {
return m.Host.value == "localhost" &&
m.User.value == "*" &&
m.Database.value == "*" &&
m.Mode.value == "*" &&
m.Timeout == nil
},
},
{
input: `--pgweb: host="localhost" user="anyuser" database="anydb" mode="*"`,
check: func(m *Metadata) bool {
return m.Host.value == "localhost" &&
m.Host.re == nil &&
m.User.value == "anyuser" &&
m.Database.value == "anydb" &&
m.Mode.value == "*" &&
m.Timeout == nil
},
},
{
input: `--pgweb: host="localhost" timeout="foo"`,
err: `error initializing "timeout" field: strconv.Atoi: parsing "foo": invalid syntax`,
},
{
input: `-- pgweb: host="local(host|dev)"`,
check: func(m *Metadata) bool {
return m.Host.value == "local(host|dev)" && m.Host.re != nil &&
m.Host.matches("localhost") && m.Host.matches("localdev") &&
!m.Host.matches("localfoo") && !m.Host.matches("superlocaldev")
},
},
}
for _, ex := range examples {
t.Run(ex.input, func(t *testing.T) {
meta, err := parseMetadata(ex.input)
if ex.err != "" {
assert.Contains(t, err.Error(), ex.err)
}
if ex.check != nil {
assert.Equal(t, true, ex.check(meta))
}
})
}
}
func Test_sanitizeMetadata(t *testing.T) {
examples := []struct {
input string
output string
}{
{input: "", output: ""},
{input: "foo", output: "foo"},
{
input: `
-- pgweb: metadata
query1
-- pgweb: more metadata
query2
`,
output: "query1\nquery2",
},
}
for _, ex := range examples {
t.Run(ex.input, func(t *testing.T) {
assert.Equal(t, ex.output, sanitizeMetadata(ex.input))
})
}
}

23
pkg/queries/query.go Normal file
View File

@@ -0,0 +1,23 @@
package queries
type Query struct {
ID string
Path string
Meta *Metadata
Data string
}
// IsPermitted returns true if a query is allowed to execute for a given db context
func (q Query) IsPermitted(host, user, database, mode string) bool {
// All fields must be provided for matching
if q.Meta == nil || host == "" || user == "" || database == "" || mode == "" {
return false
}
meta := q.Meta
return meta.Host.matches(host) &&
meta.User.matches(user) &&
meta.Database.matches(database) &&
meta.Mode.matches(mode)
}

77
pkg/queries/query_test.go Normal file
View File

@@ -0,0 +1,77 @@
package queries
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestQueryIsPermitted(t *testing.T) {
examples := []struct {
name string
query Query
args []string
expected bool
}{
{
name: "no input provided",
query: makeQuery("localhost", "someuser", "somedb", "default"),
args: makeArgs("", "", "", ""),
expected: false,
},
{
name: "match on host",
query: makeQuery("localhost", "*", "*", "*"),
args: makeArgs("localhost", "user", "db", "default"),
expected: true,
},
{
name: "match on full set",
query: makeQuery("localhost", "user", "database", "mode"),
args: makeArgs("localhost", "someuser", "somedb", "default"),
expected: false,
},
{
name: "match on partial database",
query: makeQuery("localhost", "*", "myapp_*", "*"),
args: makeArgs("localhost", "user", "myapp_development", "default"),
expected: true,
},
{
name: "match on full set but not mode",
query: makeQuery("localhost", "*", "*", "readonly"),
args: makeArgs("localhost", "user", "db", "default"),
expected: false,
},
}
for _, ex := range examples {
t.Run(ex.name, func(t *testing.T) {
result := ex.query.IsPermitted(ex.args[0], ex.args[1], ex.args[2], ex.args[3])
assert.Equal(t, ex.expected, result)
})
}
}
func makeArgs(vals ...string) []string {
return vals
}
func makeQuery(host, user, database, mode string) Query {
mustfield := func(input string) field {
f, err := newField(input)
if err != nil {
panic(err)
}
return f
}
return Query{
Meta: &Metadata{
Host: mustfield(host),
User: mustfield(user),
Database: mustfield(database),
Mode: mustfield(mode),
},
}
}

88
pkg/queries/store.go Normal file
View File

@@ -0,0 +1,88 @@
package queries
import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"
)
var (
ErrQueryDirNotExist = errors.New("queries directory does not exist")
ErrQueryFileNotExist = errors.New("query file does not exist")
)
type Store struct {
dir string
}
func NewStore(dir string) *Store {
return &Store{
dir: dir,
}
}
func (s Store) Read(id string) (*Query, error) {
path := filepath.Join(s.dir, fmt.Sprintf("%s.sql", id))
return readQuery(path)
}
func (s Store) ReadAll() ([]Query, error) {
entries, err := os.ReadDir(s.dir)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
err = ErrQueryDirNotExist
}
return nil, err
}
queries := []Query{}
for _, entry := range entries {
name := entry.Name()
if filepath.Ext(name) != ".sql" {
continue
}
path := filepath.Join(s.dir, name)
query, err := readQuery(path)
if err != nil {
fmt.Fprintf(os.Stderr, "[WARN] skipping %q query file due to error: %v\n", name, err)
continue
}
if query == nil {
continue
}
queries = append(queries, *query)
}
return queries, nil
}
func readQuery(path string) (*Query, error) {
data, err := os.ReadFile(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil, ErrQueryFileNotExist
}
return nil, err
}
dataStr := string(data)
meta, err := parseMetadata(dataStr)
if err != nil {
return nil, err
}
if meta == nil {
return nil, nil
}
return &Query{
ID: strings.Replace(filepath.Base(path), ".sql", "", 1),
Path: path,
Meta: meta,
Data: sanitizeMetadata(dataStr),
}, nil
}

71
pkg/queries/store_test.go Normal file
View File

@@ -0,0 +1,71 @@
//go:build !windows
package queries
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestStoreReadAll(t *testing.T) {
t.Run("valid dir", func(t *testing.T) {
queries, err := NewStore("../../data").ReadAll()
assert.NoError(t, err)
assert.Equal(t, 2, len(queries))
})
t.Run("invalid dir", func(t *testing.T) {
queries, err := NewStore("../../data2").ReadAll()
assert.Equal(t, err.Error(), "queries directory does not exist")
assert.Equal(t, 0, len(queries))
})
}
func TestStoreRead(t *testing.T) {
examples := []struct {
id string
err string
check func(*testing.T, *Query)
}{
{id: "foo", err: "query file does not exist"},
{id: "lc_no_meta"},
{id: "lc_invalid_meta", err: `invalid "mode" field value: "foo"`},
{
id: "lc_example1",
check: func(t *testing.T, q *Query) {
assert.Equal(t, "lc_example1", q.ID)
assert.Equal(t, "../../data/lc_example1.sql", q.Path)
assert.Equal(t, "select 'foo'", q.Data)
assert.Equal(t, "localhost", q.Meta.Host.String())
assert.Equal(t, "*", q.Meta.User.String())
assert.Equal(t, "*", q.Meta.Database.String())
},
},
{
id: "lc_example2",
check: func(t *testing.T, q *Query) {
assert.Equal(t, "lc_example2", q.ID)
assert.Equal(t, "../../data/lc_example2.sql", q.Path)
assert.Equal(t, "-- some comment\nselect 'foo'", q.Data)
assert.Equal(t, "localhost", q.Meta.Host.String())
assert.Equal(t, "foo", q.Meta.User.String())
assert.Equal(t, "*", q.Meta.Database.String())
},
},
}
store := NewStore("../../data")
for _, ex := range examples {
t.Run(ex.id, func(t *testing.T) {
query, err := store.Read(ex.id)
if ex.err != "" || err != nil {
assert.Equal(t, ex.err, err.Error())
}
if ex.check != nil {
ex.check(t, query)
}
})
}
}