truss 0.11.16__py3-none-any.whl → 0.11.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

@@ -211,6 +211,14 @@ def _create_chains_table(service) -> Tuple[rich.table.Table, List[str]]:
211
211
  default=False,
212
212
  help=common.INCLUDE_GIT_INFO_DOC,
213
213
  )
214
+ @click.option(
215
+ "--disable-chain-download",
216
+ "disable_chain_download",
217
+ is_flag=True,
218
+ required=False,
219
+ default=False,
220
+ help="Disable downloading of pushed chain source code from the UI.",
221
+ )
214
222
  @click.pass_context
215
223
  @common.common_options()
216
224
  def push_chain(
@@ -227,6 +235,7 @@ def push_chain(
227
235
  environment: Optional[str],
228
236
  experimental_watch_chainlet_names: Optional[str],
229
237
  include_git_info: bool = False,
238
+ disable_chain_download: bool = False,
230
239
  ) -> None:
231
240
  """
232
241
  Deploys a chain remotely.
@@ -284,6 +293,7 @@ def push_chain(
284
293
  environment=environment,
285
294
  include_git_info=include_git_info,
286
295
  working_dir=source.parent if source.is_file() else source.resolve(),
296
+ disable_chain_download=disable_chain_download,
287
297
  )
288
298
  service = deployment_client.push(
289
299
  entrypoint_cls, options, progress_bar=progress.Progress
truss/cli/cli.py CHANGED
@@ -548,8 +548,9 @@ def push(
548
548
  model_name = remote_cli.inquire_model_name()
549
549
 
550
550
  if promote and environment:
551
- promote_warning = "'promote' flag and 'environment' flag were both specified. Ignoring the value of 'promote'"
552
- console.print(promote_warning, style="yellow")
551
+ raise click.UsageError(
552
+ "'promote' flag and 'environment' flag cannot both be specified."
553
+ )
553
554
  if promote and not environment:
554
555
  environment = PRODUCTION_ENVIRONMENT_NAME
555
556
 
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Mapping, Optional
5
5
  import requests
6
6
  from pydantic import BaseModel, Field
7
7
 
8
+ from truss.base.custom_types import SafeModel
8
9
  from truss.remote.baseten import custom_types as b10_types
9
10
  from truss.remote.baseten.auth import ApiKey, AuthService
10
11
  from truss.remote.baseten.custom_types import APIKeyCategory
@@ -13,6 +14,29 @@ from truss.remote.baseten.rest_client import RestAPIClient
13
14
  from truss.remote.baseten.utils.transfer import base64_encoded_json_str
14
15
 
15
16
  logger = logging.getLogger(__name__)
17
+ PARAMS_INDENT = "\n "
18
+
19
+
20
+ class ChainAWSCredential(SafeModel):
21
+ aws_access_key_id: str
22
+ aws_secret_access_key: str
23
+ aws_session_token: str
24
+
25
+
26
+ class ChainUploadCredentials(SafeModel):
27
+ s3_bucket: str
28
+ s3_key: str
29
+ aws_access_key_id: str
30
+ aws_secret_access_key: str
31
+ aws_session_token: str
32
+
33
+ @property
34
+ def aws_credentials(self) -> ChainAWSCredential:
35
+ return ChainAWSCredential(
36
+ aws_access_key_id=self.aws_access_key_id,
37
+ aws_secret_access_key=self.aws_secret_access_key,
38
+ aws_session_token=self.aws_session_token,
39
+ )
16
40
 
17
41
 
18
42
  class InstanceTypeV1(BaseModel):
@@ -175,6 +199,7 @@ class BasetenApi:
175
199
  allow_truss_download: bool = True,
176
200
  deployment_name: Optional[str] = None,
177
201
  origin: Optional[b10_types.ModelOrigin] = None,
202
+ environment: Optional[str] = None,
178
203
  ):
179
204
  query_string = f"""
180
205
  mutation ($trussUserEnv: String) {{
@@ -187,6 +212,7 @@ class BasetenApi:
187
212
  allow_truss_download: {"true" if allow_truss_download else "false"}
188
213
  {f'version_name: "{deployment_name}"' if deployment_name else ""}
189
214
  {f"model_origin: {origin.value}" if origin else ""}
215
+ {f'environment_name: "{environment}"' if environment else ""}
190
216
  ) {{
191
217
  model_version {{
192
218
  id
@@ -299,7 +325,11 @@ class BasetenApi:
299
325
  chain_name: Optional[str] = None,
300
326
  environment: Optional[str] = None,
301
327
  is_draft: bool = False,
328
+ original_source_artifact_s3_key: Optional[str] = None,
329
+ allow_truss_download: Optional[bool] = True,
302
330
  ):
331
+ if allow_truss_download is None:
332
+ allow_truss_download = True
303
333
  entrypoint_str = _chainlet_data_atomic_to_graphql_mutation(entrypoint)
304
334
 
305
335
  dependencies_str = ", ".join(
@@ -309,13 +339,28 @@ class BasetenApi:
309
339
  ]
310
340
  )
311
341
 
342
+ params = []
343
+ if chain_id:
344
+ params.append(f'chain_id: "{chain_id}"')
345
+ if chain_name:
346
+ params.append(f'chain_name: "{chain_name}"')
347
+ if environment:
348
+ params.append(f'environment: "{environment}"')
349
+ if original_source_artifact_s3_key:
350
+ params.append(
351
+ f'original_source_artifact_s3_key: "{original_source_artifact_s3_key}"'
352
+ )
353
+
354
+ params.append(f"is_draft: {str(is_draft).lower()}")
355
+ if allow_truss_download is False:
356
+ params.append("allow_truss_download: false")
357
+
358
+ params_str = PARAMS_INDENT.join(params)
359
+
312
360
  query_string = f"""
313
361
  mutation ($trussUserEnv: String) {{
314
362
  deploy_chain_atomic(
315
- {f'chain_id: "{chain_id}"' if chain_id else ""}
316
- {f'chain_name: "{chain_name}"' if chain_name else ""}
317
- {f'environment: "{environment}"' if environment else ""}
318
- is_draft: {str(is_draft).lower()}
363
+ {params_str}
319
364
  entrypoint: {entrypoint_str}
320
365
  dependencies: [{dependencies_str}]
321
366
  truss_user_env: $trussUserEnv
@@ -657,8 +702,29 @@ class BasetenApi:
657
702
  return resp_json["training_projects"]
658
703
 
659
704
  def get_blob_credentials(self, blob_type: b10_types.BlobType):
705
+ if blob_type == b10_types.BlobType.CHAIN:
706
+ return self.get_chain_s3_upload_credentials()
660
707
  return self._rest_api_client.get(f"v1/blobs/credentials/{blob_type.value}")
661
708
 
709
+ def get_chain_s3_upload_credentials(self) -> ChainUploadCredentials:
710
+ """Get chain artifact credentials using GraphQL query."""
711
+ query = """
712
+ query {
713
+ chain_s3_upload_credentials {
714
+ s3_bucket
715
+ s3_key
716
+ aws_access_key_id
717
+ aws_secret_access_key
718
+ aws_session_token
719
+ }
720
+ }
721
+ """
722
+ response = self._post_graphql_query(query)
723
+
724
+ return ChainUploadCredentials.model_validate(
725
+ response["data"]["chain_s3_upload_credentials"]
726
+ )
727
+
662
728
  def get_training_job_metrics(
663
729
  self,
664
730
  project_id: str,
@@ -8,6 +8,7 @@ from typing import IO, TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tup
8
8
  import requests
9
9
 
10
10
  from truss.base.errors import ValidationError
11
+ from truss.util.error_utils import handle_client_error
11
12
 
12
13
  if TYPE_CHECKING:
13
14
  from rich import progress
@@ -80,6 +81,8 @@ class ChainDeploymentHandleAtomic(NamedTuple):
80
81
  chain_id: str
81
82
  chain_deployment_id: str
82
83
  is_draft: bool
84
+ original_source_artifact_s3_key: Optional[str] = None
85
+ allow_truss_download: Optional[bool] = True
83
86
 
84
87
 
85
88
  class ModelVersionHandle(NamedTuple):
@@ -127,6 +130,8 @@ def create_chain_atomic(
127
130
  is_draft: bool,
128
131
  truss_user_env: b10_types.TrussUserEnv,
129
132
  environment: Optional[str],
133
+ original_source_artifact_s3_key: Optional[str] = None,
134
+ allow_truss_download: bool = True,
130
135
  ) -> ChainDeploymentHandleAtomic:
131
136
  if environment and is_draft:
132
137
  logging.info(
@@ -149,6 +154,8 @@ def create_chain_atomic(
149
154
  chain_name=chain_name,
150
155
  is_draft=True,
151
156
  truss_user_env=truss_user_env,
157
+ original_source_artifact_s3_key=original_source_artifact_s3_key,
158
+ allow_truss_download=allow_truss_download,
152
159
  )
153
160
  elif chain_id:
154
161
  # This is the only case where promote has relevance, since
@@ -162,6 +169,8 @@ def create_chain_atomic(
162
169
  chain_id=chain_id,
163
170
  environment=environment,
164
171
  truss_user_env=truss_user_env,
172
+ original_source_artifact_s3_key=original_source_artifact_s3_key,
173
+ allow_truss_download=allow_truss_download,
165
174
  )
166
175
  except ApiError as e:
167
176
  if (
@@ -182,6 +191,8 @@ def create_chain_atomic(
182
191
  dependencies=dependencies,
183
192
  chain_name=chain_name,
184
193
  truss_user_env=truss_user_env,
194
+ original_source_artifact_s3_key=original_source_artifact_s3_key,
195
+ allow_truss_download=allow_truss_download,
185
196
  )
186
197
 
187
198
  return ChainDeploymentHandleAtomic(
@@ -189,6 +200,8 @@ def create_chain_atomic(
189
200
  chain_id=res["chain_deployment"]["chain"]["id"],
190
201
  hostname=res["chain_deployment"]["chain"]["hostname"],
191
202
  is_draft=is_draft,
203
+ original_source_artifact_s3_key=original_source_artifact_s3_key,
204
+ allow_truss_download=allow_truss_download,
192
205
  )
193
206
 
194
207
 
@@ -342,6 +355,33 @@ def upload_truss(
342
355
  return s3_key
343
356
 
344
357
 
358
+ def upload_chain_artifact(
359
+ api: BasetenApi,
360
+ serialize_file: IO,
361
+ progress_bar: Optional[Type["progress.Progress"]],
362
+ ) -> str:
363
+ """
364
+ Upload a chain artifact to the Baseten remote.
365
+
366
+ Args:
367
+ api: BasetenApi instance
368
+ serialize_file: File-like object containing the serialized chain artifact
369
+
370
+ Returns:
371
+ The S3 key of the uploaded file
372
+ """
373
+ credentials = api.get_chain_s3_upload_credentials()
374
+ with handle_client_error("Uploading chain source"):
375
+ multipart_upload_boto3(
376
+ serialize_file.name,
377
+ credentials.s3_bucket,
378
+ credentials.s3_key,
379
+ credentials.aws_credentials.model_dump(),
380
+ progress_bar,
381
+ )
382
+ return credentials.s3_key
383
+
384
+
345
385
  def create_truss_service(
346
386
  api: BasetenApi,
347
387
  model_name: str,
@@ -398,9 +438,6 @@ def create_truss_service(
398
438
  )
399
439
 
400
440
  if model_id is None:
401
- if environment and environment != PRODUCTION_ENVIRONMENT_NAME:
402
- raise ValueError(NO_ENVIRONMENTS_EXIST_ERROR_MESSAGING)
403
-
404
441
  model_version_json = api.create_model_from_truss(
405
442
  model_name,
406
443
  s3_key,
@@ -410,6 +447,7 @@ def create_truss_service(
410
447
  allow_truss_download=allow_truss_download,
411
448
  deployment_name=deployment_name,
412
449
  origin=origin,
450
+ environment=environment,
413
451
  )
414
452
 
415
453
  return ModelVersionHandle(
@@ -120,6 +120,7 @@ class TrussUserEnv(pydantic.BaseModel):
120
120
  class BlobType(Enum):
121
121
  MODEL = "model"
122
122
  TRAIN = "train"
123
+ CHAIN = "chain"
123
124
 
124
125
 
125
126
  class FileSummary(pydantic.BaseModel):
@@ -31,6 +31,7 @@ from truss.remote.baseten.core import (
31
31
  get_model_and_versions,
32
32
  get_prod_version_from_versions,
33
33
  get_truss_watch_state,
34
+ upload_chain_artifact,
34
35
  upload_truss,
35
36
  validate_truss_config_against_backend,
36
37
  )
@@ -263,9 +264,11 @@ class BasetenRemote(TrussRemote):
263
264
  entrypoint_artifact: custom_types.ChainletArtifact,
264
265
  dependency_artifacts: List[custom_types.ChainletArtifact],
265
266
  truss_user_env: b10_types.TrussUserEnv,
267
+ chain_root: Optional[Path] = None,
266
268
  publish: bool = False,
267
269
  environment: Optional[str] = None,
268
270
  progress_bar: Optional[Type["progress.Progress"]] = None,
271
+ disable_chain_download: bool = False,
269
272
  ) -> ChainDeploymentHandleAtomic:
270
273
  # If we are promoting a model to an environment after deploy, it must be published.
271
274
  # Draft models cannot be promoted.
@@ -285,6 +288,7 @@ class BasetenRemote(TrussRemote):
285
288
  publish=publish,
286
289
  origin=custom_types.ModelOrigin.CHAINS,
287
290
  progress_bar=progress_bar,
291
+ disable_truss_download=disable_chain_download,
288
292
  )
289
293
  oracle_data = custom_types.OracleData(
290
294
  model_name=push_data.model_name,
@@ -300,6 +304,18 @@ class BasetenRemote(TrussRemote):
300
304
  )
301
305
  )
302
306
 
307
+ # Upload raw chain artifact if chain_root is provided
308
+ raw_chain_s3_key = None
309
+ if chain_root is not None:
310
+ logging.info("Uploading source artifact")
311
+ # Create a tar file from the chain root directory
312
+ original_source_tar = archive_dir(dir=chain_root, progress_bar=progress_bar)
313
+ # Upload the chain artifact to S3
314
+ raw_chain_s3_key = upload_chain_artifact(
315
+ api=self._api,
316
+ serialize_file=original_source_tar,
317
+ progress_bar=progress_bar,
318
+ )
303
319
  chain_deployment_handle = create_chain_atomic(
304
320
  api=self._api,
305
321
  chain_name=chain_name,
@@ -308,6 +324,8 @@ class BasetenRemote(TrussRemote):
308
324
  is_draft=not publish,
309
325
  truss_user_env=truss_user_env,
310
326
  environment=environment,
327
+ original_source_artifact_s3_key=raw_chain_s3_key,
328
+ allow_truss_download=not disable_chain_download,
311
329
  )
312
330
  logging.info("Successfully pushed to baseten. Chain is building and deploying.")
313
331
  return chain_deployment_handle
@@ -0,0 +1,100 @@
1
+ """Tests for truss chains CLI commands."""
2
+
3
+ from unittest.mock import Mock, patch
4
+
5
+ from click.testing import CliRunner
6
+
7
+ from truss.cli.cli import truss_cli
8
+
9
+
10
+ def test_chains_push_with_disable_chain_download_flag():
11
+ """Test that --disable-chain-download flag is properly parsed and passed through."""
12
+ runner = CliRunner()
13
+
14
+ mock_entrypoint_cls = Mock()
15
+ mock_entrypoint_cls.meta_data.chain_name = "test_chain"
16
+ mock_entrypoint_cls.display_name = "TestChain"
17
+
18
+ mock_service = Mock()
19
+ mock_service.run_remote_url = "http://test.com/run_remote"
20
+ mock_service.is_websocket = False
21
+
22
+ with patch(
23
+ "truss_chains.framework.ChainletImporter.import_target"
24
+ ) as mock_importer:
25
+ with patch("truss_chains.deployment.deployment_client.push") as mock_push:
26
+ mock_importer.return_value.__enter__.return_value = mock_entrypoint_cls
27
+ mock_push.return_value = mock_service
28
+
29
+ result = runner.invoke(
30
+ truss_cli,
31
+ [
32
+ "chains",
33
+ "push",
34
+ "test_chain.py",
35
+ "--disable-chain-download",
36
+ "--remote",
37
+ "test_remote",
38
+ "--dryrun",
39
+ ],
40
+ )
41
+
42
+ assert result.exit_code == 0
43
+
44
+ mock_push.assert_called_once()
45
+ call_args = mock_push.call_args
46
+ options = call_args[0][1]
47
+
48
+ assert hasattr(options, "disable_chain_download")
49
+ assert options.disable_chain_download is True
50
+
51
+
52
+ def test_chains_push_without_disable_chain_download_flag():
53
+ """Test that disable_chain_download defaults to False when flag is not provided."""
54
+ runner = CliRunner()
55
+
56
+ mock_entrypoint_cls = Mock()
57
+ mock_entrypoint_cls.meta_data.chain_name = "test_chain"
58
+ mock_entrypoint_cls.display_name = "TestChain"
59
+
60
+ mock_service = Mock()
61
+ mock_service.run_remote_url = "http://test.com/run_remote"
62
+ mock_service.is_websocket = False
63
+
64
+ with patch(
65
+ "truss_chains.framework.ChainletImporter.import_target"
66
+ ) as mock_importer:
67
+ with patch("truss_chains.deployment.deployment_client.push") as mock_push:
68
+ mock_importer.return_value.__enter__.return_value = mock_entrypoint_cls
69
+ mock_push.return_value = mock_service
70
+
71
+ result = runner.invoke(
72
+ truss_cli,
73
+ [
74
+ "chains",
75
+ "push",
76
+ "test_chain.py",
77
+ "--remote",
78
+ "test_remote",
79
+ "--dryrun",
80
+ ],
81
+ )
82
+
83
+ assert result.exit_code == 0
84
+
85
+ mock_push.assert_called_once()
86
+ call_args = mock_push.call_args
87
+ options = call_args[0][1]
88
+
89
+ assert hasattr(options, "disable_chain_download")
90
+ assert options.disable_chain_download is False
91
+
92
+
93
+ def test_chains_push_help_includes_disable_chain_download():
94
+ """Test that --disable-chain-download appears in the help output."""
95
+ runner = CliRunner()
96
+
97
+ result = runner.invoke(truss_cli, ["chains", "push", "--help"])
98
+
99
+ assert result.exit_code == 0
100
+ assert "--disable-chain-download" in result.output
@@ -0,0 +1,285 @@
1
+ import pathlib
2
+ import tempfile
3
+ from unittest.mock import Mock, patch
4
+
5
+ import pytest
6
+
7
+ from truss.remote.baseten import custom_types as b10_types
8
+ from truss.remote.baseten.api import BasetenApi
9
+ from truss.remote.baseten.core import upload_chain_artifact
10
+ from truss.remote.baseten.remote import BasetenRemote
11
+
12
+
13
+ @pytest.fixture
14
+ def mock_push_data():
15
+ """Fixture providing mock push data for tests."""
16
+ mock_push_data = Mock()
17
+ mock_push_data.model_name = "test-model"
18
+ mock_push_data.s3_key = "models/test-key"
19
+ mock_push_data.encoded_config_str = "encoded_config"
20
+ mock_push_data.is_draft = False
21
+ mock_push_data.model_id = "model-id"
22
+ mock_push_data.version_name = None
23
+ return mock_push_data
24
+
25
+
26
+ @pytest.fixture
27
+ def mock_remote_context():
28
+ """Fixture providing mock remote and context managers for tests."""
29
+ api = Mock(spec=BasetenApi)
30
+
31
+ remote = BasetenRemote("https://test.baseten.co", "test-api-key")
32
+ remote._api = api
33
+
34
+ chain_name = "test-chain"
35
+ entrypoint_artifact = Mock()
36
+ entrypoint_artifact.truss_dir = "/path/to/truss"
37
+ entrypoint_artifact.display_name = "entrypoint"
38
+
39
+ dependency_artifacts = []
40
+ truss_user_env = Mock()
41
+ chain_root = pathlib.Path("/path/to/chain")
42
+
43
+ with patch.object(remote, "_prepare_push") as mock_prepare_push:
44
+ with patch("truss.remote.baseten.remote.truss_build.load") as mock_load:
45
+ mock_truss_handle = Mock()
46
+ mock_truss_handle.spec.config.model_name = "test-model"
47
+ mock_load.return_value = mock_truss_handle
48
+
49
+ yield {
50
+ "remote": remote,
51
+ "api": api,
52
+ "chain_name": chain_name,
53
+ "entrypoint_artifact": entrypoint_artifact,
54
+ "dependency_artifacts": dependency_artifacts,
55
+ "truss_user_env": truss_user_env,
56
+ "chain_root": chain_root,
57
+ "mock_prepare_push": mock_prepare_push,
58
+ "mock_load": mock_load,
59
+ "mock_truss_handle": mock_truss_handle,
60
+ }
61
+
62
+
63
+ def test_get_blob_credentials_for_chain():
64
+ """Test that get_blob_credentials works correctly for chain blob type using GraphQL."""
65
+ mock_graphql_response = {
66
+ "data": {
67
+ "chain_s3_upload_credentials": {
68
+ "s3_bucket": "test-chain-bucket",
69
+ "s3_key": "chains/test-uuid/chain.tgz",
70
+ "aws_access_key_id": "test_access_key",
71
+ "aws_secret_access_key": "test_secret_key",
72
+ "aws_session_token": "test_session_token",
73
+ }
74
+ }
75
+ }
76
+
77
+ # Create a real API instance and mock the GraphQL call
78
+ mock_auth_service = Mock()
79
+ mock_auth_service.authenticate.return_value = Mock(value="test-token")
80
+ api = BasetenApi("https://test.baseten.co", mock_auth_service)
81
+ with patch.object(api, "_post_graphql_query") as mock_graphql:
82
+ mock_graphql.return_value = mock_graphql_response
83
+
84
+ result = api.get_chain_s3_upload_credentials()
85
+
86
+ assert result.s3_bucket == "test-chain-bucket"
87
+ assert result.s3_key == "chains/test-uuid/chain.tgz"
88
+ assert result.aws_access_key_id == "test_access_key"
89
+ assert result.aws_secret_access_key == "test_secret_key"
90
+ assert result.aws_session_token == "test_session_token"
91
+
92
+ mock_graphql.assert_called_once()
93
+ call_args = mock_graphql.call_args
94
+ assert "chain_s3_upload_credentials" in call_args[0][0]
95
+
96
+
97
+ def test_get_blob_credentials_for_other_types_uses_rest():
98
+ """Test that get_blob_credentials uses REST API for non-chain blob types."""
99
+ mock_response = {
100
+ "s3_bucket": "test-bucket",
101
+ "s3_key": "test-key",
102
+ "creds": {
103
+ "aws_access_key_id": "test_access_key",
104
+ "aws_secret_access_key": "test_secret_key",
105
+ "aws_session_token": "test_session_token",
106
+ },
107
+ }
108
+
109
+ mock_auth_service = Mock()
110
+ mock_auth_service.authenticate.return_value = Mock(value="test-token")
111
+ api = BasetenApi("https://test.baseten.co", mock_auth_service)
112
+ with patch.object(api, "_rest_api_client") as mock_client, patch.object(
113
+ api, "_post_graphql_query"
114
+ ) as mock_graphql:
115
+ mock_client.get.return_value = mock_response
116
+
117
+ result = api.get_blob_credentials(b10_types.BlobType.MODEL)
118
+
119
+ assert result["s3_bucket"] == "test-bucket"
120
+ assert result["s3_key"] == "test-key"
121
+
122
+ mock_client.get.assert_called_once_with("v1/blobs/credentials/model")
123
+ mock_graphql.assert_not_called()
124
+
125
+
126
+ @patch("truss.remote.baseten.core.multipart_upload_boto3")
127
+ def test_upload_chain_artifact_function(mock_multipart_upload):
128
+ """Test the upload_chain_artifact function."""
129
+ # Mock ChainUploadCredentials object
130
+ mock_credentials = Mock()
131
+ mock_credentials.s3_bucket = "test-chain-bucket"
132
+ mock_credentials.s3_key = "chains/test-uuid/chain.tgz"
133
+ mock_credentials.aws_credentials = Mock()
134
+ mock_credentials.aws_credentials.model_dump.return_value = {
135
+ "aws_access_key_id": "test_access_key",
136
+ "aws_secret_access_key": "test_secret_key",
137
+ "aws_session_token": "test_session_token",
138
+ }
139
+
140
+ api = Mock(spec=BasetenApi)
141
+ api.get_chain_s3_upload_credentials.return_value = mock_credentials
142
+
143
+ with tempfile.NamedTemporaryFile(suffix=".tgz", delete=False) as temp_file:
144
+ temp_file.write(b"test chain content")
145
+ temp_file.flush()
146
+
147
+ result = upload_chain_artifact(api, temp_file, None)
148
+
149
+ assert result == "chains/test-uuid/chain.tgz"
150
+
151
+ api.get_chain_s3_upload_credentials.assert_called_once_with()
152
+
153
+ mock_multipart_upload.assert_called_once()
154
+ call_args = mock_multipart_upload.call_args
155
+ assert call_args[0][0] == temp_file.name # file path
156
+ assert call_args[0][1] == "test-chain-bucket" # bucket
157
+ assert call_args[0][2] == "chains/test-uuid/chain.tgz" # key
158
+ assert call_args[0][3] == { # credentials
159
+ "aws_access_key_id": "test_access_key",
160
+ "aws_secret_access_key": "test_secret_key",
161
+ "aws_session_token": "test_session_token",
162
+ }
163
+
164
+
165
+ @patch("truss.remote.baseten.remote.upload_chain_artifact")
166
+ @patch("truss.remote.baseten.remote.archive_dir")
167
+ @patch("truss.remote.baseten.remote.create_chain_atomic")
168
+ def test_push_chain_atomic_with_chain_upload(
169
+ mock_create_chain_atomic,
170
+ mock_archive_dir,
171
+ mock_upload_chain_artifact,
172
+ mock_push_data,
173
+ mock_remote_context,
174
+ ):
175
+ """Test that push_chain_atomic uploads raw chain artifact when chain_root is provided."""
176
+ mock_create_chain_atomic.return_value = Mock()
177
+ mock_archive_dir.return_value = Mock()
178
+ mock_upload_chain_artifact.return_value = "chains/test-uuid/chain.tgz"
179
+
180
+ context = mock_remote_context
181
+ remote = context["remote"]
182
+ chain_name = context["chain_name"]
183
+ entrypoint_artifact = context["entrypoint_artifact"]
184
+ dependency_artifacts = context["dependency_artifacts"]
185
+ truss_user_env = context["truss_user_env"]
186
+ chain_root = context["chain_root"]
187
+
188
+ context["mock_prepare_push"].return_value = mock_push_data
189
+
190
+ result = remote.push_chain_atomic(
191
+ chain_name=chain_name,
192
+ entrypoint_artifact=entrypoint_artifact,
193
+ dependency_artifacts=dependency_artifacts,
194
+ truss_user_env=truss_user_env,
195
+ chain_root=chain_root,
196
+ publish=True,
197
+ )
198
+ assert result == mock_create_chain_atomic.return_value
199
+
200
+ mock_archive_dir.assert_called_once_with(dir=chain_root, progress_bar=None)
201
+ mock_upload_chain_artifact.assert_called_once()
202
+
203
+ mock_create_chain_atomic.assert_called_once()
204
+
205
+
206
+ @patch("truss.remote.baseten.remote.create_chain_atomic")
207
+ def test_push_chain_atomic_without_chain_upload(
208
+ mock_create_chain_atomic, mock_push_data, mock_remote_context
209
+ ):
210
+ """Test that push_chain_atomic skips chain upload when chain_root is None."""
211
+ mock_create_chain_atomic.return_value = Mock()
212
+
213
+ context = mock_remote_context
214
+ remote = context["remote"]
215
+ chain_name = context["chain_name"]
216
+ entrypoint_artifact = context["entrypoint_artifact"]
217
+ dependency_artifacts = context["dependency_artifacts"]
218
+ truss_user_env = context["truss_user_env"]
219
+
220
+ context["mock_prepare_push"].return_value = mock_push_data
221
+
222
+ with patch("truss.remote.baseten.remote.upload_chain_artifact") as mock_upload:
223
+ with patch(
224
+ "truss.remote.baseten.core.create_tar_with_progress_bar"
225
+ ) as mock_tar:
226
+ # Call push_chain_atomic without chain_root
227
+ result = remote.push_chain_atomic(
228
+ chain_name=chain_name,
229
+ entrypoint_artifact=entrypoint_artifact,
230
+ dependency_artifacts=dependency_artifacts,
231
+ truss_user_env=truss_user_env,
232
+ chain_root=None, # No chain root
233
+ publish=True,
234
+ )
235
+
236
+ assert result
237
+ # Verify chain artifact upload was NOT called
238
+ mock_tar.assert_not_called()
239
+ mock_upload.assert_not_called()
240
+
241
+ mock_create_chain_atomic.assert_called_once()
242
+
243
+
244
+ @patch("truss.remote.baseten.core.multipart_upload_boto3")
245
+ def test_upload_chain_artifact_error_handling(mock_multipart_upload):
246
+ """Test error handling in upload_chain_artifact."""
247
+ # Mock API to raise an exception
248
+ api = Mock(spec=BasetenApi)
249
+ api.get_chain_s3_upload_credentials.side_effect = Exception("API Error")
250
+
251
+ with tempfile.NamedTemporaryFile(suffix=".tgz") as temp_file:
252
+ # Should raise the exception
253
+ with pytest.raises(Exception, match="API Error"):
254
+ upload_chain_artifact(api, temp_file, None)
255
+
256
+
257
+ def test_upload_chain_artifact_credentials_extraction():
258
+ """Test that credentials are properly extracted from API response."""
259
+ # Mock ChainUploadCredentials object
260
+ mock_credentials = Mock()
261
+ mock_credentials.s3_bucket = "test-bucket"
262
+ mock_credentials.s3_key = "chains/test-uuid/chain.tgz"
263
+ mock_credentials.aws_credentials = Mock()
264
+ mock_credentials.aws_credentials.model_dump.return_value = {
265
+ "aws_access_key_id": "access_key",
266
+ "aws_secret_access_key": "secret_key",
267
+ "aws_session_token": "session_token",
268
+ }
269
+
270
+ api = Mock(spec=BasetenApi)
271
+ api.get_chain_s3_upload_credentials.return_value = mock_credentials
272
+
273
+ with patch("truss.remote.baseten.core.multipart_upload_boto3") as mock_upload:
274
+ with tempfile.NamedTemporaryFile(suffix=".tgz") as temp_file:
275
+ upload_chain_artifact(api, temp_file, None)
276
+
277
+ call_args = mock_upload.call_args
278
+ credentials = call_args[0][3]
279
+
280
+ assert credentials == {
281
+ "aws_access_key_id": "access_key",
282
+ "aws_secret_access_key": "secret_key",
283
+ "aws_session_token": "session_token",
284
+ }
285
+ assert "extra_field" not in credentials
@@ -0,0 +1,37 @@
1
+ import logging
2
+ from contextlib import contextmanager
3
+ from typing import Generator
4
+
5
+ from botocore.exceptions import ClientError, NoCredentialsError
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ @contextmanager
11
+ def handle_client_error(
12
+ operation_description: str = "AWS operation",
13
+ ) -> Generator[None, None, None]:
14
+ """
15
+ Context manager to handle common boto3 errors and convert them to RuntimeError.
16
+
17
+ Args:
18
+ operation_description: Description of the operation being performed for error messages
19
+
20
+ Raises:
21
+ RuntimeError: For NoCredentialsError, ClientError, and other exceptions
22
+ """
23
+ try:
24
+ yield
25
+ except NoCredentialsError as nce:
26
+ raise RuntimeError(
27
+ f"No AWS credentials found for {operation_description}\nOriginal exception: {str(nce)}"
28
+ )
29
+ except ClientError as ce:
30
+ raise RuntimeError(
31
+ f"AWS client error when {operation_description} (check your credentials): {str(ce)}"
32
+ )
33
+
34
+ except Exception as exc:
35
+ raise RuntimeError(
36
+ f"Unexpected error `{exc}` during `{operation_description}`\nOriginal exception: {str(exc)}"
37
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: truss
3
- Version: 0.11.16
3
+ Version: 0.11.18
4
4
  Summary: A seamless bridge from model development to model delivery
5
5
  Project-URL: Repository, https://github.com/basetenlabs/truss
6
6
  Project-URL: Homepage, https://truss.baseten.co
@@ -8,8 +8,8 @@ truss/base/errors.py,sha256=zDVLEvseTChdPP0oNhBBQCtQUtZJUaof5zeWMIjqz6o,691
8
8
  truss/base/trt_llm_config.py,sha256=rEtBVFg2QnNMxnaz11s5Z69dJB1w7Bpt48Wf6jSsVZI,33087
9
9
  truss/base/truss_config.py,sha256=s39Xc1e20s8IV07YLl_aVnp-uRS18ZQ2TV-3FILx4nY,28416
10
10
  truss/base/truss_spec.py,sha256=jFVF79CXoEEspl2kXBAPyi-rwISReIGTdobGpaIhwJw,5979
11
- truss/cli/chains_commands.py,sha256=Kpa5mCg6URAJQE2ZmZfVQFhjBHEitKT28tKiW0H6XAI,17406
12
- truss/cli/cli.py,sha256=PaMkuwXZflkU7sa1tEoT_Zmy-iBkEZs1m4IVqcieaeo,30367
11
+ truss/cli/chains_commands.py,sha256=QijtACpuAt2O1RV_qhTNPw0jcFg-u0dX9PP-ct0t-rs,17716
12
+ truss/cli/cli.py,sha256=VGOw1ell7h9bna64UmopavCpVPdjDerSaGPDoizIsRI,30313
13
13
  truss/cli/remote_cli.py,sha256=G_xCKRXzgkCmkiZJhUFfsv5YSVgde1jLA5LPQitpZgI,1905
14
14
  truss/cli/train_commands.py,sha256=CrVqWsdkmSxgi3i2sSEyiE4QdfD0Z96F2Ib-PMZJjm8,20444
15
15
  truss/cli/logs/base_watcher.py,sha256=vuqteoaMVGX34cgKcETf4X_gOkvnSnDaWz1_pbeFhqs,3343
@@ -52,12 +52,12 @@ truss/patch/truss_dir_patch_applier.py,sha256=ALnaVnu96g0kF2UmGuBFTua3lrXpwAy4sG
52
52
  truss/remote/remote_factory.py,sha256=-0gLh_yIyNDgD48Q6sR8Yo5dOMQg84lrHRvn_XR0n4s,3585
53
53
  truss/remote/truss_remote.py,sha256=TEe6h6by5-JLy7PMFsDN2QxIY5FmdIYN3bKvHHl02xM,8440
54
54
  truss/remote/baseten/__init__.py,sha256=XNqJW1zyp143XQc6-7XVwsUA_Q_ZJv_ausn1_Ohtw9Y,176
55
- truss/remote/baseten/api.py,sha256=5B5IXNy0v8hRHNH2ar3rldDa47kwt5s1PtKZQ9_pfmE,28263
55
+ truss/remote/baseten/api.py,sha256=54Cl_2zHpRU4g2VXzK-BYlxPJeHHImceFrbxD9AASXo,30335
56
56
  truss/remote/baseten/auth.py,sha256=tI7s6cI2EZgzpMIzrdbILHyGwiHDnmoKf_JBhJXT55E,776
57
- truss/remote/baseten/core.py,sha256=uxtmBI9RAVHu1glIEJb5Q4ccJYLeZM1Cp5Svb9W68Yw,21965
58
- truss/remote/baseten/custom_types.py,sha256=bYrfTzGgYr6FDoya0omyadCLSTcTc-83U2scQORyUj0,4715
57
+ truss/remote/baseten/core.py,sha256=FC5-87Vs2f0NR8eddtSRvr3Z5W2rF7mpiq9jCPrbzr4,23399
58
+ truss/remote/baseten/custom_types.py,sha256=g7yWkE8p6uIAG5JqgfELFGHzjFLvO7vLPzbe-yl1nYs,4735
59
59
  truss/remote/baseten/error.py,sha256=3TNTwwPqZnr4NRd9Sl6SfLUQR2fz9l6akDPpOntTpzA,578
60
- truss/remote/baseten/remote.py,sha256=Se8AES5mk8jxa8S9fN2DSG7wnsaV7ftRjJ4Uwc_w_S0,22544
60
+ truss/remote/baseten/remote.py,sha256=aKG1BODtrnmuRV-M8T3F3pw8oHawGwI09caKANJ19BM,23420
61
61
  truss/remote/baseten/rest_client.py,sha256=_t3CWsWARt2u0C0fDsF4rtvkkHe-lH7KXoPxWXAkKd4,1185
62
62
  truss/remote/baseten/service.py,sha256=HMaKiYbr2Mzv4BfXF9QkJ8H3Wwrq3LOMpFt9js4t0rs,5834
63
63
  truss/remote/baseten/utils/status.py,sha256=jputc9N9AHXxUuW4KOk6mcZKzQ_gOBOe5BSx9K0DxPY,1266
@@ -141,6 +141,7 @@ truss/tests/test_testing_utilities_for_other_tests.py,sha256=YqIKflnd_BUMYaDBSkX
141
141
  truss/tests/test_truss_gatherer.py,sha256=bn288OEkC49YY0mhly4cAl410ktZPfElNdWwZy82WfA,1261
142
142
  truss/tests/test_truss_handle.py,sha256=-xz9VXkecXDTslmQZ-dmUmQLnvD0uumRqHS2uvGlMBA,30750
143
143
  truss/tests/test_util.py,sha256=hs1bNMkXKEdoPRx4Nw-NAEdoibR92OubZuADGmbiYsQ,1344
144
+ truss/tests/cli/test_chains_cli.py,sha256=l9GTQrhRm9SRZn43WkMY4tdRslLmdsVyiydRPa1_Ja4,3162
144
145
  truss/tests/cli/test_cli.py,sha256=yfbVS5u1hnAmmA8mJ539vj3lhH-JVGUvC4Q_Mbort44,787
145
146
  truss/tests/cli/train/test_cache_view.py,sha256=aVRCh3atRpFbJqyYgq7N-vAW0DiKMftQ7ajUqO2ClOg,22606
146
147
  truss/tests/cli/train/test_deploy_checkpoints.py,sha256=Ndkd9YxEgDLf3zLAZYH0myFK_wkKTz0oGZ57yWQt_l8,10100
@@ -162,6 +163,7 @@ truss/tests/remote/test_truss_remote.py,sha256=Rguyrnbx5RlbPJHFfCtsRtX1czAJ9Fo0a
162
163
  truss/tests/remote/baseten/conftest.py,sha256=vNk0nfDB7XdmqatOMhjdANCWFGYM4VwSHVKlaBO2PPk,442
163
164
  truss/tests/remote/baseten/test_api.py,sha256=AKJeNsrUtTNa0QPClfEvXlBOSJ214PKp23ULehMRJOQ,15885
164
165
  truss/tests/remote/baseten/test_auth.py,sha256=ttu4bDnmwGfo3oiNut4HVGnh-QnjAefwZJctiibQJKY,669
166
+ truss/tests/remote/baseten/test_chain_upload.py,sha256=XaaF1ocovkBYsLMJ8EpXB9FUGfQZAwu4iyOWqoVn7tc,10886
165
167
  truss/tests/remote/baseten/test_core.py,sha256=6NzJTDmoSUv6Muy1LFEYIUg10-cqw-hbLyeTSWcdNjY,26117
166
168
  truss/tests/remote/baseten/test_remote.py,sha256=y1qSPL1t7dBeYI3xMFn436fttG7wkYdAoENTz7qKObg,23634
167
169
  truss/tests/remote/baseten/test_service.py,sha256=ehbGkzzSPdLN7JHxc0O9YDPfzzKqU8OBzJGjRdw08zE,3786
@@ -340,6 +342,7 @@ truss/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
340
342
  truss/util/docker.py,sha256=6PD7kMBBrOjsdvgkuSv7JMgZbe3NoJIeGasljMm2SwA,3934
341
343
  truss/util/download.py,sha256=1lfBwzyaNLEp7SAVrBd9BX5inZpkCVp8sBnS9RNoiJA,2521
342
344
  truss/util/env_vars.py,sha256=PmKsXdN-PX2-_xk9XcdHTFuRRWiFaMw2iNUKxE8B1Ro,1671
345
+ truss/util/error_utils.py,sha256=pvBH_opyCVpfUbFvm8bgOtjOWit23x2smkBPPdkSZwc,1148
343
346
  truss/util/gpu.py,sha256=YiEF_JZyzur0MDMJOebMuJBQxrHD9ApGI0aPpWdb5BU,553
344
347
  truss/util/jinja.py,sha256=7KbuYNq55I3DGtImAiCvBwR0K9-z1Jo6gMhmsy4lNZE,333
345
348
  truss/util/log_utils.py,sha256=LwSgRh2K7KFjKKqBxr-IirFxGIzHi1mUM7YEvujvHsE,1985
@@ -349,8 +352,8 @@ truss/util/requirements.py,sha256=6T4nVV_NbSl3mAEo-CAk3JFmyJ_RJD768QaR55RdUJQ,69
349
352
  truss/util/user_config.py,sha256=CvBf5oouNyfdcFXOg3HFhELVW-THiuwyOYdW3aTxdHw,9130
350
353
  truss_chains/__init__.py,sha256=QDw1YwdqMaQpz5Oltu2Eq2vzEX9fDrMoqnhtbeh60i4,1278
351
354
  truss_chains/framework.py,sha256=CS7tSegPe2Q8UUT6CDkrtSrB3utr_1QN1jTEPjrj5Ug,67519
352
- truss_chains/private_types.py,sha256=6CaQEPawFLXjEbJ-01lqfexJtUIekF_q61LNENWegFo,8917
353
- truss_chains/public_api.py,sha256=0AXV6UdZIFAMycUNG_klgo4aLFmBZeKGfrulZEWzR0M,9532
355
+ truss_chains/private_types.py,sha256=vdcl8FuVsL9JGIu_9K7fd2EW9Ytzoq8nfEx5pmuMKTA,9063
356
+ truss_chains/public_api.py,sha256=civY8juJU92jSGBI7zM1qMnA7hlUdCq7L8o4IOo5meA,9722
354
357
  truss_chains/public_types.py,sha256=RPr8jgKO_F_26F7H3CpwbidL-6euoKPdFHVpEIpYqrQ,29415
355
358
  truss_chains/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
356
359
  truss_chains/pydantic_numpy.py,sha256=MG8Ji_Inwo_JSfM2n7TPj8B-nbrBlDYsY3SOeBwD8fE,4289
@@ -358,7 +361,7 @@ truss_chains/streaming.py,sha256=DGl2LEAN67YwP7Nn9MK488KmYc4KopWmcHuE6WjyO1Q,125
358
361
  truss_chains/utils.py,sha256=LvpCG2lnN6dqPqyX3PwLH9tyjUzqQN3N4WeEFROMHak,6291
359
362
  truss_chains/deployment/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
360
363
  truss_chains/deployment/code_gen.py,sha256=397FiSNZuW59J3Ma7N9GKGfvG_87BNFAXCIV8BW41t0,32669
361
- truss_chains/deployment/deployment_client.py,sha256=OoqkO3daktYzR2YsIcDvsuGfjR05X2K7QlA7wvFduzc,34208
364
+ truss_chains/deployment/deployment_client.py,sha256=4cHuvaynVCclJ6M9pw8ukhO1E2NRKohIRxftvOfNvOE,34499
362
365
  truss_chains/reference_code/reference_chainlet.py,sha256=5feSeqGtrHDbldkfZCfX2R5YbbW0Uhc35mhaP2pXrHw,1340
363
366
  truss_chains/reference_code/reference_model.py,sha256=emH3hb23E_nbP98I37PGp1Xk1hz3g3lQ00tiLo55cSM,322
364
367
  truss_chains/remote_chainlet/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -371,8 +374,8 @@ truss_train/deployment.py,sha256=lWWANSuzBWu2M4oK4qD7n-oVR1JKdmw2Pn5BJQHg-Ck,307
371
374
  truss_train/loader.py,sha256=0o66EjBaHc2YY4syxxHVR4ordJWs13lNXnKjKq2wq0U,1630
372
375
  truss_train/public_api.py,sha256=9N_NstiUlmBuLUwH_fNG_1x7OhGCytZLNvqKXBlStrM,1220
373
376
  truss_train/restore_from_checkpoint.py,sha256=8hdPm-WSgkt74HDPjvCjZMBpvA9MwtoYsxVjOoa7BaM,1176
374
- truss-0.11.16.dist-info/METADATA,sha256=ykIbKzLWO3FqVipKHGs7Ad8yNkqGXoHm74AchWFVgHA,6678
375
- truss-0.11.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
376
- truss-0.11.16.dist-info/entry_points.txt,sha256=-MwKfHHQHQ6j0HqIgvxrz3CehCmczDLTD-OsRHnjjuU,130
377
- truss-0.11.16.dist-info/licenses/LICENSE,sha256=FTqGzu85i-uw1Gi8E_o0oD60bH9yQ_XIGtZbA1QUYiw,1064
378
- truss-0.11.16.dist-info/RECORD,,
377
+ truss-0.11.18.dist-info/METADATA,sha256=vlaTGYYSaK0iMt1NvmUEuF6uhy-hygDj2MLDbQQpJO0,6678
378
+ truss-0.11.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
379
+ truss-0.11.18.dist-info/entry_points.txt,sha256=-MwKfHHQHQ6j0HqIgvxrz3CehCmczDLTD-OsRHnjjuU,130
380
+ truss-0.11.18.dist-info/licenses/LICENSE,sha256=FTqGzu85i-uw1Gi8E_o0oD60bH9yQ_XIGtZbA1QUYiw,1064
381
+ truss-0.11.18.dist-info/RECORD,,
@@ -516,14 +516,21 @@ def _create_baseten_chain(
516
516
 
517
517
  _create_chains_secret_if_missing(remote_provider)
518
518
 
519
+ # Get chain root for raw chain artifact upload
520
+ chain_root = None
521
+ if entrypoint_descriptor is not None:
522
+ chain_root = _get_chain_root(entrypoint_descriptor.chainlet_cls)
523
+
519
524
  chain_deployment_handle = remote_provider.push_chain_atomic(
520
525
  baseten_options.chain_name,
521
526
  entrypoint_artifact,
522
527
  dependency_artifacts,
523
528
  truss_user_env,
529
+ chain_root=chain_root,
524
530
  publish=baseten_options.publish,
525
531
  environment=baseten_options.environment,
526
532
  progress_bar=progress_bar,
533
+ disable_chain_download=baseten_options.disable_chain_download,
527
534
  )
528
535
  return BasetenChainService(
529
536
  baseten_options.chain_name,
@@ -265,6 +265,7 @@ class PushOptionsBaseten(PushOptions):
265
265
  environment: Optional[str]
266
266
  include_git_info: bool
267
267
  working_dir: pathlib.Path
268
+ disable_chain_download: bool = False
268
269
 
269
270
  @classmethod
270
271
  def create(
@@ -277,6 +278,7 @@ class PushOptionsBaseten(PushOptions):
277
278
  include_git_info: bool,
278
279
  working_dir: pathlib.Path,
279
280
  environment: Optional[str] = None,
281
+ disable_chain_download: bool = False,
280
282
  ) -> "PushOptionsBaseten":
281
283
  if promote and not environment:
282
284
  environment = PRODUCTION_ENVIRONMENT_NAME
@@ -290,6 +292,7 @@ class PushOptionsBaseten(PushOptions):
290
292
  environment=environment,
291
293
  include_git_info=include_git_info,
292
294
  working_dir=working_dir,
295
+ disable_chain_download=disable_chain_download,
293
296
  )
294
297
 
295
298
 
@@ -151,6 +151,7 @@ def push(
151
151
  environment: Optional[str] = None,
152
152
  progress_bar: Optional[Type["progress.Progress"]] = None,
153
153
  include_git_info: bool = False,
154
+ disable_chain_download: bool = False,
154
155
  ) -> deployment_client.BasetenChainService:
155
156
  """
156
157
  Deploys a chain remotely (with all dependent chainlets).
@@ -172,6 +173,7 @@ def push(
172
173
  include_git_info: Whether to attach git versioning info (sha, branch, tag) to
173
174
  deployments made from within a git repo. If set to True in `.trussrc`, it
174
175
  will always be attached.
176
+ disable_chain_download: Disable downloading of pushed chain source code from the UI.
175
177
 
176
178
  Returns:
177
179
  A chain service handle to the deployed chain.
@@ -186,6 +188,7 @@ def push(
186
188
  environment=environment,
187
189
  include_git_info=include_git_info,
188
190
  working_dir=pathlib.Path(inspect.getfile(entrypoint)).parent,
191
+ disable_chain_download=disable_chain_download,
189
192
  )
190
193
  service = deployment_client.push(entrypoint, options, progress_bar=progress_bar)
191
194
  assert isinstance(