truss 0.11.6rc102__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- truss/api/__init__.py +5 -2
- truss/base/constants.py +1 -0
- truss/base/trt_llm_config.py +14 -3
- truss/base/truss_config.py +19 -4
- truss/cli/chains_commands.py +49 -1
- truss/cli/cli.py +38 -7
- truss/cli/logs/base_watcher.py +31 -12
- truss/cli/logs/model_log_watcher.py +24 -1
- truss/cli/remote_cli.py +29 -0
- truss/cli/resolvers/chain_team_resolver.py +82 -0
- truss/cli/resolvers/model_team_resolver.py +90 -0
- truss/cli/resolvers/training_project_team_resolver.py +81 -0
- truss/cli/train/cache.py +332 -0
- truss/cli/train/core.py +57 -163
- truss/cli/train/deploy_checkpoints/__init__.py +2 -2
- truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +236 -103
- truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +1 -52
- truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +1 -86
- truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py +1 -85
- truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +1 -56
- truss/cli/train/types.py +18 -9
- truss/cli/train_commands.py +180 -35
- truss/cli/utils/common.py +40 -3
- truss/contexts/image_builder/serving_image_builder.py +17 -4
- truss/remote/baseten/api.py +215 -9
- truss/remote/baseten/core.py +63 -7
- truss/remote/baseten/custom_types.py +1 -0
- truss/remote/baseten/remote.py +42 -2
- truss/remote/baseten/service.py +0 -7
- truss/remote/baseten/utils/transfer.py +5 -2
- truss/templates/base.Dockerfile.jinja +8 -4
- truss/templates/control/control/application.py +51 -26
- truss/templates/control/control/endpoints.py +1 -5
- truss/templates/control/control/helpers/inference_server_process_controller.py +10 -4
- truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py +33 -18
- truss/templates/control/control/server.py +1 -1
- truss/templates/control/requirements.txt +1 -2
- truss/templates/docker_server/proxy.conf.jinja +13 -0
- truss/templates/docker_server/supervisord.conf.jinja +2 -1
- truss/templates/no_build.Dockerfile.jinja +1 -0
- truss/templates/server/requirements.txt +2 -3
- truss/templates/server/truss_server.py +2 -5
- truss/templates/server.Dockerfile.jinja +12 -12
- truss/templates/shared/lazy_data_resolver.py +214 -2
- truss/templates/shared/util.py +6 -5
- truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
- truss/tests/cli/test_chains_cli.py +144 -0
- truss/tests/cli/test_cli.py +134 -1
- truss/tests/cli/test_cli_utils_common.py +11 -0
- truss/tests/cli/test_model_team_resolver.py +279 -0
- truss/tests/cli/train/test_cache_view.py +240 -3
- truss/tests/cli/train/test_deploy_checkpoints.py +2 -846
- truss/tests/cli/train/test_train_cli_core.py +2 -2
- truss/tests/cli/train/test_train_team_parameter.py +395 -0
- truss/tests/conftest.py +187 -0
- truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
- truss/tests/remote/baseten/test_api.py +122 -3
- truss/tests/remote/baseten/test_chain_upload.py +294 -0
- truss/tests/remote/baseten/test_core.py +86 -0
- truss/tests/remote/baseten/test_remote.py +216 -288
- truss/tests/remote/baseten/test_service.py +56 -0
- truss/tests/templates/control/control/conftest.py +20 -0
- truss/tests/templates/control/control/test_endpoints.py +4 -0
- truss/tests/templates/control/control/test_server.py +8 -24
- truss/tests/templates/control/control/test_server_integration.py +4 -2
- truss/tests/test_config.py +21 -12
- truss/tests/test_data/server.Dockerfile +3 -1
- truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
- truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
- truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
- truss/tests/test_model_inference.py +13 -0
- truss/tests/util/test_env_vars.py +8 -3
- truss/util/__init__.py +0 -0
- truss/util/env_vars.py +19 -8
- truss/util/error_utils.py +37 -0
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/METADATA +2 -2
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/RECORD +88 -70
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
- truss_chains/deployment/deployment_client.py +16 -4
- truss_chains/private_types.py +18 -0
- truss_chains/public_api.py +3 -0
- truss_train/definitions.py +6 -4
- truss_train/deployment.py +43 -21
- truss_train/public_api.py +4 -2
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
|
@@ -165,9 +165,9 @@ def test_create_model_version_from_truss(mock_post, baseten_api):
|
|
|
165
165
|
"config_str",
|
|
166
166
|
"semver_bump",
|
|
167
167
|
b10_types.TrussUserEnv.collect(),
|
|
168
|
-
False,
|
|
169
|
-
"deployment_name",
|
|
170
|
-
"production",
|
|
168
|
+
preserve_previous_prod_deployment=False,
|
|
169
|
+
deployment_name="deployment_name",
|
|
170
|
+
environment="production",
|
|
171
171
|
)
|
|
172
172
|
|
|
173
173
|
gql_mutation = mock_post.call_args[1]["json"]["query"]
|
|
@@ -182,6 +182,7 @@ def test_create_model_version_from_truss(mock_post, baseten_api):
|
|
|
182
182
|
assert 'name: "deployment_name"' in gql_mutation
|
|
183
183
|
assert 'environment_name: "production"' in gql_mutation
|
|
184
184
|
assert "preserve_env_instance_type: true" in gql_mutation
|
|
185
|
+
assert "deploy_timeout_minutes: " not in gql_mutation
|
|
185
186
|
|
|
186
187
|
|
|
187
188
|
@mock.patch("requests.post", return_value=mock_create_model_version_response())
|
|
@@ -211,6 +212,7 @@ def test_create_model_version_from_truss_does_not_send_deployment_name_if_not_sp
|
|
|
211
212
|
assert " name: " not in gql_mutation
|
|
212
213
|
assert "environment_name: " not in gql_mutation
|
|
213
214
|
assert "preserve_env_instance_type: false" in gql_mutation
|
|
215
|
+
assert "deploy_timeout_minutes: " not in gql_mutation
|
|
214
216
|
|
|
215
217
|
|
|
216
218
|
@mock.patch("requests.post", return_value=mock_create_model_version_response())
|
|
@@ -242,6 +244,57 @@ def test_create_model_version_from_truss_does_not_scale_old_prod_to_zero_if_keep
|
|
|
242
244
|
assert " name: " not in gql_mutation
|
|
243
245
|
assert 'environment_name: "staging"' in gql_mutation
|
|
244
246
|
assert "preserve_env_instance_type: true" in gql_mutation
|
|
247
|
+
assert "deploy_timeout_minutes: " not in gql_mutation
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
@mock.patch("requests.post", return_value=mock_create_model_version_response())
|
|
251
|
+
def test_create_model_version_from_truss_with_deploy_timeout_minutes(
|
|
252
|
+
mock_post, baseten_api
|
|
253
|
+
):
|
|
254
|
+
baseten_api.create_model_version_from_truss(
|
|
255
|
+
"model_id",
|
|
256
|
+
"s3key",
|
|
257
|
+
"config_str",
|
|
258
|
+
"semver_bump",
|
|
259
|
+
b10_types.TrussUserEnv.collect(),
|
|
260
|
+
preserve_previous_prod_deployment=False,
|
|
261
|
+
deployment_name="deployment_name",
|
|
262
|
+
environment="production",
|
|
263
|
+
deploy_timeout_minutes=300,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
gql_mutation = mock_post.call_args[1]["json"]["query"]
|
|
267
|
+
assert 'model_id: "model_id"' in gql_mutation
|
|
268
|
+
assert 's3_key: "s3key"' in gql_mutation
|
|
269
|
+
assert 'config: "config_str"' in gql_mutation
|
|
270
|
+
assert 'semver_bump: "semver_bump"' in gql_mutation
|
|
271
|
+
assert {
|
|
272
|
+
"trussUserEnv": b10_types.TrussUserEnv.collect().model_dump_json()
|
|
273
|
+
} == mock_post.call_args[1]["json"]["variables"]
|
|
274
|
+
assert "scale_down_old_production: true" in gql_mutation
|
|
275
|
+
assert 'name: "deployment_name"' in gql_mutation
|
|
276
|
+
assert 'environment_name: "production"' in gql_mutation
|
|
277
|
+
assert "preserve_env_instance_type: true" in gql_mutation
|
|
278
|
+
assert "deploy_timeout_minutes: 300" in gql_mutation
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
@mock.patch("requests.post", return_value=mock_create_model_version_response())
|
|
282
|
+
def test_create_model_version_from_truss_with_deploy_timeout_minutes_zero(
|
|
283
|
+
mock_post, baseten_api
|
|
284
|
+
):
|
|
285
|
+
"""Test that deploy_timeout_minutes of 0 is handled correctly"""
|
|
286
|
+
baseten_api.create_model_version_from_truss(
|
|
287
|
+
"model_id",
|
|
288
|
+
"s3key",
|
|
289
|
+
"config_str",
|
|
290
|
+
"semver_bump",
|
|
291
|
+
b10_types.TrussUserEnv.collect(),
|
|
292
|
+
preserve_previous_prod_deployment=False,
|
|
293
|
+
deploy_timeout_minutes=0,
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
gql_mutation = mock_post.call_args[1]["json"]["query"]
|
|
297
|
+
assert "deploy_timeout_minutes: 0" in gql_mutation
|
|
245
298
|
|
|
246
299
|
|
|
247
300
|
@mock.patch("requests.post", return_value=mock_create_model_response())
|
|
@@ -332,6 +385,48 @@ def test_create_development_model_from_truss_with_allow_truss_download(
|
|
|
332
385
|
"trussUserEnv": b10_types.TrussUserEnv.collect().model_dump_json()
|
|
333
386
|
} == mock_post.call_args[1]["json"]["variables"]
|
|
334
387
|
assert "allow_truss_download: false" in gql_mutation
|
|
388
|
+
assert "deploy_timeout_minutes: " not in gql_mutation
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
@mock.patch("requests.post", return_value=mock_create_development_model_response())
|
|
392
|
+
def test_create_development_model_from_truss_with_deploy_timeout_minutes(
|
|
393
|
+
mock_post, baseten_api
|
|
394
|
+
):
|
|
395
|
+
baseten_api.create_development_model_from_truss(
|
|
396
|
+
"model_name",
|
|
397
|
+
"s3key",
|
|
398
|
+
"config_str",
|
|
399
|
+
b10_types.TrussUserEnv.collect(),
|
|
400
|
+
allow_truss_download=False,
|
|
401
|
+
deploy_timeout_minutes=300,
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
gql_mutation = mock_post.call_args[1]["json"]["query"]
|
|
405
|
+
assert 'name: "model_name"' in gql_mutation
|
|
406
|
+
assert 's3_key: "s3key"' in gql_mutation
|
|
407
|
+
assert 'config: "config_str"' in gql_mutation
|
|
408
|
+
assert {
|
|
409
|
+
"trussUserEnv": b10_types.TrussUserEnv.collect().model_dump_json()
|
|
410
|
+
} == mock_post.call_args[1]["json"]["variables"]
|
|
411
|
+
assert "allow_truss_download: false" in gql_mutation
|
|
412
|
+
assert "deploy_timeout_minutes: 300" in gql_mutation
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
@mock.patch("requests.post", return_value=mock_create_development_model_response())
|
|
416
|
+
def test_create_development_model_from_truss_with_deploy_timeout_minutes_zero(
|
|
417
|
+
mock_post, baseten_api
|
|
418
|
+
):
|
|
419
|
+
"""Test that deploy_timeout_minutes of 0 is handled correctly"""
|
|
420
|
+
baseten_api.create_development_model_from_truss(
|
|
421
|
+
"model_name",
|
|
422
|
+
"s3key",
|
|
423
|
+
"config_str",
|
|
424
|
+
b10_types.TrussUserEnv.collect(),
|
|
425
|
+
deploy_timeout_minutes=0,
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
gql_mutation = mock_post.call_args[1]["json"]["query"]
|
|
429
|
+
assert "deploy_timeout_minutes: 0" in gql_mutation
|
|
335
430
|
|
|
336
431
|
|
|
337
432
|
@mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response())
|
|
@@ -357,6 +452,30 @@ def test_deploy_chain_deployment(mock_post, baseten_api):
|
|
|
357
452
|
assert 'chain_id: "chain_id"' in gql_mutation
|
|
358
453
|
assert "dependencies:" in gql_mutation
|
|
359
454
|
assert "entrypoint:" in gql_mutation
|
|
455
|
+
assert "deployment_name" not in gql_mutation
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
@mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response())
|
|
459
|
+
def test_deploy_chain_deployment_with_deployment_name(mock_post, baseten_api):
|
|
460
|
+
baseten_api.deploy_chain_atomic(
|
|
461
|
+
environment="production",
|
|
462
|
+
chain_id="chain_id",
|
|
463
|
+
dependencies=[],
|
|
464
|
+
entrypoint=ChainletDataAtomic(
|
|
465
|
+
name="chainlet-1",
|
|
466
|
+
oracle=OracleData(
|
|
467
|
+
model_name="model-1",
|
|
468
|
+
s3_key="s3-key-1",
|
|
469
|
+
encoded_config_str="encoded-config-str-1",
|
|
470
|
+
),
|
|
471
|
+
),
|
|
472
|
+
truss_user_env=b10_types.TrussUserEnv.collect(),
|
|
473
|
+
deployment_name="chain-deployment-name",
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
gql_mutation = mock_post.call_args[1]["json"]["query"]
|
|
477
|
+
|
|
478
|
+
assert 'deployment_name: "chain-deployment-name"' in gql_mutation
|
|
360
479
|
|
|
361
480
|
|
|
362
481
|
@mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response())
|
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
import pathlib
|
|
2
|
+
import tempfile
|
|
3
|
+
from unittest.mock import Mock, patch
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
|
|
7
|
+
from truss.remote.baseten import custom_types as b10_types
|
|
8
|
+
from truss.remote.baseten.api import BasetenApi
|
|
9
|
+
from truss.remote.baseten.core import upload_chain_artifact
|
|
10
|
+
from truss.remote.baseten.remote import BasetenRemote
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@pytest.fixture
|
|
14
|
+
def mock_push_data():
|
|
15
|
+
"""Fixture providing mock push data for tests."""
|
|
16
|
+
mock_push_data = Mock()
|
|
17
|
+
mock_push_data.model_name = "test-model"
|
|
18
|
+
mock_push_data.s3_key = "models/test-key"
|
|
19
|
+
mock_push_data.encoded_config_str = "encoded_config"
|
|
20
|
+
mock_push_data.is_draft = False
|
|
21
|
+
mock_push_data.model_id = "model-id"
|
|
22
|
+
mock_push_data.version_name = None
|
|
23
|
+
return mock_push_data
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@pytest.fixture
|
|
27
|
+
def mock_remote_context():
|
|
28
|
+
"""Fixture providing mock remote and context managers for tests."""
|
|
29
|
+
api = Mock(spec=BasetenApi)
|
|
30
|
+
|
|
31
|
+
remote = BasetenRemote("https://test.baseten.co", "test-api-key")
|
|
32
|
+
remote._api = api
|
|
33
|
+
|
|
34
|
+
chain_name = "test-chain"
|
|
35
|
+
entrypoint_artifact = Mock()
|
|
36
|
+
entrypoint_artifact.truss_dir = "/path/to/truss"
|
|
37
|
+
entrypoint_artifact.display_name = "entrypoint"
|
|
38
|
+
|
|
39
|
+
dependency_artifacts = []
|
|
40
|
+
truss_user_env = Mock()
|
|
41
|
+
chain_root = pathlib.Path("/path/to/chain")
|
|
42
|
+
|
|
43
|
+
with patch.object(remote, "_prepare_push") as mock_prepare_push:
|
|
44
|
+
with patch("truss.remote.baseten.remote.truss_build.load") as mock_load:
|
|
45
|
+
mock_truss_handle = Mock()
|
|
46
|
+
mock_truss_handle.spec.config.model_name = "test-model"
|
|
47
|
+
mock_load.return_value = mock_truss_handle
|
|
48
|
+
|
|
49
|
+
yield {
|
|
50
|
+
"remote": remote,
|
|
51
|
+
"api": api,
|
|
52
|
+
"chain_name": chain_name,
|
|
53
|
+
"entrypoint_artifact": entrypoint_artifact,
|
|
54
|
+
"dependency_artifacts": dependency_artifacts,
|
|
55
|
+
"truss_user_env": truss_user_env,
|
|
56
|
+
"chain_root": chain_root,
|
|
57
|
+
"mock_prepare_push": mock_prepare_push,
|
|
58
|
+
"mock_load": mock_load,
|
|
59
|
+
"mock_truss_handle": mock_truss_handle,
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def test_get_blob_credentials_for_chain():
|
|
64
|
+
"""Test that get_blob_credentials works correctly for chain blob type using GraphQL."""
|
|
65
|
+
mock_graphql_response = {
|
|
66
|
+
"data": {
|
|
67
|
+
"chain_s3_upload_credentials": {
|
|
68
|
+
"s3_bucket": "test-chain-bucket",
|
|
69
|
+
"s3_key": "chains/test-uuid/chain.tgz",
|
|
70
|
+
"aws_access_key_id": "test_access_key",
|
|
71
|
+
"aws_secret_access_key": "test_secret_key",
|
|
72
|
+
"aws_session_token": "test_session_token",
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
# Create a real API instance and mock the GraphQL call
|
|
78
|
+
mock_auth_service = Mock()
|
|
79
|
+
mock_auth_service.authenticate.return_value = Mock(value="test-token")
|
|
80
|
+
api = BasetenApi("https://test.baseten.co", mock_auth_service)
|
|
81
|
+
with patch.object(api, "_post_graphql_query") as mock_graphql:
|
|
82
|
+
mock_graphql.return_value = mock_graphql_response
|
|
83
|
+
|
|
84
|
+
result = api.get_chain_s3_upload_credentials()
|
|
85
|
+
|
|
86
|
+
assert result.s3_bucket == "test-chain-bucket"
|
|
87
|
+
assert result.s3_key == "chains/test-uuid/chain.tgz"
|
|
88
|
+
assert result.aws_access_key_id == "test_access_key"
|
|
89
|
+
assert result.aws_secret_access_key == "test_secret_key"
|
|
90
|
+
assert result.aws_session_token == "test_session_token"
|
|
91
|
+
|
|
92
|
+
mock_graphql.assert_called_once()
|
|
93
|
+
call_args = mock_graphql.call_args
|
|
94
|
+
assert "chain_s3_upload_credentials" in call_args[0][0]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def test_get_blob_credentials_for_other_types_uses_rest():
|
|
98
|
+
"""Test that get_blob_credentials uses REST API for non-chain blob types."""
|
|
99
|
+
mock_response = {
|
|
100
|
+
"s3_bucket": "test-bucket",
|
|
101
|
+
"s3_key": "test-key",
|
|
102
|
+
"creds": {
|
|
103
|
+
"aws_access_key_id": "test_access_key",
|
|
104
|
+
"aws_secret_access_key": "test_secret_key",
|
|
105
|
+
"aws_session_token": "test_session_token",
|
|
106
|
+
},
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
mock_auth_service = Mock()
|
|
110
|
+
mock_auth_service.authenticate.return_value = Mock(value="test-token")
|
|
111
|
+
api = BasetenApi("https://test.baseten.co", mock_auth_service)
|
|
112
|
+
with patch.object(api, "_rest_api_client") as mock_client, patch.object(
|
|
113
|
+
api, "_post_graphql_query"
|
|
114
|
+
) as mock_graphql:
|
|
115
|
+
mock_client.get.return_value = mock_response
|
|
116
|
+
|
|
117
|
+
result = api.get_blob_credentials(b10_types.BlobType.MODEL)
|
|
118
|
+
|
|
119
|
+
assert result["s3_bucket"] == "test-bucket"
|
|
120
|
+
assert result["s3_key"] == "test-key"
|
|
121
|
+
|
|
122
|
+
mock_client.get.assert_called_once_with("v1/blobs/credentials/model")
|
|
123
|
+
mock_graphql.assert_not_called()
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@patch("truss.remote.baseten.core.multipart_upload_boto3")
|
|
127
|
+
def test_upload_chain_artifact_function(mock_multipart_upload):
|
|
128
|
+
"""Test the upload_chain_artifact function."""
|
|
129
|
+
# Mock ChainUploadCredentials object
|
|
130
|
+
mock_credentials = Mock()
|
|
131
|
+
mock_credentials.s3_bucket = "test-chain-bucket"
|
|
132
|
+
mock_credentials.s3_key = "chains/test-uuid/chain.tgz"
|
|
133
|
+
mock_credentials.aws_credentials = Mock()
|
|
134
|
+
mock_credentials.aws_credentials.model_dump.return_value = {
|
|
135
|
+
"aws_access_key_id": "test_access_key",
|
|
136
|
+
"aws_secret_access_key": "test_secret_key",
|
|
137
|
+
"aws_session_token": "test_session_token",
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
api = Mock(spec=BasetenApi)
|
|
141
|
+
api.get_chain_s3_upload_credentials.return_value = mock_credentials
|
|
142
|
+
|
|
143
|
+
with tempfile.NamedTemporaryFile(suffix=".tgz", delete=False) as temp_file:
|
|
144
|
+
temp_file.write(b"test chain content")
|
|
145
|
+
temp_file.flush()
|
|
146
|
+
|
|
147
|
+
result = upload_chain_artifact(api, temp_file, None)
|
|
148
|
+
|
|
149
|
+
assert result == "chains/test-uuid/chain.tgz"
|
|
150
|
+
|
|
151
|
+
api.get_chain_s3_upload_credentials.assert_called_once_with()
|
|
152
|
+
|
|
153
|
+
mock_multipart_upload.assert_called_once()
|
|
154
|
+
call_args = mock_multipart_upload.call_args
|
|
155
|
+
assert call_args[0][0] == temp_file.name # file path
|
|
156
|
+
assert call_args[0][1] == "test-chain-bucket" # bucket
|
|
157
|
+
assert call_args[0][2] == "chains/test-uuid/chain.tgz" # key
|
|
158
|
+
assert call_args[0][3] == { # credentials
|
|
159
|
+
"aws_access_key_id": "test_access_key",
|
|
160
|
+
"aws_secret_access_key": "test_secret_key",
|
|
161
|
+
"aws_session_token": "test_session_token",
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@patch("truss.remote.baseten.remote.upload_chain_artifact")
|
|
166
|
+
@patch("truss.remote.baseten.remote.archive_dir")
|
|
167
|
+
@patch("truss.remote.baseten.remote.create_chain_atomic")
|
|
168
|
+
def test_push_chain_atomic_with_chain_upload(
|
|
169
|
+
mock_create_chain_atomic,
|
|
170
|
+
mock_archive_dir,
|
|
171
|
+
mock_upload_chain_artifact,
|
|
172
|
+
mock_push_data,
|
|
173
|
+
mock_remote_context,
|
|
174
|
+
):
|
|
175
|
+
"""Test that push_chain_atomic uploads raw chain artifact when chain_root is provided."""
|
|
176
|
+
mock_create_chain_atomic.return_value = Mock()
|
|
177
|
+
mock_archive_dir.return_value = Mock()
|
|
178
|
+
mock_upload_chain_artifact.return_value = "chains/test-uuid/chain.tgz"
|
|
179
|
+
|
|
180
|
+
context = mock_remote_context
|
|
181
|
+
remote = context["remote"]
|
|
182
|
+
chain_name = context["chain_name"]
|
|
183
|
+
entrypoint_artifact = context["entrypoint_artifact"]
|
|
184
|
+
dependency_artifacts = context["dependency_artifacts"]
|
|
185
|
+
truss_user_env = context["truss_user_env"]
|
|
186
|
+
chain_root = context["chain_root"]
|
|
187
|
+
|
|
188
|
+
context["mock_prepare_push"].return_value = mock_push_data
|
|
189
|
+
deployment_name = "custom_deployment"
|
|
190
|
+
|
|
191
|
+
result = remote.push_chain_atomic(
|
|
192
|
+
chain_name=chain_name,
|
|
193
|
+
entrypoint_artifact=entrypoint_artifact,
|
|
194
|
+
dependency_artifacts=dependency_artifacts,
|
|
195
|
+
truss_user_env=truss_user_env,
|
|
196
|
+
chain_root=chain_root,
|
|
197
|
+
publish=True,
|
|
198
|
+
deployment_name=deployment_name,
|
|
199
|
+
)
|
|
200
|
+
assert result == mock_create_chain_atomic.return_value
|
|
201
|
+
|
|
202
|
+
mock_archive_dir.assert_called_once_with(dir=chain_root, progress_bar=None)
|
|
203
|
+
mock_upload_chain_artifact.assert_called_once()
|
|
204
|
+
mock_create_chain_atomic.assert_called_once()
|
|
205
|
+
create_kwargs = mock_create_chain_atomic.call_args.kwargs
|
|
206
|
+
assert create_kwargs["deployment_name"] == deployment_name
|
|
207
|
+
|
|
208
|
+
prepare_kwargs = context["mock_prepare_push"].call_args.kwargs
|
|
209
|
+
assert prepare_kwargs["deployment_name"] == deployment_name
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
@patch("truss.remote.baseten.remote.create_chain_atomic")
|
|
213
|
+
def test_push_chain_atomic_without_chain_upload(
|
|
214
|
+
mock_create_chain_atomic, mock_push_data, mock_remote_context
|
|
215
|
+
):
|
|
216
|
+
"""Test that push_chain_atomic skips chain upload when chain_root is None."""
|
|
217
|
+
mock_create_chain_atomic.return_value = Mock()
|
|
218
|
+
|
|
219
|
+
context = mock_remote_context
|
|
220
|
+
remote = context["remote"]
|
|
221
|
+
chain_name = context["chain_name"]
|
|
222
|
+
entrypoint_artifact = context["entrypoint_artifact"]
|
|
223
|
+
dependency_artifacts = context["dependency_artifacts"]
|
|
224
|
+
truss_user_env = context["truss_user_env"]
|
|
225
|
+
|
|
226
|
+
context["mock_prepare_push"].return_value = mock_push_data
|
|
227
|
+
|
|
228
|
+
with patch("truss.remote.baseten.remote.upload_chain_artifact") as mock_upload:
|
|
229
|
+
with patch(
|
|
230
|
+
"truss.remote.baseten.core.create_tar_with_progress_bar"
|
|
231
|
+
) as mock_tar:
|
|
232
|
+
# Call push_chain_atomic without chain_root
|
|
233
|
+
result = remote.push_chain_atomic(
|
|
234
|
+
chain_name=chain_name,
|
|
235
|
+
entrypoint_artifact=entrypoint_artifact,
|
|
236
|
+
dependency_artifacts=dependency_artifacts,
|
|
237
|
+
truss_user_env=truss_user_env,
|
|
238
|
+
chain_root=None, # No chain root
|
|
239
|
+
publish=True,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
assert result
|
|
243
|
+
# Verify chain artifact upload was NOT called
|
|
244
|
+
mock_tar.assert_not_called()
|
|
245
|
+
mock_upload.assert_not_called()
|
|
246
|
+
|
|
247
|
+
mock_create_chain_atomic.assert_called_once()
|
|
248
|
+
create_kwargs = mock_create_chain_atomic.call_args.kwargs
|
|
249
|
+
assert "deployment_name" in create_kwargs
|
|
250
|
+
assert create_kwargs["deployment_name"] is None
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
@patch("truss.remote.baseten.core.multipart_upload_boto3")
|
|
254
|
+
def test_upload_chain_artifact_error_handling(mock_multipart_upload):
|
|
255
|
+
"""Test error handling in upload_chain_artifact."""
|
|
256
|
+
# Mock API to raise an exception
|
|
257
|
+
api = Mock(spec=BasetenApi)
|
|
258
|
+
api.get_chain_s3_upload_credentials.side_effect = Exception("API Error")
|
|
259
|
+
|
|
260
|
+
with tempfile.NamedTemporaryFile(suffix=".tgz") as temp_file:
|
|
261
|
+
# Should raise the exception
|
|
262
|
+
with pytest.raises(Exception, match="API Error"):
|
|
263
|
+
upload_chain_artifact(api, temp_file, None)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def test_upload_chain_artifact_credentials_extraction():
|
|
267
|
+
"""Test that credentials are properly extracted from API response."""
|
|
268
|
+
# Mock ChainUploadCredentials object
|
|
269
|
+
mock_credentials = Mock()
|
|
270
|
+
mock_credentials.s3_bucket = "test-bucket"
|
|
271
|
+
mock_credentials.s3_key = "chains/test-uuid/chain.tgz"
|
|
272
|
+
mock_credentials.aws_credentials = Mock()
|
|
273
|
+
mock_credentials.aws_credentials.model_dump.return_value = {
|
|
274
|
+
"aws_access_key_id": "access_key",
|
|
275
|
+
"aws_secret_access_key": "secret_key",
|
|
276
|
+
"aws_session_token": "session_token",
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
api = Mock(spec=BasetenApi)
|
|
280
|
+
api.get_chain_s3_upload_credentials.return_value = mock_credentials
|
|
281
|
+
|
|
282
|
+
with patch("truss.remote.baseten.core.multipart_upload_boto3") as mock_upload:
|
|
283
|
+
with tempfile.NamedTemporaryFile(suffix=".tgz") as temp_file:
|
|
284
|
+
upload_chain_artifact(api, temp_file, None)
|
|
285
|
+
|
|
286
|
+
call_args = mock_upload.call_args
|
|
287
|
+
credentials = call_args[0][3]
|
|
288
|
+
|
|
289
|
+
assert credentials == {
|
|
290
|
+
"aws_access_key_id": "access_key",
|
|
291
|
+
"aws_secret_access_key": "secret_key",
|
|
292
|
+
"aws_session_token": "session_token",
|
|
293
|
+
}
|
|
294
|
+
assert "extra_field" not in credentials
|
|
@@ -187,6 +187,7 @@ def test_create_truss_service_handles_existing_model(inputs):
|
|
|
187
187
|
_, kwargs = api.create_model_version_from_truss.call_args
|
|
188
188
|
for k, v in inputs.items():
|
|
189
189
|
assert kwargs[k] == v
|
|
190
|
+
assert kwargs.get("deploy_timeout_minutes") is None
|
|
190
191
|
|
|
191
192
|
|
|
192
193
|
@pytest.mark.parametrize("allow_truss_download", [True, False])
|
|
@@ -761,3 +762,88 @@ def test_get_training_job_logs_with_pagination_default_batch_size(baseten_api):
|
|
|
761
762
|
query_params = call_args[0][2] # query_params
|
|
762
763
|
|
|
763
764
|
assert query_params["limit"] == MAX_BATCH_SIZE
|
|
765
|
+
|
|
766
|
+
|
|
767
|
+
def test_create_truss_service_passes_deploy_timeout_minutes():
|
|
768
|
+
"""Test that deploy_timeout_minutes is passed through to create_model_version_from_truss"""
|
|
769
|
+
api = MagicMock()
|
|
770
|
+
return_value = {
|
|
771
|
+
"id": "model_version_id",
|
|
772
|
+
"oracle": {"id": "model_id", "hostname": "hostname"},
|
|
773
|
+
"instance_type": {"name": "1x2"},
|
|
774
|
+
}
|
|
775
|
+
api.create_model_version_from_truss.return_value = return_value
|
|
776
|
+
version_handle = create_truss_service(
|
|
777
|
+
api,
|
|
778
|
+
"model_name",
|
|
779
|
+
"s3_key",
|
|
780
|
+
"config",
|
|
781
|
+
b10_types.TrussUserEnv.collect(),
|
|
782
|
+
is_draft=False,
|
|
783
|
+
model_id="model_id",
|
|
784
|
+
environment="staging",
|
|
785
|
+
deploy_timeout_minutes=600,
|
|
786
|
+
)
|
|
787
|
+
|
|
788
|
+
assert version_handle.version_id == "model_version_id"
|
|
789
|
+
assert version_handle.model_id == "model_id"
|
|
790
|
+
api.create_model_version_from_truss.assert_called_once()
|
|
791
|
+
_, kwargs = api.create_model_version_from_truss.call_args
|
|
792
|
+
assert kwargs["deploy_timeout_minutes"] == 600
|
|
793
|
+
|
|
794
|
+
|
|
795
|
+
def test_create_truss_service_passes_deploy_timeout_minutes_with_other_params():
|
|
796
|
+
"""Test that deploy_timeout_minutes works correctly with other parameters like preserve_env_instance_type"""
|
|
797
|
+
api = MagicMock()
|
|
798
|
+
return_value = {
|
|
799
|
+
"id": "model_version_id",
|
|
800
|
+
"oracle": {"id": "model_id", "hostname": "hostname"},
|
|
801
|
+
"instance_type": {"name": "1x2"},
|
|
802
|
+
}
|
|
803
|
+
api.create_model_version_from_truss.return_value = return_value
|
|
804
|
+
version_handle = create_truss_service(
|
|
805
|
+
api,
|
|
806
|
+
"model_name",
|
|
807
|
+
"s3_key",
|
|
808
|
+
"config",
|
|
809
|
+
b10_types.TrussUserEnv.collect(),
|
|
810
|
+
is_draft=False,
|
|
811
|
+
model_id="model_id",
|
|
812
|
+
environment="production",
|
|
813
|
+
preserve_env_instance_type=False,
|
|
814
|
+
deploy_timeout_minutes=900,
|
|
815
|
+
)
|
|
816
|
+
|
|
817
|
+
assert version_handle.version_id == "model_version_id"
|
|
818
|
+
api.create_model_version_from_truss.assert_called_once()
|
|
819
|
+
_, kwargs = api.create_model_version_from_truss.call_args
|
|
820
|
+
assert kwargs["deploy_timeout_minutes"] == 900
|
|
821
|
+
assert kwargs["preserve_env_instance_type"] is False
|
|
822
|
+
assert kwargs["environment"] == "production"
|
|
823
|
+
|
|
824
|
+
|
|
825
|
+
def test_create_truss_service_passes_deploy_timeout_minutes_for_development_model():
|
|
826
|
+
"""Test that deploy_timeout_minutes is passed through to create_development_model_from_truss"""
|
|
827
|
+
api = MagicMock()
|
|
828
|
+
return_value = {
|
|
829
|
+
"id": "model_version_id",
|
|
830
|
+
"oracle": {"id": "model_id", "hostname": "hostname"},
|
|
831
|
+
"instance_type": {"name": "1x2"},
|
|
832
|
+
}
|
|
833
|
+
api.create_development_model_from_truss.return_value = return_value
|
|
834
|
+
version_handle = create_truss_service(
|
|
835
|
+
api,
|
|
836
|
+
"model_name",
|
|
837
|
+
"s3_key",
|
|
838
|
+
"config",
|
|
839
|
+
b10_types.TrussUserEnv.collect(),
|
|
840
|
+
is_draft=True,
|
|
841
|
+
model_id=None,
|
|
842
|
+
deploy_timeout_minutes=600,
|
|
843
|
+
)
|
|
844
|
+
|
|
845
|
+
assert version_handle.version_id == "model_version_id"
|
|
846
|
+
assert version_handle.model_id == "model_id"
|
|
847
|
+
api.create_development_model_from_truss.assert_called_once()
|
|
848
|
+
_, kwargs = api.create_development_model_from_truss.call_args
|
|
849
|
+
assert kwargs["deploy_timeout_minutes"] == 600
|