kiln-ai 0.6.1__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (40) hide show
  1. kiln_ai/adapters/__init__.py +2 -0
  2. kiln_ai/adapters/adapter_registry.py +19 -0
  3. kiln_ai/adapters/data_gen/test_data_gen_task.py +29 -21
  4. kiln_ai/adapters/fine_tune/__init__.py +14 -0
  5. kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
  6. kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
  7. kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
  8. kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
  9. kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
  10. kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
  11. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
  12. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
  13. kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
  14. kiln_ai/adapters/langchain_adapters.py +103 -13
  15. kiln_ai/adapters/ml_model_list.py +218 -304
  16. kiln_ai/adapters/ollama_tools.py +114 -0
  17. kiln_ai/adapters/provider_tools.py +295 -0
  18. kiln_ai/adapters/repair/test_repair_task.py +6 -11
  19. kiln_ai/adapters/test_langchain_adapter.py +46 -18
  20. kiln_ai/adapters/test_ollama_tools.py +42 -0
  21. kiln_ai/adapters/test_prompt_adaptors.py +7 -5
  22. kiln_ai/adapters/test_provider_tools.py +312 -0
  23. kiln_ai/adapters/test_structured_output.py +22 -43
  24. kiln_ai/datamodel/__init__.py +235 -22
  25. kiln_ai/datamodel/basemodel.py +30 -0
  26. kiln_ai/datamodel/registry.py +31 -0
  27. kiln_ai/datamodel/test_basemodel.py +29 -1
  28. kiln_ai/datamodel/test_dataset_split.py +234 -0
  29. kiln_ai/datamodel/test_example_models.py +12 -0
  30. kiln_ai/datamodel/test_models.py +91 -1
  31. kiln_ai/datamodel/test_registry.py +96 -0
  32. kiln_ai/utils/config.py +9 -0
  33. kiln_ai/utils/name_generator.py +125 -0
  34. kiln_ai/utils/test_name_geneator.py +47 -0
  35. {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.0.dist-info}/METADATA +4 -2
  36. kiln_ai-0.7.0.dist-info/RECORD +56 -0
  37. kiln_ai/adapters/test_ml_model_list.py +0 -181
  38. kiln_ai-0.6.1.dist-info/RECORD +0 -37
  39. {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.0.dist-info}/WHEEL +0 -0
  40. {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,455 @@
1
+ from pathlib import Path
2
+ from unittest.mock import AsyncMock, MagicMock, patch
3
+
4
+ import httpx
5
+ import pytest
6
+
7
+ from kiln_ai.adapters.fine_tune.base_finetune import (
8
+ FineTuneParameter,
9
+ FineTuneStatus,
10
+ FineTuneStatusType,
11
+ )
12
+ from kiln_ai.adapters.fine_tune.dataset_formatter import DatasetFormat, DatasetFormatter
13
+ from kiln_ai.adapters.fine_tune.fireworks_finetune import FireworksFinetune
14
+ from kiln_ai.datamodel import (
15
+ DatasetSplit,
16
+ Task,
17
+ Train80Test20SplitDefinition,
18
+ )
19
+ from kiln_ai.datamodel import Finetune as FinetuneModel
20
+ from kiln_ai.utils.config import Config
21
+
22
+
23
+ @pytest.fixture
24
+ def fireworks_finetune(tmp_path):
25
+ tmp_file = tmp_path / "test-finetune.kiln"
26
+ finetune = FireworksFinetune(
27
+ datamodel=FinetuneModel(
28
+ name="test-finetune",
29
+ provider="fireworks",
30
+ provider_id="fw-123",
31
+ base_model_id="llama-v2-7b",
32
+ train_split_name="train",
33
+ dataset_split_id="dataset-123",
34
+ system_message="Test system message",
35
+ path=tmp_file,
36
+ properties={"undeployed_model_id": "ftm-123"},
37
+ ),
38
+ )
39
+ return finetune
40
+
41
+
42
+ @pytest.fixture
43
+ def mock_response():
44
+ response = MagicMock(spec=httpx.Response)
45
+ response.status_code = 200
46
+ response.json.return_value = {
47
+ "state": "COMPLETED",
48
+ "model": "llama-v2-7b",
49
+ }
50
+ return response
51
+
52
+
53
+ @pytest.fixture
54
+ def mock_client():
55
+ client = MagicMock(spec=httpx.AsyncClient)
56
+ return client
57
+
58
+
59
+ @pytest.fixture
60
+ def mock_api_key():
61
+ with patch.object(Config, "shared") as mock_config:
62
+ mock_config.return_value.fireworks_api_key = "test-api-key"
63
+ mock_config.return_value.fireworks_account_id = "test-account-id"
64
+ yield
65
+
66
+
67
+ async def test_setup(fireworks_finetune, mock_api_key):
68
+ if (
69
+ not Config.shared().fireworks_api_key
70
+ or not Config.shared().fireworks_account_id
71
+ ):
72
+ pytest.skip("Fireworks API key or account ID not set")
73
+
74
+ # Real API call, with fake ID
75
+ status = await fireworks_finetune.status()
76
+ assert status.status == FineTuneStatusType.unknown
77
+ assert "Error retrieving fine-tuning job status" in status.message
78
+
79
+
80
+ async def test_status_missing_credentials(fireworks_finetune):
81
+ with patch.object(Config, "shared") as mock_config:
82
+ mock_config.return_value.fireworks_api_key = None
83
+ mock_config.return_value.fireworks_account_id = None
84
+
85
+ status = await fireworks_finetune.status()
86
+ assert status.status == FineTuneStatusType.unknown
87
+ assert "Fireworks API key or account ID not set" == status.message
88
+
89
+
90
+ async def test_status_missing_provider_id(fireworks_finetune, mock_api_key):
91
+ fireworks_finetune.datamodel.provider_id = None
92
+
93
+ status = await fireworks_finetune.status()
94
+ assert status.status == FineTuneStatusType.unknown
95
+ assert "Fine-tuning job ID not set" in status.message
96
+
97
+
98
+ @pytest.mark.parametrize(
99
+ "status_code,expected_status,expected_message",
100
+ [
101
+ (
102
+ 401,
103
+ FineTuneStatusType.unknown,
104
+ "Error retrieving fine-tuning job status: [401]",
105
+ ),
106
+ (
107
+ 404,
108
+ FineTuneStatusType.unknown,
109
+ "Error retrieving fine-tuning job status: [404]",
110
+ ),
111
+ (
112
+ 500,
113
+ FineTuneStatusType.unknown,
114
+ "Error retrieving fine-tuning job status: [500]",
115
+ ),
116
+ ],
117
+ )
118
+ async def test_status_api_errors(
119
+ fireworks_finetune,
120
+ mock_response,
121
+ mock_client,
122
+ status_code,
123
+ expected_status,
124
+ expected_message,
125
+ mock_api_key,
126
+ ):
127
+ mock_response.status_code = status_code
128
+ mock_response.text = "Error message"
129
+ mock_client.get.return_value = mock_response
130
+
131
+ with patch("httpx.AsyncClient") as mock_client_class:
132
+ mock_client_class.return_value.__aenter__.return_value = mock_client
133
+ status = await fireworks_finetune.status()
134
+ assert status.status == expected_status
135
+ assert expected_message in status.message
136
+
137
+
138
+ @pytest.mark.parametrize(
139
+ "state,expected_status,message",
140
+ [
141
+ ("FAILED", FineTuneStatusType.failed, "Fine-tuning job failed"),
142
+ ("DELETING", FineTuneStatusType.failed, "Fine-tuning job failed"),
143
+ ("COMPLETED", FineTuneStatusType.completed, "Fine-tuning job completed"),
144
+ (
145
+ "CREATING",
146
+ FineTuneStatusType.running,
147
+ "Fine-tuning job is running [CREATING]",
148
+ ),
149
+ ("PENDING", FineTuneStatusType.running, "Fine-tuning job is running [PENDING]"),
150
+ ("RUNNING", FineTuneStatusType.running, "Fine-tuning job is running [RUNNING]"),
151
+ (
152
+ "UNKNOWN_STATE",
153
+ FineTuneStatusType.unknown,
154
+ "Unknown fine-tuning job status [UNKNOWN_STATE]",
155
+ ),
156
+ (
157
+ "UNSPECIFIED_STATE",
158
+ FineTuneStatusType.unknown,
159
+ "Unknown fine-tuning job status [UNSPECIFIED_STATE]",
160
+ ),
161
+ ],
162
+ )
163
+ async def test_status_job_states(
164
+ fireworks_finetune,
165
+ mock_response,
166
+ mock_client,
167
+ state,
168
+ expected_status,
169
+ message,
170
+ mock_api_key,
171
+ ):
172
+ mock_response.json.return_value = {"state": state}
173
+ mock_client.get.return_value = mock_response
174
+
175
+ with (
176
+ patch("httpx.AsyncClient") as mock_client_class,
177
+ patch.object(fireworks_finetune, "_deploy", return_value=True),
178
+ ):
179
+ mock_client_class.return_value.__aenter__.return_value = mock_client
180
+ status = await fireworks_finetune.status()
181
+ assert status.status == expected_status
182
+ assert message == status.message
183
+
184
+
185
+ async def test_status_invalid_response(
186
+ fireworks_finetune, mock_response, mock_client, mock_api_key
187
+ ):
188
+ mock_response.json.return_value = {"no_state_field": "value"}
189
+ mock_client.get.return_value = mock_response
190
+
191
+ with patch("httpx.AsyncClient") as mock_client_class:
192
+ mock_client_class.return_value.__aenter__.return_value = mock_client
193
+ status = await fireworks_finetune.status()
194
+ assert status.status == FineTuneStatusType.unknown
195
+ assert "Invalid response from Fireworks" in status.message
196
+
197
+
198
+ async def test_status_request_exception(fireworks_finetune, mock_client, mock_api_key):
199
+ mock_client.get.side_effect = Exception("Connection error")
200
+
201
+ with patch("httpx.AsyncClient") as mock_client_class:
202
+ mock_client_class.return_value.__aenter__.return_value = mock_client
203
+ status = await fireworks_finetune.status()
204
+ assert status.status == FineTuneStatusType.unknown
205
+ assert (
206
+ "Error retrieving fine-tuning job status: Connection error"
207
+ == status.message
208
+ )
209
+
210
+
211
+ @pytest.fixture
212
+ def mock_dataset():
213
+ return DatasetSplit(
214
+ id="test-dataset-123",
215
+ name="Test Dataset",
216
+ splits=Train80Test20SplitDefinition,
217
+ split_contents={"train": [], "test": []},
218
+ )
219
+
220
+
221
+ @pytest.fixture
222
+ def mock_task():
223
+ return Task(
224
+ id="test-task-123",
225
+ name="Test Task",
226
+ output_json_schema=None, # Can be modified in specific tests
227
+ instruction="Test instruction",
228
+ )
229
+
230
+
231
+ async def test_generate_and_upload_jsonl_success(
232
+ fireworks_finetune, mock_dataset, mock_task, mock_api_key
233
+ ):
234
+ mock_path = Path("mock_path.jsonl")
235
+ mock_dataset_id = "dataset-123"
236
+
237
+ # Mock the formatter
238
+ mock_formatter = MagicMock(spec=DatasetFormatter)
239
+ mock_formatter.dump_to_file.return_value = mock_path
240
+
241
+ # Mock responses for the three API calls
242
+ create_response = MagicMock(spec=httpx.Response)
243
+ create_response.status_code = 200
244
+
245
+ upload_response = MagicMock(spec=httpx.Response)
246
+ upload_response.status_code = 200
247
+
248
+ status_response = MagicMock(spec=httpx.Response)
249
+ status_response.status_code = 200
250
+ status_response.json.return_value = {"state": "READY"}
251
+
252
+ with (
253
+ patch(
254
+ "kiln_ai.adapters.fine_tune.fireworks_finetune.DatasetFormatter",
255
+ return_value=mock_formatter,
256
+ ),
257
+ patch("httpx.AsyncClient") as mock_client_class,
258
+ patch("builtins.open"),
259
+ patch(
260
+ "kiln_ai.adapters.fine_tune.fireworks_finetune.uuid4",
261
+ return_value=mock_dataset_id,
262
+ ),
263
+ ):
264
+ mock_client = AsyncMock()
265
+ mock_client.post = AsyncMock(side_effect=[create_response, upload_response])
266
+ mock_client.get = AsyncMock(return_value=status_response)
267
+ mock_client_class.return_value.__aenter__.return_value = mock_client
268
+
269
+ result = await fireworks_finetune.generate_and_upload_jsonl(
270
+ mock_dataset, "train", mock_task
271
+ )
272
+
273
+ # Verify formatter was created with correct parameters
274
+ mock_formatter.dump_to_file.assert_called_once_with(
275
+ "train", DatasetFormat.OPENAI_CHAT_JSONL
276
+ )
277
+
278
+ assert result == mock_dataset_id
279
+ assert mock_client.post.call_count == 2
280
+ assert mock_client.get.call_count == 1
281
+
282
+
283
+ async def test_start_success(fireworks_finetune, mock_dataset, mock_task, mock_api_key):
284
+ fireworks_finetune.datamodel.parent = mock_task
285
+ mock_dataset_id = "dataset-123"
286
+ mock_model_id = "ft-model-123"
287
+
288
+ # Mock response for create fine-tuning job
289
+ create_response = MagicMock(spec=httpx.Response)
290
+ create_response.status_code = 200
291
+ create_response.json.return_value = {"name": mock_model_id}
292
+
293
+ with (
294
+ patch.object(
295
+ fireworks_finetune,
296
+ "generate_and_upload_jsonl",
297
+ return_value=mock_dataset_id,
298
+ ),
299
+ patch("httpx.AsyncClient") as mock_client_class,
300
+ ):
301
+ mock_client = AsyncMock()
302
+ mock_client.post.return_value = create_response
303
+ mock_client_class.return_value.__aenter__.return_value = mock_client
304
+
305
+ await fireworks_finetune._start(mock_dataset)
306
+
307
+ # Verify dataset was uploaded
308
+ fireworks_finetune.generate_and_upload_jsonl.assert_called_once_with(
309
+ mock_dataset, fireworks_finetune.datamodel.train_split_name, mock_task
310
+ )
311
+
312
+ # Verify model ID was updated
313
+ assert fireworks_finetune.datamodel.provider_id == mock_model_id
314
+
315
+
316
+ async def test_start_api_error(
317
+ fireworks_finetune, mock_dataset, mock_task, mock_api_key
318
+ ):
319
+ fireworks_finetune.datamodel.parent = mock_task
320
+ mock_dataset_id = "dataset-123"
321
+
322
+ # Mock error response
323
+ error_response = MagicMock(spec=httpx.Response)
324
+ error_response.status_code = 500
325
+ error_response.text = "Internal Server Error"
326
+
327
+ with (
328
+ patch.object(
329
+ fireworks_finetune,
330
+ "generate_and_upload_jsonl",
331
+ return_value=mock_dataset_id,
332
+ ),
333
+ patch("httpx.AsyncClient") as mock_client_class,
334
+ ):
335
+ mock_client = AsyncMock()
336
+ mock_client.post.return_value = error_response
337
+ mock_client_class.return_value.__aenter__.return_value = mock_client
338
+
339
+ with pytest.raises(ValueError, match="Failed to create fine-tuning job"):
340
+ await fireworks_finetune._start(mock_dataset)
341
+
342
+
343
+ def test_available_parameters(fireworks_finetune):
344
+ parameters = fireworks_finetune.available_parameters()
345
+ assert len(parameters) == 4
346
+ assert all(isinstance(p, FineTuneParameter) for p in parameters)
347
+
348
+ payload_parameters = fireworks_finetune.create_payload_parameters(
349
+ {"lora_rank": 16, "epochs": 3, "learning_rate": 0.001, "batch_size": 32}
350
+ )
351
+ assert payload_parameters == {
352
+ "loraRank": 16,
353
+ "epochs": 3,
354
+ "learningRate": 0.001,
355
+ "batchSize": 32,
356
+ }
357
+ payload_parameters = fireworks_finetune.create_payload_parameters({})
358
+ assert payload_parameters == {}
359
+
360
+ payload_parameters = fireworks_finetune.create_payload_parameters(
361
+ {"lora_rank": 16, "epochs": 3}
362
+ )
363
+ assert payload_parameters == {"loraRank": 16, "epochs": 3}
364
+
365
+
366
+ async def test_deploy_success(fireworks_finetune, mock_api_key):
367
+ # Mock response for successful deployment
368
+ success_response = MagicMock(spec=httpx.Response)
369
+ success_response.status_code = 200
370
+ assert fireworks_finetune.datamodel.fine_tune_model_id is None
371
+
372
+ with patch("httpx.AsyncClient") as mock_client_class:
373
+ mock_client = AsyncMock()
374
+ mock_client.post.return_value = success_response
375
+ mock_client_class.return_value.__aenter__.return_value = mock_client
376
+
377
+ result = await fireworks_finetune._deploy()
378
+ assert result is True
379
+ assert fireworks_finetune.datamodel.fine_tune_model_id == "ftm-123"
380
+
381
+
382
+ async def test_deploy_already_deployed(fireworks_finetune, mock_api_key):
383
+ # Mock response for already deployed model
384
+ already_deployed_response = MagicMock(spec=httpx.Response)
385
+ already_deployed_response.status_code = 400
386
+ already_deployed_response.json.return_value = {
387
+ "code": 9,
388
+ "message": "Model already deployed",
389
+ }
390
+
391
+ with patch("httpx.AsyncClient") as mock_client_class:
392
+ mock_client = AsyncMock()
393
+ mock_client.post.return_value = already_deployed_response
394
+ mock_client_class.return_value.__aenter__.return_value = mock_client
395
+
396
+ result = await fireworks_finetune._deploy()
397
+ assert result is True
398
+
399
+
400
+ async def test_deploy_failure(fireworks_finetune, mock_api_key):
401
+ # Mock response for failed deployment
402
+ failure_response = MagicMock(spec=httpx.Response)
403
+ failure_response.status_code = 500
404
+ failure_response.json.return_value = {"code": 1}
405
+
406
+ with patch("httpx.AsyncClient") as mock_client_class:
407
+ mock_client = AsyncMock()
408
+ mock_client.post.return_value = failure_response
409
+ mock_client_class.return_value.__aenter__.return_value = mock_client
410
+
411
+ result = await fireworks_finetune._deploy()
412
+ assert result is False
413
+
414
+
415
+ async def test_deploy_missing_credentials(fireworks_finetune):
416
+ # Test missing API key or account ID
417
+ with patch.object(Config, "shared") as mock_config:
418
+ mock_config.return_value.fireworks_api_key = None
419
+ mock_config.return_value.fireworks_account_id = None
420
+
421
+ with pytest.raises(ValueError, match="Fireworks API key or account ID not set"):
422
+ await fireworks_finetune._deploy()
423
+
424
+
425
+ async def test_deploy_missing_model_id(fireworks_finetune, mock_api_key):
426
+ # Test missing model ID
427
+ fireworks_finetune.datamodel.properties["undeployed_model_id"] = None
428
+
429
+ response = await fireworks_finetune._deploy()
430
+ assert response is False
431
+
432
+
433
+ async def test_status_with_deploy(fireworks_finetune, mock_api_key):
434
+ # Mock _status to return completed
435
+ mock_status_response = FineTuneStatus(
436
+ status=FineTuneStatusType.completed, message="Fine-tuning job completed"
437
+ )
438
+
439
+ with (
440
+ patch.object(
441
+ fireworks_finetune, "_status", return_value=mock_status_response
442
+ ) as mock_status,
443
+ patch.object(fireworks_finetune, "_deploy", return_value=False) as mock_deploy,
444
+ ):
445
+ status = await fireworks_finetune.status()
446
+
447
+ # Verify _status was called
448
+ mock_status.assert_called_once()
449
+
450
+ # Verify _deploy was called since status was completed
451
+ mock_deploy.assert_called_once()
452
+
453
+ # Verify message was updated due to failed deployment
454
+ assert status.status == FineTuneStatusType.completed
455
+ assert status.message == "Fine-tuning job completed but failed to deploy model."