kiln-ai 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +2 -0
- kiln_ai/adapters/adapter_registry.py +22 -44
- kiln_ai/adapters/chat/__init__.py +8 -0
- kiln_ai/adapters/chat/chat_formatter.py +234 -0
- kiln_ai/adapters/chat/test_chat_formatter.py +131 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +19 -6
- kiln_ai/adapters/eval/base_eval.py +8 -6
- kiln_ai/adapters/eval/eval_runner.py +4 -1
- kiln_ai/adapters/eval/g_eval.py +23 -5
- kiln_ai/adapters/eval/test_base_eval.py +166 -15
- kiln_ai/adapters/eval/test_eval_runner.py +3 -0
- kiln_ai/adapters/eval/test_g_eval.py +1 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +2 -2
- kiln_ai/adapters/fine_tune/dataset_formatter.py +138 -272
- kiln_ai/adapters/fine_tune/test_base_finetune.py +10 -10
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +287 -353
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +3 -3
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +6 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +1 -0
- kiln_ai/adapters/fine_tune/test_vertex_finetune.py +4 -4
- kiln_ai/adapters/fine_tune/together_finetune.py +12 -1
- kiln_ai/adapters/ml_model_list.py +80 -43
- kiln_ai/adapters/model_adapters/base_adapter.py +73 -26
- kiln_ai/adapters/model_adapters/litellm_adapter.py +79 -97
- kiln_ai/adapters/model_adapters/litellm_config.py +3 -2
- kiln_ai/adapters/model_adapters/test_base_adapter.py +235 -60
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +56 -21
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +41 -0
- kiln_ai/adapters/model_adapters/test_structured_output.py +44 -12
- kiln_ai/adapters/prompt_builders.py +0 -16
- kiln_ai/adapters/provider_tools.py +27 -9
- kiln_ai/adapters/repair/test_repair_task.py +24 -3
- kiln_ai/adapters/test_adapter_registry.py +88 -28
- kiln_ai/adapters/test_ml_model_list.py +158 -0
- kiln_ai/adapters/test_prompt_adaptors.py +17 -3
- kiln_ai/adapters/test_prompt_builders.py +3 -16
- kiln_ai/adapters/test_provider_tools.py +69 -20
- kiln_ai/datamodel/__init__.py +0 -2
- kiln_ai/datamodel/datamodel_enums.py +38 -13
- kiln_ai/datamodel/finetune.py +12 -7
- kiln_ai/datamodel/task.py +68 -7
- kiln_ai/datamodel/test_basemodel.py +2 -1
- kiln_ai/datamodel/test_dataset_split.py +0 -8
- kiln_ai/datamodel/test_models.py +33 -10
- kiln_ai/datamodel/test_task.py +168 -2
- kiln_ai/utils/config.py +3 -2
- kiln_ai/utils/dataset_import.py +1 -1
- kiln_ai/utils/logging.py +165 -0
- kiln_ai/utils/test_config.py +23 -0
- kiln_ai/utils/test_dataset_import.py +30 -0
- {kiln_ai-0.16.0.dist-info → kiln_ai-0.17.0.dist-info}/METADATA +1 -1
- {kiln_ai-0.16.0.dist-info → kiln_ai-0.17.0.dist-info}/RECORD +54 -49
- {kiln_ai-0.16.0.dist-info → kiln_ai-0.17.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.16.0.dist-info → kiln_ai-0.17.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -7,11 +7,12 @@ from unittest.mock import Mock
|
|
|
7
7
|
|
|
8
8
|
import pytest
|
|
9
9
|
|
|
10
|
+
from kiln_ai.adapters.chat.chat_formatter import COT_FINAL_ANSWER_PROMPT, ChatMessage
|
|
10
11
|
from kiln_ai.adapters.fine_tune.dataset_formatter import (
|
|
12
|
+
VERTEX_GEMINI_ROLE_MAP,
|
|
11
13
|
DatasetFormat,
|
|
12
14
|
DatasetFormatter,
|
|
13
|
-
|
|
14
|
-
build_training_data,
|
|
15
|
+
build_training_chat,
|
|
15
16
|
generate_chat_message_response,
|
|
16
17
|
generate_chat_message_toolcall,
|
|
17
18
|
generate_huggingface_chat_template,
|
|
@@ -19,16 +20,15 @@ from kiln_ai.adapters.fine_tune.dataset_formatter import (
|
|
|
19
20
|
generate_vertex_gemini,
|
|
20
21
|
serialize_r1_style_message,
|
|
21
22
|
)
|
|
22
|
-
from kiln_ai.adapters.model_adapters.base_adapter import COT_FINAL_ANSWER_PROMPT
|
|
23
23
|
from kiln_ai.datamodel import (
|
|
24
24
|
DatasetSplit,
|
|
25
25
|
DataSource,
|
|
26
26
|
DataSourceType,
|
|
27
|
-
FinetuneDataStrategy,
|
|
28
27
|
Task,
|
|
29
28
|
TaskOutput,
|
|
30
29
|
TaskRun,
|
|
31
30
|
)
|
|
31
|
+
from kiln_ai.datamodel.datamodel_enums import ChatStrategy
|
|
32
32
|
|
|
33
33
|
logger = logging.getLogger(__name__)
|
|
34
34
|
|
|
@@ -100,41 +100,56 @@ def mock_dataset(mock_task):
|
|
|
100
100
|
return dataset
|
|
101
101
|
|
|
102
102
|
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
103
|
+
@pytest.fixture
|
|
104
|
+
def mock_training_chat_short():
|
|
105
|
+
return [
|
|
106
|
+
ChatMessage(role="system", content="system message"),
|
|
107
|
+
ChatMessage(
|
|
108
|
+
role="user",
|
|
109
|
+
content="test input",
|
|
110
|
+
),
|
|
111
|
+
ChatMessage(role="assistant", content="test output"),
|
|
112
|
+
]
|
|
109
113
|
|
|
110
|
-
result = generate_chat_message_response(thinking_data)
|
|
111
114
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
115
|
+
@pytest.fixture
|
|
116
|
+
def mock_training_chat_two_step_plaintext():
|
|
117
|
+
return [
|
|
118
|
+
ChatMessage(role="system", content="system message"),
|
|
119
|
+
ChatMessage(
|
|
120
|
+
role="user",
|
|
121
|
+
content="The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
|
|
122
|
+
),
|
|
123
|
+
ChatMessage(role="assistant", content="thinking output"),
|
|
124
|
+
ChatMessage(role="user", content="thinking final answer prompt"),
|
|
125
|
+
ChatMessage(role="assistant", content="test output"),
|
|
126
|
+
]
|
|
119
127
|
|
|
120
128
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
129
|
+
@pytest.fixture
|
|
130
|
+
def mock_training_chat_two_step_json():
|
|
131
|
+
return [
|
|
132
|
+
ChatMessage(role="system", content="system message"),
|
|
133
|
+
ChatMessage(
|
|
134
|
+
role="user",
|
|
135
|
+
content="The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
|
|
136
|
+
),
|
|
137
|
+
ChatMessage(role="assistant", content="thinking output"),
|
|
138
|
+
ChatMessage(role="user", content="thinking final answer prompt"),
|
|
139
|
+
ChatMessage(role="assistant", content='{"a":"你好"}'),
|
|
140
|
+
]
|
|
130
141
|
|
|
131
|
-
|
|
142
|
+
|
|
143
|
+
def test_generate_chat_message_response(mock_training_chat_two_step_plaintext):
|
|
144
|
+
result = generate_chat_message_response(mock_training_chat_two_step_plaintext)
|
|
132
145
|
|
|
133
146
|
assert result == {
|
|
134
147
|
"messages": [
|
|
135
148
|
{"role": "system", "content": "system message"},
|
|
136
|
-
{
|
|
137
|
-
|
|
149
|
+
{
|
|
150
|
+
"role": "user",
|
|
151
|
+
"content": "The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
|
|
152
|
+
},
|
|
138
153
|
{"role": "assistant", "content": "thinking output"},
|
|
139
154
|
{"role": "user", "content": "thinking final answer prompt"},
|
|
140
155
|
{"role": "assistant", "content": "test output"},
|
|
@@ -142,79 +157,33 @@ def test_generate_chat_message_response_thinking():
|
|
|
142
157
|
}
|
|
143
158
|
|
|
144
159
|
|
|
145
|
-
def
|
|
146
|
-
|
|
147
|
-
input="test input",
|
|
148
|
-
system_message="system message",
|
|
149
|
-
final_output="test output",
|
|
150
|
-
thinking="thinking output",
|
|
151
|
-
thinking_instructions=None,
|
|
152
|
-
thinking_final_answer_prompt=None,
|
|
153
|
-
thinking_r1_style=True,
|
|
154
|
-
)
|
|
155
|
-
|
|
156
|
-
result = generate_chat_message_response(thinking_data)
|
|
160
|
+
def test_generate_chat_message_response_json(mock_training_chat_two_step_json):
|
|
161
|
+
result = generate_chat_message_response(mock_training_chat_two_step_json)
|
|
157
162
|
|
|
158
163
|
assert result == {
|
|
159
164
|
"messages": [
|
|
160
165
|
{"role": "system", "content": "system message"},
|
|
161
|
-
{"role": "user", "content": "test input"},
|
|
162
166
|
{
|
|
163
|
-
"role": "
|
|
164
|
-
"content": "<
|
|
167
|
+
"role": "user",
|
|
168
|
+
"content": "The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
|
|
165
169
|
},
|
|
170
|
+
{"role": "assistant", "content": "thinking output"},
|
|
171
|
+
{"role": "user", "content": "thinking final answer prompt"},
|
|
172
|
+
{"role": "assistant", "content": '{"a":"你好"}'},
|
|
166
173
|
]
|
|
167
174
|
}
|
|
168
175
|
|
|
169
176
|
|
|
170
|
-
def test_generate_chat_message_toolcall():
|
|
171
|
-
|
|
172
|
-
input="test input 你好",
|
|
173
|
-
system_message="system message 你好",
|
|
174
|
-
final_output='{"key": "value 你好"}',
|
|
175
|
-
)
|
|
176
|
-
|
|
177
|
-
result = generate_chat_message_toolcall(training_data)
|
|
177
|
+
def test_generate_chat_message_toolcall(mock_training_chat_two_step_json):
|
|
178
|
+
result = generate_chat_message_toolcall(mock_training_chat_two_step_json)
|
|
178
179
|
|
|
179
180
|
assert result == {
|
|
180
181
|
"messages": [
|
|
181
|
-
{"role": "system", "content": "system message
|
|
182
|
-
{"role": "user", "content": "test input 你好"},
|
|
182
|
+
{"role": "system", "content": "system message"},
|
|
183
183
|
{
|
|
184
|
-
"role": "
|
|
185
|
-
"content":
|
|
186
|
-
"tool_calls": [
|
|
187
|
-
{
|
|
188
|
-
"id": "call_1",
|
|
189
|
-
"type": "function",
|
|
190
|
-
"function": {
|
|
191
|
-
"name": "task_response",
|
|
192
|
-
"arguments": '{"key": "value 你好"}',
|
|
193
|
-
},
|
|
194
|
-
}
|
|
195
|
-
],
|
|
184
|
+
"role": "user",
|
|
185
|
+
"content": "The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
|
|
196
186
|
},
|
|
197
|
-
]
|
|
198
|
-
}
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
def test_generate_chat_message_toolcall_thinking():
|
|
202
|
-
training_data = ModelTrainingData(
|
|
203
|
-
input="test input",
|
|
204
|
-
system_message="system message",
|
|
205
|
-
final_output='{"key": "value"}',
|
|
206
|
-
thinking="thinking output",
|
|
207
|
-
thinking_instructions="thinking instructions",
|
|
208
|
-
thinking_final_answer_prompt="thinking final answer prompt",
|
|
209
|
-
)
|
|
210
|
-
|
|
211
|
-
result = generate_chat_message_toolcall(training_data)
|
|
212
|
-
|
|
213
|
-
assert result == {
|
|
214
|
-
"messages": [
|
|
215
|
-
{"role": "system", "content": "system message"},
|
|
216
|
-
{"role": "user", "content": "test input"},
|
|
217
|
-
{"role": "user", "content": "thinking instructions"},
|
|
218
187
|
{"role": "assistant", "content": "thinking output"},
|
|
219
188
|
{"role": "user", "content": "thinking final answer prompt"},
|
|
220
189
|
{
|
|
@@ -226,7 +195,7 @@ def test_generate_chat_message_toolcall_thinking():
|
|
|
226
195
|
"type": "function",
|
|
227
196
|
"function": {
|
|
228
197
|
"name": "task_response",
|
|
229
|
-
"arguments": '{"
|
|
198
|
+
"arguments": '{"a": "你好"}',
|
|
230
199
|
},
|
|
231
200
|
}
|
|
232
201
|
],
|
|
@@ -235,49 +204,17 @@ def test_generate_chat_message_toolcall_thinking():
|
|
|
235
204
|
}
|
|
236
205
|
|
|
237
206
|
|
|
238
|
-
def
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
final_output='{"key": "value"}',
|
|
243
|
-
thinking="thinking output",
|
|
244
|
-
thinking_instructions=None,
|
|
245
|
-
thinking_final_answer_prompt=None,
|
|
246
|
-
thinking_r1_style=True,
|
|
247
|
-
)
|
|
248
|
-
|
|
249
|
-
with pytest.raises(
|
|
250
|
-
ValueError,
|
|
251
|
-
match="R1 style thinking is not supported for tool call downloads",
|
|
252
|
-
):
|
|
253
|
-
generate_chat_message_toolcall(training_data)
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
def test_generate_chat_message_toolcall_invalid_json():
|
|
257
|
-
training_data = ModelTrainingData(
|
|
258
|
-
input="test input",
|
|
259
|
-
system_message="system message",
|
|
260
|
-
final_output="invalid json",
|
|
261
|
-
)
|
|
262
|
-
|
|
263
|
-
with pytest.raises(ValueError, match="Invalid JSON in for tool call"):
|
|
264
|
-
generate_chat_message_toolcall(training_data)
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
def test_dataset_formatter_init_no_parent_task(mock_dataset):
|
|
268
|
-
mock_dataset.parent_task.return_value = None
|
|
269
|
-
|
|
270
|
-
with pytest.raises(ValueError, match="Dataset has no parent task"):
|
|
271
|
-
DatasetFormatter(mock_dataset, "system message")
|
|
207
|
+
def test_generate_chat_message_toolcall_invalid_json(mock_training_chat_two_step_json):
|
|
208
|
+
mock_training_chat_two_step_json[-1].content = "invalid json"
|
|
209
|
+
with pytest.raises(ValueError, match="^Last message is not JSON"):
|
|
210
|
+
generate_chat_message_toolcall(mock_training_chat_two_step_json)
|
|
272
211
|
|
|
273
212
|
|
|
274
213
|
def test_dataset_formatter_dump_invalid_format(mock_dataset):
|
|
275
214
|
formatter = DatasetFormatter(mock_dataset, "system message")
|
|
276
215
|
|
|
277
216
|
with pytest.raises(ValueError, match="Unsupported format"):
|
|
278
|
-
formatter.dump_to_file(
|
|
279
|
-
"train", "invalid_format", FinetuneDataStrategy.final_only
|
|
280
|
-
) # type: ignore
|
|
217
|
+
formatter.dump_to_file("train", "invalid_format", ChatStrategy.single_turn)
|
|
281
218
|
|
|
282
219
|
|
|
283
220
|
def test_dataset_formatter_dump_invalid_split(mock_dataset):
|
|
@@ -287,7 +224,7 @@ def test_dataset_formatter_dump_invalid_split(mock_dataset):
|
|
|
287
224
|
formatter.dump_to_file(
|
|
288
225
|
"invalid_split",
|
|
289
226
|
DatasetFormat.OPENAI_CHAT_JSONL,
|
|
290
|
-
|
|
227
|
+
ChatStrategy.single_turn,
|
|
291
228
|
)
|
|
292
229
|
|
|
293
230
|
|
|
@@ -299,7 +236,7 @@ def test_dataset_formatter_dump_to_file(mock_dataset, tmp_path):
|
|
|
299
236
|
"train",
|
|
300
237
|
DatasetFormat.OPENAI_CHAT_JSONL,
|
|
301
238
|
path=output_path,
|
|
302
|
-
data_strategy=
|
|
239
|
+
data_strategy=ChatStrategy.single_turn,
|
|
303
240
|
)
|
|
304
241
|
|
|
305
242
|
assert result_path == output_path
|
|
@@ -325,7 +262,7 @@ def test_dataset_formatter_dump_to_temp_file(mock_dataset):
|
|
|
325
262
|
result_path = formatter.dump_to_file(
|
|
326
263
|
"train",
|
|
327
264
|
DatasetFormat.OPENAI_CHAT_JSONL,
|
|
328
|
-
data_strategy=
|
|
265
|
+
data_strategy=ChatStrategy.single_turn,
|
|
329
266
|
)
|
|
330
267
|
|
|
331
268
|
assert result_path.exists()
|
|
@@ -356,7 +293,7 @@ def test_dataset_formatter_dump_to_file_tool_format(mock_dataset, tmp_path):
|
|
|
356
293
|
"train",
|
|
357
294
|
DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL,
|
|
358
295
|
path=output_path,
|
|
359
|
-
data_strategy=
|
|
296
|
+
data_strategy=ChatStrategy.single_turn,
|
|
360
297
|
)
|
|
361
298
|
|
|
362
299
|
assert result_path == output_path
|
|
@@ -396,7 +333,7 @@ def test_dataset_formatter_dump_with_intermediate_data(
|
|
|
396
333
|
result_path = formatter.dump_to_file(
|
|
397
334
|
"train",
|
|
398
335
|
DatasetFormat.OPENAI_CHAT_JSONL,
|
|
399
|
-
data_strategy=
|
|
336
|
+
data_strategy=ChatStrategy.two_message_cot_legacy,
|
|
400
337
|
)
|
|
401
338
|
|
|
402
339
|
assert result_path.exists()
|
|
@@ -427,7 +364,7 @@ def test_dataset_formatter_dump_with_intermediate_data_r1_style(
|
|
|
427
364
|
result_path = formatter.dump_to_file(
|
|
428
365
|
"train",
|
|
429
366
|
DatasetFormat.OPENAI_CHAT_JSONL,
|
|
430
|
-
data_strategy=
|
|
367
|
+
data_strategy=ChatStrategy.single_turn_r1_thinking,
|
|
431
368
|
)
|
|
432
369
|
|
|
433
370
|
assert result_path.exists()
|
|
@@ -456,7 +393,7 @@ def test_dataset_formatter_dump_with_intermediate_data_custom_instructions(
|
|
|
456
393
|
result_path = formatter.dump_to_file(
|
|
457
394
|
"train",
|
|
458
395
|
DatasetFormat.OPENAI_CHAT_JSONL,
|
|
459
|
-
data_strategy=
|
|
396
|
+
data_strategy=ChatStrategy.two_message_cot_legacy,
|
|
460
397
|
)
|
|
461
398
|
|
|
462
399
|
assert result_path.exists()
|
|
@@ -476,41 +413,16 @@ def test_dataset_formatter_dump_with_intermediate_data_custom_instructions(
|
|
|
476
413
|
assert "thinking output" in line
|
|
477
414
|
|
|
478
415
|
|
|
479
|
-
def test_generate_huggingface_chat_template():
|
|
480
|
-
|
|
481
|
-
input="test input",
|
|
482
|
-
system_message="system message",
|
|
483
|
-
final_output="test output",
|
|
484
|
-
)
|
|
485
|
-
|
|
486
|
-
result = generate_huggingface_chat_template(training_data)
|
|
487
|
-
|
|
488
|
-
assert result == {
|
|
489
|
-
"conversations": [
|
|
490
|
-
{"role": "system", "content": "system message"},
|
|
491
|
-
{"role": "user", "content": "test input"},
|
|
492
|
-
{"role": "assistant", "content": "test output"},
|
|
493
|
-
]
|
|
494
|
-
}
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
def test_generate_huggingface_chat_template_thinking():
|
|
498
|
-
training_data = ModelTrainingData(
|
|
499
|
-
input="test input",
|
|
500
|
-
system_message="system message",
|
|
501
|
-
final_output="test output",
|
|
502
|
-
thinking="thinking output",
|
|
503
|
-
thinking_instructions="thinking instructions",
|
|
504
|
-
thinking_final_answer_prompt="thinking final answer prompt",
|
|
505
|
-
)
|
|
506
|
-
|
|
507
|
-
result = generate_huggingface_chat_template(training_data)
|
|
416
|
+
def test_generate_huggingface_chat_template(mock_training_chat_two_step_plaintext):
|
|
417
|
+
result = generate_huggingface_chat_template(mock_training_chat_two_step_plaintext)
|
|
508
418
|
|
|
509
419
|
assert result == {
|
|
510
420
|
"conversations": [
|
|
511
421
|
{"role": "system", "content": "system message"},
|
|
512
|
-
{
|
|
513
|
-
|
|
422
|
+
{
|
|
423
|
+
"role": "user",
|
|
424
|
+
"content": "The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
|
|
425
|
+
},
|
|
514
426
|
{"role": "assistant", "content": "thinking output"},
|
|
515
427
|
{"role": "user", "content": "thinking final answer prompt"},
|
|
516
428
|
{"role": "assistant", "content": "test output"},
|
|
@@ -518,39 +430,8 @@ def test_generate_huggingface_chat_template_thinking():
|
|
|
518
430
|
}
|
|
519
431
|
|
|
520
432
|
|
|
521
|
-
def
|
|
522
|
-
|
|
523
|
-
input="test input",
|
|
524
|
-
system_message="system message",
|
|
525
|
-
final_output="test output",
|
|
526
|
-
thinking="thinking output",
|
|
527
|
-
thinking_instructions=None,
|
|
528
|
-
thinking_final_answer_prompt=None,
|
|
529
|
-
thinking_r1_style=True,
|
|
530
|
-
)
|
|
531
|
-
|
|
532
|
-
result = generate_huggingface_chat_template(training_data)
|
|
533
|
-
|
|
534
|
-
assert result == {
|
|
535
|
-
"conversations": [
|
|
536
|
-
{"role": "system", "content": "system message"},
|
|
537
|
-
{"role": "user", "content": "test input"},
|
|
538
|
-
{
|
|
539
|
-
"role": "assistant",
|
|
540
|
-
"content": "<think>\nthinking output\n</think>\n\ntest output",
|
|
541
|
-
},
|
|
542
|
-
]
|
|
543
|
-
}
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
def test_generate_vertex_template():
|
|
547
|
-
training_data = ModelTrainingData(
|
|
548
|
-
input="test input",
|
|
549
|
-
system_message="system message",
|
|
550
|
-
final_output="test output",
|
|
551
|
-
)
|
|
552
|
-
|
|
553
|
-
result = generate_vertex_gemini(training_data)
|
|
433
|
+
def test_generate_vertex_template(mock_training_chat_short):
|
|
434
|
+
result = generate_vertex_gemini(mock_training_chat_short)
|
|
554
435
|
|
|
555
436
|
assert result == {
|
|
556
437
|
"systemInstruction": {
|
|
@@ -568,17 +449,8 @@ def test_generate_vertex_template():
|
|
|
568
449
|
}
|
|
569
450
|
|
|
570
451
|
|
|
571
|
-
def test_generate_vertex_template_thinking():
|
|
572
|
-
|
|
573
|
-
input="test input",
|
|
574
|
-
system_message="system message",
|
|
575
|
-
final_output="test output",
|
|
576
|
-
thinking="thinking output",
|
|
577
|
-
thinking_instructions="thinking instructions",
|
|
578
|
-
thinking_final_answer_prompt="thinking final answer prompt",
|
|
579
|
-
)
|
|
580
|
-
|
|
581
|
-
result = generate_vertex_gemini(training_data)
|
|
452
|
+
def test_generate_vertex_template_thinking(mock_training_chat_two_step_plaintext):
|
|
453
|
+
result = generate_vertex_gemini(mock_training_chat_two_step_plaintext)
|
|
582
454
|
|
|
583
455
|
assert result == {
|
|
584
456
|
"systemInstruction": {
|
|
@@ -590,8 +462,14 @@ def test_generate_vertex_template_thinking():
|
|
|
590
462
|
],
|
|
591
463
|
},
|
|
592
464
|
"contents": [
|
|
593
|
-
{
|
|
594
|
-
|
|
465
|
+
{
|
|
466
|
+
"role": "user",
|
|
467
|
+
"parts": [
|
|
468
|
+
{
|
|
469
|
+
"text": "The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
|
|
470
|
+
}
|
|
471
|
+
],
|
|
472
|
+
},
|
|
595
473
|
{"role": "model", "parts": [{"text": "thinking output"}]},
|
|
596
474
|
{"role": "user", "parts": [{"text": "thinking final answer prompt"}]},
|
|
597
475
|
{"role": "model", "parts": [{"text": "test output"}]},
|
|
@@ -599,31 +477,14 @@ def test_generate_vertex_template_thinking():
|
|
|
599
477
|
}
|
|
600
478
|
|
|
601
479
|
|
|
602
|
-
def test_generate_vertex_template_thinking_r1_style():
|
|
603
|
-
training_data = ModelTrainingData(
|
|
604
|
-
input="test input",
|
|
605
|
-
system_message="system message",
|
|
606
|
-
final_output="test output",
|
|
607
|
-
thinking="thinking output",
|
|
608
|
-
thinking_instructions=None,
|
|
609
|
-
thinking_final_answer_prompt=None,
|
|
610
|
-
thinking_r1_style=True,
|
|
611
|
-
)
|
|
612
|
-
|
|
613
|
-
with pytest.raises(
|
|
614
|
-
ValueError, match="R1 style thinking is not supported for Vertex Gemini"
|
|
615
|
-
):
|
|
616
|
-
generate_vertex_gemini(training_data)
|
|
617
|
-
|
|
618
|
-
|
|
619
480
|
def test_generate_huggingface_chat_template_toolcall():
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
481
|
+
messages = [
|
|
482
|
+
ChatMessage("system", "system message"),
|
|
483
|
+
ChatMessage("user", "test input"),
|
|
484
|
+
ChatMessage("assistant", '{"key":"value"}'),
|
|
485
|
+
]
|
|
625
486
|
|
|
626
|
-
result = generate_huggingface_chat_template_toolcall(
|
|
487
|
+
result = generate_huggingface_chat_template_toolcall(messages)
|
|
627
488
|
|
|
628
489
|
assert result["conversations"][0] == {"role": "system", "content": "system message"}
|
|
629
490
|
assert result["conversations"][1] == {"role": "user", "content": "test input"}
|
|
@@ -638,34 +499,28 @@ def test_generate_huggingface_chat_template_toolcall():
|
|
|
638
499
|
assert tool_call["function"]["arguments"] == {"key": "value"}
|
|
639
500
|
|
|
640
501
|
|
|
641
|
-
def test_generate_huggingface_chat_template_toolcall_thinking(
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
thinking="thinking output",
|
|
647
|
-
thinking_instructions="thinking instructions",
|
|
648
|
-
thinking_final_answer_prompt="thinking final answer prompt",
|
|
502
|
+
def test_generate_huggingface_chat_template_toolcall_thinking(
|
|
503
|
+
mock_training_chat_two_step_json,
|
|
504
|
+
):
|
|
505
|
+
result = generate_huggingface_chat_template_toolcall(
|
|
506
|
+
mock_training_chat_two_step_json
|
|
649
507
|
)
|
|
650
508
|
|
|
651
|
-
result = generate_huggingface_chat_template_toolcall(training_data)
|
|
652
|
-
|
|
653
509
|
assert result["conversations"][0] == {"role": "system", "content": "system message"}
|
|
654
|
-
assert result["conversations"][1] == {
|
|
655
|
-
assert result["conversations"][2] == {
|
|
510
|
+
assert result["conversations"][1] == {
|
|
656
511
|
"role": "user",
|
|
657
|
-
"content": "
|
|
512
|
+
"content": "The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
|
|
658
513
|
}
|
|
659
|
-
assert result["conversations"][
|
|
514
|
+
assert result["conversations"][2] == {
|
|
660
515
|
"role": "assistant",
|
|
661
516
|
"content": "thinking output",
|
|
662
517
|
}
|
|
663
|
-
assert result["conversations"][
|
|
518
|
+
assert result["conversations"][3] == {
|
|
664
519
|
"role": "user",
|
|
665
520
|
"content": "thinking final answer prompt",
|
|
666
521
|
}
|
|
667
522
|
|
|
668
|
-
assistant_msg = result["conversations"][
|
|
523
|
+
assistant_msg = result["conversations"][4]
|
|
669
524
|
assert assistant_msg["role"] == "assistant"
|
|
670
525
|
assert len(assistant_msg["tool_calls"]) == 1
|
|
671
526
|
tool_call = assistant_msg["tool_calls"][0]
|
|
@@ -673,53 +528,39 @@ def test_generate_huggingface_chat_template_toolcall_thinking():
|
|
|
673
528
|
assert tool_call["function"]["name"] == "task_response"
|
|
674
529
|
assert len(tool_call["function"]["id"]) == 9 # UUID is truncated to 9 chars
|
|
675
530
|
assert tool_call["function"]["id"].isalnum() # Check ID is alphanumeric
|
|
676
|
-
assert tool_call["function"]["arguments"] == {"
|
|
531
|
+
assert tool_call["function"]["arguments"] == {"a": "你好"}
|
|
677
532
|
|
|
678
533
|
|
|
679
|
-
def
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
final_output='{"key": "value"}',
|
|
684
|
-
thinking="thinking output",
|
|
685
|
-
thinking_instructions=None,
|
|
686
|
-
thinking_final_answer_prompt=None,
|
|
687
|
-
thinking_r1_style=True,
|
|
688
|
-
)
|
|
689
|
-
|
|
690
|
-
with pytest.raises(
|
|
691
|
-
ValueError,
|
|
692
|
-
match="R1 style thinking is not supported for tool call downloads",
|
|
693
|
-
):
|
|
694
|
-
generate_huggingface_chat_template_toolcall(training_data)
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
def test_generate_huggingface_chat_template_toolcall_invalid_json():
|
|
698
|
-
training_data = ModelTrainingData(
|
|
699
|
-
input="test input",
|
|
700
|
-
system_message="system message",
|
|
701
|
-
final_output="invalid json",
|
|
702
|
-
)
|
|
534
|
+
def test_generate_huggingface_chat_template_toolcall_invalid_json(
|
|
535
|
+
mock_training_chat_two_step_json,
|
|
536
|
+
):
|
|
537
|
+
mock_training_chat_two_step_json[-1].content = "invalid json"
|
|
703
538
|
|
|
704
|
-
with pytest.raises(ValueError, match="
|
|
705
|
-
generate_huggingface_chat_template_toolcall(
|
|
539
|
+
with pytest.raises(ValueError, match="^Last message is not JSON"):
|
|
540
|
+
generate_huggingface_chat_template_toolcall(mock_training_chat_two_step_json)
|
|
706
541
|
|
|
707
542
|
|
|
708
|
-
def
|
|
543
|
+
def test_build_training_chat(mock_task):
|
|
709
544
|
# Non repaired should use original output
|
|
710
545
|
mock_task_run = mock_task.runs()[0]
|
|
711
|
-
|
|
546
|
+
messages = build_training_chat(
|
|
712
547
|
mock_task_run,
|
|
713
548
|
"system message",
|
|
714
|
-
data_strategy=
|
|
549
|
+
data_strategy=ChatStrategy.single_turn,
|
|
715
550
|
)
|
|
716
|
-
|
|
717
|
-
assert
|
|
718
|
-
|
|
719
|
-
assert
|
|
720
|
-
assert
|
|
721
|
-
|
|
722
|
-
|
|
551
|
+
|
|
552
|
+
assert len(messages) == 3
|
|
553
|
+
system_msg = messages[0]
|
|
554
|
+
assert system_msg.role == "system"
|
|
555
|
+
assert system_msg.content == "system message"
|
|
556
|
+
|
|
557
|
+
user_msg = messages[1]
|
|
558
|
+
assert user_msg.role == "user"
|
|
559
|
+
assert user_msg.content == '{"test": "input 你好"}'
|
|
560
|
+
|
|
561
|
+
final_msg = messages[2]
|
|
562
|
+
assert final_msg.role == "assistant"
|
|
563
|
+
assert final_msg.content == '{"test": "output 你好"}'
|
|
723
564
|
|
|
724
565
|
|
|
725
566
|
def test_build_training_data_with_COT(mock_task):
|
|
@@ -729,47 +570,76 @@ def test_build_training_data_with_COT(mock_task):
|
|
|
729
570
|
mock_task_run.intermediate_outputs = {"chain_of_thought": "cot output"}
|
|
730
571
|
mock_task_run.thinking_training_data.return_value = "cot output"
|
|
731
572
|
|
|
732
|
-
|
|
573
|
+
messages = build_training_chat(
|
|
733
574
|
mock_task_run,
|
|
734
575
|
"system message",
|
|
735
|
-
data_strategy=
|
|
576
|
+
data_strategy=ChatStrategy.two_message_cot,
|
|
736
577
|
thinking_instructions="thinking instructions",
|
|
737
578
|
)
|
|
738
|
-
|
|
739
|
-
assert
|
|
740
|
-
|
|
741
|
-
assert
|
|
742
|
-
assert
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
assert
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
training_data = ModelTrainingData(
|
|
750
|
-
input="test input",
|
|
751
|
-
system_message="system message",
|
|
752
|
-
final_output="test output",
|
|
753
|
-
thinking="thinking output",
|
|
754
|
-
thinking_instructions="thinking instructions",
|
|
755
|
-
thinking_final_answer_prompt=COT_FINAL_ANSWER_PROMPT,
|
|
756
|
-
thinking_r1_style=False,
|
|
579
|
+
|
|
580
|
+
assert len(messages) == 5
|
|
581
|
+
system_msg = messages[0]
|
|
582
|
+
assert system_msg.role == "system"
|
|
583
|
+
assert system_msg.content == "system message"
|
|
584
|
+
|
|
585
|
+
user_msg = messages[1]
|
|
586
|
+
assert user_msg.role == "user"
|
|
587
|
+
assert (
|
|
588
|
+
user_msg.content
|
|
589
|
+
== 'The input is:\n<user_input>\n{"test": "input 你好"}\n</user_input>\n\nthinking instructions'
|
|
757
590
|
)
|
|
758
|
-
assert training_data.supports_cot() == True
|
|
759
591
|
|
|
592
|
+
assistant_msg = messages[2]
|
|
593
|
+
assert assistant_msg.role == "assistant"
|
|
594
|
+
assert assistant_msg.content == "cot output"
|
|
595
|
+
|
|
596
|
+
final_answer_prompt_msg = messages[3]
|
|
597
|
+
assert final_answer_prompt_msg.role == "user"
|
|
598
|
+
assert final_answer_prompt_msg.content == COT_FINAL_ANSWER_PROMPT
|
|
599
|
+
|
|
600
|
+
final_msg = messages[4]
|
|
601
|
+
assert final_msg.role == "assistant"
|
|
602
|
+
assert final_msg.content == '{"test": "output 你好"}'
|
|
603
|
+
|
|
604
|
+
|
|
605
|
+
def test_build_training_data_with_COT_legacy(mock_task):
|
|
606
|
+
# Setup with needed fields for thinking
|
|
607
|
+
mock_task_run = mock_task.runs()[0]
|
|
608
|
+
assert mock_task_run.parent_task() == mock_task
|
|
609
|
+
mock_task_run.intermediate_outputs = {"chain_of_thought": "cot output"}
|
|
610
|
+
mock_task_run.thinking_training_data.return_value = "cot output"
|
|
760
611
|
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
final_output="test output",
|
|
766
|
-
thinking="thinking output",
|
|
612
|
+
messages = build_training_chat(
|
|
613
|
+
mock_task_run,
|
|
614
|
+
"system message",
|
|
615
|
+
data_strategy=ChatStrategy.two_message_cot_legacy,
|
|
767
616
|
thinking_instructions="thinking instructions",
|
|
768
|
-
thinking_r1_style=True,
|
|
769
617
|
)
|
|
770
618
|
|
|
771
|
-
|
|
772
|
-
|
|
619
|
+
assert len(messages) == 6
|
|
620
|
+
system_msg = messages[0]
|
|
621
|
+
assert system_msg.role == "system"
|
|
622
|
+
assert system_msg.content == "system message"
|
|
623
|
+
|
|
624
|
+
user_msg = messages[1]
|
|
625
|
+
assert user_msg.role == "user"
|
|
626
|
+
assert user_msg.content == '{"test": "input 你好"}'
|
|
627
|
+
|
|
628
|
+
cot_msg = messages[2]
|
|
629
|
+
assert cot_msg.role == "system"
|
|
630
|
+
assert cot_msg.content == "thinking instructions"
|
|
631
|
+
|
|
632
|
+
assistant_msg = messages[3]
|
|
633
|
+
assert assistant_msg.role == "assistant"
|
|
634
|
+
assert assistant_msg.content == "cot output"
|
|
635
|
+
|
|
636
|
+
final_answer_prompt_msg = messages[4]
|
|
637
|
+
assert final_answer_prompt_msg.role == "user"
|
|
638
|
+
assert final_answer_prompt_msg.content == COT_FINAL_ANSWER_PROMPT
|
|
639
|
+
|
|
640
|
+
final_msg = messages[5]
|
|
641
|
+
assert final_msg.role == "assistant"
|
|
642
|
+
assert final_msg.content == '{"test": "output 你好"}'
|
|
773
643
|
|
|
774
644
|
|
|
775
645
|
def test_build_training_data_with_COT_r1_style(mock_task):
|
|
@@ -779,19 +649,28 @@ def test_build_training_data_with_COT_r1_style(mock_task):
|
|
|
779
649
|
mock_task_run.intermediate_outputs = {"chain_of_thought": "cot output"}
|
|
780
650
|
mock_task_run.thinking_training_data.return_value = "cot output"
|
|
781
651
|
|
|
782
|
-
|
|
652
|
+
messages = build_training_chat(
|
|
783
653
|
mock_task_run,
|
|
784
654
|
"system message",
|
|
785
|
-
data_strategy=
|
|
655
|
+
data_strategy=ChatStrategy.single_turn_r1_thinking,
|
|
786
656
|
thinking_instructions=None,
|
|
787
657
|
)
|
|
788
|
-
|
|
789
|
-
assert
|
|
790
|
-
|
|
791
|
-
assert
|
|
792
|
-
assert
|
|
793
|
-
|
|
794
|
-
|
|
658
|
+
|
|
659
|
+
assert len(messages) == 3
|
|
660
|
+
system_msg = messages[0]
|
|
661
|
+
assert system_msg.role == "system"
|
|
662
|
+
assert system_msg.content == "system message"
|
|
663
|
+
|
|
664
|
+
user_msg = messages[1]
|
|
665
|
+
assert user_msg.role == "user"
|
|
666
|
+
assert user_msg.content == '{"test": "input 你好"}'
|
|
667
|
+
|
|
668
|
+
final_msg = messages[2]
|
|
669
|
+
assert final_msg.role == "assistant"
|
|
670
|
+
assert (
|
|
671
|
+
final_msg.content
|
|
672
|
+
== '<think>\ncot output\n</think>\n\n{"test": "output 你好"}'
|
|
673
|
+
)
|
|
795
674
|
|
|
796
675
|
|
|
797
676
|
def test_build_training_data_with_thinking(mock_task):
|
|
@@ -807,19 +686,36 @@ def test_build_training_data_with_thinking(mock_task):
|
|
|
807
686
|
mock_task.thinking_instruction = "thinking instructions"
|
|
808
687
|
assert mock_task.thinking_instruction == "thinking instructions"
|
|
809
688
|
|
|
810
|
-
|
|
689
|
+
messages = build_training_chat(
|
|
811
690
|
mock_task_run,
|
|
812
691
|
"system message",
|
|
813
|
-
|
|
692
|
+
ChatStrategy.two_message_cot,
|
|
814
693
|
thinking_instructions="thinking instructions",
|
|
815
694
|
)
|
|
816
|
-
|
|
817
|
-
assert
|
|
818
|
-
|
|
819
|
-
assert
|
|
820
|
-
assert
|
|
821
|
-
|
|
822
|
-
|
|
695
|
+
|
|
696
|
+
assert len(messages) == 5
|
|
697
|
+
system_msg = messages[0]
|
|
698
|
+
assert system_msg.role == "system"
|
|
699
|
+
assert system_msg.content == "system message"
|
|
700
|
+
|
|
701
|
+
user_msg = messages[1]
|
|
702
|
+
assert user_msg.role == "user"
|
|
703
|
+
assert (
|
|
704
|
+
user_msg.content
|
|
705
|
+
== 'The input is:\n<user_input>\n{"test": "input 你好"}\n</user_input>\n\nthinking instructions'
|
|
706
|
+
)
|
|
707
|
+
|
|
708
|
+
assistant_msg = messages[2]
|
|
709
|
+
assert assistant_msg.role == "assistant"
|
|
710
|
+
assert assistant_msg.content == "thinking output"
|
|
711
|
+
|
|
712
|
+
final_answer_prompt_msg = messages[3]
|
|
713
|
+
assert final_answer_prompt_msg.role == "user"
|
|
714
|
+
assert final_answer_prompt_msg.content == COT_FINAL_ANSWER_PROMPT
|
|
715
|
+
|
|
716
|
+
final_msg = messages[4]
|
|
717
|
+
assert final_msg.role == "assistant"
|
|
718
|
+
assert final_msg.content == '{"test": "output 你好"}'
|
|
823
719
|
|
|
824
720
|
|
|
825
721
|
def test_build_training_data_with_thinking_r1_style(mock_task):
|
|
@@ -836,19 +732,28 @@ def test_build_training_data_with_thinking_r1_style(mock_task):
|
|
|
836
732
|
|
|
837
733
|
assert mock_task.thinking_instruction == "thinking instructions"
|
|
838
734
|
|
|
839
|
-
|
|
735
|
+
messages = build_training_chat(
|
|
840
736
|
mock_task_run,
|
|
841
737
|
"system message",
|
|
842
|
-
|
|
738
|
+
ChatStrategy.single_turn_r1_thinking,
|
|
843
739
|
thinking_instructions=None,
|
|
844
740
|
)
|
|
845
|
-
|
|
846
|
-
assert
|
|
847
|
-
|
|
848
|
-
assert
|
|
849
|
-
assert
|
|
850
|
-
|
|
851
|
-
|
|
741
|
+
|
|
742
|
+
assert len(messages) == 3
|
|
743
|
+
system_msg = messages[0]
|
|
744
|
+
assert system_msg.role == "system"
|
|
745
|
+
assert system_msg.content == "system message"
|
|
746
|
+
|
|
747
|
+
user_msg = messages[1]
|
|
748
|
+
assert user_msg.role == "user"
|
|
749
|
+
assert user_msg.content == '{"test": "input 你好"}'
|
|
750
|
+
|
|
751
|
+
final_msg = messages[2]
|
|
752
|
+
assert final_msg.role == "assistant"
|
|
753
|
+
assert (
|
|
754
|
+
final_msg.content
|
|
755
|
+
== '<think>\nthinking output\n</think>\n\n{"test": "output 你好"}'
|
|
756
|
+
)
|
|
852
757
|
|
|
853
758
|
|
|
854
759
|
def test_build_training_data_with_repaired_output(mock_task):
|
|
@@ -863,17 +768,25 @@ def test_build_training_data_with_repaired_output(mock_task):
|
|
|
863
768
|
),
|
|
864
769
|
)
|
|
865
770
|
|
|
866
|
-
|
|
771
|
+
messages = build_training_chat(
|
|
867
772
|
mock_task_run,
|
|
868
773
|
"system message",
|
|
869
|
-
data_strategy=
|
|
774
|
+
data_strategy=ChatStrategy.single_turn,
|
|
870
775
|
)
|
|
871
|
-
|
|
872
|
-
assert
|
|
873
|
-
|
|
874
|
-
assert
|
|
875
|
-
assert
|
|
876
|
-
|
|
776
|
+
|
|
777
|
+
assert len(messages) == 3
|
|
778
|
+
system_msg = messages[0]
|
|
779
|
+
assert system_msg.role == "system"
|
|
780
|
+
assert system_msg.content == "system message"
|
|
781
|
+
|
|
782
|
+
user_msg = messages[1]
|
|
783
|
+
assert user_msg.role == "user"
|
|
784
|
+
assert user_msg.content == '{"test": "input 你好"}'
|
|
785
|
+
|
|
786
|
+
final_msg = messages[2]
|
|
787
|
+
assert final_msg.role == "assistant"
|
|
788
|
+
# Note we re-format the json
|
|
789
|
+
assert final_msg.content == '{"test": "repaired output"}'
|
|
877
790
|
|
|
878
791
|
|
|
879
792
|
def test_dataset_formatter_dump_to_file_json_schema_format(mock_dataset, tmp_path):
|
|
@@ -884,7 +797,7 @@ def test_dataset_formatter_dump_to_file_json_schema_format(mock_dataset, tmp_pat
|
|
|
884
797
|
"train",
|
|
885
798
|
DatasetFormat.OPENAI_CHAT_JSON_SCHEMA_JSONL,
|
|
886
799
|
path=output_path,
|
|
887
|
-
data_strategy=
|
|
800
|
+
data_strategy=ChatStrategy.single_turn,
|
|
888
801
|
)
|
|
889
802
|
|
|
890
803
|
assert result_path == output_path
|
|
@@ -940,3 +853,24 @@ def test_serialize_r1_style_message_missing_thinking(thinking, final_output):
|
|
|
940
853
|
),
|
|
941
854
|
):
|
|
942
855
|
serialize_r1_style_message(thinking=thinking, final_output=final_output)
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
def test_vertex_gemini_role_map_coverage():
|
|
859
|
+
"""Test that VERTEX_GEMINI_ROLE_MAP covers all possible ChatMessage.role values"""
|
|
860
|
+
from typing import Literal, get_type_hints
|
|
861
|
+
|
|
862
|
+
# Get the Literal type from ChatMessage.role
|
|
863
|
+
role_type = get_type_hints(ChatMessage)["role"]
|
|
864
|
+
# Extract the possible values from the Literal type
|
|
865
|
+
possible_roles = role_type.__args__ # type: ignore
|
|
866
|
+
|
|
867
|
+
# Check that every possible role is in the map
|
|
868
|
+
for role in possible_roles:
|
|
869
|
+
assert role in VERTEX_GEMINI_ROLE_MAP, (
|
|
870
|
+
f"Role {role} is not mapped in VERTEX_GEMINI_ROLE_MAP"
|
|
871
|
+
)
|
|
872
|
+
|
|
873
|
+
# Check that there are no extra mappings
|
|
874
|
+
assert set(VERTEX_GEMINI_ROLE_MAP.keys()) == set(possible_roles), (
|
|
875
|
+
"VERTEX_GEMINI_ROLE_MAP has extra mappings"
|
|
876
|
+
)
|