themis-eval 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themis/__init__.py +12 -1
- themis/_version.py +2 -2
- themis/api.py +343 -0
- themis/backends/__init__.py +17 -0
- themis/backends/execution.py +197 -0
- themis/backends/storage.py +260 -0
- themis/cli/__init__.py +5 -0
- themis/cli/__main__.py +6 -0
- themis/cli/commands/__init__.py +19 -0
- themis/cli/commands/benchmarks.py +221 -0
- themis/cli/commands/comparison.py +394 -0
- themis/cli/commands/config_commands.py +244 -0
- themis/cli/commands/cost.py +214 -0
- themis/cli/commands/demo.py +68 -0
- themis/cli/commands/info.py +90 -0
- themis/cli/commands/leaderboard.py +362 -0
- themis/cli/commands/math_benchmarks.py +318 -0
- themis/cli/commands/mcq_benchmarks.py +207 -0
- themis/cli/commands/results.py +252 -0
- themis/cli/commands/sample_run.py +244 -0
- themis/cli/commands/visualize.py +299 -0
- themis/cli/main.py +463 -0
- themis/cli/new_project.py +33 -0
- themis/cli/utils.py +51 -0
- themis/comparison/__init__.py +25 -0
- themis/comparison/engine.py +348 -0
- themis/comparison/reports.py +283 -0
- themis/comparison/statistics.py +402 -0
- themis/config/__init__.py +19 -0
- themis/config/loader.py +27 -0
- themis/config/registry.py +34 -0
- themis/config/runtime.py +214 -0
- themis/config/schema.py +112 -0
- themis/core/__init__.py +5 -0
- themis/core/conversation.py +354 -0
- themis/core/entities.py +184 -0
- themis/core/serialization.py +231 -0
- themis/core/tools.py +393 -0
- themis/core/types.py +141 -0
- themis/datasets/__init__.py +273 -0
- themis/datasets/base.py +264 -0
- themis/datasets/commonsense_qa.py +174 -0
- themis/datasets/competition_math.py +265 -0
- themis/datasets/coqa.py +133 -0
- themis/datasets/gpqa.py +190 -0
- themis/datasets/gsm8k.py +123 -0
- themis/datasets/gsm_symbolic.py +124 -0
- themis/datasets/math500.py +122 -0
- themis/datasets/med_qa.py +179 -0
- themis/datasets/medmcqa.py +169 -0
- themis/datasets/mmlu_pro.py +262 -0
- themis/datasets/piqa.py +146 -0
- themis/datasets/registry.py +201 -0
- themis/datasets/schema.py +245 -0
- themis/datasets/sciq.py +150 -0
- themis/datasets/social_i_qa.py +151 -0
- themis/datasets/super_gpqa.py +263 -0
- themis/evaluation/__init__.py +1 -0
- themis/evaluation/conditional.py +410 -0
- themis/evaluation/extractors/__init__.py +19 -0
- themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
- themis/evaluation/extractors/exceptions.py +7 -0
- themis/evaluation/extractors/identity_extractor.py +29 -0
- themis/evaluation/extractors/json_field_extractor.py +45 -0
- themis/evaluation/extractors/math_verify_extractor.py +37 -0
- themis/evaluation/extractors/regex_extractor.py +43 -0
- themis/evaluation/math_verify_utils.py +87 -0
- themis/evaluation/metrics/__init__.py +21 -0
- themis/evaluation/metrics/code/__init__.py +19 -0
- themis/evaluation/metrics/code/codebleu.py +144 -0
- themis/evaluation/metrics/code/execution.py +280 -0
- themis/evaluation/metrics/code/pass_at_k.py +181 -0
- themis/evaluation/metrics/composite_metric.py +47 -0
- themis/evaluation/metrics/consistency_metric.py +80 -0
- themis/evaluation/metrics/exact_match.py +51 -0
- themis/evaluation/metrics/length_difference_tolerance.py +33 -0
- themis/evaluation/metrics/math_verify_accuracy.py +40 -0
- themis/evaluation/metrics/nlp/__init__.py +21 -0
- themis/evaluation/metrics/nlp/bertscore.py +138 -0
- themis/evaluation/metrics/nlp/bleu.py +129 -0
- themis/evaluation/metrics/nlp/meteor.py +153 -0
- themis/evaluation/metrics/nlp/rouge.py +136 -0
- themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
- themis/evaluation/metrics/response_length.py +33 -0
- themis/evaluation/metrics/rubric_judge_metric.py +134 -0
- themis/evaluation/pipeline.py +49 -0
- themis/evaluation/pipelines/__init__.py +15 -0
- themis/evaluation/pipelines/composable_pipeline.py +357 -0
- themis/evaluation/pipelines/standard_pipeline.py +348 -0
- themis/evaluation/reports.py +293 -0
- themis/evaluation/statistics/__init__.py +53 -0
- themis/evaluation/statistics/bootstrap.py +79 -0
- themis/evaluation/statistics/confidence_intervals.py +121 -0
- themis/evaluation/statistics/distributions.py +207 -0
- themis/evaluation/statistics/effect_sizes.py +124 -0
- themis/evaluation/statistics/hypothesis_tests.py +305 -0
- themis/evaluation/statistics/types.py +139 -0
- themis/evaluation/strategies/__init__.py +13 -0
- themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
- themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
- themis/evaluation/strategies/evaluation_strategy.py +24 -0
- themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
- themis/experiment/__init__.py +5 -0
- themis/experiment/builder.py +151 -0
- themis/experiment/cache_manager.py +134 -0
- themis/experiment/comparison.py +631 -0
- themis/experiment/cost.py +310 -0
- themis/experiment/definitions.py +62 -0
- themis/experiment/export.py +798 -0
- themis/experiment/export_csv.py +159 -0
- themis/experiment/integration_manager.py +104 -0
- themis/experiment/math.py +192 -0
- themis/experiment/mcq.py +169 -0
- themis/experiment/orchestrator.py +415 -0
- themis/experiment/pricing.py +317 -0
- themis/experiment/storage.py +1458 -0
- themis/experiment/visualization.py +588 -0
- themis/generation/__init__.py +1 -0
- themis/generation/agentic_runner.py +420 -0
- themis/generation/batching.py +254 -0
- themis/generation/clients.py +143 -0
- themis/generation/conversation_runner.py +236 -0
- themis/generation/plan.py +456 -0
- themis/generation/providers/litellm_provider.py +221 -0
- themis/generation/providers/vllm_provider.py +135 -0
- themis/generation/router.py +34 -0
- themis/generation/runner.py +207 -0
- themis/generation/strategies.py +98 -0
- themis/generation/templates.py +71 -0
- themis/generation/turn_strategies.py +393 -0
- themis/generation/types.py +9 -0
- themis/integrations/__init__.py +0 -0
- themis/integrations/huggingface.py +72 -0
- themis/integrations/wandb.py +77 -0
- themis/interfaces/__init__.py +169 -0
- themis/presets/__init__.py +10 -0
- themis/presets/benchmarks.py +354 -0
- themis/presets/models.py +190 -0
- themis/project/__init__.py +20 -0
- themis/project/definitions.py +98 -0
- themis/project/patterns.py +230 -0
- themis/providers/__init__.py +5 -0
- themis/providers/registry.py +39 -0
- themis/server/__init__.py +28 -0
- themis/server/app.py +337 -0
- themis/utils/api_generator.py +379 -0
- themis/utils/cost_tracking.py +376 -0
- themis/utils/dashboard.py +452 -0
- themis/utils/logging_utils.py +41 -0
- themis/utils/progress.py +58 -0
- themis/utils/tracing.py +320 -0
- themis_eval-0.2.0.dist-info/METADATA +596 -0
- themis_eval-0.2.0.dist-info/RECORD +157 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/WHEEL +1 -1
- themis_eval-0.1.0.dist-info/METADATA +0 -758
- themis_eval-0.1.0.dist-info/RECORD +0 -8
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
"""Serialization helpers for Themis core entities."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import copy
|
|
6
|
+
from typing import Any, Dict
|
|
7
|
+
|
|
8
|
+
from themis.core import entities as core_entities
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def serialize_sampling(config: core_entities.SamplingConfig) -> Dict[str, Any]:
|
|
12
|
+
return {
|
|
13
|
+
"temperature": config.temperature,
|
|
14
|
+
"top_p": config.top_p,
|
|
15
|
+
"max_tokens": config.max_tokens,
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def deserialize_sampling(data: Dict[str, Any]) -> core_entities.SamplingConfig:
|
|
20
|
+
return core_entities.SamplingConfig(
|
|
21
|
+
temperature=data["temperature"],
|
|
22
|
+
top_p=data["top_p"],
|
|
23
|
+
max_tokens=data["max_tokens"],
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def serialize_model_spec(spec: core_entities.ModelSpec) -> Dict[str, Any]:
|
|
28
|
+
return {
|
|
29
|
+
"identifier": spec.identifier,
|
|
30
|
+
"provider": spec.provider,
|
|
31
|
+
"metadata": copy.deepcopy(spec.metadata),
|
|
32
|
+
"default_sampling": serialize_sampling(spec.default_sampling)
|
|
33
|
+
if spec.default_sampling
|
|
34
|
+
else None,
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def deserialize_model_spec(data: Dict[str, Any]) -> core_entities.ModelSpec:
|
|
39
|
+
default_sampling = (
|
|
40
|
+
deserialize_sampling(data["default_sampling"])
|
|
41
|
+
if data.get("default_sampling")
|
|
42
|
+
else None
|
|
43
|
+
)
|
|
44
|
+
return core_entities.ModelSpec(
|
|
45
|
+
identifier=data["identifier"],
|
|
46
|
+
provider=data["provider"],
|
|
47
|
+
metadata=copy.deepcopy(data.get("metadata", {})),
|
|
48
|
+
default_sampling=default_sampling,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def serialize_prompt_spec(spec: core_entities.PromptSpec) -> Dict[str, Any]:
|
|
53
|
+
return {
|
|
54
|
+
"name": spec.name,
|
|
55
|
+
"template": spec.template,
|
|
56
|
+
"metadata": copy.deepcopy(spec.metadata),
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def deserialize_prompt_spec(data: Dict[str, Any]) -> core_entities.PromptSpec:
|
|
61
|
+
return core_entities.PromptSpec(
|
|
62
|
+
name=data["name"],
|
|
63
|
+
template=data["template"],
|
|
64
|
+
metadata=copy.deepcopy(data.get("metadata", {})),
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def serialize_prompt_render(render: core_entities.PromptRender) -> Dict[str, Any]:
|
|
69
|
+
return {
|
|
70
|
+
"spec": serialize_prompt_spec(render.spec),
|
|
71
|
+
"text": render.text,
|
|
72
|
+
"context": copy.deepcopy(render.context),
|
|
73
|
+
"metadata": copy.deepcopy(render.metadata),
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def deserialize_prompt_render(data: Dict[str, Any]) -> core_entities.PromptRender:
|
|
78
|
+
return core_entities.PromptRender(
|
|
79
|
+
spec=deserialize_prompt_spec(data["spec"]),
|
|
80
|
+
text=data["text"],
|
|
81
|
+
context=copy.deepcopy(data.get("context", {})),
|
|
82
|
+
metadata=copy.deepcopy(data.get("metadata", {})),
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def serialize_reference(
|
|
87
|
+
reference: core_entities.Reference | None,
|
|
88
|
+
) -> Dict[str, Any] | None:
|
|
89
|
+
if reference is None:
|
|
90
|
+
return None
|
|
91
|
+
return {"kind": reference.kind, "value": reference.value}
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def deserialize_reference(
|
|
95
|
+
data: Dict[str, Any] | None,
|
|
96
|
+
) -> core_entities.Reference | None:
|
|
97
|
+
if data is None:
|
|
98
|
+
return None
|
|
99
|
+
return core_entities.Reference(kind=data["kind"], value=data.get("value"))
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def serialize_generation_task(task: core_entities.GenerationTask) -> Dict[str, Any]:
|
|
103
|
+
return {
|
|
104
|
+
"prompt": serialize_prompt_render(task.prompt),
|
|
105
|
+
"model": serialize_model_spec(task.model),
|
|
106
|
+
"sampling": serialize_sampling(task.sampling),
|
|
107
|
+
"metadata": copy.deepcopy(task.metadata),
|
|
108
|
+
"reference": serialize_reference(task.reference),
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def deserialize_generation_task(data: Dict[str, Any]) -> core_entities.GenerationTask:
|
|
113
|
+
return core_entities.GenerationTask(
|
|
114
|
+
prompt=deserialize_prompt_render(data["prompt"]),
|
|
115
|
+
model=deserialize_model_spec(data["model"]),
|
|
116
|
+
sampling=deserialize_sampling(data["sampling"]),
|
|
117
|
+
metadata=copy.deepcopy(data.get("metadata", {})),
|
|
118
|
+
reference=deserialize_reference(data.get("reference")),
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def serialize_generation_record(
|
|
123
|
+
record: core_entities.GenerationRecord,
|
|
124
|
+
) -> Dict[str, Any]:
|
|
125
|
+
return {
|
|
126
|
+
"task": serialize_generation_task(record.task),
|
|
127
|
+
"output": {
|
|
128
|
+
"text": record.output.text,
|
|
129
|
+
"raw": record.output.raw,
|
|
130
|
+
}
|
|
131
|
+
if record.output
|
|
132
|
+
else None,
|
|
133
|
+
"error": {
|
|
134
|
+
"message": record.error.message,
|
|
135
|
+
"kind": record.error.kind,
|
|
136
|
+
"details": copy.deepcopy(record.error.details),
|
|
137
|
+
}
|
|
138
|
+
if record.error
|
|
139
|
+
else None,
|
|
140
|
+
"metrics": copy.deepcopy(record.metrics),
|
|
141
|
+
"attempts": [
|
|
142
|
+
serialize_generation_record(attempt) for attempt in record.attempts
|
|
143
|
+
],
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def deserialize_generation_record(
|
|
148
|
+
data: Dict[str, Any],
|
|
149
|
+
) -> core_entities.GenerationRecord:
|
|
150
|
+
output_data = data.get("output")
|
|
151
|
+
error_data = data.get("error")
|
|
152
|
+
return core_entities.GenerationRecord(
|
|
153
|
+
task=deserialize_generation_task(data["task"]),
|
|
154
|
+
output=core_entities.ModelOutput(
|
|
155
|
+
text=output_data["text"], raw=output_data.get("raw")
|
|
156
|
+
)
|
|
157
|
+
if output_data
|
|
158
|
+
else None,
|
|
159
|
+
error=core_entities.ModelError(
|
|
160
|
+
message=error_data["message"],
|
|
161
|
+
kind=error_data.get("kind", "model_error"),
|
|
162
|
+
details=copy.deepcopy(error_data.get("details", {})),
|
|
163
|
+
)
|
|
164
|
+
if error_data
|
|
165
|
+
else None,
|
|
166
|
+
metrics=copy.deepcopy(data.get("metrics", {})),
|
|
167
|
+
attempts=[
|
|
168
|
+
deserialize_generation_record(attempt)
|
|
169
|
+
for attempt in data.get("attempts", [])
|
|
170
|
+
],
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def serialize_metric_score(score: core_entities.MetricScore) -> Dict[str, Any]:
|
|
175
|
+
return {
|
|
176
|
+
"metric_name": score.metric_name,
|
|
177
|
+
"value": score.value,
|
|
178
|
+
"details": copy.deepcopy(score.details),
|
|
179
|
+
"metadata": copy.deepcopy(score.metadata),
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def deserialize_metric_score(data: Dict[str, Any]) -> core_entities.MetricScore:
|
|
184
|
+
return core_entities.MetricScore(
|
|
185
|
+
metric_name=data["metric_name"],
|
|
186
|
+
value=data["value"],
|
|
187
|
+
details=copy.deepcopy(data.get("details", {})),
|
|
188
|
+
metadata=copy.deepcopy(data.get("metadata", {})),
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def serialize_evaluation_record(
|
|
193
|
+
record: core_entities.EvaluationRecord,
|
|
194
|
+
) -> Dict[str, Any]:
|
|
195
|
+
return {
|
|
196
|
+
"sample_id": record.sample_id,
|
|
197
|
+
"scores": [serialize_metric_score(score) for score in record.scores],
|
|
198
|
+
"failures": list(record.failures),
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def deserialize_evaluation_record(
|
|
203
|
+
data: Dict[str, Any],
|
|
204
|
+
) -> core_entities.EvaluationRecord:
|
|
205
|
+
return core_entities.EvaluationRecord(
|
|
206
|
+
sample_id=data.get("sample_id"),
|
|
207
|
+
scores=[deserialize_metric_score(score) for score in data.get("scores", [])],
|
|
208
|
+
failures=list(data.get("failures", [])),
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
__all__ = [
|
|
213
|
+
"serialize_generation_record",
|
|
214
|
+
"deserialize_generation_record",
|
|
215
|
+
"serialize_generation_task",
|
|
216
|
+
"deserialize_generation_task",
|
|
217
|
+
"serialize_evaluation_record",
|
|
218
|
+
"deserialize_evaluation_record",
|
|
219
|
+
"serialize_metric_score",
|
|
220
|
+
"deserialize_metric_score",
|
|
221
|
+
"serialize_sampling",
|
|
222
|
+
"deserialize_sampling",
|
|
223
|
+
"serialize_model_spec",
|
|
224
|
+
"deserialize_model_spec",
|
|
225
|
+
"serialize_prompt_spec",
|
|
226
|
+
"deserialize_prompt_spec",
|
|
227
|
+
"serialize_prompt_render",
|
|
228
|
+
"deserialize_prompt_render",
|
|
229
|
+
"serialize_reference",
|
|
230
|
+
"deserialize_reference",
|
|
231
|
+
]
|
themis/core/tools.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
1
|
+
"""Tool use primitives for agentic workflows.
|
|
2
|
+
|
|
3
|
+
This module provides abstractions for defining and executing tools
|
|
4
|
+
(functions) that models can call during generation. This enables
|
|
5
|
+
agentic workflows, function calling, and tool-augmented generation.
|
|
6
|
+
|
|
7
|
+
Examples:
|
|
8
|
+
# Define a tool
|
|
9
|
+
def calculator(operation: str, a: float, b: float) -> float:
|
|
10
|
+
if operation == "add":
|
|
11
|
+
return a + b
|
|
12
|
+
elif operation == "multiply":
|
|
13
|
+
return a * b
|
|
14
|
+
raise ValueError(f"Unknown operation: {operation}")
|
|
15
|
+
|
|
16
|
+
tool = ToolDefinition(
|
|
17
|
+
name="calculator",
|
|
18
|
+
description="Perform arithmetic operations",
|
|
19
|
+
parameters={
|
|
20
|
+
"type": "object",
|
|
21
|
+
"properties": {
|
|
22
|
+
"operation": {"type": "string", "enum": ["add", "multiply"]},
|
|
23
|
+
"a": {"type": "number"},
|
|
24
|
+
"b": {"type": "number"},
|
|
25
|
+
},
|
|
26
|
+
"required": ["operation", "a", "b"],
|
|
27
|
+
},
|
|
28
|
+
handler=calculator
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
# Register tool
|
|
32
|
+
registry = ToolRegistry()
|
|
33
|
+
registry.register(tool)
|
|
34
|
+
|
|
35
|
+
# Execute tool
|
|
36
|
+
call = ToolCall(tool_name="calculator", arguments={"operation": "add", "a": 2, "b": 3})
|
|
37
|
+
result = registry.execute(call)
|
|
38
|
+
print(result.result) # 5.0
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
from __future__ import annotations
|
|
42
|
+
|
|
43
|
+
import time
|
|
44
|
+
import uuid
|
|
45
|
+
from dataclasses import dataclass, field
|
|
46
|
+
from typing import Any, Callable
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class ToolDefinition:
|
|
51
|
+
"""Defines a tool/function available to the model.
|
|
52
|
+
|
|
53
|
+
Attributes:
|
|
54
|
+
name: Tool name (should be unique)
|
|
55
|
+
description: Human-readable description of what tool does
|
|
56
|
+
parameters: JSON Schema describing parameters
|
|
57
|
+
handler: Function to execute when tool is called
|
|
58
|
+
metadata: Additional metadata
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
name: str
|
|
62
|
+
description: str
|
|
63
|
+
parameters: dict[str, Any]
|
|
64
|
+
handler: Callable[[dict[str, Any]], Any]
|
|
65
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
66
|
+
|
|
67
|
+
def to_dict(self) -> dict[str, Any]:
|
|
68
|
+
"""Convert tool definition to dictionary (without handler).
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
Dictionary representation suitable for JSON serialization
|
|
72
|
+
"""
|
|
73
|
+
return {
|
|
74
|
+
"name": self.name,
|
|
75
|
+
"description": self.description,
|
|
76
|
+
"parameters": self.parameters,
|
|
77
|
+
"metadata": self.metadata,
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
def validate_arguments(self, arguments: dict[str, Any]) -> list[str]:
|
|
81
|
+
"""Validate arguments against parameter schema.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
arguments: Arguments to validate
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
List of validation error messages (empty if valid)
|
|
88
|
+
"""
|
|
89
|
+
errors = []
|
|
90
|
+
|
|
91
|
+
# Simple validation - check required fields
|
|
92
|
+
if "required" in self.parameters:
|
|
93
|
+
for field in self.parameters["required"]:
|
|
94
|
+
if field not in arguments:
|
|
95
|
+
errors.append(f"Missing required field: {field}")
|
|
96
|
+
|
|
97
|
+
# Check for unknown fields
|
|
98
|
+
if "properties" in self.parameters:
|
|
99
|
+
known_fields = set(self.parameters["properties"].keys())
|
|
100
|
+
for field in arguments.keys():
|
|
101
|
+
if field not in known_fields:
|
|
102
|
+
errors.append(f"Unknown field: {field}")
|
|
103
|
+
|
|
104
|
+
return errors
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@dataclass
|
|
108
|
+
class ToolCall:
|
|
109
|
+
"""Represents a request to execute a tool.
|
|
110
|
+
|
|
111
|
+
Attributes:
|
|
112
|
+
tool_name: Name of tool to execute
|
|
113
|
+
arguments: Arguments to pass to tool
|
|
114
|
+
call_id: Unique identifier for this call
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
tool_name: str
|
|
118
|
+
arguments: dict[str, Any]
|
|
119
|
+
call_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
|
120
|
+
|
|
121
|
+
def to_dict(self) -> dict[str, Any]:
|
|
122
|
+
"""Convert to dictionary.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
Dictionary representation
|
|
126
|
+
"""
|
|
127
|
+
return {
|
|
128
|
+
"tool_name": self.tool_name,
|
|
129
|
+
"arguments": self.arguments,
|
|
130
|
+
"call_id": self.call_id,
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
@dataclass
|
|
135
|
+
class ToolResult:
|
|
136
|
+
"""Result from executing a tool.
|
|
137
|
+
|
|
138
|
+
Attributes:
|
|
139
|
+
call: Original tool call
|
|
140
|
+
result: Result value (if successful)
|
|
141
|
+
error: Error message (if failed)
|
|
142
|
+
execution_time_ms: Time taken to execute (milliseconds)
|
|
143
|
+
metadata: Additional metadata
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
call: ToolCall
|
|
147
|
+
result: Any | None
|
|
148
|
+
error: str | None
|
|
149
|
+
execution_time_ms: float
|
|
150
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
151
|
+
|
|
152
|
+
def is_success(self) -> bool:
|
|
153
|
+
"""Check if tool execution was successful.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
True if no error
|
|
157
|
+
"""
|
|
158
|
+
return self.error is None
|
|
159
|
+
|
|
160
|
+
def to_dict(self) -> dict[str, Any]:
|
|
161
|
+
"""Convert to dictionary.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
Dictionary representation
|
|
165
|
+
"""
|
|
166
|
+
return {
|
|
167
|
+
"call": self.call.to_dict(),
|
|
168
|
+
"result": self.result,
|
|
169
|
+
"error": self.error,
|
|
170
|
+
"execution_time_ms": self.execution_time_ms,
|
|
171
|
+
"metadata": self.metadata,
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class ToolRegistry:
|
|
176
|
+
"""Registry for managing and executing tools.
|
|
177
|
+
|
|
178
|
+
This class maintains a registry of available tools and provides
|
|
179
|
+
methods for registering, retrieving, and executing them.
|
|
180
|
+
|
|
181
|
+
Examples:
|
|
182
|
+
registry = ToolRegistry()
|
|
183
|
+
|
|
184
|
+
# Register tools
|
|
185
|
+
registry.register(calculator_tool)
|
|
186
|
+
registry.register(search_tool)
|
|
187
|
+
|
|
188
|
+
# Execute tool
|
|
189
|
+
call = ToolCall(tool_name="calculator", arguments={...})
|
|
190
|
+
result = registry.execute(call)
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
def __init__(self):
|
|
194
|
+
"""Initialize empty tool registry."""
|
|
195
|
+
self._tools: dict[str, ToolDefinition] = {}
|
|
196
|
+
|
|
197
|
+
def register(self, tool: ToolDefinition) -> None:
|
|
198
|
+
"""Register a tool.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
tool: Tool definition to register
|
|
202
|
+
|
|
203
|
+
Raises:
|
|
204
|
+
ValueError: If tool with same name already registered
|
|
205
|
+
"""
|
|
206
|
+
if tool.name in self._tools:
|
|
207
|
+
raise ValueError(f"Tool '{tool.name}' already registered")
|
|
208
|
+
|
|
209
|
+
self._tools[tool.name] = tool
|
|
210
|
+
|
|
211
|
+
def unregister(self, name: str) -> None:
|
|
212
|
+
"""Unregister a tool by name.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
name: Tool name to unregister
|
|
216
|
+
"""
|
|
217
|
+
self._tools.pop(name, None)
|
|
218
|
+
|
|
219
|
+
def get(self, name: str) -> ToolDefinition | None:
|
|
220
|
+
"""Get tool by name.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
name: Tool name
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
ToolDefinition if found, None otherwise
|
|
227
|
+
"""
|
|
228
|
+
return self._tools.get(name)
|
|
229
|
+
|
|
230
|
+
def list_tools(self) -> list[ToolDefinition]:
|
|
231
|
+
"""Get all registered tools.
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
List of tool definitions
|
|
235
|
+
"""
|
|
236
|
+
return list(self._tools.values())
|
|
237
|
+
|
|
238
|
+
def execute(self, call: ToolCall) -> ToolResult:
|
|
239
|
+
"""Execute a tool call.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
call: Tool call to execute
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
ToolResult with execution result or error
|
|
246
|
+
"""
|
|
247
|
+
tool = self._tools.get(call.tool_name)
|
|
248
|
+
|
|
249
|
+
if tool is None:
|
|
250
|
+
return ToolResult(
|
|
251
|
+
call=call,
|
|
252
|
+
result=None,
|
|
253
|
+
error=f"Unknown tool: {call.tool_name}",
|
|
254
|
+
execution_time_ms=0.0,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# Validate arguments
|
|
258
|
+
validation_errors = tool.validate_arguments(call.arguments)
|
|
259
|
+
if validation_errors:
|
|
260
|
+
return ToolResult(
|
|
261
|
+
call=call,
|
|
262
|
+
result=None,
|
|
263
|
+
error=f"Invalid arguments: {'; '.join(validation_errors)}",
|
|
264
|
+
execution_time_ms=0.0,
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
# Execute tool
|
|
268
|
+
start = time.perf_counter()
|
|
269
|
+
try:
|
|
270
|
+
result = tool.handler(call.arguments)
|
|
271
|
+
elapsed = (time.perf_counter() - start) * 1000
|
|
272
|
+
return ToolResult(
|
|
273
|
+
call=call,
|
|
274
|
+
result=result,
|
|
275
|
+
error=None,
|
|
276
|
+
execution_time_ms=elapsed,
|
|
277
|
+
)
|
|
278
|
+
except Exception as e:
|
|
279
|
+
elapsed = (time.perf_counter() - start) * 1000
|
|
280
|
+
return ToolResult(
|
|
281
|
+
call=call,
|
|
282
|
+
result=None,
|
|
283
|
+
error=f"{e.__class__.__name__}: {str(e)}",
|
|
284
|
+
execution_time_ms=elapsed,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
def to_dict_list(self) -> list[dict[str, Any]]:
|
|
288
|
+
"""Get all tools as dictionary list (for sending to model).
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
List of tool definitions as dictionaries
|
|
292
|
+
"""
|
|
293
|
+
return [tool.to_dict() for tool in self._tools.values()]
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
# Built-in tools for common use cases
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def create_calculator_tool() -> ToolDefinition:
|
|
300
|
+
"""Create a basic calculator tool.
|
|
301
|
+
|
|
302
|
+
Returns:
|
|
303
|
+
ToolDefinition for calculator
|
|
304
|
+
"""
|
|
305
|
+
|
|
306
|
+
def handler(args: dict[str, Any]) -> float:
|
|
307
|
+
operation = args["operation"]
|
|
308
|
+
a = float(args["a"])
|
|
309
|
+
b = float(args["b"])
|
|
310
|
+
|
|
311
|
+
if operation == "add":
|
|
312
|
+
return a + b
|
|
313
|
+
elif operation == "subtract":
|
|
314
|
+
return a - b
|
|
315
|
+
elif operation == "multiply":
|
|
316
|
+
return a * b
|
|
317
|
+
elif operation == "divide":
|
|
318
|
+
if b == 0:
|
|
319
|
+
raise ValueError("Division by zero")
|
|
320
|
+
return a / b
|
|
321
|
+
else:
|
|
322
|
+
raise ValueError(f"Unknown operation: {operation}")
|
|
323
|
+
|
|
324
|
+
return ToolDefinition(
|
|
325
|
+
name="calculator",
|
|
326
|
+
description="Perform basic arithmetic operations (add, subtract, multiply, divide)",
|
|
327
|
+
parameters={
|
|
328
|
+
"type": "object",
|
|
329
|
+
"properties": {
|
|
330
|
+
"operation": {
|
|
331
|
+
"type": "string",
|
|
332
|
+
"enum": ["add", "subtract", "multiply", "divide"],
|
|
333
|
+
"description": "The arithmetic operation to perform",
|
|
334
|
+
},
|
|
335
|
+
"a": {"type": "number", "description": "First number"},
|
|
336
|
+
"b": {"type": "number", "description": "Second number"},
|
|
337
|
+
},
|
|
338
|
+
"required": ["operation", "a", "b"],
|
|
339
|
+
},
|
|
340
|
+
handler=handler,
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def create_counter_tool() -> ToolDefinition:
|
|
345
|
+
"""Create a stateful counter tool for testing.
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
ToolDefinition for counter
|
|
349
|
+
"""
|
|
350
|
+
counter = {"value": 0}
|
|
351
|
+
|
|
352
|
+
def handler(args: dict[str, Any]) -> int:
|
|
353
|
+
action = args["action"]
|
|
354
|
+
|
|
355
|
+
if action == "increment":
|
|
356
|
+
counter["value"] += 1
|
|
357
|
+
elif action == "decrement":
|
|
358
|
+
counter["value"] -= 1
|
|
359
|
+
elif action == "reset":
|
|
360
|
+
counter["value"] = 0
|
|
361
|
+
elif action == "get":
|
|
362
|
+
pass # Just return current value
|
|
363
|
+
else:
|
|
364
|
+
raise ValueError(f"Unknown action: {action}")
|
|
365
|
+
|
|
366
|
+
return counter["value"]
|
|
367
|
+
|
|
368
|
+
return ToolDefinition(
|
|
369
|
+
name="counter",
|
|
370
|
+
description="Simple counter that can be incremented, decremented, or reset",
|
|
371
|
+
parameters={
|
|
372
|
+
"type": "object",
|
|
373
|
+
"properties": {
|
|
374
|
+
"action": {
|
|
375
|
+
"type": "string",
|
|
376
|
+
"enum": ["increment", "decrement", "reset", "get"],
|
|
377
|
+
"description": "Action to perform on counter",
|
|
378
|
+
},
|
|
379
|
+
},
|
|
380
|
+
"required": ["action"],
|
|
381
|
+
},
|
|
382
|
+
handler=handler,
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
__all__ = [
|
|
387
|
+
"ToolDefinition",
|
|
388
|
+
"ToolCall",
|
|
389
|
+
"ToolResult",
|
|
390
|
+
"ToolRegistry",
|
|
391
|
+
"create_calculator_tool",
|
|
392
|
+
"create_counter_tool",
|
|
393
|
+
]
|