db-condenser 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- db_condenser/__init__.py +0 -0
- db_condenser/config_reader.py +191 -0
- db_condenser/data_masking.py +52 -0
- db_condenser/database_helper.py +13 -0
- db_condenser/db_connect.py +126 -0
- db_condenser/direct_subset.py +128 -0
- db_condenser/mysql_database_creator.py +128 -0
- db_condenser/mysql_database_helper.py +302 -0
- db_condenser/psql_database_creator.py +243 -0
- db_condenser/psql_database_helper.py +428 -0
- db_condenser/result_tabulator.py +49 -0
- db_condenser/subset.py +590 -0
- db_condenser/subset_utils.py +226 -0
- db_condenser/topo_orderer.py +46 -0
- db_condenser-1.0.0.dist-info/METADATA +119 -0
- db_condenser-1.0.0.dist-info/RECORD +19 -0
- db_condenser-1.0.0.dist-info/WHEEL +4 -0
- db_condenser-1.0.0.dist-info/entry_points.txt +3 -0
- db_condenser-1.0.0.dist-info/licenses/LICENSE +9 -0
db_condenser/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import sys
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class InitialTarget:
|
|
10
|
+
table: str
|
|
11
|
+
percent: float | None = None
|
|
12
|
+
where: str | None = None
|
|
13
|
+
|
|
14
|
+
def __post_init__(self):
|
|
15
|
+
# Exactly one of where/percent must be set
|
|
16
|
+
if (self.where is None) == (self.percent is None):
|
|
17
|
+
raise ValueError(
|
|
18
|
+
"Initial Target must specify exactly one of 'where' or 'percent'"
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DbType(str, Enum):
|
|
23
|
+
POSTGRES = "postgres"
|
|
24
|
+
MYSQL = "mysql"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class DbConnectInfo:
|
|
29
|
+
user_name: str
|
|
30
|
+
host: str
|
|
31
|
+
db_name: str
|
|
32
|
+
port: int
|
|
33
|
+
ssl_mode: str | None = None
|
|
34
|
+
# No password will prompt user
|
|
35
|
+
password: str | None = None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class UpstreamFilter:
|
|
40
|
+
condition: str
|
|
41
|
+
table: str | None = None
|
|
42
|
+
column: str | None = None
|
|
43
|
+
|
|
44
|
+
def __post_init__(self):
|
|
45
|
+
# Exactly one of table/column must be set
|
|
46
|
+
if (self.table is None) == (self.column is None):
|
|
47
|
+
raise ValueError(
|
|
48
|
+
"Upstream filters must specify exactly one of 'table' or 'column'"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass
|
|
53
|
+
class DependencyBreak:
|
|
54
|
+
fk_table: str
|
|
55
|
+
target_table: str
|
|
56
|
+
preserve_fk_opportunistically: bool = False
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass
|
|
60
|
+
class FkAugmentation:
|
|
61
|
+
fk_table: str
|
|
62
|
+
fk_columns: list[str]
|
|
63
|
+
target_table: str
|
|
64
|
+
target_columns: list[str]
|
|
65
|
+
|
|
66
|
+
def __post_init__(self):
|
|
67
|
+
if len(self.fk_columns) != len(self.target_columns):
|
|
68
|
+
raise ValueError("fk_columns and target_columns must be the same length")
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclass
|
|
72
|
+
class Config:
|
|
73
|
+
db_type: DbType
|
|
74
|
+
initial_targets: list[InitialTarget]
|
|
75
|
+
source_db_connection_info: DbConnectInfo
|
|
76
|
+
destination_db_connection_info: DbConnectInfo
|
|
77
|
+
keep_disconnected_tables: bool = False
|
|
78
|
+
upstream_filters: list[UpstreamFilter] = field(default_factory=list)
|
|
79
|
+
excluded_tables: list[str] = field(default_factory=list)
|
|
80
|
+
passthrough_tables: list[str] = field(default_factory=list)
|
|
81
|
+
dependency_breaks: list[DependencyBreak] = field(default_factory=list)
|
|
82
|
+
fk_augmentation: list[FkAugmentation] = field(default_factory=list)
|
|
83
|
+
max_rows_per_table: int | Literal["ALL"] | None = None
|
|
84
|
+
use_temp_tables: bool = False
|
|
85
|
+
use_copy_protocol: bool = False
|
|
86
|
+
pre_constraint_sql: list[str] = field(default_factory=list)
|
|
87
|
+
post_subset_sql: list[str] = field(default_factory=list)
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def dependency_break_set(self) -> set[tuple[str, str]]:
|
|
91
|
+
return {(b.fk_table, b.target_table) for b in self.dependency_breaks}
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def preserve_fk_opportunistically(self) -> set[tuple[str, str]]:
|
|
95
|
+
return {
|
|
96
|
+
(b.fk_table, b.target_table)
|
|
97
|
+
for b in self.dependency_breaks
|
|
98
|
+
if b.preserve_fk_opportunistically
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def initial_target_tables(self) -> list[str]:
|
|
103
|
+
return [target.table for target in self.initial_targets]
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
config: Config | None = None
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _raw_dict_to_config(raw_config: dict) -> Config:
|
|
110
|
+
initial_targets = []
|
|
111
|
+
db_type = DbType(raw_config["db_type"].lower())
|
|
112
|
+
|
|
113
|
+
initial_targets = [
|
|
114
|
+
InitialTarget(**target) for target in raw_config["initial_targets"]
|
|
115
|
+
]
|
|
116
|
+
|
|
117
|
+
source_db = DbConnectInfo(**raw_config["source_db_connection_info"])
|
|
118
|
+
dest_db = DbConnectInfo(**raw_config["destination_db_connection_info"])
|
|
119
|
+
|
|
120
|
+
upstream_filters = [
|
|
121
|
+
UpstreamFilter(**table) for table in raw_config.get("upstream_filters", [])
|
|
122
|
+
]
|
|
123
|
+
|
|
124
|
+
excluded_tables = [table for table in raw_config.get("excluded_tables", [])]
|
|
125
|
+
passthrough_tables = list(
|
|
126
|
+
set([table for table in raw_config.get("passthrough_tables", [])])
|
|
127
|
+
)
|
|
128
|
+
dependency_breaks = [
|
|
129
|
+
DependencyBreak(**relation)
|
|
130
|
+
for relation in raw_config.get("dependency_breaks", [])
|
|
131
|
+
]
|
|
132
|
+
fk_augmentation = []
|
|
133
|
+
for fka in raw_config.get("fk_augmentation", []):
|
|
134
|
+
if "fk_schema" in fka:
|
|
135
|
+
fka = {
|
|
136
|
+
"fk_table": fka["fk_schema"] + "." + fka["fk_table"],
|
|
137
|
+
"fk_columns": fka["fk_columns"],
|
|
138
|
+
"target_table": fka["target_schema"] + "." + fka["target_table"],
|
|
139
|
+
"target_columns": fka["target_columns"],
|
|
140
|
+
}
|
|
141
|
+
fk_augmentation.append(FkAugmentation(**fka))
|
|
142
|
+
|
|
143
|
+
pre_constraint_sql = [sql for sql in raw_config.get("pre_constraint_sql", [])]
|
|
144
|
+
post_subset_sql = [sql for sql in raw_config.get("post_subset_sql", [])]
|
|
145
|
+
max_rows_per_table = raw_config.get("max_rows_per_table", None)
|
|
146
|
+
use_temp_tables = bool(raw_config.get("use_temp_tables", False))
|
|
147
|
+
use_copy_protocol = bool(raw_config.get("use_copy_protocol", False))
|
|
148
|
+
return Config(
|
|
149
|
+
db_type=db_type,
|
|
150
|
+
initial_targets=initial_targets,
|
|
151
|
+
source_db_connection_info=source_db,
|
|
152
|
+
destination_db_connection_info=dest_db,
|
|
153
|
+
keep_disconnected_tables=bool(
|
|
154
|
+
raw_config.get("keep_disconnected_tables", False)
|
|
155
|
+
),
|
|
156
|
+
upstream_filters=upstream_filters,
|
|
157
|
+
excluded_tables=excluded_tables,
|
|
158
|
+
passthrough_tables=passthrough_tables,
|
|
159
|
+
dependency_breaks=dependency_breaks,
|
|
160
|
+
fk_augmentation=fk_augmentation,
|
|
161
|
+
max_rows_per_table=max_rows_per_table,
|
|
162
|
+
use_temp_tables=use_temp_tables,
|
|
163
|
+
use_copy_protocol=use_copy_protocol,
|
|
164
|
+
pre_constraint_sql=pre_constraint_sql,
|
|
165
|
+
post_subset_sql=post_subset_sql,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def initialize(file_like=None):
|
|
170
|
+
global config
|
|
171
|
+
if config:
|
|
172
|
+
print("WARNING: Attempted to initialize configuration twice.", file=sys.stderr)
|
|
173
|
+
|
|
174
|
+
if not file_like:
|
|
175
|
+
with open("config.json", "r") as fp:
|
|
176
|
+
raw_config = json.load(fp)
|
|
177
|
+
else:
|
|
178
|
+
raw_config = json.load(file_like)
|
|
179
|
+
|
|
180
|
+
config = _raw_dict_to_config(raw_config)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def get_config() -> Config:
|
|
184
|
+
if config is None:
|
|
185
|
+
raise RuntimeError("Config not initialized — call initialize() first")
|
|
186
|
+
return config
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def reset_config():
|
|
190
|
+
global config
|
|
191
|
+
config = None
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from typing import Any, Final
|
|
2
|
+
|
|
3
|
+
from faker import Faker
|
|
4
|
+
|
|
5
|
+
fake = Faker()
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DataMasking:
|
|
9
|
+
@staticmethod
|
|
10
|
+
def null_out(_: str):
|
|
11
|
+
return None
|
|
12
|
+
|
|
13
|
+
@staticmethod
|
|
14
|
+
def mask_numbers(value: Any) -> str | None:
|
|
15
|
+
"""
|
|
16
|
+
Mask certain strings that may contain a mixture of letters,
|
|
17
|
+
normal characters, whitespaces, or special characters
|
|
18
|
+
"""
|
|
19
|
+
if value is None:
|
|
20
|
+
return None
|
|
21
|
+
str_value = str(value)
|
|
22
|
+
return "".join(
|
|
23
|
+
str(fake.random_digit()) if c.isdigit() else c for c in str_value
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
@staticmethod
|
|
27
|
+
def mask_characters(value: Any) -> str | None:
|
|
28
|
+
"""
|
|
29
|
+
Mask certain strings that may contain a mixture of letters,
|
|
30
|
+
normal characters, whitespaces, or special characters
|
|
31
|
+
"""
|
|
32
|
+
if value is None:
|
|
33
|
+
return None
|
|
34
|
+
str_value = str(value)
|
|
35
|
+
return "".join(fake.random_letter() if c.isalpha() else c for c in str_value)
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def mask_email(email: Any) -> str | None:
|
|
39
|
+
if email is None:
|
|
40
|
+
return None
|
|
41
|
+
s = str(email).split("@")
|
|
42
|
+
if len(s) < 2:
|
|
43
|
+
return fake.email()
|
|
44
|
+
return f"{fake.user_name()}@{s[1]}"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
DATA_MASKING_MAPPER: Final = {
|
|
48
|
+
"null_out": DataMasking.null_out,
|
|
49
|
+
"mask_numbers": DataMasking.mask_numbers,
|
|
50
|
+
"mask_characters": DataMasking.mask_characters,
|
|
51
|
+
"mask_email": DataMasking.mask_email,
|
|
52
|
+
}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from db_condenser.config_reader import DbType, get_config
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def get_specific_helper():
|
|
5
|
+
config = get_config()
|
|
6
|
+
if config.db_type == DbType.POSTGRES:
|
|
7
|
+
from db_condenser import psql_database_helper
|
|
8
|
+
|
|
9
|
+
return psql_database_helper
|
|
10
|
+
else:
|
|
11
|
+
from db_condenser import mysql_database_helper
|
|
12
|
+
|
|
13
|
+
return mysql_database_helper
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
import getpass
|
|
2
|
+
import sys
|
|
3
|
+
import time
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
|
|
6
|
+
import mysql.connector
|
|
7
|
+
import psycopg
|
|
8
|
+
|
|
9
|
+
from db_condenser.config_reader import DbConnectInfo, DbType
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DbConnection:
|
|
13
|
+
def __init__(self, connection):
|
|
14
|
+
self.connection = connection
|
|
15
|
+
|
|
16
|
+
def commit(self):
|
|
17
|
+
self.connection.commit()
|
|
18
|
+
|
|
19
|
+
def close(self):
|
|
20
|
+
self.connection.close()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class LoggingCursor:
|
|
24
|
+
def __init__(self, cursor, verbose=False):
|
|
25
|
+
self.inner_cursor = cursor
|
|
26
|
+
self._verbose = verbose
|
|
27
|
+
|
|
28
|
+
def execute(self, query, params=None):
|
|
29
|
+
start_time = time.time()
|
|
30
|
+
if self._verbose:
|
|
31
|
+
print("Beginning query @ {}:\n\t{}".format(str(datetime.now()), query))
|
|
32
|
+
sys.stdout.flush()
|
|
33
|
+
retval = self.inner_cursor.execute(query, params)
|
|
34
|
+
if self._verbose:
|
|
35
|
+
print("\tQuery completed in {}s".format(time.time() - start_time))
|
|
36
|
+
sys.stdout.flush()
|
|
37
|
+
return retval
|
|
38
|
+
|
|
39
|
+
def __getattr__(self, name):
|
|
40
|
+
return self.inner_cursor.__getattribute__(name)
|
|
41
|
+
|
|
42
|
+
def __exit__(self, a, b, c):
|
|
43
|
+
return self.inner_cursor.__exit__(a, b, c)
|
|
44
|
+
|
|
45
|
+
def __enter__(self):
|
|
46
|
+
return LoggingCursor(self.inner_cursor.__enter__(), self._verbose)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# small wrapper to the connection class that gives us a common interface to the cursor()
|
|
50
|
+
# method across MySQL and Postgres. This one is for Postgres
|
|
51
|
+
class PsqlConnection(DbConnection):
|
|
52
|
+
def __init__(self, connect, read_repeatable, verbose=False):
|
|
53
|
+
connection_args = dict(
|
|
54
|
+
dbname=connect.db_name,
|
|
55
|
+
user=connect.user,
|
|
56
|
+
password=connect.password,
|
|
57
|
+
host=connect.host,
|
|
58
|
+
port=connect.port,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
if connect.ssl_mode:
|
|
62
|
+
connection_args["sslmode"] = connect.ssl_mode
|
|
63
|
+
|
|
64
|
+
DbConnection.__init__(self, psycopg.connect(**connection_args))
|
|
65
|
+
self._verbose = verbose
|
|
66
|
+
if read_repeatable:
|
|
67
|
+
self.connection.isolation_level = psycopg.IsolationLevel.REPEATABLE_READ
|
|
68
|
+
|
|
69
|
+
def cursor(self, name=None, withhold=False):
|
|
70
|
+
return LoggingCursor(
|
|
71
|
+
self.connection.cursor(name=name, withhold=withhold), self._verbose
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
# small wrapper to the connection class that gives us a common interface to the cursor()
|
|
76
|
+
# method across MySQL and Postgres. This one is for MySQL
|
|
77
|
+
class MySqlConnection(DbConnection):
|
|
78
|
+
def __init__(self, connect, read_repeatable, verbose=False):
|
|
79
|
+
DbConnection.__init__(
|
|
80
|
+
self,
|
|
81
|
+
mysql.connector.connect(
|
|
82
|
+
host=connect.host,
|
|
83
|
+
port=connect.port,
|
|
84
|
+
user=connect.user,
|
|
85
|
+
password=connect.password,
|
|
86
|
+
database=connect.db_name,
|
|
87
|
+
),
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
self.db_name = connect.db_name
|
|
91
|
+
self._verbose = verbose
|
|
92
|
+
|
|
93
|
+
if read_repeatable:
|
|
94
|
+
self.connection.start_transaction(isolation_level="REPEATABLE READ")
|
|
95
|
+
|
|
96
|
+
def cursor(self, name=None, withhold=False):
|
|
97
|
+
return LoggingCursor(self.connection.cursor(), self._verbose)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class DbConnect:
|
|
101
|
+
def __init__(self, db_type: DbType, connection_info: DbConnectInfo, verbose=False):
|
|
102
|
+
if connection_info.password is None:
|
|
103
|
+
connection_info.password = getpass.getpass(
|
|
104
|
+
"Enter password for {0} on host {1}: ".format(
|
|
105
|
+
connection_info.user_name, connection_info.host
|
|
106
|
+
)
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
self.user = connection_info.user_name
|
|
110
|
+
self.password = connection_info.password
|
|
111
|
+
self.host = connection_info.host
|
|
112
|
+
self.port = connection_info.port
|
|
113
|
+
self.db_name = connection_info.db_name
|
|
114
|
+
self.ssl_mode = connection_info.ssl_mode
|
|
115
|
+
self.__db_type = db_type
|
|
116
|
+
self._verbose = verbose
|
|
117
|
+
|
|
118
|
+
def get_db_connection(
|
|
119
|
+
self, read_repeatable=False
|
|
120
|
+
) -> PsqlConnection | MySqlConnection:
|
|
121
|
+
if self.__db_type == DbType.POSTGRES:
|
|
122
|
+
return PsqlConnection(self, read_repeatable, self._verbose)
|
|
123
|
+
elif self.__db_type == DbType.MYSQL:
|
|
124
|
+
return MySqlConnection(self, read_repeatable, self._verbose)
|
|
125
|
+
else:
|
|
126
|
+
raise ValueError("unknown db_type " + self.__db_type)
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import sys
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
from db_condenser import config_reader, database_helper, result_tabulator
|
|
6
|
+
from db_condenser.config_reader import DbConnectInfo, DbType
|
|
7
|
+
from db_condenser.db_connect import DbConnect, MySqlConnection, PsqlConnection
|
|
8
|
+
from db_condenser.mysql_database_creator import MySqlDatabaseCreator
|
|
9
|
+
from db_condenser.psql_database_creator import PsqlDatabaseCreator
|
|
10
|
+
from db_condenser.subset import Subset
|
|
11
|
+
from db_condenser.subset_utils import print_progress
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def db_creator(
|
|
15
|
+
db_type: DbType, source: DbConnect, dest: DbConnect
|
|
16
|
+
) -> PsqlDatabaseCreator | MySqlDatabaseCreator:
|
|
17
|
+
if db_type == DbType.POSTGRES:
|
|
18
|
+
return PsqlDatabaseCreator(source, dest, False)
|
|
19
|
+
elif db_type == DbType.MYSQL:
|
|
20
|
+
return MySqlDatabaseCreator(source, dest)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _parse_args():
|
|
24
|
+
parser = argparse.ArgumentParser(description="Database Condenser")
|
|
25
|
+
parser.add_argument("--stdin", action="store_true", help="Read config from stdin")
|
|
26
|
+
parser.add_argument(
|
|
27
|
+
"-y", "--yes", action="store_true", help="Skip destination confirmation prompt"
|
|
28
|
+
)
|
|
29
|
+
parser.add_argument(
|
|
30
|
+
"--no-constraints", action="store_true", help="Skip adding constraints"
|
|
31
|
+
)
|
|
32
|
+
parser.add_argument(
|
|
33
|
+
"-v", "--verbose", action="store_true", help="Log every query with timing"
|
|
34
|
+
)
|
|
35
|
+
return parser.parse_args()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _confirm_destination(dest_info: DbConnectInfo):
|
|
39
|
+
print(
|
|
40
|
+
f"\nDestination: {dest_info.host}:{dest_info.port}/{dest_info.db_name}"
|
|
41
|
+
f" (user: {dest_info.user_name})"
|
|
42
|
+
)
|
|
43
|
+
response = input("Proceed with subsetting into this destination? [y/N] ")
|
|
44
|
+
if response.lower() not in ("y", "yes"):
|
|
45
|
+
print("Aborted.")
|
|
46
|
+
sys.exit(1)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def main():
|
|
50
|
+
args = _parse_args()
|
|
51
|
+
|
|
52
|
+
if args.stdin:
|
|
53
|
+
config_reader.initialize(sys.stdin)
|
|
54
|
+
else:
|
|
55
|
+
config_reader.initialize()
|
|
56
|
+
|
|
57
|
+
config = config_reader.get_config()
|
|
58
|
+
|
|
59
|
+
db_type = config.db_type
|
|
60
|
+
source_dbc = DbConnect(
|
|
61
|
+
db_type, config.source_db_connection_info, verbose=args.verbose
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
dest_info = config.destination_db_connection_info
|
|
65
|
+
if not args.yes and dest_info.host not in ("localhost", "127.0.0.1"):
|
|
66
|
+
_confirm_destination(dest_info)
|
|
67
|
+
|
|
68
|
+
destination_dbc = DbConnect(db_type, dest_info, verbose=args.verbose)
|
|
69
|
+
|
|
70
|
+
database = db_creator(db_type, source_dbc, destination_dbc)
|
|
71
|
+
database.teardown()
|
|
72
|
+
database.create()
|
|
73
|
+
|
|
74
|
+
# Get list of tables to operate on
|
|
75
|
+
db_helper = database_helper.get_specific_helper()
|
|
76
|
+
all_tables = db_helper.list_all_tables(source_dbc)
|
|
77
|
+
all_tables = [x for x in all_tables if x not in config.excluded_tables]
|
|
78
|
+
|
|
79
|
+
subsetter = Subset(source_dbc, destination_dbc, all_tables)
|
|
80
|
+
|
|
81
|
+
try:
|
|
82
|
+
subsetter.prep_temp_dbs()
|
|
83
|
+
subsetter.run_middle_out()
|
|
84
|
+
|
|
85
|
+
print("Beginning pre constraint SQL calls")
|
|
86
|
+
start_time = time.time()
|
|
87
|
+
for idx, sql in enumerate(config.pre_constraint_sql):
|
|
88
|
+
print_progress(sql, idx + 1, len(config.pre_constraint_sql))
|
|
89
|
+
db_helper.run_query(sql, destination_dbc.get_db_connection())
|
|
90
|
+
print(
|
|
91
|
+
"Completed pre constraint SQL calls in {}s".format(time.time() - start_time)
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
print("Adding database constraints")
|
|
95
|
+
if not args.no_constraints:
|
|
96
|
+
database.add_constraints()
|
|
97
|
+
|
|
98
|
+
print("Beginning post subset SQL calls")
|
|
99
|
+
start_time = time.time()
|
|
100
|
+
for idx, sql in enumerate(config.post_subset_sql):
|
|
101
|
+
print_progress(sql, idx + 1, len(config.post_subset_sql))
|
|
102
|
+
db_helper.run_query(sql, destination_dbc.get_db_connection())
|
|
103
|
+
print("Completed post subset SQL calls in {}s".format(time.time() - start_time))
|
|
104
|
+
|
|
105
|
+
print("Resetting sequence numbering")
|
|
106
|
+
all_tables_no_pg = [table for table in all_tables if "pgbench" not in table]
|
|
107
|
+
dest_conn = destination_dbc.get_db_connection()
|
|
108
|
+
if db_type == DbType.POSTGRES:
|
|
109
|
+
assert isinstance(dest_conn, PsqlConnection)
|
|
110
|
+
db_helper.update_sequence_numbering(dest_conn, all_tables_no_pg)
|
|
111
|
+
elif db_type == DbType.MYSQL:
|
|
112
|
+
# TODO update sequencing for mysql
|
|
113
|
+
assert isinstance(dest_conn, MySqlConnection)
|
|
114
|
+
# db_helper.update_sequence_numbering(
|
|
115
|
+
# dest_conn, all_tables_no_pg
|
|
116
|
+
# )
|
|
117
|
+
|
|
118
|
+
result_tabulator.tabulate(source_dbc, destination_dbc, all_tables)
|
|
119
|
+
except KeyboardInterrupt:
|
|
120
|
+
print("\nInterrupted — closing connections...")
|
|
121
|
+
raise
|
|
122
|
+
finally:
|
|
123
|
+
subsetter.unprep_temp_dbs()
|
|
124
|
+
subsetter.close_connections()
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
if __name__ == "__main__":
|
|
128
|
+
main()
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import subprocess
|
|
3
|
+
|
|
4
|
+
from db_condenser import config_reader, db_connect
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class MySqlDatabaseCreator:
|
|
8
|
+
def __init__(self, source_connect, destination_connect):
|
|
9
|
+
self.__source_connect = source_connect
|
|
10
|
+
self.__destination_connect = destination_connect
|
|
11
|
+
|
|
12
|
+
def create(self):
|
|
13
|
+
cur_path = os.getcwd()
|
|
14
|
+
|
|
15
|
+
mysql_bin_path = get_mysql_bin_path()
|
|
16
|
+
if mysql_bin_path != "":
|
|
17
|
+
os.chdir(mysql_bin_path)
|
|
18
|
+
|
|
19
|
+
ca = connection_args(self.__source_connect)
|
|
20
|
+
args = (
|
|
21
|
+
["mysqldump", "--no-data", "--routines"]
|
|
22
|
+
+ ca
|
|
23
|
+
+ [self.__source_connect.db_name]
|
|
24
|
+
)
|
|
25
|
+
result = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
26
|
+
if result.returncode != 0:
|
|
27
|
+
raise Exception(
|
|
28
|
+
"Capturing schema failed. Details:\n{}".format(result.stderr)
|
|
29
|
+
)
|
|
30
|
+
commands_to_create_schema = result.stdout
|
|
31
|
+
|
|
32
|
+
ca = connection_args(self.__destination_connect)
|
|
33
|
+
args = (
|
|
34
|
+
["mysql"]
|
|
35
|
+
+ ca
|
|
36
|
+
+ ["-e", "CREATE DATABASE " + self.__destination_connect.db_name]
|
|
37
|
+
)
|
|
38
|
+
result = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
39
|
+
if result.returncode != 0:
|
|
40
|
+
raise Exception(
|
|
41
|
+
"Creating destination database failed. Details:\n{}".format(
|
|
42
|
+
result.stderr
|
|
43
|
+
)
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
args = ["mysql", "-D", self.__destination_connect.db_name] + ca
|
|
47
|
+
result = subprocess.run(
|
|
48
|
+
args,
|
|
49
|
+
stdout=subprocess.PIPE,
|
|
50
|
+
stderr=subprocess.PIPE,
|
|
51
|
+
input=commands_to_create_schema,
|
|
52
|
+
)
|
|
53
|
+
if result.returncode != 0:
|
|
54
|
+
raise Exception(
|
|
55
|
+
"Creating destination schema. Details:\n{}".format(result.stderr)
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
os.chdir(cur_path)
|
|
59
|
+
|
|
60
|
+
def teardown(self):
|
|
61
|
+
self.run_query_on_destination(
|
|
62
|
+
"DROP DATABASE IF EXISTS " + self.__destination_connect.db_name + ";"
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def add_constraints(self):
|
|
66
|
+
# no-op for mysql
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
def run_query_on_destination(self, command):
|
|
70
|
+
cur_path = os.getcwd()
|
|
71
|
+
mysql_bin_path = get_mysql_bin_path()
|
|
72
|
+
if mysql_bin_path != "":
|
|
73
|
+
os.chdir(mysql_bin_path)
|
|
74
|
+
|
|
75
|
+
ca = connection_args(self.__destination_connect)
|
|
76
|
+
args = ["mysql"] + ca + ["-e", command]
|
|
77
|
+
result = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
78
|
+
os.chdir(cur_path)
|
|
79
|
+
if result.returncode != 0:
|
|
80
|
+
raise Exception(
|
|
81
|
+
"Failed to run command '{}'. Details:\n{}".format(
|
|
82
|
+
command, result.stderr
|
|
83
|
+
)
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def get_mysql_bin_path():
|
|
88
|
+
if "MYSQL_PATH" in os.environ:
|
|
89
|
+
mysql_bin_path = os.environ["MYSQL_PATH"]
|
|
90
|
+
else:
|
|
91
|
+
mysql_bin_path = ""
|
|
92
|
+
err = os.system(
|
|
93
|
+
'"'
|
|
94
|
+
+ os.path.join(mysql_bin_path, "mysqldump")
|
|
95
|
+
+ '"'
|
|
96
|
+
+ " --help > "
|
|
97
|
+
+ os.devnull
|
|
98
|
+
)
|
|
99
|
+
if err != 0:
|
|
100
|
+
raise Exception(
|
|
101
|
+
"Couldn't find MySQL utilities, consider specifying MYSQL_PATH environment variable if MySQL isn't "
|
|
102
|
+
+ "in your PATH."
|
|
103
|
+
)
|
|
104
|
+
return mysql_bin_path
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def connection_args(connect):
|
|
108
|
+
host_arg = "--host={}".format(connect.host)
|
|
109
|
+
port_arg = "--port={}".format(connect.port)
|
|
110
|
+
user_arg = "--user={}".format(connect.user)
|
|
111
|
+
password_arg = "--password={}".format(connect.password)
|
|
112
|
+
return [host_arg, port_arg, user_arg, password_arg]
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
# This is just for unit testing the creation and tear down processes
|
|
116
|
+
if __name__ == "__main__":
|
|
117
|
+
config_reader.initialize()
|
|
118
|
+
|
|
119
|
+
config = config_reader.get_config()
|
|
120
|
+
src_connect = db_connect.DbConnect(
|
|
121
|
+
config_reader.DbType.MYSQL, config.source_db_connection_info
|
|
122
|
+
)
|
|
123
|
+
dest_connect = db_connect.DbConnect(
|
|
124
|
+
config_reader.DbType.MYSQL, config.destination_db_connection_info
|
|
125
|
+
)
|
|
126
|
+
msdbc = MySqlDatabaseCreator(src_connect, dest_connect)
|
|
127
|
+
msdbc.teardown()
|
|
128
|
+
msdbc.create()
|