truss 0.11.15rc12__py3-none-any.whl → 0.11.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

@@ -211,14 +211,6 @@ def _create_chains_table(service) -> Tuple[rich.table.Table, List[str]]:
211
211
  default=False,
212
212
  help=common.INCLUDE_GIT_INFO_DOC,
213
213
  )
214
- @click.option(
215
- "--disable-chain-download",
216
- "disable_chain_download",
217
- is_flag=True,
218
- required=False,
219
- default=False,
220
- help="Disable downloading of pushed chain source code from the UI.",
221
- )
222
214
  @click.pass_context
223
215
  @common.common_options()
224
216
  def push_chain(
@@ -235,7 +227,6 @@ def push_chain(
235
227
  environment: Optional[str],
236
228
  experimental_watch_chainlet_names: Optional[str],
237
229
  include_git_info: bool = False,
238
- disable_chain_download: bool = False,
239
230
  ) -> None:
240
231
  """
241
232
  Deploys a chain remotely.
@@ -293,7 +284,6 @@ def push_chain(
293
284
  environment=environment,
294
285
  include_git_info=include_git_info,
295
286
  working_dir=source.parent if source.is_file() else source.resolve(),
296
- disable_chain_download=disable_chain_download,
297
287
  )
298
288
  service = deployment_client.push(
299
289
  entrypoint_cls, options, progress_bar=progress.Progress
@@ -11,11 +11,10 @@ from typing import Optional, Type
11
11
  import boto3
12
12
  from botocore import UNSIGNED
13
13
  from botocore.client import Config
14
+ from botocore.exceptions import ClientError, NoCredentialsError
14
15
  from google.cloud import storage
15
16
  from huggingface_hub import hf_hub_download
16
17
 
17
- from truss.util.error_utils import handle_client_error
18
-
19
18
  B10CP_PATH_TRUSS_ENV_VAR_NAME = "B10CP_PATH_TRUSS"
20
19
 
21
20
  GCS_CREDENTIALS = "/app/data/service_account.json"
@@ -189,14 +188,24 @@ class S3File(RepositoryFile):
189
188
  if not dst_file.parent.exists():
190
189
  dst_file.parent.mkdir(parents=True)
191
190
 
192
- with handle_client_error(
193
- f"accessing S3 bucket {bucket_name} for file {file_name}"
194
- ):
191
+ try:
195
192
  url = client.generate_presigned_url(
196
193
  "get_object",
197
194
  Params={"Bucket": bucket_name, "Key": file_name},
198
195
  ExpiresIn=3600,
199
196
  )
197
+ except NoCredentialsError as nce:
198
+ raise RuntimeError(
199
+ f"No AWS credentials found\nOriginal exception: {str(nce)}"
200
+ )
201
+ except ClientError as ce:
202
+ raise RuntimeError(
203
+ f"Client error when accessing the S3 bucket (check your credentials): {str(ce)}"
204
+ )
205
+ except Exception as exc:
206
+ raise RuntimeError(
207
+ f"File not found on S3 bucket: {file_name}\nOriginal exception: {str(exc)}"
208
+ )
200
209
 
201
210
  download_file_using_b10cp(url, dst_file, self.file_name)
202
211
 
@@ -5,7 +5,6 @@ from typing import Any, Dict, List, Mapping, Optional
5
5
  import requests
6
6
  from pydantic import BaseModel, Field
7
7
 
8
- from truss.base.custom_types import SafeModel
9
8
  from truss.remote.baseten import custom_types as b10_types
10
9
  from truss.remote.baseten.auth import ApiKey, AuthService
11
10
  from truss.remote.baseten.custom_types import APIKeyCategory
@@ -14,29 +13,6 @@ from truss.remote.baseten.rest_client import RestAPIClient
14
13
  from truss.remote.baseten.utils.transfer import base64_encoded_json_str
15
14
 
16
15
  logger = logging.getLogger(__name__)
17
- PARAMS_INDENT = "\n "
18
-
19
-
20
- class ChainAWSCredential(SafeModel):
21
- aws_access_key_id: str
22
- aws_secret_access_key: str
23
- aws_session_token: str
24
-
25
-
26
- class ChainUploadCredentials(SafeModel):
27
- s3_bucket: str
28
- s3_key: str
29
- aws_access_key_id: str
30
- aws_secret_access_key: str
31
- aws_session_token: str
32
-
33
- @property
34
- def aws_credentials(self) -> ChainAWSCredential:
35
- return ChainAWSCredential(
36
- aws_access_key_id=self.aws_access_key_id,
37
- aws_secret_access_key=self.aws_secret_access_key,
38
- aws_session_token=self.aws_session_token,
39
- )
40
16
 
41
17
 
42
18
  class InstanceTypeV1(BaseModel):
@@ -323,11 +299,7 @@ class BasetenApi:
323
299
  chain_name: Optional[str] = None,
324
300
  environment: Optional[str] = None,
325
301
  is_draft: bool = False,
326
- original_source_artifact_s3_key: Optional[str] = None,
327
- allow_truss_download: Optional[bool] = True,
328
302
  ):
329
- if allow_truss_download is None:
330
- allow_truss_download = True
331
303
  entrypoint_str = _chainlet_data_atomic_to_graphql_mutation(entrypoint)
332
304
 
333
305
  dependencies_str = ", ".join(
@@ -337,28 +309,13 @@ class BasetenApi:
337
309
  ]
338
310
  )
339
311
 
340
- params = []
341
- if chain_id:
342
- params.append(f'chain_id: "{chain_id}"')
343
- if chain_name:
344
- params.append(f'chain_name: "{chain_name}"')
345
- if environment:
346
- params.append(f'environment: "{environment}"')
347
- if original_source_artifact_s3_key:
348
- params.append(
349
- f'original_source_artifact_s3_key: "{original_source_artifact_s3_key}"'
350
- )
351
-
352
- params.append(f"is_draft: {str(is_draft).lower()}")
353
- if allow_truss_download is False:
354
- params.append("allow_truss_download: false")
355
-
356
- params_str = PARAMS_INDENT.join(params)
357
-
358
312
  query_string = f"""
359
313
  mutation ($trussUserEnv: String) {{
360
314
  deploy_chain_atomic(
361
- {params_str}
315
+ {f'chain_id: "{chain_id}"' if chain_id else ""}
316
+ {f'chain_name: "{chain_name}"' if chain_name else ""}
317
+ {f'environment: "{environment}"' if environment else ""}
318
+ is_draft: {str(is_draft).lower()}
362
319
  entrypoint: {entrypoint_str}
363
320
  dependencies: [{dependencies_str}]
364
321
  truss_user_env: $trussUserEnv
@@ -700,29 +657,8 @@ class BasetenApi:
700
657
  return resp_json["training_projects"]
701
658
 
702
659
  def get_blob_credentials(self, blob_type: b10_types.BlobType):
703
- if blob_type == b10_types.BlobType.CHAIN:
704
- return self.get_chain_s3_upload_credentials()
705
660
  return self._rest_api_client.get(f"v1/blobs/credentials/{blob_type.value}")
706
661
 
707
- def get_chain_s3_upload_credentials(self) -> ChainUploadCredentials:
708
- """Get chain artifact credentials using GraphQL query."""
709
- query = """
710
- query {
711
- chain_s3_upload_credentials {
712
- s3_bucket
713
- s3_key
714
- aws_access_key_id
715
- aws_secret_access_key
716
- aws_session_token
717
- }
718
- }
719
- """
720
- response = self._post_graphql_query(query)
721
-
722
- return ChainUploadCredentials.model_validate(
723
- response["data"]["chain_s3_upload_credentials"]
724
- )
725
-
726
662
  def get_training_job_metrics(
727
663
  self,
728
664
  project_id: str,
@@ -8,7 +8,6 @@ from typing import IO, TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tup
8
8
  import requests
9
9
 
10
10
  from truss.base.errors import ValidationError
11
- from truss.util.error_utils import handle_client_error
12
11
 
13
12
  if TYPE_CHECKING:
14
13
  from rich import progress
@@ -81,8 +80,6 @@ class ChainDeploymentHandleAtomic(NamedTuple):
81
80
  chain_id: str
82
81
  chain_deployment_id: str
83
82
  is_draft: bool
84
- original_source_artifact_s3_key: Optional[str] = None
85
- allow_truss_download: Optional[bool] = True
86
83
 
87
84
 
88
85
  class ModelVersionHandle(NamedTuple):
@@ -130,8 +127,6 @@ def create_chain_atomic(
130
127
  is_draft: bool,
131
128
  truss_user_env: b10_types.TrussUserEnv,
132
129
  environment: Optional[str],
133
- original_source_artifact_s3_key: Optional[str] = None,
134
- allow_truss_download: bool = True,
135
130
  ) -> ChainDeploymentHandleAtomic:
136
131
  if environment and is_draft:
137
132
  logging.info(
@@ -154,8 +149,6 @@ def create_chain_atomic(
154
149
  chain_name=chain_name,
155
150
  is_draft=True,
156
151
  truss_user_env=truss_user_env,
157
- original_source_artifact_s3_key=original_source_artifact_s3_key,
158
- allow_truss_download=allow_truss_download,
159
152
  )
160
153
  elif chain_id:
161
154
  # This is the only case where promote has relevance, since
@@ -169,8 +162,6 @@ def create_chain_atomic(
169
162
  chain_id=chain_id,
170
163
  environment=environment,
171
164
  truss_user_env=truss_user_env,
172
- original_source_artifact_s3_key=original_source_artifact_s3_key,
173
- allow_truss_download=allow_truss_download,
174
165
  )
175
166
  except ApiError as e:
176
167
  if (
@@ -191,8 +182,6 @@ def create_chain_atomic(
191
182
  dependencies=dependencies,
192
183
  chain_name=chain_name,
193
184
  truss_user_env=truss_user_env,
194
- original_source_artifact_s3_key=original_source_artifact_s3_key,
195
- allow_truss_download=allow_truss_download,
196
185
  )
197
186
 
198
187
  return ChainDeploymentHandleAtomic(
@@ -200,8 +189,6 @@ def create_chain_atomic(
200
189
  chain_id=res["chain_deployment"]["chain"]["id"],
201
190
  hostname=res["chain_deployment"]["chain"]["hostname"],
202
191
  is_draft=is_draft,
203
- original_source_artifact_s3_key=original_source_artifact_s3_key,
204
- allow_truss_download=allow_truss_download,
205
192
  )
206
193
 
207
194
 
@@ -355,33 +342,6 @@ def upload_truss(
355
342
  return s3_key
356
343
 
357
344
 
358
- def upload_chain_artifact(
359
- api: BasetenApi,
360
- serialize_file: IO,
361
- progress_bar: Optional[Type["progress.Progress"]],
362
- ) -> str:
363
- """
364
- Upload a chain artifact to the Baseten remote.
365
-
366
- Args:
367
- api: BasetenApi instance
368
- serialize_file: File-like object containing the serialized chain artifact
369
-
370
- Returns:
371
- The S3 key of the uploaded file
372
- """
373
- credentials = api.get_chain_s3_upload_credentials()
374
- with handle_client_error("Uploading chain source"):
375
- multipart_upload_boto3(
376
- serialize_file.name,
377
- credentials.s3_bucket,
378
- credentials.s3_key,
379
- credentials.aws_credentials.model_dump(),
380
- progress_bar,
381
- )
382
- return credentials.s3_key
383
-
384
-
385
345
  def create_truss_service(
386
346
  api: BasetenApi,
387
347
  model_name: str,
@@ -120,7 +120,6 @@ class TrussUserEnv(pydantic.BaseModel):
120
120
  class BlobType(Enum):
121
121
  MODEL = "model"
122
122
  TRAIN = "train"
123
- CHAIN = "chain"
124
123
 
125
124
 
126
125
  class FileSummary(pydantic.BaseModel):
@@ -31,7 +31,6 @@ from truss.remote.baseten.core import (
31
31
  get_model_and_versions,
32
32
  get_prod_version_from_versions,
33
33
  get_truss_watch_state,
34
- upload_chain_artifact,
35
34
  upload_truss,
36
35
  validate_truss_config_against_backend,
37
36
  )
@@ -264,11 +263,9 @@ class BasetenRemote(TrussRemote):
264
263
  entrypoint_artifact: custom_types.ChainletArtifact,
265
264
  dependency_artifacts: List[custom_types.ChainletArtifact],
266
265
  truss_user_env: b10_types.TrussUserEnv,
267
- chain_root: Optional[Path] = None,
268
266
  publish: bool = False,
269
267
  environment: Optional[str] = None,
270
268
  progress_bar: Optional[Type["progress.Progress"]] = None,
271
- disable_chain_download: bool = False,
272
269
  ) -> ChainDeploymentHandleAtomic:
273
270
  # If we are promoting a model to an environment after deploy, it must be published.
274
271
  # Draft models cannot be promoted.
@@ -288,7 +285,6 @@ class BasetenRemote(TrussRemote):
288
285
  publish=publish,
289
286
  origin=custom_types.ModelOrigin.CHAINS,
290
287
  progress_bar=progress_bar,
291
- disable_truss_download=disable_chain_download,
292
288
  )
293
289
  oracle_data = custom_types.OracleData(
294
290
  model_name=push_data.model_name,
@@ -304,18 +300,6 @@ class BasetenRemote(TrussRemote):
304
300
  )
305
301
  )
306
302
 
307
- # Upload raw chain artifact if chain_root is provided
308
- raw_chain_s3_key = None
309
- if chain_root is not None:
310
- logging.info("Uploading source artifact")
311
- # Create a tar file from the chain root directory
312
- original_source_tar = archive_dir(dir=chain_root, progress_bar=progress_bar)
313
- # Upload the chain artifact to S3
314
- raw_chain_s3_key = upload_chain_artifact(
315
- api=self._api,
316
- serialize_file=original_source_tar,
317
- progress_bar=progress_bar,
318
- )
319
303
  chain_deployment_handle = create_chain_atomic(
320
304
  api=self._api,
321
305
  chain_name=chain_name,
@@ -324,8 +308,6 @@ class BasetenRemote(TrussRemote):
324
308
  is_draft=not publish,
325
309
  truss_user_env=truss_user_env,
326
310
  environment=environment,
327
- original_source_artifact_s3_key=raw_chain_s3_key,
328
- allow_truss_download=not disable_chain_download,
329
311
  )
330
312
  logging.info("Successfully pushed to baseten. Chain is building and deploying.")
331
313
  return chain_deployment_handle
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Optional, Type
7
7
  import boto3
8
8
  from boto3.s3.transfer import TransferConfig
9
9
 
10
- from truss.util.env_vars import override_env_vars
10
+ from truss.util.env_vars import modify_env_vars
11
11
 
12
12
  if TYPE_CHECKING:
13
13
  from rich import progress
@@ -26,7 +26,10 @@ def multipart_upload_boto3(
26
26
  ) -> None:
27
27
  # In the CLI flow, ignore any local ~/.aws/config files,
28
28
  # which can interfere with uploading the Truss to S3.
29
- with override_env_vars({"AWS_CONFIG_FILE": ""}):
29
+ aws_env_vars = set(
30
+ env_var for env_var in os.environ.keys() if env_var.startswith("AWS_")
31
+ )
32
+ with modify_env_vars(deletions=aws_env_vars):
30
33
  s3_resource = boto3.resource("s3", **credentials)
31
34
  filesize = os.stat(file_path).st_size
32
35
 
@@ -7,7 +7,7 @@ logfile_maxbytes=0 ; No size limit on logfile (since logging is disabl
7
7
  [program:model-server]
8
8
  command={{start_command}} ; Command to start the model server (provided by Jinja variable)
9
9
  startsecs=30 ; Wait 30 seconds before assuming the server is running
10
- startretries=1 ; Do not retry if server fails to start
10
+ startretries=0 ; Do not retry if server fails to start
11
11
  autostart=true ; Automatically start the program when supervisord starts
12
12
  autorestart=false ; Don't restart the program
13
13
  stdout_logfile=/dev/fd/1 ; Send stdout to the first file descriptor (stdout)
@@ -1,14 +1,19 @@
1
1
  import os
2
2
 
3
- from truss.util.env_vars import override_env_vars
3
+ from truss.util.env_vars import modify_env_vars
4
4
 
5
5
 
6
- def test_override_env_vars():
6
+ def test_modify_env_vars():
7
7
  os.environ["API_KEY"] = "original_key"
8
+ os.environ["AWS_CONFIG_FILE"] = "original_config_file"
8
9
 
9
- with override_env_vars({"API_KEY": "new_key", "DEBUG": "true"}):
10
+ with modify_env_vars(
11
+ overrides={"API_KEY": "new_key", "DEBUG": "true"}, deletions={"AWS_CONFIG_FILE"}
12
+ ):
10
13
  assert os.environ["API_KEY"] == "new_key"
11
14
  assert os.environ["DEBUG"] == "true"
15
+ assert "AWS_CONFIG_FILE" not in os.environ
12
16
 
13
17
  assert os.environ["API_KEY"] == "original_key"
14
18
  assert "DEBUG" not in os.environ
19
+ assert os.environ["AWS_CONFIG_FILE"] == "original_config_file"
truss/util/env_vars.py CHANGED
@@ -1,32 +1,43 @@
1
1
  import os
2
- from typing import Dict, Optional
2
+ from typing import Dict, Optional, Set
3
3
 
4
4
 
5
- class override_env_vars:
5
+ class modify_env_vars:
6
6
  """A context manager for temporarily overwriting environment variables.
7
7
 
8
8
  Usage:
9
- with override_env_vars({'API_KEY': 'test_key', 'DEBUG': 'true'}):
9
+ with modify_env_vars(overrides={'API_KEY': 'test_key', 'DEBUG': 'true'}, deletions={'AWS_CONFIG_FILE'}):
10
10
  # Environment variables are modified here
11
11
  ...
12
12
  # Original environment is restored here
13
13
  """
14
14
 
15
- def __init__(self, env_vars: Dict[str, str]):
15
+ def __init__(
16
+ self,
17
+ overrides: Optional[Dict[str, str]] = None,
18
+ deletions: Optional[Set[str]] = None,
19
+ ):
16
20
  """
17
21
  Args:
18
- env_vars: Dictionary of environment variables to set
22
+ overrides: Dictionary of environment variables to set
23
+ deletions: Set of environment variables to delete
19
24
  """
20
- self.env_vars = env_vars
25
+ self.overrides: Dict[str, str] = overrides or dict()
26
+ self.deletions: Set[str] = deletions or set()
21
27
  self.original_vars: Dict[str, Optional[str]] = {}
22
28
 
23
29
  def __enter__(self):
24
- for key in self.env_vars:
30
+ all_keys = set(self.overrides.keys()) | self.deletions
31
+ for key in all_keys:
25
32
  self.original_vars[key] = os.environ.get(key)
26
33
 
27
- for key, value in self.env_vars.items():
34
+ for key, value in self.overrides.items():
28
35
  os.environ[key] = value
29
36
 
37
+ for key in self.deletions:
38
+ if key in os.environ:
39
+ del os.environ[key]
40
+
30
41
  return self
31
42
 
32
43
  def __exit__(self, exc_type, exc_val, exc_tb):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: truss
3
- Version: 0.11.15rc12
3
+ Version: 0.11.16
4
4
  Summary: A seamless bridge from model development to model delivery
5
5
  Project-URL: Repository, https://github.com/basetenlabs/truss
6
6
  Project-URL: Homepage, https://truss.baseten.co
@@ -8,7 +8,7 @@ truss/base/errors.py,sha256=zDVLEvseTChdPP0oNhBBQCtQUtZJUaof5zeWMIjqz6o,691
8
8
  truss/base/trt_llm_config.py,sha256=rEtBVFg2QnNMxnaz11s5Z69dJB1w7Bpt48Wf6jSsVZI,33087
9
9
  truss/base/truss_config.py,sha256=s39Xc1e20s8IV07YLl_aVnp-uRS18ZQ2TV-3FILx4nY,28416
10
10
  truss/base/truss_spec.py,sha256=jFVF79CXoEEspl2kXBAPyi-rwISReIGTdobGpaIhwJw,5979
11
- truss/cli/chains_commands.py,sha256=QijtACpuAt2O1RV_qhTNPw0jcFg-u0dX9PP-ct0t-rs,17716
11
+ truss/cli/chains_commands.py,sha256=Kpa5mCg6URAJQE2ZmZfVQFhjBHEitKT28tKiW0H6XAI,17406
12
12
  truss/cli/cli.py,sha256=PaMkuwXZflkU7sa1tEoT_Zmy-iBkEZs1m4IVqcieaeo,30367
13
13
  truss/cli/remote_cli.py,sha256=G_xCKRXzgkCmkiZJhUFfsv5YSVgde1jLA5LPQitpZgI,1905
14
14
  truss/cli/train_commands.py,sha256=CrVqWsdkmSxgi3i2sSEyiE4QdfD0Z96F2Ib-PMZJjm8,20444
@@ -34,7 +34,7 @@ truss/cli/utils/output.py,sha256=GNjU85ZAMp5BI6Yij5wYXcaAvpm_kmHV0nHNmdkMxb0,646
34
34
  truss/cli/utils/self_upgrade.py,sha256=eTJZA4Wc8uUp4Qh6viRQp6bZm--wnQp7KWe5KRRpPtg,5427
35
35
  truss/contexts/docker_build_setup.py,sha256=cF4ExZgtYvrWxvyCAaUZUvV_DB_7__MqVomUDpalvKo,3925
36
36
  truss/contexts/truss_context.py,sha256=uS6L-ACHxNk0BsJwESOHh1lA0OGGw0pb33aFKGsASj4,436
37
- truss/contexts/image_builder/cache_warmer.py,sha256=EETFAgZk7C6rQezzFxz4XqjS5LIyF7uM1VVscQt_cBA,6959
37
+ truss/contexts/image_builder/cache_warmer.py,sha256=TGMV1Mh87n2e_dSowH0sf0rZhZraDOR-LVapZL3a5r8,7377
38
38
  truss/contexts/image_builder/image_builder.py,sha256=IuRgDeeoHVLzIkJvKtX3807eeqEyaroCs_KWDcIHZUg,1461
39
39
  truss/contexts/image_builder/serving_image_builder.py,sha256=1PfHtkTEdNPhSQAX8Ajk_0LN3KR2EfLKwOJsnECtKXQ,33958
40
40
  truss/contexts/image_builder/util.py,sha256=y2-CjUKv0XV-0w2sr1fUCflysDJLsoU4oPp6tvvoFnk,1203
@@ -52,18 +52,18 @@ truss/patch/truss_dir_patch_applier.py,sha256=ALnaVnu96g0kF2UmGuBFTua3lrXpwAy4sG
52
52
  truss/remote/remote_factory.py,sha256=-0gLh_yIyNDgD48Q6sR8Yo5dOMQg84lrHRvn_XR0n4s,3585
53
53
  truss/remote/truss_remote.py,sha256=TEe6h6by5-JLy7PMFsDN2QxIY5FmdIYN3bKvHHl02xM,8440
54
54
  truss/remote/baseten/__init__.py,sha256=XNqJW1zyp143XQc6-7XVwsUA_Q_ZJv_ausn1_Ohtw9Y,176
55
- truss/remote/baseten/api.py,sha256=2Es2afWKnz7OlQJHIbvYAKoSrb1dn9SnsAY--uHXbTs,30210
55
+ truss/remote/baseten/api.py,sha256=5B5IXNy0v8hRHNH2ar3rldDa47kwt5s1PtKZQ9_pfmE,28263
56
56
  truss/remote/baseten/auth.py,sha256=tI7s6cI2EZgzpMIzrdbILHyGwiHDnmoKf_JBhJXT55E,776
57
- truss/remote/baseten/core.py,sha256=69utHGGFRw1ZQUobj80TSmaBgU3plnsfHZfiR15dPrY,23502
58
- truss/remote/baseten/custom_types.py,sha256=g7yWkE8p6uIAG5JqgfELFGHzjFLvO7vLPzbe-yl1nYs,4735
57
+ truss/remote/baseten/core.py,sha256=uxtmBI9RAVHu1glIEJb5Q4ccJYLeZM1Cp5Svb9W68Yw,21965
58
+ truss/remote/baseten/custom_types.py,sha256=bYrfTzGgYr6FDoya0omyadCLSTcTc-83U2scQORyUj0,4715
59
59
  truss/remote/baseten/error.py,sha256=3TNTwwPqZnr4NRd9Sl6SfLUQR2fz9l6akDPpOntTpzA,578
60
- truss/remote/baseten/remote.py,sha256=aKG1BODtrnmuRV-M8T3F3pw8oHawGwI09caKANJ19BM,23420
60
+ truss/remote/baseten/remote.py,sha256=Se8AES5mk8jxa8S9fN2DSG7wnsaV7ftRjJ4Uwc_w_S0,22544
61
61
  truss/remote/baseten/rest_client.py,sha256=_t3CWsWARt2u0C0fDsF4rtvkkHe-lH7KXoPxWXAkKd4,1185
62
62
  truss/remote/baseten/service.py,sha256=HMaKiYbr2Mzv4BfXF9QkJ8H3Wwrq3LOMpFt9js4t0rs,5834
63
63
  truss/remote/baseten/utils/status.py,sha256=jputc9N9AHXxUuW4KOk6mcZKzQ_gOBOe5BSx9K0DxPY,1266
64
64
  truss/remote/baseten/utils/tar.py,sha256=pMUv--YkwXDngUx1WUOK-KmAIKMcOg2E-CD5x4heh3s,2514
65
65
  truss/remote/baseten/utils/time.py,sha256=Ry9GMjYnbIGYVIGwtmv4V8ljWjvdcaCf5NOQzlNeGxI,397
66
- truss/remote/baseten/utils/transfer.py,sha256=d3VptuQb6M1nyS6kz0BAfeOYDLkMKUjatJXpY-mp-As,1548
66
+ truss/remote/baseten/utils/transfer.py,sha256=vMI-Dcd_HaRkqVjWI02y3eK9DjiqgG0ULayS0KxIIvA,1652
67
67
  truss/templates/README.md.jinja,sha256=N7CJdyldZuJamj5jLh47le0hFBdu9irVsTBqoxhPNPQ,2476
68
68
  truss/templates/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
69
69
  truss/templates/base.Dockerfile.jinja,sha256=tdMmK5TeiQuYbz4gqbACM3R-l-mazqL9tAZtJ4sxC4g,5331
@@ -93,7 +93,7 @@ truss/templates/custom/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NM
93
93
  truss/templates/custom/model/model.py,sha256=J04rLxK09Pwt2F4GoKOLKL-H-CqZUdYIM-PL2CE9PoE,1079
94
94
  truss/templates/custom_python_dx/my_model.py,sha256=NG75mQ6wxzB1BYUemDFZvRLBET-UrzuUK4FuHjqI29U,910
95
95
  truss/templates/docker_server/proxy.conf.jinja,sha256=Axily8EtznvrF7mUCgy2VFY99BYRt4BycZ0p9uWfd0s,2025
96
- truss/templates/docker_server/supervisord.conf.jinja,sha256=h7yhO1IM9xr8JMn3K8arwRkYSaEL_dzD6U6rPGKmnIY,1835
96
+ truss/templates/docker_server/supervisord.conf.jinja,sha256=AliMMd6bNn-oCYIB8GumTvt8L2JshkjFaoqIcyBzQmc,1835
97
97
  truss/templates/server/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
98
98
  truss/templates/server/main.py,sha256=kWXrdD8z8IpamyWxc8qcvd5ck9gM1Kz2QH5qHJCnmOQ,222
99
99
  truss/templates/server/model_wrapper.py,sha256=k75VVISwwlsx5EGb82UZsu8kCM_i6Yi3-Hd0-Kpm1yo,42055
@@ -141,7 +141,6 @@ truss/tests/test_testing_utilities_for_other_tests.py,sha256=YqIKflnd_BUMYaDBSkX
141
141
  truss/tests/test_truss_gatherer.py,sha256=bn288OEkC49YY0mhly4cAl410ktZPfElNdWwZy82WfA,1261
142
142
  truss/tests/test_truss_handle.py,sha256=-xz9VXkecXDTslmQZ-dmUmQLnvD0uumRqHS2uvGlMBA,30750
143
143
  truss/tests/test_util.py,sha256=hs1bNMkXKEdoPRx4Nw-NAEdoibR92OubZuADGmbiYsQ,1344
144
- truss/tests/cli/test_chains_cli.py,sha256=l9GTQrhRm9SRZn43WkMY4tdRslLmdsVyiydRPa1_Ja4,3162
145
144
  truss/tests/cli/test_cli.py,sha256=yfbVS5u1hnAmmA8mJ539vj3lhH-JVGUvC4Q_Mbort44,787
146
145
  truss/tests/cli/train/test_cache_view.py,sha256=aVRCh3atRpFbJqyYgq7N-vAW0DiKMftQ7ajUqO2ClOg,22606
147
146
  truss/tests/cli/train/test_deploy_checkpoints.py,sha256=Ndkd9YxEgDLf3zLAZYH0myFK_wkKTz0oGZ57yWQt_l8,10100
@@ -163,7 +162,6 @@ truss/tests/remote/test_truss_remote.py,sha256=Rguyrnbx5RlbPJHFfCtsRtX1czAJ9Fo0a
163
162
  truss/tests/remote/baseten/conftest.py,sha256=vNk0nfDB7XdmqatOMhjdANCWFGYM4VwSHVKlaBO2PPk,442
164
163
  truss/tests/remote/baseten/test_api.py,sha256=AKJeNsrUtTNa0QPClfEvXlBOSJ214PKp23ULehMRJOQ,15885
165
164
  truss/tests/remote/baseten/test_auth.py,sha256=ttu4bDnmwGfo3oiNut4HVGnh-QnjAefwZJctiibQJKY,669
166
- truss/tests/remote/baseten/test_chain_upload.py,sha256=XaaF1ocovkBYsLMJ8EpXB9FUGfQZAwu4iyOWqoVn7tc,10886
167
165
  truss/tests/remote/baseten/test_core.py,sha256=6NzJTDmoSUv6Muy1LFEYIUg10-cqw-hbLyeTSWcdNjY,26117
168
166
  truss/tests/remote/baseten/test_remote.py,sha256=y1qSPL1t7dBeYI3xMFn436fttG7wkYdAoENTz7qKObg,23634
169
167
  truss/tests/remote/baseten/test_service.py,sha256=ehbGkzzSPdLN7JHxc0O9YDPfzzKqU8OBzJGjRdw08zE,3786
@@ -318,7 +316,7 @@ truss/tests/test_data/test_truss_with_error/packages/helpers_2.py,sha256=q_UpVfX
318
316
  truss/tests/trt_llm/test_trt_llm_config.py,sha256=lNQ4EEkOsiT17KvnvW1snCeEBd7K_cl9_Y0dko3qpn8,8505
319
317
  truss/tests/trt_llm/test_validation.py,sha256=dmax2EHxRfqxJvWzV8uubkTef50833KBBHw-WkHufL8,2120
320
318
  truss/tests/util/test_config_checks.py,sha256=aoZF_Q-eRd3qz5wjUqa8Cr_7qF2SxodXbBIY_DBuFWg,522
321
- truss/tests/util/test_env_vars.py,sha256=hthgB1mU0bJb1H4Jugc-0khArlLZ3x6tLE82cDaa-J0,390
319
+ truss/tests/util/test_env_vars.py,sha256=kz5FlynWvXhNR9bhf-xC-0cXuyYyAOvtKLLxKLDUGf4,616
322
320
  truss/tests/util/test_path.py,sha256=YfW3-IM_7iRsdR1Cb26KB1BkDsG_53_BUGBzoxY2Nog,7408
323
321
  truss/trt_llm/config_checks.py,sha256=Efxb5l7vRNveDglse78untq9V-IgtLxejk3_8JKxN5I,4671
324
322
  truss/trt_llm/validation.py,sha256=cse-EnmuHmRpwBuSc3IvmSnl-StSQCIFM1nssgnaRUQ,1848
@@ -341,8 +339,7 @@ truss/util/.truss_ignore,sha256=jpQA9ou-r_JEIcEHsUqGLHhir_m3d4IPGNyzKXtS-2g,3131
341
339
  truss/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
342
340
  truss/util/docker.py,sha256=6PD7kMBBrOjsdvgkuSv7JMgZbe3NoJIeGasljMm2SwA,3934
343
341
  truss/util/download.py,sha256=1lfBwzyaNLEp7SAVrBd9BX5inZpkCVp8sBnS9RNoiJA,2521
344
- truss/util/env_vars.py,sha256=7Bv686eER71Barrs6fNamk_TrTJGmu9yV2TxaVmupn0,1232
345
- truss/util/error_utils.py,sha256=aO76Vf8LMlvhM28jRJ1qzNl4E5ZyvKK4TFQq_UhbQrk,1095
342
+ truss/util/env_vars.py,sha256=PmKsXdN-PX2-_xk9XcdHTFuRRWiFaMw2iNUKxE8B1Ro,1671
346
343
  truss/util/gpu.py,sha256=YiEF_JZyzur0MDMJOebMuJBQxrHD9ApGI0aPpWdb5BU,553
347
344
  truss/util/jinja.py,sha256=7KbuYNq55I3DGtImAiCvBwR0K9-z1Jo6gMhmsy4lNZE,333
348
345
  truss/util/log_utils.py,sha256=LwSgRh2K7KFjKKqBxr-IirFxGIzHi1mUM7YEvujvHsE,1985
@@ -352,8 +349,8 @@ truss/util/requirements.py,sha256=6T4nVV_NbSl3mAEo-CAk3JFmyJ_RJD768QaR55RdUJQ,69
352
349
  truss/util/user_config.py,sha256=CvBf5oouNyfdcFXOg3HFhELVW-THiuwyOYdW3aTxdHw,9130
353
350
  truss_chains/__init__.py,sha256=QDw1YwdqMaQpz5Oltu2Eq2vzEX9fDrMoqnhtbeh60i4,1278
354
351
  truss_chains/framework.py,sha256=CS7tSegPe2Q8UUT6CDkrtSrB3utr_1QN1jTEPjrj5Ug,67519
355
- truss_chains/private_types.py,sha256=vdcl8FuVsL9JGIu_9K7fd2EW9Ytzoq8nfEx5pmuMKTA,9063
356
- truss_chains/public_api.py,sha256=civY8juJU92jSGBI7zM1qMnA7hlUdCq7L8o4IOo5meA,9722
352
+ truss_chains/private_types.py,sha256=6CaQEPawFLXjEbJ-01lqfexJtUIekF_q61LNENWegFo,8917
353
+ truss_chains/public_api.py,sha256=0AXV6UdZIFAMycUNG_klgo4aLFmBZeKGfrulZEWzR0M,9532
357
354
  truss_chains/public_types.py,sha256=RPr8jgKO_F_26F7H3CpwbidL-6euoKPdFHVpEIpYqrQ,29415
358
355
  truss_chains/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
359
356
  truss_chains/pydantic_numpy.py,sha256=MG8Ji_Inwo_JSfM2n7TPj8B-nbrBlDYsY3SOeBwD8fE,4289
@@ -361,7 +358,7 @@ truss_chains/streaming.py,sha256=DGl2LEAN67YwP7Nn9MK488KmYc4KopWmcHuE6WjyO1Q,125
361
358
  truss_chains/utils.py,sha256=LvpCG2lnN6dqPqyX3PwLH9tyjUzqQN3N4WeEFROMHak,6291
362
359
  truss_chains/deployment/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
363
360
  truss_chains/deployment/code_gen.py,sha256=397FiSNZuW59J3Ma7N9GKGfvG_87BNFAXCIV8BW41t0,32669
364
- truss_chains/deployment/deployment_client.py,sha256=4cHuvaynVCclJ6M9pw8ukhO1E2NRKohIRxftvOfNvOE,34499
361
+ truss_chains/deployment/deployment_client.py,sha256=OoqkO3daktYzR2YsIcDvsuGfjR05X2K7QlA7wvFduzc,34208
365
362
  truss_chains/reference_code/reference_chainlet.py,sha256=5feSeqGtrHDbldkfZCfX2R5YbbW0Uhc35mhaP2pXrHw,1340
366
363
  truss_chains/reference_code/reference_model.py,sha256=emH3hb23E_nbP98I37PGp1Xk1hz3g3lQ00tiLo55cSM,322
367
364
  truss_chains/remote_chainlet/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -374,8 +371,8 @@ truss_train/deployment.py,sha256=lWWANSuzBWu2M4oK4qD7n-oVR1JKdmw2Pn5BJQHg-Ck,307
374
371
  truss_train/loader.py,sha256=0o66EjBaHc2YY4syxxHVR4ordJWs13lNXnKjKq2wq0U,1630
375
372
  truss_train/public_api.py,sha256=9N_NstiUlmBuLUwH_fNG_1x7OhGCytZLNvqKXBlStrM,1220
376
373
  truss_train/restore_from_checkpoint.py,sha256=8hdPm-WSgkt74HDPjvCjZMBpvA9MwtoYsxVjOoa7BaM,1176
377
- truss-0.11.15rc12.dist-info/METADATA,sha256=xsPYMgpuuqJKv31YfroCshYgoeEd-GvyYo7lh1_ONBo,6682
378
- truss-0.11.15rc12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
379
- truss-0.11.15rc12.dist-info/entry_points.txt,sha256=-MwKfHHQHQ6j0HqIgvxrz3CehCmczDLTD-OsRHnjjuU,130
380
- truss-0.11.15rc12.dist-info/licenses/LICENSE,sha256=FTqGzu85i-uw1Gi8E_o0oD60bH9yQ_XIGtZbA1QUYiw,1064
381
- truss-0.11.15rc12.dist-info/RECORD,,
374
+ truss-0.11.16.dist-info/METADATA,sha256=ykIbKzLWO3FqVipKHGs7Ad8yNkqGXoHm74AchWFVgHA,6678
375
+ truss-0.11.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
376
+ truss-0.11.16.dist-info/entry_points.txt,sha256=-MwKfHHQHQ6j0HqIgvxrz3CehCmczDLTD-OsRHnjjuU,130
377
+ truss-0.11.16.dist-info/licenses/LICENSE,sha256=FTqGzu85i-uw1Gi8E_o0oD60bH9yQ_XIGtZbA1QUYiw,1064
378
+ truss-0.11.16.dist-info/RECORD,,
@@ -516,21 +516,14 @@ def _create_baseten_chain(
516
516
 
517
517
  _create_chains_secret_if_missing(remote_provider)
518
518
 
519
- # Get chain root for raw chain artifact upload
520
- chain_root = None
521
- if entrypoint_descriptor is not None:
522
- chain_root = _get_chain_root(entrypoint_descriptor.chainlet_cls)
523
-
524
519
  chain_deployment_handle = remote_provider.push_chain_atomic(
525
520
  baseten_options.chain_name,
526
521
  entrypoint_artifact,
527
522
  dependency_artifacts,
528
523
  truss_user_env,
529
- chain_root=chain_root,
530
524
  publish=baseten_options.publish,
531
525
  environment=baseten_options.environment,
532
526
  progress_bar=progress_bar,
533
- disable_chain_download=baseten_options.disable_chain_download,
534
527
  )
535
528
  return BasetenChainService(
536
529
  baseten_options.chain_name,
@@ -265,7 +265,6 @@ class PushOptionsBaseten(PushOptions):
265
265
  environment: Optional[str]
266
266
  include_git_info: bool
267
267
  working_dir: pathlib.Path
268
- disable_chain_download: bool = False
269
268
 
270
269
  @classmethod
271
270
  def create(
@@ -278,7 +277,6 @@ class PushOptionsBaseten(PushOptions):
278
277
  include_git_info: bool,
279
278
  working_dir: pathlib.Path,
280
279
  environment: Optional[str] = None,
281
- disable_chain_download: bool = False,
282
280
  ) -> "PushOptionsBaseten":
283
281
  if promote and not environment:
284
282
  environment = PRODUCTION_ENVIRONMENT_NAME
@@ -292,7 +290,6 @@ class PushOptionsBaseten(PushOptions):
292
290
  environment=environment,
293
291
  include_git_info=include_git_info,
294
292
  working_dir=working_dir,
295
- disable_chain_download=disable_chain_download,
296
293
  )
297
294
 
298
295
 
@@ -151,7 +151,6 @@ def push(
151
151
  environment: Optional[str] = None,
152
152
  progress_bar: Optional[Type["progress.Progress"]] = None,
153
153
  include_git_info: bool = False,
154
- disable_chain_download: bool = False,
155
154
  ) -> deployment_client.BasetenChainService:
156
155
  """
157
156
  Deploys a chain remotely (with all dependent chainlets).
@@ -173,7 +172,6 @@ def push(
173
172
  include_git_info: Whether to attach git versioning info (sha, branch, tag) to
174
173
  deployments made from within a git repo. If set to True in `.trussrc`, it
175
174
  will always be attached.
176
- disable_chain_download: Disable downloading of pushed chain source code from the UI.
177
175
 
178
176
  Returns:
179
177
  A chain service handle to the deployed chain.
@@ -188,7 +186,6 @@ def push(
188
186
  environment=environment,
189
187
  include_git_info=include_git_info,
190
188
  working_dir=pathlib.Path(inspect.getfile(entrypoint)).parent,
191
- disable_chain_download=disable_chain_download,
192
189
  )
193
190
  service = deployment_client.push(entrypoint, options, progress_bar=progress_bar)
194
191
  assert isinstance(
@@ -1,100 +0,0 @@
1
- """Tests for truss chains CLI commands."""
2
-
3
- from unittest.mock import Mock, patch
4
-
5
- from click.testing import CliRunner
6
-
7
- from truss.cli.cli import truss_cli
8
-
9
-
10
- def test_chains_push_with_disable_chain_download_flag():
11
- """Test that --disable-chain-download flag is properly parsed and passed through."""
12
- runner = CliRunner()
13
-
14
- mock_entrypoint_cls = Mock()
15
- mock_entrypoint_cls.meta_data.chain_name = "test_chain"
16
- mock_entrypoint_cls.display_name = "TestChain"
17
-
18
- mock_service = Mock()
19
- mock_service.run_remote_url = "http://test.com/run_remote"
20
- mock_service.is_websocket = False
21
-
22
- with patch(
23
- "truss_chains.framework.ChainletImporter.import_target"
24
- ) as mock_importer:
25
- with patch("truss_chains.deployment.deployment_client.push") as mock_push:
26
- mock_importer.return_value.__enter__.return_value = mock_entrypoint_cls
27
- mock_push.return_value = mock_service
28
-
29
- result = runner.invoke(
30
- truss_cli,
31
- [
32
- "chains",
33
- "push",
34
- "test_chain.py",
35
- "--disable-chain-download",
36
- "--remote",
37
- "test_remote",
38
- "--dryrun",
39
- ],
40
- )
41
-
42
- assert result.exit_code == 0
43
-
44
- mock_push.assert_called_once()
45
- call_args = mock_push.call_args
46
- options = call_args[0][1]
47
-
48
- assert hasattr(options, "disable_chain_download")
49
- assert options.disable_chain_download is True
50
-
51
-
52
- def test_chains_push_without_disable_chain_download_flag():
53
- """Test that disable_chain_download defaults to False when flag is not provided."""
54
- runner = CliRunner()
55
-
56
- mock_entrypoint_cls = Mock()
57
- mock_entrypoint_cls.meta_data.chain_name = "test_chain"
58
- mock_entrypoint_cls.display_name = "TestChain"
59
-
60
- mock_service = Mock()
61
- mock_service.run_remote_url = "http://test.com/run_remote"
62
- mock_service.is_websocket = False
63
-
64
- with patch(
65
- "truss_chains.framework.ChainletImporter.import_target"
66
- ) as mock_importer:
67
- with patch("truss_chains.deployment.deployment_client.push") as mock_push:
68
- mock_importer.return_value.__enter__.return_value = mock_entrypoint_cls
69
- mock_push.return_value = mock_service
70
-
71
- result = runner.invoke(
72
- truss_cli,
73
- [
74
- "chains",
75
- "push",
76
- "test_chain.py",
77
- "--remote",
78
- "test_remote",
79
- "--dryrun",
80
- ],
81
- )
82
-
83
- assert result.exit_code == 0
84
-
85
- mock_push.assert_called_once()
86
- call_args = mock_push.call_args
87
- options = call_args[0][1]
88
-
89
- assert hasattr(options, "disable_chain_download")
90
- assert options.disable_chain_download is False
91
-
92
-
93
- def test_chains_push_help_includes_disable_chain_download():
94
- """Test that --disable-chain-download appears in the help output."""
95
- runner = CliRunner()
96
-
97
- result = runner.invoke(truss_cli, ["chains", "push", "--help"])
98
-
99
- assert result.exit_code == 0
100
- assert "--disable-chain-download" in result.output
@@ -1,285 +0,0 @@
1
- import pathlib
2
- import tempfile
3
- from unittest.mock import Mock, patch
4
-
5
- import pytest
6
-
7
- from truss.remote.baseten import custom_types as b10_types
8
- from truss.remote.baseten.api import BasetenApi
9
- from truss.remote.baseten.core import upload_chain_artifact
10
- from truss.remote.baseten.remote import BasetenRemote
11
-
12
-
13
- @pytest.fixture
14
- def mock_push_data():
15
- """Fixture providing mock push data for tests."""
16
- mock_push_data = Mock()
17
- mock_push_data.model_name = "test-model"
18
- mock_push_data.s3_key = "models/test-key"
19
- mock_push_data.encoded_config_str = "encoded_config"
20
- mock_push_data.is_draft = False
21
- mock_push_data.model_id = "model-id"
22
- mock_push_data.version_name = None
23
- return mock_push_data
24
-
25
-
26
- @pytest.fixture
27
- def mock_remote_context():
28
- """Fixture providing mock remote and context managers for tests."""
29
- api = Mock(spec=BasetenApi)
30
-
31
- remote = BasetenRemote("https://test.baseten.co", "test-api-key")
32
- remote._api = api
33
-
34
- chain_name = "test-chain"
35
- entrypoint_artifact = Mock()
36
- entrypoint_artifact.truss_dir = "/path/to/truss"
37
- entrypoint_artifact.display_name = "entrypoint"
38
-
39
- dependency_artifacts = []
40
- truss_user_env = Mock()
41
- chain_root = pathlib.Path("/path/to/chain")
42
-
43
- with patch.object(remote, "_prepare_push") as mock_prepare_push:
44
- with patch("truss.remote.baseten.remote.truss_build.load") as mock_load:
45
- mock_truss_handle = Mock()
46
- mock_truss_handle.spec.config.model_name = "test-model"
47
- mock_load.return_value = mock_truss_handle
48
-
49
- yield {
50
- "remote": remote,
51
- "api": api,
52
- "chain_name": chain_name,
53
- "entrypoint_artifact": entrypoint_artifact,
54
- "dependency_artifacts": dependency_artifacts,
55
- "truss_user_env": truss_user_env,
56
- "chain_root": chain_root,
57
- "mock_prepare_push": mock_prepare_push,
58
- "mock_load": mock_load,
59
- "mock_truss_handle": mock_truss_handle,
60
- }
61
-
62
-
63
- def test_get_blob_credentials_for_chain():
64
- """Test that get_blob_credentials works correctly for chain blob type using GraphQL."""
65
- mock_graphql_response = {
66
- "data": {
67
- "chain_s3_upload_credentials": {
68
- "s3_bucket": "test-chain-bucket",
69
- "s3_key": "chains/test-uuid/chain.tgz",
70
- "aws_access_key_id": "test_access_key",
71
- "aws_secret_access_key": "test_secret_key",
72
- "aws_session_token": "test_session_token",
73
- }
74
- }
75
- }
76
-
77
- # Create a real API instance and mock the GraphQL call
78
- mock_auth_service = Mock()
79
- mock_auth_service.authenticate.return_value = Mock(value="test-token")
80
- api = BasetenApi("https://test.baseten.co", mock_auth_service)
81
- with patch.object(api, "_post_graphql_query") as mock_graphql:
82
- mock_graphql.return_value = mock_graphql_response
83
-
84
- result = api.get_chain_s3_upload_credentials()
85
-
86
- assert result.s3_bucket == "test-chain-bucket"
87
- assert result.s3_key == "chains/test-uuid/chain.tgz"
88
- assert result.aws_access_key_id == "test_access_key"
89
- assert result.aws_secret_access_key == "test_secret_key"
90
- assert result.aws_session_token == "test_session_token"
91
-
92
- mock_graphql.assert_called_once()
93
- call_args = mock_graphql.call_args
94
- assert "chain_s3_upload_credentials" in call_args[0][0]
95
-
96
-
97
- def test_get_blob_credentials_for_other_types_uses_rest():
98
- """Test that get_blob_credentials uses REST API for non-chain blob types."""
99
- mock_response = {
100
- "s3_bucket": "test-bucket",
101
- "s3_key": "test-key",
102
- "creds": {
103
- "aws_access_key_id": "test_access_key",
104
- "aws_secret_access_key": "test_secret_key",
105
- "aws_session_token": "test_session_token",
106
- },
107
- }
108
-
109
- mock_auth_service = Mock()
110
- mock_auth_service.authenticate.return_value = Mock(value="test-token")
111
- api = BasetenApi("https://test.baseten.co", mock_auth_service)
112
- with patch.object(api, "_rest_api_client") as mock_client, patch.object(
113
- api, "_post_graphql_query"
114
- ) as mock_graphql:
115
- mock_client.get.return_value = mock_response
116
-
117
- result = api.get_blob_credentials(b10_types.BlobType.MODEL)
118
-
119
- assert result["s3_bucket"] == "test-bucket"
120
- assert result["s3_key"] == "test-key"
121
-
122
- mock_client.get.assert_called_once_with("v1/blobs/credentials/model")
123
- mock_graphql.assert_not_called()
124
-
125
-
126
- @patch("truss.remote.baseten.core.multipart_upload_boto3")
127
- def test_upload_chain_artifact_function(mock_multipart_upload):
128
- """Test the upload_chain_artifact function."""
129
- # Mock ChainUploadCredentials object
130
- mock_credentials = Mock()
131
- mock_credentials.s3_bucket = "test-chain-bucket"
132
- mock_credentials.s3_key = "chains/test-uuid/chain.tgz"
133
- mock_credentials.aws_credentials = Mock()
134
- mock_credentials.aws_credentials.model_dump.return_value = {
135
- "aws_access_key_id": "test_access_key",
136
- "aws_secret_access_key": "test_secret_key",
137
- "aws_session_token": "test_session_token",
138
- }
139
-
140
- api = Mock(spec=BasetenApi)
141
- api.get_chain_s3_upload_credentials.return_value = mock_credentials
142
-
143
- with tempfile.NamedTemporaryFile(suffix=".tgz", delete=False) as temp_file:
144
- temp_file.write(b"test chain content")
145
- temp_file.flush()
146
-
147
- result = upload_chain_artifact(api, temp_file, None)
148
-
149
- assert result == "chains/test-uuid/chain.tgz"
150
-
151
- api.get_chain_s3_upload_credentials.assert_called_once_with()
152
-
153
- mock_multipart_upload.assert_called_once()
154
- call_args = mock_multipart_upload.call_args
155
- assert call_args[0][0] == temp_file.name # file path
156
- assert call_args[0][1] == "test-chain-bucket" # bucket
157
- assert call_args[0][2] == "chains/test-uuid/chain.tgz" # key
158
- assert call_args[0][3] == { # credentials
159
- "aws_access_key_id": "test_access_key",
160
- "aws_secret_access_key": "test_secret_key",
161
- "aws_session_token": "test_session_token",
162
- }
163
-
164
-
165
- @patch("truss.remote.baseten.remote.upload_chain_artifact")
166
- @patch("truss.remote.baseten.remote.archive_dir")
167
- @patch("truss.remote.baseten.remote.create_chain_atomic")
168
- def test_push_chain_atomic_with_chain_upload(
169
- mock_create_chain_atomic,
170
- mock_archive_dir,
171
- mock_upload_chain_artifact,
172
- mock_push_data,
173
- mock_remote_context,
174
- ):
175
- """Test that push_chain_atomic uploads raw chain artifact when chain_root is provided."""
176
- mock_create_chain_atomic.return_value = Mock()
177
- mock_archive_dir.return_value = Mock()
178
- mock_upload_chain_artifact.return_value = "chains/test-uuid/chain.tgz"
179
-
180
- context = mock_remote_context
181
- remote = context["remote"]
182
- chain_name = context["chain_name"]
183
- entrypoint_artifact = context["entrypoint_artifact"]
184
- dependency_artifacts = context["dependency_artifacts"]
185
- truss_user_env = context["truss_user_env"]
186
- chain_root = context["chain_root"]
187
-
188
- context["mock_prepare_push"].return_value = mock_push_data
189
-
190
- result = remote.push_chain_atomic(
191
- chain_name=chain_name,
192
- entrypoint_artifact=entrypoint_artifact,
193
- dependency_artifacts=dependency_artifacts,
194
- truss_user_env=truss_user_env,
195
- chain_root=chain_root,
196
- publish=True,
197
- )
198
- assert result == mock_create_chain_atomic.return_value
199
-
200
- mock_archive_dir.assert_called_once_with(dir=chain_root, progress_bar=None)
201
- mock_upload_chain_artifact.assert_called_once()
202
-
203
- mock_create_chain_atomic.assert_called_once()
204
-
205
-
206
- @patch("truss.remote.baseten.remote.create_chain_atomic")
207
- def test_push_chain_atomic_without_chain_upload(
208
- mock_create_chain_atomic, mock_push_data, mock_remote_context
209
- ):
210
- """Test that push_chain_atomic skips chain upload when chain_root is None."""
211
- mock_create_chain_atomic.return_value = Mock()
212
-
213
- context = mock_remote_context
214
- remote = context["remote"]
215
- chain_name = context["chain_name"]
216
- entrypoint_artifact = context["entrypoint_artifact"]
217
- dependency_artifacts = context["dependency_artifacts"]
218
- truss_user_env = context["truss_user_env"]
219
-
220
- context["mock_prepare_push"].return_value = mock_push_data
221
-
222
- with patch("truss.remote.baseten.remote.upload_chain_artifact") as mock_upload:
223
- with patch(
224
- "truss.remote.baseten.core.create_tar_with_progress_bar"
225
- ) as mock_tar:
226
- # Call push_chain_atomic without chain_root
227
- result = remote.push_chain_atomic(
228
- chain_name=chain_name,
229
- entrypoint_artifact=entrypoint_artifact,
230
- dependency_artifacts=dependency_artifacts,
231
- truss_user_env=truss_user_env,
232
- chain_root=None, # No chain root
233
- publish=True,
234
- )
235
-
236
- assert result
237
- # Verify chain artifact upload was NOT called
238
- mock_tar.assert_not_called()
239
- mock_upload.assert_not_called()
240
-
241
- mock_create_chain_atomic.assert_called_once()
242
-
243
-
244
- @patch("truss.remote.baseten.core.multipart_upload_boto3")
245
- def test_upload_chain_artifact_error_handling(mock_multipart_upload):
246
- """Test error handling in upload_chain_artifact."""
247
- # Mock API to raise an exception
248
- api = Mock(spec=BasetenApi)
249
- api.get_chain_s3_upload_credentials.side_effect = Exception("API Error")
250
-
251
- with tempfile.NamedTemporaryFile(suffix=".tgz") as temp_file:
252
- # Should raise the exception
253
- with pytest.raises(Exception, match="API Error"):
254
- upload_chain_artifact(api, temp_file, None)
255
-
256
-
257
- def test_upload_chain_artifact_credentials_extraction():
258
- """Test that credentials are properly extracted from API response."""
259
- # Mock ChainUploadCredentials object
260
- mock_credentials = Mock()
261
- mock_credentials.s3_bucket = "test-bucket"
262
- mock_credentials.s3_key = "chains/test-uuid/chain.tgz"
263
- mock_credentials.aws_credentials = Mock()
264
- mock_credentials.aws_credentials.model_dump.return_value = {
265
- "aws_access_key_id": "access_key",
266
- "aws_secret_access_key": "secret_key",
267
- "aws_session_token": "session_token",
268
- }
269
-
270
- api = Mock(spec=BasetenApi)
271
- api.get_chain_s3_upload_credentials.return_value = mock_credentials
272
-
273
- with patch("truss.remote.baseten.core.multipart_upload_boto3") as mock_upload:
274
- with tempfile.NamedTemporaryFile(suffix=".tgz") as temp_file:
275
- upload_chain_artifact(api, temp_file, None)
276
-
277
- call_args = mock_upload.call_args
278
- credentials = call_args[0][3]
279
-
280
- assert credentials == {
281
- "aws_access_key_id": "access_key",
282
- "aws_secret_access_key": "secret_key",
283
- "aws_session_token": "session_token",
284
- }
285
- assert "extra_field" not in credentials
truss/util/error_utils.py DELETED
@@ -1,34 +0,0 @@
1
- from contextlib import contextmanager
2
- from typing import Generator
3
-
4
- from botocore.exceptions import ClientError, NoCredentialsError
5
-
6
-
7
- @contextmanager
8
- def handle_client_error(
9
- operation_description: str = "AWS operation",
10
- ) -> Generator[None, None, None]:
11
- """
12
- Context manager to handle common boto3 errors and convert them to RuntimeError.
13
-
14
- Args:
15
- operation_description: Description of the operation being performed for error messages
16
-
17
- Raises:
18
- RuntimeError: For NoCredentialsError, ClientError, and other exceptions
19
- """
20
- try:
21
- yield
22
- except NoCredentialsError as nce:
23
- raise RuntimeError(
24
- f"No AWS credentials found for {operation_description}\nOriginal exception: {str(nce)}"
25
- )
26
- except ClientError as ce:
27
- raise RuntimeError(
28
- f"AWS client error when {operation_description} (check your credentials): {str(ce)}"
29
- )
30
-
31
- except Exception as exc:
32
- raise RuntimeError(
33
- f"Unexpected error `{exc}` during `{operation_description}`\nOriginal exception: {str(exc)}"
34
- )