truss 0.11.6rc102__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. truss/api/__init__.py +5 -2
  2. truss/base/constants.py +1 -0
  3. truss/base/trt_llm_config.py +14 -3
  4. truss/base/truss_config.py +19 -4
  5. truss/cli/chains_commands.py +49 -1
  6. truss/cli/cli.py +38 -7
  7. truss/cli/logs/base_watcher.py +31 -12
  8. truss/cli/logs/model_log_watcher.py +24 -1
  9. truss/cli/remote_cli.py +29 -0
  10. truss/cli/resolvers/chain_team_resolver.py +82 -0
  11. truss/cli/resolvers/model_team_resolver.py +90 -0
  12. truss/cli/resolvers/training_project_team_resolver.py +81 -0
  13. truss/cli/train/cache.py +332 -0
  14. truss/cli/train/core.py +57 -163
  15. truss/cli/train/deploy_checkpoints/__init__.py +2 -2
  16. truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +236 -103
  17. truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +1 -52
  18. truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +1 -86
  19. truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py +1 -85
  20. truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +1 -56
  21. truss/cli/train/types.py +18 -9
  22. truss/cli/train_commands.py +180 -35
  23. truss/cli/utils/common.py +40 -3
  24. truss/contexts/image_builder/serving_image_builder.py +17 -4
  25. truss/remote/baseten/api.py +215 -9
  26. truss/remote/baseten/core.py +63 -7
  27. truss/remote/baseten/custom_types.py +1 -0
  28. truss/remote/baseten/remote.py +42 -2
  29. truss/remote/baseten/service.py +0 -7
  30. truss/remote/baseten/utils/transfer.py +5 -2
  31. truss/templates/base.Dockerfile.jinja +8 -4
  32. truss/templates/control/control/application.py +51 -26
  33. truss/templates/control/control/endpoints.py +1 -5
  34. truss/templates/control/control/helpers/inference_server_process_controller.py +10 -4
  35. truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py +33 -18
  36. truss/templates/control/control/server.py +1 -1
  37. truss/templates/control/requirements.txt +1 -2
  38. truss/templates/docker_server/proxy.conf.jinja +13 -0
  39. truss/templates/docker_server/supervisord.conf.jinja +2 -1
  40. truss/templates/no_build.Dockerfile.jinja +1 -0
  41. truss/templates/server/requirements.txt +2 -3
  42. truss/templates/server/truss_server.py +2 -5
  43. truss/templates/server.Dockerfile.jinja +12 -12
  44. truss/templates/shared/lazy_data_resolver.py +214 -2
  45. truss/templates/shared/util.py +6 -5
  46. truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
  47. truss/tests/cli/test_chains_cli.py +144 -0
  48. truss/tests/cli/test_cli.py +134 -1
  49. truss/tests/cli/test_cli_utils_common.py +11 -0
  50. truss/tests/cli/test_model_team_resolver.py +279 -0
  51. truss/tests/cli/train/test_cache_view.py +240 -3
  52. truss/tests/cli/train/test_deploy_checkpoints.py +2 -846
  53. truss/tests/cli/train/test_train_cli_core.py +2 -2
  54. truss/tests/cli/train/test_train_team_parameter.py +395 -0
  55. truss/tests/conftest.py +187 -0
  56. truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
  57. truss/tests/remote/baseten/test_api.py +122 -3
  58. truss/tests/remote/baseten/test_chain_upload.py +294 -0
  59. truss/tests/remote/baseten/test_core.py +86 -0
  60. truss/tests/remote/baseten/test_remote.py +216 -288
  61. truss/tests/remote/baseten/test_service.py +56 -0
  62. truss/tests/templates/control/control/conftest.py +20 -0
  63. truss/tests/templates/control/control/test_endpoints.py +4 -0
  64. truss/tests/templates/control/control/test_server.py +8 -24
  65. truss/tests/templates/control/control/test_server_integration.py +4 -2
  66. truss/tests/test_config.py +21 -12
  67. truss/tests/test_data/server.Dockerfile +3 -1
  68. truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
  69. truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
  70. truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
  71. truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
  72. truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
  73. truss/tests/test_model_inference.py +13 -0
  74. truss/tests/util/test_env_vars.py +8 -3
  75. truss/util/__init__.py +0 -0
  76. truss/util/env_vars.py +19 -8
  77. truss/util/error_utils.py +37 -0
  78. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/METADATA +2 -2
  79. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/RECORD +88 -70
  80. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
  81. truss_chains/deployment/deployment_client.py +16 -4
  82. truss_chains/private_types.py +18 -0
  83. truss_chains/public_api.py +3 -0
  84. truss_train/definitions.py +6 -4
  85. truss_train/deployment.py +43 -21
  86. truss_train/public_api.py +4 -2
  87. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
  88. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
@@ -1,12 +1,3 @@
1
- from pathlib import Path
2
-
3
- from jinja2 import Template
4
-
5
- from truss.base import truss_config
6
- from truss.cli.train.deploy_checkpoints.deploy_checkpoints_helpers import (
7
- START_COMMAND_ENVVAR_NAME,
8
- )
9
- from truss.cli.train.types import DeployCheckpointsConfigComplete
10
1
  from truss_train.definitions import (
11
2
  ALLOWED_LORA_RANKS,
12
3
  DEFAULT_LORA_RANK,
@@ -14,78 +5,19 @@ from truss_train.definitions import (
14
5
  LoRADetails,
15
6
  )
16
7
 
17
- from .deploy_checkpoints_helpers import (
18
- setup_base_truss_config,
19
- setup_environment_variables_and_secrets,
20
- )
21
-
22
- VLLM_LORA_START_COMMAND = Template(
23
- 'sh -c "{%if envvars %}{{ envvars }} {% endif %}vllm serve {{ base_model_id }}'
24
- + " --port 8000"
25
- + "{{ specify_tensor_parallelism }}"
26
- + " --enable-lora"
27
- + " --max-lora-rank {{ max_lora_rank }}"
28
- + " --dtype bfloat16"
29
- + ' --lora-modules {{ lora_modules }}"'
30
- )
31
-
32
8
 
33
9
  def hydrate_lora_checkpoint(
34
10
  job_id: str, checkpoint_id: str, checkpoint: dict
35
11
  ) -> LoRACheckpoint:
36
12
  """Create a LoRA-specific Checkpoint object."""
37
13
  # NOTE: Slash at the end is important since it means the checkpoint is a directory
38
- paths = [f"rank-0/{checkpoint_id}/"]
39
14
  return LoRACheckpoint(
40
15
  training_job_id=job_id,
41
- paths=paths,
42
16
  lora_details=LoRADetails(rank=_get_lora_rank(checkpoint)),
17
+ checkpoint_name=checkpoint_id,
43
18
  )
44
19
 
45
20
 
46
- def render_vllm_lora_truss_config(
47
- checkpoint_deploy: DeployCheckpointsConfigComplete,
48
- ) -> truss_config.TrussConfig:
49
- """Render truss config specifically for LoRA checkpoints using vLLM."""
50
- truss_deploy_config = setup_base_truss_config(checkpoint_deploy)
51
- start_command_envvars = setup_environment_variables_and_secrets(
52
- truss_deploy_config, checkpoint_deploy
53
- )
54
-
55
- checkpoint_str = _build_lora_checkpoint_string(truss_deploy_config)
56
-
57
- max_lora_rank = max(
58
- [
59
- checkpoint.lora_details.rank or DEFAULT_LORA_RANK
60
- for checkpoint in checkpoint_deploy.checkpoint_details.checkpoints
61
- if hasattr(checkpoint, "lora_details") and checkpoint.lora_details
62
- ]
63
- )
64
- accelerator = checkpoint_deploy.compute.accelerator
65
- if accelerator:
66
- specify_tensor_parallelism = f" --tensor-parallel-size {accelerator.count}"
67
- else:
68
- specify_tensor_parallelism = ""
69
-
70
- start_command_args = {
71
- "base_model_id": checkpoint_deploy.checkpoint_details.base_model_id,
72
- "lora_modules": checkpoint_str,
73
- "envvars": start_command_envvars,
74
- "max_lora_rank": max_lora_rank,
75
- "specify_tensor_parallelism": specify_tensor_parallelism,
76
- }
77
- start_command = VLLM_LORA_START_COMMAND.render(**start_command_args)
78
- # Note: we set the start command as an environment variable in supervisord config.
79
- # This is so that we don't have to change the supervisord config when the start command changes.
80
- # Our goal is to reduce the number of times we need to rebuild the image, and allow us to deploy faster.
81
- truss_deploy_config.environment_variables[START_COMMAND_ENVVAR_NAME] = start_command
82
- # Note: supervisord uses the convention %(ENV_VAR_NAME)s to access environment variable VAR_NAME
83
- truss_deploy_config.docker_server.start_command = ( # type: ignore[union-attr]
84
- f"%(ENV_{START_COMMAND_ENVVAR_NAME})s"
85
- )
86
- return truss_deploy_config
87
-
88
-
89
21
  def _get_lora_rank(checkpoint_resp: dict) -> int:
90
22
  """Extract and validate LoRA rank from checkpoint response."""
91
23
  lora_adapter_config = checkpoint_resp.get("lora_adapter_config") or {}
@@ -99,19 +31,3 @@ def _get_lora_rank(checkpoint_resp: dict) -> int:
99
31
  )
100
32
 
101
33
  return lora_rank
102
-
103
-
104
- def _build_lora_checkpoint_string(truss_deploy_config) -> str:
105
- """Build the checkpoint string for LoRA modules from truss deploy config."""
106
- checkpoint_parts = []
107
- for (
108
- truss_checkpoint
109
- ) in truss_deploy_config.training_checkpoints.artifact_references: # type: ignore
110
- ckpt_path = Path(
111
- truss_deploy_config.training_checkpoints.download_folder, # type: ignore
112
- truss_checkpoint.training_job_id,
113
- truss_checkpoint.paths[0],
114
- )
115
- checkpoint_parts.append(f"{truss_checkpoint.training_job_id}={ckpt_path}")
116
-
117
- return " ".join(checkpoint_parts)
@@ -1,63 +1,8 @@
1
- from jinja2 import Template
2
-
3
- from truss.base import truss_config
4
- from truss.cli.train.deploy_checkpoints.deploy_checkpoints_helpers import (
5
- START_COMMAND_ENVVAR_NAME,
6
- )
7
- from truss.cli.train.deploy_checkpoints.deploy_full_checkpoints import (
8
- build_full_checkpoint_string,
9
- )
10
- from truss.cli.train.types import DeployCheckpointsConfigComplete
11
1
  from truss_train.definitions import WhisperCheckpoint
12
2
 
13
- from .deploy_checkpoints_helpers import (
14
- setup_base_truss_config,
15
- setup_environment_variables_and_secrets,
16
- )
17
-
18
- VLLM_WHISPER_START_COMMAND = Template(
19
- "sh -c '{% if envvars %}{{ envvars }} {% endif %}"
20
- 'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
21
- "vllm serve {{ model_path }} --port 8000 --tensor-parallel-size {{ specify_tensor_parallelism }}'"
22
- )
23
-
24
-
25
- def render_vllm_whisper_truss_config(
26
- checkpoint_deploy: DeployCheckpointsConfigComplete,
27
- ) -> truss_config.TrussConfig:
28
- """Render truss config specifically for whisper checkpoints using vLLM."""
29
- truss_deploy_config = setup_base_truss_config(checkpoint_deploy)
30
-
31
- start_command_envvars = setup_environment_variables_and_secrets(
32
- truss_deploy_config, checkpoint_deploy
33
- )
34
-
35
- checkpoint_str = build_full_checkpoint_string(truss_deploy_config)
36
-
37
- accelerator = checkpoint_deploy.compute.accelerator
38
-
39
- start_command_args = {
40
- "model_path": checkpoint_str,
41
- "envvars": start_command_envvars,
42
- "specify_tensor_parallelism": accelerator.count if accelerator else 1,
43
- }
44
- # Note: we set the start command as an environment variable in supervisord config.
45
- # This is so that we don't have to change the supervisord config when the start command changes.
46
- # Our goal is to reduce the number of times we need to rebuild the image, and allow us to deploy faster.
47
- start_command = VLLM_WHISPER_START_COMMAND.render(**start_command_args)
48
- truss_deploy_config.environment_variables[START_COMMAND_ENVVAR_NAME] = start_command
49
- # Note: supervisord uses the convention %(ENV_VAR_NAME)s to access environment variable VAR_NAME
50
- truss_deploy_config.docker_server.start_command = ( # type: ignore[union-attr]
51
- f"%(ENV_{START_COMMAND_ENVVAR_NAME})s"
52
- )
53
-
54
- return truss_deploy_config
55
-
56
3
 
57
4
  def hydrate_whisper_checkpoint(
58
5
  job_id: str, checkpoint_id: str, checkpoint: dict
59
6
  ) -> WhisperCheckpoint:
60
7
  """Create a Checkpoint object for whisper model weights."""
61
- # NOTE: Slash at the end is important since it means the checkpoint is a directory
62
- paths = [f"rank-0/{checkpoint_id}/"]
63
- return WhisperCheckpoint(training_job_id=job_id, paths=paths)
8
+ return WhisperCheckpoint(training_job_id=job_id, checkpoint_name=checkpoint_id)
truss/cli/train/types.py CHANGED
@@ -1,18 +1,20 @@
1
1
  from dataclasses import dataclass
2
- from pathlib import Path
3
2
  from typing import Optional
4
3
 
4
+ from pydantic import BaseModel
5
+
6
+ from truss.base import truss_config
5
7
  from truss_train.definitions import (
6
8
  CheckpointList,
7
9
  Compute,
8
10
  DeployCheckpointsConfig,
9
11
  DeployCheckpointsRuntime,
10
- ModelWeightsFormat,
11
12
  )
12
13
 
13
14
 
14
15
  @dataclass
15
- class PrepareCheckpointArgs:
16
+ class DeployCheckpointArgs:
17
+ dry_run: bool
16
18
  project_id: Optional[str]
17
19
  job_id: Optional[str]
18
20
  deploy_config_path: Optional[str]
@@ -26,13 +28,20 @@ class DeployCheckpointsConfigComplete(DeployCheckpointsConfig):
26
28
 
27
29
  checkpoint_details: CheckpointList
28
30
  model_name: str
29
- deployment_name: str
30
31
  runtime: DeployCheckpointsRuntime
31
32
  compute: Compute
32
- model_weight_format: ModelWeightsFormat
33
33
 
34
34
 
35
- @dataclass
36
- class PrepareCheckpointResult:
37
- truss_directory: Path
38
- checkpoint_deploy_config: DeployCheckpointsConfigComplete
35
+ class DeploySuccessModelVersion(BaseModel):
36
+ # allow extra fields to be forwards compatible with server
37
+ class Config:
38
+ extra = "allow"
39
+
40
+ name: str
41
+ id: str
42
+
43
+
44
+ class DeploySuccessResult(BaseModel):
45
+ deploy_config: DeployCheckpointsConfigComplete
46
+ truss_config: Optional[truss_config.TrussConfig]
47
+ model_version: Optional[DeploySuccessModelVersion]
@@ -1,5 +1,6 @@
1
1
  import os
2
2
  import sys
3
+ from datetime import datetime
3
4
  from pathlib import Path
4
5
  from typing import Optional, cast
5
6
 
@@ -8,12 +9,18 @@ import rich_click as click
8
9
  import truss.cli.train.core as train_cli
9
10
  from truss.base.constants import TRAINING_TEMPLATE_DIR
10
11
  from truss.cli import remote_cli
11
- from truss.cli.cli import push, truss_cli
12
+ from truss.cli.cli import truss_cli
12
13
  from truss.cli.logs import utils as cli_log_utils
13
14
  from truss.cli.logs.training_log_watcher import TrainingLogWatcher
15
+ from truss.cli.resolvers.training_project_team_resolver import (
16
+ resolve_training_project_team_name,
17
+ )
14
18
  from truss.cli.train import common as train_common
15
19
  from truss.cli.train import core
16
- from truss.cli.train.core import (
20
+ from truss.cli.train.cache import (
21
+ OUTPUT_FORMAT_CLI_TABLE,
22
+ OUTPUT_FORMAT_CSV,
23
+ OUTPUT_FORMAT_JSON,
17
24
  SORT_BY_FILEPATH,
18
25
  SORT_BY_MODIFIED,
19
26
  SORT_BY_PERMISSIONS,
@@ -22,6 +29,7 @@ from truss.cli.train.core import (
22
29
  SORT_ORDER_ASC,
23
30
  SORT_ORDER_DESC,
24
31
  )
32
+ from truss.cli.train.types import DeploySuccessResult
25
33
  from truss.cli.utils import common
26
34
  from truss.cli.utils.output import console, error_console
27
35
  from truss.remote.baseten.core import get_training_job_logs_with_pagination
@@ -41,13 +49,14 @@ truss_cli.add_command(train)
41
49
 
42
50
  def _print_training_job_success_message(
43
51
  job_id: str,
52
+ project_id: str,
44
53
  project_name: str,
45
- job_object: TrainingJob,
54
+ job_object: Optional[TrainingJob],
46
55
  remote_provider: BasetenRemote,
47
56
  ) -> None:
48
57
  """Print success message and helpful commands for a training job."""
49
58
  console.print("✨ Training job successfully created!", style="green")
50
- should_print_cache_summary = (
59
+ should_print_cache_summary = job_object and (
51
60
  job_object.runtime.enable_cache
52
61
  or job_object.runtime.cache_config
53
62
  and job_object.runtime.cache_config.enabled
@@ -64,7 +73,7 @@ def _print_training_job_success_message(
64
73
  f"🔍 View metrics for your job via "
65
74
  f"[cyan]'truss train metrics --job-id {job_id}'[/cyan]\n"
66
75
  f"{cache_summary_snippet}"
67
- f"🌐 Status page: {common.format_link(core.status_page_url(remote_provider.remote_url, job_id))}"
76
+ f"🌐 View job in the UI: {common.format_link(core.status_page_url(remote_provider.remote_url, project_id, job_id))}"
68
77
  )
69
78
 
70
79
 
@@ -80,8 +89,13 @@ def _handle_post_create_logic(
80
89
  style="green",
81
90
  )
82
91
  else:
92
+ # recreate currently doesn't pass back a job object.
83
93
  _print_training_job_success_message(
84
- job_id, project_name, job_resp["job_object"], remote_provider
94
+ job_id,
95
+ project_id,
96
+ project_name,
97
+ job_resp.get("job_object"),
98
+ remote_provider,
85
99
  )
86
100
 
87
101
  if tail:
@@ -100,29 +114,70 @@ def _prepare_click_context(f: click.Command, params: dict) -> click.Context:
100
114
  return ctx
101
115
 
102
116
 
117
+ def _resolve_team_name(
118
+ remote_provider: BasetenRemote,
119
+ provided_team_name: Optional[str],
120
+ existing_project_name: Optional[str] = None,
121
+ existing_teams: Optional[dict[str, dict[str, str]]] = None,
122
+ ) -> tuple[Optional[str], Optional[str]]:
123
+ return resolve_training_project_team_name(
124
+ remote_provider=remote_provider,
125
+ provided_team_name=provided_team_name,
126
+ existing_project_name=existing_project_name,
127
+ existing_teams=existing_teams,
128
+ )
129
+
130
+
103
131
  @train.command(name="push")
104
132
  @click.argument("config", type=Path, required=True)
105
133
  @click.option("--remote", type=str, required=False, help="Remote to use")
106
134
  @click.option("--tail", is_flag=True, help="Tail for status + logs after push.")
107
135
  @click.option("--job-name", type=str, required=False, help="Name of the training job.")
136
+ @click.option(
137
+ "--team",
138
+ "provided_team_name",
139
+ type=str,
140
+ required=False,
141
+ help="Team name for the training project",
142
+ )
108
143
  @common.common_options()
109
144
  def push_training_job(
110
- config: Path, remote: Optional[str], tail: bool, job_name: Optional[str]
145
+ config: Path,
146
+ remote: Optional[str],
147
+ tail: bool,
148
+ job_name: Optional[str],
149
+ provided_team_name: Optional[str],
111
150
  ):
112
151
  """Run a training job"""
113
- from truss_train import deployment
152
+ from truss_train import deployment, loader
114
153
 
115
154
  if not remote:
116
155
  remote = remote_cli.inquire_remote_name()
117
156
 
118
- with console.status("Creating training job...", spinner="dots"):
119
- remote_provider: BasetenRemote = cast(
120
- BasetenRemote, RemoteFactory.create(remote=remote)
121
- )
122
- job_resp = deployment.create_training_job_from_file(
123
- remote_provider, config, job_name
157
+ remote_provider: BasetenRemote = cast(
158
+ BasetenRemote, RemoteFactory.create(remote=remote)
159
+ )
160
+
161
+ existing_teams = remote_provider.api.get_teams()
162
+
163
+ with loader.import_training_project(config) as training_project:
164
+ team_name, team_id = _resolve_team_name(
165
+ remote_provider,
166
+ provided_team_name,
167
+ existing_project_name=training_project.name,
168
+ existing_teams=existing_teams,
124
169
  )
125
170
 
171
+ with console.status("Creating training job...", spinner="dots"):
172
+ job_resp = deployment.create_training_job(
173
+ remote_provider,
174
+ config,
175
+ training_project,
176
+ job_name_from_cli=job_name,
177
+ team_name=team_name,
178
+ team_id=team_id,
179
+ )
180
+
126
181
  # Note: This post create logic needs to happen outside the context
127
182
  # of the above context manager, as only one console session can be active
128
183
  # at a time.
@@ -156,11 +211,16 @@ def recreate_training_job(job_id: Optional[str], remote: Optional[str], tail: bo
156
211
  @train.command(name="logs")
157
212
  @click.option("--remote", type=str, required=False, help="Remote to use")
158
213
  @click.option("--project-id", type=str, required=False, help="Project ID.")
214
+ @click.option("--project", type=str, required=False, help="Project name or project id.")
159
215
  @click.option("--job-id", type=str, required=False, help="Job ID.")
160
216
  @click.option("--tail", is_flag=True, help="Tail for ongoing logs.")
161
217
  @common.common_options()
162
218
  def get_job_logs(
163
- remote: Optional[str], project_id: Optional[str], job_id: Optional[str], tail: bool
219
+ remote: Optional[str],
220
+ project_id: Optional[str],
221
+ project: Optional[str],
222
+ job_id: Optional[str],
223
+ tail: bool,
164
224
  ):
165
225
  """Fetch logs for a training job"""
166
226
 
@@ -170,6 +230,10 @@ def get_job_logs(
170
230
  remote_provider: BasetenRemote = cast(
171
231
  BasetenRemote, RemoteFactory.create(remote=remote)
172
232
  )
233
+ project_id = _maybe_resolve_project_id_from_id_or_name(
234
+ remote_provider, project_id=project_id, project=project
235
+ )
236
+
173
237
  project_id, job_id = train_common.get_most_recent_job(
174
238
  remote_provider, project_id, job_id
175
239
  )
@@ -188,12 +252,17 @@ def get_job_logs(
188
252
 
189
253
  @train.command(name="stop")
190
254
  @click.option("--project-id", type=str, required=False, help="Project ID.")
255
+ @click.option("--project", type=str, required=False, help="Project name or project id.")
191
256
  @click.option("--job-id", type=str, required=False, help="Job ID.")
192
257
  @click.option("--all", is_flag=True, help="Stop all running jobs.")
193
258
  @click.option("--remote", type=str, required=False, help="Remote to use")
194
259
  @common.common_options()
195
260
  def stop_job(
196
- project_id: Optional[str], job_id: Optional[str], all: bool, remote: Optional[str]
261
+ project_id: Optional[str],
262
+ project: Optional[str],
263
+ job_id: Optional[str],
264
+ all: bool,
265
+ remote: Optional[str],
197
266
  ):
198
267
  """Stop a training job"""
199
268
 
@@ -203,6 +272,9 @@ def stop_job(
203
272
  remote_provider: BasetenRemote = cast(
204
273
  BasetenRemote, RemoteFactory.create(remote=remote)
205
274
  )
275
+ project_id = _maybe_resolve_project_id_from_id_or_name(
276
+ remote_provider, project_id=project_id, project=project
277
+ )
206
278
  if all:
207
279
  train_cli.stop_all_jobs(remote_provider, project_id)
208
280
  else:
@@ -217,13 +289,17 @@ def stop_job(
217
289
  @click.option(
218
290
  "--project-id", type=str, required=False, help="View training jobs for a project."
219
291
  )
292
+ @click.option("--project", type=str, required=False, help="Project name or project id.")
220
293
  @click.option(
221
294
  "--job-id", type=str, required=False, help="View a specific training job."
222
295
  )
223
296
  @click.option("--remote", type=str, required=False, help="Remote to use")
224
297
  @common.common_options()
225
298
  def view_training(
226
- project_id: Optional[str], job_id: Optional[str], remote: Optional[str]
299
+ project_id: Optional[str],
300
+ project: Optional[str],
301
+ job_id: Optional[str],
302
+ remote: Optional[str],
227
303
  ):
228
304
  """List all training jobs for a project"""
229
305
 
@@ -233,16 +309,24 @@ def view_training(
233
309
  remote_provider: BasetenRemote = cast(
234
310
  BasetenRemote, RemoteFactory.create(remote=remote)
235
311
  )
312
+ project_id = _maybe_resolve_project_id_from_id_or_name(
313
+ remote_provider, project_id=project_id, project=project
314
+ )
315
+
236
316
  train_cli.view_training_details(remote_provider, project_id, job_id)
237
317
 
238
318
 
239
319
  @train.command(name="metrics")
240
320
  @click.option("--project-id", type=str, required=False, help="Project ID.")
321
+ @click.option("--project", type=str, required=False, help="Project name or project id.")
241
322
  @click.option("--job-id", type=str, required=False, help="Job ID.")
242
323
  @click.option("--remote", type=str, required=False, help="Remote to use")
243
324
  @common.common_options()
244
325
  def get_job_metrics(
245
- project_id: Optional[str], job_id: Optional[str], remote: Optional[str]
326
+ project_id: Optional[str],
327
+ project: Optional[str],
328
+ job_id: Optional[str],
329
+ remote: Optional[str],
246
330
  ):
247
331
  """Get metrics for a training job"""
248
332
 
@@ -252,11 +336,15 @@ def get_job_metrics(
252
336
  remote_provider: BasetenRemote = cast(
253
337
  BasetenRemote, RemoteFactory.create(remote=remote)
254
338
  )
339
+ project_id = _maybe_resolve_project_id_from_id_or_name(
340
+ remote_provider, project_id=project_id, project=project
341
+ )
255
342
  train_cli.view_training_job_metrics(remote_provider, project_id, job_id)
256
343
 
257
344
 
258
345
  @train.command(name="deploy_checkpoints")
259
346
  @click.option("--project-id", type=str, required=False, help="Project ID.")
347
+ @click.option("--project", type=str, required=False, help="Project name or project id.")
260
348
  @click.option("--job-id", type=str, required=False, help="Job ID.")
261
349
  @click.option(
262
350
  "--config",
@@ -267,14 +355,22 @@ def get_job_metrics(
267
355
  @click.option(
268
356
  "--dry-run", is_flag=True, help="Generate a truss config without deploying"
269
357
  )
358
+ @click.option(
359
+ "--truss-config-output-dir",
360
+ type=str,
361
+ required=False,
362
+ help="Path to output the truss config to. If not provided, will output to truss_configs/<model_version_name>_<model_version_id> or truss_configs/dry_run_<timestamp> if dry run.",
363
+ )
270
364
  @click.option("--remote", type=str, required=False, help="Remote to use")
271
365
  @common.common_options()
272
366
  def deploy_checkpoints(
273
367
  project_id: Optional[str],
368
+ project: Optional[str],
274
369
  job_id: Optional[str],
275
370
  config: Optional[str],
276
371
  remote: Optional[str],
277
372
  dry_run: bool,
373
+ truss_config_output_dir: Optional[str],
278
374
  ):
279
375
  """
280
376
  Deploy a LoRA checkpoint via vLLM.
@@ -286,26 +382,52 @@ def deploy_checkpoints(
286
382
  remote_provider: BasetenRemote = cast(
287
383
  BasetenRemote, RemoteFactory.create(remote=remote)
288
384
  )
289
- prepare_checkpoint_result = train_cli.prepare_checkpoint_deploy(
385
+ project_id = _maybe_resolve_project_id_from_id_or_name(
386
+ remote_provider, project_id=project_id, project=project
387
+ )
388
+ result = train_cli.create_model_version_from_inference_template(
290
389
  remote_provider,
291
- train_cli.PrepareCheckpointArgs(
292
- project_id=project_id, job_id=job_id, deploy_config_path=config
390
+ train_cli.DeployCheckpointArgs(
391
+ project_id=project_id,
392
+ job_id=job_id,
393
+ deploy_config_path=config,
394
+ dry_run=dry_run,
293
395
  ),
294
396
  )
295
397
 
296
- params = {
297
- "target_directory": prepare_checkpoint_result.truss_directory,
298
- "remote": remote,
299
- "model_name": prepare_checkpoint_result.checkpoint_deploy_config.model_name,
300
- "publish": True,
301
- "deployment_name": prepare_checkpoint_result.checkpoint_deploy_config.deployment_name,
302
- }
303
- ctx = _prepare_click_context(push, params)
304
398
  if dry_run:
305
- console.print("--dry-run flag provided, not deploying", style="yellow")
306
- else:
307
- push.invoke(ctx)
308
- train_cli.print_deploy_checkpoints_success_message(prepare_checkpoint_result)
399
+ console.print("did not deploy because --dry-run flag provided", style="yellow")
400
+
401
+ _write_truss_config(result, truss_config_output_dir, dry_run)
402
+
403
+ if not dry_run:
404
+ train_cli.print_deploy_checkpoints_success_message(result.deploy_config)
405
+
406
+
407
+ def _write_truss_config(
408
+ result: DeploySuccessResult, truss_config_output_dir: Optional[str], dry_run: bool
409
+ ) -> None:
410
+ if not result.truss_config:
411
+ return
412
+ # format: 20251006_123456
413
+ datestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
414
+ folder_name = (
415
+ f"{result.model_version.name}_{result.model_version.id}"
416
+ if result.model_version
417
+ else f"dry_run_{datestamp}"
418
+ )
419
+ output_dir_str = truss_config_output_dir or f"truss_configs/{folder_name}"
420
+ output_dir = Path(output_dir_str)
421
+ output_path = output_dir / "config.yaml"
422
+ os.makedirs(output_dir, exist_ok=True)
423
+ console.print(f"Writing truss config to {output_path}", style="yellow")
424
+ console.print(f"👀 Run `cat {output_path}` to view the truss config", style="green")
425
+ if dry_run:
426
+ console.print(
427
+ f"🚀 Run `cd {output_dir} && truss push --publish` to deploy the truss",
428
+ style="green",
429
+ )
430
+ result.truss_config.write_to_yaml_file(output_path)
309
431
 
310
432
 
311
433
  @train.command(name="download")
@@ -481,8 +603,17 @@ def cache():
481
603
  default=SORT_ORDER_ASC,
482
604
  help="Sort order: ascending or descending.",
483
605
  )
606
+ @click.option(
607
+ "-o",
608
+ "--output-format",
609
+ type=click.Choice([OUTPUT_FORMAT_CLI_TABLE, OUTPUT_FORMAT_CSV, OUTPUT_FORMAT_JSON]),
610
+ default=OUTPUT_FORMAT_CLI_TABLE,
611
+ help="Output format: cli-table (default), csv, or json.",
612
+ )
484
613
  @common.common_options()
485
- def view_cache_summary(project: str, remote: Optional[str], sort: str, order: str):
614
+ def view_cache_summary(
615
+ project: str, remote: Optional[str], sort: str, order: str, output_format: str
616
+ ):
486
617
  """View cache summary for a training project"""
487
618
  if not remote:
488
619
  remote = remote_cli.inquire_remote_name()
@@ -491,4 +622,18 @@ def view_cache_summary(project: str, remote: Optional[str], sort: str, order: st
491
622
  BasetenRemote, RemoteFactory.create(remote=remote)
492
623
  )
493
624
 
494
- train_cli.view_cache_summary_by_project(remote_provider, project, sort, order)
625
+ train_cli.view_cache_summary_by_project(
626
+ remote_provider, project, sort, order, output_format
627
+ )
628
+
629
+
630
+ def _maybe_resolve_project_id_from_id_or_name(
631
+ remote_provider: BasetenRemote, project_id: Optional[str], project: Optional[str]
632
+ ) -> Optional[str]:
633
+ """resolve the project_id or project. `project` can be name or id"""
634
+ if project and project_id:
635
+ console.print("Both `project-id` and `project` provided. Using `project`.")
636
+ project_str = project or project_id
637
+ if not project_str:
638
+ return None
639
+ return train_cli.fetch_project_by_name_or_id(remote_provider, project_str)["id"]
truss/cli/utils/common.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import datetime
2
2
  import logging
3
+ import re
3
4
  import sys
4
5
  import warnings
5
6
  from functools import wraps
@@ -20,6 +21,8 @@ from truss.cli.utils import self_upgrade
20
21
  from truss.cli.utils.output import console
21
22
  from truss.util import user_config
22
23
 
24
+ logger = logging.getLogger(__name__)
25
+
23
26
  INCLUDE_GIT_INFO_DOC = (
24
27
  "Whether to attach git versioning info (sha, branch, tag) to deployments made from "
25
28
  "within a git repo. If set to True in `.trussrc`, it will always be attached."
@@ -181,10 +184,44 @@ def is_human_log_level(ctx: click.Context) -> bool:
181
184
  return get_required_option(ctx, "log") != _HUMANFRIENDLY_LOG_LEVEL
182
185
 
183
186
 
184
- def format_localized_time(iso_timestamp: str) -> str:
187
+ def _normalize_iso_timestamp(iso_timestamp: str) -> str:
188
+ iso_timestamp = iso_timestamp.strip()
185
189
  if iso_timestamp.endswith("Z"):
186
- iso_timestamp = iso_timestamp.replace("Z", "+00:00")
187
- utc_time = datetime.datetime.fromisoformat(iso_timestamp)
190
+ iso_timestamp = iso_timestamp[:-1] + "+00:00"
191
+
192
+ tz_part = ""
193
+ tz_match = re.search(r"([+-]\d{2}:\d{2}|[+-]\d{4})$", iso_timestamp)
194
+ if tz_match:
195
+ tz_part = tz_match.group(0)
196
+ iso_timestamp = iso_timestamp[: tz_match.start()]
197
+
198
+ iso_timestamp = iso_timestamp.rstrip()
199
+
200
+ if tz_part and ":" not in tz_part:
201
+ tz_part = f"{tz_part[:3]}:{tz_part[3:]}"
202
+
203
+ fractional_match = re.search(r"\.(\d+)$", iso_timestamp)
204
+ if fractional_match:
205
+ fractional_digits = fractional_match.group(1)
206
+ if len(fractional_digits) > 6:
207
+ iso_timestamp = (
208
+ iso_timestamp[: fractional_match.start()] + "." + fractional_digits[:6]
209
+ )
210
+
211
+ return f"{iso_timestamp}{tz_part}"
212
+
213
+
214
+ # NOTE: `pyproject.toml` declares support down to Python 3.9, whose
215
+ # `datetime.fromisoformat` cannot parse nanosecond fractions or colonless offsets,
216
+ # so normalize timestamps before parsing.
217
+ def format_localized_time(iso_timestamp: str) -> str:
218
+ try:
219
+ utc_time = datetime.datetime.fromisoformat(iso_timestamp)
220
+ except ValueError:
221
+ # Handle non-standard formats (nanoseconds, Z suffix, colonless offsets)
222
+ normalized_timestamp = _normalize_iso_timestamp(iso_timestamp)
223
+ utc_time = datetime.datetime.fromisoformat(normalized_timestamp)
224
+
188
225
  local_time = utc_time.astimezone()
189
226
  return local_time.strftime("%Y-%m-%d %H:%M:%S")
190
227