hud-python 0.4.27__py3-none-any.whl → 0.4.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hud-python might be problematic. Click here for more details.
- hud/__init__.py +2 -1
- hud/agents/base.py +73 -45
- hud/agents/claude.py +8 -4
- hud/agents/openai_chat_generic.py +65 -40
- hud/agents/tests/test_base.py +0 -4
- hud/agents/tests/test_openai.py +1 -1
- hud/cli/__init__.py +182 -52
- hud/cli/dev.py +8 -9
- hud/cli/eval.py +317 -119
- hud/cli/flows/__init__.py +0 -0
- hud/cli/flows/tasks.py +0 -0
- hud/cli/get.py +160 -0
- hud/cli/rl/__init__.py +563 -71
- hud/cli/rl/config.py +94 -0
- hud/cli/rl/display.py +133 -0
- hud/cli/rl/gpu.py +63 -0
- hud/cli/rl/gpu_utils.py +318 -0
- hud/cli/rl/presets.py +96 -0
- hud/cli/rl/remote_runner.py +348 -0
- hud/cli/rl/rl_api.py +150 -0
- hud/cli/rl/vllm.py +177 -0
- hud/cli/tests/test_analyze_metadata.py +0 -1
- hud/cli/utils/tasks.py +26 -0
- hud/clients/base.py +21 -23
- hud/clients/mcp_use.py +36 -44
- hud/clients/tests/test_mcp_use_retry.py +10 -10
- hud/datasets/__init__.py +4 -3
- hud/datasets/{execution/parallel.py → parallel.py} +1 -1
- hud/datasets/{execution/runner.py → runner.py} +1 -1
- hud/datasets/utils.py +1 -1
- hud/native/tests/test_native_init.py +1 -1
- hud/otel/config.py +1 -1
- hud/otel/instrumentation.py +35 -0
- hud/rl/README.md +31 -0
- hud/rl/__init__.py +1 -0
- hud/rl/actor.py +174 -0
- hud/rl/buffer.py +371 -0
- hud/rl/chat_template.jinja +101 -0
- hud/rl/config.py +184 -0
- hud/rl/distributed.py +95 -0
- hud/rl/learner.py +586 -0
- hud/rl/tests/__init__.py +1 -0
- hud/rl/tests/test_learner.py +171 -0
- hud/rl/train.py +354 -0
- hud/rl/types.py +101 -0
- hud/rl/utils/start_vllm_server.sh +30 -0
- hud/rl/utils.py +524 -0
- hud/rl/vllm_adapter.py +125 -0
- hud/settings.py +6 -0
- hud/telemetry/__init__.py +2 -1
- hud/telemetry/job.py +46 -3
- hud/telemetry/tests/test_trace.py +3 -3
- hud/telemetry/trace.py +85 -13
- hud/tools/computer/hud.py +4 -4
- hud/tools/tests/test_computer.py +3 -3
- hud/tools/tests/test_computer_actions.py +1 -1
- hud/types.py +123 -2
- hud/utils/group_eval.py +223 -0
- hud/utils/hud_console.py +113 -13
- hud/utils/tasks.py +119 -0
- hud/utils/tests/test_version.py +1 -1
- hud/version.py +1 -1
- {hud_python-0.4.27.dist-info → hud_python-0.4.29.dist-info}/METADATA +20 -2
- {hud_python-0.4.27.dist-info → hud_python-0.4.29.dist-info}/RECORD +67 -47
- hud/cli/hf.py +0 -406
- hud/cli/rl/README.md +0 -243
- hud/cli/rl/init.py +0 -370
- hud/cli/rl/pod.py +0 -501
- hud/cli/rl/ssh.py +0 -322
- hud/cli/rl/train.py +0 -562
- hud/cli/rl/utils.py +0 -165
- hud/datasets/execution/__init__.py +0 -13
- hud/datasets/task.py +0 -116
- {hud_python-0.4.27.dist-info → hud_python-0.4.29.dist-info}/WHEEL +0 -0
- {hud_python-0.4.27.dist-info → hud_python-0.4.29.dist-info}/entry_points.txt +0 -0
- {hud_python-0.4.27.dist-info → hud_python-0.4.29.dist-info}/licenses/LICENSE +0 -0
hud/otel/instrumentation.py
CHANGED
|
@@ -55,6 +55,9 @@ def _patch_mcp_instrumentation() -> None:
|
|
|
55
55
|
try:
|
|
56
56
|
from opentelemetry.instrumentation.mcp.instrumentation import McpInstrumentor
|
|
57
57
|
|
|
58
|
+
# First, patch the get_error_type function to handle invalid HTTP status codes
|
|
59
|
+
_patch_get_error_type()
|
|
60
|
+
|
|
58
61
|
def patched_transport_wrapper(self: Any, tracer: Any) -> Callable[..., Any]:
|
|
59
62
|
@asynccontextmanager
|
|
60
63
|
async def traced_method(
|
|
@@ -98,3 +101,35 @@ def _patch_mcp_instrumentation() -> None:
|
|
|
98
101
|
|
|
99
102
|
logger = logging.getLogger(__name__)
|
|
100
103
|
logger.warning("Failed to patch MCP instrumentation: %s", e)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _patch_get_error_type() -> None:
|
|
107
|
+
"""Patch get_error_type to handle invalid HTTP status codes gracefully."""
|
|
108
|
+
import re
|
|
109
|
+
from http import HTTPStatus
|
|
110
|
+
|
|
111
|
+
try:
|
|
112
|
+
import opentelemetry.instrumentation.mcp.instrumentation as mcp_inst
|
|
113
|
+
|
|
114
|
+
def patched_get_error_type(error_message: str) -> str | None:
|
|
115
|
+
"""Extract HTTP status from error message, handling invalid codes."""
|
|
116
|
+
if not isinstance(error_message, str):
|
|
117
|
+
return None
|
|
118
|
+
match = re.search(r"\b(4\d{2}|5\d{2})\b", error_message)
|
|
119
|
+
if match:
|
|
120
|
+
num = int(match.group())
|
|
121
|
+
try:
|
|
122
|
+
# Only return if it's a valid HTTPStatus
|
|
123
|
+
if 400 <= num <= 599:
|
|
124
|
+
return HTTPStatus(num).name
|
|
125
|
+
except ValueError:
|
|
126
|
+
# Not a valid HTTP status code
|
|
127
|
+
logger.debug("Ignoring invalid HTTP status code: %s", num)
|
|
128
|
+
return None
|
|
129
|
+
|
|
130
|
+
# Apply the patch
|
|
131
|
+
mcp_inst.get_error_type = patched_get_error_type
|
|
132
|
+
logger.debug("Patched get_error_type to handle invalid HTTP status codes")
|
|
133
|
+
|
|
134
|
+
except Exception as e:
|
|
135
|
+
logger.warning("Failed to patch get_error_type: %s", e)
|
hud/rl/README.md
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
We suggest running hud rl (or with the --local flag) for optimal hyperparameters and native HuggingFace running.
|
|
2
|
+
|
|
3
|
+
Install:
|
|
4
|
+
```bash
|
|
5
|
+
sudo apt-get update -y && sudo apt-get install -y cuda-toolkit-12-6
|
|
6
|
+
uv pip install -e .[rl]
|
|
7
|
+
uv pip install ninja
|
|
8
|
+
uv pip install flash-attn --no-build-isolation
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
However, if you want to run the training directly, launch a vllm server with:
|
|
13
|
+
```bash
|
|
14
|
+
export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True
|
|
15
|
+
export TOKENIZERS_PARALLELISM=false
|
|
16
|
+
export VLLM_LOGGING_LEVEL=INFO
|
|
17
|
+
export CUDA_VISIBLE_DEVICES=7 # Set this to your last GPU
|
|
18
|
+
|
|
19
|
+
uv run vllm serve Qwen/Qwen2.5-VL-3B-Instruct \
|
|
20
|
+
--api-key token-abc123 --host 0.0.0.0 --port 8000 --tensor-parallel-size 1 --trust-remote-code \
|
|
21
|
+
--max-model-len 16384 --enable-lora --max-lora-rank 64 --max-cpu-loras 4 --enable-auto-tool-choice \
|
|
22
|
+
--tool-call-parser hermes --disable-log-requests --dtype auto
|
|
23
|
+
```
|
|
24
|
+
|
|
25
|
+
And training with (replace 2 with your spare GPUs):
|
|
26
|
+
```bash
|
|
27
|
+
hud get hud-evals/2048-basic
|
|
28
|
+
torchrun --nproc-per-node 2 -m hud.rl.train --tasks 2048-basic.json --verbose
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
Add a `--config path/to/config.json` flag to run a specific configuration (or change the defaults in config.py)
|
hud/rl/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""RL module for HUD."""
|
hud/rl/actor.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
"""Actor for episode collection using vLLM and HUD."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
import httpx
|
|
9
|
+
from openai import AsyncOpenAI
|
|
10
|
+
|
|
11
|
+
import hud
|
|
12
|
+
from hud.agents.openai_chat_generic import GenericOpenAIChatAgent
|
|
13
|
+
from hud.clients.utils.retry_transport import create_retry_httpx_client
|
|
14
|
+
from hud.types import Task, Trace
|
|
15
|
+
from hud.utils.hud_console import HUDConsole
|
|
16
|
+
|
|
17
|
+
from .config import Config
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
hud_console = HUDConsole(logger)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Actor:
|
|
24
|
+
"""Collects episodes using vLLM-served models via HUD agents."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, config: Config) -> None:
|
|
27
|
+
self.config = config
|
|
28
|
+
self.actor_config = config.actor
|
|
29
|
+
self.current_adapter = config.model.base_model
|
|
30
|
+
|
|
31
|
+
# Setup OpenAI client for vLLM
|
|
32
|
+
base_url = self.actor_config.vllm_base_url.replace("localhost", "127.0.0.1")
|
|
33
|
+
self.openai_client = self._create_openai_client(base_url)
|
|
34
|
+
|
|
35
|
+
def _create_openai_client(self, base_url: str) -> AsyncOpenAI:
|
|
36
|
+
"""Create OpenAI client with optimized settings for vLLM."""
|
|
37
|
+
# Match connection limits to parallel_episodes to avoid bottlenecks
|
|
38
|
+
# Use shorter per-request timeout and keep retries modest to avoid long blocking
|
|
39
|
+
http_client = create_retry_httpx_client(
|
|
40
|
+
timeout=httpx.Timeout(30.0),
|
|
41
|
+
)
|
|
42
|
+
return AsyncOpenAI(
|
|
43
|
+
base_url=base_url,
|
|
44
|
+
api_key=self.actor_config.vllm_api_key,
|
|
45
|
+
http_client=http_client,
|
|
46
|
+
max_retries=2,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
def create_agent(self) -> GenericOpenAIChatAgent:
|
|
50
|
+
"""Create an agent with the current adapter."""
|
|
51
|
+
return GenericOpenAIChatAgent(
|
|
52
|
+
openai_client=self.openai_client,
|
|
53
|
+
model_name=self.current_adapter,
|
|
54
|
+
allowed_tools=self.actor_config.allowed_tools,
|
|
55
|
+
append_setup_output=False,
|
|
56
|
+
system_prompt=self.actor_config.system_prompt,
|
|
57
|
+
verbose=self.config.verbose,
|
|
58
|
+
completion_kwargs={
|
|
59
|
+
"temperature": self.actor_config.temperature,
|
|
60
|
+
"max_tokens": self.actor_config.max_new_tokens,
|
|
61
|
+
"tool_choice": "required" if self.actor_config.force_tool_choice else "auto",
|
|
62
|
+
},
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def update_adapter(self, adapter_name: str) -> None:
|
|
66
|
+
"""Update the current adapter being used."""
|
|
67
|
+
self.current_adapter = adapter_name
|
|
68
|
+
hud_console.info(f"[Actor] Using adapter: {adapter_name}")
|
|
69
|
+
|
|
70
|
+
async def run_tasks(self, tasks: list[Task], job_id: str) -> list[Trace]:
|
|
71
|
+
"""Run tasks and collect traces."""
|
|
72
|
+
traces = []
|
|
73
|
+
|
|
74
|
+
# Process tasks in batches respecting max_parallel_episodes limit
|
|
75
|
+
for batch_start in range(0, len(tasks), self.actor_config.max_parallel_episodes):
|
|
76
|
+
batch_end = min(batch_start + self.actor_config.max_parallel_episodes, len(tasks))
|
|
77
|
+
batch = tasks[batch_start:batch_end]
|
|
78
|
+
|
|
79
|
+
# Run batch in parallel with per-episode timeout protection
|
|
80
|
+
async def run_with_timeout(t: Task) -> Trace:
|
|
81
|
+
try:
|
|
82
|
+
return await asyncio.wait_for(
|
|
83
|
+
self._run_task(t, job_id),
|
|
84
|
+
timeout=self.actor_config.episode_timeout_sec,
|
|
85
|
+
)
|
|
86
|
+
except TimeoutError:
|
|
87
|
+
hud_console.warning_log(f"Episode timed out for task {t.id}")
|
|
88
|
+
return Trace(isError=True, content="Episode timeout")
|
|
89
|
+
|
|
90
|
+
results = await asyncio.gather(
|
|
91
|
+
*[run_with_timeout(t) for t in batch],
|
|
92
|
+
return_exceptions=True,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# Normalize exceptions to error traces
|
|
96
|
+
for res in results:
|
|
97
|
+
if isinstance(res, Exception):
|
|
98
|
+
hud_console.warning_log(f"Episode error: {res}")
|
|
99
|
+
traces.append(Trace(isError=True, content=str(res)))
|
|
100
|
+
else:
|
|
101
|
+
traces.append(res)
|
|
102
|
+
|
|
103
|
+
return traces
|
|
104
|
+
|
|
105
|
+
async def _run_task(self, task: Task, job_id: str) -> Trace:
|
|
106
|
+
"""Run a single task."""
|
|
107
|
+
agent = self.create_agent()
|
|
108
|
+
|
|
109
|
+
# Run the task
|
|
110
|
+
try:
|
|
111
|
+
with hud.trace(f"Training | {task.id}", job_id=job_id):
|
|
112
|
+
result = await agent.run(task, max_steps=self.actor_config.max_steps_per_episode)
|
|
113
|
+
|
|
114
|
+
except Exception:
|
|
115
|
+
logger.info("GOT EXCEPTION")
|
|
116
|
+
return Trace(isError=True)
|
|
117
|
+
|
|
118
|
+
result.info["tool_spec"] = agent.get_tool_schemas()
|
|
119
|
+
|
|
120
|
+
return result
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
if __name__ == "__main__":
|
|
124
|
+
from hud.types import Task
|
|
125
|
+
|
|
126
|
+
async def test_actor() -> None:
|
|
127
|
+
"""Test the actor with a single 2048 task using local hud-browser image."""
|
|
128
|
+
config = Config()
|
|
129
|
+
config.actor.max_parallel_episodes = 1
|
|
130
|
+
config.actor.max_steps_per_episode = 6
|
|
131
|
+
config.actor.episode_timeout_sec = 120
|
|
132
|
+
config.verbose = True
|
|
133
|
+
|
|
134
|
+
# Create test task with local hud-browser image
|
|
135
|
+
task_data = {
|
|
136
|
+
"id": "test_2048_128",
|
|
137
|
+
"prompt": "Play the browser-based 2048 game and try to reach the 128 tile. Start by taking a screenshot, then make strategic moves using arrow keys.", # noqa: E501
|
|
138
|
+
"mcp_config": {
|
|
139
|
+
"local": {
|
|
140
|
+
"command": "sh",
|
|
141
|
+
"args": [
|
|
142
|
+
"-c",
|
|
143
|
+
"docker run --rm --platform linux/amd64 -i hud-browser:latest 2>/dev/null",
|
|
144
|
+
],
|
|
145
|
+
}
|
|
146
|
+
},
|
|
147
|
+
"setup_tool": {"name": "launch_app", "arguments": {"app_name": "2048"}},
|
|
148
|
+
"evaluate_tool": {
|
|
149
|
+
"name": "evaluate",
|
|
150
|
+
"arguments": {"name": "game_2048_max_number", "arguments": {"target": 128}},
|
|
151
|
+
},
|
|
152
|
+
"system_prompt": "You are an expert 2048 game player. Use arrow keys to reach the target tile. First take a screenshot, then make strategic moves.", # noqa: E501
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
task = Task(**task_data)
|
|
156
|
+
actor = Actor(config)
|
|
157
|
+
|
|
158
|
+
logger.info("Testing actor with task: %s", task.id)
|
|
159
|
+
logger.info("Model: %s", config.model.base_model)
|
|
160
|
+
logger.info("VLLM: %s", config.actor.vllm_base_url)
|
|
161
|
+
|
|
162
|
+
traces = await actor.run_tasks([task], job_id="test_2048")
|
|
163
|
+
|
|
164
|
+
for trace in traces:
|
|
165
|
+
if trace.isError:
|
|
166
|
+
logger.info("Error: %s", trace.content)
|
|
167
|
+
else:
|
|
168
|
+
logger.info("Success!")
|
|
169
|
+
logger.info("Trace info: %s", trace.info if hasattr(trace, "info") else "No info")
|
|
170
|
+
# Check for evaluation in the trace info
|
|
171
|
+
if hasattr(trace, "info") and "evaluation" in trace.info:
|
|
172
|
+
logger.info(" Evaluation: %s", trace.info["evaluation"])
|
|
173
|
+
|
|
174
|
+
asyncio.run(test_actor())
|
hud/rl/buffer.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
1
|
+
"""Replay buffer for storing and sampling episodes."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import random
|
|
7
|
+
from collections import deque
|
|
8
|
+
from typing import TYPE_CHECKING, Generic, TypeVar
|
|
9
|
+
|
|
10
|
+
from hud.types import Task, Trace
|
|
11
|
+
from hud.utils.hud_console import HUDConsole
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
hud_console = HUDConsole(logger=logger)
|
|
15
|
+
|
|
16
|
+
T = TypeVar("T")
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from collections.abc import Callable
|
|
20
|
+
|
|
21
|
+
from hud.rl.config import Config
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Buffer(Generic[T]):
|
|
25
|
+
"""Simple buffer for a list of tasks, traces or episodes."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, max_size: int = 10000) -> None:
|
|
28
|
+
self.max_size = max_size
|
|
29
|
+
self.buffer: deque[T] = deque(maxlen=max_size)
|
|
30
|
+
|
|
31
|
+
def add(self, items: list[T] | T, shuffle: bool = False) -> None:
|
|
32
|
+
"""Add items to buffer."""
|
|
33
|
+
if isinstance(items, list):
|
|
34
|
+
for item in items:
|
|
35
|
+
self.buffer.append(item)
|
|
36
|
+
else:
|
|
37
|
+
self.buffer.append(items)
|
|
38
|
+
if shuffle:
|
|
39
|
+
random.shuffle(self.buffer)
|
|
40
|
+
|
|
41
|
+
def add_fill(self, items: list[T] | T, target_size: int, shuffle: bool = False) -> None:
|
|
42
|
+
"""Add items to buffer until the buffer is at least the target size."""
|
|
43
|
+
while len(self.buffer) < target_size:
|
|
44
|
+
self.add(items, shuffle)
|
|
45
|
+
|
|
46
|
+
def get(self, n: int = 0) -> list[T]:
|
|
47
|
+
"""Get items from the buffer."""
|
|
48
|
+
if n == 0:
|
|
49
|
+
return list(self.buffer)
|
|
50
|
+
if n > len(self.buffer):
|
|
51
|
+
raise ValueError("Not enough items in buffer")
|
|
52
|
+
return list(self.buffer)[-n:]
|
|
53
|
+
|
|
54
|
+
def consume(self, n: int = 0) -> list[T]:
|
|
55
|
+
"""Consume items from the buffer."""
|
|
56
|
+
if n == 0:
|
|
57
|
+
return list(self.buffer)
|
|
58
|
+
if n > len(self.buffer):
|
|
59
|
+
raise ValueError("Not enough items in buffer")
|
|
60
|
+
|
|
61
|
+
return [self.buffer.pop() for _ in range(n)]
|
|
62
|
+
|
|
63
|
+
def get_filtered(
|
|
64
|
+
self, n: int = 0, filter_fn: Callable[[T], bool] | None = None, consume: bool = False
|
|
65
|
+
) -> list[T]:
|
|
66
|
+
"""Filter the buffer by a filter function."""
|
|
67
|
+
filtered = (
|
|
68
|
+
[item for item in self.buffer if filter_fn(item)] if filter_fn else list(self.buffer)
|
|
69
|
+
)
|
|
70
|
+
if n == 0:
|
|
71
|
+
return filtered
|
|
72
|
+
return self.consume(n) if consume else self.get(n)
|
|
73
|
+
|
|
74
|
+
def sample(
|
|
75
|
+
self,
|
|
76
|
+
batch_size: int,
|
|
77
|
+
n: int = 0,
|
|
78
|
+
filter_fn: Callable[[T], bool] | None = None,
|
|
79
|
+
consume: bool = False,
|
|
80
|
+
) -> list[T]:
|
|
81
|
+
"""Sample a batch of items with optional filtering."""
|
|
82
|
+
items = self.get_filtered(n, filter_fn, consume)
|
|
83
|
+
|
|
84
|
+
if len(items) < batch_size:
|
|
85
|
+
hud_console.warning(f"Buffer has {len(items)} items, requested {batch_size}")
|
|
86
|
+
return items
|
|
87
|
+
|
|
88
|
+
return random.sample(items, batch_size)
|
|
89
|
+
|
|
90
|
+
def clear(self) -> None:
|
|
91
|
+
"""Clear the buffer."""
|
|
92
|
+
self.buffer.clear()
|
|
93
|
+
|
|
94
|
+
def __len__(self) -> int:
|
|
95
|
+
"""Use len() directly on Buffer instances."""
|
|
96
|
+
return len(self.buffer)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class DatasetBuffer(Buffer[Task]):
|
|
100
|
+
"""
|
|
101
|
+
Buffer for a dataset.
|
|
102
|
+
Loads in individual tasks that will be trained for a specified number of training steps.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
dataset: list[Task] | Task,
|
|
108
|
+
config: Config,
|
|
109
|
+
) -> None:
|
|
110
|
+
self.config = config
|
|
111
|
+
|
|
112
|
+
self.group_size = config.training.group_size
|
|
113
|
+
self.batch_size = config.training.batch_size
|
|
114
|
+
self.training_steps = config.training.training_steps
|
|
115
|
+
|
|
116
|
+
if self.group_size > self.batch_size:
|
|
117
|
+
raise ValueError(
|
|
118
|
+
f"Group size is greater than batch size, {self.group_size} > {self.batch_size}"
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
if self.batch_size % self.group_size != 0:
|
|
122
|
+
raise ValueError(
|
|
123
|
+
f"A batch cannot have irregular groups, {self.group_size} % {self.batch_size} != 0"
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
if self.group_size % config.training.mini_batch_size != 0:
|
|
127
|
+
raise ValueError(
|
|
128
|
+
f"Group size is not a multiple of mini batch size, {self.group_size} % {config.training.mini_batch_size} != 0" # noqa: E501
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
self.groups_per_batch = self.batch_size // self.group_size
|
|
132
|
+
self.number_of_tasks = self.training_steps * self.groups_per_batch
|
|
133
|
+
|
|
134
|
+
super().__init__(self.number_of_tasks)
|
|
135
|
+
|
|
136
|
+
dataset = dataset if isinstance(dataset, list) else [dataset]
|
|
137
|
+
tasks = self._validate_tasks(dataset)
|
|
138
|
+
if config.training.shuffle_dataset:
|
|
139
|
+
random.shuffle(tasks)
|
|
140
|
+
if len(tasks) > self.number_of_tasks:
|
|
141
|
+
leftovers = len(tasks) - self.number_of_tasks
|
|
142
|
+
hud_console.warning(
|
|
143
|
+
f"Training steps ({self.training_steps}) will lead to {leftovers} tasks not being trained" # noqa: E501
|
|
144
|
+
)
|
|
145
|
+
tasks = tasks[: self.number_of_tasks]
|
|
146
|
+
|
|
147
|
+
# Check if the dataset is imbalanced
|
|
148
|
+
self.dataset_size = len(tasks)
|
|
149
|
+
if self.training_steps % self.dataset_size != 0:
|
|
150
|
+
leftovers = self.number_of_tasks % self.dataset_size
|
|
151
|
+
hud_console.warning(
|
|
152
|
+
f"Dataset imbalanced ({leftovers} tasks will be trained 1 more time)"
|
|
153
|
+
)
|
|
154
|
+
hud_console.warning(
|
|
155
|
+
f"This is because the number of training steps ({self.training_steps}) is not a multiple of the dataset size ({self.dataset_size})" # noqa: E501
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
self.add_fill(tasks, self.number_of_tasks, config.training.shuffle_dataset)
|
|
159
|
+
|
|
160
|
+
def _validate_tasks(self, tasks: list[Task]) -> list[Task]:
|
|
161
|
+
"""Validate that all tasks are proper HUD Task objects."""
|
|
162
|
+
if not tasks:
|
|
163
|
+
raise ValueError("No tasks provided to DatasetBuffer")
|
|
164
|
+
|
|
165
|
+
validated_tasks = []
|
|
166
|
+
for i, task in enumerate(tasks):
|
|
167
|
+
if not isinstance(task, Task):
|
|
168
|
+
raise TypeError(f"Task at index {i} is not a HUD Task object, got {type(task)}")
|
|
169
|
+
validated_tasks.append(task)
|
|
170
|
+
|
|
171
|
+
return validated_tasks
|
|
172
|
+
|
|
173
|
+
@property
|
|
174
|
+
def info(self) -> dict[str, int | float | str]:
|
|
175
|
+
"""Get the info of the buffer."""
|
|
176
|
+
return {
|
|
177
|
+
"total_items": len(self),
|
|
178
|
+
"total_traces": self.number_of_tasks * self.group_size,
|
|
179
|
+
"total_batches": self.training_steps,
|
|
180
|
+
"task_repeats": self.number_of_tasks // self.dataset_size,
|
|
181
|
+
"dataset_size": self.dataset_size,
|
|
182
|
+
"group_size": self.group_size,
|
|
183
|
+
"batch_size": self.batch_size,
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
def get_tasks(self, consume: bool = True) -> list[Task]:
|
|
187
|
+
"""Get tasks for a batch."""
|
|
188
|
+
tasks = self.consume(self.groups_per_batch) if consume else self.get(self.groups_per_batch)
|
|
189
|
+
# Create groups where each group contains group_size copies of the same task
|
|
190
|
+
result = []
|
|
191
|
+
for task in tasks:
|
|
192
|
+
result.extend([task] * self.group_size)
|
|
193
|
+
return result
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class ReplayBuffer(Buffer[Trace]):
|
|
197
|
+
"""Buffer for traces."""
|
|
198
|
+
|
|
199
|
+
def __init__(self, config: Config) -> None:
|
|
200
|
+
self.config = config
|
|
201
|
+
|
|
202
|
+
self.buffer_steps = config.training.buffer_steps
|
|
203
|
+
self.select_strategy = config.training.select_strategy
|
|
204
|
+
self.group_size = config.training.group_size
|
|
205
|
+
self.batch_size = config.training.batch_size
|
|
206
|
+
|
|
207
|
+
buffer_size = self.buffer_steps * self.batch_size
|
|
208
|
+
|
|
209
|
+
super().__init__(buffer_size)
|
|
210
|
+
|
|
211
|
+
def sample_traces(self) -> list[Trace]:
|
|
212
|
+
"""Sample traces for a batch."""
|
|
213
|
+
if self.select_strategy == "recent":
|
|
214
|
+
return self.get(self.batch_size)
|
|
215
|
+
elif self.select_strategy == "random":
|
|
216
|
+
return self.sample(self.batch_size)
|
|
217
|
+
elif self.select_strategy == "variance":
|
|
218
|
+
return self._sample_high_variance_traces()
|
|
219
|
+
else:
|
|
220
|
+
raise ValueError(f"Invalid select strategy: {self.select_strategy}")
|
|
221
|
+
|
|
222
|
+
def _sample_high_variance_traces(self) -> list[Trace]:
|
|
223
|
+
from collections import Counter, defaultdict, deque
|
|
224
|
+
|
|
225
|
+
# Expect recent window to already be grouped by task id
|
|
226
|
+
|
|
227
|
+
# Build recent window and earlier lookup (short form)
|
|
228
|
+
buf_list = list(self.buffer)
|
|
229
|
+
if len(buf_list) < self.batch_size:
|
|
230
|
+
hud_console.warning(
|
|
231
|
+
f"[group-sampler] Buffer has only {len(buf_list)} traces, need {self.batch_size}"
|
|
232
|
+
)
|
|
233
|
+
while len(buf_list) < self.batch_size:
|
|
234
|
+
take = min(len(buf_list) or 1, self.batch_size - len(buf_list))
|
|
235
|
+
buf_list.extend(buf_list[:take])
|
|
236
|
+
recent_traces = buf_list[-self.batch_size :]
|
|
237
|
+
hud_console.info(
|
|
238
|
+
f"[group-sampler] recent-window histogram: {Counter(getattr(t.task, 'id', 'NA') for t in recent_traces)}" # noqa: E501
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
hud_console.info(
|
|
242
|
+
f"[group-sampler] Building earlier traces lookup, buffer size: {len(buf_list)}"
|
|
243
|
+
)
|
|
244
|
+
earlier_traces_by_task: dict[str, deque[Trace]] = defaultdict(deque)
|
|
245
|
+
for tr in buf_list[: -self.batch_size]:
|
|
246
|
+
earlier_traces_by_task[getattr(tr.task, "id", "NA")].append(tr)
|
|
247
|
+
|
|
248
|
+
# Chunk from the most-recent end
|
|
249
|
+
final_traces: list[Trace] = []
|
|
250
|
+
groups_per_batch = self.batch_size // self.group_size
|
|
251
|
+
hud_console.info(f"[group-sampler] Processing {groups_per_batch} groups")
|
|
252
|
+
for g_idx in range(groups_per_batch):
|
|
253
|
+
start = g_idx * self.group_size
|
|
254
|
+
end = start + self.group_size
|
|
255
|
+
group = recent_traces[start:end]
|
|
256
|
+
|
|
257
|
+
# Assert homogeneity: every trace in a group must share the same task id
|
|
258
|
+
cnt = Counter(getattr(t.task, "id", "NA") for t in group)
|
|
259
|
+
if len(cnt) != 1:
|
|
260
|
+
raise RuntimeError(f"Group {g_idx} is not homogeneous: {dict(cnt)}")
|
|
261
|
+
target_tid = next(iter(cnt.keys()))
|
|
262
|
+
|
|
263
|
+
# Build homogeneous group of target_tid, filling from earlier traces to increase spread
|
|
264
|
+
homogeneous: list[Trace] = [
|
|
265
|
+
t for t in group if getattr(t.task, "id", "NA") == target_tid
|
|
266
|
+
]
|
|
267
|
+
needed = self.group_size - len(homogeneous)
|
|
268
|
+
|
|
269
|
+
# Greedy fill: choose earlier traces (same task-id) farthest from current mean reward
|
|
270
|
+
def current_mean(homogeneous: list[Trace]) -> float:
|
|
271
|
+
if not homogeneous:
|
|
272
|
+
return 0.0
|
|
273
|
+
vals = [float(getattr(t, "reward", 0.0) or 0.0) for t in homogeneous]
|
|
274
|
+
return sum(vals) / len(vals)
|
|
275
|
+
|
|
276
|
+
while needed > 0:
|
|
277
|
+
pool = earlier_traces_by_task.get(target_tid, deque())
|
|
278
|
+
if pool:
|
|
279
|
+
mu = current_mean(homogeneous)
|
|
280
|
+
# pick element farthest from current mean
|
|
281
|
+
best_i = None
|
|
282
|
+
best_dist = -1.0
|
|
283
|
+
for i, tr in enumerate(list(pool)):
|
|
284
|
+
r = float(getattr(tr, "reward", 0.0) or 0.0)
|
|
285
|
+
dist = abs(r - mu)
|
|
286
|
+
if dist > best_dist:
|
|
287
|
+
best_dist = dist
|
|
288
|
+
best_i = i
|
|
289
|
+
# pop selected
|
|
290
|
+
chosen = list(pool)[best_i] # type: ignore[index]
|
|
291
|
+
# remove from deque efficiently by rotating
|
|
292
|
+
left = list(pool)
|
|
293
|
+
if best_i is not None:
|
|
294
|
+
left.pop(best_i) # O(n) but pool is small in practice
|
|
295
|
+
earlier_traces_by_task[target_tid] = deque(left)
|
|
296
|
+
homogeneous.append(chosen)
|
|
297
|
+
else:
|
|
298
|
+
# duplicate extreme within current homogeneous set
|
|
299
|
+
if not homogeneous:
|
|
300
|
+
raise RuntimeError(f"Group {g_idx} has no traces for target {target_tid}")
|
|
301
|
+
mu = current_mean(homogeneous)
|
|
302
|
+
extreme = max(
|
|
303
|
+
homogeneous, key=lambda t: abs(float(getattr(t, "reward", 0.0) or 0.0) - mu)
|
|
304
|
+
)
|
|
305
|
+
homogeneous.append(extreme)
|
|
306
|
+
needed -= 1
|
|
307
|
+
|
|
308
|
+
# Replacement step: swap in earlier traces to increase reward spread
|
|
309
|
+
pool = earlier_traces_by_task.get(target_tid, deque())
|
|
310
|
+
if pool:
|
|
311
|
+
# Log pool stats
|
|
312
|
+
pool_vals = [float(getattr(tr, "reward", 0.0) or 0.0) for tr in list(pool)]
|
|
313
|
+
if pool_vals:
|
|
314
|
+
pool_mean = sum(pool_vals) / len(pool_vals)
|
|
315
|
+
pool_var = sum((v - pool_mean) * (v - pool_mean) for v in pool_vals) / len(
|
|
316
|
+
pool_vals
|
|
317
|
+
)
|
|
318
|
+
hud_console.info(
|
|
319
|
+
f"[group-sampler] Group {g_idx}: earlier-pool size={len(pool_vals)} mean={pool_mean:.4f} std={(pool_var**0.5):.4f}" # noqa: E501
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
# Decide how many to replace (up to 1/4 of group, at least 1)
|
|
323
|
+
replace_k = max(1, self.group_size // 4)
|
|
324
|
+
replace_k = min(replace_k, len(pool), self.group_size)
|
|
325
|
+
|
|
326
|
+
if replace_k > 0:
|
|
327
|
+
mu = current_mean(homogeneous)
|
|
328
|
+
# Select replacement candidates from pool farthest from current mean
|
|
329
|
+
pool_list = list(pool)
|
|
330
|
+
pool_indices = list(range(len(pool_list)))
|
|
331
|
+
pool_indices.sort(
|
|
332
|
+
key=lambda i: abs(
|
|
333
|
+
(float(getattr(pool_list[i], "reward", 0.0) or 0.0)) - mu
|
|
334
|
+
),
|
|
335
|
+
reverse=True,
|
|
336
|
+
)
|
|
337
|
+
chosen_pool_idx = set(pool_indices[:replace_k])
|
|
338
|
+
replacements = [pool_list[i] for i in pool_indices[:replace_k]]
|
|
339
|
+
|
|
340
|
+
# Remove chosen from pool deque
|
|
341
|
+
remaining = [tr for i, tr in enumerate(pool_list) if i not in chosen_pool_idx]
|
|
342
|
+
earlier_traces_by_task[target_tid] = deque(remaining)
|
|
343
|
+
|
|
344
|
+
# Select current group positions closest to mean to replace
|
|
345
|
+
group_indices = list(range(len(homogeneous)))
|
|
346
|
+
group_indices.sort(
|
|
347
|
+
key=lambda i: abs(
|
|
348
|
+
(float(getattr(homogeneous[i], "reward", 0.0) or 0.0)) - mu
|
|
349
|
+
)
|
|
350
|
+
)
|
|
351
|
+
target_positions = group_indices[:replace_k]
|
|
352
|
+
|
|
353
|
+
for pos, new_tr in zip(target_positions, replacements, strict=False):
|
|
354
|
+
homogeneous[pos] = new_tr
|
|
355
|
+
|
|
356
|
+
# Validate homogeneity
|
|
357
|
+
if any(getattr(t.task, "id", "NA") != target_tid for t in homogeneous):
|
|
358
|
+
raise RuntimeError(f"Group {g_idx} is not homogeneous after sampling")
|
|
359
|
+
final_traces.extend(homogeneous)
|
|
360
|
+
|
|
361
|
+
for i in range(0, len(final_traces), self.group_size):
|
|
362
|
+
block = final_traces[i : i + self.group_size]
|
|
363
|
+
if len({getattr(t.task, "id", "NA") for t in block}) != 1:
|
|
364
|
+
raise RuntimeError(f"Homogeneity validation failed for block starting at index {i}")
|
|
365
|
+
|
|
366
|
+
hud_console.info(
|
|
367
|
+
f"[group-sampler] final histogram: {Counter(getattr(t.task, 'id', 'NA') for t in final_traces)}" # noqa: E501
|
|
368
|
+
)
|
|
369
|
+
return final_traces
|
|
370
|
+
|
|
371
|
+
# --------------------------------------------------------------------
|