truss 0.11.8rc12__py3-none-any.whl → 0.11.9rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/cli/train/core.py +13 -11
- truss/cli/train/deploy_checkpoints/__init__.py +2 -2
- truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +211 -106
- truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +1 -52
- truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +0 -59
- truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py +0 -83
- truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +0 -53
- truss/cli/train/types.py +1 -11
- truss/cli/train_commands.py +5 -15
- truss/remote/baseten/api.py +87 -0
- truss/templates/control/control/application.py +46 -23
- truss/templates/control/control/endpoints.py +1 -5
- truss/tests/cli/train/test_deploy_checkpoints.py +0 -843
- truss/tests/templates/control/control/conftest.py +20 -0
- truss/tests/templates/control/control/test_endpoints.py +4 -0
- truss/tests/templates/control/control/test_server.py +7 -23
- truss/tests/templates/control/control/test_server_integration.py +4 -2
- {truss-0.11.8rc12.dist-info → truss-0.11.9rc2.dist-info}/METADATA +1 -1
- {truss-0.11.8rc12.dist-info → truss-0.11.9rc2.dist-info}/RECORD +23 -22
- truss_train/definitions.py +0 -1
- {truss-0.11.8rc12.dist-info → truss-0.11.9rc2.dist-info}/WHEEL +0 -0
- {truss-0.11.8rc12.dist-info → truss-0.11.9rc2.dist-info}/entry_points.txt +0 -0
- {truss-0.11.8rc12.dist-info → truss-0.11.9rc2.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,12 +1,3 @@
|
|
|
1
|
-
from pathlib import Path
|
|
2
|
-
|
|
3
|
-
from jinja2 import Template
|
|
4
|
-
|
|
5
|
-
from truss.base import truss_config
|
|
6
|
-
from truss.cli.train.deploy_checkpoints.deploy_checkpoints_helpers import (
|
|
7
|
-
START_COMMAND_ENVVAR_NAME,
|
|
8
|
-
)
|
|
9
|
-
from truss.cli.train.types import DeployCheckpointsConfigComplete
|
|
10
1
|
from truss_train.definitions import (
|
|
11
2
|
ALLOWED_LORA_RANKS,
|
|
12
3
|
DEFAULT_LORA_RANK,
|
|
@@ -14,21 +5,6 @@ from truss_train.definitions import (
|
|
|
14
5
|
LoRADetails,
|
|
15
6
|
)
|
|
16
7
|
|
|
17
|
-
from .deploy_checkpoints_helpers import (
|
|
18
|
-
setup_base_truss_config,
|
|
19
|
-
setup_environment_variables_and_secrets,
|
|
20
|
-
)
|
|
21
|
-
|
|
22
|
-
VLLM_LORA_START_COMMAND = Template(
|
|
23
|
-
'sh -c "{%if envvars %}{{ envvars }} {% endif %}vllm serve {{ base_model_id }}'
|
|
24
|
-
+ " --port 8000"
|
|
25
|
-
+ "{{ specify_tensor_parallelism }}"
|
|
26
|
-
+ " --enable-lora"
|
|
27
|
-
+ " --max-lora-rank {{ max_lora_rank }}"
|
|
28
|
-
+ " --dtype bfloat16"
|
|
29
|
-
+ ' --lora-modules {{ lora_modules }}"'
|
|
30
|
-
)
|
|
31
|
-
|
|
32
8
|
|
|
33
9
|
def hydrate_lora_checkpoint(
|
|
34
10
|
job_id: str, checkpoint_id: str, checkpoint: dict
|
|
@@ -43,49 +19,6 @@ def hydrate_lora_checkpoint(
|
|
|
43
19
|
)
|
|
44
20
|
|
|
45
21
|
|
|
46
|
-
def render_vllm_lora_truss_config(
|
|
47
|
-
checkpoint_deploy: DeployCheckpointsConfigComplete,
|
|
48
|
-
) -> truss_config.TrussConfig:
|
|
49
|
-
"""Render truss config specifically for LoRA checkpoints using vLLM."""
|
|
50
|
-
truss_deploy_config = setup_base_truss_config(checkpoint_deploy)
|
|
51
|
-
start_command_envvars = setup_environment_variables_and_secrets(
|
|
52
|
-
truss_deploy_config, checkpoint_deploy
|
|
53
|
-
)
|
|
54
|
-
|
|
55
|
-
checkpoint_str = _build_lora_checkpoint_string(truss_deploy_config)
|
|
56
|
-
|
|
57
|
-
max_lora_rank = max(
|
|
58
|
-
[
|
|
59
|
-
checkpoint.lora_details.rank or DEFAULT_LORA_RANK
|
|
60
|
-
for checkpoint in checkpoint_deploy.checkpoint_details.checkpoints
|
|
61
|
-
if hasattr(checkpoint, "lora_details") and checkpoint.lora_details
|
|
62
|
-
]
|
|
63
|
-
)
|
|
64
|
-
accelerator = checkpoint_deploy.compute.accelerator
|
|
65
|
-
if accelerator:
|
|
66
|
-
specify_tensor_parallelism = f" --tensor-parallel-size {accelerator.count}"
|
|
67
|
-
else:
|
|
68
|
-
specify_tensor_parallelism = ""
|
|
69
|
-
|
|
70
|
-
start_command_args = {
|
|
71
|
-
"base_model_id": checkpoint_deploy.checkpoint_details.base_model_id,
|
|
72
|
-
"lora_modules": checkpoint_str,
|
|
73
|
-
"envvars": start_command_envvars,
|
|
74
|
-
"max_lora_rank": max_lora_rank,
|
|
75
|
-
"specify_tensor_parallelism": specify_tensor_parallelism,
|
|
76
|
-
}
|
|
77
|
-
start_command = VLLM_LORA_START_COMMAND.render(**start_command_args)
|
|
78
|
-
# Note: we set the start command as an environment variable in supervisord config.
|
|
79
|
-
# This is so that we don't have to change the supervisord config when the start command changes.
|
|
80
|
-
# Our goal is to reduce the number of times we need to rebuild the image, and allow us to deploy faster.
|
|
81
|
-
truss_deploy_config.environment_variables[START_COMMAND_ENVVAR_NAME] = start_command
|
|
82
|
-
# Note: supervisord uses the convention %(ENV_VAR_NAME)s to access environment variable VAR_NAME
|
|
83
|
-
truss_deploy_config.docker_server.start_command = ( # type: ignore[union-attr]
|
|
84
|
-
f"%(ENV_{START_COMMAND_ENVVAR_NAME})s"
|
|
85
|
-
)
|
|
86
|
-
return truss_deploy_config
|
|
87
|
-
|
|
88
|
-
|
|
89
22
|
def _get_lora_rank(checkpoint_resp: dict) -> int:
|
|
90
23
|
"""Extract and validate LoRA rank from checkpoint response."""
|
|
91
24
|
lora_adapter_config = checkpoint_resp.get("lora_adapter_config") or {}
|
|
@@ -99,19 +32,3 @@ def _get_lora_rank(checkpoint_resp: dict) -> int:
|
|
|
99
32
|
)
|
|
100
33
|
|
|
101
34
|
return lora_rank
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
def _build_lora_checkpoint_string(truss_deploy_config) -> str:
|
|
105
|
-
"""Build the checkpoint string for LoRA modules from truss deploy config."""
|
|
106
|
-
checkpoint_parts = []
|
|
107
|
-
for (
|
|
108
|
-
truss_checkpoint
|
|
109
|
-
) in truss_deploy_config.training_checkpoints.artifact_references: # type: ignore
|
|
110
|
-
ckpt_path = Path(
|
|
111
|
-
truss_deploy_config.training_checkpoints.download_folder, # type: ignore
|
|
112
|
-
truss_checkpoint.training_job_id,
|
|
113
|
-
truss_checkpoint.paths[0],
|
|
114
|
-
)
|
|
115
|
-
checkpoint_parts.append(f"{truss_checkpoint.training_job_id}={ckpt_path}")
|
|
116
|
-
|
|
117
|
-
return " ".join(checkpoint_parts)
|
|
@@ -1,58 +1,5 @@
|
|
|
1
|
-
from jinja2 import Template
|
|
2
|
-
|
|
3
|
-
from truss.base import truss_config
|
|
4
|
-
from truss.cli.train.deploy_checkpoints.deploy_checkpoints_helpers import (
|
|
5
|
-
START_COMMAND_ENVVAR_NAME,
|
|
6
|
-
)
|
|
7
|
-
from truss.cli.train.deploy_checkpoints.deploy_full_checkpoints import (
|
|
8
|
-
build_full_checkpoint_string,
|
|
9
|
-
)
|
|
10
|
-
from truss.cli.train.types import DeployCheckpointsConfigComplete
|
|
11
1
|
from truss_train.definitions import WhisperCheckpoint
|
|
12
2
|
|
|
13
|
-
from .deploy_checkpoints_helpers import (
|
|
14
|
-
setup_base_truss_config,
|
|
15
|
-
setup_environment_variables_and_secrets,
|
|
16
|
-
)
|
|
17
|
-
|
|
18
|
-
VLLM_WHISPER_START_COMMAND = Template(
|
|
19
|
-
"sh -c '{% if envvars %}{{ envvars }} {% endif %}"
|
|
20
|
-
'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
|
|
21
|
-
"vllm serve {{ model_path }} --port 8000 --tensor-parallel-size {{ specify_tensor_parallelism }}'"
|
|
22
|
-
)
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
def render_vllm_whisper_truss_config(
|
|
26
|
-
checkpoint_deploy: DeployCheckpointsConfigComplete,
|
|
27
|
-
) -> truss_config.TrussConfig:
|
|
28
|
-
"""Render truss config specifically for whisper checkpoints using vLLM."""
|
|
29
|
-
truss_deploy_config = setup_base_truss_config(checkpoint_deploy)
|
|
30
|
-
|
|
31
|
-
start_command_envvars = setup_environment_variables_and_secrets(
|
|
32
|
-
truss_deploy_config, checkpoint_deploy
|
|
33
|
-
)
|
|
34
|
-
|
|
35
|
-
checkpoint_str = build_full_checkpoint_string(truss_deploy_config)
|
|
36
|
-
|
|
37
|
-
accelerator = checkpoint_deploy.compute.accelerator
|
|
38
|
-
|
|
39
|
-
start_command_args = {
|
|
40
|
-
"model_path": checkpoint_str,
|
|
41
|
-
"envvars": start_command_envvars,
|
|
42
|
-
"specify_tensor_parallelism": accelerator.count if accelerator else 1,
|
|
43
|
-
}
|
|
44
|
-
# Note: we set the start command as an environment variable in supervisord config.
|
|
45
|
-
# This is so that we don't have to change the supervisord config when the start command changes.
|
|
46
|
-
# Our goal is to reduce the number of times we need to rebuild the image, and allow us to deploy faster.
|
|
47
|
-
start_command = VLLM_WHISPER_START_COMMAND.render(**start_command_args)
|
|
48
|
-
truss_deploy_config.environment_variables[START_COMMAND_ENVVAR_NAME] = start_command
|
|
49
|
-
# Note: supervisord uses the convention %(ENV_VAR_NAME)s to access environment variable VAR_NAME
|
|
50
|
-
truss_deploy_config.docker_server.start_command = ( # type: ignore[union-attr]
|
|
51
|
-
f"%(ENV_{START_COMMAND_ENVVAR_NAME})s"
|
|
52
|
-
)
|
|
53
|
-
|
|
54
|
-
return truss_deploy_config
|
|
55
|
-
|
|
56
3
|
|
|
57
4
|
def hydrate_whisper_checkpoint(
|
|
58
5
|
job_id: str, checkpoint_id: str, checkpoint: dict
|
truss/cli/train/types.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
from pathlib import Path
|
|
3
2
|
from typing import Optional
|
|
4
3
|
|
|
5
4
|
from truss_train.definitions import (
|
|
@@ -7,12 +6,11 @@ from truss_train.definitions import (
|
|
|
7
6
|
Compute,
|
|
8
7
|
DeployCheckpointsConfig,
|
|
9
8
|
DeployCheckpointsRuntime,
|
|
10
|
-
ModelWeightsFormat,
|
|
11
9
|
)
|
|
12
10
|
|
|
13
11
|
|
|
14
12
|
@dataclass
|
|
15
|
-
class
|
|
13
|
+
class DeployCheckpointArgs:
|
|
16
14
|
project_id: Optional[str]
|
|
17
15
|
job_id: Optional[str]
|
|
18
16
|
deploy_config_path: Optional[str]
|
|
@@ -26,13 +24,5 @@ class DeployCheckpointsConfigComplete(DeployCheckpointsConfig):
|
|
|
26
24
|
|
|
27
25
|
checkpoint_details: CheckpointList
|
|
28
26
|
model_name: str
|
|
29
|
-
deployment_name: str
|
|
30
27
|
runtime: DeployCheckpointsRuntime
|
|
31
28
|
compute: Compute
|
|
32
|
-
model_weight_format: ModelWeightsFormat
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
@dataclass
|
|
36
|
-
class PrepareCheckpointResult:
|
|
37
|
-
truss_directory: Path
|
|
38
|
-
checkpoint_deploy_config: DeployCheckpointsConfigComplete
|
truss/cli/train_commands.py
CHANGED
|
@@ -8,7 +8,7 @@ import rich_click as click
|
|
|
8
8
|
import truss.cli.train.core as train_cli
|
|
9
9
|
from truss.base.constants import TRAINING_TEMPLATE_DIR
|
|
10
10
|
from truss.cli import remote_cli
|
|
11
|
-
from truss.cli.cli import
|
|
11
|
+
from truss.cli.cli import truss_cli
|
|
12
12
|
from truss.cli.logs import utils as cli_log_utils
|
|
13
13
|
from truss.cli.logs.training_log_watcher import TrainingLogWatcher
|
|
14
14
|
from truss.cli.train import common as train_common
|
|
@@ -329,26 +329,16 @@ def deploy_checkpoints(
|
|
|
329
329
|
project_id = _maybe_resolve_project_id_from_id_or_name(
|
|
330
330
|
remote_provider, project_id=project_id, project=project
|
|
331
331
|
)
|
|
332
|
-
|
|
332
|
+
result = train_cli.create_model_version_from_inference_template(
|
|
333
333
|
remote_provider,
|
|
334
|
-
train_cli.
|
|
334
|
+
train_cli.DeployCheckpointArgs(
|
|
335
335
|
project_id=project_id, job_id=job_id, deploy_config_path=config
|
|
336
336
|
),
|
|
337
337
|
)
|
|
338
338
|
|
|
339
|
-
params = {
|
|
340
|
-
"target_directory": prepare_checkpoint_result.truss_directory,
|
|
341
|
-
"remote": remote,
|
|
342
|
-
"model_name": prepare_checkpoint_result.checkpoint_deploy_config.model_name,
|
|
343
|
-
"publish": True,
|
|
344
|
-
"deployment_name": prepare_checkpoint_result.checkpoint_deploy_config.deployment_name,
|
|
345
|
-
}
|
|
346
|
-
ctx = _prepare_click_context(push, params)
|
|
347
339
|
if dry_run:
|
|
348
|
-
console.print("--dry-run flag provided, not
|
|
349
|
-
|
|
350
|
-
push.invoke(ctx)
|
|
351
|
-
train_cli.print_deploy_checkpoints_success_message(prepare_checkpoint_result)
|
|
340
|
+
console.print("--dry-run flag provided, did not deploy", style="yellow")
|
|
341
|
+
train_cli.print_deploy_checkpoints_success_message(result)
|
|
352
342
|
|
|
353
343
|
|
|
354
344
|
@train.command(name="download")
|
truss/remote/baseten/api.py
CHANGED
|
@@ -3,6 +3,7 @@ from enum import Enum
|
|
|
3
3
|
from typing import Any, Dict, List, Mapping, Optional
|
|
4
4
|
|
|
5
5
|
import requests
|
|
6
|
+
from pydantic import BaseModel, Field
|
|
6
7
|
|
|
7
8
|
from truss.remote.baseten import custom_types as b10_types
|
|
8
9
|
from truss.remote.baseten.auth import ApiKey, AuthService
|
|
@@ -14,6 +15,39 @@ from truss.remote.baseten.utils.transfer import base64_encoded_json_str
|
|
|
14
15
|
logger = logging.getLogger(__name__)
|
|
15
16
|
|
|
16
17
|
|
|
18
|
+
class InstanceTypeV1(BaseModel):
|
|
19
|
+
"""An instance type."""
|
|
20
|
+
|
|
21
|
+
id: str = Field(description="Identifier string for the instance type")
|
|
22
|
+
name: str = Field(description="Name of the instance type")
|
|
23
|
+
display_name: str = Field(
|
|
24
|
+
alias="displayName", description="Display name of the instance type"
|
|
25
|
+
)
|
|
26
|
+
gpu_count: int = Field(
|
|
27
|
+
alias="gpuCount", description="Number of GPUs on the instance type"
|
|
28
|
+
)
|
|
29
|
+
default: bool = Field(description="Whether this is the default instance type")
|
|
30
|
+
gpu_memory: Optional[int] = Field(alias="gpuMemory", description="GPU memory in MB")
|
|
31
|
+
node_count: int = Field(alias="nodeCount", description="Number of nodes")
|
|
32
|
+
gpu_type: Optional[str] = Field(
|
|
33
|
+
alias="gpuType", description="Type of GPU on the instance type"
|
|
34
|
+
)
|
|
35
|
+
millicpu_limit: int = Field(
|
|
36
|
+
alias="millicpuLimit", description="CPU limit of the instance type in millicpu"
|
|
37
|
+
)
|
|
38
|
+
memory_limit: int = Field(
|
|
39
|
+
alias="memoryLimit", description="Memory limit of the instance type in MB"
|
|
40
|
+
)
|
|
41
|
+
price: Optional[float] = Field(description="Price of the instance type")
|
|
42
|
+
limited_capacity: Optional[bool] = Field(
|
|
43
|
+
alias="limitedCapacity",
|
|
44
|
+
description="Whether this instance type has limited capacity",
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
class Config:
|
|
48
|
+
populate_by_name = True
|
|
49
|
+
|
|
50
|
+
|
|
17
51
|
API_URL_MAPPING = {
|
|
18
52
|
"https://app.baseten.co": "https://api.baseten.co",
|
|
19
53
|
"https://app.staging.baseten.co": "https://api.staging.baseten.co",
|
|
@@ -750,3 +784,56 @@ class BasetenApi:
|
|
|
750
784
|
|
|
751
785
|
# NB(nikhil): reverse order so latest logs are at the end
|
|
752
786
|
return resp_json["logs"][::-1]
|
|
787
|
+
|
|
788
|
+
def create_model_version_from_inference_template(self, request_data: dict):
|
|
789
|
+
"""
|
|
790
|
+
Create a model version from an inference template using GraphQL mutation.
|
|
791
|
+
|
|
792
|
+
Args:
|
|
793
|
+
request_data: Dictionary containing the request structure with metadata,
|
|
794
|
+
weights_sources, inference_stack, and instance_type_id
|
|
795
|
+
"""
|
|
796
|
+
query_string = """
|
|
797
|
+
mutation ($request: CreateModelVersionFromInferenceTemplateRequest!) {
|
|
798
|
+
create_model_version_from_inference_template(request: $request) {
|
|
799
|
+
model_version {
|
|
800
|
+
id
|
|
801
|
+
name
|
|
802
|
+
}
|
|
803
|
+
}
|
|
804
|
+
}
|
|
805
|
+
"""
|
|
806
|
+
|
|
807
|
+
resp = self._post_graphql_query(
|
|
808
|
+
query_string, variables={"request": request_data}
|
|
809
|
+
)
|
|
810
|
+
return resp["data"]["create_model_version_from_inference_template"]
|
|
811
|
+
|
|
812
|
+
def get_instance_types(self) -> List[InstanceTypeV1]:
|
|
813
|
+
"""
|
|
814
|
+
Get all available instance types via GraphQL API.
|
|
815
|
+
"""
|
|
816
|
+
query_string = """
|
|
817
|
+
query Instances {
|
|
818
|
+
listedInstances: listed_instances {
|
|
819
|
+
id
|
|
820
|
+
name
|
|
821
|
+
millicpuLimit: millicpu_limit
|
|
822
|
+
memoryLimit: memory_limit
|
|
823
|
+
gpuCount: gpu_count
|
|
824
|
+
gpuType: gpu_type
|
|
825
|
+
gpuMemory: gpu_memory
|
|
826
|
+
default
|
|
827
|
+
displayName: display_name
|
|
828
|
+
nodeCount: node_count
|
|
829
|
+
price
|
|
830
|
+
limitedCapacity: limited_capacity
|
|
831
|
+
}
|
|
832
|
+
}
|
|
833
|
+
"""
|
|
834
|
+
|
|
835
|
+
resp = self._post_graphql_query(query_string)
|
|
836
|
+
instance_types_data = resp["data"]["listedInstances"]
|
|
837
|
+
return [
|
|
838
|
+
InstanceTypeV1(**instance_type) for instance_type in instance_types_data
|
|
839
|
+
]
|
|
@@ -1,13 +1,15 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
+
import http
|
|
2
3
|
import logging
|
|
3
4
|
import logging.config
|
|
4
5
|
import re
|
|
6
|
+
import traceback
|
|
5
7
|
from pathlib import Path
|
|
6
|
-
from typing import Dict
|
|
8
|
+
from typing import Awaitable, Callable, Dict
|
|
7
9
|
|
|
8
10
|
import httpx
|
|
9
11
|
from endpoints import control_app
|
|
10
|
-
from fastapi import FastAPI
|
|
12
|
+
from fastapi import FastAPI, Request, Response
|
|
11
13
|
from fastapi.responses import JSONResponse
|
|
12
14
|
from helpers.errors import ModelLoadFailed, PatchApplicatonError
|
|
13
15
|
from helpers.inference_server_controller import InferenceServerController
|
|
@@ -16,22 +18,47 @@ from helpers.inference_server_starter import async_inference_server_startup_flow
|
|
|
16
18
|
from helpers.truss_patch.model_container_patch_applier import ModelContainerPatchApplier
|
|
17
19
|
from shared import log_config
|
|
18
20
|
from starlette.datastructures import State
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
)
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
21
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
22
|
+
|
|
23
|
+
SANITIZED_EXCEPTION_FRAMES = 2
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# NB(nikhil): SanitizedExceptionMiddleware will reduce the noise of control server stack frames, since
|
|
27
|
+
# users often complain about the verbosity. Now, if any exceptions are explicitly raised during a proxied
|
|
28
|
+
# request, we'll log the last two stack frames which should be sufficient for debugging while significantly
|
|
29
|
+
# cutting down the volume.
|
|
30
|
+
class SanitizedExceptionMiddleware(BaseHTTPMiddleware):
|
|
31
|
+
def __init__(self, app, num_frames: int = SANITIZED_EXCEPTION_FRAMES):
|
|
32
|
+
super().__init__(app)
|
|
33
|
+
self.num_frames = num_frames
|
|
34
|
+
|
|
35
|
+
async def dispatch(
|
|
36
|
+
self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
|
37
|
+
) -> Response:
|
|
38
|
+
try:
|
|
39
|
+
return await call_next(request)
|
|
40
|
+
except Exception as exc:
|
|
41
|
+
sanitized_traceback = self._create_sanitized_traceback(exc)
|
|
42
|
+
request.app.state.logger.error(sanitized_traceback)
|
|
43
|
+
|
|
44
|
+
if isinstance(exc, ModelLoadFailed):
|
|
45
|
+
return JSONResponse(
|
|
46
|
+
{"error": str(exc)}, status_code=http.HTTPStatus.BAD_GATEWAY.value
|
|
47
|
+
)
|
|
48
|
+
elif isinstance(exc, PatchApplicatonError):
|
|
49
|
+
error_type = _camel_to_snake_case(type(exc).__name__)
|
|
50
|
+
return JSONResponse({"error": {"type": error_type, "msg": str(exc)}})
|
|
51
|
+
else:
|
|
52
|
+
return JSONResponse(
|
|
53
|
+
{"error": {"type": "unknown", "msg": str(exc)}},
|
|
54
|
+
status_code=http.HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
def _create_sanitized_traceback(self, error: Exception) -> str:
|
|
58
|
+
tb_lines = traceback.format_tb(error.__traceback__)
|
|
59
|
+
if tb_lines and self.num_frames > 0:
|
|
60
|
+
return "".join(tb_lines[-self.num_frames :])
|
|
61
|
+
return f"{type(error).__name__}: {error}"
|
|
35
62
|
|
|
36
63
|
|
|
37
64
|
def create_app(base_config: Dict):
|
|
@@ -82,14 +109,10 @@ def create_app(base_config: Dict):
|
|
|
82
109
|
app = FastAPI(
|
|
83
110
|
title="Truss Live Reload Server",
|
|
84
111
|
on_startup=[start_background_inference_startup],
|
|
85
|
-
exception_handlers={
|
|
86
|
-
PatchApplicatonError: handle_patch_error,
|
|
87
|
-
ModelLoadFailed: handle_model_load_failed,
|
|
88
|
-
Exception: generic_error_handler,
|
|
89
|
-
},
|
|
90
112
|
)
|
|
91
113
|
app.state = app_state
|
|
92
114
|
app.include_router(control_app)
|
|
115
|
+
app.add_middleware(SanitizedExceptionMiddleware)
|
|
93
116
|
|
|
94
117
|
@app.on_event("shutdown")
|
|
95
118
|
def on_shutdown():
|
|
@@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Optional, Protocol
|
|
|
5
5
|
import httpx
|
|
6
6
|
from fastapi import APIRouter, WebSocket
|
|
7
7
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
8
|
+
from helpers.errors import ModelLoadFailed, ModelNotReady
|
|
8
9
|
from httpx_ws import AsyncWebSocketSession, WebSocketDisconnect, aconnect_ws
|
|
9
10
|
from httpx_ws import _exceptions as httpx_ws_exceptions
|
|
10
11
|
from starlette.requests import ClientDisconnect, Request
|
|
@@ -13,11 +14,6 @@ from starlette.websockets import WebSocketDisconnect as StartletteWebSocketDisco
|
|
|
13
14
|
from tenacity import RetryCallState, Retrying, retry_if_exception_type, wait_fixed
|
|
14
15
|
from wsproto.events import BytesMessage, TextMessage
|
|
15
16
|
|
|
16
|
-
from truss.templates.control.control.helpers.errors import (
|
|
17
|
-
ModelLoadFailed,
|
|
18
|
-
ModelNotReady,
|
|
19
|
-
)
|
|
20
|
-
|
|
21
17
|
INFERENCE_SERVER_START_WAIT_SECS = 60
|
|
22
18
|
BASE_RETRY_EXCEPTIONS = (
|
|
23
19
|
retry_if_exception_type(httpx.ConnectError)
|