pgweb/pkg/client/tunnel.go
2016-01-13 01:29:14 -06:00

128 lines
2.3 KiB
Go

package client
import (
"fmt"
"io"
"io/ioutil"
"log"
"net"
"os"
"sync"
"golang.org/x/crypto/ssh"
"github.com/sosedoff/pgweb/pkg/connection"
)
const (
PORT_START = 29168
PORT_LIMIT = 500
)
type Tunnel struct {
TargetHost string
TargetPort string
SshHost string
SshPort string
SshUser string
SshPassword string
SshKey string
Config *ssh.ClientConfig
Client *ssh.Client
}
func privateKeyPath() string {
return os.Getenv("HOME") + "/.ssh/id_rsa"
}
func parsePrivateKey(keyPath string) (ssh.Signer, error) {
buff, err := ioutil.ReadFile(keyPath)
if err != nil {
return nil, err
}
return ssh.ParsePrivateKey(buff)
}
func makeConfig(user, password, keyPath string) (*ssh.ClientConfig, error) {
methods := []ssh.AuthMethod{}
if keyPath != "" {
key, err := parsePrivateKey(keyPath)
if err != nil {
return nil, err
}
methods = append(methods, ssh.PublicKeys(key))
}
methods = append(methods, ssh.Password(password))
return &ssh.ClientConfig{User: user, Auth: methods}, nil
}
func (tunnel *Tunnel) sshEndpoint() string {
return fmt.Sprintf("%s:%v", tunnel.SshHost, tunnel.SshPort)
}
func (tunnel *Tunnel) targetEndpoint() string {
return fmt.Sprintf("%v:%v", tunnel.TargetHost, tunnel.TargetPort)
}
func (tunnel *Tunnel) Start() error {
config, err := makeConfig(tunnel.SshUser, tunnel.SshPassword, tunnel.SshKey)
if err != nil {
return err
}
client, err := ssh.Dial("tcp", tunnel.sshEndpoint(), config)
if err != nil {
return err
}
defer client.Close()
port, err := connection.AvailablePort(PORT_START, PORT_LIMIT)
if err != nil {
return err
}
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%v", port))
if err != nil {
return err
}
defer listener.Close()
for {
conn, err := listener.Accept()
if err != nil {
return err
}
go tunnel.handleConnection(conn, client)
}
}
func (tunnel *Tunnel) copy(wg *sync.WaitGroup, writer, reader net.Conn) {
defer wg.Done()
if _, err := io.Copy(writer, reader); err != nil {
log.Println("Tunnel copy error:", err)
}
}
func (tunnel *Tunnel) handleConnection(local net.Conn, sshClient *ssh.Client) {
remote, err := sshClient.Dial("tcp", tunnel.targetEndpoint())
if err != nil {
return
}
wg := sync.WaitGroup{}
wg.Add(2)
go tunnel.copy(&wg, local, remote)
go tunnel.copy(&wg, remote, local)
wg.Wait()
}