kiln-ai 0.13.2__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

@@ -0,0 +1,586 @@
1
+ import time
2
+ from pathlib import Path
3
+ from unittest.mock import AsyncMock, MagicMock, patch
4
+
5
+ import pytest
6
+ from google.cloud import storage
7
+ from google.cloud.aiplatform_v1beta1 import types as gca_types
8
+ from vertexai.tuning import sft
9
+
10
+ from kiln_ai.adapters.fine_tune.base_finetune import FineTuneStatusType
11
+ from kiln_ai.adapters.fine_tune.dataset_formatter import DatasetFormat, DatasetFormatter
12
+ from kiln_ai.adapters.fine_tune.vertex_finetune import VertexFinetune
13
+ from kiln_ai.datamodel import (
14
+ DatasetSplit,
15
+ FinetuneDataStrategy,
16
+ StructuredOutputMode,
17
+ Task,
18
+ )
19
+ from kiln_ai.datamodel import Finetune as FinetuneModel
20
+ from kiln_ai.datamodel.dataset_split import Train80Test20SplitDefinition
21
+ from kiln_ai.utils.config import Config
22
+
23
+
24
+ @pytest.fixture
25
+ def vertex_finetune(tmp_path):
26
+ tmp_file = tmp_path / "test-finetune.kiln"
27
+ finetune = VertexFinetune(
28
+ datamodel=FinetuneModel(
29
+ name="test-finetune",
30
+ provider="vertex",
31
+ provider_id="vertex-123",
32
+ base_model_id="gemini-2.0-pro",
33
+ train_split_name="train",
34
+ dataset_split_id="dataset-123",
35
+ system_message="Test system message",
36
+ fine_tune_model_id="ft-123",
37
+ path=tmp_file,
38
+ data_strategy=FinetuneDataStrategy.final_only,
39
+ ),
40
+ )
41
+ return finetune
42
+
43
+
44
+ @pytest.fixture
45
+ def mock_response():
46
+ # Mock SFT job response object
47
+ response = MagicMock(spec=sft.SupervisedTuningJob)
48
+ response.error = None
49
+ response.state = gca_types.JobState.JOB_STATE_SUCCEEDED
50
+ response.tuned_model_endpoint_name = "ft-123"
51
+ return response
52
+
53
+
54
+ @pytest.fixture
55
+ def mock_dataset():
56
+ return DatasetSplit(
57
+ id="test-dataset-123",
58
+ name="Test Dataset",
59
+ splits=Train80Test20SplitDefinition,
60
+ split_contents={"train": [], "test": []},
61
+ )
62
+
63
+
64
+ @pytest.fixture
65
+ def mock_task():
66
+ return Task(
67
+ id="test-task-123",
68
+ name="Test Task",
69
+ output_json_schema=None, # Can be modified in specific tests
70
+ instruction="Test instruction",
71
+ )
72
+
73
+
74
+ async def test_status_pending_no_provider_id(vertex_finetune):
75
+ vertex_finetune.datamodel.provider_id = None
76
+
77
+ status = await vertex_finetune.status()
78
+ assert status.status == FineTuneStatusType.pending
79
+ assert "This fine-tune has not been started" in status.message
80
+
81
+
82
+ @pytest.mark.parametrize(
83
+ "state,expected_status,message_contains",
84
+ [
85
+ (
86
+ gca_types.JobState.JOB_STATE_FAILED,
87
+ FineTuneStatusType.failed,
88
+ "Fine Tune Job Failed",
89
+ ),
90
+ (
91
+ gca_types.JobState.JOB_STATE_EXPIRED,
92
+ FineTuneStatusType.failed,
93
+ "Fine Tune Job Failed",
94
+ ),
95
+ (
96
+ gca_types.JobState.JOB_STATE_CANCELLED,
97
+ FineTuneStatusType.failed,
98
+ "Fine Tune Job Cancelled",
99
+ ),
100
+ (
101
+ gca_types.JobState.JOB_STATE_CANCELLING,
102
+ FineTuneStatusType.failed,
103
+ "Fine Tune Job Cancelled",
104
+ ),
105
+ (
106
+ gca_types.JobState.JOB_STATE_PENDING,
107
+ FineTuneStatusType.pending,
108
+ "Fine Tune Job Pending",
109
+ ),
110
+ (
111
+ gca_types.JobState.JOB_STATE_QUEUED,
112
+ FineTuneStatusType.pending,
113
+ "Fine Tune Job Pending",
114
+ ),
115
+ (
116
+ gca_types.JobState.JOB_STATE_RUNNING,
117
+ FineTuneStatusType.running,
118
+ "Fine Tune Job Running",
119
+ ),
120
+ (
121
+ gca_types.JobState.JOB_STATE_SUCCEEDED,
122
+ FineTuneStatusType.completed,
123
+ "Fine Tune Job Completed",
124
+ ),
125
+ (
126
+ gca_types.JobState.JOB_STATE_PARTIALLY_SUCCEEDED,
127
+ FineTuneStatusType.completed,
128
+ "Fine Tune Job Completed",
129
+ ),
130
+ (
131
+ gca_types.JobState.JOB_STATE_PAUSED,
132
+ FineTuneStatusType.unknown,
133
+ "Unknown state",
134
+ ),
135
+ (
136
+ gca_types.JobState.JOB_STATE_UPDATING,
137
+ FineTuneStatusType.unknown,
138
+ "Unknown state",
139
+ ),
140
+ (
141
+ gca_types.JobState.JOB_STATE_UNSPECIFIED,
142
+ FineTuneStatusType.unknown,
143
+ "Unknown state",
144
+ ),
145
+ (999, FineTuneStatusType.unknown, "Unknown state"), # Test unknown state
146
+ ],
147
+ )
148
+ async def test_status_job_states(
149
+ vertex_finetune,
150
+ mock_response,
151
+ state,
152
+ expected_status,
153
+ message_contains,
154
+ ):
155
+ mock_response.state = state
156
+
157
+ with patch(
158
+ "kiln_ai.adapters.fine_tune.vertex_finetune.sft.SupervisedTuningJob",
159
+ return_value=mock_response,
160
+ ):
161
+ status = await vertex_finetune.status()
162
+ assert status.status == expected_status
163
+ assert message_contains in status.message
164
+
165
+
166
+ async def test_status_with_error(vertex_finetune, mock_response):
167
+ # Set up error response
168
+ mock_response.error = MagicMock()
169
+ mock_response.error.code = 123
170
+ mock_response.error.message = "Test error message"
171
+
172
+ with patch(
173
+ "kiln_ai.adapters.fine_tune.vertex_finetune.sft.SupervisedTuningJob",
174
+ return_value=mock_response,
175
+ ):
176
+ status = await vertex_finetune.status()
177
+ assert status.status == FineTuneStatusType.failed
178
+ assert "Test error message [123]" in status.message
179
+
180
+
181
+ async def test_status_updates_model_id(vertex_finetune, mock_response):
182
+ # Set initial fine-tuned model ID
183
+ vertex_finetune.datamodel.fine_tune_model_id = "old-ft-model"
184
+
185
+ # Set new model ID in response
186
+ mock_response.tuned_model_endpoint_name = "new-ft-model"
187
+
188
+ with patch(
189
+ "kiln_ai.adapters.fine_tune.vertex_finetune.sft.SupervisedTuningJob",
190
+ return_value=mock_response,
191
+ ):
192
+ status = await vertex_finetune.status()
193
+
194
+ # Verify model ID was updated
195
+ assert vertex_finetune.datamodel.fine_tune_model_id == "new-ft-model"
196
+
197
+ # Verify status returned correctly
198
+ assert status.status == FineTuneStatusType.completed
199
+ assert status.message == "Fine Tune Job Completed"
200
+
201
+
202
+ async def test_status_updates_latest_status(vertex_finetune, mock_response):
203
+ # Set initial status
204
+ vertex_finetune.datamodel.latest_status = FineTuneStatusType.running
205
+
206
+ # Set completed state in response
207
+ mock_response.state = gca_types.JobState.JOB_STATE_SUCCEEDED
208
+
209
+ with patch(
210
+ "kiln_ai.adapters.fine_tune.vertex_finetune.sft.SupervisedTuningJob",
211
+ return_value=mock_response,
212
+ ):
213
+ status = await vertex_finetune.status()
214
+
215
+ # Verify status was updated in datamodel
216
+ assert vertex_finetune.datamodel.latest_status == FineTuneStatusType.completed
217
+ assert status.status == FineTuneStatusType.completed
218
+
219
+ # Verify file was saved (since path exists)
220
+ assert vertex_finetune.datamodel.path.exists()
221
+
222
+
223
+ async def test_status_model_id_update_exception(vertex_finetune, mock_response):
224
+ # Set up response to raise an exception when accessing tuned_model_endpoint_name
225
+ mock_response.tuned_model_endpoint_name = None
226
+
227
+ # Create a property that raises an exception when accessed
228
+ def raise_exception(self):
229
+ raise Exception("Model ID error")
230
+
231
+ type(mock_response).tuned_model_endpoint_name = property(raise_exception)
232
+
233
+ with (
234
+ patch(
235
+ "kiln_ai.adapters.fine_tune.vertex_finetune.sft.SupervisedTuningJob",
236
+ return_value=mock_response,
237
+ ),
238
+ patch("kiln_ai.adapters.fine_tune.vertex_finetune.logger") as mock_logger,
239
+ ):
240
+ status = await vertex_finetune.status()
241
+
242
+ # Verify warning was logged
243
+ mock_logger.warning.assert_called_once()
244
+ assert (
245
+ "Error updating fine-tune model ID" in mock_logger.warning.call_args[0][0]
246
+ )
247
+
248
+ # Status should still be returned even with the exception
249
+ assert status.status == FineTuneStatusType.completed
250
+
251
+
252
+ @pytest.mark.parametrize(
253
+ "data_strategy,thinking_instructions",
254
+ [
255
+ (FinetuneDataStrategy.final_and_intermediate, "Custom thinking instructions"),
256
+ (FinetuneDataStrategy.final_only, None),
257
+ ],
258
+ )
259
+ async def test_generate_and_upload_jsonl(
260
+ vertex_finetune,
261
+ mock_dataset,
262
+ mock_task,
263
+ data_strategy,
264
+ thinking_instructions,
265
+ tmp_path,
266
+ ):
267
+ # Create finetune with specific data strategy and thinking instructions
268
+ finetune = VertexFinetune(
269
+ datamodel=FinetuneModel(
270
+ name="test-finetune",
271
+ provider="vertex",
272
+ provider_id="vertex-123",
273
+ base_model_id="gemini-2.0-pro",
274
+ train_split_name="train",
275
+ dataset_split_id="dataset-123",
276
+ system_message="Test system message",
277
+ path=tmp_path / "test-finetune.kiln",
278
+ data_strategy=data_strategy,
279
+ thinking_instructions=thinking_instructions,
280
+ ),
281
+ )
282
+
283
+ mock_path = Path("mock_path.jsonl")
284
+ expected_uri = "gs://kiln-ai-data/1234567890/mock_path.jsonl"
285
+
286
+ # Mock the formatter
287
+ mock_formatter = MagicMock(spec=DatasetFormatter)
288
+ mock_formatter.dump_to_file.return_value = mock_path
289
+
290
+ # Mock storage client and bucket operations
291
+ mock_bucket = MagicMock()
292
+ mock_bucket.name = "kiln-ai-data"
293
+
294
+ mock_blob = MagicMock()
295
+ mock_blob.name = f"1234567890/{mock_path.name}"
296
+
297
+ mock_storage_client = MagicMock(spec=storage.Client)
298
+ mock_storage_client.lookup_bucket.return_value = mock_bucket
299
+ mock_storage_client.bucket.return_value = mock_bucket
300
+
301
+ mock_bucket.blob.return_value = mock_blob
302
+
303
+ with (
304
+ patch(
305
+ "kiln_ai.adapters.fine_tune.vertex_finetune.DatasetFormatter",
306
+ return_value=mock_formatter,
307
+ ),
308
+ patch(
309
+ "kiln_ai.adapters.fine_tune.vertex_finetune.storage.Client",
310
+ return_value=mock_storage_client,
311
+ ),
312
+ patch(
313
+ "kiln_ai.adapters.fine_tune.vertex_finetune.time.time",
314
+ return_value=1234567890,
315
+ ),
316
+ patch.object(Config, "shared") as mock_config,
317
+ ):
318
+ mock_config.return_value.vertex_project_id = "test-project"
319
+ mock_config.return_value.vertex_location = "us-central1"
320
+
321
+ result = await finetune.generate_and_upload_jsonl(
322
+ mock_dataset, "train", mock_task, DatasetFormat.VERTEX_GEMINI
323
+ )
324
+
325
+ # Verify formatter was created with correct parameters
326
+ mock_formatter.dump_to_file.assert_called_once_with(
327
+ "train", DatasetFormat.VERTEX_GEMINI, data_strategy
328
+ )
329
+
330
+ # Verify storage client was created with correct parameters
331
+ mock_storage_client.bucket.assert_called_once_with("kiln-ai-data")
332
+
333
+ # Verify blob was created and uploaded
334
+ mock_bucket.blob.assert_called_once_with(f"1234567890/{mock_path.name}")
335
+ mock_blob.upload_from_filename.assert_called_once_with(mock_path)
336
+
337
+ # Verify GCS URI was returned
338
+ assert result == expected_uri
339
+
340
+
341
+ async def test_generate_and_upload_jsonl_create_bucket(
342
+ vertex_finetune, mock_dataset, mock_task
343
+ ):
344
+ mock_path = Path("mock_path.jsonl")
345
+ expected_uri = "gs://kiln-ai-data/1234567890/mock_path.jsonl"
346
+
347
+ # Mock the formatter
348
+ mock_formatter = MagicMock(spec=DatasetFormatter)
349
+ mock_formatter.dump_to_file.return_value = mock_path
350
+
351
+ # Mock storage client and bucket operations - bucket doesn't exist
352
+ mock_bucket = MagicMock()
353
+ mock_bucket.name = "kiln-ai-data"
354
+
355
+ mock_blob = MagicMock()
356
+ mock_blob.name = f"1234567890/{mock_path.name}"
357
+
358
+ mock_storage_client = MagicMock(spec=storage.Client)
359
+ mock_storage_client.lookup_bucket.return_value = None # Bucket doesn't exist
360
+ mock_storage_client.create_bucket.return_value = mock_bucket
361
+
362
+ mock_bucket.blob.return_value = mock_blob
363
+
364
+ with (
365
+ patch(
366
+ "kiln_ai.adapters.fine_tune.vertex_finetune.DatasetFormatter",
367
+ return_value=mock_formatter,
368
+ ),
369
+ patch(
370
+ "kiln_ai.adapters.fine_tune.vertex_finetune.storage.Client",
371
+ return_value=mock_storage_client,
372
+ ),
373
+ patch(
374
+ "kiln_ai.adapters.fine_tune.vertex_finetune.time.time",
375
+ return_value=1234567890,
376
+ ),
377
+ patch.object(Config, "shared") as mock_config,
378
+ ):
379
+ mock_config.return_value.vertex_project_id = "test-project"
380
+ mock_config.return_value.vertex_location = "us-central1"
381
+
382
+ result = await vertex_finetune.generate_and_upload_jsonl(
383
+ mock_dataset, "train", mock_task, DatasetFormat.VERTEX_GEMINI
384
+ )
385
+
386
+ # Verify bucket was created
387
+ mock_storage_client.create_bucket.assert_called_once_with(
388
+ "kiln-ai-data", location="us-central1"
389
+ )
390
+
391
+ # Verify blob was created and uploaded
392
+ mock_blob.upload_from_filename.assert_called_once_with(mock_path)
393
+
394
+ # Verify GCS URI was returned
395
+ assert result == expected_uri
396
+
397
+
398
+ @pytest.mark.parametrize(
399
+ "output_schema,expected_mode,expected_format",
400
+ [
401
+ (
402
+ '{"type": "object", "properties": {"key": {"type": "string"}}}',
403
+ StructuredOutputMode.json_mode,
404
+ DatasetFormat.VERTEX_GEMINI,
405
+ ),
406
+ (None, None, DatasetFormat.VERTEX_GEMINI),
407
+ ],
408
+ )
409
+ async def test_start_success(
410
+ vertex_finetune,
411
+ mock_dataset,
412
+ mock_task,
413
+ output_schema,
414
+ expected_mode,
415
+ expected_format,
416
+ ):
417
+ # Set task for finetune
418
+ vertex_finetune.datamodel.parent = mock_task
419
+ mock_task.output_json_schema = output_schema
420
+
421
+ # Mock hyperparameters
422
+ vertex_finetune.datamodel.parameters = {
423
+ "epochs": 3,
424
+ "learning_rate_multiplier": 0.1,
425
+ "adapter_size": 8,
426
+ }
427
+
428
+ # Mock train response
429
+ mock_sft_job = MagicMock()
430
+ mock_sft_job.resource_name = "vertex-ft-123"
431
+
432
+ train_file_uri = "gs://kiln-ai-data/train.jsonl"
433
+ validation_file_uri = "gs://kiln-ai-data/validation.jsonl"
434
+
435
+ with (
436
+ patch.object(
437
+ vertex_finetune,
438
+ "generate_and_upload_jsonl",
439
+ side_effect=[train_file_uri, validation_file_uri],
440
+ ) as mock_upload,
441
+ patch("kiln_ai.adapters.fine_tune.vertex_finetune.vertexai.init") as mock_init,
442
+ patch(
443
+ "kiln_ai.adapters.fine_tune.vertex_finetune.sft.train",
444
+ return_value=mock_sft_job,
445
+ ) as mock_train,
446
+ patch.object(Config, "shared") as mock_config,
447
+ ):
448
+ mock_config.return_value.vertex_project_id = "test-project"
449
+ mock_config.return_value.vertex_location = "us-central1"
450
+
451
+ # Only training split, no validation
452
+ vertex_finetune.datamodel.validation_split_name = None
453
+
454
+ await vertex_finetune._start(mock_dataset)
455
+
456
+ # Verify initialize was called
457
+ mock_init.assert_called_once_with(
458
+ project="test-project", location="us-central1"
459
+ )
460
+
461
+ # Verify file uploads (only training file, no validation)
462
+ mock_upload.assert_called_once_with(
463
+ mock_dataset,
464
+ vertex_finetune.datamodel.train_split_name,
465
+ mock_task,
466
+ expected_format,
467
+ )
468
+
469
+ # Verify train call with correct parameters
470
+ mock_train.assert_called_once_with(
471
+ source_model=vertex_finetune.datamodel.base_model_id,
472
+ train_dataset=train_file_uri,
473
+ validation_dataset=None,
474
+ tuned_model_display_name=f"kiln_finetune_{vertex_finetune.datamodel.id}",
475
+ epochs=3,
476
+ adapter_size=8,
477
+ learning_rate_multiplier=0.1,
478
+ labels={
479
+ "source": "kiln",
480
+ "kiln_finetune_id": str(vertex_finetune.datamodel.id),
481
+ "kiln_task_id": str(mock_task.id),
482
+ },
483
+ )
484
+
485
+ # Verify model updates
486
+ assert vertex_finetune.datamodel.provider_id == "vertex-ft-123"
487
+ assert vertex_finetune.datamodel.structured_output_mode == expected_mode
488
+
489
+
490
+ async def test_start_with_validation(vertex_finetune, mock_dataset, mock_task):
491
+ # Set task and validation split for finetune
492
+ vertex_finetune.datamodel.parent = mock_task
493
+ vertex_finetune.datamodel.validation_split_name = "test"
494
+
495
+ # Mock train response
496
+ mock_sft_job = MagicMock()
497
+ mock_sft_job.resource_name = "vertex-ft-123"
498
+
499
+ train_file_uri = "gs://kiln-ai-data/train.jsonl"
500
+ validation_file_uri = "gs://kiln-ai-data/validation.jsonl"
501
+
502
+ with (
503
+ patch.object(
504
+ vertex_finetune,
505
+ "generate_and_upload_jsonl",
506
+ side_effect=[train_file_uri, validation_file_uri],
507
+ ) as mock_upload,
508
+ patch("kiln_ai.adapters.fine_tune.vertex_finetune.vertexai.init"),
509
+ patch(
510
+ "kiln_ai.adapters.fine_tune.vertex_finetune.sft.train",
511
+ return_value=mock_sft_job,
512
+ ) as mock_train,
513
+ patch.object(Config, "shared") as mock_config,
514
+ ):
515
+ mock_config.return_value.vertex_project_id = "test-project"
516
+ mock_config.return_value.vertex_location = "us-central1"
517
+
518
+ await vertex_finetune._start(mock_dataset)
519
+
520
+ # Verify both files were uploaded
521
+ assert mock_upload.call_count == 2
522
+ mock_upload.assert_any_call(
523
+ mock_dataset,
524
+ vertex_finetune.datamodel.train_split_name,
525
+ mock_task,
526
+ DatasetFormat.VERTEX_GEMINI,
527
+ )
528
+ mock_upload.assert_any_call(
529
+ mock_dataset,
530
+ "test",
531
+ mock_task,
532
+ DatasetFormat.VERTEX_GEMINI,
533
+ )
534
+
535
+ # Verify validation file was included
536
+ mock_train.assert_called_once()
537
+ assert mock_train.call_args[1]["validation_dataset"] == validation_file_uri
538
+
539
+
540
+ async def test_start_no_task(vertex_finetune, mock_dataset):
541
+ # No parent task set
542
+ vertex_finetune.datamodel.parent = None
543
+
544
+ with pytest.raises(ValueError, match="Task is required to start a fine-tune"):
545
+ await vertex_finetune._start(mock_dataset)
546
+
547
+
548
+ def test_available_parameters():
549
+ parameters = VertexFinetune.available_parameters()
550
+ assert len(parameters) == 3
551
+
552
+ # Verify parameter names and types
553
+ param_names = [p.name for p in parameters]
554
+ assert "learning_rate_multiplier" in param_names
555
+ assert "epochs" in param_names
556
+ assert "adapter_size" in param_names
557
+
558
+ # Verify all parameters are optional
559
+ assert all(p.optional for p in parameters)
560
+
561
+
562
+ @pytest.mark.parametrize(
563
+ "project_id,location,should_raise",
564
+ [
565
+ ("test-project", "us-central1", False),
566
+ ("", "us-central1", True),
567
+ (None, "us-central1", True),
568
+ ("test-project", "", True),
569
+ ("test-project", None, True),
570
+ (None, None, True),
571
+ ],
572
+ )
573
+ def test_get_vertex_provider_location(project_id, location, should_raise):
574
+ with patch.object(Config, "shared") as mock_config:
575
+ mock_config.return_value.vertex_project_id = project_id
576
+ mock_config.return_value.vertex_location = location
577
+
578
+ if should_raise:
579
+ with pytest.raises(
580
+ ValueError, match="Google Vertex project and location must be set"
581
+ ):
582
+ VertexFinetune.get_vertex_provider_location()
583
+ else:
584
+ project, loc = VertexFinetune.get_vertex_provider_location()
585
+ assert project == project_id
586
+ assert loc == location