truss 0.11.18rc500__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- truss/api/__init__.py +5 -2
- truss/base/truss_config.py +10 -3
- truss/cli/chains_commands.py +39 -1
- truss/cli/cli.py +35 -5
- truss/cli/remote_cli.py +29 -0
- truss/cli/resolvers/chain_team_resolver.py +82 -0
- truss/cli/resolvers/model_team_resolver.py +90 -0
- truss/cli/resolvers/training_project_team_resolver.py +81 -0
- truss/cli/train/cache.py +332 -0
- truss/cli/train/core.py +19 -143
- truss/cli/train_commands.py +69 -11
- truss/cli/utils/common.py +40 -3
- truss/remote/baseten/api.py +58 -5
- truss/remote/baseten/core.py +22 -4
- truss/remote/baseten/remote.py +24 -2
- truss/templates/control/control/helpers/inference_server_process_controller.py +3 -1
- truss/templates/server/requirements.txt +1 -1
- truss/templates/server.Dockerfile.jinja +10 -10
- truss/templates/shared/util.py +6 -5
- truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
- truss/tests/cli/test_chains_cli.py +44 -0
- truss/tests/cli/test_cli.py +134 -1
- truss/tests/cli/test_cli_utils_common.py +11 -0
- truss/tests/cli/test_model_team_resolver.py +279 -0
- truss/tests/cli/train/test_cache_view.py +240 -3
- truss/tests/cli/train/test_train_cli_core.py +2 -2
- truss/tests/cli/train/test_train_team_parameter.py +395 -0
- truss/tests/conftest.py +187 -0
- truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
- truss/tests/remote/baseten/test_api.py +122 -3
- truss/tests/remote/baseten/test_chain_upload.py +10 -1
- truss/tests/remote/baseten/test_core.py +86 -0
- truss/tests/remote/baseten/test_remote.py +216 -288
- truss/tests/test_config.py +21 -12
- truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
- truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
- truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
- truss/tests/test_model_inference.py +13 -0
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/METADATA +1 -1
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/RECORD +50 -38
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
- truss_chains/deployment/deployment_client.py +9 -4
- truss_chains/private_types.py +15 -0
- truss_train/definitions.py +3 -1
- truss_train/deployment.py +43 -21
- truss_train/public_api.py +4 -2
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
truss/cli/train/cache.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
import csv
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Any, Callable, Optional
|
|
7
|
+
|
|
8
|
+
import rich
|
|
9
|
+
|
|
10
|
+
from truss.cli.train import common
|
|
11
|
+
from truss.cli.utils import common as cli_common
|
|
12
|
+
from truss.cli.utils.output import console
|
|
13
|
+
from truss.remote.baseten.custom_types import (
|
|
14
|
+
FileSummary,
|
|
15
|
+
FileSummaryWithTotalSize,
|
|
16
|
+
GetCacheSummaryResponseV1,
|
|
17
|
+
)
|
|
18
|
+
from truss.remote.baseten.remote import BasetenRemote
|
|
19
|
+
|
|
20
|
+
# Sort constants
|
|
21
|
+
SORT_BY_FILEPATH = "filepath"
|
|
22
|
+
SORT_BY_SIZE = "size"
|
|
23
|
+
SORT_BY_MODIFIED = "modified"
|
|
24
|
+
SORT_BY_TYPE = "type"
|
|
25
|
+
SORT_BY_PERMISSIONS = "permissions"
|
|
26
|
+
SORT_ORDER_ASC = "asc"
|
|
27
|
+
SORT_ORDER_DESC = "desc"
|
|
28
|
+
|
|
29
|
+
# Output format constants
|
|
30
|
+
OUTPUT_FORMAT_CLI_TABLE = "cli-table"
|
|
31
|
+
OUTPUT_FORMAT_CSV = "csv"
|
|
32
|
+
OUTPUT_FORMAT_JSON = "json"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def calculate_directory_sizes(
|
|
36
|
+
files: list[FileSummary], max_depth: int = 100
|
|
37
|
+
) -> dict[str, int]:
|
|
38
|
+
"""Calculate total sizes for directories based on their file contents."""
|
|
39
|
+
directory_sizes = {}
|
|
40
|
+
|
|
41
|
+
for file_info in files:
|
|
42
|
+
if file_info.file_type == "directory":
|
|
43
|
+
directory_sizes[file_info.path] = 0
|
|
44
|
+
|
|
45
|
+
for file_info in files:
|
|
46
|
+
current_path = file_info.path
|
|
47
|
+
for i in range(max_depth):
|
|
48
|
+
if current_path is None:
|
|
49
|
+
break
|
|
50
|
+
if current_path in directory_sizes:
|
|
51
|
+
directory_sizes[current_path] += file_info.size_bytes
|
|
52
|
+
# Move to parent directory
|
|
53
|
+
parent = os.path.dirname(current_path)
|
|
54
|
+
if parent == current_path: # Reached root
|
|
55
|
+
break
|
|
56
|
+
current_path = parent
|
|
57
|
+
|
|
58
|
+
return directory_sizes
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def create_file_summary_with_directory_sizes(
|
|
62
|
+
files: list[FileSummary],
|
|
63
|
+
) -> list[FileSummaryWithTotalSize]:
|
|
64
|
+
"""Create file summaries with total sizes including directory sizes."""
|
|
65
|
+
directory_sizes = calculate_directory_sizes(files)
|
|
66
|
+
return [
|
|
67
|
+
FileSummaryWithTotalSize(
|
|
68
|
+
file_summary=file_info,
|
|
69
|
+
total_size=directory_sizes.get(file_info.path, file_info.size_bytes),
|
|
70
|
+
)
|
|
71
|
+
for file_info in files
|
|
72
|
+
]
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _get_sort_key(sort_by: str) -> Callable[[FileSummaryWithTotalSize], Any]:
|
|
76
|
+
"""Get the sort key function for the given sort option."""
|
|
77
|
+
if sort_by == SORT_BY_FILEPATH:
|
|
78
|
+
return lambda x: x.file_summary.path
|
|
79
|
+
elif sort_by == SORT_BY_SIZE:
|
|
80
|
+
return lambda x: x.total_size
|
|
81
|
+
elif sort_by == SORT_BY_MODIFIED:
|
|
82
|
+
return lambda x: x.file_summary.modified
|
|
83
|
+
elif sort_by == SORT_BY_TYPE:
|
|
84
|
+
return lambda x: x.file_summary.file_type or ""
|
|
85
|
+
elif sort_by == SORT_BY_PERMISSIONS:
|
|
86
|
+
return lambda x: x.file_summary.permissions or ""
|
|
87
|
+
else:
|
|
88
|
+
raise ValueError(f"Invalid --sort argument: {sort_by}")
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class CacheSummaryViewer(ABC):
|
|
92
|
+
"""Base class for cache summary viewers that output in different formats."""
|
|
93
|
+
|
|
94
|
+
@abstractmethod
|
|
95
|
+
def output_cache_summary(
|
|
96
|
+
self,
|
|
97
|
+
cache_data: Optional[GetCacheSummaryResponseV1],
|
|
98
|
+
files_with_total_sizes: list[FileSummaryWithTotalSize],
|
|
99
|
+
total_size: int,
|
|
100
|
+
total_size_str: str,
|
|
101
|
+
project_id: str,
|
|
102
|
+
) -> None:
|
|
103
|
+
"""Output the cache summary in the viewer's format."""
|
|
104
|
+
pass
|
|
105
|
+
|
|
106
|
+
@abstractmethod
|
|
107
|
+
def output_no_cache_message(self, project_id: str) -> None:
|
|
108
|
+
"""Output message when no cache summary is found."""
|
|
109
|
+
pass
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class CLITableViewer(CacheSummaryViewer):
|
|
113
|
+
"""Viewer that outputs cache summary as a styled CLI table."""
|
|
114
|
+
|
|
115
|
+
def output_cache_summary(
|
|
116
|
+
self,
|
|
117
|
+
cache_data: Optional[GetCacheSummaryResponseV1],
|
|
118
|
+
files_with_total_sizes: list[FileSummaryWithTotalSize],
|
|
119
|
+
total_size: int,
|
|
120
|
+
total_size_str: str,
|
|
121
|
+
project_id: str,
|
|
122
|
+
) -> None:
|
|
123
|
+
"""Output cache summary as a styled CLI table."""
|
|
124
|
+
if not cache_data:
|
|
125
|
+
return
|
|
126
|
+
|
|
127
|
+
table = rich.table.Table(title=f"Cache summary for project: {project_id}")
|
|
128
|
+
table.add_column("File Path", style="cyan", overflow="fold")
|
|
129
|
+
table.add_column("Size", style="green")
|
|
130
|
+
table.add_column("Modified", style="yellow")
|
|
131
|
+
table.add_column("Type")
|
|
132
|
+
table.add_column("Permissions", style="magenta")
|
|
133
|
+
|
|
134
|
+
console.print(
|
|
135
|
+
f"📅 Cache captured at: {cache_data.timestamp}", style="bold blue"
|
|
136
|
+
)
|
|
137
|
+
console.print(f"📁 Project ID: {cache_data.project_id}", style="bold blue")
|
|
138
|
+
console.print()
|
|
139
|
+
console.print(
|
|
140
|
+
f"📊 Total files: {len(files_with_total_sizes)}", style="bold green"
|
|
141
|
+
)
|
|
142
|
+
console.print(f"💾 Total size: {total_size_str}", style="bold green")
|
|
143
|
+
console.print()
|
|
144
|
+
# Note: Long file paths wrap across multiple lines. To copy full paths easily,
|
|
145
|
+
# use --output-format csv or --output-format json
|
|
146
|
+
if any(len(f.file_summary.path) > 60 for f in files_with_total_sizes):
|
|
147
|
+
console.print(
|
|
148
|
+
"💡 Tip: Use -o csv or -o json to output paths on single lines for easier copying",
|
|
149
|
+
style="dim",
|
|
150
|
+
)
|
|
151
|
+
console.print()
|
|
152
|
+
|
|
153
|
+
for file_info in files_with_total_sizes:
|
|
154
|
+
total_size = file_info.total_size
|
|
155
|
+
|
|
156
|
+
size_str = cli_common.format_bytes_to_human_readable(int(total_size))
|
|
157
|
+
|
|
158
|
+
modified_str = cli_common.format_localized_time(
|
|
159
|
+
file_info.file_summary.modified
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
table.add_row(
|
|
163
|
+
file_info.file_summary.path,
|
|
164
|
+
size_str,
|
|
165
|
+
modified_str,
|
|
166
|
+
file_info.file_summary.file_type or "Unknown",
|
|
167
|
+
file_info.file_summary.permissions or "Unknown",
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
console.print(table)
|
|
171
|
+
|
|
172
|
+
def output_no_cache_message(self, project_id: str) -> None:
|
|
173
|
+
"""Output message when no cache summary is found."""
|
|
174
|
+
console.print("No cache summary found for this project.", style="yellow")
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class CSVViewer(CacheSummaryViewer):
|
|
178
|
+
"""Viewer that outputs cache summary in CSV format."""
|
|
179
|
+
|
|
180
|
+
def output_cache_summary(
|
|
181
|
+
self,
|
|
182
|
+
cache_data: Optional[GetCacheSummaryResponseV1],
|
|
183
|
+
files_with_total_sizes: list[FileSummaryWithTotalSize],
|
|
184
|
+
total_size: int,
|
|
185
|
+
total_size_str: str,
|
|
186
|
+
project_id: str,
|
|
187
|
+
) -> None:
|
|
188
|
+
"""Output cache summary in CSV format."""
|
|
189
|
+
writer = csv.writer(sys.stdout)
|
|
190
|
+
# Write header
|
|
191
|
+
writer.writerow(
|
|
192
|
+
[
|
|
193
|
+
"File Path",
|
|
194
|
+
"Size (bytes)",
|
|
195
|
+
"Size (human readable)",
|
|
196
|
+
"Modified",
|
|
197
|
+
"Type",
|
|
198
|
+
"Permissions",
|
|
199
|
+
]
|
|
200
|
+
)
|
|
201
|
+
# Write data rows
|
|
202
|
+
for file_info in files_with_total_sizes:
|
|
203
|
+
size_str = cli_common.format_bytes_to_human_readable(
|
|
204
|
+
int(file_info.total_size)
|
|
205
|
+
)
|
|
206
|
+
modified_str = cli_common.format_localized_time(
|
|
207
|
+
file_info.file_summary.modified
|
|
208
|
+
)
|
|
209
|
+
writer.writerow(
|
|
210
|
+
[
|
|
211
|
+
file_info.file_summary.path,
|
|
212
|
+
str(file_info.total_size),
|
|
213
|
+
size_str,
|
|
214
|
+
modified_str,
|
|
215
|
+
file_info.file_summary.file_type or "Unknown",
|
|
216
|
+
file_info.file_summary.permissions or "Unknown",
|
|
217
|
+
]
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
def output_no_cache_message(self, project_id: str) -> None:
|
|
221
|
+
"""Output empty CSV with headers when no cache summary is found."""
|
|
222
|
+
self.output_cache_summary(None, [], 0, "0 B", project_id)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class JSONViewer(CacheSummaryViewer):
|
|
226
|
+
"""Viewer that outputs cache summary in JSON format."""
|
|
227
|
+
|
|
228
|
+
def output_cache_summary(
|
|
229
|
+
self,
|
|
230
|
+
cache_data: Optional[GetCacheSummaryResponseV1],
|
|
231
|
+
files_with_total_sizes: list[FileSummaryWithTotalSize],
|
|
232
|
+
total_size: int,
|
|
233
|
+
total_size_str: str,
|
|
234
|
+
project_id: str,
|
|
235
|
+
) -> None:
|
|
236
|
+
"""Output cache summary in JSON format."""
|
|
237
|
+
files_data = []
|
|
238
|
+
for file_info in files_with_total_sizes:
|
|
239
|
+
size_str = cli_common.format_bytes_to_human_readable(
|
|
240
|
+
int(file_info.total_size)
|
|
241
|
+
)
|
|
242
|
+
modified_str = cli_common.format_localized_time(
|
|
243
|
+
file_info.file_summary.modified
|
|
244
|
+
)
|
|
245
|
+
files_data.append(
|
|
246
|
+
{
|
|
247
|
+
"path": file_info.file_summary.path,
|
|
248
|
+
"size_bytes": file_info.total_size,
|
|
249
|
+
"size_human_readable": size_str,
|
|
250
|
+
"modified": modified_str,
|
|
251
|
+
"type": file_info.file_summary.file_type or "Unknown",
|
|
252
|
+
"permissions": file_info.file_summary.permissions or "Unknown",
|
|
253
|
+
}
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
output = {
|
|
257
|
+
"timestamp": cache_data.timestamp if cache_data else "",
|
|
258
|
+
"project_id": cache_data.project_id if cache_data else project_id,
|
|
259
|
+
"total_files": len(files_with_total_sizes),
|
|
260
|
+
"total_size_bytes": total_size,
|
|
261
|
+
"total_size_human_readable": total_size_str,
|
|
262
|
+
"files": files_data,
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
print(json.dumps(output, indent=2))
|
|
266
|
+
|
|
267
|
+
def output_no_cache_message(self, project_id: str) -> None:
|
|
268
|
+
"""Output empty JSON structure when no cache summary is found."""
|
|
269
|
+
self.output_cache_summary(None, [], 0, "0 B", project_id)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def _get_cache_summary_viewer(output_format: str) -> CacheSummaryViewer:
|
|
273
|
+
"""Factory function to get the appropriate viewer for the output format."""
|
|
274
|
+
if output_format == OUTPUT_FORMAT_CSV:
|
|
275
|
+
return CSVViewer()
|
|
276
|
+
elif output_format == OUTPUT_FORMAT_JSON:
|
|
277
|
+
return JSONViewer()
|
|
278
|
+
else:
|
|
279
|
+
return CLITableViewer()
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def view_cache_summary(
|
|
283
|
+
remote_provider: BasetenRemote,
|
|
284
|
+
project_id: str,
|
|
285
|
+
sort_by: str = SORT_BY_FILEPATH,
|
|
286
|
+
order: str = SORT_ORDER_ASC,
|
|
287
|
+
output_format: str = OUTPUT_FORMAT_CLI_TABLE,
|
|
288
|
+
):
|
|
289
|
+
"""View cache summary for a training project."""
|
|
290
|
+
viewer_factories = {
|
|
291
|
+
OUTPUT_FORMAT_CSV: lambda: CSVViewer(),
|
|
292
|
+
OUTPUT_FORMAT_JSON: lambda: JSONViewer(),
|
|
293
|
+
OUTPUT_FORMAT_CLI_TABLE: lambda: CLITableViewer(),
|
|
294
|
+
}
|
|
295
|
+
viewer_factory = viewer_factories.get(output_format)
|
|
296
|
+
if not viewer_factory:
|
|
297
|
+
raise ValueError(f"Invalid output format: {output_format}")
|
|
298
|
+
viewer = viewer_factory()
|
|
299
|
+
try:
|
|
300
|
+
raw_cache_data = remote_provider.api.get_cache_summary(project_id)
|
|
301
|
+
|
|
302
|
+
if not raw_cache_data:
|
|
303
|
+
viewer.output_no_cache_message(project_id)
|
|
304
|
+
return
|
|
305
|
+
|
|
306
|
+
cache_data = GetCacheSummaryResponseV1.model_validate(raw_cache_data)
|
|
307
|
+
|
|
308
|
+
files = cache_data.file_summaries
|
|
309
|
+
files_with_total_sizes = (
|
|
310
|
+
create_file_summary_with_directory_sizes(files) if files else []
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
reverse = order == SORT_ORDER_DESC
|
|
314
|
+
sort_key = _get_sort_key(sort_by)
|
|
315
|
+
files_with_total_sizes.sort(key=sort_key, reverse=reverse)
|
|
316
|
+
|
|
317
|
+
total_size = sum(
|
|
318
|
+
file_info.file_summary.size_bytes for file_info in files_with_total_sizes
|
|
319
|
+
)
|
|
320
|
+
total_size_str = common.format_bytes_to_human_readable(total_size)
|
|
321
|
+
|
|
322
|
+
viewer.output_cache_summary(
|
|
323
|
+
cache_data, files_with_total_sizes, total_size, total_size_str, project_id
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
except Exception as e:
|
|
327
|
+
# For CSV/JSON formats, print to stderr to keep stdout clean for piping
|
|
328
|
+
if output_format in (OUTPUT_FORMAT_CSV, OUTPUT_FORMAT_JSON):
|
|
329
|
+
print(f"Error fetching cache summary: {str(e)}", file=sys.stderr)
|
|
330
|
+
else:
|
|
331
|
+
console.print(f"Error fetching cache summary: {str(e)}", style="red")
|
|
332
|
+
raise
|
truss/cli/train/core.py
CHANGED
|
@@ -23,23 +23,10 @@ from truss.cli.train.types import (
|
|
|
23
23
|
)
|
|
24
24
|
from truss.cli.utils import common as cli_common
|
|
25
25
|
from truss.cli.utils.output import console
|
|
26
|
-
from truss.remote.baseten.custom_types import (
|
|
27
|
-
FileSummary,
|
|
28
|
-
FileSummaryWithTotalSize,
|
|
29
|
-
GetCacheSummaryResponseV1,
|
|
30
|
-
)
|
|
31
26
|
from truss.remote.baseten.remote import BasetenRemote
|
|
32
27
|
from truss_train import loader
|
|
33
28
|
from truss_train.definitions import DeployCheckpointsConfig
|
|
34
29
|
|
|
35
|
-
SORT_BY_FILEPATH = "filepath"
|
|
36
|
-
SORT_BY_SIZE = "size"
|
|
37
|
-
SORT_BY_MODIFIED = "modified"
|
|
38
|
-
SORT_BY_TYPE = "type"
|
|
39
|
-
SORT_BY_PERMISSIONS = "permissions"
|
|
40
|
-
SORT_ORDER_ASC = "asc"
|
|
41
|
-
SORT_ORDER_DESC = "desc"
|
|
42
|
-
|
|
43
30
|
ACTIVE_JOB_STATUSES = [
|
|
44
31
|
"TRAINING_JOB_RUNNING",
|
|
45
32
|
"TRAINING_JOB_CREATED",
|
|
@@ -630,139 +617,28 @@ def fetch_project_by_name_or_id(
|
|
|
630
617
|
raise click.ClickException(f"Error fetching project: {str(e)}")
|
|
631
618
|
|
|
632
619
|
|
|
633
|
-
def create_file_summary_with_directory_sizes(
|
|
634
|
-
files: list[FileSummary],
|
|
635
|
-
) -> list[FileSummaryWithTotalSize]:
|
|
636
|
-
directory_sizes = calculate_directory_sizes(files)
|
|
637
|
-
return [
|
|
638
|
-
FileSummaryWithTotalSize(
|
|
639
|
-
file_summary=file_info,
|
|
640
|
-
total_size=directory_sizes.get(file_info.path, file_info.size_bytes),
|
|
641
|
-
)
|
|
642
|
-
for file_info in files
|
|
643
|
-
]
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
def calculate_directory_sizes(
|
|
647
|
-
files: list[FileSummary], max_depth: int = 100
|
|
648
|
-
) -> dict[str, int]:
|
|
649
|
-
directory_sizes = {}
|
|
650
|
-
|
|
651
|
-
for file_info in files:
|
|
652
|
-
if file_info.file_type == "directory":
|
|
653
|
-
directory_sizes[file_info.path] = 0
|
|
654
|
-
|
|
655
|
-
for file_info in files:
|
|
656
|
-
current_path = file_info.path
|
|
657
|
-
for i in range(max_depth):
|
|
658
|
-
if current_path is None:
|
|
659
|
-
break
|
|
660
|
-
if current_path in directory_sizes:
|
|
661
|
-
directory_sizes[current_path] += file_info.size_bytes
|
|
662
|
-
# Move to parent directory
|
|
663
|
-
parent = os.path.dirname(current_path)
|
|
664
|
-
if parent == current_path: # Reached root
|
|
665
|
-
break
|
|
666
|
-
current_path = parent
|
|
667
|
-
|
|
668
|
-
return directory_sizes
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
def view_cache_summary(
|
|
672
|
-
remote_provider: BasetenRemote,
|
|
673
|
-
project_id: str,
|
|
674
|
-
sort_by: str = SORT_BY_FILEPATH,
|
|
675
|
-
order: str = SORT_ORDER_ASC,
|
|
676
|
-
):
|
|
677
|
-
"""View cache summary for a training project."""
|
|
678
|
-
try:
|
|
679
|
-
raw_cache_data = remote_provider.api.get_cache_summary(project_id)
|
|
680
|
-
|
|
681
|
-
if not raw_cache_data:
|
|
682
|
-
console.print("No cache summary found for this project.", style="yellow")
|
|
683
|
-
return
|
|
684
|
-
|
|
685
|
-
cache_data = GetCacheSummaryResponseV1.model_validate(raw_cache_data)
|
|
686
|
-
|
|
687
|
-
table = rich.table.Table(title=f"Cache summary for project: {project_id}")
|
|
688
|
-
table.add_column("File Path", style="cyan")
|
|
689
|
-
table.add_column("Size", style="green")
|
|
690
|
-
table.add_column("Modified", style="yellow")
|
|
691
|
-
table.add_column("Type")
|
|
692
|
-
table.add_column("Permissions", style="magenta")
|
|
693
|
-
|
|
694
|
-
files = cache_data.file_summaries
|
|
695
|
-
if not files:
|
|
696
|
-
console.print("No files found in cache.", style="yellow")
|
|
697
|
-
return
|
|
698
|
-
|
|
699
|
-
files_with_total_sizes = create_file_summary_with_directory_sizes(files)
|
|
700
|
-
|
|
701
|
-
reverse = order == SORT_ORDER_DESC
|
|
702
|
-
sort_key = _get_sort_key(sort_by)
|
|
703
|
-
files_with_total_sizes.sort(key=sort_key, reverse=reverse)
|
|
704
|
-
|
|
705
|
-
total_size = sum(
|
|
706
|
-
file_info.file_summary.size_bytes for file_info in files_with_total_sizes
|
|
707
|
-
)
|
|
708
|
-
total_size_str = common.format_bytes_to_human_readable(total_size)
|
|
709
|
-
|
|
710
|
-
console.print(
|
|
711
|
-
f"📅 Cache captured at: {cache_data.timestamp}", style="bold blue"
|
|
712
|
-
)
|
|
713
|
-
console.print(f"📁 Project ID: {cache_data.project_id}", style="bold blue")
|
|
714
|
-
console.print()
|
|
715
|
-
console.print(
|
|
716
|
-
f"📊 Total files: {len(files_with_total_sizes)}", style="bold green"
|
|
717
|
-
)
|
|
718
|
-
console.print(f"💾 Total size: {total_size_str}", style="bold green")
|
|
719
|
-
console.print()
|
|
720
|
-
|
|
721
|
-
for file_info in files_with_total_sizes:
|
|
722
|
-
total_size = file_info.total_size
|
|
723
|
-
|
|
724
|
-
size_str = cli_common.format_bytes_to_human_readable(int(total_size))
|
|
725
|
-
|
|
726
|
-
modified_str = cli_common.format_localized_time(
|
|
727
|
-
file_info.file_summary.modified
|
|
728
|
-
)
|
|
729
|
-
|
|
730
|
-
table.add_row(
|
|
731
|
-
file_info.file_summary.path,
|
|
732
|
-
size_str,
|
|
733
|
-
modified_str,
|
|
734
|
-
file_info.file_summary.file_type or "Unknown",
|
|
735
|
-
file_info.file_summary.permissions or "Unknown",
|
|
736
|
-
)
|
|
737
|
-
|
|
738
|
-
console.print(table)
|
|
739
|
-
|
|
740
|
-
except Exception as e:
|
|
741
|
-
console.print(f"Error fetching cache summary: {str(e)}", style="red")
|
|
742
|
-
raise
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
def _get_sort_key(sort_by: str) -> Callable[[FileSummaryWithTotalSize], Any]:
|
|
746
|
-
if sort_by == SORT_BY_FILEPATH:
|
|
747
|
-
return lambda x: x.file_summary.path
|
|
748
|
-
elif sort_by == SORT_BY_SIZE:
|
|
749
|
-
return lambda x: x.total_size
|
|
750
|
-
elif sort_by == SORT_BY_MODIFIED:
|
|
751
|
-
return lambda x: x.file_summary.modified
|
|
752
|
-
elif sort_by == SORT_BY_TYPE:
|
|
753
|
-
return lambda x: x.file_summary.file_type or ""
|
|
754
|
-
elif sort_by == SORT_BY_PERMISSIONS:
|
|
755
|
-
return lambda x: x.file_summary.permissions or ""
|
|
756
|
-
else:
|
|
757
|
-
raise ValueError(f"Invalid --sort argument: {sort_by}")
|
|
758
|
-
|
|
759
|
-
|
|
760
620
|
def view_cache_summary_by_project(
|
|
761
621
|
remote_provider: BasetenRemote,
|
|
762
622
|
project_identifier: str,
|
|
763
|
-
sort_by: str =
|
|
764
|
-
order: str =
|
|
623
|
+
sort_by: Optional[str] = None,
|
|
624
|
+
order: Optional[str] = None,
|
|
625
|
+
output_format: Optional[str] = None,
|
|
765
626
|
):
|
|
766
627
|
"""View cache summary for a training project by ID or name."""
|
|
628
|
+
from truss.cli.train.cache import (
|
|
629
|
+
OUTPUT_FORMAT_CLI_TABLE,
|
|
630
|
+
SORT_BY_FILEPATH,
|
|
631
|
+
SORT_ORDER_ASC,
|
|
632
|
+
view_cache_summary,
|
|
633
|
+
)
|
|
634
|
+
|
|
635
|
+
# Use constants for defaults if not provided
|
|
636
|
+
if sort_by is None:
|
|
637
|
+
sort_by = SORT_BY_FILEPATH
|
|
638
|
+
if order is None:
|
|
639
|
+
order = SORT_ORDER_ASC
|
|
640
|
+
if output_format is None:
|
|
641
|
+
output_format = OUTPUT_FORMAT_CLI_TABLE
|
|
642
|
+
|
|
767
643
|
project = fetch_project_by_name_or_id(remote_provider, project_identifier)
|
|
768
|
-
view_cache_summary(remote_provider, project["id"], sort_by, order)
|
|
644
|
+
view_cache_summary(remote_provider, project["id"], sort_by, order, output_format)
|
truss/cli/train_commands.py
CHANGED
|
@@ -12,9 +12,15 @@ from truss.cli import remote_cli
|
|
|
12
12
|
from truss.cli.cli import truss_cli
|
|
13
13
|
from truss.cli.logs import utils as cli_log_utils
|
|
14
14
|
from truss.cli.logs.training_log_watcher import TrainingLogWatcher
|
|
15
|
+
from truss.cli.resolvers.training_project_team_resolver import (
|
|
16
|
+
resolve_training_project_team_name,
|
|
17
|
+
)
|
|
15
18
|
from truss.cli.train import common as train_common
|
|
16
19
|
from truss.cli.train import core
|
|
17
|
-
from truss.cli.train.
|
|
20
|
+
from truss.cli.train.cache import (
|
|
21
|
+
OUTPUT_FORMAT_CLI_TABLE,
|
|
22
|
+
OUTPUT_FORMAT_CSV,
|
|
23
|
+
OUTPUT_FORMAT_JSON,
|
|
18
24
|
SORT_BY_FILEPATH,
|
|
19
25
|
SORT_BY_MODIFIED,
|
|
20
26
|
SORT_BY_PERMISSIONS,
|
|
@@ -108,29 +114,70 @@ def _prepare_click_context(f: click.Command, params: dict) -> click.Context:
|
|
|
108
114
|
return ctx
|
|
109
115
|
|
|
110
116
|
|
|
117
|
+
def _resolve_team_name(
|
|
118
|
+
remote_provider: BasetenRemote,
|
|
119
|
+
provided_team_name: Optional[str],
|
|
120
|
+
existing_project_name: Optional[str] = None,
|
|
121
|
+
existing_teams: Optional[dict[str, dict[str, str]]] = None,
|
|
122
|
+
) -> tuple[Optional[str], Optional[str]]:
|
|
123
|
+
return resolve_training_project_team_name(
|
|
124
|
+
remote_provider=remote_provider,
|
|
125
|
+
provided_team_name=provided_team_name,
|
|
126
|
+
existing_project_name=existing_project_name,
|
|
127
|
+
existing_teams=existing_teams,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
111
131
|
@train.command(name="push")
|
|
112
132
|
@click.argument("config", type=Path, required=True)
|
|
113
133
|
@click.option("--remote", type=str, required=False, help="Remote to use")
|
|
114
134
|
@click.option("--tail", is_flag=True, help="Tail for status + logs after push.")
|
|
115
135
|
@click.option("--job-name", type=str, required=False, help="Name of the training job.")
|
|
136
|
+
@click.option(
|
|
137
|
+
"--team",
|
|
138
|
+
"provided_team_name",
|
|
139
|
+
type=str,
|
|
140
|
+
required=False,
|
|
141
|
+
help="Team name for the training project",
|
|
142
|
+
)
|
|
116
143
|
@common.common_options()
|
|
117
144
|
def push_training_job(
|
|
118
|
-
config: Path,
|
|
145
|
+
config: Path,
|
|
146
|
+
remote: Optional[str],
|
|
147
|
+
tail: bool,
|
|
148
|
+
job_name: Optional[str],
|
|
149
|
+
provided_team_name: Optional[str],
|
|
119
150
|
):
|
|
120
151
|
"""Run a training job"""
|
|
121
|
-
from truss_train import deployment
|
|
152
|
+
from truss_train import deployment, loader
|
|
122
153
|
|
|
123
154
|
if not remote:
|
|
124
155
|
remote = remote_cli.inquire_remote_name()
|
|
125
156
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
157
|
+
remote_provider: BasetenRemote = cast(
|
|
158
|
+
BasetenRemote, RemoteFactory.create(remote=remote)
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
existing_teams = remote_provider.api.get_teams()
|
|
162
|
+
|
|
163
|
+
with loader.import_training_project(config) as training_project:
|
|
164
|
+
team_name, team_id = _resolve_team_name(
|
|
165
|
+
remote_provider,
|
|
166
|
+
provided_team_name,
|
|
167
|
+
existing_project_name=training_project.name,
|
|
168
|
+
existing_teams=existing_teams,
|
|
132
169
|
)
|
|
133
170
|
|
|
171
|
+
with console.status("Creating training job...", spinner="dots"):
|
|
172
|
+
job_resp = deployment.create_training_job(
|
|
173
|
+
remote_provider,
|
|
174
|
+
config,
|
|
175
|
+
training_project,
|
|
176
|
+
job_name_from_cli=job_name,
|
|
177
|
+
team_name=team_name,
|
|
178
|
+
team_id=team_id,
|
|
179
|
+
)
|
|
180
|
+
|
|
134
181
|
# Note: This post create logic needs to happen outside the context
|
|
135
182
|
# of the above context manager, as only one console session can be active
|
|
136
183
|
# at a time.
|
|
@@ -556,8 +603,17 @@ def cache():
|
|
|
556
603
|
default=SORT_ORDER_ASC,
|
|
557
604
|
help="Sort order: ascending or descending.",
|
|
558
605
|
)
|
|
606
|
+
@click.option(
|
|
607
|
+
"-o",
|
|
608
|
+
"--output-format",
|
|
609
|
+
type=click.Choice([OUTPUT_FORMAT_CLI_TABLE, OUTPUT_FORMAT_CSV, OUTPUT_FORMAT_JSON]),
|
|
610
|
+
default=OUTPUT_FORMAT_CLI_TABLE,
|
|
611
|
+
help="Output format: cli-table (default), csv, or json.",
|
|
612
|
+
)
|
|
559
613
|
@common.common_options()
|
|
560
|
-
def view_cache_summary(
|
|
614
|
+
def view_cache_summary(
|
|
615
|
+
project: str, remote: Optional[str], sort: str, order: str, output_format: str
|
|
616
|
+
):
|
|
561
617
|
"""View cache summary for a training project"""
|
|
562
618
|
if not remote:
|
|
563
619
|
remote = remote_cli.inquire_remote_name()
|
|
@@ -566,7 +622,9 @@ def view_cache_summary(project: str, remote: Optional[str], sort: str, order: st
|
|
|
566
622
|
BasetenRemote, RemoteFactory.create(remote=remote)
|
|
567
623
|
)
|
|
568
624
|
|
|
569
|
-
train_cli.view_cache_summary_by_project(
|
|
625
|
+
train_cli.view_cache_summary_by_project(
|
|
626
|
+
remote_provider, project, sort, order, output_format
|
|
627
|
+
)
|
|
570
628
|
|
|
571
629
|
|
|
572
630
|
def _maybe_resolve_project_id_from_id_or_name(
|
truss/cli/utils/common.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import datetime
|
|
2
2
|
import logging
|
|
3
|
+
import re
|
|
3
4
|
import sys
|
|
4
5
|
import warnings
|
|
5
6
|
from functools import wraps
|
|
@@ -20,6 +21,8 @@ from truss.cli.utils import self_upgrade
|
|
|
20
21
|
from truss.cli.utils.output import console
|
|
21
22
|
from truss.util import user_config
|
|
22
23
|
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
23
26
|
INCLUDE_GIT_INFO_DOC = (
|
|
24
27
|
"Whether to attach git versioning info (sha, branch, tag) to deployments made from "
|
|
25
28
|
"within a git repo. If set to True in `.trussrc`, it will always be attached."
|
|
@@ -181,10 +184,44 @@ def is_human_log_level(ctx: click.Context) -> bool:
|
|
|
181
184
|
return get_required_option(ctx, "log") != _HUMANFRIENDLY_LOG_LEVEL
|
|
182
185
|
|
|
183
186
|
|
|
184
|
-
def
|
|
187
|
+
def _normalize_iso_timestamp(iso_timestamp: str) -> str:
|
|
188
|
+
iso_timestamp = iso_timestamp.strip()
|
|
185
189
|
if iso_timestamp.endswith("Z"):
|
|
186
|
-
iso_timestamp = iso_timestamp
|
|
187
|
-
|
|
190
|
+
iso_timestamp = iso_timestamp[:-1] + "+00:00"
|
|
191
|
+
|
|
192
|
+
tz_part = ""
|
|
193
|
+
tz_match = re.search(r"([+-]\d{2}:\d{2}|[+-]\d{4})$", iso_timestamp)
|
|
194
|
+
if tz_match:
|
|
195
|
+
tz_part = tz_match.group(0)
|
|
196
|
+
iso_timestamp = iso_timestamp[: tz_match.start()]
|
|
197
|
+
|
|
198
|
+
iso_timestamp = iso_timestamp.rstrip()
|
|
199
|
+
|
|
200
|
+
if tz_part and ":" not in tz_part:
|
|
201
|
+
tz_part = f"{tz_part[:3]}:{tz_part[3:]}"
|
|
202
|
+
|
|
203
|
+
fractional_match = re.search(r"\.(\d+)$", iso_timestamp)
|
|
204
|
+
if fractional_match:
|
|
205
|
+
fractional_digits = fractional_match.group(1)
|
|
206
|
+
if len(fractional_digits) > 6:
|
|
207
|
+
iso_timestamp = (
|
|
208
|
+
iso_timestamp[: fractional_match.start()] + "." + fractional_digits[:6]
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
return f"{iso_timestamp}{tz_part}"
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
# NOTE: `pyproject.toml` declares support down to Python 3.9, whose
|
|
215
|
+
# `datetime.fromisoformat` cannot parse nanosecond fractions or colonless offsets,
|
|
216
|
+
# so normalize timestamps before parsing.
|
|
217
|
+
def format_localized_time(iso_timestamp: str) -> str:
|
|
218
|
+
try:
|
|
219
|
+
utc_time = datetime.datetime.fromisoformat(iso_timestamp)
|
|
220
|
+
except ValueError:
|
|
221
|
+
# Handle non-standard formats (nanoseconds, Z suffix, colonless offsets)
|
|
222
|
+
normalized_timestamp = _normalize_iso_timestamp(iso_timestamp)
|
|
223
|
+
utc_time = datetime.datetime.fromisoformat(normalized_timestamp)
|
|
224
|
+
|
|
188
225
|
local_time = utc_time.astimezone()
|
|
189
226
|
return local_time.strftime("%Y-%m-%d %H:%M:%S")
|
|
190
227
|
|