themis-eval 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. themis/cli/__init__.py +5 -0
  2. themis/cli/__main__.py +6 -0
  3. themis/cli/commands/__init__.py +19 -0
  4. themis/cli/commands/benchmarks.py +221 -0
  5. themis/cli/commands/comparison.py +394 -0
  6. themis/cli/commands/config_commands.py +244 -0
  7. themis/cli/commands/cost.py +214 -0
  8. themis/cli/commands/demo.py +68 -0
  9. themis/cli/commands/info.py +90 -0
  10. themis/cli/commands/leaderboard.py +362 -0
  11. themis/cli/commands/math_benchmarks.py +318 -0
  12. themis/cli/commands/mcq_benchmarks.py +207 -0
  13. themis/cli/commands/sample_run.py +244 -0
  14. themis/cli/commands/visualize.py +299 -0
  15. themis/cli/main.py +93 -0
  16. themis/cli/new_project.py +33 -0
  17. themis/cli/utils.py +51 -0
  18. themis/config/__init__.py +19 -0
  19. themis/config/loader.py +27 -0
  20. themis/config/registry.py +34 -0
  21. themis/config/runtime.py +214 -0
  22. themis/config/schema.py +112 -0
  23. themis/core/__init__.py +5 -0
  24. themis/core/conversation.py +354 -0
  25. themis/core/entities.py +164 -0
  26. themis/core/serialization.py +231 -0
  27. themis/core/tools.py +393 -0
  28. themis/core/types.py +141 -0
  29. themis/datasets/__init__.py +273 -0
  30. themis/datasets/base.py +264 -0
  31. themis/datasets/commonsense_qa.py +174 -0
  32. themis/datasets/competition_math.py +265 -0
  33. themis/datasets/coqa.py +133 -0
  34. themis/datasets/gpqa.py +190 -0
  35. themis/datasets/gsm8k.py +123 -0
  36. themis/datasets/gsm_symbolic.py +124 -0
  37. themis/datasets/math500.py +122 -0
  38. themis/datasets/med_qa.py +179 -0
  39. themis/datasets/medmcqa.py +169 -0
  40. themis/datasets/mmlu_pro.py +262 -0
  41. themis/datasets/piqa.py +146 -0
  42. themis/datasets/registry.py +201 -0
  43. themis/datasets/schema.py +245 -0
  44. themis/datasets/sciq.py +150 -0
  45. themis/datasets/social_i_qa.py +151 -0
  46. themis/datasets/super_gpqa.py +263 -0
  47. themis/evaluation/__init__.py +1 -0
  48. themis/evaluation/conditional.py +410 -0
  49. themis/evaluation/extractors/__init__.py +19 -0
  50. themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
  51. themis/evaluation/extractors/exceptions.py +7 -0
  52. themis/evaluation/extractors/identity_extractor.py +29 -0
  53. themis/evaluation/extractors/json_field_extractor.py +45 -0
  54. themis/evaluation/extractors/math_verify_extractor.py +37 -0
  55. themis/evaluation/extractors/regex_extractor.py +43 -0
  56. themis/evaluation/math_verify_utils.py +87 -0
  57. themis/evaluation/metrics/__init__.py +21 -0
  58. themis/evaluation/metrics/composite_metric.py +47 -0
  59. themis/evaluation/metrics/consistency_metric.py +80 -0
  60. themis/evaluation/metrics/exact_match.py +51 -0
  61. themis/evaluation/metrics/length_difference_tolerance.py +33 -0
  62. themis/evaluation/metrics/math_verify_accuracy.py +40 -0
  63. themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
  64. themis/evaluation/metrics/response_length.py +33 -0
  65. themis/evaluation/metrics/rubric_judge_metric.py +134 -0
  66. themis/evaluation/pipeline.py +49 -0
  67. themis/evaluation/pipelines/__init__.py +15 -0
  68. themis/evaluation/pipelines/composable_pipeline.py +357 -0
  69. themis/evaluation/pipelines/standard_pipeline.py +288 -0
  70. themis/evaluation/reports.py +293 -0
  71. themis/evaluation/statistics/__init__.py +53 -0
  72. themis/evaluation/statistics/bootstrap.py +79 -0
  73. themis/evaluation/statistics/confidence_intervals.py +121 -0
  74. themis/evaluation/statistics/distributions.py +207 -0
  75. themis/evaluation/statistics/effect_sizes.py +124 -0
  76. themis/evaluation/statistics/hypothesis_tests.py +305 -0
  77. themis/evaluation/statistics/types.py +139 -0
  78. themis/evaluation/strategies/__init__.py +13 -0
  79. themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
  80. themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
  81. themis/evaluation/strategies/evaluation_strategy.py +24 -0
  82. themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
  83. themis/experiment/__init__.py +5 -0
  84. themis/experiment/builder.py +151 -0
  85. themis/experiment/cache_manager.py +129 -0
  86. themis/experiment/comparison.py +631 -0
  87. themis/experiment/cost.py +310 -0
  88. themis/experiment/definitions.py +62 -0
  89. themis/experiment/export.py +690 -0
  90. themis/experiment/export_csv.py +159 -0
  91. themis/experiment/integration_manager.py +104 -0
  92. themis/experiment/math.py +192 -0
  93. themis/experiment/mcq.py +169 -0
  94. themis/experiment/orchestrator.py +373 -0
  95. themis/experiment/pricing.py +317 -0
  96. themis/experiment/storage.py +255 -0
  97. themis/experiment/visualization.py +588 -0
  98. themis/generation/__init__.py +1 -0
  99. themis/generation/agentic_runner.py +420 -0
  100. themis/generation/batching.py +254 -0
  101. themis/generation/clients.py +143 -0
  102. themis/generation/conversation_runner.py +236 -0
  103. themis/generation/plan.py +456 -0
  104. themis/generation/providers/litellm_provider.py +221 -0
  105. themis/generation/providers/vllm_provider.py +135 -0
  106. themis/generation/router.py +34 -0
  107. themis/generation/runner.py +207 -0
  108. themis/generation/strategies.py +98 -0
  109. themis/generation/templates.py +71 -0
  110. themis/generation/turn_strategies.py +393 -0
  111. themis/generation/types.py +9 -0
  112. themis/integrations/__init__.py +0 -0
  113. themis/integrations/huggingface.py +61 -0
  114. themis/integrations/wandb.py +65 -0
  115. themis/interfaces/__init__.py +83 -0
  116. themis/project/__init__.py +20 -0
  117. themis/project/definitions.py +98 -0
  118. themis/project/patterns.py +230 -0
  119. themis/providers/__init__.py +5 -0
  120. themis/providers/registry.py +39 -0
  121. themis/utils/api_generator.py +379 -0
  122. themis/utils/cost_tracking.py +376 -0
  123. themis/utils/dashboard.py +452 -0
  124. themis/utils/logging_utils.py +41 -0
  125. themis/utils/progress.py +58 -0
  126. themis/utils/tracing.py +320 -0
  127. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/METADATA +1 -1
  128. themis_eval-0.1.1.dist-info/RECORD +134 -0
  129. themis_eval-0.1.0.dist-info/RECORD +0 -8
  130. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/WHEEL +0 -0
  131. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/licenses/LICENSE +0 -0
  132. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,456 @@
1
+ """Generation planning primitives.
2
+
3
+ This module provides generation planning with flexible expansion strategies:
4
+
5
+ 1. GenerationPlan: Traditional Cartesian product expansion
6
+ 2. FlexibleGenerationPlan: Pluggable expansion strategies
7
+ 3. Expansion strategies:
8
+ - CartesianExpansionStrategy: Full Cartesian product (default)
9
+ - FilteredExpansionStrategy: Filter specific combinations
10
+ - ConditionalExpansionStrategy: Route based on conditions
11
+ - ChainedExpansionStrategy: Chain multiple strategies
12
+
13
+ Example (Traditional):
14
+ >>> plan = GenerationPlan(
15
+ ... templates=[template1, template2],
16
+ ... models=[model1, model2],
17
+ ... sampling_parameters=[config1]
18
+ ... )
19
+ >>> tasks = list(plan.expand(dataset))
20
+
21
+ Example (Filtered):
22
+ >>> plan = FlexibleGenerationPlan(
23
+ ... templates=[template1, template2],
24
+ ... models=[model1, model2],
25
+ ... sampling_parameters=[config1],
26
+ ... expansion_strategy=FilteredExpansionStrategy(
27
+ ... task_filter=lambda row, tpl, mdl, smp: mdl.identifier != "gpt-4" or row.get("difficulty") == "hard"
28
+ ... )
29
+ ... )
30
+ >>> tasks = list(plan.expand(dataset)) # Only creates GPT-4 tasks for hard problems
31
+ """
32
+
33
+ from __future__ import annotations
34
+
35
+ from dataclasses import dataclass, field
36
+ from typing import Any, Callable, Dict, Iterator, Protocol, Sequence
37
+
38
+ from themis.core import entities as core_entities
39
+ from themis.generation import templates
40
+
41
+
42
+ @dataclass
43
+ class GenerationPlan:
44
+ templates: Sequence[templates.PromptTemplate]
45
+ models: Sequence[core_entities.ModelSpec]
46
+ sampling_parameters: Sequence[core_entities.SamplingConfig]
47
+ dataset_id_field: str = "id"
48
+ reference_field: str | None = "expected"
49
+ metadata_fields: Sequence[str] = field(default_factory=tuple)
50
+ context_builder: Callable[[dict[str, Any]], dict[str, Any]] | None = None
51
+
52
+ def expand(
53
+ self, dataset: Sequence[dict[str, object]]
54
+ ) -> Iterator[core_entities.GenerationTask]:
55
+ for row in dataset:
56
+ row_dict = dict(row)
57
+ context = self._build_context(row_dict)
58
+ dataset_id = row_dict.get(self.dataset_id_field)
59
+ reference = (
60
+ row_dict.get(self.reference_field) if self.reference_field else None
61
+ )
62
+ for template in self.templates:
63
+ rendered_prompt = template.render_prompt(context)
64
+ base_metadata = self._build_metadata(template, dataset_id, row_dict)
65
+ for model in self.models:
66
+ for sampling in self.sampling_parameters:
67
+ yield core_entities.GenerationTask(
68
+ prompt=rendered_prompt,
69
+ model=model,
70
+ sampling=sampling,
71
+ metadata=dict(base_metadata),
72
+ reference=self._build_reference(reference),
73
+ )
74
+
75
+ def _build_context(self, row: dict[str, Any]) -> dict[str, Any]:
76
+ if self.context_builder is None:
77
+ return dict(row)
78
+ return self.context_builder(dict(row))
79
+
80
+ def _build_metadata(
81
+ self,
82
+ template: templates.PromptTemplate,
83
+ dataset_id: Any,
84
+ row: dict[str, Any],
85
+ ) -> Dict[str, Any]:
86
+ metadata = {
87
+ f"template_{key}": value for key, value in (template.metadata or {}).items()
88
+ }
89
+ if dataset_id is not None:
90
+ metadata["dataset_id"] = dataset_id
91
+ for field_name in self.metadata_fields:
92
+ if field_name in row:
93
+ metadata[field_name] = row[field_name]
94
+ return metadata
95
+
96
+ def _build_reference(
97
+ self, raw_reference: Any | None
98
+ ) -> core_entities.Reference | None:
99
+ if raw_reference is None:
100
+ return None
101
+ return core_entities.Reference(
102
+ kind=self.reference_field or "reference", value=raw_reference
103
+ )
104
+
105
+
106
+ # ============================================================================
107
+ # Flexible Generation Planning with Expansion Strategies
108
+ # ============================================================================
109
+
110
+
111
+ @dataclass
112
+ class PlanContext:
113
+ """Context passed to expansion strategies.
114
+
115
+ Contains all the information needed to expand dataset rows into tasks.
116
+ """
117
+
118
+ templates: Sequence[templates.PromptTemplate]
119
+ models: Sequence[core_entities.ModelSpec]
120
+ sampling_parameters: Sequence[core_entities.SamplingConfig]
121
+ dataset_id_field: str
122
+ reference_field: str | None
123
+ metadata_fields: Sequence[str]
124
+ context_builder: Callable[[dict], dict] | None
125
+
126
+
127
+ class ExpansionStrategy(Protocol):
128
+ """Strategy for expanding dataset into generation tasks.
129
+
130
+ Different strategies can control which combinations of
131
+ (row, template, model, sampling) are generated.
132
+ """
133
+
134
+ def expand(
135
+ self,
136
+ dataset: Sequence[dict[str, Any]],
137
+ context: PlanContext,
138
+ ) -> Iterator[core_entities.GenerationTask]:
139
+ """Expand dataset rows into generation tasks.
140
+
141
+ Args:
142
+ dataset: Dataset rows to expand
143
+ context: Plan context with templates, models, etc.
144
+
145
+ Yields:
146
+ Generation tasks
147
+ """
148
+ ...
149
+
150
+
151
+ class CartesianExpansionStrategy:
152
+ """Traditional Cartesian product expansion (default behavior).
153
+
154
+ Generates all possible combinations of:
155
+ - Each row in dataset
156
+ - Each template
157
+ - Each model
158
+ - Each sampling configuration
159
+
160
+ This is the default expansion strategy used by GenerationPlan.
161
+ """
162
+
163
+ def expand(
164
+ self,
165
+ dataset: Sequence[dict[str, Any]],
166
+ context: PlanContext,
167
+ ) -> Iterator[core_entities.GenerationTask]:
168
+ """Expand using Cartesian product."""
169
+ for row in dataset:
170
+ row_dict = dict(row)
171
+ ctx = (
172
+ context.context_builder(row_dict)
173
+ if context.context_builder
174
+ else row_dict
175
+ )
176
+ dataset_id = row_dict.get(context.dataset_id_field)
177
+ reference = (
178
+ row_dict.get(context.reference_field)
179
+ if context.reference_field
180
+ else None
181
+ )
182
+
183
+ for template in context.templates:
184
+ rendered = template.render_prompt(ctx)
185
+ base_metadata = self._build_metadata(
186
+ template, dataset_id, row_dict, context
187
+ )
188
+
189
+ for model in context.models:
190
+ for sampling in context.sampling_parameters:
191
+ yield core_entities.GenerationTask(
192
+ prompt=rendered,
193
+ model=model,
194
+ sampling=sampling,
195
+ metadata=dict(base_metadata),
196
+ reference=self._build_reference(reference, context),
197
+ )
198
+
199
+ def _build_metadata(
200
+ self,
201
+ template: templates.PromptTemplate,
202
+ dataset_id: Any,
203
+ row: dict[str, Any],
204
+ context: PlanContext,
205
+ ) -> Dict[str, Any]:
206
+ """Build metadata dict for task."""
207
+ metadata = {
208
+ f"template_{key}": value for key, value in (template.metadata or {}).items()
209
+ }
210
+ if dataset_id is not None:
211
+ metadata["dataset_id"] = dataset_id
212
+ for field_name in context.metadata_fields:
213
+ if field_name in row:
214
+ metadata[field_name] = row[field_name]
215
+ return metadata
216
+
217
+ def _build_reference(
218
+ self, raw_reference: Any | None, context: PlanContext
219
+ ) -> core_entities.Reference | None:
220
+ """Build reference object."""
221
+ if raw_reference is None:
222
+ return None
223
+ return core_entities.Reference(
224
+ kind=context.reference_field or "reference", value=raw_reference
225
+ )
226
+
227
+
228
+ class FilteredExpansionStrategy:
229
+ """Expansion strategy that filters specific combinations.
230
+
231
+ Only generates tasks that pass the filter function. Useful for:
232
+ - Expensive models only on hard problems
233
+ - Specific templates for specific models
234
+ - Conditional generation based on metadata
235
+
236
+ Example:
237
+ >>> # Only use GPT-4 on hard problems
238
+ >>> strategy = FilteredExpansionStrategy(
239
+ ... task_filter=lambda row, tpl, mdl, smp: (
240
+ ... mdl.identifier != "gpt-4" or row.get("difficulty") == "hard"
241
+ ... )
242
+ ... )
243
+ """
244
+
245
+ def __init__(
246
+ self,
247
+ task_filter: Callable[
248
+ [
249
+ dict[str, Any], # row
250
+ templates.PromptTemplate, # template
251
+ core_entities.ModelSpec, # model
252
+ core_entities.SamplingConfig, # sampling
253
+ ],
254
+ bool,
255
+ ],
256
+ ):
257
+ """Initialize filtered expansion strategy.
258
+
259
+ Args:
260
+ task_filter: Function that returns True if task should be generated
261
+ """
262
+ self._filter = task_filter
263
+ self._base_strategy = CartesianExpansionStrategy()
264
+
265
+ def expand(
266
+ self,
267
+ dataset: Sequence[dict[str, Any]],
268
+ context: PlanContext,
269
+ ) -> Iterator[core_entities.GenerationTask]:
270
+ """Expand with filtering."""
271
+ for row in dataset:
272
+ row_dict = dict(row)
273
+ ctx = (
274
+ context.context_builder(row_dict)
275
+ if context.context_builder
276
+ else row_dict
277
+ )
278
+ dataset_id = row_dict.get(context.dataset_id_field)
279
+ reference = (
280
+ row_dict.get(context.reference_field)
281
+ if context.reference_field
282
+ else None
283
+ )
284
+
285
+ for template in context.templates:
286
+ rendered = template.render_prompt(ctx)
287
+ base_metadata = self._base_strategy._build_metadata(
288
+ template, dataset_id, row_dict, context
289
+ )
290
+
291
+ for model in context.models:
292
+ for sampling in context.sampling_parameters:
293
+ # Check if this combination should be generated
294
+ if self._filter(row_dict, template, model, sampling):
295
+ yield core_entities.GenerationTask(
296
+ prompt=rendered,
297
+ model=model,
298
+ sampling=sampling,
299
+ metadata=dict(base_metadata),
300
+ reference=self._base_strategy._build_reference(
301
+ reference, context
302
+ ),
303
+ )
304
+
305
+
306
+ class ConditionalExpansionStrategy:
307
+ """Expansion strategy that routes to different strategies based on conditions.
308
+
309
+ Evaluates conditions in order and uses the first matching strategy.
310
+ Falls back to default strategy if no conditions match.
311
+
312
+ Example:
313
+ >>> # Use different strategies for math vs code problems
314
+ >>> strategy = ConditionalExpansionStrategy(
315
+ ... rules=[
316
+ ... (lambda row: row.get("type") == "math", math_strategy),
317
+ ... (lambda row: row.get("type") == "code", code_strategy),
318
+ ... ],
319
+ ... default=default_strategy
320
+ ... )
321
+ """
322
+
323
+ def __init__(
324
+ self,
325
+ rules: list[tuple[Callable[[dict], bool], ExpansionStrategy]],
326
+ default: ExpansionStrategy,
327
+ ):
328
+ """Initialize conditional expansion strategy.
329
+
330
+ Args:
331
+ rules: List of (condition, strategy) tuples
332
+ default: Default strategy if no conditions match
333
+ """
334
+ self._rules = rules
335
+ self._default = default
336
+
337
+ def expand(
338
+ self,
339
+ dataset: Sequence[dict[str, Any]],
340
+ context: PlanContext,
341
+ ) -> Iterator[core_entities.GenerationTask]:
342
+ """Expand using conditional routing."""
343
+ # Group rows by which strategy applies
344
+ strategy_groups: dict[int, list[dict]] = {}
345
+
346
+ for row in dataset:
347
+ # Find first matching rule
348
+ matched = False
349
+ for rule_idx, (condition, strategy) in enumerate(self._rules):
350
+ if condition(row):
351
+ strategy_groups.setdefault(rule_idx, []).append(row)
352
+ matched = True
353
+ break
354
+
355
+ if not matched:
356
+ strategy_groups.setdefault(-1, []).append(row)
357
+
358
+ # Expand each group with its strategy
359
+ for rule_idx, group_rows in strategy_groups.items():
360
+ if rule_idx == -1:
361
+ strategy = self._default
362
+ else:
363
+ strategy = self._rules[rule_idx][1]
364
+
365
+ yield from strategy.expand(group_rows, context)
366
+
367
+
368
+ class ChainedExpansionStrategy:
369
+ """Expansion strategy that chains multiple strategies.
370
+
371
+ Applies multiple strategies in sequence, yielding tasks from all of them.
372
+ Useful for combining different expansion approaches.
373
+
374
+ Example:
375
+ >>> # Generate baseline tasks + additional high-temperature samples for hard problems
376
+ >>> strategy = ChainedExpansionStrategy([
377
+ ... CartesianExpansionStrategy(),
378
+ ... FilteredExpansionStrategy(
379
+ ... task_filter=lambda row, tpl, mdl, smp: (
380
+ ... row.get("difficulty") == "hard" and smp.temperature > 0.5
381
+ ... )
382
+ ... )
383
+ ... ])
384
+ """
385
+
386
+ def __init__(self, strategies: Sequence[ExpansionStrategy]):
387
+ """Initialize chained expansion strategy.
388
+
389
+ Args:
390
+ strategies: List of strategies to apply in sequence
391
+ """
392
+ self._strategies = strategies
393
+
394
+ def expand(
395
+ self,
396
+ dataset: Sequence[dict[str, Any]],
397
+ context: PlanContext,
398
+ ) -> Iterator[core_entities.GenerationTask]:
399
+ """Expand using all strategies in sequence."""
400
+ for strategy in self._strategies:
401
+ yield from strategy.expand(dataset, context)
402
+
403
+
404
+ @dataclass
405
+ class FlexibleGenerationPlan:
406
+ """Generation plan with pluggable expansion strategy.
407
+
408
+ Allows controlling how dataset rows are expanded into generation tasks
409
+ using different expansion strategies.
410
+
411
+ Example:
412
+ >>> # Filter expensive model to hard problems only
413
+ >>> plan = FlexibleGenerationPlan(
414
+ ... templates=[template1, template2],
415
+ ... models=[cheap_model, expensive_model],
416
+ ... sampling_parameters=[config],
417
+ ... expansion_strategy=FilteredExpansionStrategy(
418
+ ... task_filter=lambda row, tpl, mdl, smp: (
419
+ ... mdl.identifier != "expensive" or row.get("difficulty") == "hard"
420
+ ... )
421
+ ... )
422
+ ... )
423
+ """
424
+
425
+ templates: Sequence[templates.PromptTemplate]
426
+ models: Sequence[core_entities.ModelSpec]
427
+ sampling_parameters: Sequence[core_entities.SamplingConfig]
428
+ expansion_strategy: ExpansionStrategy | None = None
429
+ dataset_id_field: str = "id"
430
+ reference_field: str | None = "expected"
431
+ metadata_fields: Sequence[str] = field(default_factory=tuple)
432
+ context_builder: Callable[[dict], dict] | None = None
433
+
434
+ def expand(
435
+ self, dataset: Sequence[dict[str, object]]
436
+ ) -> Iterator[core_entities.GenerationTask]:
437
+ """Expand dataset into generation tasks using strategy.
438
+
439
+ Args:
440
+ dataset: Dataset rows to expand
441
+
442
+ Yields:
443
+ Generation tasks
444
+ """
445
+ context = PlanContext(
446
+ templates=self.templates,
447
+ models=self.models,
448
+ sampling_parameters=self.sampling_parameters,
449
+ dataset_id_field=self.dataset_id_field,
450
+ reference_field=self.reference_field,
451
+ metadata_fields=self.metadata_fields,
452
+ context_builder=self.context_builder,
453
+ )
454
+
455
+ strategy = self.expansion_strategy or CartesianExpansionStrategy()
456
+ yield from strategy.expand(dataset, context)
@@ -0,0 +1,221 @@
1
+ """LiteLLM provider supporting 100+ LLM providers through a unified interface."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import threading
6
+ from dataclasses import dataclass
7
+ from typing import Any, Dict
8
+
9
+ from themis.core import entities as core_entities
10
+ from themis.interfaces import ModelProvider
11
+ from themis.providers import register_provider
12
+
13
+
14
+ @dataclass
15
+ class LiteLLMProvider(ModelProvider):
16
+ """
17
+ Universal LLM provider using LiteLLM.
18
+
19
+ Supports 100+ providers including:
20
+ - OpenAI (gpt-4, gpt-3.5-turbo, etc.)
21
+ - Anthropic (claude-3-opus, claude-3-sonnet, etc.)
22
+ - Azure OpenAI (azure/<deployment-name>)
23
+ - AWS Bedrock (bedrock/<model-id>)
24
+ - Google AI (gemini-pro, etc.)
25
+ - Cohere, Replicate, Hugging Face, and many more
26
+
27
+ Configuration options:
28
+ - api_key: Optional API key (can also use env vars like OPENAI_API_KEY)
29
+ - api_base: Optional custom API base URL
30
+ - timeout: Request timeout in seconds (default: 60)
31
+ - max_retries: Number of retries for failed requests (default: 2)
32
+ - n_parallel: Maximum number of parallel requests (default: 10)
33
+ - drop_params: Whether to drop unsupported params (default: False)
34
+ - custom_llm_provider: Force a specific provider (e.g., "openai", "anthropic")
35
+ - extra_kwargs: Additional kwargs to pass to litellm.completion()
36
+ """
37
+
38
+ api_key: str | None = None
39
+ api_base: str | None = None
40
+ timeout: int = 60
41
+ max_retries: int = 2
42
+ n_parallel: int = 10
43
+ drop_params: bool = False
44
+ custom_llm_provider: str | None = None
45
+ extra_kwargs: Dict[str, Any] | None = None
46
+
47
+ def __post_init__(self) -> None:
48
+ self._semaphore = threading.Semaphore(max(1, self.n_parallel))
49
+ self._extra_kwargs = self.extra_kwargs or {}
50
+
51
+ # Lazy import to avoid import errors if litellm not installed
52
+ try:
53
+ import litellm
54
+
55
+ self._litellm = litellm
56
+ # Configure litellm settings
57
+ litellm.drop_params = self.drop_params
58
+ if self.max_retries > 0:
59
+ litellm.num_retries = self.max_retries
60
+ except ImportError as exc:
61
+ raise RuntimeError(
62
+ "LiteLLM is not installed. Install via `pip install litellm` or "
63
+ "`uv add litellm` to use LiteLLMProvider."
64
+ ) from exc
65
+
66
+ def generate(
67
+ self, task: core_entities.GenerationTask
68
+ ) -> core_entities.GenerationRecord: # type: ignore[override]
69
+ """Generate a response using LiteLLM."""
70
+
71
+ messages = self._build_messages(task)
72
+ completion_kwargs = self._build_completion_kwargs(task, messages)
73
+
74
+ try:
75
+ with self._semaphore:
76
+ response = self._litellm.completion(**completion_kwargs)
77
+
78
+ # Extract the generated text
79
+ text = response.choices[0].message.content or ""
80
+
81
+ # Extract usage information
82
+ usage = response.usage if hasattr(response, "usage") else None
83
+ usage_dict = None
84
+ metrics = {}
85
+ if usage:
86
+ prompt_tokens = getattr(usage, "prompt_tokens", None)
87
+ completion_tokens = getattr(usage, "completion_tokens", None)
88
+ total_tokens = getattr(usage, "total_tokens", None)
89
+
90
+ metrics["prompt_tokens"] = prompt_tokens
91
+ metrics["completion_tokens"] = completion_tokens
92
+ metrics["total_tokens"] = total_tokens
93
+ # Alias for consistency with other providers
94
+ metrics["response_tokens"] = completion_tokens
95
+
96
+ # Create usage dict for cost tracking
97
+ if prompt_tokens is not None and completion_tokens is not None:
98
+ usage_dict = {
99
+ "prompt_tokens": prompt_tokens,
100
+ "completion_tokens": completion_tokens,
101
+ "total_tokens": total_tokens or (prompt_tokens + completion_tokens),
102
+ }
103
+
104
+ # Extract model information
105
+ model_used = getattr(response, "model", task.model.identifier)
106
+ metrics["model_used"] = model_used
107
+
108
+ # Convert response to dict for raw storage
109
+ raw_data = response.model_dump() if hasattr(response, "model_dump") else {}
110
+
111
+ return core_entities.GenerationRecord(
112
+ task=task,
113
+ output=core_entities.ModelOutput(text=text, raw=raw_data, usage=usage_dict),
114
+ error=None,
115
+ metrics=metrics,
116
+ )
117
+
118
+ except Exception as exc:
119
+ # Capture detailed error information
120
+ error_type = type(exc).__name__
121
+ error_message = str(exc)
122
+
123
+ # Extract additional context if available
124
+ details: Dict[str, Any] = {
125
+ "error_type": error_type,
126
+ "model": task.model.identifier,
127
+ }
128
+
129
+ # Check for specific litellm exceptions
130
+ if hasattr(exc, "status_code"):
131
+ details["status_code"] = exc.status_code # type: ignore
132
+ if hasattr(exc, "llm_provider"):
133
+ details["llm_provider"] = exc.llm_provider # type: ignore
134
+
135
+ return core_entities.GenerationRecord(
136
+ task=task,
137
+ output=None,
138
+ error=core_entities.ModelError(
139
+ message=error_message,
140
+ kind=error_type,
141
+ details=details,
142
+ ),
143
+ metrics={},
144
+ )
145
+
146
+ def _build_messages(
147
+ self, task: core_entities.GenerationTask
148
+ ) -> list[dict[str, str]]:
149
+ """Build messages array from the generation task."""
150
+ messages = []
151
+
152
+ # Add system message if provided in metadata
153
+ system_prompt = task.prompt.metadata.get(
154
+ "system_prompt"
155
+ ) or task.prompt.spec.metadata.get("system_prompt")
156
+ if system_prompt:
157
+ messages.append({"role": "system", "content": system_prompt})
158
+
159
+ # Add the main user prompt
160
+ messages.append({"role": "user", "content": task.prompt.text})
161
+
162
+ # Support for conversation history if provided
163
+ conversation_history = task.metadata.get("conversation_history")
164
+ if conversation_history:
165
+ # If conversation history is provided, prepend it
166
+ messages = conversation_history + messages
167
+
168
+ return messages
169
+
170
+ def _build_completion_kwargs(
171
+ self, task: core_entities.GenerationTask, messages: list[dict[str, str]]
172
+ ) -> Dict[str, Any]:
173
+ """Build the kwargs dictionary for litellm.completion()."""
174
+
175
+ kwargs: Dict[str, Any] = {
176
+ "model": task.model.identifier,
177
+ "messages": messages,
178
+ "temperature": task.sampling.temperature,
179
+ "top_p": task.sampling.top_p,
180
+ "timeout": self.timeout,
181
+ }
182
+
183
+ # Add max_tokens if specified (negative values mean no limit)
184
+ if task.sampling.max_tokens >= 0:
185
+ kwargs["max_tokens"] = task.sampling.max_tokens
186
+
187
+ # Add API key if provided
188
+ if self.api_key:
189
+ kwargs["api_key"] = self.api_key
190
+
191
+ # Add custom API base if provided
192
+ if self.api_base:
193
+ kwargs["api_base"] = self.api_base
194
+
195
+ # Add custom provider if specified
196
+ if self.custom_llm_provider:
197
+ kwargs["custom_llm_provider"] = self.custom_llm_provider
198
+
199
+ # Merge any extra kwargs provided in configuration
200
+ if self._extra_kwargs:
201
+ kwargs.update(self._extra_kwargs)
202
+
203
+ # Allow task-level overrides via metadata
204
+ litellm_kwargs = task.metadata.get("litellm_kwargs", {})
205
+ if litellm_kwargs:
206
+ kwargs.update(litellm_kwargs)
207
+
208
+ return kwargs
209
+
210
+
211
+ # Register the provider with multiple aliases for convenience
212
+ register_provider("litellm", LiteLLMProvider)
213
+ register_provider("openai", LiteLLMProvider)
214
+ register_provider("anthropic", LiteLLMProvider)
215
+ register_provider("azure", LiteLLMProvider)
216
+ register_provider("bedrock", LiteLLMProvider)
217
+ register_provider("gemini", LiteLLMProvider)
218
+ register_provider("cohere", LiteLLMProvider)
219
+
220
+
221
+ __all__ = ["LiteLLMProvider"]