truss 0.11.18rc500__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- truss/api/__init__.py +5 -2
- truss/base/truss_config.py +10 -3
- truss/cli/chains_commands.py +39 -1
- truss/cli/cli.py +35 -5
- truss/cli/remote_cli.py +29 -0
- truss/cli/resolvers/chain_team_resolver.py +82 -0
- truss/cli/resolvers/model_team_resolver.py +90 -0
- truss/cli/resolvers/training_project_team_resolver.py +81 -0
- truss/cli/train/cache.py +332 -0
- truss/cli/train/core.py +19 -143
- truss/cli/train_commands.py +69 -11
- truss/cli/utils/common.py +40 -3
- truss/remote/baseten/api.py +58 -5
- truss/remote/baseten/core.py +22 -4
- truss/remote/baseten/remote.py +24 -2
- truss/templates/control/control/helpers/inference_server_process_controller.py +3 -1
- truss/templates/server/requirements.txt +1 -1
- truss/templates/server.Dockerfile.jinja +10 -10
- truss/templates/shared/util.py +6 -5
- truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
- truss/tests/cli/test_chains_cli.py +44 -0
- truss/tests/cli/test_cli.py +134 -1
- truss/tests/cli/test_cli_utils_common.py +11 -0
- truss/tests/cli/test_model_team_resolver.py +279 -0
- truss/tests/cli/train/test_cache_view.py +240 -3
- truss/tests/cli/train/test_train_cli_core.py +2 -2
- truss/tests/cli/train/test_train_team_parameter.py +395 -0
- truss/tests/conftest.py +187 -0
- truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
- truss/tests/remote/baseten/test_api.py +122 -3
- truss/tests/remote/baseten/test_chain_upload.py +10 -1
- truss/tests/remote/baseten/test_core.py +86 -0
- truss/tests/remote/baseten/test_remote.py +216 -288
- truss/tests/test_config.py +21 -12
- truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
- truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
- truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
- truss/tests/test_model_inference.py +13 -0
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/METADATA +1 -1
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/RECORD +50 -38
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
- truss_chains/deployment/deployment_client.py +9 -4
- truss_chains/private_types.py +15 -0
- truss_train/definitions.py +3 -1
- truss_train/deployment.py +43 -21
- truss_train/public_api.py +4 -2
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
|
@@ -165,9 +165,9 @@ def test_create_model_version_from_truss(mock_post, baseten_api):
|
|
|
165
165
|
"config_str",
|
|
166
166
|
"semver_bump",
|
|
167
167
|
b10_types.TrussUserEnv.collect(),
|
|
168
|
-
False,
|
|
169
|
-
"deployment_name",
|
|
170
|
-
"production",
|
|
168
|
+
preserve_previous_prod_deployment=False,
|
|
169
|
+
deployment_name="deployment_name",
|
|
170
|
+
environment="production",
|
|
171
171
|
)
|
|
172
172
|
|
|
173
173
|
gql_mutation = mock_post.call_args[1]["json"]["query"]
|
|
@@ -182,6 +182,7 @@ def test_create_model_version_from_truss(mock_post, baseten_api):
|
|
|
182
182
|
assert 'name: "deployment_name"' in gql_mutation
|
|
183
183
|
assert 'environment_name: "production"' in gql_mutation
|
|
184
184
|
assert "preserve_env_instance_type: true" in gql_mutation
|
|
185
|
+
assert "deploy_timeout_minutes: " not in gql_mutation
|
|
185
186
|
|
|
186
187
|
|
|
187
188
|
@mock.patch("requests.post", return_value=mock_create_model_version_response())
|
|
@@ -211,6 +212,7 @@ def test_create_model_version_from_truss_does_not_send_deployment_name_if_not_sp
|
|
|
211
212
|
assert " name: " not in gql_mutation
|
|
212
213
|
assert "environment_name: " not in gql_mutation
|
|
213
214
|
assert "preserve_env_instance_type: false" in gql_mutation
|
|
215
|
+
assert "deploy_timeout_minutes: " not in gql_mutation
|
|
214
216
|
|
|
215
217
|
|
|
216
218
|
@mock.patch("requests.post", return_value=mock_create_model_version_response())
|
|
@@ -242,6 +244,57 @@ def test_create_model_version_from_truss_does_not_scale_old_prod_to_zero_if_keep
|
|
|
242
244
|
assert " name: " not in gql_mutation
|
|
243
245
|
assert 'environment_name: "staging"' in gql_mutation
|
|
244
246
|
assert "preserve_env_instance_type: true" in gql_mutation
|
|
247
|
+
assert "deploy_timeout_minutes: " not in gql_mutation
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
@mock.patch("requests.post", return_value=mock_create_model_version_response())
|
|
251
|
+
def test_create_model_version_from_truss_with_deploy_timeout_minutes(
|
|
252
|
+
mock_post, baseten_api
|
|
253
|
+
):
|
|
254
|
+
baseten_api.create_model_version_from_truss(
|
|
255
|
+
"model_id",
|
|
256
|
+
"s3key",
|
|
257
|
+
"config_str",
|
|
258
|
+
"semver_bump",
|
|
259
|
+
b10_types.TrussUserEnv.collect(),
|
|
260
|
+
preserve_previous_prod_deployment=False,
|
|
261
|
+
deployment_name="deployment_name",
|
|
262
|
+
environment="production",
|
|
263
|
+
deploy_timeout_minutes=300,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
gql_mutation = mock_post.call_args[1]["json"]["query"]
|
|
267
|
+
assert 'model_id: "model_id"' in gql_mutation
|
|
268
|
+
assert 's3_key: "s3key"' in gql_mutation
|
|
269
|
+
assert 'config: "config_str"' in gql_mutation
|
|
270
|
+
assert 'semver_bump: "semver_bump"' in gql_mutation
|
|
271
|
+
assert {
|
|
272
|
+
"trussUserEnv": b10_types.TrussUserEnv.collect().model_dump_json()
|
|
273
|
+
} == mock_post.call_args[1]["json"]["variables"]
|
|
274
|
+
assert "scale_down_old_production: true" in gql_mutation
|
|
275
|
+
assert 'name: "deployment_name"' in gql_mutation
|
|
276
|
+
assert 'environment_name: "production"' in gql_mutation
|
|
277
|
+
assert "preserve_env_instance_type: true" in gql_mutation
|
|
278
|
+
assert "deploy_timeout_minutes: 300" in gql_mutation
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
@mock.patch("requests.post", return_value=mock_create_model_version_response())
|
|
282
|
+
def test_create_model_version_from_truss_with_deploy_timeout_minutes_zero(
|
|
283
|
+
mock_post, baseten_api
|
|
284
|
+
):
|
|
285
|
+
"""Test that deploy_timeout_minutes of 0 is handled correctly"""
|
|
286
|
+
baseten_api.create_model_version_from_truss(
|
|
287
|
+
"model_id",
|
|
288
|
+
"s3key",
|
|
289
|
+
"config_str",
|
|
290
|
+
"semver_bump",
|
|
291
|
+
b10_types.TrussUserEnv.collect(),
|
|
292
|
+
preserve_previous_prod_deployment=False,
|
|
293
|
+
deploy_timeout_minutes=0,
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
gql_mutation = mock_post.call_args[1]["json"]["query"]
|
|
297
|
+
assert "deploy_timeout_minutes: 0" in gql_mutation
|
|
245
298
|
|
|
246
299
|
|
|
247
300
|
@mock.patch("requests.post", return_value=mock_create_model_response())
|
|
@@ -332,6 +385,48 @@ def test_create_development_model_from_truss_with_allow_truss_download(
|
|
|
332
385
|
"trussUserEnv": b10_types.TrussUserEnv.collect().model_dump_json()
|
|
333
386
|
} == mock_post.call_args[1]["json"]["variables"]
|
|
334
387
|
assert "allow_truss_download: false" in gql_mutation
|
|
388
|
+
assert "deploy_timeout_minutes: " not in gql_mutation
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
@mock.patch("requests.post", return_value=mock_create_development_model_response())
|
|
392
|
+
def test_create_development_model_from_truss_with_deploy_timeout_minutes(
|
|
393
|
+
mock_post, baseten_api
|
|
394
|
+
):
|
|
395
|
+
baseten_api.create_development_model_from_truss(
|
|
396
|
+
"model_name",
|
|
397
|
+
"s3key",
|
|
398
|
+
"config_str",
|
|
399
|
+
b10_types.TrussUserEnv.collect(),
|
|
400
|
+
allow_truss_download=False,
|
|
401
|
+
deploy_timeout_minutes=300,
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
gql_mutation = mock_post.call_args[1]["json"]["query"]
|
|
405
|
+
assert 'name: "model_name"' in gql_mutation
|
|
406
|
+
assert 's3_key: "s3key"' in gql_mutation
|
|
407
|
+
assert 'config: "config_str"' in gql_mutation
|
|
408
|
+
assert {
|
|
409
|
+
"trussUserEnv": b10_types.TrussUserEnv.collect().model_dump_json()
|
|
410
|
+
} == mock_post.call_args[1]["json"]["variables"]
|
|
411
|
+
assert "allow_truss_download: false" in gql_mutation
|
|
412
|
+
assert "deploy_timeout_minutes: 300" in gql_mutation
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
@mock.patch("requests.post", return_value=mock_create_development_model_response())
|
|
416
|
+
def test_create_development_model_from_truss_with_deploy_timeout_minutes_zero(
|
|
417
|
+
mock_post, baseten_api
|
|
418
|
+
):
|
|
419
|
+
"""Test that deploy_timeout_minutes of 0 is handled correctly"""
|
|
420
|
+
baseten_api.create_development_model_from_truss(
|
|
421
|
+
"model_name",
|
|
422
|
+
"s3key",
|
|
423
|
+
"config_str",
|
|
424
|
+
b10_types.TrussUserEnv.collect(),
|
|
425
|
+
deploy_timeout_minutes=0,
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
gql_mutation = mock_post.call_args[1]["json"]["query"]
|
|
429
|
+
assert "deploy_timeout_minutes: 0" in gql_mutation
|
|
335
430
|
|
|
336
431
|
|
|
337
432
|
@mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response())
|
|
@@ -357,6 +452,30 @@ def test_deploy_chain_deployment(mock_post, baseten_api):
|
|
|
357
452
|
assert 'chain_id: "chain_id"' in gql_mutation
|
|
358
453
|
assert "dependencies:" in gql_mutation
|
|
359
454
|
assert "entrypoint:" in gql_mutation
|
|
455
|
+
assert "deployment_name" not in gql_mutation
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
@mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response())
|
|
459
|
+
def test_deploy_chain_deployment_with_deployment_name(mock_post, baseten_api):
|
|
460
|
+
baseten_api.deploy_chain_atomic(
|
|
461
|
+
environment="production",
|
|
462
|
+
chain_id="chain_id",
|
|
463
|
+
dependencies=[],
|
|
464
|
+
entrypoint=ChainletDataAtomic(
|
|
465
|
+
name="chainlet-1",
|
|
466
|
+
oracle=OracleData(
|
|
467
|
+
model_name="model-1",
|
|
468
|
+
s3_key="s3-key-1",
|
|
469
|
+
encoded_config_str="encoded-config-str-1",
|
|
470
|
+
),
|
|
471
|
+
),
|
|
472
|
+
truss_user_env=b10_types.TrussUserEnv.collect(),
|
|
473
|
+
deployment_name="chain-deployment-name",
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
gql_mutation = mock_post.call_args[1]["json"]["query"]
|
|
477
|
+
|
|
478
|
+
assert 'deployment_name: "chain-deployment-name"' in gql_mutation
|
|
360
479
|
|
|
361
480
|
|
|
362
481
|
@mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response())
|
|
@@ -186,6 +186,7 @@ def test_push_chain_atomic_with_chain_upload(
|
|
|
186
186
|
chain_root = context["chain_root"]
|
|
187
187
|
|
|
188
188
|
context["mock_prepare_push"].return_value = mock_push_data
|
|
189
|
+
deployment_name = "custom_deployment"
|
|
189
190
|
|
|
190
191
|
result = remote.push_chain_atomic(
|
|
191
192
|
chain_name=chain_name,
|
|
@@ -194,13 +195,18 @@ def test_push_chain_atomic_with_chain_upload(
|
|
|
194
195
|
truss_user_env=truss_user_env,
|
|
195
196
|
chain_root=chain_root,
|
|
196
197
|
publish=True,
|
|
198
|
+
deployment_name=deployment_name,
|
|
197
199
|
)
|
|
198
200
|
assert result == mock_create_chain_atomic.return_value
|
|
199
201
|
|
|
200
202
|
mock_archive_dir.assert_called_once_with(dir=chain_root, progress_bar=None)
|
|
201
203
|
mock_upload_chain_artifact.assert_called_once()
|
|
202
|
-
|
|
203
204
|
mock_create_chain_atomic.assert_called_once()
|
|
205
|
+
create_kwargs = mock_create_chain_atomic.call_args.kwargs
|
|
206
|
+
assert create_kwargs["deployment_name"] == deployment_name
|
|
207
|
+
|
|
208
|
+
prepare_kwargs = context["mock_prepare_push"].call_args.kwargs
|
|
209
|
+
assert prepare_kwargs["deployment_name"] == deployment_name
|
|
204
210
|
|
|
205
211
|
|
|
206
212
|
@patch("truss.remote.baseten.remote.create_chain_atomic")
|
|
@@ -239,6 +245,9 @@ def test_push_chain_atomic_without_chain_upload(
|
|
|
239
245
|
mock_upload.assert_not_called()
|
|
240
246
|
|
|
241
247
|
mock_create_chain_atomic.assert_called_once()
|
|
248
|
+
create_kwargs = mock_create_chain_atomic.call_args.kwargs
|
|
249
|
+
assert "deployment_name" in create_kwargs
|
|
250
|
+
assert create_kwargs["deployment_name"] is None
|
|
242
251
|
|
|
243
252
|
|
|
244
253
|
@patch("truss.remote.baseten.core.multipart_upload_boto3")
|
|
@@ -187,6 +187,7 @@ def test_create_truss_service_handles_existing_model(inputs):
|
|
|
187
187
|
_, kwargs = api.create_model_version_from_truss.call_args
|
|
188
188
|
for k, v in inputs.items():
|
|
189
189
|
assert kwargs[k] == v
|
|
190
|
+
assert kwargs.get("deploy_timeout_minutes") is None
|
|
190
191
|
|
|
191
192
|
|
|
192
193
|
@pytest.mark.parametrize("allow_truss_download", [True, False])
|
|
@@ -761,3 +762,88 @@ def test_get_training_job_logs_with_pagination_default_batch_size(baseten_api):
|
|
|
761
762
|
query_params = call_args[0][2] # query_params
|
|
762
763
|
|
|
763
764
|
assert query_params["limit"] == MAX_BATCH_SIZE
|
|
765
|
+
|
|
766
|
+
|
|
767
|
+
def test_create_truss_service_passes_deploy_timeout_minutes():
|
|
768
|
+
"""Test that deploy_timeout_minutes is passed through to create_model_version_from_truss"""
|
|
769
|
+
api = MagicMock()
|
|
770
|
+
return_value = {
|
|
771
|
+
"id": "model_version_id",
|
|
772
|
+
"oracle": {"id": "model_id", "hostname": "hostname"},
|
|
773
|
+
"instance_type": {"name": "1x2"},
|
|
774
|
+
}
|
|
775
|
+
api.create_model_version_from_truss.return_value = return_value
|
|
776
|
+
version_handle = create_truss_service(
|
|
777
|
+
api,
|
|
778
|
+
"model_name",
|
|
779
|
+
"s3_key",
|
|
780
|
+
"config",
|
|
781
|
+
b10_types.TrussUserEnv.collect(),
|
|
782
|
+
is_draft=False,
|
|
783
|
+
model_id="model_id",
|
|
784
|
+
environment="staging",
|
|
785
|
+
deploy_timeout_minutes=600,
|
|
786
|
+
)
|
|
787
|
+
|
|
788
|
+
assert version_handle.version_id == "model_version_id"
|
|
789
|
+
assert version_handle.model_id == "model_id"
|
|
790
|
+
api.create_model_version_from_truss.assert_called_once()
|
|
791
|
+
_, kwargs = api.create_model_version_from_truss.call_args
|
|
792
|
+
assert kwargs["deploy_timeout_minutes"] == 600
|
|
793
|
+
|
|
794
|
+
|
|
795
|
+
def test_create_truss_service_passes_deploy_timeout_minutes_with_other_params():
|
|
796
|
+
"""Test that deploy_timeout_minutes works correctly with other parameters like preserve_env_instance_type"""
|
|
797
|
+
api = MagicMock()
|
|
798
|
+
return_value = {
|
|
799
|
+
"id": "model_version_id",
|
|
800
|
+
"oracle": {"id": "model_id", "hostname": "hostname"},
|
|
801
|
+
"instance_type": {"name": "1x2"},
|
|
802
|
+
}
|
|
803
|
+
api.create_model_version_from_truss.return_value = return_value
|
|
804
|
+
version_handle = create_truss_service(
|
|
805
|
+
api,
|
|
806
|
+
"model_name",
|
|
807
|
+
"s3_key",
|
|
808
|
+
"config",
|
|
809
|
+
b10_types.TrussUserEnv.collect(),
|
|
810
|
+
is_draft=False,
|
|
811
|
+
model_id="model_id",
|
|
812
|
+
environment="production",
|
|
813
|
+
preserve_env_instance_type=False,
|
|
814
|
+
deploy_timeout_minutes=900,
|
|
815
|
+
)
|
|
816
|
+
|
|
817
|
+
assert version_handle.version_id == "model_version_id"
|
|
818
|
+
api.create_model_version_from_truss.assert_called_once()
|
|
819
|
+
_, kwargs = api.create_model_version_from_truss.call_args
|
|
820
|
+
assert kwargs["deploy_timeout_minutes"] == 900
|
|
821
|
+
assert kwargs["preserve_env_instance_type"] is False
|
|
822
|
+
assert kwargs["environment"] == "production"
|
|
823
|
+
|
|
824
|
+
|
|
825
|
+
def test_create_truss_service_passes_deploy_timeout_minutes_for_development_model():
|
|
826
|
+
"""Test that deploy_timeout_minutes is passed through to create_development_model_from_truss"""
|
|
827
|
+
api = MagicMock()
|
|
828
|
+
return_value = {
|
|
829
|
+
"id": "model_version_id",
|
|
830
|
+
"oracle": {"id": "model_id", "hostname": "hostname"},
|
|
831
|
+
"instance_type": {"name": "1x2"},
|
|
832
|
+
}
|
|
833
|
+
api.create_development_model_from_truss.return_value = return_value
|
|
834
|
+
version_handle = create_truss_service(
|
|
835
|
+
api,
|
|
836
|
+
"model_name",
|
|
837
|
+
"s3_key",
|
|
838
|
+
"config",
|
|
839
|
+
b10_types.TrussUserEnv.collect(),
|
|
840
|
+
is_draft=True,
|
|
841
|
+
model_id=None,
|
|
842
|
+
deploy_timeout_minutes=600,
|
|
843
|
+
)
|
|
844
|
+
|
|
845
|
+
assert version_handle.version_id == "model_version_id"
|
|
846
|
+
assert version_handle.model_id == "model_id"
|
|
847
|
+
api.create_development_model_from_truss.assert_called_once()
|
|
848
|
+
_, kwargs = api.create_development_model_from_truss.call_args
|
|
849
|
+
assert kwargs["deploy_timeout_minutes"] == 600
|