themis-eval 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themis/cli/__init__.py +5 -0
- themis/cli/__main__.py +6 -0
- themis/cli/commands/__init__.py +19 -0
- themis/cli/commands/benchmarks.py +221 -0
- themis/cli/commands/comparison.py +394 -0
- themis/cli/commands/config_commands.py +244 -0
- themis/cli/commands/cost.py +214 -0
- themis/cli/commands/demo.py +68 -0
- themis/cli/commands/info.py +90 -0
- themis/cli/commands/leaderboard.py +362 -0
- themis/cli/commands/math_benchmarks.py +318 -0
- themis/cli/commands/mcq_benchmarks.py +207 -0
- themis/cli/commands/sample_run.py +244 -0
- themis/cli/commands/visualize.py +299 -0
- themis/cli/main.py +93 -0
- themis/cli/new_project.py +33 -0
- themis/cli/utils.py +51 -0
- themis/config/__init__.py +19 -0
- themis/config/loader.py +27 -0
- themis/config/registry.py +34 -0
- themis/config/runtime.py +214 -0
- themis/config/schema.py +112 -0
- themis/core/__init__.py +5 -0
- themis/core/conversation.py +354 -0
- themis/core/entities.py +164 -0
- themis/core/serialization.py +231 -0
- themis/core/tools.py +393 -0
- themis/core/types.py +141 -0
- themis/datasets/__init__.py +273 -0
- themis/datasets/base.py +264 -0
- themis/datasets/commonsense_qa.py +174 -0
- themis/datasets/competition_math.py +265 -0
- themis/datasets/coqa.py +133 -0
- themis/datasets/gpqa.py +190 -0
- themis/datasets/gsm8k.py +123 -0
- themis/datasets/gsm_symbolic.py +124 -0
- themis/datasets/math500.py +122 -0
- themis/datasets/med_qa.py +179 -0
- themis/datasets/medmcqa.py +169 -0
- themis/datasets/mmlu_pro.py +262 -0
- themis/datasets/piqa.py +146 -0
- themis/datasets/registry.py +201 -0
- themis/datasets/schema.py +245 -0
- themis/datasets/sciq.py +150 -0
- themis/datasets/social_i_qa.py +151 -0
- themis/datasets/super_gpqa.py +263 -0
- themis/evaluation/__init__.py +1 -0
- themis/evaluation/conditional.py +410 -0
- themis/evaluation/extractors/__init__.py +19 -0
- themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
- themis/evaluation/extractors/exceptions.py +7 -0
- themis/evaluation/extractors/identity_extractor.py +29 -0
- themis/evaluation/extractors/json_field_extractor.py +45 -0
- themis/evaluation/extractors/math_verify_extractor.py +37 -0
- themis/evaluation/extractors/regex_extractor.py +43 -0
- themis/evaluation/math_verify_utils.py +87 -0
- themis/evaluation/metrics/__init__.py +21 -0
- themis/evaluation/metrics/composite_metric.py +47 -0
- themis/evaluation/metrics/consistency_metric.py +80 -0
- themis/evaluation/metrics/exact_match.py +51 -0
- themis/evaluation/metrics/length_difference_tolerance.py +33 -0
- themis/evaluation/metrics/math_verify_accuracy.py +40 -0
- themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
- themis/evaluation/metrics/response_length.py +33 -0
- themis/evaluation/metrics/rubric_judge_metric.py +134 -0
- themis/evaluation/pipeline.py +49 -0
- themis/evaluation/pipelines/__init__.py +15 -0
- themis/evaluation/pipelines/composable_pipeline.py +357 -0
- themis/evaluation/pipelines/standard_pipeline.py +288 -0
- themis/evaluation/reports.py +293 -0
- themis/evaluation/statistics/__init__.py +53 -0
- themis/evaluation/statistics/bootstrap.py +79 -0
- themis/evaluation/statistics/confidence_intervals.py +121 -0
- themis/evaluation/statistics/distributions.py +207 -0
- themis/evaluation/statistics/effect_sizes.py +124 -0
- themis/evaluation/statistics/hypothesis_tests.py +305 -0
- themis/evaluation/statistics/types.py +139 -0
- themis/evaluation/strategies/__init__.py +13 -0
- themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
- themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
- themis/evaluation/strategies/evaluation_strategy.py +24 -0
- themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
- themis/experiment/__init__.py +5 -0
- themis/experiment/builder.py +151 -0
- themis/experiment/cache_manager.py +129 -0
- themis/experiment/comparison.py +631 -0
- themis/experiment/cost.py +310 -0
- themis/experiment/definitions.py +62 -0
- themis/experiment/export.py +690 -0
- themis/experiment/export_csv.py +159 -0
- themis/experiment/integration_manager.py +104 -0
- themis/experiment/math.py +192 -0
- themis/experiment/mcq.py +169 -0
- themis/experiment/orchestrator.py +373 -0
- themis/experiment/pricing.py +317 -0
- themis/experiment/storage.py +255 -0
- themis/experiment/visualization.py +588 -0
- themis/generation/__init__.py +1 -0
- themis/generation/agentic_runner.py +420 -0
- themis/generation/batching.py +254 -0
- themis/generation/clients.py +143 -0
- themis/generation/conversation_runner.py +236 -0
- themis/generation/plan.py +456 -0
- themis/generation/providers/litellm_provider.py +221 -0
- themis/generation/providers/vllm_provider.py +135 -0
- themis/generation/router.py +34 -0
- themis/generation/runner.py +207 -0
- themis/generation/strategies.py +98 -0
- themis/generation/templates.py +71 -0
- themis/generation/turn_strategies.py +393 -0
- themis/generation/types.py +9 -0
- themis/integrations/__init__.py +0 -0
- themis/integrations/huggingface.py +61 -0
- themis/integrations/wandb.py +65 -0
- themis/interfaces/__init__.py +83 -0
- themis/project/__init__.py +20 -0
- themis/project/definitions.py +98 -0
- themis/project/patterns.py +230 -0
- themis/providers/__init__.py +5 -0
- themis/providers/registry.py +39 -0
- themis/utils/api_generator.py +379 -0
- themis/utils/cost_tracking.py +376 -0
- themis/utils/dashboard.py +452 -0
- themis/utils/logging_utils.py +41 -0
- themis/utils/progress.py +58 -0
- themis/utils/tracing.py +320 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/METADATA +1 -1
- themis_eval-0.1.1.dist-info/RECORD +134 -0
- themis_eval-0.1.0.dist-info/RECORD +0 -8
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/WHEEL +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
"""Model provider implementations used for experiments."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import math
|
|
7
|
+
import random
|
|
8
|
+
import re
|
|
9
|
+
from typing import Tuple
|
|
10
|
+
|
|
11
|
+
from themis.core import entities as core_entities
|
|
12
|
+
from themis.interfaces import ModelProvider
|
|
13
|
+
from themis.providers import register_provider
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class FakeMathModelClient(ModelProvider):
|
|
17
|
+
"""A lightweight heuristic provider used for math experiments."""
|
|
18
|
+
|
|
19
|
+
_POINT_PATTERN = re.compile(
|
|
20
|
+
r"point\s*\(\s*(-?\d+)\s*,\s*(-?\d+)\s*\)", re.IGNORECASE
|
|
21
|
+
)
|
|
22
|
+
_ARITHMETIC_PATTERN = re.compile(
|
|
23
|
+
r"(-?\d+(?:\.\d+)?)\s*([+\-*/])\s*(-?\d+(?:\.\d+)?)"
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self, *, seed: int | None = None, default_answer: str = "unknown"
|
|
28
|
+
) -> None:
|
|
29
|
+
self._rng = random.Random(seed)
|
|
30
|
+
self._default_answer = default_answer
|
|
31
|
+
|
|
32
|
+
def generate(
|
|
33
|
+
self, task: core_entities.GenerationTask
|
|
34
|
+
) -> core_entities.GenerationRecord: # type: ignore[override]
|
|
35
|
+
prompt_text = task.prompt.text
|
|
36
|
+
answer, reason = self._solve(prompt_text)
|
|
37
|
+
expect_boxed = bool(task.metadata.get("template_expect_boxed"))
|
|
38
|
+
if expect_boxed and "\\boxed{" not in answer:
|
|
39
|
+
answer = f"\\boxed{{{answer}}}"
|
|
40
|
+
payload = {
|
|
41
|
+
"answer": answer,
|
|
42
|
+
"reasoning": reason,
|
|
43
|
+
"model": task.model.identifier,
|
|
44
|
+
}
|
|
45
|
+
latency = self._rng.randint(8, 18)
|
|
46
|
+
return core_entities.GenerationRecord(
|
|
47
|
+
task=task,
|
|
48
|
+
output=core_entities.ModelOutput(text=json.dumps(payload), raw=payload),
|
|
49
|
+
error=None,
|
|
50
|
+
metrics={"latency_ms": latency},
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def _solve(self, prompt: str) -> Tuple[str, str]:
|
|
54
|
+
prompt_lower = prompt.lower()
|
|
55
|
+
polar = self._solve_polar_coordinates(prompt_lower)
|
|
56
|
+
if polar is not None:
|
|
57
|
+
return polar
|
|
58
|
+
|
|
59
|
+
arithmetic = self._solve_arithmetic(prompt_lower)
|
|
60
|
+
if arithmetic is not None:
|
|
61
|
+
return arithmetic
|
|
62
|
+
|
|
63
|
+
return self._default_answer, "Unable to derive answer with heuristic solver."
|
|
64
|
+
|
|
65
|
+
def _solve_polar_coordinates(self, prompt_lower: str) -> Tuple[str, str] | None:
|
|
66
|
+
if "polar" not in prompt_lower:
|
|
67
|
+
return None
|
|
68
|
+
match = self._POINT_PATTERN.search(prompt_lower)
|
|
69
|
+
if not match:
|
|
70
|
+
return None
|
|
71
|
+
x = int(match.group(1))
|
|
72
|
+
y = int(match.group(2))
|
|
73
|
+
radius_squared = x * x + y * y
|
|
74
|
+
radius = math.sqrt(radius_squared)
|
|
75
|
+
if math.isclose(radius, round(radius)):
|
|
76
|
+
radius_str = str(int(round(radius)))
|
|
77
|
+
else:
|
|
78
|
+
radius_str = f"\\sqrt{{{radius_squared}}}"
|
|
79
|
+
theta = math.atan2(y, x)
|
|
80
|
+
theta_str = self._format_theta(theta)
|
|
81
|
+
answer = f"\\left( {radius_str}, {theta_str} \\right)"
|
|
82
|
+
reasoning = f"Converted rectangular point ({x}, {y}) into polar coordinates."
|
|
83
|
+
return answer, reasoning
|
|
84
|
+
|
|
85
|
+
def _format_theta(self, theta: float) -> str:
|
|
86
|
+
tau = 2 * math.pi
|
|
87
|
+
theta = theta % tau
|
|
88
|
+
multiples = {
|
|
89
|
+
0: "0",
|
|
90
|
+
math.pi / 6: "\\frac{\\pi}{6}",
|
|
91
|
+
math.pi / 4: "\\frac{\\pi}{4}",
|
|
92
|
+
math.pi / 3: "\\frac{\\pi}{3}",
|
|
93
|
+
math.pi / 2: "\\frac{\\pi}{2}",
|
|
94
|
+
math.pi: "\\pi",
|
|
95
|
+
3 * math.pi / 2: "\\frac{3\\pi}{2}",
|
|
96
|
+
}
|
|
97
|
+
for value, label in multiples.items():
|
|
98
|
+
if math.isclose(theta, value, abs_tol=1e-6):
|
|
99
|
+
return label
|
|
100
|
+
if math.isclose(theta, 5 * math.pi / 6, abs_tol=1e-6):
|
|
101
|
+
return "\\frac{5\\pi}{6}"
|
|
102
|
+
if math.isclose(theta, 7 * math.pi / 6, abs_tol=1e-6):
|
|
103
|
+
return "\\frac{7\\pi}{6}"
|
|
104
|
+
if math.isclose(theta, 4 * math.pi / 3, abs_tol=1e-6):
|
|
105
|
+
return "\\frac{4\\pi}{3}"
|
|
106
|
+
return f"{theta:.3f}"
|
|
107
|
+
|
|
108
|
+
def _solve_arithmetic(self, prompt_lower: str) -> Tuple[str, str] | None:
|
|
109
|
+
if "what is" not in prompt_lower and "compute" not in prompt_lower:
|
|
110
|
+
return None
|
|
111
|
+
match = self._ARITHMETIC_PATTERN.search(prompt_lower)
|
|
112
|
+
if not match:
|
|
113
|
+
return None
|
|
114
|
+
left = float(match.group(1))
|
|
115
|
+
op = match.group(2)
|
|
116
|
+
right = float(match.group(3))
|
|
117
|
+
if op == "+":
|
|
118
|
+
result = left + right
|
|
119
|
+
elif op == "-":
|
|
120
|
+
result = left - right
|
|
121
|
+
elif op == "*":
|
|
122
|
+
result = left * right
|
|
123
|
+
elif op == "/":
|
|
124
|
+
if right == 0:
|
|
125
|
+
return "undefined", "Division by zero encountered."
|
|
126
|
+
result = left / right
|
|
127
|
+
else:
|
|
128
|
+
return None
|
|
129
|
+
if result.is_integer():
|
|
130
|
+
answer = str(int(result))
|
|
131
|
+
else:
|
|
132
|
+
answer = f"{result:.3f}"
|
|
133
|
+
reasoning = f"Evaluated {left} {op} {right} using arithmetic solver."
|
|
134
|
+
return answer, reasoning
|
|
135
|
+
|
|
136
|
+
def count_tokens(self, text: str) -> int:
|
|
137
|
+
return len(text.split())
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
__all__ = ["FakeMathModelClient"]
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
register_provider("fake", FakeMathModelClient)
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
"""Conversation runner for multi-turn interactions.
|
|
2
|
+
|
|
3
|
+
This module provides a runner that executes multi-turn conversations
|
|
4
|
+
using turn strategies to determine the flow of the conversation.
|
|
5
|
+
|
|
6
|
+
Examples:
|
|
7
|
+
from themis.generation import conversation_runner, turn_strategies
|
|
8
|
+
from themis.core import conversation, entities
|
|
9
|
+
|
|
10
|
+
# Create provider and strategy
|
|
11
|
+
provider = FakeProvider()
|
|
12
|
+
strategy = turn_strategies.FixedSequenceTurnStrategy([
|
|
13
|
+
"What is 2+2?",
|
|
14
|
+
"What about 3+3?"
|
|
15
|
+
])
|
|
16
|
+
|
|
17
|
+
# Create runner
|
|
18
|
+
runner = conversation_runner.ConversationRunner(
|
|
19
|
+
provider=provider,
|
|
20
|
+
turn_strategy=strategy,
|
|
21
|
+
max_turns=5
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
# Create conversation task
|
|
25
|
+
context = conversation.ConversationContext()
|
|
26
|
+
context.add_message("system", "You are a math tutor.")
|
|
27
|
+
|
|
28
|
+
task = conversation.ConversationTask(
|
|
29
|
+
context=context,
|
|
30
|
+
model=model_spec,
|
|
31
|
+
sampling=sampling_config
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
# Run conversation
|
|
35
|
+
record = runner.run_conversation(task)
|
|
36
|
+
print(f"Conversation had {record.total_turns()} turns")
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
from __future__ import annotations
|
|
40
|
+
|
|
41
|
+
import logging
|
|
42
|
+
from typing import Any
|
|
43
|
+
|
|
44
|
+
from themis.core import conversation as conv
|
|
45
|
+
from themis.core import entities as core_entities
|
|
46
|
+
from themis.generation import turn_strategies
|
|
47
|
+
from themis.interfaces import ModelProvider
|
|
48
|
+
from themis.utils import tracing
|
|
49
|
+
|
|
50
|
+
logger = logging.getLogger(__name__)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class ConversationRunner:
|
|
54
|
+
"""Runner for executing multi-turn conversations.
|
|
55
|
+
|
|
56
|
+
This runner manages the conversation loop, generating responses
|
|
57
|
+
and determining next turns using a TurnStrategy.
|
|
58
|
+
|
|
59
|
+
Attributes:
|
|
60
|
+
provider: Model provider for generation
|
|
61
|
+
turn_strategy: Strategy for determining next turns
|
|
62
|
+
max_turns: Maximum number of conversation turns
|
|
63
|
+
prompt_template: Optional template for formatting context
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
*,
|
|
69
|
+
provider: ModelProvider,
|
|
70
|
+
turn_strategy: turn_strategies.TurnStrategy,
|
|
71
|
+
max_turns: int = 10,
|
|
72
|
+
prompt_template: Any | None = None,
|
|
73
|
+
):
|
|
74
|
+
"""Initialize conversation runner.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
provider: Model provider for generation
|
|
78
|
+
turn_strategy: Strategy for determining next turns
|
|
79
|
+
max_turns: Maximum number of conversation turns
|
|
80
|
+
prompt_template: Optional template for formatting context
|
|
81
|
+
"""
|
|
82
|
+
self._provider = provider
|
|
83
|
+
self._turn_strategy = turn_strategy
|
|
84
|
+
self._max_turns = max_turns
|
|
85
|
+
self._prompt_template = prompt_template
|
|
86
|
+
|
|
87
|
+
def run_conversation(self, task: conv.ConversationTask) -> conv.ConversationRecord:
|
|
88
|
+
"""Execute a multi-turn conversation.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
task: Conversation task to execute
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
ConversationRecord with full conversation history
|
|
95
|
+
"""
|
|
96
|
+
with tracing.span(
|
|
97
|
+
"run_conversation",
|
|
98
|
+
model=task.model.identifier,
|
|
99
|
+
max_turns=task.max_turns,
|
|
100
|
+
):
|
|
101
|
+
turns: list[conv.ConversationTurn] = []
|
|
102
|
+
context = task.context
|
|
103
|
+
max_turns = min(task.max_turns, self._max_turns)
|
|
104
|
+
|
|
105
|
+
for turn_num in range(max_turns):
|
|
106
|
+
with tracing.span("conversation_turn", turn=turn_num):
|
|
107
|
+
logger.debug(
|
|
108
|
+
"Starting conversation turn %d/%d", turn_num + 1, max_turns
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Generate response for current context
|
|
112
|
+
with tracing.span("generate_response"):
|
|
113
|
+
prompt_text = context.to_prompt(self._prompt_template)
|
|
114
|
+
generation_task = self._create_generation_task(
|
|
115
|
+
task, prompt_text, turn_num
|
|
116
|
+
)
|
|
117
|
+
record = self._provider.generate(generation_task)
|
|
118
|
+
|
|
119
|
+
# Add assistant response to context
|
|
120
|
+
if record.output:
|
|
121
|
+
context.add_message("assistant", record.output.text)
|
|
122
|
+
else:
|
|
123
|
+
# Generation failed
|
|
124
|
+
logger.warning(
|
|
125
|
+
"Generation failed at turn %d: %s",
|
|
126
|
+
turn_num,
|
|
127
|
+
record.error.message if record.error else "unknown error",
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Create turn record (no user message yet)
|
|
131
|
+
turn = conv.ConversationTurn(
|
|
132
|
+
turn_number=turn_num,
|
|
133
|
+
user_message=None,
|
|
134
|
+
generation_record=record,
|
|
135
|
+
context_snapshot=self._snapshot_context(context),
|
|
136
|
+
)
|
|
137
|
+
turns.append(turn)
|
|
138
|
+
|
|
139
|
+
# Check stop conditions
|
|
140
|
+
if task.should_stop():
|
|
141
|
+
logger.debug("Task stop condition met at turn %d", turn_num)
|
|
142
|
+
break
|
|
143
|
+
|
|
144
|
+
# Determine next turn
|
|
145
|
+
with tracing.span("plan_next_turn"):
|
|
146
|
+
next_message = self._turn_strategy.next_turn(context, record)
|
|
147
|
+
|
|
148
|
+
if next_message is None:
|
|
149
|
+
logger.debug(
|
|
150
|
+
"Turn strategy ended conversation at turn %d", turn_num
|
|
151
|
+
)
|
|
152
|
+
break
|
|
153
|
+
|
|
154
|
+
# Add user message for next turn
|
|
155
|
+
user_msg = conv.Message(role="user", content=next_message)
|
|
156
|
+
context.add_message("user", next_message)
|
|
157
|
+
turn.user_message = user_msg
|
|
158
|
+
|
|
159
|
+
logger.debug(
|
|
160
|
+
"Planned next turn: %s",
|
|
161
|
+
next_message[:50] + ("..." if len(next_message) > 50 else ""),
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
# Create conversation record
|
|
165
|
+
record = conv.ConversationRecord(
|
|
166
|
+
task=task,
|
|
167
|
+
context=context,
|
|
168
|
+
turns=turns,
|
|
169
|
+
metadata={
|
|
170
|
+
"total_turns": len(turns),
|
|
171
|
+
"max_turns_reached": len(turns) >= max_turns,
|
|
172
|
+
"stop_condition_met": task.should_stop(),
|
|
173
|
+
},
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
logger.info(
|
|
177
|
+
"Conversation completed: %d turns, stop_reason=%s",
|
|
178
|
+
len(turns),
|
|
179
|
+
"max_turns" if record.metadata["max_turns_reached"] else "strategy",
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
return record
|
|
183
|
+
|
|
184
|
+
def _create_generation_task(
|
|
185
|
+
self, conv_task: conv.ConversationTask, prompt_text: str, turn_num: int
|
|
186
|
+
) -> core_entities.GenerationTask:
|
|
187
|
+
"""Create a generation task from conversation state.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
conv_task: Conversation task
|
|
191
|
+
prompt_text: Rendered prompt text
|
|
192
|
+
turn_num: Current turn number
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
GenerationTask for this turn
|
|
196
|
+
"""
|
|
197
|
+
from themis.core.entities import PromptRender, PromptSpec
|
|
198
|
+
|
|
199
|
+
prompt_render = PromptRender(
|
|
200
|
+
spec=PromptSpec(
|
|
201
|
+
name=f"conversation_turn_{turn_num}",
|
|
202
|
+
template="",
|
|
203
|
+
metadata={"turn": turn_num},
|
|
204
|
+
),
|
|
205
|
+
text=prompt_text,
|
|
206
|
+
context={"turn": turn_num},
|
|
207
|
+
metadata={"turn": turn_num},
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
metadata = dict(conv_task.metadata)
|
|
211
|
+
metadata["turn"] = turn_num
|
|
212
|
+
metadata["conversation"] = True
|
|
213
|
+
|
|
214
|
+
return core_entities.GenerationTask(
|
|
215
|
+
prompt=prompt_render,
|
|
216
|
+
model=conv_task.model,
|
|
217
|
+
sampling=conv_task.sampling,
|
|
218
|
+
metadata=metadata,
|
|
219
|
+
reference=conv_task.reference,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
def _snapshot_context(
|
|
223
|
+
self, context: conv.ConversationContext
|
|
224
|
+
) -> conv.ConversationContext:
|
|
225
|
+
"""Create a snapshot of conversation context.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
context: Context to snapshot
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
Copy of context
|
|
232
|
+
"""
|
|
233
|
+
return conv.ConversationContext.from_dict(context.to_dict())
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
__all__ = ["ConversationRunner"]
|