truss 0.10.9rc535__py3-none-any.whl → 0.10.10rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

Files changed (33) hide show
  1. truss/cli/logs/base_watcher.py +1 -1
  2. truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +30 -22
  3. truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +8 -2
  4. truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +14 -7
  5. truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +63 -0
  6. truss/cli/train/deploy_from_checkpoint_config_whisper.yml +17 -0
  7. truss/cli/train/metrics_watcher.py +170 -59
  8. truss/cli/train_commands.py +11 -3
  9. truss/contexts/image_builder/serving_image_builder.py +22 -39
  10. truss/remote/baseten/api.py +11 -0
  11. truss/remote/baseten/core.py +209 -1
  12. truss/remote/baseten/utils/time.py +15 -0
  13. truss/templates/base.Dockerfile.jinja +6 -23
  14. truss/templates/cache.Dockerfile.jinja +5 -5
  15. truss/templates/copy_cache_files.Dockerfile.jinja +1 -1
  16. truss/templates/docker_server/supervisord.conf.jinja +0 -1
  17. truss/templates/server/requirements.txt +1 -1
  18. truss/templates/server.Dockerfile.jinja +16 -33
  19. truss/tests/cli/train/test_deploy_checkpoints.py +446 -2
  20. truss/tests/cli/train/test_train_cli_core.py +96 -0
  21. truss/tests/remote/baseten/conftest.py +18 -0
  22. truss/tests/remote/baseten/test_api.py +49 -14
  23. truss/tests/remote/baseten/test_core.py +517 -1
  24. {truss-0.10.9rc535.dist-info → truss-0.10.10rc0.dist-info}/METADATA +2 -2
  25. {truss-0.10.9rc535.dist-info → truss-0.10.10rc0.dist-info}/RECORD +31 -29
  26. truss_train/definitions.py +6 -0
  27. truss_train/deployment.py +15 -2
  28. truss_train/loader.py +7 -20
  29. truss/tests/util/test_basetenpointer.py +0 -227
  30. truss/util/basetenpointer.py +0 -160
  31. {truss-0.10.9rc535.dist-info → truss-0.10.10rc0.dist-info}/WHEEL +0 -0
  32. {truss-0.10.9rc535.dist-info → truss-0.10.10rc0.dist-info}/entry_points.txt +0 -0
  33. {truss-0.10.9rc535.dist-info → truss-0.10.10rc0.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,8 @@
1
1
  import os
2
2
  import re
3
+ from dataclasses import dataclass
3
4
  from pathlib import Path
5
+ from typing import Dict, List, Optional
4
6
  from unittest.mock import MagicMock, patch
5
7
 
6
8
  import pytest
@@ -23,6 +25,11 @@ from truss.cli.train.deploy_checkpoints.deploy_lora_checkpoints import (
23
25
  hydrate_lora_checkpoint,
24
26
  render_vllm_lora_truss_config,
25
27
  )
28
+ from truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints import (
29
+ VLLM_WHISPER_START_COMMAND,
30
+ hydrate_whisper_checkpoint,
31
+ render_vllm_whisper_truss_config,
32
+ )
26
33
  from truss.cli.train.types import (
27
34
  DeployCheckpointsConfigComplete,
28
35
  PrepareCheckpointResult,
@@ -505,8 +512,16 @@ def test_render_vllm_full_truss_config():
505
512
  )
506
513
 
507
514
  result = render_vllm_full_truss_config(deploy_config)
508
-
509
- expected_vllm_command = 'sh -c "HF_TOKEN=$(cat /secrets/hf_token) vllm serve /tmp/training_checkpoints/job123/rank-0/checkpoint-1 --port 8000 --tensor-parallel-size 2 --dtype bfloat16"'
515
+ expected_vllm_command = (
516
+ "sh -c 'HF_TOKEN=$(cat /secrets/hf_token) "
517
+ 'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
518
+ "if [ -f /tmp/training_checkpoints/job123/rank-0/checkpoint-1/chat_template.jinja ]; then "
519
+ "vllm serve /tmp/training_checkpoints/job123/rank-0/checkpoint-1 "
520
+ "--chat-template /tmp/training_checkpoints/job123/rank-0/checkpoint-1/chat_template.jinja "
521
+ "--port 8000 --tensor-parallel-size 2 --dtype bfloat16; else "
522
+ "vllm serve /tmp/training_checkpoints/job123/rank-0/checkpoint-1 "
523
+ "--port 8000 --tensor-parallel-size 2 --dtype bfloat16; fi'"
524
+ )
510
525
 
511
526
  assert isinstance(result, truss_config.TrussConfig)
512
527
  assert result.model_name == "test-full-model"
@@ -674,3 +689,432 @@ def test_get_checkpoint_ids_to_deploy_single_checkpoint():
674
689
 
675
690
  # Should return the single checkpoint directly
676
691
  assert result == ["checkpoint-1"]
692
+
693
+
694
+ def test_vllm_whisper_start_command_template():
695
+ """Test that the VLLM_WHISPER_START_COMMAND template renders correctly."""
696
+ # Test with all variables
697
+ result = VLLM_WHISPER_START_COMMAND.render(
698
+ model_path="/path/to/model",
699
+ envvars="CUDA_VISIBLE_DEVICES=0",
700
+ specify_tensor_parallelism=4,
701
+ )
702
+
703
+ expected = (
704
+ "sh -c 'CUDA_VISIBLE_DEVICES=0 "
705
+ 'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
706
+ "vllm serve /path/to/model --port 8000 --tensor-parallel-size 4'"
707
+ )
708
+ assert result == expected
709
+
710
+ result = VLLM_WHISPER_START_COMMAND.render(
711
+ model_path="/path/to/model", envvars=None, specify_tensor_parallelism=1
712
+ )
713
+
714
+ expected = (
715
+ "sh -c '"
716
+ 'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
717
+ "vllm serve /path/to/model --port 8000 --tensor-parallel-size 1'"
718
+ )
719
+ assert result == expected
720
+
721
+
722
+ def test_hydrate_whisper_checkpoint():
723
+ """Test that hydrate_whisper_checkpoint creates correct WhisperCheckpoint object."""
724
+ job_id = "test-job-123"
725
+ checkpoint_id = "checkpoint-456"
726
+ checkpoint = {"some": "data"}
727
+
728
+ result = hydrate_whisper_checkpoint(job_id, checkpoint_id, checkpoint)
729
+
730
+ assert result.training_job_id == job_id
731
+ assert result.paths == [f"rank-0/{checkpoint_id}/"]
732
+ assert result.model_weight_format == definitions.ModelWeightsFormat.WHISPER
733
+ assert isinstance(result, definitions.WhisperCheckpoint)
734
+
735
+
736
+ @patch(
737
+ "truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.setup_base_truss_config"
738
+ )
739
+ @patch(
740
+ "truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.setup_environment_variables_and_secrets"
741
+ )
742
+ @patch(
743
+ "truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.build_full_checkpoint_string"
744
+ )
745
+ def test_render_vllm_whisper_truss_config(
746
+ mock_build_full_checkpoint_string, mock_setup_env_vars, mock_setup_base_config
747
+ ):
748
+ """Test that render_vllm_whisper_truss_config renders truss config correctly."""
749
+ # Mock dependencies
750
+ mock_truss_config = MagicMock()
751
+ mock_truss_config.environment_variables = {}
752
+ mock_truss_config.docker_server = MagicMock()
753
+ mock_setup_base_config.return_value = mock_truss_config
754
+
755
+ mock_setup_env_vars.return_value = "HF_TOKEN=$(cat /secrets/hf_access_token)"
756
+ mock_build_full_checkpoint_string.return_value = "/path/to/checkpoint"
757
+
758
+ # Create test config
759
+ deploy_config = DeployCheckpointsConfigComplete(
760
+ checkpoint_details=definitions.CheckpointList(
761
+ checkpoints=[
762
+ definitions.WhisperCheckpoint(
763
+ training_job_id="job123",
764
+ paths=["rank-0/checkpoint-1/"],
765
+ model_weight_format=definitions.ModelWeightsFormat.WHISPER,
766
+ )
767
+ ],
768
+ base_model_id="openai/whisper-large-v3",
769
+ ),
770
+ model_name="whisper-large-v3-vLLM",
771
+ compute=definitions.Compute(
772
+ accelerator=truss_config.AcceleratorSpec(accelerator="H100", count=4)
773
+ ),
774
+ runtime=definitions.DeployCheckpointsRuntime(
775
+ environment_variables={
776
+ "HF_TOKEN": definitions.SecretReference(name="hf_access_token")
777
+ }
778
+ ),
779
+ deployment_name="whisper-large-v3-vLLM",
780
+ model_weight_format=definitions.ModelWeightsFormat.WHISPER,
781
+ )
782
+
783
+ result = render_vllm_whisper_truss_config(deploy_config)
784
+
785
+ mock_setup_base_config.assert_called_once_with(deploy_config)
786
+ mock_setup_env_vars.assert_called_once_with(mock_truss_config, deploy_config)
787
+ mock_build_full_checkpoint_string.assert_called_once_with(mock_truss_config)
788
+
789
+ assert result == mock_truss_config
790
+
791
+ expected_start_command = (
792
+ "sh -c 'HF_TOKEN=$(cat /secrets/hf_access_token) "
793
+ 'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
794
+ "vllm serve /path/to/checkpoint --port 8000 --tensor-parallel-size 4'"
795
+ )
796
+ assert (
797
+ result.environment_variables[START_COMMAND_ENVVAR_NAME]
798
+ == expected_start_command
799
+ )
800
+
801
+ assert result.docker_server.start_command == f"%(ENV_{START_COMMAND_ENVVAR_NAME})s"
802
+
803
+
804
+ @patch(
805
+ "truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.setup_base_truss_config"
806
+ )
807
+ @patch(
808
+ "truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.setup_environment_variables_and_secrets"
809
+ )
810
+ @patch(
811
+ "truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.build_full_checkpoint_string"
812
+ )
813
+ def test_render_vllm_whisper_truss_config_with_envvars(
814
+ mock_build_full_checkpoint_string, mock_setup_env_vars, mock_setup_base_config
815
+ ):
816
+ """Test that render_vllm_whisper_truss_config handles environment variables correctly."""
817
+ # Mock dependencies
818
+ mock_truss_config = MagicMock()
819
+ mock_truss_config.environment_variables = {}
820
+ mock_truss_config.docker_server = MagicMock()
821
+ mock_setup_base_config.return_value = mock_truss_config
822
+
823
+ mock_setup_env_vars.return_value = "CUDA_VISIBLE_DEVICES=0,1"
824
+ mock_build_full_checkpoint_string.return_value = "/path/to/checkpoint"
825
+
826
+ # Create test config with environment variables
827
+ deploy_config = DeployCheckpointsConfigComplete(
828
+ checkpoint_details=definitions.CheckpointList(
829
+ checkpoints=[
830
+ definitions.WhisperCheckpoint(
831
+ training_job_id="job123",
832
+ paths=["rank-0/checkpoint-1/"],
833
+ model_weight_format=definitions.ModelWeightsFormat.WHISPER,
834
+ )
835
+ ],
836
+ base_model_id="openai/whisper-large-v3",
837
+ ),
838
+ model_name="whisper-large-v3-vLLM",
839
+ compute=definitions.Compute(
840
+ accelerator=truss_config.AcceleratorSpec(accelerator="H100", count=2)
841
+ ),
842
+ runtime=definitions.DeployCheckpointsRuntime(
843
+ environment_variables={
844
+ "CUDA_VISIBLE_DEVICES": "0,1",
845
+ "HF_TOKEN": definitions.SecretReference(name="hf_access_token"),
846
+ }
847
+ ),
848
+ deployment_name="whisper-large-v3-vLLM",
849
+ model_weight_format=definitions.ModelWeightsFormat.WHISPER,
850
+ )
851
+
852
+ # Call function under test
853
+ result = render_vllm_whisper_truss_config(deploy_config)
854
+
855
+ # Verify environment variables are included in start command
856
+ expected_start_command = (
857
+ "sh -c 'CUDA_VISIBLE_DEVICES=0,1 "
858
+ 'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
859
+ "vllm serve /path/to/checkpoint --port 8000 --tensor-parallel-size 2'"
860
+ )
861
+ assert (
862
+ result.environment_variables[START_COMMAND_ENVVAR_NAME]
863
+ == expected_start_command
864
+ )
865
+
866
+
867
+ @dataclass
868
+ class TestCase:
869
+ """Test case for setup_base_truss_config function."""
870
+
871
+ desc: str
872
+ input_config: DeployCheckpointsConfigComplete
873
+ expected_model_name: str
874
+ expected_predict_endpoint: str
875
+ expected_accelerator: Optional[str]
876
+ expected_accelerator_count: Optional[int]
877
+ expected_checkpoint_paths: List[str]
878
+ expected_environment_variables: Dict[str, str]
879
+ should_raise: Optional[str] = None # Error message if function should raise
880
+
881
+ __test__ = False # Tell pytest this is not a test class
882
+
883
+
884
+ def test_setup_base_truss_config():
885
+ """Table-driven test for setup_base_truss_config function."""
886
+ from truss.cli.train.deploy_checkpoints.deploy_checkpoints_helpers import (
887
+ setup_base_truss_config,
888
+ )
889
+
890
+ # Define test cases
891
+ test_cases = [
892
+ TestCase(
893
+ desc="LoRA checkpoint with H100 accelerator",
894
+ input_config=DeployCheckpointsConfigComplete(
895
+ checkpoint_details=definitions.CheckpointList(
896
+ checkpoints=[
897
+ definitions.LoRACheckpoint(
898
+ training_job_id="job123",
899
+ paths=["rank-0/checkpoint-1/"],
900
+ model_weight_format=ModelWeightsFormat.LORA,
901
+ lora_details=definitions.LoRADetails(rank=32),
902
+ )
903
+ ],
904
+ base_model_id="google/gemma-3-27b-it",
905
+ ),
906
+ model_name="test-lora-model",
907
+ compute=definitions.Compute(
908
+ accelerator=truss_config.AcceleratorSpec(
909
+ accelerator="H100", count=4
910
+ )
911
+ ),
912
+ runtime=definitions.DeployCheckpointsRuntime(environment_variables={}),
913
+ deployment_name="test-deployment",
914
+ model_weight_format=ModelWeightsFormat.LORA,
915
+ ),
916
+ expected_model_name="test-lora-model",
917
+ expected_predict_endpoint="/v1/chat/completions",
918
+ expected_accelerator="H100",
919
+ expected_accelerator_count=4,
920
+ expected_checkpoint_paths=["rank-0/checkpoint-1/"],
921
+ expected_environment_variables={
922
+ "VLLM_LOGGING_LEVEL": "WARNING",
923
+ "VLLM_USE_V1": "0",
924
+ "HF_HUB_ENABLE_HF_TRANSFER": "1",
925
+ },
926
+ ),
927
+ TestCase(
928
+ desc="Whisper checkpoint with A100 accelerator",
929
+ input_config=DeployCheckpointsConfigComplete(
930
+ checkpoint_details=definitions.CheckpointList(
931
+ checkpoints=[
932
+ definitions.WhisperCheckpoint(
933
+ training_job_id="job123",
934
+ paths=["rank-0/checkpoint-1/"],
935
+ model_weight_format=definitions.ModelWeightsFormat.WHISPER,
936
+ )
937
+ ],
938
+ base_model_id="openai/whisper-large-v3",
939
+ ),
940
+ model_name="test-whisper-model",
941
+ compute=definitions.Compute(
942
+ accelerator=truss_config.AcceleratorSpec(
943
+ accelerator="A100", count=2
944
+ )
945
+ ),
946
+ runtime=definitions.DeployCheckpointsRuntime(environment_variables={}),
947
+ deployment_name="test-whisper-deployment",
948
+ model_weight_format=definitions.ModelWeightsFormat.WHISPER,
949
+ ),
950
+ expected_model_name="test-whisper-model",
951
+ expected_predict_endpoint="/v1/audio/transcriptions",
952
+ expected_accelerator="A100",
953
+ expected_accelerator_count=2,
954
+ expected_checkpoint_paths=["rank-0/checkpoint-1/"],
955
+ expected_environment_variables={
956
+ "VLLM_LOGGING_LEVEL": "WARNING",
957
+ "VLLM_USE_V1": "0",
958
+ "HF_HUB_ENABLE_HF_TRANSFER": "1",
959
+ },
960
+ ),
961
+ TestCase(
962
+ desc="Multiple LoRA checkpoints",
963
+ input_config=DeployCheckpointsConfigComplete(
964
+ checkpoint_details=definitions.CheckpointList(
965
+ checkpoints=[
966
+ definitions.LoRACheckpoint(
967
+ training_job_id="job123",
968
+ paths=["rank-0/checkpoint-1/"],
969
+ model_weight_format=ModelWeightsFormat.LORA,
970
+ lora_details=definitions.LoRADetails(rank=16),
971
+ ),
972
+ definitions.LoRACheckpoint(
973
+ training_job_id="job123",
974
+ paths=["rank-0/checkpoint-2/"],
975
+ model_weight_format=ModelWeightsFormat.LORA,
976
+ lora_details=definitions.LoRADetails(rank=32),
977
+ ),
978
+ ],
979
+ base_model_id="google/gemma-3-27b-it",
980
+ ),
981
+ model_name="test-multi-checkpoint-model",
982
+ compute=definitions.Compute(
983
+ accelerator=truss_config.AcceleratorSpec(
984
+ accelerator="H100", count=4
985
+ )
986
+ ),
987
+ runtime=definitions.DeployCheckpointsRuntime(environment_variables={}),
988
+ deployment_name="test-multi-deployment",
989
+ model_weight_format=ModelWeightsFormat.LORA,
990
+ ),
991
+ expected_model_name="test-multi-checkpoint-model",
992
+ expected_predict_endpoint="/v1/chat/completions",
993
+ expected_accelerator="H100",
994
+ expected_accelerator_count=4,
995
+ expected_checkpoint_paths=["rank-0/checkpoint-1/", "rank-0/checkpoint-2/"],
996
+ expected_environment_variables={
997
+ "VLLM_LOGGING_LEVEL": "WARNING",
998
+ "VLLM_USE_V1": "0",
999
+ "HF_HUB_ENABLE_HF_TRANSFER": "1",
1000
+ },
1001
+ ),
1002
+ TestCase(
1003
+ desc="No accelerator specified",
1004
+ input_config=DeployCheckpointsConfigComplete(
1005
+ checkpoint_details=definitions.CheckpointList(
1006
+ checkpoints=[
1007
+ definitions.LoRACheckpoint(
1008
+ training_job_id="job123",
1009
+ paths=["rank-0/checkpoint-1/"],
1010
+ model_weight_format=ModelWeightsFormat.LORA,
1011
+ lora_details=definitions.LoRADetails(rank=16),
1012
+ )
1013
+ ],
1014
+ base_model_id="google/gemma-3-27b-it",
1015
+ ),
1016
+ model_name="test-no-accelerator-model",
1017
+ compute=definitions.Compute(), # No accelerator specified
1018
+ runtime=definitions.DeployCheckpointsRuntime(environment_variables={}),
1019
+ deployment_name="test-no-accelerator-deployment",
1020
+ model_weight_format=ModelWeightsFormat.LORA,
1021
+ ),
1022
+ expected_model_name="test-no-accelerator-model",
1023
+ expected_predict_endpoint="/v1/chat/completions",
1024
+ expected_accelerator=None,
1025
+ expected_accelerator_count=None,
1026
+ expected_checkpoint_paths=["rank-0/checkpoint-1/"],
1027
+ expected_environment_variables={
1028
+ "VLLM_LOGGING_LEVEL": "WARNING",
1029
+ "VLLM_USE_V1": "0",
1030
+ "HF_HUB_ENABLE_HF_TRANSFER": "1",
1031
+ },
1032
+ ),
1033
+ ]
1034
+
1035
+ # Run test cases
1036
+ for test_case in test_cases:
1037
+ print(f"Running test case: {test_case.desc}")
1038
+
1039
+ if test_case.should_raise:
1040
+ # Test error cases
1041
+ with pytest.raises(Exception, match=test_case.should_raise):
1042
+ setup_base_truss_config(test_case.input_config)
1043
+ else:
1044
+ # Test success cases
1045
+ result = setup_base_truss_config(test_case.input_config)
1046
+
1047
+ # Verify basic structure
1048
+ assert isinstance(result, truss_config.TrussConfig), (
1049
+ f"Test case '{test_case.desc}': Result should be TrussConfig"
1050
+ )
1051
+ assert result.model_name == test_case.expected_model_name, (
1052
+ f"Test case '{test_case.desc}': Model name mismatch"
1053
+ )
1054
+
1055
+ # Verify docker server configuration
1056
+ assert result.docker_server is not None, (
1057
+ f"Test case '{test_case.desc}': Docker server should not be None"
1058
+ )
1059
+ assert result.docker_server.start_command == 'sh -c ""', (
1060
+ f"Test case '{test_case.desc}': Start command mismatch"
1061
+ )
1062
+ assert result.docker_server.readiness_endpoint == "/health", (
1063
+ f"Test case '{test_case.desc}': Readiness endpoint mismatch"
1064
+ )
1065
+ assert result.docker_server.liveness_endpoint == "/health", (
1066
+ f"Test case '{test_case.desc}': Liveness endpoint mismatch"
1067
+ )
1068
+ assert (
1069
+ result.docker_server.predict_endpoint
1070
+ == test_case.expected_predict_endpoint
1071
+ ), f"Test case '{test_case.desc}': Predict endpoint mismatch"
1072
+ assert result.docker_server.server_port == 8000, (
1073
+ f"Test case '{test_case.desc}': Server port mismatch"
1074
+ )
1075
+
1076
+ # Verify training checkpoints
1077
+ assert result.training_checkpoints is not None, (
1078
+ f"Test case '{test_case.desc}': Training checkpoints should not be None"
1079
+ )
1080
+ assert len(result.training_checkpoints.artifact_references) == len(
1081
+ test_case.expected_checkpoint_paths
1082
+ ), f"Test case '{test_case.desc}': Number of checkpoint artifacts mismatch"
1083
+
1084
+ for i, expected_path in enumerate(test_case.expected_checkpoint_paths):
1085
+ artifact_ref = result.training_checkpoints.artifact_references[i]
1086
+ assert artifact_ref.paths == [expected_path], (
1087
+ f"Test case '{test_case.desc}': Checkpoint path {i} mismatch"
1088
+ )
1089
+
1090
+ # Verify resources
1091
+ assert result.resources is not None, (
1092
+ f"Test case '{test_case.desc}': Resources should not be None"
1093
+ )
1094
+
1095
+ if test_case.expected_accelerator:
1096
+ assert result.resources.accelerator is not None, (
1097
+ f"Test case '{test_case.desc}': Accelerator should not be None"
1098
+ )
1099
+ assert (
1100
+ result.resources.accelerator.accelerator
1101
+ == test_case.expected_accelerator
1102
+ ), f"Test case '{test_case.desc}': Accelerator type mismatch"
1103
+ assert (
1104
+ result.resources.accelerator.count
1105
+ == test_case.expected_accelerator_count
1106
+ ), f"Test case '{test_case.desc}': Accelerator count mismatch"
1107
+ else:
1108
+ # When no accelerator is specified, it creates an AcceleratorSpec with None values
1109
+ assert result.resources.accelerator is not None, (
1110
+ f"Test case '{test_case.desc}': Accelerator should exist"
1111
+ )
1112
+ assert result.resources.accelerator.accelerator is None, (
1113
+ f"Test case '{test_case.desc}': Accelerator type should be None"
1114
+ )
1115
+
1116
+ # Verify environment variables
1117
+ for key, expected_value in test_case.expected_environment_variables.items():
1118
+ assert result.environment_variables[key] == expected_value, (
1119
+ f"Test case '{test_case.desc}': Environment variable {key} mismatch"
1120
+ )
@@ -51,6 +51,38 @@ def test_view_training_job_metrics(time_sleep, capfd):
51
51
  "0": [{"timestamp": "", "value": 4321}],
52
52
  "1": [{"timestamp": "", "value": 2222}],
53
53
  },
54
+ "per_node_metrics": [
55
+ {
56
+ "node_id": "node-0",
57
+ "metrics": {
58
+ "cpu_usage": [{"timestamp": "", "value": 3.2}],
59
+ "cpu_memory_usage_bytes": [{"timestamp": "", "value": 1234}],
60
+ "gpu_utilization": {
61
+ "0": [{"timestamp": "", "value": 0.2}],
62
+ "1": [{"timestamp": "", "value": 0.3}],
63
+ },
64
+ "gpu_memory_usage_bytes": {
65
+ "0": [{"timestamp": "", "value": 4321}],
66
+ "1": [{"timestamp": "", "value": 2222}],
67
+ },
68
+ },
69
+ },
70
+ {
71
+ "node_id": "node-1",
72
+ "metrics": {
73
+ "cpu_usage": [{"timestamp": "", "value": 2.8}],
74
+ "cpu_memory_usage_bytes": [{"timestamp": "", "value": 1000}],
75
+ "gpu_utilization": {
76
+ "0": [{"timestamp": "", "value": 0.15}],
77
+ "1": [{"timestamp": "", "value": 0.25}],
78
+ },
79
+ "gpu_memory_usage_bytes": {
80
+ "0": [{"timestamp": "", "value": 4000}],
81
+ "1": [{"timestamp": "", "value": 2000}],
82
+ },
83
+ },
84
+ },
85
+ ],
54
86
  },
55
87
  {
56
88
  "training_job": {
@@ -68,6 +100,38 @@ def test_view_training_job_metrics(time_sleep, capfd):
68
100
  "0": [{"timestamp": "", "value": 4321}],
69
101
  "1": [{"timestamp": "", "value": 2222}],
70
102
  },
103
+ "per_node_metrics": [
104
+ {
105
+ "node_id": "node-0",
106
+ "metrics": {
107
+ "cpu_usage": [{"timestamp": "", "value": 3.2}],
108
+ "cpu_memory_usage_bytes": [{"timestamp": "", "value": 1234}],
109
+ "gpu_utilization": {
110
+ "0": [{"timestamp": "", "value": 0.2}],
111
+ "1": [{"timestamp": "", "value": 0.3}],
112
+ },
113
+ "gpu_memory_usage_bytes": {
114
+ "0": [{"timestamp": "", "value": 4321}],
115
+ "1": [{"timestamp": "", "value": 2222}],
116
+ },
117
+ },
118
+ },
119
+ {
120
+ "node_id": "node-1",
121
+ "metrics": {
122
+ "cpu_usage": [{"timestamp": "", "value": 2.8}],
123
+ "cpu_memory_usage_bytes": [{"timestamp": "", "value": 1000}],
124
+ "gpu_utilization": {
125
+ "0": [{"timestamp": "", "value": 0.15}],
126
+ "1": [{"timestamp": "", "value": 0.25}],
127
+ },
128
+ "gpu_memory_usage_bytes": {
129
+ "0": [{"timestamp": "", "value": 4000}],
130
+ "1": [{"timestamp": "", "value": 2000}],
131
+ },
132
+ },
133
+ },
134
+ ],
71
135
  },
72
136
  {
73
137
  "training_job": {
@@ -85,6 +149,38 @@ def test_view_training_job_metrics(time_sleep, capfd):
85
149
  "0": [{"timestamp": "", "value": 4321}],
86
150
  "1": [{"timestamp": "", "value": 2222}],
87
151
  },
152
+ "per_node_metrics": [
153
+ {
154
+ "node_id": "node-0",
155
+ "metrics": {
156
+ "cpu_usage": [{"timestamp": "", "value": 3.2}],
157
+ "cpu_memory_usage_bytes": [{"timestamp": "", "value": 1234}],
158
+ "gpu_utilization": {
159
+ "0": [{"timestamp": "", "value": 0.2}],
160
+ "1": [{"timestamp": "", "value": 0.3}],
161
+ },
162
+ "gpu_memory_usage_bytes": {
163
+ "0": [{"timestamp": "", "value": 4321}],
164
+ "1": [{"timestamp": "", "value": 2222}],
165
+ },
166
+ },
167
+ },
168
+ {
169
+ "node_id": "node-1",
170
+ "metrics": {
171
+ "cpu_usage": [{"timestamp": "", "value": 2.8}],
172
+ "cpu_memory_usage_bytes": [{"timestamp": "", "value": 1000}],
173
+ "gpu_utilization": {
174
+ "0": [{"timestamp": "", "value": 0.15}],
175
+ "1": [{"timestamp": "", "value": 0.25}],
176
+ },
177
+ "gpu_memory_usage_bytes": {
178
+ "0": [{"timestamp": "", "value": 4000}],
179
+ "1": [{"timestamp": "", "value": 2002}],
180
+ },
181
+ },
182
+ },
183
+ ],
88
184
  },
89
185
  ]
90
186
 
@@ -0,0 +1,18 @@
1
+ from unittest import mock
2
+
3
+ import pytest
4
+
5
+ from truss.remote.baseten.api import BasetenApi
6
+
7
+
8
+ @pytest.fixture
9
+ def mock_auth_service():
10
+ auth_service = mock.Mock()
11
+ auth_token = mock.Mock(headers=lambda: {"Authorization": "Api-Key token"})
12
+ auth_service.authenticate.return_value = auth_token
13
+ return auth_service
14
+
15
+
16
+ @pytest.fixture
17
+ def baseten_api(mock_auth_service):
18
+ return BasetenApi("https://app.test.com", mock_auth_service)