kiln-ai 0.13.2__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

@@ -315,7 +315,7 @@ async def test_generate_and_upload_jsonl_success(
315
315
  "thinking_instructions": thinking_instructions,
316
316
  }
317
317
 
318
- assert result == mock_dataset_id
318
+ assert result == "kiln-" + mock_dataset_id
319
319
  assert mock_client.post.call_count == 2
320
320
  assert mock_client.get.call_count == 1
321
321
 
@@ -448,7 +448,7 @@ def test_available_parameters(fireworks_finetune):
448
448
  assert payload_parameters == {"loraRank": 16, "epochs": 3}
449
449
 
450
450
 
451
- async def test_deploy_success(fireworks_finetune, mock_api_key):
451
+ async def test_deploy_serverless_success(fireworks_finetune, mock_api_key):
452
452
  # Mock response for successful deployment
453
453
  success_response = MagicMock(spec=httpx.Response)
454
454
  success_response.status_code = 200
@@ -467,12 +467,12 @@ async def test_deploy_success(fireworks_finetune, mock_api_key):
467
467
  mock_client.post.return_value = success_response
468
468
  mock_client_class.return_value.__aenter__.return_value = mock_client
469
469
 
470
- result = await fireworks_finetune._deploy()
470
+ result = await fireworks_finetune._deploy_serverless()
471
471
  assert result is True
472
472
  assert fireworks_finetune.datamodel.fine_tune_model_id == "ftm-123"
473
473
 
474
474
 
475
- async def test_deploy_already_deployed(fireworks_finetune, mock_api_key):
475
+ async def test_deploy_serverless_already_deployed(fireworks_finetune, mock_api_key):
476
476
  # Mock response for already deployed model
477
477
  already_deployed_response = MagicMock(spec=httpx.Response)
478
478
  already_deployed_response.status_code = 400
@@ -494,12 +494,12 @@ async def test_deploy_already_deployed(fireworks_finetune, mock_api_key):
494
494
  mock_client.post.return_value = already_deployed_response
495
495
  mock_client_class.return_value.__aenter__.return_value = mock_client
496
496
 
497
- result = await fireworks_finetune._deploy()
497
+ result = await fireworks_finetune._deploy_serverless()
498
498
  assert result is True
499
499
  assert fireworks_finetune.datamodel.fine_tune_model_id == "ftm-123"
500
500
 
501
501
 
502
- async def test_deploy_failure(fireworks_finetune, mock_api_key):
502
+ async def test_deploy_serverless_failure(fireworks_finetune, mock_api_key):
503
503
  # Mock response for failed deployment
504
504
  failure_response = MagicMock(spec=httpx.Response)
505
505
  failure_response.status_code = 500
@@ -510,18 +510,28 @@ async def test_deploy_failure(fireworks_finetune, mock_api_key):
510
510
  mock_client.post.return_value = failure_response
511
511
  mock_client_class.return_value.__aenter__.return_value = mock_client
512
512
 
513
- result = await fireworks_finetune._deploy()
513
+ result = await fireworks_finetune._deploy_serverless()
514
514
  assert result is False
515
515
 
516
516
 
517
- async def test_deploy_missing_credentials(fireworks_finetune):
517
+ async def test_deploy_serverless_missing_credentials(fireworks_finetune):
518
518
  # Test missing API key or account ID
519
519
  with patch.object(Config, "shared") as mock_config:
520
520
  mock_config.return_value.fireworks_api_key = None
521
521
  mock_config.return_value.fireworks_account_id = None
522
522
 
523
523
  with pytest.raises(ValueError, match="Fireworks API key or account ID not set"):
524
- await fireworks_finetune._deploy()
524
+ await fireworks_finetune._deploy_serverless()
525
+
526
+
527
+ async def test_deploy_server_missing_credentials(fireworks_finetune):
528
+ # Test missing API key or account ID
529
+ with patch.object(Config, "shared") as mock_config:
530
+ mock_config.return_value.fireworks_api_key = None
531
+ mock_config.return_value.fireworks_account_id = None
532
+
533
+ response = await fireworks_finetune._check_or_deploy_server()
534
+ assert response is False
525
535
 
526
536
 
527
537
  async def test_deploy_missing_model_id(fireworks_finetune, mock_api_key):
@@ -564,3 +574,479 @@ async def test_status_with_deploy(fireworks_finetune, mock_api_key):
564
574
  # Verify message was updated due to failed deployment
565
575
  assert status.status == FineTuneStatusType.completed
566
576
  assert status.message == "Fine-tuning job completed but failed to deploy model."
577
+
578
+
579
+ @pytest.mark.paid
580
+ async def test_fetch_all_deployments(fireworks_finetune):
581
+ deployments = await fireworks_finetune._fetch_all_deployments()
582
+ assert isinstance(deployments, list)
583
+
584
+
585
+ async def test_api_key_and_account_id(fireworks_finetune, mock_api_key):
586
+ # Test successful retrieval of API key and account ID
587
+ api_key, account_id = fireworks_finetune.api_key_and_account_id()
588
+ assert api_key == "test-api-key"
589
+ assert account_id == "test-account-id"
590
+
591
+
592
+ async def test_api_key_and_account_id_missing_credentials(fireworks_finetune):
593
+ # Test missing API key or account ID
594
+ with patch.object(Config, "shared") as mock_config:
595
+ mock_config.return_value.fireworks_api_key = None
596
+ mock_config.return_value.fireworks_account_id = None
597
+
598
+ with pytest.raises(ValueError, match="Fireworks API key or account ID not set"):
599
+ fireworks_finetune.api_key_and_account_id()
600
+
601
+
602
+ def test_deployment_display_name(fireworks_finetune):
603
+ # Test with default ID and name
604
+ display_name = fireworks_finetune.deployment_display_name()
605
+ expected = f"Kiln AI fine-tuned model [ID:{fireworks_finetune.datamodel.id}][name:test-finetune]"[
606
+ :60
607
+ ]
608
+ assert display_name == expected
609
+
610
+ # Test with a very long name to ensure 60 character limit
611
+ fireworks_finetune.datamodel.name = "x" * 100
612
+ display_name = fireworks_finetune.deployment_display_name()
613
+ assert len(display_name) == 60
614
+ assert display_name.startswith("Kiln AI fine-tuned model [ID:")
615
+
616
+
617
+ async def test_model_id_checking_status_completed(fireworks_finetune):
618
+ # Test with completed status and valid model ID
619
+ status_response = (
620
+ FineTuneStatus(status=FineTuneStatusType.completed, message=""),
621
+ "model-123",
622
+ )
623
+
624
+ with patch.object(fireworks_finetune, "_status", return_value=status_response):
625
+ model_id = await fireworks_finetune.model_id_checking_status()
626
+ assert model_id == "model-123"
627
+
628
+
629
+ async def test_model_id_checking_status_not_completed(fireworks_finetune):
630
+ # Test with non-completed status
631
+ status_response = (
632
+ FineTuneStatus(status=FineTuneStatusType.running, message=""),
633
+ "model-123",
634
+ )
635
+
636
+ with patch.object(fireworks_finetune, "_status", return_value=status_response):
637
+ model_id = await fireworks_finetune.model_id_checking_status()
638
+ assert model_id is None
639
+
640
+
641
+ async def test_model_id_checking_status_invalid_model_id(fireworks_finetune):
642
+ # Test with completed status but invalid model ID
643
+ status_response = (
644
+ FineTuneStatus(status=FineTuneStatusType.completed, message=""),
645
+ None,
646
+ )
647
+
648
+ with patch.object(fireworks_finetune, "_status", return_value=status_response):
649
+ model_id = await fireworks_finetune.model_id_checking_status()
650
+ assert model_id is None
651
+
652
+ # Test with non-string model ID
653
+ status_response = (
654
+ FineTuneStatus(status=FineTuneStatusType.completed, message=""),
655
+ {"id": "model-123"}, # Not a string
656
+ )
657
+
658
+ with patch.object(fireworks_finetune, "_status", return_value=status_response):
659
+ model_id = await fireworks_finetune.model_id_checking_status()
660
+ assert model_id is None
661
+
662
+
663
+ @pytest.mark.parametrize(
664
+ "base_model_id,expected_method",
665
+ [
666
+ ("accounts/fireworks/models/llama-v3p1-8b-instruct", "_deploy_serverless"),
667
+ ("accounts/fireworks/models/llama-v3p1-70b-instruct", "_deploy_serverless"),
668
+ ("some-other-model", "_check_or_deploy_server"),
669
+ ],
670
+ )
671
+ async def test_deploy_model_selection(
672
+ fireworks_finetune, base_model_id, expected_method, mock_api_key
673
+ ):
674
+ # Set the base model ID
675
+ fireworks_finetune.datamodel.base_model_id = base_model_id
676
+
677
+ # Mock the deployment methods
678
+ with (
679
+ patch.object(
680
+ fireworks_finetune, "_deploy_serverless", return_value=True
681
+ ) as mock_serverless,
682
+ patch.object(
683
+ fireworks_finetune, "_check_or_deploy_server", return_value=True
684
+ ) as mock_server,
685
+ ):
686
+ result = await fireworks_finetune._deploy()
687
+
688
+ # Verify the correct method was called based on the model
689
+ if expected_method == "_deploy_serverless":
690
+ mock_serverless.assert_called_once()
691
+ mock_server.assert_not_called()
692
+ else:
693
+ mock_serverless.assert_not_called()
694
+ mock_server.assert_called_once()
695
+
696
+ assert result is True
697
+
698
+
699
+ async def test_fetch_all_deployments_request_error(fireworks_finetune, mock_api_key):
700
+ # Test with error response
701
+ error_response = MagicMock(spec=httpx.Response)
702
+ error_response.status_code = 500
703
+ error_response.text = "Internal Server Error"
704
+
705
+ with patch("httpx.AsyncClient") as mock_client_class:
706
+ mock_client = AsyncMock()
707
+ mock_client.get.side_effect = Exception("API request failed")
708
+ mock_client_class.return_value.__aenter__.return_value = mock_client
709
+
710
+ with pytest.raises(Exception, match="API request failed"):
711
+ await fireworks_finetune._fetch_all_deployments()
712
+
713
+ # Verify API was called with correct parameters
714
+ mock_client.get.assert_called_once()
715
+ call_args = mock_client.get.call_args[1]
716
+ assert "params" in call_args
717
+ assert call_args["params"]["pageSize"] == 200
718
+
719
+
720
+ async def test_fetch_all_deployments_standard_case(fireworks_finetune, mock_api_key):
721
+ # Test with single page of results
722
+ mock_deployments = [
723
+ {"id": "deploy-1", "baseModel": "model-1", "state": "READY"},
724
+ {"id": "deploy-2", "baseModel": "model-2", "state": "READY"},
725
+ ]
726
+
727
+ success_response = MagicMock(spec=httpx.Response)
728
+ success_response.status_code = 200
729
+ success_response.json.return_value = {
730
+ "deployments": mock_deployments,
731
+ "nextPageToken": None,
732
+ }
733
+
734
+ with patch("httpx.AsyncClient") as mock_client_class:
735
+ mock_client = AsyncMock()
736
+ mock_client.get.return_value = success_response
737
+ mock_client_class.return_value.__aenter__.return_value = mock_client
738
+
739
+ deployments = await fireworks_finetune._fetch_all_deployments()
740
+
741
+ # Verify API was called correctly
742
+ mock_client.get.assert_called_once()
743
+
744
+ # Verify correct deployments were returned
745
+ assert deployments == mock_deployments
746
+ assert len(deployments) == 2
747
+ assert deployments[0]["id"] == "deploy-1"
748
+ assert deployments[1]["id"] == "deploy-2"
749
+
750
+
751
+ async def test_fetch_all_deployments_paged_case(fireworks_finetune, mock_api_key):
752
+ # Test with multiple pages of results
753
+ mock_deployments_page1 = [
754
+ {"id": "deploy-1", "baseModel": "model-1", "state": "READY"},
755
+ {"id": "deploy-2", "baseModel": "model-2", "state": "READY"},
756
+ ]
757
+
758
+ mock_deployments_page2 = [
759
+ {"id": "deploy-3", "baseModel": "model-3", "state": "READY"},
760
+ {"id": "deploy-4", "baseModel": "model-4", "state": "READY"},
761
+ ]
762
+
763
+ page1_response = MagicMock(spec=httpx.Response)
764
+ page1_response.status_code = 200
765
+ page1_response.json.return_value = {
766
+ "deployments": mock_deployments_page1,
767
+ "nextPageToken": "page2token",
768
+ }
769
+
770
+ page2_response = MagicMock(spec=httpx.Response)
771
+ page2_response.status_code = 200
772
+ page2_response.json.return_value = {
773
+ "deployments": mock_deployments_page2,
774
+ "nextPageToken": None,
775
+ }
776
+
777
+ with patch("httpx.AsyncClient") as mock_client_class:
778
+ mock_client = AsyncMock()
779
+ mock_client.get.side_effect = [page1_response, page2_response]
780
+ mock_client_class.return_value.__aenter__.return_value = mock_client
781
+
782
+ deployments = await fireworks_finetune._fetch_all_deployments()
783
+
784
+ # Verify API was called twice (once for each page)
785
+ assert mock_client.get.call_count == 2
786
+
787
+ # Verify first call had no page token
788
+ first_call_args = mock_client.get.call_args_list[0][1]
789
+ assert "pageToken" not in first_call_args["params"]
790
+
791
+ # Verify second call included the page token
792
+ second_call_args = mock_client.get.call_args_list[1][1]
793
+ assert second_call_args["params"]["pageToken"] == "page2token"
794
+
795
+ # Verify all deployments from both pages were returned
796
+ assert len(deployments) == 4
797
+ assert deployments == mock_deployments_page1 + mock_deployments_page2
798
+ for deployment in deployments:
799
+ assert deployment["id"] in [
800
+ "deploy-1",
801
+ "deploy-2",
802
+ "deploy-3",
803
+ "deploy-4",
804
+ ]
805
+
806
+
807
+ async def test_deploy_server_success(fireworks_finetune, mock_api_key):
808
+ # Mock response for successful deployment
809
+ success_response = MagicMock(spec=httpx.Response)
810
+ success_response.status_code = 200
811
+ success_response.json.return_value = {"baseModel": "model-123"}
812
+
813
+ status_response = (
814
+ FineTuneStatus(status=FineTuneStatusType.completed, message=""),
815
+ "model-123",
816
+ )
817
+
818
+ with (
819
+ patch("httpx.AsyncClient") as mock_client_class,
820
+ patch.object(
821
+ fireworks_finetune, "model_id_checking_status", return_value="model-123"
822
+ ),
823
+ ):
824
+ mock_client = AsyncMock()
825
+ mock_client.post.return_value = success_response
826
+ mock_client_class.return_value.__aenter__.return_value = mock_client
827
+
828
+ result = await fireworks_finetune._deploy_server()
829
+
830
+ # Verify result
831
+ assert result is True
832
+
833
+ # Verify fine_tune_model_id was updated
834
+ assert fireworks_finetune.datamodel.fine_tune_model_id == "model-123"
835
+
836
+ # Verify API was called with correct parameters
837
+ mock_client.post.assert_called_once()
838
+ call_args = mock_client.post.call_args[1]
839
+ assert "json" in call_args
840
+ assert call_args["json"]["baseModel"] == "model-123"
841
+ assert call_args["json"]["minReplicaCount"] == 0
842
+ assert "autoscalingPolicy" in call_args["json"]
843
+ assert call_args["json"]["autoscalingPolicy"]["scaleToZeroWindow"] == "300s"
844
+
845
+ # load the datamodel from the file and confirm the fine_tune_model_id was updated
846
+ loaded_datamodel = FinetuneModel.load_from_file(
847
+ fireworks_finetune.datamodel.path
848
+ )
849
+ assert loaded_datamodel.fine_tune_model_id == "model-123"
850
+
851
+
852
+ async def test_deploy_server_failure(fireworks_finetune, mock_api_key):
853
+ # Mock response for failed deployment
854
+ failure_response = MagicMock(spec=httpx.Response)
855
+ failure_response.status_code = 500
856
+ failure_response.text = "Internal Server Error"
857
+
858
+ with (
859
+ patch("httpx.AsyncClient") as mock_client_class,
860
+ patch.object(
861
+ fireworks_finetune, "model_id_checking_status", return_value="model-123"
862
+ ),
863
+ ):
864
+ mock_client = AsyncMock()
865
+ mock_client.post.return_value = failure_response
866
+ mock_client_class.return_value.__aenter__.return_value = mock_client
867
+
868
+ result = await fireworks_finetune._deploy_server()
869
+
870
+ # Verify result
871
+ assert result is False
872
+
873
+ # Verify API was called
874
+ mock_client.post.assert_called_once()
875
+
876
+
877
+ async def test_deploy_server_non_200_but_valid_response(
878
+ fireworks_finetune, mock_api_key
879
+ ):
880
+ # Mock response with non-200 status but valid JSON response
881
+ mixed_response = MagicMock(spec=httpx.Response)
882
+ mixed_response.status_code = 200
883
+ mixed_response.json.return_value = {"not_baseModel": "something-else"}
884
+
885
+ with (
886
+ patch("httpx.AsyncClient") as mock_client_class,
887
+ patch.object(
888
+ fireworks_finetune, "model_id_checking_status", return_value="model-123"
889
+ ),
890
+ ):
891
+ mock_client = AsyncMock()
892
+ mock_client.post.return_value = mixed_response
893
+ mock_client_class.return_value.__aenter__.return_value = mock_client
894
+
895
+ result = await fireworks_finetune._deploy_server()
896
+
897
+ # Verify result - should fail because baseModel is missing
898
+ assert result is False
899
+
900
+
901
+ async def test_deploy_server_missing_model_id(fireworks_finetune, mock_api_key):
902
+ # Test when model_id_checking_status returns None
903
+ with patch.object(
904
+ fireworks_finetune, "model_id_checking_status", return_value=None
905
+ ):
906
+ result = await fireworks_finetune._deploy_server()
907
+
908
+ # Verify result - should fail because model ID is missing
909
+ assert result is False
910
+
911
+
912
+ @pytest.mark.parametrize(
913
+ "state,expected_already_deployed",
914
+ [
915
+ ("READY", True),
916
+ ("CREATING", True),
917
+ ("FAILED", False),
918
+ ],
919
+ )
920
+ async def test_check_or_deploy_server_already_deployed(
921
+ fireworks_finetune, mock_api_key, state, expected_already_deployed
922
+ ):
923
+ # Test when model is already deployed (should return True without calling _deploy_server)
924
+
925
+ # Set a fine_tune_model_id so we search for deployments
926
+ fireworks_finetune.datamodel.fine_tune_model_id = "model-123"
927
+
928
+ # Mock deployments including one matching our model ID
929
+ mock_deployments = [
930
+ {"id": "deploy-1", "baseModel": "different-model", "state": "READY"},
931
+ {"id": "deploy-2", "baseModel": "model-123", "state": state},
932
+ ]
933
+
934
+ with (
935
+ patch.object(
936
+ fireworks_finetune, "_fetch_all_deployments", return_value=mock_deployments
937
+ ) as mock_fetch,
938
+ patch.object(fireworks_finetune, "_deploy_server") as mock_deploy,
939
+ ):
940
+ mock_deploy.return_value = True
941
+ result = await fireworks_finetune._check_or_deploy_server()
942
+ # Even true if the model is in a non-ready state, as we'll call deploy (checked below)
943
+ assert result is True
944
+
945
+ if expected_already_deployed:
946
+ assert mock_deploy.call_count == 0
947
+ else:
948
+ assert mock_deploy.call_count == 1
949
+
950
+ # Verify _fetch_all_deployments was called
951
+ mock_fetch.assert_called_once()
952
+
953
+
954
+ async def test_check_or_deploy_server_not_deployed(fireworks_finetune, mock_api_key):
955
+ # Test when model exists but isn't deployed (should call _deploy_server)
956
+
957
+ # Set a fine_tune_model_id so we search for deployments
958
+ fireworks_finetune.datamodel.fine_tune_model_id = "model-123"
959
+
960
+ # Mock deployments without our model ID
961
+ mock_deployments = [
962
+ {"id": "deploy-1", "baseModel": "different-model-1", "state": "READY"},
963
+ {"id": "deploy-2", "baseModel": "different-model-2", "state": "READY"},
964
+ ]
965
+
966
+ with (
967
+ patch.object(
968
+ fireworks_finetune, "_fetch_all_deployments", return_value=mock_deployments
969
+ ) as mock_fetch,
970
+ patch.object(
971
+ fireworks_finetune, "_deploy_server", return_value=True
972
+ ) as mock_deploy,
973
+ ):
974
+ result = await fireworks_finetune._check_or_deploy_server()
975
+
976
+ # Verify method returned True (from _deploy_server)
977
+ assert result is True
978
+
979
+ # Verify _fetch_all_deployments was called
980
+ mock_fetch.assert_called_once()
981
+
982
+ # Verify _deploy_server was called since model is not deployed
983
+ mock_deploy.assert_called_once()
984
+
985
+
986
+ async def test_check_or_deploy_server_no_model_id(fireworks_finetune, mock_api_key):
987
+ # Test when no fine_tune_model_id exists (should skip fetch and call _deploy_server directly)
988
+
989
+ # Ensure no fine_tune_model_id is set
990
+ fireworks_finetune.datamodel.fine_tune_model_id = None
991
+
992
+ with (
993
+ patch.object(fireworks_finetune, "_fetch_all_deployments") as mock_fetch,
994
+ patch.object(
995
+ fireworks_finetune, "_deploy_server", return_value=True
996
+ ) as mock_deploy,
997
+ ):
998
+ result = await fireworks_finetune._check_or_deploy_server()
999
+
1000
+ # Verify method returned True (from _deploy_server)
1001
+ assert result is True
1002
+
1003
+ # Verify _fetch_all_deployments was NOT called
1004
+ mock_fetch.assert_not_called()
1005
+
1006
+ # Verify _deploy_server was called directly
1007
+ mock_deploy.assert_called_once()
1008
+
1009
+
1010
+ async def test_check_or_deploy_server_deploy_fails(fireworks_finetune, mock_api_key):
1011
+ # Test when deployment fails
1012
+
1013
+ # Ensure no fine_tune_model_id is set
1014
+ fireworks_finetune.datamodel.fine_tune_model_id = None
1015
+
1016
+ with (
1017
+ patch.object(
1018
+ fireworks_finetune, "_deploy_server", return_value=False
1019
+ ) as mock_deploy,
1020
+ ):
1021
+ result = await fireworks_finetune._check_or_deploy_server()
1022
+
1023
+ # Verify method returned False (from _deploy_server)
1024
+ assert result is False
1025
+
1026
+ # Verify _deploy_server was called
1027
+ mock_deploy.assert_called_once()
1028
+
1029
+
1030
+ async def test_fetch_all_deployments_invalid_json(fireworks_finetune, mock_api_key):
1031
+ # Test with invalid JSON response (missing 'deployments' key)
1032
+ invalid_response = MagicMock(spec=httpx.Response)
1033
+ invalid_response.status_code = 200
1034
+ invalid_response.json.return_value = {
1035
+ "some_other_key": "value",
1036
+ # No 'deployments' key
1037
+ }
1038
+ invalid_response.text = '{"some_other_key": "value"}'
1039
+
1040
+ with patch("httpx.AsyncClient") as mock_client_class:
1041
+ mock_client = AsyncMock()
1042
+ mock_client.get.return_value = invalid_response
1043
+ mock_client_class.return_value.__aenter__.return_value = mock_client
1044
+
1045
+ with pytest.raises(
1046
+ ValueError,
1047
+ match="Invalid response from Fireworks. Expected list of deployments in 'deployments' key",
1048
+ ):
1049
+ await fireworks_finetune._fetch_all_deployments()
1050
+
1051
+ # Verify API was called
1052
+ mock_client.get.assert_called_once()