kiln-ai 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +11 -1
- kiln_ai/adapters/adapter_registry.py +19 -0
- kiln_ai/adapters/data_gen/__init__.py +11 -0
- kiln_ai/adapters/data_gen/data_gen_task.py +69 -1
- kiln_ai/adapters/data_gen/test_data_gen_task.py +30 -21
- kiln_ai/adapters/fine_tune/__init__.py +14 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
- kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
- kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
- kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
- kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
- kiln_ai/adapters/langchain_adapters.py +103 -13
- kiln_ai/adapters/ml_model_list.py +218 -304
- kiln_ai/adapters/ollama_tools.py +114 -0
- kiln_ai/adapters/provider_tools.py +295 -0
- kiln_ai/adapters/repair/test_repair_task.py +6 -11
- kiln_ai/adapters/test_langchain_adapter.py +46 -18
- kiln_ai/adapters/test_ollama_tools.py +42 -0
- kiln_ai/adapters/test_prompt_adaptors.py +7 -5
- kiln_ai/adapters/test_provider_tools.py +312 -0
- kiln_ai/adapters/test_structured_output.py +22 -43
- kiln_ai/datamodel/__init__.py +235 -22
- kiln_ai/datamodel/basemodel.py +30 -0
- kiln_ai/datamodel/registry.py +31 -0
- kiln_ai/datamodel/test_basemodel.py +29 -1
- kiln_ai/datamodel/test_dataset_split.py +234 -0
- kiln_ai/datamodel/test_example_models.py +12 -0
- kiln_ai/datamodel/test_models.py +91 -1
- kiln_ai/datamodel/test_registry.py +96 -0
- kiln_ai/utils/config.py +9 -0
- kiln_ai/utils/name_generator.py +125 -0
- kiln_ai/utils/test_name_geneator.py +47 -0
- {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/METADATA +4 -2
- kiln_ai-0.7.0.dist-info/RECORD +56 -0
- kiln_ai/adapters/test_ml_model_list.py +0 -181
- kiln_ai-0.6.0.dist-info/RECORD +0 -36
- {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
from unittest.mock import Mock
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from kiln_ai.adapters.fine_tune.base_finetune import (
|
|
6
|
+
BaseFinetuneAdapter,
|
|
7
|
+
FineTuneParameter,
|
|
8
|
+
FineTuneStatus,
|
|
9
|
+
FineTuneStatusType,
|
|
10
|
+
)
|
|
11
|
+
from kiln_ai.datamodel import DatasetSplit, Task
|
|
12
|
+
from kiln_ai.datamodel import Finetune as FinetuneModel
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class MockFinetune(BaseFinetuneAdapter):
|
|
16
|
+
"""Mock implementation of BaseFinetune for testing"""
|
|
17
|
+
|
|
18
|
+
async def _start(self, dataset: DatasetSplit) -> None:
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
async def status(self) -> FineTuneStatus:
|
|
22
|
+
return FineTuneStatus(status=FineTuneStatusType.pending, message="loading...")
|
|
23
|
+
|
|
24
|
+
@classmethod
|
|
25
|
+
def available_parameters(cls) -> list[FineTuneParameter]:
|
|
26
|
+
return [
|
|
27
|
+
FineTuneParameter(
|
|
28
|
+
name="learning_rate",
|
|
29
|
+
type="float",
|
|
30
|
+
description="Learning rate for training",
|
|
31
|
+
),
|
|
32
|
+
FineTuneParameter(
|
|
33
|
+
name="epochs",
|
|
34
|
+
type="int",
|
|
35
|
+
description="Number of training epochs",
|
|
36
|
+
optional=False,
|
|
37
|
+
),
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@pytest.fixture
|
|
42
|
+
def sample_task(tmp_path):
|
|
43
|
+
task_path = tmp_path / "task.kiln"
|
|
44
|
+
task = Task(
|
|
45
|
+
name="Test Task",
|
|
46
|
+
path=task_path,
|
|
47
|
+
description="Test task for fine-tuning",
|
|
48
|
+
instruction="Test instruction",
|
|
49
|
+
)
|
|
50
|
+
task.save_to_file()
|
|
51
|
+
return task
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@pytest.fixture
|
|
55
|
+
def basic_finetune(sample_task):
|
|
56
|
+
return MockFinetune(
|
|
57
|
+
datamodel=FinetuneModel(
|
|
58
|
+
parent=sample_task,
|
|
59
|
+
name="test_finetune",
|
|
60
|
+
provider="test_provider",
|
|
61
|
+
provider_id="model_1234",
|
|
62
|
+
base_model_id="test_model",
|
|
63
|
+
train_split_name="train",
|
|
64
|
+
dataset_split_id="dataset-123",
|
|
65
|
+
system_message="Test system message",
|
|
66
|
+
),
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
async def test_finetune_status(basic_finetune):
|
|
71
|
+
status = await basic_finetune.status()
|
|
72
|
+
assert status.status == FineTuneStatusType.pending
|
|
73
|
+
assert status.message == "loading..."
|
|
74
|
+
assert isinstance(status, FineTuneStatus)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def test_available_parameters():
|
|
78
|
+
params = MockFinetune.available_parameters()
|
|
79
|
+
assert len(params) == 2
|
|
80
|
+
|
|
81
|
+
learning_rate_param = params[0]
|
|
82
|
+
assert learning_rate_param.name == "learning_rate"
|
|
83
|
+
assert learning_rate_param.type == "float"
|
|
84
|
+
assert learning_rate_param.optional is True
|
|
85
|
+
|
|
86
|
+
epochs_param = params[1]
|
|
87
|
+
assert epochs_param.name == "epochs"
|
|
88
|
+
assert epochs_param.type == "int"
|
|
89
|
+
assert epochs_param.optional is False
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def test_validate_parameters_valid():
|
|
93
|
+
# Test valid parameters
|
|
94
|
+
valid_params = {
|
|
95
|
+
"learning_rate": 0.001,
|
|
96
|
+
"epochs": 10,
|
|
97
|
+
}
|
|
98
|
+
MockFinetune.validate_parameters(valid_params) # Should not raise
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def test_validate_parameters_missing_required():
|
|
102
|
+
# Test missing required parameter
|
|
103
|
+
invalid_params = {
|
|
104
|
+
"learning_rate": 0.001,
|
|
105
|
+
# missing required 'epochs'
|
|
106
|
+
}
|
|
107
|
+
with pytest.raises(ValueError, match="Parameter epochs is required"):
|
|
108
|
+
MockFinetune.validate_parameters(invalid_params)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def test_validate_parameters_wrong_type():
|
|
112
|
+
# Test wrong parameter types
|
|
113
|
+
invalid_params = {
|
|
114
|
+
"learning_rate": "0.001", # string instead of float
|
|
115
|
+
"epochs": 10,
|
|
116
|
+
}
|
|
117
|
+
with pytest.raises(ValueError, match="Parameter learning_rate must be a float"):
|
|
118
|
+
MockFinetune.validate_parameters(invalid_params)
|
|
119
|
+
|
|
120
|
+
invalid_params = {
|
|
121
|
+
"learning_rate": 0.001,
|
|
122
|
+
"epochs": 10.5, # float instead of int
|
|
123
|
+
}
|
|
124
|
+
with pytest.raises(ValueError, match="Parameter epochs must be an integer"):
|
|
125
|
+
MockFinetune.validate_parameters(invalid_params)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def test_validate_parameters_unknown_parameter():
|
|
129
|
+
# Test unknown parameter
|
|
130
|
+
invalid_params = {
|
|
131
|
+
"learning_rate": 0.001,
|
|
132
|
+
"epochs": 10,
|
|
133
|
+
"unknown_param": "value",
|
|
134
|
+
}
|
|
135
|
+
with pytest.raises(ValueError, match="Parameter unknown_param is not available"):
|
|
136
|
+
MockFinetune.validate_parameters(invalid_params)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@pytest.fixture
|
|
140
|
+
def mock_dataset(sample_task):
|
|
141
|
+
dataset = Mock(spec=DatasetSplit)
|
|
142
|
+
dataset.id = "dataset_123"
|
|
143
|
+
dataset.parent_task.return_value = sample_task
|
|
144
|
+
dataset.split_contents = {"train": [], "validation": [], "test": []}
|
|
145
|
+
return dataset
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
async def test_create_and_start_success(mock_dataset):
|
|
149
|
+
# Test successful creation with minimal parameters
|
|
150
|
+
adapter, datamodel = await MockFinetune.create_and_start(
|
|
151
|
+
dataset=mock_dataset,
|
|
152
|
+
provider_id="openai",
|
|
153
|
+
provider_base_model_id="gpt-4o-mini-2024-07-18",
|
|
154
|
+
train_split_name="train",
|
|
155
|
+
parameters={"epochs": 10}, # Required parameter
|
|
156
|
+
system_message="Test system message",
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
assert isinstance(adapter, MockFinetune)
|
|
160
|
+
assert isinstance(datamodel, FinetuneModel)
|
|
161
|
+
assert len(datamodel.name.split()) == 2 # 2 word memorable name
|
|
162
|
+
assert datamodel.provider == "openai"
|
|
163
|
+
assert datamodel.base_model_id == "gpt-4o-mini-2024-07-18"
|
|
164
|
+
assert datamodel.dataset_split_id == mock_dataset.id
|
|
165
|
+
assert datamodel.train_split_name == "train"
|
|
166
|
+
assert datamodel.parameters == {"epochs": 10}
|
|
167
|
+
assert datamodel.system_message == "Test system message"
|
|
168
|
+
assert datamodel.path.exists()
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
async def test_create_and_start_with_all_params(mock_dataset):
|
|
172
|
+
# Test creation with all optional parameters
|
|
173
|
+
adapter, datamodel = await MockFinetune.create_and_start(
|
|
174
|
+
dataset=mock_dataset,
|
|
175
|
+
provider_id="openai",
|
|
176
|
+
provider_base_model_id="gpt-4o-mini-2024-07-18",
|
|
177
|
+
train_split_name="train",
|
|
178
|
+
parameters={"epochs": 10, "learning_rate": 0.001},
|
|
179
|
+
name="Custom Name",
|
|
180
|
+
description="Custom Description",
|
|
181
|
+
validation_split_name="test",
|
|
182
|
+
system_message="Test system message",
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
assert datamodel.name == "Custom Name"
|
|
186
|
+
assert datamodel.description == "Custom Description"
|
|
187
|
+
assert datamodel.validation_split_name == "test"
|
|
188
|
+
assert datamodel.parameters == {"epochs": 10, "learning_rate": 0.001}
|
|
189
|
+
assert datamodel.system_message == "Test system message"
|
|
190
|
+
assert adapter.datamodel == datamodel
|
|
191
|
+
|
|
192
|
+
# load the datamodel from the file, confirm it's saved
|
|
193
|
+
loaded_datamodel = FinetuneModel.load_from_file(datamodel.path)
|
|
194
|
+
assert loaded_datamodel.model_dump_json() == datamodel.model_dump_json()
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
async def test_create_and_start_invalid_parameters(mock_dataset):
|
|
198
|
+
# Test with invalid parameters
|
|
199
|
+
with pytest.raises(ValueError, match="Parameter epochs is required"):
|
|
200
|
+
await MockFinetune.create_and_start(
|
|
201
|
+
dataset=mock_dataset,
|
|
202
|
+
provider_id="openai",
|
|
203
|
+
provider_base_model_id="gpt-4o-mini-2024-07-18",
|
|
204
|
+
train_split_name="train",
|
|
205
|
+
parameters={"learning_rate": 0.001}, # Missing required 'epochs'
|
|
206
|
+
system_message="Test system message",
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
async def test_create_and_start_no_parent_task():
|
|
211
|
+
# Test with dataset that has no parent task
|
|
212
|
+
dataset = Mock(spec=DatasetSplit)
|
|
213
|
+
dataset.id = "dataset_123"
|
|
214
|
+
dataset.parent_task.return_value = None
|
|
215
|
+
dataset.split_contents = {"train": [], "validation": [], "test": []}
|
|
216
|
+
|
|
217
|
+
with pytest.raises(ValueError, match="Dataset must have a parent task with a path"):
|
|
218
|
+
await MockFinetune.create_and_start(
|
|
219
|
+
dataset=dataset,
|
|
220
|
+
provider_id="openai",
|
|
221
|
+
provider_base_model_id="gpt-4o-mini-2024-07-18",
|
|
222
|
+
train_split_name="train",
|
|
223
|
+
parameters={"epochs": 10},
|
|
224
|
+
system_message="Test system message",
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
async def test_create_and_start_no_parent_task_path():
|
|
229
|
+
# Test with dataset that has parent task but no path
|
|
230
|
+
task = Mock(spec=Task)
|
|
231
|
+
task.path = None
|
|
232
|
+
|
|
233
|
+
dataset = Mock(spec=DatasetSplit)
|
|
234
|
+
dataset.id = "dataset_123"
|
|
235
|
+
dataset.parent_task.return_value = task
|
|
236
|
+
dataset.split_contents = {"train": [], "validation": [], "test": []}
|
|
237
|
+
|
|
238
|
+
with pytest.raises(ValueError, match="Dataset must have a parent task with a path"):
|
|
239
|
+
await MockFinetune.create_and_start(
|
|
240
|
+
dataset=dataset,
|
|
241
|
+
provider_id="openai",
|
|
242
|
+
provider_base_model_id="gpt-4o-mini-2024-07-18",
|
|
243
|
+
train_split_name="train",
|
|
244
|
+
parameters={"epochs": 10},
|
|
245
|
+
system_message="Test system message",
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def test_check_valid_provider_model():
|
|
250
|
+
MockFinetune.check_valid_provider_model("openai", "gpt-4o-mini-2024-07-18")
|
|
251
|
+
|
|
252
|
+
with pytest.raises(
|
|
253
|
+
ValueError, match="Provider openai with base model gpt-99 is not available"
|
|
254
|
+
):
|
|
255
|
+
MockFinetune.check_valid_provider_model("openai", "gpt-99")
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
async def test_create_and_start_invalid_train_split(mock_dataset):
|
|
259
|
+
# Test with an invalid train split name
|
|
260
|
+
mock_dataset.split_contents = {"valid_train": [], "valid_test": []}
|
|
261
|
+
|
|
262
|
+
with pytest.raises(
|
|
263
|
+
ValueError, match="Train split invalid_train not found in dataset"
|
|
264
|
+
):
|
|
265
|
+
await MockFinetune.create_and_start(
|
|
266
|
+
dataset=mock_dataset,
|
|
267
|
+
provider_id="openai",
|
|
268
|
+
provider_base_model_id="gpt-4o-mini-2024-07-18",
|
|
269
|
+
train_split_name="invalid_train", # Invalid train split
|
|
270
|
+
parameters={"epochs": 10},
|
|
271
|
+
system_message="Test system message",
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
async def test_create_and_start_invalid_validation_split(mock_dataset):
|
|
276
|
+
# Test with an invalid validation split name
|
|
277
|
+
mock_dataset.split_contents = {"valid_train": [], "valid_test": []}
|
|
278
|
+
|
|
279
|
+
with pytest.raises(
|
|
280
|
+
ValueError, match="Validation split invalid_test not found in dataset"
|
|
281
|
+
):
|
|
282
|
+
await MockFinetune.create_and_start(
|
|
283
|
+
dataset=mock_dataset,
|
|
284
|
+
provider_id="openai",
|
|
285
|
+
provider_base_model_id="gpt-4o-mini-2024-07-18",
|
|
286
|
+
train_split_name="valid_train",
|
|
287
|
+
validation_split_name="invalid_test", # Invalid validation split
|
|
288
|
+
parameters={"epochs": 10},
|
|
289
|
+
system_message="Test system message",
|
|
290
|
+
)
|
|
@@ -0,0 +1,342 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import tempfile
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from unittest.mock import Mock
|
|
5
|
+
|
|
6
|
+
import pytest
|
|
7
|
+
|
|
8
|
+
from kiln_ai.adapters.fine_tune.dataset_formatter import (
|
|
9
|
+
DatasetFormat,
|
|
10
|
+
DatasetFormatter,
|
|
11
|
+
generate_chat_message_response,
|
|
12
|
+
generate_chat_message_toolcall,
|
|
13
|
+
generate_huggingface_chat_template,
|
|
14
|
+
generate_huggingface_chat_template_toolcall,
|
|
15
|
+
)
|
|
16
|
+
from kiln_ai.datamodel import (
|
|
17
|
+
DatasetSplit,
|
|
18
|
+
DataSource,
|
|
19
|
+
DataSourceType,
|
|
20
|
+
Task,
|
|
21
|
+
TaskOutput,
|
|
22
|
+
TaskRun,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@pytest.fixture
|
|
27
|
+
def mock_task():
|
|
28
|
+
task = Mock(spec=Task)
|
|
29
|
+
task_runs = [
|
|
30
|
+
TaskRun(
|
|
31
|
+
id=f"run{i}",
|
|
32
|
+
input='{"test": "input"}',
|
|
33
|
+
input_source=DataSource(
|
|
34
|
+
type=DataSourceType.human, properties={"created_by": "test"}
|
|
35
|
+
),
|
|
36
|
+
output=TaskOutput(
|
|
37
|
+
output='{"test": "output"}',
|
|
38
|
+
source=DataSource(
|
|
39
|
+
type=DataSourceType.synthetic,
|
|
40
|
+
properties={
|
|
41
|
+
"model_name": "test",
|
|
42
|
+
"model_provider": "test",
|
|
43
|
+
"adapter_name": "test",
|
|
44
|
+
},
|
|
45
|
+
),
|
|
46
|
+
),
|
|
47
|
+
)
|
|
48
|
+
for i in range(1, 4)
|
|
49
|
+
]
|
|
50
|
+
task.runs.return_value = task_runs
|
|
51
|
+
return task
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@pytest.fixture
|
|
55
|
+
def mock_dataset(mock_task):
|
|
56
|
+
dataset = Mock(spec=DatasetSplit)
|
|
57
|
+
dataset.name = "test_dataset"
|
|
58
|
+
dataset.parent_task.return_value = mock_task
|
|
59
|
+
dataset.split_contents = {"train": ["run1", "run2"], "test": ["run3"]}
|
|
60
|
+
return dataset
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def test_generate_chat_message_response():
|
|
64
|
+
task_run = TaskRun(
|
|
65
|
+
id="run1",
|
|
66
|
+
input="test input",
|
|
67
|
+
input_source=DataSource(
|
|
68
|
+
type=DataSourceType.human, properties={"created_by": "test"}
|
|
69
|
+
),
|
|
70
|
+
output=TaskOutput(
|
|
71
|
+
output="test output",
|
|
72
|
+
source=DataSource(
|
|
73
|
+
type=DataSourceType.synthetic,
|
|
74
|
+
properties={
|
|
75
|
+
"model_name": "test",
|
|
76
|
+
"model_provider": "test",
|
|
77
|
+
"adapter_name": "test",
|
|
78
|
+
},
|
|
79
|
+
),
|
|
80
|
+
),
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
result = generate_chat_message_response(task_run, "system message")
|
|
84
|
+
|
|
85
|
+
assert result == {
|
|
86
|
+
"messages": [
|
|
87
|
+
{"role": "system", "content": "system message"},
|
|
88
|
+
{"role": "user", "content": "test input"},
|
|
89
|
+
{"role": "assistant", "content": "test output"},
|
|
90
|
+
]
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def test_generate_chat_message_toolcall():
|
|
95
|
+
task_run = TaskRun(
|
|
96
|
+
id="run1",
|
|
97
|
+
input="test input",
|
|
98
|
+
input_source=DataSource(
|
|
99
|
+
type=DataSourceType.human, properties={"created_by": "test"}
|
|
100
|
+
),
|
|
101
|
+
output=TaskOutput(
|
|
102
|
+
output='{"key": "value"}',
|
|
103
|
+
source=DataSource(
|
|
104
|
+
type=DataSourceType.synthetic,
|
|
105
|
+
properties={
|
|
106
|
+
"model_name": "test",
|
|
107
|
+
"model_provider": "test",
|
|
108
|
+
"adapter_name": "test",
|
|
109
|
+
},
|
|
110
|
+
),
|
|
111
|
+
),
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
result = generate_chat_message_toolcall(task_run, "system message")
|
|
115
|
+
|
|
116
|
+
assert result == {
|
|
117
|
+
"messages": [
|
|
118
|
+
{"role": "system", "content": "system message"},
|
|
119
|
+
{"role": "user", "content": "test input"},
|
|
120
|
+
{
|
|
121
|
+
"role": "assistant",
|
|
122
|
+
"content": None,
|
|
123
|
+
"tool_calls": [
|
|
124
|
+
{
|
|
125
|
+
"id": "call_1",
|
|
126
|
+
"type": "function",
|
|
127
|
+
"function": {
|
|
128
|
+
"name": "task_response",
|
|
129
|
+
"arguments": '{"key": "value"}',
|
|
130
|
+
},
|
|
131
|
+
}
|
|
132
|
+
],
|
|
133
|
+
},
|
|
134
|
+
]
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def test_generate_chat_message_toolcall_invalid_json():
|
|
139
|
+
task_run = TaskRun(
|
|
140
|
+
id="run1",
|
|
141
|
+
input="test input",
|
|
142
|
+
input_source=DataSource(
|
|
143
|
+
type=DataSourceType.human, properties={"created_by": "test"}
|
|
144
|
+
),
|
|
145
|
+
output=TaskOutput(
|
|
146
|
+
output="invalid json",
|
|
147
|
+
source=DataSource(
|
|
148
|
+
type=DataSourceType.synthetic,
|
|
149
|
+
properties={
|
|
150
|
+
"model_name": "test",
|
|
151
|
+
"model_provider": "test",
|
|
152
|
+
"adapter_name": "test",
|
|
153
|
+
},
|
|
154
|
+
),
|
|
155
|
+
),
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
with pytest.raises(ValueError, match="Invalid JSON in for tool call"):
|
|
159
|
+
generate_chat_message_toolcall(task_run, "system message")
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def test_dataset_formatter_init_no_parent_task(mock_dataset):
|
|
163
|
+
mock_dataset.parent_task.return_value = None
|
|
164
|
+
|
|
165
|
+
with pytest.raises(ValueError, match="Dataset has no parent task"):
|
|
166
|
+
DatasetFormatter(mock_dataset, "system message")
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def test_dataset_formatter_dump_invalid_format(mock_dataset):
|
|
170
|
+
formatter = DatasetFormatter(mock_dataset, "system message")
|
|
171
|
+
|
|
172
|
+
with pytest.raises(ValueError, match="Unsupported format"):
|
|
173
|
+
formatter.dump_to_file("train", "invalid_format") # type: ignore
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def test_dataset_formatter_dump_invalid_split(mock_dataset):
|
|
177
|
+
formatter = DatasetFormatter(mock_dataset, "system message")
|
|
178
|
+
|
|
179
|
+
with pytest.raises(ValueError, match="Split invalid_split not found in dataset"):
|
|
180
|
+
formatter.dump_to_file("invalid_split", DatasetFormat.OPENAI_CHAT_JSONL)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def test_dataset_formatter_dump_to_file(mock_dataset, tmp_path):
|
|
184
|
+
formatter = DatasetFormatter(mock_dataset, "system message")
|
|
185
|
+
output_path = tmp_path / "output.jsonl"
|
|
186
|
+
|
|
187
|
+
result_path = formatter.dump_to_file(
|
|
188
|
+
"train", DatasetFormat.OPENAI_CHAT_JSONL, output_path
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
assert result_path == output_path
|
|
192
|
+
assert output_path.exists()
|
|
193
|
+
|
|
194
|
+
# Verify file contents
|
|
195
|
+
with open(output_path) as f:
|
|
196
|
+
lines = f.readlines()
|
|
197
|
+
assert len(lines) == 2 # Should have 2 entries for train split
|
|
198
|
+
for line in lines:
|
|
199
|
+
data = json.loads(line)
|
|
200
|
+
assert "messages" in data
|
|
201
|
+
assert len(data["messages"]) == 3
|
|
202
|
+
assert data["messages"][0]["content"] == "system message"
|
|
203
|
+
assert data["messages"][1]["content"] == '{"test": "input"}'
|
|
204
|
+
assert data["messages"][2]["content"] == '{"test": "output"}'
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def test_dataset_formatter_dump_to_temp_file(mock_dataset):
|
|
208
|
+
formatter = DatasetFormatter(mock_dataset, "system message")
|
|
209
|
+
|
|
210
|
+
result_path = formatter.dump_to_file("train", DatasetFormat.OPENAI_CHAT_JSONL)
|
|
211
|
+
|
|
212
|
+
assert result_path.exists()
|
|
213
|
+
assert result_path.parent == Path(tempfile.gettempdir())
|
|
214
|
+
assert result_path.name.startswith("test_dataset_train_")
|
|
215
|
+
assert result_path.name.endswith(".jsonl")
|
|
216
|
+
# Verify file contents
|
|
217
|
+
with open(result_path) as f:
|
|
218
|
+
lines = f.readlines()
|
|
219
|
+
assert len(lines) == 2
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def test_dataset_formatter_dump_to_file_tool_format(mock_dataset, tmp_path):
|
|
223
|
+
formatter = DatasetFormatter(mock_dataset, "system message")
|
|
224
|
+
output_path = tmp_path / "output.jsonl"
|
|
225
|
+
|
|
226
|
+
result_path = formatter.dump_to_file(
|
|
227
|
+
"train", DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL, output_path
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
assert result_path == output_path
|
|
231
|
+
assert output_path.exists()
|
|
232
|
+
|
|
233
|
+
# Verify file contents
|
|
234
|
+
with open(output_path) as f:
|
|
235
|
+
lines = f.readlines()
|
|
236
|
+
assert len(lines) == 2 # Should have 2 entries for train split
|
|
237
|
+
for line in lines:
|
|
238
|
+
data = json.loads(line)
|
|
239
|
+
assert "messages" in data
|
|
240
|
+
assert len(data["messages"]) == 3
|
|
241
|
+
# Check system and user messages
|
|
242
|
+
assert data["messages"][0]["content"] == "system message"
|
|
243
|
+
assert data["messages"][1]["content"] == '{"test": "input"}'
|
|
244
|
+
# Check tool call format
|
|
245
|
+
assistant_msg = data["messages"][2]
|
|
246
|
+
assert assistant_msg["content"] is None
|
|
247
|
+
assert "tool_calls" in assistant_msg
|
|
248
|
+
assert len(assistant_msg["tool_calls"]) == 1
|
|
249
|
+
tool_call = assistant_msg["tool_calls"][0]
|
|
250
|
+
assert tool_call["type"] == "function"
|
|
251
|
+
assert tool_call["function"]["name"] == "task_response"
|
|
252
|
+
assert tool_call["function"]["arguments"] == '{"test": "output"}'
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def test_generate_huggingface_chat_template():
|
|
256
|
+
task_run = TaskRun(
|
|
257
|
+
id="run1",
|
|
258
|
+
input="test input",
|
|
259
|
+
input_source=DataSource(
|
|
260
|
+
type=DataSourceType.human, properties={"created_by": "test"}
|
|
261
|
+
),
|
|
262
|
+
output=TaskOutput(
|
|
263
|
+
output="test output",
|
|
264
|
+
source=DataSource(
|
|
265
|
+
type=DataSourceType.synthetic,
|
|
266
|
+
properties={
|
|
267
|
+
"model_name": "test",
|
|
268
|
+
"model_provider": "test",
|
|
269
|
+
"adapter_name": "test",
|
|
270
|
+
},
|
|
271
|
+
),
|
|
272
|
+
),
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
result = generate_huggingface_chat_template(task_run, "system message")
|
|
276
|
+
|
|
277
|
+
assert result == {
|
|
278
|
+
"conversations": [
|
|
279
|
+
{"role": "system", "content": "system message"},
|
|
280
|
+
{"role": "user", "content": "test input"},
|
|
281
|
+
{"role": "assistant", "content": "test output"},
|
|
282
|
+
]
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def test_generate_huggingface_chat_template_toolcall():
|
|
287
|
+
task_run = TaskRun(
|
|
288
|
+
id="run1",
|
|
289
|
+
input="test input",
|
|
290
|
+
input_source=DataSource(
|
|
291
|
+
type=DataSourceType.human, properties={"created_by": "test"}
|
|
292
|
+
),
|
|
293
|
+
output=TaskOutput(
|
|
294
|
+
output='{"key": "value"}',
|
|
295
|
+
source=DataSource(
|
|
296
|
+
type=DataSourceType.synthetic,
|
|
297
|
+
properties={
|
|
298
|
+
"model_name": "test",
|
|
299
|
+
"model_provider": "test",
|
|
300
|
+
"adapter_name": "test",
|
|
301
|
+
},
|
|
302
|
+
),
|
|
303
|
+
),
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
result = generate_huggingface_chat_template_toolcall(task_run, "system message")
|
|
307
|
+
|
|
308
|
+
assert result["conversations"][0] == {"role": "system", "content": "system message"}
|
|
309
|
+
assert result["conversations"][1] == {"role": "user", "content": "test input"}
|
|
310
|
+
assistant_msg = result["conversations"][2]
|
|
311
|
+
assert assistant_msg["role"] == "assistant"
|
|
312
|
+
assert len(assistant_msg["tool_calls"]) == 1
|
|
313
|
+
tool_call = assistant_msg["tool_calls"][0]
|
|
314
|
+
assert tool_call["type"] == "function"
|
|
315
|
+
assert tool_call["function"]["name"] == "task_response"
|
|
316
|
+
assert len(tool_call["function"]["id"]) == 9 # UUID is truncated to 9 chars
|
|
317
|
+
assert tool_call["function"]["id"].isalnum() # Check ID is alphanumeric
|
|
318
|
+
assert tool_call["function"]["arguments"] == {"key": "value"}
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def test_generate_huggingface_chat_template_toolcall_invalid_json():
|
|
322
|
+
task_run = TaskRun(
|
|
323
|
+
id="run1",
|
|
324
|
+
input="test input",
|
|
325
|
+
input_source=DataSource(
|
|
326
|
+
type=DataSourceType.human, properties={"created_by": "test"}
|
|
327
|
+
),
|
|
328
|
+
output=TaskOutput(
|
|
329
|
+
output="invalid json",
|
|
330
|
+
source=DataSource(
|
|
331
|
+
type=DataSourceType.synthetic,
|
|
332
|
+
properties={
|
|
333
|
+
"model_name": "test",
|
|
334
|
+
"model_provider": "test",
|
|
335
|
+
"adapter_name": "test",
|
|
336
|
+
},
|
|
337
|
+
),
|
|
338
|
+
),
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
with pytest.raises(ValueError, match="Invalid JSON in for tool call"):
|
|
342
|
+
generate_huggingface_chat_template_toolcall(task_run, "system message")
|