themis-eval 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themis/cli/__init__.py +5 -0
- themis/cli/__main__.py +6 -0
- themis/cli/commands/__init__.py +19 -0
- themis/cli/commands/benchmarks.py +221 -0
- themis/cli/commands/comparison.py +394 -0
- themis/cli/commands/config_commands.py +244 -0
- themis/cli/commands/cost.py +214 -0
- themis/cli/commands/demo.py +68 -0
- themis/cli/commands/info.py +90 -0
- themis/cli/commands/leaderboard.py +362 -0
- themis/cli/commands/math_benchmarks.py +318 -0
- themis/cli/commands/mcq_benchmarks.py +207 -0
- themis/cli/commands/sample_run.py +244 -0
- themis/cli/commands/visualize.py +299 -0
- themis/cli/main.py +93 -0
- themis/cli/new_project.py +33 -0
- themis/cli/utils.py +51 -0
- themis/config/__init__.py +19 -0
- themis/config/loader.py +27 -0
- themis/config/registry.py +34 -0
- themis/config/runtime.py +214 -0
- themis/config/schema.py +112 -0
- themis/core/__init__.py +5 -0
- themis/core/conversation.py +354 -0
- themis/core/entities.py +164 -0
- themis/core/serialization.py +231 -0
- themis/core/tools.py +393 -0
- themis/core/types.py +141 -0
- themis/datasets/__init__.py +273 -0
- themis/datasets/base.py +264 -0
- themis/datasets/commonsense_qa.py +174 -0
- themis/datasets/competition_math.py +265 -0
- themis/datasets/coqa.py +133 -0
- themis/datasets/gpqa.py +190 -0
- themis/datasets/gsm8k.py +123 -0
- themis/datasets/gsm_symbolic.py +124 -0
- themis/datasets/math500.py +122 -0
- themis/datasets/med_qa.py +179 -0
- themis/datasets/medmcqa.py +169 -0
- themis/datasets/mmlu_pro.py +262 -0
- themis/datasets/piqa.py +146 -0
- themis/datasets/registry.py +201 -0
- themis/datasets/schema.py +245 -0
- themis/datasets/sciq.py +150 -0
- themis/datasets/social_i_qa.py +151 -0
- themis/datasets/super_gpqa.py +263 -0
- themis/evaluation/__init__.py +1 -0
- themis/evaluation/conditional.py +410 -0
- themis/evaluation/extractors/__init__.py +19 -0
- themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
- themis/evaluation/extractors/exceptions.py +7 -0
- themis/evaluation/extractors/identity_extractor.py +29 -0
- themis/evaluation/extractors/json_field_extractor.py +45 -0
- themis/evaluation/extractors/math_verify_extractor.py +37 -0
- themis/evaluation/extractors/regex_extractor.py +43 -0
- themis/evaluation/math_verify_utils.py +87 -0
- themis/evaluation/metrics/__init__.py +21 -0
- themis/evaluation/metrics/composite_metric.py +47 -0
- themis/evaluation/metrics/consistency_metric.py +80 -0
- themis/evaluation/metrics/exact_match.py +51 -0
- themis/evaluation/metrics/length_difference_tolerance.py +33 -0
- themis/evaluation/metrics/math_verify_accuracy.py +40 -0
- themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
- themis/evaluation/metrics/response_length.py +33 -0
- themis/evaluation/metrics/rubric_judge_metric.py +134 -0
- themis/evaluation/pipeline.py +49 -0
- themis/evaluation/pipelines/__init__.py +15 -0
- themis/evaluation/pipelines/composable_pipeline.py +357 -0
- themis/evaluation/pipelines/standard_pipeline.py +288 -0
- themis/evaluation/reports.py +293 -0
- themis/evaluation/statistics/__init__.py +53 -0
- themis/evaluation/statistics/bootstrap.py +79 -0
- themis/evaluation/statistics/confidence_intervals.py +121 -0
- themis/evaluation/statistics/distributions.py +207 -0
- themis/evaluation/statistics/effect_sizes.py +124 -0
- themis/evaluation/statistics/hypothesis_tests.py +305 -0
- themis/evaluation/statistics/types.py +139 -0
- themis/evaluation/strategies/__init__.py +13 -0
- themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
- themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
- themis/evaluation/strategies/evaluation_strategy.py +24 -0
- themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
- themis/experiment/__init__.py +5 -0
- themis/experiment/builder.py +151 -0
- themis/experiment/cache_manager.py +129 -0
- themis/experiment/comparison.py +631 -0
- themis/experiment/cost.py +310 -0
- themis/experiment/definitions.py +62 -0
- themis/experiment/export.py +690 -0
- themis/experiment/export_csv.py +159 -0
- themis/experiment/integration_manager.py +104 -0
- themis/experiment/math.py +192 -0
- themis/experiment/mcq.py +169 -0
- themis/experiment/orchestrator.py +373 -0
- themis/experiment/pricing.py +317 -0
- themis/experiment/storage.py +255 -0
- themis/experiment/visualization.py +588 -0
- themis/generation/__init__.py +1 -0
- themis/generation/agentic_runner.py +420 -0
- themis/generation/batching.py +254 -0
- themis/generation/clients.py +143 -0
- themis/generation/conversation_runner.py +236 -0
- themis/generation/plan.py +456 -0
- themis/generation/providers/litellm_provider.py +221 -0
- themis/generation/providers/vllm_provider.py +135 -0
- themis/generation/router.py +34 -0
- themis/generation/runner.py +207 -0
- themis/generation/strategies.py +98 -0
- themis/generation/templates.py +71 -0
- themis/generation/turn_strategies.py +393 -0
- themis/generation/types.py +9 -0
- themis/integrations/__init__.py +0 -0
- themis/integrations/huggingface.py +61 -0
- themis/integrations/wandb.py +65 -0
- themis/interfaces/__init__.py +83 -0
- themis/project/__init__.py +20 -0
- themis/project/definitions.py +98 -0
- themis/project/patterns.py +230 -0
- themis/providers/__init__.py +5 -0
- themis/providers/registry.py +39 -0
- themis/utils/api_generator.py +379 -0
- themis/utils/cost_tracking.py +376 -0
- themis/utils/dashboard.py +452 -0
- themis/utils/logging_utils.py +41 -0
- themis/utils/progress.py +58 -0
- themis/utils/tracing.py +320 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/METADATA +1 -1
- themis_eval-0.1.1.dist-info/RECORD +134 -0
- themis_eval-0.1.0.dist-info/RECORD +0 -8
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/WHEEL +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
"""vLLM provider using AsyncLLMEngine."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import threading
|
|
7
|
+
import time
|
|
8
|
+
from typing import Any, Dict, List
|
|
9
|
+
|
|
10
|
+
from themis.core import entities as core_entities
|
|
11
|
+
from themis.interfaces import ModelProvider
|
|
12
|
+
from themis.providers import register_provider
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class VLLMProvider(ModelProvider):
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
*,
|
|
19
|
+
model: str,
|
|
20
|
+
tensor_parallel_size: int = 1,
|
|
21
|
+
max_parallel: int = 2,
|
|
22
|
+
engine_kwargs: Dict[str, Any] | None = None,
|
|
23
|
+
) -> None:
|
|
24
|
+
self._model_name = model
|
|
25
|
+
self._tp_size = max(1, tensor_parallel_size)
|
|
26
|
+
self._max_parallel = max(1, max_parallel)
|
|
27
|
+
self._engine_kwargs = engine_kwargs or {}
|
|
28
|
+
self._engines = self._create_engines()
|
|
29
|
+
self._engine_lock = threading.Lock()
|
|
30
|
+
self._rr_index = 0
|
|
31
|
+
self._semaphore = threading.Semaphore(self._max_parallel)
|
|
32
|
+
|
|
33
|
+
def generate(
|
|
34
|
+
self, task: core_entities.GenerationTask
|
|
35
|
+
) -> core_entities.GenerationRecord: # type: ignore[override]
|
|
36
|
+
with self._semaphore:
|
|
37
|
+
engine = self._select_engine()
|
|
38
|
+
text, raw = asyncio.run(self._run_generation(engine, task))
|
|
39
|
+
metrics = {k: v for k, v in raw.items() if k != "chunks"}
|
|
40
|
+
return core_entities.GenerationRecord(
|
|
41
|
+
task=task,
|
|
42
|
+
output=core_entities.ModelOutput(text=text, raw=raw),
|
|
43
|
+
error=None,
|
|
44
|
+
metrics=metrics,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
async def _run_generation(self, engine, task: core_entities.GenerationTask):
|
|
48
|
+
SamplingParams = self._sampling_params_cls
|
|
49
|
+
sampling_params = SamplingParams(
|
|
50
|
+
temperature=task.sampling.temperature,
|
|
51
|
+
top_p=task.sampling.top_p,
|
|
52
|
+
max_tokens=None
|
|
53
|
+
if task.sampling.max_tokens < 0
|
|
54
|
+
else task.sampling.max_tokens,
|
|
55
|
+
)
|
|
56
|
+
dataset_id = task.metadata.get("dataset_id", "sample")
|
|
57
|
+
request_id = f"themis-{dataset_id}-{time.time_ns()}"
|
|
58
|
+
chunks: List[str] = []
|
|
59
|
+
tokenizer = getattr(engine, "tokenizer", None)
|
|
60
|
+
async for output in engine.generate(
|
|
61
|
+
prompt=task.prompt.text,
|
|
62
|
+
sampling_params=sampling_params,
|
|
63
|
+
request_id=request_id,
|
|
64
|
+
):
|
|
65
|
+
if output.outputs:
|
|
66
|
+
chunks.append(output.outputs[0].text)
|
|
67
|
+
final_text = chunks[-1] if chunks else ""
|
|
68
|
+
metrics = {"chunks": chunks}
|
|
69
|
+
if tokenizer is not None:
|
|
70
|
+
try:
|
|
71
|
+
metrics["prompt_tokens"] = len(tokenizer.encode(task.prompt.text))
|
|
72
|
+
metrics["response_tokens"] = len(tokenizer.encode(final_text))
|
|
73
|
+
except Exception: # pragma: no cover
|
|
74
|
+
pass
|
|
75
|
+
return final_text, metrics
|
|
76
|
+
|
|
77
|
+
def _select_engine(self):
|
|
78
|
+
with self._engine_lock:
|
|
79
|
+
engine = self._engines[self._rr_index]
|
|
80
|
+
self._rr_index = (self._rr_index + 1) % len(self._engines)
|
|
81
|
+
return engine
|
|
82
|
+
|
|
83
|
+
def _create_engines(self):
|
|
84
|
+
AsyncLLMEngine, SamplingParams = self._load_vllm_classes()
|
|
85
|
+
self._sampling_params_cls = SamplingParams
|
|
86
|
+
engine_count = self._determine_engine_count()
|
|
87
|
+
engines = []
|
|
88
|
+
for idx in range(engine_count):
|
|
89
|
+
engine = AsyncLLMEngine(
|
|
90
|
+
model=self._model_name,
|
|
91
|
+
tensor_parallel_size=self._tp_size,
|
|
92
|
+
**self._engine_kwargs,
|
|
93
|
+
)
|
|
94
|
+
engines.append(engine)
|
|
95
|
+
return engines
|
|
96
|
+
|
|
97
|
+
def _determine_engine_count(self) -> int:
|
|
98
|
+
device_count = 0
|
|
99
|
+
try:
|
|
100
|
+
import torch
|
|
101
|
+
|
|
102
|
+
if torch.cuda.is_available():
|
|
103
|
+
device_count = torch.cuda.device_count()
|
|
104
|
+
except ImportError:
|
|
105
|
+
device_count = 0
|
|
106
|
+
if device_count and device_count % self._tp_size == 0:
|
|
107
|
+
return max(1, device_count // self._tp_size)
|
|
108
|
+
return 1
|
|
109
|
+
|
|
110
|
+
def count_tokens(self, text: str) -> int | None:
|
|
111
|
+
tokenizer = (
|
|
112
|
+
getattr(self._engines[0], "tokenizer", None) if self._engines else None
|
|
113
|
+
)
|
|
114
|
+
if tokenizer is None:
|
|
115
|
+
return None
|
|
116
|
+
try:
|
|
117
|
+
return len(tokenizer.encode(text))
|
|
118
|
+
except Exception:
|
|
119
|
+
return None
|
|
120
|
+
|
|
121
|
+
@staticmethod
|
|
122
|
+
def _load_vllm_classes():
|
|
123
|
+
try:
|
|
124
|
+
from vllm import AsyncLLMEngine, SamplingParams
|
|
125
|
+
except ImportError as exc: # pragma: no cover - optional dep
|
|
126
|
+
raise RuntimeError(
|
|
127
|
+
"vLLM is not installed. Install via `pip install vllm` to use VLLMProvider."
|
|
128
|
+
) from exc
|
|
129
|
+
return AsyncLLMEngine, SamplingParams
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
register_provider("vllm", VLLMProvider)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
__all__ = ["VLLMProvider"]
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Utility router mapping generation tasks to providers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Mapping
|
|
6
|
+
|
|
7
|
+
from themis.core import entities as core_entities
|
|
8
|
+
from themis.interfaces import ModelProvider
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ProviderRouter(ModelProvider):
|
|
12
|
+
"""Dispatches generation tasks to concrete providers by model identifier."""
|
|
13
|
+
|
|
14
|
+
def __init__(self, providers: Mapping[str, ModelProvider]):
|
|
15
|
+
self._providers = dict(providers)
|
|
16
|
+
|
|
17
|
+
def generate(
|
|
18
|
+
self, task: core_entities.GenerationTask
|
|
19
|
+
) -> core_entities.GenerationRecord: # type: ignore[override]
|
|
20
|
+
provider = self._providers.get(task.model.identifier)
|
|
21
|
+
if provider is None:
|
|
22
|
+
known = ", ".join(sorted(self._providers)) or "<none>"
|
|
23
|
+
raise RuntimeError(
|
|
24
|
+
f"No provider registered for model '{task.model.identifier}'. "
|
|
25
|
+
f"Known providers: {known}."
|
|
26
|
+
)
|
|
27
|
+
return provider.generate(task)
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def providers(self) -> Mapping[str, ModelProvider]:
|
|
31
|
+
return self._providers
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
__all__ = ["ProviderRouter"]
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
"""Generation runner primitives."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import time
|
|
7
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
8
|
+
from typing import Callable, Iterable, Iterator, List
|
|
9
|
+
|
|
10
|
+
from themis.core import entities as core_entities
|
|
11
|
+
from themis.generation import strategies
|
|
12
|
+
from themis.interfaces import ModelProvider
|
|
13
|
+
from themis.utils import tracing
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class GenerationRunner:
|
|
19
|
+
"""Delegates generation tasks to an injected provider with strategy support."""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
*,
|
|
24
|
+
provider: ModelProvider,
|
|
25
|
+
strategy_resolver: Callable[
|
|
26
|
+
[core_entities.GenerationTask], strategies.GenerationStrategy
|
|
27
|
+
]
|
|
28
|
+
| None = None,
|
|
29
|
+
max_parallel: int = 1,
|
|
30
|
+
max_retries: int = 3,
|
|
31
|
+
retry_initial_delay: float = 0.5,
|
|
32
|
+
retry_backoff_multiplier: float = 2.0,
|
|
33
|
+
retry_max_delay: float | None = 2.0,
|
|
34
|
+
) -> None:
|
|
35
|
+
self._provider = provider
|
|
36
|
+
self._strategy_resolver = strategy_resolver or (
|
|
37
|
+
lambda task: strategies.SingleAttemptStrategy()
|
|
38
|
+
)
|
|
39
|
+
self._max_parallel = max(1, max_parallel)
|
|
40
|
+
self._max_retries = max(1, int(max_retries))
|
|
41
|
+
self._retry_initial_delay = max(0.0, retry_initial_delay)
|
|
42
|
+
self._retry_backoff_multiplier = max(1.0, retry_backoff_multiplier)
|
|
43
|
+
self._retry_max_delay = (
|
|
44
|
+
retry_max_delay if retry_max_delay is None else max(0.0, retry_max_delay)
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
def run(
|
|
48
|
+
self, tasks: Iterable[core_entities.GenerationTask]
|
|
49
|
+
) -> Iterator[core_entities.GenerationRecord]:
|
|
50
|
+
task_list = list(tasks)
|
|
51
|
+
if not task_list:
|
|
52
|
+
return
|
|
53
|
+
if self._max_parallel <= 1:
|
|
54
|
+
for task in task_list:
|
|
55
|
+
yield self._execute_task(task)
|
|
56
|
+
return
|
|
57
|
+
|
|
58
|
+
with ThreadPoolExecutor(max_workers=self._max_parallel) as executor:
|
|
59
|
+
futures = [executor.submit(self._execute_task, task) for task in task_list]
|
|
60
|
+
for future in futures:
|
|
61
|
+
yield future.result()
|
|
62
|
+
|
|
63
|
+
def _run_single_attempt(
|
|
64
|
+
self, task: core_entities.GenerationTask
|
|
65
|
+
) -> core_entities.GenerationRecord:
|
|
66
|
+
attempt_errors: List[dict[str, object]] = []
|
|
67
|
+
last_error: Exception | None = None
|
|
68
|
+
delay = self._retry_initial_delay
|
|
69
|
+
task_label = task.metadata.get("dataset_id") or task.prompt.template_name
|
|
70
|
+
for attempt in range(1, self._max_retries + 1):
|
|
71
|
+
try:
|
|
72
|
+
logger.debug(
|
|
73
|
+
"Starting generation for %s attempt %s/%s",
|
|
74
|
+
task_label,
|
|
75
|
+
attempt,
|
|
76
|
+
self._max_retries,
|
|
77
|
+
)
|
|
78
|
+
record = self._invoke_provider(task)
|
|
79
|
+
record.metrics["generation_attempts"] = attempt
|
|
80
|
+
if attempt_errors:
|
|
81
|
+
record.metrics.setdefault("retry_errors", attempt_errors)
|
|
82
|
+
logger.debug("Completed %s in %s attempt(s)", task_label, attempt)
|
|
83
|
+
return record
|
|
84
|
+
except Exception as exc: # pragma: no cover - defensive path
|
|
85
|
+
last_error = exc
|
|
86
|
+
logger.warning(
|
|
87
|
+
"Attempt %s/%s for %s failed: %s",
|
|
88
|
+
attempt,
|
|
89
|
+
self._max_retries,
|
|
90
|
+
task_label,
|
|
91
|
+
exc,
|
|
92
|
+
)
|
|
93
|
+
attempt_errors.append(
|
|
94
|
+
{
|
|
95
|
+
"attempt": attempt,
|
|
96
|
+
"error": str(exc),
|
|
97
|
+
"exception_type": exc.__class__.__name__,
|
|
98
|
+
}
|
|
99
|
+
)
|
|
100
|
+
if attempt >= self._max_retries:
|
|
101
|
+
break
|
|
102
|
+
if delay > 0:
|
|
103
|
+
time.sleep(delay)
|
|
104
|
+
delay = self._next_delay(delay)
|
|
105
|
+
|
|
106
|
+
return self._build_failure_record(task, attempt_errors, last_error)
|
|
107
|
+
|
|
108
|
+
def _invoke_provider(
|
|
109
|
+
self, task: core_entities.GenerationTask
|
|
110
|
+
) -> core_entities.GenerationRecord:
|
|
111
|
+
start = time.perf_counter()
|
|
112
|
+
|
|
113
|
+
with tracing.span("provider_generate", model=task.model.identifier):
|
|
114
|
+
record = self._provider.generate(task)
|
|
115
|
+
|
|
116
|
+
elapsed_ms = (time.perf_counter() - start) * 1000
|
|
117
|
+
record.metrics.setdefault("generation_time_ms", elapsed_ms)
|
|
118
|
+
record.metrics.setdefault("prompt_chars", len(task.prompt.text))
|
|
119
|
+
prompt_tokens = record.metrics.get("prompt_tokens")
|
|
120
|
+
if prompt_tokens is None:
|
|
121
|
+
prompt_tokens = self._count_tokens(task.prompt.text)
|
|
122
|
+
if prompt_tokens is None:
|
|
123
|
+
prompt_tokens = len(task.prompt.text.split())
|
|
124
|
+
record.metrics["prompt_tokens"] = prompt_tokens
|
|
125
|
+
if record.output:
|
|
126
|
+
record.metrics.setdefault("response_chars", len(record.output.text))
|
|
127
|
+
response_tokens = record.metrics.get("response_tokens")
|
|
128
|
+
if response_tokens is None:
|
|
129
|
+
response_tokens = self._count_tokens(record.output.text)
|
|
130
|
+
if response_tokens is None:
|
|
131
|
+
response_tokens = len(record.output.text.split())
|
|
132
|
+
record.metrics["response_tokens"] = response_tokens
|
|
133
|
+
return record
|
|
134
|
+
|
|
135
|
+
def _next_delay(self, previous_delay: float) -> float:
|
|
136
|
+
if previous_delay <= 0:
|
|
137
|
+
next_delay = self._retry_initial_delay
|
|
138
|
+
else:
|
|
139
|
+
next_delay = previous_delay * self._retry_backoff_multiplier
|
|
140
|
+
if self._retry_max_delay is not None:
|
|
141
|
+
next_delay = min(next_delay, self._retry_max_delay)
|
|
142
|
+
return next_delay
|
|
143
|
+
|
|
144
|
+
def _build_failure_record(
|
|
145
|
+
self,
|
|
146
|
+
task: core_entities.GenerationTask,
|
|
147
|
+
attempt_errors: List[dict[str, object]],
|
|
148
|
+
last_error: Exception | None,
|
|
149
|
+
) -> core_entities.GenerationRecord:
|
|
150
|
+
attempts = len(attempt_errors) or 1
|
|
151
|
+
cause = str(last_error) if last_error else "unknown error"
|
|
152
|
+
message = (
|
|
153
|
+
f"Generation failed for model '{task.model.identifier}' "
|
|
154
|
+
f"after {attempts} attempt(s): {cause}"
|
|
155
|
+
)
|
|
156
|
+
logger.error(
|
|
157
|
+
"All attempts failed for %s after %s tries",
|
|
158
|
+
task.metadata.get("dataset_id") or task.prompt.template_name,
|
|
159
|
+
attempts,
|
|
160
|
+
exc_info=last_error,
|
|
161
|
+
)
|
|
162
|
+
return core_entities.GenerationRecord(
|
|
163
|
+
task=task,
|
|
164
|
+
output=None,
|
|
165
|
+
error=core_entities.ModelError(
|
|
166
|
+
message=message,
|
|
167
|
+
kind="provider_error",
|
|
168
|
+
details={
|
|
169
|
+
"attempts": attempt_errors,
|
|
170
|
+
"model": task.model.identifier,
|
|
171
|
+
"provider": task.model.provider,
|
|
172
|
+
},
|
|
173
|
+
),
|
|
174
|
+
metrics={"generation_attempts": attempts, "retry_errors": attempt_errors},
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
def _execute_task(
|
|
178
|
+
self, task: core_entities.GenerationTask
|
|
179
|
+
) -> core_entities.GenerationRecord:
|
|
180
|
+
task_id = task.metadata.get("dataset_id", "unknown")
|
|
181
|
+
model_id = task.model.identifier
|
|
182
|
+
|
|
183
|
+
with tracing.span("execute_task", task_id=task_id, model=model_id):
|
|
184
|
+
strategy = self._strategy_resolver(task)
|
|
185
|
+
attempt_records: List[core_entities.GenerationRecord] = []
|
|
186
|
+
|
|
187
|
+
with tracing.span("expand_strategy"):
|
|
188
|
+
expansion = list(strategy.expand(task))
|
|
189
|
+
|
|
190
|
+
for attempt_task in expansion:
|
|
191
|
+
with tracing.span("run_attempt"):
|
|
192
|
+
attempt_records.append(self._run_single_attempt(attempt_task))
|
|
193
|
+
|
|
194
|
+
with tracing.span("aggregate_strategy"):
|
|
195
|
+
aggregated = strategy.aggregate(task, attempt_records)
|
|
196
|
+
|
|
197
|
+
aggregated.attempts = attempt_records
|
|
198
|
+
return aggregated
|
|
199
|
+
|
|
200
|
+
def _count_tokens(self, text: str) -> int | None:
|
|
201
|
+
counter = getattr(self._provider, "count_tokens", None)
|
|
202
|
+
if callable(counter):
|
|
203
|
+
try:
|
|
204
|
+
return int(counter(text))
|
|
205
|
+
except Exception: # pragma: no cover - tokenization failure
|
|
206
|
+
return None
|
|
207
|
+
return None
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""Generation strategy interfaces and default implementations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Iterable, List, Protocol
|
|
7
|
+
|
|
8
|
+
from themis.core import entities as core_entities
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GenerationStrategy(Protocol):
|
|
12
|
+
"""Strategy responsible for expanding a task into one or more execution attempts."""
|
|
13
|
+
|
|
14
|
+
def expand(
|
|
15
|
+
self, task: core_entities.GenerationTask
|
|
16
|
+
) -> Iterable[core_entities.GenerationTask]: # pragma: no cover - interface
|
|
17
|
+
...
|
|
18
|
+
|
|
19
|
+
def aggregate(
|
|
20
|
+
self,
|
|
21
|
+
task: core_entities.GenerationTask,
|
|
22
|
+
records: List[core_entities.GenerationRecord],
|
|
23
|
+
) -> core_entities.GenerationRecord: # pragma: no cover - interface
|
|
24
|
+
...
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class SingleAttemptStrategy:
|
|
29
|
+
"""Default strategy – run exactly once and pass-through result."""
|
|
30
|
+
|
|
31
|
+
def expand(
|
|
32
|
+
self, task: core_entities.GenerationTask
|
|
33
|
+
) -> Iterable[core_entities.GenerationTask]:
|
|
34
|
+
return [task]
|
|
35
|
+
|
|
36
|
+
def aggregate(
|
|
37
|
+
self,
|
|
38
|
+
task: core_entities.GenerationTask,
|
|
39
|
+
records: List[core_entities.GenerationRecord],
|
|
40
|
+
) -> core_entities.GenerationRecord:
|
|
41
|
+
record = records[0]
|
|
42
|
+
return core_entities.GenerationRecord(
|
|
43
|
+
task=task,
|
|
44
|
+
output=record.output,
|
|
45
|
+
error=record.error,
|
|
46
|
+
metrics=dict(record.metrics),
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class RepeatedSamplingStrategy:
|
|
52
|
+
"""Repeat the same task multiple times for test-time scaling."""
|
|
53
|
+
|
|
54
|
+
attempts: int
|
|
55
|
+
metadata_label: str = "attempts"
|
|
56
|
+
|
|
57
|
+
def expand(
|
|
58
|
+
self, task: core_entities.GenerationTask
|
|
59
|
+
) -> Iterable[core_entities.GenerationTask]:
|
|
60
|
+
for index in range(self.attempts):
|
|
61
|
+
attempt_metadata = dict(task.metadata)
|
|
62
|
+
attempt_metadata[self.metadata_label] = index
|
|
63
|
+
yield core_entities.GenerationTask(
|
|
64
|
+
prompt=task.prompt,
|
|
65
|
+
model=task.model,
|
|
66
|
+
sampling=task.sampling,
|
|
67
|
+
metadata=attempt_metadata,
|
|
68
|
+
reference=task.reference,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
def aggregate(
|
|
72
|
+
self,
|
|
73
|
+
task: core_entities.GenerationTask,
|
|
74
|
+
records: List[core_entities.GenerationRecord],
|
|
75
|
+
) -> core_entities.GenerationRecord:
|
|
76
|
+
best = next((record for record in records if not record.error), records[0])
|
|
77
|
+
aggregated = core_entities.GenerationRecord(
|
|
78
|
+
task=task,
|
|
79
|
+
output=best.output,
|
|
80
|
+
error=best.error,
|
|
81
|
+
metrics=dict(best.metrics),
|
|
82
|
+
)
|
|
83
|
+
aggregated.metrics["attempt_count"] = len(records)
|
|
84
|
+
aggregated.metrics["attempt_outcomes"] = [
|
|
85
|
+
{
|
|
86
|
+
"output": record.output.text if record.output else None,
|
|
87
|
+
"error": record.error.message if record.error else None,
|
|
88
|
+
}
|
|
89
|
+
for record in records
|
|
90
|
+
]
|
|
91
|
+
return aggregated
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
__all__ = [
|
|
95
|
+
"GenerationStrategy",
|
|
96
|
+
"SingleAttemptStrategy",
|
|
97
|
+
"RepeatedSamplingStrategy",
|
|
98
|
+
]
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
"""Prompt template primitives for Themis generation domain."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from itertools import product
|
|
7
|
+
from typing import Any, Dict, Iterable, List
|
|
8
|
+
|
|
9
|
+
from themis.core import entities as core_entities
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TemplateRenderingError(RuntimeError):
|
|
13
|
+
"""Raised when a prompt template cannot be rendered."""
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class PromptTemplate:
|
|
18
|
+
"""Represents a format string and associated metadata."""
|
|
19
|
+
|
|
20
|
+
name: str
|
|
21
|
+
template: str
|
|
22
|
+
metadata: Dict[str, Any] | None = None
|
|
23
|
+
|
|
24
|
+
def __post_init__(self) -> None:
|
|
25
|
+
self._spec = core_entities.PromptSpec(
|
|
26
|
+
name=self.name,
|
|
27
|
+
template=self.template,
|
|
28
|
+
metadata=dict(self.metadata or {}),
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
def render(self, **kwargs: Any) -> str:
|
|
32
|
+
try:
|
|
33
|
+
return self.template.format(**kwargs)
|
|
34
|
+
except KeyError as exc: # pragma: no cover - defensive path
|
|
35
|
+
missing = exc.args[0]
|
|
36
|
+
raise TemplateRenderingError(
|
|
37
|
+
f"Missing template variable: {missing}"
|
|
38
|
+
) from exc
|
|
39
|
+
|
|
40
|
+
def expand_variants(
|
|
41
|
+
self,
|
|
42
|
+
*,
|
|
43
|
+
base_context: Dict[str, Any],
|
|
44
|
+
variant_values: Dict[str, Iterable[Any]],
|
|
45
|
+
) -> List[core_entities.PromptRender]:
|
|
46
|
+
"""Generate prompts for the cross-product of variant fields."""
|
|
47
|
+
|
|
48
|
+
if not variant_values:
|
|
49
|
+
return [self._render_context(base_context)]
|
|
50
|
+
|
|
51
|
+
keys = sorted(variant_values.keys())
|
|
52
|
+
prompts: list[core_entities.PromptRender] = []
|
|
53
|
+
for combo in product(*(variant_values[key] for key in keys)):
|
|
54
|
+
combo_context = dict(base_context)
|
|
55
|
+
combo_context.update(dict(zip(keys, combo)))
|
|
56
|
+
prompts.append(self._render_context(combo_context))
|
|
57
|
+
return prompts
|
|
58
|
+
|
|
59
|
+
def render_prompt(self, context: Dict[str, Any]) -> core_entities.PromptRender:
|
|
60
|
+
"""Render the template to a core PromptRender."""
|
|
61
|
+
return self._render_context(context)
|
|
62
|
+
|
|
63
|
+
def _render_context(self, context: Dict[str, Any]) -> core_entities.PromptRender:
|
|
64
|
+
prompt_text = self.render(**context)
|
|
65
|
+
metadata = dict(self.metadata or {})
|
|
66
|
+
return core_entities.PromptRender(
|
|
67
|
+
spec=self._spec,
|
|
68
|
+
text=prompt_text,
|
|
69
|
+
context=dict(context),
|
|
70
|
+
metadata=metadata,
|
|
71
|
+
)
|