truss 0.11.6rc102__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- truss/api/__init__.py +5 -2
- truss/base/constants.py +1 -0
- truss/base/trt_llm_config.py +14 -3
- truss/base/truss_config.py +19 -4
- truss/cli/chains_commands.py +49 -1
- truss/cli/cli.py +38 -7
- truss/cli/logs/base_watcher.py +31 -12
- truss/cli/logs/model_log_watcher.py +24 -1
- truss/cli/remote_cli.py +29 -0
- truss/cli/resolvers/chain_team_resolver.py +82 -0
- truss/cli/resolvers/model_team_resolver.py +90 -0
- truss/cli/resolvers/training_project_team_resolver.py +81 -0
- truss/cli/train/cache.py +332 -0
- truss/cli/train/core.py +57 -163
- truss/cli/train/deploy_checkpoints/__init__.py +2 -2
- truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +236 -103
- truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +1 -52
- truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +1 -86
- truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py +1 -85
- truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +1 -56
- truss/cli/train/types.py +18 -9
- truss/cli/train_commands.py +180 -35
- truss/cli/utils/common.py +40 -3
- truss/contexts/image_builder/serving_image_builder.py +17 -4
- truss/remote/baseten/api.py +215 -9
- truss/remote/baseten/core.py +63 -7
- truss/remote/baseten/custom_types.py +1 -0
- truss/remote/baseten/remote.py +42 -2
- truss/remote/baseten/service.py +0 -7
- truss/remote/baseten/utils/transfer.py +5 -2
- truss/templates/base.Dockerfile.jinja +8 -4
- truss/templates/control/control/application.py +51 -26
- truss/templates/control/control/endpoints.py +1 -5
- truss/templates/control/control/helpers/inference_server_process_controller.py +10 -4
- truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py +33 -18
- truss/templates/control/control/server.py +1 -1
- truss/templates/control/requirements.txt +1 -2
- truss/templates/docker_server/proxy.conf.jinja +13 -0
- truss/templates/docker_server/supervisord.conf.jinja +2 -1
- truss/templates/no_build.Dockerfile.jinja +1 -0
- truss/templates/server/requirements.txt +2 -3
- truss/templates/server/truss_server.py +2 -5
- truss/templates/server.Dockerfile.jinja +12 -12
- truss/templates/shared/lazy_data_resolver.py +214 -2
- truss/templates/shared/util.py +6 -5
- truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
- truss/tests/cli/test_chains_cli.py +144 -0
- truss/tests/cli/test_cli.py +134 -1
- truss/tests/cli/test_cli_utils_common.py +11 -0
- truss/tests/cli/test_model_team_resolver.py +279 -0
- truss/tests/cli/train/test_cache_view.py +240 -3
- truss/tests/cli/train/test_deploy_checkpoints.py +2 -846
- truss/tests/cli/train/test_train_cli_core.py +2 -2
- truss/tests/cli/train/test_train_team_parameter.py +395 -0
- truss/tests/conftest.py +187 -0
- truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
- truss/tests/remote/baseten/test_api.py +122 -3
- truss/tests/remote/baseten/test_chain_upload.py +294 -0
- truss/tests/remote/baseten/test_core.py +86 -0
- truss/tests/remote/baseten/test_remote.py +216 -288
- truss/tests/remote/baseten/test_service.py +56 -0
- truss/tests/templates/control/control/conftest.py +20 -0
- truss/tests/templates/control/control/test_endpoints.py +4 -0
- truss/tests/templates/control/control/test_server.py +8 -24
- truss/tests/templates/control/control/test_server_integration.py +4 -2
- truss/tests/test_config.py +21 -12
- truss/tests/test_data/server.Dockerfile +3 -1
- truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
- truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
- truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
- truss/tests/test_model_inference.py +13 -0
- truss/tests/util/test_env_vars.py +8 -3
- truss/util/__init__.py +0 -0
- truss/util/env_vars.py +19 -8
- truss/util/error_utils.py +37 -0
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/METADATA +2 -2
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/RECORD +88 -70
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
- truss_chains/deployment/deployment_client.py +16 -4
- truss_chains/private_types.py +18 -0
- truss_chains/public_api.py +3 -0
- truss_train/definitions.py +6 -4
- truss_train/deployment.py +43 -21
- truss_train/public_api.py +4 -2
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,10 +1,207 @@
|
|
|
1
1
|
import atexit
|
|
2
|
+
import json
|
|
2
3
|
import logging
|
|
3
4
|
import time
|
|
5
|
+
from dataclasses import dataclass
|
|
4
6
|
from functools import lru_cache
|
|
5
7
|
from pathlib import Path
|
|
6
8
|
from threading import Lock, Thread
|
|
7
|
-
from typing import Optional, Union
|
|
9
|
+
from typing import List, Optional, Union
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from prometheus_client import Counter, Gauge, Histogram
|
|
13
|
+
|
|
14
|
+
PROMETHEUS_AVAILABLE = True
|
|
15
|
+
except ImportError:
|
|
16
|
+
PROMETHEUS_AVAILABLE = False
|
|
17
|
+
METRICS_REGISTERED = False
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(frozen=True)
|
|
21
|
+
class FileDownloadMetric:
|
|
22
|
+
file_name: str
|
|
23
|
+
file_size_bytes: int
|
|
24
|
+
download_time_secs: float
|
|
25
|
+
download_speed_mb_s: float
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass(frozen=True)
|
|
29
|
+
class TrussTransferStats:
|
|
30
|
+
total_manifest_size_bytes: int
|
|
31
|
+
total_download_time_secs: float
|
|
32
|
+
total_aggregated_mb_s: Optional[float]
|
|
33
|
+
file_downloads: List[FileDownloadMetric]
|
|
34
|
+
b10fs_read_speed_mbps: Optional[float]
|
|
35
|
+
b10fs_decision_to_use: bool
|
|
36
|
+
b10fs_enabled: bool
|
|
37
|
+
b10fs_hot_starts_files: int
|
|
38
|
+
b10fs_hot_starts_bytes: int
|
|
39
|
+
b10fs_cold_starts_files: int
|
|
40
|
+
b10fs_cold_starts_bytes: int
|
|
41
|
+
success: bool
|
|
42
|
+
timestamp: int
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def from_json_file(cls, path: Path) -> Optional["TrussTransferStats"]:
|
|
46
|
+
if not path.exists():
|
|
47
|
+
return None
|
|
48
|
+
try:
|
|
49
|
+
with open(path) as f:
|
|
50
|
+
data = json.load(f)
|
|
51
|
+
file_downloads = [
|
|
52
|
+
FileDownloadMetric(**fd) for fd in data.get("file_downloads", [])
|
|
53
|
+
]
|
|
54
|
+
return cls(
|
|
55
|
+
total_manifest_size_bytes=data["total_manifest_size_bytes"],
|
|
56
|
+
total_download_time_secs=data["total_download_time_secs"],
|
|
57
|
+
total_aggregated_mb_s=data.get("total_aggregated_mb_s"),
|
|
58
|
+
file_downloads=file_downloads,
|
|
59
|
+
b10fs_read_speed_mbps=data.get("b10fs_read_speed_mbps"),
|
|
60
|
+
b10fs_decision_to_use=data["b10fs_decision_to_use"],
|
|
61
|
+
b10fs_enabled=data["b10fs_enabled"],
|
|
62
|
+
b10fs_hot_starts_files=data["b10fs_hot_starts_files"],
|
|
63
|
+
b10fs_hot_starts_bytes=data["b10fs_hot_starts_bytes"],
|
|
64
|
+
b10fs_cold_starts_files=data["b10fs_cold_starts_files"],
|
|
65
|
+
b10fs_cold_starts_bytes=data["b10fs_cold_starts_bytes"],
|
|
66
|
+
success=data["success"],
|
|
67
|
+
timestamp=data["timestamp"],
|
|
68
|
+
)
|
|
69
|
+
except Exception:
|
|
70
|
+
return None
|
|
71
|
+
|
|
72
|
+
def publish_to_prometheus(self, hidden_time: float = 0.0):
|
|
73
|
+
"""Publish transfer stats to Prometheus metrics. Only runs once."""
|
|
74
|
+
if not PROMETHEUS_AVAILABLE:
|
|
75
|
+
return
|
|
76
|
+
global METRICS_REGISTERED
|
|
77
|
+
|
|
78
|
+
if METRICS_REGISTERED:
|
|
79
|
+
logging.info(
|
|
80
|
+
"Model cache metrics already registered, skipping."
|
|
81
|
+
) # this should never happen
|
|
82
|
+
return
|
|
83
|
+
else:
|
|
84
|
+
# Ensure metrics are only registered once
|
|
85
|
+
METRICS_REGISTERED = True
|
|
86
|
+
|
|
87
|
+
# Define metrics with model_cache prefix
|
|
88
|
+
manifest_size_gauge = Gauge(
|
|
89
|
+
"model_cache_manifest_size_bytes", "Total manifest size in bytes"
|
|
90
|
+
)
|
|
91
|
+
# histograms have intentially wide buckets to capture a variety of download times
|
|
92
|
+
download_time_histogram = Histogram(
|
|
93
|
+
"model_cache_download_time_seconds",
|
|
94
|
+
"Total download time in seconds",
|
|
95
|
+
buckets=[0]
|
|
96
|
+
+ [
|
|
97
|
+
2**i
|
|
98
|
+
for i in range(-3, 11) # = [0.125, .. 2048] seconds
|
|
99
|
+
]
|
|
100
|
+
+ [float("inf")],
|
|
101
|
+
)
|
|
102
|
+
download_speed_gauge = Gauge(
|
|
103
|
+
"model_cache_download_speed_mbps", "Aggregated download speed in MB/s"
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# File download metrics (aggregated)
|
|
107
|
+
files_downloaded_counter = Counter(
|
|
108
|
+
"model_cache_files_downloaded_total", "Total number of files downloaded"
|
|
109
|
+
)
|
|
110
|
+
total_file_size_counter = Counter(
|
|
111
|
+
"model_cache_file_size_bytes_total",
|
|
112
|
+
"Total size of downloaded files in bytes",
|
|
113
|
+
)
|
|
114
|
+
file_download_hidden_time_gauge = Gauge(
|
|
115
|
+
"model_cache_file_download_hidden_time_seconds",
|
|
116
|
+
"Total time hidden from user by starting the import before user code (seconds)",
|
|
117
|
+
)
|
|
118
|
+
file_download_time_histogram = Histogram(
|
|
119
|
+
"model_cache_file_download_time_seconds",
|
|
120
|
+
"File download time distribution",
|
|
121
|
+
buckets=[0]
|
|
122
|
+
+ [
|
|
123
|
+
2**i
|
|
124
|
+
for i in range(-3, 11) # = [0.125, .. 2048] seconds
|
|
125
|
+
]
|
|
126
|
+
+ [float("inf")],
|
|
127
|
+
)
|
|
128
|
+
file_download_speed_histogram = Histogram(
|
|
129
|
+
"model_cache_file_download_speed_mbps",
|
|
130
|
+
"File download speed distribution",
|
|
131
|
+
buckets=[0]
|
|
132
|
+
+ [
|
|
133
|
+
2**i
|
|
134
|
+
for i in range(-1, 12) # = [0.5, .. 4096] MB/s
|
|
135
|
+
]
|
|
136
|
+
+ [float("inf")],
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# B10FS specific metrics
|
|
140
|
+
b10fs_enabled_gauge = Gauge(
|
|
141
|
+
"model_cache_b10fs_enabled", "Whether B10FS is enabled"
|
|
142
|
+
)
|
|
143
|
+
b10fs_decision_gauge = Gauge(
|
|
144
|
+
"model_cache_b10fs_decision_to_use", "Whether B10FS was chosen for use"
|
|
145
|
+
)
|
|
146
|
+
b10fs_read_speed_gauge = Gauge(
|
|
147
|
+
"model_cache_b10fs_read_speed_mbps", "B10FS read speed in Mbps"
|
|
148
|
+
)
|
|
149
|
+
b10fs_hot_files_gauge = Gauge(
|
|
150
|
+
"model_cache_b10fs_hot_starts_files", "Number of hot start files"
|
|
151
|
+
)
|
|
152
|
+
b10fs_hot_bytes_gauge = Gauge(
|
|
153
|
+
"model_cache_b10fs_hot_starts_bytes", "Number of hot start bytes"
|
|
154
|
+
)
|
|
155
|
+
b10fs_cold_files_gauge = Gauge(
|
|
156
|
+
"model_cache_b10fs_cold_starts_files", "Number of cold start files"
|
|
157
|
+
)
|
|
158
|
+
b10fs_cold_bytes_gauge = Gauge(
|
|
159
|
+
"model_cache_b10fs_cold_starts_bytes", "Number of cold start bytes"
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# Transfer success metric
|
|
163
|
+
transfer_success_counter = Counter(
|
|
164
|
+
"model_cache_transfer_success_total",
|
|
165
|
+
"Total successful transfers",
|
|
166
|
+
["success"],
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# Set main transfer metrics
|
|
170
|
+
manifest_size_gauge.set(self.total_manifest_size_bytes)
|
|
171
|
+
download_time_histogram.observe(self.total_download_time_secs)
|
|
172
|
+
file_download_hidden_time_gauge.set(hidden_time)
|
|
173
|
+
|
|
174
|
+
if self.total_aggregated_mb_s is not None:
|
|
175
|
+
download_speed_gauge.set(self.total_aggregated_mb_s)
|
|
176
|
+
|
|
177
|
+
# Aggregate file download metrics
|
|
178
|
+
total_files = len(self.file_downloads)
|
|
179
|
+
total_file_bytes = sum(fd.file_size_bytes for fd in self.file_downloads)
|
|
180
|
+
|
|
181
|
+
files_downloaded_counter.inc(total_files)
|
|
182
|
+
total_file_size_counter.inc(total_file_bytes)
|
|
183
|
+
|
|
184
|
+
# Record individual file metrics for distribution
|
|
185
|
+
for fd in self.file_downloads:
|
|
186
|
+
if fd.file_size_bytes > 1 * 1024 * 1024: # Only log files larger than 1MB
|
|
187
|
+
file_download_time_histogram.observe(fd.download_time_secs)
|
|
188
|
+
file_download_speed_histogram.observe(fd.download_speed_mb_s)
|
|
189
|
+
|
|
190
|
+
# B10FS metrics
|
|
191
|
+
b10fs_enabled_gauge.set(1 if self.b10fs_enabled else 0)
|
|
192
|
+
b10fs_decision_gauge.set(1 if self.b10fs_decision_to_use else 0)
|
|
193
|
+
|
|
194
|
+
if self.b10fs_read_speed_mbps is not None:
|
|
195
|
+
b10fs_read_speed_gauge.set(self.b10fs_read_speed_mbps)
|
|
196
|
+
|
|
197
|
+
b10fs_hot_files_gauge.set(self.b10fs_hot_starts_files)
|
|
198
|
+
b10fs_hot_bytes_gauge.set(self.b10fs_hot_starts_bytes)
|
|
199
|
+
b10fs_cold_files_gauge.set(self.b10fs_cold_starts_files)
|
|
200
|
+
b10fs_cold_bytes_gauge.set(self.b10fs_cold_starts_bytes)
|
|
201
|
+
|
|
202
|
+
# Success metric
|
|
203
|
+
transfer_success_counter.labels(success=str(self.success)).inc()
|
|
204
|
+
|
|
8
205
|
|
|
9
206
|
LAZY_DATA_RESOLVER_PATH = [
|
|
10
207
|
# synced with pub static LAZY_DATA_RESOLVER_PATHS: &[&str]
|
|
@@ -129,6 +326,9 @@ class LazyDataResolverV2:
|
|
|
129
326
|
|
|
130
327
|
"""
|
|
131
328
|
start_lock = time.time()
|
|
329
|
+
publish_stats = (
|
|
330
|
+
log_stats and not self._is_collected_by_user
|
|
331
|
+
) # only publish results once per resolver
|
|
132
332
|
self._is_collected_by_user = issue_collect or self._is_collected_by_user
|
|
133
333
|
with self._lock:
|
|
134
334
|
result = self._fetch()
|
|
@@ -137,8 +337,20 @@ class LazyDataResolverV2:
|
|
|
137
337
|
f"Error occurred while fetching data: {result}"
|
|
138
338
|
) from result
|
|
139
339
|
if log_stats and result:
|
|
340
|
+
# TODO: instument the stats, which are written to /tmp/truss_transfer_stats.json
|
|
341
|
+
# also add fetch time, and blocking time
|
|
342
|
+
# TrussTransferStats
|
|
343
|
+
fetch_t = time.time() - self._start_time
|
|
344
|
+
start_lock_t = time.time() - start_lock
|
|
345
|
+
stats = TrussTransferStats.from_json_file(
|
|
346
|
+
Path("/tmp/truss_transfer_stats.json")
|
|
347
|
+
)
|
|
348
|
+
if stats and publish_stats:
|
|
349
|
+
self.logger.info(f"model_cache: {stats}")
|
|
350
|
+
# Publish stats to Prometheus
|
|
351
|
+
stats.publish_to_prometheus()
|
|
140
352
|
self.logger.info(
|
|
141
|
-
f"model_cache: Fetch took {
|
|
353
|
+
f"model_cache: Fetch took {fetch_t:.2f} seconds, of which {start_lock_t:.2f} seconds were spent blocking."
|
|
142
354
|
)
|
|
143
355
|
return result
|
|
144
356
|
|
truss/templates/shared/util.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import multiprocessing
|
|
2
2
|
import os
|
|
3
3
|
import sys
|
|
4
|
-
from typing import List
|
|
4
|
+
from typing import List, Optional
|
|
5
5
|
|
|
6
6
|
import psutil
|
|
7
7
|
|
|
@@ -62,7 +62,10 @@ def all_processes_dead(procs: List[multiprocessing.Process]) -> bool:
|
|
|
62
62
|
return True
|
|
63
63
|
|
|
64
64
|
|
|
65
|
-
def kill_child_processes(
|
|
65
|
+
def kill_child_processes(
|
|
66
|
+
parent_pid: int,
|
|
67
|
+
timeout_seconds: Optional[float] = CHILD_PROCESS_WAIT_TIMEOUT_SECONDS,
|
|
68
|
+
):
|
|
66
69
|
try:
|
|
67
70
|
parent = psutil.Process(parent_pid)
|
|
68
71
|
except psutil.NoSuchProcess:
|
|
@@ -70,8 +73,6 @@ def kill_child_processes(parent_pid: int):
|
|
|
70
73
|
children = parent.children(recursive=True)
|
|
71
74
|
for process in children:
|
|
72
75
|
process.terminate()
|
|
73
|
-
gone, alive = psutil.wait_procs(
|
|
74
|
-
children, timeout=CHILD_PROCESS_WAIT_TIMEOUT_SECONDS
|
|
75
|
-
)
|
|
76
|
+
gone, alive = psutil.wait_procs(children, timeout=timeout_seconds)
|
|
76
77
|
for process in alive:
|
|
77
78
|
process.kill()
|