fleet-python 0.2.43__py3-none-any.whl → 0.2.44__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fleet-python might be problematic. Click here for more details.

fleet/resources/sqlite.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Any, List, Optional
1
+ from typing import Any, List, Optional, Dict, Tuple
2
2
  from ..instance.models import Resource as ResourceModel
3
3
  from ..instance.models import DescribeResponse, QueryRequest, QueryResponse
4
4
  from .base import Resource
@@ -25,12 +25,12 @@ from fleet.verifiers.db import (
25
25
  class SyncDatabaseSnapshot:
26
26
  """Async database snapshot that fetches data through API and stores locally for diffing."""
27
27
 
28
- def __init__(self, resource: "SQLiteResource", name: str | None = None):
28
+ def __init__(self, resource: "SQLiteResource", name: Optional[str] = None):
29
29
  self.resource = resource
30
30
  self.name = name or f"snapshot_{datetime.utcnow().isoformat()}"
31
31
  self.created_at = datetime.utcnow()
32
- self._data: dict[str, list[dict[str, Any]]] = {}
33
- self._schemas: dict[str, list[str]] = {}
32
+ self._data: Dict[str, List[Dict[str, Any]]] = {}
33
+ self._schemas: Dict[str, List[str]] = {}
34
34
  self._fetched = False
35
35
 
36
36
  def _ensure_fetched(self):
@@ -69,7 +69,7 @@ class SyncDatabaseSnapshot:
69
69
 
70
70
  self._fetched = True
71
71
 
72
- def tables(self) -> list[str]:
72
+ def tables(self) -> List[str]:
73
73
  """Get list of all tables in the snapshot."""
74
74
  self._ensure_fetched()
75
75
  return list(self._data.keys())
@@ -81,7 +81,7 @@ class SyncDatabaseSnapshot:
81
81
  def diff(
82
82
  self,
83
83
  other: "SyncDatabaseSnapshot",
84
- ignore_config: IgnoreConfig | None = None,
84
+ ignore_config: Optional[IgnoreConfig] = None,
85
85
  ) -> "SyncSnapshotDiff":
86
86
  """Compare this snapshot with another."""
87
87
  self._ensure_fetched()
@@ -95,13 +95,13 @@ class SyncSnapshotQueryBuilder:
95
95
  def __init__(self, snapshot: SyncDatabaseSnapshot, table: str):
96
96
  self._snapshot = snapshot
97
97
  self._table = table
98
- self._select_cols: list[str] = ["*"]
99
- self._conditions: list[tuple[str, str, Any]] = []
100
- self._limit: int | None = None
101
- self._order_by: str | None = None
98
+ self._select_cols: List[str] = ["*"]
99
+ self._conditions: List[Tuple[str, str, Any]] = []
100
+ self._limit: Optional[int] = None
101
+ self._order_by: Optional[str] = None
102
102
  self._order_desc: bool = False
103
103
 
104
- def _get_data(self) -> list[dict[str, Any]]:
104
+ def _get_data(self) -> List[Dict[str, Any]]:
105
105
  """Get table data from snapshot."""
106
106
  self._snapshot._ensure_fetched()
107
107
  return self._snapshot._data.get(self._table, [])
@@ -122,11 +122,11 @@ class SyncSnapshotQueryBuilder:
122
122
  qb._order_desc = desc
123
123
  return qb
124
124
 
125
- def first(self) -> dict[str, Any] | None:
125
+ def first(self) -> Optional[Dict[str, Any]]:
126
126
  rows = self.all()
127
127
  return rows[0] if rows else None
128
128
 
129
- def all(self) -> list[dict[str, Any]]:
129
+ def all(self) -> List[Dict[str, Any]]:
130
130
  data = self._get_data()
131
131
 
132
132
  # Apply filters
@@ -185,14 +185,14 @@ class SyncSnapshotDiff:
185
185
  self,
186
186
  before: SyncDatabaseSnapshot,
187
187
  after: SyncDatabaseSnapshot,
188
- ignore_config: IgnoreConfig | None = None,
188
+ ignore_config: Optional[IgnoreConfig] = None,
189
189
  ):
190
190
  self.before = before
191
191
  self.after = after
192
192
  self.ignore_config = ignore_config or IgnoreConfig()
193
- self._cached: dict[str, Any] | None = None
193
+ self._cached: Optional[Dict[str, Any]] = None
194
194
 
195
- def _get_primary_key_columns(self, table: str) -> list[str]:
195
+ def _get_primary_key_columns(self, table: str) -> List[str]:
196
196
  """Get primary key columns for a table."""
197
197
  # Try to get from schema
198
198
  schema_response = self.after.resource.query(f"PRAGMA table_info({table})")
@@ -222,7 +222,7 @@ class SyncSnapshotDiff:
222
222
  return self._cached
223
223
 
224
224
  all_tables = set(self.before.tables()) | set(self.after.tables())
225
- diff: dict[str, dict[str, Any]] = {}
225
+ diff: Dict[str, Dict[str, Any]] = {}
226
226
 
227
227
  for tbl in all_tables:
228
228
  if self.ignore_config.should_ignore_table(tbl):
@@ -236,7 +236,7 @@ class SyncSnapshotDiff:
236
236
  after_data = self.after._data.get(tbl, [])
237
237
 
238
238
  # Create indexes by primary key
239
- def make_key(row: dict, pk_cols: list[str]) -> Any:
239
+ def make_key(row: dict, pk_cols: List[str]) -> Any:
240
240
  if len(pk_cols) == 1:
241
241
  return row.get(pk_cols[0])
242
242
  return tuple(row.get(col) for col in pk_cols)
@@ -304,12 +304,12 @@ class SyncSnapshotDiff:
304
304
  self._cached = diff
305
305
  return diff
306
306
 
307
- def expect_only(self, allowed_changes: list[dict[str, Any]]):
307
+ def expect_only(self, allowed_changes: List[Dict[str, Any]]):
308
308
  """Ensure only specified changes occurred."""
309
309
  diff = self._collect()
310
310
 
311
311
  def _is_change_allowed(
312
- table: str, row_id: Any, field: str | None, after_value: Any
312
+ table: str, row_id: Any, field: Optional[str], after_value: Any
313
313
  ) -> bool:
314
314
  """Check if a change is in the allowed list using semantic comparison."""
315
315
  for allowed in allowed_changes:
@@ -440,11 +440,11 @@ class SyncQueryBuilder:
440
440
  def __init__(self, resource: "SQLiteResource", table: str):
441
441
  self._resource = resource
442
442
  self._table = table
443
- self._select_cols: list[str] = ["*"]
444
- self._conditions: list[tuple[str, str, Any]] = []
445
- self._joins: list[tuple[str, dict[str, str]]] = []
446
- self._limit: int | None = None
447
- self._order_by: str | None = None
443
+ self._select_cols: List[str] = ["*"]
444
+ self._conditions: List[Tuple[str, str, Any]] = []
445
+ self._joins: List[Tuple[str, Dict[str, str]]] = []
446
+ self._limit: Optional[int] = None
447
+ self._order_by: Optional[str] = None
448
448
 
449
449
  # Column projection / limiting / ordering
450
450
  def select(self, *columns: str) -> "SyncQueryBuilder":
@@ -486,12 +486,12 @@ class SyncQueryBuilder:
486
486
  def lte(self, column: str, value: Any) -> "SyncQueryBuilder":
487
487
  return self._add_condition(column, "<=", value)
488
488
 
489
- def in_(self, column: str, values: list[Any]) -> "SyncQueryBuilder":
489
+ def in_(self, column: str, values: List[Any]) -> "SyncQueryBuilder":
490
490
  qb = self._clone()
491
491
  qb._conditions.append((column, "IN", tuple(values)))
492
492
  return qb
493
493
 
494
- def not_in(self, column: str, values: list[Any]) -> "SyncQueryBuilder":
494
+ def not_in(self, column: str, values: List[Any]) -> "SyncQueryBuilder":
495
495
  qb = self._clone()
496
496
  qb._conditions.append((column, "NOT IN", tuple(values)))
497
497
  return qb
@@ -508,16 +508,16 @@ class SyncQueryBuilder:
508
508
  return qb
509
509
 
510
510
  # JOIN
511
- def join(self, other_table: str, on: dict[str, str]) -> "SyncQueryBuilder":
511
+ def join(self, other_table: str, on: Dict[str, str]) -> "SyncQueryBuilder":
512
512
  qb = self._clone()
513
513
  qb._joins.append((other_table, on))
514
514
  return qb
515
515
 
516
516
  # Compile to SQL
517
- def _compile(self) -> tuple[str, list[Any]]:
517
+ def _compile(self) -> Tuple[str, List[Any]]:
518
518
  cols = ", ".join(self._select_cols)
519
519
  sql = [f"SELECT {cols} FROM {self._table}"]
520
- params: list[Any] = []
520
+ params: List[Any] = []
521
521
 
522
522
  # Joins
523
523
  for tbl, onmap in self._joins:
@@ -558,11 +558,11 @@ class SyncQueryBuilder:
558
558
  return row_dict.get("__cnt__", 0)
559
559
  return 0
560
560
 
561
- def first(self) -> dict[str, Any] | None:
561
+ def first(self) -> Optional[Dict[str, Any]]:
562
562
  rows = self.limit(1).all()
563
563
  return rows[0] if rows else None
564
564
 
565
- def all(self) -> list[dict[str, Any]]:
565
+ def all(self) -> List[Dict[str, Any]]:
566
566
  sql, params = self._compile()
567
567
  response = self._resource.query(sql, params)
568
568
  if not response.rows:
@@ -651,9 +651,7 @@ class SQLiteResource(Resource):
651
651
  )
652
652
  return DescribeResponse(**response.json())
653
653
 
654
- def query(
655
- self, query: str, args: Optional[List[Any]] = None
656
- ) -> QueryResponse:
654
+ def query(self, query: str, args: Optional[List[Any]] = None) -> QueryResponse:
657
655
  return self._query(query, args, read_only=True)
658
656
 
659
657
  def exec(self, query: str, args: Optional[List[Any]] = None) -> QueryResponse:
@@ -674,7 +672,7 @@ class SQLiteResource(Resource):
674
672
  """Create a query builder for the specified table."""
675
673
  return SyncQueryBuilder(self, table_name)
676
674
 
677
- def snapshot(self, name: str | None = None) -> SyncDatabaseSnapshot:
675
+ def snapshot(self, name: Optional[str] = None) -> SyncDatabaseSnapshot:
678
676
  """Create a snapshot of the current database state."""
679
677
  snapshot = SyncDatabaseSnapshot(self, name)
680
678
  snapshot._ensure_fetched()
@@ -683,7 +681,7 @@ class SQLiteResource(Resource):
683
681
  def diff(
684
682
  self,
685
683
  other: "SQLiteResource",
686
- ignore_config: IgnoreConfig | None = None,
684
+ ignore_config: Optional[IgnoreConfig] = None,
687
685
  ) -> SyncSnapshotDiff:
688
686
  """Compare this database with another AsyncSQLiteResource.
689
687
 
@@ -695,12 +693,8 @@ class SQLiteResource(Resource):
695
693
  AsyncSnapshotDiff: Object containing the differences between the two databases
696
694
  """
697
695
  # Create snapshots of both databases
698
- before_snapshot = self.snapshot(
699
- name=f"before_{datetime.utcnow().isoformat()}"
700
- )
701
- after_snapshot = other.snapshot(
702
- name=f"after_{datetime.utcnow().isoformat()}"
703
- )
696
+ before_snapshot = self.snapshot(name=f"before_{datetime.utcnow().isoformat()}")
697
+ after_snapshot = other.snapshot(name=f"after_{datetime.utcnow().isoformat()}")
704
698
 
705
699
  # Return the diff between the snapshots
706
700
  return before_snapshot.diff(after_snapshot, ignore_config)
fleet/tasks.py CHANGED
@@ -47,7 +47,7 @@ class Task(BaseModel):
47
47
  @property
48
48
  def env_key(self) -> str:
49
49
  """Get the environment key combining env_id and version."""
50
- if self.version and self.version != "None":
50
+ if self.version and self.version != "None" and ":" not in self.env_id:
51
51
  return f"{self.env_id}:{self.version}"
52
52
  return self.env_id
53
53
 
@@ -70,17 +70,13 @@ class Task(BaseModel):
70
70
  import inspect
71
71
 
72
72
  # Check if verifier has remote method (for decorated verifiers)
73
- if hasattr(self.verifier, "remote"):
74
- result = self.verifier.remote(env, *args, **kwargs)
75
- else:
76
- # For verifiers created from string, call directly
77
- result = self.verifier(env, *args, **kwargs)
73
+ result = self.verifier.remote(env, *args, **kwargs)
78
74
 
79
75
  # If the result is a coroutine, we need to run it
80
76
  if inspect.iscoroutine(result):
81
77
  # Check if we're already in an event loop
82
78
  try:
83
- loop = asyncio.get_running_loop()
79
+ asyncio.get_running_loop()
84
80
  # We're in an async context, can't use asyncio.run()
85
81
  raise RuntimeError(
86
82
  "Cannot run async verifier in sync mode while event loop is running. "
@@ -141,7 +137,7 @@ def verifier_from_string(
141
137
  """
142
138
  try:
143
139
  import inspect
144
- from .verifiers import verifier, SyncVerifierFunction
140
+ from .verifiers import SyncVerifierFunction
145
141
  from .verifiers.code import TASK_SUCCESSFUL_SCORE, TASK_FAILED_SCORE
146
142
  from .verifiers.db import IgnoreConfig
147
143
 
@@ -172,36 +168,13 @@ def verifier_from_string(
172
168
  if func_obj is None:
173
169
  raise ValueError("No function found in verifier code")
174
170
 
175
- # Create a wrapper function that provides the necessary globals
176
- def wrapped_verifier(env, *args, **kwargs):
177
- # Set up globals for the function execution
178
- func_globals = (
179
- func_obj.__globals__.copy() if hasattr(func_obj, "__globals__") else {}
180
- )
181
- func_globals.update(
182
- {
183
- "TASK_SUCCESSFUL_SCORE": TASK_SUCCESSFUL_SCORE,
184
- "TASK_FAILED_SCORE": TASK_FAILED_SCORE,
185
- "IgnoreConfig": IgnoreConfig,
186
- }
187
- )
188
-
189
- # Create a new function with the updated globals
190
- import types
191
-
192
- new_func = types.FunctionType(
193
- func_obj.__code__,
194
- func_globals,
195
- func_obj.__name__,
196
- func_obj.__defaults__,
197
- func_obj.__closure__,
198
- )
199
-
200
- return new_func(env, *args, **kwargs)
201
-
202
- # Create an AsyncVerifierFunction instance with the wrapped function
171
+ # Create an SyncVerifierFunction instance with raw code
203
172
  verifier_instance = SyncVerifierFunction(
204
- wrapped_verifier, verifier_key, verifier_id
173
+ func=func_obj,
174
+ key=verifier_key,
175
+ verifier_id=verifier_id,
176
+ sha256=sha256,
177
+ raw_code=verifier_func,
205
178
  )
206
179
 
207
180
  # Store additional metadata
@@ -214,6 +187,18 @@ def verifier_from_string(
214
187
  raise ValueError(f"Failed to create verifier from string: {e}")
215
188
 
216
189
 
190
+ def load_tasks_from_file(filename: str) -> List[Task]:
191
+ """Load tasks from a JSON file.
192
+
193
+ Example:
194
+ tasks = fleet.load_tasks_from_file("my_tasks.json")
195
+ """
196
+ from .global_client import get_client
197
+
198
+ client = get_client()
199
+ return client.load_tasks_from_file(filename)
200
+
201
+
217
202
  def load_tasks(
218
203
  env_key: Optional[str] = None,
219
204
  keys: Optional[List[str]] = None,
@@ -2,7 +2,7 @@
2
2
 
3
3
  from .db import DatabaseSnapshot, IgnoreConfig, SnapshotDiff
4
4
  from .code import TASK_SUCCESSFUL_SCORE, TASK_FAILED_SCORE
5
- from .decorator import (
5
+ from .verifier import (
6
6
  verifier,
7
7
  SyncVerifierFunction,
8
8
  )
fleet/verifiers/db.py CHANGED
@@ -22,9 +22,11 @@ import json
22
22
  # Low‑level helpers
23
23
  ################################################################################
24
24
 
25
- SQLValue = str | int | float | None
26
- Condition = tuple[str, str, SQLValue] # (column, op, value)
27
- JoinSpec = tuple[str, dict[str, str]] # (table, on mapping)
25
+ from typing import Union, Tuple, Dict, List, Optional, Any, Set
26
+
27
+ SQLValue = Union[str, int, float, None]
28
+ Condition = Tuple[str, str, SQLValue] # (column, op, value)
29
+ JoinSpec = Tuple[str, Dict[str, str]] # (table, on mapping)
28
30
 
29
31
 
30
32
  def _is_json_string(value: Any) -> bool:
@@ -98,13 +100,13 @@ class QueryBuilder:
98
100
  def __init__(self, snapshot: "DatabaseSnapshot", table: str): # noqa: UP037
99
101
  self._snapshot = snapshot
100
102
  self._table = table
101
- self._select_cols: list[str] = ["*"]
102
- self._conditions: list[Condition] = []
103
- self._joins: list[JoinSpec] = []
104
- self._limit: int | None = None
105
- self._order_by: str | None = None
103
+ self._select_cols: List[str] = ["*"]
104
+ self._conditions: List[Condition] = []
105
+ self._joins: List[JoinSpec] = []
106
+ self._limit: Optional[int] = None
107
+ self._order_by: Optional[str] = None
106
108
  # Cache for idempotent executions
107
- self._cached_rows: list[dict[str, Any]] | None = None
109
+ self._cached_rows: Optional[List[Dict[str, Any]]] = None
108
110
 
109
111
  # ---------------------------------------------------------------------
110
112
  # Column projection / limiting / ordering
@@ -150,12 +152,12 @@ class QueryBuilder:
150
152
  def lte(self, column: str, value: SQLValue) -> "QueryBuilder": # noqa: UP037
151
153
  return self._add_condition(column, "<=", value)
152
154
 
153
- def in_(self, column: str, values: list[SQLValue]) -> "QueryBuilder": # noqa: UP037
155
+ def in_(self, column: str, values: List[SQLValue]) -> "QueryBuilder": # noqa: UP037
154
156
  qb = self._clone()
155
157
  qb._conditions.append((column, "IN", tuple(values)))
156
158
  return qb
157
159
 
158
- def not_in(self, column: str, values: list[SQLValue]) -> "QueryBuilder": # noqa: UP037
160
+ def not_in(self, column: str, values: List[SQLValue]) -> "QueryBuilder": # noqa: UP037
159
161
  qb = self._clone()
160
162
  qb._conditions.append((column, "NOT IN", tuple(values)))
161
163
  return qb
@@ -174,7 +176,7 @@ class QueryBuilder:
174
176
  # ---------------------------------------------------------------------
175
177
  # JOIN (simple inner join)
176
178
  # ---------------------------------------------------------------------
177
- def join(self, other_table: str, on: dict[str, str]) -> "QueryBuilder": # noqa: UP037
179
+ def join(self, other_table: str, on: Dict[str, str]) -> "QueryBuilder": # noqa: UP037
178
180
  """`on` expects {local_col: remote_col}."""
179
181
  qb = self._clone()
180
182
  qb._joins.append((other_table, on))
@@ -183,10 +185,10 @@ class QueryBuilder:
183
185
  # ---------------------------------------------------------------------
184
186
  # Execution helpers
185
187
  # ---------------------------------------------------------------------
186
- def _compile(self) -> tuple[str, list[Any]]:
188
+ def _compile(self) -> Tuple[str, List[Any]]:
187
189
  cols = ", ".join(self._select_cols)
188
190
  sql = [f"SELECT {cols} FROM {self._table}"]
189
- params: list[Any] = []
191
+ params: List[Any] = []
190
192
 
191
193
  # Joins -------------------------------------------------------------
192
194
  for tbl, onmap in self._joins:
@@ -224,7 +226,7 @@ class QueryBuilder:
224
226
 
225
227
  return " ".join(sql), params
226
228
 
227
- def _execute(self) -> list[dict[str, Any]]:
229
+ def _execute(self) -> List[Dict[str, Any]]:
228
230
  if self._cached_rows is not None:
229
231
  return self._cached_rows
230
232
 
@@ -255,10 +257,10 @@ class QueryBuilder:
255
257
  conn.close()
256
258
  return _CountResult(val)
257
259
 
258
- def first(self) -> dict[str, Any] | None:
260
+ def first(self) -> Optional[Dict[str, Any]]:
259
261
  return self.limit(1)._execute()[0] if self.limit(1)._execute() else None
260
262
 
261
- def all(self) -> list[dict[str, Any]]:
263
+ def all(self) -> List[Dict[str, Any]]:
262
264
  return self._execute()
263
265
 
264
266
  # Assertions -----------------------------------------------------------
@@ -357,9 +359,9 @@ class IgnoreConfig:
357
359
 
358
360
  def __init__(
359
361
  self,
360
- tables: set[str] | None = None,
361
- fields: set[str] | None = None,
362
- table_fields: dict[str, set[str]] | None = None,
362
+ tables: Optional[Set[str]] = None,
363
+ fields: Optional[Set[str]] = None,
364
+ table_fields: Optional[Dict[str, Set[str]]] = None,
363
365
  ):
364
366
  """
365
367
  Args:
@@ -386,7 +388,7 @@ class IgnoreConfig:
386
388
  return False
387
389
 
388
390
 
389
- def _format_row_for_error(row: dict[str, Any], max_fields: int = 10) -> str:
391
+ def _format_row_for_error(row: Dict[str, Any], max_fields: int = 10) -> str:
390
392
  """Format a row dictionary for error messages with truncation if needed."""
391
393
  if not row:
392
394
  return "{empty row}"
@@ -402,7 +404,7 @@ def _format_row_for_error(row: dict[str, Any], max_fields: int = 10) -> str:
402
404
  return "{" + ", ".join(shown_items) + f", ... +{remaining} more fields" + "}"
403
405
 
404
406
 
405
- def _get_row_identifier(row: dict[str, Any]) -> str:
407
+ def _get_row_identifier(row: Dict[str, Any]) -> str:
406
408
  """Extract a meaningful identifier from a row for error messages."""
407
409
  # Try common ID fields first
408
410
  for id_field in ["id", "pk", "primary_key", "key"]:
@@ -429,7 +431,7 @@ class SnapshotDiff:
429
431
  self,
430
432
  before: DatabaseSnapshot,
431
433
  after: DatabaseSnapshot,
432
- ignore_config: IgnoreConfig | None = None,
434
+ ignore_config: Optional[IgnoreConfig] = None,
433
435
  ):
434
436
  from .sql_differ import SQLiteDiffer # local import to avoid circularity
435
437
 
@@ -437,14 +439,14 @@ class SnapshotDiff:
437
439
  self.after = after
438
440
  self.ignore_config = ignore_config or IgnoreConfig()
439
441
  self._differ = SQLiteDiffer(before.db_path, after.db_path)
440
- self._cached: dict[str, Any] | None = None
442
+ self._cached: Optional[Dict[str, Any]] = None
441
443
 
442
444
  # ------------------------------------------------------------------
443
445
  def _collect(self):
444
446
  if self._cached is not None:
445
447
  return self._cached
446
448
  all_tables = set(self.before.tables()) | set(self.after.tables())
447
- diff: dict[str, dict[str, Any]] = {}
449
+ diff: Dict[str, Dict[str, Any]] = {}
448
450
  for tbl in all_tables:
449
451
  if self.ignore_config.should_ignore_table(tbl):
450
452
  continue
@@ -453,12 +455,12 @@ class SnapshotDiff:
453
455
  return diff
454
456
 
455
457
  # ------------------------------------------------------------------
456
- def expect_only(self, allowed_changes: list[dict[str, Any]]):
458
+ def expect_only(self, allowed_changes: List[Dict[str, Any]]):
457
459
  """Allowed changes is a list of {table, pk, field, after} (before optional)."""
458
460
  diff = self._collect()
459
461
 
460
462
  def _is_change_allowed(
461
- table: str, row_id: str, field: str | None, after_value: Any
463
+ table: str, row_id: str, field: Optional[str], after_value: Any
462
464
  ) -> bool:
463
465
  """Check if a change is in the allowed list using semantic comparison."""
464
466
  for allowed in allowed_changes:
@@ -596,7 +598,10 @@ class SnapshotDiff:
596
598
  return self
597
599
 
598
600
  def expect(
599
- self, *, allow: list[dict[str, Any]] = None, forbid: list[dict[str, Any]] = None
601
+ self,
602
+ *,
603
+ allow: Optional[List[Dict[str, Any]]] = None,
604
+ forbid: Optional[List[Dict[str, Any]]] = None,
600
605
  ):
601
606
  """More granular: allow / forbid per‑table and per‑field."""
602
607
  allow = allow or []
@@ -629,7 +634,7 @@ class SnapshotDiff:
629
634
  class DatabaseSnapshot:
630
635
  """Represents a snapshot of an SQLite DB with DSL entrypoints."""
631
636
 
632
- def __init__(self, db_path: str, *, name: str | None = None):
637
+ def __init__(self, db_path: str, *, name: Optional[str] = None):
633
638
  self.db_path = db_path
634
639
  self.name = name or f"snapshot_{datetime.utcnow().isoformat()}"
635
640
  self.created_at = datetime.utcnow()
@@ -639,7 +644,7 @@ class DatabaseSnapshot:
639
644
  return QueryBuilder(self, table)
640
645
 
641
646
  # Metadata -------------------------------------------------------------
642
- def tables(self) -> list[str]:
647
+ def tables(self) -> List[str]:
643
648
  conn = sqlite3.connect(self.db_path)
644
649
  cur = conn.cursor()
645
650
  cur.execute(
@@ -654,7 +659,7 @@ class DatabaseSnapshot:
654
659
  def diff(
655
660
  self,
656
661
  other: "DatabaseSnapshot", # noqa: UP037
657
- ignore_config: IgnoreConfig | None = None,
662
+ ignore_config: Optional[IgnoreConfig] = None,
658
663
  ) -> SnapshotDiff:
659
664
  return SnapshotDiff(self, other, ignore_config)
660
665
 
@@ -663,7 +668,7 @@ class DatabaseSnapshot:
663
668
  ############################################################################
664
669
 
665
670
  def expect_row(
666
- self, table: str, where: dict[str, SQLValue], expect: dict[str, SQLValue]
671
+ self, table: str, where: Dict[str, SQLValue], expect: Dict[str, SQLValue]
667
672
  ):
668
673
  qb = self.table(table)
669
674
  for k, v in where.items():
@@ -676,10 +681,10 @@ class DatabaseSnapshot:
676
681
  def expect_rows(
677
682
  self,
678
683
  table: str,
679
- where: dict[str, SQLValue],
684
+ where: Dict[str, SQLValue],
680
685
  *,
681
- count: int | None = None,
682
- contains: list[dict[str, SQLValue]] | None = None,
686
+ count: Optional[int] = None,
687
+ contains: Optional[List[Dict[str, SQLValue]]] = None,
683
688
  ):
684
689
  qb = self.table(table)
685
690
  for k, v in where.items():
@@ -694,7 +699,7 @@ class DatabaseSnapshot:
694
699
  raise AssertionError(f"Expected a row matching {cond} in {table}")
695
700
  return self
696
701
 
697
- def expect_absent_row(self, table: str, where: dict[str, SQLValue]):
702
+ def expect_absent_row(self, table: str, where: Dict[str, SQLValue]):
698
703
  qb = self.table(table)
699
704
  for k, v in where.items():
700
705
  qb = qb.eq(k, v)
fleet/verifiers/parse.py CHANGED
@@ -1,7 +1,10 @@
1
1
  import re
2
2
 
3
3
 
4
- def extract_function_name(function_code: str) -> str | None:
4
+ from typing import Optional
5
+
6
+
7
+ def extract_function_name(function_code: str) -> Optional[str]:
5
8
  """
6
9
  Extract function name from Python function code.
7
10
 
@@ -1,5 +1,5 @@
1
1
  import sqlite3
2
- from typing import Any
2
+ from typing import Any, Optional, List, Dict, Tuple
3
3
 
4
4
 
5
5
  class SQLiteDiffer:
@@ -7,7 +7,7 @@ class SQLiteDiffer:
7
7
  self.before_db = before_db
8
8
  self.after_db = after_db
9
9
 
10
- def get_table_schema(self, db_path: str, table_name: str) -> list[str]:
10
+ def get_table_schema(self, db_path: str, table_name: str) -> List[str]:
11
11
  """Get column names for a table"""
12
12
  conn = sqlite3.connect(db_path)
13
13
  cursor = conn.cursor()
@@ -16,7 +16,7 @@ class SQLiteDiffer:
16
16
  conn.close()
17
17
  return columns
18
18
 
19
- def get_primary_key_columns(self, db_path: str, table_name: str) -> list[str]:
19
+ def get_primary_key_columns(self, db_path: str, table_name: str) -> List[str]:
20
20
  """Get all primary key columns for a table, ordered by their position"""
21
21
  conn = sqlite3.connect(db_path)
22
22
  cursor = conn.cursor()
@@ -34,7 +34,7 @@ class SQLiteDiffer:
34
34
  pk_columns.sort(key=lambda x: x[0])
35
35
  return [col[1] for col in pk_columns]
36
36
 
37
- def get_all_tables(self, db_path: str) -> list[str]:
37
+ def get_all_tables(self, db_path: str) -> List[str]:
38
38
  """Get all table names from database"""
39
39
  conn = sqlite3.connect(db_path)
40
40
  cursor = conn.cursor()
@@ -49,8 +49,8 @@ class SQLiteDiffer:
49
49
  self,
50
50
  db_path: str,
51
51
  table_name: str,
52
- primary_key_columns: list[str] | None = None,
53
- ) -> tuple[dict[Any, dict], list[str]]:
52
+ primary_key_columns: Optional[List[str]] = None,
53
+ ) -> Tuple[Dict[Any, dict], List[str]]:
54
54
  """Get table data indexed by primary key (single column or composite)"""
55
55
  conn = sqlite3.connect(db_path)
56
56
  conn.row_factory = sqlite3.Row
@@ -97,7 +97,7 @@ class SQLiteDiffer:
97
97
  conn.close()
98
98
  return data, primary_key_columns
99
99
 
100
- def compare_rows(self, before_row: dict, after_row: dict) -> dict[str, dict]:
100
+ def compare_rows(self, before_row: dict, after_row: dict) -> Dict[str, dict]:
101
101
  """Compare two rows field by field"""
102
102
  changes = {}
103
103
 
@@ -113,7 +113,7 @@ class SQLiteDiffer:
113
113
  return changes
114
114
 
115
115
  def diff_table(
116
- self, table_name: str, primary_key_columns: list[str] | None = None
116
+ self, table_name: str, primary_key_columns: Optional[List[str]] = None
117
117
  ) -> dict:
118
118
  """Create comprehensive diff of a table"""
119
119
  before_data, detected_pk = self.get_table_data(