truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/__init__.py +10 -3
- truss/api/__init__.py +123 -0
- truss/api/definitions.py +51 -0
- truss/base/constants.py +116 -0
- truss/base/custom_types.py +29 -0
- truss/{errors.py → base/errors.py} +4 -0
- truss/base/trt_llm_config.py +310 -0
- truss/{truss_config.py → base/truss_config.py} +344 -31
- truss/{truss_spec.py → base/truss_spec.py} +20 -6
- truss/{validation.py → base/validation.py} +60 -11
- truss/cli/cli.py +841 -88
- truss/{remote → cli}/remote_cli.py +2 -7
- truss/contexts/docker_build_setup.py +67 -0
- truss/contexts/image_builder/cache_warmer.py +2 -8
- truss/contexts/image_builder/image_builder.py +1 -1
- truss/contexts/image_builder/serving_image_builder.py +292 -46
- truss/contexts/image_builder/util.py +1 -3
- truss/contexts/local_loader/docker_build_emulator.py +58 -0
- truss/contexts/local_loader/load_model_local.py +2 -2
- truss/contexts/local_loader/truss_module_loader.py +1 -1
- truss/contexts/local_loader/utils.py +1 -1
- truss/local/local_config.py +2 -6
- truss/local/local_config_handler.py +20 -5
- truss/patch/__init__.py +1 -0
- truss/patch/hash.py +4 -70
- truss/patch/signature.py +4 -16
- truss/patch/truss_dir_patch_applier.py +3 -78
- truss/remote/baseten/api.py +308 -23
- truss/remote/baseten/auth.py +3 -3
- truss/remote/baseten/core.py +257 -50
- truss/remote/baseten/custom_types.py +44 -0
- truss/remote/baseten/error.py +4 -0
- truss/remote/baseten/remote.py +369 -118
- truss/remote/baseten/service.py +118 -11
- truss/remote/baseten/utils/status.py +29 -0
- truss/remote/baseten/utils/tar.py +34 -22
- truss/remote/baseten/utils/transfer.py +36 -23
- truss/remote/remote_factory.py +14 -5
- truss/remote/truss_remote.py +72 -45
- truss/templates/base.Dockerfile.jinja +18 -16
- truss/templates/cache.Dockerfile.jinja +3 -3
- truss/{server → templates/control}/control/application.py +14 -35
- truss/{server → templates/control}/control/endpoints.py +39 -9
- truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
- truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
- truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
- truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
- truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
- truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
- truss/{server → templates/control}/control/server.py +11 -6
- truss/templates/control/requirements.txt +9 -0
- truss/templates/custom_python_dx/my_model.py +28 -0
- truss/templates/docker_server/proxy.conf.jinja +42 -0
- truss/templates/docker_server/supervisord.conf.jinja +27 -0
- truss/templates/docker_server_requirements.txt +1 -0
- truss/templates/server/common/errors.py +231 -0
- truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
- truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
- truss/{server → templates/server}/common/retry.py +1 -0
- truss/{server → templates/server}/common/schema.py +11 -9
- truss/templates/server/common/tracing.py +157 -0
- truss/templates/server/main.py +9 -0
- truss/templates/server/model_wrapper.py +961 -0
- truss/templates/server/requirements.txt +21 -0
- truss/templates/server/truss_server.py +447 -0
- truss/templates/server.Dockerfile.jinja +62 -14
- truss/templates/shared/dynamic_config_resolver.py +28 -0
- truss/templates/shared/lazy_data_resolver.py +164 -0
- truss/templates/shared/log_config.py +125 -0
- truss/{server → templates}/shared/secrets_resolver.py +1 -2
- truss/{server → templates}/shared/serialization.py +31 -9
- truss/{server → templates}/shared/util.py +3 -13
- truss/templates/trtllm-audio/model/model.py +49 -0
- truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
- truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
- truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
- truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
- truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
- truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
- truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
- truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
- truss/templates/trtllm-briton/src/extension.py +64 -0
- truss/tests/conftest.py +302 -94
- truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
- truss/tests/contexts/local_loader/test_load_local.py +2 -2
- truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
- truss/tests/patch/test_calc_patch.py +439 -127
- truss/tests/patch/test_dir_signature.py +3 -12
- truss/tests/patch/test_hash.py +1 -1
- truss/tests/patch/test_signature.py +1 -1
- truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
- truss/tests/patch/test_types.py +2 -2
- truss/tests/remote/baseten/test_api.py +153 -58
- truss/tests/remote/baseten/test_auth.py +2 -1
- truss/tests/remote/baseten/test_core.py +160 -12
- truss/tests/remote/baseten/test_remote.py +489 -77
- truss/tests/remote/baseten/test_service.py +55 -0
- truss/tests/remote/test_remote_factory.py +16 -18
- truss/tests/remote/test_truss_remote.py +26 -17
- truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
- truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
- truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
- truss/tests/{server → templates/control}/control/test_server.py +79 -24
- truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
- truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
- truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
- truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
- truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
- truss/tests/{server → templates/server}/common/test_retry.py +3 -3
- truss/tests/templates/server/test_model_wrapper.py +248 -0
- truss/tests/{server → templates/server}/test_schema.py +3 -5
- truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
- truss/tests/test_build.py +9 -52
- truss/tests/test_config.py +336 -77
- truss/tests/test_context_builder_image.py +3 -11
- truss/tests/test_control_truss_patching.py +7 -12
- truss/tests/test_custom_server.py +38 -0
- truss/tests/test_data/context_builder_image_test/test.py +3 -0
- truss/tests/test_data/happy.ipynb +56 -0
- truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
- truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
- truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
- truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
- truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
- truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
- truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
- truss/tests/test_data/test_async_truss/__init__.py +0 -0
- truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/config.yaml +16 -0
- truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/config.yaml +13 -0
- truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
- truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
- truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
- truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
- truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
- truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
- truss/tests/test_data/test_openai/__init__.py +0 -0
- truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
- truss/tests/test_data/test_openai/model/__init__.py +0 -0
- truss/tests/test_data/test_openai/model/model.py +15 -0
- truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
- truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
- truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
- truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
- truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
- truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
- truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
- truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
- truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
- truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
- truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
- truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
- truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
- truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
- truss/tests/test_data/test_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss/config.yaml +4 -0
- truss/tests/test_data/test_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss/model/dummy +0 -0
- truss/tests/test_data/test_truss/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
- truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
- truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_docker.py +2 -1
- truss/tests/test_model_inference.py +1340 -292
- truss/tests/test_model_schema.py +33 -26
- truss/tests/test_testing_utilities_for_other_tests.py +50 -5
- truss/tests/test_truss_gatherer.py +3 -5
- truss/tests/test_truss_handle.py +62 -59
- truss/tests/test_util.py +2 -1
- truss/tests/test_validation.py +15 -13
- truss/tests/trt_llm/test_trt_llm_config.py +41 -0
- truss/tests/trt_llm/test_validation.py +91 -0
- truss/tests/util/test_config_checks.py +40 -0
- truss/tests/util/test_env_vars.py +14 -0
- truss/tests/util/test_path.py +10 -23
- truss/trt_llm/config_checks.py +43 -0
- truss/trt_llm/validation.py +42 -0
- truss/truss_handle/__init__.py +0 -0
- truss/truss_handle/build.py +122 -0
- truss/{decorators.py → truss_handle/decorators.py} +1 -1
- truss/truss_handle/patch/__init__.py +0 -0
- truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
- truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
- truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
- truss/truss_handle/patch/hash.py +71 -0
- truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
- truss/truss_handle/patch/signature.py +22 -0
- truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
- truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
- truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
- truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
- truss/util/.truss_ignore +3 -0
- truss/{docker.py → util/docker.py} +6 -2
- truss/util/download.py +6 -15
- truss/util/env_vars.py +41 -0
- truss/util/log_utils.py +52 -0
- truss/util/path.py +20 -20
- truss/util/requirements.py +11 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
- truss-0.60.0.dist-info/RECORD +324 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
- truss-0.60.0.dist-info/entry_points.txt +4 -0
- truss_chains/__init__.py +71 -0
- truss_chains/definitions.py +756 -0
- truss_chains/deployment/__init__.py +0 -0
- truss_chains/deployment/code_gen.py +816 -0
- truss_chains/deployment/deployment_client.py +871 -0
- truss_chains/framework.py +1480 -0
- truss_chains/public_api.py +231 -0
- truss_chains/py.typed +0 -0
- truss_chains/pydantic_numpy.py +131 -0
- truss_chains/reference_code/reference_chainlet.py +34 -0
- truss_chains/reference_code/reference_model.py +10 -0
- truss_chains/remote_chainlet/__init__.py +0 -0
- truss_chains/remote_chainlet/model_skeleton.py +60 -0
- truss_chains/remote_chainlet/stub.py +380 -0
- truss_chains/remote_chainlet/utils.py +332 -0
- truss_chains/streaming.py +378 -0
- truss_chains/utils.py +178 -0
- CODE_OF_CONDUCT.md +0 -131
- CONTRIBUTING.md +0 -48
- README.md +0 -137
- context_builder.Dockerfile +0 -24
- truss/blob/blob_backend.py +0 -10
- truss/blob/blob_backend_registry.py +0 -23
- truss/blob/http_public_blob_backend.py +0 -23
- truss/build/__init__.py +0 -2
- truss/build/build.py +0 -143
- truss/build/configure.py +0 -63
- truss/cli/__init__.py +0 -2
- truss/cli/console.py +0 -5
- truss/cli/create.py +0 -5
- truss/config/trt_llm.py +0 -81
- truss/constants.py +0 -61
- truss/model_inference.py +0 -123
- truss/patch/types.py +0 -30
- truss/pytest.ini +0 -7
- truss/server/common/errors.py +0 -100
- truss/server/common/termination_handler_middleware.py +0 -64
- truss/server/common/truss_server.py +0 -389
- truss/server/control/patch/model_code_patch_applier.py +0 -46
- truss/server/control/patch/requirement_name_identifier.py +0 -17
- truss/server/inference_server.py +0 -29
- truss/server/model_wrapper.py +0 -434
- truss/server/shared/logging.py +0 -81
- truss/templates/trtllm/model/model.py +0 -97
- truss/templates/trtllm/packages/build_engine_utils.py +0 -34
- truss/templates/trtllm/packages/constants.py +0 -11
- truss/templates/trtllm/packages/schema.py +0 -216
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
- truss/templates/trtllm/packages/triton_client.py +0 -150
- truss/templates/trtllm/packages/utils.py +0 -43
- truss/test_data/context_builder_image_test/test.py +0 -4
- truss/test_data/happy.ipynb +0 -54
- truss/test_data/model_load_failure_test/config.yaml +0 -2
- truss/test_data/test_concurrency_truss/config.yaml +0 -2
- truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
- truss/test_data/test_streaming_truss/config.yaml +0 -3
- truss/test_data/test_truss/config.yaml +0 -2
- truss/tests/server/common/test_termination_handler_middleware.py +0 -93
- truss/tests/server/control/test_model_container_patch_applier.py +0 -203
- truss/tests/server/core/server/common/test_util.py +0 -19
- truss/tests/server/test_model_wrapper.py +0 -87
- truss/util/data_structures.py +0 -16
- truss-0.10.0rc1.dist-info/RECORD +0 -216
- truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
- truss/{server/shared → base}/__init__.py +0 -0
- truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
- truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
- truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
- truss/{server → templates/server}/common/__init__.py +0 -0
- truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
- truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
- truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
- truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
- truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
- truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
- truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
- truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
- truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
- truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
- truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
- truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
- truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
- truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
- /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
- /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
- /truss/{patch → truss_handle/patch}/constants.py +0 -0
- /truss/{notebook.py → util/notebook.py} +0 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
|
@@ -1,29 +1,44 @@
|
|
|
1
|
-
import
|
|
1
|
+
from urllib import parse
|
|
2
|
+
|
|
2
3
|
import pytest
|
|
3
4
|
import requests_mock
|
|
4
|
-
|
|
5
|
+
import truss
|
|
6
|
+
from truss.remote.baseten.core import (
|
|
7
|
+
ModelId,
|
|
8
|
+
ModelName,
|
|
9
|
+
ModelVersionId,
|
|
10
|
+
create_chain_atomic,
|
|
11
|
+
)
|
|
12
|
+
from truss.remote.baseten.custom_types import ChainletDataAtomic, OracleData
|
|
13
|
+
from truss.remote.baseten.error import RemoteError
|
|
5
14
|
from truss.remote.baseten.remote import BasetenRemote
|
|
6
|
-
from truss.truss_handle import TrussHandle
|
|
15
|
+
from truss.truss_handle.truss_handle import TrussHandle
|
|
7
16
|
|
|
8
17
|
_TEST_REMOTE_URL = "http://test_remote.com"
|
|
18
|
+
_TEST_REMOTE_GRAPHQL_PATH = "http://test_remote.com/graphql/"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def assert_request_matches_expected_query(request, expected_query) -> None:
|
|
22
|
+
unescaped_content = parse.unquote_plus(request.text)
|
|
23
|
+
actual_lines = tuple(
|
|
24
|
+
line.strip()
|
|
25
|
+
for line in unescaped_content.replace("query=", "").strip().split("\n")
|
|
26
|
+
if line.strip()
|
|
27
|
+
)
|
|
28
|
+
expected_lines = tuple(
|
|
29
|
+
line.strip() for line in expected_query.split("\n") if line.strip()
|
|
30
|
+
)
|
|
31
|
+
assert actual_lines == expected_lines
|
|
9
32
|
|
|
10
33
|
|
|
11
34
|
def test_get_service_by_version_id():
|
|
12
35
|
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
|
|
13
36
|
|
|
14
|
-
version = {
|
|
15
|
-
"id": "version_id",
|
|
16
|
-
"oracle": {
|
|
17
|
-
"id": "model_id",
|
|
18
|
-
},
|
|
19
|
-
}
|
|
37
|
+
version = {"id": "version_id", "oracle": {"id": "model_id"}}
|
|
20
38
|
model_version_response = {"data": {"model_version": version}}
|
|
21
39
|
|
|
22
40
|
with requests_mock.Mocker() as m:
|
|
23
|
-
m.post(
|
|
24
|
-
remote._api._api_url,
|
|
25
|
-
json=model_version_response,
|
|
26
|
-
)
|
|
41
|
+
m.post(_TEST_REMOTE_GRAPHQL_PATH, json=model_version_response)
|
|
27
42
|
service = remote.get_service(model_identifier=ModelVersionId("version_id"))
|
|
28
43
|
|
|
29
44
|
assert service.model_id == "model_id"
|
|
@@ -34,11 +49,8 @@ def test_get_service_by_version_id_no_version():
|
|
|
34
49
|
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
|
|
35
50
|
model_version_response = {"errors": [{"message": "error"}]}
|
|
36
51
|
with requests_mock.Mocker() as m:
|
|
37
|
-
m.post(
|
|
38
|
-
|
|
39
|
-
json=model_version_response,
|
|
40
|
-
)
|
|
41
|
-
with pytest.raises(click.UsageError):
|
|
52
|
+
m.post(_TEST_REMOTE_GRAPHQL_PATH, json=model_version_response)
|
|
53
|
+
with pytest.raises(RemoteError):
|
|
42
54
|
remote.get_service(model_identifier=ModelVersionId("version_id"))
|
|
43
55
|
|
|
44
56
|
|
|
@@ -52,19 +64,12 @@ def test_get_service_by_model_name():
|
|
|
52
64
|
]
|
|
53
65
|
model_response = {
|
|
54
66
|
"data": {
|
|
55
|
-
"model": {
|
|
56
|
-
"name": "model_name",
|
|
57
|
-
"id": "model_id",
|
|
58
|
-
"versions": versions,
|
|
59
|
-
}
|
|
67
|
+
"model": {"name": "model_name", "id": "model_id", "versions": versions}
|
|
60
68
|
}
|
|
61
69
|
}
|
|
62
70
|
|
|
63
71
|
with requests_mock.Mocker() as m:
|
|
64
|
-
m.post(
|
|
65
|
-
remote._api._api_url,
|
|
66
|
-
json=model_response,
|
|
67
|
-
)
|
|
72
|
+
m.post(_TEST_REMOTE_GRAPHQL_PATH, json=model_response)
|
|
68
73
|
|
|
69
74
|
# Check that the production version is returned when published is True.
|
|
70
75
|
service = remote.get_service(
|
|
@@ -84,24 +89,15 @@ def test_get_service_by_model_name():
|
|
|
84
89
|
def test_get_service_by_model_name_no_dev_version():
|
|
85
90
|
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
|
|
86
91
|
|
|
87
|
-
versions = [
|
|
88
|
-
{"id": "1", "is_draft": False, "is_primary": True},
|
|
89
|
-
]
|
|
92
|
+
versions = [{"id": "1", "is_draft": False, "is_primary": True}]
|
|
90
93
|
model_response = {
|
|
91
94
|
"data": {
|
|
92
|
-
"model": {
|
|
93
|
-
"name": "model_name",
|
|
94
|
-
"id": "model_id",
|
|
95
|
-
"versions": versions,
|
|
96
|
-
}
|
|
95
|
+
"model": {"name": "model_name", "id": "model_id", "versions": versions}
|
|
97
96
|
}
|
|
98
97
|
}
|
|
99
98
|
|
|
100
99
|
with requests_mock.Mocker() as m:
|
|
101
|
-
m.post(
|
|
102
|
-
remote._api._api_url,
|
|
103
|
-
json=model_response,
|
|
104
|
-
)
|
|
100
|
+
m.post(_TEST_REMOTE_GRAPHQL_PATH, json=model_response)
|
|
105
101
|
|
|
106
102
|
# Check that the production version is returned when published is True.
|
|
107
103
|
service = remote.get_service(
|
|
@@ -112,7 +108,7 @@ def test_get_service_by_model_name_no_dev_version():
|
|
|
112
108
|
|
|
113
109
|
# Since no development version exists, calling get_service with
|
|
114
110
|
# published=False should raise an error.
|
|
115
|
-
with pytest.raises(
|
|
111
|
+
with pytest.raises(RemoteError):
|
|
116
112
|
remote.get_service(
|
|
117
113
|
model_identifier=ModelName("model_name"), published=False
|
|
118
114
|
)
|
|
@@ -121,28 +117,19 @@ def test_get_service_by_model_name_no_dev_version():
|
|
|
121
117
|
def test_get_service_by_model_name_no_prod_version():
|
|
122
118
|
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
|
|
123
119
|
|
|
124
|
-
versions = [
|
|
125
|
-
{"id": "1", "is_draft": True, "is_primary": False},
|
|
126
|
-
]
|
|
120
|
+
versions = [{"id": "1", "is_draft": True, "is_primary": False}]
|
|
127
121
|
model_response = {
|
|
128
122
|
"data": {
|
|
129
|
-
"model": {
|
|
130
|
-
"name": "model_name",
|
|
131
|
-
"id": "model_id",
|
|
132
|
-
"versions": versions,
|
|
133
|
-
}
|
|
123
|
+
"model": {"name": "model_name", "id": "model_id", "versions": versions}
|
|
134
124
|
}
|
|
135
125
|
}
|
|
136
126
|
|
|
137
127
|
with requests_mock.Mocker() as m:
|
|
138
|
-
m.post(
|
|
139
|
-
remote._api._api_url,
|
|
140
|
-
json=model_response,
|
|
141
|
-
)
|
|
128
|
+
m.post(_TEST_REMOTE_GRAPHQL_PATH, json=model_response)
|
|
142
129
|
|
|
143
130
|
# Since no production version exists, calling get_service with
|
|
144
131
|
# published=True should raise an error.
|
|
145
|
-
with pytest.raises(
|
|
132
|
+
with pytest.raises(RemoteError):
|
|
146
133
|
remote.get_service(model_identifier=ModelName("model_name"), published=True)
|
|
147
134
|
|
|
148
135
|
# Check that the development version is returned when published is False.
|
|
@@ -167,10 +154,7 @@ def test_get_service_by_model_id():
|
|
|
167
154
|
}
|
|
168
155
|
|
|
169
156
|
with requests_mock.Mocker() as m:
|
|
170
|
-
m.post(
|
|
171
|
-
remote._api._api_url,
|
|
172
|
-
json=model_response,
|
|
173
|
-
)
|
|
157
|
+
m.post(_TEST_REMOTE_GRAPHQL_PATH, json=model_response)
|
|
174
158
|
|
|
175
159
|
service = remote.get_service(model_identifier=ModelId("model_id"))
|
|
176
160
|
assert service.model_id == "model_id"
|
|
@@ -181,11 +165,8 @@ def test_get_service_by_model_id_no_model():
|
|
|
181
165
|
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
|
|
182
166
|
model_response = {"errors": [{"message": "error"}]}
|
|
183
167
|
with requests_mock.Mocker() as m:
|
|
184
|
-
m.post(
|
|
185
|
-
|
|
186
|
-
json=model_response,
|
|
187
|
-
)
|
|
188
|
-
with pytest.raises(click.UsageError):
|
|
168
|
+
m.post(_TEST_REMOTE_GRAPHQL_PATH, json=model_response)
|
|
169
|
+
with pytest.raises(RemoteError):
|
|
189
170
|
remote.get_service(model_identifier=ModelId("model_id"))
|
|
190
171
|
|
|
191
172
|
|
|
@@ -203,17 +184,22 @@ def test_push_raised_value_error_when_deployment_name_and_not_publish(
|
|
|
203
184
|
}
|
|
204
185
|
}
|
|
205
186
|
with requests_mock.Mocker() as m:
|
|
206
|
-
m.post(
|
|
207
|
-
remote._api._api_url,
|
|
208
|
-
json=model_response,
|
|
209
|
-
)
|
|
187
|
+
m.post(_TEST_REMOTE_GRAPHQL_PATH, json=model_response)
|
|
210
188
|
th = TrussHandle(custom_model_truss_dir_with_pre_and_post)
|
|
211
189
|
|
|
212
190
|
with pytest.raises(
|
|
213
191
|
ValueError,
|
|
214
192
|
match="Deployment name cannot be used for development deployment",
|
|
215
193
|
):
|
|
216
|
-
remote.push(
|
|
194
|
+
remote.push(
|
|
195
|
+
th,
|
|
196
|
+
"model_name",
|
|
197
|
+
publish=False,
|
|
198
|
+
trusted=False,
|
|
199
|
+
promote=False,
|
|
200
|
+
preserve_previous_prod_deployment=False,
|
|
201
|
+
deployment_name="dep_name",
|
|
202
|
+
)
|
|
217
203
|
|
|
218
204
|
|
|
219
205
|
def test_push_raised_value_error_when_deployment_name_is_not_valid(
|
|
@@ -230,17 +216,22 @@ def test_push_raised_value_error_when_deployment_name_is_not_valid(
|
|
|
230
216
|
}
|
|
231
217
|
}
|
|
232
218
|
with requests_mock.Mocker() as m:
|
|
233
|
-
m.post(
|
|
234
|
-
remote._api._api_url,
|
|
235
|
-
json=model_response,
|
|
236
|
-
)
|
|
219
|
+
m.post(_TEST_REMOTE_GRAPHQL_PATH, json=model_response)
|
|
237
220
|
th = TrussHandle(custom_model_truss_dir_with_pre_and_post)
|
|
238
221
|
|
|
239
222
|
with pytest.raises(
|
|
240
223
|
ValueError,
|
|
241
224
|
match="Deployment name must only contain alphanumeric, -, _ and . characters",
|
|
242
225
|
):
|
|
243
|
-
remote.push(
|
|
226
|
+
remote.push(
|
|
227
|
+
th,
|
|
228
|
+
"model_name",
|
|
229
|
+
publish=True,
|
|
230
|
+
trusted=False,
|
|
231
|
+
promote=False,
|
|
232
|
+
preserve_previous_prod_deployment=False,
|
|
233
|
+
deployment_name="dep//name",
|
|
234
|
+
)
|
|
244
235
|
|
|
245
236
|
|
|
246
237
|
def test_push_raised_value_error_when_keep_previous_prod_settings_and_not_promote(
|
|
@@ -257,14 +248,435 @@ def test_push_raised_value_error_when_keep_previous_prod_settings_and_not_promot
|
|
|
257
248
|
}
|
|
258
249
|
}
|
|
259
250
|
with requests_mock.Mocker() as m:
|
|
260
|
-
m.post(
|
|
261
|
-
remote._api._api_url,
|
|
262
|
-
json=model_response,
|
|
263
|
-
)
|
|
251
|
+
m.post(_TEST_REMOTE_GRAPHQL_PATH, json=model_response)
|
|
264
252
|
th = TrussHandle(custom_model_truss_dir_with_pre_and_post)
|
|
265
253
|
|
|
266
254
|
with pytest.raises(
|
|
267
255
|
ValueError,
|
|
268
256
|
match="preserve-previous-production-deployment can only be used with the '--promote' option",
|
|
269
257
|
):
|
|
270
|
-
remote.push(
|
|
258
|
+
remote.push(
|
|
259
|
+
th,
|
|
260
|
+
"model_name",
|
|
261
|
+
publish=False,
|
|
262
|
+
trusted=False,
|
|
263
|
+
promote=False,
|
|
264
|
+
preserve_previous_prod_deployment=True,
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def test_create_chain_with_no_publish():
|
|
269
|
+
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
|
|
270
|
+
|
|
271
|
+
with requests_mock.Mocker() as m:
|
|
272
|
+
m.post(
|
|
273
|
+
_TEST_REMOTE_GRAPHQL_PATH,
|
|
274
|
+
[
|
|
275
|
+
{"json": {"data": {"chains": []}}},
|
|
276
|
+
{
|
|
277
|
+
"json": {
|
|
278
|
+
"data": {
|
|
279
|
+
"deploy_chain_atomic": {
|
|
280
|
+
"chain_id": "new-chain-id",
|
|
281
|
+
"chain_deployment_id": "new-chain-deployment-id",
|
|
282
|
+
"entrypoint_model_id": "new-entrypoint-model-id",
|
|
283
|
+
"entrypoint_model_version_id": "new-entrypoint-model-version-id",
|
|
284
|
+
}
|
|
285
|
+
}
|
|
286
|
+
}
|
|
287
|
+
},
|
|
288
|
+
],
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
deployment_handle = create_chain_atomic(
|
|
292
|
+
api=remote.api,
|
|
293
|
+
chain_name="draft_chain",
|
|
294
|
+
entrypoint=ChainletDataAtomic(
|
|
295
|
+
name="chainlet-1",
|
|
296
|
+
oracle=OracleData(
|
|
297
|
+
model_name="model-1",
|
|
298
|
+
s3_key="s3-key-1",
|
|
299
|
+
encoded_config_str="encoded-config-str-1",
|
|
300
|
+
is_trusted=True,
|
|
301
|
+
),
|
|
302
|
+
),
|
|
303
|
+
dependencies=[],
|
|
304
|
+
is_draft=True,
|
|
305
|
+
environment=None,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
get_chains_graphql_request = m.request_history[0]
|
|
309
|
+
create_chain_graphql_request = m.request_history[1]
|
|
310
|
+
|
|
311
|
+
expected_get_chains_query = """
|
|
312
|
+
{
|
|
313
|
+
chains {
|
|
314
|
+
id
|
|
315
|
+
name
|
|
316
|
+
}
|
|
317
|
+
}
|
|
318
|
+
""".strip()
|
|
319
|
+
|
|
320
|
+
assert_request_matches_expected_query(
|
|
321
|
+
get_chains_graphql_request, expected_get_chains_query
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
chainlets_string = """
|
|
325
|
+
{
|
|
326
|
+
name: "chainlet-1",
|
|
327
|
+
oracle: {
|
|
328
|
+
model_name: "model-1",
|
|
329
|
+
s3_key: "s3-key-1",
|
|
330
|
+
encoded_config_str: "encoded-config-str-1",
|
|
331
|
+
is_trusted: true,
|
|
332
|
+
semver_bump: "MINOR"
|
|
333
|
+
}
|
|
334
|
+
}
|
|
335
|
+
""".strip()
|
|
336
|
+
|
|
337
|
+
# Note that if publish=False and promote=True, we set publish to True and create
|
|
338
|
+
# a non-draft deployment
|
|
339
|
+
expected_create_chain_mutation = f"""
|
|
340
|
+
mutation {{
|
|
341
|
+
deploy_chain_atomic(
|
|
342
|
+
chain_name: "draft_chain"
|
|
343
|
+
is_draft: true
|
|
344
|
+
entrypoint: {chainlets_string}
|
|
345
|
+
dependencies: []
|
|
346
|
+
client_version: "{truss.version()}"
|
|
347
|
+
) {{
|
|
348
|
+
chain_id
|
|
349
|
+
chain_deployment_id
|
|
350
|
+
entrypoint_model_id
|
|
351
|
+
entrypoint_model_version_id
|
|
352
|
+
}}
|
|
353
|
+
}}
|
|
354
|
+
""".strip()
|
|
355
|
+
|
|
356
|
+
assert_request_matches_expected_query(
|
|
357
|
+
create_chain_graphql_request, expected_create_chain_mutation
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
assert deployment_handle.chain_id == "new-chain-id"
|
|
361
|
+
assert deployment_handle.chain_deployment_id == "new-chain-deployment-id"
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def test_create_chain_no_existing_chain():
|
|
365
|
+
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
|
|
366
|
+
|
|
367
|
+
with requests_mock.Mocker() as m:
|
|
368
|
+
m.post(
|
|
369
|
+
_TEST_REMOTE_GRAPHQL_PATH,
|
|
370
|
+
[
|
|
371
|
+
{"json": {"data": {"chains": []}}},
|
|
372
|
+
{
|
|
373
|
+
"json": {
|
|
374
|
+
"data": {
|
|
375
|
+
"deploy_chain_atomic": {
|
|
376
|
+
"chain_id": "new-chain-id",
|
|
377
|
+
"chain_deployment_id": "new-chain-deployment-id",
|
|
378
|
+
"entrypoint_model_id": "new-entrypoint-model-id",
|
|
379
|
+
"entrypoint_model_version_id": "new-entrypoint-model-version-id",
|
|
380
|
+
}
|
|
381
|
+
}
|
|
382
|
+
}
|
|
383
|
+
},
|
|
384
|
+
],
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
deployment_handle = create_chain_atomic(
|
|
388
|
+
api=remote.api,
|
|
389
|
+
chain_name="new_chain",
|
|
390
|
+
entrypoint=ChainletDataAtomic(
|
|
391
|
+
name="chainlet-1",
|
|
392
|
+
oracle=OracleData(
|
|
393
|
+
model_name="model-1",
|
|
394
|
+
s3_key="s3-key-1",
|
|
395
|
+
encoded_config_str="encoded-config-str-1",
|
|
396
|
+
is_trusted=True,
|
|
397
|
+
),
|
|
398
|
+
),
|
|
399
|
+
dependencies=[],
|
|
400
|
+
is_draft=False,
|
|
401
|
+
environment=None,
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
get_chains_graphql_request = m.request_history[0]
|
|
405
|
+
create_chain_graphql_request = m.request_history[1]
|
|
406
|
+
|
|
407
|
+
expected_get_chains_query = """
|
|
408
|
+
{
|
|
409
|
+
chains {
|
|
410
|
+
id
|
|
411
|
+
name
|
|
412
|
+
}
|
|
413
|
+
}
|
|
414
|
+
""".strip()
|
|
415
|
+
|
|
416
|
+
assert_request_matches_expected_query(
|
|
417
|
+
get_chains_graphql_request, expected_get_chains_query
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
chainlets_string = """
|
|
421
|
+
{
|
|
422
|
+
name: "chainlet-1",
|
|
423
|
+
oracle: {
|
|
424
|
+
model_name: "model-1",
|
|
425
|
+
s3_key: "s3-key-1",
|
|
426
|
+
encoded_config_str: "encoded-config-str-1",
|
|
427
|
+
is_trusted: true,
|
|
428
|
+
semver_bump: "MINOR"
|
|
429
|
+
}
|
|
430
|
+
}
|
|
431
|
+
""".strip()
|
|
432
|
+
|
|
433
|
+
expected_create_chain_mutation = f"""
|
|
434
|
+
mutation {{
|
|
435
|
+
deploy_chain_atomic(
|
|
436
|
+
chain_name: "new_chain"
|
|
437
|
+
is_draft: false
|
|
438
|
+
entrypoint: {chainlets_string}
|
|
439
|
+
dependencies: []
|
|
440
|
+
client_version: "{truss.version()}"
|
|
441
|
+
) {{
|
|
442
|
+
chain_id
|
|
443
|
+
chain_deployment_id
|
|
444
|
+
entrypoint_model_id
|
|
445
|
+
entrypoint_model_version_id
|
|
446
|
+
}}
|
|
447
|
+
}}
|
|
448
|
+
""".strip()
|
|
449
|
+
|
|
450
|
+
assert_request_matches_expected_query(
|
|
451
|
+
create_chain_graphql_request, expected_create_chain_mutation
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
assert deployment_handle.chain_id == "new-chain-id"
|
|
455
|
+
assert deployment_handle.chain_deployment_id == "new-chain-deployment-id"
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def test_create_chain_with_existing_chain_promote_to_environment_publish_false():
|
|
459
|
+
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
|
|
460
|
+
|
|
461
|
+
with requests_mock.Mocker() as m:
|
|
462
|
+
m.post(
|
|
463
|
+
_TEST_REMOTE_GRAPHQL_PATH,
|
|
464
|
+
[
|
|
465
|
+
{
|
|
466
|
+
"json": {
|
|
467
|
+
"data": {
|
|
468
|
+
"chains": [{"id": "old-chain-id", "name": "old_chain"}]
|
|
469
|
+
}
|
|
470
|
+
}
|
|
471
|
+
},
|
|
472
|
+
{
|
|
473
|
+
"json": {
|
|
474
|
+
"data": {
|
|
475
|
+
"deploy_chain_atomic": {
|
|
476
|
+
"chain_id": "new-chain-id",
|
|
477
|
+
"chain_deployment_id": "new-chain-deployment-id",
|
|
478
|
+
"entrypoint_model_id": "new-entrypoint-model-id",
|
|
479
|
+
"entrypoint_model_version_id": "new-entrypoint-model-version-id",
|
|
480
|
+
}
|
|
481
|
+
}
|
|
482
|
+
}
|
|
483
|
+
},
|
|
484
|
+
],
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
deployment_handle = create_chain_atomic(
|
|
488
|
+
api=remote.api,
|
|
489
|
+
chain_name="old_chain",
|
|
490
|
+
entrypoint=ChainletDataAtomic(
|
|
491
|
+
name="chainlet-1",
|
|
492
|
+
oracle=OracleData(
|
|
493
|
+
model_name="model-1",
|
|
494
|
+
s3_key="s3-key-1",
|
|
495
|
+
encoded_config_str="encoded-config-str-1",
|
|
496
|
+
is_trusted=True,
|
|
497
|
+
),
|
|
498
|
+
),
|
|
499
|
+
dependencies=[],
|
|
500
|
+
is_draft=True,
|
|
501
|
+
environment="production",
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
get_chains_graphql_request = m.request_history[0]
|
|
505
|
+
create_chain_graphql_request = m.request_history[1]
|
|
506
|
+
|
|
507
|
+
expected_get_chains_query = """
|
|
508
|
+
{
|
|
509
|
+
chains {
|
|
510
|
+
id
|
|
511
|
+
name
|
|
512
|
+
}
|
|
513
|
+
}
|
|
514
|
+
""".strip()
|
|
515
|
+
|
|
516
|
+
assert_request_matches_expected_query(
|
|
517
|
+
get_chains_graphql_request, expected_get_chains_query
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
# Note that if publish=False and environment!=None, we set publish to True and create
|
|
521
|
+
# a non-draft deployment
|
|
522
|
+
chainlets_string = """
|
|
523
|
+
{
|
|
524
|
+
name: "chainlet-1",
|
|
525
|
+
oracle: {
|
|
526
|
+
model_name: "model-1",
|
|
527
|
+
s3_key: "s3-key-1",
|
|
528
|
+
encoded_config_str: "encoded-config-str-1",
|
|
529
|
+
is_trusted: true,
|
|
530
|
+
semver_bump: "MINOR"
|
|
531
|
+
}
|
|
532
|
+
}
|
|
533
|
+
""".strip()
|
|
534
|
+
|
|
535
|
+
expected_create_chain_mutation = f"""
|
|
536
|
+
mutation {{
|
|
537
|
+
deploy_chain_atomic(
|
|
538
|
+
chain_id: "old-chain-id"
|
|
539
|
+
environment: "production"
|
|
540
|
+
is_draft: false
|
|
541
|
+
entrypoint: {chainlets_string}
|
|
542
|
+
dependencies: []
|
|
543
|
+
client_version: "{truss.version()}"
|
|
544
|
+
) {{
|
|
545
|
+
chain_id
|
|
546
|
+
chain_deployment_id
|
|
547
|
+
entrypoint_model_id
|
|
548
|
+
entrypoint_model_version_id
|
|
549
|
+
}}
|
|
550
|
+
}}
|
|
551
|
+
""".strip()
|
|
552
|
+
|
|
553
|
+
assert_request_matches_expected_query(
|
|
554
|
+
create_chain_graphql_request, expected_create_chain_mutation
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
assert deployment_handle.chain_id == "new-chain-id"
|
|
558
|
+
assert deployment_handle.chain_deployment_id == "new-chain-deployment-id"
|
|
559
|
+
|
|
560
|
+
|
|
561
|
+
def test_create_chain_existing_chain_publish_true_no_promotion():
|
|
562
|
+
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
|
|
563
|
+
|
|
564
|
+
with requests_mock.Mocker() as m:
|
|
565
|
+
m.post(
|
|
566
|
+
_TEST_REMOTE_GRAPHQL_PATH,
|
|
567
|
+
[
|
|
568
|
+
{
|
|
569
|
+
"json": {
|
|
570
|
+
"data": {
|
|
571
|
+
"chains": [{"id": "old-chain-id", "name": "old_chain"}]
|
|
572
|
+
}
|
|
573
|
+
}
|
|
574
|
+
},
|
|
575
|
+
{
|
|
576
|
+
"json": {
|
|
577
|
+
"data": {
|
|
578
|
+
"deploy_chain_atomic": {
|
|
579
|
+
"chain_id": "new-chain-id",
|
|
580
|
+
"chain_deployment_id": "new-chain-deployment-id",
|
|
581
|
+
"entrypoint_model_id": "new-entrypoint-model-id",
|
|
582
|
+
"entrypoint_model_version_id": "new-entrypoint-model-version-id",
|
|
583
|
+
}
|
|
584
|
+
}
|
|
585
|
+
}
|
|
586
|
+
},
|
|
587
|
+
],
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
deployment_handle = create_chain_atomic(
|
|
591
|
+
api=remote.api,
|
|
592
|
+
chain_name="old_chain",
|
|
593
|
+
entrypoint=ChainletDataAtomic(
|
|
594
|
+
name="chainlet-1",
|
|
595
|
+
oracle=OracleData(
|
|
596
|
+
model_name="model-1",
|
|
597
|
+
s3_key="s3-key-1",
|
|
598
|
+
encoded_config_str="encoded-config-str-1",
|
|
599
|
+
is_trusted=True,
|
|
600
|
+
),
|
|
601
|
+
),
|
|
602
|
+
dependencies=[],
|
|
603
|
+
is_draft=False,
|
|
604
|
+
environment=None,
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
get_chains_graphql_request = m.request_history[0]
|
|
608
|
+
create_chain_graphql_request = m.request_history[1]
|
|
609
|
+
|
|
610
|
+
expected_get_chains_query = """
|
|
611
|
+
{
|
|
612
|
+
chains {
|
|
613
|
+
id
|
|
614
|
+
name
|
|
615
|
+
}
|
|
616
|
+
}
|
|
617
|
+
""".strip()
|
|
618
|
+
|
|
619
|
+
assert_request_matches_expected_query(
|
|
620
|
+
get_chains_graphql_request, expected_get_chains_query
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
chainlets_string = """
|
|
624
|
+
{
|
|
625
|
+
name: "chainlet-1",
|
|
626
|
+
oracle: {
|
|
627
|
+
model_name: "model-1",
|
|
628
|
+
s3_key: "s3-key-1",
|
|
629
|
+
encoded_config_str: "encoded-config-str-1",
|
|
630
|
+
is_trusted: true,
|
|
631
|
+
semver_bump: "MINOR"
|
|
632
|
+
}
|
|
633
|
+
}
|
|
634
|
+
""".strip()
|
|
635
|
+
|
|
636
|
+
expected_create_chain_mutation = f"""
|
|
637
|
+
mutation {{
|
|
638
|
+
deploy_chain_atomic(
|
|
639
|
+
chain_id: "old-chain-id"
|
|
640
|
+
is_draft: false
|
|
641
|
+
entrypoint: {chainlets_string}
|
|
642
|
+
dependencies: []
|
|
643
|
+
client_version: "{truss.version()}"
|
|
644
|
+
) {{
|
|
645
|
+
chain_id
|
|
646
|
+
chain_deployment_id
|
|
647
|
+
entrypoint_model_id
|
|
648
|
+
entrypoint_model_version_id
|
|
649
|
+
}}
|
|
650
|
+
}}
|
|
651
|
+
""".strip()
|
|
652
|
+
|
|
653
|
+
assert_request_matches_expected_query(
|
|
654
|
+
create_chain_graphql_request, expected_create_chain_mutation
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
assert deployment_handle.chain_id == "new-chain-id"
|
|
658
|
+
assert deployment_handle.chain_deployment_id == "new-chain-deployment-id"
|
|
659
|
+
|
|
660
|
+
|
|
661
|
+
@pytest.mark.parametrize("publish", [True, False])
|
|
662
|
+
def test_push_raised_value_error_when_disable_truss_download_for_existing_model(
|
|
663
|
+
publish, custom_model_truss_dir_with_pre_and_post
|
|
664
|
+
):
|
|
665
|
+
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
|
|
666
|
+
model_response = {
|
|
667
|
+
"data": {
|
|
668
|
+
"model": {
|
|
669
|
+
"name": "model_name",
|
|
670
|
+
"id": "model_id",
|
|
671
|
+
"primary_version": {"id": "version_id"},
|
|
672
|
+
}
|
|
673
|
+
}
|
|
674
|
+
}
|
|
675
|
+
with requests_mock.Mocker() as m:
|
|
676
|
+
m.post(_TEST_REMOTE_GRAPHQL_PATH, json=model_response)
|
|
677
|
+
th = TrussHandle(custom_model_truss_dir_with_pre_and_post)
|
|
678
|
+
|
|
679
|
+
with pytest.raises(
|
|
680
|
+
ValueError, match="disable-truss-download can only be used for new models"
|
|
681
|
+
):
|
|
682
|
+
remote.push(th, "model_name", publish=publish, disable_truss_download=True)
|