kiln-ai 0.8.1__py3-none-any.whl → 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +7 -7
- kiln_ai/adapters/adapter_registry.py +77 -5
- kiln_ai/adapters/data_gen/data_gen_task.py +3 -3
- kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
- kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
- kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
- kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
- kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +469 -129
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +113 -21
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
- kiln_ai/adapters/ml_model_list.py +323 -94
- kiln_ai/adapters/model_adapters/__init__.py +18 -0
- kiln_ai/adapters/{base_adapter.py → model_adapters/base_adapter.py} +81 -37
- kiln_ai/adapters/{langchain_adapters.py → model_adapters/langchain_adapters.py} +130 -84
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +11 -0
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +246 -0
- kiln_ai/adapters/model_adapters/test_base_adapter.py +190 -0
- kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +103 -88
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +225 -0
- kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +43 -15
- kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +93 -20
- kiln_ai/adapters/parsers/__init__.py +10 -0
- kiln_ai/adapters/parsers/base_parser.py +12 -0
- kiln_ai/adapters/parsers/json_parser.py +37 -0
- kiln_ai/adapters/parsers/parser_registry.py +19 -0
- kiln_ai/adapters/parsers/r1_parser.py +69 -0
- kiln_ai/adapters/parsers/test_json_parser.py +81 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
- kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
- kiln_ai/adapters/prompt_builders.py +126 -20
- kiln_ai/adapters/provider_tools.py +91 -36
- kiln_ai/adapters/repair/repair_task.py +17 -6
- kiln_ai/adapters/repair/test_repair_task.py +4 -4
- kiln_ai/adapters/run_output.py +8 -0
- kiln_ai/adapters/test_adapter_registry.py +177 -0
- kiln_ai/adapters/test_generate_docs.py +69 -0
- kiln_ai/adapters/test_prompt_adaptors.py +8 -4
- kiln_ai/adapters/test_prompt_builders.py +190 -29
- kiln_ai/adapters/test_provider_tools.py +268 -46
- kiln_ai/datamodel/__init__.py +193 -12
- kiln_ai/datamodel/basemodel.py +31 -11
- kiln_ai/datamodel/json_schema.py +8 -3
- kiln_ai/datamodel/model_cache.py +8 -3
- kiln_ai/datamodel/test_basemodel.py +81 -2
- kiln_ai/datamodel/test_dataset_split.py +100 -3
- kiln_ai/datamodel/test_example_models.py +25 -4
- kiln_ai/datamodel/test_model_cache.py +24 -0
- kiln_ai/datamodel/test_model_perf.py +125 -0
- kiln_ai/datamodel/test_models.py +129 -0
- kiln_ai/utils/exhaustive_error.py +6 -0
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.11.1.dist-info}/METADATA +9 -7
- kiln_ai-0.11.1.dist-info/RECORD +76 -0
- kiln_ai-0.8.1.dist-info/RECORD +0 -58
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.11.1.dist-info}/WHEEL +0 -0
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.11.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -8,15 +8,20 @@ import pytest
|
|
|
8
8
|
from kiln_ai.adapters.fine_tune.dataset_formatter import (
|
|
9
9
|
DatasetFormat,
|
|
10
10
|
DatasetFormatter,
|
|
11
|
+
ModelTrainingData,
|
|
12
|
+
build_training_data,
|
|
11
13
|
generate_chat_message_response,
|
|
12
14
|
generate_chat_message_toolcall,
|
|
13
15
|
generate_huggingface_chat_template,
|
|
14
16
|
generate_huggingface_chat_template_toolcall,
|
|
17
|
+
generate_vertex_gemini_1_5,
|
|
15
18
|
)
|
|
19
|
+
from kiln_ai.adapters.model_adapters.base_adapter import COT_FINAL_ANSWER_PROMPT
|
|
16
20
|
from kiln_ai.datamodel import (
|
|
17
21
|
DatasetSplit,
|
|
18
22
|
DataSource,
|
|
19
23
|
DataSourceType,
|
|
24
|
+
FinetuneDataStrategy,
|
|
20
25
|
Task,
|
|
21
26
|
TaskOutput,
|
|
22
27
|
TaskRun,
|
|
@@ -25,32 +30,60 @@ from kiln_ai.datamodel import (
|
|
|
25
30
|
|
|
26
31
|
@pytest.fixture
|
|
27
32
|
def mock_task():
|
|
28
|
-
task = Mock(spec=Task)
|
|
33
|
+
task = Mock(spec=Task, thinking_instruction=None)
|
|
29
34
|
task_runs = [
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
"
|
|
42
|
-
"model_provider": "test",
|
|
43
|
-
"adapter_name": "test",
|
|
35
|
+
Mock(
|
|
36
|
+
spec=TaskRun,
|
|
37
|
+
**{
|
|
38
|
+
"id": f"run{i}",
|
|
39
|
+
"input": '{"test": "input 你好"}',
|
|
40
|
+
"repaired_output": None,
|
|
41
|
+
"intermediate_outputs": {},
|
|
42
|
+
"input_source": Mock(
|
|
43
|
+
spec=DataSource,
|
|
44
|
+
**{
|
|
45
|
+
"type": DataSourceType.human,
|
|
46
|
+
"properties": {"created_by": "test"},
|
|
44
47
|
},
|
|
45
48
|
),
|
|
46
|
-
|
|
49
|
+
"output": Mock(
|
|
50
|
+
spec=TaskOutput,
|
|
51
|
+
**{
|
|
52
|
+
"output": '{"test": "output 你好"}',
|
|
53
|
+
"source": Mock(
|
|
54
|
+
spec=DataSource,
|
|
55
|
+
**{
|
|
56
|
+
"type": DataSourceType.synthetic,
|
|
57
|
+
"properties": {
|
|
58
|
+
"model_name": "test",
|
|
59
|
+
"model_provider": "test",
|
|
60
|
+
"adapter_name": "test",
|
|
61
|
+
},
|
|
62
|
+
},
|
|
63
|
+
),
|
|
64
|
+
},
|
|
65
|
+
),
|
|
66
|
+
},
|
|
47
67
|
)
|
|
48
68
|
for i in range(1, 4)
|
|
49
69
|
]
|
|
70
|
+
|
|
71
|
+
# Set up parent_task reference for each TaskRun
|
|
72
|
+
for run in task_runs:
|
|
73
|
+
run.parent_task = Mock(return_value=task)
|
|
74
|
+
|
|
50
75
|
task.runs.return_value = task_runs
|
|
51
76
|
return task
|
|
52
77
|
|
|
53
78
|
|
|
79
|
+
@pytest.fixture
|
|
80
|
+
def mock_intermediate_outputs(mock_task):
|
|
81
|
+
for run in mock_task.runs():
|
|
82
|
+
run.intermediate_outputs = {"reasoning": "thinking output"}
|
|
83
|
+
mock_task.thinking_instruction = "thinking instructions"
|
|
84
|
+
return mock_task
|
|
85
|
+
|
|
86
|
+
|
|
54
87
|
@pytest.fixture
|
|
55
88
|
def mock_dataset(mock_task):
|
|
56
89
|
dataset = Mock(spec=DatasetSplit)
|
|
@@ -61,26 +94,13 @@ def mock_dataset(mock_task):
|
|
|
61
94
|
|
|
62
95
|
|
|
63
96
|
def test_generate_chat_message_response():
|
|
64
|
-
|
|
65
|
-
id="run1",
|
|
97
|
+
thinking_data = ModelTrainingData(
|
|
66
98
|
input="test input",
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
),
|
|
70
|
-
output=TaskOutput(
|
|
71
|
-
output="test output",
|
|
72
|
-
source=DataSource(
|
|
73
|
-
type=DataSourceType.synthetic,
|
|
74
|
-
properties={
|
|
75
|
-
"model_name": "test",
|
|
76
|
-
"model_provider": "test",
|
|
77
|
-
"adapter_name": "test",
|
|
78
|
-
},
|
|
79
|
-
),
|
|
80
|
-
),
|
|
99
|
+
system_message="system message",
|
|
100
|
+
final_output="test output",
|
|
81
101
|
)
|
|
82
102
|
|
|
83
|
-
result = generate_chat_message_response(
|
|
103
|
+
result = generate_chat_message_response(thinking_data)
|
|
84
104
|
|
|
85
105
|
assert result == {
|
|
86
106
|
"messages": [
|
|
@@ -91,32 +111,80 @@ def test_generate_chat_message_response():
|
|
|
91
111
|
}
|
|
92
112
|
|
|
93
113
|
|
|
114
|
+
def test_generate_chat_message_response_thinking():
|
|
115
|
+
thinking_data = ModelTrainingData(
|
|
116
|
+
input="test input",
|
|
117
|
+
system_message="system message",
|
|
118
|
+
final_output="test output",
|
|
119
|
+
thinking="thinking output",
|
|
120
|
+
thinking_instructions="thinking instructions",
|
|
121
|
+
thinking_final_answer_prompt="thinking final answer prompt",
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
result = generate_chat_message_response(thinking_data)
|
|
125
|
+
|
|
126
|
+
assert result == {
|
|
127
|
+
"messages": [
|
|
128
|
+
{"role": "system", "content": "system message"},
|
|
129
|
+
{"role": "user", "content": "test input"},
|
|
130
|
+
{"role": "user", "content": "thinking instructions"},
|
|
131
|
+
{"role": "assistant", "content": "thinking output"},
|
|
132
|
+
{"role": "user", "content": "thinking final answer prompt"},
|
|
133
|
+
{"role": "assistant", "content": "test output"},
|
|
134
|
+
]
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
|
|
94
138
|
def test_generate_chat_message_toolcall():
|
|
95
|
-
|
|
96
|
-
|
|
139
|
+
training_data = ModelTrainingData(
|
|
140
|
+
input="test input 你好",
|
|
141
|
+
system_message="system message 你好",
|
|
142
|
+
final_output='{"key": "value 你好"}',
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
result = generate_chat_message_toolcall(training_data)
|
|
146
|
+
|
|
147
|
+
assert result == {
|
|
148
|
+
"messages": [
|
|
149
|
+
{"role": "system", "content": "system message 你好"},
|
|
150
|
+
{"role": "user", "content": "test input 你好"},
|
|
151
|
+
{
|
|
152
|
+
"role": "assistant",
|
|
153
|
+
"content": None,
|
|
154
|
+
"tool_calls": [
|
|
155
|
+
{
|
|
156
|
+
"id": "call_1",
|
|
157
|
+
"type": "function",
|
|
158
|
+
"function": {
|
|
159
|
+
"name": "task_response",
|
|
160
|
+
"arguments": '{"key": "value 你好"}',
|
|
161
|
+
},
|
|
162
|
+
}
|
|
163
|
+
],
|
|
164
|
+
},
|
|
165
|
+
]
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def test_generate_chat_message_toolcall_thinking():
|
|
170
|
+
training_data = ModelTrainingData(
|
|
97
171
|
input="test input",
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
source=DataSource(
|
|
104
|
-
type=DataSourceType.synthetic,
|
|
105
|
-
properties={
|
|
106
|
-
"model_name": "test",
|
|
107
|
-
"model_provider": "test",
|
|
108
|
-
"adapter_name": "test",
|
|
109
|
-
},
|
|
110
|
-
),
|
|
111
|
-
),
|
|
172
|
+
system_message="system message",
|
|
173
|
+
final_output='{"key": "value"}',
|
|
174
|
+
thinking="thinking output",
|
|
175
|
+
thinking_instructions="thinking instructions",
|
|
176
|
+
thinking_final_answer_prompt="thinking final answer prompt",
|
|
112
177
|
)
|
|
113
178
|
|
|
114
|
-
result = generate_chat_message_toolcall(
|
|
179
|
+
result = generate_chat_message_toolcall(training_data)
|
|
115
180
|
|
|
116
181
|
assert result == {
|
|
117
182
|
"messages": [
|
|
118
183
|
{"role": "system", "content": "system message"},
|
|
119
184
|
{"role": "user", "content": "test input"},
|
|
185
|
+
{"role": "user", "content": "thinking instructions"},
|
|
186
|
+
{"role": "assistant", "content": "thinking output"},
|
|
187
|
+
{"role": "user", "content": "thinking final answer prompt"},
|
|
120
188
|
{
|
|
121
189
|
"role": "assistant",
|
|
122
190
|
"content": None,
|
|
@@ -136,27 +204,14 @@ def test_generate_chat_message_toolcall():
|
|
|
136
204
|
|
|
137
205
|
|
|
138
206
|
def test_generate_chat_message_toolcall_invalid_json():
|
|
139
|
-
|
|
140
|
-
id="run1",
|
|
207
|
+
training_data = ModelTrainingData(
|
|
141
208
|
input="test input",
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
),
|
|
145
|
-
output=TaskOutput(
|
|
146
|
-
output="invalid json",
|
|
147
|
-
source=DataSource(
|
|
148
|
-
type=DataSourceType.synthetic,
|
|
149
|
-
properties={
|
|
150
|
-
"model_name": "test",
|
|
151
|
-
"model_provider": "test",
|
|
152
|
-
"adapter_name": "test",
|
|
153
|
-
},
|
|
154
|
-
),
|
|
155
|
-
),
|
|
209
|
+
system_message="system message",
|
|
210
|
+
final_output="invalid json",
|
|
156
211
|
)
|
|
157
212
|
|
|
158
213
|
with pytest.raises(ValueError, match="Invalid JSON in for tool call"):
|
|
159
|
-
generate_chat_message_toolcall(
|
|
214
|
+
generate_chat_message_toolcall(training_data)
|
|
160
215
|
|
|
161
216
|
|
|
162
217
|
def test_dataset_formatter_init_no_parent_task(mock_dataset):
|
|
@@ -170,14 +225,20 @@ def test_dataset_formatter_dump_invalid_format(mock_dataset):
|
|
|
170
225
|
formatter = DatasetFormatter(mock_dataset, "system message")
|
|
171
226
|
|
|
172
227
|
with pytest.raises(ValueError, match="Unsupported format"):
|
|
173
|
-
formatter.dump_to_file(
|
|
228
|
+
formatter.dump_to_file(
|
|
229
|
+
"train", "invalid_format", FinetuneDataStrategy.final_only
|
|
230
|
+
) # type: ignore
|
|
174
231
|
|
|
175
232
|
|
|
176
233
|
def test_dataset_formatter_dump_invalid_split(mock_dataset):
|
|
177
234
|
formatter = DatasetFormatter(mock_dataset, "system message")
|
|
178
235
|
|
|
179
236
|
with pytest.raises(ValueError, match="Split invalid_split not found in dataset"):
|
|
180
|
-
formatter.dump_to_file(
|
|
237
|
+
formatter.dump_to_file(
|
|
238
|
+
"invalid_split",
|
|
239
|
+
DatasetFormat.OPENAI_CHAT_JSONL,
|
|
240
|
+
FinetuneDataStrategy.final_only,
|
|
241
|
+
)
|
|
181
242
|
|
|
182
243
|
|
|
183
244
|
def test_dataset_formatter_dump_to_file(mock_dataset, tmp_path):
|
|
@@ -185,7 +246,10 @@ def test_dataset_formatter_dump_to_file(mock_dataset, tmp_path):
|
|
|
185
246
|
output_path = tmp_path / "output.jsonl"
|
|
186
247
|
|
|
187
248
|
result_path = formatter.dump_to_file(
|
|
188
|
-
"train",
|
|
249
|
+
"train",
|
|
250
|
+
DatasetFormat.OPENAI_CHAT_JSONL,
|
|
251
|
+
path=output_path,
|
|
252
|
+
data_strategy=FinetuneDataStrategy.final_only,
|
|
189
253
|
)
|
|
190
254
|
|
|
191
255
|
assert result_path == output_path
|
|
@@ -200,23 +264,38 @@ def test_dataset_formatter_dump_to_file(mock_dataset, tmp_path):
|
|
|
200
264
|
assert "messages" in data
|
|
201
265
|
assert len(data["messages"]) == 3
|
|
202
266
|
assert data["messages"][0]["content"] == "system message"
|
|
203
|
-
assert data["messages"][1]["content"] == '{"test": "input"}'
|
|
204
|
-
|
|
267
|
+
assert data["messages"][1]["content"] == '{"test": "input 你好"}'
|
|
268
|
+
# Raw chat doesn't fix json issues, like extra spaces
|
|
269
|
+
assert data["messages"][2]["content"] == '{"test": "output 你好"}'
|
|
205
270
|
|
|
206
271
|
|
|
207
272
|
def test_dataset_formatter_dump_to_temp_file(mock_dataset):
|
|
208
|
-
formatter = DatasetFormatter(mock_dataset, "system message")
|
|
273
|
+
formatter = DatasetFormatter(mock_dataset, "system message 你好")
|
|
209
274
|
|
|
210
|
-
result_path = formatter.dump_to_file(
|
|
275
|
+
result_path = formatter.dump_to_file(
|
|
276
|
+
"train",
|
|
277
|
+
DatasetFormat.OPENAI_CHAT_JSONL,
|
|
278
|
+
data_strategy=FinetuneDataStrategy.final_only,
|
|
279
|
+
)
|
|
211
280
|
|
|
212
281
|
assert result_path.exists()
|
|
213
282
|
assert result_path.parent == Path(tempfile.gettempdir())
|
|
214
|
-
|
|
283
|
+
# Test our nice naming
|
|
284
|
+
assert result_path.name.startswith(
|
|
285
|
+
"test_dataset -- split-train -- format-openai_chat_jsonl -- no-cot.jsonl"
|
|
286
|
+
)
|
|
215
287
|
assert result_path.name.endswith(".jsonl")
|
|
216
288
|
# Verify file contents
|
|
217
289
|
with open(result_path) as f:
|
|
218
290
|
lines = f.readlines()
|
|
219
291
|
assert len(lines) == 2
|
|
292
|
+
# check non-ascii characters are not escaped
|
|
293
|
+
assert "你好" in lines[0]
|
|
294
|
+
assert "你好" in lines[1]
|
|
295
|
+
|
|
296
|
+
# confirm didn't use COT for final_only
|
|
297
|
+
assert "thinking output" not in lines[0]
|
|
298
|
+
assert "thinking instructions" not in lines[0]
|
|
220
299
|
|
|
221
300
|
|
|
222
301
|
def test_dataset_formatter_dump_to_file_tool_format(mock_dataset, tmp_path):
|
|
@@ -224,7 +303,10 @@ def test_dataset_formatter_dump_to_file_tool_format(mock_dataset, tmp_path):
|
|
|
224
303
|
output_path = tmp_path / "output.jsonl"
|
|
225
304
|
|
|
226
305
|
result_path = formatter.dump_to_file(
|
|
227
|
-
"train",
|
|
306
|
+
"train",
|
|
307
|
+
DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL,
|
|
308
|
+
path=output_path,
|
|
309
|
+
data_strategy=FinetuneDataStrategy.final_only,
|
|
228
310
|
)
|
|
229
311
|
|
|
230
312
|
assert result_path == output_path
|
|
@@ -240,7 +322,7 @@ def test_dataset_formatter_dump_to_file_tool_format(mock_dataset, tmp_path):
|
|
|
240
322
|
assert len(data["messages"]) == 3
|
|
241
323
|
# Check system and user messages
|
|
242
324
|
assert data["messages"][0]["content"] == "system message"
|
|
243
|
-
assert data["messages"][1]["content"] == '{"test": "input"}'
|
|
325
|
+
assert data["messages"][1]["content"] == '{"test": "input 你好"}'
|
|
244
326
|
# Check tool call format
|
|
245
327
|
assistant_msg = data["messages"][2]
|
|
246
328
|
assert assistant_msg["content"] is None
|
|
@@ -249,30 +331,78 @@ def test_dataset_formatter_dump_to_file_tool_format(mock_dataset, tmp_path):
|
|
|
249
331
|
tool_call = assistant_msg["tool_calls"][0]
|
|
250
332
|
assert tool_call["type"] == "function"
|
|
251
333
|
assert tool_call["function"]["name"] == "task_response"
|
|
252
|
-
assert tool_call["function"]["arguments"] == '{"test": "output"}'
|
|
334
|
+
assert tool_call["function"]["arguments"] == '{"test": "output 你好"}'
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def test_dataset_formatter_dump_with_intermediate_data(
|
|
338
|
+
mock_dataset, mock_intermediate_outputs
|
|
339
|
+
):
|
|
340
|
+
formatter = DatasetFormatter(
|
|
341
|
+
mock_dataset,
|
|
342
|
+
"system message 你好",
|
|
343
|
+
thinking_instructions="thinking instructions",
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
result_path = formatter.dump_to_file(
|
|
347
|
+
"train",
|
|
348
|
+
DatasetFormat.OPENAI_CHAT_JSONL,
|
|
349
|
+
data_strategy=FinetuneDataStrategy.final_and_intermediate,
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
assert result_path.exists()
|
|
353
|
+
assert result_path.parent == Path(tempfile.gettempdir())
|
|
354
|
+
# Test our nice naming, with cot
|
|
355
|
+
assert (
|
|
356
|
+
result_path.name
|
|
357
|
+
== "test_dataset -- split-train -- format-openai_chat_jsonl -- cot.jsonl"
|
|
358
|
+
)
|
|
359
|
+
# Verify file contents
|
|
360
|
+
with open(result_path) as f:
|
|
361
|
+
lines = f.readlines()
|
|
362
|
+
assert len(lines) == 2
|
|
363
|
+
for line in lines:
|
|
364
|
+
assert "thinking output" in line
|
|
365
|
+
assert "thinking instructions" in line
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def test_dataset_formatter_dump_with_intermediate_data_custom_instructions(
|
|
369
|
+
mock_dataset, mock_intermediate_outputs
|
|
370
|
+
):
|
|
371
|
+
formatter = DatasetFormatter(
|
|
372
|
+
mock_dataset, "custom system message 你好", "custom thinking instructions"
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
result_path = formatter.dump_to_file(
|
|
376
|
+
"train",
|
|
377
|
+
DatasetFormat.OPENAI_CHAT_JSONL,
|
|
378
|
+
data_strategy=FinetuneDataStrategy.final_and_intermediate,
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
assert result_path.exists()
|
|
382
|
+
assert result_path.parent == Path(tempfile.gettempdir())
|
|
383
|
+
# Test our nice naming, with cot
|
|
384
|
+
assert (
|
|
385
|
+
result_path.name
|
|
386
|
+
== "test_dataset -- split-train -- format-openai_chat_jsonl -- cot.jsonl"
|
|
387
|
+
)
|
|
388
|
+
# Verify file contents
|
|
389
|
+
with open(result_path) as f:
|
|
390
|
+
lines = f.readlines()
|
|
391
|
+
assert len(lines) == 2
|
|
392
|
+
for line in lines:
|
|
393
|
+
assert "custom system message 你好" in line
|
|
394
|
+
assert "custom thinking instructions" in line
|
|
395
|
+
assert "thinking output" in line
|
|
253
396
|
|
|
254
397
|
|
|
255
398
|
def test_generate_huggingface_chat_template():
|
|
256
|
-
|
|
257
|
-
id="run1",
|
|
399
|
+
training_data = ModelTrainingData(
|
|
258
400
|
input="test input",
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
),
|
|
262
|
-
output=TaskOutput(
|
|
263
|
-
output="test output",
|
|
264
|
-
source=DataSource(
|
|
265
|
-
type=DataSourceType.synthetic,
|
|
266
|
-
properties={
|
|
267
|
-
"model_name": "test",
|
|
268
|
-
"model_provider": "test",
|
|
269
|
-
"adapter_name": "test",
|
|
270
|
-
},
|
|
271
|
-
),
|
|
272
|
-
),
|
|
401
|
+
system_message="system message",
|
|
402
|
+
final_output="test output",
|
|
273
403
|
)
|
|
274
404
|
|
|
275
|
-
result = generate_huggingface_chat_template(
|
|
405
|
+
result = generate_huggingface_chat_template(training_data)
|
|
276
406
|
|
|
277
407
|
assert result == {
|
|
278
408
|
"conversations": [
|
|
@@ -283,27 +413,96 @@ def test_generate_huggingface_chat_template():
|
|
|
283
413
|
}
|
|
284
414
|
|
|
285
415
|
|
|
416
|
+
def test_generate_huggingface_chat_template_thinking():
|
|
417
|
+
training_data = ModelTrainingData(
|
|
418
|
+
input="test input",
|
|
419
|
+
system_message="system message",
|
|
420
|
+
final_output="test output",
|
|
421
|
+
thinking="thinking output",
|
|
422
|
+
thinking_instructions="thinking instructions",
|
|
423
|
+
thinking_final_answer_prompt="thinking final answer prompt",
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
result = generate_huggingface_chat_template(training_data)
|
|
427
|
+
|
|
428
|
+
assert result == {
|
|
429
|
+
"conversations": [
|
|
430
|
+
{"role": "system", "content": "system message"},
|
|
431
|
+
{"role": "user", "content": "test input"},
|
|
432
|
+
{"role": "user", "content": "thinking instructions"},
|
|
433
|
+
{"role": "assistant", "content": "thinking output"},
|
|
434
|
+
{"role": "user", "content": "thinking final answer prompt"},
|
|
435
|
+
{"role": "assistant", "content": "test output"},
|
|
436
|
+
]
|
|
437
|
+
}
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
def test_generate_vertex_template():
|
|
441
|
+
training_data = ModelTrainingData(
|
|
442
|
+
input="test input",
|
|
443
|
+
system_message="system message",
|
|
444
|
+
final_output="test output",
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
result = generate_vertex_gemini_1_5(training_data)
|
|
448
|
+
|
|
449
|
+
assert result == {
|
|
450
|
+
"systemInstruction": {
|
|
451
|
+
"role": "system",
|
|
452
|
+
"parts": [
|
|
453
|
+
{
|
|
454
|
+
"text": "system message",
|
|
455
|
+
}
|
|
456
|
+
],
|
|
457
|
+
},
|
|
458
|
+
"contents": [
|
|
459
|
+
{"role": "user", "parts": [{"text": "test input"}]},
|
|
460
|
+
{"role": "model", "parts": [{"text": "test output"}]},
|
|
461
|
+
],
|
|
462
|
+
}
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
def test_generate_vertex_template_thinking():
|
|
466
|
+
training_data = ModelTrainingData(
|
|
467
|
+
input="test input",
|
|
468
|
+
system_message="system message",
|
|
469
|
+
final_output="test output",
|
|
470
|
+
thinking="thinking output",
|
|
471
|
+
thinking_instructions="thinking instructions",
|
|
472
|
+
thinking_final_answer_prompt="thinking final answer prompt",
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
result = generate_vertex_gemini_1_5(training_data)
|
|
476
|
+
|
|
477
|
+
print(result)
|
|
478
|
+
|
|
479
|
+
assert result == {
|
|
480
|
+
"systemInstruction": {
|
|
481
|
+
"role": "system",
|
|
482
|
+
"parts": [
|
|
483
|
+
{
|
|
484
|
+
"text": "system message",
|
|
485
|
+
}
|
|
486
|
+
],
|
|
487
|
+
},
|
|
488
|
+
"contents": [
|
|
489
|
+
{"role": "user", "parts": [{"text": "test input"}]},
|
|
490
|
+
{"role": "user", "parts": [{"text": "thinking instructions"}]},
|
|
491
|
+
{"role": "model", "parts": [{"text": "thinking output"}]},
|
|
492
|
+
{"role": "user", "parts": [{"text": "thinking final answer prompt"}]},
|
|
493
|
+
{"role": "model", "parts": [{"text": "test output"}]},
|
|
494
|
+
],
|
|
495
|
+
}
|
|
496
|
+
|
|
497
|
+
|
|
286
498
|
def test_generate_huggingface_chat_template_toolcall():
|
|
287
|
-
|
|
288
|
-
id="run1",
|
|
499
|
+
training_data = ModelTrainingData(
|
|
289
500
|
input="test input",
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
),
|
|
293
|
-
output=TaskOutput(
|
|
294
|
-
output='{"key": "value"}',
|
|
295
|
-
source=DataSource(
|
|
296
|
-
type=DataSourceType.synthetic,
|
|
297
|
-
properties={
|
|
298
|
-
"model_name": "test",
|
|
299
|
-
"model_provider": "test",
|
|
300
|
-
"adapter_name": "test",
|
|
301
|
-
},
|
|
302
|
-
),
|
|
303
|
-
),
|
|
501
|
+
system_message="system message",
|
|
502
|
+
final_output='{"key": "value"}',
|
|
304
503
|
)
|
|
305
504
|
|
|
306
|
-
result = generate_huggingface_chat_template_toolcall(
|
|
505
|
+
result = generate_huggingface_chat_template_toolcall(training_data)
|
|
307
506
|
|
|
308
507
|
assert result["conversations"][0] == {"role": "system", "content": "system message"}
|
|
309
508
|
assert result["conversations"][1] == {"role": "user", "content": "test input"}
|
|
@@ -318,25 +517,166 @@ def test_generate_huggingface_chat_template_toolcall():
|
|
|
318
517
|
assert tool_call["function"]["arguments"] == {"key": "value"}
|
|
319
518
|
|
|
320
519
|
|
|
520
|
+
def test_generate_huggingface_chat_template_toolcall_thinking():
|
|
521
|
+
training_data = ModelTrainingData(
|
|
522
|
+
input="test input",
|
|
523
|
+
system_message="system message",
|
|
524
|
+
final_output='{"key": "value"}',
|
|
525
|
+
thinking="thinking output",
|
|
526
|
+
thinking_instructions="thinking instructions",
|
|
527
|
+
thinking_final_answer_prompt="thinking final answer prompt",
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
result = generate_huggingface_chat_template_toolcall(training_data)
|
|
531
|
+
|
|
532
|
+
assert result["conversations"][0] == {"role": "system", "content": "system message"}
|
|
533
|
+
assert result["conversations"][1] == {"role": "user", "content": "test input"}
|
|
534
|
+
assert result["conversations"][2] == {
|
|
535
|
+
"role": "user",
|
|
536
|
+
"content": "thinking instructions",
|
|
537
|
+
}
|
|
538
|
+
assert result["conversations"][3] == {
|
|
539
|
+
"role": "assistant",
|
|
540
|
+
"content": "thinking output",
|
|
541
|
+
}
|
|
542
|
+
assert result["conversations"][4] == {
|
|
543
|
+
"role": "user",
|
|
544
|
+
"content": "thinking final answer prompt",
|
|
545
|
+
}
|
|
546
|
+
|
|
547
|
+
assistant_msg = result["conversations"][5]
|
|
548
|
+
assert assistant_msg["role"] == "assistant"
|
|
549
|
+
assert len(assistant_msg["tool_calls"]) == 1
|
|
550
|
+
tool_call = assistant_msg["tool_calls"][0]
|
|
551
|
+
assert tool_call["type"] == "function"
|
|
552
|
+
assert tool_call["function"]["name"] == "task_response"
|
|
553
|
+
assert len(tool_call["function"]["id"]) == 9 # UUID is truncated to 9 chars
|
|
554
|
+
assert tool_call["function"]["id"].isalnum() # Check ID is alphanumeric
|
|
555
|
+
assert tool_call["function"]["arguments"] == {"key": "value"}
|
|
556
|
+
|
|
557
|
+
|
|
321
558
|
def test_generate_huggingface_chat_template_toolcall_invalid_json():
|
|
322
|
-
|
|
323
|
-
id="run1",
|
|
559
|
+
training_data = ModelTrainingData(
|
|
324
560
|
input="test input",
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
),
|
|
328
|
-
output=TaskOutput(
|
|
329
|
-
output="invalid json",
|
|
330
|
-
source=DataSource(
|
|
331
|
-
type=DataSourceType.synthetic,
|
|
332
|
-
properties={
|
|
333
|
-
"model_name": "test",
|
|
334
|
-
"model_provider": "test",
|
|
335
|
-
"adapter_name": "test",
|
|
336
|
-
},
|
|
337
|
-
),
|
|
338
|
-
),
|
|
561
|
+
system_message="system message",
|
|
562
|
+
final_output="invalid json",
|
|
339
563
|
)
|
|
340
564
|
|
|
341
565
|
with pytest.raises(ValueError, match="Invalid JSON in for tool call"):
|
|
342
|
-
generate_huggingface_chat_template_toolcall(
|
|
566
|
+
generate_huggingface_chat_template_toolcall(training_data)
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
def test_build_training_data(mock_task):
|
|
570
|
+
# Non repaired should use original output
|
|
571
|
+
mock_task_run = mock_task.runs()[0]
|
|
572
|
+
training_data_output = build_training_data(mock_task_run, "system message", False)
|
|
573
|
+
assert training_data_output.final_output == '{"test": "output 你好"}'
|
|
574
|
+
assert training_data_output.thinking is None
|
|
575
|
+
assert training_data_output.thinking_instructions is None
|
|
576
|
+
assert training_data_output.thinking_final_answer_prompt is None
|
|
577
|
+
assert training_data_output.input == '{"test": "input 你好"}'
|
|
578
|
+
assert training_data_output.system_message == "system message"
|
|
579
|
+
assert not training_data_output.supports_cot()
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
def test_build_training_data_with_COT(mock_task):
|
|
583
|
+
# Setup with needed fields for thinking
|
|
584
|
+
mock_task_run = mock_task.runs()[0]
|
|
585
|
+
assert mock_task_run.parent_task() == mock_task
|
|
586
|
+
mock_task_run.intermediate_outputs = {"chain_of_thought": "cot output"}
|
|
587
|
+
|
|
588
|
+
training_data_output = build_training_data(
|
|
589
|
+
mock_task_run,
|
|
590
|
+
"system message",
|
|
591
|
+
True,
|
|
592
|
+
thinking_instructions="thinking instructions",
|
|
593
|
+
)
|
|
594
|
+
assert training_data_output.final_output == '{"test": "output 你好"}'
|
|
595
|
+
assert training_data_output.thinking == "cot output"
|
|
596
|
+
assert training_data_output.thinking_instructions == "thinking instructions"
|
|
597
|
+
assert training_data_output.thinking_final_answer_prompt == COT_FINAL_ANSWER_PROMPT
|
|
598
|
+
assert training_data_output.input == '{"test": "input 你好"}'
|
|
599
|
+
assert training_data_output.system_message == "system message"
|
|
600
|
+
assert training_data_output.supports_cot()
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
def test_build_training_data_with_thinking(mock_task):
|
|
604
|
+
# Setup with needed fields for thinking
|
|
605
|
+
mock_task_run = mock_task.runs()[0]
|
|
606
|
+
assert mock_task_run.parent_task() == mock_task
|
|
607
|
+
# It should just use the reasoning output if both thinking and chain_of_thought are present
|
|
608
|
+
mock_task_run.intermediate_outputs = {
|
|
609
|
+
"reasoning": "thinking output",
|
|
610
|
+
"chain_of_thought": "cot output",
|
|
611
|
+
}
|
|
612
|
+
mock_task.thinking_instruction = "thinking instructions"
|
|
613
|
+
assert mock_task.thinking_instruction == "thinking instructions"
|
|
614
|
+
|
|
615
|
+
training_data_output = build_training_data(
|
|
616
|
+
mock_task_run,
|
|
617
|
+
"system message",
|
|
618
|
+
True,
|
|
619
|
+
thinking_instructions="thinking instructions",
|
|
620
|
+
)
|
|
621
|
+
assert training_data_output.final_output == '{"test": "output 你好"}'
|
|
622
|
+
assert training_data_output.thinking == "thinking output"
|
|
623
|
+
assert training_data_output.thinking_instructions == "thinking instructions"
|
|
624
|
+
assert training_data_output.thinking_final_answer_prompt == COT_FINAL_ANSWER_PROMPT
|
|
625
|
+
assert training_data_output.input == '{"test": "input 你好"}'
|
|
626
|
+
assert training_data_output.system_message == "system message"
|
|
627
|
+
assert training_data_output.supports_cot()
|
|
628
|
+
|
|
629
|
+
|
|
630
|
+
def test_build_training_data_with_repaired_output(mock_task):
|
|
631
|
+
# use repaired output if available
|
|
632
|
+
mock_task_run = mock_task.runs()[0]
|
|
633
|
+
mock_task_run.repair_instructions = "repair instructions"
|
|
634
|
+
mock_task_run.repaired_output = TaskOutput(
|
|
635
|
+
output='{"test": "repaired output"}',
|
|
636
|
+
source=DataSource(
|
|
637
|
+
type=DataSourceType.human,
|
|
638
|
+
properties={"created_by": "test-user"},
|
|
639
|
+
),
|
|
640
|
+
)
|
|
641
|
+
|
|
642
|
+
training_data_output = build_training_data(mock_task_run, "system message", False)
|
|
643
|
+
assert training_data_output.final_output == '{"test": "repaired output"}'
|
|
644
|
+
assert training_data_output.thinking is None
|
|
645
|
+
assert training_data_output.thinking_instructions is None
|
|
646
|
+
assert training_data_output.thinking_final_answer_prompt is None
|
|
647
|
+
assert training_data_output.input == '{"test": "input 你好"}'
|
|
648
|
+
assert training_data_output.system_message == "system message"
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
def test_dataset_formatter_dump_to_file_json_schema_format(mock_dataset, tmp_path):
|
|
652
|
+
formatter = DatasetFormatter(mock_dataset, "system message")
|
|
653
|
+
output_path = tmp_path / "output.jsonl"
|
|
654
|
+
|
|
655
|
+
result_path = formatter.dump_to_file(
|
|
656
|
+
"train",
|
|
657
|
+
DatasetFormat.OPENAI_CHAT_JSON_SCHEMA_JSONL,
|
|
658
|
+
path=output_path,
|
|
659
|
+
data_strategy=FinetuneDataStrategy.final_only,
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
assert result_path == output_path
|
|
663
|
+
assert output_path.exists()
|
|
664
|
+
|
|
665
|
+
# Verify file contents
|
|
666
|
+
with open(output_path) as f:
|
|
667
|
+
lines = f.readlines()
|
|
668
|
+
assert len(lines) == 2 # Should have 2 entries for train split
|
|
669
|
+
for line in lines:
|
|
670
|
+
data = json.loads(line)
|
|
671
|
+
assert "messages" in data
|
|
672
|
+
assert len(data["messages"]) == 3
|
|
673
|
+
# Check system and user messages
|
|
674
|
+
assert data["messages"][0]["content"] == "system message"
|
|
675
|
+
assert data["messages"][1]["content"] == '{"test": "input 你好"}'
|
|
676
|
+
# Check JSON format
|
|
677
|
+
assistant_msg = data["messages"][2]
|
|
678
|
+
assert assistant_msg["role"] == "assistant"
|
|
679
|
+
# Verify the content is valid JSON
|
|
680
|
+
assert assistant_msg["content"] == '{"test": "output 你好"}'
|
|
681
|
+
json_content = json.loads(assistant_msg["content"])
|
|
682
|
+
assert json_content == {"test": "output 你好"}
|