truss 0.11.14__py3-none-any.whl → 0.11.15rc10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/cli/chains_commands.py +10 -0
- truss/contexts/image_builder/cache_warmer.py +5 -14
- truss/remote/baseten/api.py +68 -4
- truss/remote/baseten/core.py +40 -0
- truss/remote/baseten/custom_types.py +1 -0
- truss/remote/baseten/remote.py +18 -0
- truss/templates/docker_server/supervisord.conf.jinja +1 -1
- truss/tests/cli/test_chains_cli.py +100 -0
- truss/tests/remote/baseten/test_chain_upload.py +285 -0
- truss/util/error_utils.py +34 -0
- {truss-0.11.14.dist-info → truss-0.11.15rc10.dist-info}/METADATA +1 -1
- {truss-0.11.14.dist-info → truss-0.11.15rc10.dist-info}/RECORD +18 -15
- truss_chains/deployment/deployment_client.py +7 -0
- truss_chains/private_types.py +3 -0
- truss_chains/public_api.py +3 -0
- {truss-0.11.14.dist-info → truss-0.11.15rc10.dist-info}/WHEEL +0 -0
- {truss-0.11.14.dist-info → truss-0.11.15rc10.dist-info}/entry_points.txt +0 -0
- {truss-0.11.14.dist-info → truss-0.11.15rc10.dist-info}/licenses/LICENSE +0 -0
truss/cli/chains_commands.py
CHANGED
|
@@ -211,6 +211,14 @@ def _create_chains_table(service) -> Tuple[rich.table.Table, List[str]]:
|
|
|
211
211
|
default=False,
|
|
212
212
|
help=common.INCLUDE_GIT_INFO_DOC,
|
|
213
213
|
)
|
|
214
|
+
@click.option(
|
|
215
|
+
"--disable-chain-download",
|
|
216
|
+
"disable_chain_download",
|
|
217
|
+
is_flag=True,
|
|
218
|
+
required=False,
|
|
219
|
+
default=False,
|
|
220
|
+
help="Disable downloading of pushed chain source code from the UI.",
|
|
221
|
+
)
|
|
214
222
|
@click.pass_context
|
|
215
223
|
@common.common_options()
|
|
216
224
|
def push_chain(
|
|
@@ -227,6 +235,7 @@ def push_chain(
|
|
|
227
235
|
environment: Optional[str],
|
|
228
236
|
experimental_watch_chainlet_names: Optional[str],
|
|
229
237
|
include_git_info: bool = False,
|
|
238
|
+
disable_chain_download: bool = False,
|
|
230
239
|
) -> None:
|
|
231
240
|
"""
|
|
232
241
|
Deploys a chain remotely.
|
|
@@ -284,6 +293,7 @@ def push_chain(
|
|
|
284
293
|
environment=environment,
|
|
285
294
|
include_git_info=include_git_info,
|
|
286
295
|
working_dir=source.parent if source.is_file() else source.resolve(),
|
|
296
|
+
disable_chain_download=disable_chain_download,
|
|
287
297
|
)
|
|
288
298
|
service = deployment_client.push(
|
|
289
299
|
entrypoint_cls, options, progress_bar=progress.Progress
|
|
@@ -11,10 +11,11 @@ from typing import Optional, Type
|
|
|
11
11
|
import boto3
|
|
12
12
|
from botocore import UNSIGNED
|
|
13
13
|
from botocore.client import Config
|
|
14
|
-
from botocore.exceptions import ClientError, NoCredentialsError
|
|
15
14
|
from google.cloud import storage
|
|
16
15
|
from huggingface_hub import hf_hub_download
|
|
17
16
|
|
|
17
|
+
from truss.util.error_utils import handle_client_error
|
|
18
|
+
|
|
18
19
|
B10CP_PATH_TRUSS_ENV_VAR_NAME = "B10CP_PATH_TRUSS"
|
|
19
20
|
|
|
20
21
|
GCS_CREDENTIALS = "/app/data/service_account.json"
|
|
@@ -188,24 +189,14 @@ class S3File(RepositoryFile):
|
|
|
188
189
|
if not dst_file.parent.exists():
|
|
189
190
|
dst_file.parent.mkdir(parents=True)
|
|
190
191
|
|
|
191
|
-
|
|
192
|
+
with handle_client_error(
|
|
193
|
+
f"accessing S3 bucket {bucket_name} for file {file_name}"
|
|
194
|
+
):
|
|
192
195
|
url = client.generate_presigned_url(
|
|
193
196
|
"get_object",
|
|
194
197
|
Params={"Bucket": bucket_name, "Key": file_name},
|
|
195
198
|
ExpiresIn=3600,
|
|
196
199
|
)
|
|
197
|
-
except NoCredentialsError as nce:
|
|
198
|
-
raise RuntimeError(
|
|
199
|
-
f"No AWS credentials found\nOriginal exception: {str(nce)}"
|
|
200
|
-
)
|
|
201
|
-
except ClientError as ce:
|
|
202
|
-
raise RuntimeError(
|
|
203
|
-
f"Client error when accessing the S3 bucket (check your credentials): {str(ce)}"
|
|
204
|
-
)
|
|
205
|
-
except Exception as exc:
|
|
206
|
-
raise RuntimeError(
|
|
207
|
-
f"File not found on S3 bucket: {file_name}\nOriginal exception: {str(exc)}"
|
|
208
|
-
)
|
|
209
200
|
|
|
210
201
|
download_file_using_b10cp(url, dst_file, self.file_name)
|
|
211
202
|
|
truss/remote/baseten/api.py
CHANGED
|
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
|
|
5
5
|
import requests
|
|
6
6
|
from pydantic import BaseModel, Field
|
|
7
7
|
|
|
8
|
+
from truss.base.custom_types import SafeModel
|
|
8
9
|
from truss.remote.baseten import custom_types as b10_types
|
|
9
10
|
from truss.remote.baseten.auth import ApiKey, AuthService
|
|
10
11
|
from truss.remote.baseten.custom_types import APIKeyCategory
|
|
@@ -13,6 +14,29 @@ from truss.remote.baseten.rest_client import RestAPIClient
|
|
|
13
14
|
from truss.remote.baseten.utils.transfer import base64_encoded_json_str
|
|
14
15
|
|
|
15
16
|
logger = logging.getLogger(__name__)
|
|
17
|
+
PARAMS_INDENT = "\n "
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ChainAWSCredential(SafeModel):
|
|
21
|
+
aws_access_key_id: str
|
|
22
|
+
aws_secret_access_key: str
|
|
23
|
+
aws_session_token: str
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ChainUploadCredentials(SafeModel):
|
|
27
|
+
s3_bucket: str
|
|
28
|
+
s3_key: str
|
|
29
|
+
aws_access_key_id: str
|
|
30
|
+
aws_secret_access_key: str
|
|
31
|
+
aws_session_token: str
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def aws_credentials(self) -> ChainAWSCredential:
|
|
35
|
+
return ChainAWSCredential(
|
|
36
|
+
aws_access_key_id=self.aws_access_key_id,
|
|
37
|
+
aws_secret_access_key=self.aws_secret_access_key,
|
|
38
|
+
aws_session_token=self.aws_session_token,
|
|
39
|
+
)
|
|
16
40
|
|
|
17
41
|
|
|
18
42
|
class InstanceTypeV1(BaseModel):
|
|
@@ -299,7 +323,11 @@ class BasetenApi:
|
|
|
299
323
|
chain_name: Optional[str] = None,
|
|
300
324
|
environment: Optional[str] = None,
|
|
301
325
|
is_draft: bool = False,
|
|
326
|
+
original_source_artifact_s3_key: Optional[str] = None,
|
|
327
|
+
allow_truss_download: Optional[bool] = True,
|
|
302
328
|
):
|
|
329
|
+
if allow_truss_download is None:
|
|
330
|
+
allow_truss_download = True
|
|
303
331
|
entrypoint_str = _chainlet_data_atomic_to_graphql_mutation(entrypoint)
|
|
304
332
|
|
|
305
333
|
dependencies_str = ", ".join(
|
|
@@ -309,13 +337,28 @@ class BasetenApi:
|
|
|
309
337
|
]
|
|
310
338
|
)
|
|
311
339
|
|
|
340
|
+
params = []
|
|
341
|
+
if chain_id:
|
|
342
|
+
params.append(f'chain_id: "{chain_id}"')
|
|
343
|
+
if chain_name:
|
|
344
|
+
params.append(f'chain_name: "{chain_name}"')
|
|
345
|
+
if environment:
|
|
346
|
+
params.append(f'environment: "{environment}"')
|
|
347
|
+
if original_source_artifact_s3_key:
|
|
348
|
+
params.append(
|
|
349
|
+
f'original_source_artifact_s3_key: "{original_source_artifact_s3_key}"'
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
params.append(f"is_draft: {str(is_draft).lower()}")
|
|
353
|
+
if allow_truss_download is False:
|
|
354
|
+
params.append("allow_truss_download: false")
|
|
355
|
+
|
|
356
|
+
params_str = PARAMS_INDENT.join(params)
|
|
357
|
+
|
|
312
358
|
query_string = f"""
|
|
313
359
|
mutation ($trussUserEnv: String) {{
|
|
314
360
|
deploy_chain_atomic(
|
|
315
|
-
{
|
|
316
|
-
{f'chain_name: "{chain_name}"' if chain_name else ""}
|
|
317
|
-
{f'environment: "{environment}"' if environment else ""}
|
|
318
|
-
is_draft: {str(is_draft).lower()}
|
|
361
|
+
{params_str}
|
|
319
362
|
entrypoint: {entrypoint_str}
|
|
320
363
|
dependencies: [{dependencies_str}]
|
|
321
364
|
truss_user_env: $trussUserEnv
|
|
@@ -657,8 +700,29 @@ class BasetenApi:
|
|
|
657
700
|
return resp_json["training_projects"]
|
|
658
701
|
|
|
659
702
|
def get_blob_credentials(self, blob_type: b10_types.BlobType):
|
|
703
|
+
if blob_type == b10_types.BlobType.CHAIN:
|
|
704
|
+
return self.get_chain_s3_upload_credentials()
|
|
660
705
|
return self._rest_api_client.get(f"v1/blobs/credentials/{blob_type.value}")
|
|
661
706
|
|
|
707
|
+
def get_chain_s3_upload_credentials(self) -> ChainUploadCredentials:
|
|
708
|
+
"""Get chain artifact credentials using GraphQL query."""
|
|
709
|
+
query = """
|
|
710
|
+
query {
|
|
711
|
+
chain_s3_upload_credentials {
|
|
712
|
+
s3_bucket
|
|
713
|
+
s3_key
|
|
714
|
+
aws_access_key_id
|
|
715
|
+
aws_secret_access_key
|
|
716
|
+
aws_session_token
|
|
717
|
+
}
|
|
718
|
+
}
|
|
719
|
+
"""
|
|
720
|
+
response = self._post_graphql_query(query)
|
|
721
|
+
|
|
722
|
+
return ChainUploadCredentials.model_validate(
|
|
723
|
+
response["data"]["chain_s3_upload_credentials"]
|
|
724
|
+
)
|
|
725
|
+
|
|
662
726
|
def get_training_job_metrics(
|
|
663
727
|
self,
|
|
664
728
|
project_id: str,
|
truss/remote/baseten/core.py
CHANGED
|
@@ -8,6 +8,7 @@ from typing import IO, TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tup
|
|
|
8
8
|
import requests
|
|
9
9
|
|
|
10
10
|
from truss.base.errors import ValidationError
|
|
11
|
+
from truss.util.error_utils import handle_client_error
|
|
11
12
|
|
|
12
13
|
if TYPE_CHECKING:
|
|
13
14
|
from rich import progress
|
|
@@ -80,6 +81,8 @@ class ChainDeploymentHandleAtomic(NamedTuple):
|
|
|
80
81
|
chain_id: str
|
|
81
82
|
chain_deployment_id: str
|
|
82
83
|
is_draft: bool
|
|
84
|
+
original_source_artifact_s3_key: Optional[str] = None
|
|
85
|
+
allow_truss_download: Optional[bool] = True
|
|
83
86
|
|
|
84
87
|
|
|
85
88
|
class ModelVersionHandle(NamedTuple):
|
|
@@ -127,6 +130,8 @@ def create_chain_atomic(
|
|
|
127
130
|
is_draft: bool,
|
|
128
131
|
truss_user_env: b10_types.TrussUserEnv,
|
|
129
132
|
environment: Optional[str],
|
|
133
|
+
original_source_artifact_s3_key: Optional[str] = None,
|
|
134
|
+
allow_truss_download: bool = True,
|
|
130
135
|
) -> ChainDeploymentHandleAtomic:
|
|
131
136
|
if environment and is_draft:
|
|
132
137
|
logging.info(
|
|
@@ -149,6 +154,8 @@ def create_chain_atomic(
|
|
|
149
154
|
chain_name=chain_name,
|
|
150
155
|
is_draft=True,
|
|
151
156
|
truss_user_env=truss_user_env,
|
|
157
|
+
original_source_artifact_s3_key=original_source_artifact_s3_key,
|
|
158
|
+
allow_truss_download=allow_truss_download,
|
|
152
159
|
)
|
|
153
160
|
elif chain_id:
|
|
154
161
|
# This is the only case where promote has relevance, since
|
|
@@ -162,6 +169,8 @@ def create_chain_atomic(
|
|
|
162
169
|
chain_id=chain_id,
|
|
163
170
|
environment=environment,
|
|
164
171
|
truss_user_env=truss_user_env,
|
|
172
|
+
original_source_artifact_s3_key=original_source_artifact_s3_key,
|
|
173
|
+
allow_truss_download=allow_truss_download,
|
|
165
174
|
)
|
|
166
175
|
except ApiError as e:
|
|
167
176
|
if (
|
|
@@ -182,6 +191,8 @@ def create_chain_atomic(
|
|
|
182
191
|
dependencies=dependencies,
|
|
183
192
|
chain_name=chain_name,
|
|
184
193
|
truss_user_env=truss_user_env,
|
|
194
|
+
original_source_artifact_s3_key=original_source_artifact_s3_key,
|
|
195
|
+
allow_truss_download=allow_truss_download,
|
|
185
196
|
)
|
|
186
197
|
|
|
187
198
|
return ChainDeploymentHandleAtomic(
|
|
@@ -189,6 +200,8 @@ def create_chain_atomic(
|
|
|
189
200
|
chain_id=res["chain_deployment"]["chain"]["id"],
|
|
190
201
|
hostname=res["chain_deployment"]["chain"]["hostname"],
|
|
191
202
|
is_draft=is_draft,
|
|
203
|
+
original_source_artifact_s3_key=original_source_artifact_s3_key,
|
|
204
|
+
allow_truss_download=allow_truss_download,
|
|
192
205
|
)
|
|
193
206
|
|
|
194
207
|
|
|
@@ -342,6 +355,33 @@ def upload_truss(
|
|
|
342
355
|
return s3_key
|
|
343
356
|
|
|
344
357
|
|
|
358
|
+
def upload_chain_artifact(
|
|
359
|
+
api: BasetenApi,
|
|
360
|
+
serialize_file: IO,
|
|
361
|
+
progress_bar: Optional[Type["progress.Progress"]],
|
|
362
|
+
) -> str:
|
|
363
|
+
"""
|
|
364
|
+
Upload a chain artifact to the Baseten remote.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
api: BasetenApi instance
|
|
368
|
+
serialize_file: File-like object containing the serialized chain artifact
|
|
369
|
+
|
|
370
|
+
Returns:
|
|
371
|
+
The S3 key of the uploaded file
|
|
372
|
+
"""
|
|
373
|
+
credentials = api.get_chain_s3_upload_credentials()
|
|
374
|
+
with handle_client_error("Uploading chain source"):
|
|
375
|
+
multipart_upload_boto3(
|
|
376
|
+
serialize_file.name,
|
|
377
|
+
credentials.s3_bucket,
|
|
378
|
+
credentials.s3_key,
|
|
379
|
+
credentials.aws_credentials.model_dump(),
|
|
380
|
+
progress_bar,
|
|
381
|
+
)
|
|
382
|
+
return credentials.s3_key
|
|
383
|
+
|
|
384
|
+
|
|
345
385
|
def create_truss_service(
|
|
346
386
|
api: BasetenApi,
|
|
347
387
|
model_name: str,
|
truss/remote/baseten/remote.py
CHANGED
|
@@ -31,6 +31,7 @@ from truss.remote.baseten.core import (
|
|
|
31
31
|
get_model_and_versions,
|
|
32
32
|
get_prod_version_from_versions,
|
|
33
33
|
get_truss_watch_state,
|
|
34
|
+
upload_chain_artifact,
|
|
34
35
|
upload_truss,
|
|
35
36
|
validate_truss_config_against_backend,
|
|
36
37
|
)
|
|
@@ -263,9 +264,11 @@ class BasetenRemote(TrussRemote):
|
|
|
263
264
|
entrypoint_artifact: custom_types.ChainletArtifact,
|
|
264
265
|
dependency_artifacts: List[custom_types.ChainletArtifact],
|
|
265
266
|
truss_user_env: b10_types.TrussUserEnv,
|
|
267
|
+
chain_root: Optional[Path] = None,
|
|
266
268
|
publish: bool = False,
|
|
267
269
|
environment: Optional[str] = None,
|
|
268
270
|
progress_bar: Optional[Type["progress.Progress"]] = None,
|
|
271
|
+
disable_chain_download: bool = False,
|
|
269
272
|
) -> ChainDeploymentHandleAtomic:
|
|
270
273
|
# If we are promoting a model to an environment after deploy, it must be published.
|
|
271
274
|
# Draft models cannot be promoted.
|
|
@@ -285,6 +288,7 @@ class BasetenRemote(TrussRemote):
|
|
|
285
288
|
publish=publish,
|
|
286
289
|
origin=custom_types.ModelOrigin.CHAINS,
|
|
287
290
|
progress_bar=progress_bar,
|
|
291
|
+
disable_truss_download=disable_chain_download,
|
|
288
292
|
)
|
|
289
293
|
oracle_data = custom_types.OracleData(
|
|
290
294
|
model_name=push_data.model_name,
|
|
@@ -300,6 +304,18 @@ class BasetenRemote(TrussRemote):
|
|
|
300
304
|
)
|
|
301
305
|
)
|
|
302
306
|
|
|
307
|
+
# Upload raw chain artifact if chain_root is provided
|
|
308
|
+
raw_chain_s3_key = None
|
|
309
|
+
if chain_root is not None:
|
|
310
|
+
logging.info("Uploading source artifact")
|
|
311
|
+
# Create a tar file from the chain root directory
|
|
312
|
+
original_source_tar = archive_dir(dir=chain_root, progress_bar=progress_bar)
|
|
313
|
+
# Upload the chain artifact to S3
|
|
314
|
+
raw_chain_s3_key = upload_chain_artifact(
|
|
315
|
+
api=self._api,
|
|
316
|
+
serialize_file=original_source_tar,
|
|
317
|
+
progress_bar=progress_bar,
|
|
318
|
+
)
|
|
303
319
|
chain_deployment_handle = create_chain_atomic(
|
|
304
320
|
api=self._api,
|
|
305
321
|
chain_name=chain_name,
|
|
@@ -308,6 +324,8 @@ class BasetenRemote(TrussRemote):
|
|
|
308
324
|
is_draft=not publish,
|
|
309
325
|
truss_user_env=truss_user_env,
|
|
310
326
|
environment=environment,
|
|
327
|
+
original_source_artifact_s3_key=raw_chain_s3_key,
|
|
328
|
+
allow_truss_download=not disable_chain_download,
|
|
311
329
|
)
|
|
312
330
|
logging.info("Successfully pushed to baseten. Chain is building and deploying.")
|
|
313
331
|
return chain_deployment_handle
|
|
@@ -8,7 +8,7 @@ logfile_maxbytes=0 ; No size limit on logfile (since logging is disabl
|
|
|
8
8
|
command={{start_command}} ; Command to start the model server (provided by Jinja variable)
|
|
9
9
|
startsecs=30 ; Wait 30 seconds before assuming the server is running
|
|
10
10
|
autostart=true ; Automatically start the program when supervisord starts
|
|
11
|
-
autorestart=
|
|
11
|
+
autorestart=unexpected ; Don't restart the program if it exits with clear exit code
|
|
12
12
|
stdout_logfile=/dev/fd/1 ; Send stdout to the first file descriptor (stdout)
|
|
13
13
|
stdout_logfile_maxbytes=0 ; No size limit on stdout log
|
|
14
14
|
redirect_stderr=true ; Redirect stderr to stdout
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
"""Tests for truss chains CLI commands."""
|
|
2
|
+
|
|
3
|
+
from unittest.mock import Mock, patch
|
|
4
|
+
|
|
5
|
+
from click.testing import CliRunner
|
|
6
|
+
|
|
7
|
+
from truss.cli.cli import truss_cli
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def test_chains_push_with_disable_chain_download_flag():
|
|
11
|
+
"""Test that --disable-chain-download flag is properly parsed and passed through."""
|
|
12
|
+
runner = CliRunner()
|
|
13
|
+
|
|
14
|
+
mock_entrypoint_cls = Mock()
|
|
15
|
+
mock_entrypoint_cls.meta_data.chain_name = "test_chain"
|
|
16
|
+
mock_entrypoint_cls.display_name = "TestChain"
|
|
17
|
+
|
|
18
|
+
mock_service = Mock()
|
|
19
|
+
mock_service.run_remote_url = "http://test.com/run_remote"
|
|
20
|
+
mock_service.is_websocket = False
|
|
21
|
+
|
|
22
|
+
with patch(
|
|
23
|
+
"truss_chains.framework.ChainletImporter.import_target"
|
|
24
|
+
) as mock_importer:
|
|
25
|
+
with patch("truss_chains.deployment.deployment_client.push") as mock_push:
|
|
26
|
+
mock_importer.return_value.__enter__.return_value = mock_entrypoint_cls
|
|
27
|
+
mock_push.return_value = mock_service
|
|
28
|
+
|
|
29
|
+
result = runner.invoke(
|
|
30
|
+
truss_cli,
|
|
31
|
+
[
|
|
32
|
+
"chains",
|
|
33
|
+
"push",
|
|
34
|
+
"test_chain.py",
|
|
35
|
+
"--disable-chain-download",
|
|
36
|
+
"--remote",
|
|
37
|
+
"test_remote",
|
|
38
|
+
"--dryrun",
|
|
39
|
+
],
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
assert result.exit_code == 0
|
|
43
|
+
|
|
44
|
+
mock_push.assert_called_once()
|
|
45
|
+
call_args = mock_push.call_args
|
|
46
|
+
options = call_args[0][1]
|
|
47
|
+
|
|
48
|
+
assert hasattr(options, "disable_chain_download")
|
|
49
|
+
assert options.disable_chain_download is True
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def test_chains_push_without_disable_chain_download_flag():
|
|
53
|
+
"""Test that disable_chain_download defaults to False when flag is not provided."""
|
|
54
|
+
runner = CliRunner()
|
|
55
|
+
|
|
56
|
+
mock_entrypoint_cls = Mock()
|
|
57
|
+
mock_entrypoint_cls.meta_data.chain_name = "test_chain"
|
|
58
|
+
mock_entrypoint_cls.display_name = "TestChain"
|
|
59
|
+
|
|
60
|
+
mock_service = Mock()
|
|
61
|
+
mock_service.run_remote_url = "http://test.com/run_remote"
|
|
62
|
+
mock_service.is_websocket = False
|
|
63
|
+
|
|
64
|
+
with patch(
|
|
65
|
+
"truss_chains.framework.ChainletImporter.import_target"
|
|
66
|
+
) as mock_importer:
|
|
67
|
+
with patch("truss_chains.deployment.deployment_client.push") as mock_push:
|
|
68
|
+
mock_importer.return_value.__enter__.return_value = mock_entrypoint_cls
|
|
69
|
+
mock_push.return_value = mock_service
|
|
70
|
+
|
|
71
|
+
result = runner.invoke(
|
|
72
|
+
truss_cli,
|
|
73
|
+
[
|
|
74
|
+
"chains",
|
|
75
|
+
"push",
|
|
76
|
+
"test_chain.py",
|
|
77
|
+
"--remote",
|
|
78
|
+
"test_remote",
|
|
79
|
+
"--dryrun",
|
|
80
|
+
],
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
assert result.exit_code == 0
|
|
84
|
+
|
|
85
|
+
mock_push.assert_called_once()
|
|
86
|
+
call_args = mock_push.call_args
|
|
87
|
+
options = call_args[0][1]
|
|
88
|
+
|
|
89
|
+
assert hasattr(options, "disable_chain_download")
|
|
90
|
+
assert options.disable_chain_download is False
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def test_chains_push_help_includes_disable_chain_download():
|
|
94
|
+
"""Test that --disable-chain-download appears in the help output."""
|
|
95
|
+
runner = CliRunner()
|
|
96
|
+
|
|
97
|
+
result = runner.invoke(truss_cli, ["chains", "push", "--help"])
|
|
98
|
+
|
|
99
|
+
assert result.exit_code == 0
|
|
100
|
+
assert "--disable-chain-download" in result.output
|
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
import pathlib
|
|
2
|
+
import tempfile
|
|
3
|
+
from unittest.mock import Mock, patch
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
|
|
7
|
+
from truss.remote.baseten import custom_types as b10_types
|
|
8
|
+
from truss.remote.baseten.api import BasetenApi
|
|
9
|
+
from truss.remote.baseten.core import upload_chain_artifact
|
|
10
|
+
from truss.remote.baseten.remote import BasetenRemote
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@pytest.fixture
|
|
14
|
+
def mock_push_data():
|
|
15
|
+
"""Fixture providing mock push data for tests."""
|
|
16
|
+
mock_push_data = Mock()
|
|
17
|
+
mock_push_data.model_name = "test-model"
|
|
18
|
+
mock_push_data.s3_key = "models/test-key"
|
|
19
|
+
mock_push_data.encoded_config_str = "encoded_config"
|
|
20
|
+
mock_push_data.is_draft = False
|
|
21
|
+
mock_push_data.model_id = "model-id"
|
|
22
|
+
mock_push_data.version_name = None
|
|
23
|
+
return mock_push_data
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@pytest.fixture
|
|
27
|
+
def mock_remote_context():
|
|
28
|
+
"""Fixture providing mock remote and context managers for tests."""
|
|
29
|
+
api = Mock(spec=BasetenApi)
|
|
30
|
+
|
|
31
|
+
remote = BasetenRemote("https://test.baseten.co", "test-api-key")
|
|
32
|
+
remote._api = api
|
|
33
|
+
|
|
34
|
+
chain_name = "test-chain"
|
|
35
|
+
entrypoint_artifact = Mock()
|
|
36
|
+
entrypoint_artifact.truss_dir = "/path/to/truss"
|
|
37
|
+
entrypoint_artifact.display_name = "entrypoint"
|
|
38
|
+
|
|
39
|
+
dependency_artifacts = []
|
|
40
|
+
truss_user_env = Mock()
|
|
41
|
+
chain_root = pathlib.Path("/path/to/chain")
|
|
42
|
+
|
|
43
|
+
with patch.object(remote, "_prepare_push") as mock_prepare_push:
|
|
44
|
+
with patch("truss.remote.baseten.remote.truss_build.load") as mock_load:
|
|
45
|
+
mock_truss_handle = Mock()
|
|
46
|
+
mock_truss_handle.spec.config.model_name = "test-model"
|
|
47
|
+
mock_load.return_value = mock_truss_handle
|
|
48
|
+
|
|
49
|
+
yield {
|
|
50
|
+
"remote": remote,
|
|
51
|
+
"api": api,
|
|
52
|
+
"chain_name": chain_name,
|
|
53
|
+
"entrypoint_artifact": entrypoint_artifact,
|
|
54
|
+
"dependency_artifacts": dependency_artifacts,
|
|
55
|
+
"truss_user_env": truss_user_env,
|
|
56
|
+
"chain_root": chain_root,
|
|
57
|
+
"mock_prepare_push": mock_prepare_push,
|
|
58
|
+
"mock_load": mock_load,
|
|
59
|
+
"mock_truss_handle": mock_truss_handle,
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def test_get_blob_credentials_for_chain():
|
|
64
|
+
"""Test that get_blob_credentials works correctly for chain blob type using GraphQL."""
|
|
65
|
+
mock_graphql_response = {
|
|
66
|
+
"data": {
|
|
67
|
+
"chain_s3_upload_credentials": {
|
|
68
|
+
"s3_bucket": "test-chain-bucket",
|
|
69
|
+
"s3_key": "chains/test-uuid/chain.tgz",
|
|
70
|
+
"aws_access_key_id": "test_access_key",
|
|
71
|
+
"aws_secret_access_key": "test_secret_key",
|
|
72
|
+
"aws_session_token": "test_session_token",
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
# Create a real API instance and mock the GraphQL call
|
|
78
|
+
mock_auth_service = Mock()
|
|
79
|
+
mock_auth_service.authenticate.return_value = Mock(value="test-token")
|
|
80
|
+
api = BasetenApi("https://test.baseten.co", mock_auth_service)
|
|
81
|
+
with patch.object(api, "_post_graphql_query") as mock_graphql:
|
|
82
|
+
mock_graphql.return_value = mock_graphql_response
|
|
83
|
+
|
|
84
|
+
result = api.get_chain_s3_upload_credentials()
|
|
85
|
+
|
|
86
|
+
assert result.s3_bucket == "test-chain-bucket"
|
|
87
|
+
assert result.s3_key == "chains/test-uuid/chain.tgz"
|
|
88
|
+
assert result.aws_access_key_id == "test_access_key"
|
|
89
|
+
assert result.aws_secret_access_key == "test_secret_key"
|
|
90
|
+
assert result.aws_session_token == "test_session_token"
|
|
91
|
+
|
|
92
|
+
mock_graphql.assert_called_once()
|
|
93
|
+
call_args = mock_graphql.call_args
|
|
94
|
+
assert "chain_s3_upload_credentials" in call_args[0][0]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def test_get_blob_credentials_for_other_types_uses_rest():
|
|
98
|
+
"""Test that get_blob_credentials uses REST API for non-chain blob types."""
|
|
99
|
+
mock_response = {
|
|
100
|
+
"s3_bucket": "test-bucket",
|
|
101
|
+
"s3_key": "test-key",
|
|
102
|
+
"creds": {
|
|
103
|
+
"aws_access_key_id": "test_access_key",
|
|
104
|
+
"aws_secret_access_key": "test_secret_key",
|
|
105
|
+
"aws_session_token": "test_session_token",
|
|
106
|
+
},
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
mock_auth_service = Mock()
|
|
110
|
+
mock_auth_service.authenticate.return_value = Mock(value="test-token")
|
|
111
|
+
api = BasetenApi("https://test.baseten.co", mock_auth_service)
|
|
112
|
+
with patch.object(api, "_rest_api_client") as mock_client, patch.object(
|
|
113
|
+
api, "_post_graphql_query"
|
|
114
|
+
) as mock_graphql:
|
|
115
|
+
mock_client.get.return_value = mock_response
|
|
116
|
+
|
|
117
|
+
result = api.get_blob_credentials(b10_types.BlobType.MODEL)
|
|
118
|
+
|
|
119
|
+
assert result["s3_bucket"] == "test-bucket"
|
|
120
|
+
assert result["s3_key"] == "test-key"
|
|
121
|
+
|
|
122
|
+
mock_client.get.assert_called_once_with("v1/blobs/credentials/model")
|
|
123
|
+
mock_graphql.assert_not_called()
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@patch("truss.remote.baseten.core.multipart_upload_boto3")
|
|
127
|
+
def test_upload_chain_artifact_function(mock_multipart_upload):
|
|
128
|
+
"""Test the upload_chain_artifact function."""
|
|
129
|
+
# Mock ChainUploadCredentials object
|
|
130
|
+
mock_credentials = Mock()
|
|
131
|
+
mock_credentials.s3_bucket = "test-chain-bucket"
|
|
132
|
+
mock_credentials.s3_key = "chains/test-uuid/chain.tgz"
|
|
133
|
+
mock_credentials.aws_credentials = Mock()
|
|
134
|
+
mock_credentials.aws_credentials.model_dump.return_value = {
|
|
135
|
+
"aws_access_key_id": "test_access_key",
|
|
136
|
+
"aws_secret_access_key": "test_secret_key",
|
|
137
|
+
"aws_session_token": "test_session_token",
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
api = Mock(spec=BasetenApi)
|
|
141
|
+
api.get_chain_s3_upload_credentials.return_value = mock_credentials
|
|
142
|
+
|
|
143
|
+
with tempfile.NamedTemporaryFile(suffix=".tgz", delete=False) as temp_file:
|
|
144
|
+
temp_file.write(b"test chain content")
|
|
145
|
+
temp_file.flush()
|
|
146
|
+
|
|
147
|
+
result = upload_chain_artifact(api, temp_file, None)
|
|
148
|
+
|
|
149
|
+
assert result == "chains/test-uuid/chain.tgz"
|
|
150
|
+
|
|
151
|
+
api.get_chain_s3_upload_credentials.assert_called_once_with()
|
|
152
|
+
|
|
153
|
+
mock_multipart_upload.assert_called_once()
|
|
154
|
+
call_args = mock_multipart_upload.call_args
|
|
155
|
+
assert call_args[0][0] == temp_file.name # file path
|
|
156
|
+
assert call_args[0][1] == "test-chain-bucket" # bucket
|
|
157
|
+
assert call_args[0][2] == "chains/test-uuid/chain.tgz" # key
|
|
158
|
+
assert call_args[0][3] == { # credentials
|
|
159
|
+
"aws_access_key_id": "test_access_key",
|
|
160
|
+
"aws_secret_access_key": "test_secret_key",
|
|
161
|
+
"aws_session_token": "test_session_token",
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@patch("truss.remote.baseten.remote.upload_chain_artifact")
|
|
166
|
+
@patch("truss.remote.baseten.remote.archive_dir")
|
|
167
|
+
@patch("truss.remote.baseten.remote.create_chain_atomic")
|
|
168
|
+
def test_push_chain_atomic_with_chain_upload(
|
|
169
|
+
mock_create_chain_atomic,
|
|
170
|
+
mock_archive_dir,
|
|
171
|
+
mock_upload_chain_artifact,
|
|
172
|
+
mock_push_data,
|
|
173
|
+
mock_remote_context,
|
|
174
|
+
):
|
|
175
|
+
"""Test that push_chain_atomic uploads raw chain artifact when chain_root is provided."""
|
|
176
|
+
mock_create_chain_atomic.return_value = Mock()
|
|
177
|
+
mock_archive_dir.return_value = Mock()
|
|
178
|
+
mock_upload_chain_artifact.return_value = "chains/test-uuid/chain.tgz"
|
|
179
|
+
|
|
180
|
+
context = mock_remote_context
|
|
181
|
+
remote = context["remote"]
|
|
182
|
+
chain_name = context["chain_name"]
|
|
183
|
+
entrypoint_artifact = context["entrypoint_artifact"]
|
|
184
|
+
dependency_artifacts = context["dependency_artifacts"]
|
|
185
|
+
truss_user_env = context["truss_user_env"]
|
|
186
|
+
chain_root = context["chain_root"]
|
|
187
|
+
|
|
188
|
+
context["mock_prepare_push"].return_value = mock_push_data
|
|
189
|
+
|
|
190
|
+
result = remote.push_chain_atomic(
|
|
191
|
+
chain_name=chain_name,
|
|
192
|
+
entrypoint_artifact=entrypoint_artifact,
|
|
193
|
+
dependency_artifacts=dependency_artifacts,
|
|
194
|
+
truss_user_env=truss_user_env,
|
|
195
|
+
chain_root=chain_root,
|
|
196
|
+
publish=True,
|
|
197
|
+
)
|
|
198
|
+
assert result == mock_create_chain_atomic.return_value
|
|
199
|
+
|
|
200
|
+
mock_archive_dir.assert_called_once_with(dir=chain_root, progress_bar=None)
|
|
201
|
+
mock_upload_chain_artifact.assert_called_once()
|
|
202
|
+
|
|
203
|
+
mock_create_chain_atomic.assert_called_once()
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
@patch("truss.remote.baseten.remote.create_chain_atomic")
|
|
207
|
+
def test_push_chain_atomic_without_chain_upload(
|
|
208
|
+
mock_create_chain_atomic, mock_push_data, mock_remote_context
|
|
209
|
+
):
|
|
210
|
+
"""Test that push_chain_atomic skips chain upload when chain_root is None."""
|
|
211
|
+
mock_create_chain_atomic.return_value = Mock()
|
|
212
|
+
|
|
213
|
+
context = mock_remote_context
|
|
214
|
+
remote = context["remote"]
|
|
215
|
+
chain_name = context["chain_name"]
|
|
216
|
+
entrypoint_artifact = context["entrypoint_artifact"]
|
|
217
|
+
dependency_artifacts = context["dependency_artifacts"]
|
|
218
|
+
truss_user_env = context["truss_user_env"]
|
|
219
|
+
|
|
220
|
+
context["mock_prepare_push"].return_value = mock_push_data
|
|
221
|
+
|
|
222
|
+
with patch("truss.remote.baseten.remote.upload_chain_artifact") as mock_upload:
|
|
223
|
+
with patch(
|
|
224
|
+
"truss.remote.baseten.core.create_tar_with_progress_bar"
|
|
225
|
+
) as mock_tar:
|
|
226
|
+
# Call push_chain_atomic without chain_root
|
|
227
|
+
result = remote.push_chain_atomic(
|
|
228
|
+
chain_name=chain_name,
|
|
229
|
+
entrypoint_artifact=entrypoint_artifact,
|
|
230
|
+
dependency_artifacts=dependency_artifacts,
|
|
231
|
+
truss_user_env=truss_user_env,
|
|
232
|
+
chain_root=None, # No chain root
|
|
233
|
+
publish=True,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
assert result
|
|
237
|
+
# Verify chain artifact upload was NOT called
|
|
238
|
+
mock_tar.assert_not_called()
|
|
239
|
+
mock_upload.assert_not_called()
|
|
240
|
+
|
|
241
|
+
mock_create_chain_atomic.assert_called_once()
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
@patch("truss.remote.baseten.core.multipart_upload_boto3")
|
|
245
|
+
def test_upload_chain_artifact_error_handling(mock_multipart_upload):
|
|
246
|
+
"""Test error handling in upload_chain_artifact."""
|
|
247
|
+
# Mock API to raise an exception
|
|
248
|
+
api = Mock(spec=BasetenApi)
|
|
249
|
+
api.get_chain_s3_upload_credentials.side_effect = Exception("API Error")
|
|
250
|
+
|
|
251
|
+
with tempfile.NamedTemporaryFile(suffix=".tgz") as temp_file:
|
|
252
|
+
# Should raise the exception
|
|
253
|
+
with pytest.raises(Exception, match="API Error"):
|
|
254
|
+
upload_chain_artifact(api, temp_file, None)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def test_upload_chain_artifact_credentials_extraction():
|
|
258
|
+
"""Test that credentials are properly extracted from API response."""
|
|
259
|
+
# Mock ChainUploadCredentials object
|
|
260
|
+
mock_credentials = Mock()
|
|
261
|
+
mock_credentials.s3_bucket = "test-bucket"
|
|
262
|
+
mock_credentials.s3_key = "chains/test-uuid/chain.tgz"
|
|
263
|
+
mock_credentials.aws_credentials = Mock()
|
|
264
|
+
mock_credentials.aws_credentials.model_dump.return_value = {
|
|
265
|
+
"aws_access_key_id": "access_key",
|
|
266
|
+
"aws_secret_access_key": "secret_key",
|
|
267
|
+
"aws_session_token": "session_token",
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
api = Mock(spec=BasetenApi)
|
|
271
|
+
api.get_chain_s3_upload_credentials.return_value = mock_credentials
|
|
272
|
+
|
|
273
|
+
with patch("truss.remote.baseten.core.multipart_upload_boto3") as mock_upload:
|
|
274
|
+
with tempfile.NamedTemporaryFile(suffix=".tgz") as temp_file:
|
|
275
|
+
upload_chain_artifact(api, temp_file, None)
|
|
276
|
+
|
|
277
|
+
call_args = mock_upload.call_args
|
|
278
|
+
credentials = call_args[0][3]
|
|
279
|
+
|
|
280
|
+
assert credentials == {
|
|
281
|
+
"aws_access_key_id": "access_key",
|
|
282
|
+
"aws_secret_access_key": "secret_key",
|
|
283
|
+
"aws_session_token": "session_token",
|
|
284
|
+
}
|
|
285
|
+
assert "extra_field" not in credentials
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from contextlib import contextmanager
|
|
2
|
+
from typing import Generator
|
|
3
|
+
|
|
4
|
+
from botocore.exceptions import ClientError, NoCredentialsError
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@contextmanager
|
|
8
|
+
def handle_client_error(
|
|
9
|
+
operation_description: str = "AWS operation",
|
|
10
|
+
) -> Generator[None, None, None]:
|
|
11
|
+
"""
|
|
12
|
+
Context manager to handle common boto3 errors and convert them to RuntimeError.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
operation_description: Description of the operation being performed for error messages
|
|
16
|
+
|
|
17
|
+
Raises:
|
|
18
|
+
RuntimeError: For NoCredentialsError, ClientError, and other exceptions
|
|
19
|
+
"""
|
|
20
|
+
try:
|
|
21
|
+
yield
|
|
22
|
+
except NoCredentialsError as nce:
|
|
23
|
+
raise RuntimeError(
|
|
24
|
+
f"No AWS credentials found for {operation_description}\nOriginal exception: {str(nce)}"
|
|
25
|
+
)
|
|
26
|
+
except ClientError as ce:
|
|
27
|
+
raise RuntimeError(
|
|
28
|
+
f"AWS client error when {operation_description} (check your credentials): {str(ce)}"
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
except Exception as exc:
|
|
32
|
+
raise RuntimeError(
|
|
33
|
+
f"Unexpected error `{exc}` during `{operation_description}`\nOriginal exception: {str(exc)}"
|
|
34
|
+
)
|
|
@@ -8,7 +8,7 @@ truss/base/errors.py,sha256=zDVLEvseTChdPP0oNhBBQCtQUtZJUaof5zeWMIjqz6o,691
|
|
|
8
8
|
truss/base/trt_llm_config.py,sha256=rEtBVFg2QnNMxnaz11s5Z69dJB1w7Bpt48Wf6jSsVZI,33087
|
|
9
9
|
truss/base/truss_config.py,sha256=s39Xc1e20s8IV07YLl_aVnp-uRS18ZQ2TV-3FILx4nY,28416
|
|
10
10
|
truss/base/truss_spec.py,sha256=jFVF79CXoEEspl2kXBAPyi-rwISReIGTdobGpaIhwJw,5979
|
|
11
|
-
truss/cli/chains_commands.py,sha256=
|
|
11
|
+
truss/cli/chains_commands.py,sha256=QijtACpuAt2O1RV_qhTNPw0jcFg-u0dX9PP-ct0t-rs,17716
|
|
12
12
|
truss/cli/cli.py,sha256=PaMkuwXZflkU7sa1tEoT_Zmy-iBkEZs1m4IVqcieaeo,30367
|
|
13
13
|
truss/cli/remote_cli.py,sha256=G_xCKRXzgkCmkiZJhUFfsv5YSVgde1jLA5LPQitpZgI,1905
|
|
14
14
|
truss/cli/train_commands.py,sha256=CrVqWsdkmSxgi3i2sSEyiE4QdfD0Z96F2Ib-PMZJjm8,20444
|
|
@@ -34,7 +34,7 @@ truss/cli/utils/output.py,sha256=GNjU85ZAMp5BI6Yij5wYXcaAvpm_kmHV0nHNmdkMxb0,646
|
|
|
34
34
|
truss/cli/utils/self_upgrade.py,sha256=eTJZA4Wc8uUp4Qh6viRQp6bZm--wnQp7KWe5KRRpPtg,5427
|
|
35
35
|
truss/contexts/docker_build_setup.py,sha256=cF4ExZgtYvrWxvyCAaUZUvV_DB_7__MqVomUDpalvKo,3925
|
|
36
36
|
truss/contexts/truss_context.py,sha256=uS6L-ACHxNk0BsJwESOHh1lA0OGGw0pb33aFKGsASj4,436
|
|
37
|
-
truss/contexts/image_builder/cache_warmer.py,sha256=
|
|
37
|
+
truss/contexts/image_builder/cache_warmer.py,sha256=EETFAgZk7C6rQezzFxz4XqjS5LIyF7uM1VVscQt_cBA,6959
|
|
38
38
|
truss/contexts/image_builder/image_builder.py,sha256=IuRgDeeoHVLzIkJvKtX3807eeqEyaroCs_KWDcIHZUg,1461
|
|
39
39
|
truss/contexts/image_builder/serving_image_builder.py,sha256=1PfHtkTEdNPhSQAX8Ajk_0LN3KR2EfLKwOJsnECtKXQ,33958
|
|
40
40
|
truss/contexts/image_builder/util.py,sha256=y2-CjUKv0XV-0w2sr1fUCflysDJLsoU4oPp6tvvoFnk,1203
|
|
@@ -52,12 +52,12 @@ truss/patch/truss_dir_patch_applier.py,sha256=ALnaVnu96g0kF2UmGuBFTua3lrXpwAy4sG
|
|
|
52
52
|
truss/remote/remote_factory.py,sha256=-0gLh_yIyNDgD48Q6sR8Yo5dOMQg84lrHRvn_XR0n4s,3585
|
|
53
53
|
truss/remote/truss_remote.py,sha256=TEe6h6by5-JLy7PMFsDN2QxIY5FmdIYN3bKvHHl02xM,8440
|
|
54
54
|
truss/remote/baseten/__init__.py,sha256=XNqJW1zyp143XQc6-7XVwsUA_Q_ZJv_ausn1_Ohtw9Y,176
|
|
55
|
-
truss/remote/baseten/api.py,sha256=
|
|
55
|
+
truss/remote/baseten/api.py,sha256=2Es2afWKnz7OlQJHIbvYAKoSrb1dn9SnsAY--uHXbTs,30210
|
|
56
56
|
truss/remote/baseten/auth.py,sha256=tI7s6cI2EZgzpMIzrdbILHyGwiHDnmoKf_JBhJXT55E,776
|
|
57
|
-
truss/remote/baseten/core.py,sha256=
|
|
58
|
-
truss/remote/baseten/custom_types.py,sha256=
|
|
57
|
+
truss/remote/baseten/core.py,sha256=69utHGGFRw1ZQUobj80TSmaBgU3plnsfHZfiR15dPrY,23502
|
|
58
|
+
truss/remote/baseten/custom_types.py,sha256=g7yWkE8p6uIAG5JqgfELFGHzjFLvO7vLPzbe-yl1nYs,4735
|
|
59
59
|
truss/remote/baseten/error.py,sha256=3TNTwwPqZnr4NRd9Sl6SfLUQR2fz9l6akDPpOntTpzA,578
|
|
60
|
-
truss/remote/baseten/remote.py,sha256=
|
|
60
|
+
truss/remote/baseten/remote.py,sha256=aKG1BODtrnmuRV-M8T3F3pw8oHawGwI09caKANJ19BM,23420
|
|
61
61
|
truss/remote/baseten/rest_client.py,sha256=_t3CWsWARt2u0C0fDsF4rtvkkHe-lH7KXoPxWXAkKd4,1185
|
|
62
62
|
truss/remote/baseten/service.py,sha256=HMaKiYbr2Mzv4BfXF9QkJ8H3Wwrq3LOMpFt9js4t0rs,5834
|
|
63
63
|
truss/remote/baseten/utils/status.py,sha256=jputc9N9AHXxUuW4KOk6mcZKzQ_gOBOe5BSx9K0DxPY,1266
|
|
@@ -93,7 +93,7 @@ truss/templates/custom/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NM
|
|
|
93
93
|
truss/templates/custom/model/model.py,sha256=J04rLxK09Pwt2F4GoKOLKL-H-CqZUdYIM-PL2CE9PoE,1079
|
|
94
94
|
truss/templates/custom_python_dx/my_model.py,sha256=NG75mQ6wxzB1BYUemDFZvRLBET-UrzuUK4FuHjqI29U,910
|
|
95
95
|
truss/templates/docker_server/proxy.conf.jinja,sha256=Axily8EtznvrF7mUCgy2VFY99BYRt4BycZ0p9uWfd0s,2025
|
|
96
|
-
truss/templates/docker_server/supervisord.conf.jinja,sha256=
|
|
96
|
+
truss/templates/docker_server/supervisord.conf.jinja,sha256=N95c7nQvPZ5i-Oypy_7nYaV7JgcYQ25M5D4F2exS2HI,1804
|
|
97
97
|
truss/templates/server/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
98
98
|
truss/templates/server/main.py,sha256=kWXrdD8z8IpamyWxc8qcvd5ck9gM1Kz2QH5qHJCnmOQ,222
|
|
99
99
|
truss/templates/server/model_wrapper.py,sha256=k75VVISwwlsx5EGb82UZsu8kCM_i6Yi3-Hd0-Kpm1yo,42055
|
|
@@ -141,6 +141,7 @@ truss/tests/test_testing_utilities_for_other_tests.py,sha256=YqIKflnd_BUMYaDBSkX
|
|
|
141
141
|
truss/tests/test_truss_gatherer.py,sha256=bn288OEkC49YY0mhly4cAl410ktZPfElNdWwZy82WfA,1261
|
|
142
142
|
truss/tests/test_truss_handle.py,sha256=-xz9VXkecXDTslmQZ-dmUmQLnvD0uumRqHS2uvGlMBA,30750
|
|
143
143
|
truss/tests/test_util.py,sha256=hs1bNMkXKEdoPRx4Nw-NAEdoibR92OubZuADGmbiYsQ,1344
|
|
144
|
+
truss/tests/cli/test_chains_cli.py,sha256=l9GTQrhRm9SRZn43WkMY4tdRslLmdsVyiydRPa1_Ja4,3162
|
|
144
145
|
truss/tests/cli/test_cli.py,sha256=yfbVS5u1hnAmmA8mJ539vj3lhH-JVGUvC4Q_Mbort44,787
|
|
145
146
|
truss/tests/cli/train/test_cache_view.py,sha256=aVRCh3atRpFbJqyYgq7N-vAW0DiKMftQ7ajUqO2ClOg,22606
|
|
146
147
|
truss/tests/cli/train/test_deploy_checkpoints.py,sha256=Ndkd9YxEgDLf3zLAZYH0myFK_wkKTz0oGZ57yWQt_l8,10100
|
|
@@ -162,6 +163,7 @@ truss/tests/remote/test_truss_remote.py,sha256=Rguyrnbx5RlbPJHFfCtsRtX1czAJ9Fo0a
|
|
|
162
163
|
truss/tests/remote/baseten/conftest.py,sha256=vNk0nfDB7XdmqatOMhjdANCWFGYM4VwSHVKlaBO2PPk,442
|
|
163
164
|
truss/tests/remote/baseten/test_api.py,sha256=AKJeNsrUtTNa0QPClfEvXlBOSJ214PKp23ULehMRJOQ,15885
|
|
164
165
|
truss/tests/remote/baseten/test_auth.py,sha256=ttu4bDnmwGfo3oiNut4HVGnh-QnjAefwZJctiibQJKY,669
|
|
166
|
+
truss/tests/remote/baseten/test_chain_upload.py,sha256=XaaF1ocovkBYsLMJ8EpXB9FUGfQZAwu4iyOWqoVn7tc,10886
|
|
165
167
|
truss/tests/remote/baseten/test_core.py,sha256=6NzJTDmoSUv6Muy1LFEYIUg10-cqw-hbLyeTSWcdNjY,26117
|
|
166
168
|
truss/tests/remote/baseten/test_remote.py,sha256=y1qSPL1t7dBeYI3xMFn436fttG7wkYdAoENTz7qKObg,23634
|
|
167
169
|
truss/tests/remote/baseten/test_service.py,sha256=ehbGkzzSPdLN7JHxc0O9YDPfzzKqU8OBzJGjRdw08zE,3786
|
|
@@ -340,6 +342,7 @@ truss/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
|
340
342
|
truss/util/docker.py,sha256=6PD7kMBBrOjsdvgkuSv7JMgZbe3NoJIeGasljMm2SwA,3934
|
|
341
343
|
truss/util/download.py,sha256=1lfBwzyaNLEp7SAVrBd9BX5inZpkCVp8sBnS9RNoiJA,2521
|
|
342
344
|
truss/util/env_vars.py,sha256=7Bv686eER71Barrs6fNamk_TrTJGmu9yV2TxaVmupn0,1232
|
|
345
|
+
truss/util/error_utils.py,sha256=aO76Vf8LMlvhM28jRJ1qzNl4E5ZyvKK4TFQq_UhbQrk,1095
|
|
343
346
|
truss/util/gpu.py,sha256=YiEF_JZyzur0MDMJOebMuJBQxrHD9ApGI0aPpWdb5BU,553
|
|
344
347
|
truss/util/jinja.py,sha256=7KbuYNq55I3DGtImAiCvBwR0K9-z1Jo6gMhmsy4lNZE,333
|
|
345
348
|
truss/util/log_utils.py,sha256=LwSgRh2K7KFjKKqBxr-IirFxGIzHi1mUM7YEvujvHsE,1985
|
|
@@ -349,8 +352,8 @@ truss/util/requirements.py,sha256=6T4nVV_NbSl3mAEo-CAk3JFmyJ_RJD768QaR55RdUJQ,69
|
|
|
349
352
|
truss/util/user_config.py,sha256=CvBf5oouNyfdcFXOg3HFhELVW-THiuwyOYdW3aTxdHw,9130
|
|
350
353
|
truss_chains/__init__.py,sha256=QDw1YwdqMaQpz5Oltu2Eq2vzEX9fDrMoqnhtbeh60i4,1278
|
|
351
354
|
truss_chains/framework.py,sha256=CS7tSegPe2Q8UUT6CDkrtSrB3utr_1QN1jTEPjrj5Ug,67519
|
|
352
|
-
truss_chains/private_types.py,sha256=
|
|
353
|
-
truss_chains/public_api.py,sha256=
|
|
355
|
+
truss_chains/private_types.py,sha256=vdcl8FuVsL9JGIu_9K7fd2EW9Ytzoq8nfEx5pmuMKTA,9063
|
|
356
|
+
truss_chains/public_api.py,sha256=civY8juJU92jSGBI7zM1qMnA7hlUdCq7L8o4IOo5meA,9722
|
|
354
357
|
truss_chains/public_types.py,sha256=RPr8jgKO_F_26F7H3CpwbidL-6euoKPdFHVpEIpYqrQ,29415
|
|
355
358
|
truss_chains/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
356
359
|
truss_chains/pydantic_numpy.py,sha256=MG8Ji_Inwo_JSfM2n7TPj8B-nbrBlDYsY3SOeBwD8fE,4289
|
|
@@ -358,7 +361,7 @@ truss_chains/streaming.py,sha256=DGl2LEAN67YwP7Nn9MK488KmYc4KopWmcHuE6WjyO1Q,125
|
|
|
358
361
|
truss_chains/utils.py,sha256=LvpCG2lnN6dqPqyX3PwLH9tyjUzqQN3N4WeEFROMHak,6291
|
|
359
362
|
truss_chains/deployment/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
360
363
|
truss_chains/deployment/code_gen.py,sha256=397FiSNZuW59J3Ma7N9GKGfvG_87BNFAXCIV8BW41t0,32669
|
|
361
|
-
truss_chains/deployment/deployment_client.py,sha256=
|
|
364
|
+
truss_chains/deployment/deployment_client.py,sha256=4cHuvaynVCclJ6M9pw8ukhO1E2NRKohIRxftvOfNvOE,34499
|
|
362
365
|
truss_chains/reference_code/reference_chainlet.py,sha256=5feSeqGtrHDbldkfZCfX2R5YbbW0Uhc35mhaP2pXrHw,1340
|
|
363
366
|
truss_chains/reference_code/reference_model.py,sha256=emH3hb23E_nbP98I37PGp1Xk1hz3g3lQ00tiLo55cSM,322
|
|
364
367
|
truss_chains/remote_chainlet/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -371,8 +374,8 @@ truss_train/deployment.py,sha256=lWWANSuzBWu2M4oK4qD7n-oVR1JKdmw2Pn5BJQHg-Ck,307
|
|
|
371
374
|
truss_train/loader.py,sha256=0o66EjBaHc2YY4syxxHVR4ordJWs13lNXnKjKq2wq0U,1630
|
|
372
375
|
truss_train/public_api.py,sha256=9N_NstiUlmBuLUwH_fNG_1x7OhGCytZLNvqKXBlStrM,1220
|
|
373
376
|
truss_train/restore_from_checkpoint.py,sha256=8hdPm-WSgkt74HDPjvCjZMBpvA9MwtoYsxVjOoa7BaM,1176
|
|
374
|
-
truss-0.11.
|
|
375
|
-
truss-0.11.
|
|
376
|
-
truss-0.11.
|
|
377
|
-
truss-0.11.
|
|
378
|
-
truss-0.11.
|
|
377
|
+
truss-0.11.15rc10.dist-info/METADATA,sha256=eFZQXzoDyrBqITGt5jfqUKBdFWLaL92ZxYIPjAd1b58,6682
|
|
378
|
+
truss-0.11.15rc10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
379
|
+
truss-0.11.15rc10.dist-info/entry_points.txt,sha256=-MwKfHHQHQ6j0HqIgvxrz3CehCmczDLTD-OsRHnjjuU,130
|
|
380
|
+
truss-0.11.15rc10.dist-info/licenses/LICENSE,sha256=FTqGzu85i-uw1Gi8E_o0oD60bH9yQ_XIGtZbA1QUYiw,1064
|
|
381
|
+
truss-0.11.15rc10.dist-info/RECORD,,
|
|
@@ -516,14 +516,21 @@ def _create_baseten_chain(
|
|
|
516
516
|
|
|
517
517
|
_create_chains_secret_if_missing(remote_provider)
|
|
518
518
|
|
|
519
|
+
# Get chain root for raw chain artifact upload
|
|
520
|
+
chain_root = None
|
|
521
|
+
if entrypoint_descriptor is not None:
|
|
522
|
+
chain_root = _get_chain_root(entrypoint_descriptor.chainlet_cls)
|
|
523
|
+
|
|
519
524
|
chain_deployment_handle = remote_provider.push_chain_atomic(
|
|
520
525
|
baseten_options.chain_name,
|
|
521
526
|
entrypoint_artifact,
|
|
522
527
|
dependency_artifacts,
|
|
523
528
|
truss_user_env,
|
|
529
|
+
chain_root=chain_root,
|
|
524
530
|
publish=baseten_options.publish,
|
|
525
531
|
environment=baseten_options.environment,
|
|
526
532
|
progress_bar=progress_bar,
|
|
533
|
+
disable_chain_download=baseten_options.disable_chain_download,
|
|
527
534
|
)
|
|
528
535
|
return BasetenChainService(
|
|
529
536
|
baseten_options.chain_name,
|
truss_chains/private_types.py
CHANGED
|
@@ -265,6 +265,7 @@ class PushOptionsBaseten(PushOptions):
|
|
|
265
265
|
environment: Optional[str]
|
|
266
266
|
include_git_info: bool
|
|
267
267
|
working_dir: pathlib.Path
|
|
268
|
+
disable_chain_download: bool = False
|
|
268
269
|
|
|
269
270
|
@classmethod
|
|
270
271
|
def create(
|
|
@@ -277,6 +278,7 @@ class PushOptionsBaseten(PushOptions):
|
|
|
277
278
|
include_git_info: bool,
|
|
278
279
|
working_dir: pathlib.Path,
|
|
279
280
|
environment: Optional[str] = None,
|
|
281
|
+
disable_chain_download: bool = False,
|
|
280
282
|
) -> "PushOptionsBaseten":
|
|
281
283
|
if promote and not environment:
|
|
282
284
|
environment = PRODUCTION_ENVIRONMENT_NAME
|
|
@@ -290,6 +292,7 @@ class PushOptionsBaseten(PushOptions):
|
|
|
290
292
|
environment=environment,
|
|
291
293
|
include_git_info=include_git_info,
|
|
292
294
|
working_dir=working_dir,
|
|
295
|
+
disable_chain_download=disable_chain_download,
|
|
293
296
|
)
|
|
294
297
|
|
|
295
298
|
|
truss_chains/public_api.py
CHANGED
|
@@ -151,6 +151,7 @@ def push(
|
|
|
151
151
|
environment: Optional[str] = None,
|
|
152
152
|
progress_bar: Optional[Type["progress.Progress"]] = None,
|
|
153
153
|
include_git_info: bool = False,
|
|
154
|
+
disable_chain_download: bool = False,
|
|
154
155
|
) -> deployment_client.BasetenChainService:
|
|
155
156
|
"""
|
|
156
157
|
Deploys a chain remotely (with all dependent chainlets).
|
|
@@ -172,6 +173,7 @@ def push(
|
|
|
172
173
|
include_git_info: Whether to attach git versioning info (sha, branch, tag) to
|
|
173
174
|
deployments made from within a git repo. If set to True in `.trussrc`, it
|
|
174
175
|
will always be attached.
|
|
176
|
+
disable_chain_download: Disable downloading of pushed chain source code from the UI.
|
|
175
177
|
|
|
176
178
|
Returns:
|
|
177
179
|
A chain service handle to the deployed chain.
|
|
@@ -186,6 +188,7 @@ def push(
|
|
|
186
188
|
environment=environment,
|
|
187
189
|
include_git_info=include_git_info,
|
|
188
190
|
working_dir=pathlib.Path(inspect.getfile(entrypoint)).parent,
|
|
191
|
+
disable_chain_download=disable_chain_download,
|
|
189
192
|
)
|
|
190
193
|
service = deployment_client.push(entrypoint, options, progress_bar=progress_bar)
|
|
191
194
|
assert isinstance(
|
|
File without changes
|
|
File without changes
|
|
File without changes
|