Local queries (#641)

* Read local queries from pgweb home directory
* Refactor local query functionality
* Allow picking local query in the query tab
* WIP
* Disable local query dropdown during execution
* Only allow local queries running in a single session mode
* Add middleware to enforce local query endpoint availability
* Fix query check
* Add query store tests
* Make query store errors portable
* Skip building specific tests on windows
This commit is contained in:
Dan Sosedoff 2023-02-02 16:13:14 -06:00 committed by GitHub
parent 1c3ab1fd1c
commit 41bf189e6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 884 additions and 12 deletions

2
data/lc_example1.sql Normal file
View File

@ -0,0 +1,2 @@
-- pgweb: host="localhost"
select 'foo'

5
data/lc_example2.sql Normal file
View File

@ -0,0 +1,5 @@
-- pgweb: host="localhost"
-- some comment
-- pgweb: user="foo"
select 'foo'

2
data/lc_invalid_meta.sql Normal file
View File

@ -0,0 +1,2 @@
-- pgweb: host="localhost" mode="foo"
select 'foo'

1
data/lc_no_meta.sql Normal file
View File

@ -0,0 +1 @@
select 'foo'

View File

@ -17,6 +17,7 @@ import (
"github.com/sosedoff/pgweb/pkg/command" "github.com/sosedoff/pgweb/pkg/command"
"github.com/sosedoff/pgweb/pkg/connection" "github.com/sosedoff/pgweb/pkg/connection"
"github.com/sosedoff/pgweb/pkg/metrics" "github.com/sosedoff/pgweb/pkg/metrics"
"github.com/sosedoff/pgweb/pkg/queries"
"github.com/sosedoff/pgweb/pkg/shared" "github.com/sosedoff/pgweb/pkg/shared"
"github.com/sosedoff/pgweb/static" "github.com/sosedoff/pgweb/static"
) )
@ -27,6 +28,9 @@ var (
// DbSessions represents the mapping for client connections // DbSessions represents the mapping for client connections
DbSessions *SessionManager DbSessions *SessionManager
// QueryStore reads the SQL queries stores in the home directory
QueryStore *queries.Store
) )
// DB returns a database connection from the client context // DB returns a database connection from the client context
@ -555,6 +559,7 @@ func GetInfo(c *gin.Context) {
"features": gin.H{ "features": gin.H{
"session_lock": command.Opts.LockSession, "session_lock": command.Opts.LockSession,
"query_timeout": command.Opts.QueryTimeout, "query_timeout": command.Opts.QueryTimeout,
"local_queries": QueryStore != nil,
}, },
}) })
} }
@ -606,3 +611,78 @@ func GetFunction(c *gin.Context) {
res, err := DB(c).Function(c.Param("id")) res, err := DB(c).Function(c.Param("id"))
serveResult(c, res, err) serveResult(c, res, err)
} }
func GetLocalQueries(c *gin.Context) {
connCtx, err := DB(c).GetConnContext()
if err != nil {
badRequest(c, err)
return
}
storeQueries, err := QueryStore.ReadAll()
if err != nil {
badRequest(c, err)
return
}
queries := []localQuery{}
for _, q := range storeQueries {
if !q.IsPermitted(connCtx.Host, connCtx.User, connCtx.Database, connCtx.Mode) {
continue
}
queries = append(queries, localQuery{
ID: q.ID,
Title: q.Meta.Title,
Description: q.Meta.Description,
Query: cleanQuery(q.Data),
})
}
successResponse(c, queries)
}
func RunLocalQuery(c *gin.Context) {
query, err := QueryStore.Read(c.Param("id"))
if err != nil {
if err == queries.ErrQueryFileNotExist {
query = nil
} else {
badRequest(c, err)
return
}
}
if query == nil {
errorResponse(c, 404, "query not found")
return
}
connCtx, err := DB(c).GetConnContext()
if err != nil {
badRequest(c, err)
return
}
if !query.IsPermitted(connCtx.Host, connCtx.User, connCtx.Database, connCtx.Mode) {
errorResponse(c, 404, "query not found")
return
}
if c.Request.Method == http.MethodGet {
successResponse(c, localQuery{
ID: query.ID,
Title: query.Meta.Title,
Description: query.Meta.Description,
Query: query.Data,
})
return
}
statement := cleanQuery(query.Data)
if statement == "" {
badRequest(c, errQueryRequired)
return
}
HandleQuery(statement, c)
}

View File

@ -56,3 +56,14 @@ func corsMiddleware() gin.HandlerFunc {
c.Header("Access-Control-Allow-Origin", command.Opts.CorsOrigin) c.Header("Access-Control-Allow-Origin", command.Opts.CorsOrigin)
} }
} }
func requireLocalQueries() gin.HandlerFunc {
return func(c *gin.Context) {
if QueryStore == nil {
badRequest(c, "local queries are disabled")
return
}
c.Next()
}
}

View File

@ -54,6 +54,9 @@ func SetupRoutes(router *gin.Engine) {
api.GET("/history", GetHistory) api.GET("/history", GetHistory)
api.GET("/bookmarks", GetBookmarks) api.GET("/bookmarks", GetBookmarks)
api.GET("/export", DataExport) api.GET("/export", DataExport)
api.GET("/local_queries", requireLocalQueries(), GetLocalQueries)
api.GET("/local_queries/:id", requireLocalQueries(), RunLocalQuery)
api.POST("/local_queries/:id", requireLocalQueries(), RunLocalQuery)
} }
func SetupMetrics(engine *gin.Engine) { func SetupMetrics(engine *gin.Engine) {

8
pkg/api/types.go Normal file
View File

@ -0,0 +1,8 @@
package api
type localQuery struct {
ID string `json:"id"`
Title string `json:"title,omitempty"`
Description string `json:"description,omitempty"`
Query string `json:"query"`
}

View File

@ -1,6 +1,7 @@
package cli package cli
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
"os/exec" "os/exec"
@ -20,6 +21,7 @@ import (
"github.com/sosedoff/pgweb/pkg/command" "github.com/sosedoff/pgweb/pkg/command"
"github.com/sosedoff/pgweb/pkg/connection" "github.com/sosedoff/pgweb/pkg/connection"
"github.com/sosedoff/pgweb/pkg/metrics" "github.com/sosedoff/pgweb/pkg/metrics"
"github.com/sosedoff/pgweb/pkg/queries"
"github.com/sosedoff/pgweb/pkg/util" "github.com/sosedoff/pgweb/pkg/util"
) )
@ -28,11 +30,11 @@ var (
options command.Options options command.Options
readonlyWarning = ` readonlyWarning = `
------------------------------------------------------ --------------------------------------------------------------------------------
SECURITY WARNING: You are running pgweb in read-only mode. SECURITY WARNING: You are running Pgweb in read-only mode.
This mode is designed for environments where users could potentially delete / change data. This mode is designed for environments where users could potentially delete or change data.
For proper read-only access please follow postgresql role management documentation. For proper read-only access please follow PostgreSQL role management documentation.
------------------------------------------------------` --------------------------------------------------------------------------------`
regexErrConnectionRefused = regexp.MustCompile(`(connection|actively) refused`) regexErrConnectionRefused = regexp.MustCompile(`(connection|actively) refused`)
regexErrAuthFailed = regexp.MustCompile(`authentication failed`) regexErrAuthFailed = regexp.MustCompile(`authentication failed`)
@ -157,9 +159,33 @@ func initOptions() {
} }
} }
configureLocalQueryStore()
printVersion() printVersion()
} }
func configureLocalQueryStore() {
if options.Sessions || options.QueriesDir == "" {
return
}
stat, err := os.Stat(options.QueriesDir)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
logger.Debugf("local queries directory %q does not exist, disabling feature", options.QueriesDir)
} else {
logger.Debugf("local queries feature disabled due to error: %v", err)
}
return
}
if !stat.IsDir() {
logger.Debugf("local queries path %q is not a directory", options.QueriesDir)
return
}
api.QueryStore = queries.NewStore(options.QueriesDir)
}
func configureLogger(opts command.Options) error { func configureLogger(opts command.Options) error {
if options.Debug { if options.Debug {
logger.SetLevel(logrus.DebugLevel) logger.SetLevel(logrus.DebugLevel)

View File

@ -585,3 +585,44 @@ func (client *Client) hasHistoryRecord(query string) bool {
return result return result
} }
type ConnContext struct {
Host string
User string
Database string
Mode string
}
func (c ConnContext) String() string {
return fmt.Sprintf(
"host=%q user=%q database=%q mode=%q",
c.Host, c.User, c.Database, c.Mode,
)
}
// ConnContext returns information about current database connection
func (client *Client) GetConnContext() (*ConnContext, error) {
url, err := neturl.Parse(client.ConnectionString)
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
connCtx := ConnContext{
Host: url.Hostname(),
Mode: "default",
}
if command.Opts.ReadOnly {
connCtx.Mode = "readonly"
}
row := client.db.QueryRowContext(ctx, "SELECT current_user, current_database()")
if err := row.Scan(&connCtx.User, &connCtx.Database); err != nil {
return nil, err
}
return &connCtx, nil
}

View File

@ -660,6 +660,15 @@ func testTablesStats(t *testing.T) {
assert.Equal(t, columns, result.Columns) assert.Equal(t, columns, result.Columns)
} }
func testConnContext(t *testing.T) {
result, err := testClient.GetConnContext()
assert.NoError(t, err)
assert.Equal(t, "localhost", result.Host)
assert.Equal(t, "postgres", result.User)
assert.Equal(t, "booktown", result.Database)
assert.Equal(t, "default", result.Mode)
}
func TestAll(t *testing.T) { func TestAll(t *testing.T) {
if onWindows() { if onWindows() {
t.Log("Unit testing on Windows platform is not supported.") t.Log("Unit testing on Windows platform is not supported.")
@ -698,6 +707,7 @@ func TestAll(t *testing.T) {
testReadOnlyMode(t) testReadOnlyMode(t)
testDumpExport(t) testDumpExport(t)
testTablesStats(t) testTablesStats(t)
testConnContext(t)
teardownClient() teardownClient()
teardown(t, true) teardown(t, true)

View File

@ -48,6 +48,7 @@ type Options struct {
LockSession bool `long:"lock-session" description:"Lock session to a single database connection"` LockSession bool `long:"lock-session" description:"Lock session to a single database connection"`
Bookmark string `short:"b" long:"bookmark" description:"Bookmark to use for connection. Bookmark files are stored under $HOME/.pgweb/bookmarks/*.toml" default:""` Bookmark string `short:"b" long:"bookmark" description:"Bookmark to use for connection. Bookmark files are stored under $HOME/.pgweb/bookmarks/*.toml" default:""`
BookmarksDir string `long:"bookmarks-dir" description:"Overrides default directory for bookmark files to search" default:""` BookmarksDir string `long:"bookmarks-dir" description:"Overrides default directory for bookmark files to search" default:""`
QueriesDir string `long:"queries-dir" description:"Overrides default directory for local queries"`
DisablePrettyJSON bool `long:"no-pretty-json" description:"Disable JSON formatting feature for result export"` DisablePrettyJSON bool `long:"no-pretty-json" description:"Disable JSON formatting feature for result export"`
DisableSSH bool `long:"no-ssh" description:"Disable database connections via SSH"` DisableSSH bool `long:"no-ssh" description:"Disable database connections via SSH"`
ConnectBackend string `long:"connect-backend" description:"Enable database authentication through a third party backend"` ConnectBackend string `long:"connect-backend" description:"Enable database authentication through a third party backend"`
@ -159,10 +160,19 @@ func ParseOptions(args []string) (Options, error) {
} }
} }
homePath, err := homedir.Dir()
if err != nil {
fmt.Fprintf(os.Stderr, "[WARN] cant detect home dir: %v", err)
homePath = os.Getenv("HOME")
}
if homePath != "" {
if opts.BookmarksDir == "" { if opts.BookmarksDir == "" {
path, err := homedir.Dir() opts.BookmarksDir = filepath.Join(homePath, ".pgweb/bookmarks")
if err == nil { }
opts.BookmarksDir = filepath.Join(path, ".pgweb/bookmarks")
if opts.QueriesDir == "" {
opts.QueriesDir = filepath.Join(homePath, ".pgweb/queries")
} }
} }

43
pkg/queries/field.go Normal file
View File

@ -0,0 +1,43 @@
package queries
import (
"fmt"
"regexp"
"strings"
)
type field struct {
value string
re *regexp.Regexp
}
func (f field) String() string {
return f.value
}
func (f field) matches(input string) bool {
if f.re != nil {
return f.re.MatchString(input)
}
return f.value == input
}
func newField(value string) (field, error) {
f := field{value: value}
if value == "*" { // match everything
f.re = reMatchAll
} else if reExpression.MatchString(value) { // match by given expression
// Make writing expressions easier for values like "foo_*"
if strings.Count(value, "*") == 1 {
value = strings.Replace(value, "*", "(.+)", 1)
}
re, err := regexp.Compile(fmt.Sprintf("^%s$", value))
if err != nil {
return f, err
}
f.re = re
}
return f, nil
}

41
pkg/queries/field_test.go Normal file
View File

@ -0,0 +1,41 @@
package queries
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_field(t *testing.T) {
field, err := newField("val")
assert.NoError(t, err)
assert.Equal(t, "val", field.value)
assert.Equal(t, true, field.matches("val"))
assert.Equal(t, false, field.matches("value"))
field, err = newField("*")
assert.NoError(t, err)
assert.Equal(t, "*", field.value)
assert.NotNil(t, field.re)
assert.Equal(t, true, field.matches("val"))
assert.Equal(t, true, field.matches("value"))
field, err = newField("(.+")
assert.EqualError(t, err, "error parsing regexp: missing closing ): `^(.+$`")
assert.NotNil(t, field)
field, err = newField("foo_*")
assert.NoError(t, err)
assert.Equal(t, "foo_*", field.value)
assert.NotNil(t, field.re)
assert.Equal(t, false, field.matches("foo"))
assert.Equal(t, true, field.matches("foo_bar"))
assert.Equal(t, true, field.matches("foo_bar_widget"))
}
func Test_fieldString(t *testing.T) {
field, err := newField("val")
assert.NoError(t, err)
assert.Equal(t, "val", field.String())
}

148
pkg/queries/metadata.go Normal file
View File

@ -0,0 +1,148 @@
package queries
import (
"fmt"
"regexp"
"strconv"
"strings"
"time"
)
var (
reMetaPrefix = regexp.MustCompile(`(?m)^\s*--\s*pgweb:\s*(.+)`)
reMetaContent = regexp.MustCompile(`([\w]+)\s*=\s*"([^"]+)"`)
reMatchAll = regexp.MustCompile(`^(.+)$`)
reExpression = regexp.MustCompile(`[\[\]\(\)\+\*]+`)
allowedKeys = []string{"title", "description", "host", "user", "database", "mode", "timeout"}
allowedModes = map[string]bool{"readonly": true, "*": true}
)
type Metadata struct {
Title string
Description string
Host field
User field
Database field
Mode field
Timeout *time.Duration
}
func parseMetadata(input string) (*Metadata, error) {
fields, err := parseFields(input)
if err != nil {
return nil, err
}
if fields == nil {
return nil, nil
}
// Host must be set to limit queries availability
if fields["host"] == "" {
return nil, fmt.Errorf("host field must be set")
}
// Allow matching for any user, database and mode by default
if fields["user"] == "" {
fields["user"] = "*"
}
if fields["database"] == "" {
fields["database"] = "*"
}
if fields["mode"] == "" {
fields["mode"] = "*"
}
hostField, err := newField(fields["host"])
if err != nil {
return nil, fmt.Errorf(`error initializing "host" field: %w`, err)
}
userField, err := newField(fields["user"])
if err != nil {
return nil, fmt.Errorf(`error initializing "user" field: %w`, err)
}
dbField, err := newField(fields["database"])
if err != nil {
return nil, fmt.Errorf(`error initializing "database" field: %w`, err)
}
if !allowedModes[fields["mode"]] {
return nil, fmt.Errorf(`invalid "mode" field value: %q`, fields["mode"])
}
modeField, err := newField(fields["mode"])
if err != nil {
return nil, fmt.Errorf(`error initializing "mode" field: %w`, err)
}
var timeout *time.Duration
if fields["timeout"] != "" {
timeoutSec, err := strconv.Atoi(fields["timeout"])
if err != nil {
return nil, fmt.Errorf(`error initializing "timeout" field: %w`, err)
}
timeoutVal := time.Duration(timeoutSec) * time.Second
timeout = &timeoutVal
}
return &Metadata{
Title: fields["title"],
Description: fields["description"],
Host: hostField,
User: userField,
Database: dbField,
Mode: modeField,
Timeout: timeout,
}, nil
}
func parseFields(input string) (map[string]string, error) {
result := map[string]string{}
seenKeys := map[string]bool{}
allowed := map[string]bool{}
for _, key := range allowedKeys {
allowed[key] = true
}
matches := reMetaPrefix.FindAllStringSubmatch(input, -1)
if len(matches) == 0 {
return nil, nil
}
for _, match := range matches {
content := reMetaContent.FindAllStringSubmatch(match[1], -1)
if len(content) == 0 {
continue
}
for _, field := range content {
key := field[1]
value := field[2]
if !allowed[key] {
return result, fmt.Errorf("unknown key: %q", key)
}
if seenKeys[key] {
return result, fmt.Errorf("duplicate key: %q", key)
}
seenKeys[key] = true
result[key] = value
}
}
return result, nil
}
func sanitizeMetadata(input string) string {
lines := []string{}
for _, line := range strings.Split(input, "\n") {
line = reMetaPrefix.ReplaceAllString(line, "")
if len(line) > 0 {
lines = append(lines, line)
}
}
return strings.Join(lines, "\n")
}

View File

@ -0,0 +1,146 @@
package queries
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_parseFields(t *testing.T) {
examples := []struct {
input string
err error
vals map[string]string
}{
{input: "", err: nil, vals: nil},
{input: "foobar", err: nil, vals: nil},
{input: "-- no pgweb meta", err: nil, vals: nil},
{
input: `--pgweb: foo=bar`,
err: nil,
vals: map[string]string{},
},
{
input: `--pgweb: host="localhost"`,
err: nil,
vals: map[string]string{"host": "localhost"},
},
{
input: `--pgweb: host="*" user="admin" database ="mydb"; mode = "readonly"`,
err: nil,
vals: map[string]string{
"host": "*",
"database": "mydb",
"user": "admin",
"mode": "readonly",
},
},
}
for _, ex := range examples {
t.Run(ex.input, func(t *testing.T) {
fields, err := parseFields(ex.input)
assert.Equal(t, ex.err, err)
assert.Equal(t, ex.vals, fields)
})
}
}
func Test_parseMetadata(t *testing.T) {
examples := []struct {
input string
err string
check func(meta *Metadata) bool
}{
{
input: `--pgweb: `,
err: `host field must be set`,
},
{
input: `--pgweb: hello="world"`,
err: `unknown key: "hello"`,
},
{
input: `--pgweb: host="localhost" user="anyuser" database="anydb" mode="foo"`,
err: `invalid "mode" field value: "foo"`,
},
{
input: "--pgweb2:",
check: func(m *Metadata) bool {
return m == nil
},
},
{
input: `--pgweb: host="localhost"`,
check: func(m *Metadata) bool {
return m.Host.value == "localhost" &&
m.User.value == "*" &&
m.Database.value == "*" &&
m.Mode.value == "*" &&
m.Timeout == nil
},
},
{
input: `--pgweb: host="localhost" user="anyuser" database="anydb" mode="*"`,
check: func(m *Metadata) bool {
return m.Host.value == "localhost" &&
m.Host.re == nil &&
m.User.value == "anyuser" &&
m.Database.value == "anydb" &&
m.Mode.value == "*" &&
m.Timeout == nil
},
},
{
input: `--pgweb: host="localhost" timeout="foo"`,
err: `error initializing "timeout" field: strconv.Atoi: parsing "foo": invalid syntax`,
},
{
input: `-- pgweb: host="local(host|dev)"`,
check: func(m *Metadata) bool {
return m.Host.value == "local(host|dev)" && m.Host.re != nil &&
m.Host.matches("localhost") && m.Host.matches("localdev") &&
!m.Host.matches("localfoo") && !m.Host.matches("superlocaldev")
},
},
}
for _, ex := range examples {
t.Run(ex.input, func(t *testing.T) {
meta, err := parseMetadata(ex.input)
if ex.err != "" {
assert.Contains(t, err.Error(), ex.err)
}
if ex.check != nil {
assert.Equal(t, true, ex.check(meta))
}
})
}
}
func Test_sanitizeMetadata(t *testing.T) {
examples := []struct {
input string
output string
}{
{input: "", output: ""},
{input: "foo", output: "foo"},
{
input: `
-- pgweb: metadata
query1
-- pgweb: more metadata
query2
`,
output: "query1\nquery2",
},
}
for _, ex := range examples {
t.Run(ex.input, func(t *testing.T) {
assert.Equal(t, ex.output, sanitizeMetadata(ex.input))
})
}
}

23
pkg/queries/query.go Normal file
View File

@ -0,0 +1,23 @@
package queries
type Query struct {
ID string
Path string
Meta *Metadata
Data string
}
// IsPermitted returns true if a query is allowed to execute for a given db context
func (q Query) IsPermitted(host, user, database, mode string) bool {
// All fields must be provided for matching
if q.Meta == nil || host == "" || user == "" || database == "" || mode == "" {
return false
}
meta := q.Meta
return meta.Host.matches(host) &&
meta.User.matches(user) &&
meta.Database.matches(database) &&
meta.Mode.matches(mode)
}

77
pkg/queries/query_test.go Normal file
View File

@ -0,0 +1,77 @@
package queries
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestQueryIsPermitted(t *testing.T) {
examples := []struct {
name string
query Query
args []string
expected bool
}{
{
name: "no input provided",
query: makeQuery("localhost", "someuser", "somedb", "default"),
args: makeArgs("", "", "", ""),
expected: false,
},
{
name: "match on host",
query: makeQuery("localhost", "*", "*", "*"),
args: makeArgs("localhost", "user", "db", "default"),
expected: true,
},
{
name: "match on full set",
query: makeQuery("localhost", "user", "database", "mode"),
args: makeArgs("localhost", "someuser", "somedb", "default"),
expected: false,
},
{
name: "match on partial database",
query: makeQuery("localhost", "*", "myapp_*", "*"),
args: makeArgs("localhost", "user", "myapp_development", "default"),
expected: true,
},
{
name: "match on full set but not mode",
query: makeQuery("localhost", "*", "*", "readonly"),
args: makeArgs("localhost", "user", "db", "default"),
expected: false,
},
}
for _, ex := range examples {
t.Run(ex.name, func(t *testing.T) {
result := ex.query.IsPermitted(ex.args[0], ex.args[1], ex.args[2], ex.args[3])
assert.Equal(t, ex.expected, result)
})
}
}
func makeArgs(vals ...string) []string {
return vals
}
func makeQuery(host, user, database, mode string) Query {
mustfield := func(input string) field {
f, err := newField(input)
if err != nil {
panic(err)
}
return f
}
return Query{
Meta: &Metadata{
Host: mustfield(host),
User: mustfield(user),
Database: mustfield(database),
Mode: mustfield(mode),
},
}
}

88
pkg/queries/store.go Normal file
View File

@ -0,0 +1,88 @@
package queries
import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"
)
var (
ErrQueryDirNotExist = errors.New("queries directory does not exist")
ErrQueryFileNotExist = errors.New("query file does not exist")
)
type Store struct {
dir string
}
func NewStore(dir string) *Store {
return &Store{
dir: dir,
}
}
func (s Store) Read(id string) (*Query, error) {
path := filepath.Join(s.dir, fmt.Sprintf("%s.sql", id))
return readQuery(path)
}
func (s Store) ReadAll() ([]Query, error) {
entries, err := os.ReadDir(s.dir)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
err = ErrQueryDirNotExist
}
return nil, err
}
queries := []Query{}
for _, entry := range entries {
name := entry.Name()
if filepath.Ext(name) != ".sql" {
continue
}
path := filepath.Join(s.dir, name)
query, err := readQuery(path)
if err != nil {
fmt.Fprintf(os.Stderr, "[WARN] skipping %q query file due to error: %v\n", name, err)
continue
}
if query == nil {
continue
}
queries = append(queries, *query)
}
return queries, nil
}
func readQuery(path string) (*Query, error) {
data, err := os.ReadFile(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil, ErrQueryFileNotExist
}
return nil, err
}
dataStr := string(data)
meta, err := parseMetadata(dataStr)
if err != nil {
return nil, err
}
if meta == nil {
return nil, nil
}
return &Query{
ID: strings.Replace(filepath.Base(path), ".sql", "", 1),
Path: path,
Meta: meta,
Data: sanitizeMetadata(dataStr),
}, nil
}

71
pkg/queries/store_test.go Normal file
View File

@ -0,0 +1,71 @@
//go:build !windows
package queries
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestStoreReadAll(t *testing.T) {
t.Run("valid dir", func(t *testing.T) {
queries, err := NewStore("../../data").ReadAll()
assert.NoError(t, err)
assert.Equal(t, 2, len(queries))
})
t.Run("invalid dir", func(t *testing.T) {
queries, err := NewStore("../../data2").ReadAll()
assert.Equal(t, err.Error(), "queries directory does not exist")
assert.Equal(t, 0, len(queries))
})
}
func TestStoreRead(t *testing.T) {
examples := []struct {
id string
err string
check func(*testing.T, *Query)
}{
{id: "foo", err: "query file does not exist"},
{id: "lc_no_meta"},
{id: "lc_invalid_meta", err: `invalid "mode" field value: "foo"`},
{
id: "lc_example1",
check: func(t *testing.T, q *Query) {
assert.Equal(t, "lc_example1", q.ID)
assert.Equal(t, "../../data/lc_example1.sql", q.Path)
assert.Equal(t, "select 'foo'", q.Data)
assert.Equal(t, "localhost", q.Meta.Host.String())
assert.Equal(t, "*", q.Meta.User.String())
assert.Equal(t, "*", q.Meta.Database.String())
},
},
{
id: "lc_example2",
check: func(t *testing.T, q *Query) {
assert.Equal(t, "lc_example2", q.ID)
assert.Equal(t, "../../data/lc_example2.sql", q.Path)
assert.Equal(t, "-- some comment\nselect 'foo'", q.Data)
assert.Equal(t, "localhost", q.Meta.Host.String())
assert.Equal(t, "foo", q.Meta.User.String())
assert.Equal(t, "*", q.Meta.Database.String())
},
},
}
store := NewStore("../../data")
for _, ex := range examples {
t.Run(ex.id, func(t *testing.T) {
query, err := store.Read(ex.id)
if ex.err != "" || err != nil {
assert.Equal(t, ex.err, err.Error())
}
if ex.check != nil {
ex.check(t, query)
}
})
}
}

View File

@ -357,6 +357,7 @@
#input .actions #query_progress { #input .actions #query_progress {
display: none; display: none;
float: left; float: left;
font-size: 12px;
line-height: 30px; line-height: 30px;
height: 30px; height: 30px;
color: #aaa; color: #aaa;

View File

@ -87,6 +87,13 @@
<li><a href="#" id="analyze">Analyze Query</a></li> <li><a href="#" id="analyze">Analyze Query</a></li>
</ul> </ul>
</div> </div>
<div id="load-query-dropdown" class="btn-group left" style="display: none">
<button id="load-local-query" type="button" class="btn btn-default dropdown-toggle" data-toggle="dropdown" disabled="disabled">
Template <span class="caret"></span>
</button>
<ul class="dropdown-menu" role="menu">
</ul>
</div>
<div id="query_progress">Please wait, query is executing...</div> <div id="query_progress">Please wait, query is executing...</div>
<div class="pull-right"> <div class="pull-right">
<span id="result-rows-count"></span> <span id="result-rows-count"></span>

View File

@ -178,6 +178,33 @@ function buildSchemaSection(name, objects) {
return section; return section;
} }
function loadLocalQueries() {
if (!appFeatures.local_queries) return;
$("body").on("click", "a.load-local-query", function(e) {
var id = $(this).data("id");
apiCall("get", "/local_queries/" + id, {}, function(resp) {
editor.setValue(resp.query);
editor.clearSelection();
});
});
apiCall("get", "/local_queries", {}, function(resp) {
if (resp.error) return;
var container = $("#load-query-dropdown").find(".dropdown-menu");
resp.forEach(function(item) {
var title = item.title || item.id;
$("<li><a href='#' class='load-local-query' data-id='" + item.id + "'>" + title + "</a></li>").appendTo(container);
});
if (resp.length > 0) $("#load-local-query").prop("disabled", "");
$("#load-query-dropdown").show();
});
}
function loadSchemas() { function loadSchemas() {
$("#objects").html(""); $("#objects").html("");
@ -738,13 +765,13 @@ function showActivityPanel() {
} }
function showQueryProgressMessage() { function showQueryProgressMessage() {
$("#run, #explain-dropdown-toggle, #csv, #json, #xml").prop("disabled", true); $("#run, #explain-dropdown-toggle, #csv, #json, #xml, #load-local-query").prop("disabled", true);
$("#explain-dropdown").removeClass("open"); $("#explain-dropdown").removeClass("open");
$("#query_progress").show(); $("#query_progress").show();
} }
function hideQueryProgressMessage() { function hideQueryProgressMessage() {
$("#run, #explain-dropdown-toggle, #csv, #json, #xml").prop("disabled", false); $("#run, #explain-dropdown-toggle, #csv, #json, #xml, #load-local-query").prop("disabled", false);
$("#query_progress").hide(); $("#query_progress").hide();
} }
@ -1810,6 +1837,7 @@ $(document).ready(function() {
connected = true; connected = true;
loadSchemas(); loadSchemas();
loadLocalQueries();
$("#current_database").text(resp.current_database); $("#current_database").text(resp.current_database);
$("#main").show(); $("#main").show();