diff --git a/pkg/api/api.go b/pkg/api/api.go index ecda2ee..28dd4de 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -453,8 +453,14 @@ func GetHistory(c *gin.Context) { // GetConnectionInfo renders information about current connection func GetConnectionInfo(c *gin.Context) { - res, err := DB(c).Info() + conn := DB(c) + if err := conn.TestWithTimeout(5 * time.Second); err != nil { + badRequest(c, err) + return + } + + res, err := conn.Info() if err != nil { badRequest(c, err) return diff --git a/pkg/client/client.go b/pkg/client/client.go index 4fce389..008733c 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -188,6 +188,32 @@ func (client *Client) Test() error { return client.db.Ping() } +func (client *Client) TestWithTimeout(timeout time.Duration) (result error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + // Check connection status right away without waiting for the ticker to kick in. + // We're expecting to get "connection refused" here for the most part. + if err := client.db.PingContext(ctx); err == nil { + return nil + } + + ticker := time.NewTicker(250 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + result = client.db.PingContext(ctx) + if result == nil { + return + } + case <-ctx.Done(): + return + } + } +} + func (client *Client) Info() (*Result, error) { return client.query(statements.Info) }