truss 0.10.13__py3-none-any.whl → 0.11.1rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/cli/chains_commands.py +1 -1
- truss/cli/train/core.py +82 -31
- truss/contexts/image_builder/serving_image_builder.py +7 -0
- truss/contexts/local_loader/docker_build_emulator.py +32 -8
- truss/remote/baseten/custom_types.py +7 -0
- truss/templates/base.Dockerfile.jinja +35 -6
- truss/templates/cache.Dockerfile.jinja +8 -7
- truss/templates/control/control/endpoints.py +72 -32
- truss/templates/copy_cache_files.Dockerfile.jinja +1 -1
- truss/templates/docker_server/supervisord.conf.jinja +1 -0
- truss/templates/server/truss_server.py +3 -3
- truss/templates/server.Dockerfile.jinja +33 -19
- truss/tests/cli/train/test_train_cli_core.py +254 -1
- truss/tests/contexts/image_builder/test_serving_image_builder.py +1 -1
- truss/tests/templates/control/control/test_endpoints.py +22 -14
- truss/tests/templates/control/control/test_server_integration.py +62 -41
- truss/tests/templates/server/test_truss_server.py +19 -12
- truss/tests/test_data/server.Dockerfile +13 -10
- truss/tests/test_model_inference.py +4 -2
- {truss-0.10.13.dist-info → truss-0.11.1rc1.dist-info}/METADATA +1 -1
- {truss-0.10.13.dist-info → truss-0.11.1rc1.dist-info}/RECORD +27 -27
- truss_chains/deployment/deployment_client.py +4 -2
- truss_chains/public_types.py +1 -0
- truss_chains/remote_chainlet/utils.py +8 -0
- {truss-0.10.13.dist-info → truss-0.11.1rc1.dist-info}/WHEEL +0 -0
- {truss-0.10.13.dist-info → truss-0.11.1rc1.dist-info}/entry_points.txt +0 -0
- {truss-0.10.13.dist-info → truss-0.11.1rc1.dist-info}/licenses/LICENSE +0 -0
truss/cli/chains_commands.py
CHANGED
|
@@ -359,7 +359,7 @@ def push_chain(
|
|
|
359
359
|
"--name",
|
|
360
360
|
type=str,
|
|
361
361
|
required=False,
|
|
362
|
-
help="Name of the chain to be
|
|
362
|
+
help="Name of the chain to be watched. If not given, the entrypoint name is used.",
|
|
363
363
|
)
|
|
364
364
|
@click.option(
|
|
365
365
|
"--remote",
|
truss/cli/train/core.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import os
|
|
2
3
|
import tarfile
|
|
3
4
|
import tempfile
|
|
4
5
|
from dataclasses import dataclass
|
|
@@ -16,6 +17,11 @@ from truss.cli.train.metrics_watcher import MetricsWatcher
|
|
|
16
17
|
from truss.cli.train.types import PrepareCheckpointArgs, PrepareCheckpointResult
|
|
17
18
|
from truss.cli.utils import common as cli_common
|
|
18
19
|
from truss.cli.utils.output import console
|
|
20
|
+
from truss.remote.baseten.custom_types import (
|
|
21
|
+
FileSummary,
|
|
22
|
+
FileSummaryWithTotalSize,
|
|
23
|
+
GetCacheSummaryResponseV1,
|
|
24
|
+
)
|
|
19
25
|
from truss.remote.baseten.remote import BasetenRemote
|
|
20
26
|
from truss_train import loader
|
|
21
27
|
from truss_train.definitions import DeployCheckpointsConfig
|
|
@@ -446,6 +452,44 @@ def fetch_project_by_name_or_id(
|
|
|
446
452
|
raise click.ClickException(f"Error fetching project: {str(e)}")
|
|
447
453
|
|
|
448
454
|
|
|
455
|
+
def create_file_summary_with_directory_sizes(
|
|
456
|
+
files: list[FileSummary],
|
|
457
|
+
) -> list[FileSummaryWithTotalSize]:
|
|
458
|
+
directory_sizes = calculate_directory_sizes(files)
|
|
459
|
+
return [
|
|
460
|
+
FileSummaryWithTotalSize(
|
|
461
|
+
file_summary=file_info,
|
|
462
|
+
total_size=directory_sizes.get(file_info.path, file_info.size_bytes),
|
|
463
|
+
)
|
|
464
|
+
for file_info in files
|
|
465
|
+
]
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
def calculate_directory_sizes(
|
|
469
|
+
files: list[FileSummary], max_depth: int = 100
|
|
470
|
+
) -> dict[str, int]:
|
|
471
|
+
directory_sizes = {}
|
|
472
|
+
|
|
473
|
+
for file_info in files:
|
|
474
|
+
if file_info.file_type == "directory":
|
|
475
|
+
directory_sizes[file_info.path] = 0
|
|
476
|
+
|
|
477
|
+
for file_info in files:
|
|
478
|
+
current_path = file_info.path
|
|
479
|
+
for i in range(max_depth):
|
|
480
|
+
if current_path is None:
|
|
481
|
+
break
|
|
482
|
+
if current_path in directory_sizes:
|
|
483
|
+
directory_sizes[current_path] += file_info.size_bytes
|
|
484
|
+
# Move to parent directory
|
|
485
|
+
parent = os.path.dirname(current_path)
|
|
486
|
+
if parent == current_path: # Reached root
|
|
487
|
+
break
|
|
488
|
+
current_path = parent
|
|
489
|
+
|
|
490
|
+
return directory_sizes
|
|
491
|
+
|
|
492
|
+
|
|
449
493
|
def view_cache_summary(
|
|
450
494
|
remote_provider: BasetenRemote,
|
|
451
495
|
project_id: str,
|
|
@@ -454,12 +498,14 @@ def view_cache_summary(
|
|
|
454
498
|
):
|
|
455
499
|
"""View cache summary for a training project."""
|
|
456
500
|
try:
|
|
457
|
-
|
|
501
|
+
raw_cache_data = remote_provider.api.get_cache_summary(project_id)
|
|
458
502
|
|
|
459
|
-
if not
|
|
503
|
+
if not raw_cache_data:
|
|
460
504
|
console.print("No cache summary found for this project.", style="yellow")
|
|
461
505
|
return
|
|
462
506
|
|
|
507
|
+
cache_data = GetCacheSummaryResponseV1.model_validate(raw_cache_data)
|
|
508
|
+
|
|
463
509
|
table = rich.table.Table(title=f"Cache summary for project: {project_id}")
|
|
464
510
|
table.add_column("File Path", style="cyan")
|
|
465
511
|
table.add_column("Size", style="green")
|
|
@@ -467,58 +513,48 @@ def view_cache_summary(
|
|
|
467
513
|
table.add_column("Type")
|
|
468
514
|
table.add_column("Permissions", style="magenta")
|
|
469
515
|
|
|
470
|
-
files = cache_data.
|
|
516
|
+
files = cache_data.file_summaries
|
|
471
517
|
if not files:
|
|
472
518
|
console.print("No files found in cache.", style="yellow")
|
|
473
519
|
return
|
|
474
520
|
|
|
475
|
-
|
|
521
|
+
files_with_total_sizes = create_file_summary_with_directory_sizes(files)
|
|
476
522
|
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
files.sort(key=lambda x: x.get("size_bytes", 0), reverse=reverse)
|
|
481
|
-
elif sort_by == SORT_BY_MODIFIED:
|
|
482
|
-
files.sort(key=lambda x: x.get("modified", ""), reverse=reverse)
|
|
483
|
-
elif sort_by == SORT_BY_TYPE:
|
|
484
|
-
files.sort(key=lambda x: x.get("file_type", ""), reverse=reverse)
|
|
485
|
-
elif sort_by == SORT_BY_PERMISSIONS:
|
|
486
|
-
files.sort(key=lambda x: x.get("permissions", ""), reverse=reverse)
|
|
487
|
-
|
|
488
|
-
total_size = 0
|
|
489
|
-
for file_info in files:
|
|
490
|
-
total_size += file_info.get("size_bytes", 0)
|
|
523
|
+
reverse = order == SORT_ORDER_DESC
|
|
524
|
+
sort_key = _get_sort_key(sort_by)
|
|
525
|
+
files_with_total_sizes.sort(key=sort_key, reverse=reverse)
|
|
491
526
|
|
|
527
|
+
total_size = sum(
|
|
528
|
+
file_info.file_summary.size_bytes for file_info in files_with_total_sizes
|
|
529
|
+
)
|
|
492
530
|
total_size_str = common.format_bytes_to_human_readable(total_size)
|
|
493
531
|
|
|
494
532
|
console.print(
|
|
495
|
-
f"📅 Cache captured at: {cache_data.
|
|
496
|
-
style="bold blue",
|
|
533
|
+
f"📅 Cache captured at: {cache_data.timestamp}", style="bold blue"
|
|
497
534
|
)
|
|
535
|
+
console.print(f"📁 Project ID: {cache_data.project_id}", style="bold blue")
|
|
536
|
+
console.print()
|
|
498
537
|
console.print(
|
|
499
|
-
f"
|
|
500
|
-
style="bold blue",
|
|
538
|
+
f"📊 Total files: {len(files_with_total_sizes)}", style="bold green"
|
|
501
539
|
)
|
|
502
|
-
console.print()
|
|
503
|
-
console.print(f"📊 Total files: {len(files)}", style="bold green")
|
|
504
540
|
console.print(f"💾 Total size: {total_size_str}", style="bold green")
|
|
505
541
|
console.print()
|
|
506
542
|
|
|
507
|
-
for file_info in
|
|
508
|
-
|
|
543
|
+
for file_info in files_with_total_sizes:
|
|
544
|
+
total_size = file_info.total_size
|
|
509
545
|
|
|
510
|
-
size_str = cli_common.format_bytes_to_human_readable(int(
|
|
546
|
+
size_str = cli_common.format_bytes_to_human_readable(int(total_size))
|
|
511
547
|
|
|
512
548
|
modified_str = cli_common.format_localized_time(
|
|
513
|
-
file_info.
|
|
549
|
+
file_info.file_summary.modified
|
|
514
550
|
)
|
|
515
551
|
|
|
516
552
|
table.add_row(
|
|
517
|
-
file_info.
|
|
553
|
+
file_info.file_summary.path,
|
|
518
554
|
size_str,
|
|
519
555
|
modified_str,
|
|
520
|
-
file_info.
|
|
521
|
-
file_info.
|
|
556
|
+
file_info.file_summary.file_type or "Unknown",
|
|
557
|
+
file_info.file_summary.permissions or "Unknown",
|
|
522
558
|
)
|
|
523
559
|
|
|
524
560
|
console.print(table)
|
|
@@ -528,6 +564,21 @@ def view_cache_summary(
|
|
|
528
564
|
raise
|
|
529
565
|
|
|
530
566
|
|
|
567
|
+
def _get_sort_key(sort_by: str) -> Callable[[FileSummaryWithTotalSize], Any]:
|
|
568
|
+
if sort_by == SORT_BY_FILEPATH:
|
|
569
|
+
return lambda x: x.file_summary.path
|
|
570
|
+
elif sort_by == SORT_BY_SIZE:
|
|
571
|
+
return lambda x: x.total_size
|
|
572
|
+
elif sort_by == SORT_BY_MODIFIED:
|
|
573
|
+
return lambda x: x.file_summary.modified
|
|
574
|
+
elif sort_by == SORT_BY_TYPE:
|
|
575
|
+
return lambda x: x.file_summary.file_type or ""
|
|
576
|
+
elif sort_by == SORT_BY_PERMISSIONS:
|
|
577
|
+
return lambda x: x.file_summary.permissions or ""
|
|
578
|
+
else:
|
|
579
|
+
raise ValueError(f"Invalid --sort argument: {sort_by}")
|
|
580
|
+
|
|
581
|
+
|
|
531
582
|
def view_cache_summary_by_project(
|
|
532
583
|
remote_provider: BasetenRemote,
|
|
533
584
|
project_identifier: str,
|
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
4
|
import logging
|
|
5
|
+
import os
|
|
5
6
|
import re
|
|
6
7
|
import shutil
|
|
7
8
|
from abc import ABC, abstractmethod
|
|
@@ -783,6 +784,10 @@ class ServingImageBuilder(ImageBuilder):
|
|
|
783
784
|
config
|
|
784
785
|
)
|
|
785
786
|
|
|
787
|
+
non_root_user = os.getenv("BT_USE_NON_ROOT_USER", False)
|
|
788
|
+
enable_model_container_admin_commands = os.getenv(
|
|
789
|
+
"BT_ENABLE_MODEL_CONTAINER_ADMIN_CMDS"
|
|
790
|
+
)
|
|
786
791
|
dockerfile_contents = dockerfile_template.render(
|
|
787
792
|
should_install_server_requirements=should_install_server_requirements,
|
|
788
793
|
base_image_name_and_tag=base_image_name_and_tag,
|
|
@@ -816,6 +821,8 @@ class ServingImageBuilder(ImageBuilder):
|
|
|
816
821
|
build_commands=build_commands,
|
|
817
822
|
use_local_src=config.use_local_src,
|
|
818
823
|
passthrough_environment_variables=passthrough_environment_variables,
|
|
824
|
+
non_root_user=non_root_user,
|
|
825
|
+
enable_model_container_admin_commands=enable_model_container_admin_commands,
|
|
819
826
|
**FILENAME_CONSTANTS_MAP,
|
|
820
827
|
)
|
|
821
828
|
# Consolidate repeated empty lines to single empty lines.
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import re
|
|
1
2
|
from dataclasses import dataclass, field
|
|
2
3
|
from pathlib import Path
|
|
3
4
|
from typing import Dict, List
|
|
@@ -31,12 +32,32 @@ class DockerBuildEmulator:
|
|
|
31
32
|
self._context_dir = context_dir
|
|
32
33
|
|
|
33
34
|
def run(self, fs_root_dir: Path) -> DockerBuildEmulatorResult:
|
|
34
|
-
def _resolve_env(
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
35
|
+
def _resolve_env(in_value: str) -> str:
|
|
36
|
+
# Valid environment variable name pattern
|
|
37
|
+
var_name_pattern = r"[A-Za-z_][A-Za-z0-9_]*"
|
|
38
|
+
|
|
39
|
+
# Handle ${VAR} syntax
|
|
40
|
+
def replace_braced_var(match):
|
|
41
|
+
var_name = match.group(1)
|
|
42
|
+
return result.env.get(
|
|
43
|
+
var_name, match.group(0)
|
|
44
|
+
) # Return original if not found
|
|
45
|
+
|
|
46
|
+
# Handle $VAR syntax (word boundary ensures we don't match parts of other vars)
|
|
47
|
+
def replace_simple_var(match):
|
|
48
|
+
var_name = match.group(1)
|
|
49
|
+
return result.env.get(
|
|
50
|
+
var_name, match.group(0)
|
|
51
|
+
) # Return original if not found
|
|
52
|
+
|
|
53
|
+
# Replace ${VAR} patterns first, using % substitution to avoid additional braces noise with f-strings
|
|
54
|
+
value = re.sub(
|
|
55
|
+
r"\$\{(%s)\}" % var_name_pattern, replace_braced_var, in_value
|
|
56
|
+
)
|
|
57
|
+
# Then replace remaining $VAR patterns (only at word boundaries)
|
|
58
|
+
value = re.sub(r"\$(%s)\b" % var_name_pattern, replace_simple_var, value)
|
|
59
|
+
|
|
60
|
+
return value
|
|
40
61
|
|
|
41
62
|
def _resolve_values(keys: List[str]) -> List[str]:
|
|
42
63
|
return list(map(_resolve_env, keys))
|
|
@@ -53,11 +74,14 @@ class DockerBuildEmulator:
|
|
|
53
74
|
if cmd.instruction == DockerInstruction.ENTRYPOINT:
|
|
54
75
|
result.entrypoint = list(values)
|
|
55
76
|
if cmd.instruction == DockerInstruction.COPY:
|
|
77
|
+
# Filter out --chown flags
|
|
78
|
+
filtered_values = [v for v in values if not v.startswith("--chown")]
|
|
79
|
+
|
|
56
80
|
# NB(nikhil): Skip COPY commands with --from flag (multi-stage builds)
|
|
57
|
-
if len(
|
|
81
|
+
if len(filtered_values) != 2:
|
|
58
82
|
continue
|
|
59
83
|
|
|
60
|
-
src, dst =
|
|
84
|
+
src, dst = filtered_values
|
|
61
85
|
src = src.replace("./", "", 1)
|
|
62
86
|
dst = dst.replace("/", "", 1)
|
|
63
87
|
copy_tree_or_file(self._context_dir / src, fs_root_dir / dst)
|
|
@@ -138,6 +138,13 @@ class FileSummary(pydantic.BaseModel):
|
|
|
138
138
|
)
|
|
139
139
|
|
|
140
140
|
|
|
141
|
+
class FileSummaryWithTotalSize(pydantic.BaseModel):
|
|
142
|
+
file_summary: FileSummary
|
|
143
|
+
total_size: int = pydantic.Field(
|
|
144
|
+
description="Total size of the file and all its subdirectories"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
141
148
|
class GetCacheSummaryResponseV1(pydantic.BaseModel):
|
|
142
149
|
"""Response for getting cache summary."""
|
|
143
150
|
|
|
@@ -8,6 +8,35 @@ FROM {{ base_image_name_and_tag }} AS truss_server
|
|
|
8
8
|
{%- set python_executable = config.base_image.python_executable_path or 'python3' %}
|
|
9
9
|
ENV PYTHON_EXECUTABLE="{{ python_executable }}"
|
|
10
10
|
|
|
11
|
+
{%- set app_username = "app" %} {# needed later for USER directive#}
|
|
12
|
+
{% block user_setup %}
|
|
13
|
+
{%- set app_user_uid = 60000 %}
|
|
14
|
+
{%- set control_server_dir = "/control" %}
|
|
15
|
+
{%- set default_owner = "root:root" %}
|
|
16
|
+
{# The non-root user's home directory. #}
|
|
17
|
+
{# uv will use $HOME to install packages. #}
|
|
18
|
+
ENV HOME=/home/{{ app_username }}
|
|
19
|
+
{# Directory containing inference server code. #}
|
|
20
|
+
ENV APP_HOME=/{{ app_username }}
|
|
21
|
+
RUN mkdir -p ${APP_HOME} {{ control_server_dir }}
|
|
22
|
+
{# Create a non-root user to run model containers. #}
|
|
23
|
+
RUN useradd -u {{ app_user_uid }} -ms /bin/bash {{ app_username }}
|
|
24
|
+
{% endblock %} {#- endblock user_setup #}
|
|
25
|
+
|
|
26
|
+
{#- at the very beginning, set non-interactive mode for apt #}
|
|
27
|
+
ENV DEBIAN_FRONTEND=noninteractive
|
|
28
|
+
|
|
29
|
+
{# If non-root user is enabled and model container admin commands are enabled, install sudo #}
|
|
30
|
+
{# to allow the non-root user to install packages. #}
|
|
31
|
+
{%- if non_root_user and enable_model_container_admin_commands %}
|
|
32
|
+
RUN apt update && apt install -y sudo
|
|
33
|
+
{%- set allowed_admin_commands = ["/usr/bin/apt install *", "/usr/bin/apt update"] %}
|
|
34
|
+
RUN echo "Defaults:{{ app_username }} passwd_tries=0\n{{ app_username }} ALL=(root) NOPASSWD: {{ allowed_admin_commands | join(", ") }}" > /etc/sudoers.d/app-packages
|
|
35
|
+
RUN chmod 0440 /etc/sudoers.d/app-packages
|
|
36
|
+
{#- optional but good practice: check if the sudoers file is valid #}
|
|
37
|
+
RUN visudo -c
|
|
38
|
+
{%- endif %} {#- endif non_root_user and enable_model_container_admin_commands #}
|
|
39
|
+
|
|
11
40
|
{%- set UV_VERSION = "0.7.19" %}
|
|
12
41
|
{#
|
|
13
42
|
NB(nikhil): We use a semi-complex uv installation command across the board:
|
|
@@ -39,7 +68,8 @@ RUN if ! command -v uv >/dev/null 2>&1; then \
|
|
|
39
68
|
command -v curl >/dev/null 2>&1 || (apt update && apt install -y curl) && \
|
|
40
69
|
curl -LsSf --retry 5 --retry-delay 5 https://astral.sh/uv/{{ UV_VERSION }}/install.sh | sh; \
|
|
41
70
|
fi
|
|
42
|
-
|
|
71
|
+
{# Add the user's local bin to the path, used by uv. #}
|
|
72
|
+
ENV PATH=${PATH}:${HOME}/.local/bin
|
|
43
73
|
{% endblock %}
|
|
44
74
|
|
|
45
75
|
{% block base_image_patch %}
|
|
@@ -57,7 +87,7 @@ RUN {{ sys_pip_install_command }} install mkl
|
|
|
57
87
|
|
|
58
88
|
{% block install_system_requirements %}
|
|
59
89
|
{%- if should_install_system_requirements %}
|
|
60
|
-
COPY ./{{ system_packages_filename }} {{ system_packages_filename }}
|
|
90
|
+
COPY --chown={{ default_owner }} ./{{ system_packages_filename }} {{ system_packages_filename }}
|
|
61
91
|
RUN apt-get update && apt-get install --yes --no-install-recommends $(cat {{ system_packages_filename }}) \
|
|
62
92
|
&& apt-get autoremove -y \
|
|
63
93
|
&& apt-get clean -y \
|
|
@@ -68,11 +98,11 @@ RUN apt-get update && apt-get install --yes --no-install-recommends $(cat {{ sys
|
|
|
68
98
|
|
|
69
99
|
{% block install_requirements %}
|
|
70
100
|
{%- if should_install_user_requirements_file %}
|
|
71
|
-
COPY ./{{ user_supplied_requirements_filename }} {{ user_supplied_requirements_filename }}
|
|
101
|
+
COPY --chown={{ default_owner }} ./{{ user_supplied_requirements_filename }} {{ user_supplied_requirements_filename }}
|
|
72
102
|
RUN {{ sys_pip_install_command }} -r {{ user_supplied_requirements_filename }} --no-cache-dir
|
|
73
103
|
{%- endif %}
|
|
74
104
|
{%- if should_install_requirements %}
|
|
75
|
-
COPY ./{{ config_requirements_filename }} {{ config_requirements_filename }}
|
|
105
|
+
COPY --chown={{ default_owner }} ./{{ config_requirements_filename }} {{ config_requirements_filename }}
|
|
76
106
|
RUN {{ sys_pip_install_command }} -r {{ config_requirements_filename }} --no-cache-dir
|
|
77
107
|
{%- endif %}
|
|
78
108
|
{% endblock %}
|
|
@@ -80,7 +110,6 @@ RUN {{ sys_pip_install_command }} -r {{ config_requirements_filename }} --no-cac
|
|
|
80
110
|
|
|
81
111
|
|
|
82
112
|
{%- if not config.docker_server %}
|
|
83
|
-
ENV APP_HOME="/app"
|
|
84
113
|
WORKDIR $APP_HOME
|
|
85
114
|
{%- endif %}
|
|
86
115
|
|
|
@@ -90,7 +119,7 @@ WORKDIR $APP_HOME
|
|
|
90
119
|
|
|
91
120
|
{% block bundled_packages_copy %}
|
|
92
121
|
{%- if bundled_packages_dir_exists %}
|
|
93
|
-
COPY ./{{ config.bundled_packages_dir }} /packages
|
|
122
|
+
COPY --chown={{ default_owner }} ./{{ config.bundled_packages_dir }} /packages
|
|
94
123
|
{%- endif %}
|
|
95
124
|
{% endblock %}
|
|
96
125
|
|
|
@@ -1,20 +1,21 @@
|
|
|
1
1
|
FROM python:3.11-slim AS cache_warmer
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
3
|
+
ENV APP_HOME=/app
|
|
4
|
+
RUN mkdir -p ${APP_HOME}/model_cache
|
|
5
|
+
WORKDIR ${APP_HOME}
|
|
5
6
|
|
|
6
7
|
{% if hf_access_token %}
|
|
7
8
|
ENV HUGGING_FACE_HUB_TOKEN="{{hf_access_token}}"
|
|
8
9
|
{% endif %}
|
|
9
10
|
|
|
10
11
|
RUN apt-get -y update; apt-get -y install curl; curl -s https://baseten-public.s3.us-west-2.amazonaws.com/bin/b10cp-5fe8dc7da-linux-amd64 -o /app/b10cp; chmod +x /app/b10cp
|
|
11
|
-
ENV B10CP_PATH_TRUSS="/
|
|
12
|
-
COPY ./cache_requirements.txt /
|
|
13
|
-
RUN pip install -r /
|
|
14
|
-
COPY ./cache_warmer.py /cache_warmer.py
|
|
12
|
+
ENV B10CP_PATH_TRUSS="${APP_HOME}/b10cp"
|
|
13
|
+
COPY --chown={{ default_owner }} ./cache_requirements.txt ${APP_HOME}/cache_requirements.txt
|
|
14
|
+
RUN pip install -r ${APP_HOME}/cache_requirements.txt --no-cache-dir && rm -rf /root/.cache/pip
|
|
15
|
+
COPY --chown={{ default_owner }} ./cache_warmer.py /cache_warmer.py
|
|
15
16
|
|
|
16
17
|
{% for credential in credentials_to_cache %}
|
|
17
|
-
COPY ./{{credential}} /
|
|
18
|
+
COPY ./{{credential}} ${APP_HOME}/{{credential}}
|
|
18
19
|
{% endfor %}
|
|
19
20
|
|
|
20
21
|
{% for repo, hf_dir in models.items() %}
|
|
@@ -1,14 +1,15 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import logging
|
|
3
|
-
from typing import Any, Callable, Dict
|
|
3
|
+
from typing import Any, Callable, Dict, Optional, Protocol
|
|
4
4
|
|
|
5
5
|
import httpx
|
|
6
6
|
from fastapi import APIRouter, WebSocket
|
|
7
7
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
8
|
+
from httpx_ws import AsyncWebSocketSession, WebSocketDisconnect, aconnect_ws
|
|
8
9
|
from httpx_ws import _exceptions as httpx_ws_exceptions
|
|
9
|
-
from httpx_ws import aconnect_ws
|
|
10
10
|
from starlette.requests import ClientDisconnect, Request
|
|
11
11
|
from starlette.responses import Response
|
|
12
|
+
from starlette.websockets import WebSocketDisconnect as StartletteWebSocketDisconnect
|
|
12
13
|
from tenacity import RetryCallState, Retrying, retry_if_exception_type, wait_fixed
|
|
13
14
|
from wsproto.events import BytesMessage, TextMessage
|
|
14
15
|
|
|
@@ -30,6 +31,10 @@ BASE_RETRY_EXCEPTIONS = (
|
|
|
30
31
|
control_app = APIRouter()
|
|
31
32
|
|
|
32
33
|
|
|
34
|
+
class CloseableWebsocket(Protocol):
|
|
35
|
+
async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: ...
|
|
36
|
+
|
|
37
|
+
|
|
33
38
|
@control_app.get("/")
|
|
34
39
|
def index():
|
|
35
40
|
return {}
|
|
@@ -118,13 +123,75 @@ def inference_retries(
|
|
|
118
123
|
yield attempt
|
|
119
124
|
|
|
120
125
|
|
|
121
|
-
async def _safe_close_ws(
|
|
126
|
+
async def _safe_close_ws(
|
|
127
|
+
ws: CloseableWebsocket,
|
|
128
|
+
logger: logging.Logger,
|
|
129
|
+
code: int = 1000,
|
|
130
|
+
reason: Optional[str] = None,
|
|
131
|
+
):
|
|
122
132
|
try:
|
|
123
|
-
await ws.close()
|
|
133
|
+
await ws.close(code, reason)
|
|
124
134
|
except RuntimeError as close_error:
|
|
125
135
|
logger.debug(f"Duplicate close of websocket: `{close_error}`.")
|
|
126
136
|
|
|
127
137
|
|
|
138
|
+
async def forward_to_server(
|
|
139
|
+
client_ws: WebSocket, server_ws: AsyncWebSocketSession
|
|
140
|
+
) -> None:
|
|
141
|
+
while True:
|
|
142
|
+
message = await client_ws.receive()
|
|
143
|
+
if message.get("type") == "websocket.disconnect":
|
|
144
|
+
raise StartletteWebSocketDisconnect(
|
|
145
|
+
message.get("code", 1000), message.get("reason")
|
|
146
|
+
)
|
|
147
|
+
if "text" in message:
|
|
148
|
+
await server_ws.send_text(message["text"])
|
|
149
|
+
elif "bytes" in message:
|
|
150
|
+
await server_ws.send_bytes(message["bytes"])
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
async def forward_to_client(client_ws: WebSocket, server_ws: AsyncWebSocketSession):
|
|
154
|
+
while True:
|
|
155
|
+
message = await server_ws.receive()
|
|
156
|
+
if message is None:
|
|
157
|
+
break
|
|
158
|
+
if isinstance(message, TextMessage):
|
|
159
|
+
await client_ws.send_text(message.data)
|
|
160
|
+
elif isinstance(message, BytesMessage):
|
|
161
|
+
await client_ws.send_bytes(message.data)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
# NB(nikhil): _handle_websocket_forwarding uses some py311 specific syntax, but in newer
|
|
165
|
+
# versions of truss we're guaranteed to be running the control server with at least that version.
|
|
166
|
+
async def _handle_websocket_forwarding(
|
|
167
|
+
client_ws: WebSocket, server_ws: AsyncWebSocketSession
|
|
168
|
+
):
|
|
169
|
+
logger = client_ws.app.state.logger
|
|
170
|
+
try:
|
|
171
|
+
async with asyncio.TaskGroup() as tg: # type: ignore[attr-defined]
|
|
172
|
+
tg.create_task(forward_to_client(client_ws, server_ws))
|
|
173
|
+
tg.create_task(forward_to_server(client_ws, server_ws))
|
|
174
|
+
except ExceptionGroup as eg: # type: ignore[name-defined] # noqa: F821
|
|
175
|
+
exc = eg.exceptions[0] # NB(nikhil): Only care about the first one.
|
|
176
|
+
if isinstance(exc, WebSocketDisconnect):
|
|
177
|
+
await _safe_close_ws(client_ws, logger, exc.code, exc.reason)
|
|
178
|
+
elif isinstance(exc, StartletteWebSocketDisconnect):
|
|
179
|
+
await _safe_close_ws(server_ws, logger, exc.code, exc.reason)
|
|
180
|
+
else:
|
|
181
|
+
logger.warning(f"Ungraceful websocket close: {exc}")
|
|
182
|
+
finally:
|
|
183
|
+
await _safe_close_ws(client_ws, logger)
|
|
184
|
+
await _safe_close_ws(server_ws, logger)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
async def _attempt_websocket_proxy(
|
|
188
|
+
client_ws: WebSocket, proxy_client: httpx.AsyncClient, logger
|
|
189
|
+
):
|
|
190
|
+
async with aconnect_ws("/v1/websocket", proxy_client) as server_ws: # type: ignore
|
|
191
|
+
await client_ws.accept()
|
|
192
|
+
await _handle_websocket_forwarding(client_ws, server_ws)
|
|
193
|
+
|
|
194
|
+
|
|
128
195
|
async def proxy_ws(client_ws: WebSocket):
|
|
129
196
|
proxy_client: httpx.AsyncClient = client_ws.app.state.proxy_client
|
|
130
197
|
logger = client_ws.app.state.logger
|
|
@@ -132,34 +199,7 @@ async def proxy_ws(client_ws: WebSocket):
|
|
|
132
199
|
for attempt in inference_retries():
|
|
133
200
|
with attempt:
|
|
134
201
|
try:
|
|
135
|
-
|
|
136
|
-
# Unfortunate, but FastAPI and httpx-ws have slightly different abstractions
|
|
137
|
-
# for sending data, so it's not easy to create a unified wrapper.
|
|
138
|
-
async def forward_to_server():
|
|
139
|
-
while True:
|
|
140
|
-
message = await client_ws.receive()
|
|
141
|
-
if message.get("type") == "websocket.disconnect":
|
|
142
|
-
break
|
|
143
|
-
if "text" in message:
|
|
144
|
-
await server_ws.send_text(message["text"])
|
|
145
|
-
elif "bytes" in message:
|
|
146
|
-
await server_ws.send_bytes(message["bytes"])
|
|
147
|
-
|
|
148
|
-
async def forward_to_client():
|
|
149
|
-
while True:
|
|
150
|
-
message = await server_ws.receive()
|
|
151
|
-
if message is None:
|
|
152
|
-
break
|
|
153
|
-
if isinstance(message, TextMessage):
|
|
154
|
-
await client_ws.send_text(message.data)
|
|
155
|
-
elif isinstance(message, BytesMessage):
|
|
156
|
-
await client_ws.send_bytes(message.data)
|
|
157
|
-
|
|
158
|
-
await client_ws.accept()
|
|
159
|
-
try:
|
|
160
|
-
await asyncio.gather(forward_to_client(), forward_to_server())
|
|
161
|
-
finally:
|
|
162
|
-
await _safe_close_ws(client_ws, logger)
|
|
202
|
+
await _attempt_websocket_proxy(client_ws, proxy_client, logger)
|
|
163
203
|
except httpx_ws_exceptions.HTTPXWSException as e:
|
|
164
204
|
logger.warning(f"WebSocket connection rejected: {e}")
|
|
165
205
|
await _safe_close_ws(client_ws, logger)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
[supervisord]
|
|
2
|
+
pidfile=/tmp/supervisord.pid ; Set PID file location to /tmp to be writable by the non-root user
|
|
2
3
|
nodaemon=true ; Run supervisord in the foreground (useful for containers)
|
|
3
4
|
logfile=/dev/null ; Disable logging to file (send logs to /dev/null)
|
|
4
5
|
logfile_maxbytes=0 ; No size limit on logfile (since logging is disabled)
|
|
@@ -76,7 +76,7 @@ async def parse_body(request: Request) -> bytes:
|
|
|
76
76
|
|
|
77
77
|
|
|
78
78
|
async def _safe_close_websocket(
|
|
79
|
-
ws: WebSocket, reason: Optional[str]
|
|
79
|
+
ws: WebSocket, status_code: int = 1000, reason: Optional[str] = None
|
|
80
80
|
) -> None:
|
|
81
81
|
try:
|
|
82
82
|
await ws.close(code=status_code, reason=reason)
|
|
@@ -257,14 +257,14 @@ class BasetenEndpoints:
|
|
|
257
257
|
try:
|
|
258
258
|
await ws.accept()
|
|
259
259
|
await self._model.websocket(ws)
|
|
260
|
-
await _safe_close_websocket(ws,
|
|
260
|
+
await _safe_close_websocket(ws, status_code=1000, reason=None)
|
|
261
261
|
except WebSocketDisconnect as ws_error:
|
|
262
262
|
logging.info(
|
|
263
263
|
f"Client terminated websocket connection: `{ws_error}`."
|
|
264
264
|
)
|
|
265
265
|
except Exception:
|
|
266
266
|
await _safe_close_websocket(
|
|
267
|
-
ws, errors.MODEL_ERROR_MESSAGE
|
|
267
|
+
ws, status_code=1011, reason=errors.MODEL_ERROR_MESSAGE
|
|
268
268
|
)
|
|
269
269
|
raise # Re raise to let `intercept_exceptions` deal with it.
|
|
270
270
|
|