dcs-sdk 1.6.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_diff/__init__.py +221 -0
- data_diff/__main__.py +517 -0
- data_diff/abcs/__init__.py +13 -0
- data_diff/abcs/compiler.py +27 -0
- data_diff/abcs/database_types.py +402 -0
- data_diff/config.py +141 -0
- data_diff/databases/__init__.py +38 -0
- data_diff/databases/_connect.py +323 -0
- data_diff/databases/base.py +1417 -0
- data_diff/databases/bigquery.py +376 -0
- data_diff/databases/clickhouse.py +217 -0
- data_diff/databases/databricks.py +262 -0
- data_diff/databases/duckdb.py +207 -0
- data_diff/databases/mssql.py +343 -0
- data_diff/databases/mysql.py +189 -0
- data_diff/databases/oracle.py +238 -0
- data_diff/databases/postgresql.py +293 -0
- data_diff/databases/presto.py +222 -0
- data_diff/databases/redis.py +93 -0
- data_diff/databases/redshift.py +233 -0
- data_diff/databases/snowflake.py +222 -0
- data_diff/databases/sybase.py +720 -0
- data_diff/databases/trino.py +73 -0
- data_diff/databases/vertica.py +174 -0
- data_diff/diff_tables.py +489 -0
- data_diff/errors.py +17 -0
- data_diff/format.py +369 -0
- data_diff/hashdiff_tables.py +1026 -0
- data_diff/info_tree.py +76 -0
- data_diff/joindiff_tables.py +434 -0
- data_diff/lexicographic_space.py +253 -0
- data_diff/parse_time.py +88 -0
- data_diff/py.typed +0 -0
- data_diff/queries/__init__.py +13 -0
- data_diff/queries/api.py +213 -0
- data_diff/queries/ast_classes.py +811 -0
- data_diff/queries/base.py +38 -0
- data_diff/queries/extras.py +43 -0
- data_diff/query_utils.py +70 -0
- data_diff/schema.py +67 -0
- data_diff/table_segment.py +583 -0
- data_diff/thread_utils.py +112 -0
- data_diff/utils.py +1022 -0
- data_diff/version.py +15 -0
- dcs_core/__init__.py +13 -0
- dcs_core/__main__.py +17 -0
- dcs_core/__version__.py +15 -0
- dcs_core/cli/__init__.py +13 -0
- dcs_core/cli/cli.py +165 -0
- dcs_core/core/__init__.py +19 -0
- dcs_core/core/common/__init__.py +13 -0
- dcs_core/core/common/errors.py +50 -0
- dcs_core/core/common/models/__init__.py +13 -0
- dcs_core/core/common/models/configuration.py +284 -0
- dcs_core/core/common/models/dashboard.py +24 -0
- dcs_core/core/common/models/data_source_resource.py +75 -0
- dcs_core/core/common/models/metric.py +160 -0
- dcs_core/core/common/models/profile.py +75 -0
- dcs_core/core/common/models/validation.py +216 -0
- dcs_core/core/common/models/widget.py +44 -0
- dcs_core/core/configuration/__init__.py +13 -0
- dcs_core/core/configuration/config_loader.py +139 -0
- dcs_core/core/configuration/configuration_parser.py +262 -0
- dcs_core/core/configuration/configuration_parser_arc.py +328 -0
- dcs_core/core/datasource/__init__.py +13 -0
- dcs_core/core/datasource/base.py +62 -0
- dcs_core/core/datasource/manager.py +112 -0
- dcs_core/core/datasource/search_datasource.py +421 -0
- dcs_core/core/datasource/sql_datasource.py +1094 -0
- dcs_core/core/inspect.py +163 -0
- dcs_core/core/logger/__init__.py +13 -0
- dcs_core/core/logger/base.py +32 -0
- dcs_core/core/logger/default_logger.py +94 -0
- dcs_core/core/metric/__init__.py +13 -0
- dcs_core/core/metric/base.py +220 -0
- dcs_core/core/metric/combined_metric.py +98 -0
- dcs_core/core/metric/custom_metric.py +34 -0
- dcs_core/core/metric/manager.py +137 -0
- dcs_core/core/metric/numeric_metric.py +403 -0
- dcs_core/core/metric/reliability_metric.py +90 -0
- dcs_core/core/profiling/__init__.py +13 -0
- dcs_core/core/profiling/datasource_profiling.py +136 -0
- dcs_core/core/profiling/numeric_field_profiling.py +72 -0
- dcs_core/core/profiling/text_field_profiling.py +67 -0
- dcs_core/core/repository/__init__.py +13 -0
- dcs_core/core/repository/metric_repository.py +77 -0
- dcs_core/core/utils/__init__.py +13 -0
- dcs_core/core/utils/log.py +29 -0
- dcs_core/core/utils/tracking.py +105 -0
- dcs_core/core/utils/utils.py +44 -0
- dcs_core/core/validation/__init__.py +13 -0
- dcs_core/core/validation/base.py +230 -0
- dcs_core/core/validation/completeness_validation.py +153 -0
- dcs_core/core/validation/custom_query_validation.py +24 -0
- dcs_core/core/validation/manager.py +282 -0
- dcs_core/core/validation/numeric_validation.py +276 -0
- dcs_core/core/validation/reliability_validation.py +91 -0
- dcs_core/core/validation/uniqueness_validation.py +61 -0
- dcs_core/core/validation/validity_validation.py +738 -0
- dcs_core/integrations/__init__.py +13 -0
- dcs_core/integrations/databases/__init__.py +13 -0
- dcs_core/integrations/databases/bigquery.py +187 -0
- dcs_core/integrations/databases/databricks.py +51 -0
- dcs_core/integrations/databases/db2.py +652 -0
- dcs_core/integrations/databases/elasticsearch.py +61 -0
- dcs_core/integrations/databases/mssql.py +829 -0
- dcs_core/integrations/databases/mysql.py +409 -0
- dcs_core/integrations/databases/opensearch.py +64 -0
- dcs_core/integrations/databases/oracle.py +719 -0
- dcs_core/integrations/databases/postgres.py +482 -0
- dcs_core/integrations/databases/redshift.py +53 -0
- dcs_core/integrations/databases/snowflake.py +48 -0
- dcs_core/integrations/databases/spark_df.py +111 -0
- dcs_core/integrations/databases/sybase.py +1069 -0
- dcs_core/integrations/storage/__init__.py +13 -0
- dcs_core/integrations/storage/local_file.py +149 -0
- dcs_core/integrations/utils/__init__.py +13 -0
- dcs_core/integrations/utils/utils.py +36 -0
- dcs_core/report/__init__.py +13 -0
- dcs_core/report/dashboard.py +211 -0
- dcs_core/report/models.py +88 -0
- dcs_core/report/static/assets/fonts/DMSans-Bold.ttf +0 -0
- dcs_core/report/static/assets/fonts/DMSans-Medium.ttf +0 -0
- dcs_core/report/static/assets/fonts/DMSans-Regular.ttf +0 -0
- dcs_core/report/static/assets/fonts/DMSans-SemiBold.ttf +0 -0
- dcs_core/report/static/assets/images/docs.svg +6 -0
- dcs_core/report/static/assets/images/github.svg +4 -0
- dcs_core/report/static/assets/images/logo.svg +7 -0
- dcs_core/report/static/assets/images/slack.svg +13 -0
- dcs_core/report/static/index.js +2 -0
- dcs_core/report/static/index.js.LICENSE.txt +3971 -0
- dcs_sdk/__init__.py +13 -0
- dcs_sdk/__main__.py +18 -0
- dcs_sdk/__version__.py +15 -0
- dcs_sdk/cli/__init__.py +13 -0
- dcs_sdk/cli/cli.py +163 -0
- dcs_sdk/sdk/__init__.py +58 -0
- dcs_sdk/sdk/config/__init__.py +13 -0
- dcs_sdk/sdk/config/config_loader.py +491 -0
- dcs_sdk/sdk/data_diff/__init__.py +13 -0
- dcs_sdk/sdk/data_diff/data_differ.py +821 -0
- dcs_sdk/sdk/rules/__init__.py +15 -0
- dcs_sdk/sdk/rules/rules_mappping.py +31 -0
- dcs_sdk/sdk/rules/rules_repository.py +214 -0
- dcs_sdk/sdk/rules/schema_rules.py +65 -0
- dcs_sdk/sdk/utils/__init__.py +13 -0
- dcs_sdk/sdk/utils/serializer.py +25 -0
- dcs_sdk/sdk/utils/similarity_score/__init__.py +13 -0
- dcs_sdk/sdk/utils/similarity_score/base_provider.py +153 -0
- dcs_sdk/sdk/utils/similarity_score/cosine_similarity_provider.py +39 -0
- dcs_sdk/sdk/utils/similarity_score/jaccard_provider.py +24 -0
- dcs_sdk/sdk/utils/similarity_score/levenshtein_distance_provider.py +31 -0
- dcs_sdk/sdk/utils/table.py +475 -0
- dcs_sdk/sdk/utils/themes.py +40 -0
- dcs_sdk/sdk/utils/utils.py +349 -0
- dcs_sdk-1.6.5.dist-info/METADATA +150 -0
- dcs_sdk-1.6.5.dist-info/RECORD +159 -0
- dcs_sdk-1.6.5.dist-info/WHEEL +4 -0
- dcs_sdk-1.6.5.dist-info/entry_points.txt +4 -0
|
@@ -0,0 +1,343 @@
|
|
|
1
|
+
# Copyright 2022-present, the Waterdip Labs Pvt. Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import re
|
|
16
|
+
from typing import Any, ClassVar, Dict, Optional, Type
|
|
17
|
+
|
|
18
|
+
import attrs
|
|
19
|
+
from loguru import logger
|
|
20
|
+
|
|
21
|
+
from data_diff.abcs.database_types import (
|
|
22
|
+
JSON,
|
|
23
|
+
Boolean,
|
|
24
|
+
Date,
|
|
25
|
+
Datetime,
|
|
26
|
+
DbPath,
|
|
27
|
+
Decimal,
|
|
28
|
+
Float,
|
|
29
|
+
FractionalType,
|
|
30
|
+
Integer,
|
|
31
|
+
Native_UUID,
|
|
32
|
+
String_UUID,
|
|
33
|
+
TemporalType,
|
|
34
|
+
Text,
|
|
35
|
+
Time,
|
|
36
|
+
Timestamp,
|
|
37
|
+
TimestampTZ,
|
|
38
|
+
)
|
|
39
|
+
from data_diff.databases.base import (
|
|
40
|
+
CHECKSUM_HEXDIGITS,
|
|
41
|
+
CHECKSUM_OFFSET,
|
|
42
|
+
BaseDialect,
|
|
43
|
+
ConnectError,
|
|
44
|
+
QueryError,
|
|
45
|
+
ThreadedDatabase,
|
|
46
|
+
import_helper,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@import_helper("mssql")
|
|
51
|
+
def import_mssql():
|
|
52
|
+
import pyodbc
|
|
53
|
+
|
|
54
|
+
return pyodbc
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@attrs.define(frozen=False)
|
|
58
|
+
class Dialect(BaseDialect):
|
|
59
|
+
name = "MsSQL"
|
|
60
|
+
ROUNDS_ON_PREC_LOSS = True
|
|
61
|
+
SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True
|
|
62
|
+
SUPPORTS_INDEXES = True
|
|
63
|
+
TYPE_CLASSES = {
|
|
64
|
+
# Timestamps
|
|
65
|
+
"datetimeoffset": TimestampTZ,
|
|
66
|
+
"datetime": Datetime,
|
|
67
|
+
"datetime2": Timestamp,
|
|
68
|
+
"smalldatetime": Datetime,
|
|
69
|
+
"timestamp": Datetime,
|
|
70
|
+
"date": Date,
|
|
71
|
+
"time": Time,
|
|
72
|
+
# Numbers
|
|
73
|
+
"float": Float,
|
|
74
|
+
"real": Float,
|
|
75
|
+
"decimal": Decimal,
|
|
76
|
+
"money": Decimal,
|
|
77
|
+
"smallmoney": Decimal,
|
|
78
|
+
"numeric": Decimal,
|
|
79
|
+
# int
|
|
80
|
+
"int": Integer,
|
|
81
|
+
"bigint": Integer,
|
|
82
|
+
"tinyint": Integer,
|
|
83
|
+
"smallint": Integer,
|
|
84
|
+
# Text
|
|
85
|
+
"varchar": Text,
|
|
86
|
+
"char": Text,
|
|
87
|
+
"text": Text,
|
|
88
|
+
"ntext": Text,
|
|
89
|
+
"nvarchar": Text,
|
|
90
|
+
"nchar": Text,
|
|
91
|
+
"binary": Text,
|
|
92
|
+
"varbinary": Text,
|
|
93
|
+
"xml": Text,
|
|
94
|
+
# UUID
|
|
95
|
+
"uniqueidentifier": Native_UUID,
|
|
96
|
+
# Bool
|
|
97
|
+
"bit": Boolean,
|
|
98
|
+
# JSON
|
|
99
|
+
"json": JSON,
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
def quote(self, s: str, is_table: bool = False) -> str:
|
|
103
|
+
if s in self.TABLE_NAMES and self.default_schema and is_table:
|
|
104
|
+
return f"[{self.default_schema}].[{s}]"
|
|
105
|
+
return f"[{s}]"
|
|
106
|
+
|
|
107
|
+
def set_timezone_to_utc(self) -> str:
|
|
108
|
+
raise NotImplementedError("MsSQL does not support a session timezone setting.")
|
|
109
|
+
|
|
110
|
+
def current_timestamp(self) -> str:
|
|
111
|
+
return "GETDATE()"
|
|
112
|
+
|
|
113
|
+
def current_database(self) -> str:
|
|
114
|
+
return "DB_NAME()"
|
|
115
|
+
|
|
116
|
+
def current_schema(self) -> str:
|
|
117
|
+
return """default_schema_name
|
|
118
|
+
FROM sys.database_principals
|
|
119
|
+
WHERE name = CURRENT_USER"""
|
|
120
|
+
|
|
121
|
+
def to_string(self, s: str) -> str:
|
|
122
|
+
s_temp = re.sub(r'["\[\]`]', "", s)
|
|
123
|
+
col_info = self.get_column_raw_info(s_temp)
|
|
124
|
+
|
|
125
|
+
ch_len = (col_info and col_info.character_maximum_length) or None
|
|
126
|
+
|
|
127
|
+
if ch_len is None or ch_len <= 0:
|
|
128
|
+
ch_len = "MAX"
|
|
129
|
+
else:
|
|
130
|
+
ch_len = min(ch_len, 8000)
|
|
131
|
+
|
|
132
|
+
if col_info and col_info.data_type.lower().strip() in ["nvarchar", "nchar"]:
|
|
133
|
+
return f"CONVERT(NVARCHAR({ch_len}), {s})"
|
|
134
|
+
|
|
135
|
+
elif col_info and col_info.data_type.lower().strip() == "text":
|
|
136
|
+
return f"CONVERT(VARCHAR(MAX), {s})"
|
|
137
|
+
|
|
138
|
+
elif col_info and col_info.data_type.lower().strip() == "ntext":
|
|
139
|
+
return f"CONVERT(NVARCHAR(MAX), {s})"
|
|
140
|
+
|
|
141
|
+
return f"CONVERT(VARCHAR({ch_len}), {s})"
|
|
142
|
+
|
|
143
|
+
def type_repr(self, t) -> str:
|
|
144
|
+
try:
|
|
145
|
+
return {bool: "bit", str: "text"}[t]
|
|
146
|
+
except KeyError:
|
|
147
|
+
return super().type_repr(t)
|
|
148
|
+
|
|
149
|
+
def random(self) -> str:
|
|
150
|
+
return "rand()"
|
|
151
|
+
|
|
152
|
+
def is_distinct_from(self, a: str, b: str) -> str:
|
|
153
|
+
# IS (NOT) DISTINCT FROM is available only since SQLServer 2022.
|
|
154
|
+
# See: https://stackoverflow.com/a/18684859/857383
|
|
155
|
+
return f"(({a}<>{b} OR {a} IS NULL OR {b} IS NULL) AND NOT({a} IS NULL AND {b} IS NULL))"
|
|
156
|
+
|
|
157
|
+
def limit_select(
|
|
158
|
+
self,
|
|
159
|
+
select_query: str,
|
|
160
|
+
offset: Optional[int] = None,
|
|
161
|
+
limit: Optional[int] = None,
|
|
162
|
+
has_order_by: Optional[bool] = None,
|
|
163
|
+
) -> str:
|
|
164
|
+
|
|
165
|
+
if offset:
|
|
166
|
+
raise NotImplementedError("No support for OFFSET in query")
|
|
167
|
+
|
|
168
|
+
result = ""
|
|
169
|
+
if not has_order_by:
|
|
170
|
+
result += "ORDER BY 1"
|
|
171
|
+
|
|
172
|
+
if limit is not None:
|
|
173
|
+
result += f" OFFSET 0 ROWS FETCH NEXT {limit} ROWS ONLY"
|
|
174
|
+
|
|
175
|
+
# select_query = re.sub(r"TRIM\(\[([\w]+)\]\)", r"TRIM(CAST([\1] AS NVARCHAR(MAX)))", select_query)
|
|
176
|
+
|
|
177
|
+
# select_query = re.sub(r"TRIM\(([\w]+)\)", r"TRIM(CAST(\1 AS NVARCHAR(MAX)))", select_query)
|
|
178
|
+
|
|
179
|
+
# select_query = re.sub(r"TRIM\(\[([\w]+)\]\)", r"LTRIM(RTRIM(CAST([\1] AS VARCHAR(8000))))", select_query)
|
|
180
|
+
|
|
181
|
+
# select_query = re.sub(r"TRIM\(([\w]+)\)", r"LTRIM(RTRIM(CAST(\1 AS VARCHAR(8000))))", select_query)
|
|
182
|
+
|
|
183
|
+
return f"{select_query} {result}"
|
|
184
|
+
|
|
185
|
+
def constant_values(self, rows) -> str:
|
|
186
|
+
values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows)
|
|
187
|
+
return f"VALUES {values}"
|
|
188
|
+
|
|
189
|
+
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
|
|
190
|
+
# if coltype.precision > 0:
|
|
191
|
+
# formatted_value = (
|
|
192
|
+
# f"FORMAT({value}, 'yyyy-MM-dd HH:mm:ss') + '.' + "
|
|
193
|
+
# f"SUBSTRING(FORMAT({value}, 'fffffff'), 1, {coltype.precision})"
|
|
194
|
+
# )
|
|
195
|
+
# else:
|
|
196
|
+
# formatted_value = f"FORMAT({value}, 'yyyy-MM-dd HH:mm:ss')"
|
|
197
|
+
|
|
198
|
+
# return formatted_value
|
|
199
|
+
if isinstance(coltype, Datetime):
|
|
200
|
+
if coltype.precision > 0:
|
|
201
|
+
return f"CASE WHEN {value} IS NULL THEN NULL ELSE FORMAT({value}, 'yyyy-MM-dd HH:mm:ss.fff') END"
|
|
202
|
+
return f"CASE WHEN {value} IS NULL THEN NULL ELSE FORMAT({value}, 'yyyy-MM-dd HH:mm:ss') END"
|
|
203
|
+
return f"CAST({value} AS VARCHAR)"
|
|
204
|
+
|
|
205
|
+
def normalize_number(self, value: str, coltype: FractionalType) -> str:
|
|
206
|
+
return self.to_string(f"CAST({value} AS DECIMAL(38, {coltype.precision}))")
|
|
207
|
+
|
|
208
|
+
def md5_as_int(self, s: str) -> str:
|
|
209
|
+
return f"convert(bigint, convert(varbinary, '0x' + RIGHT(CONVERT(NVARCHAR(32), HashBytes('MD5', {s}), 2), {CHECKSUM_HEXDIGITS}), 1)) - {CHECKSUM_OFFSET}"
|
|
210
|
+
|
|
211
|
+
def md5_as_hex(self, s: str) -> str:
|
|
212
|
+
return f"HashBytes('MD5', {s})"
|
|
213
|
+
|
|
214
|
+
def parse_table_name(self, name: str) -> DbPath:
|
|
215
|
+
"Parse the given table name into a DbPath"
|
|
216
|
+
self.TABLE_NAMES.append(name.split(".")[-1])
|
|
217
|
+
return tuple(name.split("."))
|
|
218
|
+
|
|
219
|
+
def normalize_uuid(self, value, coltype):
|
|
220
|
+
return self.to_string(value)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
@attrs.define(frozen=False, init=False, kw_only=True)
|
|
224
|
+
class MsSQL(ThreadedDatabase):
|
|
225
|
+
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
|
|
226
|
+
CONNECT_URI_HELP = "mssql://<user>:<password>@<host>/<database>/<schema>"
|
|
227
|
+
CONNECT_URI_PARAMS = ["database", "schema"]
|
|
228
|
+
|
|
229
|
+
default_database: str
|
|
230
|
+
_args: Dict[str, Any]
|
|
231
|
+
_mssql: Any
|
|
232
|
+
_conn: Any
|
|
233
|
+
|
|
234
|
+
def __init__(self, host, port, user, password, *, database, thread_count, **kw) -> None:
|
|
235
|
+
super().__init__(thread_count=thread_count)
|
|
236
|
+
|
|
237
|
+
port = port if port else 1433
|
|
238
|
+
args = dict(
|
|
239
|
+
host=host,
|
|
240
|
+
port=port,
|
|
241
|
+
database=database,
|
|
242
|
+
user=user,
|
|
243
|
+
password=password,
|
|
244
|
+
**kw,
|
|
245
|
+
)
|
|
246
|
+
self._args = {k: v for k, v in args.items() if v is not None}
|
|
247
|
+
if self._args.get("odbc_driver", None) is not None:
|
|
248
|
+
self._args["driver"] = self._args.pop("odbc_driver")
|
|
249
|
+
else:
|
|
250
|
+
self._args["driver"] = "{ODBC Driver 18 for SQL Server}"
|
|
251
|
+
try:
|
|
252
|
+
self.default_database = self._args["database"]
|
|
253
|
+
self.default_schema = self._args["schema"]
|
|
254
|
+
self.dialect.default_schema = self.default_schema
|
|
255
|
+
except KeyError:
|
|
256
|
+
raise ValueError("Specify a default database and schema.")
|
|
257
|
+
self._mssql = None
|
|
258
|
+
self._conn = self.create_connection()
|
|
259
|
+
|
|
260
|
+
def create_connection(self):
|
|
261
|
+
self._mssql = import_mssql()
|
|
262
|
+
try:
|
|
263
|
+
server = self._args.get("server")
|
|
264
|
+
port = self._args.get("port")
|
|
265
|
+
host = self._args.get("host")
|
|
266
|
+
driver = self._args.get("driver")
|
|
267
|
+
user = self._args.get("user")
|
|
268
|
+
password = self._args.get("password")
|
|
269
|
+
database = self._args.get("database")
|
|
270
|
+
connection_params = self._build_connection_params(
|
|
271
|
+
driver=driver, database=database, username=user, password=password
|
|
272
|
+
)
|
|
273
|
+
self._conn = self._establish_connection(connection_params, host, server, port)
|
|
274
|
+
return self._conn
|
|
275
|
+
except self._mssql.Error as error:
|
|
276
|
+
raise ConnectError(*error.args) from error
|
|
277
|
+
|
|
278
|
+
def _prepare_driver_string(self, driver: str) -> str:
|
|
279
|
+
return f"{{{driver}}}" if not driver.startswith("{") else driver
|
|
280
|
+
|
|
281
|
+
def _build_connection_params(self, driver: str, database: str, username: str, password: str) -> dict:
|
|
282
|
+
return {
|
|
283
|
+
"DRIVER": self._prepare_driver_string(driver),
|
|
284
|
+
"DATABASE": database,
|
|
285
|
+
"UID": username,
|
|
286
|
+
"PWD": password,
|
|
287
|
+
"TrustServerCertificate": "yes",
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
def _establish_connection(self, conn_dict: dict, host: str, server: str, port: str) -> Any:
|
|
291
|
+
connection_attempts = [
|
|
292
|
+
(host, True), # host with port
|
|
293
|
+
(host, False), # host without port
|
|
294
|
+
(server, True), # server with port
|
|
295
|
+
(server, False), # server without port
|
|
296
|
+
]
|
|
297
|
+
|
|
298
|
+
for _, (server_value, use_port) in enumerate(connection_attempts, 1):
|
|
299
|
+
if not server_value:
|
|
300
|
+
continue
|
|
301
|
+
try:
|
|
302
|
+
conn_dict["SERVER"] = f"{server_value},{port}" if use_port and port else server_value
|
|
303
|
+
connection = self._mssql.connect(**conn_dict)
|
|
304
|
+
logger.info(f"Connected to MSSQL database using {conn_dict['SERVER']}")
|
|
305
|
+
return connection
|
|
306
|
+
except Exception:
|
|
307
|
+
continue
|
|
308
|
+
|
|
309
|
+
def select_table_schema(self, path: DbPath) -> str:
|
|
310
|
+
"""Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)"""
|
|
311
|
+
database, schema, name = self._normalize_table_path(path)
|
|
312
|
+
info_schema_path = ["information_schema", "columns"]
|
|
313
|
+
if database:
|
|
314
|
+
info_schema_path.insert(0, self.dialect.quote(database))
|
|
315
|
+
|
|
316
|
+
return (
|
|
317
|
+
"SELECT column_name, data_type, ISNULL(datetime_precision, 0) AS datetime_precision, ISNULL(numeric_precision, 0) AS numeric_precision, ISNULL(numeric_scale, 0) AS numeric_scale, collation_name, ISNULL(character_maximum_length, 0) AS character_maximum_length "
|
|
318
|
+
f"FROM {'.'.join(info_schema_path)} "
|
|
319
|
+
f"WHERE table_name = '{name}' AND table_schema = '{schema}'"
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
def _normalize_table_path(self, path: DbPath) -> DbPath:
|
|
323
|
+
if len(path) == 1:
|
|
324
|
+
return self.default_database, self.default_schema, path[0]
|
|
325
|
+
elif len(path) == 2:
|
|
326
|
+
return self.default_database, path[0], path[1]
|
|
327
|
+
elif len(path) == 3:
|
|
328
|
+
return path
|
|
329
|
+
|
|
330
|
+
raise ValueError(
|
|
331
|
+
f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table"
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
def _query_cursor(self, c, sql_code: str):
|
|
335
|
+
try:
|
|
336
|
+
return super()._query_cursor(c, sql_code)
|
|
337
|
+
except self._mssql.DatabaseError as e:
|
|
338
|
+
raise QueryError(e)
|
|
339
|
+
|
|
340
|
+
def close(self):
|
|
341
|
+
super().close()
|
|
342
|
+
if self._conn is not None:
|
|
343
|
+
self._conn.close()
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
# Copyright 2022-present, the Waterdip Labs Pvt. Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, ClassVar, Dict, Tuple, Type, Union
|
|
16
|
+
|
|
17
|
+
import attrs
|
|
18
|
+
|
|
19
|
+
from data_diff.abcs.database_types import (
|
|
20
|
+
Boolean,
|
|
21
|
+
ColType_UUID,
|
|
22
|
+
Date,
|
|
23
|
+
Datetime,
|
|
24
|
+
DbPath,
|
|
25
|
+
Decimal,
|
|
26
|
+
Float,
|
|
27
|
+
FractionalType,
|
|
28
|
+
Integer,
|
|
29
|
+
TemporalType,
|
|
30
|
+
Text,
|
|
31
|
+
Timestamp,
|
|
32
|
+
)
|
|
33
|
+
from data_diff.databases.base import (
|
|
34
|
+
CHECKSUM_HEXDIGITS,
|
|
35
|
+
CHECKSUM_OFFSET,
|
|
36
|
+
MD5_HEXDIGITS,
|
|
37
|
+
TIMESTAMP_PRECISION_POS,
|
|
38
|
+
BaseDialect,
|
|
39
|
+
ConnectError,
|
|
40
|
+
ThreadedDatabase,
|
|
41
|
+
ThreadLocalInterpreter,
|
|
42
|
+
import_helper,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@import_helper("mysql")
|
|
47
|
+
def import_mysql():
|
|
48
|
+
import mysql.connector
|
|
49
|
+
|
|
50
|
+
return mysql.connector
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@attrs.define(frozen=False)
|
|
54
|
+
class Dialect(BaseDialect):
|
|
55
|
+
name = "MySQL"
|
|
56
|
+
ROUNDS_ON_PREC_LOSS = True
|
|
57
|
+
SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True
|
|
58
|
+
SUPPORTS_INDEXES = True
|
|
59
|
+
TYPE_CLASSES = {
|
|
60
|
+
# Dates
|
|
61
|
+
"datetime": Datetime,
|
|
62
|
+
"timestamp": Timestamp,
|
|
63
|
+
"date": Date,
|
|
64
|
+
# Numbers
|
|
65
|
+
"double": Float,
|
|
66
|
+
"float": Float,
|
|
67
|
+
"decimal": Decimal,
|
|
68
|
+
"int": Integer,
|
|
69
|
+
"bigint": Integer,
|
|
70
|
+
"mediumint": Integer,
|
|
71
|
+
"smallint": Integer,
|
|
72
|
+
"tinyint": Integer,
|
|
73
|
+
# Text
|
|
74
|
+
"varchar": Text,
|
|
75
|
+
"char": Text,
|
|
76
|
+
"varbinary": Text,
|
|
77
|
+
"binary": Text,
|
|
78
|
+
"text": Text,
|
|
79
|
+
"mediumtext": Text,
|
|
80
|
+
"longtext": Text,
|
|
81
|
+
"tinytext": Text,
|
|
82
|
+
# Boolean
|
|
83
|
+
"boolean": Boolean,
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
def quote(self, s: str, is_table: bool = False) -> str:
|
|
87
|
+
if s in self.TABLE_NAMES and self.default_schema and is_table:
|
|
88
|
+
return f"`{self.default_schema}`.`{s}`"
|
|
89
|
+
return f"`{s}`"
|
|
90
|
+
|
|
91
|
+
def to_string(self, s: str) -> str:
|
|
92
|
+
return f"cast({s} as char)"
|
|
93
|
+
|
|
94
|
+
def is_distinct_from(self, a: str, b: str) -> str:
|
|
95
|
+
return f"not ({a} <=> {b})"
|
|
96
|
+
|
|
97
|
+
def random(self) -> str:
|
|
98
|
+
return "RAND()"
|
|
99
|
+
|
|
100
|
+
def type_repr(self, t) -> str:
|
|
101
|
+
try:
|
|
102
|
+
return {
|
|
103
|
+
str: "VARCHAR(1024)",
|
|
104
|
+
}[t]
|
|
105
|
+
except KeyError:
|
|
106
|
+
return super().type_repr(t)
|
|
107
|
+
|
|
108
|
+
def explain_as_text(self, query: str) -> str:
|
|
109
|
+
return f"EXPLAIN FORMAT=TREE {query}"
|
|
110
|
+
|
|
111
|
+
def optimizer_hints(self, s: str):
|
|
112
|
+
return f"/*+ {s} */ "
|
|
113
|
+
|
|
114
|
+
def set_timezone_to_utc(self) -> str:
|
|
115
|
+
return "SET @@session.time_zone='+00:00'"
|
|
116
|
+
|
|
117
|
+
def md5_as_int(self, s: str) -> str:
|
|
118
|
+
return f"conv(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) - {CHECKSUM_OFFSET}"
|
|
119
|
+
|
|
120
|
+
def md5_as_hex(self, s: str) -> str:
|
|
121
|
+
return f"md5({s})"
|
|
122
|
+
|
|
123
|
+
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
|
|
124
|
+
if coltype.rounds:
|
|
125
|
+
return self.to_string(f"cast( cast({value} as datetime({coltype.precision})) as datetime(6))")
|
|
126
|
+
|
|
127
|
+
s = self.to_string(f"cast({value} as datetime(6))")
|
|
128
|
+
return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')"
|
|
129
|
+
|
|
130
|
+
def normalize_number(self, value: str, coltype: FractionalType) -> str:
|
|
131
|
+
return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))")
|
|
132
|
+
|
|
133
|
+
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
|
|
134
|
+
return f"TRIM(CAST({value} AS char))"
|
|
135
|
+
|
|
136
|
+
def parse_table_name(self, name: str) -> DbPath:
|
|
137
|
+
"Parse the given table name into a DbPath"
|
|
138
|
+
self.TABLE_NAMES.append(name.split(".")[-1])
|
|
139
|
+
return tuple(name.split("."))
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@attrs.define(frozen=False, init=False, kw_only=True)
|
|
143
|
+
class MySQL(ThreadedDatabase):
|
|
144
|
+
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
|
|
145
|
+
SUPPORTS_ALPHANUMS = False
|
|
146
|
+
SUPPORTS_UNIQUE_CONSTAINT = True
|
|
147
|
+
CONNECT_URI_HELP = "mysql://<user>:<password>@<host>/<database>"
|
|
148
|
+
CONNECT_URI_PARAMS = ["database?"]
|
|
149
|
+
|
|
150
|
+
_args: Dict[str, Any]
|
|
151
|
+
|
|
152
|
+
def __init__(self, *, thread_count, **kw) -> None:
|
|
153
|
+
super().__init__(thread_count=thread_count)
|
|
154
|
+
self._args = kw
|
|
155
|
+
self._args = {k: v for k, v in self._args.items() if v}
|
|
156
|
+
self._args.pop("schema", None)
|
|
157
|
+
# In MySQL schema and database are synonymous
|
|
158
|
+
try:
|
|
159
|
+
self.default_schema = kw["database"]
|
|
160
|
+
self.dialect.default_schema = self.default_schema
|
|
161
|
+
except KeyError:
|
|
162
|
+
raise ValueError("MySQL URL must specify a database")
|
|
163
|
+
|
|
164
|
+
def create_connection(self):
|
|
165
|
+
mysql = import_mysql()
|
|
166
|
+
try:
|
|
167
|
+
return mysql.connect(charset="utf8", use_unicode=True, **self._args)
|
|
168
|
+
except mysql.Error as e:
|
|
169
|
+
if e.errno == mysql.errorcode.ER_ACCESS_DENIED_ERROR:
|
|
170
|
+
raise ConnectError("Bad user name or password") from e
|
|
171
|
+
elif e.errno == mysql.errorcode.ER_BAD_DB_ERROR:
|
|
172
|
+
raise ConnectError("Database does not exist") from e
|
|
173
|
+
raise ConnectError(*e.args) from e
|
|
174
|
+
|
|
175
|
+
def _query_in_worker(self, sql_code: Union[str, ThreadLocalInterpreter]):
|
|
176
|
+
"This method runs in a worker thread"
|
|
177
|
+
if self._init_error:
|
|
178
|
+
raise self._init_error
|
|
179
|
+
if not self.thread_local.conn.is_connected():
|
|
180
|
+
self.thread_local.conn.ping(reconnect=True, attempts=3, delay=5)
|
|
181
|
+
return self._query_conn(self.thread_local.conn, sql_code)
|
|
182
|
+
|
|
183
|
+
def select_table_schema(self, path):
|
|
184
|
+
schema, name = self._normalize_table_path(path)
|
|
185
|
+
return (
|
|
186
|
+
"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale, NULL as collation_name, character_maximum_length "
|
|
187
|
+
"FROM information_schema.columns "
|
|
188
|
+
f"WHERE table_name = '{name}' AND table_schema = '{schema}'"
|
|
189
|
+
)
|