kiln-ai 0.13.2__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

@@ -2,11 +2,13 @@ import json
2
2
  from abc import abstractmethod
3
3
  from typing import Dict
4
4
 
5
+ import jsonschema
6
+
5
7
  from kiln_ai.adapters.adapter_registry import adapter_for_task
6
8
  from kiln_ai.adapters.ml_model_list import ModelProviderName
7
9
  from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
8
10
  from kiln_ai.datamodel.eval import Eval, EvalConfig, EvalScores
9
- from kiln_ai.datamodel.json_schema import validate_schema
11
+ from kiln_ai.datamodel.json_schema import validate_schema_with_value_error
10
12
  from kiln_ai.datamodel.task import RunConfig, TaskOutputRatingType, TaskRun
11
13
  from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
12
14
 
@@ -72,7 +74,10 @@ class BaseEval:
72
74
  run_output = await run_adapter.invoke(parsed_input)
73
75
 
74
76
  eval_output, intermediate_outputs = await self.run_eval(run_output)
75
- validate_schema(eval_output, self.score_schema)
77
+
78
+ validate_schema_with_value_error(
79
+ eval_output, self.score_schema, "Eval output does not match score schema."
80
+ )
76
81
 
77
82
  return run_output, eval_output, intermediate_outputs
78
83
 
@@ -72,8 +72,6 @@ class BaseFinetuneAdapter(ABC):
72
72
  Create and start a fine-tune.
73
73
  """
74
74
 
75
- cls.check_valid_provider_model(provider_id, provider_base_model_id)
76
-
77
75
  if not dataset.id:
78
76
  raise ValueError("Dataset must have an id")
79
77
 
@@ -168,9 +166,12 @@ class BaseFinetuneAdapter(ABC):
168
166
 
169
167
  # Strict type checking for numeric types
170
168
  if expected_type is float and not isinstance(value, float):
171
- raise ValueError(
172
- f"Parameter {parameter.name} must be a float, got {type(value)}"
173
- )
169
+ if isinstance(value, int):
170
+ value = float(value)
171
+ else:
172
+ raise ValueError(
173
+ f"Parameter {parameter.name} must be a float, got {type(value)}"
174
+ )
174
175
  elif expected_type is int and not isinstance(value, int):
175
176
  raise ValueError(
176
177
  f"Parameter {parameter.name} must be an integer, got {type(value)}"
@@ -184,21 +185,3 @@ class BaseFinetuneAdapter(ABC):
184
185
  for parameter_key in parameters:
185
186
  if parameter_key not in allowed_parameters:
186
187
  raise ValueError(f"Parameter {parameter_key} is not available")
187
-
188
- @classmethod
189
- def check_valid_provider_model(
190
- cls, provider_id: str, provider_base_model_id: str
191
- ) -> None:
192
- """
193
- Check if the provider and base model are valid.
194
- """
195
- for model in built_in_models:
196
- for provider in model.providers:
197
- if (
198
- provider.name == provider_id
199
- and provider.provider_finetune_id == provider_base_model_id
200
- ):
201
- return
202
- raise ValueError(
203
- f"Provider {provider_id} with base model {provider_base_model_id} is not available"
204
- )
@@ -30,8 +30,8 @@ class DatasetFormat(str, Enum):
30
30
  "huggingface_chat_template_toolcall_jsonl"
31
31
  )
32
32
 
33
- """Vertex Gemini 1.5 format (flash and pro)"""
34
- VERTEX_GEMINI_1_5 = "vertex_gemini_1_5"
33
+ """Vertex Gemini format"""
34
+ VERTEX_GEMINI = "vertex_gemini"
35
35
 
36
36
 
37
37
  @dataclass
@@ -288,7 +288,7 @@ def generate_huggingface_chat_template_toolcall(
288
288
  return {"conversations": conversations}
289
289
 
290
290
 
291
- def generate_vertex_gemini_1_5(
291
+ def generate_vertex_gemini(
292
292
  training_data: ModelTrainingData,
293
293
  ) -> Dict[str, Any]:
294
294
  """Generate Vertex Gemini 1.5 format (flash and pro)"""
@@ -346,7 +346,7 @@ FORMAT_GENERATORS: Dict[DatasetFormat, FormatGenerator] = {
346
346
  DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL: generate_chat_message_toolcall,
347
347
  DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_JSONL: generate_huggingface_chat_template,
348
348
  DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL: generate_huggingface_chat_template_toolcall,
349
- DatasetFormat.VERTEX_GEMINI_1_5: generate_vertex_gemini_1_5,
349
+ DatasetFormat.VERTEX_GEMINI: generate_vertex_gemini,
350
350
  }
351
351
 
352
352
 
@@ -4,10 +4,12 @@ from kiln_ai.adapters.fine_tune.base_finetune import BaseFinetuneAdapter
4
4
  from kiln_ai.adapters.fine_tune.fireworks_finetune import FireworksFinetune
5
5
  from kiln_ai.adapters.fine_tune.openai_finetune import OpenAIFinetune
6
6
  from kiln_ai.adapters.fine_tune.together_finetune import TogetherFinetune
7
+ from kiln_ai.adapters.fine_tune.vertex_finetune import VertexFinetune
7
8
  from kiln_ai.adapters.ml_model_list import ModelProviderName
8
9
 
9
10
  finetune_registry: dict[ModelProviderName, Type[BaseFinetuneAdapter]] = {
10
11
  ModelProviderName.openai: OpenAIFinetune,
11
12
  ModelProviderName.fireworks_ai: FireworksFinetune,
12
13
  ModelProviderName.together_ai: TogetherFinetune,
14
+ ModelProviderName.vertex: VertexFinetune,
13
15
  }
@@ -1,4 +1,5 @@
1
- from typing import Tuple
1
+ import logging
2
+ from typing import List, Tuple
2
3
  from uuid import uuid4
3
4
 
4
5
  import httpx
@@ -13,6 +14,14 @@ from kiln_ai.adapters.fine_tune.dataset_formatter import DatasetFormat, DatasetF
13
14
  from kiln_ai.datamodel import DatasetSplit, StructuredOutputMode, Task
14
15
  from kiln_ai.utils.config import Config
15
16
 
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # https://docs.fireworks.ai/fine-tuning/fine-tuning-models#supported-base-models-loras-on-serverless
20
+ serverless_models = [
21
+ "accounts/fireworks/models/llama-v3p1-8b-instruct",
22
+ "accounts/fireworks/models/llama-v3p1-70b-instruct",
23
+ ]
24
+
16
25
 
17
26
  class FireworksFinetune(BaseFinetuneAdapter):
18
27
  """
@@ -189,7 +198,8 @@ class FireworksFinetune(BaseFinetuneAdapter):
189
198
  if not api_key or not account_id:
190
199
  raise ValueError("Fireworks API key or account ID not set")
191
200
  url = f"https://api.fireworks.ai/v1/accounts/{account_id}/datasets"
192
- dataset_id = str(uuid4())
201
+ # First char can't be a digit: https://discord.com/channels/1137072072808472616/1363214412395184350/1363214412395184350
202
+ dataset_id = "kiln-" + str(uuid4())
193
203
  payload = {
194
204
  "datasetId": dataset_id,
195
205
  "dataset": {
@@ -283,32 +293,54 @@ class FireworksFinetune(BaseFinetuneAdapter):
283
293
  return {k: v for k, v in payload.items() if v is not None}
284
294
 
285
295
  async def _deploy(self) -> bool:
286
- # Now we "deploy" the model using PEFT serverless.
287
- # A bit complicated: most fireworks deploys are server based.
288
- # However, a Lora can be serverless (PEFT).
289
- # By calling the deploy endpoint WITHOUT first creating a deployment ID, it will only deploy if it can be done serverless.
290
- # https://docs.fireworks.ai/models/deploying#deploying-to-serverless
291
- # This endpoint will return 400 if already deployed with code 9, so we consider that a success.
296
+ if self.datamodel.base_model_id in serverless_models:
297
+ return await self._deploy_serverless()
298
+ else:
299
+ return await self._check_or_deploy_server()
292
300
 
301
+ def api_key_and_account_id(self) -> Tuple[str, str]:
293
302
  api_key = Config.shared().fireworks_api_key
294
303
  account_id = Config.shared().fireworks_account_id
295
304
  if not api_key or not account_id:
296
305
  raise ValueError("Fireworks API key or account ID not set")
306
+ return api_key, account_id
297
307
 
308
+ def deployment_display_name(self) -> str:
309
+ # Limit the display name to 60 characters
310
+ display_name = f"Kiln AI fine-tuned model [ID:{self.datamodel.id}][name:{self.datamodel.name}]"[
311
+ :60
312
+ ]
313
+ return display_name
314
+
315
+ async def model_id_checking_status(self) -> str | None:
298
316
  # Model ID != fine tune ID on Fireworks. Model is the result of the tune job. Call status to get it.
299
317
  status, model_id = await self._status()
300
318
  if status.status != FineTuneStatusType.completed:
301
- return False
319
+ return None
302
320
  if not model_id or not isinstance(model_id, str):
303
- return False
321
+ return None
322
+ return model_id
323
+
324
+ async def _deploy_serverless(self) -> bool:
325
+ # Now we "deploy" the model using PEFT serverless.
326
+ # A bit complicated: most fireworks deploys are server based.
327
+ # However, a Lora can be serverless (PEFT).
328
+ # By calling the deploy endpoint WITHOUT first creating a deployment ID, it will only deploy if it can be done serverless.
329
+ # https://docs.fireworks.ai/models/deploying#deploying-to-serverless
330
+ # This endpoint will return 400 if already deployed with code 9, so we consider that a success.
331
+
332
+ api_key, account_id = self.api_key_and_account_id()
304
333
 
305
334
  url = f"https://api.fireworks.ai/v1/accounts/{account_id}/deployedModels"
306
- # Limit the display name to 60 characters
307
- display_name = f"Kiln AI fine-tuned model [ID:{self.datamodel.id}][name:{self.datamodel.name}]"[
308
- :60
309
- ]
335
+ model_id = await self.model_id_checking_status()
336
+ if not model_id:
337
+ logger.error(
338
+ "Model ID not found - can't deploy model to Fireworks serverless"
339
+ )
340
+ return False
341
+
310
342
  payload = {
311
- "displayName": display_name,
343
+ "displayName": self.deployment_display_name(),
312
344
  "model": model_id,
313
345
  }
314
346
  headers = {
@@ -327,4 +359,120 @@ class FireworksFinetune(BaseFinetuneAdapter):
327
359
  self.datamodel.save_to_file()
328
360
  return True
329
361
 
362
+ logger.error(
363
+ f"Failed to deploy model to Fireworks serverless: [{response.status_code}] {response.text}"
364
+ )
330
365
  return False
366
+
367
+ async def _check_or_deploy_server(self) -> bool:
368
+ """
369
+ Check if the model is already deployed. If not, deploy it to a dedicated server.
370
+ """
371
+
372
+ # Check if the model is already deployed
373
+ # If it's fine_tune_model_id is set, it might be deployed. However, Fireworks deletes them over time so we need to check.
374
+ if self.datamodel.fine_tune_model_id:
375
+ deployments = await self._fetch_all_deployments()
376
+ for deployment in deployments:
377
+ if deployment[
378
+ "baseModel"
379
+ ] == self.datamodel.fine_tune_model_id and deployment["state"] in [
380
+ "READY",
381
+ "CREATING",
382
+ ]:
383
+ return True
384
+
385
+ # If the model is not deployed, deploy it
386
+ return await self._deploy_server()
387
+
388
+ async def _deploy_server(self) -> bool:
389
+ # For models that are not serverless, we just need to deploy the model to a server.
390
+ # We use a scale-to-zero on-demand deployment. If you stop using it, it
391
+ # will scale to zero and charges will stop.
392
+ model_id = await self.model_id_checking_status()
393
+ if not model_id:
394
+ logger.error("Model ID not found - can't deploy model to Fireworks server")
395
+ return False
396
+
397
+ api_key, account_id = self.api_key_and_account_id()
398
+ url = f"https://api.fireworks.ai/v1/accounts/{account_id}/deployments"
399
+
400
+ payload = {
401
+ "displayName": self.deployment_display_name(),
402
+ "description": "Deployed by Kiln AI",
403
+ # Allow scale to zero
404
+ "minReplicaCount": 0,
405
+ "autoscalingPolicy": {
406
+ "scaleUpWindow": "30s",
407
+ "scaleDownWindow": "300s",
408
+ # Scale to zero after 5 minutes of inactivity - this is the minimum allowed
409
+ "scaleToZeroWindow": "300s",
410
+ },
411
+ "baseModel": model_id,
412
+ }
413
+ headers = {
414
+ "Authorization": f"Bearer {api_key}",
415
+ "Content-Type": "application/json",
416
+ }
417
+
418
+ async with httpx.AsyncClient() as client:
419
+ response = await client.post(url, json=payload, headers=headers)
420
+
421
+ if response.status_code == 200:
422
+ basemodel = response.json().get("baseModel")
423
+ if basemodel is not None and isinstance(basemodel, str):
424
+ self.datamodel.fine_tune_model_id = basemodel
425
+ if self.datamodel.path:
426
+ self.datamodel.save_to_file()
427
+ return True
428
+
429
+ logger.error(
430
+ f"Failed to deploy model to Fireworks server: [{response.status_code}] {response.text}"
431
+ )
432
+ return False
433
+
434
+ async def _fetch_all_deployments(self) -> List[dict]:
435
+ """
436
+ Fetch all deployments for an account.
437
+ """
438
+ api_key, account_id = self.api_key_and_account_id()
439
+
440
+ url = f"https://api.fireworks.ai/v1/accounts/{account_id}/deployments"
441
+
442
+ params = {
443
+ # Note: filter param does not work for baseModel, which would have been ideal, and ideally would have been documented. Instead we'll fetch all and filter.
444
+ # Max page size
445
+ "pageSize": 200,
446
+ }
447
+ headers = {
448
+ "Authorization": f"Bearer {api_key}",
449
+ }
450
+
451
+ deployments = []
452
+
453
+ # Paginate through all deployments
454
+ async with httpx.AsyncClient() as client:
455
+ while True:
456
+ response = await client.get(url, params=params, headers=headers)
457
+ json = response.json()
458
+ if "deployments" not in json or not isinstance(
459
+ json["deployments"], list
460
+ ):
461
+ raise ValueError(
462
+ f"Invalid response from Fireworks. Expected list of deployments in 'deployments' key: [{response.status_code}] {response.text}"
463
+ )
464
+ deployments.extend(json["deployments"])
465
+ next_page_token = json.get("nextPageToken")
466
+ if (
467
+ next_page_token
468
+ and isinstance(next_page_token, str)
469
+ and len(next_page_token) > 0
470
+ ):
471
+ params = {
472
+ "pageSize": 200,
473
+ "pageToken": next_page_token,
474
+ }
475
+ else:
476
+ break
477
+
478
+ return deployments
@@ -98,6 +98,13 @@ def test_validate_parameters_valid():
98
98
  }
99
99
  MockFinetune.validate_parameters(valid_params) # Should not raise
100
100
 
101
+ # Test valid parameters (float as int)
102
+ valid_params = {
103
+ "learning_rate": 1,
104
+ "epochs": 10,
105
+ }
106
+ MockFinetune.validate_parameters(valid_params) # Should not raise
107
+
101
108
 
102
109
  def test_validate_parameters_missing_required():
103
110
  # Test missing required parameter
@@ -261,15 +268,6 @@ async def test_create_and_start_no_parent_task_path():
261
268
  )
262
269
 
263
270
 
264
- def test_check_valid_provider_model():
265
- MockFinetune.check_valid_provider_model("openai", "gpt-4o-mini-2024-07-18")
266
-
267
- with pytest.raises(
268
- ValueError, match="Provider openai with base model gpt-99 is not available"
269
- ):
270
- MockFinetune.check_valid_provider_model("openai", "gpt-99")
271
-
272
-
273
271
  async def test_create_and_start_invalid_train_split(mock_dataset):
274
272
  # Test with an invalid train split name
275
273
  mock_dataset.split_contents = {"valid_train": [], "valid_test": []}
@@ -15,7 +15,7 @@ from kiln_ai.adapters.fine_tune.dataset_formatter import (
15
15
  generate_chat_message_toolcall,
16
16
  generate_huggingface_chat_template,
17
17
  generate_huggingface_chat_template_toolcall,
18
- generate_vertex_gemini_1_5,
18
+ generate_vertex_gemini,
19
19
  )
20
20
  from kiln_ai.adapters.model_adapters.base_adapter import COT_FINAL_ANSWER_PROMPT
21
21
  from kiln_ai.datamodel import (
@@ -447,7 +447,7 @@ def test_generate_vertex_template():
447
447
  final_output="test output",
448
448
  )
449
449
 
450
- result = generate_vertex_gemini_1_5(training_data)
450
+ result = generate_vertex_gemini(training_data)
451
451
 
452
452
  assert result == {
453
453
  "systemInstruction": {
@@ -475,7 +475,7 @@ def test_generate_vertex_template_thinking():
475
475
  thinking_final_answer_prompt="thinking final answer prompt",
476
476
  )
477
477
 
478
- result = generate_vertex_gemini_1_5(training_data)
478
+ result = generate_vertex_gemini(training_data)
479
479
 
480
480
  logger.info(result)
481
481