truss 0.11.6rc102__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- truss/api/__init__.py +5 -2
- truss/base/constants.py +1 -0
- truss/base/trt_llm_config.py +14 -3
- truss/base/truss_config.py +19 -4
- truss/cli/chains_commands.py +49 -1
- truss/cli/cli.py +38 -7
- truss/cli/logs/base_watcher.py +31 -12
- truss/cli/logs/model_log_watcher.py +24 -1
- truss/cli/remote_cli.py +29 -0
- truss/cli/resolvers/chain_team_resolver.py +82 -0
- truss/cli/resolvers/model_team_resolver.py +90 -0
- truss/cli/resolvers/training_project_team_resolver.py +81 -0
- truss/cli/train/cache.py +332 -0
- truss/cli/train/core.py +57 -163
- truss/cli/train/deploy_checkpoints/__init__.py +2 -2
- truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +236 -103
- truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +1 -52
- truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +1 -86
- truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py +1 -85
- truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +1 -56
- truss/cli/train/types.py +18 -9
- truss/cli/train_commands.py +180 -35
- truss/cli/utils/common.py +40 -3
- truss/contexts/image_builder/serving_image_builder.py +17 -4
- truss/remote/baseten/api.py +215 -9
- truss/remote/baseten/core.py +63 -7
- truss/remote/baseten/custom_types.py +1 -0
- truss/remote/baseten/remote.py +42 -2
- truss/remote/baseten/service.py +0 -7
- truss/remote/baseten/utils/transfer.py +5 -2
- truss/templates/base.Dockerfile.jinja +8 -4
- truss/templates/control/control/application.py +51 -26
- truss/templates/control/control/endpoints.py +1 -5
- truss/templates/control/control/helpers/inference_server_process_controller.py +10 -4
- truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py +33 -18
- truss/templates/control/control/server.py +1 -1
- truss/templates/control/requirements.txt +1 -2
- truss/templates/docker_server/proxy.conf.jinja +13 -0
- truss/templates/docker_server/supervisord.conf.jinja +2 -1
- truss/templates/no_build.Dockerfile.jinja +1 -0
- truss/templates/server/requirements.txt +2 -3
- truss/templates/server/truss_server.py +2 -5
- truss/templates/server.Dockerfile.jinja +12 -12
- truss/templates/shared/lazy_data_resolver.py +214 -2
- truss/templates/shared/util.py +6 -5
- truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
- truss/tests/cli/test_chains_cli.py +144 -0
- truss/tests/cli/test_cli.py +134 -1
- truss/tests/cli/test_cli_utils_common.py +11 -0
- truss/tests/cli/test_model_team_resolver.py +279 -0
- truss/tests/cli/train/test_cache_view.py +240 -3
- truss/tests/cli/train/test_deploy_checkpoints.py +2 -846
- truss/tests/cli/train/test_train_cli_core.py +2 -2
- truss/tests/cli/train/test_train_team_parameter.py +395 -0
- truss/tests/conftest.py +187 -0
- truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
- truss/tests/remote/baseten/test_api.py +122 -3
- truss/tests/remote/baseten/test_chain_upload.py +294 -0
- truss/tests/remote/baseten/test_core.py +86 -0
- truss/tests/remote/baseten/test_remote.py +216 -288
- truss/tests/remote/baseten/test_service.py +56 -0
- truss/tests/templates/control/control/conftest.py +20 -0
- truss/tests/templates/control/control/test_endpoints.py +4 -0
- truss/tests/templates/control/control/test_server.py +8 -24
- truss/tests/templates/control/control/test_server_integration.py +4 -2
- truss/tests/test_config.py +21 -12
- truss/tests/test_data/server.Dockerfile +3 -1
- truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
- truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
- truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
- truss/tests/test_model_inference.py +13 -0
- truss/tests/util/test_env_vars.py +8 -3
- truss/util/__init__.py +0 -0
- truss/util/env_vars.py +19 -8
- truss/util/error_utils.py +37 -0
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/METADATA +2 -2
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/RECORD +88 -70
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
- truss_chains/deployment/deployment_client.py +16 -4
- truss_chains/private_types.py +18 -0
- truss_chains/public_api.py +3 -0
- truss_train/definitions.py +6 -4
- truss_train/deployment.py +43 -21
- truss_train/public_api.py +4 -2
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""Team resolution logic for training projects."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import click
|
|
6
|
+
|
|
7
|
+
from truss.cli import remote_cli
|
|
8
|
+
from truss.remote.baseten.remote import BasetenRemote
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def resolve_training_project_team_name(
|
|
12
|
+
remote_provider: BasetenRemote,
|
|
13
|
+
provided_team_name: Optional[str],
|
|
14
|
+
existing_project_name: Optional[str] = None,
|
|
15
|
+
existing_teams: Optional[dict[str, dict[str, str]]] = None,
|
|
16
|
+
) -> tuple[Optional[str], Optional[str]]:
|
|
17
|
+
"""Resolve team name and team_id from provided team name or by prompting the user.
|
|
18
|
+
Returns a tuple of (team_name, team_id).
|
|
19
|
+
This function handles 8 distinct scenarios organized into 3 high-level categories:
|
|
20
|
+
|
|
21
|
+
HIGH-LEVEL SCENARIO 1: --team PROVIDED
|
|
22
|
+
SCENARIO 1: Valid team name, user has access
|
|
23
|
+
→ Returns (team_name, team_id) for that team (no prompt, no error)
|
|
24
|
+
SCENARIO 2: Invalid team name (does not exist)
|
|
25
|
+
→ Raises ClickException with error message listing available teams
|
|
26
|
+
|
|
27
|
+
HIGH-LEVEL SCENARIO 2: --team NOT PROVIDED, Training project does not exist
|
|
28
|
+
SCENARIO 3: User has multiple teams, no existing project
|
|
29
|
+
→ Prompts user to select a team via inquire_team()
|
|
30
|
+
SCENARIO 6: User has exactly one team, no existing project
|
|
31
|
+
→ Returns (team_name, team_id) for the single team automatically (no prompt)
|
|
32
|
+
|
|
33
|
+
HIGH-LEVEL SCENARIO 3: --team NOT PROVIDED, Training project exists
|
|
34
|
+
SCENARIO 4: User has multiple teams, existing project in exactly one team
|
|
35
|
+
→ Auto-detects and returns (team_name, team_id) for that team (no prompt)
|
|
36
|
+
SCENARIO 5: User has multiple teams, existing project exists in multiple teams
|
|
37
|
+
→ Prompts user to select a team via inquire_team()
|
|
38
|
+
SCENARIO 7: User has exactly one team, existing project matches the team
|
|
39
|
+
→ Auto-detects and returns (team_name, team_id) for the single team (no prompt)
|
|
40
|
+
SCENARIO 8: User has exactly one team, existing project exists in different team
|
|
41
|
+
→ Returns (team_name, team_id) for the single team automatically (no prompt, uses user's only team)
|
|
42
|
+
"""
|
|
43
|
+
if existing_teams is None:
|
|
44
|
+
existing_teams = remote_provider.api.get_teams()
|
|
45
|
+
|
|
46
|
+
def _get_team_id(team_name: Optional[str]) -> Optional[str]:
|
|
47
|
+
if team_name and existing_teams:
|
|
48
|
+
team_data = existing_teams.get(team_name)
|
|
49
|
+
return team_data["id"] if team_data else None
|
|
50
|
+
return None
|
|
51
|
+
|
|
52
|
+
if provided_team_name is not None:
|
|
53
|
+
if provided_team_name not in existing_teams:
|
|
54
|
+
available_teams_str = remote_cli.format_available_teams(existing_teams)
|
|
55
|
+
raise click.ClickException(
|
|
56
|
+
f"Team '{provided_team_name}' does not exist. Available teams: {available_teams_str}"
|
|
57
|
+
)
|
|
58
|
+
return (provided_team_name, _get_team_id(provided_team_name))
|
|
59
|
+
|
|
60
|
+
existing_projects = None
|
|
61
|
+
if existing_project_name is not None:
|
|
62
|
+
existing_projects = remote_provider.api.list_training_projects()
|
|
63
|
+
matching_projects = [
|
|
64
|
+
p for p in existing_projects if p.get("name") == existing_project_name
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
if len(matching_projects) > 1:
|
|
68
|
+
selected_team_name = remote_cli.inquire_team(existing_teams=existing_teams)
|
|
69
|
+
return (selected_team_name, _get_team_id(selected_team_name))
|
|
70
|
+
|
|
71
|
+
if len(matching_projects) == 1:
|
|
72
|
+
project_team_name = matching_projects[0].get("team_name")
|
|
73
|
+
if project_team_name in existing_teams:
|
|
74
|
+
return (project_team_name, _get_team_id(project_team_name))
|
|
75
|
+
|
|
76
|
+
if len(existing_teams) == 1:
|
|
77
|
+
single_team_name = list(existing_teams.keys())[0]
|
|
78
|
+
return (single_team_name, _get_team_id(single_team_name))
|
|
79
|
+
|
|
80
|
+
selected_team_name = remote_cli.inquire_team(existing_teams=existing_teams)
|
|
81
|
+
return (selected_team_name, _get_team_id(selected_team_name))
|
truss/cli/train/cache.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
import csv
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Any, Callable, Optional
|
|
7
|
+
|
|
8
|
+
import rich
|
|
9
|
+
|
|
10
|
+
from truss.cli.train import common
|
|
11
|
+
from truss.cli.utils import common as cli_common
|
|
12
|
+
from truss.cli.utils.output import console
|
|
13
|
+
from truss.remote.baseten.custom_types import (
|
|
14
|
+
FileSummary,
|
|
15
|
+
FileSummaryWithTotalSize,
|
|
16
|
+
GetCacheSummaryResponseV1,
|
|
17
|
+
)
|
|
18
|
+
from truss.remote.baseten.remote import BasetenRemote
|
|
19
|
+
|
|
20
|
+
# Sort constants
|
|
21
|
+
SORT_BY_FILEPATH = "filepath"
|
|
22
|
+
SORT_BY_SIZE = "size"
|
|
23
|
+
SORT_BY_MODIFIED = "modified"
|
|
24
|
+
SORT_BY_TYPE = "type"
|
|
25
|
+
SORT_BY_PERMISSIONS = "permissions"
|
|
26
|
+
SORT_ORDER_ASC = "asc"
|
|
27
|
+
SORT_ORDER_DESC = "desc"
|
|
28
|
+
|
|
29
|
+
# Output format constants
|
|
30
|
+
OUTPUT_FORMAT_CLI_TABLE = "cli-table"
|
|
31
|
+
OUTPUT_FORMAT_CSV = "csv"
|
|
32
|
+
OUTPUT_FORMAT_JSON = "json"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def calculate_directory_sizes(
|
|
36
|
+
files: list[FileSummary], max_depth: int = 100
|
|
37
|
+
) -> dict[str, int]:
|
|
38
|
+
"""Calculate total sizes for directories based on their file contents."""
|
|
39
|
+
directory_sizes = {}
|
|
40
|
+
|
|
41
|
+
for file_info in files:
|
|
42
|
+
if file_info.file_type == "directory":
|
|
43
|
+
directory_sizes[file_info.path] = 0
|
|
44
|
+
|
|
45
|
+
for file_info in files:
|
|
46
|
+
current_path = file_info.path
|
|
47
|
+
for i in range(max_depth):
|
|
48
|
+
if current_path is None:
|
|
49
|
+
break
|
|
50
|
+
if current_path in directory_sizes:
|
|
51
|
+
directory_sizes[current_path] += file_info.size_bytes
|
|
52
|
+
# Move to parent directory
|
|
53
|
+
parent = os.path.dirname(current_path)
|
|
54
|
+
if parent == current_path: # Reached root
|
|
55
|
+
break
|
|
56
|
+
current_path = parent
|
|
57
|
+
|
|
58
|
+
return directory_sizes
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def create_file_summary_with_directory_sizes(
|
|
62
|
+
files: list[FileSummary],
|
|
63
|
+
) -> list[FileSummaryWithTotalSize]:
|
|
64
|
+
"""Create file summaries with total sizes including directory sizes."""
|
|
65
|
+
directory_sizes = calculate_directory_sizes(files)
|
|
66
|
+
return [
|
|
67
|
+
FileSummaryWithTotalSize(
|
|
68
|
+
file_summary=file_info,
|
|
69
|
+
total_size=directory_sizes.get(file_info.path, file_info.size_bytes),
|
|
70
|
+
)
|
|
71
|
+
for file_info in files
|
|
72
|
+
]
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _get_sort_key(sort_by: str) -> Callable[[FileSummaryWithTotalSize], Any]:
|
|
76
|
+
"""Get the sort key function for the given sort option."""
|
|
77
|
+
if sort_by == SORT_BY_FILEPATH:
|
|
78
|
+
return lambda x: x.file_summary.path
|
|
79
|
+
elif sort_by == SORT_BY_SIZE:
|
|
80
|
+
return lambda x: x.total_size
|
|
81
|
+
elif sort_by == SORT_BY_MODIFIED:
|
|
82
|
+
return lambda x: x.file_summary.modified
|
|
83
|
+
elif sort_by == SORT_BY_TYPE:
|
|
84
|
+
return lambda x: x.file_summary.file_type or ""
|
|
85
|
+
elif sort_by == SORT_BY_PERMISSIONS:
|
|
86
|
+
return lambda x: x.file_summary.permissions or ""
|
|
87
|
+
else:
|
|
88
|
+
raise ValueError(f"Invalid --sort argument: {sort_by}")
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class CacheSummaryViewer(ABC):
|
|
92
|
+
"""Base class for cache summary viewers that output in different formats."""
|
|
93
|
+
|
|
94
|
+
@abstractmethod
|
|
95
|
+
def output_cache_summary(
|
|
96
|
+
self,
|
|
97
|
+
cache_data: Optional[GetCacheSummaryResponseV1],
|
|
98
|
+
files_with_total_sizes: list[FileSummaryWithTotalSize],
|
|
99
|
+
total_size: int,
|
|
100
|
+
total_size_str: str,
|
|
101
|
+
project_id: str,
|
|
102
|
+
) -> None:
|
|
103
|
+
"""Output the cache summary in the viewer's format."""
|
|
104
|
+
pass
|
|
105
|
+
|
|
106
|
+
@abstractmethod
|
|
107
|
+
def output_no_cache_message(self, project_id: str) -> None:
|
|
108
|
+
"""Output message when no cache summary is found."""
|
|
109
|
+
pass
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class CLITableViewer(CacheSummaryViewer):
|
|
113
|
+
"""Viewer that outputs cache summary as a styled CLI table."""
|
|
114
|
+
|
|
115
|
+
def output_cache_summary(
|
|
116
|
+
self,
|
|
117
|
+
cache_data: Optional[GetCacheSummaryResponseV1],
|
|
118
|
+
files_with_total_sizes: list[FileSummaryWithTotalSize],
|
|
119
|
+
total_size: int,
|
|
120
|
+
total_size_str: str,
|
|
121
|
+
project_id: str,
|
|
122
|
+
) -> None:
|
|
123
|
+
"""Output cache summary as a styled CLI table."""
|
|
124
|
+
if not cache_data:
|
|
125
|
+
return
|
|
126
|
+
|
|
127
|
+
table = rich.table.Table(title=f"Cache summary for project: {project_id}")
|
|
128
|
+
table.add_column("File Path", style="cyan", overflow="fold")
|
|
129
|
+
table.add_column("Size", style="green")
|
|
130
|
+
table.add_column("Modified", style="yellow")
|
|
131
|
+
table.add_column("Type")
|
|
132
|
+
table.add_column("Permissions", style="magenta")
|
|
133
|
+
|
|
134
|
+
console.print(
|
|
135
|
+
f"📅 Cache captured at: {cache_data.timestamp}", style="bold blue"
|
|
136
|
+
)
|
|
137
|
+
console.print(f"📁 Project ID: {cache_data.project_id}", style="bold blue")
|
|
138
|
+
console.print()
|
|
139
|
+
console.print(
|
|
140
|
+
f"📊 Total files: {len(files_with_total_sizes)}", style="bold green"
|
|
141
|
+
)
|
|
142
|
+
console.print(f"💾 Total size: {total_size_str}", style="bold green")
|
|
143
|
+
console.print()
|
|
144
|
+
# Note: Long file paths wrap across multiple lines. To copy full paths easily,
|
|
145
|
+
# use --output-format csv or --output-format json
|
|
146
|
+
if any(len(f.file_summary.path) > 60 for f in files_with_total_sizes):
|
|
147
|
+
console.print(
|
|
148
|
+
"💡 Tip: Use -o csv or -o json to output paths on single lines for easier copying",
|
|
149
|
+
style="dim",
|
|
150
|
+
)
|
|
151
|
+
console.print()
|
|
152
|
+
|
|
153
|
+
for file_info in files_with_total_sizes:
|
|
154
|
+
total_size = file_info.total_size
|
|
155
|
+
|
|
156
|
+
size_str = cli_common.format_bytes_to_human_readable(int(total_size))
|
|
157
|
+
|
|
158
|
+
modified_str = cli_common.format_localized_time(
|
|
159
|
+
file_info.file_summary.modified
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
table.add_row(
|
|
163
|
+
file_info.file_summary.path,
|
|
164
|
+
size_str,
|
|
165
|
+
modified_str,
|
|
166
|
+
file_info.file_summary.file_type or "Unknown",
|
|
167
|
+
file_info.file_summary.permissions or "Unknown",
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
console.print(table)
|
|
171
|
+
|
|
172
|
+
def output_no_cache_message(self, project_id: str) -> None:
|
|
173
|
+
"""Output message when no cache summary is found."""
|
|
174
|
+
console.print("No cache summary found for this project.", style="yellow")
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class CSVViewer(CacheSummaryViewer):
|
|
178
|
+
"""Viewer that outputs cache summary in CSV format."""
|
|
179
|
+
|
|
180
|
+
def output_cache_summary(
|
|
181
|
+
self,
|
|
182
|
+
cache_data: Optional[GetCacheSummaryResponseV1],
|
|
183
|
+
files_with_total_sizes: list[FileSummaryWithTotalSize],
|
|
184
|
+
total_size: int,
|
|
185
|
+
total_size_str: str,
|
|
186
|
+
project_id: str,
|
|
187
|
+
) -> None:
|
|
188
|
+
"""Output cache summary in CSV format."""
|
|
189
|
+
writer = csv.writer(sys.stdout)
|
|
190
|
+
# Write header
|
|
191
|
+
writer.writerow(
|
|
192
|
+
[
|
|
193
|
+
"File Path",
|
|
194
|
+
"Size (bytes)",
|
|
195
|
+
"Size (human readable)",
|
|
196
|
+
"Modified",
|
|
197
|
+
"Type",
|
|
198
|
+
"Permissions",
|
|
199
|
+
]
|
|
200
|
+
)
|
|
201
|
+
# Write data rows
|
|
202
|
+
for file_info in files_with_total_sizes:
|
|
203
|
+
size_str = cli_common.format_bytes_to_human_readable(
|
|
204
|
+
int(file_info.total_size)
|
|
205
|
+
)
|
|
206
|
+
modified_str = cli_common.format_localized_time(
|
|
207
|
+
file_info.file_summary.modified
|
|
208
|
+
)
|
|
209
|
+
writer.writerow(
|
|
210
|
+
[
|
|
211
|
+
file_info.file_summary.path,
|
|
212
|
+
str(file_info.total_size),
|
|
213
|
+
size_str,
|
|
214
|
+
modified_str,
|
|
215
|
+
file_info.file_summary.file_type or "Unknown",
|
|
216
|
+
file_info.file_summary.permissions or "Unknown",
|
|
217
|
+
]
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
def output_no_cache_message(self, project_id: str) -> None:
|
|
221
|
+
"""Output empty CSV with headers when no cache summary is found."""
|
|
222
|
+
self.output_cache_summary(None, [], 0, "0 B", project_id)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class JSONViewer(CacheSummaryViewer):
|
|
226
|
+
"""Viewer that outputs cache summary in JSON format."""
|
|
227
|
+
|
|
228
|
+
def output_cache_summary(
|
|
229
|
+
self,
|
|
230
|
+
cache_data: Optional[GetCacheSummaryResponseV1],
|
|
231
|
+
files_with_total_sizes: list[FileSummaryWithTotalSize],
|
|
232
|
+
total_size: int,
|
|
233
|
+
total_size_str: str,
|
|
234
|
+
project_id: str,
|
|
235
|
+
) -> None:
|
|
236
|
+
"""Output cache summary in JSON format."""
|
|
237
|
+
files_data = []
|
|
238
|
+
for file_info in files_with_total_sizes:
|
|
239
|
+
size_str = cli_common.format_bytes_to_human_readable(
|
|
240
|
+
int(file_info.total_size)
|
|
241
|
+
)
|
|
242
|
+
modified_str = cli_common.format_localized_time(
|
|
243
|
+
file_info.file_summary.modified
|
|
244
|
+
)
|
|
245
|
+
files_data.append(
|
|
246
|
+
{
|
|
247
|
+
"path": file_info.file_summary.path,
|
|
248
|
+
"size_bytes": file_info.total_size,
|
|
249
|
+
"size_human_readable": size_str,
|
|
250
|
+
"modified": modified_str,
|
|
251
|
+
"type": file_info.file_summary.file_type or "Unknown",
|
|
252
|
+
"permissions": file_info.file_summary.permissions or "Unknown",
|
|
253
|
+
}
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
output = {
|
|
257
|
+
"timestamp": cache_data.timestamp if cache_data else "",
|
|
258
|
+
"project_id": cache_data.project_id if cache_data else project_id,
|
|
259
|
+
"total_files": len(files_with_total_sizes),
|
|
260
|
+
"total_size_bytes": total_size,
|
|
261
|
+
"total_size_human_readable": total_size_str,
|
|
262
|
+
"files": files_data,
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
print(json.dumps(output, indent=2))
|
|
266
|
+
|
|
267
|
+
def output_no_cache_message(self, project_id: str) -> None:
|
|
268
|
+
"""Output empty JSON structure when no cache summary is found."""
|
|
269
|
+
self.output_cache_summary(None, [], 0, "0 B", project_id)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def _get_cache_summary_viewer(output_format: str) -> CacheSummaryViewer:
|
|
273
|
+
"""Factory function to get the appropriate viewer for the output format."""
|
|
274
|
+
if output_format == OUTPUT_FORMAT_CSV:
|
|
275
|
+
return CSVViewer()
|
|
276
|
+
elif output_format == OUTPUT_FORMAT_JSON:
|
|
277
|
+
return JSONViewer()
|
|
278
|
+
else:
|
|
279
|
+
return CLITableViewer()
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def view_cache_summary(
|
|
283
|
+
remote_provider: BasetenRemote,
|
|
284
|
+
project_id: str,
|
|
285
|
+
sort_by: str = SORT_BY_FILEPATH,
|
|
286
|
+
order: str = SORT_ORDER_ASC,
|
|
287
|
+
output_format: str = OUTPUT_FORMAT_CLI_TABLE,
|
|
288
|
+
):
|
|
289
|
+
"""View cache summary for a training project."""
|
|
290
|
+
viewer_factories = {
|
|
291
|
+
OUTPUT_FORMAT_CSV: lambda: CSVViewer(),
|
|
292
|
+
OUTPUT_FORMAT_JSON: lambda: JSONViewer(),
|
|
293
|
+
OUTPUT_FORMAT_CLI_TABLE: lambda: CLITableViewer(),
|
|
294
|
+
}
|
|
295
|
+
viewer_factory = viewer_factories.get(output_format)
|
|
296
|
+
if not viewer_factory:
|
|
297
|
+
raise ValueError(f"Invalid output format: {output_format}")
|
|
298
|
+
viewer = viewer_factory()
|
|
299
|
+
try:
|
|
300
|
+
raw_cache_data = remote_provider.api.get_cache_summary(project_id)
|
|
301
|
+
|
|
302
|
+
if not raw_cache_data:
|
|
303
|
+
viewer.output_no_cache_message(project_id)
|
|
304
|
+
return
|
|
305
|
+
|
|
306
|
+
cache_data = GetCacheSummaryResponseV1.model_validate(raw_cache_data)
|
|
307
|
+
|
|
308
|
+
files = cache_data.file_summaries
|
|
309
|
+
files_with_total_sizes = (
|
|
310
|
+
create_file_summary_with_directory_sizes(files) if files else []
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
reverse = order == SORT_ORDER_DESC
|
|
314
|
+
sort_key = _get_sort_key(sort_by)
|
|
315
|
+
files_with_total_sizes.sort(key=sort_key, reverse=reverse)
|
|
316
|
+
|
|
317
|
+
total_size = sum(
|
|
318
|
+
file_info.file_summary.size_bytes for file_info in files_with_total_sizes
|
|
319
|
+
)
|
|
320
|
+
total_size_str = common.format_bytes_to_human_readable(total_size)
|
|
321
|
+
|
|
322
|
+
viewer.output_cache_summary(
|
|
323
|
+
cache_data, files_with_total_sizes, total_size, total_size_str, project_id
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
except Exception as e:
|
|
327
|
+
# For CSV/JSON formats, print to stderr to keep stdout clean for piping
|
|
328
|
+
if output_format in (OUTPUT_FORMAT_CSV, OUTPUT_FORMAT_JSON):
|
|
329
|
+
print(f"Error fetching cache summary: {str(e)}", file=sys.stderr)
|
|
330
|
+
else:
|
|
331
|
+
console.print(f"Error fetching cache summary: {str(e)}", style="red")
|
|
332
|
+
raise
|