db-condenser 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,191 @@
1
+ import json
2
+ import sys
3
+ from dataclasses import dataclass, field
4
+ from enum import Enum
5
+ from typing import Literal
6
+
7
+
8
+ @dataclass
9
+ class InitialTarget:
10
+ table: str
11
+ percent: float | None = None
12
+ where: str | None = None
13
+
14
+ def __post_init__(self):
15
+ # Exactly one of where/percent must be set
16
+ if (self.where is None) == (self.percent is None):
17
+ raise ValueError(
18
+ "Initial Target must specify exactly one of 'where' or 'percent'"
19
+ )
20
+
21
+
22
+ class DbType(str, Enum):
23
+ POSTGRES = "postgres"
24
+ MYSQL = "mysql"
25
+
26
+
27
+ @dataclass
28
+ class DbConnectInfo:
29
+ user_name: str
30
+ host: str
31
+ db_name: str
32
+ port: int
33
+ ssl_mode: str | None = None
34
+ # No password will prompt user
35
+ password: str | None = None
36
+
37
+
38
+ @dataclass
39
+ class UpstreamFilter:
40
+ condition: str
41
+ table: str | None = None
42
+ column: str | None = None
43
+
44
+ def __post_init__(self):
45
+ # Exactly one of table/column must be set
46
+ if (self.table is None) == (self.column is None):
47
+ raise ValueError(
48
+ "Upstream filters must specify exactly one of 'table' or 'column'"
49
+ )
50
+
51
+
52
+ @dataclass
53
+ class DependencyBreak:
54
+ fk_table: str
55
+ target_table: str
56
+ preserve_fk_opportunistically: bool = False
57
+
58
+
59
+ @dataclass
60
+ class FkAugmentation:
61
+ fk_table: str
62
+ fk_columns: list[str]
63
+ target_table: str
64
+ target_columns: list[str]
65
+
66
+ def __post_init__(self):
67
+ if len(self.fk_columns) != len(self.target_columns):
68
+ raise ValueError("fk_columns and target_columns must be the same length")
69
+
70
+
71
+ @dataclass
72
+ class Config:
73
+ db_type: DbType
74
+ initial_targets: list[InitialTarget]
75
+ source_db_connection_info: DbConnectInfo
76
+ destination_db_connection_info: DbConnectInfo
77
+ keep_disconnected_tables: bool = False
78
+ upstream_filters: list[UpstreamFilter] = field(default_factory=list)
79
+ excluded_tables: list[str] = field(default_factory=list)
80
+ passthrough_tables: list[str] = field(default_factory=list)
81
+ dependency_breaks: list[DependencyBreak] = field(default_factory=list)
82
+ fk_augmentation: list[FkAugmentation] = field(default_factory=list)
83
+ max_rows_per_table: int | Literal["ALL"] | None = None
84
+ use_temp_tables: bool = False
85
+ use_copy_protocol: bool = False
86
+ pre_constraint_sql: list[str] = field(default_factory=list)
87
+ post_subset_sql: list[str] = field(default_factory=list)
88
+
89
+ @property
90
+ def dependency_break_set(self) -> set[tuple[str, str]]:
91
+ return {(b.fk_table, b.target_table) for b in self.dependency_breaks}
92
+
93
+ @property
94
+ def preserve_fk_opportunistically(self) -> set[tuple[str, str]]:
95
+ return {
96
+ (b.fk_table, b.target_table)
97
+ for b in self.dependency_breaks
98
+ if b.preserve_fk_opportunistically
99
+ }
100
+
101
+ @property
102
+ def initial_target_tables(self) -> list[str]:
103
+ return [target.table for target in self.initial_targets]
104
+
105
+
106
+ config: Config | None = None
107
+
108
+
109
+ def _raw_dict_to_config(raw_config: dict) -> Config:
110
+ initial_targets = []
111
+ db_type = DbType(raw_config["db_type"].lower())
112
+
113
+ initial_targets = [
114
+ InitialTarget(**target) for target in raw_config["initial_targets"]
115
+ ]
116
+
117
+ source_db = DbConnectInfo(**raw_config["source_db_connection_info"])
118
+ dest_db = DbConnectInfo(**raw_config["destination_db_connection_info"])
119
+
120
+ upstream_filters = [
121
+ UpstreamFilter(**table) for table in raw_config.get("upstream_filters", [])
122
+ ]
123
+
124
+ excluded_tables = [table for table in raw_config.get("excluded_tables", [])]
125
+ passthrough_tables = list(
126
+ set([table for table in raw_config.get("passthrough_tables", [])])
127
+ )
128
+ dependency_breaks = [
129
+ DependencyBreak(**relation)
130
+ for relation in raw_config.get("dependency_breaks", [])
131
+ ]
132
+ fk_augmentation = []
133
+ for fka in raw_config.get("fk_augmentation", []):
134
+ if "fk_schema" in fka:
135
+ fka = {
136
+ "fk_table": fka["fk_schema"] + "." + fka["fk_table"],
137
+ "fk_columns": fka["fk_columns"],
138
+ "target_table": fka["target_schema"] + "." + fka["target_table"],
139
+ "target_columns": fka["target_columns"],
140
+ }
141
+ fk_augmentation.append(FkAugmentation(**fka))
142
+
143
+ pre_constraint_sql = [sql for sql in raw_config.get("pre_constraint_sql", [])]
144
+ post_subset_sql = [sql for sql in raw_config.get("post_subset_sql", [])]
145
+ max_rows_per_table = raw_config.get("max_rows_per_table", None)
146
+ use_temp_tables = bool(raw_config.get("use_temp_tables", False))
147
+ use_copy_protocol = bool(raw_config.get("use_copy_protocol", False))
148
+ return Config(
149
+ db_type=db_type,
150
+ initial_targets=initial_targets,
151
+ source_db_connection_info=source_db,
152
+ destination_db_connection_info=dest_db,
153
+ keep_disconnected_tables=bool(
154
+ raw_config.get("keep_disconnected_tables", False)
155
+ ),
156
+ upstream_filters=upstream_filters,
157
+ excluded_tables=excluded_tables,
158
+ passthrough_tables=passthrough_tables,
159
+ dependency_breaks=dependency_breaks,
160
+ fk_augmentation=fk_augmentation,
161
+ max_rows_per_table=max_rows_per_table,
162
+ use_temp_tables=use_temp_tables,
163
+ use_copy_protocol=use_copy_protocol,
164
+ pre_constraint_sql=pre_constraint_sql,
165
+ post_subset_sql=post_subset_sql,
166
+ )
167
+
168
+
169
+ def initialize(file_like=None):
170
+ global config
171
+ if config:
172
+ print("WARNING: Attempted to initialize configuration twice.", file=sys.stderr)
173
+
174
+ if not file_like:
175
+ with open("config.json", "r") as fp:
176
+ raw_config = json.load(fp)
177
+ else:
178
+ raw_config = json.load(file_like)
179
+
180
+ config = _raw_dict_to_config(raw_config)
181
+
182
+
183
+ def get_config() -> Config:
184
+ if config is None:
185
+ raise RuntimeError("Config not initialized — call initialize() first")
186
+ return config
187
+
188
+
189
+ def reset_config():
190
+ global config
191
+ config = None
@@ -0,0 +1,52 @@
1
+ from typing import Any, Final
2
+
3
+ from faker import Faker
4
+
5
+ fake = Faker()
6
+
7
+
8
+ class DataMasking:
9
+ @staticmethod
10
+ def null_out(_: str):
11
+ return None
12
+
13
+ @staticmethod
14
+ def mask_numbers(value: Any) -> str | None:
15
+ """
16
+ Mask certain strings that may contain a mixture of letters,
17
+ normal characters, whitespaces, or special characters
18
+ """
19
+ if value is None:
20
+ return None
21
+ str_value = str(value)
22
+ return "".join(
23
+ str(fake.random_digit()) if c.isdigit() else c for c in str_value
24
+ )
25
+
26
+ @staticmethod
27
+ def mask_characters(value: Any) -> str | None:
28
+ """
29
+ Mask certain strings that may contain a mixture of letters,
30
+ normal characters, whitespaces, or special characters
31
+ """
32
+ if value is None:
33
+ return None
34
+ str_value = str(value)
35
+ return "".join(fake.random_letter() if c.isalpha() else c for c in str_value)
36
+
37
+ @staticmethod
38
+ def mask_email(email: Any) -> str | None:
39
+ if email is None:
40
+ return None
41
+ s = str(email).split("@")
42
+ if len(s) < 2:
43
+ return fake.email()
44
+ return f"{fake.user_name()}@{s[1]}"
45
+
46
+
47
+ DATA_MASKING_MAPPER: Final = {
48
+ "null_out": DataMasking.null_out,
49
+ "mask_numbers": DataMasking.mask_numbers,
50
+ "mask_characters": DataMasking.mask_characters,
51
+ "mask_email": DataMasking.mask_email,
52
+ }
@@ -0,0 +1,13 @@
1
+ from db_condenser.config_reader import DbType, get_config
2
+
3
+
4
+ def get_specific_helper():
5
+ config = get_config()
6
+ if config.db_type == DbType.POSTGRES:
7
+ from db_condenser import psql_database_helper
8
+
9
+ return psql_database_helper
10
+ else:
11
+ from db_condenser import mysql_database_helper
12
+
13
+ return mysql_database_helper
@@ -0,0 +1,126 @@
1
+ import getpass
2
+ import sys
3
+ import time
4
+ from datetime import datetime
5
+
6
+ import mysql.connector
7
+ import psycopg
8
+
9
+ from db_condenser.config_reader import DbConnectInfo, DbType
10
+
11
+
12
+ class DbConnection:
13
+ def __init__(self, connection):
14
+ self.connection = connection
15
+
16
+ def commit(self):
17
+ self.connection.commit()
18
+
19
+ def close(self):
20
+ self.connection.close()
21
+
22
+
23
+ class LoggingCursor:
24
+ def __init__(self, cursor, verbose=False):
25
+ self.inner_cursor = cursor
26
+ self._verbose = verbose
27
+
28
+ def execute(self, query, params=None):
29
+ start_time = time.time()
30
+ if self._verbose:
31
+ print("Beginning query @ {}:\n\t{}".format(str(datetime.now()), query))
32
+ sys.stdout.flush()
33
+ retval = self.inner_cursor.execute(query, params)
34
+ if self._verbose:
35
+ print("\tQuery completed in {}s".format(time.time() - start_time))
36
+ sys.stdout.flush()
37
+ return retval
38
+
39
+ def __getattr__(self, name):
40
+ return self.inner_cursor.__getattribute__(name)
41
+
42
+ def __exit__(self, a, b, c):
43
+ return self.inner_cursor.__exit__(a, b, c)
44
+
45
+ def __enter__(self):
46
+ return LoggingCursor(self.inner_cursor.__enter__(), self._verbose)
47
+
48
+
49
+ # small wrapper to the connection class that gives us a common interface to the cursor()
50
+ # method across MySQL and Postgres. This one is for Postgres
51
+ class PsqlConnection(DbConnection):
52
+ def __init__(self, connect, read_repeatable, verbose=False):
53
+ connection_args = dict(
54
+ dbname=connect.db_name,
55
+ user=connect.user,
56
+ password=connect.password,
57
+ host=connect.host,
58
+ port=connect.port,
59
+ )
60
+
61
+ if connect.ssl_mode:
62
+ connection_args["sslmode"] = connect.ssl_mode
63
+
64
+ DbConnection.__init__(self, psycopg.connect(**connection_args))
65
+ self._verbose = verbose
66
+ if read_repeatable:
67
+ self.connection.isolation_level = psycopg.IsolationLevel.REPEATABLE_READ
68
+
69
+ def cursor(self, name=None, withhold=False):
70
+ return LoggingCursor(
71
+ self.connection.cursor(name=name, withhold=withhold), self._verbose
72
+ )
73
+
74
+
75
+ # small wrapper to the connection class that gives us a common interface to the cursor()
76
+ # method across MySQL and Postgres. This one is for MySQL
77
+ class MySqlConnection(DbConnection):
78
+ def __init__(self, connect, read_repeatable, verbose=False):
79
+ DbConnection.__init__(
80
+ self,
81
+ mysql.connector.connect(
82
+ host=connect.host,
83
+ port=connect.port,
84
+ user=connect.user,
85
+ password=connect.password,
86
+ database=connect.db_name,
87
+ ),
88
+ )
89
+
90
+ self.db_name = connect.db_name
91
+ self._verbose = verbose
92
+
93
+ if read_repeatable:
94
+ self.connection.start_transaction(isolation_level="REPEATABLE READ")
95
+
96
+ def cursor(self, name=None, withhold=False):
97
+ return LoggingCursor(self.connection.cursor(), self._verbose)
98
+
99
+
100
+ class DbConnect:
101
+ def __init__(self, db_type: DbType, connection_info: DbConnectInfo, verbose=False):
102
+ if connection_info.password is None:
103
+ connection_info.password = getpass.getpass(
104
+ "Enter password for {0} on host {1}: ".format(
105
+ connection_info.user_name, connection_info.host
106
+ )
107
+ )
108
+
109
+ self.user = connection_info.user_name
110
+ self.password = connection_info.password
111
+ self.host = connection_info.host
112
+ self.port = connection_info.port
113
+ self.db_name = connection_info.db_name
114
+ self.ssl_mode = connection_info.ssl_mode
115
+ self.__db_type = db_type
116
+ self._verbose = verbose
117
+
118
+ def get_db_connection(
119
+ self, read_repeatable=False
120
+ ) -> PsqlConnection | MySqlConnection:
121
+ if self.__db_type == DbType.POSTGRES:
122
+ return PsqlConnection(self, read_repeatable, self._verbose)
123
+ elif self.__db_type == DbType.MYSQL:
124
+ return MySqlConnection(self, read_repeatable, self._verbose)
125
+ else:
126
+ raise ValueError("unknown db_type " + self.__db_type)
@@ -0,0 +1,128 @@
1
+ import argparse
2
+ import sys
3
+ import time
4
+
5
+ from db_condenser import config_reader, database_helper, result_tabulator
6
+ from db_condenser.config_reader import DbConnectInfo, DbType
7
+ from db_condenser.db_connect import DbConnect, MySqlConnection, PsqlConnection
8
+ from db_condenser.mysql_database_creator import MySqlDatabaseCreator
9
+ from db_condenser.psql_database_creator import PsqlDatabaseCreator
10
+ from db_condenser.subset import Subset
11
+ from db_condenser.subset_utils import print_progress
12
+
13
+
14
+ def db_creator(
15
+ db_type: DbType, source: DbConnect, dest: DbConnect
16
+ ) -> PsqlDatabaseCreator | MySqlDatabaseCreator:
17
+ if db_type == DbType.POSTGRES:
18
+ return PsqlDatabaseCreator(source, dest, False)
19
+ elif db_type == DbType.MYSQL:
20
+ return MySqlDatabaseCreator(source, dest)
21
+
22
+
23
+ def _parse_args():
24
+ parser = argparse.ArgumentParser(description="Database Condenser")
25
+ parser.add_argument("--stdin", action="store_true", help="Read config from stdin")
26
+ parser.add_argument(
27
+ "-y", "--yes", action="store_true", help="Skip destination confirmation prompt"
28
+ )
29
+ parser.add_argument(
30
+ "--no-constraints", action="store_true", help="Skip adding constraints"
31
+ )
32
+ parser.add_argument(
33
+ "-v", "--verbose", action="store_true", help="Log every query with timing"
34
+ )
35
+ return parser.parse_args()
36
+
37
+
38
+ def _confirm_destination(dest_info: DbConnectInfo):
39
+ print(
40
+ f"\nDestination: {dest_info.host}:{dest_info.port}/{dest_info.db_name}"
41
+ f" (user: {dest_info.user_name})"
42
+ )
43
+ response = input("Proceed with subsetting into this destination? [y/N] ")
44
+ if response.lower() not in ("y", "yes"):
45
+ print("Aborted.")
46
+ sys.exit(1)
47
+
48
+
49
+ def main():
50
+ args = _parse_args()
51
+
52
+ if args.stdin:
53
+ config_reader.initialize(sys.stdin)
54
+ else:
55
+ config_reader.initialize()
56
+
57
+ config = config_reader.get_config()
58
+
59
+ db_type = config.db_type
60
+ source_dbc = DbConnect(
61
+ db_type, config.source_db_connection_info, verbose=args.verbose
62
+ )
63
+
64
+ dest_info = config.destination_db_connection_info
65
+ if not args.yes and dest_info.host not in ("localhost", "127.0.0.1"):
66
+ _confirm_destination(dest_info)
67
+
68
+ destination_dbc = DbConnect(db_type, dest_info, verbose=args.verbose)
69
+
70
+ database = db_creator(db_type, source_dbc, destination_dbc)
71
+ database.teardown()
72
+ database.create()
73
+
74
+ # Get list of tables to operate on
75
+ db_helper = database_helper.get_specific_helper()
76
+ all_tables = db_helper.list_all_tables(source_dbc)
77
+ all_tables = [x for x in all_tables if x not in config.excluded_tables]
78
+
79
+ subsetter = Subset(source_dbc, destination_dbc, all_tables)
80
+
81
+ try:
82
+ subsetter.prep_temp_dbs()
83
+ subsetter.run_middle_out()
84
+
85
+ print("Beginning pre constraint SQL calls")
86
+ start_time = time.time()
87
+ for idx, sql in enumerate(config.pre_constraint_sql):
88
+ print_progress(sql, idx + 1, len(config.pre_constraint_sql))
89
+ db_helper.run_query(sql, destination_dbc.get_db_connection())
90
+ print(
91
+ "Completed pre constraint SQL calls in {}s".format(time.time() - start_time)
92
+ )
93
+
94
+ print("Adding database constraints")
95
+ if not args.no_constraints:
96
+ database.add_constraints()
97
+
98
+ print("Beginning post subset SQL calls")
99
+ start_time = time.time()
100
+ for idx, sql in enumerate(config.post_subset_sql):
101
+ print_progress(sql, idx + 1, len(config.post_subset_sql))
102
+ db_helper.run_query(sql, destination_dbc.get_db_connection())
103
+ print("Completed post subset SQL calls in {}s".format(time.time() - start_time))
104
+
105
+ print("Resetting sequence numbering")
106
+ all_tables_no_pg = [table for table in all_tables if "pgbench" not in table]
107
+ dest_conn = destination_dbc.get_db_connection()
108
+ if db_type == DbType.POSTGRES:
109
+ assert isinstance(dest_conn, PsqlConnection)
110
+ db_helper.update_sequence_numbering(dest_conn, all_tables_no_pg)
111
+ elif db_type == DbType.MYSQL:
112
+ # TODO update sequencing for mysql
113
+ assert isinstance(dest_conn, MySqlConnection)
114
+ # db_helper.update_sequence_numbering(
115
+ # dest_conn, all_tables_no_pg
116
+ # )
117
+
118
+ result_tabulator.tabulate(source_dbc, destination_dbc, all_tables)
119
+ except KeyboardInterrupt:
120
+ print("\nInterrupted — closing connections...")
121
+ raise
122
+ finally:
123
+ subsetter.unprep_temp_dbs()
124
+ subsetter.close_connections()
125
+
126
+
127
+ if __name__ == "__main__":
128
+ main()
@@ -0,0 +1,128 @@
1
+ import os
2
+ import subprocess
3
+
4
+ from db_condenser import config_reader, db_connect
5
+
6
+
7
+ class MySqlDatabaseCreator:
8
+ def __init__(self, source_connect, destination_connect):
9
+ self.__source_connect = source_connect
10
+ self.__destination_connect = destination_connect
11
+
12
+ def create(self):
13
+ cur_path = os.getcwd()
14
+
15
+ mysql_bin_path = get_mysql_bin_path()
16
+ if mysql_bin_path != "":
17
+ os.chdir(mysql_bin_path)
18
+
19
+ ca = connection_args(self.__source_connect)
20
+ args = (
21
+ ["mysqldump", "--no-data", "--routines"]
22
+ + ca
23
+ + [self.__source_connect.db_name]
24
+ )
25
+ result = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
26
+ if result.returncode != 0:
27
+ raise Exception(
28
+ "Capturing schema failed. Details:\n{}".format(result.stderr)
29
+ )
30
+ commands_to_create_schema = result.stdout
31
+
32
+ ca = connection_args(self.__destination_connect)
33
+ args = (
34
+ ["mysql"]
35
+ + ca
36
+ + ["-e", "CREATE DATABASE " + self.__destination_connect.db_name]
37
+ )
38
+ result = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
39
+ if result.returncode != 0:
40
+ raise Exception(
41
+ "Creating destination database failed. Details:\n{}".format(
42
+ result.stderr
43
+ )
44
+ )
45
+
46
+ args = ["mysql", "-D", self.__destination_connect.db_name] + ca
47
+ result = subprocess.run(
48
+ args,
49
+ stdout=subprocess.PIPE,
50
+ stderr=subprocess.PIPE,
51
+ input=commands_to_create_schema,
52
+ )
53
+ if result.returncode != 0:
54
+ raise Exception(
55
+ "Creating destination schema. Details:\n{}".format(result.stderr)
56
+ )
57
+
58
+ os.chdir(cur_path)
59
+
60
+ def teardown(self):
61
+ self.run_query_on_destination(
62
+ "DROP DATABASE IF EXISTS " + self.__destination_connect.db_name + ";"
63
+ )
64
+
65
+ def add_constraints(self):
66
+ # no-op for mysql
67
+ pass
68
+
69
+ def run_query_on_destination(self, command):
70
+ cur_path = os.getcwd()
71
+ mysql_bin_path = get_mysql_bin_path()
72
+ if mysql_bin_path != "":
73
+ os.chdir(mysql_bin_path)
74
+
75
+ ca = connection_args(self.__destination_connect)
76
+ args = ["mysql"] + ca + ["-e", command]
77
+ result = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
78
+ os.chdir(cur_path)
79
+ if result.returncode != 0:
80
+ raise Exception(
81
+ "Failed to run command '{}'. Details:\n{}".format(
82
+ command, result.stderr
83
+ )
84
+ )
85
+
86
+
87
+ def get_mysql_bin_path():
88
+ if "MYSQL_PATH" in os.environ:
89
+ mysql_bin_path = os.environ["MYSQL_PATH"]
90
+ else:
91
+ mysql_bin_path = ""
92
+ err = os.system(
93
+ '"'
94
+ + os.path.join(mysql_bin_path, "mysqldump")
95
+ + '"'
96
+ + " --help > "
97
+ + os.devnull
98
+ )
99
+ if err != 0:
100
+ raise Exception(
101
+ "Couldn't find MySQL utilities, consider specifying MYSQL_PATH environment variable if MySQL isn't "
102
+ + "in your PATH."
103
+ )
104
+ return mysql_bin_path
105
+
106
+
107
+ def connection_args(connect):
108
+ host_arg = "--host={}".format(connect.host)
109
+ port_arg = "--port={}".format(connect.port)
110
+ user_arg = "--user={}".format(connect.user)
111
+ password_arg = "--password={}".format(connect.password)
112
+ return [host_arg, port_arg, user_arg, password_arg]
113
+
114
+
115
+ # This is just for unit testing the creation and tear down processes
116
+ if __name__ == "__main__":
117
+ config_reader.initialize()
118
+
119
+ config = config_reader.get_config()
120
+ src_connect = db_connect.DbConnect(
121
+ config_reader.DbType.MYSQL, config.source_db_connection_info
122
+ )
123
+ dest_connect = db_connect.DbConnect(
124
+ config_reader.DbType.MYSQL, config.destination_db_connection_info
125
+ )
126
+ msdbc = MySqlDatabaseCreator(src_connect, dest_connect)
127
+ msdbc.teardown()
128
+ msdbc.create()