fleet-python 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fleet-python might be problematic. Click here for more details.
- examples/dsl_example.py +112 -0
- examples/example.py +38 -0
- examples/nova_act_example.py +180 -0
- examples/openai_example.py +448 -0
- examples/quickstart.py +5 -5
- fleet/__init__.py +24 -3
- fleet/base.py +1 -1
- fleet/client.py +60 -28
- fleet/env/__init__.py +2 -7
- fleet/env/client.py +9 -235
- fleet/manager/__init__.py +22 -0
- fleet/manager/client.py +258 -0
- fleet/{env → manager}/models.py +15 -14
- fleet/resources/base.py +5 -2
- fleet/resources/browser.py +32 -6
- fleet/resources/sqlite.py +5 -5
- fleet/verifiers/__init__.py +4 -0
- fleet/verifiers/database_snapshot.py +666 -0
- fleet/verifiers/sql_differ.py +187 -0
- {fleet_python-0.2.0.dist-info → fleet_python-0.2.2.dist-info}/METADATA +1 -1
- fleet_python-0.2.2.dist-info/RECORD +27 -0
- examples/browser_control_example.py +0 -51
- fleet_python-0.2.0.dist-info/RECORD +0 -19
- /fleet/{env → manager}/base.py +0 -0
- {fleet_python-0.2.0.dist-info → fleet_python-0.2.2.dist-info}/WHEEL +0 -0
- {fleet_python-0.2.0.dist-info → fleet_python-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {fleet_python-0.2.0.dist-info → fleet_python-0.2.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,666 @@
|
|
|
1
|
+
# database_dsl.py
|
|
2
|
+
"""A schema‑agnostic, SQL‑native DSL for snapshot validation and diff invariants.
|
|
3
|
+
|
|
4
|
+
The module extends your original `DatabaseSnapshot` implementation with
|
|
5
|
+
|
|
6
|
+
* A **Supabase‑style query builder** (method‑chaining: `select`, `eq`, `join`, …).
|
|
7
|
+
* Assertion helpers (`assert_exists`, `assert_none`, `assert_eq`, `count().assert_eq`, …).
|
|
8
|
+
* A `SnapshotDiff` engine that enforces invariants (`expect_only`, `expect`).
|
|
9
|
+
* Convenience helpers (`expect_row`, `expect_rows`, `expect_absent_row`).
|
|
10
|
+
|
|
11
|
+
The public API stays tiny yet composable; everything else is built on
|
|
12
|
+
orthogonal primitives so it works for *any* relational schema.
|
|
13
|
+
"""
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import sqlite3
|
|
17
|
+
from datetime import datetime
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
################################################################################
|
|
21
|
+
# Low‑level helpers
|
|
22
|
+
################################################################################
|
|
23
|
+
|
|
24
|
+
SQLValue = str | int | float | None
|
|
25
|
+
Condition = tuple[str, str, SQLValue] # (column, op, value)
|
|
26
|
+
JoinSpec = tuple[str, dict[str, str]] # (table, on mapping)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class _CountResult:
|
|
30
|
+
"""Wraps an integer count so we can chain assertions fluently."""
|
|
31
|
+
|
|
32
|
+
def __init__(self, value: int):
|
|
33
|
+
self.value = value
|
|
34
|
+
|
|
35
|
+
# Assertions ------------------------------------------------------------
|
|
36
|
+
def assert_eq(self, expected: int):
|
|
37
|
+
if self.value != expected:
|
|
38
|
+
raise AssertionError(f"Expected {expected}, got {self.value}")
|
|
39
|
+
return self
|
|
40
|
+
|
|
41
|
+
def assert_gt(self, threshold: int):
|
|
42
|
+
if self.value <= threshold:
|
|
43
|
+
raise AssertionError(f"Expected > {threshold}, got {self.value}")
|
|
44
|
+
return self
|
|
45
|
+
|
|
46
|
+
def assert_between(self, low: int, high: int):
|
|
47
|
+
if not low <= self.value <= high:
|
|
48
|
+
raise AssertionError(f"Expected {low}‑{high}, got {self.value}")
|
|
49
|
+
return self
|
|
50
|
+
|
|
51
|
+
# Convenience -----------------------------------------------------------
|
|
52
|
+
def __int__(self):
|
|
53
|
+
return self.value
|
|
54
|
+
|
|
55
|
+
def __repr__(self):
|
|
56
|
+
return f"<Count {self.value}>"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
################################################################################
|
|
60
|
+
# Query Builder
|
|
61
|
+
################################################################################
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class QueryBuilder:
|
|
65
|
+
"""Fluent SQL builder executed against a single `DatabaseSnapshot`."""
|
|
66
|
+
|
|
67
|
+
def __init__(self, snapshot: "DatabaseSnapshot", table: str): # noqa: UP037
|
|
68
|
+
self._snapshot = snapshot
|
|
69
|
+
self._table = table
|
|
70
|
+
self._select_cols: list[str] = ["*"]
|
|
71
|
+
self._conditions: list[Condition] = []
|
|
72
|
+
self._joins: list[JoinSpec] = []
|
|
73
|
+
self._limit: int | None = None
|
|
74
|
+
self._order_by: str | None = None
|
|
75
|
+
# Cache for idempotent executions
|
|
76
|
+
self._cached_rows: list[dict[str, Any]] | None = None
|
|
77
|
+
|
|
78
|
+
# ---------------------------------------------------------------------
|
|
79
|
+
# Column projection / limiting / ordering
|
|
80
|
+
# ---------------------------------------------------------------------
|
|
81
|
+
def select(self, *columns: str) -> "QueryBuilder": # noqa: UP037
|
|
82
|
+
qb = self._clone()
|
|
83
|
+
qb._select_cols = list(columns) if columns else ["*"]
|
|
84
|
+
return qb
|
|
85
|
+
|
|
86
|
+
def limit(self, n: int) -> "QueryBuilder": # noqa: UP037
|
|
87
|
+
qb = self._clone()
|
|
88
|
+
qb._limit = n
|
|
89
|
+
return qb
|
|
90
|
+
|
|
91
|
+
def sort(self, column: str, desc: bool = False) -> "QueryBuilder": # noqa: UP037
|
|
92
|
+
qb = self._clone()
|
|
93
|
+
qb._order_by = f"{column} {'DESC' if desc else 'ASC'}"
|
|
94
|
+
return qb
|
|
95
|
+
|
|
96
|
+
# ---------------------------------------------------------------------
|
|
97
|
+
# WHERE helpers (SQL‑like)
|
|
98
|
+
# ---------------------------------------------------------------------
|
|
99
|
+
def _add_condition(
|
|
100
|
+
self, column: str, op: str, value: SQLValue
|
|
101
|
+
) -> "QueryBuilder": # noqa: UP037
|
|
102
|
+
qb = self._clone()
|
|
103
|
+
qb._conditions.append((column, op, value))
|
|
104
|
+
return qb
|
|
105
|
+
|
|
106
|
+
def eq(self, column: str, value: SQLValue) -> "QueryBuilder": # noqa: UP037
|
|
107
|
+
return self._add_condition(column, "=", value)
|
|
108
|
+
|
|
109
|
+
def neq(self, column: str, value: SQLValue) -> "QueryBuilder": # noqa: UP037
|
|
110
|
+
return self._add_condition(column, "!=", value)
|
|
111
|
+
|
|
112
|
+
def gt(self, column: str, value: SQLValue) -> "QueryBuilder": # noqa: UP037
|
|
113
|
+
return self._add_condition(column, ">", value)
|
|
114
|
+
|
|
115
|
+
def gte(self, column: str, value: SQLValue) -> "QueryBuilder": # noqa: UP037
|
|
116
|
+
return self._add_condition(column, ">=", value)
|
|
117
|
+
|
|
118
|
+
def lt(self, column: str, value: SQLValue) -> "QueryBuilder": # noqa: UP037
|
|
119
|
+
return self._add_condition(column, "<", value)
|
|
120
|
+
|
|
121
|
+
def lte(self, column: str, value: SQLValue) -> "QueryBuilder": # noqa: UP037
|
|
122
|
+
return self._add_condition(column, "<=", value)
|
|
123
|
+
|
|
124
|
+
def in_(self, column: str, values: list[SQLValue]) -> "QueryBuilder": # noqa: UP037
|
|
125
|
+
qb = self._clone()
|
|
126
|
+
qb._conditions.append((column, "IN", tuple(values)))
|
|
127
|
+
return qb
|
|
128
|
+
|
|
129
|
+
def not_in(
|
|
130
|
+
self, column: str, values: list[SQLValue]
|
|
131
|
+
) -> "QueryBuilder": # noqa: UP037
|
|
132
|
+
qb = self._clone()
|
|
133
|
+
qb._conditions.append((column, "NOT IN", tuple(values)))
|
|
134
|
+
return qb
|
|
135
|
+
|
|
136
|
+
def is_null(self, column: str) -> "QueryBuilder": # noqa: UP037
|
|
137
|
+
return self._add_condition(column, "IS", None)
|
|
138
|
+
|
|
139
|
+
def not_null(self, column: str) -> "QueryBuilder": # noqa: UP037
|
|
140
|
+
return self._add_condition(column, "IS NOT", None)
|
|
141
|
+
|
|
142
|
+
def ilike(self, column: str, pattern: str) -> "QueryBuilder": # noqa: UP037
|
|
143
|
+
qb = self._clone()
|
|
144
|
+
qb._conditions.append((column, "ILIKE", pattern))
|
|
145
|
+
return qb
|
|
146
|
+
|
|
147
|
+
# ---------------------------------------------------------------------
|
|
148
|
+
# JOIN (simple inner join)
|
|
149
|
+
# ---------------------------------------------------------------------
|
|
150
|
+
def join(
|
|
151
|
+
self, other_table: str, on: dict[str, str]
|
|
152
|
+
) -> "QueryBuilder": # noqa: UP037
|
|
153
|
+
"""`on` expects {local_col: remote_col}."""
|
|
154
|
+
qb = self._clone()
|
|
155
|
+
qb._joins.append((other_table, on))
|
|
156
|
+
return qb
|
|
157
|
+
|
|
158
|
+
# ---------------------------------------------------------------------
|
|
159
|
+
# Execution helpers
|
|
160
|
+
# ---------------------------------------------------------------------
|
|
161
|
+
def _compile(self) -> tuple[str, list[Any]]:
|
|
162
|
+
cols = ", ".join(self._select_cols)
|
|
163
|
+
sql = [f"SELECT {cols} FROM {self._table}"]
|
|
164
|
+
params: list[Any] = []
|
|
165
|
+
|
|
166
|
+
# Joins -------------------------------------------------------------
|
|
167
|
+
for tbl, onmap in self._joins:
|
|
168
|
+
join_clauses = [
|
|
169
|
+
f"{self._table}.{l} = {tbl}.{r}" for l, r in onmap.items() # noqa: E741
|
|
170
|
+
]
|
|
171
|
+
sql.append(f"JOIN {tbl} ON {' AND '.join(join_clauses)}")
|
|
172
|
+
|
|
173
|
+
# WHERE -------------------------------------------------------------
|
|
174
|
+
if self._conditions:
|
|
175
|
+
placeholders = []
|
|
176
|
+
for col, op, val in self._conditions:
|
|
177
|
+
if op in ("IN", "NOT IN") and isinstance(val, tuple):
|
|
178
|
+
ph = ", ".join(["?" for _ in val])
|
|
179
|
+
placeholders.append(f"{col} {op} ({ph})")
|
|
180
|
+
params.extend(val)
|
|
181
|
+
elif op in ("IS", "IS NOT"):
|
|
182
|
+
placeholders.append(f"{col} {op} NULL")
|
|
183
|
+
elif op == "ILIKE":
|
|
184
|
+
placeholders.append(
|
|
185
|
+
f"{col} LIKE ?"
|
|
186
|
+
) # SQLite has no ILIKE; LIKE is case‑insensitive when in NOCASE collation
|
|
187
|
+
params.append(val)
|
|
188
|
+
else:
|
|
189
|
+
placeholders.append(f"{col} {op} ?")
|
|
190
|
+
params.append(val)
|
|
191
|
+
sql.append("WHERE " + " AND ".join(placeholders))
|
|
192
|
+
|
|
193
|
+
# ORDER / LIMIT -----------------------------------------------------
|
|
194
|
+
if self._order_by:
|
|
195
|
+
sql.append(f"ORDER BY {self._order_by}")
|
|
196
|
+
if self._limit is not None:
|
|
197
|
+
sql.append(f"LIMIT {self._limit}")
|
|
198
|
+
|
|
199
|
+
return " ".join(sql), params
|
|
200
|
+
|
|
201
|
+
def _execute(self) -> list[dict[str, Any]]:
|
|
202
|
+
if self._cached_rows is not None:
|
|
203
|
+
return self._cached_rows
|
|
204
|
+
|
|
205
|
+
sql, params = self._compile()
|
|
206
|
+
conn = sqlite3.connect(self._snapshot.db_path)
|
|
207
|
+
conn.row_factory = sqlite3.Row
|
|
208
|
+
cur = conn.cursor()
|
|
209
|
+
cur.execute(sql, params)
|
|
210
|
+
rows = [dict(r) for r in cur.fetchall()]
|
|
211
|
+
cur.close()
|
|
212
|
+
conn.close()
|
|
213
|
+
self._cached_rows = rows
|
|
214
|
+
return rows
|
|
215
|
+
|
|
216
|
+
# ---------------------------------------------------------------------
|
|
217
|
+
# High‑level result helpers / assertions
|
|
218
|
+
# ---------------------------------------------------------------------
|
|
219
|
+
def count(self) -> _CountResult:
|
|
220
|
+
qb = self.select("COUNT(*) AS __cnt__").limit(
|
|
221
|
+
None
|
|
222
|
+
) # remove limit since counting overrides
|
|
223
|
+
sql, params = qb._compile()
|
|
224
|
+
conn = sqlite3.connect(self._snapshot.db_path)
|
|
225
|
+
cur = conn.cursor()
|
|
226
|
+
cur.execute(sql, params)
|
|
227
|
+
val = cur.fetchone()[0] or 0
|
|
228
|
+
cur.close()
|
|
229
|
+
conn.close()
|
|
230
|
+
return _CountResult(val)
|
|
231
|
+
|
|
232
|
+
def first(self) -> dict[str, Any] | None:
|
|
233
|
+
return self.limit(1)._execute()[0] if self.limit(1)._execute() else None
|
|
234
|
+
|
|
235
|
+
def all(self) -> list[dict[str, Any]]:
|
|
236
|
+
return self._execute()
|
|
237
|
+
|
|
238
|
+
# Assertions -----------------------------------------------------------
|
|
239
|
+
def assert_exists(self):
|
|
240
|
+
row = self.first()
|
|
241
|
+
if row is None:
|
|
242
|
+
# Build descriptive error message
|
|
243
|
+
sql, params = self._compile()
|
|
244
|
+
error_msg = (
|
|
245
|
+
f"Expected at least one matching row, but found none.\n"
|
|
246
|
+
f"Query: {sql}\n"
|
|
247
|
+
f"Parameters: {params}\n"
|
|
248
|
+
f"Table: {self._table}"
|
|
249
|
+
)
|
|
250
|
+
if hasattr(self, "_conditions") and self._conditions:
|
|
251
|
+
conditions_str = ", ".join(
|
|
252
|
+
[f"{col} {op} {val}" for col, op, val in self._conditions]
|
|
253
|
+
)
|
|
254
|
+
error_msg += f"\nConditions: {conditions_str}"
|
|
255
|
+
raise AssertionError(error_msg)
|
|
256
|
+
return self
|
|
257
|
+
|
|
258
|
+
def assert_none(self):
|
|
259
|
+
row = self.first()
|
|
260
|
+
if row is not None:
|
|
261
|
+
row_id = _get_row_identifier(row)
|
|
262
|
+
row_data = _format_row_for_error(row)
|
|
263
|
+
sql, params = self._compile()
|
|
264
|
+
error_msg = (
|
|
265
|
+
f"Expected no matching rows, but found at least one.\n"
|
|
266
|
+
f"Found row: {row_id}\n"
|
|
267
|
+
f"Row data: {row_data}\n"
|
|
268
|
+
f"Query: {sql}\n"
|
|
269
|
+
f"Parameters: {params}\n"
|
|
270
|
+
f"Table: {self._table}"
|
|
271
|
+
)
|
|
272
|
+
raise AssertionError(error_msg)
|
|
273
|
+
return self
|
|
274
|
+
|
|
275
|
+
def assert_eq(self, column: str, value: SQLValue):
|
|
276
|
+
row = self.first()
|
|
277
|
+
if row is None:
|
|
278
|
+
sql, params = self._compile()
|
|
279
|
+
error_msg = (
|
|
280
|
+
f"Row not found for equality assertion.\n"
|
|
281
|
+
f"Expected to find a row with {column}={repr(value)}\n"
|
|
282
|
+
f"Query: {sql}\n"
|
|
283
|
+
f"Parameters: {params}\n"
|
|
284
|
+
f"Table: {self._table}"
|
|
285
|
+
)
|
|
286
|
+
raise AssertionError(error_msg)
|
|
287
|
+
|
|
288
|
+
actual_value = row.get(column)
|
|
289
|
+
if actual_value != value:
|
|
290
|
+
row_id = _get_row_identifier(row)
|
|
291
|
+
row_data = _format_row_for_error(row)
|
|
292
|
+
error_msg = (
|
|
293
|
+
f"Field value assertion failed.\n"
|
|
294
|
+
f"Row: {row_id}\n"
|
|
295
|
+
f"Field: {column}\n"
|
|
296
|
+
f"Expected: {repr(value)}\n"
|
|
297
|
+
f"Actual: {repr(actual_value)}\n"
|
|
298
|
+
f"Full row data: {row_data}\n"
|
|
299
|
+
f"Table: {self._table}"
|
|
300
|
+
)
|
|
301
|
+
raise AssertionError(error_msg)
|
|
302
|
+
return self
|
|
303
|
+
|
|
304
|
+
# Misc -----------------------------------------------------------------
|
|
305
|
+
def explain(self) -> str:
|
|
306
|
+
sql, params = self._compile()
|
|
307
|
+
return f"SQL: {sql}\nParams: {params}"
|
|
308
|
+
|
|
309
|
+
# Utilities ------------------------------------------------------------
|
|
310
|
+
def _clone(self) -> "QueryBuilder": # noqa: UP037
|
|
311
|
+
qb = QueryBuilder(self._snapshot, self._table)
|
|
312
|
+
qb._select_cols = list(self._select_cols)
|
|
313
|
+
qb._conditions = list(self._conditions)
|
|
314
|
+
qb._joins = list(self._joins)
|
|
315
|
+
qb._limit = self._limit
|
|
316
|
+
qb._order_by = self._order_by
|
|
317
|
+
return qb
|
|
318
|
+
|
|
319
|
+
# Representation -------------------------------------------------------
|
|
320
|
+
def __repr__(self):
|
|
321
|
+
return f"<QueryBuilder {self.explain()}>"
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
################################################################################
|
|
325
|
+
# Snapshot Diff invariants
|
|
326
|
+
################################################################################
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
class IgnoreConfig:
|
|
330
|
+
"""Configuration for ignoring specific tables, fields, or combinations during diff operations."""
|
|
331
|
+
|
|
332
|
+
def __init__(
|
|
333
|
+
self,
|
|
334
|
+
tables: set[str] | None = None,
|
|
335
|
+
fields: set[str] | None = None,
|
|
336
|
+
table_fields: dict[str, set[str]] | None = None,
|
|
337
|
+
):
|
|
338
|
+
"""
|
|
339
|
+
Args:
|
|
340
|
+
tables: Set of table names to completely ignore
|
|
341
|
+
fields: Set of field names to ignore across all tables
|
|
342
|
+
table_fields: Dict mapping table names to sets of field names to ignore in that table
|
|
343
|
+
"""
|
|
344
|
+
self.tables = tables or set()
|
|
345
|
+
self.fields = fields or set()
|
|
346
|
+
self.table_fields = table_fields or {}
|
|
347
|
+
|
|
348
|
+
def should_ignore_table(self, table: str) -> bool:
|
|
349
|
+
"""Check if a table should be completely ignored."""
|
|
350
|
+
return table in self.tables
|
|
351
|
+
|
|
352
|
+
def should_ignore_field(self, table: str, field: str) -> bool:
|
|
353
|
+
"""Check if a specific field in a table should be ignored."""
|
|
354
|
+
# Global field ignore
|
|
355
|
+
if field in self.fields:
|
|
356
|
+
return True
|
|
357
|
+
# Table-specific field ignore
|
|
358
|
+
if table in self.table_fields and field in self.table_fields[table]:
|
|
359
|
+
return True
|
|
360
|
+
return False
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def _format_row_for_error(row: dict[str, Any], max_fields: int = 10) -> str:
|
|
364
|
+
"""Format a row dictionary for error messages with truncation if needed."""
|
|
365
|
+
if not row:
|
|
366
|
+
return "{empty row}"
|
|
367
|
+
|
|
368
|
+
items = list(row.items())
|
|
369
|
+
if len(items) <= max_fields:
|
|
370
|
+
formatted_items = [f"{k}={repr(v)}" for k, v in items]
|
|
371
|
+
return "{" + ", ".join(formatted_items) + "}"
|
|
372
|
+
else:
|
|
373
|
+
# Show first few fields and indicate truncation
|
|
374
|
+
shown_items = [f"{k}={repr(v)}" for k, v in items[:max_fields]]
|
|
375
|
+
remaining = len(items) - max_fields
|
|
376
|
+
return "{" + ", ".join(shown_items) + f", ... +{remaining} more fields" + "}"
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def _get_row_identifier(row: dict[str, Any]) -> str:
|
|
380
|
+
"""Extract a meaningful identifier from a row for error messages."""
|
|
381
|
+
# Try common ID fields first
|
|
382
|
+
for id_field in ["id", "pk", "primary_key", "key"]:
|
|
383
|
+
if id_field in row and row[id_field] is not None:
|
|
384
|
+
return f"{id_field}={repr(row[id_field])}"
|
|
385
|
+
|
|
386
|
+
# Try name fields
|
|
387
|
+
for name_field in ["name", "title", "label"]:
|
|
388
|
+
if name_field in row and row[name_field] is not None:
|
|
389
|
+
return f"{name_field}={repr(row[name_field])}"
|
|
390
|
+
|
|
391
|
+
# Fall back to first non-None field
|
|
392
|
+
for key, value in row.items():
|
|
393
|
+
if value is not None:
|
|
394
|
+
return f"{key}={repr(value)}"
|
|
395
|
+
|
|
396
|
+
return "no identifier found"
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
class SnapshotDiff:
|
|
400
|
+
"""Compute & validate changes between two snapshots."""
|
|
401
|
+
|
|
402
|
+
def __init__(
|
|
403
|
+
self,
|
|
404
|
+
before: DatabaseSnapshot,
|
|
405
|
+
after: DatabaseSnapshot,
|
|
406
|
+
ignore_config: IgnoreConfig | None = None,
|
|
407
|
+
):
|
|
408
|
+
from .sql_differ import SQLiteDiffer # local import to avoid circularity
|
|
409
|
+
|
|
410
|
+
self.before = before
|
|
411
|
+
self.after = after
|
|
412
|
+
self.ignore_config = ignore_config or IgnoreConfig()
|
|
413
|
+
self._differ = SQLiteDiffer(before.db_path, after.db_path)
|
|
414
|
+
self._cached: dict[str, Any] | None = None
|
|
415
|
+
|
|
416
|
+
# ------------------------------------------------------------------
|
|
417
|
+
def _collect(self):
|
|
418
|
+
if self._cached is not None:
|
|
419
|
+
return self._cached
|
|
420
|
+
all_tables = set(self.before.tables()) | set(self.after.tables())
|
|
421
|
+
diff: dict[str, dict[str, Any]] = {}
|
|
422
|
+
for tbl in all_tables:
|
|
423
|
+
if self.ignore_config.should_ignore_table(tbl):
|
|
424
|
+
continue
|
|
425
|
+
diff[tbl] = self._differ.diff_table(tbl)
|
|
426
|
+
self._cached = diff
|
|
427
|
+
return diff
|
|
428
|
+
|
|
429
|
+
# ------------------------------------------------------------------
|
|
430
|
+
def expect_only(self, allowed_changes: list[dict[str, Any]]):
|
|
431
|
+
"""Allowed changes is a list of {table, pk, field, after} (before optional)."""
|
|
432
|
+
diff = self._collect()
|
|
433
|
+
allowed_set = {
|
|
434
|
+
(c["table"], c.get("pk"), c.get("field"), c.get("after"))
|
|
435
|
+
for c in allowed_changes
|
|
436
|
+
}
|
|
437
|
+
|
|
438
|
+
# Collect all unexpected changes for detailed reporting
|
|
439
|
+
unexpected_changes = []
|
|
440
|
+
|
|
441
|
+
for tbl, report in diff.items():
|
|
442
|
+
for row in report.get("modified_rows", []):
|
|
443
|
+
for f, vals in row["changes"].items():
|
|
444
|
+
if self.ignore_config.should_ignore_field(tbl, f):
|
|
445
|
+
continue
|
|
446
|
+
tup = (tbl, row["row_id"], f, vals["after"])
|
|
447
|
+
if tup not in allowed_set:
|
|
448
|
+
unexpected_changes.append(
|
|
449
|
+
{
|
|
450
|
+
"type": "modification",
|
|
451
|
+
"table": tbl,
|
|
452
|
+
"row_id": row["row_id"],
|
|
453
|
+
"field": f,
|
|
454
|
+
"before": vals.get("before"),
|
|
455
|
+
"after": vals["after"],
|
|
456
|
+
"full_row": row,
|
|
457
|
+
}
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
for row in report.get("added_rows", []):
|
|
461
|
+
tup = (tbl, row["row_id"], None, "__added__")
|
|
462
|
+
if tup not in allowed_set:
|
|
463
|
+
unexpected_changes.append(
|
|
464
|
+
{
|
|
465
|
+
"type": "insertion",
|
|
466
|
+
"table": tbl,
|
|
467
|
+
"row_id": row["row_id"],
|
|
468
|
+
"field": None,
|
|
469
|
+
"after": "__added__",
|
|
470
|
+
"full_row": row,
|
|
471
|
+
}
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
for row in report.get("removed_rows", []):
|
|
475
|
+
tup = (tbl, row["row_id"], None, "__removed__")
|
|
476
|
+
if tup not in allowed_set:
|
|
477
|
+
unexpected_changes.append(
|
|
478
|
+
{
|
|
479
|
+
"type": "deletion",
|
|
480
|
+
"table": tbl,
|
|
481
|
+
"row_id": row["row_id"],
|
|
482
|
+
"field": None,
|
|
483
|
+
"after": "__removed__",
|
|
484
|
+
"full_row": row,
|
|
485
|
+
}
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
if unexpected_changes:
|
|
489
|
+
# Build comprehensive error message
|
|
490
|
+
error_lines = ["Unexpected database changes detected:"]
|
|
491
|
+
error_lines.append("")
|
|
492
|
+
|
|
493
|
+
for i, change in enumerate(
|
|
494
|
+
unexpected_changes[:5], 1
|
|
495
|
+
): # Show first 5 changes
|
|
496
|
+
error_lines.append(
|
|
497
|
+
f"{i}. {change['type'].upper()} in table '{change['table']}':"
|
|
498
|
+
)
|
|
499
|
+
error_lines.append(f" Row ID: {change['row_id']}")
|
|
500
|
+
|
|
501
|
+
if change["type"] == "modification":
|
|
502
|
+
error_lines.append(f" Field: {change['field']}")
|
|
503
|
+
error_lines.append(f" Before: {repr(change['before'])}")
|
|
504
|
+
error_lines.append(f" After: {repr(change['after'])}")
|
|
505
|
+
elif change["type"] == "insertion":
|
|
506
|
+
error_lines.append(" New row added")
|
|
507
|
+
elif change["type"] == "deletion":
|
|
508
|
+
error_lines.append(" Row deleted")
|
|
509
|
+
|
|
510
|
+
# Show some context from the row
|
|
511
|
+
if "full_row" in change and change["full_row"]:
|
|
512
|
+
row_data = change["full_row"]
|
|
513
|
+
if change["type"] == "modification" and "data" in row_data:
|
|
514
|
+
# For modifications, show the current state
|
|
515
|
+
formatted_row = _format_row_for_error(
|
|
516
|
+
row_data.get("data", {}), max_fields=5
|
|
517
|
+
)
|
|
518
|
+
error_lines.append(f" Row data: {formatted_row}")
|
|
519
|
+
elif (
|
|
520
|
+
change["type"] in ["insertion", "deletion"]
|
|
521
|
+
and "data" in row_data
|
|
522
|
+
):
|
|
523
|
+
# For insertions/deletions, show the row data
|
|
524
|
+
formatted_row = _format_row_for_error(
|
|
525
|
+
row_data.get("data", {}), max_fields=5
|
|
526
|
+
)
|
|
527
|
+
error_lines.append(f" Row data: {formatted_row}")
|
|
528
|
+
|
|
529
|
+
error_lines.append("")
|
|
530
|
+
|
|
531
|
+
if len(unexpected_changes) > 5:
|
|
532
|
+
error_lines.append(
|
|
533
|
+
f"... and {len(unexpected_changes) - 5} more unexpected changes"
|
|
534
|
+
)
|
|
535
|
+
error_lines.append("")
|
|
536
|
+
|
|
537
|
+
# Show what changes were allowed
|
|
538
|
+
error_lines.append("Allowed changes were:")
|
|
539
|
+
if allowed_changes:
|
|
540
|
+
for i, allowed in enumerate(allowed_changes[:3], 1):
|
|
541
|
+
error_lines.append(
|
|
542
|
+
f" {i}. Table: {allowed.get('table')}, "
|
|
543
|
+
f"ID: {allowed.get('pk')}, "
|
|
544
|
+
f"Field: {allowed.get('field')}, "
|
|
545
|
+
f"After: {repr(allowed.get('after'))}"
|
|
546
|
+
)
|
|
547
|
+
if len(allowed_changes) > 3:
|
|
548
|
+
error_lines.append(
|
|
549
|
+
f" ... and {len(allowed_changes) - 3} more allowed changes"
|
|
550
|
+
)
|
|
551
|
+
else:
|
|
552
|
+
error_lines.append(" (No changes were allowed)")
|
|
553
|
+
|
|
554
|
+
raise AssertionError("\n".join(error_lines))
|
|
555
|
+
|
|
556
|
+
return self
|
|
557
|
+
|
|
558
|
+
def expect(
|
|
559
|
+
self, *, allow: list[dict[str, Any]] = None, forbid: list[dict[str, Any]] = None
|
|
560
|
+
):
|
|
561
|
+
"""More granular: allow / forbid per‑table and per‑field."""
|
|
562
|
+
allow = allow or []
|
|
563
|
+
forbid = forbid or []
|
|
564
|
+
allow_tbl_field = {(c["table"], c.get("field")) for c in allow}
|
|
565
|
+
forbid_tbl_field = {(c["table"], c.get("field")) for c in forbid}
|
|
566
|
+
diff = self._collect()
|
|
567
|
+
for tbl, report in diff.items():
|
|
568
|
+
for row in report.get("modified_rows", []):
|
|
569
|
+
for f in row["changed"].keys():
|
|
570
|
+
if self.ignore_config.should_ignore_field(tbl, f):
|
|
571
|
+
continue
|
|
572
|
+
key = (tbl, f)
|
|
573
|
+
if key in forbid_tbl_field:
|
|
574
|
+
raise AssertionError(f"Modification to forbidden field {key}")
|
|
575
|
+
if allow_tbl_field and key not in allow_tbl_field:
|
|
576
|
+
raise AssertionError(f"Modification to unallowed field {key}")
|
|
577
|
+
if (tbl, None) in forbid_tbl_field and (
|
|
578
|
+
report.get("added_rows") or report.get("removed_rows")
|
|
579
|
+
):
|
|
580
|
+
raise AssertionError(f"Changes in forbidden table {tbl}")
|
|
581
|
+
return self
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
################################################################################
|
|
585
|
+
# DatabaseSnapshot with DSL entrypoints
|
|
586
|
+
################################################################################
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
class DatabaseSnapshot:
|
|
590
|
+
"""Represents a snapshot of an SQLite DB with DSL entrypoints."""
|
|
591
|
+
|
|
592
|
+
def __init__(self, db_path: str, *, name: str | None = None):
|
|
593
|
+
self.db_path = db_path
|
|
594
|
+
self.name = name or f"snapshot_{datetime.utcnow().isoformat()}"
|
|
595
|
+
self.created_at = datetime.utcnow()
|
|
596
|
+
|
|
597
|
+
# DSL entry ------------------------------------------------------------
|
|
598
|
+
def table(self, table: str) -> QueryBuilder:
|
|
599
|
+
return QueryBuilder(self, table)
|
|
600
|
+
|
|
601
|
+
# Metadata -------------------------------------------------------------
|
|
602
|
+
def tables(self) -> list[str]:
|
|
603
|
+
conn = sqlite3.connect(self.db_path)
|
|
604
|
+
cur = conn.cursor()
|
|
605
|
+
cur.execute(
|
|
606
|
+
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
|
|
607
|
+
)
|
|
608
|
+
tbls = [r[0] for r in cur.fetchall()]
|
|
609
|
+
cur.close()
|
|
610
|
+
conn.close()
|
|
611
|
+
return tbls
|
|
612
|
+
|
|
613
|
+
# Diff interface -------------------------------------------------------
|
|
614
|
+
def diff(
|
|
615
|
+
self,
|
|
616
|
+
other: "DatabaseSnapshot", # noqa: UP037
|
|
617
|
+
ignore_config: IgnoreConfig | None = None,
|
|
618
|
+
) -> SnapshotDiff:
|
|
619
|
+
return SnapshotDiff(self, other, ignore_config)
|
|
620
|
+
|
|
621
|
+
############################################################################
|
|
622
|
+
# Convenience, schema‑agnostic expectation helpers
|
|
623
|
+
############################################################################
|
|
624
|
+
|
|
625
|
+
def expect_row(
|
|
626
|
+
self, table: str, where: dict[str, SQLValue], expect: dict[str, SQLValue]
|
|
627
|
+
):
|
|
628
|
+
qb = self.table(table)
|
|
629
|
+
for k, v in where.items():
|
|
630
|
+
qb = qb.eq(k, v)
|
|
631
|
+
qb.assert_exists()
|
|
632
|
+
for col, val in expect.items():
|
|
633
|
+
qb.assert_eq(col, val)
|
|
634
|
+
return self
|
|
635
|
+
|
|
636
|
+
def expect_rows(
|
|
637
|
+
self,
|
|
638
|
+
table: str,
|
|
639
|
+
where: dict[str, SQLValue],
|
|
640
|
+
*,
|
|
641
|
+
count: int | None = None,
|
|
642
|
+
contains: list[dict[str, SQLValue]] | None = None,
|
|
643
|
+
):
|
|
644
|
+
qb = self.table(table)
|
|
645
|
+
for k, v in where.items():
|
|
646
|
+
qb = qb.eq(k, v)
|
|
647
|
+
if count is not None:
|
|
648
|
+
qb.count().assert_eq(count)
|
|
649
|
+
if contains:
|
|
650
|
+
rows = qb.all()
|
|
651
|
+
for cond in contains:
|
|
652
|
+
matched = any(all(r.get(k) == v for k, v in cond.items()) for r in rows)
|
|
653
|
+
if not matched:
|
|
654
|
+
raise AssertionError(f"Expected a row matching {cond} in {table}")
|
|
655
|
+
return self
|
|
656
|
+
|
|
657
|
+
def expect_absent_row(self, table: str, where: dict[str, SQLValue]):
|
|
658
|
+
qb = self.table(table)
|
|
659
|
+
for k, v in where.items():
|
|
660
|
+
qb = qb.eq(k, v)
|
|
661
|
+
qb.assert_none()
|
|
662
|
+
return self
|
|
663
|
+
|
|
664
|
+
# ---------------------------------------------------------------------
|
|
665
|
+
def __repr__(self):
|
|
666
|
+
return f"<DatabaseSnapshot {self.name} at {self.db_path}>"
|