truss 0.11.18rc500__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. truss/api/__init__.py +5 -2
  2. truss/base/truss_config.py +10 -3
  3. truss/cli/chains_commands.py +39 -1
  4. truss/cli/cli.py +35 -5
  5. truss/cli/remote_cli.py +29 -0
  6. truss/cli/resolvers/chain_team_resolver.py +82 -0
  7. truss/cli/resolvers/model_team_resolver.py +90 -0
  8. truss/cli/resolvers/training_project_team_resolver.py +81 -0
  9. truss/cli/train/cache.py +332 -0
  10. truss/cli/train/core.py +19 -143
  11. truss/cli/train_commands.py +69 -11
  12. truss/cli/utils/common.py +40 -3
  13. truss/remote/baseten/api.py +58 -5
  14. truss/remote/baseten/core.py +22 -4
  15. truss/remote/baseten/remote.py +24 -2
  16. truss/templates/control/control/helpers/inference_server_process_controller.py +3 -1
  17. truss/templates/server/requirements.txt +1 -1
  18. truss/templates/server.Dockerfile.jinja +10 -10
  19. truss/templates/shared/util.py +6 -5
  20. truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
  21. truss/tests/cli/test_chains_cli.py +44 -0
  22. truss/tests/cli/test_cli.py +134 -1
  23. truss/tests/cli/test_cli_utils_common.py +11 -0
  24. truss/tests/cli/test_model_team_resolver.py +279 -0
  25. truss/tests/cli/train/test_cache_view.py +240 -3
  26. truss/tests/cli/train/test_train_cli_core.py +2 -2
  27. truss/tests/cli/train/test_train_team_parameter.py +395 -0
  28. truss/tests/conftest.py +187 -0
  29. truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
  30. truss/tests/remote/baseten/test_api.py +122 -3
  31. truss/tests/remote/baseten/test_chain_upload.py +10 -1
  32. truss/tests/remote/baseten/test_core.py +86 -0
  33. truss/tests/remote/baseten/test_remote.py +216 -288
  34. truss/tests/test_config.py +21 -12
  35. truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
  36. truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
  37. truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
  38. truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
  39. truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
  40. truss/tests/test_model_inference.py +13 -0
  41. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/METADATA +1 -1
  42. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/RECORD +50 -38
  43. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
  44. truss_chains/deployment/deployment_client.py +9 -4
  45. truss_chains/private_types.py +15 -0
  46. truss_train/definitions.py +3 -1
  47. truss_train/deployment.py +43 -21
  48. truss_train/public_api.py +4 -2
  49. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
  50. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
truss/api/__init__.py CHANGED
@@ -65,6 +65,7 @@ def push(
65
65
  progress_bar: Optional[Type["progress.Progress"]] = None,
66
66
  include_git_info: bool = False,
67
67
  preserve_env_instance_type: bool = True,
68
+ deploy_timeout_minutes: Optional[int] = None,
68
69
  ) -> definitions.ModelDeployment:
69
70
  """
70
71
  Pushes a Truss to Baseten.
@@ -77,13 +78,13 @@ def push(
77
78
  promote the truss to production after deploy completes.
78
79
  promote: Push the truss as a published deployment. Even if a production deployment exists,
79
80
  promote the truss to production after deploy completes.
80
- preserve_previous_production_deployment: Preserve the previous production deployments autoscaling
81
+ preserve_previous_production_deployment: Preserve the previous production deployment's autoscaling
81
82
  setting. When not specified, the previous production deployment will be updated to allow it to
82
83
  scale to zero. Can only be use in combination with `promote` option.
83
84
  trusted: [DEPRECATED]
84
85
  deployment_name: Name of the deployment created by the push. Can only be
85
86
  used in combination with `publish` or `promote`. Deployment name must
86
- only contain alphanumeric, ’.’, ’-’ or _ characters.
87
+ only contain alphanumeric, '.', '-' or '_' characters.
87
88
  environment: Name of stable environment on baseten.
88
89
  progress_bar: Optional `rich.progress.Progress` if output is desired.
89
90
  include_git_info: Whether to attach git versioning info (sha, branch, tag) to
@@ -92,6 +93,7 @@ def push(
92
93
  preserve_env_instance_type: When pushing a truss to an environment, whether to use the resources
93
94
  specified in the truss config to resolve the instance type or preserve the instance type
94
95
  configured in the specified environment.
96
+ deploy_timeout_minutes: Optional timeout in minutes for the deployment operation.
95
97
 
96
98
  Returns:
97
99
  The newly created ModelDeployment.
@@ -135,6 +137,7 @@ def push(
135
137
  progress_bar=progress_bar,
136
138
  include_git_info=include_git_info,
137
139
  preserve_env_instance_type=preserve_env_instance_type,
140
+ deploy_timeout_minutes=deploy_timeout_minutes,
138
141
  ) # type: ignore
139
142
 
140
143
  return definitions.ModelDeployment(cast(BasetenService, service))
@@ -147,7 +147,7 @@ class ModelRepo(custom_types.ConfigModel):
147
147
  volume_folder: Optional[
148
148
  Annotated[str, pydantic.StringConstraints(min_length=1)]
149
149
  ] = None
150
- use_volume: bool = False
150
+ use_volume: bool
151
151
  kind: ModelRepoSourceKind = ModelRepoSourceKind.HF
152
152
  runtime_secret_name: str = "hf_access_token"
153
153
 
@@ -163,7 +163,7 @@ class ModelRepo(custom_types.ConfigModel):
163
163
  return v
164
164
  if v.get("kind") == ModelRepoSourceKind.HF.value and v.get("revision") is None:
165
165
  logger.warning(
166
- "the key `revision: str` is required for use_volume=True huggingface repos."
166
+ "the key `revision: str` is required for use_volume=True huggingface repos. For S3/GCS/Azure repos, set it to any non-empty string."
167
167
  )
168
168
  raise_insufficent_revision(v.get("repo_id"), v.get("revision"))
169
169
  if v.get("volume_folder") is None or len(v["volume_folder"]) == 0:
@@ -202,7 +202,14 @@ class ModelCache(pydantic.RootModel[list[ModelRepo]]):
202
202
  )
203
203
 
204
204
 
205
- class CacheInternal(ModelCache): ...
205
+ class ModelRepoCacheInternal(ModelRepo):
206
+ use_volume: bool = False # override
207
+
208
+
209
+ class CacheInternal(pydantic.RootModel[list[ModelRepoCacheInternal]]):
210
+ @property
211
+ def models(self) -> list[ModelRepoCacheInternal]:
212
+ return self.root
206
213
 
207
214
 
208
215
  class HealthChecks(custom_types.ConfigModel):
@@ -1,6 +1,6 @@
1
1
  import time
2
2
  from pathlib import Path
3
- from typing import List, Optional, Tuple
3
+ from typing import List, Optional, Tuple, cast
4
4
 
5
5
  import rich
6
6
  import rich.live
@@ -14,10 +14,13 @@ from rich import progress
14
14
 
15
15
  from truss.cli import remote_cli
16
16
  from truss.cli.cli import truss_cli
17
+ from truss.cli.resolvers.chain_team_resolver import resolve_chain_team_name
17
18
  from truss.cli.utils import common, output
18
19
  from truss.cli.utils.output import console
19
20
  from truss.remote.baseten.core import ACTIVE_STATUS, DEPLOYING_STATUSES
21
+ from truss.remote.baseten.remote import BasetenRemote
20
22
  from truss.remote.baseten.utils.status import get_displayable_status
23
+ from truss.remote.remote_factory import RemoteFactory
21
24
  from truss.util import user_config
22
25
  from truss.util.log_utils import LogInterceptor
23
26
 
@@ -219,6 +222,22 @@ def _create_chains_table(service) -> Tuple[rich.table.Table, List[str]]:
219
222
  default=False,
220
223
  help="Disable downloading of pushed chain source code from the UI.",
221
224
  )
225
+ @click.option(
226
+ "--deployment-name",
227
+ type=str,
228
+ required=False,
229
+ help=(
230
+ "Name of the deployment created by the publish. Can only be used "
231
+ "in combination with '--publish' or '--promote'."
232
+ ),
233
+ )
234
+ @click.option(
235
+ "--team",
236
+ "provided_team_name",
237
+ type=str,
238
+ required=False,
239
+ help="Team name for the chain deployment",
240
+ )
222
241
  @click.pass_context
223
242
  @common.common_options()
224
243
  def push_chain(
@@ -236,6 +255,8 @@ def push_chain(
236
255
  experimental_watch_chainlet_names: Optional[str],
237
256
  include_git_info: bool = False,
238
257
  disable_chain_download: bool = False,
258
+ deployment_name: Optional[str] = None,
259
+ provided_team_name: Optional[str] = None,
239
260
  ) -> None:
240
261
  """
241
262
  Deploys a chain remotely.
@@ -280,10 +301,24 @@ def push_chain(
280
301
  if not include_git_info:
281
302
  include_git_info = user_config.settings.include_git_info
282
303
 
304
+ # Resolve team if not in dryrun mode
305
+ team_id = None
283
306
  with framework.ChainletImporter.import_target(source, entrypoint) as entrypoint_cls:
284
307
  chain_name = (
285
308
  name or entrypoint_cls.meta_data.chain_name or entrypoint_cls.display_name
286
309
  )
310
+
311
+ remote_provider = None
312
+ if not dryrun and remote:
313
+ remote_provider = cast(BasetenRemote, RemoteFactory.create(remote=remote))
314
+ existing_teams = remote_provider.api.get_teams()
315
+ _, team_id = resolve_chain_team_name(
316
+ remote_provider,
317
+ provided_team_name,
318
+ existing_chain_name=chain_name,
319
+ existing_teams=existing_teams,
320
+ )
321
+
287
322
  options = chains_def.PushOptionsBaseten.create(
288
323
  chain_name=chain_name,
289
324
  promote=promote,
@@ -294,6 +329,9 @@ def push_chain(
294
329
  include_git_info=include_git_info,
295
330
  working_dir=source.parent if source.is_file() else source.resolve(),
296
331
  disable_chain_download=disable_chain_download,
332
+ deployment_name=deployment_name,
333
+ team_id=team_id,
334
+ remote_provider=remote_provider,
297
335
  )
298
336
  service = deployment_client.push(
299
337
  entrypoint_cls, options, progress_bar=progress.Progress
truss/cli/cli.py CHANGED
@@ -19,6 +19,7 @@ from truss.base.truss_config import Build, ModelServer, TransportKind
19
19
  from truss.cli import remote_cli
20
20
  from truss.cli.logs import utils as cli_log_utils
21
21
  from truss.cli.logs.model_log_watcher import ModelDeploymentLogWatcher
22
+ from truss.cli.resolvers.model_team_resolver import resolve_model_team_name
22
23
  from truss.cli.utils import common
23
24
  from truss.cli.utils.output import console, error_console
24
25
  from truss.remote.baseten.core import (
@@ -462,7 +463,7 @@ def run_python(script, target_directory):
462
463
  required=False,
463
464
  help=(
464
465
  "Name of the deployment created by the push. Can only be "
465
- "used in combination with '--publish' or '--promote'."
466
+ "used in combination with --publish or --promote."
466
467
  ),
467
468
  )
468
469
  @click.option(
@@ -501,6 +502,19 @@ def run_python(script, target_directory):
501
502
  "Default is --preserve-env-instance-type."
502
503
  ),
503
504
  )
505
+ @click.option(
506
+ "--deploy-timeout-minutes",
507
+ type=int,
508
+ required=False,
509
+ help="Timeout in minutes for the deploy operation.",
510
+ )
511
+ @click.option(
512
+ "--team",
513
+ "provided_team_name",
514
+ type=str,
515
+ required=False,
516
+ help="Team name for the model",
517
+ )
504
518
  @common.common_options()
505
519
  def push(
506
520
  target_directory: str,
@@ -518,6 +532,8 @@ def push(
518
532
  include_git_info: bool = False,
519
533
  tail: bool = False,
520
534
  preserve_env_instance_type: bool = True,
535
+ deploy_timeout_minutes: Optional[int] = None,
536
+ provided_team_name: Optional[str] = None,
521
537
  ) -> None:
522
538
  """
523
539
  Pushes a truss to a TrussRemote.
@@ -547,6 +563,17 @@ def push(
547
563
  if not model_name:
548
564
  model_name = remote_cli.inquire_model_name()
549
565
 
566
+ # Resolve team_id if BasetenRemote
567
+ team_id = None
568
+ if isinstance(remote_provider, BasetenRemote):
569
+ existing_teams = remote_provider.api.get_teams()
570
+ _, team_id = resolve_model_team_name(
571
+ remote_provider=remote_provider,
572
+ provided_team_name=provided_team_name,
573
+ existing_model_name=model_name,
574
+ existing_teams=existing_teams,
575
+ )
576
+
550
577
  if promote and environment:
551
578
  raise click.UsageError(
552
579
  "'promote' flag and 'environment' flag cannot both be specified."
@@ -612,11 +639,12 @@ def push(
612
639
  console.print(fp8_and_num_builder_gpus_text, style="yellow")
613
640
 
614
641
  source = Path(target_directory)
615
- # TODO(Abu): This needs to be refactored to be more generic
642
+ working_dir = source.parent if source.is_file() else source.resolve()
643
+
616
644
  service = remote_provider.push(
617
- tr,
645
+ truss_handle=tr,
618
646
  model_name=model_name,
619
- working_dir=source.parent if source.is_file() else source.resolve(),
647
+ working_dir=working_dir,
620
648
  publish=publish,
621
649
  promote=promote,
622
650
  preserve_previous_prod_deployment=preserve_previous_production_deployment,
@@ -626,7 +654,9 @@ def push(
626
654
  progress_bar=progress.Progress,
627
655
  include_git_info=include_git_info,
628
656
  preserve_env_instance_type=preserve_env_instance_type,
629
- ) # type: ignore
657
+ deploy_timeout_minutes=deploy_timeout_minutes,
658
+ team_id=team_id,
659
+ )
630
660
 
631
661
  click.echo(f"✨ Model {model_name} was successfully pushed ✨")
632
662
 
truss/cli/remote_cli.py CHANGED
@@ -1,3 +1,5 @@
1
+ from typing import Optional
2
+
1
3
  from InquirerPy import inquirer
2
4
  from InquirerPy.validator import ValidationError, Validator
3
5
 
@@ -56,3 +58,30 @@ def inquire_remote_name() -> str:
56
58
 
57
59
  def inquire_model_name() -> str:
58
60
  return inquirer.text("📦 Name this model:", qmark="").execute()
61
+
62
+
63
+ def get_team_id_from_name(
64
+ teams: dict[str, dict[str, str]], team_name: str
65
+ ) -> Optional[str]:
66
+ team = teams.get(team_name)
67
+ return team["id"] if team else None
68
+
69
+
70
+ def format_available_teams(teams: dict[str, dict[str, str]]) -> str:
71
+ team_names = list(teams.keys())
72
+ return ", ".join(team_names) if team_names else "none"
73
+
74
+
75
+ def inquire_team(
76
+ existing_teams: Optional[dict[str, dict[str, str]]] = None,
77
+ ) -> Optional[str]:
78
+ if existing_teams is not None:
79
+ selected_team_name = inquirer.select(
80
+ "👥 Which team do you want to push to?",
81
+ qmark="",
82
+ choices=list[str](existing_teams.keys()),
83
+ ).execute()
84
+ return selected_team_name
85
+
86
+ # If no existing teams, return None (don't propagate team param)
87
+ return None
@@ -0,0 +1,82 @@
1
+ """Team resolution logic for chains."""
2
+
3
+ from typing import Optional
4
+
5
+ import click
6
+
7
+ from truss.cli import remote_cli
8
+ from truss.remote.baseten.remote import BasetenRemote
9
+
10
+
11
+ def resolve_chain_team_name(
12
+ remote_provider: BasetenRemote,
13
+ provided_team_name: Optional[str],
14
+ existing_chain_name: Optional[str] = None,
15
+ existing_teams: Optional[dict[str, dict[str, str]]] = None,
16
+ ) -> tuple[Optional[str], Optional[str]]:
17
+ """Resolve team name and team_id from provided team name or by prompting the user.
18
+ Returns a tuple of (team_name, team_id).
19
+ This function handles 8 distinct scenarios organized into 3 high-level categories:
20
+
21
+ HIGH-LEVEL SCENARIO 1: --team PROVIDED
22
+ SCENARIO 1: Valid team name, user has access
23
+ → Returns (team_name, team_id) for that team (no prompt, no error)
24
+ SCENARIO 2: Invalid team name (does not exist)
25
+ → Raises ClickException with error message listing available teams
26
+
27
+ HIGH-LEVEL SCENARIO 2: --team NOT PROVIDED, Chain does not exist
28
+ SCENARIO 3: User has multiple teams, no existing chain
29
+ → Prompts user to select a team via inquire_team()
30
+ SCENARIO 6: User has exactly one team, no existing chain
31
+ → Returns (team_name, team_id) for the single team automatically (no prompt)
32
+
33
+ HIGH-LEVEL SCENARIO 3: --team NOT PROVIDED, Chain exists
34
+ SCENARIO 4: User has multiple teams, existing chain in exactly one team
35
+ → Auto-detects and returns (team_name, team_id) for that team (no prompt)
36
+ SCENARIO 5: User has multiple teams, existing chain exists in multiple teams
37
+ → Prompts user to select a team via inquire_team()
38
+ SCENARIO 7: User has exactly one team, existing chain matches the team
39
+ → Auto-detects and returns (team_name, team_id) for the single team (no prompt)
40
+ SCENARIO 8: User has exactly one team, existing chain exists in different team
41
+ → Returns (team_name, team_id) for the single team automatically (no prompt, uses user's only team)
42
+ """
43
+ if existing_teams is None:
44
+ existing_teams = remote_provider.api.get_teams()
45
+
46
+ def _get_team_id(team_name: Optional[str]) -> Optional[str]:
47
+ if team_name and existing_teams:
48
+ team_data = existing_teams.get(team_name)
49
+ return team_data["id"] if team_data else None
50
+ return None
51
+
52
+ if provided_team_name is not None:
53
+ if provided_team_name not in existing_teams:
54
+ available_teams_str = remote_cli.format_available_teams(existing_teams)
55
+ raise click.ClickException(
56
+ f"Team '{provided_team_name}' does not exist. Available teams: {available_teams_str}"
57
+ )
58
+ return (provided_team_name, _get_team_id(provided_team_name))
59
+
60
+ existing_chains = None
61
+ if existing_chain_name is not None:
62
+ existing_chains = remote_provider.api.get_chains()
63
+ matching_chains = [
64
+ c for c in existing_chains if c.get("name") == existing_chain_name
65
+ ]
66
+
67
+ if len(matching_chains) > 1:
68
+ selected_team_name = remote_cli.inquire_team(existing_teams=existing_teams)
69
+ return (selected_team_name, _get_team_id(selected_team_name))
70
+
71
+ if len(matching_chains) == 1:
72
+ chain_team = matching_chains[0].get("team")
73
+ chain_team_name = chain_team.get("name") if chain_team else None
74
+ if chain_team_name and chain_team_name in existing_teams:
75
+ return (chain_team_name, _get_team_id(chain_team_name))
76
+
77
+ if len(existing_teams) == 1:
78
+ single_team_name = list(existing_teams.keys())[0]
79
+ return (single_team_name, _get_team_id(single_team_name))
80
+
81
+ selected_team_name = remote_cli.inquire_team(existing_teams=existing_teams)
82
+ return (selected_team_name, _get_team_id(selected_team_name))
@@ -0,0 +1,90 @@
1
+ """Team resolution logic for models."""
2
+
3
+ from typing import Optional
4
+
5
+ import click
6
+
7
+ from truss.cli import remote_cli
8
+ from truss.remote.baseten.remote import BasetenRemote
9
+
10
+
11
+ def resolve_model_team_name(
12
+ remote_provider: BasetenRemote,
13
+ provided_team_name: Optional[str],
14
+ existing_model_name: Optional[str] = None,
15
+ existing_teams: Optional[dict[str, dict[str, str]]] = None,
16
+ ) -> tuple[Optional[str], Optional[str]]:
17
+ """Resolve team name and team_id from provided team name or by prompting the user.
18
+ Returns a tuple of (team_name, team_id).
19
+ This function handles 8 distinct scenarios organized into 3 high-level categories:
20
+
21
+ HIGH-LEVEL SCENARIO 1: --team PROVIDED
22
+ SCENARIO 1: Valid team name, user has access
23
+ → Returns (team_name, team_id) for that team (no prompt, no error)
24
+ SCENARIO 2: Invalid team name (does not exist)
25
+ → Raises ClickException with error message listing available teams
26
+
27
+ HIGH-LEVEL SCENARIO 2: --team NOT PROVIDED, Model does not exist
28
+ SCENARIO 3: User has multiple teams, no existing model
29
+ → Prompts user to select a team via inquire_team()
30
+ SCENARIO 6: User has exactly one team, no existing model
31
+ → Returns (team_name, team_id) for the single team automatically (no prompt)
32
+
33
+ HIGH-LEVEL SCENARIO 3: --team NOT PROVIDED, Model exists
34
+ SCENARIO 4: User has multiple teams, existing model in exactly one team
35
+ → Auto-detects and returns (team_name, team_id) for that team (no prompt)
36
+ SCENARIO 5: User has multiple teams, existing model exists in multiple teams
37
+ → Prompts user to select a team via inquire_team()
38
+ SCENARIO 7: User has exactly one team, existing model matches the team
39
+ → Auto-detects and returns (team_name, team_id) for the single team (no prompt)
40
+ SCENARIO 8: User has exactly one team, existing model exists in different team
41
+ → Returns (team_name, team_id) for the single team automatically (no prompt, uses user's only team)
42
+ """
43
+ if existing_teams is None:
44
+ existing_teams = remote_provider.api.get_teams()
45
+
46
+ def _get_team_id(team_name: Optional[str]) -> Optional[str]:
47
+ if team_name and existing_teams:
48
+ team_data = existing_teams.get(team_name)
49
+ return team_data["id"] if team_data else None
50
+ return None
51
+
52
+ def _get_matching_models_in_accessible_teams(model_name: str) -> list[dict]:
53
+ """Get models matching the name that are in teams the user has access to."""
54
+ all_models_data = remote_provider.api.models()
55
+ accessible_team_ids = {team_data["id"] for team_data in existing_teams.values()}
56
+
57
+ return [
58
+ m
59
+ for m in all_models_data.get("models", [])
60
+ if m.get("name") == model_name
61
+ and m.get("team", {}).get("id") in accessible_team_ids
62
+ ]
63
+
64
+ if provided_team_name is not None:
65
+ if provided_team_name not in existing_teams:
66
+ available_teams_str = remote_cli.format_available_teams(existing_teams)
67
+ raise click.ClickException(
68
+ f"Team '{provided_team_name}' does not exist. Available teams: {available_teams_str}"
69
+ )
70
+ return (provided_team_name, _get_team_id(provided_team_name))
71
+
72
+ if existing_model_name is not None:
73
+ matching_models = _get_matching_models_in_accessible_teams(existing_model_name)
74
+
75
+ if len(matching_models) == 1:
76
+ # Exactly one model in an accessible team - auto-detect
77
+ team = matching_models[0].get("team", {})
78
+ model_team_name = team.get("name")
79
+ model_team_id = team.get("id")
80
+ if model_team_name and model_team_name in existing_teams:
81
+ return (model_team_name, model_team_id)
82
+ # If len > 1, multiple models exist - fall through to prompt logic
83
+ # If len == 0, no models exist - fall through to prompt logic
84
+
85
+ if len(existing_teams) == 1:
86
+ single_team_name = list(existing_teams.keys())[0]
87
+ return (single_team_name, _get_team_id(single_team_name))
88
+
89
+ selected_team_name = remote_cli.inquire_team(existing_teams=existing_teams)
90
+ return (selected_team_name, _get_team_id(selected_team_name))
@@ -0,0 +1,81 @@
1
+ """Team resolution logic for training projects."""
2
+
3
+ from typing import Optional
4
+
5
+ import click
6
+
7
+ from truss.cli import remote_cli
8
+ from truss.remote.baseten.remote import BasetenRemote
9
+
10
+
11
+ def resolve_training_project_team_name(
12
+ remote_provider: BasetenRemote,
13
+ provided_team_name: Optional[str],
14
+ existing_project_name: Optional[str] = None,
15
+ existing_teams: Optional[dict[str, dict[str, str]]] = None,
16
+ ) -> tuple[Optional[str], Optional[str]]:
17
+ """Resolve team name and team_id from provided team name or by prompting the user.
18
+ Returns a tuple of (team_name, team_id).
19
+ This function handles 8 distinct scenarios organized into 3 high-level categories:
20
+
21
+ HIGH-LEVEL SCENARIO 1: --team PROVIDED
22
+ SCENARIO 1: Valid team name, user has access
23
+ → Returns (team_name, team_id) for that team (no prompt, no error)
24
+ SCENARIO 2: Invalid team name (does not exist)
25
+ → Raises ClickException with error message listing available teams
26
+
27
+ HIGH-LEVEL SCENARIO 2: --team NOT PROVIDED, Training project does not exist
28
+ SCENARIO 3: User has multiple teams, no existing project
29
+ → Prompts user to select a team via inquire_team()
30
+ SCENARIO 6: User has exactly one team, no existing project
31
+ → Returns (team_name, team_id) for the single team automatically (no prompt)
32
+
33
+ HIGH-LEVEL SCENARIO 3: --team NOT PROVIDED, Training project exists
34
+ SCENARIO 4: User has multiple teams, existing project in exactly one team
35
+ → Auto-detects and returns (team_name, team_id) for that team (no prompt)
36
+ SCENARIO 5: User has multiple teams, existing project exists in multiple teams
37
+ → Prompts user to select a team via inquire_team()
38
+ SCENARIO 7: User has exactly one team, existing project matches the team
39
+ → Auto-detects and returns (team_name, team_id) for the single team (no prompt)
40
+ SCENARIO 8: User has exactly one team, existing project exists in different team
41
+ → Returns (team_name, team_id) for the single team automatically (no prompt, uses user's only team)
42
+ """
43
+ if existing_teams is None:
44
+ existing_teams = remote_provider.api.get_teams()
45
+
46
+ def _get_team_id(team_name: Optional[str]) -> Optional[str]:
47
+ if team_name and existing_teams:
48
+ team_data = existing_teams.get(team_name)
49
+ return team_data["id"] if team_data else None
50
+ return None
51
+
52
+ if provided_team_name is not None:
53
+ if provided_team_name not in existing_teams:
54
+ available_teams_str = remote_cli.format_available_teams(existing_teams)
55
+ raise click.ClickException(
56
+ f"Team '{provided_team_name}' does not exist. Available teams: {available_teams_str}"
57
+ )
58
+ return (provided_team_name, _get_team_id(provided_team_name))
59
+
60
+ existing_projects = None
61
+ if existing_project_name is not None:
62
+ existing_projects = remote_provider.api.list_training_projects()
63
+ matching_projects = [
64
+ p for p in existing_projects if p.get("name") == existing_project_name
65
+ ]
66
+
67
+ if len(matching_projects) > 1:
68
+ selected_team_name = remote_cli.inquire_team(existing_teams=existing_teams)
69
+ return (selected_team_name, _get_team_id(selected_team_name))
70
+
71
+ if len(matching_projects) == 1:
72
+ project_team_name = matching_projects[0].get("team_name")
73
+ if project_team_name in existing_teams:
74
+ return (project_team_name, _get_team_id(project_team_name))
75
+
76
+ if len(existing_teams) == 1:
77
+ single_team_name = list(existing_teams.keys())[0]
78
+ return (single_team_name, _get_team_id(single_team_name))
79
+
80
+ selected_team_name = remote_cli.inquire_team(existing_teams=existing_teams)
81
+ return (selected_team_name, _get_team_id(selected_team_name))