truss 0.10.9rc535__py3-none-any.whl → 0.10.10rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/cli/logs/base_watcher.py +1 -1
- truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +30 -22
- truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +8 -2
- truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +14 -7
- truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +63 -0
- truss/cli/train/deploy_from_checkpoint_config_whisper.yml +17 -0
- truss/cli/train/metrics_watcher.py +170 -59
- truss/cli/train_commands.py +11 -3
- truss/contexts/image_builder/serving_image_builder.py +22 -39
- truss/remote/baseten/api.py +11 -0
- truss/remote/baseten/core.py +209 -1
- truss/remote/baseten/utils/time.py +15 -0
- truss/templates/base.Dockerfile.jinja +6 -23
- truss/templates/cache.Dockerfile.jinja +5 -5
- truss/templates/copy_cache_files.Dockerfile.jinja +1 -1
- truss/templates/docker_server/supervisord.conf.jinja +0 -1
- truss/templates/server/requirements.txt +1 -1
- truss/templates/server.Dockerfile.jinja +16 -33
- truss/tests/cli/train/test_deploy_checkpoints.py +446 -2
- truss/tests/cli/train/test_train_cli_core.py +96 -0
- truss/tests/remote/baseten/conftest.py +18 -0
- truss/tests/remote/baseten/test_api.py +49 -14
- truss/tests/remote/baseten/test_core.py +517 -1
- {truss-0.10.9rc535.dist-info → truss-0.10.10rc0.dist-info}/METADATA +2 -2
- {truss-0.10.9rc535.dist-info → truss-0.10.10rc0.dist-info}/RECORD +31 -29
- truss_train/definitions.py +6 -0
- truss_train/deployment.py +15 -2
- truss_train/loader.py +7 -20
- truss/tests/util/test_basetenpointer.py +0 -227
- truss/util/basetenpointer.py +0 -160
- {truss-0.10.9rc535.dist-info → truss-0.10.10rc0.dist-info}/WHEEL +0 -0
- {truss-0.10.9rc535.dist-info → truss-0.10.10rc0.dist-info}/entry_points.txt +0 -0
- {truss-0.10.9rc535.dist-info → truss-0.10.10rc0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import re
|
|
3
|
+
from dataclasses import dataclass
|
|
3
4
|
from pathlib import Path
|
|
5
|
+
from typing import Dict, List, Optional
|
|
4
6
|
from unittest.mock import MagicMock, patch
|
|
5
7
|
|
|
6
8
|
import pytest
|
|
@@ -23,6 +25,11 @@ from truss.cli.train.deploy_checkpoints.deploy_lora_checkpoints import (
|
|
|
23
25
|
hydrate_lora_checkpoint,
|
|
24
26
|
render_vllm_lora_truss_config,
|
|
25
27
|
)
|
|
28
|
+
from truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints import (
|
|
29
|
+
VLLM_WHISPER_START_COMMAND,
|
|
30
|
+
hydrate_whisper_checkpoint,
|
|
31
|
+
render_vllm_whisper_truss_config,
|
|
32
|
+
)
|
|
26
33
|
from truss.cli.train.types import (
|
|
27
34
|
DeployCheckpointsConfigComplete,
|
|
28
35
|
PrepareCheckpointResult,
|
|
@@ -505,8 +512,16 @@ def test_render_vllm_full_truss_config():
|
|
|
505
512
|
)
|
|
506
513
|
|
|
507
514
|
result = render_vllm_full_truss_config(deploy_config)
|
|
508
|
-
|
|
509
|
-
|
|
515
|
+
expected_vllm_command = (
|
|
516
|
+
"sh -c 'HF_TOKEN=$(cat /secrets/hf_token) "
|
|
517
|
+
'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
|
|
518
|
+
"if [ -f /tmp/training_checkpoints/job123/rank-0/checkpoint-1/chat_template.jinja ]; then "
|
|
519
|
+
"vllm serve /tmp/training_checkpoints/job123/rank-0/checkpoint-1 "
|
|
520
|
+
"--chat-template /tmp/training_checkpoints/job123/rank-0/checkpoint-1/chat_template.jinja "
|
|
521
|
+
"--port 8000 --tensor-parallel-size 2 --dtype bfloat16; else "
|
|
522
|
+
"vllm serve /tmp/training_checkpoints/job123/rank-0/checkpoint-1 "
|
|
523
|
+
"--port 8000 --tensor-parallel-size 2 --dtype bfloat16; fi'"
|
|
524
|
+
)
|
|
510
525
|
|
|
511
526
|
assert isinstance(result, truss_config.TrussConfig)
|
|
512
527
|
assert result.model_name == "test-full-model"
|
|
@@ -674,3 +689,432 @@ def test_get_checkpoint_ids_to_deploy_single_checkpoint():
|
|
|
674
689
|
|
|
675
690
|
# Should return the single checkpoint directly
|
|
676
691
|
assert result == ["checkpoint-1"]
|
|
692
|
+
|
|
693
|
+
|
|
694
|
+
def test_vllm_whisper_start_command_template():
|
|
695
|
+
"""Test that the VLLM_WHISPER_START_COMMAND template renders correctly."""
|
|
696
|
+
# Test with all variables
|
|
697
|
+
result = VLLM_WHISPER_START_COMMAND.render(
|
|
698
|
+
model_path="/path/to/model",
|
|
699
|
+
envvars="CUDA_VISIBLE_DEVICES=0",
|
|
700
|
+
specify_tensor_parallelism=4,
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
expected = (
|
|
704
|
+
"sh -c 'CUDA_VISIBLE_DEVICES=0 "
|
|
705
|
+
'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
|
|
706
|
+
"vllm serve /path/to/model --port 8000 --tensor-parallel-size 4'"
|
|
707
|
+
)
|
|
708
|
+
assert result == expected
|
|
709
|
+
|
|
710
|
+
result = VLLM_WHISPER_START_COMMAND.render(
|
|
711
|
+
model_path="/path/to/model", envvars=None, specify_tensor_parallelism=1
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
expected = (
|
|
715
|
+
"sh -c '"
|
|
716
|
+
'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
|
|
717
|
+
"vllm serve /path/to/model --port 8000 --tensor-parallel-size 1'"
|
|
718
|
+
)
|
|
719
|
+
assert result == expected
|
|
720
|
+
|
|
721
|
+
|
|
722
|
+
def test_hydrate_whisper_checkpoint():
|
|
723
|
+
"""Test that hydrate_whisper_checkpoint creates correct WhisperCheckpoint object."""
|
|
724
|
+
job_id = "test-job-123"
|
|
725
|
+
checkpoint_id = "checkpoint-456"
|
|
726
|
+
checkpoint = {"some": "data"}
|
|
727
|
+
|
|
728
|
+
result = hydrate_whisper_checkpoint(job_id, checkpoint_id, checkpoint)
|
|
729
|
+
|
|
730
|
+
assert result.training_job_id == job_id
|
|
731
|
+
assert result.paths == [f"rank-0/{checkpoint_id}/"]
|
|
732
|
+
assert result.model_weight_format == definitions.ModelWeightsFormat.WHISPER
|
|
733
|
+
assert isinstance(result, definitions.WhisperCheckpoint)
|
|
734
|
+
|
|
735
|
+
|
|
736
|
+
@patch(
|
|
737
|
+
"truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.setup_base_truss_config"
|
|
738
|
+
)
|
|
739
|
+
@patch(
|
|
740
|
+
"truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.setup_environment_variables_and_secrets"
|
|
741
|
+
)
|
|
742
|
+
@patch(
|
|
743
|
+
"truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.build_full_checkpoint_string"
|
|
744
|
+
)
|
|
745
|
+
def test_render_vllm_whisper_truss_config(
|
|
746
|
+
mock_build_full_checkpoint_string, mock_setup_env_vars, mock_setup_base_config
|
|
747
|
+
):
|
|
748
|
+
"""Test that render_vllm_whisper_truss_config renders truss config correctly."""
|
|
749
|
+
# Mock dependencies
|
|
750
|
+
mock_truss_config = MagicMock()
|
|
751
|
+
mock_truss_config.environment_variables = {}
|
|
752
|
+
mock_truss_config.docker_server = MagicMock()
|
|
753
|
+
mock_setup_base_config.return_value = mock_truss_config
|
|
754
|
+
|
|
755
|
+
mock_setup_env_vars.return_value = "HF_TOKEN=$(cat /secrets/hf_access_token)"
|
|
756
|
+
mock_build_full_checkpoint_string.return_value = "/path/to/checkpoint"
|
|
757
|
+
|
|
758
|
+
# Create test config
|
|
759
|
+
deploy_config = DeployCheckpointsConfigComplete(
|
|
760
|
+
checkpoint_details=definitions.CheckpointList(
|
|
761
|
+
checkpoints=[
|
|
762
|
+
definitions.WhisperCheckpoint(
|
|
763
|
+
training_job_id="job123",
|
|
764
|
+
paths=["rank-0/checkpoint-1/"],
|
|
765
|
+
model_weight_format=definitions.ModelWeightsFormat.WHISPER,
|
|
766
|
+
)
|
|
767
|
+
],
|
|
768
|
+
base_model_id="openai/whisper-large-v3",
|
|
769
|
+
),
|
|
770
|
+
model_name="whisper-large-v3-vLLM",
|
|
771
|
+
compute=definitions.Compute(
|
|
772
|
+
accelerator=truss_config.AcceleratorSpec(accelerator="H100", count=4)
|
|
773
|
+
),
|
|
774
|
+
runtime=definitions.DeployCheckpointsRuntime(
|
|
775
|
+
environment_variables={
|
|
776
|
+
"HF_TOKEN": definitions.SecretReference(name="hf_access_token")
|
|
777
|
+
}
|
|
778
|
+
),
|
|
779
|
+
deployment_name="whisper-large-v3-vLLM",
|
|
780
|
+
model_weight_format=definitions.ModelWeightsFormat.WHISPER,
|
|
781
|
+
)
|
|
782
|
+
|
|
783
|
+
result = render_vllm_whisper_truss_config(deploy_config)
|
|
784
|
+
|
|
785
|
+
mock_setup_base_config.assert_called_once_with(deploy_config)
|
|
786
|
+
mock_setup_env_vars.assert_called_once_with(mock_truss_config, deploy_config)
|
|
787
|
+
mock_build_full_checkpoint_string.assert_called_once_with(mock_truss_config)
|
|
788
|
+
|
|
789
|
+
assert result == mock_truss_config
|
|
790
|
+
|
|
791
|
+
expected_start_command = (
|
|
792
|
+
"sh -c 'HF_TOKEN=$(cat /secrets/hf_access_token) "
|
|
793
|
+
'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
|
|
794
|
+
"vllm serve /path/to/checkpoint --port 8000 --tensor-parallel-size 4'"
|
|
795
|
+
)
|
|
796
|
+
assert (
|
|
797
|
+
result.environment_variables[START_COMMAND_ENVVAR_NAME]
|
|
798
|
+
== expected_start_command
|
|
799
|
+
)
|
|
800
|
+
|
|
801
|
+
assert result.docker_server.start_command == f"%(ENV_{START_COMMAND_ENVVAR_NAME})s"
|
|
802
|
+
|
|
803
|
+
|
|
804
|
+
@patch(
|
|
805
|
+
"truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.setup_base_truss_config"
|
|
806
|
+
)
|
|
807
|
+
@patch(
|
|
808
|
+
"truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.setup_environment_variables_and_secrets"
|
|
809
|
+
)
|
|
810
|
+
@patch(
|
|
811
|
+
"truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.build_full_checkpoint_string"
|
|
812
|
+
)
|
|
813
|
+
def test_render_vllm_whisper_truss_config_with_envvars(
|
|
814
|
+
mock_build_full_checkpoint_string, mock_setup_env_vars, mock_setup_base_config
|
|
815
|
+
):
|
|
816
|
+
"""Test that render_vllm_whisper_truss_config handles environment variables correctly."""
|
|
817
|
+
# Mock dependencies
|
|
818
|
+
mock_truss_config = MagicMock()
|
|
819
|
+
mock_truss_config.environment_variables = {}
|
|
820
|
+
mock_truss_config.docker_server = MagicMock()
|
|
821
|
+
mock_setup_base_config.return_value = mock_truss_config
|
|
822
|
+
|
|
823
|
+
mock_setup_env_vars.return_value = "CUDA_VISIBLE_DEVICES=0,1"
|
|
824
|
+
mock_build_full_checkpoint_string.return_value = "/path/to/checkpoint"
|
|
825
|
+
|
|
826
|
+
# Create test config with environment variables
|
|
827
|
+
deploy_config = DeployCheckpointsConfigComplete(
|
|
828
|
+
checkpoint_details=definitions.CheckpointList(
|
|
829
|
+
checkpoints=[
|
|
830
|
+
definitions.WhisperCheckpoint(
|
|
831
|
+
training_job_id="job123",
|
|
832
|
+
paths=["rank-0/checkpoint-1/"],
|
|
833
|
+
model_weight_format=definitions.ModelWeightsFormat.WHISPER,
|
|
834
|
+
)
|
|
835
|
+
],
|
|
836
|
+
base_model_id="openai/whisper-large-v3",
|
|
837
|
+
),
|
|
838
|
+
model_name="whisper-large-v3-vLLM",
|
|
839
|
+
compute=definitions.Compute(
|
|
840
|
+
accelerator=truss_config.AcceleratorSpec(accelerator="H100", count=2)
|
|
841
|
+
),
|
|
842
|
+
runtime=definitions.DeployCheckpointsRuntime(
|
|
843
|
+
environment_variables={
|
|
844
|
+
"CUDA_VISIBLE_DEVICES": "0,1",
|
|
845
|
+
"HF_TOKEN": definitions.SecretReference(name="hf_access_token"),
|
|
846
|
+
}
|
|
847
|
+
),
|
|
848
|
+
deployment_name="whisper-large-v3-vLLM",
|
|
849
|
+
model_weight_format=definitions.ModelWeightsFormat.WHISPER,
|
|
850
|
+
)
|
|
851
|
+
|
|
852
|
+
# Call function under test
|
|
853
|
+
result = render_vllm_whisper_truss_config(deploy_config)
|
|
854
|
+
|
|
855
|
+
# Verify environment variables are included in start command
|
|
856
|
+
expected_start_command = (
|
|
857
|
+
"sh -c 'CUDA_VISIBLE_DEVICES=0,1 "
|
|
858
|
+
'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
|
|
859
|
+
"vllm serve /path/to/checkpoint --port 8000 --tensor-parallel-size 2'"
|
|
860
|
+
)
|
|
861
|
+
assert (
|
|
862
|
+
result.environment_variables[START_COMMAND_ENVVAR_NAME]
|
|
863
|
+
== expected_start_command
|
|
864
|
+
)
|
|
865
|
+
|
|
866
|
+
|
|
867
|
+
@dataclass
|
|
868
|
+
class TestCase:
|
|
869
|
+
"""Test case for setup_base_truss_config function."""
|
|
870
|
+
|
|
871
|
+
desc: str
|
|
872
|
+
input_config: DeployCheckpointsConfigComplete
|
|
873
|
+
expected_model_name: str
|
|
874
|
+
expected_predict_endpoint: str
|
|
875
|
+
expected_accelerator: Optional[str]
|
|
876
|
+
expected_accelerator_count: Optional[int]
|
|
877
|
+
expected_checkpoint_paths: List[str]
|
|
878
|
+
expected_environment_variables: Dict[str, str]
|
|
879
|
+
should_raise: Optional[str] = None # Error message if function should raise
|
|
880
|
+
|
|
881
|
+
__test__ = False # Tell pytest this is not a test class
|
|
882
|
+
|
|
883
|
+
|
|
884
|
+
def test_setup_base_truss_config():
|
|
885
|
+
"""Table-driven test for setup_base_truss_config function."""
|
|
886
|
+
from truss.cli.train.deploy_checkpoints.deploy_checkpoints_helpers import (
|
|
887
|
+
setup_base_truss_config,
|
|
888
|
+
)
|
|
889
|
+
|
|
890
|
+
# Define test cases
|
|
891
|
+
test_cases = [
|
|
892
|
+
TestCase(
|
|
893
|
+
desc="LoRA checkpoint with H100 accelerator",
|
|
894
|
+
input_config=DeployCheckpointsConfigComplete(
|
|
895
|
+
checkpoint_details=definitions.CheckpointList(
|
|
896
|
+
checkpoints=[
|
|
897
|
+
definitions.LoRACheckpoint(
|
|
898
|
+
training_job_id="job123",
|
|
899
|
+
paths=["rank-0/checkpoint-1/"],
|
|
900
|
+
model_weight_format=ModelWeightsFormat.LORA,
|
|
901
|
+
lora_details=definitions.LoRADetails(rank=32),
|
|
902
|
+
)
|
|
903
|
+
],
|
|
904
|
+
base_model_id="google/gemma-3-27b-it",
|
|
905
|
+
),
|
|
906
|
+
model_name="test-lora-model",
|
|
907
|
+
compute=definitions.Compute(
|
|
908
|
+
accelerator=truss_config.AcceleratorSpec(
|
|
909
|
+
accelerator="H100", count=4
|
|
910
|
+
)
|
|
911
|
+
),
|
|
912
|
+
runtime=definitions.DeployCheckpointsRuntime(environment_variables={}),
|
|
913
|
+
deployment_name="test-deployment",
|
|
914
|
+
model_weight_format=ModelWeightsFormat.LORA,
|
|
915
|
+
),
|
|
916
|
+
expected_model_name="test-lora-model",
|
|
917
|
+
expected_predict_endpoint="/v1/chat/completions",
|
|
918
|
+
expected_accelerator="H100",
|
|
919
|
+
expected_accelerator_count=4,
|
|
920
|
+
expected_checkpoint_paths=["rank-0/checkpoint-1/"],
|
|
921
|
+
expected_environment_variables={
|
|
922
|
+
"VLLM_LOGGING_LEVEL": "WARNING",
|
|
923
|
+
"VLLM_USE_V1": "0",
|
|
924
|
+
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
|
925
|
+
},
|
|
926
|
+
),
|
|
927
|
+
TestCase(
|
|
928
|
+
desc="Whisper checkpoint with A100 accelerator",
|
|
929
|
+
input_config=DeployCheckpointsConfigComplete(
|
|
930
|
+
checkpoint_details=definitions.CheckpointList(
|
|
931
|
+
checkpoints=[
|
|
932
|
+
definitions.WhisperCheckpoint(
|
|
933
|
+
training_job_id="job123",
|
|
934
|
+
paths=["rank-0/checkpoint-1/"],
|
|
935
|
+
model_weight_format=definitions.ModelWeightsFormat.WHISPER,
|
|
936
|
+
)
|
|
937
|
+
],
|
|
938
|
+
base_model_id="openai/whisper-large-v3",
|
|
939
|
+
),
|
|
940
|
+
model_name="test-whisper-model",
|
|
941
|
+
compute=definitions.Compute(
|
|
942
|
+
accelerator=truss_config.AcceleratorSpec(
|
|
943
|
+
accelerator="A100", count=2
|
|
944
|
+
)
|
|
945
|
+
),
|
|
946
|
+
runtime=definitions.DeployCheckpointsRuntime(environment_variables={}),
|
|
947
|
+
deployment_name="test-whisper-deployment",
|
|
948
|
+
model_weight_format=definitions.ModelWeightsFormat.WHISPER,
|
|
949
|
+
),
|
|
950
|
+
expected_model_name="test-whisper-model",
|
|
951
|
+
expected_predict_endpoint="/v1/audio/transcriptions",
|
|
952
|
+
expected_accelerator="A100",
|
|
953
|
+
expected_accelerator_count=2,
|
|
954
|
+
expected_checkpoint_paths=["rank-0/checkpoint-1/"],
|
|
955
|
+
expected_environment_variables={
|
|
956
|
+
"VLLM_LOGGING_LEVEL": "WARNING",
|
|
957
|
+
"VLLM_USE_V1": "0",
|
|
958
|
+
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
|
959
|
+
},
|
|
960
|
+
),
|
|
961
|
+
TestCase(
|
|
962
|
+
desc="Multiple LoRA checkpoints",
|
|
963
|
+
input_config=DeployCheckpointsConfigComplete(
|
|
964
|
+
checkpoint_details=definitions.CheckpointList(
|
|
965
|
+
checkpoints=[
|
|
966
|
+
definitions.LoRACheckpoint(
|
|
967
|
+
training_job_id="job123",
|
|
968
|
+
paths=["rank-0/checkpoint-1/"],
|
|
969
|
+
model_weight_format=ModelWeightsFormat.LORA,
|
|
970
|
+
lora_details=definitions.LoRADetails(rank=16),
|
|
971
|
+
),
|
|
972
|
+
definitions.LoRACheckpoint(
|
|
973
|
+
training_job_id="job123",
|
|
974
|
+
paths=["rank-0/checkpoint-2/"],
|
|
975
|
+
model_weight_format=ModelWeightsFormat.LORA,
|
|
976
|
+
lora_details=definitions.LoRADetails(rank=32),
|
|
977
|
+
),
|
|
978
|
+
],
|
|
979
|
+
base_model_id="google/gemma-3-27b-it",
|
|
980
|
+
),
|
|
981
|
+
model_name="test-multi-checkpoint-model",
|
|
982
|
+
compute=definitions.Compute(
|
|
983
|
+
accelerator=truss_config.AcceleratorSpec(
|
|
984
|
+
accelerator="H100", count=4
|
|
985
|
+
)
|
|
986
|
+
),
|
|
987
|
+
runtime=definitions.DeployCheckpointsRuntime(environment_variables={}),
|
|
988
|
+
deployment_name="test-multi-deployment",
|
|
989
|
+
model_weight_format=ModelWeightsFormat.LORA,
|
|
990
|
+
),
|
|
991
|
+
expected_model_name="test-multi-checkpoint-model",
|
|
992
|
+
expected_predict_endpoint="/v1/chat/completions",
|
|
993
|
+
expected_accelerator="H100",
|
|
994
|
+
expected_accelerator_count=4,
|
|
995
|
+
expected_checkpoint_paths=["rank-0/checkpoint-1/", "rank-0/checkpoint-2/"],
|
|
996
|
+
expected_environment_variables={
|
|
997
|
+
"VLLM_LOGGING_LEVEL": "WARNING",
|
|
998
|
+
"VLLM_USE_V1": "0",
|
|
999
|
+
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
|
1000
|
+
},
|
|
1001
|
+
),
|
|
1002
|
+
TestCase(
|
|
1003
|
+
desc="No accelerator specified",
|
|
1004
|
+
input_config=DeployCheckpointsConfigComplete(
|
|
1005
|
+
checkpoint_details=definitions.CheckpointList(
|
|
1006
|
+
checkpoints=[
|
|
1007
|
+
definitions.LoRACheckpoint(
|
|
1008
|
+
training_job_id="job123",
|
|
1009
|
+
paths=["rank-0/checkpoint-1/"],
|
|
1010
|
+
model_weight_format=ModelWeightsFormat.LORA,
|
|
1011
|
+
lora_details=definitions.LoRADetails(rank=16),
|
|
1012
|
+
)
|
|
1013
|
+
],
|
|
1014
|
+
base_model_id="google/gemma-3-27b-it",
|
|
1015
|
+
),
|
|
1016
|
+
model_name="test-no-accelerator-model",
|
|
1017
|
+
compute=definitions.Compute(), # No accelerator specified
|
|
1018
|
+
runtime=definitions.DeployCheckpointsRuntime(environment_variables={}),
|
|
1019
|
+
deployment_name="test-no-accelerator-deployment",
|
|
1020
|
+
model_weight_format=ModelWeightsFormat.LORA,
|
|
1021
|
+
),
|
|
1022
|
+
expected_model_name="test-no-accelerator-model",
|
|
1023
|
+
expected_predict_endpoint="/v1/chat/completions",
|
|
1024
|
+
expected_accelerator=None,
|
|
1025
|
+
expected_accelerator_count=None,
|
|
1026
|
+
expected_checkpoint_paths=["rank-0/checkpoint-1/"],
|
|
1027
|
+
expected_environment_variables={
|
|
1028
|
+
"VLLM_LOGGING_LEVEL": "WARNING",
|
|
1029
|
+
"VLLM_USE_V1": "0",
|
|
1030
|
+
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
|
1031
|
+
},
|
|
1032
|
+
),
|
|
1033
|
+
]
|
|
1034
|
+
|
|
1035
|
+
# Run test cases
|
|
1036
|
+
for test_case in test_cases:
|
|
1037
|
+
print(f"Running test case: {test_case.desc}")
|
|
1038
|
+
|
|
1039
|
+
if test_case.should_raise:
|
|
1040
|
+
# Test error cases
|
|
1041
|
+
with pytest.raises(Exception, match=test_case.should_raise):
|
|
1042
|
+
setup_base_truss_config(test_case.input_config)
|
|
1043
|
+
else:
|
|
1044
|
+
# Test success cases
|
|
1045
|
+
result = setup_base_truss_config(test_case.input_config)
|
|
1046
|
+
|
|
1047
|
+
# Verify basic structure
|
|
1048
|
+
assert isinstance(result, truss_config.TrussConfig), (
|
|
1049
|
+
f"Test case '{test_case.desc}': Result should be TrussConfig"
|
|
1050
|
+
)
|
|
1051
|
+
assert result.model_name == test_case.expected_model_name, (
|
|
1052
|
+
f"Test case '{test_case.desc}': Model name mismatch"
|
|
1053
|
+
)
|
|
1054
|
+
|
|
1055
|
+
# Verify docker server configuration
|
|
1056
|
+
assert result.docker_server is not None, (
|
|
1057
|
+
f"Test case '{test_case.desc}': Docker server should not be None"
|
|
1058
|
+
)
|
|
1059
|
+
assert result.docker_server.start_command == 'sh -c ""', (
|
|
1060
|
+
f"Test case '{test_case.desc}': Start command mismatch"
|
|
1061
|
+
)
|
|
1062
|
+
assert result.docker_server.readiness_endpoint == "/health", (
|
|
1063
|
+
f"Test case '{test_case.desc}': Readiness endpoint mismatch"
|
|
1064
|
+
)
|
|
1065
|
+
assert result.docker_server.liveness_endpoint == "/health", (
|
|
1066
|
+
f"Test case '{test_case.desc}': Liveness endpoint mismatch"
|
|
1067
|
+
)
|
|
1068
|
+
assert (
|
|
1069
|
+
result.docker_server.predict_endpoint
|
|
1070
|
+
== test_case.expected_predict_endpoint
|
|
1071
|
+
), f"Test case '{test_case.desc}': Predict endpoint mismatch"
|
|
1072
|
+
assert result.docker_server.server_port == 8000, (
|
|
1073
|
+
f"Test case '{test_case.desc}': Server port mismatch"
|
|
1074
|
+
)
|
|
1075
|
+
|
|
1076
|
+
# Verify training checkpoints
|
|
1077
|
+
assert result.training_checkpoints is not None, (
|
|
1078
|
+
f"Test case '{test_case.desc}': Training checkpoints should not be None"
|
|
1079
|
+
)
|
|
1080
|
+
assert len(result.training_checkpoints.artifact_references) == len(
|
|
1081
|
+
test_case.expected_checkpoint_paths
|
|
1082
|
+
), f"Test case '{test_case.desc}': Number of checkpoint artifacts mismatch"
|
|
1083
|
+
|
|
1084
|
+
for i, expected_path in enumerate(test_case.expected_checkpoint_paths):
|
|
1085
|
+
artifact_ref = result.training_checkpoints.artifact_references[i]
|
|
1086
|
+
assert artifact_ref.paths == [expected_path], (
|
|
1087
|
+
f"Test case '{test_case.desc}': Checkpoint path {i} mismatch"
|
|
1088
|
+
)
|
|
1089
|
+
|
|
1090
|
+
# Verify resources
|
|
1091
|
+
assert result.resources is not None, (
|
|
1092
|
+
f"Test case '{test_case.desc}': Resources should not be None"
|
|
1093
|
+
)
|
|
1094
|
+
|
|
1095
|
+
if test_case.expected_accelerator:
|
|
1096
|
+
assert result.resources.accelerator is not None, (
|
|
1097
|
+
f"Test case '{test_case.desc}': Accelerator should not be None"
|
|
1098
|
+
)
|
|
1099
|
+
assert (
|
|
1100
|
+
result.resources.accelerator.accelerator
|
|
1101
|
+
== test_case.expected_accelerator
|
|
1102
|
+
), f"Test case '{test_case.desc}': Accelerator type mismatch"
|
|
1103
|
+
assert (
|
|
1104
|
+
result.resources.accelerator.count
|
|
1105
|
+
== test_case.expected_accelerator_count
|
|
1106
|
+
), f"Test case '{test_case.desc}': Accelerator count mismatch"
|
|
1107
|
+
else:
|
|
1108
|
+
# When no accelerator is specified, it creates an AcceleratorSpec with None values
|
|
1109
|
+
assert result.resources.accelerator is not None, (
|
|
1110
|
+
f"Test case '{test_case.desc}': Accelerator should exist"
|
|
1111
|
+
)
|
|
1112
|
+
assert result.resources.accelerator.accelerator is None, (
|
|
1113
|
+
f"Test case '{test_case.desc}': Accelerator type should be None"
|
|
1114
|
+
)
|
|
1115
|
+
|
|
1116
|
+
# Verify environment variables
|
|
1117
|
+
for key, expected_value in test_case.expected_environment_variables.items():
|
|
1118
|
+
assert result.environment_variables[key] == expected_value, (
|
|
1119
|
+
f"Test case '{test_case.desc}': Environment variable {key} mismatch"
|
|
1120
|
+
)
|
|
@@ -51,6 +51,38 @@ def test_view_training_job_metrics(time_sleep, capfd):
|
|
|
51
51
|
"0": [{"timestamp": "", "value": 4321}],
|
|
52
52
|
"1": [{"timestamp": "", "value": 2222}],
|
|
53
53
|
},
|
|
54
|
+
"per_node_metrics": [
|
|
55
|
+
{
|
|
56
|
+
"node_id": "node-0",
|
|
57
|
+
"metrics": {
|
|
58
|
+
"cpu_usage": [{"timestamp": "", "value": 3.2}],
|
|
59
|
+
"cpu_memory_usage_bytes": [{"timestamp": "", "value": 1234}],
|
|
60
|
+
"gpu_utilization": {
|
|
61
|
+
"0": [{"timestamp": "", "value": 0.2}],
|
|
62
|
+
"1": [{"timestamp": "", "value": 0.3}],
|
|
63
|
+
},
|
|
64
|
+
"gpu_memory_usage_bytes": {
|
|
65
|
+
"0": [{"timestamp": "", "value": 4321}],
|
|
66
|
+
"1": [{"timestamp": "", "value": 2222}],
|
|
67
|
+
},
|
|
68
|
+
},
|
|
69
|
+
},
|
|
70
|
+
{
|
|
71
|
+
"node_id": "node-1",
|
|
72
|
+
"metrics": {
|
|
73
|
+
"cpu_usage": [{"timestamp": "", "value": 2.8}],
|
|
74
|
+
"cpu_memory_usage_bytes": [{"timestamp": "", "value": 1000}],
|
|
75
|
+
"gpu_utilization": {
|
|
76
|
+
"0": [{"timestamp": "", "value": 0.15}],
|
|
77
|
+
"1": [{"timestamp": "", "value": 0.25}],
|
|
78
|
+
},
|
|
79
|
+
"gpu_memory_usage_bytes": {
|
|
80
|
+
"0": [{"timestamp": "", "value": 4000}],
|
|
81
|
+
"1": [{"timestamp": "", "value": 2000}],
|
|
82
|
+
},
|
|
83
|
+
},
|
|
84
|
+
},
|
|
85
|
+
],
|
|
54
86
|
},
|
|
55
87
|
{
|
|
56
88
|
"training_job": {
|
|
@@ -68,6 +100,38 @@ def test_view_training_job_metrics(time_sleep, capfd):
|
|
|
68
100
|
"0": [{"timestamp": "", "value": 4321}],
|
|
69
101
|
"1": [{"timestamp": "", "value": 2222}],
|
|
70
102
|
},
|
|
103
|
+
"per_node_metrics": [
|
|
104
|
+
{
|
|
105
|
+
"node_id": "node-0",
|
|
106
|
+
"metrics": {
|
|
107
|
+
"cpu_usage": [{"timestamp": "", "value": 3.2}],
|
|
108
|
+
"cpu_memory_usage_bytes": [{"timestamp": "", "value": 1234}],
|
|
109
|
+
"gpu_utilization": {
|
|
110
|
+
"0": [{"timestamp": "", "value": 0.2}],
|
|
111
|
+
"1": [{"timestamp": "", "value": 0.3}],
|
|
112
|
+
},
|
|
113
|
+
"gpu_memory_usage_bytes": {
|
|
114
|
+
"0": [{"timestamp": "", "value": 4321}],
|
|
115
|
+
"1": [{"timestamp": "", "value": 2222}],
|
|
116
|
+
},
|
|
117
|
+
},
|
|
118
|
+
},
|
|
119
|
+
{
|
|
120
|
+
"node_id": "node-1",
|
|
121
|
+
"metrics": {
|
|
122
|
+
"cpu_usage": [{"timestamp": "", "value": 2.8}],
|
|
123
|
+
"cpu_memory_usage_bytes": [{"timestamp": "", "value": 1000}],
|
|
124
|
+
"gpu_utilization": {
|
|
125
|
+
"0": [{"timestamp": "", "value": 0.15}],
|
|
126
|
+
"1": [{"timestamp": "", "value": 0.25}],
|
|
127
|
+
},
|
|
128
|
+
"gpu_memory_usage_bytes": {
|
|
129
|
+
"0": [{"timestamp": "", "value": 4000}],
|
|
130
|
+
"1": [{"timestamp": "", "value": 2000}],
|
|
131
|
+
},
|
|
132
|
+
},
|
|
133
|
+
},
|
|
134
|
+
],
|
|
71
135
|
},
|
|
72
136
|
{
|
|
73
137
|
"training_job": {
|
|
@@ -85,6 +149,38 @@ def test_view_training_job_metrics(time_sleep, capfd):
|
|
|
85
149
|
"0": [{"timestamp": "", "value": 4321}],
|
|
86
150
|
"1": [{"timestamp": "", "value": 2222}],
|
|
87
151
|
},
|
|
152
|
+
"per_node_metrics": [
|
|
153
|
+
{
|
|
154
|
+
"node_id": "node-0",
|
|
155
|
+
"metrics": {
|
|
156
|
+
"cpu_usage": [{"timestamp": "", "value": 3.2}],
|
|
157
|
+
"cpu_memory_usage_bytes": [{"timestamp": "", "value": 1234}],
|
|
158
|
+
"gpu_utilization": {
|
|
159
|
+
"0": [{"timestamp": "", "value": 0.2}],
|
|
160
|
+
"1": [{"timestamp": "", "value": 0.3}],
|
|
161
|
+
},
|
|
162
|
+
"gpu_memory_usage_bytes": {
|
|
163
|
+
"0": [{"timestamp": "", "value": 4321}],
|
|
164
|
+
"1": [{"timestamp": "", "value": 2222}],
|
|
165
|
+
},
|
|
166
|
+
},
|
|
167
|
+
},
|
|
168
|
+
{
|
|
169
|
+
"node_id": "node-1",
|
|
170
|
+
"metrics": {
|
|
171
|
+
"cpu_usage": [{"timestamp": "", "value": 2.8}],
|
|
172
|
+
"cpu_memory_usage_bytes": [{"timestamp": "", "value": 1000}],
|
|
173
|
+
"gpu_utilization": {
|
|
174
|
+
"0": [{"timestamp": "", "value": 0.15}],
|
|
175
|
+
"1": [{"timestamp": "", "value": 0.25}],
|
|
176
|
+
},
|
|
177
|
+
"gpu_memory_usage_bytes": {
|
|
178
|
+
"0": [{"timestamp": "", "value": 4000}],
|
|
179
|
+
"1": [{"timestamp": "", "value": 2002}],
|
|
180
|
+
},
|
|
181
|
+
},
|
|
182
|
+
},
|
|
183
|
+
],
|
|
88
184
|
},
|
|
89
185
|
]
|
|
90
186
|
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from unittest import mock
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from truss.remote.baseten.api import BasetenApi
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@pytest.fixture
|
|
9
|
+
def mock_auth_service():
|
|
10
|
+
auth_service = mock.Mock()
|
|
11
|
+
auth_token = mock.Mock(headers=lambda: {"Authorization": "Api-Key token"})
|
|
12
|
+
auth_service.authenticate.return_value = auth_token
|
|
13
|
+
return auth_service
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@pytest.fixture
|
|
17
|
+
def baseten_api(mock_auth_service):
|
|
18
|
+
return BasetenApi("https://app.test.com", mock_auth_service)
|