hud-python 0.4.20__py3-none-any.whl → 0.4.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hud-python might be problematic. Click here for more details.
- hud/__init__.py +7 -0
- hud/agents/base.py +40 -10
- hud/agents/claude.py +13 -8
- hud/agents/tests/test_client.py +6 -27
- hud/cli/__init__.py +50 -20
- hud/cli/build.py +3 -44
- hud/cli/eval.py +25 -6
- hud/cli/init.py +4 -4
- hud/cli/push.py +3 -1
- hud/cli/tests/test_push.py +6 -6
- hud/clients/__init__.py +3 -2
- hud/clients/base.py +20 -9
- hud/clients/mcp_use.py +44 -22
- hud/datasets/task.py +6 -2
- hud/native/__init__.py +6 -0
- hud/native/comparator.py +546 -0
- hud/native/tests/__init__.py +1 -0
- hud/native/tests/test_comparator.py +539 -0
- hud/native/tests/test_native_init.py +79 -0
- hud/otel/instrumentation.py +0 -2
- hud/server/server.py +9 -2
- hud/shared/exceptions.py +204 -31
- hud/shared/hints.py +177 -0
- hud/shared/requests.py +15 -3
- hud/shared/tests/test_exceptions.py +385 -144
- hud/tools/__init__.py +2 -0
- hud/tools/submit.py +66 -0
- hud/types.py +33 -5
- hud/utils/design.py +57 -0
- hud/utils/mcp.py +6 -0
- hud/utils/pretty_errors.py +68 -0
- hud/utils/tests/test_version.py +1 -1
- hud/version.py +1 -1
- {hud_python-0.4.20.dist-info → hud_python-0.4.21.dist-info}/METADATA +2 -4
- {hud_python-0.4.20.dist-info → hud_python-0.4.21.dist-info}/RECORD +38 -30
- {hud_python-0.4.20.dist-info → hud_python-0.4.21.dist-info}/WHEEL +0 -0
- {hud_python-0.4.20.dist-info → hud_python-0.4.21.dist-info}/entry_points.txt +0 -0
- {hud_python-0.4.20.dist-info → hud_python-0.4.21.dist-info}/licenses/LICENSE +0 -0
hud/native/comparator.py
ADDED
|
@@ -0,0 +1,546 @@
|
|
|
1
|
+
"""Universal type-aware comparator MCP server.
|
|
2
|
+
|
|
3
|
+
This server provides comparison capabilities that automatically detect
|
|
4
|
+
and handle different data types (text, int, float, json) with lenient
|
|
5
|
+
parsing and multiple comparison strategies.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import logging
|
|
12
|
+
import re
|
|
13
|
+
import sys
|
|
14
|
+
from difflib import SequenceMatcher
|
|
15
|
+
from enum import Enum
|
|
16
|
+
from typing import Any, Literal
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel, Field
|
|
19
|
+
|
|
20
|
+
from hud.server import MCPServer
|
|
21
|
+
from hud.tools import BaseTool, SubmitTool
|
|
22
|
+
from hud.tools.submit import get_submission
|
|
23
|
+
from hud.tools.types import EvaluationResult
|
|
24
|
+
|
|
25
|
+
# Configure logging
|
|
26
|
+
logging.basicConfig(
|
|
27
|
+
stream=sys.stderr,
|
|
28
|
+
level=logging.INFO,
|
|
29
|
+
format="[%(levelname)s] %(asctime)s | %(name)s | %(message)s",
|
|
30
|
+
)
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class DataType(str, Enum):
|
|
35
|
+
"""Detected data types."""
|
|
36
|
+
|
|
37
|
+
TEXT = "text"
|
|
38
|
+
INTEGER = "integer"
|
|
39
|
+
FLOAT = "float"
|
|
40
|
+
JSON = "json"
|
|
41
|
+
BOOLEAN = "boolean"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class ComparisonMode(str, Enum):
|
|
45
|
+
"""Available comparison modes."""
|
|
46
|
+
|
|
47
|
+
AUTO = "auto" # Auto-detect based on type
|
|
48
|
+
EXACT = "exact" # Exact match
|
|
49
|
+
FUZZY = "fuzzy" # Fuzzy text matching with similarity threshold
|
|
50
|
+
NUMERIC = "numeric" # Numeric comparison with tolerance
|
|
51
|
+
SEMANTIC = "semantic" # Semantic equivalence (JSON structures)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class ComparisonResult(BaseModel):
|
|
55
|
+
"""Result of a comparison operation."""
|
|
56
|
+
|
|
57
|
+
matches: bool
|
|
58
|
+
similarity: float = Field(ge=0.0, le=1.0)
|
|
59
|
+
detected_type: DataType
|
|
60
|
+
comparison_mode: ComparisonMode
|
|
61
|
+
details: dict[str, Any] = Field(default_factory=dict)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# Extraction functions for handling LLM outputs
|
|
65
|
+
def extract_json(text: str) -> str:
|
|
66
|
+
"""Extract the last valid JSON object or array from text."""
|
|
67
|
+
if not text:
|
|
68
|
+
return text
|
|
69
|
+
|
|
70
|
+
# First, try if the whole string is valid JSON
|
|
71
|
+
try:
|
|
72
|
+
json.loads(text)
|
|
73
|
+
return text
|
|
74
|
+
except (json.JSONDecodeError, TypeError):
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
# Strategy: Find all { or [ characters and try to parse from each one
|
|
78
|
+
candidates = []
|
|
79
|
+
|
|
80
|
+
# Find all potential JSON starting points
|
|
81
|
+
for i, char in enumerate(text):
|
|
82
|
+
if char in "{[":
|
|
83
|
+
# Try to find matching closing bracket
|
|
84
|
+
bracket_count = 0
|
|
85
|
+
in_string = False
|
|
86
|
+
escape_next = False
|
|
87
|
+
|
|
88
|
+
for j in range(i, len(text)):
|
|
89
|
+
current_char = text[j]
|
|
90
|
+
|
|
91
|
+
if escape_next:
|
|
92
|
+
escape_next = False
|
|
93
|
+
continue
|
|
94
|
+
|
|
95
|
+
if current_char == "\\":
|
|
96
|
+
escape_next = True
|
|
97
|
+
continue
|
|
98
|
+
|
|
99
|
+
if current_char == '"' and not escape_next:
|
|
100
|
+
in_string = not in_string
|
|
101
|
+
continue
|
|
102
|
+
|
|
103
|
+
if not in_string:
|
|
104
|
+
if current_char in "{[":
|
|
105
|
+
bracket_count += 1
|
|
106
|
+
elif current_char in "}]":
|
|
107
|
+
bracket_count -= 1
|
|
108
|
+
|
|
109
|
+
if bracket_count == 0:
|
|
110
|
+
# Found matching bracket
|
|
111
|
+
candidate = text[i : j + 1]
|
|
112
|
+
try:
|
|
113
|
+
json.loads(candidate)
|
|
114
|
+
candidates.append((j + 1, candidate))
|
|
115
|
+
except (json.JSONDecodeError, TypeError):
|
|
116
|
+
pass
|
|
117
|
+
break
|
|
118
|
+
|
|
119
|
+
# Return the last valid JSON found
|
|
120
|
+
if candidates:
|
|
121
|
+
candidates.sort(key=lambda x: x[0])
|
|
122
|
+
return candidates[-1][1]
|
|
123
|
+
|
|
124
|
+
return text
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def extract_number(text: str, number_type: Literal["int", "float"] = "float") -> str:
|
|
128
|
+
"""Extract the last number from text."""
|
|
129
|
+
if not text:
|
|
130
|
+
return text
|
|
131
|
+
|
|
132
|
+
# Pattern for numbers (including scientific notation)
|
|
133
|
+
number_pattern = r"-?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?"
|
|
134
|
+
|
|
135
|
+
matches = list(re.finditer(number_pattern, text))
|
|
136
|
+
if matches:
|
|
137
|
+
last_number = matches[-1].group()
|
|
138
|
+
|
|
139
|
+
# For int type, ensure we don't have decimals
|
|
140
|
+
if number_type == "int":
|
|
141
|
+
try:
|
|
142
|
+
# Check if it's actually an integer
|
|
143
|
+
float_val = float(last_number)
|
|
144
|
+
if float_val.is_integer():
|
|
145
|
+
return str(int(float_val))
|
|
146
|
+
except ValueError:
|
|
147
|
+
pass
|
|
148
|
+
|
|
149
|
+
return last_number
|
|
150
|
+
|
|
151
|
+
return text
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def extract_boolean(text: str) -> str:
|
|
155
|
+
"""Extract the last boolean value from text."""
|
|
156
|
+
if not text:
|
|
157
|
+
return text
|
|
158
|
+
|
|
159
|
+
# Look for boolean values (case insensitive)
|
|
160
|
+
bool_pattern = r"\b(true|false|True|False|TRUE|FALSE)\b"
|
|
161
|
+
|
|
162
|
+
matches = list(re.finditer(bool_pattern, text))
|
|
163
|
+
if matches:
|
|
164
|
+
return matches[-1].group().lower()
|
|
165
|
+
|
|
166
|
+
return text
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def extract_list(text: str) -> str:
|
|
170
|
+
"""Extract the last list/array from text."""
|
|
171
|
+
# For lists, we use the same logic as JSON extraction
|
|
172
|
+
# since lists are JSON arrays
|
|
173
|
+
return extract_json(text)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _compare_exact(value: Any, reference: Any, **kwargs: Any) -> tuple[bool, float, dict]:
|
|
177
|
+
"""Exact comparison."""
|
|
178
|
+
matches = value == reference
|
|
179
|
+
return matches, 1.0 if matches else 0.0, {}
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _compare_fuzzy(
|
|
183
|
+
value: Any, reference: Any, threshold: float = 0.8, **kwargs: Any
|
|
184
|
+
) -> tuple[bool, float, dict]:
|
|
185
|
+
"""Fuzzy text comparison."""
|
|
186
|
+
str_value = str(value).strip()
|
|
187
|
+
str_reference = str(reference).strip()
|
|
188
|
+
similarity = SequenceMatcher(None, str_value, str_reference).ratio()
|
|
189
|
+
return similarity >= threshold, similarity, {"threshold": threshold}
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _compare_numeric(
|
|
193
|
+
value: Any, reference: Any, tolerance: float = 1e-6, **kwargs: Any
|
|
194
|
+
) -> tuple[bool, float, dict]:
|
|
195
|
+
"""Numeric comparison with tolerance."""
|
|
196
|
+
try:
|
|
197
|
+
num_value = float(value)
|
|
198
|
+
num_reference = float(reference)
|
|
199
|
+
|
|
200
|
+
abs_diff = abs(num_value - num_reference)
|
|
201
|
+
rel_diff = abs_diff / max(abs(num_reference), 1e-10)
|
|
202
|
+
|
|
203
|
+
matches = abs_diff <= tolerance or rel_diff <= tolerance
|
|
204
|
+
similarity = max(0.0, 1.0 - rel_diff)
|
|
205
|
+
|
|
206
|
+
return (
|
|
207
|
+
matches,
|
|
208
|
+
similarity,
|
|
209
|
+
{
|
|
210
|
+
"absolute_difference": abs_diff,
|
|
211
|
+
"relative_difference": rel_diff,
|
|
212
|
+
"tolerance": tolerance,
|
|
213
|
+
},
|
|
214
|
+
)
|
|
215
|
+
except (ValueError, TypeError):
|
|
216
|
+
return False, 0.0, {"error": "non-numeric values"}
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def _compare_semantic(
|
|
220
|
+
value: Any, reference: Any, threshold: float = 0.8, tolerance: float = 1e-6, **kwargs: Any
|
|
221
|
+
) -> tuple[bool, float, dict]:
|
|
222
|
+
"""Semantic comparison for JSON-parseable structures."""
|
|
223
|
+
# Try to parse as JSON
|
|
224
|
+
try:
|
|
225
|
+
v_obj = json.loads(value) if isinstance(value, str) else value
|
|
226
|
+
r_obj = json.loads(reference) if isinstance(reference, str) else reference
|
|
227
|
+
except (json.JSONDecodeError, TypeError):
|
|
228
|
+
# Not valid JSON - fall back to fuzzy text comparison
|
|
229
|
+
return _compare_fuzzy(value, reference, threshold)
|
|
230
|
+
|
|
231
|
+
# Now dispatch based on parsed types
|
|
232
|
+
if isinstance(v_obj, dict) and isinstance(r_obj, dict):
|
|
233
|
+
# Dictionary comparison
|
|
234
|
+
try:
|
|
235
|
+
norm_value = json.dumps(v_obj, sort_keys=True)
|
|
236
|
+
norm_reference = json.dumps(r_obj, sort_keys=True)
|
|
237
|
+
matches = norm_value == norm_reference
|
|
238
|
+
|
|
239
|
+
# Calculate similarity based on keys and values
|
|
240
|
+
common_keys = set(v_obj.keys()) & set(r_obj.keys())
|
|
241
|
+
all_keys = set(v_obj.keys()) | set(r_obj.keys())
|
|
242
|
+
key_similarity = len(common_keys) / len(all_keys) if all_keys else 1.0
|
|
243
|
+
|
|
244
|
+
if common_keys:
|
|
245
|
+
matching_values = sum(1 for k in common_keys if v_obj.get(k) == r_obj.get(k))
|
|
246
|
+
value_similarity = matching_values / len(common_keys)
|
|
247
|
+
similarity = (key_similarity + value_similarity) / 2
|
|
248
|
+
else:
|
|
249
|
+
similarity = key_similarity
|
|
250
|
+
|
|
251
|
+
return matches, similarity, {"type": "dict"}
|
|
252
|
+
except Exception as e:
|
|
253
|
+
return False, 0.0, {"error": str(e)}
|
|
254
|
+
|
|
255
|
+
elif isinstance(v_obj, list) and isinstance(r_obj, list):
|
|
256
|
+
# List comparison - element by element
|
|
257
|
+
matches = v_obj == r_obj
|
|
258
|
+
if len(v_obj) == len(r_obj) == 0:
|
|
259
|
+
return True, 1.0, {"type": "list", "length": 0}
|
|
260
|
+
|
|
261
|
+
# Similarity based on matching positions
|
|
262
|
+
if len(v_obj) == len(r_obj):
|
|
263
|
+
matching = sum(1 for a, b in zip(v_obj, r_obj, strict=False) if a == b)
|
|
264
|
+
similarity = matching / len(v_obj)
|
|
265
|
+
else:
|
|
266
|
+
similarity = 0.0
|
|
267
|
+
|
|
268
|
+
return matches, similarity, {"type": "list", "length": len(v_obj)}
|
|
269
|
+
|
|
270
|
+
elif isinstance(v_obj, (int | float)) and isinstance(r_obj, (int | float)):
|
|
271
|
+
# Numeric comparison
|
|
272
|
+
return _compare_numeric(v_obj, r_obj, tolerance)
|
|
273
|
+
|
|
274
|
+
elif isinstance(v_obj, bool) and isinstance(r_obj, bool):
|
|
275
|
+
# Boolean comparison
|
|
276
|
+
return _compare_exact(v_obj, r_obj)
|
|
277
|
+
|
|
278
|
+
elif isinstance(v_obj, str) and isinstance(r_obj, str):
|
|
279
|
+
# String comparison - could be fuzzy or exact
|
|
280
|
+
return _compare_fuzzy(v_obj, r_obj, threshold)
|
|
281
|
+
|
|
282
|
+
else:
|
|
283
|
+
# Different types or other cases - exact comparison
|
|
284
|
+
return _compare_exact(v_obj, r_obj)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
# Map modes to comparison functions
|
|
288
|
+
COMPARISON_FUNCTIONS = {
|
|
289
|
+
ComparisonMode.EXACT: _compare_exact,
|
|
290
|
+
ComparisonMode.FUZZY: _compare_fuzzy,
|
|
291
|
+
ComparisonMode.NUMERIC: _compare_numeric,
|
|
292
|
+
ComparisonMode.SEMANTIC: _compare_semantic,
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def detect_type(value: str | None = None) -> DataType:
|
|
297
|
+
"""Detect the data type of a string value."""
|
|
298
|
+
if value is None:
|
|
299
|
+
return DataType.TEXT
|
|
300
|
+
|
|
301
|
+
# Try boolean
|
|
302
|
+
if value.lower() in ("true", "false"):
|
|
303
|
+
return DataType.BOOLEAN
|
|
304
|
+
|
|
305
|
+
# Try JSON (dict or list)
|
|
306
|
+
try:
|
|
307
|
+
parsed = json.loads(value)
|
|
308
|
+
if isinstance(parsed, (dict | list)):
|
|
309
|
+
return DataType.JSON
|
|
310
|
+
# Continue checking if it's a JSON primitive
|
|
311
|
+
value = str(parsed)
|
|
312
|
+
except (json.JSONDecodeError, TypeError):
|
|
313
|
+
pass
|
|
314
|
+
|
|
315
|
+
# Try integer
|
|
316
|
+
try:
|
|
317
|
+
int(value)
|
|
318
|
+
return DataType.INTEGER
|
|
319
|
+
except ValueError:
|
|
320
|
+
pass
|
|
321
|
+
|
|
322
|
+
# Try float
|
|
323
|
+
try:
|
|
324
|
+
float(value)
|
|
325
|
+
return DataType.FLOAT
|
|
326
|
+
except ValueError:
|
|
327
|
+
pass
|
|
328
|
+
|
|
329
|
+
# Default to text
|
|
330
|
+
return DataType.TEXT
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def auto_select_mode(value_type: DataType, ref_type: DataType) -> ComparisonMode:
|
|
334
|
+
"""Auto-select comparison mode based on detected types."""
|
|
335
|
+
# If either is JSON, use semantic
|
|
336
|
+
if DataType.JSON in (value_type, ref_type):
|
|
337
|
+
return ComparisonMode.SEMANTIC
|
|
338
|
+
|
|
339
|
+
# If either is numeric, use numeric
|
|
340
|
+
if value_type in (DataType.INTEGER, DataType.FLOAT) or ref_type in (
|
|
341
|
+
DataType.INTEGER,
|
|
342
|
+
DataType.FLOAT,
|
|
343
|
+
):
|
|
344
|
+
return ComparisonMode.NUMERIC
|
|
345
|
+
|
|
346
|
+
# Booleans use exact
|
|
347
|
+
if value_type == DataType.BOOLEAN or ref_type == DataType.BOOLEAN:
|
|
348
|
+
return ComparisonMode.EXACT
|
|
349
|
+
|
|
350
|
+
# Default to fuzzy for text
|
|
351
|
+
return ComparisonMode.FUZZY
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
class CompareTool(BaseTool):
|
|
355
|
+
"""Universal comparison tool with mode selection."""
|
|
356
|
+
|
|
357
|
+
name = "compare"
|
|
358
|
+
title = "Compare Tool"
|
|
359
|
+
description = "Compare values with explicit or automatic mode selection"
|
|
360
|
+
|
|
361
|
+
async def __call__(
|
|
362
|
+
self,
|
|
363
|
+
value: Any | list[Any] | None = None,
|
|
364
|
+
reference: Any | list[Any] | None = None,
|
|
365
|
+
mode: ComparisonMode = ComparisonMode.AUTO,
|
|
366
|
+
threshold: float = 0.8,
|
|
367
|
+
tolerance: float = 1e-6,
|
|
368
|
+
) -> EvaluationResult:
|
|
369
|
+
"""Compare values with specified or auto-detected mode."""
|
|
370
|
+
# Get value from submission if not provided
|
|
371
|
+
if value is None:
|
|
372
|
+
value = get_submission()
|
|
373
|
+
|
|
374
|
+
# Normalize inputs to lists
|
|
375
|
+
if value is None or reference is None:
|
|
376
|
+
return EvaluationResult(
|
|
377
|
+
reward=0.0, done=False, content="Missing value or reference", isError=True
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
# Convert to lists
|
|
381
|
+
val_list = value if isinstance(value, list) else [value]
|
|
382
|
+
ref_list = reference if isinstance(reference, list) else [reference]
|
|
383
|
+
|
|
384
|
+
# Check list compatibility
|
|
385
|
+
comp_type = "scalar"
|
|
386
|
+
if isinstance(value, list) and isinstance(reference, list):
|
|
387
|
+
if len(val_list) != len(ref_list):
|
|
388
|
+
return EvaluationResult(
|
|
389
|
+
reward=0.0,
|
|
390
|
+
done=False,
|
|
391
|
+
content=f"Error: Mismatched lengths - {len(val_list)} values vs {len(ref_list)} references", # noqa: E501
|
|
392
|
+
)
|
|
393
|
+
comp_type = "batch"
|
|
394
|
+
elif isinstance(value, list) and not isinstance(reference, list):
|
|
395
|
+
ref_list = [reference] * len(val_list)
|
|
396
|
+
comp_type = "broadcast"
|
|
397
|
+
elif not isinstance(value, list) and isinstance(reference, list):
|
|
398
|
+
val_list = [value] * len(ref_list)
|
|
399
|
+
comp_type = "broadcast"
|
|
400
|
+
|
|
401
|
+
# Process each pair
|
|
402
|
+
results = []
|
|
403
|
+
for v, r in zip(val_list, ref_list, strict=False):
|
|
404
|
+
# Convert to strings
|
|
405
|
+
v_str, r_str = str(v), str(r)
|
|
406
|
+
|
|
407
|
+
# Determine comparison mode
|
|
408
|
+
if mode == ComparisonMode.AUTO:
|
|
409
|
+
v_type = detect_type(v_str)
|
|
410
|
+
r_type = detect_type(r_str)
|
|
411
|
+
comparison_mode = auto_select_mode(v_type, r_type)
|
|
412
|
+
detected_type = v_type if v_type == r_type else DataType.TEXT
|
|
413
|
+
else:
|
|
414
|
+
comparison_mode = mode
|
|
415
|
+
detected_type = DataType.TEXT
|
|
416
|
+
|
|
417
|
+
# For exact mode, skip parsing and compare raw strings
|
|
418
|
+
if comparison_mode == ComparisonMode.EXACT and mode == ComparisonMode.EXACT:
|
|
419
|
+
matches = v_str == r_str
|
|
420
|
+
result = ComparisonResult(
|
|
421
|
+
matches=matches,
|
|
422
|
+
similarity=1.0 if matches else 0.0,
|
|
423
|
+
detected_type=detected_type,
|
|
424
|
+
comparison_mode=comparison_mode,
|
|
425
|
+
)
|
|
426
|
+
else:
|
|
427
|
+
# Get comparison function and run it
|
|
428
|
+
compare_fn = COMPARISON_FUNCTIONS[comparison_mode]
|
|
429
|
+
matches, similarity, details = compare_fn(
|
|
430
|
+
v_str, r_str, threshold=threshold, tolerance=tolerance
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
result = ComparisonResult(
|
|
434
|
+
matches=matches,
|
|
435
|
+
similarity=similarity,
|
|
436
|
+
detected_type=detected_type,
|
|
437
|
+
comparison_mode=comparison_mode,
|
|
438
|
+
details=details,
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
results.append(result)
|
|
442
|
+
|
|
443
|
+
# Aggregate results
|
|
444
|
+
if not results:
|
|
445
|
+
return EvaluationResult(reward=0.0, done=False, content="No comparisons performed")
|
|
446
|
+
|
|
447
|
+
total_similarity = sum(r.similarity for r in results)
|
|
448
|
+
avg_similarity = total_similarity / len(results)
|
|
449
|
+
all_match = all(r.matches for r in results)
|
|
450
|
+
match_count = sum(1 for r in results if r.matches)
|
|
451
|
+
|
|
452
|
+
# Format content
|
|
453
|
+
prefix = {"scalar": "Single", "batch": "Batch", "broadcast": "Broadcast"}.get(
|
|
454
|
+
comp_type, "Unknown"
|
|
455
|
+
)
|
|
456
|
+
mode_name = results[0].comparison_mode.value.capitalize()
|
|
457
|
+
|
|
458
|
+
return EvaluationResult(
|
|
459
|
+
reward=avg_similarity,
|
|
460
|
+
done=all_match,
|
|
461
|
+
content=f"{prefix} {mode_name}: {match_count}/{len(results)} matches, avg={avg_similarity:.3f}", # noqa: E501
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
# Map of specific aliases to their preprocessing needs
|
|
466
|
+
ALIAS_PREPROCESSORS = {
|
|
467
|
+
"compare_json": lambda v: extract_json(v),
|
|
468
|
+
"compare_int": lambda v: extract_number(v, "int"),
|
|
469
|
+
"compare_float": lambda v: extract_number(v, "float"),
|
|
470
|
+
"compare_boolean": lambda v: extract_boolean(v),
|
|
471
|
+
"compare_list": lambda v: extract_list(v),
|
|
472
|
+
}
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
# Helper to create alias tool classes
|
|
476
|
+
def make_alias_tool(name: str, preset_mode: ComparisonMode, description: str) -> type[BaseTool]:
|
|
477
|
+
"""Create an alias tool class that presets the mode."""
|
|
478
|
+
|
|
479
|
+
class AliasTool(BaseTool):
|
|
480
|
+
def __init__(self) -> None:
|
|
481
|
+
super().__init__(
|
|
482
|
+
name=name,
|
|
483
|
+
title=f"Compare ({preset_mode.capitalize()})",
|
|
484
|
+
description=description + " (auto-handles lists, extracts from outputs)",
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
async def __call__(
|
|
488
|
+
self,
|
|
489
|
+
value: Any | list[Any] | None = None,
|
|
490
|
+
reference: Any | list[Any] | None = None,
|
|
491
|
+
threshold: float = 0.8,
|
|
492
|
+
tolerance: float = 1e-6,
|
|
493
|
+
) -> EvaluationResult:
|
|
494
|
+
"""Alias that calls compare with preset mode."""
|
|
495
|
+
# Apply specific preprocessing if this alias has one
|
|
496
|
+
if value is not None and name in ALIAS_PREPROCESSORS:
|
|
497
|
+
preprocessor = ALIAS_PREPROCESSORS[name]
|
|
498
|
+
if isinstance(value, list):
|
|
499
|
+
value = [preprocessor(str(v)) for v in value]
|
|
500
|
+
else:
|
|
501
|
+
value = preprocessor(str(value))
|
|
502
|
+
|
|
503
|
+
tool = CompareTool()
|
|
504
|
+
return await tool(
|
|
505
|
+
value=value,
|
|
506
|
+
reference=reference,
|
|
507
|
+
mode=preset_mode,
|
|
508
|
+
threshold=threshold,
|
|
509
|
+
tolerance=tolerance,
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
return AliasTool
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
# Create MCP server
|
|
516
|
+
comparator_server = MCPServer(name="comparator")
|
|
517
|
+
|
|
518
|
+
# Register main tool
|
|
519
|
+
comparator_server.add_tool(SubmitTool())
|
|
520
|
+
comparator_server.add_tool(CompareTool())
|
|
521
|
+
|
|
522
|
+
# Register aliases - these are just thin wrappers
|
|
523
|
+
ALIASES = [
|
|
524
|
+
("compare_exact", ComparisonMode.EXACT, "Exact string comparison"),
|
|
525
|
+
("compare_text", ComparisonMode.FUZZY, "Fuzzy text comparison"),
|
|
526
|
+
("compare_string", ComparisonMode.FUZZY, "Fuzzy string comparison (alias for text)"),
|
|
527
|
+
("compare_numeric", ComparisonMode.NUMERIC, "Numeric comparison with tolerance"),
|
|
528
|
+
("compare_float", ComparisonMode.NUMERIC, "Float comparison (alias for numeric)"),
|
|
529
|
+
("compare_int", ComparisonMode.NUMERIC, "Integer comparison (alias for numeric)"),
|
|
530
|
+
("compare_json", ComparisonMode.SEMANTIC, "Semantic JSON comparison"),
|
|
531
|
+
("compare_boolean", ComparisonMode.EXACT, "Boolean comparison (exact match)"),
|
|
532
|
+
("compare_list", ComparisonMode.SEMANTIC, "List comparison (alias for CompareTool)"),
|
|
533
|
+
]
|
|
534
|
+
|
|
535
|
+
for name, mode, desc in ALIASES:
|
|
536
|
+
AliasTool = make_alias_tool(name, mode, desc)
|
|
537
|
+
comparator_server.add_tool(AliasTool())
|
|
538
|
+
|
|
539
|
+
# Export for mounting
|
|
540
|
+
__all__ = ["comparator_server"]
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
if __name__ == "__main__":
|
|
544
|
+
# Run as standalone server
|
|
545
|
+
logger.info("Starting Comparator MCP Server...")
|
|
546
|
+
comparator_server.run()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Tests for the native MCP servers."""
|