fleet-python 0.2.66b2__py3-none-any.whl → 0.2.105__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. examples/export_tasks.py +16 -5
  2. examples/export_tasks_filtered.py +245 -0
  3. examples/fetch_tasks.py +230 -0
  4. examples/import_tasks.py +140 -8
  5. examples/iterate_verifiers.py +725 -0
  6. fleet/__init__.py +128 -5
  7. fleet/_async/__init__.py +27 -3
  8. fleet/_async/base.py +24 -9
  9. fleet/_async/client.py +938 -41
  10. fleet/_async/env/client.py +60 -3
  11. fleet/_async/instance/client.py +52 -7
  12. fleet/_async/models.py +15 -0
  13. fleet/_async/resources/api.py +200 -0
  14. fleet/_async/resources/sqlite.py +1801 -46
  15. fleet/_async/tasks.py +122 -25
  16. fleet/_async/verifiers/bundler.py +22 -21
  17. fleet/_async/verifiers/verifier.py +25 -19
  18. fleet/agent/__init__.py +32 -0
  19. fleet/agent/gemini_cua/Dockerfile +45 -0
  20. fleet/agent/gemini_cua/__init__.py +10 -0
  21. fleet/agent/gemini_cua/agent.py +759 -0
  22. fleet/agent/gemini_cua/mcp/main.py +108 -0
  23. fleet/agent/gemini_cua/mcp_server/__init__.py +5 -0
  24. fleet/agent/gemini_cua/mcp_server/main.py +105 -0
  25. fleet/agent/gemini_cua/mcp_server/tools.py +178 -0
  26. fleet/agent/gemini_cua/requirements.txt +5 -0
  27. fleet/agent/gemini_cua/start.sh +30 -0
  28. fleet/agent/orchestrator.py +854 -0
  29. fleet/agent/types.py +49 -0
  30. fleet/agent/utils.py +34 -0
  31. fleet/base.py +34 -9
  32. fleet/cli.py +1061 -0
  33. fleet/client.py +1060 -48
  34. fleet/config.py +1 -1
  35. fleet/env/__init__.py +16 -0
  36. fleet/env/client.py +60 -3
  37. fleet/eval/__init__.py +15 -0
  38. fleet/eval/uploader.py +231 -0
  39. fleet/exceptions.py +8 -0
  40. fleet/instance/client.py +53 -8
  41. fleet/instance/models.py +1 -0
  42. fleet/models.py +303 -0
  43. fleet/proxy/__init__.py +25 -0
  44. fleet/proxy/proxy.py +453 -0
  45. fleet/proxy/whitelist.py +244 -0
  46. fleet/resources/api.py +200 -0
  47. fleet/resources/sqlite.py +1845 -46
  48. fleet/tasks.py +113 -20
  49. fleet/utils/__init__.py +7 -0
  50. fleet/utils/http_logging.py +178 -0
  51. fleet/utils/logging.py +13 -0
  52. fleet/utils/playwright.py +440 -0
  53. fleet/verifiers/bundler.py +22 -21
  54. fleet/verifiers/db.py +985 -1
  55. fleet/verifiers/decorator.py +1 -1
  56. fleet/verifiers/verifier.py +25 -19
  57. {fleet_python-0.2.66b2.dist-info → fleet_python-0.2.105.dist-info}/METADATA +28 -1
  58. fleet_python-0.2.105.dist-info/RECORD +115 -0
  59. {fleet_python-0.2.66b2.dist-info → fleet_python-0.2.105.dist-info}/WHEEL +1 -1
  60. fleet_python-0.2.105.dist-info/entry_points.txt +2 -0
  61. tests/test_app_method.py +85 -0
  62. tests/test_expect_exactly.py +4148 -0
  63. tests/test_expect_only.py +2593 -0
  64. tests/test_instance_dispatch.py +607 -0
  65. tests/test_sqlite_resource_dual_mode.py +263 -0
  66. tests/test_sqlite_shared_memory_behavior.py +117 -0
  67. fleet_python-0.2.66b2.dist-info/RECORD +0 -81
  68. tests/test_verifier_security.py +0 -427
  69. {fleet_python-0.2.66b2.dist-info → fleet_python-0.2.105.dist-info}/licenses/LICENSE +0 -0
  70. {fleet_python-0.2.66b2.dist-info → fleet_python-0.2.105.dist-info}/top_level.txt +0 -0
fleet/resources/sqlite.py CHANGED
@@ -6,6 +6,8 @@ from datetime import datetime
6
6
  import tempfile
7
7
  import sqlite3
8
8
  import os
9
+ import re
10
+ import json
9
11
 
10
12
  from typing import TYPE_CHECKING
11
13
 
@@ -19,11 +21,23 @@ from fleet.verifiers.db import (
19
21
  _get_row_identifier,
20
22
  _format_row_for_error,
21
23
  _values_equivalent,
24
+ validate_diff_expect_exactly,
22
25
  )
23
26
 
24
27
 
28
+ def _quote_identifier(identifier: str) -> str:
29
+ """Quote an identifier (table or column name) for SQLite.
30
+
31
+ SQLite uses double quotes for identifiers and escapes internal quotes by doubling them.
32
+ This handles reserved keywords like 'order', 'table', etc.
33
+ """
34
+ # Escape any double quotes in the identifier by doubling them
35
+ escaped = identifier.replace('"', '""')
36
+ return f'"{escaped}"'
37
+
38
+
25
39
  class SyncDatabaseSnapshot:
26
- """Async database snapshot that fetches data through API and stores locally for diffing."""
40
+ """Lazy database snapshot that fetches data on-demand through API."""
27
41
 
28
42
  def __init__(self, resource: "SQLiteResource", name: Optional[str] = None):
29
43
  self.resource = resource
@@ -31,11 +45,12 @@ class SyncDatabaseSnapshot:
31
45
  self.created_at = datetime.utcnow()
32
46
  self._data: Dict[str, List[Dict[str, Any]]] = {}
33
47
  self._schemas: Dict[str, List[str]] = {}
34
- self._fetched = False
48
+ self._table_names: Optional[List[str]] = None
49
+ self._fetched_tables: set = set()
35
50
 
36
- def _ensure_fetched(self):
37
- """Fetch all data from remote database if not already fetched."""
38
- if self._fetched:
51
+ def _ensure_tables_list(self):
52
+ """Fetch just the list of table names if not already fetched."""
53
+ if self._table_names is not None:
39
54
  return
40
55
 
41
56
  # Get all tables
@@ -44,35 +59,36 @@ class SyncDatabaseSnapshot:
44
59
  )
45
60
 
46
61
  if not tables_response.rows:
47
- self._fetched = True
62
+ self._table_names = []
48
63
  return
49
64
 
50
- table_names = [row[0] for row in tables_response.rows]
51
-
52
- # Fetch data from each table
53
- for table in table_names:
54
- # Get table schema
55
- schema_response = self.resource.query(f"PRAGMA table_info({table})")
56
- if schema_response.rows:
57
- self._schemas[table] = [
58
- row[1] for row in schema_response.rows
59
- ] # Column names
60
-
61
- # Get all data
62
- data_response = self.resource.query(f"SELECT * FROM {table}")
63
- if data_response.rows and data_response.columns:
64
- self._data[table] = [
65
- dict(zip(data_response.columns, row)) for row in data_response.rows
66
- ]
67
- else:
68
- self._data[table] = []
65
+ self._table_names = [row[0] for row in tables_response.rows]
66
+
67
+ def _ensure_table_data(self, table: str):
68
+ """Fetch data for a specific table on demand."""
69
+ if table in self._fetched_tables:
70
+ return
69
71
 
70
- self._fetched = True
72
+ # Get table schema
73
+ schema_response = self.resource.query(f"PRAGMA table_info({_quote_identifier(table)})")
74
+ if schema_response.rows:
75
+ self._schemas[table] = [row[1] for row in schema_response.rows] # Column names
76
+
77
+ # Get all data for this table
78
+ data_response = self.resource.query(f"SELECT * FROM {_quote_identifier(table)}")
79
+ if data_response.rows and data_response.columns:
80
+ self._data[table] = [
81
+ dict(zip(data_response.columns, row)) for row in data_response.rows
82
+ ]
83
+ else:
84
+ self._data[table] = []
85
+
86
+ self._fetched_tables.add(table)
71
87
 
72
88
  def tables(self) -> List[str]:
73
89
  """Get list of all tables in the snapshot."""
74
- self._ensure_fetched()
75
- return list(self._data.keys())
90
+ self._ensure_tables_list()
91
+ return list(self._table_names) if self._table_names else []
76
92
 
77
93
  def table(self, table_name: str) -> "SyncSnapshotQueryBuilder":
78
94
  """Create a query builder for snapshot data."""
@@ -84,13 +100,12 @@ class SyncDatabaseSnapshot:
84
100
  ignore_config: Optional[IgnoreConfig] = None,
85
101
  ) -> "SyncSnapshotDiff":
86
102
  """Compare this snapshot with another."""
87
- self._ensure_fetched()
88
- other._ensure_fetched()
103
+ # No need to fetch all data upfront - diff will fetch on demand
89
104
  return SyncSnapshotDiff(self, other, ignore_config)
90
105
 
91
106
 
92
107
  class SyncSnapshotQueryBuilder:
93
- """Query builder that works on local snapshot data."""
108
+ """Query builder that works on snapshot data - can use targeted queries when possible."""
94
109
 
95
110
  def __init__(self, snapshot: SyncDatabaseSnapshot, table: str):
96
111
  self._snapshot = snapshot
@@ -100,10 +115,63 @@ class SyncSnapshotQueryBuilder:
100
115
  self._limit: Optional[int] = None
101
116
  self._order_by: Optional[str] = None
102
117
  self._order_desc: bool = False
118
+ self._use_targeted_query = True # Try to use targeted queries when possible
119
+
120
+ def _can_use_targeted_query(self) -> bool:
121
+ """Check if we can use a targeted query instead of loading all data."""
122
+ # We can use targeted query if:
123
+ # 1. We have simple equality conditions
124
+ # 2. No complex operations like joins
125
+ # 3. The query is selective (has conditions)
126
+ if not self._conditions:
127
+ return False
128
+ for col, op, val in self._conditions:
129
+ if op not in ["=", "IS", "IS NOT"]:
130
+ return False
131
+ return True
132
+
133
+ def _execute_targeted_query(self) -> List[Dict[str, Any]]:
134
+ """Execute a targeted query directly instead of loading all data."""
135
+ # Build WHERE clause
136
+ where_parts = []
137
+ for col, op, val in self._conditions:
138
+ if op == "=" and val is None:
139
+ where_parts.append(f"{_quote_identifier(col)} IS NULL")
140
+ elif op == "IS":
141
+ where_parts.append(f"{_quote_identifier(col)} IS NULL")
142
+ elif op == "IS NOT":
143
+ where_parts.append(f"{_quote_identifier(col)} IS NOT NULL")
144
+ elif op == "=":
145
+ if isinstance(val, str):
146
+ escaped_val = val.replace("'", "''")
147
+ where_parts.append(f"{_quote_identifier(col)} = '{escaped_val}'")
148
+ else:
149
+ where_parts.append(f"{_quote_identifier(col)} = '{val}'")
150
+
151
+ where_clause = " AND ".join(where_parts)
152
+
153
+ # Build full query
154
+ cols = ", ".join(self._select_cols)
155
+ query = f"SELECT {cols} FROM {_quote_identifier(self._table)} WHERE {where_clause}"
156
+
157
+ if self._order_by:
158
+ query += f" ORDER BY {self._order_by}"
159
+ if self._limit is not None:
160
+ query += f" LIMIT {self._limit}"
161
+
162
+ # Execute query
163
+ response = self._snapshot.resource.query(query)
164
+ if response.rows and response.columns:
165
+ return [dict(zip(response.columns, row)) for row in response.rows]
166
+ return []
103
167
 
104
168
  def _get_data(self) -> List[Dict[str, Any]]:
105
- """Get table data from snapshot."""
106
- self._snapshot._ensure_fetched()
169
+ """Get table data - use targeted query if possible, otherwise load all data."""
170
+ if self._use_targeted_query and self._can_use_targeted_query():
171
+ return self._execute_targeted_query()
172
+
173
+ # Fall back to loading all data
174
+ self._snapshot._ensure_table_data(self._table)
107
175
  return self._snapshot._data.get(self._table, [])
108
176
 
109
177
  def eq(self, column: str, value: Any) -> "SyncSnapshotQueryBuilder":
@@ -142,6 +210,11 @@ class SyncSnapshotQueryBuilder:
142
210
  return rows[0] if rows else None
143
211
 
144
212
  def all(self) -> List[Dict[str, Any]]:
213
+ # If we can use targeted query, _get_data already applies filters
214
+ if self._use_targeted_query and self._can_use_targeted_query():
215
+ return self._get_data()
216
+
217
+ # Otherwise, get all data and apply filters manually
145
218
  data = self._get_data()
146
219
 
147
220
  # Apply filters
@@ -206,11 +279,12 @@ class SyncSnapshotDiff:
206
279
  self.after = after
207
280
  self.ignore_config = ignore_config or IgnoreConfig()
208
281
  self._cached: Optional[Dict[str, Any]] = None
282
+ self._targeted_mode = False # Flag to use targeted queries
209
283
 
210
284
  def _get_primary_key_columns(self, table: str) -> List[str]:
211
285
  """Get primary key columns for a table."""
212
286
  # Try to get from schema
213
- schema_response = self.after.resource.query(f"PRAGMA table_info({table})")
287
+ schema_response = self.after.resource.query(f"PRAGMA table_info({_quote_identifier(table)})")
214
288
  if not schema_response.rows:
215
289
  return ["id"] # Default fallback
216
290
 
@@ -246,6 +320,10 @@ class SyncSnapshotDiff:
246
320
  # Get primary key columns
247
321
  pk_columns = self._get_primary_key_columns(tbl)
248
322
 
323
+ # Ensure data is fetched for this table
324
+ self.before._ensure_table_data(tbl)
325
+ self.after._ensure_table_data(tbl)
326
+
249
327
  # Get data from both snapshots
250
328
  before_data = self.before._data.get(tbl, [])
251
329
  after_data = self.after._data.get(tbl, [])
@@ -324,9 +402,405 @@ class SyncSnapshotDiff:
324
402
  """Expose the computed diff so callers can introspect like the legacy API."""
325
403
  return self._collect()
326
404
 
327
- def expect_only(self, allowed_changes: List[Dict[str, Any]]):
328
- """Ensure only specified changes occurred."""
329
- diff = self._collect()
405
+ def _can_use_targeted_queries(self, allowed_changes: List[Dict[str, Any]]) -> bool:
406
+ """Check if we can use targeted queries for optimization."""
407
+ # We can use targeted queries if all allowed changes specify table and pk
408
+ for change in allowed_changes:
409
+ if "table" not in change or "pk" not in change:
410
+ return False
411
+ return True
412
+
413
+ def _build_pk_where_clause(self, pk_columns: List[str], pk_value: Any) -> str:
414
+ """Build WHERE clause for primary key lookup."""
415
+ # Escape single quotes in values to prevent SQL injection
416
+ def escape_value(val: Any) -> str:
417
+ if val is None:
418
+ return "NULL"
419
+ elif isinstance(val, str):
420
+ escaped = str(val).replace("'", "''")
421
+ return f"'{escaped}'"
422
+ else:
423
+ return f"'{val}'"
424
+
425
+ if len(pk_columns) == 1:
426
+ return f"{_quote_identifier(pk_columns[0])} = {escape_value(pk_value)}"
427
+ else:
428
+ # Composite key
429
+ if isinstance(pk_value, tuple):
430
+ conditions = [
431
+ f"{_quote_identifier(col)} = {escape_value(val)}"
432
+ for col, val in zip(pk_columns, pk_value)
433
+ ]
434
+ return " AND ".join(conditions)
435
+ else:
436
+ # Shouldn't happen if data is consistent
437
+ return f"{_quote_identifier(pk_columns[0])} = {escape_value(pk_value)}"
438
+
439
+ def _expect_no_changes(self):
440
+ """Efficiently verify that no changes occurred between snapshots using row counts."""
441
+ try:
442
+ import concurrent.futures
443
+ from threading import Lock
444
+
445
+ # Get all tables from both snapshots
446
+ before_tables = set(self.before.tables())
447
+ after_tables = set(self.after.tables())
448
+
449
+ # Check for added/removed tables (excluding ignored ones)
450
+ added_tables = after_tables - before_tables
451
+ removed_tables = before_tables - after_tables
452
+
453
+ for table in added_tables:
454
+ if not self.ignore_config.should_ignore_table(table):
455
+ raise AssertionError(f"Unexpected table added: {table}")
456
+
457
+ for table in removed_tables:
458
+ if not self.ignore_config.should_ignore_table(table):
459
+ raise AssertionError(f"Unexpected table removed: {table}")
460
+
461
+ # Prepare tables to check
462
+ tables_to_check = []
463
+ all_tables = before_tables | after_tables
464
+ for table in all_tables:
465
+ if not self.ignore_config.should_ignore_table(table):
466
+ tables_to_check.append(table)
467
+
468
+ # If no tables to check, we're done
469
+ if not tables_to_check:
470
+ return self
471
+
472
+ # Use ThreadPoolExecutor to parallelize count queries
473
+ errors = []
474
+ errors_lock = Lock()
475
+ tables_needing_verification = []
476
+ verification_lock = Lock()
477
+
478
+ def check_table_counts(table: str):
479
+ """Check row counts for a single table."""
480
+ try:
481
+ # Get row counts from both snapshots
482
+ before_count = 0
483
+ after_count = 0
484
+
485
+ if table in before_tables:
486
+ before_count_response = self.before.resource.query(
487
+ f"SELECT COUNT(*) FROM {_quote_identifier(table)}"
488
+ )
489
+ before_count = (
490
+ before_count_response.rows[0][0]
491
+ if before_count_response.rows
492
+ else 0
493
+ )
494
+
495
+ if table in after_tables:
496
+ after_count_response = self.after.resource.query(
497
+ f"SELECT COUNT(*) FROM {_quote_identifier(table)}"
498
+ )
499
+ after_count = (
500
+ after_count_response.rows[0][0]
501
+ if after_count_response.rows
502
+ else 0
503
+ )
504
+
505
+ if before_count != after_count:
506
+ error_msg = (
507
+ f"Unexpected change in table '{table}': "
508
+ f"row count changed from {before_count} to {after_count}"
509
+ )
510
+ with errors_lock:
511
+ errors.append(AssertionError(error_msg))
512
+ elif before_count > 0 and before_count <= 1000:
513
+ # Mark for detailed verification
514
+ with verification_lock:
515
+ tables_needing_verification.append(table)
516
+
517
+ except Exception as e:
518
+ with errors_lock:
519
+ errors.append(e)
520
+
521
+ # Execute count checks in parallel
522
+ with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
523
+ futures = [
524
+ executor.submit(check_table_counts, table)
525
+ for table in tables_to_check
526
+ ]
527
+ concurrent.futures.wait(futures)
528
+
529
+ # Check if any errors occurred during count checking
530
+ if errors:
531
+ raise errors[0]
532
+
533
+ # Now verify small tables for data changes (also in parallel)
534
+ if tables_needing_verification:
535
+ verification_errors = []
536
+
537
+ def verify_table(table: str):
538
+ """Verify a single table's data hasn't changed."""
539
+ try:
540
+ self._verify_table_unchanged(table)
541
+ except AssertionError as e:
542
+ with errors_lock:
543
+ verification_errors.append(e)
544
+
545
+ with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
546
+ futures = [
547
+ executor.submit(verify_table, table)
548
+ for table in tables_needing_verification
549
+ ]
550
+ concurrent.futures.wait(futures)
551
+
552
+ # Check if any errors occurred during verification
553
+ if verification_errors:
554
+ raise verification_errors[0]
555
+
556
+ return self
557
+
558
+ except AssertionError:
559
+ # Re-raise assertion errors (these are expected failures)
560
+ raise
561
+ except Exception as e:
562
+ # If the optimized check fails for other reasons, fall back to full diff
563
+ print(f"Warning: Optimized no-changes check failed: {e}")
564
+ print("Falling back to full diff...")
565
+ return self._validate_diff_against_allowed_changes(self._collect(), [])
566
+
567
+ def _verify_table_unchanged(self, table: str):
568
+ """Verify that a table's data hasn't changed (for small tables)."""
569
+ # Get primary key columns
570
+ pk_columns = self._get_primary_key_columns(table)
571
+
572
+ # Get sorted data from both snapshots
573
+ order_by = ", ".join(pk_columns) if pk_columns else "rowid"
574
+
575
+ before_response = self.before.resource.query(
576
+ f"SELECT * FROM {_quote_identifier(table)} ORDER BY {order_by}"
577
+ )
578
+ after_response = self.after.resource.query(
579
+ f"SELECT * FROM {_quote_identifier(table)} ORDER BY {order_by}"
580
+ )
581
+
582
+ # Quick check: if column counts differ, there's a schema change
583
+ if before_response.columns != after_response.columns:
584
+ raise AssertionError(f"Schema changed in table '{table}'")
585
+
586
+ # Compare row by row
587
+ if len(before_response.rows) != len(after_response.rows):
588
+ raise AssertionError(
589
+ f"Row count mismatch in table '{table}': "
590
+ f"{len(before_response.rows)} vs {len(after_response.rows)}"
591
+ )
592
+
593
+ for i, (before_row, after_row) in enumerate(
594
+ zip(before_response.rows, after_response.rows)
595
+ ):
596
+ before_dict = dict(zip(before_response.columns, before_row))
597
+ after_dict = dict(zip(after_response.columns, after_row))
598
+
599
+ # Compare fields, ignoring those in ignore config
600
+ for field in before_response.columns:
601
+ if self.ignore_config.should_ignore_field(table, field):
602
+ continue
603
+
604
+ if not _values_equivalent(
605
+ before_dict.get(field), after_dict.get(field)
606
+ ):
607
+ pk_val = before_dict.get(pk_columns[0]) if pk_columns else i
608
+ raise AssertionError(
609
+ f"Unexpected change in table '{table}', row {pk_val}, "
610
+ f"field '{field}': {repr(before_dict.get(field))} -> {repr(after_dict.get(field))}"
611
+ )
612
+
613
+ def _is_field_change_allowed(
614
+ self, table_changes: List[Dict[str, Any]], pk: Any, field: str, after_val: Any
615
+ ) -> bool:
616
+ """Check if a specific field change is allowed."""
617
+ for change in table_changes:
618
+ if (
619
+ str(change.get("pk")) == str(pk)
620
+ and change.get("field") == field
621
+ and _values_equivalent(change.get("after"), after_val)
622
+ ):
623
+ return True
624
+ return False
625
+
626
+ def _is_row_change_allowed(
627
+ self, table_changes: List[Dict[str, Any]], pk: Any, change_type: str
628
+ ) -> bool:
629
+ """Check if a row addition/deletion is allowed."""
630
+ for change in table_changes:
631
+ if str(change.get("pk")) == str(pk) and change.get("after") == change_type:
632
+ return True
633
+ return False
634
+
635
+ def _expect_only_targeted(self, allowed_changes: List[Dict[str, Any]]):
636
+ """Optimized version that only queries specific rows mentioned in allowed_changes."""
637
+ import concurrent.futures
638
+ from threading import Lock
639
+
640
+ # Group allowed changes by table
641
+ changes_by_table: Dict[str, List[Dict[str, Any]]] = {}
642
+ for change in allowed_changes:
643
+ table = change["table"]
644
+ if table not in changes_by_table:
645
+ changes_by_table[table] = []
646
+ changes_by_table[table].append(change)
647
+
648
+ errors = []
649
+ errors_lock = Lock()
650
+
651
+ # Function to check a single row
652
+ def check_row(
653
+ table: str,
654
+ pk: Any,
655
+ table_changes: List[Dict[str, Any]],
656
+ pk_columns: List[str],
657
+ ):
658
+ try:
659
+ # Build WHERE clause for this PK
660
+ where_sql = self._build_pk_where_clause(pk_columns, pk)
661
+
662
+ # Query before snapshot
663
+ before_query = f"SELECT * FROM {_quote_identifier(table)} WHERE {where_sql}"
664
+ before_response = self.before.resource.query(before_query)
665
+ before_row = (
666
+ dict(zip(before_response.columns, before_response.rows[0]))
667
+ if before_response.rows
668
+ else None
669
+ )
670
+
671
+ # Query after snapshot
672
+ after_response = self.after.resource.query(before_query)
673
+ after_row = (
674
+ dict(zip(after_response.columns, after_response.rows[0]))
675
+ if after_response.rows
676
+ else None
677
+ )
678
+
679
+ # Check changes for this row
680
+ if before_row and after_row:
681
+ # Modified row - check fields
682
+ for field in set(before_row.keys()) | set(after_row.keys()):
683
+ if self.ignore_config.should_ignore_field(table, field):
684
+ continue
685
+ before_val = before_row.get(field)
686
+ after_val = after_row.get(field)
687
+ if not _values_equivalent(before_val, after_val):
688
+ # Check if this change is allowed
689
+ if not self._is_field_change_allowed(
690
+ table_changes, pk, field, after_val
691
+ ):
692
+ error_msg = (
693
+ f"Unexpected change in table '{table}', "
694
+ f"row {pk}, field '{field}': "
695
+ f"{repr(before_val)} -> {repr(after_val)}"
696
+ )
697
+ with errors_lock:
698
+ errors.append(AssertionError(error_msg))
699
+ return # Stop checking this row
700
+ elif not before_row and after_row:
701
+ # Added row
702
+ if not self._is_row_change_allowed(table_changes, pk, "__added__"):
703
+ error_msg = f"Unexpected row added in table '{table}': {pk}"
704
+ with errors_lock:
705
+ errors.append(AssertionError(error_msg))
706
+ elif before_row and not after_row:
707
+ # Removed row
708
+ if not self._is_row_change_allowed(table_changes, pk, "__removed__"):
709
+ error_msg = f"Unexpected row removed from table '{table}': {pk}"
710
+ with errors_lock:
711
+ errors.append(AssertionError(error_msg))
712
+ except Exception as e:
713
+ with errors_lock:
714
+ errors.append(e)
715
+
716
+ # Prepare all row checks
717
+ row_checks = []
718
+ for table, table_changes in changes_by_table.items():
719
+ if self.ignore_config.should_ignore_table(table):
720
+ continue
721
+
722
+ # Get primary key columns once per table
723
+ pk_columns = self._get_primary_key_columns(table)
724
+
725
+ # Extract unique PKs to check
726
+ pks_to_check = {change["pk"] for change in table_changes}
727
+
728
+ for pk in pks_to_check:
729
+ row_checks.append((table, pk, table_changes, pk_columns))
730
+
731
+ # Execute row checks in parallel
732
+ if row_checks:
733
+ with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
734
+ futures = [
735
+ executor.submit(check_row, table, pk, table_changes, pk_columns)
736
+ for table, pk, table_changes, pk_columns in row_checks
737
+ ]
738
+ concurrent.futures.wait(futures)
739
+
740
+ # Check for errors from row checks
741
+ if errors:
742
+ raise errors[0]
743
+
744
+ # Now check tables not mentioned in allowed_changes to ensure no changes
745
+ all_tables = set(self.before.tables()) | set(self.after.tables())
746
+ tables_to_verify = []
747
+
748
+ for table in all_tables:
749
+ if (
750
+ table not in changes_by_table
751
+ and not self.ignore_config.should_ignore_table(table)
752
+ ):
753
+ tables_to_verify.append(table)
754
+
755
+ # Function to verify no changes in a table
756
+ def verify_no_changes(table: str):
757
+ try:
758
+ # For tables with no allowed changes, just check row counts
759
+ before_count_response = self.before.resource.query(
760
+ f"SELECT COUNT(*) FROM {_quote_identifier(table)}"
761
+ )
762
+ before_count = (
763
+ before_count_response.rows[0][0]
764
+ if before_count_response.rows
765
+ else 0
766
+ )
767
+
768
+ after_count_response = self.after.resource.query(
769
+ f"SELECT COUNT(*) FROM {_quote_identifier(table)}"
770
+ )
771
+ after_count = (
772
+ after_count_response.rows[0][0] if after_count_response.rows else 0
773
+ )
774
+
775
+ if before_count != after_count:
776
+ error_msg = (
777
+ f"Unexpected change in table '{table}': "
778
+ f"row count changed from {before_count} to {after_count}"
779
+ )
780
+ with errors_lock:
781
+ errors.append(AssertionError(error_msg))
782
+ except Exception as e:
783
+ with errors_lock:
784
+ errors.append(e)
785
+
786
+ # Execute table verification in parallel
787
+ if tables_to_verify:
788
+ with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
789
+ futures = [
790
+ executor.submit(verify_no_changes, table) for table in tables_to_verify
791
+ ]
792
+ concurrent.futures.wait(futures)
793
+
794
+ # Final error check
795
+ if errors:
796
+ raise errors[0]
797
+
798
+ return self
799
+
800
+ def _validate_diff_against_allowed_changes(
801
+ self, diff: Dict[str, Any], allowed_changes: List[Dict[str, Any]]
802
+ ):
803
+ """Validate a collected diff against allowed changes."""
330
804
 
331
805
  def _is_change_allowed(
332
806
  table: str, row_id: Any, field: Optional[str], after_value: Any
@@ -453,6 +927,1077 @@ class SyncSnapshotDiff:
453
927
 
454
928
  return self
455
929
 
930
+ def _expect_only_targeted_v2(self, allowed_changes: List[Dict[str, Any]]):
931
+ """Optimized version that only queries specific rows mentioned in allowed_changes.
932
+
933
+ Supports v2 spec formats:
934
+ - {"table": "t", "pk": 1, "type": "insert", "fields": [...]}
935
+ - {"table": "t", "pk": 1, "type": "modify", "resulting_fields": [...], "no_other_changes": bool}
936
+ - {"table": "t", "pk": 1, "type": "delete", "fields": [...]}
937
+ - Legacy single-field specs: {"table": "t", "pk": 1, "field": "x", "after": val}
938
+ """
939
+ import concurrent.futures
940
+ from threading import Lock
941
+
942
+ # Helper functions for v2 spec validation
943
+ def _parse_fields_spec(
944
+ fields_spec: List[Tuple[str, Any]]
945
+ ) -> Dict[str, Tuple[bool, Any]]:
946
+ """Parse a fields spec into a mapping of field_name -> (should_check_value, expected_value)."""
947
+ spec_map: Dict[str, Tuple[bool, Any]] = {}
948
+ for spec_tuple in fields_spec:
949
+ if len(spec_tuple) != 2:
950
+ raise ValueError(
951
+ f"Invalid field spec tuple: {spec_tuple}. "
952
+ f"Expected 2-tuple like ('field', value), ('field', None), or ('field', ...)"
953
+ )
954
+ field_name, expected_value = spec_tuple
955
+ if expected_value is ...:
956
+ spec_map[field_name] = (False, None)
957
+ else:
958
+ spec_map[field_name] = (True, expected_value)
959
+ return spec_map
960
+
961
+ def _get_spec_for_pk(table: str, pk: Any) -> Optional[Dict[str, Any]]:
962
+ """Get the spec for a given table/pk."""
963
+ for allowed in allowed_changes:
964
+ if (
965
+ allowed["table"] == table
966
+ and str(allowed.get("pk")) == str(pk)
967
+ ):
968
+ return allowed
969
+ return None
970
+
971
+ def _get_all_specs_for_pk(table: str, pk: Any) -> List[Dict[str, Any]]:
972
+ """Get all specs for a given table/pk (for legacy multi-field specs)."""
973
+ specs = []
974
+ for allowed in allowed_changes:
975
+ if (
976
+ allowed["table"] == table
977
+ and str(allowed.get("pk")) == str(pk)
978
+ ):
979
+ specs.append(allowed)
980
+ return specs
981
+
982
+ def _validate_insert_row(
983
+ table: str, pk: Any, row_data: Dict[str, Any], specs: List[Dict[str, Any]]
984
+ ) -> Optional[str]:
985
+ """Validate an inserted row against specs. Returns error message or None."""
986
+ # Check for type: "insert" spec with fields
987
+ for spec in specs:
988
+ if spec.get("type") == "insert":
989
+ fields_spec = spec.get("fields")
990
+ if fields_spec is not None:
991
+ # Validate each field
992
+ spec_map = _parse_fields_spec(fields_spec)
993
+ for field_name, field_value in row_data.items():
994
+ if field_name == "rowid":
995
+ continue
996
+ if self.ignore_config.should_ignore_field(table, field_name):
997
+ continue
998
+ if field_name not in spec_map:
999
+ return f"Field '{field_name}' not in insert spec for table '{table}' pk={pk}"
1000
+ should_check, expected_value = spec_map[field_name]
1001
+ if should_check and not _values_equivalent(expected_value, field_value):
1002
+ return (
1003
+ f"Insert mismatch in table '{table}' pk={pk}, "
1004
+ f"field '{field_name}': expected {repr(expected_value)}, got {repr(field_value)}"
1005
+ )
1006
+ # type: "insert" found (with or without fields) - allowed
1007
+ return None
1008
+
1009
+ # Check for legacy whole-row spec
1010
+ for spec in specs:
1011
+ if spec.get("fields") is None and spec.get("after") == "__added__":
1012
+ return None
1013
+
1014
+ return f"Unexpected row added in table '{table}': pk={pk}"
1015
+
1016
+ def _validate_delete_row(
1017
+ table: str, pk: Any, row_data: Dict[str, Any], specs: List[Dict[str, Any]]
1018
+ ) -> Optional[str]:
1019
+ """Validate a deleted row against specs. Returns error message or None."""
1020
+ # Check for type: "delete" spec with optional fields
1021
+ for spec in specs:
1022
+ if spec.get("type") == "delete":
1023
+ fields_spec = spec.get("fields")
1024
+ if fields_spec is not None:
1025
+ # Validate each field against the deleted row
1026
+ spec_map = _parse_fields_spec(fields_spec)
1027
+ for field_name, (should_check, expected_value) in spec_map.items():
1028
+ if field_name not in row_data:
1029
+ return f"Field '{field_name}' in delete spec not found in row for table '{table}' pk={pk}"
1030
+ if should_check and not _values_equivalent(expected_value, row_data[field_name]):
1031
+ return (
1032
+ f"Delete mismatch in table '{table}' pk={pk}, "
1033
+ f"field '{field_name}': expected {repr(expected_value)}, got {repr(row_data[field_name])}"
1034
+ )
1035
+ # type: "delete" found (with or without fields) - allowed
1036
+ return None
1037
+
1038
+ # Check for legacy whole-row spec
1039
+ for spec in specs:
1040
+ if spec.get("fields") is None and spec.get("after") == "__removed__":
1041
+ return None
1042
+
1043
+ return f"Unexpected row removed from table '{table}': pk={pk}"
1044
+
1045
+ def _validate_modify_row(
1046
+ table: str,
1047
+ pk: Any,
1048
+ before_row: Dict[str, Any],
1049
+ after_row: Dict[str, Any],
1050
+ specs: List[Dict[str, Any]],
1051
+ ) -> Optional[str]:
1052
+ """Validate a modified row against specs. Returns error message or None."""
1053
+ # Collect actual changes
1054
+ changed_fields: Dict[str, Dict[str, Any]] = {}
1055
+ for field in set(before_row.keys()) | set(after_row.keys()):
1056
+ if self.ignore_config.should_ignore_field(table, field):
1057
+ continue
1058
+ before_val = before_row.get(field)
1059
+ after_val = after_row.get(field)
1060
+ if not _values_equivalent(before_val, after_val):
1061
+ changed_fields[field] = {"before": before_val, "after": after_val}
1062
+
1063
+ if not changed_fields:
1064
+ return None # No changes
1065
+
1066
+ # Check for type: "modify" spec with resulting_fields
1067
+ for spec in specs:
1068
+ if spec.get("type") == "modify":
1069
+ resulting_fields = spec.get("resulting_fields")
1070
+ if resulting_fields is not None:
1071
+ # Validate no_other_changes is provided
1072
+ if "no_other_changes" not in spec:
1073
+ raise ValueError(
1074
+ f"Modify spec for table '{table}' pk={pk} "
1075
+ f"has 'resulting_fields' but missing required 'no_other_changes' field."
1076
+ )
1077
+ no_other_changes = spec["no_other_changes"]
1078
+ if not isinstance(no_other_changes, bool):
1079
+ raise ValueError(
1080
+ f"Modify spec for table '{table}' pk={pk} "
1081
+ f"'no_other_changes' must be boolean, got {type(no_other_changes).__name__}"
1082
+ )
1083
+
1084
+ spec_map = _parse_fields_spec(resulting_fields)
1085
+
1086
+ # Validate changed fields
1087
+ for field_name, vals in changed_fields.items():
1088
+ after_val = vals["after"]
1089
+ if field_name not in spec_map:
1090
+ if no_other_changes:
1091
+ return (
1092
+ f"Unexpected field change in table '{table}' pk={pk}: "
1093
+ f"field '{field_name}' not in resulting_fields"
1094
+ )
1095
+ # no_other_changes=False: ignore this field
1096
+ else:
1097
+ should_check, expected_value = spec_map[field_name]
1098
+ if should_check and not _values_equivalent(expected_value, after_val):
1099
+ return (
1100
+ f"Modify mismatch in table '{table}' pk={pk}, "
1101
+ f"field '{field_name}': expected {repr(expected_value)}, got {repr(after_val)}"
1102
+ )
1103
+ return None # Validation passed
1104
+ else:
1105
+ # type: "modify" without resulting_fields - allow any modification
1106
+ return None
1107
+
1108
+ # Check for legacy single-field specs
1109
+ for field_name, vals in changed_fields.items():
1110
+ after_val = vals["after"]
1111
+ field_allowed = False
1112
+ for spec in specs:
1113
+ if (
1114
+ spec.get("field") == field_name
1115
+ and _values_equivalent(spec.get("after"), after_val)
1116
+ ):
1117
+ field_allowed = True
1118
+ break
1119
+ if not field_allowed:
1120
+ return (
1121
+ f"Unexpected change in table '{table}' pk={pk}, "
1122
+ f"field '{field_name}': {repr(vals['before'])} -> {repr(after_val)}"
1123
+ )
1124
+
1125
+ return None
1126
+
1127
+ # Group allowed changes by table
1128
+ changes_by_table: Dict[str, List[Dict[str, Any]]] = {}
1129
+ for change in allowed_changes:
1130
+ table = change["table"]
1131
+ if table not in changes_by_table:
1132
+ changes_by_table[table] = []
1133
+ changes_by_table[table].append(change)
1134
+
1135
+ errors = []
1136
+ errors_lock = Lock()
1137
+
1138
+ # Function to check a single row
1139
+ def check_row(
1140
+ table: str,
1141
+ pk: Any,
1142
+ pk_columns: List[str],
1143
+ ):
1144
+ try:
1145
+ # Build WHERE clause for this PK
1146
+ where_sql = self._build_pk_where_clause(pk_columns, pk)
1147
+
1148
+ # Query before snapshot
1149
+ before_query = f"SELECT * FROM {_quote_identifier(table)} WHERE {where_sql}"
1150
+ before_response = self.before.resource.query(before_query)
1151
+ before_row = (
1152
+ dict(zip(before_response.columns, before_response.rows[0]))
1153
+ if before_response.rows
1154
+ else None
1155
+ )
1156
+
1157
+ # Query after snapshot
1158
+ after_response = self.after.resource.query(before_query)
1159
+ after_row = (
1160
+ dict(zip(after_response.columns, after_response.rows[0]))
1161
+ if after_response.rows
1162
+ else None
1163
+ )
1164
+
1165
+ # Get all specs for this table/pk
1166
+ specs = _get_all_specs_for_pk(table, pk)
1167
+
1168
+ # Check changes for this row
1169
+ if before_row and after_row:
1170
+ # Modified row
1171
+ error = _validate_modify_row(table, pk, before_row, after_row, specs)
1172
+ if error:
1173
+ with errors_lock:
1174
+ errors.append(AssertionError(error))
1175
+ elif not before_row and after_row:
1176
+ # Added row
1177
+ error = _validate_insert_row(table, pk, after_row, specs)
1178
+ if error:
1179
+ with errors_lock:
1180
+ errors.append(AssertionError(error))
1181
+ elif before_row and not after_row:
1182
+ # Removed row
1183
+ error = _validate_delete_row(table, pk, before_row, specs)
1184
+ if error:
1185
+ with errors_lock:
1186
+ errors.append(AssertionError(error))
1187
+
1188
+ except Exception as e:
1189
+ with errors_lock:
1190
+ errors.append(e)
1191
+
1192
+ # Prepare all row checks
1193
+ row_checks = []
1194
+ for table, table_changes in changes_by_table.items():
1195
+ if self.ignore_config.should_ignore_table(table):
1196
+ continue
1197
+
1198
+ # Get primary key columns once per table
1199
+ pk_columns = self._get_primary_key_columns(table)
1200
+
1201
+ # Extract unique PKs to check
1202
+ pks_to_check = {change["pk"] for change in table_changes}
1203
+
1204
+ for pk in pks_to_check:
1205
+ row_checks.append((table, pk, pk_columns))
1206
+
1207
+ # Execute row checks in parallel
1208
+ if row_checks:
1209
+ with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
1210
+ futures = [
1211
+ executor.submit(check_row, table, pk, pk_columns)
1212
+ for table, pk, pk_columns in row_checks
1213
+ ]
1214
+ concurrent.futures.wait(futures)
1215
+
1216
+ # Check for errors from row checks
1217
+ if errors:
1218
+ raise errors[0]
1219
+
1220
+ # Now check tables not mentioned in allowed_changes to ensure no changes
1221
+ all_tables = set(self.before.tables()) | set(self.after.tables())
1222
+ tables_to_verify = []
1223
+
1224
+ for table in all_tables:
1225
+ if (
1226
+ table not in changes_by_table
1227
+ and not self.ignore_config.should_ignore_table(table)
1228
+ ):
1229
+ tables_to_verify.append(table)
1230
+
1231
+ # Function to verify no changes in a table
1232
+ def verify_no_changes(table: str):
1233
+ try:
1234
+ # For tables with no allowed changes, just check row counts
1235
+ before_count_response = self.before.resource.query(
1236
+ f"SELECT COUNT(*) FROM {_quote_identifier(table)}"
1237
+ )
1238
+ before_count = (
1239
+ before_count_response.rows[0][0]
1240
+ if before_count_response.rows
1241
+ else 0
1242
+ )
1243
+
1244
+ after_count_response = self.after.resource.query(
1245
+ f"SELECT COUNT(*) FROM {_quote_identifier(table)}"
1246
+ )
1247
+ after_count = (
1248
+ after_count_response.rows[0][0] if after_count_response.rows else 0
1249
+ )
1250
+
1251
+ if before_count != after_count:
1252
+ error_msg = (
1253
+ f"Unexpected change in table '{table}': "
1254
+ f"row count changed from {before_count} to {after_count}"
1255
+ )
1256
+ with errors_lock:
1257
+ errors.append(AssertionError(error_msg))
1258
+ except Exception as e:
1259
+ with errors_lock:
1260
+ errors.append(e)
1261
+
1262
+ # Execute table verification in parallel
1263
+ if tables_to_verify:
1264
+ with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
1265
+ futures = [
1266
+ executor.submit(verify_no_changes, table) for table in tables_to_verify
1267
+ ]
1268
+ concurrent.futures.wait(futures)
1269
+
1270
+ # Final error check
1271
+ if errors:
1272
+ raise errors[0]
1273
+
1274
+ return self
1275
+
1276
+ def _validate_diff_against_allowed_changes_v2(
1277
+ self, diff: Dict[str, Any], allowed_changes: List[Dict[str, Any]]
1278
+ ):
1279
+ """Validate a collected diff against allowed changes with field-level spec support.
1280
+
1281
+ This version supports explicit change types via the "type" field:
1282
+ 1. Insert specs: {"table": "t", "pk": 1, "type": "insert", "fields": [("name", "value"), ("status", ...)]}
1283
+ - ("name", value): check that field equals value
1284
+ - ("name", None): check that field is SQL NULL
1285
+ - ("name", ...): don't check the value, just acknowledge the field exists
1286
+ 2. Modify specs: {"table": "t", "pk": 1, "type": "modify", "resulting_fields": [...], "no_other_changes": True/False}
1287
+ - Uses "resulting_fields" (not "fields") to be explicit about what's being checked
1288
+ - "no_other_changes" is REQUIRED and must be True or False:
1289
+ - True: Every changed field must be in resulting_fields (strict mode)
1290
+ - False: Only check fields in resulting_fields match, ignore other changes
1291
+ - ("field_name", value): check that after value equals value
1292
+ - ("field_name", None): check that after value is SQL NULL
1293
+ - ("field_name", ...): don't check value, just acknowledge field changed
1294
+ 3. Delete specs:
1295
+ - Without field validation: {"table": "t", "pk": 1, "type": "delete"}
1296
+ - With field validation: {"table": "t", "pk": 1, "type": "delete", "fields": [...]}
1297
+ 4. Whole-row specs (legacy):
1298
+ - For additions: {"table": "t", "pk": 1, "fields": None, "after": "__added__"}
1299
+ - For deletions: {"table": "t", "pk": 1, "fields": None, "after": "__removed__"}
1300
+
1301
+ When using "fields" for inserts, every field must be accounted for in the list.
1302
+ For modifications, use "resulting_fields" with explicit "no_other_changes".
1303
+ For deletions with "fields", all specified fields are validated against the deleted row.
1304
+ """
1305
+
1306
+ def _is_change_allowed(
1307
+ table: str, row_id: Any, field: Optional[str], after_value: Any
1308
+ ) -> bool:
1309
+ """Check if a change is in the allowed list using semantic comparison."""
1310
+ for allowed in allowed_changes:
1311
+ allowed_pk = allowed.get("pk")
1312
+ # Handle type conversion for primary key comparison
1313
+ pk_match = (
1314
+ str(allowed_pk) == str(row_id) if allowed_pk is not None else False
1315
+ )
1316
+
1317
+ # For whole-row specs, check "fields": None; for field-level, check "field"
1318
+ field_match = (
1319
+ ("fields" in allowed and allowed.get("fields") is None)
1320
+ if field is None
1321
+ else allowed.get("field") == field
1322
+ )
1323
+ if (
1324
+ allowed["table"] == table
1325
+ and pk_match
1326
+ and field_match
1327
+ and _values_equivalent(allowed.get("after"), after_value)
1328
+ ):
1329
+ return True
1330
+ return False
1331
+
1332
+ def _get_fields_spec_for_type(
1333
+ table: str, row_id: Any, change_type: str
1334
+ ) -> Optional[List[Tuple[str, Any]]]:
1335
+ """Get the bulk fields spec for a given table/row/type if it exists.
1336
+
1337
+ Args:
1338
+ table: The table name
1339
+ row_id: The primary key value
1340
+ change_type: One of "insert", "modify", or "delete"
1341
+
1342
+ Note: For "modify" type, use _get_modify_spec instead.
1343
+ """
1344
+ for allowed in allowed_changes:
1345
+ allowed_pk = allowed.get("pk")
1346
+ pk_match = (
1347
+ str(allowed_pk) == str(row_id) if allowed_pk is not None else False
1348
+ )
1349
+ if (
1350
+ allowed["table"] == table
1351
+ and pk_match
1352
+ and allowed.get("type") == change_type
1353
+ and "fields" in allowed
1354
+ ):
1355
+ return allowed["fields"]
1356
+ return None
1357
+
1358
+ def _get_modify_spec(table: str, row_id: Any) -> Optional[Dict[str, Any]]:
1359
+ """Get the modify spec for a given table/row if it exists.
1360
+
1361
+ Returns the full spec dict containing:
1362
+ - resulting_fields: List of field tuples
1363
+ - no_other_changes: Boolean (required)
1364
+
1365
+ Returns None if no modify spec found.
1366
+ """
1367
+ for allowed in allowed_changes:
1368
+ allowed_pk = allowed.get("pk")
1369
+ pk_match = (
1370
+ str(allowed_pk) == str(row_id) if allowed_pk is not None else False
1371
+ )
1372
+ if (
1373
+ allowed["table"] == table
1374
+ and pk_match
1375
+ and allowed.get("type") == "modify"
1376
+ ):
1377
+ return allowed
1378
+ return None
1379
+
1380
+ def _is_type_allowed(table: str, row_id: Any, change_type: str) -> bool:
1381
+ """Check if a change type is allowed for the given table/row (with or without fields)."""
1382
+ for allowed in allowed_changes:
1383
+ allowed_pk = allowed.get("pk")
1384
+ pk_match = (
1385
+ str(allowed_pk) == str(row_id) if allowed_pk is not None else False
1386
+ )
1387
+ if (
1388
+ allowed["table"] == table
1389
+ and pk_match
1390
+ and allowed.get("type") == change_type
1391
+ ):
1392
+ return True
1393
+ return False
1394
+
1395
+ def _parse_fields_spec(
1396
+ fields_spec: List[Tuple[str, Any]]
1397
+ ) -> Dict[str, Tuple[bool, Any]]:
1398
+ """Parse a fields spec into a mapping of field_name -> (should_check_value, expected_value)."""
1399
+ spec_map: Dict[str, Tuple[bool, Any]] = {}
1400
+ for spec_tuple in fields_spec:
1401
+ if len(spec_tuple) != 2:
1402
+ raise ValueError(
1403
+ f"Invalid field spec tuple: {spec_tuple}. "
1404
+ f"Expected 2-tuple like ('field', value), ('field', None), or ('field', ...)"
1405
+ )
1406
+ field_name, expected_value = spec_tuple
1407
+ if expected_value is ...:
1408
+ # Ellipsis: don't check value, just acknowledge field exists
1409
+ spec_map[field_name] = (False, None)
1410
+ else:
1411
+ # Any other value (including None for NULL check): check value
1412
+ spec_map[field_name] = (True, expected_value)
1413
+ return spec_map
1414
+
1415
+ def _validate_row_with_fields_spec(
1416
+ table: str,
1417
+ row_id: Any,
1418
+ row_data: Dict[str, Any],
1419
+ fields_spec: List[Tuple[str, Any]],
1420
+ ) -> Optional[List[Tuple[str, Any, str]]]:
1421
+ """Validate a row against a bulk fields spec.
1422
+
1423
+ Returns None if validation passes, or a list of (field, actual_value, issue)
1424
+ tuples for mismatches.
1425
+
1426
+ Field spec semantics:
1427
+ - ("field_name", value): check that field equals value
1428
+ - ("field_name", None): check that field is SQL NULL
1429
+ - ("field_name", ...): don't check value (acknowledge field exists)
1430
+ """
1431
+ spec_map = _parse_fields_spec(fields_spec)
1432
+ unmatched_fields: List[Tuple[str, Any, str]] = []
1433
+
1434
+ for field_name, field_value in row_data.items():
1435
+ # Skip rowid as it's internal
1436
+ if field_name == "rowid":
1437
+ continue
1438
+ # Skip ignored fields
1439
+ if self.ignore_config.should_ignore_field(table, field_name):
1440
+ continue
1441
+
1442
+ if field_name not in spec_map:
1443
+ # Field not in spec - this is an error
1444
+ unmatched_fields.append(
1445
+ (field_name, field_value, "NOT_IN_FIELDS_SPEC")
1446
+ )
1447
+ else:
1448
+ should_check, expected_value = spec_map[field_name]
1449
+ if should_check and not _values_equivalent(
1450
+ expected_value, field_value
1451
+ ):
1452
+ # Value doesn't match
1453
+ unmatched_fields.append(
1454
+ (field_name, field_value, f"expected {repr(expected_value)}")
1455
+ )
1456
+
1457
+ return unmatched_fields if unmatched_fields else None
1458
+
1459
+ def _validate_modification_with_fields_spec(
1460
+ table: str,
1461
+ row_id: Any,
1462
+ row_changes: Dict[str, Dict[str, Any]],
1463
+ resulting_fields: List[Tuple[str, Any]],
1464
+ no_other_changes: bool,
1465
+ ) -> Optional[List[Tuple[str, Any, str]]]:
1466
+ """Validate a modification against a resulting_fields spec.
1467
+
1468
+ Returns None if validation passes, or a list of (field, actual_value, issue)
1469
+ tuples for mismatches.
1470
+
1471
+ Args:
1472
+ table: The table name
1473
+ row_id: The row primary key
1474
+ row_changes: Dict of field_name -> {"before": ..., "after": ...}
1475
+ resulting_fields: List of field tuples to validate
1476
+ no_other_changes: If True, all changed fields must be in resulting_fields.
1477
+ If False, only validate fields in resulting_fields, ignore others.
1478
+
1479
+ Field spec semantics for modifications:
1480
+ - ("field_name", value): check that after value equals value
1481
+ - ("field_name", None): check that after value is SQL NULL
1482
+ - ("field_name", ...): don't check value, just acknowledge field changed
1483
+ """
1484
+ spec_map = _parse_fields_spec(resulting_fields)
1485
+ unmatched_fields: List[Tuple[str, Any, str]] = []
1486
+
1487
+ for field_name, vals in row_changes.items():
1488
+ # Skip ignored fields
1489
+ if self.ignore_config.should_ignore_field(table, field_name):
1490
+ continue
1491
+
1492
+ after_value = vals["after"]
1493
+
1494
+ if field_name not in spec_map:
1495
+ # Changed field not in spec
1496
+ if no_other_changes:
1497
+ # Strict mode: all changed fields must be accounted for
1498
+ unmatched_fields.append(
1499
+ (field_name, after_value, "NOT_IN_RESULTING_FIELDS")
1500
+ )
1501
+ # If no_other_changes=False, ignore fields not in spec
1502
+ else:
1503
+ should_check, expected_value = spec_map[field_name]
1504
+ if should_check and not _values_equivalent(
1505
+ expected_value, after_value
1506
+ ):
1507
+ # Value doesn't match
1508
+ unmatched_fields.append(
1509
+ (field_name, after_value, f"expected {repr(expected_value)}")
1510
+ )
1511
+
1512
+ return unmatched_fields if unmatched_fields else None
1513
+
1514
+
1515
+ # Collect all unexpected changes for detailed reporting
1516
+ unexpected_changes = []
1517
+
1518
+ for tbl, report in diff.items():
1519
+ for row in report.get("modified_rows", []):
1520
+ row_changes = row["changes"]
1521
+
1522
+ # Check for modify spec with resulting_fields
1523
+ modify_spec = _get_modify_spec(tbl, row["row_id"])
1524
+ if modify_spec is not None:
1525
+ resulting_fields = modify_spec.get("resulting_fields")
1526
+ if resulting_fields is not None:
1527
+ # Validate that no_other_changes is provided
1528
+ if "no_other_changes" not in modify_spec:
1529
+ raise ValueError(
1530
+ f"Modify spec for table '{tbl}' pk={row['row_id']} "
1531
+ f"has 'resulting_fields' but missing required 'no_other_changes' field. "
1532
+ f"Set 'no_other_changes': True to verify no other fields changed, "
1533
+ f"or 'no_other_changes': False to only check the specified fields."
1534
+ )
1535
+ no_other_changes = modify_spec["no_other_changes"]
1536
+ if not isinstance(no_other_changes, bool):
1537
+ raise ValueError(
1538
+ f"Modify spec for table '{tbl}' pk={row['row_id']} "
1539
+ f"has 'no_other_changes' but it must be a boolean (True or False), "
1540
+ f"got {type(no_other_changes).__name__}: {repr(no_other_changes)}"
1541
+ )
1542
+
1543
+ unmatched = _validate_modification_with_fields_spec(
1544
+ tbl, row["row_id"], row_changes, resulting_fields, no_other_changes
1545
+ )
1546
+ if unmatched:
1547
+ unexpected_changes.append(
1548
+ {
1549
+ "type": "modification",
1550
+ "table": tbl,
1551
+ "row_id": row["row_id"],
1552
+ "field": None,
1553
+ "before": None,
1554
+ "after": None,
1555
+ "full_row": row,
1556
+ "unmatched_fields": unmatched,
1557
+ }
1558
+ )
1559
+ continue # Skip to next row
1560
+ else:
1561
+ # Modify spec without resulting_fields - just allow the modification
1562
+ continue # Skip to next row
1563
+
1564
+ # Fall back to single-field specs (legacy)
1565
+ for f, vals in row_changes.items():
1566
+ if self.ignore_config.should_ignore_field(tbl, f):
1567
+ continue
1568
+ if not _is_change_allowed(tbl, row["row_id"], f, vals["after"]):
1569
+ unexpected_changes.append(
1570
+ {
1571
+ "type": "modification",
1572
+ "table": tbl,
1573
+ "row_id": row["row_id"],
1574
+ "field": f,
1575
+ "before": vals.get("before"),
1576
+ "after": vals["after"],
1577
+ "full_row": row,
1578
+ }
1579
+ )
1580
+
1581
+ for row in report.get("added_rows", []):
1582
+ row_data = row.get("data", {})
1583
+
1584
+ # Check for bulk fields spec (type: "insert")
1585
+ fields_spec = _get_fields_spec_for_type(tbl, row["row_id"], "insert")
1586
+ if fields_spec is not None:
1587
+ unmatched = _validate_row_with_fields_spec(
1588
+ tbl, row["row_id"], row_data, fields_spec
1589
+ )
1590
+ if unmatched:
1591
+ unexpected_changes.append(
1592
+ {
1593
+ "type": "insertion",
1594
+ "table": tbl,
1595
+ "row_id": row["row_id"],
1596
+ "field": None,
1597
+ "after": "__added__",
1598
+ "full_row": row,
1599
+ "unmatched_fields": unmatched,
1600
+ }
1601
+ )
1602
+ continue # Skip to next row
1603
+
1604
+ # Check if insertion is allowed without field validation
1605
+ if _is_type_allowed(tbl, row["row_id"], "insert"):
1606
+ continue # Insertion is allowed, skip to next row
1607
+
1608
+ # Check for whole-row spec (legacy)
1609
+ whole_row_allowed = _is_change_allowed(
1610
+ tbl, row["row_id"], None, "__added__"
1611
+ )
1612
+
1613
+ if not whole_row_allowed:
1614
+ unexpected_changes.append(
1615
+ {
1616
+ "type": "insertion",
1617
+ "table": tbl,
1618
+ "row_id": row["row_id"],
1619
+ "field": None,
1620
+ "after": "__added__",
1621
+ "full_row": row,
1622
+ }
1623
+ )
1624
+
1625
+ for row in report.get("removed_rows", []):
1626
+ row_data = row.get("data", {})
1627
+
1628
+ # Check for bulk fields spec (type: "delete")
1629
+ fields_spec = _get_fields_spec_for_type(tbl, row["row_id"], "delete")
1630
+ if fields_spec is not None:
1631
+ unmatched = _validate_row_with_fields_spec(
1632
+ tbl, row["row_id"], row_data, fields_spec
1633
+ )
1634
+ if unmatched:
1635
+ unexpected_changes.append(
1636
+ {
1637
+ "type": "deletion",
1638
+ "table": tbl,
1639
+ "row_id": row["row_id"],
1640
+ "field": None,
1641
+ "after": "__removed__",
1642
+ "full_row": row,
1643
+ "unmatched_fields": unmatched,
1644
+ }
1645
+ )
1646
+ continue # Skip to next row
1647
+
1648
+ # Check if deletion is allowed without field validation
1649
+ if _is_type_allowed(tbl, row["row_id"], "delete"):
1650
+ continue # Deletion is allowed, skip to next row
1651
+
1652
+ # Check for whole-row spec (legacy)
1653
+ whole_row_allowed = _is_change_allowed(
1654
+ tbl, row["row_id"], None, "__removed__"
1655
+ )
1656
+
1657
+ if not whole_row_allowed:
1658
+ unexpected_changes.append(
1659
+ {
1660
+ "type": "deletion",
1661
+ "table": tbl,
1662
+ "row_id": row["row_id"],
1663
+ "field": None,
1664
+ "after": "__removed__",
1665
+ "full_row": row,
1666
+ }
1667
+ )
1668
+
1669
+ if unexpected_changes:
1670
+ # Build comprehensive error message
1671
+ error_lines = ["Unexpected database changes detected:"]
1672
+ error_lines.append("")
1673
+
1674
+ for i, change in enumerate(unexpected_changes[:5], 1):
1675
+ error_lines.append(
1676
+ f"{i}. {change['type'].upper()} in table '{change['table']}':"
1677
+ )
1678
+ error_lines.append(f" Row ID: {change['row_id']}")
1679
+
1680
+ if change["type"] == "modification":
1681
+ error_lines.append(f" Field: {change['field']}")
1682
+ error_lines.append(f" Before: {repr(change['before'])}")
1683
+ error_lines.append(f" After: {repr(change['after'])}")
1684
+ elif change["type"] == "insertion":
1685
+ error_lines.append(" New row added")
1686
+ elif change["type"] == "deletion":
1687
+ error_lines.append(" Row deleted")
1688
+
1689
+ # Show unmatched fields if present (from bulk fields spec validation)
1690
+ if "unmatched_fields" in change and change["unmatched_fields"]:
1691
+ error_lines.append(" Unmatched fields:")
1692
+ for field_info in change["unmatched_fields"][:5]:
1693
+ field_name, actual_value, issue = field_info
1694
+ error_lines.append(
1695
+ f" - {field_name}: {repr(actual_value)} ({issue})"
1696
+ )
1697
+ if len(change["unmatched_fields"]) > 10:
1698
+ error_lines.append(
1699
+ f" ... and {len(change['unmatched_fields']) - 10} more"
1700
+ )
1701
+
1702
+ # Show some context from the row
1703
+ if "full_row" in change and change["full_row"]:
1704
+ row_data = change["full_row"]
1705
+ if change["type"] == "modification" and "data" in row_data:
1706
+ # For modifications, show the current state
1707
+ formatted_row = _format_row_for_error(
1708
+ row_data.get("data", {}), max_fields=5
1709
+ )
1710
+ error_lines.append(f" Row data: {formatted_row}")
1711
+ elif (
1712
+ change["type"] in ["insertion", "deletion"]
1713
+ and "data" in row_data
1714
+ ):
1715
+ # For insertions/deletions, show the row data
1716
+ formatted_row = _format_row_for_error(
1717
+ row_data.get("data", {}), max_fields=5
1718
+ )
1719
+ error_lines.append(f" Row data: {formatted_row}")
1720
+
1721
+ error_lines.append("")
1722
+
1723
+ if len(unexpected_changes) > 5:
1724
+ error_lines.append(
1725
+ f"... and {len(unexpected_changes) - 5} more unexpected changes"
1726
+ )
1727
+ error_lines.append("")
1728
+
1729
+ # Show what changes were allowed
1730
+ error_lines.append("Allowed changes were:")
1731
+ if allowed_changes:
1732
+ for i, allowed in enumerate(allowed_changes[:3], 1):
1733
+ change_type = allowed.get("type", "unspecified")
1734
+
1735
+ # For modify type, use resulting_fields
1736
+ if change_type == "modify" and "resulting_fields" in allowed and allowed["resulting_fields"] is not None:
1737
+ fields_summary = ", ".join(
1738
+ f[0] if len(f) == 1 else f"{f[0]}={'NOT_CHECKED' if f[1] is ... else repr(f[1])}"
1739
+ for f in allowed["resulting_fields"][:3]
1740
+ )
1741
+ if len(allowed["resulting_fields"]) > 3:
1742
+ fields_summary += f", ... +{len(allowed['resulting_fields']) - 3} more"
1743
+ no_other = allowed.get("no_other_changes", "NOT_SET")
1744
+ error_lines.append(
1745
+ f" {i}. Table: {allowed.get('table')}, "
1746
+ f"ID: {allowed.get('pk')}, "
1747
+ f"Type: {change_type}, "
1748
+ f"resulting_fields: [{fields_summary}], "
1749
+ f"no_other_changes: {no_other}"
1750
+ )
1751
+ elif "fields" in allowed and allowed["fields"] is not None:
1752
+ # Show bulk fields spec (for insert/delete)
1753
+ fields_summary = ", ".join(
1754
+ f[0] if len(f) == 1 else f"{f[0]}={'NOT_CHECKED' if f[1] is ... else repr(f[1])}"
1755
+ for f in allowed["fields"][:3]
1756
+ )
1757
+ if len(allowed["fields"]) > 3:
1758
+ fields_summary += f", ... +{len(allowed['fields']) - 3} more"
1759
+ error_lines.append(
1760
+ f" {i}. Table: {allowed.get('table')}, "
1761
+ f"ID: {allowed.get('pk')}, "
1762
+ f"Type: {change_type}, "
1763
+ f"Fields: [{fields_summary}]"
1764
+ )
1765
+ else:
1766
+ error_lines.append(
1767
+ f" {i}. Table: {allowed.get('table')}, "
1768
+ f"ID: {allowed.get('pk')}, "
1769
+ f"Type: {change_type}"
1770
+ )
1771
+ if len(allowed_changes) > 3:
1772
+ error_lines.append(
1773
+ f" ... and {len(allowed_changes) - 3} more allowed changes"
1774
+ )
1775
+ else:
1776
+ error_lines.append(" (No changes were allowed)")
1777
+
1778
+ raise AssertionError("\n".join(error_lines))
1779
+
1780
+ return self
1781
+
1782
+ def expect_only(self, allowed_changes: List[Dict[str, Any]]):
1783
+ """Ensure only specified changes occurred."""
1784
+ # Normalize pk values: convert lists to tuples for hashability and consistency
1785
+ for change in allowed_changes:
1786
+ if "pk" in change and isinstance(change["pk"], list):
1787
+ change["pk"] = tuple(change["pk"])
1788
+
1789
+ # Special case: empty allowed_changes means no changes should have occurred
1790
+ if not allowed_changes:
1791
+ return self._expect_no_changes()
1792
+
1793
+ # For expect_only, we can optimize by only checking the specific rows mentioned
1794
+ if self._can_use_targeted_queries(allowed_changes):
1795
+ return self._expect_only_targeted(allowed_changes)
1796
+
1797
+ # Fall back to full diff for complex cases
1798
+ diff = self._collect()
1799
+ return self._validate_diff_against_allowed_changes(diff, allowed_changes)
1800
+
1801
+ def expect_only_v2(self, allowed_changes: List[Dict[str, Any]]):
1802
+ """Ensure only specified changes occurred, with field-level spec support.
1803
+
1804
+ This version supports field-level specifications for added/removed rows,
1805
+ allowing users to specify expected field values instead of just whole-row specs.
1806
+ """
1807
+ # Normalize pk values: convert lists to tuples for hashability and consistency
1808
+ for change in allowed_changes:
1809
+ if "pk" in change and isinstance(change["pk"], list):
1810
+ change["pk"] = tuple(change["pk"])
1811
+
1812
+ # Special case: empty allowed_changes means no changes should have occurred
1813
+ if not allowed_changes:
1814
+ return self._expect_no_changes()
1815
+
1816
+ resource = self.after.resource
1817
+ # Disabled: structured diff endpoint not yet available
1818
+ if False and resource.client is not None and resource._mode == "http":
1819
+ api_diff = None
1820
+ try:
1821
+ payload = {}
1822
+ if self.ignore_config:
1823
+ payload["ignore_config"] = {
1824
+ "tables": list(self.ignore_config.tables),
1825
+ "fields": list(self.ignore_config.fields),
1826
+ "table_fields": {
1827
+ table: list(fields) for table, fields in self.ignore_config.table_fields.items()
1828
+ }
1829
+ }
1830
+ response = resource.client.request(
1831
+ "POST",
1832
+ "/diff/structured",
1833
+ json=payload,
1834
+ )
1835
+ result = response.json()
1836
+ if result.get("success") and "diff" in result:
1837
+ api_diff = result["diff"]
1838
+ except Exception as e:
1839
+ # Fall back to local diff if API call fails
1840
+ print(f"Warning: Failed to fetch structured diff from API: {e}")
1841
+ print("Falling back to local diff computation...")
1842
+
1843
+ # Validate outside try block so AssertionError propagates
1844
+ if api_diff is not None:
1845
+ return self._validate_diff_against_allowed_changes_v2(api_diff, allowed_changes)
1846
+
1847
+ # For expect_only_v2, we can optimize by only checking the specific rows mentioned
1848
+ if self._can_use_targeted_queries(allowed_changes):
1849
+ return self._expect_only_targeted_v2(allowed_changes)
1850
+
1851
+ # Fall back to full diff for complex cases
1852
+ diff = self._collect()
1853
+ return self._validate_diff_against_allowed_changes_v2(diff, allowed_changes)
1854
+
1855
+ def expect_exactly(self, expected_changes: List[Dict[str, Any]]):
1856
+ """Verify that EXACTLY the specified changes occurred.
1857
+
1858
+ This is stricter than expect_only_v2:
1859
+ 1. All changes in diff must match a spec (no unexpected changes)
1860
+ 2. All specs must have a matching change in diff (no missing expected changes)
1861
+
1862
+ This method is ideal for verifying that an agent performed exactly what was expected -
1863
+ not more, not less.
1864
+
1865
+ Args:
1866
+ expected_changes: List of expected change specs. Each spec requires:
1867
+ - "type": "insert", "modify", or "delete" (required)
1868
+ - "table": table name (required)
1869
+ - "pk": primary key value (required)
1870
+
1871
+ Spec formats by type:
1872
+ - Insert: {"type": "insert", "table": "t", "pk": 1, "fields": [...]}
1873
+ - Modify: {"type": "modify", "table": "t", "pk": 1, "resulting_fields": [...], "no_other_changes": True/False}
1874
+ - Delete: {"type": "delete", "table": "t", "pk": 1}
1875
+
1876
+ Field specs are 2-tuples: (field_name, expected_value)
1877
+ - ("name", "Alice"): check field equals "Alice"
1878
+ - ("name", ...): accept any value (ellipsis)
1879
+ - ("name", None): check field is SQL NULL
1880
+
1881
+ Note: Legacy specs without explicit "type" are not supported.
1882
+
1883
+ Returns:
1884
+ self for method chaining
1885
+
1886
+ Raises:
1887
+ AssertionError: If there are unexpected changes OR if expected changes are missing
1888
+ ValueError: If specs are missing required fields or have invalid format
1889
+ """
1890
+ # Get the diff (using HTTP if available, otherwise local)
1891
+ resource = self.after.resource
1892
+ diff = None
1893
+
1894
+ if resource.client is not None and resource._mode == "http":
1895
+ try:
1896
+ payload = {}
1897
+ if self.ignore_config:
1898
+ payload["ignore_config"] = {
1899
+ "tables": list(self.ignore_config.tables),
1900
+ "fields": list(self.ignore_config.fields),
1901
+ "table_fields": {
1902
+ table: list(fields) for table, fields in self.ignore_config.table_fields.items()
1903
+ }
1904
+ }
1905
+ response = resource.client.request(
1906
+ "POST",
1907
+ "/diff/structured",
1908
+ json=payload,
1909
+ )
1910
+ result = response.json()
1911
+ if result.get("success") and "diff" in result:
1912
+ diff = result["diff"]
1913
+ except Exception as e:
1914
+ print(f"Warning: Failed to fetch structured diff from API: {e}")
1915
+ print("Falling back to local diff computation...")
1916
+
1917
+ if diff is None:
1918
+ diff = self._collect()
1919
+
1920
+ # Use shared validation logic
1921
+ success, error_msg, _ = validate_diff_expect_exactly(
1922
+ diff, expected_changes, self.ignore_config
1923
+ )
1924
+
1925
+ if not success:
1926
+ raise AssertionError(error_msg)
1927
+
1928
+ return self
1929
+
1930
+ def _ensure_all_fetched(self):
1931
+ """Fetch ALL data from ALL tables upfront (non-lazy loading).
1932
+
1933
+ This is the old approach before lazy loading was introduced.
1934
+ Used by expect_only_v1 for simpler, non-optimized diffing.
1935
+ """
1936
+ # Get all tables
1937
+ tables_response = self.before.resource.query(
1938
+ "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
1939
+ )
1940
+
1941
+ if tables_response.rows:
1942
+ before_tables = [row[0] for row in tables_response.rows]
1943
+ for table in before_tables:
1944
+ self.before._ensure_table_data(table)
1945
+
1946
+ # Also fetch from after snapshot
1947
+ tables_response = self.after.resource.query(
1948
+ "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
1949
+ )
1950
+
1951
+ if tables_response.rows:
1952
+ after_tables = [row[0] for row in tables_response.rows]
1953
+ for table in after_tables:
1954
+ self.after._ensure_table_data(table)
1955
+
1956
+ def expect_only_v1(self, allowed_changes: List[Dict[str, Any]]):
1957
+ """Ensure only specified changes occurred using the original (non-optimized) approach.
1958
+
1959
+ This version attempts to use the /api/v1/env/diff/structured endpoint if available,
1960
+ falling back to local diff computation if the endpoint is not available.
1961
+
1962
+ Use this when you want the simpler, more predictable behavior of the original
1963
+ implementation without any query optimizations.
1964
+ """
1965
+ # Try to use the structured diff endpoint if we have an HTTP client
1966
+ resource = self.after.resource
1967
+ if resource.client is not None and resource._mode == "http":
1968
+ api_diff = None
1969
+ try:
1970
+ payload = {}
1971
+ if self.ignore_config:
1972
+ payload["ignore_config"] = {
1973
+ "tables": list(self.ignore_config.tables),
1974
+ "fields": list(self.ignore_config.fields),
1975
+ "table_fields": {
1976
+ table: list(fields) for table, fields in self.ignore_config.table_fields.items()
1977
+ }
1978
+ }
1979
+ response = resource.client.request(
1980
+ "POST",
1981
+ "/diff/structured",
1982
+ json=payload,
1983
+ )
1984
+ result = response.json()
1985
+ if result.get("success") and "diff" in result:
1986
+ api_diff = result["diff"]
1987
+ except Exception as e:
1988
+ # Fall back to local diff if API call fails
1989
+ print(f"Warning: Failed to fetch structured diff from API: {e}")
1990
+ print("Falling back to local diff computation...")
1991
+
1992
+ # Validate outside try block so AssertionError propagates
1993
+ if api_diff is not None:
1994
+ return self._validate_diff_against_allowed_changes(api_diff, allowed_changes)
1995
+
1996
+ # Fall back to local diff computation
1997
+ self._ensure_all_fetched()
1998
+ diff = self._collect()
1999
+ return self._validate_diff_against_allowed_changes(diff, allowed_changes)
2000
+
456
2001
 
457
2002
  class SyncQueryBuilder:
458
2003
  """Async query builder that translates DSL to SQL and executes through the API."""
@@ -551,13 +2096,13 @@ class SyncQueryBuilder:
551
2096
  # Compile to SQL
552
2097
  def _compile(self) -> Tuple[str, List[Any]]:
553
2098
  cols = ", ".join(self._select_cols)
554
- sql = [f"SELECT {cols} FROM {self._table}"]
2099
+ sql = [f"SELECT {cols} FROM {_quote_identifier(self._table)}"]
555
2100
  params: List[Any] = []
556
2101
 
557
2102
  # Joins
558
2103
  for tbl, onmap in self._joins:
559
- join_clauses = [f"{self._table}.{l} = {tbl}.{r}" for l, r in onmap.items()]
560
- sql.append(f"JOIN {tbl} ON {' AND '.join(join_clauses)}")
2104
+ join_clauses = [f"{_quote_identifier(self._table)}.{_quote_identifier(l)} = {_quote_identifier(tbl)}.{_quote_identifier(r)}" for l, r in onmap.items()]
2105
+ sql.append(f"JOIN {_quote_identifier(tbl)} ON {' AND '.join(join_clauses)}")
561
2106
 
562
2107
  # WHERE
563
2108
  if self._conditions:
@@ -565,12 +2110,12 @@ class SyncQueryBuilder:
565
2110
  for col, op, val in self._conditions:
566
2111
  if op in ("IN", "NOT IN") and isinstance(val, tuple):
567
2112
  ph = ", ".join(["?" for _ in val])
568
- placeholders.append(f"{col} {op} ({ph})")
2113
+ placeholders.append(f"{_quote_identifier(col)} {op} ({ph})")
569
2114
  params.extend(val)
570
2115
  elif op in ("IS", "IS NOT"):
571
- placeholders.append(f"{col} {op} NULL")
2116
+ placeholders.append(f"{_quote_identifier(col)} {op} NULL")
572
2117
  else:
573
- placeholders.append(f"{col} {op} ?")
2118
+ placeholders.append(f"{_quote_identifier(col)} {op} ?")
574
2119
  params.append(val)
575
2120
  sql.append("WHERE " + " AND ".join(placeholders))
576
2121
 
@@ -675,16 +2220,103 @@ class SyncQueryBuilder:
675
2220
 
676
2221
 
677
2222
  class SQLiteResource(Resource):
678
- def __init__(self, resource: ResourceModel, client: "SyncWrapper"):
2223
+ def __init__(
2224
+ self,
2225
+ resource: ResourceModel,
2226
+ client: Optional["SyncWrapper"] = None,
2227
+ db_path: Optional[str] = None,
2228
+ ):
679
2229
  super().__init__(resource)
680
2230
  self.client = client
2231
+ self.db_path = db_path
2232
+ self._mode = "direct" if db_path else "http"
2233
+
2234
+ @property
2235
+ def mode(self) -> str:
2236
+ """Return the mode of this resource: 'direct' (local file) or 'http' (remote API)."""
2237
+ return self._mode
681
2238
 
682
2239
  def describe(self) -> DescribeResponse:
683
2240
  """Describe the SQLite database schema."""
2241
+ if self._mode == "direct":
2242
+ return self._describe_direct()
2243
+ else:
2244
+ return self._describe_http()
2245
+
2246
+ def _describe_http(self) -> DescribeResponse:
2247
+ """Describe database schema via HTTP API."""
684
2248
  response = self.client.request(
685
2249
  "GET", f"/resources/sqlite/{self.resource.name}/describe"
686
2250
  )
687
- return DescribeResponse(**response.json())
2251
+ try:
2252
+ return DescribeResponse(**response.json())
2253
+ except json.JSONDecodeError as e:
2254
+ raise ValueError(
2255
+ f"Failed to parse JSON response from SQLite describe endpoint. "
2256
+ f"Status: {response.status_code}, "
2257
+ f"Response text: {response.text[:500]}"
2258
+ ) from e
2259
+
2260
+ def _describe_direct(self) -> DescribeResponse:
2261
+ """Describe database schema from local file or in-memory database."""
2262
+ try:
2263
+ # Check if we need URI mode (for shared memory databases)
2264
+ use_uri = 'mode=memory' in self.db_path
2265
+ conn = sqlite3.connect(self.db_path, uri=use_uri)
2266
+ cursor = conn.cursor()
2267
+
2268
+ # Get all tables
2269
+ cursor.execute(
2270
+ "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
2271
+ )
2272
+ table_names = [row[0] for row in cursor.fetchall()]
2273
+
2274
+ tables = []
2275
+ for table_name in table_names:
2276
+ # Get table info
2277
+ cursor.execute(f"PRAGMA table_info({_quote_identifier(table_name)})")
2278
+ columns = cursor.fetchall()
2279
+
2280
+ # Get CREATE TABLE SQL
2281
+ cursor.execute(
2282
+ f"SELECT sql FROM sqlite_master WHERE type='table' AND name=?",
2283
+ (table_name,)
2284
+ )
2285
+ sql_row = cursor.fetchone()
2286
+ create_sql = sql_row[0] if sql_row else ""
2287
+
2288
+ table_schema = {
2289
+ "name": table_name,
2290
+ "sql": create_sql,
2291
+ "columns": [
2292
+ {
2293
+ "name": col[1],
2294
+ "type": col[2],
2295
+ "notnull": bool(col[3]),
2296
+ "default_value": col[4],
2297
+ "primary_key": col[5] > 0,
2298
+ }
2299
+ for col in columns
2300
+ ],
2301
+ }
2302
+ tables.append(table_schema)
2303
+
2304
+ conn.close()
2305
+
2306
+ return DescribeResponse(
2307
+ success=True,
2308
+ resource_name=self.resource.name,
2309
+ tables=tables,
2310
+ message="Schema retrieved from local file",
2311
+ )
2312
+ except Exception as e:
2313
+ return DescribeResponse(
2314
+ success=False,
2315
+ resource_name=self.resource.name,
2316
+ tables=None,
2317
+ error=str(e),
2318
+ message=f"Failed to describe database: {str(e)}",
2319
+ )
688
2320
 
689
2321
  def query(self, query: str, args: Optional[List[Any]] = None) -> QueryResponse:
690
2322
  return self._query(query, args, read_only=True)
@@ -695,6 +2327,121 @@ class SQLiteResource(Resource):
695
2327
  def _query(
696
2328
  self, query: str, args: Optional[List[Any]] = None, read_only: bool = True
697
2329
  ) -> QueryResponse:
2330
+ if self._mode == "direct":
2331
+ return self._query_direct(query, args, read_only)
2332
+ else:
2333
+ # Check if this is a PRAGMA query - HTTP endpoints don't support PRAGMA
2334
+ query_stripped = query.strip().upper()
2335
+ if query_stripped.startswith("PRAGMA"):
2336
+ return self._handle_pragma_query_http(query, args)
2337
+ return self._query_http(query, args, read_only)
2338
+
2339
+ def _handle_pragma_query_http(
2340
+ self, query: str, args: Optional[List[Any]] = None
2341
+ ) -> QueryResponse:
2342
+ """Handle PRAGMA queries in HTTP mode by using the describe endpoint."""
2343
+ query_upper = query.strip().upper()
2344
+
2345
+ # Extract table name from PRAGMA table_info(table_name)
2346
+ if "TABLE_INFO" in query_upper:
2347
+ # Match: PRAGMA table_info("table") or PRAGMA table_info(table)
2348
+ match = re.search(r'TABLE_INFO\s*\(\s*"([^"]+)"\s*\)', query, re.IGNORECASE)
2349
+ if not match:
2350
+ match = re.search(r"TABLE_INFO\s*\(\s*'([^']+)'\s*\)", query, re.IGNORECASE)
2351
+ if not match:
2352
+ match = re.search(r'TABLE_INFO\s*\(\s*([^\s\)]+)\s*\)', query, re.IGNORECASE)
2353
+
2354
+ if match:
2355
+ table_name = match.group(1)
2356
+
2357
+ # Use the describe endpoint to get schema
2358
+ describe_response = self.describe()
2359
+ if not describe_response.success or not describe_response.tables:
2360
+ return QueryResponse(
2361
+ success=False,
2362
+ columns=None,
2363
+ rows=None,
2364
+ error="Failed to get schema information",
2365
+ message="PRAGMA query failed: could not retrieve schema"
2366
+ )
2367
+
2368
+ # Find the table in the schema
2369
+ table_schema = None
2370
+ for table in describe_response.tables:
2371
+ # Handle both dict and TableSchema objects
2372
+ table_name_in_schema = table.name if hasattr(table, 'name') else table.get("name")
2373
+ if table_name_in_schema == table_name:
2374
+ table_schema = table
2375
+ break
2376
+
2377
+ if not table_schema:
2378
+ return QueryResponse(
2379
+ success=False,
2380
+ columns=None,
2381
+ rows=None,
2382
+ error=f"Table '{table_name}' not found",
2383
+ message=f"PRAGMA query failed: table '{table_name}' not found"
2384
+ )
2385
+
2386
+ # Get columns from table schema
2387
+ columns = table_schema.columns if hasattr(table_schema, 'columns') else table_schema.get("columns")
2388
+ if not columns:
2389
+ return QueryResponse(
2390
+ success=False,
2391
+ columns=None,
2392
+ rows=None,
2393
+ error=f"Table '{table_name}' has no columns",
2394
+ message=f"PRAGMA query failed: table '{table_name}' has no columns"
2395
+ )
2396
+
2397
+ # Convert schema to PRAGMA table_info format
2398
+ # Format: (cid, name, type, notnull, dflt_value, pk)
2399
+ rows = []
2400
+ for idx, col in enumerate(columns):
2401
+ # Handle both dict and object column definitions
2402
+ if isinstance(col, dict):
2403
+ col_name = col["name"]
2404
+ col_type = col.get("type", "")
2405
+ col_notnull = col.get("notnull", False)
2406
+ col_default = col.get("default_value")
2407
+ col_pk = col.get("pk", 0)
2408
+ else:
2409
+ col_name = col.name if hasattr(col, 'name') else str(col)
2410
+ col_type = getattr(col, 'type', "")
2411
+ col_notnull = getattr(col, 'notnull', False)
2412
+ col_default = getattr(col, 'default_value', None)
2413
+ col_pk = getattr(col, 'pk', 0)
2414
+
2415
+ row = (
2416
+ idx, # cid
2417
+ col_name, # name
2418
+ col_type, # type
2419
+ 1 if col_notnull else 0, # notnull
2420
+ col_default, # dflt_value
2421
+ col_pk # pk
2422
+ )
2423
+ rows.append(row)
2424
+
2425
+ return QueryResponse(
2426
+ success=True,
2427
+ columns=["cid", "name", "type", "notnull", "dflt_value", "pk"],
2428
+ rows=rows,
2429
+ message="PRAGMA query executed successfully via describe endpoint"
2430
+ )
2431
+
2432
+ # For other PRAGMA queries, return an error indicating they're not supported
2433
+ return QueryResponse(
2434
+ success=False,
2435
+ columns=None,
2436
+ rows=None,
2437
+ error="PRAGMA query not supported in HTTP mode",
2438
+ message=f"PRAGMA query '{query}' is not supported via HTTP API"
2439
+ )
2440
+
2441
+ def _query_http(
2442
+ self, query: str, args: Optional[List[Any]] = None, read_only: bool = True
2443
+ ) -> QueryResponse:
2444
+ """Execute query via HTTP API."""
698
2445
  request = QueryRequest(query=query, args=args, read_only=read_only)
699
2446
  response = self.client.request(
700
2447
  "POST",
@@ -703,6 +2450,59 @@ class SQLiteResource(Resource):
703
2450
  )
704
2451
  return QueryResponse(**response.json())
705
2452
 
2453
+ def _query_direct(
2454
+ self, query: str, args: Optional[List[Any]] = None, read_only: bool = True
2455
+ ) -> QueryResponse:
2456
+ """Execute query directly on local SQLite file or in-memory database."""
2457
+ try:
2458
+ # Check if we need URI mode (for shared memory databases)
2459
+ use_uri = 'mode=memory' in self.db_path
2460
+ conn = sqlite3.connect(self.db_path, uri=use_uri)
2461
+ cursor = conn.cursor()
2462
+
2463
+ # Execute the query
2464
+ if args:
2465
+ cursor.execute(query, args)
2466
+ else:
2467
+ cursor.execute(query)
2468
+
2469
+ # For write operations, commit the transaction
2470
+ if not read_only:
2471
+ conn.commit()
2472
+
2473
+ # Get column names if available
2474
+ columns = [desc[0] for desc in cursor.description] if cursor.description else []
2475
+
2476
+ # Fetch results for SELECT queries
2477
+ rows = []
2478
+ rows_affected = 0
2479
+ last_insert_id = None
2480
+
2481
+ if cursor.description: # SELECT query
2482
+ rows = cursor.fetchall()
2483
+ else: # INSERT/UPDATE/DELETE
2484
+ rows_affected = cursor.rowcount
2485
+ last_insert_id = cursor.lastrowid if cursor.lastrowid else None
2486
+
2487
+ conn.close()
2488
+
2489
+ return QueryResponse(
2490
+ success=True,
2491
+ columns=columns if columns else None,
2492
+ rows=rows if rows else None,
2493
+ rows_affected=rows_affected if rows_affected > 0 else None,
2494
+ last_insert_id=last_insert_id,
2495
+ message="Query executed successfully",
2496
+ )
2497
+ except Exception as e:
2498
+ return QueryResponse(
2499
+ success=False,
2500
+ columns=None,
2501
+ rows=None,
2502
+ error=str(e),
2503
+ message=f"Query failed: {str(e)}",
2504
+ )
2505
+
706
2506
  def table(self, table_name: str) -> SyncQueryBuilder:
707
2507
  """Create a query builder for the specified table."""
708
2508
  return SyncQueryBuilder(self, table_name)
@@ -710,7 +2510,6 @@ class SQLiteResource(Resource):
710
2510
  def snapshot(self, name: Optional[str] = None) -> SyncDatabaseSnapshot:
711
2511
  """Create a snapshot of the current database state."""
712
2512
  snapshot = SyncDatabaseSnapshot(self, name)
713
- snapshot._ensure_fetched()
714
2513
  return snapshot
715
2514
 
716
2515
  def diff(