truss 0.11.6rc102__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- truss/api/__init__.py +5 -2
- truss/base/constants.py +1 -0
- truss/base/trt_llm_config.py +14 -3
- truss/base/truss_config.py +19 -4
- truss/cli/chains_commands.py +49 -1
- truss/cli/cli.py +38 -7
- truss/cli/logs/base_watcher.py +31 -12
- truss/cli/logs/model_log_watcher.py +24 -1
- truss/cli/remote_cli.py +29 -0
- truss/cli/resolvers/chain_team_resolver.py +82 -0
- truss/cli/resolvers/model_team_resolver.py +90 -0
- truss/cli/resolvers/training_project_team_resolver.py +81 -0
- truss/cli/train/cache.py +332 -0
- truss/cli/train/core.py +57 -163
- truss/cli/train/deploy_checkpoints/__init__.py +2 -2
- truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +236 -103
- truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +1 -52
- truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +1 -86
- truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py +1 -85
- truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +1 -56
- truss/cli/train/types.py +18 -9
- truss/cli/train_commands.py +180 -35
- truss/cli/utils/common.py +40 -3
- truss/contexts/image_builder/serving_image_builder.py +17 -4
- truss/remote/baseten/api.py +215 -9
- truss/remote/baseten/core.py +63 -7
- truss/remote/baseten/custom_types.py +1 -0
- truss/remote/baseten/remote.py +42 -2
- truss/remote/baseten/service.py +0 -7
- truss/remote/baseten/utils/transfer.py +5 -2
- truss/templates/base.Dockerfile.jinja +8 -4
- truss/templates/control/control/application.py +51 -26
- truss/templates/control/control/endpoints.py +1 -5
- truss/templates/control/control/helpers/inference_server_process_controller.py +10 -4
- truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py +33 -18
- truss/templates/control/control/server.py +1 -1
- truss/templates/control/requirements.txt +1 -2
- truss/templates/docker_server/proxy.conf.jinja +13 -0
- truss/templates/docker_server/supervisord.conf.jinja +2 -1
- truss/templates/no_build.Dockerfile.jinja +1 -0
- truss/templates/server/requirements.txt +2 -3
- truss/templates/server/truss_server.py +2 -5
- truss/templates/server.Dockerfile.jinja +12 -12
- truss/templates/shared/lazy_data_resolver.py +214 -2
- truss/templates/shared/util.py +6 -5
- truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
- truss/tests/cli/test_chains_cli.py +144 -0
- truss/tests/cli/test_cli.py +134 -1
- truss/tests/cli/test_cli_utils_common.py +11 -0
- truss/tests/cli/test_model_team_resolver.py +279 -0
- truss/tests/cli/train/test_cache_view.py +240 -3
- truss/tests/cli/train/test_deploy_checkpoints.py +2 -846
- truss/tests/cli/train/test_train_cli_core.py +2 -2
- truss/tests/cli/train/test_train_team_parameter.py +395 -0
- truss/tests/conftest.py +187 -0
- truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
- truss/tests/remote/baseten/test_api.py +122 -3
- truss/tests/remote/baseten/test_chain_upload.py +294 -0
- truss/tests/remote/baseten/test_core.py +86 -0
- truss/tests/remote/baseten/test_remote.py +216 -288
- truss/tests/remote/baseten/test_service.py +56 -0
- truss/tests/templates/control/control/conftest.py +20 -0
- truss/tests/templates/control/control/test_endpoints.py +4 -0
- truss/tests/templates/control/control/test_server.py +8 -24
- truss/tests/templates/control/control/test_server_integration.py +4 -2
- truss/tests/test_config.py +21 -12
- truss/tests/test_data/server.Dockerfile +3 -1
- truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
- truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
- truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
- truss/tests/test_model_inference.py +13 -0
- truss/tests/util/test_env_vars.py +8 -3
- truss/util/__init__.py +0 -0
- truss/util/env_vars.py +19 -8
- truss/util/error_utils.py +37 -0
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/METADATA +2 -2
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/RECORD +88 -70
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
- truss_chains/deployment/deployment_client.py +16 -4
- truss_chains/private_types.py +18 -0
- truss_chains/public_api.py +3 -0
- truss_train/definitions.py +6 -4
- truss_train/deployment.py +43 -21
- truss_train/public_api.py +4 -2
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
truss/remote/baseten/remote.py
CHANGED
|
@@ -31,6 +31,7 @@ from truss.remote.baseten.core import (
|
|
|
31
31
|
get_model_and_versions,
|
|
32
32
|
get_prod_version_from_versions,
|
|
33
33
|
get_truss_watch_state,
|
|
34
|
+
upload_chain_artifact,
|
|
34
35
|
upload_truss,
|
|
35
36
|
validate_truss_config_against_backend,
|
|
36
37
|
)
|
|
@@ -68,6 +69,7 @@ class FinalPushData(custom_types.OracleData):
|
|
|
68
69
|
origin: Optional[custom_types.ModelOrigin] = None
|
|
69
70
|
environment: Optional[str] = None
|
|
70
71
|
allow_truss_download: bool
|
|
72
|
+
team_id: Optional[str] = None
|
|
71
73
|
|
|
72
74
|
|
|
73
75
|
class BasetenRemote(TrussRemote):
|
|
@@ -126,6 +128,8 @@ class BasetenRemote(TrussRemote):
|
|
|
126
128
|
origin: Optional[custom_types.ModelOrigin] = None,
|
|
127
129
|
environment: Optional[str] = None,
|
|
128
130
|
progress_bar: Optional[Type["progress.Progress"]] = None,
|
|
131
|
+
deploy_timeout_minutes: Optional[int] = None,
|
|
132
|
+
team_id: Optional[str] = None,
|
|
129
133
|
) -> FinalPushData:
|
|
130
134
|
if model_name.isspace():
|
|
131
135
|
raise ValueError("Model name cannot be empty")
|
|
@@ -163,6 +167,13 @@ class BasetenRemote(TrussRemote):
|
|
|
163
167
|
"Deployment name must only contain alphanumeric, -, _ and . characters"
|
|
164
168
|
)
|
|
165
169
|
|
|
170
|
+
if deploy_timeout_minutes is not None and (
|
|
171
|
+
deploy_timeout_minutes < 10 or deploy_timeout_minutes > 1440
|
|
172
|
+
):
|
|
173
|
+
raise ValueError(
|
|
174
|
+
"deploy-timeout-minutes must be between 10 minutes and 1440 minutes (24 hours)"
|
|
175
|
+
)
|
|
176
|
+
|
|
166
177
|
model_id = exists_model(self._api, model_name)
|
|
167
178
|
|
|
168
179
|
if model_id is not None and disable_truss_download:
|
|
@@ -187,6 +198,7 @@ class BasetenRemote(TrussRemote):
|
|
|
187
198
|
origin=origin,
|
|
188
199
|
environment=environment,
|
|
189
200
|
allow_truss_download=not disable_truss_download,
|
|
201
|
+
team_id=team_id,
|
|
190
202
|
)
|
|
191
203
|
|
|
192
204
|
def push( # type: ignore
|
|
@@ -204,6 +216,8 @@ class BasetenRemote(TrussRemote):
|
|
|
204
216
|
progress_bar: Optional[Type["progress.Progress"]] = None,
|
|
205
217
|
include_git_info: bool = False,
|
|
206
218
|
preserve_env_instance_type: bool = True,
|
|
219
|
+
deploy_timeout_minutes: Optional[int] = None,
|
|
220
|
+
team_id: Optional[str] = None,
|
|
207
221
|
) -> BasetenService:
|
|
208
222
|
push_data = self._prepare_push(
|
|
209
223
|
truss_handle=truss_handle,
|
|
@@ -216,6 +230,8 @@ class BasetenRemote(TrussRemote):
|
|
|
216
230
|
origin=origin,
|
|
217
231
|
environment=environment,
|
|
218
232
|
progress_bar=progress_bar,
|
|
233
|
+
deploy_timeout_minutes=deploy_timeout_minutes,
|
|
234
|
+
team_id=team_id,
|
|
219
235
|
)
|
|
220
236
|
|
|
221
237
|
if include_git_info:
|
|
@@ -241,6 +257,8 @@ class BasetenRemote(TrussRemote):
|
|
|
241
257
|
environment=push_data.environment,
|
|
242
258
|
truss_user_env=truss_user_env,
|
|
243
259
|
preserve_env_instance_type=preserve_env_instance_type,
|
|
260
|
+
deploy_timeout_minutes=deploy_timeout_minutes,
|
|
261
|
+
team_id=push_data.team_id,
|
|
244
262
|
)
|
|
245
263
|
|
|
246
264
|
if model_version_handle.instance_type_name:
|
|
@@ -263,9 +281,13 @@ class BasetenRemote(TrussRemote):
|
|
|
263
281
|
entrypoint_artifact: custom_types.ChainletArtifact,
|
|
264
282
|
dependency_artifacts: List[custom_types.ChainletArtifact],
|
|
265
283
|
truss_user_env: b10_types.TrussUserEnv,
|
|
284
|
+
chain_root: Optional[Path] = None,
|
|
266
285
|
publish: bool = False,
|
|
267
286
|
environment: Optional[str] = None,
|
|
268
287
|
progress_bar: Optional[Type["progress.Progress"]] = None,
|
|
288
|
+
disable_chain_download: bool = False,
|
|
289
|
+
deployment_name: Optional[str] = None,
|
|
290
|
+
team_id: Optional[str] = None,
|
|
269
291
|
) -> ChainDeploymentHandleAtomic:
|
|
270
292
|
# If we are promoting a model to an environment after deploy, it must be published.
|
|
271
293
|
# Draft models cannot be promoted.
|
|
@@ -285,6 +307,8 @@ class BasetenRemote(TrussRemote):
|
|
|
285
307
|
publish=publish,
|
|
286
308
|
origin=custom_types.ModelOrigin.CHAINS,
|
|
287
309
|
progress_bar=progress_bar,
|
|
310
|
+
disable_truss_download=disable_chain_download,
|
|
311
|
+
deployment_name=deployment_name,
|
|
288
312
|
)
|
|
289
313
|
oracle_data = custom_types.OracleData(
|
|
290
314
|
model_name=push_data.model_name,
|
|
@@ -300,6 +324,18 @@ class BasetenRemote(TrussRemote):
|
|
|
300
324
|
)
|
|
301
325
|
)
|
|
302
326
|
|
|
327
|
+
# Upload raw chain artifact if chain_root is provided
|
|
328
|
+
raw_chain_s3_key = None
|
|
329
|
+
if chain_root is not None:
|
|
330
|
+
logging.info("Uploading source artifact")
|
|
331
|
+
# Create a tar file from the chain root directory
|
|
332
|
+
original_source_tar = archive_dir(dir=chain_root, progress_bar=progress_bar)
|
|
333
|
+
# Upload the chain artifact to S3
|
|
334
|
+
raw_chain_s3_key = upload_chain_artifact(
|
|
335
|
+
api=self._api,
|
|
336
|
+
serialize_file=original_source_tar,
|
|
337
|
+
progress_bar=progress_bar,
|
|
338
|
+
)
|
|
303
339
|
chain_deployment_handle = create_chain_atomic(
|
|
304
340
|
api=self._api,
|
|
305
341
|
chain_name=chain_name,
|
|
@@ -308,6 +344,10 @@ class BasetenRemote(TrussRemote):
|
|
|
308
344
|
is_draft=not publish,
|
|
309
345
|
truss_user_env=truss_user_env,
|
|
310
346
|
environment=environment,
|
|
347
|
+
original_source_artifact_s3_key=raw_chain_s3_key,
|
|
348
|
+
allow_truss_download=not disable_chain_download,
|
|
349
|
+
deployment_name=deployment_name,
|
|
350
|
+
team_id=team_id,
|
|
311
351
|
)
|
|
312
352
|
logging.info("Successfully pushed to baseten. Chain is building and deploying.")
|
|
313
353
|
return chain_deployment_handle
|
|
@@ -571,5 +611,5 @@ class BasetenRemote(TrussRemote):
|
|
|
571
611
|
) -> PatchResult:
|
|
572
612
|
return self._patch(watch_path, truss_ignore_patterns, console=None)
|
|
573
613
|
|
|
574
|
-
def upsert_training_project(self, training_project):
|
|
575
|
-
return self._api.upsert_training_project(training_project)
|
|
614
|
+
def upsert_training_project(self, training_project, team_id=None):
|
|
615
|
+
return self._api.upsert_training_project(training_project, team_id=team_id)
|
truss/remote/baseten/service.py
CHANGED
|
@@ -137,13 +137,6 @@ class BasetenService(TrussService):
|
|
|
137
137
|
|
|
138
138
|
return decode_content()
|
|
139
139
|
|
|
140
|
-
parsed_response = response.json()
|
|
141
|
-
|
|
142
|
-
if "error" in parsed_response:
|
|
143
|
-
# In the case that the model is in a non-ready state, the response
|
|
144
|
-
# will be a json with an `error` key.
|
|
145
|
-
return parsed_response
|
|
146
|
-
|
|
147
140
|
return response.json()
|
|
148
141
|
|
|
149
142
|
def authenticate(self) -> dict:
|
|
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Optional, Type
|
|
|
7
7
|
import boto3
|
|
8
8
|
from boto3.s3.transfer import TransferConfig
|
|
9
9
|
|
|
10
|
-
from truss.util.env_vars import
|
|
10
|
+
from truss.util.env_vars import modify_env_vars
|
|
11
11
|
|
|
12
12
|
if TYPE_CHECKING:
|
|
13
13
|
from rich import progress
|
|
@@ -26,7 +26,10 @@ def multipart_upload_boto3(
|
|
|
26
26
|
) -> None:
|
|
27
27
|
# In the CLI flow, ignore any local ~/.aws/config files,
|
|
28
28
|
# which can interfere with uploading the Truss to S3.
|
|
29
|
-
|
|
29
|
+
aws_env_vars = set(
|
|
30
|
+
env_var for env_var in os.environ.keys() if env_var.startswith("AWS_")
|
|
31
|
+
)
|
|
32
|
+
with modify_env_vars(deletions=aws_env_vars):
|
|
30
33
|
s3_resource = boto3.resource("s3", **credentials)
|
|
31
34
|
filesize = os.stat(file_path).st_size
|
|
32
35
|
|
|
@@ -33,7 +33,7 @@ RUN useradd -u {{ app_user_uid }} -ms /bin/bash {{ app_username }}
|
|
|
33
33
|
ENV DEBIAN_FRONTEND=noninteractive
|
|
34
34
|
|
|
35
35
|
|
|
36
|
-
{%- set UV_VERSION = "0.
|
|
36
|
+
{%- set UV_VERSION = "0.8.22" %}
|
|
37
37
|
{#
|
|
38
38
|
NB(nikhil): We use a semi-complex uv installation command across the board:
|
|
39
39
|
- A generous UV_HTTP_TIMEOUT (5m) for packages that take a long time to install.
|
|
@@ -59,10 +59,12 @@ RUN {{ python_exec_path }} -c "import sys; \
|
|
|
59
59
|
{% endblock %}
|
|
60
60
|
|
|
61
61
|
{% block install_uv %}
|
|
62
|
-
{# Install `uv` and `curl` if not already present in the image.
|
|
62
|
+
{# Install `uv` and `curl` if not already present in the image. We validate the expected location for `uv` at the very end
|
|
63
|
+
due to limitations with `pipefail` in Docker context. #}
|
|
63
64
|
RUN if ! command -v uv >/dev/null 2>&1; then \
|
|
64
65
|
command -v curl >/dev/null 2>&1 || (apt update && apt install -y curl) && \
|
|
65
|
-
curl -LsSf --retry 5 --retry-delay 5 https://astral.sh/uv/{{ UV_VERSION }}/install.sh | sh
|
|
66
|
+
curl -LsSf --retry 5 --retry-delay 5 https://astral.sh/uv/{{ UV_VERSION }}/install.sh | sh && \
|
|
67
|
+
test -x ${HOME}/.local/bin/uv; \
|
|
66
68
|
fi
|
|
67
69
|
{# Add the user's local bin to the path, used by uv. #}
|
|
68
70
|
ENV PATH=${PATH}:${HOME}/.local/bin
|
|
@@ -113,9 +115,11 @@ WORKDIR $APP_HOME
|
|
|
113
115
|
{% endblock %}
|
|
114
116
|
|
|
115
117
|
|
|
118
|
+
{% set packages_dir = "/packages" %}
|
|
119
|
+
RUN mkdir -p {{ packages_dir }}
|
|
116
120
|
{% block bundled_packages_copy %}
|
|
117
121
|
{%- if bundled_packages_dir_exists %}
|
|
118
|
-
COPY --chown={{ default_owner }} ./{{ config.bundled_packages_dir }}
|
|
122
|
+
COPY --chown={{ default_owner }} ./{{ config.bundled_packages_dir }} {{ packages_dir }}
|
|
119
123
|
{%- endif %}
|
|
120
124
|
{% endblock %}
|
|
121
125
|
|
|
@@ -1,13 +1,15 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
+
import http
|
|
2
3
|
import logging
|
|
3
4
|
import logging.config
|
|
4
5
|
import re
|
|
6
|
+
import traceback
|
|
5
7
|
from pathlib import Path
|
|
6
|
-
from typing import Dict
|
|
8
|
+
from typing import Awaitable, Callable, Dict
|
|
7
9
|
|
|
8
10
|
import httpx
|
|
9
11
|
from endpoints import control_app
|
|
10
|
-
from fastapi import FastAPI
|
|
12
|
+
from fastapi import FastAPI, Request, Response
|
|
11
13
|
from fastapi.responses import JSONResponse
|
|
12
14
|
from helpers.errors import ModelLoadFailed, PatchApplicatonError
|
|
13
15
|
from helpers.inference_server_controller import InferenceServerController
|
|
@@ -16,22 +18,50 @@ from helpers.inference_server_starter import async_inference_server_startup_flow
|
|
|
16
18
|
from helpers.truss_patch.model_container_patch_applier import ModelContainerPatchApplier
|
|
17
19
|
from shared import log_config
|
|
18
20
|
from starlette.datastructures import State
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
)
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
21
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
22
|
+
|
|
23
|
+
SANITIZED_EXCEPTION_FRAMES = 2
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# NB(nikhil): SanitizedExceptionMiddleware will reduce the noise of control server stack frames, since
|
|
27
|
+
# users often complain about the verbosity. Now, if any exceptions are explicitly raised during a proxied
|
|
28
|
+
# request, we'll log the last two stack frames which should be sufficient for debugging while significantly
|
|
29
|
+
# cutting down the volume.
|
|
30
|
+
class SanitizedExceptionMiddleware(BaseHTTPMiddleware):
|
|
31
|
+
def __init__(self, app, num_frames: int = SANITIZED_EXCEPTION_FRAMES):
|
|
32
|
+
super().__init__(app)
|
|
33
|
+
self.num_frames = num_frames
|
|
34
|
+
|
|
35
|
+
async def dispatch(
|
|
36
|
+
self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
|
37
|
+
) -> Response:
|
|
38
|
+
try:
|
|
39
|
+
return await call_next(request)
|
|
40
|
+
except Exception as exc:
|
|
41
|
+
# NB(nikhil): Intentionally bypass error logging for ModelLoadFailed, since health checks
|
|
42
|
+
# are noisy. The underlying model logs for why the load failed will still be visible.
|
|
43
|
+
if isinstance(exc, ModelLoadFailed):
|
|
44
|
+
return JSONResponse(
|
|
45
|
+
{"error": str(exc)}, status_code=http.HTTPStatus.BAD_GATEWAY.value
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
sanitized_traceback = self._create_sanitized_traceback(exc)
|
|
49
|
+
request.app.state.logger.error(sanitized_traceback)
|
|
50
|
+
|
|
51
|
+
if isinstance(exc, PatchApplicatonError):
|
|
52
|
+
error_type = _camel_to_snake_case(type(exc).__name__)
|
|
53
|
+
return JSONResponse({"error": {"type": error_type, "msg": str(exc)}})
|
|
54
|
+
else:
|
|
55
|
+
return JSONResponse(
|
|
56
|
+
{"error": {"type": "unknown", "msg": str(exc)}},
|
|
57
|
+
status_code=http.HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def _create_sanitized_traceback(self, error: Exception) -> str:
|
|
61
|
+
tb_lines = traceback.format_tb(error.__traceback__)
|
|
62
|
+
if tb_lines and self.num_frames > 0:
|
|
63
|
+
return "".join(tb_lines[-self.num_frames :])
|
|
64
|
+
return f"{type(error).__name__}: {error}"
|
|
35
65
|
|
|
36
66
|
|
|
37
67
|
def create_app(base_config: Dict):
|
|
@@ -57,10 +87,9 @@ def create_app(base_config: Dict):
|
|
|
57
87
|
base_url=f"http://localhost:{app_state.inference_server_port}", limits=limits
|
|
58
88
|
)
|
|
59
89
|
|
|
60
|
-
|
|
61
|
-
|
|
90
|
+
uv_path = getattr(app_state, "uv_path", None)
|
|
62
91
|
patch_applier = ModelContainerPatchApplier(
|
|
63
|
-
Path(app_state.inference_server_home), app_logger,
|
|
92
|
+
Path(app_state.inference_server_home), app_logger, uv_path
|
|
64
93
|
)
|
|
65
94
|
|
|
66
95
|
oversee_inference_server = getattr(app_state, "oversee_inference_server", True)
|
|
@@ -82,14 +111,10 @@ def create_app(base_config: Dict):
|
|
|
82
111
|
app = FastAPI(
|
|
83
112
|
title="Truss Live Reload Server",
|
|
84
113
|
on_startup=[start_background_inference_startup],
|
|
85
|
-
exception_handlers={
|
|
86
|
-
PatchApplicatonError: handle_patch_error,
|
|
87
|
-
ModelLoadFailed: handle_model_load_failed,
|
|
88
|
-
Exception: generic_error_handler,
|
|
89
|
-
},
|
|
90
114
|
)
|
|
91
115
|
app.state = app_state
|
|
92
116
|
app.include_router(control_app)
|
|
117
|
+
app.add_middleware(SanitizedExceptionMiddleware)
|
|
93
118
|
|
|
94
119
|
@app.on_event("shutdown")
|
|
95
120
|
def on_shutdown():
|
|
@@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Optional, Protocol
|
|
|
5
5
|
import httpx
|
|
6
6
|
from fastapi import APIRouter, WebSocket
|
|
7
7
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
8
|
+
from helpers.errors import ModelLoadFailed, ModelNotReady
|
|
8
9
|
from httpx_ws import AsyncWebSocketSession, WebSocketDisconnect, aconnect_ws
|
|
9
10
|
from httpx_ws import _exceptions as httpx_ws_exceptions
|
|
10
11
|
from starlette.requests import ClientDisconnect, Request
|
|
@@ -13,11 +14,6 @@ from starlette.websockets import WebSocketDisconnect as StartletteWebSocketDisco
|
|
|
13
14
|
from tenacity import RetryCallState, Retrying, retry_if_exception_type, wait_fixed
|
|
14
15
|
from wsproto.events import BytesMessage, TextMessage
|
|
15
16
|
|
|
16
|
-
from truss.templates.control.control.helpers.errors import (
|
|
17
|
-
ModelLoadFailed,
|
|
18
|
-
ModelNotReady,
|
|
19
|
-
)
|
|
20
|
-
|
|
21
17
|
INFERENCE_SERVER_START_WAIT_SECS = 60
|
|
22
18
|
BASE_RETRY_EXCEPTIONS = (
|
|
23
19
|
retry_if_exception_type(httpx.ConnectError)
|
|
@@ -5,6 +5,7 @@ from pathlib import Path
|
|
|
5
5
|
from typing import List, Optional
|
|
6
6
|
|
|
7
7
|
from helpers.context_managers import current_directory
|
|
8
|
+
from shared.util import kill_child_processes
|
|
8
9
|
|
|
9
10
|
INFERENCE_SERVER_FAILED_FILE = Path("~/inference_server_crashed.txt").expanduser()
|
|
10
11
|
TERMINATION_TIMEOUT_SECS = 120.0
|
|
@@ -46,17 +47,22 @@ class InferenceServerProcessController:
|
|
|
46
47
|
self._inference_server_ever_started = True
|
|
47
48
|
self._logged_unrecoverable_since_last_restart = False
|
|
48
49
|
|
|
50
|
+
def _terminate_children_and_process(self):
|
|
51
|
+
"""Kill child processes first, then parent. Prevents port binding conflicts."""
|
|
52
|
+
# Use a shorter timeout than the truss patch read timeout (=120s):
|
|
53
|
+
# see remote/baseten/api.py:_post_graphql_query()
|
|
54
|
+
kill_child_processes(self._inference_server_process.pid, timeout_seconds=30)
|
|
55
|
+
self._inference_server_process.terminate()
|
|
56
|
+
|
|
49
57
|
def stop(self):
|
|
50
58
|
if self._inference_server_process is not None:
|
|
51
|
-
self.
|
|
59
|
+
self._terminate_children_and_process()
|
|
52
60
|
self._inference_server_process.wait()
|
|
53
|
-
# Introduce delay to avoid failing to grab the port
|
|
54
|
-
time.sleep(3)
|
|
55
61
|
|
|
56
62
|
self._inference_server_started = False
|
|
57
63
|
|
|
58
64
|
def terminate_with_wait(self):
|
|
59
|
-
self.
|
|
65
|
+
self._terminate_children_and_process()
|
|
60
66
|
self._inference_server_terminated = True
|
|
61
67
|
termination_check_attempts = int(
|
|
62
68
|
TERMINATION_TIMEOUT_SECS / TERMINATION_CHECK_INTERVAL_SECS
|
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
import os
|
|
3
|
+
import shutil
|
|
2
4
|
import subprocess
|
|
3
5
|
from pathlib import Path
|
|
4
6
|
from typing import Optional
|
|
@@ -30,7 +32,7 @@ class ModelContainerPatchApplier:
|
|
|
30
32
|
self,
|
|
31
33
|
inference_server_home: Path,
|
|
32
34
|
app_logger: logging.Logger,
|
|
33
|
-
|
|
35
|
+
uv_path: Optional[str] = None, # Only meant for testing
|
|
34
36
|
) -> None:
|
|
35
37
|
self._inference_server_home = inference_server_home
|
|
36
38
|
self._model_module_dir = (
|
|
@@ -41,9 +43,19 @@ class ModelContainerPatchApplier:
|
|
|
41
43
|
).resolve()
|
|
42
44
|
self._data_dir = self._inference_server_home / self._truss_config.data_dir
|
|
43
45
|
self._app_logger = app_logger
|
|
44
|
-
self.
|
|
45
|
-
if
|
|
46
|
-
self.
|
|
46
|
+
self._uv_path_cached = None
|
|
47
|
+
if uv_path is not None:
|
|
48
|
+
self._uv_path_cached = uv_path
|
|
49
|
+
|
|
50
|
+
self._python_executable = self._get_python_executable()
|
|
51
|
+
|
|
52
|
+
def _get_python_executable(self) -> str:
|
|
53
|
+
# NB(nikhil): `uv` requires the full path to the python interpreter for patching
|
|
54
|
+
# python modules. We expect PYTHON_EXECUTABLE to exist in all development images, but
|
|
55
|
+
# we fallback to python3 as a default.
|
|
56
|
+
python_executable = os.environ.get("PYTHON_EXECUTABLE", "python3")
|
|
57
|
+
full_executable_path = shutil.which(python_executable)
|
|
58
|
+
return full_executable_path or python_executable
|
|
47
59
|
|
|
48
60
|
def __call__(self, patch: Patch, inf_env: dict):
|
|
49
61
|
self._app_logger.debug(f"Applying patch {patch.to_dict()}")
|
|
@@ -79,10 +91,10 @@ class ModelContainerPatchApplier:
|
|
|
79
91
|
return TrussConfig.from_yaml(self._inference_server_home / "config.yaml")
|
|
80
92
|
|
|
81
93
|
@property
|
|
82
|
-
def
|
|
83
|
-
if self.
|
|
84
|
-
self.
|
|
85
|
-
return self.
|
|
94
|
+
def _uv_path(self) -> str:
|
|
95
|
+
if self._uv_path_cached is None:
|
|
96
|
+
self._uv_path_cached = _identify_uv_path()
|
|
97
|
+
return self._uv_path_cached
|
|
86
98
|
|
|
87
99
|
def _apply_python_requirement_patch(
|
|
88
100
|
self, python_requirement_patch: PythonRequirementPatch
|
|
@@ -95,20 +107,25 @@ class ModelContainerPatchApplier:
|
|
|
95
107
|
if action == Action.REMOVE:
|
|
96
108
|
subprocess.run(
|
|
97
109
|
[
|
|
98
|
-
self.
|
|
110
|
+
self._uv_path,
|
|
111
|
+
"pip",
|
|
99
112
|
"uninstall",
|
|
100
|
-
"-y",
|
|
101
113
|
python_requirement_patch.requirement,
|
|
114
|
+
"--python",
|
|
115
|
+
self._python_executable,
|
|
102
116
|
],
|
|
103
117
|
check=True,
|
|
104
118
|
)
|
|
105
119
|
elif action in [Action.ADD, Action.UPDATE]:
|
|
106
120
|
subprocess.run(
|
|
107
121
|
[
|
|
108
|
-
self.
|
|
122
|
+
self._uv_path,
|
|
123
|
+
"pip",
|
|
109
124
|
"install",
|
|
110
125
|
python_requirement_patch.requirement,
|
|
111
126
|
"--upgrade",
|
|
127
|
+
"--python",
|
|
128
|
+
self._python_executable,
|
|
112
129
|
],
|
|
113
130
|
check=True,
|
|
114
131
|
)
|
|
@@ -158,11 +175,9 @@ class ModelContainerPatchApplier:
|
|
|
158
175
|
raise ValueError(f"Unknown patch action {action}")
|
|
159
176
|
|
|
160
177
|
|
|
161
|
-
def
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
if Path("/usr/local/bin/pip").exists():
|
|
166
|
-
return "/usr/local/bin/pip"
|
|
178
|
+
def _identify_uv_path() -> str:
|
|
179
|
+
uv_path = shutil.which("uv")
|
|
180
|
+
if not uv_path:
|
|
181
|
+
raise RuntimeError("Unable to find `uv`, make sure it's installed.")
|
|
167
182
|
|
|
168
|
-
|
|
183
|
+
return uv_path
|
|
@@ -72,9 +72,9 @@ class ControlServer:
|
|
|
72
72
|
# httptools installed, which does not work with our requests & version
|
|
73
73
|
# of uvicorn.
|
|
74
74
|
http="h11",
|
|
75
|
+
loop="uvloop",
|
|
75
76
|
**extra_kwargs,
|
|
76
77
|
)
|
|
77
|
-
cfg.setup_event_loop()
|
|
78
78
|
|
|
79
79
|
server = uvicorn.Server(cfg)
|
|
80
80
|
asyncio.run(server.serve())
|
|
@@ -7,7 +7,6 @@ python-json-logger>=2.0.2
|
|
|
7
7
|
tenacity>=8.1.0
|
|
8
8
|
# To avoid divergence, this should follow the latest release.
|
|
9
9
|
truss==0.11.1
|
|
10
|
-
|
|
11
|
-
uvicorn>=0.24.0,<0.36.0
|
|
10
|
+
uvicorn>=0.24.0
|
|
12
11
|
uvloop>=0.19.0
|
|
13
12
|
websockets>=10.0
|
|
@@ -45,6 +45,19 @@ server {
|
|
|
45
45
|
proxy_pass http://127.0.0.1:{{server_port}};
|
|
46
46
|
}
|
|
47
47
|
|
|
48
|
+
location ~ ^/v1/websocket$ {
|
|
49
|
+
proxy_redirect off;
|
|
50
|
+
proxy_read_timeout 18030s;
|
|
51
|
+
proxy_http_version 1.1;
|
|
52
|
+
|
|
53
|
+
proxy_set_header Upgrade $upgrade_header;
|
|
54
|
+
proxy_set_header Connection $connection_header;
|
|
55
|
+
|
|
56
|
+
rewrite ^/v1/websocket$ {{server_endpoint}} break;
|
|
57
|
+
|
|
58
|
+
proxy_pass http://127.0.0.1:{{server_port}};
|
|
59
|
+
}
|
|
60
|
+
|
|
48
61
|
# Forward all other paths
|
|
49
62
|
location / {
|
|
50
63
|
proxy_redirect off;
|
|
@@ -7,8 +7,9 @@ logfile_maxbytes=0 ; No size limit on logfile (since logging is disabl
|
|
|
7
7
|
[program:model-server]
|
|
8
8
|
command={{start_command}} ; Command to start the model server (provided by Jinja variable)
|
|
9
9
|
startsecs=30 ; Wait 30 seconds before assuming the server is running
|
|
10
|
+
startretries=0 ; Do not retry if server fails to start
|
|
10
11
|
autostart=true ; Automatically start the program when supervisord starts
|
|
11
|
-
autorestart=
|
|
12
|
+
autorestart=false ; Don't restart the program
|
|
12
13
|
stdout_logfile=/dev/fd/1 ; Send stdout to the first file descriptor (stdout)
|
|
13
14
|
stdout_logfile_maxbytes=0 ; No size limit on stdout log
|
|
14
15
|
redirect_stderr=true ; Redirect stderr to stdout
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
FROM {{ config.base_image.image }}
|
|
@@ -18,8 +18,7 @@ psutil>=5.9.4
|
|
|
18
18
|
python-json-logger>=2.0.2
|
|
19
19
|
pyyaml>=6.0.0
|
|
20
20
|
requests>=2.31.0
|
|
21
|
-
truss-transfer==0.0.
|
|
22
|
-
|
|
23
|
-
uvicorn>=0.24.0,<0.36.0
|
|
21
|
+
truss-transfer==0.0.38
|
|
22
|
+
uvicorn>=0.24.0
|
|
24
23
|
uvloop>=0.19.0
|
|
25
24
|
websockets>=10.0
|
|
@@ -185,7 +185,7 @@ class BasetenEndpoints:
|
|
|
185
185
|
request_id = request.headers.get("x-baseten-request-id")
|
|
186
186
|
|
|
187
187
|
logging.debug(
|
|
188
|
-
f"Request received - {request.method} {method.__name__} "
|
|
188
|
+
f"[DEBUG] Request received - {request.method} /{method.__name__} "
|
|
189
189
|
f", Request ID: {request_id}"
|
|
190
190
|
)
|
|
191
191
|
self.check_healthy()
|
|
@@ -470,9 +470,6 @@ class TrussServer:
|
|
|
470
470
|
if self._config["runtime"].get("enable_debug_logs", False)
|
|
471
471
|
else "INFO"
|
|
472
472
|
)
|
|
473
|
-
|
|
474
|
-
logging.info(f"Starting truss server with log level {log_level}")
|
|
475
|
-
logging.info(f"Config: {self._config["runtime"]}")
|
|
476
473
|
extra_kwargs = {}
|
|
477
474
|
# We don't pass these if not set, to not override the default.
|
|
478
475
|
if (
|
|
@@ -500,9 +497,9 @@ class TrussServer:
|
|
|
500
497
|
timeout_graceful_shutdown=TIMEOUT_GRACEFUL_SHUTDOWN,
|
|
501
498
|
log_config=log_config.make_log_config(log_level),
|
|
502
499
|
ws_max_size=WS_MAX_MSG_SZ_BYTES,
|
|
500
|
+
loop="uvloop",
|
|
503
501
|
**extra_kwargs,
|
|
504
502
|
)
|
|
505
|
-
cfg.setup_event_loop() # Call this so uvloop gets used
|
|
506
503
|
server = uvicorn.Server(config=cfg)
|
|
507
504
|
self._server = server
|
|
508
505
|
asyncio.run(server.serve())
|
|
@@ -56,12 +56,6 @@ RUN mkdir -p {{ dst.parent }}; curl -L "{{ url }}" -o {{ dst }}
|
|
|
56
56
|
{% endfor %} {#- endfor external_data_files #}
|
|
57
57
|
{%- endif %} {#- endif external_data_files #}
|
|
58
58
|
|
|
59
|
-
{%- if build_commands %}
|
|
60
|
-
{% for command in build_commands %}
|
|
61
|
-
RUN {% for secret,path in config.build.secret_to_path_mapping.items() %} --mount=type=secret,id={{ secret }},target={{ path }}{%- endfor %} {{ command }}
|
|
62
|
-
{% endfor %} {#- endfor build_commands #}
|
|
63
|
-
{%- endif %} {#- endif build_commands #}
|
|
64
|
-
|
|
65
59
|
{# Copy data before code for better caching #}
|
|
66
60
|
{%- if data_dir_exists %}
|
|
67
61
|
COPY --chown={{ default_owner }} ./{{ config.data_dir }} ${APP_HOME}/data
|
|
@@ -69,7 +63,7 @@ COPY --chown={{ default_owner }} ./{{ config.data_dir }} ${APP_HOME}/data
|
|
|
69
63
|
|
|
70
64
|
{%- if model_cache_v2 %}
|
|
71
65
|
{# v0.0.9, keep synced with server_requirements.txt #}
|
|
72
|
-
RUN curl -sSL --fail --retry 5 --retry-delay 2 -o /usr/local/bin/truss-transfer-cli https://github.com/basetenlabs/truss/releases/download/v0.
|
|
66
|
+
RUN curl -sSL --fail --retry 5 --retry-delay 2 -o /usr/local/bin/truss-transfer-cli https://github.com/basetenlabs/truss/releases/download/v0.11.13rc3/truss-transfer-cli-v0.11.13rc3-linux-x86_64-unknown-linux-musl
|
|
73
67
|
RUN chmod +x /usr/local/bin/truss-transfer-cli
|
|
74
68
|
RUN mkdir /static-bptr
|
|
75
69
|
RUN echo "hash {{model_cache_hash}}"
|
|
@@ -104,12 +98,18 @@ COPY --chown={{ default_owner }} ./{{ config.model_module_dir }} ${APP_HOME}/mod
|
|
|
104
98
|
{# Macro to change ownership of directories and switch to regular user #}
|
|
105
99
|
{%- macro chown_and_switch_to_regular_user_if_enabled(additional_chown_dirs=[]) -%}
|
|
106
100
|
{%- if non_root_user %}
|
|
107
|
-
RUN chown -R {{ app_username }}:{{ app_username }} {% for dir in additional_chown_dirs %}{{ dir }} {% endfor %}
|
|
101
|
+
RUN chown -R {{ app_username }}:{{ app_username }} ${HOME} ${APP_HOME} {{ packages_dir }} {% for dir in additional_chown_dirs %}{{ dir }} {% endfor %}
|
|
108
102
|
USER {{ app_username }}
|
|
109
103
|
{%- endif %} {#- endif non_root_user #}
|
|
110
104
|
{%- endmacro -%}
|
|
111
105
|
|
|
112
|
-
|
|
106
|
+
{%- if build_commands %}
|
|
107
|
+
{% for command in build_commands %}
|
|
108
|
+
RUN {% for secret,path in config.build.secret_to_path_mapping.items() %} --mount=type=secret,id={{ secret }},target={{ path }}{%- endfor %} {{ command }}
|
|
109
|
+
{% endfor %} {#- endfor build_commands #}
|
|
110
|
+
{%- endif %} {#- endif build_commands #}
|
|
111
|
+
|
|
112
|
+
{%- if config.docker_server %}
|
|
113
113
|
RUN apt-get update -y && apt-get install -y --no-install-recommends \
|
|
114
114
|
curl nginx && rm -rf /var/lib/apt/lists/*
|
|
115
115
|
COPY --chown={{ default_owner }} ./docker_server_requirements.txt ${APP_HOME}/docker_server_requirements.txt
|
|
@@ -131,7 +131,7 @@ RUN rm -f /etc/nginx/sites-enabled/default
|
|
|
131
131
|
{{ chown_and_switch_to_regular_user_if_enabled(["/var/lib/nginx", "/var/log/nginx", "/run"]) }}
|
|
132
132
|
ENTRYPOINT ["/docker_server/.venv/bin/supervisord", "-c", "{{ supervisor_config_path }}"]
|
|
133
133
|
|
|
134
|
-
|
|
134
|
+
{%- elif requires_live_reload %} {#- elif requires_live_reload #}
|
|
135
135
|
ENV HASH_TRUSS="{{ truss_hash }}"
|
|
136
136
|
ENV CONTROL_SERVER_PORT="8080"
|
|
137
137
|
ENV INFERENCE_SERVER_PORT="8090"
|
|
@@ -139,11 +139,11 @@ ENV SERVER_START_CMD="/control/.env/bin/python /control/control/server.py"
|
|
|
139
139
|
{{ chown_and_switch_to_regular_user_if_enabled() }}
|
|
140
140
|
ENTRYPOINT ["/control/.env/bin/python", "/control/control/server.py"]
|
|
141
141
|
|
|
142
|
-
|
|
142
|
+
{%- else %} {#- else (default inference server) #}
|
|
143
143
|
ENV INFERENCE_SERVER_PORT="8080"
|
|
144
144
|
ENV SERVER_START_CMD="{{ python_executable }} /app/main.py"
|
|
145
145
|
{{ chown_and_switch_to_regular_user_if_enabled() }}
|
|
146
146
|
ENTRYPOINT ["{{ python_executable }}", "/app/main.py"]
|
|
147
|
-
|
|
147
|
+
{%- endif %} {#- endif config.docker_server / live_reload #}
|
|
148
148
|
|
|
149
149
|
{% endblock %} {#- endblock run #}
|