hud-python 0.4.20__py3-none-any.whl → 0.4.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hud-python might be problematic. Click here for more details.
- hud/__init__.py +7 -0
- hud/agents/base.py +40 -10
- hud/agents/claude.py +13 -8
- hud/agents/tests/test_client.py +6 -27
- hud/cli/__init__.py +50 -20
- hud/cli/build.py +3 -44
- hud/cli/eval.py +25 -6
- hud/cli/init.py +4 -4
- hud/cli/push.py +3 -1
- hud/cli/tests/test_push.py +6 -6
- hud/clients/__init__.py +3 -2
- hud/clients/base.py +20 -9
- hud/clients/mcp_use.py +44 -22
- hud/datasets/task.py +6 -2
- hud/native/__init__.py +6 -0
- hud/native/comparator.py +546 -0
- hud/native/tests/__init__.py +1 -0
- hud/native/tests/test_comparator.py +539 -0
- hud/native/tests/test_native_init.py +79 -0
- hud/otel/instrumentation.py +0 -2
- hud/server/server.py +9 -2
- hud/shared/exceptions.py +204 -31
- hud/shared/hints.py +177 -0
- hud/shared/requests.py +15 -3
- hud/shared/tests/test_exceptions.py +385 -144
- hud/tools/__init__.py +2 -0
- hud/tools/submit.py +66 -0
- hud/types.py +33 -5
- hud/utils/design.py +57 -0
- hud/utils/mcp.py +6 -0
- hud/utils/pretty_errors.py +68 -0
- hud/utils/tests/test_version.py +1 -1
- hud/version.py +1 -1
- {hud_python-0.4.20.dist-info → hud_python-0.4.21.dist-info}/METADATA +2 -4
- {hud_python-0.4.20.dist-info → hud_python-0.4.21.dist-info}/RECORD +38 -30
- {hud_python-0.4.20.dist-info → hud_python-0.4.21.dist-info}/WHEEL +0 -0
- {hud_python-0.4.20.dist-info → hud_python-0.4.21.dist-info}/entry_points.txt +0 -0
- {hud_python-0.4.20.dist-info → hud_python-0.4.21.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,539 @@
|
|
|
1
|
+
"""Tests for the comparator module."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
from fastmcp.tools.tool import FunctionTool
|
|
7
|
+
|
|
8
|
+
from hud.native.comparator import (
|
|
9
|
+
CompareTool,
|
|
10
|
+
ComparisonMode,
|
|
11
|
+
ComparisonResult,
|
|
12
|
+
DataType,
|
|
13
|
+
auto_select_mode,
|
|
14
|
+
comparator_server,
|
|
15
|
+
detect_type,
|
|
16
|
+
extract_boolean,
|
|
17
|
+
extract_json,
|
|
18
|
+
extract_list,
|
|
19
|
+
extract_number,
|
|
20
|
+
)
|
|
21
|
+
from hud.tools.submit import set_submission
|
|
22
|
+
from hud.tools.types import EvaluationResult
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class TestTypeDetection:
|
|
26
|
+
"""Test type detection functionality."""
|
|
27
|
+
|
|
28
|
+
@pytest.mark.parametrize(
|
|
29
|
+
"value,expected_type",
|
|
30
|
+
[
|
|
31
|
+
# Booleans
|
|
32
|
+
("true", DataType.BOOLEAN),
|
|
33
|
+
("True", DataType.BOOLEAN),
|
|
34
|
+
("false", DataType.BOOLEAN),
|
|
35
|
+
("False", DataType.BOOLEAN),
|
|
36
|
+
# Integers
|
|
37
|
+
("42", DataType.INTEGER),
|
|
38
|
+
("-123", DataType.INTEGER),
|
|
39
|
+
("0", DataType.INTEGER),
|
|
40
|
+
# Floats
|
|
41
|
+
("3.14", DataType.FLOAT),
|
|
42
|
+
("-2.5", DataType.FLOAT),
|
|
43
|
+
("1e-6", DataType.FLOAT),
|
|
44
|
+
# JSON
|
|
45
|
+
('{"key": "value"}', DataType.JSON),
|
|
46
|
+
("[1, 2, 3]", DataType.JSON),
|
|
47
|
+
('{"nested": {"value": 42}}', DataType.JSON),
|
|
48
|
+
# Text (fallback)
|
|
49
|
+
("hello world", DataType.TEXT),
|
|
50
|
+
("not a number", DataType.TEXT),
|
|
51
|
+
("{invalid json", DataType.TEXT),
|
|
52
|
+
("", DataType.TEXT),
|
|
53
|
+
],
|
|
54
|
+
)
|
|
55
|
+
def test_detect_type(self, value, expected_type):
|
|
56
|
+
"""Test type detection for various inputs."""
|
|
57
|
+
assert detect_type(value) == expected_type
|
|
58
|
+
|
|
59
|
+
def test_detect_type_none(self):
|
|
60
|
+
"""Test type detection for None."""
|
|
61
|
+
assert detect_type(None) == DataType.TEXT
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class TestAutoSelectMode:
|
|
65
|
+
"""Test automatic mode selection based on types."""
|
|
66
|
+
|
|
67
|
+
@pytest.mark.parametrize(
|
|
68
|
+
"value_type,ref_type,expected_mode",
|
|
69
|
+
[
|
|
70
|
+
# JSON gets semantic
|
|
71
|
+
(DataType.JSON, DataType.JSON, ComparisonMode.SEMANTIC),
|
|
72
|
+
(DataType.JSON, DataType.TEXT, ComparisonMode.SEMANTIC),
|
|
73
|
+
(DataType.TEXT, DataType.JSON, ComparisonMode.SEMANTIC),
|
|
74
|
+
# Numeric gets numeric
|
|
75
|
+
(DataType.INTEGER, DataType.INTEGER, ComparisonMode.NUMERIC),
|
|
76
|
+
(DataType.FLOAT, DataType.FLOAT, ComparisonMode.NUMERIC),
|
|
77
|
+
(DataType.INTEGER, DataType.FLOAT, ComparisonMode.NUMERIC),
|
|
78
|
+
(DataType.FLOAT, DataType.TEXT, ComparisonMode.NUMERIC),
|
|
79
|
+
# Boolean gets exact
|
|
80
|
+
(DataType.BOOLEAN, DataType.BOOLEAN, ComparisonMode.EXACT),
|
|
81
|
+
(DataType.BOOLEAN, DataType.TEXT, ComparisonMode.EXACT),
|
|
82
|
+
# Text gets fuzzy
|
|
83
|
+
(DataType.TEXT, DataType.TEXT, ComparisonMode.FUZZY),
|
|
84
|
+
],
|
|
85
|
+
)
|
|
86
|
+
def test_auto_select_mode(self, value_type, ref_type, expected_mode):
|
|
87
|
+
"""Test mode selection logic."""
|
|
88
|
+
assert auto_select_mode(value_type, ref_type) == expected_mode
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class TestCompareTool:
|
|
92
|
+
"""Test the main CompareTool class."""
|
|
93
|
+
|
|
94
|
+
@pytest.mark.asyncio
|
|
95
|
+
async def test_scalar_comparison(self):
|
|
96
|
+
"""Test comparing scalar values."""
|
|
97
|
+
tool = CompareTool()
|
|
98
|
+
|
|
99
|
+
# Exact match
|
|
100
|
+
result = await tool("hello", "hello", mode=ComparisonMode.EXACT)
|
|
101
|
+
assert isinstance(result, EvaluationResult)
|
|
102
|
+
assert result.done
|
|
103
|
+
assert result.reward == 1.0
|
|
104
|
+
assert "Single Exact: 1/1 matches" in result.content if result.content else False
|
|
105
|
+
|
|
106
|
+
# Fuzzy match
|
|
107
|
+
result = await tool("hello world", "hello wrld", mode=ComparisonMode.FUZZY)
|
|
108
|
+
assert isinstance(result, EvaluationResult)
|
|
109
|
+
assert result.done
|
|
110
|
+
assert 0.8 < result.reward < 1.0
|
|
111
|
+
assert "Single Fuzzy: 1/1 matches" in result.content if result.content else False
|
|
112
|
+
|
|
113
|
+
@pytest.mark.asyncio
|
|
114
|
+
async def test_list_comparison(self):
|
|
115
|
+
"""Test comparing lists of values."""
|
|
116
|
+
tool = CompareTool()
|
|
117
|
+
|
|
118
|
+
# Exact lists
|
|
119
|
+
result = await tool(["a", "b", "c"], ["a", "b", "c"], mode=ComparisonMode.EXACT)
|
|
120
|
+
assert result.done
|
|
121
|
+
assert result.reward == 1.0
|
|
122
|
+
assert "Batch Exact: 3/3 matches" in result.content if result.content else False
|
|
123
|
+
|
|
124
|
+
# Partial match
|
|
125
|
+
result = await tool(["a", "b", "c"], ["a", "x", "c"], mode=ComparisonMode.EXACT)
|
|
126
|
+
assert not result.done
|
|
127
|
+
assert result.reward < 1.0
|
|
128
|
+
assert "Batch Exact: 2/3 matches" in result.content if result.content else False
|
|
129
|
+
|
|
130
|
+
@pytest.mark.asyncio
|
|
131
|
+
async def test_broadcast_comparison(self):
|
|
132
|
+
"""Test broadcasting single value against list."""
|
|
133
|
+
tool = CompareTool()
|
|
134
|
+
|
|
135
|
+
# Single value vs list
|
|
136
|
+
result = await tool("a", ["a", "a", "a"], mode=ComparisonMode.EXACT)
|
|
137
|
+
assert result.done
|
|
138
|
+
assert result.reward == 1.0
|
|
139
|
+
assert "Broadcast Exact: 3/3 matches" in result.content if result.content else False
|
|
140
|
+
|
|
141
|
+
# List vs single value
|
|
142
|
+
result = await tool(["x", "x", "x"], "x", mode=ComparisonMode.EXACT)
|
|
143
|
+
assert result.done
|
|
144
|
+
assert result.reward == 1.0
|
|
145
|
+
assert "Broadcast Exact: 3/3 matches" in result.content if result.content else False
|
|
146
|
+
|
|
147
|
+
@pytest.mark.asyncio
|
|
148
|
+
async def test_submission_fallback(self):
|
|
149
|
+
"""Test using submission when value is None."""
|
|
150
|
+
tool = CompareTool()
|
|
151
|
+
|
|
152
|
+
# Set submission
|
|
153
|
+
set_submission("test value")
|
|
154
|
+
|
|
155
|
+
# Compare without providing value
|
|
156
|
+
result = await tool(None, "test value", mode=ComparisonMode.EXACT)
|
|
157
|
+
assert result.done
|
|
158
|
+
assert result.reward == 1.0
|
|
159
|
+
|
|
160
|
+
# Clear submission
|
|
161
|
+
set_submission(None)
|
|
162
|
+
|
|
163
|
+
@pytest.mark.asyncio
|
|
164
|
+
async def test_error_cases(self):
|
|
165
|
+
"""Test error handling."""
|
|
166
|
+
tool = CompareTool()
|
|
167
|
+
|
|
168
|
+
# Missing reference
|
|
169
|
+
result = await tool("value", None)
|
|
170
|
+
assert result.isError
|
|
171
|
+
assert "Missing value or reference" in result.content if result.content else False
|
|
172
|
+
|
|
173
|
+
# Mismatched list lengths
|
|
174
|
+
result = await tool([1, 2], [1, 2, 3])
|
|
175
|
+
assert not result.done
|
|
176
|
+
assert result.reward == 0.0
|
|
177
|
+
assert "Mismatched lengths" in result.content if result.content else False
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class TestCompareToolModes:
|
|
181
|
+
"""Test different comparison modes in CompareTool."""
|
|
182
|
+
|
|
183
|
+
@pytest.mark.asyncio
|
|
184
|
+
async def test_auto_mode(self):
|
|
185
|
+
"""Test automatic mode detection."""
|
|
186
|
+
tool = CompareTool()
|
|
187
|
+
|
|
188
|
+
# Numbers should use numeric
|
|
189
|
+
result = await tool("42", "42.0", mode=ComparisonMode.AUTO)
|
|
190
|
+
assert result.done
|
|
191
|
+
assert result.reward == 1.0
|
|
192
|
+
assert "Numeric" in result.content if result.content else False
|
|
193
|
+
|
|
194
|
+
# JSON should use semantic
|
|
195
|
+
result = await tool('{"a": 1}', '{"a": 1}', mode=ComparisonMode.AUTO)
|
|
196
|
+
assert result.done
|
|
197
|
+
assert result.reward == 1.0
|
|
198
|
+
assert "Semantic" in result.content if result.content else False
|
|
199
|
+
|
|
200
|
+
@pytest.mark.asyncio
|
|
201
|
+
async def test_exact_mode(self):
|
|
202
|
+
"""Test exact string comparison."""
|
|
203
|
+
tool = CompareTool()
|
|
204
|
+
|
|
205
|
+
# Exact match
|
|
206
|
+
result = await tool("hello", "hello", mode=ComparisonMode.EXACT)
|
|
207
|
+
assert result.done
|
|
208
|
+
assert result.reward == 1.0
|
|
209
|
+
|
|
210
|
+
# Different strings
|
|
211
|
+
result = await tool("42", "42.0", mode=ComparisonMode.EXACT)
|
|
212
|
+
assert not result.done
|
|
213
|
+
assert result.reward == 0.0
|
|
214
|
+
|
|
215
|
+
@pytest.mark.asyncio
|
|
216
|
+
async def test_fuzzy_mode(self):
|
|
217
|
+
"""Test fuzzy text matching."""
|
|
218
|
+
tool = CompareTool()
|
|
219
|
+
|
|
220
|
+
# High similarity
|
|
221
|
+
result = await tool("hello world", "hello wrld", mode=ComparisonMode.FUZZY, threshold=0.8)
|
|
222
|
+
assert result.done
|
|
223
|
+
assert result.reward > 0.8
|
|
224
|
+
|
|
225
|
+
# Low similarity
|
|
226
|
+
result = await tool("hello", "goodbye", mode=ComparisonMode.FUZZY, threshold=0.9)
|
|
227
|
+
assert not result.done
|
|
228
|
+
assert result.reward < 0.5
|
|
229
|
+
|
|
230
|
+
@pytest.mark.asyncio
|
|
231
|
+
async def test_numeric_mode(self):
|
|
232
|
+
"""Test numeric comparison with tolerance."""
|
|
233
|
+
tool = CompareTool()
|
|
234
|
+
|
|
235
|
+
# Within tolerance
|
|
236
|
+
result = await tool("1.0", "1.000001", mode=ComparisonMode.NUMERIC, tolerance=1e-5)
|
|
237
|
+
assert result.done
|
|
238
|
+
assert result.reward > 0.99
|
|
239
|
+
|
|
240
|
+
# Outside tolerance
|
|
241
|
+
result = await tool("1.0", "2.0", mode=ComparisonMode.NUMERIC, tolerance=0.1)
|
|
242
|
+
assert not result.done
|
|
243
|
+
assert result.reward <= 0.5
|
|
244
|
+
|
|
245
|
+
@pytest.mark.asyncio
|
|
246
|
+
async def test_semantic_mode(self):
|
|
247
|
+
"""Test semantic comparison for various types."""
|
|
248
|
+
tool = CompareTool()
|
|
249
|
+
|
|
250
|
+
# JSON objects (same structure)
|
|
251
|
+
result = await tool('{"b": 2, "a": 1}', '{"a": 1, "b": 2}', mode=ComparisonMode.SEMANTIC)
|
|
252
|
+
assert result.done
|
|
253
|
+
assert result.reward == 1.0
|
|
254
|
+
|
|
255
|
+
# JSON arrays
|
|
256
|
+
result = await tool("[1, 2, 3]", "[1, 2, 3]", mode=ComparisonMode.SEMANTIC)
|
|
257
|
+
assert result.done
|
|
258
|
+
assert result.reward == 1.0
|
|
259
|
+
|
|
260
|
+
# Numbers via semantic (uses numeric comparison)
|
|
261
|
+
result = await tool("42", "42.0", mode=ComparisonMode.SEMANTIC, tolerance=1e-6)
|
|
262
|
+
assert result.done
|
|
263
|
+
assert result.reward == 1.0
|
|
264
|
+
|
|
265
|
+
# Booleans via semantic
|
|
266
|
+
result = await tool("true", "true", mode=ComparisonMode.SEMANTIC)
|
|
267
|
+
assert result.done
|
|
268
|
+
assert result.reward == 1.0
|
|
269
|
+
|
|
270
|
+
# Text fallback when not JSON
|
|
271
|
+
result = await tool(
|
|
272
|
+
"hello world", "hello wrld", mode=ComparisonMode.SEMANTIC, threshold=0.8
|
|
273
|
+
)
|
|
274
|
+
assert result.done
|
|
275
|
+
assert result.reward > 0.8
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
class TestComparisonResult:
|
|
279
|
+
"""Test ComparisonResult model."""
|
|
280
|
+
|
|
281
|
+
def test_comparison_result_fields(self):
|
|
282
|
+
"""Test ComparisonResult has all expected fields."""
|
|
283
|
+
result = ComparisonResult(
|
|
284
|
+
matches=True,
|
|
285
|
+
similarity=0.95,
|
|
286
|
+
detected_type=DataType.TEXT,
|
|
287
|
+
comparison_mode=ComparisonMode.FUZZY,
|
|
288
|
+
details={"threshold": 0.8},
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
assert result.matches is True
|
|
292
|
+
assert result.similarity == 0.95
|
|
293
|
+
assert result.detected_type == DataType.TEXT
|
|
294
|
+
assert result.comparison_mode == ComparisonMode.FUZZY
|
|
295
|
+
assert result.details["threshold"] == 0.8
|
|
296
|
+
|
|
297
|
+
def test_comparison_result_validation(self):
|
|
298
|
+
"""Test ComparisonResult validation."""
|
|
299
|
+
# Valid similarity
|
|
300
|
+
result = ComparisonResult(
|
|
301
|
+
matches=True,
|
|
302
|
+
similarity=1.0,
|
|
303
|
+
detected_type=DataType.TEXT,
|
|
304
|
+
comparison_mode=ComparisonMode.EXACT,
|
|
305
|
+
)
|
|
306
|
+
assert result.similarity == 1.0
|
|
307
|
+
|
|
308
|
+
# Invalid similarity should raise
|
|
309
|
+
with pytest.raises(ValueError):
|
|
310
|
+
ComparisonResult(
|
|
311
|
+
matches=True,
|
|
312
|
+
similarity=1.5, # > 1.0
|
|
313
|
+
detected_type=DataType.TEXT,
|
|
314
|
+
comparison_mode=ComparisonMode.EXACT,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
class TestAliasTools:
|
|
319
|
+
"""Test the alias tools created by make_alias_tool."""
|
|
320
|
+
|
|
321
|
+
@pytest.mark.asyncio
|
|
322
|
+
async def test_aliases_work(self):
|
|
323
|
+
"""Test that aliases are properly registered and work."""
|
|
324
|
+
from hud.native.comparator import comparator_server
|
|
325
|
+
|
|
326
|
+
# Check that aliases are registered
|
|
327
|
+
tool_names = [t.name for t in comparator_server._tool_manager._tools.values()]
|
|
328
|
+
|
|
329
|
+
expected_aliases = [
|
|
330
|
+
"compare_exact",
|
|
331
|
+
"compare_text",
|
|
332
|
+
"compare_string",
|
|
333
|
+
"compare_numeric",
|
|
334
|
+
"compare_float",
|
|
335
|
+
"compare_int",
|
|
336
|
+
"compare_json",
|
|
337
|
+
"compare_boolean",
|
|
338
|
+
"compare_list",
|
|
339
|
+
]
|
|
340
|
+
|
|
341
|
+
for alias in expected_aliases:
|
|
342
|
+
assert alias in tool_names, f"Alias {alias} not found in registered tools"
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
class TestExtraction:
|
|
346
|
+
"""Test extraction functions for handling LLM outputs."""
|
|
347
|
+
|
|
348
|
+
def test_json_extraction(self):
|
|
349
|
+
"""Test JSON extraction from text."""
|
|
350
|
+
# Already valid JSON
|
|
351
|
+
assert extract_json('{"a": 1}') == '{"a": 1}'
|
|
352
|
+
assert extract_json("[1, 2, 3]") == "[1, 2, 3]"
|
|
353
|
+
|
|
354
|
+
# JSON embedded in text
|
|
355
|
+
assert extract_json('The answer is {"result": 42}') == '{"result": 42}'
|
|
356
|
+
assert extract_json('First {"a": 1} then {"b": 2}') == '{"b": 2}' # Last one
|
|
357
|
+
assert extract_json("The list is [1, 2, 3] and done") == "[1, 2, 3]"
|
|
358
|
+
|
|
359
|
+
# Complex nested JSON
|
|
360
|
+
complex_json = """{
|
|
361
|
+
"status": "success",
|
|
362
|
+
"data": {
|
|
363
|
+
"values": [1, 2, 3],
|
|
364
|
+
"metadata": {
|
|
365
|
+
"timestamp": "2024-01-01",
|
|
366
|
+
"version": "1.0"
|
|
367
|
+
}
|
|
368
|
+
}
|
|
369
|
+
}"""
|
|
370
|
+
|
|
371
|
+
llm_output = f"""
|
|
372
|
+
Let me analyze this request.
|
|
373
|
+
|
|
374
|
+
The final result is:
|
|
375
|
+
{complex_json}
|
|
376
|
+
|
|
377
|
+
This completes the analysis.
|
|
378
|
+
"""
|
|
379
|
+
|
|
380
|
+
extracted = extract_json(llm_output)
|
|
381
|
+
import json
|
|
382
|
+
|
|
383
|
+
assert json.loads(extracted) == json.loads(complex_json)
|
|
384
|
+
|
|
385
|
+
# No JSON
|
|
386
|
+
assert extract_json("No JSON here") == "No JSON here"
|
|
387
|
+
|
|
388
|
+
def test_number_extraction(self):
|
|
389
|
+
"""Test number extraction from text."""
|
|
390
|
+
# Plain numbers
|
|
391
|
+
assert extract_number("42") == "42"
|
|
392
|
+
assert extract_number("3.14") == "3.14"
|
|
393
|
+
assert extract_number("-123") == "-123"
|
|
394
|
+
assert extract_number("1.5e-10") == "1.5e-10"
|
|
395
|
+
|
|
396
|
+
# Numbers in text
|
|
397
|
+
assert extract_number("The answer is 42") == "42"
|
|
398
|
+
assert extract_number("First 10 then 20") == "20" # Last one
|
|
399
|
+
assert extract_number("Value: 3.14159") == "3.14159"
|
|
400
|
+
|
|
401
|
+
# Integer extraction
|
|
402
|
+
assert extract_number("42.0", "int") == "42"
|
|
403
|
+
assert extract_number("The count is 42.0 items", "int") == "42"
|
|
404
|
+
|
|
405
|
+
# No numbers
|
|
406
|
+
assert extract_number("No numbers here") == "No numbers here"
|
|
407
|
+
|
|
408
|
+
def test_boolean_extraction(self):
|
|
409
|
+
"""Test boolean extraction from text."""
|
|
410
|
+
# Plain booleans
|
|
411
|
+
assert extract_boolean("true") == "true"
|
|
412
|
+
assert extract_boolean("false") == "false"
|
|
413
|
+
assert extract_boolean("True") == "true"
|
|
414
|
+
assert extract_boolean("FALSE") == "false"
|
|
415
|
+
|
|
416
|
+
# Booleans in text
|
|
417
|
+
assert extract_boolean("The answer is True") == "true"
|
|
418
|
+
assert extract_boolean("First false then TRUE") == "true" # Last one
|
|
419
|
+
|
|
420
|
+
# No booleans
|
|
421
|
+
assert extract_boolean("No booleans here") == "No booleans here"
|
|
422
|
+
|
|
423
|
+
def test_list_extraction(self):
|
|
424
|
+
"""Test list extraction (uses JSON extraction)."""
|
|
425
|
+
assert extract_list("[1, 2, 3]") == "[1, 2, 3]"
|
|
426
|
+
assert extract_list('The array is ["a", "b", "c"]') == '["a", "b", "c"]'
|
|
427
|
+
assert extract_list("No lists here") == "No lists here"
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
class TestAliasPreprocessing:
|
|
431
|
+
"""Test that alias tools correctly preprocess LLM outputs."""
|
|
432
|
+
|
|
433
|
+
@pytest.mark.asyncio
|
|
434
|
+
async def test_json_alias_preprocessing(self):
|
|
435
|
+
"""Test JSON extraction in compare_json tool."""
|
|
436
|
+
tools = {t.name: t for t in comparator_server._tool_manager._tools.values()}
|
|
437
|
+
json_tool = tools["compare_json"]
|
|
438
|
+
|
|
439
|
+
assert isinstance(json_tool, FunctionTool)
|
|
440
|
+
result = await json_tool.fn(
|
|
441
|
+
value='The model thinks the answer is {"result": 42, "confidence": 0.9}',
|
|
442
|
+
reference='{"result": 42, "confidence": 0.9}',
|
|
443
|
+
)
|
|
444
|
+
assert result.done
|
|
445
|
+
assert result.reward == 1.0
|
|
446
|
+
assert "Semantic" in result.content
|
|
447
|
+
|
|
448
|
+
@pytest.mark.asyncio
|
|
449
|
+
async def test_numeric_alias_preprocessing(self):
|
|
450
|
+
"""Test number extraction in numeric tools."""
|
|
451
|
+
tools = {t.name: t for t in comparator_server._tool_manager._tools.values()}
|
|
452
|
+
|
|
453
|
+
# Float tool
|
|
454
|
+
float_tool = tools["compare_float"]
|
|
455
|
+
assert isinstance(float_tool, FunctionTool)
|
|
456
|
+
result = await float_tool.fn(
|
|
457
|
+
value="After careful calculation, the answer is 3.14159", reference="3.14159"
|
|
458
|
+
)
|
|
459
|
+
assert result.done
|
|
460
|
+
assert result.reward == 1.0
|
|
461
|
+
assert "Numeric" in result.content
|
|
462
|
+
|
|
463
|
+
# Integer tool
|
|
464
|
+
int_tool = tools["compare_int"]
|
|
465
|
+
assert isinstance(int_tool, FunctionTool)
|
|
466
|
+
result = await int_tool.fn(value="The count is exactly 42 items", reference="42")
|
|
467
|
+
assert result.done
|
|
468
|
+
assert result.reward == 1.0
|
|
469
|
+
assert "Numeric" in result.content
|
|
470
|
+
|
|
471
|
+
@pytest.mark.asyncio
|
|
472
|
+
async def test_boolean_alias_preprocessing(self):
|
|
473
|
+
"""Test boolean extraction in compare_boolean tool."""
|
|
474
|
+
tools = {t.name: t for t in comparator_server._tool_manager._tools.values()}
|
|
475
|
+
bool_tool = tools["compare_boolean"]
|
|
476
|
+
|
|
477
|
+
assert isinstance(bool_tool, FunctionTool)
|
|
478
|
+
result = await bool_tool.fn(
|
|
479
|
+
value="Based on the analysis, the statement is TRUE", reference="true"
|
|
480
|
+
)
|
|
481
|
+
assert result.done
|
|
482
|
+
assert result.reward == 1.0
|
|
483
|
+
assert "Exact" in result.content
|
|
484
|
+
|
|
485
|
+
@pytest.mark.asyncio
|
|
486
|
+
async def test_list_alias_preprocessing(self):
|
|
487
|
+
"""Test list extraction in compare_list tool."""
|
|
488
|
+
tools = {t.name: t for t in comparator_server._tool_manager._tools.values()}
|
|
489
|
+
list_tool = tools["compare_list"]
|
|
490
|
+
|
|
491
|
+
assert isinstance(list_tool, FunctionTool)
|
|
492
|
+
result = await list_tool.fn(
|
|
493
|
+
value="The sorted results are [1, 2, 3, 4, 5]", reference="[1, 2, 3, 4, 5]"
|
|
494
|
+
)
|
|
495
|
+
assert result.done
|
|
496
|
+
assert result.reward == 1.0
|
|
497
|
+
assert "Semantic" in result.content
|
|
498
|
+
|
|
499
|
+
@pytest.mark.asyncio
|
|
500
|
+
async def test_complex_llm_output(self):
|
|
501
|
+
"""Test extraction from complex LLM outputs with reasoning."""
|
|
502
|
+
tools = {t.name: t for t in comparator_server._tool_manager._tools.values()}
|
|
503
|
+
json_tool = tools["compare_json"]
|
|
504
|
+
|
|
505
|
+
llm_output = """
|
|
506
|
+
Let me analyze this request step by step.
|
|
507
|
+
|
|
508
|
+
First, I'll process the data:
|
|
509
|
+
- Item 1: processed
|
|
510
|
+
- Item 2: processed with value 42
|
|
511
|
+
|
|
512
|
+
After careful consideration, the final result is:
|
|
513
|
+
{
|
|
514
|
+
"status": "success",
|
|
515
|
+
"data": {
|
|
516
|
+
"values": [1, 2, 3],
|
|
517
|
+
"metadata": {
|
|
518
|
+
"timestamp": "2024-01-01",
|
|
519
|
+
"version": "1.0"
|
|
520
|
+
}
|
|
521
|
+
}
|
|
522
|
+
}
|
|
523
|
+
|
|
524
|
+
This completes the analysis. The JSON above contains all the required information.
|
|
525
|
+
"""
|
|
526
|
+
|
|
527
|
+
reference = """
|
|
528
|
+
{"status": "success", "data": {"values": [1, 2, 3],
|
|
529
|
+
"metadata": {"timestamp": "2024-01-01", "version": "1.0"}}}
|
|
530
|
+
"""
|
|
531
|
+
|
|
532
|
+
assert isinstance(json_tool, FunctionTool)
|
|
533
|
+
result = await json_tool.fn(value=llm_output, reference=reference)
|
|
534
|
+
assert result.done
|
|
535
|
+
assert result.reward == 1.0
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
if __name__ == "__main__":
|
|
539
|
+
pytest.main([__file__, "-v"])
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""Tests for hud.native.__init__ module."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TestNativeInit:
|
|
7
|
+
"""Tests for the native package initialization."""
|
|
8
|
+
|
|
9
|
+
def test_comparator_server_import(self):
|
|
10
|
+
"""Test that comparator server can be imported."""
|
|
11
|
+
from hud.native.comparator import comparator_server
|
|
12
|
+
from hud.server import MCPServer
|
|
13
|
+
|
|
14
|
+
# Verify comparator is an MCPServer instance
|
|
15
|
+
assert isinstance(comparator_server, MCPServer)
|
|
16
|
+
assert comparator_server.name == "comparator"
|
|
17
|
+
|
|
18
|
+
def test_all_exports(self):
|
|
19
|
+
"""Test that __all__ is properly defined."""
|
|
20
|
+
import hud.native.comparator
|
|
21
|
+
|
|
22
|
+
expected_exports = ["comparator"]
|
|
23
|
+
|
|
24
|
+
# Check __all__ exists and contains expected exports
|
|
25
|
+
assert hasattr(hud.native.comparator, "__all__")
|
|
26
|
+
assert hud.native.comparator.__all__ == expected_exports
|
|
27
|
+
|
|
28
|
+
# Verify all items in __all__ are actually available
|
|
29
|
+
for item in hud.native.comparator.__all__:
|
|
30
|
+
assert hasattr(hud.native.comparator, item)
|
|
31
|
+
|
|
32
|
+
def test_comparator_tools_registered(self):
|
|
33
|
+
"""Test that comparator server has tools registered."""
|
|
34
|
+
from hud.native.comparator import comparator_server
|
|
35
|
+
|
|
36
|
+
# The server should have tools registered
|
|
37
|
+
# We can check that the tool manager has tools
|
|
38
|
+
tool_names = [t.name for t in comparator_server._tool_manager._tools.values()]
|
|
39
|
+
|
|
40
|
+
# Should have the main compare tool
|
|
41
|
+
assert "compare" in tool_names
|
|
42
|
+
|
|
43
|
+
# Should have the submit tool
|
|
44
|
+
assert "submit" in tool_names
|
|
45
|
+
|
|
46
|
+
# Should have all the alias tools
|
|
47
|
+
expected_aliases = [
|
|
48
|
+
"compare_exact",
|
|
49
|
+
"compare_text",
|
|
50
|
+
"compare_string",
|
|
51
|
+
"compare_numeric",
|
|
52
|
+
"compare_float",
|
|
53
|
+
"compare_int",
|
|
54
|
+
"compare_json",
|
|
55
|
+
"compare_boolean",
|
|
56
|
+
"compare_list",
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
for alias in expected_aliases:
|
|
60
|
+
assert alias in tool_names, f"Expected alias {alias} not found"
|
|
61
|
+
|
|
62
|
+
# Total should be 1 (submit) + 1 (compare) + 9 (aliases) = 11 tools
|
|
63
|
+
assert len(tool_names) == 11
|
|
64
|
+
|
|
65
|
+
def test_comparator_tool_functionality(self):
|
|
66
|
+
"""Test that we can get the CompareTool from the comparator."""
|
|
67
|
+
from hud.native.comparator import comparator_server
|
|
68
|
+
from hud.tools import BaseTool
|
|
69
|
+
|
|
70
|
+
# Get the compare tool
|
|
71
|
+
compare_tool = None
|
|
72
|
+
for tool in comparator_server._tool_manager._tools.values():
|
|
73
|
+
if tool.name == "compare":
|
|
74
|
+
compare_tool = tool
|
|
75
|
+
break
|
|
76
|
+
|
|
77
|
+
assert compare_tool is not None
|
|
78
|
+
assert isinstance(compare_tool, BaseTool)
|
|
79
|
+
assert hasattr(compare_tool, "__call__")
|
hud/otel/instrumentation.py
CHANGED
hud/server/server.py
CHANGED
|
@@ -150,9 +150,13 @@ class MCPServer(FastMCP):
|
|
|
150
150
|
super().__init__(name=name, **fastmcp_kwargs)
|
|
151
151
|
self._initializer_fn: Callable | None = None
|
|
152
152
|
self._did_init = False
|
|
153
|
+
self._replaced_server = False
|
|
154
|
+
|
|
155
|
+
def _replace_with_init_server(self) -> None:
|
|
156
|
+
"""Replace the low-level server with init version when needed."""
|
|
157
|
+
if self._replaced_server:
|
|
158
|
+
return
|
|
153
159
|
|
|
154
|
-
# Replace FastMCP's low-level server with our version that supports
|
|
155
|
-
# per-server initialization hooks
|
|
156
160
|
def _run_init(ctx: RequestContext | None = None) -> Any:
|
|
157
161
|
if self._initializer_fn is not None and not self._did_init:
|
|
158
162
|
self._did_init = True
|
|
@@ -177,6 +181,7 @@ class MCPServer(FastMCP):
|
|
|
177
181
|
# Copy handlers from the old server to the new one
|
|
178
182
|
self._mcp_server.request_handlers = old_request_handlers
|
|
179
183
|
self._mcp_server.notification_handlers = old_notification_handlers
|
|
184
|
+
self._replaced_server = True
|
|
180
185
|
|
|
181
186
|
# Initializer decorator: runs on the initialize request
|
|
182
187
|
# The decorated function receives a RequestContext object with access to:
|
|
@@ -186,6 +191,8 @@ class MCPServer(FastMCP):
|
|
|
186
191
|
def initialize(self, fn: Callable | None = None) -> Callable | None:
|
|
187
192
|
def decorator(func: Callable) -> Callable:
|
|
188
193
|
self._initializer_fn = func
|
|
194
|
+
# Only replace server when there's actually an init handler
|
|
195
|
+
self._replace_with_init_server()
|
|
189
196
|
return func
|
|
190
197
|
|
|
191
198
|
return decorator(fn) if fn else decorator
|