themis-eval 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. themis/__init__.py +12 -1
  2. themis/_version.py +2 -2
  3. themis/api.py +343 -0
  4. themis/backends/__init__.py +17 -0
  5. themis/backends/execution.py +197 -0
  6. themis/backends/storage.py +260 -0
  7. themis/cli/__init__.py +5 -0
  8. themis/cli/__main__.py +6 -0
  9. themis/cli/commands/__init__.py +19 -0
  10. themis/cli/commands/benchmarks.py +221 -0
  11. themis/cli/commands/comparison.py +394 -0
  12. themis/cli/commands/config_commands.py +244 -0
  13. themis/cli/commands/cost.py +214 -0
  14. themis/cli/commands/demo.py +68 -0
  15. themis/cli/commands/info.py +90 -0
  16. themis/cli/commands/leaderboard.py +362 -0
  17. themis/cli/commands/math_benchmarks.py +318 -0
  18. themis/cli/commands/mcq_benchmarks.py +207 -0
  19. themis/cli/commands/results.py +252 -0
  20. themis/cli/commands/sample_run.py +244 -0
  21. themis/cli/commands/visualize.py +299 -0
  22. themis/cli/main.py +463 -0
  23. themis/cli/new_project.py +33 -0
  24. themis/cli/utils.py +51 -0
  25. themis/comparison/__init__.py +25 -0
  26. themis/comparison/engine.py +348 -0
  27. themis/comparison/reports.py +283 -0
  28. themis/comparison/statistics.py +402 -0
  29. themis/config/__init__.py +19 -0
  30. themis/config/loader.py +27 -0
  31. themis/config/registry.py +34 -0
  32. themis/config/runtime.py +214 -0
  33. themis/config/schema.py +112 -0
  34. themis/core/__init__.py +5 -0
  35. themis/core/conversation.py +354 -0
  36. themis/core/entities.py +184 -0
  37. themis/core/serialization.py +231 -0
  38. themis/core/tools.py +393 -0
  39. themis/core/types.py +141 -0
  40. themis/datasets/__init__.py +273 -0
  41. themis/datasets/base.py +264 -0
  42. themis/datasets/commonsense_qa.py +174 -0
  43. themis/datasets/competition_math.py +265 -0
  44. themis/datasets/coqa.py +133 -0
  45. themis/datasets/gpqa.py +190 -0
  46. themis/datasets/gsm8k.py +123 -0
  47. themis/datasets/gsm_symbolic.py +124 -0
  48. themis/datasets/math500.py +122 -0
  49. themis/datasets/med_qa.py +179 -0
  50. themis/datasets/medmcqa.py +169 -0
  51. themis/datasets/mmlu_pro.py +262 -0
  52. themis/datasets/piqa.py +146 -0
  53. themis/datasets/registry.py +201 -0
  54. themis/datasets/schema.py +245 -0
  55. themis/datasets/sciq.py +150 -0
  56. themis/datasets/social_i_qa.py +151 -0
  57. themis/datasets/super_gpqa.py +263 -0
  58. themis/evaluation/__init__.py +1 -0
  59. themis/evaluation/conditional.py +410 -0
  60. themis/evaluation/extractors/__init__.py +19 -0
  61. themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
  62. themis/evaluation/extractors/exceptions.py +7 -0
  63. themis/evaluation/extractors/identity_extractor.py +29 -0
  64. themis/evaluation/extractors/json_field_extractor.py +45 -0
  65. themis/evaluation/extractors/math_verify_extractor.py +37 -0
  66. themis/evaluation/extractors/regex_extractor.py +43 -0
  67. themis/evaluation/math_verify_utils.py +87 -0
  68. themis/evaluation/metrics/__init__.py +21 -0
  69. themis/evaluation/metrics/code/__init__.py +19 -0
  70. themis/evaluation/metrics/code/codebleu.py +144 -0
  71. themis/evaluation/metrics/code/execution.py +280 -0
  72. themis/evaluation/metrics/code/pass_at_k.py +181 -0
  73. themis/evaluation/metrics/composite_metric.py +47 -0
  74. themis/evaluation/metrics/consistency_metric.py +80 -0
  75. themis/evaluation/metrics/exact_match.py +51 -0
  76. themis/evaluation/metrics/length_difference_tolerance.py +33 -0
  77. themis/evaluation/metrics/math_verify_accuracy.py +40 -0
  78. themis/evaluation/metrics/nlp/__init__.py +21 -0
  79. themis/evaluation/metrics/nlp/bertscore.py +138 -0
  80. themis/evaluation/metrics/nlp/bleu.py +129 -0
  81. themis/evaluation/metrics/nlp/meteor.py +153 -0
  82. themis/evaluation/metrics/nlp/rouge.py +136 -0
  83. themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
  84. themis/evaluation/metrics/response_length.py +33 -0
  85. themis/evaluation/metrics/rubric_judge_metric.py +134 -0
  86. themis/evaluation/pipeline.py +49 -0
  87. themis/evaluation/pipelines/__init__.py +15 -0
  88. themis/evaluation/pipelines/composable_pipeline.py +357 -0
  89. themis/evaluation/pipelines/standard_pipeline.py +348 -0
  90. themis/evaluation/reports.py +293 -0
  91. themis/evaluation/statistics/__init__.py +53 -0
  92. themis/evaluation/statistics/bootstrap.py +79 -0
  93. themis/evaluation/statistics/confidence_intervals.py +121 -0
  94. themis/evaluation/statistics/distributions.py +207 -0
  95. themis/evaluation/statistics/effect_sizes.py +124 -0
  96. themis/evaluation/statistics/hypothesis_tests.py +305 -0
  97. themis/evaluation/statistics/types.py +139 -0
  98. themis/evaluation/strategies/__init__.py +13 -0
  99. themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
  100. themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
  101. themis/evaluation/strategies/evaluation_strategy.py +24 -0
  102. themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
  103. themis/experiment/__init__.py +5 -0
  104. themis/experiment/builder.py +151 -0
  105. themis/experiment/cache_manager.py +134 -0
  106. themis/experiment/comparison.py +631 -0
  107. themis/experiment/cost.py +310 -0
  108. themis/experiment/definitions.py +62 -0
  109. themis/experiment/export.py +798 -0
  110. themis/experiment/export_csv.py +159 -0
  111. themis/experiment/integration_manager.py +104 -0
  112. themis/experiment/math.py +192 -0
  113. themis/experiment/mcq.py +169 -0
  114. themis/experiment/orchestrator.py +415 -0
  115. themis/experiment/pricing.py +317 -0
  116. themis/experiment/storage.py +1458 -0
  117. themis/experiment/visualization.py +588 -0
  118. themis/generation/__init__.py +1 -0
  119. themis/generation/agentic_runner.py +420 -0
  120. themis/generation/batching.py +254 -0
  121. themis/generation/clients.py +143 -0
  122. themis/generation/conversation_runner.py +236 -0
  123. themis/generation/plan.py +456 -0
  124. themis/generation/providers/litellm_provider.py +221 -0
  125. themis/generation/providers/vllm_provider.py +135 -0
  126. themis/generation/router.py +34 -0
  127. themis/generation/runner.py +207 -0
  128. themis/generation/strategies.py +98 -0
  129. themis/generation/templates.py +71 -0
  130. themis/generation/turn_strategies.py +393 -0
  131. themis/generation/types.py +9 -0
  132. themis/integrations/__init__.py +0 -0
  133. themis/integrations/huggingface.py +72 -0
  134. themis/integrations/wandb.py +77 -0
  135. themis/interfaces/__init__.py +169 -0
  136. themis/presets/__init__.py +10 -0
  137. themis/presets/benchmarks.py +354 -0
  138. themis/presets/models.py +190 -0
  139. themis/project/__init__.py +20 -0
  140. themis/project/definitions.py +98 -0
  141. themis/project/patterns.py +230 -0
  142. themis/providers/__init__.py +5 -0
  143. themis/providers/registry.py +39 -0
  144. themis/server/__init__.py +28 -0
  145. themis/server/app.py +337 -0
  146. themis/utils/api_generator.py +379 -0
  147. themis/utils/cost_tracking.py +376 -0
  148. themis/utils/dashboard.py +452 -0
  149. themis/utils/logging_utils.py +41 -0
  150. themis/utils/progress.py +58 -0
  151. themis/utils/tracing.py +320 -0
  152. themis_eval-0.2.0.dist-info/METADATA +596 -0
  153. themis_eval-0.2.0.dist-info/RECORD +157 -0
  154. {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/WHEEL +1 -1
  155. themis_eval-0.1.0.dist-info/METADATA +0 -758
  156. themis_eval-0.1.0.dist-info/RECORD +0 -8
  157. {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/licenses/LICENSE +0 -0
  158. {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,112 @@
1
+ """Structured configuration definitions for Hydra/OmegaConf."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+
10
+ @dataclass
11
+ class ProviderConfig:
12
+ name: str = "fake"
13
+ options: dict[str, Any] = field(default_factory=dict)
14
+
15
+
16
+ @dataclass
17
+ class RunnerConfig:
18
+ max_parallel: int = 1
19
+ max_retries: int = 3
20
+ retry_initial_delay: float = 0.5
21
+ retry_backoff_multiplier: float = 2.0
22
+ retry_max_delay: float | None = 2.0
23
+
24
+
25
+ @dataclass
26
+ class SamplingConfig:
27
+ temperature: float = 0.0
28
+ top_p: float = 0.95
29
+ max_tokens: int = 512
30
+
31
+
32
+ @dataclass
33
+ class GenerationConfig:
34
+ model_identifier: str = "fake-math-llm"
35
+ provider: ProviderConfig = field(default_factory=ProviderConfig)
36
+ sampling: SamplingConfig = field(default_factory=SamplingConfig)
37
+ runner: RunnerConfig = field(default_factory=RunnerConfig)
38
+
39
+
40
+ @dataclass
41
+ class DatasetConfig:
42
+ source: str = "huggingface"
43
+ dataset_id: str | None = None
44
+ data_dir: str | None = None
45
+ limit: int | None = None
46
+ split: str = "test"
47
+ subjects: list[str] = field(default_factory=list)
48
+ inline_samples: list[dict[str, Any]] = field(default_factory=list)
49
+
50
+
51
+ @dataclass
52
+ class StorageConfig:
53
+ path: str | None = None
54
+ default_path: str | None = None # New field for default storage path
55
+
56
+
57
+ @dataclass
58
+ class WandbConfig:
59
+ enable: bool = False
60
+ project: str | None = None
61
+ entity: str | None = None
62
+ tags: list[str] = field(default_factory=list)
63
+
64
+
65
+ @dataclass
66
+ class HuggingFaceHubConfig:
67
+ enable: bool = False
68
+ repository: str | None = None
69
+
70
+
71
+ @dataclass
72
+ class IntegrationsConfig:
73
+ wandb: WandbConfig = field(default_factory=WandbConfig)
74
+ huggingface_hub: HuggingFaceHubConfig = field(default_factory=HuggingFaceHubConfig)
75
+
76
+
77
+ @dataclass
78
+ class ExperimentConfig:
79
+ name: str = "math500_zero_shot"
80
+ dataset: DatasetConfig = field(default_factory=DatasetConfig)
81
+ generation: GenerationConfig = field(default_factory=GenerationConfig)
82
+ storage: StorageConfig = field(default_factory=StorageConfig)
83
+ integrations: IntegrationsConfig = field(default_factory=IntegrationsConfig)
84
+ max_samples: int | None = None
85
+ run_id: str | None = None
86
+ resume: bool = True
87
+ task: str | None = None
88
+ task_options: dict[str, Any] = field(default_factory=dict)
89
+
90
+ @classmethod
91
+ def from_file(cls, path: str | Path) -> ExperimentConfig:
92
+ """Load configuration from a file."""
93
+ from .loader import load_experiment_config
94
+
95
+ return load_experiment_config(Path(path))
96
+
97
+ @classmethod
98
+ def from_dict(cls, data: dict[str, Any]) -> ExperimentConfig:
99
+ """Create configuration from a dictionary."""
100
+ from omegaconf import OmegaConf
101
+
102
+ base = OmegaConf.structured(cls)
103
+ merged = OmegaConf.merge(base, OmegaConf.create(data))
104
+ return OmegaConf.to_object(merged) # type: ignore
105
+
106
+ def to_file(self, path: str | Path) -> None:
107
+ """Save configuration to a file."""
108
+ from omegaconf import OmegaConf
109
+
110
+ conf = OmegaConf.structured(self)
111
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
112
+ OmegaConf.save(conf, Path(path))
@@ -0,0 +1,5 @@
1
+ """Core datamodel for Themis."""
2
+
3
+ from . import entities, serialization
4
+
5
+ __all__ = ["entities", "serialization"]
@@ -0,0 +1,354 @@
1
+ """Conversation primitives for multi-turn interactions.
2
+
3
+ This module provides abstractions for managing multi-turn conversations,
4
+ enabling research on dialogue systems, debugging interactions, and
5
+ agentic workflows.
6
+
7
+ Examples:
8
+ # Create a conversation
9
+ context = ConversationContext()
10
+ context.add_message("user", "What is 2+2?")
11
+ context.add_message("assistant", "2+2 equals 4.")
12
+ context.add_message("user", "What about 3+3?")
13
+
14
+ # Convert to prompt
15
+ prompt = context.to_prompt()
16
+
17
+ # Get conversation history
18
+ history = context.get_history(max_turns=2)
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import time
24
+ from dataclasses import dataclass, field
25
+ from typing import Any, Callable, Literal
26
+
27
+ from themis.core import entities as core_entities
28
+ from themis.generation import templates
29
+
30
+ MessageRole = Literal["system", "user", "assistant", "tool"]
31
+
32
+
33
+ @dataclass
34
+ class Message:
35
+ """Single message in a conversation.
36
+
37
+ Attributes:
38
+ role: Message role (system/user/assistant/tool)
39
+ content: Message text content
40
+ metadata: Additional metadata (tool calls, timestamps, etc.)
41
+ timestamp: Unix timestamp when message was created
42
+ """
43
+
44
+ role: MessageRole
45
+ content: str
46
+ metadata: dict[str, Any] = field(default_factory=dict)
47
+ timestamp: float = field(default_factory=time.time)
48
+
49
+ def to_dict(self) -> dict[str, Any]:
50
+ """Convert message to dictionary."""
51
+ return {
52
+ "role": self.role,
53
+ "content": self.content,
54
+ "metadata": self.metadata,
55
+ "timestamp": self.timestamp,
56
+ }
57
+
58
+
59
+ @dataclass
60
+ class ConversationContext:
61
+ """Maintains conversation state across turns.
62
+
63
+ This class manages the conversation history and provides utilities
64
+ for rendering conversations as prompts.
65
+
66
+ Examples:
67
+ context = ConversationContext()
68
+ context.add_message("system", "You are a helpful assistant.")
69
+ context.add_message("user", "Hello!")
70
+ context.add_message("assistant", "Hi! How can I help you?")
71
+
72
+ # Get history
73
+ messages = context.get_history()
74
+
75
+ # Render to prompt
76
+ prompt = context.to_prompt()
77
+ """
78
+
79
+ messages: list[Message] = field(default_factory=list)
80
+ metadata: dict[str, Any] = field(default_factory=dict)
81
+
82
+ def add_message(self, role: MessageRole, content: str, **metadata: Any) -> None:
83
+ """Add a message to the conversation.
84
+
85
+ Args:
86
+ role: Message role (system/user/assistant/tool)
87
+ content: Message text content
88
+ **metadata: Additional metadata to attach to message
89
+ """
90
+ self.messages.append(Message(role=role, content=content, metadata=metadata))
91
+
92
+ def get_history(self, max_turns: int | None = None) -> list[Message]:
93
+ """Get conversation history.
94
+
95
+ Args:
96
+ max_turns: Maximum number of messages to return (from end)
97
+
98
+ Returns:
99
+ List of messages (most recent if limited)
100
+ """
101
+ if max_turns is None:
102
+ return list(self.messages)
103
+ return self.messages[-max_turns:]
104
+
105
+ def get_messages_by_role(self, role: MessageRole) -> list[Message]:
106
+ """Get all messages with a specific role.
107
+
108
+ Args:
109
+ role: Role to filter by
110
+
111
+ Returns:
112
+ List of messages with matching role
113
+ """
114
+ return [msg for msg in self.messages if msg.role == role]
115
+
116
+ def to_prompt(self, template: templates.PromptTemplate | None = None) -> str:
117
+ """Render conversation to prompt string.
118
+
119
+ Args:
120
+ template: Optional template for custom formatting
121
+
122
+ Returns:
123
+ Formatted prompt string
124
+ """
125
+ if template is not None:
126
+ return template.render(messages=self.messages)
127
+
128
+ # Default format: role-prefixed messages
129
+ lines = []
130
+ for msg in self.messages:
131
+ lines.append(f"{msg.role}: {msg.content}")
132
+
133
+ return "\n\n".join(lines)
134
+
135
+ def to_dict(self) -> dict[str, Any]:
136
+ """Convert conversation to dictionary.
137
+
138
+ Returns:
139
+ Dictionary representation
140
+ """
141
+ return {
142
+ "messages": [msg.to_dict() for msg in self.messages],
143
+ "metadata": self.metadata,
144
+ }
145
+
146
+ @classmethod
147
+ def from_dict(cls, data: dict[str, Any]) -> ConversationContext:
148
+ """Create conversation from dictionary.
149
+
150
+ Args:
151
+ data: Dictionary with messages and metadata
152
+
153
+ Returns:
154
+ ConversationContext instance
155
+ """
156
+ context = cls(metadata=data.get("metadata", {}))
157
+ for msg_data in data.get("messages", []):
158
+ context.messages.append(
159
+ Message(
160
+ role=msg_data["role"],
161
+ content=msg_data["content"],
162
+ metadata=msg_data.get("metadata", {}),
163
+ timestamp=msg_data.get("timestamp", time.time()),
164
+ )
165
+ )
166
+ return context
167
+
168
+ def __len__(self) -> int:
169
+ """Return number of messages in conversation."""
170
+ return len(self.messages)
171
+
172
+
173
+ @dataclass
174
+ class ConversationTask:
175
+ """Task for multi-turn conversation execution.
176
+
177
+ This extends the basic GenerationTask concept to support
178
+ multi-turn conversations with configurable stopping conditions.
179
+
180
+ Attributes:
181
+ context: Conversation context with message history
182
+ model: Model to use for generation
183
+ sampling: Sampling configuration
184
+ metadata: Additional metadata
185
+ reference: Optional reference for evaluation
186
+ max_turns: Maximum number of conversation turns
187
+ stop_condition: Optional function to determine when to stop
188
+ """
189
+
190
+ context: ConversationContext
191
+ model: core_entities.ModelSpec
192
+ sampling: core_entities.SamplingConfig
193
+ metadata: dict[str, Any] = field(default_factory=dict)
194
+ reference: core_entities.Reference | None = None
195
+ max_turns: int = 10
196
+ stop_condition: Callable[[ConversationContext], bool] | None = None
197
+
198
+ def should_stop(self) -> bool:
199
+ """Check if conversation should stop.
200
+
201
+ Returns:
202
+ True if stop condition is met or max turns reached
203
+ """
204
+ if len(self.context) >= self.max_turns:
205
+ return True
206
+
207
+ if self.stop_condition is not None:
208
+ return self.stop_condition(self.context)
209
+
210
+ return False
211
+
212
+
213
+ @dataclass
214
+ class ConversationTurn:
215
+ """Single turn in a conversation.
216
+
217
+ Attributes:
218
+ turn_number: Turn index (0-based)
219
+ user_message: User message for this turn (if any)
220
+ generation_record: Generation result for this turn
221
+ context_snapshot: Conversation context at this turn
222
+ """
223
+
224
+ turn_number: int
225
+ user_message: Message | None
226
+ generation_record: core_entities.GenerationRecord
227
+ context_snapshot: ConversationContext
228
+
229
+
230
+ @dataclass
231
+ class ConversationRecord:
232
+ """Complete record of a multi-turn conversation.
233
+
234
+ This is the result of running a ConversationTask through
235
+ a ConversationRunner.
236
+
237
+ Attributes:
238
+ task: Original conversation task
239
+ context: Final conversation context
240
+ turns: List of turns executed
241
+ metadata: Additional metadata (e.g., total turns, stop reason)
242
+ """
243
+
244
+ task: ConversationTask
245
+ context: ConversationContext
246
+ turns: list[ConversationTurn] = field(default_factory=list)
247
+ metadata: dict[str, Any] = field(default_factory=dict)
248
+
249
+ def get_final_output(self) -> core_entities.ModelOutput | None:
250
+ """Get the final model output.
251
+
252
+ Returns:
253
+ Last turn's output, or None if no turns
254
+ """
255
+ if not self.turns:
256
+ return None
257
+ return self.turns[-1].generation_record.output
258
+
259
+ def get_all_outputs(self) -> list[core_entities.ModelOutput | None]:
260
+ """Get all model outputs from all turns.
261
+
262
+ Returns:
263
+ List of outputs (may contain None for failed turns)
264
+ """
265
+ return [turn.generation_record.output for turn in self.turns]
266
+
267
+ def total_turns(self) -> int:
268
+ """Get total number of turns executed.
269
+
270
+ Returns:
271
+ Number of turns
272
+ """
273
+ return len(self.turns)
274
+
275
+
276
+ # Common stop conditions
277
+
278
+
279
+ def stop_on_keyword(keyword: str) -> Callable[[ConversationContext], bool]:
280
+ """Create stop condition that triggers when keyword appears.
281
+
282
+ Args:
283
+ keyword: Keyword to look for in assistant messages
284
+
285
+ Returns:
286
+ Stop condition function
287
+ """
288
+
289
+ def condition(context: ConversationContext) -> bool:
290
+ if not context.messages:
291
+ return False
292
+ last_msg = context.messages[-1]
293
+ if last_msg.role == "assistant":
294
+ return keyword.lower() in last_msg.content.lower()
295
+ return False
296
+
297
+ return condition
298
+
299
+
300
+ def stop_on_pattern(
301
+ pattern: str,
302
+ ) -> Callable[[ConversationContext], bool]:
303
+ """Create stop condition that triggers when regex pattern matches.
304
+
305
+ Args:
306
+ pattern: Regex pattern to match
307
+
308
+ Returns:
309
+ Stop condition function
310
+ """
311
+ import re
312
+
313
+ compiled = re.compile(pattern, re.IGNORECASE)
314
+
315
+ def condition(context: ConversationContext) -> bool:
316
+ if not context.messages:
317
+ return False
318
+ last_msg = context.messages[-1]
319
+ if last_msg.role == "assistant":
320
+ return compiled.search(last_msg.content) is not None
321
+ return False
322
+
323
+ return condition
324
+
325
+
326
+ def stop_on_empty_response() -> Callable[[ConversationContext], bool]:
327
+ """Create stop condition that triggers on empty assistant response.
328
+
329
+ Returns:
330
+ Stop condition function
331
+ """
332
+
333
+ def condition(context: ConversationContext) -> bool:
334
+ if not context.messages:
335
+ return False
336
+ last_msg = context.messages[-1]
337
+ if last_msg.role == "assistant":
338
+ return not last_msg.content.strip()
339
+ return False
340
+
341
+ return condition
342
+
343
+
344
+ __all__ = [
345
+ "MessageRole",
346
+ "Message",
347
+ "ConversationContext",
348
+ "ConversationTask",
349
+ "ConversationTurn",
350
+ "ConversationRecord",
351
+ "stop_on_keyword",
352
+ "stop_on_pattern",
353
+ "stop_on_empty_response",
354
+ ]
@@ -0,0 +1,184 @@
1
+ """Shared dataclasses that represent Themis' internal world."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import TYPE_CHECKING, Any, Dict, Generic, List, TypeVar
7
+
8
+ if TYPE_CHECKING:
9
+ from themis.evaluation.reports import EvaluationReport
10
+
11
+ # Type variable for generic Reference
12
+ T = TypeVar("T")
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class SamplingConfig:
17
+ temperature: float
18
+ top_p: float
19
+ max_tokens: int
20
+
21
+
22
+ @dataclass(frozen=True)
23
+ class ModelSpec:
24
+ identifier: str
25
+ provider: str
26
+ default_sampling: SamplingConfig | None = None
27
+ metadata: Dict[str, Any] = field(default_factory=dict)
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class PromptSpec:
32
+ name: str
33
+ template: str
34
+ metadata: Dict[str, Any] = field(default_factory=dict)
35
+
36
+
37
+ @dataclass(frozen=True)
38
+ class PromptRender:
39
+ spec: PromptSpec
40
+ text: str
41
+ context: Dict[str, Any] = field(default_factory=dict)
42
+ metadata: Dict[str, Any] = field(default_factory=dict)
43
+
44
+ @property
45
+ def prompt_text(self) -> str:
46
+ return self.text
47
+
48
+ @property
49
+ def template_name(self) -> str:
50
+ return self.spec.name
51
+
52
+
53
+ @dataclass(frozen=True)
54
+ class Reference(Generic[T]):
55
+ """Reference value with optional type information.
56
+
57
+ This is a generic dataclass that can hold typed reference values.
58
+ For backward compatibility, it can be used without type parameters
59
+ and will behave like Reference[Any].
60
+
61
+ The value field can hold any type including:
62
+ - Simple types: str, int, float, bool
63
+ - Collections: list, tuple, set
64
+ - Dictionaries: dict (for multi-value references)
65
+ - Custom objects
66
+
67
+ Examples:
68
+ # Simple reference
69
+ ref = Reference(kind="answer", value="42")
70
+
71
+ # Multi-value reference using dict
72
+ ref = Reference(
73
+ kind="countdown_task",
74
+ value={"target": 122, "numbers": [25, 50, 75, 100]}
75
+ )
76
+
77
+ # List reference
78
+ ref = Reference(kind="valid_answers", value=["yes", "no", "maybe"])
79
+
80
+ # Typed reference
81
+ ref: Reference[str] = Reference(kind="answer", value="42")
82
+ ref: Reference[dict] = Reference(kind="task", value={"a": 1, "b": 2})
83
+
84
+ Note:
85
+ When using dict values, metrics can access individual fields directly:
86
+ >>> target = reference.value["target"]
87
+ >>> numbers = reference.value["numbers"]
88
+ """
89
+
90
+ kind: str
91
+ value: T
92
+ schema: type[T] | None = None # Optional runtime type information
93
+
94
+
95
+ @dataclass(frozen=True)
96
+ class ModelOutput:
97
+ text: str
98
+ raw: Any | None = None
99
+ usage: Dict[str, int] | None = None # Token usage: {prompt_tokens, completion_tokens, total_tokens}
100
+
101
+
102
+ @dataclass(frozen=True)
103
+ class ModelError:
104
+ message: str
105
+ kind: str = "model_error"
106
+ details: Dict[str, Any] = field(default_factory=dict)
107
+
108
+
109
+ @dataclass
110
+ class GenerationTask:
111
+ prompt: PromptRender
112
+ model: ModelSpec
113
+ sampling: SamplingConfig
114
+ metadata: Dict[str, Any] = field(default_factory=dict)
115
+ reference: Reference | None = None
116
+
117
+
118
+ @dataclass
119
+ class GenerationRecord:
120
+ task: GenerationTask
121
+ output: ModelOutput | None
122
+ error: ModelError | None
123
+ metrics: Dict[str, Any] = field(default_factory=dict)
124
+ attempts: List["GenerationRecord"] = field(default_factory=list)
125
+
126
+
127
+ @dataclass(frozen=True)
128
+ class EvaluationItem:
129
+ record: GenerationRecord
130
+ reference: Reference | None
131
+
132
+
133
+ @dataclass(frozen=True)
134
+ class MetricScore:
135
+ metric_name: str
136
+ value: float
137
+ details: Dict[str, Any] = field(default_factory=dict)
138
+ metadata: Dict[str, Any] = field(default_factory=dict)
139
+
140
+
141
+ @dataclass
142
+ class EvaluationSummary:
143
+ scores: List[MetricScore]
144
+ failures: List[str] = field(default_factory=list)
145
+
146
+
147
+ @dataclass
148
+ class EvaluationRecord:
149
+ sample_id: str | None
150
+ scores: List[MetricScore]
151
+ failures: List[str] = field(default_factory=list)
152
+
153
+
154
+ @dataclass
155
+ class ExperimentFailure:
156
+ sample_id: str | None
157
+ message: str
158
+
159
+
160
+ @dataclass
161
+ class ExperimentReport:
162
+ generation_results: list[GenerationRecord]
163
+ evaluation_report: "EvaluationReport"
164
+ failures: list[ExperimentFailure]
165
+ metadata: dict[str, object]
166
+
167
+
168
+ __all__ = [
169
+ "SamplingConfig",
170
+ "ModelSpec",
171
+ "PromptSpec",
172
+ "PromptRender",
173
+ "Reference",
174
+ "ModelOutput",
175
+ "ModelError",
176
+ "GenerationTask",
177
+ "GenerationRecord",
178
+ "EvaluationItem",
179
+ "EvaluationRecord",
180
+ "MetricScore",
181
+ "EvaluationSummary",
182
+ "ExperimentFailure",
183
+ "ExperimentReport",
184
+ ]