fleet-python 0.2.66b2__py3-none-any.whl → 0.2.105__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. examples/export_tasks.py +16 -5
  2. examples/export_tasks_filtered.py +245 -0
  3. examples/fetch_tasks.py +230 -0
  4. examples/import_tasks.py +140 -8
  5. examples/iterate_verifiers.py +725 -0
  6. fleet/__init__.py +128 -5
  7. fleet/_async/__init__.py +27 -3
  8. fleet/_async/base.py +24 -9
  9. fleet/_async/client.py +938 -41
  10. fleet/_async/env/client.py +60 -3
  11. fleet/_async/instance/client.py +52 -7
  12. fleet/_async/models.py +15 -0
  13. fleet/_async/resources/api.py +200 -0
  14. fleet/_async/resources/sqlite.py +1801 -46
  15. fleet/_async/tasks.py +122 -25
  16. fleet/_async/verifiers/bundler.py +22 -21
  17. fleet/_async/verifiers/verifier.py +25 -19
  18. fleet/agent/__init__.py +32 -0
  19. fleet/agent/gemini_cua/Dockerfile +45 -0
  20. fleet/agent/gemini_cua/__init__.py +10 -0
  21. fleet/agent/gemini_cua/agent.py +759 -0
  22. fleet/agent/gemini_cua/mcp/main.py +108 -0
  23. fleet/agent/gemini_cua/mcp_server/__init__.py +5 -0
  24. fleet/agent/gemini_cua/mcp_server/main.py +105 -0
  25. fleet/agent/gemini_cua/mcp_server/tools.py +178 -0
  26. fleet/agent/gemini_cua/requirements.txt +5 -0
  27. fleet/agent/gemini_cua/start.sh +30 -0
  28. fleet/agent/orchestrator.py +854 -0
  29. fleet/agent/types.py +49 -0
  30. fleet/agent/utils.py +34 -0
  31. fleet/base.py +34 -9
  32. fleet/cli.py +1061 -0
  33. fleet/client.py +1060 -48
  34. fleet/config.py +1 -1
  35. fleet/env/__init__.py +16 -0
  36. fleet/env/client.py +60 -3
  37. fleet/eval/__init__.py +15 -0
  38. fleet/eval/uploader.py +231 -0
  39. fleet/exceptions.py +8 -0
  40. fleet/instance/client.py +53 -8
  41. fleet/instance/models.py +1 -0
  42. fleet/models.py +303 -0
  43. fleet/proxy/__init__.py +25 -0
  44. fleet/proxy/proxy.py +453 -0
  45. fleet/proxy/whitelist.py +244 -0
  46. fleet/resources/api.py +200 -0
  47. fleet/resources/sqlite.py +1845 -46
  48. fleet/tasks.py +113 -20
  49. fleet/utils/__init__.py +7 -0
  50. fleet/utils/http_logging.py +178 -0
  51. fleet/utils/logging.py +13 -0
  52. fleet/utils/playwright.py +440 -0
  53. fleet/verifiers/bundler.py +22 -21
  54. fleet/verifiers/db.py +985 -1
  55. fleet/verifiers/decorator.py +1 -1
  56. fleet/verifiers/verifier.py +25 -19
  57. {fleet_python-0.2.66b2.dist-info → fleet_python-0.2.105.dist-info}/METADATA +28 -1
  58. fleet_python-0.2.105.dist-info/RECORD +115 -0
  59. {fleet_python-0.2.66b2.dist-info → fleet_python-0.2.105.dist-info}/WHEEL +1 -1
  60. fleet_python-0.2.105.dist-info/entry_points.txt +2 -0
  61. tests/test_app_method.py +85 -0
  62. tests/test_expect_exactly.py +4148 -0
  63. tests/test_expect_only.py +2593 -0
  64. tests/test_instance_dispatch.py +607 -0
  65. tests/test_sqlite_resource_dual_mode.py +263 -0
  66. tests/test_sqlite_shared_memory_behavior.py +117 -0
  67. fleet_python-0.2.66b2.dist-info/RECORD +0 -81
  68. tests/test_verifier_security.py +0 -427
  69. {fleet_python-0.2.66b2.dist-info → fleet_python-0.2.105.dist-info}/licenses/LICENSE +0 -0
  70. {fleet_python-0.2.66b2.dist-info → fleet_python-0.2.105.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,9 @@ from datetime import datetime
6
6
  import tempfile
7
7
  import sqlite3
8
8
  import os
9
+ import asyncio
10
+ import re
11
+ import json
9
12
 
10
13
  from typing import TYPE_CHECKING
11
14
 
@@ -19,11 +22,23 @@ from fleet.verifiers.db import (
19
22
  _get_row_identifier,
20
23
  _format_row_for_error,
21
24
  _values_equivalent,
25
+ validate_diff_expect_exactly,
22
26
  )
23
27
 
24
28
 
29
+ def _quote_identifier(identifier: str) -> str:
30
+ """Quote an identifier (table or column name) for SQLite.
31
+
32
+ SQLite uses double quotes for identifiers and escapes internal quotes by doubling them.
33
+ This handles reserved keywords like 'order', 'table', etc.
34
+ """
35
+ # Escape any double quotes in the identifier by doubling them
36
+ escaped = identifier.replace('"', '""')
37
+ return f'"{escaped}"'
38
+
39
+
25
40
  class AsyncDatabaseSnapshot:
26
- """Async database snapshot that fetches data through API and stores locally for diffing."""
41
+ """Lazy database snapshot that fetches data on-demand through API."""
27
42
 
28
43
  def __init__(self, resource: "AsyncSQLiteResource", name: Optional[str] = None):
29
44
  self.resource = resource
@@ -31,11 +46,12 @@ class AsyncDatabaseSnapshot:
31
46
  self.created_at = datetime.utcnow()
32
47
  self._data: Dict[str, List[Dict[str, Any]]] = {}
33
48
  self._schemas: Dict[str, List[str]] = {}
34
- self._fetched = False
49
+ self._table_names: Optional[List[str]] = None
50
+ self._fetched_tables: set = set()
35
51
 
36
- async def _ensure_fetched(self):
37
- """Fetch all data from remote database if not already fetched."""
38
- if self._fetched:
52
+ async def _ensure_tables_list(self):
53
+ """Fetch just the list of table names if not already fetched."""
54
+ if self._table_names is not None:
39
55
  return
40
56
 
41
57
  # Get all tables
@@ -44,35 +60,36 @@ class AsyncDatabaseSnapshot:
44
60
  )
45
61
 
46
62
  if not tables_response.rows:
47
- self._fetched = True
63
+ self._table_names = []
48
64
  return
49
65
 
50
- table_names = [row[0] for row in tables_response.rows]
51
-
52
- # Fetch data from each table
53
- for table in table_names:
54
- # Get table schema
55
- schema_response = await self.resource.query(f"PRAGMA table_info({table})")
56
- if schema_response.rows:
57
- self._schemas[table] = [
58
- row[1] for row in schema_response.rows
59
- ] # Column names
60
-
61
- # Get all data
62
- data_response = await self.resource.query(f"SELECT * FROM {table}")
63
- if data_response.rows and data_response.columns:
64
- self._data[table] = [
65
- dict(zip(data_response.columns, row)) for row in data_response.rows
66
- ]
67
- else:
68
- self._data[table] = []
66
+ self._table_names = [row[0] for row in tables_response.rows]
67
+
68
+ async def _ensure_table_data(self, table: str):
69
+ """Fetch data for a specific table on demand."""
70
+ if table in self._fetched_tables:
71
+ return
69
72
 
70
- self._fetched = True
73
+ # Get table schema
74
+ schema_response = await self.resource.query(f"PRAGMA table_info({_quote_identifier(table)})")
75
+ if schema_response.rows:
76
+ self._schemas[table] = [row[1] for row in schema_response.rows] # Column names
77
+
78
+ # Get all data for this table
79
+ data_response = await self.resource.query(f"SELECT * FROM {_quote_identifier(table)}")
80
+ if data_response.rows and data_response.columns:
81
+ self._data[table] = [
82
+ dict(zip(data_response.columns, row)) for row in data_response.rows
83
+ ]
84
+ else:
85
+ self._data[table] = []
86
+
87
+ self._fetched_tables.add(table)
71
88
 
72
89
  async def tables(self) -> List[str]:
73
90
  """Get list of all tables in the snapshot."""
74
- await self._ensure_fetched()
75
- return list(self._data.keys())
91
+ await self._ensure_tables_list()
92
+ return list(self._table_names) if self._table_names else []
76
93
 
77
94
  def table(self, table_name: str) -> "AsyncSnapshotQueryBuilder":
78
95
  """Create a query builder for snapshot data."""
@@ -84,13 +101,12 @@ class AsyncDatabaseSnapshot:
84
101
  ignore_config: Optional[IgnoreConfig] = None,
85
102
  ) -> "AsyncSnapshotDiff":
86
103
  """Compare this snapshot with another."""
87
- await self._ensure_fetched()
88
- await other._ensure_fetched()
104
+ # No need to fetch all data upfront - diff will fetch on demand
89
105
  return AsyncSnapshotDiff(self, other, ignore_config)
90
106
 
91
107
 
92
108
  class AsyncSnapshotQueryBuilder:
93
- """Query builder that works on local snapshot data."""
109
+ """Query builder that works on snapshot data - can use targeted queries when possible."""
94
110
 
95
111
  def __init__(self, snapshot: AsyncDatabaseSnapshot, table: str):
96
112
  self._snapshot = snapshot
@@ -100,10 +116,63 @@ class AsyncSnapshotQueryBuilder:
100
116
  self._limit: Optional[int] = None
101
117
  self._order_by: Optional[str] = None
102
118
  self._order_desc: bool = False
119
+ self._use_targeted_query = True # Try to use targeted queries when possible
120
+
121
+ def _can_use_targeted_query(self) -> bool:
122
+ """Check if we can use a targeted query instead of loading all data."""
123
+ # We can use targeted query if:
124
+ # 1. We have simple equality conditions
125
+ # 2. No complex operations like joins
126
+ # 3. The query is selective (has conditions)
127
+ if not self._conditions:
128
+ return False
129
+ for col, op, val in self._conditions:
130
+ if op not in ["=", "IS", "IS NOT"]:
131
+ return False
132
+ return True
133
+
134
+ async def _execute_targeted_query(self) -> List[Dict[str, Any]]:
135
+ """Execute a targeted query directly instead of loading all data."""
136
+ # Build WHERE clause
137
+ where_parts = []
138
+ for col, op, val in self._conditions:
139
+ if op == "=" and val is None:
140
+ where_parts.append(f"{_quote_identifier(col)} IS NULL")
141
+ elif op == "IS":
142
+ where_parts.append(f"{_quote_identifier(col)} IS NULL")
143
+ elif op == "IS NOT":
144
+ where_parts.append(f"{_quote_identifier(col)} IS NOT NULL")
145
+ elif op == "=":
146
+ if isinstance(val, str):
147
+ escaped_val = val.replace("'", "''")
148
+ where_parts.append(f"{_quote_identifier(col)} = '{escaped_val}'")
149
+ else:
150
+ where_parts.append(f"{_quote_identifier(col)} = '{val}'")
151
+
152
+ where_clause = " AND ".join(where_parts)
153
+
154
+ # Build full query
155
+ cols = ", ".join(self._select_cols)
156
+ query = f"SELECT {cols} FROM {_quote_identifier(self._table)} WHERE {where_clause}"
157
+
158
+ if self._order_by:
159
+ query += f" ORDER BY {self._order_by}"
160
+ if self._limit is not None:
161
+ query += f" LIMIT {self._limit}"
162
+
163
+ # Execute query
164
+ response = await self._snapshot.resource.query(query)
165
+ if response.rows and response.columns:
166
+ return [dict(zip(response.columns, row)) for row in response.rows]
167
+ return []
103
168
 
104
169
  async def _get_data(self) -> List[Dict[str, Any]]:
105
- """Get table data from snapshot."""
106
- await self._snapshot._ensure_fetched()
170
+ """Get table data - use targeted query if possible, otherwise load all data."""
171
+ if self._use_targeted_query and self._can_use_targeted_query():
172
+ return await self._execute_targeted_query()
173
+
174
+ # Fall back to loading all data
175
+ await self._snapshot._ensure_table_data(self._table)
107
176
  return self._snapshot._data.get(self._table, [])
108
177
 
109
178
  def eq(self, column: str, value: Any) -> "AsyncSnapshotQueryBuilder":
@@ -142,6 +211,11 @@ class AsyncSnapshotQueryBuilder:
142
211
  return rows[0] if rows else None
143
212
 
144
213
  async def all(self) -> List[Dict[str, Any]]:
214
+ # If we can use targeted query, _get_data already applies filters
215
+ if self._use_targeted_query and self._can_use_targeted_query():
216
+ return await self._get_data()
217
+
218
+ # Otherwise, get all data and apply filters manually
145
219
  data = await self._get_data()
146
220
 
147
221
  # Apply filters
@@ -206,11 +280,12 @@ class AsyncSnapshotDiff:
206
280
  self.after = after
207
281
  self.ignore_config = ignore_config or IgnoreConfig()
208
282
  self._cached: Optional[Dict[str, Any]] = None
283
+ self._targeted_mode = False # Flag to use targeted queries
209
284
 
210
285
  async def _get_primary_key_columns(self, table: str) -> List[str]:
211
286
  """Get primary key columns for a table."""
212
287
  # Try to get from schema
213
- schema_response = await self.after.resource.query(f"PRAGMA table_info({table})")
288
+ schema_response = await self.after.resource.query(f"PRAGMA table_info({_quote_identifier(table)})")
214
289
  if not schema_response.rows:
215
290
  return ["id"] # Default fallback
216
291
 
@@ -246,6 +321,10 @@ class AsyncSnapshotDiff:
246
321
  # Get primary key columns
247
322
  pk_columns = await self._get_primary_key_columns(tbl)
248
323
 
324
+ # Ensure data is fetched for this table
325
+ await self.before._ensure_table_data(tbl)
326
+ await self.after._ensure_table_data(tbl)
327
+
249
328
  # Get data from both snapshots
250
329
  before_data = self.before._data.get(tbl, [])
251
330
  after_data = self.after._data.get(tbl, [])
@@ -328,9 +407,378 @@ class AsyncSnapshotDiff:
328
407
  )
329
408
  return self._cached
330
409
 
331
- async def expect_only(self, allowed_changes: List[Dict[str, Any]]):
332
- """Ensure only specified changes occurred."""
333
- diff = await self._collect()
410
+ def _can_use_targeted_queries(self, allowed_changes: List[Dict[str, Any]]) -> bool:
411
+ """Check if we can use targeted queries for optimization."""
412
+ # We can use targeted queries if all allowed changes specify table and pk
413
+ for change in allowed_changes:
414
+ if "table" not in change or "pk" not in change:
415
+ return False
416
+ return True
417
+
418
+ def _build_pk_where_clause(self, pk_columns: List[str], pk_value: Any) -> str:
419
+ """Build WHERE clause for primary key lookup."""
420
+ # Escape single quotes in values to prevent SQL injection
421
+ def escape_value(val: Any) -> str:
422
+ if val is None:
423
+ return "NULL"
424
+ elif isinstance(val, str):
425
+ escaped = str(val).replace("'", "''")
426
+ return f"'{escaped}'"
427
+ else:
428
+ return f"'{val}'"
429
+
430
+ if len(pk_columns) == 1:
431
+ return f"{_quote_identifier(pk_columns[0])} = {escape_value(pk_value)}"
432
+ else:
433
+ # Composite key
434
+ if isinstance(pk_value, tuple):
435
+ conditions = [
436
+ f"{_quote_identifier(col)} = {escape_value(val)}"
437
+ for col, val in zip(pk_columns, pk_value)
438
+ ]
439
+ return " AND ".join(conditions)
440
+ else:
441
+ # Shouldn't happen if data is consistent
442
+ return f"{_quote_identifier(pk_columns[0])} = {escape_value(pk_value)}"
443
+
444
+ async def _expect_no_changes(self):
445
+ """Efficiently verify that no changes occurred between snapshots using row counts."""
446
+ try:
447
+ import asyncio
448
+
449
+ # Get all tables from both snapshots
450
+ before_tables = set(await self.before.tables())
451
+ after_tables = set(await self.after.tables())
452
+
453
+ # Check for added/removed tables (excluding ignored ones)
454
+ added_tables = after_tables - before_tables
455
+ removed_tables = before_tables - after_tables
456
+
457
+ for table in added_tables:
458
+ if not self.ignore_config.should_ignore_table(table):
459
+ raise AssertionError(f"Unexpected table added: {table}")
460
+
461
+ for table in removed_tables:
462
+ if not self.ignore_config.should_ignore_table(table):
463
+ raise AssertionError(f"Unexpected table removed: {table}")
464
+
465
+ # Prepare tables to check
466
+ tables_to_check = []
467
+ all_tables = before_tables | after_tables
468
+ for table in all_tables:
469
+ if not self.ignore_config.should_ignore_table(table):
470
+ tables_to_check.append(table)
471
+
472
+ # If no tables to check, we're done
473
+ if not tables_to_check:
474
+ return self
475
+
476
+ # Track errors and tables needing verification
477
+ errors = []
478
+ tables_needing_verification = []
479
+
480
+ async def check_table_counts(table: str):
481
+ """Check row counts for a single table."""
482
+ try:
483
+ # Get row counts from both snapshots
484
+ before_count = 0
485
+ after_count = 0
486
+
487
+ if table in before_tables:
488
+ before_count_response = await self.before.resource.query(
489
+ f"SELECT COUNT(*) FROM {_quote_identifier(table)}"
490
+ )
491
+ before_count = (
492
+ before_count_response.rows[0][0]
493
+ if before_count_response.rows
494
+ else 0
495
+ )
496
+
497
+ if table in after_tables:
498
+ after_count_response = await self.after.resource.query(
499
+ f"SELECT COUNT(*) FROM {_quote_identifier(table)}"
500
+ )
501
+ after_count = (
502
+ after_count_response.rows[0][0]
503
+ if after_count_response.rows
504
+ else 0
505
+ )
506
+
507
+ if before_count != after_count:
508
+ error_msg = (
509
+ f"Unexpected change in table '{table}': "
510
+ f"row count changed from {before_count} to {after_count}"
511
+ )
512
+ errors.append(AssertionError(error_msg))
513
+ elif before_count > 0 and before_count <= 1000:
514
+ # Mark for detailed verification
515
+ tables_needing_verification.append(table)
516
+
517
+ except Exception as e:
518
+ errors.append(e)
519
+
520
+ # Execute count checks in parallel
521
+ await asyncio.gather(*[check_table_counts(table) for table in tables_to_check])
522
+
523
+ # Check if any errors occurred during count checking
524
+ if errors:
525
+ raise errors[0]
526
+
527
+ # Now verify small tables for data changes (also in parallel)
528
+ if tables_needing_verification:
529
+ verification_errors = []
530
+
531
+ async def verify_table(table: str):
532
+ """Verify a single table's data hasn't changed."""
533
+ try:
534
+ await self._verify_table_unchanged(table)
535
+ except AssertionError as e:
536
+ verification_errors.append(e)
537
+
538
+ await asyncio.gather(*[verify_table(table) for table in tables_needing_verification])
539
+
540
+ # Check if any errors occurred during verification
541
+ if verification_errors:
542
+ raise verification_errors[0]
543
+
544
+ return self
545
+
546
+ except AssertionError:
547
+ # Re-raise assertion errors (these are expected failures)
548
+ raise
549
+ except Exception as e:
550
+ # If the optimized check fails for other reasons, fall back to full diff
551
+ print(f"Warning: Optimized no-changes check failed: {e}")
552
+ print("Falling back to full diff...")
553
+ return await self._validate_diff_against_allowed_changes(
554
+ await self._collect(), []
555
+ )
556
+
557
+ async def _verify_table_unchanged(self, table: str):
558
+ """Verify that a table's data hasn't changed (for small tables)."""
559
+ # Get primary key columns
560
+ pk_columns = await self._get_primary_key_columns(table)
561
+
562
+ # Get sorted data from both snapshots
563
+ order_by = ", ".join(pk_columns) if pk_columns else "rowid"
564
+
565
+ before_response = await self.before.resource.query(
566
+ f"SELECT * FROM {_quote_identifier(table)} ORDER BY {order_by}"
567
+ )
568
+ after_response = await self.after.resource.query(
569
+ f"SELECT * FROM {_quote_identifier(table)} ORDER BY {order_by}"
570
+ )
571
+
572
+ # Quick check: if column counts differ, there's a schema change
573
+ if before_response.columns != after_response.columns:
574
+ raise AssertionError(f"Schema changed in table '{table}'")
575
+
576
+ # Compare row by row
577
+ if len(before_response.rows) != len(after_response.rows):
578
+ raise AssertionError(
579
+ f"Row count mismatch in table '{table}': "
580
+ f"{len(before_response.rows)} vs {len(after_response.rows)}"
581
+ )
582
+
583
+ for i, (before_row, after_row) in enumerate(
584
+ zip(before_response.rows, after_response.rows)
585
+ ):
586
+ before_dict = dict(zip(before_response.columns, before_row))
587
+ after_dict = dict(zip(after_response.columns, after_row))
588
+
589
+ # Compare fields, ignoring those in ignore config
590
+ for field in before_response.columns:
591
+ if self.ignore_config.should_ignore_field(table, field):
592
+ continue
593
+
594
+ if not _values_equivalent(
595
+ before_dict.get(field), after_dict.get(field)
596
+ ):
597
+ pk_val = before_dict.get(pk_columns[0]) if pk_columns else i
598
+ raise AssertionError(
599
+ f"Unexpected change in table '{table}', row {pk_val}, "
600
+ f"field '{field}': {repr(before_dict.get(field))} -> {repr(after_dict.get(field))}"
601
+ )
602
+
603
+ def _is_field_change_allowed(
604
+ self, table_changes: List[Dict[str, Any]], pk: Any, field: str, after_val: Any
605
+ ) -> bool:
606
+ """Check if a specific field change is allowed."""
607
+ for change in table_changes:
608
+ if (
609
+ str(change.get("pk")) == str(pk)
610
+ and change.get("field") == field
611
+ and _values_equivalent(change.get("after"), after_val)
612
+ ):
613
+ return True
614
+ return False
615
+
616
+ def _is_row_change_allowed(
617
+ self, table_changes: List[Dict[str, Any]], pk: Any, change_type: str
618
+ ) -> bool:
619
+ """Check if a row addition/deletion is allowed."""
620
+ for change in table_changes:
621
+ if str(change.get("pk")) == str(pk) and change.get("after") == change_type:
622
+ return True
623
+ return False
624
+
625
+ async def _expect_only_targeted(self, allowed_changes: List[Dict[str, Any]]):
626
+ """Optimized version that only queries specific rows mentioned in allowed_changes."""
627
+ import asyncio
628
+
629
+ # Group allowed changes by table
630
+ changes_by_table: Dict[str, List[Dict[str, Any]]] = {}
631
+ for change in allowed_changes:
632
+ table = change["table"]
633
+ if table not in changes_by_table:
634
+ changes_by_table[table] = []
635
+ changes_by_table[table].append(change)
636
+
637
+ errors = []
638
+
639
+ # Function to check a single row
640
+ async def check_row(
641
+ table: str,
642
+ pk: Any,
643
+ table_changes: List[Dict[str, Any]],
644
+ pk_columns: List[str],
645
+ ):
646
+ try:
647
+ # Build WHERE clause for this PK
648
+ where_sql = self._build_pk_where_clause(pk_columns, pk)
649
+
650
+ # Query before snapshot
651
+ before_query = f"SELECT * FROM {_quote_identifier(table)} WHERE {where_sql}"
652
+ before_response = await self.before.resource.query(before_query)
653
+ before_row = (
654
+ dict(zip(before_response.columns, before_response.rows[0]))
655
+ if before_response.rows
656
+ else None
657
+ )
658
+
659
+ # Query after snapshot
660
+ after_response = await self.after.resource.query(before_query)
661
+ after_row = (
662
+ dict(zip(after_response.columns, after_response.rows[0]))
663
+ if after_response.rows
664
+ else None
665
+ )
666
+
667
+ # Check changes for this row
668
+ if before_row and after_row:
669
+ # Modified row - check fields
670
+ for field in set(before_row.keys()) | set(after_row.keys()):
671
+ if self.ignore_config.should_ignore_field(table, field):
672
+ continue
673
+ before_val = before_row.get(field)
674
+ after_val = after_row.get(field)
675
+ if not _values_equivalent(before_val, after_val):
676
+ # Check if this change is allowed
677
+ if not self._is_field_change_allowed(
678
+ table_changes, pk, field, after_val
679
+ ):
680
+ error_msg = (
681
+ f"Unexpected change in table '{table}', "
682
+ f"row {pk}, field '{field}': "
683
+ f"{repr(before_val)} -> {repr(after_val)}"
684
+ )
685
+ errors.append(AssertionError(error_msg))
686
+ return # Stop checking this row
687
+ elif not before_row and after_row:
688
+ # Added row
689
+ if not self._is_row_change_allowed(table_changes, pk, "__added__"):
690
+ error_msg = f"Unexpected row added in table '{table}': {pk}"
691
+ errors.append(AssertionError(error_msg))
692
+ elif before_row and not after_row:
693
+ # Removed row
694
+ if not self._is_row_change_allowed(table_changes, pk, "__removed__"):
695
+ error_msg = f"Unexpected row removed from table '{table}': {pk}"
696
+ errors.append(AssertionError(error_msg))
697
+ except Exception as e:
698
+ errors.append(e)
699
+
700
+ # Prepare all row checks
701
+ row_checks = []
702
+ for table, table_changes in changes_by_table.items():
703
+ if self.ignore_config.should_ignore_table(table):
704
+ continue
705
+
706
+ # Get primary key columns once per table
707
+ pk_columns = await self._get_primary_key_columns(table)
708
+
709
+ # Extract unique PKs to check
710
+ pks_to_check = {change["pk"] for change in table_changes}
711
+
712
+ for pk in pks_to_check:
713
+ row_checks.append((table, pk, table_changes, pk_columns))
714
+
715
+ # Execute row checks in parallel
716
+ if row_checks:
717
+ await asyncio.gather(
718
+ *[
719
+ check_row(table, pk, table_changes, pk_columns)
720
+ for table, pk, table_changes, pk_columns in row_checks
721
+ ]
722
+ )
723
+
724
+ # Check for errors from row checks
725
+ if errors:
726
+ raise errors[0]
727
+
728
+ # Now check tables not mentioned in allowed_changes to ensure no changes
729
+ all_tables = set(await self.before.tables()) | set(await self.after.tables())
730
+ tables_to_verify = []
731
+
732
+ for table in all_tables:
733
+ if (
734
+ table not in changes_by_table
735
+ and not self.ignore_config.should_ignore_table(table)
736
+ ):
737
+ tables_to_verify.append(table)
738
+
739
+ # Function to verify no changes in a table
740
+ async def verify_no_changes(table: str):
741
+ try:
742
+ # For tables with no allowed changes, just check row counts
743
+ before_count_response = await self.before.resource.query(
744
+ f"SELECT COUNT(*) FROM {_quote_identifier(table)}"
745
+ )
746
+ before_count = (
747
+ before_count_response.rows[0][0]
748
+ if before_count_response.rows
749
+ else 0
750
+ )
751
+
752
+ after_count_response = await self.after.resource.query(
753
+ f"SELECT COUNT(*) FROM {_quote_identifier(table)}"
754
+ )
755
+ after_count = (
756
+ after_count_response.rows[0][0] if after_count_response.rows else 0
757
+ )
758
+
759
+ if before_count != after_count:
760
+ error_msg = (
761
+ f"Unexpected change in table '{table}': "
762
+ f"row count changed from {before_count} to {after_count}"
763
+ )
764
+ errors.append(AssertionError(error_msg))
765
+ except Exception as e:
766
+ errors.append(e)
767
+
768
+ # Execute table verification in parallel
769
+ if tables_to_verify:
770
+ await asyncio.gather(*[verify_no_changes(table) for table in tables_to_verify])
771
+
772
+ # Final error check
773
+ if errors:
774
+ raise errors[0]
775
+
776
+ return self
777
+
778
+ async def _validate_diff_against_allowed_changes(
779
+ self, diff: Dict[str, Any], allowed_changes: List[Dict[str, Any]]
780
+ ):
781
+ """Validate a collected diff against allowed changes."""
334
782
 
335
783
  def _is_change_allowed(
336
784
  table: str, row_id: Any, field: Optional[str], after_value: Any
@@ -457,6 +905,1053 @@ class AsyncSnapshotDiff:
457
905
 
458
906
  return self
459
907
 
908
+ async def _expect_only_targeted_v2(self, allowed_changes: List[Dict[str, Any]]):
909
+ """Optimized version that only queries specific rows mentioned in allowed_changes.
910
+
911
+ Supports v2 spec formats:
912
+ - {"table": "t", "pk": 1, "type": "insert", "fields": [...]}
913
+ - {"table": "t", "pk": 1, "type": "modify", "resulting_fields": [...], "no_other_changes": bool}
914
+ - {"table": "t", "pk": 1, "type": "delete", "fields": [...]}
915
+ - Legacy single-field specs: {"table": "t", "pk": 1, "field": "x", "after": val}
916
+ """
917
+ import asyncio
918
+
919
+ # Helper functions for v2 spec validation
920
+ def _parse_fields_spec(
921
+ fields_spec: List[Tuple[str, Any]]
922
+ ) -> Dict[str, Tuple[bool, Any]]:
923
+ """Parse a fields spec into a mapping of field_name -> (should_check_value, expected_value)."""
924
+ spec_map: Dict[str, Tuple[bool, Any]] = {}
925
+ for spec_tuple in fields_spec:
926
+ if len(spec_tuple) != 2:
927
+ raise ValueError(
928
+ f"Invalid field spec tuple: {spec_tuple}. "
929
+ f"Expected 2-tuple like ('field', value), ('field', None), or ('field', ...)"
930
+ )
931
+ field_name, expected_value = spec_tuple
932
+ if expected_value is ...:
933
+ spec_map[field_name] = (False, None)
934
+ else:
935
+ spec_map[field_name] = (True, expected_value)
936
+ return spec_map
937
+
938
+ def _get_all_specs_for_pk(table: str, pk: Any) -> List[Dict[str, Any]]:
939
+ """Get all specs for a given table/pk (for legacy multi-field specs)."""
940
+ specs = []
941
+ for allowed in allowed_changes:
942
+ if (
943
+ allowed["table"] == table
944
+ and str(allowed.get("pk")) == str(pk)
945
+ ):
946
+ specs.append(allowed)
947
+ return specs
948
+
949
+ def _validate_insert_row(
950
+ table: str, pk: Any, row_data: Dict[str, Any], specs: List[Dict[str, Any]]
951
+ ) -> Optional[str]:
952
+ """Validate an inserted row against specs. Returns error message or None."""
953
+ # Check for type: "insert" spec with fields
954
+ for spec in specs:
955
+ if spec.get("type") == "insert":
956
+ fields_spec = spec.get("fields")
957
+ if fields_spec is not None:
958
+ # Validate each field
959
+ spec_map = _parse_fields_spec(fields_spec)
960
+ for field_name, field_value in row_data.items():
961
+ if field_name == "rowid":
962
+ continue
963
+ if self.ignore_config.should_ignore_field(table, field_name):
964
+ continue
965
+ if field_name not in spec_map:
966
+ return f"Field '{field_name}' not in insert spec for table '{table}' pk={pk}"
967
+ should_check, expected_value = spec_map[field_name]
968
+ if should_check and not _values_equivalent(expected_value, field_value):
969
+ return (
970
+ f"Insert mismatch in table '{table}' pk={pk}, "
971
+ f"field '{field_name}': expected {repr(expected_value)}, got {repr(field_value)}"
972
+ )
973
+ # type: "insert" found (with or without fields) - allowed
974
+ return None
975
+
976
+ # Check for legacy whole-row spec
977
+ for spec in specs:
978
+ if spec.get("fields") is None and spec.get("after") == "__added__":
979
+ return None
980
+
981
+ return f"Unexpected row added in table '{table}': pk={pk}"
982
+
983
+ def _validate_delete_row(
984
+ table: str, pk: Any, row_data: Dict[str, Any], specs: List[Dict[str, Any]]
985
+ ) -> Optional[str]:
986
+ """Validate a deleted row against specs. Returns error message or None."""
987
+ # Check for type: "delete" spec with optional fields
988
+ for spec in specs:
989
+ if spec.get("type") == "delete":
990
+ fields_spec = spec.get("fields")
991
+ if fields_spec is not None:
992
+ # Validate each field against the deleted row
993
+ spec_map = _parse_fields_spec(fields_spec)
994
+ for field_name, (should_check, expected_value) in spec_map.items():
995
+ if field_name not in row_data:
996
+ return f"Field '{field_name}' in delete spec not found in row for table '{table}' pk={pk}"
997
+ if should_check and not _values_equivalent(expected_value, row_data[field_name]):
998
+ return (
999
+ f"Delete mismatch in table '{table}' pk={pk}, "
1000
+ f"field '{field_name}': expected {repr(expected_value)}, got {repr(row_data[field_name])}"
1001
+ )
1002
+ # type: "delete" found (with or without fields) - allowed
1003
+ return None
1004
+
1005
+ # Check for legacy whole-row spec
1006
+ for spec in specs:
1007
+ if spec.get("fields") is None and spec.get("after") == "__removed__":
1008
+ return None
1009
+
1010
+ return f"Unexpected row removed from table '{table}': pk={pk}"
1011
+
1012
+ def _validate_modify_row(
1013
+ table: str,
1014
+ pk: Any,
1015
+ before_row: Dict[str, Any],
1016
+ after_row: Dict[str, Any],
1017
+ specs: List[Dict[str, Any]],
1018
+ ) -> Optional[str]:
1019
+ """Validate a modified row against specs. Returns error message or None."""
1020
+ # Collect actual changes
1021
+ changed_fields: Dict[str, Dict[str, Any]] = {}
1022
+ for field in set(before_row.keys()) | set(after_row.keys()):
1023
+ if self.ignore_config.should_ignore_field(table, field):
1024
+ continue
1025
+ before_val = before_row.get(field)
1026
+ after_val = after_row.get(field)
1027
+ if not _values_equivalent(before_val, after_val):
1028
+ changed_fields[field] = {"before": before_val, "after": after_val}
1029
+
1030
+ if not changed_fields:
1031
+ return None # No changes
1032
+
1033
+ # Check for type: "modify" spec with resulting_fields
1034
+ for spec in specs:
1035
+ if spec.get("type") == "modify":
1036
+ resulting_fields = spec.get("resulting_fields")
1037
+ if resulting_fields is not None:
1038
+ # Validate no_other_changes is provided
1039
+ if "no_other_changes" not in spec:
1040
+ raise ValueError(
1041
+ f"Modify spec for table '{table}' pk={pk} "
1042
+ f"has 'resulting_fields' but missing required 'no_other_changes' field."
1043
+ )
1044
+ no_other_changes = spec["no_other_changes"]
1045
+ if not isinstance(no_other_changes, bool):
1046
+ raise ValueError(
1047
+ f"Modify spec for table '{table}' pk={pk} "
1048
+ f"'no_other_changes' must be boolean, got {type(no_other_changes).__name__}"
1049
+ )
1050
+
1051
+ spec_map = _parse_fields_spec(resulting_fields)
1052
+
1053
+ # Validate changed fields
1054
+ for field_name, vals in changed_fields.items():
1055
+ after_val = vals["after"]
1056
+ if field_name not in spec_map:
1057
+ if no_other_changes:
1058
+ return (
1059
+ f"Unexpected field change in table '{table}' pk={pk}: "
1060
+ f"field '{field_name}' not in resulting_fields"
1061
+ )
1062
+ # no_other_changes=False: ignore this field
1063
+ else:
1064
+ should_check, expected_value = spec_map[field_name]
1065
+ if should_check and not _values_equivalent(expected_value, after_val):
1066
+ return (
1067
+ f"Modify mismatch in table '{table}' pk={pk}, "
1068
+ f"field '{field_name}': expected {repr(expected_value)}, got {repr(after_val)}"
1069
+ )
1070
+ return None # Validation passed
1071
+ else:
1072
+ # type: "modify" without resulting_fields - allow any modification
1073
+ return None
1074
+
1075
+ # Check for legacy single-field specs
1076
+ for field_name, vals in changed_fields.items():
1077
+ after_val = vals["after"]
1078
+ field_allowed = False
1079
+ for spec in specs:
1080
+ if (
1081
+ spec.get("field") == field_name
1082
+ and _values_equivalent(spec.get("after"), after_val)
1083
+ ):
1084
+ field_allowed = True
1085
+ break
1086
+ if not field_allowed:
1087
+ return (
1088
+ f"Unexpected change in table '{table}' pk={pk}, "
1089
+ f"field '{field_name}': {repr(vals['before'])} -> {repr(after_val)}"
1090
+ )
1091
+
1092
+ return None
1093
+
1094
+ # Group allowed changes by table
1095
+ changes_by_table: Dict[str, List[Dict[str, Any]]] = {}
1096
+ for change in allowed_changes:
1097
+ table = change["table"]
1098
+ if table not in changes_by_table:
1099
+ changes_by_table[table] = []
1100
+ changes_by_table[table].append(change)
1101
+
1102
+ errors: List[Exception] = []
1103
+
1104
+ # Async function to check a single row
1105
+ async def check_row(
1106
+ table: str,
1107
+ pk: Any,
1108
+ pk_columns: List[str],
1109
+ ):
1110
+ try:
1111
+ # Build WHERE clause for this PK
1112
+ where_sql = self._build_pk_where_clause(pk_columns, pk)
1113
+
1114
+ # Query before snapshot
1115
+ before_query = f"SELECT * FROM {_quote_identifier(table)} WHERE {where_sql}"
1116
+ before_response = await self.before.resource.query(before_query)
1117
+ before_row = (
1118
+ dict(zip(before_response.columns, before_response.rows[0]))
1119
+ if before_response.rows
1120
+ else None
1121
+ )
1122
+
1123
+ # Query after snapshot
1124
+ after_response = await self.after.resource.query(before_query)
1125
+ after_row = (
1126
+ dict(zip(after_response.columns, after_response.rows[0]))
1127
+ if after_response.rows
1128
+ else None
1129
+ )
1130
+
1131
+ # Get all specs for this table/pk
1132
+ specs = _get_all_specs_for_pk(table, pk)
1133
+
1134
+ # Check changes for this row
1135
+ if before_row and after_row:
1136
+ # Modified row
1137
+ error = _validate_modify_row(table, pk, before_row, after_row, specs)
1138
+ if error:
1139
+ errors.append(AssertionError(error))
1140
+ elif not before_row and after_row:
1141
+ # Added row
1142
+ error = _validate_insert_row(table, pk, after_row, specs)
1143
+ if error:
1144
+ errors.append(AssertionError(error))
1145
+ elif before_row and not after_row:
1146
+ # Removed row
1147
+ error = _validate_delete_row(table, pk, before_row, specs)
1148
+ if error:
1149
+ errors.append(AssertionError(error))
1150
+
1151
+ except Exception as e:
1152
+ errors.append(e)
1153
+
1154
+ # Prepare all row checks
1155
+ row_tasks = []
1156
+ for table, table_changes in changes_by_table.items():
1157
+ if self.ignore_config.should_ignore_table(table):
1158
+ continue
1159
+
1160
+ # Get primary key columns once per table
1161
+ pk_columns = self._get_primary_key_columns(table)
1162
+
1163
+ # Extract unique PKs to check
1164
+ pks_to_check = {change["pk"] for change in table_changes}
1165
+
1166
+ for pk in pks_to_check:
1167
+ row_tasks.append(check_row(table, pk, pk_columns))
1168
+
1169
+ # Execute row checks concurrently
1170
+ if row_tasks:
1171
+ await asyncio.gather(*row_tasks)
1172
+
1173
+ # Check for errors from row checks
1174
+ if errors:
1175
+ raise errors[0]
1176
+
1177
+ # Now check tables not mentioned in allowed_changes to ensure no changes
1178
+ all_tables = set(await self.before.tables()) | set(await self.after.tables())
1179
+ tables_to_verify = []
1180
+
1181
+ for table in all_tables:
1182
+ if (
1183
+ table not in changes_by_table
1184
+ and not self.ignore_config.should_ignore_table(table)
1185
+ ):
1186
+ tables_to_verify.append(table)
1187
+
1188
+ # Async function to verify no changes in a table
1189
+ async def verify_no_changes(table: str):
1190
+ try:
1191
+ # For tables with no allowed changes, just check row counts
1192
+ before_count_response = await self.before.resource.query(
1193
+ f"SELECT COUNT(*) FROM {_quote_identifier(table)}"
1194
+ )
1195
+ before_count = (
1196
+ before_count_response.rows[0][0]
1197
+ if before_count_response.rows
1198
+ else 0
1199
+ )
1200
+
1201
+ after_count_response = await self.after.resource.query(
1202
+ f"SELECT COUNT(*) FROM {_quote_identifier(table)}"
1203
+ )
1204
+ after_count = (
1205
+ after_count_response.rows[0][0] if after_count_response.rows else 0
1206
+ )
1207
+
1208
+ if before_count != after_count:
1209
+ error_msg = (
1210
+ f"Unexpected change in table '{table}': "
1211
+ f"row count changed from {before_count} to {after_count}"
1212
+ )
1213
+ errors.append(AssertionError(error_msg))
1214
+ except Exception as e:
1215
+ errors.append(e)
1216
+
1217
+ # Execute table verification concurrently
1218
+ if tables_to_verify:
1219
+ verify_tasks = [verify_no_changes(table) for table in tables_to_verify]
1220
+ await asyncio.gather(*verify_tasks)
1221
+
1222
+ # Final error check
1223
+ if errors:
1224
+ raise errors[0]
1225
+
1226
+ return self
1227
+
1228
+ async def _validate_diff_against_allowed_changes_v2(
1229
+ self, diff: Dict[str, Any], allowed_changes: List[Dict[str, Any]]
1230
+ ):
1231
+ """Validate a collected diff against allowed changes with field-level spec support.
1232
+
1233
+ This version supports explicit change types via the "type" field:
1234
+ 1. Insert specs: {"table": "t", "pk": 1, "type": "insert", "fields": [("name", "value"), ("status", ...)]}
1235
+ - ("name", value): check that field equals value
1236
+ - ("name", None): check that field is SQL NULL
1237
+ - ("name", ...): don't check the value, just acknowledge the field exists
1238
+ 2. Modify specs: {"table": "t", "pk": 1, "type": "modify", "resulting_fields": [...], "no_other_changes": True/False}
1239
+ - Uses "resulting_fields" (not "fields") to be explicit about what's being checked
1240
+ - "no_other_changes" is REQUIRED and must be True or False:
1241
+ - True: Every changed field must be in resulting_fields (strict mode)
1242
+ - False: Only check fields in resulting_fields match, ignore other changes
1243
+ - ("field_name", value): check that after value equals value
1244
+ - ("field_name", None): check that after value is SQL NULL
1245
+ - ("field_name", ...): don't check value, just acknowledge field changed
1246
+ 3. Delete specs:
1247
+ - Without field validation: {"table": "t", "pk": 1, "type": "delete"}
1248
+ - With field validation: {"table": "t", "pk": 1, "type": "delete", "fields": [...]}
1249
+ 4. Whole-row specs (legacy):
1250
+ - For additions: {"table": "t", "pk": 1, "fields": None, "after": "__added__"}
1251
+ - For deletions: {"table": "t", "pk": 1, "fields": None, "after": "__removed__"}
1252
+
1253
+ When using "fields" for inserts, every field must be accounted for in the list.
1254
+ For modifications, use "resulting_fields" with explicit "no_other_changes".
1255
+ For deletions with "fields", all specified fields are validated against the deleted row.
1256
+ """
1257
+
1258
+ def _is_change_allowed(
1259
+ table: str, row_id: Any, field: Optional[str], after_value: Any
1260
+ ) -> bool:
1261
+ """Check if a change is in the allowed list using semantic comparison."""
1262
+ for allowed in allowed_changes:
1263
+ allowed_pk = allowed.get("pk")
1264
+ # Handle type conversion for primary key comparison
1265
+ pk_match = (
1266
+ str(allowed_pk) == str(row_id) if allowed_pk is not None else False
1267
+ )
1268
+
1269
+ # For whole-row specs, check "fields": None; for field-level, check "field"
1270
+ field_match = (
1271
+ ("fields" in allowed and allowed.get("fields") is None)
1272
+ if field is None
1273
+ else allowed.get("field") == field
1274
+ )
1275
+ if (
1276
+ allowed["table"] == table
1277
+ and pk_match
1278
+ and field_match
1279
+ and _values_equivalent(allowed.get("after"), after_value)
1280
+ ):
1281
+ return True
1282
+ return False
1283
+
1284
+ def _get_fields_spec_for_type(
1285
+ table: str, row_id: Any, change_type: str
1286
+ ) -> Optional[List[Tuple[str, Any]]]:
1287
+ """Get the bulk fields spec for a given table/row/type if it exists.
1288
+
1289
+ Args:
1290
+ table: The table name
1291
+ row_id: The primary key value
1292
+ change_type: One of "insert", "modify", or "delete"
1293
+
1294
+ Note: For "modify" type, use _get_modify_spec instead.
1295
+ """
1296
+ for allowed in allowed_changes:
1297
+ allowed_pk = allowed.get("pk")
1298
+ pk_match = (
1299
+ str(allowed_pk) == str(row_id) if allowed_pk is not None else False
1300
+ )
1301
+ if (
1302
+ allowed["table"] == table
1303
+ and pk_match
1304
+ and allowed.get("type") == change_type
1305
+ and "fields" in allowed
1306
+ ):
1307
+ return allowed["fields"]
1308
+ return None
1309
+
1310
+ def _get_modify_spec(table: str, row_id: Any) -> Optional[Dict[str, Any]]:
1311
+ """Get the modify spec for a given table/row if it exists.
1312
+
1313
+ Returns the full spec dict containing:
1314
+ - resulting_fields: List of field tuples
1315
+ - no_other_changes: Boolean (required)
1316
+
1317
+ Returns None if no modify spec found.
1318
+ """
1319
+ for allowed in allowed_changes:
1320
+ allowed_pk = allowed.get("pk")
1321
+ pk_match = (
1322
+ str(allowed_pk) == str(row_id) if allowed_pk is not None else False
1323
+ )
1324
+ if (
1325
+ allowed["table"] == table
1326
+ and pk_match
1327
+ and allowed.get("type") == "modify"
1328
+ ):
1329
+ return allowed
1330
+ return None
1331
+
1332
+ def _is_type_allowed(table: str, row_id: Any, change_type: str) -> bool:
1333
+ """Check if a change type is allowed for the given table/row (with or without fields)."""
1334
+ for allowed in allowed_changes:
1335
+ allowed_pk = allowed.get("pk")
1336
+ pk_match = (
1337
+ str(allowed_pk) == str(row_id) if allowed_pk is not None else False
1338
+ )
1339
+ if (
1340
+ allowed["table"] == table
1341
+ and pk_match
1342
+ and allowed.get("type") == change_type
1343
+ ):
1344
+ return True
1345
+ return False
1346
+
1347
+ def _parse_fields_spec(
1348
+ fields_spec: List[Tuple[str, Any]]
1349
+ ) -> Dict[str, Tuple[bool, Any]]:
1350
+ """Parse a fields spec into a mapping of field_name -> (should_check_value, expected_value)."""
1351
+ spec_map: Dict[str, Tuple[bool, Any]] = {}
1352
+ for spec_tuple in fields_spec:
1353
+ if len(spec_tuple) != 2:
1354
+ raise ValueError(
1355
+ f"Invalid field spec tuple: {spec_tuple}. "
1356
+ f"Expected 2-tuple like ('field', value), ('field', None), or ('field', ...)"
1357
+ )
1358
+ field_name, expected_value = spec_tuple
1359
+ if expected_value is ...:
1360
+ # Ellipsis: don't check value, just acknowledge field exists
1361
+ spec_map[field_name] = (False, None)
1362
+ else:
1363
+ # Any other value (including None for NULL check): check value
1364
+ spec_map[field_name] = (True, expected_value)
1365
+ return spec_map
1366
+
1367
+ def _validate_row_with_fields_spec(
1368
+ table: str,
1369
+ row_id: Any,
1370
+ row_data: Dict[str, Any],
1371
+ fields_spec: List[Tuple[str, Any]],
1372
+ ) -> Optional[List[Tuple[str, Any, str]]]:
1373
+ """Validate a row against a bulk fields spec.
1374
+
1375
+ Returns None if validation passes, or a list of (field, actual_value, issue)
1376
+ tuples for mismatches.
1377
+
1378
+ Field spec semantics:
1379
+ - ("field_name", value): check that field equals value
1380
+ - ("field_name", None): check that field is SQL NULL
1381
+ - ("field_name", ...): don't check value (acknowledge field exists)
1382
+ """
1383
+ spec_map = _parse_fields_spec(fields_spec)
1384
+ unmatched_fields: List[Tuple[str, Any, str]] = []
1385
+
1386
+ for field_name, field_value in row_data.items():
1387
+ # Skip rowid as it's internal
1388
+ if field_name == "rowid":
1389
+ continue
1390
+ # Skip ignored fields
1391
+ if self.ignore_config.should_ignore_field(table, field_name):
1392
+ continue
1393
+
1394
+ if field_name not in spec_map:
1395
+ # Field not in spec - this is an error
1396
+ unmatched_fields.append(
1397
+ (field_name, field_value, "NOT_IN_FIELDS_SPEC")
1398
+ )
1399
+ else:
1400
+ should_check, expected_value = spec_map[field_name]
1401
+ if should_check and not _values_equivalent(
1402
+ expected_value, field_value
1403
+ ):
1404
+ # Value doesn't match
1405
+ unmatched_fields.append(
1406
+ (field_name, field_value, f"expected {repr(expected_value)}")
1407
+ )
1408
+
1409
+ return unmatched_fields if unmatched_fields else None
1410
+
1411
+ def _validate_modification_with_fields_spec(
1412
+ table: str,
1413
+ row_id: Any,
1414
+ row_changes: Dict[str, Dict[str, Any]],
1415
+ resulting_fields: List[Tuple[str, Any]],
1416
+ no_other_changes: bool,
1417
+ ) -> Optional[List[Tuple[str, Any, str]]]:
1418
+ """Validate a modification against a resulting_fields spec.
1419
+
1420
+ Returns None if validation passes, or a list of (field, actual_value, issue)
1421
+ tuples for mismatches.
1422
+
1423
+ Args:
1424
+ table: The table name
1425
+ row_id: The row primary key
1426
+ row_changes: Dict of field_name -> {"before": ..., "after": ...}
1427
+ resulting_fields: List of field tuples to validate
1428
+ no_other_changes: If True, all changed fields must be in resulting_fields.
1429
+ If False, only validate fields in resulting_fields, ignore others.
1430
+
1431
+ Field spec semantics for modifications:
1432
+ - ("field_name", value): check that after value equals value
1433
+ - ("field_name", None): check that after value is SQL NULL
1434
+ - ("field_name", ...): don't check value, just acknowledge field changed
1435
+ """
1436
+ spec_map = _parse_fields_spec(resulting_fields)
1437
+ unmatched_fields: List[Tuple[str, Any, str]] = []
1438
+
1439
+ for field_name, vals in row_changes.items():
1440
+ # Skip ignored fields
1441
+ if self.ignore_config.should_ignore_field(table, field_name):
1442
+ continue
1443
+
1444
+ after_value = vals["after"]
1445
+
1446
+ if field_name not in spec_map:
1447
+ # Changed field not in spec
1448
+ if no_other_changes:
1449
+ # Strict mode: all changed fields must be accounted for
1450
+ unmatched_fields.append(
1451
+ (field_name, after_value, "NOT_IN_RESULTING_FIELDS")
1452
+ )
1453
+ # If no_other_changes=False, ignore fields not in spec
1454
+ else:
1455
+ should_check, expected_value = spec_map[field_name]
1456
+ if should_check and not _values_equivalent(
1457
+ expected_value, after_value
1458
+ ):
1459
+ # Value doesn't match
1460
+ unmatched_fields.append(
1461
+ (field_name, after_value, f"expected {repr(expected_value)}")
1462
+ )
1463
+
1464
+ return unmatched_fields if unmatched_fields else None
1465
+
1466
+
1467
+ # Collect all unexpected changes for detailed reporting
1468
+ unexpected_changes = []
1469
+
1470
+ for tbl, report in diff.items():
1471
+ for row in report.get("modified_rows", []):
1472
+ row_changes = row["changes"]
1473
+
1474
+ # Check for modify spec with resulting_fields
1475
+ modify_spec = _get_modify_spec(tbl, row["row_id"])
1476
+ if modify_spec is not None:
1477
+ resulting_fields = modify_spec.get("resulting_fields")
1478
+ if resulting_fields is not None:
1479
+ # Validate that no_other_changes is provided
1480
+ if "no_other_changes" not in modify_spec:
1481
+ raise ValueError(
1482
+ f"Modify spec for table '{tbl}' pk={row['row_id']} "
1483
+ f"has 'resulting_fields' but missing required 'no_other_changes' field. "
1484
+ f"Set 'no_other_changes': True to verify no other fields changed, "
1485
+ f"or 'no_other_changes': False to only check the specified fields."
1486
+ )
1487
+ no_other_changes = modify_spec["no_other_changes"]
1488
+ if not isinstance(no_other_changes, bool):
1489
+ raise ValueError(
1490
+ f"Modify spec for table '{tbl}' pk={row['row_id']} "
1491
+ f"has 'no_other_changes' but it must be a boolean (True or False), "
1492
+ f"got {type(no_other_changes).__name__}: {repr(no_other_changes)}"
1493
+ )
1494
+
1495
+ unmatched = _validate_modification_with_fields_spec(
1496
+ tbl, row["row_id"], row_changes, resulting_fields, no_other_changes
1497
+ )
1498
+ if unmatched:
1499
+ unexpected_changes.append(
1500
+ {
1501
+ "type": "modification",
1502
+ "table": tbl,
1503
+ "row_id": row["row_id"],
1504
+ "field": None,
1505
+ "before": None,
1506
+ "after": None,
1507
+ "full_row": row,
1508
+ "unmatched_fields": unmatched,
1509
+ }
1510
+ )
1511
+ continue # Skip to next row
1512
+ else:
1513
+ # Modify spec without resulting_fields - just allow the modification
1514
+ continue # Skip to next row
1515
+
1516
+ # Fall back to single-field specs (legacy)
1517
+ for f, vals in row_changes.items():
1518
+ if self.ignore_config.should_ignore_field(tbl, f):
1519
+ continue
1520
+ if not _is_change_allowed(tbl, row["row_id"], f, vals["after"]):
1521
+ unexpected_changes.append(
1522
+ {
1523
+ "type": "modification",
1524
+ "table": tbl,
1525
+ "row_id": row["row_id"],
1526
+ "field": f,
1527
+ "before": vals.get("before"),
1528
+ "after": vals["after"],
1529
+ "full_row": row,
1530
+ }
1531
+ )
1532
+
1533
+ for row in report.get("added_rows", []):
1534
+ row_data = row.get("data", {})
1535
+
1536
+ # Check for bulk fields spec (type: "insert")
1537
+ fields_spec = _get_fields_spec_for_type(tbl, row["row_id"], "insert")
1538
+ if fields_spec is not None:
1539
+ unmatched = _validate_row_with_fields_spec(
1540
+ tbl, row["row_id"], row_data, fields_spec
1541
+ )
1542
+ if unmatched:
1543
+ unexpected_changes.append(
1544
+ {
1545
+ "type": "insertion",
1546
+ "table": tbl,
1547
+ "row_id": row["row_id"],
1548
+ "field": None,
1549
+ "after": "__added__",
1550
+ "full_row": row,
1551
+ "unmatched_fields": unmatched,
1552
+ }
1553
+ )
1554
+ continue # Skip to next row
1555
+
1556
+ # Check if insertion is allowed without field validation
1557
+ if _is_type_allowed(tbl, row["row_id"], "insert"):
1558
+ continue # Insertion is allowed, skip to next row
1559
+
1560
+ # Check for whole-row spec (legacy)
1561
+ whole_row_allowed = _is_change_allowed(
1562
+ tbl, row["row_id"], None, "__added__"
1563
+ )
1564
+
1565
+ if not whole_row_allowed:
1566
+ unexpected_changes.append(
1567
+ {
1568
+ "type": "insertion",
1569
+ "table": tbl,
1570
+ "row_id": row["row_id"],
1571
+ "field": None,
1572
+ "after": "__added__",
1573
+ "full_row": row,
1574
+ }
1575
+ )
1576
+
1577
+ for row in report.get("removed_rows", []):
1578
+ row_data = row.get("data", {})
1579
+
1580
+ # Check for bulk fields spec (type: "delete")
1581
+ fields_spec = _get_fields_spec_for_type(tbl, row["row_id"], "delete")
1582
+ if fields_spec is not None:
1583
+ unmatched = _validate_row_with_fields_spec(
1584
+ tbl, row["row_id"], row_data, fields_spec
1585
+ )
1586
+ if unmatched:
1587
+ unexpected_changes.append(
1588
+ {
1589
+ "type": "deletion",
1590
+ "table": tbl,
1591
+ "row_id": row["row_id"],
1592
+ "field": None,
1593
+ "after": "__removed__",
1594
+ "full_row": row,
1595
+ "unmatched_fields": unmatched,
1596
+ }
1597
+ )
1598
+ continue # Skip to next row
1599
+
1600
+ # Check if deletion is allowed without field validation
1601
+ if _is_type_allowed(tbl, row["row_id"], "delete"):
1602
+ continue # Deletion is allowed, skip to next row
1603
+
1604
+ # Check for whole-row spec (legacy)
1605
+ whole_row_allowed = _is_change_allowed(
1606
+ tbl, row["row_id"], None, "__removed__"
1607
+ )
1608
+
1609
+ if not whole_row_allowed:
1610
+ unexpected_changes.append(
1611
+ {
1612
+ "type": "deletion",
1613
+ "table": tbl,
1614
+ "row_id": row["row_id"],
1615
+ "field": None,
1616
+ "after": "__removed__",
1617
+ "full_row": row,
1618
+ }
1619
+ )
1620
+
1621
+ if unexpected_changes:
1622
+ # Build comprehensive error message
1623
+ error_lines = ["Unexpected database changes detected:"]
1624
+ error_lines.append("")
1625
+
1626
+ for i, change in enumerate(unexpected_changes[:5], 1):
1627
+ error_lines.append(
1628
+ f"{i}. {change['type'].upper()} in table '{change['table']}':"
1629
+ )
1630
+ error_lines.append(f" Row ID: {change['row_id']}")
1631
+
1632
+ if change["type"] == "modification":
1633
+ error_lines.append(f" Field: {change['field']}")
1634
+ error_lines.append(f" Before: {repr(change['before'])}")
1635
+ error_lines.append(f" After: {repr(change['after'])}")
1636
+ elif change["type"] == "insertion":
1637
+ error_lines.append(" New row added")
1638
+ elif change["type"] == "deletion":
1639
+ error_lines.append(" Row deleted")
1640
+
1641
+ # Show unmatched fields if present (from bulk fields spec validation)
1642
+ if "unmatched_fields" in change and change["unmatched_fields"]:
1643
+ error_lines.append(" Unmatched fields:")
1644
+ for field_info in change["unmatched_fields"][:5]:
1645
+ field_name, actual_value, issue = field_info
1646
+ error_lines.append(
1647
+ f" - {field_name}: {repr(actual_value)} ({issue})"
1648
+ )
1649
+ if len(change["unmatched_fields"]) > 10:
1650
+ error_lines.append(
1651
+ f" ... and {len(change['unmatched_fields']) - 10} more"
1652
+ )
1653
+
1654
+ # Show some context from the row
1655
+ if "full_row" in change and change["full_row"]:
1656
+ row_data = change["full_row"]
1657
+ if change["type"] == "modification" and "data" in row_data:
1658
+ # For modifications, show the current state
1659
+ formatted_row = _format_row_for_error(
1660
+ row_data.get("data", {}), max_fields=5
1661
+ )
1662
+ error_lines.append(f" Row data: {formatted_row}")
1663
+ elif (
1664
+ change["type"] in ["insertion", "deletion"]
1665
+ and "data" in row_data
1666
+ ):
1667
+ # For insertions/deletions, show the row data
1668
+ formatted_row = _format_row_for_error(
1669
+ row_data.get("data", {}), max_fields=5
1670
+ )
1671
+ error_lines.append(f" Row data: {formatted_row}")
1672
+
1673
+ error_lines.append("")
1674
+
1675
+ if len(unexpected_changes) > 5:
1676
+ error_lines.append(
1677
+ f"... and {len(unexpected_changes) - 5} more unexpected changes"
1678
+ )
1679
+ error_lines.append("")
1680
+
1681
+ # Show what changes were allowed
1682
+ error_lines.append("Allowed changes were:")
1683
+ if allowed_changes:
1684
+ for i, allowed in enumerate(allowed_changes[:3], 1):
1685
+ change_type = allowed.get("type", "unspecified")
1686
+
1687
+ # For modify type, use resulting_fields
1688
+ if change_type == "modify" and "resulting_fields" in allowed and allowed["resulting_fields"] is not None:
1689
+ fields_summary = ", ".join(
1690
+ f[0] if len(f) == 1 else f"{f[0]}={'NOT_CHECKED' if f[1] is ... else repr(f[1])}"
1691
+ for f in allowed["resulting_fields"][:3]
1692
+ )
1693
+ if len(allowed["resulting_fields"]) > 3:
1694
+ fields_summary += f", ... +{len(allowed['resulting_fields']) - 3} more"
1695
+ no_other = allowed.get("no_other_changes", "NOT_SET")
1696
+ error_lines.append(
1697
+ f" {i}. Table: {allowed.get('table')}, "
1698
+ f"ID: {allowed.get('pk')}, "
1699
+ f"Type: {change_type}, "
1700
+ f"resulting_fields: [{fields_summary}], "
1701
+ f"no_other_changes: {no_other}"
1702
+ )
1703
+ elif "fields" in allowed and allowed["fields"] is not None:
1704
+ # Show bulk fields spec (for insert/delete)
1705
+ fields_summary = ", ".join(
1706
+ f[0] if len(f) == 1 else f"{f[0]}={'NOT_CHECKED' if f[1] is ... else repr(f[1])}"
1707
+ for f in allowed["fields"][:3]
1708
+ )
1709
+ if len(allowed["fields"]) > 3:
1710
+ fields_summary += f", ... +{len(allowed['fields']) - 3} more"
1711
+ error_lines.append(
1712
+ f" {i}. Table: {allowed.get('table')}, "
1713
+ f"ID: {allowed.get('pk')}, "
1714
+ f"Type: {change_type}, "
1715
+ f"Fields: [{fields_summary}]"
1716
+ )
1717
+ else:
1718
+ error_lines.append(
1719
+ f" {i}. Table: {allowed.get('table')}, "
1720
+ f"ID: {allowed.get('pk')}, "
1721
+ f"Type: {change_type}"
1722
+ )
1723
+ if len(allowed_changes) > 3:
1724
+ error_lines.append(
1725
+ f" ... and {len(allowed_changes) - 3} more allowed changes"
1726
+ )
1727
+ else:
1728
+ error_lines.append(" (No changes were allowed)")
1729
+
1730
+ raise AssertionError("\n".join(error_lines))
1731
+
1732
+ return self
1733
+
1734
+ async def expect_only(self, allowed_changes: List[Dict[str, Any]]):
1735
+ """Ensure only specified changes occurred."""
1736
+ # Normalize pk values: convert lists to tuples for hashability and consistency
1737
+ for change in allowed_changes:
1738
+ if "pk" in change and isinstance(change["pk"], list):
1739
+ change["pk"] = tuple(change["pk"])
1740
+
1741
+ # Special case: empty allowed_changes means no changes should have occurred
1742
+ if not allowed_changes:
1743
+ return await self._expect_no_changes()
1744
+
1745
+ # For expect_only, we can optimize by only checking the specific rows mentioned
1746
+ if self._can_use_targeted_queries(allowed_changes):
1747
+ return await self._expect_only_targeted(allowed_changes)
1748
+
1749
+ # Fall back to full diff for complex cases
1750
+ diff = await self._collect()
1751
+ return await self._validate_diff_against_allowed_changes(diff, allowed_changes)
1752
+
1753
+ async def expect_only_v2(self, allowed_changes: List[Dict[str, Any]]):
1754
+ """Ensure only specified changes occurred, with field-level spec support.
1755
+
1756
+ This version supports field-level specifications for added/removed rows,
1757
+ allowing users to specify expected field values instead of just whole-row specs.
1758
+ """
1759
+ # Normalize pk values: convert lists to tuples for hashability and consistency
1760
+ for change in allowed_changes:
1761
+ if "pk" in change and isinstance(change["pk"], list):
1762
+ change["pk"] = tuple(change["pk"])
1763
+
1764
+ # Special case: empty allowed_changes means no changes should have occurred
1765
+ if not allowed_changes:
1766
+ return await self._expect_no_changes()
1767
+
1768
+ resource = self.after.resource
1769
+ # Disabled: structured diff endpoint not yet available
1770
+ if False and resource.client is not None and resource._mode == "http":
1771
+ api_diff = None
1772
+ try:
1773
+ payload = {}
1774
+ if self.ignore_config:
1775
+ payload["ignore_config"] = {
1776
+ "tables": list(self.ignore_config.tables),
1777
+ "fields": list(self.ignore_config.fields),
1778
+ "table_fields": {
1779
+ table: list(fields) for table, fields in self.ignore_config.table_fields.items()
1780
+ }
1781
+ }
1782
+ response = await resource.client.request(
1783
+ "POST",
1784
+ "/diff/structured",
1785
+ json=payload,
1786
+ )
1787
+ result = response.json()
1788
+ if result.get("success") and "diff" in result:
1789
+ api_diff = result["diff"]
1790
+ except Exception as e:
1791
+ # Fall back to local diff if API call fails
1792
+ print(f"Warning: Failed to fetch structured diff from API: {e}")
1793
+ print("Falling back to local diff computation...")
1794
+
1795
+ # Validate outside try block so AssertionError propagates
1796
+ if api_diff is not None:
1797
+ return await self._validate_diff_against_allowed_changes_v2(api_diff, allowed_changes)
1798
+
1799
+ # For expect_only_v2, we can optimize by only checking the specific rows mentioned
1800
+ if self._can_use_targeted_queries(allowed_changes):
1801
+ return await self._expect_only_targeted_v2(allowed_changes)
1802
+
1803
+ # Fall back to full diff for complex cases
1804
+ diff = await self._collect()
1805
+ return await self._validate_diff_against_allowed_changes_v2(
1806
+ diff, allowed_changes
1807
+ )
1808
+
1809
+ async def expect_exactly(self, expected_changes: List[Dict[str, Any]]):
1810
+ """Verify that EXACTLY the specified changes occurred.
1811
+
1812
+ This is stricter than expect_only_v2:
1813
+ 1. All changes in diff must match a spec (no unexpected changes)
1814
+ 2. All specs must have a matching change in diff (no missing expected changes)
1815
+
1816
+ This method is ideal for verifying that an agent performed exactly what was expected -
1817
+ not more, not less.
1818
+
1819
+ Args:
1820
+ expected_changes: List of expected change specs. Each spec requires:
1821
+ - "type": "insert", "modify", or "delete" (required)
1822
+ - "table": table name (required)
1823
+ - "pk": primary key value (required)
1824
+
1825
+ Spec formats by type:
1826
+ - Insert: {"type": "insert", "table": "t", "pk": 1, "fields": [...]}
1827
+ - Modify: {"type": "modify", "table": "t", "pk": 1, "resulting_fields": [...], "no_other_changes": True/False}
1828
+ - Delete: {"type": "delete", "table": "t", "pk": 1}
1829
+
1830
+ Field specs are 2-tuples: (field_name, expected_value)
1831
+ - ("name", "Alice"): check field equals "Alice"
1832
+ - ("name", ...): accept any value (ellipsis)
1833
+ - ("name", None): check field is SQL NULL
1834
+
1835
+ Note: Legacy specs without explicit "type" are not supported.
1836
+
1837
+ Returns:
1838
+ self for method chaining
1839
+
1840
+ Raises:
1841
+ AssertionError: If there are unexpected changes OR if expected changes are missing
1842
+ ValueError: If specs are missing required fields or have invalid format
1843
+ """
1844
+ # Get the diff (using HTTP if available, otherwise local)
1845
+ resource = self.after.resource
1846
+ diff = None
1847
+
1848
+ if resource.client is not None and resource._mode == "http":
1849
+ try:
1850
+ payload = {}
1851
+ if self.ignore_config:
1852
+ payload["ignore_config"] = {
1853
+ "tables": list(self.ignore_config.tables),
1854
+ "fields": list(self.ignore_config.fields),
1855
+ "table_fields": {
1856
+ table: list(fields) for table, fields in self.ignore_config.table_fields.items()
1857
+ }
1858
+ }
1859
+ response = await resource.client.request(
1860
+ "POST",
1861
+ "/diff/structured",
1862
+ json=payload,
1863
+ )
1864
+ result = response.json()
1865
+ if result.get("success") and "diff" in result:
1866
+ diff = result["diff"]
1867
+ except Exception as e:
1868
+ print(f"Warning: Failed to fetch structured diff from API: {e}")
1869
+ print("Falling back to local diff computation...")
1870
+
1871
+ if diff is None:
1872
+ diff = await self._collect()
1873
+
1874
+ # Use shared validation logic
1875
+ success, error_msg, _ = validate_diff_expect_exactly(
1876
+ diff, expected_changes, self.ignore_config
1877
+ )
1878
+
1879
+ if not success:
1880
+ raise AssertionError(error_msg)
1881
+
1882
+ return self
1883
+
1884
+ async def _ensure_all_fetched(self):
1885
+ """Fetch ALL data from ALL tables upfront (non-lazy loading).
1886
+
1887
+ This is the old approach before lazy loading was introduced.
1888
+ Used by expect_only_v1 for simpler, non-optimized diffing.
1889
+ """
1890
+ # Get all tables from before snapshot
1891
+ tables_response = await self.before.resource.query(
1892
+ "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
1893
+ )
1894
+
1895
+ if tables_response.rows:
1896
+ before_tables = [row[0] for row in tables_response.rows]
1897
+ for table in before_tables:
1898
+ await self.before._ensure_table_data(table)
1899
+
1900
+ # Also fetch from after snapshot
1901
+ tables_response = await self.after.resource.query(
1902
+ "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
1903
+ )
1904
+
1905
+ if tables_response.rows:
1906
+ after_tables = [row[0] for row in tables_response.rows]
1907
+ for table in after_tables:
1908
+ await self.after._ensure_table_data(table)
1909
+
1910
+ async def expect_only_v1(self, allowed_changes: List[Dict[str, Any]]):
1911
+ """Ensure only specified changes occurred using the original (non-optimized) approach.
1912
+
1913
+ This version attempts to use the /api/v1/env/diff/structured endpoint if available,
1914
+ falling back to local diff computation if the endpoint is not available.
1915
+
1916
+ Use this when you want the simpler, more predictable behavior of the original
1917
+ implementation without any query optimizations.
1918
+ """
1919
+ # Try to use the structured diff endpoint if we have an HTTP client
1920
+ resource = self.after.resource
1921
+ if resource.client is not None and resource._mode == "http":
1922
+ api_diff = None
1923
+ try:
1924
+ payload = {}
1925
+ if self.ignore_config:
1926
+ payload["ignore_config"] = {
1927
+ "tables": list(self.ignore_config.tables),
1928
+ "fields": list(self.ignore_config.fields),
1929
+ "table_fields": {
1930
+ table: list(fields) for table, fields in self.ignore_config.table_fields.items()
1931
+ }
1932
+ }
1933
+ response = await resource.client.request(
1934
+ "POST",
1935
+ "/diff/structured",
1936
+ json=payload,
1937
+ )
1938
+ result = response.json()
1939
+ if result.get("success") and "diff" in result:
1940
+ api_diff = result["diff"]
1941
+ except Exception as e:
1942
+ # Fall back to local diff if API call fails
1943
+ print(f"Warning: Failed to fetch structured diff from API: {e}")
1944
+ print("Falling back to local diff computation...")
1945
+
1946
+ # Validate outside try block so AssertionError propagates
1947
+ if api_diff is not None:
1948
+ return await self._validate_diff_against_allowed_changes(api_diff, allowed_changes)
1949
+
1950
+ # Fall back to local diff computation
1951
+ await self._ensure_all_fetched()
1952
+ diff = await self._collect()
1953
+ return await self._validate_diff_against_allowed_changes(diff, allowed_changes)
1954
+
460
1955
 
461
1956
  class AsyncQueryBuilder:
462
1957
  """Async query builder that translates DSL to SQL and executes through the API."""
@@ -555,13 +2050,13 @@ class AsyncQueryBuilder:
555
2050
  # Compile to SQL
556
2051
  def _compile(self) -> Tuple[str, List[Any]]:
557
2052
  cols = ", ".join(self._select_cols)
558
- sql = [f"SELECT {cols} FROM {self._table}"]
2053
+ sql = [f"SELECT {cols} FROM {_quote_identifier(self._table)}"]
559
2054
  params: List[Any] = []
560
2055
 
561
2056
  # Joins
562
2057
  for tbl, onmap in self._joins:
563
- join_clauses = [f"{self._table}.{l} = {tbl}.{r}" for l, r in onmap.items()]
564
- sql.append(f"JOIN {tbl} ON {' AND '.join(join_clauses)}")
2058
+ join_clauses = [f"{_quote_identifier(self._table)}.{_quote_identifier(l)} = {_quote_identifier(tbl)}.{_quote_identifier(r)}" for l, r in onmap.items()]
2059
+ sql.append(f"JOIN {_quote_identifier(tbl)} ON {' AND '.join(join_clauses)}")
565
2060
 
566
2061
  # WHERE
567
2062
  if self._conditions:
@@ -569,12 +2064,12 @@ class AsyncQueryBuilder:
569
2064
  for col, op, val in self._conditions:
570
2065
  if op in ("IN", "NOT IN") and isinstance(val, tuple):
571
2066
  ph = ", ".join(["?" for _ in val])
572
- placeholders.append(f"{col} {op} ({ph})")
2067
+ placeholders.append(f"{_quote_identifier(col)} {op} ({ph})")
573
2068
  params.extend(val)
574
2069
  elif op in ("IS", "IS NOT"):
575
- placeholders.append(f"{col} {op} NULL")
2070
+ placeholders.append(f"{_quote_identifier(col)} {op} NULL")
576
2071
  else:
577
- placeholders.append(f"{col} {op} ?")
2072
+ placeholders.append(f"{_quote_identifier(col)} {op} ?")
578
2073
  params.append(val)
579
2074
  sql.append("WHERE " + " AND ".join(placeholders))
580
2075
 
@@ -679,16 +2174,106 @@ class AsyncQueryBuilder:
679
2174
 
680
2175
 
681
2176
  class AsyncSQLiteResource(Resource):
682
- def __init__(self, resource: ResourceModel, client: "AsyncWrapper"):
2177
+ def __init__(
2178
+ self,
2179
+ resource: ResourceModel,
2180
+ client: Optional["AsyncWrapper"] = None,
2181
+ db_path: Optional[str] = None,
2182
+ ):
683
2183
  super().__init__(resource)
684
2184
  self.client = client
2185
+ self.db_path = db_path
2186
+ self._mode = "direct" if db_path else "http"
2187
+
2188
+ @property
2189
+ def mode(self) -> str:
2190
+ """Return the mode of this resource: 'direct' (local file) or 'http' (remote API)."""
2191
+ return self._mode
685
2192
 
686
2193
  async def describe(self) -> DescribeResponse:
687
2194
  """Describe the SQLite database schema."""
2195
+ if self._mode == "direct":
2196
+ return await self._describe_direct()
2197
+ else:
2198
+ return await self._describe_http()
2199
+
2200
+ async def _describe_http(self) -> DescribeResponse:
2201
+ """Describe database schema via HTTP API."""
688
2202
  response = await self.client.request(
689
2203
  "GET", f"/resources/sqlite/{self.resource.name}/describe"
690
2204
  )
691
- return DescribeResponse(**response.json())
2205
+ try:
2206
+ return DescribeResponse(**response.json())
2207
+ except json.JSONDecodeError as e:
2208
+ raise ValueError(
2209
+ f"Failed to parse JSON response from SQLite describe endpoint. "
2210
+ f"Status: {response.status_code}, "
2211
+ f"Response text: {response.text[:500]}"
2212
+ ) from e
2213
+
2214
+ async def _describe_direct(self) -> DescribeResponse:
2215
+ """Describe database schema from local file or in-memory database."""
2216
+ def _sync_describe():
2217
+ try:
2218
+ # Check if we need URI mode (for shared memory databases)
2219
+ use_uri = 'mode=memory' in self.db_path
2220
+ conn = sqlite3.connect(self.db_path, uri=use_uri)
2221
+ cursor = conn.cursor()
2222
+
2223
+ # Get all tables
2224
+ cursor.execute(
2225
+ "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
2226
+ )
2227
+ table_names = [row[0] for row in cursor.fetchall()]
2228
+
2229
+ tables = []
2230
+ for table_name in table_names:
2231
+ # Get table info
2232
+ cursor.execute(f"PRAGMA table_info({_quote_identifier(table_name)})")
2233
+ columns = cursor.fetchall()
2234
+
2235
+ # Get CREATE TABLE SQL
2236
+ cursor.execute(
2237
+ f"SELECT sql FROM sqlite_master WHERE type='table' AND name=?",
2238
+ (table_name,)
2239
+ )
2240
+ sql_row = cursor.fetchone()
2241
+ create_sql = sql_row[0] if sql_row else ""
2242
+
2243
+ table_schema = {
2244
+ "name": table_name,
2245
+ "sql": create_sql,
2246
+ "columns": [
2247
+ {
2248
+ "name": col[1],
2249
+ "type": col[2],
2250
+ "notnull": bool(col[3]),
2251
+ "default_value": col[4],
2252
+ "primary_key": col[5] > 0,
2253
+ }
2254
+ for col in columns
2255
+ ],
2256
+ }
2257
+ tables.append(table_schema)
2258
+
2259
+ conn.close()
2260
+
2261
+ return DescribeResponse(
2262
+ success=True,
2263
+ resource_name=self.resource.name,
2264
+ tables=tables,
2265
+ message="Schema retrieved from local file",
2266
+ )
2267
+ except Exception as e:
2268
+ return DescribeResponse(
2269
+ success=False,
2270
+ resource_name=self.resource.name,
2271
+ tables=None,
2272
+ error=str(e),
2273
+ message=f"Failed to describe database: {str(e)}",
2274
+ )
2275
+
2276
+ return await asyncio.to_thread(_sync_describe)
692
2277
 
693
2278
  async def query(
694
2279
  self, query: str, args: Optional[List[Any]] = None
@@ -701,6 +2286,121 @@ class AsyncSQLiteResource(Resource):
701
2286
  async def _query(
702
2287
  self, query: str, args: Optional[List[Any]] = None, read_only: bool = True
703
2288
  ) -> QueryResponse:
2289
+ if self._mode == "direct":
2290
+ return await self._query_direct(query, args, read_only)
2291
+ else:
2292
+ # Check if this is a PRAGMA query - HTTP endpoints don't support PRAGMA
2293
+ query_stripped = query.strip().upper()
2294
+ if query_stripped.startswith("PRAGMA"):
2295
+ return await self._handle_pragma_query_http(query, args)
2296
+ return await self._query_http(query, args, read_only)
2297
+
2298
+ async def _handle_pragma_query_http(
2299
+ self, query: str, args: Optional[List[Any]] = None
2300
+ ) -> QueryResponse:
2301
+ """Handle PRAGMA queries in HTTP mode by using the describe endpoint."""
2302
+ query_upper = query.strip().upper()
2303
+
2304
+ # Extract table name from PRAGMA table_info(table_name)
2305
+ if "TABLE_INFO" in query_upper:
2306
+ # Match: PRAGMA table_info("table") or PRAGMA table_info(table)
2307
+ match = re.search(r'TABLE_INFO\s*\(\s*"([^"]+)"\s*\)', query, re.IGNORECASE)
2308
+ if not match:
2309
+ match = re.search(r"TABLE_INFO\s*\(\s*'([^']+)'\s*\)", query, re.IGNORECASE)
2310
+ if not match:
2311
+ match = re.search(r'TABLE_INFO\s*\(\s*([^\s\)]+)\s*\)', query, re.IGNORECASE)
2312
+
2313
+ if match:
2314
+ table_name = match.group(1)
2315
+
2316
+ # Use the describe endpoint to get schema
2317
+ describe_response = await self.describe()
2318
+ if not describe_response.success or not describe_response.tables:
2319
+ return QueryResponse(
2320
+ success=False,
2321
+ columns=None,
2322
+ rows=None,
2323
+ error="Failed to get schema information",
2324
+ message="PRAGMA query failed: could not retrieve schema"
2325
+ )
2326
+
2327
+ # Find the table in the schema
2328
+ table_schema = None
2329
+ for table in describe_response.tables:
2330
+ # Handle both dict and TableSchema objects
2331
+ table_name_in_schema = table.name if hasattr(table, 'name') else table.get("name")
2332
+ if table_name_in_schema == table_name:
2333
+ table_schema = table
2334
+ break
2335
+
2336
+ if not table_schema:
2337
+ return QueryResponse(
2338
+ success=False,
2339
+ columns=None,
2340
+ rows=None,
2341
+ error=f"Table '{table_name}' not found",
2342
+ message=f"PRAGMA query failed: table '{table_name}' not found"
2343
+ )
2344
+
2345
+ # Get columns from table schema
2346
+ columns = table_schema.columns if hasattr(table_schema, 'columns') else table_schema.get("columns")
2347
+ if not columns:
2348
+ return QueryResponse(
2349
+ success=False,
2350
+ columns=None,
2351
+ rows=None,
2352
+ error=f"Table '{table_name}' has no columns",
2353
+ message=f"PRAGMA query failed: table '{table_name}' has no columns"
2354
+ )
2355
+
2356
+ # Convert schema to PRAGMA table_info format
2357
+ # Format: (cid, name, type, notnull, dflt_value, pk)
2358
+ rows = []
2359
+ for idx, col in enumerate(columns):
2360
+ # Handle both dict and object column definitions
2361
+ if isinstance(col, dict):
2362
+ col_name = col["name"]
2363
+ col_type = col.get("type", "")
2364
+ col_notnull = col.get("notnull", False)
2365
+ col_default = col.get("default_value")
2366
+ col_pk = col.get("pk", 0)
2367
+ else:
2368
+ col_name = col.name if hasattr(col, 'name') else str(col)
2369
+ col_type = getattr(col, 'type', "")
2370
+ col_notnull = getattr(col, 'notnull', False)
2371
+ col_default = getattr(col, 'default_value', None)
2372
+ col_pk = getattr(col, 'pk', 0)
2373
+
2374
+ row = (
2375
+ idx, # cid
2376
+ col_name, # name
2377
+ col_type, # type
2378
+ 1 if col_notnull else 0, # notnull
2379
+ col_default, # dflt_value
2380
+ col_pk # pk
2381
+ )
2382
+ rows.append(row)
2383
+
2384
+ return QueryResponse(
2385
+ success=True,
2386
+ columns=["cid", "name", "type", "notnull", "dflt_value", "pk"],
2387
+ rows=rows,
2388
+ message="PRAGMA query executed successfully via describe endpoint"
2389
+ )
2390
+
2391
+ # For other PRAGMA queries, return an error indicating they're not supported
2392
+ return QueryResponse(
2393
+ success=False,
2394
+ columns=None,
2395
+ rows=None,
2396
+ error="PRAGMA query not supported in HTTP mode",
2397
+ message=f"PRAGMA query '{query}' is not supported via HTTP API"
2398
+ )
2399
+
2400
+ async def _query_http(
2401
+ self, query: str, args: Optional[List[Any]] = None, read_only: bool = True
2402
+ ) -> QueryResponse:
2403
+ """Execute query via HTTP API."""
704
2404
  request = QueryRequest(query=query, args=args, read_only=read_only)
705
2405
  response = await self.client.request(
706
2406
  "POST",
@@ -709,6 +2409,62 @@ class AsyncSQLiteResource(Resource):
709
2409
  )
710
2410
  return QueryResponse(**response.json())
711
2411
 
2412
+ async def _query_direct(
2413
+ self, query: str, args: Optional[List[Any]] = None, read_only: bool = True
2414
+ ) -> QueryResponse:
2415
+ """Execute query directly on local SQLite file or in-memory database."""
2416
+ def _sync_query():
2417
+ try:
2418
+ # Check if we need URI mode (for shared memory databases)
2419
+ use_uri = 'mode=memory' in self.db_path
2420
+ conn = sqlite3.connect(self.db_path, uri=use_uri)
2421
+ cursor = conn.cursor()
2422
+
2423
+ # Execute the query
2424
+ if args:
2425
+ cursor.execute(query, args)
2426
+ else:
2427
+ cursor.execute(query)
2428
+
2429
+ # For write operations, commit the transaction
2430
+ if not read_only:
2431
+ conn.commit()
2432
+
2433
+ # Get column names if available
2434
+ columns = [desc[0] for desc in cursor.description] if cursor.description else []
2435
+
2436
+ # Fetch results for SELECT queries
2437
+ rows = []
2438
+ rows_affected = 0
2439
+ last_insert_id = None
2440
+
2441
+ if cursor.description: # SELECT query
2442
+ rows = cursor.fetchall()
2443
+ else: # INSERT/UPDATE/DELETE
2444
+ rows_affected = cursor.rowcount
2445
+ last_insert_id = cursor.lastrowid if cursor.lastrowid else None
2446
+
2447
+ conn.close()
2448
+
2449
+ return QueryResponse(
2450
+ success=True,
2451
+ columns=columns if columns else None,
2452
+ rows=rows if rows else None,
2453
+ rows_affected=rows_affected if rows_affected > 0 else None,
2454
+ last_insert_id=last_insert_id,
2455
+ message="Query executed successfully",
2456
+ )
2457
+ except Exception as e:
2458
+ return QueryResponse(
2459
+ success=False,
2460
+ columns=None,
2461
+ rows=None,
2462
+ error=str(e),
2463
+ message=f"Query failed: {str(e)}",
2464
+ )
2465
+
2466
+ return await asyncio.to_thread(_sync_query)
2467
+
712
2468
  def table(self, table_name: str) -> AsyncQueryBuilder:
713
2469
  """Create a query builder for the specified table."""
714
2470
  return AsyncQueryBuilder(self, table_name)
@@ -716,7 +2472,6 @@ class AsyncSQLiteResource(Resource):
716
2472
  async def snapshot(self, name: Optional[str] = None) -> AsyncDatabaseSnapshot:
717
2473
  """Create a snapshot of the current database state."""
718
2474
  snapshot = AsyncDatabaseSnapshot(self, name)
719
- await snapshot._ensure_fetched()
720
2475
  return snapshot
721
2476
 
722
2477
  async def diff(