truss 0.11.6rc102__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. truss/api/__init__.py +5 -2
  2. truss/base/constants.py +1 -0
  3. truss/base/trt_llm_config.py +14 -3
  4. truss/base/truss_config.py +19 -4
  5. truss/cli/chains_commands.py +49 -1
  6. truss/cli/cli.py +38 -7
  7. truss/cli/logs/base_watcher.py +31 -12
  8. truss/cli/logs/model_log_watcher.py +24 -1
  9. truss/cli/remote_cli.py +29 -0
  10. truss/cli/resolvers/chain_team_resolver.py +82 -0
  11. truss/cli/resolvers/model_team_resolver.py +90 -0
  12. truss/cli/resolvers/training_project_team_resolver.py +81 -0
  13. truss/cli/train/cache.py +332 -0
  14. truss/cli/train/core.py +57 -163
  15. truss/cli/train/deploy_checkpoints/__init__.py +2 -2
  16. truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +236 -103
  17. truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +1 -52
  18. truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +1 -86
  19. truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py +1 -85
  20. truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +1 -56
  21. truss/cli/train/types.py +18 -9
  22. truss/cli/train_commands.py +180 -35
  23. truss/cli/utils/common.py +40 -3
  24. truss/contexts/image_builder/serving_image_builder.py +17 -4
  25. truss/remote/baseten/api.py +215 -9
  26. truss/remote/baseten/core.py +63 -7
  27. truss/remote/baseten/custom_types.py +1 -0
  28. truss/remote/baseten/remote.py +42 -2
  29. truss/remote/baseten/service.py +0 -7
  30. truss/remote/baseten/utils/transfer.py +5 -2
  31. truss/templates/base.Dockerfile.jinja +8 -4
  32. truss/templates/control/control/application.py +51 -26
  33. truss/templates/control/control/endpoints.py +1 -5
  34. truss/templates/control/control/helpers/inference_server_process_controller.py +10 -4
  35. truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py +33 -18
  36. truss/templates/control/control/server.py +1 -1
  37. truss/templates/control/requirements.txt +1 -2
  38. truss/templates/docker_server/proxy.conf.jinja +13 -0
  39. truss/templates/docker_server/supervisord.conf.jinja +2 -1
  40. truss/templates/no_build.Dockerfile.jinja +1 -0
  41. truss/templates/server/requirements.txt +2 -3
  42. truss/templates/server/truss_server.py +2 -5
  43. truss/templates/server.Dockerfile.jinja +12 -12
  44. truss/templates/shared/lazy_data_resolver.py +214 -2
  45. truss/templates/shared/util.py +6 -5
  46. truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
  47. truss/tests/cli/test_chains_cli.py +144 -0
  48. truss/tests/cli/test_cli.py +134 -1
  49. truss/tests/cli/test_cli_utils_common.py +11 -0
  50. truss/tests/cli/test_model_team_resolver.py +279 -0
  51. truss/tests/cli/train/test_cache_view.py +240 -3
  52. truss/tests/cli/train/test_deploy_checkpoints.py +2 -846
  53. truss/tests/cli/train/test_train_cli_core.py +2 -2
  54. truss/tests/cli/train/test_train_team_parameter.py +395 -0
  55. truss/tests/conftest.py +187 -0
  56. truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
  57. truss/tests/remote/baseten/test_api.py +122 -3
  58. truss/tests/remote/baseten/test_chain_upload.py +294 -0
  59. truss/tests/remote/baseten/test_core.py +86 -0
  60. truss/tests/remote/baseten/test_remote.py +216 -288
  61. truss/tests/remote/baseten/test_service.py +56 -0
  62. truss/tests/templates/control/control/conftest.py +20 -0
  63. truss/tests/templates/control/control/test_endpoints.py +4 -0
  64. truss/tests/templates/control/control/test_server.py +8 -24
  65. truss/tests/templates/control/control/test_server_integration.py +4 -2
  66. truss/tests/test_config.py +21 -12
  67. truss/tests/test_data/server.Dockerfile +3 -1
  68. truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
  69. truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
  70. truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
  71. truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
  72. truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
  73. truss/tests/test_model_inference.py +13 -0
  74. truss/tests/util/test_env_vars.py +8 -3
  75. truss/util/__init__.py +0 -0
  76. truss/util/env_vars.py +19 -8
  77. truss/util/error_utils.py +37 -0
  78. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/METADATA +2 -2
  79. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/RECORD +88 -70
  80. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
  81. truss_chains/deployment/deployment_client.py +16 -4
  82. truss_chains/private_types.py +18 -0
  83. truss_chains/public_api.py +3 -0
  84. truss_train/definitions.py +6 -4
  85. truss_train/deployment.py +43 -21
  86. truss_train/public_api.py +4 -2
  87. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
  88. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
truss/api/__init__.py CHANGED
@@ -65,6 +65,7 @@ def push(
65
65
  progress_bar: Optional[Type["progress.Progress"]] = None,
66
66
  include_git_info: bool = False,
67
67
  preserve_env_instance_type: bool = True,
68
+ deploy_timeout_minutes: Optional[int] = None,
68
69
  ) -> definitions.ModelDeployment:
69
70
  """
70
71
  Pushes a Truss to Baseten.
@@ -77,13 +78,13 @@ def push(
77
78
  promote the truss to production after deploy completes.
78
79
  promote: Push the truss as a published deployment. Even if a production deployment exists,
79
80
  promote the truss to production after deploy completes.
80
- preserve_previous_production_deployment: Preserve the previous production deployments autoscaling
81
+ preserve_previous_production_deployment: Preserve the previous production deployment's autoscaling
81
82
  setting. When not specified, the previous production deployment will be updated to allow it to
82
83
  scale to zero. Can only be use in combination with `promote` option.
83
84
  trusted: [DEPRECATED]
84
85
  deployment_name: Name of the deployment created by the push. Can only be
85
86
  used in combination with `publish` or `promote`. Deployment name must
86
- only contain alphanumeric, ’.’, ’-’ or _ characters.
87
+ only contain alphanumeric, '.', '-' or '_' characters.
87
88
  environment: Name of stable environment on baseten.
88
89
  progress_bar: Optional `rich.progress.Progress` if output is desired.
89
90
  include_git_info: Whether to attach git versioning info (sha, branch, tag) to
@@ -92,6 +93,7 @@ def push(
92
93
  preserve_env_instance_type: When pushing a truss to an environment, whether to use the resources
93
94
  specified in the truss config to resolve the instance type or preserve the instance type
94
95
  configured in the specified environment.
96
+ deploy_timeout_minutes: Optional timeout in minutes for the deployment operation.
95
97
 
96
98
  Returns:
97
99
  The newly created ModelDeployment.
@@ -135,6 +137,7 @@ def push(
135
137
  progress_bar=progress_bar,
136
138
  include_git_info=include_git_info,
137
139
  preserve_env_instance_type=preserve_env_instance_type,
140
+ deploy_timeout_minutes=deploy_timeout_minutes,
138
141
  ) # type: ignore
139
142
 
140
143
  return definitions.ModelDeployment(cast(BasetenService, service))
truss/base/constants.py CHANGED
@@ -49,6 +49,7 @@ FILENAME_CONSTANTS_MAP = {
49
49
  }
50
50
 
51
51
  SERVER_DOCKERFILE_TEMPLATE_NAME = "server.Dockerfile.jinja"
52
+ NO_BUILD_DOCKERFILE_TEMPLATE_NAME = "no_build.Dockerfile.jinja"
52
53
  MODEL_DOCKERFILE_NAME = "Dockerfile"
53
54
  MODEL_CACHE_PATH = pathlib.Path("/app/model_cache")
54
55
  README_TEMPLATE_NAME = "README.md.jinja"
@@ -68,6 +68,7 @@ class TrussTRTLLMQuantizationType(str, Enum):
68
68
  FP8_KV = "fp8_kv"
69
69
  FP4 = "fp4"
70
70
  FP4_KV = "fp4_kv"
71
+ FP4_MLP_ONLY = "fp4_mlp_only"
71
72
 
72
73
 
73
74
  class TrussTRTLLMPluginConfiguration(PydanticTrTBaseModel):
@@ -329,9 +330,16 @@ pip install truss==0.10.8
329
330
  raise ValueError("Using fp8 context fmha requires paged context fmha")
330
331
  if (
331
332
  self.plugin_configuration.use_fp8_context_fmha
332
- and not self.quantization_type == TrussTRTLLMQuantizationType.FP8_KV
333
+ and self.quantization_type
334
+ not in (
335
+ TrussTRTLLMQuantizationType.FP8_KV,
336
+ TrussTRTLLMQuantizationType.FP4_KV,
337
+ )
333
338
  ):
334
- raise ValueError("Using fp8 context fmha requires fp8 kv cache dtype")
339
+ raise ValueError(
340
+ "Using fp8 context fmha requires fp8 kv, or fp4 with kv cache dtype"
341
+ )
342
+
335
343
  return self
336
344
 
337
345
  def _validate_speculator_config(self):
@@ -570,6 +578,7 @@ class TRTLLMConfigurationV2(PydanticTrTBaseModel):
570
578
  "quantization_type",
571
579
  "quantization_config",
572
580
  "skip_build_result",
581
+ "num_builder_gpus",
573
582
  ]
574
583
 
575
584
  build_settings = self.build.model_dump(exclude_unset=True)
@@ -705,7 +714,9 @@ def trt_llm_common_validation(config: "TrussConfig"):
705
714
  "accelerators or newer (CUDA_COMPUTE>=89)"
706
715
  )
707
716
  elif trt_llm_config.build.quantization_type in [
708
- TrussTRTLLMQuantizationType.FP4
717
+ TrussTRTLLMQuantizationType.FP4,
718
+ TrussTRTLLMQuantizationType.FP4_KV,
719
+ TrussTRTLLMQuantizationType.FP4_MLP_ONLY,
709
720
  ] and config.resources.accelerator.accelerator in [
710
721
  truss_config.Accelerator.H100,
711
722
  truss_config.Accelerator.L4,
@@ -147,7 +147,7 @@ class ModelRepo(custom_types.ConfigModel):
147
147
  volume_folder: Optional[
148
148
  Annotated[str, pydantic.StringConstraints(min_length=1)]
149
149
  ] = None
150
- use_volume: bool = False
150
+ use_volume: bool
151
151
  kind: ModelRepoSourceKind = ModelRepoSourceKind.HF
152
152
  runtime_secret_name: str = "hf_access_token"
153
153
 
@@ -163,7 +163,7 @@ class ModelRepo(custom_types.ConfigModel):
163
163
  return v
164
164
  if v.get("kind") == ModelRepoSourceKind.HF.value and v.get("revision") is None:
165
165
  logger.warning(
166
- "the key `revision: str` is required for use_volume=True huggingface repos."
166
+ "the key `revision: str` is required for use_volume=True huggingface repos. For S3/GCS/Azure repos, set it to any non-empty string."
167
167
  )
168
168
  raise_insufficent_revision(v.get("repo_id"), v.get("revision"))
169
169
  if v.get("volume_folder") is None or len(v["volume_folder"]) == 0:
@@ -202,7 +202,14 @@ class ModelCache(pydantic.RootModel[list[ModelRepo]]):
202
202
  )
203
203
 
204
204
 
205
- class CacheInternal(ModelCache): ...
205
+ class ModelRepoCacheInternal(ModelRepo):
206
+ use_volume: bool = False # override
207
+
208
+
209
+ class CacheInternal(pydantic.RootModel[list[ModelRepoCacheInternal]]):
210
+ @property
211
+ def models(self) -> list[ModelRepoCacheInternal]:
212
+ return self.root
206
213
 
207
214
 
208
215
  class HealthChecks(custom_types.ConfigModel):
@@ -541,11 +548,19 @@ class BaseImage(custom_types.ConfigModel):
541
548
 
542
549
 
543
550
  class DockerServer(custom_types.ConfigModel):
544
- start_command: str
551
+ start_command: Optional[str] = None
545
552
  server_port: int
546
553
  predict_endpoint: str
547
554
  readiness_endpoint: str
548
555
  liveness_endpoint: str
556
+ run_as_user_id: Optional[int] = None
557
+ no_build: Optional[bool] = None
558
+
559
+ @pydantic.model_validator(mode="after")
560
+ def _validate_start_command(self) -> "DockerServer":
561
+ if not self.no_build and self.start_command is None:
562
+ raise ValueError("start_command is required when no_build is not true")
563
+ return self
549
564
 
550
565
 
551
566
  class TrainingArtifactReference(custom_types.ConfigModel):
@@ -1,6 +1,6 @@
1
1
  import time
2
2
  from pathlib import Path
3
- from typing import List, Optional, Tuple
3
+ from typing import List, Optional, Tuple, cast
4
4
 
5
5
  import rich
6
6
  import rich.live
@@ -14,10 +14,13 @@ from rich import progress
14
14
 
15
15
  from truss.cli import remote_cli
16
16
  from truss.cli.cli import truss_cli
17
+ from truss.cli.resolvers.chain_team_resolver import resolve_chain_team_name
17
18
  from truss.cli.utils import common, output
18
19
  from truss.cli.utils.output import console
19
20
  from truss.remote.baseten.core import ACTIVE_STATUS, DEPLOYING_STATUSES
21
+ from truss.remote.baseten.remote import BasetenRemote
20
22
  from truss.remote.baseten.utils.status import get_displayable_status
23
+ from truss.remote.remote_factory import RemoteFactory
21
24
  from truss.util import user_config
22
25
  from truss.util.log_utils import LogInterceptor
23
26
 
@@ -211,6 +214,30 @@ def _create_chains_table(service) -> Tuple[rich.table.Table, List[str]]:
211
214
  default=False,
212
215
  help=common.INCLUDE_GIT_INFO_DOC,
213
216
  )
217
+ @click.option(
218
+ "--disable-chain-download",
219
+ "disable_chain_download",
220
+ is_flag=True,
221
+ required=False,
222
+ default=False,
223
+ help="Disable downloading of pushed chain source code from the UI.",
224
+ )
225
+ @click.option(
226
+ "--deployment-name",
227
+ type=str,
228
+ required=False,
229
+ help=(
230
+ "Name of the deployment created by the publish. Can only be used "
231
+ "in combination with '--publish' or '--promote'."
232
+ ),
233
+ )
234
+ @click.option(
235
+ "--team",
236
+ "provided_team_name",
237
+ type=str,
238
+ required=False,
239
+ help="Team name for the chain deployment",
240
+ )
214
241
  @click.pass_context
215
242
  @common.common_options()
216
243
  def push_chain(
@@ -227,6 +254,9 @@ def push_chain(
227
254
  environment: Optional[str],
228
255
  experimental_watch_chainlet_names: Optional[str],
229
256
  include_git_info: bool = False,
257
+ disable_chain_download: bool = False,
258
+ deployment_name: Optional[str] = None,
259
+ provided_team_name: Optional[str] = None,
230
260
  ) -> None:
231
261
  """
232
262
  Deploys a chain remotely.
@@ -271,10 +301,24 @@ def push_chain(
271
301
  if not include_git_info:
272
302
  include_git_info = user_config.settings.include_git_info
273
303
 
304
+ # Resolve team if not in dryrun mode
305
+ team_id = None
274
306
  with framework.ChainletImporter.import_target(source, entrypoint) as entrypoint_cls:
275
307
  chain_name = (
276
308
  name or entrypoint_cls.meta_data.chain_name or entrypoint_cls.display_name
277
309
  )
310
+
311
+ remote_provider = None
312
+ if not dryrun and remote:
313
+ remote_provider = cast(BasetenRemote, RemoteFactory.create(remote=remote))
314
+ existing_teams = remote_provider.api.get_teams()
315
+ _, team_id = resolve_chain_team_name(
316
+ remote_provider,
317
+ provided_team_name,
318
+ existing_chain_name=chain_name,
319
+ existing_teams=existing_teams,
320
+ )
321
+
278
322
  options = chains_def.PushOptionsBaseten.create(
279
323
  chain_name=chain_name,
280
324
  promote=promote,
@@ -284,6 +328,10 @@ def push_chain(
284
328
  environment=environment,
285
329
  include_git_info=include_git_info,
286
330
  working_dir=source.parent if source.is_file() else source.resolve(),
331
+ disable_chain_download=disable_chain_download,
332
+ deployment_name=deployment_name,
333
+ team_id=team_id,
334
+ remote_provider=remote_provider,
287
335
  )
288
336
  service = deployment_client.push(
289
337
  entrypoint_cls, options, progress_bar=progress.Progress
truss/cli/cli.py CHANGED
@@ -19,6 +19,7 @@ from truss.base.truss_config import Build, ModelServer, TransportKind
19
19
  from truss.cli import remote_cli
20
20
  from truss.cli.logs import utils as cli_log_utils
21
21
  from truss.cli.logs.model_log_watcher import ModelDeploymentLogWatcher
22
+ from truss.cli.resolvers.model_team_resolver import resolve_model_team_name
22
23
  from truss.cli.utils import common
23
24
  from truss.cli.utils.output import console, error_console
24
25
  from truss.remote.baseten.core import (
@@ -462,7 +463,7 @@ def run_python(script, target_directory):
462
463
  required=False,
463
464
  help=(
464
465
  "Name of the deployment created by the push. Can only be "
465
- "used in combination with '--publish' or '--promote'."
466
+ "used in combination with --publish or --promote."
466
467
  ),
467
468
  )
468
469
  @click.option(
@@ -501,6 +502,19 @@ def run_python(script, target_directory):
501
502
  "Default is --preserve-env-instance-type."
502
503
  ),
503
504
  )
505
+ @click.option(
506
+ "--deploy-timeout-minutes",
507
+ type=int,
508
+ required=False,
509
+ help="Timeout in minutes for the deploy operation.",
510
+ )
511
+ @click.option(
512
+ "--team",
513
+ "provided_team_name",
514
+ type=str,
515
+ required=False,
516
+ help="Team name for the model",
517
+ )
504
518
  @common.common_options()
505
519
  def push(
506
520
  target_directory: str,
@@ -518,6 +532,8 @@ def push(
518
532
  include_git_info: bool = False,
519
533
  tail: bool = False,
520
534
  preserve_env_instance_type: bool = True,
535
+ deploy_timeout_minutes: Optional[int] = None,
536
+ provided_team_name: Optional[str] = None,
521
537
  ) -> None:
522
538
  """
523
539
  Pushes a truss to a TrussRemote.
@@ -547,9 +563,21 @@ def push(
547
563
  if not model_name:
548
564
  model_name = remote_cli.inquire_model_name()
549
565
 
566
+ # Resolve team_id if BasetenRemote
567
+ team_id = None
568
+ if isinstance(remote_provider, BasetenRemote):
569
+ existing_teams = remote_provider.api.get_teams()
570
+ _, team_id = resolve_model_team_name(
571
+ remote_provider=remote_provider,
572
+ provided_team_name=provided_team_name,
573
+ existing_model_name=model_name,
574
+ existing_teams=existing_teams,
575
+ )
576
+
550
577
  if promote and environment:
551
- promote_warning = "'promote' flag and 'environment' flag were both specified. Ignoring the value of 'promote'"
552
- console.print(promote_warning, style="yellow")
578
+ raise click.UsageError(
579
+ "'promote' flag and 'environment' flag cannot both be specified."
580
+ )
553
581
  if promote and not environment:
554
582
  environment = PRODUCTION_ENVIRONMENT_NAME
555
583
 
@@ -611,11 +639,12 @@ def push(
611
639
  console.print(fp8_and_num_builder_gpus_text, style="yellow")
612
640
 
613
641
  source = Path(target_directory)
614
- # TODO(Abu): This needs to be refactored to be more generic
642
+ working_dir = source.parent if source.is_file() else source.resolve()
643
+
615
644
  service = remote_provider.push(
616
- tr,
645
+ truss_handle=tr,
617
646
  model_name=model_name,
618
- working_dir=source.parent if source.is_file() else source.resolve(),
647
+ working_dir=working_dir,
619
648
  publish=publish,
620
649
  promote=promote,
621
650
  preserve_previous_prod_deployment=preserve_previous_production_deployment,
@@ -625,7 +654,9 @@ def push(
625
654
  progress_bar=progress.Progress,
626
655
  include_git_info=include_git_info,
627
656
  preserve_env_instance_type=preserve_env_instance_type,
628
- ) # type: ignore
657
+ deploy_timeout_minutes=deploy_timeout_minutes,
658
+ team_id=team_id,
659
+ )
629
660
 
630
661
  click.echo(f"✨ Model {model_name} was successfully pushed ✨")
631
662
 
@@ -17,7 +17,9 @@ class LogWatcher(ABC):
17
17
  # NB(nikhil): we add buffer for clock skew, so this helps us detect duplicates.
18
18
  # TODO(nikhil): clean up hashes so this doesn't grow indefinitely.
19
19
  _log_hashes: set[str] = set()
20
- _last_poll_time: Optional[int] = None
20
+
21
+ _last_poll_time_ms: Optional[int] = None
22
+ _last_log_time_ms: Optional[int] = None
21
23
 
22
24
  def __init__(self, api: BasetenApi):
23
25
  self.api = api
@@ -26,37 +28,54 @@ class LogWatcher(ABC):
26
28
  log_str = f"{log.timestamp}-{log.message}-{log.replica}"
27
29
  return hashlib.sha256(log_str.encode("utf-8")).hexdigest()
28
30
 
29
- def _poll(self) -> Iterator[ParsedLog]:
30
- start_epoch: Optional[int] = None
31
- now = int(time.time() * 1000)
32
- if self._last_poll_time is not None:
33
- start_epoch = self._last_poll_time - CLOCK_SKEW_BUFFER_MS
31
+ def get_start_epoch_ms(self, now_ms: int) -> Optional[int]:
32
+ if self._last_poll_time_ms:
33
+ return self._last_poll_time_ms - CLOCK_SKEW_BUFFER_MS
34
+
35
+ return None
34
36
 
37
+ def fetch_and_parse_logs(
38
+ self, start_epoch_millis: Optional[int], end_epoch_millis: Optional[int]
39
+ ) -> Iterator[ParsedLog]:
35
40
  api_logs = self.fetch_logs(
36
- start_epoch_millis=start_epoch, end_epoch_millis=now + CLOCK_SKEW_BUFFER_MS
41
+ start_epoch_millis=start_epoch_millis, end_epoch_millis=end_epoch_millis
37
42
  )
38
43
 
39
44
  parsed_logs = parse_logs(api_logs)
45
+
40
46
  for log in parsed_logs:
41
- h = self._hash_log(log)
42
- if h not in self._log_hashes:
47
+ if (h := self._hash_log(log)) not in self._log_hashes:
43
48
  self._log_hashes.add(h)
49
+
44
50
  yield log
45
51
 
46
- self._last_poll_time = now
52
+ def poll(self) -> Iterator[ParsedLog]:
53
+ now_ms = int(time.time() * 1000)
54
+ start_epoch_ms = self.get_start_epoch_ms(now_ms)
55
+
56
+ for log in self.fetch_and_parse_logs(
57
+ start_epoch_millis=start_epoch_ms,
58
+ end_epoch_millis=now_ms + CLOCK_SKEW_BUFFER_MS,
59
+ ):
60
+ yield log
61
+
62
+ epoch_ns = int(log.timestamp)
63
+ self._last_log_time_ms = int(epoch_ns / 1e6)
64
+
65
+ self._last_poll_time_ms = now_ms
47
66
 
48
67
  def watch(self) -> Iterator[ParsedLog]:
49
68
  self.before_polling()
50
69
  with console.status("Polling logs", spinner="aesthetic"):
51
70
  while True:
52
- for log in self._poll():
71
+ for log in self.poll():
53
72
  yield log
54
73
  if self._log_hashes:
55
74
  break
56
75
  time.sleep(POLL_INTERVAL_SEC)
57
76
 
58
77
  while self.should_poll_again():
59
- for log in self._poll():
78
+ for log in self.poll():
60
79
  yield log
61
80
  time.sleep(POLL_INTERVAL_SEC)
62
81
  self.post_poll()
@@ -1,9 +1,12 @@
1
+ from functools import cached_property
1
2
  from typing import Any, List, Optional
2
3
 
3
4
  from truss.cli.logs.base_watcher import LogWatcher
4
5
  from truss.remote.baseten.api import BasetenApi
5
6
  from truss.remote.baseten.utils.status import MODEL_RUNNING_STATES
6
7
 
8
+ MAX_LOOK_BACK_MS = 1000 * 60 * 60 # 1 hour.
9
+
7
10
 
8
11
  class ModelDeploymentLogWatcher(LogWatcher):
9
12
  _model_id: str
@@ -25,11 +28,31 @@ class ModelDeploymentLogWatcher(LogWatcher):
25
28
  self._model_id, self._deployment_id, start_epoch_millis, end_epoch_millis
26
29
  )
27
30
 
31
+ def get_start_epoch_ms(self, now_ms: int) -> Optional[int]:
32
+ # NOTE(Tyron): If there can be multiple replicas,
33
+ # we can't use a timestamp cursor to poll for logs.
34
+ if not self._is_development:
35
+ return super().get_start_epoch_ms(now_ms)
36
+
37
+ # Cursor logic.
38
+
39
+ if self._last_log_time_ms:
40
+ return max(self._last_log_time_ms, now_ms - MAX_LOOK_BACK_MS)
41
+
42
+ return None
43
+
28
44
  def should_poll_again(self) -> bool:
29
45
  return self._current_status in MODEL_RUNNING_STATES
30
46
 
47
+ def _get_deployment(self) -> Any:
48
+ return self.api.get_deployment(self._model_id, self._deployment_id)
49
+
31
50
  def _get_current_status(self) -> str:
32
- return self.api.get_deployment(self._model_id, self._deployment_id)["status"]
51
+ return self._get_deployment()["status"]
52
+
53
+ @cached_property
54
+ def _is_development(self) -> bool:
55
+ return self._get_deployment()["is_development"]
33
56
 
34
57
  def post_poll(self) -> None:
35
58
  self._current_status = self._get_current_status()
truss/cli/remote_cli.py CHANGED
@@ -1,3 +1,5 @@
1
+ from typing import Optional
2
+
1
3
  from InquirerPy import inquirer
2
4
  from InquirerPy.validator import ValidationError, Validator
3
5
 
@@ -56,3 +58,30 @@ def inquire_remote_name() -> str:
56
58
 
57
59
  def inquire_model_name() -> str:
58
60
  return inquirer.text("📦 Name this model:", qmark="").execute()
61
+
62
+
63
+ def get_team_id_from_name(
64
+ teams: dict[str, dict[str, str]], team_name: str
65
+ ) -> Optional[str]:
66
+ team = teams.get(team_name)
67
+ return team["id"] if team else None
68
+
69
+
70
+ def format_available_teams(teams: dict[str, dict[str, str]]) -> str:
71
+ team_names = list(teams.keys())
72
+ return ", ".join(team_names) if team_names else "none"
73
+
74
+
75
+ def inquire_team(
76
+ existing_teams: Optional[dict[str, dict[str, str]]] = None,
77
+ ) -> Optional[str]:
78
+ if existing_teams is not None:
79
+ selected_team_name = inquirer.select(
80
+ "👥 Which team do you want to push to?",
81
+ qmark="",
82
+ choices=list[str](existing_teams.keys()),
83
+ ).execute()
84
+ return selected_team_name
85
+
86
+ # If no existing teams, return None (don't propagate team param)
87
+ return None
@@ -0,0 +1,82 @@
1
+ """Team resolution logic for chains."""
2
+
3
+ from typing import Optional
4
+
5
+ import click
6
+
7
+ from truss.cli import remote_cli
8
+ from truss.remote.baseten.remote import BasetenRemote
9
+
10
+
11
+ def resolve_chain_team_name(
12
+ remote_provider: BasetenRemote,
13
+ provided_team_name: Optional[str],
14
+ existing_chain_name: Optional[str] = None,
15
+ existing_teams: Optional[dict[str, dict[str, str]]] = None,
16
+ ) -> tuple[Optional[str], Optional[str]]:
17
+ """Resolve team name and team_id from provided team name or by prompting the user.
18
+ Returns a tuple of (team_name, team_id).
19
+ This function handles 8 distinct scenarios organized into 3 high-level categories:
20
+
21
+ HIGH-LEVEL SCENARIO 1: --team PROVIDED
22
+ SCENARIO 1: Valid team name, user has access
23
+ → Returns (team_name, team_id) for that team (no prompt, no error)
24
+ SCENARIO 2: Invalid team name (does not exist)
25
+ → Raises ClickException with error message listing available teams
26
+
27
+ HIGH-LEVEL SCENARIO 2: --team NOT PROVIDED, Chain does not exist
28
+ SCENARIO 3: User has multiple teams, no existing chain
29
+ → Prompts user to select a team via inquire_team()
30
+ SCENARIO 6: User has exactly one team, no existing chain
31
+ → Returns (team_name, team_id) for the single team automatically (no prompt)
32
+
33
+ HIGH-LEVEL SCENARIO 3: --team NOT PROVIDED, Chain exists
34
+ SCENARIO 4: User has multiple teams, existing chain in exactly one team
35
+ → Auto-detects and returns (team_name, team_id) for that team (no prompt)
36
+ SCENARIO 5: User has multiple teams, existing chain exists in multiple teams
37
+ → Prompts user to select a team via inquire_team()
38
+ SCENARIO 7: User has exactly one team, existing chain matches the team
39
+ → Auto-detects and returns (team_name, team_id) for the single team (no prompt)
40
+ SCENARIO 8: User has exactly one team, existing chain exists in different team
41
+ → Returns (team_name, team_id) for the single team automatically (no prompt, uses user's only team)
42
+ """
43
+ if existing_teams is None:
44
+ existing_teams = remote_provider.api.get_teams()
45
+
46
+ def _get_team_id(team_name: Optional[str]) -> Optional[str]:
47
+ if team_name and existing_teams:
48
+ team_data = existing_teams.get(team_name)
49
+ return team_data["id"] if team_data else None
50
+ return None
51
+
52
+ if provided_team_name is not None:
53
+ if provided_team_name not in existing_teams:
54
+ available_teams_str = remote_cli.format_available_teams(existing_teams)
55
+ raise click.ClickException(
56
+ f"Team '{provided_team_name}' does not exist. Available teams: {available_teams_str}"
57
+ )
58
+ return (provided_team_name, _get_team_id(provided_team_name))
59
+
60
+ existing_chains = None
61
+ if existing_chain_name is not None:
62
+ existing_chains = remote_provider.api.get_chains()
63
+ matching_chains = [
64
+ c for c in existing_chains if c.get("name") == existing_chain_name
65
+ ]
66
+
67
+ if len(matching_chains) > 1:
68
+ selected_team_name = remote_cli.inquire_team(existing_teams=existing_teams)
69
+ return (selected_team_name, _get_team_id(selected_team_name))
70
+
71
+ if len(matching_chains) == 1:
72
+ chain_team = matching_chains[0].get("team")
73
+ chain_team_name = chain_team.get("name") if chain_team else None
74
+ if chain_team_name and chain_team_name in existing_teams:
75
+ return (chain_team_name, _get_team_id(chain_team_name))
76
+
77
+ if len(existing_teams) == 1:
78
+ single_team_name = list(existing_teams.keys())[0]
79
+ return (single_team_name, _get_team_id(single_team_name))
80
+
81
+ selected_team_name = remote_cli.inquire_team(existing_teams=existing_teams)
82
+ return (selected_team_name, _get_team_id(selected_team_name))
@@ -0,0 +1,90 @@
1
+ """Team resolution logic for models."""
2
+
3
+ from typing import Optional
4
+
5
+ import click
6
+
7
+ from truss.cli import remote_cli
8
+ from truss.remote.baseten.remote import BasetenRemote
9
+
10
+
11
+ def resolve_model_team_name(
12
+ remote_provider: BasetenRemote,
13
+ provided_team_name: Optional[str],
14
+ existing_model_name: Optional[str] = None,
15
+ existing_teams: Optional[dict[str, dict[str, str]]] = None,
16
+ ) -> tuple[Optional[str], Optional[str]]:
17
+ """Resolve team name and team_id from provided team name or by prompting the user.
18
+ Returns a tuple of (team_name, team_id).
19
+ This function handles 8 distinct scenarios organized into 3 high-level categories:
20
+
21
+ HIGH-LEVEL SCENARIO 1: --team PROVIDED
22
+ SCENARIO 1: Valid team name, user has access
23
+ → Returns (team_name, team_id) for that team (no prompt, no error)
24
+ SCENARIO 2: Invalid team name (does not exist)
25
+ → Raises ClickException with error message listing available teams
26
+
27
+ HIGH-LEVEL SCENARIO 2: --team NOT PROVIDED, Model does not exist
28
+ SCENARIO 3: User has multiple teams, no existing model
29
+ → Prompts user to select a team via inquire_team()
30
+ SCENARIO 6: User has exactly one team, no existing model
31
+ → Returns (team_name, team_id) for the single team automatically (no prompt)
32
+
33
+ HIGH-LEVEL SCENARIO 3: --team NOT PROVIDED, Model exists
34
+ SCENARIO 4: User has multiple teams, existing model in exactly one team
35
+ → Auto-detects and returns (team_name, team_id) for that team (no prompt)
36
+ SCENARIO 5: User has multiple teams, existing model exists in multiple teams
37
+ → Prompts user to select a team via inquire_team()
38
+ SCENARIO 7: User has exactly one team, existing model matches the team
39
+ → Auto-detects and returns (team_name, team_id) for the single team (no prompt)
40
+ SCENARIO 8: User has exactly one team, existing model exists in different team
41
+ → Returns (team_name, team_id) for the single team automatically (no prompt, uses user's only team)
42
+ """
43
+ if existing_teams is None:
44
+ existing_teams = remote_provider.api.get_teams()
45
+
46
+ def _get_team_id(team_name: Optional[str]) -> Optional[str]:
47
+ if team_name and existing_teams:
48
+ team_data = existing_teams.get(team_name)
49
+ return team_data["id"] if team_data else None
50
+ return None
51
+
52
+ def _get_matching_models_in_accessible_teams(model_name: str) -> list[dict]:
53
+ """Get models matching the name that are in teams the user has access to."""
54
+ all_models_data = remote_provider.api.models()
55
+ accessible_team_ids = {team_data["id"] for team_data in existing_teams.values()}
56
+
57
+ return [
58
+ m
59
+ for m in all_models_data.get("models", [])
60
+ if m.get("name") == model_name
61
+ and m.get("team", {}).get("id") in accessible_team_ids
62
+ ]
63
+
64
+ if provided_team_name is not None:
65
+ if provided_team_name not in existing_teams:
66
+ available_teams_str = remote_cli.format_available_teams(existing_teams)
67
+ raise click.ClickException(
68
+ f"Team '{provided_team_name}' does not exist. Available teams: {available_teams_str}"
69
+ )
70
+ return (provided_team_name, _get_team_id(provided_team_name))
71
+
72
+ if existing_model_name is not None:
73
+ matching_models = _get_matching_models_in_accessible_teams(existing_model_name)
74
+
75
+ if len(matching_models) == 1:
76
+ # Exactly one model in an accessible team - auto-detect
77
+ team = matching_models[0].get("team", {})
78
+ model_team_name = team.get("name")
79
+ model_team_id = team.get("id")
80
+ if model_team_name and model_team_name in existing_teams:
81
+ return (model_team_name, model_team_id)
82
+ # If len > 1, multiple models exist - fall through to prompt logic
83
+ # If len == 0, no models exist - fall through to prompt logic
84
+
85
+ if len(existing_teams) == 1:
86
+ single_team_name = list(existing_teams.keys())[0]
87
+ return (single_team_name, _get_team_id(single_team_name))
88
+
89
+ selected_team_name = remote_cli.inquire_team(existing_teams=existing_teams)
90
+ return (selected_team_name, _get_team_id(selected_team_name))