Refactor bookmarks ssl params
This commit is contained in:
parent
84bf1f091b
commit
dd1fb90355
@ -21,7 +21,7 @@ type Bookmark struct {
|
||||
User string `json:"user"` // Database user
|
||||
Password string `json:"password"` // User password
|
||||
Database string `json:"database"` // Database name
|
||||
Ssl string `json:"ssl"` // Connection SSL mode
|
||||
SSLMode string `json:"ssl"` // Connection SSL mode
|
||||
SSH *shared.SSHInfo `json:"ssh"` // SSH tunnel config
|
||||
}
|
||||
|
||||
@ -33,13 +33,13 @@ func (b Bookmark) SSHInfoIsEmpty() bool {
|
||||
// ConvertToOptions returns an options struct from connection details
|
||||
func (b Bookmark) ConvertToOptions() command.Options {
|
||||
return command.Options{
|
||||
URL: b.URL,
|
||||
Host: b.Host,
|
||||
Port: b.Port,
|
||||
User: b.User,
|
||||
Pass: b.Password,
|
||||
DbName: b.Database,
|
||||
Ssl: b.Ssl,
|
||||
URL: b.URL,
|
||||
Host: b.Host,
|
||||
Port: b.Port,
|
||||
User: b.User,
|
||||
Pass: b.Password,
|
||||
DbName: b.Database,
|
||||
SSLMode: b.SSLMode,
|
||||
}
|
||||
}
|
||||
|
||||
@ -62,7 +62,7 @@ func readServerConfig(path string) (Bookmark, error) {
|
||||
valid := false
|
||||
|
||||
for _, mode := range modes {
|
||||
if bookmark.Ssl == mode {
|
||||
if bookmark.SSLMode == mode {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
@ -70,8 +70,8 @@ func readServerConfig(path string) (Bookmark, error) {
|
||||
|
||||
// Fall back to a default mode if mode is not set or invalid
|
||||
// Typical typo: ssl mode set to "disabled"
|
||||
if bookmark.Ssl == "" || !valid {
|
||||
bookmark.Ssl = "disable"
|
||||
if bookmark.SSLMode == "" || !valid {
|
||||
bookmark.SSLMode = "disable"
|
||||
}
|
||||
|
||||
// Set default SSH port if it's not provided by user
|
||||
|
@ -24,13 +24,13 @@ func Test_Bookmark(t *testing.T) {
|
||||
assert.Equal(t, 5432, bookmark.Port)
|
||||
assert.Equal(t, "postgres", bookmark.User)
|
||||
assert.Equal(t, "mydatabase", bookmark.Database)
|
||||
assert.Equal(t, "disable", bookmark.Ssl)
|
||||
assert.Equal(t, "disable", bookmark.SSLMode)
|
||||
assert.Equal(t, "", bookmark.Password)
|
||||
assert.Equal(t, "", bookmark.URL)
|
||||
|
||||
bookmark, err = readServerConfig("../../data/bookmark_invalid_ssl.toml")
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "disable", bookmark.Ssl)
|
||||
assert.Equal(t, "disable", bookmark.SSLMode)
|
||||
}
|
||||
|
||||
func Test_Bookmark_URL(t *testing.T) {
|
||||
@ -42,7 +42,7 @@ func Test_Bookmark_URL(t *testing.T) {
|
||||
assert.Equal(t, 5432, bookmark.Port)
|
||||
assert.Equal(t, "", bookmark.User)
|
||||
assert.Equal(t, "", bookmark.Database)
|
||||
assert.Equal(t, "disable", bookmark.Ssl)
|
||||
assert.Equal(t, "disable", bookmark.SSLMode)
|
||||
assert.Equal(t, "", bookmark.Password)
|
||||
}
|
||||
|
||||
@ -79,7 +79,7 @@ func Test_GetBookmark(t *testing.T) {
|
||||
User: "postgres",
|
||||
Password: "",
|
||||
Database: "mydatabase",
|
||||
Ssl: "disable",
|
||||
SSLMode: "disable",
|
||||
}
|
||||
b, err := GetBookmark("../../data", "bookmark")
|
||||
if assert.NoError(t, err) {
|
||||
@ -124,17 +124,17 @@ func Test_ConvertToOptions(t *testing.T) {
|
||||
User: "postgres",
|
||||
Password: "password",
|
||||
Database: "mydatabase",
|
||||
Ssl: "disable",
|
||||
SSLMode: "disable",
|
||||
}
|
||||
|
||||
expOpt := command.Options{
|
||||
URL: "postgres://username:password@host:port/database?sslmode=disable",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "postgres",
|
||||
Pass: "password",
|
||||
DbName: "mydatabase",
|
||||
Ssl: "disable",
|
||||
URL: "postgres://username:password@host:port/database?sslmode=disable",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "postgres",
|
||||
Pass: "password",
|
||||
DbName: "mydatabase",
|
||||
SSLMode: "disable",
|
||||
}
|
||||
opt := b.ConvertToOptions()
|
||||
assert.Equal(t, expOpt, opt)
|
||||
|
@ -24,10 +24,10 @@ type Options struct {
|
||||
User string `long:"user" description:"Database user"`
|
||||
Pass string `long:"pass" description:"Password for user"`
|
||||
DbName string `long:"db" description:"Database name"`
|
||||
Ssl string `long:"ssl" description:"SSL mode"`
|
||||
SslRootCert string `long:"ssl-rootcert" description:"SSL certificate authority file"`
|
||||
SslCert string `long:"ssl-cert" description:"SSL client certificate file"`
|
||||
SslKey string `long:"ssl-key" description:"SSL client certificate key file"`
|
||||
SSLMode string `long:"ssl" description:"SSL mode"`
|
||||
SSLRootCert string `long:"ssl-rootcert" description:"SSL certificate authority file"`
|
||||
SSLCert string `long:"ssl-cert" description:"SSL client certificate file"`
|
||||
SSLKey string `long:"ssl-key" description:"SSL client certificate key file"`
|
||||
HTTPHost string `long:"bind" description:"HTTP server host" default:"localhost"`
|
||||
HTTPPort uint `long:"listen" description:"HTTP server listen port" default:"8081"`
|
||||
AuthUser string `long:"auth-user" description:"HTTP basic auth user"`
|
||||
@ -96,7 +96,7 @@ func ParseOptions(args []string) (Options, error) {
|
||||
opts.User = ""
|
||||
opts.Pass = ""
|
||||
opts.DbName = ""
|
||||
opts.Ssl = ""
|
||||
opts.SSLMode = ""
|
||||
}
|
||||
|
||||
if opts.Prefix != "" && !strings.Contains(opts.Prefix, "/") {
|
||||
|
@ -66,13 +66,13 @@ func FormatURL(opts command.Options) (string, error) {
|
||||
|
||||
// Determine if we need to specify sslmode if it's missing
|
||||
if params["sslmode"] == "" {
|
||||
if opts.Ssl == "" {
|
||||
if opts.SSLMode == "" {
|
||||
// Only modify sslmode for local connections
|
||||
if strings.Contains(uri.Host, "localhost") || strings.Contains(uri.Host, "127.0.0.1") {
|
||||
params["sslmode"] = "disable"
|
||||
}
|
||||
} else {
|
||||
params["sslmode"] = opts.Ssl
|
||||
params["sslmode"] = opts.SSLMode
|
||||
}
|
||||
}
|
||||
|
||||
@ -108,21 +108,21 @@ func BuildStringFromOptions(opts command.Options) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if opts.Ssl != "" {
|
||||
query.Add("sslmode", opts.Ssl)
|
||||
if opts.SSLMode != "" {
|
||||
query.Add("sslmode", opts.SSLMode)
|
||||
} else {
|
||||
if opts.Host == "localhost" || opts.Host == "127.0.0.1" {
|
||||
query.Add("sslmode", "disable")
|
||||
}
|
||||
}
|
||||
if opts.SslCert != "" {
|
||||
query.Add("sslcert", opts.SslCert)
|
||||
if opts.SSLCert != "" {
|
||||
query.Add("sslcert", opts.SSLCert)
|
||||
}
|
||||
if opts.SslKey != "" {
|
||||
query.Add("sslkey", opts.SslKey)
|
||||
if opts.SSLKey != "" {
|
||||
query.Add("sslkey", opts.SSLKey)
|
||||
}
|
||||
if opts.SslRootCert != "" {
|
||||
query.Add("sslrootcert", opts.SslRootCert)
|
||||
if opts.SSLRootCert != "" {
|
||||
query.Add("sslrootcert", opts.SSLRootCert)
|
||||
}
|
||||
|
||||
url := neturl.URL{
|
||||
|
@ -38,8 +38,8 @@ func Test_Valid_Url(t *testing.T) {
|
||||
|
||||
func Test_Url_And_Ssl_Flag(t *testing.T) {
|
||||
str, err := BuildStringFromOptions(command.Options{
|
||||
URL: "postgres://myhost/database",
|
||||
Ssl: "disable",
|
||||
URL: "postgres://myhost/database",
|
||||
SSLMode: "disable",
|
||||
})
|
||||
|
||||
assert.Equal(t, nil, err)
|
||||
@ -62,15 +62,15 @@ func Test_Localhost_Url_And_No_Ssl_Flag(t *testing.T) {
|
||||
|
||||
func Test_Localhost_Url_And_Ssl_Flag(t *testing.T) {
|
||||
str, err := BuildStringFromOptions(command.Options{
|
||||
URL: "postgres://localhost/database",
|
||||
Ssl: "require",
|
||||
URL: "postgres://localhost/database",
|
||||
SSLMode: "require",
|
||||
})
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "postgres://localhost/database?sslmode=require", str)
|
||||
|
||||
str, err = BuildStringFromOptions(command.Options{
|
||||
URL: "postgres://127.0.0.1/database",
|
||||
Ssl: "require",
|
||||
URL: "postgres://127.0.0.1/database",
|
||||
SSLMode: "require",
|
||||
})
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "postgres://127.0.0.1/database?sslmode=require", str)
|
||||
@ -137,10 +137,10 @@ func Test_Localhost_And_Ssl(t *testing.T) {
|
||||
User: "user",
|
||||
Pass: "password",
|
||||
DbName: "db",
|
||||
Ssl: "require",
|
||||
SslKey: "keyPath",
|
||||
SslCert: "certPath",
|
||||
SslRootCert: "caPath",
|
||||
SSLMode: "require",
|
||||
SSLKey: "keyPath",
|
||||
SSLCert: "certPath",
|
||||
SSLRootCert: "caPath",
|
||||
}
|
||||
|
||||
str, err := BuildStringFromOptions(opts)
|
||||
|
Loading…
x
Reference in New Issue
Block a user