themis-eval 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. themis/cli/__init__.py +5 -0
  2. themis/cli/__main__.py +6 -0
  3. themis/cli/commands/__init__.py +19 -0
  4. themis/cli/commands/benchmarks.py +221 -0
  5. themis/cli/commands/comparison.py +394 -0
  6. themis/cli/commands/config_commands.py +244 -0
  7. themis/cli/commands/cost.py +214 -0
  8. themis/cli/commands/demo.py +68 -0
  9. themis/cli/commands/info.py +90 -0
  10. themis/cli/commands/leaderboard.py +362 -0
  11. themis/cli/commands/math_benchmarks.py +318 -0
  12. themis/cli/commands/mcq_benchmarks.py +207 -0
  13. themis/cli/commands/sample_run.py +244 -0
  14. themis/cli/commands/visualize.py +299 -0
  15. themis/cli/main.py +93 -0
  16. themis/cli/new_project.py +33 -0
  17. themis/cli/utils.py +51 -0
  18. themis/config/__init__.py +19 -0
  19. themis/config/loader.py +27 -0
  20. themis/config/registry.py +34 -0
  21. themis/config/runtime.py +214 -0
  22. themis/config/schema.py +112 -0
  23. themis/core/__init__.py +5 -0
  24. themis/core/conversation.py +354 -0
  25. themis/core/entities.py +164 -0
  26. themis/core/serialization.py +231 -0
  27. themis/core/tools.py +393 -0
  28. themis/core/types.py +141 -0
  29. themis/datasets/__init__.py +273 -0
  30. themis/datasets/base.py +264 -0
  31. themis/datasets/commonsense_qa.py +174 -0
  32. themis/datasets/competition_math.py +265 -0
  33. themis/datasets/coqa.py +133 -0
  34. themis/datasets/gpqa.py +190 -0
  35. themis/datasets/gsm8k.py +123 -0
  36. themis/datasets/gsm_symbolic.py +124 -0
  37. themis/datasets/math500.py +122 -0
  38. themis/datasets/med_qa.py +179 -0
  39. themis/datasets/medmcqa.py +169 -0
  40. themis/datasets/mmlu_pro.py +262 -0
  41. themis/datasets/piqa.py +146 -0
  42. themis/datasets/registry.py +201 -0
  43. themis/datasets/schema.py +245 -0
  44. themis/datasets/sciq.py +150 -0
  45. themis/datasets/social_i_qa.py +151 -0
  46. themis/datasets/super_gpqa.py +263 -0
  47. themis/evaluation/__init__.py +1 -0
  48. themis/evaluation/conditional.py +410 -0
  49. themis/evaluation/extractors/__init__.py +19 -0
  50. themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
  51. themis/evaluation/extractors/exceptions.py +7 -0
  52. themis/evaluation/extractors/identity_extractor.py +29 -0
  53. themis/evaluation/extractors/json_field_extractor.py +45 -0
  54. themis/evaluation/extractors/math_verify_extractor.py +37 -0
  55. themis/evaluation/extractors/regex_extractor.py +43 -0
  56. themis/evaluation/math_verify_utils.py +87 -0
  57. themis/evaluation/metrics/__init__.py +21 -0
  58. themis/evaluation/metrics/composite_metric.py +47 -0
  59. themis/evaluation/metrics/consistency_metric.py +80 -0
  60. themis/evaluation/metrics/exact_match.py +51 -0
  61. themis/evaluation/metrics/length_difference_tolerance.py +33 -0
  62. themis/evaluation/metrics/math_verify_accuracy.py +40 -0
  63. themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
  64. themis/evaluation/metrics/response_length.py +33 -0
  65. themis/evaluation/metrics/rubric_judge_metric.py +134 -0
  66. themis/evaluation/pipeline.py +49 -0
  67. themis/evaluation/pipelines/__init__.py +15 -0
  68. themis/evaluation/pipelines/composable_pipeline.py +357 -0
  69. themis/evaluation/pipelines/standard_pipeline.py +288 -0
  70. themis/evaluation/reports.py +293 -0
  71. themis/evaluation/statistics/__init__.py +53 -0
  72. themis/evaluation/statistics/bootstrap.py +79 -0
  73. themis/evaluation/statistics/confidence_intervals.py +121 -0
  74. themis/evaluation/statistics/distributions.py +207 -0
  75. themis/evaluation/statistics/effect_sizes.py +124 -0
  76. themis/evaluation/statistics/hypothesis_tests.py +305 -0
  77. themis/evaluation/statistics/types.py +139 -0
  78. themis/evaluation/strategies/__init__.py +13 -0
  79. themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
  80. themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
  81. themis/evaluation/strategies/evaluation_strategy.py +24 -0
  82. themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
  83. themis/experiment/__init__.py +5 -0
  84. themis/experiment/builder.py +151 -0
  85. themis/experiment/cache_manager.py +129 -0
  86. themis/experiment/comparison.py +631 -0
  87. themis/experiment/cost.py +310 -0
  88. themis/experiment/definitions.py +62 -0
  89. themis/experiment/export.py +690 -0
  90. themis/experiment/export_csv.py +159 -0
  91. themis/experiment/integration_manager.py +104 -0
  92. themis/experiment/math.py +192 -0
  93. themis/experiment/mcq.py +169 -0
  94. themis/experiment/orchestrator.py +373 -0
  95. themis/experiment/pricing.py +317 -0
  96. themis/experiment/storage.py +255 -0
  97. themis/experiment/visualization.py +588 -0
  98. themis/generation/__init__.py +1 -0
  99. themis/generation/agentic_runner.py +420 -0
  100. themis/generation/batching.py +254 -0
  101. themis/generation/clients.py +143 -0
  102. themis/generation/conversation_runner.py +236 -0
  103. themis/generation/plan.py +456 -0
  104. themis/generation/providers/litellm_provider.py +221 -0
  105. themis/generation/providers/vllm_provider.py +135 -0
  106. themis/generation/router.py +34 -0
  107. themis/generation/runner.py +207 -0
  108. themis/generation/strategies.py +98 -0
  109. themis/generation/templates.py +71 -0
  110. themis/generation/turn_strategies.py +393 -0
  111. themis/generation/types.py +9 -0
  112. themis/integrations/__init__.py +0 -0
  113. themis/integrations/huggingface.py +61 -0
  114. themis/integrations/wandb.py +65 -0
  115. themis/interfaces/__init__.py +83 -0
  116. themis/project/__init__.py +20 -0
  117. themis/project/definitions.py +98 -0
  118. themis/project/patterns.py +230 -0
  119. themis/providers/__init__.py +5 -0
  120. themis/providers/registry.py +39 -0
  121. themis/utils/api_generator.py +379 -0
  122. themis/utils/cost_tracking.py +376 -0
  123. themis/utils/dashboard.py +452 -0
  124. themis/utils/logging_utils.py +41 -0
  125. themis/utils/progress.py +58 -0
  126. themis/utils/tracing.py +320 -0
  127. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/METADATA +1 -1
  128. themis_eval-0.1.1.dist-info/RECORD +134 -0
  129. themis_eval-0.1.0.dist-info/RECORD +0 -8
  130. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/WHEEL +0 -0
  131. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/licenses/LICENSE +0 -0
  132. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,214 @@
1
+ """Runtime helpers for executing experiments from Hydra configs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import asdict
6
+ from pathlib import Path
7
+ from typing import List
8
+
9
+ from themis.core import entities as core_entities
10
+ from themis.datasets import create_dataset
11
+ from themis.experiment import math as math_experiment
12
+ from themis.experiment import mcq as mcq_experiment
13
+ from themis.experiment import orchestrator as experiment_orchestrator
14
+ from themis.experiment import storage as experiment_storage
15
+ from themis.providers import registry as provider_registry
16
+
17
+ from . import registry, schema
18
+
19
+
20
+
21
+
22
+ def run_experiment_from_config(
23
+ config: schema.ExperimentConfig,
24
+ *,
25
+ dataset: list[dict[str, object]] | None = None,
26
+ on_result=None,
27
+ ) -> experiment_orchestrator.ExperimentReport:
28
+ dataset_to_use = (
29
+ dataset
30
+ if dataset is not None
31
+ else _load_dataset(config.dataset, experiment_name=config.name)
32
+ )
33
+ experiment = _build_experiment(config)
34
+ return experiment.run(
35
+ dataset_to_use,
36
+ max_samples=config.max_samples,
37
+ run_id=config.run_id,
38
+ resume=config.resume,
39
+ on_result=on_result,
40
+ )
41
+
42
+
43
+ def summarize_report_for_config(
44
+ config: schema.ExperimentConfig,
45
+ report: experiment_orchestrator.ExperimentReport,
46
+ ) -> str:
47
+ if config.task in {
48
+ "math500",
49
+ "aime24",
50
+ "aime25",
51
+ "amc23",
52
+ "olympiadbench",
53
+ "beyondaime",
54
+ }:
55
+ return math_experiment.summarize_report(report)
56
+ if config.task in {"supergpqa", "mmlu_pro"}:
57
+ return mcq_experiment.summarize_report(report)
58
+ raise ValueError(f"Unsupported task '{config.task}' for summarization.")
59
+
60
+
61
+ def load_dataset_from_config(
62
+ config: schema.ExperimentConfig,
63
+ ) -> list[dict[str, object]]:
64
+ return _load_dataset(config.dataset, experiment_name=config.name)
65
+
66
+
67
+ def _build_experiment(
68
+ config: schema.ExperimentConfig,
69
+ ) -> experiment_orchestrator.ExperimentOrchestrator:
70
+ if config.task:
71
+ builder = registry.get_experiment_builder(config.task)
72
+ return builder(config)
73
+
74
+ raise ValueError(
75
+ "Experiment configuration must specify a 'task'. "
76
+ f"Available tasks: {', '.join(sorted(registry._EXPERIMENT_BUILDERS.keys()))}"
77
+ )
78
+
79
+
80
+ @registry.register_experiment_builder("math500")
81
+ @registry.register_experiment_builder("aime24")
82
+ @registry.register_experiment_builder("aime25")
83
+ @registry.register_experiment_builder("amc23")
84
+ @registry.register_experiment_builder("olympiadbench")
85
+ @registry.register_experiment_builder("beyondaime")
86
+ def _build_math_experiment(
87
+ config: schema.ExperimentConfig,
88
+ ) -> experiment_orchestrator.ExperimentOrchestrator:
89
+ # Use the specific path if provided, otherwise use the default path
90
+ storage_path = config.storage.path or config.storage.default_path
91
+ storage = (
92
+ experiment_storage.ExperimentStorage(Path(storage_path))
93
+ if storage_path
94
+ else None
95
+ )
96
+ sampling_cfg = core_entities.SamplingConfig(
97
+ temperature=config.generation.sampling.temperature,
98
+ top_p=config.generation.sampling.top_p,
99
+ max_tokens=config.generation.sampling.max_tokens,
100
+ )
101
+ provider = provider_registry.create_provider(
102
+ config.generation.provider.name, **config.generation.provider.options
103
+ )
104
+ runner_options = asdict(config.generation.runner)
105
+
106
+ # Use the task name from config as the default task name
107
+ task_name = config.task or "math500"
108
+ # Override task name if provided in task_options
109
+ if config.task_options and "task_name" in config.task_options:
110
+ task_name = config.task_options["task_name"]
111
+
112
+ return math_experiment.build_math500_zero_shot_experiment(
113
+ model_client=provider,
114
+ model_name=config.generation.model_identifier,
115
+ storage=storage,
116
+ sampling=sampling_cfg,
117
+ provider_name=config.generation.provider.name,
118
+ runner_options=runner_options,
119
+ task_name=task_name,
120
+ )
121
+
122
+
123
+ @registry.register_experiment_builder("supergpqa")
124
+ def _build_supergpqa_experiment(
125
+ config: schema.ExperimentConfig,
126
+ ) -> experiment_orchestrator.ExperimentOrchestrator:
127
+ return _build_mcq_experiment(config, "supergpqa", "supergpqa")
128
+
129
+
130
+ @registry.register_experiment_builder("mmlu_pro")
131
+ def _build_mmlu_pro_experiment(
132
+ config: schema.ExperimentConfig,
133
+ ) -> experiment_orchestrator.ExperimentOrchestrator:
134
+ return _build_mcq_experiment(config, "mmlu-pro", "mmlu_pro")
135
+
136
+
137
+ def _build_mcq_experiment(
138
+ config: schema.ExperimentConfig, dataset_name: str, task_id: str
139
+ ) -> experiment_orchestrator.ExperimentOrchestrator:
140
+ # Use the specific path if provided, otherwise use the default path
141
+ storage_path = config.storage.path or config.storage.default_path
142
+ storage = (
143
+ experiment_storage.ExperimentStorage(Path(storage_path))
144
+ if storage_path
145
+ else None
146
+ )
147
+ sampling_cfg = core_entities.SamplingConfig(
148
+ temperature=config.generation.sampling.temperature,
149
+ top_p=config.generation.sampling.top_p,
150
+ max_tokens=config.generation.sampling.max_tokens,
151
+ )
152
+ provider = provider_registry.create_provider(
153
+ config.generation.provider.name, **config.generation.provider.options
154
+ )
155
+ runner_options = asdict(config.generation.runner)
156
+
157
+ return mcq_experiment.build_multiple_choice_json_experiment(
158
+ dataset_name=dataset_name,
159
+ task_id=task_id,
160
+ model_client=provider,
161
+ model_name=config.generation.model_identifier,
162
+ storage=storage,
163
+ sampling=sampling_cfg,
164
+ provider_name=config.generation.provider.name,
165
+ runner_options=runner_options,
166
+ )
167
+
168
+
169
+ def _load_dataset(
170
+ config: schema.DatasetConfig, *, experiment_name: str
171
+ ) -> List[dict[str, object]]:
172
+ """Load dataset samples using the dataset registry.
173
+
174
+ Args:
175
+ config: Dataset configuration
176
+ experiment_name: Name of the experiment (used to map to dataset)
177
+
178
+ Returns:
179
+ List of sample dictionaries ready for generation
180
+ """
181
+ # Handle inline datasets (not in registry)
182
+ if config.source == "inline":
183
+ if not config.inline_samples:
184
+ raise ValueError(
185
+ "dataset.inline_samples must contain at least one row when"
186
+ " dataset.source='inline'."
187
+ )
188
+ return list(config.inline_samples)
189
+
190
+ # Use explicit dataset_id if provided
191
+ dataset_name = config.dataset_id
192
+ if not dataset_name:
193
+ # Fallback to task name if dataset_id is not provided
194
+ # This allows simple configs where task name matches dataset name
195
+ # But we should probably enforce dataset_id for clarity in the future
196
+ # For now, let's try to infer from task if available in config object passed to this function?
197
+ # Wait, _load_dataset only gets DatasetConfig and experiment_name.
198
+ # We should probably pass the full config or at least the task.
199
+ # But for now, let's rely on dataset_id being present or raise error.
200
+ raise ValueError(
201
+ "dataset.dataset_id must be provided when source is not 'inline'."
202
+ )
203
+
204
+ # Prepare options for dataset factory
205
+ options = {
206
+ "source": config.source,
207
+ "data_dir": config.data_dir,
208
+ "split": config.split,
209
+ "limit": config.limit,
210
+ "subjects": list(config.subjects) if config.subjects else None,
211
+ }
212
+
213
+ # Load samples via registry
214
+ return create_dataset(dataset_name, **options)
@@ -0,0 +1,112 @@
1
+ """Structured configuration definitions for Hydra/OmegaConf."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+
10
+ @dataclass
11
+ class ProviderConfig:
12
+ name: str = "fake"
13
+ options: dict[str, Any] = field(default_factory=dict)
14
+
15
+
16
+ @dataclass
17
+ class RunnerConfig:
18
+ max_parallel: int = 1
19
+ max_retries: int = 3
20
+ retry_initial_delay: float = 0.5
21
+ retry_backoff_multiplier: float = 2.0
22
+ retry_max_delay: float | None = 2.0
23
+
24
+
25
+ @dataclass
26
+ class SamplingConfig:
27
+ temperature: float = 0.0
28
+ top_p: float = 0.95
29
+ max_tokens: int = 512
30
+
31
+
32
+ @dataclass
33
+ class GenerationConfig:
34
+ model_identifier: str = "fake-math-llm"
35
+ provider: ProviderConfig = field(default_factory=ProviderConfig)
36
+ sampling: SamplingConfig = field(default_factory=SamplingConfig)
37
+ runner: RunnerConfig = field(default_factory=RunnerConfig)
38
+
39
+
40
+ @dataclass
41
+ class DatasetConfig:
42
+ source: str = "huggingface"
43
+ dataset_id: str | None = None
44
+ data_dir: str | None = None
45
+ limit: int | None = None
46
+ split: str = "test"
47
+ subjects: list[str] = field(default_factory=list)
48
+ inline_samples: list[dict[str, Any]] = field(default_factory=list)
49
+
50
+
51
+ @dataclass
52
+ class StorageConfig:
53
+ path: str | None = None
54
+ default_path: str | None = None # New field for default storage path
55
+
56
+
57
+ @dataclass
58
+ class WandbConfig:
59
+ enable: bool = False
60
+ project: str | None = None
61
+ entity: str | None = None
62
+ tags: list[str] = field(default_factory=list)
63
+
64
+
65
+ @dataclass
66
+ class HuggingFaceHubConfig:
67
+ enable: bool = False
68
+ repository: str | None = None
69
+
70
+
71
+ @dataclass
72
+ class IntegrationsConfig:
73
+ wandb: WandbConfig = field(default_factory=WandbConfig)
74
+ huggingface_hub: HuggingFaceHubConfig = field(default_factory=HuggingFaceHubConfig)
75
+
76
+
77
+ @dataclass
78
+ class ExperimentConfig:
79
+ name: str = "math500_zero_shot"
80
+ dataset: DatasetConfig = field(default_factory=DatasetConfig)
81
+ generation: GenerationConfig = field(default_factory=GenerationConfig)
82
+ storage: StorageConfig = field(default_factory=StorageConfig)
83
+ integrations: IntegrationsConfig = field(default_factory=IntegrationsConfig)
84
+ max_samples: int | None = None
85
+ run_id: str | None = None
86
+ resume: bool = True
87
+ task: str | None = None
88
+ task_options: dict[str, Any] = field(default_factory=dict)
89
+
90
+ @classmethod
91
+ def from_file(cls, path: str | Path) -> ExperimentConfig:
92
+ """Load configuration from a file."""
93
+ from .loader import load_experiment_config
94
+
95
+ return load_experiment_config(Path(path))
96
+
97
+ @classmethod
98
+ def from_dict(cls, data: dict[str, Any]) -> ExperimentConfig:
99
+ """Create configuration from a dictionary."""
100
+ from omegaconf import OmegaConf
101
+
102
+ base = OmegaConf.structured(cls)
103
+ merged = OmegaConf.merge(base, OmegaConf.create(data))
104
+ return OmegaConf.to_object(merged) # type: ignore
105
+
106
+ def to_file(self, path: str | Path) -> None:
107
+ """Save configuration to a file."""
108
+ from omegaconf import OmegaConf
109
+
110
+ conf = OmegaConf.structured(self)
111
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
112
+ OmegaConf.save(conf, Path(path))
@@ -0,0 +1,5 @@
1
+ """Core datamodel for Themis."""
2
+
3
+ from . import entities, serialization
4
+
5
+ __all__ = ["entities", "serialization"]
@@ -0,0 +1,354 @@
1
+ """Conversation primitives for multi-turn interactions.
2
+
3
+ This module provides abstractions for managing multi-turn conversations,
4
+ enabling research on dialogue systems, debugging interactions, and
5
+ agentic workflows.
6
+
7
+ Examples:
8
+ # Create a conversation
9
+ context = ConversationContext()
10
+ context.add_message("user", "What is 2+2?")
11
+ context.add_message("assistant", "2+2 equals 4.")
12
+ context.add_message("user", "What about 3+3?")
13
+
14
+ # Convert to prompt
15
+ prompt = context.to_prompt()
16
+
17
+ # Get conversation history
18
+ history = context.get_history(max_turns=2)
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import time
24
+ from dataclasses import dataclass, field
25
+ from typing import Any, Callable, Literal
26
+
27
+ from themis.core import entities as core_entities
28
+ from themis.generation import templates
29
+
30
+ MessageRole = Literal["system", "user", "assistant", "tool"]
31
+
32
+
33
+ @dataclass
34
+ class Message:
35
+ """Single message in a conversation.
36
+
37
+ Attributes:
38
+ role: Message role (system/user/assistant/tool)
39
+ content: Message text content
40
+ metadata: Additional metadata (tool calls, timestamps, etc.)
41
+ timestamp: Unix timestamp when message was created
42
+ """
43
+
44
+ role: MessageRole
45
+ content: str
46
+ metadata: dict[str, Any] = field(default_factory=dict)
47
+ timestamp: float = field(default_factory=time.time)
48
+
49
+ def to_dict(self) -> dict[str, Any]:
50
+ """Convert message to dictionary."""
51
+ return {
52
+ "role": self.role,
53
+ "content": self.content,
54
+ "metadata": self.metadata,
55
+ "timestamp": self.timestamp,
56
+ }
57
+
58
+
59
+ @dataclass
60
+ class ConversationContext:
61
+ """Maintains conversation state across turns.
62
+
63
+ This class manages the conversation history and provides utilities
64
+ for rendering conversations as prompts.
65
+
66
+ Examples:
67
+ context = ConversationContext()
68
+ context.add_message("system", "You are a helpful assistant.")
69
+ context.add_message("user", "Hello!")
70
+ context.add_message("assistant", "Hi! How can I help you?")
71
+
72
+ # Get history
73
+ messages = context.get_history()
74
+
75
+ # Render to prompt
76
+ prompt = context.to_prompt()
77
+ """
78
+
79
+ messages: list[Message] = field(default_factory=list)
80
+ metadata: dict[str, Any] = field(default_factory=dict)
81
+
82
+ def add_message(self, role: MessageRole, content: str, **metadata: Any) -> None:
83
+ """Add a message to the conversation.
84
+
85
+ Args:
86
+ role: Message role (system/user/assistant/tool)
87
+ content: Message text content
88
+ **metadata: Additional metadata to attach to message
89
+ """
90
+ self.messages.append(Message(role=role, content=content, metadata=metadata))
91
+
92
+ def get_history(self, max_turns: int | None = None) -> list[Message]:
93
+ """Get conversation history.
94
+
95
+ Args:
96
+ max_turns: Maximum number of messages to return (from end)
97
+
98
+ Returns:
99
+ List of messages (most recent if limited)
100
+ """
101
+ if max_turns is None:
102
+ return list(self.messages)
103
+ return self.messages[-max_turns:]
104
+
105
+ def get_messages_by_role(self, role: MessageRole) -> list[Message]:
106
+ """Get all messages with a specific role.
107
+
108
+ Args:
109
+ role: Role to filter by
110
+
111
+ Returns:
112
+ List of messages with matching role
113
+ """
114
+ return [msg for msg in self.messages if msg.role == role]
115
+
116
+ def to_prompt(self, template: templates.PromptTemplate | None = None) -> str:
117
+ """Render conversation to prompt string.
118
+
119
+ Args:
120
+ template: Optional template for custom formatting
121
+
122
+ Returns:
123
+ Formatted prompt string
124
+ """
125
+ if template is not None:
126
+ return template.render(messages=self.messages)
127
+
128
+ # Default format: role-prefixed messages
129
+ lines = []
130
+ for msg in self.messages:
131
+ lines.append(f"{msg.role}: {msg.content}")
132
+
133
+ return "\n\n".join(lines)
134
+
135
+ def to_dict(self) -> dict[str, Any]:
136
+ """Convert conversation to dictionary.
137
+
138
+ Returns:
139
+ Dictionary representation
140
+ """
141
+ return {
142
+ "messages": [msg.to_dict() for msg in self.messages],
143
+ "metadata": self.metadata,
144
+ }
145
+
146
+ @classmethod
147
+ def from_dict(cls, data: dict[str, Any]) -> ConversationContext:
148
+ """Create conversation from dictionary.
149
+
150
+ Args:
151
+ data: Dictionary with messages and metadata
152
+
153
+ Returns:
154
+ ConversationContext instance
155
+ """
156
+ context = cls(metadata=data.get("metadata", {}))
157
+ for msg_data in data.get("messages", []):
158
+ context.messages.append(
159
+ Message(
160
+ role=msg_data["role"],
161
+ content=msg_data["content"],
162
+ metadata=msg_data.get("metadata", {}),
163
+ timestamp=msg_data.get("timestamp", time.time()),
164
+ )
165
+ )
166
+ return context
167
+
168
+ def __len__(self) -> int:
169
+ """Return number of messages in conversation."""
170
+ return len(self.messages)
171
+
172
+
173
+ @dataclass
174
+ class ConversationTask:
175
+ """Task for multi-turn conversation execution.
176
+
177
+ This extends the basic GenerationTask concept to support
178
+ multi-turn conversations with configurable stopping conditions.
179
+
180
+ Attributes:
181
+ context: Conversation context with message history
182
+ model: Model to use for generation
183
+ sampling: Sampling configuration
184
+ metadata: Additional metadata
185
+ reference: Optional reference for evaluation
186
+ max_turns: Maximum number of conversation turns
187
+ stop_condition: Optional function to determine when to stop
188
+ """
189
+
190
+ context: ConversationContext
191
+ model: core_entities.ModelSpec
192
+ sampling: core_entities.SamplingConfig
193
+ metadata: dict[str, Any] = field(default_factory=dict)
194
+ reference: core_entities.Reference | None = None
195
+ max_turns: int = 10
196
+ stop_condition: Callable[[ConversationContext], bool] | None = None
197
+
198
+ def should_stop(self) -> bool:
199
+ """Check if conversation should stop.
200
+
201
+ Returns:
202
+ True if stop condition is met or max turns reached
203
+ """
204
+ if len(self.context) >= self.max_turns:
205
+ return True
206
+
207
+ if self.stop_condition is not None:
208
+ return self.stop_condition(self.context)
209
+
210
+ return False
211
+
212
+
213
+ @dataclass
214
+ class ConversationTurn:
215
+ """Single turn in a conversation.
216
+
217
+ Attributes:
218
+ turn_number: Turn index (0-based)
219
+ user_message: User message for this turn (if any)
220
+ generation_record: Generation result for this turn
221
+ context_snapshot: Conversation context at this turn
222
+ """
223
+
224
+ turn_number: int
225
+ user_message: Message | None
226
+ generation_record: core_entities.GenerationRecord
227
+ context_snapshot: ConversationContext
228
+
229
+
230
+ @dataclass
231
+ class ConversationRecord:
232
+ """Complete record of a multi-turn conversation.
233
+
234
+ This is the result of running a ConversationTask through
235
+ a ConversationRunner.
236
+
237
+ Attributes:
238
+ task: Original conversation task
239
+ context: Final conversation context
240
+ turns: List of turns executed
241
+ metadata: Additional metadata (e.g., total turns, stop reason)
242
+ """
243
+
244
+ task: ConversationTask
245
+ context: ConversationContext
246
+ turns: list[ConversationTurn] = field(default_factory=list)
247
+ metadata: dict[str, Any] = field(default_factory=dict)
248
+
249
+ def get_final_output(self) -> core_entities.ModelOutput | None:
250
+ """Get the final model output.
251
+
252
+ Returns:
253
+ Last turn's output, or None if no turns
254
+ """
255
+ if not self.turns:
256
+ return None
257
+ return self.turns[-1].generation_record.output
258
+
259
+ def get_all_outputs(self) -> list[core_entities.ModelOutput | None]:
260
+ """Get all model outputs from all turns.
261
+
262
+ Returns:
263
+ List of outputs (may contain None for failed turns)
264
+ """
265
+ return [turn.generation_record.output for turn in self.turns]
266
+
267
+ def total_turns(self) -> int:
268
+ """Get total number of turns executed.
269
+
270
+ Returns:
271
+ Number of turns
272
+ """
273
+ return len(self.turns)
274
+
275
+
276
+ # Common stop conditions
277
+
278
+
279
+ def stop_on_keyword(keyword: str) -> Callable[[ConversationContext], bool]:
280
+ """Create stop condition that triggers when keyword appears.
281
+
282
+ Args:
283
+ keyword: Keyword to look for in assistant messages
284
+
285
+ Returns:
286
+ Stop condition function
287
+ """
288
+
289
+ def condition(context: ConversationContext) -> bool:
290
+ if not context.messages:
291
+ return False
292
+ last_msg = context.messages[-1]
293
+ if last_msg.role == "assistant":
294
+ return keyword.lower() in last_msg.content.lower()
295
+ return False
296
+
297
+ return condition
298
+
299
+
300
+ def stop_on_pattern(
301
+ pattern: str,
302
+ ) -> Callable[[ConversationContext], bool]:
303
+ """Create stop condition that triggers when regex pattern matches.
304
+
305
+ Args:
306
+ pattern: Regex pattern to match
307
+
308
+ Returns:
309
+ Stop condition function
310
+ """
311
+ import re
312
+
313
+ compiled = re.compile(pattern, re.IGNORECASE)
314
+
315
+ def condition(context: ConversationContext) -> bool:
316
+ if not context.messages:
317
+ return False
318
+ last_msg = context.messages[-1]
319
+ if last_msg.role == "assistant":
320
+ return compiled.search(last_msg.content) is not None
321
+ return False
322
+
323
+ return condition
324
+
325
+
326
+ def stop_on_empty_response() -> Callable[[ConversationContext], bool]:
327
+ """Create stop condition that triggers on empty assistant response.
328
+
329
+ Returns:
330
+ Stop condition function
331
+ """
332
+
333
+ def condition(context: ConversationContext) -> bool:
334
+ if not context.messages:
335
+ return False
336
+ last_msg = context.messages[-1]
337
+ if last_msg.role == "assistant":
338
+ return not last_msg.content.strip()
339
+ return False
340
+
341
+ return condition
342
+
343
+
344
+ __all__ = [
345
+ "MessageRole",
346
+ "Message",
347
+ "ConversationContext",
348
+ "ConversationTask",
349
+ "ConversationTurn",
350
+ "ConversationRecord",
351
+ "stop_on_keyword",
352
+ "stop_on_pattern",
353
+ "stop_on_empty_response",
354
+ ]