Tunnel implementation, allow using ssh on connection screen

This commit is contained in:
Dan Sosedoff
2016-01-14 19:50:01 -06:00
parent fb66acebc3
commit f0f447857f
11 changed files with 229 additions and 93 deletions

View File

@@ -11,6 +11,7 @@ import (
"github.com/sosedoff/pgweb/pkg/command"
"github.com/sosedoff/pgweb/pkg/connection"
"github.com/sosedoff/pgweb/pkg/history"
"github.com/sosedoff/pgweb/pkg/shared"
"github.com/sosedoff/pgweb/pkg/statements"
)
@@ -63,7 +64,32 @@ func New() (*Client, error) {
return &client, nil
}
func NewFromUrl(url string) (*Client, error) {
func NewFromUrl(url string, sshInfo *shared.SSHInfo) (*Client, error) {
var tunnel *Tunnel
if sshInfo != nil {
if command.Opts.Debug {
fmt.Println("Opening SSH tunnel for:", sshInfo)
}
tunnel, err := NewTunnel(sshInfo, url)
if err != nil {
tunnel.Close()
return nil, err
}
err = tunnel.Configure()
if err != nil {
tunnel.Close()
return nil, err
}
go tunnel.Start()
// Override remote postgres port with local proxy port
url = strings.Replace(url, ":5432", fmt.Sprintf(":%v", tunnel.Port), 1)
}
if command.Opts.Debug {
fmt.Println("Creating a new client for:", url)
}
@@ -75,6 +101,7 @@ func NewFromUrl(url string) (*Client, error) {
client := Client{
db: db,
tunnel: tunnel,
ConnectionString: url,
History: history.New(),
}
@@ -230,9 +257,14 @@ func (client *Client) query(query string, args ...interface{}) (*Result, error)
// Close database connection
func (client *Client) Close() error {
if client.tunnel != nil {
client.tunnel.Close()
}
if client.db != nil {
return client.db.Close()
}
return nil
}

View File

@@ -60,7 +60,7 @@ func setup() {
}
func setupClient() {
testClient, _ = NewFromUrl("postgres://postgres@localhost/booktown?sslmode=disable")
testClient, _ = NewFromUrl("postgres://postgres@localhost/booktown?sslmode=disable", nil)
}
func teardownClient() {
@@ -79,7 +79,7 @@ func teardown() {
func test_NewClientFromUrl(t *testing.T) {
url := "postgres://postgres@localhost/booktown?sslmode=disable"
client, err := NewFromUrl(url)
client, err := NewFromUrl(url, nil)
if err != nil {
defer client.Close()
@@ -91,7 +91,7 @@ func test_NewClientFromUrl(t *testing.T) {
func test_NewClientFromUrl2(t *testing.T) {
url := "postgresql://postgres@localhost/booktown?sslmode=disable"
client, err := NewFromUrl(url)
client, err := NewFromUrl(url, nil)
if err != nil {
defer client.Close()
@@ -257,7 +257,7 @@ func test_HistoryError(t *testing.T) {
}
func test_HistoryUniqueness(t *testing.T) {
client, _ := NewFromUrl("postgres://postgres@localhost/booktown?sslmode=disable")
client, _ := NewFromUrl("postgres://postgres@localhost/booktown?sslmode=disable", nil)
client.Query("SELECT * FROM books WHERE id = 1")
client.Query("SELECT * FROM books WHERE id = 1")

View File

@@ -6,12 +6,15 @@ import (
"io/ioutil"
"log"
"net"
"net/url"
"os"
"strings"
"sync"
"golang.org/x/crypto/ssh"
"github.com/sosedoff/pgweb/pkg/connection"
"github.com/sosedoff/pgweb/pkg/shared"
)
const (
@@ -22,21 +25,22 @@ const (
type Tunnel struct {
TargetHost string
TargetPort string
SshHost string
SshPort string
SshUser string
SshPassword string
SshKey string
Config *ssh.ClientConfig
Client *ssh.Client
Port int
SSHInfo *shared.SSHInfo
Config *ssh.ClientConfig
Client *ssh.Client
Listener *net.TCPListener
}
func privateKeyPath() string {
return os.Getenv("HOME") + "/.ssh/id_rsa"
}
func fileExists(path string) bool {
_, err := os.Stat(path)
return err == nil
}
func parsePrivateKey(keyPath string) (ssh.Signer, error) {
buff, err := ioutil.ReadFile(keyPath)
if err != nil {
@@ -46,10 +50,11 @@ func parsePrivateKey(keyPath string) (ssh.Signer, error) {
return ssh.ParsePrivateKey(buff)
}
func makeConfig(user, password, keyPath string) (*ssh.ClientConfig, error) {
func makeConfig(info *shared.SSHInfo) (*ssh.ClientConfig, error) {
methods := []ssh.AuthMethod{}
if keyPath != "" {
keyPath := privateKeyPath()
if fileExists(keyPath) {
key, err := parsePrivateKey(keyPath)
if err != nil {
return nil, err
@@ -58,52 +63,19 @@ func makeConfig(user, password, keyPath string) (*ssh.ClientConfig, error) {
methods = append(methods, ssh.PublicKeys(key))
}
methods = append(methods, ssh.Password(password))
methods = append(methods, ssh.Password(info.Password))
return &ssh.ClientConfig{User: user, Auth: methods}, nil
return &ssh.ClientConfig{User: info.User, Auth: methods}, nil
}
func (tunnel *Tunnel) sshEndpoint() string {
return fmt.Sprintf("%s:%v", tunnel.SshHost, tunnel.SshPort)
return fmt.Sprintf("%s:%v", tunnel.SSHInfo.Host, tunnel.SSHInfo.Port)
}
func (tunnel *Tunnel) targetEndpoint() string {
return fmt.Sprintf("%v:%v", tunnel.TargetHost, tunnel.TargetPort)
}
func (tunnel *Tunnel) Start() error {
config, err := makeConfig(tunnel.SshUser, tunnel.SshPassword, tunnel.SshKey)
if err != nil {
return err
}
client, err := ssh.Dial("tcp", tunnel.sshEndpoint(), config)
if err != nil {
return err
}
defer client.Close()
port, err := connection.AvailablePort(PORT_START, PORT_LIMIT)
if err != nil {
return err
}
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%v", port))
if err != nil {
return err
}
defer listener.Close()
for {
conn, err := listener.Accept()
if err != nil {
return err
}
go tunnel.handleConnection(conn, client)
}
}
func (tunnel *Tunnel) copy(wg *sync.WaitGroup, writer, reader net.Conn) {
defer wg.Done()
if _, err := io.Copy(writer, reader); err != nil {
@@ -111,8 +83,8 @@ func (tunnel *Tunnel) copy(wg *sync.WaitGroup, writer, reader net.Conn) {
}
}
func (tunnel *Tunnel) handleConnection(local net.Conn, sshClient *ssh.Client) {
remote, err := sshClient.Dial("tcp", tunnel.targetEndpoint())
func (tunnel *Tunnel) handleConnection(local net.Conn) {
remote, err := tunnel.Client.Dial("tcp", tunnel.targetEndpoint())
if err != nil {
return
}
@@ -124,4 +96,79 @@ func (tunnel *Tunnel) handleConnection(local net.Conn, sshClient *ssh.Client) {
go tunnel.copy(&wg, remote, local)
wg.Wait()
local.Close()
}
func (tunnel *Tunnel) Close() {
if tunnel.Client != nil {
tunnel.Client.Close()
}
if tunnel.Listener != nil {
tunnel.Listener.Close()
}
}
func (tunnel *Tunnel) Configure() error {
config, err := makeConfig(tunnel.SSHInfo)
if err != nil {
return err
}
tunnel.Config = config
client, err := ssh.Dial("tcp", tunnel.sshEndpoint(), config)
if err != nil {
return err
}
tunnel.Client = client
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%v", tunnel.Port))
if err != nil {
return err
}
tunnel.Listener = listener.(*net.TCPListener)
return nil
}
func (tunnel *Tunnel) Start() {
for {
conn, err := tunnel.Listener.Accept()
if err != nil {
return
}
go tunnel.handleConnection(conn)
}
tunnel.Close()
}
func NewTunnel(sshInfo *shared.SSHInfo, dbUrl string) (*Tunnel, error) {
uri, err := url.Parse(dbUrl)
if err != nil {
return nil, err
}
listenPort, err := connection.AvailablePort(PORT_START, PORT_LIMIT)
if err != nil {
return nil, err
}
chunks := strings.Split(uri.Host, ":")
host := chunks[0]
port := "5432"
if len(chunks) == 2 {
port = chunks[1]
}
tunnel := &Tunnel{
Port: listenPort,
SSHInfo: sshInfo,
TargetHost: host,
TargetPort: port,
}
return tunnel, nil
}