fleet-python 0.2.43__py3-none-any.whl → 0.2.44__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fleet-python might be problematic. Click here for more details.
- fleet/__init__.py +9 -7
- fleet/_async/__init__.py +28 -16
- fleet/_async/client.py +44 -16
- fleet/_async/instance/client.py +1 -1
- fleet/_async/resources/sqlite.py +34 -34
- fleet/_async/tasks.py +13 -1
- fleet/_async/verifiers/verifier.py +3 -3
- fleet/client.py +44 -13
- fleet/instance/client.py +2 -4
- fleet/resources/sqlite.py +37 -43
- fleet/tasks.py +22 -37
- fleet/verifiers/__init__.py +1 -1
- fleet/verifiers/db.py +41 -36
- fleet/verifiers/parse.py +4 -1
- fleet/verifiers/sql_differ.py +8 -8
- fleet/verifiers/verifier.py +19 -7
- {fleet_python-0.2.43.dist-info → fleet_python-0.2.44.dist-info}/METADATA +1 -1
- {fleet_python-0.2.43.dist-info → fleet_python-0.2.44.dist-info}/RECORD +21 -21
- {fleet_python-0.2.43.dist-info → fleet_python-0.2.44.dist-info}/WHEEL +0 -0
- {fleet_python-0.2.43.dist-info → fleet_python-0.2.44.dist-info}/licenses/LICENSE +0 -0
- {fleet_python-0.2.43.dist-info → fleet_python-0.2.44.dist-info}/top_level.txt +0 -0
fleet/resources/sqlite.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, List, Optional
|
|
1
|
+
from typing import Any, List, Optional, Dict, Tuple
|
|
2
2
|
from ..instance.models import Resource as ResourceModel
|
|
3
3
|
from ..instance.models import DescribeResponse, QueryRequest, QueryResponse
|
|
4
4
|
from .base import Resource
|
|
@@ -25,12 +25,12 @@ from fleet.verifiers.db import (
|
|
|
25
25
|
class SyncDatabaseSnapshot:
|
|
26
26
|
"""Async database snapshot that fetches data through API and stores locally for diffing."""
|
|
27
27
|
|
|
28
|
-
def __init__(self, resource: "SQLiteResource", name: str
|
|
28
|
+
def __init__(self, resource: "SQLiteResource", name: Optional[str] = None):
|
|
29
29
|
self.resource = resource
|
|
30
30
|
self.name = name or f"snapshot_{datetime.utcnow().isoformat()}"
|
|
31
31
|
self.created_at = datetime.utcnow()
|
|
32
|
-
self._data:
|
|
33
|
-
self._schemas:
|
|
32
|
+
self._data: Dict[str, List[Dict[str, Any]]] = {}
|
|
33
|
+
self._schemas: Dict[str, List[str]] = {}
|
|
34
34
|
self._fetched = False
|
|
35
35
|
|
|
36
36
|
def _ensure_fetched(self):
|
|
@@ -69,7 +69,7 @@ class SyncDatabaseSnapshot:
|
|
|
69
69
|
|
|
70
70
|
self._fetched = True
|
|
71
71
|
|
|
72
|
-
def tables(self) ->
|
|
72
|
+
def tables(self) -> List[str]:
|
|
73
73
|
"""Get list of all tables in the snapshot."""
|
|
74
74
|
self._ensure_fetched()
|
|
75
75
|
return list(self._data.keys())
|
|
@@ -81,7 +81,7 @@ class SyncDatabaseSnapshot:
|
|
|
81
81
|
def diff(
|
|
82
82
|
self,
|
|
83
83
|
other: "SyncDatabaseSnapshot",
|
|
84
|
-
ignore_config: IgnoreConfig
|
|
84
|
+
ignore_config: Optional[IgnoreConfig] = None,
|
|
85
85
|
) -> "SyncSnapshotDiff":
|
|
86
86
|
"""Compare this snapshot with another."""
|
|
87
87
|
self._ensure_fetched()
|
|
@@ -95,13 +95,13 @@ class SyncSnapshotQueryBuilder:
|
|
|
95
95
|
def __init__(self, snapshot: SyncDatabaseSnapshot, table: str):
|
|
96
96
|
self._snapshot = snapshot
|
|
97
97
|
self._table = table
|
|
98
|
-
self._select_cols:
|
|
99
|
-
self._conditions:
|
|
100
|
-
self._limit: int
|
|
101
|
-
self._order_by: str
|
|
98
|
+
self._select_cols: List[str] = ["*"]
|
|
99
|
+
self._conditions: List[Tuple[str, str, Any]] = []
|
|
100
|
+
self._limit: Optional[int] = None
|
|
101
|
+
self._order_by: Optional[str] = None
|
|
102
102
|
self._order_desc: bool = False
|
|
103
103
|
|
|
104
|
-
def _get_data(self) ->
|
|
104
|
+
def _get_data(self) -> List[Dict[str, Any]]:
|
|
105
105
|
"""Get table data from snapshot."""
|
|
106
106
|
self._snapshot._ensure_fetched()
|
|
107
107
|
return self._snapshot._data.get(self._table, [])
|
|
@@ -122,11 +122,11 @@ class SyncSnapshotQueryBuilder:
|
|
|
122
122
|
qb._order_desc = desc
|
|
123
123
|
return qb
|
|
124
124
|
|
|
125
|
-
def first(self) ->
|
|
125
|
+
def first(self) -> Optional[Dict[str, Any]]:
|
|
126
126
|
rows = self.all()
|
|
127
127
|
return rows[0] if rows else None
|
|
128
128
|
|
|
129
|
-
def all(self) ->
|
|
129
|
+
def all(self) -> List[Dict[str, Any]]:
|
|
130
130
|
data = self._get_data()
|
|
131
131
|
|
|
132
132
|
# Apply filters
|
|
@@ -185,14 +185,14 @@ class SyncSnapshotDiff:
|
|
|
185
185
|
self,
|
|
186
186
|
before: SyncDatabaseSnapshot,
|
|
187
187
|
after: SyncDatabaseSnapshot,
|
|
188
|
-
ignore_config: IgnoreConfig
|
|
188
|
+
ignore_config: Optional[IgnoreConfig] = None,
|
|
189
189
|
):
|
|
190
190
|
self.before = before
|
|
191
191
|
self.after = after
|
|
192
192
|
self.ignore_config = ignore_config or IgnoreConfig()
|
|
193
|
-
self._cached:
|
|
193
|
+
self._cached: Optional[Dict[str, Any]] = None
|
|
194
194
|
|
|
195
|
-
def _get_primary_key_columns(self, table: str) ->
|
|
195
|
+
def _get_primary_key_columns(self, table: str) -> List[str]:
|
|
196
196
|
"""Get primary key columns for a table."""
|
|
197
197
|
# Try to get from schema
|
|
198
198
|
schema_response = self.after.resource.query(f"PRAGMA table_info({table})")
|
|
@@ -222,7 +222,7 @@ class SyncSnapshotDiff:
|
|
|
222
222
|
return self._cached
|
|
223
223
|
|
|
224
224
|
all_tables = set(self.before.tables()) | set(self.after.tables())
|
|
225
|
-
diff:
|
|
225
|
+
diff: Dict[str, Dict[str, Any]] = {}
|
|
226
226
|
|
|
227
227
|
for tbl in all_tables:
|
|
228
228
|
if self.ignore_config.should_ignore_table(tbl):
|
|
@@ -236,7 +236,7 @@ class SyncSnapshotDiff:
|
|
|
236
236
|
after_data = self.after._data.get(tbl, [])
|
|
237
237
|
|
|
238
238
|
# Create indexes by primary key
|
|
239
|
-
def make_key(row: dict, pk_cols:
|
|
239
|
+
def make_key(row: dict, pk_cols: List[str]) -> Any:
|
|
240
240
|
if len(pk_cols) == 1:
|
|
241
241
|
return row.get(pk_cols[0])
|
|
242
242
|
return tuple(row.get(col) for col in pk_cols)
|
|
@@ -304,12 +304,12 @@ class SyncSnapshotDiff:
|
|
|
304
304
|
self._cached = diff
|
|
305
305
|
return diff
|
|
306
306
|
|
|
307
|
-
def expect_only(self, allowed_changes:
|
|
307
|
+
def expect_only(self, allowed_changes: List[Dict[str, Any]]):
|
|
308
308
|
"""Ensure only specified changes occurred."""
|
|
309
309
|
diff = self._collect()
|
|
310
310
|
|
|
311
311
|
def _is_change_allowed(
|
|
312
|
-
table: str, row_id: Any, field: str
|
|
312
|
+
table: str, row_id: Any, field: Optional[str], after_value: Any
|
|
313
313
|
) -> bool:
|
|
314
314
|
"""Check if a change is in the allowed list using semantic comparison."""
|
|
315
315
|
for allowed in allowed_changes:
|
|
@@ -440,11 +440,11 @@ class SyncQueryBuilder:
|
|
|
440
440
|
def __init__(self, resource: "SQLiteResource", table: str):
|
|
441
441
|
self._resource = resource
|
|
442
442
|
self._table = table
|
|
443
|
-
self._select_cols:
|
|
444
|
-
self._conditions:
|
|
445
|
-
self._joins:
|
|
446
|
-
self._limit: int
|
|
447
|
-
self._order_by: str
|
|
443
|
+
self._select_cols: List[str] = ["*"]
|
|
444
|
+
self._conditions: List[Tuple[str, str, Any]] = []
|
|
445
|
+
self._joins: List[Tuple[str, Dict[str, str]]] = []
|
|
446
|
+
self._limit: Optional[int] = None
|
|
447
|
+
self._order_by: Optional[str] = None
|
|
448
448
|
|
|
449
449
|
# Column projection / limiting / ordering
|
|
450
450
|
def select(self, *columns: str) -> "SyncQueryBuilder":
|
|
@@ -486,12 +486,12 @@ class SyncQueryBuilder:
|
|
|
486
486
|
def lte(self, column: str, value: Any) -> "SyncQueryBuilder":
|
|
487
487
|
return self._add_condition(column, "<=", value)
|
|
488
488
|
|
|
489
|
-
def in_(self, column: str, values:
|
|
489
|
+
def in_(self, column: str, values: List[Any]) -> "SyncQueryBuilder":
|
|
490
490
|
qb = self._clone()
|
|
491
491
|
qb._conditions.append((column, "IN", tuple(values)))
|
|
492
492
|
return qb
|
|
493
493
|
|
|
494
|
-
def not_in(self, column: str, values:
|
|
494
|
+
def not_in(self, column: str, values: List[Any]) -> "SyncQueryBuilder":
|
|
495
495
|
qb = self._clone()
|
|
496
496
|
qb._conditions.append((column, "NOT IN", tuple(values)))
|
|
497
497
|
return qb
|
|
@@ -508,16 +508,16 @@ class SyncQueryBuilder:
|
|
|
508
508
|
return qb
|
|
509
509
|
|
|
510
510
|
# JOIN
|
|
511
|
-
def join(self, other_table: str, on:
|
|
511
|
+
def join(self, other_table: str, on: Dict[str, str]) -> "SyncQueryBuilder":
|
|
512
512
|
qb = self._clone()
|
|
513
513
|
qb._joins.append((other_table, on))
|
|
514
514
|
return qb
|
|
515
515
|
|
|
516
516
|
# Compile to SQL
|
|
517
|
-
def _compile(self) ->
|
|
517
|
+
def _compile(self) -> Tuple[str, List[Any]]:
|
|
518
518
|
cols = ", ".join(self._select_cols)
|
|
519
519
|
sql = [f"SELECT {cols} FROM {self._table}"]
|
|
520
|
-
params:
|
|
520
|
+
params: List[Any] = []
|
|
521
521
|
|
|
522
522
|
# Joins
|
|
523
523
|
for tbl, onmap in self._joins:
|
|
@@ -558,11 +558,11 @@ class SyncQueryBuilder:
|
|
|
558
558
|
return row_dict.get("__cnt__", 0)
|
|
559
559
|
return 0
|
|
560
560
|
|
|
561
|
-
def first(self) ->
|
|
561
|
+
def first(self) -> Optional[Dict[str, Any]]:
|
|
562
562
|
rows = self.limit(1).all()
|
|
563
563
|
return rows[0] if rows else None
|
|
564
564
|
|
|
565
|
-
def all(self) ->
|
|
565
|
+
def all(self) -> List[Dict[str, Any]]:
|
|
566
566
|
sql, params = self._compile()
|
|
567
567
|
response = self._resource.query(sql, params)
|
|
568
568
|
if not response.rows:
|
|
@@ -651,9 +651,7 @@ class SQLiteResource(Resource):
|
|
|
651
651
|
)
|
|
652
652
|
return DescribeResponse(**response.json())
|
|
653
653
|
|
|
654
|
-
def query(
|
|
655
|
-
self, query: str, args: Optional[List[Any]] = None
|
|
656
|
-
) -> QueryResponse:
|
|
654
|
+
def query(self, query: str, args: Optional[List[Any]] = None) -> QueryResponse:
|
|
657
655
|
return self._query(query, args, read_only=True)
|
|
658
656
|
|
|
659
657
|
def exec(self, query: str, args: Optional[List[Any]] = None) -> QueryResponse:
|
|
@@ -674,7 +672,7 @@ class SQLiteResource(Resource):
|
|
|
674
672
|
"""Create a query builder for the specified table."""
|
|
675
673
|
return SyncQueryBuilder(self, table_name)
|
|
676
674
|
|
|
677
|
-
def snapshot(self, name: str
|
|
675
|
+
def snapshot(self, name: Optional[str] = None) -> SyncDatabaseSnapshot:
|
|
678
676
|
"""Create a snapshot of the current database state."""
|
|
679
677
|
snapshot = SyncDatabaseSnapshot(self, name)
|
|
680
678
|
snapshot._ensure_fetched()
|
|
@@ -683,7 +681,7 @@ class SQLiteResource(Resource):
|
|
|
683
681
|
def diff(
|
|
684
682
|
self,
|
|
685
683
|
other: "SQLiteResource",
|
|
686
|
-
ignore_config: IgnoreConfig
|
|
684
|
+
ignore_config: Optional[IgnoreConfig] = None,
|
|
687
685
|
) -> SyncSnapshotDiff:
|
|
688
686
|
"""Compare this database with another AsyncSQLiteResource.
|
|
689
687
|
|
|
@@ -695,12 +693,8 @@ class SQLiteResource(Resource):
|
|
|
695
693
|
AsyncSnapshotDiff: Object containing the differences between the two databases
|
|
696
694
|
"""
|
|
697
695
|
# Create snapshots of both databases
|
|
698
|
-
before_snapshot = self.snapshot(
|
|
699
|
-
|
|
700
|
-
)
|
|
701
|
-
after_snapshot = other.snapshot(
|
|
702
|
-
name=f"after_{datetime.utcnow().isoformat()}"
|
|
703
|
-
)
|
|
696
|
+
before_snapshot = self.snapshot(name=f"before_{datetime.utcnow().isoformat()}")
|
|
697
|
+
after_snapshot = other.snapshot(name=f"after_{datetime.utcnow().isoformat()}")
|
|
704
698
|
|
|
705
699
|
# Return the diff between the snapshots
|
|
706
700
|
return before_snapshot.diff(after_snapshot, ignore_config)
|
fleet/tasks.py
CHANGED
|
@@ -47,7 +47,7 @@ class Task(BaseModel):
|
|
|
47
47
|
@property
|
|
48
48
|
def env_key(self) -> str:
|
|
49
49
|
"""Get the environment key combining env_id and version."""
|
|
50
|
-
if self.version and self.version != "None":
|
|
50
|
+
if self.version and self.version != "None" and ":" not in self.env_id:
|
|
51
51
|
return f"{self.env_id}:{self.version}"
|
|
52
52
|
return self.env_id
|
|
53
53
|
|
|
@@ -70,17 +70,13 @@ class Task(BaseModel):
|
|
|
70
70
|
import inspect
|
|
71
71
|
|
|
72
72
|
# Check if verifier has remote method (for decorated verifiers)
|
|
73
|
-
|
|
74
|
-
result = self.verifier.remote(env, *args, **kwargs)
|
|
75
|
-
else:
|
|
76
|
-
# For verifiers created from string, call directly
|
|
77
|
-
result = self.verifier(env, *args, **kwargs)
|
|
73
|
+
result = self.verifier.remote(env, *args, **kwargs)
|
|
78
74
|
|
|
79
75
|
# If the result is a coroutine, we need to run it
|
|
80
76
|
if inspect.iscoroutine(result):
|
|
81
77
|
# Check if we're already in an event loop
|
|
82
78
|
try:
|
|
83
|
-
|
|
79
|
+
asyncio.get_running_loop()
|
|
84
80
|
# We're in an async context, can't use asyncio.run()
|
|
85
81
|
raise RuntimeError(
|
|
86
82
|
"Cannot run async verifier in sync mode while event loop is running. "
|
|
@@ -141,7 +137,7 @@ def verifier_from_string(
|
|
|
141
137
|
"""
|
|
142
138
|
try:
|
|
143
139
|
import inspect
|
|
144
|
-
from .verifiers import
|
|
140
|
+
from .verifiers import SyncVerifierFunction
|
|
145
141
|
from .verifiers.code import TASK_SUCCESSFUL_SCORE, TASK_FAILED_SCORE
|
|
146
142
|
from .verifiers.db import IgnoreConfig
|
|
147
143
|
|
|
@@ -172,36 +168,13 @@ def verifier_from_string(
|
|
|
172
168
|
if func_obj is None:
|
|
173
169
|
raise ValueError("No function found in verifier code")
|
|
174
170
|
|
|
175
|
-
# Create
|
|
176
|
-
def wrapped_verifier(env, *args, **kwargs):
|
|
177
|
-
# Set up globals for the function execution
|
|
178
|
-
func_globals = (
|
|
179
|
-
func_obj.__globals__.copy() if hasattr(func_obj, "__globals__") else {}
|
|
180
|
-
)
|
|
181
|
-
func_globals.update(
|
|
182
|
-
{
|
|
183
|
-
"TASK_SUCCESSFUL_SCORE": TASK_SUCCESSFUL_SCORE,
|
|
184
|
-
"TASK_FAILED_SCORE": TASK_FAILED_SCORE,
|
|
185
|
-
"IgnoreConfig": IgnoreConfig,
|
|
186
|
-
}
|
|
187
|
-
)
|
|
188
|
-
|
|
189
|
-
# Create a new function with the updated globals
|
|
190
|
-
import types
|
|
191
|
-
|
|
192
|
-
new_func = types.FunctionType(
|
|
193
|
-
func_obj.__code__,
|
|
194
|
-
func_globals,
|
|
195
|
-
func_obj.__name__,
|
|
196
|
-
func_obj.__defaults__,
|
|
197
|
-
func_obj.__closure__,
|
|
198
|
-
)
|
|
199
|
-
|
|
200
|
-
return new_func(env, *args, **kwargs)
|
|
201
|
-
|
|
202
|
-
# Create an AsyncVerifierFunction instance with the wrapped function
|
|
171
|
+
# Create an SyncVerifierFunction instance with raw code
|
|
203
172
|
verifier_instance = SyncVerifierFunction(
|
|
204
|
-
|
|
173
|
+
func=func_obj,
|
|
174
|
+
key=verifier_key,
|
|
175
|
+
verifier_id=verifier_id,
|
|
176
|
+
sha256=sha256,
|
|
177
|
+
raw_code=verifier_func,
|
|
205
178
|
)
|
|
206
179
|
|
|
207
180
|
# Store additional metadata
|
|
@@ -214,6 +187,18 @@ def verifier_from_string(
|
|
|
214
187
|
raise ValueError(f"Failed to create verifier from string: {e}")
|
|
215
188
|
|
|
216
189
|
|
|
190
|
+
def load_tasks_from_file(filename: str) -> List[Task]:
|
|
191
|
+
"""Load tasks from a JSON file.
|
|
192
|
+
|
|
193
|
+
Example:
|
|
194
|
+
tasks = fleet.load_tasks_from_file("my_tasks.json")
|
|
195
|
+
"""
|
|
196
|
+
from .global_client import get_client
|
|
197
|
+
|
|
198
|
+
client = get_client()
|
|
199
|
+
return client.load_tasks_from_file(filename)
|
|
200
|
+
|
|
201
|
+
|
|
217
202
|
def load_tasks(
|
|
218
203
|
env_key: Optional[str] = None,
|
|
219
204
|
keys: Optional[List[str]] = None,
|
fleet/verifiers/__init__.py
CHANGED
fleet/verifiers/db.py
CHANGED
|
@@ -22,9 +22,11 @@ import json
|
|
|
22
22
|
# Low‑level helpers
|
|
23
23
|
################################################################################
|
|
24
24
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
25
|
+
from typing import Union, Tuple, Dict, List, Optional, Any, Set
|
|
26
|
+
|
|
27
|
+
SQLValue = Union[str, int, float, None]
|
|
28
|
+
Condition = Tuple[str, str, SQLValue] # (column, op, value)
|
|
29
|
+
JoinSpec = Tuple[str, Dict[str, str]] # (table, on mapping)
|
|
28
30
|
|
|
29
31
|
|
|
30
32
|
def _is_json_string(value: Any) -> bool:
|
|
@@ -98,13 +100,13 @@ class QueryBuilder:
|
|
|
98
100
|
def __init__(self, snapshot: "DatabaseSnapshot", table: str): # noqa: UP037
|
|
99
101
|
self._snapshot = snapshot
|
|
100
102
|
self._table = table
|
|
101
|
-
self._select_cols:
|
|
102
|
-
self._conditions:
|
|
103
|
-
self._joins:
|
|
104
|
-
self._limit: int
|
|
105
|
-
self._order_by: str
|
|
103
|
+
self._select_cols: List[str] = ["*"]
|
|
104
|
+
self._conditions: List[Condition] = []
|
|
105
|
+
self._joins: List[JoinSpec] = []
|
|
106
|
+
self._limit: Optional[int] = None
|
|
107
|
+
self._order_by: Optional[str] = None
|
|
106
108
|
# Cache for idempotent executions
|
|
107
|
-
self._cached_rows:
|
|
109
|
+
self._cached_rows: Optional[List[Dict[str, Any]]] = None
|
|
108
110
|
|
|
109
111
|
# ---------------------------------------------------------------------
|
|
110
112
|
# Column projection / limiting / ordering
|
|
@@ -150,12 +152,12 @@ class QueryBuilder:
|
|
|
150
152
|
def lte(self, column: str, value: SQLValue) -> "QueryBuilder": # noqa: UP037
|
|
151
153
|
return self._add_condition(column, "<=", value)
|
|
152
154
|
|
|
153
|
-
def in_(self, column: str, values:
|
|
155
|
+
def in_(self, column: str, values: List[SQLValue]) -> "QueryBuilder": # noqa: UP037
|
|
154
156
|
qb = self._clone()
|
|
155
157
|
qb._conditions.append((column, "IN", tuple(values)))
|
|
156
158
|
return qb
|
|
157
159
|
|
|
158
|
-
def not_in(self, column: str, values:
|
|
160
|
+
def not_in(self, column: str, values: List[SQLValue]) -> "QueryBuilder": # noqa: UP037
|
|
159
161
|
qb = self._clone()
|
|
160
162
|
qb._conditions.append((column, "NOT IN", tuple(values)))
|
|
161
163
|
return qb
|
|
@@ -174,7 +176,7 @@ class QueryBuilder:
|
|
|
174
176
|
# ---------------------------------------------------------------------
|
|
175
177
|
# JOIN (simple inner join)
|
|
176
178
|
# ---------------------------------------------------------------------
|
|
177
|
-
def join(self, other_table: str, on:
|
|
179
|
+
def join(self, other_table: str, on: Dict[str, str]) -> "QueryBuilder": # noqa: UP037
|
|
178
180
|
"""`on` expects {local_col: remote_col}."""
|
|
179
181
|
qb = self._clone()
|
|
180
182
|
qb._joins.append((other_table, on))
|
|
@@ -183,10 +185,10 @@ class QueryBuilder:
|
|
|
183
185
|
# ---------------------------------------------------------------------
|
|
184
186
|
# Execution helpers
|
|
185
187
|
# ---------------------------------------------------------------------
|
|
186
|
-
def _compile(self) ->
|
|
188
|
+
def _compile(self) -> Tuple[str, List[Any]]:
|
|
187
189
|
cols = ", ".join(self._select_cols)
|
|
188
190
|
sql = [f"SELECT {cols} FROM {self._table}"]
|
|
189
|
-
params:
|
|
191
|
+
params: List[Any] = []
|
|
190
192
|
|
|
191
193
|
# Joins -------------------------------------------------------------
|
|
192
194
|
for tbl, onmap in self._joins:
|
|
@@ -224,7 +226,7 @@ class QueryBuilder:
|
|
|
224
226
|
|
|
225
227
|
return " ".join(sql), params
|
|
226
228
|
|
|
227
|
-
def _execute(self) ->
|
|
229
|
+
def _execute(self) -> List[Dict[str, Any]]:
|
|
228
230
|
if self._cached_rows is not None:
|
|
229
231
|
return self._cached_rows
|
|
230
232
|
|
|
@@ -255,10 +257,10 @@ class QueryBuilder:
|
|
|
255
257
|
conn.close()
|
|
256
258
|
return _CountResult(val)
|
|
257
259
|
|
|
258
|
-
def first(self) ->
|
|
260
|
+
def first(self) -> Optional[Dict[str, Any]]:
|
|
259
261
|
return self.limit(1)._execute()[0] if self.limit(1)._execute() else None
|
|
260
262
|
|
|
261
|
-
def all(self) ->
|
|
263
|
+
def all(self) -> List[Dict[str, Any]]:
|
|
262
264
|
return self._execute()
|
|
263
265
|
|
|
264
266
|
# Assertions -----------------------------------------------------------
|
|
@@ -357,9 +359,9 @@ class IgnoreConfig:
|
|
|
357
359
|
|
|
358
360
|
def __init__(
|
|
359
361
|
self,
|
|
360
|
-
tables:
|
|
361
|
-
fields:
|
|
362
|
-
table_fields:
|
|
362
|
+
tables: Optional[Set[str]] = None,
|
|
363
|
+
fields: Optional[Set[str]] = None,
|
|
364
|
+
table_fields: Optional[Dict[str, Set[str]]] = None,
|
|
363
365
|
):
|
|
364
366
|
"""
|
|
365
367
|
Args:
|
|
@@ -386,7 +388,7 @@ class IgnoreConfig:
|
|
|
386
388
|
return False
|
|
387
389
|
|
|
388
390
|
|
|
389
|
-
def _format_row_for_error(row:
|
|
391
|
+
def _format_row_for_error(row: Dict[str, Any], max_fields: int = 10) -> str:
|
|
390
392
|
"""Format a row dictionary for error messages with truncation if needed."""
|
|
391
393
|
if not row:
|
|
392
394
|
return "{empty row}"
|
|
@@ -402,7 +404,7 @@ def _format_row_for_error(row: dict[str, Any], max_fields: int = 10) -> str:
|
|
|
402
404
|
return "{" + ", ".join(shown_items) + f", ... +{remaining} more fields" + "}"
|
|
403
405
|
|
|
404
406
|
|
|
405
|
-
def _get_row_identifier(row:
|
|
407
|
+
def _get_row_identifier(row: Dict[str, Any]) -> str:
|
|
406
408
|
"""Extract a meaningful identifier from a row for error messages."""
|
|
407
409
|
# Try common ID fields first
|
|
408
410
|
for id_field in ["id", "pk", "primary_key", "key"]:
|
|
@@ -429,7 +431,7 @@ class SnapshotDiff:
|
|
|
429
431
|
self,
|
|
430
432
|
before: DatabaseSnapshot,
|
|
431
433
|
after: DatabaseSnapshot,
|
|
432
|
-
ignore_config: IgnoreConfig
|
|
434
|
+
ignore_config: Optional[IgnoreConfig] = None,
|
|
433
435
|
):
|
|
434
436
|
from .sql_differ import SQLiteDiffer # local import to avoid circularity
|
|
435
437
|
|
|
@@ -437,14 +439,14 @@ class SnapshotDiff:
|
|
|
437
439
|
self.after = after
|
|
438
440
|
self.ignore_config = ignore_config or IgnoreConfig()
|
|
439
441
|
self._differ = SQLiteDiffer(before.db_path, after.db_path)
|
|
440
|
-
self._cached:
|
|
442
|
+
self._cached: Optional[Dict[str, Any]] = None
|
|
441
443
|
|
|
442
444
|
# ------------------------------------------------------------------
|
|
443
445
|
def _collect(self):
|
|
444
446
|
if self._cached is not None:
|
|
445
447
|
return self._cached
|
|
446
448
|
all_tables = set(self.before.tables()) | set(self.after.tables())
|
|
447
|
-
diff:
|
|
449
|
+
diff: Dict[str, Dict[str, Any]] = {}
|
|
448
450
|
for tbl in all_tables:
|
|
449
451
|
if self.ignore_config.should_ignore_table(tbl):
|
|
450
452
|
continue
|
|
@@ -453,12 +455,12 @@ class SnapshotDiff:
|
|
|
453
455
|
return diff
|
|
454
456
|
|
|
455
457
|
# ------------------------------------------------------------------
|
|
456
|
-
def expect_only(self, allowed_changes:
|
|
458
|
+
def expect_only(self, allowed_changes: List[Dict[str, Any]]):
|
|
457
459
|
"""Allowed changes is a list of {table, pk, field, after} (before optional)."""
|
|
458
460
|
diff = self._collect()
|
|
459
461
|
|
|
460
462
|
def _is_change_allowed(
|
|
461
|
-
table: str, row_id: str, field: str
|
|
463
|
+
table: str, row_id: str, field: Optional[str], after_value: Any
|
|
462
464
|
) -> bool:
|
|
463
465
|
"""Check if a change is in the allowed list using semantic comparison."""
|
|
464
466
|
for allowed in allowed_changes:
|
|
@@ -596,7 +598,10 @@ class SnapshotDiff:
|
|
|
596
598
|
return self
|
|
597
599
|
|
|
598
600
|
def expect(
|
|
599
|
-
self,
|
|
601
|
+
self,
|
|
602
|
+
*,
|
|
603
|
+
allow: Optional[List[Dict[str, Any]]] = None,
|
|
604
|
+
forbid: Optional[List[Dict[str, Any]]] = None,
|
|
600
605
|
):
|
|
601
606
|
"""More granular: allow / forbid per‑table and per‑field."""
|
|
602
607
|
allow = allow or []
|
|
@@ -629,7 +634,7 @@ class SnapshotDiff:
|
|
|
629
634
|
class DatabaseSnapshot:
|
|
630
635
|
"""Represents a snapshot of an SQLite DB with DSL entrypoints."""
|
|
631
636
|
|
|
632
|
-
def __init__(self, db_path: str, *, name: str
|
|
637
|
+
def __init__(self, db_path: str, *, name: Optional[str] = None):
|
|
633
638
|
self.db_path = db_path
|
|
634
639
|
self.name = name or f"snapshot_{datetime.utcnow().isoformat()}"
|
|
635
640
|
self.created_at = datetime.utcnow()
|
|
@@ -639,7 +644,7 @@ class DatabaseSnapshot:
|
|
|
639
644
|
return QueryBuilder(self, table)
|
|
640
645
|
|
|
641
646
|
# Metadata -------------------------------------------------------------
|
|
642
|
-
def tables(self) ->
|
|
647
|
+
def tables(self) -> List[str]:
|
|
643
648
|
conn = sqlite3.connect(self.db_path)
|
|
644
649
|
cur = conn.cursor()
|
|
645
650
|
cur.execute(
|
|
@@ -654,7 +659,7 @@ class DatabaseSnapshot:
|
|
|
654
659
|
def diff(
|
|
655
660
|
self,
|
|
656
661
|
other: "DatabaseSnapshot", # noqa: UP037
|
|
657
|
-
ignore_config: IgnoreConfig
|
|
662
|
+
ignore_config: Optional[IgnoreConfig] = None,
|
|
658
663
|
) -> SnapshotDiff:
|
|
659
664
|
return SnapshotDiff(self, other, ignore_config)
|
|
660
665
|
|
|
@@ -663,7 +668,7 @@ class DatabaseSnapshot:
|
|
|
663
668
|
############################################################################
|
|
664
669
|
|
|
665
670
|
def expect_row(
|
|
666
|
-
self, table: str, where:
|
|
671
|
+
self, table: str, where: Dict[str, SQLValue], expect: Dict[str, SQLValue]
|
|
667
672
|
):
|
|
668
673
|
qb = self.table(table)
|
|
669
674
|
for k, v in where.items():
|
|
@@ -676,10 +681,10 @@ class DatabaseSnapshot:
|
|
|
676
681
|
def expect_rows(
|
|
677
682
|
self,
|
|
678
683
|
table: str,
|
|
679
|
-
where:
|
|
684
|
+
where: Dict[str, SQLValue],
|
|
680
685
|
*,
|
|
681
|
-
count: int
|
|
682
|
-
contains:
|
|
686
|
+
count: Optional[int] = None,
|
|
687
|
+
contains: Optional[List[Dict[str, SQLValue]]] = None,
|
|
683
688
|
):
|
|
684
689
|
qb = self.table(table)
|
|
685
690
|
for k, v in where.items():
|
|
@@ -694,7 +699,7 @@ class DatabaseSnapshot:
|
|
|
694
699
|
raise AssertionError(f"Expected a row matching {cond} in {table}")
|
|
695
700
|
return self
|
|
696
701
|
|
|
697
|
-
def expect_absent_row(self, table: str, where:
|
|
702
|
+
def expect_absent_row(self, table: str, where: Dict[str, SQLValue]):
|
|
698
703
|
qb = self.table(table)
|
|
699
704
|
for k, v in where.items():
|
|
700
705
|
qb = qb.eq(k, v)
|
fleet/verifiers/parse.py
CHANGED
fleet/verifiers/sql_differ.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import sqlite3
|
|
2
|
-
from typing import Any
|
|
2
|
+
from typing import Any, Optional, List, Dict, Tuple
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
class SQLiteDiffer:
|
|
@@ -7,7 +7,7 @@ class SQLiteDiffer:
|
|
|
7
7
|
self.before_db = before_db
|
|
8
8
|
self.after_db = after_db
|
|
9
9
|
|
|
10
|
-
def get_table_schema(self, db_path: str, table_name: str) ->
|
|
10
|
+
def get_table_schema(self, db_path: str, table_name: str) -> List[str]:
|
|
11
11
|
"""Get column names for a table"""
|
|
12
12
|
conn = sqlite3.connect(db_path)
|
|
13
13
|
cursor = conn.cursor()
|
|
@@ -16,7 +16,7 @@ class SQLiteDiffer:
|
|
|
16
16
|
conn.close()
|
|
17
17
|
return columns
|
|
18
18
|
|
|
19
|
-
def get_primary_key_columns(self, db_path: str, table_name: str) ->
|
|
19
|
+
def get_primary_key_columns(self, db_path: str, table_name: str) -> List[str]:
|
|
20
20
|
"""Get all primary key columns for a table, ordered by their position"""
|
|
21
21
|
conn = sqlite3.connect(db_path)
|
|
22
22
|
cursor = conn.cursor()
|
|
@@ -34,7 +34,7 @@ class SQLiteDiffer:
|
|
|
34
34
|
pk_columns.sort(key=lambda x: x[0])
|
|
35
35
|
return [col[1] for col in pk_columns]
|
|
36
36
|
|
|
37
|
-
def get_all_tables(self, db_path: str) ->
|
|
37
|
+
def get_all_tables(self, db_path: str) -> List[str]:
|
|
38
38
|
"""Get all table names from database"""
|
|
39
39
|
conn = sqlite3.connect(db_path)
|
|
40
40
|
cursor = conn.cursor()
|
|
@@ -49,8 +49,8 @@ class SQLiteDiffer:
|
|
|
49
49
|
self,
|
|
50
50
|
db_path: str,
|
|
51
51
|
table_name: str,
|
|
52
|
-
primary_key_columns:
|
|
53
|
-
) ->
|
|
52
|
+
primary_key_columns: Optional[List[str]] = None,
|
|
53
|
+
) -> Tuple[Dict[Any, dict], List[str]]:
|
|
54
54
|
"""Get table data indexed by primary key (single column or composite)"""
|
|
55
55
|
conn = sqlite3.connect(db_path)
|
|
56
56
|
conn.row_factory = sqlite3.Row
|
|
@@ -97,7 +97,7 @@ class SQLiteDiffer:
|
|
|
97
97
|
conn.close()
|
|
98
98
|
return data, primary_key_columns
|
|
99
99
|
|
|
100
|
-
def compare_rows(self, before_row: dict, after_row: dict) ->
|
|
100
|
+
def compare_rows(self, before_row: dict, after_row: dict) -> Dict[str, dict]:
|
|
101
101
|
"""Compare two rows field by field"""
|
|
102
102
|
changes = {}
|
|
103
103
|
|
|
@@ -113,7 +113,7 @@ class SQLiteDiffer:
|
|
|
113
113
|
return changes
|
|
114
114
|
|
|
115
115
|
def diff_table(
|
|
116
|
-
self, table_name: str, primary_key_columns:
|
|
116
|
+
self, table_name: str, primary_key_columns: Optional[List[str]] = None
|
|
117
117
|
) -> dict:
|
|
118
118
|
"""Create comprehensive diff of a table"""
|
|
119
119
|
before_data, detected_pk = self.get_table_data(
|