diff --git a/pkg/api/api.go b/pkg/api/api.go index d1afaed..fafec07 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -26,13 +26,13 @@ var ( DbClient *client.Client // DbSessions represents the mapping for client connections - DbSessions = map[string]*client.Client{} + DbSessions *SessionManager ) // DB returns a database connection from the client context func DB(c *gin.Context) *client.Client { if command.Opts.Sessions { - return DbSessions[getSessionId(c.Request)] + return DbSessions.Get(getSessionId(c.Request)) } return DbClient } @@ -54,7 +54,7 @@ func setClient(c *gin.Context, newClient *client.Client) error { return errSessionRequired } - DbSessions[sid] = newClient + DbSessions.Add(sid, newClient) return nil } @@ -80,10 +80,10 @@ func GetSessions(c *gin.Context) { // In debug mode endpoint will return a lot of sensitive information // like full database connection string and all query history. if command.Opts.Debug { - successResponse(c, DbSessions) + successResponse(c, DbSessions.Sessions()) return } - successResponse(c, gin.H{"sessions": len(DbSessions)}) + successResponse(c, gin.H{"sessions": DbSessions.Len()}) } // ConnectWithBackend creates a new connection based on backend resource diff --git a/pkg/api/middleware.go b/pkg/api/middleware.go index f20dffb..be79ec1 100644 --- a/pkg/api/middleware.go +++ b/pkg/api/middleware.go @@ -39,7 +39,7 @@ func dbCheckMiddleware() gin.HandlerFunc { } // Determine the database connection handle for the session - conn := DbSessions[sid] + conn := DbSessions.Get(sid) if conn == nil { badRequest(c, errNotConnected) return diff --git a/pkg/api/session_cleanup.go b/pkg/api/session_cleanup.go deleted file mode 100644 index 862d121..0000000 --- a/pkg/api/session_cleanup.go +++ /dev/null @@ -1,41 +0,0 @@ -package api - -import ( - "log" - "time" - - "github.com/sosedoff/pgweb/pkg/command" -) - -// StartSessionCleanup starts a goroutine to cleanup idle database sessions -func StartSessionCleanup() { - for range time.Tick(time.Minute) { - if command.Opts.Debug { - log.Println("Triggering idle session deletion") - } - cleanupIdleSessions() - } -} - -func cleanupIdleSessions() { - ids := []string{} - - // Figure out which sessions are idle - for id, client := range DbSessions { - if client.IsIdle() { - ids = append(ids, id) - } - } - if len(ids) == 0 { - return - } - - // Close and delete idle sessions - log.Println("Closing", len(ids), "idle sessions") - for _, id := range ids { - // TODO: concurrent map edit will trigger panic - if err := DbSessions[id].Close(); err == nil { - delete(DbSessions, id) - } - } -} diff --git a/pkg/api/session_manager.go b/pkg/api/session_manager.go new file mode 100644 index 0000000..00478a8 --- /dev/null +++ b/pkg/api/session_manager.go @@ -0,0 +1,117 @@ +package api + +import ( + "sync" + "time" + + "github.com/sirupsen/logrus" + + "github.com/sosedoff/pgweb/pkg/client" +) + +type SessionManager struct { + logger *logrus.Logger + sessions map[string]*client.Client + mu sync.Mutex +} + +func NewSessionManager(logger *logrus.Logger) *SessionManager { + return &SessionManager{ + logger: logger, + sessions: map[string]*client.Client{}, + mu: sync.Mutex{}, + } +} + +func (m *SessionManager) IDs() []string { + m.mu.Lock() + defer m.mu.Unlock() + + ids := []string{} + for k := range m.sessions { + ids = append(ids, k) + } + + return ids +} + +func (m *SessionManager) Sessions() map[string]*client.Client { + m.mu.Lock() + sessions := m.sessions + defer m.mu.Unlock() + + return sessions +} + +func (m *SessionManager) Get(id string) *client.Client { + m.mu.Lock() + c := m.sessions[id] + m.mu.Unlock() + + return c +} + +func (m *SessionManager) Add(id string, conn *client.Client) { + m.mu.Lock() + m.sessions[id] = conn + m.mu.Unlock() +} + +func (m *SessionManager) Remove(id string) bool { + m.mu.Lock() + defer m.mu.Unlock() + + conn, ok := m.sessions[id] + if ok { + conn.Close() + delete(m.sessions, id) + } + + return ok +} + +func (m *SessionManager) Len() int { + m.mu.Lock() + sz := len(m.sessions) + m.mu.Unlock() + + return sz +} + +func (m *SessionManager) Cleanup() int { + removed := 0 + + m.logger.Debug("starting idle sessions cleanup") + defer func() { + m.logger.Debug("removed idle sessions:", removed) + }() + + for _, id := range m.staleSessions() { + m.logger.WithField("id", id).Debug("closing stale session") + if m.Remove(id) { + removed++ + } + } + + return removed +} + +func (m *SessionManager) RunPeriodicCleanup() { + for range time.Tick(time.Minute) { + m.Cleanup() + } +} + +func (m *SessionManager) staleSessions() []string { + m.mu.TryLock() + defer m.mu.Unlock() + + ids := []string{} + for id, conn := range m.sessions { + if conn.IsIdle() { + ids = append(ids, id) + } + } + + return ids +} diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go index 8338295..7bd07bf 100644 --- a/pkg/cli/cli.go +++ b/pkg/cli/cli.go @@ -23,6 +23,7 @@ import ( ) var ( + logger *logrus.Logger options command.Options readonlyWarning = ` @@ -36,6 +37,10 @@ For proper read-only access please follow postgresql role management documentati regexErrAuthFailed = regexp.MustCompile(`authentication failed`) ) +func init() { + logger = logrus.New() +} + func exitWithMessage(message string) { fmt.Println("Error:", message) os.Exit(1) @@ -152,6 +157,10 @@ func initOptions() { os.Exit(0) } + if options.Debug { + logger.SetLevel(logrus.DebugLevel) + } + if options.ReadOnly { fmt.Println(readonlyWarning) } @@ -184,11 +193,6 @@ func printVersion() { } func startServer() { - logger := logrus.New() - if options.Debug { - logger.SetLevel(logrus.DebugLevel) - } - router := gin.New() router.Use(api.RequestLogger(logger)) router.Use(gin.Recovery()) @@ -258,8 +262,12 @@ func Run() { } // Start session cleanup worker - if options.Sessions && !command.Opts.DisableConnectionIdleTimeout { - go api.StartSessionCleanup() + if options.Sessions { + api.DbSessions = api.NewSessionManager(logger) + + if !command.Opts.DisableConnectionIdleTimeout { + go api.DbSessions.RunPeriodicCleanup() + } } startServer() diff --git a/pkg/client/client.go b/pkg/client/client.go index 865a2e3..d58365d 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -34,7 +34,8 @@ type Client struct { serverVersion string serverType string lastQueryTime time.Time - External bool + closed bool + External bool `json:"external"` History []history.Record `json:"history"` ConnectionString string `json:"connection_string"` } @@ -423,6 +424,13 @@ func (client *Client) query(query string, args ...interface{}) (*Result, error) // Close database connection func (client *Client) Close() error { + if client.closed { + return nil + } + defer func() { + client.closed = true + }() + if client.tunnel != nil { client.tunnel.Close() }