fleet-python 0.2.66b2__py3-none-any.whl → 0.2.105__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/export_tasks.py +16 -5
- examples/export_tasks_filtered.py +245 -0
- examples/fetch_tasks.py +230 -0
- examples/import_tasks.py +140 -8
- examples/iterate_verifiers.py +725 -0
- fleet/__init__.py +128 -5
- fleet/_async/__init__.py +27 -3
- fleet/_async/base.py +24 -9
- fleet/_async/client.py +938 -41
- fleet/_async/env/client.py +60 -3
- fleet/_async/instance/client.py +52 -7
- fleet/_async/models.py +15 -0
- fleet/_async/resources/api.py +200 -0
- fleet/_async/resources/sqlite.py +1801 -46
- fleet/_async/tasks.py +122 -25
- fleet/_async/verifiers/bundler.py +22 -21
- fleet/_async/verifiers/verifier.py +25 -19
- fleet/agent/__init__.py +32 -0
- fleet/agent/gemini_cua/Dockerfile +45 -0
- fleet/agent/gemini_cua/__init__.py +10 -0
- fleet/agent/gemini_cua/agent.py +759 -0
- fleet/agent/gemini_cua/mcp/main.py +108 -0
- fleet/agent/gemini_cua/mcp_server/__init__.py +5 -0
- fleet/agent/gemini_cua/mcp_server/main.py +105 -0
- fleet/agent/gemini_cua/mcp_server/tools.py +178 -0
- fleet/agent/gemini_cua/requirements.txt +5 -0
- fleet/agent/gemini_cua/start.sh +30 -0
- fleet/agent/orchestrator.py +854 -0
- fleet/agent/types.py +49 -0
- fleet/agent/utils.py +34 -0
- fleet/base.py +34 -9
- fleet/cli.py +1061 -0
- fleet/client.py +1060 -48
- fleet/config.py +1 -1
- fleet/env/__init__.py +16 -0
- fleet/env/client.py +60 -3
- fleet/eval/__init__.py +15 -0
- fleet/eval/uploader.py +231 -0
- fleet/exceptions.py +8 -0
- fleet/instance/client.py +53 -8
- fleet/instance/models.py +1 -0
- fleet/models.py +303 -0
- fleet/proxy/__init__.py +25 -0
- fleet/proxy/proxy.py +453 -0
- fleet/proxy/whitelist.py +244 -0
- fleet/resources/api.py +200 -0
- fleet/resources/sqlite.py +1845 -46
- fleet/tasks.py +113 -20
- fleet/utils/__init__.py +7 -0
- fleet/utils/http_logging.py +178 -0
- fleet/utils/logging.py +13 -0
- fleet/utils/playwright.py +440 -0
- fleet/verifiers/bundler.py +22 -21
- fleet/verifiers/db.py +985 -1
- fleet/verifiers/decorator.py +1 -1
- fleet/verifiers/verifier.py +25 -19
- {fleet_python-0.2.66b2.dist-info → fleet_python-0.2.105.dist-info}/METADATA +28 -1
- fleet_python-0.2.105.dist-info/RECORD +115 -0
- {fleet_python-0.2.66b2.dist-info → fleet_python-0.2.105.dist-info}/WHEEL +1 -1
- fleet_python-0.2.105.dist-info/entry_points.txt +2 -0
- tests/test_app_method.py +85 -0
- tests/test_expect_exactly.py +4148 -0
- tests/test_expect_only.py +2593 -0
- tests/test_instance_dispatch.py +607 -0
- tests/test_sqlite_resource_dual_mode.py +263 -0
- tests/test_sqlite_shared_memory_behavior.py +117 -0
- fleet_python-0.2.66b2.dist-info/RECORD +0 -81
- tests/test_verifier_security.py +0 -427
- {fleet_python-0.2.66b2.dist-info → fleet_python-0.2.105.dist-info}/licenses/LICENSE +0 -0
- {fleet_python-0.2.66b2.dist-info → fleet_python-0.2.105.dist-info}/top_level.txt +0 -0
fleet/resources/sqlite.py
CHANGED
|
@@ -6,6 +6,8 @@ from datetime import datetime
|
|
|
6
6
|
import tempfile
|
|
7
7
|
import sqlite3
|
|
8
8
|
import os
|
|
9
|
+
import re
|
|
10
|
+
import json
|
|
9
11
|
|
|
10
12
|
from typing import TYPE_CHECKING
|
|
11
13
|
|
|
@@ -19,11 +21,23 @@ from fleet.verifiers.db import (
|
|
|
19
21
|
_get_row_identifier,
|
|
20
22
|
_format_row_for_error,
|
|
21
23
|
_values_equivalent,
|
|
24
|
+
validate_diff_expect_exactly,
|
|
22
25
|
)
|
|
23
26
|
|
|
24
27
|
|
|
28
|
+
def _quote_identifier(identifier: str) -> str:
|
|
29
|
+
"""Quote an identifier (table or column name) for SQLite.
|
|
30
|
+
|
|
31
|
+
SQLite uses double quotes for identifiers and escapes internal quotes by doubling them.
|
|
32
|
+
This handles reserved keywords like 'order', 'table', etc.
|
|
33
|
+
"""
|
|
34
|
+
# Escape any double quotes in the identifier by doubling them
|
|
35
|
+
escaped = identifier.replace('"', '""')
|
|
36
|
+
return f'"{escaped}"'
|
|
37
|
+
|
|
38
|
+
|
|
25
39
|
class SyncDatabaseSnapshot:
|
|
26
|
-
"""
|
|
40
|
+
"""Lazy database snapshot that fetches data on-demand through API."""
|
|
27
41
|
|
|
28
42
|
def __init__(self, resource: "SQLiteResource", name: Optional[str] = None):
|
|
29
43
|
self.resource = resource
|
|
@@ -31,11 +45,12 @@ class SyncDatabaseSnapshot:
|
|
|
31
45
|
self.created_at = datetime.utcnow()
|
|
32
46
|
self._data: Dict[str, List[Dict[str, Any]]] = {}
|
|
33
47
|
self._schemas: Dict[str, List[str]] = {}
|
|
34
|
-
self.
|
|
48
|
+
self._table_names: Optional[List[str]] = None
|
|
49
|
+
self._fetched_tables: set = set()
|
|
35
50
|
|
|
36
|
-
def
|
|
37
|
-
"""Fetch
|
|
38
|
-
if self.
|
|
51
|
+
def _ensure_tables_list(self):
|
|
52
|
+
"""Fetch just the list of table names if not already fetched."""
|
|
53
|
+
if self._table_names is not None:
|
|
39
54
|
return
|
|
40
55
|
|
|
41
56
|
# Get all tables
|
|
@@ -44,35 +59,36 @@ class SyncDatabaseSnapshot:
|
|
|
44
59
|
)
|
|
45
60
|
|
|
46
61
|
if not tables_response.rows:
|
|
47
|
-
self.
|
|
62
|
+
self._table_names = []
|
|
48
63
|
return
|
|
49
64
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
for table
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
if schema_response.rows:
|
|
57
|
-
self._schemas[table] = [
|
|
58
|
-
row[1] for row in schema_response.rows
|
|
59
|
-
] # Column names
|
|
60
|
-
|
|
61
|
-
# Get all data
|
|
62
|
-
data_response = self.resource.query(f"SELECT * FROM {table}")
|
|
63
|
-
if data_response.rows and data_response.columns:
|
|
64
|
-
self._data[table] = [
|
|
65
|
-
dict(zip(data_response.columns, row)) for row in data_response.rows
|
|
66
|
-
]
|
|
67
|
-
else:
|
|
68
|
-
self._data[table] = []
|
|
65
|
+
self._table_names = [row[0] for row in tables_response.rows]
|
|
66
|
+
|
|
67
|
+
def _ensure_table_data(self, table: str):
|
|
68
|
+
"""Fetch data for a specific table on demand."""
|
|
69
|
+
if table in self._fetched_tables:
|
|
70
|
+
return
|
|
69
71
|
|
|
70
|
-
|
|
72
|
+
# Get table schema
|
|
73
|
+
schema_response = self.resource.query(f"PRAGMA table_info({_quote_identifier(table)})")
|
|
74
|
+
if schema_response.rows:
|
|
75
|
+
self._schemas[table] = [row[1] for row in schema_response.rows] # Column names
|
|
76
|
+
|
|
77
|
+
# Get all data for this table
|
|
78
|
+
data_response = self.resource.query(f"SELECT * FROM {_quote_identifier(table)}")
|
|
79
|
+
if data_response.rows and data_response.columns:
|
|
80
|
+
self._data[table] = [
|
|
81
|
+
dict(zip(data_response.columns, row)) for row in data_response.rows
|
|
82
|
+
]
|
|
83
|
+
else:
|
|
84
|
+
self._data[table] = []
|
|
85
|
+
|
|
86
|
+
self._fetched_tables.add(table)
|
|
71
87
|
|
|
72
88
|
def tables(self) -> List[str]:
|
|
73
89
|
"""Get list of all tables in the snapshot."""
|
|
74
|
-
self.
|
|
75
|
-
return list(self.
|
|
90
|
+
self._ensure_tables_list()
|
|
91
|
+
return list(self._table_names) if self._table_names else []
|
|
76
92
|
|
|
77
93
|
def table(self, table_name: str) -> "SyncSnapshotQueryBuilder":
|
|
78
94
|
"""Create a query builder for snapshot data."""
|
|
@@ -84,13 +100,12 @@ class SyncDatabaseSnapshot:
|
|
|
84
100
|
ignore_config: Optional[IgnoreConfig] = None,
|
|
85
101
|
) -> "SyncSnapshotDiff":
|
|
86
102
|
"""Compare this snapshot with another."""
|
|
87
|
-
|
|
88
|
-
other._ensure_fetched()
|
|
103
|
+
# No need to fetch all data upfront - diff will fetch on demand
|
|
89
104
|
return SyncSnapshotDiff(self, other, ignore_config)
|
|
90
105
|
|
|
91
106
|
|
|
92
107
|
class SyncSnapshotQueryBuilder:
|
|
93
|
-
"""Query builder that works on
|
|
108
|
+
"""Query builder that works on snapshot data - can use targeted queries when possible."""
|
|
94
109
|
|
|
95
110
|
def __init__(self, snapshot: SyncDatabaseSnapshot, table: str):
|
|
96
111
|
self._snapshot = snapshot
|
|
@@ -100,10 +115,63 @@ class SyncSnapshotQueryBuilder:
|
|
|
100
115
|
self._limit: Optional[int] = None
|
|
101
116
|
self._order_by: Optional[str] = None
|
|
102
117
|
self._order_desc: bool = False
|
|
118
|
+
self._use_targeted_query = True # Try to use targeted queries when possible
|
|
119
|
+
|
|
120
|
+
def _can_use_targeted_query(self) -> bool:
|
|
121
|
+
"""Check if we can use a targeted query instead of loading all data."""
|
|
122
|
+
# We can use targeted query if:
|
|
123
|
+
# 1. We have simple equality conditions
|
|
124
|
+
# 2. No complex operations like joins
|
|
125
|
+
# 3. The query is selective (has conditions)
|
|
126
|
+
if not self._conditions:
|
|
127
|
+
return False
|
|
128
|
+
for col, op, val in self._conditions:
|
|
129
|
+
if op not in ["=", "IS", "IS NOT"]:
|
|
130
|
+
return False
|
|
131
|
+
return True
|
|
132
|
+
|
|
133
|
+
def _execute_targeted_query(self) -> List[Dict[str, Any]]:
|
|
134
|
+
"""Execute a targeted query directly instead of loading all data."""
|
|
135
|
+
# Build WHERE clause
|
|
136
|
+
where_parts = []
|
|
137
|
+
for col, op, val in self._conditions:
|
|
138
|
+
if op == "=" and val is None:
|
|
139
|
+
where_parts.append(f"{_quote_identifier(col)} IS NULL")
|
|
140
|
+
elif op == "IS":
|
|
141
|
+
where_parts.append(f"{_quote_identifier(col)} IS NULL")
|
|
142
|
+
elif op == "IS NOT":
|
|
143
|
+
where_parts.append(f"{_quote_identifier(col)} IS NOT NULL")
|
|
144
|
+
elif op == "=":
|
|
145
|
+
if isinstance(val, str):
|
|
146
|
+
escaped_val = val.replace("'", "''")
|
|
147
|
+
where_parts.append(f"{_quote_identifier(col)} = '{escaped_val}'")
|
|
148
|
+
else:
|
|
149
|
+
where_parts.append(f"{_quote_identifier(col)} = '{val}'")
|
|
150
|
+
|
|
151
|
+
where_clause = " AND ".join(where_parts)
|
|
152
|
+
|
|
153
|
+
# Build full query
|
|
154
|
+
cols = ", ".join(self._select_cols)
|
|
155
|
+
query = f"SELECT {cols} FROM {_quote_identifier(self._table)} WHERE {where_clause}"
|
|
156
|
+
|
|
157
|
+
if self._order_by:
|
|
158
|
+
query += f" ORDER BY {self._order_by}"
|
|
159
|
+
if self._limit is not None:
|
|
160
|
+
query += f" LIMIT {self._limit}"
|
|
161
|
+
|
|
162
|
+
# Execute query
|
|
163
|
+
response = self._snapshot.resource.query(query)
|
|
164
|
+
if response.rows and response.columns:
|
|
165
|
+
return [dict(zip(response.columns, row)) for row in response.rows]
|
|
166
|
+
return []
|
|
103
167
|
|
|
104
168
|
def _get_data(self) -> List[Dict[str, Any]]:
|
|
105
|
-
"""Get table data
|
|
106
|
-
self.
|
|
169
|
+
"""Get table data - use targeted query if possible, otherwise load all data."""
|
|
170
|
+
if self._use_targeted_query and self._can_use_targeted_query():
|
|
171
|
+
return self._execute_targeted_query()
|
|
172
|
+
|
|
173
|
+
# Fall back to loading all data
|
|
174
|
+
self._snapshot._ensure_table_data(self._table)
|
|
107
175
|
return self._snapshot._data.get(self._table, [])
|
|
108
176
|
|
|
109
177
|
def eq(self, column: str, value: Any) -> "SyncSnapshotQueryBuilder":
|
|
@@ -142,6 +210,11 @@ class SyncSnapshotQueryBuilder:
|
|
|
142
210
|
return rows[0] if rows else None
|
|
143
211
|
|
|
144
212
|
def all(self) -> List[Dict[str, Any]]:
|
|
213
|
+
# If we can use targeted query, _get_data already applies filters
|
|
214
|
+
if self._use_targeted_query and self._can_use_targeted_query():
|
|
215
|
+
return self._get_data()
|
|
216
|
+
|
|
217
|
+
# Otherwise, get all data and apply filters manually
|
|
145
218
|
data = self._get_data()
|
|
146
219
|
|
|
147
220
|
# Apply filters
|
|
@@ -206,11 +279,12 @@ class SyncSnapshotDiff:
|
|
|
206
279
|
self.after = after
|
|
207
280
|
self.ignore_config = ignore_config or IgnoreConfig()
|
|
208
281
|
self._cached: Optional[Dict[str, Any]] = None
|
|
282
|
+
self._targeted_mode = False # Flag to use targeted queries
|
|
209
283
|
|
|
210
284
|
def _get_primary_key_columns(self, table: str) -> List[str]:
|
|
211
285
|
"""Get primary key columns for a table."""
|
|
212
286
|
# Try to get from schema
|
|
213
|
-
schema_response = self.after.resource.query(f"PRAGMA table_info({table})")
|
|
287
|
+
schema_response = self.after.resource.query(f"PRAGMA table_info({_quote_identifier(table)})")
|
|
214
288
|
if not schema_response.rows:
|
|
215
289
|
return ["id"] # Default fallback
|
|
216
290
|
|
|
@@ -246,6 +320,10 @@ class SyncSnapshotDiff:
|
|
|
246
320
|
# Get primary key columns
|
|
247
321
|
pk_columns = self._get_primary_key_columns(tbl)
|
|
248
322
|
|
|
323
|
+
# Ensure data is fetched for this table
|
|
324
|
+
self.before._ensure_table_data(tbl)
|
|
325
|
+
self.after._ensure_table_data(tbl)
|
|
326
|
+
|
|
249
327
|
# Get data from both snapshots
|
|
250
328
|
before_data = self.before._data.get(tbl, [])
|
|
251
329
|
after_data = self.after._data.get(tbl, [])
|
|
@@ -324,9 +402,405 @@ class SyncSnapshotDiff:
|
|
|
324
402
|
"""Expose the computed diff so callers can introspect like the legacy API."""
|
|
325
403
|
return self._collect()
|
|
326
404
|
|
|
327
|
-
def
|
|
328
|
-
"""
|
|
329
|
-
|
|
405
|
+
def _can_use_targeted_queries(self, allowed_changes: List[Dict[str, Any]]) -> bool:
|
|
406
|
+
"""Check if we can use targeted queries for optimization."""
|
|
407
|
+
# We can use targeted queries if all allowed changes specify table and pk
|
|
408
|
+
for change in allowed_changes:
|
|
409
|
+
if "table" not in change or "pk" not in change:
|
|
410
|
+
return False
|
|
411
|
+
return True
|
|
412
|
+
|
|
413
|
+
def _build_pk_where_clause(self, pk_columns: List[str], pk_value: Any) -> str:
|
|
414
|
+
"""Build WHERE clause for primary key lookup."""
|
|
415
|
+
# Escape single quotes in values to prevent SQL injection
|
|
416
|
+
def escape_value(val: Any) -> str:
|
|
417
|
+
if val is None:
|
|
418
|
+
return "NULL"
|
|
419
|
+
elif isinstance(val, str):
|
|
420
|
+
escaped = str(val).replace("'", "''")
|
|
421
|
+
return f"'{escaped}'"
|
|
422
|
+
else:
|
|
423
|
+
return f"'{val}'"
|
|
424
|
+
|
|
425
|
+
if len(pk_columns) == 1:
|
|
426
|
+
return f"{_quote_identifier(pk_columns[0])} = {escape_value(pk_value)}"
|
|
427
|
+
else:
|
|
428
|
+
# Composite key
|
|
429
|
+
if isinstance(pk_value, tuple):
|
|
430
|
+
conditions = [
|
|
431
|
+
f"{_quote_identifier(col)} = {escape_value(val)}"
|
|
432
|
+
for col, val in zip(pk_columns, pk_value)
|
|
433
|
+
]
|
|
434
|
+
return " AND ".join(conditions)
|
|
435
|
+
else:
|
|
436
|
+
# Shouldn't happen if data is consistent
|
|
437
|
+
return f"{_quote_identifier(pk_columns[0])} = {escape_value(pk_value)}"
|
|
438
|
+
|
|
439
|
+
def _expect_no_changes(self):
|
|
440
|
+
"""Efficiently verify that no changes occurred between snapshots using row counts."""
|
|
441
|
+
try:
|
|
442
|
+
import concurrent.futures
|
|
443
|
+
from threading import Lock
|
|
444
|
+
|
|
445
|
+
# Get all tables from both snapshots
|
|
446
|
+
before_tables = set(self.before.tables())
|
|
447
|
+
after_tables = set(self.after.tables())
|
|
448
|
+
|
|
449
|
+
# Check for added/removed tables (excluding ignored ones)
|
|
450
|
+
added_tables = after_tables - before_tables
|
|
451
|
+
removed_tables = before_tables - after_tables
|
|
452
|
+
|
|
453
|
+
for table in added_tables:
|
|
454
|
+
if not self.ignore_config.should_ignore_table(table):
|
|
455
|
+
raise AssertionError(f"Unexpected table added: {table}")
|
|
456
|
+
|
|
457
|
+
for table in removed_tables:
|
|
458
|
+
if not self.ignore_config.should_ignore_table(table):
|
|
459
|
+
raise AssertionError(f"Unexpected table removed: {table}")
|
|
460
|
+
|
|
461
|
+
# Prepare tables to check
|
|
462
|
+
tables_to_check = []
|
|
463
|
+
all_tables = before_tables | after_tables
|
|
464
|
+
for table in all_tables:
|
|
465
|
+
if not self.ignore_config.should_ignore_table(table):
|
|
466
|
+
tables_to_check.append(table)
|
|
467
|
+
|
|
468
|
+
# If no tables to check, we're done
|
|
469
|
+
if not tables_to_check:
|
|
470
|
+
return self
|
|
471
|
+
|
|
472
|
+
# Use ThreadPoolExecutor to parallelize count queries
|
|
473
|
+
errors = []
|
|
474
|
+
errors_lock = Lock()
|
|
475
|
+
tables_needing_verification = []
|
|
476
|
+
verification_lock = Lock()
|
|
477
|
+
|
|
478
|
+
def check_table_counts(table: str):
|
|
479
|
+
"""Check row counts for a single table."""
|
|
480
|
+
try:
|
|
481
|
+
# Get row counts from both snapshots
|
|
482
|
+
before_count = 0
|
|
483
|
+
after_count = 0
|
|
484
|
+
|
|
485
|
+
if table in before_tables:
|
|
486
|
+
before_count_response = self.before.resource.query(
|
|
487
|
+
f"SELECT COUNT(*) FROM {_quote_identifier(table)}"
|
|
488
|
+
)
|
|
489
|
+
before_count = (
|
|
490
|
+
before_count_response.rows[0][0]
|
|
491
|
+
if before_count_response.rows
|
|
492
|
+
else 0
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
if table in after_tables:
|
|
496
|
+
after_count_response = self.after.resource.query(
|
|
497
|
+
f"SELECT COUNT(*) FROM {_quote_identifier(table)}"
|
|
498
|
+
)
|
|
499
|
+
after_count = (
|
|
500
|
+
after_count_response.rows[0][0]
|
|
501
|
+
if after_count_response.rows
|
|
502
|
+
else 0
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
if before_count != after_count:
|
|
506
|
+
error_msg = (
|
|
507
|
+
f"Unexpected change in table '{table}': "
|
|
508
|
+
f"row count changed from {before_count} to {after_count}"
|
|
509
|
+
)
|
|
510
|
+
with errors_lock:
|
|
511
|
+
errors.append(AssertionError(error_msg))
|
|
512
|
+
elif before_count > 0 and before_count <= 1000:
|
|
513
|
+
# Mark for detailed verification
|
|
514
|
+
with verification_lock:
|
|
515
|
+
tables_needing_verification.append(table)
|
|
516
|
+
|
|
517
|
+
except Exception as e:
|
|
518
|
+
with errors_lock:
|
|
519
|
+
errors.append(e)
|
|
520
|
+
|
|
521
|
+
# Execute count checks in parallel
|
|
522
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
|
523
|
+
futures = [
|
|
524
|
+
executor.submit(check_table_counts, table)
|
|
525
|
+
for table in tables_to_check
|
|
526
|
+
]
|
|
527
|
+
concurrent.futures.wait(futures)
|
|
528
|
+
|
|
529
|
+
# Check if any errors occurred during count checking
|
|
530
|
+
if errors:
|
|
531
|
+
raise errors[0]
|
|
532
|
+
|
|
533
|
+
# Now verify small tables for data changes (also in parallel)
|
|
534
|
+
if tables_needing_verification:
|
|
535
|
+
verification_errors = []
|
|
536
|
+
|
|
537
|
+
def verify_table(table: str):
|
|
538
|
+
"""Verify a single table's data hasn't changed."""
|
|
539
|
+
try:
|
|
540
|
+
self._verify_table_unchanged(table)
|
|
541
|
+
except AssertionError as e:
|
|
542
|
+
with errors_lock:
|
|
543
|
+
verification_errors.append(e)
|
|
544
|
+
|
|
545
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
|
|
546
|
+
futures = [
|
|
547
|
+
executor.submit(verify_table, table)
|
|
548
|
+
for table in tables_needing_verification
|
|
549
|
+
]
|
|
550
|
+
concurrent.futures.wait(futures)
|
|
551
|
+
|
|
552
|
+
# Check if any errors occurred during verification
|
|
553
|
+
if verification_errors:
|
|
554
|
+
raise verification_errors[0]
|
|
555
|
+
|
|
556
|
+
return self
|
|
557
|
+
|
|
558
|
+
except AssertionError:
|
|
559
|
+
# Re-raise assertion errors (these are expected failures)
|
|
560
|
+
raise
|
|
561
|
+
except Exception as e:
|
|
562
|
+
# If the optimized check fails for other reasons, fall back to full diff
|
|
563
|
+
print(f"Warning: Optimized no-changes check failed: {e}")
|
|
564
|
+
print("Falling back to full diff...")
|
|
565
|
+
return self._validate_diff_against_allowed_changes(self._collect(), [])
|
|
566
|
+
|
|
567
|
+
def _verify_table_unchanged(self, table: str):
|
|
568
|
+
"""Verify that a table's data hasn't changed (for small tables)."""
|
|
569
|
+
# Get primary key columns
|
|
570
|
+
pk_columns = self._get_primary_key_columns(table)
|
|
571
|
+
|
|
572
|
+
# Get sorted data from both snapshots
|
|
573
|
+
order_by = ", ".join(pk_columns) if pk_columns else "rowid"
|
|
574
|
+
|
|
575
|
+
before_response = self.before.resource.query(
|
|
576
|
+
f"SELECT * FROM {_quote_identifier(table)} ORDER BY {order_by}"
|
|
577
|
+
)
|
|
578
|
+
after_response = self.after.resource.query(
|
|
579
|
+
f"SELECT * FROM {_quote_identifier(table)} ORDER BY {order_by}"
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
# Quick check: if column counts differ, there's a schema change
|
|
583
|
+
if before_response.columns != after_response.columns:
|
|
584
|
+
raise AssertionError(f"Schema changed in table '{table}'")
|
|
585
|
+
|
|
586
|
+
# Compare row by row
|
|
587
|
+
if len(before_response.rows) != len(after_response.rows):
|
|
588
|
+
raise AssertionError(
|
|
589
|
+
f"Row count mismatch in table '{table}': "
|
|
590
|
+
f"{len(before_response.rows)} vs {len(after_response.rows)}"
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
for i, (before_row, after_row) in enumerate(
|
|
594
|
+
zip(before_response.rows, after_response.rows)
|
|
595
|
+
):
|
|
596
|
+
before_dict = dict(zip(before_response.columns, before_row))
|
|
597
|
+
after_dict = dict(zip(after_response.columns, after_row))
|
|
598
|
+
|
|
599
|
+
# Compare fields, ignoring those in ignore config
|
|
600
|
+
for field in before_response.columns:
|
|
601
|
+
if self.ignore_config.should_ignore_field(table, field):
|
|
602
|
+
continue
|
|
603
|
+
|
|
604
|
+
if not _values_equivalent(
|
|
605
|
+
before_dict.get(field), after_dict.get(field)
|
|
606
|
+
):
|
|
607
|
+
pk_val = before_dict.get(pk_columns[0]) if pk_columns else i
|
|
608
|
+
raise AssertionError(
|
|
609
|
+
f"Unexpected change in table '{table}', row {pk_val}, "
|
|
610
|
+
f"field '{field}': {repr(before_dict.get(field))} -> {repr(after_dict.get(field))}"
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
def _is_field_change_allowed(
|
|
614
|
+
self, table_changes: List[Dict[str, Any]], pk: Any, field: str, after_val: Any
|
|
615
|
+
) -> bool:
|
|
616
|
+
"""Check if a specific field change is allowed."""
|
|
617
|
+
for change in table_changes:
|
|
618
|
+
if (
|
|
619
|
+
str(change.get("pk")) == str(pk)
|
|
620
|
+
and change.get("field") == field
|
|
621
|
+
and _values_equivalent(change.get("after"), after_val)
|
|
622
|
+
):
|
|
623
|
+
return True
|
|
624
|
+
return False
|
|
625
|
+
|
|
626
|
+
def _is_row_change_allowed(
|
|
627
|
+
self, table_changes: List[Dict[str, Any]], pk: Any, change_type: str
|
|
628
|
+
) -> bool:
|
|
629
|
+
"""Check if a row addition/deletion is allowed."""
|
|
630
|
+
for change in table_changes:
|
|
631
|
+
if str(change.get("pk")) == str(pk) and change.get("after") == change_type:
|
|
632
|
+
return True
|
|
633
|
+
return False
|
|
634
|
+
|
|
635
|
+
def _expect_only_targeted(self, allowed_changes: List[Dict[str, Any]]):
|
|
636
|
+
"""Optimized version that only queries specific rows mentioned in allowed_changes."""
|
|
637
|
+
import concurrent.futures
|
|
638
|
+
from threading import Lock
|
|
639
|
+
|
|
640
|
+
# Group allowed changes by table
|
|
641
|
+
changes_by_table: Dict[str, List[Dict[str, Any]]] = {}
|
|
642
|
+
for change in allowed_changes:
|
|
643
|
+
table = change["table"]
|
|
644
|
+
if table not in changes_by_table:
|
|
645
|
+
changes_by_table[table] = []
|
|
646
|
+
changes_by_table[table].append(change)
|
|
647
|
+
|
|
648
|
+
errors = []
|
|
649
|
+
errors_lock = Lock()
|
|
650
|
+
|
|
651
|
+
# Function to check a single row
|
|
652
|
+
def check_row(
|
|
653
|
+
table: str,
|
|
654
|
+
pk: Any,
|
|
655
|
+
table_changes: List[Dict[str, Any]],
|
|
656
|
+
pk_columns: List[str],
|
|
657
|
+
):
|
|
658
|
+
try:
|
|
659
|
+
# Build WHERE clause for this PK
|
|
660
|
+
where_sql = self._build_pk_where_clause(pk_columns, pk)
|
|
661
|
+
|
|
662
|
+
# Query before snapshot
|
|
663
|
+
before_query = f"SELECT * FROM {_quote_identifier(table)} WHERE {where_sql}"
|
|
664
|
+
before_response = self.before.resource.query(before_query)
|
|
665
|
+
before_row = (
|
|
666
|
+
dict(zip(before_response.columns, before_response.rows[0]))
|
|
667
|
+
if before_response.rows
|
|
668
|
+
else None
|
|
669
|
+
)
|
|
670
|
+
|
|
671
|
+
# Query after snapshot
|
|
672
|
+
after_response = self.after.resource.query(before_query)
|
|
673
|
+
after_row = (
|
|
674
|
+
dict(zip(after_response.columns, after_response.rows[0]))
|
|
675
|
+
if after_response.rows
|
|
676
|
+
else None
|
|
677
|
+
)
|
|
678
|
+
|
|
679
|
+
# Check changes for this row
|
|
680
|
+
if before_row and after_row:
|
|
681
|
+
# Modified row - check fields
|
|
682
|
+
for field in set(before_row.keys()) | set(after_row.keys()):
|
|
683
|
+
if self.ignore_config.should_ignore_field(table, field):
|
|
684
|
+
continue
|
|
685
|
+
before_val = before_row.get(field)
|
|
686
|
+
after_val = after_row.get(field)
|
|
687
|
+
if not _values_equivalent(before_val, after_val):
|
|
688
|
+
# Check if this change is allowed
|
|
689
|
+
if not self._is_field_change_allowed(
|
|
690
|
+
table_changes, pk, field, after_val
|
|
691
|
+
):
|
|
692
|
+
error_msg = (
|
|
693
|
+
f"Unexpected change in table '{table}', "
|
|
694
|
+
f"row {pk}, field '{field}': "
|
|
695
|
+
f"{repr(before_val)} -> {repr(after_val)}"
|
|
696
|
+
)
|
|
697
|
+
with errors_lock:
|
|
698
|
+
errors.append(AssertionError(error_msg))
|
|
699
|
+
return # Stop checking this row
|
|
700
|
+
elif not before_row and after_row:
|
|
701
|
+
# Added row
|
|
702
|
+
if not self._is_row_change_allowed(table_changes, pk, "__added__"):
|
|
703
|
+
error_msg = f"Unexpected row added in table '{table}': {pk}"
|
|
704
|
+
with errors_lock:
|
|
705
|
+
errors.append(AssertionError(error_msg))
|
|
706
|
+
elif before_row and not after_row:
|
|
707
|
+
# Removed row
|
|
708
|
+
if not self._is_row_change_allowed(table_changes, pk, "__removed__"):
|
|
709
|
+
error_msg = f"Unexpected row removed from table '{table}': {pk}"
|
|
710
|
+
with errors_lock:
|
|
711
|
+
errors.append(AssertionError(error_msg))
|
|
712
|
+
except Exception as e:
|
|
713
|
+
with errors_lock:
|
|
714
|
+
errors.append(e)
|
|
715
|
+
|
|
716
|
+
# Prepare all row checks
|
|
717
|
+
row_checks = []
|
|
718
|
+
for table, table_changes in changes_by_table.items():
|
|
719
|
+
if self.ignore_config.should_ignore_table(table):
|
|
720
|
+
continue
|
|
721
|
+
|
|
722
|
+
# Get primary key columns once per table
|
|
723
|
+
pk_columns = self._get_primary_key_columns(table)
|
|
724
|
+
|
|
725
|
+
# Extract unique PKs to check
|
|
726
|
+
pks_to_check = {change["pk"] for change in table_changes}
|
|
727
|
+
|
|
728
|
+
for pk in pks_to_check:
|
|
729
|
+
row_checks.append((table, pk, table_changes, pk_columns))
|
|
730
|
+
|
|
731
|
+
# Execute row checks in parallel
|
|
732
|
+
if row_checks:
|
|
733
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
|
734
|
+
futures = [
|
|
735
|
+
executor.submit(check_row, table, pk, table_changes, pk_columns)
|
|
736
|
+
for table, pk, table_changes, pk_columns in row_checks
|
|
737
|
+
]
|
|
738
|
+
concurrent.futures.wait(futures)
|
|
739
|
+
|
|
740
|
+
# Check for errors from row checks
|
|
741
|
+
if errors:
|
|
742
|
+
raise errors[0]
|
|
743
|
+
|
|
744
|
+
# Now check tables not mentioned in allowed_changes to ensure no changes
|
|
745
|
+
all_tables = set(self.before.tables()) | set(self.after.tables())
|
|
746
|
+
tables_to_verify = []
|
|
747
|
+
|
|
748
|
+
for table in all_tables:
|
|
749
|
+
if (
|
|
750
|
+
table not in changes_by_table
|
|
751
|
+
and not self.ignore_config.should_ignore_table(table)
|
|
752
|
+
):
|
|
753
|
+
tables_to_verify.append(table)
|
|
754
|
+
|
|
755
|
+
# Function to verify no changes in a table
|
|
756
|
+
def verify_no_changes(table: str):
|
|
757
|
+
try:
|
|
758
|
+
# For tables with no allowed changes, just check row counts
|
|
759
|
+
before_count_response = self.before.resource.query(
|
|
760
|
+
f"SELECT COUNT(*) FROM {_quote_identifier(table)}"
|
|
761
|
+
)
|
|
762
|
+
before_count = (
|
|
763
|
+
before_count_response.rows[0][0]
|
|
764
|
+
if before_count_response.rows
|
|
765
|
+
else 0
|
|
766
|
+
)
|
|
767
|
+
|
|
768
|
+
after_count_response = self.after.resource.query(
|
|
769
|
+
f"SELECT COUNT(*) FROM {_quote_identifier(table)}"
|
|
770
|
+
)
|
|
771
|
+
after_count = (
|
|
772
|
+
after_count_response.rows[0][0] if after_count_response.rows else 0
|
|
773
|
+
)
|
|
774
|
+
|
|
775
|
+
if before_count != after_count:
|
|
776
|
+
error_msg = (
|
|
777
|
+
f"Unexpected change in table '{table}': "
|
|
778
|
+
f"row count changed from {before_count} to {after_count}"
|
|
779
|
+
)
|
|
780
|
+
with errors_lock:
|
|
781
|
+
errors.append(AssertionError(error_msg))
|
|
782
|
+
except Exception as e:
|
|
783
|
+
with errors_lock:
|
|
784
|
+
errors.append(e)
|
|
785
|
+
|
|
786
|
+
# Execute table verification in parallel
|
|
787
|
+
if tables_to_verify:
|
|
788
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
|
789
|
+
futures = [
|
|
790
|
+
executor.submit(verify_no_changes, table) for table in tables_to_verify
|
|
791
|
+
]
|
|
792
|
+
concurrent.futures.wait(futures)
|
|
793
|
+
|
|
794
|
+
# Final error check
|
|
795
|
+
if errors:
|
|
796
|
+
raise errors[0]
|
|
797
|
+
|
|
798
|
+
return self
|
|
799
|
+
|
|
800
|
+
def _validate_diff_against_allowed_changes(
|
|
801
|
+
self, diff: Dict[str, Any], allowed_changes: List[Dict[str, Any]]
|
|
802
|
+
):
|
|
803
|
+
"""Validate a collected diff against allowed changes."""
|
|
330
804
|
|
|
331
805
|
def _is_change_allowed(
|
|
332
806
|
table: str, row_id: Any, field: Optional[str], after_value: Any
|
|
@@ -453,6 +927,1077 @@ class SyncSnapshotDiff:
|
|
|
453
927
|
|
|
454
928
|
return self
|
|
455
929
|
|
|
930
|
+
def _expect_only_targeted_v2(self, allowed_changes: List[Dict[str, Any]]):
|
|
931
|
+
"""Optimized version that only queries specific rows mentioned in allowed_changes.
|
|
932
|
+
|
|
933
|
+
Supports v2 spec formats:
|
|
934
|
+
- {"table": "t", "pk": 1, "type": "insert", "fields": [...]}
|
|
935
|
+
- {"table": "t", "pk": 1, "type": "modify", "resulting_fields": [...], "no_other_changes": bool}
|
|
936
|
+
- {"table": "t", "pk": 1, "type": "delete", "fields": [...]}
|
|
937
|
+
- Legacy single-field specs: {"table": "t", "pk": 1, "field": "x", "after": val}
|
|
938
|
+
"""
|
|
939
|
+
import concurrent.futures
|
|
940
|
+
from threading import Lock
|
|
941
|
+
|
|
942
|
+
# Helper functions for v2 spec validation
|
|
943
|
+
def _parse_fields_spec(
|
|
944
|
+
fields_spec: List[Tuple[str, Any]]
|
|
945
|
+
) -> Dict[str, Tuple[bool, Any]]:
|
|
946
|
+
"""Parse a fields spec into a mapping of field_name -> (should_check_value, expected_value)."""
|
|
947
|
+
spec_map: Dict[str, Tuple[bool, Any]] = {}
|
|
948
|
+
for spec_tuple in fields_spec:
|
|
949
|
+
if len(spec_tuple) != 2:
|
|
950
|
+
raise ValueError(
|
|
951
|
+
f"Invalid field spec tuple: {spec_tuple}. "
|
|
952
|
+
f"Expected 2-tuple like ('field', value), ('field', None), or ('field', ...)"
|
|
953
|
+
)
|
|
954
|
+
field_name, expected_value = spec_tuple
|
|
955
|
+
if expected_value is ...:
|
|
956
|
+
spec_map[field_name] = (False, None)
|
|
957
|
+
else:
|
|
958
|
+
spec_map[field_name] = (True, expected_value)
|
|
959
|
+
return spec_map
|
|
960
|
+
|
|
961
|
+
def _get_spec_for_pk(table: str, pk: Any) -> Optional[Dict[str, Any]]:
|
|
962
|
+
"""Get the spec for a given table/pk."""
|
|
963
|
+
for allowed in allowed_changes:
|
|
964
|
+
if (
|
|
965
|
+
allowed["table"] == table
|
|
966
|
+
and str(allowed.get("pk")) == str(pk)
|
|
967
|
+
):
|
|
968
|
+
return allowed
|
|
969
|
+
return None
|
|
970
|
+
|
|
971
|
+
def _get_all_specs_for_pk(table: str, pk: Any) -> List[Dict[str, Any]]:
|
|
972
|
+
"""Get all specs for a given table/pk (for legacy multi-field specs)."""
|
|
973
|
+
specs = []
|
|
974
|
+
for allowed in allowed_changes:
|
|
975
|
+
if (
|
|
976
|
+
allowed["table"] == table
|
|
977
|
+
and str(allowed.get("pk")) == str(pk)
|
|
978
|
+
):
|
|
979
|
+
specs.append(allowed)
|
|
980
|
+
return specs
|
|
981
|
+
|
|
982
|
+
def _validate_insert_row(
|
|
983
|
+
table: str, pk: Any, row_data: Dict[str, Any], specs: List[Dict[str, Any]]
|
|
984
|
+
) -> Optional[str]:
|
|
985
|
+
"""Validate an inserted row against specs. Returns error message or None."""
|
|
986
|
+
# Check for type: "insert" spec with fields
|
|
987
|
+
for spec in specs:
|
|
988
|
+
if spec.get("type") == "insert":
|
|
989
|
+
fields_spec = spec.get("fields")
|
|
990
|
+
if fields_spec is not None:
|
|
991
|
+
# Validate each field
|
|
992
|
+
spec_map = _parse_fields_spec(fields_spec)
|
|
993
|
+
for field_name, field_value in row_data.items():
|
|
994
|
+
if field_name == "rowid":
|
|
995
|
+
continue
|
|
996
|
+
if self.ignore_config.should_ignore_field(table, field_name):
|
|
997
|
+
continue
|
|
998
|
+
if field_name not in spec_map:
|
|
999
|
+
return f"Field '{field_name}' not in insert spec for table '{table}' pk={pk}"
|
|
1000
|
+
should_check, expected_value = spec_map[field_name]
|
|
1001
|
+
if should_check and not _values_equivalent(expected_value, field_value):
|
|
1002
|
+
return (
|
|
1003
|
+
f"Insert mismatch in table '{table}' pk={pk}, "
|
|
1004
|
+
f"field '{field_name}': expected {repr(expected_value)}, got {repr(field_value)}"
|
|
1005
|
+
)
|
|
1006
|
+
# type: "insert" found (with or without fields) - allowed
|
|
1007
|
+
return None
|
|
1008
|
+
|
|
1009
|
+
# Check for legacy whole-row spec
|
|
1010
|
+
for spec in specs:
|
|
1011
|
+
if spec.get("fields") is None and spec.get("after") == "__added__":
|
|
1012
|
+
return None
|
|
1013
|
+
|
|
1014
|
+
return f"Unexpected row added in table '{table}': pk={pk}"
|
|
1015
|
+
|
|
1016
|
+
def _validate_delete_row(
|
|
1017
|
+
table: str, pk: Any, row_data: Dict[str, Any], specs: List[Dict[str, Any]]
|
|
1018
|
+
) -> Optional[str]:
|
|
1019
|
+
"""Validate a deleted row against specs. Returns error message or None."""
|
|
1020
|
+
# Check for type: "delete" spec with optional fields
|
|
1021
|
+
for spec in specs:
|
|
1022
|
+
if spec.get("type") == "delete":
|
|
1023
|
+
fields_spec = spec.get("fields")
|
|
1024
|
+
if fields_spec is not None:
|
|
1025
|
+
# Validate each field against the deleted row
|
|
1026
|
+
spec_map = _parse_fields_spec(fields_spec)
|
|
1027
|
+
for field_name, (should_check, expected_value) in spec_map.items():
|
|
1028
|
+
if field_name not in row_data:
|
|
1029
|
+
return f"Field '{field_name}' in delete spec not found in row for table '{table}' pk={pk}"
|
|
1030
|
+
if should_check and not _values_equivalent(expected_value, row_data[field_name]):
|
|
1031
|
+
return (
|
|
1032
|
+
f"Delete mismatch in table '{table}' pk={pk}, "
|
|
1033
|
+
f"field '{field_name}': expected {repr(expected_value)}, got {repr(row_data[field_name])}"
|
|
1034
|
+
)
|
|
1035
|
+
# type: "delete" found (with or without fields) - allowed
|
|
1036
|
+
return None
|
|
1037
|
+
|
|
1038
|
+
# Check for legacy whole-row spec
|
|
1039
|
+
for spec in specs:
|
|
1040
|
+
if spec.get("fields") is None and spec.get("after") == "__removed__":
|
|
1041
|
+
return None
|
|
1042
|
+
|
|
1043
|
+
return f"Unexpected row removed from table '{table}': pk={pk}"
|
|
1044
|
+
|
|
1045
|
+
def _validate_modify_row(
|
|
1046
|
+
table: str,
|
|
1047
|
+
pk: Any,
|
|
1048
|
+
before_row: Dict[str, Any],
|
|
1049
|
+
after_row: Dict[str, Any],
|
|
1050
|
+
specs: List[Dict[str, Any]],
|
|
1051
|
+
) -> Optional[str]:
|
|
1052
|
+
"""Validate a modified row against specs. Returns error message or None."""
|
|
1053
|
+
# Collect actual changes
|
|
1054
|
+
changed_fields: Dict[str, Dict[str, Any]] = {}
|
|
1055
|
+
for field in set(before_row.keys()) | set(after_row.keys()):
|
|
1056
|
+
if self.ignore_config.should_ignore_field(table, field):
|
|
1057
|
+
continue
|
|
1058
|
+
before_val = before_row.get(field)
|
|
1059
|
+
after_val = after_row.get(field)
|
|
1060
|
+
if not _values_equivalent(before_val, after_val):
|
|
1061
|
+
changed_fields[field] = {"before": before_val, "after": after_val}
|
|
1062
|
+
|
|
1063
|
+
if not changed_fields:
|
|
1064
|
+
return None # No changes
|
|
1065
|
+
|
|
1066
|
+
# Check for type: "modify" spec with resulting_fields
|
|
1067
|
+
for spec in specs:
|
|
1068
|
+
if spec.get("type") == "modify":
|
|
1069
|
+
resulting_fields = spec.get("resulting_fields")
|
|
1070
|
+
if resulting_fields is not None:
|
|
1071
|
+
# Validate no_other_changes is provided
|
|
1072
|
+
if "no_other_changes" not in spec:
|
|
1073
|
+
raise ValueError(
|
|
1074
|
+
f"Modify spec for table '{table}' pk={pk} "
|
|
1075
|
+
f"has 'resulting_fields' but missing required 'no_other_changes' field."
|
|
1076
|
+
)
|
|
1077
|
+
no_other_changes = spec["no_other_changes"]
|
|
1078
|
+
if not isinstance(no_other_changes, bool):
|
|
1079
|
+
raise ValueError(
|
|
1080
|
+
f"Modify spec for table '{table}' pk={pk} "
|
|
1081
|
+
f"'no_other_changes' must be boolean, got {type(no_other_changes).__name__}"
|
|
1082
|
+
)
|
|
1083
|
+
|
|
1084
|
+
spec_map = _parse_fields_spec(resulting_fields)
|
|
1085
|
+
|
|
1086
|
+
# Validate changed fields
|
|
1087
|
+
for field_name, vals in changed_fields.items():
|
|
1088
|
+
after_val = vals["after"]
|
|
1089
|
+
if field_name not in spec_map:
|
|
1090
|
+
if no_other_changes:
|
|
1091
|
+
return (
|
|
1092
|
+
f"Unexpected field change in table '{table}' pk={pk}: "
|
|
1093
|
+
f"field '{field_name}' not in resulting_fields"
|
|
1094
|
+
)
|
|
1095
|
+
# no_other_changes=False: ignore this field
|
|
1096
|
+
else:
|
|
1097
|
+
should_check, expected_value = spec_map[field_name]
|
|
1098
|
+
if should_check and not _values_equivalent(expected_value, after_val):
|
|
1099
|
+
return (
|
|
1100
|
+
f"Modify mismatch in table '{table}' pk={pk}, "
|
|
1101
|
+
f"field '{field_name}': expected {repr(expected_value)}, got {repr(after_val)}"
|
|
1102
|
+
)
|
|
1103
|
+
return None # Validation passed
|
|
1104
|
+
else:
|
|
1105
|
+
# type: "modify" without resulting_fields - allow any modification
|
|
1106
|
+
return None
|
|
1107
|
+
|
|
1108
|
+
# Check for legacy single-field specs
|
|
1109
|
+
for field_name, vals in changed_fields.items():
|
|
1110
|
+
after_val = vals["after"]
|
|
1111
|
+
field_allowed = False
|
|
1112
|
+
for spec in specs:
|
|
1113
|
+
if (
|
|
1114
|
+
spec.get("field") == field_name
|
|
1115
|
+
and _values_equivalent(spec.get("after"), after_val)
|
|
1116
|
+
):
|
|
1117
|
+
field_allowed = True
|
|
1118
|
+
break
|
|
1119
|
+
if not field_allowed:
|
|
1120
|
+
return (
|
|
1121
|
+
f"Unexpected change in table '{table}' pk={pk}, "
|
|
1122
|
+
f"field '{field_name}': {repr(vals['before'])} -> {repr(after_val)}"
|
|
1123
|
+
)
|
|
1124
|
+
|
|
1125
|
+
return None
|
|
1126
|
+
|
|
1127
|
+
# Group allowed changes by table
|
|
1128
|
+
changes_by_table: Dict[str, List[Dict[str, Any]]] = {}
|
|
1129
|
+
for change in allowed_changes:
|
|
1130
|
+
table = change["table"]
|
|
1131
|
+
if table not in changes_by_table:
|
|
1132
|
+
changes_by_table[table] = []
|
|
1133
|
+
changes_by_table[table].append(change)
|
|
1134
|
+
|
|
1135
|
+
errors = []
|
|
1136
|
+
errors_lock = Lock()
|
|
1137
|
+
|
|
1138
|
+
# Function to check a single row
|
|
1139
|
+
def check_row(
|
|
1140
|
+
table: str,
|
|
1141
|
+
pk: Any,
|
|
1142
|
+
pk_columns: List[str],
|
|
1143
|
+
):
|
|
1144
|
+
try:
|
|
1145
|
+
# Build WHERE clause for this PK
|
|
1146
|
+
where_sql = self._build_pk_where_clause(pk_columns, pk)
|
|
1147
|
+
|
|
1148
|
+
# Query before snapshot
|
|
1149
|
+
before_query = f"SELECT * FROM {_quote_identifier(table)} WHERE {where_sql}"
|
|
1150
|
+
before_response = self.before.resource.query(before_query)
|
|
1151
|
+
before_row = (
|
|
1152
|
+
dict(zip(before_response.columns, before_response.rows[0]))
|
|
1153
|
+
if before_response.rows
|
|
1154
|
+
else None
|
|
1155
|
+
)
|
|
1156
|
+
|
|
1157
|
+
# Query after snapshot
|
|
1158
|
+
after_response = self.after.resource.query(before_query)
|
|
1159
|
+
after_row = (
|
|
1160
|
+
dict(zip(after_response.columns, after_response.rows[0]))
|
|
1161
|
+
if after_response.rows
|
|
1162
|
+
else None
|
|
1163
|
+
)
|
|
1164
|
+
|
|
1165
|
+
# Get all specs for this table/pk
|
|
1166
|
+
specs = _get_all_specs_for_pk(table, pk)
|
|
1167
|
+
|
|
1168
|
+
# Check changes for this row
|
|
1169
|
+
if before_row and after_row:
|
|
1170
|
+
# Modified row
|
|
1171
|
+
error = _validate_modify_row(table, pk, before_row, after_row, specs)
|
|
1172
|
+
if error:
|
|
1173
|
+
with errors_lock:
|
|
1174
|
+
errors.append(AssertionError(error))
|
|
1175
|
+
elif not before_row and after_row:
|
|
1176
|
+
# Added row
|
|
1177
|
+
error = _validate_insert_row(table, pk, after_row, specs)
|
|
1178
|
+
if error:
|
|
1179
|
+
with errors_lock:
|
|
1180
|
+
errors.append(AssertionError(error))
|
|
1181
|
+
elif before_row and not after_row:
|
|
1182
|
+
# Removed row
|
|
1183
|
+
error = _validate_delete_row(table, pk, before_row, specs)
|
|
1184
|
+
if error:
|
|
1185
|
+
with errors_lock:
|
|
1186
|
+
errors.append(AssertionError(error))
|
|
1187
|
+
|
|
1188
|
+
except Exception as e:
|
|
1189
|
+
with errors_lock:
|
|
1190
|
+
errors.append(e)
|
|
1191
|
+
|
|
1192
|
+
# Prepare all row checks
|
|
1193
|
+
row_checks = []
|
|
1194
|
+
for table, table_changes in changes_by_table.items():
|
|
1195
|
+
if self.ignore_config.should_ignore_table(table):
|
|
1196
|
+
continue
|
|
1197
|
+
|
|
1198
|
+
# Get primary key columns once per table
|
|
1199
|
+
pk_columns = self._get_primary_key_columns(table)
|
|
1200
|
+
|
|
1201
|
+
# Extract unique PKs to check
|
|
1202
|
+
pks_to_check = {change["pk"] for change in table_changes}
|
|
1203
|
+
|
|
1204
|
+
for pk in pks_to_check:
|
|
1205
|
+
row_checks.append((table, pk, pk_columns))
|
|
1206
|
+
|
|
1207
|
+
# Execute row checks in parallel
|
|
1208
|
+
if row_checks:
|
|
1209
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
|
1210
|
+
futures = [
|
|
1211
|
+
executor.submit(check_row, table, pk, pk_columns)
|
|
1212
|
+
for table, pk, pk_columns in row_checks
|
|
1213
|
+
]
|
|
1214
|
+
concurrent.futures.wait(futures)
|
|
1215
|
+
|
|
1216
|
+
# Check for errors from row checks
|
|
1217
|
+
if errors:
|
|
1218
|
+
raise errors[0]
|
|
1219
|
+
|
|
1220
|
+
# Now check tables not mentioned in allowed_changes to ensure no changes
|
|
1221
|
+
all_tables = set(self.before.tables()) | set(self.after.tables())
|
|
1222
|
+
tables_to_verify = []
|
|
1223
|
+
|
|
1224
|
+
for table in all_tables:
|
|
1225
|
+
if (
|
|
1226
|
+
table not in changes_by_table
|
|
1227
|
+
and not self.ignore_config.should_ignore_table(table)
|
|
1228
|
+
):
|
|
1229
|
+
tables_to_verify.append(table)
|
|
1230
|
+
|
|
1231
|
+
# Function to verify no changes in a table
|
|
1232
|
+
def verify_no_changes(table: str):
|
|
1233
|
+
try:
|
|
1234
|
+
# For tables with no allowed changes, just check row counts
|
|
1235
|
+
before_count_response = self.before.resource.query(
|
|
1236
|
+
f"SELECT COUNT(*) FROM {_quote_identifier(table)}"
|
|
1237
|
+
)
|
|
1238
|
+
before_count = (
|
|
1239
|
+
before_count_response.rows[0][0]
|
|
1240
|
+
if before_count_response.rows
|
|
1241
|
+
else 0
|
|
1242
|
+
)
|
|
1243
|
+
|
|
1244
|
+
after_count_response = self.after.resource.query(
|
|
1245
|
+
f"SELECT COUNT(*) FROM {_quote_identifier(table)}"
|
|
1246
|
+
)
|
|
1247
|
+
after_count = (
|
|
1248
|
+
after_count_response.rows[0][0] if after_count_response.rows else 0
|
|
1249
|
+
)
|
|
1250
|
+
|
|
1251
|
+
if before_count != after_count:
|
|
1252
|
+
error_msg = (
|
|
1253
|
+
f"Unexpected change in table '{table}': "
|
|
1254
|
+
f"row count changed from {before_count} to {after_count}"
|
|
1255
|
+
)
|
|
1256
|
+
with errors_lock:
|
|
1257
|
+
errors.append(AssertionError(error_msg))
|
|
1258
|
+
except Exception as e:
|
|
1259
|
+
with errors_lock:
|
|
1260
|
+
errors.append(e)
|
|
1261
|
+
|
|
1262
|
+
# Execute table verification in parallel
|
|
1263
|
+
if tables_to_verify:
|
|
1264
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
|
1265
|
+
futures = [
|
|
1266
|
+
executor.submit(verify_no_changes, table) for table in tables_to_verify
|
|
1267
|
+
]
|
|
1268
|
+
concurrent.futures.wait(futures)
|
|
1269
|
+
|
|
1270
|
+
# Final error check
|
|
1271
|
+
if errors:
|
|
1272
|
+
raise errors[0]
|
|
1273
|
+
|
|
1274
|
+
return self
|
|
1275
|
+
|
|
1276
|
+
def _validate_diff_against_allowed_changes_v2(
|
|
1277
|
+
self, diff: Dict[str, Any], allowed_changes: List[Dict[str, Any]]
|
|
1278
|
+
):
|
|
1279
|
+
"""Validate a collected diff against allowed changes with field-level spec support.
|
|
1280
|
+
|
|
1281
|
+
This version supports explicit change types via the "type" field:
|
|
1282
|
+
1. Insert specs: {"table": "t", "pk": 1, "type": "insert", "fields": [("name", "value"), ("status", ...)]}
|
|
1283
|
+
- ("name", value): check that field equals value
|
|
1284
|
+
- ("name", None): check that field is SQL NULL
|
|
1285
|
+
- ("name", ...): don't check the value, just acknowledge the field exists
|
|
1286
|
+
2. Modify specs: {"table": "t", "pk": 1, "type": "modify", "resulting_fields": [...], "no_other_changes": True/False}
|
|
1287
|
+
- Uses "resulting_fields" (not "fields") to be explicit about what's being checked
|
|
1288
|
+
- "no_other_changes" is REQUIRED and must be True or False:
|
|
1289
|
+
- True: Every changed field must be in resulting_fields (strict mode)
|
|
1290
|
+
- False: Only check fields in resulting_fields match, ignore other changes
|
|
1291
|
+
- ("field_name", value): check that after value equals value
|
|
1292
|
+
- ("field_name", None): check that after value is SQL NULL
|
|
1293
|
+
- ("field_name", ...): don't check value, just acknowledge field changed
|
|
1294
|
+
3. Delete specs:
|
|
1295
|
+
- Without field validation: {"table": "t", "pk": 1, "type": "delete"}
|
|
1296
|
+
- With field validation: {"table": "t", "pk": 1, "type": "delete", "fields": [...]}
|
|
1297
|
+
4. Whole-row specs (legacy):
|
|
1298
|
+
- For additions: {"table": "t", "pk": 1, "fields": None, "after": "__added__"}
|
|
1299
|
+
- For deletions: {"table": "t", "pk": 1, "fields": None, "after": "__removed__"}
|
|
1300
|
+
|
|
1301
|
+
When using "fields" for inserts, every field must be accounted for in the list.
|
|
1302
|
+
For modifications, use "resulting_fields" with explicit "no_other_changes".
|
|
1303
|
+
For deletions with "fields", all specified fields are validated against the deleted row.
|
|
1304
|
+
"""
|
|
1305
|
+
|
|
1306
|
+
def _is_change_allowed(
|
|
1307
|
+
table: str, row_id: Any, field: Optional[str], after_value: Any
|
|
1308
|
+
) -> bool:
|
|
1309
|
+
"""Check if a change is in the allowed list using semantic comparison."""
|
|
1310
|
+
for allowed in allowed_changes:
|
|
1311
|
+
allowed_pk = allowed.get("pk")
|
|
1312
|
+
# Handle type conversion for primary key comparison
|
|
1313
|
+
pk_match = (
|
|
1314
|
+
str(allowed_pk) == str(row_id) if allowed_pk is not None else False
|
|
1315
|
+
)
|
|
1316
|
+
|
|
1317
|
+
# For whole-row specs, check "fields": None; for field-level, check "field"
|
|
1318
|
+
field_match = (
|
|
1319
|
+
("fields" in allowed and allowed.get("fields") is None)
|
|
1320
|
+
if field is None
|
|
1321
|
+
else allowed.get("field") == field
|
|
1322
|
+
)
|
|
1323
|
+
if (
|
|
1324
|
+
allowed["table"] == table
|
|
1325
|
+
and pk_match
|
|
1326
|
+
and field_match
|
|
1327
|
+
and _values_equivalent(allowed.get("after"), after_value)
|
|
1328
|
+
):
|
|
1329
|
+
return True
|
|
1330
|
+
return False
|
|
1331
|
+
|
|
1332
|
+
def _get_fields_spec_for_type(
|
|
1333
|
+
table: str, row_id: Any, change_type: str
|
|
1334
|
+
) -> Optional[List[Tuple[str, Any]]]:
|
|
1335
|
+
"""Get the bulk fields spec for a given table/row/type if it exists.
|
|
1336
|
+
|
|
1337
|
+
Args:
|
|
1338
|
+
table: The table name
|
|
1339
|
+
row_id: The primary key value
|
|
1340
|
+
change_type: One of "insert", "modify", or "delete"
|
|
1341
|
+
|
|
1342
|
+
Note: For "modify" type, use _get_modify_spec instead.
|
|
1343
|
+
"""
|
|
1344
|
+
for allowed in allowed_changes:
|
|
1345
|
+
allowed_pk = allowed.get("pk")
|
|
1346
|
+
pk_match = (
|
|
1347
|
+
str(allowed_pk) == str(row_id) if allowed_pk is not None else False
|
|
1348
|
+
)
|
|
1349
|
+
if (
|
|
1350
|
+
allowed["table"] == table
|
|
1351
|
+
and pk_match
|
|
1352
|
+
and allowed.get("type") == change_type
|
|
1353
|
+
and "fields" in allowed
|
|
1354
|
+
):
|
|
1355
|
+
return allowed["fields"]
|
|
1356
|
+
return None
|
|
1357
|
+
|
|
1358
|
+
def _get_modify_spec(table: str, row_id: Any) -> Optional[Dict[str, Any]]:
|
|
1359
|
+
"""Get the modify spec for a given table/row if it exists.
|
|
1360
|
+
|
|
1361
|
+
Returns the full spec dict containing:
|
|
1362
|
+
- resulting_fields: List of field tuples
|
|
1363
|
+
- no_other_changes: Boolean (required)
|
|
1364
|
+
|
|
1365
|
+
Returns None if no modify spec found.
|
|
1366
|
+
"""
|
|
1367
|
+
for allowed in allowed_changes:
|
|
1368
|
+
allowed_pk = allowed.get("pk")
|
|
1369
|
+
pk_match = (
|
|
1370
|
+
str(allowed_pk) == str(row_id) if allowed_pk is not None else False
|
|
1371
|
+
)
|
|
1372
|
+
if (
|
|
1373
|
+
allowed["table"] == table
|
|
1374
|
+
and pk_match
|
|
1375
|
+
and allowed.get("type") == "modify"
|
|
1376
|
+
):
|
|
1377
|
+
return allowed
|
|
1378
|
+
return None
|
|
1379
|
+
|
|
1380
|
+
def _is_type_allowed(table: str, row_id: Any, change_type: str) -> bool:
|
|
1381
|
+
"""Check if a change type is allowed for the given table/row (with or without fields)."""
|
|
1382
|
+
for allowed in allowed_changes:
|
|
1383
|
+
allowed_pk = allowed.get("pk")
|
|
1384
|
+
pk_match = (
|
|
1385
|
+
str(allowed_pk) == str(row_id) if allowed_pk is not None else False
|
|
1386
|
+
)
|
|
1387
|
+
if (
|
|
1388
|
+
allowed["table"] == table
|
|
1389
|
+
and pk_match
|
|
1390
|
+
and allowed.get("type") == change_type
|
|
1391
|
+
):
|
|
1392
|
+
return True
|
|
1393
|
+
return False
|
|
1394
|
+
|
|
1395
|
+
def _parse_fields_spec(
|
|
1396
|
+
fields_spec: List[Tuple[str, Any]]
|
|
1397
|
+
) -> Dict[str, Tuple[bool, Any]]:
|
|
1398
|
+
"""Parse a fields spec into a mapping of field_name -> (should_check_value, expected_value)."""
|
|
1399
|
+
spec_map: Dict[str, Tuple[bool, Any]] = {}
|
|
1400
|
+
for spec_tuple in fields_spec:
|
|
1401
|
+
if len(spec_tuple) != 2:
|
|
1402
|
+
raise ValueError(
|
|
1403
|
+
f"Invalid field spec tuple: {spec_tuple}. "
|
|
1404
|
+
f"Expected 2-tuple like ('field', value), ('field', None), or ('field', ...)"
|
|
1405
|
+
)
|
|
1406
|
+
field_name, expected_value = spec_tuple
|
|
1407
|
+
if expected_value is ...:
|
|
1408
|
+
# Ellipsis: don't check value, just acknowledge field exists
|
|
1409
|
+
spec_map[field_name] = (False, None)
|
|
1410
|
+
else:
|
|
1411
|
+
# Any other value (including None for NULL check): check value
|
|
1412
|
+
spec_map[field_name] = (True, expected_value)
|
|
1413
|
+
return spec_map
|
|
1414
|
+
|
|
1415
|
+
def _validate_row_with_fields_spec(
|
|
1416
|
+
table: str,
|
|
1417
|
+
row_id: Any,
|
|
1418
|
+
row_data: Dict[str, Any],
|
|
1419
|
+
fields_spec: List[Tuple[str, Any]],
|
|
1420
|
+
) -> Optional[List[Tuple[str, Any, str]]]:
|
|
1421
|
+
"""Validate a row against a bulk fields spec.
|
|
1422
|
+
|
|
1423
|
+
Returns None if validation passes, or a list of (field, actual_value, issue)
|
|
1424
|
+
tuples for mismatches.
|
|
1425
|
+
|
|
1426
|
+
Field spec semantics:
|
|
1427
|
+
- ("field_name", value): check that field equals value
|
|
1428
|
+
- ("field_name", None): check that field is SQL NULL
|
|
1429
|
+
- ("field_name", ...): don't check value (acknowledge field exists)
|
|
1430
|
+
"""
|
|
1431
|
+
spec_map = _parse_fields_spec(fields_spec)
|
|
1432
|
+
unmatched_fields: List[Tuple[str, Any, str]] = []
|
|
1433
|
+
|
|
1434
|
+
for field_name, field_value in row_data.items():
|
|
1435
|
+
# Skip rowid as it's internal
|
|
1436
|
+
if field_name == "rowid":
|
|
1437
|
+
continue
|
|
1438
|
+
# Skip ignored fields
|
|
1439
|
+
if self.ignore_config.should_ignore_field(table, field_name):
|
|
1440
|
+
continue
|
|
1441
|
+
|
|
1442
|
+
if field_name not in spec_map:
|
|
1443
|
+
# Field not in spec - this is an error
|
|
1444
|
+
unmatched_fields.append(
|
|
1445
|
+
(field_name, field_value, "NOT_IN_FIELDS_SPEC")
|
|
1446
|
+
)
|
|
1447
|
+
else:
|
|
1448
|
+
should_check, expected_value = spec_map[field_name]
|
|
1449
|
+
if should_check and not _values_equivalent(
|
|
1450
|
+
expected_value, field_value
|
|
1451
|
+
):
|
|
1452
|
+
# Value doesn't match
|
|
1453
|
+
unmatched_fields.append(
|
|
1454
|
+
(field_name, field_value, f"expected {repr(expected_value)}")
|
|
1455
|
+
)
|
|
1456
|
+
|
|
1457
|
+
return unmatched_fields if unmatched_fields else None
|
|
1458
|
+
|
|
1459
|
+
def _validate_modification_with_fields_spec(
|
|
1460
|
+
table: str,
|
|
1461
|
+
row_id: Any,
|
|
1462
|
+
row_changes: Dict[str, Dict[str, Any]],
|
|
1463
|
+
resulting_fields: List[Tuple[str, Any]],
|
|
1464
|
+
no_other_changes: bool,
|
|
1465
|
+
) -> Optional[List[Tuple[str, Any, str]]]:
|
|
1466
|
+
"""Validate a modification against a resulting_fields spec.
|
|
1467
|
+
|
|
1468
|
+
Returns None if validation passes, or a list of (field, actual_value, issue)
|
|
1469
|
+
tuples for mismatches.
|
|
1470
|
+
|
|
1471
|
+
Args:
|
|
1472
|
+
table: The table name
|
|
1473
|
+
row_id: The row primary key
|
|
1474
|
+
row_changes: Dict of field_name -> {"before": ..., "after": ...}
|
|
1475
|
+
resulting_fields: List of field tuples to validate
|
|
1476
|
+
no_other_changes: If True, all changed fields must be in resulting_fields.
|
|
1477
|
+
If False, only validate fields in resulting_fields, ignore others.
|
|
1478
|
+
|
|
1479
|
+
Field spec semantics for modifications:
|
|
1480
|
+
- ("field_name", value): check that after value equals value
|
|
1481
|
+
- ("field_name", None): check that after value is SQL NULL
|
|
1482
|
+
- ("field_name", ...): don't check value, just acknowledge field changed
|
|
1483
|
+
"""
|
|
1484
|
+
spec_map = _parse_fields_spec(resulting_fields)
|
|
1485
|
+
unmatched_fields: List[Tuple[str, Any, str]] = []
|
|
1486
|
+
|
|
1487
|
+
for field_name, vals in row_changes.items():
|
|
1488
|
+
# Skip ignored fields
|
|
1489
|
+
if self.ignore_config.should_ignore_field(table, field_name):
|
|
1490
|
+
continue
|
|
1491
|
+
|
|
1492
|
+
after_value = vals["after"]
|
|
1493
|
+
|
|
1494
|
+
if field_name not in spec_map:
|
|
1495
|
+
# Changed field not in spec
|
|
1496
|
+
if no_other_changes:
|
|
1497
|
+
# Strict mode: all changed fields must be accounted for
|
|
1498
|
+
unmatched_fields.append(
|
|
1499
|
+
(field_name, after_value, "NOT_IN_RESULTING_FIELDS")
|
|
1500
|
+
)
|
|
1501
|
+
# If no_other_changes=False, ignore fields not in spec
|
|
1502
|
+
else:
|
|
1503
|
+
should_check, expected_value = spec_map[field_name]
|
|
1504
|
+
if should_check and not _values_equivalent(
|
|
1505
|
+
expected_value, after_value
|
|
1506
|
+
):
|
|
1507
|
+
# Value doesn't match
|
|
1508
|
+
unmatched_fields.append(
|
|
1509
|
+
(field_name, after_value, f"expected {repr(expected_value)}")
|
|
1510
|
+
)
|
|
1511
|
+
|
|
1512
|
+
return unmatched_fields if unmatched_fields else None
|
|
1513
|
+
|
|
1514
|
+
|
|
1515
|
+
# Collect all unexpected changes for detailed reporting
|
|
1516
|
+
unexpected_changes = []
|
|
1517
|
+
|
|
1518
|
+
for tbl, report in diff.items():
|
|
1519
|
+
for row in report.get("modified_rows", []):
|
|
1520
|
+
row_changes = row["changes"]
|
|
1521
|
+
|
|
1522
|
+
# Check for modify spec with resulting_fields
|
|
1523
|
+
modify_spec = _get_modify_spec(tbl, row["row_id"])
|
|
1524
|
+
if modify_spec is not None:
|
|
1525
|
+
resulting_fields = modify_spec.get("resulting_fields")
|
|
1526
|
+
if resulting_fields is not None:
|
|
1527
|
+
# Validate that no_other_changes is provided
|
|
1528
|
+
if "no_other_changes" not in modify_spec:
|
|
1529
|
+
raise ValueError(
|
|
1530
|
+
f"Modify spec for table '{tbl}' pk={row['row_id']} "
|
|
1531
|
+
f"has 'resulting_fields' but missing required 'no_other_changes' field. "
|
|
1532
|
+
f"Set 'no_other_changes': True to verify no other fields changed, "
|
|
1533
|
+
f"or 'no_other_changes': False to only check the specified fields."
|
|
1534
|
+
)
|
|
1535
|
+
no_other_changes = modify_spec["no_other_changes"]
|
|
1536
|
+
if not isinstance(no_other_changes, bool):
|
|
1537
|
+
raise ValueError(
|
|
1538
|
+
f"Modify spec for table '{tbl}' pk={row['row_id']} "
|
|
1539
|
+
f"has 'no_other_changes' but it must be a boolean (True or False), "
|
|
1540
|
+
f"got {type(no_other_changes).__name__}: {repr(no_other_changes)}"
|
|
1541
|
+
)
|
|
1542
|
+
|
|
1543
|
+
unmatched = _validate_modification_with_fields_spec(
|
|
1544
|
+
tbl, row["row_id"], row_changes, resulting_fields, no_other_changes
|
|
1545
|
+
)
|
|
1546
|
+
if unmatched:
|
|
1547
|
+
unexpected_changes.append(
|
|
1548
|
+
{
|
|
1549
|
+
"type": "modification",
|
|
1550
|
+
"table": tbl,
|
|
1551
|
+
"row_id": row["row_id"],
|
|
1552
|
+
"field": None,
|
|
1553
|
+
"before": None,
|
|
1554
|
+
"after": None,
|
|
1555
|
+
"full_row": row,
|
|
1556
|
+
"unmatched_fields": unmatched,
|
|
1557
|
+
}
|
|
1558
|
+
)
|
|
1559
|
+
continue # Skip to next row
|
|
1560
|
+
else:
|
|
1561
|
+
# Modify spec without resulting_fields - just allow the modification
|
|
1562
|
+
continue # Skip to next row
|
|
1563
|
+
|
|
1564
|
+
# Fall back to single-field specs (legacy)
|
|
1565
|
+
for f, vals in row_changes.items():
|
|
1566
|
+
if self.ignore_config.should_ignore_field(tbl, f):
|
|
1567
|
+
continue
|
|
1568
|
+
if not _is_change_allowed(tbl, row["row_id"], f, vals["after"]):
|
|
1569
|
+
unexpected_changes.append(
|
|
1570
|
+
{
|
|
1571
|
+
"type": "modification",
|
|
1572
|
+
"table": tbl,
|
|
1573
|
+
"row_id": row["row_id"],
|
|
1574
|
+
"field": f,
|
|
1575
|
+
"before": vals.get("before"),
|
|
1576
|
+
"after": vals["after"],
|
|
1577
|
+
"full_row": row,
|
|
1578
|
+
}
|
|
1579
|
+
)
|
|
1580
|
+
|
|
1581
|
+
for row in report.get("added_rows", []):
|
|
1582
|
+
row_data = row.get("data", {})
|
|
1583
|
+
|
|
1584
|
+
# Check for bulk fields spec (type: "insert")
|
|
1585
|
+
fields_spec = _get_fields_spec_for_type(tbl, row["row_id"], "insert")
|
|
1586
|
+
if fields_spec is not None:
|
|
1587
|
+
unmatched = _validate_row_with_fields_spec(
|
|
1588
|
+
tbl, row["row_id"], row_data, fields_spec
|
|
1589
|
+
)
|
|
1590
|
+
if unmatched:
|
|
1591
|
+
unexpected_changes.append(
|
|
1592
|
+
{
|
|
1593
|
+
"type": "insertion",
|
|
1594
|
+
"table": tbl,
|
|
1595
|
+
"row_id": row["row_id"],
|
|
1596
|
+
"field": None,
|
|
1597
|
+
"after": "__added__",
|
|
1598
|
+
"full_row": row,
|
|
1599
|
+
"unmatched_fields": unmatched,
|
|
1600
|
+
}
|
|
1601
|
+
)
|
|
1602
|
+
continue # Skip to next row
|
|
1603
|
+
|
|
1604
|
+
# Check if insertion is allowed without field validation
|
|
1605
|
+
if _is_type_allowed(tbl, row["row_id"], "insert"):
|
|
1606
|
+
continue # Insertion is allowed, skip to next row
|
|
1607
|
+
|
|
1608
|
+
# Check for whole-row spec (legacy)
|
|
1609
|
+
whole_row_allowed = _is_change_allowed(
|
|
1610
|
+
tbl, row["row_id"], None, "__added__"
|
|
1611
|
+
)
|
|
1612
|
+
|
|
1613
|
+
if not whole_row_allowed:
|
|
1614
|
+
unexpected_changes.append(
|
|
1615
|
+
{
|
|
1616
|
+
"type": "insertion",
|
|
1617
|
+
"table": tbl,
|
|
1618
|
+
"row_id": row["row_id"],
|
|
1619
|
+
"field": None,
|
|
1620
|
+
"after": "__added__",
|
|
1621
|
+
"full_row": row,
|
|
1622
|
+
}
|
|
1623
|
+
)
|
|
1624
|
+
|
|
1625
|
+
for row in report.get("removed_rows", []):
|
|
1626
|
+
row_data = row.get("data", {})
|
|
1627
|
+
|
|
1628
|
+
# Check for bulk fields spec (type: "delete")
|
|
1629
|
+
fields_spec = _get_fields_spec_for_type(tbl, row["row_id"], "delete")
|
|
1630
|
+
if fields_spec is not None:
|
|
1631
|
+
unmatched = _validate_row_with_fields_spec(
|
|
1632
|
+
tbl, row["row_id"], row_data, fields_spec
|
|
1633
|
+
)
|
|
1634
|
+
if unmatched:
|
|
1635
|
+
unexpected_changes.append(
|
|
1636
|
+
{
|
|
1637
|
+
"type": "deletion",
|
|
1638
|
+
"table": tbl,
|
|
1639
|
+
"row_id": row["row_id"],
|
|
1640
|
+
"field": None,
|
|
1641
|
+
"after": "__removed__",
|
|
1642
|
+
"full_row": row,
|
|
1643
|
+
"unmatched_fields": unmatched,
|
|
1644
|
+
}
|
|
1645
|
+
)
|
|
1646
|
+
continue # Skip to next row
|
|
1647
|
+
|
|
1648
|
+
# Check if deletion is allowed without field validation
|
|
1649
|
+
if _is_type_allowed(tbl, row["row_id"], "delete"):
|
|
1650
|
+
continue # Deletion is allowed, skip to next row
|
|
1651
|
+
|
|
1652
|
+
# Check for whole-row spec (legacy)
|
|
1653
|
+
whole_row_allowed = _is_change_allowed(
|
|
1654
|
+
tbl, row["row_id"], None, "__removed__"
|
|
1655
|
+
)
|
|
1656
|
+
|
|
1657
|
+
if not whole_row_allowed:
|
|
1658
|
+
unexpected_changes.append(
|
|
1659
|
+
{
|
|
1660
|
+
"type": "deletion",
|
|
1661
|
+
"table": tbl,
|
|
1662
|
+
"row_id": row["row_id"],
|
|
1663
|
+
"field": None,
|
|
1664
|
+
"after": "__removed__",
|
|
1665
|
+
"full_row": row,
|
|
1666
|
+
}
|
|
1667
|
+
)
|
|
1668
|
+
|
|
1669
|
+
if unexpected_changes:
|
|
1670
|
+
# Build comprehensive error message
|
|
1671
|
+
error_lines = ["Unexpected database changes detected:"]
|
|
1672
|
+
error_lines.append("")
|
|
1673
|
+
|
|
1674
|
+
for i, change in enumerate(unexpected_changes[:5], 1):
|
|
1675
|
+
error_lines.append(
|
|
1676
|
+
f"{i}. {change['type'].upper()} in table '{change['table']}':"
|
|
1677
|
+
)
|
|
1678
|
+
error_lines.append(f" Row ID: {change['row_id']}")
|
|
1679
|
+
|
|
1680
|
+
if change["type"] == "modification":
|
|
1681
|
+
error_lines.append(f" Field: {change['field']}")
|
|
1682
|
+
error_lines.append(f" Before: {repr(change['before'])}")
|
|
1683
|
+
error_lines.append(f" After: {repr(change['after'])}")
|
|
1684
|
+
elif change["type"] == "insertion":
|
|
1685
|
+
error_lines.append(" New row added")
|
|
1686
|
+
elif change["type"] == "deletion":
|
|
1687
|
+
error_lines.append(" Row deleted")
|
|
1688
|
+
|
|
1689
|
+
# Show unmatched fields if present (from bulk fields spec validation)
|
|
1690
|
+
if "unmatched_fields" in change and change["unmatched_fields"]:
|
|
1691
|
+
error_lines.append(" Unmatched fields:")
|
|
1692
|
+
for field_info in change["unmatched_fields"][:5]:
|
|
1693
|
+
field_name, actual_value, issue = field_info
|
|
1694
|
+
error_lines.append(
|
|
1695
|
+
f" - {field_name}: {repr(actual_value)} ({issue})"
|
|
1696
|
+
)
|
|
1697
|
+
if len(change["unmatched_fields"]) > 10:
|
|
1698
|
+
error_lines.append(
|
|
1699
|
+
f" ... and {len(change['unmatched_fields']) - 10} more"
|
|
1700
|
+
)
|
|
1701
|
+
|
|
1702
|
+
# Show some context from the row
|
|
1703
|
+
if "full_row" in change and change["full_row"]:
|
|
1704
|
+
row_data = change["full_row"]
|
|
1705
|
+
if change["type"] == "modification" and "data" in row_data:
|
|
1706
|
+
# For modifications, show the current state
|
|
1707
|
+
formatted_row = _format_row_for_error(
|
|
1708
|
+
row_data.get("data", {}), max_fields=5
|
|
1709
|
+
)
|
|
1710
|
+
error_lines.append(f" Row data: {formatted_row}")
|
|
1711
|
+
elif (
|
|
1712
|
+
change["type"] in ["insertion", "deletion"]
|
|
1713
|
+
and "data" in row_data
|
|
1714
|
+
):
|
|
1715
|
+
# For insertions/deletions, show the row data
|
|
1716
|
+
formatted_row = _format_row_for_error(
|
|
1717
|
+
row_data.get("data", {}), max_fields=5
|
|
1718
|
+
)
|
|
1719
|
+
error_lines.append(f" Row data: {formatted_row}")
|
|
1720
|
+
|
|
1721
|
+
error_lines.append("")
|
|
1722
|
+
|
|
1723
|
+
if len(unexpected_changes) > 5:
|
|
1724
|
+
error_lines.append(
|
|
1725
|
+
f"... and {len(unexpected_changes) - 5} more unexpected changes"
|
|
1726
|
+
)
|
|
1727
|
+
error_lines.append("")
|
|
1728
|
+
|
|
1729
|
+
# Show what changes were allowed
|
|
1730
|
+
error_lines.append("Allowed changes were:")
|
|
1731
|
+
if allowed_changes:
|
|
1732
|
+
for i, allowed in enumerate(allowed_changes[:3], 1):
|
|
1733
|
+
change_type = allowed.get("type", "unspecified")
|
|
1734
|
+
|
|
1735
|
+
# For modify type, use resulting_fields
|
|
1736
|
+
if change_type == "modify" and "resulting_fields" in allowed and allowed["resulting_fields"] is not None:
|
|
1737
|
+
fields_summary = ", ".join(
|
|
1738
|
+
f[0] if len(f) == 1 else f"{f[0]}={'NOT_CHECKED' if f[1] is ... else repr(f[1])}"
|
|
1739
|
+
for f in allowed["resulting_fields"][:3]
|
|
1740
|
+
)
|
|
1741
|
+
if len(allowed["resulting_fields"]) > 3:
|
|
1742
|
+
fields_summary += f", ... +{len(allowed['resulting_fields']) - 3} more"
|
|
1743
|
+
no_other = allowed.get("no_other_changes", "NOT_SET")
|
|
1744
|
+
error_lines.append(
|
|
1745
|
+
f" {i}. Table: {allowed.get('table')}, "
|
|
1746
|
+
f"ID: {allowed.get('pk')}, "
|
|
1747
|
+
f"Type: {change_type}, "
|
|
1748
|
+
f"resulting_fields: [{fields_summary}], "
|
|
1749
|
+
f"no_other_changes: {no_other}"
|
|
1750
|
+
)
|
|
1751
|
+
elif "fields" in allowed and allowed["fields"] is not None:
|
|
1752
|
+
# Show bulk fields spec (for insert/delete)
|
|
1753
|
+
fields_summary = ", ".join(
|
|
1754
|
+
f[0] if len(f) == 1 else f"{f[0]}={'NOT_CHECKED' if f[1] is ... else repr(f[1])}"
|
|
1755
|
+
for f in allowed["fields"][:3]
|
|
1756
|
+
)
|
|
1757
|
+
if len(allowed["fields"]) > 3:
|
|
1758
|
+
fields_summary += f", ... +{len(allowed['fields']) - 3} more"
|
|
1759
|
+
error_lines.append(
|
|
1760
|
+
f" {i}. Table: {allowed.get('table')}, "
|
|
1761
|
+
f"ID: {allowed.get('pk')}, "
|
|
1762
|
+
f"Type: {change_type}, "
|
|
1763
|
+
f"Fields: [{fields_summary}]"
|
|
1764
|
+
)
|
|
1765
|
+
else:
|
|
1766
|
+
error_lines.append(
|
|
1767
|
+
f" {i}. Table: {allowed.get('table')}, "
|
|
1768
|
+
f"ID: {allowed.get('pk')}, "
|
|
1769
|
+
f"Type: {change_type}"
|
|
1770
|
+
)
|
|
1771
|
+
if len(allowed_changes) > 3:
|
|
1772
|
+
error_lines.append(
|
|
1773
|
+
f" ... and {len(allowed_changes) - 3} more allowed changes"
|
|
1774
|
+
)
|
|
1775
|
+
else:
|
|
1776
|
+
error_lines.append(" (No changes were allowed)")
|
|
1777
|
+
|
|
1778
|
+
raise AssertionError("\n".join(error_lines))
|
|
1779
|
+
|
|
1780
|
+
return self
|
|
1781
|
+
|
|
1782
|
+
def expect_only(self, allowed_changes: List[Dict[str, Any]]):
|
|
1783
|
+
"""Ensure only specified changes occurred."""
|
|
1784
|
+
# Normalize pk values: convert lists to tuples for hashability and consistency
|
|
1785
|
+
for change in allowed_changes:
|
|
1786
|
+
if "pk" in change and isinstance(change["pk"], list):
|
|
1787
|
+
change["pk"] = tuple(change["pk"])
|
|
1788
|
+
|
|
1789
|
+
# Special case: empty allowed_changes means no changes should have occurred
|
|
1790
|
+
if not allowed_changes:
|
|
1791
|
+
return self._expect_no_changes()
|
|
1792
|
+
|
|
1793
|
+
# For expect_only, we can optimize by only checking the specific rows mentioned
|
|
1794
|
+
if self._can_use_targeted_queries(allowed_changes):
|
|
1795
|
+
return self._expect_only_targeted(allowed_changes)
|
|
1796
|
+
|
|
1797
|
+
# Fall back to full diff for complex cases
|
|
1798
|
+
diff = self._collect()
|
|
1799
|
+
return self._validate_diff_against_allowed_changes(diff, allowed_changes)
|
|
1800
|
+
|
|
1801
|
+
def expect_only_v2(self, allowed_changes: List[Dict[str, Any]]):
|
|
1802
|
+
"""Ensure only specified changes occurred, with field-level spec support.
|
|
1803
|
+
|
|
1804
|
+
This version supports field-level specifications for added/removed rows,
|
|
1805
|
+
allowing users to specify expected field values instead of just whole-row specs.
|
|
1806
|
+
"""
|
|
1807
|
+
# Normalize pk values: convert lists to tuples for hashability and consistency
|
|
1808
|
+
for change in allowed_changes:
|
|
1809
|
+
if "pk" in change and isinstance(change["pk"], list):
|
|
1810
|
+
change["pk"] = tuple(change["pk"])
|
|
1811
|
+
|
|
1812
|
+
# Special case: empty allowed_changes means no changes should have occurred
|
|
1813
|
+
if not allowed_changes:
|
|
1814
|
+
return self._expect_no_changes()
|
|
1815
|
+
|
|
1816
|
+
resource = self.after.resource
|
|
1817
|
+
# Disabled: structured diff endpoint not yet available
|
|
1818
|
+
if False and resource.client is not None and resource._mode == "http":
|
|
1819
|
+
api_diff = None
|
|
1820
|
+
try:
|
|
1821
|
+
payload = {}
|
|
1822
|
+
if self.ignore_config:
|
|
1823
|
+
payload["ignore_config"] = {
|
|
1824
|
+
"tables": list(self.ignore_config.tables),
|
|
1825
|
+
"fields": list(self.ignore_config.fields),
|
|
1826
|
+
"table_fields": {
|
|
1827
|
+
table: list(fields) for table, fields in self.ignore_config.table_fields.items()
|
|
1828
|
+
}
|
|
1829
|
+
}
|
|
1830
|
+
response = resource.client.request(
|
|
1831
|
+
"POST",
|
|
1832
|
+
"/diff/structured",
|
|
1833
|
+
json=payload,
|
|
1834
|
+
)
|
|
1835
|
+
result = response.json()
|
|
1836
|
+
if result.get("success") and "diff" in result:
|
|
1837
|
+
api_diff = result["diff"]
|
|
1838
|
+
except Exception as e:
|
|
1839
|
+
# Fall back to local diff if API call fails
|
|
1840
|
+
print(f"Warning: Failed to fetch structured diff from API: {e}")
|
|
1841
|
+
print("Falling back to local diff computation...")
|
|
1842
|
+
|
|
1843
|
+
# Validate outside try block so AssertionError propagates
|
|
1844
|
+
if api_diff is not None:
|
|
1845
|
+
return self._validate_diff_against_allowed_changes_v2(api_diff, allowed_changes)
|
|
1846
|
+
|
|
1847
|
+
# For expect_only_v2, we can optimize by only checking the specific rows mentioned
|
|
1848
|
+
if self._can_use_targeted_queries(allowed_changes):
|
|
1849
|
+
return self._expect_only_targeted_v2(allowed_changes)
|
|
1850
|
+
|
|
1851
|
+
# Fall back to full diff for complex cases
|
|
1852
|
+
diff = self._collect()
|
|
1853
|
+
return self._validate_diff_against_allowed_changes_v2(diff, allowed_changes)
|
|
1854
|
+
|
|
1855
|
+
def expect_exactly(self, expected_changes: List[Dict[str, Any]]):
|
|
1856
|
+
"""Verify that EXACTLY the specified changes occurred.
|
|
1857
|
+
|
|
1858
|
+
This is stricter than expect_only_v2:
|
|
1859
|
+
1. All changes in diff must match a spec (no unexpected changes)
|
|
1860
|
+
2. All specs must have a matching change in diff (no missing expected changes)
|
|
1861
|
+
|
|
1862
|
+
This method is ideal for verifying that an agent performed exactly what was expected -
|
|
1863
|
+
not more, not less.
|
|
1864
|
+
|
|
1865
|
+
Args:
|
|
1866
|
+
expected_changes: List of expected change specs. Each spec requires:
|
|
1867
|
+
- "type": "insert", "modify", or "delete" (required)
|
|
1868
|
+
- "table": table name (required)
|
|
1869
|
+
- "pk": primary key value (required)
|
|
1870
|
+
|
|
1871
|
+
Spec formats by type:
|
|
1872
|
+
- Insert: {"type": "insert", "table": "t", "pk": 1, "fields": [...]}
|
|
1873
|
+
- Modify: {"type": "modify", "table": "t", "pk": 1, "resulting_fields": [...], "no_other_changes": True/False}
|
|
1874
|
+
- Delete: {"type": "delete", "table": "t", "pk": 1}
|
|
1875
|
+
|
|
1876
|
+
Field specs are 2-tuples: (field_name, expected_value)
|
|
1877
|
+
- ("name", "Alice"): check field equals "Alice"
|
|
1878
|
+
- ("name", ...): accept any value (ellipsis)
|
|
1879
|
+
- ("name", None): check field is SQL NULL
|
|
1880
|
+
|
|
1881
|
+
Note: Legacy specs without explicit "type" are not supported.
|
|
1882
|
+
|
|
1883
|
+
Returns:
|
|
1884
|
+
self for method chaining
|
|
1885
|
+
|
|
1886
|
+
Raises:
|
|
1887
|
+
AssertionError: If there are unexpected changes OR if expected changes are missing
|
|
1888
|
+
ValueError: If specs are missing required fields or have invalid format
|
|
1889
|
+
"""
|
|
1890
|
+
# Get the diff (using HTTP if available, otherwise local)
|
|
1891
|
+
resource = self.after.resource
|
|
1892
|
+
diff = None
|
|
1893
|
+
|
|
1894
|
+
if resource.client is not None and resource._mode == "http":
|
|
1895
|
+
try:
|
|
1896
|
+
payload = {}
|
|
1897
|
+
if self.ignore_config:
|
|
1898
|
+
payload["ignore_config"] = {
|
|
1899
|
+
"tables": list(self.ignore_config.tables),
|
|
1900
|
+
"fields": list(self.ignore_config.fields),
|
|
1901
|
+
"table_fields": {
|
|
1902
|
+
table: list(fields) for table, fields in self.ignore_config.table_fields.items()
|
|
1903
|
+
}
|
|
1904
|
+
}
|
|
1905
|
+
response = resource.client.request(
|
|
1906
|
+
"POST",
|
|
1907
|
+
"/diff/structured",
|
|
1908
|
+
json=payload,
|
|
1909
|
+
)
|
|
1910
|
+
result = response.json()
|
|
1911
|
+
if result.get("success") and "diff" in result:
|
|
1912
|
+
diff = result["diff"]
|
|
1913
|
+
except Exception as e:
|
|
1914
|
+
print(f"Warning: Failed to fetch structured diff from API: {e}")
|
|
1915
|
+
print("Falling back to local diff computation...")
|
|
1916
|
+
|
|
1917
|
+
if diff is None:
|
|
1918
|
+
diff = self._collect()
|
|
1919
|
+
|
|
1920
|
+
# Use shared validation logic
|
|
1921
|
+
success, error_msg, _ = validate_diff_expect_exactly(
|
|
1922
|
+
diff, expected_changes, self.ignore_config
|
|
1923
|
+
)
|
|
1924
|
+
|
|
1925
|
+
if not success:
|
|
1926
|
+
raise AssertionError(error_msg)
|
|
1927
|
+
|
|
1928
|
+
return self
|
|
1929
|
+
|
|
1930
|
+
def _ensure_all_fetched(self):
|
|
1931
|
+
"""Fetch ALL data from ALL tables upfront (non-lazy loading).
|
|
1932
|
+
|
|
1933
|
+
This is the old approach before lazy loading was introduced.
|
|
1934
|
+
Used by expect_only_v1 for simpler, non-optimized diffing.
|
|
1935
|
+
"""
|
|
1936
|
+
# Get all tables
|
|
1937
|
+
tables_response = self.before.resource.query(
|
|
1938
|
+
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
|
|
1939
|
+
)
|
|
1940
|
+
|
|
1941
|
+
if tables_response.rows:
|
|
1942
|
+
before_tables = [row[0] for row in tables_response.rows]
|
|
1943
|
+
for table in before_tables:
|
|
1944
|
+
self.before._ensure_table_data(table)
|
|
1945
|
+
|
|
1946
|
+
# Also fetch from after snapshot
|
|
1947
|
+
tables_response = self.after.resource.query(
|
|
1948
|
+
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
|
|
1949
|
+
)
|
|
1950
|
+
|
|
1951
|
+
if tables_response.rows:
|
|
1952
|
+
after_tables = [row[0] for row in tables_response.rows]
|
|
1953
|
+
for table in after_tables:
|
|
1954
|
+
self.after._ensure_table_data(table)
|
|
1955
|
+
|
|
1956
|
+
def expect_only_v1(self, allowed_changes: List[Dict[str, Any]]):
|
|
1957
|
+
"""Ensure only specified changes occurred using the original (non-optimized) approach.
|
|
1958
|
+
|
|
1959
|
+
This version attempts to use the /api/v1/env/diff/structured endpoint if available,
|
|
1960
|
+
falling back to local diff computation if the endpoint is not available.
|
|
1961
|
+
|
|
1962
|
+
Use this when you want the simpler, more predictable behavior of the original
|
|
1963
|
+
implementation without any query optimizations.
|
|
1964
|
+
"""
|
|
1965
|
+
# Try to use the structured diff endpoint if we have an HTTP client
|
|
1966
|
+
resource = self.after.resource
|
|
1967
|
+
if resource.client is not None and resource._mode == "http":
|
|
1968
|
+
api_diff = None
|
|
1969
|
+
try:
|
|
1970
|
+
payload = {}
|
|
1971
|
+
if self.ignore_config:
|
|
1972
|
+
payload["ignore_config"] = {
|
|
1973
|
+
"tables": list(self.ignore_config.tables),
|
|
1974
|
+
"fields": list(self.ignore_config.fields),
|
|
1975
|
+
"table_fields": {
|
|
1976
|
+
table: list(fields) for table, fields in self.ignore_config.table_fields.items()
|
|
1977
|
+
}
|
|
1978
|
+
}
|
|
1979
|
+
response = resource.client.request(
|
|
1980
|
+
"POST",
|
|
1981
|
+
"/diff/structured",
|
|
1982
|
+
json=payload,
|
|
1983
|
+
)
|
|
1984
|
+
result = response.json()
|
|
1985
|
+
if result.get("success") and "diff" in result:
|
|
1986
|
+
api_diff = result["diff"]
|
|
1987
|
+
except Exception as e:
|
|
1988
|
+
# Fall back to local diff if API call fails
|
|
1989
|
+
print(f"Warning: Failed to fetch structured diff from API: {e}")
|
|
1990
|
+
print("Falling back to local diff computation...")
|
|
1991
|
+
|
|
1992
|
+
# Validate outside try block so AssertionError propagates
|
|
1993
|
+
if api_diff is not None:
|
|
1994
|
+
return self._validate_diff_against_allowed_changes(api_diff, allowed_changes)
|
|
1995
|
+
|
|
1996
|
+
# Fall back to local diff computation
|
|
1997
|
+
self._ensure_all_fetched()
|
|
1998
|
+
diff = self._collect()
|
|
1999
|
+
return self._validate_diff_against_allowed_changes(diff, allowed_changes)
|
|
2000
|
+
|
|
456
2001
|
|
|
457
2002
|
class SyncQueryBuilder:
|
|
458
2003
|
"""Async query builder that translates DSL to SQL and executes through the API."""
|
|
@@ -551,13 +2096,13 @@ class SyncQueryBuilder:
|
|
|
551
2096
|
# Compile to SQL
|
|
552
2097
|
def _compile(self) -> Tuple[str, List[Any]]:
|
|
553
2098
|
cols = ", ".join(self._select_cols)
|
|
554
|
-
sql = [f"SELECT {cols} FROM {self._table}"]
|
|
2099
|
+
sql = [f"SELECT {cols} FROM {_quote_identifier(self._table)}"]
|
|
555
2100
|
params: List[Any] = []
|
|
556
2101
|
|
|
557
2102
|
# Joins
|
|
558
2103
|
for tbl, onmap in self._joins:
|
|
559
|
-
join_clauses = [f"{self._table}.{l} = {tbl}.{r}" for l, r in onmap.items()]
|
|
560
|
-
sql.append(f"JOIN {tbl} ON {' AND '.join(join_clauses)}")
|
|
2104
|
+
join_clauses = [f"{_quote_identifier(self._table)}.{_quote_identifier(l)} = {_quote_identifier(tbl)}.{_quote_identifier(r)}" for l, r in onmap.items()]
|
|
2105
|
+
sql.append(f"JOIN {_quote_identifier(tbl)} ON {' AND '.join(join_clauses)}")
|
|
561
2106
|
|
|
562
2107
|
# WHERE
|
|
563
2108
|
if self._conditions:
|
|
@@ -565,12 +2110,12 @@ class SyncQueryBuilder:
|
|
|
565
2110
|
for col, op, val in self._conditions:
|
|
566
2111
|
if op in ("IN", "NOT IN") and isinstance(val, tuple):
|
|
567
2112
|
ph = ", ".join(["?" for _ in val])
|
|
568
|
-
placeholders.append(f"{col} {op} ({ph})")
|
|
2113
|
+
placeholders.append(f"{_quote_identifier(col)} {op} ({ph})")
|
|
569
2114
|
params.extend(val)
|
|
570
2115
|
elif op in ("IS", "IS NOT"):
|
|
571
|
-
placeholders.append(f"{col} {op} NULL")
|
|
2116
|
+
placeholders.append(f"{_quote_identifier(col)} {op} NULL")
|
|
572
2117
|
else:
|
|
573
|
-
placeholders.append(f"{col} {op} ?")
|
|
2118
|
+
placeholders.append(f"{_quote_identifier(col)} {op} ?")
|
|
574
2119
|
params.append(val)
|
|
575
2120
|
sql.append("WHERE " + " AND ".join(placeholders))
|
|
576
2121
|
|
|
@@ -675,16 +2220,103 @@ class SyncQueryBuilder:
|
|
|
675
2220
|
|
|
676
2221
|
|
|
677
2222
|
class SQLiteResource(Resource):
|
|
678
|
-
def __init__(
|
|
2223
|
+
def __init__(
|
|
2224
|
+
self,
|
|
2225
|
+
resource: ResourceModel,
|
|
2226
|
+
client: Optional["SyncWrapper"] = None,
|
|
2227
|
+
db_path: Optional[str] = None,
|
|
2228
|
+
):
|
|
679
2229
|
super().__init__(resource)
|
|
680
2230
|
self.client = client
|
|
2231
|
+
self.db_path = db_path
|
|
2232
|
+
self._mode = "direct" if db_path else "http"
|
|
2233
|
+
|
|
2234
|
+
@property
|
|
2235
|
+
def mode(self) -> str:
|
|
2236
|
+
"""Return the mode of this resource: 'direct' (local file) or 'http' (remote API)."""
|
|
2237
|
+
return self._mode
|
|
681
2238
|
|
|
682
2239
|
def describe(self) -> DescribeResponse:
|
|
683
2240
|
"""Describe the SQLite database schema."""
|
|
2241
|
+
if self._mode == "direct":
|
|
2242
|
+
return self._describe_direct()
|
|
2243
|
+
else:
|
|
2244
|
+
return self._describe_http()
|
|
2245
|
+
|
|
2246
|
+
def _describe_http(self) -> DescribeResponse:
|
|
2247
|
+
"""Describe database schema via HTTP API."""
|
|
684
2248
|
response = self.client.request(
|
|
685
2249
|
"GET", f"/resources/sqlite/{self.resource.name}/describe"
|
|
686
2250
|
)
|
|
687
|
-
|
|
2251
|
+
try:
|
|
2252
|
+
return DescribeResponse(**response.json())
|
|
2253
|
+
except json.JSONDecodeError as e:
|
|
2254
|
+
raise ValueError(
|
|
2255
|
+
f"Failed to parse JSON response from SQLite describe endpoint. "
|
|
2256
|
+
f"Status: {response.status_code}, "
|
|
2257
|
+
f"Response text: {response.text[:500]}"
|
|
2258
|
+
) from e
|
|
2259
|
+
|
|
2260
|
+
def _describe_direct(self) -> DescribeResponse:
|
|
2261
|
+
"""Describe database schema from local file or in-memory database."""
|
|
2262
|
+
try:
|
|
2263
|
+
# Check if we need URI mode (for shared memory databases)
|
|
2264
|
+
use_uri = 'mode=memory' in self.db_path
|
|
2265
|
+
conn = sqlite3.connect(self.db_path, uri=use_uri)
|
|
2266
|
+
cursor = conn.cursor()
|
|
2267
|
+
|
|
2268
|
+
# Get all tables
|
|
2269
|
+
cursor.execute(
|
|
2270
|
+
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
|
|
2271
|
+
)
|
|
2272
|
+
table_names = [row[0] for row in cursor.fetchall()]
|
|
2273
|
+
|
|
2274
|
+
tables = []
|
|
2275
|
+
for table_name in table_names:
|
|
2276
|
+
# Get table info
|
|
2277
|
+
cursor.execute(f"PRAGMA table_info({_quote_identifier(table_name)})")
|
|
2278
|
+
columns = cursor.fetchall()
|
|
2279
|
+
|
|
2280
|
+
# Get CREATE TABLE SQL
|
|
2281
|
+
cursor.execute(
|
|
2282
|
+
f"SELECT sql FROM sqlite_master WHERE type='table' AND name=?",
|
|
2283
|
+
(table_name,)
|
|
2284
|
+
)
|
|
2285
|
+
sql_row = cursor.fetchone()
|
|
2286
|
+
create_sql = sql_row[0] if sql_row else ""
|
|
2287
|
+
|
|
2288
|
+
table_schema = {
|
|
2289
|
+
"name": table_name,
|
|
2290
|
+
"sql": create_sql,
|
|
2291
|
+
"columns": [
|
|
2292
|
+
{
|
|
2293
|
+
"name": col[1],
|
|
2294
|
+
"type": col[2],
|
|
2295
|
+
"notnull": bool(col[3]),
|
|
2296
|
+
"default_value": col[4],
|
|
2297
|
+
"primary_key": col[5] > 0,
|
|
2298
|
+
}
|
|
2299
|
+
for col in columns
|
|
2300
|
+
],
|
|
2301
|
+
}
|
|
2302
|
+
tables.append(table_schema)
|
|
2303
|
+
|
|
2304
|
+
conn.close()
|
|
2305
|
+
|
|
2306
|
+
return DescribeResponse(
|
|
2307
|
+
success=True,
|
|
2308
|
+
resource_name=self.resource.name,
|
|
2309
|
+
tables=tables,
|
|
2310
|
+
message="Schema retrieved from local file",
|
|
2311
|
+
)
|
|
2312
|
+
except Exception as e:
|
|
2313
|
+
return DescribeResponse(
|
|
2314
|
+
success=False,
|
|
2315
|
+
resource_name=self.resource.name,
|
|
2316
|
+
tables=None,
|
|
2317
|
+
error=str(e),
|
|
2318
|
+
message=f"Failed to describe database: {str(e)}",
|
|
2319
|
+
)
|
|
688
2320
|
|
|
689
2321
|
def query(self, query: str, args: Optional[List[Any]] = None) -> QueryResponse:
|
|
690
2322
|
return self._query(query, args, read_only=True)
|
|
@@ -695,6 +2327,121 @@ class SQLiteResource(Resource):
|
|
|
695
2327
|
def _query(
|
|
696
2328
|
self, query: str, args: Optional[List[Any]] = None, read_only: bool = True
|
|
697
2329
|
) -> QueryResponse:
|
|
2330
|
+
if self._mode == "direct":
|
|
2331
|
+
return self._query_direct(query, args, read_only)
|
|
2332
|
+
else:
|
|
2333
|
+
# Check if this is a PRAGMA query - HTTP endpoints don't support PRAGMA
|
|
2334
|
+
query_stripped = query.strip().upper()
|
|
2335
|
+
if query_stripped.startswith("PRAGMA"):
|
|
2336
|
+
return self._handle_pragma_query_http(query, args)
|
|
2337
|
+
return self._query_http(query, args, read_only)
|
|
2338
|
+
|
|
2339
|
+
def _handle_pragma_query_http(
|
|
2340
|
+
self, query: str, args: Optional[List[Any]] = None
|
|
2341
|
+
) -> QueryResponse:
|
|
2342
|
+
"""Handle PRAGMA queries in HTTP mode by using the describe endpoint."""
|
|
2343
|
+
query_upper = query.strip().upper()
|
|
2344
|
+
|
|
2345
|
+
# Extract table name from PRAGMA table_info(table_name)
|
|
2346
|
+
if "TABLE_INFO" in query_upper:
|
|
2347
|
+
# Match: PRAGMA table_info("table") or PRAGMA table_info(table)
|
|
2348
|
+
match = re.search(r'TABLE_INFO\s*\(\s*"([^"]+)"\s*\)', query, re.IGNORECASE)
|
|
2349
|
+
if not match:
|
|
2350
|
+
match = re.search(r"TABLE_INFO\s*\(\s*'([^']+)'\s*\)", query, re.IGNORECASE)
|
|
2351
|
+
if not match:
|
|
2352
|
+
match = re.search(r'TABLE_INFO\s*\(\s*([^\s\)]+)\s*\)', query, re.IGNORECASE)
|
|
2353
|
+
|
|
2354
|
+
if match:
|
|
2355
|
+
table_name = match.group(1)
|
|
2356
|
+
|
|
2357
|
+
# Use the describe endpoint to get schema
|
|
2358
|
+
describe_response = self.describe()
|
|
2359
|
+
if not describe_response.success or not describe_response.tables:
|
|
2360
|
+
return QueryResponse(
|
|
2361
|
+
success=False,
|
|
2362
|
+
columns=None,
|
|
2363
|
+
rows=None,
|
|
2364
|
+
error="Failed to get schema information",
|
|
2365
|
+
message="PRAGMA query failed: could not retrieve schema"
|
|
2366
|
+
)
|
|
2367
|
+
|
|
2368
|
+
# Find the table in the schema
|
|
2369
|
+
table_schema = None
|
|
2370
|
+
for table in describe_response.tables:
|
|
2371
|
+
# Handle both dict and TableSchema objects
|
|
2372
|
+
table_name_in_schema = table.name if hasattr(table, 'name') else table.get("name")
|
|
2373
|
+
if table_name_in_schema == table_name:
|
|
2374
|
+
table_schema = table
|
|
2375
|
+
break
|
|
2376
|
+
|
|
2377
|
+
if not table_schema:
|
|
2378
|
+
return QueryResponse(
|
|
2379
|
+
success=False,
|
|
2380
|
+
columns=None,
|
|
2381
|
+
rows=None,
|
|
2382
|
+
error=f"Table '{table_name}' not found",
|
|
2383
|
+
message=f"PRAGMA query failed: table '{table_name}' not found"
|
|
2384
|
+
)
|
|
2385
|
+
|
|
2386
|
+
# Get columns from table schema
|
|
2387
|
+
columns = table_schema.columns if hasattr(table_schema, 'columns') else table_schema.get("columns")
|
|
2388
|
+
if not columns:
|
|
2389
|
+
return QueryResponse(
|
|
2390
|
+
success=False,
|
|
2391
|
+
columns=None,
|
|
2392
|
+
rows=None,
|
|
2393
|
+
error=f"Table '{table_name}' has no columns",
|
|
2394
|
+
message=f"PRAGMA query failed: table '{table_name}' has no columns"
|
|
2395
|
+
)
|
|
2396
|
+
|
|
2397
|
+
# Convert schema to PRAGMA table_info format
|
|
2398
|
+
# Format: (cid, name, type, notnull, dflt_value, pk)
|
|
2399
|
+
rows = []
|
|
2400
|
+
for idx, col in enumerate(columns):
|
|
2401
|
+
# Handle both dict and object column definitions
|
|
2402
|
+
if isinstance(col, dict):
|
|
2403
|
+
col_name = col["name"]
|
|
2404
|
+
col_type = col.get("type", "")
|
|
2405
|
+
col_notnull = col.get("notnull", False)
|
|
2406
|
+
col_default = col.get("default_value")
|
|
2407
|
+
col_pk = col.get("pk", 0)
|
|
2408
|
+
else:
|
|
2409
|
+
col_name = col.name if hasattr(col, 'name') else str(col)
|
|
2410
|
+
col_type = getattr(col, 'type', "")
|
|
2411
|
+
col_notnull = getattr(col, 'notnull', False)
|
|
2412
|
+
col_default = getattr(col, 'default_value', None)
|
|
2413
|
+
col_pk = getattr(col, 'pk', 0)
|
|
2414
|
+
|
|
2415
|
+
row = (
|
|
2416
|
+
idx, # cid
|
|
2417
|
+
col_name, # name
|
|
2418
|
+
col_type, # type
|
|
2419
|
+
1 if col_notnull else 0, # notnull
|
|
2420
|
+
col_default, # dflt_value
|
|
2421
|
+
col_pk # pk
|
|
2422
|
+
)
|
|
2423
|
+
rows.append(row)
|
|
2424
|
+
|
|
2425
|
+
return QueryResponse(
|
|
2426
|
+
success=True,
|
|
2427
|
+
columns=["cid", "name", "type", "notnull", "dflt_value", "pk"],
|
|
2428
|
+
rows=rows,
|
|
2429
|
+
message="PRAGMA query executed successfully via describe endpoint"
|
|
2430
|
+
)
|
|
2431
|
+
|
|
2432
|
+
# For other PRAGMA queries, return an error indicating they're not supported
|
|
2433
|
+
return QueryResponse(
|
|
2434
|
+
success=False,
|
|
2435
|
+
columns=None,
|
|
2436
|
+
rows=None,
|
|
2437
|
+
error="PRAGMA query not supported in HTTP mode",
|
|
2438
|
+
message=f"PRAGMA query '{query}' is not supported via HTTP API"
|
|
2439
|
+
)
|
|
2440
|
+
|
|
2441
|
+
def _query_http(
|
|
2442
|
+
self, query: str, args: Optional[List[Any]] = None, read_only: bool = True
|
|
2443
|
+
) -> QueryResponse:
|
|
2444
|
+
"""Execute query via HTTP API."""
|
|
698
2445
|
request = QueryRequest(query=query, args=args, read_only=read_only)
|
|
699
2446
|
response = self.client.request(
|
|
700
2447
|
"POST",
|
|
@@ -703,6 +2450,59 @@ class SQLiteResource(Resource):
|
|
|
703
2450
|
)
|
|
704
2451
|
return QueryResponse(**response.json())
|
|
705
2452
|
|
|
2453
|
+
def _query_direct(
|
|
2454
|
+
self, query: str, args: Optional[List[Any]] = None, read_only: bool = True
|
|
2455
|
+
) -> QueryResponse:
|
|
2456
|
+
"""Execute query directly on local SQLite file or in-memory database."""
|
|
2457
|
+
try:
|
|
2458
|
+
# Check if we need URI mode (for shared memory databases)
|
|
2459
|
+
use_uri = 'mode=memory' in self.db_path
|
|
2460
|
+
conn = sqlite3.connect(self.db_path, uri=use_uri)
|
|
2461
|
+
cursor = conn.cursor()
|
|
2462
|
+
|
|
2463
|
+
# Execute the query
|
|
2464
|
+
if args:
|
|
2465
|
+
cursor.execute(query, args)
|
|
2466
|
+
else:
|
|
2467
|
+
cursor.execute(query)
|
|
2468
|
+
|
|
2469
|
+
# For write operations, commit the transaction
|
|
2470
|
+
if not read_only:
|
|
2471
|
+
conn.commit()
|
|
2472
|
+
|
|
2473
|
+
# Get column names if available
|
|
2474
|
+
columns = [desc[0] for desc in cursor.description] if cursor.description else []
|
|
2475
|
+
|
|
2476
|
+
# Fetch results for SELECT queries
|
|
2477
|
+
rows = []
|
|
2478
|
+
rows_affected = 0
|
|
2479
|
+
last_insert_id = None
|
|
2480
|
+
|
|
2481
|
+
if cursor.description: # SELECT query
|
|
2482
|
+
rows = cursor.fetchall()
|
|
2483
|
+
else: # INSERT/UPDATE/DELETE
|
|
2484
|
+
rows_affected = cursor.rowcount
|
|
2485
|
+
last_insert_id = cursor.lastrowid if cursor.lastrowid else None
|
|
2486
|
+
|
|
2487
|
+
conn.close()
|
|
2488
|
+
|
|
2489
|
+
return QueryResponse(
|
|
2490
|
+
success=True,
|
|
2491
|
+
columns=columns if columns else None,
|
|
2492
|
+
rows=rows if rows else None,
|
|
2493
|
+
rows_affected=rows_affected if rows_affected > 0 else None,
|
|
2494
|
+
last_insert_id=last_insert_id,
|
|
2495
|
+
message="Query executed successfully",
|
|
2496
|
+
)
|
|
2497
|
+
except Exception as e:
|
|
2498
|
+
return QueryResponse(
|
|
2499
|
+
success=False,
|
|
2500
|
+
columns=None,
|
|
2501
|
+
rows=None,
|
|
2502
|
+
error=str(e),
|
|
2503
|
+
message=f"Query failed: {str(e)}",
|
|
2504
|
+
)
|
|
2505
|
+
|
|
706
2506
|
def table(self, table_name: str) -> SyncQueryBuilder:
|
|
707
2507
|
"""Create a query builder for the specified table."""
|
|
708
2508
|
return SyncQueryBuilder(self, table_name)
|
|
@@ -710,7 +2510,6 @@ class SQLiteResource(Resource):
|
|
|
710
2510
|
def snapshot(self, name: Optional[str] = None) -> SyncDatabaseSnapshot:
|
|
711
2511
|
"""Create a snapshot of the current database state."""
|
|
712
2512
|
snapshot = SyncDatabaseSnapshot(self, name)
|
|
713
|
-
snapshot._ensure_fetched()
|
|
714
2513
|
return snapshot
|
|
715
2514
|
|
|
716
2515
|
def diff(
|