From b0151ee9856725910207fc89b6ad4e90e9d3c60a Mon Sep 17 00:00:00 2001 From: Dan Sosedoff Date: Thu, 13 Sep 2018 22:25:15 -0500 Subject: [PATCH 1/3] Refactor connection string generator --- pkg/connection/connection_string.go | 59 ++++++++++++++++++------ pkg/connection/connection_string_test.go | 7 +-- 2 files changed, 45 insertions(+), 21 deletions(-) diff --git a/pkg/connection/connection_string.go b/pkg/connection/connection_string.go index 2594748..d11b703 100644 --- a/pkg/connection/connection_string.go +++ b/pkg/connection/connection_string.go @@ -11,6 +11,10 @@ import ( "github.com/sosedoff/pgweb/pkg/command" ) +var ( + formatError = errors.New("Invalid URL. Valid format: postgres://user:password@host:port/db?sslmode=mode") +) + func currentUser() (string, error) { u, err := user.Current() if err == nil { @@ -25,32 +29,57 @@ func currentUser() (string, error) { return "", errors.New("Unable to detect OS user") } +// Check if connection url has a correct postgres prefix +func hasValidPrefix(str string) bool { + return strings.HasPrefix(str, "postgres://") || strings.HasPrefix(str, "postgresql://") +} + +// Extract all query vals and return as a map +func valsFromQuery(vals neturl.Values) map[string]string { + result := map[string]string{} + for k, v := range vals { + result[strings.ToLower(k)] = v[0] + } + return result +} + func FormatUrl(opts command.Options) (string, error) { url := opts.Url - // Make sure to only accept urls in a standard format - if !strings.HasPrefix(url, "postgres://") && !strings.HasPrefix(url, "postgresql://") { - return "", errors.New("Invalid URL. Valid format: postgres://user:password@host:port/db?sslmode=mode") + // Validate connection string prefix + if !hasValidPrefix(url) { + return "", formatError } - // Special handling for local connections - if strings.Contains(url, "localhost") || strings.Contains(url, "127.0.0.1") { - if !strings.Contains(url, "?sslmode") { - if opts.Ssl == "" { - url += fmt.Sprintf("?sslmode=%s", "disable") - } else { - url += fmt.Sprintf("?sslmode=%s", opts.Ssl) + // Validate the URL + uri, err := neturl.Parse(url) + if err != nil { + return "", formatError + } + + // Get query params + params := valsFromQuery(uri.Query()) + + // Determine if we need to specify sslmode if it's missing + if params["sslmode"] == "" { + if opts.Ssl == "" { + // Only modify sslmode for local connections + if strings.Contains(uri.Host, "localhost") || strings.Contains(uri.Host, "127.0.0.1") { + params["sslmode"] = "disable" } + } else { + params["sslmode"] = opts.Ssl } } - // Append sslmode parameter only if its defined as a flag and not present - // in the connection string. - if !strings.Contains(url, "?sslmode") && opts.Ssl != "" { - url += fmt.Sprintf("?sslmode=%s", opts.Ssl) + // Rebuild query params + query := neturl.Values{} + for k, v := range params { + query.Add(k, v) } + uri.RawQuery = query.Encode() - return url, nil + return uri.String(), nil } func IsBlank(opts command.Options) bool { diff --git a/pkg/connection/connection_string_test.go b/pkg/connection/connection_string_test.go index a13b35e..232dba2 100644 --- a/pkg/connection/connection_string_test.go +++ b/pkg/connection/connection_string_test.go @@ -13,6 +13,7 @@ func Test_Invalid_Url(t *testing.T) { opts := command.Options{} examples := []string{ "postgre://foobar", + "tcp://blah", "foobar", } @@ -48,14 +49,12 @@ func Test_Localhost_Url_And_No_Ssl_Flag(t *testing.T) { str, err := BuildString(command.Options{ Url: "postgres://localhost/database", }) - assert.Equal(t, nil, err) assert.Equal(t, "postgres://localhost/database?sslmode=disable", str) str, err = BuildString(command.Options{ Url: "postgres://127.0.0.1/database", }) - assert.Equal(t, nil, err) assert.Equal(t, "postgres://127.0.0.1/database?sslmode=disable", str) } @@ -65,7 +64,6 @@ func Test_Localhost_Url_And_Ssl_Flag(t *testing.T) { Url: "postgres://localhost/database", Ssl: "require", }) - assert.Equal(t, nil, err) assert.Equal(t, "postgres://localhost/database?sslmode=require", str) @@ -73,7 +71,6 @@ func Test_Localhost_Url_And_Ssl_Flag(t *testing.T) { Url: "postgres://127.0.0.1/database", Ssl: "require", }) - assert.Equal(t, nil, err) assert.Equal(t, "postgres://127.0.0.1/database?sslmode=require", str) } @@ -82,14 +79,12 @@ func Test_Localhost_Url_And_Ssl_Arg(t *testing.T) { str, err := BuildString(command.Options{ Url: "postgres://localhost/database?sslmode=require", }) - assert.Equal(t, nil, err) assert.Equal(t, "postgres://localhost/database?sslmode=require", str) str, err = BuildString(command.Options{ Url: "postgres://127.0.0.1/database?sslmode=require", }) - assert.Equal(t, nil, err) assert.Equal(t, "postgres://127.0.0.1/database?sslmode=require", str) } From dc4e8598f7d3661c33fed17c4eee783351706b45 Mon Sep 17 00:00:00 2001 From: Dan Sosedoff Date: Thu, 13 Sep 2018 22:44:11 -0500 Subject: [PATCH 2/3] Refactor building connection string from options --- pkg/cli/cli.go | 2 +- pkg/client/client.go | 2 +- pkg/connection/connection_string.go | 39 ++++++++++-------------- pkg/connection/connection_string_test.go | 34 ++++++++++----------- 4 files changed, 35 insertions(+), 42 deletions(-) diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go index 9fad2fd..6a72bbb 100644 --- a/pkg/cli/cli.go +++ b/pkg/cli/cli.go @@ -38,7 +38,7 @@ func initClientUsingBookmark(bookmarkPath, bookmarkName string) (*client.Client, if opt.Url != "" { // if the bookmark has url set, use it connStr = opt.Url } else { - connStr, err = connection.BuildString(opt) + connStr, err = connection.BuildStringFromOptions(opt) if err != nil { return nil, fmt.Errorf("error building connection string: %v", err) } diff --git a/pkg/client/client.go b/pkg/client/client.go index 694aa68..d3c38bc 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -55,7 +55,7 @@ func getSchemaAndTable(str string) (string, string) { } func New() (*Client, error) { - str, err := connection.BuildString(command.Opts) + str, err := connection.BuildStringFromOptions(command.Opts) if command.Opts.Debug && str != "" { fmt.Println("Creating a new client for:", str) diff --git a/pkg/connection/connection_string.go b/pkg/connection/connection_string.go index d11b703..51aba00 100644 --- a/pkg/connection/connection_string.go +++ b/pkg/connection/connection_string.go @@ -86,7 +86,9 @@ func IsBlank(opts command.Options) bool { return opts.Host == "" && opts.User == "" && opts.DbName == "" && opts.Url == "" } -func BuildString(opts command.Options) (string, error) { +// Build a new database connection string for the CLI options +func BuildStringFromOptions(opts command.Options) (string, error) { + // If connection string is provided we just use that if opts.Url != "" { return FormatUrl(opts) } @@ -100,31 +102,22 @@ func BuildString(opts command.Options) (string, error) { } // Disable ssl for localhost connections, most users have it disabled - if opts.Host == "localhost" || opts.Host == "127.0.0.1" { - if opts.Ssl == "" { - opts.Ssl = "disable" - } - } - - url := "postgres://" - - if opts.User != "" { - url += opts.User - } - - if opts.Pass != "" { - url += fmt.Sprintf(":%s", neturl.QueryEscape(opts.Pass)) - } - - url += fmt.Sprintf("@%s:%d", opts.Host, opts.Port) - - if opts.DbName != "" { - url += fmt.Sprintf("/%s", opts.DbName) + if opts.Ssl == "" && (opts.Host == "localhost" || opts.Host == "127.0.0.1") { + opts.Ssl = "disable" } + query := neturl.Values{} if opts.Ssl != "" { - url += fmt.Sprintf("?sslmode=%s", opts.Ssl) + query.Add("sslmode", opts.Ssl) } - return url, nil + url := neturl.URL{ + Scheme: "postgres", + Host: fmt.Sprintf("%v:%v", opts.Host, opts.Port), + User: neturl.UserPassword(opts.User, opts.Pass), + Path: fmt.Sprintf("/%s", opts.DbName), + RawQuery: query.Encode(), + } + + return url.String(), nil } diff --git a/pkg/connection/connection_string_test.go b/pkg/connection/connection_string_test.go index 232dba2..9db792b 100644 --- a/pkg/connection/connection_string_test.go +++ b/pkg/connection/connection_string_test.go @@ -19,7 +19,7 @@ func Test_Invalid_Url(t *testing.T) { for _, val := range examples { opts.Url = val - str, err := BuildString(opts) + str, err := BuildStringFromOptions(opts) assert.Equal(t, "", str) assert.Error(t, err) @@ -29,14 +29,14 @@ func Test_Invalid_Url(t *testing.T) { func Test_Valid_Url(t *testing.T) { url := "postgres://myhost/database" - str, err := BuildString(command.Options{Url: url}) + str, err := BuildStringFromOptions(command.Options{Url: url}) assert.Equal(t, nil, err) assert.Equal(t, url, str) } func Test_Url_And_Ssl_Flag(t *testing.T) { - str, err := BuildString(command.Options{ + str, err := BuildStringFromOptions(command.Options{ Url: "postgres://myhost/database", Ssl: "disable", }) @@ -46,13 +46,13 @@ func Test_Url_And_Ssl_Flag(t *testing.T) { } func Test_Localhost_Url_And_No_Ssl_Flag(t *testing.T) { - str, err := BuildString(command.Options{ + str, err := BuildStringFromOptions(command.Options{ Url: "postgres://localhost/database", }) assert.Equal(t, nil, err) assert.Equal(t, "postgres://localhost/database?sslmode=disable", str) - str, err = BuildString(command.Options{ + str, err = BuildStringFromOptions(command.Options{ Url: "postgres://127.0.0.1/database", }) assert.Equal(t, nil, err) @@ -60,14 +60,14 @@ func Test_Localhost_Url_And_No_Ssl_Flag(t *testing.T) { } func Test_Localhost_Url_And_Ssl_Flag(t *testing.T) { - str, err := BuildString(command.Options{ + str, err := BuildStringFromOptions(command.Options{ Url: "postgres://localhost/database", Ssl: "require", }) assert.Equal(t, nil, err) assert.Equal(t, "postgres://localhost/database?sslmode=require", str) - str, err = BuildString(command.Options{ + str, err = BuildStringFromOptions(command.Options{ Url: "postgres://127.0.0.1/database", Ssl: "require", }) @@ -76,13 +76,13 @@ func Test_Localhost_Url_And_Ssl_Flag(t *testing.T) { } func Test_Localhost_Url_And_Ssl_Arg(t *testing.T) { - str, err := BuildString(command.Options{ + str, err := BuildStringFromOptions(command.Options{ Url: "postgres://localhost/database?sslmode=require", }) assert.Equal(t, nil, err) assert.Equal(t, "postgres://localhost/database?sslmode=require", str) - str, err = BuildString(command.Options{ + str, err = BuildStringFromOptions(command.Options{ Url: "postgres://127.0.0.1/database?sslmode=require", }) assert.Equal(t, nil, err) @@ -90,7 +90,7 @@ func Test_Localhost_Url_And_Ssl_Arg(t *testing.T) { } func Test_Flag_Args(t *testing.T) { - str, err := BuildString(command.Options{ + str, err := BuildStringFromOptions(command.Options{ Host: "host", Port: 5432, User: "user", @@ -111,12 +111,12 @@ func Test_Localhost(t *testing.T) { DbName: "db", } - str, err := BuildString(opts) + str, err := BuildStringFromOptions(opts) assert.Equal(t, nil, err) assert.Equal(t, "postgres://user:password@localhost:5432/db?sslmode=disable", str) opts.Host = "127.0.0.1" - str, err = BuildString(opts) + str, err = BuildStringFromOptions(opts) assert.Equal(t, nil, err) assert.Equal(t, "postgres://user:password@127.0.0.1:5432/db?sslmode=disable", str) } @@ -131,7 +131,7 @@ func Test_Localhost_And_Ssl(t *testing.T) { Ssl: "require", } - str, err := BuildString(opts) + str, err := BuildStringFromOptions(opts) assert.Equal(t, nil, err) assert.Equal(t, "postgres://user:password@localhost:5432/db?sslmode=require", str) } @@ -139,18 +139,18 @@ func Test_Localhost_And_Ssl(t *testing.T) { func Test_No_User(t *testing.T) { opts := command.Options{Host: "host", Port: 5432, DbName: "db"} u, _ := user.Current() - str, err := BuildString(opts) + str, err := BuildStringFromOptions(opts) assert.Equal(t, nil, err) - assert.Equal(t, fmt.Sprintf("postgres://%s@host:5432/db", u.Username), str) + assert.Equal(t, fmt.Sprintf("postgres://%s:@host:5432/db", u.Username), str) } func Test_Port(t *testing.T) { opts := command.Options{Host: "host", User: "user", Port: 5000, DbName: "db"} - str, err := BuildString(opts) + str, err := BuildStringFromOptions(opts) assert.Equal(t, nil, err) - assert.Equal(t, "postgres://user@host:5000/db", str) + assert.Equal(t, "postgres://user:@host:5000/db", str) } func Test_Blank(t *testing.T) { From fc380df8dd0a72da4bac1abac9f5320ac53d0487 Mon Sep 17 00:00:00 2001 From: Dan Sosedoff Date: Thu, 13 Sep 2018 23:22:25 -0500 Subject: [PATCH 3/3] Fix failing win test --- pkg/connection/connection_string_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pkg/connection/connection_string_test.go b/pkg/connection/connection_string_test.go index 9db792b..de109ea 100644 --- a/pkg/connection/connection_string_test.go +++ b/pkg/connection/connection_string_test.go @@ -2,6 +2,7 @@ package connection import ( "fmt" + "net/url" "os/user" "testing" @@ -140,9 +141,10 @@ func Test_No_User(t *testing.T) { opts := command.Options{Host: "host", Port: 5432, DbName: "db"} u, _ := user.Current() str, err := BuildStringFromOptions(opts) + userAndPass := url.UserPassword(u.Username, "").String() assert.Equal(t, nil, err) - assert.Equal(t, fmt.Sprintf("postgres://%s:@host:5432/db", u.Username), str) + assert.Equal(t, fmt.Sprintf("postgres://%s@host:5432/db", userAndPass), str) } func Test_Port(t *testing.T) {