Tunnel implementation, allow using ssh on connection screen
This commit is contained in:
@@ -6,12 +6,15 @@ import (
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/sosedoff/pgweb/pkg/connection"
|
||||
"github.com/sosedoff/pgweb/pkg/shared"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -22,21 +25,22 @@ const (
|
||||
type Tunnel struct {
|
||||
TargetHost string
|
||||
TargetPort string
|
||||
|
||||
SshHost string
|
||||
SshPort string
|
||||
SshUser string
|
||||
SshPassword string
|
||||
SshKey string
|
||||
|
||||
Config *ssh.ClientConfig
|
||||
Client *ssh.Client
|
||||
Port int
|
||||
SSHInfo *shared.SSHInfo
|
||||
Config *ssh.ClientConfig
|
||||
Client *ssh.Client
|
||||
Listener *net.TCPListener
|
||||
}
|
||||
|
||||
func privateKeyPath() string {
|
||||
return os.Getenv("HOME") + "/.ssh/id_rsa"
|
||||
}
|
||||
|
||||
func fileExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func parsePrivateKey(keyPath string) (ssh.Signer, error) {
|
||||
buff, err := ioutil.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
@@ -46,10 +50,11 @@ func parsePrivateKey(keyPath string) (ssh.Signer, error) {
|
||||
return ssh.ParsePrivateKey(buff)
|
||||
}
|
||||
|
||||
func makeConfig(user, password, keyPath string) (*ssh.ClientConfig, error) {
|
||||
func makeConfig(info *shared.SSHInfo) (*ssh.ClientConfig, error) {
|
||||
methods := []ssh.AuthMethod{}
|
||||
|
||||
if keyPath != "" {
|
||||
keyPath := privateKeyPath()
|
||||
if fileExists(keyPath) {
|
||||
key, err := parsePrivateKey(keyPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -58,52 +63,19 @@ func makeConfig(user, password, keyPath string) (*ssh.ClientConfig, error) {
|
||||
methods = append(methods, ssh.PublicKeys(key))
|
||||
}
|
||||
|
||||
methods = append(methods, ssh.Password(password))
|
||||
methods = append(methods, ssh.Password(info.Password))
|
||||
|
||||
return &ssh.ClientConfig{User: user, Auth: methods}, nil
|
||||
return &ssh.ClientConfig{User: info.User, Auth: methods}, nil
|
||||
}
|
||||
|
||||
func (tunnel *Tunnel) sshEndpoint() string {
|
||||
return fmt.Sprintf("%s:%v", tunnel.SshHost, tunnel.SshPort)
|
||||
return fmt.Sprintf("%s:%v", tunnel.SSHInfo.Host, tunnel.SSHInfo.Port)
|
||||
}
|
||||
|
||||
func (tunnel *Tunnel) targetEndpoint() string {
|
||||
return fmt.Sprintf("%v:%v", tunnel.TargetHost, tunnel.TargetPort)
|
||||
}
|
||||
|
||||
func (tunnel *Tunnel) Start() error {
|
||||
config, err := makeConfig(tunnel.SshUser, tunnel.SshPassword, tunnel.SshKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, err := ssh.Dial("tcp", tunnel.sshEndpoint(), config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
port, err := connection.AvailablePort(PORT_START, PORT_LIMIT)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%v", port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go tunnel.handleConnection(conn, client)
|
||||
}
|
||||
}
|
||||
|
||||
func (tunnel *Tunnel) copy(wg *sync.WaitGroup, writer, reader net.Conn) {
|
||||
defer wg.Done()
|
||||
if _, err := io.Copy(writer, reader); err != nil {
|
||||
@@ -111,8 +83,8 @@ func (tunnel *Tunnel) copy(wg *sync.WaitGroup, writer, reader net.Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
func (tunnel *Tunnel) handleConnection(local net.Conn, sshClient *ssh.Client) {
|
||||
remote, err := sshClient.Dial("tcp", tunnel.targetEndpoint())
|
||||
func (tunnel *Tunnel) handleConnection(local net.Conn) {
|
||||
remote, err := tunnel.Client.Dial("tcp", tunnel.targetEndpoint())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -124,4 +96,79 @@ func (tunnel *Tunnel) handleConnection(local net.Conn, sshClient *ssh.Client) {
|
||||
go tunnel.copy(&wg, remote, local)
|
||||
|
||||
wg.Wait()
|
||||
local.Close()
|
||||
}
|
||||
|
||||
func (tunnel *Tunnel) Close() {
|
||||
if tunnel.Client != nil {
|
||||
tunnel.Client.Close()
|
||||
}
|
||||
|
||||
if tunnel.Listener != nil {
|
||||
tunnel.Listener.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (tunnel *Tunnel) Configure() error {
|
||||
config, err := makeConfig(tunnel.SSHInfo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tunnel.Config = config
|
||||
|
||||
client, err := ssh.Dial("tcp", tunnel.sshEndpoint(), config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tunnel.Client = client
|
||||
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%v", tunnel.Port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tunnel.Listener = listener.(*net.TCPListener)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tunnel *Tunnel) Start() {
|
||||
for {
|
||||
conn, err := tunnel.Listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
go tunnel.handleConnection(conn)
|
||||
}
|
||||
|
||||
tunnel.Close()
|
||||
}
|
||||
|
||||
func NewTunnel(sshInfo *shared.SSHInfo, dbUrl string) (*Tunnel, error) {
|
||||
uri, err := url.Parse(dbUrl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
listenPort, err := connection.AvailablePort(PORT_START, PORT_LIMIT)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
chunks := strings.Split(uri.Host, ":")
|
||||
host := chunks[0]
|
||||
port := "5432"
|
||||
|
||||
if len(chunks) == 2 {
|
||||
port = chunks[1]
|
||||
}
|
||||
|
||||
tunnel := &Tunnel{
|
||||
Port: listenPort,
|
||||
SSHInfo: sshInfo,
|
||||
TargetHost: host,
|
||||
TargetPort: port,
|
||||
}
|
||||
|
||||
return tunnel, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user