truss 0.11.6rc102__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. truss/api/__init__.py +5 -2
  2. truss/base/constants.py +1 -0
  3. truss/base/trt_llm_config.py +14 -3
  4. truss/base/truss_config.py +19 -4
  5. truss/cli/chains_commands.py +49 -1
  6. truss/cli/cli.py +38 -7
  7. truss/cli/logs/base_watcher.py +31 -12
  8. truss/cli/logs/model_log_watcher.py +24 -1
  9. truss/cli/remote_cli.py +29 -0
  10. truss/cli/resolvers/chain_team_resolver.py +82 -0
  11. truss/cli/resolvers/model_team_resolver.py +90 -0
  12. truss/cli/resolvers/training_project_team_resolver.py +81 -0
  13. truss/cli/train/cache.py +332 -0
  14. truss/cli/train/core.py +57 -163
  15. truss/cli/train/deploy_checkpoints/__init__.py +2 -2
  16. truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +236 -103
  17. truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +1 -52
  18. truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +1 -86
  19. truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py +1 -85
  20. truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +1 -56
  21. truss/cli/train/types.py +18 -9
  22. truss/cli/train_commands.py +180 -35
  23. truss/cli/utils/common.py +40 -3
  24. truss/contexts/image_builder/serving_image_builder.py +17 -4
  25. truss/remote/baseten/api.py +215 -9
  26. truss/remote/baseten/core.py +63 -7
  27. truss/remote/baseten/custom_types.py +1 -0
  28. truss/remote/baseten/remote.py +42 -2
  29. truss/remote/baseten/service.py +0 -7
  30. truss/remote/baseten/utils/transfer.py +5 -2
  31. truss/templates/base.Dockerfile.jinja +8 -4
  32. truss/templates/control/control/application.py +51 -26
  33. truss/templates/control/control/endpoints.py +1 -5
  34. truss/templates/control/control/helpers/inference_server_process_controller.py +10 -4
  35. truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py +33 -18
  36. truss/templates/control/control/server.py +1 -1
  37. truss/templates/control/requirements.txt +1 -2
  38. truss/templates/docker_server/proxy.conf.jinja +13 -0
  39. truss/templates/docker_server/supervisord.conf.jinja +2 -1
  40. truss/templates/no_build.Dockerfile.jinja +1 -0
  41. truss/templates/server/requirements.txt +2 -3
  42. truss/templates/server/truss_server.py +2 -5
  43. truss/templates/server.Dockerfile.jinja +12 -12
  44. truss/templates/shared/lazy_data_resolver.py +214 -2
  45. truss/templates/shared/util.py +6 -5
  46. truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
  47. truss/tests/cli/test_chains_cli.py +144 -0
  48. truss/tests/cli/test_cli.py +134 -1
  49. truss/tests/cli/test_cli_utils_common.py +11 -0
  50. truss/tests/cli/test_model_team_resolver.py +279 -0
  51. truss/tests/cli/train/test_cache_view.py +240 -3
  52. truss/tests/cli/train/test_deploy_checkpoints.py +2 -846
  53. truss/tests/cli/train/test_train_cli_core.py +2 -2
  54. truss/tests/cli/train/test_train_team_parameter.py +395 -0
  55. truss/tests/conftest.py +187 -0
  56. truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
  57. truss/tests/remote/baseten/test_api.py +122 -3
  58. truss/tests/remote/baseten/test_chain_upload.py +294 -0
  59. truss/tests/remote/baseten/test_core.py +86 -0
  60. truss/tests/remote/baseten/test_remote.py +216 -288
  61. truss/tests/remote/baseten/test_service.py +56 -0
  62. truss/tests/templates/control/control/conftest.py +20 -0
  63. truss/tests/templates/control/control/test_endpoints.py +4 -0
  64. truss/tests/templates/control/control/test_server.py +8 -24
  65. truss/tests/templates/control/control/test_server_integration.py +4 -2
  66. truss/tests/test_config.py +21 -12
  67. truss/tests/test_data/server.Dockerfile +3 -1
  68. truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
  69. truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
  70. truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
  71. truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
  72. truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
  73. truss/tests/test_model_inference.py +13 -0
  74. truss/tests/util/test_env_vars.py +8 -3
  75. truss/util/__init__.py +0 -0
  76. truss/util/env_vars.py +19 -8
  77. truss/util/error_utils.py +37 -0
  78. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/METADATA +2 -2
  79. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/RECORD +88 -70
  80. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
  81. truss_chains/deployment/deployment_client.py +16 -4
  82. truss_chains/private_types.py +18 -0
  83. truss_chains/public_api.py +3 -0
  84. truss_train/definitions.py +6 -4
  85. truss_train/deployment.py +43 -21
  86. truss_train/public_api.py +4 -2
  87. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
  88. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
@@ -1,10 +1,207 @@
1
1
  import atexit
2
+ import json
2
3
  import logging
3
4
  import time
5
+ from dataclasses import dataclass
4
6
  from functools import lru_cache
5
7
  from pathlib import Path
6
8
  from threading import Lock, Thread
7
- from typing import Optional, Union
9
+ from typing import List, Optional, Union
10
+
11
+ try:
12
+ from prometheus_client import Counter, Gauge, Histogram
13
+
14
+ PROMETHEUS_AVAILABLE = True
15
+ except ImportError:
16
+ PROMETHEUS_AVAILABLE = False
17
+ METRICS_REGISTERED = False
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class FileDownloadMetric:
22
+ file_name: str
23
+ file_size_bytes: int
24
+ download_time_secs: float
25
+ download_speed_mb_s: float
26
+
27
+
28
+ @dataclass(frozen=True)
29
+ class TrussTransferStats:
30
+ total_manifest_size_bytes: int
31
+ total_download_time_secs: float
32
+ total_aggregated_mb_s: Optional[float]
33
+ file_downloads: List[FileDownloadMetric]
34
+ b10fs_read_speed_mbps: Optional[float]
35
+ b10fs_decision_to_use: bool
36
+ b10fs_enabled: bool
37
+ b10fs_hot_starts_files: int
38
+ b10fs_hot_starts_bytes: int
39
+ b10fs_cold_starts_files: int
40
+ b10fs_cold_starts_bytes: int
41
+ success: bool
42
+ timestamp: int
43
+
44
+ @classmethod
45
+ def from_json_file(cls, path: Path) -> Optional["TrussTransferStats"]:
46
+ if not path.exists():
47
+ return None
48
+ try:
49
+ with open(path) as f:
50
+ data = json.load(f)
51
+ file_downloads = [
52
+ FileDownloadMetric(**fd) for fd in data.get("file_downloads", [])
53
+ ]
54
+ return cls(
55
+ total_manifest_size_bytes=data["total_manifest_size_bytes"],
56
+ total_download_time_secs=data["total_download_time_secs"],
57
+ total_aggregated_mb_s=data.get("total_aggregated_mb_s"),
58
+ file_downloads=file_downloads,
59
+ b10fs_read_speed_mbps=data.get("b10fs_read_speed_mbps"),
60
+ b10fs_decision_to_use=data["b10fs_decision_to_use"],
61
+ b10fs_enabled=data["b10fs_enabled"],
62
+ b10fs_hot_starts_files=data["b10fs_hot_starts_files"],
63
+ b10fs_hot_starts_bytes=data["b10fs_hot_starts_bytes"],
64
+ b10fs_cold_starts_files=data["b10fs_cold_starts_files"],
65
+ b10fs_cold_starts_bytes=data["b10fs_cold_starts_bytes"],
66
+ success=data["success"],
67
+ timestamp=data["timestamp"],
68
+ )
69
+ except Exception:
70
+ return None
71
+
72
+ def publish_to_prometheus(self, hidden_time: float = 0.0):
73
+ """Publish transfer stats to Prometheus metrics. Only runs once."""
74
+ if not PROMETHEUS_AVAILABLE:
75
+ return
76
+ global METRICS_REGISTERED
77
+
78
+ if METRICS_REGISTERED:
79
+ logging.info(
80
+ "Model cache metrics already registered, skipping."
81
+ ) # this should never happen
82
+ return
83
+ else:
84
+ # Ensure metrics are only registered once
85
+ METRICS_REGISTERED = True
86
+
87
+ # Define metrics with model_cache prefix
88
+ manifest_size_gauge = Gauge(
89
+ "model_cache_manifest_size_bytes", "Total manifest size in bytes"
90
+ )
91
+ # histograms have intentially wide buckets to capture a variety of download times
92
+ download_time_histogram = Histogram(
93
+ "model_cache_download_time_seconds",
94
+ "Total download time in seconds",
95
+ buckets=[0]
96
+ + [
97
+ 2**i
98
+ for i in range(-3, 11) # = [0.125, .. 2048] seconds
99
+ ]
100
+ + [float("inf")],
101
+ )
102
+ download_speed_gauge = Gauge(
103
+ "model_cache_download_speed_mbps", "Aggregated download speed in MB/s"
104
+ )
105
+
106
+ # File download metrics (aggregated)
107
+ files_downloaded_counter = Counter(
108
+ "model_cache_files_downloaded_total", "Total number of files downloaded"
109
+ )
110
+ total_file_size_counter = Counter(
111
+ "model_cache_file_size_bytes_total",
112
+ "Total size of downloaded files in bytes",
113
+ )
114
+ file_download_hidden_time_gauge = Gauge(
115
+ "model_cache_file_download_hidden_time_seconds",
116
+ "Total time hidden from user by starting the import before user code (seconds)",
117
+ )
118
+ file_download_time_histogram = Histogram(
119
+ "model_cache_file_download_time_seconds",
120
+ "File download time distribution",
121
+ buckets=[0]
122
+ + [
123
+ 2**i
124
+ for i in range(-3, 11) # = [0.125, .. 2048] seconds
125
+ ]
126
+ + [float("inf")],
127
+ )
128
+ file_download_speed_histogram = Histogram(
129
+ "model_cache_file_download_speed_mbps",
130
+ "File download speed distribution",
131
+ buckets=[0]
132
+ + [
133
+ 2**i
134
+ for i in range(-1, 12) # = [0.5, .. 4096] MB/s
135
+ ]
136
+ + [float("inf")],
137
+ )
138
+
139
+ # B10FS specific metrics
140
+ b10fs_enabled_gauge = Gauge(
141
+ "model_cache_b10fs_enabled", "Whether B10FS is enabled"
142
+ )
143
+ b10fs_decision_gauge = Gauge(
144
+ "model_cache_b10fs_decision_to_use", "Whether B10FS was chosen for use"
145
+ )
146
+ b10fs_read_speed_gauge = Gauge(
147
+ "model_cache_b10fs_read_speed_mbps", "B10FS read speed in Mbps"
148
+ )
149
+ b10fs_hot_files_gauge = Gauge(
150
+ "model_cache_b10fs_hot_starts_files", "Number of hot start files"
151
+ )
152
+ b10fs_hot_bytes_gauge = Gauge(
153
+ "model_cache_b10fs_hot_starts_bytes", "Number of hot start bytes"
154
+ )
155
+ b10fs_cold_files_gauge = Gauge(
156
+ "model_cache_b10fs_cold_starts_files", "Number of cold start files"
157
+ )
158
+ b10fs_cold_bytes_gauge = Gauge(
159
+ "model_cache_b10fs_cold_starts_bytes", "Number of cold start bytes"
160
+ )
161
+
162
+ # Transfer success metric
163
+ transfer_success_counter = Counter(
164
+ "model_cache_transfer_success_total",
165
+ "Total successful transfers",
166
+ ["success"],
167
+ )
168
+
169
+ # Set main transfer metrics
170
+ manifest_size_gauge.set(self.total_manifest_size_bytes)
171
+ download_time_histogram.observe(self.total_download_time_secs)
172
+ file_download_hidden_time_gauge.set(hidden_time)
173
+
174
+ if self.total_aggregated_mb_s is not None:
175
+ download_speed_gauge.set(self.total_aggregated_mb_s)
176
+
177
+ # Aggregate file download metrics
178
+ total_files = len(self.file_downloads)
179
+ total_file_bytes = sum(fd.file_size_bytes for fd in self.file_downloads)
180
+
181
+ files_downloaded_counter.inc(total_files)
182
+ total_file_size_counter.inc(total_file_bytes)
183
+
184
+ # Record individual file metrics for distribution
185
+ for fd in self.file_downloads:
186
+ if fd.file_size_bytes > 1 * 1024 * 1024: # Only log files larger than 1MB
187
+ file_download_time_histogram.observe(fd.download_time_secs)
188
+ file_download_speed_histogram.observe(fd.download_speed_mb_s)
189
+
190
+ # B10FS metrics
191
+ b10fs_enabled_gauge.set(1 if self.b10fs_enabled else 0)
192
+ b10fs_decision_gauge.set(1 if self.b10fs_decision_to_use else 0)
193
+
194
+ if self.b10fs_read_speed_mbps is not None:
195
+ b10fs_read_speed_gauge.set(self.b10fs_read_speed_mbps)
196
+
197
+ b10fs_hot_files_gauge.set(self.b10fs_hot_starts_files)
198
+ b10fs_hot_bytes_gauge.set(self.b10fs_hot_starts_bytes)
199
+ b10fs_cold_files_gauge.set(self.b10fs_cold_starts_files)
200
+ b10fs_cold_bytes_gauge.set(self.b10fs_cold_starts_bytes)
201
+
202
+ # Success metric
203
+ transfer_success_counter.labels(success=str(self.success)).inc()
204
+
8
205
 
9
206
  LAZY_DATA_RESOLVER_PATH = [
10
207
  # synced with pub static LAZY_DATA_RESOLVER_PATHS: &[&str]
@@ -129,6 +326,9 @@ class LazyDataResolverV2:
129
326
 
130
327
  """
131
328
  start_lock = time.time()
329
+ publish_stats = (
330
+ log_stats and not self._is_collected_by_user
331
+ ) # only publish results once per resolver
132
332
  self._is_collected_by_user = issue_collect or self._is_collected_by_user
133
333
  with self._lock:
134
334
  result = self._fetch()
@@ -137,8 +337,20 @@ class LazyDataResolverV2:
137
337
  f"Error occurred while fetching data: {result}"
138
338
  ) from result
139
339
  if log_stats and result:
340
+ # TODO: instument the stats, which are written to /tmp/truss_transfer_stats.json
341
+ # also add fetch time, and blocking time
342
+ # TrussTransferStats
343
+ fetch_t = time.time() - self._start_time
344
+ start_lock_t = time.time() - start_lock
345
+ stats = TrussTransferStats.from_json_file(
346
+ Path("/tmp/truss_transfer_stats.json")
347
+ )
348
+ if stats and publish_stats:
349
+ self.logger.info(f"model_cache: {stats}")
350
+ # Publish stats to Prometheus
351
+ stats.publish_to_prometheus()
140
352
  self.logger.info(
141
- f"model_cache: Fetch took {time.time() - self._start_time:.2f} seconds, of which {time.time() - start_lock:.2f} seconds were spent blocking."
353
+ f"model_cache: Fetch took {fetch_t:.2f} seconds, of which {start_lock_t:.2f} seconds were spent blocking."
142
354
  )
143
355
  return result
144
356
 
@@ -1,7 +1,7 @@
1
1
  import multiprocessing
2
2
  import os
3
3
  import sys
4
- from typing import List
4
+ from typing import List, Optional
5
5
 
6
6
  import psutil
7
7
 
@@ -62,7 +62,10 @@ def all_processes_dead(procs: List[multiprocessing.Process]) -> bool:
62
62
  return True
63
63
 
64
64
 
65
- def kill_child_processes(parent_pid: int):
65
+ def kill_child_processes(
66
+ parent_pid: int,
67
+ timeout_seconds: Optional[float] = CHILD_PROCESS_WAIT_TIMEOUT_SECONDS,
68
+ ):
66
69
  try:
67
70
  parent = psutil.Process(parent_pid)
68
71
  except psutil.NoSuchProcess:
@@ -70,8 +73,6 @@ def kill_child_processes(parent_pid: int):
70
73
  children = parent.children(recursive=True)
71
74
  for process in children:
72
75
  process.terminate()
73
- gone, alive = psutil.wait_procs(
74
- children, timeout=CHILD_PROCESS_WAIT_TIMEOUT_SECONDS
75
- )
76
+ gone, alive = psutil.wait_procs(children, timeout=timeout_seconds)
76
77
  for process in alive:
77
78
  process.kill()