themis-eval 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themis/__init__.py +12 -1
- themis/_version.py +2 -2
- themis/api.py +343 -0
- themis/backends/__init__.py +17 -0
- themis/backends/execution.py +197 -0
- themis/backends/storage.py +260 -0
- themis/cli/__init__.py +5 -0
- themis/cli/__main__.py +6 -0
- themis/cli/commands/__init__.py +19 -0
- themis/cli/commands/benchmarks.py +221 -0
- themis/cli/commands/comparison.py +394 -0
- themis/cli/commands/config_commands.py +244 -0
- themis/cli/commands/cost.py +214 -0
- themis/cli/commands/demo.py +68 -0
- themis/cli/commands/info.py +90 -0
- themis/cli/commands/leaderboard.py +362 -0
- themis/cli/commands/math_benchmarks.py +318 -0
- themis/cli/commands/mcq_benchmarks.py +207 -0
- themis/cli/commands/results.py +252 -0
- themis/cli/commands/sample_run.py +244 -0
- themis/cli/commands/visualize.py +299 -0
- themis/cli/main.py +463 -0
- themis/cli/new_project.py +33 -0
- themis/cli/utils.py +51 -0
- themis/comparison/__init__.py +25 -0
- themis/comparison/engine.py +348 -0
- themis/comparison/reports.py +283 -0
- themis/comparison/statistics.py +402 -0
- themis/config/__init__.py +19 -0
- themis/config/loader.py +27 -0
- themis/config/registry.py +34 -0
- themis/config/runtime.py +214 -0
- themis/config/schema.py +112 -0
- themis/core/__init__.py +5 -0
- themis/core/conversation.py +354 -0
- themis/core/entities.py +184 -0
- themis/core/serialization.py +231 -0
- themis/core/tools.py +393 -0
- themis/core/types.py +141 -0
- themis/datasets/__init__.py +273 -0
- themis/datasets/base.py +264 -0
- themis/datasets/commonsense_qa.py +174 -0
- themis/datasets/competition_math.py +265 -0
- themis/datasets/coqa.py +133 -0
- themis/datasets/gpqa.py +190 -0
- themis/datasets/gsm8k.py +123 -0
- themis/datasets/gsm_symbolic.py +124 -0
- themis/datasets/math500.py +122 -0
- themis/datasets/med_qa.py +179 -0
- themis/datasets/medmcqa.py +169 -0
- themis/datasets/mmlu_pro.py +262 -0
- themis/datasets/piqa.py +146 -0
- themis/datasets/registry.py +201 -0
- themis/datasets/schema.py +245 -0
- themis/datasets/sciq.py +150 -0
- themis/datasets/social_i_qa.py +151 -0
- themis/datasets/super_gpqa.py +263 -0
- themis/evaluation/__init__.py +1 -0
- themis/evaluation/conditional.py +410 -0
- themis/evaluation/extractors/__init__.py +19 -0
- themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
- themis/evaluation/extractors/exceptions.py +7 -0
- themis/evaluation/extractors/identity_extractor.py +29 -0
- themis/evaluation/extractors/json_field_extractor.py +45 -0
- themis/evaluation/extractors/math_verify_extractor.py +37 -0
- themis/evaluation/extractors/regex_extractor.py +43 -0
- themis/evaluation/math_verify_utils.py +87 -0
- themis/evaluation/metrics/__init__.py +21 -0
- themis/evaluation/metrics/code/__init__.py +19 -0
- themis/evaluation/metrics/code/codebleu.py +144 -0
- themis/evaluation/metrics/code/execution.py +280 -0
- themis/evaluation/metrics/code/pass_at_k.py +181 -0
- themis/evaluation/metrics/composite_metric.py +47 -0
- themis/evaluation/metrics/consistency_metric.py +80 -0
- themis/evaluation/metrics/exact_match.py +51 -0
- themis/evaluation/metrics/length_difference_tolerance.py +33 -0
- themis/evaluation/metrics/math_verify_accuracy.py +40 -0
- themis/evaluation/metrics/nlp/__init__.py +21 -0
- themis/evaluation/metrics/nlp/bertscore.py +138 -0
- themis/evaluation/metrics/nlp/bleu.py +129 -0
- themis/evaluation/metrics/nlp/meteor.py +153 -0
- themis/evaluation/metrics/nlp/rouge.py +136 -0
- themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
- themis/evaluation/metrics/response_length.py +33 -0
- themis/evaluation/metrics/rubric_judge_metric.py +134 -0
- themis/evaluation/pipeline.py +49 -0
- themis/evaluation/pipelines/__init__.py +15 -0
- themis/evaluation/pipelines/composable_pipeline.py +357 -0
- themis/evaluation/pipelines/standard_pipeline.py +348 -0
- themis/evaluation/reports.py +293 -0
- themis/evaluation/statistics/__init__.py +53 -0
- themis/evaluation/statistics/bootstrap.py +79 -0
- themis/evaluation/statistics/confidence_intervals.py +121 -0
- themis/evaluation/statistics/distributions.py +207 -0
- themis/evaluation/statistics/effect_sizes.py +124 -0
- themis/evaluation/statistics/hypothesis_tests.py +305 -0
- themis/evaluation/statistics/types.py +139 -0
- themis/evaluation/strategies/__init__.py +13 -0
- themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
- themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
- themis/evaluation/strategies/evaluation_strategy.py +24 -0
- themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
- themis/experiment/__init__.py +5 -0
- themis/experiment/builder.py +151 -0
- themis/experiment/cache_manager.py +134 -0
- themis/experiment/comparison.py +631 -0
- themis/experiment/cost.py +310 -0
- themis/experiment/definitions.py +62 -0
- themis/experiment/export.py +798 -0
- themis/experiment/export_csv.py +159 -0
- themis/experiment/integration_manager.py +104 -0
- themis/experiment/math.py +192 -0
- themis/experiment/mcq.py +169 -0
- themis/experiment/orchestrator.py +415 -0
- themis/experiment/pricing.py +317 -0
- themis/experiment/storage.py +1458 -0
- themis/experiment/visualization.py +588 -0
- themis/generation/__init__.py +1 -0
- themis/generation/agentic_runner.py +420 -0
- themis/generation/batching.py +254 -0
- themis/generation/clients.py +143 -0
- themis/generation/conversation_runner.py +236 -0
- themis/generation/plan.py +456 -0
- themis/generation/providers/litellm_provider.py +221 -0
- themis/generation/providers/vllm_provider.py +135 -0
- themis/generation/router.py +34 -0
- themis/generation/runner.py +207 -0
- themis/generation/strategies.py +98 -0
- themis/generation/templates.py +71 -0
- themis/generation/turn_strategies.py +393 -0
- themis/generation/types.py +9 -0
- themis/integrations/__init__.py +0 -0
- themis/integrations/huggingface.py +72 -0
- themis/integrations/wandb.py +77 -0
- themis/interfaces/__init__.py +169 -0
- themis/presets/__init__.py +10 -0
- themis/presets/benchmarks.py +354 -0
- themis/presets/models.py +190 -0
- themis/project/__init__.py +20 -0
- themis/project/definitions.py +98 -0
- themis/project/patterns.py +230 -0
- themis/providers/__init__.py +5 -0
- themis/providers/registry.py +39 -0
- themis/server/__init__.py +28 -0
- themis/server/app.py +337 -0
- themis/utils/api_generator.py +379 -0
- themis/utils/cost_tracking.py +376 -0
- themis/utils/dashboard.py +452 -0
- themis/utils/logging_utils.py +41 -0
- themis/utils/progress.py +58 -0
- themis/utils/tracing.py +320 -0
- themis_eval-0.2.0.dist-info/METADATA +596 -0
- themis_eval-0.2.0.dist-info/RECORD +157 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/WHEEL +1 -1
- themis_eval-0.1.0.dist-info/METADATA +0 -758
- themis_eval-0.1.0.dist-info/RECORD +0 -8
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,420 @@
|
|
|
1
|
+
"""Agentic runner with tool use support.
|
|
2
|
+
|
|
3
|
+
This module provides a runner that supports agentic workflows where
|
|
4
|
+
models can call tools/functions to augment their capabilities.
|
|
5
|
+
|
|
6
|
+
Examples:
|
|
7
|
+
from themis.generation import agentic_runner
|
|
8
|
+
from themis.core import tools, entities
|
|
9
|
+
|
|
10
|
+
# Create registry with tools
|
|
11
|
+
registry = tools.ToolRegistry()
|
|
12
|
+
registry.register(tools.create_calculator_tool())
|
|
13
|
+
|
|
14
|
+
# Create runner
|
|
15
|
+
runner = agentic_runner.AgenticRunner(
|
|
16
|
+
provider=provider,
|
|
17
|
+
tool_registry=registry,
|
|
18
|
+
max_iterations=10
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
# Create task
|
|
22
|
+
task = entities.GenerationTask(...)
|
|
23
|
+
|
|
24
|
+
# Run with tool use
|
|
25
|
+
record = runner.run_agentic(task)
|
|
26
|
+
print(f"Used {len(record.iterations)} iterations")
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
from __future__ import annotations
|
|
30
|
+
|
|
31
|
+
import json
|
|
32
|
+
import logging
|
|
33
|
+
import re
|
|
34
|
+
from dataclasses import dataclass, field
|
|
35
|
+
from typing import Any, Callable
|
|
36
|
+
|
|
37
|
+
from themis.core import conversation as conv
|
|
38
|
+
from themis.core import entities as core_entities
|
|
39
|
+
from themis.core import tools as tool_primitives
|
|
40
|
+
from themis.interfaces import ModelProvider
|
|
41
|
+
from themis.utils import tracing
|
|
42
|
+
|
|
43
|
+
logger = logging.getLogger(__name__)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class AgenticIteration:
|
|
48
|
+
"""Single iteration in an agentic workflow.
|
|
49
|
+
|
|
50
|
+
Attributes:
|
|
51
|
+
iteration_number: Iteration index (0-based)
|
|
52
|
+
generation_record: Model generation for this iteration
|
|
53
|
+
tool_calls: Tool calls extracted from generation
|
|
54
|
+
tool_results: Results from executing tools
|
|
55
|
+
context_snapshot: Conversation context at this iteration
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
iteration_number: int
|
|
59
|
+
generation_record: core_entities.GenerationRecord
|
|
60
|
+
tool_calls: list[tool_primitives.ToolCall] = field(default_factory=list)
|
|
61
|
+
tool_results: list[tool_primitives.ToolResult] = field(default_factory=list)
|
|
62
|
+
context_snapshot: conv.ConversationContext | None = None
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@dataclass
|
|
66
|
+
class AgenticRecord:
|
|
67
|
+
"""Complete record of an agentic workflow execution.
|
|
68
|
+
|
|
69
|
+
Attributes:
|
|
70
|
+
task: Original generation task
|
|
71
|
+
final_output: Final model output
|
|
72
|
+
iterations: List of iterations executed
|
|
73
|
+
context: Conversation context (if used)
|
|
74
|
+
metadata: Additional metadata
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
task: core_entities.GenerationTask
|
|
78
|
+
final_output: core_entities.ModelOutput | None
|
|
79
|
+
iterations: list[AgenticIteration] = field(default_factory=list)
|
|
80
|
+
context: conv.ConversationContext | None = None
|
|
81
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
82
|
+
|
|
83
|
+
def total_iterations(self) -> int:
|
|
84
|
+
"""Get total number of iterations.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Number of iterations
|
|
88
|
+
"""
|
|
89
|
+
return len(self.iterations)
|
|
90
|
+
|
|
91
|
+
def total_tool_calls(self) -> int:
|
|
92
|
+
"""Get total number of tool calls across all iterations.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Total tool calls
|
|
96
|
+
"""
|
|
97
|
+
return sum(len(it.tool_calls) for it in self.iterations)
|
|
98
|
+
|
|
99
|
+
def successful_tool_calls(self) -> int:
|
|
100
|
+
"""Get number of successful tool calls.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Number of successful tool executions
|
|
104
|
+
"""
|
|
105
|
+
count = 0
|
|
106
|
+
for iteration in self.iterations:
|
|
107
|
+
count += sum(1 for result in iteration.tool_results if result.is_success())
|
|
108
|
+
return count
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class AgenticRunner:
|
|
112
|
+
"""Runner supporting tool use and agentic workflows.
|
|
113
|
+
|
|
114
|
+
This runner executes an agentic loop where the model can make tool calls,
|
|
115
|
+
receive results, and continue processing until completion or max iterations.
|
|
116
|
+
|
|
117
|
+
Attributes:
|
|
118
|
+
provider: Model provider for generation
|
|
119
|
+
tool_registry: Registry of available tools
|
|
120
|
+
max_iterations: Maximum number of iterations
|
|
121
|
+
tool_call_parser: Function to parse tool calls from model output
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
def __init__(
|
|
125
|
+
self,
|
|
126
|
+
*,
|
|
127
|
+
provider: ModelProvider,
|
|
128
|
+
tool_registry: tool_primitives.ToolRegistry,
|
|
129
|
+
max_iterations: int = 10,
|
|
130
|
+
tool_call_parser: Callable[[str], list[tool_primitives.ToolCall]] | None = None,
|
|
131
|
+
):
|
|
132
|
+
"""Initialize agentic runner.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
provider: Model provider for generation
|
|
136
|
+
tool_registry: Registry of available tools
|
|
137
|
+
max_iterations: Maximum number of iterations
|
|
138
|
+
tool_call_parser: Optional custom parser for tool calls
|
|
139
|
+
"""
|
|
140
|
+
self._provider = provider
|
|
141
|
+
self._tools = tool_registry
|
|
142
|
+
self._max_iterations = max_iterations
|
|
143
|
+
self._tool_call_parser = tool_call_parser or self._default_tool_call_parser
|
|
144
|
+
|
|
145
|
+
def run_agentic(self, task: core_entities.GenerationTask) -> AgenticRecord:
|
|
146
|
+
"""Run agentic loop with tool use.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
task: Generation task to execute
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
AgenticRecord with full iteration history
|
|
153
|
+
"""
|
|
154
|
+
task_id = task.metadata.get("dataset_id", "unknown")
|
|
155
|
+
|
|
156
|
+
with tracing.span(
|
|
157
|
+
"run_agentic",
|
|
158
|
+
task_id=task_id,
|
|
159
|
+
model=task.model.identifier,
|
|
160
|
+
max_iterations=self._max_iterations,
|
|
161
|
+
):
|
|
162
|
+
# Initialize conversation context
|
|
163
|
+
context = conv.ConversationContext()
|
|
164
|
+
context.add_message("user", task.prompt.text)
|
|
165
|
+
|
|
166
|
+
# Add system message with tool descriptions
|
|
167
|
+
tool_descriptions = self._format_tool_descriptions()
|
|
168
|
+
if tool_descriptions:
|
|
169
|
+
context.add_message("system", tool_descriptions)
|
|
170
|
+
|
|
171
|
+
iterations: list[AgenticIteration] = []
|
|
172
|
+
|
|
173
|
+
for i in range(self._max_iterations):
|
|
174
|
+
with tracing.span("agentic_iteration", iteration=i):
|
|
175
|
+
logger.debug(
|
|
176
|
+
"Starting agentic iteration %d/%d", i + 1, self._max_iterations
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# Generate with current context
|
|
180
|
+
with tracing.span("generate"):
|
|
181
|
+
prompt_text = context.to_prompt()
|
|
182
|
+
gen_task = self._update_task_prompt(task, prompt_text, i)
|
|
183
|
+
record = self._provider.generate(gen_task)
|
|
184
|
+
|
|
185
|
+
# Parse tool calls from output
|
|
186
|
+
with tracing.span("parse_tool_calls"):
|
|
187
|
+
tool_calls = self._parse_tool_calls(record)
|
|
188
|
+
|
|
189
|
+
# If no tool calls, we're done
|
|
190
|
+
if not tool_calls:
|
|
191
|
+
logger.debug("No tool calls found, ending agentic loop")
|
|
192
|
+
iteration = AgenticIteration(
|
|
193
|
+
iteration_number=i,
|
|
194
|
+
generation_record=record,
|
|
195
|
+
tool_calls=[],
|
|
196
|
+
tool_results=[],
|
|
197
|
+
context_snapshot=self._snapshot_context(context),
|
|
198
|
+
)
|
|
199
|
+
iterations.append(iteration)
|
|
200
|
+
|
|
201
|
+
# Add final assistant message
|
|
202
|
+
if record.output:
|
|
203
|
+
context.add_message("assistant", record.output.text)
|
|
204
|
+
|
|
205
|
+
break
|
|
206
|
+
|
|
207
|
+
# Execute tool calls
|
|
208
|
+
with tracing.span("execute_tools", num_tools=len(tool_calls)):
|
|
209
|
+
tool_results = self._execute_tools(tool_calls)
|
|
210
|
+
|
|
211
|
+
# Create iteration record
|
|
212
|
+
iteration = AgenticIteration(
|
|
213
|
+
iteration_number=i,
|
|
214
|
+
generation_record=record,
|
|
215
|
+
tool_calls=tool_calls,
|
|
216
|
+
tool_results=tool_results,
|
|
217
|
+
context_snapshot=self._snapshot_context(context),
|
|
218
|
+
)
|
|
219
|
+
iterations.append(iteration)
|
|
220
|
+
|
|
221
|
+
# Add assistant response to context
|
|
222
|
+
if record.output:
|
|
223
|
+
context.add_message("assistant", record.output.text)
|
|
224
|
+
|
|
225
|
+
# Add tool results to context
|
|
226
|
+
for result in tool_results:
|
|
227
|
+
result_text = self._format_tool_result(result)
|
|
228
|
+
context.add_message("tool", result_text)
|
|
229
|
+
|
|
230
|
+
logger.debug(
|
|
231
|
+
"Iteration %d: %d tool calls, %d successful",
|
|
232
|
+
i,
|
|
233
|
+
len(tool_calls),
|
|
234
|
+
sum(1 for r in tool_results if r.is_success()),
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
# Determine final output
|
|
238
|
+
final_output = None
|
|
239
|
+
if iterations:
|
|
240
|
+
final_output = iterations[-1].generation_record.output
|
|
241
|
+
|
|
242
|
+
# Create agentic record
|
|
243
|
+
record = AgenticRecord(
|
|
244
|
+
task=task,
|
|
245
|
+
final_output=final_output,
|
|
246
|
+
iterations=iterations,
|
|
247
|
+
context=context,
|
|
248
|
+
metadata={
|
|
249
|
+
"total_iterations": len(iterations),
|
|
250
|
+
"total_tool_calls": sum(len(it.tool_calls) for it in iterations),
|
|
251
|
+
"max_iterations_reached": len(iterations) >= self._max_iterations,
|
|
252
|
+
},
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
logger.info(
|
|
256
|
+
"Agentic execution completed: %d iterations, %d tool calls",
|
|
257
|
+
len(iterations),
|
|
258
|
+
record.total_tool_calls(),
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
return record
|
|
262
|
+
|
|
263
|
+
def _format_tool_descriptions(self) -> str:
|
|
264
|
+
"""Format tool descriptions for system message.
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
Formatted tool descriptions
|
|
268
|
+
"""
|
|
269
|
+
tools = self._tools.list_tools()
|
|
270
|
+
if not tools:
|
|
271
|
+
return ""
|
|
272
|
+
|
|
273
|
+
lines = ["Available tools:"]
|
|
274
|
+
for tool in tools:
|
|
275
|
+
lines.append(f"\n- {tool.name}: {tool.description}")
|
|
276
|
+
lines.append(f" Parameters: {json.dumps(tool.parameters, indent=2)}")
|
|
277
|
+
|
|
278
|
+
lines.append(
|
|
279
|
+
'\nTo use a tool, output: TOOL_CALL: {"name": "tool_name", "arguments": {...}}'
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
return "\n".join(lines)
|
|
283
|
+
|
|
284
|
+
def _parse_tool_calls(
|
|
285
|
+
self, record: core_entities.GenerationRecord
|
|
286
|
+
) -> list[tool_primitives.ToolCall]:
|
|
287
|
+
"""Parse tool calls from generation record.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
record: Generation record
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
List of parsed tool calls
|
|
294
|
+
"""
|
|
295
|
+
if not record.output:
|
|
296
|
+
return []
|
|
297
|
+
|
|
298
|
+
return self._tool_call_parser(record.output.text)
|
|
299
|
+
|
|
300
|
+
def _default_tool_call_parser(self, text: str) -> list[tool_primitives.ToolCall]:
|
|
301
|
+
"""Default parser for tool calls.
|
|
302
|
+
|
|
303
|
+
Looks for lines like: TOOL_CALL: {"name": "...", "arguments": {...}}
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
text: Model output text
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
List of parsed tool calls
|
|
310
|
+
"""
|
|
311
|
+
calls = []
|
|
312
|
+
|
|
313
|
+
# Look for TOOL_CALL: {...} pattern
|
|
314
|
+
pattern = r"TOOL_CALL:\s*(\{.*?\})"
|
|
315
|
+
matches = re.finditer(pattern, text, re.DOTALL)
|
|
316
|
+
|
|
317
|
+
for match in matches:
|
|
318
|
+
try:
|
|
319
|
+
data = json.loads(match.group(1))
|
|
320
|
+
if "name" in data and "arguments" in data:
|
|
321
|
+
call = tool_primitives.ToolCall(
|
|
322
|
+
tool_name=data["name"],
|
|
323
|
+
arguments=data["arguments"],
|
|
324
|
+
)
|
|
325
|
+
calls.append(call)
|
|
326
|
+
except json.JSONDecodeError:
|
|
327
|
+
logger.warning("Failed to parse tool call JSON: %s", match.group(1))
|
|
328
|
+
continue
|
|
329
|
+
|
|
330
|
+
return calls
|
|
331
|
+
|
|
332
|
+
def _execute_tools(
|
|
333
|
+
self, calls: list[tool_primitives.ToolCall]
|
|
334
|
+
) -> list[tool_primitives.ToolResult]:
|
|
335
|
+
"""Execute tool calls.
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
calls: Tool calls to execute
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
List of tool results
|
|
342
|
+
"""
|
|
343
|
+
results = []
|
|
344
|
+
for call in calls:
|
|
345
|
+
with tracing.span("execute_tool", tool_name=call.tool_name):
|
|
346
|
+
result = self._tools.execute(call)
|
|
347
|
+
results.append(result)
|
|
348
|
+
|
|
349
|
+
return results
|
|
350
|
+
|
|
351
|
+
def _format_tool_result(self, result: tool_primitives.ToolResult) -> str:
|
|
352
|
+
"""Format tool result for context.
|
|
353
|
+
|
|
354
|
+
Args:
|
|
355
|
+
result: Tool result
|
|
356
|
+
|
|
357
|
+
Returns:
|
|
358
|
+
Formatted string
|
|
359
|
+
"""
|
|
360
|
+
if result.is_success():
|
|
361
|
+
return f"Tool {result.call.tool_name} result: {result.result}"
|
|
362
|
+
else:
|
|
363
|
+
return f"Tool {result.call.tool_name} error: {result.error}"
|
|
364
|
+
|
|
365
|
+
def _update_task_prompt(
|
|
366
|
+
self,
|
|
367
|
+
task: core_entities.GenerationTask,
|
|
368
|
+
prompt_text: str,
|
|
369
|
+
iteration: int,
|
|
370
|
+
) -> core_entities.GenerationTask:
|
|
371
|
+
"""Update task with new prompt text.
|
|
372
|
+
|
|
373
|
+
Args:
|
|
374
|
+
task: Original task
|
|
375
|
+
prompt_text: New prompt text
|
|
376
|
+
iteration: Iteration number
|
|
377
|
+
|
|
378
|
+
Returns:
|
|
379
|
+
Updated task
|
|
380
|
+
"""
|
|
381
|
+
from themis.core.entities import PromptRender, PromptSpec
|
|
382
|
+
|
|
383
|
+
new_prompt = PromptRender(
|
|
384
|
+
spec=PromptSpec(
|
|
385
|
+
name=f"agentic_iter_{iteration}",
|
|
386
|
+
template="",
|
|
387
|
+
metadata={"iteration": iteration},
|
|
388
|
+
),
|
|
389
|
+
text=prompt_text,
|
|
390
|
+
context={"iteration": iteration},
|
|
391
|
+
metadata={"iteration": iteration},
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
metadata = dict(task.metadata)
|
|
395
|
+
metadata["iteration"] = iteration
|
|
396
|
+
metadata["agentic"] = True
|
|
397
|
+
|
|
398
|
+
return core_entities.GenerationTask(
|
|
399
|
+
prompt=new_prompt,
|
|
400
|
+
model=task.model,
|
|
401
|
+
sampling=task.sampling,
|
|
402
|
+
metadata=metadata,
|
|
403
|
+
reference=task.reference,
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
def _snapshot_context(
|
|
407
|
+
self, context: conv.ConversationContext
|
|
408
|
+
) -> conv.ConversationContext:
|
|
409
|
+
"""Create snapshot of context.
|
|
410
|
+
|
|
411
|
+
Args:
|
|
412
|
+
context: Context to snapshot
|
|
413
|
+
|
|
414
|
+
Returns:
|
|
415
|
+
Copy of context
|
|
416
|
+
"""
|
|
417
|
+
return conv.ConversationContext.from_dict(context.to_dict())
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
__all__ = ["AgenticIteration", "AgenticRecord", "AgenticRunner"]
|
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
"""Batch optimization for generation.
|
|
2
|
+
|
|
3
|
+
This module provides utilities for batching generation tasks to improve efficiency:
|
|
4
|
+
- BatchConfig: Configuration for batching behavior
|
|
5
|
+
- TaskBatcher: Groups tasks for efficient batch processing
|
|
6
|
+
- Batch-aware runner patterns
|
|
7
|
+
|
|
8
|
+
Batching can reduce:
|
|
9
|
+
- API call overhead
|
|
10
|
+
- Network latency
|
|
11
|
+
- Total generation time
|
|
12
|
+
- Cost (for providers with batch APIs)
|
|
13
|
+
|
|
14
|
+
Example:
|
|
15
|
+
>>> config = BatchConfig(max_batch_size=10, group_by=lambda t: t.model.identifier)
|
|
16
|
+
>>> batcher = TaskBatcher(config)
|
|
17
|
+
>>>
|
|
18
|
+
>>> # Group tasks
|
|
19
|
+
>>> for batch in batcher.create_batches(tasks):
|
|
20
|
+
... results = provider.generate_batch(batch)
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
from dataclasses import dataclass
|
|
26
|
+
from typing import Any, Callable, Iterator, Sequence
|
|
27
|
+
|
|
28
|
+
from themis.core import entities as core_entities
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class BatchConfig:
|
|
33
|
+
"""Configuration for batch processing.
|
|
34
|
+
|
|
35
|
+
Attributes:
|
|
36
|
+
max_batch_size: Maximum number of tasks per batch
|
|
37
|
+
group_by: Function to group compatible tasks (same return value = same batch)
|
|
38
|
+
timeout_ms: Maximum time to wait for batch to fill (future use)
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
max_batch_size: int = 10
|
|
42
|
+
group_by: Callable[[core_entities.GenerationTask], str] | None = None
|
|
43
|
+
timeout_ms: float = 100
|
|
44
|
+
|
|
45
|
+
def __post_init__(self):
|
|
46
|
+
"""Validate configuration."""
|
|
47
|
+
if self.max_batch_size < 1:
|
|
48
|
+
raise ValueError("max_batch_size must be >= 1")
|
|
49
|
+
if self.timeout_ms < 0:
|
|
50
|
+
raise ValueError("timeout_ms must be >= 0")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class TaskBatcher:
|
|
54
|
+
"""Groups generation tasks into batches for efficient processing.
|
|
55
|
+
|
|
56
|
+
The batcher can group tasks by various criteria (model, prompt length, etc.)
|
|
57
|
+
and create batches within size limits.
|
|
58
|
+
|
|
59
|
+
Example:
|
|
60
|
+
>>> batcher = TaskBatcher(BatchConfig(max_batch_size=5))
|
|
61
|
+
>>> tasks = [...] # 20 tasks
|
|
62
|
+
>>> batches = list(batcher.create_batches(tasks))
|
|
63
|
+
>>> len(batches) # 4 batches of 5 each
|
|
64
|
+
4
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(self, config: BatchConfig):
|
|
68
|
+
"""Initialize batcher.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
config: Batch configuration
|
|
72
|
+
"""
|
|
73
|
+
self._config = config
|
|
74
|
+
|
|
75
|
+
def create_batches(
|
|
76
|
+
self, tasks: Sequence[core_entities.GenerationTask]
|
|
77
|
+
) -> Iterator[list[core_entities.GenerationTask]]:
|
|
78
|
+
"""Create batches from tasks.
|
|
79
|
+
|
|
80
|
+
If group_by is specified, groups tasks first, then creates batches within groups.
|
|
81
|
+
Otherwise, creates batches in order up to max_batch_size.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
tasks: Tasks to batch
|
|
85
|
+
|
|
86
|
+
Yields:
|
|
87
|
+
Batches of tasks
|
|
88
|
+
"""
|
|
89
|
+
if self._config.group_by is None:
|
|
90
|
+
# Simple batching without grouping
|
|
91
|
+
yield from self._batch_sequential(tasks)
|
|
92
|
+
else:
|
|
93
|
+
# Group tasks first, then batch each group
|
|
94
|
+
groups = self._group_tasks(tasks)
|
|
95
|
+
for group in groups.values():
|
|
96
|
+
yield from self._batch_sequential(group)
|
|
97
|
+
|
|
98
|
+
def _group_tasks(
|
|
99
|
+
self, tasks: Sequence[core_entities.GenerationTask]
|
|
100
|
+
) -> dict[str, list[core_entities.GenerationTask]]:
|
|
101
|
+
"""Group tasks by grouping function.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
tasks: Tasks to group
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
Dictionary mapping group keys to task lists
|
|
108
|
+
"""
|
|
109
|
+
if self._config.group_by is None:
|
|
110
|
+
return {"default": list(tasks)}
|
|
111
|
+
|
|
112
|
+
groups: dict[str, list[core_entities.GenerationTask]] = {}
|
|
113
|
+
for task in tasks:
|
|
114
|
+
key = self._config.group_by(task)
|
|
115
|
+
groups.setdefault(key, []).append(task)
|
|
116
|
+
|
|
117
|
+
return groups
|
|
118
|
+
|
|
119
|
+
def _batch_sequential(
|
|
120
|
+
self, tasks: Sequence[core_entities.GenerationTask]
|
|
121
|
+
) -> Iterator[list[core_entities.GenerationTask]]:
|
|
122
|
+
"""Create batches sequentially up to max_batch_size.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
tasks: Tasks to batch
|
|
126
|
+
|
|
127
|
+
Yields:
|
|
128
|
+
Batches of tasks
|
|
129
|
+
"""
|
|
130
|
+
batch = []
|
|
131
|
+
for task in tasks:
|
|
132
|
+
batch.append(task)
|
|
133
|
+
if len(batch) >= self._config.max_batch_size:
|
|
134
|
+
yield batch
|
|
135
|
+
batch = []
|
|
136
|
+
|
|
137
|
+
# Yield remaining tasks
|
|
138
|
+
if batch:
|
|
139
|
+
yield batch
|
|
140
|
+
|
|
141
|
+
def get_batch_count(self, tasks: Sequence[core_entities.GenerationTask]) -> int:
|
|
142
|
+
"""Get number of batches that would be created.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
tasks: Tasks to batch
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Number of batches
|
|
149
|
+
"""
|
|
150
|
+
return len(list(self.create_batches(tasks)))
|
|
151
|
+
|
|
152
|
+
def get_batch_stats(
|
|
153
|
+
self, tasks: Sequence[core_entities.GenerationTask]
|
|
154
|
+
) -> dict[str, Any]:
|
|
155
|
+
"""Get statistics about batching.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
tasks: Tasks to batch
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
Dictionary with batching statistics
|
|
162
|
+
"""
|
|
163
|
+
batches = list(self.create_batches(tasks))
|
|
164
|
+
batch_sizes = [len(b) for b in batches]
|
|
165
|
+
|
|
166
|
+
stats = {
|
|
167
|
+
"total_tasks": len(tasks),
|
|
168
|
+
"num_batches": len(batches),
|
|
169
|
+
"max_batch_size": max(batch_sizes) if batch_sizes else 0,
|
|
170
|
+
"min_batch_size": min(batch_sizes) if batch_sizes else 0,
|
|
171
|
+
"avg_batch_size": sum(batch_sizes) / len(batch_sizes) if batch_sizes else 0,
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
# Add group stats if grouping is enabled
|
|
175
|
+
if self._config.group_by is not None:
|
|
176
|
+
groups = self._group_tasks(tasks)
|
|
177
|
+
stats["num_groups"] = len(groups)
|
|
178
|
+
stats["group_sizes"] = {key: len(tasks) for key, tasks in groups.items()}
|
|
179
|
+
|
|
180
|
+
return stats
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
# ============================================================================
|
|
184
|
+
# Batch-Aware Helpers
|
|
185
|
+
# ============================================================================
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def group_by_model(task: core_entities.GenerationTask) -> str:
|
|
189
|
+
"""Group tasks by model identifier.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
task: Task to group
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
Model identifier as grouping key
|
|
196
|
+
"""
|
|
197
|
+
return task.model.identifier
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def group_by_prompt_length(
|
|
201
|
+
task: core_entities.GenerationTask, bucket_size: int = 100
|
|
202
|
+
) -> str:
|
|
203
|
+
"""Group tasks by prompt length (bucketed).
|
|
204
|
+
|
|
205
|
+
Groups tasks into buckets based on prompt length. This can help
|
|
206
|
+
optimize batch processing when prompt length affects performance.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
task: Task to group
|
|
210
|
+
bucket_size: Size of length buckets
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
Bucket identifier as grouping key
|
|
214
|
+
"""
|
|
215
|
+
length = len(task.prompt.text)
|
|
216
|
+
bucket = (length // bucket_size) * bucket_size
|
|
217
|
+
return f"length_{bucket}-{bucket + bucket_size}"
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def group_by_model_and_sampling(task: core_entities.GenerationTask) -> str:
|
|
221
|
+
"""Group tasks by model and sampling configuration.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
task: Task to group
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
Combined model and sampling key
|
|
228
|
+
"""
|
|
229
|
+
sampling_key = f"t{task.sampling.temperature}_p{task.sampling.top_p}"
|
|
230
|
+
return f"{task.model.identifier}_{sampling_key}"
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def create_grouping_function(
|
|
234
|
+
*groupers: Callable[[core_entities.GenerationTask], str],
|
|
235
|
+
) -> Callable[[core_entities.GenerationTask], str]:
|
|
236
|
+
"""Create a composite grouping function from multiple groupers.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
*groupers: Grouping functions to combine
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
Combined grouping function
|
|
243
|
+
|
|
244
|
+
Example:
|
|
245
|
+
>>> # Group by both model and prompt length
|
|
246
|
+
>>> grouper = create_grouping_function(group_by_model, group_by_prompt_length)
|
|
247
|
+
>>> config = BatchConfig(group_by=grouper)
|
|
248
|
+
"""
|
|
249
|
+
|
|
250
|
+
def combined_grouper(task: core_entities.GenerationTask) -> str:
|
|
251
|
+
keys = [grouper(task) for grouper in groupers]
|
|
252
|
+
return "_".join(keys)
|
|
253
|
+
|
|
254
|
+
return combined_grouper
|