fleet-python 0.2.74b2__py3-none-any.whl → 0.2.75b2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fleet/_async/resources/sqlite.py +1018 -100
- fleet/_async/verifiers/verifier.py +19 -1
- fleet/resources/sqlite.py +906 -313
- fleet/verifiers/db.py +502 -1
- fleet/verifiers/verifier.py +19 -1
- {fleet_python-0.2.74b2.dist-info → fleet_python-0.2.75b2.dist-info}/METADATA +15 -1
- {fleet_python-0.2.74b2.dist-info → fleet_python-0.2.75b2.dist-info}/RECORD +11 -10
- tests/test_expect_only.py +1921 -0
- {fleet_python-0.2.74b2.dist-info → fleet_python-0.2.75b2.dist-info}/WHEEL +0 -0
- {fleet_python-0.2.74b2.dist-info → fleet_python-0.2.75b2.dist-info}/licenses/LICENSE +0 -0
- {fleet_python-0.2.74b2.dist-info → fleet_python-0.2.75b2.dist-info}/top_level.txt +0 -0
fleet/_async/resources/sqlite.py
CHANGED
|
@@ -255,51 +255,50 @@ class AsyncSnapshotQueryBuilder:
|
|
|
255
255
|
|
|
256
256
|
class AsyncSnapshotDiff:
|
|
257
257
|
"""Compute & validate changes between two snapshots fetched via API."""
|
|
258
|
-
|
|
258
|
+
|
|
259
259
|
def __init__(
|
|
260
260
|
self,
|
|
261
261
|
before: AsyncDatabaseSnapshot,
|
|
262
262
|
after: AsyncDatabaseSnapshot,
|
|
263
|
-
ignore_config: IgnoreConfig
|
|
263
|
+
ignore_config: Optional[IgnoreConfig] = None,
|
|
264
264
|
):
|
|
265
265
|
self.before = before
|
|
266
266
|
self.after = after
|
|
267
267
|
self.ignore_config = ignore_config or IgnoreConfig()
|
|
268
|
-
self._cached:
|
|
269
|
-
|
|
270
|
-
|
|
268
|
+
self._cached: Optional[Dict[str, Any]] = None
|
|
269
|
+
self._targeted_mode = False # Flag to use targeted queries
|
|
270
|
+
|
|
271
|
+
async def _get_primary_key_columns(self, table: str) -> List[str]:
|
|
271
272
|
"""Get primary key columns for a table."""
|
|
272
273
|
# Try to get from schema
|
|
273
274
|
schema_response = await self.after.resource.query(f"PRAGMA table_info({table})")
|
|
274
275
|
if not schema_response.rows:
|
|
275
276
|
return ["id"] # Default fallback
|
|
276
|
-
|
|
277
|
+
|
|
277
278
|
pk_columns = []
|
|
278
279
|
for row in schema_response.rows:
|
|
279
280
|
# row format: (cid, name, type, notnull, dflt_value, pk)
|
|
280
281
|
if row[5] > 0: # pk > 0 means it's part of primary key
|
|
281
282
|
pk_columns.append((row[5], row[1])) # (pk_position, column_name)
|
|
282
|
-
|
|
283
|
+
|
|
283
284
|
if not pk_columns:
|
|
284
285
|
# Try common defaults
|
|
285
286
|
all_columns = [row[1] for row in schema_response.rows]
|
|
286
287
|
if "id" in all_columns:
|
|
287
288
|
return ["id"]
|
|
288
289
|
return ["rowid"]
|
|
289
|
-
|
|
290
|
+
|
|
290
291
|
# Sort by primary key position and return just the column names
|
|
291
292
|
pk_columns.sort(key=lambda x: x[0])
|
|
292
293
|
return [col[1] for col in pk_columns]
|
|
293
|
-
|
|
294
|
+
|
|
294
295
|
async def _collect(self):
|
|
295
296
|
"""Collect all differences between snapshots."""
|
|
296
297
|
if self._cached is not None:
|
|
297
298
|
return self._cached
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
all_tables = before_tables | after_tables
|
|
302
|
-
diff: dict[str, dict[str, Any]] = {}
|
|
299
|
+
|
|
300
|
+
all_tables = set(await self.before.tables()) | set(await self.after.tables())
|
|
301
|
+
diff: Dict[str, Dict[str, Any]] = {}
|
|
303
302
|
|
|
304
303
|
for tbl in all_tables:
|
|
305
304
|
if self.ignore_config.should_ignore_table(tbl):
|
|
@@ -308,31 +307,26 @@ class AsyncSnapshotDiff:
|
|
|
308
307
|
# Get primary key columns
|
|
309
308
|
pk_columns = await self._get_primary_key_columns(tbl)
|
|
310
309
|
|
|
311
|
-
#
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
310
|
+
# Ensure data is fetched for this table
|
|
311
|
+
await self.before._ensure_table_data(tbl)
|
|
312
|
+
await self.after._ensure_table_data(tbl)
|
|
313
|
+
|
|
314
|
+
# Get data from both snapshots
|
|
315
|
+
before_data = self.before._data.get(tbl, [])
|
|
316
|
+
after_data = self.after._data.get(tbl, [])
|
|
317
317
|
|
|
318
|
-
if tbl in after_tables:
|
|
319
|
-
await self.after._ensure_table_data(tbl)
|
|
320
|
-
after_data = self.after._data.get(tbl, [])
|
|
321
|
-
else:
|
|
322
|
-
after_data = []
|
|
323
|
-
|
|
324
318
|
# Create indexes by primary key
|
|
325
|
-
def make_key(row: dict, pk_cols:
|
|
319
|
+
def make_key(row: dict, pk_cols: List[str]) -> Any:
|
|
326
320
|
if len(pk_cols) == 1:
|
|
327
321
|
return row.get(pk_cols[0])
|
|
328
322
|
return tuple(row.get(col) for col in pk_cols)
|
|
329
|
-
|
|
323
|
+
|
|
330
324
|
before_index = {make_key(row, pk_columns): row for row in before_data}
|
|
331
325
|
after_index = {make_key(row, pk_columns): row for row in after_data}
|
|
332
|
-
|
|
326
|
+
|
|
333
327
|
before_keys = set(before_index.keys())
|
|
334
328
|
after_keys = set(after_index.keys())
|
|
335
|
-
|
|
329
|
+
|
|
336
330
|
# Find changes
|
|
337
331
|
result = {
|
|
338
332
|
"table_name": tbl,
|
|
@@ -343,27 +337,23 @@ class AsyncSnapshotDiff:
|
|
|
343
337
|
"unchanged_count": 0,
|
|
344
338
|
"total_changes": 0,
|
|
345
339
|
}
|
|
346
|
-
|
|
340
|
+
|
|
347
341
|
# Added rows
|
|
348
342
|
for key in after_keys - before_keys:
|
|
349
|
-
result["added_rows"].append({
|
|
350
|
-
|
|
351
|
-
"data": after_index[key]
|
|
352
|
-
})
|
|
353
|
-
|
|
343
|
+
result["added_rows"].append({"row_id": key, "data": after_index[key]})
|
|
344
|
+
|
|
354
345
|
# Removed rows
|
|
355
346
|
for key in before_keys - after_keys:
|
|
356
|
-
result["removed_rows"].append(
|
|
357
|
-
"row_id": key,
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
347
|
+
result["removed_rows"].append(
|
|
348
|
+
{"row_id": key, "data": before_index[key]}
|
|
349
|
+
)
|
|
350
|
+
|
|
361
351
|
# Modified rows
|
|
362
352
|
for key in before_keys & after_keys:
|
|
363
353
|
before_row = before_index[key]
|
|
364
354
|
after_row = after_index[key]
|
|
365
355
|
changes = {}
|
|
366
|
-
|
|
356
|
+
|
|
367
357
|
for field in set(before_row.keys()) | set(after_row.keys()):
|
|
368
358
|
if self.ignore_config.should_ignore_field(tbl, field):
|
|
369
359
|
continue
|
|
@@ -371,33 +361,413 @@ class AsyncSnapshotDiff:
|
|
|
371
361
|
after_val = after_row.get(field)
|
|
372
362
|
if not _values_equivalent(before_val, after_val):
|
|
373
363
|
changes[field] = {"before": before_val, "after": after_val}
|
|
374
|
-
|
|
364
|
+
|
|
375
365
|
if changes:
|
|
376
|
-
result["modified_rows"].append(
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
366
|
+
result["modified_rows"].append(
|
|
367
|
+
{
|
|
368
|
+
"row_id": key,
|
|
369
|
+
"changes": changes,
|
|
370
|
+
"data": after_row, # Current state
|
|
371
|
+
}
|
|
372
|
+
)
|
|
381
373
|
else:
|
|
382
374
|
result["unchanged_count"] += 1
|
|
383
|
-
|
|
375
|
+
|
|
384
376
|
result["total_changes"] = (
|
|
385
|
-
len(result["added_rows"])
|
|
386
|
-
len(result["removed_rows"])
|
|
387
|
-
len(result["modified_rows"])
|
|
377
|
+
len(result["added_rows"])
|
|
378
|
+
+ len(result["removed_rows"])
|
|
379
|
+
+ len(result["modified_rows"])
|
|
388
380
|
)
|
|
389
|
-
|
|
381
|
+
|
|
390
382
|
diff[tbl] = result
|
|
391
|
-
|
|
383
|
+
|
|
392
384
|
self._cached = diff
|
|
393
385
|
return diff
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
diff
|
|
398
|
-
|
|
386
|
+
|
|
387
|
+
@property
|
|
388
|
+
def changes(self) -> Dict[str, Dict[str, Any]]:
|
|
389
|
+
"""Expose cached changes; ensure callers awaited a diff-producing method first."""
|
|
390
|
+
if self._cached is None:
|
|
391
|
+
raise RuntimeError(
|
|
392
|
+
"Diff not collected yet; await an operation like expect_only() first."
|
|
393
|
+
)
|
|
394
|
+
return self._cached
|
|
395
|
+
|
|
396
|
+
def _can_use_targeted_queries(self, allowed_changes: List[Dict[str, Any]]) -> bool:
|
|
397
|
+
"""Check if we can use targeted queries for optimization."""
|
|
398
|
+
# We can use targeted queries if all allowed changes specify table and pk
|
|
399
|
+
for change in allowed_changes:
|
|
400
|
+
if "table" not in change or "pk" not in change:
|
|
401
|
+
return False
|
|
402
|
+
return True
|
|
403
|
+
|
|
404
|
+
def _build_pk_where_clause(self, pk_columns: List[str], pk_value: Any) -> str:
|
|
405
|
+
"""Build WHERE clause for primary key lookup."""
|
|
406
|
+
# Escape single quotes in values to prevent SQL injection
|
|
407
|
+
def escape_value(val: Any) -> str:
|
|
408
|
+
if val is None:
|
|
409
|
+
return "NULL"
|
|
410
|
+
elif isinstance(val, str):
|
|
411
|
+
escaped = str(val).replace("'", "''")
|
|
412
|
+
return f"'{escaped}'"
|
|
413
|
+
else:
|
|
414
|
+
return f"'{val}'"
|
|
415
|
+
|
|
416
|
+
if len(pk_columns) == 1:
|
|
417
|
+
return f"{pk_columns[0]} = {escape_value(pk_value)}"
|
|
418
|
+
else:
|
|
419
|
+
# Composite key
|
|
420
|
+
if isinstance(pk_value, tuple):
|
|
421
|
+
conditions = [
|
|
422
|
+
f"{col} = {escape_value(val)}"
|
|
423
|
+
for col, val in zip(pk_columns, pk_value)
|
|
424
|
+
]
|
|
425
|
+
return " AND ".join(conditions)
|
|
426
|
+
else:
|
|
427
|
+
# Shouldn't happen if data is consistent
|
|
428
|
+
return f"{pk_columns[0]} = {escape_value(pk_value)}"
|
|
429
|
+
|
|
430
|
+
async def _expect_no_changes(self):
|
|
431
|
+
"""Efficiently verify that no changes occurred between snapshots using row counts."""
|
|
432
|
+
try:
|
|
433
|
+
import asyncio
|
|
434
|
+
|
|
435
|
+
# Get all tables from both snapshots
|
|
436
|
+
before_tables = set(await self.before.tables())
|
|
437
|
+
after_tables = set(await self.after.tables())
|
|
438
|
+
|
|
439
|
+
# Check for added/removed tables (excluding ignored ones)
|
|
440
|
+
added_tables = after_tables - before_tables
|
|
441
|
+
removed_tables = before_tables - after_tables
|
|
442
|
+
|
|
443
|
+
for table in added_tables:
|
|
444
|
+
if not self.ignore_config.should_ignore_table(table):
|
|
445
|
+
raise AssertionError(f"Unexpected table added: {table}")
|
|
446
|
+
|
|
447
|
+
for table in removed_tables:
|
|
448
|
+
if not self.ignore_config.should_ignore_table(table):
|
|
449
|
+
raise AssertionError(f"Unexpected table removed: {table}")
|
|
450
|
+
|
|
451
|
+
# Prepare tables to check
|
|
452
|
+
tables_to_check = []
|
|
453
|
+
all_tables = before_tables | after_tables
|
|
454
|
+
for table in all_tables:
|
|
455
|
+
if not self.ignore_config.should_ignore_table(table):
|
|
456
|
+
tables_to_check.append(table)
|
|
457
|
+
|
|
458
|
+
# If no tables to check, we're done
|
|
459
|
+
if not tables_to_check:
|
|
460
|
+
return self
|
|
461
|
+
|
|
462
|
+
# Track errors and tables needing verification
|
|
463
|
+
errors = []
|
|
464
|
+
tables_needing_verification = []
|
|
465
|
+
|
|
466
|
+
async def check_table_counts(table: str):
|
|
467
|
+
"""Check row counts for a single table."""
|
|
468
|
+
try:
|
|
469
|
+
# Get row counts from both snapshots
|
|
470
|
+
before_count = 0
|
|
471
|
+
after_count = 0
|
|
472
|
+
|
|
473
|
+
if table in before_tables:
|
|
474
|
+
before_count_response = await self.before.resource.query(
|
|
475
|
+
f"SELECT COUNT(*) FROM {table}"
|
|
476
|
+
)
|
|
477
|
+
before_count = (
|
|
478
|
+
before_count_response.rows[0][0]
|
|
479
|
+
if before_count_response.rows
|
|
480
|
+
else 0
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
if table in after_tables:
|
|
484
|
+
after_count_response = await self.after.resource.query(
|
|
485
|
+
f"SELECT COUNT(*) FROM {table}"
|
|
486
|
+
)
|
|
487
|
+
after_count = (
|
|
488
|
+
after_count_response.rows[0][0]
|
|
489
|
+
if after_count_response.rows
|
|
490
|
+
else 0
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
if before_count != after_count:
|
|
494
|
+
error_msg = (
|
|
495
|
+
f"Unexpected change in table '{table}': "
|
|
496
|
+
f"row count changed from {before_count} to {after_count}"
|
|
497
|
+
)
|
|
498
|
+
errors.append(AssertionError(error_msg))
|
|
499
|
+
elif before_count > 0 and before_count <= 1000:
|
|
500
|
+
# Mark for detailed verification
|
|
501
|
+
tables_needing_verification.append(table)
|
|
502
|
+
|
|
503
|
+
except Exception as e:
|
|
504
|
+
errors.append(e)
|
|
505
|
+
|
|
506
|
+
# Execute count checks in parallel
|
|
507
|
+
await asyncio.gather(*[check_table_counts(table) for table in tables_to_check])
|
|
508
|
+
|
|
509
|
+
# Check if any errors occurred during count checking
|
|
510
|
+
if errors:
|
|
511
|
+
raise errors[0]
|
|
512
|
+
|
|
513
|
+
# Now verify small tables for data changes (also in parallel)
|
|
514
|
+
if tables_needing_verification:
|
|
515
|
+
verification_errors = []
|
|
516
|
+
|
|
517
|
+
async def verify_table(table: str):
|
|
518
|
+
"""Verify a single table's data hasn't changed."""
|
|
519
|
+
try:
|
|
520
|
+
await self._verify_table_unchanged(table)
|
|
521
|
+
except AssertionError as e:
|
|
522
|
+
verification_errors.append(e)
|
|
523
|
+
|
|
524
|
+
await asyncio.gather(*[verify_table(table) for table in tables_needing_verification])
|
|
525
|
+
|
|
526
|
+
# Check if any errors occurred during verification
|
|
527
|
+
if verification_errors:
|
|
528
|
+
raise verification_errors[0]
|
|
529
|
+
|
|
530
|
+
return self
|
|
531
|
+
|
|
532
|
+
except AssertionError:
|
|
533
|
+
# Re-raise assertion errors (these are expected failures)
|
|
534
|
+
raise
|
|
535
|
+
except Exception as e:
|
|
536
|
+
# If the optimized check fails for other reasons, fall back to full diff
|
|
537
|
+
print(f"Warning: Optimized no-changes check failed: {e}")
|
|
538
|
+
print("Falling back to full diff...")
|
|
539
|
+
return await self._validate_diff_against_allowed_changes(
|
|
540
|
+
await self._collect(), []
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
async def _verify_table_unchanged(self, table: str):
|
|
544
|
+
"""Verify that a table's data hasn't changed (for small tables)."""
|
|
545
|
+
# Get primary key columns
|
|
546
|
+
pk_columns = await self._get_primary_key_columns(table)
|
|
547
|
+
|
|
548
|
+
# Get sorted data from both snapshots
|
|
549
|
+
order_by = ", ".join(pk_columns) if pk_columns else "rowid"
|
|
550
|
+
|
|
551
|
+
before_response = await self.before.resource.query(
|
|
552
|
+
f"SELECT * FROM {table} ORDER BY {order_by}"
|
|
553
|
+
)
|
|
554
|
+
after_response = await self.after.resource.query(
|
|
555
|
+
f"SELECT * FROM {table} ORDER BY {order_by}"
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
# Quick check: if column counts differ, there's a schema change
|
|
559
|
+
if before_response.columns != after_response.columns:
|
|
560
|
+
raise AssertionError(f"Schema changed in table '{table}'")
|
|
561
|
+
|
|
562
|
+
# Compare row by row
|
|
563
|
+
if len(before_response.rows) != len(after_response.rows):
|
|
564
|
+
raise AssertionError(
|
|
565
|
+
f"Row count mismatch in table '{table}': "
|
|
566
|
+
f"{len(before_response.rows)} vs {len(after_response.rows)}"
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
for i, (before_row, after_row) in enumerate(
|
|
570
|
+
zip(before_response.rows, after_response.rows)
|
|
571
|
+
):
|
|
572
|
+
before_dict = dict(zip(before_response.columns, before_row))
|
|
573
|
+
after_dict = dict(zip(after_response.columns, after_row))
|
|
574
|
+
|
|
575
|
+
# Compare fields, ignoring those in ignore config
|
|
576
|
+
for field in before_response.columns:
|
|
577
|
+
if self.ignore_config.should_ignore_field(table, field):
|
|
578
|
+
continue
|
|
579
|
+
|
|
580
|
+
if not _values_equivalent(
|
|
581
|
+
before_dict.get(field), after_dict.get(field)
|
|
582
|
+
):
|
|
583
|
+
pk_val = before_dict.get(pk_columns[0]) if pk_columns else i
|
|
584
|
+
raise AssertionError(
|
|
585
|
+
f"Unexpected change in table '{table}', row {pk_val}, "
|
|
586
|
+
f"field '{field}': {repr(before_dict.get(field))} -> {repr(after_dict.get(field))}"
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
def _is_field_change_allowed(
|
|
590
|
+
self, table_changes: List[Dict[str, Any]], pk: Any, field: str, after_val: Any
|
|
591
|
+
) -> bool:
|
|
592
|
+
"""Check if a specific field change is allowed."""
|
|
593
|
+
for change in table_changes:
|
|
594
|
+
if (
|
|
595
|
+
str(change.get("pk")) == str(pk)
|
|
596
|
+
and change.get("field") == field
|
|
597
|
+
and _values_equivalent(change.get("after"), after_val)
|
|
598
|
+
):
|
|
599
|
+
return True
|
|
600
|
+
return False
|
|
601
|
+
|
|
602
|
+
def _is_row_change_allowed(
|
|
603
|
+
self, table_changes: List[Dict[str, Any]], pk: Any, change_type: str
|
|
604
|
+
) -> bool:
|
|
605
|
+
"""Check if a row addition/deletion is allowed."""
|
|
606
|
+
for change in table_changes:
|
|
607
|
+
if str(change.get("pk")) == str(pk) and change.get("after") == change_type:
|
|
608
|
+
return True
|
|
609
|
+
return False
|
|
610
|
+
|
|
611
|
+
async def _expect_only_targeted(self, allowed_changes: List[Dict[str, Any]]):
|
|
612
|
+
"""Optimized version that only queries specific rows mentioned in allowed_changes."""
|
|
613
|
+
import asyncio
|
|
614
|
+
|
|
615
|
+
# Group allowed changes by table
|
|
616
|
+
changes_by_table: Dict[str, List[Dict[str, Any]]] = {}
|
|
617
|
+
for change in allowed_changes:
|
|
618
|
+
table = change["table"]
|
|
619
|
+
if table not in changes_by_table:
|
|
620
|
+
changes_by_table[table] = []
|
|
621
|
+
changes_by_table[table].append(change)
|
|
622
|
+
|
|
623
|
+
errors = []
|
|
624
|
+
|
|
625
|
+
# Function to check a single row
|
|
626
|
+
async def check_row(
|
|
627
|
+
table: str,
|
|
628
|
+
pk: Any,
|
|
629
|
+
table_changes: List[Dict[str, Any]],
|
|
630
|
+
pk_columns: List[str],
|
|
631
|
+
):
|
|
632
|
+
try:
|
|
633
|
+
# Build WHERE clause for this PK
|
|
634
|
+
where_sql = self._build_pk_where_clause(pk_columns, pk)
|
|
635
|
+
|
|
636
|
+
# Query before snapshot
|
|
637
|
+
before_query = f"SELECT * FROM {table} WHERE {where_sql}"
|
|
638
|
+
before_response = await self.before.resource.query(before_query)
|
|
639
|
+
before_row = (
|
|
640
|
+
dict(zip(before_response.columns, before_response.rows[0]))
|
|
641
|
+
if before_response.rows
|
|
642
|
+
else None
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
# Query after snapshot
|
|
646
|
+
after_response = await self.after.resource.query(before_query)
|
|
647
|
+
after_row = (
|
|
648
|
+
dict(zip(after_response.columns, after_response.rows[0]))
|
|
649
|
+
if after_response.rows
|
|
650
|
+
else None
|
|
651
|
+
)
|
|
652
|
+
|
|
653
|
+
# Check changes for this row
|
|
654
|
+
if before_row and after_row:
|
|
655
|
+
# Modified row - check fields
|
|
656
|
+
for field in set(before_row.keys()) | set(after_row.keys()):
|
|
657
|
+
if self.ignore_config.should_ignore_field(table, field):
|
|
658
|
+
continue
|
|
659
|
+
before_val = before_row.get(field)
|
|
660
|
+
after_val = after_row.get(field)
|
|
661
|
+
if not _values_equivalent(before_val, after_val):
|
|
662
|
+
# Check if this change is allowed
|
|
663
|
+
if not self._is_field_change_allowed(
|
|
664
|
+
table_changes, pk, field, after_val
|
|
665
|
+
):
|
|
666
|
+
error_msg = (
|
|
667
|
+
f"Unexpected change in table '{table}', "
|
|
668
|
+
f"row {pk}, field '{field}': "
|
|
669
|
+
f"{repr(before_val)} -> {repr(after_val)}"
|
|
670
|
+
)
|
|
671
|
+
errors.append(AssertionError(error_msg))
|
|
672
|
+
return # Stop checking this row
|
|
673
|
+
elif not before_row and after_row:
|
|
674
|
+
# Added row
|
|
675
|
+
if not self._is_row_change_allowed(table_changes, pk, "__added__"):
|
|
676
|
+
error_msg = f"Unexpected row added in table '{table}': {pk}"
|
|
677
|
+
errors.append(AssertionError(error_msg))
|
|
678
|
+
elif before_row and not after_row:
|
|
679
|
+
# Removed row
|
|
680
|
+
if not self._is_row_change_allowed(table_changes, pk, "__removed__"):
|
|
681
|
+
error_msg = f"Unexpected row removed from table '{table}': {pk}"
|
|
682
|
+
errors.append(AssertionError(error_msg))
|
|
683
|
+
except Exception as e:
|
|
684
|
+
errors.append(e)
|
|
685
|
+
|
|
686
|
+
# Prepare all row checks
|
|
687
|
+
row_checks = []
|
|
688
|
+
for table, table_changes in changes_by_table.items():
|
|
689
|
+
if self.ignore_config.should_ignore_table(table):
|
|
690
|
+
continue
|
|
691
|
+
|
|
692
|
+
# Get primary key columns once per table
|
|
693
|
+
pk_columns = await self._get_primary_key_columns(table)
|
|
694
|
+
|
|
695
|
+
# Extract unique PKs to check
|
|
696
|
+
pks_to_check = {change["pk"] for change in table_changes}
|
|
697
|
+
|
|
698
|
+
for pk in pks_to_check:
|
|
699
|
+
row_checks.append((table, pk, table_changes, pk_columns))
|
|
700
|
+
|
|
701
|
+
# Execute row checks in parallel
|
|
702
|
+
if row_checks:
|
|
703
|
+
await asyncio.gather(
|
|
704
|
+
*[
|
|
705
|
+
check_row(table, pk, table_changes, pk_columns)
|
|
706
|
+
for table, pk, table_changes, pk_columns in row_checks
|
|
707
|
+
]
|
|
708
|
+
)
|
|
709
|
+
|
|
710
|
+
# Check for errors from row checks
|
|
711
|
+
if errors:
|
|
712
|
+
raise errors[0]
|
|
713
|
+
|
|
714
|
+
# Now check tables not mentioned in allowed_changes to ensure no changes
|
|
715
|
+
all_tables = set(await self.before.tables()) | set(await self.after.tables())
|
|
716
|
+
tables_to_verify = []
|
|
717
|
+
|
|
718
|
+
for table in all_tables:
|
|
719
|
+
if (
|
|
720
|
+
table not in changes_by_table
|
|
721
|
+
and not self.ignore_config.should_ignore_table(table)
|
|
722
|
+
):
|
|
723
|
+
tables_to_verify.append(table)
|
|
724
|
+
|
|
725
|
+
# Function to verify no changes in a table
|
|
726
|
+
async def verify_no_changes(table: str):
|
|
727
|
+
try:
|
|
728
|
+
# For tables with no allowed changes, just check row counts
|
|
729
|
+
before_count_response = await self.before.resource.query(
|
|
730
|
+
f"SELECT COUNT(*) FROM {table}"
|
|
731
|
+
)
|
|
732
|
+
before_count = (
|
|
733
|
+
before_count_response.rows[0][0]
|
|
734
|
+
if before_count_response.rows
|
|
735
|
+
else 0
|
|
736
|
+
)
|
|
737
|
+
|
|
738
|
+
after_count_response = await self.after.resource.query(
|
|
739
|
+
f"SELECT COUNT(*) FROM {table}"
|
|
740
|
+
)
|
|
741
|
+
after_count = (
|
|
742
|
+
after_count_response.rows[0][0] if after_count_response.rows else 0
|
|
743
|
+
)
|
|
744
|
+
|
|
745
|
+
if before_count != after_count:
|
|
746
|
+
error_msg = (
|
|
747
|
+
f"Unexpected change in table '{table}': "
|
|
748
|
+
f"row count changed from {before_count} to {after_count}"
|
|
749
|
+
)
|
|
750
|
+
errors.append(AssertionError(error_msg))
|
|
751
|
+
except Exception as e:
|
|
752
|
+
errors.append(e)
|
|
753
|
+
|
|
754
|
+
# Execute table verification in parallel
|
|
755
|
+
if tables_to_verify:
|
|
756
|
+
await asyncio.gather(*[verify_no_changes(table) for table in tables_to_verify])
|
|
757
|
+
|
|
758
|
+
# Final error check
|
|
759
|
+
if errors:
|
|
760
|
+
raise errors[0]
|
|
761
|
+
|
|
762
|
+
return self
|
|
763
|
+
|
|
764
|
+
async def _validate_diff_against_allowed_changes(
|
|
765
|
+
self, diff: Dict[str, Any], allowed_changes: List[Dict[str, Any]]
|
|
766
|
+
):
|
|
767
|
+
"""Validate a collected diff against allowed changes."""
|
|
768
|
+
|
|
399
769
|
def _is_change_allowed(
|
|
400
|
-
table: str, row_id: Any, field: str
|
|
770
|
+
table: str, row_id: Any, field: Optional[str], after_value: Any
|
|
401
771
|
) -> bool:
|
|
402
772
|
"""Check if a change is in the allowed list using semantic comparison."""
|
|
403
773
|
for allowed in allowed_changes:
|
|
@@ -406,7 +776,7 @@ class AsyncSnapshotDiff:
|
|
|
406
776
|
pk_match = (
|
|
407
777
|
str(allowed_pk) == str(row_id) if allowed_pk is not None else False
|
|
408
778
|
)
|
|
409
|
-
|
|
779
|
+
|
|
410
780
|
if (
|
|
411
781
|
allowed["table"] == table
|
|
412
782
|
and pk_match
|
|
@@ -415,57 +785,65 @@ class AsyncSnapshotDiff:
|
|
|
415
785
|
):
|
|
416
786
|
return True
|
|
417
787
|
return False
|
|
418
|
-
|
|
788
|
+
|
|
419
789
|
# Collect all unexpected changes
|
|
420
790
|
unexpected_changes = []
|
|
421
|
-
|
|
791
|
+
|
|
422
792
|
for tbl, report in diff.items():
|
|
423
793
|
for row in report.get("modified_rows", []):
|
|
424
794
|
for f, vals in row["changes"].items():
|
|
425
795
|
if self.ignore_config.should_ignore_field(tbl, f):
|
|
426
796
|
continue
|
|
427
797
|
if not _is_change_allowed(tbl, row["row_id"], f, vals["after"]):
|
|
428
|
-
unexpected_changes.append(
|
|
429
|
-
|
|
798
|
+
unexpected_changes.append(
|
|
799
|
+
{
|
|
800
|
+
"type": "modification",
|
|
801
|
+
"table": tbl,
|
|
802
|
+
"row_id": row["row_id"],
|
|
803
|
+
"field": f,
|
|
804
|
+
"before": vals.get("before"),
|
|
805
|
+
"after": vals["after"],
|
|
806
|
+
"full_row": row,
|
|
807
|
+
}
|
|
808
|
+
)
|
|
809
|
+
|
|
810
|
+
for row in report.get("added_rows", []):
|
|
811
|
+
if not _is_change_allowed(tbl, row["row_id"], None, "__added__"):
|
|
812
|
+
unexpected_changes.append(
|
|
813
|
+
{
|
|
814
|
+
"type": "insertion",
|
|
430
815
|
"table": tbl,
|
|
431
816
|
"row_id": row["row_id"],
|
|
432
|
-
"field":
|
|
433
|
-
"
|
|
434
|
-
"after": vals["after"],
|
|
817
|
+
"field": None,
|
|
818
|
+
"after": "__added__",
|
|
435
819
|
"full_row": row,
|
|
436
|
-
}
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
if not _is_change_allowed(tbl, row["row_id"], None, "__added__"):
|
|
440
|
-
unexpected_changes.append({
|
|
441
|
-
"type": "insertion",
|
|
442
|
-
"table": tbl,
|
|
443
|
-
"row_id": row["row_id"],
|
|
444
|
-
"field": None,
|
|
445
|
-
"after": "__added__",
|
|
446
|
-
"full_row": row,
|
|
447
|
-
})
|
|
448
|
-
|
|
820
|
+
}
|
|
821
|
+
)
|
|
822
|
+
|
|
449
823
|
for row in report.get("removed_rows", []):
|
|
450
824
|
if not _is_change_allowed(tbl, row["row_id"], None, "__removed__"):
|
|
451
|
-
unexpected_changes.append(
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
825
|
+
unexpected_changes.append(
|
|
826
|
+
{
|
|
827
|
+
"type": "deletion",
|
|
828
|
+
"table": tbl,
|
|
829
|
+
"row_id": row["row_id"],
|
|
830
|
+
"field": None,
|
|
831
|
+
"after": "__removed__",
|
|
832
|
+
"full_row": row,
|
|
833
|
+
}
|
|
834
|
+
)
|
|
835
|
+
|
|
460
836
|
if unexpected_changes:
|
|
461
837
|
# Build comprehensive error message
|
|
462
838
|
error_lines = ["Unexpected database changes detected:"]
|
|
463
839
|
error_lines.append("")
|
|
464
|
-
|
|
840
|
+
|
|
465
841
|
for i, change in enumerate(unexpected_changes[:5], 1):
|
|
466
|
-
error_lines.append(
|
|
842
|
+
error_lines.append(
|
|
843
|
+
f"{i}. {change['type'].upper()} in table '{change['table']}':"
|
|
844
|
+
)
|
|
467
845
|
error_lines.append(f" Row ID: {change['row_id']}")
|
|
468
|
-
|
|
846
|
+
|
|
469
847
|
if change["type"] == "modification":
|
|
470
848
|
error_lines.append(f" Field: {change['field']}")
|
|
471
849
|
error_lines.append(f" Before: {repr(change['before'])}")
|
|
@@ -474,7 +852,7 @@ class AsyncSnapshotDiff:
|
|
|
474
852
|
error_lines.append(" New row added")
|
|
475
853
|
elif change["type"] == "deletion":
|
|
476
854
|
error_lines.append(" Row deleted")
|
|
477
|
-
|
|
855
|
+
|
|
478
856
|
# Show some context from the row
|
|
479
857
|
if "full_row" in change and change["full_row"]:
|
|
480
858
|
row_data = change["full_row"]
|
|
@@ -483,13 +861,15 @@ class AsyncSnapshotDiff:
|
|
|
483
861
|
row_data.get("data", {}), max_fields=5
|
|
484
862
|
)
|
|
485
863
|
error_lines.append(f" Row data: {formatted_row}")
|
|
486
|
-
|
|
864
|
+
|
|
487
865
|
error_lines.append("")
|
|
488
|
-
|
|
866
|
+
|
|
489
867
|
if len(unexpected_changes) > 5:
|
|
490
|
-
error_lines.append(
|
|
868
|
+
error_lines.append(
|
|
869
|
+
f"... and {len(unexpected_changes) - 5} more unexpected changes"
|
|
870
|
+
)
|
|
491
871
|
error_lines.append("")
|
|
492
|
-
|
|
872
|
+
|
|
493
873
|
# Show what changes were allowed
|
|
494
874
|
error_lines.append("Allowed changes were:")
|
|
495
875
|
if allowed_changes:
|
|
@@ -501,14 +881,552 @@ class AsyncSnapshotDiff:
|
|
|
501
881
|
f"After: {repr(allowed.get('after'))}"
|
|
502
882
|
)
|
|
503
883
|
if len(allowed_changes) > 3:
|
|
504
|
-
error_lines.append(
|
|
884
|
+
error_lines.append(
|
|
885
|
+
f" ... and {len(allowed_changes) - 3} more allowed changes"
|
|
886
|
+
)
|
|
505
887
|
else:
|
|
506
888
|
error_lines.append(" (No changes were allowed)")
|
|
507
|
-
|
|
889
|
+
|
|
508
890
|
raise AssertionError("\n".join(error_lines))
|
|
891
|
+
|
|
892
|
+
return self
|
|
893
|
+
|
|
894
|
+
async def _validate_diff_against_allowed_changes_v2(
|
|
895
|
+
self, diff: Dict[str, Any], allowed_changes: List[Dict[str, Any]]
|
|
896
|
+
):
|
|
897
|
+
"""Validate a collected diff against allowed changes with field-level spec support.
|
|
898
|
+
|
|
899
|
+
This version supports explicit change types via the "type" field:
|
|
900
|
+
1. Insert specs: {"table": "t", "pk": 1, "type": "insert", "fields": [("name", "value"), ("status", ...)]}
|
|
901
|
+
- ("name", value): check that field equals value
|
|
902
|
+
- ("name", None): check that field is SQL NULL
|
|
903
|
+
- ("name", ...): don't check the value, just acknowledge the field exists
|
|
904
|
+
2. Modify specs: {"table": "t", "pk": 1, "type": "modify", "resulting_fields": [...], "no_other_changes": True/False}
|
|
905
|
+
- Uses "resulting_fields" (not "fields") to be explicit about what's being checked
|
|
906
|
+
- "no_other_changes" is REQUIRED and must be True or False:
|
|
907
|
+
- True: Every changed field must be in resulting_fields (strict mode)
|
|
908
|
+
- False: Only check fields in resulting_fields match, ignore other changes
|
|
909
|
+
- ("field_name", value): check that after value equals value
|
|
910
|
+
- ("field_name", None): check that after value is SQL NULL
|
|
911
|
+
- ("field_name", ...): don't check value, just acknowledge field changed
|
|
912
|
+
3. Delete specs:
|
|
913
|
+
- Without field validation: {"table": "t", "pk": 1, "type": "delete"}
|
|
914
|
+
- With field validation: {"table": "t", "pk": 1, "type": "delete", "fields": [...]}
|
|
915
|
+
4. Whole-row specs (legacy):
|
|
916
|
+
- For additions: {"table": "t", "pk": 1, "fields": None, "after": "__added__"}
|
|
917
|
+
- For deletions: {"table": "t", "pk": 1, "fields": None, "after": "__removed__"}
|
|
918
|
+
|
|
919
|
+
When using "fields" for inserts, every field must be accounted for in the list.
|
|
920
|
+
For modifications, use "resulting_fields" with explicit "no_other_changes".
|
|
921
|
+
For deletions with "fields", all specified fields are validated against the deleted row.
|
|
922
|
+
"""
|
|
923
|
+
|
|
924
|
+
def _is_change_allowed(
|
|
925
|
+
table: str, row_id: Any, field: Optional[str], after_value: Any
|
|
926
|
+
) -> bool:
|
|
927
|
+
"""Check if a change is in the allowed list using semantic comparison."""
|
|
928
|
+
for allowed in allowed_changes:
|
|
929
|
+
allowed_pk = allowed.get("pk")
|
|
930
|
+
# Handle type conversion for primary key comparison
|
|
931
|
+
pk_match = (
|
|
932
|
+
str(allowed_pk) == str(row_id) if allowed_pk is not None else False
|
|
933
|
+
)
|
|
934
|
+
|
|
935
|
+
# For whole-row specs, check "fields": None; for field-level, check "field"
|
|
936
|
+
field_match = (
|
|
937
|
+
("fields" in allowed and allowed.get("fields") is None)
|
|
938
|
+
if field is None
|
|
939
|
+
else allowed.get("field") == field
|
|
940
|
+
)
|
|
941
|
+
if (
|
|
942
|
+
allowed["table"] == table
|
|
943
|
+
and pk_match
|
|
944
|
+
and field_match
|
|
945
|
+
and _values_equivalent(allowed.get("after"), after_value)
|
|
946
|
+
):
|
|
947
|
+
return True
|
|
948
|
+
return False
|
|
949
|
+
|
|
950
|
+
def _get_fields_spec_for_type(
|
|
951
|
+
table: str, row_id: Any, change_type: str
|
|
952
|
+
) -> Optional[List[Tuple[str, Any]]]:
|
|
953
|
+
"""Get the bulk fields spec for a given table/row/type if it exists.
|
|
954
|
+
|
|
955
|
+
Args:
|
|
956
|
+
table: The table name
|
|
957
|
+
row_id: The primary key value
|
|
958
|
+
change_type: One of "insert", "modify", or "delete"
|
|
959
|
+
|
|
960
|
+
Note: For "modify" type, use _get_modify_spec instead.
|
|
961
|
+
"""
|
|
962
|
+
for allowed in allowed_changes:
|
|
963
|
+
allowed_pk = allowed.get("pk")
|
|
964
|
+
pk_match = (
|
|
965
|
+
str(allowed_pk) == str(row_id) if allowed_pk is not None else False
|
|
966
|
+
)
|
|
967
|
+
if (
|
|
968
|
+
allowed["table"] == table
|
|
969
|
+
and pk_match
|
|
970
|
+
and allowed.get("type") == change_type
|
|
971
|
+
and "fields" in allowed
|
|
972
|
+
):
|
|
973
|
+
return allowed["fields"]
|
|
974
|
+
return None
|
|
975
|
+
|
|
976
|
+
def _get_modify_spec(table: str, row_id: Any) -> Optional[Dict[str, Any]]:
|
|
977
|
+
"""Get the modify spec for a given table/row if it exists.
|
|
978
|
+
|
|
979
|
+
Returns the full spec dict containing:
|
|
980
|
+
- resulting_fields: List of field tuples
|
|
981
|
+
- no_other_changes: Boolean (required)
|
|
982
|
+
|
|
983
|
+
Returns None if no modify spec found.
|
|
984
|
+
"""
|
|
985
|
+
for allowed in allowed_changes:
|
|
986
|
+
allowed_pk = allowed.get("pk")
|
|
987
|
+
pk_match = (
|
|
988
|
+
str(allowed_pk) == str(row_id) if allowed_pk is not None else False
|
|
989
|
+
)
|
|
990
|
+
if (
|
|
991
|
+
allowed["table"] == table
|
|
992
|
+
and pk_match
|
|
993
|
+
and allowed.get("type") == "modify"
|
|
994
|
+
):
|
|
995
|
+
return allowed
|
|
996
|
+
return None
|
|
997
|
+
|
|
998
|
+
def _is_type_allowed(table: str, row_id: Any, change_type: str) -> bool:
|
|
999
|
+
"""Check if a change type is allowed for the given table/row (with or without fields)."""
|
|
1000
|
+
for allowed in allowed_changes:
|
|
1001
|
+
allowed_pk = allowed.get("pk")
|
|
1002
|
+
pk_match = (
|
|
1003
|
+
str(allowed_pk) == str(row_id) if allowed_pk is not None else False
|
|
1004
|
+
)
|
|
1005
|
+
if (
|
|
1006
|
+
allowed["table"] == table
|
|
1007
|
+
and pk_match
|
|
1008
|
+
and allowed.get("type") == change_type
|
|
1009
|
+
):
|
|
1010
|
+
return True
|
|
1011
|
+
return False
|
|
1012
|
+
|
|
1013
|
+
def _parse_fields_spec(
|
|
1014
|
+
fields_spec: List[Tuple[str, Any]]
|
|
1015
|
+
) -> Dict[str, Tuple[bool, Any]]:
|
|
1016
|
+
"""Parse a fields spec into a mapping of field_name -> (should_check_value, expected_value)."""
|
|
1017
|
+
spec_map: Dict[str, Tuple[bool, Any]] = {}
|
|
1018
|
+
for spec_tuple in fields_spec:
|
|
1019
|
+
if len(spec_tuple) != 2:
|
|
1020
|
+
raise ValueError(
|
|
1021
|
+
f"Invalid field spec tuple: {spec_tuple}. "
|
|
1022
|
+
f"Expected 2-tuple like ('field', value), ('field', None), or ('field', ...)"
|
|
1023
|
+
)
|
|
1024
|
+
field_name, expected_value = spec_tuple
|
|
1025
|
+
if expected_value is ...:
|
|
1026
|
+
# Ellipsis: don't check value, just acknowledge field exists
|
|
1027
|
+
spec_map[field_name] = (False, None)
|
|
1028
|
+
else:
|
|
1029
|
+
# Any other value (including None for NULL check): check value
|
|
1030
|
+
spec_map[field_name] = (True, expected_value)
|
|
1031
|
+
return spec_map
|
|
1032
|
+
|
|
1033
|
+
def _validate_row_with_fields_spec(
|
|
1034
|
+
table: str,
|
|
1035
|
+
row_id: Any,
|
|
1036
|
+
row_data: Dict[str, Any],
|
|
1037
|
+
fields_spec: List[Tuple[str, Any]],
|
|
1038
|
+
) -> Optional[List[Tuple[str, Any, str]]]:
|
|
1039
|
+
"""Validate a row against a bulk fields spec.
|
|
1040
|
+
|
|
1041
|
+
Returns None if validation passes, or a list of (field, actual_value, issue)
|
|
1042
|
+
tuples for mismatches.
|
|
1043
|
+
|
|
1044
|
+
Field spec semantics:
|
|
1045
|
+
- ("field_name", value): check that field equals value
|
|
1046
|
+
- ("field_name", None): check that field is SQL NULL
|
|
1047
|
+
- ("field_name", ...): don't check value (acknowledge field exists)
|
|
1048
|
+
"""
|
|
1049
|
+
spec_map = _parse_fields_spec(fields_spec)
|
|
1050
|
+
unmatched_fields: List[Tuple[str, Any, str]] = []
|
|
1051
|
+
|
|
1052
|
+
for field_name, field_value in row_data.items():
|
|
1053
|
+
# Skip rowid as it's internal
|
|
1054
|
+
if field_name == "rowid":
|
|
1055
|
+
continue
|
|
1056
|
+
# Skip ignored fields
|
|
1057
|
+
if self.ignore_config.should_ignore_field(table, field_name):
|
|
1058
|
+
continue
|
|
1059
|
+
|
|
1060
|
+
if field_name not in spec_map:
|
|
1061
|
+
# Field not in spec - this is an error
|
|
1062
|
+
unmatched_fields.append(
|
|
1063
|
+
(field_name, field_value, "NOT_IN_FIELDS_SPEC")
|
|
1064
|
+
)
|
|
1065
|
+
else:
|
|
1066
|
+
should_check, expected_value = spec_map[field_name]
|
|
1067
|
+
if should_check and not _values_equivalent(
|
|
1068
|
+
expected_value, field_value
|
|
1069
|
+
):
|
|
1070
|
+
# Value doesn't match
|
|
1071
|
+
unmatched_fields.append(
|
|
1072
|
+
(field_name, field_value, f"expected {repr(expected_value)}")
|
|
1073
|
+
)
|
|
1074
|
+
|
|
1075
|
+
return unmatched_fields if unmatched_fields else None
|
|
1076
|
+
|
|
1077
|
+
def _validate_modification_with_fields_spec(
|
|
1078
|
+
table: str,
|
|
1079
|
+
row_id: Any,
|
|
1080
|
+
row_changes: Dict[str, Dict[str, Any]],
|
|
1081
|
+
resulting_fields: List[Tuple[str, Any]],
|
|
1082
|
+
no_other_changes: bool,
|
|
1083
|
+
) -> Optional[List[Tuple[str, Any, str]]]:
|
|
1084
|
+
"""Validate a modification against a resulting_fields spec.
|
|
1085
|
+
|
|
1086
|
+
Returns None if validation passes, or a list of (field, actual_value, issue)
|
|
1087
|
+
tuples for mismatches.
|
|
509
1088
|
|
|
1089
|
+
Args:
|
|
1090
|
+
table: The table name
|
|
1091
|
+
row_id: The row primary key
|
|
1092
|
+
row_changes: Dict of field_name -> {"before": ..., "after": ...}
|
|
1093
|
+
resulting_fields: List of field tuples to validate
|
|
1094
|
+
no_other_changes: If True, all changed fields must be in resulting_fields.
|
|
1095
|
+
If False, only validate fields in resulting_fields, ignore others.
|
|
1096
|
+
|
|
1097
|
+
Field spec semantics for modifications:
|
|
1098
|
+
- ("field_name", value): check that after value equals value
|
|
1099
|
+
- ("field_name", None): check that after value is SQL NULL
|
|
1100
|
+
- ("field_name", ...): don't check value, just acknowledge field changed
|
|
1101
|
+
"""
|
|
1102
|
+
spec_map = _parse_fields_spec(resulting_fields)
|
|
1103
|
+
unmatched_fields: List[Tuple[str, Any, str]] = []
|
|
1104
|
+
|
|
1105
|
+
for field_name, vals in row_changes.items():
|
|
1106
|
+
# Skip ignored fields
|
|
1107
|
+
if self.ignore_config.should_ignore_field(table, field_name):
|
|
1108
|
+
continue
|
|
1109
|
+
|
|
1110
|
+
after_value = vals["after"]
|
|
1111
|
+
|
|
1112
|
+
if field_name not in spec_map:
|
|
1113
|
+
# Changed field not in spec
|
|
1114
|
+
if no_other_changes:
|
|
1115
|
+
# Strict mode: all changed fields must be accounted for
|
|
1116
|
+
unmatched_fields.append(
|
|
1117
|
+
(field_name, after_value, "NOT_IN_RESULTING_FIELDS")
|
|
1118
|
+
)
|
|
1119
|
+
# If no_other_changes=False, ignore fields not in spec
|
|
1120
|
+
else:
|
|
1121
|
+
should_check, expected_value = spec_map[field_name]
|
|
1122
|
+
if should_check and not _values_equivalent(
|
|
1123
|
+
expected_value, after_value
|
|
1124
|
+
):
|
|
1125
|
+
# Value doesn't match
|
|
1126
|
+
unmatched_fields.append(
|
|
1127
|
+
(field_name, after_value, f"expected {repr(expected_value)}")
|
|
1128
|
+
)
|
|
1129
|
+
|
|
1130
|
+
return unmatched_fields if unmatched_fields else None
|
|
1131
|
+
|
|
1132
|
+
|
|
1133
|
+
# Collect all unexpected changes for detailed reporting
|
|
1134
|
+
unexpected_changes = []
|
|
1135
|
+
|
|
1136
|
+
for tbl, report in diff.items():
|
|
1137
|
+
for row in report.get("modified_rows", []):
|
|
1138
|
+
row_changes = row["changes"]
|
|
1139
|
+
|
|
1140
|
+
# Check for modify spec with resulting_fields
|
|
1141
|
+
modify_spec = _get_modify_spec(tbl, row["row_id"])
|
|
1142
|
+
if modify_spec is not None:
|
|
1143
|
+
resulting_fields = modify_spec.get("resulting_fields")
|
|
1144
|
+
if resulting_fields is not None:
|
|
1145
|
+
# Validate that no_other_changes is provided
|
|
1146
|
+
if "no_other_changes" not in modify_spec:
|
|
1147
|
+
raise ValueError(
|
|
1148
|
+
f"Modify spec for table '{tbl}' pk={row['row_id']} "
|
|
1149
|
+
f"has 'resulting_fields' but missing required 'no_other_changes' field. "
|
|
1150
|
+
f"Set 'no_other_changes': True to verify no other fields changed, "
|
|
1151
|
+
f"or 'no_other_changes': False to only check the specified fields."
|
|
1152
|
+
)
|
|
1153
|
+
no_other_changes = modify_spec["no_other_changes"]
|
|
1154
|
+
if not isinstance(no_other_changes, bool):
|
|
1155
|
+
raise ValueError(
|
|
1156
|
+
f"Modify spec for table '{tbl}' pk={row['row_id']} "
|
|
1157
|
+
f"has 'no_other_changes' but it must be a boolean (True or False), "
|
|
1158
|
+
f"got {type(no_other_changes).__name__}: {repr(no_other_changes)}"
|
|
1159
|
+
)
|
|
1160
|
+
|
|
1161
|
+
unmatched = _validate_modification_with_fields_spec(
|
|
1162
|
+
tbl, row["row_id"], row_changes, resulting_fields, no_other_changes
|
|
1163
|
+
)
|
|
1164
|
+
if unmatched:
|
|
1165
|
+
unexpected_changes.append(
|
|
1166
|
+
{
|
|
1167
|
+
"type": "modification",
|
|
1168
|
+
"table": tbl,
|
|
1169
|
+
"row_id": row["row_id"],
|
|
1170
|
+
"field": None,
|
|
1171
|
+
"before": None,
|
|
1172
|
+
"after": None,
|
|
1173
|
+
"full_row": row,
|
|
1174
|
+
"unmatched_fields": unmatched,
|
|
1175
|
+
}
|
|
1176
|
+
)
|
|
1177
|
+
continue # Skip to next row
|
|
1178
|
+
else:
|
|
1179
|
+
# Modify spec without resulting_fields - just allow the modification
|
|
1180
|
+
continue # Skip to next row
|
|
1181
|
+
|
|
1182
|
+
# Fall back to single-field specs (legacy)
|
|
1183
|
+
for f, vals in row_changes.items():
|
|
1184
|
+
if self.ignore_config.should_ignore_field(tbl, f):
|
|
1185
|
+
continue
|
|
1186
|
+
if not _is_change_allowed(tbl, row["row_id"], f, vals["after"]):
|
|
1187
|
+
unexpected_changes.append(
|
|
1188
|
+
{
|
|
1189
|
+
"type": "modification",
|
|
1190
|
+
"table": tbl,
|
|
1191
|
+
"row_id": row["row_id"],
|
|
1192
|
+
"field": f,
|
|
1193
|
+
"before": vals.get("before"),
|
|
1194
|
+
"after": vals["after"],
|
|
1195
|
+
"full_row": row,
|
|
1196
|
+
}
|
|
1197
|
+
)
|
|
1198
|
+
|
|
1199
|
+
for row in report.get("added_rows", []):
|
|
1200
|
+
row_data = row.get("data", {})
|
|
1201
|
+
|
|
1202
|
+
# Check for bulk fields spec (type: "insert")
|
|
1203
|
+
fields_spec = _get_fields_spec_for_type(tbl, row["row_id"], "insert")
|
|
1204
|
+
if fields_spec is not None:
|
|
1205
|
+
unmatched = _validate_row_with_fields_spec(
|
|
1206
|
+
tbl, row["row_id"], row_data, fields_spec
|
|
1207
|
+
)
|
|
1208
|
+
if unmatched:
|
|
1209
|
+
unexpected_changes.append(
|
|
1210
|
+
{
|
|
1211
|
+
"type": "insertion",
|
|
1212
|
+
"table": tbl,
|
|
1213
|
+
"row_id": row["row_id"],
|
|
1214
|
+
"field": None,
|
|
1215
|
+
"after": "__added__",
|
|
1216
|
+
"full_row": row,
|
|
1217
|
+
"unmatched_fields": unmatched,
|
|
1218
|
+
}
|
|
1219
|
+
)
|
|
1220
|
+
continue # Skip to next row
|
|
1221
|
+
|
|
1222
|
+
# Check if insertion is allowed without field validation
|
|
1223
|
+
if _is_type_allowed(tbl, row["row_id"], "insert"):
|
|
1224
|
+
continue # Insertion is allowed, skip to next row
|
|
1225
|
+
|
|
1226
|
+
# Check for whole-row spec (legacy)
|
|
1227
|
+
whole_row_allowed = _is_change_allowed(
|
|
1228
|
+
tbl, row["row_id"], None, "__added__"
|
|
1229
|
+
)
|
|
1230
|
+
|
|
1231
|
+
if not whole_row_allowed:
|
|
1232
|
+
unexpected_changes.append(
|
|
1233
|
+
{
|
|
1234
|
+
"type": "insertion",
|
|
1235
|
+
"table": tbl,
|
|
1236
|
+
"row_id": row["row_id"],
|
|
1237
|
+
"field": None,
|
|
1238
|
+
"after": "__added__",
|
|
1239
|
+
"full_row": row,
|
|
1240
|
+
}
|
|
1241
|
+
)
|
|
1242
|
+
|
|
1243
|
+
for row in report.get("removed_rows", []):
|
|
1244
|
+
row_data = row.get("data", {})
|
|
1245
|
+
|
|
1246
|
+
# Check for bulk fields spec (type: "delete")
|
|
1247
|
+
fields_spec = _get_fields_spec_for_type(tbl, row["row_id"], "delete")
|
|
1248
|
+
if fields_spec is not None:
|
|
1249
|
+
unmatched = _validate_row_with_fields_spec(
|
|
1250
|
+
tbl, row["row_id"], row_data, fields_spec
|
|
1251
|
+
)
|
|
1252
|
+
if unmatched:
|
|
1253
|
+
unexpected_changes.append(
|
|
1254
|
+
{
|
|
1255
|
+
"type": "deletion",
|
|
1256
|
+
"table": tbl,
|
|
1257
|
+
"row_id": row["row_id"],
|
|
1258
|
+
"field": None,
|
|
1259
|
+
"after": "__removed__",
|
|
1260
|
+
"full_row": row,
|
|
1261
|
+
"unmatched_fields": unmatched,
|
|
1262
|
+
}
|
|
1263
|
+
)
|
|
1264
|
+
continue # Skip to next row
|
|
1265
|
+
|
|
1266
|
+
# Check if deletion is allowed without field validation
|
|
1267
|
+
if _is_type_allowed(tbl, row["row_id"], "delete"):
|
|
1268
|
+
continue # Deletion is allowed, skip to next row
|
|
1269
|
+
|
|
1270
|
+
# Check for whole-row spec (legacy)
|
|
1271
|
+
whole_row_allowed = _is_change_allowed(
|
|
1272
|
+
tbl, row["row_id"], None, "__removed__"
|
|
1273
|
+
)
|
|
1274
|
+
|
|
1275
|
+
if not whole_row_allowed:
|
|
1276
|
+
unexpected_changes.append(
|
|
1277
|
+
{
|
|
1278
|
+
"type": "deletion",
|
|
1279
|
+
"table": tbl,
|
|
1280
|
+
"row_id": row["row_id"],
|
|
1281
|
+
"field": None,
|
|
1282
|
+
"after": "__removed__",
|
|
1283
|
+
"full_row": row,
|
|
1284
|
+
}
|
|
1285
|
+
)
|
|
1286
|
+
|
|
1287
|
+
if unexpected_changes:
|
|
1288
|
+
# Build comprehensive error message
|
|
1289
|
+
error_lines = ["Unexpected database changes detected:"]
|
|
1290
|
+
error_lines.append("")
|
|
1291
|
+
|
|
1292
|
+
for i, change in enumerate(unexpected_changes[:5], 1):
|
|
1293
|
+
error_lines.append(
|
|
1294
|
+
f"{i}. {change['type'].upper()} in table '{change['table']}':"
|
|
1295
|
+
)
|
|
1296
|
+
error_lines.append(f" Row ID: {change['row_id']}")
|
|
1297
|
+
|
|
1298
|
+
if change["type"] == "modification":
|
|
1299
|
+
error_lines.append(f" Field: {change['field']}")
|
|
1300
|
+
error_lines.append(f" Before: {repr(change['before'])}")
|
|
1301
|
+
error_lines.append(f" After: {repr(change['after'])}")
|
|
1302
|
+
elif change["type"] == "insertion":
|
|
1303
|
+
error_lines.append(" New row added")
|
|
1304
|
+
elif change["type"] == "deletion":
|
|
1305
|
+
error_lines.append(" Row deleted")
|
|
1306
|
+
|
|
1307
|
+
# Show unmatched fields if present (from bulk fields spec validation)
|
|
1308
|
+
if "unmatched_fields" in change and change["unmatched_fields"]:
|
|
1309
|
+
error_lines.append(" Unmatched fields:")
|
|
1310
|
+
for field_info in change["unmatched_fields"][:5]:
|
|
1311
|
+
field_name, actual_value, issue = field_info
|
|
1312
|
+
error_lines.append(
|
|
1313
|
+
f" - {field_name}: {repr(actual_value)} ({issue})"
|
|
1314
|
+
)
|
|
1315
|
+
if len(change["unmatched_fields"]) > 10:
|
|
1316
|
+
error_lines.append(
|
|
1317
|
+
f" ... and {len(change['unmatched_fields']) - 10} more"
|
|
1318
|
+
)
|
|
1319
|
+
|
|
1320
|
+
# Show some context from the row
|
|
1321
|
+
if "full_row" in change and change["full_row"]:
|
|
1322
|
+
row_data = change["full_row"]
|
|
1323
|
+
if change["type"] == "modification" and "data" in row_data:
|
|
1324
|
+
# For modifications, show the current state
|
|
1325
|
+
formatted_row = _format_row_for_error(
|
|
1326
|
+
row_data.get("data", {}), max_fields=5
|
|
1327
|
+
)
|
|
1328
|
+
error_lines.append(f" Row data: {formatted_row}")
|
|
1329
|
+
elif (
|
|
1330
|
+
change["type"] in ["insertion", "deletion"]
|
|
1331
|
+
and "data" in row_data
|
|
1332
|
+
):
|
|
1333
|
+
# For insertions/deletions, show the row data
|
|
1334
|
+
formatted_row = _format_row_for_error(
|
|
1335
|
+
row_data.get("data", {}), max_fields=5
|
|
1336
|
+
)
|
|
1337
|
+
error_lines.append(f" Row data: {formatted_row}")
|
|
1338
|
+
|
|
1339
|
+
error_lines.append("")
|
|
1340
|
+
|
|
1341
|
+
if len(unexpected_changes) > 5:
|
|
1342
|
+
error_lines.append(
|
|
1343
|
+
f"... and {len(unexpected_changes) - 5} more unexpected changes"
|
|
1344
|
+
)
|
|
1345
|
+
error_lines.append("")
|
|
1346
|
+
|
|
1347
|
+
# Show what changes were allowed
|
|
1348
|
+
error_lines.append("Allowed changes were:")
|
|
1349
|
+
if allowed_changes:
|
|
1350
|
+
for i, allowed in enumerate(allowed_changes[:3], 1):
|
|
1351
|
+
change_type = allowed.get("type", "unspecified")
|
|
1352
|
+
|
|
1353
|
+
# For modify type, use resulting_fields
|
|
1354
|
+
if change_type == "modify" and "resulting_fields" in allowed and allowed["resulting_fields"] is not None:
|
|
1355
|
+
fields_summary = ", ".join(
|
|
1356
|
+
f[0] if len(f) == 1 else f"{f[0]}={'NOT_CHECKED' if f[1] is ... else repr(f[1])}"
|
|
1357
|
+
for f in allowed["resulting_fields"][:3]
|
|
1358
|
+
)
|
|
1359
|
+
if len(allowed["resulting_fields"]) > 3:
|
|
1360
|
+
fields_summary += f", ... +{len(allowed['resulting_fields']) - 3} more"
|
|
1361
|
+
no_other = allowed.get("no_other_changes", "NOT_SET")
|
|
1362
|
+
error_lines.append(
|
|
1363
|
+
f" {i}. Table: {allowed.get('table')}, "
|
|
1364
|
+
f"ID: {allowed.get('pk')}, "
|
|
1365
|
+
f"Type: {change_type}, "
|
|
1366
|
+
f"resulting_fields: [{fields_summary}], "
|
|
1367
|
+
f"no_other_changes: {no_other}"
|
|
1368
|
+
)
|
|
1369
|
+
elif "fields" in allowed and allowed["fields"] is not None:
|
|
1370
|
+
# Show bulk fields spec (for insert/delete)
|
|
1371
|
+
fields_summary = ", ".join(
|
|
1372
|
+
f[0] if len(f) == 1 else f"{f[0]}={'NOT_CHECKED' if f[1] is ... else repr(f[1])}"
|
|
1373
|
+
for f in allowed["fields"][:3]
|
|
1374
|
+
)
|
|
1375
|
+
if len(allowed["fields"]) > 3:
|
|
1376
|
+
fields_summary += f", ... +{len(allowed['fields']) - 3} more"
|
|
1377
|
+
error_lines.append(
|
|
1378
|
+
f" {i}. Table: {allowed.get('table')}, "
|
|
1379
|
+
f"ID: {allowed.get('pk')}, "
|
|
1380
|
+
f"Type: {change_type}, "
|
|
1381
|
+
f"Fields: [{fields_summary}]"
|
|
1382
|
+
)
|
|
1383
|
+
else:
|
|
1384
|
+
error_lines.append(
|
|
1385
|
+
f" {i}. Table: {allowed.get('table')}, "
|
|
1386
|
+
f"ID: {allowed.get('pk')}, "
|
|
1387
|
+
f"Type: {change_type}"
|
|
1388
|
+
)
|
|
1389
|
+
if len(allowed_changes) > 3:
|
|
1390
|
+
error_lines.append(
|
|
1391
|
+
f" ... and {len(allowed_changes) - 3} more allowed changes"
|
|
1392
|
+
)
|
|
1393
|
+
else:
|
|
1394
|
+
error_lines.append(" (No changes were allowed)")
|
|
1395
|
+
|
|
1396
|
+
raise AssertionError("\n".join(error_lines))
|
|
1397
|
+
|
|
510
1398
|
return self
|
|
511
1399
|
|
|
1400
|
+
async def expect_only(self, allowed_changes: List[Dict[str, Any]]):
|
|
1401
|
+
"""Ensure only specified changes occurred."""
|
|
1402
|
+
# Special case: empty allowed_changes means no changes should have occurred
|
|
1403
|
+
if not allowed_changes:
|
|
1404
|
+
return await self._expect_no_changes()
|
|
1405
|
+
|
|
1406
|
+
# For expect_only, we can optimize by only checking the specific rows mentioned
|
|
1407
|
+
if self._can_use_targeted_queries(allowed_changes):
|
|
1408
|
+
return await self._expect_only_targeted(allowed_changes)
|
|
1409
|
+
|
|
1410
|
+
# Fall back to full diff for complex cases
|
|
1411
|
+
diff = await self._collect()
|
|
1412
|
+
return await self._validate_diff_against_allowed_changes(diff, allowed_changes)
|
|
1413
|
+
|
|
1414
|
+
async def expect_only_v2(self, allowed_changes: List[Dict[str, Any]]):
|
|
1415
|
+
"""Ensure only specified changes occurred, with field-level spec support.
|
|
1416
|
+
|
|
1417
|
+
This version supports field-level specifications for added/removed rows,
|
|
1418
|
+
allowing users to specify expected field values instead of just whole-row specs.
|
|
1419
|
+
"""
|
|
1420
|
+
# Special case: empty allowed_changes means no changes should have occurred
|
|
1421
|
+
if not allowed_changes:
|
|
1422
|
+
return await self._expect_no_changes()
|
|
1423
|
+
|
|
1424
|
+
# Fall back to full diff for v2 (no targeted optimization yet)
|
|
1425
|
+
diff = await self._collect()
|
|
1426
|
+
return await self._validate_diff_against_allowed_changes_v2(
|
|
1427
|
+
diff, allowed_changes
|
|
1428
|
+
)
|
|
1429
|
+
|
|
512
1430
|
|
|
513
1431
|
class AsyncQueryBuilder:
|
|
514
1432
|
"""Async query builder that translates DSL to SQL and executes through the API."""
|