kiln-ai 0.15.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (72) hide show
  1. kiln_ai/adapters/__init__.py +2 -0
  2. kiln_ai/adapters/adapter_registry.py +22 -44
  3. kiln_ai/adapters/chat/__init__.py +8 -0
  4. kiln_ai/adapters/chat/chat_formatter.py +234 -0
  5. kiln_ai/adapters/chat/test_chat_formatter.py +131 -0
  6. kiln_ai/adapters/data_gen/test_data_gen_task.py +19 -6
  7. kiln_ai/adapters/eval/base_eval.py +8 -6
  8. kiln_ai/adapters/eval/eval_runner.py +9 -65
  9. kiln_ai/adapters/eval/g_eval.py +26 -8
  10. kiln_ai/adapters/eval/test_base_eval.py +166 -15
  11. kiln_ai/adapters/eval/test_eval_runner.py +3 -0
  12. kiln_ai/adapters/eval/test_g_eval.py +1 -0
  13. kiln_ai/adapters/fine_tune/base_finetune.py +2 -2
  14. kiln_ai/adapters/fine_tune/dataset_formatter.py +153 -197
  15. kiln_ai/adapters/fine_tune/test_base_finetune.py +10 -10
  16. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +402 -211
  17. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +3 -3
  18. kiln_ai/adapters/fine_tune/test_openai_finetune.py +6 -6
  19. kiln_ai/adapters/fine_tune/test_together_finetune.py +1 -0
  20. kiln_ai/adapters/fine_tune/test_vertex_finetune.py +4 -4
  21. kiln_ai/adapters/fine_tune/together_finetune.py +12 -1
  22. kiln_ai/adapters/ml_model_list.py +556 -45
  23. kiln_ai/adapters/model_adapters/base_adapter.py +100 -35
  24. kiln_ai/adapters/model_adapters/litellm_adapter.py +116 -100
  25. kiln_ai/adapters/model_adapters/litellm_config.py +3 -2
  26. kiln_ai/adapters/model_adapters/test_base_adapter.py +299 -52
  27. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +121 -22
  28. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +44 -2
  29. kiln_ai/adapters/model_adapters/test_structured_output.py +48 -18
  30. kiln_ai/adapters/parsers/base_parser.py +0 -3
  31. kiln_ai/adapters/parsers/parser_registry.py +5 -3
  32. kiln_ai/adapters/parsers/r1_parser.py +17 -2
  33. kiln_ai/adapters/parsers/request_formatters.py +40 -0
  34. kiln_ai/adapters/parsers/test_parser_registry.py +2 -2
  35. kiln_ai/adapters/parsers/test_r1_parser.py +44 -1
  36. kiln_ai/adapters/parsers/test_request_formatters.py +76 -0
  37. kiln_ai/adapters/prompt_builders.py +14 -17
  38. kiln_ai/adapters/provider_tools.py +39 -4
  39. kiln_ai/adapters/repair/test_repair_task.py +27 -5
  40. kiln_ai/adapters/test_adapter_registry.py +88 -28
  41. kiln_ai/adapters/test_ml_model_list.py +158 -0
  42. kiln_ai/adapters/test_prompt_adaptors.py +17 -3
  43. kiln_ai/adapters/test_prompt_builders.py +27 -19
  44. kiln_ai/adapters/test_provider_tools.py +130 -12
  45. kiln_ai/datamodel/__init__.py +2 -2
  46. kiln_ai/datamodel/datamodel_enums.py +43 -4
  47. kiln_ai/datamodel/dataset_filters.py +69 -1
  48. kiln_ai/datamodel/dataset_split.py +4 -0
  49. kiln_ai/datamodel/eval.py +8 -0
  50. kiln_ai/datamodel/finetune.py +13 -7
  51. kiln_ai/datamodel/prompt_id.py +1 -0
  52. kiln_ai/datamodel/task.py +68 -7
  53. kiln_ai/datamodel/task_output.py +1 -1
  54. kiln_ai/datamodel/task_run.py +39 -7
  55. kiln_ai/datamodel/test_basemodel.py +5 -8
  56. kiln_ai/datamodel/test_dataset_filters.py +82 -0
  57. kiln_ai/datamodel/test_dataset_split.py +2 -8
  58. kiln_ai/datamodel/test_example_models.py +54 -0
  59. kiln_ai/datamodel/test_models.py +80 -9
  60. kiln_ai/datamodel/test_task.py +168 -2
  61. kiln_ai/utils/async_job_runner.py +106 -0
  62. kiln_ai/utils/config.py +3 -2
  63. kiln_ai/utils/dataset_import.py +81 -19
  64. kiln_ai/utils/logging.py +165 -0
  65. kiln_ai/utils/test_async_job_runner.py +199 -0
  66. kiln_ai/utils/test_config.py +23 -0
  67. kiln_ai/utils/test_dataset_import.py +272 -10
  68. {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/METADATA +1 -1
  69. kiln_ai-0.17.0.dist-info/RECORD +113 -0
  70. kiln_ai-0.15.0.dist-info/RECORD +0 -104
  71. {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/WHEEL +0 -0
  72. {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,32 +1,34 @@
1
1
  import json
2
2
  import logging
3
+ import re
3
4
  import tempfile
4
5
  from pathlib import Path
5
6
  from unittest.mock import Mock
6
7
 
7
8
  import pytest
8
9
 
10
+ from kiln_ai.adapters.chat.chat_formatter import COT_FINAL_ANSWER_PROMPT, ChatMessage
9
11
  from kiln_ai.adapters.fine_tune.dataset_formatter import (
12
+ VERTEX_GEMINI_ROLE_MAP,
10
13
  DatasetFormat,
11
14
  DatasetFormatter,
12
- ModelTrainingData,
13
- build_training_data,
15
+ build_training_chat,
14
16
  generate_chat_message_response,
15
17
  generate_chat_message_toolcall,
16
18
  generate_huggingface_chat_template,
17
19
  generate_huggingface_chat_template_toolcall,
18
20
  generate_vertex_gemini,
21
+ serialize_r1_style_message,
19
22
  )
20
- from kiln_ai.adapters.model_adapters.base_adapter import COT_FINAL_ANSWER_PROMPT
21
23
  from kiln_ai.datamodel import (
22
24
  DatasetSplit,
23
25
  DataSource,
24
26
  DataSourceType,
25
- FinetuneDataStrategy,
26
27
  Task,
27
28
  TaskOutput,
28
29
  TaskRun,
29
30
  )
31
+ from kiln_ai.datamodel.datamodel_enums import ChatStrategy
30
32
 
31
33
  logger = logging.getLogger(__name__)
32
34
 
@@ -42,6 +44,7 @@ def mock_task():
42
44
  "input": '{"test": "input 你好"}',
43
45
  "repaired_output": None,
44
46
  "intermediate_outputs": {},
47
+ "thinking_training_data": Mock(return_value=None),
45
48
  "input_source": Mock(
46
49
  spec=DataSource,
47
50
  **{
@@ -83,6 +86,7 @@ def mock_task():
83
86
  def mock_intermediate_outputs(mock_task):
84
87
  for run in mock_task.runs():
85
88
  run.intermediate_outputs = {"reasoning": "thinking output"}
89
+ run.thinking_training_data.return_value = "thinking output"
86
90
  mock_task.thinking_instruction = "thinking instructions"
87
91
  return mock_task
88
92
 
@@ -96,41 +100,56 @@ def mock_dataset(mock_task):
96
100
  return dataset
97
101
 
98
102
 
99
- def test_generate_chat_message_response():
100
- thinking_data = ModelTrainingData(
101
- input="test input",
102
- system_message="system message",
103
- final_output="test output",
104
- )
103
+ @pytest.fixture
104
+ def mock_training_chat_short():
105
+ return [
106
+ ChatMessage(role="system", content="system message"),
107
+ ChatMessage(
108
+ role="user",
109
+ content="test input",
110
+ ),
111
+ ChatMessage(role="assistant", content="test output"),
112
+ ]
105
113
 
106
- result = generate_chat_message_response(thinking_data)
107
114
 
108
- assert result == {
109
- "messages": [
110
- {"role": "system", "content": "system message"},
111
- {"role": "user", "content": "test input"},
112
- {"role": "assistant", "content": "test output"},
113
- ]
114
- }
115
+ @pytest.fixture
116
+ def mock_training_chat_two_step_plaintext():
117
+ return [
118
+ ChatMessage(role="system", content="system message"),
119
+ ChatMessage(
120
+ role="user",
121
+ content="The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
122
+ ),
123
+ ChatMessage(role="assistant", content="thinking output"),
124
+ ChatMessage(role="user", content="thinking final answer prompt"),
125
+ ChatMessage(role="assistant", content="test output"),
126
+ ]
115
127
 
116
128
 
117
- def test_generate_chat_message_response_thinking():
118
- thinking_data = ModelTrainingData(
119
- input="test input",
120
- system_message="system message",
121
- final_output="test output",
122
- thinking="thinking output",
123
- thinking_instructions="thinking instructions",
124
- thinking_final_answer_prompt="thinking final answer prompt",
125
- )
129
+ @pytest.fixture
130
+ def mock_training_chat_two_step_json():
131
+ return [
132
+ ChatMessage(role="system", content="system message"),
133
+ ChatMessage(
134
+ role="user",
135
+ content="The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
136
+ ),
137
+ ChatMessage(role="assistant", content="thinking output"),
138
+ ChatMessage(role="user", content="thinking final answer prompt"),
139
+ ChatMessage(role="assistant", content='{"a":"你好"}'),
140
+ ]
126
141
 
127
- result = generate_chat_message_response(thinking_data)
142
+
143
+ def test_generate_chat_message_response(mock_training_chat_two_step_plaintext):
144
+ result = generate_chat_message_response(mock_training_chat_two_step_plaintext)
128
145
 
129
146
  assert result == {
130
147
  "messages": [
131
148
  {"role": "system", "content": "system message"},
132
- {"role": "user", "content": "test input"},
133
- {"role": "user", "content": "thinking instructions"},
149
+ {
150
+ "role": "user",
151
+ "content": "The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
152
+ },
134
153
  {"role": "assistant", "content": "thinking output"},
135
154
  {"role": "user", "content": "thinking final answer prompt"},
136
155
  {"role": "assistant", "content": "test output"},
@@ -138,54 +157,33 @@ def test_generate_chat_message_response_thinking():
138
157
  }
139
158
 
140
159
 
141
- def test_generate_chat_message_toolcall():
142
- training_data = ModelTrainingData(
143
- input="test input 你好",
144
- system_message="system message 你好",
145
- final_output='{"key": "value 你好"}',
146
- )
147
-
148
- result = generate_chat_message_toolcall(training_data)
160
+ def test_generate_chat_message_response_json(mock_training_chat_two_step_json):
161
+ result = generate_chat_message_response(mock_training_chat_two_step_json)
149
162
 
150
163
  assert result == {
151
164
  "messages": [
152
- {"role": "system", "content": "system message 你好"},
153
- {"role": "user", "content": "test input 你好"},
165
+ {"role": "system", "content": "system message"},
154
166
  {
155
- "role": "assistant",
156
- "content": None,
157
- "tool_calls": [
158
- {
159
- "id": "call_1",
160
- "type": "function",
161
- "function": {
162
- "name": "task_response",
163
- "arguments": '{"key": "value 你好"}',
164
- },
165
- }
166
- ],
167
+ "role": "user",
168
+ "content": "The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
167
169
  },
170
+ {"role": "assistant", "content": "thinking output"},
171
+ {"role": "user", "content": "thinking final answer prompt"},
172
+ {"role": "assistant", "content": '{"a":"你好"}'},
168
173
  ]
169
174
  }
170
175
 
171
176
 
172
- def test_generate_chat_message_toolcall_thinking():
173
- training_data = ModelTrainingData(
174
- input="test input",
175
- system_message="system message",
176
- final_output='{"key": "value"}',
177
- thinking="thinking output",
178
- thinking_instructions="thinking instructions",
179
- thinking_final_answer_prompt="thinking final answer prompt",
180
- )
181
-
182
- result = generate_chat_message_toolcall(training_data)
177
+ def test_generate_chat_message_toolcall(mock_training_chat_two_step_json):
178
+ result = generate_chat_message_toolcall(mock_training_chat_two_step_json)
183
179
 
184
180
  assert result == {
185
181
  "messages": [
186
182
  {"role": "system", "content": "system message"},
187
- {"role": "user", "content": "test input"},
188
- {"role": "user", "content": "thinking instructions"},
183
+ {
184
+ "role": "user",
185
+ "content": "The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
186
+ },
189
187
  {"role": "assistant", "content": "thinking output"},
190
188
  {"role": "user", "content": "thinking final answer prompt"},
191
189
  {
@@ -197,7 +195,7 @@ def test_generate_chat_message_toolcall_thinking():
197
195
  "type": "function",
198
196
  "function": {
199
197
  "name": "task_response",
200
- "arguments": '{"key": "value"}',
198
+ "arguments": '{"a": "你好"}',
201
199
  },
202
200
  }
203
201
  ],
@@ -206,31 +204,17 @@ def test_generate_chat_message_toolcall_thinking():
206
204
  }
207
205
 
208
206
 
209
- def test_generate_chat_message_toolcall_invalid_json():
210
- training_data = ModelTrainingData(
211
- input="test input",
212
- system_message="system message",
213
- final_output="invalid json",
214
- )
215
-
216
- with pytest.raises(ValueError, match="Invalid JSON in for tool call"):
217
- generate_chat_message_toolcall(training_data)
218
-
219
-
220
- def test_dataset_formatter_init_no_parent_task(mock_dataset):
221
- mock_dataset.parent_task.return_value = None
222
-
223
- with pytest.raises(ValueError, match="Dataset has no parent task"):
224
- DatasetFormatter(mock_dataset, "system message")
207
+ def test_generate_chat_message_toolcall_invalid_json(mock_training_chat_two_step_json):
208
+ mock_training_chat_two_step_json[-1].content = "invalid json"
209
+ with pytest.raises(ValueError, match="^Last message is not JSON"):
210
+ generate_chat_message_toolcall(mock_training_chat_two_step_json)
225
211
 
226
212
 
227
213
  def test_dataset_formatter_dump_invalid_format(mock_dataset):
228
214
  formatter = DatasetFormatter(mock_dataset, "system message")
229
215
 
230
216
  with pytest.raises(ValueError, match="Unsupported format"):
231
- formatter.dump_to_file(
232
- "train", "invalid_format", FinetuneDataStrategy.final_only
233
- ) # type: ignore
217
+ formatter.dump_to_file("train", "invalid_format", ChatStrategy.single_turn)
234
218
 
235
219
 
236
220
  def test_dataset_formatter_dump_invalid_split(mock_dataset):
@@ -240,7 +224,7 @@ def test_dataset_formatter_dump_invalid_split(mock_dataset):
240
224
  formatter.dump_to_file(
241
225
  "invalid_split",
242
226
  DatasetFormat.OPENAI_CHAT_JSONL,
243
- FinetuneDataStrategy.final_only,
227
+ ChatStrategy.single_turn,
244
228
  )
245
229
 
246
230
 
@@ -252,7 +236,7 @@ def test_dataset_formatter_dump_to_file(mock_dataset, tmp_path):
252
236
  "train",
253
237
  DatasetFormat.OPENAI_CHAT_JSONL,
254
238
  path=output_path,
255
- data_strategy=FinetuneDataStrategy.final_only,
239
+ data_strategy=ChatStrategy.single_turn,
256
240
  )
257
241
 
258
242
  assert result_path == output_path
@@ -278,7 +262,7 @@ def test_dataset_formatter_dump_to_temp_file(mock_dataset):
278
262
  result_path = formatter.dump_to_file(
279
263
  "train",
280
264
  DatasetFormat.OPENAI_CHAT_JSONL,
281
- data_strategy=FinetuneDataStrategy.final_only,
265
+ data_strategy=ChatStrategy.single_turn,
282
266
  )
283
267
 
284
268
  assert result_path.exists()
@@ -309,7 +293,7 @@ def test_dataset_formatter_dump_to_file_tool_format(mock_dataset, tmp_path):
309
293
  "train",
310
294
  DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL,
311
295
  path=output_path,
312
- data_strategy=FinetuneDataStrategy.final_only,
296
+ data_strategy=ChatStrategy.single_turn,
313
297
  )
314
298
 
315
299
  assert result_path == output_path
@@ -349,7 +333,7 @@ def test_dataset_formatter_dump_with_intermediate_data(
349
333
  result_path = formatter.dump_to_file(
350
334
  "train",
351
335
  DatasetFormat.OPENAI_CHAT_JSONL,
352
- data_strategy=FinetuneDataStrategy.final_and_intermediate,
336
+ data_strategy=ChatStrategy.two_message_cot_legacy,
353
337
  )
354
338
 
355
339
  assert result_path.exists()
@@ -368,17 +352,19 @@ def test_dataset_formatter_dump_with_intermediate_data(
368
352
  assert "thinking instructions" in line
369
353
 
370
354
 
371
- def test_dataset_formatter_dump_with_intermediate_data_custom_instructions(
355
+ def test_dataset_formatter_dump_with_intermediate_data_r1_style(
372
356
  mock_dataset, mock_intermediate_outputs
373
357
  ):
374
358
  formatter = DatasetFormatter(
375
- mock_dataset, "custom system message 你好", "custom thinking instructions"
359
+ mock_dataset,
360
+ "system message 你好",
361
+ thinking_instructions=None,
376
362
  )
377
363
 
378
364
  result_path = formatter.dump_to_file(
379
365
  "train",
380
366
  DatasetFormat.OPENAI_CHAT_JSONL,
381
- data_strategy=FinetuneDataStrategy.final_and_intermediate,
367
+ data_strategy=ChatStrategy.single_turn_r1_thinking,
382
368
  )
383
369
 
384
370
  assert result_path.exists()
@@ -393,46 +379,50 @@ def test_dataset_formatter_dump_with_intermediate_data_custom_instructions(
393
379
  lines = f.readlines()
394
380
  assert len(lines) == 2
395
381
  for line in lines:
396
- assert "custom system message 你好" in line
397
- assert "custom thinking instructions" in line
398
- assert "thinking output" in line
382
+ assert "<think>" in line
383
+ assert "</think>" in line
399
384
 
400
385
 
401
- def test_generate_huggingface_chat_template():
402
- training_data = ModelTrainingData(
403
- input="test input",
404
- system_message="system message",
405
- final_output="test output",
386
+ def test_dataset_formatter_dump_with_intermediate_data_custom_instructions(
387
+ mock_dataset, mock_intermediate_outputs
388
+ ):
389
+ formatter = DatasetFormatter(
390
+ mock_dataset, "custom system message 你好", "custom thinking instructions"
406
391
  )
407
392
 
408
- result = generate_huggingface_chat_template(training_data)
409
-
410
- assert result == {
411
- "conversations": [
412
- {"role": "system", "content": "system message"},
413
- {"role": "user", "content": "test input"},
414
- {"role": "assistant", "content": "test output"},
415
- ]
416
- }
417
-
393
+ result_path = formatter.dump_to_file(
394
+ "train",
395
+ DatasetFormat.OPENAI_CHAT_JSONL,
396
+ data_strategy=ChatStrategy.two_message_cot_legacy,
397
+ )
418
398
 
419
- def test_generate_huggingface_chat_template_thinking():
420
- training_data = ModelTrainingData(
421
- input="test input",
422
- system_message="system message",
423
- final_output="test output",
424
- thinking="thinking output",
425
- thinking_instructions="thinking instructions",
426
- thinking_final_answer_prompt="thinking final answer prompt",
399
+ assert result_path.exists()
400
+ assert result_path.parent == Path(tempfile.gettempdir())
401
+ # Test our nice naming, with cot
402
+ assert (
403
+ result_path.name
404
+ == "test_dataset -- split-train -- format-openai_chat_jsonl -- cot.jsonl"
427
405
  )
406
+ # Verify file contents
407
+ with open(result_path) as f:
408
+ lines = f.readlines()
409
+ assert len(lines) == 2
410
+ for line in lines:
411
+ assert "custom system message 你好" in line
412
+ assert "custom thinking instructions" in line
413
+ assert "thinking output" in line
428
414
 
429
- result = generate_huggingface_chat_template(training_data)
415
+
416
+ def test_generate_huggingface_chat_template(mock_training_chat_two_step_plaintext):
417
+ result = generate_huggingface_chat_template(mock_training_chat_two_step_plaintext)
430
418
 
431
419
  assert result == {
432
420
  "conversations": [
433
421
  {"role": "system", "content": "system message"},
434
- {"role": "user", "content": "test input"},
435
- {"role": "user", "content": "thinking instructions"},
422
+ {
423
+ "role": "user",
424
+ "content": "The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
425
+ },
436
426
  {"role": "assistant", "content": "thinking output"},
437
427
  {"role": "user", "content": "thinking final answer prompt"},
438
428
  {"role": "assistant", "content": "test output"},
@@ -440,14 +430,8 @@ def test_generate_huggingface_chat_template_thinking():
440
430
  }
441
431
 
442
432
 
443
- def test_generate_vertex_template():
444
- training_data = ModelTrainingData(
445
- input="test input",
446
- system_message="system message",
447
- final_output="test output",
448
- )
449
-
450
- result = generate_vertex_gemini(training_data)
433
+ def test_generate_vertex_template(mock_training_chat_short):
434
+ result = generate_vertex_gemini(mock_training_chat_short)
451
435
 
452
436
  assert result == {
453
437
  "systemInstruction": {
@@ -465,19 +449,8 @@ def test_generate_vertex_template():
465
449
  }
466
450
 
467
451
 
468
- def test_generate_vertex_template_thinking():
469
- training_data = ModelTrainingData(
470
- input="test input",
471
- system_message="system message",
472
- final_output="test output",
473
- thinking="thinking output",
474
- thinking_instructions="thinking instructions",
475
- thinking_final_answer_prompt="thinking final answer prompt",
476
- )
477
-
478
- result = generate_vertex_gemini(training_data)
479
-
480
- logger.info(result)
452
+ def test_generate_vertex_template_thinking(mock_training_chat_two_step_plaintext):
453
+ result = generate_vertex_gemini(mock_training_chat_two_step_plaintext)
481
454
 
482
455
  assert result == {
483
456
  "systemInstruction": {
@@ -489,8 +462,14 @@ def test_generate_vertex_template_thinking():
489
462
  ],
490
463
  },
491
464
  "contents": [
492
- {"role": "user", "parts": [{"text": "test input"}]},
493
- {"role": "user", "parts": [{"text": "thinking instructions"}]},
465
+ {
466
+ "role": "user",
467
+ "parts": [
468
+ {
469
+ "text": "The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
470
+ }
471
+ ],
472
+ },
494
473
  {"role": "model", "parts": [{"text": "thinking output"}]},
495
474
  {"role": "user", "parts": [{"text": "thinking final answer prompt"}]},
496
475
  {"role": "model", "parts": [{"text": "test output"}]},
@@ -499,13 +478,13 @@ def test_generate_vertex_template_thinking():
499
478
 
500
479
 
501
480
  def test_generate_huggingface_chat_template_toolcall():
502
- training_data = ModelTrainingData(
503
- input="test input",
504
- system_message="system message",
505
- final_output='{"key": "value"}',
506
- )
481
+ messages = [
482
+ ChatMessage("system", "system message"),
483
+ ChatMessage("user", "test input"),
484
+ ChatMessage("assistant", '{"key":"value"}'),
485
+ ]
507
486
 
508
- result = generate_huggingface_chat_template_toolcall(training_data)
487
+ result = generate_huggingface_chat_template_toolcall(messages)
509
488
 
510
489
  assert result["conversations"][0] == {"role": "system", "content": "system message"}
511
490
  assert result["conversations"][1] == {"role": "user", "content": "test input"}
@@ -520,34 +499,28 @@ def test_generate_huggingface_chat_template_toolcall():
520
499
  assert tool_call["function"]["arguments"] == {"key": "value"}
521
500
 
522
501
 
523
- def test_generate_huggingface_chat_template_toolcall_thinking():
524
- training_data = ModelTrainingData(
525
- input="test input",
526
- system_message="system message",
527
- final_output='{"key": "value"}',
528
- thinking="thinking output",
529
- thinking_instructions="thinking instructions",
530
- thinking_final_answer_prompt="thinking final answer prompt",
502
+ def test_generate_huggingface_chat_template_toolcall_thinking(
503
+ mock_training_chat_two_step_json,
504
+ ):
505
+ result = generate_huggingface_chat_template_toolcall(
506
+ mock_training_chat_two_step_json
531
507
  )
532
508
 
533
- result = generate_huggingface_chat_template_toolcall(training_data)
534
-
535
509
  assert result["conversations"][0] == {"role": "system", "content": "system message"}
536
- assert result["conversations"][1] == {"role": "user", "content": "test input"}
537
- assert result["conversations"][2] == {
510
+ assert result["conversations"][1] == {
538
511
  "role": "user",
539
- "content": "thinking instructions",
512
+ "content": "The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions",
540
513
  }
541
- assert result["conversations"][3] == {
514
+ assert result["conversations"][2] == {
542
515
  "role": "assistant",
543
516
  "content": "thinking output",
544
517
  }
545
- assert result["conversations"][4] == {
518
+ assert result["conversations"][3] == {
546
519
  "role": "user",
547
520
  "content": "thinking final answer prompt",
548
521
  }
549
522
 
550
- assistant_msg = result["conversations"][5]
523
+ assistant_msg = result["conversations"][4]
551
524
  assert assistant_msg["role"] == "assistant"
552
525
  assert len(assistant_msg["tool_calls"]) == 1
553
526
  tool_call = assistant_msg["tool_calls"][0]
@@ -555,31 +528,39 @@ def test_generate_huggingface_chat_template_toolcall_thinking():
555
528
  assert tool_call["function"]["name"] == "task_response"
556
529
  assert len(tool_call["function"]["id"]) == 9 # UUID is truncated to 9 chars
557
530
  assert tool_call["function"]["id"].isalnum() # Check ID is alphanumeric
558
- assert tool_call["function"]["arguments"] == {"key": "value"}
531
+ assert tool_call["function"]["arguments"] == {"a": "你好"}
559
532
 
560
533
 
561
- def test_generate_huggingface_chat_template_toolcall_invalid_json():
562
- training_data = ModelTrainingData(
563
- input="test input",
564
- system_message="system message",
565
- final_output="invalid json",
566
- )
534
+ def test_generate_huggingface_chat_template_toolcall_invalid_json(
535
+ mock_training_chat_two_step_json,
536
+ ):
537
+ mock_training_chat_two_step_json[-1].content = "invalid json"
567
538
 
568
- with pytest.raises(ValueError, match="Invalid JSON in for tool call"):
569
- generate_huggingface_chat_template_toolcall(training_data)
539
+ with pytest.raises(ValueError, match="^Last message is not JSON"):
540
+ generate_huggingface_chat_template_toolcall(mock_training_chat_two_step_json)
570
541
 
571
542
 
572
- def test_build_training_data(mock_task):
543
+ def test_build_training_chat(mock_task):
573
544
  # Non repaired should use original output
574
545
  mock_task_run = mock_task.runs()[0]
575
- training_data_output = build_training_data(mock_task_run, "system message", False)
576
- assert training_data_output.final_output == '{"test": "output 你好"}'
577
- assert training_data_output.thinking is None
578
- assert training_data_output.thinking_instructions is None
579
- assert training_data_output.thinking_final_answer_prompt is None
580
- assert training_data_output.input == '{"test": "input 你好"}'
581
- assert training_data_output.system_message == "system message"
582
- assert not training_data_output.supports_cot()
546
+ messages = build_training_chat(
547
+ mock_task_run,
548
+ "system message",
549
+ data_strategy=ChatStrategy.single_turn,
550
+ )
551
+
552
+ assert len(messages) == 3
553
+ system_msg = messages[0]
554
+ assert system_msg.role == "system"
555
+ assert system_msg.content == "system message"
556
+
557
+ user_msg = messages[1]
558
+ assert user_msg.role == "user"
559
+ assert user_msg.content == '{"test": "input 你好"}'
560
+
561
+ final_msg = messages[2]
562
+ assert final_msg.role == "assistant"
563
+ assert final_msg.content == '{"test": "output 你好"}'
583
564
 
584
565
 
585
566
  def test_build_training_data_with_COT(mock_task):
@@ -587,20 +568,109 @@ def test_build_training_data_with_COT(mock_task):
587
568
  mock_task_run = mock_task.runs()[0]
588
569
  assert mock_task_run.parent_task() == mock_task
589
570
  mock_task_run.intermediate_outputs = {"chain_of_thought": "cot output"}
571
+ mock_task_run.thinking_training_data.return_value = "cot output"
590
572
 
591
- training_data_output = build_training_data(
573
+ messages = build_training_chat(
592
574
  mock_task_run,
593
575
  "system message",
594
- True,
576
+ data_strategy=ChatStrategy.two_message_cot,
595
577
  thinking_instructions="thinking instructions",
596
578
  )
597
- assert training_data_output.final_output == '{"test": "output 你好"}'
598
- assert training_data_output.thinking == "cot output"
599
- assert training_data_output.thinking_instructions == "thinking instructions"
600
- assert training_data_output.thinking_final_answer_prompt == COT_FINAL_ANSWER_PROMPT
601
- assert training_data_output.input == '{"test": "input 你好"}'
602
- assert training_data_output.system_message == "system message"
603
- assert training_data_output.supports_cot()
579
+
580
+ assert len(messages) == 5
581
+ system_msg = messages[0]
582
+ assert system_msg.role == "system"
583
+ assert system_msg.content == "system message"
584
+
585
+ user_msg = messages[1]
586
+ assert user_msg.role == "user"
587
+ assert (
588
+ user_msg.content
589
+ == 'The input is:\n<user_input>\n{"test": "input 你好"}\n</user_input>\n\nthinking instructions'
590
+ )
591
+
592
+ assistant_msg = messages[2]
593
+ assert assistant_msg.role == "assistant"
594
+ assert assistant_msg.content == "cot output"
595
+
596
+ final_answer_prompt_msg = messages[3]
597
+ assert final_answer_prompt_msg.role == "user"
598
+ assert final_answer_prompt_msg.content == COT_FINAL_ANSWER_PROMPT
599
+
600
+ final_msg = messages[4]
601
+ assert final_msg.role == "assistant"
602
+ assert final_msg.content == '{"test": "output 你好"}'
603
+
604
+
605
+ def test_build_training_data_with_COT_legacy(mock_task):
606
+ # Setup with needed fields for thinking
607
+ mock_task_run = mock_task.runs()[0]
608
+ assert mock_task_run.parent_task() == mock_task
609
+ mock_task_run.intermediate_outputs = {"chain_of_thought": "cot output"}
610
+ mock_task_run.thinking_training_data.return_value = "cot output"
611
+
612
+ messages = build_training_chat(
613
+ mock_task_run,
614
+ "system message",
615
+ data_strategy=ChatStrategy.two_message_cot_legacy,
616
+ thinking_instructions="thinking instructions",
617
+ )
618
+
619
+ assert len(messages) == 6
620
+ system_msg = messages[0]
621
+ assert system_msg.role == "system"
622
+ assert system_msg.content == "system message"
623
+
624
+ user_msg = messages[1]
625
+ assert user_msg.role == "user"
626
+ assert user_msg.content == '{"test": "input 你好"}'
627
+
628
+ cot_msg = messages[2]
629
+ assert cot_msg.role == "system"
630
+ assert cot_msg.content == "thinking instructions"
631
+
632
+ assistant_msg = messages[3]
633
+ assert assistant_msg.role == "assistant"
634
+ assert assistant_msg.content == "cot output"
635
+
636
+ final_answer_prompt_msg = messages[4]
637
+ assert final_answer_prompt_msg.role == "user"
638
+ assert final_answer_prompt_msg.content == COT_FINAL_ANSWER_PROMPT
639
+
640
+ final_msg = messages[5]
641
+ assert final_msg.role == "assistant"
642
+ assert final_msg.content == '{"test": "output 你好"}'
643
+
644
+
645
+ def test_build_training_data_with_COT_r1_style(mock_task):
646
+ # Setup with needed fields for thinking
647
+ mock_task_run = mock_task.runs()[0]
648
+ assert mock_task_run.parent_task() == mock_task
649
+ mock_task_run.intermediate_outputs = {"chain_of_thought": "cot output"}
650
+ mock_task_run.thinking_training_data.return_value = "cot output"
651
+
652
+ messages = build_training_chat(
653
+ mock_task_run,
654
+ "system message",
655
+ data_strategy=ChatStrategy.single_turn_r1_thinking,
656
+ thinking_instructions=None,
657
+ )
658
+
659
+ assert len(messages) == 3
660
+ system_msg = messages[0]
661
+ assert system_msg.role == "system"
662
+ assert system_msg.content == "system message"
663
+
664
+ user_msg = messages[1]
665
+ assert user_msg.role == "user"
666
+ assert user_msg.content == '{"test": "input 你好"}'
667
+
668
+ final_msg = messages[2]
669
+ assert final_msg.role == "assistant"
670
+ assert (
671
+ final_msg.content
672
+ == '<think>\ncot output\n</think>\n\n{"test": "output 你好"}'
673
+ )
604
674
 
605
675
 
606
676
  def test_build_training_data_with_thinking(mock_task):
@@ -612,22 +682,78 @@ def test_build_training_data_with_thinking(mock_task):
612
682
  "reasoning": "thinking output",
613
683
  "chain_of_thought": "cot output",
614
684
  }
685
+ mock_task_run.thinking_training_data.return_value = "thinking output"
615
686
  mock_task.thinking_instruction = "thinking instructions"
616
687
  assert mock_task.thinking_instruction == "thinking instructions"
617
688
 
618
- training_data_output = build_training_data(
689
+ messages = build_training_chat(
619
690
  mock_task_run,
620
691
  "system message",
621
- True,
692
+ ChatStrategy.two_message_cot,
622
693
  thinking_instructions="thinking instructions",
623
694
  )
624
- assert training_data_output.final_output == '{"test": "output 你好"}'
625
- assert training_data_output.thinking == "thinking output"
626
- assert training_data_output.thinking_instructions == "thinking instructions"
627
- assert training_data_output.thinking_final_answer_prompt == COT_FINAL_ANSWER_PROMPT
628
- assert training_data_output.input == '{"test": "input 你好"}'
629
- assert training_data_output.system_message == "system message"
630
- assert training_data_output.supports_cot()
695
+
696
+ assert len(messages) == 5
697
+ system_msg = messages[0]
698
+ assert system_msg.role == "system"
699
+ assert system_msg.content == "system message"
700
+
701
+ user_msg = messages[1]
702
+ assert user_msg.role == "user"
703
+ assert (
704
+ user_msg.content
705
+ == 'The input is:\n<user_input>\n{"test": "input 你好"}\n</user_input>\n\nthinking instructions'
706
+ )
707
+
708
+ assistant_msg = messages[2]
709
+ assert assistant_msg.role == "assistant"
710
+ assert assistant_msg.content == "thinking output"
711
+
712
+ final_answer_prompt_msg = messages[3]
713
+ assert final_answer_prompt_msg.role == "user"
714
+ assert final_answer_prompt_msg.content == COT_FINAL_ANSWER_PROMPT
715
+
716
+ final_msg = messages[4]
717
+ assert final_msg.role == "assistant"
718
+ assert final_msg.content == '{"test": "output 你好"}'
719
+
720
+
721
+ def test_build_training_data_with_thinking_r1_style(mock_task):
722
+ # Setup with needed fields for thinking
723
+ mock_task_run = mock_task.runs()[0]
724
+ assert mock_task_run.parent_task() == mock_task
725
+ # It should just use the reasoning output if both thinking and chain_of_thought are present
726
+ mock_task_run.intermediate_outputs = {
727
+ "reasoning": "thinking output",
728
+ "chain_of_thought": "cot output",
729
+ }
730
+ mock_task_run.thinking_training_data.return_value = "thinking output"
731
+ mock_task.thinking_instruction = "thinking instructions"
732
+
733
+ assert mock_task.thinking_instruction == "thinking instructions"
734
+
735
+ messages = build_training_chat(
736
+ mock_task_run,
737
+ "system message",
738
+ ChatStrategy.single_turn_r1_thinking,
739
+ thinking_instructions=None,
740
+ )
741
+
742
+ assert len(messages) == 3
743
+ system_msg = messages[0]
744
+ assert system_msg.role == "system"
745
+ assert system_msg.content == "system message"
746
+
747
+ user_msg = messages[1]
748
+ assert user_msg.role == "user"
749
+ assert user_msg.content == '{"test": "input 你好"}'
750
+
751
+ final_msg = messages[2]
752
+ assert final_msg.role == "assistant"
753
+ assert (
754
+ final_msg.content
755
+ == '<think>\nthinking output\n</think>\n\n{"test": "output 你好"}'
756
+ )
631
757
 
632
758
 
633
759
  def test_build_training_data_with_repaired_output(mock_task):
@@ -642,13 +768,25 @@ def test_build_training_data_with_repaired_output(mock_task):
642
768
  ),
643
769
  )
644
770
 
645
- training_data_output = build_training_data(mock_task_run, "system message", False)
646
- assert training_data_output.final_output == '{"test": "repaired output"}'
647
- assert training_data_output.thinking is None
648
- assert training_data_output.thinking_instructions is None
649
- assert training_data_output.thinking_final_answer_prompt is None
650
- assert training_data_output.input == '{"test": "input 你好"}'
651
- assert training_data_output.system_message == "system message"
771
+ messages = build_training_chat(
772
+ mock_task_run,
773
+ "system message",
774
+ data_strategy=ChatStrategy.single_turn,
775
+ )
776
+
777
+ assert len(messages) == 3
778
+ system_msg = messages[0]
779
+ assert system_msg.role == "system"
780
+ assert system_msg.content == "system message"
781
+
782
+ user_msg = messages[1]
783
+ assert user_msg.role == "user"
784
+ assert user_msg.content == '{"test": "input 你好"}'
785
+
786
+ final_msg = messages[2]
787
+ assert final_msg.role == "assistant"
788
+ # Note we re-format the json
789
+ assert final_msg.content == '{"test": "repaired output"}'
652
790
 
653
791
 
654
792
  def test_dataset_formatter_dump_to_file_json_schema_format(mock_dataset, tmp_path):
@@ -659,7 +797,7 @@ def test_dataset_formatter_dump_to_file_json_schema_format(mock_dataset, tmp_pat
659
797
  "train",
660
798
  DatasetFormat.OPENAI_CHAT_JSON_SCHEMA_JSONL,
661
799
  path=output_path,
662
- data_strategy=FinetuneDataStrategy.final_only,
800
+ data_strategy=ChatStrategy.single_turn,
663
801
  )
664
802
 
665
803
  assert result_path == output_path
@@ -683,3 +821,56 @@ def test_dataset_formatter_dump_to_file_json_schema_format(mock_dataset, tmp_pat
683
821
  assert assistant_msg["content"] == '{"test": "output 你好"}'
684
822
  json_content = json.loads(assistant_msg["content"])
685
823
  assert json_content == {"test": "output 你好"}
824
+
825
+
826
+ @pytest.mark.parametrize(
827
+ "thinking,final_output,expected_output",
828
+ [
829
+ ("thinking", "final output", "<think>\nthinking\n</think>\n\nfinal output"),
830
+ ("thinking", '{"name":"joe"}', '<think>\nthinking\n</think>\n\n{"name":"joe"}'),
831
+ ],
832
+ )
833
+ def test_serialize_r1_style_message(thinking, final_output, expected_output):
834
+ assert (
835
+ serialize_r1_style_message(thinking=thinking, final_output=final_output)
836
+ == expected_output
837
+ )
838
+
839
+
840
+ @pytest.mark.parametrize(
841
+ "thinking,final_output",
842
+ [
843
+ (None, "final output"),
844
+ ("", "final output"),
845
+ (" ", "final output"),
846
+ ],
847
+ )
848
+ def test_serialize_r1_style_message_missing_thinking(thinking, final_output):
849
+ with pytest.raises(
850
+ ValueError,
851
+ match=re.escape(
852
+ "Thinking data is required when fine-tuning thinking models (R1, QwQ, etc). Please ensure your fine-tuning dataset contains reasoning or chain of thought output for every entry."
853
+ ),
854
+ ):
855
+ serialize_r1_style_message(thinking=thinking, final_output=final_output)
856
+
857
+
858
+ def test_vertex_gemini_role_map_coverage():
859
+ """Test that VERTEX_GEMINI_ROLE_MAP covers all possible ChatMessage.role values"""
860
+ from typing import Literal, get_type_hints
861
+
862
+ # Get the Literal type from ChatMessage.role
863
+ role_type = get_type_hints(ChatMessage)["role"]
864
+ # Extract the possible values from the Literal type
865
+ possible_roles = role_type.__args__ # type: ignore
866
+
867
+ # Check that every possible role is in the map
868
+ for role in possible_roles:
869
+ assert role in VERTEX_GEMINI_ROLE_MAP, (
870
+ f"Role {role} is not mapped in VERTEX_GEMINI_ROLE_MAP"
871
+ )
872
+
873
+ # Check that there are no extra mappings
874
+ assert set(VERTEX_GEMINI_ROLE_MAP.keys()) == set(possible_roles), (
875
+ "VERTEX_GEMINI_ROLE_MAP has extra mappings"
876
+ )