truss 0.10.9rc601__py3-none-any.whl → 0.10.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/base/constants.py +0 -1
- truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +30 -22
- truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +8 -2
- truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +2 -2
- truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +63 -0
- truss/cli/train/deploy_from_checkpoint_config_whisper.yml +17 -0
- truss/cli/train_commands.py +11 -3
- truss/contexts/image_builder/cache_warmer.py +1 -3
- truss/contexts/image_builder/serving_image_builder.py +24 -32
- truss/remote/baseten/api.py +11 -0
- truss/remote/baseten/core.py +209 -1
- truss/remote/baseten/utils/time.py +15 -0
- truss/templates/server/model_wrapper.py +0 -12
- truss/templates/server/requirements.txt +1 -1
- truss/templates/server/truss_server.py +0 -13
- truss/templates/server.Dockerfile.jinja +1 -1
- truss/tests/cli/train/test_deploy_checkpoints.py +436 -0
- truss/tests/contexts/image_builder/test_serving_image_builder.py +1 -1
- truss/tests/remote/baseten/conftest.py +18 -0
- truss/tests/remote/baseten/test_api.py +49 -14
- truss/tests/remote/baseten/test_core.py +517 -1
- truss/tests/test_data/test_openai/model/model.py +0 -3
- truss/truss_handle/truss_handle.py +0 -1
- {truss-0.10.9rc601.dist-info → truss-0.10.10.dist-info}/METADATA +2 -2
- {truss-0.10.9rc601.dist-info → truss-0.10.10.dist-info}/RECORD +30 -28
- truss_train/definitions.py +6 -0
- truss_train/deployment.py +15 -2
- truss/tests/util/test_basetenpointer.py +0 -227
- truss/util/basetenpointer.py +0 -160
- {truss-0.10.9rc601.dist-info → truss-0.10.10.dist-info}/WHEEL +0 -0
- {truss-0.10.9rc601.dist-info → truss-0.10.10.dist-info}/entry_points.txt +0 -0
- {truss-0.10.9rc601.dist-info → truss-0.10.10.dist-info}/licenses/LICENSE +0 -0
|
@@ -59,7 +59,6 @@ class MethodName(str, enum.Enum):
|
|
|
59
59
|
CHAT_COMPLETIONS = enum.auto()
|
|
60
60
|
COMPLETIONS = enum.auto()
|
|
61
61
|
IS_HEALTHY = enum.auto()
|
|
62
|
-
MESSAGES = enum.auto()
|
|
63
62
|
POSTPROCESS = enum.auto()
|
|
64
63
|
PREDICT = enum.auto()
|
|
65
64
|
PREPROCESS = enum.auto()
|
|
@@ -245,7 +244,6 @@ class ModelDescriptor:
|
|
|
245
244
|
is_healthy: Optional[MethodDescriptor]
|
|
246
245
|
completions: Optional[MethodDescriptor]
|
|
247
246
|
chat_completions: Optional[MethodDescriptor]
|
|
248
|
-
messages: Optional[MethodDescriptor]
|
|
249
247
|
websocket: Optional[MethodDescriptor]
|
|
250
248
|
|
|
251
249
|
@cached_property
|
|
@@ -293,7 +291,6 @@ class ModelDescriptor:
|
|
|
293
291
|
setup = cls._safe_extract_descriptor(model_cls, MethodName.SETUP_ENVIRONMENT)
|
|
294
292
|
completions = cls._safe_extract_descriptor(model_cls, MethodName.COMPLETIONS)
|
|
295
293
|
chats = cls._safe_extract_descriptor(model_cls, MethodName.CHAT_COMPLETIONS)
|
|
296
|
-
messages = cls._safe_extract_descriptor(model_cls, MethodName.MESSAGES)
|
|
297
294
|
is_healthy = cls._safe_extract_descriptor(model_cls, MethodName.IS_HEALTHY)
|
|
298
295
|
if is_healthy and is_healthy.arg_config != ArgConfig.NONE:
|
|
299
296
|
raise errors.ModelDefinitionError(
|
|
@@ -362,7 +359,6 @@ class ModelDescriptor:
|
|
|
362
359
|
is_healthy=is_healthy,
|
|
363
360
|
completions=completions,
|
|
364
361
|
chat_completions=chats,
|
|
365
|
-
messages=messages,
|
|
366
362
|
websocket=websocket,
|
|
367
363
|
)
|
|
368
364
|
|
|
@@ -929,14 +925,6 @@ class ModelWrapper:
|
|
|
929
925
|
)
|
|
930
926
|
return await self._execute_model_endpoint(inputs, request, descriptor)
|
|
931
927
|
|
|
932
|
-
async def messages(
|
|
933
|
-
self, inputs: InputType, request: starlette.requests.Request
|
|
934
|
-
) -> OutputType:
|
|
935
|
-
descriptor = self._get_descriptor_or_raise(
|
|
936
|
-
self.model_descriptor.messages, MethodName.MESSAGES
|
|
937
|
-
)
|
|
938
|
-
return await self._execute_model_endpoint(inputs, request, descriptor)
|
|
939
|
-
|
|
940
928
|
async def websocket(self, ws: WebSocket) -> None:
|
|
941
929
|
descriptor = self.model_descriptor.websocket
|
|
942
930
|
assert descriptor, "websocket can only be invoked if present on model."
|
|
@@ -231,13 +231,6 @@ class BasetenEndpoints:
|
|
|
231
231
|
method=self._model.completions, request=request, body_raw=body_raw
|
|
232
232
|
)
|
|
233
233
|
|
|
234
|
-
async def messages(
|
|
235
|
-
self, request: Request, body_raw: bytes = Depends(parse_body)
|
|
236
|
-
) -> Response:
|
|
237
|
-
return await self._execute_request(
|
|
238
|
-
method=self._model.messages, request=request, body_raw=body_raw
|
|
239
|
-
)
|
|
240
|
-
|
|
241
234
|
async def websocket(self, ws: WebSocket) -> None:
|
|
242
235
|
self.check_healthy()
|
|
243
236
|
trace_ctx = otel_propagate.extract(ws.headers) or None
|
|
@@ -435,12 +428,6 @@ class TrussServer:
|
|
|
435
428
|
methods=["POST"],
|
|
436
429
|
tags=["V1"],
|
|
437
430
|
),
|
|
438
|
-
FastAPIRoute(
|
|
439
|
-
r"/v1/messages",
|
|
440
|
-
self._endpoints.messages,
|
|
441
|
-
methods=["POST"],
|
|
442
|
-
tags=["V1"],
|
|
443
|
-
),
|
|
444
431
|
# Websocket endpoint
|
|
445
432
|
FastAPIWebSocketRoute(r"/v1/websocket", self._endpoints.websocket),
|
|
446
433
|
# Endpoint aliases for Sagemaker hosting
|
|
@@ -70,7 +70,7 @@ COPY ./{{ config.data_dir }} /app/data
|
|
|
70
70
|
|
|
71
71
|
{%- if model_cache_v2 %}
|
|
72
72
|
# v0.0.9, keep synced with server_requirements.txt
|
|
73
|
-
RUN curl -sSL --fail --retry 5 --retry-delay 2 -o /usr/local/bin/truss-transfer-cli https://github.com/basetenlabs/truss/releases/download/v0.10.
|
|
73
|
+
RUN curl -sSL --fail --retry 5 --retry-delay 2 -o /usr/local/bin/truss-transfer-cli https://github.com/basetenlabs/truss/releases/download/v0.10.10rc1/truss-transfer-cli-v0.10.10rc1-linux-x86_64-unknown-linux-musl
|
|
74
74
|
RUN chmod +x /usr/local/bin/truss-transfer-cli
|
|
75
75
|
RUN mkdir /static-bptr
|
|
76
76
|
RUN echo "hash {{model_cache_hash}}"
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import re
|
|
3
|
+
from dataclasses import dataclass
|
|
3
4
|
from pathlib import Path
|
|
5
|
+
from typing import Dict, List, Optional
|
|
4
6
|
from unittest.mock import MagicMock, patch
|
|
5
7
|
|
|
6
8
|
import pytest
|
|
@@ -23,6 +25,11 @@ from truss.cli.train.deploy_checkpoints.deploy_lora_checkpoints import (
|
|
|
23
25
|
hydrate_lora_checkpoint,
|
|
24
26
|
render_vllm_lora_truss_config,
|
|
25
27
|
)
|
|
28
|
+
from truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints import (
|
|
29
|
+
VLLM_WHISPER_START_COMMAND,
|
|
30
|
+
hydrate_whisper_checkpoint,
|
|
31
|
+
render_vllm_whisper_truss_config,
|
|
32
|
+
)
|
|
26
33
|
from truss.cli.train.types import (
|
|
27
34
|
DeployCheckpointsConfigComplete,
|
|
28
35
|
PrepareCheckpointResult,
|
|
@@ -682,3 +689,432 @@ def test_get_checkpoint_ids_to_deploy_single_checkpoint():
|
|
|
682
689
|
|
|
683
690
|
# Should return the single checkpoint directly
|
|
684
691
|
assert result == ["checkpoint-1"]
|
|
692
|
+
|
|
693
|
+
|
|
694
|
+
def test_vllm_whisper_start_command_template():
|
|
695
|
+
"""Test that the VLLM_WHISPER_START_COMMAND template renders correctly."""
|
|
696
|
+
# Test with all variables
|
|
697
|
+
result = VLLM_WHISPER_START_COMMAND.render(
|
|
698
|
+
model_path="/path/to/model",
|
|
699
|
+
envvars="CUDA_VISIBLE_DEVICES=0",
|
|
700
|
+
specify_tensor_parallelism=4,
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
expected = (
|
|
704
|
+
"sh -c 'CUDA_VISIBLE_DEVICES=0 "
|
|
705
|
+
'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
|
|
706
|
+
"vllm serve /path/to/model --port 8000 --tensor-parallel-size 4'"
|
|
707
|
+
)
|
|
708
|
+
assert result == expected
|
|
709
|
+
|
|
710
|
+
result = VLLM_WHISPER_START_COMMAND.render(
|
|
711
|
+
model_path="/path/to/model", envvars=None, specify_tensor_parallelism=1
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
expected = (
|
|
715
|
+
"sh -c '"
|
|
716
|
+
'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
|
|
717
|
+
"vllm serve /path/to/model --port 8000 --tensor-parallel-size 1'"
|
|
718
|
+
)
|
|
719
|
+
assert result == expected
|
|
720
|
+
|
|
721
|
+
|
|
722
|
+
def test_hydrate_whisper_checkpoint():
|
|
723
|
+
"""Test that hydrate_whisper_checkpoint creates correct WhisperCheckpoint object."""
|
|
724
|
+
job_id = "test-job-123"
|
|
725
|
+
checkpoint_id = "checkpoint-456"
|
|
726
|
+
checkpoint = {"some": "data"}
|
|
727
|
+
|
|
728
|
+
result = hydrate_whisper_checkpoint(job_id, checkpoint_id, checkpoint)
|
|
729
|
+
|
|
730
|
+
assert result.training_job_id == job_id
|
|
731
|
+
assert result.paths == [f"rank-0/{checkpoint_id}/"]
|
|
732
|
+
assert result.model_weight_format == definitions.ModelWeightsFormat.WHISPER
|
|
733
|
+
assert isinstance(result, definitions.WhisperCheckpoint)
|
|
734
|
+
|
|
735
|
+
|
|
736
|
+
@patch(
|
|
737
|
+
"truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.setup_base_truss_config"
|
|
738
|
+
)
|
|
739
|
+
@patch(
|
|
740
|
+
"truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.setup_environment_variables_and_secrets"
|
|
741
|
+
)
|
|
742
|
+
@patch(
|
|
743
|
+
"truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.build_full_checkpoint_string"
|
|
744
|
+
)
|
|
745
|
+
def test_render_vllm_whisper_truss_config(
|
|
746
|
+
mock_build_full_checkpoint_string, mock_setup_env_vars, mock_setup_base_config
|
|
747
|
+
):
|
|
748
|
+
"""Test that render_vllm_whisper_truss_config renders truss config correctly."""
|
|
749
|
+
# Mock dependencies
|
|
750
|
+
mock_truss_config = MagicMock()
|
|
751
|
+
mock_truss_config.environment_variables = {}
|
|
752
|
+
mock_truss_config.docker_server = MagicMock()
|
|
753
|
+
mock_setup_base_config.return_value = mock_truss_config
|
|
754
|
+
|
|
755
|
+
mock_setup_env_vars.return_value = "HF_TOKEN=$(cat /secrets/hf_access_token)"
|
|
756
|
+
mock_build_full_checkpoint_string.return_value = "/path/to/checkpoint"
|
|
757
|
+
|
|
758
|
+
# Create test config
|
|
759
|
+
deploy_config = DeployCheckpointsConfigComplete(
|
|
760
|
+
checkpoint_details=definitions.CheckpointList(
|
|
761
|
+
checkpoints=[
|
|
762
|
+
definitions.WhisperCheckpoint(
|
|
763
|
+
training_job_id="job123",
|
|
764
|
+
paths=["rank-0/checkpoint-1/"],
|
|
765
|
+
model_weight_format=definitions.ModelWeightsFormat.WHISPER,
|
|
766
|
+
)
|
|
767
|
+
],
|
|
768
|
+
base_model_id="openai/whisper-large-v3",
|
|
769
|
+
),
|
|
770
|
+
model_name="whisper-large-v3-vLLM",
|
|
771
|
+
compute=definitions.Compute(
|
|
772
|
+
accelerator=truss_config.AcceleratorSpec(accelerator="H100", count=4)
|
|
773
|
+
),
|
|
774
|
+
runtime=definitions.DeployCheckpointsRuntime(
|
|
775
|
+
environment_variables={
|
|
776
|
+
"HF_TOKEN": definitions.SecretReference(name="hf_access_token")
|
|
777
|
+
}
|
|
778
|
+
),
|
|
779
|
+
deployment_name="whisper-large-v3-vLLM",
|
|
780
|
+
model_weight_format=definitions.ModelWeightsFormat.WHISPER,
|
|
781
|
+
)
|
|
782
|
+
|
|
783
|
+
result = render_vllm_whisper_truss_config(deploy_config)
|
|
784
|
+
|
|
785
|
+
mock_setup_base_config.assert_called_once_with(deploy_config)
|
|
786
|
+
mock_setup_env_vars.assert_called_once_with(mock_truss_config, deploy_config)
|
|
787
|
+
mock_build_full_checkpoint_string.assert_called_once_with(mock_truss_config)
|
|
788
|
+
|
|
789
|
+
assert result == mock_truss_config
|
|
790
|
+
|
|
791
|
+
expected_start_command = (
|
|
792
|
+
"sh -c 'HF_TOKEN=$(cat /secrets/hf_access_token) "
|
|
793
|
+
'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
|
|
794
|
+
"vllm serve /path/to/checkpoint --port 8000 --tensor-parallel-size 4'"
|
|
795
|
+
)
|
|
796
|
+
assert (
|
|
797
|
+
result.environment_variables[START_COMMAND_ENVVAR_NAME]
|
|
798
|
+
== expected_start_command
|
|
799
|
+
)
|
|
800
|
+
|
|
801
|
+
assert result.docker_server.start_command == f"%(ENV_{START_COMMAND_ENVVAR_NAME})s"
|
|
802
|
+
|
|
803
|
+
|
|
804
|
+
@patch(
|
|
805
|
+
"truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.setup_base_truss_config"
|
|
806
|
+
)
|
|
807
|
+
@patch(
|
|
808
|
+
"truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.setup_environment_variables_and_secrets"
|
|
809
|
+
)
|
|
810
|
+
@patch(
|
|
811
|
+
"truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.build_full_checkpoint_string"
|
|
812
|
+
)
|
|
813
|
+
def test_render_vllm_whisper_truss_config_with_envvars(
|
|
814
|
+
mock_build_full_checkpoint_string, mock_setup_env_vars, mock_setup_base_config
|
|
815
|
+
):
|
|
816
|
+
"""Test that render_vllm_whisper_truss_config handles environment variables correctly."""
|
|
817
|
+
# Mock dependencies
|
|
818
|
+
mock_truss_config = MagicMock()
|
|
819
|
+
mock_truss_config.environment_variables = {}
|
|
820
|
+
mock_truss_config.docker_server = MagicMock()
|
|
821
|
+
mock_setup_base_config.return_value = mock_truss_config
|
|
822
|
+
|
|
823
|
+
mock_setup_env_vars.return_value = "CUDA_VISIBLE_DEVICES=0,1"
|
|
824
|
+
mock_build_full_checkpoint_string.return_value = "/path/to/checkpoint"
|
|
825
|
+
|
|
826
|
+
# Create test config with environment variables
|
|
827
|
+
deploy_config = DeployCheckpointsConfigComplete(
|
|
828
|
+
checkpoint_details=definitions.CheckpointList(
|
|
829
|
+
checkpoints=[
|
|
830
|
+
definitions.WhisperCheckpoint(
|
|
831
|
+
training_job_id="job123",
|
|
832
|
+
paths=["rank-0/checkpoint-1/"],
|
|
833
|
+
model_weight_format=definitions.ModelWeightsFormat.WHISPER,
|
|
834
|
+
)
|
|
835
|
+
],
|
|
836
|
+
base_model_id="openai/whisper-large-v3",
|
|
837
|
+
),
|
|
838
|
+
model_name="whisper-large-v3-vLLM",
|
|
839
|
+
compute=definitions.Compute(
|
|
840
|
+
accelerator=truss_config.AcceleratorSpec(accelerator="H100", count=2)
|
|
841
|
+
),
|
|
842
|
+
runtime=definitions.DeployCheckpointsRuntime(
|
|
843
|
+
environment_variables={
|
|
844
|
+
"CUDA_VISIBLE_DEVICES": "0,1",
|
|
845
|
+
"HF_TOKEN": definitions.SecretReference(name="hf_access_token"),
|
|
846
|
+
}
|
|
847
|
+
),
|
|
848
|
+
deployment_name="whisper-large-v3-vLLM",
|
|
849
|
+
model_weight_format=definitions.ModelWeightsFormat.WHISPER,
|
|
850
|
+
)
|
|
851
|
+
|
|
852
|
+
# Call function under test
|
|
853
|
+
result = render_vllm_whisper_truss_config(deploy_config)
|
|
854
|
+
|
|
855
|
+
# Verify environment variables are included in start command
|
|
856
|
+
expected_start_command = (
|
|
857
|
+
"sh -c 'CUDA_VISIBLE_DEVICES=0,1 "
|
|
858
|
+
'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
|
|
859
|
+
"vllm serve /path/to/checkpoint --port 8000 --tensor-parallel-size 2'"
|
|
860
|
+
)
|
|
861
|
+
assert (
|
|
862
|
+
result.environment_variables[START_COMMAND_ENVVAR_NAME]
|
|
863
|
+
== expected_start_command
|
|
864
|
+
)
|
|
865
|
+
|
|
866
|
+
|
|
867
|
+
@dataclass
|
|
868
|
+
class TestCase:
|
|
869
|
+
"""Test case for setup_base_truss_config function."""
|
|
870
|
+
|
|
871
|
+
desc: str
|
|
872
|
+
input_config: DeployCheckpointsConfigComplete
|
|
873
|
+
expected_model_name: str
|
|
874
|
+
expected_predict_endpoint: str
|
|
875
|
+
expected_accelerator: Optional[str]
|
|
876
|
+
expected_accelerator_count: Optional[int]
|
|
877
|
+
expected_checkpoint_paths: List[str]
|
|
878
|
+
expected_environment_variables: Dict[str, str]
|
|
879
|
+
should_raise: Optional[str] = None # Error message if function should raise
|
|
880
|
+
|
|
881
|
+
__test__ = False # Tell pytest this is not a test class
|
|
882
|
+
|
|
883
|
+
|
|
884
|
+
def test_setup_base_truss_config():
|
|
885
|
+
"""Table-driven test for setup_base_truss_config function."""
|
|
886
|
+
from truss.cli.train.deploy_checkpoints.deploy_checkpoints_helpers import (
|
|
887
|
+
setup_base_truss_config,
|
|
888
|
+
)
|
|
889
|
+
|
|
890
|
+
# Define test cases
|
|
891
|
+
test_cases = [
|
|
892
|
+
TestCase(
|
|
893
|
+
desc="LoRA checkpoint with H100 accelerator",
|
|
894
|
+
input_config=DeployCheckpointsConfigComplete(
|
|
895
|
+
checkpoint_details=definitions.CheckpointList(
|
|
896
|
+
checkpoints=[
|
|
897
|
+
definitions.LoRACheckpoint(
|
|
898
|
+
training_job_id="job123",
|
|
899
|
+
paths=["rank-0/checkpoint-1/"],
|
|
900
|
+
model_weight_format=ModelWeightsFormat.LORA,
|
|
901
|
+
lora_details=definitions.LoRADetails(rank=32),
|
|
902
|
+
)
|
|
903
|
+
],
|
|
904
|
+
base_model_id="google/gemma-3-27b-it",
|
|
905
|
+
),
|
|
906
|
+
model_name="test-lora-model",
|
|
907
|
+
compute=definitions.Compute(
|
|
908
|
+
accelerator=truss_config.AcceleratorSpec(
|
|
909
|
+
accelerator="H100", count=4
|
|
910
|
+
)
|
|
911
|
+
),
|
|
912
|
+
runtime=definitions.DeployCheckpointsRuntime(environment_variables={}),
|
|
913
|
+
deployment_name="test-deployment",
|
|
914
|
+
model_weight_format=ModelWeightsFormat.LORA,
|
|
915
|
+
),
|
|
916
|
+
expected_model_name="test-lora-model",
|
|
917
|
+
expected_predict_endpoint="/v1/chat/completions",
|
|
918
|
+
expected_accelerator="H100",
|
|
919
|
+
expected_accelerator_count=4,
|
|
920
|
+
expected_checkpoint_paths=["rank-0/checkpoint-1/"],
|
|
921
|
+
expected_environment_variables={
|
|
922
|
+
"VLLM_LOGGING_LEVEL": "WARNING",
|
|
923
|
+
"VLLM_USE_V1": "0",
|
|
924
|
+
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
|
925
|
+
},
|
|
926
|
+
),
|
|
927
|
+
TestCase(
|
|
928
|
+
desc="Whisper checkpoint with A100 accelerator",
|
|
929
|
+
input_config=DeployCheckpointsConfigComplete(
|
|
930
|
+
checkpoint_details=definitions.CheckpointList(
|
|
931
|
+
checkpoints=[
|
|
932
|
+
definitions.WhisperCheckpoint(
|
|
933
|
+
training_job_id="job123",
|
|
934
|
+
paths=["rank-0/checkpoint-1/"],
|
|
935
|
+
model_weight_format=definitions.ModelWeightsFormat.WHISPER,
|
|
936
|
+
)
|
|
937
|
+
],
|
|
938
|
+
base_model_id="openai/whisper-large-v3",
|
|
939
|
+
),
|
|
940
|
+
model_name="test-whisper-model",
|
|
941
|
+
compute=definitions.Compute(
|
|
942
|
+
accelerator=truss_config.AcceleratorSpec(
|
|
943
|
+
accelerator="A100", count=2
|
|
944
|
+
)
|
|
945
|
+
),
|
|
946
|
+
runtime=definitions.DeployCheckpointsRuntime(environment_variables={}),
|
|
947
|
+
deployment_name="test-whisper-deployment",
|
|
948
|
+
model_weight_format=definitions.ModelWeightsFormat.WHISPER,
|
|
949
|
+
),
|
|
950
|
+
expected_model_name="test-whisper-model",
|
|
951
|
+
expected_predict_endpoint="/v1/audio/transcriptions",
|
|
952
|
+
expected_accelerator="A100",
|
|
953
|
+
expected_accelerator_count=2,
|
|
954
|
+
expected_checkpoint_paths=["rank-0/checkpoint-1/"],
|
|
955
|
+
expected_environment_variables={
|
|
956
|
+
"VLLM_LOGGING_LEVEL": "WARNING",
|
|
957
|
+
"VLLM_USE_V1": "0",
|
|
958
|
+
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
|
959
|
+
},
|
|
960
|
+
),
|
|
961
|
+
TestCase(
|
|
962
|
+
desc="Multiple LoRA checkpoints",
|
|
963
|
+
input_config=DeployCheckpointsConfigComplete(
|
|
964
|
+
checkpoint_details=definitions.CheckpointList(
|
|
965
|
+
checkpoints=[
|
|
966
|
+
definitions.LoRACheckpoint(
|
|
967
|
+
training_job_id="job123",
|
|
968
|
+
paths=["rank-0/checkpoint-1/"],
|
|
969
|
+
model_weight_format=ModelWeightsFormat.LORA,
|
|
970
|
+
lora_details=definitions.LoRADetails(rank=16),
|
|
971
|
+
),
|
|
972
|
+
definitions.LoRACheckpoint(
|
|
973
|
+
training_job_id="job123",
|
|
974
|
+
paths=["rank-0/checkpoint-2/"],
|
|
975
|
+
model_weight_format=ModelWeightsFormat.LORA,
|
|
976
|
+
lora_details=definitions.LoRADetails(rank=32),
|
|
977
|
+
),
|
|
978
|
+
],
|
|
979
|
+
base_model_id="google/gemma-3-27b-it",
|
|
980
|
+
),
|
|
981
|
+
model_name="test-multi-checkpoint-model",
|
|
982
|
+
compute=definitions.Compute(
|
|
983
|
+
accelerator=truss_config.AcceleratorSpec(
|
|
984
|
+
accelerator="H100", count=4
|
|
985
|
+
)
|
|
986
|
+
),
|
|
987
|
+
runtime=definitions.DeployCheckpointsRuntime(environment_variables={}),
|
|
988
|
+
deployment_name="test-multi-deployment",
|
|
989
|
+
model_weight_format=ModelWeightsFormat.LORA,
|
|
990
|
+
),
|
|
991
|
+
expected_model_name="test-multi-checkpoint-model",
|
|
992
|
+
expected_predict_endpoint="/v1/chat/completions",
|
|
993
|
+
expected_accelerator="H100",
|
|
994
|
+
expected_accelerator_count=4,
|
|
995
|
+
expected_checkpoint_paths=["rank-0/checkpoint-1/", "rank-0/checkpoint-2/"],
|
|
996
|
+
expected_environment_variables={
|
|
997
|
+
"VLLM_LOGGING_LEVEL": "WARNING",
|
|
998
|
+
"VLLM_USE_V1": "0",
|
|
999
|
+
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
|
1000
|
+
},
|
|
1001
|
+
),
|
|
1002
|
+
TestCase(
|
|
1003
|
+
desc="No accelerator specified",
|
|
1004
|
+
input_config=DeployCheckpointsConfigComplete(
|
|
1005
|
+
checkpoint_details=definitions.CheckpointList(
|
|
1006
|
+
checkpoints=[
|
|
1007
|
+
definitions.LoRACheckpoint(
|
|
1008
|
+
training_job_id="job123",
|
|
1009
|
+
paths=["rank-0/checkpoint-1/"],
|
|
1010
|
+
model_weight_format=ModelWeightsFormat.LORA,
|
|
1011
|
+
lora_details=definitions.LoRADetails(rank=16),
|
|
1012
|
+
)
|
|
1013
|
+
],
|
|
1014
|
+
base_model_id="google/gemma-3-27b-it",
|
|
1015
|
+
),
|
|
1016
|
+
model_name="test-no-accelerator-model",
|
|
1017
|
+
compute=definitions.Compute(), # No accelerator specified
|
|
1018
|
+
runtime=definitions.DeployCheckpointsRuntime(environment_variables={}),
|
|
1019
|
+
deployment_name="test-no-accelerator-deployment",
|
|
1020
|
+
model_weight_format=ModelWeightsFormat.LORA,
|
|
1021
|
+
),
|
|
1022
|
+
expected_model_name="test-no-accelerator-model",
|
|
1023
|
+
expected_predict_endpoint="/v1/chat/completions",
|
|
1024
|
+
expected_accelerator=None,
|
|
1025
|
+
expected_accelerator_count=None,
|
|
1026
|
+
expected_checkpoint_paths=["rank-0/checkpoint-1/"],
|
|
1027
|
+
expected_environment_variables={
|
|
1028
|
+
"VLLM_LOGGING_LEVEL": "WARNING",
|
|
1029
|
+
"VLLM_USE_V1": "0",
|
|
1030
|
+
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
|
1031
|
+
},
|
|
1032
|
+
),
|
|
1033
|
+
]
|
|
1034
|
+
|
|
1035
|
+
# Run test cases
|
|
1036
|
+
for test_case in test_cases:
|
|
1037
|
+
print(f"Running test case: {test_case.desc}")
|
|
1038
|
+
|
|
1039
|
+
if test_case.should_raise:
|
|
1040
|
+
# Test error cases
|
|
1041
|
+
with pytest.raises(Exception, match=test_case.should_raise):
|
|
1042
|
+
setup_base_truss_config(test_case.input_config)
|
|
1043
|
+
else:
|
|
1044
|
+
# Test success cases
|
|
1045
|
+
result = setup_base_truss_config(test_case.input_config)
|
|
1046
|
+
|
|
1047
|
+
# Verify basic structure
|
|
1048
|
+
assert isinstance(result, truss_config.TrussConfig), (
|
|
1049
|
+
f"Test case '{test_case.desc}': Result should be TrussConfig"
|
|
1050
|
+
)
|
|
1051
|
+
assert result.model_name == test_case.expected_model_name, (
|
|
1052
|
+
f"Test case '{test_case.desc}': Model name mismatch"
|
|
1053
|
+
)
|
|
1054
|
+
|
|
1055
|
+
# Verify docker server configuration
|
|
1056
|
+
assert result.docker_server is not None, (
|
|
1057
|
+
f"Test case '{test_case.desc}': Docker server should not be None"
|
|
1058
|
+
)
|
|
1059
|
+
assert result.docker_server.start_command == 'sh -c ""', (
|
|
1060
|
+
f"Test case '{test_case.desc}': Start command mismatch"
|
|
1061
|
+
)
|
|
1062
|
+
assert result.docker_server.readiness_endpoint == "/health", (
|
|
1063
|
+
f"Test case '{test_case.desc}': Readiness endpoint mismatch"
|
|
1064
|
+
)
|
|
1065
|
+
assert result.docker_server.liveness_endpoint == "/health", (
|
|
1066
|
+
f"Test case '{test_case.desc}': Liveness endpoint mismatch"
|
|
1067
|
+
)
|
|
1068
|
+
assert (
|
|
1069
|
+
result.docker_server.predict_endpoint
|
|
1070
|
+
== test_case.expected_predict_endpoint
|
|
1071
|
+
), f"Test case '{test_case.desc}': Predict endpoint mismatch"
|
|
1072
|
+
assert result.docker_server.server_port == 8000, (
|
|
1073
|
+
f"Test case '{test_case.desc}': Server port mismatch"
|
|
1074
|
+
)
|
|
1075
|
+
|
|
1076
|
+
# Verify training checkpoints
|
|
1077
|
+
assert result.training_checkpoints is not None, (
|
|
1078
|
+
f"Test case '{test_case.desc}': Training checkpoints should not be None"
|
|
1079
|
+
)
|
|
1080
|
+
assert len(result.training_checkpoints.artifact_references) == len(
|
|
1081
|
+
test_case.expected_checkpoint_paths
|
|
1082
|
+
), f"Test case '{test_case.desc}': Number of checkpoint artifacts mismatch"
|
|
1083
|
+
|
|
1084
|
+
for i, expected_path in enumerate(test_case.expected_checkpoint_paths):
|
|
1085
|
+
artifact_ref = result.training_checkpoints.artifact_references[i]
|
|
1086
|
+
assert artifact_ref.paths == [expected_path], (
|
|
1087
|
+
f"Test case '{test_case.desc}': Checkpoint path {i} mismatch"
|
|
1088
|
+
)
|
|
1089
|
+
|
|
1090
|
+
# Verify resources
|
|
1091
|
+
assert result.resources is not None, (
|
|
1092
|
+
f"Test case '{test_case.desc}': Resources should not be None"
|
|
1093
|
+
)
|
|
1094
|
+
|
|
1095
|
+
if test_case.expected_accelerator:
|
|
1096
|
+
assert result.resources.accelerator is not None, (
|
|
1097
|
+
f"Test case '{test_case.desc}': Accelerator should not be None"
|
|
1098
|
+
)
|
|
1099
|
+
assert (
|
|
1100
|
+
result.resources.accelerator.accelerator
|
|
1101
|
+
== test_case.expected_accelerator
|
|
1102
|
+
), f"Test case '{test_case.desc}': Accelerator type mismatch"
|
|
1103
|
+
assert (
|
|
1104
|
+
result.resources.accelerator.count
|
|
1105
|
+
== test_case.expected_accelerator_count
|
|
1106
|
+
), f"Test case '{test_case.desc}': Accelerator count mismatch"
|
|
1107
|
+
else:
|
|
1108
|
+
# When no accelerator is specified, it creates an AcceleratorSpec with None values
|
|
1109
|
+
assert result.resources.accelerator is not None, (
|
|
1110
|
+
f"Test case '{test_case.desc}': Accelerator should exist"
|
|
1111
|
+
)
|
|
1112
|
+
assert result.resources.accelerator.accelerator is None, (
|
|
1113
|
+
f"Test case '{test_case.desc}': Accelerator type should be None"
|
|
1114
|
+
)
|
|
1115
|
+
|
|
1116
|
+
# Verify environment variables
|
|
1117
|
+
for key, expected_value in test_case.expected_environment_variables.items():
|
|
1118
|
+
assert result.environment_variables[key] == expected_value, (
|
|
1119
|
+
f"Test case '{test_case.desc}': Environment variable {key} mismatch"
|
|
1120
|
+
)
|
|
@@ -10,7 +10,6 @@ import pytest
|
|
|
10
10
|
import yaml
|
|
11
11
|
|
|
12
12
|
from truss.base.constants import (
|
|
13
|
-
HF_ACCESS_TOKEN_FILE_NAME,
|
|
14
13
|
TRTLLM_BASE_IMAGE,
|
|
15
14
|
TRTLLM_PREDICT_CONCURRENCY,
|
|
16
15
|
TRTLLM_PYTHON_EXECUTABLE,
|
|
@@ -18,6 +17,7 @@ from truss.base.constants import (
|
|
|
18
17
|
)
|
|
19
18
|
from truss.base.truss_config import ModelCache, ModelRepo, TrussConfig
|
|
20
19
|
from truss.contexts.image_builder.serving_image_builder import (
|
|
20
|
+
HF_ACCESS_TOKEN_FILE_NAME,
|
|
21
21
|
ServingImageBuilderContext,
|
|
22
22
|
get_files_to_model_cache_v1,
|
|
23
23
|
)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from unittest import mock
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from truss.remote.baseten.api import BasetenApi
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@pytest.fixture
|
|
9
|
+
def mock_auth_service():
|
|
10
|
+
auth_service = mock.Mock()
|
|
11
|
+
auth_token = mock.Mock(headers=lambda: {"Authorization": "Api-Key token"})
|
|
12
|
+
auth_service.authenticate.return_value = auth_token
|
|
13
|
+
return auth_service
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@pytest.fixture
|
|
17
|
+
def baseten_api(mock_auth_service):
|
|
18
|
+
return BasetenApi("https://app.test.com", mock_auth_service)
|
|
@@ -7,19 +7,10 @@ from requests import Response
|
|
|
7
7
|
|
|
8
8
|
import truss_train.definitions as train_definitions
|
|
9
9
|
from truss.remote.baseten import custom_types as b10_types
|
|
10
|
-
from truss.remote.baseten.api import BasetenApi
|
|
11
10
|
from truss.remote.baseten.custom_types import ChainletDataAtomic, OracleData
|
|
12
11
|
from truss.remote.baseten.error import ApiError
|
|
13
12
|
|
|
14
13
|
|
|
15
|
-
@pytest.fixture
|
|
16
|
-
def mock_auth_service():
|
|
17
|
-
auth_service = mock.Mock()
|
|
18
|
-
auth_token = mock.Mock(headers=lambda: {"Authorization": "Api-Key token"})
|
|
19
|
-
auth_service.authenticate.return_value = auth_token
|
|
20
|
-
return auth_service
|
|
21
|
-
|
|
22
|
-
|
|
23
14
|
def mock_successful_response():
|
|
24
15
|
response = Response()
|
|
25
16
|
response.status_code = 200
|
|
@@ -134,11 +125,6 @@ def mock_deploy_chain_deployment_response():
|
|
|
134
125
|
return response
|
|
135
126
|
|
|
136
127
|
|
|
137
|
-
@pytest.fixture
|
|
138
|
-
def baseten_api(mock_auth_service):
|
|
139
|
-
return BasetenApi("https://app.test.com", mock_auth_service)
|
|
140
|
-
|
|
141
|
-
|
|
142
128
|
@mock.patch("requests.post", return_value=mock_successful_response())
|
|
143
129
|
def test_post_graphql_query_success(mock_post, baseten_api):
|
|
144
130
|
response_data = {"data": {"status": "success"}}
|
|
@@ -439,3 +425,52 @@ def test_upsert_training_project(mock_post, baseten_api):
|
|
|
439
425
|
upsert_body = mock_post.call_args[1]["json"]["training_project"]
|
|
440
426
|
assert "job" not in upsert_body
|
|
441
427
|
assert "training-project" == upsert_body["name"]
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
# Mock responses for training job logs pagination tests
|
|
431
|
+
def mock_training_job_logs_response(logs, has_more=True):
|
|
432
|
+
"""Helper function to create mock training job logs response"""
|
|
433
|
+
response = Response()
|
|
434
|
+
response.status_code = 200
|
|
435
|
+
response.json = mock.Mock(return_value={"logs": logs})
|
|
436
|
+
return response
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def mock_training_job_logs_empty_response():
|
|
440
|
+
"""Helper function to create mock empty training job logs response"""
|
|
441
|
+
response = Response()
|
|
442
|
+
response.status_code = 200
|
|
443
|
+
response.json = mock.Mock(return_value={"logs": []})
|
|
444
|
+
return response
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def mock_training_job_logs_error_response():
|
|
448
|
+
"""Helper function to create mock error response for training job logs"""
|
|
449
|
+
response = Response()
|
|
450
|
+
response.status_code = 500
|
|
451
|
+
response.raise_for_status = mock.Mock(
|
|
452
|
+
side_effect=requests.exceptions.HTTPError("Server Error")
|
|
453
|
+
)
|
|
454
|
+
return response
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def test_fetch_log_batch(baseten_api):
|
|
458
|
+
"""Test _fetch_log_batch helper method"""
|
|
459
|
+
|
|
460
|
+
mock_logs = [
|
|
461
|
+
{"timestamp": "1640995200000000000", "message": "Log 1"},
|
|
462
|
+
{"timestamp": "1640995260000000000", "message": "Log 2"},
|
|
463
|
+
]
|
|
464
|
+
|
|
465
|
+
# Mock the rest_api_client
|
|
466
|
+
mock_rest_client = mock.Mock()
|
|
467
|
+
mock_rest_client.post.return_value = {"logs": mock_logs}
|
|
468
|
+
baseten_api._rest_api_client = mock_rest_client
|
|
469
|
+
|
|
470
|
+
query_params = {"limit": 100, "direction": "asc"}
|
|
471
|
+
result = baseten_api._fetch_log_batch("project-123", "job-456", query_params)
|
|
472
|
+
|
|
473
|
+
assert result == mock_logs
|
|
474
|
+
mock_rest_client.post.assert_called_with(
|
|
475
|
+
"v1/training_projects/project-123/jobs/job-456/logs", body=query_params
|
|
476
|
+
)
|