kiln-ai 0.6.1__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +2 -0
- kiln_ai/adapters/adapter_registry.py +19 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +29 -21
- kiln_ai/adapters/fine_tune/__init__.py +14 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
- kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
- kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
- kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
- kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
- kiln_ai/adapters/langchain_adapters.py +103 -13
- kiln_ai/adapters/ml_model_list.py +239 -303
- kiln_ai/adapters/ollama_tools.py +115 -0
- kiln_ai/adapters/provider_tools.py +308 -0
- kiln_ai/adapters/repair/repair_task.py +4 -2
- kiln_ai/adapters/repair/test_repair_task.py +6 -11
- kiln_ai/adapters/test_langchain_adapter.py +229 -18
- kiln_ai/adapters/test_ollama_tools.py +42 -0
- kiln_ai/adapters/test_prompt_adaptors.py +7 -5
- kiln_ai/adapters/test_provider_tools.py +531 -0
- kiln_ai/adapters/test_structured_output.py +22 -43
- kiln_ai/datamodel/__init__.py +287 -24
- kiln_ai/datamodel/basemodel.py +122 -38
- kiln_ai/datamodel/model_cache.py +116 -0
- kiln_ai/datamodel/registry.py +31 -0
- kiln_ai/datamodel/test_basemodel.py +167 -4
- kiln_ai/datamodel/test_dataset_split.py +234 -0
- kiln_ai/datamodel/test_example_models.py +12 -0
- kiln_ai/datamodel/test_model_cache.py +244 -0
- kiln_ai/datamodel/test_models.py +215 -1
- kiln_ai/datamodel/test_registry.py +96 -0
- kiln_ai/utils/config.py +14 -1
- kiln_ai/utils/name_generator.py +125 -0
- kiln_ai/utils/test_name_geneator.py +47 -0
- kiln_ai-0.7.1.dist-info/METADATA +237 -0
- kiln_ai-0.7.1.dist-info/RECORD +58 -0
- {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/WHEEL +1 -1
- kiln_ai/adapters/test_ml_model_list.py +0 -181
- kiln_ai-0.6.1.dist-info/METADATA +0 -88
- kiln_ai-0.6.1.dist-info/RECORD +0 -37
- {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
from uuid import uuid4
|
|
2
|
+
|
|
3
|
+
import httpx
|
|
4
|
+
|
|
5
|
+
from kiln_ai.adapters.fine_tune.base_finetune import (
|
|
6
|
+
BaseFinetuneAdapter,
|
|
7
|
+
FineTuneParameter,
|
|
8
|
+
FineTuneStatus,
|
|
9
|
+
FineTuneStatusType,
|
|
10
|
+
)
|
|
11
|
+
from kiln_ai.adapters.fine_tune.dataset_formatter import DatasetFormat, DatasetFormatter
|
|
12
|
+
from kiln_ai.datamodel import DatasetSplit, Task
|
|
13
|
+
from kiln_ai.utils.config import Config
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class FireworksFinetune(BaseFinetuneAdapter):
|
|
17
|
+
"""
|
|
18
|
+
A fine-tuning adapter for Fireworks.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
async def status(self) -> FineTuneStatus:
|
|
22
|
+
status = await self._status()
|
|
23
|
+
# update the datamodel if the status has changed
|
|
24
|
+
if self.datamodel.latest_status != status.status:
|
|
25
|
+
self.datamodel.latest_status = status.status
|
|
26
|
+
if self.datamodel.path:
|
|
27
|
+
self.datamodel.save_to_file()
|
|
28
|
+
|
|
29
|
+
# Deploy every time we check status. This can help resolve issues, Fireworks will undeploy unused models after a time.
|
|
30
|
+
if status.status == FineTuneStatusType.completed:
|
|
31
|
+
deployed = await self._deploy()
|
|
32
|
+
if not deployed:
|
|
33
|
+
status.message = "Fine-tuning job completed but failed to deploy model."
|
|
34
|
+
|
|
35
|
+
return status
|
|
36
|
+
|
|
37
|
+
async def _status(self) -> FineTuneStatus:
|
|
38
|
+
try:
|
|
39
|
+
api_key = Config.shared().fireworks_api_key
|
|
40
|
+
account_id = Config.shared().fireworks_account_id
|
|
41
|
+
if not api_key or not account_id:
|
|
42
|
+
return FineTuneStatus(
|
|
43
|
+
status=FineTuneStatusType.unknown,
|
|
44
|
+
message="Fireworks API key or account ID not set",
|
|
45
|
+
)
|
|
46
|
+
fine_tuning_job_id = self.datamodel.provider_id
|
|
47
|
+
if not fine_tuning_job_id:
|
|
48
|
+
return FineTuneStatus(
|
|
49
|
+
status=FineTuneStatusType.unknown,
|
|
50
|
+
message="Fine-tuning job ID not set. Can not retrieve status.",
|
|
51
|
+
)
|
|
52
|
+
# Fireworks uses path style IDs
|
|
53
|
+
url = f"https://api.fireworks.ai/v1/{fine_tuning_job_id}"
|
|
54
|
+
headers = {"Authorization": f"Bearer {api_key}"}
|
|
55
|
+
|
|
56
|
+
async with httpx.AsyncClient() as client:
|
|
57
|
+
response = await client.get(url, headers=headers, timeout=15.0)
|
|
58
|
+
|
|
59
|
+
if response.status_code != 200:
|
|
60
|
+
return FineTuneStatus(
|
|
61
|
+
status=FineTuneStatusType.unknown,
|
|
62
|
+
message=f"Error retrieving fine-tuning job status: [{response.status_code}] {response.text}",
|
|
63
|
+
)
|
|
64
|
+
data = response.json()
|
|
65
|
+
|
|
66
|
+
if "state" not in data:
|
|
67
|
+
return FineTuneStatus(
|
|
68
|
+
status=FineTuneStatusType.unknown,
|
|
69
|
+
message="Invalid response from Fireworks (no state).",
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
state = data["state"]
|
|
73
|
+
if state in ["FAILED", "DELETING"]:
|
|
74
|
+
return FineTuneStatus(
|
|
75
|
+
status=FineTuneStatusType.failed,
|
|
76
|
+
message="Fine-tuning job failed",
|
|
77
|
+
)
|
|
78
|
+
elif state in ["CREATING", "PENDING", "RUNNING"]:
|
|
79
|
+
return FineTuneStatus(
|
|
80
|
+
status=FineTuneStatusType.running,
|
|
81
|
+
message=f"Fine-tuning job is running [{state}]",
|
|
82
|
+
)
|
|
83
|
+
elif state == "COMPLETED":
|
|
84
|
+
return FineTuneStatus(
|
|
85
|
+
status=FineTuneStatusType.completed,
|
|
86
|
+
message="Fine-tuning job completed",
|
|
87
|
+
)
|
|
88
|
+
else:
|
|
89
|
+
return FineTuneStatus(
|
|
90
|
+
status=FineTuneStatusType.unknown,
|
|
91
|
+
message=f"Unknown fine-tuning job status [{state}]",
|
|
92
|
+
)
|
|
93
|
+
except Exception as e:
|
|
94
|
+
return FineTuneStatus(
|
|
95
|
+
status=FineTuneStatusType.unknown,
|
|
96
|
+
message=f"Error retrieving fine-tuning job status: {e}",
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
async def _start(self, dataset: DatasetSplit) -> None:
|
|
100
|
+
task = self.datamodel.parent_task()
|
|
101
|
+
if not task:
|
|
102
|
+
raise ValueError("Task is required to start a fine-tune")
|
|
103
|
+
|
|
104
|
+
train_file_id = await self.generate_and_upload_jsonl(
|
|
105
|
+
dataset, self.datamodel.train_split_name, task
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
api_key = Config.shared().fireworks_api_key
|
|
109
|
+
account_id = Config.shared().fireworks_account_id
|
|
110
|
+
if not api_key or not account_id:
|
|
111
|
+
raise ValueError("Fireworks API key or account ID not set")
|
|
112
|
+
|
|
113
|
+
url = f"https://api.fireworks.ai/v1/accounts/{account_id}/fineTuningJobs"
|
|
114
|
+
# Model ID != fine tune ID on Fireworks. Model is the result of the tune job.
|
|
115
|
+
model_id = str(uuid4())
|
|
116
|
+
# Limit the display name to 60 characters
|
|
117
|
+
display_name = (
|
|
118
|
+
f"Kiln AI fine-tuning [ID:{self.datamodel.id}][name:{self.datamodel.name}]"[
|
|
119
|
+
:60
|
|
120
|
+
]
|
|
121
|
+
)
|
|
122
|
+
payload = {
|
|
123
|
+
"modelId": model_id,
|
|
124
|
+
"dataset": f"accounts/{account_id}/datasets/{train_file_id}",
|
|
125
|
+
"displayName": display_name,
|
|
126
|
+
"baseModel": self.datamodel.base_model_id,
|
|
127
|
+
"conversation": {},
|
|
128
|
+
}
|
|
129
|
+
hyperparameters = self.create_payload_parameters(self.datamodel.parameters)
|
|
130
|
+
payload.update(hyperparameters)
|
|
131
|
+
headers = {
|
|
132
|
+
"Authorization": f"Bearer {api_key}",
|
|
133
|
+
"Content-Type": "application/json",
|
|
134
|
+
}
|
|
135
|
+
async with httpx.AsyncClient() as client:
|
|
136
|
+
response = await client.post(url, json=payload, headers=headers)
|
|
137
|
+
if response.status_code != 200:
|
|
138
|
+
raise ValueError(
|
|
139
|
+
f"Failed to create fine-tuning job: [{response.status_code}] {response.text}"
|
|
140
|
+
)
|
|
141
|
+
data = response.json()
|
|
142
|
+
if "name" not in data:
|
|
143
|
+
raise ValueError(
|
|
144
|
+
f"Failed to create fine-tuning job with valid name: [{response.status_code}] {response.text}"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# name is actually the ID of the fine-tune job,
|
|
148
|
+
# model ID is the model that results from the fine-tune job
|
|
149
|
+
job_id = data["name"]
|
|
150
|
+
self.datamodel.provider_id = job_id
|
|
151
|
+
# Keep track of the expected model ID before it's deployed as a property. We move it to fine_tune_model_id after deployment.
|
|
152
|
+
self.datamodel.properties["undeployed_model_id"] = (
|
|
153
|
+
f"accounts/{account_id}/models/{model_id}"
|
|
154
|
+
)
|
|
155
|
+
if self.datamodel.path:
|
|
156
|
+
self.datamodel.save_to_file()
|
|
157
|
+
|
|
158
|
+
async def generate_and_upload_jsonl(
|
|
159
|
+
self, dataset: DatasetSplit, split_name: str, task: Task
|
|
160
|
+
) -> str:
|
|
161
|
+
formatter = DatasetFormatter(dataset, self.datamodel.system_message)
|
|
162
|
+
# OpenAI compatible: https://docs.fireworks.ai/fine-tuning/fine-tuning-models#conversation
|
|
163
|
+
# Note: Fireworks does not support tool calls (confirmed by Fireworks team) so we'll use json mode
|
|
164
|
+
format = DatasetFormat.OPENAI_CHAT_JSONL
|
|
165
|
+
path = formatter.dump_to_file(split_name, format)
|
|
166
|
+
|
|
167
|
+
# First call creates the dataset
|
|
168
|
+
api_key = Config.shared().fireworks_api_key
|
|
169
|
+
account_id = Config.shared().fireworks_account_id
|
|
170
|
+
if not api_key or not account_id:
|
|
171
|
+
raise ValueError("Fireworks API key or account ID not set")
|
|
172
|
+
url = f"https://api.fireworks.ai/v1/accounts/{account_id}/datasets"
|
|
173
|
+
dataset_id = str(uuid4())
|
|
174
|
+
payload = {
|
|
175
|
+
"datasetId": dataset_id,
|
|
176
|
+
"dataset": {
|
|
177
|
+
"displayName": f"Kiln AI fine-tuning for dataset ID [{dataset.id}] split [{split_name}]",
|
|
178
|
+
"userUploaded": {},
|
|
179
|
+
},
|
|
180
|
+
}
|
|
181
|
+
headers = {
|
|
182
|
+
"Authorization": f"Bearer {api_key}",
|
|
183
|
+
"Content-Type": "application/json",
|
|
184
|
+
}
|
|
185
|
+
async with httpx.AsyncClient() as client:
|
|
186
|
+
create_dataset_response = await client.post(
|
|
187
|
+
url, json=payload, headers=headers
|
|
188
|
+
)
|
|
189
|
+
if create_dataset_response.status_code != 200:
|
|
190
|
+
raise ValueError(
|
|
191
|
+
f"Failed to create dataset: [{create_dataset_response.status_code}] {create_dataset_response.text}"
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# Second call uploads the dataset
|
|
195
|
+
url = f"https://api.fireworks.ai/v1/accounts/{account_id}/datasets/{dataset_id}:upload"
|
|
196
|
+
headers = {
|
|
197
|
+
"Authorization": f"Bearer {api_key}",
|
|
198
|
+
}
|
|
199
|
+
async with httpx.AsyncClient() as client:
|
|
200
|
+
with open(path, "rb") as f:
|
|
201
|
+
files = {"file": f}
|
|
202
|
+
upload_dataset_response = await client.post(
|
|
203
|
+
url,
|
|
204
|
+
headers=headers,
|
|
205
|
+
files=files,
|
|
206
|
+
)
|
|
207
|
+
if upload_dataset_response.status_code != 200:
|
|
208
|
+
raise ValueError(
|
|
209
|
+
f"Failed to upload dataset: [{upload_dataset_response.status_code}] {upload_dataset_response.text}"
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
# Third call checks it's "READY"
|
|
213
|
+
url = f"https://api.fireworks.ai/v1/accounts/{account_id}/datasets/{dataset_id}"
|
|
214
|
+
async with httpx.AsyncClient() as client:
|
|
215
|
+
response = await client.get(url, headers=headers)
|
|
216
|
+
if response.status_code != 200:
|
|
217
|
+
raise ValueError(
|
|
218
|
+
f"Failed to check dataset status: [{response.status_code}] {response.text}"
|
|
219
|
+
)
|
|
220
|
+
data = response.json()
|
|
221
|
+
if data["state"] != "READY":
|
|
222
|
+
raise ValueError(f"Dataset is not ready [{data['state']}]")
|
|
223
|
+
|
|
224
|
+
return dataset_id
|
|
225
|
+
|
|
226
|
+
@classmethod
|
|
227
|
+
def available_parameters(cls) -> list[FineTuneParameter]:
|
|
228
|
+
return [
|
|
229
|
+
FineTuneParameter(
|
|
230
|
+
name="epochs",
|
|
231
|
+
description="The number of epochs to fine-tune for. If not provided, defaults to a recommended value.",
|
|
232
|
+
type="int",
|
|
233
|
+
optional=True,
|
|
234
|
+
),
|
|
235
|
+
FineTuneParameter(
|
|
236
|
+
name="learning_rate",
|
|
237
|
+
description="The learning rate to use for fine-tuning. If not provided, defaults to a recommended value.",
|
|
238
|
+
type="float",
|
|
239
|
+
optional=True,
|
|
240
|
+
),
|
|
241
|
+
FineTuneParameter(
|
|
242
|
+
name="batch_size",
|
|
243
|
+
description="The batch size of dataset used in training can be configured with a positive integer less than 1024 and in power of 2. If not specified, a reasonable default value will be chosen.",
|
|
244
|
+
type="int",
|
|
245
|
+
optional=True,
|
|
246
|
+
),
|
|
247
|
+
FineTuneParameter(
|
|
248
|
+
name="lora_rank",
|
|
249
|
+
description="LoRA rank refers to the dimensionality of trainable matrices in Low-Rank Adaptation fine-tuning, balancing model adaptability and computational efficiency in fine-tuning large language models. The LoRA rank used in training can be configured with a positive integer with a max value of 32. If not specified, a reasonable default value will be chosen.",
|
|
250
|
+
type="int",
|
|
251
|
+
optional=True,
|
|
252
|
+
),
|
|
253
|
+
]
|
|
254
|
+
|
|
255
|
+
def create_payload_parameters(
|
|
256
|
+
self, parameters: dict[str, str | int | float | bool]
|
|
257
|
+
) -> dict:
|
|
258
|
+
payload = {
|
|
259
|
+
"loraRank": parameters.get("lora_rank"),
|
|
260
|
+
"epochs": parameters.get("epochs"),
|
|
261
|
+
"learningRate": parameters.get("learning_rate"),
|
|
262
|
+
"batchSize": parameters.get("batch_size"),
|
|
263
|
+
}
|
|
264
|
+
return {k: v for k, v in payload.items() if v is not None}
|
|
265
|
+
|
|
266
|
+
async def _deploy(self) -> bool:
|
|
267
|
+
# Now we "deploy" the model using PEFT serverless.
|
|
268
|
+
# A bit complicated: most fireworks deploys are server based.
|
|
269
|
+
# However, a Lora can be serverless (PEFT).
|
|
270
|
+
# By calling the deploy endpoint WITHOUT first creating a deployment ID, it will only deploy if it can be done serverless.
|
|
271
|
+
# https://docs.fireworks.ai/models/deploying#deploying-to-serverless
|
|
272
|
+
# This endpoint will return 400 if already deployed with code 9, so we consider that a success.
|
|
273
|
+
|
|
274
|
+
api_key = Config.shared().fireworks_api_key
|
|
275
|
+
account_id = Config.shared().fireworks_account_id
|
|
276
|
+
if not api_key or not account_id:
|
|
277
|
+
raise ValueError("Fireworks API key or account ID not set")
|
|
278
|
+
|
|
279
|
+
model_id = self.datamodel.properties.get("undeployed_model_id")
|
|
280
|
+
if not model_id or not isinstance(model_id, str):
|
|
281
|
+
return False
|
|
282
|
+
|
|
283
|
+
url = f"https://api.fireworks.ai/v1/accounts/{account_id}/deployedModels"
|
|
284
|
+
# Limit the display name to 60 characters
|
|
285
|
+
display_name = f"Kiln AI fine-tuned model [ID:{self.datamodel.id}][name:{self.datamodel.name}]"[
|
|
286
|
+
:60
|
|
287
|
+
]
|
|
288
|
+
payload = {
|
|
289
|
+
"displayName": display_name,
|
|
290
|
+
"model": model_id,
|
|
291
|
+
}
|
|
292
|
+
headers = {
|
|
293
|
+
"Authorization": f"Bearer {api_key}",
|
|
294
|
+
"Content-Type": "application/json",
|
|
295
|
+
}
|
|
296
|
+
async with httpx.AsyncClient() as client:
|
|
297
|
+
response = await client.post(url, json=payload, headers=headers)
|
|
298
|
+
|
|
299
|
+
# Fresh deploy worked (200) or already deployed (code=9)
|
|
300
|
+
if response.status_code == 200 or response.json().get("code") == 9:
|
|
301
|
+
# Update the datamodel if the model ID has changed, which makes it available to use in the UI
|
|
302
|
+
if self.datamodel.fine_tune_model_id != model_id:
|
|
303
|
+
self.datamodel.fine_tune_model_id = model_id
|
|
304
|
+
if self.datamodel.path:
|
|
305
|
+
self.datamodel.save_to_file()
|
|
306
|
+
return True
|
|
307
|
+
|
|
308
|
+
return False
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
import time
|
|
2
|
+
|
|
3
|
+
import openai
|
|
4
|
+
from openai.types.fine_tuning import FineTuningJob
|
|
5
|
+
|
|
6
|
+
from kiln_ai.adapters.fine_tune.base_finetune import (
|
|
7
|
+
BaseFinetuneAdapter,
|
|
8
|
+
FineTuneParameter,
|
|
9
|
+
FineTuneStatus,
|
|
10
|
+
FineTuneStatusType,
|
|
11
|
+
)
|
|
12
|
+
from kiln_ai.adapters.fine_tune.dataset_formatter import DatasetFormat, DatasetFormatter
|
|
13
|
+
from kiln_ai.datamodel import DatasetSplit, Task
|
|
14
|
+
from kiln_ai.utils.config import Config
|
|
15
|
+
|
|
16
|
+
oai_client = openai.AsyncOpenAI(
|
|
17
|
+
api_key=Config.shared().open_ai_api_key or "",
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class OpenAIFinetune(BaseFinetuneAdapter):
|
|
22
|
+
"""
|
|
23
|
+
A fine-tuning adapter for OpenAI.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
async def status(self) -> FineTuneStatus:
|
|
27
|
+
"""
|
|
28
|
+
Get the status of the fine-tune.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
# Update the datamodel with the latest status if it has changed
|
|
32
|
+
status = await self._status()
|
|
33
|
+
if status.status != self.datamodel.latest_status:
|
|
34
|
+
self.datamodel.latest_status = status.status
|
|
35
|
+
if self.datamodel.path:
|
|
36
|
+
self.datamodel.save_to_file()
|
|
37
|
+
return status
|
|
38
|
+
|
|
39
|
+
async def _status(self) -> FineTuneStatus:
|
|
40
|
+
if not self.datamodel or not self.datamodel.provider_id:
|
|
41
|
+
return FineTuneStatus(
|
|
42
|
+
status=FineTuneStatusType.pending,
|
|
43
|
+
message="This fine-tune has not been started or has not been assigned a provider ID.",
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
# Will raise an error if the job is not found, or for other issues
|
|
48
|
+
response = await oai_client.fine_tuning.jobs.retrieve(
|
|
49
|
+
self.datamodel.provider_id
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# If the fine-tuned model has been updated, update the datamodel
|
|
53
|
+
try:
|
|
54
|
+
if (
|
|
55
|
+
self.datamodel.fine_tune_model_id != response.fine_tuned_model
|
|
56
|
+
or self.datamodel.base_model_id != response.model
|
|
57
|
+
):
|
|
58
|
+
self.datamodel.fine_tune_model_id = response.fine_tuned_model
|
|
59
|
+
self.datamodel.base_model_id = response.model
|
|
60
|
+
self.datamodel.save_to_file()
|
|
61
|
+
except Exception:
|
|
62
|
+
# Don't let this error crash the status call
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
except openai.APIConnectionError:
|
|
66
|
+
return FineTuneStatus(
|
|
67
|
+
status=FineTuneStatusType.unknown, message="Server connection error"
|
|
68
|
+
)
|
|
69
|
+
except openai.RateLimitError:
|
|
70
|
+
return FineTuneStatus(
|
|
71
|
+
status=FineTuneStatusType.unknown,
|
|
72
|
+
message="Rate limit exceeded. Could not fetch fine-tune status.",
|
|
73
|
+
)
|
|
74
|
+
except openai.APIStatusError as e:
|
|
75
|
+
if e.status_code == 404:
|
|
76
|
+
return FineTuneStatus(
|
|
77
|
+
status=FineTuneStatusType.unknown,
|
|
78
|
+
message="Job with this ID not found. It may have been deleted.",
|
|
79
|
+
)
|
|
80
|
+
return FineTuneStatus(
|
|
81
|
+
status=FineTuneStatusType.unknown,
|
|
82
|
+
message=f"Unknown error: [{str(e)}]",
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
if not response or not isinstance(response, FineTuningJob):
|
|
86
|
+
return FineTuneStatus(
|
|
87
|
+
status=FineTuneStatusType.unknown,
|
|
88
|
+
message="Invalid response from OpenAI",
|
|
89
|
+
)
|
|
90
|
+
if response.error and response.error.code:
|
|
91
|
+
return FineTuneStatus(
|
|
92
|
+
status=FineTuneStatusType.failed,
|
|
93
|
+
message=f"{response.error.message} [Code: {response.error.code}]",
|
|
94
|
+
)
|
|
95
|
+
status = response.status
|
|
96
|
+
if status == "failed":
|
|
97
|
+
return FineTuneStatus(
|
|
98
|
+
status=FineTuneStatusType.failed,
|
|
99
|
+
message="Job failed - unknown reason",
|
|
100
|
+
)
|
|
101
|
+
if status == "cancelled":
|
|
102
|
+
return FineTuneStatus(
|
|
103
|
+
status=FineTuneStatusType.failed, message="Job cancelled"
|
|
104
|
+
)
|
|
105
|
+
if status in ["validating_files", "running", "queued"]:
|
|
106
|
+
time_to_finish_msg: str | None = None
|
|
107
|
+
if response.estimated_finish is not None:
|
|
108
|
+
time_to_finish_msg = f"Estimated finish time: {int(response.estimated_finish - time.time())} seconds."
|
|
109
|
+
return FineTuneStatus(
|
|
110
|
+
status=FineTuneStatusType.running,
|
|
111
|
+
message=f"Fine tune job is running [{status}]. {time_to_finish_msg or ''}",
|
|
112
|
+
)
|
|
113
|
+
if status == "succeeded":
|
|
114
|
+
return FineTuneStatus(
|
|
115
|
+
status=FineTuneStatusType.completed, message="Training job completed"
|
|
116
|
+
)
|
|
117
|
+
return FineTuneStatus(
|
|
118
|
+
status=FineTuneStatusType.unknown,
|
|
119
|
+
message=f"Unknown status: [{status}]",
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
async def _start(self, dataset: DatasetSplit) -> None:
|
|
123
|
+
task = self.datamodel.parent_task()
|
|
124
|
+
if not task:
|
|
125
|
+
raise ValueError("Task is required to start a fine-tune")
|
|
126
|
+
|
|
127
|
+
train_file_id = await self.generate_and_upload_jsonl(
|
|
128
|
+
dataset, self.datamodel.train_split_name, task
|
|
129
|
+
)
|
|
130
|
+
validation_file_id = None
|
|
131
|
+
if self.datamodel.validation_split_name:
|
|
132
|
+
validation_file_id = await self.generate_and_upload_jsonl(
|
|
133
|
+
dataset, self.datamodel.validation_split_name, task
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Filter to hyperparameters which are set via the hyperparameters field (some like seed are set via the API)
|
|
137
|
+
hyperparameters = {
|
|
138
|
+
k: v
|
|
139
|
+
for k, v in self.datamodel.parameters.items()
|
|
140
|
+
if k in ["n_epochs", "learning_rate_multiplier", "batch_size"]
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
ft = await oai_client.fine_tuning.jobs.create(
|
|
144
|
+
training_file=train_file_id,
|
|
145
|
+
model=self.datamodel.base_model_id,
|
|
146
|
+
validation_file=validation_file_id,
|
|
147
|
+
seed=self.datamodel.parameters.get("seed"), # type: ignore
|
|
148
|
+
hyperparameters=hyperparameters, # type: ignore
|
|
149
|
+
suffix=f"kiln_ai.{self.datamodel.id}",
|
|
150
|
+
)
|
|
151
|
+
self.datamodel.provider_id = ft.id
|
|
152
|
+
self.datamodel.fine_tune_model_id = ft.fine_tuned_model
|
|
153
|
+
# Model can get more specific after fine-tune call (gpt-4o-mini to gpt-4o-mini-2024-07-18) so we update it in the datamodel
|
|
154
|
+
self.datamodel.base_model_id = ft.model
|
|
155
|
+
|
|
156
|
+
return None
|
|
157
|
+
|
|
158
|
+
async def generate_and_upload_jsonl(
|
|
159
|
+
self, dataset: DatasetSplit, split_name: str, task: Task
|
|
160
|
+
) -> str:
|
|
161
|
+
formatter = DatasetFormatter(dataset, self.datamodel.system_message)
|
|
162
|
+
# All OpenAI models support tool calls for structured outputs
|
|
163
|
+
format = (
|
|
164
|
+
DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL
|
|
165
|
+
if task.output_json_schema
|
|
166
|
+
else DatasetFormat.OPENAI_CHAT_JSONL
|
|
167
|
+
)
|
|
168
|
+
path = formatter.dump_to_file(split_name, format)
|
|
169
|
+
|
|
170
|
+
response = await oai_client.files.create(
|
|
171
|
+
file=open(path, "rb"),
|
|
172
|
+
purpose="fine-tune",
|
|
173
|
+
)
|
|
174
|
+
id = response.id
|
|
175
|
+
if not id:
|
|
176
|
+
raise ValueError("Failed to upload file to OpenAI")
|
|
177
|
+
return id
|
|
178
|
+
|
|
179
|
+
@classmethod
|
|
180
|
+
def available_parameters(cls) -> list[FineTuneParameter]:
|
|
181
|
+
return [
|
|
182
|
+
FineTuneParameter(
|
|
183
|
+
name="batch_size",
|
|
184
|
+
type="int",
|
|
185
|
+
description="Number of examples in each batch. A larger batch size means that model parameters are updated less frequently, but with lower variance. Defaults to 'auto'",
|
|
186
|
+
),
|
|
187
|
+
FineTuneParameter(
|
|
188
|
+
name="learning_rate_multiplier",
|
|
189
|
+
type="float",
|
|
190
|
+
description="Scaling factor for the learning rate. A smaller learning rate may be useful to avoid overfitting. Defaults to 'auto'",
|
|
191
|
+
optional=True,
|
|
192
|
+
),
|
|
193
|
+
FineTuneParameter(
|
|
194
|
+
name="n_epochs",
|
|
195
|
+
type="int",
|
|
196
|
+
description="The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset. Defaults to 'auto'",
|
|
197
|
+
optional=True,
|
|
198
|
+
),
|
|
199
|
+
FineTuneParameter(
|
|
200
|
+
name="seed",
|
|
201
|
+
type="int",
|
|
202
|
+
description="The seed controls the reproducibility of the job. Passing in the same seed and job parameters should produce the same results, but may differ in rare cases. If a seed is not specified, one will be generated for you.",
|
|
203
|
+
optional=True,
|
|
204
|
+
),
|
|
205
|
+
]
|