diff --git a/pkg/client/client.go b/pkg/client/client.go index 1085c2e..4f1dd40 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -1,6 +1,7 @@ package client import ( + "context" "errors" "fmt" "log" @@ -34,6 +35,7 @@ type Client struct { serverVersion string serverType string lastQueryTime time.Time + queryTimeout time.Duration closed bool External bool `json:"external"` History []history.Record `json:"history"` @@ -79,7 +81,7 @@ func New() (*Client, error) { History: history.New(), } - client.setServerVersion() + client.init() return &client, nil } @@ -139,10 +141,18 @@ func NewFromUrl(url string, sshInfo *shared.SSHInfo) (*Client, error) { History: history.New(), } - client.setServerVersion() + client.init() return &client, nil } +func (client *Client) init() { + if command.Opts.QueryTimeout > 0 { + client.queryTimeout = time.Second * time.Duration(command.Opts.QueryTimeout) + } + + client.setServerVersion() +} + func (client *Client) setServerVersion() { res, err := client.query("SELECT version()") if err != nil || len(res.Rows) < 1 { @@ -338,6 +348,37 @@ func (client *Client) ServerVersion() string { return fmt.Sprintf("%s %s", client.serverType, client.serverVersion) } +func (client *Client) context() (context.Context, context.CancelFunc) { + if client.queryTimeout > 0 { + return context.WithTimeout(context.Background(), client.queryTimeout) + } + return context.Background(), func() {} +} + +func (client *Client) exec(query string, args ...interface{}) (*Result, error) { + ctx, cancel := client.context() + defer cancel() + + res, err := client.db.ExecContext(ctx, query, args...) + if err != nil { + return nil, err + } + + affected, err := res.RowsAffected() + if err != nil { + return nil, err + } + + result := Result{ + Columns: []string{"Rows Affected"}, + Rows: []Row{ + {affected}, + }, + } + + return &result, nil +} + func (client *Client) query(query string, args ...interface{}) (*Result, error) { if client.db == nil { return nil, nil @@ -363,27 +404,13 @@ func (client *Client) query(query string, args ...interface{}) (*Result, error) hasReturnValues := strings.Contains(strings.ToLower(query), " returning ") if (action == "update" || action == "delete") && !hasReturnValues { - res, err := client.db.Exec(query, args...) - if err != nil { - return nil, err - } - - affected, err := res.RowsAffected() - if err != nil { - return nil, err - } - - result := Result{ - Columns: []string{"Rows Affected"}, - Rows: []Row{ - {affected}, - }, - } - - return &result, nil + return client.exec(query, args...) } - rows, err := client.db.Queryx(query, args...) + ctx, cancel := client.context() + defer cancel() + + rows, err := client.db.QueryxContext(ctx, query, args...) if err != nil { if command.Opts.Debug { log.Println("Failed query:", query, "\nArgs:", args) diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go index 6989db9..7072b9f 100644 --- a/pkg/client/client_test.go +++ b/pkg/client/client_test.go @@ -397,10 +397,37 @@ func testTableNameWithCamelCase(t *testing.T) { } func testQuery(t *testing.T) { - res, err := testClient.Query("SELECT * FROM books") - assert.NoError(t, err) - assert.Equal(t, 4, len(res.Columns)) - assert.Equal(t, 15, len(res.Rows)) + t.Run("basic query", func(t *testing.T) { + res, err := testClient.Query("SELECT * FROM books") + assert.NoError(t, err) + assert.Equal(t, 4, len(res.Columns)) + assert.Equal(t, 15, len(res.Rows)) + }) + + t.Run("error", func(t *testing.T) { + res, err := testClient.Query("SELCT * FROM books") + assert.NotNil(t, err) + assert.Equal(t, "pq: syntax error at or near \"SELCT\"", err.Error()) + assert.Nil(t, res) + }) + + t.Run("invalid table", func(t *testing.T) { + res, err := testClient.Query("SELECT * FROM books2") + assert.NotNil(t, err) + assert.Equal(t, "pq: relation \"books2\" does not exist", err.Error()) + assert.Nil(t, res) + }) + + t.Run("timeout", func(t *testing.T) { + testClient.queryTimeout = time.Millisecond * 100 + defer func() { + testClient.queryTimeout = 0 + }() + + res, err := testClient.query("SELECT pg_sleep(1);") + assert.Equal(t, "pq: canceling statement due to user request", err.Error()) + assert.Nil(t, res) + }) } func testUpdateQuery(t *testing.T) { @@ -446,20 +473,6 @@ func testUpdateQuery(t *testing.T) { }) } -func testQueryError(t *testing.T) { - res, err := testClient.Query("SELCT * FROM books") - assert.NotNil(t, err) - assert.Equal(t, "pq: syntax error at or near \"SELCT\"", err.Error()) - assert.Nil(t, res) -} - -func testQueryInvalidTable(t *testing.T) { - res, err := testClient.Query("SELECT * FROM books2") - assert.NotNil(t, err) - assert.Equal(t, "pq: relation \"books2\" does not exist", err.Error()) - assert.Nil(t, res) -} - func testTableRowsOrderEscape(t *testing.T) { rows, err := testClient.TableRows("dummies", RowsOptions{SortColumn: "isDummy"}) assert.NoError(t, err) @@ -611,8 +624,6 @@ func TestAll(t *testing.T) { testTableNameWithCamelCase(t) testQuery(t) testUpdateQuery(t) - testQueryError(t) - testQueryInvalidTable(t) testTableRowsOrderEscape(t) testFunctions(t) testResult(t) diff --git a/pkg/command/options.go b/pkg/command/options.go index 05f19d6..ca76aff 100644 --- a/pkg/command/options.go +++ b/pkg/command/options.go @@ -49,6 +49,7 @@ type Options struct { ConnectHeaders string `long:"connect-headers" description:"List of headers to pass to the connect backend"` DisableConnectionIdleTimeout bool `long:"no-idle-timeout" description:"Disable connection idle timeout"` ConnectionIdleTimeout int `long:"idle-timeout" description:"Set connection idle timeout in minutes" default:"180"` + QueryTimeout int `long:"query-timeout" description:"Set global query execution timeout in seconds" default:"0"` Cors bool `long:"cors" description:"Enable Cross-Origin Resource Sharing (CORS)"` CorsOrigin string `long:"cors-origin" description:"Allowed CORS origins" default:"*"` BinaryCodec string `long:"binary-codec" description:"Codec for binary data serialization, one of 'none', 'hex', 'base58', 'base64'" default:"none"`