eval-framework 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. eval_framework/__init__.py +7 -0
  2. eval_framework/base_config.py +36 -0
  3. eval_framework/context/__init__.py +0 -0
  4. eval_framework/context/determined.py +177 -0
  5. eval_framework/context/eval.py +121 -0
  6. eval_framework/context/local.py +78 -0
  7. eval_framework/evaluation_generator.py +234 -0
  8. eval_framework/exceptions.py +2 -0
  9. eval_framework/external/ifeval_impl/README.md +5 -0
  10. eval_framework/external/ifeval_impl/instructions.py +1523 -0
  11. eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
  12. eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
  13. eval_framework/external/ifeval_impl/utils.py +135 -0
  14. eval_framework/llm/__init__.py +0 -0
  15. eval_framework/llm/aleph_alpha.py +432 -0
  16. eval_framework/llm/base.py +180 -0
  17. eval_framework/llm/huggingface.py +418 -0
  18. eval_framework/llm/mistral.py +88 -0
  19. eval_framework/llm/models.py +28 -0
  20. eval_framework/llm/openai.py +400 -0
  21. eval_framework/llm/vllm.py +554 -0
  22. eval_framework/logger.py +3 -0
  23. eval_framework/main.py +166 -0
  24. eval_framework/metrics/__init__.py +0 -0
  25. eval_framework/metrics/base.py +40 -0
  26. eval_framework/metrics/completion/__init__.py +1 -0
  27. eval_framework/metrics/completion/accuracy_completion.py +16 -0
  28. eval_framework/metrics/completion/aidanbench.py +28 -0
  29. eval_framework/metrics/completion/bleu.py +76 -0
  30. eval_framework/metrics/completion/chrf.py +62 -0
  31. eval_framework/metrics/completion/code_assertion.py +44 -0
  32. eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
  33. eval_framework/metrics/completion/comet.py +56 -0
  34. eval_framework/metrics/completion/concordance_index.py +38 -0
  35. eval_framework/metrics/completion/csv_format.py +102 -0
  36. eval_framework/metrics/completion/cwe_accuracy.py +49 -0
  37. eval_framework/metrics/completion/exponential_similarity.py +65 -0
  38. eval_framework/metrics/completion/f1.py +42 -0
  39. eval_framework/metrics/completion/format_checker.py +56 -0
  40. eval_framework/metrics/completion/grid_difference.py +77 -0
  41. eval_framework/metrics/completion/ifeval.py +73 -0
  42. eval_framework/metrics/completion/json_format.py +179 -0
  43. eval_framework/metrics/completion/language_checker.py +74 -0
  44. eval_framework/metrics/completion/length_control.py +83 -0
  45. eval_framework/metrics/completion/math_reasoning_completion.py +307 -0
  46. eval_framework/metrics/completion/niah_accuracy.py +163 -0
  47. eval_framework/metrics/completion/placeholder_checker.py +27 -0
  48. eval_framework/metrics/completion/repetition.py +88 -0
  49. eval_framework/metrics/completion/rouge_1.py +35 -0
  50. eval_framework/metrics/completion/rouge_2.py +45 -0
  51. eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
  52. eval_framework/metrics/completion/rouge_l.py +52 -0
  53. eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
  54. eval_framework/metrics/completion/ter.py +67 -0
  55. eval_framework/metrics/completion/text_counter.py +182 -0
  56. eval_framework/metrics/efficiency/__init__.py +0 -0
  57. eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
  58. eval_framework/metrics/llm/__init__.py +0 -0
  59. eval_framework/metrics/llm/base.py +34 -0
  60. eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
  61. eval_framework/metrics/llm/graders/coherence_grader.py +115 -0
  62. eval_framework/metrics/llm/graders/comparison_grader.py +198 -0
  63. eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
  64. eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
  65. eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
  66. eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
  67. eval_framework/metrics/llm/graders/language.py +56 -0
  68. eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
  69. eval_framework/metrics/llm/graders/models.py +74 -0
  70. eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
  71. eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
  72. eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
  73. eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
  74. eval_framework/metrics/llm/llm_judge_coherence.py +44 -0
  75. eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
  76. eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
  77. eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
  78. eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
  79. eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
  80. eval_framework/metrics/llm/llm_judge_mtbench_pair.py +306 -0
  81. eval_framework/metrics/llm/llm_judge_mtbench_single.py +210 -0
  82. eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
  83. eval_framework/metrics/llm/llm_judge_sql.py +394 -0
  84. eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
  85. eval_framework/metrics/llm/utils.py +20 -0
  86. eval_framework/metrics/loglikelihood/__init__.py +0 -0
  87. eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
  88. eval_framework/metrics/loglikelihood/base.py +50 -0
  89. eval_framework/metrics/loglikelihood/confidence_weighted_accuracy.py +25 -0
  90. eval_framework/metrics/loglikelihood/dcs.py +43 -0
  91. eval_framework/metrics/loglikelihood/probability_mass.py +53 -0
  92. eval_framework/metrics/loglikelihood/ternary.py +42 -0
  93. eval_framework/py.typed +0 -0
  94. eval_framework/response_generator.py +351 -0
  95. eval_framework/result_processors/__init__.py +0 -0
  96. eval_framework/result_processors/base.py +88 -0
  97. eval_framework/result_processors/hf_uploader.py +75 -0
  98. eval_framework/result_processors/result_processor.py +129 -0
  99. eval_framework/result_processors/wandb_uploader.py +137 -0
  100. eval_framework/run.py +369 -0
  101. eval_framework/run_direct.py +42 -0
  102. eval_framework/shared/types.py +227 -0
  103. eval_framework/tasks/__init__.py +6 -0
  104. eval_framework/tasks/base.py +392 -0
  105. eval_framework/tasks/benchmarks/__init__.py +0 -0
  106. eval_framework/tasks/benchmarks/aidanbench.py +211 -0
  107. eval_framework/tasks/benchmarks/arc.py +70 -0
  108. eval_framework/tasks/benchmarks/arc_de.py +46 -0
  109. eval_framework/tasks/benchmarks/arc_fi.py +46 -0
  110. eval_framework/tasks/benchmarks/belebele.py +60 -0
  111. eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
  112. eval_framework/tasks/benchmarks/casehold.py +47 -0
  113. eval_framework/tasks/benchmarks/chembench.py +85 -0
  114. eval_framework/tasks/benchmarks/copa.py +64 -0
  115. eval_framework/tasks/benchmarks/duc.py +91 -0
  116. eval_framework/tasks/benchmarks/flores200.py +133 -0
  117. eval_framework/tasks/benchmarks/flores_plus.py +84 -0
  118. eval_framework/tasks/benchmarks/gpqa.py +201 -0
  119. eval_framework/tasks/benchmarks/gsm8k.py +150 -0
  120. eval_framework/tasks/benchmarks/hellaswag.py +69 -0
  121. eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
  122. eval_framework/tasks/benchmarks/humaneval.py +97 -0
  123. eval_framework/tasks/benchmarks/ifeval.py +78 -0
  124. eval_framework/tasks/benchmarks/include.py +119 -0
  125. eval_framework/tasks/benchmarks/infinitebench.py +302 -0
  126. eval_framework/tasks/benchmarks/math_reasoning.py +580 -0
  127. eval_framework/tasks/benchmarks/mbpp.py +192 -0
  128. eval_framework/tasks/benchmarks/mmlu.py +215 -0
  129. eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
  130. eval_framework/tasks/benchmarks/mmlu_pro.py +164 -0
  131. eval_framework/tasks/benchmarks/mmmlu.py +529 -0
  132. eval_framework/tasks/benchmarks/openbookqa.py +85 -0
  133. eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
  134. eval_framework/tasks/benchmarks/pawsx.py +65 -0
  135. eval_framework/tasks/benchmarks/piqa.py +64 -0
  136. eval_framework/tasks/benchmarks/quality.py +56 -0
  137. eval_framework/tasks/benchmarks/sciq.py +110 -0
  138. eval_framework/tasks/benchmarks/sphyr.py +79 -0
  139. eval_framework/tasks/benchmarks/squad.py +211 -0
  140. eval_framework/tasks/benchmarks/struct_eval.py +116 -0
  141. eval_framework/tasks/benchmarks/tablebench.py +117 -0
  142. eval_framework/tasks/benchmarks/triviaqa.py +42 -0
  143. eval_framework/tasks/benchmarks/truthfulqa.py +119 -0
  144. eval_framework/tasks/benchmarks/winogender.py +64 -0
  145. eval_framework/tasks/benchmarks/winogrande.py +69 -0
  146. eval_framework/tasks/benchmarks/winox.py +57 -0
  147. eval_framework/tasks/benchmarks/wmt.py +160 -0
  148. eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
  149. eval_framework/tasks/eval_config.py +136 -0
  150. eval_framework/tasks/perturbation.py +83 -0
  151. eval_framework/tasks/registry.py +186 -0
  152. eval_framework/tasks/task_loader.py +81 -0
  153. eval_framework/tasks/task_names.py +324 -0
  154. eval_framework/tasks/utils.py +584 -0
  155. eval_framework/utils/constants.py +9 -0
  156. eval_framework/utils/file_ops.py +245 -0
  157. eval_framework/utils/generate_task_docs.py +244 -0
  158. eval_framework/utils/helpers.py +32 -0
  159. eval_framework/utils/logging.py +62 -0
  160. eval_framework/utils/packaging.py +52 -0
  161. eval_framework/utils/tqdm_handler.py +14 -0
  162. eval_framework-0.2.7.dist-info/METADATA +548 -0
  163. eval_framework-0.2.7.dist-info/RECORD +170 -0
  164. eval_framework-0.2.7.dist-info/WHEEL +4 -0
  165. eval_framework-0.2.7.dist-info/entry_points.txt +3 -0
  166. template_formatting/README.md +83 -0
  167. template_formatting/__init__.py +0 -0
  168. template_formatting/formatter.py +537 -0
  169. template_formatting/mistral_formatter.py +159 -0
  170. template_formatting/py.typed +0 -0
@@ -0,0 +1,537 @@
1
+ import re
2
+ from collections.abc import Sequence
3
+ from dataclasses import asdict, dataclass
4
+ from enum import Enum
5
+ from pathlib import Path
6
+ from typing import Any, Literal, overload, override
7
+
8
+ from pydantic import BaseModel, field_serializer, field_validator
9
+
10
+ try:
11
+ from transformers import AutoTokenizer
12
+ except ImportError:
13
+ print("template_formatting: `transformers` package is not installed, HFFormatter will not be available.")
14
+
15
+
16
+ class Role(Enum):
17
+ SYSTEM = "system"
18
+ USER = "user"
19
+ ASSISTANT = "assistant"
20
+
21
+
22
+ class Property(Enum):
23
+ ANSWER = "answer"
24
+ THOUGHT = "thought"
25
+ SOLUTION = "solution"
26
+
27
+
28
+ class Message(BaseModel):
29
+ role: Role | None = None # Optional due to compatibility with legacy finetuning format.
30
+ property: Property | None = None
31
+ content: str
32
+ has_loss: bool | None = None
33
+ type: str | None = None
34
+
35
+ @field_serializer("role")
36
+ def serialize_task_name(self, value: Role | None) -> str | None:
37
+ if value is None:
38
+ # Legacy finetuning format.
39
+ return None
40
+ return value.value
41
+
42
+ @field_validator("role", mode="before")
43
+ @classmethod
44
+ def validate_task_name(cls, value: str | Role | None) -> Role | None:
45
+ if value is None:
46
+ # Legacy finetuning format.
47
+ return None
48
+ if isinstance(value, str):
49
+ return Role(value)
50
+ return value
51
+
52
+
53
+ @dataclass
54
+ class ChatTemplate:
55
+ begin_of_text: str
56
+ end_of_text: str
57
+ begin_system_prompt: str
58
+ system_prompt: str
59
+ end_system_prompt: str
60
+ begin_assistant_id: str
61
+ end_assistant_id: str
62
+ begin_user_id: str
63
+ end_user_id: str
64
+
65
+
66
+ @dataclass
67
+ class ReasoningTemplate(ChatTemplate):
68
+ begin_thought_id: str
69
+ end_thought_id: str
70
+ begin_solution_id: str
71
+ end_solution_id: str
72
+ begin_answer_id: str
73
+ end_answer_id: str
74
+
75
+
76
+ class BaseFormatter:
77
+ template: ChatTemplate | ReasoningTemplate
78
+ strip_content: bool = False
79
+ never_strip: bool = False
80
+
81
+ def __init__(self) -> None:
82
+ super().__init__()
83
+ assert not (self.strip_content and self.never_strip), "strip_content and never_strip cannot be both True"
84
+
85
+ @staticmethod
86
+ def _verify_messages(messages: Sequence[Message]) -> None:
87
+ grouped_messages = BaseFormatter._get_grouped_messages(messages)
88
+ offset = int(grouped_messages[0][0].role == Role.SYSTEM)
89
+ user_messages = grouped_messages[offset::2]
90
+ assistant_messages = grouped_messages[offset + 1 :: 2]
91
+ if grouped_messages[0][0].role is None:
92
+ # Legacy finetuning format.
93
+ assert all(m[0].role is None for m in user_messages)
94
+ else: # New format, assert role order.
95
+ assert all(m[0].role == Role.USER for m in user_messages)
96
+ assert all(m[0].role == Role.ASSISTANT for m in assistant_messages)
97
+
98
+ @staticmethod
99
+ def _verify_message_fields(messages: Sequence[Message], output_mode: str) -> None:
100
+ if output_mode not in ("string", "list"):
101
+ raise ValueError("Unsupported output_mode: choose 'string' or 'list'")
102
+
103
+ for message in messages:
104
+ if output_mode == "string":
105
+ # eval-framework style
106
+ if not hasattr(message, "role"):
107
+ raise ValueError("Message is missing 'role' property.")
108
+ if (getattr(message, "type", None) is not None) or (getattr(message, "has_loss", None) is not None):
109
+ raise ValueError()
110
+
111
+ elif output_mode == "list":
112
+ # scaling style
113
+ if not hasattr(message, "type") or not hasattr(message, "has_loss"):
114
+ raise ValueError("Message is missing 'type' or 'has_loss' property.")
115
+
116
+ @staticmethod
117
+ def _get_grouped_messages(messages: Sequence[Message]) -> Sequence[Sequence[Message]]:
118
+ """
119
+ Groups consecutive messages to meet two criteria, while preserving the
120
+ order of each sequence item:
121
+ - Role is identical in each group.
122
+ - Each property occurs once in each group.
123
+ """
124
+ if not messages:
125
+ return []
126
+
127
+ grouped_messages = []
128
+ current_group = [messages[0]]
129
+
130
+ for message in messages[1:]:
131
+ role = current_group[0].role
132
+ group_props = set(i.property for i in current_group)
133
+ if message.role == role and message.property not in group_props:
134
+ current_group.append(message)
135
+ else:
136
+ grouped_messages.append(current_group)
137
+ current_group = [message]
138
+
139
+ grouped_messages.append(current_group)
140
+ return grouped_messages
141
+
142
+ @overload
143
+ def format(self, messages: Sequence[Message], output_mode: Literal["string"] = ...) -> str:
144
+ pass
145
+
146
+ @overload
147
+ def format(self, messages: Sequence[Message], output_mode: Literal["list"]) -> list[Message]:
148
+ pass
149
+
150
+ def format(
151
+ self, messages: Sequence[Message], output_mode: Literal["string", "list"] = "string"
152
+ ) -> str | list[Message]:
153
+ """
154
+ Formats a list of messages using the provided template.
155
+ output_mode: "string" returns a single concatenated string ('eval-framework' style),
156
+ "list" returns the messages with their content updated ('scaling' style).
157
+ """
158
+ self._verify_messages(messages)
159
+ self._verify_message_fields(messages, output_mode)
160
+
161
+ if output_mode not in {"string", "list"}:
162
+ raise ValueError("Unsupported output_mode: choose 'string' or 'list'")
163
+
164
+ if output_mode == "string":
165
+ # Generate formatted strings for each message and join them.
166
+ formatted_parts = (
167
+ self._format_message(message, i == len(messages) - 1, output_mode) for i, message in enumerate(messages)
168
+ )
169
+ return self.template.begin_of_text + "".join(formatted_parts)
170
+ else:
171
+ # Create a new list of messages with updated content.
172
+ new_messages: list[Message] = [message.model_copy(deep=True) for message in messages]
173
+ for i, message in enumerate(new_messages):
174
+ formatted_content = self._format_message(messages[i], i == len(messages) - 1, output_mode)
175
+ message.content = formatted_content
176
+
177
+ # Prepend the begin_of_text to the first message's content.
178
+ if new_messages:
179
+ new_messages[0].content = self.template.begin_of_text + new_messages[0].content
180
+ return new_messages
181
+
182
+ def _format_message(self, message: Message, is_last: bool, output_mode: Literal["string", "list"]) -> str:
183
+ """
184
+ Returns the formatted string for a single message.
185
+ """
186
+ if message.role == Role.SYSTEM:
187
+ text = getattr(message, "content", "")
188
+ if not text and hasattr(self.template, "system_prompt"):
189
+ text = self.template.system_prompt
190
+ if self.strip_content:
191
+ text = text.strip()
192
+ return f"{self.template.begin_system_prompt}{text}{self.template.end_system_prompt}"
193
+
194
+ elif message.role == Role.USER:
195
+ text = getattr(message, "content", "")
196
+ if self.strip_content:
197
+ text = text.strip()
198
+ elif output_mode == "string":
199
+ if is_last or (self.template.end_user_id != "" and not self.never_strip):
200
+ text = text.strip()
201
+ if output_mode == "string" or (output_mode == "list" and not is_last):
202
+ # start assistant message after user message
203
+ result = (
204
+ f"{self.template.begin_user_id}{text}{self.template.end_user_id}{self.template.begin_assistant_id}"
205
+ )
206
+ else:
207
+ # default HF behavior for applying chat template with
208
+ # `add_generation_prompt=False` and `continue_final_message=False` (as used in 'scaling')
209
+ result = f"{self.template.begin_user_id}{text}{self.template.end_user_id}"
210
+ return result
211
+
212
+ elif message.role == Role.ASSISTANT:
213
+ return self._format_assistant(message, is_last, output_mode)
214
+
215
+ elif message.role is None:
216
+ return getattr(message, "content", "")
217
+
218
+ else:
219
+ raise ValueError(f"Unsupported role: {message.role}")
220
+
221
+ def _format_assistant(self, message: Message, is_last: bool, output_mode: Literal["string", "list"]) -> str:
222
+ """
223
+ Formats an assistant message based on its property.
224
+ """
225
+ text = getattr(message, "content", "")
226
+ if self.strip_content:
227
+ text = text.strip()
228
+
229
+ if message.property is not None:
230
+ raise ValueError("Message properties require ReasoningFormatter")
231
+
232
+ else:
233
+ result = text
234
+ # In string mode (i.e., 'eval-framework'), omit end_assistant_id if this is the last message.
235
+ # In list mode (i.e., 'scaling'), always append it.
236
+ if output_mode == "list" or (output_mode == "string" and not is_last):
237
+ result += self.template.end_assistant_id
238
+ elif output_mode == "string":
239
+ if not self.never_strip:
240
+ result = result.strip()
241
+ else:
242
+ raise ValueError(f"Unknown output_mode: {output_mode}")
243
+
244
+ return result
245
+
246
+
247
+ class IdentityFormatter(BaseFormatter):
248
+ template = ChatTemplate(
249
+ begin_of_text="",
250
+ end_of_text="",
251
+ begin_system_prompt="",
252
+ system_prompt="",
253
+ end_system_prompt="",
254
+ begin_assistant_id="",
255
+ end_assistant_id="",
256
+ begin_user_id="",
257
+ end_user_id="",
258
+ )
259
+
260
+
261
+ class ConcatFormatter(BaseFormatter):
262
+ template = ChatTemplate(
263
+ begin_of_text="",
264
+ end_of_text="",
265
+ begin_system_prompt="",
266
+ system_prompt="",
267
+ end_system_prompt="\n\n",
268
+ begin_assistant_id="",
269
+ end_assistant_id="\n\n",
270
+ begin_user_id="",
271
+ end_user_id="",
272
+ )
273
+ # new lines are handled on task level, so we don't need to strip content here
274
+
275
+
276
+ class Llama3Formatter(BaseFormatter):
277
+ template = ChatTemplate(
278
+ begin_of_text="<|begin_of_text|>",
279
+ end_of_text="",
280
+ begin_system_prompt="<|start_header_id|>system<|end_header_id|>\n\n",
281
+ system_prompt="You are a helpful AI assistant",
282
+ end_system_prompt="<|eot_id|>",
283
+ begin_assistant_id="<|start_header_id|>assistant<|end_header_id|>\n\n",
284
+ end_assistant_id="<|eot_id|>",
285
+ begin_user_id="<|start_header_id|>user<|end_header_id|>\n\n",
286
+ end_user_id="<|eot_id|>",
287
+ )
288
+ strip_content = True # stripping content to ensure consistency with HF chat template formatter
289
+
290
+
291
+ class HFFormatter(BaseFormatter):
292
+ def __init__(self, hf_llm_name: str | Path, chat_template_kwargs: dict[str, Any] | None = None) -> None:
293
+ super().__init__()
294
+ self.tokenizer = AutoTokenizer.from_pretrained(hf_llm_name)
295
+ self.chat_template_kwargs = chat_template_kwargs or {}
296
+
297
+ if self.tokenizer.chat_template is None:
298
+ raise ValueError(f"Chat template is not available for HF model: {hf_llm_name}")
299
+
300
+ def _to_hf_message(self, message: Message) -> dict[str, str]:
301
+ if message.role is None:
302
+ raise ValueError("Message role cannot be None")
303
+ return {"role": message.role.value, "content": message.content}
304
+
305
+ @override
306
+ def format( # type: ignore[override]
307
+ self, messages: Sequence[Message], output_mode: Literal["string", "list"] = "string"
308
+ ) -> str:
309
+ hf_chat = [self._to_hf_message(message) for message in messages]
310
+
311
+ template_kwargs = {"tokenize": False, **self.chat_template_kwargs}
312
+
313
+ # output_mode encodes whether or not treat a trailing assistant message
314
+ # as a pre-fill. Training uses 'list' mode, eval uses 'string' mode.
315
+ # The naming is legacy, hence I wrote this comment to clarify. Both
316
+ # code paths return strings.
317
+ if output_mode == "string":
318
+ # if the last message is an assistant message, treat it as a pre-fill (i.e., assistant cue in evals)
319
+ is_prefill = messages[-1].role == Role.ASSISTANT
320
+ template_kwargs.update(
321
+ {
322
+ "add_generation_prompt": not is_prefill,
323
+ "continue_final_message": is_prefill,
324
+ }
325
+ )
326
+
327
+ return self.tokenizer.apply_chat_template(hf_chat, **template_kwargs)
328
+
329
+
330
+ class ReasoningFormatter(BaseFormatter):
331
+ template: ReasoningTemplate
332
+ remove_previous_thoughts: bool = False
333
+
334
+ def __init__(self, base_formatter: type[BaseFormatter]) -> None:
335
+ self.template = ReasoningTemplate(
336
+ **asdict(base_formatter.template),
337
+ begin_thought_id="<|begin_of_thought|>",
338
+ end_thought_id="<|end_of_thought|>",
339
+ begin_solution_id="<|begin_of_solution|>",
340
+ end_solution_id="<|end_of_solution|>",
341
+ begin_answer_id="<|begin_of_answer|>",
342
+ end_answer_id="<|end_of_answer|>",
343
+ )
344
+
345
+ def _format_message(self, message: Message, is_last: bool, output_mode: Literal["string", "list"]) -> str:
346
+ result = super()._format_message(message, is_last, output_mode)
347
+ if message.role == Role.USER and output_mode == "string" and (is_last or not self.remove_previous_thoughts):
348
+ result = f"{result}{self.template.begin_thought_id}"
349
+ return result
350
+
351
+ def _format_assistant(self, message: Message, is_last: bool, output_mode: Literal["string", "list"]) -> str:
352
+ """
353
+ Formats an assistant message based on its property.
354
+ """
355
+ text = getattr(message, "content", "")
356
+ if self.strip_content:
357
+ text = text.strip()
358
+
359
+ if message.property == Property.THOUGHT:
360
+ result = f"{text}{self.template.end_thought_id}{self.template.begin_solution_id}"
361
+
362
+ elif message.property == Property.SOLUTION:
363
+ result = f"{text}{self.template.begin_answer_id}"
364
+
365
+ elif message.property == Property.ANSWER:
366
+ result = (
367
+ f"{text}{self.template.end_answer_id}{self.template.end_solution_id}{self.template.end_assistant_id}"
368
+ )
369
+ if is_last:
370
+ result = f"{result}{self.template.end_of_text}"
371
+
372
+ elif message.property is None:
373
+ result = text
374
+ # In string mode (i.e., 'eval-framework'), omit end_assistant_id if this is the last message.
375
+ # In list mode (i.e., 'scaling'), always append it.
376
+ if output_mode == "list" or (output_mode == "string" and not is_last):
377
+ result += self.template.end_assistant_id
378
+ elif output_mode == "string":
379
+ if not self.never_strip:
380
+ result = result.strip()
381
+ else:
382
+ raise ValueError(f"Unknown output_mode: {output_mode}")
383
+
384
+ else:
385
+ raise ValueError(f"Unsupported property: {message.property}")
386
+
387
+ return result
388
+
389
+ @staticmethod
390
+ def _verify_messages(messages: Sequence[Message]) -> None:
391
+ # Verify role order.
392
+ BaseFormatter._verify_messages(messages)
393
+ # Verify assistant message sequence.
394
+ for group in BaseFormatter._get_grouped_messages(messages):
395
+ if group[0].role == Role.ASSISTANT:
396
+ if group[0].property is None:
397
+ for msg in group:
398
+ assert msg.property is None, "Assistant message group contains unexpected property combination."
399
+ continue
400
+ if len(group) == 1:
401
+ assert group[0].property == Property.THOUGHT
402
+ elif len(group) == 2:
403
+ assert group[0].property == Property.THOUGHT
404
+ assert group[1].property == Property.SOLUTION
405
+ elif len(group) == 3:
406
+ assert group[0].property == Property.THOUGHT
407
+ assert group[1].property == Property.SOLUTION
408
+ assert group[2].property == Property.ANSWER
409
+ else:
410
+ raise ValueError("Assistant message group is too long")
411
+
412
+ def _validate_output(self, output_str: str) -> tuple[str, ValueError | None]:
413
+ """Validate the output string according to following cases:
414
+ A) Duplicate Tokens,
415
+ B) Missing Tokens,
416
+ C) Wrong Order,
417
+ D) Still Thinking,
418
+ E) Incomplete,
419
+ F) valid.
420
+ """
421
+ required_tokens = [
422
+ self.template.end_thought_id,
423
+ self.template.begin_solution_id,
424
+ self.template.end_solution_id,
425
+ self.template.begin_answer_id,
426
+ self.template.end_answer_id,
427
+ ]
428
+
429
+ # --- Case A: Duplicate tokens ---
430
+ for token in [self.template.begin_thought_id, *required_tokens]:
431
+ count = output_str.count(token)
432
+ if count > 1:
433
+ return "error", ValueError(f"Duplicate tokens detected: '{token}' appears {count} times.")
434
+
435
+ # --- Case B: Wrong Order ---
436
+ last_index = -1
437
+ missing_tokens = []
438
+ for token in required_tokens:
439
+ index = output_str.find(token)
440
+ if index == -1: # Token is missing
441
+ missing_tokens.append(token)
442
+ else:
443
+ if missing_tokens: # Other token found before missing token
444
+ first = missing_tokens[0]
445
+ return "error", ValueError(f"Missing token: Expected '{first}' but found '{token}' instead.")
446
+ if index < last_index: # Token is out of order
447
+ return "error", ValueError(f"Incorrect token order: '{token}' appears before expected.")
448
+ last_index = index
449
+
450
+ # --- Case C: No end_thought_id ---
451
+ if self.template.end_thought_id in missing_tokens:
452
+ return "not_finished_thinking", None # Incomplete thinking (Case C)
453
+
454
+ # --- Case D: Correct Order but incomplete ---
455
+ elif missing_tokens:
456
+ return "incomplete", None # Incomplete output (Case D)
457
+
458
+ # --- Case E: Valid ---
459
+ else:
460
+ return "valid", None # valid (Case E)
461
+
462
+ def _parse_output(self, output_str: str, thought_only: bool = False) -> dict[str, str]:
463
+ """
464
+ Extracts reasoning, solution, and final answer texts.
465
+ - If 'thought_only=True', extracts only the reasoning part.
466
+ - Uses regex to handle partial/incomplete outputs.
467
+ """
468
+
469
+ if thought_only:
470
+ # Allow incomplete outputs (end_of_text is optional)
471
+ pattern = (
472
+ re.escape(self.template.begin_thought_id)
473
+ + r"(.*?)"
474
+ + re.escape(self.template.end_thought_id)
475
+ + r".*?"
476
+ + re.escape(self.template.end_of_text)
477
+ + r"$" # <-- Allows anything before <|end_of_text|>
478
+ )
479
+ else:
480
+ # Full extraction pattern
481
+ pattern = (
482
+ re.escape(self.template.begin_thought_id)
483
+ + r"(.*?)"
484
+ + re.escape(self.template.end_thought_id)
485
+ + re.escape(self.template.begin_solution_id)
486
+ + r"(.*?)"
487
+ + re.escape(self.template.end_solution_id)
488
+ + re.escape(self.template.begin_answer_id)
489
+ + r"(.*?)"
490
+ + re.escape(self.template.end_answer_id)
491
+ + r"(?:\s*"
492
+ + re.escape(self.template.end_of_text)
493
+ + r")?"
494
+ + r"$"
495
+ )
496
+
497
+ # Use re.search for partial extraction
498
+ match = re.search(pattern, output_str, re.DOTALL)
499
+ if not match:
500
+ raise ValueError("Parsing failed: Output format does not match expected structure.")
501
+
502
+ # Safely extract each part (handles missing sections)
503
+ reasoning_text = match.group(1).strip() if match.group(1) else ""
504
+ solution_text = match.group(2).strip() if len(match.groups()) > 1 and match.group(2) else ""
505
+ final_answer_text = match.group(3).strip() if len(match.groups()) > 2 and match.group(3) else ""
506
+
507
+ # Return structured Messages
508
+ return {"thought": reasoning_text, "solution": solution_text, "answer": final_answer_text}
509
+
510
+ def parse(self, output_str: str) -> tuple[dict[str, str], ValueError | None]:
511
+ (status, error) = self._validate_output(output_str)
512
+ match status:
513
+ case "error":
514
+ return {}, error
515
+ case "not_finished_thinking":
516
+ output_str_without_end = output_str.replace(self.template.end_of_text, "")
517
+ output_str_extended = output_str_without_end + self.template.end_thought_id + self.template.end_of_text
518
+ return self._parse_output(output_str_extended, thought_only=True), None
519
+ case "incomplete":
520
+ return self._parse_output(output_str, thought_only=True), None
521
+ case "valid":
522
+ return self._parse_output(output_str), None
523
+ case _:
524
+ raise ValueError("Invalid status")
525
+
526
+
527
+ def get_formatter(llm_name: str) -> BaseFormatter:
528
+ llm_name = llm_name.lower()
529
+ if "ng_7b" in llm_name or "pharia" in llm_name:
530
+ print("Use LuminousNextgenFormatter")
531
+ return Llama3Formatter()
532
+ elif "llama-3" in llm_name:
533
+ print("Use Llama3Formatter")
534
+ return Llama3Formatter()
535
+ else:
536
+ print("Use ConcatFormatter")
537
+ return ConcatFormatter()
@@ -0,0 +1,159 @@
1
+ from collections.abc import Sequence
2
+ from typing import Literal, cast
3
+
4
+ from huggingface_hub import hf_hub_download, try_to_load_from_cache
5
+
6
+ # mistral's api specific imports
7
+ from mistral_common.protocol.instruct.messages import AssistantMessage, SystemMessage, UserMessage
8
+ from mistral_common.protocol.instruct.request import ChatCompletionRequest, InstructRequest
9
+ from mistral_common.tokens.tokenizers.base import InstructTokenizer
10
+ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
11
+
12
+ # package level imports
13
+ from .formatter import BaseFormatter, ChatTemplate, Message, Role
14
+
15
+
16
+ class MistralSerializer:
17
+ def __init__(self, llm_target: str):
18
+ self.tokenizer = MistralTokenizer.from_hf_hub(llm_target)
19
+
20
+ def get_tokenizer(self) -> InstructTokenizer:
21
+ return self.tokenizer.instruct_tokenizer
22
+
23
+ @staticmethod
24
+ def convert_to_aa(msg_lst: Sequence[SystemMessage | UserMessage | AssistantMessage]) -> Sequence[Message]:
25
+ translated_messages: list[Message] = []
26
+ for msg in msg_lst:
27
+ match msg.role:
28
+ case "system":
29
+ translated_messages.append(Message(role=Role.SYSTEM, content=msg.content))
30
+ case "user":
31
+ translated_messages.append(Message(role=Role.USER, content=msg.content))
32
+ case "assistant":
33
+ translated_messages.append(Message(role=Role.ASSISTANT, content=msg.content))
34
+ case _:
35
+ raise ValueError("Role not supported")
36
+ return translated_messages
37
+
38
+ @staticmethod
39
+ def convert_from_aa(msg_lst: Sequence[Message]) -> Sequence[SystemMessage | UserMessage | AssistantMessage]:
40
+ translated_messages: list[SystemMessage | UserMessage | AssistantMessage] = []
41
+ for idx, msg in enumerate(msg_lst):
42
+ match msg.role:
43
+ case Role.SYSTEM:
44
+ translated_messages.append(SystemMessage(content=msg.content))
45
+ case Role.USER:
46
+ translated_messages.append(UserMessage(content=msg.content))
47
+ case Role.ASSISTANT:
48
+ is_completion_request = idx == (len(msg_lst) - 1) # insturcts model to complete
49
+ translated_messages.append(AssistantMessage(content=msg.content, prefix=is_completion_request))
50
+ case _:
51
+ raise ValueError("Role not supported")
52
+ return translated_messages
53
+
54
+ def build_mistral_request(
55
+ self, mistral_msg_lst: Sequence[SystemMessage | UserMessage | AssistantMessage]
56
+ ) -> InstructRequest:
57
+ # build chat request
58
+ request: ChatCompletionRequest = ChatCompletionRequest(messages=mistral_msg_lst)
59
+ # validate pydantic fields
60
+ self.tokenizer._chat_completion_request_validator.validate_request(request)
61
+ # merge same class messages
62
+ instruct_request = self.tokenizer._instruct_request_normalizer.from_chat_completion_request(request)
63
+ return instruct_request
64
+
65
+
66
+ class MistralFormatter(BaseFormatter):
67
+ def __init__(self, llm_target: str) -> None:
68
+ self.bridge_operator = MistralSerializer(llm_target=llm_target)
69
+
70
+ def format( # type: ignore[override]
71
+ self, messages: Sequence[Message], output_mode: Literal["list"] = "list"
72
+ ) -> list[Message]:
73
+ """
74
+ MistralFormatter intentionally restricts output_mode to 'list' only.
75
+
76
+ This restriction exists because Mistral's tokenization requires special handling
77
+ that bypasses traditional string-based formatting to preserve token boundaries.
78
+ String mode would break the careful tokenization that Mistral's API provides.
79
+
80
+ The type: ignore[override] is intentional; we're deliberately narrowing the
81
+ interface.
82
+
83
+ Args:
84
+ messages: Sequence of messages to format
85
+ output_mode: Must be "list" - string mode is not supported
86
+
87
+ Returns:
88
+ List of validated messages with plain text content
89
+
90
+ Raises:
91
+ ValueError: If output_mode is not "list"
92
+ """
93
+ # run back and forth translation and validate messages using mistral's API
94
+ if output_mode not in {"list"}:
95
+ raise ValueError("Unsupported output_mode: choose 'list'")
96
+
97
+ mistral_msg_lst = self.bridge_operator.convert_from_aa(msg_lst=messages)
98
+ mistral_request_object = self.bridge_operator.build_mistral_request(mistral_msg_lst=mistral_msg_lst)
99
+ aa_msg_lst = self.bridge_operator.convert_to_aa(msg_lst=mistral_request_object.messages)
100
+
101
+ # run validation using AA API
102
+ self._verify_messages(aa_msg_lst)
103
+ self._verify_message_fields(aa_msg_lst, "list")
104
+
105
+ return cast(list, aa_msg_lst)
106
+
107
+
108
+ class MagistralFormatter(MistralFormatter):
109
+ # these fields are not defined; left to MistralAPI to define; we only leverage system-prompt field
110
+ template = ChatTemplate(
111
+ begin_of_text="",
112
+ end_of_text="",
113
+ begin_system_prompt="",
114
+ system_prompt="",
115
+ end_system_prompt="",
116
+ begin_assistant_id="",
117
+ end_assistant_id="",
118
+ begin_user_id="",
119
+ end_user_id="",
120
+ )
121
+
122
+ def __init__(self, llm_target: str, sys_prompt_fname: str = "SYSTEM_PROMPT.txt") -> None:
123
+ """
124
+ sys_prompt_fname: name of folder on Magistral model card
125
+ """
126
+
127
+ def read_file(fname: str) -> str:
128
+ with open(fname) as f:
129
+ return f.read().strip()
130
+
131
+ super().__init__(llm_target)
132
+ prompt_path = try_to_load_from_cache(repo_id=llm_target, filename=sys_prompt_fname)
133
+ if isinstance(prompt_path, str):
134
+ self.template.system_prompt = read_file(fname=prompt_path)
135
+ else:
136
+ try:
137
+ prompt_path = hf_hub_download(repo_id=llm_target, filename=sys_prompt_fname)
138
+ self.template.system_prompt = read_file(fname=prompt_path)
139
+ except Exception as e:
140
+ raise e
141
+
142
+ def format( # type: ignore[override]
143
+ self, messages: Sequence[Message], output_mode: Literal["list"] = "list"
144
+ ) -> list[Message]:
145
+ """
146
+ MagistralFormatter extends MistralFormatter with automatic system prompt injection.
147
+
148
+ Inherits the same 'list'-only restriction from MistralFormatter for the same
149
+ tokenization reasons.
150
+ """
151
+ if output_mode not in {"list"}:
152
+ raise ValueError("Unsupported output_mode: choose 'list'")
153
+
154
+ if messages[0].role != Role.SYSTEM:
155
+ input_messages = [Message(role=Role.SYSTEM, content=self.template.system_prompt), *messages]
156
+ else:
157
+ input_messages = cast(list, messages)
158
+
159
+ return super().format(messages=input_messages)