themis-eval 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. themis/__init__.py +12 -1
  2. themis/_version.py +2 -2
  3. themis/api.py +343 -0
  4. themis/backends/__init__.py +17 -0
  5. themis/backends/execution.py +197 -0
  6. themis/backends/storage.py +260 -0
  7. themis/cli/__init__.py +5 -0
  8. themis/cli/__main__.py +6 -0
  9. themis/cli/commands/__init__.py +19 -0
  10. themis/cli/commands/benchmarks.py +221 -0
  11. themis/cli/commands/comparison.py +394 -0
  12. themis/cli/commands/config_commands.py +244 -0
  13. themis/cli/commands/cost.py +214 -0
  14. themis/cli/commands/demo.py +68 -0
  15. themis/cli/commands/info.py +90 -0
  16. themis/cli/commands/leaderboard.py +362 -0
  17. themis/cli/commands/math_benchmarks.py +318 -0
  18. themis/cli/commands/mcq_benchmarks.py +207 -0
  19. themis/cli/commands/results.py +252 -0
  20. themis/cli/commands/sample_run.py +244 -0
  21. themis/cli/commands/visualize.py +299 -0
  22. themis/cli/main.py +463 -0
  23. themis/cli/new_project.py +33 -0
  24. themis/cli/utils.py +51 -0
  25. themis/comparison/__init__.py +25 -0
  26. themis/comparison/engine.py +348 -0
  27. themis/comparison/reports.py +283 -0
  28. themis/comparison/statistics.py +402 -0
  29. themis/config/__init__.py +19 -0
  30. themis/config/loader.py +27 -0
  31. themis/config/registry.py +34 -0
  32. themis/config/runtime.py +214 -0
  33. themis/config/schema.py +112 -0
  34. themis/core/__init__.py +5 -0
  35. themis/core/conversation.py +354 -0
  36. themis/core/entities.py +184 -0
  37. themis/core/serialization.py +231 -0
  38. themis/core/tools.py +393 -0
  39. themis/core/types.py +141 -0
  40. themis/datasets/__init__.py +273 -0
  41. themis/datasets/base.py +264 -0
  42. themis/datasets/commonsense_qa.py +174 -0
  43. themis/datasets/competition_math.py +265 -0
  44. themis/datasets/coqa.py +133 -0
  45. themis/datasets/gpqa.py +190 -0
  46. themis/datasets/gsm8k.py +123 -0
  47. themis/datasets/gsm_symbolic.py +124 -0
  48. themis/datasets/math500.py +122 -0
  49. themis/datasets/med_qa.py +179 -0
  50. themis/datasets/medmcqa.py +169 -0
  51. themis/datasets/mmlu_pro.py +262 -0
  52. themis/datasets/piqa.py +146 -0
  53. themis/datasets/registry.py +201 -0
  54. themis/datasets/schema.py +245 -0
  55. themis/datasets/sciq.py +150 -0
  56. themis/datasets/social_i_qa.py +151 -0
  57. themis/datasets/super_gpqa.py +263 -0
  58. themis/evaluation/__init__.py +1 -0
  59. themis/evaluation/conditional.py +410 -0
  60. themis/evaluation/extractors/__init__.py +19 -0
  61. themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
  62. themis/evaluation/extractors/exceptions.py +7 -0
  63. themis/evaluation/extractors/identity_extractor.py +29 -0
  64. themis/evaluation/extractors/json_field_extractor.py +45 -0
  65. themis/evaluation/extractors/math_verify_extractor.py +37 -0
  66. themis/evaluation/extractors/regex_extractor.py +43 -0
  67. themis/evaluation/math_verify_utils.py +87 -0
  68. themis/evaluation/metrics/__init__.py +21 -0
  69. themis/evaluation/metrics/code/__init__.py +19 -0
  70. themis/evaluation/metrics/code/codebleu.py +144 -0
  71. themis/evaluation/metrics/code/execution.py +280 -0
  72. themis/evaluation/metrics/code/pass_at_k.py +181 -0
  73. themis/evaluation/metrics/composite_metric.py +47 -0
  74. themis/evaluation/metrics/consistency_metric.py +80 -0
  75. themis/evaluation/metrics/exact_match.py +51 -0
  76. themis/evaluation/metrics/length_difference_tolerance.py +33 -0
  77. themis/evaluation/metrics/math_verify_accuracy.py +40 -0
  78. themis/evaluation/metrics/nlp/__init__.py +21 -0
  79. themis/evaluation/metrics/nlp/bertscore.py +138 -0
  80. themis/evaluation/metrics/nlp/bleu.py +129 -0
  81. themis/evaluation/metrics/nlp/meteor.py +153 -0
  82. themis/evaluation/metrics/nlp/rouge.py +136 -0
  83. themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
  84. themis/evaluation/metrics/response_length.py +33 -0
  85. themis/evaluation/metrics/rubric_judge_metric.py +134 -0
  86. themis/evaluation/pipeline.py +49 -0
  87. themis/evaluation/pipelines/__init__.py +15 -0
  88. themis/evaluation/pipelines/composable_pipeline.py +357 -0
  89. themis/evaluation/pipelines/standard_pipeline.py +348 -0
  90. themis/evaluation/reports.py +293 -0
  91. themis/evaluation/statistics/__init__.py +53 -0
  92. themis/evaluation/statistics/bootstrap.py +79 -0
  93. themis/evaluation/statistics/confidence_intervals.py +121 -0
  94. themis/evaluation/statistics/distributions.py +207 -0
  95. themis/evaluation/statistics/effect_sizes.py +124 -0
  96. themis/evaluation/statistics/hypothesis_tests.py +305 -0
  97. themis/evaluation/statistics/types.py +139 -0
  98. themis/evaluation/strategies/__init__.py +13 -0
  99. themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
  100. themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
  101. themis/evaluation/strategies/evaluation_strategy.py +24 -0
  102. themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
  103. themis/experiment/__init__.py +5 -0
  104. themis/experiment/builder.py +151 -0
  105. themis/experiment/cache_manager.py +134 -0
  106. themis/experiment/comparison.py +631 -0
  107. themis/experiment/cost.py +310 -0
  108. themis/experiment/definitions.py +62 -0
  109. themis/experiment/export.py +798 -0
  110. themis/experiment/export_csv.py +159 -0
  111. themis/experiment/integration_manager.py +104 -0
  112. themis/experiment/math.py +192 -0
  113. themis/experiment/mcq.py +169 -0
  114. themis/experiment/orchestrator.py +415 -0
  115. themis/experiment/pricing.py +317 -0
  116. themis/experiment/storage.py +1458 -0
  117. themis/experiment/visualization.py +588 -0
  118. themis/generation/__init__.py +1 -0
  119. themis/generation/agentic_runner.py +420 -0
  120. themis/generation/batching.py +254 -0
  121. themis/generation/clients.py +143 -0
  122. themis/generation/conversation_runner.py +236 -0
  123. themis/generation/plan.py +456 -0
  124. themis/generation/providers/litellm_provider.py +221 -0
  125. themis/generation/providers/vllm_provider.py +135 -0
  126. themis/generation/router.py +34 -0
  127. themis/generation/runner.py +207 -0
  128. themis/generation/strategies.py +98 -0
  129. themis/generation/templates.py +71 -0
  130. themis/generation/turn_strategies.py +393 -0
  131. themis/generation/types.py +9 -0
  132. themis/integrations/__init__.py +0 -0
  133. themis/integrations/huggingface.py +72 -0
  134. themis/integrations/wandb.py +77 -0
  135. themis/interfaces/__init__.py +169 -0
  136. themis/presets/__init__.py +10 -0
  137. themis/presets/benchmarks.py +354 -0
  138. themis/presets/models.py +190 -0
  139. themis/project/__init__.py +20 -0
  140. themis/project/definitions.py +98 -0
  141. themis/project/patterns.py +230 -0
  142. themis/providers/__init__.py +5 -0
  143. themis/providers/registry.py +39 -0
  144. themis/server/__init__.py +28 -0
  145. themis/server/app.py +337 -0
  146. themis/utils/api_generator.py +379 -0
  147. themis/utils/cost_tracking.py +376 -0
  148. themis/utils/dashboard.py +452 -0
  149. themis/utils/logging_utils.py +41 -0
  150. themis/utils/progress.py +58 -0
  151. themis/utils/tracing.py +320 -0
  152. themis_eval-0.2.0.dist-info/METADATA +596 -0
  153. themis_eval-0.2.0.dist-info/RECORD +157 -0
  154. {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/WHEEL +1 -1
  155. themis_eval-0.1.0.dist-info/METADATA +0 -758
  156. themis_eval-0.1.0.dist-info/RECORD +0 -8
  157. {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/licenses/LICENSE +0 -0
  158. {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,231 @@
1
+ """Serialization helpers for Themis core entities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import copy
6
+ from typing import Any, Dict
7
+
8
+ from themis.core import entities as core_entities
9
+
10
+
11
+ def serialize_sampling(config: core_entities.SamplingConfig) -> Dict[str, Any]:
12
+ return {
13
+ "temperature": config.temperature,
14
+ "top_p": config.top_p,
15
+ "max_tokens": config.max_tokens,
16
+ }
17
+
18
+
19
+ def deserialize_sampling(data: Dict[str, Any]) -> core_entities.SamplingConfig:
20
+ return core_entities.SamplingConfig(
21
+ temperature=data["temperature"],
22
+ top_p=data["top_p"],
23
+ max_tokens=data["max_tokens"],
24
+ )
25
+
26
+
27
+ def serialize_model_spec(spec: core_entities.ModelSpec) -> Dict[str, Any]:
28
+ return {
29
+ "identifier": spec.identifier,
30
+ "provider": spec.provider,
31
+ "metadata": copy.deepcopy(spec.metadata),
32
+ "default_sampling": serialize_sampling(spec.default_sampling)
33
+ if spec.default_sampling
34
+ else None,
35
+ }
36
+
37
+
38
+ def deserialize_model_spec(data: Dict[str, Any]) -> core_entities.ModelSpec:
39
+ default_sampling = (
40
+ deserialize_sampling(data["default_sampling"])
41
+ if data.get("default_sampling")
42
+ else None
43
+ )
44
+ return core_entities.ModelSpec(
45
+ identifier=data["identifier"],
46
+ provider=data["provider"],
47
+ metadata=copy.deepcopy(data.get("metadata", {})),
48
+ default_sampling=default_sampling,
49
+ )
50
+
51
+
52
+ def serialize_prompt_spec(spec: core_entities.PromptSpec) -> Dict[str, Any]:
53
+ return {
54
+ "name": spec.name,
55
+ "template": spec.template,
56
+ "metadata": copy.deepcopy(spec.metadata),
57
+ }
58
+
59
+
60
+ def deserialize_prompt_spec(data: Dict[str, Any]) -> core_entities.PromptSpec:
61
+ return core_entities.PromptSpec(
62
+ name=data["name"],
63
+ template=data["template"],
64
+ metadata=copy.deepcopy(data.get("metadata", {})),
65
+ )
66
+
67
+
68
+ def serialize_prompt_render(render: core_entities.PromptRender) -> Dict[str, Any]:
69
+ return {
70
+ "spec": serialize_prompt_spec(render.spec),
71
+ "text": render.text,
72
+ "context": copy.deepcopy(render.context),
73
+ "metadata": copy.deepcopy(render.metadata),
74
+ }
75
+
76
+
77
+ def deserialize_prompt_render(data: Dict[str, Any]) -> core_entities.PromptRender:
78
+ return core_entities.PromptRender(
79
+ spec=deserialize_prompt_spec(data["spec"]),
80
+ text=data["text"],
81
+ context=copy.deepcopy(data.get("context", {})),
82
+ metadata=copy.deepcopy(data.get("metadata", {})),
83
+ )
84
+
85
+
86
+ def serialize_reference(
87
+ reference: core_entities.Reference | None,
88
+ ) -> Dict[str, Any] | None:
89
+ if reference is None:
90
+ return None
91
+ return {"kind": reference.kind, "value": reference.value}
92
+
93
+
94
+ def deserialize_reference(
95
+ data: Dict[str, Any] | None,
96
+ ) -> core_entities.Reference | None:
97
+ if data is None:
98
+ return None
99
+ return core_entities.Reference(kind=data["kind"], value=data.get("value"))
100
+
101
+
102
+ def serialize_generation_task(task: core_entities.GenerationTask) -> Dict[str, Any]:
103
+ return {
104
+ "prompt": serialize_prompt_render(task.prompt),
105
+ "model": serialize_model_spec(task.model),
106
+ "sampling": serialize_sampling(task.sampling),
107
+ "metadata": copy.deepcopy(task.metadata),
108
+ "reference": serialize_reference(task.reference),
109
+ }
110
+
111
+
112
+ def deserialize_generation_task(data: Dict[str, Any]) -> core_entities.GenerationTask:
113
+ return core_entities.GenerationTask(
114
+ prompt=deserialize_prompt_render(data["prompt"]),
115
+ model=deserialize_model_spec(data["model"]),
116
+ sampling=deserialize_sampling(data["sampling"]),
117
+ metadata=copy.deepcopy(data.get("metadata", {})),
118
+ reference=deserialize_reference(data.get("reference")),
119
+ )
120
+
121
+
122
+ def serialize_generation_record(
123
+ record: core_entities.GenerationRecord,
124
+ ) -> Dict[str, Any]:
125
+ return {
126
+ "task": serialize_generation_task(record.task),
127
+ "output": {
128
+ "text": record.output.text,
129
+ "raw": record.output.raw,
130
+ }
131
+ if record.output
132
+ else None,
133
+ "error": {
134
+ "message": record.error.message,
135
+ "kind": record.error.kind,
136
+ "details": copy.deepcopy(record.error.details),
137
+ }
138
+ if record.error
139
+ else None,
140
+ "metrics": copy.deepcopy(record.metrics),
141
+ "attempts": [
142
+ serialize_generation_record(attempt) for attempt in record.attempts
143
+ ],
144
+ }
145
+
146
+
147
+ def deserialize_generation_record(
148
+ data: Dict[str, Any],
149
+ ) -> core_entities.GenerationRecord:
150
+ output_data = data.get("output")
151
+ error_data = data.get("error")
152
+ return core_entities.GenerationRecord(
153
+ task=deserialize_generation_task(data["task"]),
154
+ output=core_entities.ModelOutput(
155
+ text=output_data["text"], raw=output_data.get("raw")
156
+ )
157
+ if output_data
158
+ else None,
159
+ error=core_entities.ModelError(
160
+ message=error_data["message"],
161
+ kind=error_data.get("kind", "model_error"),
162
+ details=copy.deepcopy(error_data.get("details", {})),
163
+ )
164
+ if error_data
165
+ else None,
166
+ metrics=copy.deepcopy(data.get("metrics", {})),
167
+ attempts=[
168
+ deserialize_generation_record(attempt)
169
+ for attempt in data.get("attempts", [])
170
+ ],
171
+ )
172
+
173
+
174
+ def serialize_metric_score(score: core_entities.MetricScore) -> Dict[str, Any]:
175
+ return {
176
+ "metric_name": score.metric_name,
177
+ "value": score.value,
178
+ "details": copy.deepcopy(score.details),
179
+ "metadata": copy.deepcopy(score.metadata),
180
+ }
181
+
182
+
183
+ def deserialize_metric_score(data: Dict[str, Any]) -> core_entities.MetricScore:
184
+ return core_entities.MetricScore(
185
+ metric_name=data["metric_name"],
186
+ value=data["value"],
187
+ details=copy.deepcopy(data.get("details", {})),
188
+ metadata=copy.deepcopy(data.get("metadata", {})),
189
+ )
190
+
191
+
192
+ def serialize_evaluation_record(
193
+ record: core_entities.EvaluationRecord,
194
+ ) -> Dict[str, Any]:
195
+ return {
196
+ "sample_id": record.sample_id,
197
+ "scores": [serialize_metric_score(score) for score in record.scores],
198
+ "failures": list(record.failures),
199
+ }
200
+
201
+
202
+ def deserialize_evaluation_record(
203
+ data: Dict[str, Any],
204
+ ) -> core_entities.EvaluationRecord:
205
+ return core_entities.EvaluationRecord(
206
+ sample_id=data.get("sample_id"),
207
+ scores=[deserialize_metric_score(score) for score in data.get("scores", [])],
208
+ failures=list(data.get("failures", [])),
209
+ )
210
+
211
+
212
+ __all__ = [
213
+ "serialize_generation_record",
214
+ "deserialize_generation_record",
215
+ "serialize_generation_task",
216
+ "deserialize_generation_task",
217
+ "serialize_evaluation_record",
218
+ "deserialize_evaluation_record",
219
+ "serialize_metric_score",
220
+ "deserialize_metric_score",
221
+ "serialize_sampling",
222
+ "deserialize_sampling",
223
+ "serialize_model_spec",
224
+ "deserialize_model_spec",
225
+ "serialize_prompt_spec",
226
+ "deserialize_prompt_spec",
227
+ "serialize_prompt_render",
228
+ "deserialize_prompt_render",
229
+ "serialize_reference",
230
+ "deserialize_reference",
231
+ ]
themis/core/tools.py ADDED
@@ -0,0 +1,393 @@
1
+ """Tool use primitives for agentic workflows.
2
+
3
+ This module provides abstractions for defining and executing tools
4
+ (functions) that models can call during generation. This enables
5
+ agentic workflows, function calling, and tool-augmented generation.
6
+
7
+ Examples:
8
+ # Define a tool
9
+ def calculator(operation: str, a: float, b: float) -> float:
10
+ if operation == "add":
11
+ return a + b
12
+ elif operation == "multiply":
13
+ return a * b
14
+ raise ValueError(f"Unknown operation: {operation}")
15
+
16
+ tool = ToolDefinition(
17
+ name="calculator",
18
+ description="Perform arithmetic operations",
19
+ parameters={
20
+ "type": "object",
21
+ "properties": {
22
+ "operation": {"type": "string", "enum": ["add", "multiply"]},
23
+ "a": {"type": "number"},
24
+ "b": {"type": "number"},
25
+ },
26
+ "required": ["operation", "a", "b"],
27
+ },
28
+ handler=calculator
29
+ )
30
+
31
+ # Register tool
32
+ registry = ToolRegistry()
33
+ registry.register(tool)
34
+
35
+ # Execute tool
36
+ call = ToolCall(tool_name="calculator", arguments={"operation": "add", "a": 2, "b": 3})
37
+ result = registry.execute(call)
38
+ print(result.result) # 5.0
39
+ """
40
+
41
+ from __future__ import annotations
42
+
43
+ import time
44
+ import uuid
45
+ from dataclasses import dataclass, field
46
+ from typing import Any, Callable
47
+
48
+
49
+ @dataclass
50
+ class ToolDefinition:
51
+ """Defines a tool/function available to the model.
52
+
53
+ Attributes:
54
+ name: Tool name (should be unique)
55
+ description: Human-readable description of what tool does
56
+ parameters: JSON Schema describing parameters
57
+ handler: Function to execute when tool is called
58
+ metadata: Additional metadata
59
+ """
60
+
61
+ name: str
62
+ description: str
63
+ parameters: dict[str, Any]
64
+ handler: Callable[[dict[str, Any]], Any]
65
+ metadata: dict[str, Any] = field(default_factory=dict)
66
+
67
+ def to_dict(self) -> dict[str, Any]:
68
+ """Convert tool definition to dictionary (without handler).
69
+
70
+ Returns:
71
+ Dictionary representation suitable for JSON serialization
72
+ """
73
+ return {
74
+ "name": self.name,
75
+ "description": self.description,
76
+ "parameters": self.parameters,
77
+ "metadata": self.metadata,
78
+ }
79
+
80
+ def validate_arguments(self, arguments: dict[str, Any]) -> list[str]:
81
+ """Validate arguments against parameter schema.
82
+
83
+ Args:
84
+ arguments: Arguments to validate
85
+
86
+ Returns:
87
+ List of validation error messages (empty if valid)
88
+ """
89
+ errors = []
90
+
91
+ # Simple validation - check required fields
92
+ if "required" in self.parameters:
93
+ for field in self.parameters["required"]:
94
+ if field not in arguments:
95
+ errors.append(f"Missing required field: {field}")
96
+
97
+ # Check for unknown fields
98
+ if "properties" in self.parameters:
99
+ known_fields = set(self.parameters["properties"].keys())
100
+ for field in arguments.keys():
101
+ if field not in known_fields:
102
+ errors.append(f"Unknown field: {field}")
103
+
104
+ return errors
105
+
106
+
107
+ @dataclass
108
+ class ToolCall:
109
+ """Represents a request to execute a tool.
110
+
111
+ Attributes:
112
+ tool_name: Name of tool to execute
113
+ arguments: Arguments to pass to tool
114
+ call_id: Unique identifier for this call
115
+ """
116
+
117
+ tool_name: str
118
+ arguments: dict[str, Any]
119
+ call_id: str = field(default_factory=lambda: str(uuid.uuid4()))
120
+
121
+ def to_dict(self) -> dict[str, Any]:
122
+ """Convert to dictionary.
123
+
124
+ Returns:
125
+ Dictionary representation
126
+ """
127
+ return {
128
+ "tool_name": self.tool_name,
129
+ "arguments": self.arguments,
130
+ "call_id": self.call_id,
131
+ }
132
+
133
+
134
+ @dataclass
135
+ class ToolResult:
136
+ """Result from executing a tool.
137
+
138
+ Attributes:
139
+ call: Original tool call
140
+ result: Result value (if successful)
141
+ error: Error message (if failed)
142
+ execution_time_ms: Time taken to execute (milliseconds)
143
+ metadata: Additional metadata
144
+ """
145
+
146
+ call: ToolCall
147
+ result: Any | None
148
+ error: str | None
149
+ execution_time_ms: float
150
+ metadata: dict[str, Any] = field(default_factory=dict)
151
+
152
+ def is_success(self) -> bool:
153
+ """Check if tool execution was successful.
154
+
155
+ Returns:
156
+ True if no error
157
+ """
158
+ return self.error is None
159
+
160
+ def to_dict(self) -> dict[str, Any]:
161
+ """Convert to dictionary.
162
+
163
+ Returns:
164
+ Dictionary representation
165
+ """
166
+ return {
167
+ "call": self.call.to_dict(),
168
+ "result": self.result,
169
+ "error": self.error,
170
+ "execution_time_ms": self.execution_time_ms,
171
+ "metadata": self.metadata,
172
+ }
173
+
174
+
175
+ class ToolRegistry:
176
+ """Registry for managing and executing tools.
177
+
178
+ This class maintains a registry of available tools and provides
179
+ methods for registering, retrieving, and executing them.
180
+
181
+ Examples:
182
+ registry = ToolRegistry()
183
+
184
+ # Register tools
185
+ registry.register(calculator_tool)
186
+ registry.register(search_tool)
187
+
188
+ # Execute tool
189
+ call = ToolCall(tool_name="calculator", arguments={...})
190
+ result = registry.execute(call)
191
+ """
192
+
193
+ def __init__(self):
194
+ """Initialize empty tool registry."""
195
+ self._tools: dict[str, ToolDefinition] = {}
196
+
197
+ def register(self, tool: ToolDefinition) -> None:
198
+ """Register a tool.
199
+
200
+ Args:
201
+ tool: Tool definition to register
202
+
203
+ Raises:
204
+ ValueError: If tool with same name already registered
205
+ """
206
+ if tool.name in self._tools:
207
+ raise ValueError(f"Tool '{tool.name}' already registered")
208
+
209
+ self._tools[tool.name] = tool
210
+
211
+ def unregister(self, name: str) -> None:
212
+ """Unregister a tool by name.
213
+
214
+ Args:
215
+ name: Tool name to unregister
216
+ """
217
+ self._tools.pop(name, None)
218
+
219
+ def get(self, name: str) -> ToolDefinition | None:
220
+ """Get tool by name.
221
+
222
+ Args:
223
+ name: Tool name
224
+
225
+ Returns:
226
+ ToolDefinition if found, None otherwise
227
+ """
228
+ return self._tools.get(name)
229
+
230
+ def list_tools(self) -> list[ToolDefinition]:
231
+ """Get all registered tools.
232
+
233
+ Returns:
234
+ List of tool definitions
235
+ """
236
+ return list(self._tools.values())
237
+
238
+ def execute(self, call: ToolCall) -> ToolResult:
239
+ """Execute a tool call.
240
+
241
+ Args:
242
+ call: Tool call to execute
243
+
244
+ Returns:
245
+ ToolResult with execution result or error
246
+ """
247
+ tool = self._tools.get(call.tool_name)
248
+
249
+ if tool is None:
250
+ return ToolResult(
251
+ call=call,
252
+ result=None,
253
+ error=f"Unknown tool: {call.tool_name}",
254
+ execution_time_ms=0.0,
255
+ )
256
+
257
+ # Validate arguments
258
+ validation_errors = tool.validate_arguments(call.arguments)
259
+ if validation_errors:
260
+ return ToolResult(
261
+ call=call,
262
+ result=None,
263
+ error=f"Invalid arguments: {'; '.join(validation_errors)}",
264
+ execution_time_ms=0.0,
265
+ )
266
+
267
+ # Execute tool
268
+ start = time.perf_counter()
269
+ try:
270
+ result = tool.handler(call.arguments)
271
+ elapsed = (time.perf_counter() - start) * 1000
272
+ return ToolResult(
273
+ call=call,
274
+ result=result,
275
+ error=None,
276
+ execution_time_ms=elapsed,
277
+ )
278
+ except Exception as e:
279
+ elapsed = (time.perf_counter() - start) * 1000
280
+ return ToolResult(
281
+ call=call,
282
+ result=None,
283
+ error=f"{e.__class__.__name__}: {str(e)}",
284
+ execution_time_ms=elapsed,
285
+ )
286
+
287
+ def to_dict_list(self) -> list[dict[str, Any]]:
288
+ """Get all tools as dictionary list (for sending to model).
289
+
290
+ Returns:
291
+ List of tool definitions as dictionaries
292
+ """
293
+ return [tool.to_dict() for tool in self._tools.values()]
294
+
295
+
296
+ # Built-in tools for common use cases
297
+
298
+
299
+ def create_calculator_tool() -> ToolDefinition:
300
+ """Create a basic calculator tool.
301
+
302
+ Returns:
303
+ ToolDefinition for calculator
304
+ """
305
+
306
+ def handler(args: dict[str, Any]) -> float:
307
+ operation = args["operation"]
308
+ a = float(args["a"])
309
+ b = float(args["b"])
310
+
311
+ if operation == "add":
312
+ return a + b
313
+ elif operation == "subtract":
314
+ return a - b
315
+ elif operation == "multiply":
316
+ return a * b
317
+ elif operation == "divide":
318
+ if b == 0:
319
+ raise ValueError("Division by zero")
320
+ return a / b
321
+ else:
322
+ raise ValueError(f"Unknown operation: {operation}")
323
+
324
+ return ToolDefinition(
325
+ name="calculator",
326
+ description="Perform basic arithmetic operations (add, subtract, multiply, divide)",
327
+ parameters={
328
+ "type": "object",
329
+ "properties": {
330
+ "operation": {
331
+ "type": "string",
332
+ "enum": ["add", "subtract", "multiply", "divide"],
333
+ "description": "The arithmetic operation to perform",
334
+ },
335
+ "a": {"type": "number", "description": "First number"},
336
+ "b": {"type": "number", "description": "Second number"},
337
+ },
338
+ "required": ["operation", "a", "b"],
339
+ },
340
+ handler=handler,
341
+ )
342
+
343
+
344
+ def create_counter_tool() -> ToolDefinition:
345
+ """Create a stateful counter tool for testing.
346
+
347
+ Returns:
348
+ ToolDefinition for counter
349
+ """
350
+ counter = {"value": 0}
351
+
352
+ def handler(args: dict[str, Any]) -> int:
353
+ action = args["action"]
354
+
355
+ if action == "increment":
356
+ counter["value"] += 1
357
+ elif action == "decrement":
358
+ counter["value"] -= 1
359
+ elif action == "reset":
360
+ counter["value"] = 0
361
+ elif action == "get":
362
+ pass # Just return current value
363
+ else:
364
+ raise ValueError(f"Unknown action: {action}")
365
+
366
+ return counter["value"]
367
+
368
+ return ToolDefinition(
369
+ name="counter",
370
+ description="Simple counter that can be incremented, decremented, or reset",
371
+ parameters={
372
+ "type": "object",
373
+ "properties": {
374
+ "action": {
375
+ "type": "string",
376
+ "enum": ["increment", "decrement", "reset", "get"],
377
+ "description": "Action to perform on counter",
378
+ },
379
+ },
380
+ "required": ["action"],
381
+ },
382
+ handler=handler,
383
+ )
384
+
385
+
386
+ __all__ = [
387
+ "ToolDefinition",
388
+ "ToolCall",
389
+ "ToolResult",
390
+ "ToolRegistry",
391
+ "create_calculator_tool",
392
+ "create_counter_tool",
393
+ ]