Allow retrying a connection on startup (#695)
* Allow retrying a connection on startup * Remove unused vars * Add an extra comment * Restructure retry logic a bit * Update retry logic * Fix comment * Update comment * Change type for RetryCount and RetryDelay to uint * Extra test cases * Tweak
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"log"
|
||||
neturl "net/url"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -21,6 +22,18 @@ import (
|
||||
"github.com/sosedoff/pgweb/pkg/statements"
|
||||
)
|
||||
|
||||
var (
|
||||
regexErrAuthFailed = regexp.MustCompile(`(authentication failed|role "(.*)" does not exist)`)
|
||||
regexErrConnectionRefused = regexp.MustCompile(`(connection|actively) refused`)
|
||||
regexErrDatabaseNotExist = regexp.MustCompile(`database "(.*)" does not exist`)
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAuthFailed = errors.New("authentication failed")
|
||||
ErrConnectionRefused = errors.New("connection refused")
|
||||
ErrDatabaseNotExist = errors.New("database does not exist")
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
db *sqlx.DB
|
||||
tunnel *Tunnel
|
||||
@@ -179,7 +192,28 @@ func (client *Client) setServerVersion() {
|
||||
}
|
||||
|
||||
func (client *Client) Test() error {
|
||||
return client.db.Ping()
|
||||
// NOTE: This is a different timeout defined in CLI OpenTimeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := client.db.PingContext(ctx)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
errMsg := err.Error()
|
||||
|
||||
if regexErrConnectionRefused.MatchString(errMsg) {
|
||||
return ErrConnectionRefused
|
||||
}
|
||||
if regexErrAuthFailed.MatchString(errMsg) {
|
||||
return ErrAuthFailed
|
||||
}
|
||||
if regexErrDatabaseNotExist.MatchString(errMsg) {
|
||||
return ErrDatabaseNotExist
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (client *Client) TestWithTimeout(timeout time.Duration) (result error) {
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/sosedoff/pgweb/pkg/command"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -199,7 +200,46 @@ func testClientIdleTime(t *testing.T) {
|
||||
}
|
||||
|
||||
func testTest(t *testing.T) {
|
||||
assert.NoError(t, testClient.Test())
|
||||
examples := []struct {
|
||||
name string
|
||||
input string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
input: fmt.Sprintf("postgres://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, serverDatabase),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "connection refused",
|
||||
input: "postgresql://localhost:5433/dbname",
|
||||
err: ErrConnectionRefused,
|
||||
},
|
||||
{
|
||||
name: "invalid user",
|
||||
input: fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", "foo", serverPassword, serverHost, serverPort, serverDatabase),
|
||||
err: ErrAuthFailed,
|
||||
},
|
||||
{
|
||||
name: "invalid password",
|
||||
input: fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", serverUser, "foo", serverHost, serverPort, serverDatabase),
|
||||
err: ErrAuthFailed,
|
||||
},
|
||||
{
|
||||
name: "invalid database",
|
||||
input: fmt.Sprintf("postgres://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, "foo"),
|
||||
err: ErrDatabaseNotExist,
|
||||
},
|
||||
}
|
||||
|
||||
for _, ex := range examples {
|
||||
t.Run(ex.name, func(t *testing.T) {
|
||||
conn, err := NewFromUrl(ex.input, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, ex.err, conn.Test())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testInfo(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user