kiln-ai 0.8.1__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +7 -7
- kiln_ai/adapters/adapter_registry.py +81 -10
- kiln_ai/adapters/data_gen/data_gen_task.py +21 -3
- kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
- kiln_ai/adapters/eval/base_eval.py +164 -0
- kiln_ai/adapters/eval/eval_runner.py +267 -0
- kiln_ai/adapters/eval/g_eval.py +367 -0
- kiln_ai/adapters/eval/registry.py +16 -0
- kiln_ai/adapters/eval/test_base_eval.py +324 -0
- kiln_ai/adapters/eval/test_eval_runner.py +640 -0
- kiln_ai/adapters/eval/test_g_eval.py +497 -0
- kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
- kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
- kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
- kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +472 -129
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +114 -22
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
- kiln_ai/adapters/ml_model_list.py +434 -93
- kiln_ai/adapters/model_adapters/__init__.py +18 -0
- kiln_ai/adapters/model_adapters/base_adapter.py +250 -0
- kiln_ai/adapters/model_adapters/langchain_adapters.py +309 -0
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +10 -0
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +289 -0
- kiln_ai/adapters/model_adapters/test_base_adapter.py +199 -0
- kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +105 -97
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +216 -0
- kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +80 -30
- kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +125 -46
- kiln_ai/adapters/ollama_tools.py +0 -1
- kiln_ai/adapters/parsers/__init__.py +10 -0
- kiln_ai/adapters/parsers/base_parser.py +12 -0
- kiln_ai/adapters/parsers/json_parser.py +37 -0
- kiln_ai/adapters/parsers/parser_registry.py +19 -0
- kiln_ai/adapters/parsers/r1_parser.py +69 -0
- kiln_ai/adapters/parsers/test_json_parser.py +81 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
- kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
- kiln_ai/adapters/prompt_builders.py +193 -49
- kiln_ai/adapters/provider_tools.py +91 -36
- kiln_ai/adapters/repair/repair_task.py +18 -19
- kiln_ai/adapters/repair/test_repair_task.py +7 -7
- kiln_ai/adapters/run_output.py +11 -0
- kiln_ai/adapters/test_adapter_registry.py +177 -0
- kiln_ai/adapters/test_generate_docs.py +69 -0
- kiln_ai/adapters/test_ollama_tools.py +0 -1
- kiln_ai/adapters/test_prompt_adaptors.py +25 -18
- kiln_ai/adapters/test_prompt_builders.py +265 -44
- kiln_ai/adapters/test_provider_tools.py +268 -46
- kiln_ai/datamodel/__init__.py +51 -772
- kiln_ai/datamodel/basemodel.py +31 -11
- kiln_ai/datamodel/datamodel_enums.py +58 -0
- kiln_ai/datamodel/dataset_filters.py +114 -0
- kiln_ai/datamodel/dataset_split.py +170 -0
- kiln_ai/datamodel/eval.py +298 -0
- kiln_ai/datamodel/finetune.py +105 -0
- kiln_ai/datamodel/json_schema.py +14 -3
- kiln_ai/datamodel/model_cache.py +8 -3
- kiln_ai/datamodel/project.py +23 -0
- kiln_ai/datamodel/prompt.py +37 -0
- kiln_ai/datamodel/prompt_id.py +83 -0
- kiln_ai/datamodel/strict_mode.py +24 -0
- kiln_ai/datamodel/task.py +181 -0
- kiln_ai/datamodel/task_output.py +321 -0
- kiln_ai/datamodel/task_run.py +164 -0
- kiln_ai/datamodel/test_basemodel.py +80 -2
- kiln_ai/datamodel/test_dataset_filters.py +71 -0
- kiln_ai/datamodel/test_dataset_split.py +127 -6
- kiln_ai/datamodel/test_datasource.py +3 -2
- kiln_ai/datamodel/test_eval_model.py +635 -0
- kiln_ai/datamodel/test_example_models.py +34 -17
- kiln_ai/datamodel/test_json_schema.py +23 -0
- kiln_ai/datamodel/test_model_cache.py +24 -0
- kiln_ai/datamodel/test_model_perf.py +125 -0
- kiln_ai/datamodel/test_models.py +131 -2
- kiln_ai/datamodel/test_prompt_id.py +129 -0
- kiln_ai/datamodel/test_task.py +159 -0
- kiln_ai/utils/config.py +6 -1
- kiln_ai/utils/exhaustive_error.py +6 -0
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/METADATA +45 -7
- kiln_ai-0.12.0.dist-info/RECORD +100 -0
- kiln_ai/adapters/base_adapter.py +0 -191
- kiln_ai/adapters/langchain_adapters.py +0 -256
- kiln_ai-0.8.1.dist-info/RECORD +0 -58
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import logging
|
|
2
3
|
import tempfile
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
from unittest.mock import Mock
|
|
@@ -8,49 +9,84 @@ import pytest
|
|
|
8
9
|
from kiln_ai.adapters.fine_tune.dataset_formatter import (
|
|
9
10
|
DatasetFormat,
|
|
10
11
|
DatasetFormatter,
|
|
12
|
+
ModelTrainingData,
|
|
13
|
+
build_training_data,
|
|
11
14
|
generate_chat_message_response,
|
|
12
15
|
generate_chat_message_toolcall,
|
|
13
16
|
generate_huggingface_chat_template,
|
|
14
17
|
generate_huggingface_chat_template_toolcall,
|
|
18
|
+
generate_vertex_gemini_1_5,
|
|
15
19
|
)
|
|
20
|
+
from kiln_ai.adapters.model_adapters.base_adapter import COT_FINAL_ANSWER_PROMPT
|
|
16
21
|
from kiln_ai.datamodel import (
|
|
17
22
|
DatasetSplit,
|
|
18
23
|
DataSource,
|
|
19
24
|
DataSourceType,
|
|
25
|
+
FinetuneDataStrategy,
|
|
20
26
|
Task,
|
|
21
27
|
TaskOutput,
|
|
22
28
|
TaskRun,
|
|
23
29
|
)
|
|
24
30
|
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
25
33
|
|
|
26
34
|
@pytest.fixture
|
|
27
35
|
def mock_task():
|
|
28
|
-
task = Mock(spec=Task)
|
|
36
|
+
task = Mock(spec=Task, thinking_instruction=None)
|
|
29
37
|
task_runs = [
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
"
|
|
42
|
-
|
|
43
|
-
|
|
38
|
+
Mock(
|
|
39
|
+
spec=TaskRun,
|
|
40
|
+
**{
|
|
41
|
+
"id": f"run{i}",
|
|
42
|
+
"input": '{"test": "input 你好"}',
|
|
43
|
+
"repaired_output": None,
|
|
44
|
+
"intermediate_outputs": {},
|
|
45
|
+
"input_source": Mock(
|
|
46
|
+
spec=DataSource,
|
|
47
|
+
**{
|
|
48
|
+
"type": DataSourceType.human,
|
|
49
|
+
"properties": {"created_by": "test"},
|
|
50
|
+
},
|
|
51
|
+
),
|
|
52
|
+
"output": Mock(
|
|
53
|
+
spec=TaskOutput,
|
|
54
|
+
**{
|
|
55
|
+
"output": '{"test": "output 你好"}',
|
|
56
|
+
"source": Mock(
|
|
57
|
+
spec=DataSource,
|
|
58
|
+
**{
|
|
59
|
+
"type": DataSourceType.synthetic,
|
|
60
|
+
"properties": {
|
|
61
|
+
"model_name": "test",
|
|
62
|
+
"model_provider": "test",
|
|
63
|
+
"adapter_name": "test",
|
|
64
|
+
},
|
|
65
|
+
},
|
|
66
|
+
),
|
|
44
67
|
},
|
|
45
68
|
),
|
|
46
|
-
|
|
69
|
+
},
|
|
47
70
|
)
|
|
48
71
|
for i in range(1, 4)
|
|
49
72
|
]
|
|
73
|
+
|
|
74
|
+
# Set up parent_task reference for each TaskRun
|
|
75
|
+
for run in task_runs:
|
|
76
|
+
run.parent_task = Mock(return_value=task)
|
|
77
|
+
|
|
50
78
|
task.runs.return_value = task_runs
|
|
51
79
|
return task
|
|
52
80
|
|
|
53
81
|
|
|
82
|
+
@pytest.fixture
|
|
83
|
+
def mock_intermediate_outputs(mock_task):
|
|
84
|
+
for run in mock_task.runs():
|
|
85
|
+
run.intermediate_outputs = {"reasoning": "thinking output"}
|
|
86
|
+
mock_task.thinking_instruction = "thinking instructions"
|
|
87
|
+
return mock_task
|
|
88
|
+
|
|
89
|
+
|
|
54
90
|
@pytest.fixture
|
|
55
91
|
def mock_dataset(mock_task):
|
|
56
92
|
dataset = Mock(spec=DatasetSplit)
|
|
@@ -61,26 +97,13 @@ def mock_dataset(mock_task):
|
|
|
61
97
|
|
|
62
98
|
|
|
63
99
|
def test_generate_chat_message_response():
|
|
64
|
-
|
|
65
|
-
id="run1",
|
|
100
|
+
thinking_data = ModelTrainingData(
|
|
66
101
|
input="test input",
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
),
|
|
70
|
-
output=TaskOutput(
|
|
71
|
-
output="test output",
|
|
72
|
-
source=DataSource(
|
|
73
|
-
type=DataSourceType.synthetic,
|
|
74
|
-
properties={
|
|
75
|
-
"model_name": "test",
|
|
76
|
-
"model_provider": "test",
|
|
77
|
-
"adapter_name": "test",
|
|
78
|
-
},
|
|
79
|
-
),
|
|
80
|
-
),
|
|
102
|
+
system_message="system message",
|
|
103
|
+
final_output="test output",
|
|
81
104
|
)
|
|
82
105
|
|
|
83
|
-
result = generate_chat_message_response(
|
|
106
|
+
result = generate_chat_message_response(thinking_data)
|
|
84
107
|
|
|
85
108
|
assert result == {
|
|
86
109
|
"messages": [
|
|
@@ -91,32 +114,80 @@ def test_generate_chat_message_response():
|
|
|
91
114
|
}
|
|
92
115
|
|
|
93
116
|
|
|
117
|
+
def test_generate_chat_message_response_thinking():
|
|
118
|
+
thinking_data = ModelTrainingData(
|
|
119
|
+
input="test input",
|
|
120
|
+
system_message="system message",
|
|
121
|
+
final_output="test output",
|
|
122
|
+
thinking="thinking output",
|
|
123
|
+
thinking_instructions="thinking instructions",
|
|
124
|
+
thinking_final_answer_prompt="thinking final answer prompt",
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
result = generate_chat_message_response(thinking_data)
|
|
128
|
+
|
|
129
|
+
assert result == {
|
|
130
|
+
"messages": [
|
|
131
|
+
{"role": "system", "content": "system message"},
|
|
132
|
+
{"role": "user", "content": "test input"},
|
|
133
|
+
{"role": "user", "content": "thinking instructions"},
|
|
134
|
+
{"role": "assistant", "content": "thinking output"},
|
|
135
|
+
{"role": "user", "content": "thinking final answer prompt"},
|
|
136
|
+
{"role": "assistant", "content": "test output"},
|
|
137
|
+
]
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
|
|
94
141
|
def test_generate_chat_message_toolcall():
|
|
95
|
-
|
|
96
|
-
|
|
142
|
+
training_data = ModelTrainingData(
|
|
143
|
+
input="test input 你好",
|
|
144
|
+
system_message="system message 你好",
|
|
145
|
+
final_output='{"key": "value 你好"}',
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
result = generate_chat_message_toolcall(training_data)
|
|
149
|
+
|
|
150
|
+
assert result == {
|
|
151
|
+
"messages": [
|
|
152
|
+
{"role": "system", "content": "system message 你好"},
|
|
153
|
+
{"role": "user", "content": "test input 你好"},
|
|
154
|
+
{
|
|
155
|
+
"role": "assistant",
|
|
156
|
+
"content": None,
|
|
157
|
+
"tool_calls": [
|
|
158
|
+
{
|
|
159
|
+
"id": "call_1",
|
|
160
|
+
"type": "function",
|
|
161
|
+
"function": {
|
|
162
|
+
"name": "task_response",
|
|
163
|
+
"arguments": '{"key": "value 你好"}',
|
|
164
|
+
},
|
|
165
|
+
}
|
|
166
|
+
],
|
|
167
|
+
},
|
|
168
|
+
]
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def test_generate_chat_message_toolcall_thinking():
|
|
173
|
+
training_data = ModelTrainingData(
|
|
97
174
|
input="test input",
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
source=DataSource(
|
|
104
|
-
type=DataSourceType.synthetic,
|
|
105
|
-
properties={
|
|
106
|
-
"model_name": "test",
|
|
107
|
-
"model_provider": "test",
|
|
108
|
-
"adapter_name": "test",
|
|
109
|
-
},
|
|
110
|
-
),
|
|
111
|
-
),
|
|
175
|
+
system_message="system message",
|
|
176
|
+
final_output='{"key": "value"}',
|
|
177
|
+
thinking="thinking output",
|
|
178
|
+
thinking_instructions="thinking instructions",
|
|
179
|
+
thinking_final_answer_prompt="thinking final answer prompt",
|
|
112
180
|
)
|
|
113
181
|
|
|
114
|
-
result = generate_chat_message_toolcall(
|
|
182
|
+
result = generate_chat_message_toolcall(training_data)
|
|
115
183
|
|
|
116
184
|
assert result == {
|
|
117
185
|
"messages": [
|
|
118
186
|
{"role": "system", "content": "system message"},
|
|
119
187
|
{"role": "user", "content": "test input"},
|
|
188
|
+
{"role": "user", "content": "thinking instructions"},
|
|
189
|
+
{"role": "assistant", "content": "thinking output"},
|
|
190
|
+
{"role": "user", "content": "thinking final answer prompt"},
|
|
120
191
|
{
|
|
121
192
|
"role": "assistant",
|
|
122
193
|
"content": None,
|
|
@@ -136,27 +207,14 @@ def test_generate_chat_message_toolcall():
|
|
|
136
207
|
|
|
137
208
|
|
|
138
209
|
def test_generate_chat_message_toolcall_invalid_json():
|
|
139
|
-
|
|
140
|
-
id="run1",
|
|
210
|
+
training_data = ModelTrainingData(
|
|
141
211
|
input="test input",
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
),
|
|
145
|
-
output=TaskOutput(
|
|
146
|
-
output="invalid json",
|
|
147
|
-
source=DataSource(
|
|
148
|
-
type=DataSourceType.synthetic,
|
|
149
|
-
properties={
|
|
150
|
-
"model_name": "test",
|
|
151
|
-
"model_provider": "test",
|
|
152
|
-
"adapter_name": "test",
|
|
153
|
-
},
|
|
154
|
-
),
|
|
155
|
-
),
|
|
212
|
+
system_message="system message",
|
|
213
|
+
final_output="invalid json",
|
|
156
214
|
)
|
|
157
215
|
|
|
158
216
|
with pytest.raises(ValueError, match="Invalid JSON in for tool call"):
|
|
159
|
-
generate_chat_message_toolcall(
|
|
217
|
+
generate_chat_message_toolcall(training_data)
|
|
160
218
|
|
|
161
219
|
|
|
162
220
|
def test_dataset_formatter_init_no_parent_task(mock_dataset):
|
|
@@ -170,14 +228,20 @@ def test_dataset_formatter_dump_invalid_format(mock_dataset):
|
|
|
170
228
|
formatter = DatasetFormatter(mock_dataset, "system message")
|
|
171
229
|
|
|
172
230
|
with pytest.raises(ValueError, match="Unsupported format"):
|
|
173
|
-
formatter.dump_to_file(
|
|
231
|
+
formatter.dump_to_file(
|
|
232
|
+
"train", "invalid_format", FinetuneDataStrategy.final_only
|
|
233
|
+
) # type: ignore
|
|
174
234
|
|
|
175
235
|
|
|
176
236
|
def test_dataset_formatter_dump_invalid_split(mock_dataset):
|
|
177
237
|
formatter = DatasetFormatter(mock_dataset, "system message")
|
|
178
238
|
|
|
179
239
|
with pytest.raises(ValueError, match="Split invalid_split not found in dataset"):
|
|
180
|
-
formatter.dump_to_file(
|
|
240
|
+
formatter.dump_to_file(
|
|
241
|
+
"invalid_split",
|
|
242
|
+
DatasetFormat.OPENAI_CHAT_JSONL,
|
|
243
|
+
FinetuneDataStrategy.final_only,
|
|
244
|
+
)
|
|
181
245
|
|
|
182
246
|
|
|
183
247
|
def test_dataset_formatter_dump_to_file(mock_dataset, tmp_path):
|
|
@@ -185,7 +249,10 @@ def test_dataset_formatter_dump_to_file(mock_dataset, tmp_path):
|
|
|
185
249
|
output_path = tmp_path / "output.jsonl"
|
|
186
250
|
|
|
187
251
|
result_path = formatter.dump_to_file(
|
|
188
|
-
"train",
|
|
252
|
+
"train",
|
|
253
|
+
DatasetFormat.OPENAI_CHAT_JSONL,
|
|
254
|
+
path=output_path,
|
|
255
|
+
data_strategy=FinetuneDataStrategy.final_only,
|
|
189
256
|
)
|
|
190
257
|
|
|
191
258
|
assert result_path == output_path
|
|
@@ -200,23 +267,38 @@ def test_dataset_formatter_dump_to_file(mock_dataset, tmp_path):
|
|
|
200
267
|
assert "messages" in data
|
|
201
268
|
assert len(data["messages"]) == 3
|
|
202
269
|
assert data["messages"][0]["content"] == "system message"
|
|
203
|
-
assert data["messages"][1]["content"] == '{"test": "input"}'
|
|
204
|
-
|
|
270
|
+
assert data["messages"][1]["content"] == '{"test": "input 你好"}'
|
|
271
|
+
# Raw chat doesn't fix json issues, like extra spaces
|
|
272
|
+
assert data["messages"][2]["content"] == '{"test": "output 你好"}'
|
|
205
273
|
|
|
206
274
|
|
|
207
275
|
def test_dataset_formatter_dump_to_temp_file(mock_dataset):
|
|
208
|
-
formatter = DatasetFormatter(mock_dataset, "system message")
|
|
276
|
+
formatter = DatasetFormatter(mock_dataset, "system message 你好")
|
|
209
277
|
|
|
210
|
-
result_path = formatter.dump_to_file(
|
|
278
|
+
result_path = formatter.dump_to_file(
|
|
279
|
+
"train",
|
|
280
|
+
DatasetFormat.OPENAI_CHAT_JSONL,
|
|
281
|
+
data_strategy=FinetuneDataStrategy.final_only,
|
|
282
|
+
)
|
|
211
283
|
|
|
212
284
|
assert result_path.exists()
|
|
213
285
|
assert result_path.parent == Path(tempfile.gettempdir())
|
|
214
|
-
|
|
286
|
+
# Test our nice naming
|
|
287
|
+
assert result_path.name.startswith(
|
|
288
|
+
"test_dataset -- split-train -- format-openai_chat_jsonl -- no-cot.jsonl"
|
|
289
|
+
)
|
|
215
290
|
assert result_path.name.endswith(".jsonl")
|
|
216
291
|
# Verify file contents
|
|
217
292
|
with open(result_path) as f:
|
|
218
293
|
lines = f.readlines()
|
|
219
294
|
assert len(lines) == 2
|
|
295
|
+
# check non-ascii characters are not escaped
|
|
296
|
+
assert "你好" in lines[0]
|
|
297
|
+
assert "你好" in lines[1]
|
|
298
|
+
|
|
299
|
+
# confirm didn't use COT for final_only
|
|
300
|
+
assert "thinking output" not in lines[0]
|
|
301
|
+
assert "thinking instructions" not in lines[0]
|
|
220
302
|
|
|
221
303
|
|
|
222
304
|
def test_dataset_formatter_dump_to_file_tool_format(mock_dataset, tmp_path):
|
|
@@ -224,7 +306,10 @@ def test_dataset_formatter_dump_to_file_tool_format(mock_dataset, tmp_path):
|
|
|
224
306
|
output_path = tmp_path / "output.jsonl"
|
|
225
307
|
|
|
226
308
|
result_path = formatter.dump_to_file(
|
|
227
|
-
"train",
|
|
309
|
+
"train",
|
|
310
|
+
DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL,
|
|
311
|
+
path=output_path,
|
|
312
|
+
data_strategy=FinetuneDataStrategy.final_only,
|
|
228
313
|
)
|
|
229
314
|
|
|
230
315
|
assert result_path == output_path
|
|
@@ -240,7 +325,7 @@ def test_dataset_formatter_dump_to_file_tool_format(mock_dataset, tmp_path):
|
|
|
240
325
|
assert len(data["messages"]) == 3
|
|
241
326
|
# Check system and user messages
|
|
242
327
|
assert data["messages"][0]["content"] == "system message"
|
|
243
|
-
assert data["messages"][1]["content"] == '{"test": "input"}'
|
|
328
|
+
assert data["messages"][1]["content"] == '{"test": "input 你好"}'
|
|
244
329
|
# Check tool call format
|
|
245
330
|
assistant_msg = data["messages"][2]
|
|
246
331
|
assert assistant_msg["content"] is None
|
|
@@ -249,61 +334,178 @@ def test_dataset_formatter_dump_to_file_tool_format(mock_dataset, tmp_path):
|
|
|
249
334
|
tool_call = assistant_msg["tool_calls"][0]
|
|
250
335
|
assert tool_call["type"] == "function"
|
|
251
336
|
assert tool_call["function"]["name"] == "task_response"
|
|
252
|
-
assert tool_call["function"]["arguments"] == '{"test": "output"}'
|
|
337
|
+
assert tool_call["function"]["arguments"] == '{"test": "output 你好"}'
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def test_dataset_formatter_dump_with_intermediate_data(
|
|
341
|
+
mock_dataset, mock_intermediate_outputs
|
|
342
|
+
):
|
|
343
|
+
formatter = DatasetFormatter(
|
|
344
|
+
mock_dataset,
|
|
345
|
+
"system message 你好",
|
|
346
|
+
thinking_instructions="thinking instructions",
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
result_path = formatter.dump_to_file(
|
|
350
|
+
"train",
|
|
351
|
+
DatasetFormat.OPENAI_CHAT_JSONL,
|
|
352
|
+
data_strategy=FinetuneDataStrategy.final_and_intermediate,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
assert result_path.exists()
|
|
356
|
+
assert result_path.parent == Path(tempfile.gettempdir())
|
|
357
|
+
# Test our nice naming, with cot
|
|
358
|
+
assert (
|
|
359
|
+
result_path.name
|
|
360
|
+
== "test_dataset -- split-train -- format-openai_chat_jsonl -- cot.jsonl"
|
|
361
|
+
)
|
|
362
|
+
# Verify file contents
|
|
363
|
+
with open(result_path) as f:
|
|
364
|
+
lines = f.readlines()
|
|
365
|
+
assert len(lines) == 2
|
|
366
|
+
for line in lines:
|
|
367
|
+
assert "thinking output" in line
|
|
368
|
+
assert "thinking instructions" in line
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def test_dataset_formatter_dump_with_intermediate_data_custom_instructions(
|
|
372
|
+
mock_dataset, mock_intermediate_outputs
|
|
373
|
+
):
|
|
374
|
+
formatter = DatasetFormatter(
|
|
375
|
+
mock_dataset, "custom system message 你好", "custom thinking instructions"
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
result_path = formatter.dump_to_file(
|
|
379
|
+
"train",
|
|
380
|
+
DatasetFormat.OPENAI_CHAT_JSONL,
|
|
381
|
+
data_strategy=FinetuneDataStrategy.final_and_intermediate,
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
assert result_path.exists()
|
|
385
|
+
assert result_path.parent == Path(tempfile.gettempdir())
|
|
386
|
+
# Test our nice naming, with cot
|
|
387
|
+
assert (
|
|
388
|
+
result_path.name
|
|
389
|
+
== "test_dataset -- split-train -- format-openai_chat_jsonl -- cot.jsonl"
|
|
390
|
+
)
|
|
391
|
+
# Verify file contents
|
|
392
|
+
with open(result_path) as f:
|
|
393
|
+
lines = f.readlines()
|
|
394
|
+
assert len(lines) == 2
|
|
395
|
+
for line in lines:
|
|
396
|
+
assert "custom system message 你好" in line
|
|
397
|
+
assert "custom thinking instructions" in line
|
|
398
|
+
assert "thinking output" in line
|
|
253
399
|
|
|
254
400
|
|
|
255
401
|
def test_generate_huggingface_chat_template():
|
|
256
|
-
|
|
257
|
-
id="run1",
|
|
402
|
+
training_data = ModelTrainingData(
|
|
258
403
|
input="test input",
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
404
|
+
system_message="system message",
|
|
405
|
+
final_output="test output",
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
result = generate_huggingface_chat_template(training_data)
|
|
409
|
+
|
|
410
|
+
assert result == {
|
|
411
|
+
"conversations": [
|
|
412
|
+
{"role": "system", "content": "system message"},
|
|
413
|
+
{"role": "user", "content": "test input"},
|
|
414
|
+
{"role": "assistant", "content": "test output"},
|
|
415
|
+
]
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def test_generate_huggingface_chat_template_thinking():
|
|
420
|
+
training_data = ModelTrainingData(
|
|
421
|
+
input="test input",
|
|
422
|
+
system_message="system message",
|
|
423
|
+
final_output="test output",
|
|
424
|
+
thinking="thinking output",
|
|
425
|
+
thinking_instructions="thinking instructions",
|
|
426
|
+
thinking_final_answer_prompt="thinking final answer prompt",
|
|
273
427
|
)
|
|
274
428
|
|
|
275
|
-
result = generate_huggingface_chat_template(
|
|
429
|
+
result = generate_huggingface_chat_template(training_data)
|
|
276
430
|
|
|
277
431
|
assert result == {
|
|
278
432
|
"conversations": [
|
|
279
433
|
{"role": "system", "content": "system message"},
|
|
280
434
|
{"role": "user", "content": "test input"},
|
|
435
|
+
{"role": "user", "content": "thinking instructions"},
|
|
436
|
+
{"role": "assistant", "content": "thinking output"},
|
|
437
|
+
{"role": "user", "content": "thinking final answer prompt"},
|
|
281
438
|
{"role": "assistant", "content": "test output"},
|
|
282
439
|
]
|
|
283
440
|
}
|
|
284
441
|
|
|
285
442
|
|
|
443
|
+
def test_generate_vertex_template():
|
|
444
|
+
training_data = ModelTrainingData(
|
|
445
|
+
input="test input",
|
|
446
|
+
system_message="system message",
|
|
447
|
+
final_output="test output",
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
result = generate_vertex_gemini_1_5(training_data)
|
|
451
|
+
|
|
452
|
+
assert result == {
|
|
453
|
+
"systemInstruction": {
|
|
454
|
+
"role": "system",
|
|
455
|
+
"parts": [
|
|
456
|
+
{
|
|
457
|
+
"text": "system message",
|
|
458
|
+
}
|
|
459
|
+
],
|
|
460
|
+
},
|
|
461
|
+
"contents": [
|
|
462
|
+
{"role": "user", "parts": [{"text": "test input"}]},
|
|
463
|
+
{"role": "model", "parts": [{"text": "test output"}]},
|
|
464
|
+
],
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
def test_generate_vertex_template_thinking():
|
|
469
|
+
training_data = ModelTrainingData(
|
|
470
|
+
input="test input",
|
|
471
|
+
system_message="system message",
|
|
472
|
+
final_output="test output",
|
|
473
|
+
thinking="thinking output",
|
|
474
|
+
thinking_instructions="thinking instructions",
|
|
475
|
+
thinking_final_answer_prompt="thinking final answer prompt",
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
result = generate_vertex_gemini_1_5(training_data)
|
|
479
|
+
|
|
480
|
+
logger.info(result)
|
|
481
|
+
|
|
482
|
+
assert result == {
|
|
483
|
+
"systemInstruction": {
|
|
484
|
+
"role": "system",
|
|
485
|
+
"parts": [
|
|
486
|
+
{
|
|
487
|
+
"text": "system message",
|
|
488
|
+
}
|
|
489
|
+
],
|
|
490
|
+
},
|
|
491
|
+
"contents": [
|
|
492
|
+
{"role": "user", "parts": [{"text": "test input"}]},
|
|
493
|
+
{"role": "user", "parts": [{"text": "thinking instructions"}]},
|
|
494
|
+
{"role": "model", "parts": [{"text": "thinking output"}]},
|
|
495
|
+
{"role": "user", "parts": [{"text": "thinking final answer prompt"}]},
|
|
496
|
+
{"role": "model", "parts": [{"text": "test output"}]},
|
|
497
|
+
],
|
|
498
|
+
}
|
|
499
|
+
|
|
500
|
+
|
|
286
501
|
def test_generate_huggingface_chat_template_toolcall():
|
|
287
|
-
|
|
288
|
-
id="run1",
|
|
502
|
+
training_data = ModelTrainingData(
|
|
289
503
|
input="test input",
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
),
|
|
293
|
-
output=TaskOutput(
|
|
294
|
-
output='{"key": "value"}',
|
|
295
|
-
source=DataSource(
|
|
296
|
-
type=DataSourceType.synthetic,
|
|
297
|
-
properties={
|
|
298
|
-
"model_name": "test",
|
|
299
|
-
"model_provider": "test",
|
|
300
|
-
"adapter_name": "test",
|
|
301
|
-
},
|
|
302
|
-
),
|
|
303
|
-
),
|
|
504
|
+
system_message="system message",
|
|
505
|
+
final_output='{"key": "value"}',
|
|
304
506
|
)
|
|
305
507
|
|
|
306
|
-
result = generate_huggingface_chat_template_toolcall(
|
|
508
|
+
result = generate_huggingface_chat_template_toolcall(training_data)
|
|
307
509
|
|
|
308
510
|
assert result["conversations"][0] == {"role": "system", "content": "system message"}
|
|
309
511
|
assert result["conversations"][1] == {"role": "user", "content": "test input"}
|
|
@@ -318,25 +520,166 @@ def test_generate_huggingface_chat_template_toolcall():
|
|
|
318
520
|
assert tool_call["function"]["arguments"] == {"key": "value"}
|
|
319
521
|
|
|
320
522
|
|
|
523
|
+
def test_generate_huggingface_chat_template_toolcall_thinking():
|
|
524
|
+
training_data = ModelTrainingData(
|
|
525
|
+
input="test input",
|
|
526
|
+
system_message="system message",
|
|
527
|
+
final_output='{"key": "value"}',
|
|
528
|
+
thinking="thinking output",
|
|
529
|
+
thinking_instructions="thinking instructions",
|
|
530
|
+
thinking_final_answer_prompt="thinking final answer prompt",
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
result = generate_huggingface_chat_template_toolcall(training_data)
|
|
534
|
+
|
|
535
|
+
assert result["conversations"][0] == {"role": "system", "content": "system message"}
|
|
536
|
+
assert result["conversations"][1] == {"role": "user", "content": "test input"}
|
|
537
|
+
assert result["conversations"][2] == {
|
|
538
|
+
"role": "user",
|
|
539
|
+
"content": "thinking instructions",
|
|
540
|
+
}
|
|
541
|
+
assert result["conversations"][3] == {
|
|
542
|
+
"role": "assistant",
|
|
543
|
+
"content": "thinking output",
|
|
544
|
+
}
|
|
545
|
+
assert result["conversations"][4] == {
|
|
546
|
+
"role": "user",
|
|
547
|
+
"content": "thinking final answer prompt",
|
|
548
|
+
}
|
|
549
|
+
|
|
550
|
+
assistant_msg = result["conversations"][5]
|
|
551
|
+
assert assistant_msg["role"] == "assistant"
|
|
552
|
+
assert len(assistant_msg["tool_calls"]) == 1
|
|
553
|
+
tool_call = assistant_msg["tool_calls"][0]
|
|
554
|
+
assert tool_call["type"] == "function"
|
|
555
|
+
assert tool_call["function"]["name"] == "task_response"
|
|
556
|
+
assert len(tool_call["function"]["id"]) == 9 # UUID is truncated to 9 chars
|
|
557
|
+
assert tool_call["function"]["id"].isalnum() # Check ID is alphanumeric
|
|
558
|
+
assert tool_call["function"]["arguments"] == {"key": "value"}
|
|
559
|
+
|
|
560
|
+
|
|
321
561
|
def test_generate_huggingface_chat_template_toolcall_invalid_json():
|
|
322
|
-
|
|
323
|
-
id="run1",
|
|
562
|
+
training_data = ModelTrainingData(
|
|
324
563
|
input="test input",
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
),
|
|
328
|
-
output=TaskOutput(
|
|
329
|
-
output="invalid json",
|
|
330
|
-
source=DataSource(
|
|
331
|
-
type=DataSourceType.synthetic,
|
|
332
|
-
properties={
|
|
333
|
-
"model_name": "test",
|
|
334
|
-
"model_provider": "test",
|
|
335
|
-
"adapter_name": "test",
|
|
336
|
-
},
|
|
337
|
-
),
|
|
338
|
-
),
|
|
564
|
+
system_message="system message",
|
|
565
|
+
final_output="invalid json",
|
|
339
566
|
)
|
|
340
567
|
|
|
341
568
|
with pytest.raises(ValueError, match="Invalid JSON in for tool call"):
|
|
342
|
-
generate_huggingface_chat_template_toolcall(
|
|
569
|
+
generate_huggingface_chat_template_toolcall(training_data)
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
def test_build_training_data(mock_task):
|
|
573
|
+
# Non repaired should use original output
|
|
574
|
+
mock_task_run = mock_task.runs()[0]
|
|
575
|
+
training_data_output = build_training_data(mock_task_run, "system message", False)
|
|
576
|
+
assert training_data_output.final_output == '{"test": "output 你好"}'
|
|
577
|
+
assert training_data_output.thinking is None
|
|
578
|
+
assert training_data_output.thinking_instructions is None
|
|
579
|
+
assert training_data_output.thinking_final_answer_prompt is None
|
|
580
|
+
assert training_data_output.input == '{"test": "input 你好"}'
|
|
581
|
+
assert training_data_output.system_message == "system message"
|
|
582
|
+
assert not training_data_output.supports_cot()
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
def test_build_training_data_with_COT(mock_task):
|
|
586
|
+
# Setup with needed fields for thinking
|
|
587
|
+
mock_task_run = mock_task.runs()[0]
|
|
588
|
+
assert mock_task_run.parent_task() == mock_task
|
|
589
|
+
mock_task_run.intermediate_outputs = {"chain_of_thought": "cot output"}
|
|
590
|
+
|
|
591
|
+
training_data_output = build_training_data(
|
|
592
|
+
mock_task_run,
|
|
593
|
+
"system message",
|
|
594
|
+
True,
|
|
595
|
+
thinking_instructions="thinking instructions",
|
|
596
|
+
)
|
|
597
|
+
assert training_data_output.final_output == '{"test": "output 你好"}'
|
|
598
|
+
assert training_data_output.thinking == "cot output"
|
|
599
|
+
assert training_data_output.thinking_instructions == "thinking instructions"
|
|
600
|
+
assert training_data_output.thinking_final_answer_prompt == COT_FINAL_ANSWER_PROMPT
|
|
601
|
+
assert training_data_output.input == '{"test": "input 你好"}'
|
|
602
|
+
assert training_data_output.system_message == "system message"
|
|
603
|
+
assert training_data_output.supports_cot()
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
def test_build_training_data_with_thinking(mock_task):
|
|
607
|
+
# Setup with needed fields for thinking
|
|
608
|
+
mock_task_run = mock_task.runs()[0]
|
|
609
|
+
assert mock_task_run.parent_task() == mock_task
|
|
610
|
+
# It should just use the reasoning output if both thinking and chain_of_thought are present
|
|
611
|
+
mock_task_run.intermediate_outputs = {
|
|
612
|
+
"reasoning": "thinking output",
|
|
613
|
+
"chain_of_thought": "cot output",
|
|
614
|
+
}
|
|
615
|
+
mock_task.thinking_instruction = "thinking instructions"
|
|
616
|
+
assert mock_task.thinking_instruction == "thinking instructions"
|
|
617
|
+
|
|
618
|
+
training_data_output = build_training_data(
|
|
619
|
+
mock_task_run,
|
|
620
|
+
"system message",
|
|
621
|
+
True,
|
|
622
|
+
thinking_instructions="thinking instructions",
|
|
623
|
+
)
|
|
624
|
+
assert training_data_output.final_output == '{"test": "output 你好"}'
|
|
625
|
+
assert training_data_output.thinking == "thinking output"
|
|
626
|
+
assert training_data_output.thinking_instructions == "thinking instructions"
|
|
627
|
+
assert training_data_output.thinking_final_answer_prompt == COT_FINAL_ANSWER_PROMPT
|
|
628
|
+
assert training_data_output.input == '{"test": "input 你好"}'
|
|
629
|
+
assert training_data_output.system_message == "system message"
|
|
630
|
+
assert training_data_output.supports_cot()
|
|
631
|
+
|
|
632
|
+
|
|
633
|
+
def test_build_training_data_with_repaired_output(mock_task):
|
|
634
|
+
# use repaired output if available
|
|
635
|
+
mock_task_run = mock_task.runs()[0]
|
|
636
|
+
mock_task_run.repair_instructions = "repair instructions"
|
|
637
|
+
mock_task_run.repaired_output = TaskOutput(
|
|
638
|
+
output='{"test": "repaired output"}',
|
|
639
|
+
source=DataSource(
|
|
640
|
+
type=DataSourceType.human,
|
|
641
|
+
properties={"created_by": "test-user"},
|
|
642
|
+
),
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
training_data_output = build_training_data(mock_task_run, "system message", False)
|
|
646
|
+
assert training_data_output.final_output == '{"test": "repaired output"}'
|
|
647
|
+
assert training_data_output.thinking is None
|
|
648
|
+
assert training_data_output.thinking_instructions is None
|
|
649
|
+
assert training_data_output.thinking_final_answer_prompt is None
|
|
650
|
+
assert training_data_output.input == '{"test": "input 你好"}'
|
|
651
|
+
assert training_data_output.system_message == "system message"
|
|
652
|
+
|
|
653
|
+
|
|
654
|
+
def test_dataset_formatter_dump_to_file_json_schema_format(mock_dataset, tmp_path):
|
|
655
|
+
formatter = DatasetFormatter(mock_dataset, "system message")
|
|
656
|
+
output_path = tmp_path / "output.jsonl"
|
|
657
|
+
|
|
658
|
+
result_path = formatter.dump_to_file(
|
|
659
|
+
"train",
|
|
660
|
+
DatasetFormat.OPENAI_CHAT_JSON_SCHEMA_JSONL,
|
|
661
|
+
path=output_path,
|
|
662
|
+
data_strategy=FinetuneDataStrategy.final_only,
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
assert result_path == output_path
|
|
666
|
+
assert output_path.exists()
|
|
667
|
+
|
|
668
|
+
# Verify file contents
|
|
669
|
+
with open(output_path) as f:
|
|
670
|
+
lines = f.readlines()
|
|
671
|
+
assert len(lines) == 2 # Should have 2 entries for train split
|
|
672
|
+
for line in lines:
|
|
673
|
+
data = json.loads(line)
|
|
674
|
+
assert "messages" in data
|
|
675
|
+
assert len(data["messages"]) == 3
|
|
676
|
+
# Check system and user messages
|
|
677
|
+
assert data["messages"][0]["content"] == "system message"
|
|
678
|
+
assert data["messages"][1]["content"] == '{"test": "input 你好"}'
|
|
679
|
+
# Check JSON format
|
|
680
|
+
assistant_msg = data["messages"][2]
|
|
681
|
+
assert assistant_msg["role"] == "assistant"
|
|
682
|
+
# Verify the content is valid JSON
|
|
683
|
+
assert assistant_msg["content"] == '{"test": "output 你好"}'
|
|
684
|
+
json_content = json.loads(assistant_msg["content"])
|
|
685
|
+
assert json_content == {"test": "output 你好"}
|