dcs-sdk 1.6.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_diff/__init__.py +221 -0
- data_diff/__main__.py +517 -0
- data_diff/abcs/__init__.py +13 -0
- data_diff/abcs/compiler.py +27 -0
- data_diff/abcs/database_types.py +402 -0
- data_diff/config.py +141 -0
- data_diff/databases/__init__.py +38 -0
- data_diff/databases/_connect.py +323 -0
- data_diff/databases/base.py +1417 -0
- data_diff/databases/bigquery.py +376 -0
- data_diff/databases/clickhouse.py +217 -0
- data_diff/databases/databricks.py +262 -0
- data_diff/databases/duckdb.py +207 -0
- data_diff/databases/mssql.py +343 -0
- data_diff/databases/mysql.py +189 -0
- data_diff/databases/oracle.py +238 -0
- data_diff/databases/postgresql.py +293 -0
- data_diff/databases/presto.py +222 -0
- data_diff/databases/redis.py +93 -0
- data_diff/databases/redshift.py +233 -0
- data_diff/databases/snowflake.py +222 -0
- data_diff/databases/sybase.py +720 -0
- data_diff/databases/trino.py +73 -0
- data_diff/databases/vertica.py +174 -0
- data_diff/diff_tables.py +489 -0
- data_diff/errors.py +17 -0
- data_diff/format.py +369 -0
- data_diff/hashdiff_tables.py +1026 -0
- data_diff/info_tree.py +76 -0
- data_diff/joindiff_tables.py +434 -0
- data_diff/lexicographic_space.py +253 -0
- data_diff/parse_time.py +88 -0
- data_diff/py.typed +0 -0
- data_diff/queries/__init__.py +13 -0
- data_diff/queries/api.py +213 -0
- data_diff/queries/ast_classes.py +811 -0
- data_diff/queries/base.py +38 -0
- data_diff/queries/extras.py +43 -0
- data_diff/query_utils.py +70 -0
- data_diff/schema.py +67 -0
- data_diff/table_segment.py +583 -0
- data_diff/thread_utils.py +112 -0
- data_diff/utils.py +1022 -0
- data_diff/version.py +15 -0
- dcs_core/__init__.py +13 -0
- dcs_core/__main__.py +17 -0
- dcs_core/__version__.py +15 -0
- dcs_core/cli/__init__.py +13 -0
- dcs_core/cli/cli.py +165 -0
- dcs_core/core/__init__.py +19 -0
- dcs_core/core/common/__init__.py +13 -0
- dcs_core/core/common/errors.py +50 -0
- dcs_core/core/common/models/__init__.py +13 -0
- dcs_core/core/common/models/configuration.py +284 -0
- dcs_core/core/common/models/dashboard.py +24 -0
- dcs_core/core/common/models/data_source_resource.py +75 -0
- dcs_core/core/common/models/metric.py +160 -0
- dcs_core/core/common/models/profile.py +75 -0
- dcs_core/core/common/models/validation.py +216 -0
- dcs_core/core/common/models/widget.py +44 -0
- dcs_core/core/configuration/__init__.py +13 -0
- dcs_core/core/configuration/config_loader.py +139 -0
- dcs_core/core/configuration/configuration_parser.py +262 -0
- dcs_core/core/configuration/configuration_parser_arc.py +328 -0
- dcs_core/core/datasource/__init__.py +13 -0
- dcs_core/core/datasource/base.py +62 -0
- dcs_core/core/datasource/manager.py +112 -0
- dcs_core/core/datasource/search_datasource.py +421 -0
- dcs_core/core/datasource/sql_datasource.py +1094 -0
- dcs_core/core/inspect.py +163 -0
- dcs_core/core/logger/__init__.py +13 -0
- dcs_core/core/logger/base.py +32 -0
- dcs_core/core/logger/default_logger.py +94 -0
- dcs_core/core/metric/__init__.py +13 -0
- dcs_core/core/metric/base.py +220 -0
- dcs_core/core/metric/combined_metric.py +98 -0
- dcs_core/core/metric/custom_metric.py +34 -0
- dcs_core/core/metric/manager.py +137 -0
- dcs_core/core/metric/numeric_metric.py +403 -0
- dcs_core/core/metric/reliability_metric.py +90 -0
- dcs_core/core/profiling/__init__.py +13 -0
- dcs_core/core/profiling/datasource_profiling.py +136 -0
- dcs_core/core/profiling/numeric_field_profiling.py +72 -0
- dcs_core/core/profiling/text_field_profiling.py +67 -0
- dcs_core/core/repository/__init__.py +13 -0
- dcs_core/core/repository/metric_repository.py +77 -0
- dcs_core/core/utils/__init__.py +13 -0
- dcs_core/core/utils/log.py +29 -0
- dcs_core/core/utils/tracking.py +105 -0
- dcs_core/core/utils/utils.py +44 -0
- dcs_core/core/validation/__init__.py +13 -0
- dcs_core/core/validation/base.py +230 -0
- dcs_core/core/validation/completeness_validation.py +153 -0
- dcs_core/core/validation/custom_query_validation.py +24 -0
- dcs_core/core/validation/manager.py +282 -0
- dcs_core/core/validation/numeric_validation.py +276 -0
- dcs_core/core/validation/reliability_validation.py +91 -0
- dcs_core/core/validation/uniqueness_validation.py +61 -0
- dcs_core/core/validation/validity_validation.py +738 -0
- dcs_core/integrations/__init__.py +13 -0
- dcs_core/integrations/databases/__init__.py +13 -0
- dcs_core/integrations/databases/bigquery.py +187 -0
- dcs_core/integrations/databases/databricks.py +51 -0
- dcs_core/integrations/databases/db2.py +652 -0
- dcs_core/integrations/databases/elasticsearch.py +61 -0
- dcs_core/integrations/databases/mssql.py +829 -0
- dcs_core/integrations/databases/mysql.py +409 -0
- dcs_core/integrations/databases/opensearch.py +64 -0
- dcs_core/integrations/databases/oracle.py +719 -0
- dcs_core/integrations/databases/postgres.py +482 -0
- dcs_core/integrations/databases/redshift.py +53 -0
- dcs_core/integrations/databases/snowflake.py +48 -0
- dcs_core/integrations/databases/spark_df.py +111 -0
- dcs_core/integrations/databases/sybase.py +1069 -0
- dcs_core/integrations/storage/__init__.py +13 -0
- dcs_core/integrations/storage/local_file.py +149 -0
- dcs_core/integrations/utils/__init__.py +13 -0
- dcs_core/integrations/utils/utils.py +36 -0
- dcs_core/report/__init__.py +13 -0
- dcs_core/report/dashboard.py +211 -0
- dcs_core/report/models.py +88 -0
- dcs_core/report/static/assets/fonts/DMSans-Bold.ttf +0 -0
- dcs_core/report/static/assets/fonts/DMSans-Medium.ttf +0 -0
- dcs_core/report/static/assets/fonts/DMSans-Regular.ttf +0 -0
- dcs_core/report/static/assets/fonts/DMSans-SemiBold.ttf +0 -0
- dcs_core/report/static/assets/images/docs.svg +6 -0
- dcs_core/report/static/assets/images/github.svg +4 -0
- dcs_core/report/static/assets/images/logo.svg +7 -0
- dcs_core/report/static/assets/images/slack.svg +13 -0
- dcs_core/report/static/index.js +2 -0
- dcs_core/report/static/index.js.LICENSE.txt +3971 -0
- dcs_sdk/__init__.py +13 -0
- dcs_sdk/__main__.py +18 -0
- dcs_sdk/__version__.py +15 -0
- dcs_sdk/cli/__init__.py +13 -0
- dcs_sdk/cli/cli.py +163 -0
- dcs_sdk/sdk/__init__.py +58 -0
- dcs_sdk/sdk/config/__init__.py +13 -0
- dcs_sdk/sdk/config/config_loader.py +491 -0
- dcs_sdk/sdk/data_diff/__init__.py +13 -0
- dcs_sdk/sdk/data_diff/data_differ.py +821 -0
- dcs_sdk/sdk/rules/__init__.py +15 -0
- dcs_sdk/sdk/rules/rules_mappping.py +31 -0
- dcs_sdk/sdk/rules/rules_repository.py +214 -0
- dcs_sdk/sdk/rules/schema_rules.py +65 -0
- dcs_sdk/sdk/utils/__init__.py +13 -0
- dcs_sdk/sdk/utils/serializer.py +25 -0
- dcs_sdk/sdk/utils/similarity_score/__init__.py +13 -0
- dcs_sdk/sdk/utils/similarity_score/base_provider.py +153 -0
- dcs_sdk/sdk/utils/similarity_score/cosine_similarity_provider.py +39 -0
- dcs_sdk/sdk/utils/similarity_score/jaccard_provider.py +24 -0
- dcs_sdk/sdk/utils/similarity_score/levenshtein_distance_provider.py +31 -0
- dcs_sdk/sdk/utils/table.py +475 -0
- dcs_sdk/sdk/utils/themes.py +40 -0
- dcs_sdk/sdk/utils/utils.py +349 -0
- dcs_sdk-1.6.5.dist-info/METADATA +150 -0
- dcs_sdk-1.6.5.dist-info/RECORD +159 -0
- dcs_sdk-1.6.5.dist-info/WHEEL +4 -0
- dcs_sdk-1.6.5.dist-info/entry_points.txt +4 -0
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
# Copyright 2022-present, the Waterdip Labs Pvt. Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import secrets
|
|
16
|
+
import string
|
|
17
|
+
import time
|
|
18
|
+
from typing import Any, ClassVar, Dict, List, Optional, Type
|
|
19
|
+
|
|
20
|
+
import attrs
|
|
21
|
+
|
|
22
|
+
from data_diff.abcs.database_types import (
|
|
23
|
+
ColType,
|
|
24
|
+
ColType_UUID,
|
|
25
|
+
DbPath,
|
|
26
|
+
DbTime,
|
|
27
|
+
Decimal,
|
|
28
|
+
Float,
|
|
29
|
+
FractionalType,
|
|
30
|
+
TemporalType,
|
|
31
|
+
Text,
|
|
32
|
+
Timestamp,
|
|
33
|
+
TimestampTZ,
|
|
34
|
+
)
|
|
35
|
+
from data_diff.databases.base import (
|
|
36
|
+
CHECKSUM_HEXDIGITS,
|
|
37
|
+
CHECKSUM_OFFSET,
|
|
38
|
+
MD5_HEXDIGITS,
|
|
39
|
+
TIMESTAMP_PRECISION_POS,
|
|
40
|
+
BaseDialect,
|
|
41
|
+
ConnectError,
|
|
42
|
+
QueryError,
|
|
43
|
+
ThreadedDatabase,
|
|
44
|
+
import_helper,
|
|
45
|
+
)
|
|
46
|
+
from data_diff.schema import RawColumnInfo
|
|
47
|
+
from data_diff.utils import match_regexps
|
|
48
|
+
|
|
49
|
+
SESSION_TIME_ZONE = None # Changed by the tests
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@import_helper("oracle")
|
|
53
|
+
def import_oracle():
|
|
54
|
+
import oracledb
|
|
55
|
+
|
|
56
|
+
return oracledb
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@attrs.define(frozen=False)
|
|
60
|
+
class Dialect(
|
|
61
|
+
BaseDialect,
|
|
62
|
+
):
|
|
63
|
+
name = "Oracle"
|
|
64
|
+
SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True
|
|
65
|
+
SUPPORTS_INDEXES = True
|
|
66
|
+
TYPE_CLASSES: Dict[str, type] = {
|
|
67
|
+
"NUMBER": Decimal,
|
|
68
|
+
"FLOAT": Float,
|
|
69
|
+
# Text
|
|
70
|
+
"CHAR": Text,
|
|
71
|
+
"NCHAR": Text,
|
|
72
|
+
"NVARCHAR2": Text,
|
|
73
|
+
"VARCHAR2": Text,
|
|
74
|
+
"DATE": Timestamp,
|
|
75
|
+
}
|
|
76
|
+
ROUNDS_ON_PREC_LOSS = True
|
|
77
|
+
PLACEHOLDER_TABLE = "DUAL"
|
|
78
|
+
|
|
79
|
+
def quote(self, s: str, is_table: bool = False) -> str:
|
|
80
|
+
if s in self.TABLE_NAMES and self.default_schema and is_table:
|
|
81
|
+
return f'"{self.default_schema}"."{s}"'
|
|
82
|
+
return f'"{s}"'
|
|
83
|
+
|
|
84
|
+
def to_string(self, s: str) -> str:
|
|
85
|
+
return f"cast({s} as varchar(1024))"
|
|
86
|
+
|
|
87
|
+
def limit_select(
|
|
88
|
+
self,
|
|
89
|
+
select_query: str,
|
|
90
|
+
offset: Optional[int] = None,
|
|
91
|
+
limit: Optional[int] = None,
|
|
92
|
+
has_order_by: Optional[bool] = None,
|
|
93
|
+
) -> str:
|
|
94
|
+
if offset:
|
|
95
|
+
raise NotImplementedError("No support for OFFSET in query")
|
|
96
|
+
|
|
97
|
+
return f"SELECT * FROM ({select_query}) FETCH NEXT {limit} ROWS ONLY"
|
|
98
|
+
|
|
99
|
+
def concat(self, items: List[str]) -> str:
|
|
100
|
+
joined_exprs = " || ".join(items)
|
|
101
|
+
return f"({joined_exprs})"
|
|
102
|
+
|
|
103
|
+
def timestamp_value(self, t: DbTime) -> str:
|
|
104
|
+
return "timestamp '%s'" % t.isoformat(" ")
|
|
105
|
+
|
|
106
|
+
def random(self) -> str:
|
|
107
|
+
return "dbms_random.value"
|
|
108
|
+
|
|
109
|
+
def is_distinct_from(self, a: str, b: str) -> str:
|
|
110
|
+
return f"DECODE({a}, {b}, 1, 0) = 0"
|
|
111
|
+
|
|
112
|
+
def type_repr(self, t) -> str:
|
|
113
|
+
try:
|
|
114
|
+
return {
|
|
115
|
+
str: "VARCHAR(1024)",
|
|
116
|
+
}[t]
|
|
117
|
+
except KeyError:
|
|
118
|
+
return super().type_repr(t)
|
|
119
|
+
|
|
120
|
+
def constant_values(self, rows) -> str:
|
|
121
|
+
return " UNION ALL ".join(
|
|
122
|
+
"SELECT %s FROM DUAL" % ", ".join(self._constant_value(v) for v in row) for row in rows
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
def explain_as_text(self, query: str) -> str:
|
|
126
|
+
raise NotImplementedError("Explain not yet implemented in Oracle")
|
|
127
|
+
|
|
128
|
+
def parse_type(self, table_path: DbPath, info: RawColumnInfo) -> ColType:
|
|
129
|
+
regexps = {
|
|
130
|
+
r"TIMESTAMP\((\d)\) WITH LOCAL TIME ZONE": Timestamp,
|
|
131
|
+
r"TIMESTAMP\((\d)\) WITH TIME ZONE": TimestampTZ,
|
|
132
|
+
r"TIMESTAMP\((\d)\)": Timestamp,
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
for m, t_cls in match_regexps(regexps, info.data_type):
|
|
136
|
+
precision = int(m.group(1))
|
|
137
|
+
return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS)
|
|
138
|
+
|
|
139
|
+
return super().parse_type(table_path, info)
|
|
140
|
+
|
|
141
|
+
def set_timezone_to_utc(self) -> str:
|
|
142
|
+
return "ALTER SESSION SET TIME_ZONE = 'UTC'"
|
|
143
|
+
|
|
144
|
+
def current_timestamp(self) -> str:
|
|
145
|
+
return "LOCALTIMESTAMP"
|
|
146
|
+
|
|
147
|
+
def md5_as_int(self, s: str) -> str:
|
|
148
|
+
# standard_hash is faster than DBMS_CRYPTO.Hash
|
|
149
|
+
# TODO: Find a way to use UTL_RAW.CAST_TO_BINARY_INTEGER ?
|
|
150
|
+
return f"to_number(substr(standard_hash({s}, 'MD5'), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 'xxxxxxxxxxxxxxx') - {CHECKSUM_OFFSET}"
|
|
151
|
+
|
|
152
|
+
def md5_as_hex(self, s: str) -> str:
|
|
153
|
+
return f"standard_hash({s}, 'MD5')"
|
|
154
|
+
|
|
155
|
+
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
|
|
156
|
+
# Cast is necessary for correct MD5 (trimming not enough)
|
|
157
|
+
return f"CAST(TRIM({value}) AS VARCHAR(36))"
|
|
158
|
+
|
|
159
|
+
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
|
|
160
|
+
if coltype.rounds:
|
|
161
|
+
return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')"
|
|
162
|
+
|
|
163
|
+
if coltype.precision > 0:
|
|
164
|
+
truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.FF{coltype.precision}')"
|
|
165
|
+
else:
|
|
166
|
+
truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.')"
|
|
167
|
+
return f"RPAD({truncated}, {TIMESTAMP_PRECISION_POS+6}, '0')"
|
|
168
|
+
|
|
169
|
+
def normalize_number(self, value: str, coltype: FractionalType) -> str:
|
|
170
|
+
# FM999.9990
|
|
171
|
+
format_str = "FM" + "9" * (38 - coltype.precision)
|
|
172
|
+
if coltype.precision:
|
|
173
|
+
format_str += "0." + "9" * (coltype.precision - 1) + "0"
|
|
174
|
+
return f"to_char({value}, '{format_str}')"
|
|
175
|
+
|
|
176
|
+
def generate_view_name(self, view_name: str | None = None) -> str:
|
|
177
|
+
if view_name is not None:
|
|
178
|
+
return view_name.upper()
|
|
179
|
+
random_string = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(8))
|
|
180
|
+
timestamp = int(time.time())
|
|
181
|
+
return f"view_{timestamp}_{random_string.lower()}".upper()
|
|
182
|
+
|
|
183
|
+
def parse_table_name(self, name: str) -> DbPath:
|
|
184
|
+
"Parse the given table name into a DbPath"
|
|
185
|
+
self.TABLE_NAMES.append(name.split(".")[-1])
|
|
186
|
+
return tuple(name.split("."))
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@attrs.define(frozen=False, init=False, kw_only=True)
|
|
190
|
+
class Oracle(ThreadedDatabase):
|
|
191
|
+
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
|
|
192
|
+
CONNECT_URI_HELP = "oracle://<user>:<password>@<host>/<database>"
|
|
193
|
+
CONNECT_URI_PARAMS = ["database?"]
|
|
194
|
+
|
|
195
|
+
kwargs: Dict[str, Any]
|
|
196
|
+
_oracle: Any
|
|
197
|
+
_conn: Any
|
|
198
|
+
|
|
199
|
+
def __init__(self, *, host, database, thread_count, **kw) -> None:
|
|
200
|
+
super().__init__(thread_count=thread_count)
|
|
201
|
+
self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw)
|
|
202
|
+
self.default_schema = kw.get("schema", None) or kw.get("user").upper()
|
|
203
|
+
self.dialect.default_schema = self.default_schema
|
|
204
|
+
self.kwargs = {k: v for k, v in self.kwargs.items() if v}
|
|
205
|
+
self._oracle = None
|
|
206
|
+
if "schema" in self.kwargs:
|
|
207
|
+
del self.kwargs["schema"]
|
|
208
|
+
self._conn = self.create_connection()
|
|
209
|
+
|
|
210
|
+
def create_connection(self):
|
|
211
|
+
self._oracle = import_oracle()
|
|
212
|
+
try:
|
|
213
|
+
self._conn = self._oracle.connect(**self.kwargs)
|
|
214
|
+
if SESSION_TIME_ZONE:
|
|
215
|
+
self._conn.cursor().execute(f"ALTER SESSION SET TIME_ZONE = '{SESSION_TIME_ZONE}'")
|
|
216
|
+
return self._conn
|
|
217
|
+
except Exception as e:
|
|
218
|
+
raise ConnectError(*e.args) from e
|
|
219
|
+
|
|
220
|
+
def _query_cursor(self, c, sql_code: str):
|
|
221
|
+
try:
|
|
222
|
+
return super()._query_cursor(c, sql_code)
|
|
223
|
+
except self._oracle.DatabaseError as e:
|
|
224
|
+
raise QueryError(e)
|
|
225
|
+
|
|
226
|
+
def select_table_schema(self, path: DbPath) -> str:
|
|
227
|
+
schema, name = self._normalize_table_path(path)
|
|
228
|
+
|
|
229
|
+
return (
|
|
230
|
+
f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision, "
|
|
231
|
+
f"data_scale as numeric_scale, NULL as collation_name, char_length as character_maximum_length "
|
|
232
|
+
f"FROM ALL_TAB_COLUMNS WHERE table_name = '{name}' AND owner = '{schema}'"
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
def close(self):
|
|
236
|
+
super().close()
|
|
237
|
+
if self._conn is not None:
|
|
238
|
+
self._conn.close()
|
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
# Copyright 2022-present, the Waterdip Labs Pvt. Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, ClassVar, Dict, List, Tuple, Type
|
|
16
|
+
from urllib.parse import unquote
|
|
17
|
+
|
|
18
|
+
import attrs
|
|
19
|
+
|
|
20
|
+
from data_diff.abcs.database_types import (
|
|
21
|
+
JSON,
|
|
22
|
+
Boolean,
|
|
23
|
+
ColType,
|
|
24
|
+
Date,
|
|
25
|
+
DbPath,
|
|
26
|
+
Decimal,
|
|
27
|
+
Float,
|
|
28
|
+
FractionalType,
|
|
29
|
+
Integer,
|
|
30
|
+
Native_UUID,
|
|
31
|
+
TemporalType,
|
|
32
|
+
Text,
|
|
33
|
+
Time,
|
|
34
|
+
Timestamp,
|
|
35
|
+
TimestampTZ,
|
|
36
|
+
)
|
|
37
|
+
from data_diff.databases.base import (
|
|
38
|
+
_CHECKSUM_BITSIZE,
|
|
39
|
+
CHECKSUM_HEXDIGITS,
|
|
40
|
+
CHECKSUM_OFFSET,
|
|
41
|
+
MD5_HEXDIGITS,
|
|
42
|
+
TIMESTAMP_PRECISION_POS,
|
|
43
|
+
BaseDialect,
|
|
44
|
+
ConnectError,
|
|
45
|
+
QueryResult,
|
|
46
|
+
ThreadedDatabase,
|
|
47
|
+
import_helper,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
SESSION_TIME_ZONE = None # Changed by the tests
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@import_helper("postgresql")
|
|
54
|
+
def import_postgresql():
|
|
55
|
+
import psycopg2.extras
|
|
56
|
+
|
|
57
|
+
psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select)
|
|
58
|
+
return psycopg2
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@attrs.define(frozen=False)
|
|
62
|
+
class PostgresqlDialect(BaseDialect):
|
|
63
|
+
name = "PostgreSQL"
|
|
64
|
+
ROUNDS_ON_PREC_LOSS = True
|
|
65
|
+
SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True
|
|
66
|
+
SUPPORTS_INDEXES = True
|
|
67
|
+
|
|
68
|
+
# https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-NUMERIC-DECIMAL
|
|
69
|
+
# without any precision or scale creates an “unconstrained numeric” column
|
|
70
|
+
# in which numeric values of any length can be stored, up to the implementation limits.
|
|
71
|
+
# https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-NUMERIC-TABLE
|
|
72
|
+
DEFAULT_NUMERIC_PRECISION = 16383
|
|
73
|
+
|
|
74
|
+
TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = {
|
|
75
|
+
# Timestamps
|
|
76
|
+
"timestamp with time zone": TimestampTZ,
|
|
77
|
+
"timestamp without time zone": Timestamp,
|
|
78
|
+
"timestamp": Timestamp,
|
|
79
|
+
"date": Date,
|
|
80
|
+
"time with time zone": Time,
|
|
81
|
+
"time without time zone": Time,
|
|
82
|
+
# Numbers
|
|
83
|
+
"double precision": Float,
|
|
84
|
+
"real": Float,
|
|
85
|
+
"decimal": Decimal,
|
|
86
|
+
"smallint": Integer,
|
|
87
|
+
"integer": Integer,
|
|
88
|
+
"numeric": Decimal,
|
|
89
|
+
"bigint": Integer,
|
|
90
|
+
# Text
|
|
91
|
+
"character": Text,
|
|
92
|
+
"character varying": Text,
|
|
93
|
+
"varchar": Text,
|
|
94
|
+
"text": Text,
|
|
95
|
+
"json": JSON,
|
|
96
|
+
"jsonb": JSON,
|
|
97
|
+
"uuid": Native_UUID,
|
|
98
|
+
"boolean": Boolean,
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
def quote(self, s: str, is_table: bool = False) -> str:
|
|
102
|
+
if s in self.TABLE_NAMES and self.default_schema and is_table:
|
|
103
|
+
return f'"{self.default_schema}"."{s}"'
|
|
104
|
+
return f'"{s}"'
|
|
105
|
+
|
|
106
|
+
def to_string(self, s: str):
|
|
107
|
+
return f"{s}::varchar"
|
|
108
|
+
|
|
109
|
+
def concat(self, items: List[str]) -> str:
|
|
110
|
+
joined_exprs = " || ".join(items)
|
|
111
|
+
return f"({joined_exprs})"
|
|
112
|
+
|
|
113
|
+
def _convert_db_precision_to_digits(self, p: int) -> int:
|
|
114
|
+
# Subtracting 2 due to wierd precision issues in PostgreSQL
|
|
115
|
+
return super()._convert_db_precision_to_digits(p) - 2
|
|
116
|
+
|
|
117
|
+
def set_timezone_to_utc(self) -> str:
|
|
118
|
+
return "SET TIME ZONE 'UTC'"
|
|
119
|
+
|
|
120
|
+
def current_timestamp(self) -> str:
|
|
121
|
+
return "current_timestamp"
|
|
122
|
+
|
|
123
|
+
def type_repr(self, t) -> str:
|
|
124
|
+
if isinstance(t, TimestampTZ):
|
|
125
|
+
return f"timestamp ({t.precision}) with time zone"
|
|
126
|
+
return super().type_repr(t)
|
|
127
|
+
|
|
128
|
+
def md5_as_int(self, s: str) -> str:
|
|
129
|
+
return f"('x' || substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}))::bit({_CHECKSUM_BITSIZE})::bigint - {CHECKSUM_OFFSET}"
|
|
130
|
+
|
|
131
|
+
def md5_as_hex(self, s: str) -> str:
|
|
132
|
+
return f"md5({s})"
|
|
133
|
+
|
|
134
|
+
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
|
|
135
|
+
def _add_padding(coltype: TemporalType, timestamp6: str):
|
|
136
|
+
return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
|
|
137
|
+
|
|
138
|
+
try:
|
|
139
|
+
is_date = coltype.is_date
|
|
140
|
+
is_time = coltype.is_time
|
|
141
|
+
except:
|
|
142
|
+
is_date = False
|
|
143
|
+
is_time = False
|
|
144
|
+
|
|
145
|
+
if isinstance(coltype, Date) or is_date:
|
|
146
|
+
return f"cast({value} as varchar)"
|
|
147
|
+
|
|
148
|
+
if isinstance(coltype, Time) or is_time:
|
|
149
|
+
seconds = f"EXTRACT( epoch from {value})"
|
|
150
|
+
rounded = f"ROUND({seconds}, {coltype.precision})"
|
|
151
|
+
time_value = f"CAST('00:00:00' as time) + make_interval(0, 0, 0, 0, 0, 0, {rounded})" # 6th arg = seconds
|
|
152
|
+
converted = f"to_char({time_value}, 'hh24:mi:ss.ff6')"
|
|
153
|
+
return converted
|
|
154
|
+
|
|
155
|
+
if coltype.rounds:
|
|
156
|
+
# NULL value expected to return NULL after normalization
|
|
157
|
+
null_case_begin = f"CASE WHEN {value} IS NULL THEN NULL ELSE "
|
|
158
|
+
null_case_end = "END"
|
|
159
|
+
|
|
160
|
+
# 294277 or 4714 BC would be out of range, make sure we can't round to that
|
|
161
|
+
# TODO test timezones for overflow?
|
|
162
|
+
max_timestamp = "294276-12-31 23:59:59.0000"
|
|
163
|
+
min_timestamp = "4713-01-01 00:00:00.00 BC"
|
|
164
|
+
timestamp = f"least('{max_timestamp}'::timestamp(6), {value}::timestamp(6))"
|
|
165
|
+
timestamp = f"greatest('{min_timestamp}'::timestamp(6), {timestamp})"
|
|
166
|
+
|
|
167
|
+
interval = format((0.5 * (10 ** (-coltype.precision))), f".{coltype.precision+1}f")
|
|
168
|
+
|
|
169
|
+
rounded_timestamp = (
|
|
170
|
+
f"left(to_char(least('{max_timestamp}'::timestamp, {timestamp})"
|
|
171
|
+
f"+ interval '{interval}', 'YYYY-mm-dd HH24:MI:SS.US'),"
|
|
172
|
+
f"length(to_char(least('{max_timestamp}'::timestamp, {timestamp})"
|
|
173
|
+
f"+ interval '{interval}', 'YYYY-mm-dd HH24:MI:SS.US')) - (6-{coltype.precision}))"
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
padded = _add_padding(coltype, rounded_timestamp)
|
|
177
|
+
return f"{null_case_begin} {padded} {null_case_end}"
|
|
178
|
+
|
|
179
|
+
# TODO years with > 4 digits not padded correctly
|
|
180
|
+
# current w/ precision 6: 294276-12-31 23:59:59.0000
|
|
181
|
+
# should be 294276-12-31 23:59:59.000000
|
|
182
|
+
else:
|
|
183
|
+
rounded_timestamp = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')"
|
|
184
|
+
padded = _add_padding(coltype, rounded_timestamp)
|
|
185
|
+
return padded
|
|
186
|
+
|
|
187
|
+
def normalize_number(self, value: str, coltype: FractionalType) -> str:
|
|
188
|
+
precision = min(coltype.precision, 10)
|
|
189
|
+
return self.to_string(f"{value}::decimal(38, {precision})")
|
|
190
|
+
|
|
191
|
+
def normalize_boolean(self, value: str, _coltype: Boolean) -> str:
|
|
192
|
+
return self.to_string(f"{value}::int")
|
|
193
|
+
|
|
194
|
+
def normalize_json(self, value: str, _coltype: JSON) -> str:
|
|
195
|
+
return f"{value}::text"
|
|
196
|
+
|
|
197
|
+
def parse_table_name(self, name: str) -> DbPath:
|
|
198
|
+
"Parse the given table name into a DbPath"
|
|
199
|
+
self.TABLE_NAMES.append(name.split(".")[-1])
|
|
200
|
+
return tuple(name.split("."))
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
@attrs.define(frozen=False, init=False, kw_only=True)
|
|
204
|
+
class PostgreSQL(ThreadedDatabase):
|
|
205
|
+
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = PostgresqlDialect
|
|
206
|
+
SUPPORTS_UNIQUE_CONSTAINT = True
|
|
207
|
+
CONNECT_URI_HELP = "postgresql://<user>:<password>@<host>/<database>"
|
|
208
|
+
CONNECT_URI_PARAMS = ["database?"]
|
|
209
|
+
|
|
210
|
+
_args: Dict[str, Any]
|
|
211
|
+
_conn: Any
|
|
212
|
+
|
|
213
|
+
def __init__(self, *, thread_count, **kw) -> None:
|
|
214
|
+
super().__init__(thread_count=thread_count)
|
|
215
|
+
self._args = kw
|
|
216
|
+
self.default_schema = self._args.get("schema", "public")
|
|
217
|
+
self.dialect.default_schema = self.default_schema
|
|
218
|
+
|
|
219
|
+
def create_connection(self):
|
|
220
|
+
if not self._args:
|
|
221
|
+
self._args["host"] = None # psycopg2 requires 1+ arguments
|
|
222
|
+
|
|
223
|
+
pg = import_postgresql()
|
|
224
|
+
try:
|
|
225
|
+
self._args["password"] = unquote(self._args["password"])
|
|
226
|
+
self._conn = pg.connect(
|
|
227
|
+
database=self._args.get("database"),
|
|
228
|
+
user=self._args.get("user"),
|
|
229
|
+
password=self._args.get("password"),
|
|
230
|
+
host=self._args.get("host"),
|
|
231
|
+
port=self._args.get("port"),
|
|
232
|
+
keepalives=1,
|
|
233
|
+
keepalives_idle=5,
|
|
234
|
+
keepalives_interval=2,
|
|
235
|
+
keepalives_count=2,
|
|
236
|
+
options="-c search_path={}".format(self.default_schema),
|
|
237
|
+
)
|
|
238
|
+
if SESSION_TIME_ZONE:
|
|
239
|
+
self._conn.cursor().execute(f"SET TIME ZONE '{SESSION_TIME_ZONE}'")
|
|
240
|
+
return self._conn
|
|
241
|
+
except pg.OperationalError as e:
|
|
242
|
+
raise ConnectError(*e.args) from e
|
|
243
|
+
|
|
244
|
+
def select_table_schema(self, path: DbPath) -> str:
|
|
245
|
+
database, schema, table = self._normalize_table_path(path)
|
|
246
|
+
|
|
247
|
+
info_schema_path = ["information_schema", "columns"]
|
|
248
|
+
if database:
|
|
249
|
+
info_schema_path.insert(0, database)
|
|
250
|
+
return (
|
|
251
|
+
f"SELECT column_name, data_type, datetime_precision, "
|
|
252
|
+
f"CASE WHEN data_type = 'numeric' "
|
|
253
|
+
f"THEN coalesce(numeric_precision, 131072 + {self.dialect.DEFAULT_NUMERIC_PRECISION}) "
|
|
254
|
+
f"ELSE numeric_precision END AS numeric_precision, "
|
|
255
|
+
f"CASE WHEN data_type = 'numeric' "
|
|
256
|
+
f"THEN coalesce(numeric_scale, {self.dialect.DEFAULT_NUMERIC_PRECISION}) "
|
|
257
|
+
f"ELSE numeric_scale END AS numeric_scale, "
|
|
258
|
+
f"COALESCE(collation_name, NULL) AS collation_name, "
|
|
259
|
+
f"CASE WHEN data_type = 'character varying' "
|
|
260
|
+
f"THEN character_maximum_length END AS character_maximum_length "
|
|
261
|
+
f"FROM {'.'.join(info_schema_path)} "
|
|
262
|
+
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
def select_table_unique_columns(self, path: DbPath) -> str:
|
|
266
|
+
database, schema, table = self._normalize_table_path(path)
|
|
267
|
+
|
|
268
|
+
info_schema_path = ["information_schema", "key_column_usage"]
|
|
269
|
+
if database:
|
|
270
|
+
info_schema_path.insert(0, database)
|
|
271
|
+
|
|
272
|
+
return (
|
|
273
|
+
"SELECT column_name "
|
|
274
|
+
f"FROM {'.'.join(info_schema_path)} "
|
|
275
|
+
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
def _normalize_table_path(self, path: DbPath) -> DbPath:
|
|
279
|
+
if len(path) == 1:
|
|
280
|
+
return None, self.default_schema, path[0]
|
|
281
|
+
elif len(path) == 2:
|
|
282
|
+
return None, path[0], path[1]
|
|
283
|
+
elif len(path) == 3:
|
|
284
|
+
return path
|
|
285
|
+
|
|
286
|
+
raise ValueError(
|
|
287
|
+
f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table"
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
def close(self):
|
|
291
|
+
super().close()
|
|
292
|
+
if self._conn is not None:
|
|
293
|
+
self._conn.close()
|