Merge pull request #593 from sosedoff/sessions-manager
Add internal sessions manager
This commit is contained in:
commit
e5480621ee
@ -26,13 +26,13 @@ var (
|
||||
DbClient *client.Client
|
||||
|
||||
// DbSessions represents the mapping for client connections
|
||||
DbSessions = map[string]*client.Client{}
|
||||
DbSessions *SessionManager
|
||||
)
|
||||
|
||||
// DB returns a database connection from the client context
|
||||
func DB(c *gin.Context) *client.Client {
|
||||
if command.Opts.Sessions {
|
||||
return DbSessions[getSessionId(c.Request)]
|
||||
return DbSessions.Get(getSessionId(c.Request))
|
||||
}
|
||||
return DbClient
|
||||
}
|
||||
@ -54,7 +54,7 @@ func setClient(c *gin.Context, newClient *client.Client) error {
|
||||
return errSessionRequired
|
||||
}
|
||||
|
||||
DbSessions[sid] = newClient
|
||||
DbSessions.Add(sid, newClient)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -80,10 +80,10 @@ func GetSessions(c *gin.Context) {
|
||||
// In debug mode endpoint will return a lot of sensitive information
|
||||
// like full database connection string and all query history.
|
||||
if command.Opts.Debug {
|
||||
successResponse(c, DbSessions)
|
||||
successResponse(c, DbSessions.Sessions())
|
||||
return
|
||||
}
|
||||
successResponse(c, gin.H{"sessions": len(DbSessions)})
|
||||
successResponse(c, gin.H{"sessions": DbSessions.Len()})
|
||||
}
|
||||
|
||||
// ConnectWithBackend creates a new connection based on backend resource
|
||||
|
@ -39,7 +39,7 @@ func dbCheckMiddleware() gin.HandlerFunc {
|
||||
}
|
||||
|
||||
// Determine the database connection handle for the session
|
||||
conn := DbSessions[sid]
|
||||
conn := DbSessions.Get(sid)
|
||||
if conn == nil {
|
||||
badRequest(c, errNotConnected)
|
||||
return
|
||||
|
@ -1,41 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/sosedoff/pgweb/pkg/command"
|
||||
)
|
||||
|
||||
// StartSessionCleanup starts a goroutine to cleanup idle database sessions
|
||||
func StartSessionCleanup() {
|
||||
for range time.Tick(time.Minute) {
|
||||
if command.Opts.Debug {
|
||||
log.Println("Triggering idle session deletion")
|
||||
}
|
||||
cleanupIdleSessions()
|
||||
}
|
||||
}
|
||||
|
||||
func cleanupIdleSessions() {
|
||||
ids := []string{}
|
||||
|
||||
// Figure out which sessions are idle
|
||||
for id, client := range DbSessions {
|
||||
if client.IsIdle() {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Close and delete idle sessions
|
||||
log.Println("Closing", len(ids), "idle sessions")
|
||||
for _, id := range ids {
|
||||
// TODO: concurrent map edit will trigger panic
|
||||
if err := DbSessions[id].Close(); err == nil {
|
||||
delete(DbSessions, id)
|
||||
}
|
||||
}
|
||||
}
|
130
pkg/api/session_manager.go
Normal file
130
pkg/api/session_manager.go
Normal file
@ -0,0 +1,130 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/sosedoff/pgweb/pkg/client"
|
||||
)
|
||||
|
||||
type SessionManager struct {
|
||||
logger *logrus.Logger
|
||||
sessions map[string]*client.Client
|
||||
mu sync.Mutex
|
||||
idleTimeout time.Duration
|
||||
}
|
||||
|
||||
func NewSessionManager(logger *logrus.Logger) *SessionManager {
|
||||
return &SessionManager{
|
||||
logger: logger,
|
||||
sessions: map[string]*client.Client{},
|
||||
mu: sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SessionManager) SetIdleTimeout(timeout time.Duration) {
|
||||
m.idleTimeout = timeout
|
||||
}
|
||||
|
||||
func (m *SessionManager) IDs() []string {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
ids := []string{}
|
||||
for k := range m.sessions {
|
||||
ids = append(ids, k)
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
func (m *SessionManager) Sessions() map[string]*client.Client {
|
||||
m.mu.Lock()
|
||||
sessions := m.sessions
|
||||
defer m.mu.Unlock()
|
||||
|
||||
return sessions
|
||||
}
|
||||
|
||||
func (m *SessionManager) Get(id string) *client.Client {
|
||||
m.mu.Lock()
|
||||
c := m.sessions[id]
|
||||
m.mu.Unlock()
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (m *SessionManager) Add(id string, conn *client.Client) {
|
||||
m.mu.Lock()
|
||||
m.sessions[id] = conn
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *SessionManager) Remove(id string) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
conn, ok := m.sessions[id]
|
||||
if ok {
|
||||
conn.Close()
|
||||
delete(m.sessions, id)
|
||||
}
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
func (m *SessionManager) Len() int {
|
||||
m.mu.Lock()
|
||||
sz := len(m.sessions)
|
||||
m.mu.Unlock()
|
||||
|
||||
return sz
|
||||
}
|
||||
|
||||
func (m *SessionManager) Cleanup() int {
|
||||
if m.idleTimeout == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
removed := 0
|
||||
|
||||
m.logger.Debug("starting idle sessions cleanup")
|
||||
defer func() {
|
||||
m.logger.Debug("removed idle sessions:", removed)
|
||||
}()
|
||||
|
||||
for _, id := range m.staleSessions() {
|
||||
m.logger.WithField("id", id).Debug("closing stale session")
|
||||
if m.Remove(id) {
|
||||
removed++
|
||||
}
|
||||
}
|
||||
|
||||
return removed
|
||||
}
|
||||
|
||||
func (m *SessionManager) RunPeriodicCleanup() {
|
||||
m.logger.WithField("timeout", m.idleTimeout).Info("session manager cleanup enabled")
|
||||
|
||||
for range time.Tick(time.Minute) {
|
||||
m.Cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SessionManager) staleSessions() []string {
|
||||
m.mu.TryLock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
ids := []string{}
|
||||
|
||||
for id, conn := range m.sessions {
|
||||
if now.Sub(conn.LastQueryTime()) > m.idleTimeout {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
80
pkg/api/session_manager_test.go
Normal file
80
pkg/api/session_manager_test.go
Normal file
@ -0,0 +1,80 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/sosedoff/pgweb/pkg/client"
|
||||
)
|
||||
|
||||
func TestSessionManager(t *testing.T) {
|
||||
t.Run("return ids", func(t *testing.T) {
|
||||
manager := NewSessionManager(nil)
|
||||
assert.Equal(t, []string{}, manager.IDs())
|
||||
|
||||
manager.sessions["foo"] = &client.Client{}
|
||||
manager.sessions["bar"] = &client.Client{}
|
||||
|
||||
ids := manager.IDs()
|
||||
sort.Strings(ids)
|
||||
assert.Equal(t, []string{"bar", "foo"}, ids)
|
||||
})
|
||||
|
||||
t.Run("get session", func(t *testing.T) {
|
||||
manager := NewSessionManager(nil)
|
||||
assert.Nil(t, manager.Get("foo"))
|
||||
|
||||
manager.sessions["foo"] = &client.Client{}
|
||||
assert.NotNil(t, manager.Get("foo"))
|
||||
})
|
||||
|
||||
t.Run("set session", func(t *testing.T) {
|
||||
manager := NewSessionManager(nil)
|
||||
assert.Nil(t, manager.Get("foo"))
|
||||
|
||||
manager.Add("foo", &client.Client{})
|
||||
assert.NotNil(t, manager.Get("foo"))
|
||||
})
|
||||
|
||||
t.Run("remove session", func(t *testing.T) {
|
||||
manager := NewSessionManager(nil)
|
||||
assert.Nil(t, manager.Get("foo"))
|
||||
|
||||
manager.Add("foo", &client.Client{})
|
||||
assert.NotNil(t, manager.Get("foo"))
|
||||
assert.True(t, manager.Remove("foo"))
|
||||
assert.False(t, manager.Remove("foo"))
|
||||
assert.Nil(t, manager.Get("foo"))
|
||||
})
|
||||
|
||||
t.Run("return len", func(t *testing.T) {
|
||||
manager := NewSessionManager(nil)
|
||||
manager.sessions["foo"] = &client.Client{}
|
||||
manager.sessions["bar"] = &client.Client{}
|
||||
|
||||
assert.Equal(t, 2, manager.Len())
|
||||
})
|
||||
|
||||
t.Run("clean up stale sessions", func(t *testing.T) {
|
||||
manager := NewSessionManager(logrus.New())
|
||||
conn := &client.Client{}
|
||||
manager.Add("foo", conn)
|
||||
|
||||
assert.Equal(t, 1, manager.Len())
|
||||
assert.Equal(t, 0, manager.Cleanup())
|
||||
assert.Equal(t, 1, manager.Len())
|
||||
|
||||
res, err := conn.Query("select 1")
|
||||
assert.Nil(t, res)
|
||||
assert.Nil(t, err)
|
||||
|
||||
manager.SetIdleTimeout(time.Minute)
|
||||
assert.Equal(t, 1, manager.Cleanup())
|
||||
assert.Equal(t, 0, manager.Len())
|
||||
assert.True(t, conn.IsClosed())
|
||||
})
|
||||
}
|
@ -8,6 +8,7 @@ import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jessevdk/go-flags"
|
||||
@ -23,6 +24,7 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
logger *logrus.Logger
|
||||
options command.Options
|
||||
|
||||
readonlyWarning = `
|
||||
@ -36,6 +38,10 @@ For proper read-only access please follow postgresql role management documentati
|
||||
regexErrAuthFailed = regexp.MustCompile(`authentication failed`)
|
||||
)
|
||||
|
||||
func init() {
|
||||
logger = logrus.New()
|
||||
}
|
||||
|
||||
func exitWithMessage(message string) {
|
||||
fmt.Println("Error:", message)
|
||||
os.Exit(1)
|
||||
@ -152,6 +158,10 @@ func initOptions() {
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
if options.Debug {
|
||||
logger.SetLevel(logrus.DebugLevel)
|
||||
}
|
||||
|
||||
if options.ReadOnly {
|
||||
fmt.Println(readonlyWarning)
|
||||
}
|
||||
@ -184,11 +194,6 @@ func printVersion() {
|
||||
}
|
||||
|
||||
func startServer() {
|
||||
logger := logrus.New()
|
||||
if options.Debug {
|
||||
logger.SetLevel(logrus.DebugLevel)
|
||||
}
|
||||
|
||||
router := gin.New()
|
||||
router.Use(api.RequestLogger(logger))
|
||||
router.Use(gin.Recovery())
|
||||
@ -258,8 +263,13 @@ func Run() {
|
||||
}
|
||||
|
||||
// Start session cleanup worker
|
||||
if options.Sessions && !command.Opts.DisableConnectionIdleTimeout {
|
||||
go api.StartSessionCleanup()
|
||||
if options.Sessions {
|
||||
api.DbSessions = api.NewSessionManager(logger)
|
||||
|
||||
if !command.Opts.DisableConnectionIdleTimeout {
|
||||
api.DbSessions.SetIdleTimeout(time.Minute * time.Duration(command.Opts.ConnectionIdleTimeout))
|
||||
go api.DbSessions.RunPeriodicCleanup()
|
||||
}
|
||||
}
|
||||
|
||||
startServer()
|
||||
|
@ -34,7 +34,8 @@ type Client struct {
|
||||
serverVersion string
|
||||
serverType string
|
||||
lastQueryTime time.Time
|
||||
External bool
|
||||
closed bool
|
||||
External bool `json:"external"`
|
||||
History []history.Record `json:"history"`
|
||||
ConnectionString string `json:"connection_string"`
|
||||
}
|
||||
@ -334,6 +335,10 @@ func (client *Client) ServerVersion() string {
|
||||
}
|
||||
|
||||
func (client *Client) query(query string, args ...interface{}) (*Result, error) {
|
||||
if client.db == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Update the last usage time
|
||||
defer func() {
|
||||
client.lastQueryTime = time.Now().UTC()
|
||||
@ -365,7 +370,7 @@ func (client *Client) query(query string, args ...interface{}) (*Result, error)
|
||||
result := Result{
|
||||
Columns: []string{"Rows Affected"},
|
||||
Rows: []Row{
|
||||
Row{affected},
|
||||
{affected},
|
||||
},
|
||||
}
|
||||
|
||||
@ -423,6 +428,13 @@ func (client *Client) query(query string, args ...interface{}) (*Result, error)
|
||||
|
||||
// Close database connection
|
||||
func (client *Client) Close() error {
|
||||
if client.closed {
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
client.closed = true
|
||||
}()
|
||||
|
||||
if client.tunnel != nil {
|
||||
client.tunnel.Close()
|
||||
}
|
||||
@ -434,6 +446,14 @@ func (client *Client) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) IsClosed() bool {
|
||||
return c.closed
|
||||
}
|
||||
|
||||
func (c *Client) LastQueryTime() time.Time {
|
||||
return c.lastQueryTime
|
||||
}
|
||||
|
||||
func (client *Client) IsIdle() bool {
|
||||
mins := int(time.Since(client.lastQueryTime).Minutes())
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user