hud-python 0.4.35__py3-none-any.whl → 0.4.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hud-python might be problematic. Click here for more details.

Files changed (63) hide show
  1. hud/agents/__init__.py +2 -0
  2. hud/agents/lite_llm.py +72 -0
  3. hud/agents/openai_chat_generic.py +21 -7
  4. hud/agents/tests/test_claude.py +32 -7
  5. hud/agents/tests/test_openai.py +29 -6
  6. hud/cli/__init__.py +228 -79
  7. hud/cli/build.py +26 -6
  8. hud/cli/dev.py +21 -40
  9. hud/cli/eval.py +96 -15
  10. hud/cli/flows/tasks.py +198 -65
  11. hud/cli/init.py +222 -629
  12. hud/cli/pull.py +6 -0
  13. hud/cli/push.py +11 -1
  14. hud/cli/rl/__init__.py +14 -4
  15. hud/cli/rl/celebrate.py +187 -0
  16. hud/cli/rl/config.py +15 -8
  17. hud/cli/rl/local_runner.py +44 -20
  18. hud/cli/rl/remote_runner.py +166 -87
  19. hud/cli/rl/viewer.py +141 -0
  20. hud/cli/rl/wait_utils.py +89 -0
  21. hud/cli/tests/test_build.py +3 -27
  22. hud/cli/tests/test_mcp_server.py +1 -12
  23. hud/cli/utils/config.py +85 -0
  24. hud/cli/utils/docker.py +21 -39
  25. hud/cli/utils/env_check.py +196 -0
  26. hud/cli/utils/environment.py +4 -3
  27. hud/cli/utils/interactive.py +2 -1
  28. hud/cli/utils/local_runner.py +204 -0
  29. hud/cli/utils/metadata.py +3 -1
  30. hud/cli/utils/package_runner.py +292 -0
  31. hud/cli/utils/remote_runner.py +4 -1
  32. hud/cli/utils/source_hash.py +108 -0
  33. hud/clients/base.py +1 -1
  34. hud/clients/fastmcp.py +1 -1
  35. hud/clients/mcp_use.py +30 -7
  36. hud/datasets/parallel.py +3 -1
  37. hud/datasets/runner.py +4 -1
  38. hud/otel/config.py +1 -1
  39. hud/otel/context.py +40 -6
  40. hud/rl/buffer.py +3 -0
  41. hud/rl/tests/test_learner.py +1 -1
  42. hud/rl/vllm_adapter.py +1 -1
  43. hud/server/server.py +234 -7
  44. hud/server/tests/test_add_tool.py +60 -0
  45. hud/server/tests/test_context.py +128 -0
  46. hud/server/tests/test_mcp_server_handlers.py +44 -0
  47. hud/server/tests/test_mcp_server_integration.py +405 -0
  48. hud/server/tests/test_mcp_server_more.py +247 -0
  49. hud/server/tests/test_run_wrapper.py +53 -0
  50. hud/server/tests/test_server_extra.py +166 -0
  51. hud/server/tests/test_sigterm_runner.py +78 -0
  52. hud/settings.py +38 -0
  53. hud/shared/hints.py +2 -2
  54. hud/telemetry/job.py +2 -2
  55. hud/types.py +9 -2
  56. hud/utils/tasks.py +32 -24
  57. hud/utils/tests/test_version.py +1 -1
  58. hud/version.py +1 -1
  59. {hud_python-0.4.35.dist-info → hud_python-0.4.37.dist-info}/METADATA +43 -23
  60. {hud_python-0.4.35.dist-info → hud_python-0.4.37.dist-info}/RECORD +63 -46
  61. {hud_python-0.4.35.dist-info → hud_python-0.4.37.dist-info}/WHEEL +0 -0
  62. {hud_python-0.4.35.dist-info → hud_python-0.4.37.dist-info}/entry_points.txt +0 -0
  63. {hud_python-0.4.35.dist-info → hud_python-0.4.37.dist-info}/licenses/LICENSE +0 -0
hud/server/server.py CHANGED
@@ -13,13 +13,14 @@ from typing import TYPE_CHECKING, Any
13
13
 
14
14
  import anyio
15
15
  from fastmcp.server.server import FastMCP, Transport
16
+ from starlette.responses import JSONResponse, Response
16
17
 
17
18
  from hud.server.low_level import LowLevelServerWithInit
18
19
 
19
20
  if TYPE_CHECKING:
20
21
  from collections.abc import AsyncGenerator, Callable
21
22
 
22
- from mcp.shared.context import RequestContext
23
+ from starlette.requests import Request
23
24
 
24
25
  __all__ = ["MCPServer"]
25
26
 
@@ -35,6 +36,31 @@ def _run_with_sigterm(coro_fn: Callable[..., Any], *args: Any, **kwargs: Any) ->
35
36
 
36
37
  sys.stderr.flush()
37
38
 
39
+ # Check if we're already in an event loop
40
+ try:
41
+ loop = asyncio.get_running_loop()
42
+ logger.warning(
43
+ "HUD server is running in an existing event loop. "
44
+ "SIGTERM handling may be limited. "
45
+ "Consider using await hub.run_async() instead of hub.run() in async contexts."
46
+ )
47
+
48
+ task = loop.create_task(coro_fn(*args, **kwargs))
49
+
50
+ # Try to handle SIGTERM if possible
51
+ if sys.platform != "win32":
52
+
53
+ def handle_sigterm(signum: Any, frame: Any) -> None:
54
+ logger.info("SIGTERM received in async context, cancelling task...")
55
+ loop.call_soon_threadsafe(task.cancel)
56
+
57
+ signal.signal(signal.SIGTERM, handle_sigterm)
58
+
59
+ return
60
+
61
+ except RuntimeError:
62
+ pass
63
+
38
64
  async def _runner() -> None:
39
65
  stop_evt: asyncio.Event | None = None
40
66
  if sys.platform != "win32" and os.getenv("FASTMCP_DISABLE_SIGTERM_HANDLER") != "1":
@@ -125,7 +151,11 @@ class MCPServer(FastMCP):
125
151
  # Force flush logs to ensure they're visible
126
152
  sys.stderr.flush()
127
153
 
128
- if self._shutdown_fn is not None and _sigterm_received:
154
+ if (
155
+ self._shutdown_fn is not None
156
+ and _sigterm_received
157
+ and not self._shutdown_has_run
158
+ ):
129
159
  logger.info("SIGTERM detected! Calling @mcp.shutdown handler...")
130
160
  sys.stderr.flush()
131
161
  try:
@@ -135,7 +165,9 @@ class MCPServer(FastMCP):
135
165
  except Exception as e:
136
166
  logger.error("Error during @mcp.shutdown: %s", e)
137
167
  sys.stderr.flush()
138
- _sigterm_received = False
168
+ finally:
169
+ self._shutdown_has_run = True
170
+ _sigterm_received = False
139
171
  elif self._shutdown_fn is not None:
140
172
  logger.info(
141
173
  "No SIGTERM. This is a hot reload (SIGINT) or normal exit. Skipping @mcp.shutdown handler." # noqa: E501
@@ -151,19 +183,53 @@ class MCPServer(FastMCP):
151
183
  self._initializer_fn: Callable | None = None
152
184
  self._did_init = False
153
185
  self._replaced_server = False
186
+ self._shutdown_has_run = False # Guard against double-execution of shutdown hook
154
187
 
155
188
  def _replace_with_init_server(self) -> None:
156
189
  """Replace the low-level server with init version when needed."""
157
190
  if self._replaced_server:
158
191
  return
159
192
 
160
- def _run_init(ctx: RequestContext | None = None) -> Any:
193
+ async def _run_init(ctx: object | None = None) -> None:
194
+ """Run the user initializer exactly once, with stdout redirected."""
161
195
  if self._initializer_fn is not None and not self._did_init:
162
196
  self._did_init = True
163
- # Redirect stdout to stderr during initialization to prevent
164
- # any library prints from corrupting the MCP protocol
197
+ # Prevent stdout from polluting the MCP protocol on stdio/HTTP
165
198
  with contextlib.redirect_stdout(sys.stderr):
166
- return self._initializer_fn(ctx)
199
+ import inspect
200
+
201
+ fn = self._initializer_fn
202
+ sig = inspect.signature(fn)
203
+ params = sig.parameters
204
+
205
+ ctx_param = params.get("ctx") or params.get("_ctx")
206
+ if ctx_param is not None:
207
+ if ctx_param.kind == inspect.Parameter.KEYWORD_ONLY:
208
+ result = fn(**{ctx_param.name: ctx})
209
+ else:
210
+ result = fn(ctx)
211
+ else:
212
+ required_params = [
213
+ p
214
+ for p in params.values()
215
+ if p.default is inspect._empty
216
+ and p.kind
217
+ in (
218
+ inspect.Parameter.POSITIONAL_ONLY,
219
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
220
+ inspect.Parameter.KEYWORD_ONLY,
221
+ )
222
+ ]
223
+ if required_params:
224
+ param_list = ", ".join(p.name for p in required_params)
225
+ raise TypeError(
226
+ "Initializer must accept no args or a single `ctx` argument; "
227
+ f"received required parameters: {param_list}"
228
+ )
229
+ result = fn()
230
+ if inspect.isawaitable(result):
231
+ await result
232
+ return None
167
233
  return None
168
234
 
169
235
  # Save the old server's handlers before replacing it
@@ -233,6 +299,38 @@ class MCPServer(FastMCP):
233
299
 
234
300
  _run_with_sigterm(_bootstrap)
235
301
 
302
+ async def run_async(
303
+ self,
304
+ transport: Transport | None = None,
305
+ show_banner: bool = True,
306
+ **transport_kwargs: Any,
307
+ ) -> None:
308
+ """Run the server with HUD enhancements."""
309
+ if transport is None:
310
+ transport = "stdio"
311
+
312
+ # Register HTTP helpers for HTTP transport
313
+ if transport in ("http", "sse"):
314
+ self._register_hud_helpers()
315
+ logger.info("Registered HUD helper endpoints at /hud/*")
316
+
317
+ try:
318
+ await super().run_async(
319
+ transport=transport, show_banner=show_banner, **transport_kwargs
320
+ )
321
+ finally:
322
+ # Fallback: ensure SIGTERM-triggered shutdown runs even when a custom
323
+ # lifespan bypasses our default fastmcp shutdown path.
324
+ global _sigterm_received
325
+ if self._shutdown_fn is not None and _sigterm_received and not self._shutdown_has_run:
326
+ try:
327
+ await self._shutdown_fn()
328
+ except Exception as e: # pragma: no cover - defensive logging
329
+ logger.error("Error during @mcp.shutdown (fallback): %s", e)
330
+ finally:
331
+ self._shutdown_has_run = True
332
+ _sigterm_received = False
333
+
236
334
  # Tool registration helper -- appends BaseTool to FastMCP
237
335
  def add_tool(self, obj: Any, **kwargs: Any) -> None:
238
336
  from hud.tools.base import BaseTool
@@ -242,3 +340,132 @@ class MCPServer(FastMCP):
242
340
  return
243
341
 
244
342
  super().add_tool(obj, **kwargs)
343
+
344
+ # Override to keep original callables when used as a decorator
345
+ def tool(self, name_or_fn: Any = None, **kwargs: Any) -> Any: # type: ignore[override]
346
+ """Register a tool but return the original function in decorator form.
347
+
348
+ - Decorator usage (@mcp.tool, @mcp.tool("name"), @mcp.tool(name="name"))
349
+ registers with FastMCP and returns the original function for composition.
350
+ - Call-form (mcp.tool(fn, ...)) behaves the same but returns fn.
351
+ """
352
+ # Accept BaseTool / FastMCP Tool instances or callables in call-form
353
+ if name_or_fn is not None and not isinstance(name_or_fn, str):
354
+ try:
355
+ from hud.tools.base import BaseTool # lazy import
356
+ except Exception:
357
+ BaseTool = tuple() # type: ignore[assignment]
358
+ try:
359
+ from fastmcp.tools.tool import Tool as _FastMcpTool
360
+ except Exception:
361
+ _FastMcpTool = tuple() # type: ignore[assignment]
362
+
363
+ # BaseTool instance → add underlying FunctionTool
364
+ if isinstance(name_or_fn, BaseTool):
365
+ super().add_tool(name_or_fn.mcp, **kwargs)
366
+ return name_or_fn
367
+ # FastMCP Tool/FunctionTool instance → add directly
368
+ if isinstance(name_or_fn, _FastMcpTool):
369
+ super().add_tool(name_or_fn, **kwargs)
370
+ return name_or_fn
371
+ # Callable function → register via FastMCP.tool and return original fn
372
+ if callable(name_or_fn):
373
+ super().tool(name_or_fn, **kwargs)
374
+ return name_or_fn
375
+
376
+ # Decorator form: get FastMCP's decorator, register, then return original fn
377
+ base_decorator = super().tool(name_or_fn, **kwargs)
378
+
379
+ def _wrapper(fn: Any) -> Any:
380
+ base_decorator(fn)
381
+ return fn
382
+
383
+ return _wrapper
384
+
385
+ def _register_hud_helpers(self) -> None:
386
+ """Register HUD helper HTTP routes.
387
+
388
+ This adds:
389
+ - GET /hud - Overview of available endpoints
390
+ - GET /hud/tools - List all registered tools with their schemas
391
+ - GET /hud/resources - List all registered resources
392
+ - GET /hud/prompts - List all registered prompts
393
+ """
394
+
395
+ @self.custom_route("/hud/tools", methods=["GET"])
396
+ async def list_tools(request: Request) -> Response:
397
+ """List all registered tools with their names, descriptions, and schemas."""
398
+ tools = []
399
+ # _tools is a mapping of tool_name -> FunctionTool/Tool instance
400
+ for tool_key, tool in self._tool_manager._tools.items():
401
+ tool_data = {"name": tool_key}
402
+ try:
403
+ # Prefer converting to MCP model for consistent fields
404
+ mcp_tool = tool.to_mcp_tool()
405
+ tool_data["description"] = getattr(mcp_tool, "description", "")
406
+ if hasattr(mcp_tool, "inputSchema") and mcp_tool.inputSchema:
407
+ tool_data["input_schema"] = mcp_tool.inputSchema # type: ignore[assignment]
408
+ if hasattr(mcp_tool, "outputSchema") and mcp_tool.outputSchema:
409
+ tool_data["output_schema"] = mcp_tool.outputSchema # type: ignore[assignment]
410
+ except Exception:
411
+ # Fallback to direct attributes on FunctionTool
412
+ tool_data["description"] = getattr(tool, "description", "")
413
+ params = getattr(tool, "parameters", None)
414
+ if params:
415
+ tool_data["input_schema"] = params
416
+ tools.append(tool_data)
417
+
418
+ return JSONResponse({"server": self.name, "tools": tools, "count": len(tools)})
419
+
420
+ @self.custom_route("/hud/resources", methods=["GET"])
421
+ async def list_resources(request: Request) -> Response:
422
+ """List all registered resources."""
423
+ resources = []
424
+ for resource_key, resource in self._resource_manager._resources.items():
425
+ resource_data = {
426
+ "uri": resource_key,
427
+ "name": resource.name,
428
+ "description": resource.description,
429
+ "mimeType": resource.mime_type,
430
+ }
431
+ resources.append(resource_data)
432
+
433
+ return JSONResponse(
434
+ {"server": self.name, "resources": resources, "count": len(resources)}
435
+ )
436
+
437
+ @self.custom_route("/hud/prompts", methods=["GET"])
438
+ async def list_prompts(request: Request) -> Response:
439
+ """List all registered prompts."""
440
+ prompts = []
441
+ for prompt_key, prompt in self._prompt_manager._prompts.items():
442
+ prompt_data = {
443
+ "name": prompt_key,
444
+ "description": prompt.description,
445
+ }
446
+ # Check if it has arguments
447
+ if hasattr(prompt, "arguments") and prompt.arguments:
448
+ prompt_data["arguments"] = [
449
+ {"name": arg.name, "description": arg.description, "required": arg.required}
450
+ for arg in prompt.arguments
451
+ ]
452
+ prompts.append(prompt_data)
453
+
454
+ return JSONResponse({"server": self.name, "prompts": prompts, "count": len(prompts)})
455
+
456
+ @self.custom_route("/hud", methods=["GET"])
457
+ async def hud_info(request: Request) -> Response:
458
+ """Show available HUD helper endpoints."""
459
+ base_url = str(request.base_url).rstrip("/")
460
+ return JSONResponse(
461
+ {
462
+ "name": "HUD MCP Development Helpers",
463
+ "server": self.name,
464
+ "endpoints": {
465
+ "tools": f"{base_url}/hud/tools",
466
+ "resources": f"{base_url}/hud/resources",
467
+ "prompts": f"{base_url}/hud/prompts",
468
+ },
469
+ "description": "These endpoints help you inspect your MCP server during development.", # noqa: E501
470
+ }
471
+ )
@@ -0,0 +1,60 @@
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+ import types
5
+ from typing import Any, cast
6
+
7
+ from hud.server import MCPServer
8
+
9
+
10
+ def test_add_tool_accepts_base_tool(monkeypatch):
11
+ """If obj is BaseTool, its `.mcp` gets passed through to FastMCP.add_tool."""
12
+ # Stub hud.tools.base.BaseTool and capture FastMCP.add_tool calls
13
+ mod = types.ModuleType("hud.tools.base")
14
+
15
+ class FakeBaseTool:
16
+ """Stub type checked by isinstance() inside add_tool."""
17
+
18
+ # Tell the type checker we're mutating a dynamic module
19
+ mod_any = cast("Any", mod)
20
+ mod_any.BaseTool = FakeBaseTool
21
+ monkeypatch.setitem(sys.modules, "hud.tools.base", mod)
22
+
23
+ calls: dict[str, object | None] = {"obj": None, "kwargs": None}
24
+
25
+ def fake_super_add(self, obj: object, **kwargs: object) -> None: # keep runtime the same
26
+ calls["obj"] = obj
27
+ calls["kwargs"] = kwargs
28
+
29
+ monkeypatch.setattr("hud.server.server.FastMCP.add_tool", fake_super_add, raising=True)
30
+
31
+ mcp = MCPServer(name="AddTool")
32
+ sentinel = object()
33
+
34
+ class MyTool(FakeBaseTool):
35
+ def __init__(self) -> None:
36
+ self.mcp = sentinel
37
+
38
+ mcp.add_tool(MyTool(), extra="yes")
39
+ assert calls["obj"] is sentinel
40
+ assert isinstance(calls["kwargs"], dict)
41
+ assert calls["kwargs"]["extra"] == "yes"
42
+
43
+
44
+ def test_add_tool_plain_falls_back_to_super(monkeypatch):
45
+ """Non-BaseTool objects are passed unchanged to FastMCP.add_tool."""
46
+ calls = []
47
+
48
+ def fake_super_add(self, obj, **kwargs):
49
+ calls.append((obj, kwargs))
50
+
51
+ monkeypatch.setattr("hud.server.server.FastMCP.add_tool", fake_super_add, raising=True)
52
+
53
+ mcp = MCPServer(name="AddToolPlain")
54
+
55
+ async def fn(): # pragma: no cover - never awaited by FastMCP here
56
+ return "ok"
57
+
58
+ mcp.add_tool(fn, desc="x")
59
+ assert calls and calls[0][0] is fn
60
+ assert calls[0][1]["desc"] == "x"
@@ -0,0 +1,128 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import sys
5
+
6
+ try:
7
+ import multiprocessing.connection as _mp_conn
8
+
9
+ # Pull the exception dynamically; fall back to OSError if missing in stubs/runtime
10
+ MPAuthenticationError: type[BaseException] = getattr(_mp_conn, "AuthenticationError", OSError)
11
+ except Exception: # pragma: no cover
12
+ MPAuthenticationError = OSError
13
+
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ import pytest
18
+
19
+ from hud.server.context import attach_context, serve_context
20
+
21
+ if TYPE_CHECKING:
22
+ from pathlib import Path
23
+
24
+ pytestmark = pytest.mark.skipif(
25
+ sys.platform == "win32",
26
+ reason="Context server uses UNIX domain sockets",
27
+ )
28
+
29
+
30
+ class CounterCtx:
31
+ def __init__(self) -> None:
32
+ self._n = 0
33
+
34
+ def inc(self) -> int:
35
+ self._n += 1
36
+ return self._n
37
+
38
+ def get(self) -> int:
39
+ return self._n
40
+
41
+
42
+ def test_serve_and_attach_shared_state(tmp_path: Path) -> None:
43
+ sock = str(tmp_path / "hud_ctx.sock")
44
+
45
+ mgr = serve_context(CounterCtx(), sock_path=sock)
46
+ try:
47
+ c1 = attach_context(sock_path=sock)
48
+ assert c1.get() == 0
49
+ assert c1.inc() == 1
50
+
51
+ # Second attachment sees the same underlying object
52
+ c2 = attach_context(sock_path=sock)
53
+ assert c2.get() == 1
54
+ assert c2.inc() == 2
55
+ assert c1.get() == 2 # shared state
56
+ finally:
57
+ mgr.shutdown()
58
+
59
+
60
+ def test_env_var_socket_path_overrides(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
61
+ sock = str(tmp_path / "env_ctx.sock")
62
+ monkeypatch.setenv("HUD_CTX_SOCK", sock)
63
+
64
+ mgr = serve_context(CounterCtx(), sock_path=None)
65
+ try:
66
+ c = attach_context(sock_path=None)
67
+ assert c.inc() == 1
68
+ assert c.get() == 1
69
+ finally:
70
+ mgr.shutdown()
71
+ monkeypatch.delenv("HUD_CTX_SOCK", raising=False)
72
+
73
+
74
+ def test_wrong_authkey_rejected(tmp_path: Path) -> None:
75
+ sock = str(tmp_path / "auth_ctx.sock")
76
+ mgr = serve_context(CounterCtx(), sock_path=sock, authkey=b"correct")
77
+ try:
78
+ with pytest.raises(
79
+ (MPAuthenticationError, ConnectionRefusedError, BrokenPipeError, OSError)
80
+ ):
81
+ attach_context(sock_path=sock, authkey=b"wrong")
82
+ finally:
83
+ mgr.shutdown()
84
+
85
+
86
+ def test_attach_nonexistent_raises(tmp_path: Path) -> None:
87
+ # ensure file truly doesn't exist
88
+ sock = str(tmp_path / "missing.sock")
89
+ if os.path.exists(sock):
90
+ os.unlink(sock)
91
+
92
+ with pytest.raises((FileNotFoundError, ConnectionRefusedError, OSError)):
93
+ attach_context(sock_path=sock)
94
+
95
+
96
+ @pytest.mark.asyncio
97
+ async def test_run_context_server_handles_keyboardinterrupt(
98
+ monkeypatch: pytest.MonkeyPatch, tmp_path: Path
99
+ ) -> None:
100
+ """run_context_server should call manager.shutdown() when KeyboardInterrupt occurs."""
101
+ # Capture serve_context() and the returned manager
102
+ called = {"served": False, "shutdown": False, "addr": None}
103
+
104
+ class _Mgr:
105
+ def shutdown(self) -> None:
106
+ called["shutdown"] = True
107
+
108
+ def fake_serve(ctx, sock_path, authkey):
109
+ called["served"] = True
110
+ called["addr"] = sock_path
111
+ return _Mgr()
112
+
113
+ monkeypatch.setattr("hud.server.context.serve_context", fake_serve)
114
+
115
+ # Make asyncio.Event().wait() raise KeyboardInterrupt immediately
116
+ class _FakeEvent:
117
+ async def wait(self) -> None:
118
+ raise KeyboardInterrupt
119
+
120
+ monkeypatch.setattr("hud.server.context.asyncio.Event", lambda: _FakeEvent())
121
+
122
+ from hud.server.context import run_context_server
123
+
124
+ await run_context_server(object(), sock_path=str(tmp_path / "ctx.sock"))
125
+
126
+ assert called["served"] is True
127
+ assert called["shutdown"] is True
128
+ assert str(called["addr"]).endswith("ctx.sock")
@@ -0,0 +1,44 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, cast
4
+
5
+ from hud.server import MCPServer
6
+ from hud.server.low_level import LowLevelServerWithInit
7
+
8
+
9
+ def test_notification_handlers_preserved_on_replacement():
10
+ """When init server replaces low-level server, notification handlers must be kept."""
11
+ mcp = MCPServer(name="PreserveNotif")
12
+
13
+ # Seed a fake notification handler on the pre-replacement server
14
+ before = mcp._mcp_server
15
+ cast("dict[Any, Any]", before.notification_handlers)["foo/notify"] = object()
16
+
17
+ @mcp.initialize
18
+ async def _init(_ctx) -> None:
19
+ pass
20
+
21
+ after = mcp._mcp_server
22
+ assert isinstance(after, LowLevelServerWithInit)
23
+ assert after is not before, "low-level server should be replaced once"
24
+ # Must still contain our seeded handler (dict is copied over)
25
+ assert "foo/notify" in after.notification_handlers
26
+
27
+
28
+ def test_init_server_replacement_is_idempotent():
29
+ """Second @initialize must NOT replace the low-level server again."""
30
+ mcp = MCPServer(name="InitIdempotent")
31
+
32
+ @mcp.initialize
33
+ async def _a(_ctx) -> None:
34
+ pass
35
+
36
+ first = mcp._mcp_server
37
+
38
+ @mcp.initialize
39
+ async def _b(_ctx) -> None:
40
+ # last initializer should win, but server object should not be replaced again
41
+ pass
42
+
43
+ second = mcp._mcp_server
44
+ assert first is second, "Server replacement should occur at most once per instance"