kiln-ai 0.13.2__py3-none-any.whl → 0.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

@@ -72,8 +72,6 @@ class BaseFinetuneAdapter(ABC):
72
72
  Create and start a fine-tune.
73
73
  """
74
74
 
75
- cls.check_valid_provider_model(provider_id, provider_base_model_id)
76
-
77
75
  if not dataset.id:
78
76
  raise ValueError("Dataset must have an id")
79
77
 
@@ -184,21 +182,3 @@ class BaseFinetuneAdapter(ABC):
184
182
  for parameter_key in parameters:
185
183
  if parameter_key not in allowed_parameters:
186
184
  raise ValueError(f"Parameter {parameter_key} is not available")
187
-
188
- @classmethod
189
- def check_valid_provider_model(
190
- cls, provider_id: str, provider_base_model_id: str
191
- ) -> None:
192
- """
193
- Check if the provider and base model are valid.
194
- """
195
- for model in built_in_models:
196
- for provider in model.providers:
197
- if (
198
- provider.name == provider_id
199
- and provider.provider_finetune_id == provider_base_model_id
200
- ):
201
- return
202
- raise ValueError(
203
- f"Provider {provider_id} with base model {provider_base_model_id} is not available"
204
- )
@@ -1,4 +1,5 @@
1
- from typing import Tuple
1
+ import logging
2
+ from typing import List, Tuple
2
3
  from uuid import uuid4
3
4
 
4
5
  import httpx
@@ -13,6 +14,14 @@ from kiln_ai.adapters.fine_tune.dataset_formatter import DatasetFormat, DatasetF
13
14
  from kiln_ai.datamodel import DatasetSplit, StructuredOutputMode, Task
14
15
  from kiln_ai.utils.config import Config
15
16
 
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # https://docs.fireworks.ai/fine-tuning/fine-tuning-models#supported-base-models-loras-on-serverless
20
+ serverless_models = [
21
+ "accounts/fireworks/models/llama-v3p1-8b-instruct",
22
+ "accounts/fireworks/models/llama-v3p1-70b-instruct",
23
+ ]
24
+
16
25
 
17
26
  class FireworksFinetune(BaseFinetuneAdapter):
18
27
  """
@@ -283,32 +292,54 @@ class FireworksFinetune(BaseFinetuneAdapter):
283
292
  return {k: v for k, v in payload.items() if v is not None}
284
293
 
285
294
  async def _deploy(self) -> bool:
286
- # Now we "deploy" the model using PEFT serverless.
287
- # A bit complicated: most fireworks deploys are server based.
288
- # However, a Lora can be serverless (PEFT).
289
- # By calling the deploy endpoint WITHOUT first creating a deployment ID, it will only deploy if it can be done serverless.
290
- # https://docs.fireworks.ai/models/deploying#deploying-to-serverless
291
- # This endpoint will return 400 if already deployed with code 9, so we consider that a success.
295
+ if self.datamodel.base_model_id in serverless_models:
296
+ return await self._deploy_serverless()
297
+ else:
298
+ return await self._check_or_deploy_server()
292
299
 
300
+ def api_key_and_account_id(self) -> Tuple[str, str]:
293
301
  api_key = Config.shared().fireworks_api_key
294
302
  account_id = Config.shared().fireworks_account_id
295
303
  if not api_key or not account_id:
296
304
  raise ValueError("Fireworks API key or account ID not set")
305
+ return api_key, account_id
297
306
 
307
+ def deployment_display_name(self) -> str:
308
+ # Limit the display name to 60 characters
309
+ display_name = f"Kiln AI fine-tuned model [ID:{self.datamodel.id}][name:{self.datamodel.name}]"[
310
+ :60
311
+ ]
312
+ return display_name
313
+
314
+ async def model_id_checking_status(self) -> str | None:
298
315
  # Model ID != fine tune ID on Fireworks. Model is the result of the tune job. Call status to get it.
299
316
  status, model_id = await self._status()
300
317
  if status.status != FineTuneStatusType.completed:
301
- return False
318
+ return None
302
319
  if not model_id or not isinstance(model_id, str):
303
- return False
320
+ return None
321
+ return model_id
322
+
323
+ async def _deploy_serverless(self) -> bool:
324
+ # Now we "deploy" the model using PEFT serverless.
325
+ # A bit complicated: most fireworks deploys are server based.
326
+ # However, a Lora can be serverless (PEFT).
327
+ # By calling the deploy endpoint WITHOUT first creating a deployment ID, it will only deploy if it can be done serverless.
328
+ # https://docs.fireworks.ai/models/deploying#deploying-to-serverless
329
+ # This endpoint will return 400 if already deployed with code 9, so we consider that a success.
330
+
331
+ api_key, account_id = self.api_key_and_account_id()
304
332
 
305
333
  url = f"https://api.fireworks.ai/v1/accounts/{account_id}/deployedModels"
306
- # Limit the display name to 60 characters
307
- display_name = f"Kiln AI fine-tuned model [ID:{self.datamodel.id}][name:{self.datamodel.name}]"[
308
- :60
309
- ]
334
+ model_id = await self.model_id_checking_status()
335
+ if not model_id:
336
+ logger.error(
337
+ "Model ID not found - can't deploy model to Fireworks serverless"
338
+ )
339
+ return False
340
+
310
341
  payload = {
311
- "displayName": display_name,
342
+ "displayName": self.deployment_display_name(),
312
343
  "model": model_id,
313
344
  }
314
345
  headers = {
@@ -327,4 +358,120 @@ class FireworksFinetune(BaseFinetuneAdapter):
327
358
  self.datamodel.save_to_file()
328
359
  return True
329
360
 
361
+ logger.error(
362
+ f"Failed to deploy model to Fireworks serverless: [{response.status_code}] {response.text}"
363
+ )
330
364
  return False
365
+
366
+ async def _check_or_deploy_server(self) -> bool:
367
+ """
368
+ Check if the model is already deployed. If not, deploy it to a dedicated server.
369
+ """
370
+
371
+ # Check if the model is already deployed
372
+ # If it's fine_tune_model_id is set, it might be deployed. However, Fireworks deletes them over time so we need to check.
373
+ if self.datamodel.fine_tune_model_id:
374
+ deployments = await self._fetch_all_deployments()
375
+ for deployment in deployments:
376
+ if deployment[
377
+ "baseModel"
378
+ ] == self.datamodel.fine_tune_model_id and deployment["state"] in [
379
+ "READY",
380
+ "CREATING",
381
+ ]:
382
+ return True
383
+
384
+ # If the model is not deployed, deploy it
385
+ return await self._deploy_server()
386
+
387
+ async def _deploy_server(self) -> bool:
388
+ # For models that are not serverless, we just need to deploy the model to a server.
389
+ # We use a scale-to-zero on-demand deployment. If you stop using it, it
390
+ # will scale to zero and charges will stop.
391
+ model_id = await self.model_id_checking_status()
392
+ if not model_id:
393
+ logger.error("Model ID not found - can't deploy model to Fireworks server")
394
+ return False
395
+
396
+ api_key, account_id = self.api_key_and_account_id()
397
+ url = f"https://api.fireworks.ai/v1/accounts/{account_id}/deployments"
398
+
399
+ payload = {
400
+ "displayName": self.deployment_display_name(),
401
+ "description": "Deployed by Kiln AI",
402
+ # Allow scale to zero
403
+ "minReplicaCount": 0,
404
+ "autoscalingPolicy": {
405
+ "scaleUpWindow": "30s",
406
+ "scaleDownWindow": "300s",
407
+ # Scale to zero after 5 minutes of inactivity - this is the minimum allowed
408
+ "scaleToZeroWindow": "300s",
409
+ },
410
+ "baseModel": model_id,
411
+ }
412
+ headers = {
413
+ "Authorization": f"Bearer {api_key}",
414
+ "Content-Type": "application/json",
415
+ }
416
+
417
+ async with httpx.AsyncClient() as client:
418
+ response = await client.post(url, json=payload, headers=headers)
419
+
420
+ if response.status_code == 200:
421
+ basemodel = response.json().get("baseModel")
422
+ if basemodel is not None and isinstance(basemodel, str):
423
+ self.datamodel.fine_tune_model_id = basemodel
424
+ if self.datamodel.path:
425
+ self.datamodel.save_to_file()
426
+ return True
427
+
428
+ logger.error(
429
+ f"Failed to deploy model to Fireworks server: [{response.status_code}] {response.text}"
430
+ )
431
+ return False
432
+
433
+ async def _fetch_all_deployments(self) -> List[dict]:
434
+ """
435
+ Fetch all deployments for an account.
436
+ """
437
+ api_key, account_id = self.api_key_and_account_id()
438
+
439
+ url = f"https://api.fireworks.ai/v1/accounts/{account_id}/deployments"
440
+
441
+ params = {
442
+ # Note: filter param does not work for baseModel, which would have been ideal, and ideally would have been documented. Instead we'll fetch all and filter.
443
+ # Max page size
444
+ "pageSize": 200,
445
+ }
446
+ headers = {
447
+ "Authorization": f"Bearer {api_key}",
448
+ }
449
+
450
+ deployments = []
451
+
452
+ # Paginate through all deployments
453
+ async with httpx.AsyncClient() as client:
454
+ while True:
455
+ response = await client.get(url, params=params, headers=headers)
456
+ json = response.json()
457
+ if "deployments" not in json or not isinstance(
458
+ json["deployments"], list
459
+ ):
460
+ raise ValueError(
461
+ f"Invalid response from Fireworks. Expected list of deployments in 'deployments' key: [{response.status_code}] {response.text}"
462
+ )
463
+ deployments.extend(json["deployments"])
464
+ next_page_token = json.get("nextPageToken")
465
+ if (
466
+ next_page_token
467
+ and isinstance(next_page_token, str)
468
+ and len(next_page_token) > 0
469
+ ):
470
+ params = {
471
+ "pageSize": 200,
472
+ "pageToken": next_page_token,
473
+ }
474
+ else:
475
+ break
476
+
477
+ return deployments
@@ -261,15 +261,6 @@ async def test_create_and_start_no_parent_task_path():
261
261
  )
262
262
 
263
263
 
264
- def test_check_valid_provider_model():
265
- MockFinetune.check_valid_provider_model("openai", "gpt-4o-mini-2024-07-18")
266
-
267
- with pytest.raises(
268
- ValueError, match="Provider openai with base model gpt-99 is not available"
269
- ):
270
- MockFinetune.check_valid_provider_model("openai", "gpt-99")
271
-
272
-
273
264
  async def test_create_and_start_invalid_train_split(mock_dataset):
274
265
  # Test with an invalid train split name
275
266
  mock_dataset.split_contents = {"valid_train": [], "valid_test": []}
@@ -448,7 +448,7 @@ def test_available_parameters(fireworks_finetune):
448
448
  assert payload_parameters == {"loraRank": 16, "epochs": 3}
449
449
 
450
450
 
451
- async def test_deploy_success(fireworks_finetune, mock_api_key):
451
+ async def test_deploy_serverless_success(fireworks_finetune, mock_api_key):
452
452
  # Mock response for successful deployment
453
453
  success_response = MagicMock(spec=httpx.Response)
454
454
  success_response.status_code = 200
@@ -467,12 +467,12 @@ async def test_deploy_success(fireworks_finetune, mock_api_key):
467
467
  mock_client.post.return_value = success_response
468
468
  mock_client_class.return_value.__aenter__.return_value = mock_client
469
469
 
470
- result = await fireworks_finetune._deploy()
470
+ result = await fireworks_finetune._deploy_serverless()
471
471
  assert result is True
472
472
  assert fireworks_finetune.datamodel.fine_tune_model_id == "ftm-123"
473
473
 
474
474
 
475
- async def test_deploy_already_deployed(fireworks_finetune, mock_api_key):
475
+ async def test_deploy_serverless_already_deployed(fireworks_finetune, mock_api_key):
476
476
  # Mock response for already deployed model
477
477
  already_deployed_response = MagicMock(spec=httpx.Response)
478
478
  already_deployed_response.status_code = 400
@@ -494,12 +494,12 @@ async def test_deploy_already_deployed(fireworks_finetune, mock_api_key):
494
494
  mock_client.post.return_value = already_deployed_response
495
495
  mock_client_class.return_value.__aenter__.return_value = mock_client
496
496
 
497
- result = await fireworks_finetune._deploy()
497
+ result = await fireworks_finetune._deploy_serverless()
498
498
  assert result is True
499
499
  assert fireworks_finetune.datamodel.fine_tune_model_id == "ftm-123"
500
500
 
501
501
 
502
- async def test_deploy_failure(fireworks_finetune, mock_api_key):
502
+ async def test_deploy_serverless_failure(fireworks_finetune, mock_api_key):
503
503
  # Mock response for failed deployment
504
504
  failure_response = MagicMock(spec=httpx.Response)
505
505
  failure_response.status_code = 500
@@ -510,18 +510,28 @@ async def test_deploy_failure(fireworks_finetune, mock_api_key):
510
510
  mock_client.post.return_value = failure_response
511
511
  mock_client_class.return_value.__aenter__.return_value = mock_client
512
512
 
513
- result = await fireworks_finetune._deploy()
513
+ result = await fireworks_finetune._deploy_serverless()
514
514
  assert result is False
515
515
 
516
516
 
517
- async def test_deploy_missing_credentials(fireworks_finetune):
517
+ async def test_deploy_serverless_missing_credentials(fireworks_finetune):
518
518
  # Test missing API key or account ID
519
519
  with patch.object(Config, "shared") as mock_config:
520
520
  mock_config.return_value.fireworks_api_key = None
521
521
  mock_config.return_value.fireworks_account_id = None
522
522
 
523
523
  with pytest.raises(ValueError, match="Fireworks API key or account ID not set"):
524
- await fireworks_finetune._deploy()
524
+ await fireworks_finetune._deploy_serverless()
525
+
526
+
527
+ async def test_deploy_server_missing_credentials(fireworks_finetune):
528
+ # Test missing API key or account ID
529
+ with patch.object(Config, "shared") as mock_config:
530
+ mock_config.return_value.fireworks_api_key = None
531
+ mock_config.return_value.fireworks_account_id = None
532
+
533
+ response = await fireworks_finetune._check_or_deploy_server()
534
+ assert response is False
525
535
 
526
536
 
527
537
  async def test_deploy_missing_model_id(fireworks_finetune, mock_api_key):
@@ -564,3 +574,479 @@ async def test_status_with_deploy(fireworks_finetune, mock_api_key):
564
574
  # Verify message was updated due to failed deployment
565
575
  assert status.status == FineTuneStatusType.completed
566
576
  assert status.message == "Fine-tuning job completed but failed to deploy model."
577
+
578
+
579
+ @pytest.mark.paid
580
+ async def test_fetch_all_deployments(fireworks_finetune):
581
+ deployments = await fireworks_finetune._fetch_all_deployments()
582
+ assert isinstance(deployments, list)
583
+
584
+
585
+ async def test_api_key_and_account_id(fireworks_finetune, mock_api_key):
586
+ # Test successful retrieval of API key and account ID
587
+ api_key, account_id = fireworks_finetune.api_key_and_account_id()
588
+ assert api_key == "test-api-key"
589
+ assert account_id == "test-account-id"
590
+
591
+
592
+ async def test_api_key_and_account_id_missing_credentials(fireworks_finetune):
593
+ # Test missing API key or account ID
594
+ with patch.object(Config, "shared") as mock_config:
595
+ mock_config.return_value.fireworks_api_key = None
596
+ mock_config.return_value.fireworks_account_id = None
597
+
598
+ with pytest.raises(ValueError, match="Fireworks API key or account ID not set"):
599
+ fireworks_finetune.api_key_and_account_id()
600
+
601
+
602
+ def test_deployment_display_name(fireworks_finetune):
603
+ # Test with default ID and name
604
+ display_name = fireworks_finetune.deployment_display_name()
605
+ expected = f"Kiln AI fine-tuned model [ID:{fireworks_finetune.datamodel.id}][name:test-finetune]"[
606
+ :60
607
+ ]
608
+ assert display_name == expected
609
+
610
+ # Test with a very long name to ensure 60 character limit
611
+ fireworks_finetune.datamodel.name = "x" * 100
612
+ display_name = fireworks_finetune.deployment_display_name()
613
+ assert len(display_name) == 60
614
+ assert display_name.startswith("Kiln AI fine-tuned model [ID:")
615
+
616
+
617
+ async def test_model_id_checking_status_completed(fireworks_finetune):
618
+ # Test with completed status and valid model ID
619
+ status_response = (
620
+ FineTuneStatus(status=FineTuneStatusType.completed, message=""),
621
+ "model-123",
622
+ )
623
+
624
+ with patch.object(fireworks_finetune, "_status", return_value=status_response):
625
+ model_id = await fireworks_finetune.model_id_checking_status()
626
+ assert model_id == "model-123"
627
+
628
+
629
+ async def test_model_id_checking_status_not_completed(fireworks_finetune):
630
+ # Test with non-completed status
631
+ status_response = (
632
+ FineTuneStatus(status=FineTuneStatusType.running, message=""),
633
+ "model-123",
634
+ )
635
+
636
+ with patch.object(fireworks_finetune, "_status", return_value=status_response):
637
+ model_id = await fireworks_finetune.model_id_checking_status()
638
+ assert model_id is None
639
+
640
+
641
+ async def test_model_id_checking_status_invalid_model_id(fireworks_finetune):
642
+ # Test with completed status but invalid model ID
643
+ status_response = (
644
+ FineTuneStatus(status=FineTuneStatusType.completed, message=""),
645
+ None,
646
+ )
647
+
648
+ with patch.object(fireworks_finetune, "_status", return_value=status_response):
649
+ model_id = await fireworks_finetune.model_id_checking_status()
650
+ assert model_id is None
651
+
652
+ # Test with non-string model ID
653
+ status_response = (
654
+ FineTuneStatus(status=FineTuneStatusType.completed, message=""),
655
+ {"id": "model-123"}, # Not a string
656
+ )
657
+
658
+ with patch.object(fireworks_finetune, "_status", return_value=status_response):
659
+ model_id = await fireworks_finetune.model_id_checking_status()
660
+ assert model_id is None
661
+
662
+
663
+ @pytest.mark.parametrize(
664
+ "base_model_id,expected_method",
665
+ [
666
+ ("accounts/fireworks/models/llama-v3p1-8b-instruct", "_deploy_serverless"),
667
+ ("accounts/fireworks/models/llama-v3p1-70b-instruct", "_deploy_serverless"),
668
+ ("some-other-model", "_check_or_deploy_server"),
669
+ ],
670
+ )
671
+ async def test_deploy_model_selection(
672
+ fireworks_finetune, base_model_id, expected_method, mock_api_key
673
+ ):
674
+ # Set the base model ID
675
+ fireworks_finetune.datamodel.base_model_id = base_model_id
676
+
677
+ # Mock the deployment methods
678
+ with (
679
+ patch.object(
680
+ fireworks_finetune, "_deploy_serverless", return_value=True
681
+ ) as mock_serverless,
682
+ patch.object(
683
+ fireworks_finetune, "_check_or_deploy_server", return_value=True
684
+ ) as mock_server,
685
+ ):
686
+ result = await fireworks_finetune._deploy()
687
+
688
+ # Verify the correct method was called based on the model
689
+ if expected_method == "_deploy_serverless":
690
+ mock_serverless.assert_called_once()
691
+ mock_server.assert_not_called()
692
+ else:
693
+ mock_serverless.assert_not_called()
694
+ mock_server.assert_called_once()
695
+
696
+ assert result is True
697
+
698
+
699
+ async def test_fetch_all_deployments_request_error(fireworks_finetune, mock_api_key):
700
+ # Test with error response
701
+ error_response = MagicMock(spec=httpx.Response)
702
+ error_response.status_code = 500
703
+ error_response.text = "Internal Server Error"
704
+
705
+ with patch("httpx.AsyncClient") as mock_client_class:
706
+ mock_client = AsyncMock()
707
+ mock_client.get.side_effect = Exception("API request failed")
708
+ mock_client_class.return_value.__aenter__.return_value = mock_client
709
+
710
+ with pytest.raises(Exception, match="API request failed"):
711
+ await fireworks_finetune._fetch_all_deployments()
712
+
713
+ # Verify API was called with correct parameters
714
+ mock_client.get.assert_called_once()
715
+ call_args = mock_client.get.call_args[1]
716
+ assert "params" in call_args
717
+ assert call_args["params"]["pageSize"] == 200
718
+
719
+
720
+ async def test_fetch_all_deployments_standard_case(fireworks_finetune, mock_api_key):
721
+ # Test with single page of results
722
+ mock_deployments = [
723
+ {"id": "deploy-1", "baseModel": "model-1", "state": "READY"},
724
+ {"id": "deploy-2", "baseModel": "model-2", "state": "READY"},
725
+ ]
726
+
727
+ success_response = MagicMock(spec=httpx.Response)
728
+ success_response.status_code = 200
729
+ success_response.json.return_value = {
730
+ "deployments": mock_deployments,
731
+ "nextPageToken": None,
732
+ }
733
+
734
+ with patch("httpx.AsyncClient") as mock_client_class:
735
+ mock_client = AsyncMock()
736
+ mock_client.get.return_value = success_response
737
+ mock_client_class.return_value.__aenter__.return_value = mock_client
738
+
739
+ deployments = await fireworks_finetune._fetch_all_deployments()
740
+
741
+ # Verify API was called correctly
742
+ mock_client.get.assert_called_once()
743
+
744
+ # Verify correct deployments were returned
745
+ assert deployments == mock_deployments
746
+ assert len(deployments) == 2
747
+ assert deployments[0]["id"] == "deploy-1"
748
+ assert deployments[1]["id"] == "deploy-2"
749
+
750
+
751
+ async def test_fetch_all_deployments_paged_case(fireworks_finetune, mock_api_key):
752
+ # Test with multiple pages of results
753
+ mock_deployments_page1 = [
754
+ {"id": "deploy-1", "baseModel": "model-1", "state": "READY"},
755
+ {"id": "deploy-2", "baseModel": "model-2", "state": "READY"},
756
+ ]
757
+
758
+ mock_deployments_page2 = [
759
+ {"id": "deploy-3", "baseModel": "model-3", "state": "READY"},
760
+ {"id": "deploy-4", "baseModel": "model-4", "state": "READY"},
761
+ ]
762
+
763
+ page1_response = MagicMock(spec=httpx.Response)
764
+ page1_response.status_code = 200
765
+ page1_response.json.return_value = {
766
+ "deployments": mock_deployments_page1,
767
+ "nextPageToken": "page2token",
768
+ }
769
+
770
+ page2_response = MagicMock(spec=httpx.Response)
771
+ page2_response.status_code = 200
772
+ page2_response.json.return_value = {
773
+ "deployments": mock_deployments_page2,
774
+ "nextPageToken": None,
775
+ }
776
+
777
+ with patch("httpx.AsyncClient") as mock_client_class:
778
+ mock_client = AsyncMock()
779
+ mock_client.get.side_effect = [page1_response, page2_response]
780
+ mock_client_class.return_value.__aenter__.return_value = mock_client
781
+
782
+ deployments = await fireworks_finetune._fetch_all_deployments()
783
+
784
+ # Verify API was called twice (once for each page)
785
+ assert mock_client.get.call_count == 2
786
+
787
+ # Verify first call had no page token
788
+ first_call_args = mock_client.get.call_args_list[0][1]
789
+ assert "pageToken" not in first_call_args["params"]
790
+
791
+ # Verify second call included the page token
792
+ second_call_args = mock_client.get.call_args_list[1][1]
793
+ assert second_call_args["params"]["pageToken"] == "page2token"
794
+
795
+ # Verify all deployments from both pages were returned
796
+ assert len(deployments) == 4
797
+ assert deployments == mock_deployments_page1 + mock_deployments_page2
798
+ for deployment in deployments:
799
+ assert deployment["id"] in [
800
+ "deploy-1",
801
+ "deploy-2",
802
+ "deploy-3",
803
+ "deploy-4",
804
+ ]
805
+
806
+
807
+ async def test_deploy_server_success(fireworks_finetune, mock_api_key):
808
+ # Mock response for successful deployment
809
+ success_response = MagicMock(spec=httpx.Response)
810
+ success_response.status_code = 200
811
+ success_response.json.return_value = {"baseModel": "model-123"}
812
+
813
+ status_response = (
814
+ FineTuneStatus(status=FineTuneStatusType.completed, message=""),
815
+ "model-123",
816
+ )
817
+
818
+ with (
819
+ patch("httpx.AsyncClient") as mock_client_class,
820
+ patch.object(
821
+ fireworks_finetune, "model_id_checking_status", return_value="model-123"
822
+ ),
823
+ ):
824
+ mock_client = AsyncMock()
825
+ mock_client.post.return_value = success_response
826
+ mock_client_class.return_value.__aenter__.return_value = mock_client
827
+
828
+ result = await fireworks_finetune._deploy_server()
829
+
830
+ # Verify result
831
+ assert result is True
832
+
833
+ # Verify fine_tune_model_id was updated
834
+ assert fireworks_finetune.datamodel.fine_tune_model_id == "model-123"
835
+
836
+ # Verify API was called with correct parameters
837
+ mock_client.post.assert_called_once()
838
+ call_args = mock_client.post.call_args[1]
839
+ assert "json" in call_args
840
+ assert call_args["json"]["baseModel"] == "model-123"
841
+ assert call_args["json"]["minReplicaCount"] == 0
842
+ assert "autoscalingPolicy" in call_args["json"]
843
+ assert call_args["json"]["autoscalingPolicy"]["scaleToZeroWindow"] == "300s"
844
+
845
+ # load the datamodel from the file and confirm the fine_tune_model_id was updated
846
+ loaded_datamodel = FinetuneModel.load_from_file(
847
+ fireworks_finetune.datamodel.path
848
+ )
849
+ assert loaded_datamodel.fine_tune_model_id == "model-123"
850
+
851
+
852
+ async def test_deploy_server_failure(fireworks_finetune, mock_api_key):
853
+ # Mock response for failed deployment
854
+ failure_response = MagicMock(spec=httpx.Response)
855
+ failure_response.status_code = 500
856
+ failure_response.text = "Internal Server Error"
857
+
858
+ with (
859
+ patch("httpx.AsyncClient") as mock_client_class,
860
+ patch.object(
861
+ fireworks_finetune, "model_id_checking_status", return_value="model-123"
862
+ ),
863
+ ):
864
+ mock_client = AsyncMock()
865
+ mock_client.post.return_value = failure_response
866
+ mock_client_class.return_value.__aenter__.return_value = mock_client
867
+
868
+ result = await fireworks_finetune._deploy_server()
869
+
870
+ # Verify result
871
+ assert result is False
872
+
873
+ # Verify API was called
874
+ mock_client.post.assert_called_once()
875
+
876
+
877
+ async def test_deploy_server_non_200_but_valid_response(
878
+ fireworks_finetune, mock_api_key
879
+ ):
880
+ # Mock response with non-200 status but valid JSON response
881
+ mixed_response = MagicMock(spec=httpx.Response)
882
+ mixed_response.status_code = 200
883
+ mixed_response.json.return_value = {"not_baseModel": "something-else"}
884
+
885
+ with (
886
+ patch("httpx.AsyncClient") as mock_client_class,
887
+ patch.object(
888
+ fireworks_finetune, "model_id_checking_status", return_value="model-123"
889
+ ),
890
+ ):
891
+ mock_client = AsyncMock()
892
+ mock_client.post.return_value = mixed_response
893
+ mock_client_class.return_value.__aenter__.return_value = mock_client
894
+
895
+ result = await fireworks_finetune._deploy_server()
896
+
897
+ # Verify result - should fail because baseModel is missing
898
+ assert result is False
899
+
900
+
901
+ async def test_deploy_server_missing_model_id(fireworks_finetune, mock_api_key):
902
+ # Test when model_id_checking_status returns None
903
+ with patch.object(
904
+ fireworks_finetune, "model_id_checking_status", return_value=None
905
+ ):
906
+ result = await fireworks_finetune._deploy_server()
907
+
908
+ # Verify result - should fail because model ID is missing
909
+ assert result is False
910
+
911
+
912
+ @pytest.mark.parametrize(
913
+ "state,expected_already_deployed",
914
+ [
915
+ ("READY", True),
916
+ ("CREATING", True),
917
+ ("FAILED", False),
918
+ ],
919
+ )
920
+ async def test_check_or_deploy_server_already_deployed(
921
+ fireworks_finetune, mock_api_key, state, expected_already_deployed
922
+ ):
923
+ # Test when model is already deployed (should return True without calling _deploy_server)
924
+
925
+ # Set a fine_tune_model_id so we search for deployments
926
+ fireworks_finetune.datamodel.fine_tune_model_id = "model-123"
927
+
928
+ # Mock deployments including one matching our model ID
929
+ mock_deployments = [
930
+ {"id": "deploy-1", "baseModel": "different-model", "state": "READY"},
931
+ {"id": "deploy-2", "baseModel": "model-123", "state": state},
932
+ ]
933
+
934
+ with (
935
+ patch.object(
936
+ fireworks_finetune, "_fetch_all_deployments", return_value=mock_deployments
937
+ ) as mock_fetch,
938
+ patch.object(fireworks_finetune, "_deploy_server") as mock_deploy,
939
+ ):
940
+ mock_deploy.return_value = True
941
+ result = await fireworks_finetune._check_or_deploy_server()
942
+ # Even true if the model is in a non-ready state, as we'll call deploy (checked below)
943
+ assert result is True
944
+
945
+ if expected_already_deployed:
946
+ assert mock_deploy.call_count == 0
947
+ else:
948
+ assert mock_deploy.call_count == 1
949
+
950
+ # Verify _fetch_all_deployments was called
951
+ mock_fetch.assert_called_once()
952
+
953
+
954
+ async def test_check_or_deploy_server_not_deployed(fireworks_finetune, mock_api_key):
955
+ # Test when model exists but isn't deployed (should call _deploy_server)
956
+
957
+ # Set a fine_tune_model_id so we search for deployments
958
+ fireworks_finetune.datamodel.fine_tune_model_id = "model-123"
959
+
960
+ # Mock deployments without our model ID
961
+ mock_deployments = [
962
+ {"id": "deploy-1", "baseModel": "different-model-1", "state": "READY"},
963
+ {"id": "deploy-2", "baseModel": "different-model-2", "state": "READY"},
964
+ ]
965
+
966
+ with (
967
+ patch.object(
968
+ fireworks_finetune, "_fetch_all_deployments", return_value=mock_deployments
969
+ ) as mock_fetch,
970
+ patch.object(
971
+ fireworks_finetune, "_deploy_server", return_value=True
972
+ ) as mock_deploy,
973
+ ):
974
+ result = await fireworks_finetune._check_or_deploy_server()
975
+
976
+ # Verify method returned True (from _deploy_server)
977
+ assert result is True
978
+
979
+ # Verify _fetch_all_deployments was called
980
+ mock_fetch.assert_called_once()
981
+
982
+ # Verify _deploy_server was called since model is not deployed
983
+ mock_deploy.assert_called_once()
984
+
985
+
986
+ async def test_check_or_deploy_server_no_model_id(fireworks_finetune, mock_api_key):
987
+ # Test when no fine_tune_model_id exists (should skip fetch and call _deploy_server directly)
988
+
989
+ # Ensure no fine_tune_model_id is set
990
+ fireworks_finetune.datamodel.fine_tune_model_id = None
991
+
992
+ with (
993
+ patch.object(fireworks_finetune, "_fetch_all_deployments") as mock_fetch,
994
+ patch.object(
995
+ fireworks_finetune, "_deploy_server", return_value=True
996
+ ) as mock_deploy,
997
+ ):
998
+ result = await fireworks_finetune._check_or_deploy_server()
999
+
1000
+ # Verify method returned True (from _deploy_server)
1001
+ assert result is True
1002
+
1003
+ # Verify _fetch_all_deployments was NOT called
1004
+ mock_fetch.assert_not_called()
1005
+
1006
+ # Verify _deploy_server was called directly
1007
+ mock_deploy.assert_called_once()
1008
+
1009
+
1010
+ async def test_check_or_deploy_server_deploy_fails(fireworks_finetune, mock_api_key):
1011
+ # Test when deployment fails
1012
+
1013
+ # Ensure no fine_tune_model_id is set
1014
+ fireworks_finetune.datamodel.fine_tune_model_id = None
1015
+
1016
+ with (
1017
+ patch.object(
1018
+ fireworks_finetune, "_deploy_server", return_value=False
1019
+ ) as mock_deploy,
1020
+ ):
1021
+ result = await fireworks_finetune._check_or_deploy_server()
1022
+
1023
+ # Verify method returned False (from _deploy_server)
1024
+ assert result is False
1025
+
1026
+ # Verify _deploy_server was called
1027
+ mock_deploy.assert_called_once()
1028
+
1029
+
1030
+ async def test_fetch_all_deployments_invalid_json(fireworks_finetune, mock_api_key):
1031
+ # Test with invalid JSON response (missing 'deployments' key)
1032
+ invalid_response = MagicMock(spec=httpx.Response)
1033
+ invalid_response.status_code = 200
1034
+ invalid_response.json.return_value = {
1035
+ "some_other_key": "value",
1036
+ # No 'deployments' key
1037
+ }
1038
+ invalid_response.text = '{"some_other_key": "value"}'
1039
+
1040
+ with patch("httpx.AsyncClient") as mock_client_class:
1041
+ mock_client = AsyncMock()
1042
+ mock_client.get.return_value = invalid_response
1043
+ mock_client_class.return_value.__aenter__.return_value = mock_client
1044
+
1045
+ with pytest.raises(
1046
+ ValueError,
1047
+ match="Invalid response from Fireworks. Expected list of deployments in 'deployments' key",
1048
+ ):
1049
+ await fireworks_finetune._fetch_all_deployments()
1050
+
1051
+ # Verify API was called
1052
+ mock_client.get.assert_called_once()
@@ -133,7 +133,7 @@ class KilnModelProvider(BaseModel):
133
133
  supports_structured_output: Whether the provider supports structured output formats
134
134
  supports_data_gen: Whether the provider supports data generation
135
135
  untested_model: Whether the model is untested (typically user added). The supports_ fields are not applicable.
136
- provider_finetune_id: The finetune ID for the provider, if applicable
136
+ provider_finetune_id: The finetune ID for the provider, if applicable. Some providers like Fireworks load these from an API.
137
137
  structured_output_mode: The mode we should use to call the model for structured output, if it was trained with structured output.
138
138
  parser: A parser to use for the model, if applicable
139
139
  reasoning_capable: Whether the model is designed to output thinking in a structured format (eg <think></think>). If so we don't use COT across 2 calls, and ask for thinking and final response in the same call.
@@ -576,7 +576,6 @@ built_in_models: List[KilnModel] = [
576
576
  # JSON mode not ideal (no schema), but tool calling doesn't work on 8b
577
577
  structured_output_mode=StructuredOutputMode.json_instruction_and_object,
578
578
  supports_data_gen=False,
579
- provider_finetune_id="accounts/fireworks/models/llama-v3p1-8b-instruct",
580
579
  model_id="accounts/fireworks/models/llama-v3p1-8b-instruct",
581
580
  ),
582
581
  KilnModelProvider(
@@ -618,7 +617,6 @@ built_in_models: List[KilnModel] = [
618
617
  name=ModelProviderName.fireworks_ai,
619
618
  # Tool calling forces schema -- fireworks doesn't support json_schema, just json_mode
620
619
  structured_output_mode=StructuredOutputMode.function_calling_weak,
621
- provider_finetune_id="accounts/fireworks/models/llama-v3p1-70b-instruct",
622
620
  model_id="accounts/fireworks/models/llama-v3p1-70b-instruct",
623
621
  ),
624
622
  KilnModelProvider(
@@ -764,7 +762,6 @@ built_in_models: List[KilnModel] = [
764
762
  ),
765
763
  KilnModelProvider(
766
764
  name=ModelProviderName.fireworks_ai,
767
- provider_finetune_id="accounts/fireworks/models/llama-v3p2-3b-instruct",
768
765
  supports_structured_output=False,
769
766
  supports_data_gen=False,
770
767
  model_id="accounts/fireworks/models/llama-v3p2-3b-instruct",
@@ -890,8 +887,6 @@ built_in_models: List[KilnModel] = [
890
887
  ),
891
888
  KilnModelProvider(
892
889
  name=ModelProviderName.fireworks_ai,
893
- # Finetuning not live yet
894
- # provider_finetune_id="accounts/fireworks/models/llama-v3p3-70b-instruct",
895
890
  # Tool calling forces schema -- fireworks doesn't support json_schema, just json_mode
896
891
  structured_output_mode=StructuredOutputMode.function_calling_weak,
897
892
  model_id="accounts/fireworks/models/llama-v3p3-70b-instruct",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kiln-ai
3
- Version: 0.13.2
3
+ Version: 0.14.0
4
4
  Summary: Kiln AI
5
5
  Project-URL: Homepage, https://getkiln.ai
6
6
  Project-URL: Repository, https://github.com/Kiln-AI/kiln
@@ -1,7 +1,7 @@
1
1
  kiln_ai/__init__.py,sha256=Sc4z8LRVFMwJUoc_DPVUriSXTZ6PO9MaJ80PhRbKyB8,34
2
2
  kiln_ai/adapters/__init__.py,sha256=XjGmWagEyOEVwVIAxjN5rYNsQWIEACT5DB7MMTxdPss,1005
3
3
  kiln_ai/adapters/adapter_registry.py,sha256=KmMHYQ3mxpjVLE6D-hMNWCGt6Cw9JvnFn6nMb48GE8Y,9166
4
- kiln_ai/adapters/ml_model_list.py,sha256=u1nFkJm_UD1IZjBBoynmWnhx_aPkuvSuHVI69Thma3w,58939
4
+ kiln_ai/adapters/ml_model_list.py,sha256=f_z1daFR_w4-ccJ4OWwqlIMY0ILFJt4X5LdQb3AMt_c,58592
5
5
  kiln_ai/adapters/ollama_tools.py,sha256=uObtLWfqKb9RXHN-TGGw2Y1FQlEMe0u8FgszI0zQn6U,3550
6
6
  kiln_ai/adapters/prompt_builders.py,sha256=LYHTIaisQMBFtWDRIGo1QJgOsmQ-NBpQ8fI4eImHxaQ,15269
7
7
  kiln_ai/adapters/provider_tools.py,sha256=UL3XEnnxs1TrbqPPxxHSvnL7aBd84ggh38lI0yEsX6A,14725
@@ -26,14 +26,14 @@ kiln_ai/adapters/eval/test_eval_runner.py,sha256=82WPE_frNRTSQ2lylqT0inkqcDgM72n
26
26
  kiln_ai/adapters/eval/test_g_eval.py,sha256=-Stx7E0D-WAH1HWrRSp48CiGsf-no1SHeFF9IqVXeMI,16433
27
27
  kiln_ai/adapters/eval/test_g_eval_data.py,sha256=8caiZfLWnXVX8alrBPrH7L7gqqSS9vO7u6PzcHurQcA,27769
28
28
  kiln_ai/adapters/fine_tune/__init__.py,sha256=DxdTR60chwgck1aEoVYWyfWi6Ed2ZkdJj0lar-SEAj4,257
29
- kiln_ai/adapters/fine_tune/base_finetune.py,sha256=MxSnBiapWfZQw5UmkYAtC0QXj2zDeF9Ows0k0g3p1IA,6455
29
+ kiln_ai/adapters/fine_tune/base_finetune.py,sha256=ORTclQTQYksMWPu7vNoD7wBzOIqNVK0YOwFEnvsKPWA,5759
30
30
  kiln_ai/adapters/fine_tune/dataset_formatter.py,sha256=qRhSSkMhTWn13OMb6LKPVwAU7uY4bB49GDiVSuhDkNg,14449
31
31
  kiln_ai/adapters/fine_tune/finetune_registry.py,sha256=CvcEVxtKwjgCMA-oYH9Tpjn1DVWmMzgHpXJOZ0YQA8k,610
32
- kiln_ai/adapters/fine_tune/fireworks_finetune.py,sha256=ZBS45ji9j88fFd3O9OentAUflAz716YEmz9176Ln7bU,14284
32
+ kiln_ai/adapters/fine_tune/fireworks_finetune.py,sha256=OlXp8j6Afwvk6-ySwA3Q7iuqBlKO7VLeAfNCnB3pZPI,19963
33
33
  kiln_ai/adapters/fine_tune/openai_finetune.py,sha256=Dz9E_0BWfrIkvv8ArZe-RKPwbIKPZ3v8rfbc3JELyTY,8571
34
- kiln_ai/adapters/fine_tune/test_base_finetune.py,sha256=0zWxFYrDGVuoQNQmi9vVUEkBc4mstfHnsUjQmiJA-sE,10864
34
+ kiln_ai/adapters/fine_tune/test_base_finetune.py,sha256=sjuDgJDA_dynGRelx9_wXdssaxAYIuEG-Z8NzRx9Hl0,10559
35
35
  kiln_ai/adapters/fine_tune/test_dataset_formatter.py,sha256=T3jbFZooLVBaGCE0LUVxwPxzM3l8IY41zUj3jPk-Zi8,24027
36
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py,sha256=vBvkbTYVvsimxM6fTSeOnVdFldovV5flc1qT9QjPuNE,18961
36
+ kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py,sha256=oLyLEG4TwW452lV2mvUo-wImLxzSwOuoKKeYFuGh3k8,36744
37
37
  kiln_ai/adapters/fine_tune/test_openai_finetune.py,sha256=H63Xk2PNHbt5Ev5IQpdR9JZ4uz-Huo2gfuC4mHHqe0w,20011
38
38
  kiln_ai/adapters/fine_tune/test_together_finetune.py,sha256=BUJFsyq_g77gU0JN3hg6FMBvqb0DIyTeAek-wxomKIg,18090
39
39
  kiln_ai/adapters/fine_tune/together_finetune.py,sha256=EbMPsTyKMubfwOalkFLiNFlMFIRKxLibzMTyLeUkle4,14010
@@ -97,7 +97,7 @@ kiln_ai/utils/name_generator.py,sha256=v26TgpCwQbhQFcZvzgjZvURinjrOyyFhxpsI6NQrH
97
97
  kiln_ai/utils/test_config.py,sha256=Jw3nMFeIgZUsZDRJJY2HpB-2EkR2NoZ-rDe_o9oA7ws,9174
98
98
  kiln_ai/utils/test_dataset_import.py,sha256=ZZOt7zqtaEIlMMx0VNXyRegDvnVqbWY2bcz-iMY_Oag,17427
99
99
  kiln_ai/utils/test_name_geneator.py,sha256=9-hSTBshyakqlPbFnNcggwLrL7lcPTitauBYHg9jFWI,1513
100
- kiln_ai-0.13.2.dist-info/METADATA,sha256=VVYhbE6IrTwP496RZ4ZcMizIJFW6Sur7a3qlwiUD3D4,12231
101
- kiln_ai-0.13.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
102
- kiln_ai-0.13.2.dist-info/licenses/LICENSE.txt,sha256=_NA5pnTYgRRr4qH6lE3X-TuZJ8iRcMUi5ASoGr-lEx8,1209
103
- kiln_ai-0.13.2.dist-info/RECORD,,
100
+ kiln_ai-0.14.0.dist-info/METADATA,sha256=EjgZOnknE7P9uW5BsIFJZYQAN-aUQ817SAEXjtqtjK0,12231
101
+ kiln_ai-0.14.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
102
+ kiln_ai-0.14.0.dist-info/licenses/LICENSE.txt,sha256=_NA5pnTYgRRr4qH6lE3X-TuZJ8iRcMUi5ASoGr-lEx8,1209
103
+ kiln_ai-0.14.0.dist-info/RECORD,,