truss 0.11.6rc102__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. truss/api/__init__.py +5 -2
  2. truss/base/constants.py +1 -0
  3. truss/base/trt_llm_config.py +14 -3
  4. truss/base/truss_config.py +19 -4
  5. truss/cli/chains_commands.py +49 -1
  6. truss/cli/cli.py +38 -7
  7. truss/cli/logs/base_watcher.py +31 -12
  8. truss/cli/logs/model_log_watcher.py +24 -1
  9. truss/cli/remote_cli.py +29 -0
  10. truss/cli/resolvers/chain_team_resolver.py +82 -0
  11. truss/cli/resolvers/model_team_resolver.py +90 -0
  12. truss/cli/resolvers/training_project_team_resolver.py +81 -0
  13. truss/cli/train/cache.py +332 -0
  14. truss/cli/train/core.py +57 -163
  15. truss/cli/train/deploy_checkpoints/__init__.py +2 -2
  16. truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +236 -103
  17. truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +1 -52
  18. truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +1 -86
  19. truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py +1 -85
  20. truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +1 -56
  21. truss/cli/train/types.py +18 -9
  22. truss/cli/train_commands.py +180 -35
  23. truss/cli/utils/common.py +40 -3
  24. truss/contexts/image_builder/serving_image_builder.py +17 -4
  25. truss/remote/baseten/api.py +215 -9
  26. truss/remote/baseten/core.py +63 -7
  27. truss/remote/baseten/custom_types.py +1 -0
  28. truss/remote/baseten/remote.py +42 -2
  29. truss/remote/baseten/service.py +0 -7
  30. truss/remote/baseten/utils/transfer.py +5 -2
  31. truss/templates/base.Dockerfile.jinja +8 -4
  32. truss/templates/control/control/application.py +51 -26
  33. truss/templates/control/control/endpoints.py +1 -5
  34. truss/templates/control/control/helpers/inference_server_process_controller.py +10 -4
  35. truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py +33 -18
  36. truss/templates/control/control/server.py +1 -1
  37. truss/templates/control/requirements.txt +1 -2
  38. truss/templates/docker_server/proxy.conf.jinja +13 -0
  39. truss/templates/docker_server/supervisord.conf.jinja +2 -1
  40. truss/templates/no_build.Dockerfile.jinja +1 -0
  41. truss/templates/server/requirements.txt +2 -3
  42. truss/templates/server/truss_server.py +2 -5
  43. truss/templates/server.Dockerfile.jinja +12 -12
  44. truss/templates/shared/lazy_data_resolver.py +214 -2
  45. truss/templates/shared/util.py +6 -5
  46. truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
  47. truss/tests/cli/test_chains_cli.py +144 -0
  48. truss/tests/cli/test_cli.py +134 -1
  49. truss/tests/cli/test_cli_utils_common.py +11 -0
  50. truss/tests/cli/test_model_team_resolver.py +279 -0
  51. truss/tests/cli/train/test_cache_view.py +240 -3
  52. truss/tests/cli/train/test_deploy_checkpoints.py +2 -846
  53. truss/tests/cli/train/test_train_cli_core.py +2 -2
  54. truss/tests/cli/train/test_train_team_parameter.py +395 -0
  55. truss/tests/conftest.py +187 -0
  56. truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
  57. truss/tests/remote/baseten/test_api.py +122 -3
  58. truss/tests/remote/baseten/test_chain_upload.py +294 -0
  59. truss/tests/remote/baseten/test_core.py +86 -0
  60. truss/tests/remote/baseten/test_remote.py +216 -288
  61. truss/tests/remote/baseten/test_service.py +56 -0
  62. truss/tests/templates/control/control/conftest.py +20 -0
  63. truss/tests/templates/control/control/test_endpoints.py +4 -0
  64. truss/tests/templates/control/control/test_server.py +8 -24
  65. truss/tests/templates/control/control/test_server_integration.py +4 -2
  66. truss/tests/test_config.py +21 -12
  67. truss/tests/test_data/server.Dockerfile +3 -1
  68. truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
  69. truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
  70. truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
  71. truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
  72. truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
  73. truss/tests/test_model_inference.py +13 -0
  74. truss/tests/util/test_env_vars.py +8 -3
  75. truss/util/__init__.py +0 -0
  76. truss/util/env_vars.py +19 -8
  77. truss/util/error_utils.py +37 -0
  78. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/METADATA +2 -2
  79. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/RECORD +88 -70
  80. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
  81. truss_chains/deployment/deployment_client.py +16 -4
  82. truss_chains/private_types.py +18 -0
  83. truss_chains/public_api.py +3 -0
  84. truss_train/definitions.py +6 -4
  85. truss_train/deployment.py +43 -21
  86. truss_train/public_api.py +4 -2
  87. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
  88. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
@@ -31,6 +31,7 @@ from truss.remote.baseten.core import (
31
31
  get_model_and_versions,
32
32
  get_prod_version_from_versions,
33
33
  get_truss_watch_state,
34
+ upload_chain_artifact,
34
35
  upload_truss,
35
36
  validate_truss_config_against_backend,
36
37
  )
@@ -68,6 +69,7 @@ class FinalPushData(custom_types.OracleData):
68
69
  origin: Optional[custom_types.ModelOrigin] = None
69
70
  environment: Optional[str] = None
70
71
  allow_truss_download: bool
72
+ team_id: Optional[str] = None
71
73
 
72
74
 
73
75
  class BasetenRemote(TrussRemote):
@@ -126,6 +128,8 @@ class BasetenRemote(TrussRemote):
126
128
  origin: Optional[custom_types.ModelOrigin] = None,
127
129
  environment: Optional[str] = None,
128
130
  progress_bar: Optional[Type["progress.Progress"]] = None,
131
+ deploy_timeout_minutes: Optional[int] = None,
132
+ team_id: Optional[str] = None,
129
133
  ) -> FinalPushData:
130
134
  if model_name.isspace():
131
135
  raise ValueError("Model name cannot be empty")
@@ -163,6 +167,13 @@ class BasetenRemote(TrussRemote):
163
167
  "Deployment name must only contain alphanumeric, -, _ and . characters"
164
168
  )
165
169
 
170
+ if deploy_timeout_minutes is not None and (
171
+ deploy_timeout_minutes < 10 or deploy_timeout_minutes > 1440
172
+ ):
173
+ raise ValueError(
174
+ "deploy-timeout-minutes must be between 10 minutes and 1440 minutes (24 hours)"
175
+ )
176
+
166
177
  model_id = exists_model(self._api, model_name)
167
178
 
168
179
  if model_id is not None and disable_truss_download:
@@ -187,6 +198,7 @@ class BasetenRemote(TrussRemote):
187
198
  origin=origin,
188
199
  environment=environment,
189
200
  allow_truss_download=not disable_truss_download,
201
+ team_id=team_id,
190
202
  )
191
203
 
192
204
  def push( # type: ignore
@@ -204,6 +216,8 @@ class BasetenRemote(TrussRemote):
204
216
  progress_bar: Optional[Type["progress.Progress"]] = None,
205
217
  include_git_info: bool = False,
206
218
  preserve_env_instance_type: bool = True,
219
+ deploy_timeout_minutes: Optional[int] = None,
220
+ team_id: Optional[str] = None,
207
221
  ) -> BasetenService:
208
222
  push_data = self._prepare_push(
209
223
  truss_handle=truss_handle,
@@ -216,6 +230,8 @@ class BasetenRemote(TrussRemote):
216
230
  origin=origin,
217
231
  environment=environment,
218
232
  progress_bar=progress_bar,
233
+ deploy_timeout_minutes=deploy_timeout_minutes,
234
+ team_id=team_id,
219
235
  )
220
236
 
221
237
  if include_git_info:
@@ -241,6 +257,8 @@ class BasetenRemote(TrussRemote):
241
257
  environment=push_data.environment,
242
258
  truss_user_env=truss_user_env,
243
259
  preserve_env_instance_type=preserve_env_instance_type,
260
+ deploy_timeout_minutes=deploy_timeout_minutes,
261
+ team_id=push_data.team_id,
244
262
  )
245
263
 
246
264
  if model_version_handle.instance_type_name:
@@ -263,9 +281,13 @@ class BasetenRemote(TrussRemote):
263
281
  entrypoint_artifact: custom_types.ChainletArtifact,
264
282
  dependency_artifacts: List[custom_types.ChainletArtifact],
265
283
  truss_user_env: b10_types.TrussUserEnv,
284
+ chain_root: Optional[Path] = None,
266
285
  publish: bool = False,
267
286
  environment: Optional[str] = None,
268
287
  progress_bar: Optional[Type["progress.Progress"]] = None,
288
+ disable_chain_download: bool = False,
289
+ deployment_name: Optional[str] = None,
290
+ team_id: Optional[str] = None,
269
291
  ) -> ChainDeploymentHandleAtomic:
270
292
  # If we are promoting a model to an environment after deploy, it must be published.
271
293
  # Draft models cannot be promoted.
@@ -285,6 +307,8 @@ class BasetenRemote(TrussRemote):
285
307
  publish=publish,
286
308
  origin=custom_types.ModelOrigin.CHAINS,
287
309
  progress_bar=progress_bar,
310
+ disable_truss_download=disable_chain_download,
311
+ deployment_name=deployment_name,
288
312
  )
289
313
  oracle_data = custom_types.OracleData(
290
314
  model_name=push_data.model_name,
@@ -300,6 +324,18 @@ class BasetenRemote(TrussRemote):
300
324
  )
301
325
  )
302
326
 
327
+ # Upload raw chain artifact if chain_root is provided
328
+ raw_chain_s3_key = None
329
+ if chain_root is not None:
330
+ logging.info("Uploading source artifact")
331
+ # Create a tar file from the chain root directory
332
+ original_source_tar = archive_dir(dir=chain_root, progress_bar=progress_bar)
333
+ # Upload the chain artifact to S3
334
+ raw_chain_s3_key = upload_chain_artifact(
335
+ api=self._api,
336
+ serialize_file=original_source_tar,
337
+ progress_bar=progress_bar,
338
+ )
303
339
  chain_deployment_handle = create_chain_atomic(
304
340
  api=self._api,
305
341
  chain_name=chain_name,
@@ -308,6 +344,10 @@ class BasetenRemote(TrussRemote):
308
344
  is_draft=not publish,
309
345
  truss_user_env=truss_user_env,
310
346
  environment=environment,
347
+ original_source_artifact_s3_key=raw_chain_s3_key,
348
+ allow_truss_download=not disable_chain_download,
349
+ deployment_name=deployment_name,
350
+ team_id=team_id,
311
351
  )
312
352
  logging.info("Successfully pushed to baseten. Chain is building and deploying.")
313
353
  return chain_deployment_handle
@@ -571,5 +611,5 @@ class BasetenRemote(TrussRemote):
571
611
  ) -> PatchResult:
572
612
  return self._patch(watch_path, truss_ignore_patterns, console=None)
573
613
 
574
- def upsert_training_project(self, training_project):
575
- return self._api.upsert_training_project(training_project)
614
+ def upsert_training_project(self, training_project, team_id=None):
615
+ return self._api.upsert_training_project(training_project, team_id=team_id)
@@ -137,13 +137,6 @@ class BasetenService(TrussService):
137
137
 
138
138
  return decode_content()
139
139
 
140
- parsed_response = response.json()
141
-
142
- if "error" in parsed_response:
143
- # In the case that the model is in a non-ready state, the response
144
- # will be a json with an `error` key.
145
- return parsed_response
146
-
147
140
  return response.json()
148
141
 
149
142
  def authenticate(self) -> dict:
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Optional, Type
7
7
  import boto3
8
8
  from boto3.s3.transfer import TransferConfig
9
9
 
10
- from truss.util.env_vars import override_env_vars
10
+ from truss.util.env_vars import modify_env_vars
11
11
 
12
12
  if TYPE_CHECKING:
13
13
  from rich import progress
@@ -26,7 +26,10 @@ def multipart_upload_boto3(
26
26
  ) -> None:
27
27
  # In the CLI flow, ignore any local ~/.aws/config files,
28
28
  # which can interfere with uploading the Truss to S3.
29
- with override_env_vars({"AWS_CONFIG_FILE": ""}):
29
+ aws_env_vars = set(
30
+ env_var for env_var in os.environ.keys() if env_var.startswith("AWS_")
31
+ )
32
+ with modify_env_vars(deletions=aws_env_vars):
30
33
  s3_resource = boto3.resource("s3", **credentials)
31
34
  filesize = os.stat(file_path).st_size
32
35
 
@@ -33,7 +33,7 @@ RUN useradd -u {{ app_user_uid }} -ms /bin/bash {{ app_username }}
33
33
  ENV DEBIAN_FRONTEND=noninteractive
34
34
 
35
35
 
36
- {%- set UV_VERSION = "0.7.19" %}
36
+ {%- set UV_VERSION = "0.8.22" %}
37
37
  {#
38
38
  NB(nikhil): We use a semi-complex uv installation command across the board:
39
39
  - A generous UV_HTTP_TIMEOUT (5m) for packages that take a long time to install.
@@ -59,10 +59,12 @@ RUN {{ python_exec_path }} -c "import sys; \
59
59
  {% endblock %}
60
60
 
61
61
  {% block install_uv %}
62
- {# Install `uv` and `curl` if not already present in the image. #}
62
+ {# Install `uv` and `curl` if not already present in the image. We validate the expected location for `uv` at the very end
63
+ due to limitations with `pipefail` in Docker context. #}
63
64
  RUN if ! command -v uv >/dev/null 2>&1; then \
64
65
  command -v curl >/dev/null 2>&1 || (apt update && apt install -y curl) && \
65
- curl -LsSf --retry 5 --retry-delay 5 https://astral.sh/uv/{{ UV_VERSION }}/install.sh | sh; \
66
+ curl -LsSf --retry 5 --retry-delay 5 https://astral.sh/uv/{{ UV_VERSION }}/install.sh | sh && \
67
+ test -x ${HOME}/.local/bin/uv; \
66
68
  fi
67
69
  {# Add the user's local bin to the path, used by uv. #}
68
70
  ENV PATH=${PATH}:${HOME}/.local/bin
@@ -113,9 +115,11 @@ WORKDIR $APP_HOME
113
115
  {% endblock %}
114
116
 
115
117
 
118
+ {% set packages_dir = "/packages" %}
119
+ RUN mkdir -p {{ packages_dir }}
116
120
  {% block bundled_packages_copy %}
117
121
  {%- if bundled_packages_dir_exists %}
118
- COPY --chown={{ default_owner }} ./{{ config.bundled_packages_dir }} /packages
122
+ COPY --chown={{ default_owner }} ./{{ config.bundled_packages_dir }} {{ packages_dir }}
119
123
  {%- endif %}
120
124
  {% endblock %}
121
125
 
@@ -1,13 +1,15 @@
1
1
  import asyncio
2
+ import http
2
3
  import logging
3
4
  import logging.config
4
5
  import re
6
+ import traceback
5
7
  from pathlib import Path
6
- from typing import Dict
8
+ from typing import Awaitable, Callable, Dict
7
9
 
8
10
  import httpx
9
11
  from endpoints import control_app
10
- from fastapi import FastAPI
12
+ from fastapi import FastAPI, Request, Response
11
13
  from fastapi.responses import JSONResponse
12
14
  from helpers.errors import ModelLoadFailed, PatchApplicatonError
13
15
  from helpers.inference_server_controller import InferenceServerController
@@ -16,22 +18,50 @@ from helpers.inference_server_starter import async_inference_server_startup_flow
16
18
  from helpers.truss_patch.model_container_patch_applier import ModelContainerPatchApplier
17
19
  from shared import log_config
18
20
  from starlette.datastructures import State
19
-
20
-
21
- async def handle_patch_error(_, exc):
22
- error_type = _camel_to_snake_case(type(exc).__name__)
23
- return JSONResponse(content={"error": {"type": error_type, "msg": str(exc)}})
24
-
25
-
26
- async def generic_error_handler(_, exc):
27
- return JSONResponse(
28
- content={"error": {"type": "unknown", "msg": f"{type(exc)}: {exc}"}}
29
- )
30
-
31
-
32
- async def handle_model_load_failed(_, error):
33
- # Model load failures should result in 503 status
34
- return JSONResponse({"error": str(error)}, 503)
21
+ from starlette.middleware.base import BaseHTTPMiddleware
22
+
23
+ SANITIZED_EXCEPTION_FRAMES = 2
24
+
25
+
26
+ # NB(nikhil): SanitizedExceptionMiddleware will reduce the noise of control server stack frames, since
27
+ # users often complain about the verbosity. Now, if any exceptions are explicitly raised during a proxied
28
+ # request, we'll log the last two stack frames which should be sufficient for debugging while significantly
29
+ # cutting down the volume.
30
+ class SanitizedExceptionMiddleware(BaseHTTPMiddleware):
31
+ def __init__(self, app, num_frames: int = SANITIZED_EXCEPTION_FRAMES):
32
+ super().__init__(app)
33
+ self.num_frames = num_frames
34
+
35
+ async def dispatch(
36
+ self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
37
+ ) -> Response:
38
+ try:
39
+ return await call_next(request)
40
+ except Exception as exc:
41
+ # NB(nikhil): Intentionally bypass error logging for ModelLoadFailed, since health checks
42
+ # are noisy. The underlying model logs for why the load failed will still be visible.
43
+ if isinstance(exc, ModelLoadFailed):
44
+ return JSONResponse(
45
+ {"error": str(exc)}, status_code=http.HTTPStatus.BAD_GATEWAY.value
46
+ )
47
+
48
+ sanitized_traceback = self._create_sanitized_traceback(exc)
49
+ request.app.state.logger.error(sanitized_traceback)
50
+
51
+ if isinstance(exc, PatchApplicatonError):
52
+ error_type = _camel_to_snake_case(type(exc).__name__)
53
+ return JSONResponse({"error": {"type": error_type, "msg": str(exc)}})
54
+ else:
55
+ return JSONResponse(
56
+ {"error": {"type": "unknown", "msg": str(exc)}},
57
+ status_code=http.HTTPStatus.INTERNAL_SERVER_ERROR.value,
58
+ )
59
+
60
+ def _create_sanitized_traceback(self, error: Exception) -> str:
61
+ tb_lines = traceback.format_tb(error.__traceback__)
62
+ if tb_lines and self.num_frames > 0:
63
+ return "".join(tb_lines[-self.num_frames :])
64
+ return f"{type(error).__name__}: {error}"
35
65
 
36
66
 
37
67
  def create_app(base_config: Dict):
@@ -57,10 +87,9 @@ def create_app(base_config: Dict):
57
87
  base_url=f"http://localhost:{app_state.inference_server_port}", limits=limits
58
88
  )
59
89
 
60
- pip_path = getattr(app_state, "pip_path", None)
61
-
90
+ uv_path = getattr(app_state, "uv_path", None)
62
91
  patch_applier = ModelContainerPatchApplier(
63
- Path(app_state.inference_server_home), app_logger, pip_path
92
+ Path(app_state.inference_server_home), app_logger, uv_path
64
93
  )
65
94
 
66
95
  oversee_inference_server = getattr(app_state, "oversee_inference_server", True)
@@ -82,14 +111,10 @@ def create_app(base_config: Dict):
82
111
  app = FastAPI(
83
112
  title="Truss Live Reload Server",
84
113
  on_startup=[start_background_inference_startup],
85
- exception_handlers={
86
- PatchApplicatonError: handle_patch_error,
87
- ModelLoadFailed: handle_model_load_failed,
88
- Exception: generic_error_handler,
89
- },
90
114
  )
91
115
  app.state = app_state
92
116
  app.include_router(control_app)
117
+ app.add_middleware(SanitizedExceptionMiddleware)
93
118
 
94
119
  @app.on_event("shutdown")
95
120
  def on_shutdown():
@@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Optional, Protocol
5
5
  import httpx
6
6
  from fastapi import APIRouter, WebSocket
7
7
  from fastapi.responses import JSONResponse, StreamingResponse
8
+ from helpers.errors import ModelLoadFailed, ModelNotReady
8
9
  from httpx_ws import AsyncWebSocketSession, WebSocketDisconnect, aconnect_ws
9
10
  from httpx_ws import _exceptions as httpx_ws_exceptions
10
11
  from starlette.requests import ClientDisconnect, Request
@@ -13,11 +14,6 @@ from starlette.websockets import WebSocketDisconnect as StartletteWebSocketDisco
13
14
  from tenacity import RetryCallState, Retrying, retry_if_exception_type, wait_fixed
14
15
  from wsproto.events import BytesMessage, TextMessage
15
16
 
16
- from truss.templates.control.control.helpers.errors import (
17
- ModelLoadFailed,
18
- ModelNotReady,
19
- )
20
-
21
17
  INFERENCE_SERVER_START_WAIT_SECS = 60
22
18
  BASE_RETRY_EXCEPTIONS = (
23
19
  retry_if_exception_type(httpx.ConnectError)
@@ -5,6 +5,7 @@ from pathlib import Path
5
5
  from typing import List, Optional
6
6
 
7
7
  from helpers.context_managers import current_directory
8
+ from shared.util import kill_child_processes
8
9
 
9
10
  INFERENCE_SERVER_FAILED_FILE = Path("~/inference_server_crashed.txt").expanduser()
10
11
  TERMINATION_TIMEOUT_SECS = 120.0
@@ -46,17 +47,22 @@ class InferenceServerProcessController:
46
47
  self._inference_server_ever_started = True
47
48
  self._logged_unrecoverable_since_last_restart = False
48
49
 
50
+ def _terminate_children_and_process(self):
51
+ """Kill child processes first, then parent. Prevents port binding conflicts."""
52
+ # Use a shorter timeout than the truss patch read timeout (=120s):
53
+ # see remote/baseten/api.py:_post_graphql_query()
54
+ kill_child_processes(self._inference_server_process.pid, timeout_seconds=30)
55
+ self._inference_server_process.terminate()
56
+
49
57
  def stop(self):
50
58
  if self._inference_server_process is not None:
51
- self._inference_server_process.terminate()
59
+ self._terminate_children_and_process()
52
60
  self._inference_server_process.wait()
53
- # Introduce delay to avoid failing to grab the port
54
- time.sleep(3)
55
61
 
56
62
  self._inference_server_started = False
57
63
 
58
64
  def terminate_with_wait(self):
59
- self._inference_server_process.terminate()
65
+ self._terminate_children_and_process()
60
66
  self._inference_server_terminated = True
61
67
  termination_check_attempts = int(
62
68
  TERMINATION_TIMEOUT_SECS / TERMINATION_CHECK_INTERVAL_SECS
@@ -1,4 +1,6 @@
1
1
  import logging
2
+ import os
3
+ import shutil
2
4
  import subprocess
3
5
  from pathlib import Path
4
6
  from typing import Optional
@@ -30,7 +32,7 @@ class ModelContainerPatchApplier:
30
32
  self,
31
33
  inference_server_home: Path,
32
34
  app_logger: logging.Logger,
33
- pip_path: Optional[str] = None, # Only meant for testing
35
+ uv_path: Optional[str] = None, # Only meant for testing
34
36
  ) -> None:
35
37
  self._inference_server_home = inference_server_home
36
38
  self._model_module_dir = (
@@ -41,9 +43,19 @@ class ModelContainerPatchApplier:
41
43
  ).resolve()
42
44
  self._data_dir = self._inference_server_home / self._truss_config.data_dir
43
45
  self._app_logger = app_logger
44
- self._pip_path_cached = None
45
- if pip_path is not None:
46
- self._pip_path_cached = "pip"
46
+ self._uv_path_cached = None
47
+ if uv_path is not None:
48
+ self._uv_path_cached = uv_path
49
+
50
+ self._python_executable = self._get_python_executable()
51
+
52
+ def _get_python_executable(self) -> str:
53
+ # NB(nikhil): `uv` requires the full path to the python interpreter for patching
54
+ # python modules. We expect PYTHON_EXECUTABLE to exist in all development images, but
55
+ # we fallback to python3 as a default.
56
+ python_executable = os.environ.get("PYTHON_EXECUTABLE", "python3")
57
+ full_executable_path = shutil.which(python_executable)
58
+ return full_executable_path or python_executable
47
59
 
48
60
  def __call__(self, patch: Patch, inf_env: dict):
49
61
  self._app_logger.debug(f"Applying patch {patch.to_dict()}")
@@ -79,10 +91,10 @@ class ModelContainerPatchApplier:
79
91
  return TrussConfig.from_yaml(self._inference_server_home / "config.yaml")
80
92
 
81
93
  @property
82
- def _pip_path(self) -> str:
83
- if self._pip_path_cached is None:
84
- self._pip_path_cached = _identify_pip_path()
85
- return self._pip_path_cached
94
+ def _uv_path(self) -> str:
95
+ if self._uv_path_cached is None:
96
+ self._uv_path_cached = _identify_uv_path()
97
+ return self._uv_path_cached
86
98
 
87
99
  def _apply_python_requirement_patch(
88
100
  self, python_requirement_patch: PythonRequirementPatch
@@ -95,20 +107,25 @@ class ModelContainerPatchApplier:
95
107
  if action == Action.REMOVE:
96
108
  subprocess.run(
97
109
  [
98
- self._pip_path,
110
+ self._uv_path,
111
+ "pip",
99
112
  "uninstall",
100
- "-y",
101
113
  python_requirement_patch.requirement,
114
+ "--python",
115
+ self._python_executable,
102
116
  ],
103
117
  check=True,
104
118
  )
105
119
  elif action in [Action.ADD, Action.UPDATE]:
106
120
  subprocess.run(
107
121
  [
108
- self._pip_path,
122
+ self._uv_path,
123
+ "pip",
109
124
  "install",
110
125
  python_requirement_patch.requirement,
111
126
  "--upgrade",
127
+ "--python",
128
+ self._python_executable,
112
129
  ],
113
130
  check=True,
114
131
  )
@@ -158,11 +175,9 @@ class ModelContainerPatchApplier:
158
175
  raise ValueError(f"Unknown patch action {action}")
159
176
 
160
177
 
161
- def _identify_pip_path() -> str:
162
- if Path("/usr/local/bin/pip3").exists():
163
- return "/usr/local/bin/pip3"
164
-
165
- if Path("/usr/local/bin/pip").exists():
166
- return "/usr/local/bin/pip"
178
+ def _identify_uv_path() -> str:
179
+ uv_path = shutil.which("uv")
180
+ if not uv_path:
181
+ raise RuntimeError("Unable to find `uv`, make sure it's installed.")
167
182
 
168
- raise RuntimeError("Unable to find pip, make sure it's installed.")
183
+ return uv_path
@@ -72,9 +72,9 @@ class ControlServer:
72
72
  # httptools installed, which does not work with our requests & version
73
73
  # of uvicorn.
74
74
  http="h11",
75
+ loop="uvloop",
75
76
  **extra_kwargs,
76
77
  )
77
- cfg.setup_event_loop()
78
78
 
79
79
  server = uvicorn.Server(cfg)
80
80
  asyncio.run(server.serve())
@@ -7,7 +7,6 @@ python-json-logger>=2.0.2
7
7
  tenacity>=8.1.0
8
8
  # To avoid divergence, this should follow the latest release.
9
9
  truss==0.11.1
10
- # NB(nikhil): Uvicorn 0.36.0 has breaking changes for the event loop, so we pin to a lower version.
11
- uvicorn>=0.24.0,<0.36.0
10
+ uvicorn>=0.24.0
12
11
  uvloop>=0.19.0
13
12
  websockets>=10.0
@@ -45,6 +45,19 @@ server {
45
45
  proxy_pass http://127.0.0.1:{{server_port}};
46
46
  }
47
47
 
48
+ location ~ ^/v1/websocket$ {
49
+ proxy_redirect off;
50
+ proxy_read_timeout 18030s;
51
+ proxy_http_version 1.1;
52
+
53
+ proxy_set_header Upgrade $upgrade_header;
54
+ proxy_set_header Connection $connection_header;
55
+
56
+ rewrite ^/v1/websocket$ {{server_endpoint}} break;
57
+
58
+ proxy_pass http://127.0.0.1:{{server_port}};
59
+ }
60
+
48
61
  # Forward all other paths
49
62
  location / {
50
63
  proxy_redirect off;
@@ -7,8 +7,9 @@ logfile_maxbytes=0 ; No size limit on logfile (since logging is disabl
7
7
  [program:model-server]
8
8
  command={{start_command}} ; Command to start the model server (provided by Jinja variable)
9
9
  startsecs=30 ; Wait 30 seconds before assuming the server is running
10
+ startretries=0 ; Do not retry if server fails to start
10
11
  autostart=true ; Automatically start the program when supervisord starts
11
- autorestart=true ; Always restart the program if it exits, no matter what the exit code
12
+ autorestart=false ; Don't restart the program
12
13
  stdout_logfile=/dev/fd/1 ; Send stdout to the first file descriptor (stdout)
13
14
  stdout_logfile_maxbytes=0 ; No size limit on stdout log
14
15
  redirect_stderr=true ; Redirect stderr to stdout
@@ -0,0 +1 @@
1
+ FROM {{ config.base_image.image }}
@@ -18,8 +18,7 @@ psutil>=5.9.4
18
18
  python-json-logger>=2.0.2
19
19
  pyyaml>=6.0.0
20
20
  requests>=2.31.0
21
- truss-transfer==0.0.31
22
- # NB(nikhil): Uvicorn 0.36.0 has breaking changes for the event loop, so we pin to a lower version.
23
- uvicorn>=0.24.0,<0.36.0
21
+ truss-transfer==0.0.38
22
+ uvicorn>=0.24.0
24
23
  uvloop>=0.19.0
25
24
  websockets>=10.0
@@ -185,7 +185,7 @@ class BasetenEndpoints:
185
185
  request_id = request.headers.get("x-baseten-request-id")
186
186
 
187
187
  logging.debug(
188
- f"Request received - {request.method} {method.__name__} "
188
+ f"[DEBUG] Request received - {request.method} /{method.__name__} "
189
189
  f", Request ID: {request_id}"
190
190
  )
191
191
  self.check_healthy()
@@ -470,9 +470,6 @@ class TrussServer:
470
470
  if self._config["runtime"].get("enable_debug_logs", False)
471
471
  else "INFO"
472
472
  )
473
-
474
- logging.info(f"Starting truss server with log level {log_level}")
475
- logging.info(f"Config: {self._config["runtime"]}")
476
473
  extra_kwargs = {}
477
474
  # We don't pass these if not set, to not override the default.
478
475
  if (
@@ -500,9 +497,9 @@ class TrussServer:
500
497
  timeout_graceful_shutdown=TIMEOUT_GRACEFUL_SHUTDOWN,
501
498
  log_config=log_config.make_log_config(log_level),
502
499
  ws_max_size=WS_MAX_MSG_SZ_BYTES,
500
+ loop="uvloop",
503
501
  **extra_kwargs,
504
502
  )
505
- cfg.setup_event_loop() # Call this so uvloop gets used
506
503
  server = uvicorn.Server(config=cfg)
507
504
  self._server = server
508
505
  asyncio.run(server.serve())
@@ -56,12 +56,6 @@ RUN mkdir -p {{ dst.parent }}; curl -L "{{ url }}" -o {{ dst }}
56
56
  {% endfor %} {#- endfor external_data_files #}
57
57
  {%- endif %} {#- endif external_data_files #}
58
58
 
59
- {%- if build_commands %}
60
- {% for command in build_commands %}
61
- RUN {% for secret,path in config.build.secret_to_path_mapping.items() %} --mount=type=secret,id={{ secret }},target={{ path }}{%- endfor %} {{ command }}
62
- {% endfor %} {#- endfor build_commands #}
63
- {%- endif %} {#- endif build_commands #}
64
-
65
59
  {# Copy data before code for better caching #}
66
60
  {%- if data_dir_exists %}
67
61
  COPY --chown={{ default_owner }} ./{{ config.data_dir }} ${APP_HOME}/data
@@ -69,7 +63,7 @@ COPY --chown={{ default_owner }} ./{{ config.data_dir }} ${APP_HOME}/data
69
63
 
70
64
  {%- if model_cache_v2 %}
71
65
  {# v0.0.9, keep synced with server_requirements.txt #}
72
- RUN curl -sSL --fail --retry 5 --retry-delay 2 -o /usr/local/bin/truss-transfer-cli https://github.com/basetenlabs/truss/releases/download/v0.10.11rc1/truss-transfer-cli-v0.10.11rc1-linux-x86_64-unknown-linux-musl
66
+ RUN curl -sSL --fail --retry 5 --retry-delay 2 -o /usr/local/bin/truss-transfer-cli https://github.com/basetenlabs/truss/releases/download/v0.11.13rc3/truss-transfer-cli-v0.11.13rc3-linux-x86_64-unknown-linux-musl
73
67
  RUN chmod +x /usr/local/bin/truss-transfer-cli
74
68
  RUN mkdir /static-bptr
75
69
  RUN echo "hash {{model_cache_hash}}"
@@ -104,12 +98,18 @@ COPY --chown={{ default_owner }} ./{{ config.model_module_dir }} ${APP_HOME}/mod
104
98
  {# Macro to change ownership of directories and switch to regular user #}
105
99
  {%- macro chown_and_switch_to_regular_user_if_enabled(additional_chown_dirs=[]) -%}
106
100
  {%- if non_root_user %}
107
- RUN chown -R {{ app_username }}:{{ app_username }} {% for dir in additional_chown_dirs %}{{ dir }} {% endfor %}${HOME} ${APP_HOME}
101
+ RUN chown -R {{ app_username }}:{{ app_username }} ${HOME} ${APP_HOME} {{ packages_dir }} {% for dir in additional_chown_dirs %}{{ dir }} {% endfor %}
108
102
  USER {{ app_username }}
109
103
  {%- endif %} {#- endif non_root_user #}
110
104
  {%- endmacro -%}
111
105
 
112
- {%- if config.docker_server %}
106
+ {%- if build_commands %}
107
+ {% for command in build_commands %}
108
+ RUN {% for secret,path in config.build.secret_to_path_mapping.items() %} --mount=type=secret,id={{ secret }},target={{ path }}{%- endfor %} {{ command }}
109
+ {% endfor %} {#- endfor build_commands #}
110
+ {%- endif %} {#- endif build_commands #}
111
+
112
+ {%- if config.docker_server %}
113
113
  RUN apt-get update -y && apt-get install -y --no-install-recommends \
114
114
  curl nginx && rm -rf /var/lib/apt/lists/*
115
115
  COPY --chown={{ default_owner }} ./docker_server_requirements.txt ${APP_HOME}/docker_server_requirements.txt
@@ -131,7 +131,7 @@ RUN rm -f /etc/nginx/sites-enabled/default
131
131
  {{ chown_and_switch_to_regular_user_if_enabled(["/var/lib/nginx", "/var/log/nginx", "/run"]) }}
132
132
  ENTRYPOINT ["/docker_server/.venv/bin/supervisord", "-c", "{{ supervisor_config_path }}"]
133
133
 
134
- {%- elif requires_live_reload %} {#- elif requires_live_reload #}
134
+ {%- elif requires_live_reload %} {#- elif requires_live_reload #}
135
135
  ENV HASH_TRUSS="{{ truss_hash }}"
136
136
  ENV CONTROL_SERVER_PORT="8080"
137
137
  ENV INFERENCE_SERVER_PORT="8090"
@@ -139,11 +139,11 @@ ENV SERVER_START_CMD="/control/.env/bin/python /control/control/server.py"
139
139
  {{ chown_and_switch_to_regular_user_if_enabled() }}
140
140
  ENTRYPOINT ["/control/.env/bin/python", "/control/control/server.py"]
141
141
 
142
- {%- else %} {#- else (default inference server) #}
142
+ {%- else %} {#- else (default inference server) #}
143
143
  ENV INFERENCE_SERVER_PORT="8080"
144
144
  ENV SERVER_START_CMD="{{ python_executable }} /app/main.py"
145
145
  {{ chown_and_switch_to_regular_user_if_enabled() }}
146
146
  ENTRYPOINT ["{{ python_executable }}", "/app/main.py"]
147
- {%- endif %} {#- endif config.docker_server / live_reload #}
147
+ {%- endif %} {#- endif config.docker_server / live_reload #}
148
148
 
149
149
  {% endblock %} {#- endblock run #}