diff --git a/pkg/command/options.go b/pkg/command/options.go index 702db43..f0a5786 100644 --- a/pkg/command/options.go +++ b/pkg/command/options.go @@ -36,6 +36,7 @@ type Options struct { SSLRootCert string `long:"ssl-rootcert" description:"SSL certificate authority file"` SSLCert string `long:"ssl-cert" description:"SSL client certificate file"` SSLKey string `long:"ssl-key" description:"SSL client certificate key file"` + OpenTimeout int `long:"open-timeout" description:" Maximum wait for connection, in seconds" default:"30"` HTTPHost string `long:"bind" description:"HTTP server host" default:"localhost"` HTTPPort uint `long:"listen" description:"HTTP server listen port" default:"8081"` AuthUser string `long:"auth-user" description:"HTTP basic auth user"` diff --git a/pkg/connection/connection_string.go b/pkg/connection/connection_string.go index 4ed8e76..f5f22b1 100644 --- a/pkg/connection/connection_string.go +++ b/pkg/connection/connection_string.go @@ -6,9 +6,11 @@ import ( neturl "net/url" "os" "os/user" + "strconv" "strings" "github.com/jackc/pgpassfile" + "github.com/sosedoff/pgweb/pkg/command" ) @@ -88,6 +90,11 @@ func FormatURL(opts command.Options) (string, error) { } } + // Configure default connect timeout + if opts.OpenTimeout > 0 { + params["connect_timeout"] = strconv.Itoa(opts.OpenTimeout) + } + // Rebuild query params query := neturl.Values{} for k, v := range params { @@ -142,6 +149,11 @@ func BuildStringFromOptions(opts command.Options) (string, error) { opts.Pass = lookupPassword(opts, nil) } + // Configure default connect timeout + if opts.OpenTimeout > 0 { + query.Add("connect_timeout", strconv.Itoa(opts.OpenTimeout)) + } + url := neturl.URL{ Scheme: "postgres", Host: fmt.Sprintf("%v:%v", opts.Host, opts.Port), diff --git a/pkg/connection/connection_string_test.go b/pkg/connection/connection_string_test.go index 4e656fa..e54b26b 100644 --- a/pkg/connection/connection_string_test.go +++ b/pkg/connection/connection_string_test.go @@ -173,6 +173,28 @@ func TestBuildStringFromOptions(t *testing.T) { assert.Equal(t, "postgres://foobar:password2@127.0.0.1:5432/foobar2?sslmode=disable", str) }) + t.Run("with connection timeout", func(t *testing.T) { + opts := command.Options{ + URL: "postgres://user:pass@localhost:5432/dbname", + OpenTimeout: 30, + } + str, err := BuildStringFromOptions(opts) + assert.NoError(t, err) + assert.Equal(t, "postgres://user:pass@localhost:5432/dbname?connect_timeout=30&sslmode=disable", str) + + opts = command.Options{ + Host: "localhost", + Port: 5432, + User: "username", + DbName: "dbname", + OpenTimeout: 30, + } + + str, err = BuildStringFromOptions(opts) + assert.NoError(t, err) + assert.Equal(t, "postgres://username:@localhost:5432/dbname?connect_timeout=30&sslmode=disable", str) + }) + t.Run("invalid url", func(t *testing.T) { opts := command.Options{} examples := []string{ @@ -231,6 +253,14 @@ func TestFormatURL(t *testing.T) { }, result: "postgres://username:password@localhost:5432/dbname?sslmode=disable", }, + { + name: "with timeout setting", + input: command.Options{ + URL: "postgres://username@localhost:5432/dbname", + OpenTimeout: 30, + }, + result: "postgres://username@localhost:5432/dbname?connect_timeout=30&sslmode=disable", + }, } for _, ex := range examples {