fleet-python 0.2.29__py3-none-any.whl → 0.2.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fleet-python might be problematic. Click here for more details.
- examples/diff_example.py +30 -20
- examples/dsl_example.py +12 -7
- examples/example.py +4 -4
- examples/exampleResume.py +191 -0
- examples/example_account.py +8 -0
- examples/example_action_log.py +2 -2
- examples/example_client.py +2 -2
- examples/example_mcp_anthropic.py +8 -5
- examples/example_mcp_openai.py +2 -2
- examples/example_sync.py +4 -4
- examples/example_task.py +16 -6
- examples/example_tasks.py +3 -6
- examples/example_verifier.py +16 -3
- examples/gemini_example.py +6 -6
- examples/json_tasks_example.py +2 -2
- examples/nova_act_example.py +2 -2
- examples/openai_example.py +3 -3
- examples/openai_simple_example.py +3 -3
- examples/query_builder_example.py +11 -7
- examples/test_cdp_logging.py +80 -0
- fleet/__init__.py +60 -5
- fleet/_async/__init__.py +258 -1
- fleet/_async/base.py +2 -1
- fleet/_async/client.py +164 -144
- fleet/_async/env/client.py +2 -0
- fleet/_async/global_client.py +43 -0
- fleet/_async/instance/client.py +1 -1
- fleet/_async/models.py +172 -171
- fleet/_async/resources/base.py +1 -1
- fleet/_async/resources/mcp.py +55 -0
- fleet/_async/resources/sqlite.py +141 -130
- fleet/_async/tasks.py +69 -16
- fleet/_async/verifiers/__init__.py +2 -2
- fleet/_async/verifiers/bundler.py +18 -14
- fleet/_async/verifiers/verifier.py +77 -71
- fleet/base.py +2 -1
- fleet/client.py +162 -148
- fleet/config.py +3 -2
- fleet/env/__init__.py +9 -1
- fleet/env/client.py +4 -1
- fleet/global_client.py +43 -0
- fleet/instance/__init__.py +1 -1
- fleet/instance/client.py +1 -1
- fleet/models.py +172 -171
- fleet/resources/base.py +1 -1
- fleet/resources/mcp.py +11 -16
- fleet/resources/sqlite.py +141 -130
- fleet/tasks.py +86 -15
- fleet/types.py +1 -1
- fleet/verifiers/__init__.py +2 -2
- fleet/verifiers/bundler.py +18 -14
- fleet/verifiers/code.py +1 -1
- fleet/verifiers/decorator.py +25 -34
- fleet/verifiers/parse.py +98 -68
- fleet/verifiers/verifier.py +77 -71
- {fleet_python-0.2.29.dist-info → fleet_python-0.2.34.dist-info}/METADATA +9 -9
- fleet_python-0.2.34.dist-info/RECORD +76 -0
- scripts/fix_sync_imports.py +87 -59
- scripts/unasync.py +10 -9
- fleet_python-0.2.29.dist-info/RECORD +0 -70
- {fleet_python-0.2.29.dist-info → fleet_python-0.2.34.dist-info}/WHEEL +0 -0
- {fleet_python-0.2.29.dist-info → fleet_python-0.2.34.dist-info}/licenses/LICENSE +0 -0
- {fleet_python-0.2.29.dist-info → fleet_python-0.2.34.dist-info}/top_level.txt +0 -0
fleet/tasks.py
CHANGED
|
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import re
|
|
6
6
|
from datetime import datetime
|
|
7
|
-
from typing import Any, Dict, Optional
|
|
7
|
+
from typing import Any, Dict, Optional, List
|
|
8
8
|
from uuid import UUID
|
|
9
9
|
|
|
10
10
|
from pydantic import BaseModel, Field, validator
|
|
@@ -15,31 +15,39 @@ from fleet.types import VerifierFunction
|
|
|
15
15
|
|
|
16
16
|
class Task(BaseModel):
|
|
17
17
|
"""A task model representing a single task in the Fleet system."""
|
|
18
|
-
|
|
18
|
+
|
|
19
19
|
key: str = Field(..., description="Unique task key identifier")
|
|
20
20
|
prompt: str = Field(..., description="Task prompt or instruction")
|
|
21
21
|
env_id: str = Field(..., description="Environment identifier")
|
|
22
|
-
env_variables: Optional[Dict[str, Any]] = Field(
|
|
22
|
+
env_variables: Optional[Dict[str, Any]] = Field(
|
|
23
|
+
default_factory=dict, description="Environment variables"
|
|
24
|
+
)
|
|
23
25
|
created_at: Optional[datetime] = Field(None, description="Task creation timestamp")
|
|
24
26
|
version: Optional[str] = Field(None, description="Task version")
|
|
25
27
|
verifier_func: Optional[str] = Field(None, description="Verifier function code")
|
|
26
|
-
verifier: Optional[Any] = Field(
|
|
28
|
+
verifier: Optional[Any] = Field(
|
|
29
|
+
None, description="Verifier function with decorator (async or sync)"
|
|
30
|
+
)
|
|
27
31
|
verifier_id: Optional[str] = Field(None, description="Verifier identifier")
|
|
28
32
|
verifier_sha: Optional[str] = Field(None, description="Verifier SHA256 hash")
|
|
29
|
-
metadata: Optional[Dict[str, Any]] = Field(
|
|
33
|
+
metadata: Optional[Dict[str, Any]] = Field(
|
|
34
|
+
default_factory=dict, description="Additional task metadata"
|
|
35
|
+
)
|
|
30
36
|
|
|
31
|
-
@validator(
|
|
37
|
+
@validator("key")
|
|
32
38
|
def validate_key_format(cls, v):
|
|
33
39
|
"""Validate key follows kebab-case format."""
|
|
34
|
-
if not re.match(r
|
|
35
|
-
raise ValueError(
|
|
40
|
+
if not re.match(r"^[a-z0-9]+(-[a-z0-9]+)*$", v):
|
|
41
|
+
raise ValueError(
|
|
42
|
+
f"Invalid task key format: {v}. Must follow kebab-case format."
|
|
43
|
+
)
|
|
36
44
|
return v
|
|
37
45
|
|
|
38
|
-
@validator(
|
|
46
|
+
@validator("created_at", pre=True, always=True)
|
|
39
47
|
def set_created_at(cls, v):
|
|
40
48
|
"""Set created_at to current time if not provided."""
|
|
41
49
|
return v or datetime.now()
|
|
42
|
-
|
|
50
|
+
|
|
43
51
|
@property
|
|
44
52
|
def env_key(self) -> str:
|
|
45
53
|
"""Get the environment key combining env_id and version."""
|
|
@@ -49,26 +57,45 @@ class Task(BaseModel):
|
|
|
49
57
|
|
|
50
58
|
class Config:
|
|
51
59
|
"""Pydantic model configuration."""
|
|
60
|
+
|
|
52
61
|
json_encoders = {
|
|
53
62
|
datetime: lambda v: v.isoformat(),
|
|
54
63
|
}
|
|
55
64
|
# Allow arbitrary types for the verifier field
|
|
56
|
-
arbitrary_types_allowed = True
|
|
65
|
+
arbitrary_types_allowed = True
|
|
57
66
|
|
|
58
67
|
def verify(self, env, *args, **kwargs) -> float:
|
|
59
68
|
"""Verify the task using the verifier function (sync version).
|
|
60
|
-
|
|
69
|
+
|
|
61
70
|
For sync environments, calls the sync verifier directly.
|
|
62
71
|
For async verifiers, automatically runs them with asyncio.run().
|
|
63
72
|
"""
|
|
64
73
|
if self.verifier:
|
|
65
|
-
|
|
74
|
+
import inspect
|
|
75
|
+
|
|
76
|
+
result = self.verifier.remote(env, *args, **kwargs)
|
|
77
|
+
|
|
78
|
+
# If the result is a coroutine, we need to run it
|
|
79
|
+
if inspect.iscoroutine(result):
|
|
80
|
+
# Check if we're already in an event loop
|
|
81
|
+
try:
|
|
82
|
+
loop = asyncio.get_running_loop()
|
|
83
|
+
# We're in an async context, can't use asyncio.run()
|
|
84
|
+
raise RuntimeError(
|
|
85
|
+
"Cannot run async verifier in sync mode while event loop is running. "
|
|
86
|
+
"Use await task.verify_async() instead."
|
|
87
|
+
)
|
|
88
|
+
except RuntimeError:
|
|
89
|
+
# No event loop running, safe to use asyncio.run()
|
|
90
|
+
return asyncio.run(result)
|
|
91
|
+
else:
|
|
92
|
+
return result
|
|
66
93
|
else:
|
|
67
94
|
raise ValueError("No verifier function found for this task")
|
|
68
|
-
|
|
95
|
+
|
|
69
96
|
def verify_async(self, *args, **kwargs) -> float:
|
|
70
97
|
"""Verify the task using the verifier function (async version).
|
|
71
|
-
|
|
98
|
+
|
|
72
99
|
For async environments, awaits the async verifier.
|
|
73
100
|
Works with both sync and async verifiers in async contexts.
|
|
74
101
|
"""
|
|
@@ -76,9 +103,53 @@ class Task(BaseModel):
|
|
|
76
103
|
result = self.verifier.remote(*args, **kwargs)
|
|
77
104
|
# If it's a coroutine, await it
|
|
78
105
|
import inspect
|
|
106
|
+
|
|
79
107
|
if inspect.iscoroutine(result):
|
|
80
108
|
return result
|
|
81
109
|
else:
|
|
82
110
|
return result
|
|
83
111
|
else:
|
|
84
112
|
raise ValueError("No verifier function found for this task")
|
|
113
|
+
|
|
114
|
+
def make_env(self, region: Optional[str] = None):
|
|
115
|
+
"""Create an environment instance for this task's environment.
|
|
116
|
+
|
|
117
|
+
Uses the task's env_id (and version if present) to create the env.
|
|
118
|
+
"""
|
|
119
|
+
if not self.env_id:
|
|
120
|
+
raise ValueError("Task has no env_id defined")
|
|
121
|
+
# Deferred import to avoid circular dependencies
|
|
122
|
+
from .client import Fleet
|
|
123
|
+
|
|
124
|
+
return Fleet().make(env_key=self.env_key, region=region)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def load_tasks(
|
|
128
|
+
env_key: Optional[str] = None,
|
|
129
|
+
keys: Optional[List[str]] = None,
|
|
130
|
+
version: Optional[str] = None,
|
|
131
|
+
team_id: Optional[str] = None
|
|
132
|
+
) -> List[Task]:
|
|
133
|
+
"""Convenience function to load tasks with optional filtering.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
env_key: Optional environment key to filter tasks by
|
|
137
|
+
keys: Optional list of task keys to filter by
|
|
138
|
+
version: Optional version to filter tasks by
|
|
139
|
+
team_id: Optional team_id to filter by (admin only)
|
|
140
|
+
|
|
141
|
+
Examples:
|
|
142
|
+
tasks = await fleet.load_tasks(env_key="fira")
|
|
143
|
+
tasks = await fleet.load_tasks(keys=["task1", "task2"])
|
|
144
|
+
tasks = await fleet.load_tasks(env_key="fira", version="v1.0")
|
|
145
|
+
"""
|
|
146
|
+
# Use the global client by default so users can pre-configure it once
|
|
147
|
+
from .global_client import get_client
|
|
148
|
+
|
|
149
|
+
client = get_client()
|
|
150
|
+
return client.load_tasks(
|
|
151
|
+
env_key=env_key,
|
|
152
|
+
keys=keys,
|
|
153
|
+
version=version,
|
|
154
|
+
team_id=team_id
|
|
155
|
+
)
|
fleet/types.py
CHANGED
|
@@ -15,4 +15,4 @@ if TYPE_CHECKING:
|
|
|
15
15
|
|
|
16
16
|
# Union type to support both async and sync verifiers
|
|
17
17
|
# This definition works for both the async and sync versions of the codebase
|
|
18
|
-
VerifierFunction = Union["SyncVerifierFunction", "AsyncVerifierFunction"]
|
|
18
|
+
VerifierFunction = Union["SyncVerifierFunction", "AsyncVerifierFunction"]
|
fleet/verifiers/__init__.py
CHANGED
fleet/verifiers/bundler.py
CHANGED
|
@@ -544,7 +544,9 @@ class FunctionBundler:
|
|
|
544
544
|
# Ensure fleet-python is always included
|
|
545
545
|
if not requirements:
|
|
546
546
|
requirements = ["fleet-python"]
|
|
547
|
-
elif "fleet-python" not in [
|
|
547
|
+
elif "fleet-python" not in [
|
|
548
|
+
r.split("==")[0].split(">=")[0] for r in requirements
|
|
549
|
+
]:
|
|
548
550
|
requirements.append("fleet-python")
|
|
549
551
|
requirements_file.write_text("\n".join(sorted(set(requirements))))
|
|
550
552
|
|
|
@@ -663,37 +665,39 @@ class FunctionBundler:
|
|
|
663
665
|
logger.warning(f"Failed to extract function {function_name}: {e}")
|
|
664
666
|
|
|
665
667
|
return None
|
|
666
|
-
|
|
668
|
+
|
|
667
669
|
def _get_function_source_without_decorator(self, func: Callable) -> str:
|
|
668
670
|
"""Get function source code without the @verifier decorator."""
|
|
669
671
|
source = inspect.getsource(func)
|
|
670
|
-
lines = source.split(
|
|
671
|
-
|
|
672
|
+
lines = source.split("\n")
|
|
673
|
+
|
|
672
674
|
# Find where the function definition starts
|
|
673
675
|
func_start = -1
|
|
674
676
|
for i, line in enumerate(lines):
|
|
675
|
-
if line.strip().startswith(
|
|
677
|
+
if line.strip().startswith("def "):
|
|
676
678
|
func_start = i
|
|
677
679
|
break
|
|
678
|
-
|
|
680
|
+
|
|
679
681
|
if func_start == -1:
|
|
680
682
|
# Couldn't find function definition, return original
|
|
681
683
|
return source
|
|
682
|
-
|
|
684
|
+
|
|
683
685
|
# Return only from the function definition onward
|
|
684
686
|
func_lines = lines[func_start:]
|
|
685
|
-
|
|
687
|
+
|
|
686
688
|
# Remove common indentation
|
|
687
689
|
if func_lines:
|
|
688
690
|
# Find minimum indentation (excluding empty lines)
|
|
689
|
-
min_indent = float(
|
|
691
|
+
min_indent = float("inf")
|
|
690
692
|
for line in func_lines:
|
|
691
693
|
if line.strip():
|
|
692
694
|
indent = len(line) - len(line.lstrip())
|
|
693
695
|
min_indent = min(min_indent, indent)
|
|
694
|
-
|
|
696
|
+
|
|
695
697
|
# Remove the common indentation
|
|
696
|
-
if min_indent < float(
|
|
697
|
-
func_lines = [
|
|
698
|
-
|
|
699
|
-
|
|
698
|
+
if min_indent < float("inf"):
|
|
699
|
+
func_lines = [
|
|
700
|
+
line[min_indent:] if line.strip() else line for line in func_lines
|
|
701
|
+
]
|
|
702
|
+
|
|
703
|
+
return "\n".join(func_lines)
|
fleet/verifiers/code.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
TASK_SUCCESSFUL_SCORE = 1
|
|
2
|
-
TASK_FAILED_SCORE = 0
|
|
2
|
+
TASK_FAILED_SCORE = 0
|
fleet/verifiers/decorator.py
CHANGED
|
@@ -13,34 +13,26 @@ import logging
|
|
|
13
13
|
|
|
14
14
|
logger = logging.getLogger(__name__)
|
|
15
15
|
|
|
16
|
-
F = TypeVar(
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
16
|
+
F = TypeVar("F", bound=Callable[..., Any])
|
|
20
17
|
|
|
21
18
|
|
|
22
19
|
class SyncVerifierFunction:
|
|
23
20
|
"""Wrapper for a verified function that supports local execution with env-first pattern."""
|
|
24
|
-
|
|
25
|
-
def __init__(
|
|
26
|
-
self,
|
|
27
|
-
func: F,
|
|
28
|
-
key: str,
|
|
29
|
-
verifier_id: str
|
|
30
|
-
):
|
|
21
|
+
|
|
22
|
+
def __init__(self, func: F, key: str, verifier_id: str):
|
|
31
23
|
self.func = func
|
|
32
24
|
self.key = key
|
|
33
25
|
self.name = key # Keep name for backward compatibility
|
|
34
26
|
self.verifier_id = verifier_id
|
|
35
|
-
|
|
27
|
+
|
|
36
28
|
# Copy function metadata
|
|
37
29
|
functools.update_wrapper(self, func)
|
|
38
|
-
|
|
30
|
+
|
|
39
31
|
def __call__(self, env, *args, **kwargs) -> float:
|
|
40
32
|
"""Local execution of the verifier function with env as first parameter."""
|
|
41
33
|
try:
|
|
42
34
|
result = self.func(env, *args, **kwargs)
|
|
43
|
-
|
|
35
|
+
|
|
44
36
|
# Handle different return types
|
|
45
37
|
if isinstance(result, (int, float)):
|
|
46
38
|
# Direct score return
|
|
@@ -49,31 +41,33 @@ class SyncVerifierFunction:
|
|
|
49
41
|
return float(result["score"])
|
|
50
42
|
else:
|
|
51
43
|
# Try to extract score from object attributes
|
|
52
|
-
if hasattr(result,
|
|
44
|
+
if hasattr(result, "score"):
|
|
53
45
|
return float(result.score)
|
|
54
46
|
else:
|
|
55
|
-
raise ValueError(
|
|
56
|
-
|
|
47
|
+
raise ValueError(
|
|
48
|
+
f"Verifier function must return a score (number). Got {type(result)}"
|
|
49
|
+
)
|
|
50
|
+
|
|
57
51
|
except Exception as e:
|
|
58
52
|
logger.error(f"Error in verifier {self.key}: {e}")
|
|
59
53
|
# Return error score 0
|
|
60
54
|
return 0.0
|
|
61
55
|
|
|
56
|
+
|
|
62
57
|
def verifier(
|
|
63
|
-
key: Optional[str] = None,
|
|
64
|
-
verifier_id: Optional[str] = None
|
|
58
|
+
key: Optional[str] = None, verifier_id: Optional[str] = None
|
|
65
59
|
) -> Callable[[F], SyncVerifierFunction]:
|
|
66
60
|
"""
|
|
67
61
|
Decorator to create a verifier function with env-first pattern.
|
|
68
|
-
|
|
62
|
+
|
|
69
63
|
The decorated function must take 'env' as its first parameter, making it explicit
|
|
70
64
|
that verifiers operate within an environment context. This makes verifiers reusable
|
|
71
65
|
across different environments.
|
|
72
|
-
|
|
66
|
+
|
|
73
67
|
Args:
|
|
74
68
|
key: Optional key for the verifier. Defaults to function name.
|
|
75
69
|
verifier_id: Optional unique ID for the verifier. Defaults to generated UUID.
|
|
76
|
-
|
|
70
|
+
|
|
77
71
|
Example:
|
|
78
72
|
@verifier(key="test_database_state")
|
|
79
73
|
def check_user_count(env, expected_count: int) -> float:
|
|
@@ -81,23 +75,20 @@ def verifier(
|
|
|
81
75
|
result = db.query("SELECT COUNT(*) FROM users")
|
|
82
76
|
actual_count = result.rows[0][0]
|
|
83
77
|
return 1.0 if actual_count >= expected_count else 0.0
|
|
84
|
-
|
|
78
|
+
|
|
85
79
|
# Usage with different environments
|
|
86
|
-
env1 =
|
|
87
|
-
env2 =
|
|
88
|
-
|
|
80
|
+
env1 = fleet.env.make("fira")
|
|
81
|
+
env2 = fleet.env.make("another_env")
|
|
82
|
+
|
|
89
83
|
# Local execution
|
|
90
84
|
result = await check_user_count(env1, 5)
|
|
91
85
|
result = await check_user_count(env2, 5) # Same verifier, different env
|
|
92
86
|
"""
|
|
87
|
+
|
|
93
88
|
def decorator(func: F) -> SyncVerifierFunction:
|
|
94
89
|
verifier_key = key or func.__name__
|
|
95
90
|
verifier_uuid = verifier_id or str(uuid.uuid4())
|
|
96
|
-
|
|
97
|
-
return SyncVerifierFunction(
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
verifier_uuid
|
|
101
|
-
)
|
|
102
|
-
|
|
103
|
-
return decorator
|
|
91
|
+
|
|
92
|
+
return SyncVerifierFunction(func, verifier_key, verifier_uuid)
|
|
93
|
+
|
|
94
|
+
return decorator
|
fleet/verifiers/parse.py
CHANGED
|
@@ -4,68 +4,86 @@ import re
|
|
|
4
4
|
def extract_function_name(function_code: str) -> str | None:
|
|
5
5
|
"""
|
|
6
6
|
Extract function name from Python function code.
|
|
7
|
-
|
|
7
|
+
|
|
8
8
|
Handles both regular functions (def) and async functions (async def).
|
|
9
|
-
|
|
9
|
+
|
|
10
10
|
Args:
|
|
11
11
|
function_code: Python function code as a string
|
|
12
|
-
|
|
12
|
+
|
|
13
13
|
Returns:
|
|
14
14
|
The function name if found, None otherwise
|
|
15
15
|
"""
|
|
16
|
-
#
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
16
|
+
# Normalize escaped newlines and strip common Markdown code fences
|
|
17
|
+
code = function_code.replace("\\n", "\n").strip()
|
|
18
|
+
|
|
19
|
+
if "```" in code:
|
|
20
|
+
# Extract the first fenced block if present
|
|
21
|
+
fence_blocks = re.findall(r"```[a-zA-Z0-9_+-]*\n([\s\S]*?)\n```", code)
|
|
22
|
+
if fence_blocks:
|
|
23
|
+
code = fence_blocks[0].strip()
|
|
24
|
+
|
|
25
|
+
# Remove leading decorators (keep them for regex but allow preceding lines)
|
|
26
|
+
# Robust regex: allow optional decorators and whitespace before the def
|
|
27
|
+
pattern = r"^\s*(?:@[\w\.\n+() ,]*\n\s*)*(?:async\s+)?def\s+([A-Za-z_]\w*)\s*\("
|
|
28
|
+
match = re.search(pattern, code, flags=re.MULTILINE)
|
|
21
29
|
if match:
|
|
22
30
|
return match.group(1)
|
|
23
|
-
|
|
31
|
+
|
|
32
|
+
# Fallback: search anywhere (not anchored) for a def signature
|
|
33
|
+
fallback = r"(?:async\s+)?def\s+([A-Za-z_]\w*)\s*\("
|
|
34
|
+
match = re.search(fallback, code)
|
|
35
|
+
if match:
|
|
36
|
+
return match.group(1)
|
|
37
|
+
|
|
24
38
|
return None
|
|
25
39
|
|
|
26
40
|
|
|
27
41
|
def convert_verifier_string(verifier_str: str) -> str:
|
|
28
42
|
"""
|
|
29
|
-
Convert a verifier function string from the old format (env: Environment)
|
|
43
|
+
Convert a verifier function string from the old format (env: Environment)
|
|
30
44
|
to the new format (before: DatabaseSnapshot, after: DatabaseSnapshot).
|
|
31
|
-
|
|
45
|
+
|
|
32
46
|
Args:
|
|
33
47
|
verifier_str: The original verifier function as a string
|
|
34
|
-
|
|
48
|
+
|
|
35
49
|
Returns:
|
|
36
50
|
The converted verifier function string
|
|
37
51
|
"""
|
|
38
52
|
# First, handle escaped newlines in the input
|
|
39
|
-
verifier_str = verifier_str.replace(
|
|
40
|
-
|
|
53
|
+
verifier_str = verifier_str.replace("\\n", "\n")
|
|
54
|
+
|
|
41
55
|
# Extract function name, docstring, and body
|
|
42
56
|
# More flexible pattern that accepts both int and float return types
|
|
43
57
|
func_pattern = r'def\s+(\w+)\s*\(\s*env(?:\s*:\s*Environment)?\s*,?\s*final_answer(?:\s*:\s*str\s*\|\s*None)?\s*(?:=\s*None)?\s*\)\s*->\s*(?:float|int):\s*\n((?:\s*""".*?"""\s*\n)?)(.*)'
|
|
44
58
|
match = re.match(func_pattern, verifier_str.strip(), re.DOTALL)
|
|
45
|
-
|
|
59
|
+
|
|
46
60
|
if not match:
|
|
47
61
|
# Try with multiline pattern
|
|
48
62
|
func_pattern_multiline = r'def\s+(\w+)\s*\(\s*\n?\s*env(?:\s*:\s*Environment)?\s*,?\s*\n?\s*final_answer(?:\s*:\s*str\s*\|\s*None)?\s*(?:=\s*None)?\s*\n?\s*\)\s*->\s*(?:float|int):\s*\n((?:\s*""".*?"""\s*\n)?)(.*)'
|
|
49
63
|
match = re.match(func_pattern_multiline, verifier_str.strip(), re.DOTALL)
|
|
50
|
-
|
|
64
|
+
|
|
51
65
|
if not match:
|
|
52
|
-
raise ValueError(
|
|
53
|
-
|
|
66
|
+
raise ValueError(
|
|
67
|
+
"Could not parse verifier function. Expected format: def function_name(env: Environment, final_answer: str | None = None) -> float/int:"
|
|
68
|
+
)
|
|
69
|
+
|
|
54
70
|
func_name = match.group(1)
|
|
55
71
|
docstring = match.group(2).strip()
|
|
56
72
|
body = match.group(3)
|
|
57
|
-
|
|
73
|
+
|
|
58
74
|
# Find all unique env.db() calls
|
|
59
75
|
db_calls = re.findall(r'env\.db\("(\w+)"\)', body)
|
|
60
|
-
unique_db_names = list(
|
|
61
|
-
|
|
76
|
+
unique_db_names = list(
|
|
77
|
+
dict.fromkeys(db_calls)
|
|
78
|
+
) # Remove duplicates while preserving order
|
|
79
|
+
|
|
62
80
|
# Build the new function
|
|
63
|
-
new_func = f
|
|
81
|
+
new_func = f"""def {func_name}(
|
|
64
82
|
before: DatabaseSnapshot, after: DatabaseSnapshot, transcript: str | None = None
|
|
65
83
|
) -> int:
|
|
66
84
|
class Environment:
|
|
67
|
-
def db(self, name: str) -> DatabaseSnapshot:
|
|
68
|
-
|
|
85
|
+
def db(self, name: str) -> DatabaseSnapshot:"""
|
|
86
|
+
|
|
69
87
|
# Build the db method based on found database names
|
|
70
88
|
if unique_db_names:
|
|
71
89
|
conditions = []
|
|
@@ -73,28 +91,32 @@ def convert_verifier_string(verifier_str: str) -> str:
|
|
|
73
91
|
if db_name == "seed":
|
|
74
92
|
conditions.append('before if name == "seed"')
|
|
75
93
|
elif db_name == "current":
|
|
76
|
-
conditions.append(
|
|
94
|
+
conditions.append("after")
|
|
77
95
|
else:
|
|
78
96
|
# Handle other database names if needed
|
|
79
97
|
conditions.append(f'None # Handle "{db_name}"')
|
|
80
|
-
|
|
81
|
-
if
|
|
82
|
-
|
|
83
|
-
|
|
98
|
+
|
|
99
|
+
if (
|
|
100
|
+
len(conditions) == 2
|
|
101
|
+
and "seed" in unique_db_names
|
|
102
|
+
and "current" in unique_db_names
|
|
103
|
+
):
|
|
104
|
+
new_func += f"""
|
|
105
|
+
return before if name == "seed" else after"""
|
|
84
106
|
else:
|
|
85
107
|
# More complex mapping if needed
|
|
86
|
-
new_func += f
|
|
108
|
+
new_func += f"""
|
|
87
109
|
if name == "seed":
|
|
88
110
|
return before
|
|
89
111
|
elif name == "current":
|
|
90
112
|
return after
|
|
91
113
|
else:
|
|
92
|
-
raise ValueError(f"Unknown database name: {{name}}")
|
|
114
|
+
raise ValueError(f"Unknown database name: {{name}}")"""
|
|
93
115
|
else:
|
|
94
|
-
new_func +=
|
|
95
|
-
return before if name == "seed" else after
|
|
96
|
-
|
|
97
|
-
new_func +=
|
|
116
|
+
new_func += """
|
|
117
|
+
return before if name == "seed" else after"""
|
|
118
|
+
|
|
119
|
+
new_func += """
|
|
98
120
|
|
|
99
121
|
@property
|
|
100
122
|
def instance(self):
|
|
@@ -103,42 +125,44 @@ def convert_verifier_string(verifier_str: str) -> str:
|
|
|
103
125
|
def load(self):
|
|
104
126
|
pass
|
|
105
127
|
|
|
106
|
-
def verifier(env: Environment, final_answer: str | None = None) -> float:
|
|
107
|
-
|
|
128
|
+
def verifier(env: Environment, final_answer: str | None = None) -> float:"""
|
|
129
|
+
|
|
108
130
|
if docstring:
|
|
109
|
-
new_func += f
|
|
110
|
-
|
|
131
|
+
new_func += f"\n {docstring}"
|
|
132
|
+
|
|
111
133
|
# First, find the minimum indentation in the body (excluding empty lines)
|
|
112
134
|
body_lines = body.splitlines()
|
|
113
|
-
min_indent = float(
|
|
135
|
+
min_indent = float("inf")
|
|
114
136
|
for line in body_lines:
|
|
115
137
|
if line.strip(): # Non-empty line
|
|
116
138
|
indent_len = len(line) - len(line.lstrip())
|
|
117
139
|
min_indent = min(min_indent, indent_len)
|
|
118
|
-
|
|
140
|
+
|
|
119
141
|
# If we didn't find any non-empty lines, set min_indent to 0
|
|
120
|
-
if min_indent == float(
|
|
142
|
+
if min_indent == float("inf"):
|
|
121
143
|
min_indent = 0
|
|
122
|
-
|
|
144
|
+
|
|
123
145
|
# Now strip the minimum indentation and re-indent to 8 spaces
|
|
124
146
|
if body_lines:
|
|
125
147
|
indented_lines = []
|
|
126
148
|
for line in body_lines:
|
|
127
149
|
if line.strip(): # Non-empty line
|
|
128
150
|
# Remove the minimum indentation and add 8 spaces
|
|
129
|
-
stripped_line =
|
|
130
|
-
|
|
151
|
+
stripped_line = (
|
|
152
|
+
line[min_indent:] if len(line) > min_indent else line.lstrip()
|
|
153
|
+
)
|
|
154
|
+
indented_lines.append(" " + stripped_line)
|
|
131
155
|
else: # Empty line
|
|
132
|
-
indented_lines.append(
|
|
133
|
-
|
|
134
|
-
indented_body =
|
|
135
|
-
new_func += f
|
|
136
|
-
|
|
156
|
+
indented_lines.append("")
|
|
157
|
+
|
|
158
|
+
indented_body = "\n".join(indented_lines)
|
|
159
|
+
new_func += f"\n{indented_body}"
|
|
160
|
+
|
|
137
161
|
# Add the return statement
|
|
138
|
-
new_func +=
|
|
162
|
+
new_func += "\n\n return verifier(Environment(), transcript)"
|
|
139
163
|
|
|
140
164
|
# Replace TASK_FAILED_SCORE with 0 in the function string
|
|
141
|
-
new_func = new_func.replace(
|
|
165
|
+
new_func = new_func.replace("TASK_FAILED_SCORE", "0")
|
|
142
166
|
|
|
143
167
|
return new_func
|
|
144
168
|
|
|
@@ -147,39 +171,45 @@ def convert_new_to_old_verifier(verifier_str: str) -> str:
|
|
|
147
171
|
"""
|
|
148
172
|
Convert a verifier function from the new format (before/after: DatabaseSnapshot)
|
|
149
173
|
to the old format (env: Environment).
|
|
150
|
-
|
|
174
|
+
|
|
151
175
|
This is the inverse of convert_verifier_string.
|
|
152
|
-
|
|
176
|
+
|
|
153
177
|
Args:
|
|
154
178
|
verifier_str: The new format verifier function as a string
|
|
155
|
-
|
|
179
|
+
|
|
156
180
|
Returns:
|
|
157
181
|
The converted verifier function string that accepts env
|
|
158
|
-
"""
|
|
182
|
+
"""
|
|
159
183
|
# Extract function name, parameters, docstring, and body
|
|
160
184
|
# Pattern for new format with flexible whitespace and multiline support
|
|
161
185
|
func_pattern = r'def\s+(\w+)\s*\(\s*before\s*:\s*DatabaseSnapshot\s*,?\s*after\s*:\s*DatabaseSnapshot\s*,?\s*transcript\s*:\s*str\s*\|\s*None\s*=\s*None\s*,?\s*\)\s*->\s*int:\s*((?:\s*""".*?"""\s*)?)(.*)'
|
|
162
|
-
|
|
186
|
+
|
|
163
187
|
# Try multiline pattern that's more flexible
|
|
164
188
|
func_pattern_multiline = r'def\s+(\w+)\s*\(\s*\n?\s*before\s*:\s*DatabaseSnapshot\s*,?\s*\n?\s*after\s*:\s*DatabaseSnapshot\s*,?\s*\n?\s*transcript\s*:\s*str\s*\|\s*None\s*=\s*None\s*,?\s*\n?\s*\)\s*->\s*int:\s*\n?((?:\s*""".*?"""\s*)?)(.*)'
|
|
165
|
-
|
|
166
|
-
match = re.match(
|
|
167
|
-
|
|
189
|
+
|
|
190
|
+
match = re.match(
|
|
191
|
+
func_pattern_multiline, verifier_str.strip(), re.DOTALL | re.MULTILINE
|
|
192
|
+
)
|
|
193
|
+
|
|
168
194
|
if not match:
|
|
169
195
|
# Even more flexible pattern
|
|
170
|
-
func_pattern_flexible =
|
|
196
|
+
func_pattern_flexible = (
|
|
197
|
+
r'def\s+(\w+)\s*\([^)]*\)\s*->\s*int:\s*\n?((?:\s*""".*?"""\s*)?)(.*)'
|
|
198
|
+
)
|
|
171
199
|
match = re.match(func_pattern_flexible, verifier_str.strip(), re.DOTALL)
|
|
172
|
-
|
|
200
|
+
|
|
173
201
|
if not match:
|
|
174
202
|
raise ValueError("Could not parse new format verifier function")
|
|
175
|
-
|
|
203
|
+
|
|
176
204
|
func_name = match.group(1)
|
|
177
205
|
docstring = match.group(2).strip()
|
|
178
206
|
body = match.group(3)
|
|
179
|
-
|
|
207
|
+
|
|
180
208
|
# Indent the original function body
|
|
181
|
-
indented_verifier =
|
|
182
|
-
|
|
209
|
+
indented_verifier = "\n".join(
|
|
210
|
+
" " + line if line.strip() else line for line in verifier_str.splitlines()
|
|
211
|
+
)
|
|
212
|
+
|
|
183
213
|
# Build the wrapper function
|
|
184
214
|
wrapper_func = f'''def {func_name}_wrapper(env, *args, **kwargs) -> float:
|
|
185
215
|
"""Wrapper to adapt new format verifier to old format."""
|
|
@@ -203,5 +233,5 @@ def convert_new_to_old_verifier(verifier_str: str) -> str:
|
|
|
203
233
|
# Call the inner function and convert result
|
|
204
234
|
result = {func_name}(before, after, transcript)
|
|
205
235
|
return float(result)'''
|
|
206
|
-
|
|
207
|
-
return wrapper_func
|
|
236
|
+
|
|
237
|
+
return wrapper_func
|