diff --git a/pkg/command/options.go b/pkg/command/options.go index 7eeaf83..6ea3872 100644 --- a/pkg/command/options.go +++ b/pkg/command/options.go @@ -18,7 +18,10 @@ type Options struct { User string `long:"user" description:"Database user"` Pass string `long:"pass" description:"Password for user"` DbName string `long:"db" description:"Database name"` - Ssl string `long:"ssl" description:"SSL option"` + Ssl string `long:"ssl" description:"SSL mode"` + SslRootCert string `long:"ssl-rootcert" description:"SSL certificate authority file"` + SslCert string `long:"ssl-cert" description:"SSL client certificate file"` + SslKey string `long:"ssl-key" description:"SSL client certificate key file"` HTTPHost string `long:"bind" description:"HTTP server host" default:"localhost"` HTTPPort uint `long:"listen" description:"HTTP server listen port" default:"8081"` AuthUser string `long:"auth-user" description:"HTTP basic auth user"` @@ -43,6 +46,7 @@ type Options struct { var Opts Options +// ParseOptions returns a new options struct from the input arguments func ParseOptions(args []string) (Options, error) { var opts = Options{} diff --git a/pkg/connection/connection_string.go b/pkg/connection/connection_string.go index 4ea141b..3ca831c 100644 --- a/pkg/connection/connection_string.go +++ b/pkg/connection/connection_string.go @@ -93,6 +93,8 @@ func IsBlank(opts command.Options) bool { // BuildStringFromOptions returns a new connection string built from options func BuildStringFromOptions(opts command.Options) (string, error) { + query := neturl.Values{} + // If connection string is provided we just use that if opts.URL != "" { return FormatURL(opts) @@ -106,14 +108,21 @@ func BuildStringFromOptions(opts command.Options) (string, error) { } } - // Disable ssl for localhost connections, most users have it disabled - if opts.Ssl == "" && (opts.Host == "localhost" || opts.Host == "127.0.0.1") { - opts.Ssl = "disable" - } - - query := neturl.Values{} if opts.Ssl != "" { query.Add("sslmode", opts.Ssl) + } else { + if opts.Host == "localhost" || opts.Host == "127.0.0.1" { + query.Add("sslmode", "disable") + } + } + if opts.SslCert != "" { + query.Add("sslcert", opts.SslCert) + } + if opts.SslKey != "" { + query.Add("sslkey", opts.SslKey) + } + if opts.SslRootCert != "" { + query.Add("sslrootcert", opts.SslRootCert) } url := neturl.URL{ diff --git a/pkg/connection/connection_string_test.go b/pkg/connection/connection_string_test.go index 5b4b004..ce82bfe 100644 --- a/pkg/connection/connection_string_test.go +++ b/pkg/connection/connection_string_test.go @@ -90,6 +90,14 @@ func Test_Localhost_Url_And_Ssl_Arg(t *testing.T) { assert.Equal(t, "postgres://127.0.0.1/database?sslmode=require", str) } +func Test_ExtendedSSLFlags(t *testing.T) { + str, err := BuildStringFromOptions(command.Options{ + URL: "postgres://localhost/database?sslmode=require&sslcert=cert&sslkey=key&sslrootcert=ca", + }) + assert.Equal(t, nil, err) + assert.Equal(t, "postgres://localhost/database?sslcert=cert&sslkey=key&sslmode=require&sslrootcert=ca", str) +} + func Test_Flag_Args(t *testing.T) { str, err := BuildStringFromOptions(command.Options{ Host: "host", @@ -124,17 +132,20 @@ func Test_Localhost(t *testing.T) { func Test_Localhost_And_Ssl(t *testing.T) { opts := command.Options{ - Host: "localhost", - Port: 5432, - User: "user", - Pass: "password", - DbName: "db", - Ssl: "require", + Host: "localhost", + Port: 5432, + User: "user", + Pass: "password", + DbName: "db", + Ssl: "require", + SslKey: "keyPath", + SslCert: "certPath", + SslRootCert: "caPath", } str, err := BuildStringFromOptions(opts) assert.Equal(t, nil, err) - assert.Equal(t, "postgres://user:password@localhost:5432/db?sslmode=require", str) + assert.Equal(t, "postgres://user:password@localhost:5432/db?sslcert=certPath&sslkey=keyPath&sslmode=require&sslrootcert=caPath", str) } func Test_No_User(t *testing.T) {