truss 0.11.14rc500__py3-none-any.whl → 0.11.15rc10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

@@ -211,6 +211,14 @@ def _create_chains_table(service) -> Tuple[rich.table.Table, List[str]]:
211
211
  default=False,
212
212
  help=common.INCLUDE_GIT_INFO_DOC,
213
213
  )
214
+ @click.option(
215
+ "--disable-chain-download",
216
+ "disable_chain_download",
217
+ is_flag=True,
218
+ required=False,
219
+ default=False,
220
+ help="Disable downloading of pushed chain source code from the UI.",
221
+ )
214
222
  @click.pass_context
215
223
  @common.common_options()
216
224
  def push_chain(
@@ -227,6 +235,7 @@ def push_chain(
227
235
  environment: Optional[str],
228
236
  experimental_watch_chainlet_names: Optional[str],
229
237
  include_git_info: bool = False,
238
+ disable_chain_download: bool = False,
230
239
  ) -> None:
231
240
  """
232
241
  Deploys a chain remotely.
@@ -284,6 +293,7 @@ def push_chain(
284
293
  environment=environment,
285
294
  include_git_info=include_git_info,
286
295
  working_dir=source.parent if source.is_file() else source.resolve(),
296
+ disable_chain_download=disable_chain_download,
287
297
  )
288
298
  service = deployment_client.push(
289
299
  entrypoint_cls, options, progress_bar=progress.Progress
@@ -11,10 +11,11 @@ from typing import Optional, Type
11
11
  import boto3
12
12
  from botocore import UNSIGNED
13
13
  from botocore.client import Config
14
- from botocore.exceptions import ClientError, NoCredentialsError
15
14
  from google.cloud import storage
16
15
  from huggingface_hub import hf_hub_download
17
16
 
17
+ from truss.util.error_utils import handle_client_error
18
+
18
19
  B10CP_PATH_TRUSS_ENV_VAR_NAME = "B10CP_PATH_TRUSS"
19
20
 
20
21
  GCS_CREDENTIALS = "/app/data/service_account.json"
@@ -188,24 +189,14 @@ class S3File(RepositoryFile):
188
189
  if not dst_file.parent.exists():
189
190
  dst_file.parent.mkdir(parents=True)
190
191
 
191
- try:
192
+ with handle_client_error(
193
+ f"accessing S3 bucket {bucket_name} for file {file_name}"
194
+ ):
192
195
  url = client.generate_presigned_url(
193
196
  "get_object",
194
197
  Params={"Bucket": bucket_name, "Key": file_name},
195
198
  ExpiresIn=3600,
196
199
  )
197
- except NoCredentialsError as nce:
198
- raise RuntimeError(
199
- f"No AWS credentials found\nOriginal exception: {str(nce)}"
200
- )
201
- except ClientError as ce:
202
- raise RuntimeError(
203
- f"Client error when accessing the S3 bucket (check your credentials): {str(ce)}"
204
- )
205
- except Exception as exc:
206
- raise RuntimeError(
207
- f"File not found on S3 bucket: {file_name}\nOriginal exception: {str(exc)}"
208
- )
209
200
 
210
201
  download_file_using_b10cp(url, dst_file, self.file_name)
211
202
 
@@ -797,11 +797,6 @@ class ServingImageBuilder(ImageBuilder):
797
797
  config
798
798
  )
799
799
 
800
- user_id = None
801
- if config.model_cache and config.model_cache.is_v2:
802
- if config.docker_server and config.docker_server.run_as_user_id:
803
- user_id = config.docker_server.run_as_user_id
804
-
805
800
  non_root_user = os.getenv("BT_USE_NON_ROOT_USER", False)
806
801
  dockerfile_contents = dockerfile_template.render(
807
802
  should_install_server_requirements=should_install_server_requirements,
@@ -837,7 +832,6 @@ class ServingImageBuilder(ImageBuilder):
837
832
  use_local_src=config.use_local_src,
838
833
  passthrough_environment_variables=passthrough_environment_variables,
839
834
  non_root_user=non_root_user,
840
- user_id=user_id,
841
835
  **FILENAME_CONSTANTS_MAP,
842
836
  )
843
837
  # Consolidate repeated empty lines to single empty lines.
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Mapping, Optional
5
5
  import requests
6
6
  from pydantic import BaseModel, Field
7
7
 
8
+ from truss.base.custom_types import SafeModel
8
9
  from truss.remote.baseten import custom_types as b10_types
9
10
  from truss.remote.baseten.auth import ApiKey, AuthService
10
11
  from truss.remote.baseten.custom_types import APIKeyCategory
@@ -13,6 +14,29 @@ from truss.remote.baseten.rest_client import RestAPIClient
13
14
  from truss.remote.baseten.utils.transfer import base64_encoded_json_str
14
15
 
15
16
  logger = logging.getLogger(__name__)
17
+ PARAMS_INDENT = "\n "
18
+
19
+
20
+ class ChainAWSCredential(SafeModel):
21
+ aws_access_key_id: str
22
+ aws_secret_access_key: str
23
+ aws_session_token: str
24
+
25
+
26
+ class ChainUploadCredentials(SafeModel):
27
+ s3_bucket: str
28
+ s3_key: str
29
+ aws_access_key_id: str
30
+ aws_secret_access_key: str
31
+ aws_session_token: str
32
+
33
+ @property
34
+ def aws_credentials(self) -> ChainAWSCredential:
35
+ return ChainAWSCredential(
36
+ aws_access_key_id=self.aws_access_key_id,
37
+ aws_secret_access_key=self.aws_secret_access_key,
38
+ aws_session_token=self.aws_session_token,
39
+ )
16
40
 
17
41
 
18
42
  class InstanceTypeV1(BaseModel):
@@ -299,7 +323,11 @@ class BasetenApi:
299
323
  chain_name: Optional[str] = None,
300
324
  environment: Optional[str] = None,
301
325
  is_draft: bool = False,
326
+ original_source_artifact_s3_key: Optional[str] = None,
327
+ allow_truss_download: Optional[bool] = True,
302
328
  ):
329
+ if allow_truss_download is None:
330
+ allow_truss_download = True
303
331
  entrypoint_str = _chainlet_data_atomic_to_graphql_mutation(entrypoint)
304
332
 
305
333
  dependencies_str = ", ".join(
@@ -309,13 +337,28 @@ class BasetenApi:
309
337
  ]
310
338
  )
311
339
 
340
+ params = []
341
+ if chain_id:
342
+ params.append(f'chain_id: "{chain_id}"')
343
+ if chain_name:
344
+ params.append(f'chain_name: "{chain_name}"')
345
+ if environment:
346
+ params.append(f'environment: "{environment}"')
347
+ if original_source_artifact_s3_key:
348
+ params.append(
349
+ f'original_source_artifact_s3_key: "{original_source_artifact_s3_key}"'
350
+ )
351
+
352
+ params.append(f"is_draft: {str(is_draft).lower()}")
353
+ if allow_truss_download is False:
354
+ params.append("allow_truss_download: false")
355
+
356
+ params_str = PARAMS_INDENT.join(params)
357
+
312
358
  query_string = f"""
313
359
  mutation ($trussUserEnv: String) {{
314
360
  deploy_chain_atomic(
315
- {f'chain_id: "{chain_id}"' if chain_id else ""}
316
- {f'chain_name: "{chain_name}"' if chain_name else ""}
317
- {f'environment: "{environment}"' if environment else ""}
318
- is_draft: {str(is_draft).lower()}
361
+ {params_str}
319
362
  entrypoint: {entrypoint_str}
320
363
  dependencies: [{dependencies_str}]
321
364
  truss_user_env: $trussUserEnv
@@ -657,8 +700,29 @@ class BasetenApi:
657
700
  return resp_json["training_projects"]
658
701
 
659
702
  def get_blob_credentials(self, blob_type: b10_types.BlobType):
703
+ if blob_type == b10_types.BlobType.CHAIN:
704
+ return self.get_chain_s3_upload_credentials()
660
705
  return self._rest_api_client.get(f"v1/blobs/credentials/{blob_type.value}")
661
706
 
707
+ def get_chain_s3_upload_credentials(self) -> ChainUploadCredentials:
708
+ """Get chain artifact credentials using GraphQL query."""
709
+ query = """
710
+ query {
711
+ chain_s3_upload_credentials {
712
+ s3_bucket
713
+ s3_key
714
+ aws_access_key_id
715
+ aws_secret_access_key
716
+ aws_session_token
717
+ }
718
+ }
719
+ """
720
+ response = self._post_graphql_query(query)
721
+
722
+ return ChainUploadCredentials.model_validate(
723
+ response["data"]["chain_s3_upload_credentials"]
724
+ )
725
+
662
726
  def get_training_job_metrics(
663
727
  self,
664
728
  project_id: str,
@@ -8,6 +8,7 @@ from typing import IO, TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tup
8
8
  import requests
9
9
 
10
10
  from truss.base.errors import ValidationError
11
+ from truss.util.error_utils import handle_client_error
11
12
 
12
13
  if TYPE_CHECKING:
13
14
  from rich import progress
@@ -80,6 +81,8 @@ class ChainDeploymentHandleAtomic(NamedTuple):
80
81
  chain_id: str
81
82
  chain_deployment_id: str
82
83
  is_draft: bool
84
+ original_source_artifact_s3_key: Optional[str] = None
85
+ allow_truss_download: Optional[bool] = True
83
86
 
84
87
 
85
88
  class ModelVersionHandle(NamedTuple):
@@ -127,6 +130,8 @@ def create_chain_atomic(
127
130
  is_draft: bool,
128
131
  truss_user_env: b10_types.TrussUserEnv,
129
132
  environment: Optional[str],
133
+ original_source_artifact_s3_key: Optional[str] = None,
134
+ allow_truss_download: bool = True,
130
135
  ) -> ChainDeploymentHandleAtomic:
131
136
  if environment and is_draft:
132
137
  logging.info(
@@ -149,6 +154,8 @@ def create_chain_atomic(
149
154
  chain_name=chain_name,
150
155
  is_draft=True,
151
156
  truss_user_env=truss_user_env,
157
+ original_source_artifact_s3_key=original_source_artifact_s3_key,
158
+ allow_truss_download=allow_truss_download,
152
159
  )
153
160
  elif chain_id:
154
161
  # This is the only case where promote has relevance, since
@@ -162,6 +169,8 @@ def create_chain_atomic(
162
169
  chain_id=chain_id,
163
170
  environment=environment,
164
171
  truss_user_env=truss_user_env,
172
+ original_source_artifact_s3_key=original_source_artifact_s3_key,
173
+ allow_truss_download=allow_truss_download,
165
174
  )
166
175
  except ApiError as e:
167
176
  if (
@@ -182,6 +191,8 @@ def create_chain_atomic(
182
191
  dependencies=dependencies,
183
192
  chain_name=chain_name,
184
193
  truss_user_env=truss_user_env,
194
+ original_source_artifact_s3_key=original_source_artifact_s3_key,
195
+ allow_truss_download=allow_truss_download,
185
196
  )
186
197
 
187
198
  return ChainDeploymentHandleAtomic(
@@ -189,6 +200,8 @@ def create_chain_atomic(
189
200
  chain_id=res["chain_deployment"]["chain"]["id"],
190
201
  hostname=res["chain_deployment"]["chain"]["hostname"],
191
202
  is_draft=is_draft,
203
+ original_source_artifact_s3_key=original_source_artifact_s3_key,
204
+ allow_truss_download=allow_truss_download,
192
205
  )
193
206
 
194
207
 
@@ -342,6 +355,33 @@ def upload_truss(
342
355
  return s3_key
343
356
 
344
357
 
358
+ def upload_chain_artifact(
359
+ api: BasetenApi,
360
+ serialize_file: IO,
361
+ progress_bar: Optional[Type["progress.Progress"]],
362
+ ) -> str:
363
+ """
364
+ Upload a chain artifact to the Baseten remote.
365
+
366
+ Args:
367
+ api: BasetenApi instance
368
+ serialize_file: File-like object containing the serialized chain artifact
369
+
370
+ Returns:
371
+ The S3 key of the uploaded file
372
+ """
373
+ credentials = api.get_chain_s3_upload_credentials()
374
+ with handle_client_error("Uploading chain source"):
375
+ multipart_upload_boto3(
376
+ serialize_file.name,
377
+ credentials.s3_bucket,
378
+ credentials.s3_key,
379
+ credentials.aws_credentials.model_dump(),
380
+ progress_bar,
381
+ )
382
+ return credentials.s3_key
383
+
384
+
345
385
  def create_truss_service(
346
386
  api: BasetenApi,
347
387
  model_name: str,
@@ -120,6 +120,7 @@ class TrussUserEnv(pydantic.BaseModel):
120
120
  class BlobType(Enum):
121
121
  MODEL = "model"
122
122
  TRAIN = "train"
123
+ CHAIN = "chain"
123
124
 
124
125
 
125
126
  class FileSummary(pydantic.BaseModel):
@@ -31,6 +31,7 @@ from truss.remote.baseten.core import (
31
31
  get_model_and_versions,
32
32
  get_prod_version_from_versions,
33
33
  get_truss_watch_state,
34
+ upload_chain_artifact,
34
35
  upload_truss,
35
36
  validate_truss_config_against_backend,
36
37
  )
@@ -263,9 +264,11 @@ class BasetenRemote(TrussRemote):
263
264
  entrypoint_artifact: custom_types.ChainletArtifact,
264
265
  dependency_artifacts: List[custom_types.ChainletArtifact],
265
266
  truss_user_env: b10_types.TrussUserEnv,
267
+ chain_root: Optional[Path] = None,
266
268
  publish: bool = False,
267
269
  environment: Optional[str] = None,
268
270
  progress_bar: Optional[Type["progress.Progress"]] = None,
271
+ disable_chain_download: bool = False,
269
272
  ) -> ChainDeploymentHandleAtomic:
270
273
  # If we are promoting a model to an environment after deploy, it must be published.
271
274
  # Draft models cannot be promoted.
@@ -285,6 +288,7 @@ class BasetenRemote(TrussRemote):
285
288
  publish=publish,
286
289
  origin=custom_types.ModelOrigin.CHAINS,
287
290
  progress_bar=progress_bar,
291
+ disable_truss_download=disable_chain_download,
288
292
  )
289
293
  oracle_data = custom_types.OracleData(
290
294
  model_name=push_data.model_name,
@@ -300,6 +304,18 @@ class BasetenRemote(TrussRemote):
300
304
  )
301
305
  )
302
306
 
307
+ # Upload raw chain artifact if chain_root is provided
308
+ raw_chain_s3_key = None
309
+ if chain_root is not None:
310
+ logging.info("Uploading source artifact")
311
+ # Create a tar file from the chain root directory
312
+ original_source_tar = archive_dir(dir=chain_root, progress_bar=progress_bar)
313
+ # Upload the chain artifact to S3
314
+ raw_chain_s3_key = upload_chain_artifact(
315
+ api=self._api,
316
+ serialize_file=original_source_tar,
317
+ progress_bar=progress_bar,
318
+ )
303
319
  chain_deployment_handle = create_chain_atomic(
304
320
  api=self._api,
305
321
  chain_name=chain_name,
@@ -308,6 +324,8 @@ class BasetenRemote(TrussRemote):
308
324
  is_draft=not publish,
309
325
  truss_user_env=truss_user_env,
310
326
  environment=environment,
327
+ original_source_artifact_s3_key=raw_chain_s3_key,
328
+ allow_truss_download=not disable_chain_download,
311
329
  )
312
330
  logging.info("Successfully pushed to baseten. Chain is building and deploying.")
313
331
  return chain_deployment_handle
@@ -8,7 +8,7 @@ logfile_maxbytes=0 ; No size limit on logfile (since logging is disabl
8
8
  command={{start_command}} ; Command to start the model server (provided by Jinja variable)
9
9
  startsecs=30 ; Wait 30 seconds before assuming the server is running
10
10
  autostart=true ; Automatically start the program when supervisord starts
11
- autorestart=true ; Always restart the program if it exits, no matter what the exit code
11
+ autorestart=unexpected ; Don't restart the program if it exits with clear exit code
12
12
  stdout_logfile=/dev/fd/1 ; Send stdout to the first file descriptor (stdout)
13
13
  stdout_logfile_maxbytes=0 ; No size limit on stdout log
14
14
  redirect_stderr=true ; Redirect stderr to stdout
@@ -1,5 +1 @@
1
1
  FROM {{ config.base_image.image }}
2
- {# Add COPY for bptr-manifest if model_cache v2 is enabled #}
3
- {% if user_id %}
4
- COPY --chown={{ user_id }}:{{ user_id }} ./bptr-manifest /static-bptr/static-bptr-manifest.json
5
- {% endif %}
@@ -0,0 +1,100 @@
1
+ """Tests for truss chains CLI commands."""
2
+
3
+ from unittest.mock import Mock, patch
4
+
5
+ from click.testing import CliRunner
6
+
7
+ from truss.cli.cli import truss_cli
8
+
9
+
10
+ def test_chains_push_with_disable_chain_download_flag():
11
+ """Test that --disable-chain-download flag is properly parsed and passed through."""
12
+ runner = CliRunner()
13
+
14
+ mock_entrypoint_cls = Mock()
15
+ mock_entrypoint_cls.meta_data.chain_name = "test_chain"
16
+ mock_entrypoint_cls.display_name = "TestChain"
17
+
18
+ mock_service = Mock()
19
+ mock_service.run_remote_url = "http://test.com/run_remote"
20
+ mock_service.is_websocket = False
21
+
22
+ with patch(
23
+ "truss_chains.framework.ChainletImporter.import_target"
24
+ ) as mock_importer:
25
+ with patch("truss_chains.deployment.deployment_client.push") as mock_push:
26
+ mock_importer.return_value.__enter__.return_value = mock_entrypoint_cls
27
+ mock_push.return_value = mock_service
28
+
29
+ result = runner.invoke(
30
+ truss_cli,
31
+ [
32
+ "chains",
33
+ "push",
34
+ "test_chain.py",
35
+ "--disable-chain-download",
36
+ "--remote",
37
+ "test_remote",
38
+ "--dryrun",
39
+ ],
40
+ )
41
+
42
+ assert result.exit_code == 0
43
+
44
+ mock_push.assert_called_once()
45
+ call_args = mock_push.call_args
46
+ options = call_args[0][1]
47
+
48
+ assert hasattr(options, "disable_chain_download")
49
+ assert options.disable_chain_download is True
50
+
51
+
52
+ def test_chains_push_without_disable_chain_download_flag():
53
+ """Test that disable_chain_download defaults to False when flag is not provided."""
54
+ runner = CliRunner()
55
+
56
+ mock_entrypoint_cls = Mock()
57
+ mock_entrypoint_cls.meta_data.chain_name = "test_chain"
58
+ mock_entrypoint_cls.display_name = "TestChain"
59
+
60
+ mock_service = Mock()
61
+ mock_service.run_remote_url = "http://test.com/run_remote"
62
+ mock_service.is_websocket = False
63
+
64
+ with patch(
65
+ "truss_chains.framework.ChainletImporter.import_target"
66
+ ) as mock_importer:
67
+ with patch("truss_chains.deployment.deployment_client.push") as mock_push:
68
+ mock_importer.return_value.__enter__.return_value = mock_entrypoint_cls
69
+ mock_push.return_value = mock_service
70
+
71
+ result = runner.invoke(
72
+ truss_cli,
73
+ [
74
+ "chains",
75
+ "push",
76
+ "test_chain.py",
77
+ "--remote",
78
+ "test_remote",
79
+ "--dryrun",
80
+ ],
81
+ )
82
+
83
+ assert result.exit_code == 0
84
+
85
+ mock_push.assert_called_once()
86
+ call_args = mock_push.call_args
87
+ options = call_args[0][1]
88
+
89
+ assert hasattr(options, "disable_chain_download")
90
+ assert options.disable_chain_download is False
91
+
92
+
93
+ def test_chains_push_help_includes_disable_chain_download():
94
+ """Test that --disable-chain-download appears in the help output."""
95
+ runner = CliRunner()
96
+
97
+ result = runner.invoke(truss_cli, ["chains", "push", "--help"])
98
+
99
+ assert result.exit_code == 0
100
+ assert "--disable-chain-download" in result.output
@@ -0,0 +1,285 @@
1
+ import pathlib
2
+ import tempfile
3
+ from unittest.mock import Mock, patch
4
+
5
+ import pytest
6
+
7
+ from truss.remote.baseten import custom_types as b10_types
8
+ from truss.remote.baseten.api import BasetenApi
9
+ from truss.remote.baseten.core import upload_chain_artifact
10
+ from truss.remote.baseten.remote import BasetenRemote
11
+
12
+
13
+ @pytest.fixture
14
+ def mock_push_data():
15
+ """Fixture providing mock push data for tests."""
16
+ mock_push_data = Mock()
17
+ mock_push_data.model_name = "test-model"
18
+ mock_push_data.s3_key = "models/test-key"
19
+ mock_push_data.encoded_config_str = "encoded_config"
20
+ mock_push_data.is_draft = False
21
+ mock_push_data.model_id = "model-id"
22
+ mock_push_data.version_name = None
23
+ return mock_push_data
24
+
25
+
26
+ @pytest.fixture
27
+ def mock_remote_context():
28
+ """Fixture providing mock remote and context managers for tests."""
29
+ api = Mock(spec=BasetenApi)
30
+
31
+ remote = BasetenRemote("https://test.baseten.co", "test-api-key")
32
+ remote._api = api
33
+
34
+ chain_name = "test-chain"
35
+ entrypoint_artifact = Mock()
36
+ entrypoint_artifact.truss_dir = "/path/to/truss"
37
+ entrypoint_artifact.display_name = "entrypoint"
38
+
39
+ dependency_artifacts = []
40
+ truss_user_env = Mock()
41
+ chain_root = pathlib.Path("/path/to/chain")
42
+
43
+ with patch.object(remote, "_prepare_push") as mock_prepare_push:
44
+ with patch("truss.remote.baseten.remote.truss_build.load") as mock_load:
45
+ mock_truss_handle = Mock()
46
+ mock_truss_handle.spec.config.model_name = "test-model"
47
+ mock_load.return_value = mock_truss_handle
48
+
49
+ yield {
50
+ "remote": remote,
51
+ "api": api,
52
+ "chain_name": chain_name,
53
+ "entrypoint_artifact": entrypoint_artifact,
54
+ "dependency_artifacts": dependency_artifacts,
55
+ "truss_user_env": truss_user_env,
56
+ "chain_root": chain_root,
57
+ "mock_prepare_push": mock_prepare_push,
58
+ "mock_load": mock_load,
59
+ "mock_truss_handle": mock_truss_handle,
60
+ }
61
+
62
+
63
+ def test_get_blob_credentials_for_chain():
64
+ """Test that get_blob_credentials works correctly for chain blob type using GraphQL."""
65
+ mock_graphql_response = {
66
+ "data": {
67
+ "chain_s3_upload_credentials": {
68
+ "s3_bucket": "test-chain-bucket",
69
+ "s3_key": "chains/test-uuid/chain.tgz",
70
+ "aws_access_key_id": "test_access_key",
71
+ "aws_secret_access_key": "test_secret_key",
72
+ "aws_session_token": "test_session_token",
73
+ }
74
+ }
75
+ }
76
+
77
+ # Create a real API instance and mock the GraphQL call
78
+ mock_auth_service = Mock()
79
+ mock_auth_service.authenticate.return_value = Mock(value="test-token")
80
+ api = BasetenApi("https://test.baseten.co", mock_auth_service)
81
+ with patch.object(api, "_post_graphql_query") as mock_graphql:
82
+ mock_graphql.return_value = mock_graphql_response
83
+
84
+ result = api.get_chain_s3_upload_credentials()
85
+
86
+ assert result.s3_bucket == "test-chain-bucket"
87
+ assert result.s3_key == "chains/test-uuid/chain.tgz"
88
+ assert result.aws_access_key_id == "test_access_key"
89
+ assert result.aws_secret_access_key == "test_secret_key"
90
+ assert result.aws_session_token == "test_session_token"
91
+
92
+ mock_graphql.assert_called_once()
93
+ call_args = mock_graphql.call_args
94
+ assert "chain_s3_upload_credentials" in call_args[0][0]
95
+
96
+
97
+ def test_get_blob_credentials_for_other_types_uses_rest():
98
+ """Test that get_blob_credentials uses REST API for non-chain blob types."""
99
+ mock_response = {
100
+ "s3_bucket": "test-bucket",
101
+ "s3_key": "test-key",
102
+ "creds": {
103
+ "aws_access_key_id": "test_access_key",
104
+ "aws_secret_access_key": "test_secret_key",
105
+ "aws_session_token": "test_session_token",
106
+ },
107
+ }
108
+
109
+ mock_auth_service = Mock()
110
+ mock_auth_service.authenticate.return_value = Mock(value="test-token")
111
+ api = BasetenApi("https://test.baseten.co", mock_auth_service)
112
+ with patch.object(api, "_rest_api_client") as mock_client, patch.object(
113
+ api, "_post_graphql_query"
114
+ ) as mock_graphql:
115
+ mock_client.get.return_value = mock_response
116
+
117
+ result = api.get_blob_credentials(b10_types.BlobType.MODEL)
118
+
119
+ assert result["s3_bucket"] == "test-bucket"
120
+ assert result["s3_key"] == "test-key"
121
+
122
+ mock_client.get.assert_called_once_with("v1/blobs/credentials/model")
123
+ mock_graphql.assert_not_called()
124
+
125
+
126
+ @patch("truss.remote.baseten.core.multipart_upload_boto3")
127
+ def test_upload_chain_artifact_function(mock_multipart_upload):
128
+ """Test the upload_chain_artifact function."""
129
+ # Mock ChainUploadCredentials object
130
+ mock_credentials = Mock()
131
+ mock_credentials.s3_bucket = "test-chain-bucket"
132
+ mock_credentials.s3_key = "chains/test-uuid/chain.tgz"
133
+ mock_credentials.aws_credentials = Mock()
134
+ mock_credentials.aws_credentials.model_dump.return_value = {
135
+ "aws_access_key_id": "test_access_key",
136
+ "aws_secret_access_key": "test_secret_key",
137
+ "aws_session_token": "test_session_token",
138
+ }
139
+
140
+ api = Mock(spec=BasetenApi)
141
+ api.get_chain_s3_upload_credentials.return_value = mock_credentials
142
+
143
+ with tempfile.NamedTemporaryFile(suffix=".tgz", delete=False) as temp_file:
144
+ temp_file.write(b"test chain content")
145
+ temp_file.flush()
146
+
147
+ result = upload_chain_artifact(api, temp_file, None)
148
+
149
+ assert result == "chains/test-uuid/chain.tgz"
150
+
151
+ api.get_chain_s3_upload_credentials.assert_called_once_with()
152
+
153
+ mock_multipart_upload.assert_called_once()
154
+ call_args = mock_multipart_upload.call_args
155
+ assert call_args[0][0] == temp_file.name # file path
156
+ assert call_args[0][1] == "test-chain-bucket" # bucket
157
+ assert call_args[0][2] == "chains/test-uuid/chain.tgz" # key
158
+ assert call_args[0][3] == { # credentials
159
+ "aws_access_key_id": "test_access_key",
160
+ "aws_secret_access_key": "test_secret_key",
161
+ "aws_session_token": "test_session_token",
162
+ }
163
+
164
+
165
+ @patch("truss.remote.baseten.remote.upload_chain_artifact")
166
+ @patch("truss.remote.baseten.remote.archive_dir")
167
+ @patch("truss.remote.baseten.remote.create_chain_atomic")
168
+ def test_push_chain_atomic_with_chain_upload(
169
+ mock_create_chain_atomic,
170
+ mock_archive_dir,
171
+ mock_upload_chain_artifact,
172
+ mock_push_data,
173
+ mock_remote_context,
174
+ ):
175
+ """Test that push_chain_atomic uploads raw chain artifact when chain_root is provided."""
176
+ mock_create_chain_atomic.return_value = Mock()
177
+ mock_archive_dir.return_value = Mock()
178
+ mock_upload_chain_artifact.return_value = "chains/test-uuid/chain.tgz"
179
+
180
+ context = mock_remote_context
181
+ remote = context["remote"]
182
+ chain_name = context["chain_name"]
183
+ entrypoint_artifact = context["entrypoint_artifact"]
184
+ dependency_artifacts = context["dependency_artifacts"]
185
+ truss_user_env = context["truss_user_env"]
186
+ chain_root = context["chain_root"]
187
+
188
+ context["mock_prepare_push"].return_value = mock_push_data
189
+
190
+ result = remote.push_chain_atomic(
191
+ chain_name=chain_name,
192
+ entrypoint_artifact=entrypoint_artifact,
193
+ dependency_artifacts=dependency_artifacts,
194
+ truss_user_env=truss_user_env,
195
+ chain_root=chain_root,
196
+ publish=True,
197
+ )
198
+ assert result == mock_create_chain_atomic.return_value
199
+
200
+ mock_archive_dir.assert_called_once_with(dir=chain_root, progress_bar=None)
201
+ mock_upload_chain_artifact.assert_called_once()
202
+
203
+ mock_create_chain_atomic.assert_called_once()
204
+
205
+
206
+ @patch("truss.remote.baseten.remote.create_chain_atomic")
207
+ def test_push_chain_atomic_without_chain_upload(
208
+ mock_create_chain_atomic, mock_push_data, mock_remote_context
209
+ ):
210
+ """Test that push_chain_atomic skips chain upload when chain_root is None."""
211
+ mock_create_chain_atomic.return_value = Mock()
212
+
213
+ context = mock_remote_context
214
+ remote = context["remote"]
215
+ chain_name = context["chain_name"]
216
+ entrypoint_artifact = context["entrypoint_artifact"]
217
+ dependency_artifacts = context["dependency_artifacts"]
218
+ truss_user_env = context["truss_user_env"]
219
+
220
+ context["mock_prepare_push"].return_value = mock_push_data
221
+
222
+ with patch("truss.remote.baseten.remote.upload_chain_artifact") as mock_upload:
223
+ with patch(
224
+ "truss.remote.baseten.core.create_tar_with_progress_bar"
225
+ ) as mock_tar:
226
+ # Call push_chain_atomic without chain_root
227
+ result = remote.push_chain_atomic(
228
+ chain_name=chain_name,
229
+ entrypoint_artifact=entrypoint_artifact,
230
+ dependency_artifacts=dependency_artifacts,
231
+ truss_user_env=truss_user_env,
232
+ chain_root=None, # No chain root
233
+ publish=True,
234
+ )
235
+
236
+ assert result
237
+ # Verify chain artifact upload was NOT called
238
+ mock_tar.assert_not_called()
239
+ mock_upload.assert_not_called()
240
+
241
+ mock_create_chain_atomic.assert_called_once()
242
+
243
+
244
+ @patch("truss.remote.baseten.core.multipart_upload_boto3")
245
+ def test_upload_chain_artifact_error_handling(mock_multipart_upload):
246
+ """Test error handling in upload_chain_artifact."""
247
+ # Mock API to raise an exception
248
+ api = Mock(spec=BasetenApi)
249
+ api.get_chain_s3_upload_credentials.side_effect = Exception("API Error")
250
+
251
+ with tempfile.NamedTemporaryFile(suffix=".tgz") as temp_file:
252
+ # Should raise the exception
253
+ with pytest.raises(Exception, match="API Error"):
254
+ upload_chain_artifact(api, temp_file, None)
255
+
256
+
257
+ def test_upload_chain_artifact_credentials_extraction():
258
+ """Test that credentials are properly extracted from API response."""
259
+ # Mock ChainUploadCredentials object
260
+ mock_credentials = Mock()
261
+ mock_credentials.s3_bucket = "test-bucket"
262
+ mock_credentials.s3_key = "chains/test-uuid/chain.tgz"
263
+ mock_credentials.aws_credentials = Mock()
264
+ mock_credentials.aws_credentials.model_dump.return_value = {
265
+ "aws_access_key_id": "access_key",
266
+ "aws_secret_access_key": "secret_key",
267
+ "aws_session_token": "session_token",
268
+ }
269
+
270
+ api = Mock(spec=BasetenApi)
271
+ api.get_chain_s3_upload_credentials.return_value = mock_credentials
272
+
273
+ with patch("truss.remote.baseten.core.multipart_upload_boto3") as mock_upload:
274
+ with tempfile.NamedTemporaryFile(suffix=".tgz") as temp_file:
275
+ upload_chain_artifact(api, temp_file, None)
276
+
277
+ call_args = mock_upload.call_args
278
+ credentials = call_args[0][3]
279
+
280
+ assert credentials == {
281
+ "aws_access_key_id": "access_key",
282
+ "aws_secret_access_key": "secret_key",
283
+ "aws_session_token": "session_token",
284
+ }
285
+ assert "extra_field" not in credentials
@@ -0,0 +1,34 @@
1
+ from contextlib import contextmanager
2
+ from typing import Generator
3
+
4
+ from botocore.exceptions import ClientError, NoCredentialsError
5
+
6
+
7
+ @contextmanager
8
+ def handle_client_error(
9
+ operation_description: str = "AWS operation",
10
+ ) -> Generator[None, None, None]:
11
+ """
12
+ Context manager to handle common boto3 errors and convert them to RuntimeError.
13
+
14
+ Args:
15
+ operation_description: Description of the operation being performed for error messages
16
+
17
+ Raises:
18
+ RuntimeError: For NoCredentialsError, ClientError, and other exceptions
19
+ """
20
+ try:
21
+ yield
22
+ except NoCredentialsError as nce:
23
+ raise RuntimeError(
24
+ f"No AWS credentials found for {operation_description}\nOriginal exception: {str(nce)}"
25
+ )
26
+ except ClientError as ce:
27
+ raise RuntimeError(
28
+ f"AWS client error when {operation_description} (check your credentials): {str(ce)}"
29
+ )
30
+
31
+ except Exception as exc:
32
+ raise RuntimeError(
33
+ f"Unexpected error `{exc}` during `{operation_description}`\nOriginal exception: {str(exc)}"
34
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: truss
3
- Version: 0.11.14rc500
3
+ Version: 0.11.15rc10
4
4
  Summary: A seamless bridge from model development to model delivery
5
5
  Project-URL: Repository, https://github.com/basetenlabs/truss
6
6
  Project-URL: Homepage, https://truss.baseten.co
@@ -8,7 +8,7 @@ truss/base/errors.py,sha256=zDVLEvseTChdPP0oNhBBQCtQUtZJUaof5zeWMIjqz6o,691
8
8
  truss/base/trt_llm_config.py,sha256=rEtBVFg2QnNMxnaz11s5Z69dJB1w7Bpt48Wf6jSsVZI,33087
9
9
  truss/base/truss_config.py,sha256=s39Xc1e20s8IV07YLl_aVnp-uRS18ZQ2TV-3FILx4nY,28416
10
10
  truss/base/truss_spec.py,sha256=jFVF79CXoEEspl2kXBAPyi-rwISReIGTdobGpaIhwJw,5979
11
- truss/cli/chains_commands.py,sha256=Kpa5mCg6URAJQE2ZmZfVQFhjBHEitKT28tKiW0H6XAI,17406
11
+ truss/cli/chains_commands.py,sha256=QijtACpuAt2O1RV_qhTNPw0jcFg-u0dX9PP-ct0t-rs,17716
12
12
  truss/cli/cli.py,sha256=PaMkuwXZflkU7sa1tEoT_Zmy-iBkEZs1m4IVqcieaeo,30367
13
13
  truss/cli/remote_cli.py,sha256=G_xCKRXzgkCmkiZJhUFfsv5YSVgde1jLA5LPQitpZgI,1905
14
14
  truss/cli/train_commands.py,sha256=CrVqWsdkmSxgi3i2sSEyiE4QdfD0Z96F2Ib-PMZJjm8,20444
@@ -34,9 +34,9 @@ truss/cli/utils/output.py,sha256=GNjU85ZAMp5BI6Yij5wYXcaAvpm_kmHV0nHNmdkMxb0,646
34
34
  truss/cli/utils/self_upgrade.py,sha256=eTJZA4Wc8uUp4Qh6viRQp6bZm--wnQp7KWe5KRRpPtg,5427
35
35
  truss/contexts/docker_build_setup.py,sha256=cF4ExZgtYvrWxvyCAaUZUvV_DB_7__MqVomUDpalvKo,3925
36
36
  truss/contexts/truss_context.py,sha256=uS6L-ACHxNk0BsJwESOHh1lA0OGGw0pb33aFKGsASj4,436
37
- truss/contexts/image_builder/cache_warmer.py,sha256=TGMV1Mh87n2e_dSowH0sf0rZhZraDOR-LVapZL3a5r8,7377
37
+ truss/contexts/image_builder/cache_warmer.py,sha256=EETFAgZk7C6rQezzFxz4XqjS5LIyF7uM1VVscQt_cBA,6959
38
38
  truss/contexts/image_builder/image_builder.py,sha256=IuRgDeeoHVLzIkJvKtX3807eeqEyaroCs_KWDcIHZUg,1461
39
- truss/contexts/image_builder/serving_image_builder.py,sha256=G7FeSLeqSyT9rQTaJpuAALVmcAI3e3ZXRASR_WRKnGM,34210
39
+ truss/contexts/image_builder/serving_image_builder.py,sha256=1PfHtkTEdNPhSQAX8Ajk_0LN3KR2EfLKwOJsnECtKXQ,33958
40
40
  truss/contexts/image_builder/util.py,sha256=y2-CjUKv0XV-0w2sr1fUCflysDJLsoU4oPp6tvvoFnk,1203
41
41
  truss/contexts/local_loader/docker_build_emulator.py,sha256=3n0eIlJblz_sldh4AN8AHQDyfjQGdYyld5FabBdd9wE,3563
42
42
  truss/contexts/local_loader/dockerfile_parser.py,sha256=GoRJ0Af_3ILyLhjovK5lrCGn1rMxz6W3l681ro17ZzI,1344
@@ -52,12 +52,12 @@ truss/patch/truss_dir_patch_applier.py,sha256=ALnaVnu96g0kF2UmGuBFTua3lrXpwAy4sG
52
52
  truss/remote/remote_factory.py,sha256=-0gLh_yIyNDgD48Q6sR8Yo5dOMQg84lrHRvn_XR0n4s,3585
53
53
  truss/remote/truss_remote.py,sha256=TEe6h6by5-JLy7PMFsDN2QxIY5FmdIYN3bKvHHl02xM,8440
54
54
  truss/remote/baseten/__init__.py,sha256=XNqJW1zyp143XQc6-7XVwsUA_Q_ZJv_ausn1_Ohtw9Y,176
55
- truss/remote/baseten/api.py,sha256=5B5IXNy0v8hRHNH2ar3rldDa47kwt5s1PtKZQ9_pfmE,28263
55
+ truss/remote/baseten/api.py,sha256=2Es2afWKnz7OlQJHIbvYAKoSrb1dn9SnsAY--uHXbTs,30210
56
56
  truss/remote/baseten/auth.py,sha256=tI7s6cI2EZgzpMIzrdbILHyGwiHDnmoKf_JBhJXT55E,776
57
- truss/remote/baseten/core.py,sha256=uxtmBI9RAVHu1glIEJb5Q4ccJYLeZM1Cp5Svb9W68Yw,21965
58
- truss/remote/baseten/custom_types.py,sha256=bYrfTzGgYr6FDoya0omyadCLSTcTc-83U2scQORyUj0,4715
57
+ truss/remote/baseten/core.py,sha256=69utHGGFRw1ZQUobj80TSmaBgU3plnsfHZfiR15dPrY,23502
58
+ truss/remote/baseten/custom_types.py,sha256=g7yWkE8p6uIAG5JqgfELFGHzjFLvO7vLPzbe-yl1nYs,4735
59
59
  truss/remote/baseten/error.py,sha256=3TNTwwPqZnr4NRd9Sl6SfLUQR2fz9l6akDPpOntTpzA,578
60
- truss/remote/baseten/remote.py,sha256=Se8AES5mk8jxa8S9fN2DSG7wnsaV7ftRjJ4Uwc_w_S0,22544
60
+ truss/remote/baseten/remote.py,sha256=aKG1BODtrnmuRV-M8T3F3pw8oHawGwI09caKANJ19BM,23420
61
61
  truss/remote/baseten/rest_client.py,sha256=_t3CWsWARt2u0C0fDsF4rtvkkHe-lH7KXoPxWXAkKd4,1185
62
62
  truss/remote/baseten/service.py,sha256=HMaKiYbr2Mzv4BfXF9QkJ8H3Wwrq3LOMpFt9js4t0rs,5834
63
63
  truss/remote/baseten/utils/status.py,sha256=jputc9N9AHXxUuW4KOk6mcZKzQ_gOBOe5BSx9K0DxPY,1266
@@ -71,7 +71,7 @@ truss/templates/cache.Dockerfile.jinja,sha256=1qZqDo1phrcqi-Vwol-VafYJkADsBbQWU6
71
71
  truss/templates/cache_requirements.txt,sha256=xoPoJ-OVnf1z6oq_RVM3vCr3ionByyqMLj7wGs61nUs,87
72
72
  truss/templates/copy_cache_files.Dockerfile.jinja,sha256=Os5zFdYLZ_AfCRGq4RcpVTObOTwL7zvmwYcvOzd_Zqo,126
73
73
  truss/templates/docker_server_requirements.txt,sha256=PyhOPKAmKW1N2vLvTfLMwsEtuGpoRrbWuNo7tT6v2Mc,18
74
- truss/templates/no_build.Dockerfile.jinja,sha256=KB9GlRKLFolIvSn7G7L4kOpTkz3FdM5qx8HVL2Q7whI,222
74
+ truss/templates/no_build.Dockerfile.jinja,sha256=8x2PJUxr_gHai0St8ue2aWyih36t8kBytXMGr_5LG4w,35
75
75
  truss/templates/server.Dockerfile.jinja,sha256=Mu5_ZxuAknwaEOsF0l-XssA9pDg3pD3eLl6JBzNJ4rg,7091
76
76
  truss/templates/control/requirements.txt,sha256=tJGr83WoE0CZm2FrloZ9VScK84q-_FTuVXjDYrexhW0,250
77
77
  truss/templates/control/control/application.py,sha256=5Kam6M-XtfKGaXQz8cc3d0bwDkB80o2MskABWROx1gk,5321
@@ -93,7 +93,7 @@ truss/templates/custom/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NM
93
93
  truss/templates/custom/model/model.py,sha256=J04rLxK09Pwt2F4GoKOLKL-H-CqZUdYIM-PL2CE9PoE,1079
94
94
  truss/templates/custom_python_dx/my_model.py,sha256=NG75mQ6wxzB1BYUemDFZvRLBET-UrzuUK4FuHjqI29U,910
95
95
  truss/templates/docker_server/proxy.conf.jinja,sha256=Axily8EtznvrF7mUCgy2VFY99BYRt4BycZ0p9uWfd0s,2025
96
- truss/templates/docker_server/supervisord.conf.jinja,sha256=dd37fwZE--cutrvOUCqEyJQQQhlp61H2IUs2huKWsSk,1808
96
+ truss/templates/docker_server/supervisord.conf.jinja,sha256=N95c7nQvPZ5i-Oypy_7nYaV7JgcYQ25M5D4F2exS2HI,1804
97
97
  truss/templates/server/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
98
98
  truss/templates/server/main.py,sha256=kWXrdD8z8IpamyWxc8qcvd5ck9gM1Kz2QH5qHJCnmOQ,222
99
99
  truss/templates/server/model_wrapper.py,sha256=k75VVISwwlsx5EGb82UZsu8kCM_i6Yi3-Hd0-Kpm1yo,42055
@@ -141,6 +141,7 @@ truss/tests/test_testing_utilities_for_other_tests.py,sha256=YqIKflnd_BUMYaDBSkX
141
141
  truss/tests/test_truss_gatherer.py,sha256=bn288OEkC49YY0mhly4cAl410ktZPfElNdWwZy82WfA,1261
142
142
  truss/tests/test_truss_handle.py,sha256=-xz9VXkecXDTslmQZ-dmUmQLnvD0uumRqHS2uvGlMBA,30750
143
143
  truss/tests/test_util.py,sha256=hs1bNMkXKEdoPRx4Nw-NAEdoibR92OubZuADGmbiYsQ,1344
144
+ truss/tests/cli/test_chains_cli.py,sha256=l9GTQrhRm9SRZn43WkMY4tdRslLmdsVyiydRPa1_Ja4,3162
144
145
  truss/tests/cli/test_cli.py,sha256=yfbVS5u1hnAmmA8mJ539vj3lhH-JVGUvC4Q_Mbort44,787
145
146
  truss/tests/cli/train/test_cache_view.py,sha256=aVRCh3atRpFbJqyYgq7N-vAW0DiKMftQ7ajUqO2ClOg,22606
146
147
  truss/tests/cli/train/test_deploy_checkpoints.py,sha256=Ndkd9YxEgDLf3zLAZYH0myFK_wkKTz0oGZ57yWQt_l8,10100
@@ -162,6 +163,7 @@ truss/tests/remote/test_truss_remote.py,sha256=Rguyrnbx5RlbPJHFfCtsRtX1czAJ9Fo0a
162
163
  truss/tests/remote/baseten/conftest.py,sha256=vNk0nfDB7XdmqatOMhjdANCWFGYM4VwSHVKlaBO2PPk,442
163
164
  truss/tests/remote/baseten/test_api.py,sha256=AKJeNsrUtTNa0QPClfEvXlBOSJ214PKp23ULehMRJOQ,15885
164
165
  truss/tests/remote/baseten/test_auth.py,sha256=ttu4bDnmwGfo3oiNut4HVGnh-QnjAefwZJctiibQJKY,669
166
+ truss/tests/remote/baseten/test_chain_upload.py,sha256=XaaF1ocovkBYsLMJ8EpXB9FUGfQZAwu4iyOWqoVn7tc,10886
165
167
  truss/tests/remote/baseten/test_core.py,sha256=6NzJTDmoSUv6Muy1LFEYIUg10-cqw-hbLyeTSWcdNjY,26117
166
168
  truss/tests/remote/baseten/test_remote.py,sha256=y1qSPL1t7dBeYI3xMFn436fttG7wkYdAoENTz7qKObg,23634
167
169
  truss/tests/remote/baseten/test_service.py,sha256=ehbGkzzSPdLN7JHxc0O9YDPfzzKqU8OBzJGjRdw08zE,3786
@@ -340,6 +342,7 @@ truss/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
340
342
  truss/util/docker.py,sha256=6PD7kMBBrOjsdvgkuSv7JMgZbe3NoJIeGasljMm2SwA,3934
341
343
  truss/util/download.py,sha256=1lfBwzyaNLEp7SAVrBd9BX5inZpkCVp8sBnS9RNoiJA,2521
342
344
  truss/util/env_vars.py,sha256=7Bv686eER71Barrs6fNamk_TrTJGmu9yV2TxaVmupn0,1232
345
+ truss/util/error_utils.py,sha256=aO76Vf8LMlvhM28jRJ1qzNl4E5ZyvKK4TFQq_UhbQrk,1095
343
346
  truss/util/gpu.py,sha256=YiEF_JZyzur0MDMJOebMuJBQxrHD9ApGI0aPpWdb5BU,553
344
347
  truss/util/jinja.py,sha256=7KbuYNq55I3DGtImAiCvBwR0K9-z1Jo6gMhmsy4lNZE,333
345
348
  truss/util/log_utils.py,sha256=LwSgRh2K7KFjKKqBxr-IirFxGIzHi1mUM7YEvujvHsE,1985
@@ -349,8 +352,8 @@ truss/util/requirements.py,sha256=6T4nVV_NbSl3mAEo-CAk3JFmyJ_RJD768QaR55RdUJQ,69
349
352
  truss/util/user_config.py,sha256=CvBf5oouNyfdcFXOg3HFhELVW-THiuwyOYdW3aTxdHw,9130
350
353
  truss_chains/__init__.py,sha256=QDw1YwdqMaQpz5Oltu2Eq2vzEX9fDrMoqnhtbeh60i4,1278
351
354
  truss_chains/framework.py,sha256=CS7tSegPe2Q8UUT6CDkrtSrB3utr_1QN1jTEPjrj5Ug,67519
352
- truss_chains/private_types.py,sha256=6CaQEPawFLXjEbJ-01lqfexJtUIekF_q61LNENWegFo,8917
353
- truss_chains/public_api.py,sha256=0AXV6UdZIFAMycUNG_klgo4aLFmBZeKGfrulZEWzR0M,9532
355
+ truss_chains/private_types.py,sha256=vdcl8FuVsL9JGIu_9K7fd2EW9Ytzoq8nfEx5pmuMKTA,9063
356
+ truss_chains/public_api.py,sha256=civY8juJU92jSGBI7zM1qMnA7hlUdCq7L8o4IOo5meA,9722
354
357
  truss_chains/public_types.py,sha256=RPr8jgKO_F_26F7H3CpwbidL-6euoKPdFHVpEIpYqrQ,29415
355
358
  truss_chains/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
356
359
  truss_chains/pydantic_numpy.py,sha256=MG8Ji_Inwo_JSfM2n7TPj8B-nbrBlDYsY3SOeBwD8fE,4289
@@ -358,7 +361,7 @@ truss_chains/streaming.py,sha256=DGl2LEAN67YwP7Nn9MK488KmYc4KopWmcHuE6WjyO1Q,125
358
361
  truss_chains/utils.py,sha256=LvpCG2lnN6dqPqyX3PwLH9tyjUzqQN3N4WeEFROMHak,6291
359
362
  truss_chains/deployment/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
360
363
  truss_chains/deployment/code_gen.py,sha256=397FiSNZuW59J3Ma7N9GKGfvG_87BNFAXCIV8BW41t0,32669
361
- truss_chains/deployment/deployment_client.py,sha256=OoqkO3daktYzR2YsIcDvsuGfjR05X2K7QlA7wvFduzc,34208
364
+ truss_chains/deployment/deployment_client.py,sha256=4cHuvaynVCclJ6M9pw8ukhO1E2NRKohIRxftvOfNvOE,34499
362
365
  truss_chains/reference_code/reference_chainlet.py,sha256=5feSeqGtrHDbldkfZCfX2R5YbbW0Uhc35mhaP2pXrHw,1340
363
366
  truss_chains/reference_code/reference_model.py,sha256=emH3hb23E_nbP98I37PGp1Xk1hz3g3lQ00tiLo55cSM,322
364
367
  truss_chains/remote_chainlet/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -371,8 +374,8 @@ truss_train/deployment.py,sha256=lWWANSuzBWu2M4oK4qD7n-oVR1JKdmw2Pn5BJQHg-Ck,307
371
374
  truss_train/loader.py,sha256=0o66EjBaHc2YY4syxxHVR4ordJWs13lNXnKjKq2wq0U,1630
372
375
  truss_train/public_api.py,sha256=9N_NstiUlmBuLUwH_fNG_1x7OhGCytZLNvqKXBlStrM,1220
373
376
  truss_train/restore_from_checkpoint.py,sha256=8hdPm-WSgkt74HDPjvCjZMBpvA9MwtoYsxVjOoa7BaM,1176
374
- truss-0.11.14rc500.dist-info/METADATA,sha256=081Ne_g7Fvmciw4BxLWfyGBJQ1OtyI8C2UuebFsfHm8,6683
375
- truss-0.11.14rc500.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
376
- truss-0.11.14rc500.dist-info/entry_points.txt,sha256=-MwKfHHQHQ6j0HqIgvxrz3CehCmczDLTD-OsRHnjjuU,130
377
- truss-0.11.14rc500.dist-info/licenses/LICENSE,sha256=FTqGzu85i-uw1Gi8E_o0oD60bH9yQ_XIGtZbA1QUYiw,1064
378
- truss-0.11.14rc500.dist-info/RECORD,,
377
+ truss-0.11.15rc10.dist-info/METADATA,sha256=eFZQXzoDyrBqITGt5jfqUKBdFWLaL92ZxYIPjAd1b58,6682
378
+ truss-0.11.15rc10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
379
+ truss-0.11.15rc10.dist-info/entry_points.txt,sha256=-MwKfHHQHQ6j0HqIgvxrz3CehCmczDLTD-OsRHnjjuU,130
380
+ truss-0.11.15rc10.dist-info/licenses/LICENSE,sha256=FTqGzu85i-uw1Gi8E_o0oD60bH9yQ_XIGtZbA1QUYiw,1064
381
+ truss-0.11.15rc10.dist-info/RECORD,,
@@ -516,14 +516,21 @@ def _create_baseten_chain(
516
516
 
517
517
  _create_chains_secret_if_missing(remote_provider)
518
518
 
519
+ # Get chain root for raw chain artifact upload
520
+ chain_root = None
521
+ if entrypoint_descriptor is not None:
522
+ chain_root = _get_chain_root(entrypoint_descriptor.chainlet_cls)
523
+
519
524
  chain_deployment_handle = remote_provider.push_chain_atomic(
520
525
  baseten_options.chain_name,
521
526
  entrypoint_artifact,
522
527
  dependency_artifacts,
523
528
  truss_user_env,
529
+ chain_root=chain_root,
524
530
  publish=baseten_options.publish,
525
531
  environment=baseten_options.environment,
526
532
  progress_bar=progress_bar,
533
+ disable_chain_download=baseten_options.disable_chain_download,
527
534
  )
528
535
  return BasetenChainService(
529
536
  baseten_options.chain_name,
@@ -265,6 +265,7 @@ class PushOptionsBaseten(PushOptions):
265
265
  environment: Optional[str]
266
266
  include_git_info: bool
267
267
  working_dir: pathlib.Path
268
+ disable_chain_download: bool = False
268
269
 
269
270
  @classmethod
270
271
  def create(
@@ -277,6 +278,7 @@ class PushOptionsBaseten(PushOptions):
277
278
  include_git_info: bool,
278
279
  working_dir: pathlib.Path,
279
280
  environment: Optional[str] = None,
281
+ disable_chain_download: bool = False,
280
282
  ) -> "PushOptionsBaseten":
281
283
  if promote and not environment:
282
284
  environment = PRODUCTION_ENVIRONMENT_NAME
@@ -290,6 +292,7 @@ class PushOptionsBaseten(PushOptions):
290
292
  environment=environment,
291
293
  include_git_info=include_git_info,
292
294
  working_dir=working_dir,
295
+ disable_chain_download=disable_chain_download,
293
296
  )
294
297
 
295
298
 
@@ -151,6 +151,7 @@ def push(
151
151
  environment: Optional[str] = None,
152
152
  progress_bar: Optional[Type["progress.Progress"]] = None,
153
153
  include_git_info: bool = False,
154
+ disable_chain_download: bool = False,
154
155
  ) -> deployment_client.BasetenChainService:
155
156
  """
156
157
  Deploys a chain remotely (with all dependent chainlets).
@@ -172,6 +173,7 @@ def push(
172
173
  include_git_info: Whether to attach git versioning info (sha, branch, tag) to
173
174
  deployments made from within a git repo. If set to True in `.trussrc`, it
174
175
  will always be attached.
176
+ disable_chain_download: Disable downloading of pushed chain source code from the UI.
175
177
 
176
178
  Returns:
177
179
  A chain service handle to the deployed chain.
@@ -186,6 +188,7 @@ def push(
186
188
  environment=environment,
187
189
  include_git_info=include_git_info,
188
190
  working_dir=pathlib.Path(inspect.getfile(entrypoint)).parent,
191
+ disable_chain_download=disable_chain_download,
189
192
  )
190
193
  service = deployment_client.push(entrypoint, options, progress_bar=progress_bar)
191
194
  assert isinstance(