Initial ssh tunnel implementation

This commit is contained in:
Dan Sosedoff 2016-01-13 01:29:14 -06:00
parent ab94dd5b86
commit 68c2b4d084
6 changed files with 330 additions and 272 deletions

View File

@ -18,6 +18,12 @@ type Bookmark struct {
Password string `json:"password"` // User password
Database string `json:"database"` // Database name
Ssl string `json:"ssl"` // Connection SSL mode
SshHost string `json:"ssh_user"`
SshPort string `json:"ssh_port"`
SshUser string `json:"ssh_user"`
SshPassword string `json:"ssh_password"`
SshKey string `json:"ssh_key"`
}
func readServerConfig(path string) (Bookmark, error) {

View File

@ -16,6 +16,7 @@ import (
type Client struct {
db *sqlx.DB
tunnel *Tunnel
History []history.Record `json:"history"`
ConnectionString string `json:"connection_string"`
}
@ -49,7 +50,6 @@ func New() (*Client, error) {
}
db, err := sqlx.Open("postgres", str)
if err != nil {
return nil, err
}

127
pkg/client/tunnel.go Normal file
View File

@ -0,0 +1,127 @@
package client
import (
"fmt"
"io"
"io/ioutil"
"log"
"net"
"os"
"sync"
"golang.org/x/crypto/ssh"
"github.com/sosedoff/pgweb/pkg/connection"
)
const (
PORT_START = 29168
PORT_LIMIT = 500
)
type Tunnel struct {
TargetHost string
TargetPort string
SshHost string
SshPort string
SshUser string
SshPassword string
SshKey string
Config *ssh.ClientConfig
Client *ssh.Client
}
func privateKeyPath() string {
return os.Getenv("HOME") + "/.ssh/id_rsa"
}
func parsePrivateKey(keyPath string) (ssh.Signer, error) {
buff, err := ioutil.ReadFile(keyPath)
if err != nil {
return nil, err
}
return ssh.ParsePrivateKey(buff)
}
func makeConfig(user, password, keyPath string) (*ssh.ClientConfig, error) {
methods := []ssh.AuthMethod{}
if keyPath != "" {
key, err := parsePrivateKey(keyPath)
if err != nil {
return nil, err
}
methods = append(methods, ssh.PublicKeys(key))
}
methods = append(methods, ssh.Password(password))
return &ssh.ClientConfig{User: user, Auth: methods}, nil
}
func (tunnel *Tunnel) sshEndpoint() string {
return fmt.Sprintf("%s:%v", tunnel.SshHost, tunnel.SshPort)
}
func (tunnel *Tunnel) targetEndpoint() string {
return fmt.Sprintf("%v:%v", tunnel.TargetHost, tunnel.TargetPort)
}
func (tunnel *Tunnel) Start() error {
config, err := makeConfig(tunnel.SshUser, tunnel.SshPassword, tunnel.SshKey)
if err != nil {
return err
}
client, err := ssh.Dial("tcp", tunnel.sshEndpoint(), config)
if err != nil {
return err
}
defer client.Close()
port, err := connection.AvailablePort(PORT_START, PORT_LIMIT)
if err != nil {
return err
}
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%v", port))
if err != nil {
return err
}
defer listener.Close()
for {
conn, err := listener.Accept()
if err != nil {
return err
}
go tunnel.handleConnection(conn, client)
}
}
func (tunnel *Tunnel) copy(wg *sync.WaitGroup, writer, reader net.Conn) {
defer wg.Done()
if _, err := io.Copy(writer, reader); err != nil {
log.Println("Tunnel copy error:", err)
}
}
func (tunnel *Tunnel) handleConnection(local net.Conn, sshClient *ssh.Client) {
remote, err := sshClient.Dial("tcp", tunnel.targetEndpoint())
if err != nil {
return
}
wg := sync.WaitGroup{}
wg.Add(2)
go tunnel.copy(&wg, local, remote)
go tunnel.copy(&wg, remote, local)
wg.Wait()
}

View File

@ -23,7 +23,7 @@ func portAvailable(port int) bool {
}
// Get available TCP port on localhost by trying available ports in a range
func getAvailablePort(start int, limit int) (int, error) {
func AvailablePort(start int, limit int) (int, error) {
for i := start; i <= (start + limit); i++ {
if portAvailable(i) {
return i, nil

View File

@ -44,7 +44,7 @@ func Test_getAvailablePort(t *testing.T) {
t.Skip("FIXME")
}
port, err := getAvailablePort(8081, 1)
port, err := AvailablePort(8081, 1)
assert.Equal(t, nil, err)
assert.Equal(t, 8081, port)
@ -65,11 +65,11 @@ func Test_getAvailablePort(t *testing.T) {
}
}()
port, err = getAvailablePort(8081, 0)
port, err = AvailablePort(8081, 0)
assert.EqualError(t, err, "No available port")
assert.Equal(t, -1, port)
port, err = getAvailablePort(8081, 1)
port, err = AvailablePort(8081, 1)
assert.Equal(t, nil, err)
assert.Equal(t, 8082, port)
}

File diff suppressed because one or more lines are too long