fleet-python 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fleet-python might be problematic. Click here for more details.

fleet/verifiers/db.py ADDED
@@ -0,0 +1,706 @@
1
+ """A schema‑agnostic, SQL‑native DSL for snapshot validation and diff invariants.
2
+
3
+ The module extends your original `DatabaseSnapshot` implementation with
4
+
5
+ * A **Supabase‑style query builder** (method‑chaining: `select`, `eq`, `join`, …).
6
+ * Assertion helpers (`assert_exists`, `assert_none`, `assert_eq`, `count().assert_eq`, …).
7
+ * A `SnapshotDiff` engine that enforces invariants (`expect_only`, `expect`).
8
+ * Convenience helpers (`expect_row`, `expect_rows`, `expect_absent_row`).
9
+
10
+ The public API stays tiny yet composable; everything else is built on
11
+ orthogonal primitives so it works for *any* relational schema.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import sqlite3
17
+ from datetime import datetime
18
+ from typing import Any
19
+ import json
20
+
21
+ ################################################################################
22
+ # Low‑level helpers
23
+ ################################################################################
24
+
25
+ SQLValue = str | int | float | None
26
+ Condition = tuple[str, str, SQLValue] # (column, op, value)
27
+ JoinSpec = tuple[str, dict[str, str]] # (table, on mapping)
28
+
29
+
30
+ def _is_json_string(value: Any) -> bool:
31
+ """Check if a value looks like a JSON string."""
32
+ if not isinstance(value, str):
33
+ return False
34
+ value = value.strip()
35
+ return (value.startswith("{") and value.endswith("}")) or (
36
+ value.startswith("[") and value.endswith("]")
37
+ )
38
+
39
+
40
+ def _values_equivalent(val1: Any, val2: Any) -> bool:
41
+ """Compare two values, using JSON semantic comparison for JSON strings."""
42
+ # If both are exactly equal, return True
43
+ if val1 == val2:
44
+ return True
45
+
46
+ # If both look like JSON strings, try semantic comparison
47
+ if _is_json_string(val1) and _is_json_string(val2):
48
+ try:
49
+ parsed1 = json.loads(val1)
50
+ parsed2 = json.loads(val2)
51
+ return parsed1 == parsed2
52
+ except (json.JSONDecodeError, TypeError):
53
+ # If parsing fails, fall back to string comparison
54
+ pass
55
+
56
+ # Default to exact comparison
57
+ return val1 == val2
58
+
59
+
60
+ class _CountResult:
61
+ """Wraps an integer count so we can chain assertions fluently."""
62
+
63
+ def __init__(self, value: int):
64
+ self.value = value
65
+
66
+ # Assertions ------------------------------------------------------------
67
+ def assert_eq(self, expected: int):
68
+ if self.value != expected:
69
+ raise AssertionError(f"Expected {expected}, got {self.value}")
70
+ return self
71
+
72
+ def assert_gt(self, threshold: int):
73
+ if self.value <= threshold:
74
+ raise AssertionError(f"Expected > {threshold}, got {self.value}")
75
+ return self
76
+
77
+ def assert_between(self, low: int, high: int):
78
+ if not low <= self.value <= high:
79
+ raise AssertionError(f"Expected {low}‑{high}, got {self.value}")
80
+ return self
81
+
82
+ # Convenience -----------------------------------------------------------
83
+ def __int__(self):
84
+ return self.value
85
+
86
+ def __repr__(self):
87
+ return f"<Count {self.value}>"
88
+
89
+
90
+ ################################################################################
91
+ # Query Builder
92
+ ################################################################################
93
+
94
+
95
+ class QueryBuilder:
96
+ """Fluent SQL builder executed against a single `DatabaseSnapshot`."""
97
+
98
+ def __init__(self, snapshot: "DatabaseSnapshot", table: str): # noqa: UP037
99
+ self._snapshot = snapshot
100
+ self._table = table
101
+ self._select_cols: list[str] = ["*"]
102
+ self._conditions: list[Condition] = []
103
+ self._joins: list[JoinSpec] = []
104
+ self._limit: int | None = None
105
+ self._order_by: str | None = None
106
+ # Cache for idempotent executions
107
+ self._cached_rows: list[dict[str, Any]] | None = None
108
+
109
+ # ---------------------------------------------------------------------
110
+ # Column projection / limiting / ordering
111
+ # ---------------------------------------------------------------------
112
+ def select(self, *columns: str) -> "QueryBuilder": # noqa: UP037
113
+ qb = self._clone()
114
+ qb._select_cols = list(columns) if columns else ["*"]
115
+ return qb
116
+
117
+ def limit(self, n: int) -> "QueryBuilder": # noqa: UP037
118
+ qb = self._clone()
119
+ qb._limit = n
120
+ return qb
121
+
122
+ def sort(self, column: str, desc: bool = False) -> "QueryBuilder": # noqa: UP037
123
+ qb = self._clone()
124
+ qb._order_by = f"{column} {'DESC' if desc else 'ASC'}"
125
+ return qb
126
+
127
+ # ---------------------------------------------------------------------
128
+ # WHERE helpers (SQL‑like)
129
+ # ---------------------------------------------------------------------
130
+ def _add_condition(self, column: str, op: str, value: SQLValue) -> "QueryBuilder": # noqa: UP037
131
+ qb = self._clone()
132
+ qb._conditions.append((column, op, value))
133
+ return qb
134
+
135
+ def eq(self, column: str, value: SQLValue) -> "QueryBuilder": # noqa: UP037
136
+ return self._add_condition(column, "=", value)
137
+
138
+ def neq(self, column: str, value: SQLValue) -> "QueryBuilder": # noqa: UP037
139
+ return self._add_condition(column, "!=", value)
140
+
141
+ def gt(self, column: str, value: SQLValue) -> "QueryBuilder": # noqa: UP037
142
+ return self._add_condition(column, ">", value)
143
+
144
+ def gte(self, column: str, value: SQLValue) -> "QueryBuilder": # noqa: UP037
145
+ return self._add_condition(column, ">=", value)
146
+
147
+ def lt(self, column: str, value: SQLValue) -> "QueryBuilder": # noqa: UP037
148
+ return self._add_condition(column, "<", value)
149
+
150
+ def lte(self, column: str, value: SQLValue) -> "QueryBuilder": # noqa: UP037
151
+ return self._add_condition(column, "<=", value)
152
+
153
+ def in_(self, column: str, values: list[SQLValue]) -> "QueryBuilder": # noqa: UP037
154
+ qb = self._clone()
155
+ qb._conditions.append((column, "IN", tuple(values)))
156
+ return qb
157
+
158
+ def not_in(self, column: str, values: list[SQLValue]) -> "QueryBuilder": # noqa: UP037
159
+ qb = self._clone()
160
+ qb._conditions.append((column, "NOT IN", tuple(values)))
161
+ return qb
162
+
163
+ def is_null(self, column: str) -> "QueryBuilder": # noqa: UP037
164
+ return self._add_condition(column, "IS", None)
165
+
166
+ def not_null(self, column: str) -> "QueryBuilder": # noqa: UP037
167
+ return self._add_condition(column, "IS NOT", None)
168
+
169
+ def ilike(self, column: str, pattern: str) -> "QueryBuilder": # noqa: UP037
170
+ qb = self._clone()
171
+ qb._conditions.append((column, "ILIKE", pattern))
172
+ return qb
173
+
174
+ # ---------------------------------------------------------------------
175
+ # JOIN (simple inner join)
176
+ # ---------------------------------------------------------------------
177
+ def join(self, other_table: str, on: dict[str, str]) -> "QueryBuilder": # noqa: UP037
178
+ """`on` expects {local_col: remote_col}."""
179
+ qb = self._clone()
180
+ qb._joins.append((other_table, on))
181
+ return qb
182
+
183
+ # ---------------------------------------------------------------------
184
+ # Execution helpers
185
+ # ---------------------------------------------------------------------
186
+ def _compile(self) -> tuple[str, list[Any]]:
187
+ cols = ", ".join(self._select_cols)
188
+ sql = [f"SELECT {cols} FROM {self._table}"]
189
+ params: list[Any] = []
190
+
191
+ # Joins -------------------------------------------------------------
192
+ for tbl, onmap in self._joins:
193
+ join_clauses = [
194
+ f"{self._table}.{l} = {tbl}.{r}"
195
+ for l, r in onmap.items() # noqa: E741
196
+ ]
197
+ sql.append(f"JOIN {tbl} ON {' AND '.join(join_clauses)}")
198
+
199
+ # WHERE -------------------------------------------------------------
200
+ if self._conditions:
201
+ placeholders = []
202
+ for col, op, val in self._conditions:
203
+ if op in ("IN", "NOT IN") and isinstance(val, tuple):
204
+ ph = ", ".join(["?" for _ in val])
205
+ placeholders.append(f"{col} {op} ({ph})")
206
+ params.extend(val)
207
+ elif op in ("IS", "IS NOT"):
208
+ placeholders.append(f"{col} {op} NULL")
209
+ elif op == "ILIKE":
210
+ placeholders.append(
211
+ f"{col} LIKE ?"
212
+ ) # SQLite has no ILIKE; LIKE is case‑insensitive when in NOCASE collation
213
+ params.append(val)
214
+ else:
215
+ placeholders.append(f"{col} {op} ?")
216
+ params.append(val)
217
+ sql.append("WHERE " + " AND ".join(placeholders))
218
+
219
+ # ORDER / LIMIT -----------------------------------------------------
220
+ if self._order_by:
221
+ sql.append(f"ORDER BY {self._order_by}")
222
+ if self._limit is not None:
223
+ sql.append(f"LIMIT {self._limit}")
224
+
225
+ return " ".join(sql), params
226
+
227
+ def _execute(self) -> list[dict[str, Any]]:
228
+ if self._cached_rows is not None:
229
+ return self._cached_rows
230
+
231
+ sql, params = self._compile()
232
+ conn = sqlite3.connect(self._snapshot.db_path)
233
+ conn.row_factory = sqlite3.Row
234
+ cur = conn.cursor()
235
+ cur.execute(sql, params)
236
+ rows = [dict(r) for r in cur.fetchall()]
237
+ cur.close()
238
+ conn.close()
239
+ self._cached_rows = rows
240
+ return rows
241
+
242
+ # ---------------------------------------------------------------------
243
+ # High‑level result helpers / assertions
244
+ # ---------------------------------------------------------------------
245
+ def count(self) -> _CountResult:
246
+ qb = self.select("COUNT(*) AS __cnt__").limit(
247
+ None
248
+ ) # remove limit since counting overrides
249
+ sql, params = qb._compile()
250
+ conn = sqlite3.connect(self._snapshot.db_path)
251
+ cur = conn.cursor()
252
+ cur.execute(sql, params)
253
+ val = cur.fetchone()[0] or 0
254
+ cur.close()
255
+ conn.close()
256
+ return _CountResult(val)
257
+
258
+ def first(self) -> dict[str, Any] | None:
259
+ return self.limit(1)._execute()[0] if self.limit(1)._execute() else None
260
+
261
+ def all(self) -> list[dict[str, Any]]:
262
+ return self._execute()
263
+
264
+ # Assertions -----------------------------------------------------------
265
+ def assert_exists(self):
266
+ row = self.first()
267
+ if row is None:
268
+ # Build descriptive error message
269
+ sql, params = self._compile()
270
+ error_msg = (
271
+ f"Expected at least one matching row, but found none.\n"
272
+ f"Query: {sql}\n"
273
+ f"Parameters: {params}\n"
274
+ f"Table: {self._table}"
275
+ )
276
+ if hasattr(self, "_conditions") and self._conditions:
277
+ conditions_str = ", ".join(
278
+ [f"{col} {op} {val}" for col, op, val in self._conditions]
279
+ )
280
+ error_msg += f"\nConditions: {conditions_str}"
281
+ raise AssertionError(error_msg)
282
+ return self
283
+
284
+ def assert_none(self):
285
+ row = self.first()
286
+ if row is not None:
287
+ row_id = _get_row_identifier(row)
288
+ row_data = _format_row_for_error(row)
289
+ sql, params = self._compile()
290
+ error_msg = (
291
+ f"Expected no matching rows, but found at least one.\n"
292
+ f"Found row: {row_id}\n"
293
+ f"Row data: {row_data}\n"
294
+ f"Query: {sql}\n"
295
+ f"Parameters: {params}\n"
296
+ f"Table: {self._table}"
297
+ )
298
+ raise AssertionError(error_msg)
299
+ return self
300
+
301
+ def assert_eq(self, column: str, value: SQLValue):
302
+ row = self.first()
303
+ if row is None:
304
+ sql, params = self._compile()
305
+ error_msg = (
306
+ f"Row not found for equality assertion.\n"
307
+ f"Expected to find a row with {column}={repr(value)}\n"
308
+ f"Query: {sql}\n"
309
+ f"Parameters: {params}\n"
310
+ f"Table: {self._table}"
311
+ )
312
+ raise AssertionError(error_msg)
313
+
314
+ actual_value = row.get(column)
315
+ if actual_value != value:
316
+ row_id = _get_row_identifier(row)
317
+ row_data = _format_row_for_error(row)
318
+ error_msg = (
319
+ f"Field value assertion failed.\n"
320
+ f"Row: {row_id}\n"
321
+ f"Field: {column}\n"
322
+ f"Expected: {repr(value)}\n"
323
+ f"Actual: {repr(actual_value)}\n"
324
+ f"Full row data: {row_data}\n"
325
+ f"Table: {self._table}"
326
+ )
327
+ raise AssertionError(error_msg)
328
+ return self
329
+
330
+ # Misc -----------------------------------------------------------------
331
+ def explain(self) -> str:
332
+ sql, params = self._compile()
333
+ return f"SQL: {sql}\nParams: {params}"
334
+
335
+ # Utilities ------------------------------------------------------------
336
+ def _clone(self) -> "QueryBuilder": # noqa: UP037
337
+ qb = QueryBuilder(self._snapshot, self._table)
338
+ qb._select_cols = list(self._select_cols)
339
+ qb._conditions = list(self._conditions)
340
+ qb._joins = list(self._joins)
341
+ qb._limit = self._limit
342
+ qb._order_by = self._order_by
343
+ return qb
344
+
345
+ # Representation -------------------------------------------------------
346
+ def __repr__(self):
347
+ return f"<QueryBuilder {self.explain()}>"
348
+
349
+
350
+ ################################################################################
351
+ # Snapshot Diff invariants
352
+ ################################################################################
353
+
354
+
355
+ class IgnoreConfig:
356
+ """Configuration for ignoring specific tables, fields, or combinations during diff operations."""
357
+
358
+ def __init__(
359
+ self,
360
+ tables: set[str] | None = None,
361
+ fields: set[str] | None = None,
362
+ table_fields: dict[str, set[str]] | None = None,
363
+ ):
364
+ """
365
+ Args:
366
+ tables: Set of table names to completely ignore
367
+ fields: Set of field names to ignore across all tables
368
+ table_fields: Dict mapping table names to sets of field names to ignore in that table
369
+ """
370
+ self.tables = tables or set()
371
+ self.fields = fields or set()
372
+ self.table_fields = table_fields or {}
373
+
374
+ def should_ignore_table(self, table: str) -> bool:
375
+ """Check if a table should be completely ignored."""
376
+ return table in self.tables
377
+
378
+ def should_ignore_field(self, table: str, field: str) -> bool:
379
+ """Check if a specific field in a table should be ignored."""
380
+ # Global field ignore
381
+ if field in self.fields:
382
+ return True
383
+ # Table-specific field ignore
384
+ if table in self.table_fields and field in self.table_fields[table]:
385
+ return True
386
+ return False
387
+
388
+
389
+ def _format_row_for_error(row: dict[str, Any], max_fields: int = 10) -> str:
390
+ """Format a row dictionary for error messages with truncation if needed."""
391
+ if not row:
392
+ return "{empty row}"
393
+
394
+ items = list(row.items())
395
+ if len(items) <= max_fields:
396
+ formatted_items = [f"{k}={repr(v)}" for k, v in items]
397
+ return "{" + ", ".join(formatted_items) + "}"
398
+ else:
399
+ # Show first few fields and indicate truncation
400
+ shown_items = [f"{k}={repr(v)}" for k, v in items[:max_fields]]
401
+ remaining = len(items) - max_fields
402
+ return "{" + ", ".join(shown_items) + f", ... +{remaining} more fields" + "}"
403
+
404
+
405
+ def _get_row_identifier(row: dict[str, Any]) -> str:
406
+ """Extract a meaningful identifier from a row for error messages."""
407
+ # Try common ID fields first
408
+ for id_field in ["id", "pk", "primary_key", "key"]:
409
+ if id_field in row and row[id_field] is not None:
410
+ return f"{id_field}={repr(row[id_field])}"
411
+
412
+ # Try name fields
413
+ for name_field in ["name", "title", "label"]:
414
+ if name_field in row and row[name_field] is not None:
415
+ return f"{name_field}={repr(row[name_field])}"
416
+
417
+ # Fall back to first non-None field
418
+ for key, value in row.items():
419
+ if value is not None:
420
+ return f"{key}={repr(value)}"
421
+
422
+ return "no identifier found"
423
+
424
+
425
+ class SnapshotDiff:
426
+ """Compute & validate changes between two snapshots."""
427
+
428
+ def __init__(
429
+ self,
430
+ before: DatabaseSnapshot,
431
+ after: DatabaseSnapshot,
432
+ ignore_config: IgnoreConfig | None = None,
433
+ ):
434
+ from .sql_differ import SQLiteDiffer # local import to avoid circularity
435
+
436
+ self.before = before
437
+ self.after = after
438
+ self.ignore_config = ignore_config or IgnoreConfig()
439
+ self._differ = SQLiteDiffer(before.db_path, after.db_path)
440
+ self._cached: dict[str, Any] | None = None
441
+
442
+ # ------------------------------------------------------------------
443
+ def _collect(self):
444
+ if self._cached is not None:
445
+ return self._cached
446
+ all_tables = set(self.before.tables()) | set(self.after.tables())
447
+ diff: dict[str, dict[str, Any]] = {}
448
+ for tbl in all_tables:
449
+ if self.ignore_config.should_ignore_table(tbl):
450
+ continue
451
+ diff[tbl] = self._differ.diff_table(tbl)
452
+ self._cached = diff
453
+ return diff
454
+
455
+ # ------------------------------------------------------------------
456
+ def expect_only(self, allowed_changes: list[dict[str, Any]]):
457
+ """Allowed changes is a list of {table, pk, field, after} (before optional)."""
458
+ diff = self._collect()
459
+
460
+ def _is_change_allowed(
461
+ table: str, row_id: str, field: str | None, after_value: Any
462
+ ) -> bool:
463
+ """Check if a change is in the allowed list using semantic comparison."""
464
+ for allowed in allowed_changes:
465
+ allowed_pk = allowed.get("pk")
466
+ # Handle type conversion for primary key comparison
467
+ # Convert both to strings for comparison to handle int/string mismatches
468
+ pk_match = (
469
+ str(allowed_pk) == str(row_id) if allowed_pk is not None else False
470
+ )
471
+
472
+ if (
473
+ allowed["table"] == table
474
+ and pk_match
475
+ and allowed.get("field") == field
476
+ and _values_equivalent(allowed.get("after"), after_value)
477
+ ):
478
+ return True
479
+ return False
480
+
481
+ # Collect all unexpected changes for detailed reporting
482
+ unexpected_changes = []
483
+
484
+ for tbl, report in diff.items():
485
+ for row in report.get("modified_rows", []):
486
+ for f, vals in row["changes"].items():
487
+ if self.ignore_config.should_ignore_field(tbl, f):
488
+ continue
489
+ if not _is_change_allowed(tbl, row["row_id"], f, vals["after"]):
490
+ unexpected_changes.append(
491
+ {
492
+ "type": "modification",
493
+ "table": tbl,
494
+ "row_id": row["row_id"],
495
+ "field": f,
496
+ "before": vals.get("before"),
497
+ "after": vals["after"],
498
+ "full_row": row,
499
+ }
500
+ )
501
+
502
+ for row in report.get("added_rows", []):
503
+ if not _is_change_allowed(tbl, row["row_id"], None, "__added__"):
504
+ unexpected_changes.append(
505
+ {
506
+ "type": "insertion",
507
+ "table": tbl,
508
+ "row_id": row["row_id"],
509
+ "field": None,
510
+ "after": "__added__",
511
+ "full_row": row,
512
+ }
513
+ )
514
+
515
+ for row in report.get("removed_rows", []):
516
+ if not _is_change_allowed(tbl, row["row_id"], None, "__removed__"):
517
+ unexpected_changes.append(
518
+ {
519
+ "type": "deletion",
520
+ "table": tbl,
521
+ "row_id": row["row_id"],
522
+ "field": None,
523
+ "after": "__removed__",
524
+ "full_row": row,
525
+ }
526
+ )
527
+
528
+ if unexpected_changes:
529
+ # Build comprehensive error message
530
+ error_lines = ["Unexpected database changes detected:"]
531
+ error_lines.append("")
532
+
533
+ for i, change in enumerate(
534
+ unexpected_changes[:5], 1
535
+ ): # Show first 5 changes
536
+ error_lines.append(
537
+ f"{i}. {change['type'].upper()} in table '{change['table']}':"
538
+ )
539
+ error_lines.append(f" Row ID: {change['row_id']}")
540
+
541
+ if change["type"] == "modification":
542
+ error_lines.append(f" Field: {change['field']}")
543
+ error_lines.append(f" Before: {repr(change['before'])}")
544
+ error_lines.append(f" After: {repr(change['after'])}")
545
+ elif change["type"] == "insertion":
546
+ error_lines.append(" New row added")
547
+ elif change["type"] == "deletion":
548
+ error_lines.append(" Row deleted")
549
+
550
+ # Show some context from the row
551
+ if "full_row" in change and change["full_row"]:
552
+ row_data = change["full_row"]
553
+ if change["type"] == "modification" and "data" in row_data:
554
+ # For modifications, show the current state
555
+ formatted_row = _format_row_for_error(
556
+ row_data.get("data", {}), max_fields=5
557
+ )
558
+ error_lines.append(f" Row data: {formatted_row}")
559
+ elif (
560
+ change["type"] in ["insertion", "deletion"]
561
+ and "data" in row_data
562
+ ):
563
+ # For insertions/deletions, show the row data
564
+ formatted_row = _format_row_for_error(
565
+ row_data.get("data", {}), max_fields=5
566
+ )
567
+ error_lines.append(f" Row data: {formatted_row}")
568
+
569
+ error_lines.append("")
570
+
571
+ if len(unexpected_changes) > 5:
572
+ error_lines.append(
573
+ f"... and {len(unexpected_changes) - 5} more unexpected changes"
574
+ )
575
+ error_lines.append("")
576
+
577
+ # Show what changes were allowed
578
+ error_lines.append("Allowed changes were:")
579
+ if allowed_changes:
580
+ for i, allowed in enumerate(allowed_changes[:3], 1):
581
+ error_lines.append(
582
+ f" {i}. Table: {allowed.get('table')}, "
583
+ f"ID: {allowed.get('pk')}, "
584
+ f"Field: {allowed.get('field')}, "
585
+ f"After: {repr(allowed.get('after'))}"
586
+ )
587
+ if len(allowed_changes) > 3:
588
+ error_lines.append(
589
+ f" ... and {len(allowed_changes) - 3} more allowed changes"
590
+ )
591
+ else:
592
+ error_lines.append(" (No changes were allowed)")
593
+
594
+ raise AssertionError("\n".join(error_lines))
595
+
596
+ return self
597
+
598
+ def expect(
599
+ self, *, allow: list[dict[str, Any]] = None, forbid: list[dict[str, Any]] = None
600
+ ):
601
+ """More granular: allow / forbid per‑table and per‑field."""
602
+ allow = allow or []
603
+ forbid = forbid or []
604
+ allow_tbl_field = {(c["table"], c.get("field")) for c in allow}
605
+ forbid_tbl_field = {(c["table"], c.get("field")) for c in forbid}
606
+ diff = self._collect()
607
+ for tbl, report in diff.items():
608
+ for row in report.get("modified_rows", []):
609
+ for f in row["changed"].keys():
610
+ if self.ignore_config.should_ignore_field(tbl, f):
611
+ continue
612
+ key = (tbl, f)
613
+ if key in forbid_tbl_field:
614
+ raise AssertionError(f"Modification to forbidden field {key}")
615
+ if allow_tbl_field and key not in allow_tbl_field:
616
+ raise AssertionError(f"Modification to unallowed field {key}")
617
+ if (tbl, None) in forbid_tbl_field and (
618
+ report.get("added_rows") or report.get("removed_rows")
619
+ ):
620
+ raise AssertionError(f"Changes in forbidden table {tbl}")
621
+ return self
622
+
623
+
624
+ ################################################################################
625
+ # DatabaseSnapshot with DSL entrypoints
626
+ ################################################################################
627
+
628
+
629
+ class DatabaseSnapshot:
630
+ """Represents a snapshot of an SQLite DB with DSL entrypoints."""
631
+
632
+ def __init__(self, db_path: str, *, name: str | None = None):
633
+ self.db_path = db_path
634
+ self.name = name or f"snapshot_{datetime.utcnow().isoformat()}"
635
+ self.created_at = datetime.utcnow()
636
+
637
+ # DSL entry ------------------------------------------------------------
638
+ def table(self, table: str) -> QueryBuilder:
639
+ return QueryBuilder(self, table)
640
+
641
+ # Metadata -------------------------------------------------------------
642
+ def tables(self) -> list[str]:
643
+ conn = sqlite3.connect(self.db_path)
644
+ cur = conn.cursor()
645
+ cur.execute(
646
+ "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
647
+ )
648
+ tbls = [r[0] for r in cur.fetchall()]
649
+ cur.close()
650
+ conn.close()
651
+ return tbls
652
+
653
+ # Diff interface -------------------------------------------------------
654
+ def diff(
655
+ self,
656
+ other: "DatabaseSnapshot", # noqa: UP037
657
+ ignore_config: IgnoreConfig | None = None,
658
+ ) -> SnapshotDiff:
659
+ return SnapshotDiff(self, other, ignore_config)
660
+
661
+ ############################################################################
662
+ # Convenience, schema‑agnostic expectation helpers
663
+ ############################################################################
664
+
665
+ def expect_row(
666
+ self, table: str, where: dict[str, SQLValue], expect: dict[str, SQLValue]
667
+ ):
668
+ qb = self.table(table)
669
+ for k, v in where.items():
670
+ qb = qb.eq(k, v)
671
+ qb.assert_exists()
672
+ for col, val in expect.items():
673
+ qb.assert_eq(col, val)
674
+ return self
675
+
676
+ def expect_rows(
677
+ self,
678
+ table: str,
679
+ where: dict[str, SQLValue],
680
+ *,
681
+ count: int | None = None,
682
+ contains: list[dict[str, SQLValue]] | None = None,
683
+ ):
684
+ qb = self.table(table)
685
+ for k, v in where.items():
686
+ qb = qb.eq(k, v)
687
+ if count is not None:
688
+ qb.count().assert_eq(count)
689
+ if contains:
690
+ rows = qb.all()
691
+ for cond in contains:
692
+ matched = any(all(r.get(k) == v for k, v in cond.items()) for r in rows)
693
+ if not matched:
694
+ raise AssertionError(f"Expected a row matching {cond} in {table}")
695
+ return self
696
+
697
+ def expect_absent_row(self, table: str, where: dict[str, SQLValue]):
698
+ qb = self.table(table)
699
+ for k, v in where.items():
700
+ qb = qb.eq(k, v)
701
+ qb.assert_none()
702
+ return self
703
+
704
+ # ---------------------------------------------------------------------
705
+ def __repr__(self):
706
+ return f"<DatabaseSnapshot {self.name} at {self.db_path}>"