truss 0.11.6rc102__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- truss/api/__init__.py +5 -2
- truss/base/constants.py +1 -0
- truss/base/trt_llm_config.py +14 -3
- truss/base/truss_config.py +19 -4
- truss/cli/chains_commands.py +49 -1
- truss/cli/cli.py +38 -7
- truss/cli/logs/base_watcher.py +31 -12
- truss/cli/logs/model_log_watcher.py +24 -1
- truss/cli/remote_cli.py +29 -0
- truss/cli/resolvers/chain_team_resolver.py +82 -0
- truss/cli/resolvers/model_team_resolver.py +90 -0
- truss/cli/resolvers/training_project_team_resolver.py +81 -0
- truss/cli/train/cache.py +332 -0
- truss/cli/train/core.py +57 -163
- truss/cli/train/deploy_checkpoints/__init__.py +2 -2
- truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +236 -103
- truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +1 -52
- truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +1 -86
- truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py +1 -85
- truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +1 -56
- truss/cli/train/types.py +18 -9
- truss/cli/train_commands.py +180 -35
- truss/cli/utils/common.py +40 -3
- truss/contexts/image_builder/serving_image_builder.py +17 -4
- truss/remote/baseten/api.py +215 -9
- truss/remote/baseten/core.py +63 -7
- truss/remote/baseten/custom_types.py +1 -0
- truss/remote/baseten/remote.py +42 -2
- truss/remote/baseten/service.py +0 -7
- truss/remote/baseten/utils/transfer.py +5 -2
- truss/templates/base.Dockerfile.jinja +8 -4
- truss/templates/control/control/application.py +51 -26
- truss/templates/control/control/endpoints.py +1 -5
- truss/templates/control/control/helpers/inference_server_process_controller.py +10 -4
- truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py +33 -18
- truss/templates/control/control/server.py +1 -1
- truss/templates/control/requirements.txt +1 -2
- truss/templates/docker_server/proxy.conf.jinja +13 -0
- truss/templates/docker_server/supervisord.conf.jinja +2 -1
- truss/templates/no_build.Dockerfile.jinja +1 -0
- truss/templates/server/requirements.txt +2 -3
- truss/templates/server/truss_server.py +2 -5
- truss/templates/server.Dockerfile.jinja +12 -12
- truss/templates/shared/lazy_data_resolver.py +214 -2
- truss/templates/shared/util.py +6 -5
- truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
- truss/tests/cli/test_chains_cli.py +144 -0
- truss/tests/cli/test_cli.py +134 -1
- truss/tests/cli/test_cli_utils_common.py +11 -0
- truss/tests/cli/test_model_team_resolver.py +279 -0
- truss/tests/cli/train/test_cache_view.py +240 -3
- truss/tests/cli/train/test_deploy_checkpoints.py +2 -846
- truss/tests/cli/train/test_train_cli_core.py +2 -2
- truss/tests/cli/train/test_train_team_parameter.py +395 -0
- truss/tests/conftest.py +187 -0
- truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
- truss/tests/remote/baseten/test_api.py +122 -3
- truss/tests/remote/baseten/test_chain_upload.py +294 -0
- truss/tests/remote/baseten/test_core.py +86 -0
- truss/tests/remote/baseten/test_remote.py +216 -288
- truss/tests/remote/baseten/test_service.py +56 -0
- truss/tests/templates/control/control/conftest.py +20 -0
- truss/tests/templates/control/control/test_endpoints.py +4 -0
- truss/tests/templates/control/control/test_server.py +8 -24
- truss/tests/templates/control/control/test_server_integration.py +4 -2
- truss/tests/test_config.py +21 -12
- truss/tests/test_data/server.Dockerfile +3 -1
- truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
- truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
- truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
- truss/tests/test_model_inference.py +13 -0
- truss/tests/util/test_env_vars.py +8 -3
- truss/util/__init__.py +0 -0
- truss/util/env_vars.py +19 -8
- truss/util/error_utils.py +37 -0
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/METADATA +2 -2
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/RECORD +88 -70
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
- truss_chains/deployment/deployment_client.py +16 -4
- truss_chains/private_types.py +18 -0
- truss_chains/public_api.py +3 -0
- truss_train/definitions.py +6 -4
- truss_train/deployment.py +43 -21
- truss_train/public_api.py +4 -2
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,12 +1,3 @@
|
|
|
1
|
-
from pathlib import Path
|
|
2
|
-
|
|
3
|
-
from jinja2 import Template
|
|
4
|
-
|
|
5
|
-
from truss.base import truss_config
|
|
6
|
-
from truss.cli.train.deploy_checkpoints.deploy_checkpoints_helpers import (
|
|
7
|
-
START_COMMAND_ENVVAR_NAME,
|
|
8
|
-
)
|
|
9
|
-
from truss.cli.train.types import DeployCheckpointsConfigComplete
|
|
10
1
|
from truss_train.definitions import (
|
|
11
2
|
ALLOWED_LORA_RANKS,
|
|
12
3
|
DEFAULT_LORA_RANK,
|
|
@@ -14,78 +5,19 @@ from truss_train.definitions import (
|
|
|
14
5
|
LoRADetails,
|
|
15
6
|
)
|
|
16
7
|
|
|
17
|
-
from .deploy_checkpoints_helpers import (
|
|
18
|
-
setup_base_truss_config,
|
|
19
|
-
setup_environment_variables_and_secrets,
|
|
20
|
-
)
|
|
21
|
-
|
|
22
|
-
VLLM_LORA_START_COMMAND = Template(
|
|
23
|
-
'sh -c "{%if envvars %}{{ envvars }} {% endif %}vllm serve {{ base_model_id }}'
|
|
24
|
-
+ " --port 8000"
|
|
25
|
-
+ "{{ specify_tensor_parallelism }}"
|
|
26
|
-
+ " --enable-lora"
|
|
27
|
-
+ " --max-lora-rank {{ max_lora_rank }}"
|
|
28
|
-
+ " --dtype bfloat16"
|
|
29
|
-
+ ' --lora-modules {{ lora_modules }}"'
|
|
30
|
-
)
|
|
31
|
-
|
|
32
8
|
|
|
33
9
|
def hydrate_lora_checkpoint(
|
|
34
10
|
job_id: str, checkpoint_id: str, checkpoint: dict
|
|
35
11
|
) -> LoRACheckpoint:
|
|
36
12
|
"""Create a LoRA-specific Checkpoint object."""
|
|
37
13
|
# NOTE: Slash at the end is important since it means the checkpoint is a directory
|
|
38
|
-
paths = [f"rank-0/{checkpoint_id}/"]
|
|
39
14
|
return LoRACheckpoint(
|
|
40
15
|
training_job_id=job_id,
|
|
41
|
-
paths=paths,
|
|
42
16
|
lora_details=LoRADetails(rank=_get_lora_rank(checkpoint)),
|
|
17
|
+
checkpoint_name=checkpoint_id,
|
|
43
18
|
)
|
|
44
19
|
|
|
45
20
|
|
|
46
|
-
def render_vllm_lora_truss_config(
|
|
47
|
-
checkpoint_deploy: DeployCheckpointsConfigComplete,
|
|
48
|
-
) -> truss_config.TrussConfig:
|
|
49
|
-
"""Render truss config specifically for LoRA checkpoints using vLLM."""
|
|
50
|
-
truss_deploy_config = setup_base_truss_config(checkpoint_deploy)
|
|
51
|
-
start_command_envvars = setup_environment_variables_and_secrets(
|
|
52
|
-
truss_deploy_config, checkpoint_deploy
|
|
53
|
-
)
|
|
54
|
-
|
|
55
|
-
checkpoint_str = _build_lora_checkpoint_string(truss_deploy_config)
|
|
56
|
-
|
|
57
|
-
max_lora_rank = max(
|
|
58
|
-
[
|
|
59
|
-
checkpoint.lora_details.rank or DEFAULT_LORA_RANK
|
|
60
|
-
for checkpoint in checkpoint_deploy.checkpoint_details.checkpoints
|
|
61
|
-
if hasattr(checkpoint, "lora_details") and checkpoint.lora_details
|
|
62
|
-
]
|
|
63
|
-
)
|
|
64
|
-
accelerator = checkpoint_deploy.compute.accelerator
|
|
65
|
-
if accelerator:
|
|
66
|
-
specify_tensor_parallelism = f" --tensor-parallel-size {accelerator.count}"
|
|
67
|
-
else:
|
|
68
|
-
specify_tensor_parallelism = ""
|
|
69
|
-
|
|
70
|
-
start_command_args = {
|
|
71
|
-
"base_model_id": checkpoint_deploy.checkpoint_details.base_model_id,
|
|
72
|
-
"lora_modules": checkpoint_str,
|
|
73
|
-
"envvars": start_command_envvars,
|
|
74
|
-
"max_lora_rank": max_lora_rank,
|
|
75
|
-
"specify_tensor_parallelism": specify_tensor_parallelism,
|
|
76
|
-
}
|
|
77
|
-
start_command = VLLM_LORA_START_COMMAND.render(**start_command_args)
|
|
78
|
-
# Note: we set the start command as an environment variable in supervisord config.
|
|
79
|
-
# This is so that we don't have to change the supervisord config when the start command changes.
|
|
80
|
-
# Our goal is to reduce the number of times we need to rebuild the image, and allow us to deploy faster.
|
|
81
|
-
truss_deploy_config.environment_variables[START_COMMAND_ENVVAR_NAME] = start_command
|
|
82
|
-
# Note: supervisord uses the convention %(ENV_VAR_NAME)s to access environment variable VAR_NAME
|
|
83
|
-
truss_deploy_config.docker_server.start_command = ( # type: ignore[union-attr]
|
|
84
|
-
f"%(ENV_{START_COMMAND_ENVVAR_NAME})s"
|
|
85
|
-
)
|
|
86
|
-
return truss_deploy_config
|
|
87
|
-
|
|
88
|
-
|
|
89
21
|
def _get_lora_rank(checkpoint_resp: dict) -> int:
|
|
90
22
|
"""Extract and validate LoRA rank from checkpoint response."""
|
|
91
23
|
lora_adapter_config = checkpoint_resp.get("lora_adapter_config") or {}
|
|
@@ -99,19 +31,3 @@ def _get_lora_rank(checkpoint_resp: dict) -> int:
|
|
|
99
31
|
)
|
|
100
32
|
|
|
101
33
|
return lora_rank
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
def _build_lora_checkpoint_string(truss_deploy_config) -> str:
|
|
105
|
-
"""Build the checkpoint string for LoRA modules from truss deploy config."""
|
|
106
|
-
checkpoint_parts = []
|
|
107
|
-
for (
|
|
108
|
-
truss_checkpoint
|
|
109
|
-
) in truss_deploy_config.training_checkpoints.artifact_references: # type: ignore
|
|
110
|
-
ckpt_path = Path(
|
|
111
|
-
truss_deploy_config.training_checkpoints.download_folder, # type: ignore
|
|
112
|
-
truss_checkpoint.training_job_id,
|
|
113
|
-
truss_checkpoint.paths[0],
|
|
114
|
-
)
|
|
115
|
-
checkpoint_parts.append(f"{truss_checkpoint.training_job_id}={ckpt_path}")
|
|
116
|
-
|
|
117
|
-
return " ".join(checkpoint_parts)
|
|
@@ -1,63 +1,8 @@
|
|
|
1
|
-
from jinja2 import Template
|
|
2
|
-
|
|
3
|
-
from truss.base import truss_config
|
|
4
|
-
from truss.cli.train.deploy_checkpoints.deploy_checkpoints_helpers import (
|
|
5
|
-
START_COMMAND_ENVVAR_NAME,
|
|
6
|
-
)
|
|
7
|
-
from truss.cli.train.deploy_checkpoints.deploy_full_checkpoints import (
|
|
8
|
-
build_full_checkpoint_string,
|
|
9
|
-
)
|
|
10
|
-
from truss.cli.train.types import DeployCheckpointsConfigComplete
|
|
11
1
|
from truss_train.definitions import WhisperCheckpoint
|
|
12
2
|
|
|
13
|
-
from .deploy_checkpoints_helpers import (
|
|
14
|
-
setup_base_truss_config,
|
|
15
|
-
setup_environment_variables_and_secrets,
|
|
16
|
-
)
|
|
17
|
-
|
|
18
|
-
VLLM_WHISPER_START_COMMAND = Template(
|
|
19
|
-
"sh -c '{% if envvars %}{{ envvars }} {% endif %}"
|
|
20
|
-
'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
|
|
21
|
-
"vllm serve {{ model_path }} --port 8000 --tensor-parallel-size {{ specify_tensor_parallelism }}'"
|
|
22
|
-
)
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
def render_vllm_whisper_truss_config(
|
|
26
|
-
checkpoint_deploy: DeployCheckpointsConfigComplete,
|
|
27
|
-
) -> truss_config.TrussConfig:
|
|
28
|
-
"""Render truss config specifically for whisper checkpoints using vLLM."""
|
|
29
|
-
truss_deploy_config = setup_base_truss_config(checkpoint_deploy)
|
|
30
|
-
|
|
31
|
-
start_command_envvars = setup_environment_variables_and_secrets(
|
|
32
|
-
truss_deploy_config, checkpoint_deploy
|
|
33
|
-
)
|
|
34
|
-
|
|
35
|
-
checkpoint_str = build_full_checkpoint_string(truss_deploy_config)
|
|
36
|
-
|
|
37
|
-
accelerator = checkpoint_deploy.compute.accelerator
|
|
38
|
-
|
|
39
|
-
start_command_args = {
|
|
40
|
-
"model_path": checkpoint_str,
|
|
41
|
-
"envvars": start_command_envvars,
|
|
42
|
-
"specify_tensor_parallelism": accelerator.count if accelerator else 1,
|
|
43
|
-
}
|
|
44
|
-
# Note: we set the start command as an environment variable in supervisord config.
|
|
45
|
-
# This is so that we don't have to change the supervisord config when the start command changes.
|
|
46
|
-
# Our goal is to reduce the number of times we need to rebuild the image, and allow us to deploy faster.
|
|
47
|
-
start_command = VLLM_WHISPER_START_COMMAND.render(**start_command_args)
|
|
48
|
-
truss_deploy_config.environment_variables[START_COMMAND_ENVVAR_NAME] = start_command
|
|
49
|
-
# Note: supervisord uses the convention %(ENV_VAR_NAME)s to access environment variable VAR_NAME
|
|
50
|
-
truss_deploy_config.docker_server.start_command = ( # type: ignore[union-attr]
|
|
51
|
-
f"%(ENV_{START_COMMAND_ENVVAR_NAME})s"
|
|
52
|
-
)
|
|
53
|
-
|
|
54
|
-
return truss_deploy_config
|
|
55
|
-
|
|
56
3
|
|
|
57
4
|
def hydrate_whisper_checkpoint(
|
|
58
5
|
job_id: str, checkpoint_id: str, checkpoint: dict
|
|
59
6
|
) -> WhisperCheckpoint:
|
|
60
7
|
"""Create a Checkpoint object for whisper model weights."""
|
|
61
|
-
|
|
62
|
-
paths = [f"rank-0/{checkpoint_id}/"]
|
|
63
|
-
return WhisperCheckpoint(training_job_id=job_id, paths=paths)
|
|
8
|
+
return WhisperCheckpoint(training_job_id=job_id, checkpoint_name=checkpoint_id)
|
truss/cli/train/types.py
CHANGED
|
@@ -1,18 +1,20 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
from pathlib import Path
|
|
3
2
|
from typing import Optional
|
|
4
3
|
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
from truss.base import truss_config
|
|
5
7
|
from truss_train.definitions import (
|
|
6
8
|
CheckpointList,
|
|
7
9
|
Compute,
|
|
8
10
|
DeployCheckpointsConfig,
|
|
9
11
|
DeployCheckpointsRuntime,
|
|
10
|
-
ModelWeightsFormat,
|
|
11
12
|
)
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
@dataclass
|
|
15
|
-
class
|
|
16
|
+
class DeployCheckpointArgs:
|
|
17
|
+
dry_run: bool
|
|
16
18
|
project_id: Optional[str]
|
|
17
19
|
job_id: Optional[str]
|
|
18
20
|
deploy_config_path: Optional[str]
|
|
@@ -26,13 +28,20 @@ class DeployCheckpointsConfigComplete(DeployCheckpointsConfig):
|
|
|
26
28
|
|
|
27
29
|
checkpoint_details: CheckpointList
|
|
28
30
|
model_name: str
|
|
29
|
-
deployment_name: str
|
|
30
31
|
runtime: DeployCheckpointsRuntime
|
|
31
32
|
compute: Compute
|
|
32
|
-
model_weight_format: ModelWeightsFormat
|
|
33
33
|
|
|
34
34
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
35
|
+
class DeploySuccessModelVersion(BaseModel):
|
|
36
|
+
# allow extra fields to be forwards compatible with server
|
|
37
|
+
class Config:
|
|
38
|
+
extra = "allow"
|
|
39
|
+
|
|
40
|
+
name: str
|
|
41
|
+
id: str
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class DeploySuccessResult(BaseModel):
|
|
45
|
+
deploy_config: DeployCheckpointsConfigComplete
|
|
46
|
+
truss_config: Optional[truss_config.TrussConfig]
|
|
47
|
+
model_version: Optional[DeploySuccessModelVersion]
|
truss/cli/train_commands.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import sys
|
|
3
|
+
from datetime import datetime
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
from typing import Optional, cast
|
|
5
6
|
|
|
@@ -8,12 +9,18 @@ import rich_click as click
|
|
|
8
9
|
import truss.cli.train.core as train_cli
|
|
9
10
|
from truss.base.constants import TRAINING_TEMPLATE_DIR
|
|
10
11
|
from truss.cli import remote_cli
|
|
11
|
-
from truss.cli.cli import
|
|
12
|
+
from truss.cli.cli import truss_cli
|
|
12
13
|
from truss.cli.logs import utils as cli_log_utils
|
|
13
14
|
from truss.cli.logs.training_log_watcher import TrainingLogWatcher
|
|
15
|
+
from truss.cli.resolvers.training_project_team_resolver import (
|
|
16
|
+
resolve_training_project_team_name,
|
|
17
|
+
)
|
|
14
18
|
from truss.cli.train import common as train_common
|
|
15
19
|
from truss.cli.train import core
|
|
16
|
-
from truss.cli.train.
|
|
20
|
+
from truss.cli.train.cache import (
|
|
21
|
+
OUTPUT_FORMAT_CLI_TABLE,
|
|
22
|
+
OUTPUT_FORMAT_CSV,
|
|
23
|
+
OUTPUT_FORMAT_JSON,
|
|
17
24
|
SORT_BY_FILEPATH,
|
|
18
25
|
SORT_BY_MODIFIED,
|
|
19
26
|
SORT_BY_PERMISSIONS,
|
|
@@ -22,6 +29,7 @@ from truss.cli.train.core import (
|
|
|
22
29
|
SORT_ORDER_ASC,
|
|
23
30
|
SORT_ORDER_DESC,
|
|
24
31
|
)
|
|
32
|
+
from truss.cli.train.types import DeploySuccessResult
|
|
25
33
|
from truss.cli.utils import common
|
|
26
34
|
from truss.cli.utils.output import console, error_console
|
|
27
35
|
from truss.remote.baseten.core import get_training_job_logs_with_pagination
|
|
@@ -41,13 +49,14 @@ truss_cli.add_command(train)
|
|
|
41
49
|
|
|
42
50
|
def _print_training_job_success_message(
|
|
43
51
|
job_id: str,
|
|
52
|
+
project_id: str,
|
|
44
53
|
project_name: str,
|
|
45
|
-
job_object: TrainingJob,
|
|
54
|
+
job_object: Optional[TrainingJob],
|
|
46
55
|
remote_provider: BasetenRemote,
|
|
47
56
|
) -> None:
|
|
48
57
|
"""Print success message and helpful commands for a training job."""
|
|
49
58
|
console.print("✨ Training job successfully created!", style="green")
|
|
50
|
-
should_print_cache_summary = (
|
|
59
|
+
should_print_cache_summary = job_object and (
|
|
51
60
|
job_object.runtime.enable_cache
|
|
52
61
|
or job_object.runtime.cache_config
|
|
53
62
|
and job_object.runtime.cache_config.enabled
|
|
@@ -64,7 +73,7 @@ def _print_training_job_success_message(
|
|
|
64
73
|
f"🔍 View metrics for your job via "
|
|
65
74
|
f"[cyan]'truss train metrics --job-id {job_id}'[/cyan]\n"
|
|
66
75
|
f"{cache_summary_snippet}"
|
|
67
|
-
f"🌐
|
|
76
|
+
f"🌐 View job in the UI: {common.format_link(core.status_page_url(remote_provider.remote_url, project_id, job_id))}"
|
|
68
77
|
)
|
|
69
78
|
|
|
70
79
|
|
|
@@ -80,8 +89,13 @@ def _handle_post_create_logic(
|
|
|
80
89
|
style="green",
|
|
81
90
|
)
|
|
82
91
|
else:
|
|
92
|
+
# recreate currently doesn't pass back a job object.
|
|
83
93
|
_print_training_job_success_message(
|
|
84
|
-
job_id,
|
|
94
|
+
job_id,
|
|
95
|
+
project_id,
|
|
96
|
+
project_name,
|
|
97
|
+
job_resp.get("job_object"),
|
|
98
|
+
remote_provider,
|
|
85
99
|
)
|
|
86
100
|
|
|
87
101
|
if tail:
|
|
@@ -100,29 +114,70 @@ def _prepare_click_context(f: click.Command, params: dict) -> click.Context:
|
|
|
100
114
|
return ctx
|
|
101
115
|
|
|
102
116
|
|
|
117
|
+
def _resolve_team_name(
|
|
118
|
+
remote_provider: BasetenRemote,
|
|
119
|
+
provided_team_name: Optional[str],
|
|
120
|
+
existing_project_name: Optional[str] = None,
|
|
121
|
+
existing_teams: Optional[dict[str, dict[str, str]]] = None,
|
|
122
|
+
) -> tuple[Optional[str], Optional[str]]:
|
|
123
|
+
return resolve_training_project_team_name(
|
|
124
|
+
remote_provider=remote_provider,
|
|
125
|
+
provided_team_name=provided_team_name,
|
|
126
|
+
existing_project_name=existing_project_name,
|
|
127
|
+
existing_teams=existing_teams,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
103
131
|
@train.command(name="push")
|
|
104
132
|
@click.argument("config", type=Path, required=True)
|
|
105
133
|
@click.option("--remote", type=str, required=False, help="Remote to use")
|
|
106
134
|
@click.option("--tail", is_flag=True, help="Tail for status + logs after push.")
|
|
107
135
|
@click.option("--job-name", type=str, required=False, help="Name of the training job.")
|
|
136
|
+
@click.option(
|
|
137
|
+
"--team",
|
|
138
|
+
"provided_team_name",
|
|
139
|
+
type=str,
|
|
140
|
+
required=False,
|
|
141
|
+
help="Team name for the training project",
|
|
142
|
+
)
|
|
108
143
|
@common.common_options()
|
|
109
144
|
def push_training_job(
|
|
110
|
-
config: Path,
|
|
145
|
+
config: Path,
|
|
146
|
+
remote: Optional[str],
|
|
147
|
+
tail: bool,
|
|
148
|
+
job_name: Optional[str],
|
|
149
|
+
provided_team_name: Optional[str],
|
|
111
150
|
):
|
|
112
151
|
"""Run a training job"""
|
|
113
|
-
from truss_train import deployment
|
|
152
|
+
from truss_train import deployment, loader
|
|
114
153
|
|
|
115
154
|
if not remote:
|
|
116
155
|
remote = remote_cli.inquire_remote_name()
|
|
117
156
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
157
|
+
remote_provider: BasetenRemote = cast(
|
|
158
|
+
BasetenRemote, RemoteFactory.create(remote=remote)
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
existing_teams = remote_provider.api.get_teams()
|
|
162
|
+
|
|
163
|
+
with loader.import_training_project(config) as training_project:
|
|
164
|
+
team_name, team_id = _resolve_team_name(
|
|
165
|
+
remote_provider,
|
|
166
|
+
provided_team_name,
|
|
167
|
+
existing_project_name=training_project.name,
|
|
168
|
+
existing_teams=existing_teams,
|
|
124
169
|
)
|
|
125
170
|
|
|
171
|
+
with console.status("Creating training job...", spinner="dots"):
|
|
172
|
+
job_resp = deployment.create_training_job(
|
|
173
|
+
remote_provider,
|
|
174
|
+
config,
|
|
175
|
+
training_project,
|
|
176
|
+
job_name_from_cli=job_name,
|
|
177
|
+
team_name=team_name,
|
|
178
|
+
team_id=team_id,
|
|
179
|
+
)
|
|
180
|
+
|
|
126
181
|
# Note: This post create logic needs to happen outside the context
|
|
127
182
|
# of the above context manager, as only one console session can be active
|
|
128
183
|
# at a time.
|
|
@@ -156,11 +211,16 @@ def recreate_training_job(job_id: Optional[str], remote: Optional[str], tail: bo
|
|
|
156
211
|
@train.command(name="logs")
|
|
157
212
|
@click.option("--remote", type=str, required=False, help="Remote to use")
|
|
158
213
|
@click.option("--project-id", type=str, required=False, help="Project ID.")
|
|
214
|
+
@click.option("--project", type=str, required=False, help="Project name or project id.")
|
|
159
215
|
@click.option("--job-id", type=str, required=False, help="Job ID.")
|
|
160
216
|
@click.option("--tail", is_flag=True, help="Tail for ongoing logs.")
|
|
161
217
|
@common.common_options()
|
|
162
218
|
def get_job_logs(
|
|
163
|
-
remote: Optional[str],
|
|
219
|
+
remote: Optional[str],
|
|
220
|
+
project_id: Optional[str],
|
|
221
|
+
project: Optional[str],
|
|
222
|
+
job_id: Optional[str],
|
|
223
|
+
tail: bool,
|
|
164
224
|
):
|
|
165
225
|
"""Fetch logs for a training job"""
|
|
166
226
|
|
|
@@ -170,6 +230,10 @@ def get_job_logs(
|
|
|
170
230
|
remote_provider: BasetenRemote = cast(
|
|
171
231
|
BasetenRemote, RemoteFactory.create(remote=remote)
|
|
172
232
|
)
|
|
233
|
+
project_id = _maybe_resolve_project_id_from_id_or_name(
|
|
234
|
+
remote_provider, project_id=project_id, project=project
|
|
235
|
+
)
|
|
236
|
+
|
|
173
237
|
project_id, job_id = train_common.get_most_recent_job(
|
|
174
238
|
remote_provider, project_id, job_id
|
|
175
239
|
)
|
|
@@ -188,12 +252,17 @@ def get_job_logs(
|
|
|
188
252
|
|
|
189
253
|
@train.command(name="stop")
|
|
190
254
|
@click.option("--project-id", type=str, required=False, help="Project ID.")
|
|
255
|
+
@click.option("--project", type=str, required=False, help="Project name or project id.")
|
|
191
256
|
@click.option("--job-id", type=str, required=False, help="Job ID.")
|
|
192
257
|
@click.option("--all", is_flag=True, help="Stop all running jobs.")
|
|
193
258
|
@click.option("--remote", type=str, required=False, help="Remote to use")
|
|
194
259
|
@common.common_options()
|
|
195
260
|
def stop_job(
|
|
196
|
-
project_id: Optional[str],
|
|
261
|
+
project_id: Optional[str],
|
|
262
|
+
project: Optional[str],
|
|
263
|
+
job_id: Optional[str],
|
|
264
|
+
all: bool,
|
|
265
|
+
remote: Optional[str],
|
|
197
266
|
):
|
|
198
267
|
"""Stop a training job"""
|
|
199
268
|
|
|
@@ -203,6 +272,9 @@ def stop_job(
|
|
|
203
272
|
remote_provider: BasetenRemote = cast(
|
|
204
273
|
BasetenRemote, RemoteFactory.create(remote=remote)
|
|
205
274
|
)
|
|
275
|
+
project_id = _maybe_resolve_project_id_from_id_or_name(
|
|
276
|
+
remote_provider, project_id=project_id, project=project
|
|
277
|
+
)
|
|
206
278
|
if all:
|
|
207
279
|
train_cli.stop_all_jobs(remote_provider, project_id)
|
|
208
280
|
else:
|
|
@@ -217,13 +289,17 @@ def stop_job(
|
|
|
217
289
|
@click.option(
|
|
218
290
|
"--project-id", type=str, required=False, help="View training jobs for a project."
|
|
219
291
|
)
|
|
292
|
+
@click.option("--project", type=str, required=False, help="Project name or project id.")
|
|
220
293
|
@click.option(
|
|
221
294
|
"--job-id", type=str, required=False, help="View a specific training job."
|
|
222
295
|
)
|
|
223
296
|
@click.option("--remote", type=str, required=False, help="Remote to use")
|
|
224
297
|
@common.common_options()
|
|
225
298
|
def view_training(
|
|
226
|
-
project_id: Optional[str],
|
|
299
|
+
project_id: Optional[str],
|
|
300
|
+
project: Optional[str],
|
|
301
|
+
job_id: Optional[str],
|
|
302
|
+
remote: Optional[str],
|
|
227
303
|
):
|
|
228
304
|
"""List all training jobs for a project"""
|
|
229
305
|
|
|
@@ -233,16 +309,24 @@ def view_training(
|
|
|
233
309
|
remote_provider: BasetenRemote = cast(
|
|
234
310
|
BasetenRemote, RemoteFactory.create(remote=remote)
|
|
235
311
|
)
|
|
312
|
+
project_id = _maybe_resolve_project_id_from_id_or_name(
|
|
313
|
+
remote_provider, project_id=project_id, project=project
|
|
314
|
+
)
|
|
315
|
+
|
|
236
316
|
train_cli.view_training_details(remote_provider, project_id, job_id)
|
|
237
317
|
|
|
238
318
|
|
|
239
319
|
@train.command(name="metrics")
|
|
240
320
|
@click.option("--project-id", type=str, required=False, help="Project ID.")
|
|
321
|
+
@click.option("--project", type=str, required=False, help="Project name or project id.")
|
|
241
322
|
@click.option("--job-id", type=str, required=False, help="Job ID.")
|
|
242
323
|
@click.option("--remote", type=str, required=False, help="Remote to use")
|
|
243
324
|
@common.common_options()
|
|
244
325
|
def get_job_metrics(
|
|
245
|
-
project_id: Optional[str],
|
|
326
|
+
project_id: Optional[str],
|
|
327
|
+
project: Optional[str],
|
|
328
|
+
job_id: Optional[str],
|
|
329
|
+
remote: Optional[str],
|
|
246
330
|
):
|
|
247
331
|
"""Get metrics for a training job"""
|
|
248
332
|
|
|
@@ -252,11 +336,15 @@ def get_job_metrics(
|
|
|
252
336
|
remote_provider: BasetenRemote = cast(
|
|
253
337
|
BasetenRemote, RemoteFactory.create(remote=remote)
|
|
254
338
|
)
|
|
339
|
+
project_id = _maybe_resolve_project_id_from_id_or_name(
|
|
340
|
+
remote_provider, project_id=project_id, project=project
|
|
341
|
+
)
|
|
255
342
|
train_cli.view_training_job_metrics(remote_provider, project_id, job_id)
|
|
256
343
|
|
|
257
344
|
|
|
258
345
|
@train.command(name="deploy_checkpoints")
|
|
259
346
|
@click.option("--project-id", type=str, required=False, help="Project ID.")
|
|
347
|
+
@click.option("--project", type=str, required=False, help="Project name or project id.")
|
|
260
348
|
@click.option("--job-id", type=str, required=False, help="Job ID.")
|
|
261
349
|
@click.option(
|
|
262
350
|
"--config",
|
|
@@ -267,14 +355,22 @@ def get_job_metrics(
|
|
|
267
355
|
@click.option(
|
|
268
356
|
"--dry-run", is_flag=True, help="Generate a truss config without deploying"
|
|
269
357
|
)
|
|
358
|
+
@click.option(
|
|
359
|
+
"--truss-config-output-dir",
|
|
360
|
+
type=str,
|
|
361
|
+
required=False,
|
|
362
|
+
help="Path to output the truss config to. If not provided, will output to truss_configs/<model_version_name>_<model_version_id> or truss_configs/dry_run_<timestamp> if dry run.",
|
|
363
|
+
)
|
|
270
364
|
@click.option("--remote", type=str, required=False, help="Remote to use")
|
|
271
365
|
@common.common_options()
|
|
272
366
|
def deploy_checkpoints(
|
|
273
367
|
project_id: Optional[str],
|
|
368
|
+
project: Optional[str],
|
|
274
369
|
job_id: Optional[str],
|
|
275
370
|
config: Optional[str],
|
|
276
371
|
remote: Optional[str],
|
|
277
372
|
dry_run: bool,
|
|
373
|
+
truss_config_output_dir: Optional[str],
|
|
278
374
|
):
|
|
279
375
|
"""
|
|
280
376
|
Deploy a LoRA checkpoint via vLLM.
|
|
@@ -286,26 +382,52 @@ def deploy_checkpoints(
|
|
|
286
382
|
remote_provider: BasetenRemote = cast(
|
|
287
383
|
BasetenRemote, RemoteFactory.create(remote=remote)
|
|
288
384
|
)
|
|
289
|
-
|
|
385
|
+
project_id = _maybe_resolve_project_id_from_id_or_name(
|
|
386
|
+
remote_provider, project_id=project_id, project=project
|
|
387
|
+
)
|
|
388
|
+
result = train_cli.create_model_version_from_inference_template(
|
|
290
389
|
remote_provider,
|
|
291
|
-
train_cli.
|
|
292
|
-
project_id=project_id,
|
|
390
|
+
train_cli.DeployCheckpointArgs(
|
|
391
|
+
project_id=project_id,
|
|
392
|
+
job_id=job_id,
|
|
393
|
+
deploy_config_path=config,
|
|
394
|
+
dry_run=dry_run,
|
|
293
395
|
),
|
|
294
396
|
)
|
|
295
397
|
|
|
296
|
-
params = {
|
|
297
|
-
"target_directory": prepare_checkpoint_result.truss_directory,
|
|
298
|
-
"remote": remote,
|
|
299
|
-
"model_name": prepare_checkpoint_result.checkpoint_deploy_config.model_name,
|
|
300
|
-
"publish": True,
|
|
301
|
-
"deployment_name": prepare_checkpoint_result.checkpoint_deploy_config.deployment_name,
|
|
302
|
-
}
|
|
303
|
-
ctx = _prepare_click_context(push, params)
|
|
304
398
|
if dry_run:
|
|
305
|
-
console.print("--dry-run flag provided
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
399
|
+
console.print("did not deploy because --dry-run flag provided", style="yellow")
|
|
400
|
+
|
|
401
|
+
_write_truss_config(result, truss_config_output_dir, dry_run)
|
|
402
|
+
|
|
403
|
+
if not dry_run:
|
|
404
|
+
train_cli.print_deploy_checkpoints_success_message(result.deploy_config)
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def _write_truss_config(
|
|
408
|
+
result: DeploySuccessResult, truss_config_output_dir: Optional[str], dry_run: bool
|
|
409
|
+
) -> None:
|
|
410
|
+
if not result.truss_config:
|
|
411
|
+
return
|
|
412
|
+
# format: 20251006_123456
|
|
413
|
+
datestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
414
|
+
folder_name = (
|
|
415
|
+
f"{result.model_version.name}_{result.model_version.id}"
|
|
416
|
+
if result.model_version
|
|
417
|
+
else f"dry_run_{datestamp}"
|
|
418
|
+
)
|
|
419
|
+
output_dir_str = truss_config_output_dir or f"truss_configs/{folder_name}"
|
|
420
|
+
output_dir = Path(output_dir_str)
|
|
421
|
+
output_path = output_dir / "config.yaml"
|
|
422
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
423
|
+
console.print(f"Writing truss config to {output_path}", style="yellow")
|
|
424
|
+
console.print(f"👀 Run `cat {output_path}` to view the truss config", style="green")
|
|
425
|
+
if dry_run:
|
|
426
|
+
console.print(
|
|
427
|
+
f"🚀 Run `cd {output_dir} && truss push --publish` to deploy the truss",
|
|
428
|
+
style="green",
|
|
429
|
+
)
|
|
430
|
+
result.truss_config.write_to_yaml_file(output_path)
|
|
309
431
|
|
|
310
432
|
|
|
311
433
|
@train.command(name="download")
|
|
@@ -481,8 +603,17 @@ def cache():
|
|
|
481
603
|
default=SORT_ORDER_ASC,
|
|
482
604
|
help="Sort order: ascending or descending.",
|
|
483
605
|
)
|
|
606
|
+
@click.option(
|
|
607
|
+
"-o",
|
|
608
|
+
"--output-format",
|
|
609
|
+
type=click.Choice([OUTPUT_FORMAT_CLI_TABLE, OUTPUT_FORMAT_CSV, OUTPUT_FORMAT_JSON]),
|
|
610
|
+
default=OUTPUT_FORMAT_CLI_TABLE,
|
|
611
|
+
help="Output format: cli-table (default), csv, or json.",
|
|
612
|
+
)
|
|
484
613
|
@common.common_options()
|
|
485
|
-
def view_cache_summary(
|
|
614
|
+
def view_cache_summary(
|
|
615
|
+
project: str, remote: Optional[str], sort: str, order: str, output_format: str
|
|
616
|
+
):
|
|
486
617
|
"""View cache summary for a training project"""
|
|
487
618
|
if not remote:
|
|
488
619
|
remote = remote_cli.inquire_remote_name()
|
|
@@ -491,4 +622,18 @@ def view_cache_summary(project: str, remote: Optional[str], sort: str, order: st
|
|
|
491
622
|
BasetenRemote, RemoteFactory.create(remote=remote)
|
|
492
623
|
)
|
|
493
624
|
|
|
494
|
-
train_cli.view_cache_summary_by_project(
|
|
625
|
+
train_cli.view_cache_summary_by_project(
|
|
626
|
+
remote_provider, project, sort, order, output_format
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
|
|
630
|
+
def _maybe_resolve_project_id_from_id_or_name(
|
|
631
|
+
remote_provider: BasetenRemote, project_id: Optional[str], project: Optional[str]
|
|
632
|
+
) -> Optional[str]:
|
|
633
|
+
"""resolve the project_id or project. `project` can be name or id"""
|
|
634
|
+
if project and project_id:
|
|
635
|
+
console.print("Both `project-id` and `project` provided. Using `project`.")
|
|
636
|
+
project_str = project or project_id
|
|
637
|
+
if not project_str:
|
|
638
|
+
return None
|
|
639
|
+
return train_cli.fetch_project_by_name_or_id(remote_provider, project_str)["id"]
|
truss/cli/utils/common.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import datetime
|
|
2
2
|
import logging
|
|
3
|
+
import re
|
|
3
4
|
import sys
|
|
4
5
|
import warnings
|
|
5
6
|
from functools import wraps
|
|
@@ -20,6 +21,8 @@ from truss.cli.utils import self_upgrade
|
|
|
20
21
|
from truss.cli.utils.output import console
|
|
21
22
|
from truss.util import user_config
|
|
22
23
|
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
23
26
|
INCLUDE_GIT_INFO_DOC = (
|
|
24
27
|
"Whether to attach git versioning info (sha, branch, tag) to deployments made from "
|
|
25
28
|
"within a git repo. If set to True in `.trussrc`, it will always be attached."
|
|
@@ -181,10 +184,44 @@ def is_human_log_level(ctx: click.Context) -> bool:
|
|
|
181
184
|
return get_required_option(ctx, "log") != _HUMANFRIENDLY_LOG_LEVEL
|
|
182
185
|
|
|
183
186
|
|
|
184
|
-
def
|
|
187
|
+
def _normalize_iso_timestamp(iso_timestamp: str) -> str:
|
|
188
|
+
iso_timestamp = iso_timestamp.strip()
|
|
185
189
|
if iso_timestamp.endswith("Z"):
|
|
186
|
-
iso_timestamp = iso_timestamp
|
|
187
|
-
|
|
190
|
+
iso_timestamp = iso_timestamp[:-1] + "+00:00"
|
|
191
|
+
|
|
192
|
+
tz_part = ""
|
|
193
|
+
tz_match = re.search(r"([+-]\d{2}:\d{2}|[+-]\d{4})$", iso_timestamp)
|
|
194
|
+
if tz_match:
|
|
195
|
+
tz_part = tz_match.group(0)
|
|
196
|
+
iso_timestamp = iso_timestamp[: tz_match.start()]
|
|
197
|
+
|
|
198
|
+
iso_timestamp = iso_timestamp.rstrip()
|
|
199
|
+
|
|
200
|
+
if tz_part and ":" not in tz_part:
|
|
201
|
+
tz_part = f"{tz_part[:3]}:{tz_part[3:]}"
|
|
202
|
+
|
|
203
|
+
fractional_match = re.search(r"\.(\d+)$", iso_timestamp)
|
|
204
|
+
if fractional_match:
|
|
205
|
+
fractional_digits = fractional_match.group(1)
|
|
206
|
+
if len(fractional_digits) > 6:
|
|
207
|
+
iso_timestamp = (
|
|
208
|
+
iso_timestamp[: fractional_match.start()] + "." + fractional_digits[:6]
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
return f"{iso_timestamp}{tz_part}"
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
# NOTE: `pyproject.toml` declares support down to Python 3.9, whose
|
|
215
|
+
# `datetime.fromisoformat` cannot parse nanosecond fractions or colonless offsets,
|
|
216
|
+
# so normalize timestamps before parsing.
|
|
217
|
+
def format_localized_time(iso_timestamp: str) -> str:
|
|
218
|
+
try:
|
|
219
|
+
utc_time = datetime.datetime.fromisoformat(iso_timestamp)
|
|
220
|
+
except ValueError:
|
|
221
|
+
# Handle non-standard formats (nanoseconds, Z suffix, colonless offsets)
|
|
222
|
+
normalized_timestamp = _normalize_iso_timestamp(iso_timestamp)
|
|
223
|
+
utc_time = datetime.datetime.fromisoformat(normalized_timestamp)
|
|
224
|
+
|
|
188
225
|
local_time = utc_time.astimezone()
|
|
189
226
|
return local_time.strftime("%Y-%m-%d %H:%M:%S")
|
|
190
227
|
|