kiln-ai 0.16.0__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. kiln_ai/adapters/__init__.py +2 -0
  2. kiln_ai/adapters/adapter_registry.py +22 -44
  3. kiln_ai/adapters/chat/__init__.py +8 -0
  4. kiln_ai/adapters/chat/chat_formatter.py +233 -0
  5. kiln_ai/adapters/chat/test_chat_formatter.py +131 -0
  6. kiln_ai/adapters/data_gen/data_gen_prompts.py +121 -36
  7. kiln_ai/adapters/data_gen/data_gen_task.py +49 -36
  8. kiln_ai/adapters/data_gen/test_data_gen_task.py +330 -40
  9. kiln_ai/adapters/eval/base_eval.py +7 -6
  10. kiln_ai/adapters/eval/eval_runner.py +9 -2
  11. kiln_ai/adapters/eval/g_eval.py +40 -17
  12. kiln_ai/adapters/eval/test_base_eval.py +174 -17
  13. kiln_ai/adapters/eval/test_eval_runner.py +3 -0
  14. kiln_ai/adapters/eval/test_g_eval.py +116 -5
  15. kiln_ai/adapters/fine_tune/base_finetune.py +3 -8
  16. kiln_ai/adapters/fine_tune/dataset_formatter.py +135 -273
  17. kiln_ai/adapters/fine_tune/test_base_finetune.py +10 -10
  18. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +287 -353
  19. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +3 -3
  20. kiln_ai/adapters/fine_tune/test_openai_finetune.py +6 -6
  21. kiln_ai/adapters/fine_tune/test_together_finetune.py +1 -0
  22. kiln_ai/adapters/fine_tune/test_vertex_finetune.py +6 -11
  23. kiln_ai/adapters/fine_tune/together_finetune.py +13 -2
  24. kiln_ai/adapters/ml_model_list.py +370 -84
  25. kiln_ai/adapters/model_adapters/base_adapter.py +73 -26
  26. kiln_ai/adapters/model_adapters/litellm_adapter.py +88 -97
  27. kiln_ai/adapters/model_adapters/litellm_config.py +3 -2
  28. kiln_ai/adapters/model_adapters/test_base_adapter.py +235 -61
  29. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +104 -21
  30. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +41 -0
  31. kiln_ai/adapters/model_adapters/test_structured_output.py +44 -12
  32. kiln_ai/adapters/parsers/parser_registry.py +0 -2
  33. kiln_ai/adapters/parsers/r1_parser.py +0 -1
  34. kiln_ai/adapters/prompt_builders.py +0 -16
  35. kiln_ai/adapters/provider_tools.py +27 -9
  36. kiln_ai/adapters/remote_config.py +66 -0
  37. kiln_ai/adapters/repair/repair_task.py +1 -6
  38. kiln_ai/adapters/repair/test_repair_task.py +24 -3
  39. kiln_ai/adapters/test_adapter_registry.py +88 -28
  40. kiln_ai/adapters/test_ml_model_list.py +176 -0
  41. kiln_ai/adapters/test_prompt_adaptors.py +17 -7
  42. kiln_ai/adapters/test_prompt_builders.py +3 -16
  43. kiln_ai/adapters/test_provider_tools.py +69 -20
  44. kiln_ai/adapters/test_remote_config.py +100 -0
  45. kiln_ai/datamodel/__init__.py +0 -2
  46. kiln_ai/datamodel/datamodel_enums.py +38 -13
  47. kiln_ai/datamodel/eval.py +32 -0
  48. kiln_ai/datamodel/finetune.py +12 -8
  49. kiln_ai/datamodel/task.py +68 -7
  50. kiln_ai/datamodel/task_output.py +0 -2
  51. kiln_ai/datamodel/task_run.py +0 -2
  52. kiln_ai/datamodel/test_basemodel.py +2 -1
  53. kiln_ai/datamodel/test_dataset_split.py +0 -8
  54. kiln_ai/datamodel/test_eval_model.py +146 -4
  55. kiln_ai/datamodel/test_models.py +33 -10
  56. kiln_ai/datamodel/test_task.py +168 -2
  57. kiln_ai/utils/config.py +3 -2
  58. kiln_ai/utils/dataset_import.py +1 -1
  59. kiln_ai/utils/logging.py +166 -0
  60. kiln_ai/utils/test_config.py +23 -0
  61. kiln_ai/utils/test_dataset_import.py +30 -0
  62. {kiln_ai-0.16.0.dist-info → kiln_ai-0.18.0.dist-info}/METADATA +2 -2
  63. kiln_ai-0.18.0.dist-info/RECORD +115 -0
  64. kiln_ai-0.16.0.dist-info/RECORD +0 -108
  65. {kiln_ai-0.16.0.dist-info → kiln_ai-0.18.0.dist-info}/WHEEL +0 -0
  66. {kiln_ai-0.16.0.dist-info → kiln_ai-0.18.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -7,11 +7,12 @@ from unittest.mock import Mock
7
7
 
8
8
  import pytest
9
9
 
10
+ from kiln_ai.adapters.chat.chat_formatter import COT_FINAL_ANSWER_PROMPT, ChatMessage
10
11
  from kiln_ai.adapters.fine_tune.dataset_formatter import (
12
+ VERTEX_GEMINI_ROLE_MAP,
11
13
  DatasetFormat,
12
14
  DatasetFormatter,
13
- ModelTrainingData,
14
- build_training_data,
15
+ build_training_chat,
15
16
  generate_chat_message_response,
16
17
  generate_chat_message_toolcall,
17
18
  generate_huggingface_chat_template,
@@ -19,16 +20,15 @@ from kiln_ai.adapters.fine_tune.dataset_formatter import (
19
20
  generate_vertex_gemini,
20
21
  serialize_r1_style_message,
21
22
  )
22
- from kiln_ai.adapters.model_adapters.base_adapter import COT_FINAL_ANSWER_PROMPT
23
23
  from kiln_ai.datamodel import (
24
24
  DatasetSplit,
25
25
  DataSource,
26
26
  DataSourceType,
27
- FinetuneDataStrategy,
28
27
  Task,
29
28
  TaskOutput,
30
29
  TaskRun,
31
30
  )
31
+ from kiln_ai.datamodel.datamodel_enums import ChatStrategy
32
32
 
33
33
  logger = logging.getLogger(__name__)
34
34
 
@@ -100,41 +100,56 @@ def mock_dataset(mock_task):
100
100
  return dataset
101
101
 
102
102
 
103
- def test_generate_chat_message_response():
104
- thinking_data = ModelTrainingData(
105
- input="test input",
106
- system_message="system message",
107
- final_output="test output",
108
- )
103
+ @pytest.fixture
104
+ def mock_training_chat_short():
105
+ return [
106
+ ChatMessage(role="system", content="system message"),
107
+ ChatMessage(
108
+ role="user",
109
+ content="test input",
110
+ ),
111
+ ChatMessage(role="assistant", content="test output"),
112
+ ]
109
113
 
110
- result = generate_chat_message_response(thinking_data)
111
114
 
112
- assert result == {
113
- "messages": [
114
- {"role": "system", "content": "system message"},
115
- {"role": "user", "content": "test input"},
116
- {"role": "assistant", "content": "test output"},
117
- ]
118
- }
115
+ @pytest.fixture
116
+ def mock_training_chat_two_step_plaintext():
117
+ return [
118
+ ChatMessage(role="system", content="system message"),
119
+ ChatMessage(
120
+ role="user",
121
+ content="The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
122
+ ),
123
+ ChatMessage(role="assistant", content="thinking output"),
124
+ ChatMessage(role="user", content="thinking final answer prompt"),
125
+ ChatMessage(role="assistant", content="test output"),
126
+ ]
119
127
 
120
128
 
121
- def test_generate_chat_message_response_thinking():
122
- thinking_data = ModelTrainingData(
123
- input="test input",
124
- system_message="system message",
125
- final_output="test output",
126
- thinking="thinking output",
127
- thinking_instructions="thinking instructions",
128
- thinking_final_answer_prompt="thinking final answer prompt",
129
- )
129
+ @pytest.fixture
130
+ def mock_training_chat_two_step_json():
131
+ return [
132
+ ChatMessage(role="system", content="system message"),
133
+ ChatMessage(
134
+ role="user",
135
+ content="The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
136
+ ),
137
+ ChatMessage(role="assistant", content="thinking output"),
138
+ ChatMessage(role="user", content="thinking final answer prompt"),
139
+ ChatMessage(role="assistant", content='{"a":"你好"}'),
140
+ ]
130
141
 
131
- result = generate_chat_message_response(thinking_data)
142
+
143
+ def test_generate_chat_message_response(mock_training_chat_two_step_plaintext):
144
+ result = generate_chat_message_response(mock_training_chat_two_step_plaintext)
132
145
 
133
146
  assert result == {
134
147
  "messages": [
135
148
  {"role": "system", "content": "system message"},
136
- {"role": "user", "content": "test input"},
137
- {"role": "user", "content": "thinking instructions"},
149
+ {
150
+ "role": "user",
151
+ "content": "The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
152
+ },
138
153
  {"role": "assistant", "content": "thinking output"},
139
154
  {"role": "user", "content": "thinking final answer prompt"},
140
155
  {"role": "assistant", "content": "test output"},
@@ -142,79 +157,33 @@ def test_generate_chat_message_response_thinking():
142
157
  }
143
158
 
144
159
 
145
- def test_generate_chat_message_response_thinking_r1_style():
146
- thinking_data = ModelTrainingData(
147
- input="test input",
148
- system_message="system message",
149
- final_output="test output",
150
- thinking="thinking output",
151
- thinking_instructions=None,
152
- thinking_final_answer_prompt=None,
153
- thinking_r1_style=True,
154
- )
155
-
156
- result = generate_chat_message_response(thinking_data)
160
+ def test_generate_chat_message_response_json(mock_training_chat_two_step_json):
161
+ result = generate_chat_message_response(mock_training_chat_two_step_json)
157
162
 
158
163
  assert result == {
159
164
  "messages": [
160
165
  {"role": "system", "content": "system message"},
161
- {"role": "user", "content": "test input"},
162
166
  {
163
- "role": "assistant",
164
- "content": "<think>\nthinking output\n</think>\n\ntest output",
167
+ "role": "user",
168
+ "content": "The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
165
169
  },
170
+ {"role": "assistant", "content": "thinking output"},
171
+ {"role": "user", "content": "thinking final answer prompt"},
172
+ {"role": "assistant", "content": '{"a":"你好"}'},
166
173
  ]
167
174
  }
168
175
 
169
176
 
170
- def test_generate_chat_message_toolcall():
171
- training_data = ModelTrainingData(
172
- input="test input 你好",
173
- system_message="system message 你好",
174
- final_output='{"key": "value 你好"}',
175
- )
176
-
177
- result = generate_chat_message_toolcall(training_data)
177
+ def test_generate_chat_message_toolcall(mock_training_chat_two_step_json):
178
+ result = generate_chat_message_toolcall(mock_training_chat_two_step_json)
178
179
 
179
180
  assert result == {
180
181
  "messages": [
181
- {"role": "system", "content": "system message 你好"},
182
- {"role": "user", "content": "test input 你好"},
182
+ {"role": "system", "content": "system message"},
183
183
  {
184
- "role": "assistant",
185
- "content": None,
186
- "tool_calls": [
187
- {
188
- "id": "call_1",
189
- "type": "function",
190
- "function": {
191
- "name": "task_response",
192
- "arguments": '{"key": "value 你好"}',
193
- },
194
- }
195
- ],
184
+ "role": "user",
185
+ "content": "The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
196
186
  },
197
- ]
198
- }
199
-
200
-
201
- def test_generate_chat_message_toolcall_thinking():
202
- training_data = ModelTrainingData(
203
- input="test input",
204
- system_message="system message",
205
- final_output='{"key": "value"}',
206
- thinking="thinking output",
207
- thinking_instructions="thinking instructions",
208
- thinking_final_answer_prompt="thinking final answer prompt",
209
- )
210
-
211
- result = generate_chat_message_toolcall(training_data)
212
-
213
- assert result == {
214
- "messages": [
215
- {"role": "system", "content": "system message"},
216
- {"role": "user", "content": "test input"},
217
- {"role": "user", "content": "thinking instructions"},
218
187
  {"role": "assistant", "content": "thinking output"},
219
188
  {"role": "user", "content": "thinking final answer prompt"},
220
189
  {
@@ -226,7 +195,7 @@ def test_generate_chat_message_toolcall_thinking():
226
195
  "type": "function",
227
196
  "function": {
228
197
  "name": "task_response",
229
- "arguments": '{"key": "value"}',
198
+ "arguments": '{"a": "你好"}',
230
199
  },
231
200
  }
232
201
  ],
@@ -235,49 +204,17 @@ def test_generate_chat_message_toolcall_thinking():
235
204
  }
236
205
 
237
206
 
238
- def test_generate_chat_message_toolcall_thinking_r1_style():
239
- training_data = ModelTrainingData(
240
- input="test input",
241
- system_message="system message",
242
- final_output='{"key": "value"}',
243
- thinking="thinking output",
244
- thinking_instructions=None,
245
- thinking_final_answer_prompt=None,
246
- thinking_r1_style=True,
247
- )
248
-
249
- with pytest.raises(
250
- ValueError,
251
- match="R1 style thinking is not supported for tool call downloads",
252
- ):
253
- generate_chat_message_toolcall(training_data)
254
-
255
-
256
- def test_generate_chat_message_toolcall_invalid_json():
257
- training_data = ModelTrainingData(
258
- input="test input",
259
- system_message="system message",
260
- final_output="invalid json",
261
- )
262
-
263
- with pytest.raises(ValueError, match="Invalid JSON in for tool call"):
264
- generate_chat_message_toolcall(training_data)
265
-
266
-
267
- def test_dataset_formatter_init_no_parent_task(mock_dataset):
268
- mock_dataset.parent_task.return_value = None
269
-
270
- with pytest.raises(ValueError, match="Dataset has no parent task"):
271
- DatasetFormatter(mock_dataset, "system message")
207
+ def test_generate_chat_message_toolcall_invalid_json(mock_training_chat_two_step_json):
208
+ mock_training_chat_two_step_json[-1].content = "invalid json"
209
+ with pytest.raises(ValueError, match="^Last message is not JSON"):
210
+ generate_chat_message_toolcall(mock_training_chat_two_step_json)
272
211
 
273
212
 
274
213
  def test_dataset_formatter_dump_invalid_format(mock_dataset):
275
214
  formatter = DatasetFormatter(mock_dataset, "system message")
276
215
 
277
216
  with pytest.raises(ValueError, match="Unsupported format"):
278
- formatter.dump_to_file(
279
- "train", "invalid_format", FinetuneDataStrategy.final_only
280
- ) # type: ignore
217
+ formatter.dump_to_file("train", "invalid_format", ChatStrategy.single_turn)
281
218
 
282
219
 
283
220
  def test_dataset_formatter_dump_invalid_split(mock_dataset):
@@ -287,7 +224,7 @@ def test_dataset_formatter_dump_invalid_split(mock_dataset):
287
224
  formatter.dump_to_file(
288
225
  "invalid_split",
289
226
  DatasetFormat.OPENAI_CHAT_JSONL,
290
- FinetuneDataStrategy.final_only,
227
+ ChatStrategy.single_turn,
291
228
  )
292
229
 
293
230
 
@@ -299,7 +236,7 @@ def test_dataset_formatter_dump_to_file(mock_dataset, tmp_path):
299
236
  "train",
300
237
  DatasetFormat.OPENAI_CHAT_JSONL,
301
238
  path=output_path,
302
- data_strategy=FinetuneDataStrategy.final_only,
239
+ data_strategy=ChatStrategy.single_turn,
303
240
  )
304
241
 
305
242
  assert result_path == output_path
@@ -325,7 +262,7 @@ def test_dataset_formatter_dump_to_temp_file(mock_dataset):
325
262
  result_path = formatter.dump_to_file(
326
263
  "train",
327
264
  DatasetFormat.OPENAI_CHAT_JSONL,
328
- data_strategy=FinetuneDataStrategy.final_only,
265
+ data_strategy=ChatStrategy.single_turn,
329
266
  )
330
267
 
331
268
  assert result_path.exists()
@@ -356,7 +293,7 @@ def test_dataset_formatter_dump_to_file_tool_format(mock_dataset, tmp_path):
356
293
  "train",
357
294
  DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL,
358
295
  path=output_path,
359
- data_strategy=FinetuneDataStrategy.final_only,
296
+ data_strategy=ChatStrategy.single_turn,
360
297
  )
361
298
 
362
299
  assert result_path == output_path
@@ -396,7 +333,7 @@ def test_dataset_formatter_dump_with_intermediate_data(
396
333
  result_path = formatter.dump_to_file(
397
334
  "train",
398
335
  DatasetFormat.OPENAI_CHAT_JSONL,
399
- data_strategy=FinetuneDataStrategy.final_and_intermediate,
336
+ data_strategy=ChatStrategy.two_message_cot_legacy,
400
337
  )
401
338
 
402
339
  assert result_path.exists()
@@ -427,7 +364,7 @@ def test_dataset_formatter_dump_with_intermediate_data_r1_style(
427
364
  result_path = formatter.dump_to_file(
428
365
  "train",
429
366
  DatasetFormat.OPENAI_CHAT_JSONL,
430
- data_strategy=FinetuneDataStrategy.final_and_intermediate_r1_compatible,
367
+ data_strategy=ChatStrategy.single_turn_r1_thinking,
431
368
  )
432
369
 
433
370
  assert result_path.exists()
@@ -456,7 +393,7 @@ def test_dataset_formatter_dump_with_intermediate_data_custom_instructions(
456
393
  result_path = formatter.dump_to_file(
457
394
  "train",
458
395
  DatasetFormat.OPENAI_CHAT_JSONL,
459
- data_strategy=FinetuneDataStrategy.final_and_intermediate,
396
+ data_strategy=ChatStrategy.two_message_cot_legacy,
460
397
  )
461
398
 
462
399
  assert result_path.exists()
@@ -476,41 +413,16 @@ def test_dataset_formatter_dump_with_intermediate_data_custom_instructions(
476
413
  assert "thinking output" in line
477
414
 
478
415
 
479
- def test_generate_huggingface_chat_template():
480
- training_data = ModelTrainingData(
481
- input="test input",
482
- system_message="system message",
483
- final_output="test output",
484
- )
485
-
486
- result = generate_huggingface_chat_template(training_data)
487
-
488
- assert result == {
489
- "conversations": [
490
- {"role": "system", "content": "system message"},
491
- {"role": "user", "content": "test input"},
492
- {"role": "assistant", "content": "test output"},
493
- ]
494
- }
495
-
496
-
497
- def test_generate_huggingface_chat_template_thinking():
498
- training_data = ModelTrainingData(
499
- input="test input",
500
- system_message="system message",
501
- final_output="test output",
502
- thinking="thinking output",
503
- thinking_instructions="thinking instructions",
504
- thinking_final_answer_prompt="thinking final answer prompt",
505
- )
506
-
507
- result = generate_huggingface_chat_template(training_data)
416
+ def test_generate_huggingface_chat_template(mock_training_chat_two_step_plaintext):
417
+ result = generate_huggingface_chat_template(mock_training_chat_two_step_plaintext)
508
418
 
509
419
  assert result == {
510
420
  "conversations": [
511
421
  {"role": "system", "content": "system message"},
512
- {"role": "user", "content": "test input"},
513
- {"role": "user", "content": "thinking instructions"},
422
+ {
423
+ "role": "user",
424
+ "content": "The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
425
+ },
514
426
  {"role": "assistant", "content": "thinking output"},
515
427
  {"role": "user", "content": "thinking final answer prompt"},
516
428
  {"role": "assistant", "content": "test output"},
@@ -518,39 +430,8 @@ def test_generate_huggingface_chat_template_thinking():
518
430
  }
519
431
 
520
432
 
521
- def test_generate_huggingface_chat_template_thinking_r1_style():
522
- training_data = ModelTrainingData(
523
- input="test input",
524
- system_message="system message",
525
- final_output="test output",
526
- thinking="thinking output",
527
- thinking_instructions=None,
528
- thinking_final_answer_prompt=None,
529
- thinking_r1_style=True,
530
- )
531
-
532
- result = generate_huggingface_chat_template(training_data)
533
-
534
- assert result == {
535
- "conversations": [
536
- {"role": "system", "content": "system message"},
537
- {"role": "user", "content": "test input"},
538
- {
539
- "role": "assistant",
540
- "content": "<think>\nthinking output\n</think>\n\ntest output",
541
- },
542
- ]
543
- }
544
-
545
-
546
- def test_generate_vertex_template():
547
- training_data = ModelTrainingData(
548
- input="test input",
549
- system_message="system message",
550
- final_output="test output",
551
- )
552
-
553
- result = generate_vertex_gemini(training_data)
433
+ def test_generate_vertex_template(mock_training_chat_short):
434
+ result = generate_vertex_gemini(mock_training_chat_short)
554
435
 
555
436
  assert result == {
556
437
  "systemInstruction": {
@@ -568,17 +449,8 @@ def test_generate_vertex_template():
568
449
  }
569
450
 
570
451
 
571
- def test_generate_vertex_template_thinking():
572
- training_data = ModelTrainingData(
573
- input="test input",
574
- system_message="system message",
575
- final_output="test output",
576
- thinking="thinking output",
577
- thinking_instructions="thinking instructions",
578
- thinking_final_answer_prompt="thinking final answer prompt",
579
- )
580
-
581
- result = generate_vertex_gemini(training_data)
452
+ def test_generate_vertex_template_thinking(mock_training_chat_two_step_plaintext):
453
+ result = generate_vertex_gemini(mock_training_chat_two_step_plaintext)
582
454
 
583
455
  assert result == {
584
456
  "systemInstruction": {
@@ -590,8 +462,14 @@ def test_generate_vertex_template_thinking():
590
462
  ],
591
463
  },
592
464
  "contents": [
593
- {"role": "user", "parts": [{"text": "test input"}]},
594
- {"role": "user", "parts": [{"text": "thinking instructions"}]},
465
+ {
466
+ "role": "user",
467
+ "parts": [
468
+ {
469
+ "text": "The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
470
+ }
471
+ ],
472
+ },
595
473
  {"role": "model", "parts": [{"text": "thinking output"}]},
596
474
  {"role": "user", "parts": [{"text": "thinking final answer prompt"}]},
597
475
  {"role": "model", "parts": [{"text": "test output"}]},
@@ -599,31 +477,14 @@ def test_generate_vertex_template_thinking():
599
477
  }
600
478
 
601
479
 
602
- def test_generate_vertex_template_thinking_r1_style():
603
- training_data = ModelTrainingData(
604
- input="test input",
605
- system_message="system message",
606
- final_output="test output",
607
- thinking="thinking output",
608
- thinking_instructions=None,
609
- thinking_final_answer_prompt=None,
610
- thinking_r1_style=True,
611
- )
612
-
613
- with pytest.raises(
614
- ValueError, match="R1 style thinking is not supported for Vertex Gemini"
615
- ):
616
- generate_vertex_gemini(training_data)
617
-
618
-
619
480
  def test_generate_huggingface_chat_template_toolcall():
620
- training_data = ModelTrainingData(
621
- input="test input",
622
- system_message="system message",
623
- final_output='{"key": "value"}',
624
- )
481
+ messages = [
482
+ ChatMessage("system", "system message"),
483
+ ChatMessage("user", "test input"),
484
+ ChatMessage("assistant", '{"key":"value"}'),
485
+ ]
625
486
 
626
- result = generate_huggingface_chat_template_toolcall(training_data)
487
+ result = generate_huggingface_chat_template_toolcall(messages)
627
488
 
628
489
  assert result["conversations"][0] == {"role": "system", "content": "system message"}
629
490
  assert result["conversations"][1] == {"role": "user", "content": "test input"}
@@ -638,34 +499,28 @@ def test_generate_huggingface_chat_template_toolcall():
638
499
  assert tool_call["function"]["arguments"] == {"key": "value"}
639
500
 
640
501
 
641
- def test_generate_huggingface_chat_template_toolcall_thinking():
642
- training_data = ModelTrainingData(
643
- input="test input",
644
- system_message="system message",
645
- final_output='{"key": "value"}',
646
- thinking="thinking output",
647
- thinking_instructions="thinking instructions",
648
- thinking_final_answer_prompt="thinking final answer prompt",
502
+ def test_generate_huggingface_chat_template_toolcall_thinking(
503
+ mock_training_chat_two_step_json,
504
+ ):
505
+ result = generate_huggingface_chat_template_toolcall(
506
+ mock_training_chat_two_step_json
649
507
  )
650
508
 
651
- result = generate_huggingface_chat_template_toolcall(training_data)
652
-
653
509
  assert result["conversations"][0] == {"role": "system", "content": "system message"}
654
- assert result["conversations"][1] == {"role": "user", "content": "test input"}
655
- assert result["conversations"][2] == {
510
+ assert result["conversations"][1] == {
656
511
  "role": "user",
657
- "content": "thinking instructions",
512
+ "content": "The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
658
513
  }
659
- assert result["conversations"][3] == {
514
+ assert result["conversations"][2] == {
660
515
  "role": "assistant",
661
516
  "content": "thinking output",
662
517
  }
663
- assert result["conversations"][4] == {
518
+ assert result["conversations"][3] == {
664
519
  "role": "user",
665
520
  "content": "thinking final answer prompt",
666
521
  }
667
522
 
668
- assistant_msg = result["conversations"][5]
523
+ assistant_msg = result["conversations"][4]
669
524
  assert assistant_msg["role"] == "assistant"
670
525
  assert len(assistant_msg["tool_calls"]) == 1
671
526
  tool_call = assistant_msg["tool_calls"][0]
@@ -673,53 +528,39 @@ def test_generate_huggingface_chat_template_toolcall_thinking():
673
528
  assert tool_call["function"]["name"] == "task_response"
674
529
  assert len(tool_call["function"]["id"]) == 9 # UUID is truncated to 9 chars
675
530
  assert tool_call["function"]["id"].isalnum() # Check ID is alphanumeric
676
- assert tool_call["function"]["arguments"] == {"key": "value"}
531
+ assert tool_call["function"]["arguments"] == {"a": "你好"}
677
532
 
678
533
 
679
- def test_generate_huggingface_chat_template_toolcall_thinking_r1_style():
680
- training_data = ModelTrainingData(
681
- input="test input",
682
- system_message="system message",
683
- final_output='{"key": "value"}',
684
- thinking="thinking output",
685
- thinking_instructions=None,
686
- thinking_final_answer_prompt=None,
687
- thinking_r1_style=True,
688
- )
689
-
690
- with pytest.raises(
691
- ValueError,
692
- match="R1 style thinking is not supported for tool call downloads",
693
- ):
694
- generate_huggingface_chat_template_toolcall(training_data)
695
-
696
-
697
- def test_generate_huggingface_chat_template_toolcall_invalid_json():
698
- training_data = ModelTrainingData(
699
- input="test input",
700
- system_message="system message",
701
- final_output="invalid json",
702
- )
534
+ def test_generate_huggingface_chat_template_toolcall_invalid_json(
535
+ mock_training_chat_two_step_json,
536
+ ):
537
+ mock_training_chat_two_step_json[-1].content = "invalid json"
703
538
 
704
- with pytest.raises(ValueError, match="Invalid JSON in for tool call"):
705
- generate_huggingface_chat_template_toolcall(training_data)
539
+ with pytest.raises(ValueError, match="^Last message is not JSON"):
540
+ generate_huggingface_chat_template_toolcall(mock_training_chat_two_step_json)
706
541
 
707
542
 
708
- def test_build_training_data(mock_task):
543
+ def test_build_training_chat(mock_task):
709
544
  # Non repaired should use original output
710
545
  mock_task_run = mock_task.runs()[0]
711
- training_data_output = build_training_data(
546
+ messages = build_training_chat(
712
547
  mock_task_run,
713
548
  "system message",
714
- data_strategy=FinetuneDataStrategy.final_only,
549
+ data_strategy=ChatStrategy.single_turn,
715
550
  )
716
- assert training_data_output.final_output == '{"test": "output 你好"}'
717
- assert training_data_output.thinking is None
718
- assert training_data_output.thinking_instructions is None
719
- assert training_data_output.thinking_final_answer_prompt is None
720
- assert training_data_output.input == '{"test": "input 你好"}'
721
- assert training_data_output.system_message == "system message"
722
- assert not training_data_output.supports_cot()
551
+
552
+ assert len(messages) == 3
553
+ system_msg = messages[0]
554
+ assert system_msg.role == "system"
555
+ assert system_msg.content == "system message"
556
+
557
+ user_msg = messages[1]
558
+ assert user_msg.role == "user"
559
+ assert user_msg.content == '{"test": "input 你好"}'
560
+
561
+ final_msg = messages[2]
562
+ assert final_msg.role == "assistant"
563
+ assert final_msg.content == '{"test": "output 你好"}'
723
564
 
724
565
 
725
566
  def test_build_training_data_with_COT(mock_task):
@@ -729,47 +570,76 @@ def test_build_training_data_with_COT(mock_task):
729
570
  mock_task_run.intermediate_outputs = {"chain_of_thought": "cot output"}
730
571
  mock_task_run.thinking_training_data.return_value = "cot output"
731
572
 
732
- training_data_output = build_training_data(
573
+ messages = build_training_chat(
733
574
  mock_task_run,
734
575
  "system message",
735
- data_strategy=FinetuneDataStrategy.final_and_intermediate,
576
+ data_strategy=ChatStrategy.two_message_cot,
736
577
  thinking_instructions="thinking instructions",
737
578
  )
738
- assert training_data_output.final_output == '{"test": "output 你好"}'
739
- assert training_data_output.thinking == "cot output"
740
- assert training_data_output.thinking_instructions == "thinking instructions"
741
- assert training_data_output.thinking_final_answer_prompt == COT_FINAL_ANSWER_PROMPT
742
- assert training_data_output.input == '{"test": "input 你好"}'
743
- assert training_data_output.system_message == "system message"
744
- assert training_data_output.thinking_r1_style == False
745
- assert training_data_output.supports_cot()
746
-
747
-
748
- def test_model_training_data_supports_cot(mock_task):
749
- training_data = ModelTrainingData(
750
- input="test input",
751
- system_message="system message",
752
- final_output="test output",
753
- thinking="thinking output",
754
- thinking_instructions="thinking instructions",
755
- thinking_final_answer_prompt=COT_FINAL_ANSWER_PROMPT,
756
- thinking_r1_style=False,
579
+
580
+ assert len(messages) == 5
581
+ system_msg = messages[0]
582
+ assert system_msg.role == "system"
583
+ assert system_msg.content == "system message"
584
+
585
+ user_msg = messages[1]
586
+ assert user_msg.role == "user"
587
+ assert (
588
+ user_msg.content
589
+ == 'The input is:\n<user_input>\n{"test": "input 你好"}\n</user_input>\n\nthinking instructions'
757
590
  )
758
- assert training_data.supports_cot() == True
759
591
 
592
+ assistant_msg = messages[2]
593
+ assert assistant_msg.role == "assistant"
594
+ assert assistant_msg.content == "cot output"
595
+
596
+ final_answer_prompt_msg = messages[3]
597
+ assert final_answer_prompt_msg.role == "user"
598
+ assert final_answer_prompt_msg.content == COT_FINAL_ANSWER_PROMPT
599
+
600
+ final_msg = messages[4]
601
+ assert final_msg.role == "assistant"
602
+ assert final_msg.content == '{"test": "output 你好"}'
603
+
604
+
605
+ def test_build_training_data_with_COT_legacy(mock_task):
606
+ # Setup with needed fields for thinking
607
+ mock_task_run = mock_task.runs()[0]
608
+ assert mock_task_run.parent_task() == mock_task
609
+ mock_task_run.intermediate_outputs = {"chain_of_thought": "cot output"}
610
+ mock_task_run.thinking_training_data.return_value = "cot output"
760
611
 
761
- def test_model_training_data_supports_cot_r1_style(mock_task):
762
- training_data = ModelTrainingData(
763
- input="test input",
764
- system_message="system message",
765
- final_output="test output",
766
- thinking="thinking output",
612
+ messages = build_training_chat(
613
+ mock_task_run,
614
+ "system message",
615
+ data_strategy=ChatStrategy.two_message_cot_legacy,
767
616
  thinking_instructions="thinking instructions",
768
- thinking_r1_style=True,
769
617
  )
770
618
 
771
- with pytest.raises(ValueError, match="R1 style does not support COT"):
772
- training_data.supports_cot()
619
+ assert len(messages) == 6
620
+ system_msg = messages[0]
621
+ assert system_msg.role == "system"
622
+ assert system_msg.content == "system message"
623
+
624
+ user_msg = messages[1]
625
+ assert user_msg.role == "user"
626
+ assert user_msg.content == '{"test": "input 你好"}'
627
+
628
+ cot_msg = messages[2]
629
+ assert cot_msg.role == "system"
630
+ assert cot_msg.content == "thinking instructions"
631
+
632
+ assistant_msg = messages[3]
633
+ assert assistant_msg.role == "assistant"
634
+ assert assistant_msg.content == "cot output"
635
+
636
+ final_answer_prompt_msg = messages[4]
637
+ assert final_answer_prompt_msg.role == "user"
638
+ assert final_answer_prompt_msg.content == COT_FINAL_ANSWER_PROMPT
639
+
640
+ final_msg = messages[5]
641
+ assert final_msg.role == "assistant"
642
+ assert final_msg.content == '{"test": "output 你好"}'
773
643
 
774
644
 
775
645
  def test_build_training_data_with_COT_r1_style(mock_task):
@@ -779,19 +649,28 @@ def test_build_training_data_with_COT_r1_style(mock_task):
779
649
  mock_task_run.intermediate_outputs = {"chain_of_thought": "cot output"}
780
650
  mock_task_run.thinking_training_data.return_value = "cot output"
781
651
 
782
- training_data_output = build_training_data(
652
+ messages = build_training_chat(
783
653
  mock_task_run,
784
654
  "system message",
785
- data_strategy=FinetuneDataStrategy.final_and_intermediate_r1_compatible,
655
+ data_strategy=ChatStrategy.single_turn_r1_thinking,
786
656
  thinking_instructions=None,
787
657
  )
788
- assert training_data_output.final_output == '{"test": "output 你好"}'
789
- assert training_data_output.thinking == "cot output"
790
- assert training_data_output.thinking_instructions == None
791
- assert training_data_output.thinking_final_answer_prompt == None
792
- assert training_data_output.input == '{"test": "input 你好"}'
793
- assert training_data_output.system_message == "system message"
794
- assert training_data_output.thinking_r1_style == True
658
+
659
+ assert len(messages) == 3
660
+ system_msg = messages[0]
661
+ assert system_msg.role == "system"
662
+ assert system_msg.content == "system message"
663
+
664
+ user_msg = messages[1]
665
+ assert user_msg.role == "user"
666
+ assert user_msg.content == '{"test": "input 你好"}'
667
+
668
+ final_msg = messages[2]
669
+ assert final_msg.role == "assistant"
670
+ assert (
671
+ final_msg.content
672
+ == '<think>\ncot output\n</think>\n\n{"test": "output 你好"}'
673
+ )
795
674
 
796
675
 
797
676
  def test_build_training_data_with_thinking(mock_task):
@@ -807,19 +686,36 @@ def test_build_training_data_with_thinking(mock_task):
807
686
  mock_task.thinking_instruction = "thinking instructions"
808
687
  assert mock_task.thinking_instruction == "thinking instructions"
809
688
 
810
- training_data_output = build_training_data(
689
+ messages = build_training_chat(
811
690
  mock_task_run,
812
691
  "system message",
813
- FinetuneDataStrategy.final_and_intermediate,
692
+ ChatStrategy.two_message_cot,
814
693
  thinking_instructions="thinking instructions",
815
694
  )
816
- assert training_data_output.final_output == '{"test": "output 你好"}'
817
- assert training_data_output.thinking == "thinking output"
818
- assert training_data_output.thinking_instructions == "thinking instructions"
819
- assert training_data_output.thinking_final_answer_prompt == COT_FINAL_ANSWER_PROMPT
820
- assert training_data_output.input == '{"test": "input 你好"}'
821
- assert training_data_output.system_message == "system message"
822
- assert training_data_output.thinking_r1_style == False
695
+
696
+ assert len(messages) == 5
697
+ system_msg = messages[0]
698
+ assert system_msg.role == "system"
699
+ assert system_msg.content == "system message"
700
+
701
+ user_msg = messages[1]
702
+ assert user_msg.role == "user"
703
+ assert (
704
+ user_msg.content
705
+ == 'The input is:\n<user_input>\n{"test": "input 你好"}\n</user_input>\n\nthinking instructions'
706
+ )
707
+
708
+ assistant_msg = messages[2]
709
+ assert assistant_msg.role == "assistant"
710
+ assert assistant_msg.content == "thinking output"
711
+
712
+ final_answer_prompt_msg = messages[3]
713
+ assert final_answer_prompt_msg.role == "user"
714
+ assert final_answer_prompt_msg.content == COT_FINAL_ANSWER_PROMPT
715
+
716
+ final_msg = messages[4]
717
+ assert final_msg.role == "assistant"
718
+ assert final_msg.content == '{"test": "output 你好"}'
823
719
 
824
720
 
825
721
  def test_build_training_data_with_thinking_r1_style(mock_task):
@@ -836,19 +732,28 @@ def test_build_training_data_with_thinking_r1_style(mock_task):
836
732
 
837
733
  assert mock_task.thinking_instruction == "thinking instructions"
838
734
 
839
- training_data_output = build_training_data(
735
+ messages = build_training_chat(
840
736
  mock_task_run,
841
737
  "system message",
842
- FinetuneDataStrategy.final_and_intermediate_r1_compatible,
738
+ ChatStrategy.single_turn_r1_thinking,
843
739
  thinking_instructions=None,
844
740
  )
845
- assert training_data_output.final_output == '{"test": "output 你好"}'
846
- assert training_data_output.thinking == "thinking output"
847
- assert training_data_output.thinking_instructions == None
848
- assert training_data_output.thinking_final_answer_prompt == None
849
- assert training_data_output.input == '{"test": "input 你好"}'
850
- assert training_data_output.system_message == "system message"
851
- assert training_data_output.thinking_r1_style == True
741
+
742
+ assert len(messages) == 3
743
+ system_msg = messages[0]
744
+ assert system_msg.role == "system"
745
+ assert system_msg.content == "system message"
746
+
747
+ user_msg = messages[1]
748
+ assert user_msg.role == "user"
749
+ assert user_msg.content == '{"test": "input 你好"}'
750
+
751
+ final_msg = messages[2]
752
+ assert final_msg.role == "assistant"
753
+ assert (
754
+ final_msg.content
755
+ == '<think>\nthinking output\n</think>\n\n{"test": "output 你好"}'
756
+ )
852
757
 
853
758
 
854
759
  def test_build_training_data_with_repaired_output(mock_task):
@@ -863,17 +768,25 @@ def test_build_training_data_with_repaired_output(mock_task):
863
768
  ),
864
769
  )
865
770
 
866
- training_data_output = build_training_data(
771
+ messages = build_training_chat(
867
772
  mock_task_run,
868
773
  "system message",
869
- data_strategy=FinetuneDataStrategy.final_only,
774
+ data_strategy=ChatStrategy.single_turn,
870
775
  )
871
- assert training_data_output.final_output == '{"test": "repaired output"}'
872
- assert training_data_output.thinking is None
873
- assert training_data_output.thinking_instructions is None
874
- assert training_data_output.thinking_final_answer_prompt is None
875
- assert training_data_output.input == '{"test": "input 你好"}'
876
- assert training_data_output.system_message == "system message"
776
+
777
+ assert len(messages) == 3
778
+ system_msg = messages[0]
779
+ assert system_msg.role == "system"
780
+ assert system_msg.content == "system message"
781
+
782
+ user_msg = messages[1]
783
+ assert user_msg.role == "user"
784
+ assert user_msg.content == '{"test": "input 你好"}'
785
+
786
+ final_msg = messages[2]
787
+ assert final_msg.role == "assistant"
788
+ # Note we re-format the json
789
+ assert final_msg.content == '{"test": "repaired output"}'
877
790
 
878
791
 
879
792
  def test_dataset_formatter_dump_to_file_json_schema_format(mock_dataset, tmp_path):
@@ -884,7 +797,7 @@ def test_dataset_formatter_dump_to_file_json_schema_format(mock_dataset, tmp_pat
884
797
  "train",
885
798
  DatasetFormat.OPENAI_CHAT_JSON_SCHEMA_JSONL,
886
799
  path=output_path,
887
- data_strategy=FinetuneDataStrategy.final_only,
800
+ data_strategy=ChatStrategy.single_turn,
888
801
  )
889
802
 
890
803
  assert result_path == output_path
@@ -940,3 +853,24 @@ def test_serialize_r1_style_message_missing_thinking(thinking, final_output):
940
853
  ),
941
854
  ):
942
855
  serialize_r1_style_message(thinking=thinking, final_output=final_output)
856
+
857
+
858
+ def test_vertex_gemini_role_map_coverage():
859
+ """Test that VERTEX_GEMINI_ROLE_MAP covers all possible ChatMessage.role values"""
860
+ from typing import get_type_hints
861
+
862
+ # Get the Literal type from ChatMessage.role
863
+ role_type = get_type_hints(ChatMessage)["role"]
864
+ # Extract the possible values from the Literal type
865
+ possible_roles = role_type.__args__ # type: ignore
866
+
867
+ # Check that every possible role is in the map
868
+ for role in possible_roles:
869
+ assert role in VERTEX_GEMINI_ROLE_MAP, (
870
+ f"Role {role} is not mapped in VERTEX_GEMINI_ROLE_MAP"
871
+ )
872
+
873
+ # Check that there are no extra mappings
874
+ assert set(VERTEX_GEMINI_ROLE_MAP.keys()) == set(possible_roles), (
875
+ "VERTEX_GEMINI_ROLE_MAP has extra mappings"
876
+ )