Add test for server type and version detection (#612)

This commit is contained in:
Dan Sosedoff 2022-12-08 13:33:38 -06:00 committed by GitHub
parent 1754274d46
commit 0008842a68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 81 additions and 24 deletions

View File

@ -7,7 +7,6 @@ import (
"log"
neturl "net/url"
"reflect"
"regexp"
"strings"
"time"
@ -21,14 +20,6 @@ import (
"github.com/sosedoff/pgweb/pkg/statements"
)
var (
postgresSignature = regexp.MustCompile(`(?i)postgresql ([\d\.]+)\s`)
postgresType = "PostgreSQL"
cockroachSignature = regexp.MustCompile(`(?i)cockroachdb ccl v([\d\.]+)\s`)
cockroachType = "CockroachDB"
)
type Client struct {
db *sqlx.DB
tunnel *Tunnel
@ -137,6 +128,7 @@ func NewFromUrl(url string, sshInfo *shared.SSHInfo) (*Client, error) {
client := Client{
db: db,
tunnel: tunnel,
serverType: postgresType,
ConnectionString: url,
History: history.New(),
}
@ -160,21 +152,10 @@ func (client *Client) setServerVersion() {
}
version := res.Rows[0][0].(string)
// Detect postgresql
matches := postgresSignature.FindAllStringSubmatch(version, 1)
if len(matches) > 0 {
client.serverType = postgresType
client.serverVersion = matches[0][1]
return
}
// Detect cockroachdb
matches = cockroachSignature.FindAllStringSubmatch(version, 1)
if len(matches) > 0 {
client.serverType = cockroachType
client.serverVersion = matches[0][1]
return
match, serverType, serverVersion := detectServerTypeAndVersion(version)
if match {
client.serverType = serverType
client.serverVersion = serverVersion
}
}

View File

@ -12,6 +12,14 @@ var (
// Comment regular expressions
reSlashComment = regexp.MustCompile(`(?m)/\*.+\*/`)
reDashComment = regexp.MustCompile(`(?m)--.+`)
// Postgres version signature
postgresSignature = regexp.MustCompile(`(?i)postgresql ([\d\.]+)\s?`)
postgresType = "PostgreSQL"
// Cockroach version signature
cockroachSignature = regexp.MustCompile(`(?i)cockroachdb ccl v([\d\.]+)\s?`)
cockroachType = "CockroachDB"
)
// Get short version from the string
@ -24,6 +32,24 @@ func getMajorMinorVersion(str string) string {
return strings.Join(chunks[0:2], ".")
}
func detectServerTypeAndVersion(version string) (bool, string, string) {
version = strings.TrimSpace(version)
// Detect postgresql
matches := postgresSignature.FindAllStringSubmatch(version, 1)
if len(matches) > 0 {
return true, postgresType, matches[0][1]
}
// Detect cockroachdb
matches = cockroachSignature.FindAllStringSubmatch(version, 1)
if len(matches) > 0 {
return true, cockroachType, matches[0][1]
}
return false, "", ""
}
// containsRestrictedKeywords returns true if given keyword is not allowed in read-only mode
func containsRestrictedKeywords(str string) bool {
str = reSlashComment.ReplaceAllString(str, "")

50
pkg/client/util_test.go Normal file
View File

@ -0,0 +1,50 @@
package client
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestDetectServerType(t *testing.T) {
examples := []struct {
input string
match bool
serverType string
version string
}{
{input: "",
match: false,
serverType: "",
version: "",
},
{
input: " postgresql 15 ",
match: true,
serverType: postgresType,
version: "15",
},
{
input: "PostgreSQL 14.5 (Homebrew) on aarch64-apple-darwin21.6.0",
match: true,
serverType: postgresType,
version: "14.5",
},
{
input: "PostgreSQL 11.16, compiled by Visual C++ build 1800, 64-bit",
match: true,
serverType: postgresType,
version: "11.16",
},
}
for _, ex := range examples {
t.Run("input:"+ex.input, func(t *testing.T) {
match, stype, version := detectServerTypeAndVersion(ex.input)
assert.Equal(t, ex.match, match)
assert.Equal(t, ex.serverType, stype)
assert.Equal(t, ex.version, version)
})
}
}