Add test for server type and version detection (#612)
This commit is contained in:
parent
1754274d46
commit
0008842a68
@ -7,7 +7,6 @@ import (
|
||||
"log"
|
||||
neturl "net/url"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -21,14 +20,6 @@ import (
|
||||
"github.com/sosedoff/pgweb/pkg/statements"
|
||||
)
|
||||
|
||||
var (
|
||||
postgresSignature = regexp.MustCompile(`(?i)postgresql ([\d\.]+)\s`)
|
||||
postgresType = "PostgreSQL"
|
||||
|
||||
cockroachSignature = regexp.MustCompile(`(?i)cockroachdb ccl v([\d\.]+)\s`)
|
||||
cockroachType = "CockroachDB"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
db *sqlx.DB
|
||||
tunnel *Tunnel
|
||||
@ -137,6 +128,7 @@ func NewFromUrl(url string, sshInfo *shared.SSHInfo) (*Client, error) {
|
||||
client := Client{
|
||||
db: db,
|
||||
tunnel: tunnel,
|
||||
serverType: postgresType,
|
||||
ConnectionString: url,
|
||||
History: history.New(),
|
||||
}
|
||||
@ -160,21 +152,10 @@ func (client *Client) setServerVersion() {
|
||||
}
|
||||
|
||||
version := res.Rows[0][0].(string)
|
||||
|
||||
// Detect postgresql
|
||||
matches := postgresSignature.FindAllStringSubmatch(version, 1)
|
||||
if len(matches) > 0 {
|
||||
client.serverType = postgresType
|
||||
client.serverVersion = matches[0][1]
|
||||
return
|
||||
}
|
||||
|
||||
// Detect cockroachdb
|
||||
matches = cockroachSignature.FindAllStringSubmatch(version, 1)
|
||||
if len(matches) > 0 {
|
||||
client.serverType = cockroachType
|
||||
client.serverVersion = matches[0][1]
|
||||
return
|
||||
match, serverType, serverVersion := detectServerTypeAndVersion(version)
|
||||
if match {
|
||||
client.serverType = serverType
|
||||
client.serverVersion = serverVersion
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -12,6 +12,14 @@ var (
|
||||
// Comment regular expressions
|
||||
reSlashComment = regexp.MustCompile(`(?m)/\*.+\*/`)
|
||||
reDashComment = regexp.MustCompile(`(?m)--.+`)
|
||||
|
||||
// Postgres version signature
|
||||
postgresSignature = regexp.MustCompile(`(?i)postgresql ([\d\.]+)\s?`)
|
||||
postgresType = "PostgreSQL"
|
||||
|
||||
// Cockroach version signature
|
||||
cockroachSignature = regexp.MustCompile(`(?i)cockroachdb ccl v([\d\.]+)\s?`)
|
||||
cockroachType = "CockroachDB"
|
||||
)
|
||||
|
||||
// Get short version from the string
|
||||
@ -24,6 +32,24 @@ func getMajorMinorVersion(str string) string {
|
||||
return strings.Join(chunks[0:2], ".")
|
||||
}
|
||||
|
||||
func detectServerTypeAndVersion(version string) (bool, string, string) {
|
||||
version = strings.TrimSpace(version)
|
||||
|
||||
// Detect postgresql
|
||||
matches := postgresSignature.FindAllStringSubmatch(version, 1)
|
||||
if len(matches) > 0 {
|
||||
return true, postgresType, matches[0][1]
|
||||
}
|
||||
|
||||
// Detect cockroachdb
|
||||
matches = cockroachSignature.FindAllStringSubmatch(version, 1)
|
||||
if len(matches) > 0 {
|
||||
return true, cockroachType, matches[0][1]
|
||||
}
|
||||
|
||||
return false, "", ""
|
||||
}
|
||||
|
||||
// containsRestrictedKeywords returns true if given keyword is not allowed in read-only mode
|
||||
func containsRestrictedKeywords(str string) bool {
|
||||
str = reSlashComment.ReplaceAllString(str, "")
|
||||
|
50
pkg/client/util_test.go
Normal file
50
pkg/client/util_test.go
Normal file
@ -0,0 +1,50 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDetectServerType(t *testing.T) {
|
||||
examples := []struct {
|
||||
input string
|
||||
match bool
|
||||
serverType string
|
||||
version string
|
||||
}{
|
||||
{input: "",
|
||||
match: false,
|
||||
serverType: "",
|
||||
version: "",
|
||||
},
|
||||
{
|
||||
input: " postgresql 15 ",
|
||||
match: true,
|
||||
serverType: postgresType,
|
||||
version: "15",
|
||||
},
|
||||
{
|
||||
input: "PostgreSQL 14.5 (Homebrew) on aarch64-apple-darwin21.6.0",
|
||||
match: true,
|
||||
serverType: postgresType,
|
||||
version: "14.5",
|
||||
},
|
||||
{
|
||||
input: "PostgreSQL 11.16, compiled by Visual C++ build 1800, 64-bit",
|
||||
match: true,
|
||||
serverType: postgresType,
|
||||
version: "11.16",
|
||||
},
|
||||
}
|
||||
|
||||
for _, ex := range examples {
|
||||
t.Run("input:"+ex.input, func(t *testing.T) {
|
||||
match, stype, version := detectServerTypeAndVersion(ex.input)
|
||||
|
||||
assert.Equal(t, ex.match, match)
|
||||
assert.Equal(t, ex.serverType, stype)
|
||||
assert.Equal(t, ex.version, version)
|
||||
})
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user