truss 0.11.6rc102__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- truss/api/__init__.py +5 -2
- truss/base/constants.py +1 -0
- truss/base/trt_llm_config.py +14 -3
- truss/base/truss_config.py +19 -4
- truss/cli/chains_commands.py +49 -1
- truss/cli/cli.py +38 -7
- truss/cli/logs/base_watcher.py +31 -12
- truss/cli/logs/model_log_watcher.py +24 -1
- truss/cli/remote_cli.py +29 -0
- truss/cli/resolvers/chain_team_resolver.py +82 -0
- truss/cli/resolvers/model_team_resolver.py +90 -0
- truss/cli/resolvers/training_project_team_resolver.py +81 -0
- truss/cli/train/cache.py +332 -0
- truss/cli/train/core.py +57 -163
- truss/cli/train/deploy_checkpoints/__init__.py +2 -2
- truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +236 -103
- truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +1 -52
- truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +1 -86
- truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py +1 -85
- truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +1 -56
- truss/cli/train/types.py +18 -9
- truss/cli/train_commands.py +180 -35
- truss/cli/utils/common.py +40 -3
- truss/contexts/image_builder/serving_image_builder.py +17 -4
- truss/remote/baseten/api.py +215 -9
- truss/remote/baseten/core.py +63 -7
- truss/remote/baseten/custom_types.py +1 -0
- truss/remote/baseten/remote.py +42 -2
- truss/remote/baseten/service.py +0 -7
- truss/remote/baseten/utils/transfer.py +5 -2
- truss/templates/base.Dockerfile.jinja +8 -4
- truss/templates/control/control/application.py +51 -26
- truss/templates/control/control/endpoints.py +1 -5
- truss/templates/control/control/helpers/inference_server_process_controller.py +10 -4
- truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py +33 -18
- truss/templates/control/control/server.py +1 -1
- truss/templates/control/requirements.txt +1 -2
- truss/templates/docker_server/proxy.conf.jinja +13 -0
- truss/templates/docker_server/supervisord.conf.jinja +2 -1
- truss/templates/no_build.Dockerfile.jinja +1 -0
- truss/templates/server/requirements.txt +2 -3
- truss/templates/server/truss_server.py +2 -5
- truss/templates/server.Dockerfile.jinja +12 -12
- truss/templates/shared/lazy_data_resolver.py +214 -2
- truss/templates/shared/util.py +6 -5
- truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
- truss/tests/cli/test_chains_cli.py +144 -0
- truss/tests/cli/test_cli.py +134 -1
- truss/tests/cli/test_cli_utils_common.py +11 -0
- truss/tests/cli/test_model_team_resolver.py +279 -0
- truss/tests/cli/train/test_cache_view.py +240 -3
- truss/tests/cli/train/test_deploy_checkpoints.py +2 -846
- truss/tests/cli/train/test_train_cli_core.py +2 -2
- truss/tests/cli/train/test_train_team_parameter.py +395 -0
- truss/tests/conftest.py +187 -0
- truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
- truss/tests/remote/baseten/test_api.py +122 -3
- truss/tests/remote/baseten/test_chain_upload.py +294 -0
- truss/tests/remote/baseten/test_core.py +86 -0
- truss/tests/remote/baseten/test_remote.py +216 -288
- truss/tests/remote/baseten/test_service.py +56 -0
- truss/tests/templates/control/control/conftest.py +20 -0
- truss/tests/templates/control/control/test_endpoints.py +4 -0
- truss/tests/templates/control/control/test_server.py +8 -24
- truss/tests/templates/control/control/test_server_integration.py +4 -2
- truss/tests/test_config.py +21 -12
- truss/tests/test_data/server.Dockerfile +3 -1
- truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
- truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
- truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
- truss/tests/test_model_inference.py +13 -0
- truss/tests/util/test_env_vars.py +8 -3
- truss/util/__init__.py +0 -0
- truss/util/env_vars.py +19 -8
- truss/util/error_utils.py +37 -0
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/METADATA +2 -2
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/RECORD +88 -70
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
- truss_chains/deployment/deployment_client.py +16 -4
- truss_chains/private_types.py +18 -0
- truss_chains/public_api.py +3 -0
- truss_train/definitions.py +6 -4
- truss_train/deployment.py +43 -21
- truss_train/public_api.py +4 -2
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
truss/api/__init__.py
CHANGED
|
@@ -65,6 +65,7 @@ def push(
|
|
|
65
65
|
progress_bar: Optional[Type["progress.Progress"]] = None,
|
|
66
66
|
include_git_info: bool = False,
|
|
67
67
|
preserve_env_instance_type: bool = True,
|
|
68
|
+
deploy_timeout_minutes: Optional[int] = None,
|
|
68
69
|
) -> definitions.ModelDeployment:
|
|
69
70
|
"""
|
|
70
71
|
Pushes a Truss to Baseten.
|
|
@@ -77,13 +78,13 @@ def push(
|
|
|
77
78
|
promote the truss to production after deploy completes.
|
|
78
79
|
promote: Push the truss as a published deployment. Even if a production deployment exists,
|
|
79
80
|
promote the truss to production after deploy completes.
|
|
80
|
-
preserve_previous_production_deployment: Preserve the previous production deployment
|
|
81
|
+
preserve_previous_production_deployment: Preserve the previous production deployment's autoscaling
|
|
81
82
|
setting. When not specified, the previous production deployment will be updated to allow it to
|
|
82
83
|
scale to zero. Can only be use in combination with `promote` option.
|
|
83
84
|
trusted: [DEPRECATED]
|
|
84
85
|
deployment_name: Name of the deployment created by the push. Can only be
|
|
85
86
|
used in combination with `publish` or `promote`. Deployment name must
|
|
86
|
-
only contain alphanumeric,
|
|
87
|
+
only contain alphanumeric, '.', '-' or '_' characters.
|
|
87
88
|
environment: Name of stable environment on baseten.
|
|
88
89
|
progress_bar: Optional `rich.progress.Progress` if output is desired.
|
|
89
90
|
include_git_info: Whether to attach git versioning info (sha, branch, tag) to
|
|
@@ -92,6 +93,7 @@ def push(
|
|
|
92
93
|
preserve_env_instance_type: When pushing a truss to an environment, whether to use the resources
|
|
93
94
|
specified in the truss config to resolve the instance type or preserve the instance type
|
|
94
95
|
configured in the specified environment.
|
|
96
|
+
deploy_timeout_minutes: Optional timeout in minutes for the deployment operation.
|
|
95
97
|
|
|
96
98
|
Returns:
|
|
97
99
|
The newly created ModelDeployment.
|
|
@@ -135,6 +137,7 @@ def push(
|
|
|
135
137
|
progress_bar=progress_bar,
|
|
136
138
|
include_git_info=include_git_info,
|
|
137
139
|
preserve_env_instance_type=preserve_env_instance_type,
|
|
140
|
+
deploy_timeout_minutes=deploy_timeout_minutes,
|
|
138
141
|
) # type: ignore
|
|
139
142
|
|
|
140
143
|
return definitions.ModelDeployment(cast(BasetenService, service))
|
truss/base/constants.py
CHANGED
|
@@ -49,6 +49,7 @@ FILENAME_CONSTANTS_MAP = {
|
|
|
49
49
|
}
|
|
50
50
|
|
|
51
51
|
SERVER_DOCKERFILE_TEMPLATE_NAME = "server.Dockerfile.jinja"
|
|
52
|
+
NO_BUILD_DOCKERFILE_TEMPLATE_NAME = "no_build.Dockerfile.jinja"
|
|
52
53
|
MODEL_DOCKERFILE_NAME = "Dockerfile"
|
|
53
54
|
MODEL_CACHE_PATH = pathlib.Path("/app/model_cache")
|
|
54
55
|
README_TEMPLATE_NAME = "README.md.jinja"
|
truss/base/trt_llm_config.py
CHANGED
|
@@ -68,6 +68,7 @@ class TrussTRTLLMQuantizationType(str, Enum):
|
|
|
68
68
|
FP8_KV = "fp8_kv"
|
|
69
69
|
FP4 = "fp4"
|
|
70
70
|
FP4_KV = "fp4_kv"
|
|
71
|
+
FP4_MLP_ONLY = "fp4_mlp_only"
|
|
71
72
|
|
|
72
73
|
|
|
73
74
|
class TrussTRTLLMPluginConfiguration(PydanticTrTBaseModel):
|
|
@@ -329,9 +330,16 @@ pip install truss==0.10.8
|
|
|
329
330
|
raise ValueError("Using fp8 context fmha requires paged context fmha")
|
|
330
331
|
if (
|
|
331
332
|
self.plugin_configuration.use_fp8_context_fmha
|
|
332
|
-
and
|
|
333
|
+
and self.quantization_type
|
|
334
|
+
not in (
|
|
335
|
+
TrussTRTLLMQuantizationType.FP8_KV,
|
|
336
|
+
TrussTRTLLMQuantizationType.FP4_KV,
|
|
337
|
+
)
|
|
333
338
|
):
|
|
334
|
-
raise ValueError(
|
|
339
|
+
raise ValueError(
|
|
340
|
+
"Using fp8 context fmha requires fp8 kv, or fp4 with kv cache dtype"
|
|
341
|
+
)
|
|
342
|
+
|
|
335
343
|
return self
|
|
336
344
|
|
|
337
345
|
def _validate_speculator_config(self):
|
|
@@ -570,6 +578,7 @@ class TRTLLMConfigurationV2(PydanticTrTBaseModel):
|
|
|
570
578
|
"quantization_type",
|
|
571
579
|
"quantization_config",
|
|
572
580
|
"skip_build_result",
|
|
581
|
+
"num_builder_gpus",
|
|
573
582
|
]
|
|
574
583
|
|
|
575
584
|
build_settings = self.build.model_dump(exclude_unset=True)
|
|
@@ -705,7 +714,9 @@ def trt_llm_common_validation(config: "TrussConfig"):
|
|
|
705
714
|
"accelerators or newer (CUDA_COMPUTE>=89)"
|
|
706
715
|
)
|
|
707
716
|
elif trt_llm_config.build.quantization_type in [
|
|
708
|
-
TrussTRTLLMQuantizationType.FP4
|
|
717
|
+
TrussTRTLLMQuantizationType.FP4,
|
|
718
|
+
TrussTRTLLMQuantizationType.FP4_KV,
|
|
719
|
+
TrussTRTLLMQuantizationType.FP4_MLP_ONLY,
|
|
709
720
|
] and config.resources.accelerator.accelerator in [
|
|
710
721
|
truss_config.Accelerator.H100,
|
|
711
722
|
truss_config.Accelerator.L4,
|
truss/base/truss_config.py
CHANGED
|
@@ -147,7 +147,7 @@ class ModelRepo(custom_types.ConfigModel):
|
|
|
147
147
|
volume_folder: Optional[
|
|
148
148
|
Annotated[str, pydantic.StringConstraints(min_length=1)]
|
|
149
149
|
] = None
|
|
150
|
-
use_volume: bool
|
|
150
|
+
use_volume: bool
|
|
151
151
|
kind: ModelRepoSourceKind = ModelRepoSourceKind.HF
|
|
152
152
|
runtime_secret_name: str = "hf_access_token"
|
|
153
153
|
|
|
@@ -163,7 +163,7 @@ class ModelRepo(custom_types.ConfigModel):
|
|
|
163
163
|
return v
|
|
164
164
|
if v.get("kind") == ModelRepoSourceKind.HF.value and v.get("revision") is None:
|
|
165
165
|
logger.warning(
|
|
166
|
-
"the key `revision: str` is required for use_volume=True huggingface repos."
|
|
166
|
+
"the key `revision: str` is required for use_volume=True huggingface repos. For S3/GCS/Azure repos, set it to any non-empty string."
|
|
167
167
|
)
|
|
168
168
|
raise_insufficent_revision(v.get("repo_id"), v.get("revision"))
|
|
169
169
|
if v.get("volume_folder") is None or len(v["volume_folder"]) == 0:
|
|
@@ -202,7 +202,14 @@ class ModelCache(pydantic.RootModel[list[ModelRepo]]):
|
|
|
202
202
|
)
|
|
203
203
|
|
|
204
204
|
|
|
205
|
-
class
|
|
205
|
+
class ModelRepoCacheInternal(ModelRepo):
|
|
206
|
+
use_volume: bool = False # override
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class CacheInternal(pydantic.RootModel[list[ModelRepoCacheInternal]]):
|
|
210
|
+
@property
|
|
211
|
+
def models(self) -> list[ModelRepoCacheInternal]:
|
|
212
|
+
return self.root
|
|
206
213
|
|
|
207
214
|
|
|
208
215
|
class HealthChecks(custom_types.ConfigModel):
|
|
@@ -541,11 +548,19 @@ class BaseImage(custom_types.ConfigModel):
|
|
|
541
548
|
|
|
542
549
|
|
|
543
550
|
class DockerServer(custom_types.ConfigModel):
|
|
544
|
-
start_command: str
|
|
551
|
+
start_command: Optional[str] = None
|
|
545
552
|
server_port: int
|
|
546
553
|
predict_endpoint: str
|
|
547
554
|
readiness_endpoint: str
|
|
548
555
|
liveness_endpoint: str
|
|
556
|
+
run_as_user_id: Optional[int] = None
|
|
557
|
+
no_build: Optional[bool] = None
|
|
558
|
+
|
|
559
|
+
@pydantic.model_validator(mode="after")
|
|
560
|
+
def _validate_start_command(self) -> "DockerServer":
|
|
561
|
+
if not self.no_build and self.start_command is None:
|
|
562
|
+
raise ValueError("start_command is required when no_build is not true")
|
|
563
|
+
return self
|
|
549
564
|
|
|
550
565
|
|
|
551
566
|
class TrainingArtifactReference(custom_types.ConfigModel):
|
truss/cli/chains_commands.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import time
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import List, Optional, Tuple
|
|
3
|
+
from typing import List, Optional, Tuple, cast
|
|
4
4
|
|
|
5
5
|
import rich
|
|
6
6
|
import rich.live
|
|
@@ -14,10 +14,13 @@ from rich import progress
|
|
|
14
14
|
|
|
15
15
|
from truss.cli import remote_cli
|
|
16
16
|
from truss.cli.cli import truss_cli
|
|
17
|
+
from truss.cli.resolvers.chain_team_resolver import resolve_chain_team_name
|
|
17
18
|
from truss.cli.utils import common, output
|
|
18
19
|
from truss.cli.utils.output import console
|
|
19
20
|
from truss.remote.baseten.core import ACTIVE_STATUS, DEPLOYING_STATUSES
|
|
21
|
+
from truss.remote.baseten.remote import BasetenRemote
|
|
20
22
|
from truss.remote.baseten.utils.status import get_displayable_status
|
|
23
|
+
from truss.remote.remote_factory import RemoteFactory
|
|
21
24
|
from truss.util import user_config
|
|
22
25
|
from truss.util.log_utils import LogInterceptor
|
|
23
26
|
|
|
@@ -211,6 +214,30 @@ def _create_chains_table(service) -> Tuple[rich.table.Table, List[str]]:
|
|
|
211
214
|
default=False,
|
|
212
215
|
help=common.INCLUDE_GIT_INFO_DOC,
|
|
213
216
|
)
|
|
217
|
+
@click.option(
|
|
218
|
+
"--disable-chain-download",
|
|
219
|
+
"disable_chain_download",
|
|
220
|
+
is_flag=True,
|
|
221
|
+
required=False,
|
|
222
|
+
default=False,
|
|
223
|
+
help="Disable downloading of pushed chain source code from the UI.",
|
|
224
|
+
)
|
|
225
|
+
@click.option(
|
|
226
|
+
"--deployment-name",
|
|
227
|
+
type=str,
|
|
228
|
+
required=False,
|
|
229
|
+
help=(
|
|
230
|
+
"Name of the deployment created by the publish. Can only be used "
|
|
231
|
+
"in combination with '--publish' or '--promote'."
|
|
232
|
+
),
|
|
233
|
+
)
|
|
234
|
+
@click.option(
|
|
235
|
+
"--team",
|
|
236
|
+
"provided_team_name",
|
|
237
|
+
type=str,
|
|
238
|
+
required=False,
|
|
239
|
+
help="Team name for the chain deployment",
|
|
240
|
+
)
|
|
214
241
|
@click.pass_context
|
|
215
242
|
@common.common_options()
|
|
216
243
|
def push_chain(
|
|
@@ -227,6 +254,9 @@ def push_chain(
|
|
|
227
254
|
environment: Optional[str],
|
|
228
255
|
experimental_watch_chainlet_names: Optional[str],
|
|
229
256
|
include_git_info: bool = False,
|
|
257
|
+
disable_chain_download: bool = False,
|
|
258
|
+
deployment_name: Optional[str] = None,
|
|
259
|
+
provided_team_name: Optional[str] = None,
|
|
230
260
|
) -> None:
|
|
231
261
|
"""
|
|
232
262
|
Deploys a chain remotely.
|
|
@@ -271,10 +301,24 @@ def push_chain(
|
|
|
271
301
|
if not include_git_info:
|
|
272
302
|
include_git_info = user_config.settings.include_git_info
|
|
273
303
|
|
|
304
|
+
# Resolve team if not in dryrun mode
|
|
305
|
+
team_id = None
|
|
274
306
|
with framework.ChainletImporter.import_target(source, entrypoint) as entrypoint_cls:
|
|
275
307
|
chain_name = (
|
|
276
308
|
name or entrypoint_cls.meta_data.chain_name or entrypoint_cls.display_name
|
|
277
309
|
)
|
|
310
|
+
|
|
311
|
+
remote_provider = None
|
|
312
|
+
if not dryrun and remote:
|
|
313
|
+
remote_provider = cast(BasetenRemote, RemoteFactory.create(remote=remote))
|
|
314
|
+
existing_teams = remote_provider.api.get_teams()
|
|
315
|
+
_, team_id = resolve_chain_team_name(
|
|
316
|
+
remote_provider,
|
|
317
|
+
provided_team_name,
|
|
318
|
+
existing_chain_name=chain_name,
|
|
319
|
+
existing_teams=existing_teams,
|
|
320
|
+
)
|
|
321
|
+
|
|
278
322
|
options = chains_def.PushOptionsBaseten.create(
|
|
279
323
|
chain_name=chain_name,
|
|
280
324
|
promote=promote,
|
|
@@ -284,6 +328,10 @@ def push_chain(
|
|
|
284
328
|
environment=environment,
|
|
285
329
|
include_git_info=include_git_info,
|
|
286
330
|
working_dir=source.parent if source.is_file() else source.resolve(),
|
|
331
|
+
disable_chain_download=disable_chain_download,
|
|
332
|
+
deployment_name=deployment_name,
|
|
333
|
+
team_id=team_id,
|
|
334
|
+
remote_provider=remote_provider,
|
|
287
335
|
)
|
|
288
336
|
service = deployment_client.push(
|
|
289
337
|
entrypoint_cls, options, progress_bar=progress.Progress
|
truss/cli/cli.py
CHANGED
|
@@ -19,6 +19,7 @@ from truss.base.truss_config import Build, ModelServer, TransportKind
|
|
|
19
19
|
from truss.cli import remote_cli
|
|
20
20
|
from truss.cli.logs import utils as cli_log_utils
|
|
21
21
|
from truss.cli.logs.model_log_watcher import ModelDeploymentLogWatcher
|
|
22
|
+
from truss.cli.resolvers.model_team_resolver import resolve_model_team_name
|
|
22
23
|
from truss.cli.utils import common
|
|
23
24
|
from truss.cli.utils.output import console, error_console
|
|
24
25
|
from truss.remote.baseten.core import (
|
|
@@ -462,7 +463,7 @@ def run_python(script, target_directory):
|
|
|
462
463
|
required=False,
|
|
463
464
|
help=(
|
|
464
465
|
"Name of the deployment created by the push. Can only be "
|
|
465
|
-
"used in combination with
|
|
466
|
+
"used in combination with --publish or --promote."
|
|
466
467
|
),
|
|
467
468
|
)
|
|
468
469
|
@click.option(
|
|
@@ -501,6 +502,19 @@ def run_python(script, target_directory):
|
|
|
501
502
|
"Default is --preserve-env-instance-type."
|
|
502
503
|
),
|
|
503
504
|
)
|
|
505
|
+
@click.option(
|
|
506
|
+
"--deploy-timeout-minutes",
|
|
507
|
+
type=int,
|
|
508
|
+
required=False,
|
|
509
|
+
help="Timeout in minutes for the deploy operation.",
|
|
510
|
+
)
|
|
511
|
+
@click.option(
|
|
512
|
+
"--team",
|
|
513
|
+
"provided_team_name",
|
|
514
|
+
type=str,
|
|
515
|
+
required=False,
|
|
516
|
+
help="Team name for the model",
|
|
517
|
+
)
|
|
504
518
|
@common.common_options()
|
|
505
519
|
def push(
|
|
506
520
|
target_directory: str,
|
|
@@ -518,6 +532,8 @@ def push(
|
|
|
518
532
|
include_git_info: bool = False,
|
|
519
533
|
tail: bool = False,
|
|
520
534
|
preserve_env_instance_type: bool = True,
|
|
535
|
+
deploy_timeout_minutes: Optional[int] = None,
|
|
536
|
+
provided_team_name: Optional[str] = None,
|
|
521
537
|
) -> None:
|
|
522
538
|
"""
|
|
523
539
|
Pushes a truss to a TrussRemote.
|
|
@@ -547,9 +563,21 @@ def push(
|
|
|
547
563
|
if not model_name:
|
|
548
564
|
model_name = remote_cli.inquire_model_name()
|
|
549
565
|
|
|
566
|
+
# Resolve team_id if BasetenRemote
|
|
567
|
+
team_id = None
|
|
568
|
+
if isinstance(remote_provider, BasetenRemote):
|
|
569
|
+
existing_teams = remote_provider.api.get_teams()
|
|
570
|
+
_, team_id = resolve_model_team_name(
|
|
571
|
+
remote_provider=remote_provider,
|
|
572
|
+
provided_team_name=provided_team_name,
|
|
573
|
+
existing_model_name=model_name,
|
|
574
|
+
existing_teams=existing_teams,
|
|
575
|
+
)
|
|
576
|
+
|
|
550
577
|
if promote and environment:
|
|
551
|
-
|
|
552
|
-
|
|
578
|
+
raise click.UsageError(
|
|
579
|
+
"'promote' flag and 'environment' flag cannot both be specified."
|
|
580
|
+
)
|
|
553
581
|
if promote and not environment:
|
|
554
582
|
environment = PRODUCTION_ENVIRONMENT_NAME
|
|
555
583
|
|
|
@@ -611,11 +639,12 @@ def push(
|
|
|
611
639
|
console.print(fp8_and_num_builder_gpus_text, style="yellow")
|
|
612
640
|
|
|
613
641
|
source = Path(target_directory)
|
|
614
|
-
|
|
642
|
+
working_dir = source.parent if source.is_file() else source.resolve()
|
|
643
|
+
|
|
615
644
|
service = remote_provider.push(
|
|
616
|
-
tr,
|
|
645
|
+
truss_handle=tr,
|
|
617
646
|
model_name=model_name,
|
|
618
|
-
working_dir=
|
|
647
|
+
working_dir=working_dir,
|
|
619
648
|
publish=publish,
|
|
620
649
|
promote=promote,
|
|
621
650
|
preserve_previous_prod_deployment=preserve_previous_production_deployment,
|
|
@@ -625,7 +654,9 @@ def push(
|
|
|
625
654
|
progress_bar=progress.Progress,
|
|
626
655
|
include_git_info=include_git_info,
|
|
627
656
|
preserve_env_instance_type=preserve_env_instance_type,
|
|
628
|
-
|
|
657
|
+
deploy_timeout_minutes=deploy_timeout_minutes,
|
|
658
|
+
team_id=team_id,
|
|
659
|
+
)
|
|
629
660
|
|
|
630
661
|
click.echo(f"✨ Model {model_name} was successfully pushed ✨")
|
|
631
662
|
|
truss/cli/logs/base_watcher.py
CHANGED
|
@@ -17,7 +17,9 @@ class LogWatcher(ABC):
|
|
|
17
17
|
# NB(nikhil): we add buffer for clock skew, so this helps us detect duplicates.
|
|
18
18
|
# TODO(nikhil): clean up hashes so this doesn't grow indefinitely.
|
|
19
19
|
_log_hashes: set[str] = set()
|
|
20
|
-
|
|
20
|
+
|
|
21
|
+
_last_poll_time_ms: Optional[int] = None
|
|
22
|
+
_last_log_time_ms: Optional[int] = None
|
|
21
23
|
|
|
22
24
|
def __init__(self, api: BasetenApi):
|
|
23
25
|
self.api = api
|
|
@@ -26,37 +28,54 @@ class LogWatcher(ABC):
|
|
|
26
28
|
log_str = f"{log.timestamp}-{log.message}-{log.replica}"
|
|
27
29
|
return hashlib.sha256(log_str.encode("utf-8")).hexdigest()
|
|
28
30
|
|
|
29
|
-
def
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
31
|
+
def get_start_epoch_ms(self, now_ms: int) -> Optional[int]:
|
|
32
|
+
if self._last_poll_time_ms:
|
|
33
|
+
return self._last_poll_time_ms - CLOCK_SKEW_BUFFER_MS
|
|
34
|
+
|
|
35
|
+
return None
|
|
34
36
|
|
|
37
|
+
def fetch_and_parse_logs(
|
|
38
|
+
self, start_epoch_millis: Optional[int], end_epoch_millis: Optional[int]
|
|
39
|
+
) -> Iterator[ParsedLog]:
|
|
35
40
|
api_logs = self.fetch_logs(
|
|
36
|
-
start_epoch_millis=
|
|
41
|
+
start_epoch_millis=start_epoch_millis, end_epoch_millis=end_epoch_millis
|
|
37
42
|
)
|
|
38
43
|
|
|
39
44
|
parsed_logs = parse_logs(api_logs)
|
|
45
|
+
|
|
40
46
|
for log in parsed_logs:
|
|
41
|
-
h
|
|
42
|
-
if h not in self._log_hashes:
|
|
47
|
+
if (h := self._hash_log(log)) not in self._log_hashes:
|
|
43
48
|
self._log_hashes.add(h)
|
|
49
|
+
|
|
44
50
|
yield log
|
|
45
51
|
|
|
46
|
-
|
|
52
|
+
def poll(self) -> Iterator[ParsedLog]:
|
|
53
|
+
now_ms = int(time.time() * 1000)
|
|
54
|
+
start_epoch_ms = self.get_start_epoch_ms(now_ms)
|
|
55
|
+
|
|
56
|
+
for log in self.fetch_and_parse_logs(
|
|
57
|
+
start_epoch_millis=start_epoch_ms,
|
|
58
|
+
end_epoch_millis=now_ms + CLOCK_SKEW_BUFFER_MS,
|
|
59
|
+
):
|
|
60
|
+
yield log
|
|
61
|
+
|
|
62
|
+
epoch_ns = int(log.timestamp)
|
|
63
|
+
self._last_log_time_ms = int(epoch_ns / 1e6)
|
|
64
|
+
|
|
65
|
+
self._last_poll_time_ms = now_ms
|
|
47
66
|
|
|
48
67
|
def watch(self) -> Iterator[ParsedLog]:
|
|
49
68
|
self.before_polling()
|
|
50
69
|
with console.status("Polling logs", spinner="aesthetic"):
|
|
51
70
|
while True:
|
|
52
|
-
for log in self.
|
|
71
|
+
for log in self.poll():
|
|
53
72
|
yield log
|
|
54
73
|
if self._log_hashes:
|
|
55
74
|
break
|
|
56
75
|
time.sleep(POLL_INTERVAL_SEC)
|
|
57
76
|
|
|
58
77
|
while self.should_poll_again():
|
|
59
|
-
for log in self.
|
|
78
|
+
for log in self.poll():
|
|
60
79
|
yield log
|
|
61
80
|
time.sleep(POLL_INTERVAL_SEC)
|
|
62
81
|
self.post_poll()
|
|
@@ -1,9 +1,12 @@
|
|
|
1
|
+
from functools import cached_property
|
|
1
2
|
from typing import Any, List, Optional
|
|
2
3
|
|
|
3
4
|
from truss.cli.logs.base_watcher import LogWatcher
|
|
4
5
|
from truss.remote.baseten.api import BasetenApi
|
|
5
6
|
from truss.remote.baseten.utils.status import MODEL_RUNNING_STATES
|
|
6
7
|
|
|
8
|
+
MAX_LOOK_BACK_MS = 1000 * 60 * 60 # 1 hour.
|
|
9
|
+
|
|
7
10
|
|
|
8
11
|
class ModelDeploymentLogWatcher(LogWatcher):
|
|
9
12
|
_model_id: str
|
|
@@ -25,11 +28,31 @@ class ModelDeploymentLogWatcher(LogWatcher):
|
|
|
25
28
|
self._model_id, self._deployment_id, start_epoch_millis, end_epoch_millis
|
|
26
29
|
)
|
|
27
30
|
|
|
31
|
+
def get_start_epoch_ms(self, now_ms: int) -> Optional[int]:
|
|
32
|
+
# NOTE(Tyron): If there can be multiple replicas,
|
|
33
|
+
# we can't use a timestamp cursor to poll for logs.
|
|
34
|
+
if not self._is_development:
|
|
35
|
+
return super().get_start_epoch_ms(now_ms)
|
|
36
|
+
|
|
37
|
+
# Cursor logic.
|
|
38
|
+
|
|
39
|
+
if self._last_log_time_ms:
|
|
40
|
+
return max(self._last_log_time_ms, now_ms - MAX_LOOK_BACK_MS)
|
|
41
|
+
|
|
42
|
+
return None
|
|
43
|
+
|
|
28
44
|
def should_poll_again(self) -> bool:
|
|
29
45
|
return self._current_status in MODEL_RUNNING_STATES
|
|
30
46
|
|
|
47
|
+
def _get_deployment(self) -> Any:
|
|
48
|
+
return self.api.get_deployment(self._model_id, self._deployment_id)
|
|
49
|
+
|
|
31
50
|
def _get_current_status(self) -> str:
|
|
32
|
-
return self.
|
|
51
|
+
return self._get_deployment()["status"]
|
|
52
|
+
|
|
53
|
+
@cached_property
|
|
54
|
+
def _is_development(self) -> bool:
|
|
55
|
+
return self._get_deployment()["is_development"]
|
|
33
56
|
|
|
34
57
|
def post_poll(self) -> None:
|
|
35
58
|
self._current_status = self._get_current_status()
|
truss/cli/remote_cli.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
1
3
|
from InquirerPy import inquirer
|
|
2
4
|
from InquirerPy.validator import ValidationError, Validator
|
|
3
5
|
|
|
@@ -56,3 +58,30 @@ def inquire_remote_name() -> str:
|
|
|
56
58
|
|
|
57
59
|
def inquire_model_name() -> str:
|
|
58
60
|
return inquirer.text("📦 Name this model:", qmark="").execute()
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def get_team_id_from_name(
|
|
64
|
+
teams: dict[str, dict[str, str]], team_name: str
|
|
65
|
+
) -> Optional[str]:
|
|
66
|
+
team = teams.get(team_name)
|
|
67
|
+
return team["id"] if team else None
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def format_available_teams(teams: dict[str, dict[str, str]]) -> str:
|
|
71
|
+
team_names = list(teams.keys())
|
|
72
|
+
return ", ".join(team_names) if team_names else "none"
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def inquire_team(
|
|
76
|
+
existing_teams: Optional[dict[str, dict[str, str]]] = None,
|
|
77
|
+
) -> Optional[str]:
|
|
78
|
+
if existing_teams is not None:
|
|
79
|
+
selected_team_name = inquirer.select(
|
|
80
|
+
"👥 Which team do you want to push to?",
|
|
81
|
+
qmark="",
|
|
82
|
+
choices=list[str](existing_teams.keys()),
|
|
83
|
+
).execute()
|
|
84
|
+
return selected_team_name
|
|
85
|
+
|
|
86
|
+
# If no existing teams, return None (don't propagate team param)
|
|
87
|
+
return None
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""Team resolution logic for chains."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import click
|
|
6
|
+
|
|
7
|
+
from truss.cli import remote_cli
|
|
8
|
+
from truss.remote.baseten.remote import BasetenRemote
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def resolve_chain_team_name(
|
|
12
|
+
remote_provider: BasetenRemote,
|
|
13
|
+
provided_team_name: Optional[str],
|
|
14
|
+
existing_chain_name: Optional[str] = None,
|
|
15
|
+
existing_teams: Optional[dict[str, dict[str, str]]] = None,
|
|
16
|
+
) -> tuple[Optional[str], Optional[str]]:
|
|
17
|
+
"""Resolve team name and team_id from provided team name or by prompting the user.
|
|
18
|
+
Returns a tuple of (team_name, team_id).
|
|
19
|
+
This function handles 8 distinct scenarios organized into 3 high-level categories:
|
|
20
|
+
|
|
21
|
+
HIGH-LEVEL SCENARIO 1: --team PROVIDED
|
|
22
|
+
SCENARIO 1: Valid team name, user has access
|
|
23
|
+
→ Returns (team_name, team_id) for that team (no prompt, no error)
|
|
24
|
+
SCENARIO 2: Invalid team name (does not exist)
|
|
25
|
+
→ Raises ClickException with error message listing available teams
|
|
26
|
+
|
|
27
|
+
HIGH-LEVEL SCENARIO 2: --team NOT PROVIDED, Chain does not exist
|
|
28
|
+
SCENARIO 3: User has multiple teams, no existing chain
|
|
29
|
+
→ Prompts user to select a team via inquire_team()
|
|
30
|
+
SCENARIO 6: User has exactly one team, no existing chain
|
|
31
|
+
→ Returns (team_name, team_id) for the single team automatically (no prompt)
|
|
32
|
+
|
|
33
|
+
HIGH-LEVEL SCENARIO 3: --team NOT PROVIDED, Chain exists
|
|
34
|
+
SCENARIO 4: User has multiple teams, existing chain in exactly one team
|
|
35
|
+
→ Auto-detects and returns (team_name, team_id) for that team (no prompt)
|
|
36
|
+
SCENARIO 5: User has multiple teams, existing chain exists in multiple teams
|
|
37
|
+
→ Prompts user to select a team via inquire_team()
|
|
38
|
+
SCENARIO 7: User has exactly one team, existing chain matches the team
|
|
39
|
+
→ Auto-detects and returns (team_name, team_id) for the single team (no prompt)
|
|
40
|
+
SCENARIO 8: User has exactly one team, existing chain exists in different team
|
|
41
|
+
→ Returns (team_name, team_id) for the single team automatically (no prompt, uses user's only team)
|
|
42
|
+
"""
|
|
43
|
+
if existing_teams is None:
|
|
44
|
+
existing_teams = remote_provider.api.get_teams()
|
|
45
|
+
|
|
46
|
+
def _get_team_id(team_name: Optional[str]) -> Optional[str]:
|
|
47
|
+
if team_name and existing_teams:
|
|
48
|
+
team_data = existing_teams.get(team_name)
|
|
49
|
+
return team_data["id"] if team_data else None
|
|
50
|
+
return None
|
|
51
|
+
|
|
52
|
+
if provided_team_name is not None:
|
|
53
|
+
if provided_team_name not in existing_teams:
|
|
54
|
+
available_teams_str = remote_cli.format_available_teams(existing_teams)
|
|
55
|
+
raise click.ClickException(
|
|
56
|
+
f"Team '{provided_team_name}' does not exist. Available teams: {available_teams_str}"
|
|
57
|
+
)
|
|
58
|
+
return (provided_team_name, _get_team_id(provided_team_name))
|
|
59
|
+
|
|
60
|
+
existing_chains = None
|
|
61
|
+
if existing_chain_name is not None:
|
|
62
|
+
existing_chains = remote_provider.api.get_chains()
|
|
63
|
+
matching_chains = [
|
|
64
|
+
c for c in existing_chains if c.get("name") == existing_chain_name
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
if len(matching_chains) > 1:
|
|
68
|
+
selected_team_name = remote_cli.inquire_team(existing_teams=existing_teams)
|
|
69
|
+
return (selected_team_name, _get_team_id(selected_team_name))
|
|
70
|
+
|
|
71
|
+
if len(matching_chains) == 1:
|
|
72
|
+
chain_team = matching_chains[0].get("team")
|
|
73
|
+
chain_team_name = chain_team.get("name") if chain_team else None
|
|
74
|
+
if chain_team_name and chain_team_name in existing_teams:
|
|
75
|
+
return (chain_team_name, _get_team_id(chain_team_name))
|
|
76
|
+
|
|
77
|
+
if len(existing_teams) == 1:
|
|
78
|
+
single_team_name = list(existing_teams.keys())[0]
|
|
79
|
+
return (single_team_name, _get_team_id(single_team_name))
|
|
80
|
+
|
|
81
|
+
selected_team_name = remote_cli.inquire_team(existing_teams=existing_teams)
|
|
82
|
+
return (selected_team_name, _get_team_id(selected_team_name))
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""Team resolution logic for models."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import click
|
|
6
|
+
|
|
7
|
+
from truss.cli import remote_cli
|
|
8
|
+
from truss.remote.baseten.remote import BasetenRemote
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def resolve_model_team_name(
|
|
12
|
+
remote_provider: BasetenRemote,
|
|
13
|
+
provided_team_name: Optional[str],
|
|
14
|
+
existing_model_name: Optional[str] = None,
|
|
15
|
+
existing_teams: Optional[dict[str, dict[str, str]]] = None,
|
|
16
|
+
) -> tuple[Optional[str], Optional[str]]:
|
|
17
|
+
"""Resolve team name and team_id from provided team name or by prompting the user.
|
|
18
|
+
Returns a tuple of (team_name, team_id).
|
|
19
|
+
This function handles 8 distinct scenarios organized into 3 high-level categories:
|
|
20
|
+
|
|
21
|
+
HIGH-LEVEL SCENARIO 1: --team PROVIDED
|
|
22
|
+
SCENARIO 1: Valid team name, user has access
|
|
23
|
+
→ Returns (team_name, team_id) for that team (no prompt, no error)
|
|
24
|
+
SCENARIO 2: Invalid team name (does not exist)
|
|
25
|
+
→ Raises ClickException with error message listing available teams
|
|
26
|
+
|
|
27
|
+
HIGH-LEVEL SCENARIO 2: --team NOT PROVIDED, Model does not exist
|
|
28
|
+
SCENARIO 3: User has multiple teams, no existing model
|
|
29
|
+
→ Prompts user to select a team via inquire_team()
|
|
30
|
+
SCENARIO 6: User has exactly one team, no existing model
|
|
31
|
+
→ Returns (team_name, team_id) for the single team automatically (no prompt)
|
|
32
|
+
|
|
33
|
+
HIGH-LEVEL SCENARIO 3: --team NOT PROVIDED, Model exists
|
|
34
|
+
SCENARIO 4: User has multiple teams, existing model in exactly one team
|
|
35
|
+
→ Auto-detects and returns (team_name, team_id) for that team (no prompt)
|
|
36
|
+
SCENARIO 5: User has multiple teams, existing model exists in multiple teams
|
|
37
|
+
→ Prompts user to select a team via inquire_team()
|
|
38
|
+
SCENARIO 7: User has exactly one team, existing model matches the team
|
|
39
|
+
→ Auto-detects and returns (team_name, team_id) for the single team (no prompt)
|
|
40
|
+
SCENARIO 8: User has exactly one team, existing model exists in different team
|
|
41
|
+
→ Returns (team_name, team_id) for the single team automatically (no prompt, uses user's only team)
|
|
42
|
+
"""
|
|
43
|
+
if existing_teams is None:
|
|
44
|
+
existing_teams = remote_provider.api.get_teams()
|
|
45
|
+
|
|
46
|
+
def _get_team_id(team_name: Optional[str]) -> Optional[str]:
|
|
47
|
+
if team_name and existing_teams:
|
|
48
|
+
team_data = existing_teams.get(team_name)
|
|
49
|
+
return team_data["id"] if team_data else None
|
|
50
|
+
return None
|
|
51
|
+
|
|
52
|
+
def _get_matching_models_in_accessible_teams(model_name: str) -> list[dict]:
|
|
53
|
+
"""Get models matching the name that are in teams the user has access to."""
|
|
54
|
+
all_models_data = remote_provider.api.models()
|
|
55
|
+
accessible_team_ids = {team_data["id"] for team_data in existing_teams.values()}
|
|
56
|
+
|
|
57
|
+
return [
|
|
58
|
+
m
|
|
59
|
+
for m in all_models_data.get("models", [])
|
|
60
|
+
if m.get("name") == model_name
|
|
61
|
+
and m.get("team", {}).get("id") in accessible_team_ids
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
if provided_team_name is not None:
|
|
65
|
+
if provided_team_name not in existing_teams:
|
|
66
|
+
available_teams_str = remote_cli.format_available_teams(existing_teams)
|
|
67
|
+
raise click.ClickException(
|
|
68
|
+
f"Team '{provided_team_name}' does not exist. Available teams: {available_teams_str}"
|
|
69
|
+
)
|
|
70
|
+
return (provided_team_name, _get_team_id(provided_team_name))
|
|
71
|
+
|
|
72
|
+
if existing_model_name is not None:
|
|
73
|
+
matching_models = _get_matching_models_in_accessible_teams(existing_model_name)
|
|
74
|
+
|
|
75
|
+
if len(matching_models) == 1:
|
|
76
|
+
# Exactly one model in an accessible team - auto-detect
|
|
77
|
+
team = matching_models[0].get("team", {})
|
|
78
|
+
model_team_name = team.get("name")
|
|
79
|
+
model_team_id = team.get("id")
|
|
80
|
+
if model_team_name and model_team_name in existing_teams:
|
|
81
|
+
return (model_team_name, model_team_id)
|
|
82
|
+
# If len > 1, multiple models exist - fall through to prompt logic
|
|
83
|
+
# If len == 0, no models exist - fall through to prompt logic
|
|
84
|
+
|
|
85
|
+
if len(existing_teams) == 1:
|
|
86
|
+
single_team_name = list(existing_teams.keys())[0]
|
|
87
|
+
return (single_team_name, _get_team_id(single_team_name))
|
|
88
|
+
|
|
89
|
+
selected_team_name = remote_cli.inquire_team(existing_teams=existing_teams)
|
|
90
|
+
return (selected_team_name, _get_team_id(selected_team_name))
|