eval-framework 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eval_framework/__init__.py +7 -0
- eval_framework/base_config.py +36 -0
- eval_framework/context/__init__.py +0 -0
- eval_framework/context/determined.py +177 -0
- eval_framework/context/eval.py +121 -0
- eval_framework/context/local.py +78 -0
- eval_framework/evaluation_generator.py +234 -0
- eval_framework/exceptions.py +2 -0
- eval_framework/external/ifeval_impl/README.md +5 -0
- eval_framework/external/ifeval_impl/instructions.py +1523 -0
- eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
- eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
- eval_framework/external/ifeval_impl/utils.py +135 -0
- eval_framework/llm/__init__.py +0 -0
- eval_framework/llm/aleph_alpha.py +432 -0
- eval_framework/llm/base.py +180 -0
- eval_framework/llm/huggingface.py +418 -0
- eval_framework/llm/mistral.py +88 -0
- eval_framework/llm/models.py +28 -0
- eval_framework/llm/openai.py +400 -0
- eval_framework/llm/vllm.py +554 -0
- eval_framework/logger.py +3 -0
- eval_framework/main.py +166 -0
- eval_framework/metrics/__init__.py +0 -0
- eval_framework/metrics/base.py +40 -0
- eval_framework/metrics/completion/__init__.py +1 -0
- eval_framework/metrics/completion/accuracy_completion.py +16 -0
- eval_framework/metrics/completion/aidanbench.py +28 -0
- eval_framework/metrics/completion/bleu.py +76 -0
- eval_framework/metrics/completion/chrf.py +62 -0
- eval_framework/metrics/completion/code_assertion.py +44 -0
- eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
- eval_framework/metrics/completion/comet.py +56 -0
- eval_framework/metrics/completion/concordance_index.py +38 -0
- eval_framework/metrics/completion/csv_format.py +102 -0
- eval_framework/metrics/completion/cwe_accuracy.py +49 -0
- eval_framework/metrics/completion/exponential_similarity.py +65 -0
- eval_framework/metrics/completion/f1.py +42 -0
- eval_framework/metrics/completion/format_checker.py +56 -0
- eval_framework/metrics/completion/grid_difference.py +77 -0
- eval_framework/metrics/completion/ifeval.py +73 -0
- eval_framework/metrics/completion/json_format.py +179 -0
- eval_framework/metrics/completion/language_checker.py +74 -0
- eval_framework/metrics/completion/length_control.py +83 -0
- eval_framework/metrics/completion/math_reasoning_completion.py +307 -0
- eval_framework/metrics/completion/niah_accuracy.py +163 -0
- eval_framework/metrics/completion/placeholder_checker.py +27 -0
- eval_framework/metrics/completion/repetition.py +88 -0
- eval_framework/metrics/completion/rouge_1.py +35 -0
- eval_framework/metrics/completion/rouge_2.py +45 -0
- eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
- eval_framework/metrics/completion/rouge_l.py +52 -0
- eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
- eval_framework/metrics/completion/ter.py +67 -0
- eval_framework/metrics/completion/text_counter.py +182 -0
- eval_framework/metrics/efficiency/__init__.py +0 -0
- eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
- eval_framework/metrics/llm/__init__.py +0 -0
- eval_framework/metrics/llm/base.py +34 -0
- eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
- eval_framework/metrics/llm/graders/coherence_grader.py +115 -0
- eval_framework/metrics/llm/graders/comparison_grader.py +198 -0
- eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
- eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
- eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
- eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
- eval_framework/metrics/llm/graders/language.py +56 -0
- eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
- eval_framework/metrics/llm/graders/models.py +74 -0
- eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
- eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
- eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
- eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
- eval_framework/metrics/llm/llm_judge_coherence.py +44 -0
- eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
- eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
- eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
- eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
- eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
- eval_framework/metrics/llm/llm_judge_mtbench_pair.py +306 -0
- eval_framework/metrics/llm/llm_judge_mtbench_single.py +210 -0
- eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
- eval_framework/metrics/llm/llm_judge_sql.py +394 -0
- eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
- eval_framework/metrics/llm/utils.py +20 -0
- eval_framework/metrics/loglikelihood/__init__.py +0 -0
- eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
- eval_framework/metrics/loglikelihood/base.py +50 -0
- eval_framework/metrics/loglikelihood/confidence_weighted_accuracy.py +25 -0
- eval_framework/metrics/loglikelihood/dcs.py +43 -0
- eval_framework/metrics/loglikelihood/probability_mass.py +53 -0
- eval_framework/metrics/loglikelihood/ternary.py +42 -0
- eval_framework/py.typed +0 -0
- eval_framework/response_generator.py +351 -0
- eval_framework/result_processors/__init__.py +0 -0
- eval_framework/result_processors/base.py +88 -0
- eval_framework/result_processors/hf_uploader.py +75 -0
- eval_framework/result_processors/result_processor.py +129 -0
- eval_framework/result_processors/wandb_uploader.py +137 -0
- eval_framework/run.py +369 -0
- eval_framework/run_direct.py +42 -0
- eval_framework/shared/types.py +227 -0
- eval_framework/tasks/__init__.py +6 -0
- eval_framework/tasks/base.py +392 -0
- eval_framework/tasks/benchmarks/__init__.py +0 -0
- eval_framework/tasks/benchmarks/aidanbench.py +211 -0
- eval_framework/tasks/benchmarks/arc.py +70 -0
- eval_framework/tasks/benchmarks/arc_de.py +46 -0
- eval_framework/tasks/benchmarks/arc_fi.py +46 -0
- eval_framework/tasks/benchmarks/belebele.py +60 -0
- eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
- eval_framework/tasks/benchmarks/casehold.py +47 -0
- eval_framework/tasks/benchmarks/chembench.py +85 -0
- eval_framework/tasks/benchmarks/copa.py +64 -0
- eval_framework/tasks/benchmarks/duc.py +91 -0
- eval_framework/tasks/benchmarks/flores200.py +133 -0
- eval_framework/tasks/benchmarks/flores_plus.py +84 -0
- eval_framework/tasks/benchmarks/gpqa.py +201 -0
- eval_framework/tasks/benchmarks/gsm8k.py +150 -0
- eval_framework/tasks/benchmarks/hellaswag.py +69 -0
- eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
- eval_framework/tasks/benchmarks/humaneval.py +97 -0
- eval_framework/tasks/benchmarks/ifeval.py +78 -0
- eval_framework/tasks/benchmarks/include.py +119 -0
- eval_framework/tasks/benchmarks/infinitebench.py +302 -0
- eval_framework/tasks/benchmarks/math_reasoning.py +580 -0
- eval_framework/tasks/benchmarks/mbpp.py +192 -0
- eval_framework/tasks/benchmarks/mmlu.py +215 -0
- eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
- eval_framework/tasks/benchmarks/mmlu_pro.py +164 -0
- eval_framework/tasks/benchmarks/mmmlu.py +529 -0
- eval_framework/tasks/benchmarks/openbookqa.py +85 -0
- eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
- eval_framework/tasks/benchmarks/pawsx.py +65 -0
- eval_framework/tasks/benchmarks/piqa.py +64 -0
- eval_framework/tasks/benchmarks/quality.py +56 -0
- eval_framework/tasks/benchmarks/sciq.py +110 -0
- eval_framework/tasks/benchmarks/sphyr.py +79 -0
- eval_framework/tasks/benchmarks/squad.py +211 -0
- eval_framework/tasks/benchmarks/struct_eval.py +116 -0
- eval_framework/tasks/benchmarks/tablebench.py +117 -0
- eval_framework/tasks/benchmarks/triviaqa.py +42 -0
- eval_framework/tasks/benchmarks/truthfulqa.py +119 -0
- eval_framework/tasks/benchmarks/winogender.py +64 -0
- eval_framework/tasks/benchmarks/winogrande.py +69 -0
- eval_framework/tasks/benchmarks/winox.py +57 -0
- eval_framework/tasks/benchmarks/wmt.py +160 -0
- eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
- eval_framework/tasks/eval_config.py +136 -0
- eval_framework/tasks/perturbation.py +83 -0
- eval_framework/tasks/registry.py +186 -0
- eval_framework/tasks/task_loader.py +81 -0
- eval_framework/tasks/task_names.py +324 -0
- eval_framework/tasks/utils.py +584 -0
- eval_framework/utils/constants.py +9 -0
- eval_framework/utils/file_ops.py +245 -0
- eval_framework/utils/generate_task_docs.py +244 -0
- eval_framework/utils/helpers.py +32 -0
- eval_framework/utils/logging.py +62 -0
- eval_framework/utils/packaging.py +52 -0
- eval_framework/utils/tqdm_handler.py +14 -0
- eval_framework-0.2.7.dist-info/METADATA +548 -0
- eval_framework-0.2.7.dist-info/RECORD +170 -0
- eval_framework-0.2.7.dist-info/WHEEL +4 -0
- eval_framework-0.2.7.dist-info/entry_points.txt +3 -0
- template_formatting/README.md +83 -0
- template_formatting/__init__.py +0 -0
- template_formatting/formatter.py +537 -0
- template_formatting/mistral_formatter.py +159 -0
- template_formatting/py.typed +0 -0
|
@@ -0,0 +1,537 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from dataclasses import asdict, dataclass
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Literal, overload, override
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, field_serializer, field_validator
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
from transformers import AutoTokenizer
|
|
12
|
+
except ImportError:
|
|
13
|
+
print("template_formatting: `transformers` package is not installed, HFFormatter will not be available.")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Role(Enum):
|
|
17
|
+
SYSTEM = "system"
|
|
18
|
+
USER = "user"
|
|
19
|
+
ASSISTANT = "assistant"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Property(Enum):
|
|
23
|
+
ANSWER = "answer"
|
|
24
|
+
THOUGHT = "thought"
|
|
25
|
+
SOLUTION = "solution"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Message(BaseModel):
|
|
29
|
+
role: Role | None = None # Optional due to compatibility with legacy finetuning format.
|
|
30
|
+
property: Property | None = None
|
|
31
|
+
content: str
|
|
32
|
+
has_loss: bool | None = None
|
|
33
|
+
type: str | None = None
|
|
34
|
+
|
|
35
|
+
@field_serializer("role")
|
|
36
|
+
def serialize_task_name(self, value: Role | None) -> str | None:
|
|
37
|
+
if value is None:
|
|
38
|
+
# Legacy finetuning format.
|
|
39
|
+
return None
|
|
40
|
+
return value.value
|
|
41
|
+
|
|
42
|
+
@field_validator("role", mode="before")
|
|
43
|
+
@classmethod
|
|
44
|
+
def validate_task_name(cls, value: str | Role | None) -> Role | None:
|
|
45
|
+
if value is None:
|
|
46
|
+
# Legacy finetuning format.
|
|
47
|
+
return None
|
|
48
|
+
if isinstance(value, str):
|
|
49
|
+
return Role(value)
|
|
50
|
+
return value
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class ChatTemplate:
|
|
55
|
+
begin_of_text: str
|
|
56
|
+
end_of_text: str
|
|
57
|
+
begin_system_prompt: str
|
|
58
|
+
system_prompt: str
|
|
59
|
+
end_system_prompt: str
|
|
60
|
+
begin_assistant_id: str
|
|
61
|
+
end_assistant_id: str
|
|
62
|
+
begin_user_id: str
|
|
63
|
+
end_user_id: str
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass
|
|
67
|
+
class ReasoningTemplate(ChatTemplate):
|
|
68
|
+
begin_thought_id: str
|
|
69
|
+
end_thought_id: str
|
|
70
|
+
begin_solution_id: str
|
|
71
|
+
end_solution_id: str
|
|
72
|
+
begin_answer_id: str
|
|
73
|
+
end_answer_id: str
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class BaseFormatter:
|
|
77
|
+
template: ChatTemplate | ReasoningTemplate
|
|
78
|
+
strip_content: bool = False
|
|
79
|
+
never_strip: bool = False
|
|
80
|
+
|
|
81
|
+
def __init__(self) -> None:
|
|
82
|
+
super().__init__()
|
|
83
|
+
assert not (self.strip_content and self.never_strip), "strip_content and never_strip cannot be both True"
|
|
84
|
+
|
|
85
|
+
@staticmethod
|
|
86
|
+
def _verify_messages(messages: Sequence[Message]) -> None:
|
|
87
|
+
grouped_messages = BaseFormatter._get_grouped_messages(messages)
|
|
88
|
+
offset = int(grouped_messages[0][0].role == Role.SYSTEM)
|
|
89
|
+
user_messages = grouped_messages[offset::2]
|
|
90
|
+
assistant_messages = grouped_messages[offset + 1 :: 2]
|
|
91
|
+
if grouped_messages[0][0].role is None:
|
|
92
|
+
# Legacy finetuning format.
|
|
93
|
+
assert all(m[0].role is None for m in user_messages)
|
|
94
|
+
else: # New format, assert role order.
|
|
95
|
+
assert all(m[0].role == Role.USER for m in user_messages)
|
|
96
|
+
assert all(m[0].role == Role.ASSISTANT for m in assistant_messages)
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def _verify_message_fields(messages: Sequence[Message], output_mode: str) -> None:
|
|
100
|
+
if output_mode not in ("string", "list"):
|
|
101
|
+
raise ValueError("Unsupported output_mode: choose 'string' or 'list'")
|
|
102
|
+
|
|
103
|
+
for message in messages:
|
|
104
|
+
if output_mode == "string":
|
|
105
|
+
# eval-framework style
|
|
106
|
+
if not hasattr(message, "role"):
|
|
107
|
+
raise ValueError("Message is missing 'role' property.")
|
|
108
|
+
if (getattr(message, "type", None) is not None) or (getattr(message, "has_loss", None) is not None):
|
|
109
|
+
raise ValueError()
|
|
110
|
+
|
|
111
|
+
elif output_mode == "list":
|
|
112
|
+
# scaling style
|
|
113
|
+
if not hasattr(message, "type") or not hasattr(message, "has_loss"):
|
|
114
|
+
raise ValueError("Message is missing 'type' or 'has_loss' property.")
|
|
115
|
+
|
|
116
|
+
@staticmethod
|
|
117
|
+
def _get_grouped_messages(messages: Sequence[Message]) -> Sequence[Sequence[Message]]:
|
|
118
|
+
"""
|
|
119
|
+
Groups consecutive messages to meet two criteria, while preserving the
|
|
120
|
+
order of each sequence item:
|
|
121
|
+
- Role is identical in each group.
|
|
122
|
+
- Each property occurs once in each group.
|
|
123
|
+
"""
|
|
124
|
+
if not messages:
|
|
125
|
+
return []
|
|
126
|
+
|
|
127
|
+
grouped_messages = []
|
|
128
|
+
current_group = [messages[0]]
|
|
129
|
+
|
|
130
|
+
for message in messages[1:]:
|
|
131
|
+
role = current_group[0].role
|
|
132
|
+
group_props = set(i.property for i in current_group)
|
|
133
|
+
if message.role == role and message.property not in group_props:
|
|
134
|
+
current_group.append(message)
|
|
135
|
+
else:
|
|
136
|
+
grouped_messages.append(current_group)
|
|
137
|
+
current_group = [message]
|
|
138
|
+
|
|
139
|
+
grouped_messages.append(current_group)
|
|
140
|
+
return grouped_messages
|
|
141
|
+
|
|
142
|
+
@overload
|
|
143
|
+
def format(self, messages: Sequence[Message], output_mode: Literal["string"] = ...) -> str:
|
|
144
|
+
pass
|
|
145
|
+
|
|
146
|
+
@overload
|
|
147
|
+
def format(self, messages: Sequence[Message], output_mode: Literal["list"]) -> list[Message]:
|
|
148
|
+
pass
|
|
149
|
+
|
|
150
|
+
def format(
|
|
151
|
+
self, messages: Sequence[Message], output_mode: Literal["string", "list"] = "string"
|
|
152
|
+
) -> str | list[Message]:
|
|
153
|
+
"""
|
|
154
|
+
Formats a list of messages using the provided template.
|
|
155
|
+
output_mode: "string" returns a single concatenated string ('eval-framework' style),
|
|
156
|
+
"list" returns the messages with their content updated ('scaling' style).
|
|
157
|
+
"""
|
|
158
|
+
self._verify_messages(messages)
|
|
159
|
+
self._verify_message_fields(messages, output_mode)
|
|
160
|
+
|
|
161
|
+
if output_mode not in {"string", "list"}:
|
|
162
|
+
raise ValueError("Unsupported output_mode: choose 'string' or 'list'")
|
|
163
|
+
|
|
164
|
+
if output_mode == "string":
|
|
165
|
+
# Generate formatted strings for each message and join them.
|
|
166
|
+
formatted_parts = (
|
|
167
|
+
self._format_message(message, i == len(messages) - 1, output_mode) for i, message in enumerate(messages)
|
|
168
|
+
)
|
|
169
|
+
return self.template.begin_of_text + "".join(formatted_parts)
|
|
170
|
+
else:
|
|
171
|
+
# Create a new list of messages with updated content.
|
|
172
|
+
new_messages: list[Message] = [message.model_copy(deep=True) for message in messages]
|
|
173
|
+
for i, message in enumerate(new_messages):
|
|
174
|
+
formatted_content = self._format_message(messages[i], i == len(messages) - 1, output_mode)
|
|
175
|
+
message.content = formatted_content
|
|
176
|
+
|
|
177
|
+
# Prepend the begin_of_text to the first message's content.
|
|
178
|
+
if new_messages:
|
|
179
|
+
new_messages[0].content = self.template.begin_of_text + new_messages[0].content
|
|
180
|
+
return new_messages
|
|
181
|
+
|
|
182
|
+
def _format_message(self, message: Message, is_last: bool, output_mode: Literal["string", "list"]) -> str:
|
|
183
|
+
"""
|
|
184
|
+
Returns the formatted string for a single message.
|
|
185
|
+
"""
|
|
186
|
+
if message.role == Role.SYSTEM:
|
|
187
|
+
text = getattr(message, "content", "")
|
|
188
|
+
if not text and hasattr(self.template, "system_prompt"):
|
|
189
|
+
text = self.template.system_prompt
|
|
190
|
+
if self.strip_content:
|
|
191
|
+
text = text.strip()
|
|
192
|
+
return f"{self.template.begin_system_prompt}{text}{self.template.end_system_prompt}"
|
|
193
|
+
|
|
194
|
+
elif message.role == Role.USER:
|
|
195
|
+
text = getattr(message, "content", "")
|
|
196
|
+
if self.strip_content:
|
|
197
|
+
text = text.strip()
|
|
198
|
+
elif output_mode == "string":
|
|
199
|
+
if is_last or (self.template.end_user_id != "" and not self.never_strip):
|
|
200
|
+
text = text.strip()
|
|
201
|
+
if output_mode == "string" or (output_mode == "list" and not is_last):
|
|
202
|
+
# start assistant message after user message
|
|
203
|
+
result = (
|
|
204
|
+
f"{self.template.begin_user_id}{text}{self.template.end_user_id}{self.template.begin_assistant_id}"
|
|
205
|
+
)
|
|
206
|
+
else:
|
|
207
|
+
# default HF behavior for applying chat template with
|
|
208
|
+
# `add_generation_prompt=False` and `continue_final_message=False` (as used in 'scaling')
|
|
209
|
+
result = f"{self.template.begin_user_id}{text}{self.template.end_user_id}"
|
|
210
|
+
return result
|
|
211
|
+
|
|
212
|
+
elif message.role == Role.ASSISTANT:
|
|
213
|
+
return self._format_assistant(message, is_last, output_mode)
|
|
214
|
+
|
|
215
|
+
elif message.role is None:
|
|
216
|
+
return getattr(message, "content", "")
|
|
217
|
+
|
|
218
|
+
else:
|
|
219
|
+
raise ValueError(f"Unsupported role: {message.role}")
|
|
220
|
+
|
|
221
|
+
def _format_assistant(self, message: Message, is_last: bool, output_mode: Literal["string", "list"]) -> str:
|
|
222
|
+
"""
|
|
223
|
+
Formats an assistant message based on its property.
|
|
224
|
+
"""
|
|
225
|
+
text = getattr(message, "content", "")
|
|
226
|
+
if self.strip_content:
|
|
227
|
+
text = text.strip()
|
|
228
|
+
|
|
229
|
+
if message.property is not None:
|
|
230
|
+
raise ValueError("Message properties require ReasoningFormatter")
|
|
231
|
+
|
|
232
|
+
else:
|
|
233
|
+
result = text
|
|
234
|
+
# In string mode (i.e., 'eval-framework'), omit end_assistant_id if this is the last message.
|
|
235
|
+
# In list mode (i.e., 'scaling'), always append it.
|
|
236
|
+
if output_mode == "list" or (output_mode == "string" and not is_last):
|
|
237
|
+
result += self.template.end_assistant_id
|
|
238
|
+
elif output_mode == "string":
|
|
239
|
+
if not self.never_strip:
|
|
240
|
+
result = result.strip()
|
|
241
|
+
else:
|
|
242
|
+
raise ValueError(f"Unknown output_mode: {output_mode}")
|
|
243
|
+
|
|
244
|
+
return result
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
class IdentityFormatter(BaseFormatter):
|
|
248
|
+
template = ChatTemplate(
|
|
249
|
+
begin_of_text="",
|
|
250
|
+
end_of_text="",
|
|
251
|
+
begin_system_prompt="",
|
|
252
|
+
system_prompt="",
|
|
253
|
+
end_system_prompt="",
|
|
254
|
+
begin_assistant_id="",
|
|
255
|
+
end_assistant_id="",
|
|
256
|
+
begin_user_id="",
|
|
257
|
+
end_user_id="",
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
class ConcatFormatter(BaseFormatter):
|
|
262
|
+
template = ChatTemplate(
|
|
263
|
+
begin_of_text="",
|
|
264
|
+
end_of_text="",
|
|
265
|
+
begin_system_prompt="",
|
|
266
|
+
system_prompt="",
|
|
267
|
+
end_system_prompt="\n\n",
|
|
268
|
+
begin_assistant_id="",
|
|
269
|
+
end_assistant_id="\n\n",
|
|
270
|
+
begin_user_id="",
|
|
271
|
+
end_user_id="",
|
|
272
|
+
)
|
|
273
|
+
# new lines are handled on task level, so we don't need to strip content here
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
class Llama3Formatter(BaseFormatter):
|
|
277
|
+
template = ChatTemplate(
|
|
278
|
+
begin_of_text="<|begin_of_text|>",
|
|
279
|
+
end_of_text="",
|
|
280
|
+
begin_system_prompt="<|start_header_id|>system<|end_header_id|>\n\n",
|
|
281
|
+
system_prompt="You are a helpful AI assistant",
|
|
282
|
+
end_system_prompt="<|eot_id|>",
|
|
283
|
+
begin_assistant_id="<|start_header_id|>assistant<|end_header_id|>\n\n",
|
|
284
|
+
end_assistant_id="<|eot_id|>",
|
|
285
|
+
begin_user_id="<|start_header_id|>user<|end_header_id|>\n\n",
|
|
286
|
+
end_user_id="<|eot_id|>",
|
|
287
|
+
)
|
|
288
|
+
strip_content = True # stripping content to ensure consistency with HF chat template formatter
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class HFFormatter(BaseFormatter):
|
|
292
|
+
def __init__(self, hf_llm_name: str | Path, chat_template_kwargs: dict[str, Any] | None = None) -> None:
|
|
293
|
+
super().__init__()
|
|
294
|
+
self.tokenizer = AutoTokenizer.from_pretrained(hf_llm_name)
|
|
295
|
+
self.chat_template_kwargs = chat_template_kwargs or {}
|
|
296
|
+
|
|
297
|
+
if self.tokenizer.chat_template is None:
|
|
298
|
+
raise ValueError(f"Chat template is not available for HF model: {hf_llm_name}")
|
|
299
|
+
|
|
300
|
+
def _to_hf_message(self, message: Message) -> dict[str, str]:
|
|
301
|
+
if message.role is None:
|
|
302
|
+
raise ValueError("Message role cannot be None")
|
|
303
|
+
return {"role": message.role.value, "content": message.content}
|
|
304
|
+
|
|
305
|
+
@override
|
|
306
|
+
def format( # type: ignore[override]
|
|
307
|
+
self, messages: Sequence[Message], output_mode: Literal["string", "list"] = "string"
|
|
308
|
+
) -> str:
|
|
309
|
+
hf_chat = [self._to_hf_message(message) for message in messages]
|
|
310
|
+
|
|
311
|
+
template_kwargs = {"tokenize": False, **self.chat_template_kwargs}
|
|
312
|
+
|
|
313
|
+
# output_mode encodes whether or not treat a trailing assistant message
|
|
314
|
+
# as a pre-fill. Training uses 'list' mode, eval uses 'string' mode.
|
|
315
|
+
# The naming is legacy, hence I wrote this comment to clarify. Both
|
|
316
|
+
# code paths return strings.
|
|
317
|
+
if output_mode == "string":
|
|
318
|
+
# if the last message is an assistant message, treat it as a pre-fill (i.e., assistant cue in evals)
|
|
319
|
+
is_prefill = messages[-1].role == Role.ASSISTANT
|
|
320
|
+
template_kwargs.update(
|
|
321
|
+
{
|
|
322
|
+
"add_generation_prompt": not is_prefill,
|
|
323
|
+
"continue_final_message": is_prefill,
|
|
324
|
+
}
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
return self.tokenizer.apply_chat_template(hf_chat, **template_kwargs)
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
class ReasoningFormatter(BaseFormatter):
|
|
331
|
+
template: ReasoningTemplate
|
|
332
|
+
remove_previous_thoughts: bool = False
|
|
333
|
+
|
|
334
|
+
def __init__(self, base_formatter: type[BaseFormatter]) -> None:
|
|
335
|
+
self.template = ReasoningTemplate(
|
|
336
|
+
**asdict(base_formatter.template),
|
|
337
|
+
begin_thought_id="<|begin_of_thought|>",
|
|
338
|
+
end_thought_id="<|end_of_thought|>",
|
|
339
|
+
begin_solution_id="<|begin_of_solution|>",
|
|
340
|
+
end_solution_id="<|end_of_solution|>",
|
|
341
|
+
begin_answer_id="<|begin_of_answer|>",
|
|
342
|
+
end_answer_id="<|end_of_answer|>",
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
def _format_message(self, message: Message, is_last: bool, output_mode: Literal["string", "list"]) -> str:
|
|
346
|
+
result = super()._format_message(message, is_last, output_mode)
|
|
347
|
+
if message.role == Role.USER and output_mode == "string" and (is_last or not self.remove_previous_thoughts):
|
|
348
|
+
result = f"{result}{self.template.begin_thought_id}"
|
|
349
|
+
return result
|
|
350
|
+
|
|
351
|
+
def _format_assistant(self, message: Message, is_last: bool, output_mode: Literal["string", "list"]) -> str:
|
|
352
|
+
"""
|
|
353
|
+
Formats an assistant message based on its property.
|
|
354
|
+
"""
|
|
355
|
+
text = getattr(message, "content", "")
|
|
356
|
+
if self.strip_content:
|
|
357
|
+
text = text.strip()
|
|
358
|
+
|
|
359
|
+
if message.property == Property.THOUGHT:
|
|
360
|
+
result = f"{text}{self.template.end_thought_id}{self.template.begin_solution_id}"
|
|
361
|
+
|
|
362
|
+
elif message.property == Property.SOLUTION:
|
|
363
|
+
result = f"{text}{self.template.begin_answer_id}"
|
|
364
|
+
|
|
365
|
+
elif message.property == Property.ANSWER:
|
|
366
|
+
result = (
|
|
367
|
+
f"{text}{self.template.end_answer_id}{self.template.end_solution_id}{self.template.end_assistant_id}"
|
|
368
|
+
)
|
|
369
|
+
if is_last:
|
|
370
|
+
result = f"{result}{self.template.end_of_text}"
|
|
371
|
+
|
|
372
|
+
elif message.property is None:
|
|
373
|
+
result = text
|
|
374
|
+
# In string mode (i.e., 'eval-framework'), omit end_assistant_id if this is the last message.
|
|
375
|
+
# In list mode (i.e., 'scaling'), always append it.
|
|
376
|
+
if output_mode == "list" or (output_mode == "string" and not is_last):
|
|
377
|
+
result += self.template.end_assistant_id
|
|
378
|
+
elif output_mode == "string":
|
|
379
|
+
if not self.never_strip:
|
|
380
|
+
result = result.strip()
|
|
381
|
+
else:
|
|
382
|
+
raise ValueError(f"Unknown output_mode: {output_mode}")
|
|
383
|
+
|
|
384
|
+
else:
|
|
385
|
+
raise ValueError(f"Unsupported property: {message.property}")
|
|
386
|
+
|
|
387
|
+
return result
|
|
388
|
+
|
|
389
|
+
@staticmethod
|
|
390
|
+
def _verify_messages(messages: Sequence[Message]) -> None:
|
|
391
|
+
# Verify role order.
|
|
392
|
+
BaseFormatter._verify_messages(messages)
|
|
393
|
+
# Verify assistant message sequence.
|
|
394
|
+
for group in BaseFormatter._get_grouped_messages(messages):
|
|
395
|
+
if group[0].role == Role.ASSISTANT:
|
|
396
|
+
if group[0].property is None:
|
|
397
|
+
for msg in group:
|
|
398
|
+
assert msg.property is None, "Assistant message group contains unexpected property combination."
|
|
399
|
+
continue
|
|
400
|
+
if len(group) == 1:
|
|
401
|
+
assert group[0].property == Property.THOUGHT
|
|
402
|
+
elif len(group) == 2:
|
|
403
|
+
assert group[0].property == Property.THOUGHT
|
|
404
|
+
assert group[1].property == Property.SOLUTION
|
|
405
|
+
elif len(group) == 3:
|
|
406
|
+
assert group[0].property == Property.THOUGHT
|
|
407
|
+
assert group[1].property == Property.SOLUTION
|
|
408
|
+
assert group[2].property == Property.ANSWER
|
|
409
|
+
else:
|
|
410
|
+
raise ValueError("Assistant message group is too long")
|
|
411
|
+
|
|
412
|
+
def _validate_output(self, output_str: str) -> tuple[str, ValueError | None]:
|
|
413
|
+
"""Validate the output string according to following cases:
|
|
414
|
+
A) Duplicate Tokens,
|
|
415
|
+
B) Missing Tokens,
|
|
416
|
+
C) Wrong Order,
|
|
417
|
+
D) Still Thinking,
|
|
418
|
+
E) Incomplete,
|
|
419
|
+
F) valid.
|
|
420
|
+
"""
|
|
421
|
+
required_tokens = [
|
|
422
|
+
self.template.end_thought_id,
|
|
423
|
+
self.template.begin_solution_id,
|
|
424
|
+
self.template.end_solution_id,
|
|
425
|
+
self.template.begin_answer_id,
|
|
426
|
+
self.template.end_answer_id,
|
|
427
|
+
]
|
|
428
|
+
|
|
429
|
+
# --- Case A: Duplicate tokens ---
|
|
430
|
+
for token in [self.template.begin_thought_id, *required_tokens]:
|
|
431
|
+
count = output_str.count(token)
|
|
432
|
+
if count > 1:
|
|
433
|
+
return "error", ValueError(f"Duplicate tokens detected: '{token}' appears {count} times.")
|
|
434
|
+
|
|
435
|
+
# --- Case B: Wrong Order ---
|
|
436
|
+
last_index = -1
|
|
437
|
+
missing_tokens = []
|
|
438
|
+
for token in required_tokens:
|
|
439
|
+
index = output_str.find(token)
|
|
440
|
+
if index == -1: # Token is missing
|
|
441
|
+
missing_tokens.append(token)
|
|
442
|
+
else:
|
|
443
|
+
if missing_tokens: # Other token found before missing token
|
|
444
|
+
first = missing_tokens[0]
|
|
445
|
+
return "error", ValueError(f"Missing token: Expected '{first}' but found '{token}' instead.")
|
|
446
|
+
if index < last_index: # Token is out of order
|
|
447
|
+
return "error", ValueError(f"Incorrect token order: '{token}' appears before expected.")
|
|
448
|
+
last_index = index
|
|
449
|
+
|
|
450
|
+
# --- Case C: No end_thought_id ---
|
|
451
|
+
if self.template.end_thought_id in missing_tokens:
|
|
452
|
+
return "not_finished_thinking", None # Incomplete thinking (Case C)
|
|
453
|
+
|
|
454
|
+
# --- Case D: Correct Order but incomplete ---
|
|
455
|
+
elif missing_tokens:
|
|
456
|
+
return "incomplete", None # Incomplete output (Case D)
|
|
457
|
+
|
|
458
|
+
# --- Case E: Valid ---
|
|
459
|
+
else:
|
|
460
|
+
return "valid", None # valid (Case E)
|
|
461
|
+
|
|
462
|
+
def _parse_output(self, output_str: str, thought_only: bool = False) -> dict[str, str]:
|
|
463
|
+
"""
|
|
464
|
+
Extracts reasoning, solution, and final answer texts.
|
|
465
|
+
- If 'thought_only=True', extracts only the reasoning part.
|
|
466
|
+
- Uses regex to handle partial/incomplete outputs.
|
|
467
|
+
"""
|
|
468
|
+
|
|
469
|
+
if thought_only:
|
|
470
|
+
# Allow incomplete outputs (end_of_text is optional)
|
|
471
|
+
pattern = (
|
|
472
|
+
re.escape(self.template.begin_thought_id)
|
|
473
|
+
+ r"(.*?)"
|
|
474
|
+
+ re.escape(self.template.end_thought_id)
|
|
475
|
+
+ r".*?"
|
|
476
|
+
+ re.escape(self.template.end_of_text)
|
|
477
|
+
+ r"$" # <-- Allows anything before <|end_of_text|>
|
|
478
|
+
)
|
|
479
|
+
else:
|
|
480
|
+
# Full extraction pattern
|
|
481
|
+
pattern = (
|
|
482
|
+
re.escape(self.template.begin_thought_id)
|
|
483
|
+
+ r"(.*?)"
|
|
484
|
+
+ re.escape(self.template.end_thought_id)
|
|
485
|
+
+ re.escape(self.template.begin_solution_id)
|
|
486
|
+
+ r"(.*?)"
|
|
487
|
+
+ re.escape(self.template.end_solution_id)
|
|
488
|
+
+ re.escape(self.template.begin_answer_id)
|
|
489
|
+
+ r"(.*?)"
|
|
490
|
+
+ re.escape(self.template.end_answer_id)
|
|
491
|
+
+ r"(?:\s*"
|
|
492
|
+
+ re.escape(self.template.end_of_text)
|
|
493
|
+
+ r")?"
|
|
494
|
+
+ r"$"
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
# Use re.search for partial extraction
|
|
498
|
+
match = re.search(pattern, output_str, re.DOTALL)
|
|
499
|
+
if not match:
|
|
500
|
+
raise ValueError("Parsing failed: Output format does not match expected structure.")
|
|
501
|
+
|
|
502
|
+
# Safely extract each part (handles missing sections)
|
|
503
|
+
reasoning_text = match.group(1).strip() if match.group(1) else ""
|
|
504
|
+
solution_text = match.group(2).strip() if len(match.groups()) > 1 and match.group(2) else ""
|
|
505
|
+
final_answer_text = match.group(3).strip() if len(match.groups()) > 2 and match.group(3) else ""
|
|
506
|
+
|
|
507
|
+
# Return structured Messages
|
|
508
|
+
return {"thought": reasoning_text, "solution": solution_text, "answer": final_answer_text}
|
|
509
|
+
|
|
510
|
+
def parse(self, output_str: str) -> tuple[dict[str, str], ValueError | None]:
|
|
511
|
+
(status, error) = self._validate_output(output_str)
|
|
512
|
+
match status:
|
|
513
|
+
case "error":
|
|
514
|
+
return {}, error
|
|
515
|
+
case "not_finished_thinking":
|
|
516
|
+
output_str_without_end = output_str.replace(self.template.end_of_text, "")
|
|
517
|
+
output_str_extended = output_str_without_end + self.template.end_thought_id + self.template.end_of_text
|
|
518
|
+
return self._parse_output(output_str_extended, thought_only=True), None
|
|
519
|
+
case "incomplete":
|
|
520
|
+
return self._parse_output(output_str, thought_only=True), None
|
|
521
|
+
case "valid":
|
|
522
|
+
return self._parse_output(output_str), None
|
|
523
|
+
case _:
|
|
524
|
+
raise ValueError("Invalid status")
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
def get_formatter(llm_name: str) -> BaseFormatter:
|
|
528
|
+
llm_name = llm_name.lower()
|
|
529
|
+
if "ng_7b" in llm_name or "pharia" in llm_name:
|
|
530
|
+
print("Use LuminousNextgenFormatter")
|
|
531
|
+
return Llama3Formatter()
|
|
532
|
+
elif "llama-3" in llm_name:
|
|
533
|
+
print("Use Llama3Formatter")
|
|
534
|
+
return Llama3Formatter()
|
|
535
|
+
else:
|
|
536
|
+
print("Use ConcatFormatter")
|
|
537
|
+
return ConcatFormatter()
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import Literal, cast
|
|
3
|
+
|
|
4
|
+
from huggingface_hub import hf_hub_download, try_to_load_from_cache
|
|
5
|
+
|
|
6
|
+
# mistral's api specific imports
|
|
7
|
+
from mistral_common.protocol.instruct.messages import AssistantMessage, SystemMessage, UserMessage
|
|
8
|
+
from mistral_common.protocol.instruct.request import ChatCompletionRequest, InstructRequest
|
|
9
|
+
from mistral_common.tokens.tokenizers.base import InstructTokenizer
|
|
10
|
+
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
|
11
|
+
|
|
12
|
+
# package level imports
|
|
13
|
+
from .formatter import BaseFormatter, ChatTemplate, Message, Role
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MistralSerializer:
|
|
17
|
+
def __init__(self, llm_target: str):
|
|
18
|
+
self.tokenizer = MistralTokenizer.from_hf_hub(llm_target)
|
|
19
|
+
|
|
20
|
+
def get_tokenizer(self) -> InstructTokenizer:
|
|
21
|
+
return self.tokenizer.instruct_tokenizer
|
|
22
|
+
|
|
23
|
+
@staticmethod
|
|
24
|
+
def convert_to_aa(msg_lst: Sequence[SystemMessage | UserMessage | AssistantMessage]) -> Sequence[Message]:
|
|
25
|
+
translated_messages: list[Message] = []
|
|
26
|
+
for msg in msg_lst:
|
|
27
|
+
match msg.role:
|
|
28
|
+
case "system":
|
|
29
|
+
translated_messages.append(Message(role=Role.SYSTEM, content=msg.content))
|
|
30
|
+
case "user":
|
|
31
|
+
translated_messages.append(Message(role=Role.USER, content=msg.content))
|
|
32
|
+
case "assistant":
|
|
33
|
+
translated_messages.append(Message(role=Role.ASSISTANT, content=msg.content))
|
|
34
|
+
case _:
|
|
35
|
+
raise ValueError("Role not supported")
|
|
36
|
+
return translated_messages
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
def convert_from_aa(msg_lst: Sequence[Message]) -> Sequence[SystemMessage | UserMessage | AssistantMessage]:
|
|
40
|
+
translated_messages: list[SystemMessage | UserMessage | AssistantMessage] = []
|
|
41
|
+
for idx, msg in enumerate(msg_lst):
|
|
42
|
+
match msg.role:
|
|
43
|
+
case Role.SYSTEM:
|
|
44
|
+
translated_messages.append(SystemMessage(content=msg.content))
|
|
45
|
+
case Role.USER:
|
|
46
|
+
translated_messages.append(UserMessage(content=msg.content))
|
|
47
|
+
case Role.ASSISTANT:
|
|
48
|
+
is_completion_request = idx == (len(msg_lst) - 1) # insturcts model to complete
|
|
49
|
+
translated_messages.append(AssistantMessage(content=msg.content, prefix=is_completion_request))
|
|
50
|
+
case _:
|
|
51
|
+
raise ValueError("Role not supported")
|
|
52
|
+
return translated_messages
|
|
53
|
+
|
|
54
|
+
def build_mistral_request(
|
|
55
|
+
self, mistral_msg_lst: Sequence[SystemMessage | UserMessage | AssistantMessage]
|
|
56
|
+
) -> InstructRequest:
|
|
57
|
+
# build chat request
|
|
58
|
+
request: ChatCompletionRequest = ChatCompletionRequest(messages=mistral_msg_lst)
|
|
59
|
+
# validate pydantic fields
|
|
60
|
+
self.tokenizer._chat_completion_request_validator.validate_request(request)
|
|
61
|
+
# merge same class messages
|
|
62
|
+
instruct_request = self.tokenizer._instruct_request_normalizer.from_chat_completion_request(request)
|
|
63
|
+
return instruct_request
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class MistralFormatter(BaseFormatter):
|
|
67
|
+
def __init__(self, llm_target: str) -> None:
|
|
68
|
+
self.bridge_operator = MistralSerializer(llm_target=llm_target)
|
|
69
|
+
|
|
70
|
+
def format( # type: ignore[override]
|
|
71
|
+
self, messages: Sequence[Message], output_mode: Literal["list"] = "list"
|
|
72
|
+
) -> list[Message]:
|
|
73
|
+
"""
|
|
74
|
+
MistralFormatter intentionally restricts output_mode to 'list' only.
|
|
75
|
+
|
|
76
|
+
This restriction exists because Mistral's tokenization requires special handling
|
|
77
|
+
that bypasses traditional string-based formatting to preserve token boundaries.
|
|
78
|
+
String mode would break the careful tokenization that Mistral's API provides.
|
|
79
|
+
|
|
80
|
+
The type: ignore[override] is intentional; we're deliberately narrowing the
|
|
81
|
+
interface.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
messages: Sequence of messages to format
|
|
85
|
+
output_mode: Must be "list" - string mode is not supported
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
List of validated messages with plain text content
|
|
89
|
+
|
|
90
|
+
Raises:
|
|
91
|
+
ValueError: If output_mode is not "list"
|
|
92
|
+
"""
|
|
93
|
+
# run back and forth translation and validate messages using mistral's API
|
|
94
|
+
if output_mode not in {"list"}:
|
|
95
|
+
raise ValueError("Unsupported output_mode: choose 'list'")
|
|
96
|
+
|
|
97
|
+
mistral_msg_lst = self.bridge_operator.convert_from_aa(msg_lst=messages)
|
|
98
|
+
mistral_request_object = self.bridge_operator.build_mistral_request(mistral_msg_lst=mistral_msg_lst)
|
|
99
|
+
aa_msg_lst = self.bridge_operator.convert_to_aa(msg_lst=mistral_request_object.messages)
|
|
100
|
+
|
|
101
|
+
# run validation using AA API
|
|
102
|
+
self._verify_messages(aa_msg_lst)
|
|
103
|
+
self._verify_message_fields(aa_msg_lst, "list")
|
|
104
|
+
|
|
105
|
+
return cast(list, aa_msg_lst)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class MagistralFormatter(MistralFormatter):
|
|
109
|
+
# these fields are not defined; left to MistralAPI to define; we only leverage system-prompt field
|
|
110
|
+
template = ChatTemplate(
|
|
111
|
+
begin_of_text="",
|
|
112
|
+
end_of_text="",
|
|
113
|
+
begin_system_prompt="",
|
|
114
|
+
system_prompt="",
|
|
115
|
+
end_system_prompt="",
|
|
116
|
+
begin_assistant_id="",
|
|
117
|
+
end_assistant_id="",
|
|
118
|
+
begin_user_id="",
|
|
119
|
+
end_user_id="",
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
def __init__(self, llm_target: str, sys_prompt_fname: str = "SYSTEM_PROMPT.txt") -> None:
|
|
123
|
+
"""
|
|
124
|
+
sys_prompt_fname: name of folder on Magistral model card
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
def read_file(fname: str) -> str:
|
|
128
|
+
with open(fname) as f:
|
|
129
|
+
return f.read().strip()
|
|
130
|
+
|
|
131
|
+
super().__init__(llm_target)
|
|
132
|
+
prompt_path = try_to_load_from_cache(repo_id=llm_target, filename=sys_prompt_fname)
|
|
133
|
+
if isinstance(prompt_path, str):
|
|
134
|
+
self.template.system_prompt = read_file(fname=prompt_path)
|
|
135
|
+
else:
|
|
136
|
+
try:
|
|
137
|
+
prompt_path = hf_hub_download(repo_id=llm_target, filename=sys_prompt_fname)
|
|
138
|
+
self.template.system_prompt = read_file(fname=prompt_path)
|
|
139
|
+
except Exception as e:
|
|
140
|
+
raise e
|
|
141
|
+
|
|
142
|
+
def format( # type: ignore[override]
|
|
143
|
+
self, messages: Sequence[Message], output_mode: Literal["list"] = "list"
|
|
144
|
+
) -> list[Message]:
|
|
145
|
+
"""
|
|
146
|
+
MagistralFormatter extends MistralFormatter with automatic system prompt injection.
|
|
147
|
+
|
|
148
|
+
Inherits the same 'list'-only restriction from MistralFormatter for the same
|
|
149
|
+
tokenization reasons.
|
|
150
|
+
"""
|
|
151
|
+
if output_mode not in {"list"}:
|
|
152
|
+
raise ValueError("Unsupported output_mode: choose 'list'")
|
|
153
|
+
|
|
154
|
+
if messages[0].role != Role.SYSTEM:
|
|
155
|
+
input_messages = [Message(role=Role.SYSTEM, content=self.template.system_prompt), *messages]
|
|
156
|
+
else:
|
|
157
|
+
input_messages = cast(list, messages)
|
|
158
|
+
|
|
159
|
+
return super().format(messages=input_messages)
|