dcs-sdk 1.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (159) hide show
  1. data_diff/__init__.py +221 -0
  2. data_diff/__main__.py +517 -0
  3. data_diff/abcs/__init__.py +13 -0
  4. data_diff/abcs/compiler.py +27 -0
  5. data_diff/abcs/database_types.py +402 -0
  6. data_diff/config.py +141 -0
  7. data_diff/databases/__init__.py +38 -0
  8. data_diff/databases/_connect.py +323 -0
  9. data_diff/databases/base.py +1417 -0
  10. data_diff/databases/bigquery.py +376 -0
  11. data_diff/databases/clickhouse.py +217 -0
  12. data_diff/databases/databricks.py +262 -0
  13. data_diff/databases/duckdb.py +207 -0
  14. data_diff/databases/mssql.py +343 -0
  15. data_diff/databases/mysql.py +189 -0
  16. data_diff/databases/oracle.py +238 -0
  17. data_diff/databases/postgresql.py +293 -0
  18. data_diff/databases/presto.py +222 -0
  19. data_diff/databases/redis.py +93 -0
  20. data_diff/databases/redshift.py +233 -0
  21. data_diff/databases/snowflake.py +222 -0
  22. data_diff/databases/sybase.py +720 -0
  23. data_diff/databases/trino.py +73 -0
  24. data_diff/databases/vertica.py +174 -0
  25. data_diff/diff_tables.py +489 -0
  26. data_diff/errors.py +17 -0
  27. data_diff/format.py +369 -0
  28. data_diff/hashdiff_tables.py +1026 -0
  29. data_diff/info_tree.py +76 -0
  30. data_diff/joindiff_tables.py +434 -0
  31. data_diff/lexicographic_space.py +253 -0
  32. data_diff/parse_time.py +88 -0
  33. data_diff/py.typed +0 -0
  34. data_diff/queries/__init__.py +13 -0
  35. data_diff/queries/api.py +213 -0
  36. data_diff/queries/ast_classes.py +811 -0
  37. data_diff/queries/base.py +38 -0
  38. data_diff/queries/extras.py +43 -0
  39. data_diff/query_utils.py +70 -0
  40. data_diff/schema.py +67 -0
  41. data_diff/table_segment.py +583 -0
  42. data_diff/thread_utils.py +112 -0
  43. data_diff/utils.py +1022 -0
  44. data_diff/version.py +15 -0
  45. dcs_core/__init__.py +13 -0
  46. dcs_core/__main__.py +17 -0
  47. dcs_core/__version__.py +15 -0
  48. dcs_core/cli/__init__.py +13 -0
  49. dcs_core/cli/cli.py +165 -0
  50. dcs_core/core/__init__.py +19 -0
  51. dcs_core/core/common/__init__.py +13 -0
  52. dcs_core/core/common/errors.py +50 -0
  53. dcs_core/core/common/models/__init__.py +13 -0
  54. dcs_core/core/common/models/configuration.py +284 -0
  55. dcs_core/core/common/models/dashboard.py +24 -0
  56. dcs_core/core/common/models/data_source_resource.py +75 -0
  57. dcs_core/core/common/models/metric.py +160 -0
  58. dcs_core/core/common/models/profile.py +75 -0
  59. dcs_core/core/common/models/validation.py +216 -0
  60. dcs_core/core/common/models/widget.py +44 -0
  61. dcs_core/core/configuration/__init__.py +13 -0
  62. dcs_core/core/configuration/config_loader.py +139 -0
  63. dcs_core/core/configuration/configuration_parser.py +262 -0
  64. dcs_core/core/configuration/configuration_parser_arc.py +328 -0
  65. dcs_core/core/datasource/__init__.py +13 -0
  66. dcs_core/core/datasource/base.py +62 -0
  67. dcs_core/core/datasource/manager.py +112 -0
  68. dcs_core/core/datasource/search_datasource.py +421 -0
  69. dcs_core/core/datasource/sql_datasource.py +1094 -0
  70. dcs_core/core/inspect.py +163 -0
  71. dcs_core/core/logger/__init__.py +13 -0
  72. dcs_core/core/logger/base.py +32 -0
  73. dcs_core/core/logger/default_logger.py +94 -0
  74. dcs_core/core/metric/__init__.py +13 -0
  75. dcs_core/core/metric/base.py +220 -0
  76. dcs_core/core/metric/combined_metric.py +98 -0
  77. dcs_core/core/metric/custom_metric.py +34 -0
  78. dcs_core/core/metric/manager.py +137 -0
  79. dcs_core/core/metric/numeric_metric.py +403 -0
  80. dcs_core/core/metric/reliability_metric.py +90 -0
  81. dcs_core/core/profiling/__init__.py +13 -0
  82. dcs_core/core/profiling/datasource_profiling.py +136 -0
  83. dcs_core/core/profiling/numeric_field_profiling.py +72 -0
  84. dcs_core/core/profiling/text_field_profiling.py +67 -0
  85. dcs_core/core/repository/__init__.py +13 -0
  86. dcs_core/core/repository/metric_repository.py +77 -0
  87. dcs_core/core/utils/__init__.py +13 -0
  88. dcs_core/core/utils/log.py +29 -0
  89. dcs_core/core/utils/tracking.py +105 -0
  90. dcs_core/core/utils/utils.py +44 -0
  91. dcs_core/core/validation/__init__.py +13 -0
  92. dcs_core/core/validation/base.py +230 -0
  93. dcs_core/core/validation/completeness_validation.py +153 -0
  94. dcs_core/core/validation/custom_query_validation.py +24 -0
  95. dcs_core/core/validation/manager.py +282 -0
  96. dcs_core/core/validation/numeric_validation.py +276 -0
  97. dcs_core/core/validation/reliability_validation.py +91 -0
  98. dcs_core/core/validation/uniqueness_validation.py +61 -0
  99. dcs_core/core/validation/validity_validation.py +738 -0
  100. dcs_core/integrations/__init__.py +13 -0
  101. dcs_core/integrations/databases/__init__.py +13 -0
  102. dcs_core/integrations/databases/bigquery.py +187 -0
  103. dcs_core/integrations/databases/databricks.py +51 -0
  104. dcs_core/integrations/databases/db2.py +652 -0
  105. dcs_core/integrations/databases/elasticsearch.py +61 -0
  106. dcs_core/integrations/databases/mssql.py +829 -0
  107. dcs_core/integrations/databases/mysql.py +409 -0
  108. dcs_core/integrations/databases/opensearch.py +64 -0
  109. dcs_core/integrations/databases/oracle.py +719 -0
  110. dcs_core/integrations/databases/postgres.py +482 -0
  111. dcs_core/integrations/databases/redshift.py +53 -0
  112. dcs_core/integrations/databases/snowflake.py +48 -0
  113. dcs_core/integrations/databases/spark_df.py +111 -0
  114. dcs_core/integrations/databases/sybase.py +1069 -0
  115. dcs_core/integrations/storage/__init__.py +13 -0
  116. dcs_core/integrations/storage/local_file.py +149 -0
  117. dcs_core/integrations/utils/__init__.py +13 -0
  118. dcs_core/integrations/utils/utils.py +36 -0
  119. dcs_core/report/__init__.py +13 -0
  120. dcs_core/report/dashboard.py +211 -0
  121. dcs_core/report/models.py +88 -0
  122. dcs_core/report/static/assets/fonts/DMSans-Bold.ttf +0 -0
  123. dcs_core/report/static/assets/fonts/DMSans-Medium.ttf +0 -0
  124. dcs_core/report/static/assets/fonts/DMSans-Regular.ttf +0 -0
  125. dcs_core/report/static/assets/fonts/DMSans-SemiBold.ttf +0 -0
  126. dcs_core/report/static/assets/images/docs.svg +6 -0
  127. dcs_core/report/static/assets/images/github.svg +4 -0
  128. dcs_core/report/static/assets/images/logo.svg +7 -0
  129. dcs_core/report/static/assets/images/slack.svg +13 -0
  130. dcs_core/report/static/index.js +2 -0
  131. dcs_core/report/static/index.js.LICENSE.txt +3971 -0
  132. dcs_sdk/__init__.py +13 -0
  133. dcs_sdk/__main__.py +18 -0
  134. dcs_sdk/__version__.py +15 -0
  135. dcs_sdk/cli/__init__.py +13 -0
  136. dcs_sdk/cli/cli.py +163 -0
  137. dcs_sdk/sdk/__init__.py +58 -0
  138. dcs_sdk/sdk/config/__init__.py +13 -0
  139. dcs_sdk/sdk/config/config_loader.py +491 -0
  140. dcs_sdk/sdk/data_diff/__init__.py +13 -0
  141. dcs_sdk/sdk/data_diff/data_differ.py +821 -0
  142. dcs_sdk/sdk/rules/__init__.py +15 -0
  143. dcs_sdk/sdk/rules/rules_mappping.py +31 -0
  144. dcs_sdk/sdk/rules/rules_repository.py +214 -0
  145. dcs_sdk/sdk/rules/schema_rules.py +65 -0
  146. dcs_sdk/sdk/utils/__init__.py +13 -0
  147. dcs_sdk/sdk/utils/serializer.py +25 -0
  148. dcs_sdk/sdk/utils/similarity_score/__init__.py +13 -0
  149. dcs_sdk/sdk/utils/similarity_score/base_provider.py +153 -0
  150. dcs_sdk/sdk/utils/similarity_score/cosine_similarity_provider.py +39 -0
  151. dcs_sdk/sdk/utils/similarity_score/jaccard_provider.py +24 -0
  152. dcs_sdk/sdk/utils/similarity_score/levenshtein_distance_provider.py +31 -0
  153. dcs_sdk/sdk/utils/table.py +475 -0
  154. dcs_sdk/sdk/utils/themes.py +40 -0
  155. dcs_sdk/sdk/utils/utils.py +349 -0
  156. dcs_sdk-1.6.5.dist-info/METADATA +150 -0
  157. dcs_sdk-1.6.5.dist-info/RECORD +159 -0
  158. dcs_sdk-1.6.5.dist-info/WHEEL +4 -0
  159. dcs_sdk-1.6.5.dist-info/entry_points.txt +4 -0
@@ -0,0 +1,376 @@
1
+ # Copyright 2022-present, the Waterdip Labs Pvt. Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import base64
16
+ import json
17
+ import os
18
+ import re
19
+ import secrets
20
+ import string
21
+ import time
22
+ from typing import Any, ClassVar, List, Optional, Tuple, Type, Union
23
+
24
+ import attrs
25
+
26
+ from data_diff.abcs.database_types import (
27
+ JSON,
28
+ Array,
29
+ Boolean,
30
+ ColType,
31
+ Date,
32
+ Datetime,
33
+ DbPath,
34
+ Decimal,
35
+ Float,
36
+ FractionalType,
37
+ Integer,
38
+ Struct,
39
+ TemporalType,
40
+ Text,
41
+ Time,
42
+ Timestamp,
43
+ UnknownColType,
44
+ )
45
+ from data_diff.databases.base import (
46
+ CHECKSUM_HEXDIGITS,
47
+ CHECKSUM_OFFSET,
48
+ MD5_HEXDIGITS,
49
+ TIMESTAMP_PRECISION_POS,
50
+ BaseDialect,
51
+ ConnectError,
52
+ Database,
53
+ QueryResult,
54
+ ThreadLocalInterpreter,
55
+ apply_query,
56
+ import_helper,
57
+ parse_table_name,
58
+ )
59
+ from data_diff.schema import RawColumnInfo
60
+
61
+
62
+ @import_helper(text="Please install BigQuery and configure your google-cloud access.")
63
+ def import_bigquery():
64
+ from google.cloud import bigquery
65
+
66
+ return bigquery
67
+
68
+
69
+ def import_bigquery_service_account():
70
+ from google.oauth2 import service_account
71
+
72
+ return service_account
73
+
74
+
75
+ def import_bigquery_service_account_impersonation():
76
+ from google.auth import impersonated_credentials
77
+
78
+ return impersonated_credentials
79
+
80
+
81
+ @attrs.define(frozen=False)
82
+ class Dialect(BaseDialect):
83
+ name = "BigQuery"
84
+ ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation
85
+ TYPE_CLASSES = {
86
+ # Dates
87
+ "TIMESTAMP": Timestamp,
88
+ "DATETIME": Datetime,
89
+ "DATE": Date,
90
+ "TIME": Time,
91
+ # Numbers
92
+ "INT64": Integer,
93
+ "INT32": Integer,
94
+ "NUMERIC": Decimal,
95
+ "BIGNUMERIC": Decimal,
96
+ "FLOAT64": Float,
97
+ "FLOAT32": Float,
98
+ "STRING": Text,
99
+ "BOOL": Boolean,
100
+ "JSON": JSON,
101
+ }
102
+ TYPE_ARRAY_RE = re.compile(r"ARRAY<(.+)>")
103
+ TYPE_STRUCT_RE = re.compile(r"STRUCT<(.+)>")
104
+ # [BIG]NUMERIC, [BIG]NUMERIC(precision, scale), [BIG]NUMERIC(precision)
105
+ TYPE_NUMERIC_RE = re.compile(r"^((BIG)?NUMERIC)(?:\((\d+)(?:, (\d+))?\))?$")
106
+ # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#parameterized_decimal_type
107
+ # The default scale is 9, which means a number can have up to 9 digits after the decimal point.
108
+ DEFAULT_NUMERIC_PRECISION = 9
109
+
110
+ def random(self) -> str:
111
+ return "RAND()"
112
+
113
+ def quote(self, s: str, is_table: bool = False) -> str:
114
+ if s in self.TABLE_NAMES and (self.default_schema or self.project) and is_table:
115
+ if self.project and self.default_schema:
116
+ return f"`{self.project}`.`{self.default_schema}`.`{s}`"
117
+ elif self.project:
118
+ return f"`{self.project}`.`{s}`"
119
+ elif self.default_schema:
120
+ return f"`{self.default_schema}`.`{s}`"
121
+ return f"`{s}`"
122
+
123
+ def to_string(self, s: str) -> str:
124
+ return f"cast({s} as string)"
125
+
126
+ def type_repr(self, t) -> str:
127
+ try:
128
+ return {str: "STRING", float: "FLOAT64"}[t]
129
+ except KeyError:
130
+ return super().type_repr(t)
131
+
132
+ def parse_type(self, table_path: DbPath, info: RawColumnInfo) -> ColType:
133
+ col_type = super().parse_type(table_path, info)
134
+ if not isinstance(col_type, UnknownColType):
135
+ return col_type
136
+
137
+ m = self.TYPE_ARRAY_RE.fullmatch(info.data_type)
138
+ if m:
139
+ item_info = attrs.evolve(info, data_type=m.group(1))
140
+ item_type = self.parse_type(table_path, item_info)
141
+ col_type = Array(item_type=item_type)
142
+ return col_type
143
+
144
+ # We currently ignore structs' structure, but later can parse it too. Examples:
145
+ # - STRUCT<INT64, STRING(10)> (unnamed)
146
+ # - STRUCT<foo INT64, bar STRING(10)> (named)
147
+ # - STRUCT<foo INT64, bar ARRAY<INT64>> (with complex fields)
148
+ # - STRUCT<foo INT64, bar STRUCT<a INT64, b INT64>> (nested)
149
+ m = self.TYPE_STRUCT_RE.fullmatch(info.data_type)
150
+ if m:
151
+ col_type = Struct()
152
+ return col_type
153
+
154
+ m = self.TYPE_NUMERIC_RE.fullmatch(info.data_type)
155
+ if m:
156
+ precision = int(m.group(3)) if m.group(3) else None
157
+ scale = int(m.group(4)) if m.group(4) else None
158
+
159
+ if scale is not None:
160
+ # NUMERIC(..., scale) — scale is set explicitly
161
+ effective_precision = scale
162
+ elif precision is not None:
163
+ # NUMERIC(...) — scale is missing but precision is set
164
+ # effectively the same as NUMERIC(..., 0)
165
+ effective_precision = 0
166
+ else:
167
+ # NUMERIC → default scale is 9
168
+ effective_precision = 9
169
+ col_type = Decimal(precision=effective_precision)
170
+ return col_type
171
+
172
+ return col_type
173
+
174
+ def to_comparable(self, value: str, coltype: ColType) -> str:
175
+ """Ensure that the expression is comparable in ``IS DISTINCT FROM``."""
176
+ if isinstance(coltype, (JSON, Array, Struct)):
177
+ return self.normalize_value_by_type(value, coltype)
178
+ else:
179
+ return super().to_comparable(value, coltype)
180
+
181
+ def set_timezone_to_utc(self) -> str:
182
+ raise NotImplementedError()
183
+
184
+ def parse_table_name(self, name: str) -> DbPath:
185
+ self.TABLE_NAMES.append(name.split(".")[-1])
186
+ path = parse_table_name(name)
187
+ return tuple(i for i in path if i is not None)
188
+
189
+ def md5_as_int(self, s: str) -> str:
190
+ return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS})) as int64) as numeric) - {CHECKSUM_OFFSET}"
191
+
192
+ def md5_as_hex(self, s: str) -> str:
193
+ return f"md5({s})"
194
+
195
+ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
196
+ try:
197
+ is_date = coltype.is_date
198
+ is_time = coltype.is_time
199
+ except:
200
+ is_date = False
201
+ is_time = False
202
+ if isinstance(coltype, Date) or is_date:
203
+ return f"FORMAT_DATE('%F', {value})"
204
+ if isinstance(coltype, Time) or is_time:
205
+ microseconds = f"TIME_DIFF( {value}, cast('00:00:00' as time), microsecond)"
206
+ rounded = f"ROUND({microseconds}, -6 + {coltype.precision})"
207
+ time_value = f"TIME_ADD(cast('00:00:00' as time), interval cast({rounded} as int64) microsecond)"
208
+ converted = f"FORMAT_TIME('%H:%M:%E6S', {time_value})"
209
+ return converted
210
+
211
+ if coltype.rounds:
212
+ timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))"
213
+ return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {timestamp})"
214
+
215
+ if coltype.precision == 0:
216
+ return f"FORMAT_TIMESTAMP('%F %H:%M:%S.000000', {value})"
217
+ elif coltype.precision == 6:
218
+ return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})"
219
+
220
+ timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})"
221
+ return (
222
+ f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
223
+ )
224
+
225
+ def normalize_number(self, value: str, coltype: FractionalType) -> str:
226
+ return f"format('%.{coltype.precision}f', {value})"
227
+
228
+ def normalize_boolean(self, value: str, _coltype: Boolean) -> str:
229
+ return self.to_string(f"cast({value} as int)")
230
+
231
+ def normalize_json(self, value: str, _coltype: JSON) -> str:
232
+ # BigQuery is unable to compare arrays & structs with ==/!=/distinct from, e.g.:
233
+ # Got error: 400 Grouping is not defined for arguments of type ARRAY<INT64> at …
234
+ # So we do the best effort and compare it as strings, hoping that the JSON forms
235
+ # match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc.
236
+ return f"to_json_string({value})"
237
+
238
+ def normalize_array(self, value: str, _coltype: Array) -> str:
239
+ # BigQuery is unable to compare arrays & structs with ==/!=/distinct from, e.g.:
240
+ # Got error: 400 Grouping is not defined for arguments of type ARRAY<INT64> at …
241
+ # So we do the best effort and compare it as strings, hoping that the JSON forms
242
+ # match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc.
243
+ return f"to_json_string({value})"
244
+
245
+ def normalize_struct(self, value: str, _coltype: Struct) -> str:
246
+ # BigQuery is unable to compare arrays & structs with ==/!=/distinct from, e.g.:
247
+ # Got error: 400 Grouping is not defined for arguments of type ARRAY<INT64> at …
248
+ # So we do the best effort and compare it as strings, hoping that the JSON forms
249
+ # match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc.
250
+ return f"to_json_string({value})"
251
+
252
+ def generate_view_name(self, view_name: Optional[str] = None) -> str:
253
+ if view_name is not None:
254
+ return view_name
255
+ random_string = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(8))
256
+ timestamp = int(time.time())
257
+ return f"view_{timestamp}_{random_string.lower()}"
258
+
259
+ def create_view(self, query: str, dataset: Optional[str], view_name: Optional[str] = None) -> Tuple[str, str]:
260
+ view_name = self.generate_view_name(view_name=view_name)
261
+ full_name = f"`{self.project}`.`{dataset}`.`{view_name}`" if dataset else f"`{view_name}`"
262
+ return f"CREATE VIEW {full_name} AS {query}", view_name
263
+
264
+ def drop_view(self, view_name: str, dataset: Optional[str]) -> str:
265
+ full_name = f"`{self.project}`.`{dataset}`.`{view_name}`" if dataset else f"`{view_name}`"
266
+ return f"DROP VIEW {full_name}"
267
+
268
+
269
+ @attrs.define(frozen=False, init=False, kw_only=True)
270
+ class BigQuery(Database):
271
+ DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
272
+ CONNECT_URI_HELP = "bigquery://<project>/<dataset>"
273
+ CONNECT_URI_PARAMS = ["dataset"]
274
+
275
+ project: str
276
+ dataset: str
277
+ _client: Any
278
+
279
+ def __init__(self, project, *, dataset, bigquery_credentials=None, **kw) -> None:
280
+ super().__init__()
281
+ credentials = bigquery_credentials
282
+ bigquery = import_bigquery()
283
+
284
+ keyfile = kw.pop("keyfile", None)
285
+ impersonate_service_account = kw.pop("impersonate_service_account", None)
286
+ kw = {k: v for k, v in kw.items() if v}
287
+ if keyfile:
288
+ bigquery_service_account = import_bigquery_service_account()
289
+ if os.path.isfile(keyfile):
290
+ credentials = bigquery_service_account.Credentials.from_service_account_file(
291
+ keyfile,
292
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
293
+ )
294
+ else:
295
+ try:
296
+ decoded_key = base64.b64decode(keyfile).decode("utf-8")
297
+ key_info = json.loads(decoded_key)
298
+ except Exception:
299
+ key_info = json.loads(keyfile)
300
+ credentials = bigquery_service_account.Credentials.from_service_account_info(
301
+ key_info,
302
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
303
+ )
304
+ elif impersonate_service_account:
305
+ bigquery_service_account_impersonation = import_bigquery_service_account_impersonation()
306
+ credentials = bigquery_service_account_impersonation.Credentials(
307
+ source_credentials=credentials,
308
+ target_principal=impersonate_service_account,
309
+ target_scopes=["https://www.googleapis.com/auth/cloud-platform"],
310
+ )
311
+ if "schema" in kw:
312
+ dataset = kw.pop("schema")
313
+ self._client = bigquery.Client(project=project, credentials=credentials, **kw)
314
+ self.project = project
315
+ self.dataset = dataset
316
+
317
+ self.default_schema = dataset
318
+ self.dialect.default_schema = self.default_schema
319
+ self.dialect.project = self.project
320
+
321
+ def _normalize_returned_value(self, value):
322
+ if isinstance(value, bytes):
323
+ return value.decode()
324
+ return value
325
+
326
+ def _query_atom(self, sql_code: str):
327
+ from google.cloud import bigquery
328
+
329
+ try:
330
+ result = self._client.query(sql_code).result()
331
+ columns = [c.name for c in result.schema]
332
+ rows = list(result)
333
+ except Exception as e:
334
+ msg = "Exception when trying to execute SQL code:\n %s\n\nGot error: %s"
335
+ raise ConnectError(msg % (sql_code, e))
336
+
337
+ if rows and isinstance(rows[0], bigquery.table.Row):
338
+ rows = [tuple(self._normalize_returned_value(v) for v in row.values()) for row in rows]
339
+ return QueryResult(rows, columns)
340
+
341
+ def _query(self, sql_code: Union[str, ThreadLocalInterpreter]) -> QueryResult:
342
+ return apply_query(self._query_atom, sql_code)
343
+
344
+ def close(self):
345
+ super().close()
346
+ self._client.close()
347
+
348
+ def select_table_schema(self, path: DbPath) -> str:
349
+ project, schema, name = self._normalize_table_path(path)
350
+ return (
351
+ "SELECT column_name, data_type, 6 as datetime_precision, 38 as numeric_precision, 9 as numeric_scale, "
352
+ "NULL AS collation_name, NULL AS character_maximum_length "
353
+ f"FROM `{project}`.`{schema}`.INFORMATION_SCHEMA.COLUMNS "
354
+ f"WHERE table_name = '{name}' AND table_schema = '{schema}'"
355
+ )
356
+
357
+ def query_table_unique_columns(self, path: DbPath) -> List[str]:
358
+ return []
359
+
360
+ def _normalize_table_path(self, path: DbPath) -> DbPath:
361
+ if len(path) == 0:
362
+ raise ValueError(f"{self.name}: Bad table path for {self}: ()")
363
+ elif len(path) == 1:
364
+ return (self.project, self.default_schema, path[0])
365
+ elif len(path) == 2:
366
+ return (self.project,) + path
367
+ elif len(path) == 3:
368
+ return path
369
+ else:
370
+ raise ValueError(
371
+ f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: [project.]schema.table"
372
+ )
373
+
374
+ @property
375
+ def is_autocommit(self) -> bool:
376
+ return True
@@ -0,0 +1,217 @@
1
+ # Copyright 2022-present, the Waterdip Labs Pvt. Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, ClassVar, Dict, Optional, Type
16
+
17
+ import attrs
18
+
19
+ from data_diff.abcs.database_types import (
20
+ Boolean,
21
+ ColType,
22
+ DbPath,
23
+ Decimal,
24
+ Float,
25
+ FractionalType,
26
+ Integer,
27
+ Native_UUID,
28
+ TemporalType,
29
+ Text,
30
+ Timestamp,
31
+ )
32
+ from data_diff.databases.base import (
33
+ CHECKSUM_HEXDIGITS,
34
+ CHECKSUM_OFFSET,
35
+ MD5_HEXDIGITS,
36
+ TIMESTAMP_PRECISION_POS,
37
+ BaseDialect,
38
+ ConnectError,
39
+ ThreadedDatabase,
40
+ import_helper,
41
+ )
42
+ from data_diff.schema import RawColumnInfo
43
+
44
+ # https://clickhouse.com/docs/en/operations/server-configuration-parameters/settings/#default-database
45
+ DEFAULT_DATABASE = "default"
46
+
47
+
48
+ @import_helper("clickhouse")
49
+ def import_clickhouse():
50
+ import clickhouse_driver
51
+
52
+ return clickhouse_driver
53
+
54
+
55
+ @attrs.define(frozen=False)
56
+ class Dialect(BaseDialect):
57
+ name = "Clickhouse"
58
+ ROUNDS_ON_PREC_LOSS = False
59
+ TYPE_CLASSES = {
60
+ "Int8": Integer,
61
+ "Int16": Integer,
62
+ "Int32": Integer,
63
+ "Int64": Integer,
64
+ "Int128": Integer,
65
+ "Int256": Integer,
66
+ "UInt8": Integer,
67
+ "UInt16": Integer,
68
+ "UInt32": Integer,
69
+ "UInt64": Integer,
70
+ "UInt128": Integer,
71
+ "UInt256": Integer,
72
+ "Float32": Float,
73
+ "Float64": Float,
74
+ "Decimal": Decimal,
75
+ "UUID": Native_UUID,
76
+ "String": Text,
77
+ "FixedString": Text,
78
+ "DateTime": Timestamp,
79
+ "DateTime64": Timestamp,
80
+ "Bool": Boolean,
81
+ }
82
+
83
+ def quote(self, s: str, is_table: bool = False) -> str:
84
+ return f'"{s}"'
85
+
86
+ def to_string(self, s: str) -> str:
87
+ return f"toString({s})"
88
+
89
+ def _convert_db_precision_to_digits(self, p: int) -> int:
90
+ # Done the same as for PostgreSQL but need to rewrite in another way
91
+ # because it does not help for float with a big integer part.
92
+ return super()._convert_db_precision_to_digits(p) - 2
93
+
94
+ def parse_type(self, table_path: DbPath, info: RawColumnInfo) -> ColType:
95
+ nullable_prefix = "Nullable("
96
+ if info.data_type.startswith(nullable_prefix):
97
+ info = attrs.evolve(info, data_type=info.data_type[len(nullable_prefix) :].rstrip(")"))
98
+
99
+ if info.data_type.startswith("Decimal"):
100
+ info = attrs.evolve(info, data_type="Decimal")
101
+ elif info.data_type.startswith("FixedString"):
102
+ info = attrs.evolve(info, data_type="FixedString")
103
+ elif info.data_type.startswith("DateTime64"):
104
+ info = attrs.evolve(info, data_type="DateTime64")
105
+
106
+ return super().parse_type(table_path, info)
107
+
108
+ # def timestamp_value(self, t: DbTime) -> str:
109
+ # # return f"'{t}'"
110
+ # return f"'{str(t)[:19]}'"
111
+
112
+ def set_timezone_to_utc(self) -> str:
113
+ raise NotImplementedError()
114
+
115
+ def current_timestamp(self) -> str:
116
+ return "now()"
117
+
118
+ def md5_as_int(self, s: str) -> str:
119
+ substr_idx = 1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS
120
+ return (
121
+ f"reinterpretAsUInt128(reverse(unhex(lowerUTF8(substr(hex(MD5({s})), {substr_idx}))))) - {CHECKSUM_OFFSET}"
122
+ )
123
+
124
+ def md5_as_hex(self, s: str) -> str:
125
+ return f"hex(MD5({s}))"
126
+
127
+ def normalize_number(self, value: str, coltype: FractionalType) -> str:
128
+ # If a decimal value has trailing zeros in a fractional part, when casting to string they are dropped.
129
+ # For example:
130
+ # select toString(toDecimal128(1.10, 2)); -- the result is 1.1
131
+ # select toString(toDecimal128(1.00, 2)); -- the result is 1
132
+ # So, we should use some custom approach to save these trailing zeros.
133
+ # To avoid it, we can add a small value like 0.000001 to prevent dropping of zeros from the end when casting.
134
+ # For examples above it looks like:
135
+ # select toString(toDecimal128(1.10, 2 + 1) + toDecimal128(0.001, 3)); -- the result is 1.101
136
+ # After that, cut an extra symbol from the string, i.e. 1.101 -> 1.10
137
+ # So, the algorithm is:
138
+ # 1. Cast to decimal with precision + 1
139
+ # 2. Add a small value 10^(-precision-1)
140
+ # 3. Cast the result to string
141
+ # 4. Drop the extra digit from the string. To do that, we need to slice the string
142
+ # with length = digits in an integer part + 1 (symbol of ".") + precision
143
+
144
+ if coltype.precision == 0:
145
+ return self.to_string(f"round({value})")
146
+
147
+ precision = coltype.precision
148
+ # TODO: too complex, is there better performance way?
149
+ value = f"""
150
+ if({value} >= 0, '', '-') || left(
151
+ toString(
152
+ toDecimal128(
153
+ round(abs({value}), {precision}),
154
+ {precision} + 1
155
+ )
156
+ +
157
+ toDecimal128(
158
+ exp10(-{precision + 1}),
159
+ {precision} + 1
160
+ )
161
+ ),
162
+ toUInt8(
163
+ greatest(
164
+ floor(log10(abs({value}))) + 1,
165
+ 1
166
+ )
167
+ ) + 1 + {precision}
168
+ )
169
+ """
170
+ return value
171
+
172
+ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
173
+ prec = coltype.precision
174
+ if coltype.rounds:
175
+ timestamp = f"toDateTime64(round(toUnixTimestamp64Micro(toDateTime64({value}, 6)) / 1000000, {prec}), 6)"
176
+ return self.to_string(timestamp)
177
+
178
+ fractional = f"toUnixTimestamp64Micro(toDateTime64({value}, {prec})) % 1000000"
179
+ fractional = f"lpad({self.to_string(fractional)}, 6, '0')"
180
+ value = f"formatDateTime({value}, '%Y-%m-%d %H:%M:%S') || '.' || {self.to_string(fractional)}"
181
+ return f"rpad({value}, {TIMESTAMP_PRECISION_POS + 6}, '0')"
182
+
183
+
184
+ @attrs.define(frozen=False, init=False, kw_only=True)
185
+ class Clickhouse(ThreadedDatabase):
186
+ DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
187
+ CONNECT_URI_HELP = "clickhouse://<user>:<password>@<host>/<database>"
188
+ CONNECT_URI_PARAMS = ["database?"]
189
+
190
+ _args: Dict[str, Any]
191
+
192
+ def __init__(self, *, thread_count: int, **kw) -> None:
193
+ super().__init__(thread_count=thread_count)
194
+
195
+ self._args = kw
196
+ # In Clickhouse database and schema are the same
197
+ self.default_schema = kw.get("database", DEFAULT_DATABASE)
198
+
199
+ def create_connection(self):
200
+ clickhouse = import_clickhouse()
201
+
202
+ class SingleConnection(clickhouse.dbapi.connection.Connection):
203
+ """Not thread-safe connection to Clickhouse"""
204
+
205
+ def cursor(self, cursor_factory=None):
206
+ if not len(self.cursors):
207
+ _ = super().cursor()
208
+ return self.cursors[0]
209
+
210
+ try:
211
+ return SingleConnection(**self._args)
212
+ except clickhouse.OperationError as e:
213
+ raise ConnectError(*e.args) from e
214
+
215
+ @property
216
+ def is_autocommit(self) -> bool:
217
+ return True