kiln-ai 0.13.2__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/eval/base_eval.py +7 -2
- kiln_ai/adapters/fine_tune/base_finetune.py +6 -23
- kiln_ai/adapters/fine_tune/dataset_formatter.py +4 -4
- kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +163 -15
- kiln_ai/adapters/fine_tune/test_base_finetune.py +7 -9
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +3 -3
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +495 -9
- kiln_ai/adapters/fine_tune/test_vertex_finetune.py +586 -0
- kiln_ai/adapters/fine_tune/vertex_finetune.py +217 -0
- kiln_ai/adapters/ml_model_list.py +319 -43
- kiln_ai/adapters/model_adapters/base_adapter.py +15 -10
- kiln_ai/adapters/model_adapters/litellm_adapter.py +10 -5
- kiln_ai/adapters/provider_tools.py +7 -0
- kiln_ai/adapters/test_provider_tools.py +16 -0
- kiln_ai/datamodel/json_schema.py +24 -7
- kiln_ai/datamodel/task_output.py +9 -5
- kiln_ai/datamodel/task_run.py +29 -5
- kiln_ai/datamodel/test_example_models.py +104 -3
- kiln_ai/datamodel/test_json_schema.py +22 -3
- kiln_ai/datamodel/test_model_perf.py +3 -2
- {kiln_ai-0.13.2.dist-info → kiln_ai-0.15.0.dist-info}/METADATA +3 -2
- {kiln_ai-0.13.2.dist-info → kiln_ai-0.15.0.dist-info}/RECORD +25 -24
- kiln_ai/adapters/test_generate_docs.py +0 -69
- {kiln_ai-0.13.2.dist-info → kiln_ai-0.15.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.13.2.dist-info → kiln_ai-0.15.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -2,11 +2,13 @@ import json
|
|
|
2
2
|
from abc import abstractmethod
|
|
3
3
|
from typing import Dict
|
|
4
4
|
|
|
5
|
+
import jsonschema
|
|
6
|
+
|
|
5
7
|
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
6
8
|
from kiln_ai.adapters.ml_model_list import ModelProviderName
|
|
7
9
|
from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
|
|
8
10
|
from kiln_ai.datamodel.eval import Eval, EvalConfig, EvalScores
|
|
9
|
-
from kiln_ai.datamodel.json_schema import
|
|
11
|
+
from kiln_ai.datamodel.json_schema import validate_schema_with_value_error
|
|
10
12
|
from kiln_ai.datamodel.task import RunConfig, TaskOutputRatingType, TaskRun
|
|
11
13
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
12
14
|
|
|
@@ -72,7 +74,10 @@ class BaseEval:
|
|
|
72
74
|
run_output = await run_adapter.invoke(parsed_input)
|
|
73
75
|
|
|
74
76
|
eval_output, intermediate_outputs = await self.run_eval(run_output)
|
|
75
|
-
|
|
77
|
+
|
|
78
|
+
validate_schema_with_value_error(
|
|
79
|
+
eval_output, self.score_schema, "Eval output does not match score schema."
|
|
80
|
+
)
|
|
76
81
|
|
|
77
82
|
return run_output, eval_output, intermediate_outputs
|
|
78
83
|
|
|
@@ -72,8 +72,6 @@ class BaseFinetuneAdapter(ABC):
|
|
|
72
72
|
Create and start a fine-tune.
|
|
73
73
|
"""
|
|
74
74
|
|
|
75
|
-
cls.check_valid_provider_model(provider_id, provider_base_model_id)
|
|
76
|
-
|
|
77
75
|
if not dataset.id:
|
|
78
76
|
raise ValueError("Dataset must have an id")
|
|
79
77
|
|
|
@@ -168,9 +166,12 @@ class BaseFinetuneAdapter(ABC):
|
|
|
168
166
|
|
|
169
167
|
# Strict type checking for numeric types
|
|
170
168
|
if expected_type is float and not isinstance(value, float):
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
169
|
+
if isinstance(value, int):
|
|
170
|
+
value = float(value)
|
|
171
|
+
else:
|
|
172
|
+
raise ValueError(
|
|
173
|
+
f"Parameter {parameter.name} must be a float, got {type(value)}"
|
|
174
|
+
)
|
|
174
175
|
elif expected_type is int and not isinstance(value, int):
|
|
175
176
|
raise ValueError(
|
|
176
177
|
f"Parameter {parameter.name} must be an integer, got {type(value)}"
|
|
@@ -184,21 +185,3 @@ class BaseFinetuneAdapter(ABC):
|
|
|
184
185
|
for parameter_key in parameters:
|
|
185
186
|
if parameter_key not in allowed_parameters:
|
|
186
187
|
raise ValueError(f"Parameter {parameter_key} is not available")
|
|
187
|
-
|
|
188
|
-
@classmethod
|
|
189
|
-
def check_valid_provider_model(
|
|
190
|
-
cls, provider_id: str, provider_base_model_id: str
|
|
191
|
-
) -> None:
|
|
192
|
-
"""
|
|
193
|
-
Check if the provider and base model are valid.
|
|
194
|
-
"""
|
|
195
|
-
for model in built_in_models:
|
|
196
|
-
for provider in model.providers:
|
|
197
|
-
if (
|
|
198
|
-
provider.name == provider_id
|
|
199
|
-
and provider.provider_finetune_id == provider_base_model_id
|
|
200
|
-
):
|
|
201
|
-
return
|
|
202
|
-
raise ValueError(
|
|
203
|
-
f"Provider {provider_id} with base model {provider_base_model_id} is not available"
|
|
204
|
-
)
|
|
@@ -30,8 +30,8 @@ class DatasetFormat(str, Enum):
|
|
|
30
30
|
"huggingface_chat_template_toolcall_jsonl"
|
|
31
31
|
)
|
|
32
32
|
|
|
33
|
-
"""Vertex Gemini
|
|
34
|
-
|
|
33
|
+
"""Vertex Gemini format"""
|
|
34
|
+
VERTEX_GEMINI = "vertex_gemini"
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
@dataclass
|
|
@@ -288,7 +288,7 @@ def generate_huggingface_chat_template_toolcall(
|
|
|
288
288
|
return {"conversations": conversations}
|
|
289
289
|
|
|
290
290
|
|
|
291
|
-
def
|
|
291
|
+
def generate_vertex_gemini(
|
|
292
292
|
training_data: ModelTrainingData,
|
|
293
293
|
) -> Dict[str, Any]:
|
|
294
294
|
"""Generate Vertex Gemini 1.5 format (flash and pro)"""
|
|
@@ -346,7 +346,7 @@ FORMAT_GENERATORS: Dict[DatasetFormat, FormatGenerator] = {
|
|
|
346
346
|
DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL: generate_chat_message_toolcall,
|
|
347
347
|
DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_JSONL: generate_huggingface_chat_template,
|
|
348
348
|
DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL: generate_huggingface_chat_template_toolcall,
|
|
349
|
-
DatasetFormat.
|
|
349
|
+
DatasetFormat.VERTEX_GEMINI: generate_vertex_gemini,
|
|
350
350
|
}
|
|
351
351
|
|
|
352
352
|
|
|
@@ -4,10 +4,12 @@ from kiln_ai.adapters.fine_tune.base_finetune import BaseFinetuneAdapter
|
|
|
4
4
|
from kiln_ai.adapters.fine_tune.fireworks_finetune import FireworksFinetune
|
|
5
5
|
from kiln_ai.adapters.fine_tune.openai_finetune import OpenAIFinetune
|
|
6
6
|
from kiln_ai.adapters.fine_tune.together_finetune import TogetherFinetune
|
|
7
|
+
from kiln_ai.adapters.fine_tune.vertex_finetune import VertexFinetune
|
|
7
8
|
from kiln_ai.adapters.ml_model_list import ModelProviderName
|
|
8
9
|
|
|
9
10
|
finetune_registry: dict[ModelProviderName, Type[BaseFinetuneAdapter]] = {
|
|
10
11
|
ModelProviderName.openai: OpenAIFinetune,
|
|
11
12
|
ModelProviderName.fireworks_ai: FireworksFinetune,
|
|
12
13
|
ModelProviderName.together_ai: TogetherFinetune,
|
|
14
|
+
ModelProviderName.vertex: VertexFinetune,
|
|
13
15
|
}
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
|
|
1
|
+
import logging
|
|
2
|
+
from typing import List, Tuple
|
|
2
3
|
from uuid import uuid4
|
|
3
4
|
|
|
4
5
|
import httpx
|
|
@@ -13,6 +14,14 @@ from kiln_ai.adapters.fine_tune.dataset_formatter import DatasetFormat, DatasetF
|
|
|
13
14
|
from kiln_ai.datamodel import DatasetSplit, StructuredOutputMode, Task
|
|
14
15
|
from kiln_ai.utils.config import Config
|
|
15
16
|
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
# https://docs.fireworks.ai/fine-tuning/fine-tuning-models#supported-base-models-loras-on-serverless
|
|
20
|
+
serverless_models = [
|
|
21
|
+
"accounts/fireworks/models/llama-v3p1-8b-instruct",
|
|
22
|
+
"accounts/fireworks/models/llama-v3p1-70b-instruct",
|
|
23
|
+
]
|
|
24
|
+
|
|
16
25
|
|
|
17
26
|
class FireworksFinetune(BaseFinetuneAdapter):
|
|
18
27
|
"""
|
|
@@ -189,7 +198,8 @@ class FireworksFinetune(BaseFinetuneAdapter):
|
|
|
189
198
|
if not api_key or not account_id:
|
|
190
199
|
raise ValueError("Fireworks API key or account ID not set")
|
|
191
200
|
url = f"https://api.fireworks.ai/v1/accounts/{account_id}/datasets"
|
|
192
|
-
|
|
201
|
+
# First char can't be a digit: https://discord.com/channels/1137072072808472616/1363214412395184350/1363214412395184350
|
|
202
|
+
dataset_id = "kiln-" + str(uuid4())
|
|
193
203
|
payload = {
|
|
194
204
|
"datasetId": dataset_id,
|
|
195
205
|
"dataset": {
|
|
@@ -283,32 +293,54 @@ class FireworksFinetune(BaseFinetuneAdapter):
|
|
|
283
293
|
return {k: v for k, v in payload.items() if v is not None}
|
|
284
294
|
|
|
285
295
|
async def _deploy(self) -> bool:
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
# https://docs.fireworks.ai/models/deploying#deploying-to-serverless
|
|
291
|
-
# This endpoint will return 400 if already deployed with code 9, so we consider that a success.
|
|
296
|
+
if self.datamodel.base_model_id in serverless_models:
|
|
297
|
+
return await self._deploy_serverless()
|
|
298
|
+
else:
|
|
299
|
+
return await self._check_or_deploy_server()
|
|
292
300
|
|
|
301
|
+
def api_key_and_account_id(self) -> Tuple[str, str]:
|
|
293
302
|
api_key = Config.shared().fireworks_api_key
|
|
294
303
|
account_id = Config.shared().fireworks_account_id
|
|
295
304
|
if not api_key or not account_id:
|
|
296
305
|
raise ValueError("Fireworks API key or account ID not set")
|
|
306
|
+
return api_key, account_id
|
|
297
307
|
|
|
308
|
+
def deployment_display_name(self) -> str:
|
|
309
|
+
# Limit the display name to 60 characters
|
|
310
|
+
display_name = f"Kiln AI fine-tuned model [ID:{self.datamodel.id}][name:{self.datamodel.name}]"[
|
|
311
|
+
:60
|
|
312
|
+
]
|
|
313
|
+
return display_name
|
|
314
|
+
|
|
315
|
+
async def model_id_checking_status(self) -> str | None:
|
|
298
316
|
# Model ID != fine tune ID on Fireworks. Model is the result of the tune job. Call status to get it.
|
|
299
317
|
status, model_id = await self._status()
|
|
300
318
|
if status.status != FineTuneStatusType.completed:
|
|
301
|
-
return
|
|
319
|
+
return None
|
|
302
320
|
if not model_id or not isinstance(model_id, str):
|
|
303
|
-
return
|
|
321
|
+
return None
|
|
322
|
+
return model_id
|
|
323
|
+
|
|
324
|
+
async def _deploy_serverless(self) -> bool:
|
|
325
|
+
# Now we "deploy" the model using PEFT serverless.
|
|
326
|
+
# A bit complicated: most fireworks deploys are server based.
|
|
327
|
+
# However, a Lora can be serverless (PEFT).
|
|
328
|
+
# By calling the deploy endpoint WITHOUT first creating a deployment ID, it will only deploy if it can be done serverless.
|
|
329
|
+
# https://docs.fireworks.ai/models/deploying#deploying-to-serverless
|
|
330
|
+
# This endpoint will return 400 if already deployed with code 9, so we consider that a success.
|
|
331
|
+
|
|
332
|
+
api_key, account_id = self.api_key_and_account_id()
|
|
304
333
|
|
|
305
334
|
url = f"https://api.fireworks.ai/v1/accounts/{account_id}/deployedModels"
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
335
|
+
model_id = await self.model_id_checking_status()
|
|
336
|
+
if not model_id:
|
|
337
|
+
logger.error(
|
|
338
|
+
"Model ID not found - can't deploy model to Fireworks serverless"
|
|
339
|
+
)
|
|
340
|
+
return False
|
|
341
|
+
|
|
310
342
|
payload = {
|
|
311
|
-
"displayName":
|
|
343
|
+
"displayName": self.deployment_display_name(),
|
|
312
344
|
"model": model_id,
|
|
313
345
|
}
|
|
314
346
|
headers = {
|
|
@@ -327,4 +359,120 @@ class FireworksFinetune(BaseFinetuneAdapter):
|
|
|
327
359
|
self.datamodel.save_to_file()
|
|
328
360
|
return True
|
|
329
361
|
|
|
362
|
+
logger.error(
|
|
363
|
+
f"Failed to deploy model to Fireworks serverless: [{response.status_code}] {response.text}"
|
|
364
|
+
)
|
|
330
365
|
return False
|
|
366
|
+
|
|
367
|
+
async def _check_or_deploy_server(self) -> bool:
|
|
368
|
+
"""
|
|
369
|
+
Check if the model is already deployed. If not, deploy it to a dedicated server.
|
|
370
|
+
"""
|
|
371
|
+
|
|
372
|
+
# Check if the model is already deployed
|
|
373
|
+
# If it's fine_tune_model_id is set, it might be deployed. However, Fireworks deletes them over time so we need to check.
|
|
374
|
+
if self.datamodel.fine_tune_model_id:
|
|
375
|
+
deployments = await self._fetch_all_deployments()
|
|
376
|
+
for deployment in deployments:
|
|
377
|
+
if deployment[
|
|
378
|
+
"baseModel"
|
|
379
|
+
] == self.datamodel.fine_tune_model_id and deployment["state"] in [
|
|
380
|
+
"READY",
|
|
381
|
+
"CREATING",
|
|
382
|
+
]:
|
|
383
|
+
return True
|
|
384
|
+
|
|
385
|
+
# If the model is not deployed, deploy it
|
|
386
|
+
return await self._deploy_server()
|
|
387
|
+
|
|
388
|
+
async def _deploy_server(self) -> bool:
|
|
389
|
+
# For models that are not serverless, we just need to deploy the model to a server.
|
|
390
|
+
# We use a scale-to-zero on-demand deployment. If you stop using it, it
|
|
391
|
+
# will scale to zero and charges will stop.
|
|
392
|
+
model_id = await self.model_id_checking_status()
|
|
393
|
+
if not model_id:
|
|
394
|
+
logger.error("Model ID not found - can't deploy model to Fireworks server")
|
|
395
|
+
return False
|
|
396
|
+
|
|
397
|
+
api_key, account_id = self.api_key_and_account_id()
|
|
398
|
+
url = f"https://api.fireworks.ai/v1/accounts/{account_id}/deployments"
|
|
399
|
+
|
|
400
|
+
payload = {
|
|
401
|
+
"displayName": self.deployment_display_name(),
|
|
402
|
+
"description": "Deployed by Kiln AI",
|
|
403
|
+
# Allow scale to zero
|
|
404
|
+
"minReplicaCount": 0,
|
|
405
|
+
"autoscalingPolicy": {
|
|
406
|
+
"scaleUpWindow": "30s",
|
|
407
|
+
"scaleDownWindow": "300s",
|
|
408
|
+
# Scale to zero after 5 minutes of inactivity - this is the minimum allowed
|
|
409
|
+
"scaleToZeroWindow": "300s",
|
|
410
|
+
},
|
|
411
|
+
"baseModel": model_id,
|
|
412
|
+
}
|
|
413
|
+
headers = {
|
|
414
|
+
"Authorization": f"Bearer {api_key}",
|
|
415
|
+
"Content-Type": "application/json",
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
async with httpx.AsyncClient() as client:
|
|
419
|
+
response = await client.post(url, json=payload, headers=headers)
|
|
420
|
+
|
|
421
|
+
if response.status_code == 200:
|
|
422
|
+
basemodel = response.json().get("baseModel")
|
|
423
|
+
if basemodel is not None and isinstance(basemodel, str):
|
|
424
|
+
self.datamodel.fine_tune_model_id = basemodel
|
|
425
|
+
if self.datamodel.path:
|
|
426
|
+
self.datamodel.save_to_file()
|
|
427
|
+
return True
|
|
428
|
+
|
|
429
|
+
logger.error(
|
|
430
|
+
f"Failed to deploy model to Fireworks server: [{response.status_code}] {response.text}"
|
|
431
|
+
)
|
|
432
|
+
return False
|
|
433
|
+
|
|
434
|
+
async def _fetch_all_deployments(self) -> List[dict]:
|
|
435
|
+
"""
|
|
436
|
+
Fetch all deployments for an account.
|
|
437
|
+
"""
|
|
438
|
+
api_key, account_id = self.api_key_and_account_id()
|
|
439
|
+
|
|
440
|
+
url = f"https://api.fireworks.ai/v1/accounts/{account_id}/deployments"
|
|
441
|
+
|
|
442
|
+
params = {
|
|
443
|
+
# Note: filter param does not work for baseModel, which would have been ideal, and ideally would have been documented. Instead we'll fetch all and filter.
|
|
444
|
+
# Max page size
|
|
445
|
+
"pageSize": 200,
|
|
446
|
+
}
|
|
447
|
+
headers = {
|
|
448
|
+
"Authorization": f"Bearer {api_key}",
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
deployments = []
|
|
452
|
+
|
|
453
|
+
# Paginate through all deployments
|
|
454
|
+
async with httpx.AsyncClient() as client:
|
|
455
|
+
while True:
|
|
456
|
+
response = await client.get(url, params=params, headers=headers)
|
|
457
|
+
json = response.json()
|
|
458
|
+
if "deployments" not in json or not isinstance(
|
|
459
|
+
json["deployments"], list
|
|
460
|
+
):
|
|
461
|
+
raise ValueError(
|
|
462
|
+
f"Invalid response from Fireworks. Expected list of deployments in 'deployments' key: [{response.status_code}] {response.text}"
|
|
463
|
+
)
|
|
464
|
+
deployments.extend(json["deployments"])
|
|
465
|
+
next_page_token = json.get("nextPageToken")
|
|
466
|
+
if (
|
|
467
|
+
next_page_token
|
|
468
|
+
and isinstance(next_page_token, str)
|
|
469
|
+
and len(next_page_token) > 0
|
|
470
|
+
):
|
|
471
|
+
params = {
|
|
472
|
+
"pageSize": 200,
|
|
473
|
+
"pageToken": next_page_token,
|
|
474
|
+
}
|
|
475
|
+
else:
|
|
476
|
+
break
|
|
477
|
+
|
|
478
|
+
return deployments
|
|
@@ -98,6 +98,13 @@ def test_validate_parameters_valid():
|
|
|
98
98
|
}
|
|
99
99
|
MockFinetune.validate_parameters(valid_params) # Should not raise
|
|
100
100
|
|
|
101
|
+
# Test valid parameters (float as int)
|
|
102
|
+
valid_params = {
|
|
103
|
+
"learning_rate": 1,
|
|
104
|
+
"epochs": 10,
|
|
105
|
+
}
|
|
106
|
+
MockFinetune.validate_parameters(valid_params) # Should not raise
|
|
107
|
+
|
|
101
108
|
|
|
102
109
|
def test_validate_parameters_missing_required():
|
|
103
110
|
# Test missing required parameter
|
|
@@ -261,15 +268,6 @@ async def test_create_and_start_no_parent_task_path():
|
|
|
261
268
|
)
|
|
262
269
|
|
|
263
270
|
|
|
264
|
-
def test_check_valid_provider_model():
|
|
265
|
-
MockFinetune.check_valid_provider_model("openai", "gpt-4o-mini-2024-07-18")
|
|
266
|
-
|
|
267
|
-
with pytest.raises(
|
|
268
|
-
ValueError, match="Provider openai with base model gpt-99 is not available"
|
|
269
|
-
):
|
|
270
|
-
MockFinetune.check_valid_provider_model("openai", "gpt-99")
|
|
271
|
-
|
|
272
|
-
|
|
273
271
|
async def test_create_and_start_invalid_train_split(mock_dataset):
|
|
274
272
|
# Test with an invalid train split name
|
|
275
273
|
mock_dataset.split_contents = {"valid_train": [], "valid_test": []}
|
|
@@ -15,7 +15,7 @@ from kiln_ai.adapters.fine_tune.dataset_formatter import (
|
|
|
15
15
|
generate_chat_message_toolcall,
|
|
16
16
|
generate_huggingface_chat_template,
|
|
17
17
|
generate_huggingface_chat_template_toolcall,
|
|
18
|
-
|
|
18
|
+
generate_vertex_gemini,
|
|
19
19
|
)
|
|
20
20
|
from kiln_ai.adapters.model_adapters.base_adapter import COT_FINAL_ANSWER_PROMPT
|
|
21
21
|
from kiln_ai.datamodel import (
|
|
@@ -447,7 +447,7 @@ def test_generate_vertex_template():
|
|
|
447
447
|
final_output="test output",
|
|
448
448
|
)
|
|
449
449
|
|
|
450
|
-
result =
|
|
450
|
+
result = generate_vertex_gemini(training_data)
|
|
451
451
|
|
|
452
452
|
assert result == {
|
|
453
453
|
"systemInstruction": {
|
|
@@ -475,7 +475,7 @@ def test_generate_vertex_template_thinking():
|
|
|
475
475
|
thinking_final_answer_prompt="thinking final answer prompt",
|
|
476
476
|
)
|
|
477
477
|
|
|
478
|
-
result =
|
|
478
|
+
result = generate_vertex_gemini(training_data)
|
|
479
479
|
|
|
480
480
|
logger.info(result)
|
|
481
481
|
|