dcs-sdk 1.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (159) hide show
  1. data_diff/__init__.py +221 -0
  2. data_diff/__main__.py +517 -0
  3. data_diff/abcs/__init__.py +13 -0
  4. data_diff/abcs/compiler.py +27 -0
  5. data_diff/abcs/database_types.py +402 -0
  6. data_diff/config.py +141 -0
  7. data_diff/databases/__init__.py +38 -0
  8. data_diff/databases/_connect.py +323 -0
  9. data_diff/databases/base.py +1417 -0
  10. data_diff/databases/bigquery.py +376 -0
  11. data_diff/databases/clickhouse.py +217 -0
  12. data_diff/databases/databricks.py +262 -0
  13. data_diff/databases/duckdb.py +207 -0
  14. data_diff/databases/mssql.py +343 -0
  15. data_diff/databases/mysql.py +189 -0
  16. data_diff/databases/oracle.py +238 -0
  17. data_diff/databases/postgresql.py +293 -0
  18. data_diff/databases/presto.py +222 -0
  19. data_diff/databases/redis.py +93 -0
  20. data_diff/databases/redshift.py +233 -0
  21. data_diff/databases/snowflake.py +222 -0
  22. data_diff/databases/sybase.py +720 -0
  23. data_diff/databases/trino.py +73 -0
  24. data_diff/databases/vertica.py +174 -0
  25. data_diff/diff_tables.py +489 -0
  26. data_diff/errors.py +17 -0
  27. data_diff/format.py +369 -0
  28. data_diff/hashdiff_tables.py +1026 -0
  29. data_diff/info_tree.py +76 -0
  30. data_diff/joindiff_tables.py +434 -0
  31. data_diff/lexicographic_space.py +253 -0
  32. data_diff/parse_time.py +88 -0
  33. data_diff/py.typed +0 -0
  34. data_diff/queries/__init__.py +13 -0
  35. data_diff/queries/api.py +213 -0
  36. data_diff/queries/ast_classes.py +811 -0
  37. data_diff/queries/base.py +38 -0
  38. data_diff/queries/extras.py +43 -0
  39. data_diff/query_utils.py +70 -0
  40. data_diff/schema.py +67 -0
  41. data_diff/table_segment.py +583 -0
  42. data_diff/thread_utils.py +112 -0
  43. data_diff/utils.py +1022 -0
  44. data_diff/version.py +15 -0
  45. dcs_core/__init__.py +13 -0
  46. dcs_core/__main__.py +17 -0
  47. dcs_core/__version__.py +15 -0
  48. dcs_core/cli/__init__.py +13 -0
  49. dcs_core/cli/cli.py +165 -0
  50. dcs_core/core/__init__.py +19 -0
  51. dcs_core/core/common/__init__.py +13 -0
  52. dcs_core/core/common/errors.py +50 -0
  53. dcs_core/core/common/models/__init__.py +13 -0
  54. dcs_core/core/common/models/configuration.py +284 -0
  55. dcs_core/core/common/models/dashboard.py +24 -0
  56. dcs_core/core/common/models/data_source_resource.py +75 -0
  57. dcs_core/core/common/models/metric.py +160 -0
  58. dcs_core/core/common/models/profile.py +75 -0
  59. dcs_core/core/common/models/validation.py +216 -0
  60. dcs_core/core/common/models/widget.py +44 -0
  61. dcs_core/core/configuration/__init__.py +13 -0
  62. dcs_core/core/configuration/config_loader.py +139 -0
  63. dcs_core/core/configuration/configuration_parser.py +262 -0
  64. dcs_core/core/configuration/configuration_parser_arc.py +328 -0
  65. dcs_core/core/datasource/__init__.py +13 -0
  66. dcs_core/core/datasource/base.py +62 -0
  67. dcs_core/core/datasource/manager.py +112 -0
  68. dcs_core/core/datasource/search_datasource.py +421 -0
  69. dcs_core/core/datasource/sql_datasource.py +1094 -0
  70. dcs_core/core/inspect.py +163 -0
  71. dcs_core/core/logger/__init__.py +13 -0
  72. dcs_core/core/logger/base.py +32 -0
  73. dcs_core/core/logger/default_logger.py +94 -0
  74. dcs_core/core/metric/__init__.py +13 -0
  75. dcs_core/core/metric/base.py +220 -0
  76. dcs_core/core/metric/combined_metric.py +98 -0
  77. dcs_core/core/metric/custom_metric.py +34 -0
  78. dcs_core/core/metric/manager.py +137 -0
  79. dcs_core/core/metric/numeric_metric.py +403 -0
  80. dcs_core/core/metric/reliability_metric.py +90 -0
  81. dcs_core/core/profiling/__init__.py +13 -0
  82. dcs_core/core/profiling/datasource_profiling.py +136 -0
  83. dcs_core/core/profiling/numeric_field_profiling.py +72 -0
  84. dcs_core/core/profiling/text_field_profiling.py +67 -0
  85. dcs_core/core/repository/__init__.py +13 -0
  86. dcs_core/core/repository/metric_repository.py +77 -0
  87. dcs_core/core/utils/__init__.py +13 -0
  88. dcs_core/core/utils/log.py +29 -0
  89. dcs_core/core/utils/tracking.py +105 -0
  90. dcs_core/core/utils/utils.py +44 -0
  91. dcs_core/core/validation/__init__.py +13 -0
  92. dcs_core/core/validation/base.py +230 -0
  93. dcs_core/core/validation/completeness_validation.py +153 -0
  94. dcs_core/core/validation/custom_query_validation.py +24 -0
  95. dcs_core/core/validation/manager.py +282 -0
  96. dcs_core/core/validation/numeric_validation.py +276 -0
  97. dcs_core/core/validation/reliability_validation.py +91 -0
  98. dcs_core/core/validation/uniqueness_validation.py +61 -0
  99. dcs_core/core/validation/validity_validation.py +738 -0
  100. dcs_core/integrations/__init__.py +13 -0
  101. dcs_core/integrations/databases/__init__.py +13 -0
  102. dcs_core/integrations/databases/bigquery.py +187 -0
  103. dcs_core/integrations/databases/databricks.py +51 -0
  104. dcs_core/integrations/databases/db2.py +652 -0
  105. dcs_core/integrations/databases/elasticsearch.py +61 -0
  106. dcs_core/integrations/databases/mssql.py +829 -0
  107. dcs_core/integrations/databases/mysql.py +409 -0
  108. dcs_core/integrations/databases/opensearch.py +64 -0
  109. dcs_core/integrations/databases/oracle.py +719 -0
  110. dcs_core/integrations/databases/postgres.py +482 -0
  111. dcs_core/integrations/databases/redshift.py +53 -0
  112. dcs_core/integrations/databases/snowflake.py +48 -0
  113. dcs_core/integrations/databases/spark_df.py +111 -0
  114. dcs_core/integrations/databases/sybase.py +1069 -0
  115. dcs_core/integrations/storage/__init__.py +13 -0
  116. dcs_core/integrations/storage/local_file.py +149 -0
  117. dcs_core/integrations/utils/__init__.py +13 -0
  118. dcs_core/integrations/utils/utils.py +36 -0
  119. dcs_core/report/__init__.py +13 -0
  120. dcs_core/report/dashboard.py +211 -0
  121. dcs_core/report/models.py +88 -0
  122. dcs_core/report/static/assets/fonts/DMSans-Bold.ttf +0 -0
  123. dcs_core/report/static/assets/fonts/DMSans-Medium.ttf +0 -0
  124. dcs_core/report/static/assets/fonts/DMSans-Regular.ttf +0 -0
  125. dcs_core/report/static/assets/fonts/DMSans-SemiBold.ttf +0 -0
  126. dcs_core/report/static/assets/images/docs.svg +6 -0
  127. dcs_core/report/static/assets/images/github.svg +4 -0
  128. dcs_core/report/static/assets/images/logo.svg +7 -0
  129. dcs_core/report/static/assets/images/slack.svg +13 -0
  130. dcs_core/report/static/index.js +2 -0
  131. dcs_core/report/static/index.js.LICENSE.txt +3971 -0
  132. dcs_sdk/__init__.py +13 -0
  133. dcs_sdk/__main__.py +18 -0
  134. dcs_sdk/__version__.py +15 -0
  135. dcs_sdk/cli/__init__.py +13 -0
  136. dcs_sdk/cli/cli.py +163 -0
  137. dcs_sdk/sdk/__init__.py +58 -0
  138. dcs_sdk/sdk/config/__init__.py +13 -0
  139. dcs_sdk/sdk/config/config_loader.py +491 -0
  140. dcs_sdk/sdk/data_diff/__init__.py +13 -0
  141. dcs_sdk/sdk/data_diff/data_differ.py +821 -0
  142. dcs_sdk/sdk/rules/__init__.py +15 -0
  143. dcs_sdk/sdk/rules/rules_mappping.py +31 -0
  144. dcs_sdk/sdk/rules/rules_repository.py +214 -0
  145. dcs_sdk/sdk/rules/schema_rules.py +65 -0
  146. dcs_sdk/sdk/utils/__init__.py +13 -0
  147. dcs_sdk/sdk/utils/serializer.py +25 -0
  148. dcs_sdk/sdk/utils/similarity_score/__init__.py +13 -0
  149. dcs_sdk/sdk/utils/similarity_score/base_provider.py +153 -0
  150. dcs_sdk/sdk/utils/similarity_score/cosine_similarity_provider.py +39 -0
  151. dcs_sdk/sdk/utils/similarity_score/jaccard_provider.py +24 -0
  152. dcs_sdk/sdk/utils/similarity_score/levenshtein_distance_provider.py +31 -0
  153. dcs_sdk/sdk/utils/table.py +475 -0
  154. dcs_sdk/sdk/utils/themes.py +40 -0
  155. dcs_sdk/sdk/utils/utils.py +349 -0
  156. dcs_sdk-1.6.5.dist-info/METADATA +150 -0
  157. dcs_sdk-1.6.5.dist-info/RECORD +159 -0
  158. dcs_sdk-1.6.5.dist-info/WHEEL +4 -0
  159. dcs_sdk-1.6.5.dist-info/entry_points.txt +4 -0
@@ -0,0 +1,1417 @@
1
+ # Copyright 2022-present, the Waterdip Labs Pvt. Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import abc
16
+ import contextvars
17
+ import decimal
18
+ import functools
19
+ import logging
20
+ import math
21
+ import random
22
+ import secrets
23
+ import statistics
24
+ import string
25
+ import sys
26
+ import threading
27
+ import time
28
+ from abc import abstractmethod
29
+ from concurrent.futures import ThreadPoolExecutor
30
+ from datetime import datetime
31
+ from functools import partial, wraps
32
+ from typing import (
33
+ Any,
34
+ Callable,
35
+ ClassVar,
36
+ Dict,
37
+ Generator,
38
+ Iterator,
39
+ List,
40
+ NewType,
41
+ Optional,
42
+ Sequence,
43
+ Tuple,
44
+ Type,
45
+ TypeVar,
46
+ Union,
47
+ )
48
+ from uuid import UUID
49
+
50
+ import attrs
51
+ from loguru import logger as loguru_logger
52
+ from typing_extensions import Self
53
+
54
+ from data_diff.abcs.compiler import AbstractCompiler, Compilable
55
+ from data_diff.abcs.database_types import (
56
+ JSON,
57
+ ArithAlphanumeric,
58
+ ArithUnicodeString,
59
+ Array,
60
+ Boolean,
61
+ ColType,
62
+ ColType_UUID,
63
+ DbPath,
64
+ DbTime,
65
+ Decimal,
66
+ Float,
67
+ FractionalType,
68
+ Integer,
69
+ Native_UUID,
70
+ String_Alphanum,
71
+ String_UUID,
72
+ String_VaryingAlphanum,
73
+ String_VaryingUnicode,
74
+ Struct,
75
+ TemporalType,
76
+ Text,
77
+ TimestampTZ,
78
+ UnknownColType,
79
+ )
80
+ from data_diff.queries.api import SKIP, Code, Explain, Expr, Select, table, this
81
+ from data_diff.queries.ast_classes import (
82
+ Alias,
83
+ BinOp,
84
+ CaseWhen,
85
+ Cast,
86
+ Column,
87
+ Commit,
88
+ Concat,
89
+ ConstantTable,
90
+ Count,
91
+ CreateTable,
92
+ Cte,
93
+ CurrentTimestamp,
94
+ DropTable,
95
+ Func,
96
+ GroupBy,
97
+ In,
98
+ InsertToTable,
99
+ IsDistinctFrom,
100
+ ITable,
101
+ Join,
102
+ Param,
103
+ Random,
104
+ Root,
105
+ TableAlias,
106
+ TableOp,
107
+ TablePath,
108
+ TruncateTable,
109
+ UnaryOp,
110
+ WhenThen,
111
+ _ResolveColumn,
112
+ )
113
+ from data_diff.queries.extras import (
114
+ ApplyFuncAndNormalizeAsString,
115
+ Checksum,
116
+ NormalizeAsString,
117
+ )
118
+ from data_diff.schema import RawColumnInfo
119
+ from data_diff.utils import (
120
+ ArithDate,
121
+ ArithDateTime,
122
+ ArithString,
123
+ ArithTimestamp,
124
+ ArithTimestampTZ,
125
+ ArithUUID,
126
+ SybaseDriverTypes,
127
+ is_uuid,
128
+ join_iter,
129
+ safezip,
130
+ )
131
+
132
+ logger = logging.getLogger("database")
133
+ cv_params = contextvars.ContextVar("params")
134
+
135
+
136
+ class CompileError(Exception):
137
+ pass
138
+
139
+
140
+ @attrs.define(frozen=True)
141
+ class Compiler(AbstractCompiler):
142
+ """
143
+ Compiler bears the context for a single compilation.
144
+
145
+ There can be multiple compilation per app run.
146
+ There can be multiple compilers in one compilation (with varying contexts).
147
+ """
148
+
149
+ # Database is needed to normalize tables. Dialect is needed for recursive compilations.
150
+ # In theory, it is many-to-many relations: e.g. a generic ODBC driver with multiple dialects.
151
+ # In practice, we currently bind the dialects to the specific database classes.
152
+ database: "Database"
153
+
154
+ in_select: bool = False # Compilation runtime flag
155
+ in_join: bool = False # Compilation runtime flag
156
+
157
+ _table_context: List = attrs.field(factory=list) # List[ITable]
158
+ _subqueries: Dict[str, Any] = attrs.field(factory=dict) # XXX not thread-safe
159
+ root: bool = True
160
+
161
+ _counter: List = attrs.field(factory=lambda: [0])
162
+
163
+ @property
164
+ def dialect(self) -> "BaseDialect":
165
+ return self.database.dialect
166
+
167
+ # TODO: DEPRECATED: Remove once the dialect is used directly in all places.
168
+ def compile(self, elem, params=None) -> str:
169
+ return self.dialect.compile(self, elem, params)
170
+
171
+ def new_unique_name(self, prefix="tmp") -> str:
172
+ self._counter[0] += 1
173
+ return f"{prefix}{self._counter[0]}"
174
+
175
+ def new_unique_table_name(self, prefix="tmp") -> DbPath:
176
+ self._counter[0] += 1
177
+ table_name = f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}"
178
+ return self.database.dialect.parse_table_name(table_name)
179
+
180
+ def add_table_context(self, *tables: Sequence, **kw) -> Self:
181
+ return attrs.evolve(self, table_context=self._table_context + list(tables), **kw)
182
+
183
+
184
+ def parse_table_name(t):
185
+ return tuple(t.split("."))
186
+
187
+
188
+ def import_helper(package: str = None, text=""):
189
+ def dec(f):
190
+ @wraps(f)
191
+ def _inner():
192
+ try:
193
+ return f()
194
+ except ModuleNotFoundError as e:
195
+ s = text
196
+ if package:
197
+ s += f"Please complete setup by running: pip install 'dcs-cli[{package}]'."
198
+ raise ModuleNotFoundError(f"{e}\n\n{s}\n")
199
+
200
+ return _inner
201
+
202
+ return dec
203
+
204
+
205
+ class ConnectError(Exception):
206
+ pass
207
+
208
+
209
+ class QueryError(Exception):
210
+ pass
211
+
212
+
213
+ def _one(seq):
214
+ (x,) = seq
215
+ return x
216
+
217
+
218
+ @attrs.define(frozen=False)
219
+ class ThreadLocalInterpreter:
220
+ """An interpeter used to execute a sequence of queries within the same thread and cursor.
221
+
222
+ Useful for cursor-sensitive operations, such as creating a temporary table.
223
+ """
224
+
225
+ compiler: Compiler
226
+ gen: Generator
227
+
228
+ def apply_queries(self, callback: Callable[[str], Any]) -> None:
229
+ q: Expr = next(self.gen)
230
+ while True:
231
+ sql = self.compiler.database.dialect.compile(self.compiler, q)
232
+ logger.debug("Running SQL (%s-TL):\n%s", self.compiler.database.name, sql)
233
+ try:
234
+ try:
235
+ res = callback(sql) if sql is not SKIP else SKIP
236
+ except Exception as e:
237
+ q = self.gen.throw(type(e), e)
238
+ else:
239
+ q = self.gen.send(res)
240
+ except StopIteration:
241
+ break
242
+
243
+
244
+ def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocalInterpreter]) -> list:
245
+ if isinstance(sql_code, ThreadLocalInterpreter):
246
+ return sql_code.apply_queries(callback)
247
+ else:
248
+ return callback(sql_code)
249
+
250
+
251
+ @attrs.define(frozen=False)
252
+ class BaseDialect(abc.ABC):
253
+ SUPPORTS_PRIMARY_KEY: ClassVar[bool] = False
254
+ SUPPORTS_INDEXES: ClassVar[bool] = False
255
+ PREVENT_OVERFLOW_WHEN_CONCAT: ClassVar[bool] = False
256
+ TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = {}
257
+ DEFAULT_NUMERIC_PRECISION: ClassVar[int] = 0 # effective precision when type is just "NUMERIC"
258
+
259
+ PLACEHOLDER_TABLE = None # Used for Oracle
260
+
261
+ default_schema: Optional[str] = None
262
+ TABLE_NAMES: List[str] = attrs.Factory(list)
263
+ project: Optional[str] = None # Used for BigQuery
264
+
265
+ # Some database do not support long string so concatenation might lead to type overflow
266
+ _prevent_overflow_when_concat: bool = False
267
+ sybase_driver_type: SybaseDriverTypes = attrs.Factory(lambda: SybaseDriverTypes())
268
+ query_config_for_free_tds: Dict = attrs.Factory(
269
+ lambda: {
270
+ "ase_query_chosen": False,
271
+ "freetds_query_chosen": False,
272
+ }
273
+ )
274
+ table_schema: Dict[str, RawColumnInfo] = attrs.Factory(dict)
275
+
276
+ def enable_preventing_type_overflow(self) -> None:
277
+ logger.info("Preventing type overflow when concatenation is enabled")
278
+ self._prevent_overflow_when_concat = True
279
+
280
+ def parse_table_name(self, name: str) -> DbPath:
281
+ "Parse the given table name into a DbPath"
282
+ return parse_table_name(name)
283
+
284
+ def compile(self, compiler: Compiler, elem, params=None) -> str:
285
+ if params:
286
+ cv_params.set(params)
287
+
288
+ if compiler.root and isinstance(elem, Compilable) and not isinstance(elem, Root):
289
+ from data_diff.queries.ast_classes import Select
290
+
291
+ elem = Select(columns=[elem])
292
+
293
+ res = self._compile(compiler, elem)
294
+ if compiler.root and compiler._subqueries:
295
+ subq = ", ".join(f"\n {k} AS ({v})" for k, v in compiler._subqueries.items())
296
+ compiler._subqueries.clear()
297
+ return f"WITH {subq}\n{res}"
298
+ return res
299
+
300
+ def _compile(self, compiler: Compiler, elem) -> str:
301
+ if elem is None:
302
+ return "NULL"
303
+ elif isinstance(elem, Compilable):
304
+ return self.render_compilable(attrs.evolve(compiler, root=False), elem)
305
+ elif isinstance(elem, ColType):
306
+ return self.render_coltype(attrs.evolve(compiler, root=False), elem)
307
+ elif isinstance(elem, str):
308
+ return f"'{elem}'"
309
+ elif isinstance(elem, (int, float)):
310
+ return str(elem)
311
+ elif isinstance(elem, datetime):
312
+ return self.timestamp_value(elem)
313
+ elif isinstance(elem, bytes):
314
+ return f"b'{elem.decode()}'"
315
+ elif isinstance(elem, ArithUUID):
316
+ if self.name.lower() == "trino":
317
+ return f"CAST('{elem.uuid}' AS UUID)"
318
+ s = f"'{elem.uuid}'"
319
+ return s.upper() if elem.uppercase else s.lower() if elem.lowercase else s
320
+ elif isinstance(elem, (ArithDateTime, ArithTimestamp, ArithTimestampTZ)):
321
+ return self.timestamp_value(elem._dt)
322
+ elif isinstance(elem, ArithDate):
323
+ from datetime import time
324
+
325
+ return self.timestamp_value(datetime.combine(elem._date, time.min))
326
+ elif isinstance(elem, ArithString):
327
+ return f"'{elem}'"
328
+ assert False, elem
329
+
330
+ def render_compilable(self, c: Compiler, elem: Compilable) -> str:
331
+ # All ifs are only for better code navigation, IDE usage detection, and type checking.
332
+ # The last catch-all would render them anyway — it is a typical "visitor" pattern.
333
+ if isinstance(elem, Column):
334
+ return self.render_column(c, elem)
335
+ elif isinstance(elem, Cte):
336
+ return self.render_cte(c, elem)
337
+ elif isinstance(elem, Commit):
338
+ return self.render_commit(c, elem)
339
+ elif isinstance(elem, Param):
340
+ return self.render_param(c, elem)
341
+ elif isinstance(elem, NormalizeAsString):
342
+ return self.render_normalizeasstring(c, elem)
343
+ elif isinstance(elem, ApplyFuncAndNormalizeAsString):
344
+ return self.render_applyfuncandnormalizeasstring(c, elem)
345
+ elif isinstance(elem, Checksum):
346
+ return self.render_checksum(c, elem)
347
+ elif isinstance(elem, Concat):
348
+ return self.render_concat(c, elem)
349
+ elif isinstance(elem, Func):
350
+ return self.render_func(c, elem)
351
+ elif isinstance(elem, WhenThen):
352
+ return self.render_whenthen(c, elem)
353
+ elif isinstance(elem, CaseWhen):
354
+ return self.render_casewhen(c, elem)
355
+ elif isinstance(elem, IsDistinctFrom):
356
+ return self.render_isdistinctfrom(c, elem)
357
+ elif isinstance(elem, UnaryOp):
358
+ return self.render_unaryop(c, elem)
359
+ elif isinstance(elem, BinOp):
360
+ return self.render_binop(c, elem)
361
+ elif isinstance(elem, TablePath):
362
+ return self.render_tablepath(c, elem)
363
+ elif isinstance(elem, TableAlias):
364
+ return self.render_tablealias(c, elem)
365
+ elif isinstance(elem, TableOp):
366
+ return self.render_tableop(c, elem)
367
+ elif isinstance(elem, Select):
368
+ return self.render_select(c, elem)
369
+ elif isinstance(elem, Join):
370
+ return self.render_join(c, elem)
371
+ elif isinstance(elem, GroupBy):
372
+ return self.render_groupby(c, elem)
373
+ elif isinstance(elem, Count):
374
+ return self.render_count(c, elem)
375
+ elif isinstance(elem, Alias):
376
+ return self.render_alias(c, elem)
377
+ elif isinstance(elem, In):
378
+ return self.render_in(c, elem)
379
+ elif isinstance(elem, Cast):
380
+ return self.render_cast(c, elem)
381
+ elif isinstance(elem, Random):
382
+ return self.render_random(c, elem)
383
+ elif isinstance(elem, Explain):
384
+ return self.render_explain(c, elem)
385
+ elif isinstance(elem, CurrentTimestamp):
386
+ return self.render_currenttimestamp(c, elem)
387
+ elif isinstance(elem, CreateTable):
388
+ return self.render_createtable(c, elem)
389
+ elif isinstance(elem, DropTable):
390
+ return self.render_droptable(c, elem)
391
+ elif isinstance(elem, TruncateTable):
392
+ return self.render_truncatetable(c, elem)
393
+ elif isinstance(elem, InsertToTable):
394
+ return self.render_inserttotable(c, elem)
395
+ elif isinstance(elem, Code):
396
+ return self.render_code(c, elem)
397
+ elif isinstance(elem, _ResolveColumn):
398
+ return self.render__resolvecolumn(c, elem)
399
+
400
+ method_name = f"render_{elem.__class__.__name__.lower()}"
401
+ method = getattr(self, method_name, None)
402
+ if method is not None:
403
+ return method(c, elem)
404
+ else:
405
+ raise RuntimeError(f"Cannot render AST of type {elem.__class__}")
406
+ # return elem.compile(compiler.replace(root=False))
407
+
408
+ def render_coltype(self, c: Compiler, elem: ColType) -> str:
409
+ return self.type_repr(elem)
410
+
411
+ def render_column(self, c: Compiler, elem: Column) -> str:
412
+ if c._table_context:
413
+ if len(c._table_context) > 1:
414
+ aliases = [
415
+ t for t in c._table_context if isinstance(t, TableAlias) and t.source_table is elem.source_table
416
+ ]
417
+ if not aliases:
418
+ return self.quote(elem.name)
419
+ elif len(aliases) > 1:
420
+ raise CompileError(f"Too many aliases for column {elem.name}")
421
+ (alias,) = aliases
422
+
423
+ return f"{self.quote(alias.name)}.{self.quote(elem.name)}"
424
+
425
+ return self.quote(elem.name)
426
+
427
+ def render_cte(self, parent_c: Compiler, elem: Cte) -> str:
428
+ c: Compiler = attrs.evolve(parent_c, table_context=[], in_select=False)
429
+ compiled = self.compile(c, elem.source_table)
430
+
431
+ name = elem.name or parent_c.new_unique_name()
432
+ name_params = f"{name}({', '.join(elem.params)})" if elem.params else name
433
+ parent_c._subqueries[name_params] = compiled
434
+
435
+ return name
436
+
437
+ def render_commit(self, c: Compiler, elem: Commit) -> str:
438
+ return "COMMIT" if not c.database.is_autocommit else SKIP
439
+
440
+ def render_param(self, c: Compiler, elem: Param) -> str:
441
+ params = cv_params.get()
442
+ return self._compile(c, params[elem.name])
443
+
444
+ def render_normalizeasstring(self, c: Compiler, elem: NormalizeAsString) -> str:
445
+ expr = self.compile(c, elem.expr)
446
+ return self.normalize_value_by_type(expr, elem.expr_type or elem.expr.type)
447
+
448
+ def render_applyfuncandnormalizeasstring(self, c: Compiler, elem: ApplyFuncAndNormalizeAsString) -> str:
449
+ expr = elem.expr
450
+ expr_type = expr.type
451
+
452
+ if isinstance(expr_type, Native_UUID):
453
+ # Normalize first, apply template after (for uuids)
454
+ # Needed because min/max(uuid) fails in postgresql
455
+ expr = NormalizeAsString(expr, expr_type)
456
+ if elem.apply_func is not None:
457
+ expr = elem.apply_func(expr) # Apply template using Python's string formatting
458
+
459
+ else:
460
+ # Apply template before normalizing (for ints)
461
+ if elem.apply_func is not None:
462
+ expr = elem.apply_func(expr) # Apply template using Python's string formatting
463
+ expr = NormalizeAsString(expr, expr_type)
464
+
465
+ return self.compile(c, expr)
466
+
467
+ def render_checksum(self, c: Compiler, elem: Checksum) -> str:
468
+ if len(elem.exprs) > 1:
469
+ exprs = [Code(f"coalesce({self.compile(c, expr)}, '<null>')") for expr in elem.exprs]
470
+ # exprs = [self.compile(c, e) for e in exprs]
471
+ expr = Concat(exprs, "|")
472
+ else:
473
+ # No need to coalesce - safe to assume that key cannot be null
474
+ (expr,) = elem.exprs
475
+ expr = self.compile(c, expr)
476
+ md5 = self.md5_as_int(expr)
477
+ return f"sum({md5})"
478
+
479
+ def render_concat(self, c: Compiler, elem: Concat) -> str:
480
+ if self._prevent_overflow_when_concat:
481
+ items = [
482
+ f"{self.compile(c, Code(self.md5_as_hex(self.to_string(self.compile(c, expr)))))}"
483
+ for expr in elem.exprs
484
+ ]
485
+
486
+ # We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL
487
+ else:
488
+ items = [
489
+ f"coalesce({self.compile(c, Code(self.to_string(self.compile(c, expr))))}, '<null>')"
490
+ for expr in elem.exprs
491
+ ]
492
+
493
+ assert items
494
+ if len(items) == 1:
495
+ return items[0]
496
+
497
+ if elem.sep:
498
+ items = list(join_iter(f"'{elem.sep}'", items))
499
+ return self.concat(items)
500
+
501
+ def render_alias(self, c: Compiler, elem: Alias) -> str:
502
+ return f"{self.compile(c, elem.expr)} AS {self.quote(elem.name)}"
503
+
504
+ def render_count(self, c: Compiler, elem: Count) -> str:
505
+ expr = self.compile(c, elem.expr) if elem.expr else "*"
506
+ if elem.distinct:
507
+ return f"count(distinct {expr})"
508
+ return f"count({expr})"
509
+
510
+ def render_code(self, c: Compiler, elem: Code) -> str:
511
+ if not elem.args:
512
+ return elem.code
513
+
514
+ args = {k: self.compile(c, v) for k, v in elem.args.items()}
515
+ return elem.code.format(**args)
516
+
517
+ def render_func(self, c: Compiler, elem: Func) -> str:
518
+ args = ", ".join(self.compile(c, e) for e in elem.args)
519
+ return f"{elem.name}({args})"
520
+
521
+ def render_whenthen(self, c: Compiler, elem: WhenThen) -> str:
522
+ return f"WHEN {self.compile(c, elem.when)} THEN {self.compile(c, elem.then)}"
523
+
524
+ def render_casewhen(self, c: Compiler, elem: CaseWhen) -> str:
525
+ assert elem.cases
526
+ when_thens = " ".join(self.compile(c, case) for case in elem.cases)
527
+ else_expr = (" ELSE " + self.compile(c, elem.else_expr)) if elem.else_expr is not None else ""
528
+ return f"CASE {when_thens}{else_expr} END"
529
+
530
+ def render_isdistinctfrom(self, c: Compiler, elem: IsDistinctFrom) -> str:
531
+ a = self.to_comparable(self.compile(c, elem.a), elem.a.type)
532
+ b = self.to_comparable(self.compile(c, elem.b), elem.b.type)
533
+ return self.is_distinct_from(a, b)
534
+
535
+ def render_unaryop(self, c: Compiler, elem: UnaryOp) -> str:
536
+ return f"({elem.op}{self.compile(c, elem.expr)})"
537
+
538
+ def render_binop(self, c: Compiler, elem: BinOp) -> str:
539
+ expr = f" {elem.op} ".join(self.compile(c, a) for a in elem.args)
540
+ return f"({expr})"
541
+
542
+ def render_tablepath(self, c: Compiler, elem: TablePath) -> str:
543
+ path = elem.path # c.database._normalize_table_path(self.name)
544
+ return ".".join(map(lambda p: self.quote(p, True), path))
545
+
546
+ def render_tablealias(self, c: Compiler, elem: TableAlias) -> str:
547
+ return f"{self.compile(c, elem.source_table)} {self.quote(elem.name)}"
548
+
549
+ def render_tableop(self, parent_c: Compiler, elem: TableOp) -> str:
550
+ c: Compiler = attrs.evolve(parent_c, in_select=False)
551
+ table_expr = f"{self.compile(c, elem.table1)} {elem.op} {self.compile(c, elem.table2)}"
552
+ if parent_c.in_select:
553
+ table_expr = f"({table_expr}) {c.new_unique_name()}"
554
+ elif parent_c.in_join:
555
+ table_expr = f"({table_expr})"
556
+ return table_expr
557
+
558
+ def render__resolvecolumn(self, c: Compiler, elem: _ResolveColumn) -> str:
559
+ return self.compile(c, elem._get_resolved())
560
+
561
+ def render_select(self, parent_c: Compiler, elem: Select) -> str:
562
+ c: Compiler = attrs.evolve(parent_c, in_select=True) # .add_table_context(self.table)
563
+ compile_fn = functools.partial(self.compile, c)
564
+
565
+ columns = ", ".join(map(compile_fn, elem.columns)) if elem.columns else "*"
566
+ distinct = "DISTINCT " if elem.distinct else ""
567
+ optimizer_hints = self.optimizer_hints(elem.optimizer_hints) if elem.optimizer_hints else ""
568
+ select = f"SELECT {optimizer_hints}{distinct}{columns}"
569
+
570
+ if elem.table:
571
+ select += " FROM " + self.compile(c, elem.table)
572
+ elif self.PLACEHOLDER_TABLE:
573
+ select += f" FROM {self.PLACEHOLDER_TABLE}"
574
+
575
+ if elem.where_exprs:
576
+ select += " WHERE " + " AND ".join(map(compile_fn, elem.where_exprs))
577
+
578
+ if elem.group_by_exprs:
579
+ select += " GROUP BY " + ", ".join(map(compile_fn, elem.group_by_exprs))
580
+
581
+ if elem.having_exprs:
582
+ assert elem.group_by_exprs
583
+ select += " HAVING " + " AND ".join(map(compile_fn, elem.having_exprs))
584
+
585
+ if elem.order_by_exprs:
586
+ select += " ORDER BY " + ", ".join(map(compile_fn, elem.order_by_exprs))
587
+
588
+ if elem.limit_expr is not None:
589
+ has_order_by = bool(elem.order_by_exprs)
590
+ select = self.limit_select(select_query=select, offset=0, limit=elem.limit_expr, has_order_by=has_order_by)
591
+
592
+ if parent_c.in_select:
593
+ select = f"({select}) {c.new_unique_name()}"
594
+ elif parent_c.in_join:
595
+ select = f"({select})"
596
+ return select
597
+
598
+ def render_join(self, parent_c: Compiler, elem: Join) -> str:
599
+ tables = [
600
+ t if isinstance(t, TableAlias) else TableAlias(t, name=parent_c.new_unique_name())
601
+ for t in elem.source_tables
602
+ ]
603
+ c = parent_c.add_table_context(*tables, in_join=True, in_select=False)
604
+ op = " JOIN " if elem.op is None else f" {elem.op} JOIN "
605
+ joined = op.join(self.compile(c, t) for t in tables)
606
+
607
+ if elem.on_exprs:
608
+ on = " AND ".join(self.compile(c, e) for e in elem.on_exprs)
609
+ res = f"{joined} ON {on}"
610
+ else:
611
+ res = joined
612
+
613
+ compile_fn = functools.partial(self.compile, c)
614
+ columns = "*" if elem.columns is None else ", ".join(map(compile_fn, elem.columns))
615
+ select = f"SELECT {columns} FROM {res}"
616
+
617
+ if parent_c.in_select:
618
+ select = f"({select}) {c.new_unique_name()}"
619
+ elif parent_c.in_join:
620
+ select = f"({select})"
621
+ return select
622
+
623
+ def render_groupby(self, c: Compiler, elem: GroupBy) -> str:
624
+ compile_fn = functools.partial(self.compile, c)
625
+
626
+ if elem.values is None:
627
+ raise CompileError(".group_by() must be followed by a call to .agg()")
628
+
629
+ keys = [str(i + 1) for i in range(len(elem.keys))]
630
+ columns = (elem.keys or []) + (elem.values or [])
631
+ if isinstance(elem.table, Select) and elem.table.columns is None and elem.table.group_by_exprs is None:
632
+ return self.compile(
633
+ c,
634
+ attrs.evolve(
635
+ elem.table,
636
+ columns=columns,
637
+ group_by_exprs=[Code(k) for k in keys],
638
+ having_exprs=elem.having_exprs,
639
+ ),
640
+ )
641
+
642
+ keys_str = ", ".join(keys)
643
+ columns_str = ", ".join(self.compile(c, x) for x in columns)
644
+ having_str = (
645
+ " HAVING " + " AND ".join(map(compile_fn, elem.having_exprs)) if elem.having_exprs is not None else ""
646
+ )
647
+ select = f"SELECT {columns_str} FROM {self.compile(attrs.evolve(c, in_select=True), elem.table)} GROUP BY {keys_str}{having_str}"
648
+
649
+ if c.in_select:
650
+ select = f"({select}) {c.new_unique_name()}"
651
+ elif c.in_join:
652
+ select = f"({select})"
653
+ return select
654
+
655
+ def render_in(self, c: Compiler, elem: In) -> str:
656
+ compile_fn = functools.partial(self.compile, c)
657
+ elems = ", ".join(map(compile_fn, elem.list))
658
+ return f"({self.compile(c, elem.expr)} IN ({elems}))"
659
+
660
+ def render_cast(self, c: Compiler, elem: Cast) -> str:
661
+ return f"cast({self.compile(c, elem.expr)} as {self.compile(c, elem.target_type)})"
662
+
663
+ def render_random(self, c: Compiler, elem: Random) -> str:
664
+ return self.random()
665
+
666
+ def render_explain(self, c: Compiler, elem: Explain) -> str:
667
+ return self.explain_as_text(self.compile(c, elem.select))
668
+
669
+ def render_currenttimestamp(self, c: Compiler, elem: CurrentTimestamp) -> str:
670
+ return self.current_timestamp()
671
+
672
+ def render_createtable(self, c: Compiler, elem: CreateTable) -> str:
673
+ ne = "IF NOT EXISTS " if elem.if_not_exists else ""
674
+ if elem.source_table:
675
+ return f"CREATE TABLE {ne}{self.compile(c, elem.path)} AS {self.compile(c, elem.source_table)}"
676
+
677
+ schema = ", ".join(f"{self.quote(k)} {self.type_repr(v)}" for k, v in elem.path.schema.items())
678
+ pks = (
679
+ ", PRIMARY KEY (%s)" % ", ".join(elem.primary_keys)
680
+ if elem.primary_keys and self.SUPPORTS_PRIMARY_KEY
681
+ else ""
682
+ )
683
+ return f"CREATE TABLE {ne}{self.compile(c, elem.path)}({schema}{pks})"
684
+
685
+ def render_droptable(self, c: Compiler, elem: DropTable) -> str:
686
+ ie = "IF EXISTS " if elem.if_exists else ""
687
+ return f"DROP TABLE {ie}{self.compile(c, elem.path)}"
688
+
689
+ def render_truncatetable(self, c: Compiler, elem: TruncateTable) -> str:
690
+ return f"TRUNCATE TABLE {self.compile(c, elem.path)}"
691
+
692
+ def render_inserttotable(self, c: Compiler, elem: InsertToTable) -> str:
693
+ if isinstance(elem.expr, ConstantTable):
694
+ expr = self.constant_values(elem.expr.rows)
695
+ else:
696
+ expr = self.compile(c, elem.expr)
697
+
698
+ columns = "(%s)" % ", ".join(map(self.quote, elem.columns)) if elem.columns is not None else ""
699
+
700
+ return f"INSERT INTO {self.compile(c, elem.path)}{columns} {expr}"
701
+
702
+ def limit_select(
703
+ self,
704
+ select_query: str,
705
+ offset: Optional[int] = None,
706
+ limit: Optional[int] = None,
707
+ has_order_by: Optional[bool] = None,
708
+ ) -> str:
709
+ if offset:
710
+ raise NotImplementedError("No support for OFFSET in query")
711
+
712
+ return f"SELECT * FROM ({select_query}) AS LIMITED_SELECT LIMIT {limit}"
713
+
714
+ def concat(self, items: List[str]) -> str:
715
+ "Provide SQL for concatenating a bunch of columns into a string"
716
+ assert len(items) > 1
717
+ joined_exprs = ", ".join(items)
718
+ return f"concat({joined_exprs})"
719
+
720
+ def to_comparable(self, value: str, coltype: ColType) -> str:
721
+ """Ensure that the expression is comparable in ``IS DISTINCT FROM``."""
722
+ return value
723
+
724
+ def is_distinct_from(self, a: str, b: str) -> str:
725
+ "Provide SQL for a comparison where NULL = NULL is true"
726
+ return f"{a} is distinct from {b}"
727
+
728
+ def timestamp_value(self, t: DbTime) -> str:
729
+ "Provide SQL for the given timestamp value"
730
+ return f"'{t.isoformat()}'"
731
+
732
+ def random(self) -> str:
733
+ "Provide SQL for generating a random number betweein 0..1"
734
+ return "random()"
735
+
736
+ def current_timestamp(self) -> str:
737
+ "Provide SQL for returning the current timestamp, aka now"
738
+ return "current_timestamp()"
739
+
740
+ def current_database(self) -> str:
741
+ "Provide SQL for returning the current default database."
742
+ return "current_database()"
743
+
744
+ def current_schema(self) -> str:
745
+ "Provide SQL for returning the current default schema."
746
+ return "current_schema()"
747
+
748
+ def explain_as_text(self, query: str) -> str:
749
+ "Provide SQL for explaining a query, returned as table(varchar)"
750
+ return f"EXPLAIN {query}"
751
+
752
+ def _constant_value(self, v):
753
+ if v is None:
754
+ return "NULL"
755
+ elif isinstance(v, str):
756
+ return f"'{v}'"
757
+ elif isinstance(v, datetime):
758
+ return self.timestamp_value(v)
759
+ elif isinstance(v, UUID): # probably unused anymore in favour of ArithUUID
760
+ return f"'{v}'"
761
+ elif isinstance(v, ArithUUID):
762
+ return f"'{v.uuid}'"
763
+ elif isinstance(v, decimal.Decimal):
764
+ return str(v)
765
+ elif isinstance(v, bytearray):
766
+ return f"'{v.decode()}'"
767
+ elif isinstance(v, Code):
768
+ return v.code
769
+ elif isinstance(v, ArithAlphanumeric):
770
+ return f"'{v._str}'"
771
+ elif isinstance(v, ArithUnicodeString):
772
+ return f"'{v._str}'"
773
+ elif isinstance(v, ArithDate):
774
+ return f"'{str(v)}'"
775
+ elif isinstance(v, ArithTimestamp):
776
+ return f"'{str(v)}'"
777
+ elif isinstance(v, ArithTimestampTZ):
778
+ return f"'{str(v)}'"
779
+ elif isinstance(v, ArithDateTime):
780
+ return self.timestamp_value(v._dt)
781
+ return repr(v)
782
+
783
+ def constant_values(self, rows) -> str:
784
+ values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows)
785
+ return f"VALUES {values}"
786
+
787
+ def type_repr(self, t) -> str:
788
+ if isinstance(t, str):
789
+ return t
790
+ elif isinstance(t, TimestampTZ):
791
+ return f"TIMESTAMP({min(t.precision, DEFAULT_DATETIME_PRECISION)})"
792
+ return {
793
+ int: "INT",
794
+ str: "VARCHAR",
795
+ bool: "BOOLEAN",
796
+ float: "FLOAT",
797
+ datetime: "TIMESTAMP",
798
+ }[t]
799
+
800
+ def parse_type(self, table_path: DbPath, info: RawColumnInfo) -> ColType:
801
+ "Parse type info as returned by the database"
802
+
803
+ cls = self.TYPE_CLASSES.get(info.data_type)
804
+ if cls is None:
805
+ return UnknownColType(info.data_type)
806
+
807
+ if issubclass(cls, TemporalType):
808
+ return cls(
809
+ precision=(
810
+ info.datetime_precision if info.datetime_precision is not None else DEFAULT_DATETIME_PRECISION
811
+ ),
812
+ rounds=self.ROUNDS_ON_PREC_LOSS,
813
+ )
814
+
815
+ elif issubclass(cls, Integer):
816
+ return cls()
817
+
818
+ elif issubclass(cls, Boolean):
819
+ return cls()
820
+
821
+ elif issubclass(cls, Decimal):
822
+ if info.numeric_scale is None:
823
+ return cls(precision=0) # Needed for Oracle.
824
+ return cls(precision=info.numeric_scale)
825
+
826
+ elif issubclass(cls, Float):
827
+ # assert numeric_scale is None
828
+ return cls(
829
+ precision=self._convert_db_precision_to_digits(
830
+ info.numeric_precision if info.numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
831
+ )
832
+ )
833
+
834
+ elif issubclass(cls, (JSON, Array, Struct, Text, Native_UUID)):
835
+ return cls()
836
+
837
+ raise TypeError(f"Parsing {info.data_type} returned an unknown type {cls!r}.")
838
+
839
+ def _convert_db_precision_to_digits(self, p: int) -> int:
840
+ """Convert from binary precision, used by floats, to decimal precision."""
841
+ # See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format
842
+ return math.floor(math.log(2**p, 10))
843
+
844
+ @property
845
+ @abstractmethod
846
+ def name(self) -> str:
847
+ "Name of the dialect"
848
+
849
+ @property
850
+ @abstractmethod
851
+ def ROUNDS_ON_PREC_LOSS(self) -> bool:
852
+ "True if db rounds real values when losing precision, False if it truncates."
853
+
854
+ @abstractmethod
855
+ def quote(self, s: str, is_table: bool = False) -> str:
856
+ "Quote SQL name"
857
+
858
+ @abstractmethod
859
+ def to_string(self, s: str) -> str:
860
+ # TODO rewrite using cast_to(x, str)
861
+ "Provide SQL for casting a column to string"
862
+
863
+ @abstractmethod
864
+ def set_timezone_to_utc(self) -> str:
865
+ "Provide SQL for setting the session timezone to UTC"
866
+
867
+ @abstractmethod
868
+ def md5_as_int(self, s: str) -> str:
869
+ "Provide SQL for computing md5 and returning an int"
870
+
871
+ @abstractmethod
872
+ def md5_as_hex(self, s: str) -> str:
873
+ """Method to calculate MD5"""
874
+
875
+ @abstractmethod
876
+ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
877
+ """Creates an SQL expression, that converts 'value' to a normalized timestamp.
878
+
879
+ The returned expression must accept any SQL datetime/timestamp, and return a string.
880
+
881
+ Date format: ``YYYY-MM-DD HH:mm:SS.FFFFFF``
882
+
883
+ Precision of dates should be rounded up/down according to coltype.rounds
884
+ e.g. precision 3 and coltype.rounds:
885
+ - 1969-12-31 23:59:59.999999 -> 1970-01-01 00:00:00.000000
886
+ - 1970-01-01 00:00:00.000888 -> 1970-01-01 00:00:00.001000
887
+ - 1970-01-01 00:00:00.123123 -> 1970-01-01 00:00:00.123000
888
+
889
+ Make sure NULLs remain NULLs
890
+ """
891
+
892
+ @abstractmethod
893
+ def normalize_number(self, value: str, coltype: FractionalType) -> str:
894
+ """Creates an SQL expression, that converts 'value' to a normalized number.
895
+
896
+ The returned expression must accept any SQL int/numeric/float, and return a string.
897
+
898
+ Floats/Decimals are expected in the format
899
+ "I.P"
900
+
901
+ Where I is the integer part of the number (as many digits as necessary),
902
+ and must be at least one digit (0).
903
+ P is the fractional digits, the amount of which is specified with
904
+ coltype.precision. Trailing zeroes may be necessary.
905
+ If P is 0, the dot is omitted.
906
+
907
+ Note: We use 'precision' differently than most databases. For decimals,
908
+ it's the same as ``numeric_scale``, and for floats, who use binary precision,
909
+ it can be calculated as ``log10(2**numeric_precision)``.
910
+ """
911
+
912
+ def normalize_boolean(self, value: str, _coltype: Boolean) -> str:
913
+ """Creates an SQL expression, that converts 'value' to either '0' or '1'."""
914
+ return self.to_string(value)
915
+
916
+ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
917
+ """Creates an SQL expression, that strips uuids of artifacts like whitespace."""
918
+ if isinstance(coltype, String_UUID):
919
+ return f"TRIM({value})"
920
+ return self.to_string(value)
921
+
922
+ def normalize_json(self, value: str, _coltype: JSON) -> str:
923
+ """Creates an SQL expression, that converts 'value' to its minified json string representation."""
924
+ return self.to_string(value)
925
+
926
+ def normalize_array(self, value: str, _coltype: Array) -> str:
927
+ """Creates an SQL expression, that serialized an array into a JSON string."""
928
+ return self.to_string(value)
929
+
930
+ def normalize_struct(self, value: str, _coltype: Struct) -> str:
931
+ """Creates an SQL expression, that serialized a typed struct into a JSON string."""
932
+ return self.to_string(value)
933
+
934
+ def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
935
+ """Creates an SQL expression, that converts 'value' to a normalized representation.
936
+
937
+ The returned expression must accept any SQL value, and return a string.
938
+
939
+ The default implementation dispatches to a method according to `coltype`:
940
+
941
+ ::
942
+
943
+ TemporalType -> normalize_timestamp()
944
+ FractionalType -> normalize_number()
945
+ *else* -> to_string()
946
+
947
+ (`Integer` falls in the *else* category)
948
+
949
+ """
950
+ if isinstance(coltype, TemporalType):
951
+ return self.normalize_timestamp(value, coltype)
952
+ elif isinstance(coltype, FractionalType):
953
+ return self.normalize_number(value, coltype)
954
+ elif isinstance(coltype, ColType_UUID):
955
+ return self.normalize_uuid(value, coltype)
956
+ elif isinstance(coltype, Boolean):
957
+ return self.normalize_boolean(value, coltype)
958
+ elif isinstance(coltype, JSON):
959
+ return self.normalize_json(value, coltype)
960
+ elif isinstance(coltype, Array):
961
+ return self.normalize_array(value, coltype)
962
+ elif isinstance(coltype, Struct):
963
+ return self.normalize_struct(value, coltype)
964
+ return self.to_string(value)
965
+
966
+ def optimizer_hints(self, hints: str) -> str:
967
+ return f"/*+ {hints} */ "
968
+
969
+ def generate_view_name(self, view_name: str | None = None) -> str:
970
+ if view_name is not None:
971
+ return view_name
972
+ random_string = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(8))
973
+ timestamp = int(time.time())
974
+ return f"view_{timestamp}_{random_string.lower()}"
975
+
976
+ def create_view(self, query: str, schema: str | None, view_name: str | None = None) -> Tuple[str, str]:
977
+ view_name = self.generate_view_name(view_name=view_name)
978
+ schema_prefix = f"{schema}." if schema else ""
979
+ return f"CREATE VIEW {schema_prefix}{view_name} AS {query}", view_name
980
+
981
+ def drop_view(self, view_name: str, schema: str | None) -> str:
982
+ schema_prefix = f"{schema}." if schema else ""
983
+ return f"DROP VIEW {schema_prefix}{view_name}"
984
+
985
+ def get_column_raw_info(self, value: str) -> Optional[RawColumnInfo]:
986
+ raw_column_info = self.table_schema.get(value, None)
987
+ return raw_column_info
988
+
989
+
990
+ T = TypeVar("T", bound=BaseDialect)
991
+ Row = Sequence[Any]
992
+
993
+
994
+ @attrs.define(frozen=True)
995
+ class QueryResult:
996
+ rows: List[Row]
997
+ columns: Optional[list] = None
998
+
999
+ def __iter__(self) -> Iterator[Row]:
1000
+ return iter(self.rows)
1001
+
1002
+ def __len__(self) -> int:
1003
+ return len(self.rows)
1004
+
1005
+ def __getitem__(self, i) -> Row:
1006
+ return self.rows[i]
1007
+
1008
+
1009
+ @attrs.define(frozen=False, kw_only=True)
1010
+ class Database(abc.ABC):
1011
+ """Base abstract class for databases.
1012
+
1013
+ Used for providing connection code and implementation specific SQL utilities.
1014
+
1015
+ Instanciated using :meth:`~data_diff.connect`
1016
+ """
1017
+
1018
+ DIALECT_CLASS: ClassVar[Type[BaseDialect]] = BaseDialect
1019
+
1020
+ SUPPORTS_ALPHANUMS: ClassVar[bool] = True
1021
+ SUPPORTS_UNIQUE_CONSTAINT: ClassVar[bool] = False
1022
+ CONNECT_URI_KWPARAMS: ClassVar[List[str]] = []
1023
+
1024
+ default_schema: Optional[str] = None
1025
+ _interactive: bool = False
1026
+ is_closed: bool = False
1027
+ _dialect: BaseDialect = None
1028
+
1029
+ def __enter__(self):
1030
+ return self
1031
+
1032
+ def __exit__(self, exc_type, exc_value, traceback):
1033
+ self.close()
1034
+
1035
+ @property
1036
+ def name(self):
1037
+ return type(self).__name__
1038
+
1039
+ def compile(self, sql_ast):
1040
+ return self.dialect.compile(Compiler(self), sql_ast)
1041
+
1042
+ def query(self, sql_ast: Union[Expr, Generator], res_type: type = None, log_message: Optional[str] = None):
1043
+ """Query the given SQL code/AST, and attempt to convert the result to type 'res_type'
1044
+
1045
+ If given a generator, it will execute all the yielded sql queries with the same thread and cursor.
1046
+ The results of the queries a returned by the `yield` stmt (using the .send() mechanism).
1047
+ It's a cleaner approach than exposing cursors, but may not be enough in all cases.
1048
+ """
1049
+
1050
+ compiler = Compiler(self)
1051
+ if isinstance(sql_ast, Generator):
1052
+ sql_code = ThreadLocalInterpreter(compiler, sql_ast)
1053
+ elif isinstance(sql_ast, list):
1054
+ for i in sql_ast[:-1]:
1055
+ self.query(i)
1056
+ return self.query(sql_ast[-1], res_type)
1057
+ else:
1058
+ if isinstance(sql_ast, str):
1059
+ sql_code = sql_ast
1060
+ else:
1061
+ if res_type is None:
1062
+ res_type = sql_ast.type
1063
+ sql_code = self.compile(sql_ast)
1064
+ if sql_code is SKIP:
1065
+ return SKIP
1066
+
1067
+ if log_message:
1068
+ logger.debug("Running SQL (%s): %s \n%s", self.name, log_message, sql_code)
1069
+ else:
1070
+ logger.debug("Running SQL (%s):\n%s", self.name, sql_code)
1071
+
1072
+ if self._interactive and isinstance(sql_ast, Select):
1073
+ explained_sql = self.compile(Explain(sql_ast))
1074
+ explain = self._query(explained_sql)
1075
+ for row in explain:
1076
+ # Most returned a 1-tuple. Presto returns a string
1077
+ if isinstance(row, tuple):
1078
+ (row,) = row
1079
+ logger.debug("EXPLAIN: %s", row)
1080
+ answer = input("Continue? [y/n] ")
1081
+ if answer.lower() not in ["y", "yes"]:
1082
+ sys.exit(1)
1083
+
1084
+ res = self._query(sql_code)
1085
+ if res_type is list:
1086
+ return list(res)
1087
+ elif res_type is int:
1088
+ if not res:
1089
+ raise ValueError("Query returned 0 rows, expected 1")
1090
+ row = _one(res)
1091
+ if not row:
1092
+ raise ValueError("Row is empty, expected 1 column")
1093
+ res = _one(row)
1094
+ if res is None: # May happen due to sum() of 0 items
1095
+ return None
1096
+ return int(res)
1097
+ elif res_type is datetime:
1098
+ res = _one(_one(res))
1099
+ if isinstance(res, str):
1100
+ res = datetime.fromisoformat(res[:23]) # TODO use a better parsing method
1101
+ return res
1102
+ elif res_type is tuple:
1103
+ assert len(res) == 1, (sql_code, res)
1104
+ return res[0]
1105
+ elif getattr(res_type, "__origin__", None) is list and len(res_type.__args__) == 1:
1106
+ if res_type.__args__ in ((int,), (str,)):
1107
+ return [_one(row) for row in res]
1108
+ elif res_type.__args__ in [(Tuple,), (tuple,)]:
1109
+ return [tuple(row) for row in res]
1110
+ elif res_type.__args__ == (dict,):
1111
+ return [dict(safezip(res.columns, row)) for row in res]
1112
+ else:
1113
+ raise ValueError(res_type)
1114
+ return res
1115
+
1116
+ def enable_interactive(self):
1117
+ self._interactive = True
1118
+
1119
+ def select_table_schema(self, path: DbPath) -> str:
1120
+ """Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)"""
1121
+ schema, name = self._normalize_table_path(path)
1122
+
1123
+ return (
1124
+ "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale "
1125
+ "FROM information_schema.columns "
1126
+ f"WHERE table_name = '{name}' AND table_schema = '{schema}'"
1127
+ )
1128
+
1129
+ def safe_get(self, lst, idx, default=None):
1130
+ return lst[idx] if 0 <= idx < len(lst) else default
1131
+
1132
+ def query_table_schema(self, path: DbPath) -> Dict[str, RawColumnInfo]:
1133
+ """Query the table for its schema for table in 'path', and return {column: tuple}
1134
+ where the tuple is (table_name, col_name, type_repr, datetime_precision?, numeric_precision?, numeric_scale?)
1135
+
1136
+ Note: This method exists instead of select_table_schema(), just because not all databases support
1137
+ accessing the schema using a SQL query.
1138
+ """
1139
+ rows = self.query(self.select_table_schema(path), list, log_message=path)
1140
+ if not rows:
1141
+ raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")
1142
+ d = {
1143
+ r[0]: RawColumnInfo(
1144
+ column_name=self.safe_get(r, 0),
1145
+ data_type=self.safe_get(r, 1),
1146
+ datetime_precision=self.safe_get(r, 2),
1147
+ numeric_precision=self.safe_get(r, 3),
1148
+ numeric_scale=self.safe_get(r, 4),
1149
+ collation_name=self.safe_get(r, 5),
1150
+ character_maximum_length=self.safe_get(r, 6),
1151
+ )
1152
+ for r in rows
1153
+ }
1154
+ assert len(d) == len(rows)
1155
+ if not self.dialect.table_schema:
1156
+ self.dialect.table_schema = d
1157
+ return d
1158
+
1159
+ def select_table_unique_columns(self, path: DbPath) -> str:
1160
+ """Provide SQL for selecting the names of unique columns in the table"""
1161
+ schema, name = self._normalize_table_path(path)
1162
+
1163
+ return (
1164
+ "SELECT column_name "
1165
+ "FROM information_schema.key_column_usage "
1166
+ f"WHERE table_name = '{name}' AND table_schema = '{schema}'"
1167
+ )
1168
+
1169
+ def query_table_unique_columns(self, path: DbPath) -> List[str]:
1170
+ """Query the table for its unique columns for table in 'path', and return {column}"""
1171
+ if not self.SUPPORTS_UNIQUE_CONSTAINT:
1172
+ raise NotImplementedError("This database doesn't support 'unique' constraints")
1173
+ res = self.query(self.select_table_unique_columns(path), List[str], log_message=path)
1174
+ return list(res)
1175
+
1176
+ def create_view_from_query(
1177
+ self,
1178
+ query: str,
1179
+ schema: str | None = None,
1180
+ view_name: str | None = None,
1181
+ ) -> str:
1182
+ """Provide SQL for creating a view from a select statement"""
1183
+ schema = schema or self.default_schema
1184
+ view_query, v_name = self.dialect.create_view(query, schema, view_name)
1185
+ self.query(view_query)
1186
+ return v_name
1187
+
1188
+ def drop_view_from_db(self, view_name: str, schema: str | None = None) -> None:
1189
+ schema = schema or self.default_schema
1190
+ view_drop_query = self.dialect.drop_view(view_name, schema)
1191
+ self.query(view_drop_query)
1192
+
1193
+ def _process_table_schema(
1194
+ self,
1195
+ path: DbPath,
1196
+ raw_schema: Dict[str, RawColumnInfo],
1197
+ filter_columns: Sequence[str] = None,
1198
+ where: str = None,
1199
+ ):
1200
+ """Process the result of query_table_schema().
1201
+
1202
+ Done in a separate step, to minimize the amount of processed columns.
1203
+ Needed because processing each column may:
1204
+ * throw errors and warnings
1205
+ * query the database to sample values
1206
+
1207
+ """
1208
+ if filter_columns is None:
1209
+ filtered_schema = raw_schema
1210
+ else:
1211
+ accept = {i.lower() for i in filter_columns}
1212
+ filtered_schema = {name: row for name, row in raw_schema.items() if name.lower() in accept}
1213
+
1214
+ col_dict = {info.column_name: self.dialect.parse_type(path, info) for info in filtered_schema.values()}
1215
+
1216
+ self._refine_coltypes(path, col_dict, where)
1217
+
1218
+ # Return a dict of form {name: type} after normalization
1219
+ return col_dict
1220
+
1221
+ def _refine_coltypes(
1222
+ self, table_path: DbPath, col_dict: Dict[str, ColType], where: Optional[str] = None, sample_size=64
1223
+ ) -> Dict[str, ColType]:
1224
+ """Refine the types in the column dict, by querying the database for a sample of their values
1225
+
1226
+ 'where' restricts the rows to be sampled.
1227
+ """
1228
+
1229
+ text_columns = [k for k, v in col_dict.items() if isinstance(v, Text)]
1230
+ if not text_columns:
1231
+ return col_dict
1232
+
1233
+ fields = [Code(self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID())) for c in text_columns]
1234
+
1235
+ samples_by_row = self.query(
1236
+ table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size),
1237
+ list,
1238
+ log_message=table_path,
1239
+ )
1240
+ samples_by_col = list(zip(*samples_by_row)) if samples_by_row else [[]] * len(text_columns)
1241
+ for col_name, samples in safezip(text_columns, samples_by_col):
1242
+ uuid_samples = [s for s in samples if s and is_uuid(s)]
1243
+
1244
+ if uuid_samples:
1245
+ if len(uuid_samples) != len(samples):
1246
+ logger.warning(
1247
+ f"Mixed UUID/Non-UUID values detected in column {'.'.join(table_path)}.{col_name}, disabling UUID support."
1248
+ )
1249
+ else:
1250
+ assert col_name in col_dict
1251
+ col_dict[col_name] = String_UUID(
1252
+ lowercase=all(s == s.lower() for s in uuid_samples),
1253
+ uppercase=all(s == s.upper() for s in uuid_samples),
1254
+ )
1255
+ continue
1256
+
1257
+ if self.SUPPORTS_ALPHANUMS: # Anything but MySQL (so far)
1258
+ alphanum_samples = [s for s in samples if String_Alphanum.test_value(s)]
1259
+ if alphanum_samples:
1260
+ if len(alphanum_samples) != len(samples):
1261
+ logger.debug(
1262
+ f"Mixed Alphanum/Non-Alphanum values detected in column {'.'.join(table_path)}.{col_name}. It cannot be used as a key."
1263
+ )
1264
+ # Fallback to Unicode string type
1265
+ assert col_name in col_dict
1266
+ col_dict[col_name] = String_VaryingUnicode(collation=col_dict[col_name].collation)
1267
+ else:
1268
+ assert col_name in col_dict
1269
+ col_dict[col_name] = String_VaryingAlphanum(collation=col_dict[col_name].collation)
1270
+ else:
1271
+ # All samples failed alphanum test, fallback to Unicode string
1272
+ assert col_name in col_dict
1273
+ col_dict[col_name] = String_VaryingUnicode(collation=col_dict[col_name].collation)
1274
+
1275
+ return col_dict
1276
+
1277
+ def _normalize_table_path(self, path: DbPath) -> DbPath:
1278
+ if len(path) == 1:
1279
+ return self.default_schema, path[0]
1280
+ elif len(path) == 2:
1281
+ return path
1282
+
1283
+ raise ValueError(f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table")
1284
+
1285
+ def _query_cursor(self, c, sql_code: str) -> QueryResult:
1286
+ assert isinstance(sql_code, str), sql_code
1287
+ try:
1288
+ c.execute(sql_code)
1289
+ if sql_code.lower().startswith(("select", "explain", "show")):
1290
+ columns = [col[0] for col in c.description]
1291
+
1292
+ fetched = c.fetchall()
1293
+ result = QueryResult(fetched, columns)
1294
+ return result
1295
+ elif sql_code.lower().startswith(("create", "drop")):
1296
+ try:
1297
+ c.connection.commit()
1298
+ except AttributeError:
1299
+ ...
1300
+ except Exception as _e:
1301
+ try:
1302
+ c.connection.rollback()
1303
+ except Exception as rollback_error:
1304
+ loguru_logger.error(f"Rollback failed: {rollback_error}")
1305
+ raise
1306
+
1307
+ def _query_conn(self, conn, sql_code: Union[str, ThreadLocalInterpreter]) -> QueryResult:
1308
+ c = conn.cursor()
1309
+ callback = partial(self._query_cursor, c)
1310
+ return apply_query(callback, sql_code)
1311
+
1312
+ def quote(self, s: str, is_table: bool = False):
1313
+ """Quote column name"""
1314
+ return self.dialect.quote(s)
1315
+
1316
+ def close(self):
1317
+ """Close connection(s) to the database instance. Querying will stop functioning."""
1318
+ self.is_closed = True
1319
+
1320
+ @property
1321
+ def dialect(self) -> BaseDialect:
1322
+ "The dialect of the database. Used internally by Database, and also available publicly."
1323
+
1324
+ if not self._dialect:
1325
+ self._dialect = self.DIALECT_CLASS()
1326
+ return self._dialect
1327
+
1328
+ @property
1329
+ @abstractmethod
1330
+ def CONNECT_URI_HELP(self) -> str:
1331
+ "Example URI to show the user in help and error messages"
1332
+
1333
+ @property
1334
+ @abstractmethod
1335
+ def CONNECT_URI_PARAMS(self) -> List[str]:
1336
+ "List of parameters given in the path of the URI"
1337
+
1338
+ @abstractmethod
1339
+ def _query(self, sql_code: str) -> list:
1340
+ "Send query to database and return result"
1341
+
1342
+ @property
1343
+ @abstractmethod
1344
+ def is_autocommit(self) -> bool:
1345
+ "Return whether the database autocommits changes. When false, COMMIT statements are skipped."
1346
+
1347
+
1348
+ @attrs.define(frozen=False)
1349
+ class ThreadedDatabase(Database):
1350
+ """Access the database through singleton threads.
1351
+
1352
+ Used for database connectors that do not support sharing their connection between different threads.
1353
+ """
1354
+
1355
+ thread_count: int = 1
1356
+
1357
+ _init_error: Optional[Exception] = None
1358
+ _queue: Optional[ThreadPoolExecutor] = None
1359
+ thread_local: threading.local = attrs.field(factory=threading.local)
1360
+
1361
+ def __attrs_post_init__(self) -> None:
1362
+ self._queue = ThreadPoolExecutor(self.thread_count, initializer=self.set_conn)
1363
+ logger.info(f"[{self.name}] Starting a threadpool, size={self.thread_count}.")
1364
+
1365
+ def set_conn(self):
1366
+ assert not hasattr(self.thread_local, "conn")
1367
+ try:
1368
+ self.thread_local.conn = self.create_connection()
1369
+ except Exception as e:
1370
+ self._init_error = e
1371
+
1372
+ def _query(self, sql_code: Union[str, ThreadLocalInterpreter]) -> QueryResult:
1373
+ r = self._queue.submit(self._query_in_worker, sql_code)
1374
+ return r.result()
1375
+
1376
+ def _query_in_worker(self, sql_code: Union[str, ThreadLocalInterpreter]):
1377
+ """This method runs in a worker thread"""
1378
+ if self._init_error:
1379
+ raise self._init_error
1380
+ return self._query_conn(self.thread_local.conn, sql_code)
1381
+
1382
+ @abstractmethod
1383
+ def create_connection(self):
1384
+ """Return a connection instance, that supports the .cursor() method."""
1385
+
1386
+ def close(self):
1387
+ super().close()
1388
+ self._queue.shutdown()
1389
+ if hasattr(self.thread_local, "conn"):
1390
+ self.thread_local.conn.close()
1391
+
1392
+ @property
1393
+ def is_autocommit(self) -> bool:
1394
+ return False
1395
+
1396
+
1397
+ CHECKSUM_HEXDIGITS = 12 # Must be 12 or lower, otherwise SUM() overflows
1398
+ MD5_HEXDIGITS = 32
1399
+
1400
+ _CHECKSUM_BITSIZE = CHECKSUM_HEXDIGITS << 2
1401
+ CHECKSUM_MASK = (2**_CHECKSUM_BITSIZE) - 1
1402
+
1403
+ # bigint is typically 8 bytes
1404
+ # if checksum is shorter, most databases will pad it with zeros
1405
+ # 0xFF → 0x00000000000000FF;
1406
+ # because of that, the numeric representation is always positive,
1407
+ # which limits the number of checksums that we can add together before overflowing.
1408
+ # we can fix that by adding a negative offset of half the max value,
1409
+ # so that the distribution is from -0.5*max to +0.5*max.
1410
+ # then negative numbers can compensate for the positive ones allowing to add more checksums together
1411
+ # without overflowing.
1412
+ CHECKSUM_OFFSET = CHECKSUM_MASK // 2
1413
+
1414
+ DEFAULT_DATETIME_PRECISION = 6
1415
+ DEFAULT_NUMERIC_PRECISION = 24
1416
+
1417
+ TIMESTAMP_PRECISION_POS = 20 # len("2022-06-03 12:24:35.") == 20