eval-framework 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. eval_framework/__init__.py +7 -0
  2. eval_framework/base_config.py +36 -0
  3. eval_framework/context/__init__.py +0 -0
  4. eval_framework/context/determined.py +170 -0
  5. eval_framework/context/eval.py +114 -0
  6. eval_framework/context/local.py +52 -0
  7. eval_framework/evaluation_generator.py +231 -0
  8. eval_framework/exceptions.py +2 -0
  9. eval_framework/external/ifeval_impl/README.md +5 -0
  10. eval_framework/external/ifeval_impl/instructions.py +1523 -0
  11. eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
  12. eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
  13. eval_framework/external/ifeval_impl/utils.py +135 -0
  14. eval_framework/llm/__init__.py +0 -0
  15. eval_framework/llm/aleph_alpha.py +323 -0
  16. eval_framework/llm/base.py +58 -0
  17. eval_framework/llm/huggingface.py +332 -0
  18. eval_framework/llm/mistral.py +73 -0
  19. eval_framework/llm/models.py +16 -0
  20. eval_framework/llm/openai.py +205 -0
  21. eval_framework/llm/vllm.py +438 -0
  22. eval_framework/logger.py +3 -0
  23. eval_framework/main.py +187 -0
  24. eval_framework/metrics/__init__.py +0 -0
  25. eval_framework/metrics/base.py +40 -0
  26. eval_framework/metrics/completion/__init__.py +1 -0
  27. eval_framework/metrics/completion/accuracy_completion.py +16 -0
  28. eval_framework/metrics/completion/bleu.py +76 -0
  29. eval_framework/metrics/completion/chrf.py +62 -0
  30. eval_framework/metrics/completion/code_assertion.py +44 -0
  31. eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
  32. eval_framework/metrics/completion/comet.py +56 -0
  33. eval_framework/metrics/completion/concordance_index.py +38 -0
  34. eval_framework/metrics/completion/csv_format.py +102 -0
  35. eval_framework/metrics/completion/cwe_accuracy.py +49 -0
  36. eval_framework/metrics/completion/exponential_similarity.py +65 -0
  37. eval_framework/metrics/completion/f1.py +42 -0
  38. eval_framework/metrics/completion/format_checker.py +56 -0
  39. eval_framework/metrics/completion/grid_difference.py +77 -0
  40. eval_framework/metrics/completion/ifeval.py +73 -0
  41. eval_framework/metrics/completion/json_format.py +171 -0
  42. eval_framework/metrics/completion/language_checker.py +74 -0
  43. eval_framework/metrics/completion/length_control.py +83 -0
  44. eval_framework/metrics/completion/math_reasoning_completion.py +303 -0
  45. eval_framework/metrics/completion/niah_accuracy.py +163 -0
  46. eval_framework/metrics/completion/placeholder_checker.py +27 -0
  47. eval_framework/metrics/completion/repetition.py +88 -0
  48. eval_framework/metrics/completion/rouge_1.py +35 -0
  49. eval_framework/metrics/completion/rouge_2.py +45 -0
  50. eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
  51. eval_framework/metrics/completion/rouge_l.py +52 -0
  52. eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
  53. eval_framework/metrics/completion/ter.py +67 -0
  54. eval_framework/metrics/completion/text_counter.py +182 -0
  55. eval_framework/metrics/efficiency/__init__.py +0 -0
  56. eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
  57. eval_framework/metrics/llm/__init__.py +0 -0
  58. eval_framework/metrics/llm/base.py +8 -0
  59. eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
  60. eval_framework/metrics/llm/graders/comparison_grader.py +146 -0
  61. eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
  62. eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
  63. eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
  64. eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
  65. eval_framework/metrics/llm/graders/language.py +56 -0
  66. eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
  67. eval_framework/metrics/llm/graders/models.py +74 -0
  68. eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
  69. eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
  70. eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
  71. eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
  72. eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
  73. eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
  74. eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
  75. eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
  76. eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
  77. eval_framework/metrics/llm/llm_judge_mtbench_pair.py +205 -0
  78. eval_framework/metrics/llm/llm_judge_mtbench_single.py +188 -0
  79. eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
  80. eval_framework/metrics/llm/llm_judge_sql.py +394 -0
  81. eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
  82. eval_framework/metrics/loglikelihood/__init__.py +0 -0
  83. eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
  84. eval_framework/metrics/loglikelihood/probability_mass.py +56 -0
  85. eval_framework/py.typed +0 -0
  86. eval_framework/response_generator.py +416 -0
  87. eval_framework/result_processors/__init__.py +0 -0
  88. eval_framework/result_processors/base.py +74 -0
  89. eval_framework/result_processors/hf_processor.py +87 -0
  90. eval_framework/result_processors/result_processor.py +129 -0
  91. eval_framework/run.py +314 -0
  92. eval_framework/run_direct.py +42 -0
  93. eval_framework/shared/types.py +227 -0
  94. eval_framework/tasks/__init__.py +6 -0
  95. eval_framework/tasks/base.py +314 -0
  96. eval_framework/tasks/benchmarks/__init__.py +0 -0
  97. eval_framework/tasks/benchmarks/arc.py +46 -0
  98. eval_framework/tasks/benchmarks/arc_de.py +46 -0
  99. eval_framework/tasks/benchmarks/arc_fi.py +46 -0
  100. eval_framework/tasks/benchmarks/belebele.py +60 -0
  101. eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
  102. eval_framework/tasks/benchmarks/casehold.py +47 -0
  103. eval_framework/tasks/benchmarks/chembench.py +85 -0
  104. eval_framework/tasks/benchmarks/copa.py +39 -0
  105. eval_framework/tasks/benchmarks/duc.py +91 -0
  106. eval_framework/tasks/benchmarks/flores200.py +62 -0
  107. eval_framework/tasks/benchmarks/flores_plus.py +84 -0
  108. eval_framework/tasks/benchmarks/gpqa.py +177 -0
  109. eval_framework/tasks/benchmarks/gsm8k.py +148 -0
  110. eval_framework/tasks/benchmarks/hellaswag.py +44 -0
  111. eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
  112. eval_framework/tasks/benchmarks/humaneval.py +97 -0
  113. eval_framework/tasks/benchmarks/ifeval.py +78 -0
  114. eval_framework/tasks/benchmarks/include.py +119 -0
  115. eval_framework/tasks/benchmarks/infinitebench.py +302 -0
  116. eval_framework/tasks/benchmarks/math_reasoning.py +569 -0
  117. eval_framework/tasks/benchmarks/mbpp.py +192 -0
  118. eval_framework/tasks/benchmarks/mmlu.py +190 -0
  119. eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
  120. eval_framework/tasks/benchmarks/mmlu_pro.py +139 -0
  121. eval_framework/tasks/benchmarks/mmmlu.py +529 -0
  122. eval_framework/tasks/benchmarks/openbookqa.py +37 -0
  123. eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
  124. eval_framework/tasks/benchmarks/pawsx.py +65 -0
  125. eval_framework/tasks/benchmarks/piqa.py +39 -0
  126. eval_framework/tasks/benchmarks/quality.py +56 -0
  127. eval_framework/tasks/benchmarks/sciq.py +44 -0
  128. eval_framework/tasks/benchmarks/sphyr.py +75 -0
  129. eval_framework/tasks/benchmarks/squad.py +89 -0
  130. eval_framework/tasks/benchmarks/struct_eval.py +110 -0
  131. eval_framework/tasks/benchmarks/tablebench.py +117 -0
  132. eval_framework/tasks/benchmarks/triviaqa.py +42 -0
  133. eval_framework/tasks/benchmarks/truthfulqa.py +95 -0
  134. eval_framework/tasks/benchmarks/winogender.py +39 -0
  135. eval_framework/tasks/benchmarks/winogrande.py +44 -0
  136. eval_framework/tasks/benchmarks/winox.py +57 -0
  137. eval_framework/tasks/benchmarks/wmt.py +160 -0
  138. eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
  139. eval_framework/tasks/eval_config.py +112 -0
  140. eval_framework/tasks/perturbation.py +83 -0
  141. eval_framework/tasks/registry.py +186 -0
  142. eval_framework/tasks/task_loader.py +80 -0
  143. eval_framework/tasks/task_names.py +138 -0
  144. eval_framework/tasks/utils.py +578 -0
  145. eval_framework/utils/constants.py +9 -0
  146. eval_framework/utils/generate_task_docs.py +229 -0
  147. eval_framework/utils/helpers.py +3 -0
  148. eval_framework/utils/logging.py +50 -0
  149. eval_framework/utils/packaging.py +52 -0
  150. eval_framework-0.2.0.dist-info/METADATA +514 -0
  151. eval_framework-0.2.0.dist-info/RECORD +161 -0
  152. eval_framework-0.2.0.dist-info/WHEEL +4 -0
  153. eval_framework-0.2.0.dist-info/entry_points.txt +3 -0
  154. template_formatting/README.md +83 -0
  155. template_formatting/__init__.py +0 -0
  156. template_formatting/formatter.py +536 -0
  157. template_formatting/mistral_formatter.py +159 -0
  158. template_formatting/py.typed +0 -0
  159. template_formatting/tests/test_formatter_eval.py +408 -0
  160. template_formatting/tests/test_formatter_scaling.py +253 -0
  161. template_formatting/tests/test_mistral_formatter.py +136 -0
@@ -0,0 +1,408 @@
1
+ # ruff: noqa: E501
2
+ import importlib.util
3
+
4
+ import pytest
5
+
6
+ from template_formatting.formatter import (
7
+ BaseFormatter,
8
+ ConcatFormatter,
9
+ HFFormatter,
10
+ Llama3Formatter,
11
+ Message,
12
+ Property,
13
+ ReasoningFormatter,
14
+ Role,
15
+ get_formatter,
16
+ )
17
+
18
+ package_exists = importlib.util.find_spec("transformers") is not None
19
+
20
+ # no tests requiring a GPU runner are contained here -> no additional pytest GPU markers
21
+
22
+
23
+ @pytest.fixture()
24
+ def concat_formatter() -> BaseFormatter:
25
+ return ConcatFormatter()
26
+
27
+
28
+ @pytest.fixture()
29
+ def llama3_formatter() -> BaseFormatter:
30
+ return Llama3Formatter()
31
+
32
+
33
+ @pytest.fixture()
34
+ def hf_formatter() -> BaseFormatter:
35
+ return HFFormatter("meta-llama/Meta-Llama-3-8B-Instruct")
36
+
37
+
38
+ @pytest.fixture()
39
+ def llama3_reasoning_formatter() -> BaseFormatter:
40
+ llama3_reasoning_formatter = ReasoningFormatter(Llama3Formatter)
41
+ llama3_reasoning_formatter.template.end_of_text = "<|end_of_text|>"
42
+ return llama3_reasoning_formatter
43
+
44
+
45
+ def test_concat_formatter(concat_formatter: BaseFormatter) -> None:
46
+ messages = [
47
+ Message(role=Role.SYSTEM, content="You are a helpful AI assistant for travel tips and recommendations"),
48
+ Message(role=Role.USER, content="What is France's capital?\n"), # new line has to be handled on task level
49
+ Message(role=Role.ASSISTANT, content="Bonjour! The capital of France is Paris!"),
50
+ Message(role=Role.USER, content="Great, thanks!"),
51
+ ]
52
+
53
+ formatted_conversation = concat_formatter.format(messages, output_mode="string")
54
+ expected_output = (
55
+ "You are a helpful AI assistant for travel tips and recommendations\n\n"
56
+ "What is France's capital?\n"
57
+ "Bonjour! The capital of France is Paris!\n\n"
58
+ "Great, thanks!"
59
+ )
60
+
61
+ assert formatted_conversation == expected_output
62
+
63
+
64
+ @pytest.mark.skipif(
65
+ not package_exists,
66
+ reason="`transformers` package is not installed, HFFormatter will not be available.",
67
+ )
68
+ def test_llama3_formatter_with_system_and_assistant_simple(
69
+ llama3_formatter: BaseFormatter, hf_formatter: BaseFormatter
70
+ ) -> None:
71
+ conversation = [
72
+ Message(role=Role.SYSTEM, content="You are a helpful AI assistant for travel tips and recommendations"),
73
+ Message(role=Role.USER, content="What is France's capital?"),
74
+ Message(role=Role.ASSISTANT, content="Bonjour! The capital of France is Paris!"),
75
+ ]
76
+
77
+ formatted_conversation = llama3_formatter.format(conversation, output_mode="string")
78
+ expected_output = (
79
+ "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
80
+ "You are a helpful AI assistant for travel tips and recommendations<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
81
+ "What is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
82
+ "Bonjour! The capital of France is Paris!"
83
+ )
84
+
85
+ assert formatted_conversation == expected_output
86
+
87
+ hf_formatted_conversation = hf_formatter.format(conversation, output_mode="string")
88
+ assert hf_formatted_conversation == expected_output
89
+
90
+
91
+ @pytest.mark.skipif(
92
+ not package_exists,
93
+ reason="`transformers` package is not installed, HFFormatter will not be available.",
94
+ )
95
+ def test_llama3_formatter_with_system_and_assistant(
96
+ llama3_formatter: BaseFormatter, hf_formatter: BaseFormatter
97
+ ) -> None:
98
+ conversation = [
99
+ Message(role=Role.SYSTEM, content="You are a helpful AI assistant for travel tips and recommendations"),
100
+ Message(role=Role.USER, content="What is France's capital?"),
101
+ Message(role=Role.ASSISTANT, content="Bonjour! The capital of France is Paris!"),
102
+ Message(role=Role.USER, content="Great, thanks!"),
103
+ ]
104
+
105
+ formatted_conversation = llama3_formatter.format(conversation, output_mode="string")
106
+ expected_output = (
107
+ "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
108
+ "You are a helpful AI assistant for travel tips and recommendations<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
109
+ "What is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
110
+ "Bonjour! The capital of France is Paris!<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
111
+ "Great, thanks!<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
112
+ )
113
+
114
+ assert formatted_conversation == expected_output
115
+
116
+ hf_formatted_conversation = hf_formatter.format(conversation, output_mode="string")
117
+ assert hf_formatted_conversation == expected_output
118
+
119
+
120
+ @pytest.mark.skipif(
121
+ not package_exists,
122
+ reason="`transformers` package is not installed, HFFormatter will not be available.",
123
+ )
124
+ def test_llama3_formatter_without_system_and_assistant(
125
+ llama3_formatter: BaseFormatter, hf_formatter: BaseFormatter
126
+ ) -> None:
127
+ conversation = [
128
+ Message(role=Role.USER, content="What is France's capital?"),
129
+ ]
130
+
131
+ formatted_conversation = llama3_formatter.format(conversation, output_mode="string")
132
+ expected_output = (
133
+ "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
134
+ "What is France's capital?<|eot_id|>"
135
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
136
+ )
137
+ assert formatted_conversation == expected_output
138
+
139
+ hf_formatted_conversation = hf_formatter.format(conversation, output_mode="string")
140
+ assert hf_formatted_conversation == expected_output
141
+
142
+
143
+ @pytest.mark.skipif(
144
+ not package_exists,
145
+ reason="`transformers` package is not installed, HFFormatter will not be available.",
146
+ )
147
+ def test_llama3_formatter_without_system_multiple_rounds(
148
+ llama3_formatter: BaseFormatter, hf_formatter: BaseFormatter
149
+ ) -> None:
150
+ conversation = [
151
+ Message(role=Role.USER, content="What is France's capital?"),
152
+ Message(role=Role.ASSISTANT, content="Bonjour! The capital of France is Paris!"),
153
+ Message(role=Role.USER, content="What can I do there?"),
154
+ Message(
155
+ role=Role.ASSISTANT,
156
+ content=(
157
+ "Paris offers many attractions and activities. "
158
+ "Some popular things to do include visiting the Eiffel Tower, "
159
+ "exploring the Louvre Museum, taking a river cruise along the Seine, "
160
+ "and strolling through charming neighborhoods like Montmartre."
161
+ ),
162
+ ),
163
+ Message(role=Role.USER, content="What else?"),
164
+ ]
165
+
166
+ formatted_conversation = llama3_formatter.format(conversation, output_mode="string")
167
+ expected_output = (
168
+ "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
169
+ "What is France's capital?<|eot_id|>"
170
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
171
+ "Bonjour! The capital of France is Paris!<|eot_id|>"
172
+ "<|start_header_id|>user<|end_header_id|>\n\n"
173
+ "What can I do there?<|eot_id|>"
174
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
175
+ "Paris offers many attractions and activities. Some popular things to do "
176
+ "include visiting the Eiffel Tower, exploring the Louvre Museum, taking a river "
177
+ "cruise along the Seine, and strolling through charming neighborhoods like Montmartre.<|eot_id|>"
178
+ "<|start_header_id|>user<|end_header_id|>\n\n"
179
+ "What else?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
180
+ )
181
+ assert formatted_conversation == expected_output
182
+
183
+ hf_formatted_conversation = hf_formatter.format(conversation, output_mode="string")
184
+ assert hf_formatted_conversation == expected_output
185
+
186
+
187
+ @pytest.mark.skipif(
188
+ not package_exists,
189
+ reason="`transformers` package is not installed, HFFormatter will not be available.",
190
+ )
191
+ def test_llama3_formatter_with_prefilling(llama3_formatter: BaseFormatter, hf_formatter: BaseFormatter) -> None:
192
+ conversation = [
193
+ Message(role=Role.USER, content="How many helicopters can a human eat in one sitting?"),
194
+ Message(role=Role.ASSISTANT, content="A human can"), # aka "cue"
195
+ ]
196
+
197
+ formatted_conversation = llama3_formatter.format(conversation, output_mode="string")
198
+ expected_output = (
199
+ "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
200
+ "How many helicopters can a human eat in one sitting?<|eot_id|>"
201
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
202
+ "A human can"
203
+ )
204
+ assert formatted_conversation == expected_output
205
+
206
+ hf_formatted_conversation = hf_formatter.format(conversation, output_mode="string")
207
+ assert hf_formatted_conversation == expected_output
208
+
209
+
210
+ @pytest.mark.skipif(
211
+ not package_exists,
212
+ reason="`transformers` package is not installed, HFFormatter will not be available.",
213
+ )
214
+ def test_stripping_of_whitespace(llama3_formatter: BaseFormatter, hf_formatter: BaseFormatter) -> None:
215
+ conversation = [
216
+ Message(role=Role.USER, content=" What is the capital of France? "),
217
+ Message(role=Role.ASSISTANT, content=" The capital of France is "), #
218
+ ]
219
+
220
+ formatted_conversation = llama3_formatter.format(conversation, output_mode="string")
221
+ expected_output = (
222
+ "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
223
+ "What is the capital of France?<|eot_id|>"
224
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
225
+ "The capital of France is"
226
+ )
227
+ assert formatted_conversation == expected_output
228
+
229
+ hf_formatted_conversation = hf_formatter.format(conversation, output_mode="string")
230
+ assert hf_formatted_conversation == expected_output
231
+
232
+
233
+ @pytest.mark.parametrize(
234
+ "model_name, expected_formatter",
235
+ [
236
+ pytest.param("llama-3", Llama3Formatter, id="llama-3"),
237
+ pytest.param("llama-3-base", Llama3Formatter, id="llama-3-base"),
238
+ pytest.param("llama-3-large", Llama3Formatter, id="llama-3-large"),
239
+ pytest.param("my-llama-3-model", Llama3Formatter, id="custom-llama-3-model"),
240
+ pytest.param("gpt2", ConcatFormatter, id="gpt2"),
241
+ pytest.param("bert", ConcatFormatter, id="bert"),
242
+ pytest.param("roberta", ConcatFormatter, id="roberta"),
243
+ pytest.param("distilbert", ConcatFormatter, id="distilbert"),
244
+ pytest.param("custom-model", ConcatFormatter, id="custom-non-llama3-model"),
245
+ pytest.param("", ConcatFormatter, id="empty-model-name"),
246
+ ],
247
+ )
248
+ def test_get_formatter(model_name: str, expected_formatter: type[BaseFormatter]) -> None:
249
+ formatter = get_formatter(model_name)
250
+ assert isinstance(formatter, expected_formatter)
251
+
252
+
253
+ # ReasoningFormatter tests
254
+
255
+
256
+ def test_reasoning_formatter_with_system_and_user(llama3_reasoning_formatter: BaseFormatter) -> None:
257
+ conversation = [
258
+ Message(role=Role.SYSTEM, content="You are a helpful AI assistant for travel tips and recommendations"),
259
+ Message(role=Role.USER, content="What is France's capital?"),
260
+ ]
261
+
262
+ formatted_conversation = llama3_reasoning_formatter.format(conversation, output_mode="string")
263
+ expected_output = (
264
+ "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
265
+ "You are a helpful AI assistant for travel tips and recommendations<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
266
+ "What is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
267
+ "<|begin_of_thought|>"
268
+ )
269
+
270
+ assert formatted_conversation == expected_output
271
+
272
+
273
+ def test_reasoning_formatter_with_user(llama3_reasoning_formatter: BaseFormatter) -> None:
274
+ conversation = [
275
+ Message(role=Role.USER, content="What is France's capital?"),
276
+ ]
277
+
278
+ formatted_conversation = llama3_reasoning_formatter.format(conversation, output_mode="string")
279
+ expected_output = (
280
+ "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
281
+ "What is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
282
+ "<|begin_of_thought|>"
283
+ )
284
+
285
+ assert formatted_conversation == expected_output
286
+
287
+
288
+ def test_reasoning_formatter_with_system_user_and_thought(llama3_reasoning_formatter: BaseFormatter) -> None:
289
+ conversation = [
290
+ Message(role=Role.SYSTEM, content="You are a helpful AI assistant for travel tips and recommendations"),
291
+ Message(role=Role.USER, content="What is France's capital?"),
292
+ Message(role=Role.ASSISTANT, property=Property.THOUGHT, content="Bonjour! Let me think about this..."),
293
+ ]
294
+
295
+ formatted_conversation = llama3_reasoning_formatter.format(conversation, output_mode="string")
296
+ expected_output = (
297
+ "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
298
+ "You are a helpful AI assistant for travel tips and recommendations<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
299
+ "What is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
300
+ "<|begin_of_thought|>Bonjour! Let me think about this...<|end_of_thought|>"
301
+ "<|begin_of_solution|>"
302
+ )
303
+ assert formatted_conversation == expected_output
304
+
305
+
306
+ def test_reasoning_formatter_with_system_user_thought_and_solution(llama3_reasoning_formatter: BaseFormatter) -> None:
307
+ conversation = [
308
+ Message(role=Role.SYSTEM, content="You are a helpful AI assistant for travel tips and recommendations"),
309
+ Message(role=Role.USER, content="What is France's capital?"),
310
+ Message(role=Role.ASSISTANT, property=Property.THOUGHT, content="Bonjour! Let me think about this..."),
311
+ Message(role=Role.ASSISTANT, property=Property.SOLUTION, content="Merci! The capital of France is Paris!"),
312
+ ]
313
+
314
+ formatted_conversation = llama3_reasoning_formatter.format(conversation, output_mode="string")
315
+ expected_output = (
316
+ "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
317
+ "You are a helpful AI assistant for travel tips and recommendations<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
318
+ "What is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
319
+ "<|begin_of_thought|>Bonjour! Let me think about this..."
320
+ "<|end_of_thought|><|begin_of_solution|>Merci! The capital of France is Paris!"
321
+ "<|begin_of_answer|>"
322
+ )
323
+
324
+ assert formatted_conversation == expected_output
325
+
326
+
327
+ def test_reasoning_formatter_with_system_user_thought_solution_and_answer(
328
+ llama3_reasoning_formatter: BaseFormatter,
329
+ ) -> None:
330
+ conversation = [
331
+ Message(role=Role.SYSTEM, content="You are a helpful AI assistant for travel tips and recommendations"),
332
+ Message(role=Role.USER, content="What is France's capital?"),
333
+ Message(role=Role.ASSISTANT, property=Property.THOUGHT, content="Bonjour! Let me think about this..."),
334
+ Message(role=Role.ASSISTANT, property=Property.SOLUTION, content="Merci! The capital of France is Paris!"),
335
+ Message(role=Role.ASSISTANT, property=Property.ANSWER, content="\\boxed{Paris}"),
336
+ ]
337
+
338
+ formatted_conversation = llama3_reasoning_formatter.format(conversation, output_mode="string")
339
+ expected_output = (
340
+ "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
341
+ "You are a helpful AI assistant for travel tips and recommendations<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
342
+ "What is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
343
+ "<|begin_of_thought|>Bonjour! Let me think about this...<|end_of_thought|>"
344
+ "<|begin_of_solution|>Merci! The capital of France is Paris!"
345
+ "<|begin_of_answer|>\\boxed{Paris}<|end_of_answer|><|end_of_solution|><|eot_id|><|end_of_text|>"
346
+ )
347
+
348
+ assert formatted_conversation == expected_output
349
+
350
+
351
+ def test_reasoning_formatter_parse_wrong_order() -> None:
352
+ base_formatter = Llama3Formatter
353
+ rf = ReasoningFormatter(base_formatter)
354
+ rt = rf.template
355
+ output_str = (
356
+ rt.begin_thought_id
357
+ + "thought"
358
+ + rt.begin_solution_id
359
+ + "solution" # Wrong: begin_solution_id comes before end_thought_id.
360
+ + rt.end_thought_id
361
+ + rt.end_solution_id
362
+ + rt.begin_answer_id
363
+ + "answer"
364
+ + rt.end_answer_id
365
+ + rt.end_of_text
366
+ )
367
+ parsed, error = rf.parse(output_str)
368
+ assert error is not None
369
+ with pytest.raises(ValueError):
370
+ raise error
371
+
372
+
373
+ def test_reasoning_formatter_parse_incomplete() -> None:
374
+ base_formatter = Llama3Formatter
375
+ rf = ReasoningFormatter(base_formatter)
376
+ rt = rf.template
377
+
378
+ output_str = rt.begin_thought_id + "only thought" + rt.end_thought_id
379
+ parsed, error = rf.parse(output_str)
380
+ assert error is None
381
+ assert parsed["thought"] == "only thought"
382
+ assert parsed.get("solution", "") == ""
383
+ assert parsed.get("answer", "") == ""
384
+
385
+
386
+ def test_reasoning_formatter_parse_duplicate_tokens() -> None:
387
+ base_formatter = Llama3Formatter
388
+ rf = ReasoningFormatter(base_formatter)
389
+ rt = rf.template
390
+
391
+ output_str = (
392
+ rt.begin_thought_id
393
+ + "thought"
394
+ + rt.begin_thought_id
395
+ + "duplicate"
396
+ + rt.end_thought_id
397
+ + rt.begin_solution_id
398
+ + "solution"
399
+ + rt.end_solution_id
400
+ + rt.begin_answer_id
401
+ + "answer"
402
+ + rt.end_answer_id
403
+ + rt.end_of_text
404
+ )
405
+ parsed, error = rf.parse(output_str)
406
+ assert error is not None
407
+ with pytest.raises(ValueError):
408
+ raise error
@@ -0,0 +1,253 @@
1
+ # ruff: noqa: E501
2
+ import importlib.util
3
+
4
+ import pytest
5
+
6
+ from template_formatting.formatter import (
7
+ BaseFormatter,
8
+ ConcatFormatter,
9
+ HFFormatter,
10
+ Llama3Formatter,
11
+ Message,
12
+ Property,
13
+ ReasoningFormatter,
14
+ Role,
15
+ )
16
+
17
+ package_exists = importlib.util.find_spec("transformers") is not None
18
+
19
+
20
+ @pytest.fixture()
21
+ def concat_formatter() -> BaseFormatter:
22
+ return ConcatFormatter()
23
+
24
+
25
+ @pytest.fixture()
26
+ def llama3_formatter() -> BaseFormatter:
27
+ return Llama3Formatter()
28
+
29
+
30
+ @pytest.fixture()
31
+ def hf_formatter() -> BaseFormatter:
32
+ return HFFormatter("meta-llama/Meta-Llama-3-8B-Instruct")
33
+
34
+
35
+ def test_get_grouped_messages_same_property() -> None:
36
+ defaults = {"content": "dummy", "has_loss": False, "type": "text"}
37
+
38
+ messages = [
39
+ Message(role=Role.USER, property=None, **defaults),
40
+ Message(role=Role.ASSISTANT, property=None, **defaults),
41
+ Message(role=Role.ASSISTANT, property=None, **defaults),
42
+ ]
43
+
44
+ grouped_messages = BaseFormatter._get_grouped_messages(messages)
45
+
46
+ assert grouped_messages == [
47
+ [Message(role=Role.USER, property=None, **defaults)],
48
+ [Message(role=Role.ASSISTANT, property=None, **defaults)],
49
+ [Message(role=Role.ASSISTANT, property=None, **defaults)],
50
+ ]
51
+
52
+
53
+ def test_get_grouped_messages_different_property() -> None:
54
+ defaults = {"content": "dummy", "has_loss": False, "type": "text"}
55
+
56
+ messages = [
57
+ Message(role=Role.USER, property=None, **defaults),
58
+ Message(role=Role.ASSISTANT, property=None, **defaults),
59
+ Message(role=Role.ASSISTANT, property=Property.ANSWER, **defaults),
60
+ ]
61
+
62
+ grouped_messages = BaseFormatter._get_grouped_messages(messages)
63
+
64
+ assert grouped_messages == [
65
+ [Message(role=Role.USER, property=None, **defaults)],
66
+ [
67
+ Message(role=Role.ASSISTANT, property=None, **defaults),
68
+ Message(role=Role.ASSISTANT, property=Property.ANSWER, **defaults),
69
+ ],
70
+ ]
71
+
72
+
73
+ def test_base_verify_messages() -> None:
74
+ defaults = {"content": "dummy", "has_loss": False, "type": "text"}
75
+
76
+ messages = [
77
+ Message(role=Role.USER, property=None, **defaults),
78
+ Message(role=Role.ASSISTANT, property=None, **defaults),
79
+ ]
80
+
81
+ # Does not raise an assertion error.
82
+ BaseFormatter._verify_messages(messages)
83
+
84
+
85
+ def test_base_verify_messages_raises_exception() -> None:
86
+ defaults = {"content": "dummy", "has_loss": False, "type": "text"}
87
+
88
+ messages = [
89
+ Message(role=Role.USER, property=None, **defaults),
90
+ Message(role=Role.ASSISTANT, property=None, **defaults),
91
+ Message(role=Role.ASSISTANT, property=None, **defaults),
92
+ ]
93
+
94
+ with pytest.raises(AssertionError):
95
+ BaseFormatter._verify_messages(messages)
96
+
97
+
98
+ def test_reasoning_verify_messages() -> None:
99
+ defaults = {"content": "dummy", "has_loss": False, "type": "text"}
100
+
101
+ messages = [
102
+ Message(role=Role.USER, property=None, **defaults),
103
+ Message(role=Role.ASSISTANT, property=Property.THOUGHT, **defaults),
104
+ Message(role=Role.ASSISTANT, property=Property.SOLUTION, **defaults),
105
+ Message(role=Role.ASSISTANT, property=Property.ANSWER, **defaults),
106
+ ]
107
+
108
+ # Does not raise an assertion error.
109
+ ReasoningFormatter._verify_messages(messages)
110
+
111
+
112
+ def test_reasoning_verify_messages_raises_exception() -> None:
113
+ defaults = {"content": "dummy", "has_loss": False, "type": "text"}
114
+
115
+ messages = [
116
+ Message(role=Role.USER, property=None, **defaults),
117
+ Message(role=Role.ASSISTANT, property=None, **defaults),
118
+ Message(role=Role.ASSISTANT, property=Property.ANSWER, **defaults),
119
+ ]
120
+
121
+ with pytest.raises(AssertionError):
122
+ ReasoningFormatter._verify_messages(messages)
123
+
124
+
125
+ ## Assert that formatting is in line with HF Formatter
126
+
127
+
128
+ @pytest.mark.skipif(
129
+ not package_exists,
130
+ reason="`transformers` package is not installed, HFFormatter will not be available.",
131
+ )
132
+ def test_llama3_formatter_with_system_and_assistant_simple(
133
+ llama3_formatter: BaseFormatter, hf_formatter: BaseFormatter
134
+ ) -> None:
135
+ conversation = [
136
+ Message(
137
+ role=Role.SYSTEM,
138
+ content="You are a helpful AI assistant for travel tips and recommendations",
139
+ has_loss=False,
140
+ type="text",
141
+ ),
142
+ Message(role=Role.USER, content="What is France's capital?", has_loss=False, type="text"),
143
+ Message(role=Role.ASSISTANT, content="Bonjour! The capital of France is Paris!", has_loss=True, type="text"),
144
+ ]
145
+
146
+ formatted_conversation = llama3_formatter.format(conversation, output_mode="list")
147
+
148
+ expected_contents = [
149
+ (
150
+ "<|begin_of_text|>"
151
+ "<|start_header_id|>system<|end_header_id|>\n\n"
152
+ "You are a helpful AI assistant for travel tips and recommendations<|eot_id|>"
153
+ ),
154
+ (
155
+ "<|start_header_id|>user<|end_header_id|>\n\n"
156
+ "What is France's capital?<|eot_id|>"
157
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
158
+ ),
159
+ "Bonjour! The capital of France is Paris!<|eot_id|>",
160
+ ]
161
+
162
+ for formatted_message, expected in zip(formatted_conversation, expected_contents):
163
+ assert formatted_message.content == expected
164
+
165
+ # stringify the list
166
+ formatted_conversation_str = "".join(elm.content for elm in formatted_conversation)
167
+
168
+ expected_output_str = (
169
+ "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
170
+ "You are a helpful AI assistant for travel tips and recommendations<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
171
+ "What is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
172
+ "Bonjour! The capital of France is Paris!<|eot_id|>"
173
+ )
174
+
175
+ assert formatted_conversation_str == expected_output_str
176
+
177
+ hf_formatted_conversation = hf_formatter.format(conversation, output_mode="list")
178
+ assert hf_formatted_conversation == formatted_conversation_str
179
+
180
+
181
+ @pytest.mark.skipif(
182
+ not package_exists,
183
+ reason="`transformers` package is not installed, HFFormatter will not be available.",
184
+ )
185
+ def test_llama3_formatter_without_system_multiple_rounds_list(
186
+ llama3_formatter: BaseFormatter, hf_formatter: BaseFormatter
187
+ ) -> None:
188
+ conversation = [
189
+ Message(role=Role.USER, content="What is France's capital?", has_loss=False, type="text"),
190
+ Message(role=Role.ASSISTANT, content="Bonjour! The capital of France is Paris!", has_loss=True, type="text"),
191
+ Message(role=Role.USER, content="What can I do there?", has_loss=False, type="text"),
192
+ Message(
193
+ role=Role.ASSISTANT,
194
+ content=(
195
+ "Paris offers many attractions and activities. "
196
+ "Some popular things to do include visiting the Eiffel Tower, "
197
+ "exploring the Louvre Museum, taking a river cruise along the Seine, "
198
+ "and strolling through charming neighborhoods like Montmartre."
199
+ ),
200
+ has_loss=False,
201
+ type="text",
202
+ ),
203
+ Message(role=Role.USER, content="What else?", has_loss=False, type="text"),
204
+ ]
205
+
206
+ original_conversation = conversation.copy()
207
+
208
+ formatted_conversation = llama3_formatter.format(conversation, output_mode="list")
209
+
210
+ expected_contents = [
211
+ (
212
+ "<|begin_of_text|>"
213
+ "<|start_header_id|>user<|end_header_id|>\n\n"
214
+ "What is France's capital?<|eot_id|>"
215
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
216
+ ),
217
+ "Bonjour! The capital of France is Paris!<|eot_id|>",
218
+ (
219
+ "<|start_header_id|>user<|end_header_id|>\n\n"
220
+ "What can I do there?<|eot_id|>"
221
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
222
+ ),
223
+ (
224
+ "Paris offers many attractions and activities. Some popular things to do include visiting the Eiffel Tower, "
225
+ "exploring the Louvre Museum, taking a river cruise along the Seine, and strolling through charming neighborhoods like Montmartre.<|eot_id|>"
226
+ ),
227
+ ("<|start_header_id|>user<|end_header_id|>\n\nWhat else?<|eot_id|>"),
228
+ ]
229
+
230
+ for formatted_message, expected in zip(formatted_conversation, expected_contents):
231
+ assert formatted_message.content == expected
232
+
233
+ # stringify the list
234
+ formatted_conversation_str = "".join(elm.content for elm in formatted_conversation)
235
+
236
+ expected_output_str = (
237
+ "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
238
+ "What is France's capital?<|eot_id|>"
239
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
240
+ "Bonjour! The capital of France is Paris!<|eot_id|>"
241
+ "<|start_header_id|>user<|end_header_id|>\n\n"
242
+ "What can I do there?<|eot_id|>"
243
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
244
+ "Paris offers many attractions and activities. Some popular things to do "
245
+ "include visiting the Eiffel Tower, exploring the Louvre Museum, taking a river "
246
+ "cruise along the Seine, and strolling through charming neighborhoods like Montmartre.<|eot_id|>"
247
+ "<|start_header_id|>user<|end_header_id|>\n\n"
248
+ "What else?<|eot_id|>"
249
+ )
250
+ assert formatted_conversation_str == expected_output_str
251
+
252
+ hf_formatted_conversation = hf_formatter.format(original_conversation, output_mode="list")
253
+ assert hf_formatted_conversation == expected_output_str