truss 0.11.16__py3-none-any.whl → 0.11.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/cli/chains_commands.py +10 -0
- truss/cli/cli.py +3 -2
- truss/remote/baseten/api.py +70 -4
- truss/remote/baseten/core.py +41 -3
- truss/remote/baseten/custom_types.py +1 -0
- truss/remote/baseten/remote.py +18 -0
- truss/tests/cli/test_chains_cli.py +100 -0
- truss/tests/remote/baseten/test_chain_upload.py +285 -0
- truss/util/error_utils.py +37 -0
- {truss-0.11.16.dist-info → truss-0.11.18.dist-info}/METADATA +1 -1
- {truss-0.11.16.dist-info → truss-0.11.18.dist-info}/RECORD +17 -14
- truss_chains/deployment/deployment_client.py +7 -0
- truss_chains/private_types.py +3 -0
- truss_chains/public_api.py +3 -0
- {truss-0.11.16.dist-info → truss-0.11.18.dist-info}/WHEEL +0 -0
- {truss-0.11.16.dist-info → truss-0.11.18.dist-info}/entry_points.txt +0 -0
- {truss-0.11.16.dist-info → truss-0.11.18.dist-info}/licenses/LICENSE +0 -0
truss/cli/chains_commands.py
CHANGED
|
@@ -211,6 +211,14 @@ def _create_chains_table(service) -> Tuple[rich.table.Table, List[str]]:
|
|
|
211
211
|
default=False,
|
|
212
212
|
help=common.INCLUDE_GIT_INFO_DOC,
|
|
213
213
|
)
|
|
214
|
+
@click.option(
|
|
215
|
+
"--disable-chain-download",
|
|
216
|
+
"disable_chain_download",
|
|
217
|
+
is_flag=True,
|
|
218
|
+
required=False,
|
|
219
|
+
default=False,
|
|
220
|
+
help="Disable downloading of pushed chain source code from the UI.",
|
|
221
|
+
)
|
|
214
222
|
@click.pass_context
|
|
215
223
|
@common.common_options()
|
|
216
224
|
def push_chain(
|
|
@@ -227,6 +235,7 @@ def push_chain(
|
|
|
227
235
|
environment: Optional[str],
|
|
228
236
|
experimental_watch_chainlet_names: Optional[str],
|
|
229
237
|
include_git_info: bool = False,
|
|
238
|
+
disable_chain_download: bool = False,
|
|
230
239
|
) -> None:
|
|
231
240
|
"""
|
|
232
241
|
Deploys a chain remotely.
|
|
@@ -284,6 +293,7 @@ def push_chain(
|
|
|
284
293
|
environment=environment,
|
|
285
294
|
include_git_info=include_git_info,
|
|
286
295
|
working_dir=source.parent if source.is_file() else source.resolve(),
|
|
296
|
+
disable_chain_download=disable_chain_download,
|
|
287
297
|
)
|
|
288
298
|
service = deployment_client.push(
|
|
289
299
|
entrypoint_cls, options, progress_bar=progress.Progress
|
truss/cli/cli.py
CHANGED
|
@@ -548,8 +548,9 @@ def push(
|
|
|
548
548
|
model_name = remote_cli.inquire_model_name()
|
|
549
549
|
|
|
550
550
|
if promote and environment:
|
|
551
|
-
|
|
552
|
-
|
|
551
|
+
raise click.UsageError(
|
|
552
|
+
"'promote' flag and 'environment' flag cannot both be specified."
|
|
553
|
+
)
|
|
553
554
|
if promote and not environment:
|
|
554
555
|
environment = PRODUCTION_ENVIRONMENT_NAME
|
|
555
556
|
|
truss/remote/baseten/api.py
CHANGED
|
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
|
|
5
5
|
import requests
|
|
6
6
|
from pydantic import BaseModel, Field
|
|
7
7
|
|
|
8
|
+
from truss.base.custom_types import SafeModel
|
|
8
9
|
from truss.remote.baseten import custom_types as b10_types
|
|
9
10
|
from truss.remote.baseten.auth import ApiKey, AuthService
|
|
10
11
|
from truss.remote.baseten.custom_types import APIKeyCategory
|
|
@@ -13,6 +14,29 @@ from truss.remote.baseten.rest_client import RestAPIClient
|
|
|
13
14
|
from truss.remote.baseten.utils.transfer import base64_encoded_json_str
|
|
14
15
|
|
|
15
16
|
logger = logging.getLogger(__name__)
|
|
17
|
+
PARAMS_INDENT = "\n "
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ChainAWSCredential(SafeModel):
|
|
21
|
+
aws_access_key_id: str
|
|
22
|
+
aws_secret_access_key: str
|
|
23
|
+
aws_session_token: str
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ChainUploadCredentials(SafeModel):
|
|
27
|
+
s3_bucket: str
|
|
28
|
+
s3_key: str
|
|
29
|
+
aws_access_key_id: str
|
|
30
|
+
aws_secret_access_key: str
|
|
31
|
+
aws_session_token: str
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def aws_credentials(self) -> ChainAWSCredential:
|
|
35
|
+
return ChainAWSCredential(
|
|
36
|
+
aws_access_key_id=self.aws_access_key_id,
|
|
37
|
+
aws_secret_access_key=self.aws_secret_access_key,
|
|
38
|
+
aws_session_token=self.aws_session_token,
|
|
39
|
+
)
|
|
16
40
|
|
|
17
41
|
|
|
18
42
|
class InstanceTypeV1(BaseModel):
|
|
@@ -175,6 +199,7 @@ class BasetenApi:
|
|
|
175
199
|
allow_truss_download: bool = True,
|
|
176
200
|
deployment_name: Optional[str] = None,
|
|
177
201
|
origin: Optional[b10_types.ModelOrigin] = None,
|
|
202
|
+
environment: Optional[str] = None,
|
|
178
203
|
):
|
|
179
204
|
query_string = f"""
|
|
180
205
|
mutation ($trussUserEnv: String) {{
|
|
@@ -187,6 +212,7 @@ class BasetenApi:
|
|
|
187
212
|
allow_truss_download: {"true" if allow_truss_download else "false"}
|
|
188
213
|
{f'version_name: "{deployment_name}"' if deployment_name else ""}
|
|
189
214
|
{f"model_origin: {origin.value}" if origin else ""}
|
|
215
|
+
{f'environment_name: "{environment}"' if environment else ""}
|
|
190
216
|
) {{
|
|
191
217
|
model_version {{
|
|
192
218
|
id
|
|
@@ -299,7 +325,11 @@ class BasetenApi:
|
|
|
299
325
|
chain_name: Optional[str] = None,
|
|
300
326
|
environment: Optional[str] = None,
|
|
301
327
|
is_draft: bool = False,
|
|
328
|
+
original_source_artifact_s3_key: Optional[str] = None,
|
|
329
|
+
allow_truss_download: Optional[bool] = True,
|
|
302
330
|
):
|
|
331
|
+
if allow_truss_download is None:
|
|
332
|
+
allow_truss_download = True
|
|
303
333
|
entrypoint_str = _chainlet_data_atomic_to_graphql_mutation(entrypoint)
|
|
304
334
|
|
|
305
335
|
dependencies_str = ", ".join(
|
|
@@ -309,13 +339,28 @@ class BasetenApi:
|
|
|
309
339
|
]
|
|
310
340
|
)
|
|
311
341
|
|
|
342
|
+
params = []
|
|
343
|
+
if chain_id:
|
|
344
|
+
params.append(f'chain_id: "{chain_id}"')
|
|
345
|
+
if chain_name:
|
|
346
|
+
params.append(f'chain_name: "{chain_name}"')
|
|
347
|
+
if environment:
|
|
348
|
+
params.append(f'environment: "{environment}"')
|
|
349
|
+
if original_source_artifact_s3_key:
|
|
350
|
+
params.append(
|
|
351
|
+
f'original_source_artifact_s3_key: "{original_source_artifact_s3_key}"'
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
params.append(f"is_draft: {str(is_draft).lower()}")
|
|
355
|
+
if allow_truss_download is False:
|
|
356
|
+
params.append("allow_truss_download: false")
|
|
357
|
+
|
|
358
|
+
params_str = PARAMS_INDENT.join(params)
|
|
359
|
+
|
|
312
360
|
query_string = f"""
|
|
313
361
|
mutation ($trussUserEnv: String) {{
|
|
314
362
|
deploy_chain_atomic(
|
|
315
|
-
{
|
|
316
|
-
{f'chain_name: "{chain_name}"' if chain_name else ""}
|
|
317
|
-
{f'environment: "{environment}"' if environment else ""}
|
|
318
|
-
is_draft: {str(is_draft).lower()}
|
|
363
|
+
{params_str}
|
|
319
364
|
entrypoint: {entrypoint_str}
|
|
320
365
|
dependencies: [{dependencies_str}]
|
|
321
366
|
truss_user_env: $trussUserEnv
|
|
@@ -657,8 +702,29 @@ class BasetenApi:
|
|
|
657
702
|
return resp_json["training_projects"]
|
|
658
703
|
|
|
659
704
|
def get_blob_credentials(self, blob_type: b10_types.BlobType):
|
|
705
|
+
if blob_type == b10_types.BlobType.CHAIN:
|
|
706
|
+
return self.get_chain_s3_upload_credentials()
|
|
660
707
|
return self._rest_api_client.get(f"v1/blobs/credentials/{blob_type.value}")
|
|
661
708
|
|
|
709
|
+
def get_chain_s3_upload_credentials(self) -> ChainUploadCredentials:
|
|
710
|
+
"""Get chain artifact credentials using GraphQL query."""
|
|
711
|
+
query = """
|
|
712
|
+
query {
|
|
713
|
+
chain_s3_upload_credentials {
|
|
714
|
+
s3_bucket
|
|
715
|
+
s3_key
|
|
716
|
+
aws_access_key_id
|
|
717
|
+
aws_secret_access_key
|
|
718
|
+
aws_session_token
|
|
719
|
+
}
|
|
720
|
+
}
|
|
721
|
+
"""
|
|
722
|
+
response = self._post_graphql_query(query)
|
|
723
|
+
|
|
724
|
+
return ChainUploadCredentials.model_validate(
|
|
725
|
+
response["data"]["chain_s3_upload_credentials"]
|
|
726
|
+
)
|
|
727
|
+
|
|
662
728
|
def get_training_job_metrics(
|
|
663
729
|
self,
|
|
664
730
|
project_id: str,
|
truss/remote/baseten/core.py
CHANGED
|
@@ -8,6 +8,7 @@ from typing import IO, TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tup
|
|
|
8
8
|
import requests
|
|
9
9
|
|
|
10
10
|
from truss.base.errors import ValidationError
|
|
11
|
+
from truss.util.error_utils import handle_client_error
|
|
11
12
|
|
|
12
13
|
if TYPE_CHECKING:
|
|
13
14
|
from rich import progress
|
|
@@ -80,6 +81,8 @@ class ChainDeploymentHandleAtomic(NamedTuple):
|
|
|
80
81
|
chain_id: str
|
|
81
82
|
chain_deployment_id: str
|
|
82
83
|
is_draft: bool
|
|
84
|
+
original_source_artifact_s3_key: Optional[str] = None
|
|
85
|
+
allow_truss_download: Optional[bool] = True
|
|
83
86
|
|
|
84
87
|
|
|
85
88
|
class ModelVersionHandle(NamedTuple):
|
|
@@ -127,6 +130,8 @@ def create_chain_atomic(
|
|
|
127
130
|
is_draft: bool,
|
|
128
131
|
truss_user_env: b10_types.TrussUserEnv,
|
|
129
132
|
environment: Optional[str],
|
|
133
|
+
original_source_artifact_s3_key: Optional[str] = None,
|
|
134
|
+
allow_truss_download: bool = True,
|
|
130
135
|
) -> ChainDeploymentHandleAtomic:
|
|
131
136
|
if environment and is_draft:
|
|
132
137
|
logging.info(
|
|
@@ -149,6 +154,8 @@ def create_chain_atomic(
|
|
|
149
154
|
chain_name=chain_name,
|
|
150
155
|
is_draft=True,
|
|
151
156
|
truss_user_env=truss_user_env,
|
|
157
|
+
original_source_artifact_s3_key=original_source_artifact_s3_key,
|
|
158
|
+
allow_truss_download=allow_truss_download,
|
|
152
159
|
)
|
|
153
160
|
elif chain_id:
|
|
154
161
|
# This is the only case where promote has relevance, since
|
|
@@ -162,6 +169,8 @@ def create_chain_atomic(
|
|
|
162
169
|
chain_id=chain_id,
|
|
163
170
|
environment=environment,
|
|
164
171
|
truss_user_env=truss_user_env,
|
|
172
|
+
original_source_artifact_s3_key=original_source_artifact_s3_key,
|
|
173
|
+
allow_truss_download=allow_truss_download,
|
|
165
174
|
)
|
|
166
175
|
except ApiError as e:
|
|
167
176
|
if (
|
|
@@ -182,6 +191,8 @@ def create_chain_atomic(
|
|
|
182
191
|
dependencies=dependencies,
|
|
183
192
|
chain_name=chain_name,
|
|
184
193
|
truss_user_env=truss_user_env,
|
|
194
|
+
original_source_artifact_s3_key=original_source_artifact_s3_key,
|
|
195
|
+
allow_truss_download=allow_truss_download,
|
|
185
196
|
)
|
|
186
197
|
|
|
187
198
|
return ChainDeploymentHandleAtomic(
|
|
@@ -189,6 +200,8 @@ def create_chain_atomic(
|
|
|
189
200
|
chain_id=res["chain_deployment"]["chain"]["id"],
|
|
190
201
|
hostname=res["chain_deployment"]["chain"]["hostname"],
|
|
191
202
|
is_draft=is_draft,
|
|
203
|
+
original_source_artifact_s3_key=original_source_artifact_s3_key,
|
|
204
|
+
allow_truss_download=allow_truss_download,
|
|
192
205
|
)
|
|
193
206
|
|
|
194
207
|
|
|
@@ -342,6 +355,33 @@ def upload_truss(
|
|
|
342
355
|
return s3_key
|
|
343
356
|
|
|
344
357
|
|
|
358
|
+
def upload_chain_artifact(
|
|
359
|
+
api: BasetenApi,
|
|
360
|
+
serialize_file: IO,
|
|
361
|
+
progress_bar: Optional[Type["progress.Progress"]],
|
|
362
|
+
) -> str:
|
|
363
|
+
"""
|
|
364
|
+
Upload a chain artifact to the Baseten remote.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
api: BasetenApi instance
|
|
368
|
+
serialize_file: File-like object containing the serialized chain artifact
|
|
369
|
+
|
|
370
|
+
Returns:
|
|
371
|
+
The S3 key of the uploaded file
|
|
372
|
+
"""
|
|
373
|
+
credentials = api.get_chain_s3_upload_credentials()
|
|
374
|
+
with handle_client_error("Uploading chain source"):
|
|
375
|
+
multipart_upload_boto3(
|
|
376
|
+
serialize_file.name,
|
|
377
|
+
credentials.s3_bucket,
|
|
378
|
+
credentials.s3_key,
|
|
379
|
+
credentials.aws_credentials.model_dump(),
|
|
380
|
+
progress_bar,
|
|
381
|
+
)
|
|
382
|
+
return credentials.s3_key
|
|
383
|
+
|
|
384
|
+
|
|
345
385
|
def create_truss_service(
|
|
346
386
|
api: BasetenApi,
|
|
347
387
|
model_name: str,
|
|
@@ -398,9 +438,6 @@ def create_truss_service(
|
|
|
398
438
|
)
|
|
399
439
|
|
|
400
440
|
if model_id is None:
|
|
401
|
-
if environment and environment != PRODUCTION_ENVIRONMENT_NAME:
|
|
402
|
-
raise ValueError(NO_ENVIRONMENTS_EXIST_ERROR_MESSAGING)
|
|
403
|
-
|
|
404
441
|
model_version_json = api.create_model_from_truss(
|
|
405
442
|
model_name,
|
|
406
443
|
s3_key,
|
|
@@ -410,6 +447,7 @@ def create_truss_service(
|
|
|
410
447
|
allow_truss_download=allow_truss_download,
|
|
411
448
|
deployment_name=deployment_name,
|
|
412
449
|
origin=origin,
|
|
450
|
+
environment=environment,
|
|
413
451
|
)
|
|
414
452
|
|
|
415
453
|
return ModelVersionHandle(
|
truss/remote/baseten/remote.py
CHANGED
|
@@ -31,6 +31,7 @@ from truss.remote.baseten.core import (
|
|
|
31
31
|
get_model_and_versions,
|
|
32
32
|
get_prod_version_from_versions,
|
|
33
33
|
get_truss_watch_state,
|
|
34
|
+
upload_chain_artifact,
|
|
34
35
|
upload_truss,
|
|
35
36
|
validate_truss_config_against_backend,
|
|
36
37
|
)
|
|
@@ -263,9 +264,11 @@ class BasetenRemote(TrussRemote):
|
|
|
263
264
|
entrypoint_artifact: custom_types.ChainletArtifact,
|
|
264
265
|
dependency_artifacts: List[custom_types.ChainletArtifact],
|
|
265
266
|
truss_user_env: b10_types.TrussUserEnv,
|
|
267
|
+
chain_root: Optional[Path] = None,
|
|
266
268
|
publish: bool = False,
|
|
267
269
|
environment: Optional[str] = None,
|
|
268
270
|
progress_bar: Optional[Type["progress.Progress"]] = None,
|
|
271
|
+
disable_chain_download: bool = False,
|
|
269
272
|
) -> ChainDeploymentHandleAtomic:
|
|
270
273
|
# If we are promoting a model to an environment after deploy, it must be published.
|
|
271
274
|
# Draft models cannot be promoted.
|
|
@@ -285,6 +288,7 @@ class BasetenRemote(TrussRemote):
|
|
|
285
288
|
publish=publish,
|
|
286
289
|
origin=custom_types.ModelOrigin.CHAINS,
|
|
287
290
|
progress_bar=progress_bar,
|
|
291
|
+
disable_truss_download=disable_chain_download,
|
|
288
292
|
)
|
|
289
293
|
oracle_data = custom_types.OracleData(
|
|
290
294
|
model_name=push_data.model_name,
|
|
@@ -300,6 +304,18 @@ class BasetenRemote(TrussRemote):
|
|
|
300
304
|
)
|
|
301
305
|
)
|
|
302
306
|
|
|
307
|
+
# Upload raw chain artifact if chain_root is provided
|
|
308
|
+
raw_chain_s3_key = None
|
|
309
|
+
if chain_root is not None:
|
|
310
|
+
logging.info("Uploading source artifact")
|
|
311
|
+
# Create a tar file from the chain root directory
|
|
312
|
+
original_source_tar = archive_dir(dir=chain_root, progress_bar=progress_bar)
|
|
313
|
+
# Upload the chain artifact to S3
|
|
314
|
+
raw_chain_s3_key = upload_chain_artifact(
|
|
315
|
+
api=self._api,
|
|
316
|
+
serialize_file=original_source_tar,
|
|
317
|
+
progress_bar=progress_bar,
|
|
318
|
+
)
|
|
303
319
|
chain_deployment_handle = create_chain_atomic(
|
|
304
320
|
api=self._api,
|
|
305
321
|
chain_name=chain_name,
|
|
@@ -308,6 +324,8 @@ class BasetenRemote(TrussRemote):
|
|
|
308
324
|
is_draft=not publish,
|
|
309
325
|
truss_user_env=truss_user_env,
|
|
310
326
|
environment=environment,
|
|
327
|
+
original_source_artifact_s3_key=raw_chain_s3_key,
|
|
328
|
+
allow_truss_download=not disable_chain_download,
|
|
311
329
|
)
|
|
312
330
|
logging.info("Successfully pushed to baseten. Chain is building and deploying.")
|
|
313
331
|
return chain_deployment_handle
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
"""Tests for truss chains CLI commands."""
|
|
2
|
+
|
|
3
|
+
from unittest.mock import Mock, patch
|
|
4
|
+
|
|
5
|
+
from click.testing import CliRunner
|
|
6
|
+
|
|
7
|
+
from truss.cli.cli import truss_cli
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def test_chains_push_with_disable_chain_download_flag():
|
|
11
|
+
"""Test that --disable-chain-download flag is properly parsed and passed through."""
|
|
12
|
+
runner = CliRunner()
|
|
13
|
+
|
|
14
|
+
mock_entrypoint_cls = Mock()
|
|
15
|
+
mock_entrypoint_cls.meta_data.chain_name = "test_chain"
|
|
16
|
+
mock_entrypoint_cls.display_name = "TestChain"
|
|
17
|
+
|
|
18
|
+
mock_service = Mock()
|
|
19
|
+
mock_service.run_remote_url = "http://test.com/run_remote"
|
|
20
|
+
mock_service.is_websocket = False
|
|
21
|
+
|
|
22
|
+
with patch(
|
|
23
|
+
"truss_chains.framework.ChainletImporter.import_target"
|
|
24
|
+
) as mock_importer:
|
|
25
|
+
with patch("truss_chains.deployment.deployment_client.push") as mock_push:
|
|
26
|
+
mock_importer.return_value.__enter__.return_value = mock_entrypoint_cls
|
|
27
|
+
mock_push.return_value = mock_service
|
|
28
|
+
|
|
29
|
+
result = runner.invoke(
|
|
30
|
+
truss_cli,
|
|
31
|
+
[
|
|
32
|
+
"chains",
|
|
33
|
+
"push",
|
|
34
|
+
"test_chain.py",
|
|
35
|
+
"--disable-chain-download",
|
|
36
|
+
"--remote",
|
|
37
|
+
"test_remote",
|
|
38
|
+
"--dryrun",
|
|
39
|
+
],
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
assert result.exit_code == 0
|
|
43
|
+
|
|
44
|
+
mock_push.assert_called_once()
|
|
45
|
+
call_args = mock_push.call_args
|
|
46
|
+
options = call_args[0][1]
|
|
47
|
+
|
|
48
|
+
assert hasattr(options, "disable_chain_download")
|
|
49
|
+
assert options.disable_chain_download is True
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def test_chains_push_without_disable_chain_download_flag():
|
|
53
|
+
"""Test that disable_chain_download defaults to False when flag is not provided."""
|
|
54
|
+
runner = CliRunner()
|
|
55
|
+
|
|
56
|
+
mock_entrypoint_cls = Mock()
|
|
57
|
+
mock_entrypoint_cls.meta_data.chain_name = "test_chain"
|
|
58
|
+
mock_entrypoint_cls.display_name = "TestChain"
|
|
59
|
+
|
|
60
|
+
mock_service = Mock()
|
|
61
|
+
mock_service.run_remote_url = "http://test.com/run_remote"
|
|
62
|
+
mock_service.is_websocket = False
|
|
63
|
+
|
|
64
|
+
with patch(
|
|
65
|
+
"truss_chains.framework.ChainletImporter.import_target"
|
|
66
|
+
) as mock_importer:
|
|
67
|
+
with patch("truss_chains.deployment.deployment_client.push") as mock_push:
|
|
68
|
+
mock_importer.return_value.__enter__.return_value = mock_entrypoint_cls
|
|
69
|
+
mock_push.return_value = mock_service
|
|
70
|
+
|
|
71
|
+
result = runner.invoke(
|
|
72
|
+
truss_cli,
|
|
73
|
+
[
|
|
74
|
+
"chains",
|
|
75
|
+
"push",
|
|
76
|
+
"test_chain.py",
|
|
77
|
+
"--remote",
|
|
78
|
+
"test_remote",
|
|
79
|
+
"--dryrun",
|
|
80
|
+
],
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
assert result.exit_code == 0
|
|
84
|
+
|
|
85
|
+
mock_push.assert_called_once()
|
|
86
|
+
call_args = mock_push.call_args
|
|
87
|
+
options = call_args[0][1]
|
|
88
|
+
|
|
89
|
+
assert hasattr(options, "disable_chain_download")
|
|
90
|
+
assert options.disable_chain_download is False
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def test_chains_push_help_includes_disable_chain_download():
|
|
94
|
+
"""Test that --disable-chain-download appears in the help output."""
|
|
95
|
+
runner = CliRunner()
|
|
96
|
+
|
|
97
|
+
result = runner.invoke(truss_cli, ["chains", "push", "--help"])
|
|
98
|
+
|
|
99
|
+
assert result.exit_code == 0
|
|
100
|
+
assert "--disable-chain-download" in result.output
|
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
import pathlib
|
|
2
|
+
import tempfile
|
|
3
|
+
from unittest.mock import Mock, patch
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
|
|
7
|
+
from truss.remote.baseten import custom_types as b10_types
|
|
8
|
+
from truss.remote.baseten.api import BasetenApi
|
|
9
|
+
from truss.remote.baseten.core import upload_chain_artifact
|
|
10
|
+
from truss.remote.baseten.remote import BasetenRemote
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@pytest.fixture
|
|
14
|
+
def mock_push_data():
|
|
15
|
+
"""Fixture providing mock push data for tests."""
|
|
16
|
+
mock_push_data = Mock()
|
|
17
|
+
mock_push_data.model_name = "test-model"
|
|
18
|
+
mock_push_data.s3_key = "models/test-key"
|
|
19
|
+
mock_push_data.encoded_config_str = "encoded_config"
|
|
20
|
+
mock_push_data.is_draft = False
|
|
21
|
+
mock_push_data.model_id = "model-id"
|
|
22
|
+
mock_push_data.version_name = None
|
|
23
|
+
return mock_push_data
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@pytest.fixture
|
|
27
|
+
def mock_remote_context():
|
|
28
|
+
"""Fixture providing mock remote and context managers for tests."""
|
|
29
|
+
api = Mock(spec=BasetenApi)
|
|
30
|
+
|
|
31
|
+
remote = BasetenRemote("https://test.baseten.co", "test-api-key")
|
|
32
|
+
remote._api = api
|
|
33
|
+
|
|
34
|
+
chain_name = "test-chain"
|
|
35
|
+
entrypoint_artifact = Mock()
|
|
36
|
+
entrypoint_artifact.truss_dir = "/path/to/truss"
|
|
37
|
+
entrypoint_artifact.display_name = "entrypoint"
|
|
38
|
+
|
|
39
|
+
dependency_artifacts = []
|
|
40
|
+
truss_user_env = Mock()
|
|
41
|
+
chain_root = pathlib.Path("/path/to/chain")
|
|
42
|
+
|
|
43
|
+
with patch.object(remote, "_prepare_push") as mock_prepare_push:
|
|
44
|
+
with patch("truss.remote.baseten.remote.truss_build.load") as mock_load:
|
|
45
|
+
mock_truss_handle = Mock()
|
|
46
|
+
mock_truss_handle.spec.config.model_name = "test-model"
|
|
47
|
+
mock_load.return_value = mock_truss_handle
|
|
48
|
+
|
|
49
|
+
yield {
|
|
50
|
+
"remote": remote,
|
|
51
|
+
"api": api,
|
|
52
|
+
"chain_name": chain_name,
|
|
53
|
+
"entrypoint_artifact": entrypoint_artifact,
|
|
54
|
+
"dependency_artifacts": dependency_artifacts,
|
|
55
|
+
"truss_user_env": truss_user_env,
|
|
56
|
+
"chain_root": chain_root,
|
|
57
|
+
"mock_prepare_push": mock_prepare_push,
|
|
58
|
+
"mock_load": mock_load,
|
|
59
|
+
"mock_truss_handle": mock_truss_handle,
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def test_get_blob_credentials_for_chain():
|
|
64
|
+
"""Test that get_blob_credentials works correctly for chain blob type using GraphQL."""
|
|
65
|
+
mock_graphql_response = {
|
|
66
|
+
"data": {
|
|
67
|
+
"chain_s3_upload_credentials": {
|
|
68
|
+
"s3_bucket": "test-chain-bucket",
|
|
69
|
+
"s3_key": "chains/test-uuid/chain.tgz",
|
|
70
|
+
"aws_access_key_id": "test_access_key",
|
|
71
|
+
"aws_secret_access_key": "test_secret_key",
|
|
72
|
+
"aws_session_token": "test_session_token",
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
# Create a real API instance and mock the GraphQL call
|
|
78
|
+
mock_auth_service = Mock()
|
|
79
|
+
mock_auth_service.authenticate.return_value = Mock(value="test-token")
|
|
80
|
+
api = BasetenApi("https://test.baseten.co", mock_auth_service)
|
|
81
|
+
with patch.object(api, "_post_graphql_query") as mock_graphql:
|
|
82
|
+
mock_graphql.return_value = mock_graphql_response
|
|
83
|
+
|
|
84
|
+
result = api.get_chain_s3_upload_credentials()
|
|
85
|
+
|
|
86
|
+
assert result.s3_bucket == "test-chain-bucket"
|
|
87
|
+
assert result.s3_key == "chains/test-uuid/chain.tgz"
|
|
88
|
+
assert result.aws_access_key_id == "test_access_key"
|
|
89
|
+
assert result.aws_secret_access_key == "test_secret_key"
|
|
90
|
+
assert result.aws_session_token == "test_session_token"
|
|
91
|
+
|
|
92
|
+
mock_graphql.assert_called_once()
|
|
93
|
+
call_args = mock_graphql.call_args
|
|
94
|
+
assert "chain_s3_upload_credentials" in call_args[0][0]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def test_get_blob_credentials_for_other_types_uses_rest():
|
|
98
|
+
"""Test that get_blob_credentials uses REST API for non-chain blob types."""
|
|
99
|
+
mock_response = {
|
|
100
|
+
"s3_bucket": "test-bucket",
|
|
101
|
+
"s3_key": "test-key",
|
|
102
|
+
"creds": {
|
|
103
|
+
"aws_access_key_id": "test_access_key",
|
|
104
|
+
"aws_secret_access_key": "test_secret_key",
|
|
105
|
+
"aws_session_token": "test_session_token",
|
|
106
|
+
},
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
mock_auth_service = Mock()
|
|
110
|
+
mock_auth_service.authenticate.return_value = Mock(value="test-token")
|
|
111
|
+
api = BasetenApi("https://test.baseten.co", mock_auth_service)
|
|
112
|
+
with patch.object(api, "_rest_api_client") as mock_client, patch.object(
|
|
113
|
+
api, "_post_graphql_query"
|
|
114
|
+
) as mock_graphql:
|
|
115
|
+
mock_client.get.return_value = mock_response
|
|
116
|
+
|
|
117
|
+
result = api.get_blob_credentials(b10_types.BlobType.MODEL)
|
|
118
|
+
|
|
119
|
+
assert result["s3_bucket"] == "test-bucket"
|
|
120
|
+
assert result["s3_key"] == "test-key"
|
|
121
|
+
|
|
122
|
+
mock_client.get.assert_called_once_with("v1/blobs/credentials/model")
|
|
123
|
+
mock_graphql.assert_not_called()
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@patch("truss.remote.baseten.core.multipart_upload_boto3")
|
|
127
|
+
def test_upload_chain_artifact_function(mock_multipart_upload):
|
|
128
|
+
"""Test the upload_chain_artifact function."""
|
|
129
|
+
# Mock ChainUploadCredentials object
|
|
130
|
+
mock_credentials = Mock()
|
|
131
|
+
mock_credentials.s3_bucket = "test-chain-bucket"
|
|
132
|
+
mock_credentials.s3_key = "chains/test-uuid/chain.tgz"
|
|
133
|
+
mock_credentials.aws_credentials = Mock()
|
|
134
|
+
mock_credentials.aws_credentials.model_dump.return_value = {
|
|
135
|
+
"aws_access_key_id": "test_access_key",
|
|
136
|
+
"aws_secret_access_key": "test_secret_key",
|
|
137
|
+
"aws_session_token": "test_session_token",
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
api = Mock(spec=BasetenApi)
|
|
141
|
+
api.get_chain_s3_upload_credentials.return_value = mock_credentials
|
|
142
|
+
|
|
143
|
+
with tempfile.NamedTemporaryFile(suffix=".tgz", delete=False) as temp_file:
|
|
144
|
+
temp_file.write(b"test chain content")
|
|
145
|
+
temp_file.flush()
|
|
146
|
+
|
|
147
|
+
result = upload_chain_artifact(api, temp_file, None)
|
|
148
|
+
|
|
149
|
+
assert result == "chains/test-uuid/chain.tgz"
|
|
150
|
+
|
|
151
|
+
api.get_chain_s3_upload_credentials.assert_called_once_with()
|
|
152
|
+
|
|
153
|
+
mock_multipart_upload.assert_called_once()
|
|
154
|
+
call_args = mock_multipart_upload.call_args
|
|
155
|
+
assert call_args[0][0] == temp_file.name # file path
|
|
156
|
+
assert call_args[0][1] == "test-chain-bucket" # bucket
|
|
157
|
+
assert call_args[0][2] == "chains/test-uuid/chain.tgz" # key
|
|
158
|
+
assert call_args[0][3] == { # credentials
|
|
159
|
+
"aws_access_key_id": "test_access_key",
|
|
160
|
+
"aws_secret_access_key": "test_secret_key",
|
|
161
|
+
"aws_session_token": "test_session_token",
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@patch("truss.remote.baseten.remote.upload_chain_artifact")
|
|
166
|
+
@patch("truss.remote.baseten.remote.archive_dir")
|
|
167
|
+
@patch("truss.remote.baseten.remote.create_chain_atomic")
|
|
168
|
+
def test_push_chain_atomic_with_chain_upload(
|
|
169
|
+
mock_create_chain_atomic,
|
|
170
|
+
mock_archive_dir,
|
|
171
|
+
mock_upload_chain_artifact,
|
|
172
|
+
mock_push_data,
|
|
173
|
+
mock_remote_context,
|
|
174
|
+
):
|
|
175
|
+
"""Test that push_chain_atomic uploads raw chain artifact when chain_root is provided."""
|
|
176
|
+
mock_create_chain_atomic.return_value = Mock()
|
|
177
|
+
mock_archive_dir.return_value = Mock()
|
|
178
|
+
mock_upload_chain_artifact.return_value = "chains/test-uuid/chain.tgz"
|
|
179
|
+
|
|
180
|
+
context = mock_remote_context
|
|
181
|
+
remote = context["remote"]
|
|
182
|
+
chain_name = context["chain_name"]
|
|
183
|
+
entrypoint_artifact = context["entrypoint_artifact"]
|
|
184
|
+
dependency_artifacts = context["dependency_artifacts"]
|
|
185
|
+
truss_user_env = context["truss_user_env"]
|
|
186
|
+
chain_root = context["chain_root"]
|
|
187
|
+
|
|
188
|
+
context["mock_prepare_push"].return_value = mock_push_data
|
|
189
|
+
|
|
190
|
+
result = remote.push_chain_atomic(
|
|
191
|
+
chain_name=chain_name,
|
|
192
|
+
entrypoint_artifact=entrypoint_artifact,
|
|
193
|
+
dependency_artifacts=dependency_artifacts,
|
|
194
|
+
truss_user_env=truss_user_env,
|
|
195
|
+
chain_root=chain_root,
|
|
196
|
+
publish=True,
|
|
197
|
+
)
|
|
198
|
+
assert result == mock_create_chain_atomic.return_value
|
|
199
|
+
|
|
200
|
+
mock_archive_dir.assert_called_once_with(dir=chain_root, progress_bar=None)
|
|
201
|
+
mock_upload_chain_artifact.assert_called_once()
|
|
202
|
+
|
|
203
|
+
mock_create_chain_atomic.assert_called_once()
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
@patch("truss.remote.baseten.remote.create_chain_atomic")
|
|
207
|
+
def test_push_chain_atomic_without_chain_upload(
|
|
208
|
+
mock_create_chain_atomic, mock_push_data, mock_remote_context
|
|
209
|
+
):
|
|
210
|
+
"""Test that push_chain_atomic skips chain upload when chain_root is None."""
|
|
211
|
+
mock_create_chain_atomic.return_value = Mock()
|
|
212
|
+
|
|
213
|
+
context = mock_remote_context
|
|
214
|
+
remote = context["remote"]
|
|
215
|
+
chain_name = context["chain_name"]
|
|
216
|
+
entrypoint_artifact = context["entrypoint_artifact"]
|
|
217
|
+
dependency_artifacts = context["dependency_artifacts"]
|
|
218
|
+
truss_user_env = context["truss_user_env"]
|
|
219
|
+
|
|
220
|
+
context["mock_prepare_push"].return_value = mock_push_data
|
|
221
|
+
|
|
222
|
+
with patch("truss.remote.baseten.remote.upload_chain_artifact") as mock_upload:
|
|
223
|
+
with patch(
|
|
224
|
+
"truss.remote.baseten.core.create_tar_with_progress_bar"
|
|
225
|
+
) as mock_tar:
|
|
226
|
+
# Call push_chain_atomic without chain_root
|
|
227
|
+
result = remote.push_chain_atomic(
|
|
228
|
+
chain_name=chain_name,
|
|
229
|
+
entrypoint_artifact=entrypoint_artifact,
|
|
230
|
+
dependency_artifacts=dependency_artifacts,
|
|
231
|
+
truss_user_env=truss_user_env,
|
|
232
|
+
chain_root=None, # No chain root
|
|
233
|
+
publish=True,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
assert result
|
|
237
|
+
# Verify chain artifact upload was NOT called
|
|
238
|
+
mock_tar.assert_not_called()
|
|
239
|
+
mock_upload.assert_not_called()
|
|
240
|
+
|
|
241
|
+
mock_create_chain_atomic.assert_called_once()
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
@patch("truss.remote.baseten.core.multipart_upload_boto3")
|
|
245
|
+
def test_upload_chain_artifact_error_handling(mock_multipart_upload):
|
|
246
|
+
"""Test error handling in upload_chain_artifact."""
|
|
247
|
+
# Mock API to raise an exception
|
|
248
|
+
api = Mock(spec=BasetenApi)
|
|
249
|
+
api.get_chain_s3_upload_credentials.side_effect = Exception("API Error")
|
|
250
|
+
|
|
251
|
+
with tempfile.NamedTemporaryFile(suffix=".tgz") as temp_file:
|
|
252
|
+
# Should raise the exception
|
|
253
|
+
with pytest.raises(Exception, match="API Error"):
|
|
254
|
+
upload_chain_artifact(api, temp_file, None)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def test_upload_chain_artifact_credentials_extraction():
|
|
258
|
+
"""Test that credentials are properly extracted from API response."""
|
|
259
|
+
# Mock ChainUploadCredentials object
|
|
260
|
+
mock_credentials = Mock()
|
|
261
|
+
mock_credentials.s3_bucket = "test-bucket"
|
|
262
|
+
mock_credentials.s3_key = "chains/test-uuid/chain.tgz"
|
|
263
|
+
mock_credentials.aws_credentials = Mock()
|
|
264
|
+
mock_credentials.aws_credentials.model_dump.return_value = {
|
|
265
|
+
"aws_access_key_id": "access_key",
|
|
266
|
+
"aws_secret_access_key": "secret_key",
|
|
267
|
+
"aws_session_token": "session_token",
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
api = Mock(spec=BasetenApi)
|
|
271
|
+
api.get_chain_s3_upload_credentials.return_value = mock_credentials
|
|
272
|
+
|
|
273
|
+
with patch("truss.remote.baseten.core.multipart_upload_boto3") as mock_upload:
|
|
274
|
+
with tempfile.NamedTemporaryFile(suffix=".tgz") as temp_file:
|
|
275
|
+
upload_chain_artifact(api, temp_file, None)
|
|
276
|
+
|
|
277
|
+
call_args = mock_upload.call_args
|
|
278
|
+
credentials = call_args[0][3]
|
|
279
|
+
|
|
280
|
+
assert credentials == {
|
|
281
|
+
"aws_access_key_id": "access_key",
|
|
282
|
+
"aws_secret_access_key": "secret_key",
|
|
283
|
+
"aws_session_token": "session_token",
|
|
284
|
+
}
|
|
285
|
+
assert "extra_field" not in credentials
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
from typing import Generator
|
|
4
|
+
|
|
5
|
+
from botocore.exceptions import ClientError, NoCredentialsError
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@contextmanager
|
|
11
|
+
def handle_client_error(
|
|
12
|
+
operation_description: str = "AWS operation",
|
|
13
|
+
) -> Generator[None, None, None]:
|
|
14
|
+
"""
|
|
15
|
+
Context manager to handle common boto3 errors and convert them to RuntimeError.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
operation_description: Description of the operation being performed for error messages
|
|
19
|
+
|
|
20
|
+
Raises:
|
|
21
|
+
RuntimeError: For NoCredentialsError, ClientError, and other exceptions
|
|
22
|
+
"""
|
|
23
|
+
try:
|
|
24
|
+
yield
|
|
25
|
+
except NoCredentialsError as nce:
|
|
26
|
+
raise RuntimeError(
|
|
27
|
+
f"No AWS credentials found for {operation_description}\nOriginal exception: {str(nce)}"
|
|
28
|
+
)
|
|
29
|
+
except ClientError as ce:
|
|
30
|
+
raise RuntimeError(
|
|
31
|
+
f"AWS client error when {operation_description} (check your credentials): {str(ce)}"
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
except Exception as exc:
|
|
35
|
+
raise RuntimeError(
|
|
36
|
+
f"Unexpected error `{exc}` during `{operation_description}`\nOriginal exception: {str(exc)}"
|
|
37
|
+
)
|
|
@@ -8,8 +8,8 @@ truss/base/errors.py,sha256=zDVLEvseTChdPP0oNhBBQCtQUtZJUaof5zeWMIjqz6o,691
|
|
|
8
8
|
truss/base/trt_llm_config.py,sha256=rEtBVFg2QnNMxnaz11s5Z69dJB1w7Bpt48Wf6jSsVZI,33087
|
|
9
9
|
truss/base/truss_config.py,sha256=s39Xc1e20s8IV07YLl_aVnp-uRS18ZQ2TV-3FILx4nY,28416
|
|
10
10
|
truss/base/truss_spec.py,sha256=jFVF79CXoEEspl2kXBAPyi-rwISReIGTdobGpaIhwJw,5979
|
|
11
|
-
truss/cli/chains_commands.py,sha256=
|
|
12
|
-
truss/cli/cli.py,sha256=
|
|
11
|
+
truss/cli/chains_commands.py,sha256=QijtACpuAt2O1RV_qhTNPw0jcFg-u0dX9PP-ct0t-rs,17716
|
|
12
|
+
truss/cli/cli.py,sha256=VGOw1ell7h9bna64UmopavCpVPdjDerSaGPDoizIsRI,30313
|
|
13
13
|
truss/cli/remote_cli.py,sha256=G_xCKRXzgkCmkiZJhUFfsv5YSVgde1jLA5LPQitpZgI,1905
|
|
14
14
|
truss/cli/train_commands.py,sha256=CrVqWsdkmSxgi3i2sSEyiE4QdfD0Z96F2Ib-PMZJjm8,20444
|
|
15
15
|
truss/cli/logs/base_watcher.py,sha256=vuqteoaMVGX34cgKcETf4X_gOkvnSnDaWz1_pbeFhqs,3343
|
|
@@ -52,12 +52,12 @@ truss/patch/truss_dir_patch_applier.py,sha256=ALnaVnu96g0kF2UmGuBFTua3lrXpwAy4sG
|
|
|
52
52
|
truss/remote/remote_factory.py,sha256=-0gLh_yIyNDgD48Q6sR8Yo5dOMQg84lrHRvn_XR0n4s,3585
|
|
53
53
|
truss/remote/truss_remote.py,sha256=TEe6h6by5-JLy7PMFsDN2QxIY5FmdIYN3bKvHHl02xM,8440
|
|
54
54
|
truss/remote/baseten/__init__.py,sha256=XNqJW1zyp143XQc6-7XVwsUA_Q_ZJv_ausn1_Ohtw9Y,176
|
|
55
|
-
truss/remote/baseten/api.py,sha256=
|
|
55
|
+
truss/remote/baseten/api.py,sha256=54Cl_2zHpRU4g2VXzK-BYlxPJeHHImceFrbxD9AASXo,30335
|
|
56
56
|
truss/remote/baseten/auth.py,sha256=tI7s6cI2EZgzpMIzrdbILHyGwiHDnmoKf_JBhJXT55E,776
|
|
57
|
-
truss/remote/baseten/core.py,sha256=
|
|
58
|
-
truss/remote/baseten/custom_types.py,sha256=
|
|
57
|
+
truss/remote/baseten/core.py,sha256=FC5-87Vs2f0NR8eddtSRvr3Z5W2rF7mpiq9jCPrbzr4,23399
|
|
58
|
+
truss/remote/baseten/custom_types.py,sha256=g7yWkE8p6uIAG5JqgfELFGHzjFLvO7vLPzbe-yl1nYs,4735
|
|
59
59
|
truss/remote/baseten/error.py,sha256=3TNTwwPqZnr4NRd9Sl6SfLUQR2fz9l6akDPpOntTpzA,578
|
|
60
|
-
truss/remote/baseten/remote.py,sha256=
|
|
60
|
+
truss/remote/baseten/remote.py,sha256=aKG1BODtrnmuRV-M8T3F3pw8oHawGwI09caKANJ19BM,23420
|
|
61
61
|
truss/remote/baseten/rest_client.py,sha256=_t3CWsWARt2u0C0fDsF4rtvkkHe-lH7KXoPxWXAkKd4,1185
|
|
62
62
|
truss/remote/baseten/service.py,sha256=HMaKiYbr2Mzv4BfXF9QkJ8H3Wwrq3LOMpFt9js4t0rs,5834
|
|
63
63
|
truss/remote/baseten/utils/status.py,sha256=jputc9N9AHXxUuW4KOk6mcZKzQ_gOBOe5BSx9K0DxPY,1266
|
|
@@ -141,6 +141,7 @@ truss/tests/test_testing_utilities_for_other_tests.py,sha256=YqIKflnd_BUMYaDBSkX
|
|
|
141
141
|
truss/tests/test_truss_gatherer.py,sha256=bn288OEkC49YY0mhly4cAl410ktZPfElNdWwZy82WfA,1261
|
|
142
142
|
truss/tests/test_truss_handle.py,sha256=-xz9VXkecXDTslmQZ-dmUmQLnvD0uumRqHS2uvGlMBA,30750
|
|
143
143
|
truss/tests/test_util.py,sha256=hs1bNMkXKEdoPRx4Nw-NAEdoibR92OubZuADGmbiYsQ,1344
|
|
144
|
+
truss/tests/cli/test_chains_cli.py,sha256=l9GTQrhRm9SRZn43WkMY4tdRslLmdsVyiydRPa1_Ja4,3162
|
|
144
145
|
truss/tests/cli/test_cli.py,sha256=yfbVS5u1hnAmmA8mJ539vj3lhH-JVGUvC4Q_Mbort44,787
|
|
145
146
|
truss/tests/cli/train/test_cache_view.py,sha256=aVRCh3atRpFbJqyYgq7N-vAW0DiKMftQ7ajUqO2ClOg,22606
|
|
146
147
|
truss/tests/cli/train/test_deploy_checkpoints.py,sha256=Ndkd9YxEgDLf3zLAZYH0myFK_wkKTz0oGZ57yWQt_l8,10100
|
|
@@ -162,6 +163,7 @@ truss/tests/remote/test_truss_remote.py,sha256=Rguyrnbx5RlbPJHFfCtsRtX1czAJ9Fo0a
|
|
|
162
163
|
truss/tests/remote/baseten/conftest.py,sha256=vNk0nfDB7XdmqatOMhjdANCWFGYM4VwSHVKlaBO2PPk,442
|
|
163
164
|
truss/tests/remote/baseten/test_api.py,sha256=AKJeNsrUtTNa0QPClfEvXlBOSJ214PKp23ULehMRJOQ,15885
|
|
164
165
|
truss/tests/remote/baseten/test_auth.py,sha256=ttu4bDnmwGfo3oiNut4HVGnh-QnjAefwZJctiibQJKY,669
|
|
166
|
+
truss/tests/remote/baseten/test_chain_upload.py,sha256=XaaF1ocovkBYsLMJ8EpXB9FUGfQZAwu4iyOWqoVn7tc,10886
|
|
165
167
|
truss/tests/remote/baseten/test_core.py,sha256=6NzJTDmoSUv6Muy1LFEYIUg10-cqw-hbLyeTSWcdNjY,26117
|
|
166
168
|
truss/tests/remote/baseten/test_remote.py,sha256=y1qSPL1t7dBeYI3xMFn436fttG7wkYdAoENTz7qKObg,23634
|
|
167
169
|
truss/tests/remote/baseten/test_service.py,sha256=ehbGkzzSPdLN7JHxc0O9YDPfzzKqU8OBzJGjRdw08zE,3786
|
|
@@ -340,6 +342,7 @@ truss/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
|
340
342
|
truss/util/docker.py,sha256=6PD7kMBBrOjsdvgkuSv7JMgZbe3NoJIeGasljMm2SwA,3934
|
|
341
343
|
truss/util/download.py,sha256=1lfBwzyaNLEp7SAVrBd9BX5inZpkCVp8sBnS9RNoiJA,2521
|
|
342
344
|
truss/util/env_vars.py,sha256=PmKsXdN-PX2-_xk9XcdHTFuRRWiFaMw2iNUKxE8B1Ro,1671
|
|
345
|
+
truss/util/error_utils.py,sha256=pvBH_opyCVpfUbFvm8bgOtjOWit23x2smkBPPdkSZwc,1148
|
|
343
346
|
truss/util/gpu.py,sha256=YiEF_JZyzur0MDMJOebMuJBQxrHD9ApGI0aPpWdb5BU,553
|
|
344
347
|
truss/util/jinja.py,sha256=7KbuYNq55I3DGtImAiCvBwR0K9-z1Jo6gMhmsy4lNZE,333
|
|
345
348
|
truss/util/log_utils.py,sha256=LwSgRh2K7KFjKKqBxr-IirFxGIzHi1mUM7YEvujvHsE,1985
|
|
@@ -349,8 +352,8 @@ truss/util/requirements.py,sha256=6T4nVV_NbSl3mAEo-CAk3JFmyJ_RJD768QaR55RdUJQ,69
|
|
|
349
352
|
truss/util/user_config.py,sha256=CvBf5oouNyfdcFXOg3HFhELVW-THiuwyOYdW3aTxdHw,9130
|
|
350
353
|
truss_chains/__init__.py,sha256=QDw1YwdqMaQpz5Oltu2Eq2vzEX9fDrMoqnhtbeh60i4,1278
|
|
351
354
|
truss_chains/framework.py,sha256=CS7tSegPe2Q8UUT6CDkrtSrB3utr_1QN1jTEPjrj5Ug,67519
|
|
352
|
-
truss_chains/private_types.py,sha256=
|
|
353
|
-
truss_chains/public_api.py,sha256=
|
|
355
|
+
truss_chains/private_types.py,sha256=vdcl8FuVsL9JGIu_9K7fd2EW9Ytzoq8nfEx5pmuMKTA,9063
|
|
356
|
+
truss_chains/public_api.py,sha256=civY8juJU92jSGBI7zM1qMnA7hlUdCq7L8o4IOo5meA,9722
|
|
354
357
|
truss_chains/public_types.py,sha256=RPr8jgKO_F_26F7H3CpwbidL-6euoKPdFHVpEIpYqrQ,29415
|
|
355
358
|
truss_chains/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
356
359
|
truss_chains/pydantic_numpy.py,sha256=MG8Ji_Inwo_JSfM2n7TPj8B-nbrBlDYsY3SOeBwD8fE,4289
|
|
@@ -358,7 +361,7 @@ truss_chains/streaming.py,sha256=DGl2LEAN67YwP7Nn9MK488KmYc4KopWmcHuE6WjyO1Q,125
|
|
|
358
361
|
truss_chains/utils.py,sha256=LvpCG2lnN6dqPqyX3PwLH9tyjUzqQN3N4WeEFROMHak,6291
|
|
359
362
|
truss_chains/deployment/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
360
363
|
truss_chains/deployment/code_gen.py,sha256=397FiSNZuW59J3Ma7N9GKGfvG_87BNFAXCIV8BW41t0,32669
|
|
361
|
-
truss_chains/deployment/deployment_client.py,sha256=
|
|
364
|
+
truss_chains/deployment/deployment_client.py,sha256=4cHuvaynVCclJ6M9pw8ukhO1E2NRKohIRxftvOfNvOE,34499
|
|
362
365
|
truss_chains/reference_code/reference_chainlet.py,sha256=5feSeqGtrHDbldkfZCfX2R5YbbW0Uhc35mhaP2pXrHw,1340
|
|
363
366
|
truss_chains/reference_code/reference_model.py,sha256=emH3hb23E_nbP98I37PGp1Xk1hz3g3lQ00tiLo55cSM,322
|
|
364
367
|
truss_chains/remote_chainlet/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -371,8 +374,8 @@ truss_train/deployment.py,sha256=lWWANSuzBWu2M4oK4qD7n-oVR1JKdmw2Pn5BJQHg-Ck,307
|
|
|
371
374
|
truss_train/loader.py,sha256=0o66EjBaHc2YY4syxxHVR4ordJWs13lNXnKjKq2wq0U,1630
|
|
372
375
|
truss_train/public_api.py,sha256=9N_NstiUlmBuLUwH_fNG_1x7OhGCytZLNvqKXBlStrM,1220
|
|
373
376
|
truss_train/restore_from_checkpoint.py,sha256=8hdPm-WSgkt74HDPjvCjZMBpvA9MwtoYsxVjOoa7BaM,1176
|
|
374
|
-
truss-0.11.
|
|
375
|
-
truss-0.11.
|
|
376
|
-
truss-0.11.
|
|
377
|
-
truss-0.11.
|
|
378
|
-
truss-0.11.
|
|
377
|
+
truss-0.11.18.dist-info/METADATA,sha256=vlaTGYYSaK0iMt1NvmUEuF6uhy-hygDj2MLDbQQpJO0,6678
|
|
378
|
+
truss-0.11.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
379
|
+
truss-0.11.18.dist-info/entry_points.txt,sha256=-MwKfHHQHQ6j0HqIgvxrz3CehCmczDLTD-OsRHnjjuU,130
|
|
380
|
+
truss-0.11.18.dist-info/licenses/LICENSE,sha256=FTqGzu85i-uw1Gi8E_o0oD60bH9yQ_XIGtZbA1QUYiw,1064
|
|
381
|
+
truss-0.11.18.dist-info/RECORD,,
|
|
@@ -516,14 +516,21 @@ def _create_baseten_chain(
|
|
|
516
516
|
|
|
517
517
|
_create_chains_secret_if_missing(remote_provider)
|
|
518
518
|
|
|
519
|
+
# Get chain root for raw chain artifact upload
|
|
520
|
+
chain_root = None
|
|
521
|
+
if entrypoint_descriptor is not None:
|
|
522
|
+
chain_root = _get_chain_root(entrypoint_descriptor.chainlet_cls)
|
|
523
|
+
|
|
519
524
|
chain_deployment_handle = remote_provider.push_chain_atomic(
|
|
520
525
|
baseten_options.chain_name,
|
|
521
526
|
entrypoint_artifact,
|
|
522
527
|
dependency_artifacts,
|
|
523
528
|
truss_user_env,
|
|
529
|
+
chain_root=chain_root,
|
|
524
530
|
publish=baseten_options.publish,
|
|
525
531
|
environment=baseten_options.environment,
|
|
526
532
|
progress_bar=progress_bar,
|
|
533
|
+
disable_chain_download=baseten_options.disable_chain_download,
|
|
527
534
|
)
|
|
528
535
|
return BasetenChainService(
|
|
529
536
|
baseten_options.chain_name,
|
truss_chains/private_types.py
CHANGED
|
@@ -265,6 +265,7 @@ class PushOptionsBaseten(PushOptions):
|
|
|
265
265
|
environment: Optional[str]
|
|
266
266
|
include_git_info: bool
|
|
267
267
|
working_dir: pathlib.Path
|
|
268
|
+
disable_chain_download: bool = False
|
|
268
269
|
|
|
269
270
|
@classmethod
|
|
270
271
|
def create(
|
|
@@ -277,6 +278,7 @@ class PushOptionsBaseten(PushOptions):
|
|
|
277
278
|
include_git_info: bool,
|
|
278
279
|
working_dir: pathlib.Path,
|
|
279
280
|
environment: Optional[str] = None,
|
|
281
|
+
disable_chain_download: bool = False,
|
|
280
282
|
) -> "PushOptionsBaseten":
|
|
281
283
|
if promote and not environment:
|
|
282
284
|
environment = PRODUCTION_ENVIRONMENT_NAME
|
|
@@ -290,6 +292,7 @@ class PushOptionsBaseten(PushOptions):
|
|
|
290
292
|
environment=environment,
|
|
291
293
|
include_git_info=include_git_info,
|
|
292
294
|
working_dir=working_dir,
|
|
295
|
+
disable_chain_download=disable_chain_download,
|
|
293
296
|
)
|
|
294
297
|
|
|
295
298
|
|
truss_chains/public_api.py
CHANGED
|
@@ -151,6 +151,7 @@ def push(
|
|
|
151
151
|
environment: Optional[str] = None,
|
|
152
152
|
progress_bar: Optional[Type["progress.Progress"]] = None,
|
|
153
153
|
include_git_info: bool = False,
|
|
154
|
+
disable_chain_download: bool = False,
|
|
154
155
|
) -> deployment_client.BasetenChainService:
|
|
155
156
|
"""
|
|
156
157
|
Deploys a chain remotely (with all dependent chainlets).
|
|
@@ -172,6 +173,7 @@ def push(
|
|
|
172
173
|
include_git_info: Whether to attach git versioning info (sha, branch, tag) to
|
|
173
174
|
deployments made from within a git repo. If set to True in `.trussrc`, it
|
|
174
175
|
will always be attached.
|
|
176
|
+
disable_chain_download: Disable downloading of pushed chain source code from the UI.
|
|
175
177
|
|
|
176
178
|
Returns:
|
|
177
179
|
A chain service handle to the deployed chain.
|
|
@@ -186,6 +188,7 @@ def push(
|
|
|
186
188
|
environment=environment,
|
|
187
189
|
include_git_info=include_git_info,
|
|
188
190
|
working_dir=pathlib.Path(inspect.getfile(entrypoint)).parent,
|
|
191
|
+
disable_chain_download=disable_chain_download,
|
|
189
192
|
)
|
|
190
193
|
service = deployment_client.push(entrypoint, options, progress_bar=progress_bar)
|
|
191
194
|
assert isinstance(
|
|
File without changes
|
|
File without changes
|
|
File without changes
|