themis-eval 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themis/__init__.py +12 -1
- themis/_version.py +2 -2
- themis/api.py +343 -0
- themis/backends/__init__.py +17 -0
- themis/backends/execution.py +197 -0
- themis/backends/storage.py +260 -0
- themis/cli/__init__.py +5 -0
- themis/cli/__main__.py +6 -0
- themis/cli/commands/__init__.py +19 -0
- themis/cli/commands/benchmarks.py +221 -0
- themis/cli/commands/comparison.py +394 -0
- themis/cli/commands/config_commands.py +244 -0
- themis/cli/commands/cost.py +214 -0
- themis/cli/commands/demo.py +68 -0
- themis/cli/commands/info.py +90 -0
- themis/cli/commands/leaderboard.py +362 -0
- themis/cli/commands/math_benchmarks.py +318 -0
- themis/cli/commands/mcq_benchmarks.py +207 -0
- themis/cli/commands/results.py +252 -0
- themis/cli/commands/sample_run.py +244 -0
- themis/cli/commands/visualize.py +299 -0
- themis/cli/main.py +463 -0
- themis/cli/new_project.py +33 -0
- themis/cli/utils.py +51 -0
- themis/comparison/__init__.py +25 -0
- themis/comparison/engine.py +348 -0
- themis/comparison/reports.py +283 -0
- themis/comparison/statistics.py +402 -0
- themis/config/__init__.py +19 -0
- themis/config/loader.py +27 -0
- themis/config/registry.py +34 -0
- themis/config/runtime.py +214 -0
- themis/config/schema.py +112 -0
- themis/core/__init__.py +5 -0
- themis/core/conversation.py +354 -0
- themis/core/entities.py +184 -0
- themis/core/serialization.py +231 -0
- themis/core/tools.py +393 -0
- themis/core/types.py +141 -0
- themis/datasets/__init__.py +273 -0
- themis/datasets/base.py +264 -0
- themis/datasets/commonsense_qa.py +174 -0
- themis/datasets/competition_math.py +265 -0
- themis/datasets/coqa.py +133 -0
- themis/datasets/gpqa.py +190 -0
- themis/datasets/gsm8k.py +123 -0
- themis/datasets/gsm_symbolic.py +124 -0
- themis/datasets/math500.py +122 -0
- themis/datasets/med_qa.py +179 -0
- themis/datasets/medmcqa.py +169 -0
- themis/datasets/mmlu_pro.py +262 -0
- themis/datasets/piqa.py +146 -0
- themis/datasets/registry.py +201 -0
- themis/datasets/schema.py +245 -0
- themis/datasets/sciq.py +150 -0
- themis/datasets/social_i_qa.py +151 -0
- themis/datasets/super_gpqa.py +263 -0
- themis/evaluation/__init__.py +1 -0
- themis/evaluation/conditional.py +410 -0
- themis/evaluation/extractors/__init__.py +19 -0
- themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
- themis/evaluation/extractors/exceptions.py +7 -0
- themis/evaluation/extractors/identity_extractor.py +29 -0
- themis/evaluation/extractors/json_field_extractor.py +45 -0
- themis/evaluation/extractors/math_verify_extractor.py +37 -0
- themis/evaluation/extractors/regex_extractor.py +43 -0
- themis/evaluation/math_verify_utils.py +87 -0
- themis/evaluation/metrics/__init__.py +21 -0
- themis/evaluation/metrics/code/__init__.py +19 -0
- themis/evaluation/metrics/code/codebleu.py +144 -0
- themis/evaluation/metrics/code/execution.py +280 -0
- themis/evaluation/metrics/code/pass_at_k.py +181 -0
- themis/evaluation/metrics/composite_metric.py +47 -0
- themis/evaluation/metrics/consistency_metric.py +80 -0
- themis/evaluation/metrics/exact_match.py +51 -0
- themis/evaluation/metrics/length_difference_tolerance.py +33 -0
- themis/evaluation/metrics/math_verify_accuracy.py +40 -0
- themis/evaluation/metrics/nlp/__init__.py +21 -0
- themis/evaluation/metrics/nlp/bertscore.py +138 -0
- themis/evaluation/metrics/nlp/bleu.py +129 -0
- themis/evaluation/metrics/nlp/meteor.py +153 -0
- themis/evaluation/metrics/nlp/rouge.py +136 -0
- themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
- themis/evaluation/metrics/response_length.py +33 -0
- themis/evaluation/metrics/rubric_judge_metric.py +134 -0
- themis/evaluation/pipeline.py +49 -0
- themis/evaluation/pipelines/__init__.py +15 -0
- themis/evaluation/pipelines/composable_pipeline.py +357 -0
- themis/evaluation/pipelines/standard_pipeline.py +348 -0
- themis/evaluation/reports.py +293 -0
- themis/evaluation/statistics/__init__.py +53 -0
- themis/evaluation/statistics/bootstrap.py +79 -0
- themis/evaluation/statistics/confidence_intervals.py +121 -0
- themis/evaluation/statistics/distributions.py +207 -0
- themis/evaluation/statistics/effect_sizes.py +124 -0
- themis/evaluation/statistics/hypothesis_tests.py +305 -0
- themis/evaluation/statistics/types.py +139 -0
- themis/evaluation/strategies/__init__.py +13 -0
- themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
- themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
- themis/evaluation/strategies/evaluation_strategy.py +24 -0
- themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
- themis/experiment/__init__.py +5 -0
- themis/experiment/builder.py +151 -0
- themis/experiment/cache_manager.py +134 -0
- themis/experiment/comparison.py +631 -0
- themis/experiment/cost.py +310 -0
- themis/experiment/definitions.py +62 -0
- themis/experiment/export.py +798 -0
- themis/experiment/export_csv.py +159 -0
- themis/experiment/integration_manager.py +104 -0
- themis/experiment/math.py +192 -0
- themis/experiment/mcq.py +169 -0
- themis/experiment/orchestrator.py +415 -0
- themis/experiment/pricing.py +317 -0
- themis/experiment/storage.py +1458 -0
- themis/experiment/visualization.py +588 -0
- themis/generation/__init__.py +1 -0
- themis/generation/agentic_runner.py +420 -0
- themis/generation/batching.py +254 -0
- themis/generation/clients.py +143 -0
- themis/generation/conversation_runner.py +236 -0
- themis/generation/plan.py +456 -0
- themis/generation/providers/litellm_provider.py +221 -0
- themis/generation/providers/vllm_provider.py +135 -0
- themis/generation/router.py +34 -0
- themis/generation/runner.py +207 -0
- themis/generation/strategies.py +98 -0
- themis/generation/templates.py +71 -0
- themis/generation/turn_strategies.py +393 -0
- themis/generation/types.py +9 -0
- themis/integrations/__init__.py +0 -0
- themis/integrations/huggingface.py +72 -0
- themis/integrations/wandb.py +77 -0
- themis/interfaces/__init__.py +169 -0
- themis/presets/__init__.py +10 -0
- themis/presets/benchmarks.py +354 -0
- themis/presets/models.py +190 -0
- themis/project/__init__.py +20 -0
- themis/project/definitions.py +98 -0
- themis/project/patterns.py +230 -0
- themis/providers/__init__.py +5 -0
- themis/providers/registry.py +39 -0
- themis/server/__init__.py +28 -0
- themis/server/app.py +337 -0
- themis/utils/api_generator.py +379 -0
- themis/utils/cost_tracking.py +376 -0
- themis/utils/dashboard.py +452 -0
- themis/utils/logging_utils.py +41 -0
- themis/utils/progress.py +58 -0
- themis/utils/tracing.py +320 -0
- themis_eval-0.2.0.dist-info/METADATA +596 -0
- themis_eval-0.2.0.dist-info/RECORD +157 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/WHEEL +1 -1
- themis_eval-0.1.0.dist-info/METADATA +0 -758
- themis_eval-0.1.0.dist-info/RECORD +0 -8
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,456 @@
|
|
|
1
|
+
"""Generation planning primitives.
|
|
2
|
+
|
|
3
|
+
This module provides generation planning with flexible expansion strategies:
|
|
4
|
+
|
|
5
|
+
1. GenerationPlan: Traditional Cartesian product expansion
|
|
6
|
+
2. FlexibleGenerationPlan: Pluggable expansion strategies
|
|
7
|
+
3. Expansion strategies:
|
|
8
|
+
- CartesianExpansionStrategy: Full Cartesian product (default)
|
|
9
|
+
- FilteredExpansionStrategy: Filter specific combinations
|
|
10
|
+
- ConditionalExpansionStrategy: Route based on conditions
|
|
11
|
+
- ChainedExpansionStrategy: Chain multiple strategies
|
|
12
|
+
|
|
13
|
+
Example (Traditional):
|
|
14
|
+
>>> plan = GenerationPlan(
|
|
15
|
+
... templates=[template1, template2],
|
|
16
|
+
... models=[model1, model2],
|
|
17
|
+
... sampling_parameters=[config1]
|
|
18
|
+
... )
|
|
19
|
+
>>> tasks = list(plan.expand(dataset))
|
|
20
|
+
|
|
21
|
+
Example (Filtered):
|
|
22
|
+
>>> plan = FlexibleGenerationPlan(
|
|
23
|
+
... templates=[template1, template2],
|
|
24
|
+
... models=[model1, model2],
|
|
25
|
+
... sampling_parameters=[config1],
|
|
26
|
+
... expansion_strategy=FilteredExpansionStrategy(
|
|
27
|
+
... task_filter=lambda row, tpl, mdl, smp: mdl.identifier != "gpt-4" or row.get("difficulty") == "hard"
|
|
28
|
+
... )
|
|
29
|
+
... )
|
|
30
|
+
>>> tasks = list(plan.expand(dataset)) # Only creates GPT-4 tasks for hard problems
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
from __future__ import annotations
|
|
34
|
+
|
|
35
|
+
from dataclasses import dataclass, field
|
|
36
|
+
from typing import Any, Callable, Dict, Iterator, Protocol, Sequence
|
|
37
|
+
|
|
38
|
+
from themis.core import entities as core_entities
|
|
39
|
+
from themis.generation import templates
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class GenerationPlan:
|
|
44
|
+
templates: Sequence[templates.PromptTemplate]
|
|
45
|
+
models: Sequence[core_entities.ModelSpec]
|
|
46
|
+
sampling_parameters: Sequence[core_entities.SamplingConfig]
|
|
47
|
+
dataset_id_field: str = "id"
|
|
48
|
+
reference_field: str | None = "expected"
|
|
49
|
+
metadata_fields: Sequence[str] = field(default_factory=tuple)
|
|
50
|
+
context_builder: Callable[[dict[str, Any]], dict[str, Any]] | None = None
|
|
51
|
+
|
|
52
|
+
def expand(
|
|
53
|
+
self, dataset: Sequence[dict[str, object]]
|
|
54
|
+
) -> Iterator[core_entities.GenerationTask]:
|
|
55
|
+
for row in dataset:
|
|
56
|
+
row_dict = dict(row)
|
|
57
|
+
context = self._build_context(row_dict)
|
|
58
|
+
dataset_id = row_dict.get(self.dataset_id_field)
|
|
59
|
+
reference = (
|
|
60
|
+
row_dict.get(self.reference_field) if self.reference_field else None
|
|
61
|
+
)
|
|
62
|
+
for template in self.templates:
|
|
63
|
+
rendered_prompt = template.render_prompt(context)
|
|
64
|
+
base_metadata = self._build_metadata(template, dataset_id, row_dict)
|
|
65
|
+
for model in self.models:
|
|
66
|
+
for sampling in self.sampling_parameters:
|
|
67
|
+
yield core_entities.GenerationTask(
|
|
68
|
+
prompt=rendered_prompt,
|
|
69
|
+
model=model,
|
|
70
|
+
sampling=sampling,
|
|
71
|
+
metadata=dict(base_metadata),
|
|
72
|
+
reference=self._build_reference(reference),
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
def _build_context(self, row: dict[str, Any]) -> dict[str, Any]:
|
|
76
|
+
if self.context_builder is None:
|
|
77
|
+
return dict(row)
|
|
78
|
+
return self.context_builder(dict(row))
|
|
79
|
+
|
|
80
|
+
def _build_metadata(
|
|
81
|
+
self,
|
|
82
|
+
template: templates.PromptTemplate,
|
|
83
|
+
dataset_id: Any,
|
|
84
|
+
row: dict[str, Any],
|
|
85
|
+
) -> Dict[str, Any]:
|
|
86
|
+
metadata = {
|
|
87
|
+
f"template_{key}": value for key, value in (template.metadata or {}).items()
|
|
88
|
+
}
|
|
89
|
+
if dataset_id is not None:
|
|
90
|
+
metadata["dataset_id"] = dataset_id
|
|
91
|
+
for field_name in self.metadata_fields:
|
|
92
|
+
if field_name in row:
|
|
93
|
+
metadata[field_name] = row[field_name]
|
|
94
|
+
return metadata
|
|
95
|
+
|
|
96
|
+
def _build_reference(
|
|
97
|
+
self, raw_reference: Any | None
|
|
98
|
+
) -> core_entities.Reference | None:
|
|
99
|
+
if raw_reference is None:
|
|
100
|
+
return None
|
|
101
|
+
return core_entities.Reference(
|
|
102
|
+
kind=self.reference_field or "reference", value=raw_reference
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
# ============================================================================
|
|
107
|
+
# Flexible Generation Planning with Expansion Strategies
|
|
108
|
+
# ============================================================================
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@dataclass
|
|
112
|
+
class PlanContext:
|
|
113
|
+
"""Context passed to expansion strategies.
|
|
114
|
+
|
|
115
|
+
Contains all the information needed to expand dataset rows into tasks.
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
templates: Sequence[templates.PromptTemplate]
|
|
119
|
+
models: Sequence[core_entities.ModelSpec]
|
|
120
|
+
sampling_parameters: Sequence[core_entities.SamplingConfig]
|
|
121
|
+
dataset_id_field: str
|
|
122
|
+
reference_field: str | None
|
|
123
|
+
metadata_fields: Sequence[str]
|
|
124
|
+
context_builder: Callable[[dict], dict] | None
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class ExpansionStrategy(Protocol):
|
|
128
|
+
"""Strategy for expanding dataset into generation tasks.
|
|
129
|
+
|
|
130
|
+
Different strategies can control which combinations of
|
|
131
|
+
(row, template, model, sampling) are generated.
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
def expand(
|
|
135
|
+
self,
|
|
136
|
+
dataset: Sequence[dict[str, Any]],
|
|
137
|
+
context: PlanContext,
|
|
138
|
+
) -> Iterator[core_entities.GenerationTask]:
|
|
139
|
+
"""Expand dataset rows into generation tasks.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
dataset: Dataset rows to expand
|
|
143
|
+
context: Plan context with templates, models, etc.
|
|
144
|
+
|
|
145
|
+
Yields:
|
|
146
|
+
Generation tasks
|
|
147
|
+
"""
|
|
148
|
+
...
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class CartesianExpansionStrategy:
|
|
152
|
+
"""Traditional Cartesian product expansion (default behavior).
|
|
153
|
+
|
|
154
|
+
Generates all possible combinations of:
|
|
155
|
+
- Each row in dataset
|
|
156
|
+
- Each template
|
|
157
|
+
- Each model
|
|
158
|
+
- Each sampling configuration
|
|
159
|
+
|
|
160
|
+
This is the default expansion strategy used by GenerationPlan.
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
def expand(
|
|
164
|
+
self,
|
|
165
|
+
dataset: Sequence[dict[str, Any]],
|
|
166
|
+
context: PlanContext,
|
|
167
|
+
) -> Iterator[core_entities.GenerationTask]:
|
|
168
|
+
"""Expand using Cartesian product."""
|
|
169
|
+
for row in dataset:
|
|
170
|
+
row_dict = dict(row)
|
|
171
|
+
ctx = (
|
|
172
|
+
context.context_builder(row_dict)
|
|
173
|
+
if context.context_builder
|
|
174
|
+
else row_dict
|
|
175
|
+
)
|
|
176
|
+
dataset_id = row_dict.get(context.dataset_id_field)
|
|
177
|
+
reference = (
|
|
178
|
+
row_dict.get(context.reference_field)
|
|
179
|
+
if context.reference_field
|
|
180
|
+
else None
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
for template in context.templates:
|
|
184
|
+
rendered = template.render_prompt(ctx)
|
|
185
|
+
base_metadata = self._build_metadata(
|
|
186
|
+
template, dataset_id, row_dict, context
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
for model in context.models:
|
|
190
|
+
for sampling in context.sampling_parameters:
|
|
191
|
+
yield core_entities.GenerationTask(
|
|
192
|
+
prompt=rendered,
|
|
193
|
+
model=model,
|
|
194
|
+
sampling=sampling,
|
|
195
|
+
metadata=dict(base_metadata),
|
|
196
|
+
reference=self._build_reference(reference, context),
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
def _build_metadata(
|
|
200
|
+
self,
|
|
201
|
+
template: templates.PromptTemplate,
|
|
202
|
+
dataset_id: Any,
|
|
203
|
+
row: dict[str, Any],
|
|
204
|
+
context: PlanContext,
|
|
205
|
+
) -> Dict[str, Any]:
|
|
206
|
+
"""Build metadata dict for task."""
|
|
207
|
+
metadata = {
|
|
208
|
+
f"template_{key}": value for key, value in (template.metadata or {}).items()
|
|
209
|
+
}
|
|
210
|
+
if dataset_id is not None:
|
|
211
|
+
metadata["dataset_id"] = dataset_id
|
|
212
|
+
for field_name in context.metadata_fields:
|
|
213
|
+
if field_name in row:
|
|
214
|
+
metadata[field_name] = row[field_name]
|
|
215
|
+
return metadata
|
|
216
|
+
|
|
217
|
+
def _build_reference(
|
|
218
|
+
self, raw_reference: Any | None, context: PlanContext
|
|
219
|
+
) -> core_entities.Reference | None:
|
|
220
|
+
"""Build reference object."""
|
|
221
|
+
if raw_reference is None:
|
|
222
|
+
return None
|
|
223
|
+
return core_entities.Reference(
|
|
224
|
+
kind=context.reference_field or "reference", value=raw_reference
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
class FilteredExpansionStrategy:
|
|
229
|
+
"""Expansion strategy that filters specific combinations.
|
|
230
|
+
|
|
231
|
+
Only generates tasks that pass the filter function. Useful for:
|
|
232
|
+
- Expensive models only on hard problems
|
|
233
|
+
- Specific templates for specific models
|
|
234
|
+
- Conditional generation based on metadata
|
|
235
|
+
|
|
236
|
+
Example:
|
|
237
|
+
>>> # Only use GPT-4 on hard problems
|
|
238
|
+
>>> strategy = FilteredExpansionStrategy(
|
|
239
|
+
... task_filter=lambda row, tpl, mdl, smp: (
|
|
240
|
+
... mdl.identifier != "gpt-4" or row.get("difficulty") == "hard"
|
|
241
|
+
... )
|
|
242
|
+
... )
|
|
243
|
+
"""
|
|
244
|
+
|
|
245
|
+
def __init__(
|
|
246
|
+
self,
|
|
247
|
+
task_filter: Callable[
|
|
248
|
+
[
|
|
249
|
+
dict[str, Any], # row
|
|
250
|
+
templates.PromptTemplate, # template
|
|
251
|
+
core_entities.ModelSpec, # model
|
|
252
|
+
core_entities.SamplingConfig, # sampling
|
|
253
|
+
],
|
|
254
|
+
bool,
|
|
255
|
+
],
|
|
256
|
+
):
|
|
257
|
+
"""Initialize filtered expansion strategy.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
task_filter: Function that returns True if task should be generated
|
|
261
|
+
"""
|
|
262
|
+
self._filter = task_filter
|
|
263
|
+
self._base_strategy = CartesianExpansionStrategy()
|
|
264
|
+
|
|
265
|
+
def expand(
|
|
266
|
+
self,
|
|
267
|
+
dataset: Sequence[dict[str, Any]],
|
|
268
|
+
context: PlanContext,
|
|
269
|
+
) -> Iterator[core_entities.GenerationTask]:
|
|
270
|
+
"""Expand with filtering."""
|
|
271
|
+
for row in dataset:
|
|
272
|
+
row_dict = dict(row)
|
|
273
|
+
ctx = (
|
|
274
|
+
context.context_builder(row_dict)
|
|
275
|
+
if context.context_builder
|
|
276
|
+
else row_dict
|
|
277
|
+
)
|
|
278
|
+
dataset_id = row_dict.get(context.dataset_id_field)
|
|
279
|
+
reference = (
|
|
280
|
+
row_dict.get(context.reference_field)
|
|
281
|
+
if context.reference_field
|
|
282
|
+
else None
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
for template in context.templates:
|
|
286
|
+
rendered = template.render_prompt(ctx)
|
|
287
|
+
base_metadata = self._base_strategy._build_metadata(
|
|
288
|
+
template, dataset_id, row_dict, context
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
for model in context.models:
|
|
292
|
+
for sampling in context.sampling_parameters:
|
|
293
|
+
# Check if this combination should be generated
|
|
294
|
+
if self._filter(row_dict, template, model, sampling):
|
|
295
|
+
yield core_entities.GenerationTask(
|
|
296
|
+
prompt=rendered,
|
|
297
|
+
model=model,
|
|
298
|
+
sampling=sampling,
|
|
299
|
+
metadata=dict(base_metadata),
|
|
300
|
+
reference=self._base_strategy._build_reference(
|
|
301
|
+
reference, context
|
|
302
|
+
),
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
class ConditionalExpansionStrategy:
|
|
307
|
+
"""Expansion strategy that routes to different strategies based on conditions.
|
|
308
|
+
|
|
309
|
+
Evaluates conditions in order and uses the first matching strategy.
|
|
310
|
+
Falls back to default strategy if no conditions match.
|
|
311
|
+
|
|
312
|
+
Example:
|
|
313
|
+
>>> # Use different strategies for math vs code problems
|
|
314
|
+
>>> strategy = ConditionalExpansionStrategy(
|
|
315
|
+
... rules=[
|
|
316
|
+
... (lambda row: row.get("type") == "math", math_strategy),
|
|
317
|
+
... (lambda row: row.get("type") == "code", code_strategy),
|
|
318
|
+
... ],
|
|
319
|
+
... default=default_strategy
|
|
320
|
+
... )
|
|
321
|
+
"""
|
|
322
|
+
|
|
323
|
+
def __init__(
|
|
324
|
+
self,
|
|
325
|
+
rules: list[tuple[Callable[[dict], bool], ExpansionStrategy]],
|
|
326
|
+
default: ExpansionStrategy,
|
|
327
|
+
):
|
|
328
|
+
"""Initialize conditional expansion strategy.
|
|
329
|
+
|
|
330
|
+
Args:
|
|
331
|
+
rules: List of (condition, strategy) tuples
|
|
332
|
+
default: Default strategy if no conditions match
|
|
333
|
+
"""
|
|
334
|
+
self._rules = rules
|
|
335
|
+
self._default = default
|
|
336
|
+
|
|
337
|
+
def expand(
|
|
338
|
+
self,
|
|
339
|
+
dataset: Sequence[dict[str, Any]],
|
|
340
|
+
context: PlanContext,
|
|
341
|
+
) -> Iterator[core_entities.GenerationTask]:
|
|
342
|
+
"""Expand using conditional routing."""
|
|
343
|
+
# Group rows by which strategy applies
|
|
344
|
+
strategy_groups: dict[int, list[dict]] = {}
|
|
345
|
+
|
|
346
|
+
for row in dataset:
|
|
347
|
+
# Find first matching rule
|
|
348
|
+
matched = False
|
|
349
|
+
for rule_idx, (condition, strategy) in enumerate(self._rules):
|
|
350
|
+
if condition(row):
|
|
351
|
+
strategy_groups.setdefault(rule_idx, []).append(row)
|
|
352
|
+
matched = True
|
|
353
|
+
break
|
|
354
|
+
|
|
355
|
+
if not matched:
|
|
356
|
+
strategy_groups.setdefault(-1, []).append(row)
|
|
357
|
+
|
|
358
|
+
# Expand each group with its strategy
|
|
359
|
+
for rule_idx, group_rows in strategy_groups.items():
|
|
360
|
+
if rule_idx == -1:
|
|
361
|
+
strategy = self._default
|
|
362
|
+
else:
|
|
363
|
+
strategy = self._rules[rule_idx][1]
|
|
364
|
+
|
|
365
|
+
yield from strategy.expand(group_rows, context)
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
class ChainedExpansionStrategy:
|
|
369
|
+
"""Expansion strategy that chains multiple strategies.
|
|
370
|
+
|
|
371
|
+
Applies multiple strategies in sequence, yielding tasks from all of them.
|
|
372
|
+
Useful for combining different expansion approaches.
|
|
373
|
+
|
|
374
|
+
Example:
|
|
375
|
+
>>> # Generate baseline tasks + additional high-temperature samples for hard problems
|
|
376
|
+
>>> strategy = ChainedExpansionStrategy([
|
|
377
|
+
... CartesianExpansionStrategy(),
|
|
378
|
+
... FilteredExpansionStrategy(
|
|
379
|
+
... task_filter=lambda row, tpl, mdl, smp: (
|
|
380
|
+
... row.get("difficulty") == "hard" and smp.temperature > 0.5
|
|
381
|
+
... )
|
|
382
|
+
... )
|
|
383
|
+
... ])
|
|
384
|
+
"""
|
|
385
|
+
|
|
386
|
+
def __init__(self, strategies: Sequence[ExpansionStrategy]):
|
|
387
|
+
"""Initialize chained expansion strategy.
|
|
388
|
+
|
|
389
|
+
Args:
|
|
390
|
+
strategies: List of strategies to apply in sequence
|
|
391
|
+
"""
|
|
392
|
+
self._strategies = strategies
|
|
393
|
+
|
|
394
|
+
def expand(
|
|
395
|
+
self,
|
|
396
|
+
dataset: Sequence[dict[str, Any]],
|
|
397
|
+
context: PlanContext,
|
|
398
|
+
) -> Iterator[core_entities.GenerationTask]:
|
|
399
|
+
"""Expand using all strategies in sequence."""
|
|
400
|
+
for strategy in self._strategies:
|
|
401
|
+
yield from strategy.expand(dataset, context)
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
@dataclass
|
|
405
|
+
class FlexibleGenerationPlan:
|
|
406
|
+
"""Generation plan with pluggable expansion strategy.
|
|
407
|
+
|
|
408
|
+
Allows controlling how dataset rows are expanded into generation tasks
|
|
409
|
+
using different expansion strategies.
|
|
410
|
+
|
|
411
|
+
Example:
|
|
412
|
+
>>> # Filter expensive model to hard problems only
|
|
413
|
+
>>> plan = FlexibleGenerationPlan(
|
|
414
|
+
... templates=[template1, template2],
|
|
415
|
+
... models=[cheap_model, expensive_model],
|
|
416
|
+
... sampling_parameters=[config],
|
|
417
|
+
... expansion_strategy=FilteredExpansionStrategy(
|
|
418
|
+
... task_filter=lambda row, tpl, mdl, smp: (
|
|
419
|
+
... mdl.identifier != "expensive" or row.get("difficulty") == "hard"
|
|
420
|
+
... )
|
|
421
|
+
... )
|
|
422
|
+
... )
|
|
423
|
+
"""
|
|
424
|
+
|
|
425
|
+
templates: Sequence[templates.PromptTemplate]
|
|
426
|
+
models: Sequence[core_entities.ModelSpec]
|
|
427
|
+
sampling_parameters: Sequence[core_entities.SamplingConfig]
|
|
428
|
+
expansion_strategy: ExpansionStrategy | None = None
|
|
429
|
+
dataset_id_field: str = "id"
|
|
430
|
+
reference_field: str | None = "expected"
|
|
431
|
+
metadata_fields: Sequence[str] = field(default_factory=tuple)
|
|
432
|
+
context_builder: Callable[[dict], dict] | None = None
|
|
433
|
+
|
|
434
|
+
def expand(
|
|
435
|
+
self, dataset: Sequence[dict[str, object]]
|
|
436
|
+
) -> Iterator[core_entities.GenerationTask]:
|
|
437
|
+
"""Expand dataset into generation tasks using strategy.
|
|
438
|
+
|
|
439
|
+
Args:
|
|
440
|
+
dataset: Dataset rows to expand
|
|
441
|
+
|
|
442
|
+
Yields:
|
|
443
|
+
Generation tasks
|
|
444
|
+
"""
|
|
445
|
+
context = PlanContext(
|
|
446
|
+
templates=self.templates,
|
|
447
|
+
models=self.models,
|
|
448
|
+
sampling_parameters=self.sampling_parameters,
|
|
449
|
+
dataset_id_field=self.dataset_id_field,
|
|
450
|
+
reference_field=self.reference_field,
|
|
451
|
+
metadata_fields=self.metadata_fields,
|
|
452
|
+
context_builder=self.context_builder,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
strategy = self.expansion_strategy or CartesianExpansionStrategy()
|
|
456
|
+
yield from strategy.expand(dataset, context)
|
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
"""LiteLLM provider supporting 100+ LLM providers through a unified interface."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import threading
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Any, Dict
|
|
8
|
+
|
|
9
|
+
from themis.core import entities as core_entities
|
|
10
|
+
from themis.interfaces import ModelProvider
|
|
11
|
+
from themis.providers import register_provider
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class LiteLLMProvider(ModelProvider):
|
|
16
|
+
"""
|
|
17
|
+
Universal LLM provider using LiteLLM.
|
|
18
|
+
|
|
19
|
+
Supports 100+ providers including:
|
|
20
|
+
- OpenAI (gpt-4, gpt-3.5-turbo, etc.)
|
|
21
|
+
- Anthropic (claude-3-opus, claude-3-sonnet, etc.)
|
|
22
|
+
- Azure OpenAI (azure/<deployment-name>)
|
|
23
|
+
- AWS Bedrock (bedrock/<model-id>)
|
|
24
|
+
- Google AI (gemini-pro, etc.)
|
|
25
|
+
- Cohere, Replicate, Hugging Face, and many more
|
|
26
|
+
|
|
27
|
+
Configuration options:
|
|
28
|
+
- api_key: Optional API key (can also use env vars like OPENAI_API_KEY)
|
|
29
|
+
- api_base: Optional custom API base URL
|
|
30
|
+
- timeout: Request timeout in seconds (default: 60)
|
|
31
|
+
- max_retries: Number of retries for failed requests (default: 2)
|
|
32
|
+
- n_parallel: Maximum number of parallel requests (default: 10)
|
|
33
|
+
- drop_params: Whether to drop unsupported params (default: False)
|
|
34
|
+
- custom_llm_provider: Force a specific provider (e.g., "openai", "anthropic")
|
|
35
|
+
- extra_kwargs: Additional kwargs to pass to litellm.completion()
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
api_key: str | None = None
|
|
39
|
+
api_base: str | None = None
|
|
40
|
+
timeout: int = 60
|
|
41
|
+
max_retries: int = 2
|
|
42
|
+
n_parallel: int = 10
|
|
43
|
+
drop_params: bool = False
|
|
44
|
+
custom_llm_provider: str | None = None
|
|
45
|
+
extra_kwargs: Dict[str, Any] | None = None
|
|
46
|
+
|
|
47
|
+
def __post_init__(self) -> None:
|
|
48
|
+
self._semaphore = threading.Semaphore(max(1, self.n_parallel))
|
|
49
|
+
self._extra_kwargs = self.extra_kwargs or {}
|
|
50
|
+
|
|
51
|
+
# Lazy import to avoid import errors if litellm not installed
|
|
52
|
+
try:
|
|
53
|
+
import litellm
|
|
54
|
+
|
|
55
|
+
self._litellm = litellm
|
|
56
|
+
# Configure litellm settings
|
|
57
|
+
litellm.drop_params = self.drop_params
|
|
58
|
+
if self.max_retries > 0:
|
|
59
|
+
litellm.num_retries = self.max_retries
|
|
60
|
+
except ImportError as exc:
|
|
61
|
+
raise RuntimeError(
|
|
62
|
+
"LiteLLM is not installed. Install via `pip install litellm` or "
|
|
63
|
+
"`uv add litellm` to use LiteLLMProvider."
|
|
64
|
+
) from exc
|
|
65
|
+
|
|
66
|
+
def generate(
|
|
67
|
+
self, task: core_entities.GenerationTask
|
|
68
|
+
) -> core_entities.GenerationRecord: # type: ignore[override]
|
|
69
|
+
"""Generate a response using LiteLLM."""
|
|
70
|
+
|
|
71
|
+
messages = self._build_messages(task)
|
|
72
|
+
completion_kwargs = self._build_completion_kwargs(task, messages)
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
with self._semaphore:
|
|
76
|
+
response = self._litellm.completion(**completion_kwargs)
|
|
77
|
+
|
|
78
|
+
# Extract the generated text
|
|
79
|
+
text = response.choices[0].message.content or ""
|
|
80
|
+
|
|
81
|
+
# Extract usage information
|
|
82
|
+
usage = response.usage if hasattr(response, "usage") else None
|
|
83
|
+
usage_dict = None
|
|
84
|
+
metrics = {}
|
|
85
|
+
if usage:
|
|
86
|
+
prompt_tokens = getattr(usage, "prompt_tokens", None)
|
|
87
|
+
completion_tokens = getattr(usage, "completion_tokens", None)
|
|
88
|
+
total_tokens = getattr(usage, "total_tokens", None)
|
|
89
|
+
|
|
90
|
+
metrics["prompt_tokens"] = prompt_tokens
|
|
91
|
+
metrics["completion_tokens"] = completion_tokens
|
|
92
|
+
metrics["total_tokens"] = total_tokens
|
|
93
|
+
# Alias for consistency with other providers
|
|
94
|
+
metrics["response_tokens"] = completion_tokens
|
|
95
|
+
|
|
96
|
+
# Create usage dict for cost tracking
|
|
97
|
+
if prompt_tokens is not None and completion_tokens is not None:
|
|
98
|
+
usage_dict = {
|
|
99
|
+
"prompt_tokens": prompt_tokens,
|
|
100
|
+
"completion_tokens": completion_tokens,
|
|
101
|
+
"total_tokens": total_tokens or (prompt_tokens + completion_tokens),
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
# Extract model information
|
|
105
|
+
model_used = getattr(response, "model", task.model.identifier)
|
|
106
|
+
metrics["model_used"] = model_used
|
|
107
|
+
|
|
108
|
+
# Convert response to dict for raw storage
|
|
109
|
+
raw_data = response.model_dump() if hasattr(response, "model_dump") else {}
|
|
110
|
+
|
|
111
|
+
return core_entities.GenerationRecord(
|
|
112
|
+
task=task,
|
|
113
|
+
output=core_entities.ModelOutput(text=text, raw=raw_data, usage=usage_dict),
|
|
114
|
+
error=None,
|
|
115
|
+
metrics=metrics,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
except Exception as exc:
|
|
119
|
+
# Capture detailed error information
|
|
120
|
+
error_type = type(exc).__name__
|
|
121
|
+
error_message = str(exc)
|
|
122
|
+
|
|
123
|
+
# Extract additional context if available
|
|
124
|
+
details: Dict[str, Any] = {
|
|
125
|
+
"error_type": error_type,
|
|
126
|
+
"model": task.model.identifier,
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
# Check for specific litellm exceptions
|
|
130
|
+
if hasattr(exc, "status_code"):
|
|
131
|
+
details["status_code"] = exc.status_code # type: ignore
|
|
132
|
+
if hasattr(exc, "llm_provider"):
|
|
133
|
+
details["llm_provider"] = exc.llm_provider # type: ignore
|
|
134
|
+
|
|
135
|
+
return core_entities.GenerationRecord(
|
|
136
|
+
task=task,
|
|
137
|
+
output=None,
|
|
138
|
+
error=core_entities.ModelError(
|
|
139
|
+
message=error_message,
|
|
140
|
+
kind=error_type,
|
|
141
|
+
details=details,
|
|
142
|
+
),
|
|
143
|
+
metrics={},
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
def _build_messages(
|
|
147
|
+
self, task: core_entities.GenerationTask
|
|
148
|
+
) -> list[dict[str, str]]:
|
|
149
|
+
"""Build messages array from the generation task."""
|
|
150
|
+
messages = []
|
|
151
|
+
|
|
152
|
+
# Add system message if provided in metadata
|
|
153
|
+
system_prompt = task.prompt.metadata.get(
|
|
154
|
+
"system_prompt"
|
|
155
|
+
) or task.prompt.spec.metadata.get("system_prompt")
|
|
156
|
+
if system_prompt:
|
|
157
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
158
|
+
|
|
159
|
+
# Add the main user prompt
|
|
160
|
+
messages.append({"role": "user", "content": task.prompt.text})
|
|
161
|
+
|
|
162
|
+
# Support for conversation history if provided
|
|
163
|
+
conversation_history = task.metadata.get("conversation_history")
|
|
164
|
+
if conversation_history:
|
|
165
|
+
# If conversation history is provided, prepend it
|
|
166
|
+
messages = conversation_history + messages
|
|
167
|
+
|
|
168
|
+
return messages
|
|
169
|
+
|
|
170
|
+
def _build_completion_kwargs(
|
|
171
|
+
self, task: core_entities.GenerationTask, messages: list[dict[str, str]]
|
|
172
|
+
) -> Dict[str, Any]:
|
|
173
|
+
"""Build the kwargs dictionary for litellm.completion()."""
|
|
174
|
+
|
|
175
|
+
kwargs: Dict[str, Any] = {
|
|
176
|
+
"model": task.model.identifier,
|
|
177
|
+
"messages": messages,
|
|
178
|
+
"temperature": task.sampling.temperature,
|
|
179
|
+
"top_p": task.sampling.top_p,
|
|
180
|
+
"timeout": self.timeout,
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
# Add max_tokens if specified (negative values mean no limit)
|
|
184
|
+
if task.sampling.max_tokens >= 0:
|
|
185
|
+
kwargs["max_tokens"] = task.sampling.max_tokens
|
|
186
|
+
|
|
187
|
+
# Add API key if provided
|
|
188
|
+
if self.api_key:
|
|
189
|
+
kwargs["api_key"] = self.api_key
|
|
190
|
+
|
|
191
|
+
# Add custom API base if provided
|
|
192
|
+
if self.api_base:
|
|
193
|
+
kwargs["api_base"] = self.api_base
|
|
194
|
+
|
|
195
|
+
# Add custom provider if specified
|
|
196
|
+
if self.custom_llm_provider:
|
|
197
|
+
kwargs["custom_llm_provider"] = self.custom_llm_provider
|
|
198
|
+
|
|
199
|
+
# Merge any extra kwargs provided in configuration
|
|
200
|
+
if self._extra_kwargs:
|
|
201
|
+
kwargs.update(self._extra_kwargs)
|
|
202
|
+
|
|
203
|
+
# Allow task-level overrides via metadata
|
|
204
|
+
litellm_kwargs = task.metadata.get("litellm_kwargs", {})
|
|
205
|
+
if litellm_kwargs:
|
|
206
|
+
kwargs.update(litellm_kwargs)
|
|
207
|
+
|
|
208
|
+
return kwargs
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
# Register the provider with multiple aliases for convenience
|
|
212
|
+
register_provider("litellm", LiteLLMProvider)
|
|
213
|
+
register_provider("openai", LiteLLMProvider)
|
|
214
|
+
register_provider("anthropic", LiteLLMProvider)
|
|
215
|
+
register_provider("azure", LiteLLMProvider)
|
|
216
|
+
register_provider("bedrock", LiteLLMProvider)
|
|
217
|
+
register_provider("gemini", LiteLLMProvider)
|
|
218
|
+
register_provider("cohere", LiteLLMProvider)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
__all__ = ["LiteLLMProvider"]
|