truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/__init__.py +10 -3
- truss/api/__init__.py +123 -0
- truss/api/definitions.py +51 -0
- truss/base/constants.py +116 -0
- truss/base/custom_types.py +29 -0
- truss/{errors.py → base/errors.py} +4 -0
- truss/base/trt_llm_config.py +310 -0
- truss/{truss_config.py → base/truss_config.py} +344 -31
- truss/{truss_spec.py → base/truss_spec.py} +20 -6
- truss/{validation.py → base/validation.py} +60 -11
- truss/cli/cli.py +841 -88
- truss/{remote → cli}/remote_cli.py +2 -7
- truss/contexts/docker_build_setup.py +67 -0
- truss/contexts/image_builder/cache_warmer.py +2 -8
- truss/contexts/image_builder/image_builder.py +1 -1
- truss/contexts/image_builder/serving_image_builder.py +292 -46
- truss/contexts/image_builder/util.py +1 -3
- truss/contexts/local_loader/docker_build_emulator.py +58 -0
- truss/contexts/local_loader/load_model_local.py +2 -2
- truss/contexts/local_loader/truss_module_loader.py +1 -1
- truss/contexts/local_loader/utils.py +1 -1
- truss/local/local_config.py +2 -6
- truss/local/local_config_handler.py +20 -5
- truss/patch/__init__.py +1 -0
- truss/patch/hash.py +4 -70
- truss/patch/signature.py +4 -16
- truss/patch/truss_dir_patch_applier.py +3 -78
- truss/remote/baseten/api.py +308 -23
- truss/remote/baseten/auth.py +3 -3
- truss/remote/baseten/core.py +257 -50
- truss/remote/baseten/custom_types.py +44 -0
- truss/remote/baseten/error.py +4 -0
- truss/remote/baseten/remote.py +369 -118
- truss/remote/baseten/service.py +118 -11
- truss/remote/baseten/utils/status.py +29 -0
- truss/remote/baseten/utils/tar.py +34 -22
- truss/remote/baseten/utils/transfer.py +36 -23
- truss/remote/remote_factory.py +14 -5
- truss/remote/truss_remote.py +72 -45
- truss/templates/base.Dockerfile.jinja +18 -16
- truss/templates/cache.Dockerfile.jinja +3 -3
- truss/{server → templates/control}/control/application.py +14 -35
- truss/{server → templates/control}/control/endpoints.py +39 -9
- truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
- truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
- truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
- truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
- truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
- truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
- truss/{server → templates/control}/control/server.py +11 -6
- truss/templates/control/requirements.txt +9 -0
- truss/templates/custom_python_dx/my_model.py +28 -0
- truss/templates/docker_server/proxy.conf.jinja +42 -0
- truss/templates/docker_server/supervisord.conf.jinja +27 -0
- truss/templates/docker_server_requirements.txt +1 -0
- truss/templates/server/common/errors.py +231 -0
- truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
- truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
- truss/{server → templates/server}/common/retry.py +1 -0
- truss/{server → templates/server}/common/schema.py +11 -9
- truss/templates/server/common/tracing.py +157 -0
- truss/templates/server/main.py +9 -0
- truss/templates/server/model_wrapper.py +961 -0
- truss/templates/server/requirements.txt +21 -0
- truss/templates/server/truss_server.py +447 -0
- truss/templates/server.Dockerfile.jinja +62 -14
- truss/templates/shared/dynamic_config_resolver.py +28 -0
- truss/templates/shared/lazy_data_resolver.py +164 -0
- truss/templates/shared/log_config.py +125 -0
- truss/{server → templates}/shared/secrets_resolver.py +1 -2
- truss/{server → templates}/shared/serialization.py +31 -9
- truss/{server → templates}/shared/util.py +3 -13
- truss/templates/trtllm-audio/model/model.py +49 -0
- truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
- truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
- truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
- truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
- truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
- truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
- truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
- truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
- truss/templates/trtllm-briton/src/extension.py +64 -0
- truss/tests/conftest.py +302 -94
- truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
- truss/tests/contexts/local_loader/test_load_local.py +2 -2
- truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
- truss/tests/patch/test_calc_patch.py +439 -127
- truss/tests/patch/test_dir_signature.py +3 -12
- truss/tests/patch/test_hash.py +1 -1
- truss/tests/patch/test_signature.py +1 -1
- truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
- truss/tests/patch/test_types.py +2 -2
- truss/tests/remote/baseten/test_api.py +153 -58
- truss/tests/remote/baseten/test_auth.py +2 -1
- truss/tests/remote/baseten/test_core.py +160 -12
- truss/tests/remote/baseten/test_remote.py +489 -77
- truss/tests/remote/baseten/test_service.py +55 -0
- truss/tests/remote/test_remote_factory.py +16 -18
- truss/tests/remote/test_truss_remote.py +26 -17
- truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
- truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
- truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
- truss/tests/{server → templates/control}/control/test_server.py +79 -24
- truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
- truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
- truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
- truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
- truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
- truss/tests/{server → templates/server}/common/test_retry.py +3 -3
- truss/tests/templates/server/test_model_wrapper.py +248 -0
- truss/tests/{server → templates/server}/test_schema.py +3 -5
- truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
- truss/tests/test_build.py +9 -52
- truss/tests/test_config.py +336 -77
- truss/tests/test_context_builder_image.py +3 -11
- truss/tests/test_control_truss_patching.py +7 -12
- truss/tests/test_custom_server.py +38 -0
- truss/tests/test_data/context_builder_image_test/test.py +3 -0
- truss/tests/test_data/happy.ipynb +56 -0
- truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
- truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
- truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
- truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
- truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
- truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
- truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
- truss/tests/test_data/test_async_truss/__init__.py +0 -0
- truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/config.yaml +16 -0
- truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/config.yaml +13 -0
- truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
- truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
- truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
- truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
- truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
- truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
- truss/tests/test_data/test_openai/__init__.py +0 -0
- truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
- truss/tests/test_data/test_openai/model/__init__.py +0 -0
- truss/tests/test_data/test_openai/model/model.py +15 -0
- truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
- truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
- truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
- truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
- truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
- truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
- truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
- truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
- truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
- truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
- truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
- truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
- truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
- truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
- truss/tests/test_data/test_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss/config.yaml +4 -0
- truss/tests/test_data/test_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss/model/dummy +0 -0
- truss/tests/test_data/test_truss/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
- truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
- truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_docker.py +2 -1
- truss/tests/test_model_inference.py +1340 -292
- truss/tests/test_model_schema.py +33 -26
- truss/tests/test_testing_utilities_for_other_tests.py +50 -5
- truss/tests/test_truss_gatherer.py +3 -5
- truss/tests/test_truss_handle.py +62 -59
- truss/tests/test_util.py +2 -1
- truss/tests/test_validation.py +15 -13
- truss/tests/trt_llm/test_trt_llm_config.py +41 -0
- truss/tests/trt_llm/test_validation.py +91 -0
- truss/tests/util/test_config_checks.py +40 -0
- truss/tests/util/test_env_vars.py +14 -0
- truss/tests/util/test_path.py +10 -23
- truss/trt_llm/config_checks.py +43 -0
- truss/trt_llm/validation.py +42 -0
- truss/truss_handle/__init__.py +0 -0
- truss/truss_handle/build.py +122 -0
- truss/{decorators.py → truss_handle/decorators.py} +1 -1
- truss/truss_handle/patch/__init__.py +0 -0
- truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
- truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
- truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
- truss/truss_handle/patch/hash.py +71 -0
- truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
- truss/truss_handle/patch/signature.py +22 -0
- truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
- truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
- truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
- truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
- truss/util/.truss_ignore +3 -0
- truss/{docker.py → util/docker.py} +6 -2
- truss/util/download.py +6 -15
- truss/util/env_vars.py +41 -0
- truss/util/log_utils.py +52 -0
- truss/util/path.py +20 -20
- truss/util/requirements.py +11 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
- truss-0.60.0.dist-info/RECORD +324 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
- truss-0.60.0.dist-info/entry_points.txt +4 -0
- truss_chains/__init__.py +71 -0
- truss_chains/definitions.py +756 -0
- truss_chains/deployment/__init__.py +0 -0
- truss_chains/deployment/code_gen.py +816 -0
- truss_chains/deployment/deployment_client.py +871 -0
- truss_chains/framework.py +1480 -0
- truss_chains/public_api.py +231 -0
- truss_chains/py.typed +0 -0
- truss_chains/pydantic_numpy.py +131 -0
- truss_chains/reference_code/reference_chainlet.py +34 -0
- truss_chains/reference_code/reference_model.py +10 -0
- truss_chains/remote_chainlet/__init__.py +0 -0
- truss_chains/remote_chainlet/model_skeleton.py +60 -0
- truss_chains/remote_chainlet/stub.py +380 -0
- truss_chains/remote_chainlet/utils.py +332 -0
- truss_chains/streaming.py +378 -0
- truss_chains/utils.py +178 -0
- CODE_OF_CONDUCT.md +0 -131
- CONTRIBUTING.md +0 -48
- README.md +0 -137
- context_builder.Dockerfile +0 -24
- truss/blob/blob_backend.py +0 -10
- truss/blob/blob_backend_registry.py +0 -23
- truss/blob/http_public_blob_backend.py +0 -23
- truss/build/__init__.py +0 -2
- truss/build/build.py +0 -143
- truss/build/configure.py +0 -63
- truss/cli/__init__.py +0 -2
- truss/cli/console.py +0 -5
- truss/cli/create.py +0 -5
- truss/config/trt_llm.py +0 -81
- truss/constants.py +0 -61
- truss/model_inference.py +0 -123
- truss/patch/types.py +0 -30
- truss/pytest.ini +0 -7
- truss/server/common/errors.py +0 -100
- truss/server/common/termination_handler_middleware.py +0 -64
- truss/server/common/truss_server.py +0 -389
- truss/server/control/patch/model_code_patch_applier.py +0 -46
- truss/server/control/patch/requirement_name_identifier.py +0 -17
- truss/server/inference_server.py +0 -29
- truss/server/model_wrapper.py +0 -434
- truss/server/shared/logging.py +0 -81
- truss/templates/trtllm/model/model.py +0 -97
- truss/templates/trtllm/packages/build_engine_utils.py +0 -34
- truss/templates/trtllm/packages/constants.py +0 -11
- truss/templates/trtllm/packages/schema.py +0 -216
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
- truss/templates/trtllm/packages/triton_client.py +0 -150
- truss/templates/trtllm/packages/utils.py +0 -43
- truss/test_data/context_builder_image_test/test.py +0 -4
- truss/test_data/happy.ipynb +0 -54
- truss/test_data/model_load_failure_test/config.yaml +0 -2
- truss/test_data/test_concurrency_truss/config.yaml +0 -2
- truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
- truss/test_data/test_streaming_truss/config.yaml +0 -3
- truss/test_data/test_truss/config.yaml +0 -2
- truss/tests/server/common/test_termination_handler_middleware.py +0 -93
- truss/tests/server/control/test_model_container_patch_applier.py +0 -203
- truss/tests/server/core/server/common/test_util.py +0 -19
- truss/tests/server/test_model_wrapper.py +0 -87
- truss/util/data_structures.py +0 -16
- truss-0.10.0rc1.dist-info/RECORD +0 -216
- truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
- truss/{server/shared → base}/__init__.py +0 -0
- truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
- truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
- truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
- truss/{server → templates/server}/common/__init__.py +0 -0
- truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
- truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
- truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
- truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
- truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
- truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
- truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
- truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
- truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
- truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
- truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
- truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
- truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
- truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
- /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
- /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
- /truss/{patch → truss_handle/patch}/constants.py +0 -0
- /truss/{notebook.py → util/notebook.py} +0 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from truss.patch.dir_signature import directory_content_signature
|
|
1
|
+
from truss.truss_handle.patch.dir_signature import directory_content_signature
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
def test_directory_content_signature(tmp_path):
|
|
@@ -12,12 +12,7 @@ def test_directory_content_signature(tmp_path):
|
|
|
12
12
|
|
|
13
13
|
content_sign = directory_content_signature(root)
|
|
14
14
|
|
|
15
|
-
assert content_sign.keys() == {
|
|
16
|
-
"dir",
|
|
17
|
-
"dir/file3",
|
|
18
|
-
"file1",
|
|
19
|
-
"file2",
|
|
20
|
-
}
|
|
15
|
+
assert content_sign.keys() == {"dir", "dir/file3", "file1", "file2"}
|
|
21
16
|
|
|
22
17
|
|
|
23
18
|
def test_directory_content_signature_ignore_patterns(tmp_path):
|
|
@@ -40,8 +35,4 @@ def test_directory_content_signature_ignore_patterns(tmp_path):
|
|
|
40
35
|
root=root, ignore_patterns=["data/*", ".git"]
|
|
41
36
|
)
|
|
42
37
|
|
|
43
|
-
assert content_sign.keys() == {
|
|
44
|
-
"data",
|
|
45
|
-
"file1",
|
|
46
|
-
"file2",
|
|
47
|
-
}
|
|
38
|
+
assert content_sign.keys() == {"data", "file1", "file2"}
|
truss/tests/patch/test_hash.py
CHANGED
|
@@ -2,17 +2,18 @@ import logging
|
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
|
|
4
4
|
import yaml
|
|
5
|
-
from truss.
|
|
6
|
-
from truss.
|
|
5
|
+
from truss.base.truss_config import TrussConfig
|
|
6
|
+
from truss.templates.control.control.helpers.custom_types import (
|
|
7
7
|
Action,
|
|
8
8
|
ConfigPatch,
|
|
9
9
|
ModelCodePatch,
|
|
10
|
+
PackagePatch,
|
|
10
11
|
Patch,
|
|
11
12
|
PatchType,
|
|
12
13
|
PythonRequirementPatch,
|
|
13
14
|
SystemPackagePatch,
|
|
14
15
|
)
|
|
15
|
-
from truss.
|
|
16
|
+
from truss.truss_handle.patch.truss_dir_patch_applier import TrussDirPatchApplier
|
|
16
17
|
|
|
17
18
|
TEST_LOGGER = logging.getLogger("test_logger")
|
|
18
19
|
|
|
@@ -32,6 +33,23 @@ def test_model_code_patch(custom_model_truss_dir: Path):
|
|
|
32
33
|
assert (custom_model_truss_dir / "model" / "model.py").read_text() == "test_content"
|
|
33
34
|
|
|
34
35
|
|
|
36
|
+
def test_packages_patch(custom_model_truss_dir: Path):
|
|
37
|
+
applier = TrussDirPatchApplier(custom_model_truss_dir, TEST_LOGGER)
|
|
38
|
+
applier(
|
|
39
|
+
[
|
|
40
|
+
Patch(
|
|
41
|
+
type=PatchType.PACKAGE,
|
|
42
|
+
body=PackagePatch(
|
|
43
|
+
action=Action.UPDATE, path="user_package.py", content="import sys"
|
|
44
|
+
),
|
|
45
|
+
)
|
|
46
|
+
]
|
|
47
|
+
)
|
|
48
|
+
assert (
|
|
49
|
+
custom_model_truss_dir / "packages" / "user_package.py"
|
|
50
|
+
).read_text() == "import sys"
|
|
51
|
+
|
|
52
|
+
|
|
35
53
|
def test_python_requirement_patch(custom_model_truss_dir: Path):
|
|
36
54
|
req = "git+https://github.com/huggingface/transformers.git"
|
|
37
55
|
applier = TrussDirPatchApplier(custom_model_truss_dir, TEST_LOGGER)
|
|
@@ -41,10 +59,7 @@ def test_python_requirement_patch(custom_model_truss_dir: Path):
|
|
|
41
59
|
[
|
|
42
60
|
Patch(
|
|
43
61
|
type=PatchType.PYTHON_REQUIREMENT,
|
|
44
|
-
body=PythonRequirementPatch(
|
|
45
|
-
action=Action.ADD,
|
|
46
|
-
requirement=req,
|
|
47
|
-
),
|
|
62
|
+
body=PythonRequirementPatch(action=Action.ADD, requirement=req),
|
|
48
63
|
),
|
|
49
64
|
Patch(
|
|
50
65
|
type=PatchType.CONFIG,
|
|
@@ -73,10 +88,7 @@ def test_system_requirement_patch(custom_model_truss_dir: Path):
|
|
|
73
88
|
),
|
|
74
89
|
Patch(
|
|
75
90
|
type=PatchType.SYSTEM_PACKAGE,
|
|
76
|
-
body=SystemPackagePatch(
|
|
77
|
-
action=Action.ADD,
|
|
78
|
-
package="curl",
|
|
79
|
-
),
|
|
91
|
+
body=SystemPackagePatch(action=Action.ADD, package="curl"),
|
|
80
92
|
),
|
|
81
93
|
]
|
|
82
94
|
)
|
truss/tests/patch/test_types.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
from truss.patch.
|
|
2
|
-
from truss.patch.
|
|
1
|
+
from truss.truss_handle.patch.custom_types import TrussSignature
|
|
2
|
+
from truss.truss_handle.patch.signature import calc_truss_signature
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
def test_truss_signature_type(custom_model_truss_dir):
|
|
@@ -4,9 +4,11 @@ import pytest
|
|
|
4
4
|
import requests
|
|
5
5
|
from requests import Response
|
|
6
6
|
from truss.remote.baseten.api import BasetenApi
|
|
7
|
+
from truss.remote.baseten.custom_types import ChainletDataAtomic, OracleData
|
|
7
8
|
from truss.remote.baseten.error import ApiError
|
|
8
9
|
|
|
9
10
|
|
|
11
|
+
@pytest.fixture
|
|
10
12
|
def mock_auth_service():
|
|
11
13
|
auth_service = mock.Mock()
|
|
12
14
|
auth_token = mock.Mock(headers=lambda: {"Authorization": "Api-Key token"})
|
|
@@ -52,45 +54,62 @@ def mock_create_model_response():
|
|
|
52
54
|
return response
|
|
53
55
|
|
|
54
56
|
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
57
|
+
def mock_create_development_model_response():
|
|
58
|
+
response = Response()
|
|
59
|
+
response.status_code = 200
|
|
60
|
+
response.json = mock.Mock(
|
|
61
|
+
return_value={"data": {"deploy_draft_truss": {"id": "12345"}}}
|
|
62
|
+
)
|
|
63
|
+
return response
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def mock_deploy_chain_deployment_response():
|
|
67
|
+
response = Response()
|
|
68
|
+
response.status_code = 200
|
|
69
|
+
response.json = mock.Mock(
|
|
70
|
+
return_value={
|
|
71
|
+
"data": {
|
|
72
|
+
"deploy_chain_atomic": {
|
|
73
|
+
"chain_id": "12345",
|
|
74
|
+
"chain_deployment_id": "54321",
|
|
75
|
+
"entrypoint_model_id": "67890",
|
|
76
|
+
"entrypoint_model_version_id": "09876",
|
|
77
|
+
}
|
|
78
|
+
}
|
|
79
|
+
}
|
|
80
|
+
)
|
|
81
|
+
return response
|
|
82
|
+
|
|
60
83
|
|
|
84
|
+
@pytest.fixture
|
|
85
|
+
def baseten_api(mock_auth_service):
|
|
86
|
+
return BasetenApi("https://app.test.com", mock_auth_service)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@mock.patch("requests.post", return_value=mock_successful_response())
|
|
90
|
+
def test_post_graphql_query_success(mock_post, baseten_api):
|
|
61
91
|
response_data = {"data": {"status": "success"}}
|
|
62
92
|
|
|
63
|
-
result =
|
|
93
|
+
result = baseten_api._post_graphql_query("sample_query_string")
|
|
64
94
|
|
|
65
95
|
assert result == response_data
|
|
66
96
|
|
|
67
97
|
|
|
68
|
-
@mock.patch("truss.remote.baseten.auth.AuthService")
|
|
69
98
|
@mock.patch("requests.post", return_value=mock_graphql_error_response())
|
|
70
|
-
def test_post_graphql_query_error(mock_post,
|
|
71
|
-
api_url = "https://test.com/api"
|
|
72
|
-
api = BasetenApi(api_url, mock_auth_service)
|
|
73
|
-
|
|
99
|
+
def test_post_graphql_query_error(mock_post, baseten_api):
|
|
74
100
|
with pytest.raises(ApiError):
|
|
75
|
-
|
|
101
|
+
baseten_api._post_graphql_query("sample_query_string")
|
|
76
102
|
|
|
77
103
|
|
|
78
|
-
@mock.patch("truss.remote.baseten.auth.AuthService")
|
|
79
104
|
@mock.patch("requests.post", return_value=mock_unsuccessful_response())
|
|
80
|
-
def test_post_requests_error(mock_post,
|
|
81
|
-
api_url = "https://test.com/api"
|
|
82
|
-
api = BasetenApi(api_url, mock_auth_service)
|
|
105
|
+
def test_post_requests_error(mock_post, baseten_api):
|
|
83
106
|
with pytest.raises(requests.exceptions.HTTPError):
|
|
84
|
-
|
|
107
|
+
baseten_api._post_graphql_query("sample_query_string")
|
|
85
108
|
|
|
86
109
|
|
|
87
|
-
@mock.patch("truss.remote.baseten.auth.AuthService")
|
|
88
110
|
@mock.patch("requests.post", return_value=mock_create_model_version_response())
|
|
89
|
-
def test_create_model_version_from_truss(mock_post,
|
|
90
|
-
|
|
91
|
-
api = BasetenApi(api_url, mock_auth_service)
|
|
92
|
-
|
|
93
|
-
api.create_model_version_from_truss(
|
|
111
|
+
def test_create_model_version_from_truss(mock_post, baseten_api):
|
|
112
|
+
baseten_api.create_model_version_from_truss(
|
|
94
113
|
"model_id",
|
|
95
114
|
"s3key",
|
|
96
115
|
"config_str",
|
|
@@ -98,8 +117,8 @@ def test_create_model_version_from_truss(mock_post, mock_auth_service):
|
|
|
98
117
|
"client_version",
|
|
99
118
|
False,
|
|
100
119
|
False,
|
|
101
|
-
False,
|
|
102
120
|
"deployment_name",
|
|
121
|
+
"production",
|
|
103
122
|
)
|
|
104
123
|
|
|
105
124
|
gql_mutation = mock_post.call_args[1]["data"]["query"]
|
|
@@ -109,27 +128,22 @@ def test_create_model_version_from_truss(mock_post, mock_auth_service):
|
|
|
109
128
|
assert 'semver_bump: "semver_bump"' in gql_mutation
|
|
110
129
|
assert 'client_version: "client_version"' in gql_mutation
|
|
111
130
|
assert "is_trusted: false" in gql_mutation
|
|
112
|
-
assert "promote_after_deploy: false" in gql_mutation
|
|
113
131
|
assert "scale_down_old_production: true" in gql_mutation
|
|
114
132
|
assert 'name: "deployment_name"' in gql_mutation
|
|
133
|
+
assert 'environment_name: "production"' in gql_mutation
|
|
115
134
|
|
|
116
135
|
|
|
117
|
-
@mock.patch("truss.remote.baseten.auth.AuthService")
|
|
118
136
|
@mock.patch("requests.post", return_value=mock_create_model_version_response())
|
|
119
137
|
def test_create_model_version_from_truss_does_not_send_deployment_name_if_not_specified(
|
|
120
|
-
mock_post,
|
|
138
|
+
mock_post, baseten_api
|
|
121
139
|
):
|
|
122
|
-
|
|
123
|
-
api = BasetenApi(api_url, mock_auth_service)
|
|
124
|
-
|
|
125
|
-
api.create_model_version_from_truss(
|
|
140
|
+
baseten_api.create_model_version_from_truss(
|
|
126
141
|
"model_id",
|
|
127
142
|
"s3key",
|
|
128
143
|
"config_str",
|
|
129
144
|
"semver_bump",
|
|
130
145
|
"client_version",
|
|
131
146
|
True,
|
|
132
|
-
True,
|
|
133
147
|
False,
|
|
134
148
|
deployment_name=None,
|
|
135
149
|
)
|
|
@@ -141,20 +155,16 @@ def test_create_model_version_from_truss_does_not_send_deployment_name_if_not_sp
|
|
|
141
155
|
assert 'semver_bump: "semver_bump"' in gql_mutation
|
|
142
156
|
assert 'client_version: "client_version"' in gql_mutation
|
|
143
157
|
assert "is_trusted: true" in gql_mutation
|
|
144
|
-
assert "promote_after_deploy: true" in gql_mutation
|
|
145
158
|
assert "scale_down_old_production: true" in gql_mutation
|
|
146
|
-
assert "name: " not in gql_mutation
|
|
159
|
+
assert " name: " not in gql_mutation
|
|
160
|
+
assert "environment_name: " not in gql_mutation
|
|
147
161
|
|
|
148
162
|
|
|
149
|
-
@mock.patch("truss.remote.baseten.auth.AuthService")
|
|
150
163
|
@mock.patch("requests.post", return_value=mock_create_model_version_response())
|
|
151
164
|
def test_create_model_version_from_truss_does_not_scale_old_prod_to_zero_if_keep_previous_prod_settings(
|
|
152
|
-
mock_post,
|
|
165
|
+
mock_post, baseten_api
|
|
153
166
|
):
|
|
154
|
-
|
|
155
|
-
api = BasetenApi(api_url, mock_auth_service)
|
|
156
|
-
|
|
157
|
-
api.create_model_version_from_truss(
|
|
167
|
+
baseten_api.create_model_version_from_truss(
|
|
158
168
|
"model_id",
|
|
159
169
|
"s3key",
|
|
160
170
|
"config_str",
|
|
@@ -162,8 +172,8 @@ def test_create_model_version_from_truss_does_not_scale_old_prod_to_zero_if_keep
|
|
|
162
172
|
"client_version",
|
|
163
173
|
True,
|
|
164
174
|
True,
|
|
165
|
-
True,
|
|
166
175
|
deployment_name=None,
|
|
176
|
+
environment="staging",
|
|
167
177
|
)
|
|
168
178
|
|
|
169
179
|
gql_mutation = mock_post.call_args[1]["data"]["query"]
|
|
@@ -173,25 +183,21 @@ def test_create_model_version_from_truss_does_not_scale_old_prod_to_zero_if_keep
|
|
|
173
183
|
assert 'semver_bump: "semver_bump"' in gql_mutation
|
|
174
184
|
assert 'client_version: "client_version"' in gql_mutation
|
|
175
185
|
assert "is_trusted: true" in gql_mutation
|
|
176
|
-
assert "promote_after_deploy: true" in gql_mutation
|
|
177
186
|
assert "scale_down_old_production: false" in gql_mutation
|
|
178
|
-
assert "name: " not in gql_mutation
|
|
187
|
+
assert " name: " not in gql_mutation
|
|
188
|
+
assert 'environment_name: "staging"' in gql_mutation
|
|
179
189
|
|
|
180
190
|
|
|
181
|
-
@mock.patch("truss.remote.baseten.auth.AuthService")
|
|
182
191
|
@mock.patch("requests.post", return_value=mock_create_model_response())
|
|
183
|
-
def test_create_model_from_truss(mock_post,
|
|
184
|
-
|
|
185
|
-
api = BasetenApi(api_url, mock_auth_service)
|
|
186
|
-
|
|
187
|
-
api.create_model_from_truss(
|
|
192
|
+
def test_create_model_from_truss(mock_post, baseten_api):
|
|
193
|
+
baseten_api.create_model_from_truss(
|
|
188
194
|
"model_name",
|
|
189
195
|
"s3key",
|
|
190
196
|
"config_str",
|
|
191
197
|
"semver_bump",
|
|
192
198
|
"client_version",
|
|
193
|
-
False,
|
|
194
|
-
"deployment_name",
|
|
199
|
+
is_trusted=False,
|
|
200
|
+
deployment_name="deployment_name",
|
|
195
201
|
)
|
|
196
202
|
|
|
197
203
|
gql_mutation = mock_post.call_args[1]["data"]["query"]
|
|
@@ -204,21 +210,17 @@ def test_create_model_from_truss(mock_post, mock_auth_service):
|
|
|
204
210
|
assert 'version_name: "deployment_name"' in gql_mutation
|
|
205
211
|
|
|
206
212
|
|
|
207
|
-
@mock.patch("truss.remote.baseten.auth.AuthService")
|
|
208
213
|
@mock.patch("requests.post", return_value=mock_create_model_response())
|
|
209
214
|
def test_create_model_from_truss_does_not_send_deployment_name_if_not_specified(
|
|
210
|
-
mock_post,
|
|
215
|
+
mock_post, baseten_api
|
|
211
216
|
):
|
|
212
|
-
|
|
213
|
-
api = BasetenApi(api_url, mock_auth_service)
|
|
214
|
-
|
|
215
|
-
api.create_model_from_truss(
|
|
217
|
+
baseten_api.create_model_from_truss(
|
|
216
218
|
"model_name",
|
|
217
219
|
"s3key",
|
|
218
220
|
"config_str",
|
|
219
221
|
"semver_bump",
|
|
220
222
|
"client_version",
|
|
221
|
-
True,
|
|
223
|
+
is_trusted=True,
|
|
222
224
|
deployment_name=None,
|
|
223
225
|
)
|
|
224
226
|
|
|
@@ -230,3 +232,96 @@ def test_create_model_from_truss_does_not_send_deployment_name_if_not_specified(
|
|
|
230
232
|
assert 'client_version: "client_version"' in gql_mutation
|
|
231
233
|
assert "is_trusted: true" in gql_mutation
|
|
232
234
|
assert "version_name: " not in gql_mutation
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
@mock.patch("requests.post", return_value=mock_create_model_response())
|
|
238
|
+
def test_create_model_from_truss_with_allow_truss_download(mock_post, baseten_api):
|
|
239
|
+
baseten_api.create_model_from_truss(
|
|
240
|
+
"model_name",
|
|
241
|
+
"s3key",
|
|
242
|
+
"config_str",
|
|
243
|
+
"semver_bump",
|
|
244
|
+
"client_version",
|
|
245
|
+
is_trusted=True,
|
|
246
|
+
allow_truss_download=False,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
gql_mutation = mock_post.call_args[1]["data"]["query"]
|
|
250
|
+
assert 'name: "model_name"' in gql_mutation
|
|
251
|
+
assert 's3_key: "s3key"' in gql_mutation
|
|
252
|
+
assert 'config: "config_str"' in gql_mutation
|
|
253
|
+
assert 'semver_bump: "semver_bump"' in gql_mutation
|
|
254
|
+
assert 'client_version: "client_version"' in gql_mutation
|
|
255
|
+
assert "is_trusted: true" in gql_mutation
|
|
256
|
+
assert "allow_truss_download: false" in gql_mutation
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
@mock.patch("requests.post", return_value=mock_create_development_model_response())
|
|
260
|
+
def test_create_development_model_from_truss_with_allow_truss_download(
|
|
261
|
+
mock_post, baseten_api
|
|
262
|
+
):
|
|
263
|
+
baseten_api.create_development_model_from_truss(
|
|
264
|
+
"model_name",
|
|
265
|
+
"s3key",
|
|
266
|
+
"config_str",
|
|
267
|
+
"client_version",
|
|
268
|
+
is_trusted=True,
|
|
269
|
+
allow_truss_download=False,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
gql_mutation = mock_post.call_args[1]["data"]["query"]
|
|
273
|
+
assert 'name: "model_name"' in gql_mutation
|
|
274
|
+
assert 's3_key: "s3key"' in gql_mutation
|
|
275
|
+
assert 'config: "config_str"' in gql_mutation
|
|
276
|
+
assert 'client_version: "client_version"' in gql_mutation
|
|
277
|
+
assert "is_trusted: true" in gql_mutation
|
|
278
|
+
assert "allow_truss_download: false" in gql_mutation
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
@mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response())
|
|
282
|
+
def test_deploy_chain_deployment(mock_post, baseten_api):
|
|
283
|
+
baseten_api.deploy_chain_atomic(
|
|
284
|
+
environment="production",
|
|
285
|
+
chain_id="chain_id",
|
|
286
|
+
dependencies=[],
|
|
287
|
+
entrypoint=ChainletDataAtomic(
|
|
288
|
+
name="chainlet-1",
|
|
289
|
+
oracle=OracleData(
|
|
290
|
+
model_name="model-1",
|
|
291
|
+
s3_key="s3-key-1",
|
|
292
|
+
encoded_config_str="encoded-config-str-1",
|
|
293
|
+
is_trusted=True,
|
|
294
|
+
),
|
|
295
|
+
),
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
gql_mutation = mock_post.call_args[1]["data"]["query"]
|
|
299
|
+
|
|
300
|
+
assert 'environment: "production"' in gql_mutation
|
|
301
|
+
assert 'chain_id: "chain_id"' in gql_mutation
|
|
302
|
+
assert "dependencies:" in gql_mutation
|
|
303
|
+
assert "entrypoint:" in gql_mutation
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
@mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response())
|
|
307
|
+
def test_deploy_chain_deployment_no_environment(mock_post, baseten_api):
|
|
308
|
+
baseten_api.deploy_chain_atomic(
|
|
309
|
+
chain_id="chain_id",
|
|
310
|
+
dependencies=[],
|
|
311
|
+
entrypoint=ChainletDataAtomic(
|
|
312
|
+
name="chainlet-1",
|
|
313
|
+
oracle=OracleData(
|
|
314
|
+
model_name="model-1",
|
|
315
|
+
s3_key="s3-key-1",
|
|
316
|
+
encoded_config_str="encoded-config-str-1",
|
|
317
|
+
is_trusted=True,
|
|
318
|
+
),
|
|
319
|
+
),
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
gql_mutation = mock_post.call_args[1]["data"]["query"]
|
|
323
|
+
|
|
324
|
+
assert 'chain_id: "chain_id"' in gql_mutation
|
|
325
|
+
assert "environment" not in gql_mutation
|
|
326
|
+
assert "dependencies:" in gql_mutation
|
|
327
|
+
assert "entrypoint:" in gql_mutation
|
|
@@ -9,7 +9,8 @@ def test_api_key():
|
|
|
9
9
|
assert key.header() == {"Authorization": "Api-Key test_key"}
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
def test_auth_service_no_key():
|
|
12
|
+
def test_auth_service_no_key(monkeypatch: pytest.MonkeyPatch):
|
|
13
|
+
monkeypatch.delenv("BASETEN_API_KEY", raising=False)
|
|
13
14
|
auth_service = AuthService()
|
|
14
15
|
with pytest.raises(AuthorizationError):
|
|
15
16
|
auth_service.authenticate()
|
|
@@ -1,8 +1,13 @@
|
|
|
1
|
+
import json
|
|
1
2
|
from tempfile import NamedTemporaryFile
|
|
2
3
|
from unittest.mock import MagicMock
|
|
3
4
|
|
|
5
|
+
import pytest
|
|
6
|
+
from truss.base.constants import PRODUCTION_ENVIRONMENT_NAME
|
|
7
|
+
from truss.base.errors import ValidationError
|
|
4
8
|
from truss.remote.baseten import core
|
|
5
9
|
from truss.remote.baseten.api import BasetenApi
|
|
10
|
+
from truss.remote.baseten.core import create_truss_service
|
|
6
11
|
from truss.remote.baseten.error import ApiError
|
|
7
12
|
|
|
8
13
|
|
|
@@ -35,31 +40,23 @@ def test_upload_truss():
|
|
|
35
40
|
core.multipart_upload_boto3 = MagicMock()
|
|
36
41
|
core.multipart_upload_boto3.return_value = None
|
|
37
42
|
test_file = NamedTemporaryFile()
|
|
38
|
-
assert core.upload_truss(api, test_file) == "key"
|
|
43
|
+
assert core.upload_truss(api, test_file, None) == "key"
|
|
39
44
|
|
|
40
45
|
|
|
41
46
|
def test_get_dev_version_from_versions():
|
|
42
|
-
versions = [
|
|
43
|
-
{"id": "1", "is_draft": False},
|
|
44
|
-
{"id": "2", "is_draft": True},
|
|
45
|
-
]
|
|
47
|
+
versions = [{"id": "1", "is_draft": False}, {"id": "2", "is_draft": True}]
|
|
46
48
|
dev_version = core.get_dev_version_from_versions(versions)
|
|
47
49
|
assert dev_version["id"] == "2"
|
|
48
50
|
|
|
49
51
|
|
|
50
52
|
def test_get_dev_version_from_versions_error():
|
|
51
|
-
versions = [
|
|
52
|
-
{"id": "1", "is_draft": False},
|
|
53
|
-
]
|
|
53
|
+
versions = [{"id": "1", "is_draft": False}]
|
|
54
54
|
dev_version = core.get_dev_version_from_versions(versions)
|
|
55
55
|
assert dev_version is None
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
def test_get_dev_version():
|
|
59
|
-
versions = [
|
|
60
|
-
{"id": "1", "is_draft": False},
|
|
61
|
-
{"id": "2", "is_draft": True},
|
|
62
|
-
]
|
|
59
|
+
versions = [{"id": "1", "is_draft": False}, {"id": "2", "is_draft": True}]
|
|
63
60
|
api = MagicMock()
|
|
64
61
|
api.get_model.return_value = {"model": {"versions": versions}}
|
|
65
62
|
|
|
@@ -84,3 +81,154 @@ def test_get_prod_version_from_versions_error():
|
|
|
84
81
|
]
|
|
85
82
|
prod_version = core.get_prod_version_from_versions(versions)
|
|
86
83
|
assert prod_version is None
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@pytest.mark.parametrize("environment", [None, PRODUCTION_ENVIRONMENT_NAME])
|
|
87
|
+
def test_create_truss_service_handles_eligible_environment_values(environment):
|
|
88
|
+
api = MagicMock()
|
|
89
|
+
return_value = {"id": "id", "version_id": "model_version_id"}
|
|
90
|
+
api.create_model_from_truss.return_value = return_value
|
|
91
|
+
model_id, model_version_id = create_truss_service(
|
|
92
|
+
api,
|
|
93
|
+
"model_name",
|
|
94
|
+
"s3_key",
|
|
95
|
+
"config",
|
|
96
|
+
is_trusted=False,
|
|
97
|
+
preserve_previous_prod_deployment=False,
|
|
98
|
+
is_draft=False,
|
|
99
|
+
model_id=None,
|
|
100
|
+
deployment_name="deployment_name",
|
|
101
|
+
environment=environment,
|
|
102
|
+
)
|
|
103
|
+
assert model_id == return_value["id"]
|
|
104
|
+
assert model_version_id == return_value["version_id"]
|
|
105
|
+
api.create_model_from_truss.assert_called_once()
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@pytest.mark.parametrize("model_id", ["some_model_id", None])
|
|
109
|
+
def test_create_truss_services_handles_is_draft(model_id):
|
|
110
|
+
api = MagicMock()
|
|
111
|
+
return_value = {"id": "id", "version_id": "model_version_id"}
|
|
112
|
+
api.create_development_model_from_truss.return_value = return_value
|
|
113
|
+
model_id, model_version_id = create_truss_service(
|
|
114
|
+
api,
|
|
115
|
+
"model_name",
|
|
116
|
+
"s3_key",
|
|
117
|
+
"config",
|
|
118
|
+
is_trusted=False,
|
|
119
|
+
preserve_previous_prod_deployment=False,
|
|
120
|
+
is_draft=True,
|
|
121
|
+
model_id=model_id,
|
|
122
|
+
deployment_name="deployment_name",
|
|
123
|
+
)
|
|
124
|
+
assert model_id == return_value["id"]
|
|
125
|
+
assert model_version_id == return_value["version_id"]
|
|
126
|
+
api.create_development_model_from_truss.assert_called_once()
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@pytest.mark.parametrize(
|
|
130
|
+
"inputs",
|
|
131
|
+
[
|
|
132
|
+
{
|
|
133
|
+
"environment": None,
|
|
134
|
+
"deployment_name": "some deployment",
|
|
135
|
+
"is_trusted": True,
|
|
136
|
+
"preserve_previous_prod_deployment": False,
|
|
137
|
+
},
|
|
138
|
+
{
|
|
139
|
+
"environment": PRODUCTION_ENVIRONMENT_NAME,
|
|
140
|
+
"deployment_name": None,
|
|
141
|
+
"is_trusted": True,
|
|
142
|
+
"preserve_previous_prod_deployment": False,
|
|
143
|
+
},
|
|
144
|
+
{
|
|
145
|
+
"environment": "staging",
|
|
146
|
+
"deployment_name": "some_deployment_name",
|
|
147
|
+
"is_trusted": False,
|
|
148
|
+
"preserve_previous_prod_deployment": True,
|
|
149
|
+
},
|
|
150
|
+
],
|
|
151
|
+
)
|
|
152
|
+
def test_create_truss_service_handles_existing_model(inputs):
|
|
153
|
+
api = MagicMock()
|
|
154
|
+
return_value = {"id": "model_version_id"}
|
|
155
|
+
api.create_model_version_from_truss.return_value = return_value
|
|
156
|
+
model_id, model_version_id = create_truss_service(
|
|
157
|
+
api,
|
|
158
|
+
"model_name",
|
|
159
|
+
"s3_key",
|
|
160
|
+
"config",
|
|
161
|
+
is_draft=False,
|
|
162
|
+
model_id="model_id",
|
|
163
|
+
**inputs,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
assert model_id == "model_id"
|
|
167
|
+
assert model_version_id == return_value["id"]
|
|
168
|
+
api.create_model_version_from_truss.assert_called_once()
|
|
169
|
+
_, kwargs = api.create_model_version_from_truss.call_args
|
|
170
|
+
for k, v in inputs.items():
|
|
171
|
+
assert kwargs[k] == v
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
@pytest.mark.parametrize("allow_truss_download", [True, False])
|
|
175
|
+
@pytest.mark.parametrize("is_draft", [True, False])
|
|
176
|
+
def test_create_truss_service_handles_allow_truss_download_for_new_models(
|
|
177
|
+
is_draft, allow_truss_download
|
|
178
|
+
):
|
|
179
|
+
api = MagicMock()
|
|
180
|
+
return_value = {"id": "id", "version_id": "model_version_id"}
|
|
181
|
+
api.create_model_from_truss.return_value = return_value
|
|
182
|
+
api.create_development_model_from_truss.return_value = return_value
|
|
183
|
+
|
|
184
|
+
model_id = None
|
|
185
|
+
model_id, model_version_id = create_truss_service(
|
|
186
|
+
api,
|
|
187
|
+
"model_name",
|
|
188
|
+
"s3_key",
|
|
189
|
+
"config",
|
|
190
|
+
is_trusted=False,
|
|
191
|
+
preserve_previous_prod_deployment=False,
|
|
192
|
+
is_draft=is_draft,
|
|
193
|
+
model_id=model_id,
|
|
194
|
+
deployment_name="deployment_name",
|
|
195
|
+
allow_truss_download=allow_truss_download,
|
|
196
|
+
)
|
|
197
|
+
assert model_id == return_value["id"]
|
|
198
|
+
assert model_version_id == return_value["version_id"]
|
|
199
|
+
|
|
200
|
+
create_model_mock = (
|
|
201
|
+
api.create_development_model_from_truss
|
|
202
|
+
if is_draft
|
|
203
|
+
else api.create_model_from_truss
|
|
204
|
+
)
|
|
205
|
+
create_model_mock.assert_called_once()
|
|
206
|
+
_, kwargs = create_model_mock.call_args
|
|
207
|
+
assert kwargs["allow_truss_download"] is allow_truss_download
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def test_validate_truss_config():
|
|
211
|
+
def mock_validate_truss(client_version, config):
|
|
212
|
+
if config == {}:
|
|
213
|
+
return {"success": True, "details": json.dumps({})}
|
|
214
|
+
elif "hi" in config:
|
|
215
|
+
return {"success": False, "details": json.dumps({"errors": ["error"]})}
|
|
216
|
+
else:
|
|
217
|
+
return {
|
|
218
|
+
"success": False,
|
|
219
|
+
"details": json.dumps({"errors": ["error", "and another one"]}),
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
api = MagicMock()
|
|
223
|
+
api.validate_truss.side_effect = mock_validate_truss
|
|
224
|
+
|
|
225
|
+
assert core.validate_truss_config(api, {}) is None
|
|
226
|
+
with pytest.raises(
|
|
227
|
+
ValidationError, match="Validation failed with the following errors:\n error"
|
|
228
|
+
):
|
|
229
|
+
core.validate_truss_config(api, {"hi": "hi"})
|
|
230
|
+
with pytest.raises(
|
|
231
|
+
ValidationError,
|
|
232
|
+
match="Validation failed with the following errors:\n error\n and another one",
|
|
233
|
+
):
|
|
234
|
+
core.validate_truss_config(api, {"should_error": "hi"})
|