truss 0.11.8rc12__py3-none-any.whl → 0.11.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

Files changed (27) hide show
  1. truss/cli/logs/base_watcher.py +31 -12
  2. truss/cli/logs/model_log_watcher.py +24 -1
  3. truss/cli/train/core.py +13 -11
  4. truss/cli/train/deploy_checkpoints/__init__.py +2 -2
  5. truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +211 -106
  6. truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +1 -52
  7. truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +0 -59
  8. truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py +0 -83
  9. truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +0 -53
  10. truss/cli/train/types.py +1 -11
  11. truss/cli/train_commands.py +5 -15
  12. truss/remote/baseten/api.py +87 -0
  13. truss/templates/control/control/application.py +48 -26
  14. truss/templates/control/control/endpoints.py +1 -5
  15. truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py +33 -18
  16. truss/tests/cli/train/test_deploy_checkpoints.py +0 -843
  17. truss/tests/templates/control/control/conftest.py +20 -0
  18. truss/tests/templates/control/control/test_endpoints.py +4 -0
  19. truss/tests/templates/control/control/test_server.py +8 -24
  20. truss/tests/templates/control/control/test_server_integration.py +4 -2
  21. truss/util/__init__.py +0 -0
  22. {truss-0.11.8rc12.dist-info → truss-0.11.9.dist-info}/METADATA +1 -1
  23. {truss-0.11.8rc12.dist-info → truss-0.11.9.dist-info}/RECORD +27 -25
  24. truss_train/definitions.py +0 -1
  25. {truss-0.11.8rc12.dist-info → truss-0.11.9.dist-info}/WHEEL +0 -0
  26. {truss-0.11.8rc12.dist-info → truss-0.11.9.dist-info}/entry_points.txt +0 -0
  27. {truss-0.11.8rc12.dist-info → truss-0.11.9.dist-info}/licenses/LICENSE +0 -0
@@ -1,66 +1,7 @@
1
1
  from pathlib import Path
2
2
 
3
- from jinja2 import Template
4
-
5
- from truss.base import truss_config
6
- from truss.cli.train.deploy_checkpoints.deploy_checkpoints_helpers import (
7
- START_COMMAND_ENVVAR_NAME,
8
- )
9
- from truss.cli.train.types import DeployCheckpointsConfigComplete
10
3
  from truss_train.definitions import FullCheckpoint
11
4
 
12
- from .deploy_checkpoints_helpers import (
13
- setup_base_truss_config,
14
- setup_environment_variables_and_secrets,
15
- )
16
-
17
- # NB(aghilan): Transformers was recently changed to save a chat_template.jinja file instead of inside the tokenizer_config.json file.
18
- # Old Models will not have this file, so we check for it and use it if it exists.
19
- # vLLM will not automatically resolve the chat_template.jinja file, so we need to pass it to the start command.
20
- # This logic is needed for any models trained using Transformers v4.51.3 or later
21
- VLLM_FULL_START_COMMAND = Template(
22
- "sh -c '{% if envvars %}{{ envvars }} {% endif %}"
23
- 'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
24
- "if [ -f {{ model_path }}/chat_template.jinja ]; then "
25
- " vllm serve {{ model_path }} --chat-template {{ model_path }}/chat_template.jinja "
26
- " --port 8000 --tensor-parallel-size {{ specify_tensor_parallelism }} --dtype bfloat16; "
27
- "else "
28
- " vllm serve {{ model_path }} --port 8000 --tensor-parallel-size {{ specify_tensor_parallelism }} --dtype bfloat16; "
29
- "fi'"
30
- )
31
-
32
-
33
- def render_vllm_full_truss_config(
34
- checkpoint_deploy: DeployCheckpointsConfigComplete,
35
- ) -> truss_config.TrussConfig:
36
- """Render truss config specifically for full checkpoints using vLLM."""
37
- truss_deploy_config = setup_base_truss_config(checkpoint_deploy)
38
-
39
- start_command_envvars = setup_environment_variables_and_secrets(
40
- truss_deploy_config, checkpoint_deploy
41
- )
42
-
43
- checkpoint_str = build_full_checkpoint_string(truss_deploy_config)
44
-
45
- accelerator = checkpoint_deploy.compute.accelerator
46
-
47
- start_command_args = {
48
- "model_path": checkpoint_str,
49
- "envvars": start_command_envvars,
50
- "specify_tensor_parallelism": accelerator.count if accelerator else 1,
51
- }
52
- # Note: we set the start command as an environment variable in supervisord config.
53
- # This is so that we don't have to change the supervisord config when the start command changes.
54
- # Our goal is to reduce the number of times we need to rebuild the image, and allow us to deploy faster.
55
- start_command = VLLM_FULL_START_COMMAND.render(**start_command_args)
56
- truss_deploy_config.environment_variables[START_COMMAND_ENVVAR_NAME] = start_command
57
- # Note: supervisord uses the convention %(ENV_VAR_NAME)s to access environment variable VAR_NAME
58
- truss_deploy_config.docker_server.start_command = ( # type: ignore[union-attr]
59
- f"%(ENV_{START_COMMAND_ENVVAR_NAME})s"
60
- )
61
-
62
- return truss_deploy_config
63
-
64
5
 
65
6
  def hydrate_full_checkpoint(
66
7
  job_id: str, checkpoint_id: str, checkpoint: dict
@@ -1,12 +1,3 @@
1
- from pathlib import Path
2
-
3
- from jinja2 import Template
4
-
5
- from truss.base import truss_config
6
- from truss.cli.train.deploy_checkpoints.deploy_checkpoints_helpers import (
7
- START_COMMAND_ENVVAR_NAME,
8
- )
9
- from truss.cli.train.types import DeployCheckpointsConfigComplete
10
1
  from truss_train.definitions import (
11
2
  ALLOWED_LORA_RANKS,
12
3
  DEFAULT_LORA_RANK,
@@ -14,21 +5,6 @@ from truss_train.definitions import (
14
5
  LoRADetails,
15
6
  )
16
7
 
17
- from .deploy_checkpoints_helpers import (
18
- setup_base_truss_config,
19
- setup_environment_variables_and_secrets,
20
- )
21
-
22
- VLLM_LORA_START_COMMAND = Template(
23
- 'sh -c "{%if envvars %}{{ envvars }} {% endif %}vllm serve {{ base_model_id }}'
24
- + " --port 8000"
25
- + "{{ specify_tensor_parallelism }}"
26
- + " --enable-lora"
27
- + " --max-lora-rank {{ max_lora_rank }}"
28
- + " --dtype bfloat16"
29
- + ' --lora-modules {{ lora_modules }}"'
30
- )
31
-
32
8
 
33
9
  def hydrate_lora_checkpoint(
34
10
  job_id: str, checkpoint_id: str, checkpoint: dict
@@ -43,49 +19,6 @@ def hydrate_lora_checkpoint(
43
19
  )
44
20
 
45
21
 
46
- def render_vllm_lora_truss_config(
47
- checkpoint_deploy: DeployCheckpointsConfigComplete,
48
- ) -> truss_config.TrussConfig:
49
- """Render truss config specifically for LoRA checkpoints using vLLM."""
50
- truss_deploy_config = setup_base_truss_config(checkpoint_deploy)
51
- start_command_envvars = setup_environment_variables_and_secrets(
52
- truss_deploy_config, checkpoint_deploy
53
- )
54
-
55
- checkpoint_str = _build_lora_checkpoint_string(truss_deploy_config)
56
-
57
- max_lora_rank = max(
58
- [
59
- checkpoint.lora_details.rank or DEFAULT_LORA_RANK
60
- for checkpoint in checkpoint_deploy.checkpoint_details.checkpoints
61
- if hasattr(checkpoint, "lora_details") and checkpoint.lora_details
62
- ]
63
- )
64
- accelerator = checkpoint_deploy.compute.accelerator
65
- if accelerator:
66
- specify_tensor_parallelism = f" --tensor-parallel-size {accelerator.count}"
67
- else:
68
- specify_tensor_parallelism = ""
69
-
70
- start_command_args = {
71
- "base_model_id": checkpoint_deploy.checkpoint_details.base_model_id,
72
- "lora_modules": checkpoint_str,
73
- "envvars": start_command_envvars,
74
- "max_lora_rank": max_lora_rank,
75
- "specify_tensor_parallelism": specify_tensor_parallelism,
76
- }
77
- start_command = VLLM_LORA_START_COMMAND.render(**start_command_args)
78
- # Note: we set the start command as an environment variable in supervisord config.
79
- # This is so that we don't have to change the supervisord config when the start command changes.
80
- # Our goal is to reduce the number of times we need to rebuild the image, and allow us to deploy faster.
81
- truss_deploy_config.environment_variables[START_COMMAND_ENVVAR_NAME] = start_command
82
- # Note: supervisord uses the convention %(ENV_VAR_NAME)s to access environment variable VAR_NAME
83
- truss_deploy_config.docker_server.start_command = ( # type: ignore[union-attr]
84
- f"%(ENV_{START_COMMAND_ENVVAR_NAME})s"
85
- )
86
- return truss_deploy_config
87
-
88
-
89
22
  def _get_lora_rank(checkpoint_resp: dict) -> int:
90
23
  """Extract and validate LoRA rank from checkpoint response."""
91
24
  lora_adapter_config = checkpoint_resp.get("lora_adapter_config") or {}
@@ -99,19 +32,3 @@ def _get_lora_rank(checkpoint_resp: dict) -> int:
99
32
  )
100
33
 
101
34
  return lora_rank
102
-
103
-
104
- def _build_lora_checkpoint_string(truss_deploy_config) -> str:
105
- """Build the checkpoint string for LoRA modules from truss deploy config."""
106
- checkpoint_parts = []
107
- for (
108
- truss_checkpoint
109
- ) in truss_deploy_config.training_checkpoints.artifact_references: # type: ignore
110
- ckpt_path = Path(
111
- truss_deploy_config.training_checkpoints.download_folder, # type: ignore
112
- truss_checkpoint.training_job_id,
113
- truss_checkpoint.paths[0],
114
- )
115
- checkpoint_parts.append(f"{truss_checkpoint.training_job_id}={ckpt_path}")
116
-
117
- return " ".join(checkpoint_parts)
@@ -1,58 +1,5 @@
1
- from jinja2 import Template
2
-
3
- from truss.base import truss_config
4
- from truss.cli.train.deploy_checkpoints.deploy_checkpoints_helpers import (
5
- START_COMMAND_ENVVAR_NAME,
6
- )
7
- from truss.cli.train.deploy_checkpoints.deploy_full_checkpoints import (
8
- build_full_checkpoint_string,
9
- )
10
- from truss.cli.train.types import DeployCheckpointsConfigComplete
11
1
  from truss_train.definitions import WhisperCheckpoint
12
2
 
13
- from .deploy_checkpoints_helpers import (
14
- setup_base_truss_config,
15
- setup_environment_variables_and_secrets,
16
- )
17
-
18
- VLLM_WHISPER_START_COMMAND = Template(
19
- "sh -c '{% if envvars %}{{ envvars }} {% endif %}"
20
- 'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
21
- "vllm serve {{ model_path }} --port 8000 --tensor-parallel-size {{ specify_tensor_parallelism }}'"
22
- )
23
-
24
-
25
- def render_vllm_whisper_truss_config(
26
- checkpoint_deploy: DeployCheckpointsConfigComplete,
27
- ) -> truss_config.TrussConfig:
28
- """Render truss config specifically for whisper checkpoints using vLLM."""
29
- truss_deploy_config = setup_base_truss_config(checkpoint_deploy)
30
-
31
- start_command_envvars = setup_environment_variables_and_secrets(
32
- truss_deploy_config, checkpoint_deploy
33
- )
34
-
35
- checkpoint_str = build_full_checkpoint_string(truss_deploy_config)
36
-
37
- accelerator = checkpoint_deploy.compute.accelerator
38
-
39
- start_command_args = {
40
- "model_path": checkpoint_str,
41
- "envvars": start_command_envvars,
42
- "specify_tensor_parallelism": accelerator.count if accelerator else 1,
43
- }
44
- # Note: we set the start command as an environment variable in supervisord config.
45
- # This is so that we don't have to change the supervisord config when the start command changes.
46
- # Our goal is to reduce the number of times we need to rebuild the image, and allow us to deploy faster.
47
- start_command = VLLM_WHISPER_START_COMMAND.render(**start_command_args)
48
- truss_deploy_config.environment_variables[START_COMMAND_ENVVAR_NAME] = start_command
49
- # Note: supervisord uses the convention %(ENV_VAR_NAME)s to access environment variable VAR_NAME
50
- truss_deploy_config.docker_server.start_command = ( # type: ignore[union-attr]
51
- f"%(ENV_{START_COMMAND_ENVVAR_NAME})s"
52
- )
53
-
54
- return truss_deploy_config
55
-
56
3
 
57
4
  def hydrate_whisper_checkpoint(
58
5
  job_id: str, checkpoint_id: str, checkpoint: dict
truss/cli/train/types.py CHANGED
@@ -1,5 +1,4 @@
1
1
  from dataclasses import dataclass
2
- from pathlib import Path
3
2
  from typing import Optional
4
3
 
5
4
  from truss_train.definitions import (
@@ -7,12 +6,11 @@ from truss_train.definitions import (
7
6
  Compute,
8
7
  DeployCheckpointsConfig,
9
8
  DeployCheckpointsRuntime,
10
- ModelWeightsFormat,
11
9
  )
12
10
 
13
11
 
14
12
  @dataclass
15
- class PrepareCheckpointArgs:
13
+ class DeployCheckpointArgs:
16
14
  project_id: Optional[str]
17
15
  job_id: Optional[str]
18
16
  deploy_config_path: Optional[str]
@@ -26,13 +24,5 @@ class DeployCheckpointsConfigComplete(DeployCheckpointsConfig):
26
24
 
27
25
  checkpoint_details: CheckpointList
28
26
  model_name: str
29
- deployment_name: str
30
27
  runtime: DeployCheckpointsRuntime
31
28
  compute: Compute
32
- model_weight_format: ModelWeightsFormat
33
-
34
-
35
- @dataclass
36
- class PrepareCheckpointResult:
37
- truss_directory: Path
38
- checkpoint_deploy_config: DeployCheckpointsConfigComplete
@@ -8,7 +8,7 @@ import rich_click as click
8
8
  import truss.cli.train.core as train_cli
9
9
  from truss.base.constants import TRAINING_TEMPLATE_DIR
10
10
  from truss.cli import remote_cli
11
- from truss.cli.cli import push, truss_cli
11
+ from truss.cli.cli import truss_cli
12
12
  from truss.cli.logs import utils as cli_log_utils
13
13
  from truss.cli.logs.training_log_watcher import TrainingLogWatcher
14
14
  from truss.cli.train import common as train_common
@@ -329,26 +329,16 @@ def deploy_checkpoints(
329
329
  project_id = _maybe_resolve_project_id_from_id_or_name(
330
330
  remote_provider, project_id=project_id, project=project
331
331
  )
332
- prepare_checkpoint_result = train_cli.prepare_checkpoint_deploy(
332
+ result = train_cli.create_model_version_from_inference_template(
333
333
  remote_provider,
334
- train_cli.PrepareCheckpointArgs(
334
+ train_cli.DeployCheckpointArgs(
335
335
  project_id=project_id, job_id=job_id, deploy_config_path=config
336
336
  ),
337
337
  )
338
338
 
339
- params = {
340
- "target_directory": prepare_checkpoint_result.truss_directory,
341
- "remote": remote,
342
- "model_name": prepare_checkpoint_result.checkpoint_deploy_config.model_name,
343
- "publish": True,
344
- "deployment_name": prepare_checkpoint_result.checkpoint_deploy_config.deployment_name,
345
- }
346
- ctx = _prepare_click_context(push, params)
347
339
  if dry_run:
348
- console.print("--dry-run flag provided, not deploying", style="yellow")
349
- else:
350
- push.invoke(ctx)
351
- train_cli.print_deploy_checkpoints_success_message(prepare_checkpoint_result)
340
+ console.print("--dry-run flag provided, did not deploy", style="yellow")
341
+ train_cli.print_deploy_checkpoints_success_message(result)
352
342
 
353
343
 
354
344
  @train.command(name="download")
@@ -3,6 +3,7 @@ from enum import Enum
3
3
  from typing import Any, Dict, List, Mapping, Optional
4
4
 
5
5
  import requests
6
+ from pydantic import BaseModel, Field
6
7
 
7
8
  from truss.remote.baseten import custom_types as b10_types
8
9
  from truss.remote.baseten.auth import ApiKey, AuthService
@@ -14,6 +15,39 @@ from truss.remote.baseten.utils.transfer import base64_encoded_json_str
14
15
  logger = logging.getLogger(__name__)
15
16
 
16
17
 
18
+ class InstanceTypeV1(BaseModel):
19
+ """An instance type."""
20
+
21
+ id: str = Field(description="Identifier string for the instance type")
22
+ name: str = Field(description="Name of the instance type")
23
+ display_name: str = Field(
24
+ alias="displayName", description="Display name of the instance type"
25
+ )
26
+ gpu_count: int = Field(
27
+ alias="gpuCount", description="Number of GPUs on the instance type"
28
+ )
29
+ default: bool = Field(description="Whether this is the default instance type")
30
+ gpu_memory: Optional[int] = Field(alias="gpuMemory", description="GPU memory in MB")
31
+ node_count: int = Field(alias="nodeCount", description="Number of nodes")
32
+ gpu_type: Optional[str] = Field(
33
+ alias="gpuType", description="Type of GPU on the instance type"
34
+ )
35
+ millicpu_limit: int = Field(
36
+ alias="millicpuLimit", description="CPU limit of the instance type in millicpu"
37
+ )
38
+ memory_limit: int = Field(
39
+ alias="memoryLimit", description="Memory limit of the instance type in MB"
40
+ )
41
+ price: Optional[float] = Field(description="Price of the instance type")
42
+ limited_capacity: Optional[bool] = Field(
43
+ alias="limitedCapacity",
44
+ description="Whether this instance type has limited capacity",
45
+ )
46
+
47
+ class Config:
48
+ populate_by_name = True
49
+
50
+
17
51
  API_URL_MAPPING = {
18
52
  "https://app.baseten.co": "https://api.baseten.co",
19
53
  "https://app.staging.baseten.co": "https://api.staging.baseten.co",
@@ -750,3 +784,56 @@ class BasetenApi:
750
784
 
751
785
  # NB(nikhil): reverse order so latest logs are at the end
752
786
  return resp_json["logs"][::-1]
787
+
788
+ def create_model_version_from_inference_template(self, request_data: dict):
789
+ """
790
+ Create a model version from an inference template using GraphQL mutation.
791
+
792
+ Args:
793
+ request_data: Dictionary containing the request structure with metadata,
794
+ weights_sources, inference_stack, and instance_type_id
795
+ """
796
+ query_string = """
797
+ mutation ($request: CreateModelVersionFromInferenceTemplateRequest!) {
798
+ create_model_version_from_inference_template(request: $request) {
799
+ model_version {
800
+ id
801
+ name
802
+ }
803
+ }
804
+ }
805
+ """
806
+
807
+ resp = self._post_graphql_query(
808
+ query_string, variables={"request": request_data}
809
+ )
810
+ return resp["data"]["create_model_version_from_inference_template"]
811
+
812
+ def get_instance_types(self) -> List[InstanceTypeV1]:
813
+ """
814
+ Get all available instance types via GraphQL API.
815
+ """
816
+ query_string = """
817
+ query Instances {
818
+ listedInstances: listed_instances {
819
+ id
820
+ name
821
+ millicpuLimit: millicpu_limit
822
+ memoryLimit: memory_limit
823
+ gpuCount: gpu_count
824
+ gpuType: gpu_type
825
+ gpuMemory: gpu_memory
826
+ default
827
+ displayName: display_name
828
+ nodeCount: node_count
829
+ price
830
+ limitedCapacity: limited_capacity
831
+ }
832
+ }
833
+ """
834
+
835
+ resp = self._post_graphql_query(query_string)
836
+ instance_types_data = resp["data"]["listedInstances"]
837
+ return [
838
+ InstanceTypeV1(**instance_type) for instance_type in instance_types_data
839
+ ]
@@ -1,13 +1,15 @@
1
1
  import asyncio
2
+ import http
2
3
  import logging
3
4
  import logging.config
4
5
  import re
6
+ import traceback
5
7
  from pathlib import Path
6
- from typing import Dict
8
+ from typing import Awaitable, Callable, Dict
7
9
 
8
10
  import httpx
9
11
  from endpoints import control_app
10
- from fastapi import FastAPI
12
+ from fastapi import FastAPI, Request, Response
11
13
  from fastapi.responses import JSONResponse
12
14
  from helpers.errors import ModelLoadFailed, PatchApplicatonError
13
15
  from helpers.inference_server_controller import InferenceServerController
@@ -16,22 +18,47 @@ from helpers.inference_server_starter import async_inference_server_startup_flow
16
18
  from helpers.truss_patch.model_container_patch_applier import ModelContainerPatchApplier
17
19
  from shared import log_config
18
20
  from starlette.datastructures import State
19
-
20
-
21
- async def handle_patch_error(_, exc):
22
- error_type = _camel_to_snake_case(type(exc).__name__)
23
- return JSONResponse(content={"error": {"type": error_type, "msg": str(exc)}})
24
-
25
-
26
- async def generic_error_handler(_, exc):
27
- return JSONResponse(
28
- content={"error": {"type": "unknown", "msg": f"{type(exc)}: {exc}"}}
29
- )
30
-
31
-
32
- async def handle_model_load_failed(_, error):
33
- # Model load failures should result in 503 status
34
- return JSONResponse({"error": str(error)}, 503)
21
+ from starlette.middleware.base import BaseHTTPMiddleware
22
+
23
+ SANITIZED_EXCEPTION_FRAMES = 2
24
+
25
+
26
+ # NB(nikhil): SanitizedExceptionMiddleware will reduce the noise of control server stack frames, since
27
+ # users often complain about the verbosity. Now, if any exceptions are explicitly raised during a proxied
28
+ # request, we'll log the last two stack frames which should be sufficient for debugging while significantly
29
+ # cutting down the volume.
30
+ class SanitizedExceptionMiddleware(BaseHTTPMiddleware):
31
+ def __init__(self, app, num_frames: int = SANITIZED_EXCEPTION_FRAMES):
32
+ super().__init__(app)
33
+ self.num_frames = num_frames
34
+
35
+ async def dispatch(
36
+ self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
37
+ ) -> Response:
38
+ try:
39
+ return await call_next(request)
40
+ except Exception as exc:
41
+ sanitized_traceback = self._create_sanitized_traceback(exc)
42
+ request.app.state.logger.error(sanitized_traceback)
43
+
44
+ if isinstance(exc, ModelLoadFailed):
45
+ return JSONResponse(
46
+ {"error": str(exc)}, status_code=http.HTTPStatus.BAD_GATEWAY.value
47
+ )
48
+ elif isinstance(exc, PatchApplicatonError):
49
+ error_type = _camel_to_snake_case(type(exc).__name__)
50
+ return JSONResponse({"error": {"type": error_type, "msg": str(exc)}})
51
+ else:
52
+ return JSONResponse(
53
+ {"error": {"type": "unknown", "msg": str(exc)}},
54
+ status_code=http.HTTPStatus.INTERNAL_SERVER_ERROR.value,
55
+ )
56
+
57
+ def _create_sanitized_traceback(self, error: Exception) -> str:
58
+ tb_lines = traceback.format_tb(error.__traceback__)
59
+ if tb_lines and self.num_frames > 0:
60
+ return "".join(tb_lines[-self.num_frames :])
61
+ return f"{type(error).__name__}: {error}"
35
62
 
36
63
 
37
64
  def create_app(base_config: Dict):
@@ -57,10 +84,9 @@ def create_app(base_config: Dict):
57
84
  base_url=f"http://localhost:{app_state.inference_server_port}", limits=limits
58
85
  )
59
86
 
60
- pip_path = getattr(app_state, "pip_path", None)
61
-
87
+ uv_path = getattr(app_state, "uv_path", None)
62
88
  patch_applier = ModelContainerPatchApplier(
63
- Path(app_state.inference_server_home), app_logger, pip_path
89
+ Path(app_state.inference_server_home), app_logger, uv_path
64
90
  )
65
91
 
66
92
  oversee_inference_server = getattr(app_state, "oversee_inference_server", True)
@@ -82,14 +108,10 @@ def create_app(base_config: Dict):
82
108
  app = FastAPI(
83
109
  title="Truss Live Reload Server",
84
110
  on_startup=[start_background_inference_startup],
85
- exception_handlers={
86
- PatchApplicatonError: handle_patch_error,
87
- ModelLoadFailed: handle_model_load_failed,
88
- Exception: generic_error_handler,
89
- },
90
111
  )
91
112
  app.state = app_state
92
113
  app.include_router(control_app)
114
+ app.add_middleware(SanitizedExceptionMiddleware)
93
115
 
94
116
  @app.on_event("shutdown")
95
117
  def on_shutdown():
@@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Optional, Protocol
5
5
  import httpx
6
6
  from fastapi import APIRouter, WebSocket
7
7
  from fastapi.responses import JSONResponse, StreamingResponse
8
+ from helpers.errors import ModelLoadFailed, ModelNotReady
8
9
  from httpx_ws import AsyncWebSocketSession, WebSocketDisconnect, aconnect_ws
9
10
  from httpx_ws import _exceptions as httpx_ws_exceptions
10
11
  from starlette.requests import ClientDisconnect, Request
@@ -13,11 +14,6 @@ from starlette.websockets import WebSocketDisconnect as StartletteWebSocketDisco
13
14
  from tenacity import RetryCallState, Retrying, retry_if_exception_type, wait_fixed
14
15
  from wsproto.events import BytesMessage, TextMessage
15
16
 
16
- from truss.templates.control.control.helpers.errors import (
17
- ModelLoadFailed,
18
- ModelNotReady,
19
- )
20
-
21
17
  INFERENCE_SERVER_START_WAIT_SECS = 60
22
18
  BASE_RETRY_EXCEPTIONS = (
23
19
  retry_if_exception_type(httpx.ConnectError)
@@ -1,4 +1,6 @@
1
1
  import logging
2
+ import os
3
+ import shutil
2
4
  import subprocess
3
5
  from pathlib import Path
4
6
  from typing import Optional
@@ -30,7 +32,7 @@ class ModelContainerPatchApplier:
30
32
  self,
31
33
  inference_server_home: Path,
32
34
  app_logger: logging.Logger,
33
- pip_path: Optional[str] = None, # Only meant for testing
35
+ uv_path: Optional[str] = None, # Only meant for testing
34
36
  ) -> None:
35
37
  self._inference_server_home = inference_server_home
36
38
  self._model_module_dir = (
@@ -41,9 +43,19 @@ class ModelContainerPatchApplier:
41
43
  ).resolve()
42
44
  self._data_dir = self._inference_server_home / self._truss_config.data_dir
43
45
  self._app_logger = app_logger
44
- self._pip_path_cached = None
45
- if pip_path is not None:
46
- self._pip_path_cached = "pip"
46
+ self._uv_path_cached = None
47
+ if uv_path is not None:
48
+ self._uv_path_cached = uv_path
49
+
50
+ self._python_executable = self._get_python_executable()
51
+
52
+ def _get_python_executable(self) -> str:
53
+ # NB(nikhil): `uv` requires the full path to the python interpreter for patching
54
+ # python modules. We expect PYTHON_EXECUTABLE to exist in all development images, but
55
+ # we fallback to python3 as a default.
56
+ python_executable = os.environ.get("PYTHON_EXECUTABLE", "python3")
57
+ full_executable_path = shutil.which(python_executable)
58
+ return full_executable_path or python_executable
47
59
 
48
60
  def __call__(self, patch: Patch, inf_env: dict):
49
61
  self._app_logger.debug(f"Applying patch {patch.to_dict()}")
@@ -79,10 +91,10 @@ class ModelContainerPatchApplier:
79
91
  return TrussConfig.from_yaml(self._inference_server_home / "config.yaml")
80
92
 
81
93
  @property
82
- def _pip_path(self) -> str:
83
- if self._pip_path_cached is None:
84
- self._pip_path_cached = _identify_pip_path()
85
- return self._pip_path_cached
94
+ def _uv_path(self) -> str:
95
+ if self._uv_path_cached is None:
96
+ self._uv_path_cached = _identify_uv_path()
97
+ return self._uv_path_cached
86
98
 
87
99
  def _apply_python_requirement_patch(
88
100
  self, python_requirement_patch: PythonRequirementPatch
@@ -95,20 +107,25 @@ class ModelContainerPatchApplier:
95
107
  if action == Action.REMOVE:
96
108
  subprocess.run(
97
109
  [
98
- self._pip_path,
110
+ self._uv_path,
111
+ "pip",
99
112
  "uninstall",
100
- "-y",
101
113
  python_requirement_patch.requirement,
114
+ "--python",
115
+ self._python_executable,
102
116
  ],
103
117
  check=True,
104
118
  )
105
119
  elif action in [Action.ADD, Action.UPDATE]:
106
120
  subprocess.run(
107
121
  [
108
- self._pip_path,
122
+ self._uv_path,
123
+ "pip",
109
124
  "install",
110
125
  python_requirement_patch.requirement,
111
126
  "--upgrade",
127
+ "--python",
128
+ self._python_executable,
112
129
  ],
113
130
  check=True,
114
131
  )
@@ -158,11 +175,9 @@ class ModelContainerPatchApplier:
158
175
  raise ValueError(f"Unknown patch action {action}")
159
176
 
160
177
 
161
- def _identify_pip_path() -> str:
162
- if Path("/usr/local/bin/pip3").exists():
163
- return "/usr/local/bin/pip3"
164
-
165
- if Path("/usr/local/bin/pip").exists():
166
- return "/usr/local/bin/pip"
178
+ def _identify_uv_path() -> str:
179
+ uv_path = shutil.which("uv")
180
+ if not uv_path:
181
+ raise RuntimeError("Unable to find `uv`, make sure it's installed.")
167
182
 
168
- raise RuntimeError("Unable to find pip, make sure it's installed.")
183
+ return uv_path