truss 0.10.13__py3-none-any.whl → 0.11.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

Files changed (27) hide show
  1. truss/cli/chains_commands.py +1 -1
  2. truss/cli/train/core.py +82 -31
  3. truss/contexts/image_builder/serving_image_builder.py +7 -0
  4. truss/contexts/local_loader/docker_build_emulator.py +32 -8
  5. truss/remote/baseten/custom_types.py +7 -0
  6. truss/templates/base.Dockerfile.jinja +35 -6
  7. truss/templates/cache.Dockerfile.jinja +8 -7
  8. truss/templates/control/control/endpoints.py +72 -32
  9. truss/templates/copy_cache_files.Dockerfile.jinja +1 -1
  10. truss/templates/docker_server/supervisord.conf.jinja +1 -0
  11. truss/templates/server/truss_server.py +3 -3
  12. truss/templates/server.Dockerfile.jinja +33 -19
  13. truss/tests/cli/train/test_train_cli_core.py +254 -1
  14. truss/tests/contexts/image_builder/test_serving_image_builder.py +1 -1
  15. truss/tests/templates/control/control/test_endpoints.py +22 -14
  16. truss/tests/templates/control/control/test_server_integration.py +62 -41
  17. truss/tests/templates/server/test_truss_server.py +19 -12
  18. truss/tests/test_data/server.Dockerfile +13 -10
  19. truss/tests/test_model_inference.py +4 -2
  20. {truss-0.10.13.dist-info → truss-0.11.1rc1.dist-info}/METADATA +1 -1
  21. {truss-0.10.13.dist-info → truss-0.11.1rc1.dist-info}/RECORD +27 -27
  22. truss_chains/deployment/deployment_client.py +4 -2
  23. truss_chains/public_types.py +1 -0
  24. truss_chains/remote_chainlet/utils.py +8 -0
  25. {truss-0.10.13.dist-info → truss-0.11.1rc1.dist-info}/WHEEL +0 -0
  26. {truss-0.10.13.dist-info → truss-0.11.1rc1.dist-info}/entry_points.txt +0 -0
  27. {truss-0.10.13.dist-info → truss-0.11.1rc1.dist-info}/licenses/LICENSE +0 -0
@@ -359,7 +359,7 @@ def push_chain(
359
359
  "--name",
360
360
  type=str,
361
361
  required=False,
362
- help="Name of the chain to be deployed, if not given, the entrypoint name is used.",
362
+ help="Name of the chain to be watched. If not given, the entrypoint name is used.",
363
363
  )
364
364
  @click.option(
365
365
  "--remote",
truss/cli/train/core.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import json
2
+ import os
2
3
  import tarfile
3
4
  import tempfile
4
5
  from dataclasses import dataclass
@@ -16,6 +17,11 @@ from truss.cli.train.metrics_watcher import MetricsWatcher
16
17
  from truss.cli.train.types import PrepareCheckpointArgs, PrepareCheckpointResult
17
18
  from truss.cli.utils import common as cli_common
18
19
  from truss.cli.utils.output import console
20
+ from truss.remote.baseten.custom_types import (
21
+ FileSummary,
22
+ FileSummaryWithTotalSize,
23
+ GetCacheSummaryResponseV1,
24
+ )
19
25
  from truss.remote.baseten.remote import BasetenRemote
20
26
  from truss_train import loader
21
27
  from truss_train.definitions import DeployCheckpointsConfig
@@ -446,6 +452,44 @@ def fetch_project_by_name_or_id(
446
452
  raise click.ClickException(f"Error fetching project: {str(e)}")
447
453
 
448
454
 
455
+ def create_file_summary_with_directory_sizes(
456
+ files: list[FileSummary],
457
+ ) -> list[FileSummaryWithTotalSize]:
458
+ directory_sizes = calculate_directory_sizes(files)
459
+ return [
460
+ FileSummaryWithTotalSize(
461
+ file_summary=file_info,
462
+ total_size=directory_sizes.get(file_info.path, file_info.size_bytes),
463
+ )
464
+ for file_info in files
465
+ ]
466
+
467
+
468
+ def calculate_directory_sizes(
469
+ files: list[FileSummary], max_depth: int = 100
470
+ ) -> dict[str, int]:
471
+ directory_sizes = {}
472
+
473
+ for file_info in files:
474
+ if file_info.file_type == "directory":
475
+ directory_sizes[file_info.path] = 0
476
+
477
+ for file_info in files:
478
+ current_path = file_info.path
479
+ for i in range(max_depth):
480
+ if current_path is None:
481
+ break
482
+ if current_path in directory_sizes:
483
+ directory_sizes[current_path] += file_info.size_bytes
484
+ # Move to parent directory
485
+ parent = os.path.dirname(current_path)
486
+ if parent == current_path: # Reached root
487
+ break
488
+ current_path = parent
489
+
490
+ return directory_sizes
491
+
492
+
449
493
  def view_cache_summary(
450
494
  remote_provider: BasetenRemote,
451
495
  project_id: str,
@@ -454,12 +498,14 @@ def view_cache_summary(
454
498
  ):
455
499
  """View cache summary for a training project."""
456
500
  try:
457
- cache_data = remote_provider.api.get_cache_summary(project_id)
501
+ raw_cache_data = remote_provider.api.get_cache_summary(project_id)
458
502
 
459
- if not cache_data:
503
+ if not raw_cache_data:
460
504
  console.print("No cache summary found for this project.", style="yellow")
461
505
  return
462
506
 
507
+ cache_data = GetCacheSummaryResponseV1.model_validate(raw_cache_data)
508
+
463
509
  table = rich.table.Table(title=f"Cache summary for project: {project_id}")
464
510
  table.add_column("File Path", style="cyan")
465
511
  table.add_column("Size", style="green")
@@ -467,58 +513,48 @@ def view_cache_summary(
467
513
  table.add_column("Type")
468
514
  table.add_column("Permissions", style="magenta")
469
515
 
470
- files = cache_data.get("file_summaries", [])
516
+ files = cache_data.file_summaries
471
517
  if not files:
472
518
  console.print("No files found in cache.", style="yellow")
473
519
  return
474
520
 
475
- reverse = order == SORT_ORDER_DESC
521
+ files_with_total_sizes = create_file_summary_with_directory_sizes(files)
476
522
 
477
- if sort_by == SORT_BY_FILEPATH:
478
- files.sort(key=lambda x: x.get("path", ""), reverse=reverse)
479
- elif sort_by == SORT_BY_SIZE:
480
- files.sort(key=lambda x: x.get("size_bytes", 0), reverse=reverse)
481
- elif sort_by == SORT_BY_MODIFIED:
482
- files.sort(key=lambda x: x.get("modified", ""), reverse=reverse)
483
- elif sort_by == SORT_BY_TYPE:
484
- files.sort(key=lambda x: x.get("file_type", ""), reverse=reverse)
485
- elif sort_by == SORT_BY_PERMISSIONS:
486
- files.sort(key=lambda x: x.get("permissions", ""), reverse=reverse)
487
-
488
- total_size = 0
489
- for file_info in files:
490
- total_size += file_info.get("size_bytes", 0)
523
+ reverse = order == SORT_ORDER_DESC
524
+ sort_key = _get_sort_key(sort_by)
525
+ files_with_total_sizes.sort(key=sort_key, reverse=reverse)
491
526
 
527
+ total_size = sum(
528
+ file_info.file_summary.size_bytes for file_info in files_with_total_sizes
529
+ )
492
530
  total_size_str = common.format_bytes_to_human_readable(total_size)
493
531
 
494
532
  console.print(
495
- f"📅 Cache captured at: {cache_data.get('timestamp', 'Unknown')}",
496
- style="bold blue",
533
+ f"📅 Cache captured at: {cache_data.timestamp}", style="bold blue"
497
534
  )
535
+ console.print(f"📁 Project ID: {cache_data.project_id}", style="bold blue")
536
+ console.print()
498
537
  console.print(
499
- f"📁 Project ID: {cache_data.get('project_id', 'Unknown')}",
500
- style="bold blue",
538
+ f"📊 Total files: {len(files_with_total_sizes)}", style="bold green"
501
539
  )
502
- console.print()
503
- console.print(f"📊 Total files: {len(files)}", style="bold green")
504
540
  console.print(f"💾 Total size: {total_size_str}", style="bold green")
505
541
  console.print()
506
542
 
507
- for file_info in files:
508
- size_bytes = file_info.get("size_bytes", 0)
543
+ for file_info in files_with_total_sizes:
544
+ total_size = file_info.total_size
509
545
 
510
- size_str = cli_common.format_bytes_to_human_readable(int(size_bytes))
546
+ size_str = cli_common.format_bytes_to_human_readable(int(total_size))
511
547
 
512
548
  modified_str = cli_common.format_localized_time(
513
- file_info.get("modified", "Unknown")
549
+ file_info.file_summary.modified
514
550
  )
515
551
 
516
552
  table.add_row(
517
- file_info.get("path", "Unknown"),
553
+ file_info.file_summary.path,
518
554
  size_str,
519
555
  modified_str,
520
- file_info.get("file_type", "Unknown"),
521
- file_info.get("permissions", "Unknown"),
556
+ file_info.file_summary.file_type or "Unknown",
557
+ file_info.file_summary.permissions or "Unknown",
522
558
  )
523
559
 
524
560
  console.print(table)
@@ -528,6 +564,21 @@ def view_cache_summary(
528
564
  raise
529
565
 
530
566
 
567
+ def _get_sort_key(sort_by: str) -> Callable[[FileSummaryWithTotalSize], Any]:
568
+ if sort_by == SORT_BY_FILEPATH:
569
+ return lambda x: x.file_summary.path
570
+ elif sort_by == SORT_BY_SIZE:
571
+ return lambda x: x.total_size
572
+ elif sort_by == SORT_BY_MODIFIED:
573
+ return lambda x: x.file_summary.modified
574
+ elif sort_by == SORT_BY_TYPE:
575
+ return lambda x: x.file_summary.file_type or ""
576
+ elif sort_by == SORT_BY_PERMISSIONS:
577
+ return lambda x: x.file_summary.permissions or ""
578
+ else:
579
+ raise ValueError(f"Invalid --sort argument: {sort_by}")
580
+
581
+
531
582
  def view_cache_summary_by_project(
532
583
  remote_provider: BasetenRemote,
533
584
  project_identifier: str,
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import json
4
4
  import logging
5
+ import os
5
6
  import re
6
7
  import shutil
7
8
  from abc import ABC, abstractmethod
@@ -783,6 +784,10 @@ class ServingImageBuilder(ImageBuilder):
783
784
  config
784
785
  )
785
786
 
787
+ non_root_user = os.getenv("BT_USE_NON_ROOT_USER", False)
788
+ enable_model_container_admin_commands = os.getenv(
789
+ "BT_ENABLE_MODEL_CONTAINER_ADMIN_CMDS"
790
+ )
786
791
  dockerfile_contents = dockerfile_template.render(
787
792
  should_install_server_requirements=should_install_server_requirements,
788
793
  base_image_name_and_tag=base_image_name_and_tag,
@@ -816,6 +821,8 @@ class ServingImageBuilder(ImageBuilder):
816
821
  build_commands=build_commands,
817
822
  use_local_src=config.use_local_src,
818
823
  passthrough_environment_variables=passthrough_environment_variables,
824
+ non_root_user=non_root_user,
825
+ enable_model_container_admin_commands=enable_model_container_admin_commands,
819
826
  **FILENAME_CONSTANTS_MAP,
820
827
  )
821
828
  # Consolidate repeated empty lines to single empty lines.
@@ -1,3 +1,4 @@
1
+ import re
1
2
  from dataclasses import dataclass, field
2
3
  from pathlib import Path
3
4
  from typing import Dict, List
@@ -31,12 +32,32 @@ class DockerBuildEmulator:
31
32
  self._context_dir = context_dir
32
33
 
33
34
  def run(self, fs_root_dir: Path) -> DockerBuildEmulatorResult:
34
- def _resolve_env(key: str) -> str:
35
- if key.startswith("$"):
36
- key = key.replace("$", "", 1)
37
- v = result.env[key]
38
- return v
39
- return key
35
+ def _resolve_env(in_value: str) -> str:
36
+ # Valid environment variable name pattern
37
+ var_name_pattern = r"[A-Za-z_][A-Za-z0-9_]*"
38
+
39
+ # Handle ${VAR} syntax
40
+ def replace_braced_var(match):
41
+ var_name = match.group(1)
42
+ return result.env.get(
43
+ var_name, match.group(0)
44
+ ) # Return original if not found
45
+
46
+ # Handle $VAR syntax (word boundary ensures we don't match parts of other vars)
47
+ def replace_simple_var(match):
48
+ var_name = match.group(1)
49
+ return result.env.get(
50
+ var_name, match.group(0)
51
+ ) # Return original if not found
52
+
53
+ # Replace ${VAR} patterns first, using % substitution to avoid additional braces noise with f-strings
54
+ value = re.sub(
55
+ r"\$\{(%s)\}" % var_name_pattern, replace_braced_var, in_value
56
+ )
57
+ # Then replace remaining $VAR patterns (only at word boundaries)
58
+ value = re.sub(r"\$(%s)\b" % var_name_pattern, replace_simple_var, value)
59
+
60
+ return value
40
61
 
41
62
  def _resolve_values(keys: List[str]) -> List[str]:
42
63
  return list(map(_resolve_env, keys))
@@ -53,11 +74,14 @@ class DockerBuildEmulator:
53
74
  if cmd.instruction == DockerInstruction.ENTRYPOINT:
54
75
  result.entrypoint = list(values)
55
76
  if cmd.instruction == DockerInstruction.COPY:
77
+ # Filter out --chown flags
78
+ filtered_values = [v for v in values if not v.startswith("--chown")]
79
+
56
80
  # NB(nikhil): Skip COPY commands with --from flag (multi-stage builds)
57
- if len(values) != 2:
81
+ if len(filtered_values) != 2:
58
82
  continue
59
83
 
60
- src, dst = values
84
+ src, dst = filtered_values
61
85
  src = src.replace("./", "", 1)
62
86
  dst = dst.replace("/", "", 1)
63
87
  copy_tree_or_file(self._context_dir / src, fs_root_dir / dst)
@@ -138,6 +138,13 @@ class FileSummary(pydantic.BaseModel):
138
138
  )
139
139
 
140
140
 
141
+ class FileSummaryWithTotalSize(pydantic.BaseModel):
142
+ file_summary: FileSummary
143
+ total_size: int = pydantic.Field(
144
+ description="Total size of the file and all its subdirectories"
145
+ )
146
+
147
+
141
148
  class GetCacheSummaryResponseV1(pydantic.BaseModel):
142
149
  """Response for getting cache summary."""
143
150
 
@@ -8,6 +8,35 @@ FROM {{ base_image_name_and_tag }} AS truss_server
8
8
  {%- set python_executable = config.base_image.python_executable_path or 'python3' %}
9
9
  ENV PYTHON_EXECUTABLE="{{ python_executable }}"
10
10
 
11
+ {%- set app_username = "app" %} {# needed later for USER directive#}
12
+ {% block user_setup %}
13
+ {%- set app_user_uid = 60000 %}
14
+ {%- set control_server_dir = "/control" %}
15
+ {%- set default_owner = "root:root" %}
16
+ {# The non-root user's home directory. #}
17
+ {# uv will use $HOME to install packages. #}
18
+ ENV HOME=/home/{{ app_username }}
19
+ {# Directory containing inference server code. #}
20
+ ENV APP_HOME=/{{ app_username }}
21
+ RUN mkdir -p ${APP_HOME} {{ control_server_dir }}
22
+ {# Create a non-root user to run model containers. #}
23
+ RUN useradd -u {{ app_user_uid }} -ms /bin/bash {{ app_username }}
24
+ {% endblock %} {#- endblock user_setup #}
25
+
26
+ {#- at the very beginning, set non-interactive mode for apt #}
27
+ ENV DEBIAN_FRONTEND=noninteractive
28
+
29
+ {# If non-root user is enabled and model container admin commands are enabled, install sudo #}
30
+ {# to allow the non-root user to install packages. #}
31
+ {%- if non_root_user and enable_model_container_admin_commands %}
32
+ RUN apt update && apt install -y sudo
33
+ {%- set allowed_admin_commands = ["/usr/bin/apt install *", "/usr/bin/apt update"] %}
34
+ RUN echo "Defaults:{{ app_username }} passwd_tries=0\n{{ app_username }} ALL=(root) NOPASSWD: {{ allowed_admin_commands | join(", ") }}" > /etc/sudoers.d/app-packages
35
+ RUN chmod 0440 /etc/sudoers.d/app-packages
36
+ {#- optional but good practice: check if the sudoers file is valid #}
37
+ RUN visudo -c
38
+ {%- endif %} {#- endif non_root_user and enable_model_container_admin_commands #}
39
+
11
40
  {%- set UV_VERSION = "0.7.19" %}
12
41
  {#
13
42
  NB(nikhil): We use a semi-complex uv installation command across the board:
@@ -39,7 +68,8 @@ RUN if ! command -v uv >/dev/null 2>&1; then \
39
68
  command -v curl >/dev/null 2>&1 || (apt update && apt install -y curl) && \
40
69
  curl -LsSf --retry 5 --retry-delay 5 https://astral.sh/uv/{{ UV_VERSION }}/install.sh | sh; \
41
70
  fi
42
- ENV PATH="/root/.local/bin:$PATH"
71
+ {# Add the user's local bin to the path, used by uv. #}
72
+ ENV PATH=${PATH}:${HOME}/.local/bin
43
73
  {% endblock %}
44
74
 
45
75
  {% block base_image_patch %}
@@ -57,7 +87,7 @@ RUN {{ sys_pip_install_command }} install mkl
57
87
 
58
88
  {% block install_system_requirements %}
59
89
  {%- if should_install_system_requirements %}
60
- COPY ./{{ system_packages_filename }} {{ system_packages_filename }}
90
+ COPY --chown={{ default_owner }} ./{{ system_packages_filename }} {{ system_packages_filename }}
61
91
  RUN apt-get update && apt-get install --yes --no-install-recommends $(cat {{ system_packages_filename }}) \
62
92
  && apt-get autoremove -y \
63
93
  && apt-get clean -y \
@@ -68,11 +98,11 @@ RUN apt-get update && apt-get install --yes --no-install-recommends $(cat {{ sys
68
98
 
69
99
  {% block install_requirements %}
70
100
  {%- if should_install_user_requirements_file %}
71
- COPY ./{{ user_supplied_requirements_filename }} {{ user_supplied_requirements_filename }}
101
+ COPY --chown={{ default_owner }} ./{{ user_supplied_requirements_filename }} {{ user_supplied_requirements_filename }}
72
102
  RUN {{ sys_pip_install_command }} -r {{ user_supplied_requirements_filename }} --no-cache-dir
73
103
  {%- endif %}
74
104
  {%- if should_install_requirements %}
75
- COPY ./{{ config_requirements_filename }} {{ config_requirements_filename }}
105
+ COPY --chown={{ default_owner }} ./{{ config_requirements_filename }} {{ config_requirements_filename }}
76
106
  RUN {{ sys_pip_install_command }} -r {{ config_requirements_filename }} --no-cache-dir
77
107
  {%- endif %}
78
108
  {% endblock %}
@@ -80,7 +110,6 @@ RUN {{ sys_pip_install_command }} -r {{ config_requirements_filename }} --no-cac
80
110
 
81
111
 
82
112
  {%- if not config.docker_server %}
83
- ENV APP_HOME="/app"
84
113
  WORKDIR $APP_HOME
85
114
  {%- endif %}
86
115
 
@@ -90,7 +119,7 @@ WORKDIR $APP_HOME
90
119
 
91
120
  {% block bundled_packages_copy %}
92
121
  {%- if bundled_packages_dir_exists %}
93
- COPY ./{{ config.bundled_packages_dir }} /packages
122
+ COPY --chown={{ default_owner }} ./{{ config.bundled_packages_dir }} /packages
94
123
  {%- endif %}
95
124
  {% endblock %}
96
125
 
@@ -1,20 +1,21 @@
1
1
  FROM python:3.11-slim AS cache_warmer
2
2
 
3
- RUN mkdir -p /app/model_cache
4
- WORKDIR /app
3
+ ENV APP_HOME=/app
4
+ RUN mkdir -p ${APP_HOME}/model_cache
5
+ WORKDIR ${APP_HOME}
5
6
 
6
7
  {% if hf_access_token %}
7
8
  ENV HUGGING_FACE_HUB_TOKEN="{{hf_access_token}}"
8
9
  {% endif %}
9
10
 
10
11
  RUN apt-get -y update; apt-get -y install curl; curl -s https://baseten-public.s3.us-west-2.amazonaws.com/bin/b10cp-5fe8dc7da-linux-amd64 -o /app/b10cp; chmod +x /app/b10cp
11
- ENV B10CP_PATH_TRUSS="/app/b10cp"
12
- COPY ./cache_requirements.txt /app/cache_requirements.txt
13
- RUN pip install -r /app/cache_requirements.txt --no-cache-dir && rm -rf /root/.cache/pip
14
- COPY ./cache_warmer.py /cache_warmer.py
12
+ ENV B10CP_PATH_TRUSS="${APP_HOME}/b10cp"
13
+ COPY --chown={{ default_owner }} ./cache_requirements.txt ${APP_HOME}/cache_requirements.txt
14
+ RUN pip install -r ${APP_HOME}/cache_requirements.txt --no-cache-dir && rm -rf /root/.cache/pip
15
+ COPY --chown={{ default_owner }} ./cache_warmer.py /cache_warmer.py
15
16
 
16
17
  {% for credential in credentials_to_cache %}
17
- COPY ./{{credential}} /app/{{credential}}
18
+ COPY ./{{credential}} ${APP_HOME}/{{credential}}
18
19
  {% endfor %}
19
20
 
20
21
  {% for repo, hf_dir in models.items() %}
@@ -1,14 +1,15 @@
1
1
  import asyncio
2
2
  import logging
3
- from typing import Any, Callable, Dict
3
+ from typing import Any, Callable, Dict, Optional, Protocol
4
4
 
5
5
  import httpx
6
6
  from fastapi import APIRouter, WebSocket
7
7
  from fastapi.responses import JSONResponse, StreamingResponse
8
+ from httpx_ws import AsyncWebSocketSession, WebSocketDisconnect, aconnect_ws
8
9
  from httpx_ws import _exceptions as httpx_ws_exceptions
9
- from httpx_ws import aconnect_ws
10
10
  from starlette.requests import ClientDisconnect, Request
11
11
  from starlette.responses import Response
12
+ from starlette.websockets import WebSocketDisconnect as StartletteWebSocketDisconnect
12
13
  from tenacity import RetryCallState, Retrying, retry_if_exception_type, wait_fixed
13
14
  from wsproto.events import BytesMessage, TextMessage
14
15
 
@@ -30,6 +31,10 @@ BASE_RETRY_EXCEPTIONS = (
30
31
  control_app = APIRouter()
31
32
 
32
33
 
34
+ class CloseableWebsocket(Protocol):
35
+ async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: ...
36
+
37
+
33
38
  @control_app.get("/")
34
39
  def index():
35
40
  return {}
@@ -118,13 +123,75 @@ def inference_retries(
118
123
  yield attempt
119
124
 
120
125
 
121
- async def _safe_close_ws(ws: WebSocket, logger: logging.Logger):
126
+ async def _safe_close_ws(
127
+ ws: CloseableWebsocket,
128
+ logger: logging.Logger,
129
+ code: int = 1000,
130
+ reason: Optional[str] = None,
131
+ ):
122
132
  try:
123
- await ws.close()
133
+ await ws.close(code, reason)
124
134
  except RuntimeError as close_error:
125
135
  logger.debug(f"Duplicate close of websocket: `{close_error}`.")
126
136
 
127
137
 
138
+ async def forward_to_server(
139
+ client_ws: WebSocket, server_ws: AsyncWebSocketSession
140
+ ) -> None:
141
+ while True:
142
+ message = await client_ws.receive()
143
+ if message.get("type") == "websocket.disconnect":
144
+ raise StartletteWebSocketDisconnect(
145
+ message.get("code", 1000), message.get("reason")
146
+ )
147
+ if "text" in message:
148
+ await server_ws.send_text(message["text"])
149
+ elif "bytes" in message:
150
+ await server_ws.send_bytes(message["bytes"])
151
+
152
+
153
+ async def forward_to_client(client_ws: WebSocket, server_ws: AsyncWebSocketSession):
154
+ while True:
155
+ message = await server_ws.receive()
156
+ if message is None:
157
+ break
158
+ if isinstance(message, TextMessage):
159
+ await client_ws.send_text(message.data)
160
+ elif isinstance(message, BytesMessage):
161
+ await client_ws.send_bytes(message.data)
162
+
163
+
164
+ # NB(nikhil): _handle_websocket_forwarding uses some py311 specific syntax, but in newer
165
+ # versions of truss we're guaranteed to be running the control server with at least that version.
166
+ async def _handle_websocket_forwarding(
167
+ client_ws: WebSocket, server_ws: AsyncWebSocketSession
168
+ ):
169
+ logger = client_ws.app.state.logger
170
+ try:
171
+ async with asyncio.TaskGroup() as tg: # type: ignore[attr-defined]
172
+ tg.create_task(forward_to_client(client_ws, server_ws))
173
+ tg.create_task(forward_to_server(client_ws, server_ws))
174
+ except ExceptionGroup as eg: # type: ignore[name-defined] # noqa: F821
175
+ exc = eg.exceptions[0] # NB(nikhil): Only care about the first one.
176
+ if isinstance(exc, WebSocketDisconnect):
177
+ await _safe_close_ws(client_ws, logger, exc.code, exc.reason)
178
+ elif isinstance(exc, StartletteWebSocketDisconnect):
179
+ await _safe_close_ws(server_ws, logger, exc.code, exc.reason)
180
+ else:
181
+ logger.warning(f"Ungraceful websocket close: {exc}")
182
+ finally:
183
+ await _safe_close_ws(client_ws, logger)
184
+ await _safe_close_ws(server_ws, logger)
185
+
186
+
187
+ async def _attempt_websocket_proxy(
188
+ client_ws: WebSocket, proxy_client: httpx.AsyncClient, logger
189
+ ):
190
+ async with aconnect_ws("/v1/websocket", proxy_client) as server_ws: # type: ignore
191
+ await client_ws.accept()
192
+ await _handle_websocket_forwarding(client_ws, server_ws)
193
+
194
+
128
195
  async def proxy_ws(client_ws: WebSocket):
129
196
  proxy_client: httpx.AsyncClient = client_ws.app.state.proxy_client
130
197
  logger = client_ws.app.state.logger
@@ -132,34 +199,7 @@ async def proxy_ws(client_ws: WebSocket):
132
199
  for attempt in inference_retries():
133
200
  with attempt:
134
201
  try:
135
- async with aconnect_ws("/v1/websocket", proxy_client) as server_ws: # type: ignore
136
- # Unfortunate, but FastAPI and httpx-ws have slightly different abstractions
137
- # for sending data, so it's not easy to create a unified wrapper.
138
- async def forward_to_server():
139
- while True:
140
- message = await client_ws.receive()
141
- if message.get("type") == "websocket.disconnect":
142
- break
143
- if "text" in message:
144
- await server_ws.send_text(message["text"])
145
- elif "bytes" in message:
146
- await server_ws.send_bytes(message["bytes"])
147
-
148
- async def forward_to_client():
149
- while True:
150
- message = await server_ws.receive()
151
- if message is None:
152
- break
153
- if isinstance(message, TextMessage):
154
- await client_ws.send_text(message.data)
155
- elif isinstance(message, BytesMessage):
156
- await client_ws.send_bytes(message.data)
157
-
158
- await client_ws.accept()
159
- try:
160
- await asyncio.gather(forward_to_client(), forward_to_server())
161
- finally:
162
- await _safe_close_ws(client_ws, logger)
202
+ await _attempt_websocket_proxy(client_ws, proxy_client, logger)
163
203
  except httpx_ws_exceptions.HTTPXWSException as e:
164
204
  logger.warning(f"WebSocket connection rejected: {e}")
165
205
  await _safe_close_ws(client_ws, logger)
@@ -1,3 +1,3 @@
1
1
  {% for file in cached_files %}
2
- COPY --from=cache_warmer {{file.source}} {{file.dst}}
2
+ COPY --chown={{ default_owner }} --from=cache_warmer {{file.source}} {{file.dst}}
3
3
  {% endfor %}
@@ -1,4 +1,5 @@
1
1
  [supervisord]
2
+ pidfile=/tmp/supervisord.pid ; Set PID file location to /tmp to be writable by the non-root user
2
3
  nodaemon=true ; Run supervisord in the foreground (useful for containers)
3
4
  logfile=/dev/null ; Disable logging to file (send logs to /dev/null)
4
5
  logfile_maxbytes=0 ; No size limit on logfile (since logging is disabled)
@@ -76,7 +76,7 @@ async def parse_body(request: Request) -> bytes:
76
76
 
77
77
 
78
78
  async def _safe_close_websocket(
79
- ws: WebSocket, reason: Optional[str], status_code: int = 1000
79
+ ws: WebSocket, status_code: int = 1000, reason: Optional[str] = None
80
80
  ) -> None:
81
81
  try:
82
82
  await ws.close(code=status_code, reason=reason)
@@ -257,14 +257,14 @@ class BasetenEndpoints:
257
257
  try:
258
258
  await ws.accept()
259
259
  await self._model.websocket(ws)
260
- await _safe_close_websocket(ws, None, status_code=1000)
260
+ await _safe_close_websocket(ws, status_code=1000, reason=None)
261
261
  except WebSocketDisconnect as ws_error:
262
262
  logging.info(
263
263
  f"Client terminated websocket connection: `{ws_error}`."
264
264
  )
265
265
  except Exception:
266
266
  await _safe_close_websocket(
267
- ws, errors.MODEL_ERROR_MESSAGE, status_code=1011
267
+ ws, status_code=1011, reason=errors.MODEL_ERROR_MESSAGE
268
268
  )
269
269
  raise # Re raise to let `intercept_exceptions` deal with it.
270
270