hud-python 0.4.28__py3-none-any.whl → 0.4.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hud-python might be problematic. Click here for more details.

Files changed (75) hide show
  1. hud/__init__.py +2 -1
  2. hud/agents/base.py +73 -45
  3. hud/agents/claude.py +8 -4
  4. hud/agents/openai_chat_generic.py +65 -40
  5. hud/agents/tests/test_base.py +0 -4
  6. hud/agents/tests/test_openai.py +1 -1
  7. hud/cli/__init__.py +182 -52
  8. hud/cli/dev.py +8 -9
  9. hud/cli/eval.py +317 -119
  10. hud/cli/flows/__init__.py +0 -0
  11. hud/cli/flows/tasks.py +0 -0
  12. hud/cli/get.py +160 -0
  13. hud/cli/rl/__init__.py +563 -71
  14. hud/cli/rl/config.py +94 -0
  15. hud/cli/rl/display.py +133 -0
  16. hud/cli/rl/gpu.py +63 -0
  17. hud/cli/rl/gpu_utils.py +318 -0
  18. hud/cli/rl/presets.py +96 -0
  19. hud/cli/rl/remote_runner.py +348 -0
  20. hud/cli/rl/rl_api.py +150 -0
  21. hud/cli/rl/vllm.py +177 -0
  22. hud/cli/tests/test_analyze_metadata.py +0 -1
  23. hud/cli/utils/tasks.py +26 -0
  24. hud/clients/base.py +21 -23
  25. hud/clients/mcp_use.py +36 -44
  26. hud/clients/tests/test_mcp_use_retry.py +10 -10
  27. hud/datasets/__init__.py +4 -3
  28. hud/datasets/{execution/parallel.py → parallel.py} +1 -1
  29. hud/datasets/{execution/runner.py → runner.py} +1 -1
  30. hud/datasets/utils.py +1 -1
  31. hud/native/tests/test_native_init.py +1 -1
  32. hud/otel/config.py +1 -1
  33. hud/otel/instrumentation.py +35 -0
  34. hud/rl/README.md +31 -0
  35. hud/rl/__init__.py +1 -0
  36. hud/rl/actor.py +174 -0
  37. hud/rl/buffer.py +371 -0
  38. hud/rl/chat_template.jinja +101 -0
  39. hud/rl/config.py +184 -0
  40. hud/rl/distributed.py +95 -0
  41. hud/rl/learner.py +586 -0
  42. hud/rl/tests/__init__.py +1 -0
  43. hud/rl/tests/test_learner.py +171 -0
  44. hud/rl/train.py +354 -0
  45. hud/rl/types.py +101 -0
  46. hud/rl/utils/start_vllm_server.sh +30 -0
  47. hud/rl/utils.py +524 -0
  48. hud/rl/vllm_adapter.py +125 -0
  49. hud/settings.py +6 -0
  50. hud/telemetry/__init__.py +2 -1
  51. hud/telemetry/job.py +46 -3
  52. hud/telemetry/tests/test_trace.py +3 -3
  53. hud/telemetry/trace.py +85 -13
  54. hud/tools/tests/test_computer.py +3 -3
  55. hud/tools/tests/test_computer_actions.py +1 -1
  56. hud/types.py +123 -2
  57. hud/utils/group_eval.py +223 -0
  58. hud/utils/hud_console.py +113 -13
  59. hud/utils/tasks.py +119 -0
  60. hud/utils/tests/test_version.py +1 -1
  61. hud/version.py +1 -1
  62. {hud_python-0.4.28.dist-info → hud_python-0.4.29.dist-info}/METADATA +20 -2
  63. {hud_python-0.4.28.dist-info → hud_python-0.4.29.dist-info}/RECORD +66 -46
  64. hud/cli/hf.py +0 -406
  65. hud/cli/rl/README.md +0 -243
  66. hud/cli/rl/init.py +0 -370
  67. hud/cli/rl/pod.py +0 -501
  68. hud/cli/rl/ssh.py +0 -322
  69. hud/cli/rl/train.py +0 -562
  70. hud/cli/rl/utils.py +0 -165
  71. hud/datasets/execution/__init__.py +0 -13
  72. hud/datasets/task.py +0 -116
  73. {hud_python-0.4.28.dist-info → hud_python-0.4.29.dist-info}/WHEEL +0 -0
  74. {hud_python-0.4.28.dist-info → hud_python-0.4.29.dist-info}/entry_points.txt +0 -0
  75. {hud_python-0.4.28.dist-info → hud_python-0.4.29.dist-info}/licenses/LICENSE +0 -0
hud/telemetry/job.py CHANGED
@@ -89,6 +89,46 @@ class Job:
89
89
  except Exception as e:
90
90
  logger.warning("Failed to update job status: %s", e)
91
91
 
92
+ async def log(self, metrics: dict[str, Any]) -> None:
93
+ """Log metrics to the job.
94
+
95
+ Args:
96
+ metrics: Dictionary of metric name to value pairs
97
+
98
+ Example:
99
+ await job.log({"loss": 0.5, "accuracy": 0.95, "epoch": 1})
100
+ """
101
+ if settings.telemetry_enabled:
102
+ try:
103
+ await make_request(
104
+ method="POST",
105
+ url=f"{settings.hud_telemetry_url}/jobs/{self.id}/log",
106
+ json={"metrics": metrics, "timestamp": datetime.now(UTC).isoformat()},
107
+ api_key=settings.api_key,
108
+ )
109
+ except Exception as e:
110
+ logger.warning("Failed to log metrics to job: %s", e)
111
+
112
+ def log_sync(self, metrics: dict[str, Any]) -> None:
113
+ """Synchronously log metrics to the job.
114
+
115
+ Args:
116
+ metrics: Dictionary of metric name to value pairs
117
+
118
+ Example:
119
+ job.log_sync({"loss": 0.5, "accuracy": 0.95, "epoch": 1})
120
+ """
121
+ if settings.telemetry_enabled:
122
+ try:
123
+ make_request_sync(
124
+ method="POST",
125
+ url=f"{settings.hud_telemetry_url}/jobs/{self.id}/log",
126
+ json={"metrics": metrics, "timestamp": datetime.now(UTC).isoformat()},
127
+ api_key=settings.api_key,
128
+ )
129
+ except Exception as e:
130
+ logger.warning("Failed to log metrics to job: %s", e)
131
+
92
132
  def __repr__(self) -> str:
93
133
  return f"Job(id={self.id!r}, name={self.name!r}, status={self.status!r})"
94
134
 
@@ -225,7 +265,10 @@ def job(
225
265
 
226
266
 
227
267
  def create_job(
228
- name: str, metadata: dict[str, Any] | None = None, dataset_link: str | None = None
268
+ name: str,
269
+ metadata: dict[str, Any] | None = None,
270
+ dataset_link: str | None = None,
271
+ job_id: str | None = None,
229
272
  ) -> Job:
230
273
  """Create a job without using context manager.
231
274
 
@@ -235,7 +278,7 @@ def create_job(
235
278
  name: Human-readable job name
236
279
  metadata: Optional metadata dictionary
237
280
  dataset_link: Optional HuggingFace dataset identifier (e.g. "hud-evals/SheetBench-50")
238
-
281
+ job_id: Optional job ID (auto-generated if not provided)
239
282
  Returns:
240
283
  Job: The created job object
241
284
 
@@ -248,7 +291,7 @@ def create_job(
248
291
  finally:
249
292
  await job.update_status("completed")
250
293
  """
251
- job_id = str(uuid.uuid4())
294
+ job_id = job_id or str(uuid.uuid4())
252
295
  return Job(job_id, name, metadata, dataset_link)
253
296
 
254
297
 
@@ -23,7 +23,7 @@ class TestTraceAPI:
23
23
 
24
24
  with trace("test-trace") as task_run_id:
25
25
  # Should use placeholder ID for custom backends
26
- assert task_run_id == "custom-otlp-trace"
26
+ assert task_run_id.id == "custom-otlp-trace"
27
27
 
28
28
  def test_trace_with_enabled_telemetry_and_api_key(self):
29
29
  """Test trace behavior when telemetry is enabled with API key."""
@@ -39,7 +39,7 @@ class TestTraceAPI:
39
39
 
40
40
  with trace("test-trace") as task_run_id:
41
41
  # Should use generated UUID
42
- assert task_run_id == "mock-uuid-123"
42
+ assert task_run_id.id == "mock-uuid-123"
43
43
 
44
44
  def test_trace_with_no_api_key(self):
45
45
  """Test trace behavior with no API key (custom backend scenario)."""
@@ -60,4 +60,4 @@ class TestTraceAPI:
60
60
 
61
61
  with trace("test-trace") as task_run_id:
62
62
  # Should use custom backend placeholder
63
- assert task_run_id == "custom-otlp-trace"
63
+ assert task_run_id.id == "custom-otlp-trace"
hud/telemetry/trace.py CHANGED
@@ -6,17 +6,83 @@ The actual OpenTelemetry implementation is in hud.otel.
6
6
 
7
7
  from __future__ import annotations
8
8
 
9
+ import logging
9
10
  import uuid
10
11
  from contextlib import contextmanager
12
+ from datetime import UTC, datetime
11
13
  from typing import TYPE_CHECKING, Any
12
14
 
13
15
  from hud.otel import configure_telemetry
14
16
  from hud.otel import trace as OtelTrace
17
+ from hud.settings import settings
18
+ from hud.shared import make_request, make_request_sync
15
19
 
16
20
  if TYPE_CHECKING:
17
21
  from collections.abc import Generator
18
22
 
19
- __all__ = ["trace"]
23
+ logger = logging.getLogger(__name__)
24
+
25
+ __all__ = ["Trace", "trace"]
26
+
27
+
28
+ class Trace:
29
+ """A trace represents a single task execution with telemetry."""
30
+
31
+ def __init__(
32
+ self,
33
+ trace_id: str,
34
+ name: str,
35
+ job_id: str | None = None,
36
+ task_id: str | None = None,
37
+ ) -> None:
38
+ self.id = trace_id
39
+ self.name = name
40
+ self.job_id = job_id
41
+ self.task_id = task_id
42
+ self.created_at = datetime.now(UTC)
43
+
44
+ async def log(self, metrics: dict[str, Any]) -> None:
45
+ """Log metrics to this trace.
46
+
47
+ Args:
48
+ metrics: Dictionary of metric name to value pairs
49
+
50
+ Example:
51
+ await trace.log({"step": 1, "loss": 0.5, "accuracy": 0.92})
52
+ """
53
+ if settings.telemetry_enabled:
54
+ try:
55
+ await make_request(
56
+ method="POST",
57
+ url=f"{settings.hud_telemetry_url}/traces/{self.id}/log",
58
+ json={"metrics": metrics, "timestamp": datetime.now(UTC).isoformat()},
59
+ api_key=settings.api_key,
60
+ )
61
+ except Exception as e:
62
+ logger.warning("Failed to log metrics to trace: %s", e)
63
+
64
+ def log_sync(self, metrics: dict[str, Any]) -> None:
65
+ """Synchronously log metrics to this trace.
66
+
67
+ Args:
68
+ metrics: Dictionary of metric name to value pairs
69
+
70
+ Example:
71
+ trace.log_sync({"step": 1, "loss": 0.5, "accuracy": 0.92})
72
+ """
73
+ if settings.telemetry_enabled:
74
+ try:
75
+ make_request_sync(
76
+ method="POST",
77
+ url=f"{settings.hud_telemetry_url}/traces/{self.id}/log",
78
+ json={"metrics": metrics, "timestamp": datetime.now(UTC).isoformat()},
79
+ api_key=settings.api_key,
80
+ )
81
+ except Exception as e:
82
+ logger.warning("Failed to log metrics to trace: %s", e)
83
+
84
+ def __repr__(self) -> str:
85
+ return f"Trace(id={self.id!r}, name={self.name!r})"
20
86
 
21
87
 
22
88
  @contextmanager
@@ -27,7 +93,7 @@ def trace(
27
93
  attrs: dict[str, Any] | None = None,
28
94
  job_id: str | None = None,
29
95
  task_id: str | None = None,
30
- ) -> Generator[str, None, None]:
96
+ ) -> Generator[Trace, None, None]:
31
97
  """Start a HUD trace context.
32
98
 
33
99
  A unique task_run_id is automatically generated for each trace.
@@ -37,24 +103,27 @@ def trace(
37
103
  root: Whether this is a root trace (updates task status)
38
104
  attrs: Additional attributes to attach to the trace
39
105
  job_id: Optional job ID to associate with this trace
106
+ task_id: Optional task ID (for custom task identifiers)
40
107
 
41
108
  Yields:
42
- str: The auto-generated task run ID
109
+ Trace: The trace object with logging capabilities
43
110
 
44
111
  Usage:
45
112
  import hud
46
113
 
47
- with hud.trace("My Task") as task_run_id:
114
+ # Basic usage
115
+ with hud.trace("My Task") as trace:
48
116
  # Your code here
49
- print(f"Running task: {task_run_id}")
117
+ trace.log_sync({"step": 1, "progress": 0.5})
50
118
 
51
- # Or with default name:
52
- with hud.trace() as task_run_id:
53
- pass
119
+ # Async logging
120
+ async with hud.trace("Async Task") as trace:
121
+ await trace.log({"loss": 0.23, "accuracy": 0.95})
54
122
 
55
- # Or with job_id:
56
- with hud.trace("My Task", job_id="550e8400-e29b-41d4-a716-446655440000") as task_run_id:
57
- pass
123
+ # With job association
124
+ with hud.job("Training Run") as job:
125
+ with hud.trace("Epoch 1", job_id=job.id) as trace:
126
+ trace.log_sync({"epoch": 1, "loss": 0.5})
58
127
  """
59
128
  # Ensure telemetry is configured
60
129
  configure_telemetry()
@@ -71,6 +140,9 @@ def trace(
71
140
  # Use a placeholder for custom backends
72
141
  task_run_id = "custom-otlp-trace"
73
142
 
143
+ # Create trace object
144
+ trace_obj = Trace(task_run_id, name, job_id, task_id)
145
+
74
146
  # Delegate to OpenTelemetry implementation
75
147
  with OtelTrace(
76
148
  task_run_id,
@@ -79,5 +151,5 @@ def trace(
79
151
  attributes=attrs or {},
80
152
  job_id=job_id,
81
153
  task_id=task_id,
82
- ) as run_id:
83
- yield run_id
154
+ ):
155
+ yield trace_obj
@@ -151,7 +151,7 @@ class TestHudComputerToolExtended:
151
151
  async def test_type_action(self, base_executor):
152
152
  """Test type action with BaseExecutor."""
153
153
  tool = HudComputerTool(executor=base_executor)
154
- result = await tool(action="type", text="Hello World", enter_after=True)
154
+ result = await tool(action="write", text="Hello World", enter_after=True)
155
155
  assert result
156
156
  assert any(
157
157
  "[SIMULATED] Type" in content.text
@@ -329,7 +329,7 @@ class TestHudComputerToolExtended:
329
329
  assert result
330
330
 
331
331
  # Test type without coordinates
332
- result = await tool(action="type", text="test")
332
+ result = await tool(action="write", text="test")
333
333
  assert result
334
334
 
335
335
  @pytest.mark.asyncio
@@ -360,7 +360,7 @@ class TestHudComputerToolExtended:
360
360
  from hud.tools.types import ToolError
361
361
 
362
362
  with pytest.raises(ToolError, match="text parameter is required"):
363
- await tool(action="type", text=None)
363
+ await tool(action="write", text=None)
364
364
 
365
365
  # Test press without keys
366
366
  with pytest.raises(ToolError, match="keys parameter is required"):
@@ -12,7 +12,7 @@ CASES = [
12
12
  ("press", {"keys": ["ctrl", "c"]}),
13
13
  ("keydown", {"keys": ["shift"]}),
14
14
  ("keyup", {"keys": ["shift"]}),
15
- ("type", {"text": "hello"}),
15
+ ("write", {"text": "hello"}),
16
16
  ("scroll", {"x": 10, "y": 10, "scroll_y": 20}), # Added required x,y coordinates
17
17
  # Skip move test - it has Field parameter handling issues when called directly
18
18
  # ("move", {"x": 5, "y": 5}), # x,y are for absolute positioning
hud/types.py CHANGED
@@ -1,12 +1,120 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import json
4
+ import logging
4
5
  import uuid
6
+ from collections import defaultdict
7
+ from string import Template
5
8
  from typing import Any, Literal
6
9
 
7
10
  import mcp.types as types
8
11
  from mcp.types import CallToolRequestParams, CallToolResult
9
- from pydantic import BaseModel, ConfigDict, Field
12
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
13
+
14
+ from hud.settings import settings
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class Task(BaseModel):
20
+ """
21
+ A task configuration that can be used to create a task.
22
+
23
+ The mcp_config field supports environment variable substitution using
24
+ template placeholders in the format ${VAR_NAME} or ${VAR_NAME:default_value}.
25
+
26
+ Example:
27
+ mcp_config: {
28
+ "hud": {
29
+ "url": "${HUD_MCP_URL:https://mcp.hud.so/v3/mcp}",
30
+ "headers": {
31
+ "Authorization": "Bearer ${HUD_API_KEY}",
32
+ "Mcp-Image": "your-mcp-image"
33
+ }
34
+ }
35
+ }
36
+ """
37
+
38
+ id: str | None = None
39
+ prompt: str
40
+ mcp_config: dict[str, Any]
41
+ setup_tool: MCPToolCall | list[MCPToolCall] | None = None
42
+ evaluate_tool: MCPToolCall | list[MCPToolCall] | None = None
43
+ agent_tools: list[str] | None = None
44
+ system_prompt: str | None = None
45
+ metadata: dict[str, Any] = Field(default_factory=dict)
46
+
47
+ @field_validator("mcp_config", "metadata", mode="before")
48
+ @classmethod
49
+ def parse_json_strings(cls, v: Any) -> Any:
50
+ """Parse JSON strings into dictionaries."""
51
+ if isinstance(v, str):
52
+ try:
53
+ return json.loads(v)
54
+ except json.JSONDecodeError as e:
55
+ from hud.shared.exceptions import HudConfigError
56
+
57
+ raise HudConfigError(f"Invalid JSON string: {e}") from e
58
+ return v
59
+
60
+ @field_validator("setup_tool", "evaluate_tool", mode="before")
61
+ @classmethod
62
+ def convert_dict_to_tool_call(cls, v: Any) -> Any:
63
+ """Convert dict to MCPToolCall instance, parsing JSON strings first."""
64
+ if v is None:
65
+ return None
66
+
67
+ # Parse JSON string if needed
68
+ if isinstance(v, str):
69
+ try:
70
+ v = json.loads(v)
71
+ except json.JSONDecodeError as e:
72
+ from hud.shared.exceptions import HudConfigError
73
+
74
+ raise HudConfigError(f"Invalid JSON string: {e}") from e
75
+
76
+ if isinstance(v, dict):
77
+ return MCPToolCall(**v)
78
+ if isinstance(v, list):
79
+ return [MCPToolCall(**item) if isinstance(item, dict) else item for item in v]
80
+ return v
81
+
82
+ @field_validator("mcp_config", mode="before")
83
+ @classmethod
84
+ def resolve_env_vars(cls, v: dict[str, Any]) -> dict[str, Any]:
85
+ """
86
+ Automatically resolve environment variables in mcp_config using Template.
87
+
88
+ Supports ${VAR_NAME} syntax with variable substitution from
89
+ System environment variables (including HUD_API_KEY, etc.)
90
+
91
+ Missing variables resolve to empty strings.
92
+ """
93
+ import os
94
+
95
+ # Start with current environment variables
96
+ mapping = dict(os.environ)
97
+ mapping.update(settings.model_dump())
98
+
99
+ if settings.api_key:
100
+ mapping["HUD_API_KEY"] = settings.api_key
101
+ else:
102
+ logger.error("HUD_API_KEY is not set, tracing and remote training will not work")
103
+
104
+ def substitute_in_value(obj: Any) -> Any:
105
+ """Recursively substitute variables in nested structures."""
106
+ if isinstance(obj, str):
107
+ # Use Template's substitute with defaultdict - missing vars become empty strings
108
+ safe_mapping = defaultdict(str, mapping)
109
+ return Template(obj).substitute(safe_mapping)
110
+ elif isinstance(obj, dict):
111
+ return {k: substitute_in_value(v) for k, v in obj.items()}
112
+ elif isinstance(obj, list):
113
+ return [substitute_in_value(item) for item in obj]
114
+ else:
115
+ return obj
116
+
117
+ return substitute_in_value(v)
10
118
 
11
119
 
12
120
  class MCPToolCall(CallToolRequestParams):
@@ -150,12 +258,25 @@ class Trace(BaseModel):
150
258
  - trace: The steps taken in the run (empty if not tracing)
151
259
  """
152
260
 
153
- done: bool = Field(default=True)
154
261
  reward: float = Field(default=0.0)
262
+ done: bool = Field(default=True)
155
263
  info: dict[str, Any] = Field(default_factory=dict)
156
264
  content: str | None = Field(default=None)
157
265
  isError: bool = Field(default=False)
266
+
267
+ # Metadata
268
+ task: Task | None = Field(default=None)
269
+
270
+ # Trace
158
271
  trace: list[TraceStep] = Field(default_factory=list)
272
+ messages: list[Any] = Field(default_factory=list)
273
+
274
+ def __len__(self) -> int:
275
+ return len(self.trace)
276
+
277
+ @property
278
+ def num_messages(self) -> int:
279
+ return len(self.messages)
159
280
 
160
281
  def append(self, step: TraceStep) -> None:
161
282
  self.trace.append(step)
@@ -0,0 +1,223 @@
1
+ """Utilities for grouped evaluation of tasks, following the RL pattern."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ from statistics import mean, stdev
7
+ from typing import Any
8
+
9
+ import numpy as np
10
+
11
+ import hud
12
+ from hud.datasets import Task
13
+ from hud.types import Trace
14
+ from hud.utils.hud_console import HUDConsole
15
+
16
+ hud_console = HUDConsole()
17
+
18
+
19
+ async def run_tasks_grouped(
20
+ tasks: list[Any],
21
+ agent_class: type | Any,
22
+ agent_config: dict[str, Any] | None = None,
23
+ group_size: int = 1,
24
+ max_parallel_episodes: int = 48,
25
+ max_steps: int = 10,
26
+ verbose: bool = False,
27
+ job_id: str | None = None,
28
+ ) -> list[dict[str, Any]]:
29
+ """
30
+ Run tasks with grouping, following the RL Actor pattern.
31
+
32
+ Args:
33
+ tasks: List of tasks to run
34
+ agent_class: Agent class or instance to use
35
+ agent_config: Configuration for agent instantiation
36
+ group_size: Number of times to run each task
37
+ max_parallel_episodes: Maximum parallel episodes to run
38
+ max_steps: Maximum steps per episode
39
+ verbose: Whether to show progress
40
+ job_id: Optional job ID for tracking
41
+
42
+ Returns:
43
+ List of statistics for each task group
44
+ """
45
+ agent_config = agent_config or {}
46
+
47
+ # Duplicate tasks according to group_size, exactly like RL
48
+ grouped_tasks = []
49
+ task_mapping = [] # Track which group each result belongs to
50
+
51
+ for i, task in enumerate(tasks):
52
+ for _ in range(group_size):
53
+ grouped_tasks.append(task)
54
+ task_mapping.append(i)
55
+
56
+ hud_console.info(
57
+ f"Running {len(tasks)} tasks with group_size={group_size} ({len(grouped_tasks)} total runs)"
58
+ )
59
+
60
+ # Run all episodes, respecting max_parallel_episodes
61
+ all_traces = []
62
+
63
+ for batch_start in range(0, len(grouped_tasks), max_parallel_episodes):
64
+ batch_end = min(batch_start + max_parallel_episodes, len(grouped_tasks))
65
+ batch = grouped_tasks[batch_start:batch_end]
66
+
67
+ # Run batch in parallel
68
+ async def run_single_episode(task_data: dict[str, Any] | Task, idx: int) -> Trace:
69
+ """Run a single episode."""
70
+ try:
71
+ # Create task if needed
72
+ task = Task(**task_data) if isinstance(task_data, dict) else task_data
73
+
74
+ # Create fresh agent instance
75
+ if isinstance(agent_class, type):
76
+ agent = agent_class(**agent_config)
77
+ else:
78
+ # Agent is already instantiated
79
+ agent = agent_class
80
+
81
+ # Run the task
82
+ trace_name = f"Eval | {task.id if hasattr(task, 'id') else 'Task'} | Group {task_mapping[idx]}" # noqa: E501
83
+ with hud.trace(trace_name, job_id=job_id):
84
+ result = await agent.run(task, max_steps=max_steps)
85
+ return result
86
+
87
+ except Exception as e:
88
+ hud_console.warning_log(f"Episode failed: {e}")
89
+ return Trace(isError=True, content=str(e), reward=0.0, done=True)
90
+
91
+ # Run batch
92
+ batch_results = await asyncio.gather(
93
+ *[run_single_episode(t, batch_start + i) for i, t in enumerate(batch)],
94
+ return_exceptions=True,
95
+ )
96
+
97
+ # Normalize exceptions to error traces
98
+ for res in batch_results:
99
+ if isinstance(res, Exception):
100
+ hud_console.warning_log(f"Episode error: {res}")
101
+ all_traces.append(Trace(isError=True, content=str(res), reward=0.0, done=True))
102
+ else:
103
+ all_traces.append(res)
104
+
105
+ if verbose:
106
+ hud_console.info(f"Completed batch: {len(all_traces)}/{len(grouped_tasks)} episodes")
107
+
108
+ # Group results back by original task and calculate statistics
109
+ return calculate_group_statistics(tasks, all_traces, task_mapping, group_size)
110
+
111
+
112
+ def calculate_group_statistics(
113
+ original_tasks: list[Any],
114
+ traces: list[Trace],
115
+ task_mapping: list[int],
116
+ group_size: int,
117
+ ) -> list[dict[str, Any]]:
118
+ """
119
+ Calculate statistics for each group, similar to preprocess_advantages.
120
+
121
+ Args:
122
+ original_tasks: Original task list
123
+ traces: All traces from grouped runs
124
+ task_mapping: Mapping of trace index to task index
125
+ group_size: Number of runs per task
126
+
127
+ Returns:
128
+ List of statistics for each task
129
+ """
130
+ stats = []
131
+
132
+ # Process each original task
133
+ for task_idx, task in enumerate(original_tasks):
134
+ # Get all traces for this task
135
+ task_traces = [
136
+ traces[i] for i, mapping_idx in enumerate(task_mapping) if mapping_idx == task_idx
137
+ ]
138
+
139
+ # Extract rewards
140
+ rewards = np.array([t.reward for t in task_traces])
141
+ errors = [t for t in task_traces if t.isError]
142
+
143
+ # Calculate statistics
144
+ task_stats = {
145
+ "task_id": task.id
146
+ if isinstance(task, Task) and hasattr(task, "id")
147
+ else f"task_{task_idx}",
148
+ "prompt": task.prompt if isinstance(task, Task) else task.get("prompt", ""),
149
+ "group_size": group_size,
150
+ "rewards": rewards.tolist(),
151
+ "mean_reward": float(np.mean(rewards)),
152
+ "std_reward": float(np.std(rewards)) if len(rewards) > 1 else 0.0,
153
+ "min_reward": float(np.min(rewards)),
154
+ "max_reward": float(np.max(rewards)),
155
+ "success_rate": float(np.sum(rewards > 0) / len(rewards)) if len(rewards) > 0 else 0.0,
156
+ "error_rate": len(errors) / len(task_traces) if len(task_traces) > 0 else 0.0,
157
+ "traces": task_traces, # Keep full traces for detailed analysis
158
+ }
159
+
160
+ # Add variance info like RL does
161
+ if task_stats["std_reward"] > 1e-6:
162
+ task_stats["normalized_rewards"] = [
163
+ (r - task_stats["mean_reward"]) / task_stats["std_reward"] for r in rewards
164
+ ]
165
+ else:
166
+ task_stats["normalized_rewards"] = [0.0] * len(rewards)
167
+
168
+ stats.append(task_stats)
169
+
170
+ return stats
171
+
172
+
173
+ def display_group_statistics(stats: list[dict[str, Any]], show_details: bool = True) -> None:
174
+ """Display statistics from grouped evaluation."""
175
+ from rich.console import Console
176
+ from rich.table import Table
177
+
178
+ console = Console()
179
+
180
+ # Overall statistics
181
+ all_means = [s["mean_reward"] for s in stats]
182
+ overall_mean = mean(all_means) if all_means else 0.0
183
+ overall_std = stdev(all_means) if len(all_means) > 1 else 0.0
184
+
185
+ hud_console.success("\n📊 Evaluation Summary")
186
+ hud_console.info(f"Tasks evaluated: {len(stats)}")
187
+ hud_console.info(f"Episodes per task: {stats[0]['group_size'] if stats else 0}")
188
+ hud_console.info(f"Total episodes: {sum(len(s['rewards']) for s in stats)}")
189
+ hud_console.info(f"Overall mean reward: {overall_mean:.3f} ± {overall_std:.3f}")
190
+
191
+ # Detailed table
192
+ if show_details and len(stats) <= 20: # Only show for reasonable dataset sizes
193
+ table = Table(title="\nPer-Task Performance Distribution")
194
+ table.add_column("Task", style="cyan", no_wrap=True)
195
+ table.add_column("Mean±Std", justify="right", style="green")
196
+ table.add_column("Min/Max", justify="right")
197
+ table.add_column("Success%", justify="right", style="yellow")
198
+ table.add_column("Rewards", style="dim")
199
+
200
+ for stat in stats:
201
+ task_name = stat["prompt"][:30] + "..." if len(stat["prompt"]) > 30 else stat["prompt"]
202
+ rewards_str = " ".join([f"{r:.2f}" for r in stat["rewards"][:5]])
203
+ if len(stat["rewards"]) > 5:
204
+ rewards_str += " ..."
205
+
206
+ table.add_row(
207
+ task_name,
208
+ f"{stat['mean_reward']:.3f}±{stat['std_reward']:.3f}",
209
+ f"{stat['min_reward']:.2f}/{stat['max_reward']:.2f}",
210
+ f"{stat['success_rate'] * 100:.0f}%",
211
+ rewards_str,
212
+ )
213
+
214
+ console.print(table)
215
+
216
+ # High variance tasks
217
+ high_variance_tasks = [s for s in stats if s["std_reward"] > 0.3 and s["group_size"] > 1]
218
+ if high_variance_tasks:
219
+ hud_console.warning(f"\n⚠️ {len(high_variance_tasks)} tasks show high variance (std > 0.3)")
220
+ for task in high_variance_tasks[:3]:
221
+ hud_console.info(
222
+ f" • {task['task_id']}: μ={task['mean_reward']:.3f}, σ={task['std_reward']:.3f}" # noqa: RUF001
223
+ )