fleet-python 0.2.42__py3-none-any.whl → 0.2.44__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fleet-python might be problematic. Click here for more details.
- fleet/__init__.py +9 -7
- fleet/_async/__init__.py +28 -16
- fleet/_async/client.py +161 -69
- fleet/_async/env/client.py +9 -2
- fleet/_async/instance/client.py +1 -1
- fleet/_async/resources/sqlite.py +34 -34
- fleet/_async/tasks.py +42 -41
- fleet/_async/verifiers/verifier.py +3 -3
- fleet/client.py +164 -61
- fleet/env/client.py +9 -2
- fleet/instance/client.py +2 -4
- fleet/models.py +3 -1
- fleet/resources/sqlite.py +37 -43
- fleet/tasks.py +49 -65
- fleet/verifiers/__init__.py +1 -1
- fleet/verifiers/db.py +41 -36
- fleet/verifiers/parse.py +4 -1
- fleet/verifiers/sql_differ.py +8 -8
- fleet/verifiers/verifier.py +19 -7
- {fleet_python-0.2.42.dist-info → fleet_python-0.2.44.dist-info}/METADATA +1 -1
- {fleet_python-0.2.42.dist-info → fleet_python-0.2.44.dist-info}/RECORD +24 -24
- {fleet_python-0.2.42.dist-info → fleet_python-0.2.44.dist-info}/WHEEL +0 -0
- {fleet_python-0.2.42.dist-info → fleet_python-0.2.44.dist-info}/licenses/LICENSE +0 -0
- {fleet_python-0.2.42.dist-info → fleet_python-0.2.44.dist-info}/top_level.txt +0 -0
fleet/tasks.py
CHANGED
|
@@ -47,7 +47,7 @@ class Task(BaseModel):
|
|
|
47
47
|
@property
|
|
48
48
|
def env_key(self) -> str:
|
|
49
49
|
"""Get the environment key combining env_id and version."""
|
|
50
|
-
if self.version:
|
|
50
|
+
if self.version and self.version != "None" and ":" not in self.env_id:
|
|
51
51
|
return f"{self.env_id}:{self.version}"
|
|
52
52
|
return self.env_id
|
|
53
53
|
|
|
@@ -70,17 +70,13 @@ class Task(BaseModel):
|
|
|
70
70
|
import inspect
|
|
71
71
|
|
|
72
72
|
# Check if verifier has remote method (for decorated verifiers)
|
|
73
|
-
|
|
74
|
-
result = self.verifier.remote(env, *args, **kwargs)
|
|
75
|
-
else:
|
|
76
|
-
# For verifiers created from string, call directly
|
|
77
|
-
result = self.verifier(env, *args, **kwargs)
|
|
73
|
+
result = self.verifier.remote(env, *args, **kwargs)
|
|
78
74
|
|
|
79
75
|
# If the result is a coroutine, we need to run it
|
|
80
76
|
if inspect.iscoroutine(result):
|
|
81
77
|
# Check if we're already in an event loop
|
|
82
78
|
try:
|
|
83
|
-
|
|
79
|
+
asyncio.get_running_loop()
|
|
84
80
|
# We're in an async context, can't use asyncio.run()
|
|
85
81
|
raise RuntimeError(
|
|
86
82
|
"Cannot run async verifier in sync mode while event loop is running. "
|
|
@@ -126,93 +122,88 @@ class Task(BaseModel):
|
|
|
126
122
|
|
|
127
123
|
|
|
128
124
|
def verifier_from_string(
|
|
129
|
-
verifier_func: str,
|
|
130
|
-
|
|
131
|
-
verifier_key: str,
|
|
132
|
-
sha256: str = ""
|
|
133
|
-
) -> 'VerifierFunction':
|
|
125
|
+
verifier_func: str, verifier_id: str, verifier_key: str, sha256: str = ""
|
|
126
|
+
) -> "VerifierFunction":
|
|
134
127
|
"""Create a verifier function from string code.
|
|
135
|
-
|
|
128
|
+
|
|
136
129
|
Args:
|
|
137
130
|
verifier_func: The verifier function code as a string
|
|
138
131
|
verifier_id: Unique identifier for the verifier
|
|
139
132
|
verifier_key: Key/name for the verifier
|
|
140
133
|
sha256: SHA256 hash of the verifier code
|
|
141
|
-
|
|
134
|
+
|
|
142
135
|
Returns:
|
|
143
136
|
VerifierFunction instance that can be used to verify tasks
|
|
144
137
|
"""
|
|
145
138
|
try:
|
|
146
139
|
import inspect
|
|
147
|
-
from .verifiers import
|
|
140
|
+
from .verifiers import SyncVerifierFunction
|
|
148
141
|
from .verifiers.code import TASK_SUCCESSFUL_SCORE, TASK_FAILED_SCORE
|
|
149
142
|
from .verifiers.db import IgnoreConfig
|
|
150
|
-
|
|
143
|
+
|
|
151
144
|
# Create a globals namespace with all required imports
|
|
152
145
|
exec_globals = globals().copy()
|
|
153
|
-
exec_globals.update(
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
146
|
+
exec_globals.update(
|
|
147
|
+
{
|
|
148
|
+
"TASK_SUCCESSFUL_SCORE": TASK_SUCCESSFUL_SCORE,
|
|
149
|
+
"TASK_FAILED_SCORE": TASK_FAILED_SCORE,
|
|
150
|
+
"IgnoreConfig": IgnoreConfig,
|
|
151
|
+
"Environment": object, # Add Environment type if needed
|
|
152
|
+
}
|
|
153
|
+
)
|
|
154
|
+
|
|
160
155
|
# Create a local namespace for executing the code
|
|
161
156
|
local_namespace = {}
|
|
162
|
-
|
|
157
|
+
|
|
163
158
|
# Execute the verifier code in the namespace
|
|
164
159
|
exec(verifier_func, exec_globals, local_namespace)
|
|
165
|
-
|
|
160
|
+
|
|
166
161
|
# Find the function that was defined
|
|
167
162
|
func_obj = None
|
|
168
163
|
for name, obj in local_namespace.items():
|
|
169
164
|
if inspect.isfunction(obj):
|
|
170
165
|
func_obj = obj
|
|
171
166
|
break
|
|
172
|
-
|
|
167
|
+
|
|
173
168
|
if func_obj is None:
|
|
174
169
|
raise ValueError("No function found in verifier code")
|
|
175
|
-
|
|
176
|
-
# Create
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
# Create a new function with the updated globals
|
|
187
|
-
import types
|
|
188
|
-
new_func = types.FunctionType(
|
|
189
|
-
func_obj.__code__,
|
|
190
|
-
func_globals,
|
|
191
|
-
func_obj.__name__,
|
|
192
|
-
func_obj.__defaults__,
|
|
193
|
-
func_obj.__closure__
|
|
194
|
-
)
|
|
195
|
-
|
|
196
|
-
return new_func(env, *args, **kwargs)
|
|
197
|
-
|
|
198
|
-
# Create an AsyncVerifierFunction instance with the wrapped function
|
|
199
|
-
verifier_instance = SyncVerifierFunction(wrapped_verifier, verifier_key, verifier_id)
|
|
200
|
-
|
|
170
|
+
|
|
171
|
+
# Create an SyncVerifierFunction instance with raw code
|
|
172
|
+
verifier_instance = SyncVerifierFunction(
|
|
173
|
+
func=func_obj,
|
|
174
|
+
key=verifier_key,
|
|
175
|
+
verifier_id=verifier_id,
|
|
176
|
+
sha256=sha256,
|
|
177
|
+
raw_code=verifier_func,
|
|
178
|
+
)
|
|
179
|
+
|
|
201
180
|
# Store additional metadata
|
|
202
181
|
verifier_instance._verifier_code = verifier_func
|
|
203
182
|
verifier_instance._sha256 = sha256
|
|
204
|
-
|
|
183
|
+
|
|
205
184
|
return verifier_instance
|
|
206
|
-
|
|
185
|
+
|
|
207
186
|
except Exception as e:
|
|
208
187
|
raise ValueError(f"Failed to create verifier from string: {e}")
|
|
209
188
|
|
|
210
189
|
|
|
190
|
+
def load_tasks_from_file(filename: str) -> List[Task]:
|
|
191
|
+
"""Load tasks from a JSON file.
|
|
192
|
+
|
|
193
|
+
Example:
|
|
194
|
+
tasks = fleet.load_tasks_from_file("my_tasks.json")
|
|
195
|
+
"""
|
|
196
|
+
from .global_client import get_client
|
|
197
|
+
|
|
198
|
+
client = get_client()
|
|
199
|
+
return client.load_tasks_from_file(filename)
|
|
200
|
+
|
|
201
|
+
|
|
211
202
|
def load_tasks(
|
|
212
203
|
env_key: Optional[str] = None,
|
|
213
204
|
keys: Optional[List[str]] = None,
|
|
214
205
|
version: Optional[str] = None,
|
|
215
|
-
team_id: Optional[str] = None
|
|
206
|
+
team_id: Optional[str] = None,
|
|
216
207
|
) -> List[Task]:
|
|
217
208
|
"""Convenience function to load tasks with optional filtering.
|
|
218
209
|
|
|
@@ -232,17 +223,12 @@ def load_tasks(
|
|
|
232
223
|
|
|
233
224
|
client = get_client()
|
|
234
225
|
return client.load_tasks(
|
|
235
|
-
env_key=env_key,
|
|
236
|
-
keys=keys,
|
|
237
|
-
version=version,
|
|
238
|
-
team_id=team_id
|
|
226
|
+
env_key=env_key, keys=keys, version=version, team_id=team_id
|
|
239
227
|
)
|
|
240
228
|
|
|
241
229
|
|
|
242
230
|
def update_task(
|
|
243
|
-
task_key: str,
|
|
244
|
-
prompt: Optional[str] = None,
|
|
245
|
-
verifier_code: Optional[str] = None
|
|
231
|
+
task_key: str, prompt: Optional[str] = None, verifier_code: Optional[str] = None
|
|
246
232
|
):
|
|
247
233
|
"""Convenience function to update an existing task.
|
|
248
234
|
|
|
@@ -263,7 +249,5 @@ def update_task(
|
|
|
263
249
|
|
|
264
250
|
client = get_client()
|
|
265
251
|
return client.update_task(
|
|
266
|
-
task_key=task_key,
|
|
267
|
-
prompt=prompt,
|
|
268
|
-
verifier_code=verifier_code
|
|
252
|
+
task_key=task_key, prompt=prompt, verifier_code=verifier_code
|
|
269
253
|
)
|
fleet/verifiers/__init__.py
CHANGED
fleet/verifiers/db.py
CHANGED
|
@@ -22,9 +22,11 @@ import json
|
|
|
22
22
|
# Low‑level helpers
|
|
23
23
|
################################################################################
|
|
24
24
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
25
|
+
from typing import Union, Tuple, Dict, List, Optional, Any, Set
|
|
26
|
+
|
|
27
|
+
SQLValue = Union[str, int, float, None]
|
|
28
|
+
Condition = Tuple[str, str, SQLValue] # (column, op, value)
|
|
29
|
+
JoinSpec = Tuple[str, Dict[str, str]] # (table, on mapping)
|
|
28
30
|
|
|
29
31
|
|
|
30
32
|
def _is_json_string(value: Any) -> bool:
|
|
@@ -98,13 +100,13 @@ class QueryBuilder:
|
|
|
98
100
|
def __init__(self, snapshot: "DatabaseSnapshot", table: str): # noqa: UP037
|
|
99
101
|
self._snapshot = snapshot
|
|
100
102
|
self._table = table
|
|
101
|
-
self._select_cols:
|
|
102
|
-
self._conditions:
|
|
103
|
-
self._joins:
|
|
104
|
-
self._limit: int
|
|
105
|
-
self._order_by: str
|
|
103
|
+
self._select_cols: List[str] = ["*"]
|
|
104
|
+
self._conditions: List[Condition] = []
|
|
105
|
+
self._joins: List[JoinSpec] = []
|
|
106
|
+
self._limit: Optional[int] = None
|
|
107
|
+
self._order_by: Optional[str] = None
|
|
106
108
|
# Cache for idempotent executions
|
|
107
|
-
self._cached_rows:
|
|
109
|
+
self._cached_rows: Optional[List[Dict[str, Any]]] = None
|
|
108
110
|
|
|
109
111
|
# ---------------------------------------------------------------------
|
|
110
112
|
# Column projection / limiting / ordering
|
|
@@ -150,12 +152,12 @@ class QueryBuilder:
|
|
|
150
152
|
def lte(self, column: str, value: SQLValue) -> "QueryBuilder": # noqa: UP037
|
|
151
153
|
return self._add_condition(column, "<=", value)
|
|
152
154
|
|
|
153
|
-
def in_(self, column: str, values:
|
|
155
|
+
def in_(self, column: str, values: List[SQLValue]) -> "QueryBuilder": # noqa: UP037
|
|
154
156
|
qb = self._clone()
|
|
155
157
|
qb._conditions.append((column, "IN", tuple(values)))
|
|
156
158
|
return qb
|
|
157
159
|
|
|
158
|
-
def not_in(self, column: str, values:
|
|
160
|
+
def not_in(self, column: str, values: List[SQLValue]) -> "QueryBuilder": # noqa: UP037
|
|
159
161
|
qb = self._clone()
|
|
160
162
|
qb._conditions.append((column, "NOT IN", tuple(values)))
|
|
161
163
|
return qb
|
|
@@ -174,7 +176,7 @@ class QueryBuilder:
|
|
|
174
176
|
# ---------------------------------------------------------------------
|
|
175
177
|
# JOIN (simple inner join)
|
|
176
178
|
# ---------------------------------------------------------------------
|
|
177
|
-
def join(self, other_table: str, on:
|
|
179
|
+
def join(self, other_table: str, on: Dict[str, str]) -> "QueryBuilder": # noqa: UP037
|
|
178
180
|
"""`on` expects {local_col: remote_col}."""
|
|
179
181
|
qb = self._clone()
|
|
180
182
|
qb._joins.append((other_table, on))
|
|
@@ -183,10 +185,10 @@ class QueryBuilder:
|
|
|
183
185
|
# ---------------------------------------------------------------------
|
|
184
186
|
# Execution helpers
|
|
185
187
|
# ---------------------------------------------------------------------
|
|
186
|
-
def _compile(self) ->
|
|
188
|
+
def _compile(self) -> Tuple[str, List[Any]]:
|
|
187
189
|
cols = ", ".join(self._select_cols)
|
|
188
190
|
sql = [f"SELECT {cols} FROM {self._table}"]
|
|
189
|
-
params:
|
|
191
|
+
params: List[Any] = []
|
|
190
192
|
|
|
191
193
|
# Joins -------------------------------------------------------------
|
|
192
194
|
for tbl, onmap in self._joins:
|
|
@@ -224,7 +226,7 @@ class QueryBuilder:
|
|
|
224
226
|
|
|
225
227
|
return " ".join(sql), params
|
|
226
228
|
|
|
227
|
-
def _execute(self) ->
|
|
229
|
+
def _execute(self) -> List[Dict[str, Any]]:
|
|
228
230
|
if self._cached_rows is not None:
|
|
229
231
|
return self._cached_rows
|
|
230
232
|
|
|
@@ -255,10 +257,10 @@ class QueryBuilder:
|
|
|
255
257
|
conn.close()
|
|
256
258
|
return _CountResult(val)
|
|
257
259
|
|
|
258
|
-
def first(self) ->
|
|
260
|
+
def first(self) -> Optional[Dict[str, Any]]:
|
|
259
261
|
return self.limit(1)._execute()[0] if self.limit(1)._execute() else None
|
|
260
262
|
|
|
261
|
-
def all(self) ->
|
|
263
|
+
def all(self) -> List[Dict[str, Any]]:
|
|
262
264
|
return self._execute()
|
|
263
265
|
|
|
264
266
|
# Assertions -----------------------------------------------------------
|
|
@@ -357,9 +359,9 @@ class IgnoreConfig:
|
|
|
357
359
|
|
|
358
360
|
def __init__(
|
|
359
361
|
self,
|
|
360
|
-
tables:
|
|
361
|
-
fields:
|
|
362
|
-
table_fields:
|
|
362
|
+
tables: Optional[Set[str]] = None,
|
|
363
|
+
fields: Optional[Set[str]] = None,
|
|
364
|
+
table_fields: Optional[Dict[str, Set[str]]] = None,
|
|
363
365
|
):
|
|
364
366
|
"""
|
|
365
367
|
Args:
|
|
@@ -386,7 +388,7 @@ class IgnoreConfig:
|
|
|
386
388
|
return False
|
|
387
389
|
|
|
388
390
|
|
|
389
|
-
def _format_row_for_error(row:
|
|
391
|
+
def _format_row_for_error(row: Dict[str, Any], max_fields: int = 10) -> str:
|
|
390
392
|
"""Format a row dictionary for error messages with truncation if needed."""
|
|
391
393
|
if not row:
|
|
392
394
|
return "{empty row}"
|
|
@@ -402,7 +404,7 @@ def _format_row_for_error(row: dict[str, Any], max_fields: int = 10) -> str:
|
|
|
402
404
|
return "{" + ", ".join(shown_items) + f", ... +{remaining} more fields" + "}"
|
|
403
405
|
|
|
404
406
|
|
|
405
|
-
def _get_row_identifier(row:
|
|
407
|
+
def _get_row_identifier(row: Dict[str, Any]) -> str:
|
|
406
408
|
"""Extract a meaningful identifier from a row for error messages."""
|
|
407
409
|
# Try common ID fields first
|
|
408
410
|
for id_field in ["id", "pk", "primary_key", "key"]:
|
|
@@ -429,7 +431,7 @@ class SnapshotDiff:
|
|
|
429
431
|
self,
|
|
430
432
|
before: DatabaseSnapshot,
|
|
431
433
|
after: DatabaseSnapshot,
|
|
432
|
-
ignore_config: IgnoreConfig
|
|
434
|
+
ignore_config: Optional[IgnoreConfig] = None,
|
|
433
435
|
):
|
|
434
436
|
from .sql_differ import SQLiteDiffer # local import to avoid circularity
|
|
435
437
|
|
|
@@ -437,14 +439,14 @@ class SnapshotDiff:
|
|
|
437
439
|
self.after = after
|
|
438
440
|
self.ignore_config = ignore_config or IgnoreConfig()
|
|
439
441
|
self._differ = SQLiteDiffer(before.db_path, after.db_path)
|
|
440
|
-
self._cached:
|
|
442
|
+
self._cached: Optional[Dict[str, Any]] = None
|
|
441
443
|
|
|
442
444
|
# ------------------------------------------------------------------
|
|
443
445
|
def _collect(self):
|
|
444
446
|
if self._cached is not None:
|
|
445
447
|
return self._cached
|
|
446
448
|
all_tables = set(self.before.tables()) | set(self.after.tables())
|
|
447
|
-
diff:
|
|
449
|
+
diff: Dict[str, Dict[str, Any]] = {}
|
|
448
450
|
for tbl in all_tables:
|
|
449
451
|
if self.ignore_config.should_ignore_table(tbl):
|
|
450
452
|
continue
|
|
@@ -453,12 +455,12 @@ class SnapshotDiff:
|
|
|
453
455
|
return diff
|
|
454
456
|
|
|
455
457
|
# ------------------------------------------------------------------
|
|
456
|
-
def expect_only(self, allowed_changes:
|
|
458
|
+
def expect_only(self, allowed_changes: List[Dict[str, Any]]):
|
|
457
459
|
"""Allowed changes is a list of {table, pk, field, after} (before optional)."""
|
|
458
460
|
diff = self._collect()
|
|
459
461
|
|
|
460
462
|
def _is_change_allowed(
|
|
461
|
-
table: str, row_id: str, field: str
|
|
463
|
+
table: str, row_id: str, field: Optional[str], after_value: Any
|
|
462
464
|
) -> bool:
|
|
463
465
|
"""Check if a change is in the allowed list using semantic comparison."""
|
|
464
466
|
for allowed in allowed_changes:
|
|
@@ -596,7 +598,10 @@ class SnapshotDiff:
|
|
|
596
598
|
return self
|
|
597
599
|
|
|
598
600
|
def expect(
|
|
599
|
-
self,
|
|
601
|
+
self,
|
|
602
|
+
*,
|
|
603
|
+
allow: Optional[List[Dict[str, Any]]] = None,
|
|
604
|
+
forbid: Optional[List[Dict[str, Any]]] = None,
|
|
600
605
|
):
|
|
601
606
|
"""More granular: allow / forbid per‑table and per‑field."""
|
|
602
607
|
allow = allow or []
|
|
@@ -629,7 +634,7 @@ class SnapshotDiff:
|
|
|
629
634
|
class DatabaseSnapshot:
|
|
630
635
|
"""Represents a snapshot of an SQLite DB with DSL entrypoints."""
|
|
631
636
|
|
|
632
|
-
def __init__(self, db_path: str, *, name: str
|
|
637
|
+
def __init__(self, db_path: str, *, name: Optional[str] = None):
|
|
633
638
|
self.db_path = db_path
|
|
634
639
|
self.name = name or f"snapshot_{datetime.utcnow().isoformat()}"
|
|
635
640
|
self.created_at = datetime.utcnow()
|
|
@@ -639,7 +644,7 @@ class DatabaseSnapshot:
|
|
|
639
644
|
return QueryBuilder(self, table)
|
|
640
645
|
|
|
641
646
|
# Metadata -------------------------------------------------------------
|
|
642
|
-
def tables(self) ->
|
|
647
|
+
def tables(self) -> List[str]:
|
|
643
648
|
conn = sqlite3.connect(self.db_path)
|
|
644
649
|
cur = conn.cursor()
|
|
645
650
|
cur.execute(
|
|
@@ -654,7 +659,7 @@ class DatabaseSnapshot:
|
|
|
654
659
|
def diff(
|
|
655
660
|
self,
|
|
656
661
|
other: "DatabaseSnapshot", # noqa: UP037
|
|
657
|
-
ignore_config: IgnoreConfig
|
|
662
|
+
ignore_config: Optional[IgnoreConfig] = None,
|
|
658
663
|
) -> SnapshotDiff:
|
|
659
664
|
return SnapshotDiff(self, other, ignore_config)
|
|
660
665
|
|
|
@@ -663,7 +668,7 @@ class DatabaseSnapshot:
|
|
|
663
668
|
############################################################################
|
|
664
669
|
|
|
665
670
|
def expect_row(
|
|
666
|
-
self, table: str, where:
|
|
671
|
+
self, table: str, where: Dict[str, SQLValue], expect: Dict[str, SQLValue]
|
|
667
672
|
):
|
|
668
673
|
qb = self.table(table)
|
|
669
674
|
for k, v in where.items():
|
|
@@ -676,10 +681,10 @@ class DatabaseSnapshot:
|
|
|
676
681
|
def expect_rows(
|
|
677
682
|
self,
|
|
678
683
|
table: str,
|
|
679
|
-
where:
|
|
684
|
+
where: Dict[str, SQLValue],
|
|
680
685
|
*,
|
|
681
|
-
count: int
|
|
682
|
-
contains:
|
|
686
|
+
count: Optional[int] = None,
|
|
687
|
+
contains: Optional[List[Dict[str, SQLValue]]] = None,
|
|
683
688
|
):
|
|
684
689
|
qb = self.table(table)
|
|
685
690
|
for k, v in where.items():
|
|
@@ -694,7 +699,7 @@ class DatabaseSnapshot:
|
|
|
694
699
|
raise AssertionError(f"Expected a row matching {cond} in {table}")
|
|
695
700
|
return self
|
|
696
701
|
|
|
697
|
-
def expect_absent_row(self, table: str, where:
|
|
702
|
+
def expect_absent_row(self, table: str, where: Dict[str, SQLValue]):
|
|
698
703
|
qb = self.table(table)
|
|
699
704
|
for k, v in where.items():
|
|
700
705
|
qb = qb.eq(k, v)
|
fleet/verifiers/parse.py
CHANGED
fleet/verifiers/sql_differ.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import sqlite3
|
|
2
|
-
from typing import Any
|
|
2
|
+
from typing import Any, Optional, List, Dict, Tuple
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
class SQLiteDiffer:
|
|
@@ -7,7 +7,7 @@ class SQLiteDiffer:
|
|
|
7
7
|
self.before_db = before_db
|
|
8
8
|
self.after_db = after_db
|
|
9
9
|
|
|
10
|
-
def get_table_schema(self, db_path: str, table_name: str) ->
|
|
10
|
+
def get_table_schema(self, db_path: str, table_name: str) -> List[str]:
|
|
11
11
|
"""Get column names for a table"""
|
|
12
12
|
conn = sqlite3.connect(db_path)
|
|
13
13
|
cursor = conn.cursor()
|
|
@@ -16,7 +16,7 @@ class SQLiteDiffer:
|
|
|
16
16
|
conn.close()
|
|
17
17
|
return columns
|
|
18
18
|
|
|
19
|
-
def get_primary_key_columns(self, db_path: str, table_name: str) ->
|
|
19
|
+
def get_primary_key_columns(self, db_path: str, table_name: str) -> List[str]:
|
|
20
20
|
"""Get all primary key columns for a table, ordered by their position"""
|
|
21
21
|
conn = sqlite3.connect(db_path)
|
|
22
22
|
cursor = conn.cursor()
|
|
@@ -34,7 +34,7 @@ class SQLiteDiffer:
|
|
|
34
34
|
pk_columns.sort(key=lambda x: x[0])
|
|
35
35
|
return [col[1] for col in pk_columns]
|
|
36
36
|
|
|
37
|
-
def get_all_tables(self, db_path: str) ->
|
|
37
|
+
def get_all_tables(self, db_path: str) -> List[str]:
|
|
38
38
|
"""Get all table names from database"""
|
|
39
39
|
conn = sqlite3.connect(db_path)
|
|
40
40
|
cursor = conn.cursor()
|
|
@@ -49,8 +49,8 @@ class SQLiteDiffer:
|
|
|
49
49
|
self,
|
|
50
50
|
db_path: str,
|
|
51
51
|
table_name: str,
|
|
52
|
-
primary_key_columns:
|
|
53
|
-
) ->
|
|
52
|
+
primary_key_columns: Optional[List[str]] = None,
|
|
53
|
+
) -> Tuple[Dict[Any, dict], List[str]]:
|
|
54
54
|
"""Get table data indexed by primary key (single column or composite)"""
|
|
55
55
|
conn = sqlite3.connect(db_path)
|
|
56
56
|
conn.row_factory = sqlite3.Row
|
|
@@ -97,7 +97,7 @@ class SQLiteDiffer:
|
|
|
97
97
|
conn.close()
|
|
98
98
|
return data, primary_key_columns
|
|
99
99
|
|
|
100
|
-
def compare_rows(self, before_row: dict, after_row: dict) ->
|
|
100
|
+
def compare_rows(self, before_row: dict, after_row: dict) -> Dict[str, dict]:
|
|
101
101
|
"""Compare two rows field by field"""
|
|
102
102
|
changes = {}
|
|
103
103
|
|
|
@@ -113,7 +113,7 @@ class SQLiteDiffer:
|
|
|
113
113
|
return changes
|
|
114
114
|
|
|
115
115
|
def diff_table(
|
|
116
|
-
self, table_name: str, primary_key_columns:
|
|
116
|
+
self, table_name: str, primary_key_columns: Optional[List[str]] = None
|
|
117
117
|
) -> dict:
|
|
118
118
|
"""Create comprehensive diff of a table"""
|
|
119
119
|
before_data, detected_pk = self.get_table_data(
|
fleet/verifiers/verifier.py
CHANGED
|
@@ -12,10 +12,22 @@ import uuid
|
|
|
12
12
|
import logging
|
|
13
13
|
import hashlib
|
|
14
14
|
import inspect
|
|
15
|
-
from typing import
|
|
15
|
+
from typing import (
|
|
16
|
+
Any,
|
|
17
|
+
Callable,
|
|
18
|
+
Dict,
|
|
19
|
+
Optional,
|
|
20
|
+
List,
|
|
21
|
+
TypeVar,
|
|
22
|
+
Set,
|
|
23
|
+
TYPE_CHECKING,
|
|
24
|
+
Tuple,
|
|
25
|
+
)
|
|
16
26
|
|
|
17
27
|
from .bundler import FunctionBundler
|
|
18
|
-
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from ..client import SyncEnv
|
|
19
31
|
|
|
20
32
|
logger = logging.getLogger(__name__)
|
|
21
33
|
|
|
@@ -56,7 +68,7 @@ class SyncVerifierFunction:
|
|
|
56
68
|
# Copy function metadata
|
|
57
69
|
functools.update_wrapper(self, func)
|
|
58
70
|
|
|
59
|
-
def _get_or_create_bundle(self) ->
|
|
71
|
+
def _get_or_create_bundle(self) -> Tuple[bytes, str]:
|
|
60
72
|
"""Get or create bundle data and return (bundle_data, sha)."""
|
|
61
73
|
if self._bundle_data is None or self._bundle_sha is None:
|
|
62
74
|
# If we have raw code, create a bundle from it
|
|
@@ -98,7 +110,7 @@ class SyncVerifierFunction:
|
|
|
98
110
|
|
|
99
111
|
return self._bundle_data, self._bundle_sha
|
|
100
112
|
|
|
101
|
-
def _check_bundle_status(self, env: SyncEnv) ->
|
|
113
|
+
def _check_bundle_status(self, env: "SyncEnv") -> Tuple[str, bool]:
|
|
102
114
|
"""Check if bundle needs to be uploaded and return (sha, needs_upload)."""
|
|
103
115
|
bundle_data, bundle_sha = self._get_or_create_bundle()
|
|
104
116
|
|
|
@@ -129,7 +141,7 @@ class SyncVerifierFunction:
|
|
|
129
141
|
logger.info(f"Bundle {bundle_sha[:8]}... needs to be uploaded")
|
|
130
142
|
return bundle_sha, True # Upload needed
|
|
131
143
|
|
|
132
|
-
def __call__(self, env: SyncEnv, *args, **kwargs) -> float:
|
|
144
|
+
def __call__(self, env: "SyncEnv", *args, **kwargs) -> float:
|
|
133
145
|
"""Local execution of the verifier function with env as first parameter."""
|
|
134
146
|
try:
|
|
135
147
|
if self._is_async:
|
|
@@ -160,7 +172,7 @@ class SyncVerifierFunction:
|
|
|
160
172
|
# Return error score 0
|
|
161
173
|
return 0.0
|
|
162
174
|
|
|
163
|
-
def remote(self, env: SyncEnv, *args, **kwargs) -> float:
|
|
175
|
+
def remote(self, env: "SyncEnv", *args, **kwargs) -> float:
|
|
164
176
|
"""Remote execution of the verifier function with SHA-based bundle caching."""
|
|
165
177
|
# Async verifiers are now supported by the backend
|
|
166
178
|
# if self._is_async:
|
|
@@ -272,7 +284,7 @@ Remote traceback:
|
|
|
272
284
|
except:
|
|
273
285
|
raise RuntimeError(full_message)
|
|
274
286
|
|
|
275
|
-
def _get_env_id(self, env: SyncEnv) -> str:
|
|
287
|
+
def _get_env_id(self, env: "SyncEnv") -> str:
|
|
276
288
|
"""Generate a unique identifier for the environment."""
|
|
277
289
|
# Use instance base URL or similar unique identifier
|
|
278
290
|
if hasattr(env, "instance") and hasattr(env.instance, "base_url"):
|