dbdiff 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. dbdiff/__init__.py +7 -0
  2. dbdiff/__main__.py +15 -0
  3. dbdiff/cli.py +491 -0
  4. dbdiff/logging.json +24 -0
  5. dbdiff/main.py +728 -0
  6. dbdiff/report.py +165 -0
  7. dbdiff/templates/all_keys_count.sql +6 -0
  8. dbdiff/templates/all_keys_sample.sql +10 -0
  9. dbdiff/templates/create_dedup.sql +11 -0
  10. dbdiff/templates/create_dup.sql +11 -0
  11. dbdiff/templates/create_joined_table.sql +13 -0
  12. dbdiff/templates/create_joined_table_from_selectinto.sql +14 -0
  13. dbdiff/templates/create_temp_table.sql +1 -0
  14. dbdiff/templates/diff_column.sql +28 -0
  15. dbdiff/templates/diff_column_hier.sql +5 -0
  16. dbdiff/templates/diff_column_numeric_diffs_binned.sql +15 -0
  17. dbdiff/templates/diff_column_numeric_diffs_sorted.sql +6 -0
  18. dbdiff/templates/diff_column_raw.sql +11 -0
  19. dbdiff/templates/diff_column_summary.sql +5 -0
  20. dbdiff/templates/diff_rows_sample.sql +8 -0
  21. dbdiff/templates/first_key_base.sql +18 -0
  22. dbdiff/templates/first_key_count.sql +4 -0
  23. dbdiff/templates/first_key_sample.sql +10 -0
  24. dbdiff/templates/html/base.html +34 -0
  25. dbdiff/templates/html/report.html +241 -0
  26. dbdiff/templates/insert_diff.sql +9 -0
  27. dbdiff/templates/insert_joined_table.sql +26 -0
  28. dbdiff/templates/joined_column.sql +7 -0
  29. dbdiff/templates/joined_column_hier.sql +11 -0
  30. dbdiff/templates/joined_column_numeric_diffs_binned.sql +15 -0
  31. dbdiff/templates/joined_column_numeric_diffs_sorted.sql +6 -0
  32. dbdiff/templates/joined_column_raw.sql +6 -0
  33. dbdiff/templates/joined_count.sql +3 -0
  34. dbdiff/templates/joined_rows_count.sql +3 -0
  35. dbdiff/templates/joined_rows_sample.sql +3 -0
  36. dbdiff/templates/sub_keys_base.sql +33 -0
  37. dbdiff/templates/sub_keys_count.sql +4 -0
  38. dbdiff/templates/sub_keys_grouped.sql +14 -0
  39. dbdiff/templates/sub_keys_sample.sql +12 -0
  40. dbdiff/templates/table_columns.sql +5 -0
  41. dbdiff/templates/table_drop.sql +1 -0
  42. dbdiff/templates/table_exists.sql +4 -0
  43. dbdiff/templates/table_rows.sql +2 -0
  44. dbdiff/templates/table_rows_uniq.sql +6 -0
  45. dbdiff/vertica.py +126 -0
  46. dbdiff-0.7.1.dist-info/METADATA +200 -0
  47. dbdiff-0.7.1.dist-info/RECORD +51 -0
  48. dbdiff-0.7.1.dist-info/WHEEL +4 -0
  49. dbdiff-0.7.1.dist-info/entry_points.txt +2 -0
  50. dbdiff-0.7.1.dist-info/licenses/AUTHORS.md +4 -0
  51. dbdiff-0.7.1.dist-info/licenses/LICENSE +9 -0
dbdiff/main.py ADDED
@@ -0,0 +1,728 @@
1
+ import logging
2
+ import logging.config
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ import pandas as pd
7
+ from jinja2 import Environment, PackageLoader
8
+ from vertica_python.vertica.cursor import Cursor
9
+
10
+ from dbdiff.vertica import get_column_info_lookup, implicit_dtype_comparison
11
+
12
+ JINJA_ENV = Environment(loader=PackageLoader("dbdiff", "templates"))
13
+ LOGGER = logging.getLogger(__name__)
14
+
15
+
16
+ def is_numeric_like(dtype: str):
17
+ dtype_l = dtype.lower()
18
+ return any({"int" in dtype_l, "float" in dtype_l, "numeric" in dtype_l})
19
+
20
+
21
+ def is_date_like(dtype: str):
22
+ dtype_l = dtype.lower()
23
+ return "date" in dtype_l
24
+
25
+
26
+ def check_primary_key(cur: Cursor, schema: str, table: str, join_cols: list) -> int:
27
+ """Given a list of columns return the # of records for which they are NOT
28
+ a primary key."""
29
+
30
+ cur.execute(
31
+ JINJA_ENV.get_template("table_rows.sql").render(schema_name=schema, table_name=table)
32
+ )
33
+ r = cur.fetchall()
34
+ n_rows = r[0]["COUNT"]
35
+ cur.execute(
36
+ JINJA_ENV.get_template("table_rows_uniq.sql").render(
37
+ schema_name=schema, table_name=table, join_cols=", ".join(join_cols)
38
+ )
39
+ )
40
+ n_distinct_rows = cur.fetchall()[0]["COUNT"]
41
+ return n_rows - n_distinct_rows
42
+
43
+
44
+ def get_all_col_info(
45
+ cur: Cursor,
46
+ schema,
47
+ x_table,
48
+ y_schema,
49
+ y_table,
50
+ exclude_columns_set,
51
+ save_column_summary,
52
+ save_column_summary_format,
53
+ ) -> pd.DataFrame:
54
+ LOGGER.info("Getting column info for both tables.")
55
+ x_table_info_lookup = get_column_info_lookup(cur, schema, x_table)
56
+ y_table_info_lookup = get_column_info_lookup(cur, y_schema, y_table)
57
+
58
+ def comparable_(x, y) -> bool:
59
+ # this doesn't capture the case where they both could be converted to float to be compared (two hop conversions):
60
+ if (x is None) or (y is None):
61
+ return False
62
+ x_or_y = implicit_dtype_comparison(x, y) or implicit_dtype_comparison(y, x)
63
+ return x_or_y
64
+
65
+ all_keys = list(x_table_info_lookup.keys()) + list(y_table_info_lookup.keys())
66
+ all_col_info = {
67
+ col: {
68
+ "x_dtype": x_table_info_lookup.get(col, None),
69
+ "y_dtype": y_table_info_lookup.get(col, None),
70
+ "comparable": comparable_(
71
+ x_table_info_lookup.get(col, None), y_table_info_lookup.get(col, None)
72
+ ),
73
+ "exclude": (col in exclude_columns_set),
74
+ }
75
+ for col in all_keys
76
+ }
77
+ LOGGER.debug(all_col_info)
78
+ all_col_info_df = pd.DataFrame(all_col_info).transpose()
79
+
80
+ if save_column_summary:
81
+ if save_column_summary_format.lower() == "csv":
82
+ all_col_info_df.to_csv(Path(x_table + "_col_info.csv"))
83
+ if save_column_summary_format.lower() == "pickle":
84
+ all_col_info_df.to_pickle(Path(x_table + "_col_info.pkl"))
85
+ else:
86
+ LOGGER.info("All column info:\n" + all_col_info_df.to_string())
87
+ LOGGER.info(
88
+ "Missing columns in x:\n"
89
+ + all_col_info_df.loc[all_col_info_df.x_dtype.isnull(), :].to_string()
90
+ )
91
+ LOGGER.info(
92
+ "Missing columns in y:\n"
93
+ + all_col_info_df.loc[all_col_info_df.y_dtype.isnull(), :].to_string()
94
+ )
95
+ LOGGER.debug(all_col_info_df.comparable)
96
+ LOGGER.debug(~all_col_info_df.comparable)
97
+ LOGGER.info(
98
+ "These columns have incompatible dtypes, specifically neither of them can be implicitly converted to the other:\n"
99
+ + all_col_info_df.loc[(~all_col_info_df.comparable).astype("bool"), :].to_string()
100
+ )
101
+
102
+ return all_col_info_df
103
+
104
+
105
+ def select_distinct_rows(
106
+ cur: Cursor, schema: str, table: str, join_cols: list, use_temp_tables: bool = False
107
+ ) -> tuple[str, str]:
108
+ """Select only the rows that are distinct on join_cols.
109
+ *Instead of deleting the rows, we'll select those without duplicates into a
110
+ new table, and return the name of that new table.
111
+ Delete is inefficient, see: https://www.vertica.com/docs/9.2.x/HTML/Content/Authoring/AnalyzingData/Optimizations/PerformanceConsiderationsForDELETEAndUPDATEQueries.htm
112
+ And: https://www.vertica.com/blog/another-way-to-de-duplicate-table-rows-quick-tip/
113
+ """
114
+ drop_q = JINJA_ENV.get_template("table_drop.sql").render(
115
+ schema_name=schema, table_name=(table + "_dedup")
116
+ )
117
+ LOGGER.info(drop_q)
118
+ cur.execute(drop_q)
119
+ q = JINJA_ENV.get_template("create_dedup.sql").render(
120
+ schema_name=schema,
121
+ table_name=table,
122
+ table_name_dedup=(table + "_dedup"),
123
+ group_cols=", ".join(join_cols),
124
+ join_cols=" AND ".join(["x.{0} <=> y.{0}".format(col) for col in join_cols]),
125
+ use_temp_table=use_temp_tables,
126
+ )
127
+ if use_temp_tables:
128
+ q = JINJA_ENV.get_template("create_temp_table.sql").render(
129
+ table_name=(table + "_dedup"), query=q
130
+ )
131
+ LOGGER.info(q)
132
+ cur.execute(q)
133
+ LOGGER.info("COMMIT;")
134
+ cur.execute("COMMIT;")
135
+ drop_q = JINJA_ENV.get_template("table_drop.sql").render(
136
+ schema_name=schema, table_name=(table + "_dup")
137
+ )
138
+ LOGGER.info(drop_q)
139
+ cur.execute(drop_q)
140
+ q = JINJA_ENV.get_template("create_dup.sql").render(
141
+ schema_name=schema,
142
+ table_name=table,
143
+ table_name_dup=(table + "_dup"),
144
+ group_cols=", ".join(join_cols),
145
+ join_cols=" AND ".join(["x.{0} <=> y.{0}".format(col) for col in join_cols]),
146
+ use_temp_table=use_temp_tables,
147
+ )
148
+ if use_temp_tables:
149
+ q = JINJA_ENV.get_template("create_temp_table.sql").render(
150
+ table_name=(table + "_dup"), query=q
151
+ )
152
+ LOGGER.info(q)
153
+ cur.execute(q)
154
+ LOGGER.info("COMMIT;")
155
+ cur.execute("COMMIT;")
156
+
157
+ return (schema, "v_temp_schema")[use_temp_tables], f"{table}_dedup"
158
+
159
+
160
+ def create_joined_table(cur: Cursor, create_insert=False, **kwargs):
161
+ """
162
+ Joins two tables x and y.
163
+ :param cur: vertica python Cursor
164
+ :return: list - all queries run.
165
+ """
166
+ drop_q = JINJA_ENV.get_template("table_drop.sql").render(
167
+ schema_name=kwargs["joined_schema"], table_name=kwargs["joined_table"]
168
+ )
169
+ LOGGER.info(drop_q)
170
+ cur.execute(drop_q)
171
+
172
+ if create_insert:
173
+ # these separately do CREATE TABLE and then
174
+ # INSERT INTO
175
+ create_q = JINJA_ENV.get_template("create_joined_table.sql").render(kwargs)
176
+ LOGGER.info(create_q)
177
+ cur.execute(create_q)
178
+ insert_q = JINJA_ENV.get_template("insert_joined_table.sql").render(kwargs)
179
+ LOGGER.info(insert_q)
180
+ cur.execute(insert_q)
181
+ else:
182
+ # this does a SELECT INTO
183
+ join_q = JINJA_ENV.get_template("create_joined_table_from_selectinto.sql").render(kwargs)
184
+ LOGGER.info(join_q)
185
+ cur.execute(join_q)
186
+
187
+ LOGGER.info("COMMIT;")
188
+ cur.execute("COMMIT;")
189
+
190
+ table_rows_q = JINJA_ENV.get_template("table_rows.sql").render(
191
+ schema_name=kwargs["joined_schema"], table_name=kwargs["joined_table"]
192
+ )
193
+ LOGGER.info(table_rows_q)
194
+ cur.execute(table_rows_q)
195
+ r = cur.fetchall()
196
+ joined_row_count = r[0]["COUNT"]
197
+ return joined_row_count
198
+
199
+
200
+ def get_unmatched_rows_straight(
201
+ cur: Cursor,
202
+ x_schema: str,
203
+ y_schema: str,
204
+ x_table: str,
205
+ y_table: str,
206
+ join_cols: list,
207
+ max_rows_column: int,
208
+ ) -> dict[str, dict[str, Any]]:
209
+ """
210
+ Get rows that don't match on a join using all of the keys ("straight").
211
+ """
212
+ all_keys_count = JINJA_ENV.get_template("all_keys_count.sql")
213
+ all_keys_sample = JINJA_ENV.get_template("all_keys_sample.sql")
214
+
215
+ results = {
216
+ "x": {"count": 0, "query": "select ...", "sample": pd.DataFrame()},
217
+ "y": {"count": 0, "query": "select ...", "sample": pd.DataFrame()},
218
+ }
219
+
220
+ for side in {"x", "y"}:
221
+ d = {
222
+ "x_schema": x_schema,
223
+ "y_schema": y_schema,
224
+ "x_table": x_table,
225
+ "y_table": y_table,
226
+ "join_cols": join_cols,
227
+ "x": (side == "x"),
228
+ "max_rows_column": max_rows_column,
229
+ }
230
+ q = all_keys_count.render(d)
231
+ cur.execute(q)
232
+ r = cur.fetchall()
233
+ results[side]["count"] = r[0]["COUNT"]
234
+ results[side]["query"] = all_keys_sample.render(d)
235
+ cur.execute(results[side]["query"])
236
+ results[side]["sample"] = pd.DataFrame(cur.fetchall())
237
+
238
+ return results
239
+
240
+
241
+ def get_unmatched_rows(
242
+ cur: Cursor,
243
+ x_schema: str,
244
+ y_schema: str,
245
+ x_table: str,
246
+ y_table: str,
247
+ join_cols: list,
248
+ max_rows_column: int,
249
+ ) -> dict[Any, dict[str, dict[str, Any]]]:
250
+ """
251
+ Pull out rows that are unmatched between the two tables on the join columns.
252
+ If looking at this hierarchically, we consider the join by
253
+ key a, then key a+b (where a matched), then key a+b+c (where a+b matched), etc
254
+ to see at what level we're missing things.
255
+ """
256
+ results = {
257
+ col: {
258
+ "x": {"count": 0, "query": "select ...", "sample": pd.DataFrame()},
259
+ "y": {"count": 0, "query": "select ...", "sample": pd.DataFrame()},
260
+ }
261
+ for col in join_cols
262
+ }
263
+
264
+ first_key_count = JINJA_ENV.get_template("first_key_count.sql")
265
+ first_key_t = JINJA_ENV.get_template("first_key_sample.sql")
266
+ sub_keys_count = JINJA_ENV.get_template("sub_keys_count.sql")
267
+ sub_keys_t = JINJA_ENV.get_template("sub_keys_sample.sql")
268
+ sub_keys_g = JINJA_ENV.get_template("sub_keys_grouped.sql")
269
+
270
+ LOGGER.info(
271
+ "Getting rows that did not match on only the first join column: " + join_cols[0] + "."
272
+ )
273
+ for side in {"x", "y"}:
274
+ d = {
275
+ "x_schema": x_schema,
276
+ "y_schema": y_schema,
277
+ "x_table": x_table,
278
+ "y_table": y_table,
279
+ "join_col": join_cols[0],
280
+ "x": (side == "x"),
281
+ "max_rows_column": max_rows_column,
282
+ }
283
+ q = first_key_count.render(d)
284
+ cur.execute(q)
285
+ r = cur.fetchall()
286
+ results[join_cols[0]][side]["count"] = r[0]["COUNT"]
287
+ results[join_cols[0]][side]["query"] = first_key_t.render(d)
288
+ cur.execute(results[join_cols[0]][side]["query"])
289
+ results[join_cols[0]][side]["sample"] = pd.DataFrame(cur.fetchall())
290
+
291
+ for i in range(1, len(join_cols)):
292
+ LOGGER.info(
293
+ "Getting rows that did not match on the "
294
+ + str(i + 1)
295
+ + "-nd/rd/th join column: "
296
+ + join_cols[i]
297
+ + "."
298
+ )
299
+ LOGGER.info(
300
+ "This is equivalent to joining the tables on unique rows of "
301
+ + ",".join(join_cols[: (i + 1)])
302
+ + " where all but the last already exist."
303
+ )
304
+
305
+ for side in {"x", "y"}:
306
+ d = {
307
+ "x_schema": x_schema,
308
+ "y_schema": y_schema,
309
+ "x_table": x_table,
310
+ "y_table": y_table,
311
+ "join_cols": join_cols[: (i + 1)],
312
+ "x": (side == "x"),
313
+ "max_rows_column": max_rows_column,
314
+ }
315
+ q = sub_keys_count.render(d)
316
+ cur.execute(q)
317
+ r = cur.fetchall()
318
+ results[join_cols[i]][side]["count"] = r[0]["COUNT"]
319
+ results[join_cols[i]][side]["query"] = sub_keys_t.render(d)
320
+ cur.execute(results[join_cols[i]][side]["query"])
321
+ results[join_cols[i]][side]["sample"] = pd.DataFrame(cur.fetchall())
322
+ results[join_cols[i]][side]["query_grouped"] = sub_keys_g.render(d)
323
+ cur.execute(results[join_cols[i]][side]["query_grouped"])
324
+ results[join_cols[i]][side]["sample_grouped"] = pd.DataFrame(cur.fetchall())
325
+
326
+ return results
327
+
328
+
329
+ def create_diff_table(
330
+ cur: Cursor, schema: str, table: str, join_cols: list, all_col_info_df: pd.DataFrame
331
+ ) -> str:
332
+ drop_q = JINJA_ENV.get_template("table_drop.sql").render(schema_name=schema, table_name=table)
333
+ # so simple that putting into a template would make this harder to follow...
334
+ q = "CREATE TABLE {schema}.{table} ( {columns}, column_name VARCHAR(255) );".format(
335
+ schema=schema,
336
+ table=table,
337
+ columns=", ".join(
338
+ all_col_info_df.loc[all_col_info_df.index.isin(join_cols)]
339
+ .apply(lambda x: " ".join([x.name, x.x_dtype]), axis=1)
340
+ .values
341
+ ),
342
+ )
343
+ cur.execute(drop_q)
344
+ cur.execute(q)
345
+ return q
346
+
347
+
348
+ def insert_diff_table(cur: Cursor, **kwargs) -> None:
349
+ cur.execute(JINJA_ENV.get_template("insert_diff.sql").render(kwargs))
350
+ cur.execute("COMMIT;")
351
+
352
+
353
+ def get_diff_rows(
354
+ cur: Cursor,
355
+ output_schema: str,
356
+ x_table: str,
357
+ join_cols: list,
358
+ max_rows_all: int,
359
+ skip_row_total: bool = False,
360
+ ) -> dict:
361
+ LOGGER.debug("Getting diff rows")
362
+ # first get the count
363
+ q = JINJA_ENV.get_template("table_rows.sql").render(
364
+ schema_name=output_schema, table_name=(x_table + "_DIFF")
365
+ )
366
+ LOGGER.info(q)
367
+ cur.execute(q)
368
+ diff_total_count = cur.fetchall()[0]["COUNT"]
369
+ if skip_row_total:
370
+ LOGGER.debug(
371
+ "Skipping sample of rows with differences, query to get that sample, and the total # of rows with > 0 differences. Returning only 'total_count', the sum of cell-by-cell differences."
372
+ )
373
+ return {"total_count": diff_total_count}
374
+
375
+ q = JINJA_ENV.get_template("table_rows_uniq.sql").render(
376
+ schema_name=output_schema, table_name=(x_table + "_DIFF"), join_cols=", ".join(join_cols)
377
+ )
378
+ LOGGER.info(q)
379
+ cur.execute(q)
380
+ diff_row_count = cur.fetchall()[0]["COUNT"]
381
+
382
+ # we'll pull all columns from the joined table
383
+ q = JINJA_ENV.get_template("diff_rows_sample.sql").render(
384
+ schema_name=output_schema,
385
+ joined_table=(x_table + "_JOINED"),
386
+ diff_table=(x_table + "_DIFF"),
387
+ group_cols=", ".join(join_cols),
388
+ join_cols=" AND ".join(["x.{0} <=> joined.{0}".format(col) for col in join_cols]),
389
+ )
390
+ LOGGER.info(q)
391
+ cur.execute(q + " LIMIT " + str(max_rows_all))
392
+ diff_rows = pd.DataFrame(cur.fetchall())
393
+
394
+ return {
395
+ "query": q,
396
+ "sample": diff_rows,
397
+ "count": diff_row_count,
398
+ "total_count": diff_total_count,
399
+ }
400
+
401
+
402
+ def get_diff_rows_from_joined(
403
+ cur: Cursor,
404
+ grouped_column_diffs: dict,
405
+ output_schema: str,
406
+ x_table: str,
407
+ join_cols: list,
408
+ max_rows_all: int,
409
+ skip_row_total: bool = False,
410
+ ) -> dict:
411
+ """Get diff rows from joined table.
412
+
413
+ Non self-explanatory argument specifics:
414
+
415
+ - grouped_column_diffs:
416
+ - max_rows_all: number of rows to get for the sample (only relevant if skip_row_total=F)
417
+ - skip_row_total: skip sample of rows with differences, query to get that sample, and the total # of rows with > 0 differences. Return only 'total_count', the sum of cell-by-cell differences.
418
+
419
+ Returned data specifics:
420
+
421
+ - dict with 4 keys:
422
+ - total_count: total number of cell-by-cell differences between the two tables.
423
+ - query: query to get a sample of rows with >0 differences.
424
+ - sample: dataframe of those sample rows
425
+ - count: count of rows with >0 differences.
426
+ """
427
+ LOGGER.debug("Getting diff rows: get_diff_rows_from_joined()")
428
+
429
+ diff_total_count = sum([info["count"] for info in grouped_column_diffs.values()])
430
+ if skip_row_total or (len(grouped_column_diffs) == 0):
431
+ LOGGER.debug(
432
+ "Skipping sample of rows with differences, query to get that sample, and the total # of rows with > 0 differences. Returning only 'total_count', the sum of cell-by-cell differences."
433
+ )
434
+ return {"sample": [], "count": 0, "total_count": diff_total_count}
435
+
436
+ LOGGER.info(grouped_column_diffs)
437
+ q = JINJA_ENV.get_template("joined_rows_count.sql").render(
438
+ joined_schema=output_schema,
439
+ joined_table=(x_table + "_JOINED"),
440
+ columns=grouped_column_diffs.keys(),
441
+ )
442
+ LOGGER.info(q)
443
+ cur.execute(q)
444
+ diff_row_count = cur.fetchall()[0]["COUNT"]
445
+
446
+ # we'll pull all columns from the joined table
447
+ q = JINJA_ENV.get_template("joined_rows_sample.sql").render(
448
+ joined_schema=output_schema,
449
+ joined_table=(x_table + "_JOINED"),
450
+ columns=grouped_column_diffs.keys(),
451
+ )
452
+ LOGGER.info(q)
453
+ cur.execute(q + " LIMIT " + str(max_rows_all))
454
+ diff_rows = pd.DataFrame(cur.fetchall())
455
+
456
+ return {
457
+ "query": q,
458
+ "sample": diff_rows,
459
+ "count": diff_row_count,
460
+ "total_count": diff_total_count,
461
+ }
462
+
463
+
464
+ def get_diff_columns(cur: Cursor, output_schema: str, x_table: str) -> pd.DataFrame:
465
+ LOGGER.debug("Getting diff columns")
466
+ # The # of columns has a hard limit (~1600 in Vertica?) so don't worry about
467
+ # pulling the count first or limiting the results
468
+ q = JINJA_ENV.get_template("diff_column_summary.sql").render(
469
+ schema_name=output_schema, table_name=(x_table + "_DIFF")
470
+ )
471
+ cur.execute(q)
472
+ return pd.DataFrame(cur.fetchall())
473
+
474
+
475
+ def get_column_diffs(
476
+ diff_columns: pd.DataFrame,
477
+ cur: Cursor,
478
+ output_schema: str,
479
+ x_schema: str,
480
+ x_table: str,
481
+ y_schema: str,
482
+ y_table: str,
483
+ join_cols: list,
484
+ max_rows_column: int,
485
+ all_col_info_df: pd.DataFrame,
486
+ hierarchical: bool = False,
487
+ ) -> dict:
488
+ LOGGER.debug("Getting column diffs")
489
+ # get total count, list of most common differing pairs for each column
490
+ # list of (count, query, df)
491
+ grouped_column_diffs = {
492
+ row.column_name: {"count": row["COUNT"]} for i, row in diff_columns.iterrows()
493
+ }
494
+
495
+ for column_name, info in grouped_column_diffs.items():
496
+ LOGGER.info(
497
+ "Getting detailed diff for column: "
498
+ + str(column_name)
499
+ + " with "
500
+ + str(info["count"])
501
+ + " differences."
502
+ )
503
+ q = JINJA_ENV.get_template("diff_column.sql").render(
504
+ column=column_name,
505
+ joined_schema=output_schema,
506
+ joined_table=(x_table + "_JOINED"),
507
+ diff_schema=output_schema,
508
+ diff_table=(x_table + "_DIFF"),
509
+ group_cols=", ".join(join_cols),
510
+ join_cols=" AND ".join(["diff.{0} <=> joined.{0}".format(col) for col in join_cols]),
511
+ )
512
+ info["q"] = q
513
+ q_raw = JINJA_ENV.get_template("diff_column_raw.sql").render(
514
+ column=column_name,
515
+ joined_schema=output_schema,
516
+ joined_table=(x_table + "_JOINED"),
517
+ diff_schema=output_schema,
518
+ diff_table=(x_table + "_DIFF"),
519
+ join_cols=join_cols,
520
+ join_cols_join=" AND ".join(
521
+ ["diff.{0} <=> joined.{0}".format(col) for col in join_cols]
522
+ ),
523
+ )
524
+ info["q_raw"] = q_raw
525
+ cur.execute(q + " LIMIT " + str(max_rows_column))
526
+ info["df"] = pd.DataFrame(cur.fetchall())
527
+ cur.execute(q_raw + " LIMIT " + str(max_rows_column))
528
+ info["df_raw"] = pd.DataFrame(cur.fetchall())
529
+ if hierarchical:
530
+ for schema, table, side in ((x_schema, x_table, "x"), (y_schema, y_table, "y")):
531
+ for limit in (None, max_rows_column):
532
+ q_h = JINJA_ENV.get_template("diff_column_hier.sql").render(
533
+ column=column_name,
534
+ diff_schema=output_schema,
535
+ diff_table=(x_table + "_DIFF"),
536
+ join_cols=", ".join(join_cols),
537
+ first_join_col=join_cols[0],
538
+ schema=schema,
539
+ table=table,
540
+ limit=limit,
541
+ )
542
+ if limit is None:
543
+ info["q_h_" + side] = q_h
544
+ else:
545
+ cur.execute(q_h)
546
+ info["df_h_" + side] = pd.DataFrame(cur.fetchall())
547
+ row = all_col_info_df.loc[column_name, :]
548
+ is_numeric = is_numeric_like(row.x_dtype) and is_numeric_like(row.y_dtype)
549
+ is_date = is_date_like(row.x_dtype) and is_date_like(row.y_dtype)
550
+ if is_numeric or is_date:
551
+ info["q_n"] = JINJA_ENV.get_template("diff_column_numeric_diffs_binned.sql").render(
552
+ column=column_name,
553
+ joined_schema=output_schema,
554
+ joined_table=(x_table + "_JOINED"),
555
+ diff_schema=output_schema,
556
+ diff_table=(x_table + "_DIFF"),
557
+ group_cols=", ".join(join_cols),
558
+ join_cols=" AND ".join(
559
+ ["diff.{0} <=> joined.{0}".format(col) for col in join_cols]
560
+ ),
561
+ tiles=min({max({1, info["count"]}), 10}),
562
+ )
563
+ cur.execute(info["q_n"])
564
+ info["df_n"] = pd.DataFrame(cur.fetchall())
565
+ info["q_n_sample"] = JINJA_ENV.get_template(
566
+ "diff_column_numeric_diffs_sorted.sql"
567
+ ).render(
568
+ column=column_name,
569
+ joined_schema=output_schema,
570
+ joined_table=(x_table + "_JOINED"),
571
+ diff_schema=output_schema,
572
+ diff_table=(x_table + "_DIFF"),
573
+ join_cols=join_cols,
574
+ join_cols_join=" AND ".join(
575
+ ["diff.{0} <=> joined.{0}".format(col) for col in join_cols]
576
+ ),
577
+ )
578
+ cur.execute(info["q_n_sample"] + " LIMIT " + str(max_rows_column))
579
+ info["df_n_sample"] = pd.DataFrame(cur.fetchall())
580
+ return grouped_column_diffs
581
+
582
+
583
+ def get_column_diffs_from_joined(
584
+ cur: Cursor,
585
+ output_schema: str,
586
+ x_schema: str,
587
+ x_table: str,
588
+ y_schema: str,
589
+ y_table: str,
590
+ join_cols: list,
591
+ max_rows_column: int,
592
+ all_col_info_df: pd.DataFrame,
593
+ comparable_filter,
594
+ hierarchical: bool = False,
595
+ ) -> dict:
596
+ """Get column-by-column diffs directly from the joined table.
597
+
598
+ Non self-explanatory argument specifics:
599
+
600
+ - max_rows_column: number of rows to pull for sample differing cells on each column.
601
+ - all_col_info_df: dataframe with the following columns:
602
+ - index: column names.
603
+ - x_dtype.
604
+ - y_dtype.
605
+ - comparable_filter: an 0/1 index on all_col_info_df to filter on columns to compare.
606
+ - in cli.py, this filter/index is set using datatype matching and the user-supplied list of columns to exclude.
607
+ - hierarchical: if true, additional outputs are included for each columns that are samples with the join keys.
608
+
609
+ Returned data specifics:
610
+ - dict grouped_column_diffs:
611
+ - each `key` is a string column name for columns that matches on name between x and y tables (and are comparable on dtype, not excluded by user-supplied list).
612
+ - each `value` is a dictionary with the following keys:
613
+ - {'count': diff_count, 'df': df, 'df_raw': df_raw, 'q': q, 'q_raw': q_raw}.
614
+ - if `hierarchical` is true: `{q,d}_h_{x,y}` (q for query, d for dataframe sample) from x and y tables, respectively.
615
+ - if numeric of date: `{q,df}_n{,_sample}` (q for query, df for dataframe), the `_sample` is the biggest diffs, while the former are the binned differences.
616
+ """
617
+ column_list_to_compare = all_col_info_df.loc[
618
+ comparable_filter & ~all_col_info_df.index.isin(join_cols), :
619
+ ].index.values
620
+ LOGGER.info("Getting column diffs for columns:")
621
+ LOGGER.info(",".join(column_list_to_compare))
622
+ grouped_column_diffs = {}
623
+
624
+ for column in column_list_to_compare:
625
+ LOGGER.info("=" * 80)
626
+ LOGGER.info(column)
627
+ joined_count_q = JINJA_ENV.get_template("joined_count.sql").render(
628
+ column=column, joined_schema=output_schema, joined_table=(x_table + "_JOINED")
629
+ )
630
+ LOGGER.info(joined_count_q)
631
+ cur.execute(joined_count_q)
632
+ diff_count = cur.fetchall()[0]["COUNT"]
633
+ if diff_count > 0:
634
+ LOGGER.info(
635
+ "Getting detailed diff for column: "
636
+ + str(column)
637
+ + " with "
638
+ + str(diff_count)
639
+ + " differences."
640
+ )
641
+ q = JINJA_ENV.get_template("joined_column.sql").render(
642
+ column=column, joined_schema=output_schema, joined_table=(x_table + "_JOINED")
643
+ )
644
+ q_raw = JINJA_ENV.get_template("joined_column_raw.sql").render(
645
+ column=column,
646
+ joined_schema=output_schema,
647
+ joined_table=(x_table + "_JOINED"),
648
+ join_cols=join_cols,
649
+ )
650
+ LOGGER.info(q)
651
+ cur.execute(q + " LIMIT " + str(max_rows_column))
652
+ df = pd.DataFrame(cur.fetchall())
653
+ LOGGER.info(q_raw)
654
+ cur.execute(q_raw + " LIMIT " + str(max_rows_column))
655
+ df_raw = pd.DataFrame(cur.fetchall())
656
+ grouped_column_diffs[column] = {
657
+ "count": diff_count,
658
+ "df": df,
659
+ "df_raw": df_raw,
660
+ "q": q,
661
+ "q_raw": q_raw,
662
+ }
663
+ LOGGER.info(grouped_column_diffs[column])
664
+
665
+ if hierarchical:
666
+ for schema, table, side in ((x_schema, x_table, "x"), (y_schema, y_table, "y")):
667
+ for limit in (None, max_rows_column):
668
+ q_h = JINJA_ENV.get_template("joined_column_hier.sql").render(
669
+ column=column,
670
+ joined_schema=output_schema,
671
+ joined_table=(x_table + "_JOINED"),
672
+ join_cols=join_cols,
673
+ schema=schema,
674
+ table=table,
675
+ limit=limit,
676
+ )
677
+ if limit is None:
678
+ grouped_column_diffs[column]["q_h_" + side] = q_h
679
+ else:
680
+ cur.execute(q_h)
681
+ grouped_column_diffs[column]["df_h_" + side] = pd.DataFrame(
682
+ cur.fetchall()
683
+ )
684
+ row = all_col_info_df.loc[column, :]
685
+ is_numeric = is_numeric_like(row.x_dtype) and is_numeric_like(row.y_dtype)
686
+ is_date = is_date_like(row.x_dtype) and is_date_like(row.y_dtype)
687
+ if is_numeric or is_date:
688
+ grouped_column_diffs[column]["q_n"] = JINJA_ENV.get_template(
689
+ "joined_column_numeric_diffs_binned.sql"
690
+ ).render(
691
+ column=column,
692
+ joined_schema=output_schema,
693
+ joined_table=(x_table + "_JOINED"),
694
+ tiles=min({max({1, grouped_column_diffs[column]["count"]}), 10}),
695
+ )
696
+ cur.execute(grouped_column_diffs[column]["q_n"])
697
+ grouped_column_diffs[column]["df_n"] = pd.DataFrame(cur.fetchall())
698
+ grouped_column_diffs[column]["q_n_sample"] = JINJA_ENV.get_template(
699
+ "joined_column_numeric_diffs_sorted.sql"
700
+ ).render(
701
+ column=column,
702
+ joined_schema=output_schema,
703
+ joined_table=(x_table + "_JOINED"),
704
+ join_cols=join_cols,
705
+ )
706
+ cur.execute(
707
+ grouped_column_diffs[column]["q_n_sample"] + " LIMIT " + str(max_rows_column)
708
+ )
709
+ grouped_column_diffs[column]["df_n_sample"] = pd.DataFrame(cur.fetchall())
710
+ else:
711
+ LOGGER.info(
712
+ "NOT getting detailed diff for column: "
713
+ + str(column)
714
+ + " with "
715
+ + str(diff_count)
716
+ + " differences."
717
+ )
718
+ LOGGER.info(len(grouped_column_diffs))
719
+ grouped_column_diffs_sorted = {
720
+ x: grouped_column_diffs[x]
721
+ for x in sorted(
722
+ grouped_column_diffs.keys(),
723
+ key=lambda x: grouped_column_diffs[x]["count"],
724
+ reverse=True,
725
+ )
726
+ }
727
+ LOGGER.info(len(grouped_column_diffs_sorted))
728
+ return grouped_column_diffs_sorted