dcs-sdk 1.6.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_diff/__init__.py +221 -0
- data_diff/__main__.py +517 -0
- data_diff/abcs/__init__.py +13 -0
- data_diff/abcs/compiler.py +27 -0
- data_diff/abcs/database_types.py +402 -0
- data_diff/config.py +141 -0
- data_diff/databases/__init__.py +38 -0
- data_diff/databases/_connect.py +323 -0
- data_diff/databases/base.py +1417 -0
- data_diff/databases/bigquery.py +376 -0
- data_diff/databases/clickhouse.py +217 -0
- data_diff/databases/databricks.py +262 -0
- data_diff/databases/duckdb.py +207 -0
- data_diff/databases/mssql.py +343 -0
- data_diff/databases/mysql.py +189 -0
- data_diff/databases/oracle.py +238 -0
- data_diff/databases/postgresql.py +293 -0
- data_diff/databases/presto.py +222 -0
- data_diff/databases/redis.py +93 -0
- data_diff/databases/redshift.py +233 -0
- data_diff/databases/snowflake.py +222 -0
- data_diff/databases/sybase.py +720 -0
- data_diff/databases/trino.py +73 -0
- data_diff/databases/vertica.py +174 -0
- data_diff/diff_tables.py +489 -0
- data_diff/errors.py +17 -0
- data_diff/format.py +369 -0
- data_diff/hashdiff_tables.py +1026 -0
- data_diff/info_tree.py +76 -0
- data_diff/joindiff_tables.py +434 -0
- data_diff/lexicographic_space.py +253 -0
- data_diff/parse_time.py +88 -0
- data_diff/py.typed +0 -0
- data_diff/queries/__init__.py +13 -0
- data_diff/queries/api.py +213 -0
- data_diff/queries/ast_classes.py +811 -0
- data_diff/queries/base.py +38 -0
- data_diff/queries/extras.py +43 -0
- data_diff/query_utils.py +70 -0
- data_diff/schema.py +67 -0
- data_diff/table_segment.py +583 -0
- data_diff/thread_utils.py +112 -0
- data_diff/utils.py +1022 -0
- data_diff/version.py +15 -0
- dcs_core/__init__.py +13 -0
- dcs_core/__main__.py +17 -0
- dcs_core/__version__.py +15 -0
- dcs_core/cli/__init__.py +13 -0
- dcs_core/cli/cli.py +165 -0
- dcs_core/core/__init__.py +19 -0
- dcs_core/core/common/__init__.py +13 -0
- dcs_core/core/common/errors.py +50 -0
- dcs_core/core/common/models/__init__.py +13 -0
- dcs_core/core/common/models/configuration.py +284 -0
- dcs_core/core/common/models/dashboard.py +24 -0
- dcs_core/core/common/models/data_source_resource.py +75 -0
- dcs_core/core/common/models/metric.py +160 -0
- dcs_core/core/common/models/profile.py +75 -0
- dcs_core/core/common/models/validation.py +216 -0
- dcs_core/core/common/models/widget.py +44 -0
- dcs_core/core/configuration/__init__.py +13 -0
- dcs_core/core/configuration/config_loader.py +139 -0
- dcs_core/core/configuration/configuration_parser.py +262 -0
- dcs_core/core/configuration/configuration_parser_arc.py +328 -0
- dcs_core/core/datasource/__init__.py +13 -0
- dcs_core/core/datasource/base.py +62 -0
- dcs_core/core/datasource/manager.py +112 -0
- dcs_core/core/datasource/search_datasource.py +421 -0
- dcs_core/core/datasource/sql_datasource.py +1094 -0
- dcs_core/core/inspect.py +163 -0
- dcs_core/core/logger/__init__.py +13 -0
- dcs_core/core/logger/base.py +32 -0
- dcs_core/core/logger/default_logger.py +94 -0
- dcs_core/core/metric/__init__.py +13 -0
- dcs_core/core/metric/base.py +220 -0
- dcs_core/core/metric/combined_metric.py +98 -0
- dcs_core/core/metric/custom_metric.py +34 -0
- dcs_core/core/metric/manager.py +137 -0
- dcs_core/core/metric/numeric_metric.py +403 -0
- dcs_core/core/metric/reliability_metric.py +90 -0
- dcs_core/core/profiling/__init__.py +13 -0
- dcs_core/core/profiling/datasource_profiling.py +136 -0
- dcs_core/core/profiling/numeric_field_profiling.py +72 -0
- dcs_core/core/profiling/text_field_profiling.py +67 -0
- dcs_core/core/repository/__init__.py +13 -0
- dcs_core/core/repository/metric_repository.py +77 -0
- dcs_core/core/utils/__init__.py +13 -0
- dcs_core/core/utils/log.py +29 -0
- dcs_core/core/utils/tracking.py +105 -0
- dcs_core/core/utils/utils.py +44 -0
- dcs_core/core/validation/__init__.py +13 -0
- dcs_core/core/validation/base.py +230 -0
- dcs_core/core/validation/completeness_validation.py +153 -0
- dcs_core/core/validation/custom_query_validation.py +24 -0
- dcs_core/core/validation/manager.py +282 -0
- dcs_core/core/validation/numeric_validation.py +276 -0
- dcs_core/core/validation/reliability_validation.py +91 -0
- dcs_core/core/validation/uniqueness_validation.py +61 -0
- dcs_core/core/validation/validity_validation.py +738 -0
- dcs_core/integrations/__init__.py +13 -0
- dcs_core/integrations/databases/__init__.py +13 -0
- dcs_core/integrations/databases/bigquery.py +187 -0
- dcs_core/integrations/databases/databricks.py +51 -0
- dcs_core/integrations/databases/db2.py +652 -0
- dcs_core/integrations/databases/elasticsearch.py +61 -0
- dcs_core/integrations/databases/mssql.py +829 -0
- dcs_core/integrations/databases/mysql.py +409 -0
- dcs_core/integrations/databases/opensearch.py +64 -0
- dcs_core/integrations/databases/oracle.py +719 -0
- dcs_core/integrations/databases/postgres.py +482 -0
- dcs_core/integrations/databases/redshift.py +53 -0
- dcs_core/integrations/databases/snowflake.py +48 -0
- dcs_core/integrations/databases/spark_df.py +111 -0
- dcs_core/integrations/databases/sybase.py +1069 -0
- dcs_core/integrations/storage/__init__.py +13 -0
- dcs_core/integrations/storage/local_file.py +149 -0
- dcs_core/integrations/utils/__init__.py +13 -0
- dcs_core/integrations/utils/utils.py +36 -0
- dcs_core/report/__init__.py +13 -0
- dcs_core/report/dashboard.py +211 -0
- dcs_core/report/models.py +88 -0
- dcs_core/report/static/assets/fonts/DMSans-Bold.ttf +0 -0
- dcs_core/report/static/assets/fonts/DMSans-Medium.ttf +0 -0
- dcs_core/report/static/assets/fonts/DMSans-Regular.ttf +0 -0
- dcs_core/report/static/assets/fonts/DMSans-SemiBold.ttf +0 -0
- dcs_core/report/static/assets/images/docs.svg +6 -0
- dcs_core/report/static/assets/images/github.svg +4 -0
- dcs_core/report/static/assets/images/logo.svg +7 -0
- dcs_core/report/static/assets/images/slack.svg +13 -0
- dcs_core/report/static/index.js +2 -0
- dcs_core/report/static/index.js.LICENSE.txt +3971 -0
- dcs_sdk/__init__.py +13 -0
- dcs_sdk/__main__.py +18 -0
- dcs_sdk/__version__.py +15 -0
- dcs_sdk/cli/__init__.py +13 -0
- dcs_sdk/cli/cli.py +163 -0
- dcs_sdk/sdk/__init__.py +58 -0
- dcs_sdk/sdk/config/__init__.py +13 -0
- dcs_sdk/sdk/config/config_loader.py +491 -0
- dcs_sdk/sdk/data_diff/__init__.py +13 -0
- dcs_sdk/sdk/data_diff/data_differ.py +821 -0
- dcs_sdk/sdk/rules/__init__.py +15 -0
- dcs_sdk/sdk/rules/rules_mappping.py +31 -0
- dcs_sdk/sdk/rules/rules_repository.py +214 -0
- dcs_sdk/sdk/rules/schema_rules.py +65 -0
- dcs_sdk/sdk/utils/__init__.py +13 -0
- dcs_sdk/sdk/utils/serializer.py +25 -0
- dcs_sdk/sdk/utils/similarity_score/__init__.py +13 -0
- dcs_sdk/sdk/utils/similarity_score/base_provider.py +153 -0
- dcs_sdk/sdk/utils/similarity_score/cosine_similarity_provider.py +39 -0
- dcs_sdk/sdk/utils/similarity_score/jaccard_provider.py +24 -0
- dcs_sdk/sdk/utils/similarity_score/levenshtein_distance_provider.py +31 -0
- dcs_sdk/sdk/utils/table.py +475 -0
- dcs_sdk/sdk/utils/themes.py +40 -0
- dcs_sdk/sdk/utils/utils.py +349 -0
- dcs_sdk-1.6.5.dist-info/METADATA +150 -0
- dcs_sdk-1.6.5.dist-info/RECORD +159 -0
- dcs_sdk-1.6.5.dist-info/WHEEL +4 -0
- dcs_sdk-1.6.5.dist-info/entry_points.txt +4 -0
|
@@ -0,0 +1,1417 @@
|
|
|
1
|
+
# Copyright 2022-present, the Waterdip Labs Pvt. Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import abc
|
|
16
|
+
import contextvars
|
|
17
|
+
import decimal
|
|
18
|
+
import functools
|
|
19
|
+
import logging
|
|
20
|
+
import math
|
|
21
|
+
import random
|
|
22
|
+
import secrets
|
|
23
|
+
import statistics
|
|
24
|
+
import string
|
|
25
|
+
import sys
|
|
26
|
+
import threading
|
|
27
|
+
import time
|
|
28
|
+
from abc import abstractmethod
|
|
29
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
30
|
+
from datetime import datetime
|
|
31
|
+
from functools import partial, wraps
|
|
32
|
+
from typing import (
|
|
33
|
+
Any,
|
|
34
|
+
Callable,
|
|
35
|
+
ClassVar,
|
|
36
|
+
Dict,
|
|
37
|
+
Generator,
|
|
38
|
+
Iterator,
|
|
39
|
+
List,
|
|
40
|
+
NewType,
|
|
41
|
+
Optional,
|
|
42
|
+
Sequence,
|
|
43
|
+
Tuple,
|
|
44
|
+
Type,
|
|
45
|
+
TypeVar,
|
|
46
|
+
Union,
|
|
47
|
+
)
|
|
48
|
+
from uuid import UUID
|
|
49
|
+
|
|
50
|
+
import attrs
|
|
51
|
+
from loguru import logger as loguru_logger
|
|
52
|
+
from typing_extensions import Self
|
|
53
|
+
|
|
54
|
+
from data_diff.abcs.compiler import AbstractCompiler, Compilable
|
|
55
|
+
from data_diff.abcs.database_types import (
|
|
56
|
+
JSON,
|
|
57
|
+
ArithAlphanumeric,
|
|
58
|
+
ArithUnicodeString,
|
|
59
|
+
Array,
|
|
60
|
+
Boolean,
|
|
61
|
+
ColType,
|
|
62
|
+
ColType_UUID,
|
|
63
|
+
DbPath,
|
|
64
|
+
DbTime,
|
|
65
|
+
Decimal,
|
|
66
|
+
Float,
|
|
67
|
+
FractionalType,
|
|
68
|
+
Integer,
|
|
69
|
+
Native_UUID,
|
|
70
|
+
String_Alphanum,
|
|
71
|
+
String_UUID,
|
|
72
|
+
String_VaryingAlphanum,
|
|
73
|
+
String_VaryingUnicode,
|
|
74
|
+
Struct,
|
|
75
|
+
TemporalType,
|
|
76
|
+
Text,
|
|
77
|
+
TimestampTZ,
|
|
78
|
+
UnknownColType,
|
|
79
|
+
)
|
|
80
|
+
from data_diff.queries.api import SKIP, Code, Explain, Expr, Select, table, this
|
|
81
|
+
from data_diff.queries.ast_classes import (
|
|
82
|
+
Alias,
|
|
83
|
+
BinOp,
|
|
84
|
+
CaseWhen,
|
|
85
|
+
Cast,
|
|
86
|
+
Column,
|
|
87
|
+
Commit,
|
|
88
|
+
Concat,
|
|
89
|
+
ConstantTable,
|
|
90
|
+
Count,
|
|
91
|
+
CreateTable,
|
|
92
|
+
Cte,
|
|
93
|
+
CurrentTimestamp,
|
|
94
|
+
DropTable,
|
|
95
|
+
Func,
|
|
96
|
+
GroupBy,
|
|
97
|
+
In,
|
|
98
|
+
InsertToTable,
|
|
99
|
+
IsDistinctFrom,
|
|
100
|
+
ITable,
|
|
101
|
+
Join,
|
|
102
|
+
Param,
|
|
103
|
+
Random,
|
|
104
|
+
Root,
|
|
105
|
+
TableAlias,
|
|
106
|
+
TableOp,
|
|
107
|
+
TablePath,
|
|
108
|
+
TruncateTable,
|
|
109
|
+
UnaryOp,
|
|
110
|
+
WhenThen,
|
|
111
|
+
_ResolveColumn,
|
|
112
|
+
)
|
|
113
|
+
from data_diff.queries.extras import (
|
|
114
|
+
ApplyFuncAndNormalizeAsString,
|
|
115
|
+
Checksum,
|
|
116
|
+
NormalizeAsString,
|
|
117
|
+
)
|
|
118
|
+
from data_diff.schema import RawColumnInfo
|
|
119
|
+
from data_diff.utils import (
|
|
120
|
+
ArithDate,
|
|
121
|
+
ArithDateTime,
|
|
122
|
+
ArithString,
|
|
123
|
+
ArithTimestamp,
|
|
124
|
+
ArithTimestampTZ,
|
|
125
|
+
ArithUUID,
|
|
126
|
+
SybaseDriverTypes,
|
|
127
|
+
is_uuid,
|
|
128
|
+
join_iter,
|
|
129
|
+
safezip,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
logger = logging.getLogger("database")
|
|
133
|
+
cv_params = contextvars.ContextVar("params")
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class CompileError(Exception):
|
|
137
|
+
pass
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
@attrs.define(frozen=True)
|
|
141
|
+
class Compiler(AbstractCompiler):
|
|
142
|
+
"""
|
|
143
|
+
Compiler bears the context for a single compilation.
|
|
144
|
+
|
|
145
|
+
There can be multiple compilation per app run.
|
|
146
|
+
There can be multiple compilers in one compilation (with varying contexts).
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
# Database is needed to normalize tables. Dialect is needed for recursive compilations.
|
|
150
|
+
# In theory, it is many-to-many relations: e.g. a generic ODBC driver with multiple dialects.
|
|
151
|
+
# In practice, we currently bind the dialects to the specific database classes.
|
|
152
|
+
database: "Database"
|
|
153
|
+
|
|
154
|
+
in_select: bool = False # Compilation runtime flag
|
|
155
|
+
in_join: bool = False # Compilation runtime flag
|
|
156
|
+
|
|
157
|
+
_table_context: List = attrs.field(factory=list) # List[ITable]
|
|
158
|
+
_subqueries: Dict[str, Any] = attrs.field(factory=dict) # XXX not thread-safe
|
|
159
|
+
root: bool = True
|
|
160
|
+
|
|
161
|
+
_counter: List = attrs.field(factory=lambda: [0])
|
|
162
|
+
|
|
163
|
+
@property
|
|
164
|
+
def dialect(self) -> "BaseDialect":
|
|
165
|
+
return self.database.dialect
|
|
166
|
+
|
|
167
|
+
# TODO: DEPRECATED: Remove once the dialect is used directly in all places.
|
|
168
|
+
def compile(self, elem, params=None) -> str:
|
|
169
|
+
return self.dialect.compile(self, elem, params)
|
|
170
|
+
|
|
171
|
+
def new_unique_name(self, prefix="tmp") -> str:
|
|
172
|
+
self._counter[0] += 1
|
|
173
|
+
return f"{prefix}{self._counter[0]}"
|
|
174
|
+
|
|
175
|
+
def new_unique_table_name(self, prefix="tmp") -> DbPath:
|
|
176
|
+
self._counter[0] += 1
|
|
177
|
+
table_name = f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}"
|
|
178
|
+
return self.database.dialect.parse_table_name(table_name)
|
|
179
|
+
|
|
180
|
+
def add_table_context(self, *tables: Sequence, **kw) -> Self:
|
|
181
|
+
return attrs.evolve(self, table_context=self._table_context + list(tables), **kw)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def parse_table_name(t):
|
|
185
|
+
return tuple(t.split("."))
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def import_helper(package: str = None, text=""):
|
|
189
|
+
def dec(f):
|
|
190
|
+
@wraps(f)
|
|
191
|
+
def _inner():
|
|
192
|
+
try:
|
|
193
|
+
return f()
|
|
194
|
+
except ModuleNotFoundError as e:
|
|
195
|
+
s = text
|
|
196
|
+
if package:
|
|
197
|
+
s += f"Please complete setup by running: pip install 'dcs-cli[{package}]'."
|
|
198
|
+
raise ModuleNotFoundError(f"{e}\n\n{s}\n")
|
|
199
|
+
|
|
200
|
+
return _inner
|
|
201
|
+
|
|
202
|
+
return dec
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class ConnectError(Exception):
|
|
206
|
+
pass
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class QueryError(Exception):
|
|
210
|
+
pass
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _one(seq):
|
|
214
|
+
(x,) = seq
|
|
215
|
+
return x
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
@attrs.define(frozen=False)
|
|
219
|
+
class ThreadLocalInterpreter:
|
|
220
|
+
"""An interpeter used to execute a sequence of queries within the same thread and cursor.
|
|
221
|
+
|
|
222
|
+
Useful for cursor-sensitive operations, such as creating a temporary table.
|
|
223
|
+
"""
|
|
224
|
+
|
|
225
|
+
compiler: Compiler
|
|
226
|
+
gen: Generator
|
|
227
|
+
|
|
228
|
+
def apply_queries(self, callback: Callable[[str], Any]) -> None:
|
|
229
|
+
q: Expr = next(self.gen)
|
|
230
|
+
while True:
|
|
231
|
+
sql = self.compiler.database.dialect.compile(self.compiler, q)
|
|
232
|
+
logger.debug("Running SQL (%s-TL):\n%s", self.compiler.database.name, sql)
|
|
233
|
+
try:
|
|
234
|
+
try:
|
|
235
|
+
res = callback(sql) if sql is not SKIP else SKIP
|
|
236
|
+
except Exception as e:
|
|
237
|
+
q = self.gen.throw(type(e), e)
|
|
238
|
+
else:
|
|
239
|
+
q = self.gen.send(res)
|
|
240
|
+
except StopIteration:
|
|
241
|
+
break
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocalInterpreter]) -> list:
|
|
245
|
+
if isinstance(sql_code, ThreadLocalInterpreter):
|
|
246
|
+
return sql_code.apply_queries(callback)
|
|
247
|
+
else:
|
|
248
|
+
return callback(sql_code)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
@attrs.define(frozen=False)
|
|
252
|
+
class BaseDialect(abc.ABC):
|
|
253
|
+
SUPPORTS_PRIMARY_KEY: ClassVar[bool] = False
|
|
254
|
+
SUPPORTS_INDEXES: ClassVar[bool] = False
|
|
255
|
+
PREVENT_OVERFLOW_WHEN_CONCAT: ClassVar[bool] = False
|
|
256
|
+
TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = {}
|
|
257
|
+
DEFAULT_NUMERIC_PRECISION: ClassVar[int] = 0 # effective precision when type is just "NUMERIC"
|
|
258
|
+
|
|
259
|
+
PLACEHOLDER_TABLE = None # Used for Oracle
|
|
260
|
+
|
|
261
|
+
default_schema: Optional[str] = None
|
|
262
|
+
TABLE_NAMES: List[str] = attrs.Factory(list)
|
|
263
|
+
project: Optional[str] = None # Used for BigQuery
|
|
264
|
+
|
|
265
|
+
# Some database do not support long string so concatenation might lead to type overflow
|
|
266
|
+
_prevent_overflow_when_concat: bool = False
|
|
267
|
+
sybase_driver_type: SybaseDriverTypes = attrs.Factory(lambda: SybaseDriverTypes())
|
|
268
|
+
query_config_for_free_tds: Dict = attrs.Factory(
|
|
269
|
+
lambda: {
|
|
270
|
+
"ase_query_chosen": False,
|
|
271
|
+
"freetds_query_chosen": False,
|
|
272
|
+
}
|
|
273
|
+
)
|
|
274
|
+
table_schema: Dict[str, RawColumnInfo] = attrs.Factory(dict)
|
|
275
|
+
|
|
276
|
+
def enable_preventing_type_overflow(self) -> None:
|
|
277
|
+
logger.info("Preventing type overflow when concatenation is enabled")
|
|
278
|
+
self._prevent_overflow_when_concat = True
|
|
279
|
+
|
|
280
|
+
def parse_table_name(self, name: str) -> DbPath:
|
|
281
|
+
"Parse the given table name into a DbPath"
|
|
282
|
+
return parse_table_name(name)
|
|
283
|
+
|
|
284
|
+
def compile(self, compiler: Compiler, elem, params=None) -> str:
|
|
285
|
+
if params:
|
|
286
|
+
cv_params.set(params)
|
|
287
|
+
|
|
288
|
+
if compiler.root and isinstance(elem, Compilable) and not isinstance(elem, Root):
|
|
289
|
+
from data_diff.queries.ast_classes import Select
|
|
290
|
+
|
|
291
|
+
elem = Select(columns=[elem])
|
|
292
|
+
|
|
293
|
+
res = self._compile(compiler, elem)
|
|
294
|
+
if compiler.root and compiler._subqueries:
|
|
295
|
+
subq = ", ".join(f"\n {k} AS ({v})" for k, v in compiler._subqueries.items())
|
|
296
|
+
compiler._subqueries.clear()
|
|
297
|
+
return f"WITH {subq}\n{res}"
|
|
298
|
+
return res
|
|
299
|
+
|
|
300
|
+
def _compile(self, compiler: Compiler, elem) -> str:
|
|
301
|
+
if elem is None:
|
|
302
|
+
return "NULL"
|
|
303
|
+
elif isinstance(elem, Compilable):
|
|
304
|
+
return self.render_compilable(attrs.evolve(compiler, root=False), elem)
|
|
305
|
+
elif isinstance(elem, ColType):
|
|
306
|
+
return self.render_coltype(attrs.evolve(compiler, root=False), elem)
|
|
307
|
+
elif isinstance(elem, str):
|
|
308
|
+
return f"'{elem}'"
|
|
309
|
+
elif isinstance(elem, (int, float)):
|
|
310
|
+
return str(elem)
|
|
311
|
+
elif isinstance(elem, datetime):
|
|
312
|
+
return self.timestamp_value(elem)
|
|
313
|
+
elif isinstance(elem, bytes):
|
|
314
|
+
return f"b'{elem.decode()}'"
|
|
315
|
+
elif isinstance(elem, ArithUUID):
|
|
316
|
+
if self.name.lower() == "trino":
|
|
317
|
+
return f"CAST('{elem.uuid}' AS UUID)"
|
|
318
|
+
s = f"'{elem.uuid}'"
|
|
319
|
+
return s.upper() if elem.uppercase else s.lower() if elem.lowercase else s
|
|
320
|
+
elif isinstance(elem, (ArithDateTime, ArithTimestamp, ArithTimestampTZ)):
|
|
321
|
+
return self.timestamp_value(elem._dt)
|
|
322
|
+
elif isinstance(elem, ArithDate):
|
|
323
|
+
from datetime import time
|
|
324
|
+
|
|
325
|
+
return self.timestamp_value(datetime.combine(elem._date, time.min))
|
|
326
|
+
elif isinstance(elem, ArithString):
|
|
327
|
+
return f"'{elem}'"
|
|
328
|
+
assert False, elem
|
|
329
|
+
|
|
330
|
+
def render_compilable(self, c: Compiler, elem: Compilable) -> str:
|
|
331
|
+
# All ifs are only for better code navigation, IDE usage detection, and type checking.
|
|
332
|
+
# The last catch-all would render them anyway — it is a typical "visitor" pattern.
|
|
333
|
+
if isinstance(elem, Column):
|
|
334
|
+
return self.render_column(c, elem)
|
|
335
|
+
elif isinstance(elem, Cte):
|
|
336
|
+
return self.render_cte(c, elem)
|
|
337
|
+
elif isinstance(elem, Commit):
|
|
338
|
+
return self.render_commit(c, elem)
|
|
339
|
+
elif isinstance(elem, Param):
|
|
340
|
+
return self.render_param(c, elem)
|
|
341
|
+
elif isinstance(elem, NormalizeAsString):
|
|
342
|
+
return self.render_normalizeasstring(c, elem)
|
|
343
|
+
elif isinstance(elem, ApplyFuncAndNormalizeAsString):
|
|
344
|
+
return self.render_applyfuncandnormalizeasstring(c, elem)
|
|
345
|
+
elif isinstance(elem, Checksum):
|
|
346
|
+
return self.render_checksum(c, elem)
|
|
347
|
+
elif isinstance(elem, Concat):
|
|
348
|
+
return self.render_concat(c, elem)
|
|
349
|
+
elif isinstance(elem, Func):
|
|
350
|
+
return self.render_func(c, elem)
|
|
351
|
+
elif isinstance(elem, WhenThen):
|
|
352
|
+
return self.render_whenthen(c, elem)
|
|
353
|
+
elif isinstance(elem, CaseWhen):
|
|
354
|
+
return self.render_casewhen(c, elem)
|
|
355
|
+
elif isinstance(elem, IsDistinctFrom):
|
|
356
|
+
return self.render_isdistinctfrom(c, elem)
|
|
357
|
+
elif isinstance(elem, UnaryOp):
|
|
358
|
+
return self.render_unaryop(c, elem)
|
|
359
|
+
elif isinstance(elem, BinOp):
|
|
360
|
+
return self.render_binop(c, elem)
|
|
361
|
+
elif isinstance(elem, TablePath):
|
|
362
|
+
return self.render_tablepath(c, elem)
|
|
363
|
+
elif isinstance(elem, TableAlias):
|
|
364
|
+
return self.render_tablealias(c, elem)
|
|
365
|
+
elif isinstance(elem, TableOp):
|
|
366
|
+
return self.render_tableop(c, elem)
|
|
367
|
+
elif isinstance(elem, Select):
|
|
368
|
+
return self.render_select(c, elem)
|
|
369
|
+
elif isinstance(elem, Join):
|
|
370
|
+
return self.render_join(c, elem)
|
|
371
|
+
elif isinstance(elem, GroupBy):
|
|
372
|
+
return self.render_groupby(c, elem)
|
|
373
|
+
elif isinstance(elem, Count):
|
|
374
|
+
return self.render_count(c, elem)
|
|
375
|
+
elif isinstance(elem, Alias):
|
|
376
|
+
return self.render_alias(c, elem)
|
|
377
|
+
elif isinstance(elem, In):
|
|
378
|
+
return self.render_in(c, elem)
|
|
379
|
+
elif isinstance(elem, Cast):
|
|
380
|
+
return self.render_cast(c, elem)
|
|
381
|
+
elif isinstance(elem, Random):
|
|
382
|
+
return self.render_random(c, elem)
|
|
383
|
+
elif isinstance(elem, Explain):
|
|
384
|
+
return self.render_explain(c, elem)
|
|
385
|
+
elif isinstance(elem, CurrentTimestamp):
|
|
386
|
+
return self.render_currenttimestamp(c, elem)
|
|
387
|
+
elif isinstance(elem, CreateTable):
|
|
388
|
+
return self.render_createtable(c, elem)
|
|
389
|
+
elif isinstance(elem, DropTable):
|
|
390
|
+
return self.render_droptable(c, elem)
|
|
391
|
+
elif isinstance(elem, TruncateTable):
|
|
392
|
+
return self.render_truncatetable(c, elem)
|
|
393
|
+
elif isinstance(elem, InsertToTable):
|
|
394
|
+
return self.render_inserttotable(c, elem)
|
|
395
|
+
elif isinstance(elem, Code):
|
|
396
|
+
return self.render_code(c, elem)
|
|
397
|
+
elif isinstance(elem, _ResolveColumn):
|
|
398
|
+
return self.render__resolvecolumn(c, elem)
|
|
399
|
+
|
|
400
|
+
method_name = f"render_{elem.__class__.__name__.lower()}"
|
|
401
|
+
method = getattr(self, method_name, None)
|
|
402
|
+
if method is not None:
|
|
403
|
+
return method(c, elem)
|
|
404
|
+
else:
|
|
405
|
+
raise RuntimeError(f"Cannot render AST of type {elem.__class__}")
|
|
406
|
+
# return elem.compile(compiler.replace(root=False))
|
|
407
|
+
|
|
408
|
+
def render_coltype(self, c: Compiler, elem: ColType) -> str:
|
|
409
|
+
return self.type_repr(elem)
|
|
410
|
+
|
|
411
|
+
def render_column(self, c: Compiler, elem: Column) -> str:
|
|
412
|
+
if c._table_context:
|
|
413
|
+
if len(c._table_context) > 1:
|
|
414
|
+
aliases = [
|
|
415
|
+
t for t in c._table_context if isinstance(t, TableAlias) and t.source_table is elem.source_table
|
|
416
|
+
]
|
|
417
|
+
if not aliases:
|
|
418
|
+
return self.quote(elem.name)
|
|
419
|
+
elif len(aliases) > 1:
|
|
420
|
+
raise CompileError(f"Too many aliases for column {elem.name}")
|
|
421
|
+
(alias,) = aliases
|
|
422
|
+
|
|
423
|
+
return f"{self.quote(alias.name)}.{self.quote(elem.name)}"
|
|
424
|
+
|
|
425
|
+
return self.quote(elem.name)
|
|
426
|
+
|
|
427
|
+
def render_cte(self, parent_c: Compiler, elem: Cte) -> str:
|
|
428
|
+
c: Compiler = attrs.evolve(parent_c, table_context=[], in_select=False)
|
|
429
|
+
compiled = self.compile(c, elem.source_table)
|
|
430
|
+
|
|
431
|
+
name = elem.name or parent_c.new_unique_name()
|
|
432
|
+
name_params = f"{name}({', '.join(elem.params)})" if elem.params else name
|
|
433
|
+
parent_c._subqueries[name_params] = compiled
|
|
434
|
+
|
|
435
|
+
return name
|
|
436
|
+
|
|
437
|
+
def render_commit(self, c: Compiler, elem: Commit) -> str:
|
|
438
|
+
return "COMMIT" if not c.database.is_autocommit else SKIP
|
|
439
|
+
|
|
440
|
+
def render_param(self, c: Compiler, elem: Param) -> str:
|
|
441
|
+
params = cv_params.get()
|
|
442
|
+
return self._compile(c, params[elem.name])
|
|
443
|
+
|
|
444
|
+
def render_normalizeasstring(self, c: Compiler, elem: NormalizeAsString) -> str:
|
|
445
|
+
expr = self.compile(c, elem.expr)
|
|
446
|
+
return self.normalize_value_by_type(expr, elem.expr_type or elem.expr.type)
|
|
447
|
+
|
|
448
|
+
def render_applyfuncandnormalizeasstring(self, c: Compiler, elem: ApplyFuncAndNormalizeAsString) -> str:
|
|
449
|
+
expr = elem.expr
|
|
450
|
+
expr_type = expr.type
|
|
451
|
+
|
|
452
|
+
if isinstance(expr_type, Native_UUID):
|
|
453
|
+
# Normalize first, apply template after (for uuids)
|
|
454
|
+
# Needed because min/max(uuid) fails in postgresql
|
|
455
|
+
expr = NormalizeAsString(expr, expr_type)
|
|
456
|
+
if elem.apply_func is not None:
|
|
457
|
+
expr = elem.apply_func(expr) # Apply template using Python's string formatting
|
|
458
|
+
|
|
459
|
+
else:
|
|
460
|
+
# Apply template before normalizing (for ints)
|
|
461
|
+
if elem.apply_func is not None:
|
|
462
|
+
expr = elem.apply_func(expr) # Apply template using Python's string formatting
|
|
463
|
+
expr = NormalizeAsString(expr, expr_type)
|
|
464
|
+
|
|
465
|
+
return self.compile(c, expr)
|
|
466
|
+
|
|
467
|
+
def render_checksum(self, c: Compiler, elem: Checksum) -> str:
|
|
468
|
+
if len(elem.exprs) > 1:
|
|
469
|
+
exprs = [Code(f"coalesce({self.compile(c, expr)}, '<null>')") for expr in elem.exprs]
|
|
470
|
+
# exprs = [self.compile(c, e) for e in exprs]
|
|
471
|
+
expr = Concat(exprs, "|")
|
|
472
|
+
else:
|
|
473
|
+
# No need to coalesce - safe to assume that key cannot be null
|
|
474
|
+
(expr,) = elem.exprs
|
|
475
|
+
expr = self.compile(c, expr)
|
|
476
|
+
md5 = self.md5_as_int(expr)
|
|
477
|
+
return f"sum({md5})"
|
|
478
|
+
|
|
479
|
+
def render_concat(self, c: Compiler, elem: Concat) -> str:
|
|
480
|
+
if self._prevent_overflow_when_concat:
|
|
481
|
+
items = [
|
|
482
|
+
f"{self.compile(c, Code(self.md5_as_hex(self.to_string(self.compile(c, expr)))))}"
|
|
483
|
+
for expr in elem.exprs
|
|
484
|
+
]
|
|
485
|
+
|
|
486
|
+
# We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL
|
|
487
|
+
else:
|
|
488
|
+
items = [
|
|
489
|
+
f"coalesce({self.compile(c, Code(self.to_string(self.compile(c, expr))))}, '<null>')"
|
|
490
|
+
for expr in elem.exprs
|
|
491
|
+
]
|
|
492
|
+
|
|
493
|
+
assert items
|
|
494
|
+
if len(items) == 1:
|
|
495
|
+
return items[0]
|
|
496
|
+
|
|
497
|
+
if elem.sep:
|
|
498
|
+
items = list(join_iter(f"'{elem.sep}'", items))
|
|
499
|
+
return self.concat(items)
|
|
500
|
+
|
|
501
|
+
def render_alias(self, c: Compiler, elem: Alias) -> str:
|
|
502
|
+
return f"{self.compile(c, elem.expr)} AS {self.quote(elem.name)}"
|
|
503
|
+
|
|
504
|
+
def render_count(self, c: Compiler, elem: Count) -> str:
|
|
505
|
+
expr = self.compile(c, elem.expr) if elem.expr else "*"
|
|
506
|
+
if elem.distinct:
|
|
507
|
+
return f"count(distinct {expr})"
|
|
508
|
+
return f"count({expr})"
|
|
509
|
+
|
|
510
|
+
def render_code(self, c: Compiler, elem: Code) -> str:
|
|
511
|
+
if not elem.args:
|
|
512
|
+
return elem.code
|
|
513
|
+
|
|
514
|
+
args = {k: self.compile(c, v) for k, v in elem.args.items()}
|
|
515
|
+
return elem.code.format(**args)
|
|
516
|
+
|
|
517
|
+
def render_func(self, c: Compiler, elem: Func) -> str:
|
|
518
|
+
args = ", ".join(self.compile(c, e) for e in elem.args)
|
|
519
|
+
return f"{elem.name}({args})"
|
|
520
|
+
|
|
521
|
+
def render_whenthen(self, c: Compiler, elem: WhenThen) -> str:
|
|
522
|
+
return f"WHEN {self.compile(c, elem.when)} THEN {self.compile(c, elem.then)}"
|
|
523
|
+
|
|
524
|
+
def render_casewhen(self, c: Compiler, elem: CaseWhen) -> str:
|
|
525
|
+
assert elem.cases
|
|
526
|
+
when_thens = " ".join(self.compile(c, case) for case in elem.cases)
|
|
527
|
+
else_expr = (" ELSE " + self.compile(c, elem.else_expr)) if elem.else_expr is not None else ""
|
|
528
|
+
return f"CASE {when_thens}{else_expr} END"
|
|
529
|
+
|
|
530
|
+
def render_isdistinctfrom(self, c: Compiler, elem: IsDistinctFrom) -> str:
|
|
531
|
+
a = self.to_comparable(self.compile(c, elem.a), elem.a.type)
|
|
532
|
+
b = self.to_comparable(self.compile(c, elem.b), elem.b.type)
|
|
533
|
+
return self.is_distinct_from(a, b)
|
|
534
|
+
|
|
535
|
+
def render_unaryop(self, c: Compiler, elem: UnaryOp) -> str:
|
|
536
|
+
return f"({elem.op}{self.compile(c, elem.expr)})"
|
|
537
|
+
|
|
538
|
+
def render_binop(self, c: Compiler, elem: BinOp) -> str:
|
|
539
|
+
expr = f" {elem.op} ".join(self.compile(c, a) for a in elem.args)
|
|
540
|
+
return f"({expr})"
|
|
541
|
+
|
|
542
|
+
def render_tablepath(self, c: Compiler, elem: TablePath) -> str:
|
|
543
|
+
path = elem.path # c.database._normalize_table_path(self.name)
|
|
544
|
+
return ".".join(map(lambda p: self.quote(p, True), path))
|
|
545
|
+
|
|
546
|
+
def render_tablealias(self, c: Compiler, elem: TableAlias) -> str:
|
|
547
|
+
return f"{self.compile(c, elem.source_table)} {self.quote(elem.name)}"
|
|
548
|
+
|
|
549
|
+
def render_tableop(self, parent_c: Compiler, elem: TableOp) -> str:
|
|
550
|
+
c: Compiler = attrs.evolve(parent_c, in_select=False)
|
|
551
|
+
table_expr = f"{self.compile(c, elem.table1)} {elem.op} {self.compile(c, elem.table2)}"
|
|
552
|
+
if parent_c.in_select:
|
|
553
|
+
table_expr = f"({table_expr}) {c.new_unique_name()}"
|
|
554
|
+
elif parent_c.in_join:
|
|
555
|
+
table_expr = f"({table_expr})"
|
|
556
|
+
return table_expr
|
|
557
|
+
|
|
558
|
+
def render__resolvecolumn(self, c: Compiler, elem: _ResolveColumn) -> str:
|
|
559
|
+
return self.compile(c, elem._get_resolved())
|
|
560
|
+
|
|
561
|
+
def render_select(self, parent_c: Compiler, elem: Select) -> str:
|
|
562
|
+
c: Compiler = attrs.evolve(parent_c, in_select=True) # .add_table_context(self.table)
|
|
563
|
+
compile_fn = functools.partial(self.compile, c)
|
|
564
|
+
|
|
565
|
+
columns = ", ".join(map(compile_fn, elem.columns)) if elem.columns else "*"
|
|
566
|
+
distinct = "DISTINCT " if elem.distinct else ""
|
|
567
|
+
optimizer_hints = self.optimizer_hints(elem.optimizer_hints) if elem.optimizer_hints else ""
|
|
568
|
+
select = f"SELECT {optimizer_hints}{distinct}{columns}"
|
|
569
|
+
|
|
570
|
+
if elem.table:
|
|
571
|
+
select += " FROM " + self.compile(c, elem.table)
|
|
572
|
+
elif self.PLACEHOLDER_TABLE:
|
|
573
|
+
select += f" FROM {self.PLACEHOLDER_TABLE}"
|
|
574
|
+
|
|
575
|
+
if elem.where_exprs:
|
|
576
|
+
select += " WHERE " + " AND ".join(map(compile_fn, elem.where_exprs))
|
|
577
|
+
|
|
578
|
+
if elem.group_by_exprs:
|
|
579
|
+
select += " GROUP BY " + ", ".join(map(compile_fn, elem.group_by_exprs))
|
|
580
|
+
|
|
581
|
+
if elem.having_exprs:
|
|
582
|
+
assert elem.group_by_exprs
|
|
583
|
+
select += " HAVING " + " AND ".join(map(compile_fn, elem.having_exprs))
|
|
584
|
+
|
|
585
|
+
if elem.order_by_exprs:
|
|
586
|
+
select += " ORDER BY " + ", ".join(map(compile_fn, elem.order_by_exprs))
|
|
587
|
+
|
|
588
|
+
if elem.limit_expr is not None:
|
|
589
|
+
has_order_by = bool(elem.order_by_exprs)
|
|
590
|
+
select = self.limit_select(select_query=select, offset=0, limit=elem.limit_expr, has_order_by=has_order_by)
|
|
591
|
+
|
|
592
|
+
if parent_c.in_select:
|
|
593
|
+
select = f"({select}) {c.new_unique_name()}"
|
|
594
|
+
elif parent_c.in_join:
|
|
595
|
+
select = f"({select})"
|
|
596
|
+
return select
|
|
597
|
+
|
|
598
|
+
def render_join(self, parent_c: Compiler, elem: Join) -> str:
|
|
599
|
+
tables = [
|
|
600
|
+
t if isinstance(t, TableAlias) else TableAlias(t, name=parent_c.new_unique_name())
|
|
601
|
+
for t in elem.source_tables
|
|
602
|
+
]
|
|
603
|
+
c = parent_c.add_table_context(*tables, in_join=True, in_select=False)
|
|
604
|
+
op = " JOIN " if elem.op is None else f" {elem.op} JOIN "
|
|
605
|
+
joined = op.join(self.compile(c, t) for t in tables)
|
|
606
|
+
|
|
607
|
+
if elem.on_exprs:
|
|
608
|
+
on = " AND ".join(self.compile(c, e) for e in elem.on_exprs)
|
|
609
|
+
res = f"{joined} ON {on}"
|
|
610
|
+
else:
|
|
611
|
+
res = joined
|
|
612
|
+
|
|
613
|
+
compile_fn = functools.partial(self.compile, c)
|
|
614
|
+
columns = "*" if elem.columns is None else ", ".join(map(compile_fn, elem.columns))
|
|
615
|
+
select = f"SELECT {columns} FROM {res}"
|
|
616
|
+
|
|
617
|
+
if parent_c.in_select:
|
|
618
|
+
select = f"({select}) {c.new_unique_name()}"
|
|
619
|
+
elif parent_c.in_join:
|
|
620
|
+
select = f"({select})"
|
|
621
|
+
return select
|
|
622
|
+
|
|
623
|
+
def render_groupby(self, c: Compiler, elem: GroupBy) -> str:
|
|
624
|
+
compile_fn = functools.partial(self.compile, c)
|
|
625
|
+
|
|
626
|
+
if elem.values is None:
|
|
627
|
+
raise CompileError(".group_by() must be followed by a call to .agg()")
|
|
628
|
+
|
|
629
|
+
keys = [str(i + 1) for i in range(len(elem.keys))]
|
|
630
|
+
columns = (elem.keys or []) + (elem.values or [])
|
|
631
|
+
if isinstance(elem.table, Select) and elem.table.columns is None and elem.table.group_by_exprs is None:
|
|
632
|
+
return self.compile(
|
|
633
|
+
c,
|
|
634
|
+
attrs.evolve(
|
|
635
|
+
elem.table,
|
|
636
|
+
columns=columns,
|
|
637
|
+
group_by_exprs=[Code(k) for k in keys],
|
|
638
|
+
having_exprs=elem.having_exprs,
|
|
639
|
+
),
|
|
640
|
+
)
|
|
641
|
+
|
|
642
|
+
keys_str = ", ".join(keys)
|
|
643
|
+
columns_str = ", ".join(self.compile(c, x) for x in columns)
|
|
644
|
+
having_str = (
|
|
645
|
+
" HAVING " + " AND ".join(map(compile_fn, elem.having_exprs)) if elem.having_exprs is not None else ""
|
|
646
|
+
)
|
|
647
|
+
select = f"SELECT {columns_str} FROM {self.compile(attrs.evolve(c, in_select=True), elem.table)} GROUP BY {keys_str}{having_str}"
|
|
648
|
+
|
|
649
|
+
if c.in_select:
|
|
650
|
+
select = f"({select}) {c.new_unique_name()}"
|
|
651
|
+
elif c.in_join:
|
|
652
|
+
select = f"({select})"
|
|
653
|
+
return select
|
|
654
|
+
|
|
655
|
+
def render_in(self, c: Compiler, elem: In) -> str:
|
|
656
|
+
compile_fn = functools.partial(self.compile, c)
|
|
657
|
+
elems = ", ".join(map(compile_fn, elem.list))
|
|
658
|
+
return f"({self.compile(c, elem.expr)} IN ({elems}))"
|
|
659
|
+
|
|
660
|
+
def render_cast(self, c: Compiler, elem: Cast) -> str:
|
|
661
|
+
return f"cast({self.compile(c, elem.expr)} as {self.compile(c, elem.target_type)})"
|
|
662
|
+
|
|
663
|
+
def render_random(self, c: Compiler, elem: Random) -> str:
|
|
664
|
+
return self.random()
|
|
665
|
+
|
|
666
|
+
def render_explain(self, c: Compiler, elem: Explain) -> str:
|
|
667
|
+
return self.explain_as_text(self.compile(c, elem.select))
|
|
668
|
+
|
|
669
|
+
def render_currenttimestamp(self, c: Compiler, elem: CurrentTimestamp) -> str:
|
|
670
|
+
return self.current_timestamp()
|
|
671
|
+
|
|
672
|
+
def render_createtable(self, c: Compiler, elem: CreateTable) -> str:
|
|
673
|
+
ne = "IF NOT EXISTS " if elem.if_not_exists else ""
|
|
674
|
+
if elem.source_table:
|
|
675
|
+
return f"CREATE TABLE {ne}{self.compile(c, elem.path)} AS {self.compile(c, elem.source_table)}"
|
|
676
|
+
|
|
677
|
+
schema = ", ".join(f"{self.quote(k)} {self.type_repr(v)}" for k, v in elem.path.schema.items())
|
|
678
|
+
pks = (
|
|
679
|
+
", PRIMARY KEY (%s)" % ", ".join(elem.primary_keys)
|
|
680
|
+
if elem.primary_keys and self.SUPPORTS_PRIMARY_KEY
|
|
681
|
+
else ""
|
|
682
|
+
)
|
|
683
|
+
return f"CREATE TABLE {ne}{self.compile(c, elem.path)}({schema}{pks})"
|
|
684
|
+
|
|
685
|
+
def render_droptable(self, c: Compiler, elem: DropTable) -> str:
|
|
686
|
+
ie = "IF EXISTS " if elem.if_exists else ""
|
|
687
|
+
return f"DROP TABLE {ie}{self.compile(c, elem.path)}"
|
|
688
|
+
|
|
689
|
+
def render_truncatetable(self, c: Compiler, elem: TruncateTable) -> str:
|
|
690
|
+
return f"TRUNCATE TABLE {self.compile(c, elem.path)}"
|
|
691
|
+
|
|
692
|
+
def render_inserttotable(self, c: Compiler, elem: InsertToTable) -> str:
|
|
693
|
+
if isinstance(elem.expr, ConstantTable):
|
|
694
|
+
expr = self.constant_values(elem.expr.rows)
|
|
695
|
+
else:
|
|
696
|
+
expr = self.compile(c, elem.expr)
|
|
697
|
+
|
|
698
|
+
columns = "(%s)" % ", ".join(map(self.quote, elem.columns)) if elem.columns is not None else ""
|
|
699
|
+
|
|
700
|
+
return f"INSERT INTO {self.compile(c, elem.path)}{columns} {expr}"
|
|
701
|
+
|
|
702
|
+
def limit_select(
|
|
703
|
+
self,
|
|
704
|
+
select_query: str,
|
|
705
|
+
offset: Optional[int] = None,
|
|
706
|
+
limit: Optional[int] = None,
|
|
707
|
+
has_order_by: Optional[bool] = None,
|
|
708
|
+
) -> str:
|
|
709
|
+
if offset:
|
|
710
|
+
raise NotImplementedError("No support for OFFSET in query")
|
|
711
|
+
|
|
712
|
+
return f"SELECT * FROM ({select_query}) AS LIMITED_SELECT LIMIT {limit}"
|
|
713
|
+
|
|
714
|
+
def concat(self, items: List[str]) -> str:
|
|
715
|
+
"Provide SQL for concatenating a bunch of columns into a string"
|
|
716
|
+
assert len(items) > 1
|
|
717
|
+
joined_exprs = ", ".join(items)
|
|
718
|
+
return f"concat({joined_exprs})"
|
|
719
|
+
|
|
720
|
+
def to_comparable(self, value: str, coltype: ColType) -> str:
|
|
721
|
+
"""Ensure that the expression is comparable in ``IS DISTINCT FROM``."""
|
|
722
|
+
return value
|
|
723
|
+
|
|
724
|
+
def is_distinct_from(self, a: str, b: str) -> str:
|
|
725
|
+
"Provide SQL for a comparison where NULL = NULL is true"
|
|
726
|
+
return f"{a} is distinct from {b}"
|
|
727
|
+
|
|
728
|
+
def timestamp_value(self, t: DbTime) -> str:
|
|
729
|
+
"Provide SQL for the given timestamp value"
|
|
730
|
+
return f"'{t.isoformat()}'"
|
|
731
|
+
|
|
732
|
+
def random(self) -> str:
|
|
733
|
+
"Provide SQL for generating a random number betweein 0..1"
|
|
734
|
+
return "random()"
|
|
735
|
+
|
|
736
|
+
def current_timestamp(self) -> str:
|
|
737
|
+
"Provide SQL for returning the current timestamp, aka now"
|
|
738
|
+
return "current_timestamp()"
|
|
739
|
+
|
|
740
|
+
def current_database(self) -> str:
|
|
741
|
+
"Provide SQL for returning the current default database."
|
|
742
|
+
return "current_database()"
|
|
743
|
+
|
|
744
|
+
def current_schema(self) -> str:
|
|
745
|
+
"Provide SQL for returning the current default schema."
|
|
746
|
+
return "current_schema()"
|
|
747
|
+
|
|
748
|
+
def explain_as_text(self, query: str) -> str:
|
|
749
|
+
"Provide SQL for explaining a query, returned as table(varchar)"
|
|
750
|
+
return f"EXPLAIN {query}"
|
|
751
|
+
|
|
752
|
+
def _constant_value(self, v):
|
|
753
|
+
if v is None:
|
|
754
|
+
return "NULL"
|
|
755
|
+
elif isinstance(v, str):
|
|
756
|
+
return f"'{v}'"
|
|
757
|
+
elif isinstance(v, datetime):
|
|
758
|
+
return self.timestamp_value(v)
|
|
759
|
+
elif isinstance(v, UUID): # probably unused anymore in favour of ArithUUID
|
|
760
|
+
return f"'{v}'"
|
|
761
|
+
elif isinstance(v, ArithUUID):
|
|
762
|
+
return f"'{v.uuid}'"
|
|
763
|
+
elif isinstance(v, decimal.Decimal):
|
|
764
|
+
return str(v)
|
|
765
|
+
elif isinstance(v, bytearray):
|
|
766
|
+
return f"'{v.decode()}'"
|
|
767
|
+
elif isinstance(v, Code):
|
|
768
|
+
return v.code
|
|
769
|
+
elif isinstance(v, ArithAlphanumeric):
|
|
770
|
+
return f"'{v._str}'"
|
|
771
|
+
elif isinstance(v, ArithUnicodeString):
|
|
772
|
+
return f"'{v._str}'"
|
|
773
|
+
elif isinstance(v, ArithDate):
|
|
774
|
+
return f"'{str(v)}'"
|
|
775
|
+
elif isinstance(v, ArithTimestamp):
|
|
776
|
+
return f"'{str(v)}'"
|
|
777
|
+
elif isinstance(v, ArithTimestampTZ):
|
|
778
|
+
return f"'{str(v)}'"
|
|
779
|
+
elif isinstance(v, ArithDateTime):
|
|
780
|
+
return self.timestamp_value(v._dt)
|
|
781
|
+
return repr(v)
|
|
782
|
+
|
|
783
|
+
def constant_values(self, rows) -> str:
|
|
784
|
+
values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows)
|
|
785
|
+
return f"VALUES {values}"
|
|
786
|
+
|
|
787
|
+
def type_repr(self, t) -> str:
|
|
788
|
+
if isinstance(t, str):
|
|
789
|
+
return t
|
|
790
|
+
elif isinstance(t, TimestampTZ):
|
|
791
|
+
return f"TIMESTAMP({min(t.precision, DEFAULT_DATETIME_PRECISION)})"
|
|
792
|
+
return {
|
|
793
|
+
int: "INT",
|
|
794
|
+
str: "VARCHAR",
|
|
795
|
+
bool: "BOOLEAN",
|
|
796
|
+
float: "FLOAT",
|
|
797
|
+
datetime: "TIMESTAMP",
|
|
798
|
+
}[t]
|
|
799
|
+
|
|
800
|
+
def parse_type(self, table_path: DbPath, info: RawColumnInfo) -> ColType:
|
|
801
|
+
"Parse type info as returned by the database"
|
|
802
|
+
|
|
803
|
+
cls = self.TYPE_CLASSES.get(info.data_type)
|
|
804
|
+
if cls is None:
|
|
805
|
+
return UnknownColType(info.data_type)
|
|
806
|
+
|
|
807
|
+
if issubclass(cls, TemporalType):
|
|
808
|
+
return cls(
|
|
809
|
+
precision=(
|
|
810
|
+
info.datetime_precision if info.datetime_precision is not None else DEFAULT_DATETIME_PRECISION
|
|
811
|
+
),
|
|
812
|
+
rounds=self.ROUNDS_ON_PREC_LOSS,
|
|
813
|
+
)
|
|
814
|
+
|
|
815
|
+
elif issubclass(cls, Integer):
|
|
816
|
+
return cls()
|
|
817
|
+
|
|
818
|
+
elif issubclass(cls, Boolean):
|
|
819
|
+
return cls()
|
|
820
|
+
|
|
821
|
+
elif issubclass(cls, Decimal):
|
|
822
|
+
if info.numeric_scale is None:
|
|
823
|
+
return cls(precision=0) # Needed for Oracle.
|
|
824
|
+
return cls(precision=info.numeric_scale)
|
|
825
|
+
|
|
826
|
+
elif issubclass(cls, Float):
|
|
827
|
+
# assert numeric_scale is None
|
|
828
|
+
return cls(
|
|
829
|
+
precision=self._convert_db_precision_to_digits(
|
|
830
|
+
info.numeric_precision if info.numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
|
|
831
|
+
)
|
|
832
|
+
)
|
|
833
|
+
|
|
834
|
+
elif issubclass(cls, (JSON, Array, Struct, Text, Native_UUID)):
|
|
835
|
+
return cls()
|
|
836
|
+
|
|
837
|
+
raise TypeError(f"Parsing {info.data_type} returned an unknown type {cls!r}.")
|
|
838
|
+
|
|
839
|
+
def _convert_db_precision_to_digits(self, p: int) -> int:
|
|
840
|
+
"""Convert from binary precision, used by floats, to decimal precision."""
|
|
841
|
+
# See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format
|
|
842
|
+
return math.floor(math.log(2**p, 10))
|
|
843
|
+
|
|
844
|
+
@property
|
|
845
|
+
@abstractmethod
|
|
846
|
+
def name(self) -> str:
|
|
847
|
+
"Name of the dialect"
|
|
848
|
+
|
|
849
|
+
@property
|
|
850
|
+
@abstractmethod
|
|
851
|
+
def ROUNDS_ON_PREC_LOSS(self) -> bool:
|
|
852
|
+
"True if db rounds real values when losing precision, False if it truncates."
|
|
853
|
+
|
|
854
|
+
@abstractmethod
|
|
855
|
+
def quote(self, s: str, is_table: bool = False) -> str:
|
|
856
|
+
"Quote SQL name"
|
|
857
|
+
|
|
858
|
+
@abstractmethod
|
|
859
|
+
def to_string(self, s: str) -> str:
|
|
860
|
+
# TODO rewrite using cast_to(x, str)
|
|
861
|
+
"Provide SQL for casting a column to string"
|
|
862
|
+
|
|
863
|
+
@abstractmethod
|
|
864
|
+
def set_timezone_to_utc(self) -> str:
|
|
865
|
+
"Provide SQL for setting the session timezone to UTC"
|
|
866
|
+
|
|
867
|
+
@abstractmethod
|
|
868
|
+
def md5_as_int(self, s: str) -> str:
|
|
869
|
+
"Provide SQL for computing md5 and returning an int"
|
|
870
|
+
|
|
871
|
+
@abstractmethod
|
|
872
|
+
def md5_as_hex(self, s: str) -> str:
|
|
873
|
+
"""Method to calculate MD5"""
|
|
874
|
+
|
|
875
|
+
@abstractmethod
|
|
876
|
+
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
|
|
877
|
+
"""Creates an SQL expression, that converts 'value' to a normalized timestamp.
|
|
878
|
+
|
|
879
|
+
The returned expression must accept any SQL datetime/timestamp, and return a string.
|
|
880
|
+
|
|
881
|
+
Date format: ``YYYY-MM-DD HH:mm:SS.FFFFFF``
|
|
882
|
+
|
|
883
|
+
Precision of dates should be rounded up/down according to coltype.rounds
|
|
884
|
+
e.g. precision 3 and coltype.rounds:
|
|
885
|
+
- 1969-12-31 23:59:59.999999 -> 1970-01-01 00:00:00.000000
|
|
886
|
+
- 1970-01-01 00:00:00.000888 -> 1970-01-01 00:00:00.001000
|
|
887
|
+
- 1970-01-01 00:00:00.123123 -> 1970-01-01 00:00:00.123000
|
|
888
|
+
|
|
889
|
+
Make sure NULLs remain NULLs
|
|
890
|
+
"""
|
|
891
|
+
|
|
892
|
+
@abstractmethod
|
|
893
|
+
def normalize_number(self, value: str, coltype: FractionalType) -> str:
|
|
894
|
+
"""Creates an SQL expression, that converts 'value' to a normalized number.
|
|
895
|
+
|
|
896
|
+
The returned expression must accept any SQL int/numeric/float, and return a string.
|
|
897
|
+
|
|
898
|
+
Floats/Decimals are expected in the format
|
|
899
|
+
"I.P"
|
|
900
|
+
|
|
901
|
+
Where I is the integer part of the number (as many digits as necessary),
|
|
902
|
+
and must be at least one digit (0).
|
|
903
|
+
P is the fractional digits, the amount of which is specified with
|
|
904
|
+
coltype.precision. Trailing zeroes may be necessary.
|
|
905
|
+
If P is 0, the dot is omitted.
|
|
906
|
+
|
|
907
|
+
Note: We use 'precision' differently than most databases. For decimals,
|
|
908
|
+
it's the same as ``numeric_scale``, and for floats, who use binary precision,
|
|
909
|
+
it can be calculated as ``log10(2**numeric_precision)``.
|
|
910
|
+
"""
|
|
911
|
+
|
|
912
|
+
def normalize_boolean(self, value: str, _coltype: Boolean) -> str:
|
|
913
|
+
"""Creates an SQL expression, that converts 'value' to either '0' or '1'."""
|
|
914
|
+
return self.to_string(value)
|
|
915
|
+
|
|
916
|
+
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
|
|
917
|
+
"""Creates an SQL expression, that strips uuids of artifacts like whitespace."""
|
|
918
|
+
if isinstance(coltype, String_UUID):
|
|
919
|
+
return f"TRIM({value})"
|
|
920
|
+
return self.to_string(value)
|
|
921
|
+
|
|
922
|
+
def normalize_json(self, value: str, _coltype: JSON) -> str:
|
|
923
|
+
"""Creates an SQL expression, that converts 'value' to its minified json string representation."""
|
|
924
|
+
return self.to_string(value)
|
|
925
|
+
|
|
926
|
+
def normalize_array(self, value: str, _coltype: Array) -> str:
|
|
927
|
+
"""Creates an SQL expression, that serialized an array into a JSON string."""
|
|
928
|
+
return self.to_string(value)
|
|
929
|
+
|
|
930
|
+
def normalize_struct(self, value: str, _coltype: Struct) -> str:
|
|
931
|
+
"""Creates an SQL expression, that serialized a typed struct into a JSON string."""
|
|
932
|
+
return self.to_string(value)
|
|
933
|
+
|
|
934
|
+
def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
|
|
935
|
+
"""Creates an SQL expression, that converts 'value' to a normalized representation.
|
|
936
|
+
|
|
937
|
+
The returned expression must accept any SQL value, and return a string.
|
|
938
|
+
|
|
939
|
+
The default implementation dispatches to a method according to `coltype`:
|
|
940
|
+
|
|
941
|
+
::
|
|
942
|
+
|
|
943
|
+
TemporalType -> normalize_timestamp()
|
|
944
|
+
FractionalType -> normalize_number()
|
|
945
|
+
*else* -> to_string()
|
|
946
|
+
|
|
947
|
+
(`Integer` falls in the *else* category)
|
|
948
|
+
|
|
949
|
+
"""
|
|
950
|
+
if isinstance(coltype, TemporalType):
|
|
951
|
+
return self.normalize_timestamp(value, coltype)
|
|
952
|
+
elif isinstance(coltype, FractionalType):
|
|
953
|
+
return self.normalize_number(value, coltype)
|
|
954
|
+
elif isinstance(coltype, ColType_UUID):
|
|
955
|
+
return self.normalize_uuid(value, coltype)
|
|
956
|
+
elif isinstance(coltype, Boolean):
|
|
957
|
+
return self.normalize_boolean(value, coltype)
|
|
958
|
+
elif isinstance(coltype, JSON):
|
|
959
|
+
return self.normalize_json(value, coltype)
|
|
960
|
+
elif isinstance(coltype, Array):
|
|
961
|
+
return self.normalize_array(value, coltype)
|
|
962
|
+
elif isinstance(coltype, Struct):
|
|
963
|
+
return self.normalize_struct(value, coltype)
|
|
964
|
+
return self.to_string(value)
|
|
965
|
+
|
|
966
|
+
def optimizer_hints(self, hints: str) -> str:
|
|
967
|
+
return f"/*+ {hints} */ "
|
|
968
|
+
|
|
969
|
+
def generate_view_name(self, view_name: str | None = None) -> str:
|
|
970
|
+
if view_name is not None:
|
|
971
|
+
return view_name
|
|
972
|
+
random_string = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(8))
|
|
973
|
+
timestamp = int(time.time())
|
|
974
|
+
return f"view_{timestamp}_{random_string.lower()}"
|
|
975
|
+
|
|
976
|
+
def create_view(self, query: str, schema: str | None, view_name: str | None = None) -> Tuple[str, str]:
|
|
977
|
+
view_name = self.generate_view_name(view_name=view_name)
|
|
978
|
+
schema_prefix = f"{schema}." if schema else ""
|
|
979
|
+
return f"CREATE VIEW {schema_prefix}{view_name} AS {query}", view_name
|
|
980
|
+
|
|
981
|
+
def drop_view(self, view_name: str, schema: str | None) -> str:
|
|
982
|
+
schema_prefix = f"{schema}." if schema else ""
|
|
983
|
+
return f"DROP VIEW {schema_prefix}{view_name}"
|
|
984
|
+
|
|
985
|
+
def get_column_raw_info(self, value: str) -> Optional[RawColumnInfo]:
|
|
986
|
+
raw_column_info = self.table_schema.get(value, None)
|
|
987
|
+
return raw_column_info
|
|
988
|
+
|
|
989
|
+
|
|
990
|
+
T = TypeVar("T", bound=BaseDialect)
|
|
991
|
+
Row = Sequence[Any]
|
|
992
|
+
|
|
993
|
+
|
|
994
|
+
@attrs.define(frozen=True)
|
|
995
|
+
class QueryResult:
|
|
996
|
+
rows: List[Row]
|
|
997
|
+
columns: Optional[list] = None
|
|
998
|
+
|
|
999
|
+
def __iter__(self) -> Iterator[Row]:
|
|
1000
|
+
return iter(self.rows)
|
|
1001
|
+
|
|
1002
|
+
def __len__(self) -> int:
|
|
1003
|
+
return len(self.rows)
|
|
1004
|
+
|
|
1005
|
+
def __getitem__(self, i) -> Row:
|
|
1006
|
+
return self.rows[i]
|
|
1007
|
+
|
|
1008
|
+
|
|
1009
|
+
@attrs.define(frozen=False, kw_only=True)
|
|
1010
|
+
class Database(abc.ABC):
|
|
1011
|
+
"""Base abstract class for databases.
|
|
1012
|
+
|
|
1013
|
+
Used for providing connection code and implementation specific SQL utilities.
|
|
1014
|
+
|
|
1015
|
+
Instanciated using :meth:`~data_diff.connect`
|
|
1016
|
+
"""
|
|
1017
|
+
|
|
1018
|
+
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = BaseDialect
|
|
1019
|
+
|
|
1020
|
+
SUPPORTS_ALPHANUMS: ClassVar[bool] = True
|
|
1021
|
+
SUPPORTS_UNIQUE_CONSTAINT: ClassVar[bool] = False
|
|
1022
|
+
CONNECT_URI_KWPARAMS: ClassVar[List[str]] = []
|
|
1023
|
+
|
|
1024
|
+
default_schema: Optional[str] = None
|
|
1025
|
+
_interactive: bool = False
|
|
1026
|
+
is_closed: bool = False
|
|
1027
|
+
_dialect: BaseDialect = None
|
|
1028
|
+
|
|
1029
|
+
def __enter__(self):
|
|
1030
|
+
return self
|
|
1031
|
+
|
|
1032
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
1033
|
+
self.close()
|
|
1034
|
+
|
|
1035
|
+
@property
|
|
1036
|
+
def name(self):
|
|
1037
|
+
return type(self).__name__
|
|
1038
|
+
|
|
1039
|
+
def compile(self, sql_ast):
|
|
1040
|
+
return self.dialect.compile(Compiler(self), sql_ast)
|
|
1041
|
+
|
|
1042
|
+
def query(self, sql_ast: Union[Expr, Generator], res_type: type = None, log_message: Optional[str] = None):
|
|
1043
|
+
"""Query the given SQL code/AST, and attempt to convert the result to type 'res_type'
|
|
1044
|
+
|
|
1045
|
+
If given a generator, it will execute all the yielded sql queries with the same thread and cursor.
|
|
1046
|
+
The results of the queries a returned by the `yield` stmt (using the .send() mechanism).
|
|
1047
|
+
It's a cleaner approach than exposing cursors, but may not be enough in all cases.
|
|
1048
|
+
"""
|
|
1049
|
+
|
|
1050
|
+
compiler = Compiler(self)
|
|
1051
|
+
if isinstance(sql_ast, Generator):
|
|
1052
|
+
sql_code = ThreadLocalInterpreter(compiler, sql_ast)
|
|
1053
|
+
elif isinstance(sql_ast, list):
|
|
1054
|
+
for i in sql_ast[:-1]:
|
|
1055
|
+
self.query(i)
|
|
1056
|
+
return self.query(sql_ast[-1], res_type)
|
|
1057
|
+
else:
|
|
1058
|
+
if isinstance(sql_ast, str):
|
|
1059
|
+
sql_code = sql_ast
|
|
1060
|
+
else:
|
|
1061
|
+
if res_type is None:
|
|
1062
|
+
res_type = sql_ast.type
|
|
1063
|
+
sql_code = self.compile(sql_ast)
|
|
1064
|
+
if sql_code is SKIP:
|
|
1065
|
+
return SKIP
|
|
1066
|
+
|
|
1067
|
+
if log_message:
|
|
1068
|
+
logger.debug("Running SQL (%s): %s \n%s", self.name, log_message, sql_code)
|
|
1069
|
+
else:
|
|
1070
|
+
logger.debug("Running SQL (%s):\n%s", self.name, sql_code)
|
|
1071
|
+
|
|
1072
|
+
if self._interactive and isinstance(sql_ast, Select):
|
|
1073
|
+
explained_sql = self.compile(Explain(sql_ast))
|
|
1074
|
+
explain = self._query(explained_sql)
|
|
1075
|
+
for row in explain:
|
|
1076
|
+
# Most returned a 1-tuple. Presto returns a string
|
|
1077
|
+
if isinstance(row, tuple):
|
|
1078
|
+
(row,) = row
|
|
1079
|
+
logger.debug("EXPLAIN: %s", row)
|
|
1080
|
+
answer = input("Continue? [y/n] ")
|
|
1081
|
+
if answer.lower() not in ["y", "yes"]:
|
|
1082
|
+
sys.exit(1)
|
|
1083
|
+
|
|
1084
|
+
res = self._query(sql_code)
|
|
1085
|
+
if res_type is list:
|
|
1086
|
+
return list(res)
|
|
1087
|
+
elif res_type is int:
|
|
1088
|
+
if not res:
|
|
1089
|
+
raise ValueError("Query returned 0 rows, expected 1")
|
|
1090
|
+
row = _one(res)
|
|
1091
|
+
if not row:
|
|
1092
|
+
raise ValueError("Row is empty, expected 1 column")
|
|
1093
|
+
res = _one(row)
|
|
1094
|
+
if res is None: # May happen due to sum() of 0 items
|
|
1095
|
+
return None
|
|
1096
|
+
return int(res)
|
|
1097
|
+
elif res_type is datetime:
|
|
1098
|
+
res = _one(_one(res))
|
|
1099
|
+
if isinstance(res, str):
|
|
1100
|
+
res = datetime.fromisoformat(res[:23]) # TODO use a better parsing method
|
|
1101
|
+
return res
|
|
1102
|
+
elif res_type is tuple:
|
|
1103
|
+
assert len(res) == 1, (sql_code, res)
|
|
1104
|
+
return res[0]
|
|
1105
|
+
elif getattr(res_type, "__origin__", None) is list and len(res_type.__args__) == 1:
|
|
1106
|
+
if res_type.__args__ in ((int,), (str,)):
|
|
1107
|
+
return [_one(row) for row in res]
|
|
1108
|
+
elif res_type.__args__ in [(Tuple,), (tuple,)]:
|
|
1109
|
+
return [tuple(row) for row in res]
|
|
1110
|
+
elif res_type.__args__ == (dict,):
|
|
1111
|
+
return [dict(safezip(res.columns, row)) for row in res]
|
|
1112
|
+
else:
|
|
1113
|
+
raise ValueError(res_type)
|
|
1114
|
+
return res
|
|
1115
|
+
|
|
1116
|
+
def enable_interactive(self):
|
|
1117
|
+
self._interactive = True
|
|
1118
|
+
|
|
1119
|
+
def select_table_schema(self, path: DbPath) -> str:
|
|
1120
|
+
"""Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)"""
|
|
1121
|
+
schema, name = self._normalize_table_path(path)
|
|
1122
|
+
|
|
1123
|
+
return (
|
|
1124
|
+
"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale "
|
|
1125
|
+
"FROM information_schema.columns "
|
|
1126
|
+
f"WHERE table_name = '{name}' AND table_schema = '{schema}'"
|
|
1127
|
+
)
|
|
1128
|
+
|
|
1129
|
+
def safe_get(self, lst, idx, default=None):
|
|
1130
|
+
return lst[idx] if 0 <= idx < len(lst) else default
|
|
1131
|
+
|
|
1132
|
+
def query_table_schema(self, path: DbPath) -> Dict[str, RawColumnInfo]:
|
|
1133
|
+
"""Query the table for its schema for table in 'path', and return {column: tuple}
|
|
1134
|
+
where the tuple is (table_name, col_name, type_repr, datetime_precision?, numeric_precision?, numeric_scale?)
|
|
1135
|
+
|
|
1136
|
+
Note: This method exists instead of select_table_schema(), just because not all databases support
|
|
1137
|
+
accessing the schema using a SQL query.
|
|
1138
|
+
"""
|
|
1139
|
+
rows = self.query(self.select_table_schema(path), list, log_message=path)
|
|
1140
|
+
if not rows:
|
|
1141
|
+
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")
|
|
1142
|
+
d = {
|
|
1143
|
+
r[0]: RawColumnInfo(
|
|
1144
|
+
column_name=self.safe_get(r, 0),
|
|
1145
|
+
data_type=self.safe_get(r, 1),
|
|
1146
|
+
datetime_precision=self.safe_get(r, 2),
|
|
1147
|
+
numeric_precision=self.safe_get(r, 3),
|
|
1148
|
+
numeric_scale=self.safe_get(r, 4),
|
|
1149
|
+
collation_name=self.safe_get(r, 5),
|
|
1150
|
+
character_maximum_length=self.safe_get(r, 6),
|
|
1151
|
+
)
|
|
1152
|
+
for r in rows
|
|
1153
|
+
}
|
|
1154
|
+
assert len(d) == len(rows)
|
|
1155
|
+
if not self.dialect.table_schema:
|
|
1156
|
+
self.dialect.table_schema = d
|
|
1157
|
+
return d
|
|
1158
|
+
|
|
1159
|
+
def select_table_unique_columns(self, path: DbPath) -> str:
|
|
1160
|
+
"""Provide SQL for selecting the names of unique columns in the table"""
|
|
1161
|
+
schema, name = self._normalize_table_path(path)
|
|
1162
|
+
|
|
1163
|
+
return (
|
|
1164
|
+
"SELECT column_name "
|
|
1165
|
+
"FROM information_schema.key_column_usage "
|
|
1166
|
+
f"WHERE table_name = '{name}' AND table_schema = '{schema}'"
|
|
1167
|
+
)
|
|
1168
|
+
|
|
1169
|
+
def query_table_unique_columns(self, path: DbPath) -> List[str]:
|
|
1170
|
+
"""Query the table for its unique columns for table in 'path', and return {column}"""
|
|
1171
|
+
if not self.SUPPORTS_UNIQUE_CONSTAINT:
|
|
1172
|
+
raise NotImplementedError("This database doesn't support 'unique' constraints")
|
|
1173
|
+
res = self.query(self.select_table_unique_columns(path), List[str], log_message=path)
|
|
1174
|
+
return list(res)
|
|
1175
|
+
|
|
1176
|
+
def create_view_from_query(
|
|
1177
|
+
self,
|
|
1178
|
+
query: str,
|
|
1179
|
+
schema: str | None = None,
|
|
1180
|
+
view_name: str | None = None,
|
|
1181
|
+
) -> str:
|
|
1182
|
+
"""Provide SQL for creating a view from a select statement"""
|
|
1183
|
+
schema = schema or self.default_schema
|
|
1184
|
+
view_query, v_name = self.dialect.create_view(query, schema, view_name)
|
|
1185
|
+
self.query(view_query)
|
|
1186
|
+
return v_name
|
|
1187
|
+
|
|
1188
|
+
def drop_view_from_db(self, view_name: str, schema: str | None = None) -> None:
|
|
1189
|
+
schema = schema or self.default_schema
|
|
1190
|
+
view_drop_query = self.dialect.drop_view(view_name, schema)
|
|
1191
|
+
self.query(view_drop_query)
|
|
1192
|
+
|
|
1193
|
+
def _process_table_schema(
|
|
1194
|
+
self,
|
|
1195
|
+
path: DbPath,
|
|
1196
|
+
raw_schema: Dict[str, RawColumnInfo],
|
|
1197
|
+
filter_columns: Sequence[str] = None,
|
|
1198
|
+
where: str = None,
|
|
1199
|
+
):
|
|
1200
|
+
"""Process the result of query_table_schema().
|
|
1201
|
+
|
|
1202
|
+
Done in a separate step, to minimize the amount of processed columns.
|
|
1203
|
+
Needed because processing each column may:
|
|
1204
|
+
* throw errors and warnings
|
|
1205
|
+
* query the database to sample values
|
|
1206
|
+
|
|
1207
|
+
"""
|
|
1208
|
+
if filter_columns is None:
|
|
1209
|
+
filtered_schema = raw_schema
|
|
1210
|
+
else:
|
|
1211
|
+
accept = {i.lower() for i in filter_columns}
|
|
1212
|
+
filtered_schema = {name: row for name, row in raw_schema.items() if name.lower() in accept}
|
|
1213
|
+
|
|
1214
|
+
col_dict = {info.column_name: self.dialect.parse_type(path, info) for info in filtered_schema.values()}
|
|
1215
|
+
|
|
1216
|
+
self._refine_coltypes(path, col_dict, where)
|
|
1217
|
+
|
|
1218
|
+
# Return a dict of form {name: type} after normalization
|
|
1219
|
+
return col_dict
|
|
1220
|
+
|
|
1221
|
+
def _refine_coltypes(
|
|
1222
|
+
self, table_path: DbPath, col_dict: Dict[str, ColType], where: Optional[str] = None, sample_size=64
|
|
1223
|
+
) -> Dict[str, ColType]:
|
|
1224
|
+
"""Refine the types in the column dict, by querying the database for a sample of their values
|
|
1225
|
+
|
|
1226
|
+
'where' restricts the rows to be sampled.
|
|
1227
|
+
"""
|
|
1228
|
+
|
|
1229
|
+
text_columns = [k for k, v in col_dict.items() if isinstance(v, Text)]
|
|
1230
|
+
if not text_columns:
|
|
1231
|
+
return col_dict
|
|
1232
|
+
|
|
1233
|
+
fields = [Code(self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID())) for c in text_columns]
|
|
1234
|
+
|
|
1235
|
+
samples_by_row = self.query(
|
|
1236
|
+
table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size),
|
|
1237
|
+
list,
|
|
1238
|
+
log_message=table_path,
|
|
1239
|
+
)
|
|
1240
|
+
samples_by_col = list(zip(*samples_by_row)) if samples_by_row else [[]] * len(text_columns)
|
|
1241
|
+
for col_name, samples in safezip(text_columns, samples_by_col):
|
|
1242
|
+
uuid_samples = [s for s in samples if s and is_uuid(s)]
|
|
1243
|
+
|
|
1244
|
+
if uuid_samples:
|
|
1245
|
+
if len(uuid_samples) != len(samples):
|
|
1246
|
+
logger.warning(
|
|
1247
|
+
f"Mixed UUID/Non-UUID values detected in column {'.'.join(table_path)}.{col_name}, disabling UUID support."
|
|
1248
|
+
)
|
|
1249
|
+
else:
|
|
1250
|
+
assert col_name in col_dict
|
|
1251
|
+
col_dict[col_name] = String_UUID(
|
|
1252
|
+
lowercase=all(s == s.lower() for s in uuid_samples),
|
|
1253
|
+
uppercase=all(s == s.upper() for s in uuid_samples),
|
|
1254
|
+
)
|
|
1255
|
+
continue
|
|
1256
|
+
|
|
1257
|
+
if self.SUPPORTS_ALPHANUMS: # Anything but MySQL (so far)
|
|
1258
|
+
alphanum_samples = [s for s in samples if String_Alphanum.test_value(s)]
|
|
1259
|
+
if alphanum_samples:
|
|
1260
|
+
if len(alphanum_samples) != len(samples):
|
|
1261
|
+
logger.debug(
|
|
1262
|
+
f"Mixed Alphanum/Non-Alphanum values detected in column {'.'.join(table_path)}.{col_name}. It cannot be used as a key."
|
|
1263
|
+
)
|
|
1264
|
+
# Fallback to Unicode string type
|
|
1265
|
+
assert col_name in col_dict
|
|
1266
|
+
col_dict[col_name] = String_VaryingUnicode(collation=col_dict[col_name].collation)
|
|
1267
|
+
else:
|
|
1268
|
+
assert col_name in col_dict
|
|
1269
|
+
col_dict[col_name] = String_VaryingAlphanum(collation=col_dict[col_name].collation)
|
|
1270
|
+
else:
|
|
1271
|
+
# All samples failed alphanum test, fallback to Unicode string
|
|
1272
|
+
assert col_name in col_dict
|
|
1273
|
+
col_dict[col_name] = String_VaryingUnicode(collation=col_dict[col_name].collation)
|
|
1274
|
+
|
|
1275
|
+
return col_dict
|
|
1276
|
+
|
|
1277
|
+
def _normalize_table_path(self, path: DbPath) -> DbPath:
|
|
1278
|
+
if len(path) == 1:
|
|
1279
|
+
return self.default_schema, path[0]
|
|
1280
|
+
elif len(path) == 2:
|
|
1281
|
+
return path
|
|
1282
|
+
|
|
1283
|
+
raise ValueError(f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table")
|
|
1284
|
+
|
|
1285
|
+
def _query_cursor(self, c, sql_code: str) -> QueryResult:
|
|
1286
|
+
assert isinstance(sql_code, str), sql_code
|
|
1287
|
+
try:
|
|
1288
|
+
c.execute(sql_code)
|
|
1289
|
+
if sql_code.lower().startswith(("select", "explain", "show")):
|
|
1290
|
+
columns = [col[0] for col in c.description]
|
|
1291
|
+
|
|
1292
|
+
fetched = c.fetchall()
|
|
1293
|
+
result = QueryResult(fetched, columns)
|
|
1294
|
+
return result
|
|
1295
|
+
elif sql_code.lower().startswith(("create", "drop")):
|
|
1296
|
+
try:
|
|
1297
|
+
c.connection.commit()
|
|
1298
|
+
except AttributeError:
|
|
1299
|
+
...
|
|
1300
|
+
except Exception as _e:
|
|
1301
|
+
try:
|
|
1302
|
+
c.connection.rollback()
|
|
1303
|
+
except Exception as rollback_error:
|
|
1304
|
+
loguru_logger.error(f"Rollback failed: {rollback_error}")
|
|
1305
|
+
raise
|
|
1306
|
+
|
|
1307
|
+
def _query_conn(self, conn, sql_code: Union[str, ThreadLocalInterpreter]) -> QueryResult:
|
|
1308
|
+
c = conn.cursor()
|
|
1309
|
+
callback = partial(self._query_cursor, c)
|
|
1310
|
+
return apply_query(callback, sql_code)
|
|
1311
|
+
|
|
1312
|
+
def quote(self, s: str, is_table: bool = False):
|
|
1313
|
+
"""Quote column name"""
|
|
1314
|
+
return self.dialect.quote(s)
|
|
1315
|
+
|
|
1316
|
+
def close(self):
|
|
1317
|
+
"""Close connection(s) to the database instance. Querying will stop functioning."""
|
|
1318
|
+
self.is_closed = True
|
|
1319
|
+
|
|
1320
|
+
@property
|
|
1321
|
+
def dialect(self) -> BaseDialect:
|
|
1322
|
+
"The dialect of the database. Used internally by Database, and also available publicly."
|
|
1323
|
+
|
|
1324
|
+
if not self._dialect:
|
|
1325
|
+
self._dialect = self.DIALECT_CLASS()
|
|
1326
|
+
return self._dialect
|
|
1327
|
+
|
|
1328
|
+
@property
|
|
1329
|
+
@abstractmethod
|
|
1330
|
+
def CONNECT_URI_HELP(self) -> str:
|
|
1331
|
+
"Example URI to show the user in help and error messages"
|
|
1332
|
+
|
|
1333
|
+
@property
|
|
1334
|
+
@abstractmethod
|
|
1335
|
+
def CONNECT_URI_PARAMS(self) -> List[str]:
|
|
1336
|
+
"List of parameters given in the path of the URI"
|
|
1337
|
+
|
|
1338
|
+
@abstractmethod
|
|
1339
|
+
def _query(self, sql_code: str) -> list:
|
|
1340
|
+
"Send query to database and return result"
|
|
1341
|
+
|
|
1342
|
+
@property
|
|
1343
|
+
@abstractmethod
|
|
1344
|
+
def is_autocommit(self) -> bool:
|
|
1345
|
+
"Return whether the database autocommits changes. When false, COMMIT statements are skipped."
|
|
1346
|
+
|
|
1347
|
+
|
|
1348
|
+
@attrs.define(frozen=False)
|
|
1349
|
+
class ThreadedDatabase(Database):
|
|
1350
|
+
"""Access the database through singleton threads.
|
|
1351
|
+
|
|
1352
|
+
Used for database connectors that do not support sharing their connection between different threads.
|
|
1353
|
+
"""
|
|
1354
|
+
|
|
1355
|
+
thread_count: int = 1
|
|
1356
|
+
|
|
1357
|
+
_init_error: Optional[Exception] = None
|
|
1358
|
+
_queue: Optional[ThreadPoolExecutor] = None
|
|
1359
|
+
thread_local: threading.local = attrs.field(factory=threading.local)
|
|
1360
|
+
|
|
1361
|
+
def __attrs_post_init__(self) -> None:
|
|
1362
|
+
self._queue = ThreadPoolExecutor(self.thread_count, initializer=self.set_conn)
|
|
1363
|
+
logger.info(f"[{self.name}] Starting a threadpool, size={self.thread_count}.")
|
|
1364
|
+
|
|
1365
|
+
def set_conn(self):
|
|
1366
|
+
assert not hasattr(self.thread_local, "conn")
|
|
1367
|
+
try:
|
|
1368
|
+
self.thread_local.conn = self.create_connection()
|
|
1369
|
+
except Exception as e:
|
|
1370
|
+
self._init_error = e
|
|
1371
|
+
|
|
1372
|
+
def _query(self, sql_code: Union[str, ThreadLocalInterpreter]) -> QueryResult:
|
|
1373
|
+
r = self._queue.submit(self._query_in_worker, sql_code)
|
|
1374
|
+
return r.result()
|
|
1375
|
+
|
|
1376
|
+
def _query_in_worker(self, sql_code: Union[str, ThreadLocalInterpreter]):
|
|
1377
|
+
"""This method runs in a worker thread"""
|
|
1378
|
+
if self._init_error:
|
|
1379
|
+
raise self._init_error
|
|
1380
|
+
return self._query_conn(self.thread_local.conn, sql_code)
|
|
1381
|
+
|
|
1382
|
+
@abstractmethod
|
|
1383
|
+
def create_connection(self):
|
|
1384
|
+
"""Return a connection instance, that supports the .cursor() method."""
|
|
1385
|
+
|
|
1386
|
+
def close(self):
|
|
1387
|
+
super().close()
|
|
1388
|
+
self._queue.shutdown()
|
|
1389
|
+
if hasattr(self.thread_local, "conn"):
|
|
1390
|
+
self.thread_local.conn.close()
|
|
1391
|
+
|
|
1392
|
+
@property
|
|
1393
|
+
def is_autocommit(self) -> bool:
|
|
1394
|
+
return False
|
|
1395
|
+
|
|
1396
|
+
|
|
1397
|
+
CHECKSUM_HEXDIGITS = 12 # Must be 12 or lower, otherwise SUM() overflows
|
|
1398
|
+
MD5_HEXDIGITS = 32
|
|
1399
|
+
|
|
1400
|
+
_CHECKSUM_BITSIZE = CHECKSUM_HEXDIGITS << 2
|
|
1401
|
+
CHECKSUM_MASK = (2**_CHECKSUM_BITSIZE) - 1
|
|
1402
|
+
|
|
1403
|
+
# bigint is typically 8 bytes
|
|
1404
|
+
# if checksum is shorter, most databases will pad it with zeros
|
|
1405
|
+
# 0xFF → 0x00000000000000FF;
|
|
1406
|
+
# because of that, the numeric representation is always positive,
|
|
1407
|
+
# which limits the number of checksums that we can add together before overflowing.
|
|
1408
|
+
# we can fix that by adding a negative offset of half the max value,
|
|
1409
|
+
# so that the distribution is from -0.5*max to +0.5*max.
|
|
1410
|
+
# then negative numbers can compensate for the positive ones allowing to add more checksums together
|
|
1411
|
+
# without overflowing.
|
|
1412
|
+
CHECKSUM_OFFSET = CHECKSUM_MASK // 2
|
|
1413
|
+
|
|
1414
|
+
DEFAULT_DATETIME_PRECISION = 6
|
|
1415
|
+
DEFAULT_NUMERIC_PRECISION = 24
|
|
1416
|
+
|
|
1417
|
+
TIMESTAMP_PRECISION_POS = 20 # len("2022-06-03 12:24:35.") == 20
|