kiln-ai 0.12.0__py3-none-any.whl → 0.13.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +4 -0
- kiln_ai/adapters/adapter_registry.py +157 -28
- kiln_ai/adapters/eval/__init__.py +28 -0
- kiln_ai/adapters/eval/eval_runner.py +4 -1
- kiln_ai/adapters/eval/g_eval.py +19 -3
- kiln_ai/adapters/eval/test_base_eval.py +1 -0
- kiln_ai/adapters/eval/test_eval_runner.py +1 -0
- kiln_ai/adapters/eval/test_g_eval.py +13 -7
- kiln_ai/adapters/fine_tune/base_finetune.py +16 -2
- kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +8 -1
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +19 -0
- kiln_ai/adapters/fine_tune/test_together_finetune.py +533 -0
- kiln_ai/adapters/fine_tune/together_finetune.py +327 -0
- kiln_ai/adapters/ml_model_list.py +638 -155
- kiln_ai/adapters/model_adapters/__init__.py +2 -4
- kiln_ai/adapters/model_adapters/base_adapter.py +14 -11
- kiln_ai/adapters/model_adapters/litellm_adapter.py +391 -0
- kiln_ai/adapters/model_adapters/litellm_config.py +13 -0
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -0
- kiln_ai/adapters/model_adapters/test_structured_output.py +23 -5
- kiln_ai/adapters/ollama_tools.py +3 -2
- kiln_ai/adapters/parsers/r1_parser.py +19 -14
- kiln_ai/adapters/parsers/test_r1_parser.py +17 -5
- kiln_ai/adapters/provider_tools.py +52 -60
- kiln_ai/adapters/repair/test_repair_task.py +3 -3
- kiln_ai/adapters/run_output.py +1 -1
- kiln_ai/adapters/test_adapter_registry.py +17 -20
- kiln_ai/adapters/test_generate_docs.py +2 -2
- kiln_ai/adapters/test_prompt_adaptors.py +30 -19
- kiln_ai/adapters/test_provider_tools.py +27 -82
- kiln_ai/datamodel/basemodel.py +2 -0
- kiln_ai/datamodel/datamodel_enums.py +2 -0
- kiln_ai/datamodel/json_schema.py +1 -1
- kiln_ai/datamodel/task_output.py +13 -6
- kiln_ai/datamodel/test_basemodel.py +9 -0
- kiln_ai/datamodel/test_datasource.py +19 -0
- kiln_ai/utils/config.py +46 -0
- kiln_ai/utils/dataset_import.py +232 -0
- kiln_ai/utils/test_dataset_import.py +596 -0
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.2.dist-info}/METADATA +51 -7
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.2.dist-info}/RECORD +44 -41
- kiln_ai/adapters/model_adapters/langchain_adapters.py +0 -309
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +0 -10
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +0 -289
- kiln_ai/adapters/model_adapters/test_langchain_adapter.py +0 -343
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +0 -216
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.2.dist-info}/WHEEL +0 -0
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.2.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,533 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from unittest.mock import AsyncMock, MagicMock, patch
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
import together
|
|
6
|
+
from together.types.finetune import FinetuneJobStatus as TogetherFinetuneJobStatus
|
|
7
|
+
|
|
8
|
+
from kiln_ai.adapters.fine_tune.base_finetune import (
|
|
9
|
+
FineTuneParameter,
|
|
10
|
+
FineTuneStatusType,
|
|
11
|
+
)
|
|
12
|
+
from kiln_ai.adapters.fine_tune.dataset_formatter import DatasetFormat, DatasetFormatter
|
|
13
|
+
from kiln_ai.adapters.fine_tune.together_finetune import (
|
|
14
|
+
TogetherFinetune,
|
|
15
|
+
_completed_statuses,
|
|
16
|
+
_failed_statuses,
|
|
17
|
+
_pending_statuses,
|
|
18
|
+
_running_statuses,
|
|
19
|
+
)
|
|
20
|
+
from kiln_ai.datamodel import (
|
|
21
|
+
DatasetSplit,
|
|
22
|
+
StructuredOutputMode,
|
|
23
|
+
Task,
|
|
24
|
+
)
|
|
25
|
+
from kiln_ai.datamodel import Finetune as FinetuneModel
|
|
26
|
+
from kiln_ai.datamodel.dataset_split import Train80Test20SplitDefinition
|
|
27
|
+
from kiln_ai.utils.config import Config
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def test_together_status_categorization():
|
|
31
|
+
"""
|
|
32
|
+
Test that all statuses from TogetherFinetuneJobStatus are included in exactly one
|
|
33
|
+
of the status categorization arrays.
|
|
34
|
+
"""
|
|
35
|
+
# Collect all status values from the TogetherFinetuneJobStatus class
|
|
36
|
+
all_statuses = list(TogetherFinetuneJobStatus)
|
|
37
|
+
|
|
38
|
+
# Collect all statuses from the categorization arrays
|
|
39
|
+
categorized_statuses = set()
|
|
40
|
+
categorized_statuses.update(_pending_statuses)
|
|
41
|
+
categorized_statuses.update(_running_statuses)
|
|
42
|
+
categorized_statuses.update(_completed_statuses)
|
|
43
|
+
categorized_statuses.update(_failed_statuses)
|
|
44
|
+
|
|
45
|
+
# Check if any status is missing from categorization
|
|
46
|
+
missing_statuses = set(all_statuses) - categorized_statuses
|
|
47
|
+
assert not missing_statuses, (
|
|
48
|
+
f"These statuses are not categorized: {missing_statuses}"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# Check if any status appears in multiple categories
|
|
52
|
+
all_categorization_lists = [
|
|
53
|
+
_pending_statuses,
|
|
54
|
+
_running_statuses,
|
|
55
|
+
_completed_statuses,
|
|
56
|
+
_failed_statuses,
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
for status in all_statuses:
|
|
60
|
+
appearances = sum(status in category for category in all_categorization_lists)
|
|
61
|
+
assert appearances == 1, (
|
|
62
|
+
f"Status '{status}' appears in {appearances} categories (should be exactly 1)"
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@pytest.fixture
|
|
67
|
+
def finetune(tmp_path):
|
|
68
|
+
tmp_file = tmp_path / "test-finetune.kiln"
|
|
69
|
+
datamodel = FinetuneModel(
|
|
70
|
+
name="test-finetune",
|
|
71
|
+
provider="together",
|
|
72
|
+
provider_id="together-123",
|
|
73
|
+
base_model_id="llama-v2-7b",
|
|
74
|
+
train_split_name="train",
|
|
75
|
+
dataset_split_id="dataset-123",
|
|
76
|
+
system_message="Test system message",
|
|
77
|
+
path=tmp_file,
|
|
78
|
+
)
|
|
79
|
+
return datamodel
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@pytest.fixture
|
|
83
|
+
def together_finetune(finetune, mock_together_client, mock_api_key):
|
|
84
|
+
finetune = TogetherFinetune(datamodel=finetune)
|
|
85
|
+
return finetune
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@pytest.fixture
|
|
89
|
+
def mock_together_client():
|
|
90
|
+
with patch(
|
|
91
|
+
"kiln_ai.adapters.fine_tune.together_finetune.Together"
|
|
92
|
+
) as mock_together:
|
|
93
|
+
mock_client = MagicMock()
|
|
94
|
+
mock_together.return_value = mock_client
|
|
95
|
+
yield mock_client
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@pytest.fixture
|
|
99
|
+
def mock_api_key():
|
|
100
|
+
with patch.object(Config, "shared") as mock_config:
|
|
101
|
+
mock_config.return_value.together_api_key = "test-api-key"
|
|
102
|
+
yield
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def test_init_missing_api_key(finetune):
|
|
106
|
+
with patch.object(Config, "shared") as mock_config:
|
|
107
|
+
mock_config.return_value.together_api_key = None
|
|
108
|
+
with pytest.raises(ValueError, match="Together.ai API key not set"):
|
|
109
|
+
TogetherFinetune(datamodel=finetune)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def test_init_success(mock_api_key, mock_together_client, finetune):
|
|
113
|
+
# no exception should be raised
|
|
114
|
+
TogetherFinetune(datamodel=finetune)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def test_available_parameters():
|
|
118
|
+
parameters = TogetherFinetune.available_parameters()
|
|
119
|
+
assert len(parameters) == 11
|
|
120
|
+
assert all(isinstance(p, FineTuneParameter) for p in parameters)
|
|
121
|
+
|
|
122
|
+
# Check specific parameters
|
|
123
|
+
param_names = [p.name for p in parameters]
|
|
124
|
+
assert "epochs" in param_names
|
|
125
|
+
assert "learning_rate" in param_names
|
|
126
|
+
assert "batch_size" in param_names
|
|
127
|
+
assert "num_checkpoints" in param_names
|
|
128
|
+
assert "min_lr_ratio" in param_names
|
|
129
|
+
assert "warmup_ratio" in param_names
|
|
130
|
+
assert "max_grad_norm" in param_names
|
|
131
|
+
assert "weight_decay" in param_names
|
|
132
|
+
assert "lora_rank" in param_names
|
|
133
|
+
assert "lora_dropout" in param_names
|
|
134
|
+
assert "lora_alpha" in param_names
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
async def test_status_missing_provider_id(together_finetune, mock_api_key):
|
|
138
|
+
together_finetune.datamodel.provider_id = None
|
|
139
|
+
|
|
140
|
+
status = await together_finetune.status()
|
|
141
|
+
assert status.status == FineTuneStatusType.unknown
|
|
142
|
+
assert "Fine-tuning job ID not set" in status.message
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@pytest.mark.parametrize(
|
|
146
|
+
"together_status,expected_status,expected_message",
|
|
147
|
+
[
|
|
148
|
+
(
|
|
149
|
+
TogetherFinetuneJobStatus.STATUS_PENDING,
|
|
150
|
+
FineTuneStatusType.pending,
|
|
151
|
+
f"Fine-tuning job is pending [{TogetherFinetuneJobStatus.STATUS_PENDING}]",
|
|
152
|
+
),
|
|
153
|
+
(
|
|
154
|
+
TogetherFinetuneJobStatus.STATUS_RUNNING,
|
|
155
|
+
FineTuneStatusType.running,
|
|
156
|
+
f"Fine-tuning job is running [{TogetherFinetuneJobStatus.STATUS_RUNNING}]",
|
|
157
|
+
),
|
|
158
|
+
(
|
|
159
|
+
TogetherFinetuneJobStatus.STATUS_COMPLETED,
|
|
160
|
+
FineTuneStatusType.completed,
|
|
161
|
+
"Fine-tuning job completed",
|
|
162
|
+
),
|
|
163
|
+
(
|
|
164
|
+
TogetherFinetuneJobStatus.STATUS_ERROR,
|
|
165
|
+
FineTuneStatusType.failed,
|
|
166
|
+
f"Fine-tuning job failed [{TogetherFinetuneJobStatus.STATUS_ERROR}]",
|
|
167
|
+
),
|
|
168
|
+
(
|
|
169
|
+
"UNKNOWN_STATUS",
|
|
170
|
+
FineTuneStatusType.unknown,
|
|
171
|
+
"Unknown fine-tuning job status [UNKNOWN_STATUS]",
|
|
172
|
+
),
|
|
173
|
+
],
|
|
174
|
+
)
|
|
175
|
+
async def test_status_job_states(
|
|
176
|
+
mock_together_client,
|
|
177
|
+
together_finetune,
|
|
178
|
+
together_status,
|
|
179
|
+
expected_status,
|
|
180
|
+
expected_message,
|
|
181
|
+
mock_api_key,
|
|
182
|
+
):
|
|
183
|
+
# Mock the retrieve method of the fine_tuning object
|
|
184
|
+
mock_job = MagicMock()
|
|
185
|
+
mock_job.status = together_status
|
|
186
|
+
mock_together_client.fine_tuning.retrieve.return_value = mock_job
|
|
187
|
+
|
|
188
|
+
status = await together_finetune.status()
|
|
189
|
+
assert status.status == expected_status
|
|
190
|
+
|
|
191
|
+
# Check that the status was updated in the datamodel
|
|
192
|
+
assert together_finetune.datamodel.latest_status == expected_status
|
|
193
|
+
assert status.status == expected_status
|
|
194
|
+
assert expected_message == status.message
|
|
195
|
+
|
|
196
|
+
# Verify the fine_tuning.retrieve method was called
|
|
197
|
+
mock_together_client.fine_tuning.retrieve.assert_called_once_with(
|
|
198
|
+
id=together_finetune.datamodel.provider_id
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
async def test_status_exception(together_finetune, mock_together_client, mock_api_key):
|
|
203
|
+
# Mock the retrieve method to raise an exception
|
|
204
|
+
mock_together_client.fine_tuning.retrieve.side_effect = Exception("API error")
|
|
205
|
+
|
|
206
|
+
status = await together_finetune.status()
|
|
207
|
+
assert status.status == FineTuneStatusType.unknown
|
|
208
|
+
assert "Error retrieving fine-tuning job status: API error" == status.message
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
@pytest.fixture
|
|
212
|
+
def mock_dataset():
|
|
213
|
+
return DatasetSplit(
|
|
214
|
+
id="test-dataset-123",
|
|
215
|
+
name="Test Dataset",
|
|
216
|
+
splits=Train80Test20SplitDefinition,
|
|
217
|
+
split_contents={"train": [], "test": []},
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
@pytest.fixture
|
|
222
|
+
def mock_task():
|
|
223
|
+
return Task(
|
|
224
|
+
id="test-task-123",
|
|
225
|
+
name="Test Task",
|
|
226
|
+
output_json_schema=None, # Can be modified in specific tests
|
|
227
|
+
instruction="Test instruction",
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
async def test_generate_and_upload_jsonl_success(
|
|
232
|
+
together_finetune, mock_dataset, mock_task, mock_together_client, mock_api_key
|
|
233
|
+
):
|
|
234
|
+
# Mock the formatter
|
|
235
|
+
mock_formatter = MagicMock(spec=DatasetFormatter)
|
|
236
|
+
mock_path = Path("mock_path.jsonl")
|
|
237
|
+
mock_formatter.dump_to_file.return_value = mock_path
|
|
238
|
+
|
|
239
|
+
# Mock the files.upload response
|
|
240
|
+
mock_file = MagicMock()
|
|
241
|
+
mock_file.id = "file-123"
|
|
242
|
+
mock_together_client.files.upload.return_value = mock_file
|
|
243
|
+
|
|
244
|
+
with patch(
|
|
245
|
+
"kiln_ai.adapters.fine_tune.together_finetune.DatasetFormatter",
|
|
246
|
+
return_value=mock_formatter,
|
|
247
|
+
):
|
|
248
|
+
result = await together_finetune.generate_and_upload_jsonl(
|
|
249
|
+
mock_dataset, "train", mock_task, DatasetFormat.OPENAI_CHAT_JSONL
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Check the formatter was created with correct parameters
|
|
253
|
+
assert mock_formatter.dump_to_file.call_count == 1
|
|
254
|
+
|
|
255
|
+
# Check the file was uploaded
|
|
256
|
+
mock_together_client.files.upload.assert_called_once_with(
|
|
257
|
+
file=mock_path,
|
|
258
|
+
purpose=together.types.files.FilePurpose.FineTune,
|
|
259
|
+
check=True,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
# Check the result is the file ID
|
|
263
|
+
assert result == "file-123"
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
async def test_generate_and_upload_jsonl_error(
|
|
267
|
+
together_finetune, mock_dataset, mock_task, mock_together_client, mock_api_key
|
|
268
|
+
):
|
|
269
|
+
# Mock the formatter
|
|
270
|
+
mock_formatter = MagicMock(spec=DatasetFormatter)
|
|
271
|
+
mock_path = Path("mock_path.jsonl")
|
|
272
|
+
mock_formatter.dump_to_file.return_value = mock_path
|
|
273
|
+
|
|
274
|
+
# Mock the files.upload to raise an exception
|
|
275
|
+
mock_together_client.files.upload.side_effect = Exception("Upload failed")
|
|
276
|
+
|
|
277
|
+
with (
|
|
278
|
+
patch(
|
|
279
|
+
"kiln_ai.adapters.fine_tune.together_finetune.DatasetFormatter",
|
|
280
|
+
return_value=mock_formatter,
|
|
281
|
+
),
|
|
282
|
+
pytest.raises(ValueError, match="Failed to upload dataset: Upload failed"),
|
|
283
|
+
):
|
|
284
|
+
await together_finetune.generate_and_upload_jsonl(
|
|
285
|
+
mock_dataset, "train", mock_task, DatasetFormat.OPENAI_CHAT_JSONL
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
@pytest.mark.parametrize(
|
|
290
|
+
"output_schema,expected_mode,expected_format,validation_file",
|
|
291
|
+
[
|
|
292
|
+
(
|
|
293
|
+
'{"type": "object", "properties": {"key": {"type": "string"}}}',
|
|
294
|
+
StructuredOutputMode.json_custom_instructions,
|
|
295
|
+
DatasetFormat.OPENAI_CHAT_JSON_SCHEMA_JSONL,
|
|
296
|
+
True,
|
|
297
|
+
),
|
|
298
|
+
(None, None, DatasetFormat.OPENAI_CHAT_JSONL, False),
|
|
299
|
+
],
|
|
300
|
+
)
|
|
301
|
+
async def test_start_success(
|
|
302
|
+
together_finetune,
|
|
303
|
+
mock_dataset,
|
|
304
|
+
mock_task,
|
|
305
|
+
mock_together_client,
|
|
306
|
+
mock_api_key,
|
|
307
|
+
output_schema,
|
|
308
|
+
expected_mode,
|
|
309
|
+
expected_format,
|
|
310
|
+
validation_file,
|
|
311
|
+
):
|
|
312
|
+
# Set output schema on task
|
|
313
|
+
mock_task.output_json_schema = output_schema
|
|
314
|
+
|
|
315
|
+
# Set parent task on finetune
|
|
316
|
+
together_finetune.datamodel.parent = mock_task
|
|
317
|
+
|
|
318
|
+
# Mock file ID from generate_and_upload_jsonl
|
|
319
|
+
mock_file_id = "file-123"
|
|
320
|
+
|
|
321
|
+
# Mock fine-tuning job response
|
|
322
|
+
mock_job = MagicMock()
|
|
323
|
+
mock_job.id = "job-123"
|
|
324
|
+
mock_job.output_name = "model-123"
|
|
325
|
+
mock_together_client.fine_tuning.create.return_value = mock_job
|
|
326
|
+
|
|
327
|
+
with patch.object(
|
|
328
|
+
together_finetune,
|
|
329
|
+
"generate_and_upload_jsonl",
|
|
330
|
+
AsyncMock(return_value=mock_file_id),
|
|
331
|
+
):
|
|
332
|
+
if validation_file:
|
|
333
|
+
together_finetune.datamodel.validation_split_name = "validation"
|
|
334
|
+
|
|
335
|
+
await together_finetune._start(mock_dataset)
|
|
336
|
+
|
|
337
|
+
# Check that generate_and_upload_jsonl was called with correct parameters
|
|
338
|
+
first_call = together_finetune.generate_and_upload_jsonl.call_args_list[0].args
|
|
339
|
+
assert first_call[0] == mock_dataset
|
|
340
|
+
assert first_call[1] == together_finetune.datamodel.train_split_name
|
|
341
|
+
assert first_call[2] == mock_task
|
|
342
|
+
assert first_call[3] == expected_format
|
|
343
|
+
if validation_file:
|
|
344
|
+
second_call = together_finetune.generate_and_upload_jsonl.call_args_list[
|
|
345
|
+
1
|
|
346
|
+
].args
|
|
347
|
+
assert second_call[0] == mock_dataset
|
|
348
|
+
assert second_call[1] == "validation"
|
|
349
|
+
assert second_call[2] == mock_task
|
|
350
|
+
assert second_call[3] == expected_format
|
|
351
|
+
|
|
352
|
+
# Check that fine_tuning.create was called with correct parameters
|
|
353
|
+
mock_together_client.fine_tuning.create.assert_called_once_with(
|
|
354
|
+
training_file=mock_file_id,
|
|
355
|
+
validation_file=mock_file_id if validation_file else "",
|
|
356
|
+
model=together_finetune.datamodel.base_model_id,
|
|
357
|
+
lora=True,
|
|
358
|
+
suffix=f"kiln_ai_{together_finetune.datamodel.id}"[:40],
|
|
359
|
+
wandb_api_key=Config.shared().wandb_api_key,
|
|
360
|
+
wandb_project_name="Kiln_AI",
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
# Check that datamodel was updated correctly
|
|
364
|
+
assert together_finetune.datamodel.provider_id == "job-123"
|
|
365
|
+
assert together_finetune.datamodel.fine_tune_model_id == "model-123"
|
|
366
|
+
assert together_finetune.datamodel.structured_output_mode == expected_mode
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
async def test_start_missing_task(together_finetune, mock_dataset, mock_api_key):
|
|
370
|
+
# Don't set parent task
|
|
371
|
+
together_finetune.datamodel.parent = None
|
|
372
|
+
together_finetune.datamodel.save_to_file()
|
|
373
|
+
|
|
374
|
+
with pytest.raises(ValueError, match="Task is required to start a fine-tune"):
|
|
375
|
+
await together_finetune._start(mock_dataset)
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
async def test_deploy_always_succeeds(together_finetune, mock_api_key):
|
|
379
|
+
# Together automatically deploys, so _deploy should always return True
|
|
380
|
+
result = await together_finetune._deploy()
|
|
381
|
+
assert result is True
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def test_augment_system_message(mock_task):
|
|
385
|
+
system_message = "You are a helpful assistant."
|
|
386
|
+
|
|
387
|
+
# Plaintext == no change
|
|
388
|
+
augmented_system_message = TogetherFinetune.augment_system_message(
|
|
389
|
+
system_message, mock_task
|
|
390
|
+
)
|
|
391
|
+
assert augmented_system_message == "You are a helpful assistant."
|
|
392
|
+
|
|
393
|
+
# Now with JSON == append JSON instructions
|
|
394
|
+
mock_task.output_json_schema = (
|
|
395
|
+
'{"type": "object", "properties": {"key": {"type": "string"}}}'
|
|
396
|
+
)
|
|
397
|
+
augmented_system_message = TogetherFinetune.augment_system_message(
|
|
398
|
+
system_message, mock_task
|
|
399
|
+
)
|
|
400
|
+
assert (
|
|
401
|
+
augmented_system_message
|
|
402
|
+
== "You are a helpful assistant.\n\nReturn only JSON. Do not include any non JSON text.\n"
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
@pytest.mark.parametrize(
|
|
407
|
+
"parameter_name,input_value,expected_key,expected_value,should_exist",
|
|
408
|
+
[
|
|
409
|
+
# learning_rate tests
|
|
410
|
+
("learning_rate", 0.001, "learning_rate", 0.001, True),
|
|
411
|
+
("learning_rate", "not_a_float", "learning_rate", None, False),
|
|
412
|
+
# epochs tests
|
|
413
|
+
("epochs", 5, "n_epochs", 5, True),
|
|
414
|
+
("epochs", "not_an_int", "n_epochs", None, False),
|
|
415
|
+
# num_checkpoints tests
|
|
416
|
+
("num_checkpoints", 3, "n_checkpoints", 3, True),
|
|
417
|
+
("num_checkpoints", "not_an_int", "n_checkpoints", None, False),
|
|
418
|
+
# batch_size tests
|
|
419
|
+
("batch_size", 32, "batch_size", 32, True),
|
|
420
|
+
("batch_size", "not_an_int", "batch_size", None, False),
|
|
421
|
+
# min_lr_ratio tests
|
|
422
|
+
("min_lr_ratio", 0.1, "min_lr_ratio", 0.1, True),
|
|
423
|
+
("min_lr_ratio", "not_a_float", "min_lr_ratio", None, False),
|
|
424
|
+
# warmup_ratio tests
|
|
425
|
+
("warmup_ratio", 0.2, "warmup_ratio", 0.2, True),
|
|
426
|
+
("warmup_ratio", "not_a_float", "warmup_ratio", None, False),
|
|
427
|
+
# max_grad_norm tests
|
|
428
|
+
("max_grad_norm", 5.0, "max_grad_norm", 5.0, True),
|
|
429
|
+
("max_grad_norm", "not_a_float", "max_grad_norm", None, False),
|
|
430
|
+
# weight_decay tests
|
|
431
|
+
("weight_decay", 0.01, "weight_decay", 0.01, True),
|
|
432
|
+
("weight_decay", "not_a_float", "weight_decay", None, False),
|
|
433
|
+
# lora_rank tests
|
|
434
|
+
("lora_rank", 16, "lora_r", 16, True),
|
|
435
|
+
("lora_rank", "not_an_int", "lora_r", None, False),
|
|
436
|
+
# lora_dropout tests
|
|
437
|
+
("lora_dropout", 0.1, "lora_dropout", 0.1, True),
|
|
438
|
+
("lora_dropout", "not_a_float", "lora_dropout", None, False),
|
|
439
|
+
# lora_alpha tests
|
|
440
|
+
("lora_alpha", 32.0, "lora_alpha", 32.0, True),
|
|
441
|
+
("lora_alpha", "not_a_float", "lora_alpha", None, False),
|
|
442
|
+
],
|
|
443
|
+
)
|
|
444
|
+
def test_build_finetune_parameters(
|
|
445
|
+
together_finetune,
|
|
446
|
+
parameter_name,
|
|
447
|
+
input_value,
|
|
448
|
+
expected_key,
|
|
449
|
+
expected_value,
|
|
450
|
+
should_exist,
|
|
451
|
+
):
|
|
452
|
+
"""Test that _build_finetune_parameters correctly handles different parameters."""
|
|
453
|
+
# Set the parameter
|
|
454
|
+
together_finetune.datamodel.parameters = {parameter_name: input_value}
|
|
455
|
+
together_finetune.datamodel.id = "test-display-name"
|
|
456
|
+
|
|
457
|
+
# Call the method to build parameters
|
|
458
|
+
result = together_finetune._build_finetune_parameters()
|
|
459
|
+
|
|
460
|
+
# Check that required parameters are always present
|
|
461
|
+
assert result["lora"] is True
|
|
462
|
+
assert result["suffix"] == "kiln_ai_test-display-name"
|
|
463
|
+
|
|
464
|
+
# Check the specific parameter we're testing
|
|
465
|
+
if should_exist:
|
|
466
|
+
assert expected_key in result
|
|
467
|
+
assert result[expected_key] == expected_value
|
|
468
|
+
else:
|
|
469
|
+
assert expected_key not in result
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
@pytest.mark.parametrize(
|
|
473
|
+
"parameters,expected_params",
|
|
474
|
+
[
|
|
475
|
+
# Test default values when parameters are empty
|
|
476
|
+
(
|
|
477
|
+
{},
|
|
478
|
+
{
|
|
479
|
+
"lora": True,
|
|
480
|
+
"suffix": "kiln_ai_1234",
|
|
481
|
+
},
|
|
482
|
+
),
|
|
483
|
+
# Test multiple parameters together
|
|
484
|
+
(
|
|
485
|
+
{
|
|
486
|
+
"epochs": 3,
|
|
487
|
+
"learning_rate": 0.0005,
|
|
488
|
+
"batch_size": 16,
|
|
489
|
+
"num_checkpoints": 2,
|
|
490
|
+
"min_lr_ratio": 0.1,
|
|
491
|
+
"warmup_ratio": 0.2,
|
|
492
|
+
"max_grad_norm": 2.0,
|
|
493
|
+
"weight_decay": 0.01,
|
|
494
|
+
},
|
|
495
|
+
{
|
|
496
|
+
"lora": True,
|
|
497
|
+
"n_epochs": 3,
|
|
498
|
+
"learning_rate": 0.0005,
|
|
499
|
+
"batch_size": 16,
|
|
500
|
+
"n_checkpoints": 2,
|
|
501
|
+
"min_lr_ratio": 0.1,
|
|
502
|
+
"warmup_ratio": 0.2,
|
|
503
|
+
"max_grad_norm": 2.0,
|
|
504
|
+
"weight_decay": 0.01,
|
|
505
|
+
"suffix": "kiln_ai_1234",
|
|
506
|
+
},
|
|
507
|
+
),
|
|
508
|
+
# Test mix of valid and invalid parameters
|
|
509
|
+
(
|
|
510
|
+
{
|
|
511
|
+
"epochs": "invalid",
|
|
512
|
+
"learning_rate": 0.001,
|
|
513
|
+
"batch_size": "invalid",
|
|
514
|
+
"num_checkpoints": "invalid",
|
|
515
|
+
},
|
|
516
|
+
{
|
|
517
|
+
"lora": True,
|
|
518
|
+
"learning_rate": 0.001,
|
|
519
|
+
"suffix": "kiln_ai_1234",
|
|
520
|
+
},
|
|
521
|
+
),
|
|
522
|
+
],
|
|
523
|
+
)
|
|
524
|
+
def test_build_finetune_parameters_combinations(
|
|
525
|
+
together_finetune, parameters, expected_params
|
|
526
|
+
):
|
|
527
|
+
"""Test combinations of parameters in _build_finetune_parameters."""
|
|
528
|
+
together_finetune.datamodel.parameters = parameters
|
|
529
|
+
together_finetune.datamodel.id = "1234"
|
|
530
|
+
result = together_finetune._build_finetune_parameters()
|
|
531
|
+
|
|
532
|
+
# Check that all expected keys are present with the correct values
|
|
533
|
+
assert result == expected_params
|