kiln-ai 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +11 -1
- kiln_ai/adapters/adapter_registry.py +19 -0
- kiln_ai/adapters/data_gen/__init__.py +11 -0
- kiln_ai/adapters/data_gen/data_gen_task.py +69 -1
- kiln_ai/adapters/data_gen/test_data_gen_task.py +30 -21
- kiln_ai/adapters/fine_tune/__init__.py +14 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
- kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
- kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
- kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
- kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
- kiln_ai/adapters/langchain_adapters.py +103 -13
- kiln_ai/adapters/ml_model_list.py +218 -304
- kiln_ai/adapters/ollama_tools.py +114 -0
- kiln_ai/adapters/provider_tools.py +295 -0
- kiln_ai/adapters/repair/test_repair_task.py +6 -11
- kiln_ai/adapters/test_langchain_adapter.py +46 -18
- kiln_ai/adapters/test_ollama_tools.py +42 -0
- kiln_ai/adapters/test_prompt_adaptors.py +7 -5
- kiln_ai/adapters/test_provider_tools.py +312 -0
- kiln_ai/adapters/test_structured_output.py +22 -43
- kiln_ai/datamodel/__init__.py +235 -22
- kiln_ai/datamodel/basemodel.py +30 -0
- kiln_ai/datamodel/registry.py +31 -0
- kiln_ai/datamodel/test_basemodel.py +29 -1
- kiln_ai/datamodel/test_dataset_split.py +234 -0
- kiln_ai/datamodel/test_example_models.py +12 -0
- kiln_ai/datamodel/test_models.py +91 -1
- kiln_ai/datamodel/test_registry.py +96 -0
- kiln_ai/utils/config.py +9 -0
- kiln_ai/utils/name_generator.py +125 -0
- kiln_ai/utils/test_name_geneator.py +47 -0
- {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/METADATA +4 -2
- kiln_ai-0.7.0.dist-info/RECORD +56 -0
- kiln_ai/adapters/test_ml_model_list.py +0 -181
- kiln_ai-0.6.0.dist-info/RECORD +0 -36
- {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,455 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from unittest.mock import AsyncMock, MagicMock, patch
|
|
3
|
+
|
|
4
|
+
import httpx
|
|
5
|
+
import pytest
|
|
6
|
+
|
|
7
|
+
from kiln_ai.adapters.fine_tune.base_finetune import (
|
|
8
|
+
FineTuneParameter,
|
|
9
|
+
FineTuneStatus,
|
|
10
|
+
FineTuneStatusType,
|
|
11
|
+
)
|
|
12
|
+
from kiln_ai.adapters.fine_tune.dataset_formatter import DatasetFormat, DatasetFormatter
|
|
13
|
+
from kiln_ai.adapters.fine_tune.fireworks_finetune import FireworksFinetune
|
|
14
|
+
from kiln_ai.datamodel import (
|
|
15
|
+
DatasetSplit,
|
|
16
|
+
Task,
|
|
17
|
+
Train80Test20SplitDefinition,
|
|
18
|
+
)
|
|
19
|
+
from kiln_ai.datamodel import Finetune as FinetuneModel
|
|
20
|
+
from kiln_ai.utils.config import Config
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@pytest.fixture
|
|
24
|
+
def fireworks_finetune(tmp_path):
|
|
25
|
+
tmp_file = tmp_path / "test-finetune.kiln"
|
|
26
|
+
finetune = FireworksFinetune(
|
|
27
|
+
datamodel=FinetuneModel(
|
|
28
|
+
name="test-finetune",
|
|
29
|
+
provider="fireworks",
|
|
30
|
+
provider_id="fw-123",
|
|
31
|
+
base_model_id="llama-v2-7b",
|
|
32
|
+
train_split_name="train",
|
|
33
|
+
dataset_split_id="dataset-123",
|
|
34
|
+
system_message="Test system message",
|
|
35
|
+
path=tmp_file,
|
|
36
|
+
properties={"undeployed_model_id": "ftm-123"},
|
|
37
|
+
),
|
|
38
|
+
)
|
|
39
|
+
return finetune
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@pytest.fixture
|
|
43
|
+
def mock_response():
|
|
44
|
+
response = MagicMock(spec=httpx.Response)
|
|
45
|
+
response.status_code = 200
|
|
46
|
+
response.json.return_value = {
|
|
47
|
+
"state": "COMPLETED",
|
|
48
|
+
"model": "llama-v2-7b",
|
|
49
|
+
}
|
|
50
|
+
return response
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@pytest.fixture
|
|
54
|
+
def mock_client():
|
|
55
|
+
client = MagicMock(spec=httpx.AsyncClient)
|
|
56
|
+
return client
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@pytest.fixture
|
|
60
|
+
def mock_api_key():
|
|
61
|
+
with patch.object(Config, "shared") as mock_config:
|
|
62
|
+
mock_config.return_value.fireworks_api_key = "test-api-key"
|
|
63
|
+
mock_config.return_value.fireworks_account_id = "test-account-id"
|
|
64
|
+
yield
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
async def test_setup(fireworks_finetune, mock_api_key):
|
|
68
|
+
if (
|
|
69
|
+
not Config.shared().fireworks_api_key
|
|
70
|
+
or not Config.shared().fireworks_account_id
|
|
71
|
+
):
|
|
72
|
+
pytest.skip("Fireworks API key or account ID not set")
|
|
73
|
+
|
|
74
|
+
# Real API call, with fake ID
|
|
75
|
+
status = await fireworks_finetune.status()
|
|
76
|
+
assert status.status == FineTuneStatusType.unknown
|
|
77
|
+
assert "Error retrieving fine-tuning job status" in status.message
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
async def test_status_missing_credentials(fireworks_finetune):
|
|
81
|
+
with patch.object(Config, "shared") as mock_config:
|
|
82
|
+
mock_config.return_value.fireworks_api_key = None
|
|
83
|
+
mock_config.return_value.fireworks_account_id = None
|
|
84
|
+
|
|
85
|
+
status = await fireworks_finetune.status()
|
|
86
|
+
assert status.status == FineTuneStatusType.unknown
|
|
87
|
+
assert "Fireworks API key or account ID not set" == status.message
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
async def test_status_missing_provider_id(fireworks_finetune, mock_api_key):
|
|
91
|
+
fireworks_finetune.datamodel.provider_id = None
|
|
92
|
+
|
|
93
|
+
status = await fireworks_finetune.status()
|
|
94
|
+
assert status.status == FineTuneStatusType.unknown
|
|
95
|
+
assert "Fine-tuning job ID not set" in status.message
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@pytest.mark.parametrize(
|
|
99
|
+
"status_code,expected_status,expected_message",
|
|
100
|
+
[
|
|
101
|
+
(
|
|
102
|
+
401,
|
|
103
|
+
FineTuneStatusType.unknown,
|
|
104
|
+
"Error retrieving fine-tuning job status: [401]",
|
|
105
|
+
),
|
|
106
|
+
(
|
|
107
|
+
404,
|
|
108
|
+
FineTuneStatusType.unknown,
|
|
109
|
+
"Error retrieving fine-tuning job status: [404]",
|
|
110
|
+
),
|
|
111
|
+
(
|
|
112
|
+
500,
|
|
113
|
+
FineTuneStatusType.unknown,
|
|
114
|
+
"Error retrieving fine-tuning job status: [500]",
|
|
115
|
+
),
|
|
116
|
+
],
|
|
117
|
+
)
|
|
118
|
+
async def test_status_api_errors(
|
|
119
|
+
fireworks_finetune,
|
|
120
|
+
mock_response,
|
|
121
|
+
mock_client,
|
|
122
|
+
status_code,
|
|
123
|
+
expected_status,
|
|
124
|
+
expected_message,
|
|
125
|
+
mock_api_key,
|
|
126
|
+
):
|
|
127
|
+
mock_response.status_code = status_code
|
|
128
|
+
mock_response.text = "Error message"
|
|
129
|
+
mock_client.get.return_value = mock_response
|
|
130
|
+
|
|
131
|
+
with patch("httpx.AsyncClient") as mock_client_class:
|
|
132
|
+
mock_client_class.return_value.__aenter__.return_value = mock_client
|
|
133
|
+
status = await fireworks_finetune.status()
|
|
134
|
+
assert status.status == expected_status
|
|
135
|
+
assert expected_message in status.message
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@pytest.mark.parametrize(
|
|
139
|
+
"state,expected_status,message",
|
|
140
|
+
[
|
|
141
|
+
("FAILED", FineTuneStatusType.failed, "Fine-tuning job failed"),
|
|
142
|
+
("DELETING", FineTuneStatusType.failed, "Fine-tuning job failed"),
|
|
143
|
+
("COMPLETED", FineTuneStatusType.completed, "Fine-tuning job completed"),
|
|
144
|
+
(
|
|
145
|
+
"CREATING",
|
|
146
|
+
FineTuneStatusType.running,
|
|
147
|
+
"Fine-tuning job is running [CREATING]",
|
|
148
|
+
),
|
|
149
|
+
("PENDING", FineTuneStatusType.running, "Fine-tuning job is running [PENDING]"),
|
|
150
|
+
("RUNNING", FineTuneStatusType.running, "Fine-tuning job is running [RUNNING]"),
|
|
151
|
+
(
|
|
152
|
+
"UNKNOWN_STATE",
|
|
153
|
+
FineTuneStatusType.unknown,
|
|
154
|
+
"Unknown fine-tuning job status [UNKNOWN_STATE]",
|
|
155
|
+
),
|
|
156
|
+
(
|
|
157
|
+
"UNSPECIFIED_STATE",
|
|
158
|
+
FineTuneStatusType.unknown,
|
|
159
|
+
"Unknown fine-tuning job status [UNSPECIFIED_STATE]",
|
|
160
|
+
),
|
|
161
|
+
],
|
|
162
|
+
)
|
|
163
|
+
async def test_status_job_states(
|
|
164
|
+
fireworks_finetune,
|
|
165
|
+
mock_response,
|
|
166
|
+
mock_client,
|
|
167
|
+
state,
|
|
168
|
+
expected_status,
|
|
169
|
+
message,
|
|
170
|
+
mock_api_key,
|
|
171
|
+
):
|
|
172
|
+
mock_response.json.return_value = {"state": state}
|
|
173
|
+
mock_client.get.return_value = mock_response
|
|
174
|
+
|
|
175
|
+
with (
|
|
176
|
+
patch("httpx.AsyncClient") as mock_client_class,
|
|
177
|
+
patch.object(fireworks_finetune, "_deploy", return_value=True),
|
|
178
|
+
):
|
|
179
|
+
mock_client_class.return_value.__aenter__.return_value = mock_client
|
|
180
|
+
status = await fireworks_finetune.status()
|
|
181
|
+
assert status.status == expected_status
|
|
182
|
+
assert message == status.message
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
async def test_status_invalid_response(
|
|
186
|
+
fireworks_finetune, mock_response, mock_client, mock_api_key
|
|
187
|
+
):
|
|
188
|
+
mock_response.json.return_value = {"no_state_field": "value"}
|
|
189
|
+
mock_client.get.return_value = mock_response
|
|
190
|
+
|
|
191
|
+
with patch("httpx.AsyncClient") as mock_client_class:
|
|
192
|
+
mock_client_class.return_value.__aenter__.return_value = mock_client
|
|
193
|
+
status = await fireworks_finetune.status()
|
|
194
|
+
assert status.status == FineTuneStatusType.unknown
|
|
195
|
+
assert "Invalid response from Fireworks" in status.message
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
async def test_status_request_exception(fireworks_finetune, mock_client, mock_api_key):
|
|
199
|
+
mock_client.get.side_effect = Exception("Connection error")
|
|
200
|
+
|
|
201
|
+
with patch("httpx.AsyncClient") as mock_client_class:
|
|
202
|
+
mock_client_class.return_value.__aenter__.return_value = mock_client
|
|
203
|
+
status = await fireworks_finetune.status()
|
|
204
|
+
assert status.status == FineTuneStatusType.unknown
|
|
205
|
+
assert (
|
|
206
|
+
"Error retrieving fine-tuning job status: Connection error"
|
|
207
|
+
== status.message
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
@pytest.fixture
|
|
212
|
+
def mock_dataset():
|
|
213
|
+
return DatasetSplit(
|
|
214
|
+
id="test-dataset-123",
|
|
215
|
+
name="Test Dataset",
|
|
216
|
+
splits=Train80Test20SplitDefinition,
|
|
217
|
+
split_contents={"train": [], "test": []},
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
@pytest.fixture
|
|
222
|
+
def mock_task():
|
|
223
|
+
return Task(
|
|
224
|
+
id="test-task-123",
|
|
225
|
+
name="Test Task",
|
|
226
|
+
output_json_schema=None, # Can be modified in specific tests
|
|
227
|
+
instruction="Test instruction",
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
async def test_generate_and_upload_jsonl_success(
|
|
232
|
+
fireworks_finetune, mock_dataset, mock_task, mock_api_key
|
|
233
|
+
):
|
|
234
|
+
mock_path = Path("mock_path.jsonl")
|
|
235
|
+
mock_dataset_id = "dataset-123"
|
|
236
|
+
|
|
237
|
+
# Mock the formatter
|
|
238
|
+
mock_formatter = MagicMock(spec=DatasetFormatter)
|
|
239
|
+
mock_formatter.dump_to_file.return_value = mock_path
|
|
240
|
+
|
|
241
|
+
# Mock responses for the three API calls
|
|
242
|
+
create_response = MagicMock(spec=httpx.Response)
|
|
243
|
+
create_response.status_code = 200
|
|
244
|
+
|
|
245
|
+
upload_response = MagicMock(spec=httpx.Response)
|
|
246
|
+
upload_response.status_code = 200
|
|
247
|
+
|
|
248
|
+
status_response = MagicMock(spec=httpx.Response)
|
|
249
|
+
status_response.status_code = 200
|
|
250
|
+
status_response.json.return_value = {"state": "READY"}
|
|
251
|
+
|
|
252
|
+
with (
|
|
253
|
+
patch(
|
|
254
|
+
"kiln_ai.adapters.fine_tune.fireworks_finetune.DatasetFormatter",
|
|
255
|
+
return_value=mock_formatter,
|
|
256
|
+
),
|
|
257
|
+
patch("httpx.AsyncClient") as mock_client_class,
|
|
258
|
+
patch("builtins.open"),
|
|
259
|
+
patch(
|
|
260
|
+
"kiln_ai.adapters.fine_tune.fireworks_finetune.uuid4",
|
|
261
|
+
return_value=mock_dataset_id,
|
|
262
|
+
),
|
|
263
|
+
):
|
|
264
|
+
mock_client = AsyncMock()
|
|
265
|
+
mock_client.post = AsyncMock(side_effect=[create_response, upload_response])
|
|
266
|
+
mock_client.get = AsyncMock(return_value=status_response)
|
|
267
|
+
mock_client_class.return_value.__aenter__.return_value = mock_client
|
|
268
|
+
|
|
269
|
+
result = await fireworks_finetune.generate_and_upload_jsonl(
|
|
270
|
+
mock_dataset, "train", mock_task
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
# Verify formatter was created with correct parameters
|
|
274
|
+
mock_formatter.dump_to_file.assert_called_once_with(
|
|
275
|
+
"train", DatasetFormat.OPENAI_CHAT_JSONL
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
assert result == mock_dataset_id
|
|
279
|
+
assert mock_client.post.call_count == 2
|
|
280
|
+
assert mock_client.get.call_count == 1
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
async def test_start_success(fireworks_finetune, mock_dataset, mock_task, mock_api_key):
|
|
284
|
+
fireworks_finetune.datamodel.parent = mock_task
|
|
285
|
+
mock_dataset_id = "dataset-123"
|
|
286
|
+
mock_model_id = "ft-model-123"
|
|
287
|
+
|
|
288
|
+
# Mock response for create fine-tuning job
|
|
289
|
+
create_response = MagicMock(spec=httpx.Response)
|
|
290
|
+
create_response.status_code = 200
|
|
291
|
+
create_response.json.return_value = {"name": mock_model_id}
|
|
292
|
+
|
|
293
|
+
with (
|
|
294
|
+
patch.object(
|
|
295
|
+
fireworks_finetune,
|
|
296
|
+
"generate_and_upload_jsonl",
|
|
297
|
+
return_value=mock_dataset_id,
|
|
298
|
+
),
|
|
299
|
+
patch("httpx.AsyncClient") as mock_client_class,
|
|
300
|
+
):
|
|
301
|
+
mock_client = AsyncMock()
|
|
302
|
+
mock_client.post.return_value = create_response
|
|
303
|
+
mock_client_class.return_value.__aenter__.return_value = mock_client
|
|
304
|
+
|
|
305
|
+
await fireworks_finetune._start(mock_dataset)
|
|
306
|
+
|
|
307
|
+
# Verify dataset was uploaded
|
|
308
|
+
fireworks_finetune.generate_and_upload_jsonl.assert_called_once_with(
|
|
309
|
+
mock_dataset, fireworks_finetune.datamodel.train_split_name, mock_task
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
# Verify model ID was updated
|
|
313
|
+
assert fireworks_finetune.datamodel.provider_id == mock_model_id
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
async def test_start_api_error(
|
|
317
|
+
fireworks_finetune, mock_dataset, mock_task, mock_api_key
|
|
318
|
+
):
|
|
319
|
+
fireworks_finetune.datamodel.parent = mock_task
|
|
320
|
+
mock_dataset_id = "dataset-123"
|
|
321
|
+
|
|
322
|
+
# Mock error response
|
|
323
|
+
error_response = MagicMock(spec=httpx.Response)
|
|
324
|
+
error_response.status_code = 500
|
|
325
|
+
error_response.text = "Internal Server Error"
|
|
326
|
+
|
|
327
|
+
with (
|
|
328
|
+
patch.object(
|
|
329
|
+
fireworks_finetune,
|
|
330
|
+
"generate_and_upload_jsonl",
|
|
331
|
+
return_value=mock_dataset_id,
|
|
332
|
+
),
|
|
333
|
+
patch("httpx.AsyncClient") as mock_client_class,
|
|
334
|
+
):
|
|
335
|
+
mock_client = AsyncMock()
|
|
336
|
+
mock_client.post.return_value = error_response
|
|
337
|
+
mock_client_class.return_value.__aenter__.return_value = mock_client
|
|
338
|
+
|
|
339
|
+
with pytest.raises(ValueError, match="Failed to create fine-tuning job"):
|
|
340
|
+
await fireworks_finetune._start(mock_dataset)
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def test_available_parameters(fireworks_finetune):
|
|
344
|
+
parameters = fireworks_finetune.available_parameters()
|
|
345
|
+
assert len(parameters) == 4
|
|
346
|
+
assert all(isinstance(p, FineTuneParameter) for p in parameters)
|
|
347
|
+
|
|
348
|
+
payload_parameters = fireworks_finetune.create_payload_parameters(
|
|
349
|
+
{"lora_rank": 16, "epochs": 3, "learning_rate": 0.001, "batch_size": 32}
|
|
350
|
+
)
|
|
351
|
+
assert payload_parameters == {
|
|
352
|
+
"loraRank": 16,
|
|
353
|
+
"epochs": 3,
|
|
354
|
+
"learningRate": 0.001,
|
|
355
|
+
"batchSize": 32,
|
|
356
|
+
}
|
|
357
|
+
payload_parameters = fireworks_finetune.create_payload_parameters({})
|
|
358
|
+
assert payload_parameters == {}
|
|
359
|
+
|
|
360
|
+
payload_parameters = fireworks_finetune.create_payload_parameters(
|
|
361
|
+
{"lora_rank": 16, "epochs": 3}
|
|
362
|
+
)
|
|
363
|
+
assert payload_parameters == {"loraRank": 16, "epochs": 3}
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
async def test_deploy_success(fireworks_finetune, mock_api_key):
|
|
367
|
+
# Mock response for successful deployment
|
|
368
|
+
success_response = MagicMock(spec=httpx.Response)
|
|
369
|
+
success_response.status_code = 200
|
|
370
|
+
assert fireworks_finetune.datamodel.fine_tune_model_id is None
|
|
371
|
+
|
|
372
|
+
with patch("httpx.AsyncClient") as mock_client_class:
|
|
373
|
+
mock_client = AsyncMock()
|
|
374
|
+
mock_client.post.return_value = success_response
|
|
375
|
+
mock_client_class.return_value.__aenter__.return_value = mock_client
|
|
376
|
+
|
|
377
|
+
result = await fireworks_finetune._deploy()
|
|
378
|
+
assert result is True
|
|
379
|
+
assert fireworks_finetune.datamodel.fine_tune_model_id == "ftm-123"
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
async def test_deploy_already_deployed(fireworks_finetune, mock_api_key):
|
|
383
|
+
# Mock response for already deployed model
|
|
384
|
+
already_deployed_response = MagicMock(spec=httpx.Response)
|
|
385
|
+
already_deployed_response.status_code = 400
|
|
386
|
+
already_deployed_response.json.return_value = {
|
|
387
|
+
"code": 9,
|
|
388
|
+
"message": "Model already deployed",
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
with patch("httpx.AsyncClient") as mock_client_class:
|
|
392
|
+
mock_client = AsyncMock()
|
|
393
|
+
mock_client.post.return_value = already_deployed_response
|
|
394
|
+
mock_client_class.return_value.__aenter__.return_value = mock_client
|
|
395
|
+
|
|
396
|
+
result = await fireworks_finetune._deploy()
|
|
397
|
+
assert result is True
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
async def test_deploy_failure(fireworks_finetune, mock_api_key):
|
|
401
|
+
# Mock response for failed deployment
|
|
402
|
+
failure_response = MagicMock(spec=httpx.Response)
|
|
403
|
+
failure_response.status_code = 500
|
|
404
|
+
failure_response.json.return_value = {"code": 1}
|
|
405
|
+
|
|
406
|
+
with patch("httpx.AsyncClient") as mock_client_class:
|
|
407
|
+
mock_client = AsyncMock()
|
|
408
|
+
mock_client.post.return_value = failure_response
|
|
409
|
+
mock_client_class.return_value.__aenter__.return_value = mock_client
|
|
410
|
+
|
|
411
|
+
result = await fireworks_finetune._deploy()
|
|
412
|
+
assert result is False
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
async def test_deploy_missing_credentials(fireworks_finetune):
|
|
416
|
+
# Test missing API key or account ID
|
|
417
|
+
with patch.object(Config, "shared") as mock_config:
|
|
418
|
+
mock_config.return_value.fireworks_api_key = None
|
|
419
|
+
mock_config.return_value.fireworks_account_id = None
|
|
420
|
+
|
|
421
|
+
with pytest.raises(ValueError, match="Fireworks API key or account ID not set"):
|
|
422
|
+
await fireworks_finetune._deploy()
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
async def test_deploy_missing_model_id(fireworks_finetune, mock_api_key):
|
|
426
|
+
# Test missing model ID
|
|
427
|
+
fireworks_finetune.datamodel.properties["undeployed_model_id"] = None
|
|
428
|
+
|
|
429
|
+
response = await fireworks_finetune._deploy()
|
|
430
|
+
assert response is False
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
async def test_status_with_deploy(fireworks_finetune, mock_api_key):
|
|
434
|
+
# Mock _status to return completed
|
|
435
|
+
mock_status_response = FineTuneStatus(
|
|
436
|
+
status=FineTuneStatusType.completed, message="Fine-tuning job completed"
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
with (
|
|
440
|
+
patch.object(
|
|
441
|
+
fireworks_finetune, "_status", return_value=mock_status_response
|
|
442
|
+
) as mock_status,
|
|
443
|
+
patch.object(fireworks_finetune, "_deploy", return_value=False) as mock_deploy,
|
|
444
|
+
):
|
|
445
|
+
status = await fireworks_finetune.status()
|
|
446
|
+
|
|
447
|
+
# Verify _status was called
|
|
448
|
+
mock_status.assert_called_once()
|
|
449
|
+
|
|
450
|
+
# Verify _deploy was called since status was completed
|
|
451
|
+
mock_deploy.assert_called_once()
|
|
452
|
+
|
|
453
|
+
# Verify message was updated due to failed deployment
|
|
454
|
+
assert status.status == FineTuneStatusType.completed
|
|
455
|
+
assert status.message == "Fine-tuning job completed but failed to deploy model."
|