truss 0.11.18rc500__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. truss/api/__init__.py +5 -2
  2. truss/base/truss_config.py +10 -3
  3. truss/cli/chains_commands.py +39 -1
  4. truss/cli/cli.py +35 -5
  5. truss/cli/remote_cli.py +29 -0
  6. truss/cli/resolvers/chain_team_resolver.py +82 -0
  7. truss/cli/resolvers/model_team_resolver.py +90 -0
  8. truss/cli/resolvers/training_project_team_resolver.py +81 -0
  9. truss/cli/train/cache.py +332 -0
  10. truss/cli/train/core.py +19 -143
  11. truss/cli/train_commands.py +69 -11
  12. truss/cli/utils/common.py +40 -3
  13. truss/remote/baseten/api.py +58 -5
  14. truss/remote/baseten/core.py +22 -4
  15. truss/remote/baseten/remote.py +24 -2
  16. truss/templates/control/control/helpers/inference_server_process_controller.py +3 -1
  17. truss/templates/server/requirements.txt +1 -1
  18. truss/templates/server.Dockerfile.jinja +10 -10
  19. truss/templates/shared/util.py +6 -5
  20. truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
  21. truss/tests/cli/test_chains_cli.py +44 -0
  22. truss/tests/cli/test_cli.py +134 -1
  23. truss/tests/cli/test_cli_utils_common.py +11 -0
  24. truss/tests/cli/test_model_team_resolver.py +279 -0
  25. truss/tests/cli/train/test_cache_view.py +240 -3
  26. truss/tests/cli/train/test_train_cli_core.py +2 -2
  27. truss/tests/cli/train/test_train_team_parameter.py +395 -0
  28. truss/tests/conftest.py +187 -0
  29. truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
  30. truss/tests/remote/baseten/test_api.py +122 -3
  31. truss/tests/remote/baseten/test_chain_upload.py +10 -1
  32. truss/tests/remote/baseten/test_core.py +86 -0
  33. truss/tests/remote/baseten/test_remote.py +216 -288
  34. truss/tests/test_config.py +21 -12
  35. truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
  36. truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
  37. truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
  38. truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
  39. truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
  40. truss/tests/test_model_inference.py +13 -0
  41. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/METADATA +1 -1
  42. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/RECORD +50 -38
  43. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
  44. truss_chains/deployment/deployment_client.py +9 -4
  45. truss_chains/private_types.py +15 -0
  46. truss_train/definitions.py +3 -1
  47. truss_train/deployment.py +43 -21
  48. truss_train/public_api.py +4 -2
  49. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
  50. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
@@ -1,3 +1,5 @@
1
+ from unittest import mock
2
+
1
3
  import pydantic
2
4
  import pytest
3
5
  import requests_mock
@@ -11,16 +13,17 @@ from truss.remote.baseten.core import (
11
13
  )
12
14
  from truss.remote.baseten.custom_types import ChainletDataAtomic, OracleData
13
15
  from truss.remote.baseten.error import RemoteError
14
- from truss.remote.baseten.remote import BasetenRemote
15
16
  from truss.truss_handle.truss_handle import TrussHandle
16
17
 
17
18
  _TEST_REMOTE_URL = "http://test_remote.com"
18
19
  _TEST_REMOTE_GRAPHQL_PATH = "http://test_remote.com/graphql/"
19
20
 
20
-
21
- @pytest.fixture
22
- def remote():
23
- return BasetenRemote(_TEST_REMOTE_URL, "api_key")
21
+ TRUSS_RC_CONTENT = """
22
+ [baseten]
23
+ remote_provider = baseten
24
+ api_key = test_key
25
+ remote_url = http://test.com
26
+ """.strip()
24
27
 
25
28
 
26
29
  def assert_request_matches_expected_query(request, expected_query) -> None:
@@ -269,30 +272,41 @@ def test_push_raised_value_error_when_keep_previous_prod_settings_and_not_promot
269
272
  )
270
273
 
271
274
 
272
- def test_create_chain_with_no_publish(remote):
273
- with requests_mock.Mocker() as m:
274
- m.post(
275
- _TEST_REMOTE_GRAPHQL_PATH,
276
- [
277
- {"json": {"data": {"chains": []}}},
278
- {
279
- "json": {
280
- "data": {
281
- "deploy_chain_atomic": {
282
- "chain_deployment": {
283
- "id": "new-chain-deployment-id",
284
- "chain": {
285
- "id": "new-chain-id",
286
- "hostname": "hostname",
287
- },
288
- }
289
- }
290
- }
291
- }
292
- },
293
- ],
275
+ @pytest.mark.parametrize("deploy_timeout_minutes", [9, 1441])
276
+ def test_push_raised_value_error_when_deploy_timeout_minutes_is_invalid(
277
+ deploy_timeout_minutes, custom_model_truss_dir_with_pre_and_post, remote
278
+ ):
279
+ th = TrussHandle(custom_model_truss_dir_with_pre_and_post)
280
+
281
+ with pytest.raises(
282
+ ValueError,
283
+ match="deploy-timeout-minutes must be between 10 minutes and 1440 minutes \(24 hours\)",
284
+ ):
285
+ remote.push(
286
+ th,
287
+ "model_name",
288
+ th.truss_dir,
289
+ publish=True,
290
+ promote=False,
291
+ preserve_previous_prod_deployment=False,
292
+ deployment_name="dep_name",
293
+ deploy_timeout_minutes=deploy_timeout_minutes,
294
294
  )
295
295
 
296
+
297
+ def test_create_chain_with_no_publish(remote):
298
+ mock_deploy_response = {
299
+ "chain_deployment": {
300
+ "id": "new-chain-deployment-id",
301
+ "chain": {"id": "new-chain-id", "hostname": "hostname"},
302
+ }
303
+ }
304
+
305
+ with mock.patch.object(
306
+ remote.api, "get_chains", return_value=[]
307
+ ) as mock_get_chains, mock.patch.object(
308
+ remote.api, "deploy_chain_atomic", return_value=mock_deploy_response
309
+ ) as mock_deploy:
296
310
  deployment_handle = create_chain_atomic(
297
311
  api=remote.api,
298
312
  chain_name="draft_chain",
@@ -310,64 +324,61 @@ def test_create_chain_with_no_publish(remote):
310
324
  environment=None,
311
325
  )
312
326
 
313
- get_chains_graphql_request = m.request_history[0]
314
- create_chain_graphql_request = m.request_history[1]
327
+ mock_get_chains.assert_called_once()
328
+ mock_deploy.assert_called_once()
315
329
 
316
- expected_get_chains_query = """
317
- {
318
- chains {
319
- id
320
- name
321
- }
322
- }
323
- """.strip()
330
+ call_kwargs = mock_deploy.call_args.kwargs
331
+ assert call_kwargs["chain_name"] == "draft_chain"
332
+ assert call_kwargs.get("is_draft") is True
333
+ assert call_kwargs.get("deploy_timeout_minutes") is None
324
334
 
325
- assert_request_matches_expected_query(
326
- get_chains_graphql_request, expected_get_chains_query
327
- )
335
+ assert deployment_handle.chain_id == "new-chain-id"
336
+ assert deployment_handle.chain_deployment_id == "new-chain-deployment-id"
328
337
 
329
- chainlets_string = """
330
- {
331
- name: "chainlet-1",
332
- oracle: {
333
- model_name: "model-1",
334
- s3_key: "s3-key-1",
335
- encoded_config_str: "encoded-config-str-1",
336
- semver_bump: "MINOR"
337
- }
338
- }
339
- """.strip()
340
-
341
- # Note that if publish=False and promote=True, we set publish to True and create
342
- # a non-draft deployment
343
- expected_create_chain_mutation = f"""
344
- mutation ($trussUserEnv: String) {{
345
- deploy_chain_atomic(
346
- chain_name: "draft_chain"
347
- is_draft: true
348
- entrypoint: {chainlets_string}
349
- dependencies: []
350
- truss_user_env: $trussUserEnv
351
- ) {{
352
- chain_deployment {{
353
- id
354
- chain {{
355
- id
356
- hostname
357
- }}
358
- }}
359
- }}
360
- }}
361
- """.strip()
362
-
363
- assert_request_matches_expected_query(
364
- create_chain_graphql_request, expected_create_chain_mutation
338
+
339
+ def test_create_chain_no_existing_chain(remote):
340
+ mock_deploy_response = {
341
+ "chain_deployment": {
342
+ "id": "new-chain-deployment-id",
343
+ "chain": {"id": "new-chain-id", "hostname": "hostname"},
344
+ }
345
+ }
346
+
347
+ with mock.patch.object(
348
+ remote.api, "get_chains", return_value=[]
349
+ ) as mock_get_chains, mock.patch.object(
350
+ remote.api, "deploy_chain_atomic", return_value=mock_deploy_response
351
+ ) as mock_deploy:
352
+ deployment_handle = create_chain_atomic(
353
+ api=remote.api,
354
+ chain_name="new_chain",
355
+ entrypoint=ChainletDataAtomic(
356
+ name="chainlet-1",
357
+ oracle=OracleData(
358
+ model_name="model-1",
359
+ s3_key="s3-key-1",
360
+ encoded_config_str="encoded-config-str-1",
361
+ ),
362
+ ),
363
+ dependencies=[],
364
+ truss_user_env=b10_types.TrussUserEnv.collect(),
365
+ is_draft=False,
366
+ environment=None,
365
367
  )
368
+
369
+ mock_get_chains.assert_called_once()
370
+ mock_deploy.assert_called_once()
371
+
372
+ call_kwargs = mock_deploy.call_args.kwargs
373
+ assert call_kwargs["chain_name"] == "new_chain"
374
+ assert call_kwargs.get("is_draft") is not True
375
+ assert call_kwargs.get("deploy_timeout_minutes") is None
376
+
366
377
  assert deployment_handle.chain_id == "new-chain-id"
367
378
  assert deployment_handle.chain_deployment_id == "new-chain-deployment-id"
368
379
 
369
380
 
370
- def test_create_chain_no_existing_chain(remote):
381
+ def test_create_chain_with_deployment_name(remote):
371
382
  with requests_mock.Mocker() as m:
372
383
  m.post(
373
384
  _TEST_REMOTE_GRAPHQL_PATH,
@@ -391,7 +402,8 @@ def test_create_chain_no_existing_chain(remote):
391
402
  ],
392
403
  )
393
404
 
394
- deployment_handle = create_chain_atomic(
405
+ deployment_name = "chain-deployment"
406
+ create_chain_atomic(
395
407
  api=remote.api,
396
408
  chain_name="new_chain",
397
409
  entrypoint=ChainletDataAtomic(
@@ -406,94 +418,32 @@ def test_create_chain_no_existing_chain(remote):
406
418
  truss_user_env=b10_types.TrussUserEnv.collect(),
407
419
  is_draft=False,
408
420
  environment=None,
421
+ deployment_name=deployment_name,
409
422
  )
410
423
 
411
- get_chains_graphql_request = m.request_history[0]
412
424
  create_chain_graphql_request = m.request_history[1]
413
425
 
414
- expected_get_chains_query = """
415
- {
416
- chains {
417
- id
418
- name
419
- }
420
- }
421
- """.strip()
422
-
423
- assert_request_matches_expected_query(
424
- get_chains_graphql_request, expected_get_chains_query
425
- )
426
-
427
- chainlets_string = """
428
- {
429
- name: "chainlet-1",
430
- oracle: {
431
- model_name: "model-1",
432
- s3_key: "s3-key-1",
433
- encoded_config_str: "encoded-config-str-1",
434
- semver_bump: "MINOR"
435
- }
436
- }
437
- """.strip()
438
-
439
- expected_create_chain_mutation = f"""
440
- mutation ($trussUserEnv: String) {{
441
- deploy_chain_atomic(
442
- chain_name: "new_chain"
443
- is_draft: false
444
- entrypoint: {chainlets_string}
445
- dependencies: []
446
- truss_user_env: $trussUserEnv
447
- ) {{
448
- chain_deployment {{
449
- id
450
- chain {{
451
- id
452
- hostname
453
- }}
454
- }}
455
- }}
456
- }}
457
- """.strip()
458
-
459
- assert_request_matches_expected_query(
460
- create_chain_graphql_request, expected_create_chain_mutation
426
+ assert (
427
+ 'deployment_name: "chain-deployment"'
428
+ in create_chain_graphql_request.json()["query"]
461
429
  )
462
430
 
463
- assert deployment_handle.chain_id == "new-chain-id"
464
- assert deployment_handle.chain_deployment_id == "new-chain-deployment-id"
465
-
466
431
 
467
432
  def test_create_chain_with_existing_chain_promote_to_environment_publish_false(remote):
468
- with requests_mock.Mocker() as m:
469
- m.post(
470
- _TEST_REMOTE_GRAPHQL_PATH,
471
- [
472
- {
473
- "json": {
474
- "data": {
475
- "chains": [{"id": "old-chain-id", "name": "old_chain"}]
476
- }
477
- }
478
- },
479
- {
480
- "json": {
481
- "data": {
482
- "deploy_chain_atomic": {
483
- "chain_deployment": {
484
- "id": "new-chain-deployment-id",
485
- "chain": {
486
- "id": "new-chain-id",
487
- "hostname": "hostname",
488
- },
489
- }
490
- }
491
- }
492
- }
493
- },
494
- ],
495
- )
433
+ mock_deploy_response = {
434
+ "chain_deployment": {
435
+ "id": "new-chain-deployment-id",
436
+ "chain": {"id": "new-chain-id", "hostname": "hostname"},
437
+ }
438
+ }
496
439
 
440
+ with mock.patch.object(
441
+ remote.api,
442
+ "get_chains",
443
+ return_value=[{"id": "old-chain-id", "name": "old_chain"}],
444
+ ) as mock_get_chains, mock.patch.object(
445
+ remote.api, "deploy_chain_atomic", return_value=mock_deploy_response
446
+ ) as mock_deploy:
497
447
  deployment_handle = create_chain_atomic(
498
448
  api=remote.api,
499
449
  chain_name="old_chain",
@@ -511,95 +461,34 @@ def test_create_chain_with_existing_chain_promote_to_environment_publish_false(r
511
461
  environment="production",
512
462
  )
513
463
 
514
- get_chains_graphql_request = m.request_history[0]
515
- create_chain_graphql_request = m.request_history[1]
464
+ mock_get_chains.assert_called_once()
465
+ mock_deploy.assert_called_once()
516
466
 
517
- expected_get_chains_query = """
518
- {
519
- chains {
520
- id
521
- name
522
- }
523
- }
524
- """.strip()
525
-
526
- assert_request_matches_expected_query(
527
- get_chains_graphql_request, expected_get_chains_query
528
- )
529
-
530
- # Note that if publish=False and environment!=None, we set publish to True and create
531
- # a non-draft deployment
532
- chainlets_string = """
533
- {
534
- name: "chainlet-1",
535
- oracle: {
536
- model_name: "model-1",
537
- s3_key: "s3-key-1",
538
- encoded_config_str: "encoded-config-str-1",
539
- semver_bump: "MINOR"
540
- }
541
- }
542
- """.strip()
543
-
544
- expected_create_chain_mutation = f"""
545
- mutation ($trussUserEnv: String) {{
546
- deploy_chain_atomic(
547
- chain_id: "old-chain-id"
548
- environment: "production"
549
- is_draft: false
550
- entrypoint: {chainlets_string}
551
- dependencies: []
552
- truss_user_env: $trussUserEnv
553
- ) {{
554
- chain_deployment {{
555
- id
556
- chain {{
557
- id
558
- hostname
559
- }}
560
- }}
561
- }}
562
- }}
563
- """.strip()
564
-
565
- assert_request_matches_expected_query(
566
- create_chain_graphql_request, expected_create_chain_mutation
567
- )
467
+ call_kwargs = mock_deploy.call_args.kwargs
468
+ assert call_kwargs["chain_id"] == "old-chain-id"
469
+ assert call_kwargs["environment"] == "production"
470
+ assert call_kwargs.get("is_draft") is not True
471
+ assert call_kwargs.get("deploy_timeout_minutes") is None
568
472
 
569
473
  assert deployment_handle.chain_id == "new-chain-id"
570
474
  assert deployment_handle.chain_deployment_id == "new-chain-deployment-id"
571
475
 
572
476
 
573
477
  def test_create_chain_existing_chain_publish_true_no_promotion(remote):
574
- with requests_mock.Mocker() as m:
575
- m.post(
576
- _TEST_REMOTE_GRAPHQL_PATH,
577
- [
578
- {
579
- "json": {
580
- "data": {
581
- "chains": [{"id": "old-chain-id", "name": "old_chain"}]
582
- }
583
- }
584
- },
585
- {
586
- "json": {
587
- "data": {
588
- "deploy_chain_atomic": {
589
- "chain_deployment": {
590
- "id": "new-chain-deployment-id",
591
- "chain": {
592
- "id": "new-chain-id",
593
- "hostname": "hostname",
594
- },
595
- }
596
- }
597
- }
598
- }
599
- },
600
- ],
601
- )
478
+ mock_deploy_response = {
479
+ "chain_deployment": {
480
+ "id": "new-chain-deployment-id",
481
+ "chain": {"id": "new-chain-id", "hostname": "hostname"},
482
+ }
483
+ }
602
484
 
485
+ with mock.patch.object(
486
+ remote.api,
487
+ "get_chains",
488
+ return_value=[{"id": "old-chain-id", "name": "old_chain"}],
489
+ ) as mock_get_chains, mock.patch.object(
490
+ remote.api, "deploy_chain_atomic", return_value=mock_deploy_response
491
+ ) as mock_deploy:
603
492
  deployment_handle = create_chain_atomic(
604
493
  api=remote.api,
605
494
  chain_name="old_chain",
@@ -617,57 +506,13 @@ def test_create_chain_existing_chain_publish_true_no_promotion(remote):
617
506
  environment=None,
618
507
  )
619
508
 
620
- get_chains_graphql_request = m.request_history[0]
621
- create_chain_graphql_request = m.request_history[1]
622
-
623
- expected_get_chains_query = """
624
- {
625
- chains {
626
- id
627
- name
628
- }
629
- }
630
- """.strip()
631
-
632
- assert_request_matches_expected_query(
633
- get_chains_graphql_request, expected_get_chains_query
634
- )
509
+ mock_get_chains.assert_called_once()
510
+ mock_deploy.assert_called_once()
635
511
 
636
- chainlets_string = """
637
- {
638
- name: "chainlet-1",
639
- oracle: {
640
- model_name: "model-1",
641
- s3_key: "s3-key-1",
642
- encoded_config_str: "encoded-config-str-1",
643
- semver_bump: "MINOR"
644
- }
645
- }
646
- """.strip()
647
-
648
- expected_create_chain_mutation = f"""
649
- mutation ($trussUserEnv: String) {{
650
- deploy_chain_atomic(
651
- chain_id: "old-chain-id"
652
- is_draft: false
653
- entrypoint: {chainlets_string}
654
- dependencies: []
655
- truss_user_env: $trussUserEnv
656
- ) {{
657
- chain_deployment {{
658
- id
659
- chain {{
660
- id
661
- hostname
662
- }}
663
- }}
664
- }}
665
- }}
666
- """.strip()
667
-
668
- assert_request_matches_expected_query(
669
- create_chain_graphql_request, expected_create_chain_mutation
670
- )
512
+ call_kwargs = mock_deploy.call_args.kwargs
513
+ assert call_kwargs["chain_id"] == "old-chain-id"
514
+ assert call_kwargs.get("is_draft") is not True
515
+ assert call_kwargs.get("deploy_timeout_minutes") is None
671
516
 
672
517
  assert deployment_handle.chain_id == "new-chain-id"
673
518
  assert deployment_handle.chain_deployment_id == "new-chain-deployment-id"
@@ -726,3 +571,86 @@ def test_push_raised_validation_error_for_extra_fields(tmp_path, remote):
726
571
  match="Extra fields not allowed: \[extra_field, who_am_i\]",
727
572
  ):
728
573
  remote.push(th, "model_name", th.truss_dir)
574
+
575
+
576
+ def test_push_passes_deploy_timeout_minutes_to_create_truss_service(
577
+ custom_model_truss_dir_with_pre_and_post,
578
+ remote,
579
+ mock_baseten_requests,
580
+ mock_upload_truss,
581
+ mock_create_truss_service,
582
+ mock_truss_handle,
583
+ ):
584
+ remote.push(
585
+ mock_truss_handle,
586
+ "model_name",
587
+ mock_truss_handle.truss_dir,
588
+ publish=True,
589
+ deploy_timeout_minutes=450,
590
+ )
591
+
592
+ mock_create_truss_service.assert_called_once()
593
+ _, kwargs = mock_create_truss_service.call_args
594
+ assert kwargs["deploy_timeout_minutes"] == 450
595
+
596
+
597
+ def test_push_passes_none_deploy_timeout_minutes_when_not_specified(
598
+ custom_model_truss_dir_with_pre_and_post,
599
+ remote,
600
+ mock_baseten_requests,
601
+ mock_upload_truss,
602
+ mock_create_truss_service,
603
+ mock_truss_handle,
604
+ ):
605
+ remote.push(
606
+ mock_truss_handle, "model_name", mock_truss_handle.truss_dir, publish=True
607
+ )
608
+
609
+ mock_create_truss_service.assert_called_once()
610
+ _, kwargs = mock_create_truss_service.call_args
611
+ assert kwargs.get("deploy_timeout_minutes") is None
612
+
613
+
614
+ def test_push_integration_deploy_timeout_minutes_propagated(
615
+ custom_model_truss_dir_with_pre_and_post,
616
+ remote,
617
+ mock_baseten_requests,
618
+ mock_upload_truss,
619
+ mock_create_truss_service,
620
+ mock_truss_handle,
621
+ ):
622
+ remote.push(
623
+ mock_truss_handle,
624
+ "model_name",
625
+ mock_truss_handle.truss_dir,
626
+ publish=True,
627
+ environment="staging",
628
+ deploy_timeout_minutes=750,
629
+ )
630
+
631
+ mock_create_truss_service.assert_called_once()
632
+ _, kwargs = mock_create_truss_service.call_args
633
+ assert kwargs["deploy_timeout_minutes"] == 750
634
+ assert kwargs["environment"] == "staging"
635
+
636
+
637
+ def test_api_push_integration_deploy_timeout_minutes_propagated(
638
+ custom_model_truss_dir_with_pre_and_post,
639
+ mock_remote_factory,
640
+ temp_trussrc_dir,
641
+ mock_available_config_names,
642
+ mock_truss_handle,
643
+ ):
644
+ from truss.api import push
645
+
646
+ push(
647
+ str(mock_truss_handle.truss_dir),
648
+ remote="baseten",
649
+ model_name="test_model",
650
+ deploy_timeout_minutes=1200,
651
+ )
652
+
653
+ # Verify the remote.push was called with deploy_timeout_minutes
654
+ mock_remote_factory.push.assert_called_once()
655
+ _, push_kwargs = mock_remote_factory.push.call_args
656
+ assert push_kwargs.get("deploy_timeout_minutes") == 1200
@@ -23,6 +23,7 @@ from truss.base.truss_config import (
23
23
  HTTPOptions,
24
24
  ModelCache,
25
25
  ModelRepo,
26
+ ModelRepoCacheInternal,
26
27
  Resources,
27
28
  Runtime,
28
29
  TransportKind,
@@ -292,7 +293,10 @@ def test_cache_internal_with_models(default_config):
292
293
  config = TrussConfig(
293
294
  python_version="py39",
294
295
  cache_internal=CacheInternal(
295
- [ModelRepo(repo_id="test/model"), ModelRepo(repo_id="test/model2")]
296
+ [
297
+ ModelRepoCacheInternal(repo_id="test/model"),
298
+ ModelRepoCacheInternal(repo_id="test/model2"),
299
+ ]
296
300
  ),
297
301
  )
298
302
  new_config = default_config
@@ -305,11 +309,12 @@ def test_cache_internal_with_models(default_config):
305
309
 
306
310
  def test_huggingface_cache_single_model_default_revision(default_config):
307
311
  config = TrussConfig(
308
- python_version="py39", model_cache=ModelCache([ModelRepo(repo_id="test/model")])
312
+ python_version="py39",
313
+ model_cache=ModelCache([ModelRepo(repo_id="test/model", use_volume=False)]),
309
314
  )
310
315
 
311
316
  new_config = default_config
312
- new_config["model_cache"] = [{"repo_id": "test/model"}]
317
+ new_config["model_cache"] = [{"repo_id": "test/model", "use_volume": False}]
313
318
 
314
319
  assert new_config == config.to_dict(verbose=False)
315
320
  assert config.to_dict(verbose=True)["model_cache"][0].get("revision") is None
@@ -319,7 +324,9 @@ def test_huggingface_cache_single_model_non_default_revision_v1():
319
324
  config = TrussConfig(
320
325
  python_version="py39",
321
326
  requirements=[],
322
- model_cache=ModelCache([ModelRepo(repo_id="test/model", revision="not-main")]),
327
+ model_cache=ModelCache(
328
+ [ModelRepo(repo_id="test/model", revision="not-main", use_volume=False)]
329
+ ),
323
330
  )
324
331
 
325
332
  assert config.to_dict(verbose=False)["model_cache"][0].get("revision") == "not-main"
@@ -330,16 +337,16 @@ def test_huggingface_cache_multiple_models_default_revision(default_config):
330
337
  python_version="py39",
331
338
  model_cache=ModelCache(
332
339
  [
333
- ModelRepo(repo_id="test/model1", revision="main"),
334
- ModelRepo(repo_id="test/model2"),
340
+ ModelRepo(repo_id="test/model1", revision="main", use_volume=False),
341
+ ModelRepo(repo_id="test/model2", use_volume=False),
335
342
  ]
336
343
  ),
337
344
  )
338
345
 
339
346
  new_config = default_config
340
347
  new_config["model_cache"] = [
341
- {"repo_id": "test/model1", "revision": "main"},
342
- {"repo_id": "test/model2"},
348
+ {"repo_id": "test/model1", "revision": "main", "use_volume": False},
349
+ {"repo_id": "test/model2", "use_volume": False},
343
350
  ]
344
351
 
345
352
  assert new_config == config.to_dict(verbose=False)
@@ -355,16 +362,18 @@ def test_huggingface_cache_multiple_models_mixed_revision(default_config):
355
362
  python_version="py39",
356
363
  model_cache=ModelCache(
357
364
  [
358
- ModelRepo(repo_id="test/model1"),
359
- ModelRepo(repo_id="test/model2", revision="not-main2"),
365
+ ModelRepo(repo_id="test/model1", use_volume=False),
366
+ ModelRepo(
367
+ repo_id="test/model2", revision="not-main2", use_volume=False
368
+ ),
360
369
  ]
361
370
  ),
362
371
  )
363
372
 
364
373
  new_config = default_config
365
374
  new_config["model_cache"] = [
366
- {"repo_id": "test/model1"},
367
- {"repo_id": "test/model2", "revision": "not-main2"},
375
+ {"repo_id": "test/model1", "use_volume": False},
376
+ {"repo_id": "test/model2", "revision": "not-main2", "use_volume": False},
368
377
  ]
369
378
 
370
379
  assert new_config == config.to_dict(verbose=False)