truss 0.11.6rc102__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. truss/api/__init__.py +5 -2
  2. truss/base/constants.py +1 -0
  3. truss/base/trt_llm_config.py +14 -3
  4. truss/base/truss_config.py +19 -4
  5. truss/cli/chains_commands.py +49 -1
  6. truss/cli/cli.py +38 -7
  7. truss/cli/logs/base_watcher.py +31 -12
  8. truss/cli/logs/model_log_watcher.py +24 -1
  9. truss/cli/remote_cli.py +29 -0
  10. truss/cli/resolvers/chain_team_resolver.py +82 -0
  11. truss/cli/resolvers/model_team_resolver.py +90 -0
  12. truss/cli/resolvers/training_project_team_resolver.py +81 -0
  13. truss/cli/train/cache.py +332 -0
  14. truss/cli/train/core.py +57 -163
  15. truss/cli/train/deploy_checkpoints/__init__.py +2 -2
  16. truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +236 -103
  17. truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +1 -52
  18. truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +1 -86
  19. truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py +1 -85
  20. truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +1 -56
  21. truss/cli/train/types.py +18 -9
  22. truss/cli/train_commands.py +180 -35
  23. truss/cli/utils/common.py +40 -3
  24. truss/contexts/image_builder/serving_image_builder.py +17 -4
  25. truss/remote/baseten/api.py +215 -9
  26. truss/remote/baseten/core.py +63 -7
  27. truss/remote/baseten/custom_types.py +1 -0
  28. truss/remote/baseten/remote.py +42 -2
  29. truss/remote/baseten/service.py +0 -7
  30. truss/remote/baseten/utils/transfer.py +5 -2
  31. truss/templates/base.Dockerfile.jinja +8 -4
  32. truss/templates/control/control/application.py +51 -26
  33. truss/templates/control/control/endpoints.py +1 -5
  34. truss/templates/control/control/helpers/inference_server_process_controller.py +10 -4
  35. truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py +33 -18
  36. truss/templates/control/control/server.py +1 -1
  37. truss/templates/control/requirements.txt +1 -2
  38. truss/templates/docker_server/proxy.conf.jinja +13 -0
  39. truss/templates/docker_server/supervisord.conf.jinja +2 -1
  40. truss/templates/no_build.Dockerfile.jinja +1 -0
  41. truss/templates/server/requirements.txt +2 -3
  42. truss/templates/server/truss_server.py +2 -5
  43. truss/templates/server.Dockerfile.jinja +12 -12
  44. truss/templates/shared/lazy_data_resolver.py +214 -2
  45. truss/templates/shared/util.py +6 -5
  46. truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
  47. truss/tests/cli/test_chains_cli.py +144 -0
  48. truss/tests/cli/test_cli.py +134 -1
  49. truss/tests/cli/test_cli_utils_common.py +11 -0
  50. truss/tests/cli/test_model_team_resolver.py +279 -0
  51. truss/tests/cli/train/test_cache_view.py +240 -3
  52. truss/tests/cli/train/test_deploy_checkpoints.py +2 -846
  53. truss/tests/cli/train/test_train_cli_core.py +2 -2
  54. truss/tests/cli/train/test_train_team_parameter.py +395 -0
  55. truss/tests/conftest.py +187 -0
  56. truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
  57. truss/tests/remote/baseten/test_api.py +122 -3
  58. truss/tests/remote/baseten/test_chain_upload.py +294 -0
  59. truss/tests/remote/baseten/test_core.py +86 -0
  60. truss/tests/remote/baseten/test_remote.py +216 -288
  61. truss/tests/remote/baseten/test_service.py +56 -0
  62. truss/tests/templates/control/control/conftest.py +20 -0
  63. truss/tests/templates/control/control/test_endpoints.py +4 -0
  64. truss/tests/templates/control/control/test_server.py +8 -24
  65. truss/tests/templates/control/control/test_server_integration.py +4 -2
  66. truss/tests/test_config.py +21 -12
  67. truss/tests/test_data/server.Dockerfile +3 -1
  68. truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
  69. truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
  70. truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
  71. truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
  72. truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
  73. truss/tests/test_model_inference.py +13 -0
  74. truss/tests/util/test_env_vars.py +8 -3
  75. truss/util/__init__.py +0 -0
  76. truss/util/env_vars.py +19 -8
  77. truss/util/error_utils.py +37 -0
  78. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/METADATA +2 -2
  79. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/RECORD +88 -70
  80. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
  81. truss_chains/deployment/deployment_client.py +16 -4
  82. truss_chains/private_types.py +18 -0
  83. truss_chains/public_api.py +3 -0
  84. truss_train/definitions.py +6 -4
  85. truss_train/deployment.py +43 -21
  86. truss_train/public_api.py +4 -2
  87. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
  88. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
@@ -32,6 +32,7 @@ from truss.base.constants import (
32
32
  FILENAME_CONSTANTS_MAP,
33
33
  MODEL_CACHE_PATH,
34
34
  MODEL_DOCKERFILE_NAME,
35
+ NO_BUILD_DOCKERFILE_TEMPLATE_NAME,
35
36
  REQUIREMENTS_TXT_FILENAME,
36
37
  SERVER_CODE_DIR,
37
38
  SERVER_DOCKERFILE_TEMPLATE_NAME,
@@ -577,7 +578,10 @@ class ServingImageBuilder(ImageBuilder):
577
578
  else:
578
579
  self.prepare_trtllm_decoder_build_dir(build_dir=build_dir)
579
580
 
580
- if config.docker_server is not None:
581
+ if (
582
+ config.docker_server is not None
583
+ and config.docker_server.no_build is not True
584
+ ):
581
585
  self._copy_into_build_dir(
582
586
  TEMPLATES_DIR / "docker_server_requirements.txt",
583
587
  build_dir,
@@ -750,12 +754,21 @@ class ServingImageBuilder(ImageBuilder):
750
754
  build_commands: List[str],
751
755
  ):
752
756
  config = self._spec.config
757
+
753
758
  data_dir = build_dir / config.data_dir
754
759
  model_dir = build_dir / config.model_module_dir
755
760
  bundled_packages_dir = build_dir / config.bundled_packages_dir
756
- dockerfile_template = read_template_from_fs(
757
- TEMPLATES_DIR, SERVER_DOCKERFILE_TEMPLATE_NAME
758
- )
761
+
762
+ # Note: no-build deployment template doesn't use most of the template variables,
763
+ # because it tries to run the base image as-is to the extent possible.
764
+ if config.docker_server and config.docker_server.no_build:
765
+ dockerfile_template = read_template_from_fs(
766
+ TEMPLATES_DIR, NO_BUILD_DOCKERFILE_TEMPLATE_NAME
767
+ )
768
+ else:
769
+ dockerfile_template = read_template_from_fs(
770
+ TEMPLATES_DIR, SERVER_DOCKERFILE_TEMPLATE_NAME
771
+ )
759
772
  python_version = truss_config.to_dotted_python_version(config.python_version)
760
773
  if config.base_image:
761
774
  base_image_name_and_tag = config.base_image.image
@@ -3,7 +3,9 @@ from enum import Enum
3
3
  from typing import Any, Dict, List, Mapping, Optional
4
4
 
5
5
  import requests
6
+ from pydantic import BaseModel, Field
6
7
 
8
+ from truss.base.custom_types import SafeModel
7
9
  from truss.remote.baseten import custom_types as b10_types
8
10
  from truss.remote.baseten.auth import ApiKey, AuthService
9
11
  from truss.remote.baseten.custom_types import APIKeyCategory
@@ -12,6 +14,62 @@ from truss.remote.baseten.rest_client import RestAPIClient
12
14
  from truss.remote.baseten.utils.transfer import base64_encoded_json_str
13
15
 
14
16
  logger = logging.getLogger(__name__)
17
+ PARAMS_INDENT = "\n "
18
+
19
+
20
+ class ChainAWSCredential(SafeModel):
21
+ aws_access_key_id: str
22
+ aws_secret_access_key: str
23
+ aws_session_token: str
24
+
25
+
26
+ class ChainUploadCredentials(SafeModel):
27
+ s3_bucket: str
28
+ s3_key: str
29
+ aws_access_key_id: str
30
+ aws_secret_access_key: str
31
+ aws_session_token: str
32
+
33
+ @property
34
+ def aws_credentials(self) -> ChainAWSCredential:
35
+ return ChainAWSCredential(
36
+ aws_access_key_id=self.aws_access_key_id,
37
+ aws_secret_access_key=self.aws_secret_access_key,
38
+ aws_session_token=self.aws_session_token,
39
+ )
40
+
41
+
42
+ class InstanceTypeV1(BaseModel):
43
+ """An instance type."""
44
+
45
+ id: str = Field(description="Identifier string for the instance type")
46
+ name: str = Field(description="Name of the instance type")
47
+ display_name: str = Field(
48
+ alias="displayName", description="Display name of the instance type"
49
+ )
50
+ gpu_count: int = Field(
51
+ alias="gpuCount", description="Number of GPUs on the instance type"
52
+ )
53
+ default: bool = Field(description="Whether this is the default instance type")
54
+ gpu_memory: Optional[int] = Field(alias="gpuMemory", description="GPU memory in MB")
55
+ node_count: int = Field(alias="nodeCount", description="Number of nodes")
56
+ gpu_type: Optional[str] = Field(
57
+ alias="gpuType", description="Type of GPU on the instance type"
58
+ )
59
+ millicpu_limit: int = Field(
60
+ alias="millicpuLimit", description="CPU limit of the instance type in millicpu"
61
+ )
62
+ memory_limit: int = Field(
63
+ alias="memoryLimit", description="Memory limit of the instance type in MB"
64
+ )
65
+ price: Optional[float] = Field(description="Price of the instance type")
66
+ limited_capacity: Optional[bool] = Field(
67
+ alias="limitedCapacity",
68
+ description="Whether this instance type has limited capacity",
69
+ )
70
+
71
+ class Config:
72
+ populate_by_name = True
15
73
 
16
74
 
17
75
  API_URL_MAPPING = {
@@ -141,6 +199,9 @@ class BasetenApi:
141
199
  allow_truss_download: bool = True,
142
200
  deployment_name: Optional[str] = None,
143
201
  origin: Optional[b10_types.ModelOrigin] = None,
202
+ environment: Optional[str] = None,
203
+ deploy_timeout_minutes: Optional[int] = None,
204
+ team_id: Optional[str] = None,
144
205
  ):
145
206
  query_string = f"""
146
207
  mutation ($trussUserEnv: String) {{
@@ -153,6 +214,9 @@ class BasetenApi:
153
214
  allow_truss_download: {"true" if allow_truss_download else "false"}
154
215
  {f'version_name: "{deployment_name}"' if deployment_name else ""}
155
216
  {f"model_origin: {origin.value}" if origin else ""}
217
+ {f'environment_name: "{environment}"' if environment else ""}
218
+ {f"deploy_timeout_minutes: {deploy_timeout_minutes}" if deploy_timeout_minutes is not None else ""}
219
+ {f'team_id: "{team_id}"' if team_id else ""}
156
220
  ) {{
157
221
  model_version {{
158
222
  id
@@ -184,6 +248,7 @@ class BasetenApi:
184
248
  deployment_name: Optional[str] = None,
185
249
  environment: Optional[str] = None,
186
250
  preserve_env_instance_type: bool = True,
251
+ deploy_timeout_minutes: Optional[int] = None,
187
252
  ):
188
253
  query_string = f"""
189
254
  mutation ($trussUserEnv: String) {{
@@ -197,6 +262,7 @@ class BasetenApi:
197
262
  preserve_env_instance_type: {"true" if preserve_env_instance_type else "false"}
198
263
  {f'name: "{deployment_name}"' if deployment_name else ""}
199
264
  {f'environment_name: "{environment}"' if environment else ""}
265
+ {f"deploy_timeout_minutes: {deploy_timeout_minutes}" if deploy_timeout_minutes is not None else ""}
200
266
  ) {{
201
267
  model_version {{
202
268
  id
@@ -226,6 +292,8 @@ class BasetenApi:
226
292
  truss_user_env: b10_types.TrussUserEnv,
227
293
  allow_truss_download=True,
228
294
  origin: Optional[b10_types.ModelOrigin] = None,
295
+ deploy_timeout_minutes: Optional[int] = None,
296
+ team_id: Optional[str] = None,
229
297
  ):
230
298
  query_string = f"""
231
299
  mutation ($trussUserEnv: String) {{
@@ -235,6 +303,8 @@ class BasetenApi:
235
303
  truss_user_env: $trussUserEnv
236
304
  allow_truss_download: {"true" if allow_truss_download else "false"}
237
305
  {f"model_origin: {origin.value}" if origin else ""}
306
+ {f"deploy_timeout_minutes: {deploy_timeout_minutes}" if deploy_timeout_minutes is not None else ""}
307
+ {f'team_id: "{team_id}"' if team_id else ""}
238
308
  ) {{
239
309
  model_version {{
240
310
  id
@@ -265,7 +335,13 @@ class BasetenApi:
265
335
  chain_name: Optional[str] = None,
266
336
  environment: Optional[str] = None,
267
337
  is_draft: bool = False,
338
+ original_source_artifact_s3_key: Optional[str] = None,
339
+ allow_truss_download: Optional[bool] = True,
340
+ deployment_name: Optional[str] = None,
341
+ team_id: Optional[str] = None,
268
342
  ):
343
+ if allow_truss_download is None:
344
+ allow_truss_download = True
269
345
  entrypoint_str = _chainlet_data_atomic_to_graphql_mutation(entrypoint)
270
346
 
271
347
  dependencies_str = ", ".join(
@@ -275,13 +351,32 @@ class BasetenApi:
275
351
  ]
276
352
  )
277
353
 
354
+ params = []
355
+ if chain_id:
356
+ params.append(f'chain_id: "{chain_id}"')
357
+ if chain_name:
358
+ params.append(f'chain_name: "{chain_name}"')
359
+ if environment:
360
+ params.append(f'environment: "{environment}"')
361
+ if original_source_artifact_s3_key:
362
+ params.append(
363
+ f'original_source_artifact_s3_key: "{original_source_artifact_s3_key}"'
364
+ )
365
+ if team_id:
366
+ params.append(f'team_id: "{team_id}"')
367
+
368
+ params.append(f"is_draft: {str(is_draft).lower()}")
369
+ if allow_truss_download is False:
370
+ params.append("allow_truss_download: false")
371
+ if deployment_name:
372
+ params.append(f'deployment_name: "{deployment_name}"')
373
+
374
+ params_str = PARAMS_INDENT.join(params)
375
+
278
376
  query_string = f"""
279
377
  mutation ($trussUserEnv: String) {{
280
378
  deploy_chain_atomic(
281
- {f'chain_id: "{chain_id}"' if chain_id else ""}
282
- {f'chain_name: "{chain_name}"' if chain_name else ""}
283
- {f'environment: "{environment}"' if environment else ""}
284
- is_draft: {str(is_draft).lower()}
379
+ {params_str}
285
380
  entrypoint: {entrypoint_str}
286
381
  dependencies: [{dependencies_str}]
287
382
  truss_user_env: $trussUserEnv
@@ -303,18 +398,24 @@ class BasetenApi:
303
398
 
304
399
  return resp["data"]["deploy_chain_atomic"]
305
400
 
306
- def get_chains(self):
401
+ def get_chains(self, team_id: Optional[str] = None):
307
402
  query_string = """
308
403
  {
309
404
  chains {
310
405
  id
311
406
  name
407
+ team {
408
+ name
409
+ }
312
410
  }
313
411
  }
314
412
  """
315
413
 
316
414
  resp = self._post_graphql_query(query_string)
317
- return resp["data"]["chains"]
415
+ chains = resp["data"]["chains"]
416
+
417
+ # TODO(COR-492): Filter by team_id in the backend
418
+ return chains
318
419
 
319
420
  def get_chain_deployments(self, chain_id: str):
320
421
  query_string = f"""
@@ -377,6 +478,10 @@ class BasetenApi:
377
478
  models {
378
479
  id,
379
480
  name
481
+ team {
482
+ id
483
+ name
484
+ }
380
485
  versions{
381
486
  id,
382
487
  semver,
@@ -416,6 +521,10 @@ class BasetenApi:
416
521
  id
417
522
  name
418
523
  hostname
524
+ team {{
525
+ id
526
+ name
527
+ }}
419
528
  versions {{
420
529
  id
421
530
  semver
@@ -568,10 +677,14 @@ class BasetenApi:
568
677
  "v1/api_keys", body={"type": api_key_type.value, "name": name}
569
678
  )
570
679
 
571
- def upsert_training_project(self, training_project):
680
+ def upsert_training_project(self, training_project, team_id: Optional[str] = None):
681
+ if team_id:
682
+ endpoint = f"v1/teams/{team_id}/training_projects"
683
+ else:
684
+ endpoint = "v1/training_projects"
572
685
  resp_json = self._rest_api_client.post(
573
- "v1/training_projects",
574
- body={"training_project": training_project.model_dump()},
686
+ endpoint,
687
+ body={"training_project": training_project.model_dump(exclude_none=True)},
575
688
  )
576
689
  return resp_json["training_project"]
577
690
 
@@ -623,8 +736,29 @@ class BasetenApi:
623
736
  return resp_json["training_projects"]
624
737
 
625
738
  def get_blob_credentials(self, blob_type: b10_types.BlobType):
739
+ if blob_type == b10_types.BlobType.CHAIN:
740
+ return self.get_chain_s3_upload_credentials()
626
741
  return self._rest_api_client.get(f"v1/blobs/credentials/{blob_type.value}")
627
742
 
743
+ def get_chain_s3_upload_credentials(self) -> ChainUploadCredentials:
744
+ """Get chain artifact credentials using GraphQL query."""
745
+ query = """
746
+ query {
747
+ chain_s3_upload_credentials {
748
+ s3_bucket
749
+ s3_key
750
+ aws_access_key_id
751
+ aws_secret_access_key
752
+ aws_session_token
753
+ }
754
+ }
755
+ """
756
+ response = self._post_graphql_query(query)
757
+
758
+ return ChainUploadCredentials.model_validate(
759
+ response["data"]["chain_s3_upload_credentials"]
760
+ )
761
+
628
762
  def get_training_job_metrics(
629
763
  self,
630
764
  project_id: str,
@@ -750,3 +884,75 @@ class BasetenApi:
750
884
 
751
885
  # NB(nikhil): reverse order so latest logs are at the end
752
886
  return resp_json["logs"][::-1]
887
+
888
+ def create_model_version_from_inference_template(self, request_data: dict):
889
+ """
890
+ Create a model version from an inference template using GraphQL mutation.
891
+
892
+ Args:
893
+ request_data: Dictionary containing the request structure with metadata,
894
+ weights_sources, inference_stack, and instance_type_id
895
+ """
896
+ query_string = """
897
+ mutation ($request: CreateModelVersionFromInferenceTemplateRequest!) {
898
+ create_model_version_from_inference_template(request: $request) {
899
+ model_version {
900
+ id
901
+ name
902
+ }
903
+ truss_config
904
+ }
905
+ }
906
+ """
907
+ resp = self._post_graphql_query(
908
+ query_string, variables={"request": request_data}
909
+ )
910
+ return resp["data"]["create_model_version_from_inference_template"]
911
+
912
+ def get_instance_types(self) -> List[InstanceTypeV1]:
913
+ """
914
+ Get all available instance types via GraphQL API.
915
+ """
916
+ query_string = """
917
+ query Instances {
918
+ listedInstances: listed_instances {
919
+ id
920
+ name
921
+ millicpuLimit: millicpu_limit
922
+ memoryLimit: memory_limit
923
+ gpuCount: gpu_count
924
+ gpuType: gpu_type
925
+ gpuMemory: gpu_memory
926
+ default
927
+ displayName: display_name
928
+ nodeCount: node_count
929
+ price
930
+ limitedCapacity: limited_capacity
931
+ }
932
+ }
933
+ """
934
+
935
+ resp = self._post_graphql_query(query_string)
936
+ instance_types_data = resp["data"]["listedInstances"]
937
+ return [
938
+ InstanceTypeV1(**instance_type) for instance_type in instance_types_data
939
+ ]
940
+
941
+ def get_teams(self) -> Dict[str, Dict[str, str]]:
942
+ """
943
+ Get all available teams via GraphQL API.
944
+ Returns a dictionary mapping team name to team data (with 'id' and 'name' keys).
945
+ """
946
+ query_string = """
947
+ query Teams {
948
+ teams {
949
+ id
950
+ name
951
+ }
952
+ }
953
+ """
954
+
955
+ resp = self._post_graphql_query(query_string)
956
+ teams_data = resp["data"]["teams"]
957
+ # Convert list to dict mapping team_name -> team
958
+ return {team["name"]: team for team in teams_data}
@@ -8,6 +8,7 @@ from typing import IO, TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tup
8
8
  import requests
9
9
 
10
10
  from truss.base.errors import ValidationError
11
+ from truss.util.error_utils import handle_client_error
11
12
 
12
13
  if TYPE_CHECKING:
13
14
  from rich import progress
@@ -80,6 +81,8 @@ class ChainDeploymentHandleAtomic(NamedTuple):
80
81
  chain_id: str
81
82
  chain_deployment_id: str
82
83
  is_draft: bool
84
+ original_source_artifact_s3_key: Optional[str] = None
85
+ allow_truss_download: Optional[bool] = True
83
86
 
84
87
 
85
88
  class ModelVersionHandle(NamedTuple):
@@ -89,19 +92,21 @@ class ModelVersionHandle(NamedTuple):
89
92
  instance_type_name: Optional[str] = None
90
93
 
91
94
 
92
- def get_chain_id_by_name(api: BasetenApi, chain_name: str) -> Optional[str]:
95
+ def get_chain_id_by_name(
96
+ api: BasetenApi, chain_name: str, team_id: Optional[str] = None
97
+ ) -> Optional[str]:
93
98
  """
94
99
  Check if a chain with the given name exists in the Baseten remote.
95
100
 
96
101
  Args:
97
102
  api: BasetenApi instance
98
103
  chain_name: Name of the chain to check for existence
104
+ team_id: Optional team_id to filter chains by team
99
105
 
100
106
  Returns:
101
107
  chain_id if present, otherwise None
102
108
  """
103
- chains = api.get_chains()
104
-
109
+ chains = api.get_chains(team_id=team_id)
105
110
  chain_name_id_mapping = {chain["name"]: chain["id"] for chain in chains}
106
111
  return chain_name_id_mapping.get(chain_name)
107
112
 
@@ -127,6 +132,10 @@ def create_chain_atomic(
127
132
  is_draft: bool,
128
133
  truss_user_env: b10_types.TrussUserEnv,
129
134
  environment: Optional[str],
135
+ original_source_artifact_s3_key: Optional[str] = None,
136
+ allow_truss_download: bool = True,
137
+ deployment_name: Optional[str] = None,
138
+ team_id: Optional[str] = None,
130
139
  ) -> ChainDeploymentHandleAtomic:
131
140
  if environment and is_draft:
132
141
  logging.info(
@@ -135,7 +144,7 @@ def create_chain_atomic(
135
144
  )
136
145
  is_draft = False
137
146
 
138
- chain_id = get_chain_id_by_name(api, chain_name)
147
+ chain_id = get_chain_id_by_name(api, chain_name, team_id=team_id)
139
148
 
140
149
  # TODO(Tyron): Refactor for better readability:
141
150
  # 1. Prepare all arguments for `deploy_chain_atomic`.
@@ -149,6 +158,10 @@ def create_chain_atomic(
149
158
  chain_name=chain_name,
150
159
  is_draft=True,
151
160
  truss_user_env=truss_user_env,
161
+ original_source_artifact_s3_key=original_source_artifact_s3_key,
162
+ allow_truss_download=allow_truss_download,
163
+ deployment_name=deployment_name,
164
+ team_id=team_id,
152
165
  )
153
166
  elif chain_id:
154
167
  # This is the only case where promote has relevance, since
@@ -162,6 +175,10 @@ def create_chain_atomic(
162
175
  chain_id=chain_id,
163
176
  environment=environment,
164
177
  truss_user_env=truss_user_env,
178
+ original_source_artifact_s3_key=original_source_artifact_s3_key,
179
+ allow_truss_download=allow_truss_download,
180
+ deployment_name=deployment_name,
181
+ team_id=team_id,
165
182
  )
166
183
  except ApiError as e:
167
184
  if (
@@ -182,6 +199,10 @@ def create_chain_atomic(
182
199
  dependencies=dependencies,
183
200
  chain_name=chain_name,
184
201
  truss_user_env=truss_user_env,
202
+ original_source_artifact_s3_key=original_source_artifact_s3_key,
203
+ allow_truss_download=allow_truss_download,
204
+ deployment_name=deployment_name,
205
+ team_id=team_id,
185
206
  )
186
207
 
187
208
  return ChainDeploymentHandleAtomic(
@@ -189,6 +210,8 @@ def create_chain_atomic(
189
210
  chain_id=res["chain_deployment"]["chain"]["id"],
190
211
  hostname=res["chain_deployment"]["chain"]["hostname"],
191
212
  is_draft=is_draft,
213
+ original_source_artifact_s3_key=original_source_artifact_s3_key,
214
+ allow_truss_download=allow_truss_download,
192
215
  )
193
216
 
194
217
 
@@ -342,6 +365,33 @@ def upload_truss(
342
365
  return s3_key
343
366
 
344
367
 
368
+ def upload_chain_artifact(
369
+ api: BasetenApi,
370
+ serialize_file: IO,
371
+ progress_bar: Optional[Type["progress.Progress"]],
372
+ ) -> str:
373
+ """
374
+ Upload a chain artifact to the Baseten remote.
375
+
376
+ Args:
377
+ api: BasetenApi instance
378
+ serialize_file: File-like object containing the serialized chain artifact
379
+
380
+ Returns:
381
+ The S3 key of the uploaded file
382
+ """
383
+ credentials = api.get_chain_s3_upload_credentials()
384
+ with handle_client_error("Uploading chain source"):
385
+ multipart_upload_boto3(
386
+ serialize_file.name,
387
+ credentials.s3_bucket,
388
+ credentials.s3_key,
389
+ credentials.aws_credentials.model_dump(),
390
+ progress_bar,
391
+ )
392
+ return credentials.s3_key
393
+
394
+
345
395
  def create_truss_service(
346
396
  api: BasetenApi,
347
397
  model_name: str,
@@ -357,6 +407,8 @@ def create_truss_service(
357
407
  origin: Optional[b10_types.ModelOrigin] = None,
358
408
  environment: Optional[str] = None,
359
409
  preserve_env_instance_type: bool = True,
410
+ deploy_timeout_minutes: Optional[int] = None,
411
+ team_id: Optional[str] = None,
360
412
  ) -> ModelVersionHandle:
361
413
  """
362
414
  Create a model in the Baseten remote.
@@ -372,6 +424,7 @@ def create_truss_service(
372
424
  to zero.
373
425
  deployment_name: Name to apply to the created deployment. Not applied to
374
426
  development model.
427
+ team_id: ID of the team to create the model in.
375
428
 
376
429
  Returns:
377
430
  A Model Version handle.
@@ -384,6 +437,8 @@ def create_truss_service(
384
437
  truss_user_env,
385
438
  allow_truss_download=allow_truss_download,
386
439
  origin=origin,
440
+ deploy_timeout_minutes=deploy_timeout_minutes,
441
+ team_id=team_id,
387
442
  )
388
443
 
389
444
  return ModelVersionHandle(
@@ -398,9 +453,6 @@ def create_truss_service(
398
453
  )
399
454
 
400
455
  if model_id is None:
401
- if environment and environment != PRODUCTION_ENVIRONMENT_NAME:
402
- raise ValueError(NO_ENVIRONMENTS_EXIST_ERROR_MESSAGING)
403
-
404
456
  model_version_json = api.create_model_from_truss(
405
457
  model_name,
406
458
  s3_key,
@@ -410,6 +462,9 @@ def create_truss_service(
410
462
  allow_truss_download=allow_truss_download,
411
463
  deployment_name=deployment_name,
412
464
  origin=origin,
465
+ environment=environment,
466
+ deploy_timeout_minutes=deploy_timeout_minutes,
467
+ team_id=team_id,
413
468
  )
414
469
 
415
470
  return ModelVersionHandle(
@@ -434,6 +489,7 @@ def create_truss_service(
434
489
  deployment_name=deployment_name,
435
490
  environment=environment,
436
491
  preserve_env_instance_type=preserve_env_instance_type,
492
+ deploy_timeout_minutes=deploy_timeout_minutes,
437
493
  )
438
494
  except ApiError as e:
439
495
  if (
@@ -120,6 +120,7 @@ class TrussUserEnv(pydantic.BaseModel):
120
120
  class BlobType(Enum):
121
121
  MODEL = "model"
122
122
  TRAIN = "train"
123
+ CHAIN = "chain"
123
124
 
124
125
 
125
126
  class FileSummary(pydantic.BaseModel):