hud-python 0.4.20__py3-none-any.whl → 0.4.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hud-python might be problematic. Click here for more details.

@@ -0,0 +1,539 @@
1
+ """Tests for the comparator module."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pytest
6
+ from fastmcp.tools.tool import FunctionTool
7
+
8
+ from hud.native.comparator import (
9
+ CompareTool,
10
+ ComparisonMode,
11
+ ComparisonResult,
12
+ DataType,
13
+ auto_select_mode,
14
+ comparator_server,
15
+ detect_type,
16
+ extract_boolean,
17
+ extract_json,
18
+ extract_list,
19
+ extract_number,
20
+ )
21
+ from hud.tools.submit import set_submission
22
+ from hud.tools.types import EvaluationResult
23
+
24
+
25
+ class TestTypeDetection:
26
+ """Test type detection functionality."""
27
+
28
+ @pytest.mark.parametrize(
29
+ "value,expected_type",
30
+ [
31
+ # Booleans
32
+ ("true", DataType.BOOLEAN),
33
+ ("True", DataType.BOOLEAN),
34
+ ("false", DataType.BOOLEAN),
35
+ ("False", DataType.BOOLEAN),
36
+ # Integers
37
+ ("42", DataType.INTEGER),
38
+ ("-123", DataType.INTEGER),
39
+ ("0", DataType.INTEGER),
40
+ # Floats
41
+ ("3.14", DataType.FLOAT),
42
+ ("-2.5", DataType.FLOAT),
43
+ ("1e-6", DataType.FLOAT),
44
+ # JSON
45
+ ('{"key": "value"}', DataType.JSON),
46
+ ("[1, 2, 3]", DataType.JSON),
47
+ ('{"nested": {"value": 42}}', DataType.JSON),
48
+ # Text (fallback)
49
+ ("hello world", DataType.TEXT),
50
+ ("not a number", DataType.TEXT),
51
+ ("{invalid json", DataType.TEXT),
52
+ ("", DataType.TEXT),
53
+ ],
54
+ )
55
+ def test_detect_type(self, value, expected_type):
56
+ """Test type detection for various inputs."""
57
+ assert detect_type(value) == expected_type
58
+
59
+ def test_detect_type_none(self):
60
+ """Test type detection for None."""
61
+ assert detect_type(None) == DataType.TEXT
62
+
63
+
64
+ class TestAutoSelectMode:
65
+ """Test automatic mode selection based on types."""
66
+
67
+ @pytest.mark.parametrize(
68
+ "value_type,ref_type,expected_mode",
69
+ [
70
+ # JSON gets semantic
71
+ (DataType.JSON, DataType.JSON, ComparisonMode.SEMANTIC),
72
+ (DataType.JSON, DataType.TEXT, ComparisonMode.SEMANTIC),
73
+ (DataType.TEXT, DataType.JSON, ComparisonMode.SEMANTIC),
74
+ # Numeric gets numeric
75
+ (DataType.INTEGER, DataType.INTEGER, ComparisonMode.NUMERIC),
76
+ (DataType.FLOAT, DataType.FLOAT, ComparisonMode.NUMERIC),
77
+ (DataType.INTEGER, DataType.FLOAT, ComparisonMode.NUMERIC),
78
+ (DataType.FLOAT, DataType.TEXT, ComparisonMode.NUMERIC),
79
+ # Boolean gets exact
80
+ (DataType.BOOLEAN, DataType.BOOLEAN, ComparisonMode.EXACT),
81
+ (DataType.BOOLEAN, DataType.TEXT, ComparisonMode.EXACT),
82
+ # Text gets fuzzy
83
+ (DataType.TEXT, DataType.TEXT, ComparisonMode.FUZZY),
84
+ ],
85
+ )
86
+ def test_auto_select_mode(self, value_type, ref_type, expected_mode):
87
+ """Test mode selection logic."""
88
+ assert auto_select_mode(value_type, ref_type) == expected_mode
89
+
90
+
91
+ class TestCompareTool:
92
+ """Test the main CompareTool class."""
93
+
94
+ @pytest.mark.asyncio
95
+ async def test_scalar_comparison(self):
96
+ """Test comparing scalar values."""
97
+ tool = CompareTool()
98
+
99
+ # Exact match
100
+ result = await tool("hello", "hello", mode=ComparisonMode.EXACT)
101
+ assert isinstance(result, EvaluationResult)
102
+ assert result.done
103
+ assert result.reward == 1.0
104
+ assert "Single Exact: 1/1 matches" in result.content if result.content else False
105
+
106
+ # Fuzzy match
107
+ result = await tool("hello world", "hello wrld", mode=ComparisonMode.FUZZY)
108
+ assert isinstance(result, EvaluationResult)
109
+ assert result.done
110
+ assert 0.8 < result.reward < 1.0
111
+ assert "Single Fuzzy: 1/1 matches" in result.content if result.content else False
112
+
113
+ @pytest.mark.asyncio
114
+ async def test_list_comparison(self):
115
+ """Test comparing lists of values."""
116
+ tool = CompareTool()
117
+
118
+ # Exact lists
119
+ result = await tool(["a", "b", "c"], ["a", "b", "c"], mode=ComparisonMode.EXACT)
120
+ assert result.done
121
+ assert result.reward == 1.0
122
+ assert "Batch Exact: 3/3 matches" in result.content if result.content else False
123
+
124
+ # Partial match
125
+ result = await tool(["a", "b", "c"], ["a", "x", "c"], mode=ComparisonMode.EXACT)
126
+ assert not result.done
127
+ assert result.reward < 1.0
128
+ assert "Batch Exact: 2/3 matches" in result.content if result.content else False
129
+
130
+ @pytest.mark.asyncio
131
+ async def test_broadcast_comparison(self):
132
+ """Test broadcasting single value against list."""
133
+ tool = CompareTool()
134
+
135
+ # Single value vs list
136
+ result = await tool("a", ["a", "a", "a"], mode=ComparisonMode.EXACT)
137
+ assert result.done
138
+ assert result.reward == 1.0
139
+ assert "Broadcast Exact: 3/3 matches" in result.content if result.content else False
140
+
141
+ # List vs single value
142
+ result = await tool(["x", "x", "x"], "x", mode=ComparisonMode.EXACT)
143
+ assert result.done
144
+ assert result.reward == 1.0
145
+ assert "Broadcast Exact: 3/3 matches" in result.content if result.content else False
146
+
147
+ @pytest.mark.asyncio
148
+ async def test_submission_fallback(self):
149
+ """Test using submission when value is None."""
150
+ tool = CompareTool()
151
+
152
+ # Set submission
153
+ set_submission("test value")
154
+
155
+ # Compare without providing value
156
+ result = await tool(None, "test value", mode=ComparisonMode.EXACT)
157
+ assert result.done
158
+ assert result.reward == 1.0
159
+
160
+ # Clear submission
161
+ set_submission(None)
162
+
163
+ @pytest.mark.asyncio
164
+ async def test_error_cases(self):
165
+ """Test error handling."""
166
+ tool = CompareTool()
167
+
168
+ # Missing reference
169
+ result = await tool("value", None)
170
+ assert result.isError
171
+ assert "Missing value or reference" in result.content if result.content else False
172
+
173
+ # Mismatched list lengths
174
+ result = await tool([1, 2], [1, 2, 3])
175
+ assert not result.done
176
+ assert result.reward == 0.0
177
+ assert "Mismatched lengths" in result.content if result.content else False
178
+
179
+
180
+ class TestCompareToolModes:
181
+ """Test different comparison modes in CompareTool."""
182
+
183
+ @pytest.mark.asyncio
184
+ async def test_auto_mode(self):
185
+ """Test automatic mode detection."""
186
+ tool = CompareTool()
187
+
188
+ # Numbers should use numeric
189
+ result = await tool("42", "42.0", mode=ComparisonMode.AUTO)
190
+ assert result.done
191
+ assert result.reward == 1.0
192
+ assert "Numeric" in result.content if result.content else False
193
+
194
+ # JSON should use semantic
195
+ result = await tool('{"a": 1}', '{"a": 1}', mode=ComparisonMode.AUTO)
196
+ assert result.done
197
+ assert result.reward == 1.0
198
+ assert "Semantic" in result.content if result.content else False
199
+
200
+ @pytest.mark.asyncio
201
+ async def test_exact_mode(self):
202
+ """Test exact string comparison."""
203
+ tool = CompareTool()
204
+
205
+ # Exact match
206
+ result = await tool("hello", "hello", mode=ComparisonMode.EXACT)
207
+ assert result.done
208
+ assert result.reward == 1.0
209
+
210
+ # Different strings
211
+ result = await tool("42", "42.0", mode=ComparisonMode.EXACT)
212
+ assert not result.done
213
+ assert result.reward == 0.0
214
+
215
+ @pytest.mark.asyncio
216
+ async def test_fuzzy_mode(self):
217
+ """Test fuzzy text matching."""
218
+ tool = CompareTool()
219
+
220
+ # High similarity
221
+ result = await tool("hello world", "hello wrld", mode=ComparisonMode.FUZZY, threshold=0.8)
222
+ assert result.done
223
+ assert result.reward > 0.8
224
+
225
+ # Low similarity
226
+ result = await tool("hello", "goodbye", mode=ComparisonMode.FUZZY, threshold=0.9)
227
+ assert not result.done
228
+ assert result.reward < 0.5
229
+
230
+ @pytest.mark.asyncio
231
+ async def test_numeric_mode(self):
232
+ """Test numeric comparison with tolerance."""
233
+ tool = CompareTool()
234
+
235
+ # Within tolerance
236
+ result = await tool("1.0", "1.000001", mode=ComparisonMode.NUMERIC, tolerance=1e-5)
237
+ assert result.done
238
+ assert result.reward > 0.99
239
+
240
+ # Outside tolerance
241
+ result = await tool("1.0", "2.0", mode=ComparisonMode.NUMERIC, tolerance=0.1)
242
+ assert not result.done
243
+ assert result.reward <= 0.5
244
+
245
+ @pytest.mark.asyncio
246
+ async def test_semantic_mode(self):
247
+ """Test semantic comparison for various types."""
248
+ tool = CompareTool()
249
+
250
+ # JSON objects (same structure)
251
+ result = await tool('{"b": 2, "a": 1}', '{"a": 1, "b": 2}', mode=ComparisonMode.SEMANTIC)
252
+ assert result.done
253
+ assert result.reward == 1.0
254
+
255
+ # JSON arrays
256
+ result = await tool("[1, 2, 3]", "[1, 2, 3]", mode=ComparisonMode.SEMANTIC)
257
+ assert result.done
258
+ assert result.reward == 1.0
259
+
260
+ # Numbers via semantic (uses numeric comparison)
261
+ result = await tool("42", "42.0", mode=ComparisonMode.SEMANTIC, tolerance=1e-6)
262
+ assert result.done
263
+ assert result.reward == 1.0
264
+
265
+ # Booleans via semantic
266
+ result = await tool("true", "true", mode=ComparisonMode.SEMANTIC)
267
+ assert result.done
268
+ assert result.reward == 1.0
269
+
270
+ # Text fallback when not JSON
271
+ result = await tool(
272
+ "hello world", "hello wrld", mode=ComparisonMode.SEMANTIC, threshold=0.8
273
+ )
274
+ assert result.done
275
+ assert result.reward > 0.8
276
+
277
+
278
+ class TestComparisonResult:
279
+ """Test ComparisonResult model."""
280
+
281
+ def test_comparison_result_fields(self):
282
+ """Test ComparisonResult has all expected fields."""
283
+ result = ComparisonResult(
284
+ matches=True,
285
+ similarity=0.95,
286
+ detected_type=DataType.TEXT,
287
+ comparison_mode=ComparisonMode.FUZZY,
288
+ details={"threshold": 0.8},
289
+ )
290
+
291
+ assert result.matches is True
292
+ assert result.similarity == 0.95
293
+ assert result.detected_type == DataType.TEXT
294
+ assert result.comparison_mode == ComparisonMode.FUZZY
295
+ assert result.details["threshold"] == 0.8
296
+
297
+ def test_comparison_result_validation(self):
298
+ """Test ComparisonResult validation."""
299
+ # Valid similarity
300
+ result = ComparisonResult(
301
+ matches=True,
302
+ similarity=1.0,
303
+ detected_type=DataType.TEXT,
304
+ comparison_mode=ComparisonMode.EXACT,
305
+ )
306
+ assert result.similarity == 1.0
307
+
308
+ # Invalid similarity should raise
309
+ with pytest.raises(ValueError):
310
+ ComparisonResult(
311
+ matches=True,
312
+ similarity=1.5, # > 1.0
313
+ detected_type=DataType.TEXT,
314
+ comparison_mode=ComparisonMode.EXACT,
315
+ )
316
+
317
+
318
+ class TestAliasTools:
319
+ """Test the alias tools created by make_alias_tool."""
320
+
321
+ @pytest.mark.asyncio
322
+ async def test_aliases_work(self):
323
+ """Test that aliases are properly registered and work."""
324
+ from hud.native.comparator import comparator_server
325
+
326
+ # Check that aliases are registered
327
+ tool_names = [t.name for t in comparator_server._tool_manager._tools.values()]
328
+
329
+ expected_aliases = [
330
+ "compare_exact",
331
+ "compare_text",
332
+ "compare_string",
333
+ "compare_numeric",
334
+ "compare_float",
335
+ "compare_int",
336
+ "compare_json",
337
+ "compare_boolean",
338
+ "compare_list",
339
+ ]
340
+
341
+ for alias in expected_aliases:
342
+ assert alias in tool_names, f"Alias {alias} not found in registered tools"
343
+
344
+
345
+ class TestExtraction:
346
+ """Test extraction functions for handling LLM outputs."""
347
+
348
+ def test_json_extraction(self):
349
+ """Test JSON extraction from text."""
350
+ # Already valid JSON
351
+ assert extract_json('{"a": 1}') == '{"a": 1}'
352
+ assert extract_json("[1, 2, 3]") == "[1, 2, 3]"
353
+
354
+ # JSON embedded in text
355
+ assert extract_json('The answer is {"result": 42}') == '{"result": 42}'
356
+ assert extract_json('First {"a": 1} then {"b": 2}') == '{"b": 2}' # Last one
357
+ assert extract_json("The list is [1, 2, 3] and done") == "[1, 2, 3]"
358
+
359
+ # Complex nested JSON
360
+ complex_json = """{
361
+ "status": "success",
362
+ "data": {
363
+ "values": [1, 2, 3],
364
+ "metadata": {
365
+ "timestamp": "2024-01-01",
366
+ "version": "1.0"
367
+ }
368
+ }
369
+ }"""
370
+
371
+ llm_output = f"""
372
+ Let me analyze this request.
373
+
374
+ The final result is:
375
+ {complex_json}
376
+
377
+ This completes the analysis.
378
+ """
379
+
380
+ extracted = extract_json(llm_output)
381
+ import json
382
+
383
+ assert json.loads(extracted) == json.loads(complex_json)
384
+
385
+ # No JSON
386
+ assert extract_json("No JSON here") == "No JSON here"
387
+
388
+ def test_number_extraction(self):
389
+ """Test number extraction from text."""
390
+ # Plain numbers
391
+ assert extract_number("42") == "42"
392
+ assert extract_number("3.14") == "3.14"
393
+ assert extract_number("-123") == "-123"
394
+ assert extract_number("1.5e-10") == "1.5e-10"
395
+
396
+ # Numbers in text
397
+ assert extract_number("The answer is 42") == "42"
398
+ assert extract_number("First 10 then 20") == "20" # Last one
399
+ assert extract_number("Value: 3.14159") == "3.14159"
400
+
401
+ # Integer extraction
402
+ assert extract_number("42.0", "int") == "42"
403
+ assert extract_number("The count is 42.0 items", "int") == "42"
404
+
405
+ # No numbers
406
+ assert extract_number("No numbers here") == "No numbers here"
407
+
408
+ def test_boolean_extraction(self):
409
+ """Test boolean extraction from text."""
410
+ # Plain booleans
411
+ assert extract_boolean("true") == "true"
412
+ assert extract_boolean("false") == "false"
413
+ assert extract_boolean("True") == "true"
414
+ assert extract_boolean("FALSE") == "false"
415
+
416
+ # Booleans in text
417
+ assert extract_boolean("The answer is True") == "true"
418
+ assert extract_boolean("First false then TRUE") == "true" # Last one
419
+
420
+ # No booleans
421
+ assert extract_boolean("No booleans here") == "No booleans here"
422
+
423
+ def test_list_extraction(self):
424
+ """Test list extraction (uses JSON extraction)."""
425
+ assert extract_list("[1, 2, 3]") == "[1, 2, 3]"
426
+ assert extract_list('The array is ["a", "b", "c"]') == '["a", "b", "c"]'
427
+ assert extract_list("No lists here") == "No lists here"
428
+
429
+
430
+ class TestAliasPreprocessing:
431
+ """Test that alias tools correctly preprocess LLM outputs."""
432
+
433
+ @pytest.mark.asyncio
434
+ async def test_json_alias_preprocessing(self):
435
+ """Test JSON extraction in compare_json tool."""
436
+ tools = {t.name: t for t in comparator_server._tool_manager._tools.values()}
437
+ json_tool = tools["compare_json"]
438
+
439
+ assert isinstance(json_tool, FunctionTool)
440
+ result = await json_tool.fn(
441
+ value='The model thinks the answer is {"result": 42, "confidence": 0.9}',
442
+ reference='{"result": 42, "confidence": 0.9}',
443
+ )
444
+ assert result.done
445
+ assert result.reward == 1.0
446
+ assert "Semantic" in result.content
447
+
448
+ @pytest.mark.asyncio
449
+ async def test_numeric_alias_preprocessing(self):
450
+ """Test number extraction in numeric tools."""
451
+ tools = {t.name: t for t in comparator_server._tool_manager._tools.values()}
452
+
453
+ # Float tool
454
+ float_tool = tools["compare_float"]
455
+ assert isinstance(float_tool, FunctionTool)
456
+ result = await float_tool.fn(
457
+ value="After careful calculation, the answer is 3.14159", reference="3.14159"
458
+ )
459
+ assert result.done
460
+ assert result.reward == 1.0
461
+ assert "Numeric" in result.content
462
+
463
+ # Integer tool
464
+ int_tool = tools["compare_int"]
465
+ assert isinstance(int_tool, FunctionTool)
466
+ result = await int_tool.fn(value="The count is exactly 42 items", reference="42")
467
+ assert result.done
468
+ assert result.reward == 1.0
469
+ assert "Numeric" in result.content
470
+
471
+ @pytest.mark.asyncio
472
+ async def test_boolean_alias_preprocessing(self):
473
+ """Test boolean extraction in compare_boolean tool."""
474
+ tools = {t.name: t for t in comparator_server._tool_manager._tools.values()}
475
+ bool_tool = tools["compare_boolean"]
476
+
477
+ assert isinstance(bool_tool, FunctionTool)
478
+ result = await bool_tool.fn(
479
+ value="Based on the analysis, the statement is TRUE", reference="true"
480
+ )
481
+ assert result.done
482
+ assert result.reward == 1.0
483
+ assert "Exact" in result.content
484
+
485
+ @pytest.mark.asyncio
486
+ async def test_list_alias_preprocessing(self):
487
+ """Test list extraction in compare_list tool."""
488
+ tools = {t.name: t for t in comparator_server._tool_manager._tools.values()}
489
+ list_tool = tools["compare_list"]
490
+
491
+ assert isinstance(list_tool, FunctionTool)
492
+ result = await list_tool.fn(
493
+ value="The sorted results are [1, 2, 3, 4, 5]", reference="[1, 2, 3, 4, 5]"
494
+ )
495
+ assert result.done
496
+ assert result.reward == 1.0
497
+ assert "Semantic" in result.content
498
+
499
+ @pytest.mark.asyncio
500
+ async def test_complex_llm_output(self):
501
+ """Test extraction from complex LLM outputs with reasoning."""
502
+ tools = {t.name: t for t in comparator_server._tool_manager._tools.values()}
503
+ json_tool = tools["compare_json"]
504
+
505
+ llm_output = """
506
+ Let me analyze this request step by step.
507
+
508
+ First, I'll process the data:
509
+ - Item 1: processed
510
+ - Item 2: processed with value 42
511
+
512
+ After careful consideration, the final result is:
513
+ {
514
+ "status": "success",
515
+ "data": {
516
+ "values": [1, 2, 3],
517
+ "metadata": {
518
+ "timestamp": "2024-01-01",
519
+ "version": "1.0"
520
+ }
521
+ }
522
+ }
523
+
524
+ This completes the analysis. The JSON above contains all the required information.
525
+ """
526
+
527
+ reference = """
528
+ {"status": "success", "data": {"values": [1, 2, 3],
529
+ "metadata": {"timestamp": "2024-01-01", "version": "1.0"}}}
530
+ """
531
+
532
+ assert isinstance(json_tool, FunctionTool)
533
+ result = await json_tool.fn(value=llm_output, reference=reference)
534
+ assert result.done
535
+ assert result.reward == 1.0
536
+
537
+
538
+ if __name__ == "__main__":
539
+ pytest.main([__file__, "-v"])
@@ -0,0 +1,79 @@
1
+ """Tests for hud.native.__init__ module."""
2
+
3
+ from __future__ import annotations
4
+
5
+
6
+ class TestNativeInit:
7
+ """Tests for the native package initialization."""
8
+
9
+ def test_comparator_server_import(self):
10
+ """Test that comparator server can be imported."""
11
+ from hud.native.comparator import comparator_server
12
+ from hud.server import MCPServer
13
+
14
+ # Verify comparator is an MCPServer instance
15
+ assert isinstance(comparator_server, MCPServer)
16
+ assert comparator_server.name == "comparator"
17
+
18
+ def test_all_exports(self):
19
+ """Test that __all__ is properly defined."""
20
+ import hud.native.comparator
21
+
22
+ expected_exports = ["comparator"]
23
+
24
+ # Check __all__ exists and contains expected exports
25
+ assert hasattr(hud.native.comparator, "__all__")
26
+ assert hud.native.comparator.__all__ == expected_exports
27
+
28
+ # Verify all items in __all__ are actually available
29
+ for item in hud.native.comparator.__all__:
30
+ assert hasattr(hud.native.comparator, item)
31
+
32
+ def test_comparator_tools_registered(self):
33
+ """Test that comparator server has tools registered."""
34
+ from hud.native.comparator import comparator_server
35
+
36
+ # The server should have tools registered
37
+ # We can check that the tool manager has tools
38
+ tool_names = [t.name for t in comparator_server._tool_manager._tools.values()]
39
+
40
+ # Should have the main compare tool
41
+ assert "compare" in tool_names
42
+
43
+ # Should have the submit tool
44
+ assert "submit" in tool_names
45
+
46
+ # Should have all the alias tools
47
+ expected_aliases = [
48
+ "compare_exact",
49
+ "compare_text",
50
+ "compare_string",
51
+ "compare_numeric",
52
+ "compare_float",
53
+ "compare_int",
54
+ "compare_json",
55
+ "compare_boolean",
56
+ "compare_list",
57
+ ]
58
+
59
+ for alias in expected_aliases:
60
+ assert alias in tool_names, f"Expected alias {alias} not found"
61
+
62
+ # Total should be 1 (submit) + 1 (compare) + 9 (aliases) = 11 tools
63
+ assert len(tool_names) == 11
64
+
65
+ def test_comparator_tool_functionality(self):
66
+ """Test that we can get the CompareTool from the comparator."""
67
+ from hud.native.comparator import comparator_server
68
+ from hud.tools import BaseTool
69
+
70
+ # Get the compare tool
71
+ compare_tool = None
72
+ for tool in comparator_server._tool_manager._tools.values():
73
+ if tool.name == "compare":
74
+ compare_tool = tool
75
+ break
76
+
77
+ assert compare_tool is not None
78
+ assert isinstance(compare_tool, BaseTool)
79
+ assert hasattr(compare_tool, "__call__")
@@ -16,8 +16,6 @@ if TYPE_CHECKING:
16
16
 
17
17
  logger = logging.getLogger(__name__)
18
18
 
19
- LIFECYCLE_TOOLS = {"setup", "evaluate"}
20
-
21
19
 
22
20
  def install_mcp_instrumentation(provider: TracerProvider) -> None:
23
21
  """Enable community MCP OpenTelemetry instrumentation if present.
hud/server/server.py CHANGED
@@ -150,9 +150,13 @@ class MCPServer(FastMCP):
150
150
  super().__init__(name=name, **fastmcp_kwargs)
151
151
  self._initializer_fn: Callable | None = None
152
152
  self._did_init = False
153
+ self._replaced_server = False
154
+
155
+ def _replace_with_init_server(self) -> None:
156
+ """Replace the low-level server with init version when needed."""
157
+ if self._replaced_server:
158
+ return
153
159
 
154
- # Replace FastMCP's low-level server with our version that supports
155
- # per-server initialization hooks
156
160
  def _run_init(ctx: RequestContext | None = None) -> Any:
157
161
  if self._initializer_fn is not None and not self._did_init:
158
162
  self._did_init = True
@@ -177,6 +181,7 @@ class MCPServer(FastMCP):
177
181
  # Copy handlers from the old server to the new one
178
182
  self._mcp_server.request_handlers = old_request_handlers
179
183
  self._mcp_server.notification_handlers = old_notification_handlers
184
+ self._replaced_server = True
180
185
 
181
186
  # Initializer decorator: runs on the initialize request
182
187
  # The decorated function receives a RequestContext object with access to:
@@ -186,6 +191,8 @@ class MCPServer(FastMCP):
186
191
  def initialize(self, fn: Callable | None = None) -> Callable | None:
187
192
  def decorator(func: Callable) -> Callable:
188
193
  self._initializer_fn = func
194
+ # Only replace server when there's actually an init handler
195
+ self._replace_with_init_server()
189
196
  return func
190
197
 
191
198
  return decorator(fn) if fn else decorator