themis-eval 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themis/cli/__init__.py +5 -0
- themis/cli/__main__.py +6 -0
- themis/cli/commands/__init__.py +19 -0
- themis/cli/commands/benchmarks.py +221 -0
- themis/cli/commands/comparison.py +394 -0
- themis/cli/commands/config_commands.py +244 -0
- themis/cli/commands/cost.py +214 -0
- themis/cli/commands/demo.py +68 -0
- themis/cli/commands/info.py +90 -0
- themis/cli/commands/leaderboard.py +362 -0
- themis/cli/commands/math_benchmarks.py +318 -0
- themis/cli/commands/mcq_benchmarks.py +207 -0
- themis/cli/commands/sample_run.py +244 -0
- themis/cli/commands/visualize.py +299 -0
- themis/cli/main.py +93 -0
- themis/cli/new_project.py +33 -0
- themis/cli/utils.py +51 -0
- themis/config/__init__.py +19 -0
- themis/config/loader.py +27 -0
- themis/config/registry.py +34 -0
- themis/config/runtime.py +214 -0
- themis/config/schema.py +112 -0
- themis/core/__init__.py +5 -0
- themis/core/conversation.py +354 -0
- themis/core/entities.py +164 -0
- themis/core/serialization.py +231 -0
- themis/core/tools.py +393 -0
- themis/core/types.py +141 -0
- themis/datasets/__init__.py +273 -0
- themis/datasets/base.py +264 -0
- themis/datasets/commonsense_qa.py +174 -0
- themis/datasets/competition_math.py +265 -0
- themis/datasets/coqa.py +133 -0
- themis/datasets/gpqa.py +190 -0
- themis/datasets/gsm8k.py +123 -0
- themis/datasets/gsm_symbolic.py +124 -0
- themis/datasets/math500.py +122 -0
- themis/datasets/med_qa.py +179 -0
- themis/datasets/medmcqa.py +169 -0
- themis/datasets/mmlu_pro.py +262 -0
- themis/datasets/piqa.py +146 -0
- themis/datasets/registry.py +201 -0
- themis/datasets/schema.py +245 -0
- themis/datasets/sciq.py +150 -0
- themis/datasets/social_i_qa.py +151 -0
- themis/datasets/super_gpqa.py +263 -0
- themis/evaluation/__init__.py +1 -0
- themis/evaluation/conditional.py +410 -0
- themis/evaluation/extractors/__init__.py +19 -0
- themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
- themis/evaluation/extractors/exceptions.py +7 -0
- themis/evaluation/extractors/identity_extractor.py +29 -0
- themis/evaluation/extractors/json_field_extractor.py +45 -0
- themis/evaluation/extractors/math_verify_extractor.py +37 -0
- themis/evaluation/extractors/regex_extractor.py +43 -0
- themis/evaluation/math_verify_utils.py +87 -0
- themis/evaluation/metrics/__init__.py +21 -0
- themis/evaluation/metrics/composite_metric.py +47 -0
- themis/evaluation/metrics/consistency_metric.py +80 -0
- themis/evaluation/metrics/exact_match.py +51 -0
- themis/evaluation/metrics/length_difference_tolerance.py +33 -0
- themis/evaluation/metrics/math_verify_accuracy.py +40 -0
- themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
- themis/evaluation/metrics/response_length.py +33 -0
- themis/evaluation/metrics/rubric_judge_metric.py +134 -0
- themis/evaluation/pipeline.py +49 -0
- themis/evaluation/pipelines/__init__.py +15 -0
- themis/evaluation/pipelines/composable_pipeline.py +357 -0
- themis/evaluation/pipelines/standard_pipeline.py +288 -0
- themis/evaluation/reports.py +293 -0
- themis/evaluation/statistics/__init__.py +53 -0
- themis/evaluation/statistics/bootstrap.py +79 -0
- themis/evaluation/statistics/confidence_intervals.py +121 -0
- themis/evaluation/statistics/distributions.py +207 -0
- themis/evaluation/statistics/effect_sizes.py +124 -0
- themis/evaluation/statistics/hypothesis_tests.py +305 -0
- themis/evaluation/statistics/types.py +139 -0
- themis/evaluation/strategies/__init__.py +13 -0
- themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
- themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
- themis/evaluation/strategies/evaluation_strategy.py +24 -0
- themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
- themis/experiment/__init__.py +5 -0
- themis/experiment/builder.py +151 -0
- themis/experiment/cache_manager.py +129 -0
- themis/experiment/comparison.py +631 -0
- themis/experiment/cost.py +310 -0
- themis/experiment/definitions.py +62 -0
- themis/experiment/export.py +690 -0
- themis/experiment/export_csv.py +159 -0
- themis/experiment/integration_manager.py +104 -0
- themis/experiment/math.py +192 -0
- themis/experiment/mcq.py +169 -0
- themis/experiment/orchestrator.py +373 -0
- themis/experiment/pricing.py +317 -0
- themis/experiment/storage.py +255 -0
- themis/experiment/visualization.py +588 -0
- themis/generation/__init__.py +1 -0
- themis/generation/agentic_runner.py +420 -0
- themis/generation/batching.py +254 -0
- themis/generation/clients.py +143 -0
- themis/generation/conversation_runner.py +236 -0
- themis/generation/plan.py +456 -0
- themis/generation/providers/litellm_provider.py +221 -0
- themis/generation/providers/vllm_provider.py +135 -0
- themis/generation/router.py +34 -0
- themis/generation/runner.py +207 -0
- themis/generation/strategies.py +98 -0
- themis/generation/templates.py +71 -0
- themis/generation/turn_strategies.py +393 -0
- themis/generation/types.py +9 -0
- themis/integrations/__init__.py +0 -0
- themis/integrations/huggingface.py +61 -0
- themis/integrations/wandb.py +65 -0
- themis/interfaces/__init__.py +83 -0
- themis/project/__init__.py +20 -0
- themis/project/definitions.py +98 -0
- themis/project/patterns.py +230 -0
- themis/providers/__init__.py +5 -0
- themis/providers/registry.py +39 -0
- themis/utils/api_generator.py +379 -0
- themis/utils/cost_tracking.py +376 -0
- themis/utils/dashboard.py +452 -0
- themis/utils/logging_utils.py +41 -0
- themis/utils/progress.py +58 -0
- themis/utils/tracing.py +320 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/METADATA +1 -1
- themis_eval-0.1.1.dist-info/RECORD +134 -0
- themis_eval-0.1.0.dist-info/RECORD +0 -8
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/WHEEL +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/top_level.txt +0 -0
themis/core/tools.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
1
|
+
"""Tool use primitives for agentic workflows.
|
|
2
|
+
|
|
3
|
+
This module provides abstractions for defining and executing tools
|
|
4
|
+
(functions) that models can call during generation. This enables
|
|
5
|
+
agentic workflows, function calling, and tool-augmented generation.
|
|
6
|
+
|
|
7
|
+
Examples:
|
|
8
|
+
# Define a tool
|
|
9
|
+
def calculator(operation: str, a: float, b: float) -> float:
|
|
10
|
+
if operation == "add":
|
|
11
|
+
return a + b
|
|
12
|
+
elif operation == "multiply":
|
|
13
|
+
return a * b
|
|
14
|
+
raise ValueError(f"Unknown operation: {operation}")
|
|
15
|
+
|
|
16
|
+
tool = ToolDefinition(
|
|
17
|
+
name="calculator",
|
|
18
|
+
description="Perform arithmetic operations",
|
|
19
|
+
parameters={
|
|
20
|
+
"type": "object",
|
|
21
|
+
"properties": {
|
|
22
|
+
"operation": {"type": "string", "enum": ["add", "multiply"]},
|
|
23
|
+
"a": {"type": "number"},
|
|
24
|
+
"b": {"type": "number"},
|
|
25
|
+
},
|
|
26
|
+
"required": ["operation", "a", "b"],
|
|
27
|
+
},
|
|
28
|
+
handler=calculator
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
# Register tool
|
|
32
|
+
registry = ToolRegistry()
|
|
33
|
+
registry.register(tool)
|
|
34
|
+
|
|
35
|
+
# Execute tool
|
|
36
|
+
call = ToolCall(tool_name="calculator", arguments={"operation": "add", "a": 2, "b": 3})
|
|
37
|
+
result = registry.execute(call)
|
|
38
|
+
print(result.result) # 5.0
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
from __future__ import annotations
|
|
42
|
+
|
|
43
|
+
import time
|
|
44
|
+
import uuid
|
|
45
|
+
from dataclasses import dataclass, field
|
|
46
|
+
from typing import Any, Callable
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class ToolDefinition:
|
|
51
|
+
"""Defines a tool/function available to the model.
|
|
52
|
+
|
|
53
|
+
Attributes:
|
|
54
|
+
name: Tool name (should be unique)
|
|
55
|
+
description: Human-readable description of what tool does
|
|
56
|
+
parameters: JSON Schema describing parameters
|
|
57
|
+
handler: Function to execute when tool is called
|
|
58
|
+
metadata: Additional metadata
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
name: str
|
|
62
|
+
description: str
|
|
63
|
+
parameters: dict[str, Any]
|
|
64
|
+
handler: Callable[[dict[str, Any]], Any]
|
|
65
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
66
|
+
|
|
67
|
+
def to_dict(self) -> dict[str, Any]:
|
|
68
|
+
"""Convert tool definition to dictionary (without handler).
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
Dictionary representation suitable for JSON serialization
|
|
72
|
+
"""
|
|
73
|
+
return {
|
|
74
|
+
"name": self.name,
|
|
75
|
+
"description": self.description,
|
|
76
|
+
"parameters": self.parameters,
|
|
77
|
+
"metadata": self.metadata,
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
def validate_arguments(self, arguments: dict[str, Any]) -> list[str]:
|
|
81
|
+
"""Validate arguments against parameter schema.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
arguments: Arguments to validate
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
List of validation error messages (empty if valid)
|
|
88
|
+
"""
|
|
89
|
+
errors = []
|
|
90
|
+
|
|
91
|
+
# Simple validation - check required fields
|
|
92
|
+
if "required" in self.parameters:
|
|
93
|
+
for field in self.parameters["required"]:
|
|
94
|
+
if field not in arguments:
|
|
95
|
+
errors.append(f"Missing required field: {field}")
|
|
96
|
+
|
|
97
|
+
# Check for unknown fields
|
|
98
|
+
if "properties" in self.parameters:
|
|
99
|
+
known_fields = set(self.parameters["properties"].keys())
|
|
100
|
+
for field in arguments.keys():
|
|
101
|
+
if field not in known_fields:
|
|
102
|
+
errors.append(f"Unknown field: {field}")
|
|
103
|
+
|
|
104
|
+
return errors
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@dataclass
|
|
108
|
+
class ToolCall:
|
|
109
|
+
"""Represents a request to execute a tool.
|
|
110
|
+
|
|
111
|
+
Attributes:
|
|
112
|
+
tool_name: Name of tool to execute
|
|
113
|
+
arguments: Arguments to pass to tool
|
|
114
|
+
call_id: Unique identifier for this call
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
tool_name: str
|
|
118
|
+
arguments: dict[str, Any]
|
|
119
|
+
call_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
|
120
|
+
|
|
121
|
+
def to_dict(self) -> dict[str, Any]:
|
|
122
|
+
"""Convert to dictionary.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
Dictionary representation
|
|
126
|
+
"""
|
|
127
|
+
return {
|
|
128
|
+
"tool_name": self.tool_name,
|
|
129
|
+
"arguments": self.arguments,
|
|
130
|
+
"call_id": self.call_id,
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
@dataclass
|
|
135
|
+
class ToolResult:
|
|
136
|
+
"""Result from executing a tool.
|
|
137
|
+
|
|
138
|
+
Attributes:
|
|
139
|
+
call: Original tool call
|
|
140
|
+
result: Result value (if successful)
|
|
141
|
+
error: Error message (if failed)
|
|
142
|
+
execution_time_ms: Time taken to execute (milliseconds)
|
|
143
|
+
metadata: Additional metadata
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
call: ToolCall
|
|
147
|
+
result: Any | None
|
|
148
|
+
error: str | None
|
|
149
|
+
execution_time_ms: float
|
|
150
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
151
|
+
|
|
152
|
+
def is_success(self) -> bool:
|
|
153
|
+
"""Check if tool execution was successful.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
True if no error
|
|
157
|
+
"""
|
|
158
|
+
return self.error is None
|
|
159
|
+
|
|
160
|
+
def to_dict(self) -> dict[str, Any]:
|
|
161
|
+
"""Convert to dictionary.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
Dictionary representation
|
|
165
|
+
"""
|
|
166
|
+
return {
|
|
167
|
+
"call": self.call.to_dict(),
|
|
168
|
+
"result": self.result,
|
|
169
|
+
"error": self.error,
|
|
170
|
+
"execution_time_ms": self.execution_time_ms,
|
|
171
|
+
"metadata": self.metadata,
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class ToolRegistry:
|
|
176
|
+
"""Registry for managing and executing tools.
|
|
177
|
+
|
|
178
|
+
This class maintains a registry of available tools and provides
|
|
179
|
+
methods for registering, retrieving, and executing them.
|
|
180
|
+
|
|
181
|
+
Examples:
|
|
182
|
+
registry = ToolRegistry()
|
|
183
|
+
|
|
184
|
+
# Register tools
|
|
185
|
+
registry.register(calculator_tool)
|
|
186
|
+
registry.register(search_tool)
|
|
187
|
+
|
|
188
|
+
# Execute tool
|
|
189
|
+
call = ToolCall(tool_name="calculator", arguments={...})
|
|
190
|
+
result = registry.execute(call)
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
def __init__(self):
|
|
194
|
+
"""Initialize empty tool registry."""
|
|
195
|
+
self._tools: dict[str, ToolDefinition] = {}
|
|
196
|
+
|
|
197
|
+
def register(self, tool: ToolDefinition) -> None:
|
|
198
|
+
"""Register a tool.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
tool: Tool definition to register
|
|
202
|
+
|
|
203
|
+
Raises:
|
|
204
|
+
ValueError: If tool with same name already registered
|
|
205
|
+
"""
|
|
206
|
+
if tool.name in self._tools:
|
|
207
|
+
raise ValueError(f"Tool '{tool.name}' already registered")
|
|
208
|
+
|
|
209
|
+
self._tools[tool.name] = tool
|
|
210
|
+
|
|
211
|
+
def unregister(self, name: str) -> None:
|
|
212
|
+
"""Unregister a tool by name.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
name: Tool name to unregister
|
|
216
|
+
"""
|
|
217
|
+
self._tools.pop(name, None)
|
|
218
|
+
|
|
219
|
+
def get(self, name: str) -> ToolDefinition | None:
|
|
220
|
+
"""Get tool by name.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
name: Tool name
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
ToolDefinition if found, None otherwise
|
|
227
|
+
"""
|
|
228
|
+
return self._tools.get(name)
|
|
229
|
+
|
|
230
|
+
def list_tools(self) -> list[ToolDefinition]:
|
|
231
|
+
"""Get all registered tools.
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
List of tool definitions
|
|
235
|
+
"""
|
|
236
|
+
return list(self._tools.values())
|
|
237
|
+
|
|
238
|
+
def execute(self, call: ToolCall) -> ToolResult:
|
|
239
|
+
"""Execute a tool call.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
call: Tool call to execute
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
ToolResult with execution result or error
|
|
246
|
+
"""
|
|
247
|
+
tool = self._tools.get(call.tool_name)
|
|
248
|
+
|
|
249
|
+
if tool is None:
|
|
250
|
+
return ToolResult(
|
|
251
|
+
call=call,
|
|
252
|
+
result=None,
|
|
253
|
+
error=f"Unknown tool: {call.tool_name}",
|
|
254
|
+
execution_time_ms=0.0,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# Validate arguments
|
|
258
|
+
validation_errors = tool.validate_arguments(call.arguments)
|
|
259
|
+
if validation_errors:
|
|
260
|
+
return ToolResult(
|
|
261
|
+
call=call,
|
|
262
|
+
result=None,
|
|
263
|
+
error=f"Invalid arguments: {'; '.join(validation_errors)}",
|
|
264
|
+
execution_time_ms=0.0,
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
# Execute tool
|
|
268
|
+
start = time.perf_counter()
|
|
269
|
+
try:
|
|
270
|
+
result = tool.handler(call.arguments)
|
|
271
|
+
elapsed = (time.perf_counter() - start) * 1000
|
|
272
|
+
return ToolResult(
|
|
273
|
+
call=call,
|
|
274
|
+
result=result,
|
|
275
|
+
error=None,
|
|
276
|
+
execution_time_ms=elapsed,
|
|
277
|
+
)
|
|
278
|
+
except Exception as e:
|
|
279
|
+
elapsed = (time.perf_counter() - start) * 1000
|
|
280
|
+
return ToolResult(
|
|
281
|
+
call=call,
|
|
282
|
+
result=None,
|
|
283
|
+
error=f"{e.__class__.__name__}: {str(e)}",
|
|
284
|
+
execution_time_ms=elapsed,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
def to_dict_list(self) -> list[dict[str, Any]]:
|
|
288
|
+
"""Get all tools as dictionary list (for sending to model).
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
List of tool definitions as dictionaries
|
|
292
|
+
"""
|
|
293
|
+
return [tool.to_dict() for tool in self._tools.values()]
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
# Built-in tools for common use cases
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def create_calculator_tool() -> ToolDefinition:
|
|
300
|
+
"""Create a basic calculator tool.
|
|
301
|
+
|
|
302
|
+
Returns:
|
|
303
|
+
ToolDefinition for calculator
|
|
304
|
+
"""
|
|
305
|
+
|
|
306
|
+
def handler(args: dict[str, Any]) -> float:
|
|
307
|
+
operation = args["operation"]
|
|
308
|
+
a = float(args["a"])
|
|
309
|
+
b = float(args["b"])
|
|
310
|
+
|
|
311
|
+
if operation == "add":
|
|
312
|
+
return a + b
|
|
313
|
+
elif operation == "subtract":
|
|
314
|
+
return a - b
|
|
315
|
+
elif operation == "multiply":
|
|
316
|
+
return a * b
|
|
317
|
+
elif operation == "divide":
|
|
318
|
+
if b == 0:
|
|
319
|
+
raise ValueError("Division by zero")
|
|
320
|
+
return a / b
|
|
321
|
+
else:
|
|
322
|
+
raise ValueError(f"Unknown operation: {operation}")
|
|
323
|
+
|
|
324
|
+
return ToolDefinition(
|
|
325
|
+
name="calculator",
|
|
326
|
+
description="Perform basic arithmetic operations (add, subtract, multiply, divide)",
|
|
327
|
+
parameters={
|
|
328
|
+
"type": "object",
|
|
329
|
+
"properties": {
|
|
330
|
+
"operation": {
|
|
331
|
+
"type": "string",
|
|
332
|
+
"enum": ["add", "subtract", "multiply", "divide"],
|
|
333
|
+
"description": "The arithmetic operation to perform",
|
|
334
|
+
},
|
|
335
|
+
"a": {"type": "number", "description": "First number"},
|
|
336
|
+
"b": {"type": "number", "description": "Second number"},
|
|
337
|
+
},
|
|
338
|
+
"required": ["operation", "a", "b"],
|
|
339
|
+
},
|
|
340
|
+
handler=handler,
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def create_counter_tool() -> ToolDefinition:
|
|
345
|
+
"""Create a stateful counter tool for testing.
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
ToolDefinition for counter
|
|
349
|
+
"""
|
|
350
|
+
counter = {"value": 0}
|
|
351
|
+
|
|
352
|
+
def handler(args: dict[str, Any]) -> int:
|
|
353
|
+
action = args["action"]
|
|
354
|
+
|
|
355
|
+
if action == "increment":
|
|
356
|
+
counter["value"] += 1
|
|
357
|
+
elif action == "decrement":
|
|
358
|
+
counter["value"] -= 1
|
|
359
|
+
elif action == "reset":
|
|
360
|
+
counter["value"] = 0
|
|
361
|
+
elif action == "get":
|
|
362
|
+
pass # Just return current value
|
|
363
|
+
else:
|
|
364
|
+
raise ValueError(f"Unknown action: {action}")
|
|
365
|
+
|
|
366
|
+
return counter["value"]
|
|
367
|
+
|
|
368
|
+
return ToolDefinition(
|
|
369
|
+
name="counter",
|
|
370
|
+
description="Simple counter that can be incremented, decremented, or reset",
|
|
371
|
+
parameters={
|
|
372
|
+
"type": "object",
|
|
373
|
+
"properties": {
|
|
374
|
+
"action": {
|
|
375
|
+
"type": "string",
|
|
376
|
+
"enum": ["increment", "decrement", "reset", "get"],
|
|
377
|
+
"description": "Action to perform on counter",
|
|
378
|
+
},
|
|
379
|
+
},
|
|
380
|
+
"required": ["action"],
|
|
381
|
+
},
|
|
382
|
+
handler=handler,
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
__all__ = [
|
|
387
|
+
"ToolDefinition",
|
|
388
|
+
"ToolCall",
|
|
389
|
+
"ToolResult",
|
|
390
|
+
"ToolRegistry",
|
|
391
|
+
"create_calculator_tool",
|
|
392
|
+
"create_counter_tool",
|
|
393
|
+
]
|
themis/core/types.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""Common type definitions and generic types for Themis.
|
|
2
|
+
|
|
3
|
+
This module provides improved type safety through generic types and protocols.
|
|
4
|
+
All types are designed to be backward compatible with existing code.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import Any, Protocol, Sequence, TypeVar, runtime_checkable
|
|
10
|
+
|
|
11
|
+
from themis.core import entities
|
|
12
|
+
|
|
13
|
+
# Type variables for generic types
|
|
14
|
+
T = TypeVar("T") # Generic type for predictions/references
|
|
15
|
+
T_co = TypeVar("T_co", covariant=True) # Covariant type for outputs
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@runtime_checkable
|
|
19
|
+
class TypedExtractor(Protocol[T_co]):
|
|
20
|
+
"""Protocol for extractors with typed output.
|
|
21
|
+
|
|
22
|
+
This is a backward-compatible extension of the Extractor protocol that
|
|
23
|
+
provides type information about the extraction output.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def extract(self, raw_output: str) -> T_co:
|
|
27
|
+
"""Extract structured data from raw output.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
raw_output: Raw text output from model
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Extracted value of type T_co
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
FieldExtractionError: If extraction fails
|
|
37
|
+
"""
|
|
38
|
+
...
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@runtime_checkable
|
|
42
|
+
class TypedMetric(Protocol[T]):
|
|
43
|
+
"""Protocol for metrics with typed predictions.
|
|
44
|
+
|
|
45
|
+
This is a backward-compatible extension of the Metric interface that
|
|
46
|
+
provides type information about expected prediction types.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
name: str
|
|
50
|
+
|
|
51
|
+
def compute(
|
|
52
|
+
self,
|
|
53
|
+
*,
|
|
54
|
+
prediction: T,
|
|
55
|
+
references: Sequence[T],
|
|
56
|
+
metadata: dict[str, Any] | None = None,
|
|
57
|
+
) -> entities.MetricScore:
|
|
58
|
+
"""Compute metric score.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
prediction: Model prediction of type T
|
|
62
|
+
references: Reference answers of type T
|
|
63
|
+
metadata: Optional metadata
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
MetricScore with computed value
|
|
67
|
+
"""
|
|
68
|
+
...
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# Common type aliases for better readability
|
|
72
|
+
PredictionType = TypeVar("PredictionType")
|
|
73
|
+
ReferenceType = TypeVar("ReferenceType")
|
|
74
|
+
ExtractionType = TypeVar("ExtractionType")
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class ValidationError(ValueError):
|
|
78
|
+
"""Raised when runtime type validation fails."""
|
|
79
|
+
|
|
80
|
+
pass
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def validate_type(value: Any, expected_type: type[T], field_name: str = "value") -> T:
|
|
84
|
+
"""Validate value against expected type at runtime.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
value: Value to validate
|
|
88
|
+
expected_type: Expected type
|
|
89
|
+
field_name: Name of field for error messages
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Value cast to expected type
|
|
93
|
+
|
|
94
|
+
Raises:
|
|
95
|
+
ValidationError: If type validation fails
|
|
96
|
+
"""
|
|
97
|
+
if not isinstance(value, expected_type):
|
|
98
|
+
raise ValidationError(
|
|
99
|
+
f"{field_name} expected type {expected_type.__name__}, "
|
|
100
|
+
f"got {type(value).__name__}"
|
|
101
|
+
)
|
|
102
|
+
return value
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def validate_sequence_type(
|
|
106
|
+
values: Sequence[Any], expected_type: type[T], field_name: str = "values"
|
|
107
|
+
) -> Sequence[T]:
|
|
108
|
+
"""Validate all values in sequence against expected type.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
values: Sequence to validate
|
|
112
|
+
expected_type: Expected type for elements
|
|
113
|
+
field_name: Name of field for error messages
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
Validated sequence
|
|
117
|
+
|
|
118
|
+
Raises:
|
|
119
|
+
ValidationError: If any element fails validation
|
|
120
|
+
"""
|
|
121
|
+
for i, value in enumerate(values):
|
|
122
|
+
if not isinstance(value, expected_type):
|
|
123
|
+
raise ValidationError(
|
|
124
|
+
f"{field_name}[{i}] expected type {expected_type.__name__}, "
|
|
125
|
+
f"got {type(value).__name__}"
|
|
126
|
+
)
|
|
127
|
+
return values
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
__all__ = [
|
|
131
|
+
"T",
|
|
132
|
+
"T_co",
|
|
133
|
+
"TypedExtractor",
|
|
134
|
+
"TypedMetric",
|
|
135
|
+
"PredictionType",
|
|
136
|
+
"ReferenceType",
|
|
137
|
+
"ExtractionType",
|
|
138
|
+
"ValidationError",
|
|
139
|
+
"validate_type",
|
|
140
|
+
"validate_sequence_type",
|
|
141
|
+
]
|