hud-python 0.4.45__py3-none-any.whl → 0.5.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hud/__init__.py +27 -7
- hud/agents/__init__.py +70 -5
- hud/agents/base.py +238 -500
- hud/agents/claude.py +236 -247
- hud/agents/gateway.py +42 -0
- hud/agents/gemini.py +264 -0
- hud/agents/gemini_cua.py +324 -0
- hud/agents/grounded_openai.py +98 -100
- hud/agents/misc/integration_test_agent.py +51 -20
- hud/agents/misc/response_agent.py +48 -36
- hud/agents/openai.py +282 -296
- hud/agents/{openai_chat_generic.py → openai_chat.py} +63 -33
- hud/agents/operator.py +199 -0
- hud/agents/resolver.py +70 -0
- hud/agents/tests/conftest.py +133 -0
- hud/agents/tests/test_base.py +300 -622
- hud/agents/tests/test_base_runtime.py +233 -0
- hud/agents/tests/test_claude.py +381 -214
- hud/agents/tests/test_client.py +9 -10
- hud/agents/tests/test_gemini.py +369 -0
- hud/agents/tests/test_grounded_openai_agent.py +65 -50
- hud/agents/tests/test_openai.py +377 -140
- hud/agents/tests/test_operator.py +362 -0
- hud/agents/tests/test_resolver.py +192 -0
- hud/agents/tests/test_run_eval.py +179 -0
- hud/agents/types.py +148 -0
- hud/cli/__init__.py +493 -546
- hud/cli/analyze.py +43 -5
- hud/cli/build.py +699 -113
- hud/cli/debug.py +8 -5
- hud/cli/dev.py +889 -732
- hud/cli/eval.py +793 -667
- hud/cli/flows/dev.py +167 -0
- hud/cli/flows/init.py +191 -0
- hud/cli/flows/tasks.py +153 -56
- hud/cli/flows/templates.py +151 -0
- hud/cli/flows/tests/__init__.py +1 -0
- hud/cli/flows/tests/test_dev.py +126 -0
- hud/cli/init.py +60 -58
- hud/cli/pull.py +1 -1
- hud/cli/push.py +38 -13
- hud/cli/rft.py +311 -0
- hud/cli/rft_status.py +145 -0
- hud/cli/tests/test_analyze.py +5 -5
- hud/cli/tests/test_analyze_metadata.py +3 -2
- hud/cli/tests/test_analyze_module.py +120 -0
- hud/cli/tests/test_build.py +110 -8
- hud/cli/tests/test_build_failure.py +41 -0
- hud/cli/tests/test_build_module.py +50 -0
- hud/cli/tests/test_cli_init.py +6 -1
- hud/cli/tests/test_cli_more_wrappers.py +30 -0
- hud/cli/tests/test_cli_root.py +140 -0
- hud/cli/tests/test_convert.py +361 -0
- hud/cli/tests/test_debug.py +12 -10
- hud/cli/tests/test_dev.py +197 -0
- hud/cli/tests/test_eval.py +251 -0
- hud/cli/tests/test_eval_bedrock.py +51 -0
- hud/cli/tests/test_init.py +124 -0
- hud/cli/tests/test_main_module.py +11 -5
- hud/cli/tests/test_mcp_server.py +12 -100
- hud/cli/tests/test_push.py +1 -1
- hud/cli/tests/test_push_happy.py +74 -0
- hud/cli/tests/test_push_wrapper.py +23 -0
- hud/cli/tests/test_registry.py +1 -1
- hud/cli/tests/test_utils.py +1 -1
- hud/cli/{rl → utils}/celebrate.py +14 -12
- hud/cli/utils/config.py +18 -1
- hud/cli/utils/docker.py +130 -4
- hud/cli/utils/env_check.py +9 -9
- hud/cli/utils/git.py +136 -0
- hud/cli/utils/interactive.py +39 -5
- hud/cli/utils/metadata.py +70 -1
- hud/cli/utils/runner.py +1 -1
- hud/cli/utils/server.py +2 -2
- hud/cli/utils/source_hash.py +3 -3
- hud/cli/utils/tasks.py +4 -1
- hud/cli/utils/tests/__init__.py +0 -0
- hud/cli/utils/tests/test_config.py +58 -0
- hud/cli/utils/tests/test_docker.py +93 -0
- hud/cli/utils/tests/test_docker_hints.py +71 -0
- hud/cli/utils/tests/test_env_check.py +74 -0
- hud/cli/utils/tests/test_environment.py +42 -0
- hud/cli/utils/tests/test_git.py +142 -0
- hud/cli/utils/tests/test_interactive_module.py +60 -0
- hud/cli/utils/tests/test_local_runner.py +50 -0
- hud/cli/utils/tests/test_logging_utils.py +23 -0
- hud/cli/utils/tests/test_metadata.py +49 -0
- hud/cli/utils/tests/test_package_runner.py +35 -0
- hud/cli/utils/tests/test_registry_utils.py +49 -0
- hud/cli/utils/tests/test_remote_runner.py +25 -0
- hud/cli/utils/tests/test_runner_modules.py +52 -0
- hud/cli/utils/tests/test_source_hash.py +36 -0
- hud/cli/utils/tests/test_tasks.py +80 -0
- hud/cli/utils/version_check.py +258 -0
- hud/cli/{rl → utils}/viewer.py +2 -2
- hud/clients/README.md +12 -11
- hud/clients/__init__.py +4 -3
- hud/clients/base.py +166 -26
- hud/clients/environment.py +51 -0
- hud/clients/fastmcp.py +13 -6
- hud/clients/mcp_use.py +45 -15
- hud/clients/tests/test_analyze_scenarios.py +206 -0
- hud/clients/tests/test_protocol.py +9 -3
- hud/datasets/__init__.py +23 -20
- hud/datasets/loader.py +326 -0
- hud/datasets/runner.py +198 -105
- hud/datasets/tests/__init__.py +0 -0
- hud/datasets/tests/test_loader.py +221 -0
- hud/datasets/tests/test_utils.py +315 -0
- hud/datasets/utils.py +270 -90
- hud/environment/__init__.py +52 -0
- hud/environment/connection.py +258 -0
- hud/environment/connectors/__init__.py +33 -0
- hud/environment/connectors/base.py +68 -0
- hud/environment/connectors/local.py +177 -0
- hud/environment/connectors/mcp_config.py +137 -0
- hud/environment/connectors/openai.py +101 -0
- hud/environment/connectors/remote.py +172 -0
- hud/environment/environment.py +835 -0
- hud/environment/integrations/__init__.py +45 -0
- hud/environment/integrations/adk.py +67 -0
- hud/environment/integrations/anthropic.py +196 -0
- hud/environment/integrations/gemini.py +92 -0
- hud/environment/integrations/langchain.py +82 -0
- hud/environment/integrations/llamaindex.py +68 -0
- hud/environment/integrations/openai.py +238 -0
- hud/environment/mock.py +306 -0
- hud/environment/router.py +263 -0
- hud/environment/scenarios.py +620 -0
- hud/environment/tests/__init__.py +1 -0
- hud/environment/tests/test_connection.py +317 -0
- hud/environment/tests/test_connectors.py +205 -0
- hud/environment/tests/test_environment.py +593 -0
- hud/environment/tests/test_integrations.py +257 -0
- hud/environment/tests/test_local_connectors.py +242 -0
- hud/environment/tests/test_scenarios.py +1086 -0
- hud/environment/tests/test_tools.py +208 -0
- hud/environment/types.py +23 -0
- hud/environment/utils/__init__.py +35 -0
- hud/environment/utils/formats.py +215 -0
- hud/environment/utils/schema.py +171 -0
- hud/environment/utils/tool_wrappers.py +113 -0
- hud/eval/__init__.py +67 -0
- hud/eval/context.py +727 -0
- hud/eval/display.py +299 -0
- hud/eval/instrument.py +187 -0
- hud/eval/manager.py +533 -0
- hud/eval/parallel.py +268 -0
- hud/eval/task.py +372 -0
- hud/eval/tests/__init__.py +1 -0
- hud/eval/tests/test_context.py +178 -0
- hud/eval/tests/test_eval.py +210 -0
- hud/eval/tests/test_manager.py +152 -0
- hud/eval/tests/test_parallel.py +168 -0
- hud/eval/tests/test_task.py +291 -0
- hud/eval/types.py +65 -0
- hud/eval/utils.py +194 -0
- hud/patches/__init__.py +19 -0
- hud/patches/mcp_patches.py +308 -0
- hud/patches/warnings.py +54 -0
- hud/samples/browser.py +4 -4
- hud/server/__init__.py +2 -1
- hud/server/low_level.py +2 -1
- hud/server/router.py +164 -0
- hud/server/server.py +567 -80
- hud/server/tests/test_mcp_server_integration.py +11 -11
- hud/server/tests/test_mcp_server_more.py +1 -1
- hud/server/tests/test_server_extra.py +2 -0
- hud/settings.py +45 -3
- hud/shared/exceptions.py +36 -10
- hud/shared/hints.py +26 -1
- hud/shared/requests.py +15 -3
- hud/shared/tests/test_exceptions.py +40 -31
- hud/shared/tests/test_hints.py +167 -0
- hud/telemetry/__init__.py +20 -19
- hud/telemetry/exporter.py +201 -0
- hud/telemetry/instrument.py +165 -253
- hud/telemetry/tests/test_eval_telemetry.py +356 -0
- hud/telemetry/tests/test_exporter.py +258 -0
- hud/telemetry/tests/test_instrument.py +401 -0
- hud/tools/__init__.py +18 -2
- hud/tools/agent.py +223 -0
- hud/tools/apply_patch.py +639 -0
- hud/tools/base.py +54 -4
- hud/tools/bash.py +2 -2
- hud/tools/computer/__init__.py +36 -3
- hud/tools/computer/anthropic.py +2 -2
- hud/tools/computer/gemini.py +385 -0
- hud/tools/computer/hud.py +23 -6
- hud/tools/computer/openai.py +20 -21
- hud/tools/computer/qwen.py +434 -0
- hud/tools/computer/settings.py +37 -0
- hud/tools/edit.py +3 -7
- hud/tools/executors/base.py +4 -2
- hud/tools/executors/pyautogui.py +1 -1
- hud/tools/grounding/grounded_tool.py +13 -18
- hud/tools/grounding/grounder.py +10 -31
- hud/tools/grounding/tests/test_grounded_tool.py +26 -44
- hud/tools/jupyter.py +330 -0
- hud/tools/playwright.py +18 -3
- hud/tools/shell.py +308 -0
- hud/tools/tests/test_agent_tool.py +355 -0
- hud/tools/tests/test_apply_patch.py +718 -0
- hud/tools/tests/test_computer.py +4 -9
- hud/tools/tests/test_computer_actions.py +24 -2
- hud/tools/tests/test_jupyter_tool.py +181 -0
- hud/tools/tests/test_shell.py +596 -0
- hud/tools/tests/test_submit.py +85 -0
- hud/tools/tests/test_types.py +193 -0
- hud/tools/types.py +21 -1
- hud/types.py +194 -56
- hud/utils/__init__.py +2 -0
- hud/utils/env.py +67 -0
- hud/utils/hud_console.py +89 -18
- hud/utils/mcp.py +15 -58
- hud/utils/strict_schema.py +162 -0
- hud/utils/tests/test_init.py +1 -2
- hud/utils/tests/test_mcp.py +1 -28
- hud/utils/tests/test_pretty_errors.py +186 -0
- hud/utils/tests/test_tool_shorthand.py +154 -0
- hud/utils/tests/test_version.py +1 -1
- hud/utils/types.py +20 -0
- hud/version.py +1 -1
- hud_python-0.5.13.dist-info/METADATA +264 -0
- hud_python-0.5.13.dist-info/RECORD +305 -0
- {hud_python-0.4.45.dist-info → hud_python-0.5.13.dist-info}/WHEEL +1 -1
- hud/agents/langchain.py +0 -261
- hud/agents/lite_llm.py +0 -72
- hud/cli/rl/__init__.py +0 -180
- hud/cli/rl/config.py +0 -101
- hud/cli/rl/display.py +0 -133
- hud/cli/rl/gpu.py +0 -63
- hud/cli/rl/gpu_utils.py +0 -321
- hud/cli/rl/local_runner.py +0 -595
- hud/cli/rl/presets.py +0 -96
- hud/cli/rl/remote_runner.py +0 -463
- hud/cli/rl/rl_api.py +0 -150
- hud/cli/rl/vllm.py +0 -177
- hud/cli/rl/wait_utils.py +0 -89
- hud/datasets/parallel.py +0 -687
- hud/misc/__init__.py +0 -1
- hud/misc/claude_plays_pokemon.py +0 -292
- hud/otel/__init__.py +0 -35
- hud/otel/collector.py +0 -142
- hud/otel/config.py +0 -181
- hud/otel/context.py +0 -570
- hud/otel/exporters.py +0 -369
- hud/otel/instrumentation.py +0 -135
- hud/otel/processors.py +0 -121
- hud/otel/tests/__init__.py +0 -1
- hud/otel/tests/test_processors.py +0 -197
- hud/rl/README.md +0 -30
- hud/rl/__init__.py +0 -1
- hud/rl/actor.py +0 -176
- hud/rl/buffer.py +0 -405
- hud/rl/chat_template.jinja +0 -101
- hud/rl/config.py +0 -192
- hud/rl/distributed.py +0 -132
- hud/rl/learner.py +0 -637
- hud/rl/tests/__init__.py +0 -1
- hud/rl/tests/test_learner.py +0 -186
- hud/rl/train.py +0 -382
- hud/rl/types.py +0 -101
- hud/rl/utils/start_vllm_server.sh +0 -30
- hud/rl/utils.py +0 -524
- hud/rl/vllm_adapter.py +0 -143
- hud/telemetry/job.py +0 -352
- hud/telemetry/replay.py +0 -74
- hud/telemetry/tests/test_replay.py +0 -40
- hud/telemetry/tests/test_trace.py +0 -63
- hud/telemetry/trace.py +0 -158
- hud/utils/agent_factories.py +0 -86
- hud/utils/async_utils.py +0 -65
- hud/utils/group_eval.py +0 -223
- hud/utils/progress.py +0 -149
- hud/utils/tasks.py +0 -127
- hud/utils/tests/test_async_utils.py +0 -173
- hud/utils/tests/test_progress.py +0 -261
- hud_python-0.4.45.dist-info/METADATA +0 -552
- hud_python-0.4.45.dist-info/RECORD +0 -228
- {hud_python-0.4.45.dist-info → hud_python-0.5.13.dist-info}/entry_points.txt +0 -0
- {hud_python-0.4.45.dist-info → hud_python-0.5.13.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1086 @@
|
|
|
1
|
+
"""Tests for Environment scenario decorator."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import pytest
|
|
10
|
+
from pydantic import BaseModel
|
|
11
|
+
|
|
12
|
+
from hud.environment import Environment
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# Module-level models for Pydantic/Enum/datetime deserialization tests
|
|
16
|
+
# (prefixed with underscore to avoid pytest collection warnings)
|
|
17
|
+
class _UserConfig(BaseModel):
|
|
18
|
+
"""Pydantic model for testing."""
|
|
19
|
+
|
|
20
|
+
name: str
|
|
21
|
+
age: int
|
|
22
|
+
active: bool = True
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class _Status(Enum):
|
|
26
|
+
"""Enum for testing."""
|
|
27
|
+
|
|
28
|
+
PENDING = "pending"
|
|
29
|
+
ACTIVE = "active"
|
|
30
|
+
COMPLETED = "completed"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class _Address(BaseModel):
|
|
34
|
+
"""Nested Pydantic model for testing."""
|
|
35
|
+
|
|
36
|
+
street: str
|
|
37
|
+
city: str
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class _Person(BaseModel):
|
|
41
|
+
"""Pydantic model with nested model for testing."""
|
|
42
|
+
|
|
43
|
+
name: str
|
|
44
|
+
address: _Address
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class _Item(BaseModel):
|
|
48
|
+
"""Pydantic model for list tests."""
|
|
49
|
+
|
|
50
|
+
id: int
|
|
51
|
+
name: str
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class TestScenarioDecorator:
|
|
55
|
+
"""Tests for @env.scenario decorator."""
|
|
56
|
+
|
|
57
|
+
def test_scenario_registers_function(self) -> None:
|
|
58
|
+
"""@env.scenario registers the function."""
|
|
59
|
+
env = Environment("test-env")
|
|
60
|
+
|
|
61
|
+
@env.scenario("greet")
|
|
62
|
+
async def greet_scenario(name: str):
|
|
63
|
+
yield f"Hello, {name}!"
|
|
64
|
+
yield 1.0
|
|
65
|
+
|
|
66
|
+
assert "greet" in env._scenarios
|
|
67
|
+
|
|
68
|
+
def test_scenario_creates_mcp_prompt(self) -> None:
|
|
69
|
+
"""@env.scenario creates an MCP prompt."""
|
|
70
|
+
env = Environment("test-env")
|
|
71
|
+
|
|
72
|
+
@env.scenario("greet", description="Greeting scenario")
|
|
73
|
+
async def greet_scenario(name: str):
|
|
74
|
+
yield f"Hello, {name}!"
|
|
75
|
+
yield 1.0
|
|
76
|
+
|
|
77
|
+
# Check that prompt was registered via prompt manager
|
|
78
|
+
prompt_names = list(env._prompt_manager._prompts.keys())
|
|
79
|
+
assert "test-env:greet" in prompt_names
|
|
80
|
+
|
|
81
|
+
def test_scenario_creates_mcp_resource(self) -> None:
|
|
82
|
+
"""@env.scenario creates an MCP resource."""
|
|
83
|
+
env = Environment("test-env")
|
|
84
|
+
|
|
85
|
+
@env.scenario("greet")
|
|
86
|
+
async def greet_scenario(name: str):
|
|
87
|
+
yield f"Hello, {name}!"
|
|
88
|
+
yield 1.0
|
|
89
|
+
|
|
90
|
+
# Check that resource was registered via resource manager
|
|
91
|
+
resource_uris = list(env._resource_manager._resources.keys())
|
|
92
|
+
assert "test-env:greet" in resource_uris
|
|
93
|
+
|
|
94
|
+
def test_scenario_extracts_arguments(self) -> None:
|
|
95
|
+
"""@env.scenario extracts function arguments for prompt."""
|
|
96
|
+
env = Environment("test-env")
|
|
97
|
+
|
|
98
|
+
@env.scenario("checkout")
|
|
99
|
+
async def checkout_scenario(user_id: str, amount: int = 100):
|
|
100
|
+
yield f"Checkout for {user_id}: ${amount}"
|
|
101
|
+
yield 1.0
|
|
102
|
+
|
|
103
|
+
# Find the prompt
|
|
104
|
+
prompt = env._prompt_manager._prompts.get("test-env:checkout")
|
|
105
|
+
assert prompt is not None
|
|
106
|
+
assert prompt.arguments is not None
|
|
107
|
+
|
|
108
|
+
# Check arguments
|
|
109
|
+
arg_names = [arg.name for arg in prompt.arguments]
|
|
110
|
+
assert "user_id" in arg_names
|
|
111
|
+
assert "amount" in arg_names
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class TestScenarioExecution:
|
|
115
|
+
"""Tests for scenario execution flow."""
|
|
116
|
+
|
|
117
|
+
@pytest.mark.asyncio
|
|
118
|
+
async def test_scenario_setup_phase(self) -> None:
|
|
119
|
+
"""Scenario setup phase yields prompt."""
|
|
120
|
+
env = Environment("test-env")
|
|
121
|
+
setup_ran = False
|
|
122
|
+
|
|
123
|
+
@env.scenario("test")
|
|
124
|
+
async def test_scenario():
|
|
125
|
+
nonlocal setup_ran
|
|
126
|
+
setup_ran = True
|
|
127
|
+
yield "Test prompt"
|
|
128
|
+
yield 1.0
|
|
129
|
+
|
|
130
|
+
# Get the prompt handler
|
|
131
|
+
prompt = env._prompt_manager._prompts.get("test-env:test")
|
|
132
|
+
assert prompt is not None
|
|
133
|
+
|
|
134
|
+
# Run setup via prompt render (which calls fn) - no need for context
|
|
135
|
+
result = await prompt.render({})
|
|
136
|
+
|
|
137
|
+
assert setup_ran
|
|
138
|
+
# Result is list of PromptMessage
|
|
139
|
+
assert len(result) > 0
|
|
140
|
+
assert "Test prompt" in str(result[0].content)
|
|
141
|
+
|
|
142
|
+
@pytest.mark.asyncio
|
|
143
|
+
async def test_scenario_stores_session(self) -> None:
|
|
144
|
+
"""Scenario stores generator in session for evaluate phase."""
|
|
145
|
+
env = Environment("test-env")
|
|
146
|
+
|
|
147
|
+
@env.scenario("test")
|
|
148
|
+
async def test_scenario():
|
|
149
|
+
yield "Test prompt"
|
|
150
|
+
yield 1.0
|
|
151
|
+
|
|
152
|
+
# Run setup via prompt - no need for context
|
|
153
|
+
prompt = env._prompt_manager._prompts.get("test-env:test")
|
|
154
|
+
assert prompt is not None
|
|
155
|
+
await prompt.render({})
|
|
156
|
+
|
|
157
|
+
# Check session was stored in _active_session
|
|
158
|
+
assert env._active_session is not None
|
|
159
|
+
assert env._active_session.local_name == "test"
|
|
160
|
+
|
|
161
|
+
@pytest.mark.asyncio
|
|
162
|
+
async def test_scenario_full_flow(self) -> None:
|
|
163
|
+
"""Scenario runs setup and evaluate phases correctly."""
|
|
164
|
+
env = Environment("test-env")
|
|
165
|
+
phases = []
|
|
166
|
+
|
|
167
|
+
@env.scenario("test")
|
|
168
|
+
async def test_scenario():
|
|
169
|
+
phases.append("setup")
|
|
170
|
+
yield "Test prompt"
|
|
171
|
+
phases.append("evaluate")
|
|
172
|
+
yield 0.95
|
|
173
|
+
|
|
174
|
+
# Setup phase - no context needed for prompt/resource
|
|
175
|
+
prompt = env._prompt_manager._prompts.get("test-env:test")
|
|
176
|
+
assert prompt is not None
|
|
177
|
+
await prompt.render({})
|
|
178
|
+
assert "setup" in phases
|
|
179
|
+
assert "evaluate" not in phases
|
|
180
|
+
|
|
181
|
+
# Evaluate phase
|
|
182
|
+
resource = env._resource_manager._resources.get("test-env:test")
|
|
183
|
+
assert resource is not None
|
|
184
|
+
await resource.read()
|
|
185
|
+
assert "evaluate" in phases
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class TestScenarioWithArgs:
|
|
189
|
+
"""Tests for scenarios with arguments."""
|
|
190
|
+
|
|
191
|
+
@pytest.mark.asyncio
|
|
192
|
+
async def test_scenario_receives_args(self) -> None:
|
|
193
|
+
"""Scenario receives arguments from prompt call."""
|
|
194
|
+
env = Environment("test-env")
|
|
195
|
+
received_args = {}
|
|
196
|
+
|
|
197
|
+
@env.scenario("checkout")
|
|
198
|
+
async def checkout_scenario(user_id: str, amount: int = 100):
|
|
199
|
+
received_args["user_id"] = user_id
|
|
200
|
+
received_args["amount"] = amount
|
|
201
|
+
yield f"Checkout {user_id}: ${amount}"
|
|
202
|
+
yield 1.0
|
|
203
|
+
|
|
204
|
+
prompt = env._prompt_manager._prompts.get("test-env:checkout")
|
|
205
|
+
assert prompt is not None
|
|
206
|
+
# No context needed for prompt render
|
|
207
|
+
await prompt.render({"user_id": "alice", "amount": 50})
|
|
208
|
+
|
|
209
|
+
assert received_args["user_id"] == "alice"
|
|
210
|
+
assert received_args["amount"] == 50
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class TestScenarioSubmit:
|
|
214
|
+
"""Tests for scenario submit and answer flow."""
|
|
215
|
+
|
|
216
|
+
@pytest.mark.asyncio
|
|
217
|
+
async def test_submit_stores_answer(self) -> None:
|
|
218
|
+
"""submit() stores answer in active session."""
|
|
219
|
+
env = Environment("test-env")
|
|
220
|
+
|
|
221
|
+
@env.scenario("test")
|
|
222
|
+
async def test_scenario():
|
|
223
|
+
yield "What is 2+2?"
|
|
224
|
+
yield 1.0
|
|
225
|
+
|
|
226
|
+
# Run setup via proper API (creates _active_session)
|
|
227
|
+
await env.run_scenario_setup("test", {})
|
|
228
|
+
|
|
229
|
+
# Submit answer
|
|
230
|
+
await env.submit("test", "4")
|
|
231
|
+
|
|
232
|
+
# Answer is stored in active session (not _scenario_answers for client-side)
|
|
233
|
+
assert env._active_session is not None
|
|
234
|
+
assert env._active_session.answer == "4"
|
|
235
|
+
|
|
236
|
+
@pytest.mark.asyncio
|
|
237
|
+
async def test_scenario_receives_answer(self) -> None:
|
|
238
|
+
"""Scenario receives submitted answer via yield."""
|
|
239
|
+
env = Environment("test-env")
|
|
240
|
+
received_answer = None
|
|
241
|
+
|
|
242
|
+
@env.scenario("qa")
|
|
243
|
+
async def qa_scenario():
|
|
244
|
+
nonlocal received_answer
|
|
245
|
+
answer = yield "What is 2+2?"
|
|
246
|
+
received_answer = answer
|
|
247
|
+
yield 1.0 if answer == "4" else 0.0
|
|
248
|
+
|
|
249
|
+
# Run setup (creates _active_session)
|
|
250
|
+
prompt = env._prompt_manager._prompts.get("test-env:qa")
|
|
251
|
+
assert prompt is not None
|
|
252
|
+
await prompt.render({})
|
|
253
|
+
|
|
254
|
+
# Submit answer via _active_session
|
|
255
|
+
assert env._active_session is not None
|
|
256
|
+
env._active_session.answer = "4"
|
|
257
|
+
|
|
258
|
+
# Run evaluate
|
|
259
|
+
resource = env._resource_manager._resources.get("test-env:qa")
|
|
260
|
+
assert resource is not None
|
|
261
|
+
await resource.read()
|
|
262
|
+
|
|
263
|
+
assert received_answer == "4"
|
|
264
|
+
|
|
265
|
+
@pytest.mark.asyncio
|
|
266
|
+
async def test_scenario_evaluates_answer(self) -> None:
|
|
267
|
+
"""Scenario evaluates answer and returns reward."""
|
|
268
|
+
env = Environment("test-env")
|
|
269
|
+
|
|
270
|
+
@env.scenario("grading")
|
|
271
|
+
async def grading_scenario():
|
|
272
|
+
answer = yield "What is the capital of France?"
|
|
273
|
+
yield 1.0 if "paris" in answer.lower() else 0.0
|
|
274
|
+
|
|
275
|
+
# Run setup (creates _active_session)
|
|
276
|
+
prompt = env._prompt_manager._prompts.get("test-env:grading")
|
|
277
|
+
assert prompt is not None
|
|
278
|
+
await prompt.render({})
|
|
279
|
+
|
|
280
|
+
# Submit correct answer via _active_session
|
|
281
|
+
assert env._active_session is not None
|
|
282
|
+
env._active_session.answer = "Paris"
|
|
283
|
+
|
|
284
|
+
# Run evaluate
|
|
285
|
+
resource = env._resource_manager._resources.get("test-env:grading")
|
|
286
|
+
assert resource is not None
|
|
287
|
+
result = await resource.read()
|
|
288
|
+
|
|
289
|
+
import json
|
|
290
|
+
|
|
291
|
+
data = json.loads(result)
|
|
292
|
+
assert data["reward"] == 1.0
|
|
293
|
+
|
|
294
|
+
@pytest.mark.asyncio
|
|
295
|
+
async def test_hud_submit_normalizes_prefixed_scenario_name(self) -> None:
|
|
296
|
+
"""_hud_submit with prefixed name stores answer in _active_session.
|
|
297
|
+
|
|
298
|
+
Regression test: answers submitted with "env:scenario" prefix must
|
|
299
|
+
match the active session's local_name for storage.
|
|
300
|
+
"""
|
|
301
|
+
env = Environment("my-env")
|
|
302
|
+
|
|
303
|
+
@env.scenario("greet")
|
|
304
|
+
async def greet_scenario():
|
|
305
|
+
answer = yield "Say hello"
|
|
306
|
+
yield 1.0 if answer == "hello" else 0.0
|
|
307
|
+
|
|
308
|
+
# Run setup via prompt (creates _active_session)
|
|
309
|
+
prompt = env._prompt_manager._prompts.get("my-env:greet")
|
|
310
|
+
assert prompt is not None
|
|
311
|
+
await prompt.render({})
|
|
312
|
+
|
|
313
|
+
# Verify session exists before _hud_submit
|
|
314
|
+
assert env._active_session is not None
|
|
315
|
+
assert env._active_session.local_name == "greet"
|
|
316
|
+
|
|
317
|
+
# Simulate _hud_submit with PREFIXED scenario name (as happens in remote calls)
|
|
318
|
+
# This should normalize to "greet" and match the active session
|
|
319
|
+
await env.call_tool("_hud_submit", scenario="my-env:greet", answer="hello")
|
|
320
|
+
|
|
321
|
+
# Verify answer was stored in _active_session
|
|
322
|
+
assert env._active_session.answer == "hello"
|
|
323
|
+
|
|
324
|
+
# Verify evaluation works
|
|
325
|
+
resource = env._resource_manager._resources.get("my-env:greet")
|
|
326
|
+
assert resource is not None
|
|
327
|
+
result = await resource.read()
|
|
328
|
+
|
|
329
|
+
import json
|
|
330
|
+
|
|
331
|
+
data = json.loads(result)
|
|
332
|
+
assert data["reward"] == 1.0
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
class TestScenarioMeta:
|
|
336
|
+
"""Tests for scenario _meta containing code."""
|
|
337
|
+
|
|
338
|
+
def test_scenario_captures_source_code(self) -> None:
|
|
339
|
+
"""@env.scenario captures function source in meta."""
|
|
340
|
+
env = Environment("test-env")
|
|
341
|
+
|
|
342
|
+
@env.scenario("example")
|
|
343
|
+
async def example_scenario(x: int):
|
|
344
|
+
yield f"Process {x}"
|
|
345
|
+
yield 1.0
|
|
346
|
+
|
|
347
|
+
prompt = env._prompt_manager._prompts.get("test-env:example")
|
|
348
|
+
assert prompt is not None
|
|
349
|
+
assert prompt.meta is not None
|
|
350
|
+
assert "code" in prompt.meta
|
|
351
|
+
assert "async def example_scenario" in prompt.meta["code"]
|
|
352
|
+
assert "yield" in prompt.meta["code"]
|
|
353
|
+
|
|
354
|
+
def test_scenario_meta_on_resource(self) -> None:
|
|
355
|
+
"""Resource also has source code in meta."""
|
|
356
|
+
env = Environment("test-env")
|
|
357
|
+
|
|
358
|
+
@env.scenario("example")
|
|
359
|
+
async def example_scenario():
|
|
360
|
+
yield "Test"
|
|
361
|
+
yield 1.0
|
|
362
|
+
|
|
363
|
+
resource = env._resource_manager._resources.get("test-env:example")
|
|
364
|
+
assert resource is not None
|
|
365
|
+
assert resource.meta is not None
|
|
366
|
+
assert "code" in resource.meta
|
|
367
|
+
assert "async def example_scenario" in resource.meta["code"]
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
class TestScenarioJsonSerialization:
|
|
371
|
+
"""Tests for JSON serialization of complex argument types.
|
|
372
|
+
|
|
373
|
+
MCP prompts only support string arguments (dict[str, str]).
|
|
374
|
+
Complex types like lists, dicts, and numbers are JSON-serialized
|
|
375
|
+
when sent and deserialized based on type annotations when received.
|
|
376
|
+
"""
|
|
377
|
+
|
|
378
|
+
@pytest.mark.asyncio
|
|
379
|
+
async def test_list_argument_deserialization(self) -> None:
|
|
380
|
+
"""List arguments are JSON-deserialized from strings."""
|
|
381
|
+
env = Environment("test-env")
|
|
382
|
+
received_items: list[str] = []
|
|
383
|
+
|
|
384
|
+
@env.scenario("process_items")
|
|
385
|
+
async def process_items_scenario(items: list[str]):
|
|
386
|
+
received_items.extend(items)
|
|
387
|
+
yield f"Processing {len(items)} items"
|
|
388
|
+
yield 1.0
|
|
389
|
+
|
|
390
|
+
prompt = env._prompt_manager._prompts.get("test-env:process_items")
|
|
391
|
+
assert prompt is not None
|
|
392
|
+
|
|
393
|
+
# Simulate MCP sending JSON-encoded list as string
|
|
394
|
+
await prompt.render({"items": '["apple", "banana", "cherry"]'})
|
|
395
|
+
|
|
396
|
+
assert received_items == ["apple", "banana", "cherry"]
|
|
397
|
+
|
|
398
|
+
@pytest.mark.asyncio
|
|
399
|
+
async def test_dict_argument_deserialization(self) -> None:
|
|
400
|
+
"""Dict arguments are JSON-deserialized from strings."""
|
|
401
|
+
env = Environment("test-env")
|
|
402
|
+
received_config: dict[str, Any] = {}
|
|
403
|
+
|
|
404
|
+
@env.scenario("configure")
|
|
405
|
+
async def configure_scenario(config: dict[str, Any]):
|
|
406
|
+
received_config.update(config)
|
|
407
|
+
yield "Configuring..."
|
|
408
|
+
yield 1.0
|
|
409
|
+
|
|
410
|
+
prompt = env._prompt_manager._prompts.get("test-env:configure")
|
|
411
|
+
assert prompt is not None
|
|
412
|
+
|
|
413
|
+
# Simulate MCP sending JSON-encoded dict as string
|
|
414
|
+
await prompt.render({"config": '{"timeout": 30, "retries": 3}'})
|
|
415
|
+
|
|
416
|
+
assert received_config == {"timeout": 30, "retries": 3}
|
|
417
|
+
|
|
418
|
+
@pytest.mark.asyncio
|
|
419
|
+
async def test_int_argument_deserialization(self) -> None:
|
|
420
|
+
"""Integer arguments are JSON-deserialized from strings."""
|
|
421
|
+
env = Environment("test-env")
|
|
422
|
+
received_count = 0
|
|
423
|
+
|
|
424
|
+
@env.scenario("count")
|
|
425
|
+
async def count_scenario(count: int):
|
|
426
|
+
nonlocal received_count
|
|
427
|
+
received_count = count
|
|
428
|
+
yield f"Counting to {count}"
|
|
429
|
+
yield 1.0
|
|
430
|
+
|
|
431
|
+
prompt = env._prompt_manager._prompts.get("test-env:count")
|
|
432
|
+
assert prompt is not None
|
|
433
|
+
|
|
434
|
+
# Simulate MCP sending JSON-encoded int as string
|
|
435
|
+
await prompt.render({"count": "42"})
|
|
436
|
+
|
|
437
|
+
assert received_count == 42
|
|
438
|
+
assert isinstance(received_count, int)
|
|
439
|
+
|
|
440
|
+
@pytest.mark.asyncio
|
|
441
|
+
async def test_float_argument_deserialization(self) -> None:
|
|
442
|
+
"""Float arguments are JSON-deserialized from strings."""
|
|
443
|
+
env = Environment("test-env")
|
|
444
|
+
received_value = 0.0
|
|
445
|
+
|
|
446
|
+
@env.scenario("precision")
|
|
447
|
+
async def precision_scenario(value: float):
|
|
448
|
+
nonlocal received_value
|
|
449
|
+
received_value = value
|
|
450
|
+
yield f"Value is {value}"
|
|
451
|
+
yield 1.0
|
|
452
|
+
|
|
453
|
+
prompt = env._prompt_manager._prompts.get("test-env:precision")
|
|
454
|
+
assert prompt is not None
|
|
455
|
+
|
|
456
|
+
# Simulate MCP sending JSON-encoded float as string
|
|
457
|
+
await prompt.render({"value": "3.14159"})
|
|
458
|
+
|
|
459
|
+
assert received_value == 3.14159
|
|
460
|
+
assert isinstance(received_value, float)
|
|
461
|
+
|
|
462
|
+
@pytest.mark.asyncio
|
|
463
|
+
async def test_bool_argument_deserialization(self) -> None:
|
|
464
|
+
"""Boolean arguments are JSON-deserialized from strings."""
|
|
465
|
+
env = Environment("test-env")
|
|
466
|
+
received_flag = False
|
|
467
|
+
|
|
468
|
+
@env.scenario("toggle")
|
|
469
|
+
async def toggle_scenario(enabled: bool):
|
|
470
|
+
nonlocal received_flag
|
|
471
|
+
received_flag = enabled
|
|
472
|
+
yield f"Enabled: {enabled}"
|
|
473
|
+
yield 1.0
|
|
474
|
+
|
|
475
|
+
prompt = env._prompt_manager._prompts.get("test-env:toggle")
|
|
476
|
+
assert prompt is not None
|
|
477
|
+
|
|
478
|
+
# Simulate MCP sending JSON-encoded bool as string
|
|
479
|
+
await prompt.render({"enabled": "true"})
|
|
480
|
+
|
|
481
|
+
assert received_flag is True
|
|
482
|
+
assert isinstance(received_flag, bool)
|
|
483
|
+
|
|
484
|
+
@pytest.mark.asyncio
|
|
485
|
+
async def test_string_argument_unchanged(self) -> None:
|
|
486
|
+
"""String arguments are passed through unchanged."""
|
|
487
|
+
env = Environment("test-env")
|
|
488
|
+
received_name = ""
|
|
489
|
+
|
|
490
|
+
@env.scenario("greet")
|
|
491
|
+
async def greet_scenario(name: str):
|
|
492
|
+
nonlocal received_name
|
|
493
|
+
received_name = name
|
|
494
|
+
yield f"Hello, {name}!"
|
|
495
|
+
yield 1.0
|
|
496
|
+
|
|
497
|
+
prompt = env._prompt_manager._prompts.get("test-env:greet")
|
|
498
|
+
assert prompt is not None
|
|
499
|
+
|
|
500
|
+
# String should pass through as-is (not double-encoded)
|
|
501
|
+
await prompt.render({"name": "Alice"})
|
|
502
|
+
|
|
503
|
+
assert received_name == "Alice"
|
|
504
|
+
|
|
505
|
+
@pytest.mark.asyncio
|
|
506
|
+
async def test_mixed_argument_types(self) -> None:
|
|
507
|
+
"""Mixed argument types are handled correctly."""
|
|
508
|
+
env = Environment("test-env")
|
|
509
|
+
received_args: dict[str, Any] = {}
|
|
510
|
+
|
|
511
|
+
@env.scenario("mixed")
|
|
512
|
+
async def mixed_scenario(
|
|
513
|
+
name: str,
|
|
514
|
+
count: int,
|
|
515
|
+
items: list[str],
|
|
516
|
+
options: dict[str, bool],
|
|
517
|
+
):
|
|
518
|
+
received_args["name"] = name
|
|
519
|
+
received_args["count"] = count
|
|
520
|
+
received_args["items"] = items
|
|
521
|
+
received_args["options"] = options
|
|
522
|
+
yield "Processing..."
|
|
523
|
+
yield 1.0
|
|
524
|
+
|
|
525
|
+
prompt = env._prompt_manager._prompts.get("test-env:mixed")
|
|
526
|
+
assert prompt is not None
|
|
527
|
+
|
|
528
|
+
await prompt.render(
|
|
529
|
+
{
|
|
530
|
+
"name": "test",
|
|
531
|
+
"count": "5",
|
|
532
|
+
"items": '["a", "b", "c"]',
|
|
533
|
+
"options": '{"verbose": true, "dry_run": false}',
|
|
534
|
+
}
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
assert received_args["name"] == "test"
|
|
538
|
+
assert received_args["count"] == 5
|
|
539
|
+
assert received_args["items"] == ["a", "b", "c"]
|
|
540
|
+
assert received_args["options"] == {"verbose": True, "dry_run": False}
|
|
541
|
+
|
|
542
|
+
@pytest.mark.asyncio
|
|
543
|
+
async def test_invalid_json_falls_back_to_string(self) -> None:
|
|
544
|
+
"""Invalid JSON for non-string type falls back to string value."""
|
|
545
|
+
env = Environment("test-env")
|
|
546
|
+
received_items: list[str] = []
|
|
547
|
+
|
|
548
|
+
@env.scenario("fallback")
|
|
549
|
+
async def fallback_scenario(items: list[str]):
|
|
550
|
+
# This will receive the raw string if JSON parsing fails
|
|
551
|
+
received_items.append(str(items))
|
|
552
|
+
yield "Processing..."
|
|
553
|
+
yield 1.0
|
|
554
|
+
|
|
555
|
+
prompt = env._prompt_manager._prompts.get("test-env:fallback")
|
|
556
|
+
assert prompt is not None
|
|
557
|
+
|
|
558
|
+
# Invalid JSON - should fall back to string
|
|
559
|
+
await prompt.render({"items": "not valid json ["})
|
|
560
|
+
|
|
561
|
+
# Falls back to raw string
|
|
562
|
+
assert received_items == ["not valid json ["]
|
|
563
|
+
|
|
564
|
+
@pytest.mark.asyncio
|
|
565
|
+
async def test_nested_complex_types(self) -> None:
|
|
566
|
+
"""Nested complex types are deserialized correctly."""
|
|
567
|
+
env = Environment("test-env")
|
|
568
|
+
received_data: dict[str, Any] = {}
|
|
569
|
+
|
|
570
|
+
@env.scenario("nested")
|
|
571
|
+
async def nested_scenario(data: dict[str, Any]):
|
|
572
|
+
received_data.update(data)
|
|
573
|
+
yield "Processing nested data..."
|
|
574
|
+
yield 1.0
|
|
575
|
+
|
|
576
|
+
prompt = env._prompt_manager._prompts.get("test-env:nested")
|
|
577
|
+
assert prompt is not None
|
|
578
|
+
|
|
579
|
+
nested_json = (
|
|
580
|
+
'{"users": [{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}], '
|
|
581
|
+
'"metadata": {"version": 1}}'
|
|
582
|
+
)
|
|
583
|
+
await prompt.render({"data": nested_json})
|
|
584
|
+
|
|
585
|
+
assert received_data == {
|
|
586
|
+
"users": [
|
|
587
|
+
{"name": "Alice", "age": 30},
|
|
588
|
+
{"name": "Bob", "age": 25},
|
|
589
|
+
],
|
|
590
|
+
"metadata": {"version": 1},
|
|
591
|
+
}
|
|
592
|
+
|
|
593
|
+
@pytest.mark.asyncio
|
|
594
|
+
async def test_optional_list_with_value(self) -> None:
|
|
595
|
+
"""Optional[list[str]] receives list when provided."""
|
|
596
|
+
env = Environment("test-env")
|
|
597
|
+
received_items: list[str] | None = None
|
|
598
|
+
|
|
599
|
+
@env.scenario("optional_list")
|
|
600
|
+
async def optional_list_scenario(items: list[str] | None = None):
|
|
601
|
+
nonlocal received_items
|
|
602
|
+
received_items = items
|
|
603
|
+
yield f"Got {items}"
|
|
604
|
+
yield 1.0
|
|
605
|
+
|
|
606
|
+
prompt = env._prompt_manager._prompts.get("test-env:optional_list")
|
|
607
|
+
assert prompt is not None
|
|
608
|
+
|
|
609
|
+
await prompt.render({"items": '["x", "y", "z"]'})
|
|
610
|
+
|
|
611
|
+
assert received_items == ["x", "y", "z"]
|
|
612
|
+
|
|
613
|
+
@pytest.mark.asyncio
|
|
614
|
+
async def test_optional_list_with_null(self) -> None:
|
|
615
|
+
"""Optional[list[str]] receives None when 'null' is passed."""
|
|
616
|
+
env = Environment("test-env")
|
|
617
|
+
received_items: list[str] | None = ["initial"]
|
|
618
|
+
|
|
619
|
+
@env.scenario("optional_list_null")
|
|
620
|
+
async def optional_list_null_scenario(items: list[str] | None = None):
|
|
621
|
+
nonlocal received_items
|
|
622
|
+
received_items = items
|
|
623
|
+
yield f"Got {items}"
|
|
624
|
+
yield 1.0
|
|
625
|
+
|
|
626
|
+
prompt = env._prompt_manager._prompts.get("test-env:optional_list_null")
|
|
627
|
+
assert prompt is not None
|
|
628
|
+
|
|
629
|
+
await prompt.render({"items": "null"})
|
|
630
|
+
|
|
631
|
+
assert received_items is None
|
|
632
|
+
|
|
633
|
+
@pytest.mark.asyncio
|
|
634
|
+
async def test_optional_str_with_value(self) -> None:
|
|
635
|
+
"""Optional[str] receives string value correctly."""
|
|
636
|
+
env = Environment("test-env")
|
|
637
|
+
received_name: str | None = None
|
|
638
|
+
|
|
639
|
+
@env.scenario("optional_str")
|
|
640
|
+
async def optional_str_scenario(name: str | None = None):
|
|
641
|
+
nonlocal received_name
|
|
642
|
+
received_name = name
|
|
643
|
+
yield f"Got {name}"
|
|
644
|
+
yield 1.0
|
|
645
|
+
|
|
646
|
+
prompt = env._prompt_manager._prompts.get("test-env:optional_str")
|
|
647
|
+
assert prompt is not None
|
|
648
|
+
|
|
649
|
+
await prompt.render({"name": "Alice"})
|
|
650
|
+
|
|
651
|
+
assert received_name == "Alice"
|
|
652
|
+
|
|
653
|
+
@pytest.mark.asyncio
|
|
654
|
+
async def test_optional_str_with_null(self) -> None:
|
|
655
|
+
"""Optional[str] receives None when 'null' is passed."""
|
|
656
|
+
env = Environment("test-env")
|
|
657
|
+
received_name: str | None = "initial"
|
|
658
|
+
|
|
659
|
+
@env.scenario("optional_str_null")
|
|
660
|
+
async def optional_str_null_scenario(name: str | None = None):
|
|
661
|
+
nonlocal received_name
|
|
662
|
+
received_name = name
|
|
663
|
+
yield f"Got {name}"
|
|
664
|
+
yield 1.0
|
|
665
|
+
|
|
666
|
+
prompt = env._prompt_manager._prompts.get("test-env:optional_str_null")
|
|
667
|
+
assert prompt is not None
|
|
668
|
+
|
|
669
|
+
await prompt.render({"name": "null"})
|
|
670
|
+
|
|
671
|
+
assert received_name is None
|
|
672
|
+
|
|
673
|
+
@pytest.mark.asyncio
|
|
674
|
+
async def test_pydantic_model_deserialization(self) -> None:
|
|
675
|
+
"""Pydantic models are properly deserialized from JSON."""
|
|
676
|
+
env = Environment("test-env")
|
|
677
|
+
received_config: _UserConfig | None = None
|
|
678
|
+
|
|
679
|
+
@env.scenario("pydantic_model")
|
|
680
|
+
async def pydantic_model_scenario(config: _UserConfig):
|
|
681
|
+
nonlocal received_config
|
|
682
|
+
received_config = config
|
|
683
|
+
yield f"Got config for {config.name}"
|
|
684
|
+
yield 1.0
|
|
685
|
+
|
|
686
|
+
prompt = env._prompt_manager._prompts.get("test-env:pydantic_model")
|
|
687
|
+
assert prompt is not None
|
|
688
|
+
|
|
689
|
+
await prompt.render({"config": '{"name": "Alice", "age": 30}'})
|
|
690
|
+
|
|
691
|
+
assert received_config is not None
|
|
692
|
+
assert isinstance(received_config, _UserConfig)
|
|
693
|
+
assert received_config.name == "Alice"
|
|
694
|
+
assert received_config.age == 30
|
|
695
|
+
assert received_config.active is True # default value
|
|
696
|
+
|
|
697
|
+
@pytest.mark.asyncio
|
|
698
|
+
async def test_enum_deserialization(self) -> None:
|
|
699
|
+
"""Enum values are properly deserialized from JSON strings."""
|
|
700
|
+
env = Environment("test-env")
|
|
701
|
+
received_status: _Status | None = None
|
|
702
|
+
|
|
703
|
+
@env.scenario("enum_status")
|
|
704
|
+
async def enum_scenario(status: _Status):
|
|
705
|
+
nonlocal received_status
|
|
706
|
+
received_status = status
|
|
707
|
+
yield f"Status is {status.value}"
|
|
708
|
+
yield 1.0
|
|
709
|
+
|
|
710
|
+
prompt = env._prompt_manager._prompts.get("test-env:enum_status")
|
|
711
|
+
assert prompt is not None
|
|
712
|
+
|
|
713
|
+
await prompt.render({"status": '"active"'})
|
|
714
|
+
|
|
715
|
+
assert received_status is not None
|
|
716
|
+
assert isinstance(received_status, _Status)
|
|
717
|
+
assert received_status == _Status.ACTIVE
|
|
718
|
+
|
|
719
|
+
@pytest.mark.asyncio
|
|
720
|
+
async def test_datetime_deserialization(self) -> None:
|
|
721
|
+
"""Datetime values are properly deserialized from ISO strings."""
|
|
722
|
+
env = Environment("test-env")
|
|
723
|
+
received_dt: datetime | None = None
|
|
724
|
+
|
|
725
|
+
@env.scenario("datetime_scenario")
|
|
726
|
+
async def datetime_scenario(created_at: datetime):
|
|
727
|
+
nonlocal received_dt
|
|
728
|
+
received_dt = created_at
|
|
729
|
+
yield f"Created at {created_at}"
|
|
730
|
+
yield 1.0
|
|
731
|
+
|
|
732
|
+
prompt = env._prompt_manager._prompts.get("test-env:datetime_scenario")
|
|
733
|
+
assert prompt is not None
|
|
734
|
+
|
|
735
|
+
await prompt.render({"created_at": '"2024-06-15T10:30:00"'})
|
|
736
|
+
|
|
737
|
+
assert received_dt is not None
|
|
738
|
+
assert isinstance(received_dt, datetime)
|
|
739
|
+
assert received_dt.year == 2024
|
|
740
|
+
assert received_dt.month == 6
|
|
741
|
+
assert received_dt.day == 15
|
|
742
|
+
assert received_dt.hour == 10
|
|
743
|
+
assert received_dt.minute == 30
|
|
744
|
+
|
|
745
|
+
@pytest.mark.asyncio
|
|
746
|
+
async def test_nested_pydantic_model(self) -> None:
|
|
747
|
+
"""Nested Pydantic models are properly deserialized."""
|
|
748
|
+
env = Environment("test-env")
|
|
749
|
+
received_person: _Person | None = None
|
|
750
|
+
|
|
751
|
+
@env.scenario("nested_pydantic")
|
|
752
|
+
async def nested_pydantic_scenario(person: _Person):
|
|
753
|
+
nonlocal received_person
|
|
754
|
+
received_person = person
|
|
755
|
+
yield f"Person {person.name} from {person.address.city}"
|
|
756
|
+
yield 1.0
|
|
757
|
+
|
|
758
|
+
prompt = env._prompt_manager._prompts.get("test-env:nested_pydantic")
|
|
759
|
+
assert prompt is not None
|
|
760
|
+
|
|
761
|
+
json_data = '{"name": "Bob", "address": {"street": "123 Main St", "city": "NYC"}}'
|
|
762
|
+
await prompt.render({"person": json_data})
|
|
763
|
+
|
|
764
|
+
assert received_person is not None
|
|
765
|
+
assert isinstance(received_person, _Person)
|
|
766
|
+
assert received_person.name == "Bob"
|
|
767
|
+
assert isinstance(received_person.address, _Address)
|
|
768
|
+
assert received_person.address.city == "NYC"
|
|
769
|
+
|
|
770
|
+
@pytest.mark.asyncio
|
|
771
|
+
async def test_list_of_pydantic_models(self) -> None:
|
|
772
|
+
"""List of Pydantic models are properly deserialized."""
|
|
773
|
+
env = Environment("test-env")
|
|
774
|
+
received_items: list[_Item] = []
|
|
775
|
+
|
|
776
|
+
@env.scenario("list_pydantic")
|
|
777
|
+
async def list_pydantic_scenario(items: list[_Item]):
|
|
778
|
+
nonlocal received_items
|
|
779
|
+
received_items = items
|
|
780
|
+
yield f"Got {len(items)} items"
|
|
781
|
+
yield 1.0
|
|
782
|
+
|
|
783
|
+
prompt = env._prompt_manager._prompts.get("test-env:list_pydantic")
|
|
784
|
+
assert prompt is not None
|
|
785
|
+
|
|
786
|
+
json_data = '[{"id": 1, "name": "Apple"}, {"id": 2, "name": "Banana"}]'
|
|
787
|
+
await prompt.render({"items": json_data})
|
|
788
|
+
|
|
789
|
+
assert len(received_items) == 2
|
|
790
|
+
assert all(isinstance(item, _Item) for item in received_items)
|
|
791
|
+
assert received_items[0].name == "Apple"
|
|
792
|
+
assert received_items[1].name == "Banana"
|
|
793
|
+
|
|
794
|
+
|
|
795
|
+
class TestScenarioNameNormalization:
|
|
796
|
+
"""Test edge cases for environment and scenario name handling."""
|
|
797
|
+
|
|
798
|
+
@pytest.mark.asyncio
|
|
799
|
+
async def test_env_name_with_underscores_normalizes(self) -> None:
|
|
800
|
+
"""Environment name with underscores normalizes to hyphens."""
|
|
801
|
+
env = Environment("my_test_env")
|
|
802
|
+
assert env.name == "my-test-env"
|
|
803
|
+
|
|
804
|
+
@env.scenario("greet")
|
|
805
|
+
async def greet():
|
|
806
|
+
yield "Hello"
|
|
807
|
+
yield 1.0
|
|
808
|
+
|
|
809
|
+
# Scenario should be registered with normalized name
|
|
810
|
+
assert "my-test-env:greet" in [p.name for p in env._prompt_manager._prompts.values()]
|
|
811
|
+
|
|
812
|
+
@pytest.mark.asyncio
|
|
813
|
+
async def test_env_name_with_spaces_normalizes(self) -> None:
|
|
814
|
+
"""Environment name with spaces normalizes to hyphens."""
|
|
815
|
+
env = Environment("my test env")
|
|
816
|
+
assert env.name == "my-test-env"
|
|
817
|
+
|
|
818
|
+
@pytest.mark.asyncio
|
|
819
|
+
async def test_env_name_with_caps_normalizes(self) -> None:
|
|
820
|
+
"""Environment name with capitals normalizes to lowercase."""
|
|
821
|
+
env = Environment("MyTestEnv")
|
|
822
|
+
assert env.name == "mytestenv"
|
|
823
|
+
|
|
824
|
+
@pytest.mark.asyncio
|
|
825
|
+
async def test_env_name_mixed_formatting(self) -> None:
|
|
826
|
+
"""Environment name with mixed formatting normalizes correctly."""
|
|
827
|
+
env = Environment("My_Test Env")
|
|
828
|
+
assert env.name == "my-test-env"
|
|
829
|
+
|
|
830
|
+
@pytest.mark.asyncio
|
|
831
|
+
async def test_prefix_matches_normalized_name(self) -> None:
|
|
832
|
+
"""Scenario prefix should match normalized env name."""
|
|
833
|
+
env = Environment("my_env") # Normalizes to "my-env"
|
|
834
|
+
|
|
835
|
+
@env.scenario("test")
|
|
836
|
+
async def test_scenario():
|
|
837
|
+
yield "Prompt"
|
|
838
|
+
yield 1.0
|
|
839
|
+
|
|
840
|
+
# Calling with normalized prefix should work as local
|
|
841
|
+
prompt = await env.run_scenario_setup("my-env:test", {})
|
|
842
|
+
assert prompt == "Prompt"
|
|
843
|
+
assert env._active_session is not None
|
|
844
|
+
assert env._active_session.is_local is True
|
|
845
|
+
|
|
846
|
+
@pytest.mark.asyncio
|
|
847
|
+
async def test_unnormalized_prefix_treated_as_remote(self) -> None:
|
|
848
|
+
"""Calling with unnormalized prefix treats as remote (different env)."""
|
|
849
|
+
env = Environment("my_env") # Normalizes to "my-env"
|
|
850
|
+
|
|
851
|
+
@env.scenario("test")
|
|
852
|
+
async def test_scenario():
|
|
853
|
+
yield "Prompt"
|
|
854
|
+
yield 1.0
|
|
855
|
+
|
|
856
|
+
# Calling with "my_env:test" (underscore) won't match "my-env"
|
|
857
|
+
# So it's treated as remote - which will fail since no connection
|
|
858
|
+
with pytest.raises(ValueError, match="Scenario not found"):
|
|
859
|
+
await env.run_scenario_setup("my_env:test", {})
|
|
860
|
+
|
|
861
|
+
|
|
862
|
+
class TestScenarioMalformedNames:
|
|
863
|
+
"""Test handling of malformed scenario names."""
|
|
864
|
+
|
|
865
|
+
@pytest.mark.asyncio
|
|
866
|
+
async def test_empty_scenario_name_rejected(self) -> None:
|
|
867
|
+
"""Empty scenario name should be handled gracefully."""
|
|
868
|
+
env = Environment("test-env")
|
|
869
|
+
|
|
870
|
+
@env.scenario("valid")
|
|
871
|
+
async def valid_scenario():
|
|
872
|
+
yield "Prompt"
|
|
873
|
+
yield 1.0
|
|
874
|
+
|
|
875
|
+
# Empty name - should fail since not registered
|
|
876
|
+
with pytest.raises((ValueError, KeyError)):
|
|
877
|
+
await env.run_scenario_setup("", {})
|
|
878
|
+
|
|
879
|
+
@pytest.mark.asyncio
|
|
880
|
+
async def test_only_colon_handled(self) -> None:
|
|
881
|
+
"""Scenario name that is just ':' should be handled."""
|
|
882
|
+
env = Environment("test-env")
|
|
883
|
+
|
|
884
|
+
# ":" splits to prefix="" and short_name=""
|
|
885
|
+
with pytest.raises((ValueError, KeyError)):
|
|
886
|
+
await env.run_scenario_setup(":", {})
|
|
887
|
+
|
|
888
|
+
@pytest.mark.asyncio
|
|
889
|
+
async def test_colon_in_scenario_name_rejected_at_registration(self) -> None:
|
|
890
|
+
"""Scenario names with colons are rejected at registration time."""
|
|
891
|
+
env = Environment("test-env")
|
|
892
|
+
|
|
893
|
+
# Colons are reserved as the separator between env and scenario names
|
|
894
|
+
with pytest.raises(ValueError, match="cannot contain ':'"):
|
|
895
|
+
|
|
896
|
+
@env.scenario("invalid:name")
|
|
897
|
+
async def scenario_with_colon():
|
|
898
|
+
yield "Prompt"
|
|
899
|
+
yield 1.0
|
|
900
|
+
|
|
901
|
+
@pytest.mark.asyncio
|
|
902
|
+
async def test_whitespace_in_scenario_name(self) -> None:
|
|
903
|
+
"""Scenario names with whitespace should work (not normalized)."""
|
|
904
|
+
env = Environment("test-env")
|
|
905
|
+
|
|
906
|
+
@env.scenario("my scenario")
|
|
907
|
+
async def scenario_with_space():
|
|
908
|
+
yield "Prompt"
|
|
909
|
+
yield 1.0
|
|
910
|
+
|
|
911
|
+
# Scenario names are NOT normalized (only env names are)
|
|
912
|
+
prompt = await env.run_scenario_setup("my scenario", {})
|
|
913
|
+
assert prompt == "Prompt"
|
|
914
|
+
|
|
915
|
+
|
|
916
|
+
class TestScenarioRegistration:
|
|
917
|
+
"""Test scenario registration edge cases."""
|
|
918
|
+
|
|
919
|
+
@pytest.mark.asyncio
|
|
920
|
+
async def test_duplicate_scenario_name_overwrites(self) -> None:
|
|
921
|
+
"""Registering same scenario name twice should overwrite."""
|
|
922
|
+
env = Environment("test-env")
|
|
923
|
+
|
|
924
|
+
@env.scenario("greet")
|
|
925
|
+
async def greet_v1():
|
|
926
|
+
yield "Hello v1"
|
|
927
|
+
yield 1.0
|
|
928
|
+
|
|
929
|
+
@env.scenario("greet")
|
|
930
|
+
async def greet_v2():
|
|
931
|
+
yield "Hello v2"
|
|
932
|
+
yield 1.0
|
|
933
|
+
|
|
934
|
+
# Should use v2
|
|
935
|
+
prompt = await env.run_scenario_setup("greet", {})
|
|
936
|
+
assert prompt == "Hello v2"
|
|
937
|
+
|
|
938
|
+
@pytest.mark.asyncio
|
|
939
|
+
async def test_scenario_with_special_chars(self) -> None:
|
|
940
|
+
"""Scenario names can contain special characters."""
|
|
941
|
+
env = Environment("test-env")
|
|
942
|
+
|
|
943
|
+
@env.scenario("test-scenario_v2.0")
|
|
944
|
+
async def special_scenario():
|
|
945
|
+
yield "Prompt"
|
|
946
|
+
yield 1.0
|
|
947
|
+
|
|
948
|
+
prompt = await env.run_scenario_setup("test-scenario_v2.0", {})
|
|
949
|
+
assert prompt == "Prompt"
|
|
950
|
+
|
|
951
|
+
@pytest.mark.asyncio
|
|
952
|
+
async def test_scenario_that_yields_once(self) -> None:
|
|
953
|
+
"""Scenario that yields only once (no evaluate) should handle gracefully."""
|
|
954
|
+
env = Environment("test-env")
|
|
955
|
+
|
|
956
|
+
@env.scenario("one-yield")
|
|
957
|
+
async def one_yield_scenario():
|
|
958
|
+
yield "Prompt"
|
|
959
|
+
# No second yield!
|
|
960
|
+
|
|
961
|
+
prompt = await env.run_scenario_setup("one-yield", {})
|
|
962
|
+
assert prompt == "Prompt"
|
|
963
|
+
|
|
964
|
+
assert env._active_session is not None
|
|
965
|
+
env._active_session.answer = "test"
|
|
966
|
+
# Evaluate should handle StopAsyncIteration and return 1.0
|
|
967
|
+
reward = await env.run_scenario_evaluate("one-yield")
|
|
968
|
+
assert reward == 1.0
|
|
969
|
+
|
|
970
|
+
@pytest.mark.asyncio
|
|
971
|
+
async def test_scenario_that_yields_three_times(self) -> None:
|
|
972
|
+
"""Scenario that yields more than twice - third yield ignored."""
|
|
973
|
+
env = Environment("test-env")
|
|
974
|
+
|
|
975
|
+
@env.scenario("three-yields")
|
|
976
|
+
async def three_yield_scenario():
|
|
977
|
+
yield "Prompt"
|
|
978
|
+
yield 0.5
|
|
979
|
+
yield "This should be ignored"
|
|
980
|
+
|
|
981
|
+
prompt = await env.run_scenario_setup("three-yields", {})
|
|
982
|
+
assert prompt == "Prompt"
|
|
983
|
+
|
|
984
|
+
assert env._active_session is not None
|
|
985
|
+
env._active_session.answer = "test"
|
|
986
|
+
reward = await env.run_scenario_evaluate("three-yields")
|
|
987
|
+
assert reward == 0.5
|
|
988
|
+
|
|
989
|
+
|
|
990
|
+
class TestScenarioSessionState:
|
|
991
|
+
"""Test session state management edge cases."""
|
|
992
|
+
|
|
993
|
+
@pytest.mark.asyncio
|
|
994
|
+
async def test_submit_before_setup_raises(self) -> None:
|
|
995
|
+
"""Calling submit() before run_scenario_setup() should raise."""
|
|
996
|
+
env = Environment("test-env")
|
|
997
|
+
|
|
998
|
+
@env.scenario("test")
|
|
999
|
+
async def test_scenario():
|
|
1000
|
+
yield "Prompt"
|
|
1001
|
+
yield 1.0
|
|
1002
|
+
|
|
1003
|
+
with pytest.raises(ValueError, match="No active scenario session"):
|
|
1004
|
+
await env.submit("test", "answer")
|
|
1005
|
+
|
|
1006
|
+
@pytest.mark.asyncio
|
|
1007
|
+
async def test_evaluate_before_setup_returns_none(self) -> None:
|
|
1008
|
+
"""Calling evaluate() before setup() should return None."""
|
|
1009
|
+
env = Environment("test-env")
|
|
1010
|
+
|
|
1011
|
+
@env.scenario("test")
|
|
1012
|
+
async def test_scenario():
|
|
1013
|
+
yield "Prompt"
|
|
1014
|
+
yield 1.0
|
|
1015
|
+
|
|
1016
|
+
result = await env.run_scenario_evaluate("test")
|
|
1017
|
+
assert result is None
|
|
1018
|
+
|
|
1019
|
+
@pytest.mark.asyncio
|
|
1020
|
+
async def test_double_evaluate_returns_none(self) -> None:
|
|
1021
|
+
"""Calling evaluate() twice should return None on second call."""
|
|
1022
|
+
env = Environment("test-env")
|
|
1023
|
+
|
|
1024
|
+
@env.scenario("test")
|
|
1025
|
+
async def test_scenario():
|
|
1026
|
+
yield "Prompt"
|
|
1027
|
+
yield 0.75
|
|
1028
|
+
|
|
1029
|
+
await env.run_scenario_setup("test", {})
|
|
1030
|
+
assert env._active_session is not None
|
|
1031
|
+
env._active_session.answer = "answer"
|
|
1032
|
+
|
|
1033
|
+
reward1 = await env.run_scenario_evaluate("test")
|
|
1034
|
+
assert reward1 == 0.75
|
|
1035
|
+
|
|
1036
|
+
# Second call - session cleared
|
|
1037
|
+
reward2 = await env.run_scenario_evaluate("test")
|
|
1038
|
+
assert reward2 is None
|
|
1039
|
+
|
|
1040
|
+
@pytest.mark.asyncio
|
|
1041
|
+
async def test_submit_wrong_scenario_raises(self) -> None:
|
|
1042
|
+
"""Submitting answer for wrong scenario should raise."""
|
|
1043
|
+
env = Environment("test-env")
|
|
1044
|
+
|
|
1045
|
+
@env.scenario("scenario-a")
|
|
1046
|
+
async def scenario_a():
|
|
1047
|
+
yield "Prompt A"
|
|
1048
|
+
yield 1.0
|
|
1049
|
+
|
|
1050
|
+
@env.scenario("scenario-b")
|
|
1051
|
+
async def scenario_b():
|
|
1052
|
+
yield "Prompt B"
|
|
1053
|
+
yield 1.0
|
|
1054
|
+
|
|
1055
|
+
await env.run_scenario_setup("scenario-a", {})
|
|
1056
|
+
|
|
1057
|
+
with pytest.raises(ValueError, match="Scenario mismatch"):
|
|
1058
|
+
await env.submit("scenario-b", "answer")
|
|
1059
|
+
|
|
1060
|
+
@pytest.mark.asyncio
|
|
1061
|
+
async def test_second_setup_overwrites_first(self) -> None:
|
|
1062
|
+
"""Starting a new scenario before evaluating previous one overwrites."""
|
|
1063
|
+
env = Environment("test-env")
|
|
1064
|
+
|
|
1065
|
+
@env.scenario("first")
|
|
1066
|
+
async def first_scenario():
|
|
1067
|
+
yield "First"
|
|
1068
|
+
yield 1.0
|
|
1069
|
+
|
|
1070
|
+
@env.scenario("second")
|
|
1071
|
+
async def second_scenario():
|
|
1072
|
+
yield "Second"
|
|
1073
|
+
yield 0.5
|
|
1074
|
+
|
|
1075
|
+
await env.run_scenario_setup("first", {})
|
|
1076
|
+
assert env._active_session is not None
|
|
1077
|
+
assert env._active_session.local_name == "first"
|
|
1078
|
+
|
|
1079
|
+
# Start second without evaluating first
|
|
1080
|
+
await env.run_scenario_setup("second", {})
|
|
1081
|
+
assert env._active_session is not None
|
|
1082
|
+
assert env._active_session.local_name == "second"
|
|
1083
|
+
|
|
1084
|
+
env._active_session.answer = "answer"
|
|
1085
|
+
reward = await env.run_scenario_evaluate("second")
|
|
1086
|
+
assert reward == 0.5
|