truss 0.11.6rc102__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. truss/api/__init__.py +5 -2
  2. truss/base/constants.py +1 -0
  3. truss/base/trt_llm_config.py +14 -3
  4. truss/base/truss_config.py +19 -4
  5. truss/cli/chains_commands.py +49 -1
  6. truss/cli/cli.py +38 -7
  7. truss/cli/logs/base_watcher.py +31 -12
  8. truss/cli/logs/model_log_watcher.py +24 -1
  9. truss/cli/remote_cli.py +29 -0
  10. truss/cli/resolvers/chain_team_resolver.py +82 -0
  11. truss/cli/resolvers/model_team_resolver.py +90 -0
  12. truss/cli/resolvers/training_project_team_resolver.py +81 -0
  13. truss/cli/train/cache.py +332 -0
  14. truss/cli/train/core.py +57 -163
  15. truss/cli/train/deploy_checkpoints/__init__.py +2 -2
  16. truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +236 -103
  17. truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +1 -52
  18. truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +1 -86
  19. truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py +1 -85
  20. truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +1 -56
  21. truss/cli/train/types.py +18 -9
  22. truss/cli/train_commands.py +180 -35
  23. truss/cli/utils/common.py +40 -3
  24. truss/contexts/image_builder/serving_image_builder.py +17 -4
  25. truss/remote/baseten/api.py +215 -9
  26. truss/remote/baseten/core.py +63 -7
  27. truss/remote/baseten/custom_types.py +1 -0
  28. truss/remote/baseten/remote.py +42 -2
  29. truss/remote/baseten/service.py +0 -7
  30. truss/remote/baseten/utils/transfer.py +5 -2
  31. truss/templates/base.Dockerfile.jinja +8 -4
  32. truss/templates/control/control/application.py +51 -26
  33. truss/templates/control/control/endpoints.py +1 -5
  34. truss/templates/control/control/helpers/inference_server_process_controller.py +10 -4
  35. truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py +33 -18
  36. truss/templates/control/control/server.py +1 -1
  37. truss/templates/control/requirements.txt +1 -2
  38. truss/templates/docker_server/proxy.conf.jinja +13 -0
  39. truss/templates/docker_server/supervisord.conf.jinja +2 -1
  40. truss/templates/no_build.Dockerfile.jinja +1 -0
  41. truss/templates/server/requirements.txt +2 -3
  42. truss/templates/server/truss_server.py +2 -5
  43. truss/templates/server.Dockerfile.jinja +12 -12
  44. truss/templates/shared/lazy_data_resolver.py +214 -2
  45. truss/templates/shared/util.py +6 -5
  46. truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
  47. truss/tests/cli/test_chains_cli.py +144 -0
  48. truss/tests/cli/test_cli.py +134 -1
  49. truss/tests/cli/test_cli_utils_common.py +11 -0
  50. truss/tests/cli/test_model_team_resolver.py +279 -0
  51. truss/tests/cli/train/test_cache_view.py +240 -3
  52. truss/tests/cli/train/test_deploy_checkpoints.py +2 -846
  53. truss/tests/cli/train/test_train_cli_core.py +2 -2
  54. truss/tests/cli/train/test_train_team_parameter.py +395 -0
  55. truss/tests/conftest.py +187 -0
  56. truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
  57. truss/tests/remote/baseten/test_api.py +122 -3
  58. truss/tests/remote/baseten/test_chain_upload.py +294 -0
  59. truss/tests/remote/baseten/test_core.py +86 -0
  60. truss/tests/remote/baseten/test_remote.py +216 -288
  61. truss/tests/remote/baseten/test_service.py +56 -0
  62. truss/tests/templates/control/control/conftest.py +20 -0
  63. truss/tests/templates/control/control/test_endpoints.py +4 -0
  64. truss/tests/templates/control/control/test_server.py +8 -24
  65. truss/tests/templates/control/control/test_server_integration.py +4 -2
  66. truss/tests/test_config.py +21 -12
  67. truss/tests/test_data/server.Dockerfile +3 -1
  68. truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
  69. truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
  70. truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
  71. truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
  72. truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
  73. truss/tests/test_model_inference.py +13 -0
  74. truss/tests/util/test_env_vars.py +8 -3
  75. truss/util/__init__.py +0 -0
  76. truss/util/env_vars.py +19 -8
  77. truss/util/error_utils.py +37 -0
  78. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/METADATA +2 -2
  79. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/RECORD +88 -70
  80. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
  81. truss_chains/deployment/deployment_client.py +16 -4
  82. truss_chains/private_types.py +18 -0
  83. truss_chains/public_api.py +3 -0
  84. truss_train/definitions.py +6 -4
  85. truss_train/deployment.py +43 -21
  86. truss_train/public_api.py +4 -2
  87. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
  88. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
@@ -165,9 +165,9 @@ def test_create_model_version_from_truss(mock_post, baseten_api):
165
165
  "config_str",
166
166
  "semver_bump",
167
167
  b10_types.TrussUserEnv.collect(),
168
- False,
169
- "deployment_name",
170
- "production",
168
+ preserve_previous_prod_deployment=False,
169
+ deployment_name="deployment_name",
170
+ environment="production",
171
171
  )
172
172
 
173
173
  gql_mutation = mock_post.call_args[1]["json"]["query"]
@@ -182,6 +182,7 @@ def test_create_model_version_from_truss(mock_post, baseten_api):
182
182
  assert 'name: "deployment_name"' in gql_mutation
183
183
  assert 'environment_name: "production"' in gql_mutation
184
184
  assert "preserve_env_instance_type: true" in gql_mutation
185
+ assert "deploy_timeout_minutes: " not in gql_mutation
185
186
 
186
187
 
187
188
  @mock.patch("requests.post", return_value=mock_create_model_version_response())
@@ -211,6 +212,7 @@ def test_create_model_version_from_truss_does_not_send_deployment_name_if_not_sp
211
212
  assert " name: " not in gql_mutation
212
213
  assert "environment_name: " not in gql_mutation
213
214
  assert "preserve_env_instance_type: false" in gql_mutation
215
+ assert "deploy_timeout_minutes: " not in gql_mutation
214
216
 
215
217
 
216
218
  @mock.patch("requests.post", return_value=mock_create_model_version_response())
@@ -242,6 +244,57 @@ def test_create_model_version_from_truss_does_not_scale_old_prod_to_zero_if_keep
242
244
  assert " name: " not in gql_mutation
243
245
  assert 'environment_name: "staging"' in gql_mutation
244
246
  assert "preserve_env_instance_type: true" in gql_mutation
247
+ assert "deploy_timeout_minutes: " not in gql_mutation
248
+
249
+
250
+ @mock.patch("requests.post", return_value=mock_create_model_version_response())
251
+ def test_create_model_version_from_truss_with_deploy_timeout_minutes(
252
+ mock_post, baseten_api
253
+ ):
254
+ baseten_api.create_model_version_from_truss(
255
+ "model_id",
256
+ "s3key",
257
+ "config_str",
258
+ "semver_bump",
259
+ b10_types.TrussUserEnv.collect(),
260
+ preserve_previous_prod_deployment=False,
261
+ deployment_name="deployment_name",
262
+ environment="production",
263
+ deploy_timeout_minutes=300,
264
+ )
265
+
266
+ gql_mutation = mock_post.call_args[1]["json"]["query"]
267
+ assert 'model_id: "model_id"' in gql_mutation
268
+ assert 's3_key: "s3key"' in gql_mutation
269
+ assert 'config: "config_str"' in gql_mutation
270
+ assert 'semver_bump: "semver_bump"' in gql_mutation
271
+ assert {
272
+ "trussUserEnv": b10_types.TrussUserEnv.collect().model_dump_json()
273
+ } == mock_post.call_args[1]["json"]["variables"]
274
+ assert "scale_down_old_production: true" in gql_mutation
275
+ assert 'name: "deployment_name"' in gql_mutation
276
+ assert 'environment_name: "production"' in gql_mutation
277
+ assert "preserve_env_instance_type: true" in gql_mutation
278
+ assert "deploy_timeout_minutes: 300" in gql_mutation
279
+
280
+
281
+ @mock.patch("requests.post", return_value=mock_create_model_version_response())
282
+ def test_create_model_version_from_truss_with_deploy_timeout_minutes_zero(
283
+ mock_post, baseten_api
284
+ ):
285
+ """Test that deploy_timeout_minutes of 0 is handled correctly"""
286
+ baseten_api.create_model_version_from_truss(
287
+ "model_id",
288
+ "s3key",
289
+ "config_str",
290
+ "semver_bump",
291
+ b10_types.TrussUserEnv.collect(),
292
+ preserve_previous_prod_deployment=False,
293
+ deploy_timeout_minutes=0,
294
+ )
295
+
296
+ gql_mutation = mock_post.call_args[1]["json"]["query"]
297
+ assert "deploy_timeout_minutes: 0" in gql_mutation
245
298
 
246
299
 
247
300
  @mock.patch("requests.post", return_value=mock_create_model_response())
@@ -332,6 +385,48 @@ def test_create_development_model_from_truss_with_allow_truss_download(
332
385
  "trussUserEnv": b10_types.TrussUserEnv.collect().model_dump_json()
333
386
  } == mock_post.call_args[1]["json"]["variables"]
334
387
  assert "allow_truss_download: false" in gql_mutation
388
+ assert "deploy_timeout_minutes: " not in gql_mutation
389
+
390
+
391
+ @mock.patch("requests.post", return_value=mock_create_development_model_response())
392
+ def test_create_development_model_from_truss_with_deploy_timeout_minutes(
393
+ mock_post, baseten_api
394
+ ):
395
+ baseten_api.create_development_model_from_truss(
396
+ "model_name",
397
+ "s3key",
398
+ "config_str",
399
+ b10_types.TrussUserEnv.collect(),
400
+ allow_truss_download=False,
401
+ deploy_timeout_minutes=300,
402
+ )
403
+
404
+ gql_mutation = mock_post.call_args[1]["json"]["query"]
405
+ assert 'name: "model_name"' in gql_mutation
406
+ assert 's3_key: "s3key"' in gql_mutation
407
+ assert 'config: "config_str"' in gql_mutation
408
+ assert {
409
+ "trussUserEnv": b10_types.TrussUserEnv.collect().model_dump_json()
410
+ } == mock_post.call_args[1]["json"]["variables"]
411
+ assert "allow_truss_download: false" in gql_mutation
412
+ assert "deploy_timeout_minutes: 300" in gql_mutation
413
+
414
+
415
+ @mock.patch("requests.post", return_value=mock_create_development_model_response())
416
+ def test_create_development_model_from_truss_with_deploy_timeout_minutes_zero(
417
+ mock_post, baseten_api
418
+ ):
419
+ """Test that deploy_timeout_minutes of 0 is handled correctly"""
420
+ baseten_api.create_development_model_from_truss(
421
+ "model_name",
422
+ "s3key",
423
+ "config_str",
424
+ b10_types.TrussUserEnv.collect(),
425
+ deploy_timeout_minutes=0,
426
+ )
427
+
428
+ gql_mutation = mock_post.call_args[1]["json"]["query"]
429
+ assert "deploy_timeout_minutes: 0" in gql_mutation
335
430
 
336
431
 
337
432
  @mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response())
@@ -357,6 +452,30 @@ def test_deploy_chain_deployment(mock_post, baseten_api):
357
452
  assert 'chain_id: "chain_id"' in gql_mutation
358
453
  assert "dependencies:" in gql_mutation
359
454
  assert "entrypoint:" in gql_mutation
455
+ assert "deployment_name" not in gql_mutation
456
+
457
+
458
+ @mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response())
459
+ def test_deploy_chain_deployment_with_deployment_name(mock_post, baseten_api):
460
+ baseten_api.deploy_chain_atomic(
461
+ environment="production",
462
+ chain_id="chain_id",
463
+ dependencies=[],
464
+ entrypoint=ChainletDataAtomic(
465
+ name="chainlet-1",
466
+ oracle=OracleData(
467
+ model_name="model-1",
468
+ s3_key="s3-key-1",
469
+ encoded_config_str="encoded-config-str-1",
470
+ ),
471
+ ),
472
+ truss_user_env=b10_types.TrussUserEnv.collect(),
473
+ deployment_name="chain-deployment-name",
474
+ )
475
+
476
+ gql_mutation = mock_post.call_args[1]["json"]["query"]
477
+
478
+ assert 'deployment_name: "chain-deployment-name"' in gql_mutation
360
479
 
361
480
 
362
481
  @mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response())
@@ -0,0 +1,294 @@
1
+ import pathlib
2
+ import tempfile
3
+ from unittest.mock import Mock, patch
4
+
5
+ import pytest
6
+
7
+ from truss.remote.baseten import custom_types as b10_types
8
+ from truss.remote.baseten.api import BasetenApi
9
+ from truss.remote.baseten.core import upload_chain_artifact
10
+ from truss.remote.baseten.remote import BasetenRemote
11
+
12
+
13
+ @pytest.fixture
14
+ def mock_push_data():
15
+ """Fixture providing mock push data for tests."""
16
+ mock_push_data = Mock()
17
+ mock_push_data.model_name = "test-model"
18
+ mock_push_data.s3_key = "models/test-key"
19
+ mock_push_data.encoded_config_str = "encoded_config"
20
+ mock_push_data.is_draft = False
21
+ mock_push_data.model_id = "model-id"
22
+ mock_push_data.version_name = None
23
+ return mock_push_data
24
+
25
+
26
+ @pytest.fixture
27
+ def mock_remote_context():
28
+ """Fixture providing mock remote and context managers for tests."""
29
+ api = Mock(spec=BasetenApi)
30
+
31
+ remote = BasetenRemote("https://test.baseten.co", "test-api-key")
32
+ remote._api = api
33
+
34
+ chain_name = "test-chain"
35
+ entrypoint_artifact = Mock()
36
+ entrypoint_artifact.truss_dir = "/path/to/truss"
37
+ entrypoint_artifact.display_name = "entrypoint"
38
+
39
+ dependency_artifacts = []
40
+ truss_user_env = Mock()
41
+ chain_root = pathlib.Path("/path/to/chain")
42
+
43
+ with patch.object(remote, "_prepare_push") as mock_prepare_push:
44
+ with patch("truss.remote.baseten.remote.truss_build.load") as mock_load:
45
+ mock_truss_handle = Mock()
46
+ mock_truss_handle.spec.config.model_name = "test-model"
47
+ mock_load.return_value = mock_truss_handle
48
+
49
+ yield {
50
+ "remote": remote,
51
+ "api": api,
52
+ "chain_name": chain_name,
53
+ "entrypoint_artifact": entrypoint_artifact,
54
+ "dependency_artifacts": dependency_artifacts,
55
+ "truss_user_env": truss_user_env,
56
+ "chain_root": chain_root,
57
+ "mock_prepare_push": mock_prepare_push,
58
+ "mock_load": mock_load,
59
+ "mock_truss_handle": mock_truss_handle,
60
+ }
61
+
62
+
63
+ def test_get_blob_credentials_for_chain():
64
+ """Test that get_blob_credentials works correctly for chain blob type using GraphQL."""
65
+ mock_graphql_response = {
66
+ "data": {
67
+ "chain_s3_upload_credentials": {
68
+ "s3_bucket": "test-chain-bucket",
69
+ "s3_key": "chains/test-uuid/chain.tgz",
70
+ "aws_access_key_id": "test_access_key",
71
+ "aws_secret_access_key": "test_secret_key",
72
+ "aws_session_token": "test_session_token",
73
+ }
74
+ }
75
+ }
76
+
77
+ # Create a real API instance and mock the GraphQL call
78
+ mock_auth_service = Mock()
79
+ mock_auth_service.authenticate.return_value = Mock(value="test-token")
80
+ api = BasetenApi("https://test.baseten.co", mock_auth_service)
81
+ with patch.object(api, "_post_graphql_query") as mock_graphql:
82
+ mock_graphql.return_value = mock_graphql_response
83
+
84
+ result = api.get_chain_s3_upload_credentials()
85
+
86
+ assert result.s3_bucket == "test-chain-bucket"
87
+ assert result.s3_key == "chains/test-uuid/chain.tgz"
88
+ assert result.aws_access_key_id == "test_access_key"
89
+ assert result.aws_secret_access_key == "test_secret_key"
90
+ assert result.aws_session_token == "test_session_token"
91
+
92
+ mock_graphql.assert_called_once()
93
+ call_args = mock_graphql.call_args
94
+ assert "chain_s3_upload_credentials" in call_args[0][0]
95
+
96
+
97
+ def test_get_blob_credentials_for_other_types_uses_rest():
98
+ """Test that get_blob_credentials uses REST API for non-chain blob types."""
99
+ mock_response = {
100
+ "s3_bucket": "test-bucket",
101
+ "s3_key": "test-key",
102
+ "creds": {
103
+ "aws_access_key_id": "test_access_key",
104
+ "aws_secret_access_key": "test_secret_key",
105
+ "aws_session_token": "test_session_token",
106
+ },
107
+ }
108
+
109
+ mock_auth_service = Mock()
110
+ mock_auth_service.authenticate.return_value = Mock(value="test-token")
111
+ api = BasetenApi("https://test.baseten.co", mock_auth_service)
112
+ with patch.object(api, "_rest_api_client") as mock_client, patch.object(
113
+ api, "_post_graphql_query"
114
+ ) as mock_graphql:
115
+ mock_client.get.return_value = mock_response
116
+
117
+ result = api.get_blob_credentials(b10_types.BlobType.MODEL)
118
+
119
+ assert result["s3_bucket"] == "test-bucket"
120
+ assert result["s3_key"] == "test-key"
121
+
122
+ mock_client.get.assert_called_once_with("v1/blobs/credentials/model")
123
+ mock_graphql.assert_not_called()
124
+
125
+
126
+ @patch("truss.remote.baseten.core.multipart_upload_boto3")
127
+ def test_upload_chain_artifact_function(mock_multipart_upload):
128
+ """Test the upload_chain_artifact function."""
129
+ # Mock ChainUploadCredentials object
130
+ mock_credentials = Mock()
131
+ mock_credentials.s3_bucket = "test-chain-bucket"
132
+ mock_credentials.s3_key = "chains/test-uuid/chain.tgz"
133
+ mock_credentials.aws_credentials = Mock()
134
+ mock_credentials.aws_credentials.model_dump.return_value = {
135
+ "aws_access_key_id": "test_access_key",
136
+ "aws_secret_access_key": "test_secret_key",
137
+ "aws_session_token": "test_session_token",
138
+ }
139
+
140
+ api = Mock(spec=BasetenApi)
141
+ api.get_chain_s3_upload_credentials.return_value = mock_credentials
142
+
143
+ with tempfile.NamedTemporaryFile(suffix=".tgz", delete=False) as temp_file:
144
+ temp_file.write(b"test chain content")
145
+ temp_file.flush()
146
+
147
+ result = upload_chain_artifact(api, temp_file, None)
148
+
149
+ assert result == "chains/test-uuid/chain.tgz"
150
+
151
+ api.get_chain_s3_upload_credentials.assert_called_once_with()
152
+
153
+ mock_multipart_upload.assert_called_once()
154
+ call_args = mock_multipart_upload.call_args
155
+ assert call_args[0][0] == temp_file.name # file path
156
+ assert call_args[0][1] == "test-chain-bucket" # bucket
157
+ assert call_args[0][2] == "chains/test-uuid/chain.tgz" # key
158
+ assert call_args[0][3] == { # credentials
159
+ "aws_access_key_id": "test_access_key",
160
+ "aws_secret_access_key": "test_secret_key",
161
+ "aws_session_token": "test_session_token",
162
+ }
163
+
164
+
165
+ @patch("truss.remote.baseten.remote.upload_chain_artifact")
166
+ @patch("truss.remote.baseten.remote.archive_dir")
167
+ @patch("truss.remote.baseten.remote.create_chain_atomic")
168
+ def test_push_chain_atomic_with_chain_upload(
169
+ mock_create_chain_atomic,
170
+ mock_archive_dir,
171
+ mock_upload_chain_artifact,
172
+ mock_push_data,
173
+ mock_remote_context,
174
+ ):
175
+ """Test that push_chain_atomic uploads raw chain artifact when chain_root is provided."""
176
+ mock_create_chain_atomic.return_value = Mock()
177
+ mock_archive_dir.return_value = Mock()
178
+ mock_upload_chain_artifact.return_value = "chains/test-uuid/chain.tgz"
179
+
180
+ context = mock_remote_context
181
+ remote = context["remote"]
182
+ chain_name = context["chain_name"]
183
+ entrypoint_artifact = context["entrypoint_artifact"]
184
+ dependency_artifacts = context["dependency_artifacts"]
185
+ truss_user_env = context["truss_user_env"]
186
+ chain_root = context["chain_root"]
187
+
188
+ context["mock_prepare_push"].return_value = mock_push_data
189
+ deployment_name = "custom_deployment"
190
+
191
+ result = remote.push_chain_atomic(
192
+ chain_name=chain_name,
193
+ entrypoint_artifact=entrypoint_artifact,
194
+ dependency_artifacts=dependency_artifacts,
195
+ truss_user_env=truss_user_env,
196
+ chain_root=chain_root,
197
+ publish=True,
198
+ deployment_name=deployment_name,
199
+ )
200
+ assert result == mock_create_chain_atomic.return_value
201
+
202
+ mock_archive_dir.assert_called_once_with(dir=chain_root, progress_bar=None)
203
+ mock_upload_chain_artifact.assert_called_once()
204
+ mock_create_chain_atomic.assert_called_once()
205
+ create_kwargs = mock_create_chain_atomic.call_args.kwargs
206
+ assert create_kwargs["deployment_name"] == deployment_name
207
+
208
+ prepare_kwargs = context["mock_prepare_push"].call_args.kwargs
209
+ assert prepare_kwargs["deployment_name"] == deployment_name
210
+
211
+
212
+ @patch("truss.remote.baseten.remote.create_chain_atomic")
213
+ def test_push_chain_atomic_without_chain_upload(
214
+ mock_create_chain_atomic, mock_push_data, mock_remote_context
215
+ ):
216
+ """Test that push_chain_atomic skips chain upload when chain_root is None."""
217
+ mock_create_chain_atomic.return_value = Mock()
218
+
219
+ context = mock_remote_context
220
+ remote = context["remote"]
221
+ chain_name = context["chain_name"]
222
+ entrypoint_artifact = context["entrypoint_artifact"]
223
+ dependency_artifacts = context["dependency_artifacts"]
224
+ truss_user_env = context["truss_user_env"]
225
+
226
+ context["mock_prepare_push"].return_value = mock_push_data
227
+
228
+ with patch("truss.remote.baseten.remote.upload_chain_artifact") as mock_upload:
229
+ with patch(
230
+ "truss.remote.baseten.core.create_tar_with_progress_bar"
231
+ ) as mock_tar:
232
+ # Call push_chain_atomic without chain_root
233
+ result = remote.push_chain_atomic(
234
+ chain_name=chain_name,
235
+ entrypoint_artifact=entrypoint_artifact,
236
+ dependency_artifacts=dependency_artifacts,
237
+ truss_user_env=truss_user_env,
238
+ chain_root=None, # No chain root
239
+ publish=True,
240
+ )
241
+
242
+ assert result
243
+ # Verify chain artifact upload was NOT called
244
+ mock_tar.assert_not_called()
245
+ mock_upload.assert_not_called()
246
+
247
+ mock_create_chain_atomic.assert_called_once()
248
+ create_kwargs = mock_create_chain_atomic.call_args.kwargs
249
+ assert "deployment_name" in create_kwargs
250
+ assert create_kwargs["deployment_name"] is None
251
+
252
+
253
+ @patch("truss.remote.baseten.core.multipart_upload_boto3")
254
+ def test_upload_chain_artifact_error_handling(mock_multipart_upload):
255
+ """Test error handling in upload_chain_artifact."""
256
+ # Mock API to raise an exception
257
+ api = Mock(spec=BasetenApi)
258
+ api.get_chain_s3_upload_credentials.side_effect = Exception("API Error")
259
+
260
+ with tempfile.NamedTemporaryFile(suffix=".tgz") as temp_file:
261
+ # Should raise the exception
262
+ with pytest.raises(Exception, match="API Error"):
263
+ upload_chain_artifact(api, temp_file, None)
264
+
265
+
266
+ def test_upload_chain_artifact_credentials_extraction():
267
+ """Test that credentials are properly extracted from API response."""
268
+ # Mock ChainUploadCredentials object
269
+ mock_credentials = Mock()
270
+ mock_credentials.s3_bucket = "test-bucket"
271
+ mock_credentials.s3_key = "chains/test-uuid/chain.tgz"
272
+ mock_credentials.aws_credentials = Mock()
273
+ mock_credentials.aws_credentials.model_dump.return_value = {
274
+ "aws_access_key_id": "access_key",
275
+ "aws_secret_access_key": "secret_key",
276
+ "aws_session_token": "session_token",
277
+ }
278
+
279
+ api = Mock(spec=BasetenApi)
280
+ api.get_chain_s3_upload_credentials.return_value = mock_credentials
281
+
282
+ with patch("truss.remote.baseten.core.multipart_upload_boto3") as mock_upload:
283
+ with tempfile.NamedTemporaryFile(suffix=".tgz") as temp_file:
284
+ upload_chain_artifact(api, temp_file, None)
285
+
286
+ call_args = mock_upload.call_args
287
+ credentials = call_args[0][3]
288
+
289
+ assert credentials == {
290
+ "aws_access_key_id": "access_key",
291
+ "aws_secret_access_key": "secret_key",
292
+ "aws_session_token": "session_token",
293
+ }
294
+ assert "extra_field" not in credentials
@@ -187,6 +187,7 @@ def test_create_truss_service_handles_existing_model(inputs):
187
187
  _, kwargs = api.create_model_version_from_truss.call_args
188
188
  for k, v in inputs.items():
189
189
  assert kwargs[k] == v
190
+ assert kwargs.get("deploy_timeout_minutes") is None
190
191
 
191
192
 
192
193
  @pytest.mark.parametrize("allow_truss_download", [True, False])
@@ -761,3 +762,88 @@ def test_get_training_job_logs_with_pagination_default_batch_size(baseten_api):
761
762
  query_params = call_args[0][2] # query_params
762
763
 
763
764
  assert query_params["limit"] == MAX_BATCH_SIZE
765
+
766
+
767
+ def test_create_truss_service_passes_deploy_timeout_minutes():
768
+ """Test that deploy_timeout_minutes is passed through to create_model_version_from_truss"""
769
+ api = MagicMock()
770
+ return_value = {
771
+ "id": "model_version_id",
772
+ "oracle": {"id": "model_id", "hostname": "hostname"},
773
+ "instance_type": {"name": "1x2"},
774
+ }
775
+ api.create_model_version_from_truss.return_value = return_value
776
+ version_handle = create_truss_service(
777
+ api,
778
+ "model_name",
779
+ "s3_key",
780
+ "config",
781
+ b10_types.TrussUserEnv.collect(),
782
+ is_draft=False,
783
+ model_id="model_id",
784
+ environment="staging",
785
+ deploy_timeout_minutes=600,
786
+ )
787
+
788
+ assert version_handle.version_id == "model_version_id"
789
+ assert version_handle.model_id == "model_id"
790
+ api.create_model_version_from_truss.assert_called_once()
791
+ _, kwargs = api.create_model_version_from_truss.call_args
792
+ assert kwargs["deploy_timeout_minutes"] == 600
793
+
794
+
795
+ def test_create_truss_service_passes_deploy_timeout_minutes_with_other_params():
796
+ """Test that deploy_timeout_minutes works correctly with other parameters like preserve_env_instance_type"""
797
+ api = MagicMock()
798
+ return_value = {
799
+ "id": "model_version_id",
800
+ "oracle": {"id": "model_id", "hostname": "hostname"},
801
+ "instance_type": {"name": "1x2"},
802
+ }
803
+ api.create_model_version_from_truss.return_value = return_value
804
+ version_handle = create_truss_service(
805
+ api,
806
+ "model_name",
807
+ "s3_key",
808
+ "config",
809
+ b10_types.TrussUserEnv.collect(),
810
+ is_draft=False,
811
+ model_id="model_id",
812
+ environment="production",
813
+ preserve_env_instance_type=False,
814
+ deploy_timeout_minutes=900,
815
+ )
816
+
817
+ assert version_handle.version_id == "model_version_id"
818
+ api.create_model_version_from_truss.assert_called_once()
819
+ _, kwargs = api.create_model_version_from_truss.call_args
820
+ assert kwargs["deploy_timeout_minutes"] == 900
821
+ assert kwargs["preserve_env_instance_type"] is False
822
+ assert kwargs["environment"] == "production"
823
+
824
+
825
+ def test_create_truss_service_passes_deploy_timeout_minutes_for_development_model():
826
+ """Test that deploy_timeout_minutes is passed through to create_development_model_from_truss"""
827
+ api = MagicMock()
828
+ return_value = {
829
+ "id": "model_version_id",
830
+ "oracle": {"id": "model_id", "hostname": "hostname"},
831
+ "instance_type": {"name": "1x2"},
832
+ }
833
+ api.create_development_model_from_truss.return_value = return_value
834
+ version_handle = create_truss_service(
835
+ api,
836
+ "model_name",
837
+ "s3_key",
838
+ "config",
839
+ b10_types.TrussUserEnv.collect(),
840
+ is_draft=True,
841
+ model_id=None,
842
+ deploy_timeout_minutes=600,
843
+ )
844
+
845
+ assert version_handle.version_id == "model_version_id"
846
+ assert version_handle.model_id == "model_id"
847
+ api.create_development_model_from_truss.assert_called_once()
848
+ _, kwargs = api.create_development_model_from_truss.call_args
849
+ assert kwargs["deploy_timeout_minutes"] == 600