From 0008842a684928c9c10e3514a4395c646db3200b Mon Sep 17 00:00:00 2001 From: Dan Sosedoff Date: Thu, 8 Dec 2022 13:33:38 -0600 Subject: [PATCH] Add test for server type and version detection (#612) --- pkg/client/client.go | 29 +++++------------------- pkg/client/util.go | 26 +++++++++++++++++++++ pkg/client/util_test.go | 50 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 24 deletions(-) create mode 100644 pkg/client/util_test.go diff --git a/pkg/client/client.go b/pkg/client/client.go index 4f1dd40..2c4993d 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -7,7 +7,6 @@ import ( "log" neturl "net/url" "reflect" - "regexp" "strings" "time" @@ -21,14 +20,6 @@ import ( "github.com/sosedoff/pgweb/pkg/statements" ) -var ( - postgresSignature = regexp.MustCompile(`(?i)postgresql ([\d\.]+)\s`) - postgresType = "PostgreSQL" - - cockroachSignature = regexp.MustCompile(`(?i)cockroachdb ccl v([\d\.]+)\s`) - cockroachType = "CockroachDB" -) - type Client struct { db *sqlx.DB tunnel *Tunnel @@ -137,6 +128,7 @@ func NewFromUrl(url string, sshInfo *shared.SSHInfo) (*Client, error) { client := Client{ db: db, tunnel: tunnel, + serverType: postgresType, ConnectionString: url, History: history.New(), } @@ -160,21 +152,10 @@ func (client *Client) setServerVersion() { } version := res.Rows[0][0].(string) - - // Detect postgresql - matches := postgresSignature.FindAllStringSubmatch(version, 1) - if len(matches) > 0 { - client.serverType = postgresType - client.serverVersion = matches[0][1] - return - } - - // Detect cockroachdb - matches = cockroachSignature.FindAllStringSubmatch(version, 1) - if len(matches) > 0 { - client.serverType = cockroachType - client.serverVersion = matches[0][1] - return + match, serverType, serverVersion := detectServerTypeAndVersion(version) + if match { + client.serverType = serverType + client.serverVersion = serverVersion } } diff --git a/pkg/client/util.go b/pkg/client/util.go index d83bf1d..5dd4bb0 100644 --- a/pkg/client/util.go +++ b/pkg/client/util.go @@ -12,6 +12,14 @@ var ( // Comment regular expressions reSlashComment = regexp.MustCompile(`(?m)/\*.+\*/`) reDashComment = regexp.MustCompile(`(?m)--.+`) + + // Postgres version signature + postgresSignature = regexp.MustCompile(`(?i)postgresql ([\d\.]+)\s?`) + postgresType = "PostgreSQL" + + // Cockroach version signature + cockroachSignature = regexp.MustCompile(`(?i)cockroachdb ccl v([\d\.]+)\s?`) + cockroachType = "CockroachDB" ) // Get short version from the string @@ -24,6 +32,24 @@ func getMajorMinorVersion(str string) string { return strings.Join(chunks[0:2], ".") } +func detectServerTypeAndVersion(version string) (bool, string, string) { + version = strings.TrimSpace(version) + + // Detect postgresql + matches := postgresSignature.FindAllStringSubmatch(version, 1) + if len(matches) > 0 { + return true, postgresType, matches[0][1] + } + + // Detect cockroachdb + matches = cockroachSignature.FindAllStringSubmatch(version, 1) + if len(matches) > 0 { + return true, cockroachType, matches[0][1] + } + + return false, "", "" +} + // containsRestrictedKeywords returns true if given keyword is not allowed in read-only mode func containsRestrictedKeywords(str string) bool { str = reSlashComment.ReplaceAllString(str, "") diff --git a/pkg/client/util_test.go b/pkg/client/util_test.go new file mode 100644 index 0000000..5036a8a --- /dev/null +++ b/pkg/client/util_test.go @@ -0,0 +1,50 @@ +package client + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDetectServerType(t *testing.T) { + examples := []struct { + input string + match bool + serverType string + version string + }{ + {input: "", + match: false, + serverType: "", + version: "", + }, + { + input: " postgresql 15 ", + match: true, + serverType: postgresType, + version: "15", + }, + { + input: "PostgreSQL 14.5 (Homebrew) on aarch64-apple-darwin21.6.0", + match: true, + serverType: postgresType, + version: "14.5", + }, + { + input: "PostgreSQL 11.16, compiled by Visual C++ build 1800, 64-bit", + match: true, + serverType: postgresType, + version: "11.16", + }, + } + + for _, ex := range examples { + t.Run("input:"+ex.input, func(t *testing.T) { + match, stype, version := detectServerTypeAndVersion(ex.input) + + assert.Equal(t, ex.match, match) + assert.Equal(t, ex.serverType, stype) + assert.Equal(t, ex.version, version) + }) + } +}