hud-python 0.4.45__py3-none-any.whl → 0.5.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hud/__init__.py +27 -7
- hud/agents/__init__.py +70 -5
- hud/agents/base.py +238 -500
- hud/agents/claude.py +236 -247
- hud/agents/gateway.py +42 -0
- hud/agents/gemini.py +264 -0
- hud/agents/gemini_cua.py +324 -0
- hud/agents/grounded_openai.py +98 -100
- hud/agents/misc/integration_test_agent.py +51 -20
- hud/agents/misc/response_agent.py +48 -36
- hud/agents/openai.py +282 -296
- hud/agents/{openai_chat_generic.py → openai_chat.py} +63 -33
- hud/agents/operator.py +199 -0
- hud/agents/resolver.py +70 -0
- hud/agents/tests/conftest.py +133 -0
- hud/agents/tests/test_base.py +300 -622
- hud/agents/tests/test_base_runtime.py +233 -0
- hud/agents/tests/test_claude.py +381 -214
- hud/agents/tests/test_client.py +9 -10
- hud/agents/tests/test_gemini.py +369 -0
- hud/agents/tests/test_grounded_openai_agent.py +65 -50
- hud/agents/tests/test_openai.py +377 -140
- hud/agents/tests/test_operator.py +362 -0
- hud/agents/tests/test_resolver.py +192 -0
- hud/agents/tests/test_run_eval.py +179 -0
- hud/agents/types.py +148 -0
- hud/cli/__init__.py +493 -546
- hud/cli/analyze.py +43 -5
- hud/cli/build.py +699 -113
- hud/cli/debug.py +8 -5
- hud/cli/dev.py +889 -732
- hud/cli/eval.py +793 -667
- hud/cli/flows/dev.py +167 -0
- hud/cli/flows/init.py +191 -0
- hud/cli/flows/tasks.py +153 -56
- hud/cli/flows/templates.py +151 -0
- hud/cli/flows/tests/__init__.py +1 -0
- hud/cli/flows/tests/test_dev.py +126 -0
- hud/cli/init.py +60 -58
- hud/cli/pull.py +1 -1
- hud/cli/push.py +38 -13
- hud/cli/rft.py +311 -0
- hud/cli/rft_status.py +145 -0
- hud/cli/tests/test_analyze.py +5 -5
- hud/cli/tests/test_analyze_metadata.py +3 -2
- hud/cli/tests/test_analyze_module.py +120 -0
- hud/cli/tests/test_build.py +110 -8
- hud/cli/tests/test_build_failure.py +41 -0
- hud/cli/tests/test_build_module.py +50 -0
- hud/cli/tests/test_cli_init.py +6 -1
- hud/cli/tests/test_cli_more_wrappers.py +30 -0
- hud/cli/tests/test_cli_root.py +140 -0
- hud/cli/tests/test_convert.py +361 -0
- hud/cli/tests/test_debug.py +12 -10
- hud/cli/tests/test_dev.py +197 -0
- hud/cli/tests/test_eval.py +251 -0
- hud/cli/tests/test_eval_bedrock.py +51 -0
- hud/cli/tests/test_init.py +124 -0
- hud/cli/tests/test_main_module.py +11 -5
- hud/cli/tests/test_mcp_server.py +12 -100
- hud/cli/tests/test_push.py +1 -1
- hud/cli/tests/test_push_happy.py +74 -0
- hud/cli/tests/test_push_wrapper.py +23 -0
- hud/cli/tests/test_registry.py +1 -1
- hud/cli/tests/test_utils.py +1 -1
- hud/cli/{rl → utils}/celebrate.py +14 -12
- hud/cli/utils/config.py +18 -1
- hud/cli/utils/docker.py +130 -4
- hud/cli/utils/env_check.py +9 -9
- hud/cli/utils/git.py +136 -0
- hud/cli/utils/interactive.py +39 -5
- hud/cli/utils/metadata.py +70 -1
- hud/cli/utils/runner.py +1 -1
- hud/cli/utils/server.py +2 -2
- hud/cli/utils/source_hash.py +3 -3
- hud/cli/utils/tasks.py +4 -1
- hud/cli/utils/tests/__init__.py +0 -0
- hud/cli/utils/tests/test_config.py +58 -0
- hud/cli/utils/tests/test_docker.py +93 -0
- hud/cli/utils/tests/test_docker_hints.py +71 -0
- hud/cli/utils/tests/test_env_check.py +74 -0
- hud/cli/utils/tests/test_environment.py +42 -0
- hud/cli/utils/tests/test_git.py +142 -0
- hud/cli/utils/tests/test_interactive_module.py +60 -0
- hud/cli/utils/tests/test_local_runner.py +50 -0
- hud/cli/utils/tests/test_logging_utils.py +23 -0
- hud/cli/utils/tests/test_metadata.py +49 -0
- hud/cli/utils/tests/test_package_runner.py +35 -0
- hud/cli/utils/tests/test_registry_utils.py +49 -0
- hud/cli/utils/tests/test_remote_runner.py +25 -0
- hud/cli/utils/tests/test_runner_modules.py +52 -0
- hud/cli/utils/tests/test_source_hash.py +36 -0
- hud/cli/utils/tests/test_tasks.py +80 -0
- hud/cli/utils/version_check.py +258 -0
- hud/cli/{rl → utils}/viewer.py +2 -2
- hud/clients/README.md +12 -11
- hud/clients/__init__.py +4 -3
- hud/clients/base.py +166 -26
- hud/clients/environment.py +51 -0
- hud/clients/fastmcp.py +13 -6
- hud/clients/mcp_use.py +45 -15
- hud/clients/tests/test_analyze_scenarios.py +206 -0
- hud/clients/tests/test_protocol.py +9 -3
- hud/datasets/__init__.py +23 -20
- hud/datasets/loader.py +326 -0
- hud/datasets/runner.py +198 -105
- hud/datasets/tests/__init__.py +0 -0
- hud/datasets/tests/test_loader.py +221 -0
- hud/datasets/tests/test_utils.py +315 -0
- hud/datasets/utils.py +270 -90
- hud/environment/__init__.py +52 -0
- hud/environment/connection.py +258 -0
- hud/environment/connectors/__init__.py +33 -0
- hud/environment/connectors/base.py +68 -0
- hud/environment/connectors/local.py +177 -0
- hud/environment/connectors/mcp_config.py +137 -0
- hud/environment/connectors/openai.py +101 -0
- hud/environment/connectors/remote.py +172 -0
- hud/environment/environment.py +835 -0
- hud/environment/integrations/__init__.py +45 -0
- hud/environment/integrations/adk.py +67 -0
- hud/environment/integrations/anthropic.py +196 -0
- hud/environment/integrations/gemini.py +92 -0
- hud/environment/integrations/langchain.py +82 -0
- hud/environment/integrations/llamaindex.py +68 -0
- hud/environment/integrations/openai.py +238 -0
- hud/environment/mock.py +306 -0
- hud/environment/router.py +263 -0
- hud/environment/scenarios.py +620 -0
- hud/environment/tests/__init__.py +1 -0
- hud/environment/tests/test_connection.py +317 -0
- hud/environment/tests/test_connectors.py +205 -0
- hud/environment/tests/test_environment.py +593 -0
- hud/environment/tests/test_integrations.py +257 -0
- hud/environment/tests/test_local_connectors.py +242 -0
- hud/environment/tests/test_scenarios.py +1086 -0
- hud/environment/tests/test_tools.py +208 -0
- hud/environment/types.py +23 -0
- hud/environment/utils/__init__.py +35 -0
- hud/environment/utils/formats.py +215 -0
- hud/environment/utils/schema.py +171 -0
- hud/environment/utils/tool_wrappers.py +113 -0
- hud/eval/__init__.py +67 -0
- hud/eval/context.py +727 -0
- hud/eval/display.py +299 -0
- hud/eval/instrument.py +187 -0
- hud/eval/manager.py +533 -0
- hud/eval/parallel.py +268 -0
- hud/eval/task.py +372 -0
- hud/eval/tests/__init__.py +1 -0
- hud/eval/tests/test_context.py +178 -0
- hud/eval/tests/test_eval.py +210 -0
- hud/eval/tests/test_manager.py +152 -0
- hud/eval/tests/test_parallel.py +168 -0
- hud/eval/tests/test_task.py +291 -0
- hud/eval/types.py +65 -0
- hud/eval/utils.py +194 -0
- hud/patches/__init__.py +19 -0
- hud/patches/mcp_patches.py +308 -0
- hud/patches/warnings.py +54 -0
- hud/samples/browser.py +4 -4
- hud/server/__init__.py +2 -1
- hud/server/low_level.py +2 -1
- hud/server/router.py +164 -0
- hud/server/server.py +567 -80
- hud/server/tests/test_mcp_server_integration.py +11 -11
- hud/server/tests/test_mcp_server_more.py +1 -1
- hud/server/tests/test_server_extra.py +2 -0
- hud/settings.py +45 -3
- hud/shared/exceptions.py +36 -10
- hud/shared/hints.py +26 -1
- hud/shared/requests.py +15 -3
- hud/shared/tests/test_exceptions.py +40 -31
- hud/shared/tests/test_hints.py +167 -0
- hud/telemetry/__init__.py +20 -19
- hud/telemetry/exporter.py +201 -0
- hud/telemetry/instrument.py +165 -253
- hud/telemetry/tests/test_eval_telemetry.py +356 -0
- hud/telemetry/tests/test_exporter.py +258 -0
- hud/telemetry/tests/test_instrument.py +401 -0
- hud/tools/__init__.py +18 -2
- hud/tools/agent.py +223 -0
- hud/tools/apply_patch.py +639 -0
- hud/tools/base.py +54 -4
- hud/tools/bash.py +2 -2
- hud/tools/computer/__init__.py +36 -3
- hud/tools/computer/anthropic.py +2 -2
- hud/tools/computer/gemini.py +385 -0
- hud/tools/computer/hud.py +23 -6
- hud/tools/computer/openai.py +20 -21
- hud/tools/computer/qwen.py +434 -0
- hud/tools/computer/settings.py +37 -0
- hud/tools/edit.py +3 -7
- hud/tools/executors/base.py +4 -2
- hud/tools/executors/pyautogui.py +1 -1
- hud/tools/grounding/grounded_tool.py +13 -18
- hud/tools/grounding/grounder.py +10 -31
- hud/tools/grounding/tests/test_grounded_tool.py +26 -44
- hud/tools/jupyter.py +330 -0
- hud/tools/playwright.py +18 -3
- hud/tools/shell.py +308 -0
- hud/tools/tests/test_agent_tool.py +355 -0
- hud/tools/tests/test_apply_patch.py +718 -0
- hud/tools/tests/test_computer.py +4 -9
- hud/tools/tests/test_computer_actions.py +24 -2
- hud/tools/tests/test_jupyter_tool.py +181 -0
- hud/tools/tests/test_shell.py +596 -0
- hud/tools/tests/test_submit.py +85 -0
- hud/tools/tests/test_types.py +193 -0
- hud/tools/types.py +21 -1
- hud/types.py +194 -56
- hud/utils/__init__.py +2 -0
- hud/utils/env.py +67 -0
- hud/utils/hud_console.py +89 -18
- hud/utils/mcp.py +15 -58
- hud/utils/strict_schema.py +162 -0
- hud/utils/tests/test_init.py +1 -2
- hud/utils/tests/test_mcp.py +1 -28
- hud/utils/tests/test_pretty_errors.py +186 -0
- hud/utils/tests/test_tool_shorthand.py +154 -0
- hud/utils/tests/test_version.py +1 -1
- hud/utils/types.py +20 -0
- hud/version.py +1 -1
- hud_python-0.5.13.dist-info/METADATA +264 -0
- hud_python-0.5.13.dist-info/RECORD +305 -0
- {hud_python-0.4.45.dist-info → hud_python-0.5.13.dist-info}/WHEEL +1 -1
- hud/agents/langchain.py +0 -261
- hud/agents/lite_llm.py +0 -72
- hud/cli/rl/__init__.py +0 -180
- hud/cli/rl/config.py +0 -101
- hud/cli/rl/display.py +0 -133
- hud/cli/rl/gpu.py +0 -63
- hud/cli/rl/gpu_utils.py +0 -321
- hud/cli/rl/local_runner.py +0 -595
- hud/cli/rl/presets.py +0 -96
- hud/cli/rl/remote_runner.py +0 -463
- hud/cli/rl/rl_api.py +0 -150
- hud/cli/rl/vllm.py +0 -177
- hud/cli/rl/wait_utils.py +0 -89
- hud/datasets/parallel.py +0 -687
- hud/misc/__init__.py +0 -1
- hud/misc/claude_plays_pokemon.py +0 -292
- hud/otel/__init__.py +0 -35
- hud/otel/collector.py +0 -142
- hud/otel/config.py +0 -181
- hud/otel/context.py +0 -570
- hud/otel/exporters.py +0 -369
- hud/otel/instrumentation.py +0 -135
- hud/otel/processors.py +0 -121
- hud/otel/tests/__init__.py +0 -1
- hud/otel/tests/test_processors.py +0 -197
- hud/rl/README.md +0 -30
- hud/rl/__init__.py +0 -1
- hud/rl/actor.py +0 -176
- hud/rl/buffer.py +0 -405
- hud/rl/chat_template.jinja +0 -101
- hud/rl/config.py +0 -192
- hud/rl/distributed.py +0 -132
- hud/rl/learner.py +0 -637
- hud/rl/tests/__init__.py +0 -1
- hud/rl/tests/test_learner.py +0 -186
- hud/rl/train.py +0 -382
- hud/rl/types.py +0 -101
- hud/rl/utils/start_vllm_server.sh +0 -30
- hud/rl/utils.py +0 -524
- hud/rl/vllm_adapter.py +0 -143
- hud/telemetry/job.py +0 -352
- hud/telemetry/replay.py +0 -74
- hud/telemetry/tests/test_replay.py +0 -40
- hud/telemetry/tests/test_trace.py +0 -63
- hud/telemetry/trace.py +0 -158
- hud/utils/agent_factories.py +0 -86
- hud/utils/async_utils.py +0 -65
- hud/utils/group_eval.py +0 -223
- hud/utils/progress.py +0 -149
- hud/utils/tasks.py +0 -127
- hud/utils/tests/test_async_utils.py +0 -173
- hud/utils/tests/test_progress.py +0 -261
- hud_python-0.4.45.dist-info/METADATA +0 -552
- hud_python-0.4.45.dist-info/RECORD +0 -228
- {hud_python-0.4.45.dist-info → hud_python-0.5.13.dist-info}/entry_points.txt +0 -0
- {hud_python-0.4.45.dist-info → hud_python-0.5.13.dist-info}/licenses/LICENSE +0 -0
hud/rl/utils.py
DELETED
|
@@ -1,524 +0,0 @@
|
|
|
1
|
-
"""Utility functions for RL training."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
import base64
|
|
6
|
-
import io
|
|
7
|
-
import logging
|
|
8
|
-
import os
|
|
9
|
-
import random
|
|
10
|
-
from pathlib import Path
|
|
11
|
-
from typing import TYPE_CHECKING, Any
|
|
12
|
-
|
|
13
|
-
import numpy as np
|
|
14
|
-
import torch
|
|
15
|
-
from PIL import Image
|
|
16
|
-
from transformers.utils.chat_template_utils import render_jinja_template
|
|
17
|
-
|
|
18
|
-
from hud.utils.hud_console import HUDConsole
|
|
19
|
-
|
|
20
|
-
from .types import TrainingSample
|
|
21
|
-
|
|
22
|
-
if TYPE_CHECKING:
|
|
23
|
-
from hud.types import Trace
|
|
24
|
-
|
|
25
|
-
from .config import Config
|
|
26
|
-
|
|
27
|
-
logger = logging.getLogger(__name__)
|
|
28
|
-
hud_console = HUDConsole(logger)
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def set_seed(seed: int) -> None:
|
|
32
|
-
"""Set random seeds for reproducibility."""
|
|
33
|
-
random.seed(seed)
|
|
34
|
-
torch.manual_seed(seed)
|
|
35
|
-
if torch.cuda.is_available():
|
|
36
|
-
torch.cuda.manual_seed_all(seed)
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
def load_chat_template(path: str) -> str:
|
|
40
|
-
"""Load chat template from file."""
|
|
41
|
-
with open(path) as f:
|
|
42
|
-
return f.read()
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
def ensure_dir(path: str) -> None:
|
|
46
|
-
"""Create directory if it doesn't exist."""
|
|
47
|
-
os.makedirs(path, exist_ok=True)
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
def get_memory_usage() -> float:
|
|
51
|
-
if torch.cuda.is_available():
|
|
52
|
-
torch.cuda.synchronize()
|
|
53
|
-
return torch.cuda.memory_allocated() / 1024**3
|
|
54
|
-
return 0.0
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
def get_gpu_utilization() -> float:
|
|
58
|
-
"""Get current GPU utilization percentage (0-100)."""
|
|
59
|
-
if not torch.cuda.is_available():
|
|
60
|
-
return 0.0
|
|
61
|
-
|
|
62
|
-
try:
|
|
63
|
-
import nvidia_ml_py as nvml # type: ignore
|
|
64
|
-
|
|
65
|
-
nvml.nvmlInit()
|
|
66
|
-
device_id = torch.cuda.current_device()
|
|
67
|
-
handle = nvml.nvmlDeviceGetHandleByIndex(device_id)
|
|
68
|
-
util = nvml.nvmlDeviceGetUtilizationRates(handle)
|
|
69
|
-
return float(util.gpu)
|
|
70
|
-
except Exception:
|
|
71
|
-
# Fallback: estimate based on memory usage
|
|
72
|
-
# This is less accurate but works without nvidia-ml-py
|
|
73
|
-
return min(100.0, (torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated()) * 100)
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
def aggregate_metrics_across_ranks(
|
|
77
|
-
metrics: Any, metrics_to_aggregate: list[str] | None = None
|
|
78
|
-
) -> None:
|
|
79
|
-
"""Aggregate metrics across all ranks for proper distributed statistics.
|
|
80
|
-
|
|
81
|
-
Args:
|
|
82
|
-
metrics: TrainingMetrics object to update in-place
|
|
83
|
-
metrics_to_aggregate: List of metric names to aggregate. If None, aggregates all numeric metrics.
|
|
84
|
-
|
|
85
|
-
This function:
|
|
86
|
-
1. Gathers metric values from all ranks
|
|
87
|
-
2. Computes proper mean/std across all GPUs
|
|
88
|
-
3. Updates the metrics object in-place (only on rank 0)
|
|
89
|
-
""" # noqa: E501
|
|
90
|
-
from hud.rl.distributed import get_local_rank, get_world_size, is_main_process
|
|
91
|
-
|
|
92
|
-
if get_world_size() <= 1:
|
|
93
|
-
return # Nothing to aggregate in single GPU mode
|
|
94
|
-
|
|
95
|
-
# Default metrics that typically vary across GPUs
|
|
96
|
-
if metrics_to_aggregate is None:
|
|
97
|
-
metrics_to_aggregate = [
|
|
98
|
-
"training_time",
|
|
99
|
-
"samples_per_second",
|
|
100
|
-
"gpu_util",
|
|
101
|
-
"gpu_memory",
|
|
102
|
-
"grad_norm",
|
|
103
|
-
# Include core training scalars
|
|
104
|
-
"loss",
|
|
105
|
-
"kl",
|
|
106
|
-
"entropy",
|
|
107
|
-
"tokens",
|
|
108
|
-
"policy_ratio",
|
|
109
|
-
]
|
|
110
|
-
|
|
111
|
-
# Collect current values from this rank
|
|
112
|
-
local_values = {}
|
|
113
|
-
for metric_name in metrics_to_aggregate:
|
|
114
|
-
if hasattr(metrics, metric_name):
|
|
115
|
-
metric_obj = getattr(metrics, metric_name)
|
|
116
|
-
# Get the last value if available, otherwise 0
|
|
117
|
-
local_values[metric_name] = metric_obj.values[-1] if metric_obj.values else 0.0
|
|
118
|
-
|
|
119
|
-
# Convert to tensor for distributed gathering
|
|
120
|
-
values_tensor = torch.tensor(
|
|
121
|
-
list(local_values.values()), device=f"cuda:{get_local_rank()}", dtype=torch.float32
|
|
122
|
-
)
|
|
123
|
-
|
|
124
|
-
# Gather from all ranks using NCCL-supported all_gather
|
|
125
|
-
world_size = get_world_size()
|
|
126
|
-
gather_list = [torch.zeros_like(values_tensor) for _ in range(world_size)]
|
|
127
|
-
torch.distributed.all_gather(gather_list, values_tensor)
|
|
128
|
-
|
|
129
|
-
# Update metrics on main process only
|
|
130
|
-
if is_main_process():
|
|
131
|
-
# Reshape: [num_gpus, num_metrics]
|
|
132
|
-
all_values = torch.stack(gather_list).cpu().numpy()
|
|
133
|
-
|
|
134
|
-
# Update each metric with aggregated values
|
|
135
|
-
for i, metric_name in enumerate(local_values.keys()):
|
|
136
|
-
metric_obj = getattr(metrics, metric_name)
|
|
137
|
-
gpu_values = all_values[:, i].tolist()
|
|
138
|
-
|
|
139
|
-
# Replace last value with cross-rank mean for reporting
|
|
140
|
-
if len(metric_obj.values) == 0:
|
|
141
|
-
metric_obj.values.append(0.0)
|
|
142
|
-
metric_obj.values[-1] = float(sum(gpu_values) / len(gpu_values))
|
|
143
|
-
# Recompute mean/std across history using updated last value
|
|
144
|
-
metric_obj.mean = float(sum(metric_obj.values) / len(metric_obj.values))
|
|
145
|
-
variance = sum((x - metric_obj.mean) ** 2 for x in metric_obj.values) / len(
|
|
146
|
-
metric_obj.values
|
|
147
|
-
)
|
|
148
|
-
metric_obj.std = float(variance**0.5)
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
def b64_to_pil(b64_str: str) -> Image.Image:
|
|
152
|
-
"""Convert base64 string to PIL Image."""
|
|
153
|
-
return Image.open(io.BytesIO(base64.b64decode(b64_str))).convert("RGB")
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
def build_assistant_masks(
|
|
157
|
-
input_ids: list[list[int]],
|
|
158
|
-
tokenizer: Any,
|
|
159
|
-
) -> list[list[int]]:
|
|
160
|
-
"""
|
|
161
|
-
Build assistant masks from token IDs by finding assistant turns.
|
|
162
|
-
|
|
163
|
-
Args:
|
|
164
|
-
input_ids: List of token sequences
|
|
165
|
-
tokenizer: Tokenizer to decode tokens and get special token IDs
|
|
166
|
-
verbose: Whether to print verbose information
|
|
167
|
-
|
|
168
|
-
Returns:
|
|
169
|
-
List of binary masks indicating assistant tokens
|
|
170
|
-
"""
|
|
171
|
-
id_im_start = tokenizer.convert_tokens_to_ids("<|im_start|>")
|
|
172
|
-
id_im_end = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
|
173
|
-
id_assistant = tokenizer.convert_tokens_to_ids("assistant")
|
|
174
|
-
|
|
175
|
-
assistant_masks: list[list[int]] = []
|
|
176
|
-
|
|
177
|
-
for seq in input_ids:
|
|
178
|
-
mask = [0] * len(seq)
|
|
179
|
-
i_tok = 0
|
|
180
|
-
assistant_turn_count = 0
|
|
181
|
-
|
|
182
|
-
while i_tok < len(seq):
|
|
183
|
-
# Detect start of assistant turn
|
|
184
|
-
if (
|
|
185
|
-
seq[i_tok] == id_im_start
|
|
186
|
-
and i_tok + 1 < len(seq)
|
|
187
|
-
and seq[i_tok + 1] == id_assistant
|
|
188
|
-
):
|
|
189
|
-
assistant_turn_count += 1
|
|
190
|
-
|
|
191
|
-
# Skip '<|im_start|>', 'assistant' and possible newline token
|
|
192
|
-
i_tok += 2
|
|
193
|
-
# Check for newline after 'assistant'
|
|
194
|
-
if i_tok < len(seq) and tokenizer.decode([seq[i_tok]]) == "\n":
|
|
195
|
-
i_tok += 1
|
|
196
|
-
|
|
197
|
-
# Skip leading spaces after assistant\n
|
|
198
|
-
while i_tok < len(seq) and tokenizer.decode([seq[i_tok]]).strip() == "":
|
|
199
|
-
i_tok += 1
|
|
200
|
-
|
|
201
|
-
assistant_content_start = i_tok
|
|
202
|
-
|
|
203
|
-
# Mark tokens until we hit <|im_end|>
|
|
204
|
-
content_end = i_tok
|
|
205
|
-
while i_tok < len(seq) and seq[i_tok] != id_im_end:
|
|
206
|
-
content_end = i_tok + 1 # Track last non-<|im_end|> position
|
|
207
|
-
mask[i_tok] = 1
|
|
208
|
-
i_tok += 1
|
|
209
|
-
|
|
210
|
-
# Remove trailing spaces from the mask
|
|
211
|
-
while content_end > assistant_content_start:
|
|
212
|
-
if (
|
|
213
|
-
mask[content_end - 1] == 1
|
|
214
|
-
and tokenizer.decode([seq[content_end - 1]]).strip() == ""
|
|
215
|
-
):
|
|
216
|
-
mask[content_end - 1] = 0
|
|
217
|
-
content_end -= 1
|
|
218
|
-
else:
|
|
219
|
-
break
|
|
220
|
-
|
|
221
|
-
# Skip the <|im_end|> token
|
|
222
|
-
i_tok += 1
|
|
223
|
-
else:
|
|
224
|
-
i_tok += 1
|
|
225
|
-
|
|
226
|
-
assistant_masks.append(mask)
|
|
227
|
-
|
|
228
|
-
return assistant_masks
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
def prepare_conversation_history(
|
|
232
|
-
conversation_history: list[dict[str, Any]],
|
|
233
|
-
) -> tuple[list[dict[str, Any]], list[Image.Image]]:
|
|
234
|
-
"""Sanitize conversation history to avoid vLLM errors."""
|
|
235
|
-
sanitized_messages = []
|
|
236
|
-
images = []
|
|
237
|
-
for m in conversation_history:
|
|
238
|
-
if "tool_calls" in m:
|
|
239
|
-
m = {
|
|
240
|
-
"role": m["role"],
|
|
241
|
-
"content": m.get("content", ""),
|
|
242
|
-
"tool_calls": [
|
|
243
|
-
tc.model_dump() if not isinstance(tc, dict) else tc
|
|
244
|
-
for tc in m.get("tool_calls", [])
|
|
245
|
-
],
|
|
246
|
-
}
|
|
247
|
-
elif m.get("role") == "user":
|
|
248
|
-
user_content = m.get("content", [])
|
|
249
|
-
for c in user_content:
|
|
250
|
-
if isinstance(c, dict) and c.get("type") == "image_url":
|
|
251
|
-
image_url = c.get("image_url", {})
|
|
252
|
-
url = image_url.get("url", "")
|
|
253
|
-
if url.startswith("data:image"):
|
|
254
|
-
data = url.split(",", 1)[1] if "," in url else url
|
|
255
|
-
images.append(b64_to_pil(data))
|
|
256
|
-
elif isinstance(data, bytes | bytearray):
|
|
257
|
-
images.append(Image.open(io.BytesIO(data)).convert("RGB"))
|
|
258
|
-
c = {"type": "image"}
|
|
259
|
-
m["content"] = user_content
|
|
260
|
-
sanitized_messages.append(m)
|
|
261
|
-
return sanitized_messages, images
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
def prepare_inputs(trace: Trace, processor: Any) -> dict[str, torch.Tensor]:
|
|
265
|
-
"""
|
|
266
|
-
Prepare inputs from a trace.
|
|
267
|
-
|
|
268
|
-
Args:
|
|
269
|
-
trace: Trace to process
|
|
270
|
-
processor: Model processor
|
|
271
|
-
|
|
272
|
-
Returns:
|
|
273
|
-
Inputs for the model
|
|
274
|
-
"""
|
|
275
|
-
if len(trace.messages) == 0:
|
|
276
|
-
return {}
|
|
277
|
-
|
|
278
|
-
# Get images for current turn
|
|
279
|
-
conversation, images = prepare_conversation_history(trace.messages)
|
|
280
|
-
|
|
281
|
-
# Get absolute path to chat template
|
|
282
|
-
chat_template_path = Path(__file__).parent / "chat_template.jinja"
|
|
283
|
-
|
|
284
|
-
# For VL models, processor has a tokenizer attribute; for text models, processor IS tokenizer
|
|
285
|
-
tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
|
286
|
-
|
|
287
|
-
text_list, _ = render_jinja_template(
|
|
288
|
-
conversations=[conversation],
|
|
289
|
-
chat_template=load_chat_template(str(chat_template_path)),
|
|
290
|
-
tools=trace.info["tool_spec"] if trace.info["tool_spec"] else None, # mcp_tools
|
|
291
|
-
return_assistant_tokens_mask=True,
|
|
292
|
-
**tokenizer.special_tokens_map,
|
|
293
|
-
)
|
|
294
|
-
# For text models, don't pass images parameter
|
|
295
|
-
if hasattr(processor, "tokenizer"):
|
|
296
|
-
# VL model - processor accepts images
|
|
297
|
-
inputs = processor(
|
|
298
|
-
images=images if len(images) > 0 else None,
|
|
299
|
-
text=text_list,
|
|
300
|
-
return_offsets_mapping=False, # we no longer need char offsets
|
|
301
|
-
)
|
|
302
|
-
else:
|
|
303
|
-
# Text model - processor is tokenizer, doesn't accept images
|
|
304
|
-
inputs = processor(
|
|
305
|
-
text=text_list,
|
|
306
|
-
return_offsets_mapping=False, # we no longer need char offsets
|
|
307
|
-
)
|
|
308
|
-
|
|
309
|
-
assistant_masks = build_assistant_masks(inputs["input_ids"], tokenizer)
|
|
310
|
-
mask_tensor = torch.tensor(assistant_masks, dtype=torch.long)
|
|
311
|
-
|
|
312
|
-
# Ensure mask_tensor is 2D before slicing
|
|
313
|
-
if mask_tensor.dim() == 1:
|
|
314
|
-
mask_tensor = mask_tensor.unsqueeze(0)
|
|
315
|
-
|
|
316
|
-
# Slice to align with targets [B, T-1]
|
|
317
|
-
inputs["assistant_mask"] = mask_tensor[:, 1:].bool()
|
|
318
|
-
|
|
319
|
-
# Log amount of assistant tokens, and the first 10 tokens that are non 0, decoded
|
|
320
|
-
# assistant_batches = render_assistant_tokens(mask_tensor, inputs['input_ids'], processor)
|
|
321
|
-
inputs.convert_to_tensors(tensor_type="pt")
|
|
322
|
-
|
|
323
|
-
return inputs
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
def render_assistant_tokens(
|
|
327
|
-
mask_tensor: torch.Tensor, input_ids: torch.Tensor, processor: Any
|
|
328
|
-
) -> list[str]:
|
|
329
|
-
"""Render assistant tokens as a list of continuous batches."""
|
|
330
|
-
# Get the mask as a 1D tensor
|
|
331
|
-
mask_1d = mask_tensor[0]
|
|
332
|
-
|
|
333
|
-
# Find continuous sequences of non-zero values
|
|
334
|
-
batches = []
|
|
335
|
-
start_idx = None
|
|
336
|
-
|
|
337
|
-
for i in range(len(mask_1d)):
|
|
338
|
-
if mask_1d[i] != 0 and start_idx is None:
|
|
339
|
-
# Start of a new batch
|
|
340
|
-
start_idx = i
|
|
341
|
-
elif mask_1d[i] == 0 and start_idx is not None:
|
|
342
|
-
# End of current batch
|
|
343
|
-
# Extract and decode the tokens in this batch
|
|
344
|
-
batch_token_ids = input_ids[0][start_idx:i].tolist()
|
|
345
|
-
decoded_batch = processor.decode(batch_token_ids)
|
|
346
|
-
batches.append(decoded_batch)
|
|
347
|
-
start_idx = None
|
|
348
|
-
|
|
349
|
-
# Handle case where the last batch extends to the end
|
|
350
|
-
if start_idx is not None:
|
|
351
|
-
batch_token_ids = input_ids[0][start_idx:].tolist()
|
|
352
|
-
decoded_batch = processor.decode(batch_token_ids)
|
|
353
|
-
batches.append(decoded_batch)
|
|
354
|
-
|
|
355
|
-
return batches
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
|
|
359
|
-
"""Calculate entropy from logits in a memory-efficient way."""
|
|
360
|
-
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
|
|
361
|
-
entropy = -torch.sum(torch.exp(log_probs) * log_probs, dim=-1)
|
|
362
|
-
return entropy
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
def preprocess_advantages(group: list[Trace], config: Config) -> list[TrainingSample]:
|
|
366
|
-
"""Preprocess a group of traces."""
|
|
367
|
-
group_size = config.training.group_size
|
|
368
|
-
if config.training.batch_level == "group":
|
|
369
|
-
groups = [group[i : i + group_size] for i in range(0, len(group), group_size)]
|
|
370
|
-
elif config.training.batch_level == "batch":
|
|
371
|
-
groups = [group]
|
|
372
|
-
else:
|
|
373
|
-
raise ValueError(f"Invalid batch level: {config.training.batch_level}")
|
|
374
|
-
|
|
375
|
-
all_samples = []
|
|
376
|
-
for i, group in enumerate(groups):
|
|
377
|
-
rewards = np.array([trace.reward for trace in group])
|
|
378
|
-
mean_reward = np.mean(rewards)
|
|
379
|
-
std_reward = np.std(rewards)
|
|
380
|
-
|
|
381
|
-
# Calculate advantages
|
|
382
|
-
samples = [TrainingSample(**trace.model_dump()) for trace in group]
|
|
383
|
-
for sample, reward in zip(samples, rewards, strict=True):
|
|
384
|
-
if sample.isError:
|
|
385
|
-
sample.advantage = torch.Tensor(np.array([0.0]))
|
|
386
|
-
continue
|
|
387
|
-
# No std (non-baseline GRPO)
|
|
388
|
-
if config.training.no_std:
|
|
389
|
-
advantage_value = reward - mean_reward
|
|
390
|
-
else:
|
|
391
|
-
# Avoid division by zero
|
|
392
|
-
if std_reward < 1e-6:
|
|
393
|
-
advantage_value = torch.Tensor(np.array([0.0]))
|
|
394
|
-
else:
|
|
395
|
-
advantage_value = (reward - mean_reward) / std_reward
|
|
396
|
-
# Leave one out RLOO/LOOP
|
|
397
|
-
if config.training.leave_one_out:
|
|
398
|
-
advantage_value = advantage_value * len(group) / (len(group) - 1)
|
|
399
|
-
sample.advantage = torch.Tensor(np.array([advantage_value]))
|
|
400
|
-
hud_console.info_log(
|
|
401
|
-
f"Advantages for group {i} [{mean_reward:.4f} ± {std_reward:.4f}]:"
|
|
402
|
-
f"{[round(sample.advantage.item(), 4) for sample in samples if sample.advantage is not None]}" # noqa: E501
|
|
403
|
-
)
|
|
404
|
-
|
|
405
|
-
all_samples.extend(samples)
|
|
406
|
-
|
|
407
|
-
return all_samples
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
def batch_training_samples(samples: list[TrainingSample]) -> list[TrainingSample]:
|
|
411
|
-
"""Create batched model inputs from a list of TrainingSample.
|
|
412
|
-
|
|
413
|
-
Pads token sequences to the maximum length in the list and zero-pads
|
|
414
|
-
images to the maximum H/W when present. Returns a dictionary of batched
|
|
415
|
-
tensors suitable for a single forward pass. Keeps assistant_masks for
|
|
416
|
-
masked scoring.
|
|
417
|
-
"""
|
|
418
|
-
if not samples:
|
|
419
|
-
hud_console.warning("No samples to batch.")
|
|
420
|
-
return []
|
|
421
|
-
|
|
422
|
-
for s in samples:
|
|
423
|
-
if (
|
|
424
|
-
"assistant_mask" not in s.inputs
|
|
425
|
-
or s.inputs["assistant_mask"].sum() == 0
|
|
426
|
-
or s.advantage == 0.0
|
|
427
|
-
) and len(samples) > 1:
|
|
428
|
-
hud_console.info("Removing sample with zero advantage.")
|
|
429
|
-
samples.remove(s)
|
|
430
|
-
|
|
431
|
-
if len(samples) == 1:
|
|
432
|
-
return samples
|
|
433
|
-
|
|
434
|
-
import torch.nn.functional as F
|
|
435
|
-
|
|
436
|
-
new_samples = [TrainingSample()]
|
|
437
|
-
|
|
438
|
-
input_keys_to_expand = ["input_ids", "attention_mask", "assistant_mask"]
|
|
439
|
-
input_keys_to_cat = ["pixel_values", "image_grid_thw"]
|
|
440
|
-
updated_inputs: dict[str, list[torch.Tensor]] = {
|
|
441
|
-
k: [] for k in input_keys_to_expand + input_keys_to_cat
|
|
442
|
-
}
|
|
443
|
-
|
|
444
|
-
# Sanity check dimensions
|
|
445
|
-
for s in samples:
|
|
446
|
-
for k in input_keys_to_expand + input_keys_to_cat:
|
|
447
|
-
val = s.inputs.get(k)
|
|
448
|
-
if val is not None:
|
|
449
|
-
if k in input_keys_to_expand:
|
|
450
|
-
if val.dim() == 2 and val.size(0) == 1:
|
|
451
|
-
val = val[0]
|
|
452
|
-
elif val.dim() != 1:
|
|
453
|
-
raise ValueError(f"{k} has unexpected dimensions: {val.shape}")
|
|
454
|
-
updated_inputs[k].append(val)
|
|
455
|
-
|
|
456
|
-
# Pad 1D sequences to max length
|
|
457
|
-
max_len = max(t.size(-1) for t in updated_inputs["input_ids"])
|
|
458
|
-
|
|
459
|
-
def pad_1d(x: torch.Tensor, pad_to: int, pad_value: int) -> torch.Tensor:
|
|
460
|
-
pad = pad_to - x.size(-1)
|
|
461
|
-
return F.pad(x, (0, pad), value=pad_value) if pad > 0 else x
|
|
462
|
-
|
|
463
|
-
stacked_inputs: dict[str, torch.Tensor] = {}
|
|
464
|
-
# These are 1D sequences that need padding
|
|
465
|
-
for k in input_keys_to_expand:
|
|
466
|
-
if updated_inputs[k]:
|
|
467
|
-
# assistant_mask is T-1, others are T
|
|
468
|
-
if k == "assistant_mask":
|
|
469
|
-
stacked_inputs[k] = torch.stack(
|
|
470
|
-
[pad_1d(x, max_len - 1, 0) for x in updated_inputs[k]], dim=0
|
|
471
|
-
)
|
|
472
|
-
else:
|
|
473
|
-
stacked_inputs[k] = torch.stack(
|
|
474
|
-
[pad_1d(x, max_len, 0) for x in updated_inputs[k]], dim=0
|
|
475
|
-
)
|
|
476
|
-
|
|
477
|
-
for k in input_keys_to_cat:
|
|
478
|
-
if updated_inputs[k]:
|
|
479
|
-
# pixel_values and image_grid_thw are concatenated across all images from all samples
|
|
480
|
-
# Shape of pixel_values: (sum of all patches from all images, feature_dim)
|
|
481
|
-
# Shape of image_grid_thw: (sum of all images, 3)
|
|
482
|
-
stacked_inputs[k] = torch.cat(updated_inputs[k], dim=0)
|
|
483
|
-
else:
|
|
484
|
-
stacked_inputs.pop(k)
|
|
485
|
-
|
|
486
|
-
new_samples[0].inputs = stacked_inputs
|
|
487
|
-
|
|
488
|
-
# Pad logprobs to max length before stacking
|
|
489
|
-
# old_logprobs and ref_logprobs have shape [seq_len] or [1, seq_len] after gathering
|
|
490
|
-
def pad_logprobs(logprobs: torch.Tensor | None, max_len: int) -> torch.Tensor:
|
|
491
|
-
# Always work with 1D tensor, squeeze batch dim if present
|
|
492
|
-
if logprobs is None:
|
|
493
|
-
return torch.tensor([float("-inf")], dtype=torch.float32)
|
|
494
|
-
if logprobs.dim() == 2 and logprobs.size(0) == 1:
|
|
495
|
-
logprobs = logprobs.squeeze(0)
|
|
496
|
-
elif logprobs.dim() != 1:
|
|
497
|
-
raise ValueError(
|
|
498
|
-
f"Expected logprobs to have 1 or 2 dimensions, got {logprobs.dim()} with shape {logprobs.shape}" # noqa: E501
|
|
499
|
-
)
|
|
500
|
-
|
|
501
|
-
# Now logprobs is [seq_len]
|
|
502
|
-
seq_len = logprobs.size(0) if logprobs is not None else 0
|
|
503
|
-
if seq_len < max_len:
|
|
504
|
-
pad_size = max_len - seq_len
|
|
505
|
-
# Pad with -inf (log of 0 probability) along sequence dimension
|
|
506
|
-
return F.pad(logprobs, (0, pad_size), value=float("-inf"))
|
|
507
|
-
return logprobs
|
|
508
|
-
|
|
509
|
-
# Stack padded logprobs (these are T-1 length)
|
|
510
|
-
old_logprobs_list = [pad_logprobs(s.old_logprobs, max_len - 1) for s in samples]
|
|
511
|
-
ref_logprobs_list = [pad_logprobs(s.ref_logprobs, max_len - 1) for s in samples]
|
|
512
|
-
|
|
513
|
-
new_samples[0].old_logprobs = torch.stack(old_logprobs_list, dim=0)
|
|
514
|
-
new_samples[0].ref_logprobs = torch.stack(ref_logprobs_list, dim=0)
|
|
515
|
-
|
|
516
|
-
# Stack advantages, checking for None values
|
|
517
|
-
advantages = [s.advantage for s in samples]
|
|
518
|
-
if any(adv is None for adv in advantages):
|
|
519
|
-
raise ValueError(
|
|
520
|
-
"Some samples have None advantages. Make sure advantages are computed before batching."
|
|
521
|
-
)
|
|
522
|
-
new_samples[0].advantage = torch.stack(advantages, dim=0) # type: ignore
|
|
523
|
-
|
|
524
|
-
return new_samples
|
hud/rl/vllm_adapter.py
DELETED
|
@@ -1,143 +0,0 @@
|
|
|
1
|
-
"""vLLM adapter management for LoRA hot-swapping."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
import json
|
|
6
|
-
import logging
|
|
7
|
-
|
|
8
|
-
import requests
|
|
9
|
-
|
|
10
|
-
from hud.utils.hud_console import HUDConsole
|
|
11
|
-
|
|
12
|
-
hud_console = HUDConsole(logging.getLogger(__name__))
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class VLLMAdapter:
|
|
16
|
-
"""Manages LoRA adapter loading/unloading in vLLM."""
|
|
17
|
-
|
|
18
|
-
def __init__(self, base_url: str, api_key: str) -> None:
|
|
19
|
-
self.base_url = base_url
|
|
20
|
-
self.api_key = api_key
|
|
21
|
-
self.current_adapter = None
|
|
22
|
-
|
|
23
|
-
def load_adapter(self, adapter_name: str, adapter_path: str, timeout: int = 30) -> bool:
|
|
24
|
-
"""
|
|
25
|
-
Hot-load a LoRA adapter to vLLM.
|
|
26
|
-
|
|
27
|
-
Args:
|
|
28
|
-
adapter_name: Name to register the adapter as
|
|
29
|
-
adapter_path: Path to the adapter checkpoint
|
|
30
|
-
timeout: Request timeout in seconds
|
|
31
|
-
|
|
32
|
-
Returns:
|
|
33
|
-
True if successful, False otherwise
|
|
34
|
-
"""
|
|
35
|
-
url = f"{self.base_url}/load_lora_adapter"
|
|
36
|
-
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
|
37
|
-
payload = {"lora_name": adapter_name, "lora_path": adapter_path}
|
|
38
|
-
# Implement exponential backoff for retrying the adapter load request.
|
|
39
|
-
max_retries = 8
|
|
40
|
-
backoff_factor = 2
|
|
41
|
-
delay = 1 # initial delay in seconds
|
|
42
|
-
|
|
43
|
-
for attempt in range(1, max_retries + 1):
|
|
44
|
-
try:
|
|
45
|
-
response = requests.post(
|
|
46
|
-
url, headers=headers, data=json.dumps(payload), timeout=timeout
|
|
47
|
-
)
|
|
48
|
-
response.raise_for_status()
|
|
49
|
-
|
|
50
|
-
self.current_adapter = adapter_name
|
|
51
|
-
hud_console.info(f"[VLLMAdapter] Loaded adapter: {adapter_name}")
|
|
52
|
-
return True
|
|
53
|
-
|
|
54
|
-
except requests.exceptions.RequestException as e:
|
|
55
|
-
if attempt == max_retries:
|
|
56
|
-
hud_console.error(
|
|
57
|
-
f"[VLLMAdapter] Failed to load adapter {adapter_name} after {attempt} attempts: {e}" # noqa: E501
|
|
58
|
-
)
|
|
59
|
-
return False
|
|
60
|
-
else:
|
|
61
|
-
hud_console.warning(
|
|
62
|
-
f"[VLLMAdapter] Load adapter {adapter_name} failed (attempt {attempt}/{max_retries}): {e}. Retrying in {delay} seconds...", # noqa: E501
|
|
63
|
-
)
|
|
64
|
-
import time
|
|
65
|
-
|
|
66
|
-
time.sleep(delay)
|
|
67
|
-
delay *= backoff_factor
|
|
68
|
-
|
|
69
|
-
return False
|
|
70
|
-
|
|
71
|
-
def unload_adapter(self, adapter_name: str) -> bool:
|
|
72
|
-
"""
|
|
73
|
-
Unload a LoRA adapter from vLLM.
|
|
74
|
-
|
|
75
|
-
Args:
|
|
76
|
-
adapter_name: Name of the adapter to unload
|
|
77
|
-
|
|
78
|
-
Returns:
|
|
79
|
-
True if successful, False otherwise
|
|
80
|
-
"""
|
|
81
|
-
url = f"{self.base_url}/unload_lora_adapter"
|
|
82
|
-
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
|
83
|
-
payload = {"lora_name": adapter_name}
|
|
84
|
-
|
|
85
|
-
try:
|
|
86
|
-
response = requests.post(url, headers=headers, data=json.dumps(payload), timeout=30)
|
|
87
|
-
response.raise_for_status()
|
|
88
|
-
|
|
89
|
-
if self.current_adapter == adapter_name:
|
|
90
|
-
self.current_adapter = None
|
|
91
|
-
|
|
92
|
-
hud_console.info(f"[VLLMAdapter] Unloaded adapter: {adapter_name}")
|
|
93
|
-
return True
|
|
94
|
-
|
|
95
|
-
except requests.exceptions.RequestException as e:
|
|
96
|
-
hud_console.error(f"[VLLMAdapter] Failed to unload adapter {adapter_name}: {e}")
|
|
97
|
-
return False
|
|
98
|
-
|
|
99
|
-
def list_adapters(self) -> list | None:
|
|
100
|
-
"""
|
|
101
|
-
List all loaded LoRA adapters in vLLM.
|
|
102
|
-
|
|
103
|
-
Returns:
|
|
104
|
-
List of adapter names, or None if failed
|
|
105
|
-
"""
|
|
106
|
-
url = f"{self.base_url}/list_lora_adapters"
|
|
107
|
-
headers = {"Authorization": f"Bearer {self.api_key}"}
|
|
108
|
-
|
|
109
|
-
try:
|
|
110
|
-
response = requests.get(url, headers=headers, timeout=10)
|
|
111
|
-
response.raise_for_status()
|
|
112
|
-
return response.json().get("adapters", [])
|
|
113
|
-
|
|
114
|
-
except requests.exceptions.RequestException as e:
|
|
115
|
-
hud_console.error(f"[VLLMAdapter] Failed to list adapters: {e}")
|
|
116
|
-
return None
|
|
117
|
-
|
|
118
|
-
def get_current(self) -> str | None:
|
|
119
|
-
"""Get the name of the currently loaded adapter."""
|
|
120
|
-
return self.current_adapter
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
# Convenience function for standalone use
|
|
124
|
-
def hotload_lora(
|
|
125
|
-
adapter_name: str,
|
|
126
|
-
adapter_path: str,
|
|
127
|
-
base_url: str = "http://localhost:8000/v1",
|
|
128
|
-
api_key: str = "token-abc123",
|
|
129
|
-
) -> bool:
|
|
130
|
-
"""
|
|
131
|
-
Quick function to hot-load a LoRA adapter.
|
|
132
|
-
|
|
133
|
-
Args:
|
|
134
|
-
adapter_name: Name for the adapter
|
|
135
|
-
adapter_path: Path to adapter checkpoint
|
|
136
|
-
base_url: vLLM server URL
|
|
137
|
-
api_key: API key for vLLM
|
|
138
|
-
|
|
139
|
-
Returns:
|
|
140
|
-
True if successful
|
|
141
|
-
"""
|
|
142
|
-
adapter = VLLMAdapter(base_url, api_key)
|
|
143
|
-
return adapter.load_adapter(adapter_name, adapter_path)
|