fleet-python 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fleet-python might be problematic. Click here for more details.

@@ -0,0 +1,666 @@
1
+ # database_dsl.py
2
+ """A schema‑agnostic, SQL‑native DSL for snapshot validation and diff invariants.
3
+
4
+ The module extends your original `DatabaseSnapshot` implementation with
5
+
6
+ * A **Supabase‑style query builder** (method‑chaining: `select`, `eq`, `join`, …).
7
+ * Assertion helpers (`assert_exists`, `assert_none`, `assert_eq`, `count().assert_eq`, …).
8
+ * A `SnapshotDiff` engine that enforces invariants (`expect_only`, `expect`).
9
+ * Convenience helpers (`expect_row`, `expect_rows`, `expect_absent_row`).
10
+
11
+ The public API stays tiny yet composable; everything else is built on
12
+ orthogonal primitives so it works for *any* relational schema.
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import sqlite3
17
+ from datetime import datetime
18
+ from typing import Any
19
+
20
+ ################################################################################
21
+ # Low‑level helpers
22
+ ################################################################################
23
+
24
+ SQLValue = str | int | float | None
25
+ Condition = tuple[str, str, SQLValue] # (column, op, value)
26
+ JoinSpec = tuple[str, dict[str, str]] # (table, on mapping)
27
+
28
+
29
+ class _CountResult:
30
+ """Wraps an integer count so we can chain assertions fluently."""
31
+
32
+ def __init__(self, value: int):
33
+ self.value = value
34
+
35
+ # Assertions ------------------------------------------------------------
36
+ def assert_eq(self, expected: int):
37
+ if self.value != expected:
38
+ raise AssertionError(f"Expected {expected}, got {self.value}")
39
+ return self
40
+
41
+ def assert_gt(self, threshold: int):
42
+ if self.value <= threshold:
43
+ raise AssertionError(f"Expected > {threshold}, got {self.value}")
44
+ return self
45
+
46
+ def assert_between(self, low: int, high: int):
47
+ if not low <= self.value <= high:
48
+ raise AssertionError(f"Expected {low}‑{high}, got {self.value}")
49
+ return self
50
+
51
+ # Convenience -----------------------------------------------------------
52
+ def __int__(self):
53
+ return self.value
54
+
55
+ def __repr__(self):
56
+ return f"<Count {self.value}>"
57
+
58
+
59
+ ################################################################################
60
+ # Query Builder
61
+ ################################################################################
62
+
63
+
64
+ class QueryBuilder:
65
+ """Fluent SQL builder executed against a single `DatabaseSnapshot`."""
66
+
67
+ def __init__(self, snapshot: "DatabaseSnapshot", table: str): # noqa: UP037
68
+ self._snapshot = snapshot
69
+ self._table = table
70
+ self._select_cols: list[str] = ["*"]
71
+ self._conditions: list[Condition] = []
72
+ self._joins: list[JoinSpec] = []
73
+ self._limit: int | None = None
74
+ self._order_by: str | None = None
75
+ # Cache for idempotent executions
76
+ self._cached_rows: list[dict[str, Any]] | None = None
77
+
78
+ # ---------------------------------------------------------------------
79
+ # Column projection / limiting / ordering
80
+ # ---------------------------------------------------------------------
81
+ def select(self, *columns: str) -> "QueryBuilder": # noqa: UP037
82
+ qb = self._clone()
83
+ qb._select_cols = list(columns) if columns else ["*"]
84
+ return qb
85
+
86
+ def limit(self, n: int) -> "QueryBuilder": # noqa: UP037
87
+ qb = self._clone()
88
+ qb._limit = n
89
+ return qb
90
+
91
+ def sort(self, column: str, desc: bool = False) -> "QueryBuilder": # noqa: UP037
92
+ qb = self._clone()
93
+ qb._order_by = f"{column} {'DESC' if desc else 'ASC'}"
94
+ return qb
95
+
96
+ # ---------------------------------------------------------------------
97
+ # WHERE helpers (SQL‑like)
98
+ # ---------------------------------------------------------------------
99
+ def _add_condition(
100
+ self, column: str, op: str, value: SQLValue
101
+ ) -> "QueryBuilder": # noqa: UP037
102
+ qb = self._clone()
103
+ qb._conditions.append((column, op, value))
104
+ return qb
105
+
106
+ def eq(self, column: str, value: SQLValue) -> "QueryBuilder": # noqa: UP037
107
+ return self._add_condition(column, "=", value)
108
+
109
+ def neq(self, column: str, value: SQLValue) -> "QueryBuilder": # noqa: UP037
110
+ return self._add_condition(column, "!=", value)
111
+
112
+ def gt(self, column: str, value: SQLValue) -> "QueryBuilder": # noqa: UP037
113
+ return self._add_condition(column, ">", value)
114
+
115
+ def gte(self, column: str, value: SQLValue) -> "QueryBuilder": # noqa: UP037
116
+ return self._add_condition(column, ">=", value)
117
+
118
+ def lt(self, column: str, value: SQLValue) -> "QueryBuilder": # noqa: UP037
119
+ return self._add_condition(column, "<", value)
120
+
121
+ def lte(self, column: str, value: SQLValue) -> "QueryBuilder": # noqa: UP037
122
+ return self._add_condition(column, "<=", value)
123
+
124
+ def in_(self, column: str, values: list[SQLValue]) -> "QueryBuilder": # noqa: UP037
125
+ qb = self._clone()
126
+ qb._conditions.append((column, "IN", tuple(values)))
127
+ return qb
128
+
129
+ def not_in(
130
+ self, column: str, values: list[SQLValue]
131
+ ) -> "QueryBuilder": # noqa: UP037
132
+ qb = self._clone()
133
+ qb._conditions.append((column, "NOT IN", tuple(values)))
134
+ return qb
135
+
136
+ def is_null(self, column: str) -> "QueryBuilder": # noqa: UP037
137
+ return self._add_condition(column, "IS", None)
138
+
139
+ def not_null(self, column: str) -> "QueryBuilder": # noqa: UP037
140
+ return self._add_condition(column, "IS NOT", None)
141
+
142
+ def ilike(self, column: str, pattern: str) -> "QueryBuilder": # noqa: UP037
143
+ qb = self._clone()
144
+ qb._conditions.append((column, "ILIKE", pattern))
145
+ return qb
146
+
147
+ # ---------------------------------------------------------------------
148
+ # JOIN (simple inner join)
149
+ # ---------------------------------------------------------------------
150
+ def join(
151
+ self, other_table: str, on: dict[str, str]
152
+ ) -> "QueryBuilder": # noqa: UP037
153
+ """`on` expects {local_col: remote_col}."""
154
+ qb = self._clone()
155
+ qb._joins.append((other_table, on))
156
+ return qb
157
+
158
+ # ---------------------------------------------------------------------
159
+ # Execution helpers
160
+ # ---------------------------------------------------------------------
161
+ def _compile(self) -> tuple[str, list[Any]]:
162
+ cols = ", ".join(self._select_cols)
163
+ sql = [f"SELECT {cols} FROM {self._table}"]
164
+ params: list[Any] = []
165
+
166
+ # Joins -------------------------------------------------------------
167
+ for tbl, onmap in self._joins:
168
+ join_clauses = [
169
+ f"{self._table}.{l} = {tbl}.{r}" for l, r in onmap.items() # noqa: E741
170
+ ]
171
+ sql.append(f"JOIN {tbl} ON {' AND '.join(join_clauses)}")
172
+
173
+ # WHERE -------------------------------------------------------------
174
+ if self._conditions:
175
+ placeholders = []
176
+ for col, op, val in self._conditions:
177
+ if op in ("IN", "NOT IN") and isinstance(val, tuple):
178
+ ph = ", ".join(["?" for _ in val])
179
+ placeholders.append(f"{col} {op} ({ph})")
180
+ params.extend(val)
181
+ elif op in ("IS", "IS NOT"):
182
+ placeholders.append(f"{col} {op} NULL")
183
+ elif op == "ILIKE":
184
+ placeholders.append(
185
+ f"{col} LIKE ?"
186
+ ) # SQLite has no ILIKE; LIKE is case‑insensitive when in NOCASE collation
187
+ params.append(val)
188
+ else:
189
+ placeholders.append(f"{col} {op} ?")
190
+ params.append(val)
191
+ sql.append("WHERE " + " AND ".join(placeholders))
192
+
193
+ # ORDER / LIMIT -----------------------------------------------------
194
+ if self._order_by:
195
+ sql.append(f"ORDER BY {self._order_by}")
196
+ if self._limit is not None:
197
+ sql.append(f"LIMIT {self._limit}")
198
+
199
+ return " ".join(sql), params
200
+
201
+ def _execute(self) -> list[dict[str, Any]]:
202
+ if self._cached_rows is not None:
203
+ return self._cached_rows
204
+
205
+ sql, params = self._compile()
206
+ conn = sqlite3.connect(self._snapshot.db_path)
207
+ conn.row_factory = sqlite3.Row
208
+ cur = conn.cursor()
209
+ cur.execute(sql, params)
210
+ rows = [dict(r) for r in cur.fetchall()]
211
+ cur.close()
212
+ conn.close()
213
+ self._cached_rows = rows
214
+ return rows
215
+
216
+ # ---------------------------------------------------------------------
217
+ # High‑level result helpers / assertions
218
+ # ---------------------------------------------------------------------
219
+ def count(self) -> _CountResult:
220
+ qb = self.select("COUNT(*) AS __cnt__").limit(
221
+ None
222
+ ) # remove limit since counting overrides
223
+ sql, params = qb._compile()
224
+ conn = sqlite3.connect(self._snapshot.db_path)
225
+ cur = conn.cursor()
226
+ cur.execute(sql, params)
227
+ val = cur.fetchone()[0] or 0
228
+ cur.close()
229
+ conn.close()
230
+ return _CountResult(val)
231
+
232
+ def first(self) -> dict[str, Any] | None:
233
+ return self.limit(1)._execute()[0] if self.limit(1)._execute() else None
234
+
235
+ def all(self) -> list[dict[str, Any]]:
236
+ return self._execute()
237
+
238
+ # Assertions -----------------------------------------------------------
239
+ def assert_exists(self):
240
+ row = self.first()
241
+ if row is None:
242
+ # Build descriptive error message
243
+ sql, params = self._compile()
244
+ error_msg = (
245
+ f"Expected at least one matching row, but found none.\n"
246
+ f"Query: {sql}\n"
247
+ f"Parameters: {params}\n"
248
+ f"Table: {self._table}"
249
+ )
250
+ if hasattr(self, "_conditions") and self._conditions:
251
+ conditions_str = ", ".join(
252
+ [f"{col} {op} {val}" for col, op, val in self._conditions]
253
+ )
254
+ error_msg += f"\nConditions: {conditions_str}"
255
+ raise AssertionError(error_msg)
256
+ return self
257
+
258
+ def assert_none(self):
259
+ row = self.first()
260
+ if row is not None:
261
+ row_id = _get_row_identifier(row)
262
+ row_data = _format_row_for_error(row)
263
+ sql, params = self._compile()
264
+ error_msg = (
265
+ f"Expected no matching rows, but found at least one.\n"
266
+ f"Found row: {row_id}\n"
267
+ f"Row data: {row_data}\n"
268
+ f"Query: {sql}\n"
269
+ f"Parameters: {params}\n"
270
+ f"Table: {self._table}"
271
+ )
272
+ raise AssertionError(error_msg)
273
+ return self
274
+
275
+ def assert_eq(self, column: str, value: SQLValue):
276
+ row = self.first()
277
+ if row is None:
278
+ sql, params = self._compile()
279
+ error_msg = (
280
+ f"Row not found for equality assertion.\n"
281
+ f"Expected to find a row with {column}={repr(value)}\n"
282
+ f"Query: {sql}\n"
283
+ f"Parameters: {params}\n"
284
+ f"Table: {self._table}"
285
+ )
286
+ raise AssertionError(error_msg)
287
+
288
+ actual_value = row.get(column)
289
+ if actual_value != value:
290
+ row_id = _get_row_identifier(row)
291
+ row_data = _format_row_for_error(row)
292
+ error_msg = (
293
+ f"Field value assertion failed.\n"
294
+ f"Row: {row_id}\n"
295
+ f"Field: {column}\n"
296
+ f"Expected: {repr(value)}\n"
297
+ f"Actual: {repr(actual_value)}\n"
298
+ f"Full row data: {row_data}\n"
299
+ f"Table: {self._table}"
300
+ )
301
+ raise AssertionError(error_msg)
302
+ return self
303
+
304
+ # Misc -----------------------------------------------------------------
305
+ def explain(self) -> str:
306
+ sql, params = self._compile()
307
+ return f"SQL: {sql}\nParams: {params}"
308
+
309
+ # Utilities ------------------------------------------------------------
310
+ def _clone(self) -> "QueryBuilder": # noqa: UP037
311
+ qb = QueryBuilder(self._snapshot, self._table)
312
+ qb._select_cols = list(self._select_cols)
313
+ qb._conditions = list(self._conditions)
314
+ qb._joins = list(self._joins)
315
+ qb._limit = self._limit
316
+ qb._order_by = self._order_by
317
+ return qb
318
+
319
+ # Representation -------------------------------------------------------
320
+ def __repr__(self):
321
+ return f"<QueryBuilder {self.explain()}>"
322
+
323
+
324
+ ################################################################################
325
+ # Snapshot Diff invariants
326
+ ################################################################################
327
+
328
+
329
+ class IgnoreConfig:
330
+ """Configuration for ignoring specific tables, fields, or combinations during diff operations."""
331
+
332
+ def __init__(
333
+ self,
334
+ tables: set[str] | None = None,
335
+ fields: set[str] | None = None,
336
+ table_fields: dict[str, set[str]] | None = None,
337
+ ):
338
+ """
339
+ Args:
340
+ tables: Set of table names to completely ignore
341
+ fields: Set of field names to ignore across all tables
342
+ table_fields: Dict mapping table names to sets of field names to ignore in that table
343
+ """
344
+ self.tables = tables or set()
345
+ self.fields = fields or set()
346
+ self.table_fields = table_fields or {}
347
+
348
+ def should_ignore_table(self, table: str) -> bool:
349
+ """Check if a table should be completely ignored."""
350
+ return table in self.tables
351
+
352
+ def should_ignore_field(self, table: str, field: str) -> bool:
353
+ """Check if a specific field in a table should be ignored."""
354
+ # Global field ignore
355
+ if field in self.fields:
356
+ return True
357
+ # Table-specific field ignore
358
+ if table in self.table_fields and field in self.table_fields[table]:
359
+ return True
360
+ return False
361
+
362
+
363
+ def _format_row_for_error(row: dict[str, Any], max_fields: int = 10) -> str:
364
+ """Format a row dictionary for error messages with truncation if needed."""
365
+ if not row:
366
+ return "{empty row}"
367
+
368
+ items = list(row.items())
369
+ if len(items) <= max_fields:
370
+ formatted_items = [f"{k}={repr(v)}" for k, v in items]
371
+ return "{" + ", ".join(formatted_items) + "}"
372
+ else:
373
+ # Show first few fields and indicate truncation
374
+ shown_items = [f"{k}={repr(v)}" for k, v in items[:max_fields]]
375
+ remaining = len(items) - max_fields
376
+ return "{" + ", ".join(shown_items) + f", ... +{remaining} more fields" + "}"
377
+
378
+
379
+ def _get_row_identifier(row: dict[str, Any]) -> str:
380
+ """Extract a meaningful identifier from a row for error messages."""
381
+ # Try common ID fields first
382
+ for id_field in ["id", "pk", "primary_key", "key"]:
383
+ if id_field in row and row[id_field] is not None:
384
+ return f"{id_field}={repr(row[id_field])}"
385
+
386
+ # Try name fields
387
+ for name_field in ["name", "title", "label"]:
388
+ if name_field in row and row[name_field] is not None:
389
+ return f"{name_field}={repr(row[name_field])}"
390
+
391
+ # Fall back to first non-None field
392
+ for key, value in row.items():
393
+ if value is not None:
394
+ return f"{key}={repr(value)}"
395
+
396
+ return "no identifier found"
397
+
398
+
399
+ class SnapshotDiff:
400
+ """Compute & validate changes between two snapshots."""
401
+
402
+ def __init__(
403
+ self,
404
+ before: DatabaseSnapshot,
405
+ after: DatabaseSnapshot,
406
+ ignore_config: IgnoreConfig | None = None,
407
+ ):
408
+ from .sql_differ import SQLiteDiffer # local import to avoid circularity
409
+
410
+ self.before = before
411
+ self.after = after
412
+ self.ignore_config = ignore_config or IgnoreConfig()
413
+ self._differ = SQLiteDiffer(before.db_path, after.db_path)
414
+ self._cached: dict[str, Any] | None = None
415
+
416
+ # ------------------------------------------------------------------
417
+ def _collect(self):
418
+ if self._cached is not None:
419
+ return self._cached
420
+ all_tables = set(self.before.tables()) | set(self.after.tables())
421
+ diff: dict[str, dict[str, Any]] = {}
422
+ for tbl in all_tables:
423
+ if self.ignore_config.should_ignore_table(tbl):
424
+ continue
425
+ diff[tbl] = self._differ.diff_table(tbl)
426
+ self._cached = diff
427
+ return diff
428
+
429
+ # ------------------------------------------------------------------
430
+ def expect_only(self, allowed_changes: list[dict[str, Any]]):
431
+ """Allowed changes is a list of {table, pk, field, after} (before optional)."""
432
+ diff = self._collect()
433
+ allowed_set = {
434
+ (c["table"], c.get("pk"), c.get("field"), c.get("after"))
435
+ for c in allowed_changes
436
+ }
437
+
438
+ # Collect all unexpected changes for detailed reporting
439
+ unexpected_changes = []
440
+
441
+ for tbl, report in diff.items():
442
+ for row in report.get("modified_rows", []):
443
+ for f, vals in row["changes"].items():
444
+ if self.ignore_config.should_ignore_field(tbl, f):
445
+ continue
446
+ tup = (tbl, row["row_id"], f, vals["after"])
447
+ if tup not in allowed_set:
448
+ unexpected_changes.append(
449
+ {
450
+ "type": "modification",
451
+ "table": tbl,
452
+ "row_id": row["row_id"],
453
+ "field": f,
454
+ "before": vals.get("before"),
455
+ "after": vals["after"],
456
+ "full_row": row,
457
+ }
458
+ )
459
+
460
+ for row in report.get("added_rows", []):
461
+ tup = (tbl, row["row_id"], None, "__added__")
462
+ if tup not in allowed_set:
463
+ unexpected_changes.append(
464
+ {
465
+ "type": "insertion",
466
+ "table": tbl,
467
+ "row_id": row["row_id"],
468
+ "field": None,
469
+ "after": "__added__",
470
+ "full_row": row,
471
+ }
472
+ )
473
+
474
+ for row in report.get("removed_rows", []):
475
+ tup = (tbl, row["row_id"], None, "__removed__")
476
+ if tup not in allowed_set:
477
+ unexpected_changes.append(
478
+ {
479
+ "type": "deletion",
480
+ "table": tbl,
481
+ "row_id": row["row_id"],
482
+ "field": None,
483
+ "after": "__removed__",
484
+ "full_row": row,
485
+ }
486
+ )
487
+
488
+ if unexpected_changes:
489
+ # Build comprehensive error message
490
+ error_lines = ["Unexpected database changes detected:"]
491
+ error_lines.append("")
492
+
493
+ for i, change in enumerate(
494
+ unexpected_changes[:5], 1
495
+ ): # Show first 5 changes
496
+ error_lines.append(
497
+ f"{i}. {change['type'].upper()} in table '{change['table']}':"
498
+ )
499
+ error_lines.append(f" Row ID: {change['row_id']}")
500
+
501
+ if change["type"] == "modification":
502
+ error_lines.append(f" Field: {change['field']}")
503
+ error_lines.append(f" Before: {repr(change['before'])}")
504
+ error_lines.append(f" After: {repr(change['after'])}")
505
+ elif change["type"] == "insertion":
506
+ error_lines.append(" New row added")
507
+ elif change["type"] == "deletion":
508
+ error_lines.append(" Row deleted")
509
+
510
+ # Show some context from the row
511
+ if "full_row" in change and change["full_row"]:
512
+ row_data = change["full_row"]
513
+ if change["type"] == "modification" and "data" in row_data:
514
+ # For modifications, show the current state
515
+ formatted_row = _format_row_for_error(
516
+ row_data.get("data", {}), max_fields=5
517
+ )
518
+ error_lines.append(f" Row data: {formatted_row}")
519
+ elif (
520
+ change["type"] in ["insertion", "deletion"]
521
+ and "data" in row_data
522
+ ):
523
+ # For insertions/deletions, show the row data
524
+ formatted_row = _format_row_for_error(
525
+ row_data.get("data", {}), max_fields=5
526
+ )
527
+ error_lines.append(f" Row data: {formatted_row}")
528
+
529
+ error_lines.append("")
530
+
531
+ if len(unexpected_changes) > 5:
532
+ error_lines.append(
533
+ f"... and {len(unexpected_changes) - 5} more unexpected changes"
534
+ )
535
+ error_lines.append("")
536
+
537
+ # Show what changes were allowed
538
+ error_lines.append("Allowed changes were:")
539
+ if allowed_changes:
540
+ for i, allowed in enumerate(allowed_changes[:3], 1):
541
+ error_lines.append(
542
+ f" {i}. Table: {allowed.get('table')}, "
543
+ f"ID: {allowed.get('pk')}, "
544
+ f"Field: {allowed.get('field')}, "
545
+ f"After: {repr(allowed.get('after'))}"
546
+ )
547
+ if len(allowed_changes) > 3:
548
+ error_lines.append(
549
+ f" ... and {len(allowed_changes) - 3} more allowed changes"
550
+ )
551
+ else:
552
+ error_lines.append(" (No changes were allowed)")
553
+
554
+ raise AssertionError("\n".join(error_lines))
555
+
556
+ return self
557
+
558
+ def expect(
559
+ self, *, allow: list[dict[str, Any]] = None, forbid: list[dict[str, Any]] = None
560
+ ):
561
+ """More granular: allow / forbid per‑table and per‑field."""
562
+ allow = allow or []
563
+ forbid = forbid or []
564
+ allow_tbl_field = {(c["table"], c.get("field")) for c in allow}
565
+ forbid_tbl_field = {(c["table"], c.get("field")) for c in forbid}
566
+ diff = self._collect()
567
+ for tbl, report in diff.items():
568
+ for row in report.get("modified_rows", []):
569
+ for f in row["changed"].keys():
570
+ if self.ignore_config.should_ignore_field(tbl, f):
571
+ continue
572
+ key = (tbl, f)
573
+ if key in forbid_tbl_field:
574
+ raise AssertionError(f"Modification to forbidden field {key}")
575
+ if allow_tbl_field and key not in allow_tbl_field:
576
+ raise AssertionError(f"Modification to unallowed field {key}")
577
+ if (tbl, None) in forbid_tbl_field and (
578
+ report.get("added_rows") or report.get("removed_rows")
579
+ ):
580
+ raise AssertionError(f"Changes in forbidden table {tbl}")
581
+ return self
582
+
583
+
584
+ ################################################################################
585
+ # DatabaseSnapshot with DSL entrypoints
586
+ ################################################################################
587
+
588
+
589
+ class DatabaseSnapshot:
590
+ """Represents a snapshot of an SQLite DB with DSL entrypoints."""
591
+
592
+ def __init__(self, db_path: str, *, name: str | None = None):
593
+ self.db_path = db_path
594
+ self.name = name or f"snapshot_{datetime.utcnow().isoformat()}"
595
+ self.created_at = datetime.utcnow()
596
+
597
+ # DSL entry ------------------------------------------------------------
598
+ def table(self, table: str) -> QueryBuilder:
599
+ return QueryBuilder(self, table)
600
+
601
+ # Metadata -------------------------------------------------------------
602
+ def tables(self) -> list[str]:
603
+ conn = sqlite3.connect(self.db_path)
604
+ cur = conn.cursor()
605
+ cur.execute(
606
+ "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
607
+ )
608
+ tbls = [r[0] for r in cur.fetchall()]
609
+ cur.close()
610
+ conn.close()
611
+ return tbls
612
+
613
+ # Diff interface -------------------------------------------------------
614
+ def diff(
615
+ self,
616
+ other: "DatabaseSnapshot", # noqa: UP037
617
+ ignore_config: IgnoreConfig | None = None,
618
+ ) -> SnapshotDiff:
619
+ return SnapshotDiff(self, other, ignore_config)
620
+
621
+ ############################################################################
622
+ # Convenience, schema‑agnostic expectation helpers
623
+ ############################################################################
624
+
625
+ def expect_row(
626
+ self, table: str, where: dict[str, SQLValue], expect: dict[str, SQLValue]
627
+ ):
628
+ qb = self.table(table)
629
+ for k, v in where.items():
630
+ qb = qb.eq(k, v)
631
+ qb.assert_exists()
632
+ for col, val in expect.items():
633
+ qb.assert_eq(col, val)
634
+ return self
635
+
636
+ def expect_rows(
637
+ self,
638
+ table: str,
639
+ where: dict[str, SQLValue],
640
+ *,
641
+ count: int | None = None,
642
+ contains: list[dict[str, SQLValue]] | None = None,
643
+ ):
644
+ qb = self.table(table)
645
+ for k, v in where.items():
646
+ qb = qb.eq(k, v)
647
+ if count is not None:
648
+ qb.count().assert_eq(count)
649
+ if contains:
650
+ rows = qb.all()
651
+ for cond in contains:
652
+ matched = any(all(r.get(k) == v for k, v in cond.items()) for r in rows)
653
+ if not matched:
654
+ raise AssertionError(f"Expected a row matching {cond} in {table}")
655
+ return self
656
+
657
+ def expect_absent_row(self, table: str, where: dict[str, SQLValue]):
658
+ qb = self.table(table)
659
+ for k, v in where.items():
660
+ qb = qb.eq(k, v)
661
+ qb.assert_none()
662
+ return self
663
+
664
+ # ---------------------------------------------------------------------
665
+ def __repr__(self):
666
+ return f"<DatabaseSnapshot {self.name} at {self.db_path}>"