File: C:/Program Files/MySQL/MySQL Workbench 8.0/modules/db_mysql_re_grt.py
# Copyright (c) 2012, 2019, Oracle and/or its affiliates. All rights reserved.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0,
# as published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms, as
# designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an additional
# permission to link the program and your derivative works with the
# separately licensed software that they have included with MySQL.
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See
# the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
from wb import DefineModule
import grt
from workbench import db_utils
from workbench.db_utils import MySQLConnection, escape_sql_string, escape_sql_identifier
from workbench.exceptions import NotConnectedError
ModuleInfo = DefineModule(name= "DbMySQLRE", author= "Oracle Corp.", version="1.0")
######################### Non exposed functions and variables #################
def check_interruption():
if grt.query_status():
raise grt.UserInterrupt()
_connections = {}
def get_connection(connection_object):
if connection_object.__id__ in _connections:
return _connections[connection_object.__id__]
else:
raise NotConnectedError("No open connection to %s" % connection_object.hostIdentifier)
def execute_query(connection_object, query):
return get_connection(connection_object).executeQuery(query)
###############################################################################
@ModuleInfo.export(grt.INT, grt.classes.db_mgmt_Connection, grt.STRING)
def connect(connection, password):
try:
con = get_connection(connection)
try:
con.ping()
except Exception:
grt.send_info("Reconnecting to %s..." % connection.hostIdentifier)
con.disconnect()
con.connect()
grt.send_info("Connection restablished")
except NotConnectedError:
con = MySQLConnection(connection, password = password)
host_identifier = connection.hostIdentifier
grt.send_info("Connecting to %s..." % host_identifier)
con.connect()
_connections[connection.__id__] = con
version = "Unknown version"
result = execute_query(connection, "SHOW VARIABLES LIKE 'version'")
if result and result.nextRow():
version = result.stringByIndex(2)
grt.send_info("Connected to %s, %s" % (host_identifier, version))
return 1
@ModuleInfo.export(grt.INT, grt.classes.db_mgmt_Connection)
def disconnect(connection):
if connection.__id__ in _connections:
_connections[connection.__id__].disconnect()
del _connections[connection.__id__]
return 0
@ModuleInfo.export(grt.INT, grt.classes.db_mgmt_Connection)
def isConnected(connection):
return 1 if connection.__id__ in _connections else 0
@ModuleInfo.export(grt.STRING)
def getTargetDBMSName():
return 'Mysql'
@ModuleInfo.export(grt.classes.GrtVersion, grt.classes.db_mgmt_Connection)
def getServerVersion(connection):
"""Returns a GrtVersion instance containing information about the server version."""
result = execute_query(connection, "SHOW VARIABLES LIKE 'version'")
if result and result.nextRow():
import re
p = re.match("([0-9]*)\.([0-9]*)\.([0-9]*)", result.stringByIndex(2))
if p and p.groups():
version = grt.classes.GrtVersion()
ver_parts = [int(n) for n in p.groups()] + [0]*4
version.majorNumber, version.minorNumber, version.releaseNumber, version.buildNumber = ver_parts[:4]
return version
return None
@ModuleInfo.export(grt.LIST, grt.classes.db_mgmt_Connection)
def getCatalogNames(connection):
"""Returns a list of the available catalogs.
"""
return ["def"]
@ModuleInfo.export(grt.LIST, grt.classes.db_mgmt_Connection, grt.STRING)
def getSchemaNames(connection, catalog):
"""Returns a list of schemas for the given connection object."""
names = []
result = execute_query(connection, "SHOW DATABASES")
if grt.query_status():
raise grt.UserInterrupt()
while result and result.nextRow():
names.append(result.stringByIndex(1))
return names
def getAllTableNames(connection, catalog_name, schema_name):
names = []
result = execute_query(connection, "SHOW TABLES FROM `%s`" % escape_sql_identifier(schema_name))
while result and result.nextRow():
names.append(result.stringByIndex(1))
return names
@ModuleInfo.export(grt.LIST, grt.classes.db_mgmt_Connection, grt.STRING, grt.STRING)
def getTableNames(connection, catalog_name, schema_name):
names = []
version = getServerVersion(connection)
if (version.majorNumber, version.minorNumber, version.releaseNumber) > (5, 0, 1):
query = "SHOW FULL TABLES FROM `%s` WHERE Table_type <> 'VIEW'"
else:
query = "SHOW TABLES FROM `%s`"
result = execute_query(connection, query % escape_sql_identifier(schema_name))
while result and result.nextRow():
names.append(result.stringByIndex(1))
return names
@ModuleInfo.export(grt.LIST, grt.classes.db_mgmt_Connection, grt.STRING, grt.STRING)
def getViewNames(connection, catalog_name, schema_name):
names = []
version = getServerVersion(connection)
if (version.majorNumber, version.minorNumber, version.releaseNumber) > (5, 0, 1):
result = execute_query(connection, "SHOW FULL TABLES FROM `%s` WHERE Table_type = 'VIEW'" % escape_sql_identifier(schema_name))
while result and result.nextRow():
names.append(result.stringByIndex(1))
return names
@ModuleInfo.export(grt.LIST, grt.classes.db_mgmt_Connection, grt.STRING, grt.STRING)
def getTriggerNames(connection, catalog_name, schema_name):
names = []
version = getServerVersion(connection)
if (version.majorNumber, version.minorNumber, version.releaseNumber) >= (5, 0, 10):
result = execute_query(connection, "SHOW TRIGGERS FROM `%s`" % escape_sql_identifier(schema_name))
while result and result.nextRow():
names.append(result.stringByIndex(1))
return names
@ModuleInfo.export(grt.LIST, grt.classes.db_mgmt_Connection, grt.STRING, grt.STRING)
def getProcedureNames(connection, catalog_name, schema_name):
names = []
version = getServerVersion(connection)
if (version.majorNumber, version.minorNumber, version.releaseNumber) >= (5, 0, 0):
result = execute_query(connection, "SHOW PROCEDURE STATUS WHERE Db='%s'" % escape_sql_string(schema_name))
while result and result.nextRow():
names.append(result.stringByName("Name"))
return names
@ModuleInfo.export(grt.LIST, grt.classes.db_mgmt_Connection, grt.STRING, grt.STRING)
def getFunctionNames(connection, catalog_name, schema_name):
names = []
version = getServerVersion(connection)
if (version.majorNumber, version.minorNumber, version.releaseNumber) >= (5, 0, 0):
result = execute_query(connection, "SHOW FUNCTION STATUS WHERE Db='%s'" % escape_sql_string(schema_name))
while result and result.nextRow():
names.append(result.stringByName("Name"))
return names
@ModuleInfo.export(grt.STRING, grt.classes.db_mgmt_Connection)
def getOS(connection):
conn = get_connection(connection)
if conn:
try:
result = conn.executeQuery("SELECT @@version_compile_os")
except db_utils.QueryError as e:
grt.send_error("Error executing query: %s." % e)
return None
if result and result.nextRow():
compile_os = result.stringByIndex(1).lower()
if 'linux' in compile_os:
return 'linux'
elif 'win' in compile_os:
return 'windows'
elif any(os in compile_os for os in ['osx', 'macos']):
return 'darwin'
return None
@ModuleInfo.export(grt.STRING, grt.classes.db_mgmt_Connection)
def getServerMode(connection):
result = execute_query(connection, "SELECT @@SESSION.sql_mode")
if result and result.nextRow():
return result.stringByIndex(1)
return None
######### Reverse Engineering functions #########
@ModuleInfo.export(grt.classes.db_Catalog, grt.classes.db_mgmt_Connection, grt.STRING, (grt.LIST, grt.STRING), grt.DICT)
def reverseEngineer(connection, catalog_name, schemata_list, context):
catalog = grt.classes.db_mysql_Catalog()
catalog.name = catalog_name
catalog.simpleDatatypes.remove_all()
catalog.simpleDatatypes.extend(connection.driver.owner.simpleDatatypes)
table_names_per_schema = {}
routine_names_per_schema = {}
trigger_names_per_schema = {}
def filter_warnings(mtype, text, detail):
# filter out parser warnings about stub creation/reuse from the message stream, since
# they're harmless
if mtype == "WARNING" and (" stub " in text or "Stub " in text):
grt.send_info(text)
return True
return False
version = getServerVersion(connection)
get_tables = context.get("reverseEngineerTables", True)
get_triggers = context.get("reverseEngineerTriggers", True) and (version.majorNumber, version.minorNumber, version.releaseNumber) >= (5, 1, 21)
get_views = context.get("reverseEngineerViews", True)
get_routines = context.get("reverseEngineerRoutines", True)
# calculate total workload 1st
# 10% of the progress is for preparation
grt.send_progress(0, "Preparing...")
total = 0
i = 0.0
for schema_name in schemata_list:
check_interruption()
if get_tables and get_views:
table_names = getAllTableNames(connection, catalog_name, schema_name)
elif get_tables:
table_names = getTableNames(connection, catalog_name, schema_name)
elif get_views:
table_names = getViewNames(connection, catalog_name, schema_name)
else:
table_name = []
total += len(table_names)
table_names_per_schema[schema_name] = table_names
check_interruption()
if get_routines:
procedure_names = getProcedureNames(connection, catalog_name, schema_name)
check_interruption()
function_names = getFunctionNames(connection, catalog_name, schema_name)
check_interruption()
total += len(procedure_names)
total += len(function_names)
routine_names_per_schema[schema_name] = procedure_names, function_names
else:
routine_names_per_schema[schema_name] = [], []
if get_triggers:
trigger_names = getTriggerNames(connection, catalog_name, schema_name)
total += len(trigger_names)
else:
trigger_names = []
trigger_names_per_schema[schema_name] = trigger_names
grt.send_progress(0.1 * (i/len(schemata_list)), "Preparing...")
i += 1.0
def wrap_sql(sql, schema):
return "USE `%s`;\n%s"%(escape_sql_identifier(schema), sql)
def wrap_routine_sql(sql):
return "DELIMITER " + grt.root.wb.options.options['SqlDelimiter'] + "\n" + sql
i = 0.0
for schema_name in schemata_list:
schema = grt.classes.db_mysql_Schema()
schema.owner = catalog
schema.name = schema_name
catalog.schemata.append(schema)
context = grt.modules.MySQLParserServices.createNewParserContext(catalog.characterSets, getServerVersion(connection), getServerMode(connection), 1)
options = {}
if get_tables or get_views:
grt.send_info("Reverse engineering tables from %s" % schema_name)
for table_name in table_names_per_schema[schema_name]:
check_interruption()
grt.send_progress(0.1 + 0.9 * (i / total), "Retrieving table %s.%s..." % (schema_name, table_name))
result = execute_query(connection, "SHOW CREATE TABLE `%s`.`%s`" % (escape_sql_identifier(schema_name), escape_sql_identifier(table_name)))
i += 0.5
grt.send_progress(0.1 + 0.9 * (i / total), "Reverse engineering %s.%s..." % (schema_name, table_name))
if result and result.nextRow():
sql = result.stringByIndex(2)
grt.push_message_handler(filter_warnings)
grt.begin_progress_step(0.1 + 0.9 * (i / total), 0.1 + 0.9 * ((i+0.5) / total))
grt.modules.MySQLParserServices.parseSQLIntoCatalogSql(context, catalog, wrap_sql(sql, schema_name), options)
grt.end_progress_step()
grt.pop_message_handler()
i += 0.5
else:
raise Exception("Could not fetch table information for %s.%s" % (schema_name, table_name))
if get_triggers:
grt.send_info("Reverse engineering triggers from %s" % schema_name)
for trigger_name in trigger_names_per_schema[schema_name]:
check_interruption()
grt.send_progress(0.1 + 0.9 * (i / total), "Retrieving trigger %s.%s..." % (schema_name, trigger_name))
result = execute_query(connection, "SHOW CREATE TRIGGER `%s`.`%s`" % (escape_sql_identifier(schema_name), escape_sql_identifier(trigger_name)))
i += 0.5
grt.send_progress(0.1 + 0.9 * (i / total), "Reverse engineering %s.%s..." % (schema_name, trigger_name))
if result and result.nextRow():
sql = result.stringByName("SQL Original Statement")
grt.begin_progress_step(0.1 + 0.9 * (i / total), 0.1 + 0.9 * ((i+0.5) / total))
grt.modules.MySQLParserServices.parseSQLIntoCatalogSql(context, catalog, wrap_sql(wrap_routine_sql(sql), schema_name), options)
grt.end_progress_step()
i += 0.5
else:
raise Exception("Could not fetch trigger information for %s.%s" % (schema_name, trigger_name))
if get_routines:
grt.send_info("Reverse engineering stored procedures from %s" % schema_name)
procedure_names, function_names = routine_names_per_schema[schema_name]
for name in procedure_names:
check_interruption()
grt.send_progress(0.1 + 0.9 * (i / total), "Retrieving stored procedure %s.%s..." % (schema_name, name))
result = execute_query(connection, "SHOW CREATE PROCEDURE `%s`.`%s`" % (escape_sql_identifier(schema_name), escape_sql_identifier(name)))
i += 0.5
grt.send_progress(0.1 + 0.9 * (i / total), "Reverse engineering %s.%s..." % (schema_name, name))
if result and result.nextRow():
sql = result.stringByName("Create Procedure")
grt.begin_progress_step(0.1 + 0.9 * (i / total), 0.1 + 0.9 * ((i+0.5) / total))
grt.modules.MySQLParserServices.parseSQLIntoCatalogSql(context, catalog, wrap_sql(wrap_routine_sql(sql), schema_name), options)
grt.end_progress_step()
i += 0.5
else:
raise Exception("Could not fetch procedure information for %s.%s" % (schema_name, name))
grt.send_info("Reverse engineering functions from %s" % schema_name)
for name in function_names:
check_interruption()
grt.send_progress(0.1 + 0.9 * (i / total), "Retrieving function %s.%s..." % (schema_name, name))
result = execute_query(connection, "SHOW CREATE FUNCTION `%s`.`%s`" % (escape_sql_identifier(schema_name), escape_sql_identifier(name)))
i += 0.5
grt.send_progress(0.1 + 0.9 * (i / total), "Reverse engineering %s.%s..." % (schema_name, name))
if result and result.nextRow():
sql = result.stringByName("Create Function")
grt.begin_progress_step(0.1 + 0.9 * (i / total), 0.1 + 0.9 * ((i+0.5) / total))
grt.modules.MySQLParserServices.parseSQLIntoCatalogSql(context, catalog, wrap_sql(wrap_routine_sql(sql), schema_name), options)
grt.end_progress_step()
i += 0.5
else:
raise Exception("Could not fetch function information for %s.%s" % (schema_name, name))
grt.send_progress(1.0, "Reverse engineered %i objects" % total)
# check for any stub tables left
empty_schemas = []
for schema in catalog.schemata:
schema_has_stub_tables = False
for table in reversed(schema.tables):
if table.isStub:
grt.send_warning("Table %s was referenced from another table, but was not reverse engineered" % table.name)
schema.tables.remove(table)
schema_has_stub_tables = True
if not schema.tables and not schema.views and not schema.routines and schema_has_stub_tables:
empty_schemas.append(schema)
for schema in empty_schemas:
catalog.schemata.remove(schema)
return catalog
@ModuleInfo.export(grt.classes.db_mysql_Table, grt.classes.db_mgmt_Connection, grt.STRING, grt.STRING)
def reverseEngineerTable(connection, schema_name, table_name):
catalog = grt.classes.db_mysql_Catalog()
schema = grt.classes.db_mysql_Schema()
schema.owner = catalog
schema.name = schema_name
catalog.schemata.append(schema)
result = execute_query(connection, "SHOW CREATE TABLE `%s`.`%s`" % (escape_sql_identifier(schema_name), escape_sql_identifier(table_name)))
if result and result.nextRow():
sql = result.stringByIndex(2)
context = grt.modules.MySQLParserServices.createNewParserContext(catalog.characterSets, getServerVersion(connection), getServerMode(connection), 1)
options = {}
grt.modules.MySQLParserServices.parseSQLIntoCatalogSql(context, catalog, ("USE `%s`;\n" % escape_sql_identifier(schema_name)) + sql, options)
else:
raise Exception("Could not fetch table information for %s" % table_name)
table = [t for t in catalog.schemata[0].tables if t.name == table_name]
if table:
return table[0]
return None
@ModuleInfo.export(grt.classes.db_mysql_Table, grt.classes.db_mgmt_Connection, grt.classes.db_mysql_Catalog, grt.STRING, grt.STRING)
def reverseEngineerTableToCatalog(connection, catalog, schema_name, table_name):
schema = None
for s in catalog.schemata:
if s.name == schema_name:
schema = s
break
if not schema:
schema = grt.classes.db_mysql_Schema()
schema.owner = catalog
schema.name = schema_name
catalog.schemata.append(schema)
result = execute_query(connection, "SHOW CREATE TABLE `%s`.`%s`" % (escape_sql_identifier(schema_name), escape_sql_identifier(table_name)))
if result and result.nextRow():
sql = result.stringByIndex(2)
context = grt.modules.MySQLParserServices.createNewParserContext(catalog.characterSets, getServerVersion(connection), getServerMode(connection), 1)
options = {}
grt.modules.MySQLParserServices.parseSQLIntoCatalogSql(context, catalog, ("USE `%s`;\n" % escape_sql_identifier(schema_name))+sql, options)
else:
raise Exception("Could not fetch table information for %s" % table_name)
table = [t for t in schema.tables if t.name == table_name]
if table:
return table[0]
return None