dcs-sdk 1.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (159) hide show
  1. data_diff/__init__.py +221 -0
  2. data_diff/__main__.py +517 -0
  3. data_diff/abcs/__init__.py +13 -0
  4. data_diff/abcs/compiler.py +27 -0
  5. data_diff/abcs/database_types.py +402 -0
  6. data_diff/config.py +141 -0
  7. data_diff/databases/__init__.py +38 -0
  8. data_diff/databases/_connect.py +323 -0
  9. data_diff/databases/base.py +1417 -0
  10. data_diff/databases/bigquery.py +376 -0
  11. data_diff/databases/clickhouse.py +217 -0
  12. data_diff/databases/databricks.py +262 -0
  13. data_diff/databases/duckdb.py +207 -0
  14. data_diff/databases/mssql.py +343 -0
  15. data_diff/databases/mysql.py +189 -0
  16. data_diff/databases/oracle.py +238 -0
  17. data_diff/databases/postgresql.py +293 -0
  18. data_diff/databases/presto.py +222 -0
  19. data_diff/databases/redis.py +93 -0
  20. data_diff/databases/redshift.py +233 -0
  21. data_diff/databases/snowflake.py +222 -0
  22. data_diff/databases/sybase.py +720 -0
  23. data_diff/databases/trino.py +73 -0
  24. data_diff/databases/vertica.py +174 -0
  25. data_diff/diff_tables.py +489 -0
  26. data_diff/errors.py +17 -0
  27. data_diff/format.py +369 -0
  28. data_diff/hashdiff_tables.py +1026 -0
  29. data_diff/info_tree.py +76 -0
  30. data_diff/joindiff_tables.py +434 -0
  31. data_diff/lexicographic_space.py +253 -0
  32. data_diff/parse_time.py +88 -0
  33. data_diff/py.typed +0 -0
  34. data_diff/queries/__init__.py +13 -0
  35. data_diff/queries/api.py +213 -0
  36. data_diff/queries/ast_classes.py +811 -0
  37. data_diff/queries/base.py +38 -0
  38. data_diff/queries/extras.py +43 -0
  39. data_diff/query_utils.py +70 -0
  40. data_diff/schema.py +67 -0
  41. data_diff/table_segment.py +583 -0
  42. data_diff/thread_utils.py +112 -0
  43. data_diff/utils.py +1022 -0
  44. data_diff/version.py +15 -0
  45. dcs_core/__init__.py +13 -0
  46. dcs_core/__main__.py +17 -0
  47. dcs_core/__version__.py +15 -0
  48. dcs_core/cli/__init__.py +13 -0
  49. dcs_core/cli/cli.py +165 -0
  50. dcs_core/core/__init__.py +19 -0
  51. dcs_core/core/common/__init__.py +13 -0
  52. dcs_core/core/common/errors.py +50 -0
  53. dcs_core/core/common/models/__init__.py +13 -0
  54. dcs_core/core/common/models/configuration.py +284 -0
  55. dcs_core/core/common/models/dashboard.py +24 -0
  56. dcs_core/core/common/models/data_source_resource.py +75 -0
  57. dcs_core/core/common/models/metric.py +160 -0
  58. dcs_core/core/common/models/profile.py +75 -0
  59. dcs_core/core/common/models/validation.py +216 -0
  60. dcs_core/core/common/models/widget.py +44 -0
  61. dcs_core/core/configuration/__init__.py +13 -0
  62. dcs_core/core/configuration/config_loader.py +139 -0
  63. dcs_core/core/configuration/configuration_parser.py +262 -0
  64. dcs_core/core/configuration/configuration_parser_arc.py +328 -0
  65. dcs_core/core/datasource/__init__.py +13 -0
  66. dcs_core/core/datasource/base.py +62 -0
  67. dcs_core/core/datasource/manager.py +112 -0
  68. dcs_core/core/datasource/search_datasource.py +421 -0
  69. dcs_core/core/datasource/sql_datasource.py +1094 -0
  70. dcs_core/core/inspect.py +163 -0
  71. dcs_core/core/logger/__init__.py +13 -0
  72. dcs_core/core/logger/base.py +32 -0
  73. dcs_core/core/logger/default_logger.py +94 -0
  74. dcs_core/core/metric/__init__.py +13 -0
  75. dcs_core/core/metric/base.py +220 -0
  76. dcs_core/core/metric/combined_metric.py +98 -0
  77. dcs_core/core/metric/custom_metric.py +34 -0
  78. dcs_core/core/metric/manager.py +137 -0
  79. dcs_core/core/metric/numeric_metric.py +403 -0
  80. dcs_core/core/metric/reliability_metric.py +90 -0
  81. dcs_core/core/profiling/__init__.py +13 -0
  82. dcs_core/core/profiling/datasource_profiling.py +136 -0
  83. dcs_core/core/profiling/numeric_field_profiling.py +72 -0
  84. dcs_core/core/profiling/text_field_profiling.py +67 -0
  85. dcs_core/core/repository/__init__.py +13 -0
  86. dcs_core/core/repository/metric_repository.py +77 -0
  87. dcs_core/core/utils/__init__.py +13 -0
  88. dcs_core/core/utils/log.py +29 -0
  89. dcs_core/core/utils/tracking.py +105 -0
  90. dcs_core/core/utils/utils.py +44 -0
  91. dcs_core/core/validation/__init__.py +13 -0
  92. dcs_core/core/validation/base.py +230 -0
  93. dcs_core/core/validation/completeness_validation.py +153 -0
  94. dcs_core/core/validation/custom_query_validation.py +24 -0
  95. dcs_core/core/validation/manager.py +282 -0
  96. dcs_core/core/validation/numeric_validation.py +276 -0
  97. dcs_core/core/validation/reliability_validation.py +91 -0
  98. dcs_core/core/validation/uniqueness_validation.py +61 -0
  99. dcs_core/core/validation/validity_validation.py +738 -0
  100. dcs_core/integrations/__init__.py +13 -0
  101. dcs_core/integrations/databases/__init__.py +13 -0
  102. dcs_core/integrations/databases/bigquery.py +187 -0
  103. dcs_core/integrations/databases/databricks.py +51 -0
  104. dcs_core/integrations/databases/db2.py +652 -0
  105. dcs_core/integrations/databases/elasticsearch.py +61 -0
  106. dcs_core/integrations/databases/mssql.py +829 -0
  107. dcs_core/integrations/databases/mysql.py +409 -0
  108. dcs_core/integrations/databases/opensearch.py +64 -0
  109. dcs_core/integrations/databases/oracle.py +719 -0
  110. dcs_core/integrations/databases/postgres.py +482 -0
  111. dcs_core/integrations/databases/redshift.py +53 -0
  112. dcs_core/integrations/databases/snowflake.py +48 -0
  113. dcs_core/integrations/databases/spark_df.py +111 -0
  114. dcs_core/integrations/databases/sybase.py +1069 -0
  115. dcs_core/integrations/storage/__init__.py +13 -0
  116. dcs_core/integrations/storage/local_file.py +149 -0
  117. dcs_core/integrations/utils/__init__.py +13 -0
  118. dcs_core/integrations/utils/utils.py +36 -0
  119. dcs_core/report/__init__.py +13 -0
  120. dcs_core/report/dashboard.py +211 -0
  121. dcs_core/report/models.py +88 -0
  122. dcs_core/report/static/assets/fonts/DMSans-Bold.ttf +0 -0
  123. dcs_core/report/static/assets/fonts/DMSans-Medium.ttf +0 -0
  124. dcs_core/report/static/assets/fonts/DMSans-Regular.ttf +0 -0
  125. dcs_core/report/static/assets/fonts/DMSans-SemiBold.ttf +0 -0
  126. dcs_core/report/static/assets/images/docs.svg +6 -0
  127. dcs_core/report/static/assets/images/github.svg +4 -0
  128. dcs_core/report/static/assets/images/logo.svg +7 -0
  129. dcs_core/report/static/assets/images/slack.svg +13 -0
  130. dcs_core/report/static/index.js +2 -0
  131. dcs_core/report/static/index.js.LICENSE.txt +3971 -0
  132. dcs_sdk/__init__.py +13 -0
  133. dcs_sdk/__main__.py +18 -0
  134. dcs_sdk/__version__.py +15 -0
  135. dcs_sdk/cli/__init__.py +13 -0
  136. dcs_sdk/cli/cli.py +163 -0
  137. dcs_sdk/sdk/__init__.py +58 -0
  138. dcs_sdk/sdk/config/__init__.py +13 -0
  139. dcs_sdk/sdk/config/config_loader.py +491 -0
  140. dcs_sdk/sdk/data_diff/__init__.py +13 -0
  141. dcs_sdk/sdk/data_diff/data_differ.py +821 -0
  142. dcs_sdk/sdk/rules/__init__.py +15 -0
  143. dcs_sdk/sdk/rules/rules_mappping.py +31 -0
  144. dcs_sdk/sdk/rules/rules_repository.py +214 -0
  145. dcs_sdk/sdk/rules/schema_rules.py +65 -0
  146. dcs_sdk/sdk/utils/__init__.py +13 -0
  147. dcs_sdk/sdk/utils/serializer.py +25 -0
  148. dcs_sdk/sdk/utils/similarity_score/__init__.py +13 -0
  149. dcs_sdk/sdk/utils/similarity_score/base_provider.py +153 -0
  150. dcs_sdk/sdk/utils/similarity_score/cosine_similarity_provider.py +39 -0
  151. dcs_sdk/sdk/utils/similarity_score/jaccard_provider.py +24 -0
  152. dcs_sdk/sdk/utils/similarity_score/levenshtein_distance_provider.py +31 -0
  153. dcs_sdk/sdk/utils/table.py +475 -0
  154. dcs_sdk/sdk/utils/themes.py +40 -0
  155. dcs_sdk/sdk/utils/utils.py +349 -0
  156. dcs_sdk-1.6.5.dist-info/METADATA +150 -0
  157. dcs_sdk-1.6.5.dist-info/RECORD +159 -0
  158. dcs_sdk-1.6.5.dist-info/WHEEL +4 -0
  159. dcs_sdk-1.6.5.dist-info/entry_points.txt +4 -0
@@ -0,0 +1,343 @@
1
+ # Copyright 2022-present, the Waterdip Labs Pvt. Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from typing import Any, ClassVar, Dict, Optional, Type
17
+
18
+ import attrs
19
+ from loguru import logger
20
+
21
+ from data_diff.abcs.database_types import (
22
+ JSON,
23
+ Boolean,
24
+ Date,
25
+ Datetime,
26
+ DbPath,
27
+ Decimal,
28
+ Float,
29
+ FractionalType,
30
+ Integer,
31
+ Native_UUID,
32
+ String_UUID,
33
+ TemporalType,
34
+ Text,
35
+ Time,
36
+ Timestamp,
37
+ TimestampTZ,
38
+ )
39
+ from data_diff.databases.base import (
40
+ CHECKSUM_HEXDIGITS,
41
+ CHECKSUM_OFFSET,
42
+ BaseDialect,
43
+ ConnectError,
44
+ QueryError,
45
+ ThreadedDatabase,
46
+ import_helper,
47
+ )
48
+
49
+
50
+ @import_helper("mssql")
51
+ def import_mssql():
52
+ import pyodbc
53
+
54
+ return pyodbc
55
+
56
+
57
+ @attrs.define(frozen=False)
58
+ class Dialect(BaseDialect):
59
+ name = "MsSQL"
60
+ ROUNDS_ON_PREC_LOSS = True
61
+ SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True
62
+ SUPPORTS_INDEXES = True
63
+ TYPE_CLASSES = {
64
+ # Timestamps
65
+ "datetimeoffset": TimestampTZ,
66
+ "datetime": Datetime,
67
+ "datetime2": Timestamp,
68
+ "smalldatetime": Datetime,
69
+ "timestamp": Datetime,
70
+ "date": Date,
71
+ "time": Time,
72
+ # Numbers
73
+ "float": Float,
74
+ "real": Float,
75
+ "decimal": Decimal,
76
+ "money": Decimal,
77
+ "smallmoney": Decimal,
78
+ "numeric": Decimal,
79
+ # int
80
+ "int": Integer,
81
+ "bigint": Integer,
82
+ "tinyint": Integer,
83
+ "smallint": Integer,
84
+ # Text
85
+ "varchar": Text,
86
+ "char": Text,
87
+ "text": Text,
88
+ "ntext": Text,
89
+ "nvarchar": Text,
90
+ "nchar": Text,
91
+ "binary": Text,
92
+ "varbinary": Text,
93
+ "xml": Text,
94
+ # UUID
95
+ "uniqueidentifier": Native_UUID,
96
+ # Bool
97
+ "bit": Boolean,
98
+ # JSON
99
+ "json": JSON,
100
+ }
101
+
102
+ def quote(self, s: str, is_table: bool = False) -> str:
103
+ if s in self.TABLE_NAMES and self.default_schema and is_table:
104
+ return f"[{self.default_schema}].[{s}]"
105
+ return f"[{s}]"
106
+
107
+ def set_timezone_to_utc(self) -> str:
108
+ raise NotImplementedError("MsSQL does not support a session timezone setting.")
109
+
110
+ def current_timestamp(self) -> str:
111
+ return "GETDATE()"
112
+
113
+ def current_database(self) -> str:
114
+ return "DB_NAME()"
115
+
116
+ def current_schema(self) -> str:
117
+ return """default_schema_name
118
+ FROM sys.database_principals
119
+ WHERE name = CURRENT_USER"""
120
+
121
+ def to_string(self, s: str) -> str:
122
+ s_temp = re.sub(r'["\[\]`]', "", s)
123
+ col_info = self.get_column_raw_info(s_temp)
124
+
125
+ ch_len = (col_info and col_info.character_maximum_length) or None
126
+
127
+ if ch_len is None or ch_len <= 0:
128
+ ch_len = "MAX"
129
+ else:
130
+ ch_len = min(ch_len, 8000)
131
+
132
+ if col_info and col_info.data_type.lower().strip() in ["nvarchar", "nchar"]:
133
+ return f"CONVERT(NVARCHAR({ch_len}), {s})"
134
+
135
+ elif col_info and col_info.data_type.lower().strip() == "text":
136
+ return f"CONVERT(VARCHAR(MAX), {s})"
137
+
138
+ elif col_info and col_info.data_type.lower().strip() == "ntext":
139
+ return f"CONVERT(NVARCHAR(MAX), {s})"
140
+
141
+ return f"CONVERT(VARCHAR({ch_len}), {s})"
142
+
143
+ def type_repr(self, t) -> str:
144
+ try:
145
+ return {bool: "bit", str: "text"}[t]
146
+ except KeyError:
147
+ return super().type_repr(t)
148
+
149
+ def random(self) -> str:
150
+ return "rand()"
151
+
152
+ def is_distinct_from(self, a: str, b: str) -> str:
153
+ # IS (NOT) DISTINCT FROM is available only since SQLServer 2022.
154
+ # See: https://stackoverflow.com/a/18684859/857383
155
+ return f"(({a}<>{b} OR {a} IS NULL OR {b} IS NULL) AND NOT({a} IS NULL AND {b} IS NULL))"
156
+
157
+ def limit_select(
158
+ self,
159
+ select_query: str,
160
+ offset: Optional[int] = None,
161
+ limit: Optional[int] = None,
162
+ has_order_by: Optional[bool] = None,
163
+ ) -> str:
164
+
165
+ if offset:
166
+ raise NotImplementedError("No support for OFFSET in query")
167
+
168
+ result = ""
169
+ if not has_order_by:
170
+ result += "ORDER BY 1"
171
+
172
+ if limit is not None:
173
+ result += f" OFFSET 0 ROWS FETCH NEXT {limit} ROWS ONLY"
174
+
175
+ # select_query = re.sub(r"TRIM\(\[([\w]+)\]\)", r"TRIM(CAST([\1] AS NVARCHAR(MAX)))", select_query)
176
+
177
+ # select_query = re.sub(r"TRIM\(([\w]+)\)", r"TRIM(CAST(\1 AS NVARCHAR(MAX)))", select_query)
178
+
179
+ # select_query = re.sub(r"TRIM\(\[([\w]+)\]\)", r"LTRIM(RTRIM(CAST([\1] AS VARCHAR(8000))))", select_query)
180
+
181
+ # select_query = re.sub(r"TRIM\(([\w]+)\)", r"LTRIM(RTRIM(CAST(\1 AS VARCHAR(8000))))", select_query)
182
+
183
+ return f"{select_query} {result}"
184
+
185
+ def constant_values(self, rows) -> str:
186
+ values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows)
187
+ return f"VALUES {values}"
188
+
189
+ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
190
+ # if coltype.precision > 0:
191
+ # formatted_value = (
192
+ # f"FORMAT({value}, 'yyyy-MM-dd HH:mm:ss') + '.' + "
193
+ # f"SUBSTRING(FORMAT({value}, 'fffffff'), 1, {coltype.precision})"
194
+ # )
195
+ # else:
196
+ # formatted_value = f"FORMAT({value}, 'yyyy-MM-dd HH:mm:ss')"
197
+
198
+ # return formatted_value
199
+ if isinstance(coltype, Datetime):
200
+ if coltype.precision > 0:
201
+ return f"CASE WHEN {value} IS NULL THEN NULL ELSE FORMAT({value}, 'yyyy-MM-dd HH:mm:ss.fff') END"
202
+ return f"CASE WHEN {value} IS NULL THEN NULL ELSE FORMAT({value}, 'yyyy-MM-dd HH:mm:ss') END"
203
+ return f"CAST({value} AS VARCHAR)"
204
+
205
+ def normalize_number(self, value: str, coltype: FractionalType) -> str:
206
+ return self.to_string(f"CAST({value} AS DECIMAL(38, {coltype.precision}))")
207
+
208
+ def md5_as_int(self, s: str) -> str:
209
+ return f"convert(bigint, convert(varbinary, '0x' + RIGHT(CONVERT(NVARCHAR(32), HashBytes('MD5', {s}), 2), {CHECKSUM_HEXDIGITS}), 1)) - {CHECKSUM_OFFSET}"
210
+
211
+ def md5_as_hex(self, s: str) -> str:
212
+ return f"HashBytes('MD5', {s})"
213
+
214
+ def parse_table_name(self, name: str) -> DbPath:
215
+ "Parse the given table name into a DbPath"
216
+ self.TABLE_NAMES.append(name.split(".")[-1])
217
+ return tuple(name.split("."))
218
+
219
+ def normalize_uuid(self, value, coltype):
220
+ return self.to_string(value)
221
+
222
+
223
+ @attrs.define(frozen=False, init=False, kw_only=True)
224
+ class MsSQL(ThreadedDatabase):
225
+ DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
226
+ CONNECT_URI_HELP = "mssql://<user>:<password>@<host>/<database>/<schema>"
227
+ CONNECT_URI_PARAMS = ["database", "schema"]
228
+
229
+ default_database: str
230
+ _args: Dict[str, Any]
231
+ _mssql: Any
232
+ _conn: Any
233
+
234
+ def __init__(self, host, port, user, password, *, database, thread_count, **kw) -> None:
235
+ super().__init__(thread_count=thread_count)
236
+
237
+ port = port if port else 1433
238
+ args = dict(
239
+ host=host,
240
+ port=port,
241
+ database=database,
242
+ user=user,
243
+ password=password,
244
+ **kw,
245
+ )
246
+ self._args = {k: v for k, v in args.items() if v is not None}
247
+ if self._args.get("odbc_driver", None) is not None:
248
+ self._args["driver"] = self._args.pop("odbc_driver")
249
+ else:
250
+ self._args["driver"] = "{ODBC Driver 18 for SQL Server}"
251
+ try:
252
+ self.default_database = self._args["database"]
253
+ self.default_schema = self._args["schema"]
254
+ self.dialect.default_schema = self.default_schema
255
+ except KeyError:
256
+ raise ValueError("Specify a default database and schema.")
257
+ self._mssql = None
258
+ self._conn = self.create_connection()
259
+
260
+ def create_connection(self):
261
+ self._mssql = import_mssql()
262
+ try:
263
+ server = self._args.get("server")
264
+ port = self._args.get("port")
265
+ host = self._args.get("host")
266
+ driver = self._args.get("driver")
267
+ user = self._args.get("user")
268
+ password = self._args.get("password")
269
+ database = self._args.get("database")
270
+ connection_params = self._build_connection_params(
271
+ driver=driver, database=database, username=user, password=password
272
+ )
273
+ self._conn = self._establish_connection(connection_params, host, server, port)
274
+ return self._conn
275
+ except self._mssql.Error as error:
276
+ raise ConnectError(*error.args) from error
277
+
278
+ def _prepare_driver_string(self, driver: str) -> str:
279
+ return f"{{{driver}}}" if not driver.startswith("{") else driver
280
+
281
+ def _build_connection_params(self, driver: str, database: str, username: str, password: str) -> dict:
282
+ return {
283
+ "DRIVER": self._prepare_driver_string(driver),
284
+ "DATABASE": database,
285
+ "UID": username,
286
+ "PWD": password,
287
+ "TrustServerCertificate": "yes",
288
+ }
289
+
290
+ def _establish_connection(self, conn_dict: dict, host: str, server: str, port: str) -> Any:
291
+ connection_attempts = [
292
+ (host, True), # host with port
293
+ (host, False), # host without port
294
+ (server, True), # server with port
295
+ (server, False), # server without port
296
+ ]
297
+
298
+ for _, (server_value, use_port) in enumerate(connection_attempts, 1):
299
+ if not server_value:
300
+ continue
301
+ try:
302
+ conn_dict["SERVER"] = f"{server_value},{port}" if use_port and port else server_value
303
+ connection = self._mssql.connect(**conn_dict)
304
+ logger.info(f"Connected to MSSQL database using {conn_dict['SERVER']}")
305
+ return connection
306
+ except Exception:
307
+ continue
308
+
309
+ def select_table_schema(self, path: DbPath) -> str:
310
+ """Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)"""
311
+ database, schema, name = self._normalize_table_path(path)
312
+ info_schema_path = ["information_schema", "columns"]
313
+ if database:
314
+ info_schema_path.insert(0, self.dialect.quote(database))
315
+
316
+ return (
317
+ "SELECT column_name, data_type, ISNULL(datetime_precision, 0) AS datetime_precision, ISNULL(numeric_precision, 0) AS numeric_precision, ISNULL(numeric_scale, 0) AS numeric_scale, collation_name, ISNULL(character_maximum_length, 0) AS character_maximum_length "
318
+ f"FROM {'.'.join(info_schema_path)} "
319
+ f"WHERE table_name = '{name}' AND table_schema = '{schema}'"
320
+ )
321
+
322
+ def _normalize_table_path(self, path: DbPath) -> DbPath:
323
+ if len(path) == 1:
324
+ return self.default_database, self.default_schema, path[0]
325
+ elif len(path) == 2:
326
+ return self.default_database, path[0], path[1]
327
+ elif len(path) == 3:
328
+ return path
329
+
330
+ raise ValueError(
331
+ f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table"
332
+ )
333
+
334
+ def _query_cursor(self, c, sql_code: str):
335
+ try:
336
+ return super()._query_cursor(c, sql_code)
337
+ except self._mssql.DatabaseError as e:
338
+ raise QueryError(e)
339
+
340
+ def close(self):
341
+ super().close()
342
+ if self._conn is not None:
343
+ self._conn.close()
@@ -0,0 +1,189 @@
1
+ # Copyright 2022-present, the Waterdip Labs Pvt. Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, ClassVar, Dict, Tuple, Type, Union
16
+
17
+ import attrs
18
+
19
+ from data_diff.abcs.database_types import (
20
+ Boolean,
21
+ ColType_UUID,
22
+ Date,
23
+ Datetime,
24
+ DbPath,
25
+ Decimal,
26
+ Float,
27
+ FractionalType,
28
+ Integer,
29
+ TemporalType,
30
+ Text,
31
+ Timestamp,
32
+ )
33
+ from data_diff.databases.base import (
34
+ CHECKSUM_HEXDIGITS,
35
+ CHECKSUM_OFFSET,
36
+ MD5_HEXDIGITS,
37
+ TIMESTAMP_PRECISION_POS,
38
+ BaseDialect,
39
+ ConnectError,
40
+ ThreadedDatabase,
41
+ ThreadLocalInterpreter,
42
+ import_helper,
43
+ )
44
+
45
+
46
+ @import_helper("mysql")
47
+ def import_mysql():
48
+ import mysql.connector
49
+
50
+ return mysql.connector
51
+
52
+
53
+ @attrs.define(frozen=False)
54
+ class Dialect(BaseDialect):
55
+ name = "MySQL"
56
+ ROUNDS_ON_PREC_LOSS = True
57
+ SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True
58
+ SUPPORTS_INDEXES = True
59
+ TYPE_CLASSES = {
60
+ # Dates
61
+ "datetime": Datetime,
62
+ "timestamp": Timestamp,
63
+ "date": Date,
64
+ # Numbers
65
+ "double": Float,
66
+ "float": Float,
67
+ "decimal": Decimal,
68
+ "int": Integer,
69
+ "bigint": Integer,
70
+ "mediumint": Integer,
71
+ "smallint": Integer,
72
+ "tinyint": Integer,
73
+ # Text
74
+ "varchar": Text,
75
+ "char": Text,
76
+ "varbinary": Text,
77
+ "binary": Text,
78
+ "text": Text,
79
+ "mediumtext": Text,
80
+ "longtext": Text,
81
+ "tinytext": Text,
82
+ # Boolean
83
+ "boolean": Boolean,
84
+ }
85
+
86
+ def quote(self, s: str, is_table: bool = False) -> str:
87
+ if s in self.TABLE_NAMES and self.default_schema and is_table:
88
+ return f"`{self.default_schema}`.`{s}`"
89
+ return f"`{s}`"
90
+
91
+ def to_string(self, s: str) -> str:
92
+ return f"cast({s} as char)"
93
+
94
+ def is_distinct_from(self, a: str, b: str) -> str:
95
+ return f"not ({a} <=> {b})"
96
+
97
+ def random(self) -> str:
98
+ return "RAND()"
99
+
100
+ def type_repr(self, t) -> str:
101
+ try:
102
+ return {
103
+ str: "VARCHAR(1024)",
104
+ }[t]
105
+ except KeyError:
106
+ return super().type_repr(t)
107
+
108
+ def explain_as_text(self, query: str) -> str:
109
+ return f"EXPLAIN FORMAT=TREE {query}"
110
+
111
+ def optimizer_hints(self, s: str):
112
+ return f"/*+ {s} */ "
113
+
114
+ def set_timezone_to_utc(self) -> str:
115
+ return "SET @@session.time_zone='+00:00'"
116
+
117
+ def md5_as_int(self, s: str) -> str:
118
+ return f"conv(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) - {CHECKSUM_OFFSET}"
119
+
120
+ def md5_as_hex(self, s: str) -> str:
121
+ return f"md5({s})"
122
+
123
+ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
124
+ if coltype.rounds:
125
+ return self.to_string(f"cast( cast({value} as datetime({coltype.precision})) as datetime(6))")
126
+
127
+ s = self.to_string(f"cast({value} as datetime(6))")
128
+ return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')"
129
+
130
+ def normalize_number(self, value: str, coltype: FractionalType) -> str:
131
+ return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))")
132
+
133
+ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
134
+ return f"TRIM(CAST({value} AS char))"
135
+
136
+ def parse_table_name(self, name: str) -> DbPath:
137
+ "Parse the given table name into a DbPath"
138
+ self.TABLE_NAMES.append(name.split(".")[-1])
139
+ return tuple(name.split("."))
140
+
141
+
142
+ @attrs.define(frozen=False, init=False, kw_only=True)
143
+ class MySQL(ThreadedDatabase):
144
+ DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
145
+ SUPPORTS_ALPHANUMS = False
146
+ SUPPORTS_UNIQUE_CONSTAINT = True
147
+ CONNECT_URI_HELP = "mysql://<user>:<password>@<host>/<database>"
148
+ CONNECT_URI_PARAMS = ["database?"]
149
+
150
+ _args: Dict[str, Any]
151
+
152
+ def __init__(self, *, thread_count, **kw) -> None:
153
+ super().__init__(thread_count=thread_count)
154
+ self._args = kw
155
+ self._args = {k: v for k, v in self._args.items() if v}
156
+ self._args.pop("schema", None)
157
+ # In MySQL schema and database are synonymous
158
+ try:
159
+ self.default_schema = kw["database"]
160
+ self.dialect.default_schema = self.default_schema
161
+ except KeyError:
162
+ raise ValueError("MySQL URL must specify a database")
163
+
164
+ def create_connection(self):
165
+ mysql = import_mysql()
166
+ try:
167
+ return mysql.connect(charset="utf8", use_unicode=True, **self._args)
168
+ except mysql.Error as e:
169
+ if e.errno == mysql.errorcode.ER_ACCESS_DENIED_ERROR:
170
+ raise ConnectError("Bad user name or password") from e
171
+ elif e.errno == mysql.errorcode.ER_BAD_DB_ERROR:
172
+ raise ConnectError("Database does not exist") from e
173
+ raise ConnectError(*e.args) from e
174
+
175
+ def _query_in_worker(self, sql_code: Union[str, ThreadLocalInterpreter]):
176
+ "This method runs in a worker thread"
177
+ if self._init_error:
178
+ raise self._init_error
179
+ if not self.thread_local.conn.is_connected():
180
+ self.thread_local.conn.ping(reconnect=True, attempts=3, delay=5)
181
+ return self._query_conn(self.thread_local.conn, sql_code)
182
+
183
+ def select_table_schema(self, path):
184
+ schema, name = self._normalize_table_path(path)
185
+ return (
186
+ "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale, NULL as collation_name, character_maximum_length "
187
+ "FROM information_schema.columns "
188
+ f"WHERE table_name = '{name}' AND table_schema = '{schema}'"
189
+ )