truss 0.11.6rc102__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. truss/api/__init__.py +5 -2
  2. truss/base/constants.py +1 -0
  3. truss/base/trt_llm_config.py +14 -3
  4. truss/base/truss_config.py +19 -4
  5. truss/cli/chains_commands.py +49 -1
  6. truss/cli/cli.py +38 -7
  7. truss/cli/logs/base_watcher.py +31 -12
  8. truss/cli/logs/model_log_watcher.py +24 -1
  9. truss/cli/remote_cli.py +29 -0
  10. truss/cli/resolvers/chain_team_resolver.py +82 -0
  11. truss/cli/resolvers/model_team_resolver.py +90 -0
  12. truss/cli/resolvers/training_project_team_resolver.py +81 -0
  13. truss/cli/train/cache.py +332 -0
  14. truss/cli/train/core.py +57 -163
  15. truss/cli/train/deploy_checkpoints/__init__.py +2 -2
  16. truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +236 -103
  17. truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +1 -52
  18. truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +1 -86
  19. truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py +1 -85
  20. truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +1 -56
  21. truss/cli/train/types.py +18 -9
  22. truss/cli/train_commands.py +180 -35
  23. truss/cli/utils/common.py +40 -3
  24. truss/contexts/image_builder/serving_image_builder.py +17 -4
  25. truss/remote/baseten/api.py +215 -9
  26. truss/remote/baseten/core.py +63 -7
  27. truss/remote/baseten/custom_types.py +1 -0
  28. truss/remote/baseten/remote.py +42 -2
  29. truss/remote/baseten/service.py +0 -7
  30. truss/remote/baseten/utils/transfer.py +5 -2
  31. truss/templates/base.Dockerfile.jinja +8 -4
  32. truss/templates/control/control/application.py +51 -26
  33. truss/templates/control/control/endpoints.py +1 -5
  34. truss/templates/control/control/helpers/inference_server_process_controller.py +10 -4
  35. truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py +33 -18
  36. truss/templates/control/control/server.py +1 -1
  37. truss/templates/control/requirements.txt +1 -2
  38. truss/templates/docker_server/proxy.conf.jinja +13 -0
  39. truss/templates/docker_server/supervisord.conf.jinja +2 -1
  40. truss/templates/no_build.Dockerfile.jinja +1 -0
  41. truss/templates/server/requirements.txt +2 -3
  42. truss/templates/server/truss_server.py +2 -5
  43. truss/templates/server.Dockerfile.jinja +12 -12
  44. truss/templates/shared/lazy_data_resolver.py +214 -2
  45. truss/templates/shared/util.py +6 -5
  46. truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
  47. truss/tests/cli/test_chains_cli.py +144 -0
  48. truss/tests/cli/test_cli.py +134 -1
  49. truss/tests/cli/test_cli_utils_common.py +11 -0
  50. truss/tests/cli/test_model_team_resolver.py +279 -0
  51. truss/tests/cli/train/test_cache_view.py +240 -3
  52. truss/tests/cli/train/test_deploy_checkpoints.py +2 -846
  53. truss/tests/cli/train/test_train_cli_core.py +2 -2
  54. truss/tests/cli/train/test_train_team_parameter.py +395 -0
  55. truss/tests/conftest.py +187 -0
  56. truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
  57. truss/tests/remote/baseten/test_api.py +122 -3
  58. truss/tests/remote/baseten/test_chain_upload.py +294 -0
  59. truss/tests/remote/baseten/test_core.py +86 -0
  60. truss/tests/remote/baseten/test_remote.py +216 -288
  61. truss/tests/remote/baseten/test_service.py +56 -0
  62. truss/tests/templates/control/control/conftest.py +20 -0
  63. truss/tests/templates/control/control/test_endpoints.py +4 -0
  64. truss/tests/templates/control/control/test_server.py +8 -24
  65. truss/tests/templates/control/control/test_server_integration.py +4 -2
  66. truss/tests/test_config.py +21 -12
  67. truss/tests/test_data/server.Dockerfile +3 -1
  68. truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
  69. truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
  70. truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
  71. truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
  72. truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
  73. truss/tests/test_model_inference.py +13 -0
  74. truss/tests/util/test_env_vars.py +8 -3
  75. truss/util/__init__.py +0 -0
  76. truss/util/env_vars.py +19 -8
  77. truss/util/error_utils.py +37 -0
  78. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/METADATA +2 -2
  79. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/RECORD +88 -70
  80. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
  81. truss_chains/deployment/deployment_client.py +16 -4
  82. truss_chains/private_types.py +18 -0
  83. truss_chains/public_api.py +3 -0
  84. truss_train/definitions.py +6 -4
  85. truss_train/deployment.py +43 -21
  86. truss_train/public_api.py +4 -2
  87. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
  88. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
truss/cli/train/core.py CHANGED
@@ -16,26 +16,17 @@ from rich.text import Text
16
16
 
17
17
  from truss.cli.train import common, deploy_checkpoints
18
18
  from truss.cli.train.metrics_watcher import MetricsWatcher
19
- from truss.cli.train.types import PrepareCheckpointArgs, PrepareCheckpointResult
19
+ from truss.cli.train.types import (
20
+ DeployCheckpointArgs,
21
+ DeployCheckpointsConfigComplete,
22
+ DeploySuccessResult,
23
+ )
20
24
  from truss.cli.utils import common as cli_common
21
25
  from truss.cli.utils.output import console
22
- from truss.remote.baseten.custom_types import (
23
- FileSummary,
24
- FileSummaryWithTotalSize,
25
- GetCacheSummaryResponseV1,
26
- )
27
26
  from truss.remote.baseten.remote import BasetenRemote
28
27
  from truss_train import loader
29
28
  from truss_train.definitions import DeployCheckpointsConfig
30
29
 
31
- SORT_BY_FILEPATH = "filepath"
32
- SORT_BY_SIZE = "size"
33
- SORT_BY_MODIFIED = "modified"
34
- SORT_BY_TYPE = "type"
35
- SORT_BY_PERMISSIONS = "permissions"
36
- SORT_ORDER_ASC = "asc"
37
- SORT_ORDER_DESC = "desc"
38
-
39
30
  ACTIVE_JOB_STATUSES = [
40
31
  "TRAINING_JOB_RUNNING",
41
32
  "TRAINING_JOB_CREATED",
@@ -156,7 +147,7 @@ def display_training_projects(projects: list[dict], remote_url: str) -> None:
156
147
  latest_job = project.get("latest_job") or {}
157
148
  if latest_job_id := latest_job.get("id", ""):
158
149
  latest_job_link = cli_common.format_link(
159
- status_page_url(remote_url, latest_job_id), "link"
150
+ status_page_url(remote_url, project["id"], latest_job_id), "link"
160
151
  )
161
152
  else:
162
153
  latest_job_link = ""
@@ -242,35 +233,43 @@ def view_training_job_metrics(
242
233
  metrics_display.watch()
243
234
 
244
235
 
245
- def prepare_checkpoint_deploy(
246
- remote_provider: BasetenRemote, args: PrepareCheckpointArgs
247
- ) -> PrepareCheckpointResult:
236
+ def create_model_version_from_inference_template(
237
+ remote_provider: BasetenRemote, args: DeployCheckpointArgs
238
+ ) -> DeploySuccessResult:
248
239
  if not args.deploy_config_path:
249
- return deploy_checkpoints.prepare_checkpoint_deploy(
250
- remote_provider, DeployCheckpointsConfig(), args.project_id, args.job_id
240
+ return deploy_checkpoints.create_model_version_from_inference_template(
241
+ remote_provider,
242
+ DeployCheckpointsConfig(),
243
+ args.project_id,
244
+ args.job_id,
245
+ args.dry_run,
251
246
  )
252
247
  #### User provided a checkpoint deploy config file
253
248
  with loader.import_deploy_checkpoints_config(
254
249
  Path(args.deploy_config_path)
255
250
  ) as checkpoint_deploy:
256
- return deploy_checkpoints.prepare_checkpoint_deploy(
257
- remote_provider, checkpoint_deploy, args.project_id, args.job_id
251
+ return deploy_checkpoints.create_model_version_from_inference_template(
252
+ remote_provider,
253
+ checkpoint_deploy,
254
+ args.project_id,
255
+ args.job_id,
256
+ args.dry_run,
258
257
  )
259
258
 
260
259
 
261
260
  def _get_checkpoint_names(
262
- prepare_checkpoint_result: PrepareCheckpointResult,
261
+ checkpoint_deploy_config: DeployCheckpointsConfigComplete,
263
262
  ) -> list[str]:
264
263
  return [
265
- checkpoint.paths[0].strip("/").split("/")[-1]
266
- for checkpoint in prepare_checkpoint_result.checkpoint_deploy_config.checkpoint_details.checkpoints
264
+ checkpoint.checkpoint_name
265
+ for checkpoint in checkpoint_deploy_config.checkpoint_details.checkpoints
267
266
  ]
268
267
 
269
268
 
270
269
  def print_deploy_checkpoints_success_message(
271
- prepare_checkpoint_result: PrepareCheckpointResult,
270
+ checkpoint_deploy_config: DeployCheckpointsConfigComplete,
272
271
  ):
273
- checkpoint_names = _get_checkpoint_names(prepare_checkpoint_result)
272
+ checkpoint_names = _get_checkpoint_names(checkpoint_deploy_config)
274
273
  console.print(
275
274
  Text("\nTo run the model"),
276
275
  Text("ensure your `model` parameter is set to one of"),
@@ -279,7 +278,9 @@ def print_deploy_checkpoints_success_message(
279
278
  style="magenta",
280
279
  ),
281
280
  Text("in your request. An example request body might look like this:"),
282
- Text(f"\n{{'model': {checkpoint_names[0]}, 'messages': [...]}}", style="green"),
281
+ Text(
282
+ f'\n{{"model": "{checkpoint_names[0]}", "messages": [...]}}', style="green"
283
+ ),
283
284
  )
284
285
 
285
286
 
@@ -297,16 +298,20 @@ def display_training_job(
297
298
  table.add_column("Value")
298
299
 
299
300
  # Basic job details
301
+ table.add_row("Job Name", job["name"])
302
+ table.add_row("Job ID", job["id"])
300
303
  table.add_row("Project ID", job["training_project"]["id"])
301
304
  table.add_row("Project Name", job["training_project"]["name"])
302
- table.add_row("Job ID", job["id"])
303
305
  table.add_row("Status", job["current_status"])
304
306
  table.add_row("Instance Type", job["instance_type"]["name"])
305
307
  table.add_row("Created", cli_common.format_localized_time(job["created_at"]))
306
308
  table.add_row("Last Modified", cli_common.format_localized_time(job["updated_at"]))
307
309
  table.add_row(
308
- "Status Page",
309
- cli_common.format_link(status_page_url(remote_url, job["id"]), "link"),
310
+ "Job Page",
311
+ cli_common.format_link(
312
+ status_page_url(remote_url, job["training_project"]["id"], job["id"]),
313
+ "link",
314
+ ),
310
315
  )
311
316
 
312
317
  # Add error message if present
@@ -417,8 +422,8 @@ def download_checkpoint_artifacts(
417
422
  return urls_file
418
423
 
419
424
 
420
- def status_page_url(remote_url: str, training_job_id: str) -> str:
421
- return f"{remote_url}/training/jobs/{training_job_id}"
425
+ def status_page_url(remote_url: str, project_id: str, training_job_id: str) -> str:
426
+ return f"{remote_url}/training/{project_id}/logs/{training_job_id}"
422
427
 
423
428
 
424
429
  def _get_all_train_init_example_options(
@@ -612,139 +617,28 @@ def fetch_project_by_name_or_id(
612
617
  raise click.ClickException(f"Error fetching project: {str(e)}")
613
618
 
614
619
 
615
- def create_file_summary_with_directory_sizes(
616
- files: list[FileSummary],
617
- ) -> list[FileSummaryWithTotalSize]:
618
- directory_sizes = calculate_directory_sizes(files)
619
- return [
620
- FileSummaryWithTotalSize(
621
- file_summary=file_info,
622
- total_size=directory_sizes.get(file_info.path, file_info.size_bytes),
623
- )
624
- for file_info in files
625
- ]
626
-
627
-
628
- def calculate_directory_sizes(
629
- files: list[FileSummary], max_depth: int = 100
630
- ) -> dict[str, int]:
631
- directory_sizes = {}
632
-
633
- for file_info in files:
634
- if file_info.file_type == "directory":
635
- directory_sizes[file_info.path] = 0
636
-
637
- for file_info in files:
638
- current_path = file_info.path
639
- for i in range(max_depth):
640
- if current_path is None:
641
- break
642
- if current_path in directory_sizes:
643
- directory_sizes[current_path] += file_info.size_bytes
644
- # Move to parent directory
645
- parent = os.path.dirname(current_path)
646
- if parent == current_path: # Reached root
647
- break
648
- current_path = parent
649
-
650
- return directory_sizes
651
-
652
-
653
- def view_cache_summary(
654
- remote_provider: BasetenRemote,
655
- project_id: str,
656
- sort_by: str = SORT_BY_FILEPATH,
657
- order: str = SORT_ORDER_ASC,
658
- ):
659
- """View cache summary for a training project."""
660
- try:
661
- raw_cache_data = remote_provider.api.get_cache_summary(project_id)
662
-
663
- if not raw_cache_data:
664
- console.print("No cache summary found for this project.", style="yellow")
665
- return
666
-
667
- cache_data = GetCacheSummaryResponseV1.model_validate(raw_cache_data)
668
-
669
- table = rich.table.Table(title=f"Cache summary for project: {project_id}")
670
- table.add_column("File Path", style="cyan")
671
- table.add_column("Size", style="green")
672
- table.add_column("Modified", style="yellow")
673
- table.add_column("Type")
674
- table.add_column("Permissions", style="magenta")
675
-
676
- files = cache_data.file_summaries
677
- if not files:
678
- console.print("No files found in cache.", style="yellow")
679
- return
680
-
681
- files_with_total_sizes = create_file_summary_with_directory_sizes(files)
682
-
683
- reverse = order == SORT_ORDER_DESC
684
- sort_key = _get_sort_key(sort_by)
685
- files_with_total_sizes.sort(key=sort_key, reverse=reverse)
686
-
687
- total_size = sum(
688
- file_info.file_summary.size_bytes for file_info in files_with_total_sizes
689
- )
690
- total_size_str = common.format_bytes_to_human_readable(total_size)
691
-
692
- console.print(
693
- f"📅 Cache captured at: {cache_data.timestamp}", style="bold blue"
694
- )
695
- console.print(f"📁 Project ID: {cache_data.project_id}", style="bold blue")
696
- console.print()
697
- console.print(
698
- f"📊 Total files: {len(files_with_total_sizes)}", style="bold green"
699
- )
700
- console.print(f"💾 Total size: {total_size_str}", style="bold green")
701
- console.print()
702
-
703
- for file_info in files_with_total_sizes:
704
- total_size = file_info.total_size
705
-
706
- size_str = cli_common.format_bytes_to_human_readable(int(total_size))
707
-
708
- modified_str = cli_common.format_localized_time(
709
- file_info.file_summary.modified
710
- )
711
-
712
- table.add_row(
713
- file_info.file_summary.path,
714
- size_str,
715
- modified_str,
716
- file_info.file_summary.file_type or "Unknown",
717
- file_info.file_summary.permissions or "Unknown",
718
- )
719
-
720
- console.print(table)
721
-
722
- except Exception as e:
723
- console.print(f"Error fetching cache summary: {str(e)}", style="red")
724
- raise
725
-
726
-
727
- def _get_sort_key(sort_by: str) -> Callable[[FileSummaryWithTotalSize], Any]:
728
- if sort_by == SORT_BY_FILEPATH:
729
- return lambda x: x.file_summary.path
730
- elif sort_by == SORT_BY_SIZE:
731
- return lambda x: x.total_size
732
- elif sort_by == SORT_BY_MODIFIED:
733
- return lambda x: x.file_summary.modified
734
- elif sort_by == SORT_BY_TYPE:
735
- return lambda x: x.file_summary.file_type or ""
736
- elif sort_by == SORT_BY_PERMISSIONS:
737
- return lambda x: x.file_summary.permissions or ""
738
- else:
739
- raise ValueError(f"Invalid --sort argument: {sort_by}")
740
-
741
-
742
620
  def view_cache_summary_by_project(
743
621
  remote_provider: BasetenRemote,
744
622
  project_identifier: str,
745
- sort_by: str = SORT_BY_FILEPATH,
746
- order: str = SORT_ORDER_ASC,
623
+ sort_by: Optional[str] = None,
624
+ order: Optional[str] = None,
625
+ output_format: Optional[str] = None,
747
626
  ):
748
627
  """View cache summary for a training project by ID or name."""
628
+ from truss.cli.train.cache import (
629
+ OUTPUT_FORMAT_CLI_TABLE,
630
+ SORT_BY_FILEPATH,
631
+ SORT_ORDER_ASC,
632
+ view_cache_summary,
633
+ )
634
+
635
+ # Use constants for defaults if not provided
636
+ if sort_by is None:
637
+ sort_by = SORT_BY_FILEPATH
638
+ if order is None:
639
+ order = SORT_ORDER_ASC
640
+ if output_format is None:
641
+ output_format = OUTPUT_FORMAT_CLI_TABLE
642
+
749
643
  project = fetch_project_by_name_or_id(remote_provider, project_identifier)
750
- view_cache_summary(remote_provider, project["id"], sort_by, order)
644
+ view_cache_summary(remote_provider, project["id"], sort_by, order, output_format)
@@ -1,3 +1,3 @@
1
- from .deploy_checkpoints import prepare_checkpoint_deploy
1
+ from .deploy_checkpoints import create_model_version_from_inference_template
2
2
 
3
- __all__ = ["prepare_checkpoint_deploy"]
3
+ __all__ = ["create_model_version_from_inference_template"]