kiln-ai 0.13.2__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/eval/base_eval.py +7 -2
- kiln_ai/adapters/fine_tune/base_finetune.py +6 -23
- kiln_ai/adapters/fine_tune/dataset_formatter.py +4 -4
- kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +163 -15
- kiln_ai/adapters/fine_tune/test_base_finetune.py +7 -9
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +3 -3
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +495 -9
- kiln_ai/adapters/fine_tune/test_vertex_finetune.py +586 -0
- kiln_ai/adapters/fine_tune/vertex_finetune.py +217 -0
- kiln_ai/adapters/ml_model_list.py +319 -43
- kiln_ai/adapters/model_adapters/base_adapter.py +15 -10
- kiln_ai/adapters/model_adapters/litellm_adapter.py +10 -5
- kiln_ai/adapters/provider_tools.py +7 -0
- kiln_ai/adapters/test_provider_tools.py +16 -0
- kiln_ai/datamodel/json_schema.py +24 -7
- kiln_ai/datamodel/task_output.py +9 -5
- kiln_ai/datamodel/task_run.py +29 -5
- kiln_ai/datamodel/test_example_models.py +104 -3
- kiln_ai/datamodel/test_json_schema.py +22 -3
- kiln_ai/datamodel/test_model_perf.py +3 -2
- {kiln_ai-0.13.2.dist-info → kiln_ai-0.15.0.dist-info}/METADATA +3 -2
- {kiln_ai-0.13.2.dist-info → kiln_ai-0.15.0.dist-info}/RECORD +25 -24
- kiln_ai/adapters/test_generate_docs.py +0 -69
- {kiln_ai-0.13.2.dist-info → kiln_ai-0.15.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.13.2.dist-info → kiln_ai-0.15.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -315,7 +315,7 @@ async def test_generate_and_upload_jsonl_success(
|
|
|
315
315
|
"thinking_instructions": thinking_instructions,
|
|
316
316
|
}
|
|
317
317
|
|
|
318
|
-
assert result == mock_dataset_id
|
|
318
|
+
assert result == "kiln-" + mock_dataset_id
|
|
319
319
|
assert mock_client.post.call_count == 2
|
|
320
320
|
assert mock_client.get.call_count == 1
|
|
321
321
|
|
|
@@ -448,7 +448,7 @@ def test_available_parameters(fireworks_finetune):
|
|
|
448
448
|
assert payload_parameters == {"loraRank": 16, "epochs": 3}
|
|
449
449
|
|
|
450
450
|
|
|
451
|
-
async def
|
|
451
|
+
async def test_deploy_serverless_success(fireworks_finetune, mock_api_key):
|
|
452
452
|
# Mock response for successful deployment
|
|
453
453
|
success_response = MagicMock(spec=httpx.Response)
|
|
454
454
|
success_response.status_code = 200
|
|
@@ -467,12 +467,12 @@ async def test_deploy_success(fireworks_finetune, mock_api_key):
|
|
|
467
467
|
mock_client.post.return_value = success_response
|
|
468
468
|
mock_client_class.return_value.__aenter__.return_value = mock_client
|
|
469
469
|
|
|
470
|
-
result = await fireworks_finetune.
|
|
470
|
+
result = await fireworks_finetune._deploy_serverless()
|
|
471
471
|
assert result is True
|
|
472
472
|
assert fireworks_finetune.datamodel.fine_tune_model_id == "ftm-123"
|
|
473
473
|
|
|
474
474
|
|
|
475
|
-
async def
|
|
475
|
+
async def test_deploy_serverless_already_deployed(fireworks_finetune, mock_api_key):
|
|
476
476
|
# Mock response for already deployed model
|
|
477
477
|
already_deployed_response = MagicMock(spec=httpx.Response)
|
|
478
478
|
already_deployed_response.status_code = 400
|
|
@@ -494,12 +494,12 @@ async def test_deploy_already_deployed(fireworks_finetune, mock_api_key):
|
|
|
494
494
|
mock_client.post.return_value = already_deployed_response
|
|
495
495
|
mock_client_class.return_value.__aenter__.return_value = mock_client
|
|
496
496
|
|
|
497
|
-
result = await fireworks_finetune.
|
|
497
|
+
result = await fireworks_finetune._deploy_serverless()
|
|
498
498
|
assert result is True
|
|
499
499
|
assert fireworks_finetune.datamodel.fine_tune_model_id == "ftm-123"
|
|
500
500
|
|
|
501
501
|
|
|
502
|
-
async def
|
|
502
|
+
async def test_deploy_serverless_failure(fireworks_finetune, mock_api_key):
|
|
503
503
|
# Mock response for failed deployment
|
|
504
504
|
failure_response = MagicMock(spec=httpx.Response)
|
|
505
505
|
failure_response.status_code = 500
|
|
@@ -510,18 +510,28 @@ async def test_deploy_failure(fireworks_finetune, mock_api_key):
|
|
|
510
510
|
mock_client.post.return_value = failure_response
|
|
511
511
|
mock_client_class.return_value.__aenter__.return_value = mock_client
|
|
512
512
|
|
|
513
|
-
result = await fireworks_finetune.
|
|
513
|
+
result = await fireworks_finetune._deploy_serverless()
|
|
514
514
|
assert result is False
|
|
515
515
|
|
|
516
516
|
|
|
517
|
-
async def
|
|
517
|
+
async def test_deploy_serverless_missing_credentials(fireworks_finetune):
|
|
518
518
|
# Test missing API key or account ID
|
|
519
519
|
with patch.object(Config, "shared") as mock_config:
|
|
520
520
|
mock_config.return_value.fireworks_api_key = None
|
|
521
521
|
mock_config.return_value.fireworks_account_id = None
|
|
522
522
|
|
|
523
523
|
with pytest.raises(ValueError, match="Fireworks API key or account ID not set"):
|
|
524
|
-
await fireworks_finetune.
|
|
524
|
+
await fireworks_finetune._deploy_serverless()
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
async def test_deploy_server_missing_credentials(fireworks_finetune):
|
|
528
|
+
# Test missing API key or account ID
|
|
529
|
+
with patch.object(Config, "shared") as mock_config:
|
|
530
|
+
mock_config.return_value.fireworks_api_key = None
|
|
531
|
+
mock_config.return_value.fireworks_account_id = None
|
|
532
|
+
|
|
533
|
+
response = await fireworks_finetune._check_or_deploy_server()
|
|
534
|
+
assert response is False
|
|
525
535
|
|
|
526
536
|
|
|
527
537
|
async def test_deploy_missing_model_id(fireworks_finetune, mock_api_key):
|
|
@@ -564,3 +574,479 @@ async def test_status_with_deploy(fireworks_finetune, mock_api_key):
|
|
|
564
574
|
# Verify message was updated due to failed deployment
|
|
565
575
|
assert status.status == FineTuneStatusType.completed
|
|
566
576
|
assert status.message == "Fine-tuning job completed but failed to deploy model."
|
|
577
|
+
|
|
578
|
+
|
|
579
|
+
@pytest.mark.paid
|
|
580
|
+
async def test_fetch_all_deployments(fireworks_finetune):
|
|
581
|
+
deployments = await fireworks_finetune._fetch_all_deployments()
|
|
582
|
+
assert isinstance(deployments, list)
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
async def test_api_key_and_account_id(fireworks_finetune, mock_api_key):
|
|
586
|
+
# Test successful retrieval of API key and account ID
|
|
587
|
+
api_key, account_id = fireworks_finetune.api_key_and_account_id()
|
|
588
|
+
assert api_key == "test-api-key"
|
|
589
|
+
assert account_id == "test-account-id"
|
|
590
|
+
|
|
591
|
+
|
|
592
|
+
async def test_api_key_and_account_id_missing_credentials(fireworks_finetune):
|
|
593
|
+
# Test missing API key or account ID
|
|
594
|
+
with patch.object(Config, "shared") as mock_config:
|
|
595
|
+
mock_config.return_value.fireworks_api_key = None
|
|
596
|
+
mock_config.return_value.fireworks_account_id = None
|
|
597
|
+
|
|
598
|
+
with pytest.raises(ValueError, match="Fireworks API key or account ID not set"):
|
|
599
|
+
fireworks_finetune.api_key_and_account_id()
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
def test_deployment_display_name(fireworks_finetune):
|
|
603
|
+
# Test with default ID and name
|
|
604
|
+
display_name = fireworks_finetune.deployment_display_name()
|
|
605
|
+
expected = f"Kiln AI fine-tuned model [ID:{fireworks_finetune.datamodel.id}][name:test-finetune]"[
|
|
606
|
+
:60
|
|
607
|
+
]
|
|
608
|
+
assert display_name == expected
|
|
609
|
+
|
|
610
|
+
# Test with a very long name to ensure 60 character limit
|
|
611
|
+
fireworks_finetune.datamodel.name = "x" * 100
|
|
612
|
+
display_name = fireworks_finetune.deployment_display_name()
|
|
613
|
+
assert len(display_name) == 60
|
|
614
|
+
assert display_name.startswith("Kiln AI fine-tuned model [ID:")
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
async def test_model_id_checking_status_completed(fireworks_finetune):
|
|
618
|
+
# Test with completed status and valid model ID
|
|
619
|
+
status_response = (
|
|
620
|
+
FineTuneStatus(status=FineTuneStatusType.completed, message=""),
|
|
621
|
+
"model-123",
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
with patch.object(fireworks_finetune, "_status", return_value=status_response):
|
|
625
|
+
model_id = await fireworks_finetune.model_id_checking_status()
|
|
626
|
+
assert model_id == "model-123"
|
|
627
|
+
|
|
628
|
+
|
|
629
|
+
async def test_model_id_checking_status_not_completed(fireworks_finetune):
|
|
630
|
+
# Test with non-completed status
|
|
631
|
+
status_response = (
|
|
632
|
+
FineTuneStatus(status=FineTuneStatusType.running, message=""),
|
|
633
|
+
"model-123",
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
with patch.object(fireworks_finetune, "_status", return_value=status_response):
|
|
637
|
+
model_id = await fireworks_finetune.model_id_checking_status()
|
|
638
|
+
assert model_id is None
|
|
639
|
+
|
|
640
|
+
|
|
641
|
+
async def test_model_id_checking_status_invalid_model_id(fireworks_finetune):
|
|
642
|
+
# Test with completed status but invalid model ID
|
|
643
|
+
status_response = (
|
|
644
|
+
FineTuneStatus(status=FineTuneStatusType.completed, message=""),
|
|
645
|
+
None,
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
with patch.object(fireworks_finetune, "_status", return_value=status_response):
|
|
649
|
+
model_id = await fireworks_finetune.model_id_checking_status()
|
|
650
|
+
assert model_id is None
|
|
651
|
+
|
|
652
|
+
# Test with non-string model ID
|
|
653
|
+
status_response = (
|
|
654
|
+
FineTuneStatus(status=FineTuneStatusType.completed, message=""),
|
|
655
|
+
{"id": "model-123"}, # Not a string
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
with patch.object(fireworks_finetune, "_status", return_value=status_response):
|
|
659
|
+
model_id = await fireworks_finetune.model_id_checking_status()
|
|
660
|
+
assert model_id is None
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
@pytest.mark.parametrize(
|
|
664
|
+
"base_model_id,expected_method",
|
|
665
|
+
[
|
|
666
|
+
("accounts/fireworks/models/llama-v3p1-8b-instruct", "_deploy_serverless"),
|
|
667
|
+
("accounts/fireworks/models/llama-v3p1-70b-instruct", "_deploy_serverless"),
|
|
668
|
+
("some-other-model", "_check_or_deploy_server"),
|
|
669
|
+
],
|
|
670
|
+
)
|
|
671
|
+
async def test_deploy_model_selection(
|
|
672
|
+
fireworks_finetune, base_model_id, expected_method, mock_api_key
|
|
673
|
+
):
|
|
674
|
+
# Set the base model ID
|
|
675
|
+
fireworks_finetune.datamodel.base_model_id = base_model_id
|
|
676
|
+
|
|
677
|
+
# Mock the deployment methods
|
|
678
|
+
with (
|
|
679
|
+
patch.object(
|
|
680
|
+
fireworks_finetune, "_deploy_serverless", return_value=True
|
|
681
|
+
) as mock_serverless,
|
|
682
|
+
patch.object(
|
|
683
|
+
fireworks_finetune, "_check_or_deploy_server", return_value=True
|
|
684
|
+
) as mock_server,
|
|
685
|
+
):
|
|
686
|
+
result = await fireworks_finetune._deploy()
|
|
687
|
+
|
|
688
|
+
# Verify the correct method was called based on the model
|
|
689
|
+
if expected_method == "_deploy_serverless":
|
|
690
|
+
mock_serverless.assert_called_once()
|
|
691
|
+
mock_server.assert_not_called()
|
|
692
|
+
else:
|
|
693
|
+
mock_serverless.assert_not_called()
|
|
694
|
+
mock_server.assert_called_once()
|
|
695
|
+
|
|
696
|
+
assert result is True
|
|
697
|
+
|
|
698
|
+
|
|
699
|
+
async def test_fetch_all_deployments_request_error(fireworks_finetune, mock_api_key):
|
|
700
|
+
# Test with error response
|
|
701
|
+
error_response = MagicMock(spec=httpx.Response)
|
|
702
|
+
error_response.status_code = 500
|
|
703
|
+
error_response.text = "Internal Server Error"
|
|
704
|
+
|
|
705
|
+
with patch("httpx.AsyncClient") as mock_client_class:
|
|
706
|
+
mock_client = AsyncMock()
|
|
707
|
+
mock_client.get.side_effect = Exception("API request failed")
|
|
708
|
+
mock_client_class.return_value.__aenter__.return_value = mock_client
|
|
709
|
+
|
|
710
|
+
with pytest.raises(Exception, match="API request failed"):
|
|
711
|
+
await fireworks_finetune._fetch_all_deployments()
|
|
712
|
+
|
|
713
|
+
# Verify API was called with correct parameters
|
|
714
|
+
mock_client.get.assert_called_once()
|
|
715
|
+
call_args = mock_client.get.call_args[1]
|
|
716
|
+
assert "params" in call_args
|
|
717
|
+
assert call_args["params"]["pageSize"] == 200
|
|
718
|
+
|
|
719
|
+
|
|
720
|
+
async def test_fetch_all_deployments_standard_case(fireworks_finetune, mock_api_key):
|
|
721
|
+
# Test with single page of results
|
|
722
|
+
mock_deployments = [
|
|
723
|
+
{"id": "deploy-1", "baseModel": "model-1", "state": "READY"},
|
|
724
|
+
{"id": "deploy-2", "baseModel": "model-2", "state": "READY"},
|
|
725
|
+
]
|
|
726
|
+
|
|
727
|
+
success_response = MagicMock(spec=httpx.Response)
|
|
728
|
+
success_response.status_code = 200
|
|
729
|
+
success_response.json.return_value = {
|
|
730
|
+
"deployments": mock_deployments,
|
|
731
|
+
"nextPageToken": None,
|
|
732
|
+
}
|
|
733
|
+
|
|
734
|
+
with patch("httpx.AsyncClient") as mock_client_class:
|
|
735
|
+
mock_client = AsyncMock()
|
|
736
|
+
mock_client.get.return_value = success_response
|
|
737
|
+
mock_client_class.return_value.__aenter__.return_value = mock_client
|
|
738
|
+
|
|
739
|
+
deployments = await fireworks_finetune._fetch_all_deployments()
|
|
740
|
+
|
|
741
|
+
# Verify API was called correctly
|
|
742
|
+
mock_client.get.assert_called_once()
|
|
743
|
+
|
|
744
|
+
# Verify correct deployments were returned
|
|
745
|
+
assert deployments == mock_deployments
|
|
746
|
+
assert len(deployments) == 2
|
|
747
|
+
assert deployments[0]["id"] == "deploy-1"
|
|
748
|
+
assert deployments[1]["id"] == "deploy-2"
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
async def test_fetch_all_deployments_paged_case(fireworks_finetune, mock_api_key):
|
|
752
|
+
# Test with multiple pages of results
|
|
753
|
+
mock_deployments_page1 = [
|
|
754
|
+
{"id": "deploy-1", "baseModel": "model-1", "state": "READY"},
|
|
755
|
+
{"id": "deploy-2", "baseModel": "model-2", "state": "READY"},
|
|
756
|
+
]
|
|
757
|
+
|
|
758
|
+
mock_deployments_page2 = [
|
|
759
|
+
{"id": "deploy-3", "baseModel": "model-3", "state": "READY"},
|
|
760
|
+
{"id": "deploy-4", "baseModel": "model-4", "state": "READY"},
|
|
761
|
+
]
|
|
762
|
+
|
|
763
|
+
page1_response = MagicMock(spec=httpx.Response)
|
|
764
|
+
page1_response.status_code = 200
|
|
765
|
+
page1_response.json.return_value = {
|
|
766
|
+
"deployments": mock_deployments_page1,
|
|
767
|
+
"nextPageToken": "page2token",
|
|
768
|
+
}
|
|
769
|
+
|
|
770
|
+
page2_response = MagicMock(spec=httpx.Response)
|
|
771
|
+
page2_response.status_code = 200
|
|
772
|
+
page2_response.json.return_value = {
|
|
773
|
+
"deployments": mock_deployments_page2,
|
|
774
|
+
"nextPageToken": None,
|
|
775
|
+
}
|
|
776
|
+
|
|
777
|
+
with patch("httpx.AsyncClient") as mock_client_class:
|
|
778
|
+
mock_client = AsyncMock()
|
|
779
|
+
mock_client.get.side_effect = [page1_response, page2_response]
|
|
780
|
+
mock_client_class.return_value.__aenter__.return_value = mock_client
|
|
781
|
+
|
|
782
|
+
deployments = await fireworks_finetune._fetch_all_deployments()
|
|
783
|
+
|
|
784
|
+
# Verify API was called twice (once for each page)
|
|
785
|
+
assert mock_client.get.call_count == 2
|
|
786
|
+
|
|
787
|
+
# Verify first call had no page token
|
|
788
|
+
first_call_args = mock_client.get.call_args_list[0][1]
|
|
789
|
+
assert "pageToken" not in first_call_args["params"]
|
|
790
|
+
|
|
791
|
+
# Verify second call included the page token
|
|
792
|
+
second_call_args = mock_client.get.call_args_list[1][1]
|
|
793
|
+
assert second_call_args["params"]["pageToken"] == "page2token"
|
|
794
|
+
|
|
795
|
+
# Verify all deployments from both pages were returned
|
|
796
|
+
assert len(deployments) == 4
|
|
797
|
+
assert deployments == mock_deployments_page1 + mock_deployments_page2
|
|
798
|
+
for deployment in deployments:
|
|
799
|
+
assert deployment["id"] in [
|
|
800
|
+
"deploy-1",
|
|
801
|
+
"deploy-2",
|
|
802
|
+
"deploy-3",
|
|
803
|
+
"deploy-4",
|
|
804
|
+
]
|
|
805
|
+
|
|
806
|
+
|
|
807
|
+
async def test_deploy_server_success(fireworks_finetune, mock_api_key):
|
|
808
|
+
# Mock response for successful deployment
|
|
809
|
+
success_response = MagicMock(spec=httpx.Response)
|
|
810
|
+
success_response.status_code = 200
|
|
811
|
+
success_response.json.return_value = {"baseModel": "model-123"}
|
|
812
|
+
|
|
813
|
+
status_response = (
|
|
814
|
+
FineTuneStatus(status=FineTuneStatusType.completed, message=""),
|
|
815
|
+
"model-123",
|
|
816
|
+
)
|
|
817
|
+
|
|
818
|
+
with (
|
|
819
|
+
patch("httpx.AsyncClient") as mock_client_class,
|
|
820
|
+
patch.object(
|
|
821
|
+
fireworks_finetune, "model_id_checking_status", return_value="model-123"
|
|
822
|
+
),
|
|
823
|
+
):
|
|
824
|
+
mock_client = AsyncMock()
|
|
825
|
+
mock_client.post.return_value = success_response
|
|
826
|
+
mock_client_class.return_value.__aenter__.return_value = mock_client
|
|
827
|
+
|
|
828
|
+
result = await fireworks_finetune._deploy_server()
|
|
829
|
+
|
|
830
|
+
# Verify result
|
|
831
|
+
assert result is True
|
|
832
|
+
|
|
833
|
+
# Verify fine_tune_model_id was updated
|
|
834
|
+
assert fireworks_finetune.datamodel.fine_tune_model_id == "model-123"
|
|
835
|
+
|
|
836
|
+
# Verify API was called with correct parameters
|
|
837
|
+
mock_client.post.assert_called_once()
|
|
838
|
+
call_args = mock_client.post.call_args[1]
|
|
839
|
+
assert "json" in call_args
|
|
840
|
+
assert call_args["json"]["baseModel"] == "model-123"
|
|
841
|
+
assert call_args["json"]["minReplicaCount"] == 0
|
|
842
|
+
assert "autoscalingPolicy" in call_args["json"]
|
|
843
|
+
assert call_args["json"]["autoscalingPolicy"]["scaleToZeroWindow"] == "300s"
|
|
844
|
+
|
|
845
|
+
# load the datamodel from the file and confirm the fine_tune_model_id was updated
|
|
846
|
+
loaded_datamodel = FinetuneModel.load_from_file(
|
|
847
|
+
fireworks_finetune.datamodel.path
|
|
848
|
+
)
|
|
849
|
+
assert loaded_datamodel.fine_tune_model_id == "model-123"
|
|
850
|
+
|
|
851
|
+
|
|
852
|
+
async def test_deploy_server_failure(fireworks_finetune, mock_api_key):
|
|
853
|
+
# Mock response for failed deployment
|
|
854
|
+
failure_response = MagicMock(spec=httpx.Response)
|
|
855
|
+
failure_response.status_code = 500
|
|
856
|
+
failure_response.text = "Internal Server Error"
|
|
857
|
+
|
|
858
|
+
with (
|
|
859
|
+
patch("httpx.AsyncClient") as mock_client_class,
|
|
860
|
+
patch.object(
|
|
861
|
+
fireworks_finetune, "model_id_checking_status", return_value="model-123"
|
|
862
|
+
),
|
|
863
|
+
):
|
|
864
|
+
mock_client = AsyncMock()
|
|
865
|
+
mock_client.post.return_value = failure_response
|
|
866
|
+
mock_client_class.return_value.__aenter__.return_value = mock_client
|
|
867
|
+
|
|
868
|
+
result = await fireworks_finetune._deploy_server()
|
|
869
|
+
|
|
870
|
+
# Verify result
|
|
871
|
+
assert result is False
|
|
872
|
+
|
|
873
|
+
# Verify API was called
|
|
874
|
+
mock_client.post.assert_called_once()
|
|
875
|
+
|
|
876
|
+
|
|
877
|
+
async def test_deploy_server_non_200_but_valid_response(
|
|
878
|
+
fireworks_finetune, mock_api_key
|
|
879
|
+
):
|
|
880
|
+
# Mock response with non-200 status but valid JSON response
|
|
881
|
+
mixed_response = MagicMock(spec=httpx.Response)
|
|
882
|
+
mixed_response.status_code = 200
|
|
883
|
+
mixed_response.json.return_value = {"not_baseModel": "something-else"}
|
|
884
|
+
|
|
885
|
+
with (
|
|
886
|
+
patch("httpx.AsyncClient") as mock_client_class,
|
|
887
|
+
patch.object(
|
|
888
|
+
fireworks_finetune, "model_id_checking_status", return_value="model-123"
|
|
889
|
+
),
|
|
890
|
+
):
|
|
891
|
+
mock_client = AsyncMock()
|
|
892
|
+
mock_client.post.return_value = mixed_response
|
|
893
|
+
mock_client_class.return_value.__aenter__.return_value = mock_client
|
|
894
|
+
|
|
895
|
+
result = await fireworks_finetune._deploy_server()
|
|
896
|
+
|
|
897
|
+
# Verify result - should fail because baseModel is missing
|
|
898
|
+
assert result is False
|
|
899
|
+
|
|
900
|
+
|
|
901
|
+
async def test_deploy_server_missing_model_id(fireworks_finetune, mock_api_key):
|
|
902
|
+
# Test when model_id_checking_status returns None
|
|
903
|
+
with patch.object(
|
|
904
|
+
fireworks_finetune, "model_id_checking_status", return_value=None
|
|
905
|
+
):
|
|
906
|
+
result = await fireworks_finetune._deploy_server()
|
|
907
|
+
|
|
908
|
+
# Verify result - should fail because model ID is missing
|
|
909
|
+
assert result is False
|
|
910
|
+
|
|
911
|
+
|
|
912
|
+
@pytest.mark.parametrize(
|
|
913
|
+
"state,expected_already_deployed",
|
|
914
|
+
[
|
|
915
|
+
("READY", True),
|
|
916
|
+
("CREATING", True),
|
|
917
|
+
("FAILED", False),
|
|
918
|
+
],
|
|
919
|
+
)
|
|
920
|
+
async def test_check_or_deploy_server_already_deployed(
|
|
921
|
+
fireworks_finetune, mock_api_key, state, expected_already_deployed
|
|
922
|
+
):
|
|
923
|
+
# Test when model is already deployed (should return True without calling _deploy_server)
|
|
924
|
+
|
|
925
|
+
# Set a fine_tune_model_id so we search for deployments
|
|
926
|
+
fireworks_finetune.datamodel.fine_tune_model_id = "model-123"
|
|
927
|
+
|
|
928
|
+
# Mock deployments including one matching our model ID
|
|
929
|
+
mock_deployments = [
|
|
930
|
+
{"id": "deploy-1", "baseModel": "different-model", "state": "READY"},
|
|
931
|
+
{"id": "deploy-2", "baseModel": "model-123", "state": state},
|
|
932
|
+
]
|
|
933
|
+
|
|
934
|
+
with (
|
|
935
|
+
patch.object(
|
|
936
|
+
fireworks_finetune, "_fetch_all_deployments", return_value=mock_deployments
|
|
937
|
+
) as mock_fetch,
|
|
938
|
+
patch.object(fireworks_finetune, "_deploy_server") as mock_deploy,
|
|
939
|
+
):
|
|
940
|
+
mock_deploy.return_value = True
|
|
941
|
+
result = await fireworks_finetune._check_or_deploy_server()
|
|
942
|
+
# Even true if the model is in a non-ready state, as we'll call deploy (checked below)
|
|
943
|
+
assert result is True
|
|
944
|
+
|
|
945
|
+
if expected_already_deployed:
|
|
946
|
+
assert mock_deploy.call_count == 0
|
|
947
|
+
else:
|
|
948
|
+
assert mock_deploy.call_count == 1
|
|
949
|
+
|
|
950
|
+
# Verify _fetch_all_deployments was called
|
|
951
|
+
mock_fetch.assert_called_once()
|
|
952
|
+
|
|
953
|
+
|
|
954
|
+
async def test_check_or_deploy_server_not_deployed(fireworks_finetune, mock_api_key):
|
|
955
|
+
# Test when model exists but isn't deployed (should call _deploy_server)
|
|
956
|
+
|
|
957
|
+
# Set a fine_tune_model_id so we search for deployments
|
|
958
|
+
fireworks_finetune.datamodel.fine_tune_model_id = "model-123"
|
|
959
|
+
|
|
960
|
+
# Mock deployments without our model ID
|
|
961
|
+
mock_deployments = [
|
|
962
|
+
{"id": "deploy-1", "baseModel": "different-model-1", "state": "READY"},
|
|
963
|
+
{"id": "deploy-2", "baseModel": "different-model-2", "state": "READY"},
|
|
964
|
+
]
|
|
965
|
+
|
|
966
|
+
with (
|
|
967
|
+
patch.object(
|
|
968
|
+
fireworks_finetune, "_fetch_all_deployments", return_value=mock_deployments
|
|
969
|
+
) as mock_fetch,
|
|
970
|
+
patch.object(
|
|
971
|
+
fireworks_finetune, "_deploy_server", return_value=True
|
|
972
|
+
) as mock_deploy,
|
|
973
|
+
):
|
|
974
|
+
result = await fireworks_finetune._check_or_deploy_server()
|
|
975
|
+
|
|
976
|
+
# Verify method returned True (from _deploy_server)
|
|
977
|
+
assert result is True
|
|
978
|
+
|
|
979
|
+
# Verify _fetch_all_deployments was called
|
|
980
|
+
mock_fetch.assert_called_once()
|
|
981
|
+
|
|
982
|
+
# Verify _deploy_server was called since model is not deployed
|
|
983
|
+
mock_deploy.assert_called_once()
|
|
984
|
+
|
|
985
|
+
|
|
986
|
+
async def test_check_or_deploy_server_no_model_id(fireworks_finetune, mock_api_key):
|
|
987
|
+
# Test when no fine_tune_model_id exists (should skip fetch and call _deploy_server directly)
|
|
988
|
+
|
|
989
|
+
# Ensure no fine_tune_model_id is set
|
|
990
|
+
fireworks_finetune.datamodel.fine_tune_model_id = None
|
|
991
|
+
|
|
992
|
+
with (
|
|
993
|
+
patch.object(fireworks_finetune, "_fetch_all_deployments") as mock_fetch,
|
|
994
|
+
patch.object(
|
|
995
|
+
fireworks_finetune, "_deploy_server", return_value=True
|
|
996
|
+
) as mock_deploy,
|
|
997
|
+
):
|
|
998
|
+
result = await fireworks_finetune._check_or_deploy_server()
|
|
999
|
+
|
|
1000
|
+
# Verify method returned True (from _deploy_server)
|
|
1001
|
+
assert result is True
|
|
1002
|
+
|
|
1003
|
+
# Verify _fetch_all_deployments was NOT called
|
|
1004
|
+
mock_fetch.assert_not_called()
|
|
1005
|
+
|
|
1006
|
+
# Verify _deploy_server was called directly
|
|
1007
|
+
mock_deploy.assert_called_once()
|
|
1008
|
+
|
|
1009
|
+
|
|
1010
|
+
async def test_check_or_deploy_server_deploy_fails(fireworks_finetune, mock_api_key):
|
|
1011
|
+
# Test when deployment fails
|
|
1012
|
+
|
|
1013
|
+
# Ensure no fine_tune_model_id is set
|
|
1014
|
+
fireworks_finetune.datamodel.fine_tune_model_id = None
|
|
1015
|
+
|
|
1016
|
+
with (
|
|
1017
|
+
patch.object(
|
|
1018
|
+
fireworks_finetune, "_deploy_server", return_value=False
|
|
1019
|
+
) as mock_deploy,
|
|
1020
|
+
):
|
|
1021
|
+
result = await fireworks_finetune._check_or_deploy_server()
|
|
1022
|
+
|
|
1023
|
+
# Verify method returned False (from _deploy_server)
|
|
1024
|
+
assert result is False
|
|
1025
|
+
|
|
1026
|
+
# Verify _deploy_server was called
|
|
1027
|
+
mock_deploy.assert_called_once()
|
|
1028
|
+
|
|
1029
|
+
|
|
1030
|
+
async def test_fetch_all_deployments_invalid_json(fireworks_finetune, mock_api_key):
|
|
1031
|
+
# Test with invalid JSON response (missing 'deployments' key)
|
|
1032
|
+
invalid_response = MagicMock(spec=httpx.Response)
|
|
1033
|
+
invalid_response.status_code = 200
|
|
1034
|
+
invalid_response.json.return_value = {
|
|
1035
|
+
"some_other_key": "value",
|
|
1036
|
+
# No 'deployments' key
|
|
1037
|
+
}
|
|
1038
|
+
invalid_response.text = '{"some_other_key": "value"}'
|
|
1039
|
+
|
|
1040
|
+
with patch("httpx.AsyncClient") as mock_client_class:
|
|
1041
|
+
mock_client = AsyncMock()
|
|
1042
|
+
mock_client.get.return_value = invalid_response
|
|
1043
|
+
mock_client_class.return_value.__aenter__.return_value = mock_client
|
|
1044
|
+
|
|
1045
|
+
with pytest.raises(
|
|
1046
|
+
ValueError,
|
|
1047
|
+
match="Invalid response from Fireworks. Expected list of deployments in 'deployments' key",
|
|
1048
|
+
):
|
|
1049
|
+
await fireworks_finetune._fetch_all_deployments()
|
|
1050
|
+
|
|
1051
|
+
# Verify API was called
|
|
1052
|
+
mock_client.get.assert_called_once()
|