eval-framework 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. eval_framework/__init__.py +7 -0
  2. eval_framework/base_config.py +36 -0
  3. eval_framework/context/__init__.py +0 -0
  4. eval_framework/context/determined.py +170 -0
  5. eval_framework/context/eval.py +114 -0
  6. eval_framework/context/local.py +52 -0
  7. eval_framework/evaluation_generator.py +231 -0
  8. eval_framework/exceptions.py +2 -0
  9. eval_framework/external/ifeval_impl/README.md +5 -0
  10. eval_framework/external/ifeval_impl/instructions.py +1523 -0
  11. eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
  12. eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
  13. eval_framework/external/ifeval_impl/utils.py +135 -0
  14. eval_framework/llm/__init__.py +0 -0
  15. eval_framework/llm/aleph_alpha.py +323 -0
  16. eval_framework/llm/base.py +58 -0
  17. eval_framework/llm/huggingface.py +332 -0
  18. eval_framework/llm/mistral.py +73 -0
  19. eval_framework/llm/models.py +16 -0
  20. eval_framework/llm/openai.py +205 -0
  21. eval_framework/llm/vllm.py +438 -0
  22. eval_framework/logger.py +3 -0
  23. eval_framework/main.py +187 -0
  24. eval_framework/metrics/__init__.py +0 -0
  25. eval_framework/metrics/base.py +40 -0
  26. eval_framework/metrics/completion/__init__.py +1 -0
  27. eval_framework/metrics/completion/accuracy_completion.py +16 -0
  28. eval_framework/metrics/completion/bleu.py +76 -0
  29. eval_framework/metrics/completion/chrf.py +62 -0
  30. eval_framework/metrics/completion/code_assertion.py +44 -0
  31. eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
  32. eval_framework/metrics/completion/comet.py +56 -0
  33. eval_framework/metrics/completion/concordance_index.py +38 -0
  34. eval_framework/metrics/completion/csv_format.py +102 -0
  35. eval_framework/metrics/completion/cwe_accuracy.py +49 -0
  36. eval_framework/metrics/completion/exponential_similarity.py +65 -0
  37. eval_framework/metrics/completion/f1.py +42 -0
  38. eval_framework/metrics/completion/format_checker.py +56 -0
  39. eval_framework/metrics/completion/grid_difference.py +77 -0
  40. eval_framework/metrics/completion/ifeval.py +73 -0
  41. eval_framework/metrics/completion/json_format.py +171 -0
  42. eval_framework/metrics/completion/language_checker.py +74 -0
  43. eval_framework/metrics/completion/length_control.py +83 -0
  44. eval_framework/metrics/completion/math_reasoning_completion.py +303 -0
  45. eval_framework/metrics/completion/niah_accuracy.py +163 -0
  46. eval_framework/metrics/completion/placeholder_checker.py +27 -0
  47. eval_framework/metrics/completion/repetition.py +88 -0
  48. eval_framework/metrics/completion/rouge_1.py +35 -0
  49. eval_framework/metrics/completion/rouge_2.py +45 -0
  50. eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
  51. eval_framework/metrics/completion/rouge_l.py +52 -0
  52. eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
  53. eval_framework/metrics/completion/ter.py +67 -0
  54. eval_framework/metrics/completion/text_counter.py +182 -0
  55. eval_framework/metrics/efficiency/__init__.py +0 -0
  56. eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
  57. eval_framework/metrics/llm/__init__.py +0 -0
  58. eval_framework/metrics/llm/base.py +8 -0
  59. eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
  60. eval_framework/metrics/llm/graders/comparison_grader.py +146 -0
  61. eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
  62. eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
  63. eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
  64. eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
  65. eval_framework/metrics/llm/graders/language.py +56 -0
  66. eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
  67. eval_framework/metrics/llm/graders/models.py +74 -0
  68. eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
  69. eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
  70. eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
  71. eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
  72. eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
  73. eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
  74. eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
  75. eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
  76. eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
  77. eval_framework/metrics/llm/llm_judge_mtbench_pair.py +205 -0
  78. eval_framework/metrics/llm/llm_judge_mtbench_single.py +188 -0
  79. eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
  80. eval_framework/metrics/llm/llm_judge_sql.py +394 -0
  81. eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
  82. eval_framework/metrics/loglikelihood/__init__.py +0 -0
  83. eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
  84. eval_framework/metrics/loglikelihood/probability_mass.py +56 -0
  85. eval_framework/py.typed +0 -0
  86. eval_framework/response_generator.py +416 -0
  87. eval_framework/result_processors/__init__.py +0 -0
  88. eval_framework/result_processors/base.py +74 -0
  89. eval_framework/result_processors/hf_processor.py +87 -0
  90. eval_framework/result_processors/result_processor.py +129 -0
  91. eval_framework/run.py +314 -0
  92. eval_framework/run_direct.py +42 -0
  93. eval_framework/shared/types.py +227 -0
  94. eval_framework/tasks/__init__.py +6 -0
  95. eval_framework/tasks/base.py +314 -0
  96. eval_framework/tasks/benchmarks/__init__.py +0 -0
  97. eval_framework/tasks/benchmarks/arc.py +46 -0
  98. eval_framework/tasks/benchmarks/arc_de.py +46 -0
  99. eval_framework/tasks/benchmarks/arc_fi.py +46 -0
  100. eval_framework/tasks/benchmarks/belebele.py +60 -0
  101. eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
  102. eval_framework/tasks/benchmarks/casehold.py +47 -0
  103. eval_framework/tasks/benchmarks/chembench.py +85 -0
  104. eval_framework/tasks/benchmarks/copa.py +39 -0
  105. eval_framework/tasks/benchmarks/duc.py +91 -0
  106. eval_framework/tasks/benchmarks/flores200.py +62 -0
  107. eval_framework/tasks/benchmarks/flores_plus.py +84 -0
  108. eval_framework/tasks/benchmarks/gpqa.py +177 -0
  109. eval_framework/tasks/benchmarks/gsm8k.py +148 -0
  110. eval_framework/tasks/benchmarks/hellaswag.py +44 -0
  111. eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
  112. eval_framework/tasks/benchmarks/humaneval.py +97 -0
  113. eval_framework/tasks/benchmarks/ifeval.py +78 -0
  114. eval_framework/tasks/benchmarks/include.py +119 -0
  115. eval_framework/tasks/benchmarks/infinitebench.py +302 -0
  116. eval_framework/tasks/benchmarks/math_reasoning.py +569 -0
  117. eval_framework/tasks/benchmarks/mbpp.py +192 -0
  118. eval_framework/tasks/benchmarks/mmlu.py +190 -0
  119. eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
  120. eval_framework/tasks/benchmarks/mmlu_pro.py +139 -0
  121. eval_framework/tasks/benchmarks/mmmlu.py +529 -0
  122. eval_framework/tasks/benchmarks/openbookqa.py +37 -0
  123. eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
  124. eval_framework/tasks/benchmarks/pawsx.py +65 -0
  125. eval_framework/tasks/benchmarks/piqa.py +39 -0
  126. eval_framework/tasks/benchmarks/quality.py +56 -0
  127. eval_framework/tasks/benchmarks/sciq.py +44 -0
  128. eval_framework/tasks/benchmarks/sphyr.py +75 -0
  129. eval_framework/tasks/benchmarks/squad.py +89 -0
  130. eval_framework/tasks/benchmarks/struct_eval.py +110 -0
  131. eval_framework/tasks/benchmarks/tablebench.py +117 -0
  132. eval_framework/tasks/benchmarks/triviaqa.py +42 -0
  133. eval_framework/tasks/benchmarks/truthfulqa.py +95 -0
  134. eval_framework/tasks/benchmarks/winogender.py +39 -0
  135. eval_framework/tasks/benchmarks/winogrande.py +44 -0
  136. eval_framework/tasks/benchmarks/winox.py +57 -0
  137. eval_framework/tasks/benchmarks/wmt.py +160 -0
  138. eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
  139. eval_framework/tasks/eval_config.py +112 -0
  140. eval_framework/tasks/perturbation.py +83 -0
  141. eval_framework/tasks/registry.py +186 -0
  142. eval_framework/tasks/task_loader.py +80 -0
  143. eval_framework/tasks/task_names.py +138 -0
  144. eval_framework/tasks/utils.py +578 -0
  145. eval_framework/utils/constants.py +9 -0
  146. eval_framework/utils/generate_task_docs.py +229 -0
  147. eval_framework/utils/helpers.py +3 -0
  148. eval_framework/utils/logging.py +50 -0
  149. eval_framework/utils/packaging.py +52 -0
  150. eval_framework-0.2.0.dist-info/METADATA +514 -0
  151. eval_framework-0.2.0.dist-info/RECORD +161 -0
  152. eval_framework-0.2.0.dist-info/WHEEL +4 -0
  153. eval_framework-0.2.0.dist-info/entry_points.txt +3 -0
  154. template_formatting/README.md +83 -0
  155. template_formatting/__init__.py +0 -0
  156. template_formatting/formatter.py +536 -0
  157. template_formatting/mistral_formatter.py +159 -0
  158. template_formatting/py.typed +0 -0
  159. template_formatting/tests/test_formatter_eval.py +408 -0
  160. template_formatting/tests/test_formatter_scaling.py +253 -0
  161. template_formatting/tests/test_mistral_formatter.py +136 -0
@@ -0,0 +1,536 @@
1
+ import re
2
+ from collections.abc import Sequence
3
+ from dataclasses import asdict, dataclass
4
+ from enum import Enum
5
+ from typing import Any, Literal, overload, override
6
+
7
+ from pydantic import BaseModel, field_serializer, field_validator
8
+
9
+ try:
10
+ from transformers import AutoTokenizer
11
+ except ImportError:
12
+ print("template_formatting: `transformers` package is not installed, HFFormatter will not be available.")
13
+
14
+
15
+ class Role(Enum):
16
+ SYSTEM = "system"
17
+ USER = "user"
18
+ ASSISTANT = "assistant"
19
+
20
+
21
+ class Property(Enum):
22
+ ANSWER = "answer"
23
+ THOUGHT = "thought"
24
+ SOLUTION = "solution"
25
+
26
+
27
+ class Message(BaseModel):
28
+ role: Role | None = None # Optional due to compatibility with legacy finetuning format.
29
+ property: Property | None = None
30
+ content: str
31
+ has_loss: bool | None = None
32
+ type: str | None = None
33
+
34
+ @field_serializer("role")
35
+ def serialize_task_name(self, value: Role | None) -> str | None:
36
+ if value is None:
37
+ # Legacy finetuning format.
38
+ return None
39
+ return value.value
40
+
41
+ @field_validator("role", mode="before")
42
+ @classmethod
43
+ def validate_task_name(cls, value: str | Role | None) -> Role | None:
44
+ if value is None:
45
+ # Legacy finetuning format.
46
+ return None
47
+ if isinstance(value, str):
48
+ return Role(value)
49
+ return value
50
+
51
+
52
+ @dataclass
53
+ class ChatTemplate:
54
+ begin_of_text: str
55
+ end_of_text: str
56
+ begin_system_prompt: str
57
+ system_prompt: str
58
+ end_system_prompt: str
59
+ begin_assistant_id: str
60
+ end_assistant_id: str
61
+ begin_user_id: str
62
+ end_user_id: str
63
+
64
+
65
+ @dataclass
66
+ class ReasoningTemplate(ChatTemplate):
67
+ begin_thought_id: str
68
+ end_thought_id: str
69
+ begin_solution_id: str
70
+ end_solution_id: str
71
+ begin_answer_id: str
72
+ end_answer_id: str
73
+
74
+
75
+ class BaseFormatter:
76
+ template: ChatTemplate | ReasoningTemplate
77
+ strip_content: bool = False
78
+ never_strip: bool = False
79
+
80
+ def __init__(self) -> None:
81
+ super().__init__()
82
+ assert not (self.strip_content and self.never_strip), "strip_content and never_strip cannot be both True"
83
+
84
+ @staticmethod
85
+ def _verify_messages(messages: Sequence[Message]) -> None:
86
+ grouped_messages = BaseFormatter._get_grouped_messages(messages)
87
+ offset = int(grouped_messages[0][0].role == Role.SYSTEM)
88
+ user_messages = grouped_messages[offset::2]
89
+ assistant_messages = grouped_messages[offset + 1 :: 2]
90
+ if grouped_messages[0][0].role is None:
91
+ # Legacy finetuning format.
92
+ assert all(m[0].role is None for m in user_messages)
93
+ else: # New format, assert role order.
94
+ assert all(m[0].role == Role.USER for m in user_messages)
95
+ assert all(m[0].role == Role.ASSISTANT for m in assistant_messages)
96
+
97
+ @staticmethod
98
+ def _verify_message_fields(messages: Sequence[Message], output_mode: str) -> None:
99
+ if output_mode not in ("string", "list"):
100
+ raise ValueError("Unsupported output_mode: choose 'string' or 'list'")
101
+
102
+ for message in messages:
103
+ if output_mode == "string":
104
+ # eval-framework style
105
+ if not hasattr(message, "role"):
106
+ raise ValueError("Message is missing 'role' property.")
107
+ if (getattr(message, "type", None) is not None) or (getattr(message, "has_loss", None) is not None):
108
+ raise ValueError()
109
+
110
+ elif output_mode == "list":
111
+ # scaling style
112
+ if not hasattr(message, "type") or not hasattr(message, "has_loss"):
113
+ raise ValueError("Message is missing 'type' or 'has_loss' property.")
114
+
115
+ @staticmethod
116
+ def _get_grouped_messages(messages: Sequence[Message]) -> Sequence[Sequence[Message]]:
117
+ """
118
+ Groups consecutive messages to meet two criteria, while preserving the
119
+ order of each sequence item:
120
+ - Role is identical in each group.
121
+ - Each property occurs once in each group.
122
+ """
123
+ if not messages:
124
+ return []
125
+
126
+ grouped_messages = []
127
+ current_group = [messages[0]]
128
+
129
+ for message in messages[1:]:
130
+ role = current_group[0].role
131
+ group_props = set(i.property for i in current_group)
132
+ if message.role == role and message.property not in group_props:
133
+ current_group.append(message)
134
+ else:
135
+ grouped_messages.append(current_group)
136
+ current_group = [message]
137
+
138
+ grouped_messages.append(current_group)
139
+ return grouped_messages
140
+
141
+ @overload
142
+ def format(self, messages: Sequence[Message], output_mode: Literal["string"] = ...) -> str:
143
+ pass
144
+
145
+ @overload
146
+ def format(self, messages: Sequence[Message], output_mode: Literal["list"]) -> list[Message]:
147
+ pass
148
+
149
+ def format(
150
+ self, messages: Sequence[Message], output_mode: Literal["string", "list"] = "string"
151
+ ) -> str | list[Message]:
152
+ """
153
+ Formats a list of messages using the provided template.
154
+ output_mode: "string" returns a single concatenated string ('eval-framework' style),
155
+ "list" returns the messages with their content updated ('scaling' style).
156
+ """
157
+ self._verify_messages(messages)
158
+ self._verify_message_fields(messages, output_mode)
159
+
160
+ if output_mode not in {"string", "list"}:
161
+ raise ValueError("Unsupported output_mode: choose 'string' or 'list'")
162
+
163
+ if output_mode == "string":
164
+ # Generate formatted strings for each message and join them.
165
+ formatted_parts = (
166
+ self._format_message(message, i == len(messages) - 1, output_mode) for i, message in enumerate(messages)
167
+ )
168
+ return self.template.begin_of_text + "".join(formatted_parts)
169
+ else:
170
+ # Create a new list of messages with updated content.
171
+ new_messages: list[Message] = [message.model_copy(deep=True) for message in messages]
172
+ for i, message in enumerate(new_messages):
173
+ formatted_content = self._format_message(messages[i], i == len(messages) - 1, output_mode)
174
+ message.content = formatted_content
175
+
176
+ # Prepend the begin_of_text to the first message's content.
177
+ if new_messages:
178
+ new_messages[0].content = self.template.begin_of_text + new_messages[0].content
179
+ return new_messages
180
+
181
+ def _format_message(self, message: Message, is_last: bool, output_mode: Literal["string", "list"]) -> str:
182
+ """
183
+ Returns the formatted string for a single message.
184
+ """
185
+ if message.role == Role.SYSTEM:
186
+ text = getattr(message, "content", "")
187
+ if not text and hasattr(self.template, "system_prompt"):
188
+ text = self.template.system_prompt
189
+ if self.strip_content:
190
+ text = text.strip()
191
+ return f"{self.template.begin_system_prompt}{text}{self.template.end_system_prompt}"
192
+
193
+ elif message.role == Role.USER:
194
+ text = getattr(message, "content", "")
195
+ if self.strip_content:
196
+ text = text.strip()
197
+ elif output_mode == "string":
198
+ if is_last or (self.template.end_user_id != "" and not self.never_strip):
199
+ text = text.strip()
200
+ if output_mode == "string" or (output_mode == "list" and not is_last):
201
+ # start assistant message after user message
202
+ result = (
203
+ f"{self.template.begin_user_id}{text}{self.template.end_user_id}{self.template.begin_assistant_id}"
204
+ )
205
+ else:
206
+ # default HF behavior for applying chat template with
207
+ # `add_generation_prompt=False` and `continue_final_message=False` (as used in 'scaling')
208
+ result = f"{self.template.begin_user_id}{text}{self.template.end_user_id}"
209
+ return result
210
+
211
+ elif message.role == Role.ASSISTANT:
212
+ return self._format_assistant(message, is_last, output_mode)
213
+
214
+ elif message.role is None:
215
+ return getattr(message, "content", "")
216
+
217
+ else:
218
+ raise ValueError(f"Unsupported role: {message.role}")
219
+
220
+ def _format_assistant(self, message: Message, is_last: bool, output_mode: Literal["string", "list"]) -> str:
221
+ """
222
+ Formats an assistant message based on its property.
223
+ """
224
+ text = getattr(message, "content", "")
225
+ if self.strip_content:
226
+ text = text.strip()
227
+
228
+ if message.property is not None:
229
+ raise ValueError("Message properties require ReasoningFormatter")
230
+
231
+ else:
232
+ result = text
233
+ # In string mode (i.e., 'eval-framework'), omit end_assistant_id if this is the last message.
234
+ # In list mode (i.e., 'scaling'), always append it.
235
+ if output_mode == "list" or (output_mode == "string" and not is_last):
236
+ result += self.template.end_assistant_id
237
+ elif output_mode == "string":
238
+ if not self.never_strip:
239
+ result = result.strip()
240
+ else:
241
+ raise ValueError(f"Unknown output_mode: {output_mode}")
242
+
243
+ return result
244
+
245
+
246
+ class IdentityFormatter(BaseFormatter):
247
+ template = ChatTemplate(
248
+ begin_of_text="",
249
+ end_of_text="",
250
+ begin_system_prompt="",
251
+ system_prompt="",
252
+ end_system_prompt="",
253
+ begin_assistant_id="",
254
+ end_assistant_id="",
255
+ begin_user_id="",
256
+ end_user_id="",
257
+ )
258
+
259
+
260
+ class ConcatFormatter(BaseFormatter):
261
+ template = ChatTemplate(
262
+ begin_of_text="",
263
+ end_of_text="",
264
+ begin_system_prompt="",
265
+ system_prompt="",
266
+ end_system_prompt="\n\n",
267
+ begin_assistant_id="",
268
+ end_assistant_id="\n\n",
269
+ begin_user_id="",
270
+ end_user_id="",
271
+ )
272
+ # new lines are handled on task level, so we don't need to strip content here
273
+
274
+
275
+ class Llama3Formatter(BaseFormatter):
276
+ template = ChatTemplate(
277
+ begin_of_text="<|begin_of_text|>",
278
+ end_of_text="",
279
+ begin_system_prompt="<|start_header_id|>system<|end_header_id|>\n\n",
280
+ system_prompt="You are a helpful AI assistant",
281
+ end_system_prompt="<|eot_id|>",
282
+ begin_assistant_id="<|start_header_id|>assistant<|end_header_id|>\n\n",
283
+ end_assistant_id="<|eot_id|>",
284
+ begin_user_id="<|start_header_id|>user<|end_header_id|>\n\n",
285
+ end_user_id="<|eot_id|>",
286
+ )
287
+ strip_content = True # stripping content to ensure consistency with HF chat template formatter
288
+
289
+
290
+ class HFFormatter(BaseFormatter):
291
+ def __init__(self, hf_llm_name: str, chat_template_kwargs: dict[str, Any] | None = None) -> None:
292
+ super().__init__()
293
+ self.tokenizer = AutoTokenizer.from_pretrained(hf_llm_name)
294
+ self.chat_template_kwargs = chat_template_kwargs or {}
295
+
296
+ if self.tokenizer.chat_template is None:
297
+ raise ValueError(f"Chat template is not available for HF model: {hf_llm_name}")
298
+
299
+ def _to_hf_message(self, message: Message) -> dict[str, str]:
300
+ if message.role is None:
301
+ raise ValueError("Message role cannot be None")
302
+ return {"role": message.role.value, "content": message.content}
303
+
304
+ @override
305
+ def format( # type: ignore[override]
306
+ self, messages: Sequence[Message], output_mode: Literal["string", "list"] = "string"
307
+ ) -> str:
308
+ hf_chat = [self._to_hf_message(message) for message in messages]
309
+
310
+ template_kwargs = {"tokenize": False, **self.chat_template_kwargs}
311
+
312
+ # output_mode encodes whether or not treat a trailing assistant message
313
+ # as a pre-fill. Training uses 'list' mode, eval uses 'string' mode.
314
+ # The naming is legacy, hence I wrote this comment to clarify. Both
315
+ # code paths return strings.
316
+ if output_mode == "string":
317
+ # if the last message is an assistant message, treat it as a pre-fill (i.e., assistant cue in evals)
318
+ is_prefill = messages[-1].role == Role.ASSISTANT
319
+ template_kwargs.update(
320
+ {
321
+ "add_generation_prompt": not is_prefill,
322
+ "continue_final_message": is_prefill,
323
+ }
324
+ )
325
+
326
+ return self.tokenizer.apply_chat_template(hf_chat, **template_kwargs)
327
+
328
+
329
+ class ReasoningFormatter(BaseFormatter):
330
+ template: ReasoningTemplate
331
+ remove_previous_thoughts: bool = False
332
+
333
+ def __init__(self, base_formatter: type[BaseFormatter]) -> None:
334
+ self.template = ReasoningTemplate(
335
+ **asdict(base_formatter.template),
336
+ begin_thought_id="<|begin_of_thought|>",
337
+ end_thought_id="<|end_of_thought|>",
338
+ begin_solution_id="<|begin_of_solution|>",
339
+ end_solution_id="<|end_of_solution|>",
340
+ begin_answer_id="<|begin_of_answer|>",
341
+ end_answer_id="<|end_of_answer|>",
342
+ )
343
+
344
+ def _format_message(self, message: Message, is_last: bool, output_mode: Literal["string", "list"]) -> str:
345
+ result = super()._format_message(message, is_last, output_mode)
346
+ if message.role == Role.USER and output_mode == "string" and (is_last or not self.remove_previous_thoughts):
347
+ result = f"{result}{self.template.begin_thought_id}"
348
+ return result
349
+
350
+ def _format_assistant(self, message: Message, is_last: bool, output_mode: Literal["string", "list"]) -> str:
351
+ """
352
+ Formats an assistant message based on its property.
353
+ """
354
+ text = getattr(message, "content", "")
355
+ if self.strip_content:
356
+ text = text.strip()
357
+
358
+ if message.property == Property.THOUGHT:
359
+ result = f"{text}{self.template.end_thought_id}{self.template.begin_solution_id}"
360
+
361
+ elif message.property == Property.SOLUTION:
362
+ result = f"{text}{self.template.begin_answer_id}"
363
+
364
+ elif message.property == Property.ANSWER:
365
+ result = (
366
+ f"{text}{self.template.end_answer_id}{self.template.end_solution_id}{self.template.end_assistant_id}"
367
+ )
368
+ if is_last:
369
+ result = f"{result}{self.template.end_of_text}"
370
+
371
+ elif message.property is None:
372
+ result = text
373
+ # In string mode (i.e., 'eval-framework'), omit end_assistant_id if this is the last message.
374
+ # In list mode (i.e., 'scaling'), always append it.
375
+ if output_mode == "list" or (output_mode == "string" and not is_last):
376
+ result += self.template.end_assistant_id
377
+ elif output_mode == "string":
378
+ if not self.never_strip:
379
+ result = result.strip()
380
+ else:
381
+ raise ValueError(f"Unknown output_mode: {output_mode}")
382
+
383
+ else:
384
+ raise ValueError(f"Unsupported property: {message.property}")
385
+
386
+ return result
387
+
388
+ @staticmethod
389
+ def _verify_messages(messages: Sequence[Message]) -> None:
390
+ # Verify role order.
391
+ BaseFormatter._verify_messages(messages)
392
+ # Verify assistant message sequence.
393
+ for group in BaseFormatter._get_grouped_messages(messages):
394
+ if group[0].role == Role.ASSISTANT:
395
+ if group[0].property is None:
396
+ for msg in group:
397
+ assert msg.property is None, "Assistant message group contains unexpected property combination."
398
+ continue
399
+ if len(group) == 1:
400
+ assert group[0].property == Property.THOUGHT
401
+ elif len(group) == 2:
402
+ assert group[0].property == Property.THOUGHT
403
+ assert group[1].property == Property.SOLUTION
404
+ elif len(group) == 3:
405
+ assert group[0].property == Property.THOUGHT
406
+ assert group[1].property == Property.SOLUTION
407
+ assert group[2].property == Property.ANSWER
408
+ else:
409
+ raise ValueError("Assistant message group is too long")
410
+
411
+ def _validate_output(self, output_str: str) -> tuple[str, ValueError | None]:
412
+ """Validate the output string according to following cases:
413
+ A) Duplicate Tokens,
414
+ B) Missing Tokens,
415
+ C) Wrong Order,
416
+ D) Still Thinking,
417
+ E) Incomplete,
418
+ F) valid.
419
+ """
420
+ required_tokens = [
421
+ self.template.end_thought_id,
422
+ self.template.begin_solution_id,
423
+ self.template.end_solution_id,
424
+ self.template.begin_answer_id,
425
+ self.template.end_answer_id,
426
+ ]
427
+
428
+ # --- Case A: Duplicate tokens ---
429
+ for token in [self.template.begin_thought_id, *required_tokens]:
430
+ count = output_str.count(token)
431
+ if count > 1:
432
+ return "error", ValueError(f"Duplicate tokens detected: '{token}' appears {count} times.")
433
+
434
+ # --- Case B: Wrong Order ---
435
+ last_index = -1
436
+ missing_tokens = []
437
+ for token in required_tokens:
438
+ index = output_str.find(token)
439
+ if index == -1: # Token is missing
440
+ missing_tokens.append(token)
441
+ else:
442
+ if missing_tokens: # Other token found before missing token
443
+ first = missing_tokens[0]
444
+ return "error", ValueError(f"Missing token: Expected '{first}' but found '{token}' instead.")
445
+ if index < last_index: # Token is out of order
446
+ return "error", ValueError(f"Incorrect token order: '{token}' appears before expected.")
447
+ last_index = index
448
+
449
+ # --- Case C: No end_thought_id ---
450
+ if self.template.end_thought_id in missing_tokens:
451
+ return "not_finished_thinking", None # Incomplete thinking (Case C)
452
+
453
+ # --- Case D: Correct Order but incomplete ---
454
+ elif missing_tokens:
455
+ return "incomplete", None # Incomplete output (Case D)
456
+
457
+ # --- Case E: Valid ---
458
+ else:
459
+ return "valid", None # valid (Case E)
460
+
461
+ def _parse_output(self, output_str: str, thought_only: bool = False) -> dict[str, str]:
462
+ """
463
+ Extracts reasoning, solution, and final answer texts.
464
+ - If 'thought_only=True', extracts only the reasoning part.
465
+ - Uses regex to handle partial/incomplete outputs.
466
+ """
467
+
468
+ if thought_only:
469
+ # Allow incomplete outputs (end_of_text is optional)
470
+ pattern = (
471
+ re.escape(self.template.begin_thought_id)
472
+ + r"(.*?)"
473
+ + re.escape(self.template.end_thought_id)
474
+ + r".*?"
475
+ + re.escape(self.template.end_of_text)
476
+ + r"$" # <-- Allows anything before <|end_of_text|>
477
+ )
478
+ else:
479
+ # Full extraction pattern
480
+ pattern = (
481
+ re.escape(self.template.begin_thought_id)
482
+ + r"(.*?)"
483
+ + re.escape(self.template.end_thought_id)
484
+ + re.escape(self.template.begin_solution_id)
485
+ + r"(.*?)"
486
+ + re.escape(self.template.end_solution_id)
487
+ + re.escape(self.template.begin_answer_id)
488
+ + r"(.*?)"
489
+ + re.escape(self.template.end_answer_id)
490
+ + r"(?:\s*"
491
+ + re.escape(self.template.end_of_text)
492
+ + r")?"
493
+ + r"$"
494
+ )
495
+
496
+ # Use re.search for partial extraction
497
+ match = re.search(pattern, output_str, re.DOTALL)
498
+ if not match:
499
+ raise ValueError("Parsing failed: Output format does not match expected structure.")
500
+
501
+ # Safely extract each part (handles missing sections)
502
+ reasoning_text = match.group(1).strip() if match.group(1) else ""
503
+ solution_text = match.group(2).strip() if len(match.groups()) > 1 and match.group(2) else ""
504
+ final_answer_text = match.group(3).strip() if len(match.groups()) > 2 and match.group(3) else ""
505
+
506
+ # Return structured Messages
507
+ return {"thought": reasoning_text, "solution": solution_text, "answer": final_answer_text}
508
+
509
+ def parse(self, output_str: str) -> tuple[dict[str, str], ValueError | None]:
510
+ (status, error) = self._validate_output(output_str)
511
+ match status:
512
+ case "error":
513
+ return {}, error
514
+ case "not_finished_thinking":
515
+ output_str_without_end = output_str.replace(self.template.end_of_text, "")
516
+ output_str_extended = output_str_without_end + self.template.end_thought_id + self.template.end_of_text
517
+ return self._parse_output(output_str_extended, thought_only=True), None
518
+ case "incomplete":
519
+ return self._parse_output(output_str, thought_only=True), None
520
+ case "valid":
521
+ return self._parse_output(output_str), None
522
+ case _:
523
+ raise ValueError("Invalid status")
524
+
525
+
526
+ def get_formatter(llm_name: str) -> BaseFormatter:
527
+ llm_name = llm_name.lower()
528
+ if "ng_7b" in llm_name or "pharia" in llm_name:
529
+ print("Use LuminousNextgenFormatter")
530
+ return Llama3Formatter()
531
+ elif "llama-3" in llm_name:
532
+ print("Use Llama3Formatter")
533
+ return Llama3Formatter()
534
+ else:
535
+ print("Use ConcatFormatter")
536
+ return ConcatFormatter()
@@ -0,0 +1,159 @@
1
+ from collections.abc import Sequence
2
+ from typing import Literal, cast
3
+
4
+ from huggingface_hub import hf_hub_download, try_to_load_from_cache
5
+
6
+ # mistral's api specific imports
7
+ from mistral_common.protocol.instruct.messages import AssistantMessage, SystemMessage, UserMessage
8
+ from mistral_common.protocol.instruct.request import ChatCompletionRequest, InstructRequest
9
+ from mistral_common.tokens.tokenizers.base import InstructTokenizer
10
+ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
11
+
12
+ # package level imports
13
+ from .formatter import BaseFormatter, ChatTemplate, Message, Role
14
+
15
+
16
+ class MistralSerializer:
17
+ def __init__(self, llm_target: str):
18
+ self.tokenizer = MistralTokenizer.from_hf_hub(llm_target)
19
+
20
+ def get_tokenizer(self) -> InstructTokenizer:
21
+ return self.tokenizer.instruct_tokenizer
22
+
23
+ @staticmethod
24
+ def convert_to_aa(msg_lst: Sequence[SystemMessage | UserMessage | AssistantMessage]) -> Sequence[Message]:
25
+ translated_messages: list[Message] = []
26
+ for msg in msg_lst:
27
+ match msg.role:
28
+ case "system":
29
+ translated_messages.append(Message(role=Role.SYSTEM, content=msg.content))
30
+ case "user":
31
+ translated_messages.append(Message(role=Role.USER, content=msg.content))
32
+ case "assistant":
33
+ translated_messages.append(Message(role=Role.ASSISTANT, content=msg.content))
34
+ case _:
35
+ raise ValueError("Role not supported")
36
+ return translated_messages
37
+
38
+ @staticmethod
39
+ def convert_from_aa(msg_lst: Sequence[Message]) -> Sequence[SystemMessage | UserMessage | AssistantMessage]:
40
+ translated_messages: list[SystemMessage | UserMessage | AssistantMessage] = []
41
+ for idx, msg in enumerate(msg_lst):
42
+ match msg.role:
43
+ case Role.SYSTEM:
44
+ translated_messages.append(SystemMessage(content=msg.content))
45
+ case Role.USER:
46
+ translated_messages.append(UserMessage(content=msg.content))
47
+ case Role.ASSISTANT:
48
+ is_completion_request = idx == (len(msg_lst) - 1) # insturcts model to complete
49
+ translated_messages.append(AssistantMessage(content=msg.content, prefix=is_completion_request))
50
+ case _:
51
+ raise ValueError("Role not supported")
52
+ return translated_messages
53
+
54
+ def build_mistral_request(
55
+ self, mistral_msg_lst: Sequence[SystemMessage | UserMessage | AssistantMessage]
56
+ ) -> InstructRequest:
57
+ # build chat request
58
+ request: ChatCompletionRequest = ChatCompletionRequest(messages=mistral_msg_lst)
59
+ # validate pydantic fields
60
+ self.tokenizer._chat_completion_request_validator.validate_request(request)
61
+ # merge same class messages
62
+ instruct_request = self.tokenizer._instruct_request_normalizer.from_chat_completion_request(request)
63
+ return instruct_request
64
+
65
+
66
+ class MistralFormatter(BaseFormatter):
67
+ def __init__(self, llm_target: str) -> None:
68
+ self.bridge_operator = MistralSerializer(llm_target=llm_target)
69
+
70
+ def format( # type: ignore[override]
71
+ self, messages: Sequence[Message], output_mode: Literal["list"] = "list"
72
+ ) -> list[Message]:
73
+ """
74
+ MistralFormatter intentionally restricts output_mode to 'list' only.
75
+
76
+ This restriction exists because Mistral's tokenization requires special handling
77
+ that bypasses traditional string-based formatting to preserve token boundaries.
78
+ String mode would break the careful tokenization that Mistral's API provides.
79
+
80
+ The type: ignore[override] is intentional; we're deliberately narrowing the
81
+ interface.
82
+
83
+ Args:
84
+ messages: Sequence of messages to format
85
+ output_mode: Must be "list" - string mode is not supported
86
+
87
+ Returns:
88
+ List of validated messages with plain text content
89
+
90
+ Raises:
91
+ ValueError: If output_mode is not "list"
92
+ """
93
+ # run back and forth translation and validate messages using mistral's API
94
+ if output_mode not in {"list"}:
95
+ raise ValueError("Unsupported output_mode: choose 'list'")
96
+
97
+ mistral_msg_lst = self.bridge_operator.convert_from_aa(msg_lst=messages)
98
+ mistral_request_object = self.bridge_operator.build_mistral_request(mistral_msg_lst=mistral_msg_lst)
99
+ aa_msg_lst = self.bridge_operator.convert_to_aa(msg_lst=mistral_request_object.messages)
100
+
101
+ # run validation using AA API
102
+ self._verify_messages(aa_msg_lst)
103
+ self._verify_message_fields(aa_msg_lst, "list")
104
+
105
+ return cast(list, aa_msg_lst)
106
+
107
+
108
+ class MagistralFormatter(MistralFormatter):
109
+ # these fields are not defined; left to MistralAPI to define; we only leverage system-prompt field
110
+ template = ChatTemplate(
111
+ begin_of_text="",
112
+ end_of_text="",
113
+ begin_system_prompt="",
114
+ system_prompt="",
115
+ end_system_prompt="",
116
+ begin_assistant_id="",
117
+ end_assistant_id="",
118
+ begin_user_id="",
119
+ end_user_id="",
120
+ )
121
+
122
+ def __init__(self, llm_target: str, sys_prompt_fname: str = "SYSTEM_PROMPT.txt") -> None:
123
+ """
124
+ sys_prompt_fname: name of folder on Magistral model card
125
+ """
126
+
127
+ def read_file(fname: str) -> str:
128
+ with open(fname) as f:
129
+ return f.read().strip()
130
+
131
+ super().__init__(llm_target)
132
+ prompt_path = try_to_load_from_cache(repo_id=llm_target, filename=sys_prompt_fname)
133
+ if isinstance(prompt_path, str):
134
+ self.template.system_prompt = read_file(fname=prompt_path)
135
+ else:
136
+ try:
137
+ prompt_path = hf_hub_download(repo_id=llm_target, filename=sys_prompt_fname)
138
+ self.template.system_prompt = read_file(fname=prompt_path)
139
+ except Exception as e:
140
+ raise e
141
+
142
+ def format( # type: ignore[override]
143
+ self, messages: Sequence[Message], output_mode: Literal["list"] = "list"
144
+ ) -> list[Message]:
145
+ """
146
+ MagistralFormatter extends MistralFormatter with automatic system prompt injection.
147
+
148
+ Inherits the same 'list'-only restriction from MistralFormatter for the same
149
+ tokenization reasons.
150
+ """
151
+ if output_mode not in {"list"}:
152
+ raise ValueError("Unsupported output_mode: choose 'list'")
153
+
154
+ if messages[0].role != Role.SYSTEM:
155
+ input_messages = [Message(role=Role.SYSTEM, content=self.template.system_prompt), *messages]
156
+ else:
157
+ input_messages = cast(list, messages)
158
+
159
+ return super().format(messages=input_messages)
File without changes