truss 0.11.18rc500__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. truss/api/__init__.py +5 -2
  2. truss/base/truss_config.py +10 -3
  3. truss/cli/chains_commands.py +39 -1
  4. truss/cli/cli.py +35 -5
  5. truss/cli/remote_cli.py +29 -0
  6. truss/cli/resolvers/chain_team_resolver.py +82 -0
  7. truss/cli/resolvers/model_team_resolver.py +90 -0
  8. truss/cli/resolvers/training_project_team_resolver.py +81 -0
  9. truss/cli/train/cache.py +332 -0
  10. truss/cli/train/core.py +19 -143
  11. truss/cli/train_commands.py +69 -11
  12. truss/cli/utils/common.py +40 -3
  13. truss/remote/baseten/api.py +58 -5
  14. truss/remote/baseten/core.py +22 -4
  15. truss/remote/baseten/remote.py +24 -2
  16. truss/templates/control/control/helpers/inference_server_process_controller.py +3 -1
  17. truss/templates/server/requirements.txt +1 -1
  18. truss/templates/server.Dockerfile.jinja +10 -10
  19. truss/templates/shared/util.py +6 -5
  20. truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
  21. truss/tests/cli/test_chains_cli.py +44 -0
  22. truss/tests/cli/test_cli.py +134 -1
  23. truss/tests/cli/test_cli_utils_common.py +11 -0
  24. truss/tests/cli/test_model_team_resolver.py +279 -0
  25. truss/tests/cli/train/test_cache_view.py +240 -3
  26. truss/tests/cli/train/test_train_cli_core.py +2 -2
  27. truss/tests/cli/train/test_train_team_parameter.py +395 -0
  28. truss/tests/conftest.py +187 -0
  29. truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
  30. truss/tests/remote/baseten/test_api.py +122 -3
  31. truss/tests/remote/baseten/test_chain_upload.py +10 -1
  32. truss/tests/remote/baseten/test_core.py +86 -0
  33. truss/tests/remote/baseten/test_remote.py +216 -288
  34. truss/tests/test_config.py +21 -12
  35. truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
  36. truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
  37. truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
  38. truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
  39. truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
  40. truss/tests/test_model_inference.py +13 -0
  41. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/METADATA +1 -1
  42. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/RECORD +50 -38
  43. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
  44. truss_chains/deployment/deployment_client.py +9 -4
  45. truss_chains/private_types.py +15 -0
  46. truss_train/definitions.py +3 -1
  47. truss_train/deployment.py +43 -21
  48. truss_train/public_api.py +4 -2
  49. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
  50. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
@@ -165,9 +165,9 @@ def test_create_model_version_from_truss(mock_post, baseten_api):
165
165
  "config_str",
166
166
  "semver_bump",
167
167
  b10_types.TrussUserEnv.collect(),
168
- False,
169
- "deployment_name",
170
- "production",
168
+ preserve_previous_prod_deployment=False,
169
+ deployment_name="deployment_name",
170
+ environment="production",
171
171
  )
172
172
 
173
173
  gql_mutation = mock_post.call_args[1]["json"]["query"]
@@ -182,6 +182,7 @@ def test_create_model_version_from_truss(mock_post, baseten_api):
182
182
  assert 'name: "deployment_name"' in gql_mutation
183
183
  assert 'environment_name: "production"' in gql_mutation
184
184
  assert "preserve_env_instance_type: true" in gql_mutation
185
+ assert "deploy_timeout_minutes: " not in gql_mutation
185
186
 
186
187
 
187
188
  @mock.patch("requests.post", return_value=mock_create_model_version_response())
@@ -211,6 +212,7 @@ def test_create_model_version_from_truss_does_not_send_deployment_name_if_not_sp
211
212
  assert " name: " not in gql_mutation
212
213
  assert "environment_name: " not in gql_mutation
213
214
  assert "preserve_env_instance_type: false" in gql_mutation
215
+ assert "deploy_timeout_minutes: " not in gql_mutation
214
216
 
215
217
 
216
218
  @mock.patch("requests.post", return_value=mock_create_model_version_response())
@@ -242,6 +244,57 @@ def test_create_model_version_from_truss_does_not_scale_old_prod_to_zero_if_keep
242
244
  assert " name: " not in gql_mutation
243
245
  assert 'environment_name: "staging"' in gql_mutation
244
246
  assert "preserve_env_instance_type: true" in gql_mutation
247
+ assert "deploy_timeout_minutes: " not in gql_mutation
248
+
249
+
250
+ @mock.patch("requests.post", return_value=mock_create_model_version_response())
251
+ def test_create_model_version_from_truss_with_deploy_timeout_minutes(
252
+ mock_post, baseten_api
253
+ ):
254
+ baseten_api.create_model_version_from_truss(
255
+ "model_id",
256
+ "s3key",
257
+ "config_str",
258
+ "semver_bump",
259
+ b10_types.TrussUserEnv.collect(),
260
+ preserve_previous_prod_deployment=False,
261
+ deployment_name="deployment_name",
262
+ environment="production",
263
+ deploy_timeout_minutes=300,
264
+ )
265
+
266
+ gql_mutation = mock_post.call_args[1]["json"]["query"]
267
+ assert 'model_id: "model_id"' in gql_mutation
268
+ assert 's3_key: "s3key"' in gql_mutation
269
+ assert 'config: "config_str"' in gql_mutation
270
+ assert 'semver_bump: "semver_bump"' in gql_mutation
271
+ assert {
272
+ "trussUserEnv": b10_types.TrussUserEnv.collect().model_dump_json()
273
+ } == mock_post.call_args[1]["json"]["variables"]
274
+ assert "scale_down_old_production: true" in gql_mutation
275
+ assert 'name: "deployment_name"' in gql_mutation
276
+ assert 'environment_name: "production"' in gql_mutation
277
+ assert "preserve_env_instance_type: true" in gql_mutation
278
+ assert "deploy_timeout_minutes: 300" in gql_mutation
279
+
280
+
281
+ @mock.patch("requests.post", return_value=mock_create_model_version_response())
282
+ def test_create_model_version_from_truss_with_deploy_timeout_minutes_zero(
283
+ mock_post, baseten_api
284
+ ):
285
+ """Test that deploy_timeout_minutes of 0 is handled correctly"""
286
+ baseten_api.create_model_version_from_truss(
287
+ "model_id",
288
+ "s3key",
289
+ "config_str",
290
+ "semver_bump",
291
+ b10_types.TrussUserEnv.collect(),
292
+ preserve_previous_prod_deployment=False,
293
+ deploy_timeout_minutes=0,
294
+ )
295
+
296
+ gql_mutation = mock_post.call_args[1]["json"]["query"]
297
+ assert "deploy_timeout_minutes: 0" in gql_mutation
245
298
 
246
299
 
247
300
  @mock.patch("requests.post", return_value=mock_create_model_response())
@@ -332,6 +385,48 @@ def test_create_development_model_from_truss_with_allow_truss_download(
332
385
  "trussUserEnv": b10_types.TrussUserEnv.collect().model_dump_json()
333
386
  } == mock_post.call_args[1]["json"]["variables"]
334
387
  assert "allow_truss_download: false" in gql_mutation
388
+ assert "deploy_timeout_minutes: " not in gql_mutation
389
+
390
+
391
+ @mock.patch("requests.post", return_value=mock_create_development_model_response())
392
+ def test_create_development_model_from_truss_with_deploy_timeout_minutes(
393
+ mock_post, baseten_api
394
+ ):
395
+ baseten_api.create_development_model_from_truss(
396
+ "model_name",
397
+ "s3key",
398
+ "config_str",
399
+ b10_types.TrussUserEnv.collect(),
400
+ allow_truss_download=False,
401
+ deploy_timeout_minutes=300,
402
+ )
403
+
404
+ gql_mutation = mock_post.call_args[1]["json"]["query"]
405
+ assert 'name: "model_name"' in gql_mutation
406
+ assert 's3_key: "s3key"' in gql_mutation
407
+ assert 'config: "config_str"' in gql_mutation
408
+ assert {
409
+ "trussUserEnv": b10_types.TrussUserEnv.collect().model_dump_json()
410
+ } == mock_post.call_args[1]["json"]["variables"]
411
+ assert "allow_truss_download: false" in gql_mutation
412
+ assert "deploy_timeout_minutes: 300" in gql_mutation
413
+
414
+
415
+ @mock.patch("requests.post", return_value=mock_create_development_model_response())
416
+ def test_create_development_model_from_truss_with_deploy_timeout_minutes_zero(
417
+ mock_post, baseten_api
418
+ ):
419
+ """Test that deploy_timeout_minutes of 0 is handled correctly"""
420
+ baseten_api.create_development_model_from_truss(
421
+ "model_name",
422
+ "s3key",
423
+ "config_str",
424
+ b10_types.TrussUserEnv.collect(),
425
+ deploy_timeout_minutes=0,
426
+ )
427
+
428
+ gql_mutation = mock_post.call_args[1]["json"]["query"]
429
+ assert "deploy_timeout_minutes: 0" in gql_mutation
335
430
 
336
431
 
337
432
  @mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response())
@@ -357,6 +452,30 @@ def test_deploy_chain_deployment(mock_post, baseten_api):
357
452
  assert 'chain_id: "chain_id"' in gql_mutation
358
453
  assert "dependencies:" in gql_mutation
359
454
  assert "entrypoint:" in gql_mutation
455
+ assert "deployment_name" not in gql_mutation
456
+
457
+
458
+ @mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response())
459
+ def test_deploy_chain_deployment_with_deployment_name(mock_post, baseten_api):
460
+ baseten_api.deploy_chain_atomic(
461
+ environment="production",
462
+ chain_id="chain_id",
463
+ dependencies=[],
464
+ entrypoint=ChainletDataAtomic(
465
+ name="chainlet-1",
466
+ oracle=OracleData(
467
+ model_name="model-1",
468
+ s3_key="s3-key-1",
469
+ encoded_config_str="encoded-config-str-1",
470
+ ),
471
+ ),
472
+ truss_user_env=b10_types.TrussUserEnv.collect(),
473
+ deployment_name="chain-deployment-name",
474
+ )
475
+
476
+ gql_mutation = mock_post.call_args[1]["json"]["query"]
477
+
478
+ assert 'deployment_name: "chain-deployment-name"' in gql_mutation
360
479
 
361
480
 
362
481
  @mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response())
@@ -186,6 +186,7 @@ def test_push_chain_atomic_with_chain_upload(
186
186
  chain_root = context["chain_root"]
187
187
 
188
188
  context["mock_prepare_push"].return_value = mock_push_data
189
+ deployment_name = "custom_deployment"
189
190
 
190
191
  result = remote.push_chain_atomic(
191
192
  chain_name=chain_name,
@@ -194,13 +195,18 @@ def test_push_chain_atomic_with_chain_upload(
194
195
  truss_user_env=truss_user_env,
195
196
  chain_root=chain_root,
196
197
  publish=True,
198
+ deployment_name=deployment_name,
197
199
  )
198
200
  assert result == mock_create_chain_atomic.return_value
199
201
 
200
202
  mock_archive_dir.assert_called_once_with(dir=chain_root, progress_bar=None)
201
203
  mock_upload_chain_artifact.assert_called_once()
202
-
203
204
  mock_create_chain_atomic.assert_called_once()
205
+ create_kwargs = mock_create_chain_atomic.call_args.kwargs
206
+ assert create_kwargs["deployment_name"] == deployment_name
207
+
208
+ prepare_kwargs = context["mock_prepare_push"].call_args.kwargs
209
+ assert prepare_kwargs["deployment_name"] == deployment_name
204
210
 
205
211
 
206
212
  @patch("truss.remote.baseten.remote.create_chain_atomic")
@@ -239,6 +245,9 @@ def test_push_chain_atomic_without_chain_upload(
239
245
  mock_upload.assert_not_called()
240
246
 
241
247
  mock_create_chain_atomic.assert_called_once()
248
+ create_kwargs = mock_create_chain_atomic.call_args.kwargs
249
+ assert "deployment_name" in create_kwargs
250
+ assert create_kwargs["deployment_name"] is None
242
251
 
243
252
 
244
253
  @patch("truss.remote.baseten.core.multipart_upload_boto3")
@@ -187,6 +187,7 @@ def test_create_truss_service_handles_existing_model(inputs):
187
187
  _, kwargs = api.create_model_version_from_truss.call_args
188
188
  for k, v in inputs.items():
189
189
  assert kwargs[k] == v
190
+ assert kwargs.get("deploy_timeout_minutes") is None
190
191
 
191
192
 
192
193
  @pytest.mark.parametrize("allow_truss_download", [True, False])
@@ -761,3 +762,88 @@ def test_get_training_job_logs_with_pagination_default_batch_size(baseten_api):
761
762
  query_params = call_args[0][2] # query_params
762
763
 
763
764
  assert query_params["limit"] == MAX_BATCH_SIZE
765
+
766
+
767
+ def test_create_truss_service_passes_deploy_timeout_minutes():
768
+ """Test that deploy_timeout_minutes is passed through to create_model_version_from_truss"""
769
+ api = MagicMock()
770
+ return_value = {
771
+ "id": "model_version_id",
772
+ "oracle": {"id": "model_id", "hostname": "hostname"},
773
+ "instance_type": {"name": "1x2"},
774
+ }
775
+ api.create_model_version_from_truss.return_value = return_value
776
+ version_handle = create_truss_service(
777
+ api,
778
+ "model_name",
779
+ "s3_key",
780
+ "config",
781
+ b10_types.TrussUserEnv.collect(),
782
+ is_draft=False,
783
+ model_id="model_id",
784
+ environment="staging",
785
+ deploy_timeout_minutes=600,
786
+ )
787
+
788
+ assert version_handle.version_id == "model_version_id"
789
+ assert version_handle.model_id == "model_id"
790
+ api.create_model_version_from_truss.assert_called_once()
791
+ _, kwargs = api.create_model_version_from_truss.call_args
792
+ assert kwargs["deploy_timeout_minutes"] == 600
793
+
794
+
795
+ def test_create_truss_service_passes_deploy_timeout_minutes_with_other_params():
796
+ """Test that deploy_timeout_minutes works correctly with other parameters like preserve_env_instance_type"""
797
+ api = MagicMock()
798
+ return_value = {
799
+ "id": "model_version_id",
800
+ "oracle": {"id": "model_id", "hostname": "hostname"},
801
+ "instance_type": {"name": "1x2"},
802
+ }
803
+ api.create_model_version_from_truss.return_value = return_value
804
+ version_handle = create_truss_service(
805
+ api,
806
+ "model_name",
807
+ "s3_key",
808
+ "config",
809
+ b10_types.TrussUserEnv.collect(),
810
+ is_draft=False,
811
+ model_id="model_id",
812
+ environment="production",
813
+ preserve_env_instance_type=False,
814
+ deploy_timeout_minutes=900,
815
+ )
816
+
817
+ assert version_handle.version_id == "model_version_id"
818
+ api.create_model_version_from_truss.assert_called_once()
819
+ _, kwargs = api.create_model_version_from_truss.call_args
820
+ assert kwargs["deploy_timeout_minutes"] == 900
821
+ assert kwargs["preserve_env_instance_type"] is False
822
+ assert kwargs["environment"] == "production"
823
+
824
+
825
+ def test_create_truss_service_passes_deploy_timeout_minutes_for_development_model():
826
+ """Test that deploy_timeout_minutes is passed through to create_development_model_from_truss"""
827
+ api = MagicMock()
828
+ return_value = {
829
+ "id": "model_version_id",
830
+ "oracle": {"id": "model_id", "hostname": "hostname"},
831
+ "instance_type": {"name": "1x2"},
832
+ }
833
+ api.create_development_model_from_truss.return_value = return_value
834
+ version_handle = create_truss_service(
835
+ api,
836
+ "model_name",
837
+ "s3_key",
838
+ "config",
839
+ b10_types.TrussUserEnv.collect(),
840
+ is_draft=True,
841
+ model_id=None,
842
+ deploy_timeout_minutes=600,
843
+ )
844
+
845
+ assert version_handle.version_id == "model_version_id"
846
+ assert version_handle.model_id == "model_id"
847
+ api.create_development_model_from_truss.assert_called_once()
848
+ _, kwargs = api.create_development_model_from_truss.call_args
849
+ assert kwargs["deploy_timeout_minutes"] == 600