Initial ssh tunnel implementation
This commit is contained in:
parent
ab94dd5b86
commit
68c2b4d084
@ -18,6 +18,12 @@ type Bookmark struct {
|
|||||||
Password string `json:"password"` // User password
|
Password string `json:"password"` // User password
|
||||||
Database string `json:"database"` // Database name
|
Database string `json:"database"` // Database name
|
||||||
Ssl string `json:"ssl"` // Connection SSL mode
|
Ssl string `json:"ssl"` // Connection SSL mode
|
||||||
|
|
||||||
|
SshHost string `json:"ssh_user"`
|
||||||
|
SshPort string `json:"ssh_port"`
|
||||||
|
SshUser string `json:"ssh_user"`
|
||||||
|
SshPassword string `json:"ssh_password"`
|
||||||
|
SshKey string `json:"ssh_key"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func readServerConfig(path string) (Bookmark, error) {
|
func readServerConfig(path string) (Bookmark, error) {
|
||||||
|
@ -16,6 +16,7 @@ import (
|
|||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
db *sqlx.DB
|
db *sqlx.DB
|
||||||
|
tunnel *Tunnel
|
||||||
History []history.Record `json:"history"`
|
History []history.Record `json:"history"`
|
||||||
ConnectionString string `json:"connection_string"`
|
ConnectionString string `json:"connection_string"`
|
||||||
}
|
}
|
||||||
@ -49,7 +50,6 @@ func New() (*Client, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
db, err := sqlx.Open("postgres", str)
|
db, err := sqlx.Open("postgres", str)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
127
pkg/client/tunnel.go
Normal file
127
pkg/client/tunnel.go
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
|
"github.com/sosedoff/pgweb/pkg/connection"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
PORT_START = 29168
|
||||||
|
PORT_LIMIT = 500
|
||||||
|
)
|
||||||
|
|
||||||
|
type Tunnel struct {
|
||||||
|
TargetHost string
|
||||||
|
TargetPort string
|
||||||
|
|
||||||
|
SshHost string
|
||||||
|
SshPort string
|
||||||
|
SshUser string
|
||||||
|
SshPassword string
|
||||||
|
SshKey string
|
||||||
|
|
||||||
|
Config *ssh.ClientConfig
|
||||||
|
Client *ssh.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func privateKeyPath() string {
|
||||||
|
return os.Getenv("HOME") + "/.ssh/id_rsa"
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePrivateKey(keyPath string) (ssh.Signer, error) {
|
||||||
|
buff, err := ioutil.ReadFile(keyPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ssh.ParsePrivateKey(buff)
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeConfig(user, password, keyPath string) (*ssh.ClientConfig, error) {
|
||||||
|
methods := []ssh.AuthMethod{}
|
||||||
|
|
||||||
|
if keyPath != "" {
|
||||||
|
key, err := parsePrivateKey(keyPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
methods = append(methods, ssh.PublicKeys(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
methods = append(methods, ssh.Password(password))
|
||||||
|
|
||||||
|
return &ssh.ClientConfig{User: user, Auth: methods}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tunnel *Tunnel) sshEndpoint() string {
|
||||||
|
return fmt.Sprintf("%s:%v", tunnel.SshHost, tunnel.SshPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tunnel *Tunnel) targetEndpoint() string {
|
||||||
|
return fmt.Sprintf("%v:%v", tunnel.TargetHost, tunnel.TargetPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tunnel *Tunnel) Start() error {
|
||||||
|
config, err := makeConfig(tunnel.SshUser, tunnel.SshPassword, tunnel.SshKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := ssh.Dial("tcp", tunnel.sshEndpoint(), config)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
port, err := connection.AvailablePort(PORT_START, PORT_LIMIT)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%v", port))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer listener.Close()
|
||||||
|
|
||||||
|
for {
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
go tunnel.handleConnection(conn, client)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tunnel *Tunnel) copy(wg *sync.WaitGroup, writer, reader net.Conn) {
|
||||||
|
defer wg.Done()
|
||||||
|
if _, err := io.Copy(writer, reader); err != nil {
|
||||||
|
log.Println("Tunnel copy error:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tunnel *Tunnel) handleConnection(local net.Conn, sshClient *ssh.Client) {
|
||||||
|
remote, err := sshClient.Dial("tcp", tunnel.targetEndpoint())
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(2)
|
||||||
|
|
||||||
|
go tunnel.copy(&wg, local, remote)
|
||||||
|
go tunnel.copy(&wg, remote, local)
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
@ -23,7 +23,7 @@ func portAvailable(port int) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get available TCP port on localhost by trying available ports in a range
|
// Get available TCP port on localhost by trying available ports in a range
|
||||||
func getAvailablePort(start int, limit int) (int, error) {
|
func AvailablePort(start int, limit int) (int, error) {
|
||||||
for i := start; i <= (start + limit); i++ {
|
for i := start; i <= (start + limit); i++ {
|
||||||
if portAvailable(i) {
|
if portAvailable(i) {
|
||||||
return i, nil
|
return i, nil
|
||||||
|
@ -44,7 +44,7 @@ func Test_getAvailablePort(t *testing.T) {
|
|||||||
t.Skip("FIXME")
|
t.Skip("FIXME")
|
||||||
}
|
}
|
||||||
|
|
||||||
port, err := getAvailablePort(8081, 1)
|
port, err := AvailablePort(8081, 1)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, 8081, port)
|
assert.Equal(t, 8081, port)
|
||||||
|
|
||||||
@ -65,11 +65,11 @@ func Test_getAvailablePort(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
port, err = getAvailablePort(8081, 0)
|
port, err = AvailablePort(8081, 0)
|
||||||
assert.EqualError(t, err, "No available port")
|
assert.EqualError(t, err, "No available port")
|
||||||
assert.Equal(t, -1, port)
|
assert.Equal(t, -1, port)
|
||||||
|
|
||||||
port, err = getAvailablePort(8081, 1)
|
port, err = AvailablePort(8081, 1)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, 8082, port)
|
assert.Equal(t, 8082, port)
|
||||||
}
|
}
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user