dcs-sdk 1.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (159) hide show
  1. data_diff/__init__.py +221 -0
  2. data_diff/__main__.py +517 -0
  3. data_diff/abcs/__init__.py +13 -0
  4. data_diff/abcs/compiler.py +27 -0
  5. data_diff/abcs/database_types.py +402 -0
  6. data_diff/config.py +141 -0
  7. data_diff/databases/__init__.py +38 -0
  8. data_diff/databases/_connect.py +323 -0
  9. data_diff/databases/base.py +1417 -0
  10. data_diff/databases/bigquery.py +376 -0
  11. data_diff/databases/clickhouse.py +217 -0
  12. data_diff/databases/databricks.py +262 -0
  13. data_diff/databases/duckdb.py +207 -0
  14. data_diff/databases/mssql.py +343 -0
  15. data_diff/databases/mysql.py +189 -0
  16. data_diff/databases/oracle.py +238 -0
  17. data_diff/databases/postgresql.py +293 -0
  18. data_diff/databases/presto.py +222 -0
  19. data_diff/databases/redis.py +93 -0
  20. data_diff/databases/redshift.py +233 -0
  21. data_diff/databases/snowflake.py +222 -0
  22. data_diff/databases/sybase.py +720 -0
  23. data_diff/databases/trino.py +73 -0
  24. data_diff/databases/vertica.py +174 -0
  25. data_diff/diff_tables.py +489 -0
  26. data_diff/errors.py +17 -0
  27. data_diff/format.py +369 -0
  28. data_diff/hashdiff_tables.py +1026 -0
  29. data_diff/info_tree.py +76 -0
  30. data_diff/joindiff_tables.py +434 -0
  31. data_diff/lexicographic_space.py +253 -0
  32. data_diff/parse_time.py +88 -0
  33. data_diff/py.typed +0 -0
  34. data_diff/queries/__init__.py +13 -0
  35. data_diff/queries/api.py +213 -0
  36. data_diff/queries/ast_classes.py +811 -0
  37. data_diff/queries/base.py +38 -0
  38. data_diff/queries/extras.py +43 -0
  39. data_diff/query_utils.py +70 -0
  40. data_diff/schema.py +67 -0
  41. data_diff/table_segment.py +583 -0
  42. data_diff/thread_utils.py +112 -0
  43. data_diff/utils.py +1022 -0
  44. data_diff/version.py +15 -0
  45. dcs_core/__init__.py +13 -0
  46. dcs_core/__main__.py +17 -0
  47. dcs_core/__version__.py +15 -0
  48. dcs_core/cli/__init__.py +13 -0
  49. dcs_core/cli/cli.py +165 -0
  50. dcs_core/core/__init__.py +19 -0
  51. dcs_core/core/common/__init__.py +13 -0
  52. dcs_core/core/common/errors.py +50 -0
  53. dcs_core/core/common/models/__init__.py +13 -0
  54. dcs_core/core/common/models/configuration.py +284 -0
  55. dcs_core/core/common/models/dashboard.py +24 -0
  56. dcs_core/core/common/models/data_source_resource.py +75 -0
  57. dcs_core/core/common/models/metric.py +160 -0
  58. dcs_core/core/common/models/profile.py +75 -0
  59. dcs_core/core/common/models/validation.py +216 -0
  60. dcs_core/core/common/models/widget.py +44 -0
  61. dcs_core/core/configuration/__init__.py +13 -0
  62. dcs_core/core/configuration/config_loader.py +139 -0
  63. dcs_core/core/configuration/configuration_parser.py +262 -0
  64. dcs_core/core/configuration/configuration_parser_arc.py +328 -0
  65. dcs_core/core/datasource/__init__.py +13 -0
  66. dcs_core/core/datasource/base.py +62 -0
  67. dcs_core/core/datasource/manager.py +112 -0
  68. dcs_core/core/datasource/search_datasource.py +421 -0
  69. dcs_core/core/datasource/sql_datasource.py +1094 -0
  70. dcs_core/core/inspect.py +163 -0
  71. dcs_core/core/logger/__init__.py +13 -0
  72. dcs_core/core/logger/base.py +32 -0
  73. dcs_core/core/logger/default_logger.py +94 -0
  74. dcs_core/core/metric/__init__.py +13 -0
  75. dcs_core/core/metric/base.py +220 -0
  76. dcs_core/core/metric/combined_metric.py +98 -0
  77. dcs_core/core/metric/custom_metric.py +34 -0
  78. dcs_core/core/metric/manager.py +137 -0
  79. dcs_core/core/metric/numeric_metric.py +403 -0
  80. dcs_core/core/metric/reliability_metric.py +90 -0
  81. dcs_core/core/profiling/__init__.py +13 -0
  82. dcs_core/core/profiling/datasource_profiling.py +136 -0
  83. dcs_core/core/profiling/numeric_field_profiling.py +72 -0
  84. dcs_core/core/profiling/text_field_profiling.py +67 -0
  85. dcs_core/core/repository/__init__.py +13 -0
  86. dcs_core/core/repository/metric_repository.py +77 -0
  87. dcs_core/core/utils/__init__.py +13 -0
  88. dcs_core/core/utils/log.py +29 -0
  89. dcs_core/core/utils/tracking.py +105 -0
  90. dcs_core/core/utils/utils.py +44 -0
  91. dcs_core/core/validation/__init__.py +13 -0
  92. dcs_core/core/validation/base.py +230 -0
  93. dcs_core/core/validation/completeness_validation.py +153 -0
  94. dcs_core/core/validation/custom_query_validation.py +24 -0
  95. dcs_core/core/validation/manager.py +282 -0
  96. dcs_core/core/validation/numeric_validation.py +276 -0
  97. dcs_core/core/validation/reliability_validation.py +91 -0
  98. dcs_core/core/validation/uniqueness_validation.py +61 -0
  99. dcs_core/core/validation/validity_validation.py +738 -0
  100. dcs_core/integrations/__init__.py +13 -0
  101. dcs_core/integrations/databases/__init__.py +13 -0
  102. dcs_core/integrations/databases/bigquery.py +187 -0
  103. dcs_core/integrations/databases/databricks.py +51 -0
  104. dcs_core/integrations/databases/db2.py +652 -0
  105. dcs_core/integrations/databases/elasticsearch.py +61 -0
  106. dcs_core/integrations/databases/mssql.py +829 -0
  107. dcs_core/integrations/databases/mysql.py +409 -0
  108. dcs_core/integrations/databases/opensearch.py +64 -0
  109. dcs_core/integrations/databases/oracle.py +719 -0
  110. dcs_core/integrations/databases/postgres.py +482 -0
  111. dcs_core/integrations/databases/redshift.py +53 -0
  112. dcs_core/integrations/databases/snowflake.py +48 -0
  113. dcs_core/integrations/databases/spark_df.py +111 -0
  114. dcs_core/integrations/databases/sybase.py +1069 -0
  115. dcs_core/integrations/storage/__init__.py +13 -0
  116. dcs_core/integrations/storage/local_file.py +149 -0
  117. dcs_core/integrations/utils/__init__.py +13 -0
  118. dcs_core/integrations/utils/utils.py +36 -0
  119. dcs_core/report/__init__.py +13 -0
  120. dcs_core/report/dashboard.py +211 -0
  121. dcs_core/report/models.py +88 -0
  122. dcs_core/report/static/assets/fonts/DMSans-Bold.ttf +0 -0
  123. dcs_core/report/static/assets/fonts/DMSans-Medium.ttf +0 -0
  124. dcs_core/report/static/assets/fonts/DMSans-Regular.ttf +0 -0
  125. dcs_core/report/static/assets/fonts/DMSans-SemiBold.ttf +0 -0
  126. dcs_core/report/static/assets/images/docs.svg +6 -0
  127. dcs_core/report/static/assets/images/github.svg +4 -0
  128. dcs_core/report/static/assets/images/logo.svg +7 -0
  129. dcs_core/report/static/assets/images/slack.svg +13 -0
  130. dcs_core/report/static/index.js +2 -0
  131. dcs_core/report/static/index.js.LICENSE.txt +3971 -0
  132. dcs_sdk/__init__.py +13 -0
  133. dcs_sdk/__main__.py +18 -0
  134. dcs_sdk/__version__.py +15 -0
  135. dcs_sdk/cli/__init__.py +13 -0
  136. dcs_sdk/cli/cli.py +163 -0
  137. dcs_sdk/sdk/__init__.py +58 -0
  138. dcs_sdk/sdk/config/__init__.py +13 -0
  139. dcs_sdk/sdk/config/config_loader.py +491 -0
  140. dcs_sdk/sdk/data_diff/__init__.py +13 -0
  141. dcs_sdk/sdk/data_diff/data_differ.py +821 -0
  142. dcs_sdk/sdk/rules/__init__.py +15 -0
  143. dcs_sdk/sdk/rules/rules_mappping.py +31 -0
  144. dcs_sdk/sdk/rules/rules_repository.py +214 -0
  145. dcs_sdk/sdk/rules/schema_rules.py +65 -0
  146. dcs_sdk/sdk/utils/__init__.py +13 -0
  147. dcs_sdk/sdk/utils/serializer.py +25 -0
  148. dcs_sdk/sdk/utils/similarity_score/__init__.py +13 -0
  149. dcs_sdk/sdk/utils/similarity_score/base_provider.py +153 -0
  150. dcs_sdk/sdk/utils/similarity_score/cosine_similarity_provider.py +39 -0
  151. dcs_sdk/sdk/utils/similarity_score/jaccard_provider.py +24 -0
  152. dcs_sdk/sdk/utils/similarity_score/levenshtein_distance_provider.py +31 -0
  153. dcs_sdk/sdk/utils/table.py +475 -0
  154. dcs_sdk/sdk/utils/themes.py +40 -0
  155. dcs_sdk/sdk/utils/utils.py +349 -0
  156. dcs_sdk-1.6.5.dist-info/METADATA +150 -0
  157. dcs_sdk-1.6.5.dist-info/RECORD +159 -0
  158. dcs_sdk-1.6.5.dist-info/WHEEL +4 -0
  159. dcs_sdk-1.6.5.dist-info/entry_points.txt +4 -0
data_diff/info_tree.py ADDED
@@ -0,0 +1,76 @@
1
+ # Copyright 2022-present, the Waterdip Labs Pvt. Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import attrs
18
+ from typing_extensions import Self
19
+
20
+ from data_diff.table_segment import TableSegment
21
+
22
+
23
+ @attrs.define(frozen=False)
24
+ class SegmentInfo:
25
+ tables: List[TableSegment]
26
+
27
+ diff: Optional[List[Union[Tuple[Any, ...], List[Any]]]] = None
28
+ diff_schema: Optional[Tuple[Tuple[str, type], ...]] = None
29
+ is_diff: Optional[bool] = None
30
+ diff_count: Optional[int] = None
31
+
32
+ rowcounts: Dict[int, int] = attrs.field(factory=dict)
33
+ max_rows: Optional[int] = None
34
+
35
+ def set_diff(
36
+ self, diff: List[Union[Tuple[Any, ...], List[Any]]], schema: Optional[Tuple[Tuple[str, type]]] = None
37
+ ) -> None:
38
+ self.diff_schema = schema
39
+ self.diff = diff
40
+ self.diff_count = len(diff)
41
+ self.is_diff = self.diff_count > 0
42
+
43
+ def update_from_children(self, child_infos) -> None:
44
+ child_infos = list(child_infos)
45
+ assert child_infos
46
+
47
+ # self.diff = list(chain(*[c.diff for c in child_infos]))
48
+ self.diff_count = sum(c.diff_count for c in child_infos if c.diff_count is not None)
49
+ self.is_diff = any(c.is_diff for c in child_infos)
50
+ self.diff_schema = next((child.diff_schema for child in child_infos if child.diff_schema is not None), None)
51
+ self.diff = sum((c.diff for c in child_infos if c.diff is not None), [])
52
+
53
+ self.rowcounts = {
54
+ 1: sum(c.rowcounts[1] for c in child_infos if c.rowcounts),
55
+ 2: sum(c.rowcounts[2] for c in child_infos if c.rowcounts),
56
+ }
57
+
58
+
59
+ @attrs.define(frozen=True)
60
+ class InfoTree:
61
+ SEGMENT_INFO_CLASS = SegmentInfo
62
+
63
+ info: SegmentInfo
64
+ children: List["InfoTree"] = attrs.field(factory=list)
65
+
66
+ def add_node(self, table1: TableSegment, table2: TableSegment, max_rows: Optional[int] = None) -> Self:
67
+ cls = self.__class__
68
+ node = cls(cls.SEGMENT_INFO_CLASS([table1, table2], max_rows=max_rows))
69
+ self.children.append(node)
70
+ return node
71
+
72
+ def aggregate_info(self) -> None:
73
+ if self.children:
74
+ for c in self.children:
75
+ c.aggregate_info()
76
+ self.info.update_from_children(c.info for c in self.children)
@@ -0,0 +1,434 @@
1
+ # Copyright 2022-present, the Waterdip Labs Pvt. Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Provides classes for performing a table diff using JOIN"""
16
+
17
+ import logging
18
+ from decimal import Decimal
19
+ from functools import partial
20
+ from itertools import chain
21
+ from typing import List, Optional
22
+
23
+ import attrs
24
+
25
+ from data_diff.abcs.database_types import DbPath, NumericType
26
+ from data_diff.databases import (
27
+ BigQuery,
28
+ Database,
29
+ DuckDB,
30
+ MsSQL,
31
+ MySQL,
32
+ Oracle,
33
+ Presto,
34
+ Snowflake,
35
+ )
36
+ from data_diff.databases.base import Compiler
37
+ from data_diff.diff_tables import DiffResult, TableDiffer
38
+ from data_diff.info_tree import InfoTree
39
+ from data_diff.queries.api import (
40
+ and_,
41
+ if_,
42
+ leftjoin,
43
+ or_,
44
+ outerjoin,
45
+ rightjoin,
46
+ sum_,
47
+ table,
48
+ this,
49
+ when,
50
+ )
51
+ from data_diff.queries.ast_classes import (
52
+ Code,
53
+ Concat,
54
+ Count,
55
+ Expr,
56
+ ITable,
57
+ Join,
58
+ Random,
59
+ TablePath,
60
+ )
61
+ from data_diff.queries.extras import NormalizeAsString
62
+ from data_diff.query_utils import append_to_table, drop_table
63
+ from data_diff.table_segment import TableSegment
64
+ from data_diff.thread_utils import ThreadedYielder
65
+ from data_diff.utils import safezip
66
+
67
+ logger = logging.getLogger("joindiff_tables")
68
+
69
+ TABLE_WRITE_LIMIT = 1000
70
+
71
+
72
+ def merge_dicts(dicts):
73
+ i = iter(dicts)
74
+ try:
75
+ res = next(i)
76
+ except StopIteration:
77
+ return {}
78
+
79
+ for d in i:
80
+ res.update(d)
81
+ return res
82
+
83
+
84
+ def sample(table_expr):
85
+ return table_expr.order_by(Random()).limit(10)
86
+
87
+
88
+ def create_temp_table(c: Compiler, path: TablePath, expr: Expr) -> str:
89
+ db = c.database
90
+ c: Compiler = attrs.evolve(c, root=False) # we're compiling fragments, not full queries
91
+ if isinstance(db, BigQuery):
92
+ return f"create table {c.dialect.compile(c, path)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.dialect.compile(c, expr)}"
93
+ elif isinstance(db, Presto):
94
+ return f"create table {c.dialect.compile(c, path)} as {c.dialect.compile(c, expr)}"
95
+ elif isinstance(db, Oracle):
96
+ return f"create global temporary table {c.dialect.compile(c, path)} as {c.dialect.compile(c, expr)}"
97
+ else:
98
+ return f"create temporary table {c.dialect.compile(c, path)} as {c.dialect.compile(c, expr)}"
99
+
100
+
101
+ def bool_to_int(x):
102
+ return if_(x, 1, 0)
103
+
104
+
105
+ def _outerjoin(db: Database, a: ITable, b: ITable, keys1: List[str], keys2: List[str], select_fields: dict) -> ITable:
106
+ on = [a[k1] == b[k2] for k1, k2 in safezip(keys1, keys2)]
107
+
108
+ is_exclusive_a = and_(b[k] == None for k in keys2)
109
+ is_exclusive_b = and_(a[k] == None for k in keys1)
110
+
111
+ if isinstance(db, MsSQL):
112
+ # There is no "IS NULL" or "ISNULL()" as expressions, only as conditions.
113
+ is_exclusive_a = when(is_exclusive_a).then(1).else_(0)
114
+ is_exclusive_b = when(is_exclusive_b).then(1).else_(0)
115
+
116
+ if isinstance(db, Oracle):
117
+ is_exclusive_a = bool_to_int(is_exclusive_a)
118
+ is_exclusive_b = bool_to_int(is_exclusive_b)
119
+
120
+ if isinstance(db, MySQL):
121
+ # No outer join
122
+ l: Join = leftjoin(a, b).on(*on).select(is_exclusive_a=is_exclusive_a, is_exclusive_b=False, **select_fields)
123
+ r: Join = rightjoin(a, b).on(*on).select(is_exclusive_a=False, is_exclusive_b=is_exclusive_b, **select_fields)
124
+ return l.union(r)
125
+
126
+ return outerjoin(a, b).on(*on).select(is_exclusive_a=is_exclusive_a, is_exclusive_b=is_exclusive_b, **select_fields)
127
+
128
+
129
+ def _slice_tuple(t, *sizes):
130
+ i = 0
131
+ for size in sizes:
132
+ yield t[i : i + size]
133
+ i += size
134
+ assert i == len(t)
135
+
136
+
137
+ def json_friendly_value(v):
138
+ if isinstance(v, Decimal):
139
+ return float(v)
140
+ return v
141
+
142
+
143
+ @attrs.define(frozen=False)
144
+ class JoinDiffer(TableDiffer):
145
+ """Finds the diff between two SQL tables in the same database, using JOINs.
146
+
147
+ The algorithm uses an OUTER JOIN (or equivalent) with extra checks and statistics.
148
+ The two tables must reside in the same database, and their primary keys must be unique and not null.
149
+
150
+ All parameters are optional.
151
+
152
+ Parameters:
153
+ threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads.
154
+ max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto.
155
+ Only relevant when `threaded` is ``True``.
156
+ There may be many pools, so number of actual threads can be a lot higher.
157
+ validate_unique_key (bool): Enable/disable validating that the key columns are unique. (default: True)
158
+ If there are no UNIQUE constraints in the schema, it is done in a single query,
159
+ and can't be threaded, so it's very slow on non-cloud dbs.
160
+ sample_exclusive_rows (bool): Enable/disable sampling of exclusive rows. (default: False)
161
+ Creates a temporary table.
162
+ materialize_to_table (DbPath, optional): Path of new table to write diff results to. Disabled if not provided.
163
+ materialize_all_rows (bool): Materialize every row, not just those that are different. (default: False)
164
+ table_write_limit (int): Maximum number of rows to write when materializing, per thread.
165
+ skip_null_keys (bool): Skips diffing any rows with null PKs (displays a warning if any are null) (default: False)
166
+ """
167
+
168
+ validate_unique_key: bool = True
169
+ sample_exclusive_rows: bool = False
170
+ materialize_to_table: Optional[DbPath] = None
171
+ materialize_all_rows: bool = False
172
+ table_write_limit: int = TABLE_WRITE_LIMIT
173
+ skip_null_keys: bool = False
174
+
175
+ stats: dict = attrs.field(factory=dict)
176
+
177
+ def _diff_tables_root(self, table1: TableSegment, table2: TableSegment, info_tree: InfoTree) -> DiffResult:
178
+ db = table1.database
179
+
180
+ if table1.database is not table2.database:
181
+ raise ValueError("Join-diff only works when both tables are in the same database")
182
+
183
+ table1, table2 = self._threaded_call("with_schema", [table1, table2])
184
+
185
+ bg_funcs = [partial(self._test_duplicate_keys, table1, table2)] if self.validate_unique_key else []
186
+ if self.materialize_to_table:
187
+ drop_table(db, self.materialize_to_table)
188
+
189
+ with self._run_in_background(*bg_funcs):
190
+ if isinstance(db, (Snowflake, BigQuery, DuckDB)):
191
+ # Don't segment the table; let the database handling parallelization
192
+ yield from self._diff_segments(None, table1, table2, info_tree, None)
193
+ else:
194
+ yield from self._bisect_and_diff_tables(table1, table2, info_tree)
195
+ logger.info(f"Diffing complete: {table1.table_path} <> {table2.table_path}")
196
+ if self.materialize_to_table:
197
+ logger.info("Materialized diff to table '%s'.", ".".join(self.materialize_to_table))
198
+
199
+ def _diff_segments(
200
+ self,
201
+ ti: ThreadedYielder,
202
+ table1: TableSegment,
203
+ table2: TableSegment,
204
+ info_tree: InfoTree,
205
+ max_rows: int,
206
+ level=0,
207
+ segment_index=None,
208
+ segment_count=None,
209
+ ):
210
+ assert table1.database is table2.database
211
+
212
+ if segment_index or table1.min_key or max_rows:
213
+ logger.info(
214
+ ". " * level + f"Diffing segment {segment_index}/{segment_count}, "
215
+ f"key-range: {table1.min_key}..{table2.max_key}, "
216
+ f"size <= {max_rows}"
217
+ )
218
+
219
+ db = table1.database
220
+ diff_rows, a_cols, b_cols, is_diff_cols, all_rows = self._create_outer_join(table1, table2)
221
+
222
+ with self._run_in_background(
223
+ partial(self._collect_stats, 1, table1, info_tree),
224
+ partial(self._collect_stats, 2, table2, info_tree),
225
+ partial(self._test_null_keys, table1, table2),
226
+ partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols, table1, table2),
227
+ partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols, table1, table2),
228
+ (
229
+ partial(
230
+ self._materialize_diff,
231
+ db,
232
+ all_rows if self.materialize_all_rows else diff_rows,
233
+ segment_index=segment_index,
234
+ )
235
+ if self.materialize_to_table
236
+ else None
237
+ ),
238
+ ):
239
+ assert len(a_cols) == len(b_cols)
240
+ logger.debug(f"Querying for different rows: {table1.table_path}")
241
+ diff = db.query(diff_rows, list, log_message=table1.table_path)
242
+ info_tree.info.set_diff(diff, schema=tuple(diff_rows.schema.items()))
243
+ for is_xa, is_xb, *x in diff:
244
+ if is_xa and is_xb:
245
+ # Can't both be exclusive, meaning a pk is NULL
246
+ # This can happen if the explicit null test didn't finish running yet
247
+ if self.skip_null_keys:
248
+ # warning is thrown in explicit null test
249
+ continue
250
+ else:
251
+ raise ValueError("NULL values in one or more primary keys")
252
+ # _is_diff, a_row, b_row = _slice_tuple(x, len(is_diff_cols), len(a_cols), len(b_cols))
253
+ _is_diff, ab_row = _slice_tuple(x, len(is_diff_cols), len(a_cols) + len(b_cols))
254
+ a_row, b_row = ab_row[::2], ab_row[1::2]
255
+ assert len(a_row) == len(b_row)
256
+ if not is_xb:
257
+ yield "-", tuple(a_row)
258
+ if not is_xa:
259
+ yield "+", tuple(b_row)
260
+
261
+ def _test_duplicate_keys(self, table1: TableSegment, table2: TableSegment):
262
+ logger.debug(f"Testing for duplicate keys: {table1.table_path} <> {table2.table_path}")
263
+
264
+ # Test duplicate keys
265
+ for ts in [table1, table2]:
266
+ unique = (
267
+ ts.database.query_table_unique_columns(ts.table_path) if ts.database.SUPPORTS_UNIQUE_CONSTAINT else []
268
+ )
269
+
270
+ t = ts.make_select()
271
+ key_columns = ts.key_columns
272
+
273
+ unvalidated = list(set(key_columns) - set(unique))
274
+ if unvalidated:
275
+ logger.info(f"Validating that the are no duplicate keys in columns: {unvalidated} for {ts.table_path}")
276
+ # Validate that there are no duplicate keys
277
+ self.stats["validated_unique_keys"] = self.stats.get("validated_unique_keys", []) + [unvalidated]
278
+ q = t.select(total=Count(), total_distinct=Count(Concat(this[unvalidated]), distinct=True))
279
+ total, total_distinct = ts.database.query(q, tuple, log_message=ts.table_path)
280
+ if total != total_distinct:
281
+ raise ValueError("Duplicate primary keys")
282
+
283
+ def _test_null_keys(self, table1, table2):
284
+ logger.debug(f"Testing for null keys: {table1.table_path} <> {table2.table_path}")
285
+
286
+ # Test null keys
287
+ for ts in [table1, table2]:
288
+ t = ts.make_select()
289
+ key_columns = ts.key_columns
290
+
291
+ q = t.select(*this[key_columns]).where(or_(this[k] == None for k in key_columns))
292
+ nulls = ts.database.query(q, list, log_message=ts.table_path)
293
+ if nulls:
294
+ if self.skip_null_keys:
295
+ logger.warning(
296
+ f"NULL values in one or more primary keys of {ts.table_path}. Skipping rows with NULL keys."
297
+ )
298
+ else:
299
+ raise ValueError(f"NULL values in one or more primary keys of {ts.table_path}")
300
+
301
+ def _collect_stats(self, i, table_seg: TableSegment, info_tree: InfoTree):
302
+ logger.debug(f"Collecting stats for table #{i}: {table_seg.table_path}")
303
+ db = table_seg.database
304
+
305
+ # Metrics
306
+ col_exprs = merge_dicts(
307
+ (
308
+ {
309
+ # f"min_{c}": min_(this[c]),
310
+ # f"max_{c}": max_(this[c]),
311
+ }
312
+ if c in table_seg.key_columns
313
+ else {
314
+ f"sum_{c}": sum_(this[c]),
315
+ # f"avg_{c}": avg(this[c]),
316
+ # f"min_{c}": min_(this[c]),
317
+ # f"max_{c}": max_(this[c]),
318
+ }
319
+ )
320
+ for c in table_seg.relevant_columns
321
+ if isinstance(table_seg._schema[c], NumericType)
322
+ )
323
+ col_exprs["count"] = Count()
324
+
325
+ res = db.query(table_seg.make_select().select(**col_exprs), tuple, log_message=table_seg.table_path)
326
+
327
+ for col_name, value in safezip(col_exprs, res):
328
+ if value is not None:
329
+ value = json_friendly_value(value)
330
+ stat_name = f"table{i}_{col_name}"
331
+
332
+ if col_name == "count":
333
+ info_tree.info.rowcounts[i] = value
334
+
335
+ if stat_name in self.stats:
336
+ self.stats[stat_name] += value
337
+ else:
338
+ self.stats[stat_name] = value
339
+
340
+ logger.debug("Done collecting stats for table #%s: %s", i, table_seg.table_path)
341
+
342
+ def _create_outer_join(self, table1, table2):
343
+ db = table1.database
344
+ if db is not table2.database:
345
+ raise ValueError("Joindiff only applies to tables within the same database")
346
+
347
+ keys1 = table1.key_columns
348
+ keys2 = table2.key_columns
349
+ if len(keys1) != len(keys2):
350
+ raise ValueError("The provided key columns are of a different count")
351
+
352
+ cols1 = table1.relevant_columns
353
+ cols2 = table2.relevant_columns
354
+ if len(cols1) != len(cols2):
355
+ raise ValueError("The provided columns are of a different count")
356
+
357
+ a = table1.make_select()
358
+ b = table2.make_select()
359
+
360
+ is_diff_cols = {f"is_diff_{c1}": bool_to_int(a[c1].is_distinct_from(b[c2])) for c1, c2 in safezip(cols1, cols2)}
361
+
362
+ a_cols = {f"{c}_a": NormalizeAsString(a[c]) for c in cols1}
363
+ b_cols = {f"{c}_b": NormalizeAsString(b[c]) for c in cols2}
364
+ # Order columns as col1_a, col1_b, col2_a, col2_b, etc.
365
+ cols = {k: v for k, v in chain(*zip(a_cols.items(), b_cols.items()))}
366
+
367
+ all_rows = _outerjoin(db, a, b, keys1, keys2, {**is_diff_cols, **cols})
368
+ diff_rows = all_rows.where(or_(this[c] == 1 for c in is_diff_cols))
369
+ return diff_rows, a_cols, b_cols, is_diff_cols, all_rows
370
+
371
+ def _count_diff_per_column(
372
+ self,
373
+ db,
374
+ diff_rows,
375
+ cols,
376
+ is_diff_cols,
377
+ table1: Optional[TableSegment] = None,
378
+ table2: Optional[TableSegment] = None,
379
+ ):
380
+ logger.debug(f"Counting differences per column: {table1.table_path} <> {table2.table_path}")
381
+ is_diff_cols_counts = db.query(
382
+ diff_rows.select(sum_(this[c]) for c in is_diff_cols),
383
+ tuple,
384
+ log_message=f"{table1.table_path} <> {table2.table_path}",
385
+ )
386
+ diff_counts = {}
387
+ for name, count in safezip(cols, is_diff_cols_counts):
388
+ diff_counts[name] = diff_counts.get(name, 0) + (count or 0)
389
+ self.stats["diff_counts"] = diff_counts
390
+
391
+ def _sample_and_count_exclusive(
392
+ self,
393
+ db,
394
+ diff_rows,
395
+ a_cols,
396
+ b_cols,
397
+ table1: Optional[TableSegment] = None,
398
+ table2: Optional[TableSegment] = None,
399
+ ):
400
+ if isinstance(db, (Oracle, MsSQL)):
401
+ exclusive_rows_query = diff_rows.where((this.is_exclusive_a == 1) | (this.is_exclusive_b == 1))
402
+ else:
403
+ exclusive_rows_query = diff_rows.where(this.is_exclusive_a | this.is_exclusive_b)
404
+
405
+ if not self.sample_exclusive_rows:
406
+ logger.debug(f"Counting exclusive rows: {table1.table_path} <> {table2.table_path}")
407
+ self.stats["exclusive_count"] = db.query(
408
+ exclusive_rows_query.count(), int, log_message=f"{table1.table_path} <> {table2.table_path}"
409
+ )
410
+ return
411
+
412
+ logger.info("Counting and sampling exclusive rows")
413
+
414
+ def exclusive_rows(expr):
415
+ c = Compiler(db)
416
+ name = c.new_unique_table_name("temp_table")
417
+ exclusive_rows = table(name, schema=expr.source_table.schema)
418
+ yield Code(create_temp_table(c, exclusive_rows, expr.limit(self.table_write_limit)))
419
+
420
+ count = yield exclusive_rows.count()
421
+ self.stats["exclusive_count"] = self.stats.get("exclusive_count", 0) + count[0][0]
422
+ sample_rows = yield sample(exclusive_rows.select(*this[list(a_cols)], *this[list(b_cols)]))
423
+ self.stats["exclusive_sample"] = self.stats.get("exclusive_sample", []) + sample_rows
424
+
425
+ # Only drops if create table succeeded (meaning, the table didn't already exist)
426
+ yield exclusive_rows.drop()
427
+
428
+ # Run as a sequence of thread-local queries (compiled into a ThreadLocalInterpreter)
429
+ db.query(exclusive_rows(exclusive_rows_query), None)
430
+
431
+ def _materialize_diff(self, db, diff_rows, segment_index=None):
432
+ assert self.materialize_to_table
433
+
434
+ append_to_table(db, self.materialize_to_table, diff_rows.limit(self.table_write_limit))