themis-eval 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. themis/cli/__init__.py +5 -0
  2. themis/cli/__main__.py +6 -0
  3. themis/cli/commands/__init__.py +19 -0
  4. themis/cli/commands/benchmarks.py +221 -0
  5. themis/cli/commands/comparison.py +394 -0
  6. themis/cli/commands/config_commands.py +244 -0
  7. themis/cli/commands/cost.py +214 -0
  8. themis/cli/commands/demo.py +68 -0
  9. themis/cli/commands/info.py +90 -0
  10. themis/cli/commands/leaderboard.py +362 -0
  11. themis/cli/commands/math_benchmarks.py +318 -0
  12. themis/cli/commands/mcq_benchmarks.py +207 -0
  13. themis/cli/commands/sample_run.py +244 -0
  14. themis/cli/commands/visualize.py +299 -0
  15. themis/cli/main.py +93 -0
  16. themis/cli/new_project.py +33 -0
  17. themis/cli/utils.py +51 -0
  18. themis/config/__init__.py +19 -0
  19. themis/config/loader.py +27 -0
  20. themis/config/registry.py +34 -0
  21. themis/config/runtime.py +214 -0
  22. themis/config/schema.py +112 -0
  23. themis/core/__init__.py +5 -0
  24. themis/core/conversation.py +354 -0
  25. themis/core/entities.py +164 -0
  26. themis/core/serialization.py +231 -0
  27. themis/core/tools.py +393 -0
  28. themis/core/types.py +141 -0
  29. themis/datasets/__init__.py +273 -0
  30. themis/datasets/base.py +264 -0
  31. themis/datasets/commonsense_qa.py +174 -0
  32. themis/datasets/competition_math.py +265 -0
  33. themis/datasets/coqa.py +133 -0
  34. themis/datasets/gpqa.py +190 -0
  35. themis/datasets/gsm8k.py +123 -0
  36. themis/datasets/gsm_symbolic.py +124 -0
  37. themis/datasets/math500.py +122 -0
  38. themis/datasets/med_qa.py +179 -0
  39. themis/datasets/medmcqa.py +169 -0
  40. themis/datasets/mmlu_pro.py +262 -0
  41. themis/datasets/piqa.py +146 -0
  42. themis/datasets/registry.py +201 -0
  43. themis/datasets/schema.py +245 -0
  44. themis/datasets/sciq.py +150 -0
  45. themis/datasets/social_i_qa.py +151 -0
  46. themis/datasets/super_gpqa.py +263 -0
  47. themis/evaluation/__init__.py +1 -0
  48. themis/evaluation/conditional.py +410 -0
  49. themis/evaluation/extractors/__init__.py +19 -0
  50. themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
  51. themis/evaluation/extractors/exceptions.py +7 -0
  52. themis/evaluation/extractors/identity_extractor.py +29 -0
  53. themis/evaluation/extractors/json_field_extractor.py +45 -0
  54. themis/evaluation/extractors/math_verify_extractor.py +37 -0
  55. themis/evaluation/extractors/regex_extractor.py +43 -0
  56. themis/evaluation/math_verify_utils.py +87 -0
  57. themis/evaluation/metrics/__init__.py +21 -0
  58. themis/evaluation/metrics/composite_metric.py +47 -0
  59. themis/evaluation/metrics/consistency_metric.py +80 -0
  60. themis/evaluation/metrics/exact_match.py +51 -0
  61. themis/evaluation/metrics/length_difference_tolerance.py +33 -0
  62. themis/evaluation/metrics/math_verify_accuracy.py +40 -0
  63. themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
  64. themis/evaluation/metrics/response_length.py +33 -0
  65. themis/evaluation/metrics/rubric_judge_metric.py +134 -0
  66. themis/evaluation/pipeline.py +49 -0
  67. themis/evaluation/pipelines/__init__.py +15 -0
  68. themis/evaluation/pipelines/composable_pipeline.py +357 -0
  69. themis/evaluation/pipelines/standard_pipeline.py +288 -0
  70. themis/evaluation/reports.py +293 -0
  71. themis/evaluation/statistics/__init__.py +53 -0
  72. themis/evaluation/statistics/bootstrap.py +79 -0
  73. themis/evaluation/statistics/confidence_intervals.py +121 -0
  74. themis/evaluation/statistics/distributions.py +207 -0
  75. themis/evaluation/statistics/effect_sizes.py +124 -0
  76. themis/evaluation/statistics/hypothesis_tests.py +305 -0
  77. themis/evaluation/statistics/types.py +139 -0
  78. themis/evaluation/strategies/__init__.py +13 -0
  79. themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
  80. themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
  81. themis/evaluation/strategies/evaluation_strategy.py +24 -0
  82. themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
  83. themis/experiment/__init__.py +5 -0
  84. themis/experiment/builder.py +151 -0
  85. themis/experiment/cache_manager.py +129 -0
  86. themis/experiment/comparison.py +631 -0
  87. themis/experiment/cost.py +310 -0
  88. themis/experiment/definitions.py +62 -0
  89. themis/experiment/export.py +690 -0
  90. themis/experiment/export_csv.py +159 -0
  91. themis/experiment/integration_manager.py +104 -0
  92. themis/experiment/math.py +192 -0
  93. themis/experiment/mcq.py +169 -0
  94. themis/experiment/orchestrator.py +373 -0
  95. themis/experiment/pricing.py +317 -0
  96. themis/experiment/storage.py +255 -0
  97. themis/experiment/visualization.py +588 -0
  98. themis/generation/__init__.py +1 -0
  99. themis/generation/agentic_runner.py +420 -0
  100. themis/generation/batching.py +254 -0
  101. themis/generation/clients.py +143 -0
  102. themis/generation/conversation_runner.py +236 -0
  103. themis/generation/plan.py +456 -0
  104. themis/generation/providers/litellm_provider.py +221 -0
  105. themis/generation/providers/vllm_provider.py +135 -0
  106. themis/generation/router.py +34 -0
  107. themis/generation/runner.py +207 -0
  108. themis/generation/strategies.py +98 -0
  109. themis/generation/templates.py +71 -0
  110. themis/generation/turn_strategies.py +393 -0
  111. themis/generation/types.py +9 -0
  112. themis/integrations/__init__.py +0 -0
  113. themis/integrations/huggingface.py +61 -0
  114. themis/integrations/wandb.py +65 -0
  115. themis/interfaces/__init__.py +83 -0
  116. themis/project/__init__.py +20 -0
  117. themis/project/definitions.py +98 -0
  118. themis/project/patterns.py +230 -0
  119. themis/providers/__init__.py +5 -0
  120. themis/providers/registry.py +39 -0
  121. themis/utils/api_generator.py +379 -0
  122. themis/utils/cost_tracking.py +376 -0
  123. themis/utils/dashboard.py +452 -0
  124. themis/utils/logging_utils.py +41 -0
  125. themis/utils/progress.py +58 -0
  126. themis/utils/tracing.py +320 -0
  127. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/METADATA +1 -1
  128. themis_eval-0.1.1.dist-info/RECORD +134 -0
  129. themis_eval-0.1.0.dist-info/RECORD +0 -8
  130. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/WHEEL +0 -0
  131. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/licenses/LICENSE +0 -0
  132. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,420 @@
1
+ """Agentic runner with tool use support.
2
+
3
+ This module provides a runner that supports agentic workflows where
4
+ models can call tools/functions to augment their capabilities.
5
+
6
+ Examples:
7
+ from themis.generation import agentic_runner
8
+ from themis.core import tools, entities
9
+
10
+ # Create registry with tools
11
+ registry = tools.ToolRegistry()
12
+ registry.register(tools.create_calculator_tool())
13
+
14
+ # Create runner
15
+ runner = agentic_runner.AgenticRunner(
16
+ provider=provider,
17
+ tool_registry=registry,
18
+ max_iterations=10
19
+ )
20
+
21
+ # Create task
22
+ task = entities.GenerationTask(...)
23
+
24
+ # Run with tool use
25
+ record = runner.run_agentic(task)
26
+ print(f"Used {len(record.iterations)} iterations")
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import json
32
+ import logging
33
+ import re
34
+ from dataclasses import dataclass, field
35
+ from typing import Any, Callable
36
+
37
+ from themis.core import conversation as conv
38
+ from themis.core import entities as core_entities
39
+ from themis.core import tools as tool_primitives
40
+ from themis.interfaces import ModelProvider
41
+ from themis.utils import tracing
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ @dataclass
47
+ class AgenticIteration:
48
+ """Single iteration in an agentic workflow.
49
+
50
+ Attributes:
51
+ iteration_number: Iteration index (0-based)
52
+ generation_record: Model generation for this iteration
53
+ tool_calls: Tool calls extracted from generation
54
+ tool_results: Results from executing tools
55
+ context_snapshot: Conversation context at this iteration
56
+ """
57
+
58
+ iteration_number: int
59
+ generation_record: core_entities.GenerationRecord
60
+ tool_calls: list[tool_primitives.ToolCall] = field(default_factory=list)
61
+ tool_results: list[tool_primitives.ToolResult] = field(default_factory=list)
62
+ context_snapshot: conv.ConversationContext | None = None
63
+
64
+
65
+ @dataclass
66
+ class AgenticRecord:
67
+ """Complete record of an agentic workflow execution.
68
+
69
+ Attributes:
70
+ task: Original generation task
71
+ final_output: Final model output
72
+ iterations: List of iterations executed
73
+ context: Conversation context (if used)
74
+ metadata: Additional metadata
75
+ """
76
+
77
+ task: core_entities.GenerationTask
78
+ final_output: core_entities.ModelOutput | None
79
+ iterations: list[AgenticIteration] = field(default_factory=list)
80
+ context: conv.ConversationContext | None = None
81
+ metadata: dict[str, Any] = field(default_factory=dict)
82
+
83
+ def total_iterations(self) -> int:
84
+ """Get total number of iterations.
85
+
86
+ Returns:
87
+ Number of iterations
88
+ """
89
+ return len(self.iterations)
90
+
91
+ def total_tool_calls(self) -> int:
92
+ """Get total number of tool calls across all iterations.
93
+
94
+ Returns:
95
+ Total tool calls
96
+ """
97
+ return sum(len(it.tool_calls) for it in self.iterations)
98
+
99
+ def successful_tool_calls(self) -> int:
100
+ """Get number of successful tool calls.
101
+
102
+ Returns:
103
+ Number of successful tool executions
104
+ """
105
+ count = 0
106
+ for iteration in self.iterations:
107
+ count += sum(1 for result in iteration.tool_results if result.is_success())
108
+ return count
109
+
110
+
111
+ class AgenticRunner:
112
+ """Runner supporting tool use and agentic workflows.
113
+
114
+ This runner executes an agentic loop where the model can make tool calls,
115
+ receive results, and continue processing until completion or max iterations.
116
+
117
+ Attributes:
118
+ provider: Model provider for generation
119
+ tool_registry: Registry of available tools
120
+ max_iterations: Maximum number of iterations
121
+ tool_call_parser: Function to parse tool calls from model output
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ *,
127
+ provider: ModelProvider,
128
+ tool_registry: tool_primitives.ToolRegistry,
129
+ max_iterations: int = 10,
130
+ tool_call_parser: Callable[[str], list[tool_primitives.ToolCall]] | None = None,
131
+ ):
132
+ """Initialize agentic runner.
133
+
134
+ Args:
135
+ provider: Model provider for generation
136
+ tool_registry: Registry of available tools
137
+ max_iterations: Maximum number of iterations
138
+ tool_call_parser: Optional custom parser for tool calls
139
+ """
140
+ self._provider = provider
141
+ self._tools = tool_registry
142
+ self._max_iterations = max_iterations
143
+ self._tool_call_parser = tool_call_parser or self._default_tool_call_parser
144
+
145
+ def run_agentic(self, task: core_entities.GenerationTask) -> AgenticRecord:
146
+ """Run agentic loop with tool use.
147
+
148
+ Args:
149
+ task: Generation task to execute
150
+
151
+ Returns:
152
+ AgenticRecord with full iteration history
153
+ """
154
+ task_id = task.metadata.get("dataset_id", "unknown")
155
+
156
+ with tracing.span(
157
+ "run_agentic",
158
+ task_id=task_id,
159
+ model=task.model.identifier,
160
+ max_iterations=self._max_iterations,
161
+ ):
162
+ # Initialize conversation context
163
+ context = conv.ConversationContext()
164
+ context.add_message("user", task.prompt.text)
165
+
166
+ # Add system message with tool descriptions
167
+ tool_descriptions = self._format_tool_descriptions()
168
+ if tool_descriptions:
169
+ context.add_message("system", tool_descriptions)
170
+
171
+ iterations: list[AgenticIteration] = []
172
+
173
+ for i in range(self._max_iterations):
174
+ with tracing.span("agentic_iteration", iteration=i):
175
+ logger.debug(
176
+ "Starting agentic iteration %d/%d", i + 1, self._max_iterations
177
+ )
178
+
179
+ # Generate with current context
180
+ with tracing.span("generate"):
181
+ prompt_text = context.to_prompt()
182
+ gen_task = self._update_task_prompt(task, prompt_text, i)
183
+ record = self._provider.generate(gen_task)
184
+
185
+ # Parse tool calls from output
186
+ with tracing.span("parse_tool_calls"):
187
+ tool_calls = self._parse_tool_calls(record)
188
+
189
+ # If no tool calls, we're done
190
+ if not tool_calls:
191
+ logger.debug("No tool calls found, ending agentic loop")
192
+ iteration = AgenticIteration(
193
+ iteration_number=i,
194
+ generation_record=record,
195
+ tool_calls=[],
196
+ tool_results=[],
197
+ context_snapshot=self._snapshot_context(context),
198
+ )
199
+ iterations.append(iteration)
200
+
201
+ # Add final assistant message
202
+ if record.output:
203
+ context.add_message("assistant", record.output.text)
204
+
205
+ break
206
+
207
+ # Execute tool calls
208
+ with tracing.span("execute_tools", num_tools=len(tool_calls)):
209
+ tool_results = self._execute_tools(tool_calls)
210
+
211
+ # Create iteration record
212
+ iteration = AgenticIteration(
213
+ iteration_number=i,
214
+ generation_record=record,
215
+ tool_calls=tool_calls,
216
+ tool_results=tool_results,
217
+ context_snapshot=self._snapshot_context(context),
218
+ )
219
+ iterations.append(iteration)
220
+
221
+ # Add assistant response to context
222
+ if record.output:
223
+ context.add_message("assistant", record.output.text)
224
+
225
+ # Add tool results to context
226
+ for result in tool_results:
227
+ result_text = self._format_tool_result(result)
228
+ context.add_message("tool", result_text)
229
+
230
+ logger.debug(
231
+ "Iteration %d: %d tool calls, %d successful",
232
+ i,
233
+ len(tool_calls),
234
+ sum(1 for r in tool_results if r.is_success()),
235
+ )
236
+
237
+ # Determine final output
238
+ final_output = None
239
+ if iterations:
240
+ final_output = iterations[-1].generation_record.output
241
+
242
+ # Create agentic record
243
+ record = AgenticRecord(
244
+ task=task,
245
+ final_output=final_output,
246
+ iterations=iterations,
247
+ context=context,
248
+ metadata={
249
+ "total_iterations": len(iterations),
250
+ "total_tool_calls": sum(len(it.tool_calls) for it in iterations),
251
+ "max_iterations_reached": len(iterations) >= self._max_iterations,
252
+ },
253
+ )
254
+
255
+ logger.info(
256
+ "Agentic execution completed: %d iterations, %d tool calls",
257
+ len(iterations),
258
+ record.total_tool_calls(),
259
+ )
260
+
261
+ return record
262
+
263
+ def _format_tool_descriptions(self) -> str:
264
+ """Format tool descriptions for system message.
265
+
266
+ Returns:
267
+ Formatted tool descriptions
268
+ """
269
+ tools = self._tools.list_tools()
270
+ if not tools:
271
+ return ""
272
+
273
+ lines = ["Available tools:"]
274
+ for tool in tools:
275
+ lines.append(f"\n- {tool.name}: {tool.description}")
276
+ lines.append(f" Parameters: {json.dumps(tool.parameters, indent=2)}")
277
+
278
+ lines.append(
279
+ '\nTo use a tool, output: TOOL_CALL: {"name": "tool_name", "arguments": {...}}'
280
+ )
281
+
282
+ return "\n".join(lines)
283
+
284
+ def _parse_tool_calls(
285
+ self, record: core_entities.GenerationRecord
286
+ ) -> list[tool_primitives.ToolCall]:
287
+ """Parse tool calls from generation record.
288
+
289
+ Args:
290
+ record: Generation record
291
+
292
+ Returns:
293
+ List of parsed tool calls
294
+ """
295
+ if not record.output:
296
+ return []
297
+
298
+ return self._tool_call_parser(record.output.text)
299
+
300
+ def _default_tool_call_parser(self, text: str) -> list[tool_primitives.ToolCall]:
301
+ """Default parser for tool calls.
302
+
303
+ Looks for lines like: TOOL_CALL: {"name": "...", "arguments": {...}}
304
+
305
+ Args:
306
+ text: Model output text
307
+
308
+ Returns:
309
+ List of parsed tool calls
310
+ """
311
+ calls = []
312
+
313
+ # Look for TOOL_CALL: {...} pattern
314
+ pattern = r"TOOL_CALL:\s*(\{.*?\})"
315
+ matches = re.finditer(pattern, text, re.DOTALL)
316
+
317
+ for match in matches:
318
+ try:
319
+ data = json.loads(match.group(1))
320
+ if "name" in data and "arguments" in data:
321
+ call = tool_primitives.ToolCall(
322
+ tool_name=data["name"],
323
+ arguments=data["arguments"],
324
+ )
325
+ calls.append(call)
326
+ except json.JSONDecodeError:
327
+ logger.warning("Failed to parse tool call JSON: %s", match.group(1))
328
+ continue
329
+
330
+ return calls
331
+
332
+ def _execute_tools(
333
+ self, calls: list[tool_primitives.ToolCall]
334
+ ) -> list[tool_primitives.ToolResult]:
335
+ """Execute tool calls.
336
+
337
+ Args:
338
+ calls: Tool calls to execute
339
+
340
+ Returns:
341
+ List of tool results
342
+ """
343
+ results = []
344
+ for call in calls:
345
+ with tracing.span("execute_tool", tool_name=call.tool_name):
346
+ result = self._tools.execute(call)
347
+ results.append(result)
348
+
349
+ return results
350
+
351
+ def _format_tool_result(self, result: tool_primitives.ToolResult) -> str:
352
+ """Format tool result for context.
353
+
354
+ Args:
355
+ result: Tool result
356
+
357
+ Returns:
358
+ Formatted string
359
+ """
360
+ if result.is_success():
361
+ return f"Tool {result.call.tool_name} result: {result.result}"
362
+ else:
363
+ return f"Tool {result.call.tool_name} error: {result.error}"
364
+
365
+ def _update_task_prompt(
366
+ self,
367
+ task: core_entities.GenerationTask,
368
+ prompt_text: str,
369
+ iteration: int,
370
+ ) -> core_entities.GenerationTask:
371
+ """Update task with new prompt text.
372
+
373
+ Args:
374
+ task: Original task
375
+ prompt_text: New prompt text
376
+ iteration: Iteration number
377
+
378
+ Returns:
379
+ Updated task
380
+ """
381
+ from themis.core.entities import PromptRender, PromptSpec
382
+
383
+ new_prompt = PromptRender(
384
+ spec=PromptSpec(
385
+ name=f"agentic_iter_{iteration}",
386
+ template="",
387
+ metadata={"iteration": iteration},
388
+ ),
389
+ text=prompt_text,
390
+ context={"iteration": iteration},
391
+ metadata={"iteration": iteration},
392
+ )
393
+
394
+ metadata = dict(task.metadata)
395
+ metadata["iteration"] = iteration
396
+ metadata["agentic"] = True
397
+
398
+ return core_entities.GenerationTask(
399
+ prompt=new_prompt,
400
+ model=task.model,
401
+ sampling=task.sampling,
402
+ metadata=metadata,
403
+ reference=task.reference,
404
+ )
405
+
406
+ def _snapshot_context(
407
+ self, context: conv.ConversationContext
408
+ ) -> conv.ConversationContext:
409
+ """Create snapshot of context.
410
+
411
+ Args:
412
+ context: Context to snapshot
413
+
414
+ Returns:
415
+ Copy of context
416
+ """
417
+ return conv.ConversationContext.from_dict(context.to_dict())
418
+
419
+
420
+ __all__ = ["AgenticIteration", "AgenticRecord", "AgenticRunner"]
@@ -0,0 +1,254 @@
1
+ """Batch optimization for generation.
2
+
3
+ This module provides utilities for batching generation tasks to improve efficiency:
4
+ - BatchConfig: Configuration for batching behavior
5
+ - TaskBatcher: Groups tasks for efficient batch processing
6
+ - Batch-aware runner patterns
7
+
8
+ Batching can reduce:
9
+ - API call overhead
10
+ - Network latency
11
+ - Total generation time
12
+ - Cost (for providers with batch APIs)
13
+
14
+ Example:
15
+ >>> config = BatchConfig(max_batch_size=10, group_by=lambda t: t.model.identifier)
16
+ >>> batcher = TaskBatcher(config)
17
+ >>>
18
+ >>> # Group tasks
19
+ >>> for batch in batcher.create_batches(tasks):
20
+ ... results = provider.generate_batch(batch)
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ from dataclasses import dataclass
26
+ from typing import Any, Callable, Iterator, Sequence
27
+
28
+ from themis.core import entities as core_entities
29
+
30
+
31
+ @dataclass
32
+ class BatchConfig:
33
+ """Configuration for batch processing.
34
+
35
+ Attributes:
36
+ max_batch_size: Maximum number of tasks per batch
37
+ group_by: Function to group compatible tasks (same return value = same batch)
38
+ timeout_ms: Maximum time to wait for batch to fill (future use)
39
+ """
40
+
41
+ max_batch_size: int = 10
42
+ group_by: Callable[[core_entities.GenerationTask], str] | None = None
43
+ timeout_ms: float = 100
44
+
45
+ def __post_init__(self):
46
+ """Validate configuration."""
47
+ if self.max_batch_size < 1:
48
+ raise ValueError("max_batch_size must be >= 1")
49
+ if self.timeout_ms < 0:
50
+ raise ValueError("timeout_ms must be >= 0")
51
+
52
+
53
+ class TaskBatcher:
54
+ """Groups generation tasks into batches for efficient processing.
55
+
56
+ The batcher can group tasks by various criteria (model, prompt length, etc.)
57
+ and create batches within size limits.
58
+
59
+ Example:
60
+ >>> batcher = TaskBatcher(BatchConfig(max_batch_size=5))
61
+ >>> tasks = [...] # 20 tasks
62
+ >>> batches = list(batcher.create_batches(tasks))
63
+ >>> len(batches) # 4 batches of 5 each
64
+ 4
65
+ """
66
+
67
+ def __init__(self, config: BatchConfig):
68
+ """Initialize batcher.
69
+
70
+ Args:
71
+ config: Batch configuration
72
+ """
73
+ self._config = config
74
+
75
+ def create_batches(
76
+ self, tasks: Sequence[core_entities.GenerationTask]
77
+ ) -> Iterator[list[core_entities.GenerationTask]]:
78
+ """Create batches from tasks.
79
+
80
+ If group_by is specified, groups tasks first, then creates batches within groups.
81
+ Otherwise, creates batches in order up to max_batch_size.
82
+
83
+ Args:
84
+ tasks: Tasks to batch
85
+
86
+ Yields:
87
+ Batches of tasks
88
+ """
89
+ if self._config.group_by is None:
90
+ # Simple batching without grouping
91
+ yield from self._batch_sequential(tasks)
92
+ else:
93
+ # Group tasks first, then batch each group
94
+ groups = self._group_tasks(tasks)
95
+ for group in groups.values():
96
+ yield from self._batch_sequential(group)
97
+
98
+ def _group_tasks(
99
+ self, tasks: Sequence[core_entities.GenerationTask]
100
+ ) -> dict[str, list[core_entities.GenerationTask]]:
101
+ """Group tasks by grouping function.
102
+
103
+ Args:
104
+ tasks: Tasks to group
105
+
106
+ Returns:
107
+ Dictionary mapping group keys to task lists
108
+ """
109
+ if self._config.group_by is None:
110
+ return {"default": list(tasks)}
111
+
112
+ groups: dict[str, list[core_entities.GenerationTask]] = {}
113
+ for task in tasks:
114
+ key = self._config.group_by(task)
115
+ groups.setdefault(key, []).append(task)
116
+
117
+ return groups
118
+
119
+ def _batch_sequential(
120
+ self, tasks: Sequence[core_entities.GenerationTask]
121
+ ) -> Iterator[list[core_entities.GenerationTask]]:
122
+ """Create batches sequentially up to max_batch_size.
123
+
124
+ Args:
125
+ tasks: Tasks to batch
126
+
127
+ Yields:
128
+ Batches of tasks
129
+ """
130
+ batch = []
131
+ for task in tasks:
132
+ batch.append(task)
133
+ if len(batch) >= self._config.max_batch_size:
134
+ yield batch
135
+ batch = []
136
+
137
+ # Yield remaining tasks
138
+ if batch:
139
+ yield batch
140
+
141
+ def get_batch_count(self, tasks: Sequence[core_entities.GenerationTask]) -> int:
142
+ """Get number of batches that would be created.
143
+
144
+ Args:
145
+ tasks: Tasks to batch
146
+
147
+ Returns:
148
+ Number of batches
149
+ """
150
+ return len(list(self.create_batches(tasks)))
151
+
152
+ def get_batch_stats(
153
+ self, tasks: Sequence[core_entities.GenerationTask]
154
+ ) -> dict[str, Any]:
155
+ """Get statistics about batching.
156
+
157
+ Args:
158
+ tasks: Tasks to batch
159
+
160
+ Returns:
161
+ Dictionary with batching statistics
162
+ """
163
+ batches = list(self.create_batches(tasks))
164
+ batch_sizes = [len(b) for b in batches]
165
+
166
+ stats = {
167
+ "total_tasks": len(tasks),
168
+ "num_batches": len(batches),
169
+ "max_batch_size": max(batch_sizes) if batch_sizes else 0,
170
+ "min_batch_size": min(batch_sizes) if batch_sizes else 0,
171
+ "avg_batch_size": sum(batch_sizes) / len(batch_sizes) if batch_sizes else 0,
172
+ }
173
+
174
+ # Add group stats if grouping is enabled
175
+ if self._config.group_by is not None:
176
+ groups = self._group_tasks(tasks)
177
+ stats["num_groups"] = len(groups)
178
+ stats["group_sizes"] = {key: len(tasks) for key, tasks in groups.items()}
179
+
180
+ return stats
181
+
182
+
183
+ # ============================================================================
184
+ # Batch-Aware Helpers
185
+ # ============================================================================
186
+
187
+
188
+ def group_by_model(task: core_entities.GenerationTask) -> str:
189
+ """Group tasks by model identifier.
190
+
191
+ Args:
192
+ task: Task to group
193
+
194
+ Returns:
195
+ Model identifier as grouping key
196
+ """
197
+ return task.model.identifier
198
+
199
+
200
+ def group_by_prompt_length(
201
+ task: core_entities.GenerationTask, bucket_size: int = 100
202
+ ) -> str:
203
+ """Group tasks by prompt length (bucketed).
204
+
205
+ Groups tasks into buckets based on prompt length. This can help
206
+ optimize batch processing when prompt length affects performance.
207
+
208
+ Args:
209
+ task: Task to group
210
+ bucket_size: Size of length buckets
211
+
212
+ Returns:
213
+ Bucket identifier as grouping key
214
+ """
215
+ length = len(task.prompt.text)
216
+ bucket = (length // bucket_size) * bucket_size
217
+ return f"length_{bucket}-{bucket + bucket_size}"
218
+
219
+
220
+ def group_by_model_and_sampling(task: core_entities.GenerationTask) -> str:
221
+ """Group tasks by model and sampling configuration.
222
+
223
+ Args:
224
+ task: Task to group
225
+
226
+ Returns:
227
+ Combined model and sampling key
228
+ """
229
+ sampling_key = f"t{task.sampling.temperature}_p{task.sampling.top_p}"
230
+ return f"{task.model.identifier}_{sampling_key}"
231
+
232
+
233
+ def create_grouping_function(
234
+ *groupers: Callable[[core_entities.GenerationTask], str],
235
+ ) -> Callable[[core_entities.GenerationTask], str]:
236
+ """Create a composite grouping function from multiple groupers.
237
+
238
+ Args:
239
+ *groupers: Grouping functions to combine
240
+
241
+ Returns:
242
+ Combined grouping function
243
+
244
+ Example:
245
+ >>> # Group by both model and prompt length
246
+ >>> grouper = create_grouping_function(group_by_model, group_by_prompt_length)
247
+ >>> config = BatchConfig(group_by=grouper)
248
+ """
249
+
250
+ def combined_grouper(task: core_entities.GenerationTask) -> str:
251
+ keys = [grouper(task) for grouper in groupers]
252
+ return "_".join(keys)
253
+
254
+ return combined_grouper