truss 0.11.15rc13__py3-none-any.whl → 0.11.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/cli/chains_commands.py +0 -10
- truss/contexts/image_builder/cache_warmer.py +14 -5
- truss/remote/baseten/api.py +4 -68
- truss/remote/baseten/core.py +0 -40
- truss/remote/baseten/custom_types.py +0 -1
- truss/remote/baseten/remote.py +0 -18
- truss/remote/baseten/utils/transfer.py +5 -2
- truss/tests/util/test_env_vars.py +8 -3
- truss/util/env_vars.py +19 -8
- {truss-0.11.15rc13.dist-info → truss-0.11.16.dist-info}/METADATA +1 -1
- {truss-0.11.15rc13.dist-info → truss-0.11.16.dist-info}/RECORD +17 -20
- truss_chains/deployment/deployment_client.py +0 -7
- truss_chains/private_types.py +0 -3
- truss_chains/public_api.py +0 -3
- truss/tests/cli/test_chains_cli.py +0 -100
- truss/tests/remote/baseten/test_chain_upload.py +0 -285
- truss/util/error_utils.py +0 -34
- {truss-0.11.15rc13.dist-info → truss-0.11.16.dist-info}/WHEEL +0 -0
- {truss-0.11.15rc13.dist-info → truss-0.11.16.dist-info}/entry_points.txt +0 -0
- {truss-0.11.15rc13.dist-info → truss-0.11.16.dist-info}/licenses/LICENSE +0 -0
truss/cli/chains_commands.py
CHANGED
|
@@ -211,14 +211,6 @@ def _create_chains_table(service) -> Tuple[rich.table.Table, List[str]]:
|
|
|
211
211
|
default=False,
|
|
212
212
|
help=common.INCLUDE_GIT_INFO_DOC,
|
|
213
213
|
)
|
|
214
|
-
@click.option(
|
|
215
|
-
"--disable-chain-download",
|
|
216
|
-
"disable_chain_download",
|
|
217
|
-
is_flag=True,
|
|
218
|
-
required=False,
|
|
219
|
-
default=False,
|
|
220
|
-
help="Disable downloading of pushed chain source code from the UI.",
|
|
221
|
-
)
|
|
222
214
|
@click.pass_context
|
|
223
215
|
@common.common_options()
|
|
224
216
|
def push_chain(
|
|
@@ -235,7 +227,6 @@ def push_chain(
|
|
|
235
227
|
environment: Optional[str],
|
|
236
228
|
experimental_watch_chainlet_names: Optional[str],
|
|
237
229
|
include_git_info: bool = False,
|
|
238
|
-
disable_chain_download: bool = False,
|
|
239
230
|
) -> None:
|
|
240
231
|
"""
|
|
241
232
|
Deploys a chain remotely.
|
|
@@ -293,7 +284,6 @@ def push_chain(
|
|
|
293
284
|
environment=environment,
|
|
294
285
|
include_git_info=include_git_info,
|
|
295
286
|
working_dir=source.parent if source.is_file() else source.resolve(),
|
|
296
|
-
disable_chain_download=disable_chain_download,
|
|
297
287
|
)
|
|
298
288
|
service = deployment_client.push(
|
|
299
289
|
entrypoint_cls, options, progress_bar=progress.Progress
|
|
@@ -11,11 +11,10 @@ from typing import Optional, Type
|
|
|
11
11
|
import boto3
|
|
12
12
|
from botocore import UNSIGNED
|
|
13
13
|
from botocore.client import Config
|
|
14
|
+
from botocore.exceptions import ClientError, NoCredentialsError
|
|
14
15
|
from google.cloud import storage
|
|
15
16
|
from huggingface_hub import hf_hub_download
|
|
16
17
|
|
|
17
|
-
from truss.util.error_utils import handle_client_error
|
|
18
|
-
|
|
19
18
|
B10CP_PATH_TRUSS_ENV_VAR_NAME = "B10CP_PATH_TRUSS"
|
|
20
19
|
|
|
21
20
|
GCS_CREDENTIALS = "/app/data/service_account.json"
|
|
@@ -189,14 +188,24 @@ class S3File(RepositoryFile):
|
|
|
189
188
|
if not dst_file.parent.exists():
|
|
190
189
|
dst_file.parent.mkdir(parents=True)
|
|
191
190
|
|
|
192
|
-
|
|
193
|
-
f"accessing S3 bucket {bucket_name} for file {file_name}"
|
|
194
|
-
):
|
|
191
|
+
try:
|
|
195
192
|
url = client.generate_presigned_url(
|
|
196
193
|
"get_object",
|
|
197
194
|
Params={"Bucket": bucket_name, "Key": file_name},
|
|
198
195
|
ExpiresIn=3600,
|
|
199
196
|
)
|
|
197
|
+
except NoCredentialsError as nce:
|
|
198
|
+
raise RuntimeError(
|
|
199
|
+
f"No AWS credentials found\nOriginal exception: {str(nce)}"
|
|
200
|
+
)
|
|
201
|
+
except ClientError as ce:
|
|
202
|
+
raise RuntimeError(
|
|
203
|
+
f"Client error when accessing the S3 bucket (check your credentials): {str(ce)}"
|
|
204
|
+
)
|
|
205
|
+
except Exception as exc:
|
|
206
|
+
raise RuntimeError(
|
|
207
|
+
f"File not found on S3 bucket: {file_name}\nOriginal exception: {str(exc)}"
|
|
208
|
+
)
|
|
200
209
|
|
|
201
210
|
download_file_using_b10cp(url, dst_file, self.file_name)
|
|
202
211
|
|
truss/remote/baseten/api.py
CHANGED
|
@@ -5,7 +5,6 @@ from typing import Any, Dict, List, Mapping, Optional
|
|
|
5
5
|
import requests
|
|
6
6
|
from pydantic import BaseModel, Field
|
|
7
7
|
|
|
8
|
-
from truss.base.custom_types import SafeModel
|
|
9
8
|
from truss.remote.baseten import custom_types as b10_types
|
|
10
9
|
from truss.remote.baseten.auth import ApiKey, AuthService
|
|
11
10
|
from truss.remote.baseten.custom_types import APIKeyCategory
|
|
@@ -14,29 +13,6 @@ from truss.remote.baseten.rest_client import RestAPIClient
|
|
|
14
13
|
from truss.remote.baseten.utils.transfer import base64_encoded_json_str
|
|
15
14
|
|
|
16
15
|
logger = logging.getLogger(__name__)
|
|
17
|
-
PARAMS_INDENT = "\n "
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class ChainAWSCredential(SafeModel):
|
|
21
|
-
aws_access_key_id: str
|
|
22
|
-
aws_secret_access_key: str
|
|
23
|
-
aws_session_token: str
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
class ChainUploadCredentials(SafeModel):
|
|
27
|
-
s3_bucket: str
|
|
28
|
-
s3_key: str
|
|
29
|
-
aws_access_key_id: str
|
|
30
|
-
aws_secret_access_key: str
|
|
31
|
-
aws_session_token: str
|
|
32
|
-
|
|
33
|
-
@property
|
|
34
|
-
def aws_credentials(self) -> ChainAWSCredential:
|
|
35
|
-
return ChainAWSCredential(
|
|
36
|
-
aws_access_key_id=self.aws_access_key_id,
|
|
37
|
-
aws_secret_access_key=self.aws_secret_access_key,
|
|
38
|
-
aws_session_token=self.aws_session_token,
|
|
39
|
-
)
|
|
40
16
|
|
|
41
17
|
|
|
42
18
|
class InstanceTypeV1(BaseModel):
|
|
@@ -323,11 +299,7 @@ class BasetenApi:
|
|
|
323
299
|
chain_name: Optional[str] = None,
|
|
324
300
|
environment: Optional[str] = None,
|
|
325
301
|
is_draft: bool = False,
|
|
326
|
-
original_source_artifact_s3_key: Optional[str] = None,
|
|
327
|
-
allow_truss_download: Optional[bool] = True,
|
|
328
302
|
):
|
|
329
|
-
if allow_truss_download is None:
|
|
330
|
-
allow_truss_download = True
|
|
331
303
|
entrypoint_str = _chainlet_data_atomic_to_graphql_mutation(entrypoint)
|
|
332
304
|
|
|
333
305
|
dependencies_str = ", ".join(
|
|
@@ -337,28 +309,13 @@ class BasetenApi:
|
|
|
337
309
|
]
|
|
338
310
|
)
|
|
339
311
|
|
|
340
|
-
params = []
|
|
341
|
-
if chain_id:
|
|
342
|
-
params.append(f'chain_id: "{chain_id}"')
|
|
343
|
-
if chain_name:
|
|
344
|
-
params.append(f'chain_name: "{chain_name}"')
|
|
345
|
-
if environment:
|
|
346
|
-
params.append(f'environment: "{environment}"')
|
|
347
|
-
if original_source_artifact_s3_key:
|
|
348
|
-
params.append(
|
|
349
|
-
f'original_source_artifact_s3_key: "{original_source_artifact_s3_key}"'
|
|
350
|
-
)
|
|
351
|
-
|
|
352
|
-
params.append(f"is_draft: {str(is_draft).lower()}")
|
|
353
|
-
if allow_truss_download is False:
|
|
354
|
-
params.append("allow_truss_download: false")
|
|
355
|
-
|
|
356
|
-
params_str = PARAMS_INDENT.join(params)
|
|
357
|
-
|
|
358
312
|
query_string = f"""
|
|
359
313
|
mutation ($trussUserEnv: String) {{
|
|
360
314
|
deploy_chain_atomic(
|
|
361
|
-
{
|
|
315
|
+
{f'chain_id: "{chain_id}"' if chain_id else ""}
|
|
316
|
+
{f'chain_name: "{chain_name}"' if chain_name else ""}
|
|
317
|
+
{f'environment: "{environment}"' if environment else ""}
|
|
318
|
+
is_draft: {str(is_draft).lower()}
|
|
362
319
|
entrypoint: {entrypoint_str}
|
|
363
320
|
dependencies: [{dependencies_str}]
|
|
364
321
|
truss_user_env: $trussUserEnv
|
|
@@ -700,29 +657,8 @@ class BasetenApi:
|
|
|
700
657
|
return resp_json["training_projects"]
|
|
701
658
|
|
|
702
659
|
def get_blob_credentials(self, blob_type: b10_types.BlobType):
|
|
703
|
-
if blob_type == b10_types.BlobType.CHAIN:
|
|
704
|
-
return self.get_chain_s3_upload_credentials()
|
|
705
660
|
return self._rest_api_client.get(f"v1/blobs/credentials/{blob_type.value}")
|
|
706
661
|
|
|
707
|
-
def get_chain_s3_upload_credentials(self) -> ChainUploadCredentials:
|
|
708
|
-
"""Get chain artifact credentials using GraphQL query."""
|
|
709
|
-
query = """
|
|
710
|
-
query {
|
|
711
|
-
chain_s3_upload_credentials {
|
|
712
|
-
s3_bucket
|
|
713
|
-
s3_key
|
|
714
|
-
aws_access_key_id
|
|
715
|
-
aws_secret_access_key
|
|
716
|
-
aws_session_token
|
|
717
|
-
}
|
|
718
|
-
}
|
|
719
|
-
"""
|
|
720
|
-
response = self._post_graphql_query(query)
|
|
721
|
-
|
|
722
|
-
return ChainUploadCredentials.model_validate(
|
|
723
|
-
response["data"]["chain_s3_upload_credentials"]
|
|
724
|
-
)
|
|
725
|
-
|
|
726
662
|
def get_training_job_metrics(
|
|
727
663
|
self,
|
|
728
664
|
project_id: str,
|
truss/remote/baseten/core.py
CHANGED
|
@@ -8,7 +8,6 @@ from typing import IO, TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tup
|
|
|
8
8
|
import requests
|
|
9
9
|
|
|
10
10
|
from truss.base.errors import ValidationError
|
|
11
|
-
from truss.util.error_utils import handle_client_error
|
|
12
11
|
|
|
13
12
|
if TYPE_CHECKING:
|
|
14
13
|
from rich import progress
|
|
@@ -81,8 +80,6 @@ class ChainDeploymentHandleAtomic(NamedTuple):
|
|
|
81
80
|
chain_id: str
|
|
82
81
|
chain_deployment_id: str
|
|
83
82
|
is_draft: bool
|
|
84
|
-
original_source_artifact_s3_key: Optional[str] = None
|
|
85
|
-
allow_truss_download: Optional[bool] = True
|
|
86
83
|
|
|
87
84
|
|
|
88
85
|
class ModelVersionHandle(NamedTuple):
|
|
@@ -130,8 +127,6 @@ def create_chain_atomic(
|
|
|
130
127
|
is_draft: bool,
|
|
131
128
|
truss_user_env: b10_types.TrussUserEnv,
|
|
132
129
|
environment: Optional[str],
|
|
133
|
-
original_source_artifact_s3_key: Optional[str] = None,
|
|
134
|
-
allow_truss_download: bool = True,
|
|
135
130
|
) -> ChainDeploymentHandleAtomic:
|
|
136
131
|
if environment and is_draft:
|
|
137
132
|
logging.info(
|
|
@@ -154,8 +149,6 @@ def create_chain_atomic(
|
|
|
154
149
|
chain_name=chain_name,
|
|
155
150
|
is_draft=True,
|
|
156
151
|
truss_user_env=truss_user_env,
|
|
157
|
-
original_source_artifact_s3_key=original_source_artifact_s3_key,
|
|
158
|
-
allow_truss_download=allow_truss_download,
|
|
159
152
|
)
|
|
160
153
|
elif chain_id:
|
|
161
154
|
# This is the only case where promote has relevance, since
|
|
@@ -169,8 +162,6 @@ def create_chain_atomic(
|
|
|
169
162
|
chain_id=chain_id,
|
|
170
163
|
environment=environment,
|
|
171
164
|
truss_user_env=truss_user_env,
|
|
172
|
-
original_source_artifact_s3_key=original_source_artifact_s3_key,
|
|
173
|
-
allow_truss_download=allow_truss_download,
|
|
174
165
|
)
|
|
175
166
|
except ApiError as e:
|
|
176
167
|
if (
|
|
@@ -191,8 +182,6 @@ def create_chain_atomic(
|
|
|
191
182
|
dependencies=dependencies,
|
|
192
183
|
chain_name=chain_name,
|
|
193
184
|
truss_user_env=truss_user_env,
|
|
194
|
-
original_source_artifact_s3_key=original_source_artifact_s3_key,
|
|
195
|
-
allow_truss_download=allow_truss_download,
|
|
196
185
|
)
|
|
197
186
|
|
|
198
187
|
return ChainDeploymentHandleAtomic(
|
|
@@ -200,8 +189,6 @@ def create_chain_atomic(
|
|
|
200
189
|
chain_id=res["chain_deployment"]["chain"]["id"],
|
|
201
190
|
hostname=res["chain_deployment"]["chain"]["hostname"],
|
|
202
191
|
is_draft=is_draft,
|
|
203
|
-
original_source_artifact_s3_key=original_source_artifact_s3_key,
|
|
204
|
-
allow_truss_download=allow_truss_download,
|
|
205
192
|
)
|
|
206
193
|
|
|
207
194
|
|
|
@@ -355,33 +342,6 @@ def upload_truss(
|
|
|
355
342
|
return s3_key
|
|
356
343
|
|
|
357
344
|
|
|
358
|
-
def upload_chain_artifact(
|
|
359
|
-
api: BasetenApi,
|
|
360
|
-
serialize_file: IO,
|
|
361
|
-
progress_bar: Optional[Type["progress.Progress"]],
|
|
362
|
-
) -> str:
|
|
363
|
-
"""
|
|
364
|
-
Upload a chain artifact to the Baseten remote.
|
|
365
|
-
|
|
366
|
-
Args:
|
|
367
|
-
api: BasetenApi instance
|
|
368
|
-
serialize_file: File-like object containing the serialized chain artifact
|
|
369
|
-
|
|
370
|
-
Returns:
|
|
371
|
-
The S3 key of the uploaded file
|
|
372
|
-
"""
|
|
373
|
-
credentials = api.get_chain_s3_upload_credentials()
|
|
374
|
-
with handle_client_error("Uploading chain source"):
|
|
375
|
-
multipart_upload_boto3(
|
|
376
|
-
serialize_file.name,
|
|
377
|
-
credentials.s3_bucket,
|
|
378
|
-
credentials.s3_key,
|
|
379
|
-
credentials.aws_credentials.model_dump(),
|
|
380
|
-
progress_bar,
|
|
381
|
-
)
|
|
382
|
-
return credentials.s3_key
|
|
383
|
-
|
|
384
|
-
|
|
385
345
|
def create_truss_service(
|
|
386
346
|
api: BasetenApi,
|
|
387
347
|
model_name: str,
|
truss/remote/baseten/remote.py
CHANGED
|
@@ -31,7 +31,6 @@ from truss.remote.baseten.core import (
|
|
|
31
31
|
get_model_and_versions,
|
|
32
32
|
get_prod_version_from_versions,
|
|
33
33
|
get_truss_watch_state,
|
|
34
|
-
upload_chain_artifact,
|
|
35
34
|
upload_truss,
|
|
36
35
|
validate_truss_config_against_backend,
|
|
37
36
|
)
|
|
@@ -264,11 +263,9 @@ class BasetenRemote(TrussRemote):
|
|
|
264
263
|
entrypoint_artifact: custom_types.ChainletArtifact,
|
|
265
264
|
dependency_artifacts: List[custom_types.ChainletArtifact],
|
|
266
265
|
truss_user_env: b10_types.TrussUserEnv,
|
|
267
|
-
chain_root: Optional[Path] = None,
|
|
268
266
|
publish: bool = False,
|
|
269
267
|
environment: Optional[str] = None,
|
|
270
268
|
progress_bar: Optional[Type["progress.Progress"]] = None,
|
|
271
|
-
disable_chain_download: bool = False,
|
|
272
269
|
) -> ChainDeploymentHandleAtomic:
|
|
273
270
|
# If we are promoting a model to an environment after deploy, it must be published.
|
|
274
271
|
# Draft models cannot be promoted.
|
|
@@ -288,7 +285,6 @@ class BasetenRemote(TrussRemote):
|
|
|
288
285
|
publish=publish,
|
|
289
286
|
origin=custom_types.ModelOrigin.CHAINS,
|
|
290
287
|
progress_bar=progress_bar,
|
|
291
|
-
disable_truss_download=disable_chain_download,
|
|
292
288
|
)
|
|
293
289
|
oracle_data = custom_types.OracleData(
|
|
294
290
|
model_name=push_data.model_name,
|
|
@@ -304,18 +300,6 @@ class BasetenRemote(TrussRemote):
|
|
|
304
300
|
)
|
|
305
301
|
)
|
|
306
302
|
|
|
307
|
-
# Upload raw chain artifact if chain_root is provided
|
|
308
|
-
raw_chain_s3_key = None
|
|
309
|
-
if chain_root is not None:
|
|
310
|
-
logging.info("Uploading source artifact")
|
|
311
|
-
# Create a tar file from the chain root directory
|
|
312
|
-
original_source_tar = archive_dir(dir=chain_root, progress_bar=progress_bar)
|
|
313
|
-
# Upload the chain artifact to S3
|
|
314
|
-
raw_chain_s3_key = upload_chain_artifact(
|
|
315
|
-
api=self._api,
|
|
316
|
-
serialize_file=original_source_tar,
|
|
317
|
-
progress_bar=progress_bar,
|
|
318
|
-
)
|
|
319
303
|
chain_deployment_handle = create_chain_atomic(
|
|
320
304
|
api=self._api,
|
|
321
305
|
chain_name=chain_name,
|
|
@@ -324,8 +308,6 @@ class BasetenRemote(TrussRemote):
|
|
|
324
308
|
is_draft=not publish,
|
|
325
309
|
truss_user_env=truss_user_env,
|
|
326
310
|
environment=environment,
|
|
327
|
-
original_source_artifact_s3_key=raw_chain_s3_key,
|
|
328
|
-
allow_truss_download=not disable_chain_download,
|
|
329
311
|
)
|
|
330
312
|
logging.info("Successfully pushed to baseten. Chain is building and deploying.")
|
|
331
313
|
return chain_deployment_handle
|
|
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Optional, Type
|
|
|
7
7
|
import boto3
|
|
8
8
|
from boto3.s3.transfer import TransferConfig
|
|
9
9
|
|
|
10
|
-
from truss.util.env_vars import
|
|
10
|
+
from truss.util.env_vars import modify_env_vars
|
|
11
11
|
|
|
12
12
|
if TYPE_CHECKING:
|
|
13
13
|
from rich import progress
|
|
@@ -26,7 +26,10 @@ def multipart_upload_boto3(
|
|
|
26
26
|
) -> None:
|
|
27
27
|
# In the CLI flow, ignore any local ~/.aws/config files,
|
|
28
28
|
# which can interfere with uploading the Truss to S3.
|
|
29
|
-
|
|
29
|
+
aws_env_vars = set(
|
|
30
|
+
env_var for env_var in os.environ.keys() if env_var.startswith("AWS_")
|
|
31
|
+
)
|
|
32
|
+
with modify_env_vars(deletions=aws_env_vars):
|
|
30
33
|
s3_resource = boto3.resource("s3", **credentials)
|
|
31
34
|
filesize = os.stat(file_path).st_size
|
|
32
35
|
|
|
@@ -1,14 +1,19 @@
|
|
|
1
1
|
import os
|
|
2
2
|
|
|
3
|
-
from truss.util.env_vars import
|
|
3
|
+
from truss.util.env_vars import modify_env_vars
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
def
|
|
6
|
+
def test_modify_env_vars():
|
|
7
7
|
os.environ["API_KEY"] = "original_key"
|
|
8
|
+
os.environ["AWS_CONFIG_FILE"] = "original_config_file"
|
|
8
9
|
|
|
9
|
-
with
|
|
10
|
+
with modify_env_vars(
|
|
11
|
+
overrides={"API_KEY": "new_key", "DEBUG": "true"}, deletions={"AWS_CONFIG_FILE"}
|
|
12
|
+
):
|
|
10
13
|
assert os.environ["API_KEY"] == "new_key"
|
|
11
14
|
assert os.environ["DEBUG"] == "true"
|
|
15
|
+
assert "AWS_CONFIG_FILE" not in os.environ
|
|
12
16
|
|
|
13
17
|
assert os.environ["API_KEY"] == "original_key"
|
|
14
18
|
assert "DEBUG" not in os.environ
|
|
19
|
+
assert os.environ["AWS_CONFIG_FILE"] == "original_config_file"
|
truss/util/env_vars.py
CHANGED
|
@@ -1,32 +1,43 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from typing import Dict, Optional
|
|
2
|
+
from typing import Dict, Optional, Set
|
|
3
3
|
|
|
4
4
|
|
|
5
|
-
class
|
|
5
|
+
class modify_env_vars:
|
|
6
6
|
"""A context manager for temporarily overwriting environment variables.
|
|
7
7
|
|
|
8
8
|
Usage:
|
|
9
|
-
with
|
|
9
|
+
with modify_env_vars(overrides={'API_KEY': 'test_key', 'DEBUG': 'true'}, deletions={'AWS_CONFIG_FILE'}):
|
|
10
10
|
# Environment variables are modified here
|
|
11
11
|
...
|
|
12
12
|
# Original environment is restored here
|
|
13
13
|
"""
|
|
14
14
|
|
|
15
|
-
def __init__(
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
overrides: Optional[Dict[str, str]] = None,
|
|
18
|
+
deletions: Optional[Set[str]] = None,
|
|
19
|
+
):
|
|
16
20
|
"""
|
|
17
21
|
Args:
|
|
18
|
-
|
|
22
|
+
overrides: Dictionary of environment variables to set
|
|
23
|
+
deletions: Set of environment variables to delete
|
|
19
24
|
"""
|
|
20
|
-
self.
|
|
25
|
+
self.overrides: Dict[str, str] = overrides or dict()
|
|
26
|
+
self.deletions: Set[str] = deletions or set()
|
|
21
27
|
self.original_vars: Dict[str, Optional[str]] = {}
|
|
22
28
|
|
|
23
29
|
def __enter__(self):
|
|
24
|
-
|
|
30
|
+
all_keys = set(self.overrides.keys()) | self.deletions
|
|
31
|
+
for key in all_keys:
|
|
25
32
|
self.original_vars[key] = os.environ.get(key)
|
|
26
33
|
|
|
27
|
-
for key, value in self.
|
|
34
|
+
for key, value in self.overrides.items():
|
|
28
35
|
os.environ[key] = value
|
|
29
36
|
|
|
37
|
+
for key in self.deletions:
|
|
38
|
+
if key in os.environ:
|
|
39
|
+
del os.environ[key]
|
|
40
|
+
|
|
30
41
|
return self
|
|
31
42
|
|
|
32
43
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
@@ -8,7 +8,7 @@ truss/base/errors.py,sha256=zDVLEvseTChdPP0oNhBBQCtQUtZJUaof5zeWMIjqz6o,691
|
|
|
8
8
|
truss/base/trt_llm_config.py,sha256=rEtBVFg2QnNMxnaz11s5Z69dJB1w7Bpt48Wf6jSsVZI,33087
|
|
9
9
|
truss/base/truss_config.py,sha256=s39Xc1e20s8IV07YLl_aVnp-uRS18ZQ2TV-3FILx4nY,28416
|
|
10
10
|
truss/base/truss_spec.py,sha256=jFVF79CXoEEspl2kXBAPyi-rwISReIGTdobGpaIhwJw,5979
|
|
11
|
-
truss/cli/chains_commands.py,sha256=
|
|
11
|
+
truss/cli/chains_commands.py,sha256=Kpa5mCg6URAJQE2ZmZfVQFhjBHEitKT28tKiW0H6XAI,17406
|
|
12
12
|
truss/cli/cli.py,sha256=PaMkuwXZflkU7sa1tEoT_Zmy-iBkEZs1m4IVqcieaeo,30367
|
|
13
13
|
truss/cli/remote_cli.py,sha256=G_xCKRXzgkCmkiZJhUFfsv5YSVgde1jLA5LPQitpZgI,1905
|
|
14
14
|
truss/cli/train_commands.py,sha256=CrVqWsdkmSxgi3i2sSEyiE4QdfD0Z96F2Ib-PMZJjm8,20444
|
|
@@ -34,7 +34,7 @@ truss/cli/utils/output.py,sha256=GNjU85ZAMp5BI6Yij5wYXcaAvpm_kmHV0nHNmdkMxb0,646
|
|
|
34
34
|
truss/cli/utils/self_upgrade.py,sha256=eTJZA4Wc8uUp4Qh6viRQp6bZm--wnQp7KWe5KRRpPtg,5427
|
|
35
35
|
truss/contexts/docker_build_setup.py,sha256=cF4ExZgtYvrWxvyCAaUZUvV_DB_7__MqVomUDpalvKo,3925
|
|
36
36
|
truss/contexts/truss_context.py,sha256=uS6L-ACHxNk0BsJwESOHh1lA0OGGw0pb33aFKGsASj4,436
|
|
37
|
-
truss/contexts/image_builder/cache_warmer.py,sha256=
|
|
37
|
+
truss/contexts/image_builder/cache_warmer.py,sha256=TGMV1Mh87n2e_dSowH0sf0rZhZraDOR-LVapZL3a5r8,7377
|
|
38
38
|
truss/contexts/image_builder/image_builder.py,sha256=IuRgDeeoHVLzIkJvKtX3807eeqEyaroCs_KWDcIHZUg,1461
|
|
39
39
|
truss/contexts/image_builder/serving_image_builder.py,sha256=1PfHtkTEdNPhSQAX8Ajk_0LN3KR2EfLKwOJsnECtKXQ,33958
|
|
40
40
|
truss/contexts/image_builder/util.py,sha256=y2-CjUKv0XV-0w2sr1fUCflysDJLsoU4oPp6tvvoFnk,1203
|
|
@@ -52,18 +52,18 @@ truss/patch/truss_dir_patch_applier.py,sha256=ALnaVnu96g0kF2UmGuBFTua3lrXpwAy4sG
|
|
|
52
52
|
truss/remote/remote_factory.py,sha256=-0gLh_yIyNDgD48Q6sR8Yo5dOMQg84lrHRvn_XR0n4s,3585
|
|
53
53
|
truss/remote/truss_remote.py,sha256=TEe6h6by5-JLy7PMFsDN2QxIY5FmdIYN3bKvHHl02xM,8440
|
|
54
54
|
truss/remote/baseten/__init__.py,sha256=XNqJW1zyp143XQc6-7XVwsUA_Q_ZJv_ausn1_Ohtw9Y,176
|
|
55
|
-
truss/remote/baseten/api.py,sha256=
|
|
55
|
+
truss/remote/baseten/api.py,sha256=5B5IXNy0v8hRHNH2ar3rldDa47kwt5s1PtKZQ9_pfmE,28263
|
|
56
56
|
truss/remote/baseten/auth.py,sha256=tI7s6cI2EZgzpMIzrdbILHyGwiHDnmoKf_JBhJXT55E,776
|
|
57
|
-
truss/remote/baseten/core.py,sha256=
|
|
58
|
-
truss/remote/baseten/custom_types.py,sha256=
|
|
57
|
+
truss/remote/baseten/core.py,sha256=uxtmBI9RAVHu1glIEJb5Q4ccJYLeZM1Cp5Svb9W68Yw,21965
|
|
58
|
+
truss/remote/baseten/custom_types.py,sha256=bYrfTzGgYr6FDoya0omyadCLSTcTc-83U2scQORyUj0,4715
|
|
59
59
|
truss/remote/baseten/error.py,sha256=3TNTwwPqZnr4NRd9Sl6SfLUQR2fz9l6akDPpOntTpzA,578
|
|
60
|
-
truss/remote/baseten/remote.py,sha256=
|
|
60
|
+
truss/remote/baseten/remote.py,sha256=Se8AES5mk8jxa8S9fN2DSG7wnsaV7ftRjJ4Uwc_w_S0,22544
|
|
61
61
|
truss/remote/baseten/rest_client.py,sha256=_t3CWsWARt2u0C0fDsF4rtvkkHe-lH7KXoPxWXAkKd4,1185
|
|
62
62
|
truss/remote/baseten/service.py,sha256=HMaKiYbr2Mzv4BfXF9QkJ8H3Wwrq3LOMpFt9js4t0rs,5834
|
|
63
63
|
truss/remote/baseten/utils/status.py,sha256=jputc9N9AHXxUuW4KOk6mcZKzQ_gOBOe5BSx9K0DxPY,1266
|
|
64
64
|
truss/remote/baseten/utils/tar.py,sha256=pMUv--YkwXDngUx1WUOK-KmAIKMcOg2E-CD5x4heh3s,2514
|
|
65
65
|
truss/remote/baseten/utils/time.py,sha256=Ry9GMjYnbIGYVIGwtmv4V8ljWjvdcaCf5NOQzlNeGxI,397
|
|
66
|
-
truss/remote/baseten/utils/transfer.py,sha256=
|
|
66
|
+
truss/remote/baseten/utils/transfer.py,sha256=vMI-Dcd_HaRkqVjWI02y3eK9DjiqgG0ULayS0KxIIvA,1652
|
|
67
67
|
truss/templates/README.md.jinja,sha256=N7CJdyldZuJamj5jLh47le0hFBdu9irVsTBqoxhPNPQ,2476
|
|
68
68
|
truss/templates/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
69
69
|
truss/templates/base.Dockerfile.jinja,sha256=tdMmK5TeiQuYbz4gqbACM3R-l-mazqL9tAZtJ4sxC4g,5331
|
|
@@ -141,7 +141,6 @@ truss/tests/test_testing_utilities_for_other_tests.py,sha256=YqIKflnd_BUMYaDBSkX
|
|
|
141
141
|
truss/tests/test_truss_gatherer.py,sha256=bn288OEkC49YY0mhly4cAl410ktZPfElNdWwZy82WfA,1261
|
|
142
142
|
truss/tests/test_truss_handle.py,sha256=-xz9VXkecXDTslmQZ-dmUmQLnvD0uumRqHS2uvGlMBA,30750
|
|
143
143
|
truss/tests/test_util.py,sha256=hs1bNMkXKEdoPRx4Nw-NAEdoibR92OubZuADGmbiYsQ,1344
|
|
144
|
-
truss/tests/cli/test_chains_cli.py,sha256=l9GTQrhRm9SRZn43WkMY4tdRslLmdsVyiydRPa1_Ja4,3162
|
|
145
144
|
truss/tests/cli/test_cli.py,sha256=yfbVS5u1hnAmmA8mJ539vj3lhH-JVGUvC4Q_Mbort44,787
|
|
146
145
|
truss/tests/cli/train/test_cache_view.py,sha256=aVRCh3atRpFbJqyYgq7N-vAW0DiKMftQ7ajUqO2ClOg,22606
|
|
147
146
|
truss/tests/cli/train/test_deploy_checkpoints.py,sha256=Ndkd9YxEgDLf3zLAZYH0myFK_wkKTz0oGZ57yWQt_l8,10100
|
|
@@ -163,7 +162,6 @@ truss/tests/remote/test_truss_remote.py,sha256=Rguyrnbx5RlbPJHFfCtsRtX1czAJ9Fo0a
|
|
|
163
162
|
truss/tests/remote/baseten/conftest.py,sha256=vNk0nfDB7XdmqatOMhjdANCWFGYM4VwSHVKlaBO2PPk,442
|
|
164
163
|
truss/tests/remote/baseten/test_api.py,sha256=AKJeNsrUtTNa0QPClfEvXlBOSJ214PKp23ULehMRJOQ,15885
|
|
165
164
|
truss/tests/remote/baseten/test_auth.py,sha256=ttu4bDnmwGfo3oiNut4HVGnh-QnjAefwZJctiibQJKY,669
|
|
166
|
-
truss/tests/remote/baseten/test_chain_upload.py,sha256=XaaF1ocovkBYsLMJ8EpXB9FUGfQZAwu4iyOWqoVn7tc,10886
|
|
167
165
|
truss/tests/remote/baseten/test_core.py,sha256=6NzJTDmoSUv6Muy1LFEYIUg10-cqw-hbLyeTSWcdNjY,26117
|
|
168
166
|
truss/tests/remote/baseten/test_remote.py,sha256=y1qSPL1t7dBeYI3xMFn436fttG7wkYdAoENTz7qKObg,23634
|
|
169
167
|
truss/tests/remote/baseten/test_service.py,sha256=ehbGkzzSPdLN7JHxc0O9YDPfzzKqU8OBzJGjRdw08zE,3786
|
|
@@ -318,7 +316,7 @@ truss/tests/test_data/test_truss_with_error/packages/helpers_2.py,sha256=q_UpVfX
|
|
|
318
316
|
truss/tests/trt_llm/test_trt_llm_config.py,sha256=lNQ4EEkOsiT17KvnvW1snCeEBd7K_cl9_Y0dko3qpn8,8505
|
|
319
317
|
truss/tests/trt_llm/test_validation.py,sha256=dmax2EHxRfqxJvWzV8uubkTef50833KBBHw-WkHufL8,2120
|
|
320
318
|
truss/tests/util/test_config_checks.py,sha256=aoZF_Q-eRd3qz5wjUqa8Cr_7qF2SxodXbBIY_DBuFWg,522
|
|
321
|
-
truss/tests/util/test_env_vars.py,sha256=
|
|
319
|
+
truss/tests/util/test_env_vars.py,sha256=kz5FlynWvXhNR9bhf-xC-0cXuyYyAOvtKLLxKLDUGf4,616
|
|
322
320
|
truss/tests/util/test_path.py,sha256=YfW3-IM_7iRsdR1Cb26KB1BkDsG_53_BUGBzoxY2Nog,7408
|
|
323
321
|
truss/trt_llm/config_checks.py,sha256=Efxb5l7vRNveDglse78untq9V-IgtLxejk3_8JKxN5I,4671
|
|
324
322
|
truss/trt_llm/validation.py,sha256=cse-EnmuHmRpwBuSc3IvmSnl-StSQCIFM1nssgnaRUQ,1848
|
|
@@ -341,8 +339,7 @@ truss/util/.truss_ignore,sha256=jpQA9ou-r_JEIcEHsUqGLHhir_m3d4IPGNyzKXtS-2g,3131
|
|
|
341
339
|
truss/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
342
340
|
truss/util/docker.py,sha256=6PD7kMBBrOjsdvgkuSv7JMgZbe3NoJIeGasljMm2SwA,3934
|
|
343
341
|
truss/util/download.py,sha256=1lfBwzyaNLEp7SAVrBd9BX5inZpkCVp8sBnS9RNoiJA,2521
|
|
344
|
-
truss/util/env_vars.py,sha256=
|
|
345
|
-
truss/util/error_utils.py,sha256=aO76Vf8LMlvhM28jRJ1qzNl4E5ZyvKK4TFQq_UhbQrk,1095
|
|
342
|
+
truss/util/env_vars.py,sha256=PmKsXdN-PX2-_xk9XcdHTFuRRWiFaMw2iNUKxE8B1Ro,1671
|
|
346
343
|
truss/util/gpu.py,sha256=YiEF_JZyzur0MDMJOebMuJBQxrHD9ApGI0aPpWdb5BU,553
|
|
347
344
|
truss/util/jinja.py,sha256=7KbuYNq55I3DGtImAiCvBwR0K9-z1Jo6gMhmsy4lNZE,333
|
|
348
345
|
truss/util/log_utils.py,sha256=LwSgRh2K7KFjKKqBxr-IirFxGIzHi1mUM7YEvujvHsE,1985
|
|
@@ -352,8 +349,8 @@ truss/util/requirements.py,sha256=6T4nVV_NbSl3mAEo-CAk3JFmyJ_RJD768QaR55RdUJQ,69
|
|
|
352
349
|
truss/util/user_config.py,sha256=CvBf5oouNyfdcFXOg3HFhELVW-THiuwyOYdW3aTxdHw,9130
|
|
353
350
|
truss_chains/__init__.py,sha256=QDw1YwdqMaQpz5Oltu2Eq2vzEX9fDrMoqnhtbeh60i4,1278
|
|
354
351
|
truss_chains/framework.py,sha256=CS7tSegPe2Q8UUT6CDkrtSrB3utr_1QN1jTEPjrj5Ug,67519
|
|
355
|
-
truss_chains/private_types.py,sha256=
|
|
356
|
-
truss_chains/public_api.py,sha256=
|
|
352
|
+
truss_chains/private_types.py,sha256=6CaQEPawFLXjEbJ-01lqfexJtUIekF_q61LNENWegFo,8917
|
|
353
|
+
truss_chains/public_api.py,sha256=0AXV6UdZIFAMycUNG_klgo4aLFmBZeKGfrulZEWzR0M,9532
|
|
357
354
|
truss_chains/public_types.py,sha256=RPr8jgKO_F_26F7H3CpwbidL-6euoKPdFHVpEIpYqrQ,29415
|
|
358
355
|
truss_chains/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
359
356
|
truss_chains/pydantic_numpy.py,sha256=MG8Ji_Inwo_JSfM2n7TPj8B-nbrBlDYsY3SOeBwD8fE,4289
|
|
@@ -361,7 +358,7 @@ truss_chains/streaming.py,sha256=DGl2LEAN67YwP7Nn9MK488KmYc4KopWmcHuE6WjyO1Q,125
|
|
|
361
358
|
truss_chains/utils.py,sha256=LvpCG2lnN6dqPqyX3PwLH9tyjUzqQN3N4WeEFROMHak,6291
|
|
362
359
|
truss_chains/deployment/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
363
360
|
truss_chains/deployment/code_gen.py,sha256=397FiSNZuW59J3Ma7N9GKGfvG_87BNFAXCIV8BW41t0,32669
|
|
364
|
-
truss_chains/deployment/deployment_client.py,sha256=
|
|
361
|
+
truss_chains/deployment/deployment_client.py,sha256=OoqkO3daktYzR2YsIcDvsuGfjR05X2K7QlA7wvFduzc,34208
|
|
365
362
|
truss_chains/reference_code/reference_chainlet.py,sha256=5feSeqGtrHDbldkfZCfX2R5YbbW0Uhc35mhaP2pXrHw,1340
|
|
366
363
|
truss_chains/reference_code/reference_model.py,sha256=emH3hb23E_nbP98I37PGp1Xk1hz3g3lQ00tiLo55cSM,322
|
|
367
364
|
truss_chains/remote_chainlet/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -374,8 +371,8 @@ truss_train/deployment.py,sha256=lWWANSuzBWu2M4oK4qD7n-oVR1JKdmw2Pn5BJQHg-Ck,307
|
|
|
374
371
|
truss_train/loader.py,sha256=0o66EjBaHc2YY4syxxHVR4ordJWs13lNXnKjKq2wq0U,1630
|
|
375
372
|
truss_train/public_api.py,sha256=9N_NstiUlmBuLUwH_fNG_1x7OhGCytZLNvqKXBlStrM,1220
|
|
376
373
|
truss_train/restore_from_checkpoint.py,sha256=8hdPm-WSgkt74HDPjvCjZMBpvA9MwtoYsxVjOoa7BaM,1176
|
|
377
|
-
truss-0.11.
|
|
378
|
-
truss-0.11.
|
|
379
|
-
truss-0.11.
|
|
380
|
-
truss-0.11.
|
|
381
|
-
truss-0.11.
|
|
374
|
+
truss-0.11.16.dist-info/METADATA,sha256=ykIbKzLWO3FqVipKHGs7Ad8yNkqGXoHm74AchWFVgHA,6678
|
|
375
|
+
truss-0.11.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
376
|
+
truss-0.11.16.dist-info/entry_points.txt,sha256=-MwKfHHQHQ6j0HqIgvxrz3CehCmczDLTD-OsRHnjjuU,130
|
|
377
|
+
truss-0.11.16.dist-info/licenses/LICENSE,sha256=FTqGzu85i-uw1Gi8E_o0oD60bH9yQ_XIGtZbA1QUYiw,1064
|
|
378
|
+
truss-0.11.16.dist-info/RECORD,,
|
|
@@ -516,21 +516,14 @@ def _create_baseten_chain(
|
|
|
516
516
|
|
|
517
517
|
_create_chains_secret_if_missing(remote_provider)
|
|
518
518
|
|
|
519
|
-
# Get chain root for raw chain artifact upload
|
|
520
|
-
chain_root = None
|
|
521
|
-
if entrypoint_descriptor is not None:
|
|
522
|
-
chain_root = _get_chain_root(entrypoint_descriptor.chainlet_cls)
|
|
523
|
-
|
|
524
519
|
chain_deployment_handle = remote_provider.push_chain_atomic(
|
|
525
520
|
baseten_options.chain_name,
|
|
526
521
|
entrypoint_artifact,
|
|
527
522
|
dependency_artifacts,
|
|
528
523
|
truss_user_env,
|
|
529
|
-
chain_root=chain_root,
|
|
530
524
|
publish=baseten_options.publish,
|
|
531
525
|
environment=baseten_options.environment,
|
|
532
526
|
progress_bar=progress_bar,
|
|
533
|
-
disable_chain_download=baseten_options.disable_chain_download,
|
|
534
527
|
)
|
|
535
528
|
return BasetenChainService(
|
|
536
529
|
baseten_options.chain_name,
|
truss_chains/private_types.py
CHANGED
|
@@ -265,7 +265,6 @@ class PushOptionsBaseten(PushOptions):
|
|
|
265
265
|
environment: Optional[str]
|
|
266
266
|
include_git_info: bool
|
|
267
267
|
working_dir: pathlib.Path
|
|
268
|
-
disable_chain_download: bool = False
|
|
269
268
|
|
|
270
269
|
@classmethod
|
|
271
270
|
def create(
|
|
@@ -278,7 +277,6 @@ class PushOptionsBaseten(PushOptions):
|
|
|
278
277
|
include_git_info: bool,
|
|
279
278
|
working_dir: pathlib.Path,
|
|
280
279
|
environment: Optional[str] = None,
|
|
281
|
-
disable_chain_download: bool = False,
|
|
282
280
|
) -> "PushOptionsBaseten":
|
|
283
281
|
if promote and not environment:
|
|
284
282
|
environment = PRODUCTION_ENVIRONMENT_NAME
|
|
@@ -292,7 +290,6 @@ class PushOptionsBaseten(PushOptions):
|
|
|
292
290
|
environment=environment,
|
|
293
291
|
include_git_info=include_git_info,
|
|
294
292
|
working_dir=working_dir,
|
|
295
|
-
disable_chain_download=disable_chain_download,
|
|
296
293
|
)
|
|
297
294
|
|
|
298
295
|
|
truss_chains/public_api.py
CHANGED
|
@@ -151,7 +151,6 @@ def push(
|
|
|
151
151
|
environment: Optional[str] = None,
|
|
152
152
|
progress_bar: Optional[Type["progress.Progress"]] = None,
|
|
153
153
|
include_git_info: bool = False,
|
|
154
|
-
disable_chain_download: bool = False,
|
|
155
154
|
) -> deployment_client.BasetenChainService:
|
|
156
155
|
"""
|
|
157
156
|
Deploys a chain remotely (with all dependent chainlets).
|
|
@@ -173,7 +172,6 @@ def push(
|
|
|
173
172
|
include_git_info: Whether to attach git versioning info (sha, branch, tag) to
|
|
174
173
|
deployments made from within a git repo. If set to True in `.trussrc`, it
|
|
175
174
|
will always be attached.
|
|
176
|
-
disable_chain_download: Disable downloading of pushed chain source code from the UI.
|
|
177
175
|
|
|
178
176
|
Returns:
|
|
179
177
|
A chain service handle to the deployed chain.
|
|
@@ -188,7 +186,6 @@ def push(
|
|
|
188
186
|
environment=environment,
|
|
189
187
|
include_git_info=include_git_info,
|
|
190
188
|
working_dir=pathlib.Path(inspect.getfile(entrypoint)).parent,
|
|
191
|
-
disable_chain_download=disable_chain_download,
|
|
192
189
|
)
|
|
193
190
|
service = deployment_client.push(entrypoint, options, progress_bar=progress_bar)
|
|
194
191
|
assert isinstance(
|
|
@@ -1,100 +0,0 @@
|
|
|
1
|
-
"""Tests for truss chains CLI commands."""
|
|
2
|
-
|
|
3
|
-
from unittest.mock import Mock, patch
|
|
4
|
-
|
|
5
|
-
from click.testing import CliRunner
|
|
6
|
-
|
|
7
|
-
from truss.cli.cli import truss_cli
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
def test_chains_push_with_disable_chain_download_flag():
|
|
11
|
-
"""Test that --disable-chain-download flag is properly parsed and passed through."""
|
|
12
|
-
runner = CliRunner()
|
|
13
|
-
|
|
14
|
-
mock_entrypoint_cls = Mock()
|
|
15
|
-
mock_entrypoint_cls.meta_data.chain_name = "test_chain"
|
|
16
|
-
mock_entrypoint_cls.display_name = "TestChain"
|
|
17
|
-
|
|
18
|
-
mock_service = Mock()
|
|
19
|
-
mock_service.run_remote_url = "http://test.com/run_remote"
|
|
20
|
-
mock_service.is_websocket = False
|
|
21
|
-
|
|
22
|
-
with patch(
|
|
23
|
-
"truss_chains.framework.ChainletImporter.import_target"
|
|
24
|
-
) as mock_importer:
|
|
25
|
-
with patch("truss_chains.deployment.deployment_client.push") as mock_push:
|
|
26
|
-
mock_importer.return_value.__enter__.return_value = mock_entrypoint_cls
|
|
27
|
-
mock_push.return_value = mock_service
|
|
28
|
-
|
|
29
|
-
result = runner.invoke(
|
|
30
|
-
truss_cli,
|
|
31
|
-
[
|
|
32
|
-
"chains",
|
|
33
|
-
"push",
|
|
34
|
-
"test_chain.py",
|
|
35
|
-
"--disable-chain-download",
|
|
36
|
-
"--remote",
|
|
37
|
-
"test_remote",
|
|
38
|
-
"--dryrun",
|
|
39
|
-
],
|
|
40
|
-
)
|
|
41
|
-
|
|
42
|
-
assert result.exit_code == 0
|
|
43
|
-
|
|
44
|
-
mock_push.assert_called_once()
|
|
45
|
-
call_args = mock_push.call_args
|
|
46
|
-
options = call_args[0][1]
|
|
47
|
-
|
|
48
|
-
assert hasattr(options, "disable_chain_download")
|
|
49
|
-
assert options.disable_chain_download is True
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
def test_chains_push_without_disable_chain_download_flag():
|
|
53
|
-
"""Test that disable_chain_download defaults to False when flag is not provided."""
|
|
54
|
-
runner = CliRunner()
|
|
55
|
-
|
|
56
|
-
mock_entrypoint_cls = Mock()
|
|
57
|
-
mock_entrypoint_cls.meta_data.chain_name = "test_chain"
|
|
58
|
-
mock_entrypoint_cls.display_name = "TestChain"
|
|
59
|
-
|
|
60
|
-
mock_service = Mock()
|
|
61
|
-
mock_service.run_remote_url = "http://test.com/run_remote"
|
|
62
|
-
mock_service.is_websocket = False
|
|
63
|
-
|
|
64
|
-
with patch(
|
|
65
|
-
"truss_chains.framework.ChainletImporter.import_target"
|
|
66
|
-
) as mock_importer:
|
|
67
|
-
with patch("truss_chains.deployment.deployment_client.push") as mock_push:
|
|
68
|
-
mock_importer.return_value.__enter__.return_value = mock_entrypoint_cls
|
|
69
|
-
mock_push.return_value = mock_service
|
|
70
|
-
|
|
71
|
-
result = runner.invoke(
|
|
72
|
-
truss_cli,
|
|
73
|
-
[
|
|
74
|
-
"chains",
|
|
75
|
-
"push",
|
|
76
|
-
"test_chain.py",
|
|
77
|
-
"--remote",
|
|
78
|
-
"test_remote",
|
|
79
|
-
"--dryrun",
|
|
80
|
-
],
|
|
81
|
-
)
|
|
82
|
-
|
|
83
|
-
assert result.exit_code == 0
|
|
84
|
-
|
|
85
|
-
mock_push.assert_called_once()
|
|
86
|
-
call_args = mock_push.call_args
|
|
87
|
-
options = call_args[0][1]
|
|
88
|
-
|
|
89
|
-
assert hasattr(options, "disable_chain_download")
|
|
90
|
-
assert options.disable_chain_download is False
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
def test_chains_push_help_includes_disable_chain_download():
|
|
94
|
-
"""Test that --disable-chain-download appears in the help output."""
|
|
95
|
-
runner = CliRunner()
|
|
96
|
-
|
|
97
|
-
result = runner.invoke(truss_cli, ["chains", "push", "--help"])
|
|
98
|
-
|
|
99
|
-
assert result.exit_code == 0
|
|
100
|
-
assert "--disable-chain-download" in result.output
|
|
@@ -1,285 +0,0 @@
|
|
|
1
|
-
import pathlib
|
|
2
|
-
import tempfile
|
|
3
|
-
from unittest.mock import Mock, patch
|
|
4
|
-
|
|
5
|
-
import pytest
|
|
6
|
-
|
|
7
|
-
from truss.remote.baseten import custom_types as b10_types
|
|
8
|
-
from truss.remote.baseten.api import BasetenApi
|
|
9
|
-
from truss.remote.baseten.core import upload_chain_artifact
|
|
10
|
-
from truss.remote.baseten.remote import BasetenRemote
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
@pytest.fixture
|
|
14
|
-
def mock_push_data():
|
|
15
|
-
"""Fixture providing mock push data for tests."""
|
|
16
|
-
mock_push_data = Mock()
|
|
17
|
-
mock_push_data.model_name = "test-model"
|
|
18
|
-
mock_push_data.s3_key = "models/test-key"
|
|
19
|
-
mock_push_data.encoded_config_str = "encoded_config"
|
|
20
|
-
mock_push_data.is_draft = False
|
|
21
|
-
mock_push_data.model_id = "model-id"
|
|
22
|
-
mock_push_data.version_name = None
|
|
23
|
-
return mock_push_data
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
@pytest.fixture
|
|
27
|
-
def mock_remote_context():
|
|
28
|
-
"""Fixture providing mock remote and context managers for tests."""
|
|
29
|
-
api = Mock(spec=BasetenApi)
|
|
30
|
-
|
|
31
|
-
remote = BasetenRemote("https://test.baseten.co", "test-api-key")
|
|
32
|
-
remote._api = api
|
|
33
|
-
|
|
34
|
-
chain_name = "test-chain"
|
|
35
|
-
entrypoint_artifact = Mock()
|
|
36
|
-
entrypoint_artifact.truss_dir = "/path/to/truss"
|
|
37
|
-
entrypoint_artifact.display_name = "entrypoint"
|
|
38
|
-
|
|
39
|
-
dependency_artifacts = []
|
|
40
|
-
truss_user_env = Mock()
|
|
41
|
-
chain_root = pathlib.Path("/path/to/chain")
|
|
42
|
-
|
|
43
|
-
with patch.object(remote, "_prepare_push") as mock_prepare_push:
|
|
44
|
-
with patch("truss.remote.baseten.remote.truss_build.load") as mock_load:
|
|
45
|
-
mock_truss_handle = Mock()
|
|
46
|
-
mock_truss_handle.spec.config.model_name = "test-model"
|
|
47
|
-
mock_load.return_value = mock_truss_handle
|
|
48
|
-
|
|
49
|
-
yield {
|
|
50
|
-
"remote": remote,
|
|
51
|
-
"api": api,
|
|
52
|
-
"chain_name": chain_name,
|
|
53
|
-
"entrypoint_artifact": entrypoint_artifact,
|
|
54
|
-
"dependency_artifacts": dependency_artifacts,
|
|
55
|
-
"truss_user_env": truss_user_env,
|
|
56
|
-
"chain_root": chain_root,
|
|
57
|
-
"mock_prepare_push": mock_prepare_push,
|
|
58
|
-
"mock_load": mock_load,
|
|
59
|
-
"mock_truss_handle": mock_truss_handle,
|
|
60
|
-
}
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
def test_get_blob_credentials_for_chain():
|
|
64
|
-
"""Test that get_blob_credentials works correctly for chain blob type using GraphQL."""
|
|
65
|
-
mock_graphql_response = {
|
|
66
|
-
"data": {
|
|
67
|
-
"chain_s3_upload_credentials": {
|
|
68
|
-
"s3_bucket": "test-chain-bucket",
|
|
69
|
-
"s3_key": "chains/test-uuid/chain.tgz",
|
|
70
|
-
"aws_access_key_id": "test_access_key",
|
|
71
|
-
"aws_secret_access_key": "test_secret_key",
|
|
72
|
-
"aws_session_token": "test_session_token",
|
|
73
|
-
}
|
|
74
|
-
}
|
|
75
|
-
}
|
|
76
|
-
|
|
77
|
-
# Create a real API instance and mock the GraphQL call
|
|
78
|
-
mock_auth_service = Mock()
|
|
79
|
-
mock_auth_service.authenticate.return_value = Mock(value="test-token")
|
|
80
|
-
api = BasetenApi("https://test.baseten.co", mock_auth_service)
|
|
81
|
-
with patch.object(api, "_post_graphql_query") as mock_graphql:
|
|
82
|
-
mock_graphql.return_value = mock_graphql_response
|
|
83
|
-
|
|
84
|
-
result = api.get_chain_s3_upload_credentials()
|
|
85
|
-
|
|
86
|
-
assert result.s3_bucket == "test-chain-bucket"
|
|
87
|
-
assert result.s3_key == "chains/test-uuid/chain.tgz"
|
|
88
|
-
assert result.aws_access_key_id == "test_access_key"
|
|
89
|
-
assert result.aws_secret_access_key == "test_secret_key"
|
|
90
|
-
assert result.aws_session_token == "test_session_token"
|
|
91
|
-
|
|
92
|
-
mock_graphql.assert_called_once()
|
|
93
|
-
call_args = mock_graphql.call_args
|
|
94
|
-
assert "chain_s3_upload_credentials" in call_args[0][0]
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
def test_get_blob_credentials_for_other_types_uses_rest():
|
|
98
|
-
"""Test that get_blob_credentials uses REST API for non-chain blob types."""
|
|
99
|
-
mock_response = {
|
|
100
|
-
"s3_bucket": "test-bucket",
|
|
101
|
-
"s3_key": "test-key",
|
|
102
|
-
"creds": {
|
|
103
|
-
"aws_access_key_id": "test_access_key",
|
|
104
|
-
"aws_secret_access_key": "test_secret_key",
|
|
105
|
-
"aws_session_token": "test_session_token",
|
|
106
|
-
},
|
|
107
|
-
}
|
|
108
|
-
|
|
109
|
-
mock_auth_service = Mock()
|
|
110
|
-
mock_auth_service.authenticate.return_value = Mock(value="test-token")
|
|
111
|
-
api = BasetenApi("https://test.baseten.co", mock_auth_service)
|
|
112
|
-
with patch.object(api, "_rest_api_client") as mock_client, patch.object(
|
|
113
|
-
api, "_post_graphql_query"
|
|
114
|
-
) as mock_graphql:
|
|
115
|
-
mock_client.get.return_value = mock_response
|
|
116
|
-
|
|
117
|
-
result = api.get_blob_credentials(b10_types.BlobType.MODEL)
|
|
118
|
-
|
|
119
|
-
assert result["s3_bucket"] == "test-bucket"
|
|
120
|
-
assert result["s3_key"] == "test-key"
|
|
121
|
-
|
|
122
|
-
mock_client.get.assert_called_once_with("v1/blobs/credentials/model")
|
|
123
|
-
mock_graphql.assert_not_called()
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
@patch("truss.remote.baseten.core.multipart_upload_boto3")
|
|
127
|
-
def test_upload_chain_artifact_function(mock_multipart_upload):
|
|
128
|
-
"""Test the upload_chain_artifact function."""
|
|
129
|
-
# Mock ChainUploadCredentials object
|
|
130
|
-
mock_credentials = Mock()
|
|
131
|
-
mock_credentials.s3_bucket = "test-chain-bucket"
|
|
132
|
-
mock_credentials.s3_key = "chains/test-uuid/chain.tgz"
|
|
133
|
-
mock_credentials.aws_credentials = Mock()
|
|
134
|
-
mock_credentials.aws_credentials.model_dump.return_value = {
|
|
135
|
-
"aws_access_key_id": "test_access_key",
|
|
136
|
-
"aws_secret_access_key": "test_secret_key",
|
|
137
|
-
"aws_session_token": "test_session_token",
|
|
138
|
-
}
|
|
139
|
-
|
|
140
|
-
api = Mock(spec=BasetenApi)
|
|
141
|
-
api.get_chain_s3_upload_credentials.return_value = mock_credentials
|
|
142
|
-
|
|
143
|
-
with tempfile.NamedTemporaryFile(suffix=".tgz", delete=False) as temp_file:
|
|
144
|
-
temp_file.write(b"test chain content")
|
|
145
|
-
temp_file.flush()
|
|
146
|
-
|
|
147
|
-
result = upload_chain_artifact(api, temp_file, None)
|
|
148
|
-
|
|
149
|
-
assert result == "chains/test-uuid/chain.tgz"
|
|
150
|
-
|
|
151
|
-
api.get_chain_s3_upload_credentials.assert_called_once_with()
|
|
152
|
-
|
|
153
|
-
mock_multipart_upload.assert_called_once()
|
|
154
|
-
call_args = mock_multipart_upload.call_args
|
|
155
|
-
assert call_args[0][0] == temp_file.name # file path
|
|
156
|
-
assert call_args[0][1] == "test-chain-bucket" # bucket
|
|
157
|
-
assert call_args[0][2] == "chains/test-uuid/chain.tgz" # key
|
|
158
|
-
assert call_args[0][3] == { # credentials
|
|
159
|
-
"aws_access_key_id": "test_access_key",
|
|
160
|
-
"aws_secret_access_key": "test_secret_key",
|
|
161
|
-
"aws_session_token": "test_session_token",
|
|
162
|
-
}
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
@patch("truss.remote.baseten.remote.upload_chain_artifact")
|
|
166
|
-
@patch("truss.remote.baseten.remote.archive_dir")
|
|
167
|
-
@patch("truss.remote.baseten.remote.create_chain_atomic")
|
|
168
|
-
def test_push_chain_atomic_with_chain_upload(
|
|
169
|
-
mock_create_chain_atomic,
|
|
170
|
-
mock_archive_dir,
|
|
171
|
-
mock_upload_chain_artifact,
|
|
172
|
-
mock_push_data,
|
|
173
|
-
mock_remote_context,
|
|
174
|
-
):
|
|
175
|
-
"""Test that push_chain_atomic uploads raw chain artifact when chain_root is provided."""
|
|
176
|
-
mock_create_chain_atomic.return_value = Mock()
|
|
177
|
-
mock_archive_dir.return_value = Mock()
|
|
178
|
-
mock_upload_chain_artifact.return_value = "chains/test-uuid/chain.tgz"
|
|
179
|
-
|
|
180
|
-
context = mock_remote_context
|
|
181
|
-
remote = context["remote"]
|
|
182
|
-
chain_name = context["chain_name"]
|
|
183
|
-
entrypoint_artifact = context["entrypoint_artifact"]
|
|
184
|
-
dependency_artifacts = context["dependency_artifacts"]
|
|
185
|
-
truss_user_env = context["truss_user_env"]
|
|
186
|
-
chain_root = context["chain_root"]
|
|
187
|
-
|
|
188
|
-
context["mock_prepare_push"].return_value = mock_push_data
|
|
189
|
-
|
|
190
|
-
result = remote.push_chain_atomic(
|
|
191
|
-
chain_name=chain_name,
|
|
192
|
-
entrypoint_artifact=entrypoint_artifact,
|
|
193
|
-
dependency_artifacts=dependency_artifacts,
|
|
194
|
-
truss_user_env=truss_user_env,
|
|
195
|
-
chain_root=chain_root,
|
|
196
|
-
publish=True,
|
|
197
|
-
)
|
|
198
|
-
assert result == mock_create_chain_atomic.return_value
|
|
199
|
-
|
|
200
|
-
mock_archive_dir.assert_called_once_with(dir=chain_root, progress_bar=None)
|
|
201
|
-
mock_upload_chain_artifact.assert_called_once()
|
|
202
|
-
|
|
203
|
-
mock_create_chain_atomic.assert_called_once()
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
@patch("truss.remote.baseten.remote.create_chain_atomic")
|
|
207
|
-
def test_push_chain_atomic_without_chain_upload(
|
|
208
|
-
mock_create_chain_atomic, mock_push_data, mock_remote_context
|
|
209
|
-
):
|
|
210
|
-
"""Test that push_chain_atomic skips chain upload when chain_root is None."""
|
|
211
|
-
mock_create_chain_atomic.return_value = Mock()
|
|
212
|
-
|
|
213
|
-
context = mock_remote_context
|
|
214
|
-
remote = context["remote"]
|
|
215
|
-
chain_name = context["chain_name"]
|
|
216
|
-
entrypoint_artifact = context["entrypoint_artifact"]
|
|
217
|
-
dependency_artifacts = context["dependency_artifacts"]
|
|
218
|
-
truss_user_env = context["truss_user_env"]
|
|
219
|
-
|
|
220
|
-
context["mock_prepare_push"].return_value = mock_push_data
|
|
221
|
-
|
|
222
|
-
with patch("truss.remote.baseten.remote.upload_chain_artifact") as mock_upload:
|
|
223
|
-
with patch(
|
|
224
|
-
"truss.remote.baseten.core.create_tar_with_progress_bar"
|
|
225
|
-
) as mock_tar:
|
|
226
|
-
# Call push_chain_atomic without chain_root
|
|
227
|
-
result = remote.push_chain_atomic(
|
|
228
|
-
chain_name=chain_name,
|
|
229
|
-
entrypoint_artifact=entrypoint_artifact,
|
|
230
|
-
dependency_artifacts=dependency_artifacts,
|
|
231
|
-
truss_user_env=truss_user_env,
|
|
232
|
-
chain_root=None, # No chain root
|
|
233
|
-
publish=True,
|
|
234
|
-
)
|
|
235
|
-
|
|
236
|
-
assert result
|
|
237
|
-
# Verify chain artifact upload was NOT called
|
|
238
|
-
mock_tar.assert_not_called()
|
|
239
|
-
mock_upload.assert_not_called()
|
|
240
|
-
|
|
241
|
-
mock_create_chain_atomic.assert_called_once()
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
@patch("truss.remote.baseten.core.multipart_upload_boto3")
|
|
245
|
-
def test_upload_chain_artifact_error_handling(mock_multipart_upload):
|
|
246
|
-
"""Test error handling in upload_chain_artifact."""
|
|
247
|
-
# Mock API to raise an exception
|
|
248
|
-
api = Mock(spec=BasetenApi)
|
|
249
|
-
api.get_chain_s3_upload_credentials.side_effect = Exception("API Error")
|
|
250
|
-
|
|
251
|
-
with tempfile.NamedTemporaryFile(suffix=".tgz") as temp_file:
|
|
252
|
-
# Should raise the exception
|
|
253
|
-
with pytest.raises(Exception, match="API Error"):
|
|
254
|
-
upload_chain_artifact(api, temp_file, None)
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
def test_upload_chain_artifact_credentials_extraction():
|
|
258
|
-
"""Test that credentials are properly extracted from API response."""
|
|
259
|
-
# Mock ChainUploadCredentials object
|
|
260
|
-
mock_credentials = Mock()
|
|
261
|
-
mock_credentials.s3_bucket = "test-bucket"
|
|
262
|
-
mock_credentials.s3_key = "chains/test-uuid/chain.tgz"
|
|
263
|
-
mock_credentials.aws_credentials = Mock()
|
|
264
|
-
mock_credentials.aws_credentials.model_dump.return_value = {
|
|
265
|
-
"aws_access_key_id": "access_key",
|
|
266
|
-
"aws_secret_access_key": "secret_key",
|
|
267
|
-
"aws_session_token": "session_token",
|
|
268
|
-
}
|
|
269
|
-
|
|
270
|
-
api = Mock(spec=BasetenApi)
|
|
271
|
-
api.get_chain_s3_upload_credentials.return_value = mock_credentials
|
|
272
|
-
|
|
273
|
-
with patch("truss.remote.baseten.core.multipart_upload_boto3") as mock_upload:
|
|
274
|
-
with tempfile.NamedTemporaryFile(suffix=".tgz") as temp_file:
|
|
275
|
-
upload_chain_artifact(api, temp_file, None)
|
|
276
|
-
|
|
277
|
-
call_args = mock_upload.call_args
|
|
278
|
-
credentials = call_args[0][3]
|
|
279
|
-
|
|
280
|
-
assert credentials == {
|
|
281
|
-
"aws_access_key_id": "access_key",
|
|
282
|
-
"aws_secret_access_key": "secret_key",
|
|
283
|
-
"aws_session_token": "session_token",
|
|
284
|
-
}
|
|
285
|
-
assert "extra_field" not in credentials
|
truss/util/error_utils.py
DELETED
|
@@ -1,34 +0,0 @@
|
|
|
1
|
-
from contextlib import contextmanager
|
|
2
|
-
from typing import Generator
|
|
3
|
-
|
|
4
|
-
from botocore.exceptions import ClientError, NoCredentialsError
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
@contextmanager
|
|
8
|
-
def handle_client_error(
|
|
9
|
-
operation_description: str = "AWS operation",
|
|
10
|
-
) -> Generator[None, None, None]:
|
|
11
|
-
"""
|
|
12
|
-
Context manager to handle common boto3 errors and convert them to RuntimeError.
|
|
13
|
-
|
|
14
|
-
Args:
|
|
15
|
-
operation_description: Description of the operation being performed for error messages
|
|
16
|
-
|
|
17
|
-
Raises:
|
|
18
|
-
RuntimeError: For NoCredentialsError, ClientError, and other exceptions
|
|
19
|
-
"""
|
|
20
|
-
try:
|
|
21
|
-
yield
|
|
22
|
-
except NoCredentialsError as nce:
|
|
23
|
-
raise RuntimeError(
|
|
24
|
-
f"No AWS credentials found for {operation_description}\nOriginal exception: {str(nce)}"
|
|
25
|
-
)
|
|
26
|
-
except ClientError as ce:
|
|
27
|
-
raise RuntimeError(
|
|
28
|
-
f"AWS client error when {operation_description} (check your credentials): {str(ce)}"
|
|
29
|
-
)
|
|
30
|
-
|
|
31
|
-
except Exception as exc:
|
|
32
|
-
raise RuntimeError(
|
|
33
|
-
f"Unexpected error `{exc}` during `{operation_description}`\nOriginal exception: {str(exc)}"
|
|
34
|
-
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|