hud-python 0.5.1__py3-none-any.whl → 0.5.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. hud/__init__.py +1 -1
  2. hud/agents/__init__.py +65 -6
  3. hud/agents/base.py +33 -15
  4. hud/agents/claude.py +60 -31
  5. hud/agents/gateway.py +42 -0
  6. hud/agents/gemini.py +15 -26
  7. hud/agents/gemini_cua.py +6 -17
  8. hud/agents/misc/response_agent.py +7 -0
  9. hud/agents/openai.py +16 -29
  10. hud/agents/openai_chat.py +3 -19
  11. hud/agents/operator.py +5 -17
  12. hud/agents/resolver.py +70 -0
  13. hud/agents/tests/test_claude.py +2 -4
  14. hud/agents/tests/test_openai.py +2 -1
  15. hud/agents/tests/test_resolver.py +192 -0
  16. hud/agents/types.py +148 -0
  17. hud/cli/__init__.py +34 -3
  18. hud/cli/build.py +37 -5
  19. hud/cli/dev.py +11 -2
  20. hud/cli/eval.py +51 -39
  21. hud/cli/flows/init.py +1 -1
  22. hud/cli/pull.py +1 -1
  23. hud/cli/push.py +9 -2
  24. hud/cli/tests/test_build.py +2 -2
  25. hud/cli/tests/test_push.py +1 -1
  26. hud/cli/utils/metadata.py +1 -1
  27. hud/cli/utils/tests/test_metadata.py +1 -1
  28. hud/clients/mcp_use.py +6 -1
  29. hud/datasets/loader.py +17 -18
  30. hud/datasets/runner.py +16 -10
  31. hud/datasets/tests/test_loader.py +15 -15
  32. hud/environment/__init__.py +5 -3
  33. hud/environment/connection.py +58 -6
  34. hud/environment/connectors/mcp_config.py +29 -1
  35. hud/environment/environment.py +218 -77
  36. hud/environment/router.py +175 -24
  37. hud/environment/scenarios.py +313 -186
  38. hud/environment/tests/test_connectors.py +10 -23
  39. hud/environment/tests/test_environment.py +432 -0
  40. hud/environment/tests/test_local_connectors.py +81 -40
  41. hud/environment/tests/test_scenarios.py +820 -14
  42. hud/eval/context.py +63 -10
  43. hud/eval/instrument.py +4 -2
  44. hud/eval/manager.py +79 -12
  45. hud/eval/task.py +36 -4
  46. hud/eval/tests/test_eval.py +1 -1
  47. hud/eval/tests/test_task.py +147 -1
  48. hud/eval/types.py +2 -0
  49. hud/eval/utils.py +14 -3
  50. hud/patches/mcp_patches.py +178 -21
  51. hud/telemetry/instrument.py +8 -1
  52. hud/telemetry/tests/test_eval_telemetry.py +8 -8
  53. hud/tools/__init__.py +2 -0
  54. hud/tools/agent.py +223 -0
  55. hud/tools/computer/__init__.py +34 -5
  56. hud/tools/shell.py +3 -3
  57. hud/tools/tests/test_agent_tool.py +355 -0
  58. hud/types.py +62 -34
  59. hud/utils/hud_console.py +30 -17
  60. hud/utils/strict_schema.py +1 -1
  61. hud/utils/tests/test_version.py +1 -1
  62. hud/version.py +1 -1
  63. {hud_python-0.5.1.dist-info → hud_python-0.5.13.dist-info}/METADATA +2 -2
  64. {hud_python-0.5.1.dist-info → hud_python-0.5.13.dist-info}/RECORD +67 -61
  65. {hud_python-0.5.1.dist-info → hud_python-0.5.13.dist-info}/WHEEL +0 -0
  66. {hud_python-0.5.1.dist-info → hud_python-0.5.13.dist-info}/entry_points.txt +0 -0
  67. {hud_python-0.5.1.dist-info → hud_python-0.5.13.dist-info}/licenses/LICENSE +0 -0
@@ -5,8 +5,9 @@ from __future__ import annotations
5
5
  import inspect
6
6
  import json
7
7
  import logging
8
- import uuid
9
- from typing import TYPE_CHECKING, Any
8
+ from typing import TYPE_CHECKING, Any, get_type_hints
9
+
10
+ from pydantic import BaseModel, ConfigDict
10
11
 
11
12
  if TYPE_CHECKING:
12
13
  from collections.abc import AsyncGenerator, Callable
@@ -15,11 +16,28 @@ if TYPE_CHECKING:
15
16
  from fastmcp.resources import ResourceManager
16
17
  from fastmcp.tools import ToolManager
17
18
 
18
- __all__ = ["ScenarioMixin"]
19
+ __all__ = ["ScenarioMixin", "ScenarioSession"]
19
20
 
20
21
  logger = logging.getLogger(__name__)
21
22
 
22
23
 
24
+ class ScenarioSession(BaseModel):
25
+ """Tracks an active scenario from setup through evaluate.
26
+
27
+ Created during run_scenario_setup(), used by submit() and run_scenario_evaluate().
28
+ """
29
+
30
+ model_config = ConfigDict(arbitrary_types_allowed=True)
31
+
32
+ local_name: str # Canonical short name (e.g., "investigate")
33
+ full_name: str # Full name as called (e.g., "sentry-agent:investigate")
34
+ is_local: bool # True if running locally (generator exists)
35
+ connection_name: str | None # Which connection served it (if remote)
36
+ resource_uri: str # Full URI for reading evaluation result
37
+ generator: Any | None = None # AsyncGenerator (if local) - Any to avoid validation issues
38
+ answer: str | None = None # Submitted answer
39
+
40
+
23
41
  class ScenarioMixin:
24
42
  """Mixin providing @env.scenario decorator for setup/evaluate phases.
25
43
 
@@ -45,24 +63,25 @@ class ScenarioMixin:
45
63
  yield float(result > 0 or "found" in answer.lower())
46
64
  """
47
65
 
48
- # These come from Environment/MCPServer
66
+ # These come from Environment/MCPServer (type hints for mixin)
49
67
  name: str
50
68
  _prompt_manager: PromptManager
51
69
  _resource_manager: ResourceManager
52
70
  _tool_manager: ToolManager
53
71
 
54
- # Scenario state
72
+ # Scenario function registry
55
73
  _scenarios: dict[str, Callable[..., AsyncGenerator[Any, Any]]]
56
- _scenario_sessions: dict[str, AsyncGenerator[Any, Any]] # session_id -> generator
57
- _scenario_latest: dict[str, str] # scenario_name -> latest session_id
58
- _scenario_answers: dict[str, str] # scenario_name -> submitted answer
74
+
75
+ # Single active scenario session - used for BOTH:
76
+ # - Client-side: when we run scenarios (local or remote)
77
+ # - Server-side: when external clients call our scenarios via MCP
78
+ # Only one scenario can be active at a time.
79
+ _active_session: ScenarioSession | None
59
80
 
60
81
  def _init_scenarios(self) -> None:
61
82
  """Initialize scenario state. Called from Environment.__init__."""
62
83
  self._scenarios = {}
63
- self._scenario_sessions = {}
64
- self._scenario_latest = {}
65
- self._scenario_answers = {}
84
+ self._active_session = None
66
85
 
67
86
  # Register _hud_submit tool (underscore = hidden from agent)
68
87
  self._register_hud_submit_tool()
@@ -70,35 +89,41 @@ class ScenarioMixin:
70
89
  async def submit(self, scenario: str, answer: str) -> None:
71
90
  """Submit the agent's answer for a scenario's evaluate phase.
72
91
 
73
- This stores the answer locally and broadcasts to connected hubs
74
- that have the _hud_submit tool (auto-detected by Environment).
92
+ Uses _active_session to route to the correct connection (if remote)
93
+ or store locally (if local scenario).
75
94
 
76
95
  Args:
77
- scenario: Name of the scenario (without env prefix)
96
+ scenario: Name of the scenario (may include env prefix like "env:name")
78
97
  answer: The agent's answer/result to submit
98
+ """
99
+ local_name = scenario.split(":")[-1] if ":" in scenario else scenario
79
100
 
80
- Example:
81
- # Direct call with scenario name
82
- await env.submit("checkout", "Order completed successfully")
101
+ if not self._active_session:
102
+ raise ValueError(
103
+ "No active scenario session. Call run_scenario_setup() before submit()."
104
+ )
83
105
 
84
- # Or via EvalContext (knows its own scenario)
85
- await ctx.submit("Order completed successfully")
86
- """
87
- # Store locally for our scenarios
88
- self._scenario_answers[scenario] = answer
89
- logger.debug(
90
- "Stored answer for scenario '%s': %s...",
91
- scenario,
92
- answer[:50] if len(answer) > 50 else answer,
93
- )
94
-
95
- # Broadcast to connections that have _hud_submit
96
- # Environment._broadcast_tool auto-filters to connections with the tool
97
- await self._broadcast_tool( # type: ignore[attr-defined]
98
- "_hud_submit",
99
- scenario=scenario,
100
- answer=answer,
101
- )
106
+ if self._active_session.local_name != local_name:
107
+ raise ValueError(
108
+ f"Scenario mismatch: active session is '{self._active_session.local_name}', "
109
+ f"but submit() called with '{local_name}'"
110
+ )
111
+
112
+ self._active_session.answer = answer
113
+ logger.debug("Stored answer in session for scenario '%s'", local_name)
114
+
115
+ if not self._active_session.is_local:
116
+ # Remote scenario - send to specific connection
117
+ conn_name = self._active_session.connection_name
118
+ if not conn_name:
119
+ raise ValueError(f"Remote scenario '{local_name}' has no connection")
120
+
121
+ conn = self._connections.get(conn_name) # type: ignore[attr-defined]
122
+ if not conn or not conn.client:
123
+ raise ValueError(f"Connection '{conn_name}' not available")
124
+
125
+ await conn.call_tool("_hud_submit", {"scenario": local_name, "answer": answer})
126
+ logger.debug("Sent answer to connection '%s' for scenario '%s'", conn_name, local_name)
102
127
 
103
128
  def _register_hud_submit_tool(self) -> None:
104
129
  """Register the _hud_submit tool for receiving agent answers.
@@ -110,22 +135,33 @@ class ScenarioMixin:
110
135
  scenario_self = self
111
136
 
112
137
  async def _hud_submit(scenario: str, answer: str) -> str:
113
- """Submit the agent's answer for a scenario's evaluate phase.
138
+ """Receive an agent's answer from an external client.
114
139
 
115
- Internal tool - called by Environment.submit() on connected hubs.
140
+ Called when an external client's Environment.submit() sends an answer
141
+ to us via MCP. Stores in _active_session for resource_handler to use.
116
142
 
117
143
  Args:
118
- scenario: Name of the scenario (without env prefix)
144
+ scenario: Name of the scenario (may include env prefix like "env:name")
119
145
  answer: The agent's answer/result to submit
120
146
  """
121
- # Store locally (don't broadcast - we ARE the target)
122
- scenario_self._scenario_answers[scenario] = answer
147
+ local_name = scenario.split(":")[-1] if ":" in scenario else scenario
148
+
149
+ if not scenario_self._active_session:
150
+ raise ValueError(f"No active scenario session for '{local_name}'")
151
+
152
+ if scenario_self._active_session.local_name != local_name:
153
+ raise ValueError(
154
+ f"Scenario mismatch: active is '{scenario_self._active_session.local_name}', "
155
+ f"but received answer for '{local_name}'"
156
+ )
157
+
158
+ scenario_self._active_session.answer = answer
123
159
  logger.debug(
124
- "_hud_submit received answer for scenario '%s': %s...",
125
- scenario,
160
+ "_hud_submit stored answer for scenario '%s': %s...",
161
+ local_name,
126
162
  answer[:50] if len(answer) > 50 else answer,
127
163
  )
128
- return f"Answer submitted for scenario '{scenario}'"
164
+ return f"Answer submitted for scenario '{local_name}'"
129
165
 
130
166
  # Register the tool with underscore name
131
167
  tool = Tool.from_function(_hud_submit)
@@ -136,33 +172,58 @@ class ScenarioMixin:
136
172
  """Run a scenario's setup phase and return the prompt.
137
173
 
138
174
  Handles both local scenarios (registered via @env.scenario) and remote
139
- scenarios (via MCP prompt).
175
+ scenarios (via MCP prompt). Creates _active_session for use by submit/evaluate.
140
176
 
141
177
  Args:
142
- scenario_name: Name of the scenario to run
178
+ scenario_name: Name of the scenario to run (may include "env:" prefix)
143
179
  args: Arguments to pass to the scenario
144
180
 
145
181
  Returns:
146
182
  The prompt string from the scenario's setup phase, or None if failed
147
183
  """
148
- # Check if scenario is registered locally
149
- if scenario_name in self._scenarios:
184
+ # Determine if this should be local or remote:
185
+ # - No prefix ("greet") → check local first
186
+ # - Prefix matches our env name ("my-env:greet" when self.name="my-env") → local
187
+ # - Prefix is different ("other-env:greet") → remote only
188
+ local_name: str | None = None
189
+ is_explicitly_remote = False
190
+ if ":" in scenario_name:
191
+ prefix, short_name = scenario_name.rsplit(":", 1)
192
+ # self.name is already normalized (underscores → hyphens) in Environment.__init__
193
+ if prefix == self.name:
194
+ # Prefix matches our env - check local
195
+ local_name = short_name
196
+ else:
197
+ # Different prefix - explicitly remote
198
+ local_name = short_name
199
+ is_explicitly_remote = True
200
+ else:
201
+ # No prefix - check local
202
+ local_name = scenario_name
203
+
204
+ # Check if scenario is registered locally (unless explicitly remote)
205
+ if not is_explicitly_remote and local_name in self._scenarios:
150
206
  # Local scenario - run setup via generator
151
- scenario_fn = self._scenarios[scenario_name]
207
+ scenario_fn = self._scenarios[local_name]
152
208
  gen = scenario_fn(**args)
153
209
 
154
210
  # Run setup phase (code before first yield)
155
211
  prompt = await gen.__anext__()
156
212
 
157
- # Store generator for evaluate phase
158
- session_id = uuid.uuid4().hex[:8]
159
- self._scenario_sessions[session_id] = gen
160
- self._scenario_latest[scenario_name] = session_id
213
+ # Create session for local scenario
214
+ self._active_session = ScenarioSession(
215
+ local_name=local_name,
216
+ full_name=scenario_name,
217
+ is_local=True,
218
+ connection_name=None,
219
+ resource_uri=f"{self.name}:{local_name}",
220
+ generator=gen,
221
+ )
161
222
 
162
223
  logger.debug(
163
- "Scenario %s setup complete, session=%s",
164
- scenario_name,
165
- session_id,
224
+ "Local scenario setup: %s (session=%s)",
225
+ local_name,
226
+ self._active_session,
166
227
  )
167
228
  return str(prompt)
168
229
  else:
@@ -171,27 +232,50 @@ class ScenarioMixin:
171
232
  # Otherwise, prefix with env name: {env_name}:{scenario_name}
172
233
  if ":" in scenario_name:
173
234
  prompt_id = scenario_name
174
- logger.debug("Remote scenario (already namespaced): prompt_id=%s", prompt_id)
175
235
  else:
236
+ # Use _source_env_name (from EvalContext) or self.name - both are normalized
176
237
  env_name = getattr(self, "_source_env_name", None) or self.name
177
- safe_env_name = env_name.replace("_", "-")
178
- prompt_id = f"{safe_env_name}:{scenario_name}"
179
- logger.debug("Remote scenario (adding namespace): prompt_id=%s", prompt_id)
238
+ prompt_id = f"{env_name}:{scenario_name}"
239
+
240
+ # Serialize args for MCP prompt (only supports string values)
241
+ serialized_args: dict[str, str] = {}
242
+ for key, value in args.items():
243
+ serialized_args[key] = value if isinstance(value, str) else json.dumps(value)
244
+
180
245
  try:
181
- result = await self.get_prompt(prompt_id, args) # type: ignore[attr-defined]
246
+ result = await self.get_prompt(prompt_id, serialized_args) # type: ignore[attr-defined]
247
+ # Get connection AFTER get_prompt succeeds (routing is now guaranteed built)
248
+ conn_name = self._router.get_prompt_connection(prompt_id) # type: ignore[attr-defined]
249
+ logger.debug(
250
+ "Remote scenario: prompt_id=%s, connection=%s",
251
+ prompt_id,
252
+ conn_name or "(not found in router)",
253
+ )
182
254
  except Exception as e:
183
255
  # Fetch available scenarios for error context
184
256
  try:
185
257
  prompts = await self.list_prompts() # type: ignore[attr-defined]
186
258
  scenario_prompts = [p.name for p in prompts if ":" in p.name]
187
- available = (
188
- "\n ".join(scenario_prompts) if scenario_prompts else "(none found)"
189
- )
259
+ available = "\n ".join(scenario_prompts) if scenario_prompts else "(none)"
190
260
  except Exception:
191
- available = "(could not fetch available scenarios)"
261
+ available = "(could not fetch)"
262
+ scenario_prompts = []
263
+
264
+ original_error = str(e)
265
+ if prompt_id in scenario_prompts:
266
+ raise ValueError(
267
+ f"⚠️ ERROR: Scenario '{prompt_id}' exists but failed to execute.\n\n"
268
+ f"The scenario was found but encountered an error during setup:\n"
269
+ f" {original_error}\n\n"
270
+ f"This could be caused by:\n"
271
+ f" - Missing or invalid scenario arguments\n"
272
+ f" - An error in the scenario's setup function\n"
273
+ f" - Connection or serialization issues\n\n"
274
+ f"Check the scenario definition and required arguments."
275
+ ) from e
192
276
 
193
277
  raise ValueError(
194
- f"Scenario not found.\n\n"
278
+ f"⚠️ ERROR: Scenario not found.\n\n"
195
279
  f"Scenario IDs have the format 'environment_name:scenario_name'.\n"
196
280
  f"If you only specify 'scenario_name', the SDK uses your task's env name "
197
281
  f"as the prefix.\n"
@@ -203,35 +287,46 @@ class ScenarioMixin:
203
287
  f"Fix: Use one of the scenario IDs above in your task JSON."
204
288
  ) from e
205
289
 
206
- # Validate the response (outside try/except so errors aren't wrapped)
290
+ # Extract prompt text from response
291
+ prompt_text: str | None = None
207
292
  if result.messages:
208
293
  first_msg = result.messages[0]
209
294
  content = first_msg.content
210
295
  if hasattr(content, "text") and isinstance(content.text, str): # type: ignore[union-attr]
211
- return content.text # type: ignore[union-attr]
296
+ prompt_text = content.text # type: ignore[union-attr]
212
297
  elif isinstance(content, str):
213
- return content
214
- else:
215
- # Content exists but is neither text object nor string
216
- raise ValueError(
217
- f"Scenario '{scenario_name}' returned malformed content.\n\n"
218
- f"Expected: content with .text attribute (str) or content as str\n"
219
- f"Got: {type(content).__name__}\n\n"
220
- f"Check that the scenario's setup function returns a valid prompt."
221
- )
222
- else:
223
- # get_prompt succeeded but returned empty messages
298
+ prompt_text = content
299
+
300
+ if not prompt_text:
224
301
  raise ValueError(
225
302
  f"Scenario '{scenario_name}' returned an empty response.\n\n"
226
303
  f"The scenario's setup function was called but returned no messages.\n"
227
304
  f"Check that the scenario returns a valid prompt string."
228
305
  )
229
306
 
307
+ # Create session for remote scenario - use router's connection info
308
+ self._active_session = ScenarioSession(
309
+ local_name=local_name,
310
+ full_name=scenario_name,
311
+ is_local=False,
312
+ connection_name=conn_name,
313
+ resource_uri=prompt_id, # Resource has same URI as prompt
314
+ generator=None,
315
+ )
316
+
317
+ logger.debug(
318
+ "Remote scenario setup: %s (connection=%s)",
319
+ prompt_id,
320
+ conn_name,
321
+ )
322
+ return prompt_text
323
+
230
324
  async def run_scenario_evaluate(self, scenario_name: str) -> float | None:
231
325
  """Run a scenario's evaluate phase and return the reward.
232
326
 
233
- Uses the submitted answer (if any) via gen.asend().
234
- Handles both local and remote scenarios.
327
+ Uses _active_session created by run_scenario_setup():
328
+ - Local: use stored generator with submitted answer
329
+ - Remote: read resource from the connection that served setup
235
330
 
236
331
  Args:
237
332
  scenario_name: Name of the scenario to evaluate
@@ -239,56 +334,55 @@ class ScenarioMixin:
239
334
  Returns:
240
335
  The reward from the scenario's evaluate phase, or None if failed
241
336
  """
242
- # Check if we have a stored generator (local scenario)
243
- session_id = self._scenario_latest.get(scenario_name)
244
- if session_id:
245
- gen = self._scenario_sessions.pop(session_id, None)
246
- if gen:
247
- # Get submitted answer (if any)
248
- answer = self._scenario_answers.pop(scenario_name, None)
337
+ if not self._active_session:
338
+ logger.warning("No active session for scenario '%s'", scenario_name)
339
+ return None
249
340
 
250
- try:
251
- # Use asend to pass the answer to the scenario
252
- reward = await gen.asend(answer)
253
- logger.debug(
254
- "Scenario %s evaluate complete, answer=%s, reward=%s",
255
- scenario_name,
256
- answer[:50] if answer and len(answer) > 50 else answer,
257
- reward,
258
- )
259
- return float(reward)
260
- except StopAsyncIteration:
261
- # Generator ended without second yield - assume success
262
- return 1.0
263
- finally:
264
- # Clean up latest pointer
265
- if self._scenario_latest.get(scenario_name) == session_id:
266
- del self._scenario_latest[scenario_name]
267
-
268
- # Remote scenario - read via MCP resource
269
- # If scenario_name already contains ":", it's already namespaced - use directly
270
- if ":" in scenario_name:
271
- resource_id = scenario_name
341
+ session = self._active_session
342
+ self._active_session = None # Clear after use
343
+
344
+ if session.is_local:
345
+ # Local scenario - use generator
346
+ if not session.generator:
347
+ logger.warning("Local scenario '%s' has no generator", session.local_name)
348
+ return None
349
+
350
+ answer = session.answer
351
+ try:
352
+ reward = await session.generator.asend(answer)
353
+ logger.debug(
354
+ "Local scenario %s evaluate: answer=%s, reward=%s",
355
+ session.local_name,
356
+ answer[:50] if answer and len(answer) > 50 else answer,
357
+ reward,
358
+ )
359
+ return float(reward)
360
+ except StopAsyncIteration:
361
+ return 1.0
272
362
  else:
273
- env_name = getattr(self, "_source_env_name", None) or self.name
274
- safe_env_name = env_name.replace("_", "-")
275
- resource_id = f"{safe_env_name}:{scenario_name}"
276
- try:
277
- contents = await self.read_resource(resource_id) # type: ignore[attr-defined]
278
- if contents:
279
- first_content = contents[0]
280
- if hasattr(first_content, "text") and isinstance(first_content.text, str): # type: ignore[union-attr]
281
- data = json.loads(first_content.text) # type: ignore[union-attr]
282
- if "reward" in data:
283
- return float(data["reward"])
284
- except Exception as e:
285
- logger.warning("Failed to get scenario reward: %s", e)
286
- return None
363
+ # Remote scenario - read resource via router
364
+ try:
365
+ contents = await self.read_resource(session.resource_uri) # type: ignore[attr-defined]
366
+ if contents:
367
+ first = contents[0]
368
+ if hasattr(first, "text") and isinstance(first.text, str): # type: ignore[union-attr]
369
+ data = json.loads(first.text) # type: ignore[union-attr]
370
+ if "reward" in data:
371
+ logger.debug(
372
+ "Remote scenario %s evaluate: reward=%s",
373
+ session.local_name,
374
+ data["reward"],
375
+ )
376
+ return float(data["reward"])
377
+ except Exception as e:
378
+ logger.warning("Failed to get scenario reward from %s: %s", session.resource_uri, e)
379
+ return None
287
380
 
288
381
  def scenario(
289
382
  self,
290
383
  name: str | None = None,
291
384
  description: str | None = None,
385
+ required_env_vars: list[str] | None = None,
292
386
  ) -> Callable[
293
387
  [Callable[..., AsyncGenerator[Any, None]]],
294
388
  Callable[..., AsyncGenerator[Any, None]],
@@ -303,28 +397,37 @@ class ScenarioMixin:
303
397
  Args:
304
398
  name: Optional name for the scenario (defaults to function name)
305
399
  description: Optional description of what the scenario does
400
+ required_env_vars: Optional list of environment variable names this scenario requires.
401
+ These are used by the HUD platform to check if users have configured the
402
+ necessary API keys/credentials before running this specific scenario.
306
403
 
307
404
  Example:
308
- @env.scenario()
309
- async def search_cats(url: str):
310
- await env.call_tool("navigate", url=url)
311
- yield "Find cat images"
312
- result = await env.call_tool("count_cats")
313
- yield float(result > 0)
405
+ @env.scenario(required_env_vars=["OPENAI_API_KEY"])
406
+ async def chat(query: str):
407
+ yield f"Answer this question: {query}"
408
+ # ... evaluate
409
+ yield 1.0
314
410
 
315
411
  # MCP client usage:
316
- # 1. get_prompt("{env_name}:search_cats", {url: "..."}) -> prompt messages
412
+ # 1. get_prompt("{env_name}:chat", {query: "..."}) -> prompt messages
317
413
  # 2. agent runs...
318
- # 3. read_resource("{env_name}:search_cats") -> {"reward": 0.95}
414
+ # 3. read_resource("{env_name}:chat") -> {"reward": 0.95}
319
415
  """
320
416
 
321
417
  def decorator(
322
418
  fn: Callable[..., AsyncGenerator[Any, None]],
323
419
  ) -> Callable[..., AsyncGenerator[Any, None]]:
324
420
  scenario_name = name or fn.__name__
325
- # Sanitize env name for URI scheme (no underscores allowed)
326
- safe_env_name = self.name.replace("_", "-")
327
- scenario_id = f"{safe_env_name}:{scenario_name}"
421
+
422
+ # Validate scenario name - colons are reserved as env:scenario separator
423
+ if ":" in scenario_name:
424
+ raise ValueError(
425
+ f"Scenario name '{scenario_name}' cannot contain ':' "
426
+ "(reserved as separator between environment and scenario names)"
427
+ )
428
+
429
+ # self.name is already normalized (lowercase, hyphens) by Environment.__init__
430
+ scenario_id = f"{self.name}:{scenario_name}"
328
431
  scenario_desc = description or fn.__doc__ or f"Scenario: {scenario_name}"
329
432
 
330
433
  # Capture source code for reproducibility
@@ -353,7 +456,7 @@ class ScenarioMixin:
353
456
  # Only include JSON-serializable defaults
354
457
  default_val = p.default
355
458
  if default_val is None or isinstance(
356
- default_val, (str, int, float, bool, list, dict)
459
+ default_val, (str | int | float | bool | list | dict)
357
460
  ):
358
461
  arg_info["default"] = default_val
359
462
 
@@ -381,30 +484,81 @@ class ScenarioMixin:
381
484
  # Register PROMPT - runs setup, returns prompt messages
382
485
  # We need a reference to self and the outer variables
383
486
  scenario_self = self
384
- scenario_fn = fn
385
487
  scenario_name_ref = scenario_name
386
488
 
387
- async def prompt_handler(**handler_args: Any) -> list[str]:
388
- # Create generator instance
389
- gen = scenario_fn(**handler_args)
390
-
391
- # Run setup phase (code before first yield)
392
- prompt_text = await gen.__anext__()
393
-
394
- # Store generator with session ID
395
- session_id = uuid.uuid4().hex[:8]
396
- scenario_self._scenario_sessions[session_id] = gen
397
- scenario_self._scenario_latest[scenario_name_ref] = session_id
489
+ # Resolve parameter type hints for deserialization
490
+ # Use get_type_hints() to handle `from __future__ import annotations`
491
+ # which makes annotations lazy strings (PEP 563)
492
+ # MCP prompts only support string arguments, so we JSON-serialize complex types
493
+ # and use Pydantic TypeAdapter to properly deserialize them
494
+ try:
495
+ param_annotations = get_type_hints(fn)
496
+ except Exception:
497
+ # Fall back to raw annotations if get_type_hints fails
498
+ param_annotations = {
499
+ p.name: p.annotation
500
+ for p in sig.parameters.values()
501
+ if p.annotation is not inspect.Parameter.empty
502
+ }
398
503
 
399
- logger.debug(
400
- "Scenario %s setup complete, session=%s, prompt=%s",
401
- scenario_name_ref,
402
- session_id,
403
- prompt_text[:50] if isinstance(prompt_text, str) else prompt_text,
504
+ async def prompt_handler(**handler_args: Any) -> list[str]:
505
+ from pydantic import TypeAdapter
506
+
507
+ # Deserialize JSON-encoded arguments using Pydantic TypeAdapter
508
+ # MCP prompts only support string arguments, so complex types are
509
+ # JSON-serialized on the sending side and deserialized here
510
+ deserialized_args: dict[str, Any] = {}
511
+ for arg_name, arg_value in handler_args.items():
512
+ annotation = param_annotations.get(arg_name)
513
+
514
+ # Only attempt deserialization on string values
515
+ if not isinstance(arg_value, str):
516
+ deserialized_args[arg_name] = arg_value
517
+ continue
518
+
519
+ # If annotation is explicitly str, keep as string
520
+ if annotation is str:
521
+ deserialized_args[arg_name] = arg_value
522
+ continue
523
+
524
+ # If we have a non-str type annotation, use TypeAdapter
525
+ if annotation is not None:
526
+ try:
527
+ adapter = TypeAdapter(annotation)
528
+ deserialized_args[arg_name] = adapter.validate_json(arg_value)
529
+ continue
530
+ except Exception: # noqa: S110
531
+ pass # Fall through to generic JSON decode
532
+
533
+ # Try JSON decode for strings that look like JSON
534
+ stripped = arg_value.strip()
535
+ if (stripped and stripped[0] in "[{") or stripped in ("true", "false", "null"):
536
+ try:
537
+ deserialized_args[arg_name] = json.loads(arg_value)
538
+ continue
539
+ except json.JSONDecodeError:
540
+ pass
541
+
542
+ # Try to decode if it looks like a number
543
+ if stripped.lstrip("-").replace(".", "", 1).isdigit():
544
+ try:
545
+ deserialized_args[arg_name] = json.loads(arg_value)
546
+ continue
547
+ except json.JSONDecodeError:
548
+ pass
549
+
550
+ # Keep as string
551
+ deserialized_args[arg_name] = arg_value
552
+
553
+ # Delegate to run_scenario_setup (consolidates client/server logic)
554
+ prompt_text = await scenario_self.run_scenario_setup(
555
+ scenario_name_ref, deserialized_args
404
556
  )
405
557
 
558
+ if prompt_text is None:
559
+ raise ValueError(f"Scenario '{scenario_name_ref}' setup returned no prompt")
560
+
406
561
  # Return just the string - FastMCP wraps it in PromptMessage
407
- # Don't return dict or it gets JSON-serialized as text content
408
562
  return [str(prompt_text)]
409
563
 
410
564
  # Register prompt using FastMCP - create FunctionPrompt directly
@@ -417,6 +571,8 @@ class ScenarioMixin:
417
571
  scenario_meta["code"] = source_code
418
572
  if prompt_args:
419
573
  scenario_meta["arguments"] = prompt_args
574
+ if required_env_vars:
575
+ scenario_meta["required_env_vars"] = required_env_vars
420
576
 
421
577
  prompt = FunctionPrompt(
422
578
  name=scenario_id,
@@ -432,40 +588,11 @@ class ScenarioMixin:
432
588
 
433
589
  # Register RESOURCE - runs evaluate, returns reward
434
590
  async def resource_handler() -> str:
435
- # Get latest session for this scenario
436
- session_id = scenario_self._scenario_latest.get(scenario_name_ref)
437
- if not session_id:
438
- raise ValueError(
439
- f"No active session for scenario '{scenario_name_ref}'. "
440
- "Call the prompt first to run setup."
441
- )
442
-
443
- gen = scenario_self._scenario_sessions.pop(session_id, None)
444
- if gen is None:
445
- raise ValueError(f"Session '{session_id}' not found or already evaluated.")
446
-
447
- # Get submitted answer (if any)
448
- answer = scenario_self._scenario_answers.pop(scenario_name_ref, None)
449
-
450
- # Run evaluate phase (code after first yield)
451
- # Use asend to pass the answer (or None if not submitted)
452
- try:
453
- reward = await gen.asend(answer)
454
- except StopAsyncIteration:
455
- # Generator ended without second yield - assume success
456
- reward = 1.0
457
-
458
- logger.debug(
459
- "Scenario %s evaluate complete, session=%s, answer=%s, reward=%s",
460
- scenario_name_ref,
461
- session_id,
462
- answer[:50] if answer and len(answer) > 50 else answer,
463
- reward,
464
- )
591
+ # Delegate to run_scenario_evaluate (consolidates client/server logic)
592
+ reward = await scenario_self.run_scenario_evaluate(scenario_name_ref)
465
593
 
466
- # Clean up latest pointer if it matches
467
- if scenario_self._scenario_latest.get(scenario_name_ref) == session_id:
468
- del scenario_self._scenario_latest[scenario_name_ref]
594
+ if reward is None:
595
+ raise ValueError(f"Scenario '{scenario_name_ref}' evaluation failed")
469
596
 
470
597
  return json.dumps({"reward": float(reward)})
471
598