Local queries (#641)
* Read local queries from pgweb home directory * Refactor local query functionality * Allow picking local query in the query tab * WIP * Disable local query dropdown during execution * Only allow local queries running in a single session mode * Add middleware to enforce local query endpoint availability * Fix query check * Add query store tests * Make query store errors portable * Skip building specific tests on windows
This commit is contained in:
parent
1c3ab1fd1c
commit
41bf189e6b
2
data/lc_example1.sql
Normal file
2
data/lc_example1.sql
Normal file
@ -0,0 +1,2 @@
|
||||
-- pgweb: host="localhost"
|
||||
select 'foo'
|
5
data/lc_example2.sql
Normal file
5
data/lc_example2.sql
Normal file
@ -0,0 +1,5 @@
|
||||
-- pgweb: host="localhost"
|
||||
-- some comment
|
||||
-- pgweb: user="foo"
|
||||
|
||||
select 'foo'
|
2
data/lc_invalid_meta.sql
Normal file
2
data/lc_invalid_meta.sql
Normal file
@ -0,0 +1,2 @@
|
||||
-- pgweb: host="localhost" mode="foo"
|
||||
select 'foo'
|
1
data/lc_no_meta.sql
Normal file
1
data/lc_no_meta.sql
Normal file
@ -0,0 +1 @@
|
||||
select 'foo'
|
@ -17,6 +17,7 @@ import (
|
||||
"github.com/sosedoff/pgweb/pkg/command"
|
||||
"github.com/sosedoff/pgweb/pkg/connection"
|
||||
"github.com/sosedoff/pgweb/pkg/metrics"
|
||||
"github.com/sosedoff/pgweb/pkg/queries"
|
||||
"github.com/sosedoff/pgweb/pkg/shared"
|
||||
"github.com/sosedoff/pgweb/static"
|
||||
)
|
||||
@ -27,6 +28,9 @@ var (
|
||||
|
||||
// DbSessions represents the mapping for client connections
|
||||
DbSessions *SessionManager
|
||||
|
||||
// QueryStore reads the SQL queries stores in the home directory
|
||||
QueryStore *queries.Store
|
||||
)
|
||||
|
||||
// DB returns a database connection from the client context
|
||||
@ -555,6 +559,7 @@ func GetInfo(c *gin.Context) {
|
||||
"features": gin.H{
|
||||
"session_lock": command.Opts.LockSession,
|
||||
"query_timeout": command.Opts.QueryTimeout,
|
||||
"local_queries": QueryStore != nil,
|
||||
},
|
||||
})
|
||||
}
|
||||
@ -606,3 +611,78 @@ func GetFunction(c *gin.Context) {
|
||||
res, err := DB(c).Function(c.Param("id"))
|
||||
serveResult(c, res, err)
|
||||
}
|
||||
|
||||
func GetLocalQueries(c *gin.Context) {
|
||||
connCtx, err := DB(c).GetConnContext()
|
||||
if err != nil {
|
||||
badRequest(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
storeQueries, err := QueryStore.ReadAll()
|
||||
if err != nil {
|
||||
badRequest(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
queries := []localQuery{}
|
||||
for _, q := range storeQueries {
|
||||
if !q.IsPermitted(connCtx.Host, connCtx.User, connCtx.Database, connCtx.Mode) {
|
||||
continue
|
||||
}
|
||||
|
||||
queries = append(queries, localQuery{
|
||||
ID: q.ID,
|
||||
Title: q.Meta.Title,
|
||||
Description: q.Meta.Description,
|
||||
Query: cleanQuery(q.Data),
|
||||
})
|
||||
}
|
||||
|
||||
successResponse(c, queries)
|
||||
}
|
||||
|
||||
func RunLocalQuery(c *gin.Context) {
|
||||
query, err := QueryStore.Read(c.Param("id"))
|
||||
if err != nil {
|
||||
if err == queries.ErrQueryFileNotExist {
|
||||
query = nil
|
||||
} else {
|
||||
badRequest(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if query == nil {
|
||||
errorResponse(c, 404, "query not found")
|
||||
return
|
||||
}
|
||||
|
||||
connCtx, err := DB(c).GetConnContext()
|
||||
if err != nil {
|
||||
badRequest(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !query.IsPermitted(connCtx.Host, connCtx.User, connCtx.Database, connCtx.Mode) {
|
||||
errorResponse(c, 404, "query not found")
|
||||
return
|
||||
}
|
||||
|
||||
if c.Request.Method == http.MethodGet {
|
||||
successResponse(c, localQuery{
|
||||
ID: query.ID,
|
||||
Title: query.Meta.Title,
|
||||
Description: query.Meta.Description,
|
||||
Query: query.Data,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
statement := cleanQuery(query.Data)
|
||||
if statement == "" {
|
||||
badRequest(c, errQueryRequired)
|
||||
return
|
||||
}
|
||||
|
||||
HandleQuery(statement, c)
|
||||
}
|
||||
|
@ -56,3 +56,14 @@ func corsMiddleware() gin.HandlerFunc {
|
||||
c.Header("Access-Control-Allow-Origin", command.Opts.CorsOrigin)
|
||||
}
|
||||
}
|
||||
|
||||
func requireLocalQueries() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if QueryStore == nil {
|
||||
badRequest(c, "local queries are disabled")
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
@ -54,6 +54,9 @@ func SetupRoutes(router *gin.Engine) {
|
||||
api.GET("/history", GetHistory)
|
||||
api.GET("/bookmarks", GetBookmarks)
|
||||
api.GET("/export", DataExport)
|
||||
api.GET("/local_queries", requireLocalQueries(), GetLocalQueries)
|
||||
api.GET("/local_queries/:id", requireLocalQueries(), RunLocalQuery)
|
||||
api.POST("/local_queries/:id", requireLocalQueries(), RunLocalQuery)
|
||||
}
|
||||
|
||||
func SetupMetrics(engine *gin.Engine) {
|
||||
|
8
pkg/api/types.go
Normal file
8
pkg/api/types.go
Normal file
@ -0,0 +1,8 @@
|
||||
package api
|
||||
|
||||
type localQuery struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Query string `json:"query"`
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
@ -20,6 +21,7 @@ import (
|
||||
"github.com/sosedoff/pgweb/pkg/command"
|
||||
"github.com/sosedoff/pgweb/pkg/connection"
|
||||
"github.com/sosedoff/pgweb/pkg/metrics"
|
||||
"github.com/sosedoff/pgweb/pkg/queries"
|
||||
"github.com/sosedoff/pgweb/pkg/util"
|
||||
)
|
||||
|
||||
@ -28,11 +30,11 @@ var (
|
||||
options command.Options
|
||||
|
||||
readonlyWarning = `
|
||||
------------------------------------------------------
|
||||
SECURITY WARNING: You are running pgweb in read-only mode.
|
||||
This mode is designed for environments where users could potentially delete / change data.
|
||||
For proper read-only access please follow postgresql role management documentation.
|
||||
------------------------------------------------------`
|
||||
--------------------------------------------------------------------------------
|
||||
SECURITY WARNING: You are running Pgweb in read-only mode.
|
||||
This mode is designed for environments where users could potentially delete or change data.
|
||||
For proper read-only access please follow PostgreSQL role management documentation.
|
||||
--------------------------------------------------------------------------------`
|
||||
|
||||
regexErrConnectionRefused = regexp.MustCompile(`(connection|actively) refused`)
|
||||
regexErrAuthFailed = regexp.MustCompile(`authentication failed`)
|
||||
@ -157,9 +159,33 @@ func initOptions() {
|
||||
}
|
||||
}
|
||||
|
||||
configureLocalQueryStore()
|
||||
printVersion()
|
||||
}
|
||||
|
||||
func configureLocalQueryStore() {
|
||||
if options.Sessions || options.QueriesDir == "" {
|
||||
return
|
||||
}
|
||||
|
||||
stat, err := os.Stat(options.QueriesDir)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
logger.Debugf("local queries directory %q does not exist, disabling feature", options.QueriesDir)
|
||||
} else {
|
||||
logger.Debugf("local queries feature disabled due to error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if !stat.IsDir() {
|
||||
logger.Debugf("local queries path %q is not a directory", options.QueriesDir)
|
||||
return
|
||||
}
|
||||
|
||||
api.QueryStore = queries.NewStore(options.QueriesDir)
|
||||
}
|
||||
|
||||
func configureLogger(opts command.Options) error {
|
||||
if options.Debug {
|
||||
logger.SetLevel(logrus.DebugLevel)
|
||||
|
@ -585,3 +585,44 @@ func (client *Client) hasHistoryRecord(query string) bool {
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
type ConnContext struct {
|
||||
Host string
|
||||
User string
|
||||
Database string
|
||||
Mode string
|
||||
}
|
||||
|
||||
func (c ConnContext) String() string {
|
||||
return fmt.Sprintf(
|
||||
"host=%q user=%q database=%q mode=%q",
|
||||
c.Host, c.User, c.Database, c.Mode,
|
||||
)
|
||||
}
|
||||
|
||||
// ConnContext returns information about current database connection
|
||||
func (client *Client) GetConnContext() (*ConnContext, error) {
|
||||
url, err := neturl.Parse(client.ConnectionString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer cancel()
|
||||
|
||||
connCtx := ConnContext{
|
||||
Host: url.Hostname(),
|
||||
Mode: "default",
|
||||
}
|
||||
|
||||
if command.Opts.ReadOnly {
|
||||
connCtx.Mode = "readonly"
|
||||
}
|
||||
|
||||
row := client.db.QueryRowContext(ctx, "SELECT current_user, current_database()")
|
||||
if err := row.Scan(&connCtx.User, &connCtx.Database); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &connCtx, nil
|
||||
}
|
||||
|
@ -660,6 +660,15 @@ func testTablesStats(t *testing.T) {
|
||||
assert.Equal(t, columns, result.Columns)
|
||||
}
|
||||
|
||||
func testConnContext(t *testing.T) {
|
||||
result, err := testClient.GetConnContext()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "localhost", result.Host)
|
||||
assert.Equal(t, "postgres", result.User)
|
||||
assert.Equal(t, "booktown", result.Database)
|
||||
assert.Equal(t, "default", result.Mode)
|
||||
}
|
||||
|
||||
func TestAll(t *testing.T) {
|
||||
if onWindows() {
|
||||
t.Log("Unit testing on Windows platform is not supported.")
|
||||
@ -698,6 +707,7 @@ func TestAll(t *testing.T) {
|
||||
testReadOnlyMode(t)
|
||||
testDumpExport(t)
|
||||
testTablesStats(t)
|
||||
testConnContext(t)
|
||||
|
||||
teardownClient()
|
||||
teardown(t, true)
|
||||
|
@ -48,6 +48,7 @@ type Options struct {
|
||||
LockSession bool `long:"lock-session" description:"Lock session to a single database connection"`
|
||||
Bookmark string `short:"b" long:"bookmark" description:"Bookmark to use for connection. Bookmark files are stored under $HOME/.pgweb/bookmarks/*.toml" default:""`
|
||||
BookmarksDir string `long:"bookmarks-dir" description:"Overrides default directory for bookmark files to search" default:""`
|
||||
QueriesDir string `long:"queries-dir" description:"Overrides default directory for local queries"`
|
||||
DisablePrettyJSON bool `long:"no-pretty-json" description:"Disable JSON formatting feature for result export"`
|
||||
DisableSSH bool `long:"no-ssh" description:"Disable database connections via SSH"`
|
||||
ConnectBackend string `long:"connect-backend" description:"Enable database authentication through a third party backend"`
|
||||
@ -159,10 +160,19 @@ func ParseOptions(args []string) (Options, error) {
|
||||
}
|
||||
}
|
||||
|
||||
homePath, err := homedir.Dir()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "[WARN] cant detect home dir: %v", err)
|
||||
homePath = os.Getenv("HOME")
|
||||
}
|
||||
|
||||
if homePath != "" {
|
||||
if opts.BookmarksDir == "" {
|
||||
path, err := homedir.Dir()
|
||||
if err == nil {
|
||||
opts.BookmarksDir = filepath.Join(path, ".pgweb/bookmarks")
|
||||
opts.BookmarksDir = filepath.Join(homePath, ".pgweb/bookmarks")
|
||||
}
|
||||
|
||||
if opts.QueriesDir == "" {
|
||||
opts.QueriesDir = filepath.Join(homePath, ".pgweb/queries")
|
||||
}
|
||||
}
|
||||
|
||||
|
43
pkg/queries/field.go
Normal file
43
pkg/queries/field.go
Normal file
@ -0,0 +1,43 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type field struct {
|
||||
value string
|
||||
re *regexp.Regexp
|
||||
}
|
||||
|
||||
func (f field) String() string {
|
||||
return f.value
|
||||
}
|
||||
|
||||
func (f field) matches(input string) bool {
|
||||
if f.re != nil {
|
||||
return f.re.MatchString(input)
|
||||
}
|
||||
return f.value == input
|
||||
}
|
||||
|
||||
func newField(value string) (field, error) {
|
||||
f := field{value: value}
|
||||
|
||||
if value == "*" { // match everything
|
||||
f.re = reMatchAll
|
||||
} else if reExpression.MatchString(value) { // match by given expression
|
||||
// Make writing expressions easier for values like "foo_*"
|
||||
if strings.Count(value, "*") == 1 {
|
||||
value = strings.Replace(value, "*", "(.+)", 1)
|
||||
}
|
||||
re, err := regexp.Compile(fmt.Sprintf("^%s$", value))
|
||||
if err != nil {
|
||||
return f, err
|
||||
}
|
||||
f.re = re
|
||||
}
|
||||
|
||||
return f, nil
|
||||
}
|
41
pkg/queries/field_test.go
Normal file
41
pkg/queries/field_test.go
Normal file
@ -0,0 +1,41 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_field(t *testing.T) {
|
||||
field, err := newField("val")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "val", field.value)
|
||||
assert.Equal(t, true, field.matches("val"))
|
||||
assert.Equal(t, false, field.matches("value"))
|
||||
|
||||
field, err = newField("*")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "*", field.value)
|
||||
assert.NotNil(t, field.re)
|
||||
assert.Equal(t, true, field.matches("val"))
|
||||
assert.Equal(t, true, field.matches("value"))
|
||||
|
||||
field, err = newField("(.+")
|
||||
assert.EqualError(t, err, "error parsing regexp: missing closing ): `^(.+$`")
|
||||
assert.NotNil(t, field)
|
||||
|
||||
field, err = newField("foo_*")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "foo_*", field.value)
|
||||
assert.NotNil(t, field.re)
|
||||
assert.Equal(t, false, field.matches("foo"))
|
||||
assert.Equal(t, true, field.matches("foo_bar"))
|
||||
assert.Equal(t, true, field.matches("foo_bar_widget"))
|
||||
|
||||
}
|
||||
|
||||
func Test_fieldString(t *testing.T) {
|
||||
field, err := newField("val")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "val", field.String())
|
||||
}
|
148
pkg/queries/metadata.go
Normal file
148
pkg/queries/metadata.go
Normal file
@ -0,0 +1,148 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
reMetaPrefix = regexp.MustCompile(`(?m)^\s*--\s*pgweb:\s*(.+)`)
|
||||
reMetaContent = regexp.MustCompile(`([\w]+)\s*=\s*"([^"]+)"`)
|
||||
reMatchAll = regexp.MustCompile(`^(.+)$`)
|
||||
reExpression = regexp.MustCompile(`[\[\]\(\)\+\*]+`)
|
||||
|
||||
allowedKeys = []string{"title", "description", "host", "user", "database", "mode", "timeout"}
|
||||
allowedModes = map[string]bool{"readonly": true, "*": true}
|
||||
)
|
||||
|
||||
type Metadata struct {
|
||||
Title string
|
||||
Description string
|
||||
Host field
|
||||
User field
|
||||
Database field
|
||||
Mode field
|
||||
Timeout *time.Duration
|
||||
}
|
||||
|
||||
func parseMetadata(input string) (*Metadata, error) {
|
||||
fields, err := parseFields(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if fields == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Host must be set to limit queries availability
|
||||
if fields["host"] == "" {
|
||||
return nil, fmt.Errorf("host field must be set")
|
||||
}
|
||||
|
||||
// Allow matching for any user, database and mode by default
|
||||
if fields["user"] == "" {
|
||||
fields["user"] = "*"
|
||||
}
|
||||
if fields["database"] == "" {
|
||||
fields["database"] = "*"
|
||||
}
|
||||
if fields["mode"] == "" {
|
||||
fields["mode"] = "*"
|
||||
}
|
||||
|
||||
hostField, err := newField(fields["host"])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(`error initializing "host" field: %w`, err)
|
||||
}
|
||||
|
||||
userField, err := newField(fields["user"])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(`error initializing "user" field: %w`, err)
|
||||
}
|
||||
|
||||
dbField, err := newField(fields["database"])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(`error initializing "database" field: %w`, err)
|
||||
}
|
||||
|
||||
if !allowedModes[fields["mode"]] {
|
||||
return nil, fmt.Errorf(`invalid "mode" field value: %q`, fields["mode"])
|
||||
}
|
||||
modeField, err := newField(fields["mode"])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(`error initializing "mode" field: %w`, err)
|
||||
}
|
||||
|
||||
var timeout *time.Duration
|
||||
if fields["timeout"] != "" {
|
||||
timeoutSec, err := strconv.Atoi(fields["timeout"])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(`error initializing "timeout" field: %w`, err)
|
||||
}
|
||||
timeoutVal := time.Duration(timeoutSec) * time.Second
|
||||
timeout = &timeoutVal
|
||||
}
|
||||
|
||||
return &Metadata{
|
||||
Title: fields["title"],
|
||||
Description: fields["description"],
|
||||
Host: hostField,
|
||||
User: userField,
|
||||
Database: dbField,
|
||||
Mode: modeField,
|
||||
Timeout: timeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseFields(input string) (map[string]string, error) {
|
||||
result := map[string]string{}
|
||||
seenKeys := map[string]bool{}
|
||||
|
||||
allowed := map[string]bool{}
|
||||
for _, key := range allowedKeys {
|
||||
allowed[key] = true
|
||||
}
|
||||
|
||||
matches := reMetaPrefix.FindAllStringSubmatch(input, -1)
|
||||
if len(matches) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for _, match := range matches {
|
||||
content := reMetaContent.FindAllStringSubmatch(match[1], -1)
|
||||
if len(content) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, field := range content {
|
||||
key := field[1]
|
||||
value := field[2]
|
||||
|
||||
if !allowed[key] {
|
||||
return result, fmt.Errorf("unknown key: %q", key)
|
||||
}
|
||||
if seenKeys[key] {
|
||||
return result, fmt.Errorf("duplicate key: %q", key)
|
||||
}
|
||||
|
||||
seenKeys[key] = true
|
||||
result[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func sanitizeMetadata(input string) string {
|
||||
lines := []string{}
|
||||
for _, line := range strings.Split(input, "\n") {
|
||||
line = reMetaPrefix.ReplaceAllString(line, "")
|
||||
if len(line) > 0 {
|
||||
lines = append(lines, line)
|
||||
}
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
146
pkg/queries/metadata_test.go
Normal file
146
pkg/queries/metadata_test.go
Normal file
@ -0,0 +1,146 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_parseFields(t *testing.T) {
|
||||
examples := []struct {
|
||||
input string
|
||||
err error
|
||||
vals map[string]string
|
||||
}{
|
||||
{input: "", err: nil, vals: nil},
|
||||
{input: "foobar", err: nil, vals: nil},
|
||||
{input: "-- no pgweb meta", err: nil, vals: nil},
|
||||
{
|
||||
input: `--pgweb: foo=bar`,
|
||||
err: nil,
|
||||
vals: map[string]string{},
|
||||
},
|
||||
{
|
||||
input: `--pgweb: host="localhost"`,
|
||||
err: nil,
|
||||
vals: map[string]string{"host": "localhost"},
|
||||
},
|
||||
{
|
||||
input: `--pgweb: host="*" user="admin" database ="mydb"; mode = "readonly"`,
|
||||
err: nil,
|
||||
vals: map[string]string{
|
||||
"host": "*",
|
||||
"database": "mydb",
|
||||
"user": "admin",
|
||||
"mode": "readonly",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, ex := range examples {
|
||||
t.Run(ex.input, func(t *testing.T) {
|
||||
fields, err := parseFields(ex.input)
|
||||
assert.Equal(t, ex.err, err)
|
||||
assert.Equal(t, ex.vals, fields)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseMetadata(t *testing.T) {
|
||||
examples := []struct {
|
||||
input string
|
||||
err string
|
||||
check func(meta *Metadata) bool
|
||||
}{
|
||||
{
|
||||
input: `--pgweb: `,
|
||||
err: `host field must be set`,
|
||||
},
|
||||
{
|
||||
input: `--pgweb: hello="world"`,
|
||||
err: `unknown key: "hello"`,
|
||||
},
|
||||
{
|
||||
input: `--pgweb: host="localhost" user="anyuser" database="anydb" mode="foo"`,
|
||||
err: `invalid "mode" field value: "foo"`,
|
||||
},
|
||||
{
|
||||
input: "--pgweb2:",
|
||||
check: func(m *Metadata) bool {
|
||||
return m == nil
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `--pgweb: host="localhost"`,
|
||||
check: func(m *Metadata) bool {
|
||||
return m.Host.value == "localhost" &&
|
||||
m.User.value == "*" &&
|
||||
m.Database.value == "*" &&
|
||||
m.Mode.value == "*" &&
|
||||
m.Timeout == nil
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `--pgweb: host="localhost" user="anyuser" database="anydb" mode="*"`,
|
||||
check: func(m *Metadata) bool {
|
||||
return m.Host.value == "localhost" &&
|
||||
m.Host.re == nil &&
|
||||
m.User.value == "anyuser" &&
|
||||
m.Database.value == "anydb" &&
|
||||
m.Mode.value == "*" &&
|
||||
m.Timeout == nil
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `--pgweb: host="localhost" timeout="foo"`,
|
||||
err: `error initializing "timeout" field: strconv.Atoi: parsing "foo": invalid syntax`,
|
||||
},
|
||||
{
|
||||
input: `-- pgweb: host="local(host|dev)"`,
|
||||
check: func(m *Metadata) bool {
|
||||
return m.Host.value == "local(host|dev)" && m.Host.re != nil &&
|
||||
m.Host.matches("localhost") && m.Host.matches("localdev") &&
|
||||
!m.Host.matches("localfoo") && !m.Host.matches("superlocaldev")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, ex := range examples {
|
||||
t.Run(ex.input, func(t *testing.T) {
|
||||
meta, err := parseMetadata(ex.input)
|
||||
if ex.err != "" {
|
||||
assert.Contains(t, err.Error(), ex.err)
|
||||
}
|
||||
if ex.check != nil {
|
||||
assert.Equal(t, true, ex.check(meta))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sanitizeMetadata(t *testing.T) {
|
||||
examples := []struct {
|
||||
input string
|
||||
output string
|
||||
}{
|
||||
{input: "", output: ""},
|
||||
{input: "foo", output: "foo"},
|
||||
{
|
||||
input: `
|
||||
-- pgweb: metadata
|
||||
query1
|
||||
-- pgweb: more metadata
|
||||
|
||||
query2
|
||||
|
||||
`,
|
||||
output: "query1\nquery2",
|
||||
},
|
||||
}
|
||||
|
||||
for _, ex := range examples {
|
||||
t.Run(ex.input, func(t *testing.T) {
|
||||
assert.Equal(t, ex.output, sanitizeMetadata(ex.input))
|
||||
})
|
||||
}
|
||||
}
|
23
pkg/queries/query.go
Normal file
23
pkg/queries/query.go
Normal file
@ -0,0 +1,23 @@
|
||||
package queries
|
||||
|
||||
type Query struct {
|
||||
ID string
|
||||
Path string
|
||||
Meta *Metadata
|
||||
Data string
|
||||
}
|
||||
|
||||
// IsPermitted returns true if a query is allowed to execute for a given db context
|
||||
func (q Query) IsPermitted(host, user, database, mode string) bool {
|
||||
// All fields must be provided for matching
|
||||
if q.Meta == nil || host == "" || user == "" || database == "" || mode == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
meta := q.Meta
|
||||
|
||||
return meta.Host.matches(host) &&
|
||||
meta.User.matches(user) &&
|
||||
meta.Database.matches(database) &&
|
||||
meta.Mode.matches(mode)
|
||||
}
|
77
pkg/queries/query_test.go
Normal file
77
pkg/queries/query_test.go
Normal file
@ -0,0 +1,77 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestQueryIsPermitted(t *testing.T) {
|
||||
examples := []struct {
|
||||
name string
|
||||
query Query
|
||||
args []string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "no input provided",
|
||||
query: makeQuery("localhost", "someuser", "somedb", "default"),
|
||||
args: makeArgs("", "", "", ""),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "match on host",
|
||||
query: makeQuery("localhost", "*", "*", "*"),
|
||||
args: makeArgs("localhost", "user", "db", "default"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "match on full set",
|
||||
query: makeQuery("localhost", "user", "database", "mode"),
|
||||
args: makeArgs("localhost", "someuser", "somedb", "default"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "match on partial database",
|
||||
query: makeQuery("localhost", "*", "myapp_*", "*"),
|
||||
args: makeArgs("localhost", "user", "myapp_development", "default"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "match on full set but not mode",
|
||||
query: makeQuery("localhost", "*", "*", "readonly"),
|
||||
args: makeArgs("localhost", "user", "db", "default"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, ex := range examples {
|
||||
t.Run(ex.name, func(t *testing.T) {
|
||||
result := ex.query.IsPermitted(ex.args[0], ex.args[1], ex.args[2], ex.args[3])
|
||||
assert.Equal(t, ex.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func makeArgs(vals ...string) []string {
|
||||
return vals
|
||||
}
|
||||
|
||||
func makeQuery(host, user, database, mode string) Query {
|
||||
mustfield := func(input string) field {
|
||||
f, err := newField(input)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
return Query{
|
||||
Meta: &Metadata{
|
||||
Host: mustfield(host),
|
||||
User: mustfield(user),
|
||||
Database: mustfield(database),
|
||||
Mode: mustfield(mode),
|
||||
},
|
||||
}
|
||||
}
|
88
pkg/queries/store.go
Normal file
88
pkg/queries/store.go
Normal file
@ -0,0 +1,88 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrQueryDirNotExist = errors.New("queries directory does not exist")
|
||||
ErrQueryFileNotExist = errors.New("query file does not exist")
|
||||
)
|
||||
|
||||
type Store struct {
|
||||
dir string
|
||||
}
|
||||
|
||||
func NewStore(dir string) *Store {
|
||||
return &Store{
|
||||
dir: dir,
|
||||
}
|
||||
}
|
||||
|
||||
func (s Store) Read(id string) (*Query, error) {
|
||||
path := filepath.Join(s.dir, fmt.Sprintf("%s.sql", id))
|
||||
return readQuery(path)
|
||||
}
|
||||
|
||||
func (s Store) ReadAll() ([]Query, error) {
|
||||
entries, err := os.ReadDir(s.dir)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
err = ErrQueryDirNotExist
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
queries := []Query{}
|
||||
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if filepath.Ext(name) != ".sql" {
|
||||
continue
|
||||
}
|
||||
|
||||
path := filepath.Join(s.dir, name)
|
||||
query, err := readQuery(path)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "[WARN] skipping %q query file due to error: %v\n", name, err)
|
||||
continue
|
||||
}
|
||||
if query == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
queries = append(queries, *query)
|
||||
}
|
||||
|
||||
return queries, nil
|
||||
}
|
||||
|
||||
func readQuery(path string) (*Query, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil, ErrQueryFileNotExist
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
dataStr := string(data)
|
||||
|
||||
meta, err := parseMetadata(dataStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if meta == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &Query{
|
||||
ID: strings.Replace(filepath.Base(path), ".sql", "", 1),
|
||||
Path: path,
|
||||
Meta: meta,
|
||||
Data: sanitizeMetadata(dataStr),
|
||||
}, nil
|
||||
}
|
71
pkg/queries/store_test.go
Normal file
71
pkg/queries/store_test.go
Normal file
@ -0,0 +1,71 @@
|
||||
//go:build !windows
|
||||
|
||||
package queries
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestStoreReadAll(t *testing.T) {
|
||||
t.Run("valid dir", func(t *testing.T) {
|
||||
queries, err := NewStore("../../data").ReadAll()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(queries))
|
||||
})
|
||||
|
||||
t.Run("invalid dir", func(t *testing.T) {
|
||||
queries, err := NewStore("../../data2").ReadAll()
|
||||
assert.Equal(t, err.Error(), "queries directory does not exist")
|
||||
assert.Equal(t, 0, len(queries))
|
||||
})
|
||||
}
|
||||
|
||||
func TestStoreRead(t *testing.T) {
|
||||
examples := []struct {
|
||||
id string
|
||||
err string
|
||||
check func(*testing.T, *Query)
|
||||
}{
|
||||
{id: "foo", err: "query file does not exist"},
|
||||
{id: "lc_no_meta"},
|
||||
{id: "lc_invalid_meta", err: `invalid "mode" field value: "foo"`},
|
||||
{
|
||||
id: "lc_example1",
|
||||
check: func(t *testing.T, q *Query) {
|
||||
assert.Equal(t, "lc_example1", q.ID)
|
||||
assert.Equal(t, "../../data/lc_example1.sql", q.Path)
|
||||
assert.Equal(t, "select 'foo'", q.Data)
|
||||
assert.Equal(t, "localhost", q.Meta.Host.String())
|
||||
assert.Equal(t, "*", q.Meta.User.String())
|
||||
assert.Equal(t, "*", q.Meta.Database.String())
|
||||
},
|
||||
},
|
||||
{
|
||||
id: "lc_example2",
|
||||
check: func(t *testing.T, q *Query) {
|
||||
assert.Equal(t, "lc_example2", q.ID)
|
||||
assert.Equal(t, "../../data/lc_example2.sql", q.Path)
|
||||
assert.Equal(t, "-- some comment\nselect 'foo'", q.Data)
|
||||
assert.Equal(t, "localhost", q.Meta.Host.String())
|
||||
assert.Equal(t, "foo", q.Meta.User.String())
|
||||
assert.Equal(t, "*", q.Meta.Database.String())
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
store := NewStore("../../data")
|
||||
|
||||
for _, ex := range examples {
|
||||
t.Run(ex.id, func(t *testing.T) {
|
||||
query, err := store.Read(ex.id)
|
||||
if ex.err != "" || err != nil {
|
||||
assert.Equal(t, ex.err, err.Error())
|
||||
}
|
||||
if ex.check != nil {
|
||||
ex.check(t, query)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -357,6 +357,7 @@
|
||||
#input .actions #query_progress {
|
||||
display: none;
|
||||
float: left;
|
||||
font-size: 12px;
|
||||
line-height: 30px;
|
||||
height: 30px;
|
||||
color: #aaa;
|
||||
|
@ -87,6 +87,13 @@
|
||||
<li><a href="#" id="analyze">Analyze Query</a></li>
|
||||
</ul>
|
||||
</div>
|
||||
<div id="load-query-dropdown" class="btn-group left" style="display: none">
|
||||
<button id="load-local-query" type="button" class="btn btn-default dropdown-toggle" data-toggle="dropdown" disabled="disabled">
|
||||
Template <span class="caret"></span>
|
||||
</button>
|
||||
<ul class="dropdown-menu" role="menu">
|
||||
</ul>
|
||||
</div>
|
||||
<div id="query_progress">Please wait, query is executing...</div>
|
||||
<div class="pull-right">
|
||||
<span id="result-rows-count"></span>
|
||||
|
@ -178,6 +178,33 @@ function buildSchemaSection(name, objects) {
|
||||
return section;
|
||||
}
|
||||
|
||||
function loadLocalQueries() {
|
||||
if (!appFeatures.local_queries) return;
|
||||
|
||||
$("body").on("click", "a.load-local-query", function(e) {
|
||||
var id = $(this).data("id");
|
||||
|
||||
apiCall("get", "/local_queries/" + id, {}, function(resp) {
|
||||
editor.setValue(resp.query);
|
||||
editor.clearSelection();
|
||||
});
|
||||
});
|
||||
|
||||
apiCall("get", "/local_queries", {}, function(resp) {
|
||||
if (resp.error) return;
|
||||
|
||||
var container = $("#load-query-dropdown").find(".dropdown-menu");
|
||||
|
||||
resp.forEach(function(item) {
|
||||
var title = item.title || item.id;
|
||||
$("<li><a href='#' class='load-local-query' data-id='" + item.id + "'>" + title + "</a></li>").appendTo(container);
|
||||
});
|
||||
|
||||
if (resp.length > 0) $("#load-local-query").prop("disabled", "");
|
||||
$("#load-query-dropdown").show();
|
||||
});
|
||||
}
|
||||
|
||||
function loadSchemas() {
|
||||
$("#objects").html("");
|
||||
|
||||
@ -738,13 +765,13 @@ function showActivityPanel() {
|
||||
}
|
||||
|
||||
function showQueryProgressMessage() {
|
||||
$("#run, #explain-dropdown-toggle, #csv, #json, #xml").prop("disabled", true);
|
||||
$("#run, #explain-dropdown-toggle, #csv, #json, #xml, #load-local-query").prop("disabled", true);
|
||||
$("#explain-dropdown").removeClass("open");
|
||||
$("#query_progress").show();
|
||||
}
|
||||
|
||||
function hideQueryProgressMessage() {
|
||||
$("#run, #explain-dropdown-toggle, #csv, #json, #xml").prop("disabled", false);
|
||||
$("#run, #explain-dropdown-toggle, #csv, #json, #xml, #load-local-query").prop("disabled", false);
|
||||
$("#query_progress").hide();
|
||||
}
|
||||
|
||||
@ -1810,6 +1837,7 @@ $(document).ready(function() {
|
||||
|
||||
connected = true;
|
||||
loadSchemas();
|
||||
loadLocalQueries();
|
||||
|
||||
$("#current_database").text(resp.current_database);
|
||||
$("#main").show();
|
||||
|
Loading…
x
Reference in New Issue
Block a user