hud-python 0.4.35__py3-none-any.whl → 0.4.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hud-python might be problematic. Click here for more details.
- hud/agents/__init__.py +2 -0
- hud/agents/lite_llm.py +72 -0
- hud/agents/openai_chat_generic.py +21 -7
- hud/agents/tests/test_claude.py +32 -7
- hud/agents/tests/test_openai.py +29 -6
- hud/cli/__init__.py +228 -79
- hud/cli/build.py +26 -6
- hud/cli/dev.py +21 -40
- hud/cli/eval.py +96 -15
- hud/cli/flows/tasks.py +198 -65
- hud/cli/init.py +222 -629
- hud/cli/pull.py +6 -0
- hud/cli/push.py +11 -1
- hud/cli/rl/__init__.py +14 -4
- hud/cli/rl/celebrate.py +187 -0
- hud/cli/rl/config.py +15 -8
- hud/cli/rl/local_runner.py +44 -20
- hud/cli/rl/remote_runner.py +166 -87
- hud/cli/rl/viewer.py +141 -0
- hud/cli/rl/wait_utils.py +89 -0
- hud/cli/tests/test_build.py +3 -27
- hud/cli/tests/test_mcp_server.py +1 -12
- hud/cli/utils/config.py +85 -0
- hud/cli/utils/docker.py +21 -39
- hud/cli/utils/env_check.py +196 -0
- hud/cli/utils/environment.py +4 -3
- hud/cli/utils/interactive.py +2 -1
- hud/cli/utils/local_runner.py +204 -0
- hud/cli/utils/metadata.py +3 -1
- hud/cli/utils/package_runner.py +292 -0
- hud/cli/utils/remote_runner.py +4 -1
- hud/cli/utils/source_hash.py +108 -0
- hud/clients/base.py +1 -1
- hud/clients/fastmcp.py +1 -1
- hud/clients/mcp_use.py +30 -7
- hud/datasets/parallel.py +3 -1
- hud/datasets/runner.py +4 -1
- hud/otel/config.py +1 -1
- hud/otel/context.py +40 -6
- hud/rl/buffer.py +3 -0
- hud/rl/tests/test_learner.py +1 -1
- hud/rl/vllm_adapter.py +1 -1
- hud/server/server.py +234 -7
- hud/server/tests/test_add_tool.py +60 -0
- hud/server/tests/test_context.py +128 -0
- hud/server/tests/test_mcp_server_handlers.py +44 -0
- hud/server/tests/test_mcp_server_integration.py +405 -0
- hud/server/tests/test_mcp_server_more.py +247 -0
- hud/server/tests/test_run_wrapper.py +53 -0
- hud/server/tests/test_server_extra.py +166 -0
- hud/server/tests/test_sigterm_runner.py +78 -0
- hud/settings.py +38 -0
- hud/shared/hints.py +2 -2
- hud/telemetry/job.py +2 -2
- hud/types.py +9 -2
- hud/utils/tasks.py +32 -24
- hud/utils/tests/test_version.py +1 -1
- hud/version.py +1 -1
- {hud_python-0.4.35.dist-info → hud_python-0.4.37.dist-info}/METADATA +43 -23
- {hud_python-0.4.35.dist-info → hud_python-0.4.37.dist-info}/RECORD +63 -46
- {hud_python-0.4.35.dist-info → hud_python-0.4.37.dist-info}/WHEEL +0 -0
- {hud_python-0.4.35.dist-info → hud_python-0.4.37.dist-info}/entry_points.txt +0 -0
- {hud_python-0.4.35.dist-info → hud_python-0.4.37.dist-info}/licenses/LICENSE +0 -0
hud/server/server.py
CHANGED
|
@@ -13,13 +13,14 @@ from typing import TYPE_CHECKING, Any
|
|
|
13
13
|
|
|
14
14
|
import anyio
|
|
15
15
|
from fastmcp.server.server import FastMCP, Transport
|
|
16
|
+
from starlette.responses import JSONResponse, Response
|
|
16
17
|
|
|
17
18
|
from hud.server.low_level import LowLevelServerWithInit
|
|
18
19
|
|
|
19
20
|
if TYPE_CHECKING:
|
|
20
21
|
from collections.abc import AsyncGenerator, Callable
|
|
21
22
|
|
|
22
|
-
from
|
|
23
|
+
from starlette.requests import Request
|
|
23
24
|
|
|
24
25
|
__all__ = ["MCPServer"]
|
|
25
26
|
|
|
@@ -35,6 +36,31 @@ def _run_with_sigterm(coro_fn: Callable[..., Any], *args: Any, **kwargs: Any) ->
|
|
|
35
36
|
|
|
36
37
|
sys.stderr.flush()
|
|
37
38
|
|
|
39
|
+
# Check if we're already in an event loop
|
|
40
|
+
try:
|
|
41
|
+
loop = asyncio.get_running_loop()
|
|
42
|
+
logger.warning(
|
|
43
|
+
"HUD server is running in an existing event loop. "
|
|
44
|
+
"SIGTERM handling may be limited. "
|
|
45
|
+
"Consider using await hub.run_async() instead of hub.run() in async contexts."
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
task = loop.create_task(coro_fn(*args, **kwargs))
|
|
49
|
+
|
|
50
|
+
# Try to handle SIGTERM if possible
|
|
51
|
+
if sys.platform != "win32":
|
|
52
|
+
|
|
53
|
+
def handle_sigterm(signum: Any, frame: Any) -> None:
|
|
54
|
+
logger.info("SIGTERM received in async context, cancelling task...")
|
|
55
|
+
loop.call_soon_threadsafe(task.cancel)
|
|
56
|
+
|
|
57
|
+
signal.signal(signal.SIGTERM, handle_sigterm)
|
|
58
|
+
|
|
59
|
+
return
|
|
60
|
+
|
|
61
|
+
except RuntimeError:
|
|
62
|
+
pass
|
|
63
|
+
|
|
38
64
|
async def _runner() -> None:
|
|
39
65
|
stop_evt: asyncio.Event | None = None
|
|
40
66
|
if sys.platform != "win32" and os.getenv("FASTMCP_DISABLE_SIGTERM_HANDLER") != "1":
|
|
@@ -125,7 +151,11 @@ class MCPServer(FastMCP):
|
|
|
125
151
|
# Force flush logs to ensure they're visible
|
|
126
152
|
sys.stderr.flush()
|
|
127
153
|
|
|
128
|
-
if
|
|
154
|
+
if (
|
|
155
|
+
self._shutdown_fn is not None
|
|
156
|
+
and _sigterm_received
|
|
157
|
+
and not self._shutdown_has_run
|
|
158
|
+
):
|
|
129
159
|
logger.info("SIGTERM detected! Calling @mcp.shutdown handler...")
|
|
130
160
|
sys.stderr.flush()
|
|
131
161
|
try:
|
|
@@ -135,7 +165,9 @@ class MCPServer(FastMCP):
|
|
|
135
165
|
except Exception as e:
|
|
136
166
|
logger.error("Error during @mcp.shutdown: %s", e)
|
|
137
167
|
sys.stderr.flush()
|
|
138
|
-
|
|
168
|
+
finally:
|
|
169
|
+
self._shutdown_has_run = True
|
|
170
|
+
_sigterm_received = False
|
|
139
171
|
elif self._shutdown_fn is not None:
|
|
140
172
|
logger.info(
|
|
141
173
|
"No SIGTERM. This is a hot reload (SIGINT) or normal exit. Skipping @mcp.shutdown handler." # noqa: E501
|
|
@@ -151,19 +183,53 @@ class MCPServer(FastMCP):
|
|
|
151
183
|
self._initializer_fn: Callable | None = None
|
|
152
184
|
self._did_init = False
|
|
153
185
|
self._replaced_server = False
|
|
186
|
+
self._shutdown_has_run = False # Guard against double-execution of shutdown hook
|
|
154
187
|
|
|
155
188
|
def _replace_with_init_server(self) -> None:
|
|
156
189
|
"""Replace the low-level server with init version when needed."""
|
|
157
190
|
if self._replaced_server:
|
|
158
191
|
return
|
|
159
192
|
|
|
160
|
-
def _run_init(ctx:
|
|
193
|
+
async def _run_init(ctx: object | None = None) -> None:
|
|
194
|
+
"""Run the user initializer exactly once, with stdout redirected."""
|
|
161
195
|
if self._initializer_fn is not None and not self._did_init:
|
|
162
196
|
self._did_init = True
|
|
163
|
-
#
|
|
164
|
-
# any library prints from corrupting the MCP protocol
|
|
197
|
+
# Prevent stdout from polluting the MCP protocol on stdio/HTTP
|
|
165
198
|
with contextlib.redirect_stdout(sys.stderr):
|
|
166
|
-
|
|
199
|
+
import inspect
|
|
200
|
+
|
|
201
|
+
fn = self._initializer_fn
|
|
202
|
+
sig = inspect.signature(fn)
|
|
203
|
+
params = sig.parameters
|
|
204
|
+
|
|
205
|
+
ctx_param = params.get("ctx") or params.get("_ctx")
|
|
206
|
+
if ctx_param is not None:
|
|
207
|
+
if ctx_param.kind == inspect.Parameter.KEYWORD_ONLY:
|
|
208
|
+
result = fn(**{ctx_param.name: ctx})
|
|
209
|
+
else:
|
|
210
|
+
result = fn(ctx)
|
|
211
|
+
else:
|
|
212
|
+
required_params = [
|
|
213
|
+
p
|
|
214
|
+
for p in params.values()
|
|
215
|
+
if p.default is inspect._empty
|
|
216
|
+
and p.kind
|
|
217
|
+
in (
|
|
218
|
+
inspect.Parameter.POSITIONAL_ONLY,
|
|
219
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
220
|
+
inspect.Parameter.KEYWORD_ONLY,
|
|
221
|
+
)
|
|
222
|
+
]
|
|
223
|
+
if required_params:
|
|
224
|
+
param_list = ", ".join(p.name for p in required_params)
|
|
225
|
+
raise TypeError(
|
|
226
|
+
"Initializer must accept no args or a single `ctx` argument; "
|
|
227
|
+
f"received required parameters: {param_list}"
|
|
228
|
+
)
|
|
229
|
+
result = fn()
|
|
230
|
+
if inspect.isawaitable(result):
|
|
231
|
+
await result
|
|
232
|
+
return None
|
|
167
233
|
return None
|
|
168
234
|
|
|
169
235
|
# Save the old server's handlers before replacing it
|
|
@@ -233,6 +299,38 @@ class MCPServer(FastMCP):
|
|
|
233
299
|
|
|
234
300
|
_run_with_sigterm(_bootstrap)
|
|
235
301
|
|
|
302
|
+
async def run_async(
|
|
303
|
+
self,
|
|
304
|
+
transport: Transport | None = None,
|
|
305
|
+
show_banner: bool = True,
|
|
306
|
+
**transport_kwargs: Any,
|
|
307
|
+
) -> None:
|
|
308
|
+
"""Run the server with HUD enhancements."""
|
|
309
|
+
if transport is None:
|
|
310
|
+
transport = "stdio"
|
|
311
|
+
|
|
312
|
+
# Register HTTP helpers for HTTP transport
|
|
313
|
+
if transport in ("http", "sse"):
|
|
314
|
+
self._register_hud_helpers()
|
|
315
|
+
logger.info("Registered HUD helper endpoints at /hud/*")
|
|
316
|
+
|
|
317
|
+
try:
|
|
318
|
+
await super().run_async(
|
|
319
|
+
transport=transport, show_banner=show_banner, **transport_kwargs
|
|
320
|
+
)
|
|
321
|
+
finally:
|
|
322
|
+
# Fallback: ensure SIGTERM-triggered shutdown runs even when a custom
|
|
323
|
+
# lifespan bypasses our default fastmcp shutdown path.
|
|
324
|
+
global _sigterm_received
|
|
325
|
+
if self._shutdown_fn is not None and _sigterm_received and not self._shutdown_has_run:
|
|
326
|
+
try:
|
|
327
|
+
await self._shutdown_fn()
|
|
328
|
+
except Exception as e: # pragma: no cover - defensive logging
|
|
329
|
+
logger.error("Error during @mcp.shutdown (fallback): %s", e)
|
|
330
|
+
finally:
|
|
331
|
+
self._shutdown_has_run = True
|
|
332
|
+
_sigterm_received = False
|
|
333
|
+
|
|
236
334
|
# Tool registration helper -- appends BaseTool to FastMCP
|
|
237
335
|
def add_tool(self, obj: Any, **kwargs: Any) -> None:
|
|
238
336
|
from hud.tools.base import BaseTool
|
|
@@ -242,3 +340,132 @@ class MCPServer(FastMCP):
|
|
|
242
340
|
return
|
|
243
341
|
|
|
244
342
|
super().add_tool(obj, **kwargs)
|
|
343
|
+
|
|
344
|
+
# Override to keep original callables when used as a decorator
|
|
345
|
+
def tool(self, name_or_fn: Any = None, **kwargs: Any) -> Any: # type: ignore[override]
|
|
346
|
+
"""Register a tool but return the original function in decorator form.
|
|
347
|
+
|
|
348
|
+
- Decorator usage (@mcp.tool, @mcp.tool("name"), @mcp.tool(name="name"))
|
|
349
|
+
registers with FastMCP and returns the original function for composition.
|
|
350
|
+
- Call-form (mcp.tool(fn, ...)) behaves the same but returns fn.
|
|
351
|
+
"""
|
|
352
|
+
# Accept BaseTool / FastMCP Tool instances or callables in call-form
|
|
353
|
+
if name_or_fn is not None and not isinstance(name_or_fn, str):
|
|
354
|
+
try:
|
|
355
|
+
from hud.tools.base import BaseTool # lazy import
|
|
356
|
+
except Exception:
|
|
357
|
+
BaseTool = tuple() # type: ignore[assignment]
|
|
358
|
+
try:
|
|
359
|
+
from fastmcp.tools.tool import Tool as _FastMcpTool
|
|
360
|
+
except Exception:
|
|
361
|
+
_FastMcpTool = tuple() # type: ignore[assignment]
|
|
362
|
+
|
|
363
|
+
# BaseTool instance → add underlying FunctionTool
|
|
364
|
+
if isinstance(name_or_fn, BaseTool):
|
|
365
|
+
super().add_tool(name_or_fn.mcp, **kwargs)
|
|
366
|
+
return name_or_fn
|
|
367
|
+
# FastMCP Tool/FunctionTool instance → add directly
|
|
368
|
+
if isinstance(name_or_fn, _FastMcpTool):
|
|
369
|
+
super().add_tool(name_or_fn, **kwargs)
|
|
370
|
+
return name_or_fn
|
|
371
|
+
# Callable function → register via FastMCP.tool and return original fn
|
|
372
|
+
if callable(name_or_fn):
|
|
373
|
+
super().tool(name_or_fn, **kwargs)
|
|
374
|
+
return name_or_fn
|
|
375
|
+
|
|
376
|
+
# Decorator form: get FastMCP's decorator, register, then return original fn
|
|
377
|
+
base_decorator = super().tool(name_or_fn, **kwargs)
|
|
378
|
+
|
|
379
|
+
def _wrapper(fn: Any) -> Any:
|
|
380
|
+
base_decorator(fn)
|
|
381
|
+
return fn
|
|
382
|
+
|
|
383
|
+
return _wrapper
|
|
384
|
+
|
|
385
|
+
def _register_hud_helpers(self) -> None:
|
|
386
|
+
"""Register HUD helper HTTP routes.
|
|
387
|
+
|
|
388
|
+
This adds:
|
|
389
|
+
- GET /hud - Overview of available endpoints
|
|
390
|
+
- GET /hud/tools - List all registered tools with their schemas
|
|
391
|
+
- GET /hud/resources - List all registered resources
|
|
392
|
+
- GET /hud/prompts - List all registered prompts
|
|
393
|
+
"""
|
|
394
|
+
|
|
395
|
+
@self.custom_route("/hud/tools", methods=["GET"])
|
|
396
|
+
async def list_tools(request: Request) -> Response:
|
|
397
|
+
"""List all registered tools with their names, descriptions, and schemas."""
|
|
398
|
+
tools = []
|
|
399
|
+
# _tools is a mapping of tool_name -> FunctionTool/Tool instance
|
|
400
|
+
for tool_key, tool in self._tool_manager._tools.items():
|
|
401
|
+
tool_data = {"name": tool_key}
|
|
402
|
+
try:
|
|
403
|
+
# Prefer converting to MCP model for consistent fields
|
|
404
|
+
mcp_tool = tool.to_mcp_tool()
|
|
405
|
+
tool_data["description"] = getattr(mcp_tool, "description", "")
|
|
406
|
+
if hasattr(mcp_tool, "inputSchema") and mcp_tool.inputSchema:
|
|
407
|
+
tool_data["input_schema"] = mcp_tool.inputSchema # type: ignore[assignment]
|
|
408
|
+
if hasattr(mcp_tool, "outputSchema") and mcp_tool.outputSchema:
|
|
409
|
+
tool_data["output_schema"] = mcp_tool.outputSchema # type: ignore[assignment]
|
|
410
|
+
except Exception:
|
|
411
|
+
# Fallback to direct attributes on FunctionTool
|
|
412
|
+
tool_data["description"] = getattr(tool, "description", "")
|
|
413
|
+
params = getattr(tool, "parameters", None)
|
|
414
|
+
if params:
|
|
415
|
+
tool_data["input_schema"] = params
|
|
416
|
+
tools.append(tool_data)
|
|
417
|
+
|
|
418
|
+
return JSONResponse({"server": self.name, "tools": tools, "count": len(tools)})
|
|
419
|
+
|
|
420
|
+
@self.custom_route("/hud/resources", methods=["GET"])
|
|
421
|
+
async def list_resources(request: Request) -> Response:
|
|
422
|
+
"""List all registered resources."""
|
|
423
|
+
resources = []
|
|
424
|
+
for resource_key, resource in self._resource_manager._resources.items():
|
|
425
|
+
resource_data = {
|
|
426
|
+
"uri": resource_key,
|
|
427
|
+
"name": resource.name,
|
|
428
|
+
"description": resource.description,
|
|
429
|
+
"mimeType": resource.mime_type,
|
|
430
|
+
}
|
|
431
|
+
resources.append(resource_data)
|
|
432
|
+
|
|
433
|
+
return JSONResponse(
|
|
434
|
+
{"server": self.name, "resources": resources, "count": len(resources)}
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
@self.custom_route("/hud/prompts", methods=["GET"])
|
|
438
|
+
async def list_prompts(request: Request) -> Response:
|
|
439
|
+
"""List all registered prompts."""
|
|
440
|
+
prompts = []
|
|
441
|
+
for prompt_key, prompt in self._prompt_manager._prompts.items():
|
|
442
|
+
prompt_data = {
|
|
443
|
+
"name": prompt_key,
|
|
444
|
+
"description": prompt.description,
|
|
445
|
+
}
|
|
446
|
+
# Check if it has arguments
|
|
447
|
+
if hasattr(prompt, "arguments") and prompt.arguments:
|
|
448
|
+
prompt_data["arguments"] = [
|
|
449
|
+
{"name": arg.name, "description": arg.description, "required": arg.required}
|
|
450
|
+
for arg in prompt.arguments
|
|
451
|
+
]
|
|
452
|
+
prompts.append(prompt_data)
|
|
453
|
+
|
|
454
|
+
return JSONResponse({"server": self.name, "prompts": prompts, "count": len(prompts)})
|
|
455
|
+
|
|
456
|
+
@self.custom_route("/hud", methods=["GET"])
|
|
457
|
+
async def hud_info(request: Request) -> Response:
|
|
458
|
+
"""Show available HUD helper endpoints."""
|
|
459
|
+
base_url = str(request.base_url).rstrip("/")
|
|
460
|
+
return JSONResponse(
|
|
461
|
+
{
|
|
462
|
+
"name": "HUD MCP Development Helpers",
|
|
463
|
+
"server": self.name,
|
|
464
|
+
"endpoints": {
|
|
465
|
+
"tools": f"{base_url}/hud/tools",
|
|
466
|
+
"resources": f"{base_url}/hud/resources",
|
|
467
|
+
"prompts": f"{base_url}/hud/prompts",
|
|
468
|
+
},
|
|
469
|
+
"description": "These endpoints help you inspect your MCP server during development.", # noqa: E501
|
|
470
|
+
}
|
|
471
|
+
)
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
import types
|
|
5
|
+
from typing import Any, cast
|
|
6
|
+
|
|
7
|
+
from hud.server import MCPServer
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def test_add_tool_accepts_base_tool(monkeypatch):
|
|
11
|
+
"""If obj is BaseTool, its `.mcp` gets passed through to FastMCP.add_tool."""
|
|
12
|
+
# Stub hud.tools.base.BaseTool and capture FastMCP.add_tool calls
|
|
13
|
+
mod = types.ModuleType("hud.tools.base")
|
|
14
|
+
|
|
15
|
+
class FakeBaseTool:
|
|
16
|
+
"""Stub type checked by isinstance() inside add_tool."""
|
|
17
|
+
|
|
18
|
+
# Tell the type checker we're mutating a dynamic module
|
|
19
|
+
mod_any = cast("Any", mod)
|
|
20
|
+
mod_any.BaseTool = FakeBaseTool
|
|
21
|
+
monkeypatch.setitem(sys.modules, "hud.tools.base", mod)
|
|
22
|
+
|
|
23
|
+
calls: dict[str, object | None] = {"obj": None, "kwargs": None}
|
|
24
|
+
|
|
25
|
+
def fake_super_add(self, obj: object, **kwargs: object) -> None: # keep runtime the same
|
|
26
|
+
calls["obj"] = obj
|
|
27
|
+
calls["kwargs"] = kwargs
|
|
28
|
+
|
|
29
|
+
monkeypatch.setattr("hud.server.server.FastMCP.add_tool", fake_super_add, raising=True)
|
|
30
|
+
|
|
31
|
+
mcp = MCPServer(name="AddTool")
|
|
32
|
+
sentinel = object()
|
|
33
|
+
|
|
34
|
+
class MyTool(FakeBaseTool):
|
|
35
|
+
def __init__(self) -> None:
|
|
36
|
+
self.mcp = sentinel
|
|
37
|
+
|
|
38
|
+
mcp.add_tool(MyTool(), extra="yes")
|
|
39
|
+
assert calls["obj"] is sentinel
|
|
40
|
+
assert isinstance(calls["kwargs"], dict)
|
|
41
|
+
assert calls["kwargs"]["extra"] == "yes"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def test_add_tool_plain_falls_back_to_super(monkeypatch):
|
|
45
|
+
"""Non-BaseTool objects are passed unchanged to FastMCP.add_tool."""
|
|
46
|
+
calls = []
|
|
47
|
+
|
|
48
|
+
def fake_super_add(self, obj, **kwargs):
|
|
49
|
+
calls.append((obj, kwargs))
|
|
50
|
+
|
|
51
|
+
monkeypatch.setattr("hud.server.server.FastMCP.add_tool", fake_super_add, raising=True)
|
|
52
|
+
|
|
53
|
+
mcp = MCPServer(name="AddToolPlain")
|
|
54
|
+
|
|
55
|
+
async def fn(): # pragma: no cover - never awaited by FastMCP here
|
|
56
|
+
return "ok"
|
|
57
|
+
|
|
58
|
+
mcp.add_tool(fn, desc="x")
|
|
59
|
+
assert calls and calls[0][0] is fn
|
|
60
|
+
assert calls[0][1]["desc"] == "x"
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
import multiprocessing.connection as _mp_conn
|
|
8
|
+
|
|
9
|
+
# Pull the exception dynamically; fall back to OSError if missing in stubs/runtime
|
|
10
|
+
MPAuthenticationError: type[BaseException] = getattr(_mp_conn, "AuthenticationError", OSError)
|
|
11
|
+
except Exception: # pragma: no cover
|
|
12
|
+
MPAuthenticationError = OSError
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
import pytest
|
|
18
|
+
|
|
19
|
+
from hud.server.context import attach_context, serve_context
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
|
|
24
|
+
pytestmark = pytest.mark.skipif(
|
|
25
|
+
sys.platform == "win32",
|
|
26
|
+
reason="Context server uses UNIX domain sockets",
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class CounterCtx:
|
|
31
|
+
def __init__(self) -> None:
|
|
32
|
+
self._n = 0
|
|
33
|
+
|
|
34
|
+
def inc(self) -> int:
|
|
35
|
+
self._n += 1
|
|
36
|
+
return self._n
|
|
37
|
+
|
|
38
|
+
def get(self) -> int:
|
|
39
|
+
return self._n
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def test_serve_and_attach_shared_state(tmp_path: Path) -> None:
|
|
43
|
+
sock = str(tmp_path / "hud_ctx.sock")
|
|
44
|
+
|
|
45
|
+
mgr = serve_context(CounterCtx(), sock_path=sock)
|
|
46
|
+
try:
|
|
47
|
+
c1 = attach_context(sock_path=sock)
|
|
48
|
+
assert c1.get() == 0
|
|
49
|
+
assert c1.inc() == 1
|
|
50
|
+
|
|
51
|
+
# Second attachment sees the same underlying object
|
|
52
|
+
c2 = attach_context(sock_path=sock)
|
|
53
|
+
assert c2.get() == 1
|
|
54
|
+
assert c2.inc() == 2
|
|
55
|
+
assert c1.get() == 2 # shared state
|
|
56
|
+
finally:
|
|
57
|
+
mgr.shutdown()
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def test_env_var_socket_path_overrides(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
61
|
+
sock = str(tmp_path / "env_ctx.sock")
|
|
62
|
+
monkeypatch.setenv("HUD_CTX_SOCK", sock)
|
|
63
|
+
|
|
64
|
+
mgr = serve_context(CounterCtx(), sock_path=None)
|
|
65
|
+
try:
|
|
66
|
+
c = attach_context(sock_path=None)
|
|
67
|
+
assert c.inc() == 1
|
|
68
|
+
assert c.get() == 1
|
|
69
|
+
finally:
|
|
70
|
+
mgr.shutdown()
|
|
71
|
+
monkeypatch.delenv("HUD_CTX_SOCK", raising=False)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def test_wrong_authkey_rejected(tmp_path: Path) -> None:
|
|
75
|
+
sock = str(tmp_path / "auth_ctx.sock")
|
|
76
|
+
mgr = serve_context(CounterCtx(), sock_path=sock, authkey=b"correct")
|
|
77
|
+
try:
|
|
78
|
+
with pytest.raises(
|
|
79
|
+
(MPAuthenticationError, ConnectionRefusedError, BrokenPipeError, OSError)
|
|
80
|
+
):
|
|
81
|
+
attach_context(sock_path=sock, authkey=b"wrong")
|
|
82
|
+
finally:
|
|
83
|
+
mgr.shutdown()
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def test_attach_nonexistent_raises(tmp_path: Path) -> None:
|
|
87
|
+
# ensure file truly doesn't exist
|
|
88
|
+
sock = str(tmp_path / "missing.sock")
|
|
89
|
+
if os.path.exists(sock):
|
|
90
|
+
os.unlink(sock)
|
|
91
|
+
|
|
92
|
+
with pytest.raises((FileNotFoundError, ConnectionRefusedError, OSError)):
|
|
93
|
+
attach_context(sock_path=sock)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@pytest.mark.asyncio
|
|
97
|
+
async def test_run_context_server_handles_keyboardinterrupt(
|
|
98
|
+
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
|
|
99
|
+
) -> None:
|
|
100
|
+
"""run_context_server should call manager.shutdown() when KeyboardInterrupt occurs."""
|
|
101
|
+
# Capture serve_context() and the returned manager
|
|
102
|
+
called = {"served": False, "shutdown": False, "addr": None}
|
|
103
|
+
|
|
104
|
+
class _Mgr:
|
|
105
|
+
def shutdown(self) -> None:
|
|
106
|
+
called["shutdown"] = True
|
|
107
|
+
|
|
108
|
+
def fake_serve(ctx, sock_path, authkey):
|
|
109
|
+
called["served"] = True
|
|
110
|
+
called["addr"] = sock_path
|
|
111
|
+
return _Mgr()
|
|
112
|
+
|
|
113
|
+
monkeypatch.setattr("hud.server.context.serve_context", fake_serve)
|
|
114
|
+
|
|
115
|
+
# Make asyncio.Event().wait() raise KeyboardInterrupt immediately
|
|
116
|
+
class _FakeEvent:
|
|
117
|
+
async def wait(self) -> None:
|
|
118
|
+
raise KeyboardInterrupt
|
|
119
|
+
|
|
120
|
+
monkeypatch.setattr("hud.server.context.asyncio.Event", lambda: _FakeEvent())
|
|
121
|
+
|
|
122
|
+
from hud.server.context import run_context_server
|
|
123
|
+
|
|
124
|
+
await run_context_server(object(), sock_path=str(tmp_path / "ctx.sock"))
|
|
125
|
+
|
|
126
|
+
assert called["served"] is True
|
|
127
|
+
assert called["shutdown"] is True
|
|
128
|
+
assert str(called["addr"]).endswith("ctx.sock")
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, cast
|
|
4
|
+
|
|
5
|
+
from hud.server import MCPServer
|
|
6
|
+
from hud.server.low_level import LowLevelServerWithInit
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def test_notification_handlers_preserved_on_replacement():
|
|
10
|
+
"""When init server replaces low-level server, notification handlers must be kept."""
|
|
11
|
+
mcp = MCPServer(name="PreserveNotif")
|
|
12
|
+
|
|
13
|
+
# Seed a fake notification handler on the pre-replacement server
|
|
14
|
+
before = mcp._mcp_server
|
|
15
|
+
cast("dict[Any, Any]", before.notification_handlers)["foo/notify"] = object()
|
|
16
|
+
|
|
17
|
+
@mcp.initialize
|
|
18
|
+
async def _init(_ctx) -> None:
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
after = mcp._mcp_server
|
|
22
|
+
assert isinstance(after, LowLevelServerWithInit)
|
|
23
|
+
assert after is not before, "low-level server should be replaced once"
|
|
24
|
+
# Must still contain our seeded handler (dict is copied over)
|
|
25
|
+
assert "foo/notify" in after.notification_handlers
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def test_init_server_replacement_is_idempotent():
|
|
29
|
+
"""Second @initialize must NOT replace the low-level server again."""
|
|
30
|
+
mcp = MCPServer(name="InitIdempotent")
|
|
31
|
+
|
|
32
|
+
@mcp.initialize
|
|
33
|
+
async def _a(_ctx) -> None:
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
first = mcp._mcp_server
|
|
37
|
+
|
|
38
|
+
@mcp.initialize
|
|
39
|
+
async def _b(_ctx) -> None:
|
|
40
|
+
# last initializer should win, but server object should not be replaced again
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
second = mcp._mcp_server
|
|
44
|
+
assert first is second, "Server replacement should occur at most once per instance"
|