fleet-python 0.2.66b2__py3-none-any.whl → 0.2.105__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/export_tasks.py +16 -5
- examples/export_tasks_filtered.py +245 -0
- examples/fetch_tasks.py +230 -0
- examples/import_tasks.py +140 -8
- examples/iterate_verifiers.py +725 -0
- fleet/__init__.py +128 -5
- fleet/_async/__init__.py +27 -3
- fleet/_async/base.py +24 -9
- fleet/_async/client.py +938 -41
- fleet/_async/env/client.py +60 -3
- fleet/_async/instance/client.py +52 -7
- fleet/_async/models.py +15 -0
- fleet/_async/resources/api.py +200 -0
- fleet/_async/resources/sqlite.py +1801 -46
- fleet/_async/tasks.py +122 -25
- fleet/_async/verifiers/bundler.py +22 -21
- fleet/_async/verifiers/verifier.py +25 -19
- fleet/agent/__init__.py +32 -0
- fleet/agent/gemini_cua/Dockerfile +45 -0
- fleet/agent/gemini_cua/__init__.py +10 -0
- fleet/agent/gemini_cua/agent.py +759 -0
- fleet/agent/gemini_cua/mcp/main.py +108 -0
- fleet/agent/gemini_cua/mcp_server/__init__.py +5 -0
- fleet/agent/gemini_cua/mcp_server/main.py +105 -0
- fleet/agent/gemini_cua/mcp_server/tools.py +178 -0
- fleet/agent/gemini_cua/requirements.txt +5 -0
- fleet/agent/gemini_cua/start.sh +30 -0
- fleet/agent/orchestrator.py +854 -0
- fleet/agent/types.py +49 -0
- fleet/agent/utils.py +34 -0
- fleet/base.py +34 -9
- fleet/cli.py +1061 -0
- fleet/client.py +1060 -48
- fleet/config.py +1 -1
- fleet/env/__init__.py +16 -0
- fleet/env/client.py +60 -3
- fleet/eval/__init__.py +15 -0
- fleet/eval/uploader.py +231 -0
- fleet/exceptions.py +8 -0
- fleet/instance/client.py +53 -8
- fleet/instance/models.py +1 -0
- fleet/models.py +303 -0
- fleet/proxy/__init__.py +25 -0
- fleet/proxy/proxy.py +453 -0
- fleet/proxy/whitelist.py +244 -0
- fleet/resources/api.py +200 -0
- fleet/resources/sqlite.py +1845 -46
- fleet/tasks.py +113 -20
- fleet/utils/__init__.py +7 -0
- fleet/utils/http_logging.py +178 -0
- fleet/utils/logging.py +13 -0
- fleet/utils/playwright.py +440 -0
- fleet/verifiers/bundler.py +22 -21
- fleet/verifiers/db.py +985 -1
- fleet/verifiers/decorator.py +1 -1
- fleet/verifiers/verifier.py +25 -19
- {fleet_python-0.2.66b2.dist-info → fleet_python-0.2.105.dist-info}/METADATA +28 -1
- fleet_python-0.2.105.dist-info/RECORD +115 -0
- {fleet_python-0.2.66b2.dist-info → fleet_python-0.2.105.dist-info}/WHEEL +1 -1
- fleet_python-0.2.105.dist-info/entry_points.txt +2 -0
- tests/test_app_method.py +85 -0
- tests/test_expect_exactly.py +4148 -0
- tests/test_expect_only.py +2593 -0
- tests/test_instance_dispatch.py +607 -0
- tests/test_sqlite_resource_dual_mode.py +263 -0
- tests/test_sqlite_shared_memory_behavior.py +117 -0
- fleet_python-0.2.66b2.dist-info/RECORD +0 -81
- tests/test_verifier_security.py +0 -427
- {fleet_python-0.2.66b2.dist-info → fleet_python-0.2.105.dist-info}/licenses/LICENSE +0 -0
- {fleet_python-0.2.66b2.dist-info → fleet_python-0.2.105.dist-info}/top_level.txt +0 -0
fleet/_async/resources/sqlite.py
CHANGED
|
@@ -6,6 +6,9 @@ from datetime import datetime
|
|
|
6
6
|
import tempfile
|
|
7
7
|
import sqlite3
|
|
8
8
|
import os
|
|
9
|
+
import asyncio
|
|
10
|
+
import re
|
|
11
|
+
import json
|
|
9
12
|
|
|
10
13
|
from typing import TYPE_CHECKING
|
|
11
14
|
|
|
@@ -19,11 +22,23 @@ from fleet.verifiers.db import (
|
|
|
19
22
|
_get_row_identifier,
|
|
20
23
|
_format_row_for_error,
|
|
21
24
|
_values_equivalent,
|
|
25
|
+
validate_diff_expect_exactly,
|
|
22
26
|
)
|
|
23
27
|
|
|
24
28
|
|
|
29
|
+
def _quote_identifier(identifier: str) -> str:
|
|
30
|
+
"""Quote an identifier (table or column name) for SQLite.
|
|
31
|
+
|
|
32
|
+
SQLite uses double quotes for identifiers and escapes internal quotes by doubling them.
|
|
33
|
+
This handles reserved keywords like 'order', 'table', etc.
|
|
34
|
+
"""
|
|
35
|
+
# Escape any double quotes in the identifier by doubling them
|
|
36
|
+
escaped = identifier.replace('"', '""')
|
|
37
|
+
return f'"{escaped}"'
|
|
38
|
+
|
|
39
|
+
|
|
25
40
|
class AsyncDatabaseSnapshot:
|
|
26
|
-
"""
|
|
41
|
+
"""Lazy database snapshot that fetches data on-demand through API."""
|
|
27
42
|
|
|
28
43
|
def __init__(self, resource: "AsyncSQLiteResource", name: Optional[str] = None):
|
|
29
44
|
self.resource = resource
|
|
@@ -31,11 +46,12 @@ class AsyncDatabaseSnapshot:
|
|
|
31
46
|
self.created_at = datetime.utcnow()
|
|
32
47
|
self._data: Dict[str, List[Dict[str, Any]]] = {}
|
|
33
48
|
self._schemas: Dict[str, List[str]] = {}
|
|
34
|
-
self.
|
|
49
|
+
self._table_names: Optional[List[str]] = None
|
|
50
|
+
self._fetched_tables: set = set()
|
|
35
51
|
|
|
36
|
-
async def
|
|
37
|
-
"""Fetch
|
|
38
|
-
if self.
|
|
52
|
+
async def _ensure_tables_list(self):
|
|
53
|
+
"""Fetch just the list of table names if not already fetched."""
|
|
54
|
+
if self._table_names is not None:
|
|
39
55
|
return
|
|
40
56
|
|
|
41
57
|
# Get all tables
|
|
@@ -44,35 +60,36 @@ class AsyncDatabaseSnapshot:
|
|
|
44
60
|
)
|
|
45
61
|
|
|
46
62
|
if not tables_response.rows:
|
|
47
|
-
self.
|
|
63
|
+
self._table_names = []
|
|
48
64
|
return
|
|
49
65
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
for table
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
if schema_response.rows:
|
|
57
|
-
self._schemas[table] = [
|
|
58
|
-
row[1] for row in schema_response.rows
|
|
59
|
-
] # Column names
|
|
60
|
-
|
|
61
|
-
# Get all data
|
|
62
|
-
data_response = await self.resource.query(f"SELECT * FROM {table}")
|
|
63
|
-
if data_response.rows and data_response.columns:
|
|
64
|
-
self._data[table] = [
|
|
65
|
-
dict(zip(data_response.columns, row)) for row in data_response.rows
|
|
66
|
-
]
|
|
67
|
-
else:
|
|
68
|
-
self._data[table] = []
|
|
66
|
+
self._table_names = [row[0] for row in tables_response.rows]
|
|
67
|
+
|
|
68
|
+
async def _ensure_table_data(self, table: str):
|
|
69
|
+
"""Fetch data for a specific table on demand."""
|
|
70
|
+
if table in self._fetched_tables:
|
|
71
|
+
return
|
|
69
72
|
|
|
70
|
-
|
|
73
|
+
# Get table schema
|
|
74
|
+
schema_response = await self.resource.query(f"PRAGMA table_info({_quote_identifier(table)})")
|
|
75
|
+
if schema_response.rows:
|
|
76
|
+
self._schemas[table] = [row[1] for row in schema_response.rows] # Column names
|
|
77
|
+
|
|
78
|
+
# Get all data for this table
|
|
79
|
+
data_response = await self.resource.query(f"SELECT * FROM {_quote_identifier(table)}")
|
|
80
|
+
if data_response.rows and data_response.columns:
|
|
81
|
+
self._data[table] = [
|
|
82
|
+
dict(zip(data_response.columns, row)) for row in data_response.rows
|
|
83
|
+
]
|
|
84
|
+
else:
|
|
85
|
+
self._data[table] = []
|
|
86
|
+
|
|
87
|
+
self._fetched_tables.add(table)
|
|
71
88
|
|
|
72
89
|
async def tables(self) -> List[str]:
|
|
73
90
|
"""Get list of all tables in the snapshot."""
|
|
74
|
-
await self.
|
|
75
|
-
return list(self.
|
|
91
|
+
await self._ensure_tables_list()
|
|
92
|
+
return list(self._table_names) if self._table_names else []
|
|
76
93
|
|
|
77
94
|
def table(self, table_name: str) -> "AsyncSnapshotQueryBuilder":
|
|
78
95
|
"""Create a query builder for snapshot data."""
|
|
@@ -84,13 +101,12 @@ class AsyncDatabaseSnapshot:
|
|
|
84
101
|
ignore_config: Optional[IgnoreConfig] = None,
|
|
85
102
|
) -> "AsyncSnapshotDiff":
|
|
86
103
|
"""Compare this snapshot with another."""
|
|
87
|
-
|
|
88
|
-
await other._ensure_fetched()
|
|
104
|
+
# No need to fetch all data upfront - diff will fetch on demand
|
|
89
105
|
return AsyncSnapshotDiff(self, other, ignore_config)
|
|
90
106
|
|
|
91
107
|
|
|
92
108
|
class AsyncSnapshotQueryBuilder:
|
|
93
|
-
"""Query builder that works on
|
|
109
|
+
"""Query builder that works on snapshot data - can use targeted queries when possible."""
|
|
94
110
|
|
|
95
111
|
def __init__(self, snapshot: AsyncDatabaseSnapshot, table: str):
|
|
96
112
|
self._snapshot = snapshot
|
|
@@ -100,10 +116,63 @@ class AsyncSnapshotQueryBuilder:
|
|
|
100
116
|
self._limit: Optional[int] = None
|
|
101
117
|
self._order_by: Optional[str] = None
|
|
102
118
|
self._order_desc: bool = False
|
|
119
|
+
self._use_targeted_query = True # Try to use targeted queries when possible
|
|
120
|
+
|
|
121
|
+
def _can_use_targeted_query(self) -> bool:
|
|
122
|
+
"""Check if we can use a targeted query instead of loading all data."""
|
|
123
|
+
# We can use targeted query if:
|
|
124
|
+
# 1. We have simple equality conditions
|
|
125
|
+
# 2. No complex operations like joins
|
|
126
|
+
# 3. The query is selective (has conditions)
|
|
127
|
+
if not self._conditions:
|
|
128
|
+
return False
|
|
129
|
+
for col, op, val in self._conditions:
|
|
130
|
+
if op not in ["=", "IS", "IS NOT"]:
|
|
131
|
+
return False
|
|
132
|
+
return True
|
|
133
|
+
|
|
134
|
+
async def _execute_targeted_query(self) -> List[Dict[str, Any]]:
|
|
135
|
+
"""Execute a targeted query directly instead of loading all data."""
|
|
136
|
+
# Build WHERE clause
|
|
137
|
+
where_parts = []
|
|
138
|
+
for col, op, val in self._conditions:
|
|
139
|
+
if op == "=" and val is None:
|
|
140
|
+
where_parts.append(f"{_quote_identifier(col)} IS NULL")
|
|
141
|
+
elif op == "IS":
|
|
142
|
+
where_parts.append(f"{_quote_identifier(col)} IS NULL")
|
|
143
|
+
elif op == "IS NOT":
|
|
144
|
+
where_parts.append(f"{_quote_identifier(col)} IS NOT NULL")
|
|
145
|
+
elif op == "=":
|
|
146
|
+
if isinstance(val, str):
|
|
147
|
+
escaped_val = val.replace("'", "''")
|
|
148
|
+
where_parts.append(f"{_quote_identifier(col)} = '{escaped_val}'")
|
|
149
|
+
else:
|
|
150
|
+
where_parts.append(f"{_quote_identifier(col)} = '{val}'")
|
|
151
|
+
|
|
152
|
+
where_clause = " AND ".join(where_parts)
|
|
153
|
+
|
|
154
|
+
# Build full query
|
|
155
|
+
cols = ", ".join(self._select_cols)
|
|
156
|
+
query = f"SELECT {cols} FROM {_quote_identifier(self._table)} WHERE {where_clause}"
|
|
157
|
+
|
|
158
|
+
if self._order_by:
|
|
159
|
+
query += f" ORDER BY {self._order_by}"
|
|
160
|
+
if self._limit is not None:
|
|
161
|
+
query += f" LIMIT {self._limit}"
|
|
162
|
+
|
|
163
|
+
# Execute query
|
|
164
|
+
response = await self._snapshot.resource.query(query)
|
|
165
|
+
if response.rows and response.columns:
|
|
166
|
+
return [dict(zip(response.columns, row)) for row in response.rows]
|
|
167
|
+
return []
|
|
103
168
|
|
|
104
169
|
async def _get_data(self) -> List[Dict[str, Any]]:
|
|
105
|
-
"""Get table data
|
|
106
|
-
|
|
170
|
+
"""Get table data - use targeted query if possible, otherwise load all data."""
|
|
171
|
+
if self._use_targeted_query and self._can_use_targeted_query():
|
|
172
|
+
return await self._execute_targeted_query()
|
|
173
|
+
|
|
174
|
+
# Fall back to loading all data
|
|
175
|
+
await self._snapshot._ensure_table_data(self._table)
|
|
107
176
|
return self._snapshot._data.get(self._table, [])
|
|
108
177
|
|
|
109
178
|
def eq(self, column: str, value: Any) -> "AsyncSnapshotQueryBuilder":
|
|
@@ -142,6 +211,11 @@ class AsyncSnapshotQueryBuilder:
|
|
|
142
211
|
return rows[0] if rows else None
|
|
143
212
|
|
|
144
213
|
async def all(self) -> List[Dict[str, Any]]:
|
|
214
|
+
# If we can use targeted query, _get_data already applies filters
|
|
215
|
+
if self._use_targeted_query and self._can_use_targeted_query():
|
|
216
|
+
return await self._get_data()
|
|
217
|
+
|
|
218
|
+
# Otherwise, get all data and apply filters manually
|
|
145
219
|
data = await self._get_data()
|
|
146
220
|
|
|
147
221
|
# Apply filters
|
|
@@ -206,11 +280,12 @@ class AsyncSnapshotDiff:
|
|
|
206
280
|
self.after = after
|
|
207
281
|
self.ignore_config = ignore_config or IgnoreConfig()
|
|
208
282
|
self._cached: Optional[Dict[str, Any]] = None
|
|
283
|
+
self._targeted_mode = False # Flag to use targeted queries
|
|
209
284
|
|
|
210
285
|
async def _get_primary_key_columns(self, table: str) -> List[str]:
|
|
211
286
|
"""Get primary key columns for a table."""
|
|
212
287
|
# Try to get from schema
|
|
213
|
-
schema_response = await self.after.resource.query(f"PRAGMA table_info({table})")
|
|
288
|
+
schema_response = await self.after.resource.query(f"PRAGMA table_info({_quote_identifier(table)})")
|
|
214
289
|
if not schema_response.rows:
|
|
215
290
|
return ["id"] # Default fallback
|
|
216
291
|
|
|
@@ -246,6 +321,10 @@ class AsyncSnapshotDiff:
|
|
|
246
321
|
# Get primary key columns
|
|
247
322
|
pk_columns = await self._get_primary_key_columns(tbl)
|
|
248
323
|
|
|
324
|
+
# Ensure data is fetched for this table
|
|
325
|
+
await self.before._ensure_table_data(tbl)
|
|
326
|
+
await self.after._ensure_table_data(tbl)
|
|
327
|
+
|
|
249
328
|
# Get data from both snapshots
|
|
250
329
|
before_data = self.before._data.get(tbl, [])
|
|
251
330
|
after_data = self.after._data.get(tbl, [])
|
|
@@ -328,9 +407,378 @@ class AsyncSnapshotDiff:
|
|
|
328
407
|
)
|
|
329
408
|
return self._cached
|
|
330
409
|
|
|
331
|
-
|
|
332
|
-
"""
|
|
333
|
-
|
|
410
|
+
def _can_use_targeted_queries(self, allowed_changes: List[Dict[str, Any]]) -> bool:
|
|
411
|
+
"""Check if we can use targeted queries for optimization."""
|
|
412
|
+
# We can use targeted queries if all allowed changes specify table and pk
|
|
413
|
+
for change in allowed_changes:
|
|
414
|
+
if "table" not in change or "pk" not in change:
|
|
415
|
+
return False
|
|
416
|
+
return True
|
|
417
|
+
|
|
418
|
+
def _build_pk_where_clause(self, pk_columns: List[str], pk_value: Any) -> str:
|
|
419
|
+
"""Build WHERE clause for primary key lookup."""
|
|
420
|
+
# Escape single quotes in values to prevent SQL injection
|
|
421
|
+
def escape_value(val: Any) -> str:
|
|
422
|
+
if val is None:
|
|
423
|
+
return "NULL"
|
|
424
|
+
elif isinstance(val, str):
|
|
425
|
+
escaped = str(val).replace("'", "''")
|
|
426
|
+
return f"'{escaped}'"
|
|
427
|
+
else:
|
|
428
|
+
return f"'{val}'"
|
|
429
|
+
|
|
430
|
+
if len(pk_columns) == 1:
|
|
431
|
+
return f"{_quote_identifier(pk_columns[0])} = {escape_value(pk_value)}"
|
|
432
|
+
else:
|
|
433
|
+
# Composite key
|
|
434
|
+
if isinstance(pk_value, tuple):
|
|
435
|
+
conditions = [
|
|
436
|
+
f"{_quote_identifier(col)} = {escape_value(val)}"
|
|
437
|
+
for col, val in zip(pk_columns, pk_value)
|
|
438
|
+
]
|
|
439
|
+
return " AND ".join(conditions)
|
|
440
|
+
else:
|
|
441
|
+
# Shouldn't happen if data is consistent
|
|
442
|
+
return f"{_quote_identifier(pk_columns[0])} = {escape_value(pk_value)}"
|
|
443
|
+
|
|
444
|
+
async def _expect_no_changes(self):
|
|
445
|
+
"""Efficiently verify that no changes occurred between snapshots using row counts."""
|
|
446
|
+
try:
|
|
447
|
+
import asyncio
|
|
448
|
+
|
|
449
|
+
# Get all tables from both snapshots
|
|
450
|
+
before_tables = set(await self.before.tables())
|
|
451
|
+
after_tables = set(await self.after.tables())
|
|
452
|
+
|
|
453
|
+
# Check for added/removed tables (excluding ignored ones)
|
|
454
|
+
added_tables = after_tables - before_tables
|
|
455
|
+
removed_tables = before_tables - after_tables
|
|
456
|
+
|
|
457
|
+
for table in added_tables:
|
|
458
|
+
if not self.ignore_config.should_ignore_table(table):
|
|
459
|
+
raise AssertionError(f"Unexpected table added: {table}")
|
|
460
|
+
|
|
461
|
+
for table in removed_tables:
|
|
462
|
+
if not self.ignore_config.should_ignore_table(table):
|
|
463
|
+
raise AssertionError(f"Unexpected table removed: {table}")
|
|
464
|
+
|
|
465
|
+
# Prepare tables to check
|
|
466
|
+
tables_to_check = []
|
|
467
|
+
all_tables = before_tables | after_tables
|
|
468
|
+
for table in all_tables:
|
|
469
|
+
if not self.ignore_config.should_ignore_table(table):
|
|
470
|
+
tables_to_check.append(table)
|
|
471
|
+
|
|
472
|
+
# If no tables to check, we're done
|
|
473
|
+
if not tables_to_check:
|
|
474
|
+
return self
|
|
475
|
+
|
|
476
|
+
# Track errors and tables needing verification
|
|
477
|
+
errors = []
|
|
478
|
+
tables_needing_verification = []
|
|
479
|
+
|
|
480
|
+
async def check_table_counts(table: str):
|
|
481
|
+
"""Check row counts for a single table."""
|
|
482
|
+
try:
|
|
483
|
+
# Get row counts from both snapshots
|
|
484
|
+
before_count = 0
|
|
485
|
+
after_count = 0
|
|
486
|
+
|
|
487
|
+
if table in before_tables:
|
|
488
|
+
before_count_response = await self.before.resource.query(
|
|
489
|
+
f"SELECT COUNT(*) FROM {_quote_identifier(table)}"
|
|
490
|
+
)
|
|
491
|
+
before_count = (
|
|
492
|
+
before_count_response.rows[0][0]
|
|
493
|
+
if before_count_response.rows
|
|
494
|
+
else 0
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
if table in after_tables:
|
|
498
|
+
after_count_response = await self.after.resource.query(
|
|
499
|
+
f"SELECT COUNT(*) FROM {_quote_identifier(table)}"
|
|
500
|
+
)
|
|
501
|
+
after_count = (
|
|
502
|
+
after_count_response.rows[0][0]
|
|
503
|
+
if after_count_response.rows
|
|
504
|
+
else 0
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
if before_count != after_count:
|
|
508
|
+
error_msg = (
|
|
509
|
+
f"Unexpected change in table '{table}': "
|
|
510
|
+
f"row count changed from {before_count} to {after_count}"
|
|
511
|
+
)
|
|
512
|
+
errors.append(AssertionError(error_msg))
|
|
513
|
+
elif before_count > 0 and before_count <= 1000:
|
|
514
|
+
# Mark for detailed verification
|
|
515
|
+
tables_needing_verification.append(table)
|
|
516
|
+
|
|
517
|
+
except Exception as e:
|
|
518
|
+
errors.append(e)
|
|
519
|
+
|
|
520
|
+
# Execute count checks in parallel
|
|
521
|
+
await asyncio.gather(*[check_table_counts(table) for table in tables_to_check])
|
|
522
|
+
|
|
523
|
+
# Check if any errors occurred during count checking
|
|
524
|
+
if errors:
|
|
525
|
+
raise errors[0]
|
|
526
|
+
|
|
527
|
+
# Now verify small tables for data changes (also in parallel)
|
|
528
|
+
if tables_needing_verification:
|
|
529
|
+
verification_errors = []
|
|
530
|
+
|
|
531
|
+
async def verify_table(table: str):
|
|
532
|
+
"""Verify a single table's data hasn't changed."""
|
|
533
|
+
try:
|
|
534
|
+
await self._verify_table_unchanged(table)
|
|
535
|
+
except AssertionError as e:
|
|
536
|
+
verification_errors.append(e)
|
|
537
|
+
|
|
538
|
+
await asyncio.gather(*[verify_table(table) for table in tables_needing_verification])
|
|
539
|
+
|
|
540
|
+
# Check if any errors occurred during verification
|
|
541
|
+
if verification_errors:
|
|
542
|
+
raise verification_errors[0]
|
|
543
|
+
|
|
544
|
+
return self
|
|
545
|
+
|
|
546
|
+
except AssertionError:
|
|
547
|
+
# Re-raise assertion errors (these are expected failures)
|
|
548
|
+
raise
|
|
549
|
+
except Exception as e:
|
|
550
|
+
# If the optimized check fails for other reasons, fall back to full diff
|
|
551
|
+
print(f"Warning: Optimized no-changes check failed: {e}")
|
|
552
|
+
print("Falling back to full diff...")
|
|
553
|
+
return await self._validate_diff_against_allowed_changes(
|
|
554
|
+
await self._collect(), []
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
async def _verify_table_unchanged(self, table: str):
|
|
558
|
+
"""Verify that a table's data hasn't changed (for small tables)."""
|
|
559
|
+
# Get primary key columns
|
|
560
|
+
pk_columns = await self._get_primary_key_columns(table)
|
|
561
|
+
|
|
562
|
+
# Get sorted data from both snapshots
|
|
563
|
+
order_by = ", ".join(pk_columns) if pk_columns else "rowid"
|
|
564
|
+
|
|
565
|
+
before_response = await self.before.resource.query(
|
|
566
|
+
f"SELECT * FROM {_quote_identifier(table)} ORDER BY {order_by}"
|
|
567
|
+
)
|
|
568
|
+
after_response = await self.after.resource.query(
|
|
569
|
+
f"SELECT * FROM {_quote_identifier(table)} ORDER BY {order_by}"
|
|
570
|
+
)
|
|
571
|
+
|
|
572
|
+
# Quick check: if column counts differ, there's a schema change
|
|
573
|
+
if before_response.columns != after_response.columns:
|
|
574
|
+
raise AssertionError(f"Schema changed in table '{table}'")
|
|
575
|
+
|
|
576
|
+
# Compare row by row
|
|
577
|
+
if len(before_response.rows) != len(after_response.rows):
|
|
578
|
+
raise AssertionError(
|
|
579
|
+
f"Row count mismatch in table '{table}': "
|
|
580
|
+
f"{len(before_response.rows)} vs {len(after_response.rows)}"
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
for i, (before_row, after_row) in enumerate(
|
|
584
|
+
zip(before_response.rows, after_response.rows)
|
|
585
|
+
):
|
|
586
|
+
before_dict = dict(zip(before_response.columns, before_row))
|
|
587
|
+
after_dict = dict(zip(after_response.columns, after_row))
|
|
588
|
+
|
|
589
|
+
# Compare fields, ignoring those in ignore config
|
|
590
|
+
for field in before_response.columns:
|
|
591
|
+
if self.ignore_config.should_ignore_field(table, field):
|
|
592
|
+
continue
|
|
593
|
+
|
|
594
|
+
if not _values_equivalent(
|
|
595
|
+
before_dict.get(field), after_dict.get(field)
|
|
596
|
+
):
|
|
597
|
+
pk_val = before_dict.get(pk_columns[0]) if pk_columns else i
|
|
598
|
+
raise AssertionError(
|
|
599
|
+
f"Unexpected change in table '{table}', row {pk_val}, "
|
|
600
|
+
f"field '{field}': {repr(before_dict.get(field))} -> {repr(after_dict.get(field))}"
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
def _is_field_change_allowed(
|
|
604
|
+
self, table_changes: List[Dict[str, Any]], pk: Any, field: str, after_val: Any
|
|
605
|
+
) -> bool:
|
|
606
|
+
"""Check if a specific field change is allowed."""
|
|
607
|
+
for change in table_changes:
|
|
608
|
+
if (
|
|
609
|
+
str(change.get("pk")) == str(pk)
|
|
610
|
+
and change.get("field") == field
|
|
611
|
+
and _values_equivalent(change.get("after"), after_val)
|
|
612
|
+
):
|
|
613
|
+
return True
|
|
614
|
+
return False
|
|
615
|
+
|
|
616
|
+
def _is_row_change_allowed(
|
|
617
|
+
self, table_changes: List[Dict[str, Any]], pk: Any, change_type: str
|
|
618
|
+
) -> bool:
|
|
619
|
+
"""Check if a row addition/deletion is allowed."""
|
|
620
|
+
for change in table_changes:
|
|
621
|
+
if str(change.get("pk")) == str(pk) and change.get("after") == change_type:
|
|
622
|
+
return True
|
|
623
|
+
return False
|
|
624
|
+
|
|
625
|
+
async def _expect_only_targeted(self, allowed_changes: List[Dict[str, Any]]):
|
|
626
|
+
"""Optimized version that only queries specific rows mentioned in allowed_changes."""
|
|
627
|
+
import asyncio
|
|
628
|
+
|
|
629
|
+
# Group allowed changes by table
|
|
630
|
+
changes_by_table: Dict[str, List[Dict[str, Any]]] = {}
|
|
631
|
+
for change in allowed_changes:
|
|
632
|
+
table = change["table"]
|
|
633
|
+
if table not in changes_by_table:
|
|
634
|
+
changes_by_table[table] = []
|
|
635
|
+
changes_by_table[table].append(change)
|
|
636
|
+
|
|
637
|
+
errors = []
|
|
638
|
+
|
|
639
|
+
# Function to check a single row
|
|
640
|
+
async def check_row(
|
|
641
|
+
table: str,
|
|
642
|
+
pk: Any,
|
|
643
|
+
table_changes: List[Dict[str, Any]],
|
|
644
|
+
pk_columns: List[str],
|
|
645
|
+
):
|
|
646
|
+
try:
|
|
647
|
+
# Build WHERE clause for this PK
|
|
648
|
+
where_sql = self._build_pk_where_clause(pk_columns, pk)
|
|
649
|
+
|
|
650
|
+
# Query before snapshot
|
|
651
|
+
before_query = f"SELECT * FROM {_quote_identifier(table)} WHERE {where_sql}"
|
|
652
|
+
before_response = await self.before.resource.query(before_query)
|
|
653
|
+
before_row = (
|
|
654
|
+
dict(zip(before_response.columns, before_response.rows[0]))
|
|
655
|
+
if before_response.rows
|
|
656
|
+
else None
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
# Query after snapshot
|
|
660
|
+
after_response = await self.after.resource.query(before_query)
|
|
661
|
+
after_row = (
|
|
662
|
+
dict(zip(after_response.columns, after_response.rows[0]))
|
|
663
|
+
if after_response.rows
|
|
664
|
+
else None
|
|
665
|
+
)
|
|
666
|
+
|
|
667
|
+
# Check changes for this row
|
|
668
|
+
if before_row and after_row:
|
|
669
|
+
# Modified row - check fields
|
|
670
|
+
for field in set(before_row.keys()) | set(after_row.keys()):
|
|
671
|
+
if self.ignore_config.should_ignore_field(table, field):
|
|
672
|
+
continue
|
|
673
|
+
before_val = before_row.get(field)
|
|
674
|
+
after_val = after_row.get(field)
|
|
675
|
+
if not _values_equivalent(before_val, after_val):
|
|
676
|
+
# Check if this change is allowed
|
|
677
|
+
if not self._is_field_change_allowed(
|
|
678
|
+
table_changes, pk, field, after_val
|
|
679
|
+
):
|
|
680
|
+
error_msg = (
|
|
681
|
+
f"Unexpected change in table '{table}', "
|
|
682
|
+
f"row {pk}, field '{field}': "
|
|
683
|
+
f"{repr(before_val)} -> {repr(after_val)}"
|
|
684
|
+
)
|
|
685
|
+
errors.append(AssertionError(error_msg))
|
|
686
|
+
return # Stop checking this row
|
|
687
|
+
elif not before_row and after_row:
|
|
688
|
+
# Added row
|
|
689
|
+
if not self._is_row_change_allowed(table_changes, pk, "__added__"):
|
|
690
|
+
error_msg = f"Unexpected row added in table '{table}': {pk}"
|
|
691
|
+
errors.append(AssertionError(error_msg))
|
|
692
|
+
elif before_row and not after_row:
|
|
693
|
+
# Removed row
|
|
694
|
+
if not self._is_row_change_allowed(table_changes, pk, "__removed__"):
|
|
695
|
+
error_msg = f"Unexpected row removed from table '{table}': {pk}"
|
|
696
|
+
errors.append(AssertionError(error_msg))
|
|
697
|
+
except Exception as e:
|
|
698
|
+
errors.append(e)
|
|
699
|
+
|
|
700
|
+
# Prepare all row checks
|
|
701
|
+
row_checks = []
|
|
702
|
+
for table, table_changes in changes_by_table.items():
|
|
703
|
+
if self.ignore_config.should_ignore_table(table):
|
|
704
|
+
continue
|
|
705
|
+
|
|
706
|
+
# Get primary key columns once per table
|
|
707
|
+
pk_columns = await self._get_primary_key_columns(table)
|
|
708
|
+
|
|
709
|
+
# Extract unique PKs to check
|
|
710
|
+
pks_to_check = {change["pk"] for change in table_changes}
|
|
711
|
+
|
|
712
|
+
for pk in pks_to_check:
|
|
713
|
+
row_checks.append((table, pk, table_changes, pk_columns))
|
|
714
|
+
|
|
715
|
+
# Execute row checks in parallel
|
|
716
|
+
if row_checks:
|
|
717
|
+
await asyncio.gather(
|
|
718
|
+
*[
|
|
719
|
+
check_row(table, pk, table_changes, pk_columns)
|
|
720
|
+
for table, pk, table_changes, pk_columns in row_checks
|
|
721
|
+
]
|
|
722
|
+
)
|
|
723
|
+
|
|
724
|
+
# Check for errors from row checks
|
|
725
|
+
if errors:
|
|
726
|
+
raise errors[0]
|
|
727
|
+
|
|
728
|
+
# Now check tables not mentioned in allowed_changes to ensure no changes
|
|
729
|
+
all_tables = set(await self.before.tables()) | set(await self.after.tables())
|
|
730
|
+
tables_to_verify = []
|
|
731
|
+
|
|
732
|
+
for table in all_tables:
|
|
733
|
+
if (
|
|
734
|
+
table not in changes_by_table
|
|
735
|
+
and not self.ignore_config.should_ignore_table(table)
|
|
736
|
+
):
|
|
737
|
+
tables_to_verify.append(table)
|
|
738
|
+
|
|
739
|
+
# Function to verify no changes in a table
|
|
740
|
+
async def verify_no_changes(table: str):
|
|
741
|
+
try:
|
|
742
|
+
# For tables with no allowed changes, just check row counts
|
|
743
|
+
before_count_response = await self.before.resource.query(
|
|
744
|
+
f"SELECT COUNT(*) FROM {_quote_identifier(table)}"
|
|
745
|
+
)
|
|
746
|
+
before_count = (
|
|
747
|
+
before_count_response.rows[0][0]
|
|
748
|
+
if before_count_response.rows
|
|
749
|
+
else 0
|
|
750
|
+
)
|
|
751
|
+
|
|
752
|
+
after_count_response = await self.after.resource.query(
|
|
753
|
+
f"SELECT COUNT(*) FROM {_quote_identifier(table)}"
|
|
754
|
+
)
|
|
755
|
+
after_count = (
|
|
756
|
+
after_count_response.rows[0][0] if after_count_response.rows else 0
|
|
757
|
+
)
|
|
758
|
+
|
|
759
|
+
if before_count != after_count:
|
|
760
|
+
error_msg = (
|
|
761
|
+
f"Unexpected change in table '{table}': "
|
|
762
|
+
f"row count changed from {before_count} to {after_count}"
|
|
763
|
+
)
|
|
764
|
+
errors.append(AssertionError(error_msg))
|
|
765
|
+
except Exception as e:
|
|
766
|
+
errors.append(e)
|
|
767
|
+
|
|
768
|
+
# Execute table verification in parallel
|
|
769
|
+
if tables_to_verify:
|
|
770
|
+
await asyncio.gather(*[verify_no_changes(table) for table in tables_to_verify])
|
|
771
|
+
|
|
772
|
+
# Final error check
|
|
773
|
+
if errors:
|
|
774
|
+
raise errors[0]
|
|
775
|
+
|
|
776
|
+
return self
|
|
777
|
+
|
|
778
|
+
async def _validate_diff_against_allowed_changes(
|
|
779
|
+
self, diff: Dict[str, Any], allowed_changes: List[Dict[str, Any]]
|
|
780
|
+
):
|
|
781
|
+
"""Validate a collected diff against allowed changes."""
|
|
334
782
|
|
|
335
783
|
def _is_change_allowed(
|
|
336
784
|
table: str, row_id: Any, field: Optional[str], after_value: Any
|
|
@@ -457,6 +905,1053 @@ class AsyncSnapshotDiff:
|
|
|
457
905
|
|
|
458
906
|
return self
|
|
459
907
|
|
|
908
|
+
async def _expect_only_targeted_v2(self, allowed_changes: List[Dict[str, Any]]):
|
|
909
|
+
"""Optimized version that only queries specific rows mentioned in allowed_changes.
|
|
910
|
+
|
|
911
|
+
Supports v2 spec formats:
|
|
912
|
+
- {"table": "t", "pk": 1, "type": "insert", "fields": [...]}
|
|
913
|
+
- {"table": "t", "pk": 1, "type": "modify", "resulting_fields": [...], "no_other_changes": bool}
|
|
914
|
+
- {"table": "t", "pk": 1, "type": "delete", "fields": [...]}
|
|
915
|
+
- Legacy single-field specs: {"table": "t", "pk": 1, "field": "x", "after": val}
|
|
916
|
+
"""
|
|
917
|
+
import asyncio
|
|
918
|
+
|
|
919
|
+
# Helper functions for v2 spec validation
|
|
920
|
+
def _parse_fields_spec(
|
|
921
|
+
fields_spec: List[Tuple[str, Any]]
|
|
922
|
+
) -> Dict[str, Tuple[bool, Any]]:
|
|
923
|
+
"""Parse a fields spec into a mapping of field_name -> (should_check_value, expected_value)."""
|
|
924
|
+
spec_map: Dict[str, Tuple[bool, Any]] = {}
|
|
925
|
+
for spec_tuple in fields_spec:
|
|
926
|
+
if len(spec_tuple) != 2:
|
|
927
|
+
raise ValueError(
|
|
928
|
+
f"Invalid field spec tuple: {spec_tuple}. "
|
|
929
|
+
f"Expected 2-tuple like ('field', value), ('field', None), or ('field', ...)"
|
|
930
|
+
)
|
|
931
|
+
field_name, expected_value = spec_tuple
|
|
932
|
+
if expected_value is ...:
|
|
933
|
+
spec_map[field_name] = (False, None)
|
|
934
|
+
else:
|
|
935
|
+
spec_map[field_name] = (True, expected_value)
|
|
936
|
+
return spec_map
|
|
937
|
+
|
|
938
|
+
def _get_all_specs_for_pk(table: str, pk: Any) -> List[Dict[str, Any]]:
|
|
939
|
+
"""Get all specs for a given table/pk (for legacy multi-field specs)."""
|
|
940
|
+
specs = []
|
|
941
|
+
for allowed in allowed_changes:
|
|
942
|
+
if (
|
|
943
|
+
allowed["table"] == table
|
|
944
|
+
and str(allowed.get("pk")) == str(pk)
|
|
945
|
+
):
|
|
946
|
+
specs.append(allowed)
|
|
947
|
+
return specs
|
|
948
|
+
|
|
949
|
+
def _validate_insert_row(
|
|
950
|
+
table: str, pk: Any, row_data: Dict[str, Any], specs: List[Dict[str, Any]]
|
|
951
|
+
) -> Optional[str]:
|
|
952
|
+
"""Validate an inserted row against specs. Returns error message or None."""
|
|
953
|
+
# Check for type: "insert" spec with fields
|
|
954
|
+
for spec in specs:
|
|
955
|
+
if spec.get("type") == "insert":
|
|
956
|
+
fields_spec = spec.get("fields")
|
|
957
|
+
if fields_spec is not None:
|
|
958
|
+
# Validate each field
|
|
959
|
+
spec_map = _parse_fields_spec(fields_spec)
|
|
960
|
+
for field_name, field_value in row_data.items():
|
|
961
|
+
if field_name == "rowid":
|
|
962
|
+
continue
|
|
963
|
+
if self.ignore_config.should_ignore_field(table, field_name):
|
|
964
|
+
continue
|
|
965
|
+
if field_name not in spec_map:
|
|
966
|
+
return f"Field '{field_name}' not in insert spec for table '{table}' pk={pk}"
|
|
967
|
+
should_check, expected_value = spec_map[field_name]
|
|
968
|
+
if should_check and not _values_equivalent(expected_value, field_value):
|
|
969
|
+
return (
|
|
970
|
+
f"Insert mismatch in table '{table}' pk={pk}, "
|
|
971
|
+
f"field '{field_name}': expected {repr(expected_value)}, got {repr(field_value)}"
|
|
972
|
+
)
|
|
973
|
+
# type: "insert" found (with or without fields) - allowed
|
|
974
|
+
return None
|
|
975
|
+
|
|
976
|
+
# Check for legacy whole-row spec
|
|
977
|
+
for spec in specs:
|
|
978
|
+
if spec.get("fields") is None and spec.get("after") == "__added__":
|
|
979
|
+
return None
|
|
980
|
+
|
|
981
|
+
return f"Unexpected row added in table '{table}': pk={pk}"
|
|
982
|
+
|
|
983
|
+
def _validate_delete_row(
|
|
984
|
+
table: str, pk: Any, row_data: Dict[str, Any], specs: List[Dict[str, Any]]
|
|
985
|
+
) -> Optional[str]:
|
|
986
|
+
"""Validate a deleted row against specs. Returns error message or None."""
|
|
987
|
+
# Check for type: "delete" spec with optional fields
|
|
988
|
+
for spec in specs:
|
|
989
|
+
if spec.get("type") == "delete":
|
|
990
|
+
fields_spec = spec.get("fields")
|
|
991
|
+
if fields_spec is not None:
|
|
992
|
+
# Validate each field against the deleted row
|
|
993
|
+
spec_map = _parse_fields_spec(fields_spec)
|
|
994
|
+
for field_name, (should_check, expected_value) in spec_map.items():
|
|
995
|
+
if field_name not in row_data:
|
|
996
|
+
return f"Field '{field_name}' in delete spec not found in row for table '{table}' pk={pk}"
|
|
997
|
+
if should_check and not _values_equivalent(expected_value, row_data[field_name]):
|
|
998
|
+
return (
|
|
999
|
+
f"Delete mismatch in table '{table}' pk={pk}, "
|
|
1000
|
+
f"field '{field_name}': expected {repr(expected_value)}, got {repr(row_data[field_name])}"
|
|
1001
|
+
)
|
|
1002
|
+
# type: "delete" found (with or without fields) - allowed
|
|
1003
|
+
return None
|
|
1004
|
+
|
|
1005
|
+
# Check for legacy whole-row spec
|
|
1006
|
+
for spec in specs:
|
|
1007
|
+
if spec.get("fields") is None and spec.get("after") == "__removed__":
|
|
1008
|
+
return None
|
|
1009
|
+
|
|
1010
|
+
return f"Unexpected row removed from table '{table}': pk={pk}"
|
|
1011
|
+
|
|
1012
|
+
def _validate_modify_row(
|
|
1013
|
+
table: str,
|
|
1014
|
+
pk: Any,
|
|
1015
|
+
before_row: Dict[str, Any],
|
|
1016
|
+
after_row: Dict[str, Any],
|
|
1017
|
+
specs: List[Dict[str, Any]],
|
|
1018
|
+
) -> Optional[str]:
|
|
1019
|
+
"""Validate a modified row against specs. Returns error message or None."""
|
|
1020
|
+
# Collect actual changes
|
|
1021
|
+
changed_fields: Dict[str, Dict[str, Any]] = {}
|
|
1022
|
+
for field in set(before_row.keys()) | set(after_row.keys()):
|
|
1023
|
+
if self.ignore_config.should_ignore_field(table, field):
|
|
1024
|
+
continue
|
|
1025
|
+
before_val = before_row.get(field)
|
|
1026
|
+
after_val = after_row.get(field)
|
|
1027
|
+
if not _values_equivalent(before_val, after_val):
|
|
1028
|
+
changed_fields[field] = {"before": before_val, "after": after_val}
|
|
1029
|
+
|
|
1030
|
+
if not changed_fields:
|
|
1031
|
+
return None # No changes
|
|
1032
|
+
|
|
1033
|
+
# Check for type: "modify" spec with resulting_fields
|
|
1034
|
+
for spec in specs:
|
|
1035
|
+
if spec.get("type") == "modify":
|
|
1036
|
+
resulting_fields = spec.get("resulting_fields")
|
|
1037
|
+
if resulting_fields is not None:
|
|
1038
|
+
# Validate no_other_changes is provided
|
|
1039
|
+
if "no_other_changes" not in spec:
|
|
1040
|
+
raise ValueError(
|
|
1041
|
+
f"Modify spec for table '{table}' pk={pk} "
|
|
1042
|
+
f"has 'resulting_fields' but missing required 'no_other_changes' field."
|
|
1043
|
+
)
|
|
1044
|
+
no_other_changes = spec["no_other_changes"]
|
|
1045
|
+
if not isinstance(no_other_changes, bool):
|
|
1046
|
+
raise ValueError(
|
|
1047
|
+
f"Modify spec for table '{table}' pk={pk} "
|
|
1048
|
+
f"'no_other_changes' must be boolean, got {type(no_other_changes).__name__}"
|
|
1049
|
+
)
|
|
1050
|
+
|
|
1051
|
+
spec_map = _parse_fields_spec(resulting_fields)
|
|
1052
|
+
|
|
1053
|
+
# Validate changed fields
|
|
1054
|
+
for field_name, vals in changed_fields.items():
|
|
1055
|
+
after_val = vals["after"]
|
|
1056
|
+
if field_name not in spec_map:
|
|
1057
|
+
if no_other_changes:
|
|
1058
|
+
return (
|
|
1059
|
+
f"Unexpected field change in table '{table}' pk={pk}: "
|
|
1060
|
+
f"field '{field_name}' not in resulting_fields"
|
|
1061
|
+
)
|
|
1062
|
+
# no_other_changes=False: ignore this field
|
|
1063
|
+
else:
|
|
1064
|
+
should_check, expected_value = spec_map[field_name]
|
|
1065
|
+
if should_check and not _values_equivalent(expected_value, after_val):
|
|
1066
|
+
return (
|
|
1067
|
+
f"Modify mismatch in table '{table}' pk={pk}, "
|
|
1068
|
+
f"field '{field_name}': expected {repr(expected_value)}, got {repr(after_val)}"
|
|
1069
|
+
)
|
|
1070
|
+
return None # Validation passed
|
|
1071
|
+
else:
|
|
1072
|
+
# type: "modify" without resulting_fields - allow any modification
|
|
1073
|
+
return None
|
|
1074
|
+
|
|
1075
|
+
# Check for legacy single-field specs
|
|
1076
|
+
for field_name, vals in changed_fields.items():
|
|
1077
|
+
after_val = vals["after"]
|
|
1078
|
+
field_allowed = False
|
|
1079
|
+
for spec in specs:
|
|
1080
|
+
if (
|
|
1081
|
+
spec.get("field") == field_name
|
|
1082
|
+
and _values_equivalent(spec.get("after"), after_val)
|
|
1083
|
+
):
|
|
1084
|
+
field_allowed = True
|
|
1085
|
+
break
|
|
1086
|
+
if not field_allowed:
|
|
1087
|
+
return (
|
|
1088
|
+
f"Unexpected change in table '{table}' pk={pk}, "
|
|
1089
|
+
f"field '{field_name}': {repr(vals['before'])} -> {repr(after_val)}"
|
|
1090
|
+
)
|
|
1091
|
+
|
|
1092
|
+
return None
|
|
1093
|
+
|
|
1094
|
+
# Group allowed changes by table
|
|
1095
|
+
changes_by_table: Dict[str, List[Dict[str, Any]]] = {}
|
|
1096
|
+
for change in allowed_changes:
|
|
1097
|
+
table = change["table"]
|
|
1098
|
+
if table not in changes_by_table:
|
|
1099
|
+
changes_by_table[table] = []
|
|
1100
|
+
changes_by_table[table].append(change)
|
|
1101
|
+
|
|
1102
|
+
errors: List[Exception] = []
|
|
1103
|
+
|
|
1104
|
+
# Async function to check a single row
|
|
1105
|
+
async def check_row(
|
|
1106
|
+
table: str,
|
|
1107
|
+
pk: Any,
|
|
1108
|
+
pk_columns: List[str],
|
|
1109
|
+
):
|
|
1110
|
+
try:
|
|
1111
|
+
# Build WHERE clause for this PK
|
|
1112
|
+
where_sql = self._build_pk_where_clause(pk_columns, pk)
|
|
1113
|
+
|
|
1114
|
+
# Query before snapshot
|
|
1115
|
+
before_query = f"SELECT * FROM {_quote_identifier(table)} WHERE {where_sql}"
|
|
1116
|
+
before_response = await self.before.resource.query(before_query)
|
|
1117
|
+
before_row = (
|
|
1118
|
+
dict(zip(before_response.columns, before_response.rows[0]))
|
|
1119
|
+
if before_response.rows
|
|
1120
|
+
else None
|
|
1121
|
+
)
|
|
1122
|
+
|
|
1123
|
+
# Query after snapshot
|
|
1124
|
+
after_response = await self.after.resource.query(before_query)
|
|
1125
|
+
after_row = (
|
|
1126
|
+
dict(zip(after_response.columns, after_response.rows[0]))
|
|
1127
|
+
if after_response.rows
|
|
1128
|
+
else None
|
|
1129
|
+
)
|
|
1130
|
+
|
|
1131
|
+
# Get all specs for this table/pk
|
|
1132
|
+
specs = _get_all_specs_for_pk(table, pk)
|
|
1133
|
+
|
|
1134
|
+
# Check changes for this row
|
|
1135
|
+
if before_row and after_row:
|
|
1136
|
+
# Modified row
|
|
1137
|
+
error = _validate_modify_row(table, pk, before_row, after_row, specs)
|
|
1138
|
+
if error:
|
|
1139
|
+
errors.append(AssertionError(error))
|
|
1140
|
+
elif not before_row and after_row:
|
|
1141
|
+
# Added row
|
|
1142
|
+
error = _validate_insert_row(table, pk, after_row, specs)
|
|
1143
|
+
if error:
|
|
1144
|
+
errors.append(AssertionError(error))
|
|
1145
|
+
elif before_row and not after_row:
|
|
1146
|
+
# Removed row
|
|
1147
|
+
error = _validate_delete_row(table, pk, before_row, specs)
|
|
1148
|
+
if error:
|
|
1149
|
+
errors.append(AssertionError(error))
|
|
1150
|
+
|
|
1151
|
+
except Exception as e:
|
|
1152
|
+
errors.append(e)
|
|
1153
|
+
|
|
1154
|
+
# Prepare all row checks
|
|
1155
|
+
row_tasks = []
|
|
1156
|
+
for table, table_changes in changes_by_table.items():
|
|
1157
|
+
if self.ignore_config.should_ignore_table(table):
|
|
1158
|
+
continue
|
|
1159
|
+
|
|
1160
|
+
# Get primary key columns once per table
|
|
1161
|
+
pk_columns = self._get_primary_key_columns(table)
|
|
1162
|
+
|
|
1163
|
+
# Extract unique PKs to check
|
|
1164
|
+
pks_to_check = {change["pk"] for change in table_changes}
|
|
1165
|
+
|
|
1166
|
+
for pk in pks_to_check:
|
|
1167
|
+
row_tasks.append(check_row(table, pk, pk_columns))
|
|
1168
|
+
|
|
1169
|
+
# Execute row checks concurrently
|
|
1170
|
+
if row_tasks:
|
|
1171
|
+
await asyncio.gather(*row_tasks)
|
|
1172
|
+
|
|
1173
|
+
# Check for errors from row checks
|
|
1174
|
+
if errors:
|
|
1175
|
+
raise errors[0]
|
|
1176
|
+
|
|
1177
|
+
# Now check tables not mentioned in allowed_changes to ensure no changes
|
|
1178
|
+
all_tables = set(await self.before.tables()) | set(await self.after.tables())
|
|
1179
|
+
tables_to_verify = []
|
|
1180
|
+
|
|
1181
|
+
for table in all_tables:
|
|
1182
|
+
if (
|
|
1183
|
+
table not in changes_by_table
|
|
1184
|
+
and not self.ignore_config.should_ignore_table(table)
|
|
1185
|
+
):
|
|
1186
|
+
tables_to_verify.append(table)
|
|
1187
|
+
|
|
1188
|
+
# Async function to verify no changes in a table
|
|
1189
|
+
async def verify_no_changes(table: str):
|
|
1190
|
+
try:
|
|
1191
|
+
# For tables with no allowed changes, just check row counts
|
|
1192
|
+
before_count_response = await self.before.resource.query(
|
|
1193
|
+
f"SELECT COUNT(*) FROM {_quote_identifier(table)}"
|
|
1194
|
+
)
|
|
1195
|
+
before_count = (
|
|
1196
|
+
before_count_response.rows[0][0]
|
|
1197
|
+
if before_count_response.rows
|
|
1198
|
+
else 0
|
|
1199
|
+
)
|
|
1200
|
+
|
|
1201
|
+
after_count_response = await self.after.resource.query(
|
|
1202
|
+
f"SELECT COUNT(*) FROM {_quote_identifier(table)}"
|
|
1203
|
+
)
|
|
1204
|
+
after_count = (
|
|
1205
|
+
after_count_response.rows[0][0] if after_count_response.rows else 0
|
|
1206
|
+
)
|
|
1207
|
+
|
|
1208
|
+
if before_count != after_count:
|
|
1209
|
+
error_msg = (
|
|
1210
|
+
f"Unexpected change in table '{table}': "
|
|
1211
|
+
f"row count changed from {before_count} to {after_count}"
|
|
1212
|
+
)
|
|
1213
|
+
errors.append(AssertionError(error_msg))
|
|
1214
|
+
except Exception as e:
|
|
1215
|
+
errors.append(e)
|
|
1216
|
+
|
|
1217
|
+
# Execute table verification concurrently
|
|
1218
|
+
if tables_to_verify:
|
|
1219
|
+
verify_tasks = [verify_no_changes(table) for table in tables_to_verify]
|
|
1220
|
+
await asyncio.gather(*verify_tasks)
|
|
1221
|
+
|
|
1222
|
+
# Final error check
|
|
1223
|
+
if errors:
|
|
1224
|
+
raise errors[0]
|
|
1225
|
+
|
|
1226
|
+
return self
|
|
1227
|
+
|
|
1228
|
+
async def _validate_diff_against_allowed_changes_v2(
|
|
1229
|
+
self, diff: Dict[str, Any], allowed_changes: List[Dict[str, Any]]
|
|
1230
|
+
):
|
|
1231
|
+
"""Validate a collected diff against allowed changes with field-level spec support.
|
|
1232
|
+
|
|
1233
|
+
This version supports explicit change types via the "type" field:
|
|
1234
|
+
1. Insert specs: {"table": "t", "pk": 1, "type": "insert", "fields": [("name", "value"), ("status", ...)]}
|
|
1235
|
+
- ("name", value): check that field equals value
|
|
1236
|
+
- ("name", None): check that field is SQL NULL
|
|
1237
|
+
- ("name", ...): don't check the value, just acknowledge the field exists
|
|
1238
|
+
2. Modify specs: {"table": "t", "pk": 1, "type": "modify", "resulting_fields": [...], "no_other_changes": True/False}
|
|
1239
|
+
- Uses "resulting_fields" (not "fields") to be explicit about what's being checked
|
|
1240
|
+
- "no_other_changes" is REQUIRED and must be True or False:
|
|
1241
|
+
- True: Every changed field must be in resulting_fields (strict mode)
|
|
1242
|
+
- False: Only check fields in resulting_fields match, ignore other changes
|
|
1243
|
+
- ("field_name", value): check that after value equals value
|
|
1244
|
+
- ("field_name", None): check that after value is SQL NULL
|
|
1245
|
+
- ("field_name", ...): don't check value, just acknowledge field changed
|
|
1246
|
+
3. Delete specs:
|
|
1247
|
+
- Without field validation: {"table": "t", "pk": 1, "type": "delete"}
|
|
1248
|
+
- With field validation: {"table": "t", "pk": 1, "type": "delete", "fields": [...]}
|
|
1249
|
+
4. Whole-row specs (legacy):
|
|
1250
|
+
- For additions: {"table": "t", "pk": 1, "fields": None, "after": "__added__"}
|
|
1251
|
+
- For deletions: {"table": "t", "pk": 1, "fields": None, "after": "__removed__"}
|
|
1252
|
+
|
|
1253
|
+
When using "fields" for inserts, every field must be accounted for in the list.
|
|
1254
|
+
For modifications, use "resulting_fields" with explicit "no_other_changes".
|
|
1255
|
+
For deletions with "fields", all specified fields are validated against the deleted row.
|
|
1256
|
+
"""
|
|
1257
|
+
|
|
1258
|
+
def _is_change_allowed(
|
|
1259
|
+
table: str, row_id: Any, field: Optional[str], after_value: Any
|
|
1260
|
+
) -> bool:
|
|
1261
|
+
"""Check if a change is in the allowed list using semantic comparison."""
|
|
1262
|
+
for allowed in allowed_changes:
|
|
1263
|
+
allowed_pk = allowed.get("pk")
|
|
1264
|
+
# Handle type conversion for primary key comparison
|
|
1265
|
+
pk_match = (
|
|
1266
|
+
str(allowed_pk) == str(row_id) if allowed_pk is not None else False
|
|
1267
|
+
)
|
|
1268
|
+
|
|
1269
|
+
# For whole-row specs, check "fields": None; for field-level, check "field"
|
|
1270
|
+
field_match = (
|
|
1271
|
+
("fields" in allowed and allowed.get("fields") is None)
|
|
1272
|
+
if field is None
|
|
1273
|
+
else allowed.get("field") == field
|
|
1274
|
+
)
|
|
1275
|
+
if (
|
|
1276
|
+
allowed["table"] == table
|
|
1277
|
+
and pk_match
|
|
1278
|
+
and field_match
|
|
1279
|
+
and _values_equivalent(allowed.get("after"), after_value)
|
|
1280
|
+
):
|
|
1281
|
+
return True
|
|
1282
|
+
return False
|
|
1283
|
+
|
|
1284
|
+
def _get_fields_spec_for_type(
|
|
1285
|
+
table: str, row_id: Any, change_type: str
|
|
1286
|
+
) -> Optional[List[Tuple[str, Any]]]:
|
|
1287
|
+
"""Get the bulk fields spec for a given table/row/type if it exists.
|
|
1288
|
+
|
|
1289
|
+
Args:
|
|
1290
|
+
table: The table name
|
|
1291
|
+
row_id: The primary key value
|
|
1292
|
+
change_type: One of "insert", "modify", or "delete"
|
|
1293
|
+
|
|
1294
|
+
Note: For "modify" type, use _get_modify_spec instead.
|
|
1295
|
+
"""
|
|
1296
|
+
for allowed in allowed_changes:
|
|
1297
|
+
allowed_pk = allowed.get("pk")
|
|
1298
|
+
pk_match = (
|
|
1299
|
+
str(allowed_pk) == str(row_id) if allowed_pk is not None else False
|
|
1300
|
+
)
|
|
1301
|
+
if (
|
|
1302
|
+
allowed["table"] == table
|
|
1303
|
+
and pk_match
|
|
1304
|
+
and allowed.get("type") == change_type
|
|
1305
|
+
and "fields" in allowed
|
|
1306
|
+
):
|
|
1307
|
+
return allowed["fields"]
|
|
1308
|
+
return None
|
|
1309
|
+
|
|
1310
|
+
def _get_modify_spec(table: str, row_id: Any) -> Optional[Dict[str, Any]]:
|
|
1311
|
+
"""Get the modify spec for a given table/row if it exists.
|
|
1312
|
+
|
|
1313
|
+
Returns the full spec dict containing:
|
|
1314
|
+
- resulting_fields: List of field tuples
|
|
1315
|
+
- no_other_changes: Boolean (required)
|
|
1316
|
+
|
|
1317
|
+
Returns None if no modify spec found.
|
|
1318
|
+
"""
|
|
1319
|
+
for allowed in allowed_changes:
|
|
1320
|
+
allowed_pk = allowed.get("pk")
|
|
1321
|
+
pk_match = (
|
|
1322
|
+
str(allowed_pk) == str(row_id) if allowed_pk is not None else False
|
|
1323
|
+
)
|
|
1324
|
+
if (
|
|
1325
|
+
allowed["table"] == table
|
|
1326
|
+
and pk_match
|
|
1327
|
+
and allowed.get("type") == "modify"
|
|
1328
|
+
):
|
|
1329
|
+
return allowed
|
|
1330
|
+
return None
|
|
1331
|
+
|
|
1332
|
+
def _is_type_allowed(table: str, row_id: Any, change_type: str) -> bool:
|
|
1333
|
+
"""Check if a change type is allowed for the given table/row (with or without fields)."""
|
|
1334
|
+
for allowed in allowed_changes:
|
|
1335
|
+
allowed_pk = allowed.get("pk")
|
|
1336
|
+
pk_match = (
|
|
1337
|
+
str(allowed_pk) == str(row_id) if allowed_pk is not None else False
|
|
1338
|
+
)
|
|
1339
|
+
if (
|
|
1340
|
+
allowed["table"] == table
|
|
1341
|
+
and pk_match
|
|
1342
|
+
and allowed.get("type") == change_type
|
|
1343
|
+
):
|
|
1344
|
+
return True
|
|
1345
|
+
return False
|
|
1346
|
+
|
|
1347
|
+
def _parse_fields_spec(
|
|
1348
|
+
fields_spec: List[Tuple[str, Any]]
|
|
1349
|
+
) -> Dict[str, Tuple[bool, Any]]:
|
|
1350
|
+
"""Parse a fields spec into a mapping of field_name -> (should_check_value, expected_value)."""
|
|
1351
|
+
spec_map: Dict[str, Tuple[bool, Any]] = {}
|
|
1352
|
+
for spec_tuple in fields_spec:
|
|
1353
|
+
if len(spec_tuple) != 2:
|
|
1354
|
+
raise ValueError(
|
|
1355
|
+
f"Invalid field spec tuple: {spec_tuple}. "
|
|
1356
|
+
f"Expected 2-tuple like ('field', value), ('field', None), or ('field', ...)"
|
|
1357
|
+
)
|
|
1358
|
+
field_name, expected_value = spec_tuple
|
|
1359
|
+
if expected_value is ...:
|
|
1360
|
+
# Ellipsis: don't check value, just acknowledge field exists
|
|
1361
|
+
spec_map[field_name] = (False, None)
|
|
1362
|
+
else:
|
|
1363
|
+
# Any other value (including None for NULL check): check value
|
|
1364
|
+
spec_map[field_name] = (True, expected_value)
|
|
1365
|
+
return spec_map
|
|
1366
|
+
|
|
1367
|
+
def _validate_row_with_fields_spec(
|
|
1368
|
+
table: str,
|
|
1369
|
+
row_id: Any,
|
|
1370
|
+
row_data: Dict[str, Any],
|
|
1371
|
+
fields_spec: List[Tuple[str, Any]],
|
|
1372
|
+
) -> Optional[List[Tuple[str, Any, str]]]:
|
|
1373
|
+
"""Validate a row against a bulk fields spec.
|
|
1374
|
+
|
|
1375
|
+
Returns None if validation passes, or a list of (field, actual_value, issue)
|
|
1376
|
+
tuples for mismatches.
|
|
1377
|
+
|
|
1378
|
+
Field spec semantics:
|
|
1379
|
+
- ("field_name", value): check that field equals value
|
|
1380
|
+
- ("field_name", None): check that field is SQL NULL
|
|
1381
|
+
- ("field_name", ...): don't check value (acknowledge field exists)
|
|
1382
|
+
"""
|
|
1383
|
+
spec_map = _parse_fields_spec(fields_spec)
|
|
1384
|
+
unmatched_fields: List[Tuple[str, Any, str]] = []
|
|
1385
|
+
|
|
1386
|
+
for field_name, field_value in row_data.items():
|
|
1387
|
+
# Skip rowid as it's internal
|
|
1388
|
+
if field_name == "rowid":
|
|
1389
|
+
continue
|
|
1390
|
+
# Skip ignored fields
|
|
1391
|
+
if self.ignore_config.should_ignore_field(table, field_name):
|
|
1392
|
+
continue
|
|
1393
|
+
|
|
1394
|
+
if field_name not in spec_map:
|
|
1395
|
+
# Field not in spec - this is an error
|
|
1396
|
+
unmatched_fields.append(
|
|
1397
|
+
(field_name, field_value, "NOT_IN_FIELDS_SPEC")
|
|
1398
|
+
)
|
|
1399
|
+
else:
|
|
1400
|
+
should_check, expected_value = spec_map[field_name]
|
|
1401
|
+
if should_check and not _values_equivalent(
|
|
1402
|
+
expected_value, field_value
|
|
1403
|
+
):
|
|
1404
|
+
# Value doesn't match
|
|
1405
|
+
unmatched_fields.append(
|
|
1406
|
+
(field_name, field_value, f"expected {repr(expected_value)}")
|
|
1407
|
+
)
|
|
1408
|
+
|
|
1409
|
+
return unmatched_fields if unmatched_fields else None
|
|
1410
|
+
|
|
1411
|
+
def _validate_modification_with_fields_spec(
|
|
1412
|
+
table: str,
|
|
1413
|
+
row_id: Any,
|
|
1414
|
+
row_changes: Dict[str, Dict[str, Any]],
|
|
1415
|
+
resulting_fields: List[Tuple[str, Any]],
|
|
1416
|
+
no_other_changes: bool,
|
|
1417
|
+
) -> Optional[List[Tuple[str, Any, str]]]:
|
|
1418
|
+
"""Validate a modification against a resulting_fields spec.
|
|
1419
|
+
|
|
1420
|
+
Returns None if validation passes, or a list of (field, actual_value, issue)
|
|
1421
|
+
tuples for mismatches.
|
|
1422
|
+
|
|
1423
|
+
Args:
|
|
1424
|
+
table: The table name
|
|
1425
|
+
row_id: The row primary key
|
|
1426
|
+
row_changes: Dict of field_name -> {"before": ..., "after": ...}
|
|
1427
|
+
resulting_fields: List of field tuples to validate
|
|
1428
|
+
no_other_changes: If True, all changed fields must be in resulting_fields.
|
|
1429
|
+
If False, only validate fields in resulting_fields, ignore others.
|
|
1430
|
+
|
|
1431
|
+
Field spec semantics for modifications:
|
|
1432
|
+
- ("field_name", value): check that after value equals value
|
|
1433
|
+
- ("field_name", None): check that after value is SQL NULL
|
|
1434
|
+
- ("field_name", ...): don't check value, just acknowledge field changed
|
|
1435
|
+
"""
|
|
1436
|
+
spec_map = _parse_fields_spec(resulting_fields)
|
|
1437
|
+
unmatched_fields: List[Tuple[str, Any, str]] = []
|
|
1438
|
+
|
|
1439
|
+
for field_name, vals in row_changes.items():
|
|
1440
|
+
# Skip ignored fields
|
|
1441
|
+
if self.ignore_config.should_ignore_field(table, field_name):
|
|
1442
|
+
continue
|
|
1443
|
+
|
|
1444
|
+
after_value = vals["after"]
|
|
1445
|
+
|
|
1446
|
+
if field_name not in spec_map:
|
|
1447
|
+
# Changed field not in spec
|
|
1448
|
+
if no_other_changes:
|
|
1449
|
+
# Strict mode: all changed fields must be accounted for
|
|
1450
|
+
unmatched_fields.append(
|
|
1451
|
+
(field_name, after_value, "NOT_IN_RESULTING_FIELDS")
|
|
1452
|
+
)
|
|
1453
|
+
# If no_other_changes=False, ignore fields not in spec
|
|
1454
|
+
else:
|
|
1455
|
+
should_check, expected_value = spec_map[field_name]
|
|
1456
|
+
if should_check and not _values_equivalent(
|
|
1457
|
+
expected_value, after_value
|
|
1458
|
+
):
|
|
1459
|
+
# Value doesn't match
|
|
1460
|
+
unmatched_fields.append(
|
|
1461
|
+
(field_name, after_value, f"expected {repr(expected_value)}")
|
|
1462
|
+
)
|
|
1463
|
+
|
|
1464
|
+
return unmatched_fields if unmatched_fields else None
|
|
1465
|
+
|
|
1466
|
+
|
|
1467
|
+
# Collect all unexpected changes for detailed reporting
|
|
1468
|
+
unexpected_changes = []
|
|
1469
|
+
|
|
1470
|
+
for tbl, report in diff.items():
|
|
1471
|
+
for row in report.get("modified_rows", []):
|
|
1472
|
+
row_changes = row["changes"]
|
|
1473
|
+
|
|
1474
|
+
# Check for modify spec with resulting_fields
|
|
1475
|
+
modify_spec = _get_modify_spec(tbl, row["row_id"])
|
|
1476
|
+
if modify_spec is not None:
|
|
1477
|
+
resulting_fields = modify_spec.get("resulting_fields")
|
|
1478
|
+
if resulting_fields is not None:
|
|
1479
|
+
# Validate that no_other_changes is provided
|
|
1480
|
+
if "no_other_changes" not in modify_spec:
|
|
1481
|
+
raise ValueError(
|
|
1482
|
+
f"Modify spec for table '{tbl}' pk={row['row_id']} "
|
|
1483
|
+
f"has 'resulting_fields' but missing required 'no_other_changes' field. "
|
|
1484
|
+
f"Set 'no_other_changes': True to verify no other fields changed, "
|
|
1485
|
+
f"or 'no_other_changes': False to only check the specified fields."
|
|
1486
|
+
)
|
|
1487
|
+
no_other_changes = modify_spec["no_other_changes"]
|
|
1488
|
+
if not isinstance(no_other_changes, bool):
|
|
1489
|
+
raise ValueError(
|
|
1490
|
+
f"Modify spec for table '{tbl}' pk={row['row_id']} "
|
|
1491
|
+
f"has 'no_other_changes' but it must be a boolean (True or False), "
|
|
1492
|
+
f"got {type(no_other_changes).__name__}: {repr(no_other_changes)}"
|
|
1493
|
+
)
|
|
1494
|
+
|
|
1495
|
+
unmatched = _validate_modification_with_fields_spec(
|
|
1496
|
+
tbl, row["row_id"], row_changes, resulting_fields, no_other_changes
|
|
1497
|
+
)
|
|
1498
|
+
if unmatched:
|
|
1499
|
+
unexpected_changes.append(
|
|
1500
|
+
{
|
|
1501
|
+
"type": "modification",
|
|
1502
|
+
"table": tbl,
|
|
1503
|
+
"row_id": row["row_id"],
|
|
1504
|
+
"field": None,
|
|
1505
|
+
"before": None,
|
|
1506
|
+
"after": None,
|
|
1507
|
+
"full_row": row,
|
|
1508
|
+
"unmatched_fields": unmatched,
|
|
1509
|
+
}
|
|
1510
|
+
)
|
|
1511
|
+
continue # Skip to next row
|
|
1512
|
+
else:
|
|
1513
|
+
# Modify spec without resulting_fields - just allow the modification
|
|
1514
|
+
continue # Skip to next row
|
|
1515
|
+
|
|
1516
|
+
# Fall back to single-field specs (legacy)
|
|
1517
|
+
for f, vals in row_changes.items():
|
|
1518
|
+
if self.ignore_config.should_ignore_field(tbl, f):
|
|
1519
|
+
continue
|
|
1520
|
+
if not _is_change_allowed(tbl, row["row_id"], f, vals["after"]):
|
|
1521
|
+
unexpected_changes.append(
|
|
1522
|
+
{
|
|
1523
|
+
"type": "modification",
|
|
1524
|
+
"table": tbl,
|
|
1525
|
+
"row_id": row["row_id"],
|
|
1526
|
+
"field": f,
|
|
1527
|
+
"before": vals.get("before"),
|
|
1528
|
+
"after": vals["after"],
|
|
1529
|
+
"full_row": row,
|
|
1530
|
+
}
|
|
1531
|
+
)
|
|
1532
|
+
|
|
1533
|
+
for row in report.get("added_rows", []):
|
|
1534
|
+
row_data = row.get("data", {})
|
|
1535
|
+
|
|
1536
|
+
# Check for bulk fields spec (type: "insert")
|
|
1537
|
+
fields_spec = _get_fields_spec_for_type(tbl, row["row_id"], "insert")
|
|
1538
|
+
if fields_spec is not None:
|
|
1539
|
+
unmatched = _validate_row_with_fields_spec(
|
|
1540
|
+
tbl, row["row_id"], row_data, fields_spec
|
|
1541
|
+
)
|
|
1542
|
+
if unmatched:
|
|
1543
|
+
unexpected_changes.append(
|
|
1544
|
+
{
|
|
1545
|
+
"type": "insertion",
|
|
1546
|
+
"table": tbl,
|
|
1547
|
+
"row_id": row["row_id"],
|
|
1548
|
+
"field": None,
|
|
1549
|
+
"after": "__added__",
|
|
1550
|
+
"full_row": row,
|
|
1551
|
+
"unmatched_fields": unmatched,
|
|
1552
|
+
}
|
|
1553
|
+
)
|
|
1554
|
+
continue # Skip to next row
|
|
1555
|
+
|
|
1556
|
+
# Check if insertion is allowed without field validation
|
|
1557
|
+
if _is_type_allowed(tbl, row["row_id"], "insert"):
|
|
1558
|
+
continue # Insertion is allowed, skip to next row
|
|
1559
|
+
|
|
1560
|
+
# Check for whole-row spec (legacy)
|
|
1561
|
+
whole_row_allowed = _is_change_allowed(
|
|
1562
|
+
tbl, row["row_id"], None, "__added__"
|
|
1563
|
+
)
|
|
1564
|
+
|
|
1565
|
+
if not whole_row_allowed:
|
|
1566
|
+
unexpected_changes.append(
|
|
1567
|
+
{
|
|
1568
|
+
"type": "insertion",
|
|
1569
|
+
"table": tbl,
|
|
1570
|
+
"row_id": row["row_id"],
|
|
1571
|
+
"field": None,
|
|
1572
|
+
"after": "__added__",
|
|
1573
|
+
"full_row": row,
|
|
1574
|
+
}
|
|
1575
|
+
)
|
|
1576
|
+
|
|
1577
|
+
for row in report.get("removed_rows", []):
|
|
1578
|
+
row_data = row.get("data", {})
|
|
1579
|
+
|
|
1580
|
+
# Check for bulk fields spec (type: "delete")
|
|
1581
|
+
fields_spec = _get_fields_spec_for_type(tbl, row["row_id"], "delete")
|
|
1582
|
+
if fields_spec is not None:
|
|
1583
|
+
unmatched = _validate_row_with_fields_spec(
|
|
1584
|
+
tbl, row["row_id"], row_data, fields_spec
|
|
1585
|
+
)
|
|
1586
|
+
if unmatched:
|
|
1587
|
+
unexpected_changes.append(
|
|
1588
|
+
{
|
|
1589
|
+
"type": "deletion",
|
|
1590
|
+
"table": tbl,
|
|
1591
|
+
"row_id": row["row_id"],
|
|
1592
|
+
"field": None,
|
|
1593
|
+
"after": "__removed__",
|
|
1594
|
+
"full_row": row,
|
|
1595
|
+
"unmatched_fields": unmatched,
|
|
1596
|
+
}
|
|
1597
|
+
)
|
|
1598
|
+
continue # Skip to next row
|
|
1599
|
+
|
|
1600
|
+
# Check if deletion is allowed without field validation
|
|
1601
|
+
if _is_type_allowed(tbl, row["row_id"], "delete"):
|
|
1602
|
+
continue # Deletion is allowed, skip to next row
|
|
1603
|
+
|
|
1604
|
+
# Check for whole-row spec (legacy)
|
|
1605
|
+
whole_row_allowed = _is_change_allowed(
|
|
1606
|
+
tbl, row["row_id"], None, "__removed__"
|
|
1607
|
+
)
|
|
1608
|
+
|
|
1609
|
+
if not whole_row_allowed:
|
|
1610
|
+
unexpected_changes.append(
|
|
1611
|
+
{
|
|
1612
|
+
"type": "deletion",
|
|
1613
|
+
"table": tbl,
|
|
1614
|
+
"row_id": row["row_id"],
|
|
1615
|
+
"field": None,
|
|
1616
|
+
"after": "__removed__",
|
|
1617
|
+
"full_row": row,
|
|
1618
|
+
}
|
|
1619
|
+
)
|
|
1620
|
+
|
|
1621
|
+
if unexpected_changes:
|
|
1622
|
+
# Build comprehensive error message
|
|
1623
|
+
error_lines = ["Unexpected database changes detected:"]
|
|
1624
|
+
error_lines.append("")
|
|
1625
|
+
|
|
1626
|
+
for i, change in enumerate(unexpected_changes[:5], 1):
|
|
1627
|
+
error_lines.append(
|
|
1628
|
+
f"{i}. {change['type'].upper()} in table '{change['table']}':"
|
|
1629
|
+
)
|
|
1630
|
+
error_lines.append(f" Row ID: {change['row_id']}")
|
|
1631
|
+
|
|
1632
|
+
if change["type"] == "modification":
|
|
1633
|
+
error_lines.append(f" Field: {change['field']}")
|
|
1634
|
+
error_lines.append(f" Before: {repr(change['before'])}")
|
|
1635
|
+
error_lines.append(f" After: {repr(change['after'])}")
|
|
1636
|
+
elif change["type"] == "insertion":
|
|
1637
|
+
error_lines.append(" New row added")
|
|
1638
|
+
elif change["type"] == "deletion":
|
|
1639
|
+
error_lines.append(" Row deleted")
|
|
1640
|
+
|
|
1641
|
+
# Show unmatched fields if present (from bulk fields spec validation)
|
|
1642
|
+
if "unmatched_fields" in change and change["unmatched_fields"]:
|
|
1643
|
+
error_lines.append(" Unmatched fields:")
|
|
1644
|
+
for field_info in change["unmatched_fields"][:5]:
|
|
1645
|
+
field_name, actual_value, issue = field_info
|
|
1646
|
+
error_lines.append(
|
|
1647
|
+
f" - {field_name}: {repr(actual_value)} ({issue})"
|
|
1648
|
+
)
|
|
1649
|
+
if len(change["unmatched_fields"]) > 10:
|
|
1650
|
+
error_lines.append(
|
|
1651
|
+
f" ... and {len(change['unmatched_fields']) - 10} more"
|
|
1652
|
+
)
|
|
1653
|
+
|
|
1654
|
+
# Show some context from the row
|
|
1655
|
+
if "full_row" in change and change["full_row"]:
|
|
1656
|
+
row_data = change["full_row"]
|
|
1657
|
+
if change["type"] == "modification" and "data" in row_data:
|
|
1658
|
+
# For modifications, show the current state
|
|
1659
|
+
formatted_row = _format_row_for_error(
|
|
1660
|
+
row_data.get("data", {}), max_fields=5
|
|
1661
|
+
)
|
|
1662
|
+
error_lines.append(f" Row data: {formatted_row}")
|
|
1663
|
+
elif (
|
|
1664
|
+
change["type"] in ["insertion", "deletion"]
|
|
1665
|
+
and "data" in row_data
|
|
1666
|
+
):
|
|
1667
|
+
# For insertions/deletions, show the row data
|
|
1668
|
+
formatted_row = _format_row_for_error(
|
|
1669
|
+
row_data.get("data", {}), max_fields=5
|
|
1670
|
+
)
|
|
1671
|
+
error_lines.append(f" Row data: {formatted_row}")
|
|
1672
|
+
|
|
1673
|
+
error_lines.append("")
|
|
1674
|
+
|
|
1675
|
+
if len(unexpected_changes) > 5:
|
|
1676
|
+
error_lines.append(
|
|
1677
|
+
f"... and {len(unexpected_changes) - 5} more unexpected changes"
|
|
1678
|
+
)
|
|
1679
|
+
error_lines.append("")
|
|
1680
|
+
|
|
1681
|
+
# Show what changes were allowed
|
|
1682
|
+
error_lines.append("Allowed changes were:")
|
|
1683
|
+
if allowed_changes:
|
|
1684
|
+
for i, allowed in enumerate(allowed_changes[:3], 1):
|
|
1685
|
+
change_type = allowed.get("type", "unspecified")
|
|
1686
|
+
|
|
1687
|
+
# For modify type, use resulting_fields
|
|
1688
|
+
if change_type == "modify" and "resulting_fields" in allowed and allowed["resulting_fields"] is not None:
|
|
1689
|
+
fields_summary = ", ".join(
|
|
1690
|
+
f[0] if len(f) == 1 else f"{f[0]}={'NOT_CHECKED' if f[1] is ... else repr(f[1])}"
|
|
1691
|
+
for f in allowed["resulting_fields"][:3]
|
|
1692
|
+
)
|
|
1693
|
+
if len(allowed["resulting_fields"]) > 3:
|
|
1694
|
+
fields_summary += f", ... +{len(allowed['resulting_fields']) - 3} more"
|
|
1695
|
+
no_other = allowed.get("no_other_changes", "NOT_SET")
|
|
1696
|
+
error_lines.append(
|
|
1697
|
+
f" {i}. Table: {allowed.get('table')}, "
|
|
1698
|
+
f"ID: {allowed.get('pk')}, "
|
|
1699
|
+
f"Type: {change_type}, "
|
|
1700
|
+
f"resulting_fields: [{fields_summary}], "
|
|
1701
|
+
f"no_other_changes: {no_other}"
|
|
1702
|
+
)
|
|
1703
|
+
elif "fields" in allowed and allowed["fields"] is not None:
|
|
1704
|
+
# Show bulk fields spec (for insert/delete)
|
|
1705
|
+
fields_summary = ", ".join(
|
|
1706
|
+
f[0] if len(f) == 1 else f"{f[0]}={'NOT_CHECKED' if f[1] is ... else repr(f[1])}"
|
|
1707
|
+
for f in allowed["fields"][:3]
|
|
1708
|
+
)
|
|
1709
|
+
if len(allowed["fields"]) > 3:
|
|
1710
|
+
fields_summary += f", ... +{len(allowed['fields']) - 3} more"
|
|
1711
|
+
error_lines.append(
|
|
1712
|
+
f" {i}. Table: {allowed.get('table')}, "
|
|
1713
|
+
f"ID: {allowed.get('pk')}, "
|
|
1714
|
+
f"Type: {change_type}, "
|
|
1715
|
+
f"Fields: [{fields_summary}]"
|
|
1716
|
+
)
|
|
1717
|
+
else:
|
|
1718
|
+
error_lines.append(
|
|
1719
|
+
f" {i}. Table: {allowed.get('table')}, "
|
|
1720
|
+
f"ID: {allowed.get('pk')}, "
|
|
1721
|
+
f"Type: {change_type}"
|
|
1722
|
+
)
|
|
1723
|
+
if len(allowed_changes) > 3:
|
|
1724
|
+
error_lines.append(
|
|
1725
|
+
f" ... and {len(allowed_changes) - 3} more allowed changes"
|
|
1726
|
+
)
|
|
1727
|
+
else:
|
|
1728
|
+
error_lines.append(" (No changes were allowed)")
|
|
1729
|
+
|
|
1730
|
+
raise AssertionError("\n".join(error_lines))
|
|
1731
|
+
|
|
1732
|
+
return self
|
|
1733
|
+
|
|
1734
|
+
async def expect_only(self, allowed_changes: List[Dict[str, Any]]):
|
|
1735
|
+
"""Ensure only specified changes occurred."""
|
|
1736
|
+
# Normalize pk values: convert lists to tuples for hashability and consistency
|
|
1737
|
+
for change in allowed_changes:
|
|
1738
|
+
if "pk" in change and isinstance(change["pk"], list):
|
|
1739
|
+
change["pk"] = tuple(change["pk"])
|
|
1740
|
+
|
|
1741
|
+
# Special case: empty allowed_changes means no changes should have occurred
|
|
1742
|
+
if not allowed_changes:
|
|
1743
|
+
return await self._expect_no_changes()
|
|
1744
|
+
|
|
1745
|
+
# For expect_only, we can optimize by only checking the specific rows mentioned
|
|
1746
|
+
if self._can_use_targeted_queries(allowed_changes):
|
|
1747
|
+
return await self._expect_only_targeted(allowed_changes)
|
|
1748
|
+
|
|
1749
|
+
# Fall back to full diff for complex cases
|
|
1750
|
+
diff = await self._collect()
|
|
1751
|
+
return await self._validate_diff_against_allowed_changes(diff, allowed_changes)
|
|
1752
|
+
|
|
1753
|
+
async def expect_only_v2(self, allowed_changes: List[Dict[str, Any]]):
|
|
1754
|
+
"""Ensure only specified changes occurred, with field-level spec support.
|
|
1755
|
+
|
|
1756
|
+
This version supports field-level specifications for added/removed rows,
|
|
1757
|
+
allowing users to specify expected field values instead of just whole-row specs.
|
|
1758
|
+
"""
|
|
1759
|
+
# Normalize pk values: convert lists to tuples for hashability and consistency
|
|
1760
|
+
for change in allowed_changes:
|
|
1761
|
+
if "pk" in change and isinstance(change["pk"], list):
|
|
1762
|
+
change["pk"] = tuple(change["pk"])
|
|
1763
|
+
|
|
1764
|
+
# Special case: empty allowed_changes means no changes should have occurred
|
|
1765
|
+
if not allowed_changes:
|
|
1766
|
+
return await self._expect_no_changes()
|
|
1767
|
+
|
|
1768
|
+
resource = self.after.resource
|
|
1769
|
+
# Disabled: structured diff endpoint not yet available
|
|
1770
|
+
if False and resource.client is not None and resource._mode == "http":
|
|
1771
|
+
api_diff = None
|
|
1772
|
+
try:
|
|
1773
|
+
payload = {}
|
|
1774
|
+
if self.ignore_config:
|
|
1775
|
+
payload["ignore_config"] = {
|
|
1776
|
+
"tables": list(self.ignore_config.tables),
|
|
1777
|
+
"fields": list(self.ignore_config.fields),
|
|
1778
|
+
"table_fields": {
|
|
1779
|
+
table: list(fields) for table, fields in self.ignore_config.table_fields.items()
|
|
1780
|
+
}
|
|
1781
|
+
}
|
|
1782
|
+
response = await resource.client.request(
|
|
1783
|
+
"POST",
|
|
1784
|
+
"/diff/structured",
|
|
1785
|
+
json=payload,
|
|
1786
|
+
)
|
|
1787
|
+
result = response.json()
|
|
1788
|
+
if result.get("success") and "diff" in result:
|
|
1789
|
+
api_diff = result["diff"]
|
|
1790
|
+
except Exception as e:
|
|
1791
|
+
# Fall back to local diff if API call fails
|
|
1792
|
+
print(f"Warning: Failed to fetch structured diff from API: {e}")
|
|
1793
|
+
print("Falling back to local diff computation...")
|
|
1794
|
+
|
|
1795
|
+
# Validate outside try block so AssertionError propagates
|
|
1796
|
+
if api_diff is not None:
|
|
1797
|
+
return await self._validate_diff_against_allowed_changes_v2(api_diff, allowed_changes)
|
|
1798
|
+
|
|
1799
|
+
# For expect_only_v2, we can optimize by only checking the specific rows mentioned
|
|
1800
|
+
if self._can_use_targeted_queries(allowed_changes):
|
|
1801
|
+
return await self._expect_only_targeted_v2(allowed_changes)
|
|
1802
|
+
|
|
1803
|
+
# Fall back to full diff for complex cases
|
|
1804
|
+
diff = await self._collect()
|
|
1805
|
+
return await self._validate_diff_against_allowed_changes_v2(
|
|
1806
|
+
diff, allowed_changes
|
|
1807
|
+
)
|
|
1808
|
+
|
|
1809
|
+
async def expect_exactly(self, expected_changes: List[Dict[str, Any]]):
|
|
1810
|
+
"""Verify that EXACTLY the specified changes occurred.
|
|
1811
|
+
|
|
1812
|
+
This is stricter than expect_only_v2:
|
|
1813
|
+
1. All changes in diff must match a spec (no unexpected changes)
|
|
1814
|
+
2. All specs must have a matching change in diff (no missing expected changes)
|
|
1815
|
+
|
|
1816
|
+
This method is ideal for verifying that an agent performed exactly what was expected -
|
|
1817
|
+
not more, not less.
|
|
1818
|
+
|
|
1819
|
+
Args:
|
|
1820
|
+
expected_changes: List of expected change specs. Each spec requires:
|
|
1821
|
+
- "type": "insert", "modify", or "delete" (required)
|
|
1822
|
+
- "table": table name (required)
|
|
1823
|
+
- "pk": primary key value (required)
|
|
1824
|
+
|
|
1825
|
+
Spec formats by type:
|
|
1826
|
+
- Insert: {"type": "insert", "table": "t", "pk": 1, "fields": [...]}
|
|
1827
|
+
- Modify: {"type": "modify", "table": "t", "pk": 1, "resulting_fields": [...], "no_other_changes": True/False}
|
|
1828
|
+
- Delete: {"type": "delete", "table": "t", "pk": 1}
|
|
1829
|
+
|
|
1830
|
+
Field specs are 2-tuples: (field_name, expected_value)
|
|
1831
|
+
- ("name", "Alice"): check field equals "Alice"
|
|
1832
|
+
- ("name", ...): accept any value (ellipsis)
|
|
1833
|
+
- ("name", None): check field is SQL NULL
|
|
1834
|
+
|
|
1835
|
+
Note: Legacy specs without explicit "type" are not supported.
|
|
1836
|
+
|
|
1837
|
+
Returns:
|
|
1838
|
+
self for method chaining
|
|
1839
|
+
|
|
1840
|
+
Raises:
|
|
1841
|
+
AssertionError: If there are unexpected changes OR if expected changes are missing
|
|
1842
|
+
ValueError: If specs are missing required fields or have invalid format
|
|
1843
|
+
"""
|
|
1844
|
+
# Get the diff (using HTTP if available, otherwise local)
|
|
1845
|
+
resource = self.after.resource
|
|
1846
|
+
diff = None
|
|
1847
|
+
|
|
1848
|
+
if resource.client is not None and resource._mode == "http":
|
|
1849
|
+
try:
|
|
1850
|
+
payload = {}
|
|
1851
|
+
if self.ignore_config:
|
|
1852
|
+
payload["ignore_config"] = {
|
|
1853
|
+
"tables": list(self.ignore_config.tables),
|
|
1854
|
+
"fields": list(self.ignore_config.fields),
|
|
1855
|
+
"table_fields": {
|
|
1856
|
+
table: list(fields) for table, fields in self.ignore_config.table_fields.items()
|
|
1857
|
+
}
|
|
1858
|
+
}
|
|
1859
|
+
response = await resource.client.request(
|
|
1860
|
+
"POST",
|
|
1861
|
+
"/diff/structured",
|
|
1862
|
+
json=payload,
|
|
1863
|
+
)
|
|
1864
|
+
result = response.json()
|
|
1865
|
+
if result.get("success") and "diff" in result:
|
|
1866
|
+
diff = result["diff"]
|
|
1867
|
+
except Exception as e:
|
|
1868
|
+
print(f"Warning: Failed to fetch structured diff from API: {e}")
|
|
1869
|
+
print("Falling back to local diff computation...")
|
|
1870
|
+
|
|
1871
|
+
if diff is None:
|
|
1872
|
+
diff = await self._collect()
|
|
1873
|
+
|
|
1874
|
+
# Use shared validation logic
|
|
1875
|
+
success, error_msg, _ = validate_diff_expect_exactly(
|
|
1876
|
+
diff, expected_changes, self.ignore_config
|
|
1877
|
+
)
|
|
1878
|
+
|
|
1879
|
+
if not success:
|
|
1880
|
+
raise AssertionError(error_msg)
|
|
1881
|
+
|
|
1882
|
+
return self
|
|
1883
|
+
|
|
1884
|
+
async def _ensure_all_fetched(self):
|
|
1885
|
+
"""Fetch ALL data from ALL tables upfront (non-lazy loading).
|
|
1886
|
+
|
|
1887
|
+
This is the old approach before lazy loading was introduced.
|
|
1888
|
+
Used by expect_only_v1 for simpler, non-optimized diffing.
|
|
1889
|
+
"""
|
|
1890
|
+
# Get all tables from before snapshot
|
|
1891
|
+
tables_response = await self.before.resource.query(
|
|
1892
|
+
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
|
|
1893
|
+
)
|
|
1894
|
+
|
|
1895
|
+
if tables_response.rows:
|
|
1896
|
+
before_tables = [row[0] for row in tables_response.rows]
|
|
1897
|
+
for table in before_tables:
|
|
1898
|
+
await self.before._ensure_table_data(table)
|
|
1899
|
+
|
|
1900
|
+
# Also fetch from after snapshot
|
|
1901
|
+
tables_response = await self.after.resource.query(
|
|
1902
|
+
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
|
|
1903
|
+
)
|
|
1904
|
+
|
|
1905
|
+
if tables_response.rows:
|
|
1906
|
+
after_tables = [row[0] for row in tables_response.rows]
|
|
1907
|
+
for table in after_tables:
|
|
1908
|
+
await self.after._ensure_table_data(table)
|
|
1909
|
+
|
|
1910
|
+
async def expect_only_v1(self, allowed_changes: List[Dict[str, Any]]):
|
|
1911
|
+
"""Ensure only specified changes occurred using the original (non-optimized) approach.
|
|
1912
|
+
|
|
1913
|
+
This version attempts to use the /api/v1/env/diff/structured endpoint if available,
|
|
1914
|
+
falling back to local diff computation if the endpoint is not available.
|
|
1915
|
+
|
|
1916
|
+
Use this when you want the simpler, more predictable behavior of the original
|
|
1917
|
+
implementation without any query optimizations.
|
|
1918
|
+
"""
|
|
1919
|
+
# Try to use the structured diff endpoint if we have an HTTP client
|
|
1920
|
+
resource = self.after.resource
|
|
1921
|
+
if resource.client is not None and resource._mode == "http":
|
|
1922
|
+
api_diff = None
|
|
1923
|
+
try:
|
|
1924
|
+
payload = {}
|
|
1925
|
+
if self.ignore_config:
|
|
1926
|
+
payload["ignore_config"] = {
|
|
1927
|
+
"tables": list(self.ignore_config.tables),
|
|
1928
|
+
"fields": list(self.ignore_config.fields),
|
|
1929
|
+
"table_fields": {
|
|
1930
|
+
table: list(fields) for table, fields in self.ignore_config.table_fields.items()
|
|
1931
|
+
}
|
|
1932
|
+
}
|
|
1933
|
+
response = await resource.client.request(
|
|
1934
|
+
"POST",
|
|
1935
|
+
"/diff/structured",
|
|
1936
|
+
json=payload,
|
|
1937
|
+
)
|
|
1938
|
+
result = response.json()
|
|
1939
|
+
if result.get("success") and "diff" in result:
|
|
1940
|
+
api_diff = result["diff"]
|
|
1941
|
+
except Exception as e:
|
|
1942
|
+
# Fall back to local diff if API call fails
|
|
1943
|
+
print(f"Warning: Failed to fetch structured diff from API: {e}")
|
|
1944
|
+
print("Falling back to local diff computation...")
|
|
1945
|
+
|
|
1946
|
+
# Validate outside try block so AssertionError propagates
|
|
1947
|
+
if api_diff is not None:
|
|
1948
|
+
return await self._validate_diff_against_allowed_changes(api_diff, allowed_changes)
|
|
1949
|
+
|
|
1950
|
+
# Fall back to local diff computation
|
|
1951
|
+
await self._ensure_all_fetched()
|
|
1952
|
+
diff = await self._collect()
|
|
1953
|
+
return await self._validate_diff_against_allowed_changes(diff, allowed_changes)
|
|
1954
|
+
|
|
460
1955
|
|
|
461
1956
|
class AsyncQueryBuilder:
|
|
462
1957
|
"""Async query builder that translates DSL to SQL and executes through the API."""
|
|
@@ -555,13 +2050,13 @@ class AsyncQueryBuilder:
|
|
|
555
2050
|
# Compile to SQL
|
|
556
2051
|
def _compile(self) -> Tuple[str, List[Any]]:
|
|
557
2052
|
cols = ", ".join(self._select_cols)
|
|
558
|
-
sql = [f"SELECT {cols} FROM {self._table}"]
|
|
2053
|
+
sql = [f"SELECT {cols} FROM {_quote_identifier(self._table)}"]
|
|
559
2054
|
params: List[Any] = []
|
|
560
2055
|
|
|
561
2056
|
# Joins
|
|
562
2057
|
for tbl, onmap in self._joins:
|
|
563
|
-
join_clauses = [f"{self._table}.{l} = {tbl}.{r}" for l, r in onmap.items()]
|
|
564
|
-
sql.append(f"JOIN {tbl} ON {' AND '.join(join_clauses)}")
|
|
2058
|
+
join_clauses = [f"{_quote_identifier(self._table)}.{_quote_identifier(l)} = {_quote_identifier(tbl)}.{_quote_identifier(r)}" for l, r in onmap.items()]
|
|
2059
|
+
sql.append(f"JOIN {_quote_identifier(tbl)} ON {' AND '.join(join_clauses)}")
|
|
565
2060
|
|
|
566
2061
|
# WHERE
|
|
567
2062
|
if self._conditions:
|
|
@@ -569,12 +2064,12 @@ class AsyncQueryBuilder:
|
|
|
569
2064
|
for col, op, val in self._conditions:
|
|
570
2065
|
if op in ("IN", "NOT IN") and isinstance(val, tuple):
|
|
571
2066
|
ph = ", ".join(["?" for _ in val])
|
|
572
|
-
placeholders.append(f"{col} {op} ({ph})")
|
|
2067
|
+
placeholders.append(f"{_quote_identifier(col)} {op} ({ph})")
|
|
573
2068
|
params.extend(val)
|
|
574
2069
|
elif op in ("IS", "IS NOT"):
|
|
575
|
-
placeholders.append(f"{col} {op} NULL")
|
|
2070
|
+
placeholders.append(f"{_quote_identifier(col)} {op} NULL")
|
|
576
2071
|
else:
|
|
577
|
-
placeholders.append(f"{col} {op} ?")
|
|
2072
|
+
placeholders.append(f"{_quote_identifier(col)} {op} ?")
|
|
578
2073
|
params.append(val)
|
|
579
2074
|
sql.append("WHERE " + " AND ".join(placeholders))
|
|
580
2075
|
|
|
@@ -679,16 +2174,106 @@ class AsyncQueryBuilder:
|
|
|
679
2174
|
|
|
680
2175
|
|
|
681
2176
|
class AsyncSQLiteResource(Resource):
|
|
682
|
-
def __init__(
|
|
2177
|
+
def __init__(
|
|
2178
|
+
self,
|
|
2179
|
+
resource: ResourceModel,
|
|
2180
|
+
client: Optional["AsyncWrapper"] = None,
|
|
2181
|
+
db_path: Optional[str] = None,
|
|
2182
|
+
):
|
|
683
2183
|
super().__init__(resource)
|
|
684
2184
|
self.client = client
|
|
2185
|
+
self.db_path = db_path
|
|
2186
|
+
self._mode = "direct" if db_path else "http"
|
|
2187
|
+
|
|
2188
|
+
@property
|
|
2189
|
+
def mode(self) -> str:
|
|
2190
|
+
"""Return the mode of this resource: 'direct' (local file) or 'http' (remote API)."""
|
|
2191
|
+
return self._mode
|
|
685
2192
|
|
|
686
2193
|
async def describe(self) -> DescribeResponse:
|
|
687
2194
|
"""Describe the SQLite database schema."""
|
|
2195
|
+
if self._mode == "direct":
|
|
2196
|
+
return await self._describe_direct()
|
|
2197
|
+
else:
|
|
2198
|
+
return await self._describe_http()
|
|
2199
|
+
|
|
2200
|
+
async def _describe_http(self) -> DescribeResponse:
|
|
2201
|
+
"""Describe database schema via HTTP API."""
|
|
688
2202
|
response = await self.client.request(
|
|
689
2203
|
"GET", f"/resources/sqlite/{self.resource.name}/describe"
|
|
690
2204
|
)
|
|
691
|
-
|
|
2205
|
+
try:
|
|
2206
|
+
return DescribeResponse(**response.json())
|
|
2207
|
+
except json.JSONDecodeError as e:
|
|
2208
|
+
raise ValueError(
|
|
2209
|
+
f"Failed to parse JSON response from SQLite describe endpoint. "
|
|
2210
|
+
f"Status: {response.status_code}, "
|
|
2211
|
+
f"Response text: {response.text[:500]}"
|
|
2212
|
+
) from e
|
|
2213
|
+
|
|
2214
|
+
async def _describe_direct(self) -> DescribeResponse:
|
|
2215
|
+
"""Describe database schema from local file or in-memory database."""
|
|
2216
|
+
def _sync_describe():
|
|
2217
|
+
try:
|
|
2218
|
+
# Check if we need URI mode (for shared memory databases)
|
|
2219
|
+
use_uri = 'mode=memory' in self.db_path
|
|
2220
|
+
conn = sqlite3.connect(self.db_path, uri=use_uri)
|
|
2221
|
+
cursor = conn.cursor()
|
|
2222
|
+
|
|
2223
|
+
# Get all tables
|
|
2224
|
+
cursor.execute(
|
|
2225
|
+
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
|
|
2226
|
+
)
|
|
2227
|
+
table_names = [row[0] for row in cursor.fetchall()]
|
|
2228
|
+
|
|
2229
|
+
tables = []
|
|
2230
|
+
for table_name in table_names:
|
|
2231
|
+
# Get table info
|
|
2232
|
+
cursor.execute(f"PRAGMA table_info({_quote_identifier(table_name)})")
|
|
2233
|
+
columns = cursor.fetchall()
|
|
2234
|
+
|
|
2235
|
+
# Get CREATE TABLE SQL
|
|
2236
|
+
cursor.execute(
|
|
2237
|
+
f"SELECT sql FROM sqlite_master WHERE type='table' AND name=?",
|
|
2238
|
+
(table_name,)
|
|
2239
|
+
)
|
|
2240
|
+
sql_row = cursor.fetchone()
|
|
2241
|
+
create_sql = sql_row[0] if sql_row else ""
|
|
2242
|
+
|
|
2243
|
+
table_schema = {
|
|
2244
|
+
"name": table_name,
|
|
2245
|
+
"sql": create_sql,
|
|
2246
|
+
"columns": [
|
|
2247
|
+
{
|
|
2248
|
+
"name": col[1],
|
|
2249
|
+
"type": col[2],
|
|
2250
|
+
"notnull": bool(col[3]),
|
|
2251
|
+
"default_value": col[4],
|
|
2252
|
+
"primary_key": col[5] > 0,
|
|
2253
|
+
}
|
|
2254
|
+
for col in columns
|
|
2255
|
+
],
|
|
2256
|
+
}
|
|
2257
|
+
tables.append(table_schema)
|
|
2258
|
+
|
|
2259
|
+
conn.close()
|
|
2260
|
+
|
|
2261
|
+
return DescribeResponse(
|
|
2262
|
+
success=True,
|
|
2263
|
+
resource_name=self.resource.name,
|
|
2264
|
+
tables=tables,
|
|
2265
|
+
message="Schema retrieved from local file",
|
|
2266
|
+
)
|
|
2267
|
+
except Exception as e:
|
|
2268
|
+
return DescribeResponse(
|
|
2269
|
+
success=False,
|
|
2270
|
+
resource_name=self.resource.name,
|
|
2271
|
+
tables=None,
|
|
2272
|
+
error=str(e),
|
|
2273
|
+
message=f"Failed to describe database: {str(e)}",
|
|
2274
|
+
)
|
|
2275
|
+
|
|
2276
|
+
return await asyncio.to_thread(_sync_describe)
|
|
692
2277
|
|
|
693
2278
|
async def query(
|
|
694
2279
|
self, query: str, args: Optional[List[Any]] = None
|
|
@@ -701,6 +2286,121 @@ class AsyncSQLiteResource(Resource):
|
|
|
701
2286
|
async def _query(
|
|
702
2287
|
self, query: str, args: Optional[List[Any]] = None, read_only: bool = True
|
|
703
2288
|
) -> QueryResponse:
|
|
2289
|
+
if self._mode == "direct":
|
|
2290
|
+
return await self._query_direct(query, args, read_only)
|
|
2291
|
+
else:
|
|
2292
|
+
# Check if this is a PRAGMA query - HTTP endpoints don't support PRAGMA
|
|
2293
|
+
query_stripped = query.strip().upper()
|
|
2294
|
+
if query_stripped.startswith("PRAGMA"):
|
|
2295
|
+
return await self._handle_pragma_query_http(query, args)
|
|
2296
|
+
return await self._query_http(query, args, read_only)
|
|
2297
|
+
|
|
2298
|
+
async def _handle_pragma_query_http(
|
|
2299
|
+
self, query: str, args: Optional[List[Any]] = None
|
|
2300
|
+
) -> QueryResponse:
|
|
2301
|
+
"""Handle PRAGMA queries in HTTP mode by using the describe endpoint."""
|
|
2302
|
+
query_upper = query.strip().upper()
|
|
2303
|
+
|
|
2304
|
+
# Extract table name from PRAGMA table_info(table_name)
|
|
2305
|
+
if "TABLE_INFO" in query_upper:
|
|
2306
|
+
# Match: PRAGMA table_info("table") or PRAGMA table_info(table)
|
|
2307
|
+
match = re.search(r'TABLE_INFO\s*\(\s*"([^"]+)"\s*\)', query, re.IGNORECASE)
|
|
2308
|
+
if not match:
|
|
2309
|
+
match = re.search(r"TABLE_INFO\s*\(\s*'([^']+)'\s*\)", query, re.IGNORECASE)
|
|
2310
|
+
if not match:
|
|
2311
|
+
match = re.search(r'TABLE_INFO\s*\(\s*([^\s\)]+)\s*\)', query, re.IGNORECASE)
|
|
2312
|
+
|
|
2313
|
+
if match:
|
|
2314
|
+
table_name = match.group(1)
|
|
2315
|
+
|
|
2316
|
+
# Use the describe endpoint to get schema
|
|
2317
|
+
describe_response = await self.describe()
|
|
2318
|
+
if not describe_response.success or not describe_response.tables:
|
|
2319
|
+
return QueryResponse(
|
|
2320
|
+
success=False,
|
|
2321
|
+
columns=None,
|
|
2322
|
+
rows=None,
|
|
2323
|
+
error="Failed to get schema information",
|
|
2324
|
+
message="PRAGMA query failed: could not retrieve schema"
|
|
2325
|
+
)
|
|
2326
|
+
|
|
2327
|
+
# Find the table in the schema
|
|
2328
|
+
table_schema = None
|
|
2329
|
+
for table in describe_response.tables:
|
|
2330
|
+
# Handle both dict and TableSchema objects
|
|
2331
|
+
table_name_in_schema = table.name if hasattr(table, 'name') else table.get("name")
|
|
2332
|
+
if table_name_in_schema == table_name:
|
|
2333
|
+
table_schema = table
|
|
2334
|
+
break
|
|
2335
|
+
|
|
2336
|
+
if not table_schema:
|
|
2337
|
+
return QueryResponse(
|
|
2338
|
+
success=False,
|
|
2339
|
+
columns=None,
|
|
2340
|
+
rows=None,
|
|
2341
|
+
error=f"Table '{table_name}' not found",
|
|
2342
|
+
message=f"PRAGMA query failed: table '{table_name}' not found"
|
|
2343
|
+
)
|
|
2344
|
+
|
|
2345
|
+
# Get columns from table schema
|
|
2346
|
+
columns = table_schema.columns if hasattr(table_schema, 'columns') else table_schema.get("columns")
|
|
2347
|
+
if not columns:
|
|
2348
|
+
return QueryResponse(
|
|
2349
|
+
success=False,
|
|
2350
|
+
columns=None,
|
|
2351
|
+
rows=None,
|
|
2352
|
+
error=f"Table '{table_name}' has no columns",
|
|
2353
|
+
message=f"PRAGMA query failed: table '{table_name}' has no columns"
|
|
2354
|
+
)
|
|
2355
|
+
|
|
2356
|
+
# Convert schema to PRAGMA table_info format
|
|
2357
|
+
# Format: (cid, name, type, notnull, dflt_value, pk)
|
|
2358
|
+
rows = []
|
|
2359
|
+
for idx, col in enumerate(columns):
|
|
2360
|
+
# Handle both dict and object column definitions
|
|
2361
|
+
if isinstance(col, dict):
|
|
2362
|
+
col_name = col["name"]
|
|
2363
|
+
col_type = col.get("type", "")
|
|
2364
|
+
col_notnull = col.get("notnull", False)
|
|
2365
|
+
col_default = col.get("default_value")
|
|
2366
|
+
col_pk = col.get("pk", 0)
|
|
2367
|
+
else:
|
|
2368
|
+
col_name = col.name if hasattr(col, 'name') else str(col)
|
|
2369
|
+
col_type = getattr(col, 'type', "")
|
|
2370
|
+
col_notnull = getattr(col, 'notnull', False)
|
|
2371
|
+
col_default = getattr(col, 'default_value', None)
|
|
2372
|
+
col_pk = getattr(col, 'pk', 0)
|
|
2373
|
+
|
|
2374
|
+
row = (
|
|
2375
|
+
idx, # cid
|
|
2376
|
+
col_name, # name
|
|
2377
|
+
col_type, # type
|
|
2378
|
+
1 if col_notnull else 0, # notnull
|
|
2379
|
+
col_default, # dflt_value
|
|
2380
|
+
col_pk # pk
|
|
2381
|
+
)
|
|
2382
|
+
rows.append(row)
|
|
2383
|
+
|
|
2384
|
+
return QueryResponse(
|
|
2385
|
+
success=True,
|
|
2386
|
+
columns=["cid", "name", "type", "notnull", "dflt_value", "pk"],
|
|
2387
|
+
rows=rows,
|
|
2388
|
+
message="PRAGMA query executed successfully via describe endpoint"
|
|
2389
|
+
)
|
|
2390
|
+
|
|
2391
|
+
# For other PRAGMA queries, return an error indicating they're not supported
|
|
2392
|
+
return QueryResponse(
|
|
2393
|
+
success=False,
|
|
2394
|
+
columns=None,
|
|
2395
|
+
rows=None,
|
|
2396
|
+
error="PRAGMA query not supported in HTTP mode",
|
|
2397
|
+
message=f"PRAGMA query '{query}' is not supported via HTTP API"
|
|
2398
|
+
)
|
|
2399
|
+
|
|
2400
|
+
async def _query_http(
|
|
2401
|
+
self, query: str, args: Optional[List[Any]] = None, read_only: bool = True
|
|
2402
|
+
) -> QueryResponse:
|
|
2403
|
+
"""Execute query via HTTP API."""
|
|
704
2404
|
request = QueryRequest(query=query, args=args, read_only=read_only)
|
|
705
2405
|
response = await self.client.request(
|
|
706
2406
|
"POST",
|
|
@@ -709,6 +2409,62 @@ class AsyncSQLiteResource(Resource):
|
|
|
709
2409
|
)
|
|
710
2410
|
return QueryResponse(**response.json())
|
|
711
2411
|
|
|
2412
|
+
async def _query_direct(
|
|
2413
|
+
self, query: str, args: Optional[List[Any]] = None, read_only: bool = True
|
|
2414
|
+
) -> QueryResponse:
|
|
2415
|
+
"""Execute query directly on local SQLite file or in-memory database."""
|
|
2416
|
+
def _sync_query():
|
|
2417
|
+
try:
|
|
2418
|
+
# Check if we need URI mode (for shared memory databases)
|
|
2419
|
+
use_uri = 'mode=memory' in self.db_path
|
|
2420
|
+
conn = sqlite3.connect(self.db_path, uri=use_uri)
|
|
2421
|
+
cursor = conn.cursor()
|
|
2422
|
+
|
|
2423
|
+
# Execute the query
|
|
2424
|
+
if args:
|
|
2425
|
+
cursor.execute(query, args)
|
|
2426
|
+
else:
|
|
2427
|
+
cursor.execute(query)
|
|
2428
|
+
|
|
2429
|
+
# For write operations, commit the transaction
|
|
2430
|
+
if not read_only:
|
|
2431
|
+
conn.commit()
|
|
2432
|
+
|
|
2433
|
+
# Get column names if available
|
|
2434
|
+
columns = [desc[0] for desc in cursor.description] if cursor.description else []
|
|
2435
|
+
|
|
2436
|
+
# Fetch results for SELECT queries
|
|
2437
|
+
rows = []
|
|
2438
|
+
rows_affected = 0
|
|
2439
|
+
last_insert_id = None
|
|
2440
|
+
|
|
2441
|
+
if cursor.description: # SELECT query
|
|
2442
|
+
rows = cursor.fetchall()
|
|
2443
|
+
else: # INSERT/UPDATE/DELETE
|
|
2444
|
+
rows_affected = cursor.rowcount
|
|
2445
|
+
last_insert_id = cursor.lastrowid if cursor.lastrowid else None
|
|
2446
|
+
|
|
2447
|
+
conn.close()
|
|
2448
|
+
|
|
2449
|
+
return QueryResponse(
|
|
2450
|
+
success=True,
|
|
2451
|
+
columns=columns if columns else None,
|
|
2452
|
+
rows=rows if rows else None,
|
|
2453
|
+
rows_affected=rows_affected if rows_affected > 0 else None,
|
|
2454
|
+
last_insert_id=last_insert_id,
|
|
2455
|
+
message="Query executed successfully",
|
|
2456
|
+
)
|
|
2457
|
+
except Exception as e:
|
|
2458
|
+
return QueryResponse(
|
|
2459
|
+
success=False,
|
|
2460
|
+
columns=None,
|
|
2461
|
+
rows=None,
|
|
2462
|
+
error=str(e),
|
|
2463
|
+
message=f"Query failed: {str(e)}",
|
|
2464
|
+
)
|
|
2465
|
+
|
|
2466
|
+
return await asyncio.to_thread(_sync_query)
|
|
2467
|
+
|
|
712
2468
|
def table(self, table_name: str) -> AsyncQueryBuilder:
|
|
713
2469
|
"""Create a query builder for the specified table."""
|
|
714
2470
|
return AsyncQueryBuilder(self, table_name)
|
|
@@ -716,7 +2472,6 @@ class AsyncSQLiteResource(Resource):
|
|
|
716
2472
|
async def snapshot(self, name: Optional[str] = None) -> AsyncDatabaseSnapshot:
|
|
717
2473
|
"""Create a snapshot of the current database state."""
|
|
718
2474
|
snapshot = AsyncDatabaseSnapshot(self, name)
|
|
719
|
-
await snapshot._ensure_fetched()
|
|
720
2475
|
return snapshot
|
|
721
2476
|
|
|
722
2477
|
async def diff(
|