diff --git a/pkg/api/api.go b/pkg/api/api.go index d1afaed..fafec07 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -26,13 +26,13 @@ var ( DbClient *client.Client // DbSessions represents the mapping for client connections - DbSessions = map[string]*client.Client{} + DbSessions *SessionManager ) // DB returns a database connection from the client context func DB(c *gin.Context) *client.Client { if command.Opts.Sessions { - return DbSessions[getSessionId(c.Request)] + return DbSessions.Get(getSessionId(c.Request)) } return DbClient } @@ -54,7 +54,7 @@ func setClient(c *gin.Context, newClient *client.Client) error { return errSessionRequired } - DbSessions[sid] = newClient + DbSessions.Add(sid, newClient) return nil } @@ -80,10 +80,10 @@ func GetSessions(c *gin.Context) { // In debug mode endpoint will return a lot of sensitive information // like full database connection string and all query history. if command.Opts.Debug { - successResponse(c, DbSessions) + successResponse(c, DbSessions.Sessions()) return } - successResponse(c, gin.H{"sessions": len(DbSessions)}) + successResponse(c, gin.H{"sessions": DbSessions.Len()}) } // ConnectWithBackend creates a new connection based on backend resource diff --git a/pkg/api/middleware.go b/pkg/api/middleware.go index f20dffb..be79ec1 100644 --- a/pkg/api/middleware.go +++ b/pkg/api/middleware.go @@ -39,7 +39,7 @@ func dbCheckMiddleware() gin.HandlerFunc { } // Determine the database connection handle for the session - conn := DbSessions[sid] + conn := DbSessions.Get(sid) if conn == nil { badRequest(c, errNotConnected) return diff --git a/pkg/api/session_cleanup.go b/pkg/api/session_cleanup.go deleted file mode 100644 index 862d121..0000000 --- a/pkg/api/session_cleanup.go +++ /dev/null @@ -1,41 +0,0 @@ -package api - -import ( - "log" - "time" - - "github.com/sosedoff/pgweb/pkg/command" -) - -// StartSessionCleanup starts a goroutine to cleanup idle database sessions -func StartSessionCleanup() { - for range time.Tick(time.Minute) { - if command.Opts.Debug { - log.Println("Triggering idle session deletion") - } - cleanupIdleSessions() - } -} - -func cleanupIdleSessions() { - ids := []string{} - - // Figure out which sessions are idle - for id, client := range DbSessions { - if client.IsIdle() { - ids = append(ids, id) - } - } - if len(ids) == 0 { - return - } - - // Close and delete idle sessions - log.Println("Closing", len(ids), "idle sessions") - for _, id := range ids { - // TODO: concurrent map edit will trigger panic - if err := DbSessions[id].Close(); err == nil { - delete(DbSessions, id) - } - } -} diff --git a/pkg/api/session_manager.go b/pkg/api/session_manager.go new file mode 100644 index 0000000..9166f9f --- /dev/null +++ b/pkg/api/session_manager.go @@ -0,0 +1,130 @@ +package api + +import ( + "sync" + "time" + + "github.com/sirupsen/logrus" + + "github.com/sosedoff/pgweb/pkg/client" +) + +type SessionManager struct { + logger *logrus.Logger + sessions map[string]*client.Client + mu sync.Mutex + idleTimeout time.Duration +} + +func NewSessionManager(logger *logrus.Logger) *SessionManager { + return &SessionManager{ + logger: logger, + sessions: map[string]*client.Client{}, + mu: sync.Mutex{}, + } +} + +func (m *SessionManager) SetIdleTimeout(timeout time.Duration) { + m.idleTimeout = timeout +} + +func (m *SessionManager) IDs() []string { + m.mu.Lock() + defer m.mu.Unlock() + + ids := []string{} + for k := range m.sessions { + ids = append(ids, k) + } + + return ids +} + +func (m *SessionManager) Sessions() map[string]*client.Client { + m.mu.Lock() + sessions := m.sessions + defer m.mu.Unlock() + + return sessions +} + +func (m *SessionManager) Get(id string) *client.Client { + m.mu.Lock() + c := m.sessions[id] + m.mu.Unlock() + + return c +} + +func (m *SessionManager) Add(id string, conn *client.Client) { + m.mu.Lock() + m.sessions[id] = conn + m.mu.Unlock() +} + +func (m *SessionManager) Remove(id string) bool { + m.mu.Lock() + defer m.mu.Unlock() + + conn, ok := m.sessions[id] + if ok { + conn.Close() + delete(m.sessions, id) + } + + return ok +} + +func (m *SessionManager) Len() int { + m.mu.Lock() + sz := len(m.sessions) + m.mu.Unlock() + + return sz +} + +func (m *SessionManager) Cleanup() int { + if m.idleTimeout == 0 { + return 0 + } + + removed := 0 + + m.logger.Debug("starting idle sessions cleanup") + defer func() { + m.logger.Debug("removed idle sessions:", removed) + }() + + for _, id := range m.staleSessions() { + m.logger.WithField("id", id).Debug("closing stale session") + if m.Remove(id) { + removed++ + } + } + + return removed +} + +func (m *SessionManager) RunPeriodicCleanup() { + m.logger.WithField("timeout", m.idleTimeout).Info("session manager cleanup enabled") + + for range time.Tick(time.Minute) { + m.Cleanup() + } +} + +func (m *SessionManager) staleSessions() []string { + m.mu.TryLock() + defer m.mu.Unlock() + + now := time.Now() + ids := []string{} + + for id, conn := range m.sessions { + if now.Sub(conn.LastQueryTime()) > m.idleTimeout { + ids = append(ids, id) + } + } + + return ids +} diff --git a/pkg/api/session_manager_test.go b/pkg/api/session_manager_test.go new file mode 100644 index 0000000..594af36 --- /dev/null +++ b/pkg/api/session_manager_test.go @@ -0,0 +1,80 @@ +package api + +import ( + "sort" + "testing" + "time" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + + "github.com/sosedoff/pgweb/pkg/client" +) + +func TestSessionManager(t *testing.T) { + t.Run("return ids", func(t *testing.T) { + manager := NewSessionManager(nil) + assert.Equal(t, []string{}, manager.IDs()) + + manager.sessions["foo"] = &client.Client{} + manager.sessions["bar"] = &client.Client{} + + ids := manager.IDs() + sort.Strings(ids) + assert.Equal(t, []string{"bar", "foo"}, ids) + }) + + t.Run("get session", func(t *testing.T) { + manager := NewSessionManager(nil) + assert.Nil(t, manager.Get("foo")) + + manager.sessions["foo"] = &client.Client{} + assert.NotNil(t, manager.Get("foo")) + }) + + t.Run("set session", func(t *testing.T) { + manager := NewSessionManager(nil) + assert.Nil(t, manager.Get("foo")) + + manager.Add("foo", &client.Client{}) + assert.NotNil(t, manager.Get("foo")) + }) + + t.Run("remove session", func(t *testing.T) { + manager := NewSessionManager(nil) + assert.Nil(t, manager.Get("foo")) + + manager.Add("foo", &client.Client{}) + assert.NotNil(t, manager.Get("foo")) + assert.True(t, manager.Remove("foo")) + assert.False(t, manager.Remove("foo")) + assert.Nil(t, manager.Get("foo")) + }) + + t.Run("return len", func(t *testing.T) { + manager := NewSessionManager(nil) + manager.sessions["foo"] = &client.Client{} + manager.sessions["bar"] = &client.Client{} + + assert.Equal(t, 2, manager.Len()) + }) + + t.Run("clean up stale sessions", func(t *testing.T) { + manager := NewSessionManager(logrus.New()) + conn := &client.Client{} + manager.Add("foo", conn) + + assert.Equal(t, 1, manager.Len()) + assert.Equal(t, 0, manager.Cleanup()) + assert.Equal(t, 1, manager.Len()) + + res, err := conn.Query("select 1") + assert.Nil(t, res) + assert.Nil(t, err) + + manager.SetIdleTimeout(time.Minute) + assert.Equal(t, 1, manager.Cleanup()) + assert.Equal(t, 0, manager.Len()) + assert.True(t, conn.IsClosed()) + }) +} diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go index 8338295..bb3e306 100644 --- a/pkg/cli/cli.go +++ b/pkg/cli/cli.go @@ -8,6 +8,7 @@ import ( "regexp" "strings" "syscall" + "time" "github.com/gin-gonic/gin" "github.com/jessevdk/go-flags" @@ -23,6 +24,7 @@ import ( ) var ( + logger *logrus.Logger options command.Options readonlyWarning = ` @@ -36,6 +38,10 @@ For proper read-only access please follow postgresql role management documentati regexErrAuthFailed = regexp.MustCompile(`authentication failed`) ) +func init() { + logger = logrus.New() +} + func exitWithMessage(message string) { fmt.Println("Error:", message) os.Exit(1) @@ -152,6 +158,10 @@ func initOptions() { os.Exit(0) } + if options.Debug { + logger.SetLevel(logrus.DebugLevel) + } + if options.ReadOnly { fmt.Println(readonlyWarning) } @@ -184,11 +194,6 @@ func printVersion() { } func startServer() { - logger := logrus.New() - if options.Debug { - logger.SetLevel(logrus.DebugLevel) - } - router := gin.New() router.Use(api.RequestLogger(logger)) router.Use(gin.Recovery()) @@ -258,8 +263,13 @@ func Run() { } // Start session cleanup worker - if options.Sessions && !command.Opts.DisableConnectionIdleTimeout { - go api.StartSessionCleanup() + if options.Sessions { + api.DbSessions = api.NewSessionManager(logger) + + if !command.Opts.DisableConnectionIdleTimeout { + api.DbSessions.SetIdleTimeout(time.Minute * time.Duration(command.Opts.ConnectionIdleTimeout)) + go api.DbSessions.RunPeriodicCleanup() + } } startServer() diff --git a/pkg/client/client.go b/pkg/client/client.go index 865a2e3..3d8bd50 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -34,7 +34,8 @@ type Client struct { serverVersion string serverType string lastQueryTime time.Time - External bool + closed bool + External bool `json:"external"` History []history.Record `json:"history"` ConnectionString string `json:"connection_string"` } @@ -334,6 +335,10 @@ func (client *Client) ServerVersion() string { } func (client *Client) query(query string, args ...interface{}) (*Result, error) { + if client.db == nil { + return nil, nil + } + // Update the last usage time defer func() { client.lastQueryTime = time.Now().UTC() @@ -365,7 +370,7 @@ func (client *Client) query(query string, args ...interface{}) (*Result, error) result := Result{ Columns: []string{"Rows Affected"}, Rows: []Row{ - Row{affected}, + {affected}, }, } @@ -423,6 +428,13 @@ func (client *Client) query(query string, args ...interface{}) (*Result, error) // Close database connection func (client *Client) Close() error { + if client.closed { + return nil + } + defer func() { + client.closed = true + }() + if client.tunnel != nil { client.tunnel.Close() } @@ -434,6 +446,14 @@ func (client *Client) Close() error { return nil } +func (c *Client) IsClosed() bool { + return c.closed +} + +func (c *Client) LastQueryTime() time.Time { + return c.lastQueryTime +} + func (client *Client) IsIdle() bool { mins := int(time.Since(client.lastQueryTime).Minutes())