kiln-ai 0.15.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +2 -0
- kiln_ai/adapters/adapter_registry.py +22 -44
- kiln_ai/adapters/chat/__init__.py +8 -0
- kiln_ai/adapters/chat/chat_formatter.py +234 -0
- kiln_ai/adapters/chat/test_chat_formatter.py +131 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +19 -6
- kiln_ai/adapters/eval/base_eval.py +8 -6
- kiln_ai/adapters/eval/eval_runner.py +9 -65
- kiln_ai/adapters/eval/g_eval.py +26 -8
- kiln_ai/adapters/eval/test_base_eval.py +166 -15
- kiln_ai/adapters/eval/test_eval_runner.py +3 -0
- kiln_ai/adapters/eval/test_g_eval.py +1 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +2 -2
- kiln_ai/adapters/fine_tune/dataset_formatter.py +153 -197
- kiln_ai/adapters/fine_tune/test_base_finetune.py +10 -10
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +402 -211
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +3 -3
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +6 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +1 -0
- kiln_ai/adapters/fine_tune/test_vertex_finetune.py +4 -4
- kiln_ai/adapters/fine_tune/together_finetune.py +12 -1
- kiln_ai/adapters/ml_model_list.py +556 -45
- kiln_ai/adapters/model_adapters/base_adapter.py +100 -35
- kiln_ai/adapters/model_adapters/litellm_adapter.py +116 -100
- kiln_ai/adapters/model_adapters/litellm_config.py +3 -2
- kiln_ai/adapters/model_adapters/test_base_adapter.py +299 -52
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +121 -22
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +44 -2
- kiln_ai/adapters/model_adapters/test_structured_output.py +48 -18
- kiln_ai/adapters/parsers/base_parser.py +0 -3
- kiln_ai/adapters/parsers/parser_registry.py +5 -3
- kiln_ai/adapters/parsers/r1_parser.py +17 -2
- kiln_ai/adapters/parsers/request_formatters.py +40 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +2 -2
- kiln_ai/adapters/parsers/test_r1_parser.py +44 -1
- kiln_ai/adapters/parsers/test_request_formatters.py +76 -0
- kiln_ai/adapters/prompt_builders.py +14 -17
- kiln_ai/adapters/provider_tools.py +39 -4
- kiln_ai/adapters/repair/test_repair_task.py +27 -5
- kiln_ai/adapters/test_adapter_registry.py +88 -28
- kiln_ai/adapters/test_ml_model_list.py +158 -0
- kiln_ai/adapters/test_prompt_adaptors.py +17 -3
- kiln_ai/adapters/test_prompt_builders.py +27 -19
- kiln_ai/adapters/test_provider_tools.py +130 -12
- kiln_ai/datamodel/__init__.py +2 -2
- kiln_ai/datamodel/datamodel_enums.py +43 -4
- kiln_ai/datamodel/dataset_filters.py +69 -1
- kiln_ai/datamodel/dataset_split.py +4 -0
- kiln_ai/datamodel/eval.py +8 -0
- kiln_ai/datamodel/finetune.py +13 -7
- kiln_ai/datamodel/prompt_id.py +1 -0
- kiln_ai/datamodel/task.py +68 -7
- kiln_ai/datamodel/task_output.py +1 -1
- kiln_ai/datamodel/task_run.py +39 -7
- kiln_ai/datamodel/test_basemodel.py +5 -8
- kiln_ai/datamodel/test_dataset_filters.py +82 -0
- kiln_ai/datamodel/test_dataset_split.py +2 -8
- kiln_ai/datamodel/test_example_models.py +54 -0
- kiln_ai/datamodel/test_models.py +80 -9
- kiln_ai/datamodel/test_task.py +168 -2
- kiln_ai/utils/async_job_runner.py +106 -0
- kiln_ai/utils/config.py +3 -2
- kiln_ai/utils/dataset_import.py +81 -19
- kiln_ai/utils/logging.py +165 -0
- kiln_ai/utils/test_async_job_runner.py +199 -0
- kiln_ai/utils/test_config.py +23 -0
- kiln_ai/utils/test_dataset_import.py +272 -10
- {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/METADATA +1 -1
- kiln_ai-0.17.0.dist-info/RECORD +113 -0
- kiln_ai-0.15.0.dist-info/RECORD +0 -104
- {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -1,32 +1,34 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
|
+
import re
|
|
3
4
|
import tempfile
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
from unittest.mock import Mock
|
|
6
7
|
|
|
7
8
|
import pytest
|
|
8
9
|
|
|
10
|
+
from kiln_ai.adapters.chat.chat_formatter import COT_FINAL_ANSWER_PROMPT, ChatMessage
|
|
9
11
|
from kiln_ai.adapters.fine_tune.dataset_formatter import (
|
|
12
|
+
VERTEX_GEMINI_ROLE_MAP,
|
|
10
13
|
DatasetFormat,
|
|
11
14
|
DatasetFormatter,
|
|
12
|
-
|
|
13
|
-
build_training_data,
|
|
15
|
+
build_training_chat,
|
|
14
16
|
generate_chat_message_response,
|
|
15
17
|
generate_chat_message_toolcall,
|
|
16
18
|
generate_huggingface_chat_template,
|
|
17
19
|
generate_huggingface_chat_template_toolcall,
|
|
18
20
|
generate_vertex_gemini,
|
|
21
|
+
serialize_r1_style_message,
|
|
19
22
|
)
|
|
20
|
-
from kiln_ai.adapters.model_adapters.base_adapter import COT_FINAL_ANSWER_PROMPT
|
|
21
23
|
from kiln_ai.datamodel import (
|
|
22
24
|
DatasetSplit,
|
|
23
25
|
DataSource,
|
|
24
26
|
DataSourceType,
|
|
25
|
-
FinetuneDataStrategy,
|
|
26
27
|
Task,
|
|
27
28
|
TaskOutput,
|
|
28
29
|
TaskRun,
|
|
29
30
|
)
|
|
31
|
+
from kiln_ai.datamodel.datamodel_enums import ChatStrategy
|
|
30
32
|
|
|
31
33
|
logger = logging.getLogger(__name__)
|
|
32
34
|
|
|
@@ -42,6 +44,7 @@ def mock_task():
|
|
|
42
44
|
"input": '{"test": "input 你好"}',
|
|
43
45
|
"repaired_output": None,
|
|
44
46
|
"intermediate_outputs": {},
|
|
47
|
+
"thinking_training_data": Mock(return_value=None),
|
|
45
48
|
"input_source": Mock(
|
|
46
49
|
spec=DataSource,
|
|
47
50
|
**{
|
|
@@ -83,6 +86,7 @@ def mock_task():
|
|
|
83
86
|
def mock_intermediate_outputs(mock_task):
|
|
84
87
|
for run in mock_task.runs():
|
|
85
88
|
run.intermediate_outputs = {"reasoning": "thinking output"}
|
|
89
|
+
run.thinking_training_data.return_value = "thinking output"
|
|
86
90
|
mock_task.thinking_instruction = "thinking instructions"
|
|
87
91
|
return mock_task
|
|
88
92
|
|
|
@@ -96,41 +100,56 @@ def mock_dataset(mock_task):
|
|
|
96
100
|
return dataset
|
|
97
101
|
|
|
98
102
|
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
103
|
+
@pytest.fixture
|
|
104
|
+
def mock_training_chat_short():
|
|
105
|
+
return [
|
|
106
|
+
ChatMessage(role="system", content="system message"),
|
|
107
|
+
ChatMessage(
|
|
108
|
+
role="user",
|
|
109
|
+
content="test input",
|
|
110
|
+
),
|
|
111
|
+
ChatMessage(role="assistant", content="test output"),
|
|
112
|
+
]
|
|
105
113
|
|
|
106
|
-
result = generate_chat_message_response(thinking_data)
|
|
107
114
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
+
@pytest.fixture
|
|
116
|
+
def mock_training_chat_two_step_plaintext():
|
|
117
|
+
return [
|
|
118
|
+
ChatMessage(role="system", content="system message"),
|
|
119
|
+
ChatMessage(
|
|
120
|
+
role="user",
|
|
121
|
+
content="The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
|
|
122
|
+
),
|
|
123
|
+
ChatMessage(role="assistant", content="thinking output"),
|
|
124
|
+
ChatMessage(role="user", content="thinking final answer prompt"),
|
|
125
|
+
ChatMessage(role="assistant", content="test output"),
|
|
126
|
+
]
|
|
115
127
|
|
|
116
128
|
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
129
|
+
@pytest.fixture
|
|
130
|
+
def mock_training_chat_two_step_json():
|
|
131
|
+
return [
|
|
132
|
+
ChatMessage(role="system", content="system message"),
|
|
133
|
+
ChatMessage(
|
|
134
|
+
role="user",
|
|
135
|
+
content="The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
|
|
136
|
+
),
|
|
137
|
+
ChatMessage(role="assistant", content="thinking output"),
|
|
138
|
+
ChatMessage(role="user", content="thinking final answer prompt"),
|
|
139
|
+
ChatMessage(role="assistant", content='{"a":"你好"}'),
|
|
140
|
+
]
|
|
126
141
|
|
|
127
|
-
|
|
142
|
+
|
|
143
|
+
def test_generate_chat_message_response(mock_training_chat_two_step_plaintext):
|
|
144
|
+
result = generate_chat_message_response(mock_training_chat_two_step_plaintext)
|
|
128
145
|
|
|
129
146
|
assert result == {
|
|
130
147
|
"messages": [
|
|
131
148
|
{"role": "system", "content": "system message"},
|
|
132
|
-
{
|
|
133
|
-
|
|
149
|
+
{
|
|
150
|
+
"role": "user",
|
|
151
|
+
"content": "The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
|
|
152
|
+
},
|
|
134
153
|
{"role": "assistant", "content": "thinking output"},
|
|
135
154
|
{"role": "user", "content": "thinking final answer prompt"},
|
|
136
155
|
{"role": "assistant", "content": "test output"},
|
|
@@ -138,54 +157,33 @@ def test_generate_chat_message_response_thinking():
|
|
|
138
157
|
}
|
|
139
158
|
|
|
140
159
|
|
|
141
|
-
def
|
|
142
|
-
|
|
143
|
-
input="test input 你好",
|
|
144
|
-
system_message="system message 你好",
|
|
145
|
-
final_output='{"key": "value 你好"}',
|
|
146
|
-
)
|
|
147
|
-
|
|
148
|
-
result = generate_chat_message_toolcall(training_data)
|
|
160
|
+
def test_generate_chat_message_response_json(mock_training_chat_two_step_json):
|
|
161
|
+
result = generate_chat_message_response(mock_training_chat_two_step_json)
|
|
149
162
|
|
|
150
163
|
assert result == {
|
|
151
164
|
"messages": [
|
|
152
|
-
{"role": "system", "content": "system message
|
|
153
|
-
{"role": "user", "content": "test input 你好"},
|
|
165
|
+
{"role": "system", "content": "system message"},
|
|
154
166
|
{
|
|
155
|
-
"role": "
|
|
156
|
-
"content":
|
|
157
|
-
"tool_calls": [
|
|
158
|
-
{
|
|
159
|
-
"id": "call_1",
|
|
160
|
-
"type": "function",
|
|
161
|
-
"function": {
|
|
162
|
-
"name": "task_response",
|
|
163
|
-
"arguments": '{"key": "value 你好"}',
|
|
164
|
-
},
|
|
165
|
-
}
|
|
166
|
-
],
|
|
167
|
+
"role": "user",
|
|
168
|
+
"content": "The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
|
|
167
169
|
},
|
|
170
|
+
{"role": "assistant", "content": "thinking output"},
|
|
171
|
+
{"role": "user", "content": "thinking final answer prompt"},
|
|
172
|
+
{"role": "assistant", "content": '{"a":"你好"}'},
|
|
168
173
|
]
|
|
169
174
|
}
|
|
170
175
|
|
|
171
176
|
|
|
172
|
-
def
|
|
173
|
-
|
|
174
|
-
input="test input",
|
|
175
|
-
system_message="system message",
|
|
176
|
-
final_output='{"key": "value"}',
|
|
177
|
-
thinking="thinking output",
|
|
178
|
-
thinking_instructions="thinking instructions",
|
|
179
|
-
thinking_final_answer_prompt="thinking final answer prompt",
|
|
180
|
-
)
|
|
181
|
-
|
|
182
|
-
result = generate_chat_message_toolcall(training_data)
|
|
177
|
+
def test_generate_chat_message_toolcall(mock_training_chat_two_step_json):
|
|
178
|
+
result = generate_chat_message_toolcall(mock_training_chat_two_step_json)
|
|
183
179
|
|
|
184
180
|
assert result == {
|
|
185
181
|
"messages": [
|
|
186
182
|
{"role": "system", "content": "system message"},
|
|
187
|
-
{
|
|
188
|
-
|
|
183
|
+
{
|
|
184
|
+
"role": "user",
|
|
185
|
+
"content": "The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
|
|
186
|
+
},
|
|
189
187
|
{"role": "assistant", "content": "thinking output"},
|
|
190
188
|
{"role": "user", "content": "thinking final answer prompt"},
|
|
191
189
|
{
|
|
@@ -197,7 +195,7 @@ def test_generate_chat_message_toolcall_thinking():
|
|
|
197
195
|
"type": "function",
|
|
198
196
|
"function": {
|
|
199
197
|
"name": "task_response",
|
|
200
|
-
"arguments": '{"
|
|
198
|
+
"arguments": '{"a": "你好"}',
|
|
201
199
|
},
|
|
202
200
|
}
|
|
203
201
|
],
|
|
@@ -206,31 +204,17 @@ def test_generate_chat_message_toolcall_thinking():
|
|
|
206
204
|
}
|
|
207
205
|
|
|
208
206
|
|
|
209
|
-
def test_generate_chat_message_toolcall_invalid_json():
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
final_output="invalid json",
|
|
214
|
-
)
|
|
215
|
-
|
|
216
|
-
with pytest.raises(ValueError, match="Invalid JSON in for tool call"):
|
|
217
|
-
generate_chat_message_toolcall(training_data)
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
def test_dataset_formatter_init_no_parent_task(mock_dataset):
|
|
221
|
-
mock_dataset.parent_task.return_value = None
|
|
222
|
-
|
|
223
|
-
with pytest.raises(ValueError, match="Dataset has no parent task"):
|
|
224
|
-
DatasetFormatter(mock_dataset, "system message")
|
|
207
|
+
def test_generate_chat_message_toolcall_invalid_json(mock_training_chat_two_step_json):
|
|
208
|
+
mock_training_chat_two_step_json[-1].content = "invalid json"
|
|
209
|
+
with pytest.raises(ValueError, match="^Last message is not JSON"):
|
|
210
|
+
generate_chat_message_toolcall(mock_training_chat_two_step_json)
|
|
225
211
|
|
|
226
212
|
|
|
227
213
|
def test_dataset_formatter_dump_invalid_format(mock_dataset):
|
|
228
214
|
formatter = DatasetFormatter(mock_dataset, "system message")
|
|
229
215
|
|
|
230
216
|
with pytest.raises(ValueError, match="Unsupported format"):
|
|
231
|
-
formatter.dump_to_file(
|
|
232
|
-
"train", "invalid_format", FinetuneDataStrategy.final_only
|
|
233
|
-
) # type: ignore
|
|
217
|
+
formatter.dump_to_file("train", "invalid_format", ChatStrategy.single_turn)
|
|
234
218
|
|
|
235
219
|
|
|
236
220
|
def test_dataset_formatter_dump_invalid_split(mock_dataset):
|
|
@@ -240,7 +224,7 @@ def test_dataset_formatter_dump_invalid_split(mock_dataset):
|
|
|
240
224
|
formatter.dump_to_file(
|
|
241
225
|
"invalid_split",
|
|
242
226
|
DatasetFormat.OPENAI_CHAT_JSONL,
|
|
243
|
-
|
|
227
|
+
ChatStrategy.single_turn,
|
|
244
228
|
)
|
|
245
229
|
|
|
246
230
|
|
|
@@ -252,7 +236,7 @@ def test_dataset_formatter_dump_to_file(mock_dataset, tmp_path):
|
|
|
252
236
|
"train",
|
|
253
237
|
DatasetFormat.OPENAI_CHAT_JSONL,
|
|
254
238
|
path=output_path,
|
|
255
|
-
data_strategy=
|
|
239
|
+
data_strategy=ChatStrategy.single_turn,
|
|
256
240
|
)
|
|
257
241
|
|
|
258
242
|
assert result_path == output_path
|
|
@@ -278,7 +262,7 @@ def test_dataset_formatter_dump_to_temp_file(mock_dataset):
|
|
|
278
262
|
result_path = formatter.dump_to_file(
|
|
279
263
|
"train",
|
|
280
264
|
DatasetFormat.OPENAI_CHAT_JSONL,
|
|
281
|
-
data_strategy=
|
|
265
|
+
data_strategy=ChatStrategy.single_turn,
|
|
282
266
|
)
|
|
283
267
|
|
|
284
268
|
assert result_path.exists()
|
|
@@ -309,7 +293,7 @@ def test_dataset_formatter_dump_to_file_tool_format(mock_dataset, tmp_path):
|
|
|
309
293
|
"train",
|
|
310
294
|
DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL,
|
|
311
295
|
path=output_path,
|
|
312
|
-
data_strategy=
|
|
296
|
+
data_strategy=ChatStrategy.single_turn,
|
|
313
297
|
)
|
|
314
298
|
|
|
315
299
|
assert result_path == output_path
|
|
@@ -349,7 +333,7 @@ def test_dataset_formatter_dump_with_intermediate_data(
|
|
|
349
333
|
result_path = formatter.dump_to_file(
|
|
350
334
|
"train",
|
|
351
335
|
DatasetFormat.OPENAI_CHAT_JSONL,
|
|
352
|
-
data_strategy=
|
|
336
|
+
data_strategy=ChatStrategy.two_message_cot_legacy,
|
|
353
337
|
)
|
|
354
338
|
|
|
355
339
|
assert result_path.exists()
|
|
@@ -368,17 +352,19 @@ def test_dataset_formatter_dump_with_intermediate_data(
|
|
|
368
352
|
assert "thinking instructions" in line
|
|
369
353
|
|
|
370
354
|
|
|
371
|
-
def
|
|
355
|
+
def test_dataset_formatter_dump_with_intermediate_data_r1_style(
|
|
372
356
|
mock_dataset, mock_intermediate_outputs
|
|
373
357
|
):
|
|
374
358
|
formatter = DatasetFormatter(
|
|
375
|
-
mock_dataset,
|
|
359
|
+
mock_dataset,
|
|
360
|
+
"system message 你好",
|
|
361
|
+
thinking_instructions=None,
|
|
376
362
|
)
|
|
377
363
|
|
|
378
364
|
result_path = formatter.dump_to_file(
|
|
379
365
|
"train",
|
|
380
366
|
DatasetFormat.OPENAI_CHAT_JSONL,
|
|
381
|
-
data_strategy=
|
|
367
|
+
data_strategy=ChatStrategy.single_turn_r1_thinking,
|
|
382
368
|
)
|
|
383
369
|
|
|
384
370
|
assert result_path.exists()
|
|
@@ -393,46 +379,50 @@ def test_dataset_formatter_dump_with_intermediate_data_custom_instructions(
|
|
|
393
379
|
lines = f.readlines()
|
|
394
380
|
assert len(lines) == 2
|
|
395
381
|
for line in lines:
|
|
396
|
-
assert "
|
|
397
|
-
assert "
|
|
398
|
-
assert "thinking output" in line
|
|
382
|
+
assert "<think>" in line
|
|
383
|
+
assert "</think>" in line
|
|
399
384
|
|
|
400
385
|
|
|
401
|
-
def
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
386
|
+
def test_dataset_formatter_dump_with_intermediate_data_custom_instructions(
|
|
387
|
+
mock_dataset, mock_intermediate_outputs
|
|
388
|
+
):
|
|
389
|
+
formatter = DatasetFormatter(
|
|
390
|
+
mock_dataset, "custom system message 你好", "custom thinking instructions"
|
|
406
391
|
)
|
|
407
392
|
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
{"role": "user", "content": "test input"},
|
|
414
|
-
{"role": "assistant", "content": "test output"},
|
|
415
|
-
]
|
|
416
|
-
}
|
|
417
|
-
|
|
393
|
+
result_path = formatter.dump_to_file(
|
|
394
|
+
"train",
|
|
395
|
+
DatasetFormat.OPENAI_CHAT_JSONL,
|
|
396
|
+
data_strategy=ChatStrategy.two_message_cot_legacy,
|
|
397
|
+
)
|
|
418
398
|
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
thinking_instructions="thinking instructions",
|
|
426
|
-
thinking_final_answer_prompt="thinking final answer prompt",
|
|
399
|
+
assert result_path.exists()
|
|
400
|
+
assert result_path.parent == Path(tempfile.gettempdir())
|
|
401
|
+
# Test our nice naming, with cot
|
|
402
|
+
assert (
|
|
403
|
+
result_path.name
|
|
404
|
+
== "test_dataset -- split-train -- format-openai_chat_jsonl -- cot.jsonl"
|
|
427
405
|
)
|
|
406
|
+
# Verify file contents
|
|
407
|
+
with open(result_path) as f:
|
|
408
|
+
lines = f.readlines()
|
|
409
|
+
assert len(lines) == 2
|
|
410
|
+
for line in lines:
|
|
411
|
+
assert "custom system message 你好" in line
|
|
412
|
+
assert "custom thinking instructions" in line
|
|
413
|
+
assert "thinking output" in line
|
|
428
414
|
|
|
429
|
-
|
|
415
|
+
|
|
416
|
+
def test_generate_huggingface_chat_template(mock_training_chat_two_step_plaintext):
|
|
417
|
+
result = generate_huggingface_chat_template(mock_training_chat_two_step_plaintext)
|
|
430
418
|
|
|
431
419
|
assert result == {
|
|
432
420
|
"conversations": [
|
|
433
421
|
{"role": "system", "content": "system message"},
|
|
434
|
-
{
|
|
435
|
-
|
|
422
|
+
{
|
|
423
|
+
"role": "user",
|
|
424
|
+
"content": "The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
|
|
425
|
+
},
|
|
436
426
|
{"role": "assistant", "content": "thinking output"},
|
|
437
427
|
{"role": "user", "content": "thinking final answer prompt"},
|
|
438
428
|
{"role": "assistant", "content": "test output"},
|
|
@@ -440,14 +430,8 @@ def test_generate_huggingface_chat_template_thinking():
|
|
|
440
430
|
}
|
|
441
431
|
|
|
442
432
|
|
|
443
|
-
def test_generate_vertex_template():
|
|
444
|
-
|
|
445
|
-
input="test input",
|
|
446
|
-
system_message="system message",
|
|
447
|
-
final_output="test output",
|
|
448
|
-
)
|
|
449
|
-
|
|
450
|
-
result = generate_vertex_gemini(training_data)
|
|
433
|
+
def test_generate_vertex_template(mock_training_chat_short):
|
|
434
|
+
result = generate_vertex_gemini(mock_training_chat_short)
|
|
451
435
|
|
|
452
436
|
assert result == {
|
|
453
437
|
"systemInstruction": {
|
|
@@ -465,19 +449,8 @@ def test_generate_vertex_template():
|
|
|
465
449
|
}
|
|
466
450
|
|
|
467
451
|
|
|
468
|
-
def test_generate_vertex_template_thinking():
|
|
469
|
-
|
|
470
|
-
input="test input",
|
|
471
|
-
system_message="system message",
|
|
472
|
-
final_output="test output",
|
|
473
|
-
thinking="thinking output",
|
|
474
|
-
thinking_instructions="thinking instructions",
|
|
475
|
-
thinking_final_answer_prompt="thinking final answer prompt",
|
|
476
|
-
)
|
|
477
|
-
|
|
478
|
-
result = generate_vertex_gemini(training_data)
|
|
479
|
-
|
|
480
|
-
logger.info(result)
|
|
452
|
+
def test_generate_vertex_template_thinking(mock_training_chat_two_step_plaintext):
|
|
453
|
+
result = generate_vertex_gemini(mock_training_chat_two_step_plaintext)
|
|
481
454
|
|
|
482
455
|
assert result == {
|
|
483
456
|
"systemInstruction": {
|
|
@@ -489,8 +462,14 @@ def test_generate_vertex_template_thinking():
|
|
|
489
462
|
],
|
|
490
463
|
},
|
|
491
464
|
"contents": [
|
|
492
|
-
{
|
|
493
|
-
|
|
465
|
+
{
|
|
466
|
+
"role": "user",
|
|
467
|
+
"parts": [
|
|
468
|
+
{
|
|
469
|
+
"text": "The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
|
|
470
|
+
}
|
|
471
|
+
],
|
|
472
|
+
},
|
|
494
473
|
{"role": "model", "parts": [{"text": "thinking output"}]},
|
|
495
474
|
{"role": "user", "parts": [{"text": "thinking final answer prompt"}]},
|
|
496
475
|
{"role": "model", "parts": [{"text": "test output"}]},
|
|
@@ -499,13 +478,13 @@ def test_generate_vertex_template_thinking():
|
|
|
499
478
|
|
|
500
479
|
|
|
501
480
|
def test_generate_huggingface_chat_template_toolcall():
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
481
|
+
messages = [
|
|
482
|
+
ChatMessage("system", "system message"),
|
|
483
|
+
ChatMessage("user", "test input"),
|
|
484
|
+
ChatMessage("assistant", '{"key":"value"}'),
|
|
485
|
+
]
|
|
507
486
|
|
|
508
|
-
result = generate_huggingface_chat_template_toolcall(
|
|
487
|
+
result = generate_huggingface_chat_template_toolcall(messages)
|
|
509
488
|
|
|
510
489
|
assert result["conversations"][0] == {"role": "system", "content": "system message"}
|
|
511
490
|
assert result["conversations"][1] == {"role": "user", "content": "test input"}
|
|
@@ -520,34 +499,28 @@ def test_generate_huggingface_chat_template_toolcall():
|
|
|
520
499
|
assert tool_call["function"]["arguments"] == {"key": "value"}
|
|
521
500
|
|
|
522
501
|
|
|
523
|
-
def test_generate_huggingface_chat_template_toolcall_thinking(
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
thinking="thinking output",
|
|
529
|
-
thinking_instructions="thinking instructions",
|
|
530
|
-
thinking_final_answer_prompt="thinking final answer prompt",
|
|
502
|
+
def test_generate_huggingface_chat_template_toolcall_thinking(
|
|
503
|
+
mock_training_chat_two_step_json,
|
|
504
|
+
):
|
|
505
|
+
result = generate_huggingface_chat_template_toolcall(
|
|
506
|
+
mock_training_chat_two_step_json
|
|
531
507
|
)
|
|
532
508
|
|
|
533
|
-
result = generate_huggingface_chat_template_toolcall(training_data)
|
|
534
|
-
|
|
535
509
|
assert result["conversations"][0] == {"role": "system", "content": "system message"}
|
|
536
|
-
assert result["conversations"][1] == {
|
|
537
|
-
assert result["conversations"][2] == {
|
|
510
|
+
assert result["conversations"][1] == {
|
|
538
511
|
"role": "user",
|
|
539
|
-
"content": "
|
|
512
|
+
"content": "The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
|
|
540
513
|
}
|
|
541
|
-
assert result["conversations"][
|
|
514
|
+
assert result["conversations"][2] == {
|
|
542
515
|
"role": "assistant",
|
|
543
516
|
"content": "thinking output",
|
|
544
517
|
}
|
|
545
|
-
assert result["conversations"][
|
|
518
|
+
assert result["conversations"][3] == {
|
|
546
519
|
"role": "user",
|
|
547
520
|
"content": "thinking final answer prompt",
|
|
548
521
|
}
|
|
549
522
|
|
|
550
|
-
assistant_msg = result["conversations"][
|
|
523
|
+
assistant_msg = result["conversations"][4]
|
|
551
524
|
assert assistant_msg["role"] == "assistant"
|
|
552
525
|
assert len(assistant_msg["tool_calls"]) == 1
|
|
553
526
|
tool_call = assistant_msg["tool_calls"][0]
|
|
@@ -555,31 +528,39 @@ def test_generate_huggingface_chat_template_toolcall_thinking():
|
|
|
555
528
|
assert tool_call["function"]["name"] == "task_response"
|
|
556
529
|
assert len(tool_call["function"]["id"]) == 9 # UUID is truncated to 9 chars
|
|
557
530
|
assert tool_call["function"]["id"].isalnum() # Check ID is alphanumeric
|
|
558
|
-
assert tool_call["function"]["arguments"] == {"
|
|
531
|
+
assert tool_call["function"]["arguments"] == {"a": "你好"}
|
|
559
532
|
|
|
560
533
|
|
|
561
|
-
def test_generate_huggingface_chat_template_toolcall_invalid_json(
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
final_output="invalid json",
|
|
566
|
-
)
|
|
534
|
+
def test_generate_huggingface_chat_template_toolcall_invalid_json(
|
|
535
|
+
mock_training_chat_two_step_json,
|
|
536
|
+
):
|
|
537
|
+
mock_training_chat_two_step_json[-1].content = "invalid json"
|
|
567
538
|
|
|
568
|
-
with pytest.raises(ValueError, match="
|
|
569
|
-
generate_huggingface_chat_template_toolcall(
|
|
539
|
+
with pytest.raises(ValueError, match="^Last message is not JSON"):
|
|
540
|
+
generate_huggingface_chat_template_toolcall(mock_training_chat_two_step_json)
|
|
570
541
|
|
|
571
542
|
|
|
572
|
-
def
|
|
543
|
+
def test_build_training_chat(mock_task):
|
|
573
544
|
# Non repaired should use original output
|
|
574
545
|
mock_task_run = mock_task.runs()[0]
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
assert
|
|
582
|
-
|
|
546
|
+
messages = build_training_chat(
|
|
547
|
+
mock_task_run,
|
|
548
|
+
"system message",
|
|
549
|
+
data_strategy=ChatStrategy.single_turn,
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
assert len(messages) == 3
|
|
553
|
+
system_msg = messages[0]
|
|
554
|
+
assert system_msg.role == "system"
|
|
555
|
+
assert system_msg.content == "system message"
|
|
556
|
+
|
|
557
|
+
user_msg = messages[1]
|
|
558
|
+
assert user_msg.role == "user"
|
|
559
|
+
assert user_msg.content == '{"test": "input 你好"}'
|
|
560
|
+
|
|
561
|
+
final_msg = messages[2]
|
|
562
|
+
assert final_msg.role == "assistant"
|
|
563
|
+
assert final_msg.content == '{"test": "output 你好"}'
|
|
583
564
|
|
|
584
565
|
|
|
585
566
|
def test_build_training_data_with_COT(mock_task):
|
|
@@ -587,20 +568,109 @@ def test_build_training_data_with_COT(mock_task):
|
|
|
587
568
|
mock_task_run = mock_task.runs()[0]
|
|
588
569
|
assert mock_task_run.parent_task() == mock_task
|
|
589
570
|
mock_task_run.intermediate_outputs = {"chain_of_thought": "cot output"}
|
|
571
|
+
mock_task_run.thinking_training_data.return_value = "cot output"
|
|
590
572
|
|
|
591
|
-
|
|
573
|
+
messages = build_training_chat(
|
|
592
574
|
mock_task_run,
|
|
593
575
|
"system message",
|
|
594
|
-
|
|
576
|
+
data_strategy=ChatStrategy.two_message_cot,
|
|
595
577
|
thinking_instructions="thinking instructions",
|
|
596
578
|
)
|
|
597
|
-
|
|
598
|
-
assert
|
|
599
|
-
|
|
600
|
-
assert
|
|
601
|
-
assert
|
|
602
|
-
|
|
603
|
-
|
|
579
|
+
|
|
580
|
+
assert len(messages) == 5
|
|
581
|
+
system_msg = messages[0]
|
|
582
|
+
assert system_msg.role == "system"
|
|
583
|
+
assert system_msg.content == "system message"
|
|
584
|
+
|
|
585
|
+
user_msg = messages[1]
|
|
586
|
+
assert user_msg.role == "user"
|
|
587
|
+
assert (
|
|
588
|
+
user_msg.content
|
|
589
|
+
== 'The input is:\n<user_input>\n{"test": "input 你好"}\n</user_input>\n\nthinking instructions'
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
assistant_msg = messages[2]
|
|
593
|
+
assert assistant_msg.role == "assistant"
|
|
594
|
+
assert assistant_msg.content == "cot output"
|
|
595
|
+
|
|
596
|
+
final_answer_prompt_msg = messages[3]
|
|
597
|
+
assert final_answer_prompt_msg.role == "user"
|
|
598
|
+
assert final_answer_prompt_msg.content == COT_FINAL_ANSWER_PROMPT
|
|
599
|
+
|
|
600
|
+
final_msg = messages[4]
|
|
601
|
+
assert final_msg.role == "assistant"
|
|
602
|
+
assert final_msg.content == '{"test": "output 你好"}'
|
|
603
|
+
|
|
604
|
+
|
|
605
|
+
def test_build_training_data_with_COT_legacy(mock_task):
|
|
606
|
+
# Setup with needed fields for thinking
|
|
607
|
+
mock_task_run = mock_task.runs()[0]
|
|
608
|
+
assert mock_task_run.parent_task() == mock_task
|
|
609
|
+
mock_task_run.intermediate_outputs = {"chain_of_thought": "cot output"}
|
|
610
|
+
mock_task_run.thinking_training_data.return_value = "cot output"
|
|
611
|
+
|
|
612
|
+
messages = build_training_chat(
|
|
613
|
+
mock_task_run,
|
|
614
|
+
"system message",
|
|
615
|
+
data_strategy=ChatStrategy.two_message_cot_legacy,
|
|
616
|
+
thinking_instructions="thinking instructions",
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
assert len(messages) == 6
|
|
620
|
+
system_msg = messages[0]
|
|
621
|
+
assert system_msg.role == "system"
|
|
622
|
+
assert system_msg.content == "system message"
|
|
623
|
+
|
|
624
|
+
user_msg = messages[1]
|
|
625
|
+
assert user_msg.role == "user"
|
|
626
|
+
assert user_msg.content == '{"test": "input 你好"}'
|
|
627
|
+
|
|
628
|
+
cot_msg = messages[2]
|
|
629
|
+
assert cot_msg.role == "system"
|
|
630
|
+
assert cot_msg.content == "thinking instructions"
|
|
631
|
+
|
|
632
|
+
assistant_msg = messages[3]
|
|
633
|
+
assert assistant_msg.role == "assistant"
|
|
634
|
+
assert assistant_msg.content == "cot output"
|
|
635
|
+
|
|
636
|
+
final_answer_prompt_msg = messages[4]
|
|
637
|
+
assert final_answer_prompt_msg.role == "user"
|
|
638
|
+
assert final_answer_prompt_msg.content == COT_FINAL_ANSWER_PROMPT
|
|
639
|
+
|
|
640
|
+
final_msg = messages[5]
|
|
641
|
+
assert final_msg.role == "assistant"
|
|
642
|
+
assert final_msg.content == '{"test": "output 你好"}'
|
|
643
|
+
|
|
644
|
+
|
|
645
|
+
def test_build_training_data_with_COT_r1_style(mock_task):
|
|
646
|
+
# Setup with needed fields for thinking
|
|
647
|
+
mock_task_run = mock_task.runs()[0]
|
|
648
|
+
assert mock_task_run.parent_task() == mock_task
|
|
649
|
+
mock_task_run.intermediate_outputs = {"chain_of_thought": "cot output"}
|
|
650
|
+
mock_task_run.thinking_training_data.return_value = "cot output"
|
|
651
|
+
|
|
652
|
+
messages = build_training_chat(
|
|
653
|
+
mock_task_run,
|
|
654
|
+
"system message",
|
|
655
|
+
data_strategy=ChatStrategy.single_turn_r1_thinking,
|
|
656
|
+
thinking_instructions=None,
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
assert len(messages) == 3
|
|
660
|
+
system_msg = messages[0]
|
|
661
|
+
assert system_msg.role == "system"
|
|
662
|
+
assert system_msg.content == "system message"
|
|
663
|
+
|
|
664
|
+
user_msg = messages[1]
|
|
665
|
+
assert user_msg.role == "user"
|
|
666
|
+
assert user_msg.content == '{"test": "input 你好"}'
|
|
667
|
+
|
|
668
|
+
final_msg = messages[2]
|
|
669
|
+
assert final_msg.role == "assistant"
|
|
670
|
+
assert (
|
|
671
|
+
final_msg.content
|
|
672
|
+
== '<think>\ncot output\n</think>\n\n{"test": "output 你好"}'
|
|
673
|
+
)
|
|
604
674
|
|
|
605
675
|
|
|
606
676
|
def test_build_training_data_with_thinking(mock_task):
|
|
@@ -612,22 +682,78 @@ def test_build_training_data_with_thinking(mock_task):
|
|
|
612
682
|
"reasoning": "thinking output",
|
|
613
683
|
"chain_of_thought": "cot output",
|
|
614
684
|
}
|
|
685
|
+
mock_task_run.thinking_training_data.return_value = "thinking output"
|
|
615
686
|
mock_task.thinking_instruction = "thinking instructions"
|
|
616
687
|
assert mock_task.thinking_instruction == "thinking instructions"
|
|
617
688
|
|
|
618
|
-
|
|
689
|
+
messages = build_training_chat(
|
|
619
690
|
mock_task_run,
|
|
620
691
|
"system message",
|
|
621
|
-
|
|
692
|
+
ChatStrategy.two_message_cot,
|
|
622
693
|
thinking_instructions="thinking instructions",
|
|
623
694
|
)
|
|
624
|
-
|
|
625
|
-
assert
|
|
626
|
-
|
|
627
|
-
assert
|
|
628
|
-
assert
|
|
629
|
-
|
|
630
|
-
|
|
695
|
+
|
|
696
|
+
assert len(messages) == 5
|
|
697
|
+
system_msg = messages[0]
|
|
698
|
+
assert system_msg.role == "system"
|
|
699
|
+
assert system_msg.content == "system message"
|
|
700
|
+
|
|
701
|
+
user_msg = messages[1]
|
|
702
|
+
assert user_msg.role == "user"
|
|
703
|
+
assert (
|
|
704
|
+
user_msg.content
|
|
705
|
+
== 'The input is:\n<user_input>\n{"test": "input 你好"}\n</user_input>\n\nthinking instructions'
|
|
706
|
+
)
|
|
707
|
+
|
|
708
|
+
assistant_msg = messages[2]
|
|
709
|
+
assert assistant_msg.role == "assistant"
|
|
710
|
+
assert assistant_msg.content == "thinking output"
|
|
711
|
+
|
|
712
|
+
final_answer_prompt_msg = messages[3]
|
|
713
|
+
assert final_answer_prompt_msg.role == "user"
|
|
714
|
+
assert final_answer_prompt_msg.content == COT_FINAL_ANSWER_PROMPT
|
|
715
|
+
|
|
716
|
+
final_msg = messages[4]
|
|
717
|
+
assert final_msg.role == "assistant"
|
|
718
|
+
assert final_msg.content == '{"test": "output 你好"}'
|
|
719
|
+
|
|
720
|
+
|
|
721
|
+
def test_build_training_data_with_thinking_r1_style(mock_task):
|
|
722
|
+
# Setup with needed fields for thinking
|
|
723
|
+
mock_task_run = mock_task.runs()[0]
|
|
724
|
+
assert mock_task_run.parent_task() == mock_task
|
|
725
|
+
# It should just use the reasoning output if both thinking and chain_of_thought are present
|
|
726
|
+
mock_task_run.intermediate_outputs = {
|
|
727
|
+
"reasoning": "thinking output",
|
|
728
|
+
"chain_of_thought": "cot output",
|
|
729
|
+
}
|
|
730
|
+
mock_task_run.thinking_training_data.return_value = "thinking output"
|
|
731
|
+
mock_task.thinking_instruction = "thinking instructions"
|
|
732
|
+
|
|
733
|
+
assert mock_task.thinking_instruction == "thinking instructions"
|
|
734
|
+
|
|
735
|
+
messages = build_training_chat(
|
|
736
|
+
mock_task_run,
|
|
737
|
+
"system message",
|
|
738
|
+
ChatStrategy.single_turn_r1_thinking,
|
|
739
|
+
thinking_instructions=None,
|
|
740
|
+
)
|
|
741
|
+
|
|
742
|
+
assert len(messages) == 3
|
|
743
|
+
system_msg = messages[0]
|
|
744
|
+
assert system_msg.role == "system"
|
|
745
|
+
assert system_msg.content == "system message"
|
|
746
|
+
|
|
747
|
+
user_msg = messages[1]
|
|
748
|
+
assert user_msg.role == "user"
|
|
749
|
+
assert user_msg.content == '{"test": "input 你好"}'
|
|
750
|
+
|
|
751
|
+
final_msg = messages[2]
|
|
752
|
+
assert final_msg.role == "assistant"
|
|
753
|
+
assert (
|
|
754
|
+
final_msg.content
|
|
755
|
+
== '<think>\nthinking output\n</think>\n\n{"test": "output 你好"}'
|
|
756
|
+
)
|
|
631
757
|
|
|
632
758
|
|
|
633
759
|
def test_build_training_data_with_repaired_output(mock_task):
|
|
@@ -642,13 +768,25 @@ def test_build_training_data_with_repaired_output(mock_task):
|
|
|
642
768
|
),
|
|
643
769
|
)
|
|
644
770
|
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
assert
|
|
771
|
+
messages = build_training_chat(
|
|
772
|
+
mock_task_run,
|
|
773
|
+
"system message",
|
|
774
|
+
data_strategy=ChatStrategy.single_turn,
|
|
775
|
+
)
|
|
776
|
+
|
|
777
|
+
assert len(messages) == 3
|
|
778
|
+
system_msg = messages[0]
|
|
779
|
+
assert system_msg.role == "system"
|
|
780
|
+
assert system_msg.content == "system message"
|
|
781
|
+
|
|
782
|
+
user_msg = messages[1]
|
|
783
|
+
assert user_msg.role == "user"
|
|
784
|
+
assert user_msg.content == '{"test": "input 你好"}'
|
|
785
|
+
|
|
786
|
+
final_msg = messages[2]
|
|
787
|
+
assert final_msg.role == "assistant"
|
|
788
|
+
# Note we re-format the json
|
|
789
|
+
assert final_msg.content == '{"test": "repaired output"}'
|
|
652
790
|
|
|
653
791
|
|
|
654
792
|
def test_dataset_formatter_dump_to_file_json_schema_format(mock_dataset, tmp_path):
|
|
@@ -659,7 +797,7 @@ def test_dataset_formatter_dump_to_file_json_schema_format(mock_dataset, tmp_pat
|
|
|
659
797
|
"train",
|
|
660
798
|
DatasetFormat.OPENAI_CHAT_JSON_SCHEMA_JSONL,
|
|
661
799
|
path=output_path,
|
|
662
|
-
data_strategy=
|
|
800
|
+
data_strategy=ChatStrategy.single_turn,
|
|
663
801
|
)
|
|
664
802
|
|
|
665
803
|
assert result_path == output_path
|
|
@@ -683,3 +821,56 @@ def test_dataset_formatter_dump_to_file_json_schema_format(mock_dataset, tmp_pat
|
|
|
683
821
|
assert assistant_msg["content"] == '{"test": "output 你好"}'
|
|
684
822
|
json_content = json.loads(assistant_msg["content"])
|
|
685
823
|
assert json_content == {"test": "output 你好"}
|
|
824
|
+
|
|
825
|
+
|
|
826
|
+
@pytest.mark.parametrize(
|
|
827
|
+
"thinking,final_output,expected_output",
|
|
828
|
+
[
|
|
829
|
+
("thinking", "final output", "<think>\nthinking\n</think>\n\nfinal output"),
|
|
830
|
+
("thinking", '{"name":"joe"}', '<think>\nthinking\n</think>\n\n{"name":"joe"}'),
|
|
831
|
+
],
|
|
832
|
+
)
|
|
833
|
+
def test_serialize_r1_style_message(thinking, final_output, expected_output):
|
|
834
|
+
assert (
|
|
835
|
+
serialize_r1_style_message(thinking=thinking, final_output=final_output)
|
|
836
|
+
== expected_output
|
|
837
|
+
)
|
|
838
|
+
|
|
839
|
+
|
|
840
|
+
@pytest.mark.parametrize(
|
|
841
|
+
"thinking,final_output",
|
|
842
|
+
[
|
|
843
|
+
(None, "final output"),
|
|
844
|
+
("", "final output"),
|
|
845
|
+
(" ", "final output"),
|
|
846
|
+
],
|
|
847
|
+
)
|
|
848
|
+
def test_serialize_r1_style_message_missing_thinking(thinking, final_output):
|
|
849
|
+
with pytest.raises(
|
|
850
|
+
ValueError,
|
|
851
|
+
match=re.escape(
|
|
852
|
+
"Thinking data is required when fine-tuning thinking models (R1, QwQ, etc). Please ensure your fine-tuning dataset contains reasoning or chain of thought output for every entry."
|
|
853
|
+
),
|
|
854
|
+
):
|
|
855
|
+
serialize_r1_style_message(thinking=thinking, final_output=final_output)
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
def test_vertex_gemini_role_map_coverage():
|
|
859
|
+
"""Test that VERTEX_GEMINI_ROLE_MAP covers all possible ChatMessage.role values"""
|
|
860
|
+
from typing import Literal, get_type_hints
|
|
861
|
+
|
|
862
|
+
# Get the Literal type from ChatMessage.role
|
|
863
|
+
role_type = get_type_hints(ChatMessage)["role"]
|
|
864
|
+
# Extract the possible values from the Literal type
|
|
865
|
+
possible_roles = role_type.__args__ # type: ignore
|
|
866
|
+
|
|
867
|
+
# Check that every possible role is in the map
|
|
868
|
+
for role in possible_roles:
|
|
869
|
+
assert role in VERTEX_GEMINI_ROLE_MAP, (
|
|
870
|
+
f"Role {role} is not mapped in VERTEX_GEMINI_ROLE_MAP"
|
|
871
|
+
)
|
|
872
|
+
|
|
873
|
+
# Check that there are no extra mappings
|
|
874
|
+
assert set(VERTEX_GEMINI_ROLE_MAP.keys()) == set(possible_roles), (
|
|
875
|
+
"VERTEX_GEMINI_ROLE_MAP has extra mappings"
|
|
876
|
+
)
|