kiln-ai 0.8.0__py3-none-any.whl → 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (57) hide show
  1. kiln_ai/adapters/__init__.py +7 -7
  2. kiln_ai/adapters/adapter_registry.py +77 -5
  3. kiln_ai/adapters/data_gen/data_gen_task.py +3 -3
  4. kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
  5. kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
  6. kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
  7. kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
  8. kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
  9. kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
  10. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +469 -129
  11. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +113 -21
  12. kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
  13. kiln_ai/adapters/ml_model_list.py +323 -94
  14. kiln_ai/adapters/model_adapters/__init__.py +18 -0
  15. kiln_ai/adapters/{base_adapter.py → model_adapters/base_adapter.py} +81 -37
  16. kiln_ai/adapters/{langchain_adapters.py → model_adapters/langchain_adapters.py} +130 -84
  17. kiln_ai/adapters/model_adapters/openai_compatible_config.py +11 -0
  18. kiln_ai/adapters/model_adapters/openai_model_adapter.py +246 -0
  19. kiln_ai/adapters/model_adapters/test_base_adapter.py +190 -0
  20. kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +103 -88
  21. kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +225 -0
  22. kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +43 -15
  23. kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +93 -20
  24. kiln_ai/adapters/parsers/__init__.py +10 -0
  25. kiln_ai/adapters/parsers/base_parser.py +12 -0
  26. kiln_ai/adapters/parsers/json_parser.py +37 -0
  27. kiln_ai/adapters/parsers/parser_registry.py +19 -0
  28. kiln_ai/adapters/parsers/r1_parser.py +69 -0
  29. kiln_ai/adapters/parsers/test_json_parser.py +81 -0
  30. kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
  31. kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
  32. kiln_ai/adapters/prompt_builders.py +126 -20
  33. kiln_ai/adapters/provider_tools.py +91 -36
  34. kiln_ai/adapters/repair/repair_task.py +17 -6
  35. kiln_ai/adapters/repair/test_repair_task.py +4 -4
  36. kiln_ai/adapters/run_output.py +8 -0
  37. kiln_ai/adapters/test_adapter_registry.py +177 -0
  38. kiln_ai/adapters/test_generate_docs.py +69 -0
  39. kiln_ai/adapters/test_prompt_adaptors.py +8 -4
  40. kiln_ai/adapters/test_prompt_builders.py +190 -29
  41. kiln_ai/adapters/test_provider_tools.py +268 -46
  42. kiln_ai/datamodel/__init__.py +199 -12
  43. kiln_ai/datamodel/basemodel.py +31 -11
  44. kiln_ai/datamodel/json_schema.py +8 -3
  45. kiln_ai/datamodel/model_cache.py +8 -3
  46. kiln_ai/datamodel/test_basemodel.py +81 -2
  47. kiln_ai/datamodel/test_dataset_split.py +100 -3
  48. kiln_ai/datamodel/test_example_models.py +25 -4
  49. kiln_ai/datamodel/test_model_cache.py +24 -0
  50. kiln_ai/datamodel/test_model_perf.py +125 -0
  51. kiln_ai/datamodel/test_models.py +129 -0
  52. kiln_ai/utils/exhaustive_error.py +6 -0
  53. {kiln_ai-0.8.0.dist-info → kiln_ai-0.11.1.dist-info}/METADATA +9 -7
  54. kiln_ai-0.11.1.dist-info/RECORD +76 -0
  55. kiln_ai-0.8.0.dist-info/RECORD +0 -58
  56. {kiln_ai-0.8.0.dist-info → kiln_ai-0.11.1.dist-info}/WHEEL +0 -0
  57. {kiln_ai-0.8.0.dist-info → kiln_ai-0.11.1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -8,15 +8,20 @@ import pytest
8
8
  from kiln_ai.adapters.fine_tune.dataset_formatter import (
9
9
  DatasetFormat,
10
10
  DatasetFormatter,
11
+ ModelTrainingData,
12
+ build_training_data,
11
13
  generate_chat_message_response,
12
14
  generate_chat_message_toolcall,
13
15
  generate_huggingface_chat_template,
14
16
  generate_huggingface_chat_template_toolcall,
17
+ generate_vertex_gemini_1_5,
15
18
  )
19
+ from kiln_ai.adapters.model_adapters.base_adapter import COT_FINAL_ANSWER_PROMPT
16
20
  from kiln_ai.datamodel import (
17
21
  DatasetSplit,
18
22
  DataSource,
19
23
  DataSourceType,
24
+ FinetuneDataStrategy,
20
25
  Task,
21
26
  TaskOutput,
22
27
  TaskRun,
@@ -25,32 +30,60 @@ from kiln_ai.datamodel import (
25
30
 
26
31
  @pytest.fixture
27
32
  def mock_task():
28
- task = Mock(spec=Task)
33
+ task = Mock(spec=Task, thinking_instruction=None)
29
34
  task_runs = [
30
- TaskRun(
31
- id=f"run{i}",
32
- input='{"test": "input"}',
33
- input_source=DataSource(
34
- type=DataSourceType.human, properties={"created_by": "test"}
35
- ),
36
- output=TaskOutput(
37
- output='{"test": "output"}',
38
- source=DataSource(
39
- type=DataSourceType.synthetic,
40
- properties={
41
- "model_name": "test",
42
- "model_provider": "test",
43
- "adapter_name": "test",
35
+ Mock(
36
+ spec=TaskRun,
37
+ **{
38
+ "id": f"run{i}",
39
+ "input": '{"test": "input 你好"}',
40
+ "repaired_output": None,
41
+ "intermediate_outputs": {},
42
+ "input_source": Mock(
43
+ spec=DataSource,
44
+ **{
45
+ "type": DataSourceType.human,
46
+ "properties": {"created_by": "test"},
44
47
  },
45
48
  ),
46
- ),
49
+ "output": Mock(
50
+ spec=TaskOutput,
51
+ **{
52
+ "output": '{"test": "output 你好"}',
53
+ "source": Mock(
54
+ spec=DataSource,
55
+ **{
56
+ "type": DataSourceType.synthetic,
57
+ "properties": {
58
+ "model_name": "test",
59
+ "model_provider": "test",
60
+ "adapter_name": "test",
61
+ },
62
+ },
63
+ ),
64
+ },
65
+ ),
66
+ },
47
67
  )
48
68
  for i in range(1, 4)
49
69
  ]
70
+
71
+ # Set up parent_task reference for each TaskRun
72
+ for run in task_runs:
73
+ run.parent_task = Mock(return_value=task)
74
+
50
75
  task.runs.return_value = task_runs
51
76
  return task
52
77
 
53
78
 
79
+ @pytest.fixture
80
+ def mock_intermediate_outputs(mock_task):
81
+ for run in mock_task.runs():
82
+ run.intermediate_outputs = {"reasoning": "thinking output"}
83
+ mock_task.thinking_instruction = "thinking instructions"
84
+ return mock_task
85
+
86
+
54
87
  @pytest.fixture
55
88
  def mock_dataset(mock_task):
56
89
  dataset = Mock(spec=DatasetSplit)
@@ -61,26 +94,13 @@ def mock_dataset(mock_task):
61
94
 
62
95
 
63
96
  def test_generate_chat_message_response():
64
- task_run = TaskRun(
65
- id="run1",
97
+ thinking_data = ModelTrainingData(
66
98
  input="test input",
67
- input_source=DataSource(
68
- type=DataSourceType.human, properties={"created_by": "test"}
69
- ),
70
- output=TaskOutput(
71
- output="test output",
72
- source=DataSource(
73
- type=DataSourceType.synthetic,
74
- properties={
75
- "model_name": "test",
76
- "model_provider": "test",
77
- "adapter_name": "test",
78
- },
79
- ),
80
- ),
99
+ system_message="system message",
100
+ final_output="test output",
81
101
  )
82
102
 
83
- result = generate_chat_message_response(task_run, "system message")
103
+ result = generate_chat_message_response(thinking_data)
84
104
 
85
105
  assert result == {
86
106
  "messages": [
@@ -91,32 +111,80 @@ def test_generate_chat_message_response():
91
111
  }
92
112
 
93
113
 
114
+ def test_generate_chat_message_response_thinking():
115
+ thinking_data = ModelTrainingData(
116
+ input="test input",
117
+ system_message="system message",
118
+ final_output="test output",
119
+ thinking="thinking output",
120
+ thinking_instructions="thinking instructions",
121
+ thinking_final_answer_prompt="thinking final answer prompt",
122
+ )
123
+
124
+ result = generate_chat_message_response(thinking_data)
125
+
126
+ assert result == {
127
+ "messages": [
128
+ {"role": "system", "content": "system message"},
129
+ {"role": "user", "content": "test input"},
130
+ {"role": "user", "content": "thinking instructions"},
131
+ {"role": "assistant", "content": "thinking output"},
132
+ {"role": "user", "content": "thinking final answer prompt"},
133
+ {"role": "assistant", "content": "test output"},
134
+ ]
135
+ }
136
+
137
+
94
138
  def test_generate_chat_message_toolcall():
95
- task_run = TaskRun(
96
- id="run1",
139
+ training_data = ModelTrainingData(
140
+ input="test input 你好",
141
+ system_message="system message 你好",
142
+ final_output='{"key": "value 你好"}',
143
+ )
144
+
145
+ result = generate_chat_message_toolcall(training_data)
146
+
147
+ assert result == {
148
+ "messages": [
149
+ {"role": "system", "content": "system message 你好"},
150
+ {"role": "user", "content": "test input 你好"},
151
+ {
152
+ "role": "assistant",
153
+ "content": None,
154
+ "tool_calls": [
155
+ {
156
+ "id": "call_1",
157
+ "type": "function",
158
+ "function": {
159
+ "name": "task_response",
160
+ "arguments": '{"key": "value 你好"}',
161
+ },
162
+ }
163
+ ],
164
+ },
165
+ ]
166
+ }
167
+
168
+
169
+ def test_generate_chat_message_toolcall_thinking():
170
+ training_data = ModelTrainingData(
97
171
  input="test input",
98
- input_source=DataSource(
99
- type=DataSourceType.human, properties={"created_by": "test"}
100
- ),
101
- output=TaskOutput(
102
- output='{"key": "value"}',
103
- source=DataSource(
104
- type=DataSourceType.synthetic,
105
- properties={
106
- "model_name": "test",
107
- "model_provider": "test",
108
- "adapter_name": "test",
109
- },
110
- ),
111
- ),
172
+ system_message="system message",
173
+ final_output='{"key": "value"}',
174
+ thinking="thinking output",
175
+ thinking_instructions="thinking instructions",
176
+ thinking_final_answer_prompt="thinking final answer prompt",
112
177
  )
113
178
 
114
- result = generate_chat_message_toolcall(task_run, "system message")
179
+ result = generate_chat_message_toolcall(training_data)
115
180
 
116
181
  assert result == {
117
182
  "messages": [
118
183
  {"role": "system", "content": "system message"},
119
184
  {"role": "user", "content": "test input"},
185
+ {"role": "user", "content": "thinking instructions"},
186
+ {"role": "assistant", "content": "thinking output"},
187
+ {"role": "user", "content": "thinking final answer prompt"},
120
188
  {
121
189
  "role": "assistant",
122
190
  "content": None,
@@ -136,27 +204,14 @@ def test_generate_chat_message_toolcall():
136
204
 
137
205
 
138
206
  def test_generate_chat_message_toolcall_invalid_json():
139
- task_run = TaskRun(
140
- id="run1",
207
+ training_data = ModelTrainingData(
141
208
  input="test input",
142
- input_source=DataSource(
143
- type=DataSourceType.human, properties={"created_by": "test"}
144
- ),
145
- output=TaskOutput(
146
- output="invalid json",
147
- source=DataSource(
148
- type=DataSourceType.synthetic,
149
- properties={
150
- "model_name": "test",
151
- "model_provider": "test",
152
- "adapter_name": "test",
153
- },
154
- ),
155
- ),
209
+ system_message="system message",
210
+ final_output="invalid json",
156
211
  )
157
212
 
158
213
  with pytest.raises(ValueError, match="Invalid JSON in for tool call"):
159
- generate_chat_message_toolcall(task_run, "system message")
214
+ generate_chat_message_toolcall(training_data)
160
215
 
161
216
 
162
217
  def test_dataset_formatter_init_no_parent_task(mock_dataset):
@@ -170,14 +225,20 @@ def test_dataset_formatter_dump_invalid_format(mock_dataset):
170
225
  formatter = DatasetFormatter(mock_dataset, "system message")
171
226
 
172
227
  with pytest.raises(ValueError, match="Unsupported format"):
173
- formatter.dump_to_file("train", "invalid_format") # type: ignore
228
+ formatter.dump_to_file(
229
+ "train", "invalid_format", FinetuneDataStrategy.final_only
230
+ ) # type: ignore
174
231
 
175
232
 
176
233
  def test_dataset_formatter_dump_invalid_split(mock_dataset):
177
234
  formatter = DatasetFormatter(mock_dataset, "system message")
178
235
 
179
236
  with pytest.raises(ValueError, match="Split invalid_split not found in dataset"):
180
- formatter.dump_to_file("invalid_split", DatasetFormat.OPENAI_CHAT_JSONL)
237
+ formatter.dump_to_file(
238
+ "invalid_split",
239
+ DatasetFormat.OPENAI_CHAT_JSONL,
240
+ FinetuneDataStrategy.final_only,
241
+ )
181
242
 
182
243
 
183
244
  def test_dataset_formatter_dump_to_file(mock_dataset, tmp_path):
@@ -185,7 +246,10 @@ def test_dataset_formatter_dump_to_file(mock_dataset, tmp_path):
185
246
  output_path = tmp_path / "output.jsonl"
186
247
 
187
248
  result_path = formatter.dump_to_file(
188
- "train", DatasetFormat.OPENAI_CHAT_JSONL, output_path
249
+ "train",
250
+ DatasetFormat.OPENAI_CHAT_JSONL,
251
+ path=output_path,
252
+ data_strategy=FinetuneDataStrategy.final_only,
189
253
  )
190
254
 
191
255
  assert result_path == output_path
@@ -200,23 +264,38 @@ def test_dataset_formatter_dump_to_file(mock_dataset, tmp_path):
200
264
  assert "messages" in data
201
265
  assert len(data["messages"]) == 3
202
266
  assert data["messages"][0]["content"] == "system message"
203
- assert data["messages"][1]["content"] == '{"test": "input"}'
204
- assert data["messages"][2]["content"] == '{"test": "output"}'
267
+ assert data["messages"][1]["content"] == '{"test": "input 你好"}'
268
+ # Raw chat doesn't fix json issues, like extra spaces
269
+ assert data["messages"][2]["content"] == '{"test": "output 你好"}'
205
270
 
206
271
 
207
272
  def test_dataset_formatter_dump_to_temp_file(mock_dataset):
208
- formatter = DatasetFormatter(mock_dataset, "system message")
273
+ formatter = DatasetFormatter(mock_dataset, "system message 你好")
209
274
 
210
- result_path = formatter.dump_to_file("train", DatasetFormat.OPENAI_CHAT_JSONL)
275
+ result_path = formatter.dump_to_file(
276
+ "train",
277
+ DatasetFormat.OPENAI_CHAT_JSONL,
278
+ data_strategy=FinetuneDataStrategy.final_only,
279
+ )
211
280
 
212
281
  assert result_path.exists()
213
282
  assert result_path.parent == Path(tempfile.gettempdir())
214
- assert result_path.name.startswith("test_dataset_train_")
283
+ # Test our nice naming
284
+ assert result_path.name.startswith(
285
+ "test_dataset -- split-train -- format-openai_chat_jsonl -- no-cot.jsonl"
286
+ )
215
287
  assert result_path.name.endswith(".jsonl")
216
288
  # Verify file contents
217
289
  with open(result_path) as f:
218
290
  lines = f.readlines()
219
291
  assert len(lines) == 2
292
+ # check non-ascii characters are not escaped
293
+ assert "你好" in lines[0]
294
+ assert "你好" in lines[1]
295
+
296
+ # confirm didn't use COT for final_only
297
+ assert "thinking output" not in lines[0]
298
+ assert "thinking instructions" not in lines[0]
220
299
 
221
300
 
222
301
  def test_dataset_formatter_dump_to_file_tool_format(mock_dataset, tmp_path):
@@ -224,7 +303,10 @@ def test_dataset_formatter_dump_to_file_tool_format(mock_dataset, tmp_path):
224
303
  output_path = tmp_path / "output.jsonl"
225
304
 
226
305
  result_path = formatter.dump_to_file(
227
- "train", DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL, output_path
306
+ "train",
307
+ DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL,
308
+ path=output_path,
309
+ data_strategy=FinetuneDataStrategy.final_only,
228
310
  )
229
311
 
230
312
  assert result_path == output_path
@@ -240,7 +322,7 @@ def test_dataset_formatter_dump_to_file_tool_format(mock_dataset, tmp_path):
240
322
  assert len(data["messages"]) == 3
241
323
  # Check system and user messages
242
324
  assert data["messages"][0]["content"] == "system message"
243
- assert data["messages"][1]["content"] == '{"test": "input"}'
325
+ assert data["messages"][1]["content"] == '{"test": "input 你好"}'
244
326
  # Check tool call format
245
327
  assistant_msg = data["messages"][2]
246
328
  assert assistant_msg["content"] is None
@@ -249,30 +331,78 @@ def test_dataset_formatter_dump_to_file_tool_format(mock_dataset, tmp_path):
249
331
  tool_call = assistant_msg["tool_calls"][0]
250
332
  assert tool_call["type"] == "function"
251
333
  assert tool_call["function"]["name"] == "task_response"
252
- assert tool_call["function"]["arguments"] == '{"test": "output"}'
334
+ assert tool_call["function"]["arguments"] == '{"test": "output 你好"}'
335
+
336
+
337
+ def test_dataset_formatter_dump_with_intermediate_data(
338
+ mock_dataset, mock_intermediate_outputs
339
+ ):
340
+ formatter = DatasetFormatter(
341
+ mock_dataset,
342
+ "system message 你好",
343
+ thinking_instructions="thinking instructions",
344
+ )
345
+
346
+ result_path = formatter.dump_to_file(
347
+ "train",
348
+ DatasetFormat.OPENAI_CHAT_JSONL,
349
+ data_strategy=FinetuneDataStrategy.final_and_intermediate,
350
+ )
351
+
352
+ assert result_path.exists()
353
+ assert result_path.parent == Path(tempfile.gettempdir())
354
+ # Test our nice naming, with cot
355
+ assert (
356
+ result_path.name
357
+ == "test_dataset -- split-train -- format-openai_chat_jsonl -- cot.jsonl"
358
+ )
359
+ # Verify file contents
360
+ with open(result_path) as f:
361
+ lines = f.readlines()
362
+ assert len(lines) == 2
363
+ for line in lines:
364
+ assert "thinking output" in line
365
+ assert "thinking instructions" in line
366
+
367
+
368
+ def test_dataset_formatter_dump_with_intermediate_data_custom_instructions(
369
+ mock_dataset, mock_intermediate_outputs
370
+ ):
371
+ formatter = DatasetFormatter(
372
+ mock_dataset, "custom system message 你好", "custom thinking instructions"
373
+ )
374
+
375
+ result_path = formatter.dump_to_file(
376
+ "train",
377
+ DatasetFormat.OPENAI_CHAT_JSONL,
378
+ data_strategy=FinetuneDataStrategy.final_and_intermediate,
379
+ )
380
+
381
+ assert result_path.exists()
382
+ assert result_path.parent == Path(tempfile.gettempdir())
383
+ # Test our nice naming, with cot
384
+ assert (
385
+ result_path.name
386
+ == "test_dataset -- split-train -- format-openai_chat_jsonl -- cot.jsonl"
387
+ )
388
+ # Verify file contents
389
+ with open(result_path) as f:
390
+ lines = f.readlines()
391
+ assert len(lines) == 2
392
+ for line in lines:
393
+ assert "custom system message 你好" in line
394
+ assert "custom thinking instructions" in line
395
+ assert "thinking output" in line
253
396
 
254
397
 
255
398
  def test_generate_huggingface_chat_template():
256
- task_run = TaskRun(
257
- id="run1",
399
+ training_data = ModelTrainingData(
258
400
  input="test input",
259
- input_source=DataSource(
260
- type=DataSourceType.human, properties={"created_by": "test"}
261
- ),
262
- output=TaskOutput(
263
- output="test output",
264
- source=DataSource(
265
- type=DataSourceType.synthetic,
266
- properties={
267
- "model_name": "test",
268
- "model_provider": "test",
269
- "adapter_name": "test",
270
- },
271
- ),
272
- ),
401
+ system_message="system message",
402
+ final_output="test output",
273
403
  )
274
404
 
275
- result = generate_huggingface_chat_template(task_run, "system message")
405
+ result = generate_huggingface_chat_template(training_data)
276
406
 
277
407
  assert result == {
278
408
  "conversations": [
@@ -283,27 +413,96 @@ def test_generate_huggingface_chat_template():
283
413
  }
284
414
 
285
415
 
416
+ def test_generate_huggingface_chat_template_thinking():
417
+ training_data = ModelTrainingData(
418
+ input="test input",
419
+ system_message="system message",
420
+ final_output="test output",
421
+ thinking="thinking output",
422
+ thinking_instructions="thinking instructions",
423
+ thinking_final_answer_prompt="thinking final answer prompt",
424
+ )
425
+
426
+ result = generate_huggingface_chat_template(training_data)
427
+
428
+ assert result == {
429
+ "conversations": [
430
+ {"role": "system", "content": "system message"},
431
+ {"role": "user", "content": "test input"},
432
+ {"role": "user", "content": "thinking instructions"},
433
+ {"role": "assistant", "content": "thinking output"},
434
+ {"role": "user", "content": "thinking final answer prompt"},
435
+ {"role": "assistant", "content": "test output"},
436
+ ]
437
+ }
438
+
439
+
440
+ def test_generate_vertex_template():
441
+ training_data = ModelTrainingData(
442
+ input="test input",
443
+ system_message="system message",
444
+ final_output="test output",
445
+ )
446
+
447
+ result = generate_vertex_gemini_1_5(training_data)
448
+
449
+ assert result == {
450
+ "systemInstruction": {
451
+ "role": "system",
452
+ "parts": [
453
+ {
454
+ "text": "system message",
455
+ }
456
+ ],
457
+ },
458
+ "contents": [
459
+ {"role": "user", "parts": [{"text": "test input"}]},
460
+ {"role": "model", "parts": [{"text": "test output"}]},
461
+ ],
462
+ }
463
+
464
+
465
+ def test_generate_vertex_template_thinking():
466
+ training_data = ModelTrainingData(
467
+ input="test input",
468
+ system_message="system message",
469
+ final_output="test output",
470
+ thinking="thinking output",
471
+ thinking_instructions="thinking instructions",
472
+ thinking_final_answer_prompt="thinking final answer prompt",
473
+ )
474
+
475
+ result = generate_vertex_gemini_1_5(training_data)
476
+
477
+ print(result)
478
+
479
+ assert result == {
480
+ "systemInstruction": {
481
+ "role": "system",
482
+ "parts": [
483
+ {
484
+ "text": "system message",
485
+ }
486
+ ],
487
+ },
488
+ "contents": [
489
+ {"role": "user", "parts": [{"text": "test input"}]},
490
+ {"role": "user", "parts": [{"text": "thinking instructions"}]},
491
+ {"role": "model", "parts": [{"text": "thinking output"}]},
492
+ {"role": "user", "parts": [{"text": "thinking final answer prompt"}]},
493
+ {"role": "model", "parts": [{"text": "test output"}]},
494
+ ],
495
+ }
496
+
497
+
286
498
  def test_generate_huggingface_chat_template_toolcall():
287
- task_run = TaskRun(
288
- id="run1",
499
+ training_data = ModelTrainingData(
289
500
  input="test input",
290
- input_source=DataSource(
291
- type=DataSourceType.human, properties={"created_by": "test"}
292
- ),
293
- output=TaskOutput(
294
- output='{"key": "value"}',
295
- source=DataSource(
296
- type=DataSourceType.synthetic,
297
- properties={
298
- "model_name": "test",
299
- "model_provider": "test",
300
- "adapter_name": "test",
301
- },
302
- ),
303
- ),
501
+ system_message="system message",
502
+ final_output='{"key": "value"}',
304
503
  )
305
504
 
306
- result = generate_huggingface_chat_template_toolcall(task_run, "system message")
505
+ result = generate_huggingface_chat_template_toolcall(training_data)
307
506
 
308
507
  assert result["conversations"][0] == {"role": "system", "content": "system message"}
309
508
  assert result["conversations"][1] == {"role": "user", "content": "test input"}
@@ -318,25 +517,166 @@ def test_generate_huggingface_chat_template_toolcall():
318
517
  assert tool_call["function"]["arguments"] == {"key": "value"}
319
518
 
320
519
 
520
+ def test_generate_huggingface_chat_template_toolcall_thinking():
521
+ training_data = ModelTrainingData(
522
+ input="test input",
523
+ system_message="system message",
524
+ final_output='{"key": "value"}',
525
+ thinking="thinking output",
526
+ thinking_instructions="thinking instructions",
527
+ thinking_final_answer_prompt="thinking final answer prompt",
528
+ )
529
+
530
+ result = generate_huggingface_chat_template_toolcall(training_data)
531
+
532
+ assert result["conversations"][0] == {"role": "system", "content": "system message"}
533
+ assert result["conversations"][1] == {"role": "user", "content": "test input"}
534
+ assert result["conversations"][2] == {
535
+ "role": "user",
536
+ "content": "thinking instructions",
537
+ }
538
+ assert result["conversations"][3] == {
539
+ "role": "assistant",
540
+ "content": "thinking output",
541
+ }
542
+ assert result["conversations"][4] == {
543
+ "role": "user",
544
+ "content": "thinking final answer prompt",
545
+ }
546
+
547
+ assistant_msg = result["conversations"][5]
548
+ assert assistant_msg["role"] == "assistant"
549
+ assert len(assistant_msg["tool_calls"]) == 1
550
+ tool_call = assistant_msg["tool_calls"][0]
551
+ assert tool_call["type"] == "function"
552
+ assert tool_call["function"]["name"] == "task_response"
553
+ assert len(tool_call["function"]["id"]) == 9 # UUID is truncated to 9 chars
554
+ assert tool_call["function"]["id"].isalnum() # Check ID is alphanumeric
555
+ assert tool_call["function"]["arguments"] == {"key": "value"}
556
+
557
+
321
558
  def test_generate_huggingface_chat_template_toolcall_invalid_json():
322
- task_run = TaskRun(
323
- id="run1",
559
+ training_data = ModelTrainingData(
324
560
  input="test input",
325
- input_source=DataSource(
326
- type=DataSourceType.human, properties={"created_by": "test"}
327
- ),
328
- output=TaskOutput(
329
- output="invalid json",
330
- source=DataSource(
331
- type=DataSourceType.synthetic,
332
- properties={
333
- "model_name": "test",
334
- "model_provider": "test",
335
- "adapter_name": "test",
336
- },
337
- ),
338
- ),
561
+ system_message="system message",
562
+ final_output="invalid json",
339
563
  )
340
564
 
341
565
  with pytest.raises(ValueError, match="Invalid JSON in for tool call"):
342
- generate_huggingface_chat_template_toolcall(task_run, "system message")
566
+ generate_huggingface_chat_template_toolcall(training_data)
567
+
568
+
569
+ def test_build_training_data(mock_task):
570
+ # Non repaired should use original output
571
+ mock_task_run = mock_task.runs()[0]
572
+ training_data_output = build_training_data(mock_task_run, "system message", False)
573
+ assert training_data_output.final_output == '{"test": "output 你好"}'
574
+ assert training_data_output.thinking is None
575
+ assert training_data_output.thinking_instructions is None
576
+ assert training_data_output.thinking_final_answer_prompt is None
577
+ assert training_data_output.input == '{"test": "input 你好"}'
578
+ assert training_data_output.system_message == "system message"
579
+ assert not training_data_output.supports_cot()
580
+
581
+
582
+ def test_build_training_data_with_COT(mock_task):
583
+ # Setup with needed fields for thinking
584
+ mock_task_run = mock_task.runs()[0]
585
+ assert mock_task_run.parent_task() == mock_task
586
+ mock_task_run.intermediate_outputs = {"chain_of_thought": "cot output"}
587
+
588
+ training_data_output = build_training_data(
589
+ mock_task_run,
590
+ "system message",
591
+ True,
592
+ thinking_instructions="thinking instructions",
593
+ )
594
+ assert training_data_output.final_output == '{"test": "output 你好"}'
595
+ assert training_data_output.thinking == "cot output"
596
+ assert training_data_output.thinking_instructions == "thinking instructions"
597
+ assert training_data_output.thinking_final_answer_prompt == COT_FINAL_ANSWER_PROMPT
598
+ assert training_data_output.input == '{"test": "input 你好"}'
599
+ assert training_data_output.system_message == "system message"
600
+ assert training_data_output.supports_cot()
601
+
602
+
603
+ def test_build_training_data_with_thinking(mock_task):
604
+ # Setup with needed fields for thinking
605
+ mock_task_run = mock_task.runs()[0]
606
+ assert mock_task_run.parent_task() == mock_task
607
+ # It should just use the reasoning output if both thinking and chain_of_thought are present
608
+ mock_task_run.intermediate_outputs = {
609
+ "reasoning": "thinking output",
610
+ "chain_of_thought": "cot output",
611
+ }
612
+ mock_task.thinking_instruction = "thinking instructions"
613
+ assert mock_task.thinking_instruction == "thinking instructions"
614
+
615
+ training_data_output = build_training_data(
616
+ mock_task_run,
617
+ "system message",
618
+ True,
619
+ thinking_instructions="thinking instructions",
620
+ )
621
+ assert training_data_output.final_output == '{"test": "output 你好"}'
622
+ assert training_data_output.thinking == "thinking output"
623
+ assert training_data_output.thinking_instructions == "thinking instructions"
624
+ assert training_data_output.thinking_final_answer_prompt == COT_FINAL_ANSWER_PROMPT
625
+ assert training_data_output.input == '{"test": "input 你好"}'
626
+ assert training_data_output.system_message == "system message"
627
+ assert training_data_output.supports_cot()
628
+
629
+
630
+ def test_build_training_data_with_repaired_output(mock_task):
631
+ # use repaired output if available
632
+ mock_task_run = mock_task.runs()[0]
633
+ mock_task_run.repair_instructions = "repair instructions"
634
+ mock_task_run.repaired_output = TaskOutput(
635
+ output='{"test": "repaired output"}',
636
+ source=DataSource(
637
+ type=DataSourceType.human,
638
+ properties={"created_by": "test-user"},
639
+ ),
640
+ )
641
+
642
+ training_data_output = build_training_data(mock_task_run, "system message", False)
643
+ assert training_data_output.final_output == '{"test": "repaired output"}'
644
+ assert training_data_output.thinking is None
645
+ assert training_data_output.thinking_instructions is None
646
+ assert training_data_output.thinking_final_answer_prompt is None
647
+ assert training_data_output.input == '{"test": "input 你好"}'
648
+ assert training_data_output.system_message == "system message"
649
+
650
+
651
+ def test_dataset_formatter_dump_to_file_json_schema_format(mock_dataset, tmp_path):
652
+ formatter = DatasetFormatter(mock_dataset, "system message")
653
+ output_path = tmp_path / "output.jsonl"
654
+
655
+ result_path = formatter.dump_to_file(
656
+ "train",
657
+ DatasetFormat.OPENAI_CHAT_JSON_SCHEMA_JSONL,
658
+ path=output_path,
659
+ data_strategy=FinetuneDataStrategy.final_only,
660
+ )
661
+
662
+ assert result_path == output_path
663
+ assert output_path.exists()
664
+
665
+ # Verify file contents
666
+ with open(output_path) as f:
667
+ lines = f.readlines()
668
+ assert len(lines) == 2 # Should have 2 entries for train split
669
+ for line in lines:
670
+ data = json.loads(line)
671
+ assert "messages" in data
672
+ assert len(data["messages"]) == 3
673
+ # Check system and user messages
674
+ assert data["messages"][0]["content"] == "system message"
675
+ assert data["messages"][1]["content"] == '{"test": "input 你好"}'
676
+ # Check JSON format
677
+ assistant_msg = data["messages"][2]
678
+ assert assistant_msg["role"] == "assistant"
679
+ # Verify the content is valid JSON
680
+ assert assistant_msg["content"] == '{"test": "output 你好"}'
681
+ json_content = json.loads(assistant_msg["content"])
682
+ assert json_content == {"test": "output 你好"}