truss 0.11.6rc102__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- truss/api/__init__.py +5 -2
- truss/base/constants.py +1 -0
- truss/base/trt_llm_config.py +14 -3
- truss/base/truss_config.py +19 -4
- truss/cli/chains_commands.py +49 -1
- truss/cli/cli.py +38 -7
- truss/cli/logs/base_watcher.py +31 -12
- truss/cli/logs/model_log_watcher.py +24 -1
- truss/cli/remote_cli.py +29 -0
- truss/cli/resolvers/chain_team_resolver.py +82 -0
- truss/cli/resolvers/model_team_resolver.py +90 -0
- truss/cli/resolvers/training_project_team_resolver.py +81 -0
- truss/cli/train/cache.py +332 -0
- truss/cli/train/core.py +57 -163
- truss/cli/train/deploy_checkpoints/__init__.py +2 -2
- truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +236 -103
- truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +1 -52
- truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +1 -86
- truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py +1 -85
- truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +1 -56
- truss/cli/train/types.py +18 -9
- truss/cli/train_commands.py +180 -35
- truss/cli/utils/common.py +40 -3
- truss/contexts/image_builder/serving_image_builder.py +17 -4
- truss/remote/baseten/api.py +215 -9
- truss/remote/baseten/core.py +63 -7
- truss/remote/baseten/custom_types.py +1 -0
- truss/remote/baseten/remote.py +42 -2
- truss/remote/baseten/service.py +0 -7
- truss/remote/baseten/utils/transfer.py +5 -2
- truss/templates/base.Dockerfile.jinja +8 -4
- truss/templates/control/control/application.py +51 -26
- truss/templates/control/control/endpoints.py +1 -5
- truss/templates/control/control/helpers/inference_server_process_controller.py +10 -4
- truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py +33 -18
- truss/templates/control/control/server.py +1 -1
- truss/templates/control/requirements.txt +1 -2
- truss/templates/docker_server/proxy.conf.jinja +13 -0
- truss/templates/docker_server/supervisord.conf.jinja +2 -1
- truss/templates/no_build.Dockerfile.jinja +1 -0
- truss/templates/server/requirements.txt +2 -3
- truss/templates/server/truss_server.py +2 -5
- truss/templates/server.Dockerfile.jinja +12 -12
- truss/templates/shared/lazy_data_resolver.py +214 -2
- truss/templates/shared/util.py +6 -5
- truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
- truss/tests/cli/test_chains_cli.py +144 -0
- truss/tests/cli/test_cli.py +134 -1
- truss/tests/cli/test_cli_utils_common.py +11 -0
- truss/tests/cli/test_model_team_resolver.py +279 -0
- truss/tests/cli/train/test_cache_view.py +240 -3
- truss/tests/cli/train/test_deploy_checkpoints.py +2 -846
- truss/tests/cli/train/test_train_cli_core.py +2 -2
- truss/tests/cli/train/test_train_team_parameter.py +395 -0
- truss/tests/conftest.py +187 -0
- truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
- truss/tests/remote/baseten/test_api.py +122 -3
- truss/tests/remote/baseten/test_chain_upload.py +294 -0
- truss/tests/remote/baseten/test_core.py +86 -0
- truss/tests/remote/baseten/test_remote.py +216 -288
- truss/tests/remote/baseten/test_service.py +56 -0
- truss/tests/templates/control/control/conftest.py +20 -0
- truss/tests/templates/control/control/test_endpoints.py +4 -0
- truss/tests/templates/control/control/test_server.py +8 -24
- truss/tests/templates/control/control/test_server_integration.py +4 -2
- truss/tests/test_config.py +21 -12
- truss/tests/test_data/server.Dockerfile +3 -1
- truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
- truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
- truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
- truss/tests/test_model_inference.py +13 -0
- truss/tests/util/test_env_vars.py +8 -3
- truss/util/__init__.py +0 -0
- truss/util/env_vars.py +19 -8
- truss/util/error_utils.py +37 -0
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/METADATA +2 -2
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/RECORD +88 -70
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
- truss_chains/deployment/deployment_client.py +16 -4
- truss_chains/private_types.py +18 -0
- truss_chains/public_api.py +3 -0
- truss_train/definitions.py +6 -4
- truss_train/deployment.py +43 -21
- truss_train/public_api.py +4 -2
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
truss/cli/train/core.py
CHANGED
|
@@ -16,26 +16,17 @@ from rich.text import Text
|
|
|
16
16
|
|
|
17
17
|
from truss.cli.train import common, deploy_checkpoints
|
|
18
18
|
from truss.cli.train.metrics_watcher import MetricsWatcher
|
|
19
|
-
from truss.cli.train.types import
|
|
19
|
+
from truss.cli.train.types import (
|
|
20
|
+
DeployCheckpointArgs,
|
|
21
|
+
DeployCheckpointsConfigComplete,
|
|
22
|
+
DeploySuccessResult,
|
|
23
|
+
)
|
|
20
24
|
from truss.cli.utils import common as cli_common
|
|
21
25
|
from truss.cli.utils.output import console
|
|
22
|
-
from truss.remote.baseten.custom_types import (
|
|
23
|
-
FileSummary,
|
|
24
|
-
FileSummaryWithTotalSize,
|
|
25
|
-
GetCacheSummaryResponseV1,
|
|
26
|
-
)
|
|
27
26
|
from truss.remote.baseten.remote import BasetenRemote
|
|
28
27
|
from truss_train import loader
|
|
29
28
|
from truss_train.definitions import DeployCheckpointsConfig
|
|
30
29
|
|
|
31
|
-
SORT_BY_FILEPATH = "filepath"
|
|
32
|
-
SORT_BY_SIZE = "size"
|
|
33
|
-
SORT_BY_MODIFIED = "modified"
|
|
34
|
-
SORT_BY_TYPE = "type"
|
|
35
|
-
SORT_BY_PERMISSIONS = "permissions"
|
|
36
|
-
SORT_ORDER_ASC = "asc"
|
|
37
|
-
SORT_ORDER_DESC = "desc"
|
|
38
|
-
|
|
39
30
|
ACTIVE_JOB_STATUSES = [
|
|
40
31
|
"TRAINING_JOB_RUNNING",
|
|
41
32
|
"TRAINING_JOB_CREATED",
|
|
@@ -156,7 +147,7 @@ def display_training_projects(projects: list[dict], remote_url: str) -> None:
|
|
|
156
147
|
latest_job = project.get("latest_job") or {}
|
|
157
148
|
if latest_job_id := latest_job.get("id", ""):
|
|
158
149
|
latest_job_link = cli_common.format_link(
|
|
159
|
-
status_page_url(remote_url, latest_job_id), "link"
|
|
150
|
+
status_page_url(remote_url, project["id"], latest_job_id), "link"
|
|
160
151
|
)
|
|
161
152
|
else:
|
|
162
153
|
latest_job_link = ""
|
|
@@ -242,35 +233,43 @@ def view_training_job_metrics(
|
|
|
242
233
|
metrics_display.watch()
|
|
243
234
|
|
|
244
235
|
|
|
245
|
-
def
|
|
246
|
-
remote_provider: BasetenRemote, args:
|
|
247
|
-
) ->
|
|
236
|
+
def create_model_version_from_inference_template(
|
|
237
|
+
remote_provider: BasetenRemote, args: DeployCheckpointArgs
|
|
238
|
+
) -> DeploySuccessResult:
|
|
248
239
|
if not args.deploy_config_path:
|
|
249
|
-
return deploy_checkpoints.
|
|
250
|
-
remote_provider,
|
|
240
|
+
return deploy_checkpoints.create_model_version_from_inference_template(
|
|
241
|
+
remote_provider,
|
|
242
|
+
DeployCheckpointsConfig(),
|
|
243
|
+
args.project_id,
|
|
244
|
+
args.job_id,
|
|
245
|
+
args.dry_run,
|
|
251
246
|
)
|
|
252
247
|
#### User provided a checkpoint deploy config file
|
|
253
248
|
with loader.import_deploy_checkpoints_config(
|
|
254
249
|
Path(args.deploy_config_path)
|
|
255
250
|
) as checkpoint_deploy:
|
|
256
|
-
return deploy_checkpoints.
|
|
257
|
-
remote_provider,
|
|
251
|
+
return deploy_checkpoints.create_model_version_from_inference_template(
|
|
252
|
+
remote_provider,
|
|
253
|
+
checkpoint_deploy,
|
|
254
|
+
args.project_id,
|
|
255
|
+
args.job_id,
|
|
256
|
+
args.dry_run,
|
|
258
257
|
)
|
|
259
258
|
|
|
260
259
|
|
|
261
260
|
def _get_checkpoint_names(
|
|
262
|
-
|
|
261
|
+
checkpoint_deploy_config: DeployCheckpointsConfigComplete,
|
|
263
262
|
) -> list[str]:
|
|
264
263
|
return [
|
|
265
|
-
checkpoint.
|
|
266
|
-
for checkpoint in
|
|
264
|
+
checkpoint.checkpoint_name
|
|
265
|
+
for checkpoint in checkpoint_deploy_config.checkpoint_details.checkpoints
|
|
267
266
|
]
|
|
268
267
|
|
|
269
268
|
|
|
270
269
|
def print_deploy_checkpoints_success_message(
|
|
271
|
-
|
|
270
|
+
checkpoint_deploy_config: DeployCheckpointsConfigComplete,
|
|
272
271
|
):
|
|
273
|
-
checkpoint_names = _get_checkpoint_names(
|
|
272
|
+
checkpoint_names = _get_checkpoint_names(checkpoint_deploy_config)
|
|
274
273
|
console.print(
|
|
275
274
|
Text("\nTo run the model"),
|
|
276
275
|
Text("ensure your `model` parameter is set to one of"),
|
|
@@ -279,7 +278,9 @@ def print_deploy_checkpoints_success_message(
|
|
|
279
278
|
style="magenta",
|
|
280
279
|
),
|
|
281
280
|
Text("in your request. An example request body might look like this:"),
|
|
282
|
-
Text(
|
|
281
|
+
Text(
|
|
282
|
+
f'\n{{"model": "{checkpoint_names[0]}", "messages": [...]}}', style="green"
|
|
283
|
+
),
|
|
283
284
|
)
|
|
284
285
|
|
|
285
286
|
|
|
@@ -297,16 +298,20 @@ def display_training_job(
|
|
|
297
298
|
table.add_column("Value")
|
|
298
299
|
|
|
299
300
|
# Basic job details
|
|
301
|
+
table.add_row("Job Name", job["name"])
|
|
302
|
+
table.add_row("Job ID", job["id"])
|
|
300
303
|
table.add_row("Project ID", job["training_project"]["id"])
|
|
301
304
|
table.add_row("Project Name", job["training_project"]["name"])
|
|
302
|
-
table.add_row("Job ID", job["id"])
|
|
303
305
|
table.add_row("Status", job["current_status"])
|
|
304
306
|
table.add_row("Instance Type", job["instance_type"]["name"])
|
|
305
307
|
table.add_row("Created", cli_common.format_localized_time(job["created_at"]))
|
|
306
308
|
table.add_row("Last Modified", cli_common.format_localized_time(job["updated_at"]))
|
|
307
309
|
table.add_row(
|
|
308
|
-
"
|
|
309
|
-
cli_common.format_link(
|
|
310
|
+
"Job Page",
|
|
311
|
+
cli_common.format_link(
|
|
312
|
+
status_page_url(remote_url, job["training_project"]["id"], job["id"]),
|
|
313
|
+
"link",
|
|
314
|
+
),
|
|
310
315
|
)
|
|
311
316
|
|
|
312
317
|
# Add error message if present
|
|
@@ -417,8 +422,8 @@ def download_checkpoint_artifacts(
|
|
|
417
422
|
return urls_file
|
|
418
423
|
|
|
419
424
|
|
|
420
|
-
def status_page_url(remote_url: str, training_job_id: str) -> str:
|
|
421
|
-
return f"{remote_url}/training/
|
|
425
|
+
def status_page_url(remote_url: str, project_id: str, training_job_id: str) -> str:
|
|
426
|
+
return f"{remote_url}/training/{project_id}/logs/{training_job_id}"
|
|
422
427
|
|
|
423
428
|
|
|
424
429
|
def _get_all_train_init_example_options(
|
|
@@ -612,139 +617,28 @@ def fetch_project_by_name_or_id(
|
|
|
612
617
|
raise click.ClickException(f"Error fetching project: {str(e)}")
|
|
613
618
|
|
|
614
619
|
|
|
615
|
-
def create_file_summary_with_directory_sizes(
|
|
616
|
-
files: list[FileSummary],
|
|
617
|
-
) -> list[FileSummaryWithTotalSize]:
|
|
618
|
-
directory_sizes = calculate_directory_sizes(files)
|
|
619
|
-
return [
|
|
620
|
-
FileSummaryWithTotalSize(
|
|
621
|
-
file_summary=file_info,
|
|
622
|
-
total_size=directory_sizes.get(file_info.path, file_info.size_bytes),
|
|
623
|
-
)
|
|
624
|
-
for file_info in files
|
|
625
|
-
]
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
def calculate_directory_sizes(
|
|
629
|
-
files: list[FileSummary], max_depth: int = 100
|
|
630
|
-
) -> dict[str, int]:
|
|
631
|
-
directory_sizes = {}
|
|
632
|
-
|
|
633
|
-
for file_info in files:
|
|
634
|
-
if file_info.file_type == "directory":
|
|
635
|
-
directory_sizes[file_info.path] = 0
|
|
636
|
-
|
|
637
|
-
for file_info in files:
|
|
638
|
-
current_path = file_info.path
|
|
639
|
-
for i in range(max_depth):
|
|
640
|
-
if current_path is None:
|
|
641
|
-
break
|
|
642
|
-
if current_path in directory_sizes:
|
|
643
|
-
directory_sizes[current_path] += file_info.size_bytes
|
|
644
|
-
# Move to parent directory
|
|
645
|
-
parent = os.path.dirname(current_path)
|
|
646
|
-
if parent == current_path: # Reached root
|
|
647
|
-
break
|
|
648
|
-
current_path = parent
|
|
649
|
-
|
|
650
|
-
return directory_sizes
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
def view_cache_summary(
|
|
654
|
-
remote_provider: BasetenRemote,
|
|
655
|
-
project_id: str,
|
|
656
|
-
sort_by: str = SORT_BY_FILEPATH,
|
|
657
|
-
order: str = SORT_ORDER_ASC,
|
|
658
|
-
):
|
|
659
|
-
"""View cache summary for a training project."""
|
|
660
|
-
try:
|
|
661
|
-
raw_cache_data = remote_provider.api.get_cache_summary(project_id)
|
|
662
|
-
|
|
663
|
-
if not raw_cache_data:
|
|
664
|
-
console.print("No cache summary found for this project.", style="yellow")
|
|
665
|
-
return
|
|
666
|
-
|
|
667
|
-
cache_data = GetCacheSummaryResponseV1.model_validate(raw_cache_data)
|
|
668
|
-
|
|
669
|
-
table = rich.table.Table(title=f"Cache summary for project: {project_id}")
|
|
670
|
-
table.add_column("File Path", style="cyan")
|
|
671
|
-
table.add_column("Size", style="green")
|
|
672
|
-
table.add_column("Modified", style="yellow")
|
|
673
|
-
table.add_column("Type")
|
|
674
|
-
table.add_column("Permissions", style="magenta")
|
|
675
|
-
|
|
676
|
-
files = cache_data.file_summaries
|
|
677
|
-
if not files:
|
|
678
|
-
console.print("No files found in cache.", style="yellow")
|
|
679
|
-
return
|
|
680
|
-
|
|
681
|
-
files_with_total_sizes = create_file_summary_with_directory_sizes(files)
|
|
682
|
-
|
|
683
|
-
reverse = order == SORT_ORDER_DESC
|
|
684
|
-
sort_key = _get_sort_key(sort_by)
|
|
685
|
-
files_with_total_sizes.sort(key=sort_key, reverse=reverse)
|
|
686
|
-
|
|
687
|
-
total_size = sum(
|
|
688
|
-
file_info.file_summary.size_bytes for file_info in files_with_total_sizes
|
|
689
|
-
)
|
|
690
|
-
total_size_str = common.format_bytes_to_human_readable(total_size)
|
|
691
|
-
|
|
692
|
-
console.print(
|
|
693
|
-
f"📅 Cache captured at: {cache_data.timestamp}", style="bold blue"
|
|
694
|
-
)
|
|
695
|
-
console.print(f"📁 Project ID: {cache_data.project_id}", style="bold blue")
|
|
696
|
-
console.print()
|
|
697
|
-
console.print(
|
|
698
|
-
f"📊 Total files: {len(files_with_total_sizes)}", style="bold green"
|
|
699
|
-
)
|
|
700
|
-
console.print(f"💾 Total size: {total_size_str}", style="bold green")
|
|
701
|
-
console.print()
|
|
702
|
-
|
|
703
|
-
for file_info in files_with_total_sizes:
|
|
704
|
-
total_size = file_info.total_size
|
|
705
|
-
|
|
706
|
-
size_str = cli_common.format_bytes_to_human_readable(int(total_size))
|
|
707
|
-
|
|
708
|
-
modified_str = cli_common.format_localized_time(
|
|
709
|
-
file_info.file_summary.modified
|
|
710
|
-
)
|
|
711
|
-
|
|
712
|
-
table.add_row(
|
|
713
|
-
file_info.file_summary.path,
|
|
714
|
-
size_str,
|
|
715
|
-
modified_str,
|
|
716
|
-
file_info.file_summary.file_type or "Unknown",
|
|
717
|
-
file_info.file_summary.permissions or "Unknown",
|
|
718
|
-
)
|
|
719
|
-
|
|
720
|
-
console.print(table)
|
|
721
|
-
|
|
722
|
-
except Exception as e:
|
|
723
|
-
console.print(f"Error fetching cache summary: {str(e)}", style="red")
|
|
724
|
-
raise
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
def _get_sort_key(sort_by: str) -> Callable[[FileSummaryWithTotalSize], Any]:
|
|
728
|
-
if sort_by == SORT_BY_FILEPATH:
|
|
729
|
-
return lambda x: x.file_summary.path
|
|
730
|
-
elif sort_by == SORT_BY_SIZE:
|
|
731
|
-
return lambda x: x.total_size
|
|
732
|
-
elif sort_by == SORT_BY_MODIFIED:
|
|
733
|
-
return lambda x: x.file_summary.modified
|
|
734
|
-
elif sort_by == SORT_BY_TYPE:
|
|
735
|
-
return lambda x: x.file_summary.file_type or ""
|
|
736
|
-
elif sort_by == SORT_BY_PERMISSIONS:
|
|
737
|
-
return lambda x: x.file_summary.permissions or ""
|
|
738
|
-
else:
|
|
739
|
-
raise ValueError(f"Invalid --sort argument: {sort_by}")
|
|
740
|
-
|
|
741
|
-
|
|
742
620
|
def view_cache_summary_by_project(
|
|
743
621
|
remote_provider: BasetenRemote,
|
|
744
622
|
project_identifier: str,
|
|
745
|
-
sort_by: str =
|
|
746
|
-
order: str =
|
|
623
|
+
sort_by: Optional[str] = None,
|
|
624
|
+
order: Optional[str] = None,
|
|
625
|
+
output_format: Optional[str] = None,
|
|
747
626
|
):
|
|
748
627
|
"""View cache summary for a training project by ID or name."""
|
|
628
|
+
from truss.cli.train.cache import (
|
|
629
|
+
OUTPUT_FORMAT_CLI_TABLE,
|
|
630
|
+
SORT_BY_FILEPATH,
|
|
631
|
+
SORT_ORDER_ASC,
|
|
632
|
+
view_cache_summary,
|
|
633
|
+
)
|
|
634
|
+
|
|
635
|
+
# Use constants for defaults if not provided
|
|
636
|
+
if sort_by is None:
|
|
637
|
+
sort_by = SORT_BY_FILEPATH
|
|
638
|
+
if order is None:
|
|
639
|
+
order = SORT_ORDER_ASC
|
|
640
|
+
if output_format is None:
|
|
641
|
+
output_format = OUTPUT_FORMAT_CLI_TABLE
|
|
642
|
+
|
|
749
643
|
project = fetch_project_by_name_or_id(remote_provider, project_identifier)
|
|
750
|
-
view_cache_summary(remote_provider, project["id"], sort_by, order)
|
|
644
|
+
view_cache_summary(remote_provider, project["id"], sort_by, order, output_format)
|
|
@@ -1,3 +1,3 @@
|
|
|
1
|
-
from .deploy_checkpoints import
|
|
1
|
+
from .deploy_checkpoints import create_model_version_from_inference_template
|
|
2
2
|
|
|
3
|
-
__all__ = ["
|
|
3
|
+
__all__ = ["create_model_version_from_inference_template"]
|