File: C:/Program Files/MySQL/MySQL Workbench 8.0/sshtunnel.py
# Copyright (c) 2012, 2019, Oracle and/or its affiliates. All rights reserved.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0,
# as published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms, as
# designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an additional
# permission to link the program and your derivative works with the
# separately licensed software that they have included with MySQL.
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See
# the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
import platform
import threading
import random
import queue
import traceback
import socket
import select
import sys
import time
import os
import mforms
import paramiko
from workbench.log import log_warning, log_error, log_debug, log_debug2, log_debug3, log_info
from wb_common import SSHFingerprintNewError, format_bad_host_exception
import grt
SSH_PORT = 22
REMOTE_PORT = 3306
# timeout for closing an unused tunnel
TUNNEL_TIMEOUT = 3
SSH_CONNECTION_TIMEOUT = 10
# paramiko 1.6 didn't have this class
if hasattr(paramiko, "WarningPolicy"):
WarningPolicy = paramiko.WarningPolicy
else:
class WarningPolicy(paramiko.MissingHostKeyPolicy):
def missing_host_key(self, client, hostname, key):
import binascii
log_warning('WARNING: Unknown %s host key for %s: %s\n' % (key.get_name(), hostname, binascii.hexlify(key.get_fingerprint())))
class StoreIfConfirmedPolicy(paramiko.MissingHostKeyPolicy):
def missing_host_key(self, client, hostname, key):
raise SSHFingerprintNewError("Key mismatched", client, hostname, key)
tunnel_serial=0
class Tunnel(threading.Thread):
"""This class is a threaded implementation of an SSH tunnel.
You should not access the attributes that starts with an underscore outside this thread
of execution (e.g. self._server) for this could run into race conditions. Even when
accessing its public attributes (those that don't start with an underscore) you should
be careful of acquiring the self.lock reentrant lock (and releasing it once done):
with tunnel.lock:
if tunnel.connecting:
... whatever...
"""
def __init__(self, q, server, username, target, password, keyfile):
super(Tunnel, self).__init__()
self.daemon = True
self._server = server
self._username = username
self._target = target
self._password = password
self._keyfile = keyfile
# Acquire and release this lock while accessing the attributes of objects
# of this class from the main thread:
self.lock = threading.RLock()
# This event marks when a random port is selected to be used:
self.port_is_set = threading.Event()
self.local_port = None
self._listen_sock = None
self.q = q
self._shutdown = False
self.connecting = False
self._client = paramiko.SSHClient()
self._connections = []
def is_connecting(self):
with self.lock:
return self.connecting
def run(self):
try:
self.do_run()
except Exception as e:
log_error("Unhandled exception in SSH tunnel: %s\n" % e)
def do_run(self):
global tunnel_serial
tunnel_serial += 1
mforms.Utilities.set_thread_name("SSHTunnel%i"%tunnel_serial)
sys.stdout.write('Thread started\n') # sys.stdout.write is thread safe while print isn't
log_debug2("SSH Tunel %i thread started\n" % tunnel_serial)
# Create a socket and pick a random port number for it:
self._listen_sock = socket.socket()
while True:
local_port = random.randint(1024, 65535)
try:
self._listen_sock.bind(('127.0.0.1', local_port))
self._listen_sock.listen(2)
with self.lock:
self.local_port = local_port
break
except socket.error as exc:
sys.stdout.write('Socket error: %s for port %d\n' % (exc, local_port) )
err, msg = exc.args
if err == 22:
continue # retry
self.notify_exception_error('ERROR',"Error initializing server end of tunnel", sys.exc_info())
raise exc
finally:
with self.lock:
self.connecting = True
self.port_is_set.set()
if self._keyfile:
self.notify('INFO', 'Connecting to SSH server at %s:%s using key %s...' % (self._server[0], self._server[1], self._keyfile) )
else:
self.notify('INFO', 'Connecting to SSH server at %s:%s...' % (self._server[0], self._server[1]) )
connected = self._connect_ssh()
if not connected:
self._listen_sock.close()
self._shutdown = True
with self.lock:
self.connecting = False
if connected:
self.notify('INFO', 'Connection opened')
del self._password
last_activity = time.time()
while not self._shutdown:
try:
socks = [self._listen_sock]
for sock, chan in self._connections:
socks.append(sock)
socks.append(chan)
r, w, x = select.select(socks, [], [], TUNNEL_TIMEOUT)
except Exception as e:
if not self._shutdown:
self.notify_exception_error('ERROR', 'Error while forwarding data: %r' % e, sys.exc_info())
break
if not r and len(socks) <= 1 and time.time() - last_activity > TUNNEL_TIMEOUT:
self.notify('INFO', 'Closing tunnel to %s:%s for inactivity...' % (self._server[0], self._server[1]) )
break
last_activity = time.time()
if self._listen_sock in r:
self.notify('INFO', 'New client connection')
self.accept_client()
closed = []
for sock, chan in self._connections:
if sock in r:
data = sock.recv(1024)
if not data:
closed.append((sock, chan))
else:
chan.send(data)
if chan in r:
data = chan.recv(1024)
if not data:
closed.append((sock, chan))
else:
sock.send(data)
for item in set(closed): # set() will remove duplicates from closed list
sock, chan = item
try:
sock.close()
except:
pass
try:
chan.close()
except:
pass
self.notify('INFO', 'Client for %s disconnected' % local_port)
self._connections.remove(item)
if closed and not self._connections and time.time() - last_activity > TUNNEL_TIMEOUT:
self.notify('INFO', 'Closing tunnel to %s:%s for inactivity...' % (self._server[0], self._server[1]) )
break
# Time to shutdown:
for sock, chan in self._connections:
try:
sock.close()
except:
pass
try:
chan.close()
except:
pass
self._listen_sock.close()
self._client.close()
log_debug("Leaving tunnel thread %s\n" % self.local_port)
def notify(self, msg_type, msg_object):
log_debug2("tunnel_%i: %s %s\n" % (self.local_port, msg_type, msg_object))
self.q.put((msg_type, msg_object))
def notify_exception_error(self, msg_type, msg_txt, msg_obj = None):
self.notify(msg_type, msg_txt)
log_error("%s\n" % traceback.format_exc())
def match(self, server, username, target):
with self.lock:
return self._server == server and self._username == username and self._target == target
def _get_ssh_config_path(self):
paths = []
user_path = grt.root.wb.options.options['pathtosshconfig'] if grt.root.wb.options.options['pathtosshconfig'] is not None else None
if user_path:
paths.append(user_path)
if platform.system().lower() == "windows":
paths.append("%s\ssh\config" % mforms.App.get().get_user_data_folder())
paths.append("%s\ssh\ssh_config" % mforms.App.get().get_user_data_folder())
else:
paths.append("~/.ssh/config")
paths.append("~/.ssh/ssh_config")
for path in paths:
if os.path.isfile(os.path.expanduser(path)):
return os.path.expanduser(path)
else:
log_debug3("ssh config file not found")
return None
def _connect_ssh(self):
"""Create the SSH client and set up the connection.
Any exception coming from paramiko will be notified as an error
that would cause the failure of the connection. Some of these are:
paramiko.AuthenticationException --- raised when authentication failed for some reason
paramiko.PasswordRequiredException --- raised when a password is needed to unlock a private key file;
this is a subclass of paramiko.AuthenticationException
"""
try:
config = paramiko.config.SSHConfig()
config_file_path = self._get_ssh_config_path()
if config_file_path:
with open(config_file_path) as f:
config.parse(f)
opts = config.lookup(self._server[0])
ssh_known_hosts_file = None
if "userknownhostsfile" in opts:
ssh_known_hosts_file = opts["userknownhostsfile"]
else:
self._client.get_host_keys().clear()
ssh_known_hosts_file = '~/.ssh/known_hosts'
if platform.system().lower() == "windows":
ssh_known_hosts_file = '%s\ssh\known_hosts' % mforms.App.get().get_user_data_folder()
try:
self._client.load_host_keys(os.path.expanduser(ssh_known_hosts_file))
except IOError as e:
log_warning("IOError, probably caused by file %s not found, the message was: %s\n" % (ssh_known_hosts_file, e))
if "stricthostkeychecking" in opts and opts["stricthostkeychecking"].lower() == "no":
self._client.set_missing_host_key_policy(WarningPolicy())
else:
self._client.set_missing_host_key_policy(StoreIfConfirmedPolicy())
has_key = bool(self._keyfile)
self._client.connect(self._server[0], self._server[1], username=self._username,
key_filename=self._keyfile, password=self._password,
look_for_keys=has_key, allow_agent=has_key, timeout=SSH_CONNECTION_TIMEOUT)
except paramiko.BadHostKeyException as exc:
self.notify_exception_error('ERROR',format_bad_host_exception(exc, '%s\ssh\known_hosts' % mforms.App.get().get_user_data_folder() if platform.system().lower() == "windows" else "~/.ssh/known_hosts file"))
return False
except paramiko.BadAuthenticationType as exc:
self.notify_exception_error('ERROR', "Bad authentication type, the server is not accepting this type of authentication.\nAllowed ones are:\n %s" % exc.allowed_types, sys.exc_info());
return False
except paramiko.AuthenticationException as exc:
self.notify_exception_error('ERROR', "Authentication failed, please check credentials.\nPlease refer to logs for details", sys.exc_info())
return False
except socket.gaierror as exc:
self.notify_exception_error('ERROR', "Error connecting to SSH server: %s\nPlease refer to logs for details." % str(exc))
return False
except paramiko.ChannelException as exc:
self.notify_exception_error('ERROR', "Error connecting SSH channel.\nPlease refer to logs for details: %s" % str(exc), sys.exc_info())
return False
except SSHFingerprintNewError as exc:
self.notify_exception_error('KEY_ERROR', { 'msg': "The authenticity of host '%(0)s (%(0)s)' can't be established.\n%(1)s key fingerprint is %(2)s\nAre you sure you want to continue connecting?" % {'0': "%s:%s" % (self._server[0], self._server[1]), '1': exc.key.get_name(), '2': exc.fingerprint}, 'obj': exc})
return False
except IOError as exc:
#Io should be report to the user, so maybe he will be able to fix this issue
self.notify_exception_error('IO_ERROR', "IO Error: %s.\n Please refer to logs for details." % str(exc), sys.exc_info())
return False
except Exception as exc:
self.notify_exception_error('ERROR', "Authentication error, unhandled exception caught in tunnel manager, please refer to logs for details", sys.exc_info())
return False
else:
log_debug("connect_ssh2 OK\n")
return True
def close(self):
self.notify('INFO', 'Closing tunnel')
self._listen_sock.close()
self._shutdown = True
def accept_client(self):
try:
local_sock, peeraddr = self._listen_sock.accept()
except Exception as e:
self.notify_exception_error('ERROR', 'Error accepting new tunnel client: %r' % e,sys.exc_info())
return
self.notify('INFO', 'Client connection established')
transport = self._client.get_transport()
try:
sshchan = transport.open_channel('direct-tcpip', self._target, local_sock.getpeername())
except paramiko.ChannelException as exc:
self.notify_exception_error('ERROR', 'Could not open port forwarding SSH channel: %s' % exc)
local_sock.close()
return
except Exception as e:
self.notify_exception_error('ERROR', 'Remote connection to %s:%d failed: %r' % (self._target[0], self._target[1], e), sys.exc_info())
local_sock.close()
return
if sshchan is None:
self.notify_exception_error('ERROR', 'Remote connection to %s:%d was rejected by the SSH server.' % (self._target[0], self._target[1]), sys.exc_info())
local_sock.close()
return
self.notify('INFO', 'Tunnel now open %r -> %r -> %r' % (local_sock.getsockname(), sshchan.getpeername(), self._target))
self._connections.append((local_sock, sshchan))
class TunnelManager:
def __init__(self):
self.tunnel_by_port = {}
self.inpipe = sys.stdin
self.outpipe = sys.stdout
def _address_port_tuple(self, raw_address, default_port):
if type(raw_address) is str:
if ':' in raw_address:
address, port = raw_address.split(':', 1)
try:
port = int(port)
except:
port = default_port
return (address, port)
else:
return (raw_address, default_port)
else:
return raw_address
def lookup_tunnel(self, server, username, target):
server = self._address_port_tuple(server, default_port=SSH_PORT)
target = self._address_port_tuple(target, default_port=REMOTE_PORT)
for port, tunnel in list(self.tunnel_by_port.items()):
if tunnel.match(server, username, target) and tunnel.isAlive():
with tunnel.lock:
return tunnel.local_port
return None
def open_tunnel(self, server, username, password, keyfile, target):
try:
port = self.open_ssh(server, username, password, keyfile, target)
except Exception:
traceback.print_exc()
return (False, str(traceback.format_exc()))
return (True, port)
def open_ssh(self, server, username, password, keyfile, target):
server = self._address_port_tuple(server, default_port=SSH_PORT)
target = self._address_port_tuple(target, default_port=REMOTE_PORT)
password = password or ''
keyfile = keyfile or None
if keyfile is not None:
keyfile = keyfile.decode('utf-8')
found = None
for tunnel in list(self.tunnel_by_port.values()):
if tunnel.match(server, username, target) and tunnel.isAlive():
found = tunnel
break
if found:
with tunnel.lock:
log_debug('Reusing tunnel at port %d' % tunnel.local_port)
return tunnel.local_port
else:
tunnel = Tunnel(queue.Queue(), server, username, target, password, keyfile)
tunnel.start()
tunnel.port_is_set.wait()
with tunnel.lock:
port = tunnel.local_port
self.tunnel_by_port[port] = tunnel
return port
def wait_connection(self, port):
tunnel = self.tunnel_by_port.get(port)
if not tunnel:
return 'Could not find a tunnel for port %d' % port
error = None
close_tunnel = False
tunnel.port_is_set.wait()
if tunnel.isAlive():
while True:
# Process any message in queue. Every retrieved message is printed.
# If an error is detected in the queue, exit returning its message:
try:
msg_type, msg = tunnel.q.get_nowait()
except queue.Empty:
continue
else:
if msg_type == 'KEY_ERROR':
if mforms.Utilities.show_message("SSH Server Fingerprint Missing", msg['msg'], "Continue", "Cancel", "") == mforms.ResultOk:
msg['obj'].client._host_keys.add(msg['obj'].hostname, msg['obj'].key.get_name(), msg['obj'].key)
if msg['obj'].client._host_keys_filename is not None:
try:
if os.path.isdir(os.path.dirname(msg['obj'].client._host_keys_filename)) == False:
log_warning("Host_keys directory is missing, recreating it\n")
os.makedirs(os.path.dirname(msg['obj'].client._host_keys_filename))
if os.path.exists(msg['obj'].client._host_keys_filename) == False:
log_warning("Host_keys file is missing, recreating it\n")
open(msg['obj'].client._host_keys_filename, 'a').close()
msg['obj'].client.save_host_keys(msg['obj'].client._host_keys_filename)
log_warning("Successfully saved host_keys file.\n")
except IOError as e:
error = str(e)
break;
error = "Server key has been stored"
else:
error = "User cancelled"
close_tunnel = True
break # Exit returning the error message
elif msg_type == 'IO_ERROR':
error = msg
break # Exit returning the error message
else:
time.sleep(0.3)
_msg = msg
if type(msg) is tuple:
msg = '\n' + ''.join(traceback.format_exception(*msg))
_msg = str(_msg[1])
log_debug("%s: %s\n" % (msg_type, msg))
if msg_type == 'ERROR':
error = _msg
break # Exit returning the error message
if (not tunnel.is_connecting() or not tunnel.isAlive()) and tunnel.q.empty():
break
time.sleep(0.3)
log_debug("returning from wait_connection(%s): %s\n" % (port, error))
# we need to close tunnel so it get opened again, without it we may have problems later
if close_tunnel:
tunnel.close()
del self.tunnel_by_port[port]
return error
def get_message(self, port):
if port not in self.tunnel_by_port:
log_error("Looking up invalid port %s\n" % port)
return None
tunnel = self.tunnel_by_port[port]
try:
return tunnel.q.get_nowait()
except queue.Empty:
return None
def set_keepalive(self, port, keepalive):
if keepalive == 0:
log_info("SSH KeepAlive setting skipped.\n")
return
tunnel = self.tunnel_by_port.get(port)
if not tunnel:
log_error("Looking up invalid port %s\n" % port)
return
transport = tunnel._client.get_transport()
if transport is None:
log_error("SSHTransport not ready yet %d\n" % port)
return
transport.set_keepalive(keepalive)
def close(self, port):
pass
# tunnels auto-close when inactive
#tunnel = self.tunnel_by_port.get(port, None)
#if tunnel:
# tunnel.num_clients -= 1
# if tunnel.num_clients == 0:
# tunnel.close()
# del self.tunnel_by_port[port]
def send(self, code, arg=''):
if arg:
self.outpipe.write(code + ' ' + arg + '\n')
else:
self.outpipe.write(code + '\n')
self.outpipe.flush()
def shutdown(self):
for tunnel in list(self.tunnel_by_port.values()):
tunnel.close()
tunnel.join()
# FIXME: It seems that this function is never called. Should we remove it?
def wait_requests(self):
#print "SSH Tunnel Manager started, waiting for requests..."
self.send("READY")
while True:
request = self.inpipe.readline()
if not request:
#print "Exiting tunnel manager..."
break
try:
cmd, args = eval(request, {}, {})
except:
self.send("ERROR", "Invalid request")
continue
if cmd == "LOOKUP":
try:
port = self.lookup_tunnel(*args)
if port is not None:
self.send("OK", str(port))
else:
self.send("ERROR", "not found")
except Exception as exc:
self.send("ERROR", str(exc))
elif cmd == "OPENSSH":
try:
port = self.open_ssh(*args)
self.send("OK", str(port))
except Exception as exc:
self.send("ERROR", str(exc))
elif cmd == "CLOSE":
#self.close(args[0])
self.send("OK")
elif cmd == "WAIT":
# wait for the SSH connection to be established
error = self.wait_connection(args)
if not error:
self.send("OK")
else:
self.send("ERROR "+error)
elif cmd == "MESSAGE":
msg = self.get_message(args)
if msg:
self.send(msg)
else:
self.send("NONE")
else:
log_error("Invalid request %s\n" % request)
self.send("ERROR", "Invalid request")
"""
if "--single" in sys.argv:
target = sys.argv[2]
if "-pw" in sys.argv:
password = sys.argv[sys.argv.index("-pw")+1]
else:
password = None
if "-i" in sys.argv:
keyfile = sys.argv[sys.argv.index("-i")+1]
else:
keyfile = None
server = sys.argv[-1]
tunnel = Tunnel(None, False)
if "@" in server:
username, server = server.split("@", 1)
else:
username = ""
print "Starting tunnel..."
if type(server) == str:
if ':' in server:
server = server.split(":", 1)
server = (server[0], int(server[1]))
else:
server = (server, SSH_PORT)
if type(target) == str:
if ':' in target:
target = target.split(":", 1)
target = (target[0], int(target[1]))
else:
target = (target, REMOTE_PORT)
tunnel.start(server, username, password or "", keyfile, target)
else:
tm = TunnelManager()
sys.stdout = sys.stderr
try:
tm.wait_requests()
except KeyboardInterrupt, e:
pass
"""