fleet-python 0.2.28__py3-none-any.whl → 0.2.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fleet-python might be problematic. Click here for more details.
- examples/diff_example.py +30 -20
- examples/dsl_example.py +12 -7
- examples/example.py +4 -4
- examples/example_account.py +8 -0
- examples/example_action_log.py +2 -2
- examples/example_client.py +2 -2
- examples/example_mcp_anthropic.py +8 -5
- examples/example_mcp_openai.py +2 -2
- examples/example_sync.py +4 -4
- examples/example_task.py +16 -6
- examples/example_tasks.py +3 -6
- examples/example_verifier.py +16 -3
- examples/gemini_example.py +6 -6
- examples/json_tasks_example.py +2 -2
- examples/nova_act_example.py +2 -2
- examples/openai_example.py +3 -3
- examples/openai_simple_example.py +3 -3
- examples/query_builder_example.py +11 -7
- fleet/__init__.py +60 -5
- fleet/_async/__init__.py +258 -1
- fleet/_async/base.py +2 -1
- fleet/_async/client.py +194 -127
- fleet/_async/env/client.py +5 -1
- fleet/_async/global_client.py +43 -0
- fleet/_async/instance/client.py +1 -1
- fleet/_async/models.py +172 -171
- fleet/_async/resources/base.py +1 -1
- fleet/_async/resources/mcp.py +55 -0
- fleet/_async/resources/sqlite.py +141 -130
- fleet/_async/tasks.py +71 -16
- fleet/_async/verifiers/__init__.py +2 -2
- fleet/_async/verifiers/bundler.py +18 -14
- fleet/_async/verifiers/verifier.py +77 -71
- fleet/base.py +2 -1
- fleet/client.py +176 -136
- fleet/config.py +3 -2
- fleet/env/__init__.py +10 -1
- fleet/env/client.py +5 -1
- fleet/global_client.py +43 -0
- fleet/instance/__init__.py +1 -1
- fleet/instance/client.py +2 -4
- fleet/models.py +172 -171
- fleet/resources/base.py +1 -1
- fleet/resources/mcp.py +27 -33
- fleet/resources/sqlite.py +136 -131
- fleet/tasks.py +197 -16
- fleet/types.py +1 -1
- fleet/verifiers/__init__.py +2 -2
- fleet/verifiers/bundler.py +18 -14
- fleet/verifiers/code.py +1 -1
- fleet/verifiers/decorator.py +25 -34
- fleet/verifiers/parse.py +98 -68
- fleet/verifiers/verifier.py +77 -78
- {fleet_python-0.2.28.dist-info → fleet_python-0.2.32.dist-info}/METADATA +9 -9
- fleet_python-0.2.32.dist-info/RECORD +74 -0
- scripts/fix_sync_imports.py +87 -59
- scripts/unasync.py +10 -9
- fleet_python-0.2.28.dist-info/RECORD +0 -70
- {fleet_python-0.2.28.dist-info → fleet_python-0.2.32.dist-info}/WHEEL +0 -0
- {fleet_python-0.2.28.dist-info → fleet_python-0.2.32.dist-info}/licenses/LICENSE +0 -0
- {fleet_python-0.2.28.dist-info → fleet_python-0.2.32.dist-info}/top_level.txt +0 -0
fleet/verifiers/parse.py
CHANGED
|
@@ -4,68 +4,86 @@ import re
|
|
|
4
4
|
def extract_function_name(function_code: str) -> str | None:
|
|
5
5
|
"""
|
|
6
6
|
Extract function name from Python function code.
|
|
7
|
-
|
|
7
|
+
|
|
8
8
|
Handles both regular functions (def) and async functions (async def).
|
|
9
|
-
|
|
9
|
+
|
|
10
10
|
Args:
|
|
11
11
|
function_code: Python function code as a string
|
|
12
|
-
|
|
12
|
+
|
|
13
13
|
Returns:
|
|
14
14
|
The function name if found, None otherwise
|
|
15
15
|
"""
|
|
16
|
-
#
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
16
|
+
# Normalize escaped newlines and strip common Markdown code fences
|
|
17
|
+
code = function_code.replace("\\n", "\n").strip()
|
|
18
|
+
|
|
19
|
+
if "```" in code:
|
|
20
|
+
# Extract the first fenced block if present
|
|
21
|
+
fence_blocks = re.findall(r"```[a-zA-Z0-9_+-]*\n([\s\S]*?)\n```", code)
|
|
22
|
+
if fence_blocks:
|
|
23
|
+
code = fence_blocks[0].strip()
|
|
24
|
+
|
|
25
|
+
# Remove leading decorators (keep them for regex but allow preceding lines)
|
|
26
|
+
# Robust regex: allow optional decorators and whitespace before the def
|
|
27
|
+
pattern = r"^\s*(?:@[\w\.\n+() ,]*\n\s*)*(?:async\s+)?def\s+([A-Za-z_]\w*)\s*\("
|
|
28
|
+
match = re.search(pattern, code, flags=re.MULTILINE)
|
|
21
29
|
if match:
|
|
22
30
|
return match.group(1)
|
|
23
|
-
|
|
31
|
+
|
|
32
|
+
# Fallback: search anywhere (not anchored) for a def signature
|
|
33
|
+
fallback = r"(?:async\s+)?def\s+([A-Za-z_]\w*)\s*\("
|
|
34
|
+
match = re.search(fallback, code)
|
|
35
|
+
if match:
|
|
36
|
+
return match.group(1)
|
|
37
|
+
|
|
24
38
|
return None
|
|
25
39
|
|
|
26
40
|
|
|
27
41
|
def convert_verifier_string(verifier_str: str) -> str:
|
|
28
42
|
"""
|
|
29
|
-
Convert a verifier function string from the old format (env: Environment)
|
|
43
|
+
Convert a verifier function string from the old format (env: Environment)
|
|
30
44
|
to the new format (before: DatabaseSnapshot, after: DatabaseSnapshot).
|
|
31
|
-
|
|
45
|
+
|
|
32
46
|
Args:
|
|
33
47
|
verifier_str: The original verifier function as a string
|
|
34
|
-
|
|
48
|
+
|
|
35
49
|
Returns:
|
|
36
50
|
The converted verifier function string
|
|
37
51
|
"""
|
|
38
52
|
# First, handle escaped newlines in the input
|
|
39
|
-
verifier_str = verifier_str.replace(
|
|
40
|
-
|
|
53
|
+
verifier_str = verifier_str.replace("\\n", "\n")
|
|
54
|
+
|
|
41
55
|
# Extract function name, docstring, and body
|
|
42
56
|
# More flexible pattern that accepts both int and float return types
|
|
43
57
|
func_pattern = r'def\s+(\w+)\s*\(\s*env(?:\s*:\s*Environment)?\s*,?\s*final_answer(?:\s*:\s*str\s*\|\s*None)?\s*(?:=\s*None)?\s*\)\s*->\s*(?:float|int):\s*\n((?:\s*""".*?"""\s*\n)?)(.*)'
|
|
44
58
|
match = re.match(func_pattern, verifier_str.strip(), re.DOTALL)
|
|
45
|
-
|
|
59
|
+
|
|
46
60
|
if not match:
|
|
47
61
|
# Try with multiline pattern
|
|
48
62
|
func_pattern_multiline = r'def\s+(\w+)\s*\(\s*\n?\s*env(?:\s*:\s*Environment)?\s*,?\s*\n?\s*final_answer(?:\s*:\s*str\s*\|\s*None)?\s*(?:=\s*None)?\s*\n?\s*\)\s*->\s*(?:float|int):\s*\n((?:\s*""".*?"""\s*\n)?)(.*)'
|
|
49
63
|
match = re.match(func_pattern_multiline, verifier_str.strip(), re.DOTALL)
|
|
50
|
-
|
|
64
|
+
|
|
51
65
|
if not match:
|
|
52
|
-
raise ValueError(
|
|
53
|
-
|
|
66
|
+
raise ValueError(
|
|
67
|
+
"Could not parse verifier function. Expected format: def function_name(env: Environment, final_answer: str | None = None) -> float/int:"
|
|
68
|
+
)
|
|
69
|
+
|
|
54
70
|
func_name = match.group(1)
|
|
55
71
|
docstring = match.group(2).strip()
|
|
56
72
|
body = match.group(3)
|
|
57
|
-
|
|
73
|
+
|
|
58
74
|
# Find all unique env.db() calls
|
|
59
75
|
db_calls = re.findall(r'env\.db\("(\w+)"\)', body)
|
|
60
|
-
unique_db_names = list(
|
|
61
|
-
|
|
76
|
+
unique_db_names = list(
|
|
77
|
+
dict.fromkeys(db_calls)
|
|
78
|
+
) # Remove duplicates while preserving order
|
|
79
|
+
|
|
62
80
|
# Build the new function
|
|
63
|
-
new_func = f
|
|
81
|
+
new_func = f"""def {func_name}(
|
|
64
82
|
before: DatabaseSnapshot, after: DatabaseSnapshot, transcript: str | None = None
|
|
65
83
|
) -> int:
|
|
66
84
|
class Environment:
|
|
67
|
-
def db(self, name: str) -> DatabaseSnapshot:
|
|
68
|
-
|
|
85
|
+
def db(self, name: str) -> DatabaseSnapshot:"""
|
|
86
|
+
|
|
69
87
|
# Build the db method based on found database names
|
|
70
88
|
if unique_db_names:
|
|
71
89
|
conditions = []
|
|
@@ -73,28 +91,32 @@ def convert_verifier_string(verifier_str: str) -> str:
|
|
|
73
91
|
if db_name == "seed":
|
|
74
92
|
conditions.append('before if name == "seed"')
|
|
75
93
|
elif db_name == "current":
|
|
76
|
-
conditions.append(
|
|
94
|
+
conditions.append("after")
|
|
77
95
|
else:
|
|
78
96
|
# Handle other database names if needed
|
|
79
97
|
conditions.append(f'None # Handle "{db_name}"')
|
|
80
|
-
|
|
81
|
-
if
|
|
82
|
-
|
|
83
|
-
|
|
98
|
+
|
|
99
|
+
if (
|
|
100
|
+
len(conditions) == 2
|
|
101
|
+
and "seed" in unique_db_names
|
|
102
|
+
and "current" in unique_db_names
|
|
103
|
+
):
|
|
104
|
+
new_func += f"""
|
|
105
|
+
return before if name == "seed" else after"""
|
|
84
106
|
else:
|
|
85
107
|
# More complex mapping if needed
|
|
86
|
-
new_func += f
|
|
108
|
+
new_func += f"""
|
|
87
109
|
if name == "seed":
|
|
88
110
|
return before
|
|
89
111
|
elif name == "current":
|
|
90
112
|
return after
|
|
91
113
|
else:
|
|
92
|
-
raise ValueError(f"Unknown database name: {{name}}")
|
|
114
|
+
raise ValueError(f"Unknown database name: {{name}}")"""
|
|
93
115
|
else:
|
|
94
|
-
new_func +=
|
|
95
|
-
return before if name == "seed" else after
|
|
96
|
-
|
|
97
|
-
new_func +=
|
|
116
|
+
new_func += """
|
|
117
|
+
return before if name == "seed" else after"""
|
|
118
|
+
|
|
119
|
+
new_func += """
|
|
98
120
|
|
|
99
121
|
@property
|
|
100
122
|
def instance(self):
|
|
@@ -103,42 +125,44 @@ def convert_verifier_string(verifier_str: str) -> str:
|
|
|
103
125
|
def load(self):
|
|
104
126
|
pass
|
|
105
127
|
|
|
106
|
-
def verifier(env: Environment, final_answer: str | None = None) -> float:
|
|
107
|
-
|
|
128
|
+
def verifier(env: Environment, final_answer: str | None = None) -> float:"""
|
|
129
|
+
|
|
108
130
|
if docstring:
|
|
109
|
-
new_func += f
|
|
110
|
-
|
|
131
|
+
new_func += f"\n {docstring}"
|
|
132
|
+
|
|
111
133
|
# First, find the minimum indentation in the body (excluding empty lines)
|
|
112
134
|
body_lines = body.splitlines()
|
|
113
|
-
min_indent = float(
|
|
135
|
+
min_indent = float("inf")
|
|
114
136
|
for line in body_lines:
|
|
115
137
|
if line.strip(): # Non-empty line
|
|
116
138
|
indent_len = len(line) - len(line.lstrip())
|
|
117
139
|
min_indent = min(min_indent, indent_len)
|
|
118
|
-
|
|
140
|
+
|
|
119
141
|
# If we didn't find any non-empty lines, set min_indent to 0
|
|
120
|
-
if min_indent == float(
|
|
142
|
+
if min_indent == float("inf"):
|
|
121
143
|
min_indent = 0
|
|
122
|
-
|
|
144
|
+
|
|
123
145
|
# Now strip the minimum indentation and re-indent to 8 spaces
|
|
124
146
|
if body_lines:
|
|
125
147
|
indented_lines = []
|
|
126
148
|
for line in body_lines:
|
|
127
149
|
if line.strip(): # Non-empty line
|
|
128
150
|
# Remove the minimum indentation and add 8 spaces
|
|
129
|
-
stripped_line =
|
|
130
|
-
|
|
151
|
+
stripped_line = (
|
|
152
|
+
line[min_indent:] if len(line) > min_indent else line.lstrip()
|
|
153
|
+
)
|
|
154
|
+
indented_lines.append(" " + stripped_line)
|
|
131
155
|
else: # Empty line
|
|
132
|
-
indented_lines.append(
|
|
133
|
-
|
|
134
|
-
indented_body =
|
|
135
|
-
new_func += f
|
|
136
|
-
|
|
156
|
+
indented_lines.append("")
|
|
157
|
+
|
|
158
|
+
indented_body = "\n".join(indented_lines)
|
|
159
|
+
new_func += f"\n{indented_body}"
|
|
160
|
+
|
|
137
161
|
# Add the return statement
|
|
138
|
-
new_func +=
|
|
162
|
+
new_func += "\n\n return verifier(Environment(), transcript)"
|
|
139
163
|
|
|
140
164
|
# Replace TASK_FAILED_SCORE with 0 in the function string
|
|
141
|
-
new_func = new_func.replace(
|
|
165
|
+
new_func = new_func.replace("TASK_FAILED_SCORE", "0")
|
|
142
166
|
|
|
143
167
|
return new_func
|
|
144
168
|
|
|
@@ -147,39 +171,45 @@ def convert_new_to_old_verifier(verifier_str: str) -> str:
|
|
|
147
171
|
"""
|
|
148
172
|
Convert a verifier function from the new format (before/after: DatabaseSnapshot)
|
|
149
173
|
to the old format (env: Environment).
|
|
150
|
-
|
|
174
|
+
|
|
151
175
|
This is the inverse of convert_verifier_string.
|
|
152
|
-
|
|
176
|
+
|
|
153
177
|
Args:
|
|
154
178
|
verifier_str: The new format verifier function as a string
|
|
155
|
-
|
|
179
|
+
|
|
156
180
|
Returns:
|
|
157
181
|
The converted verifier function string that accepts env
|
|
158
|
-
"""
|
|
182
|
+
"""
|
|
159
183
|
# Extract function name, parameters, docstring, and body
|
|
160
184
|
# Pattern for new format with flexible whitespace and multiline support
|
|
161
185
|
func_pattern = r'def\s+(\w+)\s*\(\s*before\s*:\s*DatabaseSnapshot\s*,?\s*after\s*:\s*DatabaseSnapshot\s*,?\s*transcript\s*:\s*str\s*\|\s*None\s*=\s*None\s*,?\s*\)\s*->\s*int:\s*((?:\s*""".*?"""\s*)?)(.*)'
|
|
162
|
-
|
|
186
|
+
|
|
163
187
|
# Try multiline pattern that's more flexible
|
|
164
188
|
func_pattern_multiline = r'def\s+(\w+)\s*\(\s*\n?\s*before\s*:\s*DatabaseSnapshot\s*,?\s*\n?\s*after\s*:\s*DatabaseSnapshot\s*,?\s*\n?\s*transcript\s*:\s*str\s*\|\s*None\s*=\s*None\s*,?\s*\n?\s*\)\s*->\s*int:\s*\n?((?:\s*""".*?"""\s*)?)(.*)'
|
|
165
|
-
|
|
166
|
-
match = re.match(
|
|
167
|
-
|
|
189
|
+
|
|
190
|
+
match = re.match(
|
|
191
|
+
func_pattern_multiline, verifier_str.strip(), re.DOTALL | re.MULTILINE
|
|
192
|
+
)
|
|
193
|
+
|
|
168
194
|
if not match:
|
|
169
195
|
# Even more flexible pattern
|
|
170
|
-
func_pattern_flexible =
|
|
196
|
+
func_pattern_flexible = (
|
|
197
|
+
r'def\s+(\w+)\s*\([^)]*\)\s*->\s*int:\s*\n?((?:\s*""".*?"""\s*)?)(.*)'
|
|
198
|
+
)
|
|
171
199
|
match = re.match(func_pattern_flexible, verifier_str.strip(), re.DOTALL)
|
|
172
|
-
|
|
200
|
+
|
|
173
201
|
if not match:
|
|
174
202
|
raise ValueError("Could not parse new format verifier function")
|
|
175
|
-
|
|
203
|
+
|
|
176
204
|
func_name = match.group(1)
|
|
177
205
|
docstring = match.group(2).strip()
|
|
178
206
|
body = match.group(3)
|
|
179
|
-
|
|
207
|
+
|
|
180
208
|
# Indent the original function body
|
|
181
|
-
indented_verifier =
|
|
182
|
-
|
|
209
|
+
indented_verifier = "\n".join(
|
|
210
|
+
" " + line if line.strip() else line for line in verifier_str.splitlines()
|
|
211
|
+
)
|
|
212
|
+
|
|
183
213
|
# Build the wrapper function
|
|
184
214
|
wrapper_func = f'''def {func_name}_wrapper(env, *args, **kwargs) -> float:
|
|
185
215
|
"""Wrapper to adapt new format verifier to old format."""
|
|
@@ -203,5 +233,5 @@ def convert_new_to_old_verifier(verifier_str: str) -> str:
|
|
|
203
233
|
# Call the inner function and convert result
|
|
204
234
|
result = {func_name}(before, after, transcript)
|
|
205
235
|
return float(result)'''
|
|
206
|
-
|
|
207
|
-
return wrapper_func
|
|
236
|
+
|
|
237
|
+
return wrapper_func
|
fleet/verifiers/verifier.py
CHANGED
|
@@ -19,7 +19,7 @@ from ..client import SyncEnv
|
|
|
19
19
|
|
|
20
20
|
logger = logging.getLogger(__name__)
|
|
21
21
|
|
|
22
|
-
F = TypeVar(
|
|
22
|
+
F = TypeVar("F", bound=Callable[..., Any])
|
|
23
23
|
|
|
24
24
|
# Global cache to track which bundle SHAs have been uploaded to S3
|
|
25
25
|
_uploaded_bundle_shas: Set[str] = set()
|
|
@@ -33,7 +33,7 @@ def _get_bundle_sha(bundle_data: bytes) -> str:
|
|
|
33
33
|
|
|
34
34
|
class SyncVerifierFunction:
|
|
35
35
|
"""Wrapper for a verified function that supports local execution with env-first pattern."""
|
|
36
|
-
|
|
36
|
+
|
|
37
37
|
def __init__(
|
|
38
38
|
self,
|
|
39
39
|
func: F,
|
|
@@ -41,7 +41,7 @@ class SyncVerifierFunction:
|
|
|
41
41
|
extra_requirements: Optional[List[str]] = None,
|
|
42
42
|
verifier_id: Optional[str] = None,
|
|
43
43
|
sha256: Optional[str] = None,
|
|
44
|
-
raw_code: Optional[str] = None
|
|
44
|
+
raw_code: Optional[str] = None,
|
|
45
45
|
):
|
|
46
46
|
self.func = func
|
|
47
47
|
self.key = key
|
|
@@ -52,10 +52,10 @@ class SyncVerifierFunction:
|
|
|
52
52
|
self._bundle_data: Optional[bytes] = None # Cached bundle data
|
|
53
53
|
self._raw_code: Optional[str] = raw_code # Store raw code if provided
|
|
54
54
|
self._is_async = inspect.iscoroutinefunction(func)
|
|
55
|
-
|
|
55
|
+
|
|
56
56
|
# Copy function metadata
|
|
57
57
|
functools.update_wrapper(self, func)
|
|
58
|
-
|
|
58
|
+
|
|
59
59
|
def _get_or_create_bundle(self) -> tuple[bytes, str]:
|
|
60
60
|
"""Get or create bundle data and return (bundle_data, sha)."""
|
|
61
61
|
if self._bundle_data is None or self._bundle_sha is None:
|
|
@@ -63,68 +63,72 @@ class SyncVerifierFunction:
|
|
|
63
63
|
if self._raw_code:
|
|
64
64
|
import io
|
|
65
65
|
import zipfile
|
|
66
|
-
|
|
66
|
+
|
|
67
67
|
# Create zip bundle directly (matching bundler format)
|
|
68
68
|
zip_buffer = io.BytesIO()
|
|
69
|
-
with zipfile.ZipFile(zip_buffer,
|
|
69
|
+
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
|
70
70
|
# Add requirements.txt
|
|
71
71
|
requirements = self.extra_requirements or []
|
|
72
72
|
if "fleet-python" not in requirements:
|
|
73
73
|
requirements.append("fleet-python")
|
|
74
74
|
req_content = "\n".join(requirements)
|
|
75
75
|
zf.writestr("requirements.txt", req_content)
|
|
76
|
-
|
|
76
|
+
|
|
77
77
|
# Add verifier.py with the raw code
|
|
78
78
|
zf.writestr("verifier.py", self._raw_code)
|
|
79
|
-
|
|
79
|
+
|
|
80
80
|
self._bundle_data = zip_buffer.getvalue()
|
|
81
81
|
self._bundle_sha = _get_bundle_sha(self._bundle_data)
|
|
82
|
-
logger.debug(
|
|
82
|
+
logger.debug(
|
|
83
|
+
f"Created bundle from raw code for {self.key} with SHA: {self._bundle_sha}"
|
|
84
|
+
)
|
|
83
85
|
else:
|
|
84
86
|
# Try to create bundle from function source
|
|
85
87
|
try:
|
|
86
88
|
self._bundle_data = self._bundler.create_bundle(
|
|
87
|
-
self.func,
|
|
88
|
-
self.extra_requirements,
|
|
89
|
-
self.verifier_id
|
|
89
|
+
self.func, self.extra_requirements, self.verifier_id
|
|
90
90
|
)
|
|
91
91
|
self._bundle_sha = _get_bundle_sha(self._bundle_data)
|
|
92
|
-
logger.debug(
|
|
92
|
+
logger.debug(
|
|
93
|
+
f"Created bundle for {self.key} with SHA: {self._bundle_sha}"
|
|
94
|
+
)
|
|
93
95
|
except OSError as e:
|
|
94
96
|
# Can't create bundle - no source and no raw code
|
|
95
97
|
raise OSError(f"Cannot create bundle for {self.key}: {e}")
|
|
96
|
-
|
|
98
|
+
|
|
97
99
|
return self._bundle_data, self._bundle_sha
|
|
98
|
-
|
|
100
|
+
|
|
99
101
|
def _check_bundle_status(self, env: SyncEnv) -> tuple[str, bool]:
|
|
100
102
|
"""Check if bundle needs to be uploaded and return (sha, needs_upload)."""
|
|
101
103
|
bundle_data, bundle_sha = self._get_or_create_bundle()
|
|
102
|
-
|
|
104
|
+
|
|
103
105
|
# If bundle_data is empty, we're using server-side bundle
|
|
104
106
|
if not bundle_data:
|
|
105
107
|
logger.debug(f"Using server-side bundle {bundle_sha[:8]}...")
|
|
106
108
|
return bundle_sha, False # No upload needed, server has it
|
|
107
|
-
|
|
109
|
+
|
|
108
110
|
# 1. Check local process cache first
|
|
109
111
|
if bundle_sha in _uploaded_bundle_shas:
|
|
110
112
|
logger.debug(f"Bundle {bundle_sha[:8]}... found in local cache")
|
|
111
113
|
return bundle_sha, False # Already uploaded, no upload needed
|
|
112
|
-
|
|
114
|
+
|
|
113
115
|
# 2. Check if bundle exists on server (pseudocode)
|
|
114
116
|
# TODO: Add endpoint to check if bundle SHA exists in S3
|
|
115
117
|
try:
|
|
116
118
|
exists = env.check_bundle_exists(bundle_sha)
|
|
117
119
|
if exists.success:
|
|
118
|
-
logger.info(
|
|
120
|
+
logger.info(
|
|
121
|
+
f"Bundle {bundle_sha[:8]}... found on server, updating cache"
|
|
122
|
+
)
|
|
119
123
|
_uploaded_bundle_shas.add(bundle_sha)
|
|
120
124
|
return bundle_sha, False # Found on server, no upload needed
|
|
121
125
|
except Exception as e:
|
|
122
126
|
logger.warning(f"Failed to check bundle existence: {e}")
|
|
123
|
-
|
|
127
|
+
|
|
124
128
|
# 3. Bundle not found locally or on server - upload needed
|
|
125
129
|
logger.info(f"Bundle {bundle_sha[:8]}... needs to be uploaded")
|
|
126
130
|
return bundle_sha, True # Upload needed
|
|
127
|
-
|
|
131
|
+
|
|
128
132
|
def __call__(self, env: SyncEnv, *args, **kwargs) -> float:
|
|
129
133
|
"""Local execution of the verifier function with env as first parameter."""
|
|
130
134
|
try:
|
|
@@ -134,7 +138,7 @@ class SyncVerifierFunction:
|
|
|
134
138
|
else:
|
|
135
139
|
# For sync functions, call directly
|
|
136
140
|
result = self.func(env, *args, **kwargs)
|
|
137
|
-
|
|
141
|
+
|
|
138
142
|
# Handle different return types
|
|
139
143
|
if isinstance(result, (int, float)):
|
|
140
144
|
# Direct score return
|
|
@@ -144,39 +148,34 @@ class SyncVerifierFunction:
|
|
|
144
148
|
return result
|
|
145
149
|
else:
|
|
146
150
|
# Try to extract score from object attributes
|
|
147
|
-
if hasattr(result,
|
|
151
|
+
if hasattr(result, "score"):
|
|
148
152
|
return float(result.score)
|
|
149
153
|
else:
|
|
150
|
-
raise ValueError(
|
|
151
|
-
|
|
154
|
+
raise ValueError(
|
|
155
|
+
f"Verifier function must return a score (number). Got {type(result)}"
|
|
156
|
+
)
|
|
157
|
+
|
|
152
158
|
except Exception as e:
|
|
153
159
|
logger.error(f"Error in verifier {self.key}: {e}")
|
|
154
160
|
# Return error score 0
|
|
155
161
|
return 0.0
|
|
156
|
-
|
|
162
|
+
|
|
157
163
|
def remote(self, env: SyncEnv, *args, **kwargs) -> float:
|
|
158
164
|
"""Remote execution of the verifier function with SHA-based bundle caching."""
|
|
159
|
-
|
|
160
|
-
# if self._is_async:
|
|
161
|
-
# raise NotImplementedError(
|
|
162
|
-
# f"Async verifier '{self.key}' cannot be executed remotely. "
|
|
163
|
-
# "The remote execution environment only supports synchronous functions. "
|
|
164
|
-
# "Please provide a synchronous version of your verifier."
|
|
165
|
-
# )
|
|
166
|
-
|
|
165
|
+
|
|
167
166
|
args_array = list(args)
|
|
168
167
|
args_array.append({"env": env.instance_id})
|
|
169
168
|
args = tuple(args_array)
|
|
170
|
-
|
|
169
|
+
|
|
171
170
|
try:
|
|
172
171
|
# Check if bundle needs to be uploaded
|
|
173
172
|
bundle_sha, needs_upload = self._check_bundle_status(env)
|
|
174
|
-
|
|
173
|
+
|
|
175
174
|
if needs_upload:
|
|
176
175
|
# Need to upload bundle to S3
|
|
177
176
|
logger.info(f"Uploading bundle {bundle_sha[:8]}... for {self.key}")
|
|
178
177
|
bundle_data, _ = self._get_or_create_bundle()
|
|
179
|
-
|
|
178
|
+
|
|
180
179
|
response = env.execute_verifier_remote(
|
|
181
180
|
bundle_data=bundle_data,
|
|
182
181
|
bundle_sha=bundle_sha,
|
|
@@ -185,42 +184,46 @@ class SyncVerifierFunction:
|
|
|
185
184
|
args=args,
|
|
186
185
|
args_array=args_array,
|
|
187
186
|
kwargs=kwargs,
|
|
188
|
-
needs_upload=True
|
|
187
|
+
needs_upload=True,
|
|
189
188
|
)
|
|
190
|
-
|
|
189
|
+
|
|
191
190
|
# Mark as uploaded after successful execution
|
|
192
191
|
_uploaded_bundle_shas.add(bundle_sha)
|
|
193
192
|
logger.debug(f"Registered bundle {bundle_sha[:8]}... as uploaded")
|
|
194
|
-
|
|
193
|
+
|
|
195
194
|
else:
|
|
196
195
|
# Bundle already available - execute without upload
|
|
197
|
-
logger.info(
|
|
196
|
+
logger.info(
|
|
197
|
+
f"Executing cached bundle {bundle_sha[:8]}... for {self.key}"
|
|
198
|
+
)
|
|
198
199
|
bundle_data, _ = self._get_or_create_bundle()
|
|
199
|
-
|
|
200
|
+
|
|
200
201
|
response = env.execute_verifier_remote(
|
|
201
|
-
bundle_data=bundle_data or b
|
|
202
|
+
bundle_data=bundle_data or b"", # Empty if using server-side bundle
|
|
202
203
|
bundle_sha=bundle_sha,
|
|
203
204
|
key=self.key,
|
|
204
205
|
function_name=self.func.__name__,
|
|
205
206
|
args=args,
|
|
206
207
|
args_array=args_array,
|
|
207
208
|
kwargs=kwargs,
|
|
208
|
-
needs_upload=False # Don't upload, just execute
|
|
209
|
+
needs_upload=False, # Don't upload, just execute
|
|
209
210
|
)
|
|
210
|
-
|
|
211
|
+
|
|
211
212
|
# Handle response
|
|
213
|
+
if response.stdout:
|
|
214
|
+
print(response.stdout)
|
|
212
215
|
if response.success:
|
|
213
216
|
return self._process_result(response.result)
|
|
214
217
|
else:
|
|
215
218
|
self._raise_remote_error(response.error)
|
|
216
|
-
|
|
219
|
+
|
|
217
220
|
except Exception as e:
|
|
218
221
|
logger.error(f"Remote execution failed for {self.key}: {e}")
|
|
219
222
|
# If it's an HTTP error, try to get more details
|
|
220
|
-
if hasattr(e,
|
|
223
|
+
if hasattr(e, "response") and hasattr(e.response, "text"):
|
|
221
224
|
logger.error(f"Server response: {e.response.text}")
|
|
222
225
|
raise
|
|
223
|
-
|
|
226
|
+
|
|
224
227
|
def _process_result(self, result: Any) -> float:
|
|
225
228
|
"""Process remote execution result, handling different return types."""
|
|
226
229
|
# Handle different return types like local execution
|
|
@@ -230,7 +233,7 @@ class SyncVerifierFunction:
|
|
|
230
233
|
return float(result["score"])
|
|
231
234
|
else:
|
|
232
235
|
# Try to extract score from object attributes
|
|
233
|
-
if hasattr(result,
|
|
236
|
+
if hasattr(result, "score"):
|
|
234
237
|
return float(result.score)
|
|
235
238
|
else:
|
|
236
239
|
# Best effort conversion
|
|
@@ -239,13 +242,13 @@ class SyncVerifierFunction:
|
|
|
239
242
|
except (ValueError, TypeError):
|
|
240
243
|
logger.warning(f"Could not convert result to float: {result}")
|
|
241
244
|
return 0.0
|
|
242
|
-
|
|
245
|
+
|
|
243
246
|
def _raise_remote_error(self, error_info: Dict[str, Any]):
|
|
244
247
|
"""Reconstruct remote error as local exception."""
|
|
245
248
|
error_type = error_info.get("type", "RuntimeError")
|
|
246
249
|
message = error_info.get("message", "Remote execution failed")
|
|
247
250
|
traceback_str = error_info.get("traceback", "")
|
|
248
|
-
|
|
251
|
+
|
|
249
252
|
# Create a rich error message
|
|
250
253
|
full_message = f"""
|
|
251
254
|
Remote verifier execution failed:
|
|
@@ -254,32 +257,32 @@ Remote verifier execution failed:
|
|
|
254
257
|
Remote traceback:
|
|
255
258
|
{traceback_str}
|
|
256
259
|
""".strip()
|
|
257
|
-
|
|
260
|
+
|
|
258
261
|
# Try to raise the original exception type
|
|
259
262
|
try:
|
|
260
263
|
exception_class = getattr(__builtins__, error_type, RuntimeError)
|
|
261
264
|
raise exception_class(full_message)
|
|
262
265
|
except:
|
|
263
266
|
raise RuntimeError(full_message)
|
|
264
|
-
|
|
267
|
+
|
|
265
268
|
def _get_env_id(self, env: SyncEnv) -> str:
|
|
266
269
|
"""Generate a unique identifier for the environment."""
|
|
267
270
|
# Use instance base URL or similar unique identifier
|
|
268
|
-
if hasattr(env,
|
|
271
|
+
if hasattr(env, "instance") and hasattr(env.instance, "base_url"):
|
|
269
272
|
return f"{env.instance.base_url}"
|
|
270
273
|
else:
|
|
271
274
|
# Fallback to object id (less ideal but works)
|
|
272
275
|
return str(id(env))
|
|
273
|
-
|
|
276
|
+
|
|
274
277
|
def _is_bundle_not_found_error(self, error: Exception) -> bool:
|
|
275
278
|
"""Check if the error indicates the bundle was not found on the server."""
|
|
276
279
|
# Check for common "bundle not found" error patterns
|
|
277
280
|
error_msg = str(error).lower()
|
|
278
281
|
return (
|
|
279
|
-
"bundle not found" in error_msg
|
|
280
|
-
"verifier not found" in error_msg
|
|
281
|
-
"404" in error_msg
|
|
282
|
-
"not found" in error_msg
|
|
282
|
+
"bundle not found" in error_msg
|
|
283
|
+
or "verifier not found" in error_msg
|
|
284
|
+
or "404" in error_msg
|
|
285
|
+
or "not found" in error_msg
|
|
283
286
|
)
|
|
284
287
|
|
|
285
288
|
|
|
@@ -287,21 +290,21 @@ def verifier(
|
|
|
287
290
|
key: Optional[str] = None,
|
|
288
291
|
extra_requirements: Optional[List[str]] = None,
|
|
289
292
|
sha256: Optional[str] = None,
|
|
290
|
-
raw_code: Optional[str] = None
|
|
293
|
+
raw_code: Optional[str] = None,
|
|
291
294
|
) -> Callable[[F], SyncVerifierFunction]:
|
|
292
295
|
"""
|
|
293
296
|
Decorator to create a verifier function with env-first pattern.
|
|
294
|
-
|
|
297
|
+
|
|
295
298
|
The decorated function must take 'env' as its first parameter, making it explicit
|
|
296
299
|
that verifiers operate within an environment context. This makes verifiers reusable
|
|
297
300
|
across different environments.
|
|
298
|
-
|
|
301
|
+
|
|
299
302
|
Args:
|
|
300
303
|
key: Optional key for the verifier. Defaults to function name.
|
|
301
304
|
extra_requirements: Additional PyPI packages needed by the verifier.
|
|
302
305
|
sha256: Optional SHA256 hash of existing server-side bundle to use.
|
|
303
306
|
raw_code: Optional raw code to use as bundle (bypasses source extraction).
|
|
304
|
-
|
|
307
|
+
|
|
305
308
|
Example:
|
|
306
309
|
# Synchronous verifier (works locally and remotely)
|
|
307
310
|
@verifier(key="check_user_count")
|
|
@@ -310,7 +313,7 @@ def verifier(
|
|
|
310
313
|
result = db.query("SELECT COUNT(*) FROM users")
|
|
311
314
|
actual_count = result.rows[0][0]
|
|
312
315
|
return 1.0 if actual_count >= expected_count else 0.0
|
|
313
|
-
|
|
316
|
+
|
|
314
317
|
# Async verifier (only works locally)
|
|
315
318
|
@verifier(key="check_user_async")
|
|
316
319
|
async def check_user_async(env, expected_count: int) -> float:
|
|
@@ -318,29 +321,25 @@ def verifier(
|
|
|
318
321
|
result = await db.query("SELECT COUNT(*) FROM users")
|
|
319
322
|
actual_count = result.rows[0][0]
|
|
320
323
|
return 1.0 if actual_count >= expected_count else 0.0
|
|
321
|
-
|
|
324
|
+
|
|
322
325
|
# Usage
|
|
323
|
-
env = await
|
|
324
|
-
|
|
326
|
+
env = await fleet.env.make_async("fira")
|
|
327
|
+
|
|
325
328
|
# Local execution
|
|
326
329
|
result = await check_user_count(env, 5) # sync verifier
|
|
327
330
|
result = await check_user_async(env, 5) # async verifier
|
|
328
|
-
|
|
331
|
+
|
|
329
332
|
# Remote execution
|
|
330
333
|
result = await check_user_count.remote(env, 5) # sync verifier works
|
|
331
334
|
# await check_user_async.remote(env, 5) # raises NotImplementedError
|
|
332
335
|
"""
|
|
336
|
+
|
|
333
337
|
def decorator(func: F) -> SyncVerifierFunction:
|
|
334
338
|
verifier_key = key or func.__name__
|
|
335
339
|
verifier_uuid = str(uuid.uuid4())
|
|
336
|
-
|
|
340
|
+
|
|
337
341
|
return SyncVerifierFunction(
|
|
338
|
-
func,
|
|
339
|
-
verifier_key,
|
|
340
|
-
extra_requirements,
|
|
341
|
-
verifier_uuid,
|
|
342
|
-
sha256,
|
|
343
|
-
raw_code
|
|
342
|
+
func, verifier_key, extra_requirements, verifier_uuid, sha256, raw_code
|
|
344
343
|
)
|
|
345
|
-
|
|
346
|
-
return decorator
|
|
344
|
+
|
|
345
|
+
return decorator
|