kiln-ai 0.13.2__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/eval/base_eval.py +7 -2
- kiln_ai/adapters/fine_tune/base_finetune.py +6 -23
- kiln_ai/adapters/fine_tune/dataset_formatter.py +4 -4
- kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +163 -15
- kiln_ai/adapters/fine_tune/test_base_finetune.py +7 -9
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +3 -3
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +495 -9
- kiln_ai/adapters/fine_tune/test_vertex_finetune.py +586 -0
- kiln_ai/adapters/fine_tune/vertex_finetune.py +217 -0
- kiln_ai/adapters/ml_model_list.py +319 -43
- kiln_ai/adapters/model_adapters/base_adapter.py +15 -10
- kiln_ai/adapters/model_adapters/litellm_adapter.py +10 -5
- kiln_ai/adapters/provider_tools.py +7 -0
- kiln_ai/adapters/test_provider_tools.py +16 -0
- kiln_ai/datamodel/json_schema.py +24 -7
- kiln_ai/datamodel/task_output.py +9 -5
- kiln_ai/datamodel/task_run.py +29 -5
- kiln_ai/datamodel/test_example_models.py +104 -3
- kiln_ai/datamodel/test_json_schema.py +22 -3
- kiln_ai/datamodel/test_model_perf.py +3 -2
- {kiln_ai-0.13.2.dist-info → kiln_ai-0.15.0.dist-info}/METADATA +3 -2
- {kiln_ai-0.13.2.dist-info → kiln_ai-0.15.0.dist-info}/RECORD +25 -24
- kiln_ai/adapters/test_generate_docs.py +0 -69
- {kiln_ai-0.13.2.dist-info → kiln_ai-0.15.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.13.2.dist-info → kiln_ai-0.15.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,586 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from unittest.mock import AsyncMock, MagicMock, patch
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
from google.cloud import storage
|
|
7
|
+
from google.cloud.aiplatform_v1beta1 import types as gca_types
|
|
8
|
+
from vertexai.tuning import sft
|
|
9
|
+
|
|
10
|
+
from kiln_ai.adapters.fine_tune.base_finetune import FineTuneStatusType
|
|
11
|
+
from kiln_ai.adapters.fine_tune.dataset_formatter import DatasetFormat, DatasetFormatter
|
|
12
|
+
from kiln_ai.adapters.fine_tune.vertex_finetune import VertexFinetune
|
|
13
|
+
from kiln_ai.datamodel import (
|
|
14
|
+
DatasetSplit,
|
|
15
|
+
FinetuneDataStrategy,
|
|
16
|
+
StructuredOutputMode,
|
|
17
|
+
Task,
|
|
18
|
+
)
|
|
19
|
+
from kiln_ai.datamodel import Finetune as FinetuneModel
|
|
20
|
+
from kiln_ai.datamodel.dataset_split import Train80Test20SplitDefinition
|
|
21
|
+
from kiln_ai.utils.config import Config
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@pytest.fixture
|
|
25
|
+
def vertex_finetune(tmp_path):
|
|
26
|
+
tmp_file = tmp_path / "test-finetune.kiln"
|
|
27
|
+
finetune = VertexFinetune(
|
|
28
|
+
datamodel=FinetuneModel(
|
|
29
|
+
name="test-finetune",
|
|
30
|
+
provider="vertex",
|
|
31
|
+
provider_id="vertex-123",
|
|
32
|
+
base_model_id="gemini-2.0-pro",
|
|
33
|
+
train_split_name="train",
|
|
34
|
+
dataset_split_id="dataset-123",
|
|
35
|
+
system_message="Test system message",
|
|
36
|
+
fine_tune_model_id="ft-123",
|
|
37
|
+
path=tmp_file,
|
|
38
|
+
data_strategy=FinetuneDataStrategy.final_only,
|
|
39
|
+
),
|
|
40
|
+
)
|
|
41
|
+
return finetune
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@pytest.fixture
|
|
45
|
+
def mock_response():
|
|
46
|
+
# Mock SFT job response object
|
|
47
|
+
response = MagicMock(spec=sft.SupervisedTuningJob)
|
|
48
|
+
response.error = None
|
|
49
|
+
response.state = gca_types.JobState.JOB_STATE_SUCCEEDED
|
|
50
|
+
response.tuned_model_endpoint_name = "ft-123"
|
|
51
|
+
return response
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@pytest.fixture
|
|
55
|
+
def mock_dataset():
|
|
56
|
+
return DatasetSplit(
|
|
57
|
+
id="test-dataset-123",
|
|
58
|
+
name="Test Dataset",
|
|
59
|
+
splits=Train80Test20SplitDefinition,
|
|
60
|
+
split_contents={"train": [], "test": []},
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@pytest.fixture
|
|
65
|
+
def mock_task():
|
|
66
|
+
return Task(
|
|
67
|
+
id="test-task-123",
|
|
68
|
+
name="Test Task",
|
|
69
|
+
output_json_schema=None, # Can be modified in specific tests
|
|
70
|
+
instruction="Test instruction",
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
async def test_status_pending_no_provider_id(vertex_finetune):
|
|
75
|
+
vertex_finetune.datamodel.provider_id = None
|
|
76
|
+
|
|
77
|
+
status = await vertex_finetune.status()
|
|
78
|
+
assert status.status == FineTuneStatusType.pending
|
|
79
|
+
assert "This fine-tune has not been started" in status.message
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@pytest.mark.parametrize(
|
|
83
|
+
"state,expected_status,message_contains",
|
|
84
|
+
[
|
|
85
|
+
(
|
|
86
|
+
gca_types.JobState.JOB_STATE_FAILED,
|
|
87
|
+
FineTuneStatusType.failed,
|
|
88
|
+
"Fine Tune Job Failed",
|
|
89
|
+
),
|
|
90
|
+
(
|
|
91
|
+
gca_types.JobState.JOB_STATE_EXPIRED,
|
|
92
|
+
FineTuneStatusType.failed,
|
|
93
|
+
"Fine Tune Job Failed",
|
|
94
|
+
),
|
|
95
|
+
(
|
|
96
|
+
gca_types.JobState.JOB_STATE_CANCELLED,
|
|
97
|
+
FineTuneStatusType.failed,
|
|
98
|
+
"Fine Tune Job Cancelled",
|
|
99
|
+
),
|
|
100
|
+
(
|
|
101
|
+
gca_types.JobState.JOB_STATE_CANCELLING,
|
|
102
|
+
FineTuneStatusType.failed,
|
|
103
|
+
"Fine Tune Job Cancelled",
|
|
104
|
+
),
|
|
105
|
+
(
|
|
106
|
+
gca_types.JobState.JOB_STATE_PENDING,
|
|
107
|
+
FineTuneStatusType.pending,
|
|
108
|
+
"Fine Tune Job Pending",
|
|
109
|
+
),
|
|
110
|
+
(
|
|
111
|
+
gca_types.JobState.JOB_STATE_QUEUED,
|
|
112
|
+
FineTuneStatusType.pending,
|
|
113
|
+
"Fine Tune Job Pending",
|
|
114
|
+
),
|
|
115
|
+
(
|
|
116
|
+
gca_types.JobState.JOB_STATE_RUNNING,
|
|
117
|
+
FineTuneStatusType.running,
|
|
118
|
+
"Fine Tune Job Running",
|
|
119
|
+
),
|
|
120
|
+
(
|
|
121
|
+
gca_types.JobState.JOB_STATE_SUCCEEDED,
|
|
122
|
+
FineTuneStatusType.completed,
|
|
123
|
+
"Fine Tune Job Completed",
|
|
124
|
+
),
|
|
125
|
+
(
|
|
126
|
+
gca_types.JobState.JOB_STATE_PARTIALLY_SUCCEEDED,
|
|
127
|
+
FineTuneStatusType.completed,
|
|
128
|
+
"Fine Tune Job Completed",
|
|
129
|
+
),
|
|
130
|
+
(
|
|
131
|
+
gca_types.JobState.JOB_STATE_PAUSED,
|
|
132
|
+
FineTuneStatusType.unknown,
|
|
133
|
+
"Unknown state",
|
|
134
|
+
),
|
|
135
|
+
(
|
|
136
|
+
gca_types.JobState.JOB_STATE_UPDATING,
|
|
137
|
+
FineTuneStatusType.unknown,
|
|
138
|
+
"Unknown state",
|
|
139
|
+
),
|
|
140
|
+
(
|
|
141
|
+
gca_types.JobState.JOB_STATE_UNSPECIFIED,
|
|
142
|
+
FineTuneStatusType.unknown,
|
|
143
|
+
"Unknown state",
|
|
144
|
+
),
|
|
145
|
+
(999, FineTuneStatusType.unknown, "Unknown state"), # Test unknown state
|
|
146
|
+
],
|
|
147
|
+
)
|
|
148
|
+
async def test_status_job_states(
|
|
149
|
+
vertex_finetune,
|
|
150
|
+
mock_response,
|
|
151
|
+
state,
|
|
152
|
+
expected_status,
|
|
153
|
+
message_contains,
|
|
154
|
+
):
|
|
155
|
+
mock_response.state = state
|
|
156
|
+
|
|
157
|
+
with patch(
|
|
158
|
+
"kiln_ai.adapters.fine_tune.vertex_finetune.sft.SupervisedTuningJob",
|
|
159
|
+
return_value=mock_response,
|
|
160
|
+
):
|
|
161
|
+
status = await vertex_finetune.status()
|
|
162
|
+
assert status.status == expected_status
|
|
163
|
+
assert message_contains in status.message
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
async def test_status_with_error(vertex_finetune, mock_response):
|
|
167
|
+
# Set up error response
|
|
168
|
+
mock_response.error = MagicMock()
|
|
169
|
+
mock_response.error.code = 123
|
|
170
|
+
mock_response.error.message = "Test error message"
|
|
171
|
+
|
|
172
|
+
with patch(
|
|
173
|
+
"kiln_ai.adapters.fine_tune.vertex_finetune.sft.SupervisedTuningJob",
|
|
174
|
+
return_value=mock_response,
|
|
175
|
+
):
|
|
176
|
+
status = await vertex_finetune.status()
|
|
177
|
+
assert status.status == FineTuneStatusType.failed
|
|
178
|
+
assert "Test error message [123]" in status.message
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
async def test_status_updates_model_id(vertex_finetune, mock_response):
|
|
182
|
+
# Set initial fine-tuned model ID
|
|
183
|
+
vertex_finetune.datamodel.fine_tune_model_id = "old-ft-model"
|
|
184
|
+
|
|
185
|
+
# Set new model ID in response
|
|
186
|
+
mock_response.tuned_model_endpoint_name = "new-ft-model"
|
|
187
|
+
|
|
188
|
+
with patch(
|
|
189
|
+
"kiln_ai.adapters.fine_tune.vertex_finetune.sft.SupervisedTuningJob",
|
|
190
|
+
return_value=mock_response,
|
|
191
|
+
):
|
|
192
|
+
status = await vertex_finetune.status()
|
|
193
|
+
|
|
194
|
+
# Verify model ID was updated
|
|
195
|
+
assert vertex_finetune.datamodel.fine_tune_model_id == "new-ft-model"
|
|
196
|
+
|
|
197
|
+
# Verify status returned correctly
|
|
198
|
+
assert status.status == FineTuneStatusType.completed
|
|
199
|
+
assert status.message == "Fine Tune Job Completed"
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
async def test_status_updates_latest_status(vertex_finetune, mock_response):
|
|
203
|
+
# Set initial status
|
|
204
|
+
vertex_finetune.datamodel.latest_status = FineTuneStatusType.running
|
|
205
|
+
|
|
206
|
+
# Set completed state in response
|
|
207
|
+
mock_response.state = gca_types.JobState.JOB_STATE_SUCCEEDED
|
|
208
|
+
|
|
209
|
+
with patch(
|
|
210
|
+
"kiln_ai.adapters.fine_tune.vertex_finetune.sft.SupervisedTuningJob",
|
|
211
|
+
return_value=mock_response,
|
|
212
|
+
):
|
|
213
|
+
status = await vertex_finetune.status()
|
|
214
|
+
|
|
215
|
+
# Verify status was updated in datamodel
|
|
216
|
+
assert vertex_finetune.datamodel.latest_status == FineTuneStatusType.completed
|
|
217
|
+
assert status.status == FineTuneStatusType.completed
|
|
218
|
+
|
|
219
|
+
# Verify file was saved (since path exists)
|
|
220
|
+
assert vertex_finetune.datamodel.path.exists()
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
async def test_status_model_id_update_exception(vertex_finetune, mock_response):
|
|
224
|
+
# Set up response to raise an exception when accessing tuned_model_endpoint_name
|
|
225
|
+
mock_response.tuned_model_endpoint_name = None
|
|
226
|
+
|
|
227
|
+
# Create a property that raises an exception when accessed
|
|
228
|
+
def raise_exception(self):
|
|
229
|
+
raise Exception("Model ID error")
|
|
230
|
+
|
|
231
|
+
type(mock_response).tuned_model_endpoint_name = property(raise_exception)
|
|
232
|
+
|
|
233
|
+
with (
|
|
234
|
+
patch(
|
|
235
|
+
"kiln_ai.adapters.fine_tune.vertex_finetune.sft.SupervisedTuningJob",
|
|
236
|
+
return_value=mock_response,
|
|
237
|
+
),
|
|
238
|
+
patch("kiln_ai.adapters.fine_tune.vertex_finetune.logger") as mock_logger,
|
|
239
|
+
):
|
|
240
|
+
status = await vertex_finetune.status()
|
|
241
|
+
|
|
242
|
+
# Verify warning was logged
|
|
243
|
+
mock_logger.warning.assert_called_once()
|
|
244
|
+
assert (
|
|
245
|
+
"Error updating fine-tune model ID" in mock_logger.warning.call_args[0][0]
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# Status should still be returned even with the exception
|
|
249
|
+
assert status.status == FineTuneStatusType.completed
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
@pytest.mark.parametrize(
|
|
253
|
+
"data_strategy,thinking_instructions",
|
|
254
|
+
[
|
|
255
|
+
(FinetuneDataStrategy.final_and_intermediate, "Custom thinking instructions"),
|
|
256
|
+
(FinetuneDataStrategy.final_only, None),
|
|
257
|
+
],
|
|
258
|
+
)
|
|
259
|
+
async def test_generate_and_upload_jsonl(
|
|
260
|
+
vertex_finetune,
|
|
261
|
+
mock_dataset,
|
|
262
|
+
mock_task,
|
|
263
|
+
data_strategy,
|
|
264
|
+
thinking_instructions,
|
|
265
|
+
tmp_path,
|
|
266
|
+
):
|
|
267
|
+
# Create finetune with specific data strategy and thinking instructions
|
|
268
|
+
finetune = VertexFinetune(
|
|
269
|
+
datamodel=FinetuneModel(
|
|
270
|
+
name="test-finetune",
|
|
271
|
+
provider="vertex",
|
|
272
|
+
provider_id="vertex-123",
|
|
273
|
+
base_model_id="gemini-2.0-pro",
|
|
274
|
+
train_split_name="train",
|
|
275
|
+
dataset_split_id="dataset-123",
|
|
276
|
+
system_message="Test system message",
|
|
277
|
+
path=tmp_path / "test-finetune.kiln",
|
|
278
|
+
data_strategy=data_strategy,
|
|
279
|
+
thinking_instructions=thinking_instructions,
|
|
280
|
+
),
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
mock_path = Path("mock_path.jsonl")
|
|
284
|
+
expected_uri = "gs://kiln-ai-data/1234567890/mock_path.jsonl"
|
|
285
|
+
|
|
286
|
+
# Mock the formatter
|
|
287
|
+
mock_formatter = MagicMock(spec=DatasetFormatter)
|
|
288
|
+
mock_formatter.dump_to_file.return_value = mock_path
|
|
289
|
+
|
|
290
|
+
# Mock storage client and bucket operations
|
|
291
|
+
mock_bucket = MagicMock()
|
|
292
|
+
mock_bucket.name = "kiln-ai-data"
|
|
293
|
+
|
|
294
|
+
mock_blob = MagicMock()
|
|
295
|
+
mock_blob.name = f"1234567890/{mock_path.name}"
|
|
296
|
+
|
|
297
|
+
mock_storage_client = MagicMock(spec=storage.Client)
|
|
298
|
+
mock_storage_client.lookup_bucket.return_value = mock_bucket
|
|
299
|
+
mock_storage_client.bucket.return_value = mock_bucket
|
|
300
|
+
|
|
301
|
+
mock_bucket.blob.return_value = mock_blob
|
|
302
|
+
|
|
303
|
+
with (
|
|
304
|
+
patch(
|
|
305
|
+
"kiln_ai.adapters.fine_tune.vertex_finetune.DatasetFormatter",
|
|
306
|
+
return_value=mock_formatter,
|
|
307
|
+
),
|
|
308
|
+
patch(
|
|
309
|
+
"kiln_ai.adapters.fine_tune.vertex_finetune.storage.Client",
|
|
310
|
+
return_value=mock_storage_client,
|
|
311
|
+
),
|
|
312
|
+
patch(
|
|
313
|
+
"kiln_ai.adapters.fine_tune.vertex_finetune.time.time",
|
|
314
|
+
return_value=1234567890,
|
|
315
|
+
),
|
|
316
|
+
patch.object(Config, "shared") as mock_config,
|
|
317
|
+
):
|
|
318
|
+
mock_config.return_value.vertex_project_id = "test-project"
|
|
319
|
+
mock_config.return_value.vertex_location = "us-central1"
|
|
320
|
+
|
|
321
|
+
result = await finetune.generate_and_upload_jsonl(
|
|
322
|
+
mock_dataset, "train", mock_task, DatasetFormat.VERTEX_GEMINI
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
# Verify formatter was created with correct parameters
|
|
326
|
+
mock_formatter.dump_to_file.assert_called_once_with(
|
|
327
|
+
"train", DatasetFormat.VERTEX_GEMINI, data_strategy
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
# Verify storage client was created with correct parameters
|
|
331
|
+
mock_storage_client.bucket.assert_called_once_with("kiln-ai-data")
|
|
332
|
+
|
|
333
|
+
# Verify blob was created and uploaded
|
|
334
|
+
mock_bucket.blob.assert_called_once_with(f"1234567890/{mock_path.name}")
|
|
335
|
+
mock_blob.upload_from_filename.assert_called_once_with(mock_path)
|
|
336
|
+
|
|
337
|
+
# Verify GCS URI was returned
|
|
338
|
+
assert result == expected_uri
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
async def test_generate_and_upload_jsonl_create_bucket(
|
|
342
|
+
vertex_finetune, mock_dataset, mock_task
|
|
343
|
+
):
|
|
344
|
+
mock_path = Path("mock_path.jsonl")
|
|
345
|
+
expected_uri = "gs://kiln-ai-data/1234567890/mock_path.jsonl"
|
|
346
|
+
|
|
347
|
+
# Mock the formatter
|
|
348
|
+
mock_formatter = MagicMock(spec=DatasetFormatter)
|
|
349
|
+
mock_formatter.dump_to_file.return_value = mock_path
|
|
350
|
+
|
|
351
|
+
# Mock storage client and bucket operations - bucket doesn't exist
|
|
352
|
+
mock_bucket = MagicMock()
|
|
353
|
+
mock_bucket.name = "kiln-ai-data"
|
|
354
|
+
|
|
355
|
+
mock_blob = MagicMock()
|
|
356
|
+
mock_blob.name = f"1234567890/{mock_path.name}"
|
|
357
|
+
|
|
358
|
+
mock_storage_client = MagicMock(spec=storage.Client)
|
|
359
|
+
mock_storage_client.lookup_bucket.return_value = None # Bucket doesn't exist
|
|
360
|
+
mock_storage_client.create_bucket.return_value = mock_bucket
|
|
361
|
+
|
|
362
|
+
mock_bucket.blob.return_value = mock_blob
|
|
363
|
+
|
|
364
|
+
with (
|
|
365
|
+
patch(
|
|
366
|
+
"kiln_ai.adapters.fine_tune.vertex_finetune.DatasetFormatter",
|
|
367
|
+
return_value=mock_formatter,
|
|
368
|
+
),
|
|
369
|
+
patch(
|
|
370
|
+
"kiln_ai.adapters.fine_tune.vertex_finetune.storage.Client",
|
|
371
|
+
return_value=mock_storage_client,
|
|
372
|
+
),
|
|
373
|
+
patch(
|
|
374
|
+
"kiln_ai.adapters.fine_tune.vertex_finetune.time.time",
|
|
375
|
+
return_value=1234567890,
|
|
376
|
+
),
|
|
377
|
+
patch.object(Config, "shared") as mock_config,
|
|
378
|
+
):
|
|
379
|
+
mock_config.return_value.vertex_project_id = "test-project"
|
|
380
|
+
mock_config.return_value.vertex_location = "us-central1"
|
|
381
|
+
|
|
382
|
+
result = await vertex_finetune.generate_and_upload_jsonl(
|
|
383
|
+
mock_dataset, "train", mock_task, DatasetFormat.VERTEX_GEMINI
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
# Verify bucket was created
|
|
387
|
+
mock_storage_client.create_bucket.assert_called_once_with(
|
|
388
|
+
"kiln-ai-data", location="us-central1"
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
# Verify blob was created and uploaded
|
|
392
|
+
mock_blob.upload_from_filename.assert_called_once_with(mock_path)
|
|
393
|
+
|
|
394
|
+
# Verify GCS URI was returned
|
|
395
|
+
assert result == expected_uri
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
@pytest.mark.parametrize(
|
|
399
|
+
"output_schema,expected_mode,expected_format",
|
|
400
|
+
[
|
|
401
|
+
(
|
|
402
|
+
'{"type": "object", "properties": {"key": {"type": "string"}}}',
|
|
403
|
+
StructuredOutputMode.json_mode,
|
|
404
|
+
DatasetFormat.VERTEX_GEMINI,
|
|
405
|
+
),
|
|
406
|
+
(None, None, DatasetFormat.VERTEX_GEMINI),
|
|
407
|
+
],
|
|
408
|
+
)
|
|
409
|
+
async def test_start_success(
|
|
410
|
+
vertex_finetune,
|
|
411
|
+
mock_dataset,
|
|
412
|
+
mock_task,
|
|
413
|
+
output_schema,
|
|
414
|
+
expected_mode,
|
|
415
|
+
expected_format,
|
|
416
|
+
):
|
|
417
|
+
# Set task for finetune
|
|
418
|
+
vertex_finetune.datamodel.parent = mock_task
|
|
419
|
+
mock_task.output_json_schema = output_schema
|
|
420
|
+
|
|
421
|
+
# Mock hyperparameters
|
|
422
|
+
vertex_finetune.datamodel.parameters = {
|
|
423
|
+
"epochs": 3,
|
|
424
|
+
"learning_rate_multiplier": 0.1,
|
|
425
|
+
"adapter_size": 8,
|
|
426
|
+
}
|
|
427
|
+
|
|
428
|
+
# Mock train response
|
|
429
|
+
mock_sft_job = MagicMock()
|
|
430
|
+
mock_sft_job.resource_name = "vertex-ft-123"
|
|
431
|
+
|
|
432
|
+
train_file_uri = "gs://kiln-ai-data/train.jsonl"
|
|
433
|
+
validation_file_uri = "gs://kiln-ai-data/validation.jsonl"
|
|
434
|
+
|
|
435
|
+
with (
|
|
436
|
+
patch.object(
|
|
437
|
+
vertex_finetune,
|
|
438
|
+
"generate_and_upload_jsonl",
|
|
439
|
+
side_effect=[train_file_uri, validation_file_uri],
|
|
440
|
+
) as mock_upload,
|
|
441
|
+
patch("kiln_ai.adapters.fine_tune.vertex_finetune.vertexai.init") as mock_init,
|
|
442
|
+
patch(
|
|
443
|
+
"kiln_ai.adapters.fine_tune.vertex_finetune.sft.train",
|
|
444
|
+
return_value=mock_sft_job,
|
|
445
|
+
) as mock_train,
|
|
446
|
+
patch.object(Config, "shared") as mock_config,
|
|
447
|
+
):
|
|
448
|
+
mock_config.return_value.vertex_project_id = "test-project"
|
|
449
|
+
mock_config.return_value.vertex_location = "us-central1"
|
|
450
|
+
|
|
451
|
+
# Only training split, no validation
|
|
452
|
+
vertex_finetune.datamodel.validation_split_name = None
|
|
453
|
+
|
|
454
|
+
await vertex_finetune._start(mock_dataset)
|
|
455
|
+
|
|
456
|
+
# Verify initialize was called
|
|
457
|
+
mock_init.assert_called_once_with(
|
|
458
|
+
project="test-project", location="us-central1"
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
# Verify file uploads (only training file, no validation)
|
|
462
|
+
mock_upload.assert_called_once_with(
|
|
463
|
+
mock_dataset,
|
|
464
|
+
vertex_finetune.datamodel.train_split_name,
|
|
465
|
+
mock_task,
|
|
466
|
+
expected_format,
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
# Verify train call with correct parameters
|
|
470
|
+
mock_train.assert_called_once_with(
|
|
471
|
+
source_model=vertex_finetune.datamodel.base_model_id,
|
|
472
|
+
train_dataset=train_file_uri,
|
|
473
|
+
validation_dataset=None,
|
|
474
|
+
tuned_model_display_name=f"kiln_finetune_{vertex_finetune.datamodel.id}",
|
|
475
|
+
epochs=3,
|
|
476
|
+
adapter_size=8,
|
|
477
|
+
learning_rate_multiplier=0.1,
|
|
478
|
+
labels={
|
|
479
|
+
"source": "kiln",
|
|
480
|
+
"kiln_finetune_id": str(vertex_finetune.datamodel.id),
|
|
481
|
+
"kiln_task_id": str(mock_task.id),
|
|
482
|
+
},
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
# Verify model updates
|
|
486
|
+
assert vertex_finetune.datamodel.provider_id == "vertex-ft-123"
|
|
487
|
+
assert vertex_finetune.datamodel.structured_output_mode == expected_mode
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
async def test_start_with_validation(vertex_finetune, mock_dataset, mock_task):
|
|
491
|
+
# Set task and validation split for finetune
|
|
492
|
+
vertex_finetune.datamodel.parent = mock_task
|
|
493
|
+
vertex_finetune.datamodel.validation_split_name = "test"
|
|
494
|
+
|
|
495
|
+
# Mock train response
|
|
496
|
+
mock_sft_job = MagicMock()
|
|
497
|
+
mock_sft_job.resource_name = "vertex-ft-123"
|
|
498
|
+
|
|
499
|
+
train_file_uri = "gs://kiln-ai-data/train.jsonl"
|
|
500
|
+
validation_file_uri = "gs://kiln-ai-data/validation.jsonl"
|
|
501
|
+
|
|
502
|
+
with (
|
|
503
|
+
patch.object(
|
|
504
|
+
vertex_finetune,
|
|
505
|
+
"generate_and_upload_jsonl",
|
|
506
|
+
side_effect=[train_file_uri, validation_file_uri],
|
|
507
|
+
) as mock_upload,
|
|
508
|
+
patch("kiln_ai.adapters.fine_tune.vertex_finetune.vertexai.init"),
|
|
509
|
+
patch(
|
|
510
|
+
"kiln_ai.adapters.fine_tune.vertex_finetune.sft.train",
|
|
511
|
+
return_value=mock_sft_job,
|
|
512
|
+
) as mock_train,
|
|
513
|
+
patch.object(Config, "shared") as mock_config,
|
|
514
|
+
):
|
|
515
|
+
mock_config.return_value.vertex_project_id = "test-project"
|
|
516
|
+
mock_config.return_value.vertex_location = "us-central1"
|
|
517
|
+
|
|
518
|
+
await vertex_finetune._start(mock_dataset)
|
|
519
|
+
|
|
520
|
+
# Verify both files were uploaded
|
|
521
|
+
assert mock_upload.call_count == 2
|
|
522
|
+
mock_upload.assert_any_call(
|
|
523
|
+
mock_dataset,
|
|
524
|
+
vertex_finetune.datamodel.train_split_name,
|
|
525
|
+
mock_task,
|
|
526
|
+
DatasetFormat.VERTEX_GEMINI,
|
|
527
|
+
)
|
|
528
|
+
mock_upload.assert_any_call(
|
|
529
|
+
mock_dataset,
|
|
530
|
+
"test",
|
|
531
|
+
mock_task,
|
|
532
|
+
DatasetFormat.VERTEX_GEMINI,
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
# Verify validation file was included
|
|
536
|
+
mock_train.assert_called_once()
|
|
537
|
+
assert mock_train.call_args[1]["validation_dataset"] == validation_file_uri
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
async def test_start_no_task(vertex_finetune, mock_dataset):
|
|
541
|
+
# No parent task set
|
|
542
|
+
vertex_finetune.datamodel.parent = None
|
|
543
|
+
|
|
544
|
+
with pytest.raises(ValueError, match="Task is required to start a fine-tune"):
|
|
545
|
+
await vertex_finetune._start(mock_dataset)
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
def test_available_parameters():
|
|
549
|
+
parameters = VertexFinetune.available_parameters()
|
|
550
|
+
assert len(parameters) == 3
|
|
551
|
+
|
|
552
|
+
# Verify parameter names and types
|
|
553
|
+
param_names = [p.name for p in parameters]
|
|
554
|
+
assert "learning_rate_multiplier" in param_names
|
|
555
|
+
assert "epochs" in param_names
|
|
556
|
+
assert "adapter_size" in param_names
|
|
557
|
+
|
|
558
|
+
# Verify all parameters are optional
|
|
559
|
+
assert all(p.optional for p in parameters)
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
@pytest.mark.parametrize(
|
|
563
|
+
"project_id,location,should_raise",
|
|
564
|
+
[
|
|
565
|
+
("test-project", "us-central1", False),
|
|
566
|
+
("", "us-central1", True),
|
|
567
|
+
(None, "us-central1", True),
|
|
568
|
+
("test-project", "", True),
|
|
569
|
+
("test-project", None, True),
|
|
570
|
+
(None, None, True),
|
|
571
|
+
],
|
|
572
|
+
)
|
|
573
|
+
def test_get_vertex_provider_location(project_id, location, should_raise):
|
|
574
|
+
with patch.object(Config, "shared") as mock_config:
|
|
575
|
+
mock_config.return_value.vertex_project_id = project_id
|
|
576
|
+
mock_config.return_value.vertex_location = location
|
|
577
|
+
|
|
578
|
+
if should_raise:
|
|
579
|
+
with pytest.raises(
|
|
580
|
+
ValueError, match="Google Vertex project and location must be set"
|
|
581
|
+
):
|
|
582
|
+
VertexFinetune.get_vertex_provider_location()
|
|
583
|
+
else:
|
|
584
|
+
project, loc = VertexFinetune.get_vertex_provider_location()
|
|
585
|
+
assert project == project_id
|
|
586
|
+
assert loc == location
|