truss 0.11.6rc102__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- truss/api/__init__.py +5 -2
- truss/base/constants.py +1 -0
- truss/base/trt_llm_config.py +14 -3
- truss/base/truss_config.py +19 -4
- truss/cli/chains_commands.py +49 -1
- truss/cli/cli.py +38 -7
- truss/cli/logs/base_watcher.py +31 -12
- truss/cli/logs/model_log_watcher.py +24 -1
- truss/cli/remote_cli.py +29 -0
- truss/cli/resolvers/chain_team_resolver.py +82 -0
- truss/cli/resolvers/model_team_resolver.py +90 -0
- truss/cli/resolvers/training_project_team_resolver.py +81 -0
- truss/cli/train/cache.py +332 -0
- truss/cli/train/core.py +57 -163
- truss/cli/train/deploy_checkpoints/__init__.py +2 -2
- truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +236 -103
- truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +1 -52
- truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +1 -86
- truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py +1 -85
- truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +1 -56
- truss/cli/train/types.py +18 -9
- truss/cli/train_commands.py +180 -35
- truss/cli/utils/common.py +40 -3
- truss/contexts/image_builder/serving_image_builder.py +17 -4
- truss/remote/baseten/api.py +215 -9
- truss/remote/baseten/core.py +63 -7
- truss/remote/baseten/custom_types.py +1 -0
- truss/remote/baseten/remote.py +42 -2
- truss/remote/baseten/service.py +0 -7
- truss/remote/baseten/utils/transfer.py +5 -2
- truss/templates/base.Dockerfile.jinja +8 -4
- truss/templates/control/control/application.py +51 -26
- truss/templates/control/control/endpoints.py +1 -5
- truss/templates/control/control/helpers/inference_server_process_controller.py +10 -4
- truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py +33 -18
- truss/templates/control/control/server.py +1 -1
- truss/templates/control/requirements.txt +1 -2
- truss/templates/docker_server/proxy.conf.jinja +13 -0
- truss/templates/docker_server/supervisord.conf.jinja +2 -1
- truss/templates/no_build.Dockerfile.jinja +1 -0
- truss/templates/server/requirements.txt +2 -3
- truss/templates/server/truss_server.py +2 -5
- truss/templates/server.Dockerfile.jinja +12 -12
- truss/templates/shared/lazy_data_resolver.py +214 -2
- truss/templates/shared/util.py +6 -5
- truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
- truss/tests/cli/test_chains_cli.py +144 -0
- truss/tests/cli/test_cli.py +134 -1
- truss/tests/cli/test_cli_utils_common.py +11 -0
- truss/tests/cli/test_model_team_resolver.py +279 -0
- truss/tests/cli/train/test_cache_view.py +240 -3
- truss/tests/cli/train/test_deploy_checkpoints.py +2 -846
- truss/tests/cli/train/test_train_cli_core.py +2 -2
- truss/tests/cli/train/test_train_team_parameter.py +395 -0
- truss/tests/conftest.py +187 -0
- truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
- truss/tests/remote/baseten/test_api.py +122 -3
- truss/tests/remote/baseten/test_chain_upload.py +294 -0
- truss/tests/remote/baseten/test_core.py +86 -0
- truss/tests/remote/baseten/test_remote.py +216 -288
- truss/tests/remote/baseten/test_service.py +56 -0
- truss/tests/templates/control/control/conftest.py +20 -0
- truss/tests/templates/control/control/test_endpoints.py +4 -0
- truss/tests/templates/control/control/test_server.py +8 -24
- truss/tests/templates/control/control/test_server_integration.py +4 -2
- truss/tests/test_config.py +21 -12
- truss/tests/test_data/server.Dockerfile +3 -1
- truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
- truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
- truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
- truss/tests/test_model_inference.py +13 -0
- truss/tests/util/test_env_vars.py +8 -3
- truss/util/__init__.py +0 -0
- truss/util/env_vars.py +19 -8
- truss/util/error_utils.py +37 -0
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/METADATA +2 -2
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/RECORD +88 -70
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
- truss_chains/deployment/deployment_client.py +16 -4
- truss_chains/private_types.py +18 -0
- truss_chains/public_api.py +3 -0
- truss_train/definitions.py +6 -4
- truss_train/deployment.py +43 -21
- truss_train/public_api.py +4 -2
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from unittest import mock
|
|
2
|
+
|
|
1
3
|
import pydantic
|
|
2
4
|
import pytest
|
|
3
5
|
import requests_mock
|
|
@@ -11,16 +13,17 @@ from truss.remote.baseten.core import (
|
|
|
11
13
|
)
|
|
12
14
|
from truss.remote.baseten.custom_types import ChainletDataAtomic, OracleData
|
|
13
15
|
from truss.remote.baseten.error import RemoteError
|
|
14
|
-
from truss.remote.baseten.remote import BasetenRemote
|
|
15
16
|
from truss.truss_handle.truss_handle import TrussHandle
|
|
16
17
|
|
|
17
18
|
_TEST_REMOTE_URL = "http://test_remote.com"
|
|
18
19
|
_TEST_REMOTE_GRAPHQL_PATH = "http://test_remote.com/graphql/"
|
|
19
20
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
21
|
+
TRUSS_RC_CONTENT = """
|
|
22
|
+
[baseten]
|
|
23
|
+
remote_provider = baseten
|
|
24
|
+
api_key = test_key
|
|
25
|
+
remote_url = http://test.com
|
|
26
|
+
""".strip()
|
|
24
27
|
|
|
25
28
|
|
|
26
29
|
def assert_request_matches_expected_query(request, expected_query) -> None:
|
|
@@ -269,30 +272,41 @@ def test_push_raised_value_error_when_keep_previous_prod_settings_and_not_promot
|
|
|
269
272
|
)
|
|
270
273
|
|
|
271
274
|
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
}
|
|
292
|
-
},
|
|
293
|
-
],
|
|
275
|
+
@pytest.mark.parametrize("deploy_timeout_minutes", [9, 1441])
|
|
276
|
+
def test_push_raised_value_error_when_deploy_timeout_minutes_is_invalid(
|
|
277
|
+
deploy_timeout_minutes, custom_model_truss_dir_with_pre_and_post, remote
|
|
278
|
+
):
|
|
279
|
+
th = TrussHandle(custom_model_truss_dir_with_pre_and_post)
|
|
280
|
+
|
|
281
|
+
with pytest.raises(
|
|
282
|
+
ValueError,
|
|
283
|
+
match="deploy-timeout-minutes must be between 10 minutes and 1440 minutes \(24 hours\)",
|
|
284
|
+
):
|
|
285
|
+
remote.push(
|
|
286
|
+
th,
|
|
287
|
+
"model_name",
|
|
288
|
+
th.truss_dir,
|
|
289
|
+
publish=True,
|
|
290
|
+
promote=False,
|
|
291
|
+
preserve_previous_prod_deployment=False,
|
|
292
|
+
deployment_name="dep_name",
|
|
293
|
+
deploy_timeout_minutes=deploy_timeout_minutes,
|
|
294
294
|
)
|
|
295
295
|
|
|
296
|
+
|
|
297
|
+
def test_create_chain_with_no_publish(remote):
|
|
298
|
+
mock_deploy_response = {
|
|
299
|
+
"chain_deployment": {
|
|
300
|
+
"id": "new-chain-deployment-id",
|
|
301
|
+
"chain": {"id": "new-chain-id", "hostname": "hostname"},
|
|
302
|
+
}
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
with mock.patch.object(
|
|
306
|
+
remote.api, "get_chains", return_value=[]
|
|
307
|
+
) as mock_get_chains, mock.patch.object(
|
|
308
|
+
remote.api, "deploy_chain_atomic", return_value=mock_deploy_response
|
|
309
|
+
) as mock_deploy:
|
|
296
310
|
deployment_handle = create_chain_atomic(
|
|
297
311
|
api=remote.api,
|
|
298
312
|
chain_name="draft_chain",
|
|
@@ -310,64 +324,61 @@ def test_create_chain_with_no_publish(remote):
|
|
|
310
324
|
environment=None,
|
|
311
325
|
)
|
|
312
326
|
|
|
313
|
-
|
|
314
|
-
|
|
327
|
+
mock_get_chains.assert_called_once()
|
|
328
|
+
mock_deploy.assert_called_once()
|
|
315
329
|
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
name
|
|
321
|
-
}
|
|
322
|
-
}
|
|
323
|
-
""".strip()
|
|
330
|
+
call_kwargs = mock_deploy.call_args.kwargs
|
|
331
|
+
assert call_kwargs["chain_name"] == "draft_chain"
|
|
332
|
+
assert call_kwargs.get("is_draft") is True
|
|
333
|
+
assert call_kwargs.get("deploy_timeout_minutes") is None
|
|
324
334
|
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
)
|
|
335
|
+
assert deployment_handle.chain_id == "new-chain-id"
|
|
336
|
+
assert deployment_handle.chain_deployment_id == "new-chain-deployment-id"
|
|
328
337
|
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
""
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
}}
|
|
359
|
-
}}
|
|
360
|
-
}}
|
|
361
|
-
""".strip()
|
|
362
|
-
|
|
363
|
-
assert_request_matches_expected_query(
|
|
364
|
-
create_chain_graphql_request, expected_create_chain_mutation
|
|
338
|
+
|
|
339
|
+
def test_create_chain_no_existing_chain(remote):
|
|
340
|
+
mock_deploy_response = {
|
|
341
|
+
"chain_deployment": {
|
|
342
|
+
"id": "new-chain-deployment-id",
|
|
343
|
+
"chain": {"id": "new-chain-id", "hostname": "hostname"},
|
|
344
|
+
}
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
with mock.patch.object(
|
|
348
|
+
remote.api, "get_chains", return_value=[]
|
|
349
|
+
) as mock_get_chains, mock.patch.object(
|
|
350
|
+
remote.api, "deploy_chain_atomic", return_value=mock_deploy_response
|
|
351
|
+
) as mock_deploy:
|
|
352
|
+
deployment_handle = create_chain_atomic(
|
|
353
|
+
api=remote.api,
|
|
354
|
+
chain_name="new_chain",
|
|
355
|
+
entrypoint=ChainletDataAtomic(
|
|
356
|
+
name="chainlet-1",
|
|
357
|
+
oracle=OracleData(
|
|
358
|
+
model_name="model-1",
|
|
359
|
+
s3_key="s3-key-1",
|
|
360
|
+
encoded_config_str="encoded-config-str-1",
|
|
361
|
+
),
|
|
362
|
+
),
|
|
363
|
+
dependencies=[],
|
|
364
|
+
truss_user_env=b10_types.TrussUserEnv.collect(),
|
|
365
|
+
is_draft=False,
|
|
366
|
+
environment=None,
|
|
365
367
|
)
|
|
368
|
+
|
|
369
|
+
mock_get_chains.assert_called_once()
|
|
370
|
+
mock_deploy.assert_called_once()
|
|
371
|
+
|
|
372
|
+
call_kwargs = mock_deploy.call_args.kwargs
|
|
373
|
+
assert call_kwargs["chain_name"] == "new_chain"
|
|
374
|
+
assert call_kwargs.get("is_draft") is not True
|
|
375
|
+
assert call_kwargs.get("deploy_timeout_minutes") is None
|
|
376
|
+
|
|
366
377
|
assert deployment_handle.chain_id == "new-chain-id"
|
|
367
378
|
assert deployment_handle.chain_deployment_id == "new-chain-deployment-id"
|
|
368
379
|
|
|
369
380
|
|
|
370
|
-
def
|
|
381
|
+
def test_create_chain_with_deployment_name(remote):
|
|
371
382
|
with requests_mock.Mocker() as m:
|
|
372
383
|
m.post(
|
|
373
384
|
_TEST_REMOTE_GRAPHQL_PATH,
|
|
@@ -391,7 +402,8 @@ def test_create_chain_no_existing_chain(remote):
|
|
|
391
402
|
],
|
|
392
403
|
)
|
|
393
404
|
|
|
394
|
-
|
|
405
|
+
deployment_name = "chain-deployment"
|
|
406
|
+
create_chain_atomic(
|
|
395
407
|
api=remote.api,
|
|
396
408
|
chain_name="new_chain",
|
|
397
409
|
entrypoint=ChainletDataAtomic(
|
|
@@ -406,94 +418,32 @@ def test_create_chain_no_existing_chain(remote):
|
|
|
406
418
|
truss_user_env=b10_types.TrussUserEnv.collect(),
|
|
407
419
|
is_draft=False,
|
|
408
420
|
environment=None,
|
|
421
|
+
deployment_name=deployment_name,
|
|
409
422
|
)
|
|
410
423
|
|
|
411
|
-
get_chains_graphql_request = m.request_history[0]
|
|
412
424
|
create_chain_graphql_request = m.request_history[1]
|
|
413
425
|
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
id
|
|
418
|
-
name
|
|
419
|
-
}
|
|
420
|
-
}
|
|
421
|
-
""".strip()
|
|
422
|
-
|
|
423
|
-
assert_request_matches_expected_query(
|
|
424
|
-
get_chains_graphql_request, expected_get_chains_query
|
|
425
|
-
)
|
|
426
|
-
|
|
427
|
-
chainlets_string = """
|
|
428
|
-
{
|
|
429
|
-
name: "chainlet-1",
|
|
430
|
-
oracle: {
|
|
431
|
-
model_name: "model-1",
|
|
432
|
-
s3_key: "s3-key-1",
|
|
433
|
-
encoded_config_str: "encoded-config-str-1",
|
|
434
|
-
semver_bump: "MINOR"
|
|
435
|
-
}
|
|
436
|
-
}
|
|
437
|
-
""".strip()
|
|
438
|
-
|
|
439
|
-
expected_create_chain_mutation = f"""
|
|
440
|
-
mutation ($trussUserEnv: String) {{
|
|
441
|
-
deploy_chain_atomic(
|
|
442
|
-
chain_name: "new_chain"
|
|
443
|
-
is_draft: false
|
|
444
|
-
entrypoint: {chainlets_string}
|
|
445
|
-
dependencies: []
|
|
446
|
-
truss_user_env: $trussUserEnv
|
|
447
|
-
) {{
|
|
448
|
-
chain_deployment {{
|
|
449
|
-
id
|
|
450
|
-
chain {{
|
|
451
|
-
id
|
|
452
|
-
hostname
|
|
453
|
-
}}
|
|
454
|
-
}}
|
|
455
|
-
}}
|
|
456
|
-
}}
|
|
457
|
-
""".strip()
|
|
458
|
-
|
|
459
|
-
assert_request_matches_expected_query(
|
|
460
|
-
create_chain_graphql_request, expected_create_chain_mutation
|
|
426
|
+
assert (
|
|
427
|
+
'deployment_name: "chain-deployment"'
|
|
428
|
+
in create_chain_graphql_request.json()["query"]
|
|
461
429
|
)
|
|
462
430
|
|
|
463
|
-
assert deployment_handle.chain_id == "new-chain-id"
|
|
464
|
-
assert deployment_handle.chain_deployment_id == "new-chain-deployment-id"
|
|
465
|
-
|
|
466
431
|
|
|
467
432
|
def test_create_chain_with_existing_chain_promote_to_environment_publish_false(remote):
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
"data": {
|
|
475
|
-
"chains": [{"id": "old-chain-id", "name": "old_chain"}]
|
|
476
|
-
}
|
|
477
|
-
}
|
|
478
|
-
},
|
|
479
|
-
{
|
|
480
|
-
"json": {
|
|
481
|
-
"data": {
|
|
482
|
-
"deploy_chain_atomic": {
|
|
483
|
-
"chain_deployment": {
|
|
484
|
-
"id": "new-chain-deployment-id",
|
|
485
|
-
"chain": {
|
|
486
|
-
"id": "new-chain-id",
|
|
487
|
-
"hostname": "hostname",
|
|
488
|
-
},
|
|
489
|
-
}
|
|
490
|
-
}
|
|
491
|
-
}
|
|
492
|
-
}
|
|
493
|
-
},
|
|
494
|
-
],
|
|
495
|
-
)
|
|
433
|
+
mock_deploy_response = {
|
|
434
|
+
"chain_deployment": {
|
|
435
|
+
"id": "new-chain-deployment-id",
|
|
436
|
+
"chain": {"id": "new-chain-id", "hostname": "hostname"},
|
|
437
|
+
}
|
|
438
|
+
}
|
|
496
439
|
|
|
440
|
+
with mock.patch.object(
|
|
441
|
+
remote.api,
|
|
442
|
+
"get_chains",
|
|
443
|
+
return_value=[{"id": "old-chain-id", "name": "old_chain"}],
|
|
444
|
+
) as mock_get_chains, mock.patch.object(
|
|
445
|
+
remote.api, "deploy_chain_atomic", return_value=mock_deploy_response
|
|
446
|
+
) as mock_deploy:
|
|
497
447
|
deployment_handle = create_chain_atomic(
|
|
498
448
|
api=remote.api,
|
|
499
449
|
chain_name="old_chain",
|
|
@@ -511,95 +461,34 @@ def test_create_chain_with_existing_chain_promote_to_environment_publish_false(r
|
|
|
511
461
|
environment="production",
|
|
512
462
|
)
|
|
513
463
|
|
|
514
|
-
|
|
515
|
-
|
|
464
|
+
mock_get_chains.assert_called_once()
|
|
465
|
+
mock_deploy.assert_called_once()
|
|
516
466
|
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
}
|
|
523
|
-
}
|
|
524
|
-
""".strip()
|
|
525
|
-
|
|
526
|
-
assert_request_matches_expected_query(
|
|
527
|
-
get_chains_graphql_request, expected_get_chains_query
|
|
528
|
-
)
|
|
529
|
-
|
|
530
|
-
# Note that if publish=False and environment!=None, we set publish to True and create
|
|
531
|
-
# a non-draft deployment
|
|
532
|
-
chainlets_string = """
|
|
533
|
-
{
|
|
534
|
-
name: "chainlet-1",
|
|
535
|
-
oracle: {
|
|
536
|
-
model_name: "model-1",
|
|
537
|
-
s3_key: "s3-key-1",
|
|
538
|
-
encoded_config_str: "encoded-config-str-1",
|
|
539
|
-
semver_bump: "MINOR"
|
|
540
|
-
}
|
|
541
|
-
}
|
|
542
|
-
""".strip()
|
|
543
|
-
|
|
544
|
-
expected_create_chain_mutation = f"""
|
|
545
|
-
mutation ($trussUserEnv: String) {{
|
|
546
|
-
deploy_chain_atomic(
|
|
547
|
-
chain_id: "old-chain-id"
|
|
548
|
-
environment: "production"
|
|
549
|
-
is_draft: false
|
|
550
|
-
entrypoint: {chainlets_string}
|
|
551
|
-
dependencies: []
|
|
552
|
-
truss_user_env: $trussUserEnv
|
|
553
|
-
) {{
|
|
554
|
-
chain_deployment {{
|
|
555
|
-
id
|
|
556
|
-
chain {{
|
|
557
|
-
id
|
|
558
|
-
hostname
|
|
559
|
-
}}
|
|
560
|
-
}}
|
|
561
|
-
}}
|
|
562
|
-
}}
|
|
563
|
-
""".strip()
|
|
564
|
-
|
|
565
|
-
assert_request_matches_expected_query(
|
|
566
|
-
create_chain_graphql_request, expected_create_chain_mutation
|
|
567
|
-
)
|
|
467
|
+
call_kwargs = mock_deploy.call_args.kwargs
|
|
468
|
+
assert call_kwargs["chain_id"] == "old-chain-id"
|
|
469
|
+
assert call_kwargs["environment"] == "production"
|
|
470
|
+
assert call_kwargs.get("is_draft") is not True
|
|
471
|
+
assert call_kwargs.get("deploy_timeout_minutes") is None
|
|
568
472
|
|
|
569
473
|
assert deployment_handle.chain_id == "new-chain-id"
|
|
570
474
|
assert deployment_handle.chain_deployment_id == "new-chain-deployment-id"
|
|
571
475
|
|
|
572
476
|
|
|
573
477
|
def test_create_chain_existing_chain_publish_true_no_promotion(remote):
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
"data": {
|
|
581
|
-
"chains": [{"id": "old-chain-id", "name": "old_chain"}]
|
|
582
|
-
}
|
|
583
|
-
}
|
|
584
|
-
},
|
|
585
|
-
{
|
|
586
|
-
"json": {
|
|
587
|
-
"data": {
|
|
588
|
-
"deploy_chain_atomic": {
|
|
589
|
-
"chain_deployment": {
|
|
590
|
-
"id": "new-chain-deployment-id",
|
|
591
|
-
"chain": {
|
|
592
|
-
"id": "new-chain-id",
|
|
593
|
-
"hostname": "hostname",
|
|
594
|
-
},
|
|
595
|
-
}
|
|
596
|
-
}
|
|
597
|
-
}
|
|
598
|
-
}
|
|
599
|
-
},
|
|
600
|
-
],
|
|
601
|
-
)
|
|
478
|
+
mock_deploy_response = {
|
|
479
|
+
"chain_deployment": {
|
|
480
|
+
"id": "new-chain-deployment-id",
|
|
481
|
+
"chain": {"id": "new-chain-id", "hostname": "hostname"},
|
|
482
|
+
}
|
|
483
|
+
}
|
|
602
484
|
|
|
485
|
+
with mock.patch.object(
|
|
486
|
+
remote.api,
|
|
487
|
+
"get_chains",
|
|
488
|
+
return_value=[{"id": "old-chain-id", "name": "old_chain"}],
|
|
489
|
+
) as mock_get_chains, mock.patch.object(
|
|
490
|
+
remote.api, "deploy_chain_atomic", return_value=mock_deploy_response
|
|
491
|
+
) as mock_deploy:
|
|
603
492
|
deployment_handle = create_chain_atomic(
|
|
604
493
|
api=remote.api,
|
|
605
494
|
chain_name="old_chain",
|
|
@@ -617,57 +506,13 @@ def test_create_chain_existing_chain_publish_true_no_promotion(remote):
|
|
|
617
506
|
environment=None,
|
|
618
507
|
)
|
|
619
508
|
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
expected_get_chains_query = """
|
|
624
|
-
{
|
|
625
|
-
chains {
|
|
626
|
-
id
|
|
627
|
-
name
|
|
628
|
-
}
|
|
629
|
-
}
|
|
630
|
-
""".strip()
|
|
631
|
-
|
|
632
|
-
assert_request_matches_expected_query(
|
|
633
|
-
get_chains_graphql_request, expected_get_chains_query
|
|
634
|
-
)
|
|
509
|
+
mock_get_chains.assert_called_once()
|
|
510
|
+
mock_deploy.assert_called_once()
|
|
635
511
|
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
model_name: "model-1",
|
|
641
|
-
s3_key: "s3-key-1",
|
|
642
|
-
encoded_config_str: "encoded-config-str-1",
|
|
643
|
-
semver_bump: "MINOR"
|
|
644
|
-
}
|
|
645
|
-
}
|
|
646
|
-
""".strip()
|
|
647
|
-
|
|
648
|
-
expected_create_chain_mutation = f"""
|
|
649
|
-
mutation ($trussUserEnv: String) {{
|
|
650
|
-
deploy_chain_atomic(
|
|
651
|
-
chain_id: "old-chain-id"
|
|
652
|
-
is_draft: false
|
|
653
|
-
entrypoint: {chainlets_string}
|
|
654
|
-
dependencies: []
|
|
655
|
-
truss_user_env: $trussUserEnv
|
|
656
|
-
) {{
|
|
657
|
-
chain_deployment {{
|
|
658
|
-
id
|
|
659
|
-
chain {{
|
|
660
|
-
id
|
|
661
|
-
hostname
|
|
662
|
-
}}
|
|
663
|
-
}}
|
|
664
|
-
}}
|
|
665
|
-
}}
|
|
666
|
-
""".strip()
|
|
667
|
-
|
|
668
|
-
assert_request_matches_expected_query(
|
|
669
|
-
create_chain_graphql_request, expected_create_chain_mutation
|
|
670
|
-
)
|
|
512
|
+
call_kwargs = mock_deploy.call_args.kwargs
|
|
513
|
+
assert call_kwargs["chain_id"] == "old-chain-id"
|
|
514
|
+
assert call_kwargs.get("is_draft") is not True
|
|
515
|
+
assert call_kwargs.get("deploy_timeout_minutes") is None
|
|
671
516
|
|
|
672
517
|
assert deployment_handle.chain_id == "new-chain-id"
|
|
673
518
|
assert deployment_handle.chain_deployment_id == "new-chain-deployment-id"
|
|
@@ -726,3 +571,86 @@ def test_push_raised_validation_error_for_extra_fields(tmp_path, remote):
|
|
|
726
571
|
match="Extra fields not allowed: \[extra_field, who_am_i\]",
|
|
727
572
|
):
|
|
728
573
|
remote.push(th, "model_name", th.truss_dir)
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
def test_push_passes_deploy_timeout_minutes_to_create_truss_service(
|
|
577
|
+
custom_model_truss_dir_with_pre_and_post,
|
|
578
|
+
remote,
|
|
579
|
+
mock_baseten_requests,
|
|
580
|
+
mock_upload_truss,
|
|
581
|
+
mock_create_truss_service,
|
|
582
|
+
mock_truss_handle,
|
|
583
|
+
):
|
|
584
|
+
remote.push(
|
|
585
|
+
mock_truss_handle,
|
|
586
|
+
"model_name",
|
|
587
|
+
mock_truss_handle.truss_dir,
|
|
588
|
+
publish=True,
|
|
589
|
+
deploy_timeout_minutes=450,
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
mock_create_truss_service.assert_called_once()
|
|
593
|
+
_, kwargs = mock_create_truss_service.call_args
|
|
594
|
+
assert kwargs["deploy_timeout_minutes"] == 450
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
def test_push_passes_none_deploy_timeout_minutes_when_not_specified(
|
|
598
|
+
custom_model_truss_dir_with_pre_and_post,
|
|
599
|
+
remote,
|
|
600
|
+
mock_baseten_requests,
|
|
601
|
+
mock_upload_truss,
|
|
602
|
+
mock_create_truss_service,
|
|
603
|
+
mock_truss_handle,
|
|
604
|
+
):
|
|
605
|
+
remote.push(
|
|
606
|
+
mock_truss_handle, "model_name", mock_truss_handle.truss_dir, publish=True
|
|
607
|
+
)
|
|
608
|
+
|
|
609
|
+
mock_create_truss_service.assert_called_once()
|
|
610
|
+
_, kwargs = mock_create_truss_service.call_args
|
|
611
|
+
assert kwargs.get("deploy_timeout_minutes") is None
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
def test_push_integration_deploy_timeout_minutes_propagated(
|
|
615
|
+
custom_model_truss_dir_with_pre_and_post,
|
|
616
|
+
remote,
|
|
617
|
+
mock_baseten_requests,
|
|
618
|
+
mock_upload_truss,
|
|
619
|
+
mock_create_truss_service,
|
|
620
|
+
mock_truss_handle,
|
|
621
|
+
):
|
|
622
|
+
remote.push(
|
|
623
|
+
mock_truss_handle,
|
|
624
|
+
"model_name",
|
|
625
|
+
mock_truss_handle.truss_dir,
|
|
626
|
+
publish=True,
|
|
627
|
+
environment="staging",
|
|
628
|
+
deploy_timeout_minutes=750,
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
mock_create_truss_service.assert_called_once()
|
|
632
|
+
_, kwargs = mock_create_truss_service.call_args
|
|
633
|
+
assert kwargs["deploy_timeout_minutes"] == 750
|
|
634
|
+
assert kwargs["environment"] == "staging"
|
|
635
|
+
|
|
636
|
+
|
|
637
|
+
def test_api_push_integration_deploy_timeout_minutes_propagated(
|
|
638
|
+
custom_model_truss_dir_with_pre_and_post,
|
|
639
|
+
mock_remote_factory,
|
|
640
|
+
temp_trussrc_dir,
|
|
641
|
+
mock_available_config_names,
|
|
642
|
+
mock_truss_handle,
|
|
643
|
+
):
|
|
644
|
+
from truss.api import push
|
|
645
|
+
|
|
646
|
+
push(
|
|
647
|
+
str(mock_truss_handle.truss_dir),
|
|
648
|
+
remote="baseten",
|
|
649
|
+
model_name="test_model",
|
|
650
|
+
deploy_timeout_minutes=1200,
|
|
651
|
+
)
|
|
652
|
+
|
|
653
|
+
# Verify the remote.push was called with deploy_timeout_minutes
|
|
654
|
+
mock_remote_factory.push.assert_called_once()
|
|
655
|
+
_, push_kwargs = mock_remote_factory.push.call_args
|
|
656
|
+
assert push_kwargs.get("deploy_timeout_minutes") == 1200
|
|
@@ -1,4 +1,7 @@
|
|
|
1
|
+
from unittest.mock import MagicMock
|
|
2
|
+
|
|
1
3
|
from truss.remote.baseten import service
|
|
4
|
+
from truss.remote.baseten.core import ModelVersionHandle
|
|
2
5
|
|
|
3
6
|
|
|
4
7
|
def test_model_invoke_url_prod():
|
|
@@ -65,3 +68,56 @@ def test_chain_logs_url():
|
|
|
65
68
|
"https://app.baseten.co", "abc", "666", "543"
|
|
66
69
|
)
|
|
67
70
|
assert url == "https://app.baseten.co/chains/abc/logs/666/543"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def test_predict_response_to_json():
|
|
74
|
+
"""Test that predict method returns JSON response for normal dict result."""
|
|
75
|
+
# Create a mock BasetenService
|
|
76
|
+
mock_handle = MagicMock(spec=ModelVersionHandle)
|
|
77
|
+
mock_handle.model_id = "test-model"
|
|
78
|
+
mock_handle.version_id = "test-version"
|
|
79
|
+
mock_handle.hostname = "https://model-test.api.baseten.co"
|
|
80
|
+
|
|
81
|
+
mock_api = MagicMock()
|
|
82
|
+
mock_api.app_url = "https://app.baseten.co"
|
|
83
|
+
|
|
84
|
+
service_instance = service.BasetenService(
|
|
85
|
+
model_version_handle=mock_handle,
|
|
86
|
+
is_draft=False,
|
|
87
|
+
api_key="test-key",
|
|
88
|
+
service_url="https://test.com",
|
|
89
|
+
api=mock_api,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# Mock the _send_request method to return a successful JSON response
|
|
93
|
+
mock_response = MagicMock()
|
|
94
|
+
mock_response.json.return_value = {"result": "success"}
|
|
95
|
+
service_instance._send_request = MagicMock(return_value=mock_response)
|
|
96
|
+
|
|
97
|
+
# Test predict method
|
|
98
|
+
result = service_instance.predict({"input": "test"})
|
|
99
|
+
|
|
100
|
+
# Verify that the JSON response is returned directly
|
|
101
|
+
assert result == {"result": "success"}
|
|
102
|
+
|
|
103
|
+
# Test non-dict response types below
|
|
104
|
+
|
|
105
|
+
# With integer response
|
|
106
|
+
mock_response.json.return_value = 42
|
|
107
|
+
result = service_instance.predict({"input": "test"})
|
|
108
|
+
assert result == 42
|
|
109
|
+
|
|
110
|
+
# With string response
|
|
111
|
+
mock_response.json.return_value = "success"
|
|
112
|
+
result = service_instance.predict({"input": "test"})
|
|
113
|
+
assert result == "success"
|
|
114
|
+
|
|
115
|
+
# With list response
|
|
116
|
+
mock_response.json.return_value = [1, 2, 3, 4]
|
|
117
|
+
result = service_instance.predict({"input": "test"})
|
|
118
|
+
assert result == [1, 2, 3, 4]
|
|
119
|
+
|
|
120
|
+
# With boolean response
|
|
121
|
+
mock_response.json.return_value = True
|
|
122
|
+
result = service_instance.predict({"input": "test"})
|
|
123
|
+
assert result is True
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
# NB(nikhil): Slightly hacky helpers needed to set up the path so relative imports work as they do in real environments
|
|
6
|
+
def setup_control_imports():
|
|
7
|
+
base_path = Path(__file__).parent.parent.parent.parent.parent
|
|
8
|
+
paths = [
|
|
9
|
+
base_path / "templates" / "control" / "control",
|
|
10
|
+
base_path / "templates",
|
|
11
|
+
base_path / "templates" / "shared",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
for path in paths:
|
|
15
|
+
if not path.exists():
|
|
16
|
+
raise FileNotFoundError(f"Expected control path does not exist: {path}")
|
|
17
|
+
|
|
18
|
+
path_str = str(path)
|
|
19
|
+
if path_str not in sys.path:
|
|
20
|
+
sys.path.insert(0, path_str)
|
|
@@ -8,6 +8,10 @@ from httpx_ws import AsyncWebSocketSession
|
|
|
8
8
|
from httpx_ws import _exceptions as httpx_ws_exceptions
|
|
9
9
|
from wsproto.events import BytesMessage, TextMessage
|
|
10
10
|
|
|
11
|
+
from truss.tests.templates.control.control.conftest import setup_control_imports
|
|
12
|
+
|
|
13
|
+
setup_control_imports()
|
|
14
|
+
|
|
11
15
|
from truss.templates.control.control.endpoints import proxy_ws
|
|
12
16
|
|
|
13
17
|
|