truss 0.10.9rc601__py3-none-any.whl → 0.10.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

Files changed (32) hide show
  1. truss/base/constants.py +0 -1
  2. truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +30 -22
  3. truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +8 -2
  4. truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +2 -2
  5. truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +63 -0
  6. truss/cli/train/deploy_from_checkpoint_config_whisper.yml +17 -0
  7. truss/cli/train_commands.py +11 -3
  8. truss/contexts/image_builder/cache_warmer.py +1 -3
  9. truss/contexts/image_builder/serving_image_builder.py +24 -32
  10. truss/remote/baseten/api.py +11 -0
  11. truss/remote/baseten/core.py +209 -1
  12. truss/remote/baseten/utils/time.py +15 -0
  13. truss/templates/server/model_wrapper.py +0 -12
  14. truss/templates/server/requirements.txt +1 -1
  15. truss/templates/server/truss_server.py +0 -13
  16. truss/templates/server.Dockerfile.jinja +1 -1
  17. truss/tests/cli/train/test_deploy_checkpoints.py +436 -0
  18. truss/tests/contexts/image_builder/test_serving_image_builder.py +1 -1
  19. truss/tests/remote/baseten/conftest.py +18 -0
  20. truss/tests/remote/baseten/test_api.py +49 -14
  21. truss/tests/remote/baseten/test_core.py +517 -1
  22. truss/tests/test_data/test_openai/model/model.py +0 -3
  23. truss/truss_handle/truss_handle.py +0 -1
  24. {truss-0.10.9rc601.dist-info → truss-0.10.10.dist-info}/METADATA +2 -2
  25. {truss-0.10.9rc601.dist-info → truss-0.10.10.dist-info}/RECORD +30 -28
  26. truss_train/definitions.py +6 -0
  27. truss_train/deployment.py +15 -2
  28. truss/tests/util/test_basetenpointer.py +0 -227
  29. truss/util/basetenpointer.py +0 -160
  30. {truss-0.10.9rc601.dist-info → truss-0.10.10.dist-info}/WHEEL +0 -0
  31. {truss-0.10.9rc601.dist-info → truss-0.10.10.dist-info}/entry_points.txt +0 -0
  32. {truss-0.10.9rc601.dist-info → truss-0.10.10.dist-info}/licenses/LICENSE +0 -0
@@ -59,7 +59,6 @@ class MethodName(str, enum.Enum):
59
59
  CHAT_COMPLETIONS = enum.auto()
60
60
  COMPLETIONS = enum.auto()
61
61
  IS_HEALTHY = enum.auto()
62
- MESSAGES = enum.auto()
63
62
  POSTPROCESS = enum.auto()
64
63
  PREDICT = enum.auto()
65
64
  PREPROCESS = enum.auto()
@@ -245,7 +244,6 @@ class ModelDescriptor:
245
244
  is_healthy: Optional[MethodDescriptor]
246
245
  completions: Optional[MethodDescriptor]
247
246
  chat_completions: Optional[MethodDescriptor]
248
- messages: Optional[MethodDescriptor]
249
247
  websocket: Optional[MethodDescriptor]
250
248
 
251
249
  @cached_property
@@ -293,7 +291,6 @@ class ModelDescriptor:
293
291
  setup = cls._safe_extract_descriptor(model_cls, MethodName.SETUP_ENVIRONMENT)
294
292
  completions = cls._safe_extract_descriptor(model_cls, MethodName.COMPLETIONS)
295
293
  chats = cls._safe_extract_descriptor(model_cls, MethodName.CHAT_COMPLETIONS)
296
- messages = cls._safe_extract_descriptor(model_cls, MethodName.MESSAGES)
297
294
  is_healthy = cls._safe_extract_descriptor(model_cls, MethodName.IS_HEALTHY)
298
295
  if is_healthy and is_healthy.arg_config != ArgConfig.NONE:
299
296
  raise errors.ModelDefinitionError(
@@ -362,7 +359,6 @@ class ModelDescriptor:
362
359
  is_healthy=is_healthy,
363
360
  completions=completions,
364
361
  chat_completions=chats,
365
- messages=messages,
366
362
  websocket=websocket,
367
363
  )
368
364
 
@@ -929,14 +925,6 @@ class ModelWrapper:
929
925
  )
930
926
  return await self._execute_model_endpoint(inputs, request, descriptor)
931
927
 
932
- async def messages(
933
- self, inputs: InputType, request: starlette.requests.Request
934
- ) -> OutputType:
935
- descriptor = self._get_descriptor_or_raise(
936
- self.model_descriptor.messages, MethodName.MESSAGES
937
- )
938
- return await self._execute_model_endpoint(inputs, request, descriptor)
939
-
940
928
  async def websocket(self, ws: WebSocket) -> None:
941
929
  descriptor = self.model_descriptor.websocket
942
930
  assert descriptor, "websocket can only be invoked if present on model."
@@ -18,7 +18,7 @@ psutil>=5.9.4
18
18
  python-json-logger>=2.0.2
19
19
  pyyaml>=6.0.0
20
20
  requests>=2.31.0
21
- truss-transfer==0.0.27
21
+ truss-transfer==0.0.29
22
22
  uvicorn>=0.24.0
23
23
  uvloop>=0.19.0
24
24
  websockets>=10.0
@@ -231,13 +231,6 @@ class BasetenEndpoints:
231
231
  method=self._model.completions, request=request, body_raw=body_raw
232
232
  )
233
233
 
234
- async def messages(
235
- self, request: Request, body_raw: bytes = Depends(parse_body)
236
- ) -> Response:
237
- return await self._execute_request(
238
- method=self._model.messages, request=request, body_raw=body_raw
239
- )
240
-
241
234
  async def websocket(self, ws: WebSocket) -> None:
242
235
  self.check_healthy()
243
236
  trace_ctx = otel_propagate.extract(ws.headers) or None
@@ -435,12 +428,6 @@ class TrussServer:
435
428
  methods=["POST"],
436
429
  tags=["V1"],
437
430
  ),
438
- FastAPIRoute(
439
- r"/v1/messages",
440
- self._endpoints.messages,
441
- methods=["POST"],
442
- tags=["V1"],
443
- ),
444
431
  # Websocket endpoint
445
432
  FastAPIWebSocketRoute(r"/v1/websocket", self._endpoints.websocket),
446
433
  # Endpoint aliases for Sagemaker hosting
@@ -70,7 +70,7 @@ COPY ./{{ config.data_dir }} /app/data
70
70
 
71
71
  {%- if model_cache_v2 %}
72
72
  # v0.0.9, keep synced with server_requirements.txt
73
- RUN curl -sSL --fail --retry 5 --retry-delay 2 -o /usr/local/bin/truss-transfer-cli https://github.com/basetenlabs/truss/releases/download/v0.10.9rc0/truss-transfer-cli-v0.10.9rc0-linux-x86_64-unknown-linux-musl
73
+ RUN curl -sSL --fail --retry 5 --retry-delay 2 -o /usr/local/bin/truss-transfer-cli https://github.com/basetenlabs/truss/releases/download/v0.10.10rc1/truss-transfer-cli-v0.10.10rc1-linux-x86_64-unknown-linux-musl
74
74
  RUN chmod +x /usr/local/bin/truss-transfer-cli
75
75
  RUN mkdir /static-bptr
76
76
  RUN echo "hash {{model_cache_hash}}"
@@ -1,6 +1,8 @@
1
1
  import os
2
2
  import re
3
+ from dataclasses import dataclass
3
4
  from pathlib import Path
5
+ from typing import Dict, List, Optional
4
6
  from unittest.mock import MagicMock, patch
5
7
 
6
8
  import pytest
@@ -23,6 +25,11 @@ from truss.cli.train.deploy_checkpoints.deploy_lora_checkpoints import (
23
25
  hydrate_lora_checkpoint,
24
26
  render_vllm_lora_truss_config,
25
27
  )
28
+ from truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints import (
29
+ VLLM_WHISPER_START_COMMAND,
30
+ hydrate_whisper_checkpoint,
31
+ render_vllm_whisper_truss_config,
32
+ )
26
33
  from truss.cli.train.types import (
27
34
  DeployCheckpointsConfigComplete,
28
35
  PrepareCheckpointResult,
@@ -682,3 +689,432 @@ def test_get_checkpoint_ids_to_deploy_single_checkpoint():
682
689
 
683
690
  # Should return the single checkpoint directly
684
691
  assert result == ["checkpoint-1"]
692
+
693
+
694
+ def test_vllm_whisper_start_command_template():
695
+ """Test that the VLLM_WHISPER_START_COMMAND template renders correctly."""
696
+ # Test with all variables
697
+ result = VLLM_WHISPER_START_COMMAND.render(
698
+ model_path="/path/to/model",
699
+ envvars="CUDA_VISIBLE_DEVICES=0",
700
+ specify_tensor_parallelism=4,
701
+ )
702
+
703
+ expected = (
704
+ "sh -c 'CUDA_VISIBLE_DEVICES=0 "
705
+ 'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
706
+ "vllm serve /path/to/model --port 8000 --tensor-parallel-size 4'"
707
+ )
708
+ assert result == expected
709
+
710
+ result = VLLM_WHISPER_START_COMMAND.render(
711
+ model_path="/path/to/model", envvars=None, specify_tensor_parallelism=1
712
+ )
713
+
714
+ expected = (
715
+ "sh -c '"
716
+ 'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
717
+ "vllm serve /path/to/model --port 8000 --tensor-parallel-size 1'"
718
+ )
719
+ assert result == expected
720
+
721
+
722
+ def test_hydrate_whisper_checkpoint():
723
+ """Test that hydrate_whisper_checkpoint creates correct WhisperCheckpoint object."""
724
+ job_id = "test-job-123"
725
+ checkpoint_id = "checkpoint-456"
726
+ checkpoint = {"some": "data"}
727
+
728
+ result = hydrate_whisper_checkpoint(job_id, checkpoint_id, checkpoint)
729
+
730
+ assert result.training_job_id == job_id
731
+ assert result.paths == [f"rank-0/{checkpoint_id}/"]
732
+ assert result.model_weight_format == definitions.ModelWeightsFormat.WHISPER
733
+ assert isinstance(result, definitions.WhisperCheckpoint)
734
+
735
+
736
+ @patch(
737
+ "truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.setup_base_truss_config"
738
+ )
739
+ @patch(
740
+ "truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.setup_environment_variables_and_secrets"
741
+ )
742
+ @patch(
743
+ "truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.build_full_checkpoint_string"
744
+ )
745
+ def test_render_vllm_whisper_truss_config(
746
+ mock_build_full_checkpoint_string, mock_setup_env_vars, mock_setup_base_config
747
+ ):
748
+ """Test that render_vllm_whisper_truss_config renders truss config correctly."""
749
+ # Mock dependencies
750
+ mock_truss_config = MagicMock()
751
+ mock_truss_config.environment_variables = {}
752
+ mock_truss_config.docker_server = MagicMock()
753
+ mock_setup_base_config.return_value = mock_truss_config
754
+
755
+ mock_setup_env_vars.return_value = "HF_TOKEN=$(cat /secrets/hf_access_token)"
756
+ mock_build_full_checkpoint_string.return_value = "/path/to/checkpoint"
757
+
758
+ # Create test config
759
+ deploy_config = DeployCheckpointsConfigComplete(
760
+ checkpoint_details=definitions.CheckpointList(
761
+ checkpoints=[
762
+ definitions.WhisperCheckpoint(
763
+ training_job_id="job123",
764
+ paths=["rank-0/checkpoint-1/"],
765
+ model_weight_format=definitions.ModelWeightsFormat.WHISPER,
766
+ )
767
+ ],
768
+ base_model_id="openai/whisper-large-v3",
769
+ ),
770
+ model_name="whisper-large-v3-vLLM",
771
+ compute=definitions.Compute(
772
+ accelerator=truss_config.AcceleratorSpec(accelerator="H100", count=4)
773
+ ),
774
+ runtime=definitions.DeployCheckpointsRuntime(
775
+ environment_variables={
776
+ "HF_TOKEN": definitions.SecretReference(name="hf_access_token")
777
+ }
778
+ ),
779
+ deployment_name="whisper-large-v3-vLLM",
780
+ model_weight_format=definitions.ModelWeightsFormat.WHISPER,
781
+ )
782
+
783
+ result = render_vllm_whisper_truss_config(deploy_config)
784
+
785
+ mock_setup_base_config.assert_called_once_with(deploy_config)
786
+ mock_setup_env_vars.assert_called_once_with(mock_truss_config, deploy_config)
787
+ mock_build_full_checkpoint_string.assert_called_once_with(mock_truss_config)
788
+
789
+ assert result == mock_truss_config
790
+
791
+ expected_start_command = (
792
+ "sh -c 'HF_TOKEN=$(cat /secrets/hf_access_token) "
793
+ 'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
794
+ "vllm serve /path/to/checkpoint --port 8000 --tensor-parallel-size 4'"
795
+ )
796
+ assert (
797
+ result.environment_variables[START_COMMAND_ENVVAR_NAME]
798
+ == expected_start_command
799
+ )
800
+
801
+ assert result.docker_server.start_command == f"%(ENV_{START_COMMAND_ENVVAR_NAME})s"
802
+
803
+
804
+ @patch(
805
+ "truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.setup_base_truss_config"
806
+ )
807
+ @patch(
808
+ "truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.setup_environment_variables_and_secrets"
809
+ )
810
+ @patch(
811
+ "truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.build_full_checkpoint_string"
812
+ )
813
+ def test_render_vllm_whisper_truss_config_with_envvars(
814
+ mock_build_full_checkpoint_string, mock_setup_env_vars, mock_setup_base_config
815
+ ):
816
+ """Test that render_vllm_whisper_truss_config handles environment variables correctly."""
817
+ # Mock dependencies
818
+ mock_truss_config = MagicMock()
819
+ mock_truss_config.environment_variables = {}
820
+ mock_truss_config.docker_server = MagicMock()
821
+ mock_setup_base_config.return_value = mock_truss_config
822
+
823
+ mock_setup_env_vars.return_value = "CUDA_VISIBLE_DEVICES=0,1"
824
+ mock_build_full_checkpoint_string.return_value = "/path/to/checkpoint"
825
+
826
+ # Create test config with environment variables
827
+ deploy_config = DeployCheckpointsConfigComplete(
828
+ checkpoint_details=definitions.CheckpointList(
829
+ checkpoints=[
830
+ definitions.WhisperCheckpoint(
831
+ training_job_id="job123",
832
+ paths=["rank-0/checkpoint-1/"],
833
+ model_weight_format=definitions.ModelWeightsFormat.WHISPER,
834
+ )
835
+ ],
836
+ base_model_id="openai/whisper-large-v3",
837
+ ),
838
+ model_name="whisper-large-v3-vLLM",
839
+ compute=definitions.Compute(
840
+ accelerator=truss_config.AcceleratorSpec(accelerator="H100", count=2)
841
+ ),
842
+ runtime=definitions.DeployCheckpointsRuntime(
843
+ environment_variables={
844
+ "CUDA_VISIBLE_DEVICES": "0,1",
845
+ "HF_TOKEN": definitions.SecretReference(name="hf_access_token"),
846
+ }
847
+ ),
848
+ deployment_name="whisper-large-v3-vLLM",
849
+ model_weight_format=definitions.ModelWeightsFormat.WHISPER,
850
+ )
851
+
852
+ # Call function under test
853
+ result = render_vllm_whisper_truss_config(deploy_config)
854
+
855
+ # Verify environment variables are included in start command
856
+ expected_start_command = (
857
+ "sh -c 'CUDA_VISIBLE_DEVICES=0,1 "
858
+ 'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
859
+ "vllm serve /path/to/checkpoint --port 8000 --tensor-parallel-size 2'"
860
+ )
861
+ assert (
862
+ result.environment_variables[START_COMMAND_ENVVAR_NAME]
863
+ == expected_start_command
864
+ )
865
+
866
+
867
+ @dataclass
868
+ class TestCase:
869
+ """Test case for setup_base_truss_config function."""
870
+
871
+ desc: str
872
+ input_config: DeployCheckpointsConfigComplete
873
+ expected_model_name: str
874
+ expected_predict_endpoint: str
875
+ expected_accelerator: Optional[str]
876
+ expected_accelerator_count: Optional[int]
877
+ expected_checkpoint_paths: List[str]
878
+ expected_environment_variables: Dict[str, str]
879
+ should_raise: Optional[str] = None # Error message if function should raise
880
+
881
+ __test__ = False # Tell pytest this is not a test class
882
+
883
+
884
+ def test_setup_base_truss_config():
885
+ """Table-driven test for setup_base_truss_config function."""
886
+ from truss.cli.train.deploy_checkpoints.deploy_checkpoints_helpers import (
887
+ setup_base_truss_config,
888
+ )
889
+
890
+ # Define test cases
891
+ test_cases = [
892
+ TestCase(
893
+ desc="LoRA checkpoint with H100 accelerator",
894
+ input_config=DeployCheckpointsConfigComplete(
895
+ checkpoint_details=definitions.CheckpointList(
896
+ checkpoints=[
897
+ definitions.LoRACheckpoint(
898
+ training_job_id="job123",
899
+ paths=["rank-0/checkpoint-1/"],
900
+ model_weight_format=ModelWeightsFormat.LORA,
901
+ lora_details=definitions.LoRADetails(rank=32),
902
+ )
903
+ ],
904
+ base_model_id="google/gemma-3-27b-it",
905
+ ),
906
+ model_name="test-lora-model",
907
+ compute=definitions.Compute(
908
+ accelerator=truss_config.AcceleratorSpec(
909
+ accelerator="H100", count=4
910
+ )
911
+ ),
912
+ runtime=definitions.DeployCheckpointsRuntime(environment_variables={}),
913
+ deployment_name="test-deployment",
914
+ model_weight_format=ModelWeightsFormat.LORA,
915
+ ),
916
+ expected_model_name="test-lora-model",
917
+ expected_predict_endpoint="/v1/chat/completions",
918
+ expected_accelerator="H100",
919
+ expected_accelerator_count=4,
920
+ expected_checkpoint_paths=["rank-0/checkpoint-1/"],
921
+ expected_environment_variables={
922
+ "VLLM_LOGGING_LEVEL": "WARNING",
923
+ "VLLM_USE_V1": "0",
924
+ "HF_HUB_ENABLE_HF_TRANSFER": "1",
925
+ },
926
+ ),
927
+ TestCase(
928
+ desc="Whisper checkpoint with A100 accelerator",
929
+ input_config=DeployCheckpointsConfigComplete(
930
+ checkpoint_details=definitions.CheckpointList(
931
+ checkpoints=[
932
+ definitions.WhisperCheckpoint(
933
+ training_job_id="job123",
934
+ paths=["rank-0/checkpoint-1/"],
935
+ model_weight_format=definitions.ModelWeightsFormat.WHISPER,
936
+ )
937
+ ],
938
+ base_model_id="openai/whisper-large-v3",
939
+ ),
940
+ model_name="test-whisper-model",
941
+ compute=definitions.Compute(
942
+ accelerator=truss_config.AcceleratorSpec(
943
+ accelerator="A100", count=2
944
+ )
945
+ ),
946
+ runtime=definitions.DeployCheckpointsRuntime(environment_variables={}),
947
+ deployment_name="test-whisper-deployment",
948
+ model_weight_format=definitions.ModelWeightsFormat.WHISPER,
949
+ ),
950
+ expected_model_name="test-whisper-model",
951
+ expected_predict_endpoint="/v1/audio/transcriptions",
952
+ expected_accelerator="A100",
953
+ expected_accelerator_count=2,
954
+ expected_checkpoint_paths=["rank-0/checkpoint-1/"],
955
+ expected_environment_variables={
956
+ "VLLM_LOGGING_LEVEL": "WARNING",
957
+ "VLLM_USE_V1": "0",
958
+ "HF_HUB_ENABLE_HF_TRANSFER": "1",
959
+ },
960
+ ),
961
+ TestCase(
962
+ desc="Multiple LoRA checkpoints",
963
+ input_config=DeployCheckpointsConfigComplete(
964
+ checkpoint_details=definitions.CheckpointList(
965
+ checkpoints=[
966
+ definitions.LoRACheckpoint(
967
+ training_job_id="job123",
968
+ paths=["rank-0/checkpoint-1/"],
969
+ model_weight_format=ModelWeightsFormat.LORA,
970
+ lora_details=definitions.LoRADetails(rank=16),
971
+ ),
972
+ definitions.LoRACheckpoint(
973
+ training_job_id="job123",
974
+ paths=["rank-0/checkpoint-2/"],
975
+ model_weight_format=ModelWeightsFormat.LORA,
976
+ lora_details=definitions.LoRADetails(rank=32),
977
+ ),
978
+ ],
979
+ base_model_id="google/gemma-3-27b-it",
980
+ ),
981
+ model_name="test-multi-checkpoint-model",
982
+ compute=definitions.Compute(
983
+ accelerator=truss_config.AcceleratorSpec(
984
+ accelerator="H100", count=4
985
+ )
986
+ ),
987
+ runtime=definitions.DeployCheckpointsRuntime(environment_variables={}),
988
+ deployment_name="test-multi-deployment",
989
+ model_weight_format=ModelWeightsFormat.LORA,
990
+ ),
991
+ expected_model_name="test-multi-checkpoint-model",
992
+ expected_predict_endpoint="/v1/chat/completions",
993
+ expected_accelerator="H100",
994
+ expected_accelerator_count=4,
995
+ expected_checkpoint_paths=["rank-0/checkpoint-1/", "rank-0/checkpoint-2/"],
996
+ expected_environment_variables={
997
+ "VLLM_LOGGING_LEVEL": "WARNING",
998
+ "VLLM_USE_V1": "0",
999
+ "HF_HUB_ENABLE_HF_TRANSFER": "1",
1000
+ },
1001
+ ),
1002
+ TestCase(
1003
+ desc="No accelerator specified",
1004
+ input_config=DeployCheckpointsConfigComplete(
1005
+ checkpoint_details=definitions.CheckpointList(
1006
+ checkpoints=[
1007
+ definitions.LoRACheckpoint(
1008
+ training_job_id="job123",
1009
+ paths=["rank-0/checkpoint-1/"],
1010
+ model_weight_format=ModelWeightsFormat.LORA,
1011
+ lora_details=definitions.LoRADetails(rank=16),
1012
+ )
1013
+ ],
1014
+ base_model_id="google/gemma-3-27b-it",
1015
+ ),
1016
+ model_name="test-no-accelerator-model",
1017
+ compute=definitions.Compute(), # No accelerator specified
1018
+ runtime=definitions.DeployCheckpointsRuntime(environment_variables={}),
1019
+ deployment_name="test-no-accelerator-deployment",
1020
+ model_weight_format=ModelWeightsFormat.LORA,
1021
+ ),
1022
+ expected_model_name="test-no-accelerator-model",
1023
+ expected_predict_endpoint="/v1/chat/completions",
1024
+ expected_accelerator=None,
1025
+ expected_accelerator_count=None,
1026
+ expected_checkpoint_paths=["rank-0/checkpoint-1/"],
1027
+ expected_environment_variables={
1028
+ "VLLM_LOGGING_LEVEL": "WARNING",
1029
+ "VLLM_USE_V1": "0",
1030
+ "HF_HUB_ENABLE_HF_TRANSFER": "1",
1031
+ },
1032
+ ),
1033
+ ]
1034
+
1035
+ # Run test cases
1036
+ for test_case in test_cases:
1037
+ print(f"Running test case: {test_case.desc}")
1038
+
1039
+ if test_case.should_raise:
1040
+ # Test error cases
1041
+ with pytest.raises(Exception, match=test_case.should_raise):
1042
+ setup_base_truss_config(test_case.input_config)
1043
+ else:
1044
+ # Test success cases
1045
+ result = setup_base_truss_config(test_case.input_config)
1046
+
1047
+ # Verify basic structure
1048
+ assert isinstance(result, truss_config.TrussConfig), (
1049
+ f"Test case '{test_case.desc}': Result should be TrussConfig"
1050
+ )
1051
+ assert result.model_name == test_case.expected_model_name, (
1052
+ f"Test case '{test_case.desc}': Model name mismatch"
1053
+ )
1054
+
1055
+ # Verify docker server configuration
1056
+ assert result.docker_server is not None, (
1057
+ f"Test case '{test_case.desc}': Docker server should not be None"
1058
+ )
1059
+ assert result.docker_server.start_command == 'sh -c ""', (
1060
+ f"Test case '{test_case.desc}': Start command mismatch"
1061
+ )
1062
+ assert result.docker_server.readiness_endpoint == "/health", (
1063
+ f"Test case '{test_case.desc}': Readiness endpoint mismatch"
1064
+ )
1065
+ assert result.docker_server.liveness_endpoint == "/health", (
1066
+ f"Test case '{test_case.desc}': Liveness endpoint mismatch"
1067
+ )
1068
+ assert (
1069
+ result.docker_server.predict_endpoint
1070
+ == test_case.expected_predict_endpoint
1071
+ ), f"Test case '{test_case.desc}': Predict endpoint mismatch"
1072
+ assert result.docker_server.server_port == 8000, (
1073
+ f"Test case '{test_case.desc}': Server port mismatch"
1074
+ )
1075
+
1076
+ # Verify training checkpoints
1077
+ assert result.training_checkpoints is not None, (
1078
+ f"Test case '{test_case.desc}': Training checkpoints should not be None"
1079
+ )
1080
+ assert len(result.training_checkpoints.artifact_references) == len(
1081
+ test_case.expected_checkpoint_paths
1082
+ ), f"Test case '{test_case.desc}': Number of checkpoint artifacts mismatch"
1083
+
1084
+ for i, expected_path in enumerate(test_case.expected_checkpoint_paths):
1085
+ artifact_ref = result.training_checkpoints.artifact_references[i]
1086
+ assert artifact_ref.paths == [expected_path], (
1087
+ f"Test case '{test_case.desc}': Checkpoint path {i} mismatch"
1088
+ )
1089
+
1090
+ # Verify resources
1091
+ assert result.resources is not None, (
1092
+ f"Test case '{test_case.desc}': Resources should not be None"
1093
+ )
1094
+
1095
+ if test_case.expected_accelerator:
1096
+ assert result.resources.accelerator is not None, (
1097
+ f"Test case '{test_case.desc}': Accelerator should not be None"
1098
+ )
1099
+ assert (
1100
+ result.resources.accelerator.accelerator
1101
+ == test_case.expected_accelerator
1102
+ ), f"Test case '{test_case.desc}': Accelerator type mismatch"
1103
+ assert (
1104
+ result.resources.accelerator.count
1105
+ == test_case.expected_accelerator_count
1106
+ ), f"Test case '{test_case.desc}': Accelerator count mismatch"
1107
+ else:
1108
+ # When no accelerator is specified, it creates an AcceleratorSpec with None values
1109
+ assert result.resources.accelerator is not None, (
1110
+ f"Test case '{test_case.desc}': Accelerator should exist"
1111
+ )
1112
+ assert result.resources.accelerator.accelerator is None, (
1113
+ f"Test case '{test_case.desc}': Accelerator type should be None"
1114
+ )
1115
+
1116
+ # Verify environment variables
1117
+ for key, expected_value in test_case.expected_environment_variables.items():
1118
+ assert result.environment_variables[key] == expected_value, (
1119
+ f"Test case '{test_case.desc}': Environment variable {key} mismatch"
1120
+ )
@@ -10,7 +10,6 @@ import pytest
10
10
  import yaml
11
11
 
12
12
  from truss.base.constants import (
13
- HF_ACCESS_TOKEN_FILE_NAME,
14
13
  TRTLLM_BASE_IMAGE,
15
14
  TRTLLM_PREDICT_CONCURRENCY,
16
15
  TRTLLM_PYTHON_EXECUTABLE,
@@ -18,6 +17,7 @@ from truss.base.constants import (
18
17
  )
19
18
  from truss.base.truss_config import ModelCache, ModelRepo, TrussConfig
20
19
  from truss.contexts.image_builder.serving_image_builder import (
20
+ HF_ACCESS_TOKEN_FILE_NAME,
21
21
  ServingImageBuilderContext,
22
22
  get_files_to_model_cache_v1,
23
23
  )
@@ -0,0 +1,18 @@
1
+ from unittest import mock
2
+
3
+ import pytest
4
+
5
+ from truss.remote.baseten.api import BasetenApi
6
+
7
+
8
+ @pytest.fixture
9
+ def mock_auth_service():
10
+ auth_service = mock.Mock()
11
+ auth_token = mock.Mock(headers=lambda: {"Authorization": "Api-Key token"})
12
+ auth_service.authenticate.return_value = auth_token
13
+ return auth_service
14
+
15
+
16
+ @pytest.fixture
17
+ def baseten_api(mock_auth_service):
18
+ return BasetenApi("https://app.test.com", mock_auth_service)
@@ -7,19 +7,10 @@ from requests import Response
7
7
 
8
8
  import truss_train.definitions as train_definitions
9
9
  from truss.remote.baseten import custom_types as b10_types
10
- from truss.remote.baseten.api import BasetenApi
11
10
  from truss.remote.baseten.custom_types import ChainletDataAtomic, OracleData
12
11
  from truss.remote.baseten.error import ApiError
13
12
 
14
13
 
15
- @pytest.fixture
16
- def mock_auth_service():
17
- auth_service = mock.Mock()
18
- auth_token = mock.Mock(headers=lambda: {"Authorization": "Api-Key token"})
19
- auth_service.authenticate.return_value = auth_token
20
- return auth_service
21
-
22
-
23
14
  def mock_successful_response():
24
15
  response = Response()
25
16
  response.status_code = 200
@@ -134,11 +125,6 @@ def mock_deploy_chain_deployment_response():
134
125
  return response
135
126
 
136
127
 
137
- @pytest.fixture
138
- def baseten_api(mock_auth_service):
139
- return BasetenApi("https://app.test.com", mock_auth_service)
140
-
141
-
142
128
  @mock.patch("requests.post", return_value=mock_successful_response())
143
129
  def test_post_graphql_query_success(mock_post, baseten_api):
144
130
  response_data = {"data": {"status": "success"}}
@@ -439,3 +425,52 @@ def test_upsert_training_project(mock_post, baseten_api):
439
425
  upsert_body = mock_post.call_args[1]["json"]["training_project"]
440
426
  assert "job" not in upsert_body
441
427
  assert "training-project" == upsert_body["name"]
428
+
429
+
430
+ # Mock responses for training job logs pagination tests
431
+ def mock_training_job_logs_response(logs, has_more=True):
432
+ """Helper function to create mock training job logs response"""
433
+ response = Response()
434
+ response.status_code = 200
435
+ response.json = mock.Mock(return_value={"logs": logs})
436
+ return response
437
+
438
+
439
+ def mock_training_job_logs_empty_response():
440
+ """Helper function to create mock empty training job logs response"""
441
+ response = Response()
442
+ response.status_code = 200
443
+ response.json = mock.Mock(return_value={"logs": []})
444
+ return response
445
+
446
+
447
+ def mock_training_job_logs_error_response():
448
+ """Helper function to create mock error response for training job logs"""
449
+ response = Response()
450
+ response.status_code = 500
451
+ response.raise_for_status = mock.Mock(
452
+ side_effect=requests.exceptions.HTTPError("Server Error")
453
+ )
454
+ return response
455
+
456
+
457
+ def test_fetch_log_batch(baseten_api):
458
+ """Test _fetch_log_batch helper method"""
459
+
460
+ mock_logs = [
461
+ {"timestamp": "1640995200000000000", "message": "Log 1"},
462
+ {"timestamp": "1640995260000000000", "message": "Log 2"},
463
+ ]
464
+
465
+ # Mock the rest_api_client
466
+ mock_rest_client = mock.Mock()
467
+ mock_rest_client.post.return_value = {"logs": mock_logs}
468
+ baseten_api._rest_api_client = mock_rest_client
469
+
470
+ query_params = {"limit": 100, "direction": "asc"}
471
+ result = baseten_api._fetch_log_batch("project-123", "job-456", query_params)
472
+
473
+ assert result == mock_logs
474
+ mock_rest_client.post.assert_called_with(
475
+ "v1/training_projects/project-123/jobs/job-456/logs", body=query_params
476
+ )