kiln-ai 0.8.1__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (88) hide show
  1. kiln_ai/adapters/__init__.py +7 -7
  2. kiln_ai/adapters/adapter_registry.py +81 -10
  3. kiln_ai/adapters/data_gen/data_gen_task.py +21 -3
  4. kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
  5. kiln_ai/adapters/eval/base_eval.py +164 -0
  6. kiln_ai/adapters/eval/eval_runner.py +267 -0
  7. kiln_ai/adapters/eval/g_eval.py +367 -0
  8. kiln_ai/adapters/eval/registry.py +16 -0
  9. kiln_ai/adapters/eval/test_base_eval.py +324 -0
  10. kiln_ai/adapters/eval/test_eval_runner.py +640 -0
  11. kiln_ai/adapters/eval/test_g_eval.py +497 -0
  12. kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
  13. kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
  14. kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
  15. kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
  16. kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
  17. kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
  18. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +472 -129
  19. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +114 -22
  20. kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
  21. kiln_ai/adapters/ml_model_list.py +434 -93
  22. kiln_ai/adapters/model_adapters/__init__.py +18 -0
  23. kiln_ai/adapters/model_adapters/base_adapter.py +250 -0
  24. kiln_ai/adapters/model_adapters/langchain_adapters.py +309 -0
  25. kiln_ai/adapters/model_adapters/openai_compatible_config.py +10 -0
  26. kiln_ai/adapters/model_adapters/openai_model_adapter.py +289 -0
  27. kiln_ai/adapters/model_adapters/test_base_adapter.py +199 -0
  28. kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +105 -97
  29. kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +216 -0
  30. kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +80 -30
  31. kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +125 -46
  32. kiln_ai/adapters/ollama_tools.py +0 -1
  33. kiln_ai/adapters/parsers/__init__.py +10 -0
  34. kiln_ai/adapters/parsers/base_parser.py +12 -0
  35. kiln_ai/adapters/parsers/json_parser.py +37 -0
  36. kiln_ai/adapters/parsers/parser_registry.py +19 -0
  37. kiln_ai/adapters/parsers/r1_parser.py +69 -0
  38. kiln_ai/adapters/parsers/test_json_parser.py +81 -0
  39. kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
  40. kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
  41. kiln_ai/adapters/prompt_builders.py +193 -49
  42. kiln_ai/adapters/provider_tools.py +91 -36
  43. kiln_ai/adapters/repair/repair_task.py +18 -19
  44. kiln_ai/adapters/repair/test_repair_task.py +7 -7
  45. kiln_ai/adapters/run_output.py +11 -0
  46. kiln_ai/adapters/test_adapter_registry.py +177 -0
  47. kiln_ai/adapters/test_generate_docs.py +69 -0
  48. kiln_ai/adapters/test_ollama_tools.py +0 -1
  49. kiln_ai/adapters/test_prompt_adaptors.py +25 -18
  50. kiln_ai/adapters/test_prompt_builders.py +265 -44
  51. kiln_ai/adapters/test_provider_tools.py +268 -46
  52. kiln_ai/datamodel/__init__.py +51 -772
  53. kiln_ai/datamodel/basemodel.py +31 -11
  54. kiln_ai/datamodel/datamodel_enums.py +58 -0
  55. kiln_ai/datamodel/dataset_filters.py +114 -0
  56. kiln_ai/datamodel/dataset_split.py +170 -0
  57. kiln_ai/datamodel/eval.py +298 -0
  58. kiln_ai/datamodel/finetune.py +105 -0
  59. kiln_ai/datamodel/json_schema.py +14 -3
  60. kiln_ai/datamodel/model_cache.py +8 -3
  61. kiln_ai/datamodel/project.py +23 -0
  62. kiln_ai/datamodel/prompt.py +37 -0
  63. kiln_ai/datamodel/prompt_id.py +83 -0
  64. kiln_ai/datamodel/strict_mode.py +24 -0
  65. kiln_ai/datamodel/task.py +181 -0
  66. kiln_ai/datamodel/task_output.py +321 -0
  67. kiln_ai/datamodel/task_run.py +164 -0
  68. kiln_ai/datamodel/test_basemodel.py +80 -2
  69. kiln_ai/datamodel/test_dataset_filters.py +71 -0
  70. kiln_ai/datamodel/test_dataset_split.py +127 -6
  71. kiln_ai/datamodel/test_datasource.py +3 -2
  72. kiln_ai/datamodel/test_eval_model.py +635 -0
  73. kiln_ai/datamodel/test_example_models.py +34 -17
  74. kiln_ai/datamodel/test_json_schema.py +23 -0
  75. kiln_ai/datamodel/test_model_cache.py +24 -0
  76. kiln_ai/datamodel/test_model_perf.py +125 -0
  77. kiln_ai/datamodel/test_models.py +131 -2
  78. kiln_ai/datamodel/test_prompt_id.py +129 -0
  79. kiln_ai/datamodel/test_task.py +159 -0
  80. kiln_ai/utils/config.py +6 -1
  81. kiln_ai/utils/exhaustive_error.py +6 -0
  82. {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/METADATA +45 -7
  83. kiln_ai-0.12.0.dist-info/RECORD +100 -0
  84. kiln_ai/adapters/base_adapter.py +0 -191
  85. kiln_ai/adapters/langchain_adapters.py +0 -256
  86. kiln_ai-0.8.1.dist-info/RECORD +0 -58
  87. {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/WHEEL +0 -0
  88. {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  import json
2
+ import logging
2
3
  import tempfile
3
4
  from pathlib import Path
4
5
  from unittest.mock import Mock
@@ -8,49 +9,84 @@ import pytest
8
9
  from kiln_ai.adapters.fine_tune.dataset_formatter import (
9
10
  DatasetFormat,
10
11
  DatasetFormatter,
12
+ ModelTrainingData,
13
+ build_training_data,
11
14
  generate_chat_message_response,
12
15
  generate_chat_message_toolcall,
13
16
  generate_huggingface_chat_template,
14
17
  generate_huggingface_chat_template_toolcall,
18
+ generate_vertex_gemini_1_5,
15
19
  )
20
+ from kiln_ai.adapters.model_adapters.base_adapter import COT_FINAL_ANSWER_PROMPT
16
21
  from kiln_ai.datamodel import (
17
22
  DatasetSplit,
18
23
  DataSource,
19
24
  DataSourceType,
25
+ FinetuneDataStrategy,
20
26
  Task,
21
27
  TaskOutput,
22
28
  TaskRun,
23
29
  )
24
30
 
31
+ logger = logging.getLogger(__name__)
32
+
25
33
 
26
34
  @pytest.fixture
27
35
  def mock_task():
28
- task = Mock(spec=Task)
36
+ task = Mock(spec=Task, thinking_instruction=None)
29
37
  task_runs = [
30
- TaskRun(
31
- id=f"run{i}",
32
- input='{"test": "input"}',
33
- input_source=DataSource(
34
- type=DataSourceType.human, properties={"created_by": "test"}
35
- ),
36
- output=TaskOutput(
37
- output='{"test": "output"}',
38
- source=DataSource(
39
- type=DataSourceType.synthetic,
40
- properties={
41
- "model_name": "test",
42
- "model_provider": "test",
43
- "adapter_name": "test",
38
+ Mock(
39
+ spec=TaskRun,
40
+ **{
41
+ "id": f"run{i}",
42
+ "input": '{"test": "input 你好"}',
43
+ "repaired_output": None,
44
+ "intermediate_outputs": {},
45
+ "input_source": Mock(
46
+ spec=DataSource,
47
+ **{
48
+ "type": DataSourceType.human,
49
+ "properties": {"created_by": "test"},
50
+ },
51
+ ),
52
+ "output": Mock(
53
+ spec=TaskOutput,
54
+ **{
55
+ "output": '{"test": "output 你好"}',
56
+ "source": Mock(
57
+ spec=DataSource,
58
+ **{
59
+ "type": DataSourceType.synthetic,
60
+ "properties": {
61
+ "model_name": "test",
62
+ "model_provider": "test",
63
+ "adapter_name": "test",
64
+ },
65
+ },
66
+ ),
44
67
  },
45
68
  ),
46
- ),
69
+ },
47
70
  )
48
71
  for i in range(1, 4)
49
72
  ]
73
+
74
+ # Set up parent_task reference for each TaskRun
75
+ for run in task_runs:
76
+ run.parent_task = Mock(return_value=task)
77
+
50
78
  task.runs.return_value = task_runs
51
79
  return task
52
80
 
53
81
 
82
+ @pytest.fixture
83
+ def mock_intermediate_outputs(mock_task):
84
+ for run in mock_task.runs():
85
+ run.intermediate_outputs = {"reasoning": "thinking output"}
86
+ mock_task.thinking_instruction = "thinking instructions"
87
+ return mock_task
88
+
89
+
54
90
  @pytest.fixture
55
91
  def mock_dataset(mock_task):
56
92
  dataset = Mock(spec=DatasetSplit)
@@ -61,26 +97,13 @@ def mock_dataset(mock_task):
61
97
 
62
98
 
63
99
  def test_generate_chat_message_response():
64
- task_run = TaskRun(
65
- id="run1",
100
+ thinking_data = ModelTrainingData(
66
101
  input="test input",
67
- input_source=DataSource(
68
- type=DataSourceType.human, properties={"created_by": "test"}
69
- ),
70
- output=TaskOutput(
71
- output="test output",
72
- source=DataSource(
73
- type=DataSourceType.synthetic,
74
- properties={
75
- "model_name": "test",
76
- "model_provider": "test",
77
- "adapter_name": "test",
78
- },
79
- ),
80
- ),
102
+ system_message="system message",
103
+ final_output="test output",
81
104
  )
82
105
 
83
- result = generate_chat_message_response(task_run, "system message")
106
+ result = generate_chat_message_response(thinking_data)
84
107
 
85
108
  assert result == {
86
109
  "messages": [
@@ -91,32 +114,80 @@ def test_generate_chat_message_response():
91
114
  }
92
115
 
93
116
 
117
+ def test_generate_chat_message_response_thinking():
118
+ thinking_data = ModelTrainingData(
119
+ input="test input",
120
+ system_message="system message",
121
+ final_output="test output",
122
+ thinking="thinking output",
123
+ thinking_instructions="thinking instructions",
124
+ thinking_final_answer_prompt="thinking final answer prompt",
125
+ )
126
+
127
+ result = generate_chat_message_response(thinking_data)
128
+
129
+ assert result == {
130
+ "messages": [
131
+ {"role": "system", "content": "system message"},
132
+ {"role": "user", "content": "test input"},
133
+ {"role": "user", "content": "thinking instructions"},
134
+ {"role": "assistant", "content": "thinking output"},
135
+ {"role": "user", "content": "thinking final answer prompt"},
136
+ {"role": "assistant", "content": "test output"},
137
+ ]
138
+ }
139
+
140
+
94
141
  def test_generate_chat_message_toolcall():
95
- task_run = TaskRun(
96
- id="run1",
142
+ training_data = ModelTrainingData(
143
+ input="test input 你好",
144
+ system_message="system message 你好",
145
+ final_output='{"key": "value 你好"}',
146
+ )
147
+
148
+ result = generate_chat_message_toolcall(training_data)
149
+
150
+ assert result == {
151
+ "messages": [
152
+ {"role": "system", "content": "system message 你好"},
153
+ {"role": "user", "content": "test input 你好"},
154
+ {
155
+ "role": "assistant",
156
+ "content": None,
157
+ "tool_calls": [
158
+ {
159
+ "id": "call_1",
160
+ "type": "function",
161
+ "function": {
162
+ "name": "task_response",
163
+ "arguments": '{"key": "value 你好"}',
164
+ },
165
+ }
166
+ ],
167
+ },
168
+ ]
169
+ }
170
+
171
+
172
+ def test_generate_chat_message_toolcall_thinking():
173
+ training_data = ModelTrainingData(
97
174
  input="test input",
98
- input_source=DataSource(
99
- type=DataSourceType.human, properties={"created_by": "test"}
100
- ),
101
- output=TaskOutput(
102
- output='{"key": "value"}',
103
- source=DataSource(
104
- type=DataSourceType.synthetic,
105
- properties={
106
- "model_name": "test",
107
- "model_provider": "test",
108
- "adapter_name": "test",
109
- },
110
- ),
111
- ),
175
+ system_message="system message",
176
+ final_output='{"key": "value"}',
177
+ thinking="thinking output",
178
+ thinking_instructions="thinking instructions",
179
+ thinking_final_answer_prompt="thinking final answer prompt",
112
180
  )
113
181
 
114
- result = generate_chat_message_toolcall(task_run, "system message")
182
+ result = generate_chat_message_toolcall(training_data)
115
183
 
116
184
  assert result == {
117
185
  "messages": [
118
186
  {"role": "system", "content": "system message"},
119
187
  {"role": "user", "content": "test input"},
188
+ {"role": "user", "content": "thinking instructions"},
189
+ {"role": "assistant", "content": "thinking output"},
190
+ {"role": "user", "content": "thinking final answer prompt"},
120
191
  {
121
192
  "role": "assistant",
122
193
  "content": None,
@@ -136,27 +207,14 @@ def test_generate_chat_message_toolcall():
136
207
 
137
208
 
138
209
  def test_generate_chat_message_toolcall_invalid_json():
139
- task_run = TaskRun(
140
- id="run1",
210
+ training_data = ModelTrainingData(
141
211
  input="test input",
142
- input_source=DataSource(
143
- type=DataSourceType.human, properties={"created_by": "test"}
144
- ),
145
- output=TaskOutput(
146
- output="invalid json",
147
- source=DataSource(
148
- type=DataSourceType.synthetic,
149
- properties={
150
- "model_name": "test",
151
- "model_provider": "test",
152
- "adapter_name": "test",
153
- },
154
- ),
155
- ),
212
+ system_message="system message",
213
+ final_output="invalid json",
156
214
  )
157
215
 
158
216
  with pytest.raises(ValueError, match="Invalid JSON in for tool call"):
159
- generate_chat_message_toolcall(task_run, "system message")
217
+ generate_chat_message_toolcall(training_data)
160
218
 
161
219
 
162
220
  def test_dataset_formatter_init_no_parent_task(mock_dataset):
@@ -170,14 +228,20 @@ def test_dataset_formatter_dump_invalid_format(mock_dataset):
170
228
  formatter = DatasetFormatter(mock_dataset, "system message")
171
229
 
172
230
  with pytest.raises(ValueError, match="Unsupported format"):
173
- formatter.dump_to_file("train", "invalid_format") # type: ignore
231
+ formatter.dump_to_file(
232
+ "train", "invalid_format", FinetuneDataStrategy.final_only
233
+ ) # type: ignore
174
234
 
175
235
 
176
236
  def test_dataset_formatter_dump_invalid_split(mock_dataset):
177
237
  formatter = DatasetFormatter(mock_dataset, "system message")
178
238
 
179
239
  with pytest.raises(ValueError, match="Split invalid_split not found in dataset"):
180
- formatter.dump_to_file("invalid_split", DatasetFormat.OPENAI_CHAT_JSONL)
240
+ formatter.dump_to_file(
241
+ "invalid_split",
242
+ DatasetFormat.OPENAI_CHAT_JSONL,
243
+ FinetuneDataStrategy.final_only,
244
+ )
181
245
 
182
246
 
183
247
  def test_dataset_formatter_dump_to_file(mock_dataset, tmp_path):
@@ -185,7 +249,10 @@ def test_dataset_formatter_dump_to_file(mock_dataset, tmp_path):
185
249
  output_path = tmp_path / "output.jsonl"
186
250
 
187
251
  result_path = formatter.dump_to_file(
188
- "train", DatasetFormat.OPENAI_CHAT_JSONL, output_path
252
+ "train",
253
+ DatasetFormat.OPENAI_CHAT_JSONL,
254
+ path=output_path,
255
+ data_strategy=FinetuneDataStrategy.final_only,
189
256
  )
190
257
 
191
258
  assert result_path == output_path
@@ -200,23 +267,38 @@ def test_dataset_formatter_dump_to_file(mock_dataset, tmp_path):
200
267
  assert "messages" in data
201
268
  assert len(data["messages"]) == 3
202
269
  assert data["messages"][0]["content"] == "system message"
203
- assert data["messages"][1]["content"] == '{"test": "input"}'
204
- assert data["messages"][2]["content"] == '{"test": "output"}'
270
+ assert data["messages"][1]["content"] == '{"test": "input 你好"}'
271
+ # Raw chat doesn't fix json issues, like extra spaces
272
+ assert data["messages"][2]["content"] == '{"test": "output 你好"}'
205
273
 
206
274
 
207
275
  def test_dataset_formatter_dump_to_temp_file(mock_dataset):
208
- formatter = DatasetFormatter(mock_dataset, "system message")
276
+ formatter = DatasetFormatter(mock_dataset, "system message 你好")
209
277
 
210
- result_path = formatter.dump_to_file("train", DatasetFormat.OPENAI_CHAT_JSONL)
278
+ result_path = formatter.dump_to_file(
279
+ "train",
280
+ DatasetFormat.OPENAI_CHAT_JSONL,
281
+ data_strategy=FinetuneDataStrategy.final_only,
282
+ )
211
283
 
212
284
  assert result_path.exists()
213
285
  assert result_path.parent == Path(tempfile.gettempdir())
214
- assert result_path.name.startswith("test_dataset_train_")
286
+ # Test our nice naming
287
+ assert result_path.name.startswith(
288
+ "test_dataset -- split-train -- format-openai_chat_jsonl -- no-cot.jsonl"
289
+ )
215
290
  assert result_path.name.endswith(".jsonl")
216
291
  # Verify file contents
217
292
  with open(result_path) as f:
218
293
  lines = f.readlines()
219
294
  assert len(lines) == 2
295
+ # check non-ascii characters are not escaped
296
+ assert "你好" in lines[0]
297
+ assert "你好" in lines[1]
298
+
299
+ # confirm didn't use COT for final_only
300
+ assert "thinking output" not in lines[0]
301
+ assert "thinking instructions" not in lines[0]
220
302
 
221
303
 
222
304
  def test_dataset_formatter_dump_to_file_tool_format(mock_dataset, tmp_path):
@@ -224,7 +306,10 @@ def test_dataset_formatter_dump_to_file_tool_format(mock_dataset, tmp_path):
224
306
  output_path = tmp_path / "output.jsonl"
225
307
 
226
308
  result_path = formatter.dump_to_file(
227
- "train", DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL, output_path
309
+ "train",
310
+ DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL,
311
+ path=output_path,
312
+ data_strategy=FinetuneDataStrategy.final_only,
228
313
  )
229
314
 
230
315
  assert result_path == output_path
@@ -240,7 +325,7 @@ def test_dataset_formatter_dump_to_file_tool_format(mock_dataset, tmp_path):
240
325
  assert len(data["messages"]) == 3
241
326
  # Check system and user messages
242
327
  assert data["messages"][0]["content"] == "system message"
243
- assert data["messages"][1]["content"] == '{"test": "input"}'
328
+ assert data["messages"][1]["content"] == '{"test": "input 你好"}'
244
329
  # Check tool call format
245
330
  assistant_msg = data["messages"][2]
246
331
  assert assistant_msg["content"] is None
@@ -249,61 +334,178 @@ def test_dataset_formatter_dump_to_file_tool_format(mock_dataset, tmp_path):
249
334
  tool_call = assistant_msg["tool_calls"][0]
250
335
  assert tool_call["type"] == "function"
251
336
  assert tool_call["function"]["name"] == "task_response"
252
- assert tool_call["function"]["arguments"] == '{"test": "output"}'
337
+ assert tool_call["function"]["arguments"] == '{"test": "output 你好"}'
338
+
339
+
340
+ def test_dataset_formatter_dump_with_intermediate_data(
341
+ mock_dataset, mock_intermediate_outputs
342
+ ):
343
+ formatter = DatasetFormatter(
344
+ mock_dataset,
345
+ "system message 你好",
346
+ thinking_instructions="thinking instructions",
347
+ )
348
+
349
+ result_path = formatter.dump_to_file(
350
+ "train",
351
+ DatasetFormat.OPENAI_CHAT_JSONL,
352
+ data_strategy=FinetuneDataStrategy.final_and_intermediate,
353
+ )
354
+
355
+ assert result_path.exists()
356
+ assert result_path.parent == Path(tempfile.gettempdir())
357
+ # Test our nice naming, with cot
358
+ assert (
359
+ result_path.name
360
+ == "test_dataset -- split-train -- format-openai_chat_jsonl -- cot.jsonl"
361
+ )
362
+ # Verify file contents
363
+ with open(result_path) as f:
364
+ lines = f.readlines()
365
+ assert len(lines) == 2
366
+ for line in lines:
367
+ assert "thinking output" in line
368
+ assert "thinking instructions" in line
369
+
370
+
371
+ def test_dataset_formatter_dump_with_intermediate_data_custom_instructions(
372
+ mock_dataset, mock_intermediate_outputs
373
+ ):
374
+ formatter = DatasetFormatter(
375
+ mock_dataset, "custom system message 你好", "custom thinking instructions"
376
+ )
377
+
378
+ result_path = formatter.dump_to_file(
379
+ "train",
380
+ DatasetFormat.OPENAI_CHAT_JSONL,
381
+ data_strategy=FinetuneDataStrategy.final_and_intermediate,
382
+ )
383
+
384
+ assert result_path.exists()
385
+ assert result_path.parent == Path(tempfile.gettempdir())
386
+ # Test our nice naming, with cot
387
+ assert (
388
+ result_path.name
389
+ == "test_dataset -- split-train -- format-openai_chat_jsonl -- cot.jsonl"
390
+ )
391
+ # Verify file contents
392
+ with open(result_path) as f:
393
+ lines = f.readlines()
394
+ assert len(lines) == 2
395
+ for line in lines:
396
+ assert "custom system message 你好" in line
397
+ assert "custom thinking instructions" in line
398
+ assert "thinking output" in line
253
399
 
254
400
 
255
401
  def test_generate_huggingface_chat_template():
256
- task_run = TaskRun(
257
- id="run1",
402
+ training_data = ModelTrainingData(
258
403
  input="test input",
259
- input_source=DataSource(
260
- type=DataSourceType.human, properties={"created_by": "test"}
261
- ),
262
- output=TaskOutput(
263
- output="test output",
264
- source=DataSource(
265
- type=DataSourceType.synthetic,
266
- properties={
267
- "model_name": "test",
268
- "model_provider": "test",
269
- "adapter_name": "test",
270
- },
271
- ),
272
- ),
404
+ system_message="system message",
405
+ final_output="test output",
406
+ )
407
+
408
+ result = generate_huggingface_chat_template(training_data)
409
+
410
+ assert result == {
411
+ "conversations": [
412
+ {"role": "system", "content": "system message"},
413
+ {"role": "user", "content": "test input"},
414
+ {"role": "assistant", "content": "test output"},
415
+ ]
416
+ }
417
+
418
+
419
+ def test_generate_huggingface_chat_template_thinking():
420
+ training_data = ModelTrainingData(
421
+ input="test input",
422
+ system_message="system message",
423
+ final_output="test output",
424
+ thinking="thinking output",
425
+ thinking_instructions="thinking instructions",
426
+ thinking_final_answer_prompt="thinking final answer prompt",
273
427
  )
274
428
 
275
- result = generate_huggingface_chat_template(task_run, "system message")
429
+ result = generate_huggingface_chat_template(training_data)
276
430
 
277
431
  assert result == {
278
432
  "conversations": [
279
433
  {"role": "system", "content": "system message"},
280
434
  {"role": "user", "content": "test input"},
435
+ {"role": "user", "content": "thinking instructions"},
436
+ {"role": "assistant", "content": "thinking output"},
437
+ {"role": "user", "content": "thinking final answer prompt"},
281
438
  {"role": "assistant", "content": "test output"},
282
439
  ]
283
440
  }
284
441
 
285
442
 
443
+ def test_generate_vertex_template():
444
+ training_data = ModelTrainingData(
445
+ input="test input",
446
+ system_message="system message",
447
+ final_output="test output",
448
+ )
449
+
450
+ result = generate_vertex_gemini_1_5(training_data)
451
+
452
+ assert result == {
453
+ "systemInstruction": {
454
+ "role": "system",
455
+ "parts": [
456
+ {
457
+ "text": "system message",
458
+ }
459
+ ],
460
+ },
461
+ "contents": [
462
+ {"role": "user", "parts": [{"text": "test input"}]},
463
+ {"role": "model", "parts": [{"text": "test output"}]},
464
+ ],
465
+ }
466
+
467
+
468
+ def test_generate_vertex_template_thinking():
469
+ training_data = ModelTrainingData(
470
+ input="test input",
471
+ system_message="system message",
472
+ final_output="test output",
473
+ thinking="thinking output",
474
+ thinking_instructions="thinking instructions",
475
+ thinking_final_answer_prompt="thinking final answer prompt",
476
+ )
477
+
478
+ result = generate_vertex_gemini_1_5(training_data)
479
+
480
+ logger.info(result)
481
+
482
+ assert result == {
483
+ "systemInstruction": {
484
+ "role": "system",
485
+ "parts": [
486
+ {
487
+ "text": "system message",
488
+ }
489
+ ],
490
+ },
491
+ "contents": [
492
+ {"role": "user", "parts": [{"text": "test input"}]},
493
+ {"role": "user", "parts": [{"text": "thinking instructions"}]},
494
+ {"role": "model", "parts": [{"text": "thinking output"}]},
495
+ {"role": "user", "parts": [{"text": "thinking final answer prompt"}]},
496
+ {"role": "model", "parts": [{"text": "test output"}]},
497
+ ],
498
+ }
499
+
500
+
286
501
  def test_generate_huggingface_chat_template_toolcall():
287
- task_run = TaskRun(
288
- id="run1",
502
+ training_data = ModelTrainingData(
289
503
  input="test input",
290
- input_source=DataSource(
291
- type=DataSourceType.human, properties={"created_by": "test"}
292
- ),
293
- output=TaskOutput(
294
- output='{"key": "value"}',
295
- source=DataSource(
296
- type=DataSourceType.synthetic,
297
- properties={
298
- "model_name": "test",
299
- "model_provider": "test",
300
- "adapter_name": "test",
301
- },
302
- ),
303
- ),
504
+ system_message="system message",
505
+ final_output='{"key": "value"}',
304
506
  )
305
507
 
306
- result = generate_huggingface_chat_template_toolcall(task_run, "system message")
508
+ result = generate_huggingface_chat_template_toolcall(training_data)
307
509
 
308
510
  assert result["conversations"][0] == {"role": "system", "content": "system message"}
309
511
  assert result["conversations"][1] == {"role": "user", "content": "test input"}
@@ -318,25 +520,166 @@ def test_generate_huggingface_chat_template_toolcall():
318
520
  assert tool_call["function"]["arguments"] == {"key": "value"}
319
521
 
320
522
 
523
+ def test_generate_huggingface_chat_template_toolcall_thinking():
524
+ training_data = ModelTrainingData(
525
+ input="test input",
526
+ system_message="system message",
527
+ final_output='{"key": "value"}',
528
+ thinking="thinking output",
529
+ thinking_instructions="thinking instructions",
530
+ thinking_final_answer_prompt="thinking final answer prompt",
531
+ )
532
+
533
+ result = generate_huggingface_chat_template_toolcall(training_data)
534
+
535
+ assert result["conversations"][0] == {"role": "system", "content": "system message"}
536
+ assert result["conversations"][1] == {"role": "user", "content": "test input"}
537
+ assert result["conversations"][2] == {
538
+ "role": "user",
539
+ "content": "thinking instructions",
540
+ }
541
+ assert result["conversations"][3] == {
542
+ "role": "assistant",
543
+ "content": "thinking output",
544
+ }
545
+ assert result["conversations"][4] == {
546
+ "role": "user",
547
+ "content": "thinking final answer prompt",
548
+ }
549
+
550
+ assistant_msg = result["conversations"][5]
551
+ assert assistant_msg["role"] == "assistant"
552
+ assert len(assistant_msg["tool_calls"]) == 1
553
+ tool_call = assistant_msg["tool_calls"][0]
554
+ assert tool_call["type"] == "function"
555
+ assert tool_call["function"]["name"] == "task_response"
556
+ assert len(tool_call["function"]["id"]) == 9 # UUID is truncated to 9 chars
557
+ assert tool_call["function"]["id"].isalnum() # Check ID is alphanumeric
558
+ assert tool_call["function"]["arguments"] == {"key": "value"}
559
+
560
+
321
561
  def test_generate_huggingface_chat_template_toolcall_invalid_json():
322
- task_run = TaskRun(
323
- id="run1",
562
+ training_data = ModelTrainingData(
324
563
  input="test input",
325
- input_source=DataSource(
326
- type=DataSourceType.human, properties={"created_by": "test"}
327
- ),
328
- output=TaskOutput(
329
- output="invalid json",
330
- source=DataSource(
331
- type=DataSourceType.synthetic,
332
- properties={
333
- "model_name": "test",
334
- "model_provider": "test",
335
- "adapter_name": "test",
336
- },
337
- ),
338
- ),
564
+ system_message="system message",
565
+ final_output="invalid json",
339
566
  )
340
567
 
341
568
  with pytest.raises(ValueError, match="Invalid JSON in for tool call"):
342
- generate_huggingface_chat_template_toolcall(task_run, "system message")
569
+ generate_huggingface_chat_template_toolcall(training_data)
570
+
571
+
572
+ def test_build_training_data(mock_task):
573
+ # Non repaired should use original output
574
+ mock_task_run = mock_task.runs()[0]
575
+ training_data_output = build_training_data(mock_task_run, "system message", False)
576
+ assert training_data_output.final_output == '{"test": "output 你好"}'
577
+ assert training_data_output.thinking is None
578
+ assert training_data_output.thinking_instructions is None
579
+ assert training_data_output.thinking_final_answer_prompt is None
580
+ assert training_data_output.input == '{"test": "input 你好"}'
581
+ assert training_data_output.system_message == "system message"
582
+ assert not training_data_output.supports_cot()
583
+
584
+
585
+ def test_build_training_data_with_COT(mock_task):
586
+ # Setup with needed fields for thinking
587
+ mock_task_run = mock_task.runs()[0]
588
+ assert mock_task_run.parent_task() == mock_task
589
+ mock_task_run.intermediate_outputs = {"chain_of_thought": "cot output"}
590
+
591
+ training_data_output = build_training_data(
592
+ mock_task_run,
593
+ "system message",
594
+ True,
595
+ thinking_instructions="thinking instructions",
596
+ )
597
+ assert training_data_output.final_output == '{"test": "output 你好"}'
598
+ assert training_data_output.thinking == "cot output"
599
+ assert training_data_output.thinking_instructions == "thinking instructions"
600
+ assert training_data_output.thinking_final_answer_prompt == COT_FINAL_ANSWER_PROMPT
601
+ assert training_data_output.input == '{"test": "input 你好"}'
602
+ assert training_data_output.system_message == "system message"
603
+ assert training_data_output.supports_cot()
604
+
605
+
606
+ def test_build_training_data_with_thinking(mock_task):
607
+ # Setup with needed fields for thinking
608
+ mock_task_run = mock_task.runs()[0]
609
+ assert mock_task_run.parent_task() == mock_task
610
+ # It should just use the reasoning output if both thinking and chain_of_thought are present
611
+ mock_task_run.intermediate_outputs = {
612
+ "reasoning": "thinking output",
613
+ "chain_of_thought": "cot output",
614
+ }
615
+ mock_task.thinking_instruction = "thinking instructions"
616
+ assert mock_task.thinking_instruction == "thinking instructions"
617
+
618
+ training_data_output = build_training_data(
619
+ mock_task_run,
620
+ "system message",
621
+ True,
622
+ thinking_instructions="thinking instructions",
623
+ )
624
+ assert training_data_output.final_output == '{"test": "output 你好"}'
625
+ assert training_data_output.thinking == "thinking output"
626
+ assert training_data_output.thinking_instructions == "thinking instructions"
627
+ assert training_data_output.thinking_final_answer_prompt == COT_FINAL_ANSWER_PROMPT
628
+ assert training_data_output.input == '{"test": "input 你好"}'
629
+ assert training_data_output.system_message == "system message"
630
+ assert training_data_output.supports_cot()
631
+
632
+
633
+ def test_build_training_data_with_repaired_output(mock_task):
634
+ # use repaired output if available
635
+ mock_task_run = mock_task.runs()[0]
636
+ mock_task_run.repair_instructions = "repair instructions"
637
+ mock_task_run.repaired_output = TaskOutput(
638
+ output='{"test": "repaired output"}',
639
+ source=DataSource(
640
+ type=DataSourceType.human,
641
+ properties={"created_by": "test-user"},
642
+ ),
643
+ )
644
+
645
+ training_data_output = build_training_data(mock_task_run, "system message", False)
646
+ assert training_data_output.final_output == '{"test": "repaired output"}'
647
+ assert training_data_output.thinking is None
648
+ assert training_data_output.thinking_instructions is None
649
+ assert training_data_output.thinking_final_answer_prompt is None
650
+ assert training_data_output.input == '{"test": "input 你好"}'
651
+ assert training_data_output.system_message == "system message"
652
+
653
+
654
+ def test_dataset_formatter_dump_to_file_json_schema_format(mock_dataset, tmp_path):
655
+ formatter = DatasetFormatter(mock_dataset, "system message")
656
+ output_path = tmp_path / "output.jsonl"
657
+
658
+ result_path = formatter.dump_to_file(
659
+ "train",
660
+ DatasetFormat.OPENAI_CHAT_JSON_SCHEMA_JSONL,
661
+ path=output_path,
662
+ data_strategy=FinetuneDataStrategy.final_only,
663
+ )
664
+
665
+ assert result_path == output_path
666
+ assert output_path.exists()
667
+
668
+ # Verify file contents
669
+ with open(output_path) as f:
670
+ lines = f.readlines()
671
+ assert len(lines) == 2 # Should have 2 entries for train split
672
+ for line in lines:
673
+ data = json.loads(line)
674
+ assert "messages" in data
675
+ assert len(data["messages"]) == 3
676
+ # Check system and user messages
677
+ assert data["messages"][0]["content"] == "system message"
678
+ assert data["messages"][1]["content"] == '{"test": "input 你好"}'
679
+ # Check JSON format
680
+ assistant_msg = data["messages"][2]
681
+ assert assistant_msg["role"] == "assistant"
682
+ # Verify the content is valid JSON
683
+ assert assistant_msg["content"] == '{"test": "output 你好"}'
684
+ json_content = json.loads(assistant_msg["content"])
685
+ assert json_content == {"test": "output 你好"}