truss 0.11.9rc1__py3-none-any.whl → 0.11.9rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

@@ -1,12 +1,3 @@
1
- from pathlib import Path
2
-
3
- from jinja2 import Template
4
-
5
- from truss.base import truss_config
6
- from truss.cli.train.deploy_checkpoints.deploy_checkpoints_helpers import (
7
- START_COMMAND_ENVVAR_NAME,
8
- )
9
- from truss.cli.train.types import DeployCheckpointsConfigComplete
10
1
  from truss_train.definitions import (
11
2
  ALLOWED_LORA_RANKS,
12
3
  DEFAULT_LORA_RANK,
@@ -14,21 +5,6 @@ from truss_train.definitions import (
14
5
  LoRADetails,
15
6
  )
16
7
 
17
- from .deploy_checkpoints_helpers import (
18
- setup_base_truss_config,
19
- setup_environment_variables_and_secrets,
20
- )
21
-
22
- VLLM_LORA_START_COMMAND = Template(
23
- 'sh -c "{%if envvars %}{{ envvars }} {% endif %}vllm serve {{ base_model_id }}'
24
- + " --port 8000"
25
- + "{{ specify_tensor_parallelism }}"
26
- + " --enable-lora"
27
- + " --max-lora-rank {{ max_lora_rank }}"
28
- + " --dtype bfloat16"
29
- + ' --lora-modules {{ lora_modules }}"'
30
- )
31
-
32
8
 
33
9
  def hydrate_lora_checkpoint(
34
10
  job_id: str, checkpoint_id: str, checkpoint: dict
@@ -43,49 +19,6 @@ def hydrate_lora_checkpoint(
43
19
  )
44
20
 
45
21
 
46
- def render_vllm_lora_truss_config(
47
- checkpoint_deploy: DeployCheckpointsConfigComplete,
48
- ) -> truss_config.TrussConfig:
49
- """Render truss config specifically for LoRA checkpoints using vLLM."""
50
- truss_deploy_config = setup_base_truss_config(checkpoint_deploy)
51
- start_command_envvars = setup_environment_variables_and_secrets(
52
- truss_deploy_config, checkpoint_deploy
53
- )
54
-
55
- checkpoint_str = _build_lora_checkpoint_string(truss_deploy_config)
56
-
57
- max_lora_rank = max(
58
- [
59
- checkpoint.lora_details.rank or DEFAULT_LORA_RANK
60
- for checkpoint in checkpoint_deploy.checkpoint_details.checkpoints
61
- if hasattr(checkpoint, "lora_details") and checkpoint.lora_details
62
- ]
63
- )
64
- accelerator = checkpoint_deploy.compute.accelerator
65
- if accelerator:
66
- specify_tensor_parallelism = f" --tensor-parallel-size {accelerator.count}"
67
- else:
68
- specify_tensor_parallelism = ""
69
-
70
- start_command_args = {
71
- "base_model_id": checkpoint_deploy.checkpoint_details.base_model_id,
72
- "lora_modules": checkpoint_str,
73
- "envvars": start_command_envvars,
74
- "max_lora_rank": max_lora_rank,
75
- "specify_tensor_parallelism": specify_tensor_parallelism,
76
- }
77
- start_command = VLLM_LORA_START_COMMAND.render(**start_command_args)
78
- # Note: we set the start command as an environment variable in supervisord config.
79
- # This is so that we don't have to change the supervisord config when the start command changes.
80
- # Our goal is to reduce the number of times we need to rebuild the image, and allow us to deploy faster.
81
- truss_deploy_config.environment_variables[START_COMMAND_ENVVAR_NAME] = start_command
82
- # Note: supervisord uses the convention %(ENV_VAR_NAME)s to access environment variable VAR_NAME
83
- truss_deploy_config.docker_server.start_command = ( # type: ignore[union-attr]
84
- f"%(ENV_{START_COMMAND_ENVVAR_NAME})s"
85
- )
86
- return truss_deploy_config
87
-
88
-
89
22
  def _get_lora_rank(checkpoint_resp: dict) -> int:
90
23
  """Extract and validate LoRA rank from checkpoint response."""
91
24
  lora_adapter_config = checkpoint_resp.get("lora_adapter_config") or {}
@@ -99,19 +32,3 @@ def _get_lora_rank(checkpoint_resp: dict) -> int:
99
32
  )
100
33
 
101
34
  return lora_rank
102
-
103
-
104
- def _build_lora_checkpoint_string(truss_deploy_config) -> str:
105
- """Build the checkpoint string for LoRA modules from truss deploy config."""
106
- checkpoint_parts = []
107
- for (
108
- truss_checkpoint
109
- ) in truss_deploy_config.training_checkpoints.artifact_references: # type: ignore
110
- ckpt_path = Path(
111
- truss_deploy_config.training_checkpoints.download_folder, # type: ignore
112
- truss_checkpoint.training_job_id,
113
- truss_checkpoint.paths[0],
114
- )
115
- checkpoint_parts.append(f"{truss_checkpoint.training_job_id}={ckpt_path}")
116
-
117
- return " ".join(checkpoint_parts)
@@ -1,58 +1,5 @@
1
- from jinja2 import Template
2
-
3
- from truss.base import truss_config
4
- from truss.cli.train.deploy_checkpoints.deploy_checkpoints_helpers import (
5
- START_COMMAND_ENVVAR_NAME,
6
- )
7
- from truss.cli.train.deploy_checkpoints.deploy_full_checkpoints import (
8
- build_full_checkpoint_string,
9
- )
10
- from truss.cli.train.types import DeployCheckpointsConfigComplete
11
1
  from truss_train.definitions import WhisperCheckpoint
12
2
 
13
- from .deploy_checkpoints_helpers import (
14
- setup_base_truss_config,
15
- setup_environment_variables_and_secrets,
16
- )
17
-
18
- VLLM_WHISPER_START_COMMAND = Template(
19
- "sh -c '{% if envvars %}{{ envvars }} {% endif %}"
20
- 'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
21
- "vllm serve {{ model_path }} --port 8000 --tensor-parallel-size {{ specify_tensor_parallelism }}'"
22
- )
23
-
24
-
25
- def render_vllm_whisper_truss_config(
26
- checkpoint_deploy: DeployCheckpointsConfigComplete,
27
- ) -> truss_config.TrussConfig:
28
- """Render truss config specifically for whisper checkpoints using vLLM."""
29
- truss_deploy_config = setup_base_truss_config(checkpoint_deploy)
30
-
31
- start_command_envvars = setup_environment_variables_and_secrets(
32
- truss_deploy_config, checkpoint_deploy
33
- )
34
-
35
- checkpoint_str = build_full_checkpoint_string(truss_deploy_config)
36
-
37
- accelerator = checkpoint_deploy.compute.accelerator
38
-
39
- start_command_args = {
40
- "model_path": checkpoint_str,
41
- "envvars": start_command_envvars,
42
- "specify_tensor_parallelism": accelerator.count if accelerator else 1,
43
- }
44
- # Note: we set the start command as an environment variable in supervisord config.
45
- # This is so that we don't have to change the supervisord config when the start command changes.
46
- # Our goal is to reduce the number of times we need to rebuild the image, and allow us to deploy faster.
47
- start_command = VLLM_WHISPER_START_COMMAND.render(**start_command_args)
48
- truss_deploy_config.environment_variables[START_COMMAND_ENVVAR_NAME] = start_command
49
- # Note: supervisord uses the convention %(ENV_VAR_NAME)s to access environment variable VAR_NAME
50
- truss_deploy_config.docker_server.start_command = ( # type: ignore[union-attr]
51
- f"%(ENV_{START_COMMAND_ENVVAR_NAME})s"
52
- )
53
-
54
- return truss_deploy_config
55
-
56
3
 
57
4
  def hydrate_whisper_checkpoint(
58
5
  job_id: str, checkpoint_id: str, checkpoint: dict
truss/cli/train/types.py CHANGED
@@ -1,5 +1,4 @@
1
1
  from dataclasses import dataclass
2
- from pathlib import Path
3
2
  from typing import Optional
4
3
 
5
4
  from truss_train.definitions import (
@@ -7,12 +6,11 @@ from truss_train.definitions import (
7
6
  Compute,
8
7
  DeployCheckpointsConfig,
9
8
  DeployCheckpointsRuntime,
10
- ModelWeightsFormat,
11
9
  )
12
10
 
13
11
 
14
12
  @dataclass
15
- class PrepareCheckpointArgs:
13
+ class DeployCheckpointArgs:
16
14
  project_id: Optional[str]
17
15
  job_id: Optional[str]
18
16
  deploy_config_path: Optional[str]
@@ -26,13 +24,5 @@ class DeployCheckpointsConfigComplete(DeployCheckpointsConfig):
26
24
 
27
25
  checkpoint_details: CheckpointList
28
26
  model_name: str
29
- deployment_name: str
30
27
  runtime: DeployCheckpointsRuntime
31
28
  compute: Compute
32
- model_weight_format: ModelWeightsFormat
33
-
34
-
35
- @dataclass
36
- class PrepareCheckpointResult:
37
- truss_directory: Path
38
- checkpoint_deploy_config: DeployCheckpointsConfigComplete
@@ -8,7 +8,7 @@ import rich_click as click
8
8
  import truss.cli.train.core as train_cli
9
9
  from truss.base.constants import TRAINING_TEMPLATE_DIR
10
10
  from truss.cli import remote_cli
11
- from truss.cli.cli import push, truss_cli
11
+ from truss.cli.cli import truss_cli
12
12
  from truss.cli.logs import utils as cli_log_utils
13
13
  from truss.cli.logs.training_log_watcher import TrainingLogWatcher
14
14
  from truss.cli.train import common as train_common
@@ -329,26 +329,16 @@ def deploy_checkpoints(
329
329
  project_id = _maybe_resolve_project_id_from_id_or_name(
330
330
  remote_provider, project_id=project_id, project=project
331
331
  )
332
- prepare_checkpoint_result = train_cli.prepare_checkpoint_deploy(
332
+ result = train_cli.create_model_version_from_inference_template(
333
333
  remote_provider,
334
- train_cli.PrepareCheckpointArgs(
334
+ train_cli.DeployCheckpointArgs(
335
335
  project_id=project_id, job_id=job_id, deploy_config_path=config
336
336
  ),
337
337
  )
338
338
 
339
- params = {
340
- "target_directory": prepare_checkpoint_result.truss_directory,
341
- "remote": remote,
342
- "model_name": prepare_checkpoint_result.checkpoint_deploy_config.model_name,
343
- "publish": True,
344
- "deployment_name": prepare_checkpoint_result.checkpoint_deploy_config.deployment_name,
345
- }
346
- ctx = _prepare_click_context(push, params)
347
339
  if dry_run:
348
- console.print("--dry-run flag provided, not deploying", style="yellow")
349
- else:
350
- push.invoke(ctx)
351
- train_cli.print_deploy_checkpoints_success_message(prepare_checkpoint_result)
340
+ console.print("--dry-run flag provided, did not deploy", style="yellow")
341
+ train_cli.print_deploy_checkpoints_success_message(result)
352
342
 
353
343
 
354
344
  @train.command(name="download")
@@ -3,6 +3,7 @@ from enum import Enum
3
3
  from typing import Any, Dict, List, Mapping, Optional
4
4
 
5
5
  import requests
6
+ from pydantic import BaseModel, Field
6
7
 
7
8
  from truss.remote.baseten import custom_types as b10_types
8
9
  from truss.remote.baseten.auth import ApiKey, AuthService
@@ -14,6 +15,39 @@ from truss.remote.baseten.utils.transfer import base64_encoded_json_str
14
15
  logger = logging.getLogger(__name__)
15
16
 
16
17
 
18
+ class InstanceTypeV1(BaseModel):
19
+ """An instance type."""
20
+
21
+ id: str = Field(description="Identifier string for the instance type")
22
+ name: str = Field(description="Name of the instance type")
23
+ display_name: str = Field(
24
+ alias="displayName", description="Display name of the instance type"
25
+ )
26
+ gpu_count: int = Field(
27
+ alias="gpuCount", description="Number of GPUs on the instance type"
28
+ )
29
+ default: bool = Field(description="Whether this is the default instance type")
30
+ gpu_memory: Optional[int] = Field(alias="gpuMemory", description="GPU memory in MB")
31
+ node_count: int = Field(alias="nodeCount", description="Number of nodes")
32
+ gpu_type: Optional[str] = Field(
33
+ alias="gpuType", description="Type of GPU on the instance type"
34
+ )
35
+ millicpu_limit: int = Field(
36
+ alias="millicpuLimit", description="CPU limit of the instance type in millicpu"
37
+ )
38
+ memory_limit: int = Field(
39
+ alias="memoryLimit", description="Memory limit of the instance type in MB"
40
+ )
41
+ price: Optional[float] = Field(description="Price of the instance type")
42
+ limited_capacity: Optional[bool] = Field(
43
+ alias="limitedCapacity",
44
+ description="Whether this instance type has limited capacity",
45
+ )
46
+
47
+ class Config:
48
+ populate_by_name = True
49
+
50
+
17
51
  API_URL_MAPPING = {
18
52
  "https://app.baseten.co": "https://api.baseten.co",
19
53
  "https://app.staging.baseten.co": "https://api.staging.baseten.co",
@@ -750,3 +784,56 @@ class BasetenApi:
750
784
 
751
785
  # NB(nikhil): reverse order so latest logs are at the end
752
786
  return resp_json["logs"][::-1]
787
+
788
+ def create_model_version_from_inference_template(self, request_data: dict):
789
+ """
790
+ Create a model version from an inference template using GraphQL mutation.
791
+
792
+ Args:
793
+ request_data: Dictionary containing the request structure with metadata,
794
+ weights_sources, inference_stack, and instance_type_id
795
+ """
796
+ query_string = """
797
+ mutation ($request: CreateModelVersionFromInferenceTemplateRequest!) {
798
+ create_model_version_from_inference_template(request: $request) {
799
+ model_version {
800
+ id
801
+ name
802
+ }
803
+ }
804
+ }
805
+ """
806
+
807
+ resp = self._post_graphql_query(
808
+ query_string, variables={"request": request_data}
809
+ )
810
+ return resp["data"]["create_model_version_from_inference_template"]
811
+
812
+ def get_instance_types(self) -> List[InstanceTypeV1]:
813
+ """
814
+ Get all available instance types via GraphQL API.
815
+ """
816
+ query_string = """
817
+ query Instances {
818
+ listedInstances: listed_instances {
819
+ id
820
+ name
821
+ millicpuLimit: millicpu_limit
822
+ memoryLimit: memory_limit
823
+ gpuCount: gpu_count
824
+ gpuType: gpu_type
825
+ gpuMemory: gpu_memory
826
+ default
827
+ displayName: display_name
828
+ nodeCount: node_count
829
+ price
830
+ limitedCapacity: limited_capacity
831
+ }
832
+ }
833
+ """
834
+
835
+ resp = self._post_graphql_query(query_string)
836
+ instance_types_data = resp["data"]["listedInstances"]
837
+ return [
838
+ InstanceTypeV1(**instance_type) for instance_type in instance_types_data
839
+ ]
@@ -1,13 +1,15 @@
1
1
  import asyncio
2
+ import http
2
3
  import logging
3
4
  import logging.config
4
5
  import re
6
+ import traceback
5
7
  from pathlib import Path
6
- from typing import Dict
8
+ from typing import Awaitable, Callable, Dict
7
9
 
8
10
  import httpx
9
11
  from endpoints import control_app
10
- from fastapi import FastAPI
12
+ from fastapi import FastAPI, Request, Response
11
13
  from fastapi.responses import JSONResponse
12
14
  from helpers.errors import ModelLoadFailed, PatchApplicatonError
13
15
  from helpers.inference_server_controller import InferenceServerController
@@ -16,22 +18,47 @@ from helpers.inference_server_starter import async_inference_server_startup_flow
16
18
  from helpers.truss_patch.model_container_patch_applier import ModelContainerPatchApplier
17
19
  from shared import log_config
18
20
  from starlette.datastructures import State
19
-
20
-
21
- async def handle_patch_error(_, exc):
22
- error_type = _camel_to_snake_case(type(exc).__name__)
23
- return JSONResponse(content={"error": {"type": error_type, "msg": str(exc)}})
24
-
25
-
26
- async def generic_error_handler(_, exc):
27
- return JSONResponse(
28
- content={"error": {"type": "unknown", "msg": f"{type(exc)}: {exc}"}}
29
- )
30
-
31
-
32
- async def handle_model_load_failed(_, error):
33
- # Model load failures should result in 503 status
34
- return JSONResponse({"error": str(error)}, 503)
21
+ from starlette.middleware.base import BaseHTTPMiddleware
22
+
23
+ SANITIZED_EXCEPTION_FRAMES = 2
24
+
25
+
26
+ # NB(nikhil): SanitizedExceptionMiddleware will reduce the noise of control server stack frames, since
27
+ # users often complain about the verbosity. Now, if any exceptions are explicitly raised during a proxied
28
+ # request, we'll log the last two stack frames which should be sufficient for debugging while significantly
29
+ # cutting down the volume.
30
+ class SanitizedExceptionMiddleware(BaseHTTPMiddleware):
31
+ def __init__(self, app, num_frames: int = SANITIZED_EXCEPTION_FRAMES):
32
+ super().__init__(app)
33
+ self.num_frames = num_frames
34
+
35
+ async def dispatch(
36
+ self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
37
+ ) -> Response:
38
+ try:
39
+ return await call_next(request)
40
+ except Exception as exc:
41
+ sanitized_traceback = self._create_sanitized_traceback(exc)
42
+ request.app.state.logger.error(sanitized_traceback)
43
+
44
+ if isinstance(exc, ModelLoadFailed):
45
+ return JSONResponse(
46
+ {"error": str(exc)}, status_code=http.HTTPStatus.BAD_GATEWAY.value
47
+ )
48
+ elif isinstance(exc, PatchApplicatonError):
49
+ error_type = _camel_to_snake_case(type(exc).__name__)
50
+ return JSONResponse({"error": {"type": error_type, "msg": str(exc)}})
51
+ else:
52
+ return JSONResponse(
53
+ {"error": {"type": "unknown", "msg": str(exc)}},
54
+ status_code=http.HTTPStatus.INTERNAL_SERVER_ERROR.value,
55
+ )
56
+
57
+ def _create_sanitized_traceback(self, error: Exception) -> str:
58
+ tb_lines = traceback.format_tb(error.__traceback__)
59
+ if tb_lines and self.num_frames > 0:
60
+ return "".join(tb_lines[-self.num_frames :])
61
+ return f"{type(error).__name__}: {error}"
35
62
 
36
63
 
37
64
  def create_app(base_config: Dict):
@@ -82,14 +109,10 @@ def create_app(base_config: Dict):
82
109
  app = FastAPI(
83
110
  title="Truss Live Reload Server",
84
111
  on_startup=[start_background_inference_startup],
85
- exception_handlers={
86
- PatchApplicatonError: handle_patch_error,
87
- ModelLoadFailed: handle_model_load_failed,
88
- Exception: generic_error_handler,
89
- },
90
112
  )
91
113
  app.state = app_state
92
114
  app.include_router(control_app)
115
+ app.add_middleware(SanitizedExceptionMiddleware)
93
116
 
94
117
  @app.on_event("shutdown")
95
118
  def on_shutdown():
@@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Optional, Protocol
5
5
  import httpx
6
6
  from fastapi import APIRouter, WebSocket
7
7
  from fastapi.responses import JSONResponse, StreamingResponse
8
+ from helpers.errors import ModelLoadFailed, ModelNotReady
8
9
  from httpx_ws import AsyncWebSocketSession, WebSocketDisconnect, aconnect_ws
9
10
  from httpx_ws import _exceptions as httpx_ws_exceptions
10
11
  from starlette.requests import ClientDisconnect, Request
@@ -13,11 +14,6 @@ from starlette.websockets import WebSocketDisconnect as StartletteWebSocketDisco
13
14
  from tenacity import RetryCallState, Retrying, retry_if_exception_type, wait_fixed
14
15
  from wsproto.events import BytesMessage, TextMessage
15
16
 
16
- from truss.templates.control.control.helpers.errors import (
17
- ModelLoadFailed,
18
- ModelNotReady,
19
- )
20
-
21
17
  INFERENCE_SERVER_START_WAIT_SECS = 60
22
18
  BASE_RETRY_EXCEPTIONS = (
23
19
  retry_if_exception_type(httpx.ConnectError)
@@ -45,6 +45,19 @@ server {
45
45
  proxy_pass http://127.0.0.1:{{server_port}};
46
46
  }
47
47
 
48
+ location ~ ^/v1/websocket$ {
49
+ proxy_redirect off;
50
+ proxy_read_timeout 18030s;
51
+ proxy_http_version 1.1;
52
+
53
+ proxy_set_header Upgrade $upgrade_header;
54
+ proxy_set_header Connection $connection_header;
55
+
56
+ rewrite ^/v1/websocket$ {{server_endpoint}} break;
57
+
58
+ proxy_pass http://127.0.0.1:{{server_port}};
59
+ }
60
+
48
61
  # Forward all other paths
49
62
  location / {
50
63
  proxy_redirect off;