truss 0.11.6rc102__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. truss/api/__init__.py +5 -2
  2. truss/base/constants.py +1 -0
  3. truss/base/trt_llm_config.py +14 -3
  4. truss/base/truss_config.py +19 -4
  5. truss/cli/chains_commands.py +49 -1
  6. truss/cli/cli.py +38 -7
  7. truss/cli/logs/base_watcher.py +31 -12
  8. truss/cli/logs/model_log_watcher.py +24 -1
  9. truss/cli/remote_cli.py +29 -0
  10. truss/cli/resolvers/chain_team_resolver.py +82 -0
  11. truss/cli/resolvers/model_team_resolver.py +90 -0
  12. truss/cli/resolvers/training_project_team_resolver.py +81 -0
  13. truss/cli/train/cache.py +332 -0
  14. truss/cli/train/core.py +57 -163
  15. truss/cli/train/deploy_checkpoints/__init__.py +2 -2
  16. truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +236 -103
  17. truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +1 -52
  18. truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +1 -86
  19. truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py +1 -85
  20. truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +1 -56
  21. truss/cli/train/types.py +18 -9
  22. truss/cli/train_commands.py +180 -35
  23. truss/cli/utils/common.py +40 -3
  24. truss/contexts/image_builder/serving_image_builder.py +17 -4
  25. truss/remote/baseten/api.py +215 -9
  26. truss/remote/baseten/core.py +63 -7
  27. truss/remote/baseten/custom_types.py +1 -0
  28. truss/remote/baseten/remote.py +42 -2
  29. truss/remote/baseten/service.py +0 -7
  30. truss/remote/baseten/utils/transfer.py +5 -2
  31. truss/templates/base.Dockerfile.jinja +8 -4
  32. truss/templates/control/control/application.py +51 -26
  33. truss/templates/control/control/endpoints.py +1 -5
  34. truss/templates/control/control/helpers/inference_server_process_controller.py +10 -4
  35. truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py +33 -18
  36. truss/templates/control/control/server.py +1 -1
  37. truss/templates/control/requirements.txt +1 -2
  38. truss/templates/docker_server/proxy.conf.jinja +13 -0
  39. truss/templates/docker_server/supervisord.conf.jinja +2 -1
  40. truss/templates/no_build.Dockerfile.jinja +1 -0
  41. truss/templates/server/requirements.txt +2 -3
  42. truss/templates/server/truss_server.py +2 -5
  43. truss/templates/server.Dockerfile.jinja +12 -12
  44. truss/templates/shared/lazy_data_resolver.py +214 -2
  45. truss/templates/shared/util.py +6 -5
  46. truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
  47. truss/tests/cli/test_chains_cli.py +144 -0
  48. truss/tests/cli/test_cli.py +134 -1
  49. truss/tests/cli/test_cli_utils_common.py +11 -0
  50. truss/tests/cli/test_model_team_resolver.py +279 -0
  51. truss/tests/cli/train/test_cache_view.py +240 -3
  52. truss/tests/cli/train/test_deploy_checkpoints.py +2 -846
  53. truss/tests/cli/train/test_train_cli_core.py +2 -2
  54. truss/tests/cli/train/test_train_team_parameter.py +395 -0
  55. truss/tests/conftest.py +187 -0
  56. truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
  57. truss/tests/remote/baseten/test_api.py +122 -3
  58. truss/tests/remote/baseten/test_chain_upload.py +294 -0
  59. truss/tests/remote/baseten/test_core.py +86 -0
  60. truss/tests/remote/baseten/test_remote.py +216 -288
  61. truss/tests/remote/baseten/test_service.py +56 -0
  62. truss/tests/templates/control/control/conftest.py +20 -0
  63. truss/tests/templates/control/control/test_endpoints.py +4 -0
  64. truss/tests/templates/control/control/test_server.py +8 -24
  65. truss/tests/templates/control/control/test_server_integration.py +4 -2
  66. truss/tests/test_config.py +21 -12
  67. truss/tests/test_data/server.Dockerfile +3 -1
  68. truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
  69. truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
  70. truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
  71. truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
  72. truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
  73. truss/tests/test_model_inference.py +13 -0
  74. truss/tests/util/test_env_vars.py +8 -3
  75. truss/util/__init__.py +0 -0
  76. truss/util/env_vars.py +19 -8
  77. truss/util/error_utils.py +37 -0
  78. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/METADATA +2 -2
  79. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/RECORD +88 -70
  80. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
  81. truss_chains/deployment/deployment_client.py +16 -4
  82. truss_chains/private_types.py +18 -0
  83. truss_chains/public_api.py +3 -0
  84. truss_train/definitions.py +6 -4
  85. truss_train/deployment.py +43 -21
  86. truss_train/public_api.py +4 -2
  87. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
  88. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,81 @@
1
+ """Team resolution logic for training projects."""
2
+
3
+ from typing import Optional
4
+
5
+ import click
6
+
7
+ from truss.cli import remote_cli
8
+ from truss.remote.baseten.remote import BasetenRemote
9
+
10
+
11
+ def resolve_training_project_team_name(
12
+ remote_provider: BasetenRemote,
13
+ provided_team_name: Optional[str],
14
+ existing_project_name: Optional[str] = None,
15
+ existing_teams: Optional[dict[str, dict[str, str]]] = None,
16
+ ) -> tuple[Optional[str], Optional[str]]:
17
+ """Resolve team name and team_id from provided team name or by prompting the user.
18
+ Returns a tuple of (team_name, team_id).
19
+ This function handles 8 distinct scenarios organized into 3 high-level categories:
20
+
21
+ HIGH-LEVEL SCENARIO 1: --team PROVIDED
22
+ SCENARIO 1: Valid team name, user has access
23
+ → Returns (team_name, team_id) for that team (no prompt, no error)
24
+ SCENARIO 2: Invalid team name (does not exist)
25
+ → Raises ClickException with error message listing available teams
26
+
27
+ HIGH-LEVEL SCENARIO 2: --team NOT PROVIDED, Training project does not exist
28
+ SCENARIO 3: User has multiple teams, no existing project
29
+ → Prompts user to select a team via inquire_team()
30
+ SCENARIO 6: User has exactly one team, no existing project
31
+ → Returns (team_name, team_id) for the single team automatically (no prompt)
32
+
33
+ HIGH-LEVEL SCENARIO 3: --team NOT PROVIDED, Training project exists
34
+ SCENARIO 4: User has multiple teams, existing project in exactly one team
35
+ → Auto-detects and returns (team_name, team_id) for that team (no prompt)
36
+ SCENARIO 5: User has multiple teams, existing project exists in multiple teams
37
+ → Prompts user to select a team via inquire_team()
38
+ SCENARIO 7: User has exactly one team, existing project matches the team
39
+ → Auto-detects and returns (team_name, team_id) for the single team (no prompt)
40
+ SCENARIO 8: User has exactly one team, existing project exists in different team
41
+ → Returns (team_name, team_id) for the single team automatically (no prompt, uses user's only team)
42
+ """
43
+ if existing_teams is None:
44
+ existing_teams = remote_provider.api.get_teams()
45
+
46
+ def _get_team_id(team_name: Optional[str]) -> Optional[str]:
47
+ if team_name and existing_teams:
48
+ team_data = existing_teams.get(team_name)
49
+ return team_data["id"] if team_data else None
50
+ return None
51
+
52
+ if provided_team_name is not None:
53
+ if provided_team_name not in existing_teams:
54
+ available_teams_str = remote_cli.format_available_teams(existing_teams)
55
+ raise click.ClickException(
56
+ f"Team '{provided_team_name}' does not exist. Available teams: {available_teams_str}"
57
+ )
58
+ return (provided_team_name, _get_team_id(provided_team_name))
59
+
60
+ existing_projects = None
61
+ if existing_project_name is not None:
62
+ existing_projects = remote_provider.api.list_training_projects()
63
+ matching_projects = [
64
+ p for p in existing_projects if p.get("name") == existing_project_name
65
+ ]
66
+
67
+ if len(matching_projects) > 1:
68
+ selected_team_name = remote_cli.inquire_team(existing_teams=existing_teams)
69
+ return (selected_team_name, _get_team_id(selected_team_name))
70
+
71
+ if len(matching_projects) == 1:
72
+ project_team_name = matching_projects[0].get("team_name")
73
+ if project_team_name in existing_teams:
74
+ return (project_team_name, _get_team_id(project_team_name))
75
+
76
+ if len(existing_teams) == 1:
77
+ single_team_name = list(existing_teams.keys())[0]
78
+ return (single_team_name, _get_team_id(single_team_name))
79
+
80
+ selected_team_name = remote_cli.inquire_team(existing_teams=existing_teams)
81
+ return (selected_team_name, _get_team_id(selected_team_name))
@@ -0,0 +1,332 @@
1
+ import csv
2
+ import json
3
+ import os
4
+ import sys
5
+ from abc import ABC, abstractmethod
6
+ from typing import Any, Callable, Optional
7
+
8
+ import rich
9
+
10
+ from truss.cli.train import common
11
+ from truss.cli.utils import common as cli_common
12
+ from truss.cli.utils.output import console
13
+ from truss.remote.baseten.custom_types import (
14
+ FileSummary,
15
+ FileSummaryWithTotalSize,
16
+ GetCacheSummaryResponseV1,
17
+ )
18
+ from truss.remote.baseten.remote import BasetenRemote
19
+
20
+ # Sort constants
21
+ SORT_BY_FILEPATH = "filepath"
22
+ SORT_BY_SIZE = "size"
23
+ SORT_BY_MODIFIED = "modified"
24
+ SORT_BY_TYPE = "type"
25
+ SORT_BY_PERMISSIONS = "permissions"
26
+ SORT_ORDER_ASC = "asc"
27
+ SORT_ORDER_DESC = "desc"
28
+
29
+ # Output format constants
30
+ OUTPUT_FORMAT_CLI_TABLE = "cli-table"
31
+ OUTPUT_FORMAT_CSV = "csv"
32
+ OUTPUT_FORMAT_JSON = "json"
33
+
34
+
35
+ def calculate_directory_sizes(
36
+ files: list[FileSummary], max_depth: int = 100
37
+ ) -> dict[str, int]:
38
+ """Calculate total sizes for directories based on their file contents."""
39
+ directory_sizes = {}
40
+
41
+ for file_info in files:
42
+ if file_info.file_type == "directory":
43
+ directory_sizes[file_info.path] = 0
44
+
45
+ for file_info in files:
46
+ current_path = file_info.path
47
+ for i in range(max_depth):
48
+ if current_path is None:
49
+ break
50
+ if current_path in directory_sizes:
51
+ directory_sizes[current_path] += file_info.size_bytes
52
+ # Move to parent directory
53
+ parent = os.path.dirname(current_path)
54
+ if parent == current_path: # Reached root
55
+ break
56
+ current_path = parent
57
+
58
+ return directory_sizes
59
+
60
+
61
+ def create_file_summary_with_directory_sizes(
62
+ files: list[FileSummary],
63
+ ) -> list[FileSummaryWithTotalSize]:
64
+ """Create file summaries with total sizes including directory sizes."""
65
+ directory_sizes = calculate_directory_sizes(files)
66
+ return [
67
+ FileSummaryWithTotalSize(
68
+ file_summary=file_info,
69
+ total_size=directory_sizes.get(file_info.path, file_info.size_bytes),
70
+ )
71
+ for file_info in files
72
+ ]
73
+
74
+
75
+ def _get_sort_key(sort_by: str) -> Callable[[FileSummaryWithTotalSize], Any]:
76
+ """Get the sort key function for the given sort option."""
77
+ if sort_by == SORT_BY_FILEPATH:
78
+ return lambda x: x.file_summary.path
79
+ elif sort_by == SORT_BY_SIZE:
80
+ return lambda x: x.total_size
81
+ elif sort_by == SORT_BY_MODIFIED:
82
+ return lambda x: x.file_summary.modified
83
+ elif sort_by == SORT_BY_TYPE:
84
+ return lambda x: x.file_summary.file_type or ""
85
+ elif sort_by == SORT_BY_PERMISSIONS:
86
+ return lambda x: x.file_summary.permissions or ""
87
+ else:
88
+ raise ValueError(f"Invalid --sort argument: {sort_by}")
89
+
90
+
91
+ class CacheSummaryViewer(ABC):
92
+ """Base class for cache summary viewers that output in different formats."""
93
+
94
+ @abstractmethod
95
+ def output_cache_summary(
96
+ self,
97
+ cache_data: Optional[GetCacheSummaryResponseV1],
98
+ files_with_total_sizes: list[FileSummaryWithTotalSize],
99
+ total_size: int,
100
+ total_size_str: str,
101
+ project_id: str,
102
+ ) -> None:
103
+ """Output the cache summary in the viewer's format."""
104
+ pass
105
+
106
+ @abstractmethod
107
+ def output_no_cache_message(self, project_id: str) -> None:
108
+ """Output message when no cache summary is found."""
109
+ pass
110
+
111
+
112
+ class CLITableViewer(CacheSummaryViewer):
113
+ """Viewer that outputs cache summary as a styled CLI table."""
114
+
115
+ def output_cache_summary(
116
+ self,
117
+ cache_data: Optional[GetCacheSummaryResponseV1],
118
+ files_with_total_sizes: list[FileSummaryWithTotalSize],
119
+ total_size: int,
120
+ total_size_str: str,
121
+ project_id: str,
122
+ ) -> None:
123
+ """Output cache summary as a styled CLI table."""
124
+ if not cache_data:
125
+ return
126
+
127
+ table = rich.table.Table(title=f"Cache summary for project: {project_id}")
128
+ table.add_column("File Path", style="cyan", overflow="fold")
129
+ table.add_column("Size", style="green")
130
+ table.add_column("Modified", style="yellow")
131
+ table.add_column("Type")
132
+ table.add_column("Permissions", style="magenta")
133
+
134
+ console.print(
135
+ f"📅 Cache captured at: {cache_data.timestamp}", style="bold blue"
136
+ )
137
+ console.print(f"📁 Project ID: {cache_data.project_id}", style="bold blue")
138
+ console.print()
139
+ console.print(
140
+ f"📊 Total files: {len(files_with_total_sizes)}", style="bold green"
141
+ )
142
+ console.print(f"💾 Total size: {total_size_str}", style="bold green")
143
+ console.print()
144
+ # Note: Long file paths wrap across multiple lines. To copy full paths easily,
145
+ # use --output-format csv or --output-format json
146
+ if any(len(f.file_summary.path) > 60 for f in files_with_total_sizes):
147
+ console.print(
148
+ "💡 Tip: Use -o csv or -o json to output paths on single lines for easier copying",
149
+ style="dim",
150
+ )
151
+ console.print()
152
+
153
+ for file_info in files_with_total_sizes:
154
+ total_size = file_info.total_size
155
+
156
+ size_str = cli_common.format_bytes_to_human_readable(int(total_size))
157
+
158
+ modified_str = cli_common.format_localized_time(
159
+ file_info.file_summary.modified
160
+ )
161
+
162
+ table.add_row(
163
+ file_info.file_summary.path,
164
+ size_str,
165
+ modified_str,
166
+ file_info.file_summary.file_type or "Unknown",
167
+ file_info.file_summary.permissions or "Unknown",
168
+ )
169
+
170
+ console.print(table)
171
+
172
+ def output_no_cache_message(self, project_id: str) -> None:
173
+ """Output message when no cache summary is found."""
174
+ console.print("No cache summary found for this project.", style="yellow")
175
+
176
+
177
+ class CSVViewer(CacheSummaryViewer):
178
+ """Viewer that outputs cache summary in CSV format."""
179
+
180
+ def output_cache_summary(
181
+ self,
182
+ cache_data: Optional[GetCacheSummaryResponseV1],
183
+ files_with_total_sizes: list[FileSummaryWithTotalSize],
184
+ total_size: int,
185
+ total_size_str: str,
186
+ project_id: str,
187
+ ) -> None:
188
+ """Output cache summary in CSV format."""
189
+ writer = csv.writer(sys.stdout)
190
+ # Write header
191
+ writer.writerow(
192
+ [
193
+ "File Path",
194
+ "Size (bytes)",
195
+ "Size (human readable)",
196
+ "Modified",
197
+ "Type",
198
+ "Permissions",
199
+ ]
200
+ )
201
+ # Write data rows
202
+ for file_info in files_with_total_sizes:
203
+ size_str = cli_common.format_bytes_to_human_readable(
204
+ int(file_info.total_size)
205
+ )
206
+ modified_str = cli_common.format_localized_time(
207
+ file_info.file_summary.modified
208
+ )
209
+ writer.writerow(
210
+ [
211
+ file_info.file_summary.path,
212
+ str(file_info.total_size),
213
+ size_str,
214
+ modified_str,
215
+ file_info.file_summary.file_type or "Unknown",
216
+ file_info.file_summary.permissions or "Unknown",
217
+ ]
218
+ )
219
+
220
+ def output_no_cache_message(self, project_id: str) -> None:
221
+ """Output empty CSV with headers when no cache summary is found."""
222
+ self.output_cache_summary(None, [], 0, "0 B", project_id)
223
+
224
+
225
+ class JSONViewer(CacheSummaryViewer):
226
+ """Viewer that outputs cache summary in JSON format."""
227
+
228
+ def output_cache_summary(
229
+ self,
230
+ cache_data: Optional[GetCacheSummaryResponseV1],
231
+ files_with_total_sizes: list[FileSummaryWithTotalSize],
232
+ total_size: int,
233
+ total_size_str: str,
234
+ project_id: str,
235
+ ) -> None:
236
+ """Output cache summary in JSON format."""
237
+ files_data = []
238
+ for file_info in files_with_total_sizes:
239
+ size_str = cli_common.format_bytes_to_human_readable(
240
+ int(file_info.total_size)
241
+ )
242
+ modified_str = cli_common.format_localized_time(
243
+ file_info.file_summary.modified
244
+ )
245
+ files_data.append(
246
+ {
247
+ "path": file_info.file_summary.path,
248
+ "size_bytes": file_info.total_size,
249
+ "size_human_readable": size_str,
250
+ "modified": modified_str,
251
+ "type": file_info.file_summary.file_type or "Unknown",
252
+ "permissions": file_info.file_summary.permissions or "Unknown",
253
+ }
254
+ )
255
+
256
+ output = {
257
+ "timestamp": cache_data.timestamp if cache_data else "",
258
+ "project_id": cache_data.project_id if cache_data else project_id,
259
+ "total_files": len(files_with_total_sizes),
260
+ "total_size_bytes": total_size,
261
+ "total_size_human_readable": total_size_str,
262
+ "files": files_data,
263
+ }
264
+
265
+ print(json.dumps(output, indent=2))
266
+
267
+ def output_no_cache_message(self, project_id: str) -> None:
268
+ """Output empty JSON structure when no cache summary is found."""
269
+ self.output_cache_summary(None, [], 0, "0 B", project_id)
270
+
271
+
272
+ def _get_cache_summary_viewer(output_format: str) -> CacheSummaryViewer:
273
+ """Factory function to get the appropriate viewer for the output format."""
274
+ if output_format == OUTPUT_FORMAT_CSV:
275
+ return CSVViewer()
276
+ elif output_format == OUTPUT_FORMAT_JSON:
277
+ return JSONViewer()
278
+ else:
279
+ return CLITableViewer()
280
+
281
+
282
+ def view_cache_summary(
283
+ remote_provider: BasetenRemote,
284
+ project_id: str,
285
+ sort_by: str = SORT_BY_FILEPATH,
286
+ order: str = SORT_ORDER_ASC,
287
+ output_format: str = OUTPUT_FORMAT_CLI_TABLE,
288
+ ):
289
+ """View cache summary for a training project."""
290
+ viewer_factories = {
291
+ OUTPUT_FORMAT_CSV: lambda: CSVViewer(),
292
+ OUTPUT_FORMAT_JSON: lambda: JSONViewer(),
293
+ OUTPUT_FORMAT_CLI_TABLE: lambda: CLITableViewer(),
294
+ }
295
+ viewer_factory = viewer_factories.get(output_format)
296
+ if not viewer_factory:
297
+ raise ValueError(f"Invalid output format: {output_format}")
298
+ viewer = viewer_factory()
299
+ try:
300
+ raw_cache_data = remote_provider.api.get_cache_summary(project_id)
301
+
302
+ if not raw_cache_data:
303
+ viewer.output_no_cache_message(project_id)
304
+ return
305
+
306
+ cache_data = GetCacheSummaryResponseV1.model_validate(raw_cache_data)
307
+
308
+ files = cache_data.file_summaries
309
+ files_with_total_sizes = (
310
+ create_file_summary_with_directory_sizes(files) if files else []
311
+ )
312
+
313
+ reverse = order == SORT_ORDER_DESC
314
+ sort_key = _get_sort_key(sort_by)
315
+ files_with_total_sizes.sort(key=sort_key, reverse=reverse)
316
+
317
+ total_size = sum(
318
+ file_info.file_summary.size_bytes for file_info in files_with_total_sizes
319
+ )
320
+ total_size_str = common.format_bytes_to_human_readable(total_size)
321
+
322
+ viewer.output_cache_summary(
323
+ cache_data, files_with_total_sizes, total_size, total_size_str, project_id
324
+ )
325
+
326
+ except Exception as e:
327
+ # For CSV/JSON formats, print to stderr to keep stdout clean for piping
328
+ if output_format in (OUTPUT_FORMAT_CSV, OUTPUT_FORMAT_JSON):
329
+ print(f"Error fetching cache summary: {str(e)}", file=sys.stderr)
330
+ else:
331
+ console.print(f"Error fetching cache summary: {str(e)}", style="red")
332
+ raise