Implement global query timeout option (#609)

* Add global query timeout
* Tweak option settings
* Add timeout test
* Move query timeout option close to idle timeout
This commit is contained in:
Dan Sosedoff 2022-12-07 18:56:58 -06:00 committed by GitHub
parent adf1e4e9ea
commit d08dbf34aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 80 additions and 41 deletions

View File

@ -1,6 +1,7 @@
package client
import (
"context"
"errors"
"fmt"
"log"
@ -34,6 +35,7 @@ type Client struct {
serverVersion string
serverType string
lastQueryTime time.Time
queryTimeout time.Duration
closed bool
External bool `json:"external"`
History []history.Record `json:"history"`
@ -79,7 +81,7 @@ func New() (*Client, error) {
History: history.New(),
}
client.setServerVersion()
client.init()
return &client, nil
}
@ -139,10 +141,18 @@ func NewFromUrl(url string, sshInfo *shared.SSHInfo) (*Client, error) {
History: history.New(),
}
client.setServerVersion()
client.init()
return &client, nil
}
func (client *Client) init() {
if command.Opts.QueryTimeout > 0 {
client.queryTimeout = time.Second * time.Duration(command.Opts.QueryTimeout)
}
client.setServerVersion()
}
func (client *Client) setServerVersion() {
res, err := client.query("SELECT version()")
if err != nil || len(res.Rows) < 1 {
@ -338,6 +348,37 @@ func (client *Client) ServerVersion() string {
return fmt.Sprintf("%s %s", client.serverType, client.serverVersion)
}
func (client *Client) context() (context.Context, context.CancelFunc) {
if client.queryTimeout > 0 {
return context.WithTimeout(context.Background(), client.queryTimeout)
}
return context.Background(), func() {}
}
func (client *Client) exec(query string, args ...interface{}) (*Result, error) {
ctx, cancel := client.context()
defer cancel()
res, err := client.db.ExecContext(ctx, query, args...)
if err != nil {
return nil, err
}
affected, err := res.RowsAffected()
if err != nil {
return nil, err
}
result := Result{
Columns: []string{"Rows Affected"},
Rows: []Row{
{affected},
},
}
return &result, nil
}
func (client *Client) query(query string, args ...interface{}) (*Result, error) {
if client.db == nil {
return nil, nil
@ -363,27 +404,13 @@ func (client *Client) query(query string, args ...interface{}) (*Result, error)
hasReturnValues := strings.Contains(strings.ToLower(query), " returning ")
if (action == "update" || action == "delete") && !hasReturnValues {
res, err := client.db.Exec(query, args...)
if err != nil {
return nil, err
}
affected, err := res.RowsAffected()
if err != nil {
return nil, err
}
result := Result{
Columns: []string{"Rows Affected"},
Rows: []Row{
{affected},
},
}
return &result, nil
return client.exec(query, args...)
}
rows, err := client.db.Queryx(query, args...)
ctx, cancel := client.context()
defer cancel()
rows, err := client.db.QueryxContext(ctx, query, args...)
if err != nil {
if command.Opts.Debug {
log.Println("Failed query:", query, "\nArgs:", args)

View File

@ -397,10 +397,37 @@ func testTableNameWithCamelCase(t *testing.T) {
}
func testQuery(t *testing.T) {
res, err := testClient.Query("SELECT * FROM books")
assert.NoError(t, err)
assert.Equal(t, 4, len(res.Columns))
assert.Equal(t, 15, len(res.Rows))
t.Run("basic query", func(t *testing.T) {
res, err := testClient.Query("SELECT * FROM books")
assert.NoError(t, err)
assert.Equal(t, 4, len(res.Columns))
assert.Equal(t, 15, len(res.Rows))
})
t.Run("error", func(t *testing.T) {
res, err := testClient.Query("SELCT * FROM books")
assert.NotNil(t, err)
assert.Equal(t, "pq: syntax error at or near \"SELCT\"", err.Error())
assert.Nil(t, res)
})
t.Run("invalid table", func(t *testing.T) {
res, err := testClient.Query("SELECT * FROM books2")
assert.NotNil(t, err)
assert.Equal(t, "pq: relation \"books2\" does not exist", err.Error())
assert.Nil(t, res)
})
t.Run("timeout", func(t *testing.T) {
testClient.queryTimeout = time.Millisecond * 100
defer func() {
testClient.queryTimeout = 0
}()
res, err := testClient.query("SELECT pg_sleep(1);")
assert.Equal(t, "pq: canceling statement due to user request", err.Error())
assert.Nil(t, res)
})
}
func testUpdateQuery(t *testing.T) {
@ -446,20 +473,6 @@ func testUpdateQuery(t *testing.T) {
})
}
func testQueryError(t *testing.T) {
res, err := testClient.Query("SELCT * FROM books")
assert.NotNil(t, err)
assert.Equal(t, "pq: syntax error at or near \"SELCT\"", err.Error())
assert.Nil(t, res)
}
func testQueryInvalidTable(t *testing.T) {
res, err := testClient.Query("SELECT * FROM books2")
assert.NotNil(t, err)
assert.Equal(t, "pq: relation \"books2\" does not exist", err.Error())
assert.Nil(t, res)
}
func testTableRowsOrderEscape(t *testing.T) {
rows, err := testClient.TableRows("dummies", RowsOptions{SortColumn: "isDummy"})
assert.NoError(t, err)
@ -611,8 +624,6 @@ func TestAll(t *testing.T) {
testTableNameWithCamelCase(t)
testQuery(t)
testUpdateQuery(t)
testQueryError(t)
testQueryInvalidTable(t)
testTableRowsOrderEscape(t)
testFunctions(t)
testResult(t)

View File

@ -49,6 +49,7 @@ type Options struct {
ConnectHeaders string `long:"connect-headers" description:"List of headers to pass to the connect backend"`
DisableConnectionIdleTimeout bool `long:"no-idle-timeout" description:"Disable connection idle timeout"`
ConnectionIdleTimeout int `long:"idle-timeout" description:"Set connection idle timeout in minutes" default:"180"`
QueryTimeout int `long:"query-timeout" description:"Set global query execution timeout in seconds" default:"0"`
Cors bool `long:"cors" description:"Enable Cross-Origin Resource Sharing (CORS)"`
CorsOrigin string `long:"cors-origin" description:"Allowed CORS origins" default:"*"`
BinaryCodec string `long:"binary-codec" description:"Codec for binary data serialization, one of 'none', 'hex', 'base58', 'base64'" default:"none"`