kiln-ai 0.6.1__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +2 -0
- kiln_ai/adapters/adapter_registry.py +19 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +29 -21
- kiln_ai/adapters/fine_tune/__init__.py +14 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
- kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
- kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
- kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
- kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
- kiln_ai/adapters/langchain_adapters.py +103 -13
- kiln_ai/adapters/ml_model_list.py +218 -304
- kiln_ai/adapters/ollama_tools.py +114 -0
- kiln_ai/adapters/provider_tools.py +295 -0
- kiln_ai/adapters/repair/test_repair_task.py +6 -11
- kiln_ai/adapters/test_langchain_adapter.py +46 -18
- kiln_ai/adapters/test_ollama_tools.py +42 -0
- kiln_ai/adapters/test_prompt_adaptors.py +7 -5
- kiln_ai/adapters/test_provider_tools.py +312 -0
- kiln_ai/adapters/test_structured_output.py +22 -43
- kiln_ai/datamodel/__init__.py +235 -22
- kiln_ai/datamodel/basemodel.py +30 -0
- kiln_ai/datamodel/registry.py +31 -0
- kiln_ai/datamodel/test_basemodel.py +29 -1
- kiln_ai/datamodel/test_dataset_split.py +234 -0
- kiln_ai/datamodel/test_example_models.py +12 -0
- kiln_ai/datamodel/test_models.py +91 -1
- kiln_ai/datamodel/test_registry.py +96 -0
- kiln_ai/utils/config.py +9 -0
- kiln_ai/utils/name_generator.py +125 -0
- kiln_ai/utils/test_name_geneator.py +47 -0
- {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.0.dist-info}/METADATA +4 -2
- kiln_ai-0.7.0.dist-info/RECORD +56 -0
- kiln_ai/adapters/test_ml_model_list.py +0 -181
- kiln_ai-0.6.1.dist-info/RECORD +0 -37
- {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,503 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from unittest import mock
|
|
4
|
+
from unittest.mock import MagicMock, patch
|
|
5
|
+
|
|
6
|
+
import openai
|
|
7
|
+
import pytest
|
|
8
|
+
from openai.types.fine_tuning import FineTuningJob
|
|
9
|
+
|
|
10
|
+
from kiln_ai.adapters.fine_tune.base_finetune import FineTuneStatusType
|
|
11
|
+
from kiln_ai.adapters.fine_tune.dataset_formatter import DatasetFormat, DatasetFormatter
|
|
12
|
+
from kiln_ai.adapters.fine_tune.openai_finetune import OpenAIFinetune
|
|
13
|
+
from kiln_ai.datamodel import DatasetSplit, Task, Train80Test20SplitDefinition
|
|
14
|
+
from kiln_ai.datamodel import Finetune as FinetuneModel
|
|
15
|
+
from kiln_ai.utils.config import Config
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@pytest.fixture
|
|
19
|
+
def openai_finetune(tmp_path):
|
|
20
|
+
tmp_file = tmp_path / "test-finetune.kiln"
|
|
21
|
+
finetune = OpenAIFinetune(
|
|
22
|
+
datamodel=FinetuneModel(
|
|
23
|
+
name="test-finetune",
|
|
24
|
+
provider="openai",
|
|
25
|
+
provider_id="openai-123",
|
|
26
|
+
base_model_id="gpt-4o",
|
|
27
|
+
train_split_name="train",
|
|
28
|
+
dataset_split_id="dataset-123",
|
|
29
|
+
system_message="Test system message",
|
|
30
|
+
fine_tune_model_id="ft-123",
|
|
31
|
+
path=tmp_file,
|
|
32
|
+
),
|
|
33
|
+
)
|
|
34
|
+
return finetune
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@pytest.fixture
|
|
38
|
+
def mock_response():
|
|
39
|
+
response = MagicMock(spec=FineTuningJob)
|
|
40
|
+
response.error = None
|
|
41
|
+
response.status = "succeeded"
|
|
42
|
+
response.finished_at = time.time()
|
|
43
|
+
response.estimated_finish = None
|
|
44
|
+
response.fine_tuned_model = "ft-123"
|
|
45
|
+
response.model = "gpt-4o"
|
|
46
|
+
return response
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@pytest.fixture
|
|
50
|
+
def mock_dataset():
|
|
51
|
+
return DatasetSplit(
|
|
52
|
+
id="test-dataset-123",
|
|
53
|
+
name="Test Dataset",
|
|
54
|
+
splits=Train80Test20SplitDefinition,
|
|
55
|
+
split_contents={"train": [], "test": []},
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@pytest.fixture
|
|
60
|
+
def mock_task():
|
|
61
|
+
return Task(
|
|
62
|
+
id="test-task-123",
|
|
63
|
+
name="Test Task",
|
|
64
|
+
output_json_schema=None, # Can be modified in specific tests
|
|
65
|
+
instruction="Test instruction",
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
async def test_setup(openai_finetune):
|
|
70
|
+
if not Config.shared().open_ai_api_key:
|
|
71
|
+
pytest.skip("OpenAI API key not set")
|
|
72
|
+
openai_finetune.provider_id = "openai-123"
|
|
73
|
+
openai_finetune.provider = "openai"
|
|
74
|
+
|
|
75
|
+
# Real API call, with fake ID
|
|
76
|
+
status = await openai_finetune.status()
|
|
77
|
+
# fake id fails
|
|
78
|
+
assert status.status == FineTuneStatusType.unknown
|
|
79
|
+
assert "Job with this ID not found. It may have been deleted." == status.message
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@pytest.mark.parametrize(
|
|
83
|
+
"exception,expected_status,expected_message",
|
|
84
|
+
[
|
|
85
|
+
(
|
|
86
|
+
openai.APIConnectionError(request=MagicMock()),
|
|
87
|
+
FineTuneStatusType.unknown,
|
|
88
|
+
"Server connection error",
|
|
89
|
+
),
|
|
90
|
+
(
|
|
91
|
+
openai.RateLimitError(
|
|
92
|
+
message="Rate limit exceeded", body={}, response=MagicMock()
|
|
93
|
+
),
|
|
94
|
+
FineTuneStatusType.unknown,
|
|
95
|
+
"Rate limit exceeded",
|
|
96
|
+
),
|
|
97
|
+
(
|
|
98
|
+
openai.APIStatusError(
|
|
99
|
+
"Not found",
|
|
100
|
+
response=MagicMock(status_code=404),
|
|
101
|
+
body={},
|
|
102
|
+
),
|
|
103
|
+
FineTuneStatusType.unknown,
|
|
104
|
+
"Job with this ID not found",
|
|
105
|
+
),
|
|
106
|
+
(
|
|
107
|
+
openai.APIStatusError(
|
|
108
|
+
"Server error",
|
|
109
|
+
response=MagicMock(status_code=500),
|
|
110
|
+
body={},
|
|
111
|
+
),
|
|
112
|
+
FineTuneStatusType.unknown,
|
|
113
|
+
"Unknown error",
|
|
114
|
+
),
|
|
115
|
+
],
|
|
116
|
+
)
|
|
117
|
+
async def test_status_api_errors(
|
|
118
|
+
openai_finetune, exception, expected_status, expected_message
|
|
119
|
+
):
|
|
120
|
+
with patch(
|
|
121
|
+
"kiln_ai.adapters.fine_tune.openai_finetune.oai_client.fine_tuning.jobs.retrieve",
|
|
122
|
+
side_effect=exception,
|
|
123
|
+
):
|
|
124
|
+
status = await openai_finetune.status()
|
|
125
|
+
assert status.status == expected_status
|
|
126
|
+
assert expected_message in status.message
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@pytest.mark.parametrize(
|
|
130
|
+
"job_status,expected_status,message_contains",
|
|
131
|
+
[
|
|
132
|
+
("failed", FineTuneStatusType.failed, "Job failed"),
|
|
133
|
+
("cancelled", FineTuneStatusType.failed, "Job cancelled"),
|
|
134
|
+
("succeeded", FineTuneStatusType.completed, "Training job completed"),
|
|
135
|
+
("running", FineTuneStatusType.running, "Fine tune job is running"),
|
|
136
|
+
("queued", FineTuneStatusType.running, "Fine tune job is running"),
|
|
137
|
+
(
|
|
138
|
+
"validating_files",
|
|
139
|
+
FineTuneStatusType.running,
|
|
140
|
+
"Fine tune job is running",
|
|
141
|
+
),
|
|
142
|
+
("unknown_status", FineTuneStatusType.unknown, "Unknown status"),
|
|
143
|
+
],
|
|
144
|
+
)
|
|
145
|
+
async def test_status_job_states(
|
|
146
|
+
openai_finetune,
|
|
147
|
+
mock_response,
|
|
148
|
+
job_status,
|
|
149
|
+
expected_status,
|
|
150
|
+
message_contains,
|
|
151
|
+
):
|
|
152
|
+
mock_response.status = job_status
|
|
153
|
+
|
|
154
|
+
with patch(
|
|
155
|
+
"kiln_ai.adapters.fine_tune.openai_finetune.oai_client.fine_tuning.jobs.retrieve",
|
|
156
|
+
return_value=mock_response,
|
|
157
|
+
):
|
|
158
|
+
status = await openai_finetune.status()
|
|
159
|
+
assert status.status == expected_status
|
|
160
|
+
assert message_contains in status.message
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
async def test_status_with_error_response(openai_finetune, mock_response):
|
|
164
|
+
mock_response.error = MagicMock()
|
|
165
|
+
mock_response.error.message = "Something went wrong"
|
|
166
|
+
|
|
167
|
+
with patch(
|
|
168
|
+
"kiln_ai.adapters.fine_tune.openai_finetune.oai_client.fine_tuning.jobs.retrieve",
|
|
169
|
+
return_value=mock_response,
|
|
170
|
+
):
|
|
171
|
+
status = await openai_finetune.status()
|
|
172
|
+
assert status.status == FineTuneStatusType.failed
|
|
173
|
+
assert status.message.startswith("Something went wrong [Code:")
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
async def test_status_with_estimated_finish_time(openai_finetune, mock_response):
|
|
177
|
+
current_time = time.time()
|
|
178
|
+
mock_response.status = "running"
|
|
179
|
+
mock_response.estimated_finish = current_time + 300 # 5 minutes from now
|
|
180
|
+
|
|
181
|
+
with patch(
|
|
182
|
+
"kiln_ai.adapters.fine_tune.openai_finetune.oai_client.fine_tuning.jobs.retrieve",
|
|
183
|
+
return_value=mock_response,
|
|
184
|
+
):
|
|
185
|
+
status = await openai_finetune.status()
|
|
186
|
+
assert status.status == FineTuneStatusType.running
|
|
187
|
+
assert (
|
|
188
|
+
"Estimated finish time: 299 seconds" in status.message
|
|
189
|
+
) # non zero time passes
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
async def test_status_empty_response(openai_finetune):
|
|
193
|
+
with patch(
|
|
194
|
+
"kiln_ai.adapters.fine_tune.openai_finetune.oai_client.fine_tuning.jobs.retrieve",
|
|
195
|
+
return_value=mock_response,
|
|
196
|
+
):
|
|
197
|
+
status = await openai_finetune.status()
|
|
198
|
+
assert status.status == FineTuneStatusType.unknown
|
|
199
|
+
assert "Invalid response from OpenAI" in status.message
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
async def test_generate_and_upload_jsonl_success(
|
|
203
|
+
openai_finetune, mock_dataset, mock_task
|
|
204
|
+
):
|
|
205
|
+
mock_path = Path("mock_path.jsonl")
|
|
206
|
+
mock_file_id = "file-123"
|
|
207
|
+
|
|
208
|
+
# Mock the formatter
|
|
209
|
+
mock_formatter = MagicMock(spec=DatasetFormatter)
|
|
210
|
+
mock_formatter.dump_to_file.return_value = mock_path
|
|
211
|
+
|
|
212
|
+
# Mock the file response
|
|
213
|
+
mock_file_response = MagicMock()
|
|
214
|
+
mock_file_response.id = mock_file_id
|
|
215
|
+
|
|
216
|
+
with (
|
|
217
|
+
patch(
|
|
218
|
+
"kiln_ai.adapters.fine_tune.openai_finetune.DatasetFormatter",
|
|
219
|
+
return_value=mock_formatter,
|
|
220
|
+
) as mock_formatter_class,
|
|
221
|
+
patch(
|
|
222
|
+
"kiln_ai.adapters.fine_tune.openai_finetune.oai_client.files.create",
|
|
223
|
+
return_value=mock_file_response,
|
|
224
|
+
) as mock_create,
|
|
225
|
+
patch("builtins.open") as mock_open,
|
|
226
|
+
):
|
|
227
|
+
result = await openai_finetune.generate_and_upload_jsonl(
|
|
228
|
+
mock_dataset, "train", mock_task
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
# Verify formatter was created with correct parameters
|
|
232
|
+
mock_formatter_class.assert_called_once_with(
|
|
233
|
+
mock_dataset, openai_finetune.datamodel.system_message
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# Verify correct format was used
|
|
237
|
+
mock_formatter.dump_to_file.assert_called_once_with(
|
|
238
|
+
"train", DatasetFormat.OPENAI_CHAT_JSONL
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
# Verify file was opened and uploaded
|
|
242
|
+
mock_open.assert_called_once_with(mock_path, "rb")
|
|
243
|
+
mock_create.assert_called_once()
|
|
244
|
+
|
|
245
|
+
assert result == mock_file_id
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
async def test_generate_and_upload_jsonl_toolcall_success(
|
|
249
|
+
openai_finetune, mock_dataset, mock_task
|
|
250
|
+
):
|
|
251
|
+
mock_path = Path("mock_path.jsonl")
|
|
252
|
+
mock_file_id = "file-123"
|
|
253
|
+
mock_task.output_json_schema = '{"type": "object", "properties": {"key": {"type": "string"}}}' # Add JSON schema
|
|
254
|
+
|
|
255
|
+
# Mock the formatter
|
|
256
|
+
mock_formatter = MagicMock(spec=DatasetFormatter)
|
|
257
|
+
mock_formatter.dump_to_file.return_value = mock_path
|
|
258
|
+
|
|
259
|
+
# Mock the file response
|
|
260
|
+
mock_file_response = MagicMock()
|
|
261
|
+
mock_file_response.id = mock_file_id
|
|
262
|
+
|
|
263
|
+
with (
|
|
264
|
+
patch(
|
|
265
|
+
"kiln_ai.adapters.fine_tune.openai_finetune.DatasetFormatter",
|
|
266
|
+
return_value=mock_formatter,
|
|
267
|
+
) as mock_formatter_class,
|
|
268
|
+
patch(
|
|
269
|
+
"kiln_ai.adapters.fine_tune.openai_finetune.oai_client.files.create",
|
|
270
|
+
return_value=mock_file_response,
|
|
271
|
+
) as mock_create,
|
|
272
|
+
patch("builtins.open") as mock_open,
|
|
273
|
+
):
|
|
274
|
+
result = await openai_finetune.generate_and_upload_jsonl(
|
|
275
|
+
mock_dataset, "train", mock_task
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
# Verify formatter was created with correct parameters
|
|
279
|
+
mock_formatter_class.assert_called_once_with(
|
|
280
|
+
mock_dataset, openai_finetune.datamodel.system_message
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
# Verify correct format was used
|
|
284
|
+
mock_formatter.dump_to_file.assert_called_once_with(
|
|
285
|
+
"train", DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
# Verify file was opened and uploaded
|
|
289
|
+
mock_open.assert_called_once_with(mock_path, "rb")
|
|
290
|
+
mock_create.assert_called_once()
|
|
291
|
+
|
|
292
|
+
assert result == mock_file_id
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
async def test_generate_and_upload_jsonl_upload_failure(
|
|
296
|
+
openai_finetune, mock_dataset, mock_task
|
|
297
|
+
):
|
|
298
|
+
mock_path = Path("mock_path.jsonl")
|
|
299
|
+
|
|
300
|
+
mock_formatter = MagicMock(spec=DatasetFormatter)
|
|
301
|
+
mock_formatter.dump_to_file.return_value = mock_path
|
|
302
|
+
|
|
303
|
+
# Mock response with no ID
|
|
304
|
+
mock_file_response = MagicMock()
|
|
305
|
+
mock_file_response.id = None
|
|
306
|
+
|
|
307
|
+
with (
|
|
308
|
+
patch(
|
|
309
|
+
"kiln_ai.adapters.fine_tune.openai_finetune.DatasetFormatter",
|
|
310
|
+
return_value=mock_formatter,
|
|
311
|
+
),
|
|
312
|
+
patch(
|
|
313
|
+
"kiln_ai.adapters.fine_tune.openai_finetune.oai_client.files.create",
|
|
314
|
+
return_value=mock_file_response,
|
|
315
|
+
),
|
|
316
|
+
patch("builtins.open"),
|
|
317
|
+
):
|
|
318
|
+
with pytest.raises(ValueError, match="Failed to upload file to OpenAI"):
|
|
319
|
+
await openai_finetune.generate_and_upload_jsonl(
|
|
320
|
+
mock_dataset, "train", mock_task
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
async def test_generate_and_upload_jsonl_api_error(
|
|
325
|
+
openai_finetune, mock_dataset, mock_task
|
|
326
|
+
):
|
|
327
|
+
mock_path = Path("mock_path.jsonl")
|
|
328
|
+
|
|
329
|
+
mock_formatter = MagicMock(spec=DatasetFormatter)
|
|
330
|
+
mock_formatter.dump_to_file.return_value = mock_path
|
|
331
|
+
|
|
332
|
+
with (
|
|
333
|
+
patch(
|
|
334
|
+
"kiln_ai.adapters.fine_tune.openai_finetune.DatasetFormatter",
|
|
335
|
+
return_value=mock_formatter,
|
|
336
|
+
),
|
|
337
|
+
patch(
|
|
338
|
+
"kiln_ai.adapters.fine_tune.openai_finetune.oai_client.files.create",
|
|
339
|
+
side_effect=openai.APIError(
|
|
340
|
+
message="API error", request=MagicMock(), body={}
|
|
341
|
+
),
|
|
342
|
+
),
|
|
343
|
+
patch("builtins.open"),
|
|
344
|
+
):
|
|
345
|
+
with pytest.raises(openai.APIError):
|
|
346
|
+
await openai_finetune.generate_and_upload_jsonl(
|
|
347
|
+
mock_dataset, "train", mock_task
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
async def test_start_success(openai_finetune, mock_dataset, mock_task):
|
|
352
|
+
openai_finetune.datamodel.parent = mock_task
|
|
353
|
+
|
|
354
|
+
# Mock parameters
|
|
355
|
+
openai_finetune.datamodel.parameters = {
|
|
356
|
+
"n_epochs": 3,
|
|
357
|
+
"learning_rate_multiplier": 0.1,
|
|
358
|
+
"batch_size": 4,
|
|
359
|
+
"ignored_param": "value",
|
|
360
|
+
}
|
|
361
|
+
|
|
362
|
+
# Mock the fine-tuning response
|
|
363
|
+
mock_ft_response = MagicMock()
|
|
364
|
+
mock_ft_response.id = "ft-123"
|
|
365
|
+
mock_ft_response.fine_tuned_model = None
|
|
366
|
+
mock_ft_response.model = "gpt-4o-mini-2024-07-18"
|
|
367
|
+
|
|
368
|
+
with (
|
|
369
|
+
patch.object(
|
|
370
|
+
openai_finetune,
|
|
371
|
+
"generate_and_upload_jsonl",
|
|
372
|
+
side_effect=["train-file-123", "val-file-123"],
|
|
373
|
+
) as mock_upload,
|
|
374
|
+
patch(
|
|
375
|
+
"kiln_ai.adapters.fine_tune.openai_finetune.oai_client.fine_tuning.jobs.create",
|
|
376
|
+
return_value=mock_ft_response,
|
|
377
|
+
) as mock_create,
|
|
378
|
+
):
|
|
379
|
+
await openai_finetune._start(mock_dataset)
|
|
380
|
+
|
|
381
|
+
# Verify file uploads
|
|
382
|
+
assert mock_upload.call_count == 1 # Only training file
|
|
383
|
+
mock_upload.assert_called_with(
|
|
384
|
+
mock_dataset, openai_finetune.datamodel.train_split_name, mock_task
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
# Verify fine-tune creation
|
|
388
|
+
mock_create.assert_called_once_with(
|
|
389
|
+
training_file="train-file-123",
|
|
390
|
+
model="gpt-4o",
|
|
391
|
+
validation_file=None,
|
|
392
|
+
seed=None,
|
|
393
|
+
hyperparameters={
|
|
394
|
+
"n_epochs": 3,
|
|
395
|
+
"learning_rate_multiplier": 0.1,
|
|
396
|
+
"batch_size": 4,
|
|
397
|
+
},
|
|
398
|
+
suffix=f"kiln_ai.{openai_finetune.datamodel.id}",
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
# Verify model updates
|
|
402
|
+
assert openai_finetune.datamodel.provider_id == "ft-123"
|
|
403
|
+
assert openai_finetune.datamodel.base_model_id == "gpt-4o-mini-2024-07-18"
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
async def test_start_with_validation(openai_finetune, mock_dataset, mock_task):
|
|
407
|
+
openai_finetune.datamodel.parent = mock_task
|
|
408
|
+
openai_finetune.datamodel.validation_split_name = "validation"
|
|
409
|
+
|
|
410
|
+
mock_ft_response = MagicMock()
|
|
411
|
+
mock_ft_response.id = "ft-123"
|
|
412
|
+
mock_ft_response.fine_tuned_model = None
|
|
413
|
+
mock_ft_response.model = "gpt-4o-mini-2024-07-18"
|
|
414
|
+
|
|
415
|
+
with (
|
|
416
|
+
patch.object(
|
|
417
|
+
openai_finetune,
|
|
418
|
+
"generate_and_upload_jsonl",
|
|
419
|
+
side_effect=["train-file-123", "val-file-123"],
|
|
420
|
+
) as mock_upload,
|
|
421
|
+
patch(
|
|
422
|
+
"kiln_ai.adapters.fine_tune.openai_finetune.oai_client.fine_tuning.jobs.create",
|
|
423
|
+
return_value=mock_ft_response,
|
|
424
|
+
) as mock_create,
|
|
425
|
+
):
|
|
426
|
+
await openai_finetune._start(mock_dataset)
|
|
427
|
+
|
|
428
|
+
# Verify both files were uploaded
|
|
429
|
+
assert mock_upload.call_count == 2
|
|
430
|
+
mock_upload.assert_has_calls(
|
|
431
|
+
[
|
|
432
|
+
mock.call(
|
|
433
|
+
mock_dataset, openai_finetune.datamodel.train_split_name, mock_task
|
|
434
|
+
),
|
|
435
|
+
mock.call(mock_dataset, "validation", mock_task),
|
|
436
|
+
]
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
# Verify validation file was included
|
|
440
|
+
mock_create.assert_called_once()
|
|
441
|
+
assert mock_create.call_args[1]["validation_file"] == "val-file-123"
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
async def test_start_no_task(openai_finetune, mock_dataset):
|
|
445
|
+
openai_finetune.datamodel.parent = None
|
|
446
|
+
openai_finetune.datamodel.path = None
|
|
447
|
+
|
|
448
|
+
with pytest.raises(ValueError, match="Task is required to start a fine-tune"):
|
|
449
|
+
await openai_finetune._start(mock_dataset)
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
async def test_status_updates_model_ids(openai_finetune, mock_response):
|
|
453
|
+
# Set up initial model IDs
|
|
454
|
+
openai_finetune.datamodel.fine_tune_model_id = "old-ft-model"
|
|
455
|
+
openai_finetune.datamodel.base_model_id = "old-base-model"
|
|
456
|
+
|
|
457
|
+
# Configure mock response with different model IDs
|
|
458
|
+
mock_response.fine_tuned_model = "new-ft-model"
|
|
459
|
+
mock_response.model = "new-base-model"
|
|
460
|
+
mock_response.status = "succeeded"
|
|
461
|
+
|
|
462
|
+
with (
|
|
463
|
+
patch(
|
|
464
|
+
"kiln_ai.adapters.fine_tune.openai_finetune.oai_client.fine_tuning.jobs.retrieve",
|
|
465
|
+
return_value=mock_response,
|
|
466
|
+
),
|
|
467
|
+
):
|
|
468
|
+
status = await openai_finetune.status()
|
|
469
|
+
|
|
470
|
+
# Verify model IDs were updated
|
|
471
|
+
assert openai_finetune.datamodel.fine_tune_model_id == "new-ft-model"
|
|
472
|
+
assert openai_finetune.datamodel.base_model_id == "new-base-model"
|
|
473
|
+
|
|
474
|
+
# Verify save was called
|
|
475
|
+
# This isn't properly mocked, so not checking
|
|
476
|
+
# assert openai_finetune.datamodel.save.called
|
|
477
|
+
|
|
478
|
+
# Verify status is still returned correctly
|
|
479
|
+
assert status.status == FineTuneStatusType.completed
|
|
480
|
+
assert status.message == "Training job completed"
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
async def test_status_updates_latest_status(openai_finetune, mock_response):
|
|
484
|
+
# Set initial status
|
|
485
|
+
openai_finetune.datamodel.latest_status = FineTuneStatusType.running
|
|
486
|
+
assert openai_finetune.datamodel.latest_status == FineTuneStatusType.running
|
|
487
|
+
mock_response.status = "succeeded"
|
|
488
|
+
|
|
489
|
+
with (
|
|
490
|
+
patch(
|
|
491
|
+
"kiln_ai.adapters.fine_tune.openai_finetune.oai_client.fine_tuning.jobs.retrieve",
|
|
492
|
+
return_value=mock_response,
|
|
493
|
+
),
|
|
494
|
+
):
|
|
495
|
+
status = await openai_finetune.status()
|
|
496
|
+
|
|
497
|
+
# Verify status was updated in datamodel
|
|
498
|
+
assert openai_finetune.datamodel.latest_status == FineTuneStatusType.completed
|
|
499
|
+
assert status.status == FineTuneStatusType.completed
|
|
500
|
+
assert status.message == "Training job completed"
|
|
501
|
+
|
|
502
|
+
# Verify file was saved
|
|
503
|
+
assert openai_finetune.datamodel.path.exists()
|
|
@@ -1,21 +1,35 @@
|
|
|
1
|
-
|
|
1
|
+
import os
|
|
2
|
+
from os import getenv
|
|
3
|
+
from typing import Any, Dict
|
|
2
4
|
|
|
5
|
+
from langchain_aws import ChatBedrockConverse
|
|
3
6
|
from langchain_core.language_models import LanguageModelInput
|
|
4
7
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
5
8
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
6
9
|
from langchain_core.messages.base import BaseMessage
|
|
7
10
|
from langchain_core.runnables import Runnable
|
|
11
|
+
from langchain_fireworks import ChatFireworks
|
|
12
|
+
from langchain_groq import ChatGroq
|
|
13
|
+
from langchain_ollama import ChatOllama
|
|
14
|
+
from langchain_openai import ChatOpenAI
|
|
8
15
|
from pydantic import BaseModel
|
|
9
16
|
|
|
10
17
|
import kiln_ai.datamodel as datamodel
|
|
18
|
+
from kiln_ai.adapters.ollama_tools import (
|
|
19
|
+
get_ollama_connection,
|
|
20
|
+
ollama_base_url,
|
|
21
|
+
ollama_model_installed,
|
|
22
|
+
)
|
|
23
|
+
from kiln_ai.utils.config import Config
|
|
11
24
|
|
|
12
25
|
from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder, RunOutput
|
|
13
|
-
from .ml_model_list import
|
|
26
|
+
from .ml_model_list import KilnModelProvider, ModelProviderName
|
|
27
|
+
from .provider_tools import kiln_model_provider_from
|
|
14
28
|
|
|
15
29
|
LangChainModelType = BaseChatModel | Runnable[LanguageModelInput, Dict | BaseModel]
|
|
16
30
|
|
|
17
31
|
|
|
18
|
-
class
|
|
32
|
+
class LangchainAdapter(BaseAdapter):
|
|
19
33
|
_model: LangChainModelType | None = None
|
|
20
34
|
|
|
21
35
|
def __init__(
|
|
@@ -51,12 +65,6 @@ class LangChainPromptAdapter(BaseAdapter):
|
|
|
51
65
|
"model_name and provider must be provided if custom_model is not provided"
|
|
52
66
|
)
|
|
53
67
|
|
|
54
|
-
def adapter_specific_instructions(self) -> str | None:
|
|
55
|
-
# TODO: would be better to explicitly use bind_tools:tool_choice="task_response" here
|
|
56
|
-
if self.has_structured_output():
|
|
57
|
-
return "Always respond with a tool call. Never respond with a human readable message."
|
|
58
|
-
return None
|
|
59
|
-
|
|
60
68
|
async def model(self) -> LangChainModelType:
|
|
61
69
|
# cached model
|
|
62
70
|
if self._model:
|
|
@@ -79,8 +87,13 @@ class LangChainPromptAdapter(BaseAdapter):
|
|
|
79
87
|
)
|
|
80
88
|
output_schema["title"] = "task_response"
|
|
81
89
|
output_schema["description"] = "A response from the task"
|
|
90
|
+
with_structured_output_options = await get_structured_output_options(
|
|
91
|
+
self.model_name, self.model_provider
|
|
92
|
+
)
|
|
82
93
|
self._model = self._model.with_structured_output(
|
|
83
|
-
output_schema,
|
|
94
|
+
output_schema,
|
|
95
|
+
include_raw=True,
|
|
96
|
+
**with_structured_output_options,
|
|
84
97
|
)
|
|
85
98
|
return self._model
|
|
86
99
|
|
|
@@ -108,17 +121,16 @@ class LangChainPromptAdapter(BaseAdapter):
|
|
|
108
121
|
)
|
|
109
122
|
|
|
110
123
|
cot_messages = [*messages]
|
|
111
|
-
cot_response = base_model.
|
|
124
|
+
cot_response = await base_model.ainvoke(cot_messages)
|
|
112
125
|
intermediate_outputs["chain_of_thought"] = cot_response.content
|
|
113
126
|
messages.append(AIMessage(content=cot_response.content))
|
|
114
127
|
messages.append(
|
|
115
128
|
SystemMessage(content="Considering the above, return a final result.")
|
|
116
129
|
)
|
|
117
130
|
elif cot_prompt:
|
|
118
|
-
# for plaintext output, we just add COT instructions. We still only make one call.
|
|
119
131
|
messages.append(SystemMessage(content=cot_prompt))
|
|
120
132
|
|
|
121
|
-
response = chain.
|
|
133
|
+
response = await chain.ainvoke(messages)
|
|
122
134
|
|
|
123
135
|
if self.has_structured_output():
|
|
124
136
|
if (
|
|
@@ -160,3 +172,81 @@ class LangChainPromptAdapter(BaseAdapter):
|
|
|
160
172
|
):
|
|
161
173
|
return response["arguments"]
|
|
162
174
|
return response
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
async def get_structured_output_options(
|
|
178
|
+
model_name: str, model_provider: str
|
|
179
|
+
) -> Dict[str, Any]:
|
|
180
|
+
finetune_provider = await kiln_model_provider_from(model_name, model_provider)
|
|
181
|
+
if finetune_provider and finetune_provider.adapter_options.get("langchain"):
|
|
182
|
+
return finetune_provider.adapter_options["langchain"].get(
|
|
183
|
+
"with_structured_output_options", {}
|
|
184
|
+
)
|
|
185
|
+
return {}
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
async def langchain_model_from(
|
|
189
|
+
name: str, provider_name: str | None = None
|
|
190
|
+
) -> BaseChatModel:
|
|
191
|
+
provider = await kiln_model_provider_from(name, provider_name)
|
|
192
|
+
return await langchain_model_from_provider(provider, name)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
async def langchain_model_from_provider(
|
|
196
|
+
provider: KilnModelProvider, model_name: str
|
|
197
|
+
) -> BaseChatModel:
|
|
198
|
+
if provider.name == ModelProviderName.openai:
|
|
199
|
+
api_key = Config.shared().open_ai_api_key
|
|
200
|
+
return ChatOpenAI(**provider.provider_options, openai_api_key=api_key) # type: ignore[arg-type]
|
|
201
|
+
elif provider.name == ModelProviderName.groq:
|
|
202
|
+
api_key = Config.shared().groq_api_key
|
|
203
|
+
if api_key is None:
|
|
204
|
+
raise ValueError(
|
|
205
|
+
"Attempted to use Groq without an API key set. "
|
|
206
|
+
"Get your API key from https://console.groq.com/keys"
|
|
207
|
+
)
|
|
208
|
+
return ChatGroq(**provider.provider_options, groq_api_key=api_key) # type: ignore[arg-type]
|
|
209
|
+
elif provider.name == ModelProviderName.amazon_bedrock:
|
|
210
|
+
api_key = Config.shared().bedrock_access_key
|
|
211
|
+
secret_key = Config.shared().bedrock_secret_key
|
|
212
|
+
# langchain doesn't allow passing these, so ugly hack to set env vars
|
|
213
|
+
os.environ["AWS_ACCESS_KEY_ID"] = api_key
|
|
214
|
+
os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key
|
|
215
|
+
return ChatBedrockConverse(
|
|
216
|
+
**provider.provider_options,
|
|
217
|
+
)
|
|
218
|
+
elif provider.name == ModelProviderName.fireworks_ai:
|
|
219
|
+
api_key = Config.shared().fireworks_api_key
|
|
220
|
+
return ChatFireworks(**provider.provider_options, api_key=api_key)
|
|
221
|
+
elif provider.name == ModelProviderName.ollama:
|
|
222
|
+
# Ollama model naming is pretty flexible. We try a few versions of the model name
|
|
223
|
+
potential_model_names = []
|
|
224
|
+
if "model" in provider.provider_options:
|
|
225
|
+
potential_model_names.append(provider.provider_options["model"])
|
|
226
|
+
if "model_aliases" in provider.provider_options:
|
|
227
|
+
potential_model_names.extend(provider.provider_options["model_aliases"])
|
|
228
|
+
|
|
229
|
+
# Get the list of models Ollama supports
|
|
230
|
+
ollama_connection = await get_ollama_connection()
|
|
231
|
+
if ollama_connection is None:
|
|
232
|
+
raise ValueError("Failed to connect to Ollama. Ensure Ollama is running.")
|
|
233
|
+
|
|
234
|
+
for model_name in potential_model_names:
|
|
235
|
+
if ollama_model_installed(ollama_connection, model_name):
|
|
236
|
+
return ChatOllama(model=model_name, base_url=ollama_base_url())
|
|
237
|
+
|
|
238
|
+
raise ValueError(f"Model {model_name} not installed on Ollama")
|
|
239
|
+
elif provider.name == ModelProviderName.openrouter:
|
|
240
|
+
api_key = Config.shared().open_router_api_key
|
|
241
|
+
base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
|
|
242
|
+
return ChatOpenAI(
|
|
243
|
+
**provider.provider_options,
|
|
244
|
+
openai_api_key=api_key, # type: ignore[arg-type]
|
|
245
|
+
openai_api_base=base_url, # type: ignore[arg-type]
|
|
246
|
+
default_headers={
|
|
247
|
+
"HTTP-Referer": "https://getkiln.ai/openrouter",
|
|
248
|
+
"X-Title": "KilnAI",
|
|
249
|
+
},
|
|
250
|
+
)
|
|
251
|
+
else:
|
|
252
|
+
raise ValueError(f"Invalid model or provider: {model_name} - {provider.name}")
|