themis-eval 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. themis/cli/__init__.py +5 -0
  2. themis/cli/__main__.py +6 -0
  3. themis/cli/commands/__init__.py +19 -0
  4. themis/cli/commands/benchmarks.py +221 -0
  5. themis/cli/commands/comparison.py +394 -0
  6. themis/cli/commands/config_commands.py +244 -0
  7. themis/cli/commands/cost.py +214 -0
  8. themis/cli/commands/demo.py +68 -0
  9. themis/cli/commands/info.py +90 -0
  10. themis/cli/commands/leaderboard.py +362 -0
  11. themis/cli/commands/math_benchmarks.py +318 -0
  12. themis/cli/commands/mcq_benchmarks.py +207 -0
  13. themis/cli/commands/sample_run.py +244 -0
  14. themis/cli/commands/visualize.py +299 -0
  15. themis/cli/main.py +93 -0
  16. themis/cli/new_project.py +33 -0
  17. themis/cli/utils.py +51 -0
  18. themis/config/__init__.py +19 -0
  19. themis/config/loader.py +27 -0
  20. themis/config/registry.py +34 -0
  21. themis/config/runtime.py +214 -0
  22. themis/config/schema.py +112 -0
  23. themis/core/__init__.py +5 -0
  24. themis/core/conversation.py +354 -0
  25. themis/core/entities.py +164 -0
  26. themis/core/serialization.py +231 -0
  27. themis/core/tools.py +393 -0
  28. themis/core/types.py +141 -0
  29. themis/datasets/__init__.py +273 -0
  30. themis/datasets/base.py +264 -0
  31. themis/datasets/commonsense_qa.py +174 -0
  32. themis/datasets/competition_math.py +265 -0
  33. themis/datasets/coqa.py +133 -0
  34. themis/datasets/gpqa.py +190 -0
  35. themis/datasets/gsm8k.py +123 -0
  36. themis/datasets/gsm_symbolic.py +124 -0
  37. themis/datasets/math500.py +122 -0
  38. themis/datasets/med_qa.py +179 -0
  39. themis/datasets/medmcqa.py +169 -0
  40. themis/datasets/mmlu_pro.py +262 -0
  41. themis/datasets/piqa.py +146 -0
  42. themis/datasets/registry.py +201 -0
  43. themis/datasets/schema.py +245 -0
  44. themis/datasets/sciq.py +150 -0
  45. themis/datasets/social_i_qa.py +151 -0
  46. themis/datasets/super_gpqa.py +263 -0
  47. themis/evaluation/__init__.py +1 -0
  48. themis/evaluation/conditional.py +410 -0
  49. themis/evaluation/extractors/__init__.py +19 -0
  50. themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
  51. themis/evaluation/extractors/exceptions.py +7 -0
  52. themis/evaluation/extractors/identity_extractor.py +29 -0
  53. themis/evaluation/extractors/json_field_extractor.py +45 -0
  54. themis/evaluation/extractors/math_verify_extractor.py +37 -0
  55. themis/evaluation/extractors/regex_extractor.py +43 -0
  56. themis/evaluation/math_verify_utils.py +87 -0
  57. themis/evaluation/metrics/__init__.py +21 -0
  58. themis/evaluation/metrics/composite_metric.py +47 -0
  59. themis/evaluation/metrics/consistency_metric.py +80 -0
  60. themis/evaluation/metrics/exact_match.py +51 -0
  61. themis/evaluation/metrics/length_difference_tolerance.py +33 -0
  62. themis/evaluation/metrics/math_verify_accuracy.py +40 -0
  63. themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
  64. themis/evaluation/metrics/response_length.py +33 -0
  65. themis/evaluation/metrics/rubric_judge_metric.py +134 -0
  66. themis/evaluation/pipeline.py +49 -0
  67. themis/evaluation/pipelines/__init__.py +15 -0
  68. themis/evaluation/pipelines/composable_pipeline.py +357 -0
  69. themis/evaluation/pipelines/standard_pipeline.py +288 -0
  70. themis/evaluation/reports.py +293 -0
  71. themis/evaluation/statistics/__init__.py +53 -0
  72. themis/evaluation/statistics/bootstrap.py +79 -0
  73. themis/evaluation/statistics/confidence_intervals.py +121 -0
  74. themis/evaluation/statistics/distributions.py +207 -0
  75. themis/evaluation/statistics/effect_sizes.py +124 -0
  76. themis/evaluation/statistics/hypothesis_tests.py +305 -0
  77. themis/evaluation/statistics/types.py +139 -0
  78. themis/evaluation/strategies/__init__.py +13 -0
  79. themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
  80. themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
  81. themis/evaluation/strategies/evaluation_strategy.py +24 -0
  82. themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
  83. themis/experiment/__init__.py +5 -0
  84. themis/experiment/builder.py +151 -0
  85. themis/experiment/cache_manager.py +129 -0
  86. themis/experiment/comparison.py +631 -0
  87. themis/experiment/cost.py +310 -0
  88. themis/experiment/definitions.py +62 -0
  89. themis/experiment/export.py +690 -0
  90. themis/experiment/export_csv.py +159 -0
  91. themis/experiment/integration_manager.py +104 -0
  92. themis/experiment/math.py +192 -0
  93. themis/experiment/mcq.py +169 -0
  94. themis/experiment/orchestrator.py +373 -0
  95. themis/experiment/pricing.py +317 -0
  96. themis/experiment/storage.py +255 -0
  97. themis/experiment/visualization.py +588 -0
  98. themis/generation/__init__.py +1 -0
  99. themis/generation/agentic_runner.py +420 -0
  100. themis/generation/batching.py +254 -0
  101. themis/generation/clients.py +143 -0
  102. themis/generation/conversation_runner.py +236 -0
  103. themis/generation/plan.py +456 -0
  104. themis/generation/providers/litellm_provider.py +221 -0
  105. themis/generation/providers/vllm_provider.py +135 -0
  106. themis/generation/router.py +34 -0
  107. themis/generation/runner.py +207 -0
  108. themis/generation/strategies.py +98 -0
  109. themis/generation/templates.py +71 -0
  110. themis/generation/turn_strategies.py +393 -0
  111. themis/generation/types.py +9 -0
  112. themis/integrations/__init__.py +0 -0
  113. themis/integrations/huggingface.py +61 -0
  114. themis/integrations/wandb.py +65 -0
  115. themis/interfaces/__init__.py +83 -0
  116. themis/project/__init__.py +20 -0
  117. themis/project/definitions.py +98 -0
  118. themis/project/patterns.py +230 -0
  119. themis/providers/__init__.py +5 -0
  120. themis/providers/registry.py +39 -0
  121. themis/utils/api_generator.py +379 -0
  122. themis/utils/cost_tracking.py +376 -0
  123. themis/utils/dashboard.py +452 -0
  124. themis/utils/logging_utils.py +41 -0
  125. themis/utils/progress.py +58 -0
  126. themis/utils/tracing.py +320 -0
  127. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/METADATA +1 -1
  128. themis_eval-0.1.1.dist-info/RECORD +134 -0
  129. themis_eval-0.1.0.dist-info/RECORD +0 -8
  130. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/WHEEL +0 -0
  131. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/licenses/LICENSE +0 -0
  132. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,135 @@
1
+ """vLLM provider using AsyncLLMEngine."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import threading
7
+ import time
8
+ from typing import Any, Dict, List
9
+
10
+ from themis.core import entities as core_entities
11
+ from themis.interfaces import ModelProvider
12
+ from themis.providers import register_provider
13
+
14
+
15
+ class VLLMProvider(ModelProvider):
16
+ def __init__(
17
+ self,
18
+ *,
19
+ model: str,
20
+ tensor_parallel_size: int = 1,
21
+ max_parallel: int = 2,
22
+ engine_kwargs: Dict[str, Any] | None = None,
23
+ ) -> None:
24
+ self._model_name = model
25
+ self._tp_size = max(1, tensor_parallel_size)
26
+ self._max_parallel = max(1, max_parallel)
27
+ self._engine_kwargs = engine_kwargs or {}
28
+ self._engines = self._create_engines()
29
+ self._engine_lock = threading.Lock()
30
+ self._rr_index = 0
31
+ self._semaphore = threading.Semaphore(self._max_parallel)
32
+
33
+ def generate(
34
+ self, task: core_entities.GenerationTask
35
+ ) -> core_entities.GenerationRecord: # type: ignore[override]
36
+ with self._semaphore:
37
+ engine = self._select_engine()
38
+ text, raw = asyncio.run(self._run_generation(engine, task))
39
+ metrics = {k: v for k, v in raw.items() if k != "chunks"}
40
+ return core_entities.GenerationRecord(
41
+ task=task,
42
+ output=core_entities.ModelOutput(text=text, raw=raw),
43
+ error=None,
44
+ metrics=metrics,
45
+ )
46
+
47
+ async def _run_generation(self, engine, task: core_entities.GenerationTask):
48
+ SamplingParams = self._sampling_params_cls
49
+ sampling_params = SamplingParams(
50
+ temperature=task.sampling.temperature,
51
+ top_p=task.sampling.top_p,
52
+ max_tokens=None
53
+ if task.sampling.max_tokens < 0
54
+ else task.sampling.max_tokens,
55
+ )
56
+ dataset_id = task.metadata.get("dataset_id", "sample")
57
+ request_id = f"themis-{dataset_id}-{time.time_ns()}"
58
+ chunks: List[str] = []
59
+ tokenizer = getattr(engine, "tokenizer", None)
60
+ async for output in engine.generate(
61
+ prompt=task.prompt.text,
62
+ sampling_params=sampling_params,
63
+ request_id=request_id,
64
+ ):
65
+ if output.outputs:
66
+ chunks.append(output.outputs[0].text)
67
+ final_text = chunks[-1] if chunks else ""
68
+ metrics = {"chunks": chunks}
69
+ if tokenizer is not None:
70
+ try:
71
+ metrics["prompt_tokens"] = len(tokenizer.encode(task.prompt.text))
72
+ metrics["response_tokens"] = len(tokenizer.encode(final_text))
73
+ except Exception: # pragma: no cover
74
+ pass
75
+ return final_text, metrics
76
+
77
+ def _select_engine(self):
78
+ with self._engine_lock:
79
+ engine = self._engines[self._rr_index]
80
+ self._rr_index = (self._rr_index + 1) % len(self._engines)
81
+ return engine
82
+
83
+ def _create_engines(self):
84
+ AsyncLLMEngine, SamplingParams = self._load_vllm_classes()
85
+ self._sampling_params_cls = SamplingParams
86
+ engine_count = self._determine_engine_count()
87
+ engines = []
88
+ for idx in range(engine_count):
89
+ engine = AsyncLLMEngine(
90
+ model=self._model_name,
91
+ tensor_parallel_size=self._tp_size,
92
+ **self._engine_kwargs,
93
+ )
94
+ engines.append(engine)
95
+ return engines
96
+
97
+ def _determine_engine_count(self) -> int:
98
+ device_count = 0
99
+ try:
100
+ import torch
101
+
102
+ if torch.cuda.is_available():
103
+ device_count = torch.cuda.device_count()
104
+ except ImportError:
105
+ device_count = 0
106
+ if device_count and device_count % self._tp_size == 0:
107
+ return max(1, device_count // self._tp_size)
108
+ return 1
109
+
110
+ def count_tokens(self, text: str) -> int | None:
111
+ tokenizer = (
112
+ getattr(self._engines[0], "tokenizer", None) if self._engines else None
113
+ )
114
+ if tokenizer is None:
115
+ return None
116
+ try:
117
+ return len(tokenizer.encode(text))
118
+ except Exception:
119
+ return None
120
+
121
+ @staticmethod
122
+ def _load_vllm_classes():
123
+ try:
124
+ from vllm import AsyncLLMEngine, SamplingParams
125
+ except ImportError as exc: # pragma: no cover - optional dep
126
+ raise RuntimeError(
127
+ "vLLM is not installed. Install via `pip install vllm` to use VLLMProvider."
128
+ ) from exc
129
+ return AsyncLLMEngine, SamplingParams
130
+
131
+
132
+ register_provider("vllm", VLLMProvider)
133
+
134
+
135
+ __all__ = ["VLLMProvider"]
@@ -0,0 +1,34 @@
1
+ """Utility router mapping generation tasks to providers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Mapping
6
+
7
+ from themis.core import entities as core_entities
8
+ from themis.interfaces import ModelProvider
9
+
10
+
11
+ class ProviderRouter(ModelProvider):
12
+ """Dispatches generation tasks to concrete providers by model identifier."""
13
+
14
+ def __init__(self, providers: Mapping[str, ModelProvider]):
15
+ self._providers = dict(providers)
16
+
17
+ def generate(
18
+ self, task: core_entities.GenerationTask
19
+ ) -> core_entities.GenerationRecord: # type: ignore[override]
20
+ provider = self._providers.get(task.model.identifier)
21
+ if provider is None:
22
+ known = ", ".join(sorted(self._providers)) or "<none>"
23
+ raise RuntimeError(
24
+ f"No provider registered for model '{task.model.identifier}'. "
25
+ f"Known providers: {known}."
26
+ )
27
+ return provider.generate(task)
28
+
29
+ @property
30
+ def providers(self) -> Mapping[str, ModelProvider]:
31
+ return self._providers
32
+
33
+
34
+ __all__ = ["ProviderRouter"]
@@ -0,0 +1,207 @@
1
+ """Generation runner primitives."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import time
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ from typing import Callable, Iterable, Iterator, List
9
+
10
+ from themis.core import entities as core_entities
11
+ from themis.generation import strategies
12
+ from themis.interfaces import ModelProvider
13
+ from themis.utils import tracing
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class GenerationRunner:
19
+ """Delegates generation tasks to an injected provider with strategy support."""
20
+
21
+ def __init__(
22
+ self,
23
+ *,
24
+ provider: ModelProvider,
25
+ strategy_resolver: Callable[
26
+ [core_entities.GenerationTask], strategies.GenerationStrategy
27
+ ]
28
+ | None = None,
29
+ max_parallel: int = 1,
30
+ max_retries: int = 3,
31
+ retry_initial_delay: float = 0.5,
32
+ retry_backoff_multiplier: float = 2.0,
33
+ retry_max_delay: float | None = 2.0,
34
+ ) -> None:
35
+ self._provider = provider
36
+ self._strategy_resolver = strategy_resolver or (
37
+ lambda task: strategies.SingleAttemptStrategy()
38
+ )
39
+ self._max_parallel = max(1, max_parallel)
40
+ self._max_retries = max(1, int(max_retries))
41
+ self._retry_initial_delay = max(0.0, retry_initial_delay)
42
+ self._retry_backoff_multiplier = max(1.0, retry_backoff_multiplier)
43
+ self._retry_max_delay = (
44
+ retry_max_delay if retry_max_delay is None else max(0.0, retry_max_delay)
45
+ )
46
+
47
+ def run(
48
+ self, tasks: Iterable[core_entities.GenerationTask]
49
+ ) -> Iterator[core_entities.GenerationRecord]:
50
+ task_list = list(tasks)
51
+ if not task_list:
52
+ return
53
+ if self._max_parallel <= 1:
54
+ for task in task_list:
55
+ yield self._execute_task(task)
56
+ return
57
+
58
+ with ThreadPoolExecutor(max_workers=self._max_parallel) as executor:
59
+ futures = [executor.submit(self._execute_task, task) for task in task_list]
60
+ for future in futures:
61
+ yield future.result()
62
+
63
+ def _run_single_attempt(
64
+ self, task: core_entities.GenerationTask
65
+ ) -> core_entities.GenerationRecord:
66
+ attempt_errors: List[dict[str, object]] = []
67
+ last_error: Exception | None = None
68
+ delay = self._retry_initial_delay
69
+ task_label = task.metadata.get("dataset_id") or task.prompt.template_name
70
+ for attempt in range(1, self._max_retries + 1):
71
+ try:
72
+ logger.debug(
73
+ "Starting generation for %s attempt %s/%s",
74
+ task_label,
75
+ attempt,
76
+ self._max_retries,
77
+ )
78
+ record = self._invoke_provider(task)
79
+ record.metrics["generation_attempts"] = attempt
80
+ if attempt_errors:
81
+ record.metrics.setdefault("retry_errors", attempt_errors)
82
+ logger.debug("Completed %s in %s attempt(s)", task_label, attempt)
83
+ return record
84
+ except Exception as exc: # pragma: no cover - defensive path
85
+ last_error = exc
86
+ logger.warning(
87
+ "Attempt %s/%s for %s failed: %s",
88
+ attempt,
89
+ self._max_retries,
90
+ task_label,
91
+ exc,
92
+ )
93
+ attempt_errors.append(
94
+ {
95
+ "attempt": attempt,
96
+ "error": str(exc),
97
+ "exception_type": exc.__class__.__name__,
98
+ }
99
+ )
100
+ if attempt >= self._max_retries:
101
+ break
102
+ if delay > 0:
103
+ time.sleep(delay)
104
+ delay = self._next_delay(delay)
105
+
106
+ return self._build_failure_record(task, attempt_errors, last_error)
107
+
108
+ def _invoke_provider(
109
+ self, task: core_entities.GenerationTask
110
+ ) -> core_entities.GenerationRecord:
111
+ start = time.perf_counter()
112
+
113
+ with tracing.span("provider_generate", model=task.model.identifier):
114
+ record = self._provider.generate(task)
115
+
116
+ elapsed_ms = (time.perf_counter() - start) * 1000
117
+ record.metrics.setdefault("generation_time_ms", elapsed_ms)
118
+ record.metrics.setdefault("prompt_chars", len(task.prompt.text))
119
+ prompt_tokens = record.metrics.get("prompt_tokens")
120
+ if prompt_tokens is None:
121
+ prompt_tokens = self._count_tokens(task.prompt.text)
122
+ if prompt_tokens is None:
123
+ prompt_tokens = len(task.prompt.text.split())
124
+ record.metrics["prompt_tokens"] = prompt_tokens
125
+ if record.output:
126
+ record.metrics.setdefault("response_chars", len(record.output.text))
127
+ response_tokens = record.metrics.get("response_tokens")
128
+ if response_tokens is None:
129
+ response_tokens = self._count_tokens(record.output.text)
130
+ if response_tokens is None:
131
+ response_tokens = len(record.output.text.split())
132
+ record.metrics["response_tokens"] = response_tokens
133
+ return record
134
+
135
+ def _next_delay(self, previous_delay: float) -> float:
136
+ if previous_delay <= 0:
137
+ next_delay = self._retry_initial_delay
138
+ else:
139
+ next_delay = previous_delay * self._retry_backoff_multiplier
140
+ if self._retry_max_delay is not None:
141
+ next_delay = min(next_delay, self._retry_max_delay)
142
+ return next_delay
143
+
144
+ def _build_failure_record(
145
+ self,
146
+ task: core_entities.GenerationTask,
147
+ attempt_errors: List[dict[str, object]],
148
+ last_error: Exception | None,
149
+ ) -> core_entities.GenerationRecord:
150
+ attempts = len(attempt_errors) or 1
151
+ cause = str(last_error) if last_error else "unknown error"
152
+ message = (
153
+ f"Generation failed for model '{task.model.identifier}' "
154
+ f"after {attempts} attempt(s): {cause}"
155
+ )
156
+ logger.error(
157
+ "All attempts failed for %s after %s tries",
158
+ task.metadata.get("dataset_id") or task.prompt.template_name,
159
+ attempts,
160
+ exc_info=last_error,
161
+ )
162
+ return core_entities.GenerationRecord(
163
+ task=task,
164
+ output=None,
165
+ error=core_entities.ModelError(
166
+ message=message,
167
+ kind="provider_error",
168
+ details={
169
+ "attempts": attempt_errors,
170
+ "model": task.model.identifier,
171
+ "provider": task.model.provider,
172
+ },
173
+ ),
174
+ metrics={"generation_attempts": attempts, "retry_errors": attempt_errors},
175
+ )
176
+
177
+ def _execute_task(
178
+ self, task: core_entities.GenerationTask
179
+ ) -> core_entities.GenerationRecord:
180
+ task_id = task.metadata.get("dataset_id", "unknown")
181
+ model_id = task.model.identifier
182
+
183
+ with tracing.span("execute_task", task_id=task_id, model=model_id):
184
+ strategy = self._strategy_resolver(task)
185
+ attempt_records: List[core_entities.GenerationRecord] = []
186
+
187
+ with tracing.span("expand_strategy"):
188
+ expansion = list(strategy.expand(task))
189
+
190
+ for attempt_task in expansion:
191
+ with tracing.span("run_attempt"):
192
+ attempt_records.append(self._run_single_attempt(attempt_task))
193
+
194
+ with tracing.span("aggregate_strategy"):
195
+ aggregated = strategy.aggregate(task, attempt_records)
196
+
197
+ aggregated.attempts = attempt_records
198
+ return aggregated
199
+
200
+ def _count_tokens(self, text: str) -> int | None:
201
+ counter = getattr(self._provider, "count_tokens", None)
202
+ if callable(counter):
203
+ try:
204
+ return int(counter(text))
205
+ except Exception: # pragma: no cover - tokenization failure
206
+ return None
207
+ return None
@@ -0,0 +1,98 @@
1
+ """Generation strategy interfaces and default implementations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Iterable, List, Protocol
7
+
8
+ from themis.core import entities as core_entities
9
+
10
+
11
+ class GenerationStrategy(Protocol):
12
+ """Strategy responsible for expanding a task into one or more execution attempts."""
13
+
14
+ def expand(
15
+ self, task: core_entities.GenerationTask
16
+ ) -> Iterable[core_entities.GenerationTask]: # pragma: no cover - interface
17
+ ...
18
+
19
+ def aggregate(
20
+ self,
21
+ task: core_entities.GenerationTask,
22
+ records: List[core_entities.GenerationRecord],
23
+ ) -> core_entities.GenerationRecord: # pragma: no cover - interface
24
+ ...
25
+
26
+
27
+ @dataclass
28
+ class SingleAttemptStrategy:
29
+ """Default strategy – run exactly once and pass-through result."""
30
+
31
+ def expand(
32
+ self, task: core_entities.GenerationTask
33
+ ) -> Iterable[core_entities.GenerationTask]:
34
+ return [task]
35
+
36
+ def aggregate(
37
+ self,
38
+ task: core_entities.GenerationTask,
39
+ records: List[core_entities.GenerationRecord],
40
+ ) -> core_entities.GenerationRecord:
41
+ record = records[0]
42
+ return core_entities.GenerationRecord(
43
+ task=task,
44
+ output=record.output,
45
+ error=record.error,
46
+ metrics=dict(record.metrics),
47
+ )
48
+
49
+
50
+ @dataclass
51
+ class RepeatedSamplingStrategy:
52
+ """Repeat the same task multiple times for test-time scaling."""
53
+
54
+ attempts: int
55
+ metadata_label: str = "attempts"
56
+
57
+ def expand(
58
+ self, task: core_entities.GenerationTask
59
+ ) -> Iterable[core_entities.GenerationTask]:
60
+ for index in range(self.attempts):
61
+ attempt_metadata = dict(task.metadata)
62
+ attempt_metadata[self.metadata_label] = index
63
+ yield core_entities.GenerationTask(
64
+ prompt=task.prompt,
65
+ model=task.model,
66
+ sampling=task.sampling,
67
+ metadata=attempt_metadata,
68
+ reference=task.reference,
69
+ )
70
+
71
+ def aggregate(
72
+ self,
73
+ task: core_entities.GenerationTask,
74
+ records: List[core_entities.GenerationRecord],
75
+ ) -> core_entities.GenerationRecord:
76
+ best = next((record for record in records if not record.error), records[0])
77
+ aggregated = core_entities.GenerationRecord(
78
+ task=task,
79
+ output=best.output,
80
+ error=best.error,
81
+ metrics=dict(best.metrics),
82
+ )
83
+ aggregated.metrics["attempt_count"] = len(records)
84
+ aggregated.metrics["attempt_outcomes"] = [
85
+ {
86
+ "output": record.output.text if record.output else None,
87
+ "error": record.error.message if record.error else None,
88
+ }
89
+ for record in records
90
+ ]
91
+ return aggregated
92
+
93
+
94
+ __all__ = [
95
+ "GenerationStrategy",
96
+ "SingleAttemptStrategy",
97
+ "RepeatedSamplingStrategy",
98
+ ]
@@ -0,0 +1,71 @@
1
+ """Prompt template primitives for Themis generation domain."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from itertools import product
7
+ from typing import Any, Dict, Iterable, List
8
+
9
+ from themis.core import entities as core_entities
10
+
11
+
12
+ class TemplateRenderingError(RuntimeError):
13
+ """Raised when a prompt template cannot be rendered."""
14
+
15
+
16
+ @dataclass
17
+ class PromptTemplate:
18
+ """Represents a format string and associated metadata."""
19
+
20
+ name: str
21
+ template: str
22
+ metadata: Dict[str, Any] | None = None
23
+
24
+ def __post_init__(self) -> None:
25
+ self._spec = core_entities.PromptSpec(
26
+ name=self.name,
27
+ template=self.template,
28
+ metadata=dict(self.metadata or {}),
29
+ )
30
+
31
+ def render(self, **kwargs: Any) -> str:
32
+ try:
33
+ return self.template.format(**kwargs)
34
+ except KeyError as exc: # pragma: no cover - defensive path
35
+ missing = exc.args[0]
36
+ raise TemplateRenderingError(
37
+ f"Missing template variable: {missing}"
38
+ ) from exc
39
+
40
+ def expand_variants(
41
+ self,
42
+ *,
43
+ base_context: Dict[str, Any],
44
+ variant_values: Dict[str, Iterable[Any]],
45
+ ) -> List[core_entities.PromptRender]:
46
+ """Generate prompts for the cross-product of variant fields."""
47
+
48
+ if not variant_values:
49
+ return [self._render_context(base_context)]
50
+
51
+ keys = sorted(variant_values.keys())
52
+ prompts: list[core_entities.PromptRender] = []
53
+ for combo in product(*(variant_values[key] for key in keys)):
54
+ combo_context = dict(base_context)
55
+ combo_context.update(dict(zip(keys, combo)))
56
+ prompts.append(self._render_context(combo_context))
57
+ return prompts
58
+
59
+ def render_prompt(self, context: Dict[str, Any]) -> core_entities.PromptRender:
60
+ """Render the template to a core PromptRender."""
61
+ return self._render_context(context)
62
+
63
+ def _render_context(self, context: Dict[str, Any]) -> core_entities.PromptRender:
64
+ prompt_text = self.render(**context)
65
+ metadata = dict(self.metadata or {})
66
+ return core_entities.PromptRender(
67
+ spec=self._spec,
68
+ text=prompt_text,
69
+ context=dict(context),
70
+ metadata=metadata,
71
+ )