hud-python 0.5.1__py3-none-any.whl → 0.5.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hud/__init__.py +1 -1
- hud/agents/__init__.py +65 -6
- hud/agents/base.py +33 -15
- hud/agents/claude.py +60 -31
- hud/agents/gateway.py +42 -0
- hud/agents/gemini.py +15 -26
- hud/agents/gemini_cua.py +6 -17
- hud/agents/misc/response_agent.py +7 -0
- hud/agents/openai.py +16 -29
- hud/agents/openai_chat.py +3 -19
- hud/agents/operator.py +5 -17
- hud/agents/resolver.py +70 -0
- hud/agents/tests/test_claude.py +2 -4
- hud/agents/tests/test_openai.py +2 -1
- hud/agents/tests/test_resolver.py +192 -0
- hud/agents/types.py +148 -0
- hud/cli/__init__.py +34 -3
- hud/cli/build.py +37 -5
- hud/cli/dev.py +11 -2
- hud/cli/eval.py +51 -39
- hud/cli/flows/init.py +1 -1
- hud/cli/pull.py +1 -1
- hud/cli/push.py +9 -2
- hud/cli/tests/test_build.py +2 -2
- hud/cli/tests/test_push.py +1 -1
- hud/cli/utils/metadata.py +1 -1
- hud/cli/utils/tests/test_metadata.py +1 -1
- hud/clients/mcp_use.py +6 -1
- hud/datasets/loader.py +17 -18
- hud/datasets/runner.py +16 -10
- hud/datasets/tests/test_loader.py +15 -15
- hud/environment/__init__.py +5 -3
- hud/environment/connection.py +58 -6
- hud/environment/connectors/mcp_config.py +29 -1
- hud/environment/environment.py +218 -77
- hud/environment/router.py +175 -24
- hud/environment/scenarios.py +313 -186
- hud/environment/tests/test_connectors.py +10 -23
- hud/environment/tests/test_environment.py +432 -0
- hud/environment/tests/test_local_connectors.py +81 -40
- hud/environment/tests/test_scenarios.py +820 -14
- hud/eval/context.py +63 -10
- hud/eval/instrument.py +4 -2
- hud/eval/manager.py +79 -12
- hud/eval/task.py +36 -4
- hud/eval/tests/test_eval.py +1 -1
- hud/eval/tests/test_task.py +147 -1
- hud/eval/types.py +2 -0
- hud/eval/utils.py +14 -3
- hud/patches/mcp_patches.py +178 -21
- hud/telemetry/instrument.py +8 -1
- hud/telemetry/tests/test_eval_telemetry.py +8 -8
- hud/tools/__init__.py +2 -0
- hud/tools/agent.py +223 -0
- hud/tools/computer/__init__.py +34 -5
- hud/tools/shell.py +3 -3
- hud/tools/tests/test_agent_tool.py +355 -0
- hud/types.py +62 -34
- hud/utils/hud_console.py +30 -17
- hud/utils/strict_schema.py +1 -1
- hud/utils/tests/test_version.py +1 -1
- hud/version.py +1 -1
- {hud_python-0.5.1.dist-info → hud_python-0.5.13.dist-info}/METADATA +2 -2
- {hud_python-0.5.1.dist-info → hud_python-0.5.13.dist-info}/RECORD +67 -61
- {hud_python-0.5.1.dist-info → hud_python-0.5.13.dist-info}/WHEEL +0 -0
- {hud_python-0.5.1.dist-info → hud_python-0.5.13.dist-info}/entry_points.txt +0 -0
- {hud_python-0.5.1.dist-info → hud_python-0.5.13.dist-info}/licenses/LICENSE +0 -0
hud/environment/scenarios.py
CHANGED
|
@@ -5,8 +5,9 @@ from __future__ import annotations
|
|
|
5
5
|
import inspect
|
|
6
6
|
import json
|
|
7
7
|
import logging
|
|
8
|
-
import
|
|
9
|
-
|
|
8
|
+
from typing import TYPE_CHECKING, Any, get_type_hints
|
|
9
|
+
|
|
10
|
+
from pydantic import BaseModel, ConfigDict
|
|
10
11
|
|
|
11
12
|
if TYPE_CHECKING:
|
|
12
13
|
from collections.abc import AsyncGenerator, Callable
|
|
@@ -15,11 +16,28 @@ if TYPE_CHECKING:
|
|
|
15
16
|
from fastmcp.resources import ResourceManager
|
|
16
17
|
from fastmcp.tools import ToolManager
|
|
17
18
|
|
|
18
|
-
__all__ = ["ScenarioMixin"]
|
|
19
|
+
__all__ = ["ScenarioMixin", "ScenarioSession"]
|
|
19
20
|
|
|
20
21
|
logger = logging.getLogger(__name__)
|
|
21
22
|
|
|
22
23
|
|
|
24
|
+
class ScenarioSession(BaseModel):
|
|
25
|
+
"""Tracks an active scenario from setup through evaluate.
|
|
26
|
+
|
|
27
|
+
Created during run_scenario_setup(), used by submit() and run_scenario_evaluate().
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
31
|
+
|
|
32
|
+
local_name: str # Canonical short name (e.g., "investigate")
|
|
33
|
+
full_name: str # Full name as called (e.g., "sentry-agent:investigate")
|
|
34
|
+
is_local: bool # True if running locally (generator exists)
|
|
35
|
+
connection_name: str | None # Which connection served it (if remote)
|
|
36
|
+
resource_uri: str # Full URI for reading evaluation result
|
|
37
|
+
generator: Any | None = None # AsyncGenerator (if local) - Any to avoid validation issues
|
|
38
|
+
answer: str | None = None # Submitted answer
|
|
39
|
+
|
|
40
|
+
|
|
23
41
|
class ScenarioMixin:
|
|
24
42
|
"""Mixin providing @env.scenario decorator for setup/evaluate phases.
|
|
25
43
|
|
|
@@ -45,24 +63,25 @@ class ScenarioMixin:
|
|
|
45
63
|
yield float(result > 0 or "found" in answer.lower())
|
|
46
64
|
"""
|
|
47
65
|
|
|
48
|
-
# These come from Environment/MCPServer
|
|
66
|
+
# These come from Environment/MCPServer (type hints for mixin)
|
|
49
67
|
name: str
|
|
50
68
|
_prompt_manager: PromptManager
|
|
51
69
|
_resource_manager: ResourceManager
|
|
52
70
|
_tool_manager: ToolManager
|
|
53
71
|
|
|
54
|
-
# Scenario
|
|
72
|
+
# Scenario function registry
|
|
55
73
|
_scenarios: dict[str, Callable[..., AsyncGenerator[Any, Any]]]
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
74
|
+
|
|
75
|
+
# Single active scenario session - used for BOTH:
|
|
76
|
+
# - Client-side: when we run scenarios (local or remote)
|
|
77
|
+
# - Server-side: when external clients call our scenarios via MCP
|
|
78
|
+
# Only one scenario can be active at a time.
|
|
79
|
+
_active_session: ScenarioSession | None
|
|
59
80
|
|
|
60
81
|
def _init_scenarios(self) -> None:
|
|
61
82
|
"""Initialize scenario state. Called from Environment.__init__."""
|
|
62
83
|
self._scenarios = {}
|
|
63
|
-
self.
|
|
64
|
-
self._scenario_latest = {}
|
|
65
|
-
self._scenario_answers = {}
|
|
84
|
+
self._active_session = None
|
|
66
85
|
|
|
67
86
|
# Register _hud_submit tool (underscore = hidden from agent)
|
|
68
87
|
self._register_hud_submit_tool()
|
|
@@ -70,35 +89,41 @@ class ScenarioMixin:
|
|
|
70
89
|
async def submit(self, scenario: str, answer: str) -> None:
|
|
71
90
|
"""Submit the agent's answer for a scenario's evaluate phase.
|
|
72
91
|
|
|
73
|
-
|
|
74
|
-
|
|
92
|
+
Uses _active_session to route to the correct connection (if remote)
|
|
93
|
+
or store locally (if local scenario).
|
|
75
94
|
|
|
76
95
|
Args:
|
|
77
|
-
scenario: Name of the scenario (
|
|
96
|
+
scenario: Name of the scenario (may include env prefix like "env:name")
|
|
78
97
|
answer: The agent's answer/result to submit
|
|
98
|
+
"""
|
|
99
|
+
local_name = scenario.split(":")[-1] if ":" in scenario else scenario
|
|
79
100
|
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
101
|
+
if not self._active_session:
|
|
102
|
+
raise ValueError(
|
|
103
|
+
"No active scenario session. Call run_scenario_setup() before submit()."
|
|
104
|
+
)
|
|
83
105
|
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
106
|
+
if self._active_session.local_name != local_name:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
f"Scenario mismatch: active session is '{self._active_session.local_name}', "
|
|
109
|
+
f"but submit() called with '{local_name}'"
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
self._active_session.answer = answer
|
|
113
|
+
logger.debug("Stored answer in session for scenario '%s'", local_name)
|
|
114
|
+
|
|
115
|
+
if not self._active_session.is_local:
|
|
116
|
+
# Remote scenario - send to specific connection
|
|
117
|
+
conn_name = self._active_session.connection_name
|
|
118
|
+
if not conn_name:
|
|
119
|
+
raise ValueError(f"Remote scenario '{local_name}' has no connection")
|
|
120
|
+
|
|
121
|
+
conn = self._connections.get(conn_name) # type: ignore[attr-defined]
|
|
122
|
+
if not conn or not conn.client:
|
|
123
|
+
raise ValueError(f"Connection '{conn_name}' not available")
|
|
124
|
+
|
|
125
|
+
await conn.call_tool("_hud_submit", {"scenario": local_name, "answer": answer})
|
|
126
|
+
logger.debug("Sent answer to connection '%s' for scenario '%s'", conn_name, local_name)
|
|
102
127
|
|
|
103
128
|
def _register_hud_submit_tool(self) -> None:
|
|
104
129
|
"""Register the _hud_submit tool for receiving agent answers.
|
|
@@ -110,22 +135,33 @@ class ScenarioMixin:
|
|
|
110
135
|
scenario_self = self
|
|
111
136
|
|
|
112
137
|
async def _hud_submit(scenario: str, answer: str) -> str:
|
|
113
|
-
"""
|
|
138
|
+
"""Receive an agent's answer from an external client.
|
|
114
139
|
|
|
115
|
-
|
|
140
|
+
Called when an external client's Environment.submit() sends an answer
|
|
141
|
+
to us via MCP. Stores in _active_session for resource_handler to use.
|
|
116
142
|
|
|
117
143
|
Args:
|
|
118
|
-
scenario: Name of the scenario (
|
|
144
|
+
scenario: Name of the scenario (may include env prefix like "env:name")
|
|
119
145
|
answer: The agent's answer/result to submit
|
|
120
146
|
"""
|
|
121
|
-
|
|
122
|
-
|
|
147
|
+
local_name = scenario.split(":")[-1] if ":" in scenario else scenario
|
|
148
|
+
|
|
149
|
+
if not scenario_self._active_session:
|
|
150
|
+
raise ValueError(f"No active scenario session for '{local_name}'")
|
|
151
|
+
|
|
152
|
+
if scenario_self._active_session.local_name != local_name:
|
|
153
|
+
raise ValueError(
|
|
154
|
+
f"Scenario mismatch: active is '{scenario_self._active_session.local_name}', "
|
|
155
|
+
f"but received answer for '{local_name}'"
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
scenario_self._active_session.answer = answer
|
|
123
159
|
logger.debug(
|
|
124
|
-
"_hud_submit
|
|
125
|
-
|
|
160
|
+
"_hud_submit stored answer for scenario '%s': %s...",
|
|
161
|
+
local_name,
|
|
126
162
|
answer[:50] if len(answer) > 50 else answer,
|
|
127
163
|
)
|
|
128
|
-
return f"Answer submitted for scenario '{
|
|
164
|
+
return f"Answer submitted for scenario '{local_name}'"
|
|
129
165
|
|
|
130
166
|
# Register the tool with underscore name
|
|
131
167
|
tool = Tool.from_function(_hud_submit)
|
|
@@ -136,33 +172,58 @@ class ScenarioMixin:
|
|
|
136
172
|
"""Run a scenario's setup phase and return the prompt.
|
|
137
173
|
|
|
138
174
|
Handles both local scenarios (registered via @env.scenario) and remote
|
|
139
|
-
scenarios (via MCP prompt).
|
|
175
|
+
scenarios (via MCP prompt). Creates _active_session for use by submit/evaluate.
|
|
140
176
|
|
|
141
177
|
Args:
|
|
142
|
-
scenario_name: Name of the scenario to run
|
|
178
|
+
scenario_name: Name of the scenario to run (may include "env:" prefix)
|
|
143
179
|
args: Arguments to pass to the scenario
|
|
144
180
|
|
|
145
181
|
Returns:
|
|
146
182
|
The prompt string from the scenario's setup phase, or None if failed
|
|
147
183
|
"""
|
|
148
|
-
#
|
|
149
|
-
|
|
184
|
+
# Determine if this should be local or remote:
|
|
185
|
+
# - No prefix ("greet") → check local first
|
|
186
|
+
# - Prefix matches our env name ("my-env:greet" when self.name="my-env") → local
|
|
187
|
+
# - Prefix is different ("other-env:greet") → remote only
|
|
188
|
+
local_name: str | None = None
|
|
189
|
+
is_explicitly_remote = False
|
|
190
|
+
if ":" in scenario_name:
|
|
191
|
+
prefix, short_name = scenario_name.rsplit(":", 1)
|
|
192
|
+
# self.name is already normalized (underscores → hyphens) in Environment.__init__
|
|
193
|
+
if prefix == self.name:
|
|
194
|
+
# Prefix matches our env - check local
|
|
195
|
+
local_name = short_name
|
|
196
|
+
else:
|
|
197
|
+
# Different prefix - explicitly remote
|
|
198
|
+
local_name = short_name
|
|
199
|
+
is_explicitly_remote = True
|
|
200
|
+
else:
|
|
201
|
+
# No prefix - check local
|
|
202
|
+
local_name = scenario_name
|
|
203
|
+
|
|
204
|
+
# Check if scenario is registered locally (unless explicitly remote)
|
|
205
|
+
if not is_explicitly_remote and local_name in self._scenarios:
|
|
150
206
|
# Local scenario - run setup via generator
|
|
151
|
-
scenario_fn = self._scenarios[
|
|
207
|
+
scenario_fn = self._scenarios[local_name]
|
|
152
208
|
gen = scenario_fn(**args)
|
|
153
209
|
|
|
154
210
|
# Run setup phase (code before first yield)
|
|
155
211
|
prompt = await gen.__anext__()
|
|
156
212
|
|
|
157
|
-
#
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
213
|
+
# Create session for local scenario
|
|
214
|
+
self._active_session = ScenarioSession(
|
|
215
|
+
local_name=local_name,
|
|
216
|
+
full_name=scenario_name,
|
|
217
|
+
is_local=True,
|
|
218
|
+
connection_name=None,
|
|
219
|
+
resource_uri=f"{self.name}:{local_name}",
|
|
220
|
+
generator=gen,
|
|
221
|
+
)
|
|
161
222
|
|
|
162
223
|
logger.debug(
|
|
163
|
-
"
|
|
164
|
-
|
|
165
|
-
|
|
224
|
+
"Local scenario setup: %s (session=%s)",
|
|
225
|
+
local_name,
|
|
226
|
+
self._active_session,
|
|
166
227
|
)
|
|
167
228
|
return str(prompt)
|
|
168
229
|
else:
|
|
@@ -171,27 +232,50 @@ class ScenarioMixin:
|
|
|
171
232
|
# Otherwise, prefix with env name: {env_name}:{scenario_name}
|
|
172
233
|
if ":" in scenario_name:
|
|
173
234
|
prompt_id = scenario_name
|
|
174
|
-
logger.debug("Remote scenario (already namespaced): prompt_id=%s", prompt_id)
|
|
175
235
|
else:
|
|
236
|
+
# Use _source_env_name (from EvalContext) or self.name - both are normalized
|
|
176
237
|
env_name = getattr(self, "_source_env_name", None) or self.name
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
238
|
+
prompt_id = f"{env_name}:{scenario_name}"
|
|
239
|
+
|
|
240
|
+
# Serialize args for MCP prompt (only supports string values)
|
|
241
|
+
serialized_args: dict[str, str] = {}
|
|
242
|
+
for key, value in args.items():
|
|
243
|
+
serialized_args[key] = value if isinstance(value, str) else json.dumps(value)
|
|
244
|
+
|
|
180
245
|
try:
|
|
181
|
-
result = await self.get_prompt(prompt_id,
|
|
246
|
+
result = await self.get_prompt(prompt_id, serialized_args) # type: ignore[attr-defined]
|
|
247
|
+
# Get connection AFTER get_prompt succeeds (routing is now guaranteed built)
|
|
248
|
+
conn_name = self._router.get_prompt_connection(prompt_id) # type: ignore[attr-defined]
|
|
249
|
+
logger.debug(
|
|
250
|
+
"Remote scenario: prompt_id=%s, connection=%s",
|
|
251
|
+
prompt_id,
|
|
252
|
+
conn_name or "(not found in router)",
|
|
253
|
+
)
|
|
182
254
|
except Exception as e:
|
|
183
255
|
# Fetch available scenarios for error context
|
|
184
256
|
try:
|
|
185
257
|
prompts = await self.list_prompts() # type: ignore[attr-defined]
|
|
186
258
|
scenario_prompts = [p.name for p in prompts if ":" in p.name]
|
|
187
|
-
available = (
|
|
188
|
-
"\n ".join(scenario_prompts) if scenario_prompts else "(none found)"
|
|
189
|
-
)
|
|
259
|
+
available = "\n ".join(scenario_prompts) if scenario_prompts else "(none)"
|
|
190
260
|
except Exception:
|
|
191
|
-
available = "(could not fetch
|
|
261
|
+
available = "(could not fetch)"
|
|
262
|
+
scenario_prompts = []
|
|
263
|
+
|
|
264
|
+
original_error = str(e)
|
|
265
|
+
if prompt_id in scenario_prompts:
|
|
266
|
+
raise ValueError(
|
|
267
|
+
f"⚠️ ERROR: Scenario '{prompt_id}' exists but failed to execute.\n\n"
|
|
268
|
+
f"The scenario was found but encountered an error during setup:\n"
|
|
269
|
+
f" {original_error}\n\n"
|
|
270
|
+
f"This could be caused by:\n"
|
|
271
|
+
f" - Missing or invalid scenario arguments\n"
|
|
272
|
+
f" - An error in the scenario's setup function\n"
|
|
273
|
+
f" - Connection or serialization issues\n\n"
|
|
274
|
+
f"Check the scenario definition and required arguments."
|
|
275
|
+
) from e
|
|
192
276
|
|
|
193
277
|
raise ValueError(
|
|
194
|
-
f"Scenario not found.\n\n"
|
|
278
|
+
f"⚠️ ERROR: Scenario not found.\n\n"
|
|
195
279
|
f"Scenario IDs have the format 'environment_name:scenario_name'.\n"
|
|
196
280
|
f"If you only specify 'scenario_name', the SDK uses your task's env name "
|
|
197
281
|
f"as the prefix.\n"
|
|
@@ -203,35 +287,46 @@ class ScenarioMixin:
|
|
|
203
287
|
f"Fix: Use one of the scenario IDs above in your task JSON."
|
|
204
288
|
) from e
|
|
205
289
|
|
|
206
|
-
#
|
|
290
|
+
# Extract prompt text from response
|
|
291
|
+
prompt_text: str | None = None
|
|
207
292
|
if result.messages:
|
|
208
293
|
first_msg = result.messages[0]
|
|
209
294
|
content = first_msg.content
|
|
210
295
|
if hasattr(content, "text") and isinstance(content.text, str): # type: ignore[union-attr]
|
|
211
|
-
|
|
296
|
+
prompt_text = content.text # type: ignore[union-attr]
|
|
212
297
|
elif isinstance(content, str):
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
raise ValueError(
|
|
217
|
-
f"Scenario '{scenario_name}' returned malformed content.\n\n"
|
|
218
|
-
f"Expected: content with .text attribute (str) or content as str\n"
|
|
219
|
-
f"Got: {type(content).__name__}\n\n"
|
|
220
|
-
f"Check that the scenario's setup function returns a valid prompt."
|
|
221
|
-
)
|
|
222
|
-
else:
|
|
223
|
-
# get_prompt succeeded but returned empty messages
|
|
298
|
+
prompt_text = content
|
|
299
|
+
|
|
300
|
+
if not prompt_text:
|
|
224
301
|
raise ValueError(
|
|
225
302
|
f"Scenario '{scenario_name}' returned an empty response.\n\n"
|
|
226
303
|
f"The scenario's setup function was called but returned no messages.\n"
|
|
227
304
|
f"Check that the scenario returns a valid prompt string."
|
|
228
305
|
)
|
|
229
306
|
|
|
307
|
+
# Create session for remote scenario - use router's connection info
|
|
308
|
+
self._active_session = ScenarioSession(
|
|
309
|
+
local_name=local_name,
|
|
310
|
+
full_name=scenario_name,
|
|
311
|
+
is_local=False,
|
|
312
|
+
connection_name=conn_name,
|
|
313
|
+
resource_uri=prompt_id, # Resource has same URI as prompt
|
|
314
|
+
generator=None,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
logger.debug(
|
|
318
|
+
"Remote scenario setup: %s (connection=%s)",
|
|
319
|
+
prompt_id,
|
|
320
|
+
conn_name,
|
|
321
|
+
)
|
|
322
|
+
return prompt_text
|
|
323
|
+
|
|
230
324
|
async def run_scenario_evaluate(self, scenario_name: str) -> float | None:
|
|
231
325
|
"""Run a scenario's evaluate phase and return the reward.
|
|
232
326
|
|
|
233
|
-
Uses
|
|
234
|
-
|
|
327
|
+
Uses _active_session created by run_scenario_setup():
|
|
328
|
+
- Local: use stored generator with submitted answer
|
|
329
|
+
- Remote: read resource from the connection that served setup
|
|
235
330
|
|
|
236
331
|
Args:
|
|
237
332
|
scenario_name: Name of the scenario to evaluate
|
|
@@ -239,56 +334,55 @@ class ScenarioMixin:
|
|
|
239
334
|
Returns:
|
|
240
335
|
The reward from the scenario's evaluate phase, or None if failed
|
|
241
336
|
"""
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
gen = self._scenario_sessions.pop(session_id, None)
|
|
246
|
-
if gen:
|
|
247
|
-
# Get submitted answer (if any)
|
|
248
|
-
answer = self._scenario_answers.pop(scenario_name, None)
|
|
337
|
+
if not self._active_session:
|
|
338
|
+
logger.warning("No active session for scenario '%s'", scenario_name)
|
|
339
|
+
return None
|
|
249
340
|
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
if
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
resource_id = scenario_name
|
|
341
|
+
session = self._active_session
|
|
342
|
+
self._active_session = None # Clear after use
|
|
343
|
+
|
|
344
|
+
if session.is_local:
|
|
345
|
+
# Local scenario - use generator
|
|
346
|
+
if not session.generator:
|
|
347
|
+
logger.warning("Local scenario '%s' has no generator", session.local_name)
|
|
348
|
+
return None
|
|
349
|
+
|
|
350
|
+
answer = session.answer
|
|
351
|
+
try:
|
|
352
|
+
reward = await session.generator.asend(answer)
|
|
353
|
+
logger.debug(
|
|
354
|
+
"Local scenario %s evaluate: answer=%s, reward=%s",
|
|
355
|
+
session.local_name,
|
|
356
|
+
answer[:50] if answer and len(answer) > 50 else answer,
|
|
357
|
+
reward,
|
|
358
|
+
)
|
|
359
|
+
return float(reward)
|
|
360
|
+
except StopAsyncIteration:
|
|
361
|
+
return 1.0
|
|
272
362
|
else:
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
363
|
+
# Remote scenario - read resource via router
|
|
364
|
+
try:
|
|
365
|
+
contents = await self.read_resource(session.resource_uri) # type: ignore[attr-defined]
|
|
366
|
+
if contents:
|
|
367
|
+
first = contents[0]
|
|
368
|
+
if hasattr(first, "text") and isinstance(first.text, str): # type: ignore[union-attr]
|
|
369
|
+
data = json.loads(first.text) # type: ignore[union-attr]
|
|
370
|
+
if "reward" in data:
|
|
371
|
+
logger.debug(
|
|
372
|
+
"Remote scenario %s evaluate: reward=%s",
|
|
373
|
+
session.local_name,
|
|
374
|
+
data["reward"],
|
|
375
|
+
)
|
|
376
|
+
return float(data["reward"])
|
|
377
|
+
except Exception as e:
|
|
378
|
+
logger.warning("Failed to get scenario reward from %s: %s", session.resource_uri, e)
|
|
379
|
+
return None
|
|
287
380
|
|
|
288
381
|
def scenario(
|
|
289
382
|
self,
|
|
290
383
|
name: str | None = None,
|
|
291
384
|
description: str | None = None,
|
|
385
|
+
required_env_vars: list[str] | None = None,
|
|
292
386
|
) -> Callable[
|
|
293
387
|
[Callable[..., AsyncGenerator[Any, None]]],
|
|
294
388
|
Callable[..., AsyncGenerator[Any, None]],
|
|
@@ -303,28 +397,37 @@ class ScenarioMixin:
|
|
|
303
397
|
Args:
|
|
304
398
|
name: Optional name for the scenario (defaults to function name)
|
|
305
399
|
description: Optional description of what the scenario does
|
|
400
|
+
required_env_vars: Optional list of environment variable names this scenario requires.
|
|
401
|
+
These are used by the HUD platform to check if users have configured the
|
|
402
|
+
necessary API keys/credentials before running this specific scenario.
|
|
306
403
|
|
|
307
404
|
Example:
|
|
308
|
-
@env.scenario()
|
|
309
|
-
async def
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
yield float(result > 0)
|
|
405
|
+
@env.scenario(required_env_vars=["OPENAI_API_KEY"])
|
|
406
|
+
async def chat(query: str):
|
|
407
|
+
yield f"Answer this question: {query}"
|
|
408
|
+
# ... evaluate
|
|
409
|
+
yield 1.0
|
|
314
410
|
|
|
315
411
|
# MCP client usage:
|
|
316
|
-
# 1. get_prompt("{env_name}:
|
|
412
|
+
# 1. get_prompt("{env_name}:chat", {query: "..."}) -> prompt messages
|
|
317
413
|
# 2. agent runs...
|
|
318
|
-
# 3. read_resource("{env_name}:
|
|
414
|
+
# 3. read_resource("{env_name}:chat") -> {"reward": 0.95}
|
|
319
415
|
"""
|
|
320
416
|
|
|
321
417
|
def decorator(
|
|
322
418
|
fn: Callable[..., AsyncGenerator[Any, None]],
|
|
323
419
|
) -> Callable[..., AsyncGenerator[Any, None]]:
|
|
324
420
|
scenario_name = name or fn.__name__
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
421
|
+
|
|
422
|
+
# Validate scenario name - colons are reserved as env:scenario separator
|
|
423
|
+
if ":" in scenario_name:
|
|
424
|
+
raise ValueError(
|
|
425
|
+
f"Scenario name '{scenario_name}' cannot contain ':' "
|
|
426
|
+
"(reserved as separator between environment and scenario names)"
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
# self.name is already normalized (lowercase, hyphens) by Environment.__init__
|
|
430
|
+
scenario_id = f"{self.name}:{scenario_name}"
|
|
328
431
|
scenario_desc = description or fn.__doc__ or f"Scenario: {scenario_name}"
|
|
329
432
|
|
|
330
433
|
# Capture source code for reproducibility
|
|
@@ -353,7 +456,7 @@ class ScenarioMixin:
|
|
|
353
456
|
# Only include JSON-serializable defaults
|
|
354
457
|
default_val = p.default
|
|
355
458
|
if default_val is None or isinstance(
|
|
356
|
-
default_val, (str
|
|
459
|
+
default_val, (str | int | float | bool | list | dict)
|
|
357
460
|
):
|
|
358
461
|
arg_info["default"] = default_val
|
|
359
462
|
|
|
@@ -381,30 +484,81 @@ class ScenarioMixin:
|
|
|
381
484
|
# Register PROMPT - runs setup, returns prompt messages
|
|
382
485
|
# We need a reference to self and the outer variables
|
|
383
486
|
scenario_self = self
|
|
384
|
-
scenario_fn = fn
|
|
385
487
|
scenario_name_ref = scenario_name
|
|
386
488
|
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
489
|
+
# Resolve parameter type hints for deserialization
|
|
490
|
+
# Use get_type_hints() to handle `from __future__ import annotations`
|
|
491
|
+
# which makes annotations lazy strings (PEP 563)
|
|
492
|
+
# MCP prompts only support string arguments, so we JSON-serialize complex types
|
|
493
|
+
# and use Pydantic TypeAdapter to properly deserialize them
|
|
494
|
+
try:
|
|
495
|
+
param_annotations = get_type_hints(fn)
|
|
496
|
+
except Exception:
|
|
497
|
+
# Fall back to raw annotations if get_type_hints fails
|
|
498
|
+
param_annotations = {
|
|
499
|
+
p.name: p.annotation
|
|
500
|
+
for p in sig.parameters.values()
|
|
501
|
+
if p.annotation is not inspect.Parameter.empty
|
|
502
|
+
}
|
|
398
503
|
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
504
|
+
async def prompt_handler(**handler_args: Any) -> list[str]:
|
|
505
|
+
from pydantic import TypeAdapter
|
|
506
|
+
|
|
507
|
+
# Deserialize JSON-encoded arguments using Pydantic TypeAdapter
|
|
508
|
+
# MCP prompts only support string arguments, so complex types are
|
|
509
|
+
# JSON-serialized on the sending side and deserialized here
|
|
510
|
+
deserialized_args: dict[str, Any] = {}
|
|
511
|
+
for arg_name, arg_value in handler_args.items():
|
|
512
|
+
annotation = param_annotations.get(arg_name)
|
|
513
|
+
|
|
514
|
+
# Only attempt deserialization on string values
|
|
515
|
+
if not isinstance(arg_value, str):
|
|
516
|
+
deserialized_args[arg_name] = arg_value
|
|
517
|
+
continue
|
|
518
|
+
|
|
519
|
+
# If annotation is explicitly str, keep as string
|
|
520
|
+
if annotation is str:
|
|
521
|
+
deserialized_args[arg_name] = arg_value
|
|
522
|
+
continue
|
|
523
|
+
|
|
524
|
+
# If we have a non-str type annotation, use TypeAdapter
|
|
525
|
+
if annotation is not None:
|
|
526
|
+
try:
|
|
527
|
+
adapter = TypeAdapter(annotation)
|
|
528
|
+
deserialized_args[arg_name] = adapter.validate_json(arg_value)
|
|
529
|
+
continue
|
|
530
|
+
except Exception: # noqa: S110
|
|
531
|
+
pass # Fall through to generic JSON decode
|
|
532
|
+
|
|
533
|
+
# Try JSON decode for strings that look like JSON
|
|
534
|
+
stripped = arg_value.strip()
|
|
535
|
+
if (stripped and stripped[0] in "[{") or stripped in ("true", "false", "null"):
|
|
536
|
+
try:
|
|
537
|
+
deserialized_args[arg_name] = json.loads(arg_value)
|
|
538
|
+
continue
|
|
539
|
+
except json.JSONDecodeError:
|
|
540
|
+
pass
|
|
541
|
+
|
|
542
|
+
# Try to decode if it looks like a number
|
|
543
|
+
if stripped.lstrip("-").replace(".", "", 1).isdigit():
|
|
544
|
+
try:
|
|
545
|
+
deserialized_args[arg_name] = json.loads(arg_value)
|
|
546
|
+
continue
|
|
547
|
+
except json.JSONDecodeError:
|
|
548
|
+
pass
|
|
549
|
+
|
|
550
|
+
# Keep as string
|
|
551
|
+
deserialized_args[arg_name] = arg_value
|
|
552
|
+
|
|
553
|
+
# Delegate to run_scenario_setup (consolidates client/server logic)
|
|
554
|
+
prompt_text = await scenario_self.run_scenario_setup(
|
|
555
|
+
scenario_name_ref, deserialized_args
|
|
404
556
|
)
|
|
405
557
|
|
|
558
|
+
if prompt_text is None:
|
|
559
|
+
raise ValueError(f"Scenario '{scenario_name_ref}' setup returned no prompt")
|
|
560
|
+
|
|
406
561
|
# Return just the string - FastMCP wraps it in PromptMessage
|
|
407
|
-
# Don't return dict or it gets JSON-serialized as text content
|
|
408
562
|
return [str(prompt_text)]
|
|
409
563
|
|
|
410
564
|
# Register prompt using FastMCP - create FunctionPrompt directly
|
|
@@ -417,6 +571,8 @@ class ScenarioMixin:
|
|
|
417
571
|
scenario_meta["code"] = source_code
|
|
418
572
|
if prompt_args:
|
|
419
573
|
scenario_meta["arguments"] = prompt_args
|
|
574
|
+
if required_env_vars:
|
|
575
|
+
scenario_meta["required_env_vars"] = required_env_vars
|
|
420
576
|
|
|
421
577
|
prompt = FunctionPrompt(
|
|
422
578
|
name=scenario_id,
|
|
@@ -432,40 +588,11 @@ class ScenarioMixin:
|
|
|
432
588
|
|
|
433
589
|
# Register RESOURCE - runs evaluate, returns reward
|
|
434
590
|
async def resource_handler() -> str:
|
|
435
|
-
#
|
|
436
|
-
|
|
437
|
-
if not session_id:
|
|
438
|
-
raise ValueError(
|
|
439
|
-
f"No active session for scenario '{scenario_name_ref}'. "
|
|
440
|
-
"Call the prompt first to run setup."
|
|
441
|
-
)
|
|
442
|
-
|
|
443
|
-
gen = scenario_self._scenario_sessions.pop(session_id, None)
|
|
444
|
-
if gen is None:
|
|
445
|
-
raise ValueError(f"Session '{session_id}' not found or already evaluated.")
|
|
446
|
-
|
|
447
|
-
# Get submitted answer (if any)
|
|
448
|
-
answer = scenario_self._scenario_answers.pop(scenario_name_ref, None)
|
|
449
|
-
|
|
450
|
-
# Run evaluate phase (code after first yield)
|
|
451
|
-
# Use asend to pass the answer (or None if not submitted)
|
|
452
|
-
try:
|
|
453
|
-
reward = await gen.asend(answer)
|
|
454
|
-
except StopAsyncIteration:
|
|
455
|
-
# Generator ended without second yield - assume success
|
|
456
|
-
reward = 1.0
|
|
457
|
-
|
|
458
|
-
logger.debug(
|
|
459
|
-
"Scenario %s evaluate complete, session=%s, answer=%s, reward=%s",
|
|
460
|
-
scenario_name_ref,
|
|
461
|
-
session_id,
|
|
462
|
-
answer[:50] if answer and len(answer) > 50 else answer,
|
|
463
|
-
reward,
|
|
464
|
-
)
|
|
591
|
+
# Delegate to run_scenario_evaluate (consolidates client/server logic)
|
|
592
|
+
reward = await scenario_self.run_scenario_evaluate(scenario_name_ref)
|
|
465
593
|
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
del scenario_self._scenario_latest[scenario_name_ref]
|
|
594
|
+
if reward is None:
|
|
595
|
+
raise ValueError(f"Scenario '{scenario_name_ref}' evaluation failed")
|
|
469
596
|
|
|
470
597
|
return json.dumps({"reward": float(reward)})
|
|
471
598
|
|