hud-python 0.4.19__py3-none-any.whl → 0.4.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hud-python might be problematic. Click here for more details.

@@ -0,0 +1,546 @@
1
+ """Universal type-aware comparator MCP server.
2
+
3
+ This server provides comparison capabilities that automatically detect
4
+ and handle different data types (text, int, float, json) with lenient
5
+ parsing and multiple comparison strategies.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ import logging
12
+ import re
13
+ import sys
14
+ from difflib import SequenceMatcher
15
+ from enum import Enum
16
+ from typing import Any, Literal
17
+
18
+ from pydantic import BaseModel, Field
19
+
20
+ from hud.server import MCPServer
21
+ from hud.tools import BaseTool, SubmitTool
22
+ from hud.tools.submit import get_submission
23
+ from hud.tools.types import EvaluationResult
24
+
25
+ # Configure logging
26
+ logging.basicConfig(
27
+ stream=sys.stderr,
28
+ level=logging.INFO,
29
+ format="[%(levelname)s] %(asctime)s | %(name)s | %(message)s",
30
+ )
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class DataType(str, Enum):
35
+ """Detected data types."""
36
+
37
+ TEXT = "text"
38
+ INTEGER = "integer"
39
+ FLOAT = "float"
40
+ JSON = "json"
41
+ BOOLEAN = "boolean"
42
+
43
+
44
+ class ComparisonMode(str, Enum):
45
+ """Available comparison modes."""
46
+
47
+ AUTO = "auto" # Auto-detect based on type
48
+ EXACT = "exact" # Exact match
49
+ FUZZY = "fuzzy" # Fuzzy text matching with similarity threshold
50
+ NUMERIC = "numeric" # Numeric comparison with tolerance
51
+ SEMANTIC = "semantic" # Semantic equivalence (JSON structures)
52
+
53
+
54
+ class ComparisonResult(BaseModel):
55
+ """Result of a comparison operation."""
56
+
57
+ matches: bool
58
+ similarity: float = Field(ge=0.0, le=1.0)
59
+ detected_type: DataType
60
+ comparison_mode: ComparisonMode
61
+ details: dict[str, Any] = Field(default_factory=dict)
62
+
63
+
64
+ # Extraction functions for handling LLM outputs
65
+ def extract_json(text: str) -> str:
66
+ """Extract the last valid JSON object or array from text."""
67
+ if not text:
68
+ return text
69
+
70
+ # First, try if the whole string is valid JSON
71
+ try:
72
+ json.loads(text)
73
+ return text
74
+ except (json.JSONDecodeError, TypeError):
75
+ pass
76
+
77
+ # Strategy: Find all { or [ characters and try to parse from each one
78
+ candidates = []
79
+
80
+ # Find all potential JSON starting points
81
+ for i, char in enumerate(text):
82
+ if char in "{[":
83
+ # Try to find matching closing bracket
84
+ bracket_count = 0
85
+ in_string = False
86
+ escape_next = False
87
+
88
+ for j in range(i, len(text)):
89
+ current_char = text[j]
90
+
91
+ if escape_next:
92
+ escape_next = False
93
+ continue
94
+
95
+ if current_char == "\\":
96
+ escape_next = True
97
+ continue
98
+
99
+ if current_char == '"' and not escape_next:
100
+ in_string = not in_string
101
+ continue
102
+
103
+ if not in_string:
104
+ if current_char in "{[":
105
+ bracket_count += 1
106
+ elif current_char in "}]":
107
+ bracket_count -= 1
108
+
109
+ if bracket_count == 0:
110
+ # Found matching bracket
111
+ candidate = text[i : j + 1]
112
+ try:
113
+ json.loads(candidate)
114
+ candidates.append((j + 1, candidate))
115
+ except (json.JSONDecodeError, TypeError):
116
+ pass
117
+ break
118
+
119
+ # Return the last valid JSON found
120
+ if candidates:
121
+ candidates.sort(key=lambda x: x[0])
122
+ return candidates[-1][1]
123
+
124
+ return text
125
+
126
+
127
+ def extract_number(text: str, number_type: Literal["int", "float"] = "float") -> str:
128
+ """Extract the last number from text."""
129
+ if not text:
130
+ return text
131
+
132
+ # Pattern for numbers (including scientific notation)
133
+ number_pattern = r"-?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?"
134
+
135
+ matches = list(re.finditer(number_pattern, text))
136
+ if matches:
137
+ last_number = matches[-1].group()
138
+
139
+ # For int type, ensure we don't have decimals
140
+ if number_type == "int":
141
+ try:
142
+ # Check if it's actually an integer
143
+ float_val = float(last_number)
144
+ if float_val.is_integer():
145
+ return str(int(float_val))
146
+ except ValueError:
147
+ pass
148
+
149
+ return last_number
150
+
151
+ return text
152
+
153
+
154
+ def extract_boolean(text: str) -> str:
155
+ """Extract the last boolean value from text."""
156
+ if not text:
157
+ return text
158
+
159
+ # Look for boolean values (case insensitive)
160
+ bool_pattern = r"\b(true|false|True|False|TRUE|FALSE)\b"
161
+
162
+ matches = list(re.finditer(bool_pattern, text))
163
+ if matches:
164
+ return matches[-1].group().lower()
165
+
166
+ return text
167
+
168
+
169
+ def extract_list(text: str) -> str:
170
+ """Extract the last list/array from text."""
171
+ # For lists, we use the same logic as JSON extraction
172
+ # since lists are JSON arrays
173
+ return extract_json(text)
174
+
175
+
176
+ def _compare_exact(value: Any, reference: Any, **kwargs: Any) -> tuple[bool, float, dict]:
177
+ """Exact comparison."""
178
+ matches = value == reference
179
+ return matches, 1.0 if matches else 0.0, {}
180
+
181
+
182
+ def _compare_fuzzy(
183
+ value: Any, reference: Any, threshold: float = 0.8, **kwargs: Any
184
+ ) -> tuple[bool, float, dict]:
185
+ """Fuzzy text comparison."""
186
+ str_value = str(value).strip()
187
+ str_reference = str(reference).strip()
188
+ similarity = SequenceMatcher(None, str_value, str_reference).ratio()
189
+ return similarity >= threshold, similarity, {"threshold": threshold}
190
+
191
+
192
+ def _compare_numeric(
193
+ value: Any, reference: Any, tolerance: float = 1e-6, **kwargs: Any
194
+ ) -> tuple[bool, float, dict]:
195
+ """Numeric comparison with tolerance."""
196
+ try:
197
+ num_value = float(value)
198
+ num_reference = float(reference)
199
+
200
+ abs_diff = abs(num_value - num_reference)
201
+ rel_diff = abs_diff / max(abs(num_reference), 1e-10)
202
+
203
+ matches = abs_diff <= tolerance or rel_diff <= tolerance
204
+ similarity = max(0.0, 1.0 - rel_diff)
205
+
206
+ return (
207
+ matches,
208
+ similarity,
209
+ {
210
+ "absolute_difference": abs_diff,
211
+ "relative_difference": rel_diff,
212
+ "tolerance": tolerance,
213
+ },
214
+ )
215
+ except (ValueError, TypeError):
216
+ return False, 0.0, {"error": "non-numeric values"}
217
+
218
+
219
+ def _compare_semantic(
220
+ value: Any, reference: Any, threshold: float = 0.8, tolerance: float = 1e-6, **kwargs: Any
221
+ ) -> tuple[bool, float, dict]:
222
+ """Semantic comparison for JSON-parseable structures."""
223
+ # Try to parse as JSON
224
+ try:
225
+ v_obj = json.loads(value) if isinstance(value, str) else value
226
+ r_obj = json.loads(reference) if isinstance(reference, str) else reference
227
+ except (json.JSONDecodeError, TypeError):
228
+ # Not valid JSON - fall back to fuzzy text comparison
229
+ return _compare_fuzzy(value, reference, threshold)
230
+
231
+ # Now dispatch based on parsed types
232
+ if isinstance(v_obj, dict) and isinstance(r_obj, dict):
233
+ # Dictionary comparison
234
+ try:
235
+ norm_value = json.dumps(v_obj, sort_keys=True)
236
+ norm_reference = json.dumps(r_obj, sort_keys=True)
237
+ matches = norm_value == norm_reference
238
+
239
+ # Calculate similarity based on keys and values
240
+ common_keys = set(v_obj.keys()) & set(r_obj.keys())
241
+ all_keys = set(v_obj.keys()) | set(r_obj.keys())
242
+ key_similarity = len(common_keys) / len(all_keys) if all_keys else 1.0
243
+
244
+ if common_keys:
245
+ matching_values = sum(1 for k in common_keys if v_obj.get(k) == r_obj.get(k))
246
+ value_similarity = matching_values / len(common_keys)
247
+ similarity = (key_similarity + value_similarity) / 2
248
+ else:
249
+ similarity = key_similarity
250
+
251
+ return matches, similarity, {"type": "dict"}
252
+ except Exception as e:
253
+ return False, 0.0, {"error": str(e)}
254
+
255
+ elif isinstance(v_obj, list) and isinstance(r_obj, list):
256
+ # List comparison - element by element
257
+ matches = v_obj == r_obj
258
+ if len(v_obj) == len(r_obj) == 0:
259
+ return True, 1.0, {"type": "list", "length": 0}
260
+
261
+ # Similarity based on matching positions
262
+ if len(v_obj) == len(r_obj):
263
+ matching = sum(1 for a, b in zip(v_obj, r_obj, strict=False) if a == b)
264
+ similarity = matching / len(v_obj)
265
+ else:
266
+ similarity = 0.0
267
+
268
+ return matches, similarity, {"type": "list", "length": len(v_obj)}
269
+
270
+ elif isinstance(v_obj, (int | float)) and isinstance(r_obj, (int | float)):
271
+ # Numeric comparison
272
+ return _compare_numeric(v_obj, r_obj, tolerance)
273
+
274
+ elif isinstance(v_obj, bool) and isinstance(r_obj, bool):
275
+ # Boolean comparison
276
+ return _compare_exact(v_obj, r_obj)
277
+
278
+ elif isinstance(v_obj, str) and isinstance(r_obj, str):
279
+ # String comparison - could be fuzzy or exact
280
+ return _compare_fuzzy(v_obj, r_obj, threshold)
281
+
282
+ else:
283
+ # Different types or other cases - exact comparison
284
+ return _compare_exact(v_obj, r_obj)
285
+
286
+
287
+ # Map modes to comparison functions
288
+ COMPARISON_FUNCTIONS = {
289
+ ComparisonMode.EXACT: _compare_exact,
290
+ ComparisonMode.FUZZY: _compare_fuzzy,
291
+ ComparisonMode.NUMERIC: _compare_numeric,
292
+ ComparisonMode.SEMANTIC: _compare_semantic,
293
+ }
294
+
295
+
296
+ def detect_type(value: str | None = None) -> DataType:
297
+ """Detect the data type of a string value."""
298
+ if value is None:
299
+ return DataType.TEXT
300
+
301
+ # Try boolean
302
+ if value.lower() in ("true", "false"):
303
+ return DataType.BOOLEAN
304
+
305
+ # Try JSON (dict or list)
306
+ try:
307
+ parsed = json.loads(value)
308
+ if isinstance(parsed, (dict | list)):
309
+ return DataType.JSON
310
+ # Continue checking if it's a JSON primitive
311
+ value = str(parsed)
312
+ except (json.JSONDecodeError, TypeError):
313
+ pass
314
+
315
+ # Try integer
316
+ try:
317
+ int(value)
318
+ return DataType.INTEGER
319
+ except ValueError:
320
+ pass
321
+
322
+ # Try float
323
+ try:
324
+ float(value)
325
+ return DataType.FLOAT
326
+ except ValueError:
327
+ pass
328
+
329
+ # Default to text
330
+ return DataType.TEXT
331
+
332
+
333
+ def auto_select_mode(value_type: DataType, ref_type: DataType) -> ComparisonMode:
334
+ """Auto-select comparison mode based on detected types."""
335
+ # If either is JSON, use semantic
336
+ if DataType.JSON in (value_type, ref_type):
337
+ return ComparisonMode.SEMANTIC
338
+
339
+ # If either is numeric, use numeric
340
+ if value_type in (DataType.INTEGER, DataType.FLOAT) or ref_type in (
341
+ DataType.INTEGER,
342
+ DataType.FLOAT,
343
+ ):
344
+ return ComparisonMode.NUMERIC
345
+
346
+ # Booleans use exact
347
+ if value_type == DataType.BOOLEAN or ref_type == DataType.BOOLEAN:
348
+ return ComparisonMode.EXACT
349
+
350
+ # Default to fuzzy for text
351
+ return ComparisonMode.FUZZY
352
+
353
+
354
+ class CompareTool(BaseTool):
355
+ """Universal comparison tool with mode selection."""
356
+
357
+ name = "compare"
358
+ title = "Compare Tool"
359
+ description = "Compare values with explicit or automatic mode selection"
360
+
361
+ async def __call__(
362
+ self,
363
+ value: Any | list[Any] | None = None,
364
+ reference: Any | list[Any] | None = None,
365
+ mode: ComparisonMode = ComparisonMode.AUTO,
366
+ threshold: float = 0.8,
367
+ tolerance: float = 1e-6,
368
+ ) -> EvaluationResult:
369
+ """Compare values with specified or auto-detected mode."""
370
+ # Get value from submission if not provided
371
+ if value is None:
372
+ value = get_submission()
373
+
374
+ # Normalize inputs to lists
375
+ if value is None or reference is None:
376
+ return EvaluationResult(
377
+ reward=0.0, done=False, content="Missing value or reference", isError=True
378
+ )
379
+
380
+ # Convert to lists
381
+ val_list = value if isinstance(value, list) else [value]
382
+ ref_list = reference if isinstance(reference, list) else [reference]
383
+
384
+ # Check list compatibility
385
+ comp_type = "scalar"
386
+ if isinstance(value, list) and isinstance(reference, list):
387
+ if len(val_list) != len(ref_list):
388
+ return EvaluationResult(
389
+ reward=0.0,
390
+ done=False,
391
+ content=f"Error: Mismatched lengths - {len(val_list)} values vs {len(ref_list)} references", # noqa: E501
392
+ )
393
+ comp_type = "batch"
394
+ elif isinstance(value, list) and not isinstance(reference, list):
395
+ ref_list = [reference] * len(val_list)
396
+ comp_type = "broadcast"
397
+ elif not isinstance(value, list) and isinstance(reference, list):
398
+ val_list = [value] * len(ref_list)
399
+ comp_type = "broadcast"
400
+
401
+ # Process each pair
402
+ results = []
403
+ for v, r in zip(val_list, ref_list, strict=False):
404
+ # Convert to strings
405
+ v_str, r_str = str(v), str(r)
406
+
407
+ # Determine comparison mode
408
+ if mode == ComparisonMode.AUTO:
409
+ v_type = detect_type(v_str)
410
+ r_type = detect_type(r_str)
411
+ comparison_mode = auto_select_mode(v_type, r_type)
412
+ detected_type = v_type if v_type == r_type else DataType.TEXT
413
+ else:
414
+ comparison_mode = mode
415
+ detected_type = DataType.TEXT
416
+
417
+ # For exact mode, skip parsing and compare raw strings
418
+ if comparison_mode == ComparisonMode.EXACT and mode == ComparisonMode.EXACT:
419
+ matches = v_str == r_str
420
+ result = ComparisonResult(
421
+ matches=matches,
422
+ similarity=1.0 if matches else 0.0,
423
+ detected_type=detected_type,
424
+ comparison_mode=comparison_mode,
425
+ )
426
+ else:
427
+ # Get comparison function and run it
428
+ compare_fn = COMPARISON_FUNCTIONS[comparison_mode]
429
+ matches, similarity, details = compare_fn(
430
+ v_str, r_str, threshold=threshold, tolerance=tolerance
431
+ )
432
+
433
+ result = ComparisonResult(
434
+ matches=matches,
435
+ similarity=similarity,
436
+ detected_type=detected_type,
437
+ comparison_mode=comparison_mode,
438
+ details=details,
439
+ )
440
+
441
+ results.append(result)
442
+
443
+ # Aggregate results
444
+ if not results:
445
+ return EvaluationResult(reward=0.0, done=False, content="No comparisons performed")
446
+
447
+ total_similarity = sum(r.similarity for r in results)
448
+ avg_similarity = total_similarity / len(results)
449
+ all_match = all(r.matches for r in results)
450
+ match_count = sum(1 for r in results if r.matches)
451
+
452
+ # Format content
453
+ prefix = {"scalar": "Single", "batch": "Batch", "broadcast": "Broadcast"}.get(
454
+ comp_type, "Unknown"
455
+ )
456
+ mode_name = results[0].comparison_mode.value.capitalize()
457
+
458
+ return EvaluationResult(
459
+ reward=avg_similarity,
460
+ done=all_match,
461
+ content=f"{prefix} {mode_name}: {match_count}/{len(results)} matches, avg={avg_similarity:.3f}", # noqa: E501
462
+ )
463
+
464
+
465
+ # Map of specific aliases to their preprocessing needs
466
+ ALIAS_PREPROCESSORS = {
467
+ "compare_json": lambda v: extract_json(v),
468
+ "compare_int": lambda v: extract_number(v, "int"),
469
+ "compare_float": lambda v: extract_number(v, "float"),
470
+ "compare_boolean": lambda v: extract_boolean(v),
471
+ "compare_list": lambda v: extract_list(v),
472
+ }
473
+
474
+
475
+ # Helper to create alias tool classes
476
+ def make_alias_tool(name: str, preset_mode: ComparisonMode, description: str) -> type[BaseTool]:
477
+ """Create an alias tool class that presets the mode."""
478
+
479
+ class AliasTool(BaseTool):
480
+ def __init__(self) -> None:
481
+ super().__init__(
482
+ name=name,
483
+ title=f"Compare ({preset_mode.capitalize()})",
484
+ description=description + " (auto-handles lists, extracts from outputs)",
485
+ )
486
+
487
+ async def __call__(
488
+ self,
489
+ value: Any | list[Any] | None = None,
490
+ reference: Any | list[Any] | None = None,
491
+ threshold: float = 0.8,
492
+ tolerance: float = 1e-6,
493
+ ) -> EvaluationResult:
494
+ """Alias that calls compare with preset mode."""
495
+ # Apply specific preprocessing if this alias has one
496
+ if value is not None and name in ALIAS_PREPROCESSORS:
497
+ preprocessor = ALIAS_PREPROCESSORS[name]
498
+ if isinstance(value, list):
499
+ value = [preprocessor(str(v)) for v in value]
500
+ else:
501
+ value = preprocessor(str(value))
502
+
503
+ tool = CompareTool()
504
+ return await tool(
505
+ value=value,
506
+ reference=reference,
507
+ mode=preset_mode,
508
+ threshold=threshold,
509
+ tolerance=tolerance,
510
+ )
511
+
512
+ return AliasTool
513
+
514
+
515
+ # Create MCP server
516
+ comparator_server = MCPServer(name="comparator")
517
+
518
+ # Register main tool
519
+ comparator_server.add_tool(SubmitTool())
520
+ comparator_server.add_tool(CompareTool())
521
+
522
+ # Register aliases - these are just thin wrappers
523
+ ALIASES = [
524
+ ("compare_exact", ComparisonMode.EXACT, "Exact string comparison"),
525
+ ("compare_text", ComparisonMode.FUZZY, "Fuzzy text comparison"),
526
+ ("compare_string", ComparisonMode.FUZZY, "Fuzzy string comparison (alias for text)"),
527
+ ("compare_numeric", ComparisonMode.NUMERIC, "Numeric comparison with tolerance"),
528
+ ("compare_float", ComparisonMode.NUMERIC, "Float comparison (alias for numeric)"),
529
+ ("compare_int", ComparisonMode.NUMERIC, "Integer comparison (alias for numeric)"),
530
+ ("compare_json", ComparisonMode.SEMANTIC, "Semantic JSON comparison"),
531
+ ("compare_boolean", ComparisonMode.EXACT, "Boolean comparison (exact match)"),
532
+ ("compare_list", ComparisonMode.SEMANTIC, "List comparison (alias for CompareTool)"),
533
+ ]
534
+
535
+ for name, mode, desc in ALIASES:
536
+ AliasTool = make_alias_tool(name, mode, desc)
537
+ comparator_server.add_tool(AliasTool())
538
+
539
+ # Export for mounting
540
+ __all__ = ["comparator_server"]
541
+
542
+
543
+ if __name__ == "__main__":
544
+ # Run as standalone server
545
+ logger.info("Starting Comparator MCP Server...")
546
+ comparator_server.run()
@@ -0,0 +1 @@
1
+ """Tests for the native MCP servers."""