eval-framework 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eval_framework/__init__.py +7 -0
- eval_framework/base_config.py +36 -0
- eval_framework/context/__init__.py +0 -0
- eval_framework/context/determined.py +170 -0
- eval_framework/context/eval.py +114 -0
- eval_framework/context/local.py +52 -0
- eval_framework/evaluation_generator.py +231 -0
- eval_framework/exceptions.py +2 -0
- eval_framework/external/ifeval_impl/README.md +5 -0
- eval_framework/external/ifeval_impl/instructions.py +1523 -0
- eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
- eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
- eval_framework/external/ifeval_impl/utils.py +135 -0
- eval_framework/llm/__init__.py +0 -0
- eval_framework/llm/aleph_alpha.py +323 -0
- eval_framework/llm/base.py +58 -0
- eval_framework/llm/huggingface.py +332 -0
- eval_framework/llm/mistral.py +73 -0
- eval_framework/llm/models.py +16 -0
- eval_framework/llm/openai.py +205 -0
- eval_framework/llm/vllm.py +438 -0
- eval_framework/logger.py +3 -0
- eval_framework/main.py +187 -0
- eval_framework/metrics/__init__.py +0 -0
- eval_framework/metrics/base.py +40 -0
- eval_framework/metrics/completion/__init__.py +1 -0
- eval_framework/metrics/completion/accuracy_completion.py +16 -0
- eval_framework/metrics/completion/bleu.py +76 -0
- eval_framework/metrics/completion/chrf.py +62 -0
- eval_framework/metrics/completion/code_assertion.py +44 -0
- eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
- eval_framework/metrics/completion/comet.py +56 -0
- eval_framework/metrics/completion/concordance_index.py +38 -0
- eval_framework/metrics/completion/csv_format.py +102 -0
- eval_framework/metrics/completion/cwe_accuracy.py +49 -0
- eval_framework/metrics/completion/exponential_similarity.py +65 -0
- eval_framework/metrics/completion/f1.py +42 -0
- eval_framework/metrics/completion/format_checker.py +56 -0
- eval_framework/metrics/completion/grid_difference.py +77 -0
- eval_framework/metrics/completion/ifeval.py +73 -0
- eval_framework/metrics/completion/json_format.py +171 -0
- eval_framework/metrics/completion/language_checker.py +74 -0
- eval_framework/metrics/completion/length_control.py +83 -0
- eval_framework/metrics/completion/math_reasoning_completion.py +303 -0
- eval_framework/metrics/completion/niah_accuracy.py +163 -0
- eval_framework/metrics/completion/placeholder_checker.py +27 -0
- eval_framework/metrics/completion/repetition.py +88 -0
- eval_framework/metrics/completion/rouge_1.py +35 -0
- eval_framework/metrics/completion/rouge_2.py +45 -0
- eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
- eval_framework/metrics/completion/rouge_l.py +52 -0
- eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
- eval_framework/metrics/completion/ter.py +67 -0
- eval_framework/metrics/completion/text_counter.py +182 -0
- eval_framework/metrics/efficiency/__init__.py +0 -0
- eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
- eval_framework/metrics/llm/__init__.py +0 -0
- eval_framework/metrics/llm/base.py +8 -0
- eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
- eval_framework/metrics/llm/graders/comparison_grader.py +146 -0
- eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
- eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
- eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
- eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
- eval_framework/metrics/llm/graders/language.py +56 -0
- eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
- eval_framework/metrics/llm/graders/models.py +74 -0
- eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
- eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
- eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
- eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
- eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
- eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
- eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
- eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
- eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
- eval_framework/metrics/llm/llm_judge_mtbench_pair.py +205 -0
- eval_framework/metrics/llm/llm_judge_mtbench_single.py +188 -0
- eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
- eval_framework/metrics/llm/llm_judge_sql.py +394 -0
- eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
- eval_framework/metrics/loglikelihood/__init__.py +0 -0
- eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
- eval_framework/metrics/loglikelihood/probability_mass.py +56 -0
- eval_framework/py.typed +0 -0
- eval_framework/response_generator.py +416 -0
- eval_framework/result_processors/__init__.py +0 -0
- eval_framework/result_processors/base.py +74 -0
- eval_framework/result_processors/hf_processor.py +87 -0
- eval_framework/result_processors/result_processor.py +129 -0
- eval_framework/run.py +314 -0
- eval_framework/run_direct.py +42 -0
- eval_framework/shared/types.py +227 -0
- eval_framework/tasks/__init__.py +6 -0
- eval_framework/tasks/base.py +314 -0
- eval_framework/tasks/benchmarks/__init__.py +0 -0
- eval_framework/tasks/benchmarks/arc.py +46 -0
- eval_framework/tasks/benchmarks/arc_de.py +46 -0
- eval_framework/tasks/benchmarks/arc_fi.py +46 -0
- eval_framework/tasks/benchmarks/belebele.py +60 -0
- eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
- eval_framework/tasks/benchmarks/casehold.py +47 -0
- eval_framework/tasks/benchmarks/chembench.py +85 -0
- eval_framework/tasks/benchmarks/copa.py +39 -0
- eval_framework/tasks/benchmarks/duc.py +91 -0
- eval_framework/tasks/benchmarks/flores200.py +62 -0
- eval_framework/tasks/benchmarks/flores_plus.py +84 -0
- eval_framework/tasks/benchmarks/gpqa.py +177 -0
- eval_framework/tasks/benchmarks/gsm8k.py +148 -0
- eval_framework/tasks/benchmarks/hellaswag.py +44 -0
- eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
- eval_framework/tasks/benchmarks/humaneval.py +97 -0
- eval_framework/tasks/benchmarks/ifeval.py +78 -0
- eval_framework/tasks/benchmarks/include.py +119 -0
- eval_framework/tasks/benchmarks/infinitebench.py +302 -0
- eval_framework/tasks/benchmarks/math_reasoning.py +569 -0
- eval_framework/tasks/benchmarks/mbpp.py +192 -0
- eval_framework/tasks/benchmarks/mmlu.py +190 -0
- eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
- eval_framework/tasks/benchmarks/mmlu_pro.py +139 -0
- eval_framework/tasks/benchmarks/mmmlu.py +529 -0
- eval_framework/tasks/benchmarks/openbookqa.py +37 -0
- eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
- eval_framework/tasks/benchmarks/pawsx.py +65 -0
- eval_framework/tasks/benchmarks/piqa.py +39 -0
- eval_framework/tasks/benchmarks/quality.py +56 -0
- eval_framework/tasks/benchmarks/sciq.py +44 -0
- eval_framework/tasks/benchmarks/sphyr.py +75 -0
- eval_framework/tasks/benchmarks/squad.py +89 -0
- eval_framework/tasks/benchmarks/struct_eval.py +110 -0
- eval_framework/tasks/benchmarks/tablebench.py +117 -0
- eval_framework/tasks/benchmarks/triviaqa.py +42 -0
- eval_framework/tasks/benchmarks/truthfulqa.py +95 -0
- eval_framework/tasks/benchmarks/winogender.py +39 -0
- eval_framework/tasks/benchmarks/winogrande.py +44 -0
- eval_framework/tasks/benchmarks/winox.py +57 -0
- eval_framework/tasks/benchmarks/wmt.py +160 -0
- eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
- eval_framework/tasks/eval_config.py +112 -0
- eval_framework/tasks/perturbation.py +83 -0
- eval_framework/tasks/registry.py +186 -0
- eval_framework/tasks/task_loader.py +80 -0
- eval_framework/tasks/task_names.py +138 -0
- eval_framework/tasks/utils.py +578 -0
- eval_framework/utils/constants.py +9 -0
- eval_framework/utils/generate_task_docs.py +229 -0
- eval_framework/utils/helpers.py +3 -0
- eval_framework/utils/logging.py +50 -0
- eval_framework/utils/packaging.py +52 -0
- eval_framework-0.2.0.dist-info/METADATA +514 -0
- eval_framework-0.2.0.dist-info/RECORD +161 -0
- eval_framework-0.2.0.dist-info/WHEEL +4 -0
- eval_framework-0.2.0.dist-info/entry_points.txt +3 -0
- template_formatting/README.md +83 -0
- template_formatting/__init__.py +0 -0
- template_formatting/formatter.py +536 -0
- template_formatting/mistral_formatter.py +159 -0
- template_formatting/py.typed +0 -0
- template_formatting/tests/test_formatter_eval.py +408 -0
- template_formatting/tests/test_formatter_scaling.py +253 -0
- template_formatting/tests/test_mistral_formatter.py +136 -0
|
@@ -0,0 +1,408 @@
|
|
|
1
|
+
# ruff: noqa: E501
|
|
2
|
+
import importlib.util
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
from template_formatting.formatter import (
|
|
7
|
+
BaseFormatter,
|
|
8
|
+
ConcatFormatter,
|
|
9
|
+
HFFormatter,
|
|
10
|
+
Llama3Formatter,
|
|
11
|
+
Message,
|
|
12
|
+
Property,
|
|
13
|
+
ReasoningFormatter,
|
|
14
|
+
Role,
|
|
15
|
+
get_formatter,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
package_exists = importlib.util.find_spec("transformers") is not None
|
|
19
|
+
|
|
20
|
+
# no tests requiring a GPU runner are contained here -> no additional pytest GPU markers
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@pytest.fixture()
|
|
24
|
+
def concat_formatter() -> BaseFormatter:
|
|
25
|
+
return ConcatFormatter()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@pytest.fixture()
|
|
29
|
+
def llama3_formatter() -> BaseFormatter:
|
|
30
|
+
return Llama3Formatter()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@pytest.fixture()
|
|
34
|
+
def hf_formatter() -> BaseFormatter:
|
|
35
|
+
return HFFormatter("meta-llama/Meta-Llama-3-8B-Instruct")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@pytest.fixture()
|
|
39
|
+
def llama3_reasoning_formatter() -> BaseFormatter:
|
|
40
|
+
llama3_reasoning_formatter = ReasoningFormatter(Llama3Formatter)
|
|
41
|
+
llama3_reasoning_formatter.template.end_of_text = "<|end_of_text|>"
|
|
42
|
+
return llama3_reasoning_formatter
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def test_concat_formatter(concat_formatter: BaseFormatter) -> None:
|
|
46
|
+
messages = [
|
|
47
|
+
Message(role=Role.SYSTEM, content="You are a helpful AI assistant for travel tips and recommendations"),
|
|
48
|
+
Message(role=Role.USER, content="What is France's capital?\n"), # new line has to be handled on task level
|
|
49
|
+
Message(role=Role.ASSISTANT, content="Bonjour! The capital of France is Paris!"),
|
|
50
|
+
Message(role=Role.USER, content="Great, thanks!"),
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
formatted_conversation = concat_formatter.format(messages, output_mode="string")
|
|
54
|
+
expected_output = (
|
|
55
|
+
"You are a helpful AI assistant for travel tips and recommendations\n\n"
|
|
56
|
+
"What is France's capital?\n"
|
|
57
|
+
"Bonjour! The capital of France is Paris!\n\n"
|
|
58
|
+
"Great, thanks!"
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
assert formatted_conversation == expected_output
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@pytest.mark.skipif(
|
|
65
|
+
not package_exists,
|
|
66
|
+
reason="`transformers` package is not installed, HFFormatter will not be available.",
|
|
67
|
+
)
|
|
68
|
+
def test_llama3_formatter_with_system_and_assistant_simple(
|
|
69
|
+
llama3_formatter: BaseFormatter, hf_formatter: BaseFormatter
|
|
70
|
+
) -> None:
|
|
71
|
+
conversation = [
|
|
72
|
+
Message(role=Role.SYSTEM, content="You are a helpful AI assistant for travel tips and recommendations"),
|
|
73
|
+
Message(role=Role.USER, content="What is France's capital?"),
|
|
74
|
+
Message(role=Role.ASSISTANT, content="Bonjour! The capital of France is Paris!"),
|
|
75
|
+
]
|
|
76
|
+
|
|
77
|
+
formatted_conversation = llama3_formatter.format(conversation, output_mode="string")
|
|
78
|
+
expected_output = (
|
|
79
|
+
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
|
|
80
|
+
"You are a helpful AI assistant for travel tips and recommendations<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
|
|
81
|
+
"What is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
82
|
+
"Bonjour! The capital of France is Paris!"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
assert formatted_conversation == expected_output
|
|
86
|
+
|
|
87
|
+
hf_formatted_conversation = hf_formatter.format(conversation, output_mode="string")
|
|
88
|
+
assert hf_formatted_conversation == expected_output
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@pytest.mark.skipif(
|
|
92
|
+
not package_exists,
|
|
93
|
+
reason="`transformers` package is not installed, HFFormatter will not be available.",
|
|
94
|
+
)
|
|
95
|
+
def test_llama3_formatter_with_system_and_assistant(
|
|
96
|
+
llama3_formatter: BaseFormatter, hf_formatter: BaseFormatter
|
|
97
|
+
) -> None:
|
|
98
|
+
conversation = [
|
|
99
|
+
Message(role=Role.SYSTEM, content="You are a helpful AI assistant for travel tips and recommendations"),
|
|
100
|
+
Message(role=Role.USER, content="What is France's capital?"),
|
|
101
|
+
Message(role=Role.ASSISTANT, content="Bonjour! The capital of France is Paris!"),
|
|
102
|
+
Message(role=Role.USER, content="Great, thanks!"),
|
|
103
|
+
]
|
|
104
|
+
|
|
105
|
+
formatted_conversation = llama3_formatter.format(conversation, output_mode="string")
|
|
106
|
+
expected_output = (
|
|
107
|
+
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
|
|
108
|
+
"You are a helpful AI assistant for travel tips and recommendations<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
|
|
109
|
+
"What is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
110
|
+
"Bonjour! The capital of France is Paris!<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
|
|
111
|
+
"Great, thanks!<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
assert formatted_conversation == expected_output
|
|
115
|
+
|
|
116
|
+
hf_formatted_conversation = hf_formatter.format(conversation, output_mode="string")
|
|
117
|
+
assert hf_formatted_conversation == expected_output
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@pytest.mark.skipif(
|
|
121
|
+
not package_exists,
|
|
122
|
+
reason="`transformers` package is not installed, HFFormatter will not be available.",
|
|
123
|
+
)
|
|
124
|
+
def test_llama3_formatter_without_system_and_assistant(
|
|
125
|
+
llama3_formatter: BaseFormatter, hf_formatter: BaseFormatter
|
|
126
|
+
) -> None:
|
|
127
|
+
conversation = [
|
|
128
|
+
Message(role=Role.USER, content="What is France's capital?"),
|
|
129
|
+
]
|
|
130
|
+
|
|
131
|
+
formatted_conversation = llama3_formatter.format(conversation, output_mode="string")
|
|
132
|
+
expected_output = (
|
|
133
|
+
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
|
|
134
|
+
"What is France's capital?<|eot_id|>"
|
|
135
|
+
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
136
|
+
)
|
|
137
|
+
assert formatted_conversation == expected_output
|
|
138
|
+
|
|
139
|
+
hf_formatted_conversation = hf_formatter.format(conversation, output_mode="string")
|
|
140
|
+
assert hf_formatted_conversation == expected_output
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@pytest.mark.skipif(
|
|
144
|
+
not package_exists,
|
|
145
|
+
reason="`transformers` package is not installed, HFFormatter will not be available.",
|
|
146
|
+
)
|
|
147
|
+
def test_llama3_formatter_without_system_multiple_rounds(
|
|
148
|
+
llama3_formatter: BaseFormatter, hf_formatter: BaseFormatter
|
|
149
|
+
) -> None:
|
|
150
|
+
conversation = [
|
|
151
|
+
Message(role=Role.USER, content="What is France's capital?"),
|
|
152
|
+
Message(role=Role.ASSISTANT, content="Bonjour! The capital of France is Paris!"),
|
|
153
|
+
Message(role=Role.USER, content="What can I do there?"),
|
|
154
|
+
Message(
|
|
155
|
+
role=Role.ASSISTANT,
|
|
156
|
+
content=(
|
|
157
|
+
"Paris offers many attractions and activities. "
|
|
158
|
+
"Some popular things to do include visiting the Eiffel Tower, "
|
|
159
|
+
"exploring the Louvre Museum, taking a river cruise along the Seine, "
|
|
160
|
+
"and strolling through charming neighborhoods like Montmartre."
|
|
161
|
+
),
|
|
162
|
+
),
|
|
163
|
+
Message(role=Role.USER, content="What else?"),
|
|
164
|
+
]
|
|
165
|
+
|
|
166
|
+
formatted_conversation = llama3_formatter.format(conversation, output_mode="string")
|
|
167
|
+
expected_output = (
|
|
168
|
+
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
|
|
169
|
+
"What is France's capital?<|eot_id|>"
|
|
170
|
+
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
171
|
+
"Bonjour! The capital of France is Paris!<|eot_id|>"
|
|
172
|
+
"<|start_header_id|>user<|end_header_id|>\n\n"
|
|
173
|
+
"What can I do there?<|eot_id|>"
|
|
174
|
+
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
175
|
+
"Paris offers many attractions and activities. Some popular things to do "
|
|
176
|
+
"include visiting the Eiffel Tower, exploring the Louvre Museum, taking a river "
|
|
177
|
+
"cruise along the Seine, and strolling through charming neighborhoods like Montmartre.<|eot_id|>"
|
|
178
|
+
"<|start_header_id|>user<|end_header_id|>\n\n"
|
|
179
|
+
"What else?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
180
|
+
)
|
|
181
|
+
assert formatted_conversation == expected_output
|
|
182
|
+
|
|
183
|
+
hf_formatted_conversation = hf_formatter.format(conversation, output_mode="string")
|
|
184
|
+
assert hf_formatted_conversation == expected_output
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@pytest.mark.skipif(
|
|
188
|
+
not package_exists,
|
|
189
|
+
reason="`transformers` package is not installed, HFFormatter will not be available.",
|
|
190
|
+
)
|
|
191
|
+
def test_llama3_formatter_with_prefilling(llama3_formatter: BaseFormatter, hf_formatter: BaseFormatter) -> None:
|
|
192
|
+
conversation = [
|
|
193
|
+
Message(role=Role.USER, content="How many helicopters can a human eat in one sitting?"),
|
|
194
|
+
Message(role=Role.ASSISTANT, content="A human can"), # aka "cue"
|
|
195
|
+
]
|
|
196
|
+
|
|
197
|
+
formatted_conversation = llama3_formatter.format(conversation, output_mode="string")
|
|
198
|
+
expected_output = (
|
|
199
|
+
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
|
|
200
|
+
"How many helicopters can a human eat in one sitting?<|eot_id|>"
|
|
201
|
+
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
202
|
+
"A human can"
|
|
203
|
+
)
|
|
204
|
+
assert formatted_conversation == expected_output
|
|
205
|
+
|
|
206
|
+
hf_formatted_conversation = hf_formatter.format(conversation, output_mode="string")
|
|
207
|
+
assert hf_formatted_conversation == expected_output
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
@pytest.mark.skipif(
|
|
211
|
+
not package_exists,
|
|
212
|
+
reason="`transformers` package is not installed, HFFormatter will not be available.",
|
|
213
|
+
)
|
|
214
|
+
def test_stripping_of_whitespace(llama3_formatter: BaseFormatter, hf_formatter: BaseFormatter) -> None:
|
|
215
|
+
conversation = [
|
|
216
|
+
Message(role=Role.USER, content=" What is the capital of France? "),
|
|
217
|
+
Message(role=Role.ASSISTANT, content=" The capital of France is "), #
|
|
218
|
+
]
|
|
219
|
+
|
|
220
|
+
formatted_conversation = llama3_formatter.format(conversation, output_mode="string")
|
|
221
|
+
expected_output = (
|
|
222
|
+
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
|
|
223
|
+
"What is the capital of France?<|eot_id|>"
|
|
224
|
+
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
225
|
+
"The capital of France is"
|
|
226
|
+
)
|
|
227
|
+
assert formatted_conversation == expected_output
|
|
228
|
+
|
|
229
|
+
hf_formatted_conversation = hf_formatter.format(conversation, output_mode="string")
|
|
230
|
+
assert hf_formatted_conversation == expected_output
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
@pytest.mark.parametrize(
|
|
234
|
+
"model_name, expected_formatter",
|
|
235
|
+
[
|
|
236
|
+
pytest.param("llama-3", Llama3Formatter, id="llama-3"),
|
|
237
|
+
pytest.param("llama-3-base", Llama3Formatter, id="llama-3-base"),
|
|
238
|
+
pytest.param("llama-3-large", Llama3Formatter, id="llama-3-large"),
|
|
239
|
+
pytest.param("my-llama-3-model", Llama3Formatter, id="custom-llama-3-model"),
|
|
240
|
+
pytest.param("gpt2", ConcatFormatter, id="gpt2"),
|
|
241
|
+
pytest.param("bert", ConcatFormatter, id="bert"),
|
|
242
|
+
pytest.param("roberta", ConcatFormatter, id="roberta"),
|
|
243
|
+
pytest.param("distilbert", ConcatFormatter, id="distilbert"),
|
|
244
|
+
pytest.param("custom-model", ConcatFormatter, id="custom-non-llama3-model"),
|
|
245
|
+
pytest.param("", ConcatFormatter, id="empty-model-name"),
|
|
246
|
+
],
|
|
247
|
+
)
|
|
248
|
+
def test_get_formatter(model_name: str, expected_formatter: type[BaseFormatter]) -> None:
|
|
249
|
+
formatter = get_formatter(model_name)
|
|
250
|
+
assert isinstance(formatter, expected_formatter)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
# ReasoningFormatter tests
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def test_reasoning_formatter_with_system_and_user(llama3_reasoning_formatter: BaseFormatter) -> None:
|
|
257
|
+
conversation = [
|
|
258
|
+
Message(role=Role.SYSTEM, content="You are a helpful AI assistant for travel tips and recommendations"),
|
|
259
|
+
Message(role=Role.USER, content="What is France's capital?"),
|
|
260
|
+
]
|
|
261
|
+
|
|
262
|
+
formatted_conversation = llama3_reasoning_formatter.format(conversation, output_mode="string")
|
|
263
|
+
expected_output = (
|
|
264
|
+
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
|
|
265
|
+
"You are a helpful AI assistant for travel tips and recommendations<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
|
|
266
|
+
"What is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
267
|
+
"<|begin_of_thought|>"
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
assert formatted_conversation == expected_output
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def test_reasoning_formatter_with_user(llama3_reasoning_formatter: BaseFormatter) -> None:
|
|
274
|
+
conversation = [
|
|
275
|
+
Message(role=Role.USER, content="What is France's capital?"),
|
|
276
|
+
]
|
|
277
|
+
|
|
278
|
+
formatted_conversation = llama3_reasoning_formatter.format(conversation, output_mode="string")
|
|
279
|
+
expected_output = (
|
|
280
|
+
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
|
|
281
|
+
"What is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
282
|
+
"<|begin_of_thought|>"
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
assert formatted_conversation == expected_output
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def test_reasoning_formatter_with_system_user_and_thought(llama3_reasoning_formatter: BaseFormatter) -> None:
|
|
289
|
+
conversation = [
|
|
290
|
+
Message(role=Role.SYSTEM, content="You are a helpful AI assistant for travel tips and recommendations"),
|
|
291
|
+
Message(role=Role.USER, content="What is France's capital?"),
|
|
292
|
+
Message(role=Role.ASSISTANT, property=Property.THOUGHT, content="Bonjour! Let me think about this..."),
|
|
293
|
+
]
|
|
294
|
+
|
|
295
|
+
formatted_conversation = llama3_reasoning_formatter.format(conversation, output_mode="string")
|
|
296
|
+
expected_output = (
|
|
297
|
+
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
|
|
298
|
+
"You are a helpful AI assistant for travel tips and recommendations<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
|
|
299
|
+
"What is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
300
|
+
"<|begin_of_thought|>Bonjour! Let me think about this...<|end_of_thought|>"
|
|
301
|
+
"<|begin_of_solution|>"
|
|
302
|
+
)
|
|
303
|
+
assert formatted_conversation == expected_output
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def test_reasoning_formatter_with_system_user_thought_and_solution(llama3_reasoning_formatter: BaseFormatter) -> None:
|
|
307
|
+
conversation = [
|
|
308
|
+
Message(role=Role.SYSTEM, content="You are a helpful AI assistant for travel tips and recommendations"),
|
|
309
|
+
Message(role=Role.USER, content="What is France's capital?"),
|
|
310
|
+
Message(role=Role.ASSISTANT, property=Property.THOUGHT, content="Bonjour! Let me think about this..."),
|
|
311
|
+
Message(role=Role.ASSISTANT, property=Property.SOLUTION, content="Merci! The capital of France is Paris!"),
|
|
312
|
+
]
|
|
313
|
+
|
|
314
|
+
formatted_conversation = llama3_reasoning_formatter.format(conversation, output_mode="string")
|
|
315
|
+
expected_output = (
|
|
316
|
+
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
|
|
317
|
+
"You are a helpful AI assistant for travel tips and recommendations<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
|
|
318
|
+
"What is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
319
|
+
"<|begin_of_thought|>Bonjour! Let me think about this..."
|
|
320
|
+
"<|end_of_thought|><|begin_of_solution|>Merci! The capital of France is Paris!"
|
|
321
|
+
"<|begin_of_answer|>"
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
assert formatted_conversation == expected_output
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def test_reasoning_formatter_with_system_user_thought_solution_and_answer(
|
|
328
|
+
llama3_reasoning_formatter: BaseFormatter,
|
|
329
|
+
) -> None:
|
|
330
|
+
conversation = [
|
|
331
|
+
Message(role=Role.SYSTEM, content="You are a helpful AI assistant for travel tips and recommendations"),
|
|
332
|
+
Message(role=Role.USER, content="What is France's capital?"),
|
|
333
|
+
Message(role=Role.ASSISTANT, property=Property.THOUGHT, content="Bonjour! Let me think about this..."),
|
|
334
|
+
Message(role=Role.ASSISTANT, property=Property.SOLUTION, content="Merci! The capital of France is Paris!"),
|
|
335
|
+
Message(role=Role.ASSISTANT, property=Property.ANSWER, content="\\boxed{Paris}"),
|
|
336
|
+
]
|
|
337
|
+
|
|
338
|
+
formatted_conversation = llama3_reasoning_formatter.format(conversation, output_mode="string")
|
|
339
|
+
expected_output = (
|
|
340
|
+
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
|
|
341
|
+
"You are a helpful AI assistant for travel tips and recommendations<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
|
|
342
|
+
"What is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
343
|
+
"<|begin_of_thought|>Bonjour! Let me think about this...<|end_of_thought|>"
|
|
344
|
+
"<|begin_of_solution|>Merci! The capital of France is Paris!"
|
|
345
|
+
"<|begin_of_answer|>\\boxed{Paris}<|end_of_answer|><|end_of_solution|><|eot_id|><|end_of_text|>"
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
assert formatted_conversation == expected_output
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
def test_reasoning_formatter_parse_wrong_order() -> None:
|
|
352
|
+
base_formatter = Llama3Formatter
|
|
353
|
+
rf = ReasoningFormatter(base_formatter)
|
|
354
|
+
rt = rf.template
|
|
355
|
+
output_str = (
|
|
356
|
+
rt.begin_thought_id
|
|
357
|
+
+ "thought"
|
|
358
|
+
+ rt.begin_solution_id
|
|
359
|
+
+ "solution" # Wrong: begin_solution_id comes before end_thought_id.
|
|
360
|
+
+ rt.end_thought_id
|
|
361
|
+
+ rt.end_solution_id
|
|
362
|
+
+ rt.begin_answer_id
|
|
363
|
+
+ "answer"
|
|
364
|
+
+ rt.end_answer_id
|
|
365
|
+
+ rt.end_of_text
|
|
366
|
+
)
|
|
367
|
+
parsed, error = rf.parse(output_str)
|
|
368
|
+
assert error is not None
|
|
369
|
+
with pytest.raises(ValueError):
|
|
370
|
+
raise error
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def test_reasoning_formatter_parse_incomplete() -> None:
|
|
374
|
+
base_formatter = Llama3Formatter
|
|
375
|
+
rf = ReasoningFormatter(base_formatter)
|
|
376
|
+
rt = rf.template
|
|
377
|
+
|
|
378
|
+
output_str = rt.begin_thought_id + "only thought" + rt.end_thought_id
|
|
379
|
+
parsed, error = rf.parse(output_str)
|
|
380
|
+
assert error is None
|
|
381
|
+
assert parsed["thought"] == "only thought"
|
|
382
|
+
assert parsed.get("solution", "") == ""
|
|
383
|
+
assert parsed.get("answer", "") == ""
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def test_reasoning_formatter_parse_duplicate_tokens() -> None:
|
|
387
|
+
base_formatter = Llama3Formatter
|
|
388
|
+
rf = ReasoningFormatter(base_formatter)
|
|
389
|
+
rt = rf.template
|
|
390
|
+
|
|
391
|
+
output_str = (
|
|
392
|
+
rt.begin_thought_id
|
|
393
|
+
+ "thought"
|
|
394
|
+
+ rt.begin_thought_id
|
|
395
|
+
+ "duplicate"
|
|
396
|
+
+ rt.end_thought_id
|
|
397
|
+
+ rt.begin_solution_id
|
|
398
|
+
+ "solution"
|
|
399
|
+
+ rt.end_solution_id
|
|
400
|
+
+ rt.begin_answer_id
|
|
401
|
+
+ "answer"
|
|
402
|
+
+ rt.end_answer_id
|
|
403
|
+
+ rt.end_of_text
|
|
404
|
+
)
|
|
405
|
+
parsed, error = rf.parse(output_str)
|
|
406
|
+
assert error is not None
|
|
407
|
+
with pytest.raises(ValueError):
|
|
408
|
+
raise error
|
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
# ruff: noqa: E501
|
|
2
|
+
import importlib.util
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
from template_formatting.formatter import (
|
|
7
|
+
BaseFormatter,
|
|
8
|
+
ConcatFormatter,
|
|
9
|
+
HFFormatter,
|
|
10
|
+
Llama3Formatter,
|
|
11
|
+
Message,
|
|
12
|
+
Property,
|
|
13
|
+
ReasoningFormatter,
|
|
14
|
+
Role,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
package_exists = importlib.util.find_spec("transformers") is not None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@pytest.fixture()
|
|
21
|
+
def concat_formatter() -> BaseFormatter:
|
|
22
|
+
return ConcatFormatter()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@pytest.fixture()
|
|
26
|
+
def llama3_formatter() -> BaseFormatter:
|
|
27
|
+
return Llama3Formatter()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@pytest.fixture()
|
|
31
|
+
def hf_formatter() -> BaseFormatter:
|
|
32
|
+
return HFFormatter("meta-llama/Meta-Llama-3-8B-Instruct")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def test_get_grouped_messages_same_property() -> None:
|
|
36
|
+
defaults = {"content": "dummy", "has_loss": False, "type": "text"}
|
|
37
|
+
|
|
38
|
+
messages = [
|
|
39
|
+
Message(role=Role.USER, property=None, **defaults),
|
|
40
|
+
Message(role=Role.ASSISTANT, property=None, **defaults),
|
|
41
|
+
Message(role=Role.ASSISTANT, property=None, **defaults),
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
grouped_messages = BaseFormatter._get_grouped_messages(messages)
|
|
45
|
+
|
|
46
|
+
assert grouped_messages == [
|
|
47
|
+
[Message(role=Role.USER, property=None, **defaults)],
|
|
48
|
+
[Message(role=Role.ASSISTANT, property=None, **defaults)],
|
|
49
|
+
[Message(role=Role.ASSISTANT, property=None, **defaults)],
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def test_get_grouped_messages_different_property() -> None:
|
|
54
|
+
defaults = {"content": "dummy", "has_loss": False, "type": "text"}
|
|
55
|
+
|
|
56
|
+
messages = [
|
|
57
|
+
Message(role=Role.USER, property=None, **defaults),
|
|
58
|
+
Message(role=Role.ASSISTANT, property=None, **defaults),
|
|
59
|
+
Message(role=Role.ASSISTANT, property=Property.ANSWER, **defaults),
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
grouped_messages = BaseFormatter._get_grouped_messages(messages)
|
|
63
|
+
|
|
64
|
+
assert grouped_messages == [
|
|
65
|
+
[Message(role=Role.USER, property=None, **defaults)],
|
|
66
|
+
[
|
|
67
|
+
Message(role=Role.ASSISTANT, property=None, **defaults),
|
|
68
|
+
Message(role=Role.ASSISTANT, property=Property.ANSWER, **defaults),
|
|
69
|
+
],
|
|
70
|
+
]
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def test_base_verify_messages() -> None:
|
|
74
|
+
defaults = {"content": "dummy", "has_loss": False, "type": "text"}
|
|
75
|
+
|
|
76
|
+
messages = [
|
|
77
|
+
Message(role=Role.USER, property=None, **defaults),
|
|
78
|
+
Message(role=Role.ASSISTANT, property=None, **defaults),
|
|
79
|
+
]
|
|
80
|
+
|
|
81
|
+
# Does not raise an assertion error.
|
|
82
|
+
BaseFormatter._verify_messages(messages)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def test_base_verify_messages_raises_exception() -> None:
|
|
86
|
+
defaults = {"content": "dummy", "has_loss": False, "type": "text"}
|
|
87
|
+
|
|
88
|
+
messages = [
|
|
89
|
+
Message(role=Role.USER, property=None, **defaults),
|
|
90
|
+
Message(role=Role.ASSISTANT, property=None, **defaults),
|
|
91
|
+
Message(role=Role.ASSISTANT, property=None, **defaults),
|
|
92
|
+
]
|
|
93
|
+
|
|
94
|
+
with pytest.raises(AssertionError):
|
|
95
|
+
BaseFormatter._verify_messages(messages)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def test_reasoning_verify_messages() -> None:
|
|
99
|
+
defaults = {"content": "dummy", "has_loss": False, "type": "text"}
|
|
100
|
+
|
|
101
|
+
messages = [
|
|
102
|
+
Message(role=Role.USER, property=None, **defaults),
|
|
103
|
+
Message(role=Role.ASSISTANT, property=Property.THOUGHT, **defaults),
|
|
104
|
+
Message(role=Role.ASSISTANT, property=Property.SOLUTION, **defaults),
|
|
105
|
+
Message(role=Role.ASSISTANT, property=Property.ANSWER, **defaults),
|
|
106
|
+
]
|
|
107
|
+
|
|
108
|
+
# Does not raise an assertion error.
|
|
109
|
+
ReasoningFormatter._verify_messages(messages)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def test_reasoning_verify_messages_raises_exception() -> None:
|
|
113
|
+
defaults = {"content": "dummy", "has_loss": False, "type": "text"}
|
|
114
|
+
|
|
115
|
+
messages = [
|
|
116
|
+
Message(role=Role.USER, property=None, **defaults),
|
|
117
|
+
Message(role=Role.ASSISTANT, property=None, **defaults),
|
|
118
|
+
Message(role=Role.ASSISTANT, property=Property.ANSWER, **defaults),
|
|
119
|
+
]
|
|
120
|
+
|
|
121
|
+
with pytest.raises(AssertionError):
|
|
122
|
+
ReasoningFormatter._verify_messages(messages)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
## Assert that formatting is in line with HF Formatter
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@pytest.mark.skipif(
|
|
129
|
+
not package_exists,
|
|
130
|
+
reason="`transformers` package is not installed, HFFormatter will not be available.",
|
|
131
|
+
)
|
|
132
|
+
def test_llama3_formatter_with_system_and_assistant_simple(
|
|
133
|
+
llama3_formatter: BaseFormatter, hf_formatter: BaseFormatter
|
|
134
|
+
) -> None:
|
|
135
|
+
conversation = [
|
|
136
|
+
Message(
|
|
137
|
+
role=Role.SYSTEM,
|
|
138
|
+
content="You are a helpful AI assistant for travel tips and recommendations",
|
|
139
|
+
has_loss=False,
|
|
140
|
+
type="text",
|
|
141
|
+
),
|
|
142
|
+
Message(role=Role.USER, content="What is France's capital?", has_loss=False, type="text"),
|
|
143
|
+
Message(role=Role.ASSISTANT, content="Bonjour! The capital of France is Paris!", has_loss=True, type="text"),
|
|
144
|
+
]
|
|
145
|
+
|
|
146
|
+
formatted_conversation = llama3_formatter.format(conversation, output_mode="list")
|
|
147
|
+
|
|
148
|
+
expected_contents = [
|
|
149
|
+
(
|
|
150
|
+
"<|begin_of_text|>"
|
|
151
|
+
"<|start_header_id|>system<|end_header_id|>\n\n"
|
|
152
|
+
"You are a helpful AI assistant for travel tips and recommendations<|eot_id|>"
|
|
153
|
+
),
|
|
154
|
+
(
|
|
155
|
+
"<|start_header_id|>user<|end_header_id|>\n\n"
|
|
156
|
+
"What is France's capital?<|eot_id|>"
|
|
157
|
+
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
158
|
+
),
|
|
159
|
+
"Bonjour! The capital of France is Paris!<|eot_id|>",
|
|
160
|
+
]
|
|
161
|
+
|
|
162
|
+
for formatted_message, expected in zip(formatted_conversation, expected_contents):
|
|
163
|
+
assert formatted_message.content == expected
|
|
164
|
+
|
|
165
|
+
# stringify the list
|
|
166
|
+
formatted_conversation_str = "".join(elm.content for elm in formatted_conversation)
|
|
167
|
+
|
|
168
|
+
expected_output_str = (
|
|
169
|
+
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
|
|
170
|
+
"You are a helpful AI assistant for travel tips and recommendations<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
|
|
171
|
+
"What is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
172
|
+
"Bonjour! The capital of France is Paris!<|eot_id|>"
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
assert formatted_conversation_str == expected_output_str
|
|
176
|
+
|
|
177
|
+
hf_formatted_conversation = hf_formatter.format(conversation, output_mode="list")
|
|
178
|
+
assert hf_formatted_conversation == formatted_conversation_str
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
@pytest.mark.skipif(
|
|
182
|
+
not package_exists,
|
|
183
|
+
reason="`transformers` package is not installed, HFFormatter will not be available.",
|
|
184
|
+
)
|
|
185
|
+
def test_llama3_formatter_without_system_multiple_rounds_list(
|
|
186
|
+
llama3_formatter: BaseFormatter, hf_formatter: BaseFormatter
|
|
187
|
+
) -> None:
|
|
188
|
+
conversation = [
|
|
189
|
+
Message(role=Role.USER, content="What is France's capital?", has_loss=False, type="text"),
|
|
190
|
+
Message(role=Role.ASSISTANT, content="Bonjour! The capital of France is Paris!", has_loss=True, type="text"),
|
|
191
|
+
Message(role=Role.USER, content="What can I do there?", has_loss=False, type="text"),
|
|
192
|
+
Message(
|
|
193
|
+
role=Role.ASSISTANT,
|
|
194
|
+
content=(
|
|
195
|
+
"Paris offers many attractions and activities. "
|
|
196
|
+
"Some popular things to do include visiting the Eiffel Tower, "
|
|
197
|
+
"exploring the Louvre Museum, taking a river cruise along the Seine, "
|
|
198
|
+
"and strolling through charming neighborhoods like Montmartre."
|
|
199
|
+
),
|
|
200
|
+
has_loss=False,
|
|
201
|
+
type="text",
|
|
202
|
+
),
|
|
203
|
+
Message(role=Role.USER, content="What else?", has_loss=False, type="text"),
|
|
204
|
+
]
|
|
205
|
+
|
|
206
|
+
original_conversation = conversation.copy()
|
|
207
|
+
|
|
208
|
+
formatted_conversation = llama3_formatter.format(conversation, output_mode="list")
|
|
209
|
+
|
|
210
|
+
expected_contents = [
|
|
211
|
+
(
|
|
212
|
+
"<|begin_of_text|>"
|
|
213
|
+
"<|start_header_id|>user<|end_header_id|>\n\n"
|
|
214
|
+
"What is France's capital?<|eot_id|>"
|
|
215
|
+
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
216
|
+
),
|
|
217
|
+
"Bonjour! The capital of France is Paris!<|eot_id|>",
|
|
218
|
+
(
|
|
219
|
+
"<|start_header_id|>user<|end_header_id|>\n\n"
|
|
220
|
+
"What can I do there?<|eot_id|>"
|
|
221
|
+
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
222
|
+
),
|
|
223
|
+
(
|
|
224
|
+
"Paris offers many attractions and activities. Some popular things to do include visiting the Eiffel Tower, "
|
|
225
|
+
"exploring the Louvre Museum, taking a river cruise along the Seine, and strolling through charming neighborhoods like Montmartre.<|eot_id|>"
|
|
226
|
+
),
|
|
227
|
+
("<|start_header_id|>user<|end_header_id|>\n\nWhat else?<|eot_id|>"),
|
|
228
|
+
]
|
|
229
|
+
|
|
230
|
+
for formatted_message, expected in zip(formatted_conversation, expected_contents):
|
|
231
|
+
assert formatted_message.content == expected
|
|
232
|
+
|
|
233
|
+
# stringify the list
|
|
234
|
+
formatted_conversation_str = "".join(elm.content for elm in formatted_conversation)
|
|
235
|
+
|
|
236
|
+
expected_output_str = (
|
|
237
|
+
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
|
|
238
|
+
"What is France's capital?<|eot_id|>"
|
|
239
|
+
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
240
|
+
"Bonjour! The capital of France is Paris!<|eot_id|>"
|
|
241
|
+
"<|start_header_id|>user<|end_header_id|>\n\n"
|
|
242
|
+
"What can I do there?<|eot_id|>"
|
|
243
|
+
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
244
|
+
"Paris offers many attractions and activities. Some popular things to do "
|
|
245
|
+
"include visiting the Eiffel Tower, exploring the Louvre Museum, taking a river "
|
|
246
|
+
"cruise along the Seine, and strolling through charming neighborhoods like Montmartre.<|eot_id|>"
|
|
247
|
+
"<|start_header_id|>user<|end_header_id|>\n\n"
|
|
248
|
+
"What else?<|eot_id|>"
|
|
249
|
+
)
|
|
250
|
+
assert formatted_conversation_str == expected_output_str
|
|
251
|
+
|
|
252
|
+
hf_formatted_conversation = hf_formatter.format(original_conversation, output_mode="list")
|
|
253
|
+
assert hf_formatted_conversation == expected_output_str
|