truss 0.11.18rc500__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- truss/api/__init__.py +5 -2
- truss/base/truss_config.py +10 -3
- truss/cli/chains_commands.py +39 -1
- truss/cli/cli.py +35 -5
- truss/cli/remote_cli.py +29 -0
- truss/cli/resolvers/chain_team_resolver.py +82 -0
- truss/cli/resolvers/model_team_resolver.py +90 -0
- truss/cli/resolvers/training_project_team_resolver.py +81 -0
- truss/cli/train/cache.py +332 -0
- truss/cli/train/core.py +19 -143
- truss/cli/train_commands.py +69 -11
- truss/cli/utils/common.py +40 -3
- truss/remote/baseten/api.py +58 -5
- truss/remote/baseten/core.py +22 -4
- truss/remote/baseten/remote.py +24 -2
- truss/templates/control/control/helpers/inference_server_process_controller.py +3 -1
- truss/templates/server/requirements.txt +1 -1
- truss/templates/server.Dockerfile.jinja +10 -10
- truss/templates/shared/util.py +6 -5
- truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
- truss/tests/cli/test_chains_cli.py +44 -0
- truss/tests/cli/test_cli.py +134 -1
- truss/tests/cli/test_cli_utils_common.py +11 -0
- truss/tests/cli/test_model_team_resolver.py +279 -0
- truss/tests/cli/train/test_cache_view.py +240 -3
- truss/tests/cli/train/test_train_cli_core.py +2 -2
- truss/tests/cli/train/test_train_team_parameter.py +395 -0
- truss/tests/conftest.py +187 -0
- truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
- truss/tests/remote/baseten/test_api.py +122 -3
- truss/tests/remote/baseten/test_chain_upload.py +10 -1
- truss/tests/remote/baseten/test_core.py +86 -0
- truss/tests/remote/baseten/test_remote.py +216 -288
- truss/tests/test_config.py +21 -12
- truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
- truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
- truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
- truss/tests/test_model_inference.py +13 -0
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/METADATA +1 -1
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/RECORD +50 -38
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
- truss_chains/deployment/deployment_client.py +9 -4
- truss_chains/private_types.py +15 -0
- truss_train/definitions.py +3 -1
- truss_train/deployment.py +43 -21
- truss_train/public_api.py +4 -2
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from unittest import mock
|
|
2
|
+
|
|
1
3
|
import pydantic
|
|
2
4
|
import pytest
|
|
3
5
|
import requests_mock
|
|
@@ -11,16 +13,17 @@ from truss.remote.baseten.core import (
|
|
|
11
13
|
)
|
|
12
14
|
from truss.remote.baseten.custom_types import ChainletDataAtomic, OracleData
|
|
13
15
|
from truss.remote.baseten.error import RemoteError
|
|
14
|
-
from truss.remote.baseten.remote import BasetenRemote
|
|
15
16
|
from truss.truss_handle.truss_handle import TrussHandle
|
|
16
17
|
|
|
17
18
|
_TEST_REMOTE_URL = "http://test_remote.com"
|
|
18
19
|
_TEST_REMOTE_GRAPHQL_PATH = "http://test_remote.com/graphql/"
|
|
19
20
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
21
|
+
TRUSS_RC_CONTENT = """
|
|
22
|
+
[baseten]
|
|
23
|
+
remote_provider = baseten
|
|
24
|
+
api_key = test_key
|
|
25
|
+
remote_url = http://test.com
|
|
26
|
+
""".strip()
|
|
24
27
|
|
|
25
28
|
|
|
26
29
|
def assert_request_matches_expected_query(request, expected_query) -> None:
|
|
@@ -269,30 +272,41 @@ def test_push_raised_value_error_when_keep_previous_prod_settings_and_not_promot
|
|
|
269
272
|
)
|
|
270
273
|
|
|
271
274
|
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
}
|
|
292
|
-
},
|
|
293
|
-
],
|
|
275
|
+
@pytest.mark.parametrize("deploy_timeout_minutes", [9, 1441])
|
|
276
|
+
def test_push_raised_value_error_when_deploy_timeout_minutes_is_invalid(
|
|
277
|
+
deploy_timeout_minutes, custom_model_truss_dir_with_pre_and_post, remote
|
|
278
|
+
):
|
|
279
|
+
th = TrussHandle(custom_model_truss_dir_with_pre_and_post)
|
|
280
|
+
|
|
281
|
+
with pytest.raises(
|
|
282
|
+
ValueError,
|
|
283
|
+
match="deploy-timeout-minutes must be between 10 minutes and 1440 minutes \(24 hours\)",
|
|
284
|
+
):
|
|
285
|
+
remote.push(
|
|
286
|
+
th,
|
|
287
|
+
"model_name",
|
|
288
|
+
th.truss_dir,
|
|
289
|
+
publish=True,
|
|
290
|
+
promote=False,
|
|
291
|
+
preserve_previous_prod_deployment=False,
|
|
292
|
+
deployment_name="dep_name",
|
|
293
|
+
deploy_timeout_minutes=deploy_timeout_minutes,
|
|
294
294
|
)
|
|
295
295
|
|
|
296
|
+
|
|
297
|
+
def test_create_chain_with_no_publish(remote):
|
|
298
|
+
mock_deploy_response = {
|
|
299
|
+
"chain_deployment": {
|
|
300
|
+
"id": "new-chain-deployment-id",
|
|
301
|
+
"chain": {"id": "new-chain-id", "hostname": "hostname"},
|
|
302
|
+
}
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
with mock.patch.object(
|
|
306
|
+
remote.api, "get_chains", return_value=[]
|
|
307
|
+
) as mock_get_chains, mock.patch.object(
|
|
308
|
+
remote.api, "deploy_chain_atomic", return_value=mock_deploy_response
|
|
309
|
+
) as mock_deploy:
|
|
296
310
|
deployment_handle = create_chain_atomic(
|
|
297
311
|
api=remote.api,
|
|
298
312
|
chain_name="draft_chain",
|
|
@@ -310,64 +324,61 @@ def test_create_chain_with_no_publish(remote):
|
|
|
310
324
|
environment=None,
|
|
311
325
|
)
|
|
312
326
|
|
|
313
|
-
|
|
314
|
-
|
|
327
|
+
mock_get_chains.assert_called_once()
|
|
328
|
+
mock_deploy.assert_called_once()
|
|
315
329
|
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
name
|
|
321
|
-
}
|
|
322
|
-
}
|
|
323
|
-
""".strip()
|
|
330
|
+
call_kwargs = mock_deploy.call_args.kwargs
|
|
331
|
+
assert call_kwargs["chain_name"] == "draft_chain"
|
|
332
|
+
assert call_kwargs.get("is_draft") is True
|
|
333
|
+
assert call_kwargs.get("deploy_timeout_minutes") is None
|
|
324
334
|
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
)
|
|
335
|
+
assert deployment_handle.chain_id == "new-chain-id"
|
|
336
|
+
assert deployment_handle.chain_deployment_id == "new-chain-deployment-id"
|
|
328
337
|
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
""
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
}}
|
|
359
|
-
}}
|
|
360
|
-
}}
|
|
361
|
-
""".strip()
|
|
362
|
-
|
|
363
|
-
assert_request_matches_expected_query(
|
|
364
|
-
create_chain_graphql_request, expected_create_chain_mutation
|
|
338
|
+
|
|
339
|
+
def test_create_chain_no_existing_chain(remote):
|
|
340
|
+
mock_deploy_response = {
|
|
341
|
+
"chain_deployment": {
|
|
342
|
+
"id": "new-chain-deployment-id",
|
|
343
|
+
"chain": {"id": "new-chain-id", "hostname": "hostname"},
|
|
344
|
+
}
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
with mock.patch.object(
|
|
348
|
+
remote.api, "get_chains", return_value=[]
|
|
349
|
+
) as mock_get_chains, mock.patch.object(
|
|
350
|
+
remote.api, "deploy_chain_atomic", return_value=mock_deploy_response
|
|
351
|
+
) as mock_deploy:
|
|
352
|
+
deployment_handle = create_chain_atomic(
|
|
353
|
+
api=remote.api,
|
|
354
|
+
chain_name="new_chain",
|
|
355
|
+
entrypoint=ChainletDataAtomic(
|
|
356
|
+
name="chainlet-1",
|
|
357
|
+
oracle=OracleData(
|
|
358
|
+
model_name="model-1",
|
|
359
|
+
s3_key="s3-key-1",
|
|
360
|
+
encoded_config_str="encoded-config-str-1",
|
|
361
|
+
),
|
|
362
|
+
),
|
|
363
|
+
dependencies=[],
|
|
364
|
+
truss_user_env=b10_types.TrussUserEnv.collect(),
|
|
365
|
+
is_draft=False,
|
|
366
|
+
environment=None,
|
|
365
367
|
)
|
|
368
|
+
|
|
369
|
+
mock_get_chains.assert_called_once()
|
|
370
|
+
mock_deploy.assert_called_once()
|
|
371
|
+
|
|
372
|
+
call_kwargs = mock_deploy.call_args.kwargs
|
|
373
|
+
assert call_kwargs["chain_name"] == "new_chain"
|
|
374
|
+
assert call_kwargs.get("is_draft") is not True
|
|
375
|
+
assert call_kwargs.get("deploy_timeout_minutes") is None
|
|
376
|
+
|
|
366
377
|
assert deployment_handle.chain_id == "new-chain-id"
|
|
367
378
|
assert deployment_handle.chain_deployment_id == "new-chain-deployment-id"
|
|
368
379
|
|
|
369
380
|
|
|
370
|
-
def
|
|
381
|
+
def test_create_chain_with_deployment_name(remote):
|
|
371
382
|
with requests_mock.Mocker() as m:
|
|
372
383
|
m.post(
|
|
373
384
|
_TEST_REMOTE_GRAPHQL_PATH,
|
|
@@ -391,7 +402,8 @@ def test_create_chain_no_existing_chain(remote):
|
|
|
391
402
|
],
|
|
392
403
|
)
|
|
393
404
|
|
|
394
|
-
|
|
405
|
+
deployment_name = "chain-deployment"
|
|
406
|
+
create_chain_atomic(
|
|
395
407
|
api=remote.api,
|
|
396
408
|
chain_name="new_chain",
|
|
397
409
|
entrypoint=ChainletDataAtomic(
|
|
@@ -406,94 +418,32 @@ def test_create_chain_no_existing_chain(remote):
|
|
|
406
418
|
truss_user_env=b10_types.TrussUserEnv.collect(),
|
|
407
419
|
is_draft=False,
|
|
408
420
|
environment=None,
|
|
421
|
+
deployment_name=deployment_name,
|
|
409
422
|
)
|
|
410
423
|
|
|
411
|
-
get_chains_graphql_request = m.request_history[0]
|
|
412
424
|
create_chain_graphql_request = m.request_history[1]
|
|
413
425
|
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
id
|
|
418
|
-
name
|
|
419
|
-
}
|
|
420
|
-
}
|
|
421
|
-
""".strip()
|
|
422
|
-
|
|
423
|
-
assert_request_matches_expected_query(
|
|
424
|
-
get_chains_graphql_request, expected_get_chains_query
|
|
425
|
-
)
|
|
426
|
-
|
|
427
|
-
chainlets_string = """
|
|
428
|
-
{
|
|
429
|
-
name: "chainlet-1",
|
|
430
|
-
oracle: {
|
|
431
|
-
model_name: "model-1",
|
|
432
|
-
s3_key: "s3-key-1",
|
|
433
|
-
encoded_config_str: "encoded-config-str-1",
|
|
434
|
-
semver_bump: "MINOR"
|
|
435
|
-
}
|
|
436
|
-
}
|
|
437
|
-
""".strip()
|
|
438
|
-
|
|
439
|
-
expected_create_chain_mutation = f"""
|
|
440
|
-
mutation ($trussUserEnv: String) {{
|
|
441
|
-
deploy_chain_atomic(
|
|
442
|
-
chain_name: "new_chain"
|
|
443
|
-
is_draft: false
|
|
444
|
-
entrypoint: {chainlets_string}
|
|
445
|
-
dependencies: []
|
|
446
|
-
truss_user_env: $trussUserEnv
|
|
447
|
-
) {{
|
|
448
|
-
chain_deployment {{
|
|
449
|
-
id
|
|
450
|
-
chain {{
|
|
451
|
-
id
|
|
452
|
-
hostname
|
|
453
|
-
}}
|
|
454
|
-
}}
|
|
455
|
-
}}
|
|
456
|
-
}}
|
|
457
|
-
""".strip()
|
|
458
|
-
|
|
459
|
-
assert_request_matches_expected_query(
|
|
460
|
-
create_chain_graphql_request, expected_create_chain_mutation
|
|
426
|
+
assert (
|
|
427
|
+
'deployment_name: "chain-deployment"'
|
|
428
|
+
in create_chain_graphql_request.json()["query"]
|
|
461
429
|
)
|
|
462
430
|
|
|
463
|
-
assert deployment_handle.chain_id == "new-chain-id"
|
|
464
|
-
assert deployment_handle.chain_deployment_id == "new-chain-deployment-id"
|
|
465
|
-
|
|
466
431
|
|
|
467
432
|
def test_create_chain_with_existing_chain_promote_to_environment_publish_false(remote):
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
"data": {
|
|
475
|
-
"chains": [{"id": "old-chain-id", "name": "old_chain"}]
|
|
476
|
-
}
|
|
477
|
-
}
|
|
478
|
-
},
|
|
479
|
-
{
|
|
480
|
-
"json": {
|
|
481
|
-
"data": {
|
|
482
|
-
"deploy_chain_atomic": {
|
|
483
|
-
"chain_deployment": {
|
|
484
|
-
"id": "new-chain-deployment-id",
|
|
485
|
-
"chain": {
|
|
486
|
-
"id": "new-chain-id",
|
|
487
|
-
"hostname": "hostname",
|
|
488
|
-
},
|
|
489
|
-
}
|
|
490
|
-
}
|
|
491
|
-
}
|
|
492
|
-
}
|
|
493
|
-
},
|
|
494
|
-
],
|
|
495
|
-
)
|
|
433
|
+
mock_deploy_response = {
|
|
434
|
+
"chain_deployment": {
|
|
435
|
+
"id": "new-chain-deployment-id",
|
|
436
|
+
"chain": {"id": "new-chain-id", "hostname": "hostname"},
|
|
437
|
+
}
|
|
438
|
+
}
|
|
496
439
|
|
|
440
|
+
with mock.patch.object(
|
|
441
|
+
remote.api,
|
|
442
|
+
"get_chains",
|
|
443
|
+
return_value=[{"id": "old-chain-id", "name": "old_chain"}],
|
|
444
|
+
) as mock_get_chains, mock.patch.object(
|
|
445
|
+
remote.api, "deploy_chain_atomic", return_value=mock_deploy_response
|
|
446
|
+
) as mock_deploy:
|
|
497
447
|
deployment_handle = create_chain_atomic(
|
|
498
448
|
api=remote.api,
|
|
499
449
|
chain_name="old_chain",
|
|
@@ -511,95 +461,34 @@ def test_create_chain_with_existing_chain_promote_to_environment_publish_false(r
|
|
|
511
461
|
environment="production",
|
|
512
462
|
)
|
|
513
463
|
|
|
514
|
-
|
|
515
|
-
|
|
464
|
+
mock_get_chains.assert_called_once()
|
|
465
|
+
mock_deploy.assert_called_once()
|
|
516
466
|
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
}
|
|
523
|
-
}
|
|
524
|
-
""".strip()
|
|
525
|
-
|
|
526
|
-
assert_request_matches_expected_query(
|
|
527
|
-
get_chains_graphql_request, expected_get_chains_query
|
|
528
|
-
)
|
|
529
|
-
|
|
530
|
-
# Note that if publish=False and environment!=None, we set publish to True and create
|
|
531
|
-
# a non-draft deployment
|
|
532
|
-
chainlets_string = """
|
|
533
|
-
{
|
|
534
|
-
name: "chainlet-1",
|
|
535
|
-
oracle: {
|
|
536
|
-
model_name: "model-1",
|
|
537
|
-
s3_key: "s3-key-1",
|
|
538
|
-
encoded_config_str: "encoded-config-str-1",
|
|
539
|
-
semver_bump: "MINOR"
|
|
540
|
-
}
|
|
541
|
-
}
|
|
542
|
-
""".strip()
|
|
543
|
-
|
|
544
|
-
expected_create_chain_mutation = f"""
|
|
545
|
-
mutation ($trussUserEnv: String) {{
|
|
546
|
-
deploy_chain_atomic(
|
|
547
|
-
chain_id: "old-chain-id"
|
|
548
|
-
environment: "production"
|
|
549
|
-
is_draft: false
|
|
550
|
-
entrypoint: {chainlets_string}
|
|
551
|
-
dependencies: []
|
|
552
|
-
truss_user_env: $trussUserEnv
|
|
553
|
-
) {{
|
|
554
|
-
chain_deployment {{
|
|
555
|
-
id
|
|
556
|
-
chain {{
|
|
557
|
-
id
|
|
558
|
-
hostname
|
|
559
|
-
}}
|
|
560
|
-
}}
|
|
561
|
-
}}
|
|
562
|
-
}}
|
|
563
|
-
""".strip()
|
|
564
|
-
|
|
565
|
-
assert_request_matches_expected_query(
|
|
566
|
-
create_chain_graphql_request, expected_create_chain_mutation
|
|
567
|
-
)
|
|
467
|
+
call_kwargs = mock_deploy.call_args.kwargs
|
|
468
|
+
assert call_kwargs["chain_id"] == "old-chain-id"
|
|
469
|
+
assert call_kwargs["environment"] == "production"
|
|
470
|
+
assert call_kwargs.get("is_draft") is not True
|
|
471
|
+
assert call_kwargs.get("deploy_timeout_minutes") is None
|
|
568
472
|
|
|
569
473
|
assert deployment_handle.chain_id == "new-chain-id"
|
|
570
474
|
assert deployment_handle.chain_deployment_id == "new-chain-deployment-id"
|
|
571
475
|
|
|
572
476
|
|
|
573
477
|
def test_create_chain_existing_chain_publish_true_no_promotion(remote):
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
"data": {
|
|
581
|
-
"chains": [{"id": "old-chain-id", "name": "old_chain"}]
|
|
582
|
-
}
|
|
583
|
-
}
|
|
584
|
-
},
|
|
585
|
-
{
|
|
586
|
-
"json": {
|
|
587
|
-
"data": {
|
|
588
|
-
"deploy_chain_atomic": {
|
|
589
|
-
"chain_deployment": {
|
|
590
|
-
"id": "new-chain-deployment-id",
|
|
591
|
-
"chain": {
|
|
592
|
-
"id": "new-chain-id",
|
|
593
|
-
"hostname": "hostname",
|
|
594
|
-
},
|
|
595
|
-
}
|
|
596
|
-
}
|
|
597
|
-
}
|
|
598
|
-
}
|
|
599
|
-
},
|
|
600
|
-
],
|
|
601
|
-
)
|
|
478
|
+
mock_deploy_response = {
|
|
479
|
+
"chain_deployment": {
|
|
480
|
+
"id": "new-chain-deployment-id",
|
|
481
|
+
"chain": {"id": "new-chain-id", "hostname": "hostname"},
|
|
482
|
+
}
|
|
483
|
+
}
|
|
602
484
|
|
|
485
|
+
with mock.patch.object(
|
|
486
|
+
remote.api,
|
|
487
|
+
"get_chains",
|
|
488
|
+
return_value=[{"id": "old-chain-id", "name": "old_chain"}],
|
|
489
|
+
) as mock_get_chains, mock.patch.object(
|
|
490
|
+
remote.api, "deploy_chain_atomic", return_value=mock_deploy_response
|
|
491
|
+
) as mock_deploy:
|
|
603
492
|
deployment_handle = create_chain_atomic(
|
|
604
493
|
api=remote.api,
|
|
605
494
|
chain_name="old_chain",
|
|
@@ -617,57 +506,13 @@ def test_create_chain_existing_chain_publish_true_no_promotion(remote):
|
|
|
617
506
|
environment=None,
|
|
618
507
|
)
|
|
619
508
|
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
expected_get_chains_query = """
|
|
624
|
-
{
|
|
625
|
-
chains {
|
|
626
|
-
id
|
|
627
|
-
name
|
|
628
|
-
}
|
|
629
|
-
}
|
|
630
|
-
""".strip()
|
|
631
|
-
|
|
632
|
-
assert_request_matches_expected_query(
|
|
633
|
-
get_chains_graphql_request, expected_get_chains_query
|
|
634
|
-
)
|
|
509
|
+
mock_get_chains.assert_called_once()
|
|
510
|
+
mock_deploy.assert_called_once()
|
|
635
511
|
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
model_name: "model-1",
|
|
641
|
-
s3_key: "s3-key-1",
|
|
642
|
-
encoded_config_str: "encoded-config-str-1",
|
|
643
|
-
semver_bump: "MINOR"
|
|
644
|
-
}
|
|
645
|
-
}
|
|
646
|
-
""".strip()
|
|
647
|
-
|
|
648
|
-
expected_create_chain_mutation = f"""
|
|
649
|
-
mutation ($trussUserEnv: String) {{
|
|
650
|
-
deploy_chain_atomic(
|
|
651
|
-
chain_id: "old-chain-id"
|
|
652
|
-
is_draft: false
|
|
653
|
-
entrypoint: {chainlets_string}
|
|
654
|
-
dependencies: []
|
|
655
|
-
truss_user_env: $trussUserEnv
|
|
656
|
-
) {{
|
|
657
|
-
chain_deployment {{
|
|
658
|
-
id
|
|
659
|
-
chain {{
|
|
660
|
-
id
|
|
661
|
-
hostname
|
|
662
|
-
}}
|
|
663
|
-
}}
|
|
664
|
-
}}
|
|
665
|
-
}}
|
|
666
|
-
""".strip()
|
|
667
|
-
|
|
668
|
-
assert_request_matches_expected_query(
|
|
669
|
-
create_chain_graphql_request, expected_create_chain_mutation
|
|
670
|
-
)
|
|
512
|
+
call_kwargs = mock_deploy.call_args.kwargs
|
|
513
|
+
assert call_kwargs["chain_id"] == "old-chain-id"
|
|
514
|
+
assert call_kwargs.get("is_draft") is not True
|
|
515
|
+
assert call_kwargs.get("deploy_timeout_minutes") is None
|
|
671
516
|
|
|
672
517
|
assert deployment_handle.chain_id == "new-chain-id"
|
|
673
518
|
assert deployment_handle.chain_deployment_id == "new-chain-deployment-id"
|
|
@@ -726,3 +571,86 @@ def test_push_raised_validation_error_for_extra_fields(tmp_path, remote):
|
|
|
726
571
|
match="Extra fields not allowed: \[extra_field, who_am_i\]",
|
|
727
572
|
):
|
|
728
573
|
remote.push(th, "model_name", th.truss_dir)
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
def test_push_passes_deploy_timeout_minutes_to_create_truss_service(
|
|
577
|
+
custom_model_truss_dir_with_pre_and_post,
|
|
578
|
+
remote,
|
|
579
|
+
mock_baseten_requests,
|
|
580
|
+
mock_upload_truss,
|
|
581
|
+
mock_create_truss_service,
|
|
582
|
+
mock_truss_handle,
|
|
583
|
+
):
|
|
584
|
+
remote.push(
|
|
585
|
+
mock_truss_handle,
|
|
586
|
+
"model_name",
|
|
587
|
+
mock_truss_handle.truss_dir,
|
|
588
|
+
publish=True,
|
|
589
|
+
deploy_timeout_minutes=450,
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
mock_create_truss_service.assert_called_once()
|
|
593
|
+
_, kwargs = mock_create_truss_service.call_args
|
|
594
|
+
assert kwargs["deploy_timeout_minutes"] == 450
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
def test_push_passes_none_deploy_timeout_minutes_when_not_specified(
|
|
598
|
+
custom_model_truss_dir_with_pre_and_post,
|
|
599
|
+
remote,
|
|
600
|
+
mock_baseten_requests,
|
|
601
|
+
mock_upload_truss,
|
|
602
|
+
mock_create_truss_service,
|
|
603
|
+
mock_truss_handle,
|
|
604
|
+
):
|
|
605
|
+
remote.push(
|
|
606
|
+
mock_truss_handle, "model_name", mock_truss_handle.truss_dir, publish=True
|
|
607
|
+
)
|
|
608
|
+
|
|
609
|
+
mock_create_truss_service.assert_called_once()
|
|
610
|
+
_, kwargs = mock_create_truss_service.call_args
|
|
611
|
+
assert kwargs.get("deploy_timeout_minutes") is None
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
def test_push_integration_deploy_timeout_minutes_propagated(
|
|
615
|
+
custom_model_truss_dir_with_pre_and_post,
|
|
616
|
+
remote,
|
|
617
|
+
mock_baseten_requests,
|
|
618
|
+
mock_upload_truss,
|
|
619
|
+
mock_create_truss_service,
|
|
620
|
+
mock_truss_handle,
|
|
621
|
+
):
|
|
622
|
+
remote.push(
|
|
623
|
+
mock_truss_handle,
|
|
624
|
+
"model_name",
|
|
625
|
+
mock_truss_handle.truss_dir,
|
|
626
|
+
publish=True,
|
|
627
|
+
environment="staging",
|
|
628
|
+
deploy_timeout_minutes=750,
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
mock_create_truss_service.assert_called_once()
|
|
632
|
+
_, kwargs = mock_create_truss_service.call_args
|
|
633
|
+
assert kwargs["deploy_timeout_minutes"] == 750
|
|
634
|
+
assert kwargs["environment"] == "staging"
|
|
635
|
+
|
|
636
|
+
|
|
637
|
+
def test_api_push_integration_deploy_timeout_minutes_propagated(
|
|
638
|
+
custom_model_truss_dir_with_pre_and_post,
|
|
639
|
+
mock_remote_factory,
|
|
640
|
+
temp_trussrc_dir,
|
|
641
|
+
mock_available_config_names,
|
|
642
|
+
mock_truss_handle,
|
|
643
|
+
):
|
|
644
|
+
from truss.api import push
|
|
645
|
+
|
|
646
|
+
push(
|
|
647
|
+
str(mock_truss_handle.truss_dir),
|
|
648
|
+
remote="baseten",
|
|
649
|
+
model_name="test_model",
|
|
650
|
+
deploy_timeout_minutes=1200,
|
|
651
|
+
)
|
|
652
|
+
|
|
653
|
+
# Verify the remote.push was called with deploy_timeout_minutes
|
|
654
|
+
mock_remote_factory.push.assert_called_once()
|
|
655
|
+
_, push_kwargs = mock_remote_factory.push.call_args
|
|
656
|
+
assert push_kwargs.get("deploy_timeout_minutes") == 1200
|
truss/tests/test_config.py
CHANGED
|
@@ -23,6 +23,7 @@ from truss.base.truss_config import (
|
|
|
23
23
|
HTTPOptions,
|
|
24
24
|
ModelCache,
|
|
25
25
|
ModelRepo,
|
|
26
|
+
ModelRepoCacheInternal,
|
|
26
27
|
Resources,
|
|
27
28
|
Runtime,
|
|
28
29
|
TransportKind,
|
|
@@ -292,7 +293,10 @@ def test_cache_internal_with_models(default_config):
|
|
|
292
293
|
config = TrussConfig(
|
|
293
294
|
python_version="py39",
|
|
294
295
|
cache_internal=CacheInternal(
|
|
295
|
-
[
|
|
296
|
+
[
|
|
297
|
+
ModelRepoCacheInternal(repo_id="test/model"),
|
|
298
|
+
ModelRepoCacheInternal(repo_id="test/model2"),
|
|
299
|
+
]
|
|
296
300
|
),
|
|
297
301
|
)
|
|
298
302
|
new_config = default_config
|
|
@@ -305,11 +309,12 @@ def test_cache_internal_with_models(default_config):
|
|
|
305
309
|
|
|
306
310
|
def test_huggingface_cache_single_model_default_revision(default_config):
|
|
307
311
|
config = TrussConfig(
|
|
308
|
-
python_version="py39",
|
|
312
|
+
python_version="py39",
|
|
313
|
+
model_cache=ModelCache([ModelRepo(repo_id="test/model", use_volume=False)]),
|
|
309
314
|
)
|
|
310
315
|
|
|
311
316
|
new_config = default_config
|
|
312
|
-
new_config["model_cache"] = [{"repo_id": "test/model"}]
|
|
317
|
+
new_config["model_cache"] = [{"repo_id": "test/model", "use_volume": False}]
|
|
313
318
|
|
|
314
319
|
assert new_config == config.to_dict(verbose=False)
|
|
315
320
|
assert config.to_dict(verbose=True)["model_cache"][0].get("revision") is None
|
|
@@ -319,7 +324,9 @@ def test_huggingface_cache_single_model_non_default_revision_v1():
|
|
|
319
324
|
config = TrussConfig(
|
|
320
325
|
python_version="py39",
|
|
321
326
|
requirements=[],
|
|
322
|
-
model_cache=ModelCache(
|
|
327
|
+
model_cache=ModelCache(
|
|
328
|
+
[ModelRepo(repo_id="test/model", revision="not-main", use_volume=False)]
|
|
329
|
+
),
|
|
323
330
|
)
|
|
324
331
|
|
|
325
332
|
assert config.to_dict(verbose=False)["model_cache"][0].get("revision") == "not-main"
|
|
@@ -330,16 +337,16 @@ def test_huggingface_cache_multiple_models_default_revision(default_config):
|
|
|
330
337
|
python_version="py39",
|
|
331
338
|
model_cache=ModelCache(
|
|
332
339
|
[
|
|
333
|
-
ModelRepo(repo_id="test/model1", revision="main"),
|
|
334
|
-
ModelRepo(repo_id="test/model2"),
|
|
340
|
+
ModelRepo(repo_id="test/model1", revision="main", use_volume=False),
|
|
341
|
+
ModelRepo(repo_id="test/model2", use_volume=False),
|
|
335
342
|
]
|
|
336
343
|
),
|
|
337
344
|
)
|
|
338
345
|
|
|
339
346
|
new_config = default_config
|
|
340
347
|
new_config["model_cache"] = [
|
|
341
|
-
{"repo_id": "test/model1", "revision": "main"},
|
|
342
|
-
{"repo_id": "test/model2"},
|
|
348
|
+
{"repo_id": "test/model1", "revision": "main", "use_volume": False},
|
|
349
|
+
{"repo_id": "test/model2", "use_volume": False},
|
|
343
350
|
]
|
|
344
351
|
|
|
345
352
|
assert new_config == config.to_dict(verbose=False)
|
|
@@ -355,16 +362,18 @@ def test_huggingface_cache_multiple_models_mixed_revision(default_config):
|
|
|
355
362
|
python_version="py39",
|
|
356
363
|
model_cache=ModelCache(
|
|
357
364
|
[
|
|
358
|
-
ModelRepo(repo_id="test/model1"),
|
|
359
|
-
ModelRepo(
|
|
365
|
+
ModelRepo(repo_id="test/model1", use_volume=False),
|
|
366
|
+
ModelRepo(
|
|
367
|
+
repo_id="test/model2", revision="not-main2", use_volume=False
|
|
368
|
+
),
|
|
360
369
|
]
|
|
361
370
|
),
|
|
362
371
|
)
|
|
363
372
|
|
|
364
373
|
new_config = default_config
|
|
365
374
|
new_config["model_cache"] = [
|
|
366
|
-
{"repo_id": "test/model1"},
|
|
367
|
-
{"repo_id": "test/model2", "revision": "not-main2"},
|
|
375
|
+
{"repo_id": "test/model1", "use_volume": False},
|
|
376
|
+
{"repo_id": "test/model2", "revision": "not-main2", "use_volume": False},
|
|
368
377
|
]
|
|
369
378
|
|
|
370
379
|
assert new_config == config.to_dict(verbose=False)
|
|
File without changes
|