dcs-sdk 1.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (159) hide show
  1. data_diff/__init__.py +221 -0
  2. data_diff/__main__.py +517 -0
  3. data_diff/abcs/__init__.py +13 -0
  4. data_diff/abcs/compiler.py +27 -0
  5. data_diff/abcs/database_types.py +402 -0
  6. data_diff/config.py +141 -0
  7. data_diff/databases/__init__.py +38 -0
  8. data_diff/databases/_connect.py +323 -0
  9. data_diff/databases/base.py +1417 -0
  10. data_diff/databases/bigquery.py +376 -0
  11. data_diff/databases/clickhouse.py +217 -0
  12. data_diff/databases/databricks.py +262 -0
  13. data_diff/databases/duckdb.py +207 -0
  14. data_diff/databases/mssql.py +343 -0
  15. data_diff/databases/mysql.py +189 -0
  16. data_diff/databases/oracle.py +238 -0
  17. data_diff/databases/postgresql.py +293 -0
  18. data_diff/databases/presto.py +222 -0
  19. data_diff/databases/redis.py +93 -0
  20. data_diff/databases/redshift.py +233 -0
  21. data_diff/databases/snowflake.py +222 -0
  22. data_diff/databases/sybase.py +720 -0
  23. data_diff/databases/trino.py +73 -0
  24. data_diff/databases/vertica.py +174 -0
  25. data_diff/diff_tables.py +489 -0
  26. data_diff/errors.py +17 -0
  27. data_diff/format.py +369 -0
  28. data_diff/hashdiff_tables.py +1026 -0
  29. data_diff/info_tree.py +76 -0
  30. data_diff/joindiff_tables.py +434 -0
  31. data_diff/lexicographic_space.py +253 -0
  32. data_diff/parse_time.py +88 -0
  33. data_diff/py.typed +0 -0
  34. data_diff/queries/__init__.py +13 -0
  35. data_diff/queries/api.py +213 -0
  36. data_diff/queries/ast_classes.py +811 -0
  37. data_diff/queries/base.py +38 -0
  38. data_diff/queries/extras.py +43 -0
  39. data_diff/query_utils.py +70 -0
  40. data_diff/schema.py +67 -0
  41. data_diff/table_segment.py +583 -0
  42. data_diff/thread_utils.py +112 -0
  43. data_diff/utils.py +1022 -0
  44. data_diff/version.py +15 -0
  45. dcs_core/__init__.py +13 -0
  46. dcs_core/__main__.py +17 -0
  47. dcs_core/__version__.py +15 -0
  48. dcs_core/cli/__init__.py +13 -0
  49. dcs_core/cli/cli.py +165 -0
  50. dcs_core/core/__init__.py +19 -0
  51. dcs_core/core/common/__init__.py +13 -0
  52. dcs_core/core/common/errors.py +50 -0
  53. dcs_core/core/common/models/__init__.py +13 -0
  54. dcs_core/core/common/models/configuration.py +284 -0
  55. dcs_core/core/common/models/dashboard.py +24 -0
  56. dcs_core/core/common/models/data_source_resource.py +75 -0
  57. dcs_core/core/common/models/metric.py +160 -0
  58. dcs_core/core/common/models/profile.py +75 -0
  59. dcs_core/core/common/models/validation.py +216 -0
  60. dcs_core/core/common/models/widget.py +44 -0
  61. dcs_core/core/configuration/__init__.py +13 -0
  62. dcs_core/core/configuration/config_loader.py +139 -0
  63. dcs_core/core/configuration/configuration_parser.py +262 -0
  64. dcs_core/core/configuration/configuration_parser_arc.py +328 -0
  65. dcs_core/core/datasource/__init__.py +13 -0
  66. dcs_core/core/datasource/base.py +62 -0
  67. dcs_core/core/datasource/manager.py +112 -0
  68. dcs_core/core/datasource/search_datasource.py +421 -0
  69. dcs_core/core/datasource/sql_datasource.py +1094 -0
  70. dcs_core/core/inspect.py +163 -0
  71. dcs_core/core/logger/__init__.py +13 -0
  72. dcs_core/core/logger/base.py +32 -0
  73. dcs_core/core/logger/default_logger.py +94 -0
  74. dcs_core/core/metric/__init__.py +13 -0
  75. dcs_core/core/metric/base.py +220 -0
  76. dcs_core/core/metric/combined_metric.py +98 -0
  77. dcs_core/core/metric/custom_metric.py +34 -0
  78. dcs_core/core/metric/manager.py +137 -0
  79. dcs_core/core/metric/numeric_metric.py +403 -0
  80. dcs_core/core/metric/reliability_metric.py +90 -0
  81. dcs_core/core/profiling/__init__.py +13 -0
  82. dcs_core/core/profiling/datasource_profiling.py +136 -0
  83. dcs_core/core/profiling/numeric_field_profiling.py +72 -0
  84. dcs_core/core/profiling/text_field_profiling.py +67 -0
  85. dcs_core/core/repository/__init__.py +13 -0
  86. dcs_core/core/repository/metric_repository.py +77 -0
  87. dcs_core/core/utils/__init__.py +13 -0
  88. dcs_core/core/utils/log.py +29 -0
  89. dcs_core/core/utils/tracking.py +105 -0
  90. dcs_core/core/utils/utils.py +44 -0
  91. dcs_core/core/validation/__init__.py +13 -0
  92. dcs_core/core/validation/base.py +230 -0
  93. dcs_core/core/validation/completeness_validation.py +153 -0
  94. dcs_core/core/validation/custom_query_validation.py +24 -0
  95. dcs_core/core/validation/manager.py +282 -0
  96. dcs_core/core/validation/numeric_validation.py +276 -0
  97. dcs_core/core/validation/reliability_validation.py +91 -0
  98. dcs_core/core/validation/uniqueness_validation.py +61 -0
  99. dcs_core/core/validation/validity_validation.py +738 -0
  100. dcs_core/integrations/__init__.py +13 -0
  101. dcs_core/integrations/databases/__init__.py +13 -0
  102. dcs_core/integrations/databases/bigquery.py +187 -0
  103. dcs_core/integrations/databases/databricks.py +51 -0
  104. dcs_core/integrations/databases/db2.py +652 -0
  105. dcs_core/integrations/databases/elasticsearch.py +61 -0
  106. dcs_core/integrations/databases/mssql.py +829 -0
  107. dcs_core/integrations/databases/mysql.py +409 -0
  108. dcs_core/integrations/databases/opensearch.py +64 -0
  109. dcs_core/integrations/databases/oracle.py +719 -0
  110. dcs_core/integrations/databases/postgres.py +482 -0
  111. dcs_core/integrations/databases/redshift.py +53 -0
  112. dcs_core/integrations/databases/snowflake.py +48 -0
  113. dcs_core/integrations/databases/spark_df.py +111 -0
  114. dcs_core/integrations/databases/sybase.py +1069 -0
  115. dcs_core/integrations/storage/__init__.py +13 -0
  116. dcs_core/integrations/storage/local_file.py +149 -0
  117. dcs_core/integrations/utils/__init__.py +13 -0
  118. dcs_core/integrations/utils/utils.py +36 -0
  119. dcs_core/report/__init__.py +13 -0
  120. dcs_core/report/dashboard.py +211 -0
  121. dcs_core/report/models.py +88 -0
  122. dcs_core/report/static/assets/fonts/DMSans-Bold.ttf +0 -0
  123. dcs_core/report/static/assets/fonts/DMSans-Medium.ttf +0 -0
  124. dcs_core/report/static/assets/fonts/DMSans-Regular.ttf +0 -0
  125. dcs_core/report/static/assets/fonts/DMSans-SemiBold.ttf +0 -0
  126. dcs_core/report/static/assets/images/docs.svg +6 -0
  127. dcs_core/report/static/assets/images/github.svg +4 -0
  128. dcs_core/report/static/assets/images/logo.svg +7 -0
  129. dcs_core/report/static/assets/images/slack.svg +13 -0
  130. dcs_core/report/static/index.js +2 -0
  131. dcs_core/report/static/index.js.LICENSE.txt +3971 -0
  132. dcs_sdk/__init__.py +13 -0
  133. dcs_sdk/__main__.py +18 -0
  134. dcs_sdk/__version__.py +15 -0
  135. dcs_sdk/cli/__init__.py +13 -0
  136. dcs_sdk/cli/cli.py +163 -0
  137. dcs_sdk/sdk/__init__.py +58 -0
  138. dcs_sdk/sdk/config/__init__.py +13 -0
  139. dcs_sdk/sdk/config/config_loader.py +491 -0
  140. dcs_sdk/sdk/data_diff/__init__.py +13 -0
  141. dcs_sdk/sdk/data_diff/data_differ.py +821 -0
  142. dcs_sdk/sdk/rules/__init__.py +15 -0
  143. dcs_sdk/sdk/rules/rules_mappping.py +31 -0
  144. dcs_sdk/sdk/rules/rules_repository.py +214 -0
  145. dcs_sdk/sdk/rules/schema_rules.py +65 -0
  146. dcs_sdk/sdk/utils/__init__.py +13 -0
  147. dcs_sdk/sdk/utils/serializer.py +25 -0
  148. dcs_sdk/sdk/utils/similarity_score/__init__.py +13 -0
  149. dcs_sdk/sdk/utils/similarity_score/base_provider.py +153 -0
  150. dcs_sdk/sdk/utils/similarity_score/cosine_similarity_provider.py +39 -0
  151. dcs_sdk/sdk/utils/similarity_score/jaccard_provider.py +24 -0
  152. dcs_sdk/sdk/utils/similarity_score/levenshtein_distance_provider.py +31 -0
  153. dcs_sdk/sdk/utils/table.py +475 -0
  154. dcs_sdk/sdk/utils/themes.py +40 -0
  155. dcs_sdk/sdk/utils/utils.py +349 -0
  156. dcs_sdk-1.6.5.dist-info/METADATA +150 -0
  157. dcs_sdk-1.6.5.dist-info/RECORD +159 -0
  158. dcs_sdk-1.6.5.dist-info/WHEEL +4 -0
  159. dcs_sdk-1.6.5.dist-info/entry_points.txt +4 -0
@@ -0,0 +1,583 @@
1
+ # Copyright 2022-present, the Waterdip Labs Pvt. Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import time
17
+ from decimal import Decimal
18
+ from itertools import product
19
+ from typing import Container, Dict, List, Optional, Sequence, Tuple
20
+
21
+ import attrs
22
+ import numpy as np
23
+ from loguru import logger
24
+ from typing_extensions import Self
25
+
26
+ from data_diff.abcs.database_types import DbKey, DbPath, DbTime, IKey, NumericType
27
+ from data_diff.databases.base import Database
28
+ from data_diff.databases.redis import RedisBackend
29
+ from data_diff.queries.api import (
30
+ SKIP,
31
+ Code,
32
+ Count,
33
+ Expr,
34
+ and_,
35
+ max_,
36
+ min_,
37
+ or_,
38
+ table,
39
+ this,
40
+ )
41
+ from data_diff.queries.extras import (
42
+ ApplyFuncAndNormalizeAsString,
43
+ Checksum,
44
+ NormalizeAsString,
45
+ )
46
+ from data_diff.schema import RawColumnInfo, Schema, create_schema
47
+ from data_diff.utils import (
48
+ ArithDate,
49
+ ArithDateTime,
50
+ ArithString,
51
+ ArithTimestamp,
52
+ ArithTimestampTZ,
53
+ ArithUnicodeString,
54
+ JobCancelledError,
55
+ Vector,
56
+ safezip,
57
+ split_space,
58
+ )
59
+
60
+ # logger = logging.getLogger("table_segment")
61
+
62
+ RECOMMENDED_CHECKSUM_DURATION = 20
63
+
64
+
65
+ def split_key_space(min_key: DbKey, max_key: DbKey, count: int) -> List[DbKey]:
66
+ assert min_key < max_key
67
+
68
+ if max_key - min_key <= count:
69
+ count = 1
70
+
71
+ # Handle arithmetic string types (including temporal types)
72
+ if isinstance(
73
+ min_key, (ArithString, ArithUnicodeString, ArithDateTime, ArithDate, ArithTimestamp, ArithTimestampTZ)
74
+ ):
75
+ assert type(min_key) is type(max_key)
76
+ checkpoints = min_key.range(max_key, count)
77
+ else:
78
+ # Handle numeric types
79
+ if isinstance(min_key, Decimal):
80
+ min_key = float(min_key)
81
+ if isinstance(max_key, Decimal):
82
+ max_key = float(max_key)
83
+ checkpoints = split_space(min_key, max_key, count)
84
+
85
+ assert all(min_key < x < max_key for x in checkpoints)
86
+ return [min_key] + checkpoints + [max_key]
87
+
88
+
89
+ def int_product(nums: List[int]) -> int:
90
+ p = 1
91
+ for n in nums:
92
+ p *= n
93
+ return p
94
+
95
+
96
+ def split_compound_key_space(mn: Vector, mx: Vector, count: int) -> List[List[DbKey]]:
97
+ """Returns a list of split-points for each key dimension, essentially returning an N-dimensional grid of split points."""
98
+ return [split_key_space(mn_k, mx_k, count) for mn_k, mx_k in safezip(mn, mx)]
99
+
100
+
101
+ def create_mesh_from_points(*values_per_dim: list) -> List[Tuple[Vector, Vector]]:
102
+ """Given a list of values along each axis of N dimensional space,
103
+ return an array of boxes whose start-points & end-points align with the given values,
104
+ and together consitute a mesh filling that space entirely (within the bounds of the given values).
105
+
106
+ Assumes given values are already ordered ascending.
107
+
108
+ len(boxes) == ∏i( len(i)-1 )
109
+
110
+ Example:
111
+ ::
112
+ >>> d1 = 'a', 'b', 'c'
113
+ >>> d2 = 1, 2, 3
114
+ >>> d3 = 'X', 'Y'
115
+ >>> create_mesh_from_points(d1, d2, d3)
116
+ [
117
+ [('a', 1, 'X'), ('b', 2, 'Y')],
118
+ [('a', 2, 'X'), ('b', 3, 'Y')],
119
+ [('b', 1, 'X'), ('c', 2, 'Y')],
120
+ [('b', 2, 'X'), ('c', 3, 'Y')]
121
+ ]
122
+ """
123
+ assert all(len(v) >= 2 for v in values_per_dim), values_per_dim
124
+
125
+ # Create tuples of (v1, v2) for each pair of adjacent values
126
+ ranges = [list(zip(values[:-1], values[1:])) for values in values_per_dim]
127
+
128
+ assert all(a <= b for r in ranges for a, b in r)
129
+
130
+ # Create a product of all the ranges
131
+ res = [tuple(Vector(a) for a in safezip(*r)) for r in product(*ranges)]
132
+
133
+ expected_len = int_product(len(v) - 1 for v in values_per_dim)
134
+ assert len(res) == expected_len, (len(res), expected_len)
135
+ return res
136
+
137
+
138
+ @attrs.define(frozen=True)
139
+ class TableSegment:
140
+ """Signifies a segment of rows (and selected columns) within a table
141
+
142
+ Parameters:
143
+ database (Database): Database instance. See :meth:`connect`
144
+ table_path (:data:`DbPath`): Path to table in form of a tuple. e.g. `('my_dataset', 'table_name')`
145
+ key_columns (Tuple[str]): Name of the key column, which uniquely identifies each row (usually id)
146
+ update_column (str, optional): Name of updated column, which signals that rows changed.
147
+ Usually updated_at or last_update. Used by `min_update` and `max_update`.
148
+ extra_columns (Tuple[str, ...], optional): Extra columns to compare
149
+ transform_columns (Dict[str, str], optional): A dictionary mapping column names to SQL transformation expressions.
150
+ These expressions are applied directly to the specified columns within the
151
+ comparison query, *before* the data is hashed or compared. Useful for
152
+ on-the-fly normalization (e.g., type casting, timezone conversions) without
153
+ requiring intermediate views or staging tables. Defaults to an empty dict.
154
+ min_key (:data:`Vector`, optional): Lowest key value, used to restrict the segment
155
+ max_key (:data:`Vector`, optional): Highest key value, used to restrict the segment
156
+ min_update (:data:`DbTime`, optional): Lowest update_column value, used to restrict the segment
157
+ max_update (:data:`DbTime`, optional): Highest update_column value, used to restrict the segment
158
+ where (str, optional): An additional 'where' expression to restrict the search space.
159
+
160
+ case_sensitive (bool): If false, the case of column names will adjust according to the schema. Default is true.
161
+
162
+ """
163
+
164
+ # Location of table
165
+ database: Database
166
+ table_path: DbPath
167
+
168
+ # Columns
169
+ key_columns: Tuple[str, ...]
170
+ update_column: Optional[str] = None
171
+ extra_columns: Tuple[str, ...] = ()
172
+ transform_columns: Dict[str, str] = {}
173
+ ignored_columns: Container[str] = frozenset()
174
+
175
+ # Restrict the segment
176
+ min_key: Optional[Vector] = None
177
+ max_key: Optional[Vector] = None
178
+ min_update: Optional[DbTime] = None
179
+ max_update: Optional[DbTime] = None
180
+ where: Optional[str] = None
181
+
182
+ case_sensitive: Optional[bool] = True
183
+ _schema: Optional[Schema] = None
184
+
185
+ query_stats: Dict = attrs.Factory(
186
+ lambda: {
187
+ "count_queries_stats": {
188
+ "total_queries": 0,
189
+ "min_time_ms": 0.0,
190
+ "max_time_ms": 0.0,
191
+ "avg_time_ms": 0.0,
192
+ "p90_time_ms": 0.0,
193
+ "total_time_taken_ms": 0.0,
194
+ "_query_times": [],
195
+ },
196
+ "checksum_queries_stats": {
197
+ "total_queries": 0,
198
+ "min_time_ms": 0.0,
199
+ "max_time_ms": 0.0,
200
+ "avg_time_ms": 0.0,
201
+ "p90_time_ms": 0.0,
202
+ "total_time_taken_ms": 0.0,
203
+ "_query_times": [],
204
+ },
205
+ "row_fetch_queries_stats": {
206
+ "total_queries": 0,
207
+ "min_time_ms": 0.0,
208
+ "max_time_ms": 0.0,
209
+ "avg_time_ms": 0.0,
210
+ "p90_time_ms": 0.0,
211
+ "total_time_taken_ms": 0.0,
212
+ "_query_times": [],
213
+ },
214
+ "schema_queries_stats": {
215
+ "total_queries": 0,
216
+ "min_time_ms": 0.0,
217
+ "max_time_ms": 0.0,
218
+ "avg_time_ms": 0.0,
219
+ "p90_time_ms": 0.0,
220
+ "total_time_taken_ms": 0.0,
221
+ "_query_times": [],
222
+ },
223
+ }
224
+ )
225
+ job_id: Optional[int] = None
226
+
227
+ def __attrs_post_init__(self) -> None:
228
+ if not self.update_column and (self.min_update or self.max_update):
229
+ raise ValueError("Error: the min_update/max_update feature requires 'update_column' to be set.")
230
+
231
+ if self.min_key is not None and self.max_key is not None and self.min_key >= self.max_key:
232
+ raise ValueError(f"Error: min_key expected to be smaller than max_key! ({self.min_key} >= {self.max_key})")
233
+
234
+ if self.min_update is not None and self.max_update is not None and self.min_update >= self.max_update:
235
+ raise ValueError(
236
+ f"Error: min_update expected to be smaller than max_update! ({self.min_update} >= {self.max_update})"
237
+ )
238
+
239
+ def _update_stats(self, stats_key: str, query_time_ms: float) -> None:
240
+ logger.info(f"Query time for {stats_key.replace('_stats', '')}: {query_time_ms:.2f} ms")
241
+ stats = self.query_stats[stats_key]
242
+ stats["total_queries"] += 1
243
+ stats["_query_times"].append(query_time_ms)
244
+ stats["total_time_taken_ms"] += query_time_ms
245
+ stats["min_time_ms"] = min(stats["_query_times"]) if stats["_query_times"] else 0.0
246
+ stats["max_time_ms"] = max(stats["_query_times"]) if stats["_query_times"] else 0.0
247
+ if stats["_query_times"]:
248
+ times = np.array(stats["_query_times"])
249
+ stats["avg_time_ms"] = float(np.mean(times))
250
+ stats["p90_time_ms"] = float(np.percentile(times, 90, method="linear"))
251
+ else:
252
+ stats["avg_time_ms"] = 0.0
253
+ stats["p90_time_ms"] = 0.0
254
+
255
+ def _where(self) -> Optional[str]:
256
+ return f"({self.where})" if self.where else None
257
+
258
+ def _column_expr(self, column_name: str) -> Expr:
259
+ """Return expression for a column, applying configured SQL transform if present."""
260
+ quoted_column_name = self.database.quote(s=column_name)
261
+ if self.transform_columns and column_name in self.transform_columns:
262
+ transform_expr = self.transform_columns[column_name]
263
+ quoted_expr = transform_expr.format(column=quoted_column_name)
264
+ return Code(quoted_expr)
265
+ return this[column_name]
266
+
267
+ def _with_raw_schema(self, raw_schema: Dict[str, RawColumnInfo]) -> Self:
268
+ schema = self.database._process_table_schema(self.table_path, raw_schema, self.relevant_columns, self._where())
269
+ # return self.new(schema=create_schema(self.database.name, self.table_path, schema, self.case_sensitive))
270
+ return self.new(
271
+ schema=create_schema(self.database.name, self.table_path, schema, self.case_sensitive),
272
+ transform_columns=self.transform_columns,
273
+ )
274
+
275
+ def with_schema(self) -> Self:
276
+ "Queries the table schema from the database, and returns a new instance of TableSegment, with a schema."
277
+ if self._schema:
278
+ return self
279
+
280
+ start_time = time.monotonic()
281
+ raw_schema = self.database.query_table_schema(self.table_path)
282
+ query_time_ms = (time.monotonic() - start_time) * 1000
283
+ self._update_stats("schema_queries_stats", query_time_ms)
284
+ return self._with_raw_schema(raw_schema)
285
+
286
+ def get_schema(self) -> Dict[str, RawColumnInfo]:
287
+ return self.database.query_table_schema(self.table_path)
288
+
289
+ def _make_key_range(self):
290
+ if self.min_key is not None:
291
+ for mn, k in safezip(self.min_key, self.key_columns):
292
+ quoted = self.database.dialect.quote(k)
293
+ base_expr_sql = (
294
+ self.transform_columns[k].format(column=quoted)
295
+ if self.transform_columns and k in self.transform_columns
296
+ else quoted
297
+ )
298
+ constant_val = self.database.dialect._constant_value(mn)
299
+ yield Code(f"{base_expr_sql} >= {constant_val}")
300
+ if self.max_key is not None:
301
+ for k, mx in safezip(self.key_columns, self.max_key):
302
+ quoted = self.database.dialect.quote(k)
303
+ base_expr_sql = (
304
+ self.transform_columns[k].format(column=quoted)
305
+ if self.transform_columns and k in self.transform_columns
306
+ else quoted
307
+ )
308
+ constant_val = self.database.dialect._constant_value(mx)
309
+ yield Code(f"{base_expr_sql} < {constant_val}")
310
+
311
+ def _make_update_range(self):
312
+ if self.min_update is not None:
313
+ yield self.min_update <= this[self.update_column]
314
+ if self.max_update is not None:
315
+ yield this[self.update_column] < self.max_update
316
+
317
+ @property
318
+ def source_table(self):
319
+ return table(*self.table_path, schema=self._schema)
320
+
321
+ def make_select(self):
322
+ return self.source_table.where(
323
+ *self._make_key_range(), *self._make_update_range(), Code(self._where()) if self.where else SKIP
324
+ )
325
+
326
+ def get_values(self) -> list:
327
+ "Download all the relevant values of the segment from the database"
328
+ # Fetch all the original columns, even if some were later excluded from checking.
329
+
330
+ # fetched_cols = [NormalizeAsString(this[c]) for c in self.relevant_columns]
331
+ # select = self.make_select().select(*fetched_cols)
332
+ if self._is_cancelled():
333
+ raise JobCancelledError(self.job_id)
334
+ select = self.make_select().select(*self._relevant_columns_repr)
335
+ start_time = time.monotonic()
336
+ result = self.database.query(select, List[Tuple])
337
+ query_time_ms = (time.monotonic() - start_time) * 1000
338
+ self._update_stats("row_fetch_queries_stats", query_time_ms)
339
+
340
+ return result
341
+
342
+ # def get_sample_data(self, limit: int = 100) -> list:
343
+ # "Download all the relevant values of the segment from the database"
344
+
345
+ # exprs = []
346
+ # for c in self.key_columns:
347
+ # quoted = self.database.dialect.quote(c)
348
+ # exprs.append(NormalizeAsString(Code(quoted), self._schema[c]))
349
+ # if self.where:
350
+ # select = self.source_table.select(*self._relevant_columns_repr).where(Code(self._where())).limit(limit)
351
+ # self.key_columns
352
+ # else:
353
+ # select = self.source_table.select(*self._relevant_columns_repr).limit(limit)
354
+
355
+ # start_time = time.monotonic()
356
+ # result = self.database.query(select, List[Tuple])
357
+ # query_time_ms = (time.monotonic() - start_time) * 1000
358
+ # self._update_stats("row_fetch_queries_stats", query_time_ms)
359
+
360
+ def get_sample_data(self, limit: int = 100, sample_keys: Optional[List[List[DbKey]]] = None) -> list:
361
+ """
362
+ Download relevant values of the segment from the database.
363
+ If `sample_keys` is provided, it filters rows matching those composite keys.
364
+
365
+ Parameters:
366
+ limit (int): Maximum number of rows to return (default: 100).
367
+ sample_keys (Optional[List[List[DbKey]]]): List of composite keys to filter rows.
368
+ Each inner list must match the number of key_columns.
369
+
370
+ Returns:
371
+ list: List of tuples containing the queried row data.
372
+ """
373
+ if self._is_cancelled():
374
+ raise JobCancelledError(self.job_id)
375
+ select = self.make_select().select(*self._relevant_columns_repr)
376
+
377
+ filters = []
378
+
379
+ if sample_keys:
380
+ key_exprs = []
381
+ for key_values in sample_keys:
382
+ and_exprs = []
383
+ for col, val in safezip(self.key_columns, key_values):
384
+ quoted = self.database.dialect.quote(col)
385
+ base_expr_sql = (
386
+ self.transform_columns[col].format(column=quoted)
387
+ if self.transform_columns and col in self.transform_columns
388
+ else quoted
389
+ )
390
+ schema = self._schema[col]
391
+ if val is None:
392
+ and_exprs.append(Code(base_expr_sql + " IS NULL"))
393
+ continue
394
+ mk_v = schema.make_value(val)
395
+ constant_val = self.database.dialect._constant_value(mk_v)
396
+
397
+ # Special handling for Sybase timestamp equality to handle precision mismatches
398
+ if hasattr(self.database.dialect, "timestamp_equality_condition") and hasattr(
399
+ mk_v, "_dt"
400
+ ): # Check if it's a datetime-like object
401
+ where_expr = self.database.dialect.timestamp_equality_condition(base_expr_sql, constant_val)
402
+ else:
403
+ where_expr = f"{base_expr_sql} = {constant_val}"
404
+
405
+ and_exprs.append(Code(where_expr))
406
+ if and_exprs:
407
+ key_exprs.append(and_(*and_exprs))
408
+ if key_exprs:
409
+ filters.append(or_(*key_exprs))
410
+ if filters or self.where:
411
+ select = select.where(*filters)
412
+ else:
413
+ logger.warning("No filters applied; fetching up to {} rows without key restrictions", limit)
414
+
415
+ select = select.limit(limit)
416
+
417
+ start_time = time.monotonic()
418
+ result = self.database.query(select, List[Tuple])
419
+ query_time_ms = (time.monotonic() - start_time) * 1000
420
+ self._update_stats("row_fetch_queries_stats", query_time_ms)
421
+
422
+ return result
423
+
424
+ def choose_checkpoints(self, count: int) -> List[List[DbKey]]:
425
+ "Suggests a bunch of evenly-spaced checkpoints to split by, including start, end."
426
+
427
+ assert self.is_bounded
428
+
429
+ # Take Nth root of count, to approximate the appropriate box size
430
+ count = int(count ** (1 / len(self.key_columns))) or 1
431
+
432
+ return split_compound_key_space(self.min_key, self.max_key, count)
433
+
434
+ def segment_by_checkpoints(self, checkpoints: List[List[DbKey]]) -> List["TableSegment"]:
435
+ "Split the current TableSegment to a bunch of smaller ones, separated by the given checkpoints"
436
+ return [self.new_key_bounds(min_key=s, max_key=e) for s, e in create_mesh_from_points(*checkpoints)]
437
+
438
+ def new(self, **kwargs) -> Self:
439
+ """Creates a copy of the instance using 'replace()'"""
440
+ return attrs.evolve(self, **kwargs)
441
+
442
+ def new_key_bounds(self, min_key: Vector, max_key: Vector, *, key_types: Optional[Sequence[IKey]] = None) -> Self:
443
+ if self.min_key is not None:
444
+ assert self.min_key <= min_key, (self.min_key, min_key)
445
+ assert self.min_key < max_key
446
+
447
+ if self.max_key is not None:
448
+ assert min_key < self.max_key
449
+ assert max_key <= self.max_key
450
+
451
+ # If asked, enforce the PKs to proper types, mainly to meta-params of the relevant side,
452
+ # so that we do not leak e.g. casing of UUIDs from side A to side B and vice versa.
453
+ # If not asked, keep the meta-params of the keys as is (assume them already casted).
454
+ if key_types is not None:
455
+ min_key = Vector(type.make_value(val) for type, val in safezip(key_types, min_key))
456
+ max_key = Vector(type.make_value(val) for type, val in safezip(key_types, max_key))
457
+
458
+ return attrs.evolve(self, min_key=min_key, max_key=max_key)
459
+
460
+ @property
461
+ def relevant_columns(self) -> List[str]:
462
+ extras = list(self.extra_columns)
463
+
464
+ if self.update_column and self.update_column not in extras:
465
+ extras = [self.update_column] + extras
466
+
467
+ return list(self.key_columns) + extras
468
+
469
+ @property
470
+ def _relevant_columns_repr(self) -> List[Expr]:
471
+ # return [NormalizeAsString(this[c]) for c in self.relevant_columns]
472
+ expressions = []
473
+ for c in self.relevant_columns:
474
+ schema = self._schema[c]
475
+ expressions.append(NormalizeAsString(self._column_expr(c), schema))
476
+ return expressions
477
+
478
+ def count(self) -> int:
479
+ """Count how many rows are in the segment, in one pass."""
480
+ if self._is_cancelled():
481
+ raise JobCancelledError(self.job_id)
482
+ start_time = time.monotonic()
483
+ result = self.database.query(self.make_select().select(Count()), int)
484
+ query_time_ms = (time.monotonic() - start_time) * 1000
485
+ self._update_stats("count_queries_stats", query_time_ms)
486
+
487
+ return result
488
+
489
+ def count_and_checksum(self) -> Tuple[int, int]:
490
+ """Count and checksum the rows in the segment, in one pass."""
491
+ if self._is_cancelled():
492
+ raise JobCancelledError(self.job_id)
493
+ checked_columns = [c for c in self.relevant_columns if c not in self.ignored_columns]
494
+ # Build transformed expressions for checksum, honoring transforms and normalization
495
+ checksum_exprs: List[Expr] = []
496
+ for c in checked_columns:
497
+ schema = self._schema[c]
498
+ checksum_exprs.append(NormalizeAsString(self._column_expr(c), schema))
499
+
500
+ q = self.make_select().select(Count(), Checksum(checksum_exprs))
501
+ start_time = time.monotonic()
502
+ count, checksum = self.database.query(q, tuple)
503
+ query_time_ms = (time.monotonic() - start_time) * 1000
504
+
505
+ self._update_stats("checksum_queries_stats", query_time_ms)
506
+
507
+ duration = query_time_ms / 1000
508
+ if duration > RECOMMENDED_CHECKSUM_DURATION:
509
+ logger.warning(
510
+ "Checksum is taking longer than expected (%.2f). "
511
+ "We recommend increasing --bisection-factor or decreasing --threads.",
512
+ duration,
513
+ )
514
+
515
+ if count:
516
+ assert checksum, (count, checksum)
517
+ return count or 0, int(checksum) if count else None
518
+
519
+ def query_key_range(self) -> Tuple[tuple, tuple]:
520
+ """Query database for minimum and maximum key. This is used for setting the initial bounds."""
521
+ # Normalizes the result (needed for UUIDs) after the min/max computation
522
+ select = self.make_select().select(
523
+ ApplyFuncAndNormalizeAsString(self._column_expr(k), f) for k in self.key_columns for f in (min_, max_)
524
+ )
525
+ result = tuple(self.database.query(select, tuple))
526
+
527
+ if any(i is None for i in result):
528
+ raise ValueError("Table appears to be empty")
529
+
530
+ # Min/max keys are interleaved
531
+ min_key, max_key = result[::2], result[1::2]
532
+ assert len(min_key) == len(max_key)
533
+
534
+ return min_key, max_key
535
+
536
+ @property
537
+ def is_bounded(self):
538
+ return self.min_key is not None and self.max_key is not None
539
+
540
+ # def approximate_size(self, row_count: Optional[int] = None):
541
+ # if not self.is_bounded:
542
+ # raise RuntimeError("Cannot approximate the size of an unbounded segment. Must have min_key and max_key.")
543
+ # diff = self.max_key - self.min_key
544
+ # assert all(d > 0 for d in diff)
545
+ # return int_product(diff)
546
+
547
+ def approximate_size(self, row_count: Optional[int] = None) -> int:
548
+ if not self.is_bounded:
549
+ raise RuntimeError("Cannot approximate the size of an unbounded segment. Must have min_key and max_key.")
550
+
551
+ schema = self.get_schema()
552
+ key_types = [schema[col].__class__ for col in self.key_columns]
553
+
554
+ if all(issubclass(t, NumericType) for t in key_types):
555
+ try:
556
+ diff = [mx - mn for mn, mx in zip(self.min_key, self.max_key)]
557
+ if not all(d > 0 for d in diff):
558
+ return row_count if row_count is not None else self.count()
559
+ return int_product(diff)
560
+ except (ValueError, TypeError):
561
+ return row_count if row_count is not None else self.count()
562
+ else:
563
+ return row_count if row_count is not None else self.count()
564
+
565
+ def _is_cancelled(self) -> bool:
566
+ run_id = self.job_id
567
+ if not run_id:
568
+ return False
569
+ run_id = f"revoke_job:{run_id}"
570
+ try:
571
+ backend = RedisBackend.get_instance()
572
+ val = backend.client.get(run_id)
573
+ if not val:
574
+ return False
575
+ if isinstance(val, bytes):
576
+ try:
577
+ val = val.decode()
578
+ except Exception:
579
+ val = str(val)
580
+ return isinstance(val, str) and val.strip().lower() == "revoke"
581
+ except Exception:
582
+ logger.warning("Unable to query Redis for cancellation for run_id=%s", run_id)
583
+ return False
@@ -0,0 +1,112 @@
1
+ # Copyright 2022-present, the Waterdip Labs Pvt. Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import itertools
16
+ from collections import deque
17
+ from collections.abc import Iterable
18
+ from concurrent.futures import ThreadPoolExecutor
19
+ from concurrent.futures.thread import _WorkItem
20
+ from queue import PriorityQueue
21
+ from time import sleep
22
+ from typing import Any, Callable, Iterator, Optional
23
+
24
+ import attrs
25
+
26
+
27
+ class AutoPriorityQueue(PriorityQueue):
28
+ """Overrides PriorityQueue to automatically get the priority from _WorkItem.kwargs
29
+
30
+ We also assign a unique id for each item, to avoid making comparisons on _WorkItem.
31
+ As a side effect, items with the same priority are returned FIFO.
32
+ """
33
+
34
+ _counter = itertools.count().__next__
35
+
36
+ def put(self, item: Optional[_WorkItem], block=True, timeout=None) -> None:
37
+ priority = item.kwargs.pop("priority") if item is not None else 0
38
+ super().put((-priority, self._counter(), item), block, timeout)
39
+
40
+ def get(self, block=True, timeout=None) -> Optional[_WorkItem]:
41
+ _p, _c, work_item = super().get(block, timeout)
42
+ return work_item
43
+
44
+
45
+ class PriorityThreadPoolExecutor(ThreadPoolExecutor):
46
+ """Overrides ThreadPoolExecutor to use AutoPriorityQueue
47
+
48
+ XXX WARNING: Might break in future versions of Python
49
+ """
50
+
51
+ def __init__(self, *args) -> None:
52
+ super().__init__(*args)
53
+ self._work_queue = AutoPriorityQueue()
54
+
55
+
56
+ @attrs.define(frozen=False, init=False)
57
+ class ThreadedYielder(Iterable):
58
+ """Yields results from multiple threads into a single iterator, ordered by priority.
59
+
60
+ To add a source iterator, call ``submit()`` with a function that returns an iterator.
61
+ Priority for the iterator can be provided via the keyword argument 'priority'. (higher runs first)
62
+ """
63
+
64
+ _pool: ThreadPoolExecutor
65
+ _futures: deque
66
+ _yield: deque
67
+ _exception: Optional[None]
68
+
69
+ _pool: ThreadPoolExecutor
70
+ _futures: deque
71
+ _yield: deque = attrs.field(alias="_yield") # Python keyword!
72
+ _exception: Optional[None]
73
+ yield_list: bool
74
+
75
+ def __init__(self, max_workers: Optional[int] = None, yield_list: bool = False) -> None:
76
+ super().__init__()
77
+ self._pool = PriorityThreadPoolExecutor(max_workers)
78
+ self._futures = deque()
79
+ self._yield = deque()
80
+ self._exception = None
81
+ self.yield_list = yield_list
82
+
83
+ def _worker(self, fn, *args, **kwargs) -> None:
84
+ try:
85
+ res = fn(*args, **kwargs)
86
+ if res is not None:
87
+ if self.yield_list:
88
+ self._yield.append(res)
89
+ else:
90
+ self._yield += res
91
+ except Exception as e:
92
+ self._exception = e
93
+
94
+ def submit(self, fn: Callable, *args, priority: int = 0, **kwargs) -> None:
95
+ self._futures.append(self._pool.submit(self._worker, fn, *args, priority=priority, **kwargs))
96
+
97
+ def __iter__(self) -> Iterator[Any]:
98
+ while True:
99
+ if self._exception:
100
+ raise self._exception
101
+
102
+ while self._yield:
103
+ yield self._yield.popleft()
104
+
105
+ if not self._futures:
106
+ # No more tasks
107
+ return
108
+
109
+ if self._futures[0].done():
110
+ self._futures.popleft()
111
+ else:
112
+ sleep(0.001)