truss 0.11.6rc102__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- truss/api/__init__.py +5 -2
- truss/base/constants.py +1 -0
- truss/base/trt_llm_config.py +14 -3
- truss/base/truss_config.py +19 -4
- truss/cli/chains_commands.py +49 -1
- truss/cli/cli.py +38 -7
- truss/cli/logs/base_watcher.py +31 -12
- truss/cli/logs/model_log_watcher.py +24 -1
- truss/cli/remote_cli.py +29 -0
- truss/cli/resolvers/chain_team_resolver.py +82 -0
- truss/cli/resolvers/model_team_resolver.py +90 -0
- truss/cli/resolvers/training_project_team_resolver.py +81 -0
- truss/cli/train/cache.py +332 -0
- truss/cli/train/core.py +57 -163
- truss/cli/train/deploy_checkpoints/__init__.py +2 -2
- truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +236 -103
- truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +1 -52
- truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +1 -86
- truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py +1 -85
- truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +1 -56
- truss/cli/train/types.py +18 -9
- truss/cli/train_commands.py +180 -35
- truss/cli/utils/common.py +40 -3
- truss/contexts/image_builder/serving_image_builder.py +17 -4
- truss/remote/baseten/api.py +215 -9
- truss/remote/baseten/core.py +63 -7
- truss/remote/baseten/custom_types.py +1 -0
- truss/remote/baseten/remote.py +42 -2
- truss/remote/baseten/service.py +0 -7
- truss/remote/baseten/utils/transfer.py +5 -2
- truss/templates/base.Dockerfile.jinja +8 -4
- truss/templates/control/control/application.py +51 -26
- truss/templates/control/control/endpoints.py +1 -5
- truss/templates/control/control/helpers/inference_server_process_controller.py +10 -4
- truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py +33 -18
- truss/templates/control/control/server.py +1 -1
- truss/templates/control/requirements.txt +1 -2
- truss/templates/docker_server/proxy.conf.jinja +13 -0
- truss/templates/docker_server/supervisord.conf.jinja +2 -1
- truss/templates/no_build.Dockerfile.jinja +1 -0
- truss/templates/server/requirements.txt +2 -3
- truss/templates/server/truss_server.py +2 -5
- truss/templates/server.Dockerfile.jinja +12 -12
- truss/templates/shared/lazy_data_resolver.py +214 -2
- truss/templates/shared/util.py +6 -5
- truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
- truss/tests/cli/test_chains_cli.py +144 -0
- truss/tests/cli/test_cli.py +134 -1
- truss/tests/cli/test_cli_utils_common.py +11 -0
- truss/tests/cli/test_model_team_resolver.py +279 -0
- truss/tests/cli/train/test_cache_view.py +240 -3
- truss/tests/cli/train/test_deploy_checkpoints.py +2 -846
- truss/tests/cli/train/test_train_cli_core.py +2 -2
- truss/tests/cli/train/test_train_team_parameter.py +395 -0
- truss/tests/conftest.py +187 -0
- truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
- truss/tests/remote/baseten/test_api.py +122 -3
- truss/tests/remote/baseten/test_chain_upload.py +294 -0
- truss/tests/remote/baseten/test_core.py +86 -0
- truss/tests/remote/baseten/test_remote.py +216 -288
- truss/tests/remote/baseten/test_service.py +56 -0
- truss/tests/templates/control/control/conftest.py +20 -0
- truss/tests/templates/control/control/test_endpoints.py +4 -0
- truss/tests/templates/control/control/test_server.py +8 -24
- truss/tests/templates/control/control/test_server_integration.py +4 -2
- truss/tests/test_config.py +21 -12
- truss/tests/test_data/server.Dockerfile +3 -1
- truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
- truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
- truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
- truss/tests/test_model_inference.py +13 -0
- truss/tests/util/test_env_vars.py +8 -3
- truss/util/__init__.py +0 -0
- truss/util/env_vars.py +19 -8
- truss/util/error_utils.py +37 -0
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/METADATA +2 -2
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/RECORD +88 -70
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
- truss_chains/deployment/deployment_client.py +16 -4
- truss_chains/private_types.py +18 -0
- truss_chains/public_api.py +3 -0
- truss_train/definitions.py +6 -4
- truss_train/deployment.py +43 -21
- truss_train/public_api.py +4 -2
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
|
@@ -32,6 +32,7 @@ from truss.base.constants import (
|
|
|
32
32
|
FILENAME_CONSTANTS_MAP,
|
|
33
33
|
MODEL_CACHE_PATH,
|
|
34
34
|
MODEL_DOCKERFILE_NAME,
|
|
35
|
+
NO_BUILD_DOCKERFILE_TEMPLATE_NAME,
|
|
35
36
|
REQUIREMENTS_TXT_FILENAME,
|
|
36
37
|
SERVER_CODE_DIR,
|
|
37
38
|
SERVER_DOCKERFILE_TEMPLATE_NAME,
|
|
@@ -577,7 +578,10 @@ class ServingImageBuilder(ImageBuilder):
|
|
|
577
578
|
else:
|
|
578
579
|
self.prepare_trtllm_decoder_build_dir(build_dir=build_dir)
|
|
579
580
|
|
|
580
|
-
if
|
|
581
|
+
if (
|
|
582
|
+
config.docker_server is not None
|
|
583
|
+
and config.docker_server.no_build is not True
|
|
584
|
+
):
|
|
581
585
|
self._copy_into_build_dir(
|
|
582
586
|
TEMPLATES_DIR / "docker_server_requirements.txt",
|
|
583
587
|
build_dir,
|
|
@@ -750,12 +754,21 @@ class ServingImageBuilder(ImageBuilder):
|
|
|
750
754
|
build_commands: List[str],
|
|
751
755
|
):
|
|
752
756
|
config = self._spec.config
|
|
757
|
+
|
|
753
758
|
data_dir = build_dir / config.data_dir
|
|
754
759
|
model_dir = build_dir / config.model_module_dir
|
|
755
760
|
bundled_packages_dir = build_dir / config.bundled_packages_dir
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
761
|
+
|
|
762
|
+
# Note: no-build deployment template doesn't use most of the template variables,
|
|
763
|
+
# because it tries to run the base image as-is to the extent possible.
|
|
764
|
+
if config.docker_server and config.docker_server.no_build:
|
|
765
|
+
dockerfile_template = read_template_from_fs(
|
|
766
|
+
TEMPLATES_DIR, NO_BUILD_DOCKERFILE_TEMPLATE_NAME
|
|
767
|
+
)
|
|
768
|
+
else:
|
|
769
|
+
dockerfile_template = read_template_from_fs(
|
|
770
|
+
TEMPLATES_DIR, SERVER_DOCKERFILE_TEMPLATE_NAME
|
|
771
|
+
)
|
|
759
772
|
python_version = truss_config.to_dotted_python_version(config.python_version)
|
|
760
773
|
if config.base_image:
|
|
761
774
|
base_image_name_and_tag = config.base_image.image
|
truss/remote/baseten/api.py
CHANGED
|
@@ -3,7 +3,9 @@ from enum import Enum
|
|
|
3
3
|
from typing import Any, Dict, List, Mapping, Optional
|
|
4
4
|
|
|
5
5
|
import requests
|
|
6
|
+
from pydantic import BaseModel, Field
|
|
6
7
|
|
|
8
|
+
from truss.base.custom_types import SafeModel
|
|
7
9
|
from truss.remote.baseten import custom_types as b10_types
|
|
8
10
|
from truss.remote.baseten.auth import ApiKey, AuthService
|
|
9
11
|
from truss.remote.baseten.custom_types import APIKeyCategory
|
|
@@ -12,6 +14,62 @@ from truss.remote.baseten.rest_client import RestAPIClient
|
|
|
12
14
|
from truss.remote.baseten.utils.transfer import base64_encoded_json_str
|
|
13
15
|
|
|
14
16
|
logger = logging.getLogger(__name__)
|
|
17
|
+
PARAMS_INDENT = "\n "
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ChainAWSCredential(SafeModel):
|
|
21
|
+
aws_access_key_id: str
|
|
22
|
+
aws_secret_access_key: str
|
|
23
|
+
aws_session_token: str
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ChainUploadCredentials(SafeModel):
|
|
27
|
+
s3_bucket: str
|
|
28
|
+
s3_key: str
|
|
29
|
+
aws_access_key_id: str
|
|
30
|
+
aws_secret_access_key: str
|
|
31
|
+
aws_session_token: str
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def aws_credentials(self) -> ChainAWSCredential:
|
|
35
|
+
return ChainAWSCredential(
|
|
36
|
+
aws_access_key_id=self.aws_access_key_id,
|
|
37
|
+
aws_secret_access_key=self.aws_secret_access_key,
|
|
38
|
+
aws_session_token=self.aws_session_token,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class InstanceTypeV1(BaseModel):
|
|
43
|
+
"""An instance type."""
|
|
44
|
+
|
|
45
|
+
id: str = Field(description="Identifier string for the instance type")
|
|
46
|
+
name: str = Field(description="Name of the instance type")
|
|
47
|
+
display_name: str = Field(
|
|
48
|
+
alias="displayName", description="Display name of the instance type"
|
|
49
|
+
)
|
|
50
|
+
gpu_count: int = Field(
|
|
51
|
+
alias="gpuCount", description="Number of GPUs on the instance type"
|
|
52
|
+
)
|
|
53
|
+
default: bool = Field(description="Whether this is the default instance type")
|
|
54
|
+
gpu_memory: Optional[int] = Field(alias="gpuMemory", description="GPU memory in MB")
|
|
55
|
+
node_count: int = Field(alias="nodeCount", description="Number of nodes")
|
|
56
|
+
gpu_type: Optional[str] = Field(
|
|
57
|
+
alias="gpuType", description="Type of GPU on the instance type"
|
|
58
|
+
)
|
|
59
|
+
millicpu_limit: int = Field(
|
|
60
|
+
alias="millicpuLimit", description="CPU limit of the instance type in millicpu"
|
|
61
|
+
)
|
|
62
|
+
memory_limit: int = Field(
|
|
63
|
+
alias="memoryLimit", description="Memory limit of the instance type in MB"
|
|
64
|
+
)
|
|
65
|
+
price: Optional[float] = Field(description="Price of the instance type")
|
|
66
|
+
limited_capacity: Optional[bool] = Field(
|
|
67
|
+
alias="limitedCapacity",
|
|
68
|
+
description="Whether this instance type has limited capacity",
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
class Config:
|
|
72
|
+
populate_by_name = True
|
|
15
73
|
|
|
16
74
|
|
|
17
75
|
API_URL_MAPPING = {
|
|
@@ -141,6 +199,9 @@ class BasetenApi:
|
|
|
141
199
|
allow_truss_download: bool = True,
|
|
142
200
|
deployment_name: Optional[str] = None,
|
|
143
201
|
origin: Optional[b10_types.ModelOrigin] = None,
|
|
202
|
+
environment: Optional[str] = None,
|
|
203
|
+
deploy_timeout_minutes: Optional[int] = None,
|
|
204
|
+
team_id: Optional[str] = None,
|
|
144
205
|
):
|
|
145
206
|
query_string = f"""
|
|
146
207
|
mutation ($trussUserEnv: String) {{
|
|
@@ -153,6 +214,9 @@ class BasetenApi:
|
|
|
153
214
|
allow_truss_download: {"true" if allow_truss_download else "false"}
|
|
154
215
|
{f'version_name: "{deployment_name}"' if deployment_name else ""}
|
|
155
216
|
{f"model_origin: {origin.value}" if origin else ""}
|
|
217
|
+
{f'environment_name: "{environment}"' if environment else ""}
|
|
218
|
+
{f"deploy_timeout_minutes: {deploy_timeout_minutes}" if deploy_timeout_minutes is not None else ""}
|
|
219
|
+
{f'team_id: "{team_id}"' if team_id else ""}
|
|
156
220
|
) {{
|
|
157
221
|
model_version {{
|
|
158
222
|
id
|
|
@@ -184,6 +248,7 @@ class BasetenApi:
|
|
|
184
248
|
deployment_name: Optional[str] = None,
|
|
185
249
|
environment: Optional[str] = None,
|
|
186
250
|
preserve_env_instance_type: bool = True,
|
|
251
|
+
deploy_timeout_minutes: Optional[int] = None,
|
|
187
252
|
):
|
|
188
253
|
query_string = f"""
|
|
189
254
|
mutation ($trussUserEnv: String) {{
|
|
@@ -197,6 +262,7 @@ class BasetenApi:
|
|
|
197
262
|
preserve_env_instance_type: {"true" if preserve_env_instance_type else "false"}
|
|
198
263
|
{f'name: "{deployment_name}"' if deployment_name else ""}
|
|
199
264
|
{f'environment_name: "{environment}"' if environment else ""}
|
|
265
|
+
{f"deploy_timeout_minutes: {deploy_timeout_minutes}" if deploy_timeout_minutes is not None else ""}
|
|
200
266
|
) {{
|
|
201
267
|
model_version {{
|
|
202
268
|
id
|
|
@@ -226,6 +292,8 @@ class BasetenApi:
|
|
|
226
292
|
truss_user_env: b10_types.TrussUserEnv,
|
|
227
293
|
allow_truss_download=True,
|
|
228
294
|
origin: Optional[b10_types.ModelOrigin] = None,
|
|
295
|
+
deploy_timeout_minutes: Optional[int] = None,
|
|
296
|
+
team_id: Optional[str] = None,
|
|
229
297
|
):
|
|
230
298
|
query_string = f"""
|
|
231
299
|
mutation ($trussUserEnv: String) {{
|
|
@@ -235,6 +303,8 @@ class BasetenApi:
|
|
|
235
303
|
truss_user_env: $trussUserEnv
|
|
236
304
|
allow_truss_download: {"true" if allow_truss_download else "false"}
|
|
237
305
|
{f"model_origin: {origin.value}" if origin else ""}
|
|
306
|
+
{f"deploy_timeout_minutes: {deploy_timeout_minutes}" if deploy_timeout_minutes is not None else ""}
|
|
307
|
+
{f'team_id: "{team_id}"' if team_id else ""}
|
|
238
308
|
) {{
|
|
239
309
|
model_version {{
|
|
240
310
|
id
|
|
@@ -265,7 +335,13 @@ class BasetenApi:
|
|
|
265
335
|
chain_name: Optional[str] = None,
|
|
266
336
|
environment: Optional[str] = None,
|
|
267
337
|
is_draft: bool = False,
|
|
338
|
+
original_source_artifact_s3_key: Optional[str] = None,
|
|
339
|
+
allow_truss_download: Optional[bool] = True,
|
|
340
|
+
deployment_name: Optional[str] = None,
|
|
341
|
+
team_id: Optional[str] = None,
|
|
268
342
|
):
|
|
343
|
+
if allow_truss_download is None:
|
|
344
|
+
allow_truss_download = True
|
|
269
345
|
entrypoint_str = _chainlet_data_atomic_to_graphql_mutation(entrypoint)
|
|
270
346
|
|
|
271
347
|
dependencies_str = ", ".join(
|
|
@@ -275,13 +351,32 @@ class BasetenApi:
|
|
|
275
351
|
]
|
|
276
352
|
)
|
|
277
353
|
|
|
354
|
+
params = []
|
|
355
|
+
if chain_id:
|
|
356
|
+
params.append(f'chain_id: "{chain_id}"')
|
|
357
|
+
if chain_name:
|
|
358
|
+
params.append(f'chain_name: "{chain_name}"')
|
|
359
|
+
if environment:
|
|
360
|
+
params.append(f'environment: "{environment}"')
|
|
361
|
+
if original_source_artifact_s3_key:
|
|
362
|
+
params.append(
|
|
363
|
+
f'original_source_artifact_s3_key: "{original_source_artifact_s3_key}"'
|
|
364
|
+
)
|
|
365
|
+
if team_id:
|
|
366
|
+
params.append(f'team_id: "{team_id}"')
|
|
367
|
+
|
|
368
|
+
params.append(f"is_draft: {str(is_draft).lower()}")
|
|
369
|
+
if allow_truss_download is False:
|
|
370
|
+
params.append("allow_truss_download: false")
|
|
371
|
+
if deployment_name:
|
|
372
|
+
params.append(f'deployment_name: "{deployment_name}"')
|
|
373
|
+
|
|
374
|
+
params_str = PARAMS_INDENT.join(params)
|
|
375
|
+
|
|
278
376
|
query_string = f"""
|
|
279
377
|
mutation ($trussUserEnv: String) {{
|
|
280
378
|
deploy_chain_atomic(
|
|
281
|
-
{
|
|
282
|
-
{f'chain_name: "{chain_name}"' if chain_name else ""}
|
|
283
|
-
{f'environment: "{environment}"' if environment else ""}
|
|
284
|
-
is_draft: {str(is_draft).lower()}
|
|
379
|
+
{params_str}
|
|
285
380
|
entrypoint: {entrypoint_str}
|
|
286
381
|
dependencies: [{dependencies_str}]
|
|
287
382
|
truss_user_env: $trussUserEnv
|
|
@@ -303,18 +398,24 @@ class BasetenApi:
|
|
|
303
398
|
|
|
304
399
|
return resp["data"]["deploy_chain_atomic"]
|
|
305
400
|
|
|
306
|
-
def get_chains(self):
|
|
401
|
+
def get_chains(self, team_id: Optional[str] = None):
|
|
307
402
|
query_string = """
|
|
308
403
|
{
|
|
309
404
|
chains {
|
|
310
405
|
id
|
|
311
406
|
name
|
|
407
|
+
team {
|
|
408
|
+
name
|
|
409
|
+
}
|
|
312
410
|
}
|
|
313
411
|
}
|
|
314
412
|
"""
|
|
315
413
|
|
|
316
414
|
resp = self._post_graphql_query(query_string)
|
|
317
|
-
|
|
415
|
+
chains = resp["data"]["chains"]
|
|
416
|
+
|
|
417
|
+
# TODO(COR-492): Filter by team_id in the backend
|
|
418
|
+
return chains
|
|
318
419
|
|
|
319
420
|
def get_chain_deployments(self, chain_id: str):
|
|
320
421
|
query_string = f"""
|
|
@@ -377,6 +478,10 @@ class BasetenApi:
|
|
|
377
478
|
models {
|
|
378
479
|
id,
|
|
379
480
|
name
|
|
481
|
+
team {
|
|
482
|
+
id
|
|
483
|
+
name
|
|
484
|
+
}
|
|
380
485
|
versions{
|
|
381
486
|
id,
|
|
382
487
|
semver,
|
|
@@ -416,6 +521,10 @@ class BasetenApi:
|
|
|
416
521
|
id
|
|
417
522
|
name
|
|
418
523
|
hostname
|
|
524
|
+
team {{
|
|
525
|
+
id
|
|
526
|
+
name
|
|
527
|
+
}}
|
|
419
528
|
versions {{
|
|
420
529
|
id
|
|
421
530
|
semver
|
|
@@ -568,10 +677,14 @@ class BasetenApi:
|
|
|
568
677
|
"v1/api_keys", body={"type": api_key_type.value, "name": name}
|
|
569
678
|
)
|
|
570
679
|
|
|
571
|
-
def upsert_training_project(self, training_project):
|
|
680
|
+
def upsert_training_project(self, training_project, team_id: Optional[str] = None):
|
|
681
|
+
if team_id:
|
|
682
|
+
endpoint = f"v1/teams/{team_id}/training_projects"
|
|
683
|
+
else:
|
|
684
|
+
endpoint = "v1/training_projects"
|
|
572
685
|
resp_json = self._rest_api_client.post(
|
|
573
|
-
|
|
574
|
-
body={"training_project": training_project.model_dump()},
|
|
686
|
+
endpoint,
|
|
687
|
+
body={"training_project": training_project.model_dump(exclude_none=True)},
|
|
575
688
|
)
|
|
576
689
|
return resp_json["training_project"]
|
|
577
690
|
|
|
@@ -623,8 +736,29 @@ class BasetenApi:
|
|
|
623
736
|
return resp_json["training_projects"]
|
|
624
737
|
|
|
625
738
|
def get_blob_credentials(self, blob_type: b10_types.BlobType):
|
|
739
|
+
if blob_type == b10_types.BlobType.CHAIN:
|
|
740
|
+
return self.get_chain_s3_upload_credentials()
|
|
626
741
|
return self._rest_api_client.get(f"v1/blobs/credentials/{blob_type.value}")
|
|
627
742
|
|
|
743
|
+
def get_chain_s3_upload_credentials(self) -> ChainUploadCredentials:
|
|
744
|
+
"""Get chain artifact credentials using GraphQL query."""
|
|
745
|
+
query = """
|
|
746
|
+
query {
|
|
747
|
+
chain_s3_upload_credentials {
|
|
748
|
+
s3_bucket
|
|
749
|
+
s3_key
|
|
750
|
+
aws_access_key_id
|
|
751
|
+
aws_secret_access_key
|
|
752
|
+
aws_session_token
|
|
753
|
+
}
|
|
754
|
+
}
|
|
755
|
+
"""
|
|
756
|
+
response = self._post_graphql_query(query)
|
|
757
|
+
|
|
758
|
+
return ChainUploadCredentials.model_validate(
|
|
759
|
+
response["data"]["chain_s3_upload_credentials"]
|
|
760
|
+
)
|
|
761
|
+
|
|
628
762
|
def get_training_job_metrics(
|
|
629
763
|
self,
|
|
630
764
|
project_id: str,
|
|
@@ -750,3 +884,75 @@ class BasetenApi:
|
|
|
750
884
|
|
|
751
885
|
# NB(nikhil): reverse order so latest logs are at the end
|
|
752
886
|
return resp_json["logs"][::-1]
|
|
887
|
+
|
|
888
|
+
def create_model_version_from_inference_template(self, request_data: dict):
|
|
889
|
+
"""
|
|
890
|
+
Create a model version from an inference template using GraphQL mutation.
|
|
891
|
+
|
|
892
|
+
Args:
|
|
893
|
+
request_data: Dictionary containing the request structure with metadata,
|
|
894
|
+
weights_sources, inference_stack, and instance_type_id
|
|
895
|
+
"""
|
|
896
|
+
query_string = """
|
|
897
|
+
mutation ($request: CreateModelVersionFromInferenceTemplateRequest!) {
|
|
898
|
+
create_model_version_from_inference_template(request: $request) {
|
|
899
|
+
model_version {
|
|
900
|
+
id
|
|
901
|
+
name
|
|
902
|
+
}
|
|
903
|
+
truss_config
|
|
904
|
+
}
|
|
905
|
+
}
|
|
906
|
+
"""
|
|
907
|
+
resp = self._post_graphql_query(
|
|
908
|
+
query_string, variables={"request": request_data}
|
|
909
|
+
)
|
|
910
|
+
return resp["data"]["create_model_version_from_inference_template"]
|
|
911
|
+
|
|
912
|
+
def get_instance_types(self) -> List[InstanceTypeV1]:
|
|
913
|
+
"""
|
|
914
|
+
Get all available instance types via GraphQL API.
|
|
915
|
+
"""
|
|
916
|
+
query_string = """
|
|
917
|
+
query Instances {
|
|
918
|
+
listedInstances: listed_instances {
|
|
919
|
+
id
|
|
920
|
+
name
|
|
921
|
+
millicpuLimit: millicpu_limit
|
|
922
|
+
memoryLimit: memory_limit
|
|
923
|
+
gpuCount: gpu_count
|
|
924
|
+
gpuType: gpu_type
|
|
925
|
+
gpuMemory: gpu_memory
|
|
926
|
+
default
|
|
927
|
+
displayName: display_name
|
|
928
|
+
nodeCount: node_count
|
|
929
|
+
price
|
|
930
|
+
limitedCapacity: limited_capacity
|
|
931
|
+
}
|
|
932
|
+
}
|
|
933
|
+
"""
|
|
934
|
+
|
|
935
|
+
resp = self._post_graphql_query(query_string)
|
|
936
|
+
instance_types_data = resp["data"]["listedInstances"]
|
|
937
|
+
return [
|
|
938
|
+
InstanceTypeV1(**instance_type) for instance_type in instance_types_data
|
|
939
|
+
]
|
|
940
|
+
|
|
941
|
+
def get_teams(self) -> Dict[str, Dict[str, str]]:
|
|
942
|
+
"""
|
|
943
|
+
Get all available teams via GraphQL API.
|
|
944
|
+
Returns a dictionary mapping team name to team data (with 'id' and 'name' keys).
|
|
945
|
+
"""
|
|
946
|
+
query_string = """
|
|
947
|
+
query Teams {
|
|
948
|
+
teams {
|
|
949
|
+
id
|
|
950
|
+
name
|
|
951
|
+
}
|
|
952
|
+
}
|
|
953
|
+
"""
|
|
954
|
+
|
|
955
|
+
resp = self._post_graphql_query(query_string)
|
|
956
|
+
teams_data = resp["data"]["teams"]
|
|
957
|
+
# Convert list to dict mapping team_name -> team
|
|
958
|
+
return {team["name"]: team for team in teams_data}
|
truss/remote/baseten/core.py
CHANGED
|
@@ -8,6 +8,7 @@ from typing import IO, TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tup
|
|
|
8
8
|
import requests
|
|
9
9
|
|
|
10
10
|
from truss.base.errors import ValidationError
|
|
11
|
+
from truss.util.error_utils import handle_client_error
|
|
11
12
|
|
|
12
13
|
if TYPE_CHECKING:
|
|
13
14
|
from rich import progress
|
|
@@ -80,6 +81,8 @@ class ChainDeploymentHandleAtomic(NamedTuple):
|
|
|
80
81
|
chain_id: str
|
|
81
82
|
chain_deployment_id: str
|
|
82
83
|
is_draft: bool
|
|
84
|
+
original_source_artifact_s3_key: Optional[str] = None
|
|
85
|
+
allow_truss_download: Optional[bool] = True
|
|
83
86
|
|
|
84
87
|
|
|
85
88
|
class ModelVersionHandle(NamedTuple):
|
|
@@ -89,19 +92,21 @@ class ModelVersionHandle(NamedTuple):
|
|
|
89
92
|
instance_type_name: Optional[str] = None
|
|
90
93
|
|
|
91
94
|
|
|
92
|
-
def get_chain_id_by_name(
|
|
95
|
+
def get_chain_id_by_name(
|
|
96
|
+
api: BasetenApi, chain_name: str, team_id: Optional[str] = None
|
|
97
|
+
) -> Optional[str]:
|
|
93
98
|
"""
|
|
94
99
|
Check if a chain with the given name exists in the Baseten remote.
|
|
95
100
|
|
|
96
101
|
Args:
|
|
97
102
|
api: BasetenApi instance
|
|
98
103
|
chain_name: Name of the chain to check for existence
|
|
104
|
+
team_id: Optional team_id to filter chains by team
|
|
99
105
|
|
|
100
106
|
Returns:
|
|
101
107
|
chain_id if present, otherwise None
|
|
102
108
|
"""
|
|
103
|
-
chains = api.get_chains()
|
|
104
|
-
|
|
109
|
+
chains = api.get_chains(team_id=team_id)
|
|
105
110
|
chain_name_id_mapping = {chain["name"]: chain["id"] for chain in chains}
|
|
106
111
|
return chain_name_id_mapping.get(chain_name)
|
|
107
112
|
|
|
@@ -127,6 +132,10 @@ def create_chain_atomic(
|
|
|
127
132
|
is_draft: bool,
|
|
128
133
|
truss_user_env: b10_types.TrussUserEnv,
|
|
129
134
|
environment: Optional[str],
|
|
135
|
+
original_source_artifact_s3_key: Optional[str] = None,
|
|
136
|
+
allow_truss_download: bool = True,
|
|
137
|
+
deployment_name: Optional[str] = None,
|
|
138
|
+
team_id: Optional[str] = None,
|
|
130
139
|
) -> ChainDeploymentHandleAtomic:
|
|
131
140
|
if environment and is_draft:
|
|
132
141
|
logging.info(
|
|
@@ -135,7 +144,7 @@ def create_chain_atomic(
|
|
|
135
144
|
)
|
|
136
145
|
is_draft = False
|
|
137
146
|
|
|
138
|
-
chain_id = get_chain_id_by_name(api, chain_name)
|
|
147
|
+
chain_id = get_chain_id_by_name(api, chain_name, team_id=team_id)
|
|
139
148
|
|
|
140
149
|
# TODO(Tyron): Refactor for better readability:
|
|
141
150
|
# 1. Prepare all arguments for `deploy_chain_atomic`.
|
|
@@ -149,6 +158,10 @@ def create_chain_atomic(
|
|
|
149
158
|
chain_name=chain_name,
|
|
150
159
|
is_draft=True,
|
|
151
160
|
truss_user_env=truss_user_env,
|
|
161
|
+
original_source_artifact_s3_key=original_source_artifact_s3_key,
|
|
162
|
+
allow_truss_download=allow_truss_download,
|
|
163
|
+
deployment_name=deployment_name,
|
|
164
|
+
team_id=team_id,
|
|
152
165
|
)
|
|
153
166
|
elif chain_id:
|
|
154
167
|
# This is the only case where promote has relevance, since
|
|
@@ -162,6 +175,10 @@ def create_chain_atomic(
|
|
|
162
175
|
chain_id=chain_id,
|
|
163
176
|
environment=environment,
|
|
164
177
|
truss_user_env=truss_user_env,
|
|
178
|
+
original_source_artifact_s3_key=original_source_artifact_s3_key,
|
|
179
|
+
allow_truss_download=allow_truss_download,
|
|
180
|
+
deployment_name=deployment_name,
|
|
181
|
+
team_id=team_id,
|
|
165
182
|
)
|
|
166
183
|
except ApiError as e:
|
|
167
184
|
if (
|
|
@@ -182,6 +199,10 @@ def create_chain_atomic(
|
|
|
182
199
|
dependencies=dependencies,
|
|
183
200
|
chain_name=chain_name,
|
|
184
201
|
truss_user_env=truss_user_env,
|
|
202
|
+
original_source_artifact_s3_key=original_source_artifact_s3_key,
|
|
203
|
+
allow_truss_download=allow_truss_download,
|
|
204
|
+
deployment_name=deployment_name,
|
|
205
|
+
team_id=team_id,
|
|
185
206
|
)
|
|
186
207
|
|
|
187
208
|
return ChainDeploymentHandleAtomic(
|
|
@@ -189,6 +210,8 @@ def create_chain_atomic(
|
|
|
189
210
|
chain_id=res["chain_deployment"]["chain"]["id"],
|
|
190
211
|
hostname=res["chain_deployment"]["chain"]["hostname"],
|
|
191
212
|
is_draft=is_draft,
|
|
213
|
+
original_source_artifact_s3_key=original_source_artifact_s3_key,
|
|
214
|
+
allow_truss_download=allow_truss_download,
|
|
192
215
|
)
|
|
193
216
|
|
|
194
217
|
|
|
@@ -342,6 +365,33 @@ def upload_truss(
|
|
|
342
365
|
return s3_key
|
|
343
366
|
|
|
344
367
|
|
|
368
|
+
def upload_chain_artifact(
|
|
369
|
+
api: BasetenApi,
|
|
370
|
+
serialize_file: IO,
|
|
371
|
+
progress_bar: Optional[Type["progress.Progress"]],
|
|
372
|
+
) -> str:
|
|
373
|
+
"""
|
|
374
|
+
Upload a chain artifact to the Baseten remote.
|
|
375
|
+
|
|
376
|
+
Args:
|
|
377
|
+
api: BasetenApi instance
|
|
378
|
+
serialize_file: File-like object containing the serialized chain artifact
|
|
379
|
+
|
|
380
|
+
Returns:
|
|
381
|
+
The S3 key of the uploaded file
|
|
382
|
+
"""
|
|
383
|
+
credentials = api.get_chain_s3_upload_credentials()
|
|
384
|
+
with handle_client_error("Uploading chain source"):
|
|
385
|
+
multipart_upload_boto3(
|
|
386
|
+
serialize_file.name,
|
|
387
|
+
credentials.s3_bucket,
|
|
388
|
+
credentials.s3_key,
|
|
389
|
+
credentials.aws_credentials.model_dump(),
|
|
390
|
+
progress_bar,
|
|
391
|
+
)
|
|
392
|
+
return credentials.s3_key
|
|
393
|
+
|
|
394
|
+
|
|
345
395
|
def create_truss_service(
|
|
346
396
|
api: BasetenApi,
|
|
347
397
|
model_name: str,
|
|
@@ -357,6 +407,8 @@ def create_truss_service(
|
|
|
357
407
|
origin: Optional[b10_types.ModelOrigin] = None,
|
|
358
408
|
environment: Optional[str] = None,
|
|
359
409
|
preserve_env_instance_type: bool = True,
|
|
410
|
+
deploy_timeout_minutes: Optional[int] = None,
|
|
411
|
+
team_id: Optional[str] = None,
|
|
360
412
|
) -> ModelVersionHandle:
|
|
361
413
|
"""
|
|
362
414
|
Create a model in the Baseten remote.
|
|
@@ -372,6 +424,7 @@ def create_truss_service(
|
|
|
372
424
|
to zero.
|
|
373
425
|
deployment_name: Name to apply to the created deployment. Not applied to
|
|
374
426
|
development model.
|
|
427
|
+
team_id: ID of the team to create the model in.
|
|
375
428
|
|
|
376
429
|
Returns:
|
|
377
430
|
A Model Version handle.
|
|
@@ -384,6 +437,8 @@ def create_truss_service(
|
|
|
384
437
|
truss_user_env,
|
|
385
438
|
allow_truss_download=allow_truss_download,
|
|
386
439
|
origin=origin,
|
|
440
|
+
deploy_timeout_minutes=deploy_timeout_minutes,
|
|
441
|
+
team_id=team_id,
|
|
387
442
|
)
|
|
388
443
|
|
|
389
444
|
return ModelVersionHandle(
|
|
@@ -398,9 +453,6 @@ def create_truss_service(
|
|
|
398
453
|
)
|
|
399
454
|
|
|
400
455
|
if model_id is None:
|
|
401
|
-
if environment and environment != PRODUCTION_ENVIRONMENT_NAME:
|
|
402
|
-
raise ValueError(NO_ENVIRONMENTS_EXIST_ERROR_MESSAGING)
|
|
403
|
-
|
|
404
456
|
model_version_json = api.create_model_from_truss(
|
|
405
457
|
model_name,
|
|
406
458
|
s3_key,
|
|
@@ -410,6 +462,9 @@ def create_truss_service(
|
|
|
410
462
|
allow_truss_download=allow_truss_download,
|
|
411
463
|
deployment_name=deployment_name,
|
|
412
464
|
origin=origin,
|
|
465
|
+
environment=environment,
|
|
466
|
+
deploy_timeout_minutes=deploy_timeout_minutes,
|
|
467
|
+
team_id=team_id,
|
|
413
468
|
)
|
|
414
469
|
|
|
415
470
|
return ModelVersionHandle(
|
|
@@ -434,6 +489,7 @@ def create_truss_service(
|
|
|
434
489
|
deployment_name=deployment_name,
|
|
435
490
|
environment=environment,
|
|
436
491
|
preserve_env_instance_type=preserve_env_instance_type,
|
|
492
|
+
deploy_timeout_minutes=deploy_timeout_minutes,
|
|
437
493
|
)
|
|
438
494
|
except ApiError as e:
|
|
439
495
|
if (
|