Establish connections using bookmark ID only (#619)

* Establish connections using bookmark ID only
* Refactor specs
* Extra tests
* Fix homedir assertion for bookmarks path
* Fix newline in the warning message
* Check for bookmark file existence before reading
* Connect code restructure
This commit is contained in:
Dan Sosedoff
2022-12-19 12:33:13 -06:00
committed by GitHub
parent 0b9e7cdb4e
commit 69233cd769
11 changed files with 411 additions and 325 deletions

View File

@@ -1,28 +1,21 @@
package bookmarks
import (
"fmt"
"os"
"path/filepath"
"strings"
"github.com/BurntSushi/toml"
"github.com/mitchellh/go-homedir"
"github.com/sosedoff/pgweb/pkg/command"
"github.com/sosedoff/pgweb/pkg/shared"
)
// Bookmark contains information about bookmarked database connection
type Bookmark struct {
URL string `json:"url"` // Postgres connection URL
Host string `json:"host"` // Server hostname
Port int `json:"port"` // Server port
User string `json:"user"` // Database user
Password string `json:"password"` // User password
Database string `json:"database"` // Database name
SSLMode string `json:"ssl"` // Connection SSL mode
SSH *shared.SSHInfo `json:"ssh"` // SSH tunnel config
ID string // ID generated from the filename
URL string // Postgres connection URL
Host string // Server hostname
Port int // Server port
User string // Database user
Password string // User password
Database string // Database name
SSLMode string // Connection SSL mode
SSH *shared.SSHInfo // SSH tunnel config
}
// SSHInfoIsEmpty returns true if ssh configuration is not provided
@@ -42,100 +35,3 @@ func (b Bookmark) ConvertToOptions() command.Options {
SSLMode: b.SSLMode,
}
}
func readServerConfig(path string) (Bookmark, error) {
bookmark := Bookmark{}
buff, err := os.ReadFile(path)
if err != nil {
return bookmark, err
}
_, err = toml.Decode(string(buff), &bookmark)
if bookmark.Port == 0 {
bookmark.Port = 5432
}
// List of all supported postgres modes
modes := []string{"disable", "allow", "prefer", "require", "verify-ca", "verify-full"}
valid := false
for _, mode := range modes {
if bookmark.SSLMode == mode {
valid = true
break
}
}
// Fall back to a default mode if mode is not set or invalid
// Typical typo: ssl mode set to "disabled"
if bookmark.SSLMode == "" || !valid {
bookmark.SSLMode = "disable"
}
// Set default SSH port if it's not provided by user
if bookmark.SSH != nil && bookmark.SSH.Port == "" {
bookmark.SSH.Port = "22"
}
return bookmark, err
}
func fileBasename(path string) string {
filename := filepath.Base(path)
return strings.Replace(filename, filepath.Ext(path), "", 1)
}
// Path returns bookmarks storage path
func Path(overrideDir string) string {
if overrideDir == "" {
path, _ := homedir.Dir()
return fmt.Sprintf("%s/.pgweb/bookmarks", path)
}
return overrideDir
}
// ReadAll returns all available bookmarks
func ReadAll(path string) (map[string]Bookmark, error) {
results := map[string]Bookmark{}
files, err := os.ReadDir(path)
if err != nil {
return results, err
}
for _, file := range files {
if filepath.Ext(file.Name()) != ".toml" {
continue
}
fullPath := filepath.Join(path, file.Name())
key := fileBasename(file.Name())
config, err := readServerConfig(fullPath)
if err != nil {
fmt.Printf("%s parse error: %s\n", fullPath, err)
continue
}
results[key] = config
}
return results, nil
}
// GetBookmark reads an existing bookmark
func GetBookmark(bookmarkPath string, bookmarkName string) (Bookmark, error) {
bookmarks, err := ReadAll(bookmarkPath)
if err != nil {
return Bookmark{}, err
}
bookmark, ok := bookmarks[bookmarkName]
if !ok {
return Bookmark{}, fmt.Errorf("couldn't find a bookmark with name %s", bookmarkName)
}
return bookmark, nil
}

View File

@@ -8,115 +8,34 @@ import (
"github.com/stretchr/testify/assert"
)
func Test_Invalid_Bookmark_Files(t *testing.T) {
_, err := readServerConfig("foobar")
assert.Error(t, err)
func TestBookmarkSSHInfoIsEmpty(t *testing.T) {
t.Run("empty", func(t *testing.T) {
info := &shared.SSHInfo{
Host: "",
Port: "",
User: "",
}
_, err = readServerConfig("../../data/invalid.toml")
assert.Error(t, err)
assert.Equal(t, "toml: line 1: expected '.' or '=', but got 'e' instead", err.Error())
b := Bookmark{SSH: nil}
assert.True(t, b.SSHInfoIsEmpty())
b = Bookmark{SSH: info}
assert.True(t, b.SSHInfoIsEmpty())
})
t.Run("populated", func(t *testing.T) {
info := &shared.SSHInfo{
Host: "localhost",
Port: "8080",
User: "postgres",
}
b := Bookmark{SSH: info}
assert.False(t, b.SSHInfoIsEmpty())
})
}
func Test_Bookmark(t *testing.T) {
bookmark, err := readServerConfig("../../data/bookmark.toml")
assert.Equal(t, nil, err)
assert.Equal(t, "localhost", bookmark.Host)
assert.Equal(t, 5432, bookmark.Port)
assert.Equal(t, "postgres", bookmark.User)
assert.Equal(t, "mydatabase", bookmark.Database)
assert.Equal(t, "disable", bookmark.SSLMode)
assert.Equal(t, "", bookmark.Password)
assert.Equal(t, "", bookmark.URL)
bookmark, err = readServerConfig("../../data/bookmark_invalid_ssl.toml")
assert.Equal(t, nil, err)
assert.Equal(t, "disable", bookmark.SSLMode)
}
func Test_Bookmark_URL(t *testing.T) {
bookmark, err := readServerConfig("../../data/bookmark_url.toml")
assert.Equal(t, nil, err)
assert.Equal(t, "postgres://username:password@host:port/database?sslmode=disable", bookmark.URL)
assert.Equal(t, "", bookmark.Host)
assert.Equal(t, 5432, bookmark.Port)
assert.Equal(t, "", bookmark.User)
assert.Equal(t, "", bookmark.Database)
assert.Equal(t, "disable", bookmark.SSLMode)
assert.Equal(t, "", bookmark.Password)
}
func Test_Bookmarks_Path(t *testing.T) {
assert.NotEqual(t, "/.pgweb/bookmarks", Path(""))
}
func Test_Basename(t *testing.T) {
assert.Equal(t, "filename", fileBasename("filename.toml"))
assert.Equal(t, "filename", fileBasename("path/filename.toml"))
assert.Equal(t, "filename", fileBasename("~/long/path/filename.toml"))
assert.Equal(t, "filename", fileBasename("filename"))
}
func Test_ReadBookmarks_Invalid(t *testing.T) {
bookmarks, err := ReadAll("foobar")
assert.Error(t, err)
assert.Equal(t, 0, len(bookmarks))
}
func Test_ReadBookmarks(t *testing.T) {
bookmarks, err := ReadAll("../../data")
assert.Equal(t, nil, err)
assert.Equal(t, 3, len(bookmarks))
}
func Test_GetBookmark(t *testing.T) {
expBookmark := Bookmark{
Host: "localhost",
Port: 5432,
User: "postgres",
Password: "",
Database: "mydatabase",
SSLMode: "disable",
}
b, err := GetBookmark("../../data", "bookmark")
if assert.NoError(t, err) {
assert.Equal(t, expBookmark, b)
}
_, err = GetBookmark("../../data", "bar")
expErrStr := "couldn't find a bookmark with name bar"
assert.Equal(t, expErrStr, err.Error())
_, err = GetBookmark("foo", "bookmark")
assert.Error(t, err)
}
func Test_Bookmark_SSHInfoIsEmpty(t *testing.T) {
emptySSH := &shared.SSHInfo{
Host: "",
Port: "",
User: "",
}
populatedSSH := &shared.SSHInfo{
Host: "localhost",
Port: "8080",
User: "postgres",
}
b := Bookmark{SSH: nil}
assert.True(t, b.SSHInfoIsEmpty())
b = Bookmark{SSH: emptySSH}
assert.True(t, b.SSHInfoIsEmpty())
b.SSH = populatedSSH
assert.False(t, b.SSHInfoIsEmpty())
}
func Test_ConvertToOptions(t *testing.T) {
func TestBookmarkConvertToOptions(t *testing.T) {
b := Bookmark{
URL: "postgres://username:password@host:port/database?sslmode=disable",
Host: "localhost",
@@ -136,6 +55,7 @@ func Test_ConvertToOptions(t *testing.T) {
DbName: "mydatabase",
SSLMode: "disable",
}
opt := b.ConvertToOptions()
assert.Equal(t, expOpt, opt)
}

152
pkg/bookmarks/manager.go Normal file
View File

@@ -0,0 +1,152 @@
package bookmarks
import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/BurntSushi/toml"
)
type Manager struct {
dir string
}
func NewManager(dir string) Manager {
return Manager{
dir: dir,
}
}
func (m Manager) Get(id string) (*Bookmark, error) {
bookmarks, err := m.list()
if err != nil {
return nil, err
}
for _, b := range bookmarks {
if b.ID == id {
return &b, nil
}
}
return nil, fmt.Errorf("bookmark %v not found", id)
}
func (m Manager) List() ([]Bookmark, error) {
return m.list()
}
func (m Manager) ListIDs() ([]string, error) {
bookmarks, err := m.list()
if err != nil {
return nil, err
}
ids := make([]string, len(bookmarks))
for i, bookmark := range bookmarks {
ids[i] = bookmark.ID
}
return ids, nil
}
func (m Manager) list() ([]Bookmark, error) {
result := []Bookmark{}
if m.dir == "" {
return result, nil
}
info, err := os.Stat(m.dir)
if err != nil {
// Do not fail if base dir does not exists: it's not created by default
if errors.Is(err, os.ErrNotExist) {
fmt.Fprintf(os.Stderr, "[WARN] bookmarks dir %s does not exist\n", m.dir)
return result, nil
}
return nil, err
}
if !info.IsDir() {
return nil, fmt.Errorf("path %s is not a directory", m.dir)
}
dirEntries, err := os.ReadDir(m.dir)
if err != nil {
return nil, err
}
for _, entry := range dirEntries {
name := entry.Name()
if filepath.Ext(name) != ".toml" {
continue
}
bookmark, err := readBookmark(filepath.Join(m.dir, name))
if err != nil {
// Do not fail if one of the bookmarks is invalid
fmt.Fprintf(os.Stderr, "[WARN] bookmark file %s is invalid: %s\n", name, err)
continue
}
result = append(result, bookmark)
}
return result, nil
}
func readBookmark(path string) (Bookmark, error) {
bookmark := Bookmark{
ID: fileBasename(path),
}
_, err := os.Stat(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
err = fmt.Errorf("bookmark file %s does not exist", path)
}
return bookmark, err
}
buff, err := os.ReadFile(path)
if err != nil {
return bookmark, err
}
_, err = toml.Decode(string(buff), &bookmark)
if bookmark.Port == 0 {
bookmark.Port = 5432
}
// List of all supported postgres modes
modes := []string{"disable", "allow", "prefer", "require", "verify-ca", "verify-full"}
valid := false
for _, mode := range modes {
if bookmark.SSLMode == mode {
valid = true
break
}
}
// Fall back to a default mode if mode is not set or invalid
// Typical typo: ssl mode set to "disabled"
if bookmark.SSLMode == "" || !valid {
bookmark.SSLMode = "disable"
}
// Set default SSH port if it's not provided by user
if bookmark.SSH != nil && bookmark.SSH.Port == "" {
bookmark.SSH.Port = "22"
}
return bookmark, err
}
func fileBasename(path string) string {
filename := filepath.Base(path)
return strings.Replace(filename, filepath.Ext(path), "", 1)
}

View File

@@ -0,0 +1,98 @@
package bookmarks
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestManagerList(t *testing.T) {
examples := []struct {
dir string
num int
err string
}{
{"../../data", 3, ""},
{"../../data/bookmark.toml", 0, "is not a directory"},
{"../../data2", 0, ""},
{"", 0, ""},
}
for _, ex := range examples {
t.Run(ex.dir, func(t *testing.T) {
bookmarks, err := NewManager(ex.dir).List()
if ex.err != "" {
assert.Contains(t, err.Error(), ex.err)
}
assert.Len(t, bookmarks, ex.num)
})
}
}
func TestManagerListIDs(t *testing.T) {
ids, err := NewManager("../../data").ListIDs()
assert.NoError(t, err)
assert.Equal(t, []string{"bookmark", "bookmark_invalid_ssl", "bookmark_url"}, ids)
}
func TestManagerGet(t *testing.T) {
manager := NewManager("../../data")
b, err := manager.Get("bookmark")
assert.NoError(t, err)
assert.Equal(t, "bookmark", b.ID)
b, err = manager.Get("foo")
assert.Equal(t, "bookmark foo not found", err.Error())
assert.Nil(t, b)
}
func Test_fileBasename(t *testing.T) {
assert.Equal(t, "filename", fileBasename("filename.toml"))
assert.Equal(t, "filename", fileBasename("path/filename.toml"))
assert.Equal(t, "filename", fileBasename("~/long/path/filename.toml"))
assert.Equal(t, "filename", fileBasename("filename"))
}
func Test_readBookmark(t *testing.T) {
t.Run("good", func(t *testing.T) {
b, err := readBookmark("../../data/bookmark.toml")
assert.NoError(t, err)
assert.Equal(t, "bookmark", b.ID)
assert.Equal(t, "localhost", b.Host)
assert.Equal(t, 5432, b.Port)
assert.Equal(t, "postgres", b.User)
assert.Equal(t, "mydatabase", b.Database)
assert.Equal(t, "disable", b.SSLMode)
assert.Equal(t, "", b.Password)
assert.Equal(t, "", b.URL)
})
t.Run("with url", func(t *testing.T) {
b, err := readBookmark("../../data/bookmark_url.toml")
assert.NoError(t, err)
assert.Equal(t, "postgres://username:password@host:port/database?sslmode=disable", b.URL)
assert.Equal(t, "", b.Host)
assert.Equal(t, 5432, b.Port)
assert.Equal(t, "", b.User)
assert.Equal(t, "", b.Database)
assert.Equal(t, "disable", b.SSLMode)
assert.Equal(t, "", b.Password)
})
t.Run("invalid ssl", func(t *testing.T) {
b, err := readBookmark("../../data/bookmark_invalid_ssl.toml")
assert.NoError(t, err)
assert.Equal(t, "disable", b.SSLMode)
})
t.Run("invalid file", func(t *testing.T) {
_, err := readBookmark("foobar")
assert.Equal(t, "bookmark file foobar does not exist", err.Error())
})
t.Run("invalid syntax", func(t *testing.T) {
_, err := readBookmark("../../data/invalid.toml")
assert.Equal(t, "toml: line 1: expected '.' or '=', but got 'e' instead", err.Error())
})
}