truss 0.11.18rc500__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. truss/api/__init__.py +5 -2
  2. truss/base/truss_config.py +10 -3
  3. truss/cli/chains_commands.py +39 -1
  4. truss/cli/cli.py +35 -5
  5. truss/cli/remote_cli.py +29 -0
  6. truss/cli/resolvers/chain_team_resolver.py +82 -0
  7. truss/cli/resolvers/model_team_resolver.py +90 -0
  8. truss/cli/resolvers/training_project_team_resolver.py +81 -0
  9. truss/cli/train/cache.py +332 -0
  10. truss/cli/train/core.py +19 -143
  11. truss/cli/train_commands.py +69 -11
  12. truss/cli/utils/common.py +40 -3
  13. truss/remote/baseten/api.py +58 -5
  14. truss/remote/baseten/core.py +22 -4
  15. truss/remote/baseten/remote.py +24 -2
  16. truss/templates/control/control/helpers/inference_server_process_controller.py +3 -1
  17. truss/templates/server/requirements.txt +1 -1
  18. truss/templates/server.Dockerfile.jinja +10 -10
  19. truss/templates/shared/util.py +6 -5
  20. truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
  21. truss/tests/cli/test_chains_cli.py +44 -0
  22. truss/tests/cli/test_cli.py +134 -1
  23. truss/tests/cli/test_cli_utils_common.py +11 -0
  24. truss/tests/cli/test_model_team_resolver.py +279 -0
  25. truss/tests/cli/train/test_cache_view.py +240 -3
  26. truss/tests/cli/train/test_train_cli_core.py +2 -2
  27. truss/tests/cli/train/test_train_team_parameter.py +395 -0
  28. truss/tests/conftest.py +187 -0
  29. truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
  30. truss/tests/remote/baseten/test_api.py +122 -3
  31. truss/tests/remote/baseten/test_chain_upload.py +10 -1
  32. truss/tests/remote/baseten/test_core.py +86 -0
  33. truss/tests/remote/baseten/test_remote.py +216 -288
  34. truss/tests/test_config.py +21 -12
  35. truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
  36. truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
  37. truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
  38. truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
  39. truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
  40. truss/tests/test_model_inference.py +13 -0
  41. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/METADATA +1 -1
  42. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/RECORD +50 -38
  43. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
  44. truss_chains/deployment/deployment_client.py +9 -4
  45. truss_chains/private_types.py +15 -0
  46. truss_train/definitions.py +3 -1
  47. truss_train/deployment.py +43 -21
  48. truss_train/public_api.py +4 -2
  49. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
  50. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,332 @@
1
+ import csv
2
+ import json
3
+ import os
4
+ import sys
5
+ from abc import ABC, abstractmethod
6
+ from typing import Any, Callable, Optional
7
+
8
+ import rich
9
+
10
+ from truss.cli.train import common
11
+ from truss.cli.utils import common as cli_common
12
+ from truss.cli.utils.output import console
13
+ from truss.remote.baseten.custom_types import (
14
+ FileSummary,
15
+ FileSummaryWithTotalSize,
16
+ GetCacheSummaryResponseV1,
17
+ )
18
+ from truss.remote.baseten.remote import BasetenRemote
19
+
20
+ # Sort constants
21
+ SORT_BY_FILEPATH = "filepath"
22
+ SORT_BY_SIZE = "size"
23
+ SORT_BY_MODIFIED = "modified"
24
+ SORT_BY_TYPE = "type"
25
+ SORT_BY_PERMISSIONS = "permissions"
26
+ SORT_ORDER_ASC = "asc"
27
+ SORT_ORDER_DESC = "desc"
28
+
29
+ # Output format constants
30
+ OUTPUT_FORMAT_CLI_TABLE = "cli-table"
31
+ OUTPUT_FORMAT_CSV = "csv"
32
+ OUTPUT_FORMAT_JSON = "json"
33
+
34
+
35
+ def calculate_directory_sizes(
36
+ files: list[FileSummary], max_depth: int = 100
37
+ ) -> dict[str, int]:
38
+ """Calculate total sizes for directories based on their file contents."""
39
+ directory_sizes = {}
40
+
41
+ for file_info in files:
42
+ if file_info.file_type == "directory":
43
+ directory_sizes[file_info.path] = 0
44
+
45
+ for file_info in files:
46
+ current_path = file_info.path
47
+ for i in range(max_depth):
48
+ if current_path is None:
49
+ break
50
+ if current_path in directory_sizes:
51
+ directory_sizes[current_path] += file_info.size_bytes
52
+ # Move to parent directory
53
+ parent = os.path.dirname(current_path)
54
+ if parent == current_path: # Reached root
55
+ break
56
+ current_path = parent
57
+
58
+ return directory_sizes
59
+
60
+
61
+ def create_file_summary_with_directory_sizes(
62
+ files: list[FileSummary],
63
+ ) -> list[FileSummaryWithTotalSize]:
64
+ """Create file summaries with total sizes including directory sizes."""
65
+ directory_sizes = calculate_directory_sizes(files)
66
+ return [
67
+ FileSummaryWithTotalSize(
68
+ file_summary=file_info,
69
+ total_size=directory_sizes.get(file_info.path, file_info.size_bytes),
70
+ )
71
+ for file_info in files
72
+ ]
73
+
74
+
75
+ def _get_sort_key(sort_by: str) -> Callable[[FileSummaryWithTotalSize], Any]:
76
+ """Get the sort key function for the given sort option."""
77
+ if sort_by == SORT_BY_FILEPATH:
78
+ return lambda x: x.file_summary.path
79
+ elif sort_by == SORT_BY_SIZE:
80
+ return lambda x: x.total_size
81
+ elif sort_by == SORT_BY_MODIFIED:
82
+ return lambda x: x.file_summary.modified
83
+ elif sort_by == SORT_BY_TYPE:
84
+ return lambda x: x.file_summary.file_type or ""
85
+ elif sort_by == SORT_BY_PERMISSIONS:
86
+ return lambda x: x.file_summary.permissions or ""
87
+ else:
88
+ raise ValueError(f"Invalid --sort argument: {sort_by}")
89
+
90
+
91
+ class CacheSummaryViewer(ABC):
92
+ """Base class for cache summary viewers that output in different formats."""
93
+
94
+ @abstractmethod
95
+ def output_cache_summary(
96
+ self,
97
+ cache_data: Optional[GetCacheSummaryResponseV1],
98
+ files_with_total_sizes: list[FileSummaryWithTotalSize],
99
+ total_size: int,
100
+ total_size_str: str,
101
+ project_id: str,
102
+ ) -> None:
103
+ """Output the cache summary in the viewer's format."""
104
+ pass
105
+
106
+ @abstractmethod
107
+ def output_no_cache_message(self, project_id: str) -> None:
108
+ """Output message when no cache summary is found."""
109
+ pass
110
+
111
+
112
+ class CLITableViewer(CacheSummaryViewer):
113
+ """Viewer that outputs cache summary as a styled CLI table."""
114
+
115
+ def output_cache_summary(
116
+ self,
117
+ cache_data: Optional[GetCacheSummaryResponseV1],
118
+ files_with_total_sizes: list[FileSummaryWithTotalSize],
119
+ total_size: int,
120
+ total_size_str: str,
121
+ project_id: str,
122
+ ) -> None:
123
+ """Output cache summary as a styled CLI table."""
124
+ if not cache_data:
125
+ return
126
+
127
+ table = rich.table.Table(title=f"Cache summary for project: {project_id}")
128
+ table.add_column("File Path", style="cyan", overflow="fold")
129
+ table.add_column("Size", style="green")
130
+ table.add_column("Modified", style="yellow")
131
+ table.add_column("Type")
132
+ table.add_column("Permissions", style="magenta")
133
+
134
+ console.print(
135
+ f"📅 Cache captured at: {cache_data.timestamp}", style="bold blue"
136
+ )
137
+ console.print(f"📁 Project ID: {cache_data.project_id}", style="bold blue")
138
+ console.print()
139
+ console.print(
140
+ f"📊 Total files: {len(files_with_total_sizes)}", style="bold green"
141
+ )
142
+ console.print(f"💾 Total size: {total_size_str}", style="bold green")
143
+ console.print()
144
+ # Note: Long file paths wrap across multiple lines. To copy full paths easily,
145
+ # use --output-format csv or --output-format json
146
+ if any(len(f.file_summary.path) > 60 for f in files_with_total_sizes):
147
+ console.print(
148
+ "💡 Tip: Use -o csv or -o json to output paths on single lines for easier copying",
149
+ style="dim",
150
+ )
151
+ console.print()
152
+
153
+ for file_info in files_with_total_sizes:
154
+ total_size = file_info.total_size
155
+
156
+ size_str = cli_common.format_bytes_to_human_readable(int(total_size))
157
+
158
+ modified_str = cli_common.format_localized_time(
159
+ file_info.file_summary.modified
160
+ )
161
+
162
+ table.add_row(
163
+ file_info.file_summary.path,
164
+ size_str,
165
+ modified_str,
166
+ file_info.file_summary.file_type or "Unknown",
167
+ file_info.file_summary.permissions or "Unknown",
168
+ )
169
+
170
+ console.print(table)
171
+
172
+ def output_no_cache_message(self, project_id: str) -> None:
173
+ """Output message when no cache summary is found."""
174
+ console.print("No cache summary found for this project.", style="yellow")
175
+
176
+
177
+ class CSVViewer(CacheSummaryViewer):
178
+ """Viewer that outputs cache summary in CSV format."""
179
+
180
+ def output_cache_summary(
181
+ self,
182
+ cache_data: Optional[GetCacheSummaryResponseV1],
183
+ files_with_total_sizes: list[FileSummaryWithTotalSize],
184
+ total_size: int,
185
+ total_size_str: str,
186
+ project_id: str,
187
+ ) -> None:
188
+ """Output cache summary in CSV format."""
189
+ writer = csv.writer(sys.stdout)
190
+ # Write header
191
+ writer.writerow(
192
+ [
193
+ "File Path",
194
+ "Size (bytes)",
195
+ "Size (human readable)",
196
+ "Modified",
197
+ "Type",
198
+ "Permissions",
199
+ ]
200
+ )
201
+ # Write data rows
202
+ for file_info in files_with_total_sizes:
203
+ size_str = cli_common.format_bytes_to_human_readable(
204
+ int(file_info.total_size)
205
+ )
206
+ modified_str = cli_common.format_localized_time(
207
+ file_info.file_summary.modified
208
+ )
209
+ writer.writerow(
210
+ [
211
+ file_info.file_summary.path,
212
+ str(file_info.total_size),
213
+ size_str,
214
+ modified_str,
215
+ file_info.file_summary.file_type or "Unknown",
216
+ file_info.file_summary.permissions or "Unknown",
217
+ ]
218
+ )
219
+
220
+ def output_no_cache_message(self, project_id: str) -> None:
221
+ """Output empty CSV with headers when no cache summary is found."""
222
+ self.output_cache_summary(None, [], 0, "0 B", project_id)
223
+
224
+
225
+ class JSONViewer(CacheSummaryViewer):
226
+ """Viewer that outputs cache summary in JSON format."""
227
+
228
+ def output_cache_summary(
229
+ self,
230
+ cache_data: Optional[GetCacheSummaryResponseV1],
231
+ files_with_total_sizes: list[FileSummaryWithTotalSize],
232
+ total_size: int,
233
+ total_size_str: str,
234
+ project_id: str,
235
+ ) -> None:
236
+ """Output cache summary in JSON format."""
237
+ files_data = []
238
+ for file_info in files_with_total_sizes:
239
+ size_str = cli_common.format_bytes_to_human_readable(
240
+ int(file_info.total_size)
241
+ )
242
+ modified_str = cli_common.format_localized_time(
243
+ file_info.file_summary.modified
244
+ )
245
+ files_data.append(
246
+ {
247
+ "path": file_info.file_summary.path,
248
+ "size_bytes": file_info.total_size,
249
+ "size_human_readable": size_str,
250
+ "modified": modified_str,
251
+ "type": file_info.file_summary.file_type or "Unknown",
252
+ "permissions": file_info.file_summary.permissions or "Unknown",
253
+ }
254
+ )
255
+
256
+ output = {
257
+ "timestamp": cache_data.timestamp if cache_data else "",
258
+ "project_id": cache_data.project_id if cache_data else project_id,
259
+ "total_files": len(files_with_total_sizes),
260
+ "total_size_bytes": total_size,
261
+ "total_size_human_readable": total_size_str,
262
+ "files": files_data,
263
+ }
264
+
265
+ print(json.dumps(output, indent=2))
266
+
267
+ def output_no_cache_message(self, project_id: str) -> None:
268
+ """Output empty JSON structure when no cache summary is found."""
269
+ self.output_cache_summary(None, [], 0, "0 B", project_id)
270
+
271
+
272
+ def _get_cache_summary_viewer(output_format: str) -> CacheSummaryViewer:
273
+ """Factory function to get the appropriate viewer for the output format."""
274
+ if output_format == OUTPUT_FORMAT_CSV:
275
+ return CSVViewer()
276
+ elif output_format == OUTPUT_FORMAT_JSON:
277
+ return JSONViewer()
278
+ else:
279
+ return CLITableViewer()
280
+
281
+
282
+ def view_cache_summary(
283
+ remote_provider: BasetenRemote,
284
+ project_id: str,
285
+ sort_by: str = SORT_BY_FILEPATH,
286
+ order: str = SORT_ORDER_ASC,
287
+ output_format: str = OUTPUT_FORMAT_CLI_TABLE,
288
+ ):
289
+ """View cache summary for a training project."""
290
+ viewer_factories = {
291
+ OUTPUT_FORMAT_CSV: lambda: CSVViewer(),
292
+ OUTPUT_FORMAT_JSON: lambda: JSONViewer(),
293
+ OUTPUT_FORMAT_CLI_TABLE: lambda: CLITableViewer(),
294
+ }
295
+ viewer_factory = viewer_factories.get(output_format)
296
+ if not viewer_factory:
297
+ raise ValueError(f"Invalid output format: {output_format}")
298
+ viewer = viewer_factory()
299
+ try:
300
+ raw_cache_data = remote_provider.api.get_cache_summary(project_id)
301
+
302
+ if not raw_cache_data:
303
+ viewer.output_no_cache_message(project_id)
304
+ return
305
+
306
+ cache_data = GetCacheSummaryResponseV1.model_validate(raw_cache_data)
307
+
308
+ files = cache_data.file_summaries
309
+ files_with_total_sizes = (
310
+ create_file_summary_with_directory_sizes(files) if files else []
311
+ )
312
+
313
+ reverse = order == SORT_ORDER_DESC
314
+ sort_key = _get_sort_key(sort_by)
315
+ files_with_total_sizes.sort(key=sort_key, reverse=reverse)
316
+
317
+ total_size = sum(
318
+ file_info.file_summary.size_bytes for file_info in files_with_total_sizes
319
+ )
320
+ total_size_str = common.format_bytes_to_human_readable(total_size)
321
+
322
+ viewer.output_cache_summary(
323
+ cache_data, files_with_total_sizes, total_size, total_size_str, project_id
324
+ )
325
+
326
+ except Exception as e:
327
+ # For CSV/JSON formats, print to stderr to keep stdout clean for piping
328
+ if output_format in (OUTPUT_FORMAT_CSV, OUTPUT_FORMAT_JSON):
329
+ print(f"Error fetching cache summary: {str(e)}", file=sys.stderr)
330
+ else:
331
+ console.print(f"Error fetching cache summary: {str(e)}", style="red")
332
+ raise
truss/cli/train/core.py CHANGED
@@ -23,23 +23,10 @@ from truss.cli.train.types import (
23
23
  )
24
24
  from truss.cli.utils import common as cli_common
25
25
  from truss.cli.utils.output import console
26
- from truss.remote.baseten.custom_types import (
27
- FileSummary,
28
- FileSummaryWithTotalSize,
29
- GetCacheSummaryResponseV1,
30
- )
31
26
  from truss.remote.baseten.remote import BasetenRemote
32
27
  from truss_train import loader
33
28
  from truss_train.definitions import DeployCheckpointsConfig
34
29
 
35
- SORT_BY_FILEPATH = "filepath"
36
- SORT_BY_SIZE = "size"
37
- SORT_BY_MODIFIED = "modified"
38
- SORT_BY_TYPE = "type"
39
- SORT_BY_PERMISSIONS = "permissions"
40
- SORT_ORDER_ASC = "asc"
41
- SORT_ORDER_DESC = "desc"
42
-
43
30
  ACTIVE_JOB_STATUSES = [
44
31
  "TRAINING_JOB_RUNNING",
45
32
  "TRAINING_JOB_CREATED",
@@ -630,139 +617,28 @@ def fetch_project_by_name_or_id(
630
617
  raise click.ClickException(f"Error fetching project: {str(e)}")
631
618
 
632
619
 
633
- def create_file_summary_with_directory_sizes(
634
- files: list[FileSummary],
635
- ) -> list[FileSummaryWithTotalSize]:
636
- directory_sizes = calculate_directory_sizes(files)
637
- return [
638
- FileSummaryWithTotalSize(
639
- file_summary=file_info,
640
- total_size=directory_sizes.get(file_info.path, file_info.size_bytes),
641
- )
642
- for file_info in files
643
- ]
644
-
645
-
646
- def calculate_directory_sizes(
647
- files: list[FileSummary], max_depth: int = 100
648
- ) -> dict[str, int]:
649
- directory_sizes = {}
650
-
651
- for file_info in files:
652
- if file_info.file_type == "directory":
653
- directory_sizes[file_info.path] = 0
654
-
655
- for file_info in files:
656
- current_path = file_info.path
657
- for i in range(max_depth):
658
- if current_path is None:
659
- break
660
- if current_path in directory_sizes:
661
- directory_sizes[current_path] += file_info.size_bytes
662
- # Move to parent directory
663
- parent = os.path.dirname(current_path)
664
- if parent == current_path: # Reached root
665
- break
666
- current_path = parent
667
-
668
- return directory_sizes
669
-
670
-
671
- def view_cache_summary(
672
- remote_provider: BasetenRemote,
673
- project_id: str,
674
- sort_by: str = SORT_BY_FILEPATH,
675
- order: str = SORT_ORDER_ASC,
676
- ):
677
- """View cache summary for a training project."""
678
- try:
679
- raw_cache_data = remote_provider.api.get_cache_summary(project_id)
680
-
681
- if not raw_cache_data:
682
- console.print("No cache summary found for this project.", style="yellow")
683
- return
684
-
685
- cache_data = GetCacheSummaryResponseV1.model_validate(raw_cache_data)
686
-
687
- table = rich.table.Table(title=f"Cache summary for project: {project_id}")
688
- table.add_column("File Path", style="cyan")
689
- table.add_column("Size", style="green")
690
- table.add_column("Modified", style="yellow")
691
- table.add_column("Type")
692
- table.add_column("Permissions", style="magenta")
693
-
694
- files = cache_data.file_summaries
695
- if not files:
696
- console.print("No files found in cache.", style="yellow")
697
- return
698
-
699
- files_with_total_sizes = create_file_summary_with_directory_sizes(files)
700
-
701
- reverse = order == SORT_ORDER_DESC
702
- sort_key = _get_sort_key(sort_by)
703
- files_with_total_sizes.sort(key=sort_key, reverse=reverse)
704
-
705
- total_size = sum(
706
- file_info.file_summary.size_bytes for file_info in files_with_total_sizes
707
- )
708
- total_size_str = common.format_bytes_to_human_readable(total_size)
709
-
710
- console.print(
711
- f"📅 Cache captured at: {cache_data.timestamp}", style="bold blue"
712
- )
713
- console.print(f"📁 Project ID: {cache_data.project_id}", style="bold blue")
714
- console.print()
715
- console.print(
716
- f"📊 Total files: {len(files_with_total_sizes)}", style="bold green"
717
- )
718
- console.print(f"💾 Total size: {total_size_str}", style="bold green")
719
- console.print()
720
-
721
- for file_info in files_with_total_sizes:
722
- total_size = file_info.total_size
723
-
724
- size_str = cli_common.format_bytes_to_human_readable(int(total_size))
725
-
726
- modified_str = cli_common.format_localized_time(
727
- file_info.file_summary.modified
728
- )
729
-
730
- table.add_row(
731
- file_info.file_summary.path,
732
- size_str,
733
- modified_str,
734
- file_info.file_summary.file_type or "Unknown",
735
- file_info.file_summary.permissions or "Unknown",
736
- )
737
-
738
- console.print(table)
739
-
740
- except Exception as e:
741
- console.print(f"Error fetching cache summary: {str(e)}", style="red")
742
- raise
743
-
744
-
745
- def _get_sort_key(sort_by: str) -> Callable[[FileSummaryWithTotalSize], Any]:
746
- if sort_by == SORT_BY_FILEPATH:
747
- return lambda x: x.file_summary.path
748
- elif sort_by == SORT_BY_SIZE:
749
- return lambda x: x.total_size
750
- elif sort_by == SORT_BY_MODIFIED:
751
- return lambda x: x.file_summary.modified
752
- elif sort_by == SORT_BY_TYPE:
753
- return lambda x: x.file_summary.file_type or ""
754
- elif sort_by == SORT_BY_PERMISSIONS:
755
- return lambda x: x.file_summary.permissions or ""
756
- else:
757
- raise ValueError(f"Invalid --sort argument: {sort_by}")
758
-
759
-
760
620
  def view_cache_summary_by_project(
761
621
  remote_provider: BasetenRemote,
762
622
  project_identifier: str,
763
- sort_by: str = SORT_BY_FILEPATH,
764
- order: str = SORT_ORDER_ASC,
623
+ sort_by: Optional[str] = None,
624
+ order: Optional[str] = None,
625
+ output_format: Optional[str] = None,
765
626
  ):
766
627
  """View cache summary for a training project by ID or name."""
628
+ from truss.cli.train.cache import (
629
+ OUTPUT_FORMAT_CLI_TABLE,
630
+ SORT_BY_FILEPATH,
631
+ SORT_ORDER_ASC,
632
+ view_cache_summary,
633
+ )
634
+
635
+ # Use constants for defaults if not provided
636
+ if sort_by is None:
637
+ sort_by = SORT_BY_FILEPATH
638
+ if order is None:
639
+ order = SORT_ORDER_ASC
640
+ if output_format is None:
641
+ output_format = OUTPUT_FORMAT_CLI_TABLE
642
+
767
643
  project = fetch_project_by_name_or_id(remote_provider, project_identifier)
768
- view_cache_summary(remote_provider, project["id"], sort_by, order)
644
+ view_cache_summary(remote_provider, project["id"], sort_by, order, output_format)
@@ -12,9 +12,15 @@ from truss.cli import remote_cli
12
12
  from truss.cli.cli import truss_cli
13
13
  from truss.cli.logs import utils as cli_log_utils
14
14
  from truss.cli.logs.training_log_watcher import TrainingLogWatcher
15
+ from truss.cli.resolvers.training_project_team_resolver import (
16
+ resolve_training_project_team_name,
17
+ )
15
18
  from truss.cli.train import common as train_common
16
19
  from truss.cli.train import core
17
- from truss.cli.train.core import (
20
+ from truss.cli.train.cache import (
21
+ OUTPUT_FORMAT_CLI_TABLE,
22
+ OUTPUT_FORMAT_CSV,
23
+ OUTPUT_FORMAT_JSON,
18
24
  SORT_BY_FILEPATH,
19
25
  SORT_BY_MODIFIED,
20
26
  SORT_BY_PERMISSIONS,
@@ -108,29 +114,70 @@ def _prepare_click_context(f: click.Command, params: dict) -> click.Context:
108
114
  return ctx
109
115
 
110
116
 
117
+ def _resolve_team_name(
118
+ remote_provider: BasetenRemote,
119
+ provided_team_name: Optional[str],
120
+ existing_project_name: Optional[str] = None,
121
+ existing_teams: Optional[dict[str, dict[str, str]]] = None,
122
+ ) -> tuple[Optional[str], Optional[str]]:
123
+ return resolve_training_project_team_name(
124
+ remote_provider=remote_provider,
125
+ provided_team_name=provided_team_name,
126
+ existing_project_name=existing_project_name,
127
+ existing_teams=existing_teams,
128
+ )
129
+
130
+
111
131
  @train.command(name="push")
112
132
  @click.argument("config", type=Path, required=True)
113
133
  @click.option("--remote", type=str, required=False, help="Remote to use")
114
134
  @click.option("--tail", is_flag=True, help="Tail for status + logs after push.")
115
135
  @click.option("--job-name", type=str, required=False, help="Name of the training job.")
136
+ @click.option(
137
+ "--team",
138
+ "provided_team_name",
139
+ type=str,
140
+ required=False,
141
+ help="Team name for the training project",
142
+ )
116
143
  @common.common_options()
117
144
  def push_training_job(
118
- config: Path, remote: Optional[str], tail: bool, job_name: Optional[str]
145
+ config: Path,
146
+ remote: Optional[str],
147
+ tail: bool,
148
+ job_name: Optional[str],
149
+ provided_team_name: Optional[str],
119
150
  ):
120
151
  """Run a training job"""
121
- from truss_train import deployment
152
+ from truss_train import deployment, loader
122
153
 
123
154
  if not remote:
124
155
  remote = remote_cli.inquire_remote_name()
125
156
 
126
- with console.status("Creating training job...", spinner="dots"):
127
- remote_provider: BasetenRemote = cast(
128
- BasetenRemote, RemoteFactory.create(remote=remote)
129
- )
130
- job_resp = deployment.create_training_job_from_file(
131
- remote_provider, config, job_name
157
+ remote_provider: BasetenRemote = cast(
158
+ BasetenRemote, RemoteFactory.create(remote=remote)
159
+ )
160
+
161
+ existing_teams = remote_provider.api.get_teams()
162
+
163
+ with loader.import_training_project(config) as training_project:
164
+ team_name, team_id = _resolve_team_name(
165
+ remote_provider,
166
+ provided_team_name,
167
+ existing_project_name=training_project.name,
168
+ existing_teams=existing_teams,
132
169
  )
133
170
 
171
+ with console.status("Creating training job...", spinner="dots"):
172
+ job_resp = deployment.create_training_job(
173
+ remote_provider,
174
+ config,
175
+ training_project,
176
+ job_name_from_cli=job_name,
177
+ team_name=team_name,
178
+ team_id=team_id,
179
+ )
180
+
134
181
  # Note: This post create logic needs to happen outside the context
135
182
  # of the above context manager, as only one console session can be active
136
183
  # at a time.
@@ -556,8 +603,17 @@ def cache():
556
603
  default=SORT_ORDER_ASC,
557
604
  help="Sort order: ascending or descending.",
558
605
  )
606
+ @click.option(
607
+ "-o",
608
+ "--output-format",
609
+ type=click.Choice([OUTPUT_FORMAT_CLI_TABLE, OUTPUT_FORMAT_CSV, OUTPUT_FORMAT_JSON]),
610
+ default=OUTPUT_FORMAT_CLI_TABLE,
611
+ help="Output format: cli-table (default), csv, or json.",
612
+ )
559
613
  @common.common_options()
560
- def view_cache_summary(project: str, remote: Optional[str], sort: str, order: str):
614
+ def view_cache_summary(
615
+ project: str, remote: Optional[str], sort: str, order: str, output_format: str
616
+ ):
561
617
  """View cache summary for a training project"""
562
618
  if not remote:
563
619
  remote = remote_cli.inquire_remote_name()
@@ -566,7 +622,9 @@ def view_cache_summary(project: str, remote: Optional[str], sort: str, order: st
566
622
  BasetenRemote, RemoteFactory.create(remote=remote)
567
623
  )
568
624
 
569
- train_cli.view_cache_summary_by_project(remote_provider, project, sort, order)
625
+ train_cli.view_cache_summary_by_project(
626
+ remote_provider, project, sort, order, output_format
627
+ )
570
628
 
571
629
 
572
630
  def _maybe_resolve_project_id_from_id_or_name(
truss/cli/utils/common.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import datetime
2
2
  import logging
3
+ import re
3
4
  import sys
4
5
  import warnings
5
6
  from functools import wraps
@@ -20,6 +21,8 @@ from truss.cli.utils import self_upgrade
20
21
  from truss.cli.utils.output import console
21
22
  from truss.util import user_config
22
23
 
24
+ logger = logging.getLogger(__name__)
25
+
23
26
  INCLUDE_GIT_INFO_DOC = (
24
27
  "Whether to attach git versioning info (sha, branch, tag) to deployments made from "
25
28
  "within a git repo. If set to True in `.trussrc`, it will always be attached."
@@ -181,10 +184,44 @@ def is_human_log_level(ctx: click.Context) -> bool:
181
184
  return get_required_option(ctx, "log") != _HUMANFRIENDLY_LOG_LEVEL
182
185
 
183
186
 
184
- def format_localized_time(iso_timestamp: str) -> str:
187
+ def _normalize_iso_timestamp(iso_timestamp: str) -> str:
188
+ iso_timestamp = iso_timestamp.strip()
185
189
  if iso_timestamp.endswith("Z"):
186
- iso_timestamp = iso_timestamp.replace("Z", "+00:00")
187
- utc_time = datetime.datetime.fromisoformat(iso_timestamp)
190
+ iso_timestamp = iso_timestamp[:-1] + "+00:00"
191
+
192
+ tz_part = ""
193
+ tz_match = re.search(r"([+-]\d{2}:\d{2}|[+-]\d{4})$", iso_timestamp)
194
+ if tz_match:
195
+ tz_part = tz_match.group(0)
196
+ iso_timestamp = iso_timestamp[: tz_match.start()]
197
+
198
+ iso_timestamp = iso_timestamp.rstrip()
199
+
200
+ if tz_part and ":" not in tz_part:
201
+ tz_part = f"{tz_part[:3]}:{tz_part[3:]}"
202
+
203
+ fractional_match = re.search(r"\.(\d+)$", iso_timestamp)
204
+ if fractional_match:
205
+ fractional_digits = fractional_match.group(1)
206
+ if len(fractional_digits) > 6:
207
+ iso_timestamp = (
208
+ iso_timestamp[: fractional_match.start()] + "." + fractional_digits[:6]
209
+ )
210
+
211
+ return f"{iso_timestamp}{tz_part}"
212
+
213
+
214
+ # NOTE: `pyproject.toml` declares support down to Python 3.9, whose
215
+ # `datetime.fromisoformat` cannot parse nanosecond fractions or colonless offsets,
216
+ # so normalize timestamps before parsing.
217
+ def format_localized_time(iso_timestamp: str) -> str:
218
+ try:
219
+ utc_time = datetime.datetime.fromisoformat(iso_timestamp)
220
+ except ValueError:
221
+ # Handle non-standard formats (nanoseconds, Z suffix, colonless offsets)
222
+ normalized_timestamp = _normalize_iso_timestamp(iso_timestamp)
223
+ utc_time = datetime.datetime.fromisoformat(normalized_timestamp)
224
+
188
225
  local_time = utc_time.astimezone()
189
226
  return local_time.strftime("%Y-%m-%d %H:%M:%S")
190
227