themis-eval 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themis/__init__.py +12 -1
- themis/_version.py +2 -2
- themis/api.py +343 -0
- themis/backends/__init__.py +17 -0
- themis/backends/execution.py +197 -0
- themis/backends/storage.py +260 -0
- themis/cli/__init__.py +5 -0
- themis/cli/__main__.py +6 -0
- themis/cli/commands/__init__.py +19 -0
- themis/cli/commands/benchmarks.py +221 -0
- themis/cli/commands/comparison.py +394 -0
- themis/cli/commands/config_commands.py +244 -0
- themis/cli/commands/cost.py +214 -0
- themis/cli/commands/demo.py +68 -0
- themis/cli/commands/info.py +90 -0
- themis/cli/commands/leaderboard.py +362 -0
- themis/cli/commands/math_benchmarks.py +318 -0
- themis/cli/commands/mcq_benchmarks.py +207 -0
- themis/cli/commands/results.py +252 -0
- themis/cli/commands/sample_run.py +244 -0
- themis/cli/commands/visualize.py +299 -0
- themis/cli/main.py +463 -0
- themis/cli/new_project.py +33 -0
- themis/cli/utils.py +51 -0
- themis/comparison/__init__.py +25 -0
- themis/comparison/engine.py +348 -0
- themis/comparison/reports.py +283 -0
- themis/comparison/statistics.py +402 -0
- themis/config/__init__.py +19 -0
- themis/config/loader.py +27 -0
- themis/config/registry.py +34 -0
- themis/config/runtime.py +214 -0
- themis/config/schema.py +112 -0
- themis/core/__init__.py +5 -0
- themis/core/conversation.py +354 -0
- themis/core/entities.py +184 -0
- themis/core/serialization.py +231 -0
- themis/core/tools.py +393 -0
- themis/core/types.py +141 -0
- themis/datasets/__init__.py +273 -0
- themis/datasets/base.py +264 -0
- themis/datasets/commonsense_qa.py +174 -0
- themis/datasets/competition_math.py +265 -0
- themis/datasets/coqa.py +133 -0
- themis/datasets/gpqa.py +190 -0
- themis/datasets/gsm8k.py +123 -0
- themis/datasets/gsm_symbolic.py +124 -0
- themis/datasets/math500.py +122 -0
- themis/datasets/med_qa.py +179 -0
- themis/datasets/medmcqa.py +169 -0
- themis/datasets/mmlu_pro.py +262 -0
- themis/datasets/piqa.py +146 -0
- themis/datasets/registry.py +201 -0
- themis/datasets/schema.py +245 -0
- themis/datasets/sciq.py +150 -0
- themis/datasets/social_i_qa.py +151 -0
- themis/datasets/super_gpqa.py +263 -0
- themis/evaluation/__init__.py +1 -0
- themis/evaluation/conditional.py +410 -0
- themis/evaluation/extractors/__init__.py +19 -0
- themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
- themis/evaluation/extractors/exceptions.py +7 -0
- themis/evaluation/extractors/identity_extractor.py +29 -0
- themis/evaluation/extractors/json_field_extractor.py +45 -0
- themis/evaluation/extractors/math_verify_extractor.py +37 -0
- themis/evaluation/extractors/regex_extractor.py +43 -0
- themis/evaluation/math_verify_utils.py +87 -0
- themis/evaluation/metrics/__init__.py +21 -0
- themis/evaluation/metrics/code/__init__.py +19 -0
- themis/evaluation/metrics/code/codebleu.py +144 -0
- themis/evaluation/metrics/code/execution.py +280 -0
- themis/evaluation/metrics/code/pass_at_k.py +181 -0
- themis/evaluation/metrics/composite_metric.py +47 -0
- themis/evaluation/metrics/consistency_metric.py +80 -0
- themis/evaluation/metrics/exact_match.py +51 -0
- themis/evaluation/metrics/length_difference_tolerance.py +33 -0
- themis/evaluation/metrics/math_verify_accuracy.py +40 -0
- themis/evaluation/metrics/nlp/__init__.py +21 -0
- themis/evaluation/metrics/nlp/bertscore.py +138 -0
- themis/evaluation/metrics/nlp/bleu.py +129 -0
- themis/evaluation/metrics/nlp/meteor.py +153 -0
- themis/evaluation/metrics/nlp/rouge.py +136 -0
- themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
- themis/evaluation/metrics/response_length.py +33 -0
- themis/evaluation/metrics/rubric_judge_metric.py +134 -0
- themis/evaluation/pipeline.py +49 -0
- themis/evaluation/pipelines/__init__.py +15 -0
- themis/evaluation/pipelines/composable_pipeline.py +357 -0
- themis/evaluation/pipelines/standard_pipeline.py +348 -0
- themis/evaluation/reports.py +293 -0
- themis/evaluation/statistics/__init__.py +53 -0
- themis/evaluation/statistics/bootstrap.py +79 -0
- themis/evaluation/statistics/confidence_intervals.py +121 -0
- themis/evaluation/statistics/distributions.py +207 -0
- themis/evaluation/statistics/effect_sizes.py +124 -0
- themis/evaluation/statistics/hypothesis_tests.py +305 -0
- themis/evaluation/statistics/types.py +139 -0
- themis/evaluation/strategies/__init__.py +13 -0
- themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
- themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
- themis/evaluation/strategies/evaluation_strategy.py +24 -0
- themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
- themis/experiment/__init__.py +5 -0
- themis/experiment/builder.py +151 -0
- themis/experiment/cache_manager.py +134 -0
- themis/experiment/comparison.py +631 -0
- themis/experiment/cost.py +310 -0
- themis/experiment/definitions.py +62 -0
- themis/experiment/export.py +798 -0
- themis/experiment/export_csv.py +159 -0
- themis/experiment/integration_manager.py +104 -0
- themis/experiment/math.py +192 -0
- themis/experiment/mcq.py +169 -0
- themis/experiment/orchestrator.py +415 -0
- themis/experiment/pricing.py +317 -0
- themis/experiment/storage.py +1458 -0
- themis/experiment/visualization.py +588 -0
- themis/generation/__init__.py +1 -0
- themis/generation/agentic_runner.py +420 -0
- themis/generation/batching.py +254 -0
- themis/generation/clients.py +143 -0
- themis/generation/conversation_runner.py +236 -0
- themis/generation/plan.py +456 -0
- themis/generation/providers/litellm_provider.py +221 -0
- themis/generation/providers/vllm_provider.py +135 -0
- themis/generation/router.py +34 -0
- themis/generation/runner.py +207 -0
- themis/generation/strategies.py +98 -0
- themis/generation/templates.py +71 -0
- themis/generation/turn_strategies.py +393 -0
- themis/generation/types.py +9 -0
- themis/integrations/__init__.py +0 -0
- themis/integrations/huggingface.py +72 -0
- themis/integrations/wandb.py +77 -0
- themis/interfaces/__init__.py +169 -0
- themis/presets/__init__.py +10 -0
- themis/presets/benchmarks.py +354 -0
- themis/presets/models.py +190 -0
- themis/project/__init__.py +20 -0
- themis/project/definitions.py +98 -0
- themis/project/patterns.py +230 -0
- themis/providers/__init__.py +5 -0
- themis/providers/registry.py +39 -0
- themis/server/__init__.py +28 -0
- themis/server/app.py +337 -0
- themis/utils/api_generator.py +379 -0
- themis/utils/cost_tracking.py +376 -0
- themis/utils/dashboard.py +452 -0
- themis/utils/logging_utils.py +41 -0
- themis/utils/progress.py +58 -0
- themis/utils/tracing.py +320 -0
- themis_eval-0.2.0.dist-info/METADATA +596 -0
- themis_eval-0.2.0.dist-info/RECORD +157 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/WHEEL +1 -1
- themis_eval-0.1.0.dist-info/METADATA +0 -758
- themis_eval-0.1.0.dist-info/RECORD +0 -8
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/top_level.txt +0 -0
themis/config/schema.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
"""Structured configuration definitions for Hydra/OmegaConf."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class ProviderConfig:
|
|
12
|
+
name: str = "fake"
|
|
13
|
+
options: dict[str, Any] = field(default_factory=dict)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class RunnerConfig:
|
|
18
|
+
max_parallel: int = 1
|
|
19
|
+
max_retries: int = 3
|
|
20
|
+
retry_initial_delay: float = 0.5
|
|
21
|
+
retry_backoff_multiplier: float = 2.0
|
|
22
|
+
retry_max_delay: float | None = 2.0
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class SamplingConfig:
|
|
27
|
+
temperature: float = 0.0
|
|
28
|
+
top_p: float = 0.95
|
|
29
|
+
max_tokens: int = 512
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class GenerationConfig:
|
|
34
|
+
model_identifier: str = "fake-math-llm"
|
|
35
|
+
provider: ProviderConfig = field(default_factory=ProviderConfig)
|
|
36
|
+
sampling: SamplingConfig = field(default_factory=SamplingConfig)
|
|
37
|
+
runner: RunnerConfig = field(default_factory=RunnerConfig)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class DatasetConfig:
|
|
42
|
+
source: str = "huggingface"
|
|
43
|
+
dataset_id: str | None = None
|
|
44
|
+
data_dir: str | None = None
|
|
45
|
+
limit: int | None = None
|
|
46
|
+
split: str = "test"
|
|
47
|
+
subjects: list[str] = field(default_factory=list)
|
|
48
|
+
inline_samples: list[dict[str, Any]] = field(default_factory=list)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class StorageConfig:
|
|
53
|
+
path: str | None = None
|
|
54
|
+
default_path: str | None = None # New field for default storage path
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class WandbConfig:
|
|
59
|
+
enable: bool = False
|
|
60
|
+
project: str | None = None
|
|
61
|
+
entity: str | None = None
|
|
62
|
+
tags: list[str] = field(default_factory=list)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@dataclass
|
|
66
|
+
class HuggingFaceHubConfig:
|
|
67
|
+
enable: bool = False
|
|
68
|
+
repository: str | None = None
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclass
|
|
72
|
+
class IntegrationsConfig:
|
|
73
|
+
wandb: WandbConfig = field(default_factory=WandbConfig)
|
|
74
|
+
huggingface_hub: HuggingFaceHubConfig = field(default_factory=HuggingFaceHubConfig)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclass
|
|
78
|
+
class ExperimentConfig:
|
|
79
|
+
name: str = "math500_zero_shot"
|
|
80
|
+
dataset: DatasetConfig = field(default_factory=DatasetConfig)
|
|
81
|
+
generation: GenerationConfig = field(default_factory=GenerationConfig)
|
|
82
|
+
storage: StorageConfig = field(default_factory=StorageConfig)
|
|
83
|
+
integrations: IntegrationsConfig = field(default_factory=IntegrationsConfig)
|
|
84
|
+
max_samples: int | None = None
|
|
85
|
+
run_id: str | None = None
|
|
86
|
+
resume: bool = True
|
|
87
|
+
task: str | None = None
|
|
88
|
+
task_options: dict[str, Any] = field(default_factory=dict)
|
|
89
|
+
|
|
90
|
+
@classmethod
|
|
91
|
+
def from_file(cls, path: str | Path) -> ExperimentConfig:
|
|
92
|
+
"""Load configuration from a file."""
|
|
93
|
+
from .loader import load_experiment_config
|
|
94
|
+
|
|
95
|
+
return load_experiment_config(Path(path))
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def from_dict(cls, data: dict[str, Any]) -> ExperimentConfig:
|
|
99
|
+
"""Create configuration from a dictionary."""
|
|
100
|
+
from omegaconf import OmegaConf
|
|
101
|
+
|
|
102
|
+
base = OmegaConf.structured(cls)
|
|
103
|
+
merged = OmegaConf.merge(base, OmegaConf.create(data))
|
|
104
|
+
return OmegaConf.to_object(merged) # type: ignore
|
|
105
|
+
|
|
106
|
+
def to_file(self, path: str | Path) -> None:
|
|
107
|
+
"""Save configuration to a file."""
|
|
108
|
+
from omegaconf import OmegaConf
|
|
109
|
+
|
|
110
|
+
conf = OmegaConf.structured(self)
|
|
111
|
+
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
|
112
|
+
OmegaConf.save(conf, Path(path))
|
themis/core/__init__.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
1
|
+
"""Conversation primitives for multi-turn interactions.
|
|
2
|
+
|
|
3
|
+
This module provides abstractions for managing multi-turn conversations,
|
|
4
|
+
enabling research on dialogue systems, debugging interactions, and
|
|
5
|
+
agentic workflows.
|
|
6
|
+
|
|
7
|
+
Examples:
|
|
8
|
+
# Create a conversation
|
|
9
|
+
context = ConversationContext()
|
|
10
|
+
context.add_message("user", "What is 2+2?")
|
|
11
|
+
context.add_message("assistant", "2+2 equals 4.")
|
|
12
|
+
context.add_message("user", "What about 3+3?")
|
|
13
|
+
|
|
14
|
+
# Convert to prompt
|
|
15
|
+
prompt = context.to_prompt()
|
|
16
|
+
|
|
17
|
+
# Get conversation history
|
|
18
|
+
history = context.get_history(max_turns=2)
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
import time
|
|
24
|
+
from dataclasses import dataclass, field
|
|
25
|
+
from typing import Any, Callable, Literal
|
|
26
|
+
|
|
27
|
+
from themis.core import entities as core_entities
|
|
28
|
+
from themis.generation import templates
|
|
29
|
+
|
|
30
|
+
MessageRole = Literal["system", "user", "assistant", "tool"]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class Message:
|
|
35
|
+
"""Single message in a conversation.
|
|
36
|
+
|
|
37
|
+
Attributes:
|
|
38
|
+
role: Message role (system/user/assistant/tool)
|
|
39
|
+
content: Message text content
|
|
40
|
+
metadata: Additional metadata (tool calls, timestamps, etc.)
|
|
41
|
+
timestamp: Unix timestamp when message was created
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
role: MessageRole
|
|
45
|
+
content: str
|
|
46
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
47
|
+
timestamp: float = field(default_factory=time.time)
|
|
48
|
+
|
|
49
|
+
def to_dict(self) -> dict[str, Any]:
|
|
50
|
+
"""Convert message to dictionary."""
|
|
51
|
+
return {
|
|
52
|
+
"role": self.role,
|
|
53
|
+
"content": self.content,
|
|
54
|
+
"metadata": self.metadata,
|
|
55
|
+
"timestamp": self.timestamp,
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass
|
|
60
|
+
class ConversationContext:
|
|
61
|
+
"""Maintains conversation state across turns.
|
|
62
|
+
|
|
63
|
+
This class manages the conversation history and provides utilities
|
|
64
|
+
for rendering conversations as prompts.
|
|
65
|
+
|
|
66
|
+
Examples:
|
|
67
|
+
context = ConversationContext()
|
|
68
|
+
context.add_message("system", "You are a helpful assistant.")
|
|
69
|
+
context.add_message("user", "Hello!")
|
|
70
|
+
context.add_message("assistant", "Hi! How can I help you?")
|
|
71
|
+
|
|
72
|
+
# Get history
|
|
73
|
+
messages = context.get_history()
|
|
74
|
+
|
|
75
|
+
# Render to prompt
|
|
76
|
+
prompt = context.to_prompt()
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
messages: list[Message] = field(default_factory=list)
|
|
80
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
81
|
+
|
|
82
|
+
def add_message(self, role: MessageRole, content: str, **metadata: Any) -> None:
|
|
83
|
+
"""Add a message to the conversation.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
role: Message role (system/user/assistant/tool)
|
|
87
|
+
content: Message text content
|
|
88
|
+
**metadata: Additional metadata to attach to message
|
|
89
|
+
"""
|
|
90
|
+
self.messages.append(Message(role=role, content=content, metadata=metadata))
|
|
91
|
+
|
|
92
|
+
def get_history(self, max_turns: int | None = None) -> list[Message]:
|
|
93
|
+
"""Get conversation history.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
max_turns: Maximum number of messages to return (from end)
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
List of messages (most recent if limited)
|
|
100
|
+
"""
|
|
101
|
+
if max_turns is None:
|
|
102
|
+
return list(self.messages)
|
|
103
|
+
return self.messages[-max_turns:]
|
|
104
|
+
|
|
105
|
+
def get_messages_by_role(self, role: MessageRole) -> list[Message]:
|
|
106
|
+
"""Get all messages with a specific role.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
role: Role to filter by
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
List of messages with matching role
|
|
113
|
+
"""
|
|
114
|
+
return [msg for msg in self.messages if msg.role == role]
|
|
115
|
+
|
|
116
|
+
def to_prompt(self, template: templates.PromptTemplate | None = None) -> str:
|
|
117
|
+
"""Render conversation to prompt string.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
template: Optional template for custom formatting
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
Formatted prompt string
|
|
124
|
+
"""
|
|
125
|
+
if template is not None:
|
|
126
|
+
return template.render(messages=self.messages)
|
|
127
|
+
|
|
128
|
+
# Default format: role-prefixed messages
|
|
129
|
+
lines = []
|
|
130
|
+
for msg in self.messages:
|
|
131
|
+
lines.append(f"{msg.role}: {msg.content}")
|
|
132
|
+
|
|
133
|
+
return "\n\n".join(lines)
|
|
134
|
+
|
|
135
|
+
def to_dict(self) -> dict[str, Any]:
|
|
136
|
+
"""Convert conversation to dictionary.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
Dictionary representation
|
|
140
|
+
"""
|
|
141
|
+
return {
|
|
142
|
+
"messages": [msg.to_dict() for msg in self.messages],
|
|
143
|
+
"metadata": self.metadata,
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
@classmethod
|
|
147
|
+
def from_dict(cls, data: dict[str, Any]) -> ConversationContext:
|
|
148
|
+
"""Create conversation from dictionary.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
data: Dictionary with messages and metadata
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
ConversationContext instance
|
|
155
|
+
"""
|
|
156
|
+
context = cls(metadata=data.get("metadata", {}))
|
|
157
|
+
for msg_data in data.get("messages", []):
|
|
158
|
+
context.messages.append(
|
|
159
|
+
Message(
|
|
160
|
+
role=msg_data["role"],
|
|
161
|
+
content=msg_data["content"],
|
|
162
|
+
metadata=msg_data.get("metadata", {}),
|
|
163
|
+
timestamp=msg_data.get("timestamp", time.time()),
|
|
164
|
+
)
|
|
165
|
+
)
|
|
166
|
+
return context
|
|
167
|
+
|
|
168
|
+
def __len__(self) -> int:
|
|
169
|
+
"""Return number of messages in conversation."""
|
|
170
|
+
return len(self.messages)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
@dataclass
|
|
174
|
+
class ConversationTask:
|
|
175
|
+
"""Task for multi-turn conversation execution.
|
|
176
|
+
|
|
177
|
+
This extends the basic GenerationTask concept to support
|
|
178
|
+
multi-turn conversations with configurable stopping conditions.
|
|
179
|
+
|
|
180
|
+
Attributes:
|
|
181
|
+
context: Conversation context with message history
|
|
182
|
+
model: Model to use for generation
|
|
183
|
+
sampling: Sampling configuration
|
|
184
|
+
metadata: Additional metadata
|
|
185
|
+
reference: Optional reference for evaluation
|
|
186
|
+
max_turns: Maximum number of conversation turns
|
|
187
|
+
stop_condition: Optional function to determine when to stop
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
context: ConversationContext
|
|
191
|
+
model: core_entities.ModelSpec
|
|
192
|
+
sampling: core_entities.SamplingConfig
|
|
193
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
194
|
+
reference: core_entities.Reference | None = None
|
|
195
|
+
max_turns: int = 10
|
|
196
|
+
stop_condition: Callable[[ConversationContext], bool] | None = None
|
|
197
|
+
|
|
198
|
+
def should_stop(self) -> bool:
|
|
199
|
+
"""Check if conversation should stop.
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
True if stop condition is met or max turns reached
|
|
203
|
+
"""
|
|
204
|
+
if len(self.context) >= self.max_turns:
|
|
205
|
+
return True
|
|
206
|
+
|
|
207
|
+
if self.stop_condition is not None:
|
|
208
|
+
return self.stop_condition(self.context)
|
|
209
|
+
|
|
210
|
+
return False
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
@dataclass
|
|
214
|
+
class ConversationTurn:
|
|
215
|
+
"""Single turn in a conversation.
|
|
216
|
+
|
|
217
|
+
Attributes:
|
|
218
|
+
turn_number: Turn index (0-based)
|
|
219
|
+
user_message: User message for this turn (if any)
|
|
220
|
+
generation_record: Generation result for this turn
|
|
221
|
+
context_snapshot: Conversation context at this turn
|
|
222
|
+
"""
|
|
223
|
+
|
|
224
|
+
turn_number: int
|
|
225
|
+
user_message: Message | None
|
|
226
|
+
generation_record: core_entities.GenerationRecord
|
|
227
|
+
context_snapshot: ConversationContext
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
@dataclass
|
|
231
|
+
class ConversationRecord:
|
|
232
|
+
"""Complete record of a multi-turn conversation.
|
|
233
|
+
|
|
234
|
+
This is the result of running a ConversationTask through
|
|
235
|
+
a ConversationRunner.
|
|
236
|
+
|
|
237
|
+
Attributes:
|
|
238
|
+
task: Original conversation task
|
|
239
|
+
context: Final conversation context
|
|
240
|
+
turns: List of turns executed
|
|
241
|
+
metadata: Additional metadata (e.g., total turns, stop reason)
|
|
242
|
+
"""
|
|
243
|
+
|
|
244
|
+
task: ConversationTask
|
|
245
|
+
context: ConversationContext
|
|
246
|
+
turns: list[ConversationTurn] = field(default_factory=list)
|
|
247
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
248
|
+
|
|
249
|
+
def get_final_output(self) -> core_entities.ModelOutput | None:
|
|
250
|
+
"""Get the final model output.
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
Last turn's output, or None if no turns
|
|
254
|
+
"""
|
|
255
|
+
if not self.turns:
|
|
256
|
+
return None
|
|
257
|
+
return self.turns[-1].generation_record.output
|
|
258
|
+
|
|
259
|
+
def get_all_outputs(self) -> list[core_entities.ModelOutput | None]:
|
|
260
|
+
"""Get all model outputs from all turns.
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
List of outputs (may contain None for failed turns)
|
|
264
|
+
"""
|
|
265
|
+
return [turn.generation_record.output for turn in self.turns]
|
|
266
|
+
|
|
267
|
+
def total_turns(self) -> int:
|
|
268
|
+
"""Get total number of turns executed.
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
Number of turns
|
|
272
|
+
"""
|
|
273
|
+
return len(self.turns)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
# Common stop conditions
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def stop_on_keyword(keyword: str) -> Callable[[ConversationContext], bool]:
|
|
280
|
+
"""Create stop condition that triggers when keyword appears.
|
|
281
|
+
|
|
282
|
+
Args:
|
|
283
|
+
keyword: Keyword to look for in assistant messages
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
Stop condition function
|
|
287
|
+
"""
|
|
288
|
+
|
|
289
|
+
def condition(context: ConversationContext) -> bool:
|
|
290
|
+
if not context.messages:
|
|
291
|
+
return False
|
|
292
|
+
last_msg = context.messages[-1]
|
|
293
|
+
if last_msg.role == "assistant":
|
|
294
|
+
return keyword.lower() in last_msg.content.lower()
|
|
295
|
+
return False
|
|
296
|
+
|
|
297
|
+
return condition
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def stop_on_pattern(
|
|
301
|
+
pattern: str,
|
|
302
|
+
) -> Callable[[ConversationContext], bool]:
|
|
303
|
+
"""Create stop condition that triggers when regex pattern matches.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
pattern: Regex pattern to match
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
Stop condition function
|
|
310
|
+
"""
|
|
311
|
+
import re
|
|
312
|
+
|
|
313
|
+
compiled = re.compile(pattern, re.IGNORECASE)
|
|
314
|
+
|
|
315
|
+
def condition(context: ConversationContext) -> bool:
|
|
316
|
+
if not context.messages:
|
|
317
|
+
return False
|
|
318
|
+
last_msg = context.messages[-1]
|
|
319
|
+
if last_msg.role == "assistant":
|
|
320
|
+
return compiled.search(last_msg.content) is not None
|
|
321
|
+
return False
|
|
322
|
+
|
|
323
|
+
return condition
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def stop_on_empty_response() -> Callable[[ConversationContext], bool]:
|
|
327
|
+
"""Create stop condition that triggers on empty assistant response.
|
|
328
|
+
|
|
329
|
+
Returns:
|
|
330
|
+
Stop condition function
|
|
331
|
+
"""
|
|
332
|
+
|
|
333
|
+
def condition(context: ConversationContext) -> bool:
|
|
334
|
+
if not context.messages:
|
|
335
|
+
return False
|
|
336
|
+
last_msg = context.messages[-1]
|
|
337
|
+
if last_msg.role == "assistant":
|
|
338
|
+
return not last_msg.content.strip()
|
|
339
|
+
return False
|
|
340
|
+
|
|
341
|
+
return condition
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
__all__ = [
|
|
345
|
+
"MessageRole",
|
|
346
|
+
"Message",
|
|
347
|
+
"ConversationContext",
|
|
348
|
+
"ConversationTask",
|
|
349
|
+
"ConversationTurn",
|
|
350
|
+
"ConversationRecord",
|
|
351
|
+
"stop_on_keyword",
|
|
352
|
+
"stop_on_pattern",
|
|
353
|
+
"stop_on_empty_response",
|
|
354
|
+
]
|
themis/core/entities.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""Shared dataclasses that represent Themis' internal world."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Dict, Generic, List, TypeVar
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from themis.evaluation.reports import EvaluationReport
|
|
10
|
+
|
|
11
|
+
# Type variable for generic Reference
|
|
12
|
+
T = TypeVar("T")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(frozen=True)
|
|
16
|
+
class SamplingConfig:
|
|
17
|
+
temperature: float
|
|
18
|
+
top_p: float
|
|
19
|
+
max_tokens: int
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass(frozen=True)
|
|
23
|
+
class ModelSpec:
|
|
24
|
+
identifier: str
|
|
25
|
+
provider: str
|
|
26
|
+
default_sampling: SamplingConfig | None = None
|
|
27
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass(frozen=True)
|
|
31
|
+
class PromptSpec:
|
|
32
|
+
name: str
|
|
33
|
+
template: str
|
|
34
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass(frozen=True)
|
|
38
|
+
class PromptRender:
|
|
39
|
+
spec: PromptSpec
|
|
40
|
+
text: str
|
|
41
|
+
context: Dict[str, Any] = field(default_factory=dict)
|
|
42
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def prompt_text(self) -> str:
|
|
46
|
+
return self.text
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def template_name(self) -> str:
|
|
50
|
+
return self.spec.name
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass(frozen=True)
|
|
54
|
+
class Reference(Generic[T]):
|
|
55
|
+
"""Reference value with optional type information.
|
|
56
|
+
|
|
57
|
+
This is a generic dataclass that can hold typed reference values.
|
|
58
|
+
For backward compatibility, it can be used without type parameters
|
|
59
|
+
and will behave like Reference[Any].
|
|
60
|
+
|
|
61
|
+
The value field can hold any type including:
|
|
62
|
+
- Simple types: str, int, float, bool
|
|
63
|
+
- Collections: list, tuple, set
|
|
64
|
+
- Dictionaries: dict (for multi-value references)
|
|
65
|
+
- Custom objects
|
|
66
|
+
|
|
67
|
+
Examples:
|
|
68
|
+
# Simple reference
|
|
69
|
+
ref = Reference(kind="answer", value="42")
|
|
70
|
+
|
|
71
|
+
# Multi-value reference using dict
|
|
72
|
+
ref = Reference(
|
|
73
|
+
kind="countdown_task",
|
|
74
|
+
value={"target": 122, "numbers": [25, 50, 75, 100]}
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# List reference
|
|
78
|
+
ref = Reference(kind="valid_answers", value=["yes", "no", "maybe"])
|
|
79
|
+
|
|
80
|
+
# Typed reference
|
|
81
|
+
ref: Reference[str] = Reference(kind="answer", value="42")
|
|
82
|
+
ref: Reference[dict] = Reference(kind="task", value={"a": 1, "b": 2})
|
|
83
|
+
|
|
84
|
+
Note:
|
|
85
|
+
When using dict values, metrics can access individual fields directly:
|
|
86
|
+
>>> target = reference.value["target"]
|
|
87
|
+
>>> numbers = reference.value["numbers"]
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
kind: str
|
|
91
|
+
value: T
|
|
92
|
+
schema: type[T] | None = None # Optional runtime type information
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@dataclass(frozen=True)
|
|
96
|
+
class ModelOutput:
|
|
97
|
+
text: str
|
|
98
|
+
raw: Any | None = None
|
|
99
|
+
usage: Dict[str, int] | None = None # Token usage: {prompt_tokens, completion_tokens, total_tokens}
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@dataclass(frozen=True)
|
|
103
|
+
class ModelError:
|
|
104
|
+
message: str
|
|
105
|
+
kind: str = "model_error"
|
|
106
|
+
details: Dict[str, Any] = field(default_factory=dict)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@dataclass
|
|
110
|
+
class GenerationTask:
|
|
111
|
+
prompt: PromptRender
|
|
112
|
+
model: ModelSpec
|
|
113
|
+
sampling: SamplingConfig
|
|
114
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
115
|
+
reference: Reference | None = None
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@dataclass
|
|
119
|
+
class GenerationRecord:
|
|
120
|
+
task: GenerationTask
|
|
121
|
+
output: ModelOutput | None
|
|
122
|
+
error: ModelError | None
|
|
123
|
+
metrics: Dict[str, Any] = field(default_factory=dict)
|
|
124
|
+
attempts: List["GenerationRecord"] = field(default_factory=list)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@dataclass(frozen=True)
|
|
128
|
+
class EvaluationItem:
|
|
129
|
+
record: GenerationRecord
|
|
130
|
+
reference: Reference | None
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
@dataclass(frozen=True)
|
|
134
|
+
class MetricScore:
|
|
135
|
+
metric_name: str
|
|
136
|
+
value: float
|
|
137
|
+
details: Dict[str, Any] = field(default_factory=dict)
|
|
138
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@dataclass
|
|
142
|
+
class EvaluationSummary:
|
|
143
|
+
scores: List[MetricScore]
|
|
144
|
+
failures: List[str] = field(default_factory=list)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@dataclass
|
|
148
|
+
class EvaluationRecord:
|
|
149
|
+
sample_id: str | None
|
|
150
|
+
scores: List[MetricScore]
|
|
151
|
+
failures: List[str] = field(default_factory=list)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@dataclass
|
|
155
|
+
class ExperimentFailure:
|
|
156
|
+
sample_id: str | None
|
|
157
|
+
message: str
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
@dataclass
|
|
161
|
+
class ExperimentReport:
|
|
162
|
+
generation_results: list[GenerationRecord]
|
|
163
|
+
evaluation_report: "EvaluationReport"
|
|
164
|
+
failures: list[ExperimentFailure]
|
|
165
|
+
metadata: dict[str, object]
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
__all__ = [
|
|
169
|
+
"SamplingConfig",
|
|
170
|
+
"ModelSpec",
|
|
171
|
+
"PromptSpec",
|
|
172
|
+
"PromptRender",
|
|
173
|
+
"Reference",
|
|
174
|
+
"ModelOutput",
|
|
175
|
+
"ModelError",
|
|
176
|
+
"GenerationTask",
|
|
177
|
+
"GenerationRecord",
|
|
178
|
+
"EvaluationItem",
|
|
179
|
+
"EvaluationRecord",
|
|
180
|
+
"MetricScore",
|
|
181
|
+
"EvaluationSummary",
|
|
182
|
+
"ExperimentFailure",
|
|
183
|
+
"ExperimentReport",
|
|
184
|
+
]
|