truss 0.11.18rc500__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- truss/api/__init__.py +5 -2
- truss/base/truss_config.py +10 -3
- truss/cli/chains_commands.py +39 -1
- truss/cli/cli.py +35 -5
- truss/cli/remote_cli.py +29 -0
- truss/cli/resolvers/chain_team_resolver.py +82 -0
- truss/cli/resolvers/model_team_resolver.py +90 -0
- truss/cli/resolvers/training_project_team_resolver.py +81 -0
- truss/cli/train/cache.py +332 -0
- truss/cli/train/core.py +19 -143
- truss/cli/train_commands.py +69 -11
- truss/cli/utils/common.py +40 -3
- truss/remote/baseten/api.py +58 -5
- truss/remote/baseten/core.py +22 -4
- truss/remote/baseten/remote.py +24 -2
- truss/templates/control/control/helpers/inference_server_process_controller.py +3 -1
- truss/templates/server/requirements.txt +1 -1
- truss/templates/server.Dockerfile.jinja +10 -10
- truss/templates/shared/util.py +6 -5
- truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
- truss/tests/cli/test_chains_cli.py +44 -0
- truss/tests/cli/test_cli.py +134 -1
- truss/tests/cli/test_cli_utils_common.py +11 -0
- truss/tests/cli/test_model_team_resolver.py +279 -0
- truss/tests/cli/train/test_cache_view.py +240 -3
- truss/tests/cli/train/test_train_cli_core.py +2 -2
- truss/tests/cli/train/test_train_team_parameter.py +395 -0
- truss/tests/conftest.py +187 -0
- truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
- truss/tests/remote/baseten/test_api.py +122 -3
- truss/tests/remote/baseten/test_chain_upload.py +10 -1
- truss/tests/remote/baseten/test_core.py +86 -0
- truss/tests/remote/baseten/test_remote.py +216 -288
- truss/tests/test_config.py +21 -12
- truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
- truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
- truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
- truss/tests/test_model_inference.py +13 -0
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/METADATA +1 -1
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/RECORD +50 -38
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
- truss_chains/deployment/deployment_client.py +9 -4
- truss_chains/private_types.py +15 -0
- truss_train/definitions.py +3 -1
- truss_train/deployment.py +43 -21
- truss_train/public_api.py +4 -2
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
truss/api/__init__.py
CHANGED
|
@@ -65,6 +65,7 @@ def push(
|
|
|
65
65
|
progress_bar: Optional[Type["progress.Progress"]] = None,
|
|
66
66
|
include_git_info: bool = False,
|
|
67
67
|
preserve_env_instance_type: bool = True,
|
|
68
|
+
deploy_timeout_minutes: Optional[int] = None,
|
|
68
69
|
) -> definitions.ModelDeployment:
|
|
69
70
|
"""
|
|
70
71
|
Pushes a Truss to Baseten.
|
|
@@ -77,13 +78,13 @@ def push(
|
|
|
77
78
|
promote the truss to production after deploy completes.
|
|
78
79
|
promote: Push the truss as a published deployment. Even if a production deployment exists,
|
|
79
80
|
promote the truss to production after deploy completes.
|
|
80
|
-
preserve_previous_production_deployment: Preserve the previous production deployment
|
|
81
|
+
preserve_previous_production_deployment: Preserve the previous production deployment's autoscaling
|
|
81
82
|
setting. When not specified, the previous production deployment will be updated to allow it to
|
|
82
83
|
scale to zero. Can only be use in combination with `promote` option.
|
|
83
84
|
trusted: [DEPRECATED]
|
|
84
85
|
deployment_name: Name of the deployment created by the push. Can only be
|
|
85
86
|
used in combination with `publish` or `promote`. Deployment name must
|
|
86
|
-
only contain alphanumeric,
|
|
87
|
+
only contain alphanumeric, '.', '-' or '_' characters.
|
|
87
88
|
environment: Name of stable environment on baseten.
|
|
88
89
|
progress_bar: Optional `rich.progress.Progress` if output is desired.
|
|
89
90
|
include_git_info: Whether to attach git versioning info (sha, branch, tag) to
|
|
@@ -92,6 +93,7 @@ def push(
|
|
|
92
93
|
preserve_env_instance_type: When pushing a truss to an environment, whether to use the resources
|
|
93
94
|
specified in the truss config to resolve the instance type or preserve the instance type
|
|
94
95
|
configured in the specified environment.
|
|
96
|
+
deploy_timeout_minutes: Optional timeout in minutes for the deployment operation.
|
|
95
97
|
|
|
96
98
|
Returns:
|
|
97
99
|
The newly created ModelDeployment.
|
|
@@ -135,6 +137,7 @@ def push(
|
|
|
135
137
|
progress_bar=progress_bar,
|
|
136
138
|
include_git_info=include_git_info,
|
|
137
139
|
preserve_env_instance_type=preserve_env_instance_type,
|
|
140
|
+
deploy_timeout_minutes=deploy_timeout_minutes,
|
|
138
141
|
) # type: ignore
|
|
139
142
|
|
|
140
143
|
return definitions.ModelDeployment(cast(BasetenService, service))
|
truss/base/truss_config.py
CHANGED
|
@@ -147,7 +147,7 @@ class ModelRepo(custom_types.ConfigModel):
|
|
|
147
147
|
volume_folder: Optional[
|
|
148
148
|
Annotated[str, pydantic.StringConstraints(min_length=1)]
|
|
149
149
|
] = None
|
|
150
|
-
use_volume: bool
|
|
150
|
+
use_volume: bool
|
|
151
151
|
kind: ModelRepoSourceKind = ModelRepoSourceKind.HF
|
|
152
152
|
runtime_secret_name: str = "hf_access_token"
|
|
153
153
|
|
|
@@ -163,7 +163,7 @@ class ModelRepo(custom_types.ConfigModel):
|
|
|
163
163
|
return v
|
|
164
164
|
if v.get("kind") == ModelRepoSourceKind.HF.value and v.get("revision") is None:
|
|
165
165
|
logger.warning(
|
|
166
|
-
"the key `revision: str` is required for use_volume=True huggingface repos."
|
|
166
|
+
"the key `revision: str` is required for use_volume=True huggingface repos. For S3/GCS/Azure repos, set it to any non-empty string."
|
|
167
167
|
)
|
|
168
168
|
raise_insufficent_revision(v.get("repo_id"), v.get("revision"))
|
|
169
169
|
if v.get("volume_folder") is None or len(v["volume_folder"]) == 0:
|
|
@@ -202,7 +202,14 @@ class ModelCache(pydantic.RootModel[list[ModelRepo]]):
|
|
|
202
202
|
)
|
|
203
203
|
|
|
204
204
|
|
|
205
|
-
class
|
|
205
|
+
class ModelRepoCacheInternal(ModelRepo):
|
|
206
|
+
use_volume: bool = False # override
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class CacheInternal(pydantic.RootModel[list[ModelRepoCacheInternal]]):
|
|
210
|
+
@property
|
|
211
|
+
def models(self) -> list[ModelRepoCacheInternal]:
|
|
212
|
+
return self.root
|
|
206
213
|
|
|
207
214
|
|
|
208
215
|
class HealthChecks(custom_types.ConfigModel):
|
truss/cli/chains_commands.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import time
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import List, Optional, Tuple
|
|
3
|
+
from typing import List, Optional, Tuple, cast
|
|
4
4
|
|
|
5
5
|
import rich
|
|
6
6
|
import rich.live
|
|
@@ -14,10 +14,13 @@ from rich import progress
|
|
|
14
14
|
|
|
15
15
|
from truss.cli import remote_cli
|
|
16
16
|
from truss.cli.cli import truss_cli
|
|
17
|
+
from truss.cli.resolvers.chain_team_resolver import resolve_chain_team_name
|
|
17
18
|
from truss.cli.utils import common, output
|
|
18
19
|
from truss.cli.utils.output import console
|
|
19
20
|
from truss.remote.baseten.core import ACTIVE_STATUS, DEPLOYING_STATUSES
|
|
21
|
+
from truss.remote.baseten.remote import BasetenRemote
|
|
20
22
|
from truss.remote.baseten.utils.status import get_displayable_status
|
|
23
|
+
from truss.remote.remote_factory import RemoteFactory
|
|
21
24
|
from truss.util import user_config
|
|
22
25
|
from truss.util.log_utils import LogInterceptor
|
|
23
26
|
|
|
@@ -219,6 +222,22 @@ def _create_chains_table(service) -> Tuple[rich.table.Table, List[str]]:
|
|
|
219
222
|
default=False,
|
|
220
223
|
help="Disable downloading of pushed chain source code from the UI.",
|
|
221
224
|
)
|
|
225
|
+
@click.option(
|
|
226
|
+
"--deployment-name",
|
|
227
|
+
type=str,
|
|
228
|
+
required=False,
|
|
229
|
+
help=(
|
|
230
|
+
"Name of the deployment created by the publish. Can only be used "
|
|
231
|
+
"in combination with '--publish' or '--promote'."
|
|
232
|
+
),
|
|
233
|
+
)
|
|
234
|
+
@click.option(
|
|
235
|
+
"--team",
|
|
236
|
+
"provided_team_name",
|
|
237
|
+
type=str,
|
|
238
|
+
required=False,
|
|
239
|
+
help="Team name for the chain deployment",
|
|
240
|
+
)
|
|
222
241
|
@click.pass_context
|
|
223
242
|
@common.common_options()
|
|
224
243
|
def push_chain(
|
|
@@ -236,6 +255,8 @@ def push_chain(
|
|
|
236
255
|
experimental_watch_chainlet_names: Optional[str],
|
|
237
256
|
include_git_info: bool = False,
|
|
238
257
|
disable_chain_download: bool = False,
|
|
258
|
+
deployment_name: Optional[str] = None,
|
|
259
|
+
provided_team_name: Optional[str] = None,
|
|
239
260
|
) -> None:
|
|
240
261
|
"""
|
|
241
262
|
Deploys a chain remotely.
|
|
@@ -280,10 +301,24 @@ def push_chain(
|
|
|
280
301
|
if not include_git_info:
|
|
281
302
|
include_git_info = user_config.settings.include_git_info
|
|
282
303
|
|
|
304
|
+
# Resolve team if not in dryrun mode
|
|
305
|
+
team_id = None
|
|
283
306
|
with framework.ChainletImporter.import_target(source, entrypoint) as entrypoint_cls:
|
|
284
307
|
chain_name = (
|
|
285
308
|
name or entrypoint_cls.meta_data.chain_name or entrypoint_cls.display_name
|
|
286
309
|
)
|
|
310
|
+
|
|
311
|
+
remote_provider = None
|
|
312
|
+
if not dryrun and remote:
|
|
313
|
+
remote_provider = cast(BasetenRemote, RemoteFactory.create(remote=remote))
|
|
314
|
+
existing_teams = remote_provider.api.get_teams()
|
|
315
|
+
_, team_id = resolve_chain_team_name(
|
|
316
|
+
remote_provider,
|
|
317
|
+
provided_team_name,
|
|
318
|
+
existing_chain_name=chain_name,
|
|
319
|
+
existing_teams=existing_teams,
|
|
320
|
+
)
|
|
321
|
+
|
|
287
322
|
options = chains_def.PushOptionsBaseten.create(
|
|
288
323
|
chain_name=chain_name,
|
|
289
324
|
promote=promote,
|
|
@@ -294,6 +329,9 @@ def push_chain(
|
|
|
294
329
|
include_git_info=include_git_info,
|
|
295
330
|
working_dir=source.parent if source.is_file() else source.resolve(),
|
|
296
331
|
disable_chain_download=disable_chain_download,
|
|
332
|
+
deployment_name=deployment_name,
|
|
333
|
+
team_id=team_id,
|
|
334
|
+
remote_provider=remote_provider,
|
|
297
335
|
)
|
|
298
336
|
service = deployment_client.push(
|
|
299
337
|
entrypoint_cls, options, progress_bar=progress.Progress
|
truss/cli/cli.py
CHANGED
|
@@ -19,6 +19,7 @@ from truss.base.truss_config import Build, ModelServer, TransportKind
|
|
|
19
19
|
from truss.cli import remote_cli
|
|
20
20
|
from truss.cli.logs import utils as cli_log_utils
|
|
21
21
|
from truss.cli.logs.model_log_watcher import ModelDeploymentLogWatcher
|
|
22
|
+
from truss.cli.resolvers.model_team_resolver import resolve_model_team_name
|
|
22
23
|
from truss.cli.utils import common
|
|
23
24
|
from truss.cli.utils.output import console, error_console
|
|
24
25
|
from truss.remote.baseten.core import (
|
|
@@ -462,7 +463,7 @@ def run_python(script, target_directory):
|
|
|
462
463
|
required=False,
|
|
463
464
|
help=(
|
|
464
465
|
"Name of the deployment created by the push. Can only be "
|
|
465
|
-
"used in combination with
|
|
466
|
+
"used in combination with --publish or --promote."
|
|
466
467
|
),
|
|
467
468
|
)
|
|
468
469
|
@click.option(
|
|
@@ -501,6 +502,19 @@ def run_python(script, target_directory):
|
|
|
501
502
|
"Default is --preserve-env-instance-type."
|
|
502
503
|
),
|
|
503
504
|
)
|
|
505
|
+
@click.option(
|
|
506
|
+
"--deploy-timeout-minutes",
|
|
507
|
+
type=int,
|
|
508
|
+
required=False,
|
|
509
|
+
help="Timeout in minutes for the deploy operation.",
|
|
510
|
+
)
|
|
511
|
+
@click.option(
|
|
512
|
+
"--team",
|
|
513
|
+
"provided_team_name",
|
|
514
|
+
type=str,
|
|
515
|
+
required=False,
|
|
516
|
+
help="Team name for the model",
|
|
517
|
+
)
|
|
504
518
|
@common.common_options()
|
|
505
519
|
def push(
|
|
506
520
|
target_directory: str,
|
|
@@ -518,6 +532,8 @@ def push(
|
|
|
518
532
|
include_git_info: bool = False,
|
|
519
533
|
tail: bool = False,
|
|
520
534
|
preserve_env_instance_type: bool = True,
|
|
535
|
+
deploy_timeout_minutes: Optional[int] = None,
|
|
536
|
+
provided_team_name: Optional[str] = None,
|
|
521
537
|
) -> None:
|
|
522
538
|
"""
|
|
523
539
|
Pushes a truss to a TrussRemote.
|
|
@@ -547,6 +563,17 @@ def push(
|
|
|
547
563
|
if not model_name:
|
|
548
564
|
model_name = remote_cli.inquire_model_name()
|
|
549
565
|
|
|
566
|
+
# Resolve team_id if BasetenRemote
|
|
567
|
+
team_id = None
|
|
568
|
+
if isinstance(remote_provider, BasetenRemote):
|
|
569
|
+
existing_teams = remote_provider.api.get_teams()
|
|
570
|
+
_, team_id = resolve_model_team_name(
|
|
571
|
+
remote_provider=remote_provider,
|
|
572
|
+
provided_team_name=provided_team_name,
|
|
573
|
+
existing_model_name=model_name,
|
|
574
|
+
existing_teams=existing_teams,
|
|
575
|
+
)
|
|
576
|
+
|
|
550
577
|
if promote and environment:
|
|
551
578
|
raise click.UsageError(
|
|
552
579
|
"'promote' flag and 'environment' flag cannot both be specified."
|
|
@@ -612,11 +639,12 @@ def push(
|
|
|
612
639
|
console.print(fp8_and_num_builder_gpus_text, style="yellow")
|
|
613
640
|
|
|
614
641
|
source = Path(target_directory)
|
|
615
|
-
|
|
642
|
+
working_dir = source.parent if source.is_file() else source.resolve()
|
|
643
|
+
|
|
616
644
|
service = remote_provider.push(
|
|
617
|
-
tr,
|
|
645
|
+
truss_handle=tr,
|
|
618
646
|
model_name=model_name,
|
|
619
|
-
working_dir=
|
|
647
|
+
working_dir=working_dir,
|
|
620
648
|
publish=publish,
|
|
621
649
|
promote=promote,
|
|
622
650
|
preserve_previous_prod_deployment=preserve_previous_production_deployment,
|
|
@@ -626,7 +654,9 @@ def push(
|
|
|
626
654
|
progress_bar=progress.Progress,
|
|
627
655
|
include_git_info=include_git_info,
|
|
628
656
|
preserve_env_instance_type=preserve_env_instance_type,
|
|
629
|
-
|
|
657
|
+
deploy_timeout_minutes=deploy_timeout_minutes,
|
|
658
|
+
team_id=team_id,
|
|
659
|
+
)
|
|
630
660
|
|
|
631
661
|
click.echo(f"✨ Model {model_name} was successfully pushed ✨")
|
|
632
662
|
|
truss/cli/remote_cli.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
1
3
|
from InquirerPy import inquirer
|
|
2
4
|
from InquirerPy.validator import ValidationError, Validator
|
|
3
5
|
|
|
@@ -56,3 +58,30 @@ def inquire_remote_name() -> str:
|
|
|
56
58
|
|
|
57
59
|
def inquire_model_name() -> str:
|
|
58
60
|
return inquirer.text("📦 Name this model:", qmark="").execute()
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def get_team_id_from_name(
|
|
64
|
+
teams: dict[str, dict[str, str]], team_name: str
|
|
65
|
+
) -> Optional[str]:
|
|
66
|
+
team = teams.get(team_name)
|
|
67
|
+
return team["id"] if team else None
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def format_available_teams(teams: dict[str, dict[str, str]]) -> str:
|
|
71
|
+
team_names = list(teams.keys())
|
|
72
|
+
return ", ".join(team_names) if team_names else "none"
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def inquire_team(
|
|
76
|
+
existing_teams: Optional[dict[str, dict[str, str]]] = None,
|
|
77
|
+
) -> Optional[str]:
|
|
78
|
+
if existing_teams is not None:
|
|
79
|
+
selected_team_name = inquirer.select(
|
|
80
|
+
"👥 Which team do you want to push to?",
|
|
81
|
+
qmark="",
|
|
82
|
+
choices=list[str](existing_teams.keys()),
|
|
83
|
+
).execute()
|
|
84
|
+
return selected_team_name
|
|
85
|
+
|
|
86
|
+
# If no existing teams, return None (don't propagate team param)
|
|
87
|
+
return None
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""Team resolution logic for chains."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import click
|
|
6
|
+
|
|
7
|
+
from truss.cli import remote_cli
|
|
8
|
+
from truss.remote.baseten.remote import BasetenRemote
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def resolve_chain_team_name(
|
|
12
|
+
remote_provider: BasetenRemote,
|
|
13
|
+
provided_team_name: Optional[str],
|
|
14
|
+
existing_chain_name: Optional[str] = None,
|
|
15
|
+
existing_teams: Optional[dict[str, dict[str, str]]] = None,
|
|
16
|
+
) -> tuple[Optional[str], Optional[str]]:
|
|
17
|
+
"""Resolve team name and team_id from provided team name or by prompting the user.
|
|
18
|
+
Returns a tuple of (team_name, team_id).
|
|
19
|
+
This function handles 8 distinct scenarios organized into 3 high-level categories:
|
|
20
|
+
|
|
21
|
+
HIGH-LEVEL SCENARIO 1: --team PROVIDED
|
|
22
|
+
SCENARIO 1: Valid team name, user has access
|
|
23
|
+
→ Returns (team_name, team_id) for that team (no prompt, no error)
|
|
24
|
+
SCENARIO 2: Invalid team name (does not exist)
|
|
25
|
+
→ Raises ClickException with error message listing available teams
|
|
26
|
+
|
|
27
|
+
HIGH-LEVEL SCENARIO 2: --team NOT PROVIDED, Chain does not exist
|
|
28
|
+
SCENARIO 3: User has multiple teams, no existing chain
|
|
29
|
+
→ Prompts user to select a team via inquire_team()
|
|
30
|
+
SCENARIO 6: User has exactly one team, no existing chain
|
|
31
|
+
→ Returns (team_name, team_id) for the single team automatically (no prompt)
|
|
32
|
+
|
|
33
|
+
HIGH-LEVEL SCENARIO 3: --team NOT PROVIDED, Chain exists
|
|
34
|
+
SCENARIO 4: User has multiple teams, existing chain in exactly one team
|
|
35
|
+
→ Auto-detects and returns (team_name, team_id) for that team (no prompt)
|
|
36
|
+
SCENARIO 5: User has multiple teams, existing chain exists in multiple teams
|
|
37
|
+
→ Prompts user to select a team via inquire_team()
|
|
38
|
+
SCENARIO 7: User has exactly one team, existing chain matches the team
|
|
39
|
+
→ Auto-detects and returns (team_name, team_id) for the single team (no prompt)
|
|
40
|
+
SCENARIO 8: User has exactly one team, existing chain exists in different team
|
|
41
|
+
→ Returns (team_name, team_id) for the single team automatically (no prompt, uses user's only team)
|
|
42
|
+
"""
|
|
43
|
+
if existing_teams is None:
|
|
44
|
+
existing_teams = remote_provider.api.get_teams()
|
|
45
|
+
|
|
46
|
+
def _get_team_id(team_name: Optional[str]) -> Optional[str]:
|
|
47
|
+
if team_name and existing_teams:
|
|
48
|
+
team_data = existing_teams.get(team_name)
|
|
49
|
+
return team_data["id"] if team_data else None
|
|
50
|
+
return None
|
|
51
|
+
|
|
52
|
+
if provided_team_name is not None:
|
|
53
|
+
if provided_team_name not in existing_teams:
|
|
54
|
+
available_teams_str = remote_cli.format_available_teams(existing_teams)
|
|
55
|
+
raise click.ClickException(
|
|
56
|
+
f"Team '{provided_team_name}' does not exist. Available teams: {available_teams_str}"
|
|
57
|
+
)
|
|
58
|
+
return (provided_team_name, _get_team_id(provided_team_name))
|
|
59
|
+
|
|
60
|
+
existing_chains = None
|
|
61
|
+
if existing_chain_name is not None:
|
|
62
|
+
existing_chains = remote_provider.api.get_chains()
|
|
63
|
+
matching_chains = [
|
|
64
|
+
c for c in existing_chains if c.get("name") == existing_chain_name
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
if len(matching_chains) > 1:
|
|
68
|
+
selected_team_name = remote_cli.inquire_team(existing_teams=existing_teams)
|
|
69
|
+
return (selected_team_name, _get_team_id(selected_team_name))
|
|
70
|
+
|
|
71
|
+
if len(matching_chains) == 1:
|
|
72
|
+
chain_team = matching_chains[0].get("team")
|
|
73
|
+
chain_team_name = chain_team.get("name") if chain_team else None
|
|
74
|
+
if chain_team_name and chain_team_name in existing_teams:
|
|
75
|
+
return (chain_team_name, _get_team_id(chain_team_name))
|
|
76
|
+
|
|
77
|
+
if len(existing_teams) == 1:
|
|
78
|
+
single_team_name = list(existing_teams.keys())[0]
|
|
79
|
+
return (single_team_name, _get_team_id(single_team_name))
|
|
80
|
+
|
|
81
|
+
selected_team_name = remote_cli.inquire_team(existing_teams=existing_teams)
|
|
82
|
+
return (selected_team_name, _get_team_id(selected_team_name))
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""Team resolution logic for models."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import click
|
|
6
|
+
|
|
7
|
+
from truss.cli import remote_cli
|
|
8
|
+
from truss.remote.baseten.remote import BasetenRemote
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def resolve_model_team_name(
|
|
12
|
+
remote_provider: BasetenRemote,
|
|
13
|
+
provided_team_name: Optional[str],
|
|
14
|
+
existing_model_name: Optional[str] = None,
|
|
15
|
+
existing_teams: Optional[dict[str, dict[str, str]]] = None,
|
|
16
|
+
) -> tuple[Optional[str], Optional[str]]:
|
|
17
|
+
"""Resolve team name and team_id from provided team name or by prompting the user.
|
|
18
|
+
Returns a tuple of (team_name, team_id).
|
|
19
|
+
This function handles 8 distinct scenarios organized into 3 high-level categories:
|
|
20
|
+
|
|
21
|
+
HIGH-LEVEL SCENARIO 1: --team PROVIDED
|
|
22
|
+
SCENARIO 1: Valid team name, user has access
|
|
23
|
+
→ Returns (team_name, team_id) for that team (no prompt, no error)
|
|
24
|
+
SCENARIO 2: Invalid team name (does not exist)
|
|
25
|
+
→ Raises ClickException with error message listing available teams
|
|
26
|
+
|
|
27
|
+
HIGH-LEVEL SCENARIO 2: --team NOT PROVIDED, Model does not exist
|
|
28
|
+
SCENARIO 3: User has multiple teams, no existing model
|
|
29
|
+
→ Prompts user to select a team via inquire_team()
|
|
30
|
+
SCENARIO 6: User has exactly one team, no existing model
|
|
31
|
+
→ Returns (team_name, team_id) for the single team automatically (no prompt)
|
|
32
|
+
|
|
33
|
+
HIGH-LEVEL SCENARIO 3: --team NOT PROVIDED, Model exists
|
|
34
|
+
SCENARIO 4: User has multiple teams, existing model in exactly one team
|
|
35
|
+
→ Auto-detects and returns (team_name, team_id) for that team (no prompt)
|
|
36
|
+
SCENARIO 5: User has multiple teams, existing model exists in multiple teams
|
|
37
|
+
→ Prompts user to select a team via inquire_team()
|
|
38
|
+
SCENARIO 7: User has exactly one team, existing model matches the team
|
|
39
|
+
→ Auto-detects and returns (team_name, team_id) for the single team (no prompt)
|
|
40
|
+
SCENARIO 8: User has exactly one team, existing model exists in different team
|
|
41
|
+
→ Returns (team_name, team_id) for the single team automatically (no prompt, uses user's only team)
|
|
42
|
+
"""
|
|
43
|
+
if existing_teams is None:
|
|
44
|
+
existing_teams = remote_provider.api.get_teams()
|
|
45
|
+
|
|
46
|
+
def _get_team_id(team_name: Optional[str]) -> Optional[str]:
|
|
47
|
+
if team_name and existing_teams:
|
|
48
|
+
team_data = existing_teams.get(team_name)
|
|
49
|
+
return team_data["id"] if team_data else None
|
|
50
|
+
return None
|
|
51
|
+
|
|
52
|
+
def _get_matching_models_in_accessible_teams(model_name: str) -> list[dict]:
|
|
53
|
+
"""Get models matching the name that are in teams the user has access to."""
|
|
54
|
+
all_models_data = remote_provider.api.models()
|
|
55
|
+
accessible_team_ids = {team_data["id"] for team_data in existing_teams.values()}
|
|
56
|
+
|
|
57
|
+
return [
|
|
58
|
+
m
|
|
59
|
+
for m in all_models_data.get("models", [])
|
|
60
|
+
if m.get("name") == model_name
|
|
61
|
+
and m.get("team", {}).get("id") in accessible_team_ids
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
if provided_team_name is not None:
|
|
65
|
+
if provided_team_name not in existing_teams:
|
|
66
|
+
available_teams_str = remote_cli.format_available_teams(existing_teams)
|
|
67
|
+
raise click.ClickException(
|
|
68
|
+
f"Team '{provided_team_name}' does not exist. Available teams: {available_teams_str}"
|
|
69
|
+
)
|
|
70
|
+
return (provided_team_name, _get_team_id(provided_team_name))
|
|
71
|
+
|
|
72
|
+
if existing_model_name is not None:
|
|
73
|
+
matching_models = _get_matching_models_in_accessible_teams(existing_model_name)
|
|
74
|
+
|
|
75
|
+
if len(matching_models) == 1:
|
|
76
|
+
# Exactly one model in an accessible team - auto-detect
|
|
77
|
+
team = matching_models[0].get("team", {})
|
|
78
|
+
model_team_name = team.get("name")
|
|
79
|
+
model_team_id = team.get("id")
|
|
80
|
+
if model_team_name and model_team_name in existing_teams:
|
|
81
|
+
return (model_team_name, model_team_id)
|
|
82
|
+
# If len > 1, multiple models exist - fall through to prompt logic
|
|
83
|
+
# If len == 0, no models exist - fall through to prompt logic
|
|
84
|
+
|
|
85
|
+
if len(existing_teams) == 1:
|
|
86
|
+
single_team_name = list(existing_teams.keys())[0]
|
|
87
|
+
return (single_team_name, _get_team_id(single_team_name))
|
|
88
|
+
|
|
89
|
+
selected_team_name = remote_cli.inquire_team(existing_teams=existing_teams)
|
|
90
|
+
return (selected_team_name, _get_team_id(selected_team_name))
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""Team resolution logic for training projects."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import click
|
|
6
|
+
|
|
7
|
+
from truss.cli import remote_cli
|
|
8
|
+
from truss.remote.baseten.remote import BasetenRemote
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def resolve_training_project_team_name(
|
|
12
|
+
remote_provider: BasetenRemote,
|
|
13
|
+
provided_team_name: Optional[str],
|
|
14
|
+
existing_project_name: Optional[str] = None,
|
|
15
|
+
existing_teams: Optional[dict[str, dict[str, str]]] = None,
|
|
16
|
+
) -> tuple[Optional[str], Optional[str]]:
|
|
17
|
+
"""Resolve team name and team_id from provided team name or by prompting the user.
|
|
18
|
+
Returns a tuple of (team_name, team_id).
|
|
19
|
+
This function handles 8 distinct scenarios organized into 3 high-level categories:
|
|
20
|
+
|
|
21
|
+
HIGH-LEVEL SCENARIO 1: --team PROVIDED
|
|
22
|
+
SCENARIO 1: Valid team name, user has access
|
|
23
|
+
→ Returns (team_name, team_id) for that team (no prompt, no error)
|
|
24
|
+
SCENARIO 2: Invalid team name (does not exist)
|
|
25
|
+
→ Raises ClickException with error message listing available teams
|
|
26
|
+
|
|
27
|
+
HIGH-LEVEL SCENARIO 2: --team NOT PROVIDED, Training project does not exist
|
|
28
|
+
SCENARIO 3: User has multiple teams, no existing project
|
|
29
|
+
→ Prompts user to select a team via inquire_team()
|
|
30
|
+
SCENARIO 6: User has exactly one team, no existing project
|
|
31
|
+
→ Returns (team_name, team_id) for the single team automatically (no prompt)
|
|
32
|
+
|
|
33
|
+
HIGH-LEVEL SCENARIO 3: --team NOT PROVIDED, Training project exists
|
|
34
|
+
SCENARIO 4: User has multiple teams, existing project in exactly one team
|
|
35
|
+
→ Auto-detects and returns (team_name, team_id) for that team (no prompt)
|
|
36
|
+
SCENARIO 5: User has multiple teams, existing project exists in multiple teams
|
|
37
|
+
→ Prompts user to select a team via inquire_team()
|
|
38
|
+
SCENARIO 7: User has exactly one team, existing project matches the team
|
|
39
|
+
→ Auto-detects and returns (team_name, team_id) for the single team (no prompt)
|
|
40
|
+
SCENARIO 8: User has exactly one team, existing project exists in different team
|
|
41
|
+
→ Returns (team_name, team_id) for the single team automatically (no prompt, uses user's only team)
|
|
42
|
+
"""
|
|
43
|
+
if existing_teams is None:
|
|
44
|
+
existing_teams = remote_provider.api.get_teams()
|
|
45
|
+
|
|
46
|
+
def _get_team_id(team_name: Optional[str]) -> Optional[str]:
|
|
47
|
+
if team_name and existing_teams:
|
|
48
|
+
team_data = existing_teams.get(team_name)
|
|
49
|
+
return team_data["id"] if team_data else None
|
|
50
|
+
return None
|
|
51
|
+
|
|
52
|
+
if provided_team_name is not None:
|
|
53
|
+
if provided_team_name not in existing_teams:
|
|
54
|
+
available_teams_str = remote_cli.format_available_teams(existing_teams)
|
|
55
|
+
raise click.ClickException(
|
|
56
|
+
f"Team '{provided_team_name}' does not exist. Available teams: {available_teams_str}"
|
|
57
|
+
)
|
|
58
|
+
return (provided_team_name, _get_team_id(provided_team_name))
|
|
59
|
+
|
|
60
|
+
existing_projects = None
|
|
61
|
+
if existing_project_name is not None:
|
|
62
|
+
existing_projects = remote_provider.api.list_training_projects()
|
|
63
|
+
matching_projects = [
|
|
64
|
+
p for p in existing_projects if p.get("name") == existing_project_name
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
if len(matching_projects) > 1:
|
|
68
|
+
selected_team_name = remote_cli.inquire_team(existing_teams=existing_teams)
|
|
69
|
+
return (selected_team_name, _get_team_id(selected_team_name))
|
|
70
|
+
|
|
71
|
+
if len(matching_projects) == 1:
|
|
72
|
+
project_team_name = matching_projects[0].get("team_name")
|
|
73
|
+
if project_team_name in existing_teams:
|
|
74
|
+
return (project_team_name, _get_team_id(project_team_name))
|
|
75
|
+
|
|
76
|
+
if len(existing_teams) == 1:
|
|
77
|
+
single_team_name = list(existing_teams.keys())[0]
|
|
78
|
+
return (single_team_name, _get_team_id(single_team_name))
|
|
79
|
+
|
|
80
|
+
selected_team_name = remote_cli.inquire_team(existing_teams=existing_teams)
|
|
81
|
+
return (selected_team_name, _get_team_id(selected_team_name))
|