From 450badfbbb2cb94d2759896bd1fec3a3fb3dd76e Mon Sep 17 00:00:00 2001 From: Dan Sosedoff Date: Mon, 5 Dec 2022 20:56:21 -0600 Subject: [PATCH] Handle returning values in update/delete queries --- pkg/client/client.go | 4 +++- pkg/client/client_test.go | 46 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/pkg/client/client.go b/pkg/client/client.go index 3d8bd50..7497791 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -356,7 +356,9 @@ func (client *Client) query(query string, args ...interface{}) (*Result, error) } action := strings.ToLower(strings.Split(query, " ")[0]) - if action == "update" || action == "delete" { + hasReturnValues := strings.Contains(strings.ToLower(query), " returning ") + + if (action == "update" || action == "delete") && !hasReturnValues { res, err := client.db.Exec(query, args...) if err != nil { return nil, err diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go index 7ca3e75..e9304b4 100644 --- a/pkg/client/client_test.go +++ b/pkg/client/client_test.go @@ -352,12 +352,55 @@ func testTableNameWithCamelCase(t *testing.T) { func testQuery(t *testing.T) { res, err := testClient.Query("SELECT * FROM books") - assert.Equal(t, nil, err) assert.Equal(t, 4, len(res.Columns)) assert.Equal(t, 15, len(res.Rows)) } +func testUpdateQuery(t *testing.T) { + t.Run("updating data", func(t *testing.T) { + // Add new row + _, err := testClient.db.Exec("INSERT INTO books (id, title) VALUES (8888, 'Test Book'), (8889, 'Test Book 2')") + + // Update without return values + res, err := testClient.Query("UPDATE books SET title = 'Foo' WHERE id >= 8888 AND id <= 8889") + assert.NoError(t, err) + assert.Equal(t, "Rows Affected", res.Columns[0]) + assert.Equal(t, int64(2), res.Rows[0][0]) + + // Update with return values + res, err = testClient.Query("UPDATE books SET title = 'Foo2' WHERE id >= 8888 AND id <= 8889 RETURNING id, title") + assert.NoError(t, err) + assert.Equal(t, []string{"id", "title"}, res.Columns) + assert.Equal(t, Row{int64(8888), "Foo2"}, res.Rows[0]) + assert.Equal(t, Row{int64(8889), "Foo2"}, res.Rows[1]) + }) + + t.Run("deleting data", func(t *testing.T) { + // Add new row + _, err := testClient.db.Exec("INSERT INTO books (id, title) VALUES (9999, 'Test Book')") + + // Delete the existing row + res, err := testClient.Query("DELETE FROM books WHERE id = 9999") + assert.NoError(t, err) + assert.Equal(t, "Rows Affected", res.Columns[0]) + assert.Equal(t, int64(1), res.Rows[0][0]) + + // Deleting already deleted row + res, err = testClient.Query("DELETE FROM books WHERE id = 9999") + assert.NoError(t, err) + assert.Equal(t, int64(0), res.Rows[0][0]) + + // Delete with returning value + _, err = testClient.db.Exec("INSERT INTO books (id, title) VALUES (9999, 'Test Book')") + assert.NoError(t, err) + + res, err = testClient.Query("DELETE FROM books WHERE id = 9999 RETURNING id") + assert.NoError(t, err) + assert.Equal(t, int64(9999), res.Rows[0][0]) + }) +} + func testQueryError(t *testing.T) { res, err := testClient.Query("SELCT * FROM books") @@ -491,6 +534,7 @@ func TestAll(t *testing.T) { testTableConstraints(t) testTableNameWithCamelCase(t) testQuery(t) + testUpdateQuery(t) testQueryError(t) testQueryInvalidTable(t) testTableRowsOrderEscape(t)