fleet-python 0.2.74__py3-none-any.whl → 0.2.74b2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -255,50 +255,51 @@ class AsyncSnapshotQueryBuilder:
255
255
 
256
256
  class AsyncSnapshotDiff:
257
257
  """Compute & validate changes between two snapshots fetched via API."""
258
-
258
+
259
259
  def __init__(
260
260
  self,
261
261
  before: AsyncDatabaseSnapshot,
262
262
  after: AsyncDatabaseSnapshot,
263
- ignore_config: Optional[IgnoreConfig] = None,
263
+ ignore_config: IgnoreConfig | None = None,
264
264
  ):
265
265
  self.before = before
266
266
  self.after = after
267
267
  self.ignore_config = ignore_config or IgnoreConfig()
268
- self._cached: Optional[Dict[str, Any]] = None
269
- self._targeted_mode = False # Flag to use targeted queries
270
-
271
- async def _get_primary_key_columns(self, table: str) -> List[str]:
268
+ self._cached: dict[str, Any] | None = None
269
+
270
+ async def _get_primary_key_columns(self, table: str) -> list[str]:
272
271
  """Get primary key columns for a table."""
273
272
  # Try to get from schema
274
273
  schema_response = await self.after.resource.query(f"PRAGMA table_info({table})")
275
274
  if not schema_response.rows:
276
275
  return ["id"] # Default fallback
277
-
276
+
278
277
  pk_columns = []
279
278
  for row in schema_response.rows:
280
279
  # row format: (cid, name, type, notnull, dflt_value, pk)
281
280
  if row[5] > 0: # pk > 0 means it's part of primary key
282
281
  pk_columns.append((row[5], row[1])) # (pk_position, column_name)
283
-
282
+
284
283
  if not pk_columns:
285
284
  # Try common defaults
286
285
  all_columns = [row[1] for row in schema_response.rows]
287
286
  if "id" in all_columns:
288
287
  return ["id"]
289
288
  return ["rowid"]
290
-
289
+
291
290
  # Sort by primary key position and return just the column names
292
291
  pk_columns.sort(key=lambda x: x[0])
293
292
  return [col[1] for col in pk_columns]
294
-
293
+
295
294
  async def _collect(self):
296
295
  """Collect all differences between snapshots."""
297
296
  if self._cached is not None:
298
297
  return self._cached
299
-
300
- all_tables = set(await self.before.tables()) | set(await self.after.tables())
301
- diff: Dict[str, Dict[str, Any]] = {}
298
+
299
+ before_tables = set(await self.before.tables())
300
+ after_tables = set(await self.after.tables())
301
+ all_tables = before_tables | after_tables
302
+ diff: dict[str, dict[str, Any]] = {}
302
303
 
303
304
  for tbl in all_tables:
304
305
  if self.ignore_config.should_ignore_table(tbl):
@@ -307,26 +308,31 @@ class AsyncSnapshotDiff:
307
308
  # Get primary key columns
308
309
  pk_columns = await self._get_primary_key_columns(tbl)
309
310
 
310
- # Ensure data is fetched for this table
311
- await self.before._ensure_table_data(tbl)
312
- await self.after._ensure_table_data(tbl)
313
-
314
- # Get data from both snapshots
315
- before_data = self.before._data.get(tbl, [])
316
- after_data = self.after._data.get(tbl, [])
311
+ # Get data from both snapshots, fetching table contents on demand
312
+ if tbl in before_tables:
313
+ await self.before._ensure_table_data(tbl)
314
+ before_data = self.before._data.get(tbl, [])
315
+ else:
316
+ before_data = []
317
317
 
318
+ if tbl in after_tables:
319
+ await self.after._ensure_table_data(tbl)
320
+ after_data = self.after._data.get(tbl, [])
321
+ else:
322
+ after_data = []
323
+
318
324
  # Create indexes by primary key
319
- def make_key(row: dict, pk_cols: List[str]) -> Any:
325
+ def make_key(row: dict, pk_cols: list[str]) -> Any:
320
326
  if len(pk_cols) == 1:
321
327
  return row.get(pk_cols[0])
322
328
  return tuple(row.get(col) for col in pk_cols)
323
-
329
+
324
330
  before_index = {make_key(row, pk_columns): row for row in before_data}
325
331
  after_index = {make_key(row, pk_columns): row for row in after_data}
326
-
332
+
327
333
  before_keys = set(before_index.keys())
328
334
  after_keys = set(after_index.keys())
329
-
335
+
330
336
  # Find changes
331
337
  result = {
332
338
  "table_name": tbl,
@@ -337,23 +343,27 @@ class AsyncSnapshotDiff:
337
343
  "unchanged_count": 0,
338
344
  "total_changes": 0,
339
345
  }
340
-
346
+
341
347
  # Added rows
342
348
  for key in after_keys - before_keys:
343
- result["added_rows"].append({"row_id": key, "data": after_index[key]})
344
-
349
+ result["added_rows"].append({
350
+ "row_id": key,
351
+ "data": after_index[key]
352
+ })
353
+
345
354
  # Removed rows
346
355
  for key in before_keys - after_keys:
347
- result["removed_rows"].append(
348
- {"row_id": key, "data": before_index[key]}
349
- )
350
-
356
+ result["removed_rows"].append({
357
+ "row_id": key,
358
+ "data": before_index[key]
359
+ })
360
+
351
361
  # Modified rows
352
362
  for key in before_keys & after_keys:
353
363
  before_row = before_index[key]
354
364
  after_row = after_index[key]
355
365
  changes = {}
356
-
366
+
357
367
  for field in set(before_row.keys()) | set(after_row.keys()):
358
368
  if self.ignore_config.should_ignore_field(tbl, field):
359
369
  continue
@@ -361,413 +371,33 @@ class AsyncSnapshotDiff:
361
371
  after_val = after_row.get(field)
362
372
  if not _values_equivalent(before_val, after_val):
363
373
  changes[field] = {"before": before_val, "after": after_val}
364
-
374
+
365
375
  if changes:
366
- result["modified_rows"].append(
367
- {
368
- "row_id": key,
369
- "changes": changes,
370
- "data": after_row, # Current state
371
- }
372
- )
376
+ result["modified_rows"].append({
377
+ "row_id": key,
378
+ "changes": changes,
379
+ "data": after_row # Current state
380
+ })
373
381
  else:
374
382
  result["unchanged_count"] += 1
375
-
383
+
376
384
  result["total_changes"] = (
377
- len(result["added_rows"])
378
- + len(result["removed_rows"])
379
- + len(result["modified_rows"])
385
+ len(result["added_rows"]) +
386
+ len(result["removed_rows"]) +
387
+ len(result["modified_rows"])
380
388
  )
381
-
389
+
382
390
  diff[tbl] = result
383
-
391
+
384
392
  self._cached = diff
385
393
  return diff
386
-
387
- @property
388
- def changes(self) -> Dict[str, Dict[str, Any]]:
389
- """Expose cached changes; ensure callers awaited a diff-producing method first."""
390
- if self._cached is None:
391
- raise RuntimeError(
392
- "Diff not collected yet; await an operation like expect_only() first."
393
- )
394
- return self._cached
395
-
396
- def _can_use_targeted_queries(self, allowed_changes: List[Dict[str, Any]]) -> bool:
397
- """Check if we can use targeted queries for optimization."""
398
- # We can use targeted queries if all allowed changes specify table and pk
399
- for change in allowed_changes:
400
- if "table" not in change or "pk" not in change:
401
- return False
402
- return True
403
-
404
- def _build_pk_where_clause(self, pk_columns: List[str], pk_value: Any) -> str:
405
- """Build WHERE clause for primary key lookup."""
406
- # Escape single quotes in values to prevent SQL injection
407
- def escape_value(val: Any) -> str:
408
- if val is None:
409
- return "NULL"
410
- elif isinstance(val, str):
411
- escaped = str(val).replace("'", "''")
412
- return f"'{escaped}'"
413
- else:
414
- return f"'{val}'"
415
-
416
- if len(pk_columns) == 1:
417
- return f"{pk_columns[0]} = {escape_value(pk_value)}"
418
- else:
419
- # Composite key
420
- if isinstance(pk_value, tuple):
421
- conditions = [
422
- f"{col} = {escape_value(val)}"
423
- for col, val in zip(pk_columns, pk_value)
424
- ]
425
- return " AND ".join(conditions)
426
- else:
427
- # Shouldn't happen if data is consistent
428
- return f"{pk_columns[0]} = {escape_value(pk_value)}"
429
-
430
- async def _expect_no_changes(self):
431
- """Efficiently verify that no changes occurred between snapshots using row counts."""
432
- try:
433
- import asyncio
434
-
435
- # Get all tables from both snapshots
436
- before_tables = set(await self.before.tables())
437
- after_tables = set(await self.after.tables())
438
-
439
- # Check for added/removed tables (excluding ignored ones)
440
- added_tables = after_tables - before_tables
441
- removed_tables = before_tables - after_tables
442
-
443
- for table in added_tables:
444
- if not self.ignore_config.should_ignore_table(table):
445
- raise AssertionError(f"Unexpected table added: {table}")
446
-
447
- for table in removed_tables:
448
- if not self.ignore_config.should_ignore_table(table):
449
- raise AssertionError(f"Unexpected table removed: {table}")
450
-
451
- # Prepare tables to check
452
- tables_to_check = []
453
- all_tables = before_tables | after_tables
454
- for table in all_tables:
455
- if not self.ignore_config.should_ignore_table(table):
456
- tables_to_check.append(table)
457
-
458
- # If no tables to check, we're done
459
- if not tables_to_check:
460
- return self
461
-
462
- # Track errors and tables needing verification
463
- errors = []
464
- tables_needing_verification = []
465
-
466
- async def check_table_counts(table: str):
467
- """Check row counts for a single table."""
468
- try:
469
- # Get row counts from both snapshots
470
- before_count = 0
471
- after_count = 0
472
-
473
- if table in before_tables:
474
- before_count_response = await self.before.resource.query(
475
- f"SELECT COUNT(*) FROM {table}"
476
- )
477
- before_count = (
478
- before_count_response.rows[0][0]
479
- if before_count_response.rows
480
- else 0
481
- )
482
-
483
- if table in after_tables:
484
- after_count_response = await self.after.resource.query(
485
- f"SELECT COUNT(*) FROM {table}"
486
- )
487
- after_count = (
488
- after_count_response.rows[0][0]
489
- if after_count_response.rows
490
- else 0
491
- )
492
-
493
- if before_count != after_count:
494
- error_msg = (
495
- f"Unexpected change in table '{table}': "
496
- f"row count changed from {before_count} to {after_count}"
497
- )
498
- errors.append(AssertionError(error_msg))
499
- elif before_count > 0 and before_count <= 1000:
500
- # Mark for detailed verification
501
- tables_needing_verification.append(table)
502
-
503
- except Exception as e:
504
- errors.append(e)
505
-
506
- # Execute count checks in parallel
507
- await asyncio.gather(*[check_table_counts(table) for table in tables_to_check])
508
-
509
- # Check if any errors occurred during count checking
510
- if errors:
511
- raise errors[0]
512
-
513
- # Now verify small tables for data changes (also in parallel)
514
- if tables_needing_verification:
515
- verification_errors = []
516
-
517
- async def verify_table(table: str):
518
- """Verify a single table's data hasn't changed."""
519
- try:
520
- await self._verify_table_unchanged(table)
521
- except AssertionError as e:
522
- verification_errors.append(e)
523
-
524
- await asyncio.gather(*[verify_table(table) for table in tables_needing_verification])
525
-
526
- # Check if any errors occurred during verification
527
- if verification_errors:
528
- raise verification_errors[0]
529
-
530
- return self
531
-
532
- except AssertionError:
533
- # Re-raise assertion errors (these are expected failures)
534
- raise
535
- except Exception as e:
536
- # If the optimized check fails for other reasons, fall back to full diff
537
- print(f"Warning: Optimized no-changes check failed: {e}")
538
- print("Falling back to full diff...")
539
- return await self._validate_diff_against_allowed_changes(
540
- await self._collect(), []
541
- )
542
-
543
- async def _verify_table_unchanged(self, table: str):
544
- """Verify that a table's data hasn't changed (for small tables)."""
545
- # Get primary key columns
546
- pk_columns = await self._get_primary_key_columns(table)
547
-
548
- # Get sorted data from both snapshots
549
- order_by = ", ".join(pk_columns) if pk_columns else "rowid"
550
-
551
- before_response = await self.before.resource.query(
552
- f"SELECT * FROM {table} ORDER BY {order_by}"
553
- )
554
- after_response = await self.after.resource.query(
555
- f"SELECT * FROM {table} ORDER BY {order_by}"
556
- )
557
-
558
- # Quick check: if column counts differ, there's a schema change
559
- if before_response.columns != after_response.columns:
560
- raise AssertionError(f"Schema changed in table '{table}'")
561
-
562
- # Compare row by row
563
- if len(before_response.rows) != len(after_response.rows):
564
- raise AssertionError(
565
- f"Row count mismatch in table '{table}': "
566
- f"{len(before_response.rows)} vs {len(after_response.rows)}"
567
- )
568
-
569
- for i, (before_row, after_row) in enumerate(
570
- zip(before_response.rows, after_response.rows)
571
- ):
572
- before_dict = dict(zip(before_response.columns, before_row))
573
- after_dict = dict(zip(after_response.columns, after_row))
574
-
575
- # Compare fields, ignoring those in ignore config
576
- for field in before_response.columns:
577
- if self.ignore_config.should_ignore_field(table, field):
578
- continue
579
-
580
- if not _values_equivalent(
581
- before_dict.get(field), after_dict.get(field)
582
- ):
583
- pk_val = before_dict.get(pk_columns[0]) if pk_columns else i
584
- raise AssertionError(
585
- f"Unexpected change in table '{table}', row {pk_val}, "
586
- f"field '{field}': {repr(before_dict.get(field))} -> {repr(after_dict.get(field))}"
587
- )
588
-
589
- def _is_field_change_allowed(
590
- self, table_changes: List[Dict[str, Any]], pk: Any, field: str, after_val: Any
591
- ) -> bool:
592
- """Check if a specific field change is allowed."""
593
- for change in table_changes:
594
- if (
595
- str(change.get("pk")) == str(pk)
596
- and change.get("field") == field
597
- and _values_equivalent(change.get("after"), after_val)
598
- ):
599
- return True
600
- return False
601
-
602
- def _is_row_change_allowed(
603
- self, table_changes: List[Dict[str, Any]], pk: Any, change_type: str
604
- ) -> bool:
605
- """Check if a row addition/deletion is allowed."""
606
- for change in table_changes:
607
- if str(change.get("pk")) == str(pk) and change.get("after") == change_type:
608
- return True
609
- return False
610
-
611
- async def _expect_only_targeted(self, allowed_changes: List[Dict[str, Any]]):
612
- """Optimized version that only queries specific rows mentioned in allowed_changes."""
613
- import asyncio
614
-
615
- # Group allowed changes by table
616
- changes_by_table: Dict[str, List[Dict[str, Any]]] = {}
617
- for change in allowed_changes:
618
- table = change["table"]
619
- if table not in changes_by_table:
620
- changes_by_table[table] = []
621
- changes_by_table[table].append(change)
622
-
623
- errors = []
624
-
625
- # Function to check a single row
626
- async def check_row(
627
- table: str,
628
- pk: Any,
629
- table_changes: List[Dict[str, Any]],
630
- pk_columns: List[str],
631
- ):
632
- try:
633
- # Build WHERE clause for this PK
634
- where_sql = self._build_pk_where_clause(pk_columns, pk)
635
-
636
- # Query before snapshot
637
- before_query = f"SELECT * FROM {table} WHERE {where_sql}"
638
- before_response = await self.before.resource.query(before_query)
639
- before_row = (
640
- dict(zip(before_response.columns, before_response.rows[0]))
641
- if before_response.rows
642
- else None
643
- )
644
-
645
- # Query after snapshot
646
- after_response = await self.after.resource.query(before_query)
647
- after_row = (
648
- dict(zip(after_response.columns, after_response.rows[0]))
649
- if after_response.rows
650
- else None
651
- )
652
-
653
- # Check changes for this row
654
- if before_row and after_row:
655
- # Modified row - check fields
656
- for field in set(before_row.keys()) | set(after_row.keys()):
657
- if self.ignore_config.should_ignore_field(table, field):
658
- continue
659
- before_val = before_row.get(field)
660
- after_val = after_row.get(field)
661
- if not _values_equivalent(before_val, after_val):
662
- # Check if this change is allowed
663
- if not self._is_field_change_allowed(
664
- table_changes, pk, field, after_val
665
- ):
666
- error_msg = (
667
- f"Unexpected change in table '{table}', "
668
- f"row {pk}, field '{field}': "
669
- f"{repr(before_val)} -> {repr(after_val)}"
670
- )
671
- errors.append(AssertionError(error_msg))
672
- return # Stop checking this row
673
- elif not before_row and after_row:
674
- # Added row
675
- if not self._is_row_change_allowed(table_changes, pk, "__added__"):
676
- error_msg = f"Unexpected row added in table '{table}': {pk}"
677
- errors.append(AssertionError(error_msg))
678
- elif before_row and not after_row:
679
- # Removed row
680
- if not self._is_row_change_allowed(table_changes, pk, "__removed__"):
681
- error_msg = f"Unexpected row removed from table '{table}': {pk}"
682
- errors.append(AssertionError(error_msg))
683
- except Exception as e:
684
- errors.append(e)
685
-
686
- # Prepare all row checks
687
- row_checks = []
688
- for table, table_changes in changes_by_table.items():
689
- if self.ignore_config.should_ignore_table(table):
690
- continue
691
-
692
- # Get primary key columns once per table
693
- pk_columns = await self._get_primary_key_columns(table)
694
-
695
- # Extract unique PKs to check
696
- pks_to_check = {change["pk"] for change in table_changes}
697
-
698
- for pk in pks_to_check:
699
- row_checks.append((table, pk, table_changes, pk_columns))
700
-
701
- # Execute row checks in parallel
702
- if row_checks:
703
- await asyncio.gather(
704
- *[
705
- check_row(table, pk, table_changes, pk_columns)
706
- for table, pk, table_changes, pk_columns in row_checks
707
- ]
708
- )
709
-
710
- # Check for errors from row checks
711
- if errors:
712
- raise errors[0]
713
-
714
- # Now check tables not mentioned in allowed_changes to ensure no changes
715
- all_tables = set(await self.before.tables()) | set(await self.after.tables())
716
- tables_to_verify = []
717
-
718
- for table in all_tables:
719
- if (
720
- table not in changes_by_table
721
- and not self.ignore_config.should_ignore_table(table)
722
- ):
723
- tables_to_verify.append(table)
724
-
725
- # Function to verify no changes in a table
726
- async def verify_no_changes(table: str):
727
- try:
728
- # For tables with no allowed changes, just check row counts
729
- before_count_response = await self.before.resource.query(
730
- f"SELECT COUNT(*) FROM {table}"
731
- )
732
- before_count = (
733
- before_count_response.rows[0][0]
734
- if before_count_response.rows
735
- else 0
736
- )
737
-
738
- after_count_response = await self.after.resource.query(
739
- f"SELECT COUNT(*) FROM {table}"
740
- )
741
- after_count = (
742
- after_count_response.rows[0][0] if after_count_response.rows else 0
743
- )
744
-
745
- if before_count != after_count:
746
- error_msg = (
747
- f"Unexpected change in table '{table}': "
748
- f"row count changed from {before_count} to {after_count}"
749
- )
750
- errors.append(AssertionError(error_msg))
751
- except Exception as e:
752
- errors.append(e)
753
-
754
- # Execute table verification in parallel
755
- if tables_to_verify:
756
- await asyncio.gather(*[verify_no_changes(table) for table in tables_to_verify])
757
-
758
- # Final error check
759
- if errors:
760
- raise errors[0]
761
-
762
- return self
763
-
764
- async def _validate_diff_against_allowed_changes(
765
- self, diff: Dict[str, Any], allowed_changes: List[Dict[str, Any]]
766
- ):
767
- """Validate a collected diff against allowed changes."""
768
-
394
+
395
+ async def expect_only(self, allowed_changes: list[dict[str, Any]]):
396
+ """Ensure only specified changes occurred."""
397
+ diff = await self._collect()
398
+
769
399
  def _is_change_allowed(
770
- table: str, row_id: Any, field: Optional[str], after_value: Any
400
+ table: str, row_id: Any, field: str | None, after_value: Any
771
401
  ) -> bool:
772
402
  """Check if a change is in the allowed list using semantic comparison."""
773
403
  for allowed in allowed_changes:
@@ -776,7 +406,7 @@ class AsyncSnapshotDiff:
776
406
  pk_match = (
777
407
  str(allowed_pk) == str(row_id) if allowed_pk is not None else False
778
408
  )
779
-
409
+
780
410
  if (
781
411
  allowed["table"] == table
782
412
  and pk_match
@@ -785,65 +415,57 @@ class AsyncSnapshotDiff:
785
415
  ):
786
416
  return True
787
417
  return False
788
-
418
+
789
419
  # Collect all unexpected changes
790
420
  unexpected_changes = []
791
-
421
+
792
422
  for tbl, report in diff.items():
793
423
  for row in report.get("modified_rows", []):
794
424
  for f, vals in row["changes"].items():
795
425
  if self.ignore_config.should_ignore_field(tbl, f):
796
426
  continue
797
427
  if not _is_change_allowed(tbl, row["row_id"], f, vals["after"]):
798
- unexpected_changes.append(
799
- {
800
- "type": "modification",
801
- "table": tbl,
802
- "row_id": row["row_id"],
803
- "field": f,
804
- "before": vals.get("before"),
805
- "after": vals["after"],
806
- "full_row": row,
807
- }
808
- )
809
-
810
- for row in report.get("added_rows", []):
811
- if not _is_change_allowed(tbl, row["row_id"], None, "__added__"):
812
- unexpected_changes.append(
813
- {
814
- "type": "insertion",
428
+ unexpected_changes.append({
429
+ "type": "modification",
815
430
  "table": tbl,
816
431
  "row_id": row["row_id"],
817
- "field": None,
818
- "after": "__added__",
432
+ "field": f,
433
+ "before": vals.get("before"),
434
+ "after": vals["after"],
819
435
  "full_row": row,
820
- }
821
- )
822
-
436
+ })
437
+
438
+ for row in report.get("added_rows", []):
439
+ if not _is_change_allowed(tbl, row["row_id"], None, "__added__"):
440
+ unexpected_changes.append({
441
+ "type": "insertion",
442
+ "table": tbl,
443
+ "row_id": row["row_id"],
444
+ "field": None,
445
+ "after": "__added__",
446
+ "full_row": row,
447
+ })
448
+
823
449
  for row in report.get("removed_rows", []):
824
450
  if not _is_change_allowed(tbl, row["row_id"], None, "__removed__"):
825
- unexpected_changes.append(
826
- {
827
- "type": "deletion",
828
- "table": tbl,
829
- "row_id": row["row_id"],
830
- "field": None,
831
- "after": "__removed__",
832
- "full_row": row,
833
- }
834
- )
835
-
451
+ unexpected_changes.append({
452
+ "type": "deletion",
453
+ "table": tbl,
454
+ "row_id": row["row_id"],
455
+ "field": None,
456
+ "after": "__removed__",
457
+ "full_row": row,
458
+ })
459
+
836
460
  if unexpected_changes:
837
461
  # Build comprehensive error message
838
462
  error_lines = ["Unexpected database changes detected:"]
839
463
  error_lines.append("")
840
-
464
+
841
465
  for i, change in enumerate(unexpected_changes[:5], 1):
842
- error_lines.append(
843
- f"{i}. {change['type'].upper()} in table '{change['table']}':"
844
- )
466
+ error_lines.append(f"{i}. {change['type'].upper()} in table '{change['table']}':")
845
467
  error_lines.append(f" Row ID: {change['row_id']}")
846
-
468
+
847
469
  if change["type"] == "modification":
848
470
  error_lines.append(f" Field: {change['field']}")
849
471
  error_lines.append(f" Before: {repr(change['before'])}")
@@ -852,7 +474,7 @@ class AsyncSnapshotDiff:
852
474
  error_lines.append(" New row added")
853
475
  elif change["type"] == "deletion":
854
476
  error_lines.append(" Row deleted")
855
-
477
+
856
478
  # Show some context from the row
857
479
  if "full_row" in change and change["full_row"]:
858
480
  row_data = change["full_row"]
@@ -861,15 +483,13 @@ class AsyncSnapshotDiff:
861
483
  row_data.get("data", {}), max_fields=5
862
484
  )
863
485
  error_lines.append(f" Row data: {formatted_row}")
864
-
486
+
865
487
  error_lines.append("")
866
-
488
+
867
489
  if len(unexpected_changes) > 5:
868
- error_lines.append(
869
- f"... and {len(unexpected_changes) - 5} more unexpected changes"
870
- )
490
+ error_lines.append(f"... and {len(unexpected_changes) - 5} more unexpected changes")
871
491
  error_lines.append("")
872
-
492
+
873
493
  # Show what changes were allowed
874
494
  error_lines.append("Allowed changes were:")
875
495
  if allowed_changes:
@@ -881,30 +501,14 @@ class AsyncSnapshotDiff:
881
501
  f"After: {repr(allowed.get('after'))}"
882
502
  )
883
503
  if len(allowed_changes) > 3:
884
- error_lines.append(
885
- f" ... and {len(allowed_changes) - 3} more allowed changes"
886
- )
504
+ error_lines.append(f" ... and {len(allowed_changes) - 3} more allowed changes")
887
505
  else:
888
506
  error_lines.append(" (No changes were allowed)")
889
-
507
+
890
508
  raise AssertionError("\n".join(error_lines))
891
-
509
+
892
510
  return self
893
511
 
894
- async def expect_only(self, allowed_changes: List[Dict[str, Any]]):
895
- """Ensure only specified changes occurred."""
896
- # Special case: empty allowed_changes means no changes should have occurred
897
- if not allowed_changes:
898
- return await self._expect_no_changes()
899
-
900
- # For expect_only, we can optimize by only checking the specific rows mentioned
901
- if self._can_use_targeted_queries(allowed_changes):
902
- return await self._expect_only_targeted(allowed_changes)
903
-
904
- # Fall back to full diff for complex cases
905
- diff = await self._collect()
906
- return await self._validate_diff_against_allowed_changes(diff, allowed_changes)
907
-
908
512
 
909
513
  class AsyncQueryBuilder:
910
514
  """Async query builder that translates DSL to SQL and executes through the API."""