hud-python 0.4.35__py3-none-any.whl → 0.4.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hud-python might be problematic. Click here for more details.
- hud/agents/tests/test_claude.py +32 -7
- hud/agents/tests/test_openai.py +29 -6
- hud/cli/__init__.py +209 -75
- hud/cli/build.py +9 -4
- hud/cli/dev.py +20 -39
- hud/cli/eval.py +3 -2
- hud/cli/flows/tasks.py +1 -0
- hud/cli/init.py +222 -629
- hud/cli/pull.py +6 -0
- hud/cli/push.py +2 -1
- hud/cli/rl/remote_runner.py +3 -1
- hud/cli/tests/test_build.py +3 -27
- hud/cli/tests/test_mcp_server.py +1 -12
- hud/cli/utils/config.py +85 -0
- hud/cli/utils/docker.py +21 -39
- hud/cli/utils/environment.py +4 -3
- hud/cli/utils/interactive.py +2 -1
- hud/cli/utils/local_runner.py +204 -0
- hud/cli/utils/metadata.py +3 -1
- hud/cli/utils/package_runner.py +292 -0
- hud/cli/utils/remote_runner.py +4 -1
- hud/clients/mcp_use.py +30 -7
- hud/datasets/parallel.py +3 -1
- hud/datasets/runner.py +4 -1
- hud/otel/context.py +38 -4
- hud/rl/buffer.py +3 -0
- hud/rl/tests/test_learner.py +1 -1
- hud/server/server.py +157 -1
- hud/settings.py +38 -0
- hud/shared/hints.py +1 -1
- hud/utils/tests/test_version.py +1 -1
- hud/version.py +1 -1
- {hud_python-0.4.35.dist-info → hud_python-0.4.36.dist-info}/METADATA +30 -12
- {hud_python-0.4.35.dist-info → hud_python-0.4.36.dist-info}/RECORD +37 -34
- {hud_python-0.4.35.dist-info → hud_python-0.4.36.dist-info}/WHEEL +0 -0
- {hud_python-0.4.35.dist-info → hud_python-0.4.36.dist-info}/entry_points.txt +0 -0
- {hud_python-0.4.35.dist-info → hud_python-0.4.36.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
"""Run Python modules or commands as MCP servers.
|
|
2
|
+
|
|
3
|
+
This module handles direct execution of MCP servers, including:
|
|
4
|
+
- Python modules with an 'mcp' attribute
|
|
5
|
+
- External commands via FastMCP proxy
|
|
6
|
+
- Auto-reload functionality for development
|
|
7
|
+
|
|
8
|
+
For Docker container execution, see hud dev command.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import importlib
|
|
14
|
+
import logging
|
|
15
|
+
import os
|
|
16
|
+
import shlex
|
|
17
|
+
import signal
|
|
18
|
+
import subprocess
|
|
19
|
+
import sys
|
|
20
|
+
import threading
|
|
21
|
+
import time
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
from typing import Any
|
|
24
|
+
|
|
25
|
+
from fastmcp import FastMCP
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
async def run_package_as_mcp(
|
|
31
|
+
command: str | list[str],
|
|
32
|
+
transport: str = "stdio",
|
|
33
|
+
port: int = 8765,
|
|
34
|
+
verbose: bool = False,
|
|
35
|
+
reload: bool = False,
|
|
36
|
+
watch_paths: list[str] | None = None,
|
|
37
|
+
server_attr: str = "mcp",
|
|
38
|
+
**extra_kwargs: Any,
|
|
39
|
+
) -> None:
|
|
40
|
+
"""Run a command as an MCP server.
|
|
41
|
+
|
|
42
|
+
Can run:
|
|
43
|
+
- Python modules: 'controller' (imports and looks for mcp attribute)
|
|
44
|
+
- Python -m commands: 'python -m controller'
|
|
45
|
+
- Docker commands: 'docker run -it my-mcp-server'
|
|
46
|
+
- Any executable: './my-mcp-binary'
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
command: Command to run (string or list)
|
|
50
|
+
transport: Transport type ("stdio" or "http")
|
|
51
|
+
port: Port for HTTP transport
|
|
52
|
+
verbose: Enable verbose logging
|
|
53
|
+
reload: Enable auto-reload on file changes
|
|
54
|
+
watch_paths: Paths to watch for changes (defaults to ['.'])
|
|
55
|
+
**extra_kwargs: Additional arguments
|
|
56
|
+
"""
|
|
57
|
+
# Set up logging
|
|
58
|
+
if verbose:
|
|
59
|
+
logging.basicConfig(level=logging.DEBUG)
|
|
60
|
+
else:
|
|
61
|
+
logging.basicConfig(level=logging.INFO)
|
|
62
|
+
|
|
63
|
+
# Handle reload mode
|
|
64
|
+
if reload:
|
|
65
|
+
if watch_paths is None:
|
|
66
|
+
watch_paths = ["."]
|
|
67
|
+
|
|
68
|
+
# Detect external command vs module reliably.
|
|
69
|
+
# If command is a string and contains spaces (e.g., "uv run python -m controller")
|
|
70
|
+
# treat as external command. Otherwise, detect common launchers or paths.
|
|
71
|
+
is_external_cmd = False
|
|
72
|
+
if isinstance(command, list):
|
|
73
|
+
is_external_cmd = True
|
|
74
|
+
elif isinstance(command, str):
|
|
75
|
+
stripped = command.strip()
|
|
76
|
+
if " " in stripped or any(
|
|
77
|
+
stripped.startswith(x)
|
|
78
|
+
for x in ["python", "uv ", "docker", "./", "/", ".\\", "C:\\"]
|
|
79
|
+
):
|
|
80
|
+
is_external_cmd = True
|
|
81
|
+
|
|
82
|
+
if is_external_cmd:
|
|
83
|
+
# External command - pass command list directly
|
|
84
|
+
cmd_list = shlex.split(command) if isinstance(command, str) else command
|
|
85
|
+
run_with_reload(cmd_list, watch_paths, verbose)
|
|
86
|
+
else:
|
|
87
|
+
# Python module - use sys.argv approach
|
|
88
|
+
run_with_reload(None, watch_paths, verbose)
|
|
89
|
+
return
|
|
90
|
+
|
|
91
|
+
# Determine if it's a module import or a command
|
|
92
|
+
if isinstance(command, str) and not any(
|
|
93
|
+
command.startswith(x) for x in ["python", "docker", "./", "/", ".\\", "C:\\"]
|
|
94
|
+
):
|
|
95
|
+
# Treat as Python module for backwards compatibility
|
|
96
|
+
logger.info("Importing module: %s", command)
|
|
97
|
+
module = importlib.import_module(command)
|
|
98
|
+
|
|
99
|
+
# Look for server attribute in the module
|
|
100
|
+
if not hasattr(module, server_attr):
|
|
101
|
+
logger.error(
|
|
102
|
+
"Module '%s' does not have an '%s' attribute (MCPServer instance)",
|
|
103
|
+
command,
|
|
104
|
+
server_attr,
|
|
105
|
+
)
|
|
106
|
+
sys.exit(1)
|
|
107
|
+
|
|
108
|
+
server = getattr(module, server_attr)
|
|
109
|
+
|
|
110
|
+
# Configure server options
|
|
111
|
+
run_kwargs = {
|
|
112
|
+
"transport": transport,
|
|
113
|
+
"show_banner": False,
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
if transport == "http":
|
|
117
|
+
# FastMCP expects port/path directly
|
|
118
|
+
run_kwargs["port"] = port
|
|
119
|
+
run_kwargs["path"] = "/mcp"
|
|
120
|
+
|
|
121
|
+
# Merge any extra kwargs
|
|
122
|
+
run_kwargs.update(extra_kwargs)
|
|
123
|
+
|
|
124
|
+
# Run the server
|
|
125
|
+
logger.info("Running %s on %s transport", server.name, transport)
|
|
126
|
+
await server.run_async(**run_kwargs)
|
|
127
|
+
else:
|
|
128
|
+
# Run as external command using shared proxy utility
|
|
129
|
+
# Parse command if string
|
|
130
|
+
cmd_list = shlex.split(command) if isinstance(command, str) else command
|
|
131
|
+
|
|
132
|
+
# Replace 'python' with the current interpreter to preserve venv
|
|
133
|
+
if cmd_list[0] == "python":
|
|
134
|
+
cmd_list[0] = sys.executable
|
|
135
|
+
logger.info("Replaced 'python' with: %s", sys.executable)
|
|
136
|
+
|
|
137
|
+
logger.info("Running command: %s", " ".join(cmd_list))
|
|
138
|
+
|
|
139
|
+
# Create MCP config for the command
|
|
140
|
+
config = {
|
|
141
|
+
"mcpServers": {
|
|
142
|
+
"default": {
|
|
143
|
+
"command": cmd_list[0],
|
|
144
|
+
"args": cmd_list[1:] if len(cmd_list) > 1 else [],
|
|
145
|
+
# transport defaults to stdio
|
|
146
|
+
}
|
|
147
|
+
}
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
# Create proxy server
|
|
151
|
+
proxy = FastMCP.as_proxy(config, name=f"HUD Run - {cmd_list[0]}")
|
|
152
|
+
|
|
153
|
+
# Run the proxy
|
|
154
|
+
await proxy.run_async(
|
|
155
|
+
transport=transport if transport == "http" or transport == "stdio" else None,
|
|
156
|
+
port=port if transport == "http" else None,
|
|
157
|
+
show_banner=False,
|
|
158
|
+
**extra_kwargs,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def run_with_reload(
|
|
163
|
+
target_func: Any,
|
|
164
|
+
watch_paths: list[str],
|
|
165
|
+
verbose: bool = False,
|
|
166
|
+
) -> None:
|
|
167
|
+
"""Run a function or command with file watching and auto-reload.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
target_func: Function to run (sync) or command list
|
|
171
|
+
watch_paths: Paths to watch for changes
|
|
172
|
+
verbose: Enable verbose logging
|
|
173
|
+
"""
|
|
174
|
+
try:
|
|
175
|
+
import watchfiles
|
|
176
|
+
except ImportError:
|
|
177
|
+
logger.error("watchfiles is required for --reload. Install with: pip install watchfiles")
|
|
178
|
+
sys.exit(1)
|
|
179
|
+
|
|
180
|
+
# Resolve watch paths
|
|
181
|
+
resolved_paths = []
|
|
182
|
+
for path_str in watch_paths:
|
|
183
|
+
path = Path(path_str).resolve()
|
|
184
|
+
if path.is_file():
|
|
185
|
+
# Watch the directory containing the file
|
|
186
|
+
resolved_paths.append(str(path.parent))
|
|
187
|
+
else:
|
|
188
|
+
resolved_paths.append(str(path))
|
|
189
|
+
|
|
190
|
+
def run_and_restart() -> None:
|
|
191
|
+
"""Run the target function in a loop, restarting on file changes."""
|
|
192
|
+
|
|
193
|
+
process = None
|
|
194
|
+
|
|
195
|
+
def handle_signal(signum: int, frame: Any) -> None:
|
|
196
|
+
"""Handle signals by terminating the subprocess."""
|
|
197
|
+
if process:
|
|
198
|
+
process.terminate()
|
|
199
|
+
sys.exit(0)
|
|
200
|
+
|
|
201
|
+
signal.signal(signal.SIGTERM, handle_signal)
|
|
202
|
+
signal.signal(signal.SIGINT, handle_signal)
|
|
203
|
+
|
|
204
|
+
stop_event = threading.Event() # Define stop_event at the start
|
|
205
|
+
|
|
206
|
+
while True:
|
|
207
|
+
# Run the target function or command
|
|
208
|
+
if target_func is None:
|
|
209
|
+
# Use sys.argv approach for Python modules
|
|
210
|
+
child_args = [a for a in sys.argv[1:] if a != "--reload"]
|
|
211
|
+
# If first arg is already 'run', don't inject it again
|
|
212
|
+
if child_args and child_args[0] == "run":
|
|
213
|
+
cmd = [sys.executable, "-m", "hud", *child_args]
|
|
214
|
+
else:
|
|
215
|
+
cmd = [sys.executable, "-m", "hud", "run", *child_args]
|
|
216
|
+
elif isinstance(target_func, list):
|
|
217
|
+
# It's a command list
|
|
218
|
+
cmd = target_func
|
|
219
|
+
else:
|
|
220
|
+
# It's a callable - run it directly
|
|
221
|
+
target_func()
|
|
222
|
+
# Wait for file changes before restarting
|
|
223
|
+
stop_event.wait()
|
|
224
|
+
continue
|
|
225
|
+
|
|
226
|
+
if verbose:
|
|
227
|
+
logger.info("Starting process: %s", " ".join(cmd))
|
|
228
|
+
|
|
229
|
+
process = subprocess.Popen(cmd, env=os.environ) # noqa: S603
|
|
230
|
+
|
|
231
|
+
# Watch for changes
|
|
232
|
+
try:
|
|
233
|
+
# Use a proper threading.Event for stop_event as required by watchfiles
|
|
234
|
+
stop_event = threading.Event()
|
|
235
|
+
|
|
236
|
+
def _wait_and_set(
|
|
237
|
+
stop_event: threading.Event, process: subprocess.Popen[bytes]
|
|
238
|
+
) -> None:
|
|
239
|
+
try:
|
|
240
|
+
if process is not None:
|
|
241
|
+
process.wait()
|
|
242
|
+
finally:
|
|
243
|
+
stop_event.set()
|
|
244
|
+
|
|
245
|
+
threading.Thread(
|
|
246
|
+
target=_wait_and_set, args=(stop_event, process), daemon=True
|
|
247
|
+
).start()
|
|
248
|
+
|
|
249
|
+
for changes in watchfiles.watch(*resolved_paths, stop_event=stop_event):
|
|
250
|
+
logger.info("Raw changes detected: %s", changes)
|
|
251
|
+
# Filter for relevant file types
|
|
252
|
+
relevant_changes = [
|
|
253
|
+
(change_type, path)
|
|
254
|
+
for change_type, path in changes
|
|
255
|
+
if any(path.endswith(ext) for ext in [".py", ".json", ".toml", ".yaml"])
|
|
256
|
+
and "__pycache__" not in path
|
|
257
|
+
and not Path(path).name.startswith(".")
|
|
258
|
+
]
|
|
259
|
+
|
|
260
|
+
if relevant_changes:
|
|
261
|
+
logger.info("File changes detected, restarting server...")
|
|
262
|
+
if verbose:
|
|
263
|
+
for change_type, path in relevant_changes:
|
|
264
|
+
logger.debug(" %s: %s", change_type, path)
|
|
265
|
+
|
|
266
|
+
# Terminate the process
|
|
267
|
+
if process is not None:
|
|
268
|
+
process.terminate()
|
|
269
|
+
try:
|
|
270
|
+
if process is not None:
|
|
271
|
+
process.wait(timeout=5)
|
|
272
|
+
except subprocess.TimeoutExpired:
|
|
273
|
+
if process is not None:
|
|
274
|
+
process.kill()
|
|
275
|
+
process.wait()
|
|
276
|
+
|
|
277
|
+
# Brief pause before restart
|
|
278
|
+
time.sleep(0.1)
|
|
279
|
+
break
|
|
280
|
+
else:
|
|
281
|
+
logger.debug("Changes detected but filtered out: %s", changes)
|
|
282
|
+
except KeyboardInterrupt:
|
|
283
|
+
# Handle Ctrl+C gracefully
|
|
284
|
+
if process:
|
|
285
|
+
process.terminate()
|
|
286
|
+
process.wait()
|
|
287
|
+
break
|
|
288
|
+
|
|
289
|
+
# Always act as the parent. The child is launched without --reload,
|
|
290
|
+
# so it won't re-enter this function.
|
|
291
|
+
|
|
292
|
+
run_and_restart()
|
hud/cli/utils/remote_runner.py
CHANGED
|
@@ -293,7 +293,10 @@ def run_remote_server(
|
|
|
293
293
|
if not api_key:
|
|
294
294
|
api_key = settings.api_key
|
|
295
295
|
if not api_key:
|
|
296
|
-
click.echo(
|
|
296
|
+
click.echo(
|
|
297
|
+
"❌ API key required. Set HUD_API_KEY in your environment or run: hud set HUD_API_KEY=your-key-here", # noqa: E501
|
|
298
|
+
err=True,
|
|
299
|
+
)
|
|
297
300
|
sys.exit(1)
|
|
298
301
|
|
|
299
302
|
# Build headers
|
hud/clients/mcp_use.py
CHANGED
|
@@ -5,19 +5,22 @@ from __future__ import annotations
|
|
|
5
5
|
import logging
|
|
6
6
|
import traceback
|
|
7
7
|
from typing import Any
|
|
8
|
+
from urllib.parse import urlparse
|
|
8
9
|
|
|
9
10
|
from mcp import Implementation, types
|
|
10
11
|
from mcp.shared.exceptions import McpError
|
|
11
12
|
from mcp_use.client import MCPClient as MCPUseClient
|
|
12
13
|
from mcp_use.session import MCPSession as MCPUseSession
|
|
14
|
+
from mcp_use.types.http import HttpOptions
|
|
13
15
|
from pydantic import AnyUrl
|
|
14
16
|
|
|
17
|
+
from hud.settings import settings
|
|
15
18
|
from hud.types import MCPToolCall, MCPToolResult
|
|
16
19
|
from hud.utils.hud_console import HUDConsole
|
|
17
20
|
from hud.version import __version__ as hud_version
|
|
18
21
|
|
|
19
22
|
from .base import BaseHUDClient
|
|
20
|
-
from .utils.
|
|
23
|
+
from .utils.retry_transport import create_retry_httpx_client
|
|
21
24
|
|
|
22
25
|
logger = logging.getLogger(__name__)
|
|
23
26
|
hud_console = HUDConsole(logger=logger)
|
|
@@ -30,7 +33,11 @@ class MCPUseHUDClient(BaseHUDClient):
|
|
|
30
33
|
name="hud-mcp-use", title="hud MCP-use Client", version=hud_version
|
|
31
34
|
)
|
|
32
35
|
|
|
33
|
-
def __init__(
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
mcp_config: dict[str, dict[str, Any]] | None = None,
|
|
39
|
+
**kwargs: Any,
|
|
40
|
+
) -> None:
|
|
34
41
|
"""
|
|
35
42
|
Initialize MCP-use client.
|
|
36
43
|
|
|
@@ -51,6 +58,12 @@ class MCPUseHUDClient(BaseHUDClient):
|
|
|
51
58
|
str, tuple[str, types.Tool, types.Tool]
|
|
52
59
|
] = {} # server_name, original_tool, prefixed_tool
|
|
53
60
|
self._client: Any | None = None # Will be MCPUseClient when available
|
|
61
|
+
# Transport options for MCP-use (disable_sse_fallback, httpx_client_factory, etc.)
|
|
62
|
+
# Default to retry-enabled HTTPX client if factory not provided
|
|
63
|
+
self._http_options: HttpOptions = HttpOptions(
|
|
64
|
+
httpx_client_factory=create_retry_httpx_client,
|
|
65
|
+
disable_sse_fallback=True,
|
|
66
|
+
)
|
|
54
67
|
|
|
55
68
|
async def _connect(self, mcp_config: dict[str, dict[str, Any]]) -> None:
|
|
56
69
|
"""Create all sessions for MCP-use client."""
|
|
@@ -58,19 +71,29 @@ class MCPUseHUDClient(BaseHUDClient):
|
|
|
58
71
|
logger.warning("Client is already connected, cannot connect again")
|
|
59
72
|
return
|
|
60
73
|
|
|
74
|
+
# If a server target matches HUD's MCP host and no auth is provided,
|
|
75
|
+
# inject the HUD API key as a Bearer token to avoid OAuth browser flow.
|
|
76
|
+
try:
|
|
77
|
+
hud_mcp_host = urlparse(settings.hud_mcp_url).netloc
|
|
78
|
+
if mcp_config and settings.api_key and hud_mcp_host:
|
|
79
|
+
for server_cfg in mcp_config.values():
|
|
80
|
+
server_url = server_cfg.get("url")
|
|
81
|
+
if not server_url:
|
|
82
|
+
continue
|
|
83
|
+
if urlparse(server_url).netloc == hud_mcp_host and not server_cfg.get("auth"):
|
|
84
|
+
server_cfg["auth"] = settings.api_key
|
|
85
|
+
except Exception:
|
|
86
|
+
logger.warning("Failed to parse HUD MCP URL")
|
|
87
|
+
|
|
61
88
|
config = {"mcpServers": mcp_config}
|
|
62
89
|
if MCPUseClient is None:
|
|
63
90
|
raise ImportError("MCPUseClient is not available")
|
|
64
|
-
self._client = MCPUseClient.from_dict(config)
|
|
91
|
+
self._client = MCPUseClient.from_dict(config, http_options=self._http_options)
|
|
65
92
|
try:
|
|
66
93
|
assert self._client is not None # noqa: S101
|
|
67
94
|
self._sessions = await self._client.create_all_sessions()
|
|
68
95
|
hud_console.info(f"Created {len(self._sessions)} MCP sessions")
|
|
69
96
|
|
|
70
|
-
# Patch all sessions with retry logic
|
|
71
|
-
patch_all_sessions(self._sessions)
|
|
72
|
-
hud_console.debug("Applied retry logic to all MCP sessions")
|
|
73
|
-
|
|
74
97
|
# Configure validation for all sessions based on client setting
|
|
75
98
|
try:
|
|
76
99
|
for session in self._sessions.values():
|
hud/datasets/parallel.py
CHANGED
|
@@ -115,7 +115,9 @@ def _process_worker(
|
|
|
115
115
|
task_name = task_dict.get("prompt") or f"Task {index}"
|
|
116
116
|
|
|
117
117
|
# Use the job_id to group all tasks under the same job
|
|
118
|
-
|
|
118
|
+
raw_task_id = task_dict.get("id")
|
|
119
|
+
safe_task_id = str(raw_task_id) if raw_task_id is not None else None
|
|
120
|
+
with hud.trace(task_name, job_id=job_id, task_id=safe_task_id):
|
|
119
121
|
# Convert dict to Task
|
|
120
122
|
task = Task(**task_dict)
|
|
121
123
|
|
hud/datasets/runner.py
CHANGED
|
@@ -104,7 +104,10 @@ async def run_dataset(
|
|
|
104
104
|
task_name = task_dict.get("prompt") or f"Task {index}"
|
|
105
105
|
if custom_system_prompt and "system_prompt" not in task_dict:
|
|
106
106
|
task_dict["system_prompt"] = custom_system_prompt
|
|
107
|
-
|
|
107
|
+
# Ensure task_id is a string for baggage propagation
|
|
108
|
+
raw_task_id = task_dict.get("id")
|
|
109
|
+
safe_task_id = str(raw_task_id) if raw_task_id is not None else None
|
|
110
|
+
with hud.trace(task_name, job_id=job_obj.id, task_id=safe_task_id):
|
|
108
111
|
# Convert dict to Task here, at trace level
|
|
109
112
|
task = Task(**task_dict)
|
|
110
113
|
|
hud/otel/context.py
CHANGED
|
@@ -239,8 +239,25 @@ async def _update_task_status_async(
|
|
|
239
239
|
|
|
240
240
|
try:
|
|
241
241
|
data: dict[str, Any] = {"status": status}
|
|
242
|
-
|
|
243
|
-
|
|
242
|
+
|
|
243
|
+
# Resolve effective job_id from explicit param, OTel baggage, or current job context
|
|
244
|
+
effective_job_id: str | None = job_id
|
|
245
|
+
if not effective_job_id:
|
|
246
|
+
bj = baggage.get_baggage("hud.job_id")
|
|
247
|
+
if isinstance(bj, str) and bj:
|
|
248
|
+
effective_job_id = bj
|
|
249
|
+
if not effective_job_id:
|
|
250
|
+
try:
|
|
251
|
+
from hud.telemetry.job import get_current_job # Local import to avoid cycles
|
|
252
|
+
|
|
253
|
+
current_job = get_current_job()
|
|
254
|
+
if current_job:
|
|
255
|
+
effective_job_id = current_job.id
|
|
256
|
+
except Exception:
|
|
257
|
+
effective_job_id = None
|
|
258
|
+
|
|
259
|
+
if effective_job_id:
|
|
260
|
+
data["job_id"] = effective_job_id
|
|
244
261
|
if error_message:
|
|
245
262
|
data["error_message"] = error_message
|
|
246
263
|
|
|
@@ -302,8 +319,25 @@ def _update_task_status_sync(
|
|
|
302
319
|
|
|
303
320
|
try:
|
|
304
321
|
data: dict[str, Any] = {"status": status}
|
|
305
|
-
|
|
306
|
-
|
|
322
|
+
|
|
323
|
+
# Resolve effective job_id from explicit param, OTel baggage, or current job context
|
|
324
|
+
effective_job_id: str | None = job_id
|
|
325
|
+
if not effective_job_id:
|
|
326
|
+
bj = baggage.get_baggage("hud.job_id")
|
|
327
|
+
if isinstance(bj, str) and bj:
|
|
328
|
+
effective_job_id = bj
|
|
329
|
+
if not effective_job_id:
|
|
330
|
+
try:
|
|
331
|
+
from hud.telemetry.job import get_current_job # Local import to avoid cycles
|
|
332
|
+
|
|
333
|
+
current_job = get_current_job()
|
|
334
|
+
if current_job:
|
|
335
|
+
effective_job_id = current_job.id
|
|
336
|
+
except Exception:
|
|
337
|
+
effective_job_id = None
|
|
338
|
+
|
|
339
|
+
if effective_job_id:
|
|
340
|
+
data["job_id"] = effective_job_id
|
|
307
341
|
if error_message:
|
|
308
342
|
data["error_message"] = error_message
|
|
309
343
|
|
hud/rl/buffer.py
CHANGED
|
@@ -155,6 +155,9 @@ class DatasetBuffer(Buffer[Task]):
|
|
|
155
155
|
f"This is because the number of training steps ({self.training_steps}) is not a multiple of the dataset size ({self.dataset_size})" # noqa: E501
|
|
156
156
|
)
|
|
157
157
|
|
|
158
|
+
if config.verbose:
|
|
159
|
+
hud_console.info(f"Sample task: {tasks[0]}")
|
|
160
|
+
|
|
158
161
|
self.add_fill(tasks, self.number_of_tasks, config.training.shuffle_dataset)
|
|
159
162
|
|
|
160
163
|
def _validate_tasks(self, tasks: list[Task]) -> list[Task]:
|
hud/rl/tests/test_learner.py
CHANGED
|
@@ -163,7 +163,7 @@ def test_skip_update_when_zero_adv(monkeypatch, learner_stub: GRPOLearner):
|
|
|
163
163
|
# Return a zero scalar loss that *depends* on params so backward works,
|
|
164
164
|
# but has zero gradients (no update signal).
|
|
165
165
|
def _zero_loss(self, sample) -> torch.Tensor:
|
|
166
|
-
return sum(p.sum() for p in self.policy.parameters()) * 0.0
|
|
166
|
+
return sum(p.sum() for p in self.policy.parameters()) * 0.0 # type: ignore
|
|
167
167
|
|
|
168
168
|
monkeypatch.setattr(GRPOLearner, "compute_loss", _zero_loss, raising=True)
|
|
169
169
|
|
hud/server/server.py
CHANGED
|
@@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Any
|
|
|
13
13
|
|
|
14
14
|
import anyio
|
|
15
15
|
from fastmcp.server.server import FastMCP, Transport
|
|
16
|
+
from starlette.responses import JSONResponse, Response
|
|
16
17
|
|
|
17
18
|
from hud.server.low_level import LowLevelServerWithInit
|
|
18
19
|
|
|
@@ -20,6 +21,7 @@ if TYPE_CHECKING:
|
|
|
20
21
|
from collections.abc import AsyncGenerator, Callable
|
|
21
22
|
|
|
22
23
|
from mcp.shared.context import RequestContext
|
|
24
|
+
from starlette.requests import Request
|
|
23
25
|
|
|
24
26
|
__all__ = ["MCPServer"]
|
|
25
27
|
|
|
@@ -163,7 +165,15 @@ class MCPServer(FastMCP):
|
|
|
163
165
|
# Redirect stdout to stderr during initialization to prevent
|
|
164
166
|
# any library prints from corrupting the MCP protocol
|
|
165
167
|
with contextlib.redirect_stdout(sys.stderr):
|
|
166
|
-
|
|
168
|
+
# Check if function accepts ctx parameter
|
|
169
|
+
import inspect
|
|
170
|
+
|
|
171
|
+
sig = inspect.signature(self._initializer_fn)
|
|
172
|
+
if "ctx" in sig.parameters:
|
|
173
|
+
return self._initializer_fn(ctx)
|
|
174
|
+
else:
|
|
175
|
+
# Call without ctx for simpler usage
|
|
176
|
+
return self._initializer_fn()
|
|
167
177
|
return None
|
|
168
178
|
|
|
169
179
|
# Save the old server's handlers before replacing it
|
|
@@ -233,6 +243,23 @@ class MCPServer(FastMCP):
|
|
|
233
243
|
|
|
234
244
|
_run_with_sigterm(_bootstrap)
|
|
235
245
|
|
|
246
|
+
async def run_async(
|
|
247
|
+
self,
|
|
248
|
+
transport: Transport | None = None,
|
|
249
|
+
show_banner: bool = True,
|
|
250
|
+
**transport_kwargs: Any,
|
|
251
|
+
) -> None:
|
|
252
|
+
"""Run the server with HUD enhancements."""
|
|
253
|
+
if transport is None:
|
|
254
|
+
transport = "stdio"
|
|
255
|
+
|
|
256
|
+
# Register HTTP helpers for HTTP transport
|
|
257
|
+
if transport in ("http", "sse"):
|
|
258
|
+
self._register_hud_helpers()
|
|
259
|
+
logger.info("Registered HUD helper endpoints at /hud/*")
|
|
260
|
+
|
|
261
|
+
await super().run_async(transport=transport, show_banner=show_banner, **transport_kwargs)
|
|
262
|
+
|
|
236
263
|
# Tool registration helper -- appends BaseTool to FastMCP
|
|
237
264
|
def add_tool(self, obj: Any, **kwargs: Any) -> None:
|
|
238
265
|
from hud.tools.base import BaseTool
|
|
@@ -242,3 +269,132 @@ class MCPServer(FastMCP):
|
|
|
242
269
|
return
|
|
243
270
|
|
|
244
271
|
super().add_tool(obj, **kwargs)
|
|
272
|
+
|
|
273
|
+
# Override to keep original callables when used as a decorator
|
|
274
|
+
def tool(self, name_or_fn: Any = None, **kwargs: Any) -> Any: # type: ignore[override]
|
|
275
|
+
"""Register a tool but return the original function in decorator form.
|
|
276
|
+
|
|
277
|
+
- Decorator usage (@mcp.tool, @mcp.tool("name"), @mcp.tool(name="name"))
|
|
278
|
+
registers with FastMCP and returns the original function for composition.
|
|
279
|
+
- Call-form (mcp.tool(fn, ...)) behaves the same but returns fn.
|
|
280
|
+
"""
|
|
281
|
+
# Accept BaseTool / FastMCP Tool instances or callables in call-form
|
|
282
|
+
if name_or_fn is not None and not isinstance(name_or_fn, str):
|
|
283
|
+
try:
|
|
284
|
+
from hud.tools.base import BaseTool # lazy import
|
|
285
|
+
except Exception:
|
|
286
|
+
BaseTool = tuple() # type: ignore[assignment]
|
|
287
|
+
try:
|
|
288
|
+
from fastmcp.tools.tool import Tool as _FastMcpTool
|
|
289
|
+
except Exception:
|
|
290
|
+
_FastMcpTool = tuple() # type: ignore[assignment]
|
|
291
|
+
|
|
292
|
+
# BaseTool instance → add underlying FunctionTool
|
|
293
|
+
if isinstance(name_or_fn, BaseTool):
|
|
294
|
+
super().add_tool(name_or_fn.mcp, **kwargs)
|
|
295
|
+
return name_or_fn
|
|
296
|
+
# FastMCP Tool/FunctionTool instance → add directly
|
|
297
|
+
if isinstance(name_or_fn, _FastMcpTool):
|
|
298
|
+
super().add_tool(name_or_fn, **kwargs)
|
|
299
|
+
return name_or_fn
|
|
300
|
+
# Callable function → register via FastMCP.tool and return original fn
|
|
301
|
+
if callable(name_or_fn):
|
|
302
|
+
super().tool(name_or_fn, **kwargs)
|
|
303
|
+
return name_or_fn
|
|
304
|
+
|
|
305
|
+
# Decorator form: get FastMCP's decorator, register, then return original fn
|
|
306
|
+
base_decorator = super().tool(name_or_fn, **kwargs)
|
|
307
|
+
|
|
308
|
+
def _wrapper(fn: Any) -> Any:
|
|
309
|
+
base_decorator(fn)
|
|
310
|
+
return fn
|
|
311
|
+
|
|
312
|
+
return _wrapper
|
|
313
|
+
|
|
314
|
+
def _register_hud_helpers(self) -> None:
|
|
315
|
+
"""Register HUD helper HTTP routes.
|
|
316
|
+
|
|
317
|
+
This adds:
|
|
318
|
+
- GET /hud - Overview of available endpoints
|
|
319
|
+
- GET /hud/tools - List all registered tools with their schemas
|
|
320
|
+
- GET /hud/resources - List all registered resources
|
|
321
|
+
- GET /hud/prompts - List all registered prompts
|
|
322
|
+
"""
|
|
323
|
+
|
|
324
|
+
@self.custom_route("/hud/tools", methods=["GET"])
|
|
325
|
+
async def list_tools(request: Request) -> Response:
|
|
326
|
+
"""List all registered tools with their names, descriptions, and schemas."""
|
|
327
|
+
tools = []
|
|
328
|
+
# _tools is a mapping of tool_name -> FunctionTool/Tool instance
|
|
329
|
+
for tool_key, tool in self._tool_manager._tools.items():
|
|
330
|
+
tool_data = {"name": tool_key}
|
|
331
|
+
try:
|
|
332
|
+
# Prefer converting to MCP model for consistent fields
|
|
333
|
+
mcp_tool = tool.to_mcp_tool()
|
|
334
|
+
tool_data["description"] = getattr(mcp_tool, "description", "")
|
|
335
|
+
if hasattr(mcp_tool, "inputSchema") and mcp_tool.inputSchema:
|
|
336
|
+
tool_data["input_schema"] = mcp_tool.inputSchema # type: ignore[assignment]
|
|
337
|
+
if hasattr(mcp_tool, "outputSchema") and mcp_tool.outputSchema:
|
|
338
|
+
tool_data["output_schema"] = mcp_tool.outputSchema # type: ignore[assignment]
|
|
339
|
+
except Exception:
|
|
340
|
+
# Fallback to direct attributes on FunctionTool
|
|
341
|
+
tool_data["description"] = getattr(tool, "description", "")
|
|
342
|
+
params = getattr(tool, "parameters", None)
|
|
343
|
+
if params:
|
|
344
|
+
tool_data["input_schema"] = params
|
|
345
|
+
tools.append(tool_data)
|
|
346
|
+
|
|
347
|
+
return JSONResponse({"server": self.name, "tools": tools, "count": len(tools)})
|
|
348
|
+
|
|
349
|
+
@self.custom_route("/hud/resources", methods=["GET"])
|
|
350
|
+
async def list_resources(request: Request) -> Response:
|
|
351
|
+
"""List all registered resources."""
|
|
352
|
+
resources = []
|
|
353
|
+
for resource_key, resource in self._resource_manager._resources.items():
|
|
354
|
+
resource_data = {
|
|
355
|
+
"uri": resource_key,
|
|
356
|
+
"name": resource.name,
|
|
357
|
+
"description": resource.description,
|
|
358
|
+
"mimeType": resource.mime_type,
|
|
359
|
+
}
|
|
360
|
+
resources.append(resource_data)
|
|
361
|
+
|
|
362
|
+
return JSONResponse(
|
|
363
|
+
{"server": self.name, "resources": resources, "count": len(resources)}
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
@self.custom_route("/hud/prompts", methods=["GET"])
|
|
367
|
+
async def list_prompts(request: Request) -> Response:
|
|
368
|
+
"""List all registered prompts."""
|
|
369
|
+
prompts = []
|
|
370
|
+
for prompt_key, prompt in self._prompt_manager._prompts.items():
|
|
371
|
+
prompt_data = {
|
|
372
|
+
"name": prompt_key,
|
|
373
|
+
"description": prompt.description,
|
|
374
|
+
}
|
|
375
|
+
# Check if it has arguments
|
|
376
|
+
if hasattr(prompt, "arguments") and prompt.arguments:
|
|
377
|
+
prompt_data["arguments"] = [
|
|
378
|
+
{"name": arg.name, "description": arg.description, "required": arg.required}
|
|
379
|
+
for arg in prompt.arguments
|
|
380
|
+
]
|
|
381
|
+
prompts.append(prompt_data)
|
|
382
|
+
|
|
383
|
+
return JSONResponse({"server": self.name, "prompts": prompts, "count": len(prompts)})
|
|
384
|
+
|
|
385
|
+
@self.custom_route("/hud", methods=["GET"])
|
|
386
|
+
async def hud_info(request: Request) -> Response:
|
|
387
|
+
"""Show available HUD helper endpoints."""
|
|
388
|
+
base_url = str(request.base_url).rstrip("/")
|
|
389
|
+
return JSONResponse(
|
|
390
|
+
{
|
|
391
|
+
"name": "HUD MCP Development Helpers",
|
|
392
|
+
"server": self.name,
|
|
393
|
+
"endpoints": {
|
|
394
|
+
"tools": f"{base_url}/hud/tools",
|
|
395
|
+
"resources": f"{base_url}/hud/resources",
|
|
396
|
+
"prompts": f"{base_url}/hud/prompts",
|
|
397
|
+
},
|
|
398
|
+
"description": "These endpoints help you inspect your MCP server during development.", # noqa: E501
|
|
399
|
+
}
|
|
400
|
+
)
|