Implement global query timeout option (#609)
* Add global query timeout * Tweak option settings * Add timeout test * Move query timeout option close to idle timeout
This commit is contained in:
parent
adf1e4e9ea
commit
d08dbf34aa
@ -1,6 +1,7 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
@ -34,6 +35,7 @@ type Client struct {
|
||||
serverVersion string
|
||||
serverType string
|
||||
lastQueryTime time.Time
|
||||
queryTimeout time.Duration
|
||||
closed bool
|
||||
External bool `json:"external"`
|
||||
History []history.Record `json:"history"`
|
||||
@ -79,7 +81,7 @@ func New() (*Client, error) {
|
||||
History: history.New(),
|
||||
}
|
||||
|
||||
client.setServerVersion()
|
||||
client.init()
|
||||
return &client, nil
|
||||
}
|
||||
|
||||
@ -139,10 +141,18 @@ func NewFromUrl(url string, sshInfo *shared.SSHInfo) (*Client, error) {
|
||||
History: history.New(),
|
||||
}
|
||||
|
||||
client.setServerVersion()
|
||||
client.init()
|
||||
return &client, nil
|
||||
}
|
||||
|
||||
func (client *Client) init() {
|
||||
if command.Opts.QueryTimeout > 0 {
|
||||
client.queryTimeout = time.Second * time.Duration(command.Opts.QueryTimeout)
|
||||
}
|
||||
|
||||
client.setServerVersion()
|
||||
}
|
||||
|
||||
func (client *Client) setServerVersion() {
|
||||
res, err := client.query("SELECT version()")
|
||||
if err != nil || len(res.Rows) < 1 {
|
||||
@ -338,6 +348,37 @@ func (client *Client) ServerVersion() string {
|
||||
return fmt.Sprintf("%s %s", client.serverType, client.serverVersion)
|
||||
}
|
||||
|
||||
func (client *Client) context() (context.Context, context.CancelFunc) {
|
||||
if client.queryTimeout > 0 {
|
||||
return context.WithTimeout(context.Background(), client.queryTimeout)
|
||||
}
|
||||
return context.Background(), func() {}
|
||||
}
|
||||
|
||||
func (client *Client) exec(query string, args ...interface{}) (*Result, error) {
|
||||
ctx, cancel := client.context()
|
||||
defer cancel()
|
||||
|
||||
res, err := client.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := Result{
|
||||
Columns: []string{"Rows Affected"},
|
||||
Rows: []Row{
|
||||
{affected},
|
||||
},
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (client *Client) query(query string, args ...interface{}) (*Result, error) {
|
||||
if client.db == nil {
|
||||
return nil, nil
|
||||
@ -363,27 +404,13 @@ func (client *Client) query(query string, args ...interface{}) (*Result, error)
|
||||
hasReturnValues := strings.Contains(strings.ToLower(query), " returning ")
|
||||
|
||||
if (action == "update" || action == "delete") && !hasReturnValues {
|
||||
res, err := client.db.Exec(query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := Result{
|
||||
Columns: []string{"Rows Affected"},
|
||||
Rows: []Row{
|
||||
{affected},
|
||||
},
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
return client.exec(query, args...)
|
||||
}
|
||||
|
||||
rows, err := client.db.Queryx(query, args...)
|
||||
ctx, cancel := client.context()
|
||||
defer cancel()
|
||||
|
||||
rows, err := client.db.QueryxContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
if command.Opts.Debug {
|
||||
log.Println("Failed query:", query, "\nArgs:", args)
|
||||
|
@ -397,10 +397,37 @@ func testTableNameWithCamelCase(t *testing.T) {
|
||||
}
|
||||
|
||||
func testQuery(t *testing.T) {
|
||||
res, err := testClient.Query("SELECT * FROM books")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 4, len(res.Columns))
|
||||
assert.Equal(t, 15, len(res.Rows))
|
||||
t.Run("basic query", func(t *testing.T) {
|
||||
res, err := testClient.Query("SELECT * FROM books")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 4, len(res.Columns))
|
||||
assert.Equal(t, 15, len(res.Rows))
|
||||
})
|
||||
|
||||
t.Run("error", func(t *testing.T) {
|
||||
res, err := testClient.Query("SELCT * FROM books")
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "pq: syntax error at or near \"SELCT\"", err.Error())
|
||||
assert.Nil(t, res)
|
||||
})
|
||||
|
||||
t.Run("invalid table", func(t *testing.T) {
|
||||
res, err := testClient.Query("SELECT * FROM books2")
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "pq: relation \"books2\" does not exist", err.Error())
|
||||
assert.Nil(t, res)
|
||||
})
|
||||
|
||||
t.Run("timeout", func(t *testing.T) {
|
||||
testClient.queryTimeout = time.Millisecond * 100
|
||||
defer func() {
|
||||
testClient.queryTimeout = 0
|
||||
}()
|
||||
|
||||
res, err := testClient.query("SELECT pg_sleep(1);")
|
||||
assert.Equal(t, "pq: canceling statement due to user request", err.Error())
|
||||
assert.Nil(t, res)
|
||||
})
|
||||
}
|
||||
|
||||
func testUpdateQuery(t *testing.T) {
|
||||
@ -446,20 +473,6 @@ func testUpdateQuery(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func testQueryError(t *testing.T) {
|
||||
res, err := testClient.Query("SELCT * FROM books")
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "pq: syntax error at or near \"SELCT\"", err.Error())
|
||||
assert.Nil(t, res)
|
||||
}
|
||||
|
||||
func testQueryInvalidTable(t *testing.T) {
|
||||
res, err := testClient.Query("SELECT * FROM books2")
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "pq: relation \"books2\" does not exist", err.Error())
|
||||
assert.Nil(t, res)
|
||||
}
|
||||
|
||||
func testTableRowsOrderEscape(t *testing.T) {
|
||||
rows, err := testClient.TableRows("dummies", RowsOptions{SortColumn: "isDummy"})
|
||||
assert.NoError(t, err)
|
||||
@ -611,8 +624,6 @@ func TestAll(t *testing.T) {
|
||||
testTableNameWithCamelCase(t)
|
||||
testQuery(t)
|
||||
testUpdateQuery(t)
|
||||
testQueryError(t)
|
||||
testQueryInvalidTable(t)
|
||||
testTableRowsOrderEscape(t)
|
||||
testFunctions(t)
|
||||
testResult(t)
|
||||
|
@ -49,6 +49,7 @@ type Options struct {
|
||||
ConnectHeaders string `long:"connect-headers" description:"List of headers to pass to the connect backend"`
|
||||
DisableConnectionIdleTimeout bool `long:"no-idle-timeout" description:"Disable connection idle timeout"`
|
||||
ConnectionIdleTimeout int `long:"idle-timeout" description:"Set connection idle timeout in minutes" default:"180"`
|
||||
QueryTimeout int `long:"query-timeout" description:"Set global query execution timeout in seconds" default:"0"`
|
||||
Cors bool `long:"cors" description:"Enable Cross-Origin Resource Sharing (CORS)"`
|
||||
CorsOrigin string `long:"cors-origin" description:"Allowed CORS origins" default:"*"`
|
||||
BinaryCodec string `long:"binary-codec" description:"Codec for binary data serialization, one of 'none', 'hex', 'base58', 'base64'" default:"none"`
|
||||
|
Loading…
x
Reference in New Issue
Block a user