dcs-sdk 1.6.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_diff/__init__.py +221 -0
- data_diff/__main__.py +517 -0
- data_diff/abcs/__init__.py +13 -0
- data_diff/abcs/compiler.py +27 -0
- data_diff/abcs/database_types.py +402 -0
- data_diff/config.py +141 -0
- data_diff/databases/__init__.py +38 -0
- data_diff/databases/_connect.py +323 -0
- data_diff/databases/base.py +1417 -0
- data_diff/databases/bigquery.py +376 -0
- data_diff/databases/clickhouse.py +217 -0
- data_diff/databases/databricks.py +262 -0
- data_diff/databases/duckdb.py +207 -0
- data_diff/databases/mssql.py +343 -0
- data_diff/databases/mysql.py +189 -0
- data_diff/databases/oracle.py +238 -0
- data_diff/databases/postgresql.py +293 -0
- data_diff/databases/presto.py +222 -0
- data_diff/databases/redis.py +93 -0
- data_diff/databases/redshift.py +233 -0
- data_diff/databases/snowflake.py +222 -0
- data_diff/databases/sybase.py +720 -0
- data_diff/databases/trino.py +73 -0
- data_diff/databases/vertica.py +174 -0
- data_diff/diff_tables.py +489 -0
- data_diff/errors.py +17 -0
- data_diff/format.py +369 -0
- data_diff/hashdiff_tables.py +1026 -0
- data_diff/info_tree.py +76 -0
- data_diff/joindiff_tables.py +434 -0
- data_diff/lexicographic_space.py +253 -0
- data_diff/parse_time.py +88 -0
- data_diff/py.typed +0 -0
- data_diff/queries/__init__.py +13 -0
- data_diff/queries/api.py +213 -0
- data_diff/queries/ast_classes.py +811 -0
- data_diff/queries/base.py +38 -0
- data_diff/queries/extras.py +43 -0
- data_diff/query_utils.py +70 -0
- data_diff/schema.py +67 -0
- data_diff/table_segment.py +583 -0
- data_diff/thread_utils.py +112 -0
- data_diff/utils.py +1022 -0
- data_diff/version.py +15 -0
- dcs_core/__init__.py +13 -0
- dcs_core/__main__.py +17 -0
- dcs_core/__version__.py +15 -0
- dcs_core/cli/__init__.py +13 -0
- dcs_core/cli/cli.py +165 -0
- dcs_core/core/__init__.py +19 -0
- dcs_core/core/common/__init__.py +13 -0
- dcs_core/core/common/errors.py +50 -0
- dcs_core/core/common/models/__init__.py +13 -0
- dcs_core/core/common/models/configuration.py +284 -0
- dcs_core/core/common/models/dashboard.py +24 -0
- dcs_core/core/common/models/data_source_resource.py +75 -0
- dcs_core/core/common/models/metric.py +160 -0
- dcs_core/core/common/models/profile.py +75 -0
- dcs_core/core/common/models/validation.py +216 -0
- dcs_core/core/common/models/widget.py +44 -0
- dcs_core/core/configuration/__init__.py +13 -0
- dcs_core/core/configuration/config_loader.py +139 -0
- dcs_core/core/configuration/configuration_parser.py +262 -0
- dcs_core/core/configuration/configuration_parser_arc.py +328 -0
- dcs_core/core/datasource/__init__.py +13 -0
- dcs_core/core/datasource/base.py +62 -0
- dcs_core/core/datasource/manager.py +112 -0
- dcs_core/core/datasource/search_datasource.py +421 -0
- dcs_core/core/datasource/sql_datasource.py +1094 -0
- dcs_core/core/inspect.py +163 -0
- dcs_core/core/logger/__init__.py +13 -0
- dcs_core/core/logger/base.py +32 -0
- dcs_core/core/logger/default_logger.py +94 -0
- dcs_core/core/metric/__init__.py +13 -0
- dcs_core/core/metric/base.py +220 -0
- dcs_core/core/metric/combined_metric.py +98 -0
- dcs_core/core/metric/custom_metric.py +34 -0
- dcs_core/core/metric/manager.py +137 -0
- dcs_core/core/metric/numeric_metric.py +403 -0
- dcs_core/core/metric/reliability_metric.py +90 -0
- dcs_core/core/profiling/__init__.py +13 -0
- dcs_core/core/profiling/datasource_profiling.py +136 -0
- dcs_core/core/profiling/numeric_field_profiling.py +72 -0
- dcs_core/core/profiling/text_field_profiling.py +67 -0
- dcs_core/core/repository/__init__.py +13 -0
- dcs_core/core/repository/metric_repository.py +77 -0
- dcs_core/core/utils/__init__.py +13 -0
- dcs_core/core/utils/log.py +29 -0
- dcs_core/core/utils/tracking.py +105 -0
- dcs_core/core/utils/utils.py +44 -0
- dcs_core/core/validation/__init__.py +13 -0
- dcs_core/core/validation/base.py +230 -0
- dcs_core/core/validation/completeness_validation.py +153 -0
- dcs_core/core/validation/custom_query_validation.py +24 -0
- dcs_core/core/validation/manager.py +282 -0
- dcs_core/core/validation/numeric_validation.py +276 -0
- dcs_core/core/validation/reliability_validation.py +91 -0
- dcs_core/core/validation/uniqueness_validation.py +61 -0
- dcs_core/core/validation/validity_validation.py +738 -0
- dcs_core/integrations/__init__.py +13 -0
- dcs_core/integrations/databases/__init__.py +13 -0
- dcs_core/integrations/databases/bigquery.py +187 -0
- dcs_core/integrations/databases/databricks.py +51 -0
- dcs_core/integrations/databases/db2.py +652 -0
- dcs_core/integrations/databases/elasticsearch.py +61 -0
- dcs_core/integrations/databases/mssql.py +829 -0
- dcs_core/integrations/databases/mysql.py +409 -0
- dcs_core/integrations/databases/opensearch.py +64 -0
- dcs_core/integrations/databases/oracle.py +719 -0
- dcs_core/integrations/databases/postgres.py +482 -0
- dcs_core/integrations/databases/redshift.py +53 -0
- dcs_core/integrations/databases/snowflake.py +48 -0
- dcs_core/integrations/databases/spark_df.py +111 -0
- dcs_core/integrations/databases/sybase.py +1069 -0
- dcs_core/integrations/storage/__init__.py +13 -0
- dcs_core/integrations/storage/local_file.py +149 -0
- dcs_core/integrations/utils/__init__.py +13 -0
- dcs_core/integrations/utils/utils.py +36 -0
- dcs_core/report/__init__.py +13 -0
- dcs_core/report/dashboard.py +211 -0
- dcs_core/report/models.py +88 -0
- dcs_core/report/static/assets/fonts/DMSans-Bold.ttf +0 -0
- dcs_core/report/static/assets/fonts/DMSans-Medium.ttf +0 -0
- dcs_core/report/static/assets/fonts/DMSans-Regular.ttf +0 -0
- dcs_core/report/static/assets/fonts/DMSans-SemiBold.ttf +0 -0
- dcs_core/report/static/assets/images/docs.svg +6 -0
- dcs_core/report/static/assets/images/github.svg +4 -0
- dcs_core/report/static/assets/images/logo.svg +7 -0
- dcs_core/report/static/assets/images/slack.svg +13 -0
- dcs_core/report/static/index.js +2 -0
- dcs_core/report/static/index.js.LICENSE.txt +3971 -0
- dcs_sdk/__init__.py +13 -0
- dcs_sdk/__main__.py +18 -0
- dcs_sdk/__version__.py +15 -0
- dcs_sdk/cli/__init__.py +13 -0
- dcs_sdk/cli/cli.py +163 -0
- dcs_sdk/sdk/__init__.py +58 -0
- dcs_sdk/sdk/config/__init__.py +13 -0
- dcs_sdk/sdk/config/config_loader.py +491 -0
- dcs_sdk/sdk/data_diff/__init__.py +13 -0
- dcs_sdk/sdk/data_diff/data_differ.py +821 -0
- dcs_sdk/sdk/rules/__init__.py +15 -0
- dcs_sdk/sdk/rules/rules_mappping.py +31 -0
- dcs_sdk/sdk/rules/rules_repository.py +214 -0
- dcs_sdk/sdk/rules/schema_rules.py +65 -0
- dcs_sdk/sdk/utils/__init__.py +13 -0
- dcs_sdk/sdk/utils/serializer.py +25 -0
- dcs_sdk/sdk/utils/similarity_score/__init__.py +13 -0
- dcs_sdk/sdk/utils/similarity_score/base_provider.py +153 -0
- dcs_sdk/sdk/utils/similarity_score/cosine_similarity_provider.py +39 -0
- dcs_sdk/sdk/utils/similarity_score/jaccard_provider.py +24 -0
- dcs_sdk/sdk/utils/similarity_score/levenshtein_distance_provider.py +31 -0
- dcs_sdk/sdk/utils/table.py +475 -0
- dcs_sdk/sdk/utils/themes.py +40 -0
- dcs_sdk/sdk/utils/utils.py +349 -0
- dcs_sdk-1.6.5.dist-info/METADATA +150 -0
- dcs_sdk-1.6.5.dist-info/RECORD +159 -0
- dcs_sdk-1.6.5.dist-info/WHEEL +4 -0
- dcs_sdk-1.6.5.dist-info/entry_points.txt +4 -0
|
@@ -0,0 +1,583 @@
|
|
|
1
|
+
# Copyright 2022-present, the Waterdip Labs Pvt. Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
import time
|
|
17
|
+
from decimal import Decimal
|
|
18
|
+
from itertools import product
|
|
19
|
+
from typing import Container, Dict, List, Optional, Sequence, Tuple
|
|
20
|
+
|
|
21
|
+
import attrs
|
|
22
|
+
import numpy as np
|
|
23
|
+
from loguru import logger
|
|
24
|
+
from typing_extensions import Self
|
|
25
|
+
|
|
26
|
+
from data_diff.abcs.database_types import DbKey, DbPath, DbTime, IKey, NumericType
|
|
27
|
+
from data_diff.databases.base import Database
|
|
28
|
+
from data_diff.databases.redis import RedisBackend
|
|
29
|
+
from data_diff.queries.api import (
|
|
30
|
+
SKIP,
|
|
31
|
+
Code,
|
|
32
|
+
Count,
|
|
33
|
+
Expr,
|
|
34
|
+
and_,
|
|
35
|
+
max_,
|
|
36
|
+
min_,
|
|
37
|
+
or_,
|
|
38
|
+
table,
|
|
39
|
+
this,
|
|
40
|
+
)
|
|
41
|
+
from data_diff.queries.extras import (
|
|
42
|
+
ApplyFuncAndNormalizeAsString,
|
|
43
|
+
Checksum,
|
|
44
|
+
NormalizeAsString,
|
|
45
|
+
)
|
|
46
|
+
from data_diff.schema import RawColumnInfo, Schema, create_schema
|
|
47
|
+
from data_diff.utils import (
|
|
48
|
+
ArithDate,
|
|
49
|
+
ArithDateTime,
|
|
50
|
+
ArithString,
|
|
51
|
+
ArithTimestamp,
|
|
52
|
+
ArithTimestampTZ,
|
|
53
|
+
ArithUnicodeString,
|
|
54
|
+
JobCancelledError,
|
|
55
|
+
Vector,
|
|
56
|
+
safezip,
|
|
57
|
+
split_space,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# logger = logging.getLogger("table_segment")
|
|
61
|
+
|
|
62
|
+
RECOMMENDED_CHECKSUM_DURATION = 20
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def split_key_space(min_key: DbKey, max_key: DbKey, count: int) -> List[DbKey]:
|
|
66
|
+
assert min_key < max_key
|
|
67
|
+
|
|
68
|
+
if max_key - min_key <= count:
|
|
69
|
+
count = 1
|
|
70
|
+
|
|
71
|
+
# Handle arithmetic string types (including temporal types)
|
|
72
|
+
if isinstance(
|
|
73
|
+
min_key, (ArithString, ArithUnicodeString, ArithDateTime, ArithDate, ArithTimestamp, ArithTimestampTZ)
|
|
74
|
+
):
|
|
75
|
+
assert type(min_key) is type(max_key)
|
|
76
|
+
checkpoints = min_key.range(max_key, count)
|
|
77
|
+
else:
|
|
78
|
+
# Handle numeric types
|
|
79
|
+
if isinstance(min_key, Decimal):
|
|
80
|
+
min_key = float(min_key)
|
|
81
|
+
if isinstance(max_key, Decimal):
|
|
82
|
+
max_key = float(max_key)
|
|
83
|
+
checkpoints = split_space(min_key, max_key, count)
|
|
84
|
+
|
|
85
|
+
assert all(min_key < x < max_key for x in checkpoints)
|
|
86
|
+
return [min_key] + checkpoints + [max_key]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def int_product(nums: List[int]) -> int:
|
|
90
|
+
p = 1
|
|
91
|
+
for n in nums:
|
|
92
|
+
p *= n
|
|
93
|
+
return p
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def split_compound_key_space(mn: Vector, mx: Vector, count: int) -> List[List[DbKey]]:
|
|
97
|
+
"""Returns a list of split-points for each key dimension, essentially returning an N-dimensional grid of split points."""
|
|
98
|
+
return [split_key_space(mn_k, mx_k, count) for mn_k, mx_k in safezip(mn, mx)]
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def create_mesh_from_points(*values_per_dim: list) -> List[Tuple[Vector, Vector]]:
|
|
102
|
+
"""Given a list of values along each axis of N dimensional space,
|
|
103
|
+
return an array of boxes whose start-points & end-points align with the given values,
|
|
104
|
+
and together consitute a mesh filling that space entirely (within the bounds of the given values).
|
|
105
|
+
|
|
106
|
+
Assumes given values are already ordered ascending.
|
|
107
|
+
|
|
108
|
+
len(boxes) == ∏i( len(i)-1 )
|
|
109
|
+
|
|
110
|
+
Example:
|
|
111
|
+
::
|
|
112
|
+
>>> d1 = 'a', 'b', 'c'
|
|
113
|
+
>>> d2 = 1, 2, 3
|
|
114
|
+
>>> d3 = 'X', 'Y'
|
|
115
|
+
>>> create_mesh_from_points(d1, d2, d3)
|
|
116
|
+
[
|
|
117
|
+
[('a', 1, 'X'), ('b', 2, 'Y')],
|
|
118
|
+
[('a', 2, 'X'), ('b', 3, 'Y')],
|
|
119
|
+
[('b', 1, 'X'), ('c', 2, 'Y')],
|
|
120
|
+
[('b', 2, 'X'), ('c', 3, 'Y')]
|
|
121
|
+
]
|
|
122
|
+
"""
|
|
123
|
+
assert all(len(v) >= 2 for v in values_per_dim), values_per_dim
|
|
124
|
+
|
|
125
|
+
# Create tuples of (v1, v2) for each pair of adjacent values
|
|
126
|
+
ranges = [list(zip(values[:-1], values[1:])) for values in values_per_dim]
|
|
127
|
+
|
|
128
|
+
assert all(a <= b for r in ranges for a, b in r)
|
|
129
|
+
|
|
130
|
+
# Create a product of all the ranges
|
|
131
|
+
res = [tuple(Vector(a) for a in safezip(*r)) for r in product(*ranges)]
|
|
132
|
+
|
|
133
|
+
expected_len = int_product(len(v) - 1 for v in values_per_dim)
|
|
134
|
+
assert len(res) == expected_len, (len(res), expected_len)
|
|
135
|
+
return res
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@attrs.define(frozen=True)
|
|
139
|
+
class TableSegment:
|
|
140
|
+
"""Signifies a segment of rows (and selected columns) within a table
|
|
141
|
+
|
|
142
|
+
Parameters:
|
|
143
|
+
database (Database): Database instance. See :meth:`connect`
|
|
144
|
+
table_path (:data:`DbPath`): Path to table in form of a tuple. e.g. `('my_dataset', 'table_name')`
|
|
145
|
+
key_columns (Tuple[str]): Name of the key column, which uniquely identifies each row (usually id)
|
|
146
|
+
update_column (str, optional): Name of updated column, which signals that rows changed.
|
|
147
|
+
Usually updated_at or last_update. Used by `min_update` and `max_update`.
|
|
148
|
+
extra_columns (Tuple[str, ...], optional): Extra columns to compare
|
|
149
|
+
transform_columns (Dict[str, str], optional): A dictionary mapping column names to SQL transformation expressions.
|
|
150
|
+
These expressions are applied directly to the specified columns within the
|
|
151
|
+
comparison query, *before* the data is hashed or compared. Useful for
|
|
152
|
+
on-the-fly normalization (e.g., type casting, timezone conversions) without
|
|
153
|
+
requiring intermediate views or staging tables. Defaults to an empty dict.
|
|
154
|
+
min_key (:data:`Vector`, optional): Lowest key value, used to restrict the segment
|
|
155
|
+
max_key (:data:`Vector`, optional): Highest key value, used to restrict the segment
|
|
156
|
+
min_update (:data:`DbTime`, optional): Lowest update_column value, used to restrict the segment
|
|
157
|
+
max_update (:data:`DbTime`, optional): Highest update_column value, used to restrict the segment
|
|
158
|
+
where (str, optional): An additional 'where' expression to restrict the search space.
|
|
159
|
+
|
|
160
|
+
case_sensitive (bool): If false, the case of column names will adjust according to the schema. Default is true.
|
|
161
|
+
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
# Location of table
|
|
165
|
+
database: Database
|
|
166
|
+
table_path: DbPath
|
|
167
|
+
|
|
168
|
+
# Columns
|
|
169
|
+
key_columns: Tuple[str, ...]
|
|
170
|
+
update_column: Optional[str] = None
|
|
171
|
+
extra_columns: Tuple[str, ...] = ()
|
|
172
|
+
transform_columns: Dict[str, str] = {}
|
|
173
|
+
ignored_columns: Container[str] = frozenset()
|
|
174
|
+
|
|
175
|
+
# Restrict the segment
|
|
176
|
+
min_key: Optional[Vector] = None
|
|
177
|
+
max_key: Optional[Vector] = None
|
|
178
|
+
min_update: Optional[DbTime] = None
|
|
179
|
+
max_update: Optional[DbTime] = None
|
|
180
|
+
where: Optional[str] = None
|
|
181
|
+
|
|
182
|
+
case_sensitive: Optional[bool] = True
|
|
183
|
+
_schema: Optional[Schema] = None
|
|
184
|
+
|
|
185
|
+
query_stats: Dict = attrs.Factory(
|
|
186
|
+
lambda: {
|
|
187
|
+
"count_queries_stats": {
|
|
188
|
+
"total_queries": 0,
|
|
189
|
+
"min_time_ms": 0.0,
|
|
190
|
+
"max_time_ms": 0.0,
|
|
191
|
+
"avg_time_ms": 0.0,
|
|
192
|
+
"p90_time_ms": 0.0,
|
|
193
|
+
"total_time_taken_ms": 0.0,
|
|
194
|
+
"_query_times": [],
|
|
195
|
+
},
|
|
196
|
+
"checksum_queries_stats": {
|
|
197
|
+
"total_queries": 0,
|
|
198
|
+
"min_time_ms": 0.0,
|
|
199
|
+
"max_time_ms": 0.0,
|
|
200
|
+
"avg_time_ms": 0.0,
|
|
201
|
+
"p90_time_ms": 0.0,
|
|
202
|
+
"total_time_taken_ms": 0.0,
|
|
203
|
+
"_query_times": [],
|
|
204
|
+
},
|
|
205
|
+
"row_fetch_queries_stats": {
|
|
206
|
+
"total_queries": 0,
|
|
207
|
+
"min_time_ms": 0.0,
|
|
208
|
+
"max_time_ms": 0.0,
|
|
209
|
+
"avg_time_ms": 0.0,
|
|
210
|
+
"p90_time_ms": 0.0,
|
|
211
|
+
"total_time_taken_ms": 0.0,
|
|
212
|
+
"_query_times": [],
|
|
213
|
+
},
|
|
214
|
+
"schema_queries_stats": {
|
|
215
|
+
"total_queries": 0,
|
|
216
|
+
"min_time_ms": 0.0,
|
|
217
|
+
"max_time_ms": 0.0,
|
|
218
|
+
"avg_time_ms": 0.0,
|
|
219
|
+
"p90_time_ms": 0.0,
|
|
220
|
+
"total_time_taken_ms": 0.0,
|
|
221
|
+
"_query_times": [],
|
|
222
|
+
},
|
|
223
|
+
}
|
|
224
|
+
)
|
|
225
|
+
job_id: Optional[int] = None
|
|
226
|
+
|
|
227
|
+
def __attrs_post_init__(self) -> None:
|
|
228
|
+
if not self.update_column and (self.min_update or self.max_update):
|
|
229
|
+
raise ValueError("Error: the min_update/max_update feature requires 'update_column' to be set.")
|
|
230
|
+
|
|
231
|
+
if self.min_key is not None and self.max_key is not None and self.min_key >= self.max_key:
|
|
232
|
+
raise ValueError(f"Error: min_key expected to be smaller than max_key! ({self.min_key} >= {self.max_key})")
|
|
233
|
+
|
|
234
|
+
if self.min_update is not None and self.max_update is not None and self.min_update >= self.max_update:
|
|
235
|
+
raise ValueError(
|
|
236
|
+
f"Error: min_update expected to be smaller than max_update! ({self.min_update} >= {self.max_update})"
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
def _update_stats(self, stats_key: str, query_time_ms: float) -> None:
|
|
240
|
+
logger.info(f"Query time for {stats_key.replace('_stats', '')}: {query_time_ms:.2f} ms")
|
|
241
|
+
stats = self.query_stats[stats_key]
|
|
242
|
+
stats["total_queries"] += 1
|
|
243
|
+
stats["_query_times"].append(query_time_ms)
|
|
244
|
+
stats["total_time_taken_ms"] += query_time_ms
|
|
245
|
+
stats["min_time_ms"] = min(stats["_query_times"]) if stats["_query_times"] else 0.0
|
|
246
|
+
stats["max_time_ms"] = max(stats["_query_times"]) if stats["_query_times"] else 0.0
|
|
247
|
+
if stats["_query_times"]:
|
|
248
|
+
times = np.array(stats["_query_times"])
|
|
249
|
+
stats["avg_time_ms"] = float(np.mean(times))
|
|
250
|
+
stats["p90_time_ms"] = float(np.percentile(times, 90, method="linear"))
|
|
251
|
+
else:
|
|
252
|
+
stats["avg_time_ms"] = 0.0
|
|
253
|
+
stats["p90_time_ms"] = 0.0
|
|
254
|
+
|
|
255
|
+
def _where(self) -> Optional[str]:
|
|
256
|
+
return f"({self.where})" if self.where else None
|
|
257
|
+
|
|
258
|
+
def _column_expr(self, column_name: str) -> Expr:
|
|
259
|
+
"""Return expression for a column, applying configured SQL transform if present."""
|
|
260
|
+
quoted_column_name = self.database.quote(s=column_name)
|
|
261
|
+
if self.transform_columns and column_name in self.transform_columns:
|
|
262
|
+
transform_expr = self.transform_columns[column_name]
|
|
263
|
+
quoted_expr = transform_expr.format(column=quoted_column_name)
|
|
264
|
+
return Code(quoted_expr)
|
|
265
|
+
return this[column_name]
|
|
266
|
+
|
|
267
|
+
def _with_raw_schema(self, raw_schema: Dict[str, RawColumnInfo]) -> Self:
|
|
268
|
+
schema = self.database._process_table_schema(self.table_path, raw_schema, self.relevant_columns, self._where())
|
|
269
|
+
# return self.new(schema=create_schema(self.database.name, self.table_path, schema, self.case_sensitive))
|
|
270
|
+
return self.new(
|
|
271
|
+
schema=create_schema(self.database.name, self.table_path, schema, self.case_sensitive),
|
|
272
|
+
transform_columns=self.transform_columns,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
def with_schema(self) -> Self:
|
|
276
|
+
"Queries the table schema from the database, and returns a new instance of TableSegment, with a schema."
|
|
277
|
+
if self._schema:
|
|
278
|
+
return self
|
|
279
|
+
|
|
280
|
+
start_time = time.monotonic()
|
|
281
|
+
raw_schema = self.database.query_table_schema(self.table_path)
|
|
282
|
+
query_time_ms = (time.monotonic() - start_time) * 1000
|
|
283
|
+
self._update_stats("schema_queries_stats", query_time_ms)
|
|
284
|
+
return self._with_raw_schema(raw_schema)
|
|
285
|
+
|
|
286
|
+
def get_schema(self) -> Dict[str, RawColumnInfo]:
|
|
287
|
+
return self.database.query_table_schema(self.table_path)
|
|
288
|
+
|
|
289
|
+
def _make_key_range(self):
|
|
290
|
+
if self.min_key is not None:
|
|
291
|
+
for mn, k in safezip(self.min_key, self.key_columns):
|
|
292
|
+
quoted = self.database.dialect.quote(k)
|
|
293
|
+
base_expr_sql = (
|
|
294
|
+
self.transform_columns[k].format(column=quoted)
|
|
295
|
+
if self.transform_columns and k in self.transform_columns
|
|
296
|
+
else quoted
|
|
297
|
+
)
|
|
298
|
+
constant_val = self.database.dialect._constant_value(mn)
|
|
299
|
+
yield Code(f"{base_expr_sql} >= {constant_val}")
|
|
300
|
+
if self.max_key is not None:
|
|
301
|
+
for k, mx in safezip(self.key_columns, self.max_key):
|
|
302
|
+
quoted = self.database.dialect.quote(k)
|
|
303
|
+
base_expr_sql = (
|
|
304
|
+
self.transform_columns[k].format(column=quoted)
|
|
305
|
+
if self.transform_columns and k in self.transform_columns
|
|
306
|
+
else quoted
|
|
307
|
+
)
|
|
308
|
+
constant_val = self.database.dialect._constant_value(mx)
|
|
309
|
+
yield Code(f"{base_expr_sql} < {constant_val}")
|
|
310
|
+
|
|
311
|
+
def _make_update_range(self):
|
|
312
|
+
if self.min_update is not None:
|
|
313
|
+
yield self.min_update <= this[self.update_column]
|
|
314
|
+
if self.max_update is not None:
|
|
315
|
+
yield this[self.update_column] < self.max_update
|
|
316
|
+
|
|
317
|
+
@property
|
|
318
|
+
def source_table(self):
|
|
319
|
+
return table(*self.table_path, schema=self._schema)
|
|
320
|
+
|
|
321
|
+
def make_select(self):
|
|
322
|
+
return self.source_table.where(
|
|
323
|
+
*self._make_key_range(), *self._make_update_range(), Code(self._where()) if self.where else SKIP
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
def get_values(self) -> list:
|
|
327
|
+
"Download all the relevant values of the segment from the database"
|
|
328
|
+
# Fetch all the original columns, even if some were later excluded from checking.
|
|
329
|
+
|
|
330
|
+
# fetched_cols = [NormalizeAsString(this[c]) for c in self.relevant_columns]
|
|
331
|
+
# select = self.make_select().select(*fetched_cols)
|
|
332
|
+
if self._is_cancelled():
|
|
333
|
+
raise JobCancelledError(self.job_id)
|
|
334
|
+
select = self.make_select().select(*self._relevant_columns_repr)
|
|
335
|
+
start_time = time.monotonic()
|
|
336
|
+
result = self.database.query(select, List[Tuple])
|
|
337
|
+
query_time_ms = (time.monotonic() - start_time) * 1000
|
|
338
|
+
self._update_stats("row_fetch_queries_stats", query_time_ms)
|
|
339
|
+
|
|
340
|
+
return result
|
|
341
|
+
|
|
342
|
+
# def get_sample_data(self, limit: int = 100) -> list:
|
|
343
|
+
# "Download all the relevant values of the segment from the database"
|
|
344
|
+
|
|
345
|
+
# exprs = []
|
|
346
|
+
# for c in self.key_columns:
|
|
347
|
+
# quoted = self.database.dialect.quote(c)
|
|
348
|
+
# exprs.append(NormalizeAsString(Code(quoted), self._schema[c]))
|
|
349
|
+
# if self.where:
|
|
350
|
+
# select = self.source_table.select(*self._relevant_columns_repr).where(Code(self._where())).limit(limit)
|
|
351
|
+
# self.key_columns
|
|
352
|
+
# else:
|
|
353
|
+
# select = self.source_table.select(*self._relevant_columns_repr).limit(limit)
|
|
354
|
+
|
|
355
|
+
# start_time = time.monotonic()
|
|
356
|
+
# result = self.database.query(select, List[Tuple])
|
|
357
|
+
# query_time_ms = (time.monotonic() - start_time) * 1000
|
|
358
|
+
# self._update_stats("row_fetch_queries_stats", query_time_ms)
|
|
359
|
+
|
|
360
|
+
def get_sample_data(self, limit: int = 100, sample_keys: Optional[List[List[DbKey]]] = None) -> list:
|
|
361
|
+
"""
|
|
362
|
+
Download relevant values of the segment from the database.
|
|
363
|
+
If `sample_keys` is provided, it filters rows matching those composite keys.
|
|
364
|
+
|
|
365
|
+
Parameters:
|
|
366
|
+
limit (int): Maximum number of rows to return (default: 100).
|
|
367
|
+
sample_keys (Optional[List[List[DbKey]]]): List of composite keys to filter rows.
|
|
368
|
+
Each inner list must match the number of key_columns.
|
|
369
|
+
|
|
370
|
+
Returns:
|
|
371
|
+
list: List of tuples containing the queried row data.
|
|
372
|
+
"""
|
|
373
|
+
if self._is_cancelled():
|
|
374
|
+
raise JobCancelledError(self.job_id)
|
|
375
|
+
select = self.make_select().select(*self._relevant_columns_repr)
|
|
376
|
+
|
|
377
|
+
filters = []
|
|
378
|
+
|
|
379
|
+
if sample_keys:
|
|
380
|
+
key_exprs = []
|
|
381
|
+
for key_values in sample_keys:
|
|
382
|
+
and_exprs = []
|
|
383
|
+
for col, val in safezip(self.key_columns, key_values):
|
|
384
|
+
quoted = self.database.dialect.quote(col)
|
|
385
|
+
base_expr_sql = (
|
|
386
|
+
self.transform_columns[col].format(column=quoted)
|
|
387
|
+
if self.transform_columns and col in self.transform_columns
|
|
388
|
+
else quoted
|
|
389
|
+
)
|
|
390
|
+
schema = self._schema[col]
|
|
391
|
+
if val is None:
|
|
392
|
+
and_exprs.append(Code(base_expr_sql + " IS NULL"))
|
|
393
|
+
continue
|
|
394
|
+
mk_v = schema.make_value(val)
|
|
395
|
+
constant_val = self.database.dialect._constant_value(mk_v)
|
|
396
|
+
|
|
397
|
+
# Special handling for Sybase timestamp equality to handle precision mismatches
|
|
398
|
+
if hasattr(self.database.dialect, "timestamp_equality_condition") and hasattr(
|
|
399
|
+
mk_v, "_dt"
|
|
400
|
+
): # Check if it's a datetime-like object
|
|
401
|
+
where_expr = self.database.dialect.timestamp_equality_condition(base_expr_sql, constant_val)
|
|
402
|
+
else:
|
|
403
|
+
where_expr = f"{base_expr_sql} = {constant_val}"
|
|
404
|
+
|
|
405
|
+
and_exprs.append(Code(where_expr))
|
|
406
|
+
if and_exprs:
|
|
407
|
+
key_exprs.append(and_(*and_exprs))
|
|
408
|
+
if key_exprs:
|
|
409
|
+
filters.append(or_(*key_exprs))
|
|
410
|
+
if filters or self.where:
|
|
411
|
+
select = select.where(*filters)
|
|
412
|
+
else:
|
|
413
|
+
logger.warning("No filters applied; fetching up to {} rows without key restrictions", limit)
|
|
414
|
+
|
|
415
|
+
select = select.limit(limit)
|
|
416
|
+
|
|
417
|
+
start_time = time.monotonic()
|
|
418
|
+
result = self.database.query(select, List[Tuple])
|
|
419
|
+
query_time_ms = (time.monotonic() - start_time) * 1000
|
|
420
|
+
self._update_stats("row_fetch_queries_stats", query_time_ms)
|
|
421
|
+
|
|
422
|
+
return result
|
|
423
|
+
|
|
424
|
+
def choose_checkpoints(self, count: int) -> List[List[DbKey]]:
|
|
425
|
+
"Suggests a bunch of evenly-spaced checkpoints to split by, including start, end."
|
|
426
|
+
|
|
427
|
+
assert self.is_bounded
|
|
428
|
+
|
|
429
|
+
# Take Nth root of count, to approximate the appropriate box size
|
|
430
|
+
count = int(count ** (1 / len(self.key_columns))) or 1
|
|
431
|
+
|
|
432
|
+
return split_compound_key_space(self.min_key, self.max_key, count)
|
|
433
|
+
|
|
434
|
+
def segment_by_checkpoints(self, checkpoints: List[List[DbKey]]) -> List["TableSegment"]:
|
|
435
|
+
"Split the current TableSegment to a bunch of smaller ones, separated by the given checkpoints"
|
|
436
|
+
return [self.new_key_bounds(min_key=s, max_key=e) for s, e in create_mesh_from_points(*checkpoints)]
|
|
437
|
+
|
|
438
|
+
def new(self, **kwargs) -> Self:
|
|
439
|
+
"""Creates a copy of the instance using 'replace()'"""
|
|
440
|
+
return attrs.evolve(self, **kwargs)
|
|
441
|
+
|
|
442
|
+
def new_key_bounds(self, min_key: Vector, max_key: Vector, *, key_types: Optional[Sequence[IKey]] = None) -> Self:
|
|
443
|
+
if self.min_key is not None:
|
|
444
|
+
assert self.min_key <= min_key, (self.min_key, min_key)
|
|
445
|
+
assert self.min_key < max_key
|
|
446
|
+
|
|
447
|
+
if self.max_key is not None:
|
|
448
|
+
assert min_key < self.max_key
|
|
449
|
+
assert max_key <= self.max_key
|
|
450
|
+
|
|
451
|
+
# If asked, enforce the PKs to proper types, mainly to meta-params of the relevant side,
|
|
452
|
+
# so that we do not leak e.g. casing of UUIDs from side A to side B and vice versa.
|
|
453
|
+
# If not asked, keep the meta-params of the keys as is (assume them already casted).
|
|
454
|
+
if key_types is not None:
|
|
455
|
+
min_key = Vector(type.make_value(val) for type, val in safezip(key_types, min_key))
|
|
456
|
+
max_key = Vector(type.make_value(val) for type, val in safezip(key_types, max_key))
|
|
457
|
+
|
|
458
|
+
return attrs.evolve(self, min_key=min_key, max_key=max_key)
|
|
459
|
+
|
|
460
|
+
@property
|
|
461
|
+
def relevant_columns(self) -> List[str]:
|
|
462
|
+
extras = list(self.extra_columns)
|
|
463
|
+
|
|
464
|
+
if self.update_column and self.update_column not in extras:
|
|
465
|
+
extras = [self.update_column] + extras
|
|
466
|
+
|
|
467
|
+
return list(self.key_columns) + extras
|
|
468
|
+
|
|
469
|
+
@property
|
|
470
|
+
def _relevant_columns_repr(self) -> List[Expr]:
|
|
471
|
+
# return [NormalizeAsString(this[c]) for c in self.relevant_columns]
|
|
472
|
+
expressions = []
|
|
473
|
+
for c in self.relevant_columns:
|
|
474
|
+
schema = self._schema[c]
|
|
475
|
+
expressions.append(NormalizeAsString(self._column_expr(c), schema))
|
|
476
|
+
return expressions
|
|
477
|
+
|
|
478
|
+
def count(self) -> int:
|
|
479
|
+
"""Count how many rows are in the segment, in one pass."""
|
|
480
|
+
if self._is_cancelled():
|
|
481
|
+
raise JobCancelledError(self.job_id)
|
|
482
|
+
start_time = time.monotonic()
|
|
483
|
+
result = self.database.query(self.make_select().select(Count()), int)
|
|
484
|
+
query_time_ms = (time.monotonic() - start_time) * 1000
|
|
485
|
+
self._update_stats("count_queries_stats", query_time_ms)
|
|
486
|
+
|
|
487
|
+
return result
|
|
488
|
+
|
|
489
|
+
def count_and_checksum(self) -> Tuple[int, int]:
|
|
490
|
+
"""Count and checksum the rows in the segment, in one pass."""
|
|
491
|
+
if self._is_cancelled():
|
|
492
|
+
raise JobCancelledError(self.job_id)
|
|
493
|
+
checked_columns = [c for c in self.relevant_columns if c not in self.ignored_columns]
|
|
494
|
+
# Build transformed expressions for checksum, honoring transforms and normalization
|
|
495
|
+
checksum_exprs: List[Expr] = []
|
|
496
|
+
for c in checked_columns:
|
|
497
|
+
schema = self._schema[c]
|
|
498
|
+
checksum_exprs.append(NormalizeAsString(self._column_expr(c), schema))
|
|
499
|
+
|
|
500
|
+
q = self.make_select().select(Count(), Checksum(checksum_exprs))
|
|
501
|
+
start_time = time.monotonic()
|
|
502
|
+
count, checksum = self.database.query(q, tuple)
|
|
503
|
+
query_time_ms = (time.monotonic() - start_time) * 1000
|
|
504
|
+
|
|
505
|
+
self._update_stats("checksum_queries_stats", query_time_ms)
|
|
506
|
+
|
|
507
|
+
duration = query_time_ms / 1000
|
|
508
|
+
if duration > RECOMMENDED_CHECKSUM_DURATION:
|
|
509
|
+
logger.warning(
|
|
510
|
+
"Checksum is taking longer than expected (%.2f). "
|
|
511
|
+
"We recommend increasing --bisection-factor or decreasing --threads.",
|
|
512
|
+
duration,
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
if count:
|
|
516
|
+
assert checksum, (count, checksum)
|
|
517
|
+
return count or 0, int(checksum) if count else None
|
|
518
|
+
|
|
519
|
+
def query_key_range(self) -> Tuple[tuple, tuple]:
|
|
520
|
+
"""Query database for minimum and maximum key. This is used for setting the initial bounds."""
|
|
521
|
+
# Normalizes the result (needed for UUIDs) after the min/max computation
|
|
522
|
+
select = self.make_select().select(
|
|
523
|
+
ApplyFuncAndNormalizeAsString(self._column_expr(k), f) for k in self.key_columns for f in (min_, max_)
|
|
524
|
+
)
|
|
525
|
+
result = tuple(self.database.query(select, tuple))
|
|
526
|
+
|
|
527
|
+
if any(i is None for i in result):
|
|
528
|
+
raise ValueError("Table appears to be empty")
|
|
529
|
+
|
|
530
|
+
# Min/max keys are interleaved
|
|
531
|
+
min_key, max_key = result[::2], result[1::2]
|
|
532
|
+
assert len(min_key) == len(max_key)
|
|
533
|
+
|
|
534
|
+
return min_key, max_key
|
|
535
|
+
|
|
536
|
+
@property
|
|
537
|
+
def is_bounded(self):
|
|
538
|
+
return self.min_key is not None and self.max_key is not None
|
|
539
|
+
|
|
540
|
+
# def approximate_size(self, row_count: Optional[int] = None):
|
|
541
|
+
# if not self.is_bounded:
|
|
542
|
+
# raise RuntimeError("Cannot approximate the size of an unbounded segment. Must have min_key and max_key.")
|
|
543
|
+
# diff = self.max_key - self.min_key
|
|
544
|
+
# assert all(d > 0 for d in diff)
|
|
545
|
+
# return int_product(diff)
|
|
546
|
+
|
|
547
|
+
def approximate_size(self, row_count: Optional[int] = None) -> int:
|
|
548
|
+
if not self.is_bounded:
|
|
549
|
+
raise RuntimeError("Cannot approximate the size of an unbounded segment. Must have min_key and max_key.")
|
|
550
|
+
|
|
551
|
+
schema = self.get_schema()
|
|
552
|
+
key_types = [schema[col].__class__ for col in self.key_columns]
|
|
553
|
+
|
|
554
|
+
if all(issubclass(t, NumericType) for t in key_types):
|
|
555
|
+
try:
|
|
556
|
+
diff = [mx - mn for mn, mx in zip(self.min_key, self.max_key)]
|
|
557
|
+
if not all(d > 0 for d in diff):
|
|
558
|
+
return row_count if row_count is not None else self.count()
|
|
559
|
+
return int_product(diff)
|
|
560
|
+
except (ValueError, TypeError):
|
|
561
|
+
return row_count if row_count is not None else self.count()
|
|
562
|
+
else:
|
|
563
|
+
return row_count if row_count is not None else self.count()
|
|
564
|
+
|
|
565
|
+
def _is_cancelled(self) -> bool:
|
|
566
|
+
run_id = self.job_id
|
|
567
|
+
if not run_id:
|
|
568
|
+
return False
|
|
569
|
+
run_id = f"revoke_job:{run_id}"
|
|
570
|
+
try:
|
|
571
|
+
backend = RedisBackend.get_instance()
|
|
572
|
+
val = backend.client.get(run_id)
|
|
573
|
+
if not val:
|
|
574
|
+
return False
|
|
575
|
+
if isinstance(val, bytes):
|
|
576
|
+
try:
|
|
577
|
+
val = val.decode()
|
|
578
|
+
except Exception:
|
|
579
|
+
val = str(val)
|
|
580
|
+
return isinstance(val, str) and val.strip().lower() == "revoke"
|
|
581
|
+
except Exception:
|
|
582
|
+
logger.warning("Unable to query Redis for cancellation for run_id=%s", run_id)
|
|
583
|
+
return False
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
# Copyright 2022-present, the Waterdip Labs Pvt. Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import itertools
|
|
16
|
+
from collections import deque
|
|
17
|
+
from collections.abc import Iterable
|
|
18
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
19
|
+
from concurrent.futures.thread import _WorkItem
|
|
20
|
+
from queue import PriorityQueue
|
|
21
|
+
from time import sleep
|
|
22
|
+
from typing import Any, Callable, Iterator, Optional
|
|
23
|
+
|
|
24
|
+
import attrs
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class AutoPriorityQueue(PriorityQueue):
|
|
28
|
+
"""Overrides PriorityQueue to automatically get the priority from _WorkItem.kwargs
|
|
29
|
+
|
|
30
|
+
We also assign a unique id for each item, to avoid making comparisons on _WorkItem.
|
|
31
|
+
As a side effect, items with the same priority are returned FIFO.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
_counter = itertools.count().__next__
|
|
35
|
+
|
|
36
|
+
def put(self, item: Optional[_WorkItem], block=True, timeout=None) -> None:
|
|
37
|
+
priority = item.kwargs.pop("priority") if item is not None else 0
|
|
38
|
+
super().put((-priority, self._counter(), item), block, timeout)
|
|
39
|
+
|
|
40
|
+
def get(self, block=True, timeout=None) -> Optional[_WorkItem]:
|
|
41
|
+
_p, _c, work_item = super().get(block, timeout)
|
|
42
|
+
return work_item
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class PriorityThreadPoolExecutor(ThreadPoolExecutor):
|
|
46
|
+
"""Overrides ThreadPoolExecutor to use AutoPriorityQueue
|
|
47
|
+
|
|
48
|
+
XXX WARNING: Might break in future versions of Python
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(self, *args) -> None:
|
|
52
|
+
super().__init__(*args)
|
|
53
|
+
self._work_queue = AutoPriorityQueue()
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@attrs.define(frozen=False, init=False)
|
|
57
|
+
class ThreadedYielder(Iterable):
|
|
58
|
+
"""Yields results from multiple threads into a single iterator, ordered by priority.
|
|
59
|
+
|
|
60
|
+
To add a source iterator, call ``submit()`` with a function that returns an iterator.
|
|
61
|
+
Priority for the iterator can be provided via the keyword argument 'priority'. (higher runs first)
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
_pool: ThreadPoolExecutor
|
|
65
|
+
_futures: deque
|
|
66
|
+
_yield: deque
|
|
67
|
+
_exception: Optional[None]
|
|
68
|
+
|
|
69
|
+
_pool: ThreadPoolExecutor
|
|
70
|
+
_futures: deque
|
|
71
|
+
_yield: deque = attrs.field(alias="_yield") # Python keyword!
|
|
72
|
+
_exception: Optional[None]
|
|
73
|
+
yield_list: bool
|
|
74
|
+
|
|
75
|
+
def __init__(self, max_workers: Optional[int] = None, yield_list: bool = False) -> None:
|
|
76
|
+
super().__init__()
|
|
77
|
+
self._pool = PriorityThreadPoolExecutor(max_workers)
|
|
78
|
+
self._futures = deque()
|
|
79
|
+
self._yield = deque()
|
|
80
|
+
self._exception = None
|
|
81
|
+
self.yield_list = yield_list
|
|
82
|
+
|
|
83
|
+
def _worker(self, fn, *args, **kwargs) -> None:
|
|
84
|
+
try:
|
|
85
|
+
res = fn(*args, **kwargs)
|
|
86
|
+
if res is not None:
|
|
87
|
+
if self.yield_list:
|
|
88
|
+
self._yield.append(res)
|
|
89
|
+
else:
|
|
90
|
+
self._yield += res
|
|
91
|
+
except Exception as e:
|
|
92
|
+
self._exception = e
|
|
93
|
+
|
|
94
|
+
def submit(self, fn: Callable, *args, priority: int = 0, **kwargs) -> None:
|
|
95
|
+
self._futures.append(self._pool.submit(self._worker, fn, *args, priority=priority, **kwargs))
|
|
96
|
+
|
|
97
|
+
def __iter__(self) -> Iterator[Any]:
|
|
98
|
+
while True:
|
|
99
|
+
if self._exception:
|
|
100
|
+
raise self._exception
|
|
101
|
+
|
|
102
|
+
while self._yield:
|
|
103
|
+
yield self._yield.popleft()
|
|
104
|
+
|
|
105
|
+
if not self._futures:
|
|
106
|
+
# No more tasks
|
|
107
|
+
return
|
|
108
|
+
|
|
109
|
+
if self._futures[0].done():
|
|
110
|
+
self._futures.popleft()
|
|
111
|
+
else:
|
|
112
|
+
sleep(0.001)
|