truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/__init__.py +10 -3
- truss/api/__init__.py +123 -0
- truss/api/definitions.py +51 -0
- truss/base/constants.py +116 -0
- truss/base/custom_types.py +29 -0
- truss/{errors.py → base/errors.py} +4 -0
- truss/base/trt_llm_config.py +310 -0
- truss/{truss_config.py → base/truss_config.py} +344 -31
- truss/{truss_spec.py → base/truss_spec.py} +20 -6
- truss/{validation.py → base/validation.py} +60 -11
- truss/cli/cli.py +841 -88
- truss/{remote → cli}/remote_cli.py +2 -7
- truss/contexts/docker_build_setup.py +67 -0
- truss/contexts/image_builder/cache_warmer.py +2 -8
- truss/contexts/image_builder/image_builder.py +1 -1
- truss/contexts/image_builder/serving_image_builder.py +292 -46
- truss/contexts/image_builder/util.py +1 -3
- truss/contexts/local_loader/docker_build_emulator.py +58 -0
- truss/contexts/local_loader/load_model_local.py +2 -2
- truss/contexts/local_loader/truss_module_loader.py +1 -1
- truss/contexts/local_loader/utils.py +1 -1
- truss/local/local_config.py +2 -6
- truss/local/local_config_handler.py +20 -5
- truss/patch/__init__.py +1 -0
- truss/patch/hash.py +4 -70
- truss/patch/signature.py +4 -16
- truss/patch/truss_dir_patch_applier.py +3 -78
- truss/remote/baseten/api.py +308 -23
- truss/remote/baseten/auth.py +3 -3
- truss/remote/baseten/core.py +257 -50
- truss/remote/baseten/custom_types.py +44 -0
- truss/remote/baseten/error.py +4 -0
- truss/remote/baseten/remote.py +369 -118
- truss/remote/baseten/service.py +118 -11
- truss/remote/baseten/utils/status.py +29 -0
- truss/remote/baseten/utils/tar.py +34 -22
- truss/remote/baseten/utils/transfer.py +36 -23
- truss/remote/remote_factory.py +14 -5
- truss/remote/truss_remote.py +72 -45
- truss/templates/base.Dockerfile.jinja +18 -16
- truss/templates/cache.Dockerfile.jinja +3 -3
- truss/{server → templates/control}/control/application.py +14 -35
- truss/{server → templates/control}/control/endpoints.py +39 -9
- truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
- truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
- truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
- truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
- truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
- truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
- truss/{server → templates/control}/control/server.py +11 -6
- truss/templates/control/requirements.txt +9 -0
- truss/templates/custom_python_dx/my_model.py +28 -0
- truss/templates/docker_server/proxy.conf.jinja +42 -0
- truss/templates/docker_server/supervisord.conf.jinja +27 -0
- truss/templates/docker_server_requirements.txt +1 -0
- truss/templates/server/common/errors.py +231 -0
- truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
- truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
- truss/{server → templates/server}/common/retry.py +1 -0
- truss/{server → templates/server}/common/schema.py +11 -9
- truss/templates/server/common/tracing.py +157 -0
- truss/templates/server/main.py +9 -0
- truss/templates/server/model_wrapper.py +961 -0
- truss/templates/server/requirements.txt +21 -0
- truss/templates/server/truss_server.py +447 -0
- truss/templates/server.Dockerfile.jinja +62 -14
- truss/templates/shared/dynamic_config_resolver.py +28 -0
- truss/templates/shared/lazy_data_resolver.py +164 -0
- truss/templates/shared/log_config.py +125 -0
- truss/{server → templates}/shared/secrets_resolver.py +1 -2
- truss/{server → templates}/shared/serialization.py +31 -9
- truss/{server → templates}/shared/util.py +3 -13
- truss/templates/trtllm-audio/model/model.py +49 -0
- truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
- truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
- truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
- truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
- truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
- truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
- truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
- truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
- truss/templates/trtllm-briton/src/extension.py +64 -0
- truss/tests/conftest.py +302 -94
- truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
- truss/tests/contexts/local_loader/test_load_local.py +2 -2
- truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
- truss/tests/patch/test_calc_patch.py +439 -127
- truss/tests/patch/test_dir_signature.py +3 -12
- truss/tests/patch/test_hash.py +1 -1
- truss/tests/patch/test_signature.py +1 -1
- truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
- truss/tests/patch/test_types.py +2 -2
- truss/tests/remote/baseten/test_api.py +153 -58
- truss/tests/remote/baseten/test_auth.py +2 -1
- truss/tests/remote/baseten/test_core.py +160 -12
- truss/tests/remote/baseten/test_remote.py +489 -77
- truss/tests/remote/baseten/test_service.py +55 -0
- truss/tests/remote/test_remote_factory.py +16 -18
- truss/tests/remote/test_truss_remote.py +26 -17
- truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
- truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
- truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
- truss/tests/{server → templates/control}/control/test_server.py +79 -24
- truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
- truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
- truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
- truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
- truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
- truss/tests/{server → templates/server}/common/test_retry.py +3 -3
- truss/tests/templates/server/test_model_wrapper.py +248 -0
- truss/tests/{server → templates/server}/test_schema.py +3 -5
- truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
- truss/tests/test_build.py +9 -52
- truss/tests/test_config.py +336 -77
- truss/tests/test_context_builder_image.py +3 -11
- truss/tests/test_control_truss_patching.py +7 -12
- truss/tests/test_custom_server.py +38 -0
- truss/tests/test_data/context_builder_image_test/test.py +3 -0
- truss/tests/test_data/happy.ipynb +56 -0
- truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
- truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
- truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
- truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
- truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
- truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
- truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
- truss/tests/test_data/test_async_truss/__init__.py +0 -0
- truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/config.yaml +16 -0
- truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/config.yaml +13 -0
- truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
- truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
- truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
- truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
- truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
- truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
- truss/tests/test_data/test_openai/__init__.py +0 -0
- truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
- truss/tests/test_data/test_openai/model/__init__.py +0 -0
- truss/tests/test_data/test_openai/model/model.py +15 -0
- truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
- truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
- truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
- truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
- truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
- truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
- truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
- truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
- truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
- truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
- truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
- truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
- truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
- truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
- truss/tests/test_data/test_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss/config.yaml +4 -0
- truss/tests/test_data/test_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss/model/dummy +0 -0
- truss/tests/test_data/test_truss/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
- truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
- truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_docker.py +2 -1
- truss/tests/test_model_inference.py +1340 -292
- truss/tests/test_model_schema.py +33 -26
- truss/tests/test_testing_utilities_for_other_tests.py +50 -5
- truss/tests/test_truss_gatherer.py +3 -5
- truss/tests/test_truss_handle.py +62 -59
- truss/tests/test_util.py +2 -1
- truss/tests/test_validation.py +15 -13
- truss/tests/trt_llm/test_trt_llm_config.py +41 -0
- truss/tests/trt_llm/test_validation.py +91 -0
- truss/tests/util/test_config_checks.py +40 -0
- truss/tests/util/test_env_vars.py +14 -0
- truss/tests/util/test_path.py +10 -23
- truss/trt_llm/config_checks.py +43 -0
- truss/trt_llm/validation.py +42 -0
- truss/truss_handle/__init__.py +0 -0
- truss/truss_handle/build.py +122 -0
- truss/{decorators.py → truss_handle/decorators.py} +1 -1
- truss/truss_handle/patch/__init__.py +0 -0
- truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
- truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
- truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
- truss/truss_handle/patch/hash.py +71 -0
- truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
- truss/truss_handle/patch/signature.py +22 -0
- truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
- truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
- truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
- truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
- truss/util/.truss_ignore +3 -0
- truss/{docker.py → util/docker.py} +6 -2
- truss/util/download.py +6 -15
- truss/util/env_vars.py +41 -0
- truss/util/log_utils.py +52 -0
- truss/util/path.py +20 -20
- truss/util/requirements.py +11 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
- truss-0.60.0.dist-info/RECORD +324 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
- truss-0.60.0.dist-info/entry_points.txt +4 -0
- truss_chains/__init__.py +71 -0
- truss_chains/definitions.py +756 -0
- truss_chains/deployment/__init__.py +0 -0
- truss_chains/deployment/code_gen.py +816 -0
- truss_chains/deployment/deployment_client.py +871 -0
- truss_chains/framework.py +1480 -0
- truss_chains/public_api.py +231 -0
- truss_chains/py.typed +0 -0
- truss_chains/pydantic_numpy.py +131 -0
- truss_chains/reference_code/reference_chainlet.py +34 -0
- truss_chains/reference_code/reference_model.py +10 -0
- truss_chains/remote_chainlet/__init__.py +0 -0
- truss_chains/remote_chainlet/model_skeleton.py +60 -0
- truss_chains/remote_chainlet/stub.py +380 -0
- truss_chains/remote_chainlet/utils.py +332 -0
- truss_chains/streaming.py +378 -0
- truss_chains/utils.py +178 -0
- CODE_OF_CONDUCT.md +0 -131
- CONTRIBUTING.md +0 -48
- README.md +0 -137
- context_builder.Dockerfile +0 -24
- truss/blob/blob_backend.py +0 -10
- truss/blob/blob_backend_registry.py +0 -23
- truss/blob/http_public_blob_backend.py +0 -23
- truss/build/__init__.py +0 -2
- truss/build/build.py +0 -143
- truss/build/configure.py +0 -63
- truss/cli/__init__.py +0 -2
- truss/cli/console.py +0 -5
- truss/cli/create.py +0 -5
- truss/config/trt_llm.py +0 -81
- truss/constants.py +0 -61
- truss/model_inference.py +0 -123
- truss/patch/types.py +0 -30
- truss/pytest.ini +0 -7
- truss/server/common/errors.py +0 -100
- truss/server/common/termination_handler_middleware.py +0 -64
- truss/server/common/truss_server.py +0 -389
- truss/server/control/patch/model_code_patch_applier.py +0 -46
- truss/server/control/patch/requirement_name_identifier.py +0 -17
- truss/server/inference_server.py +0 -29
- truss/server/model_wrapper.py +0 -434
- truss/server/shared/logging.py +0 -81
- truss/templates/trtllm/model/model.py +0 -97
- truss/templates/trtllm/packages/build_engine_utils.py +0 -34
- truss/templates/trtllm/packages/constants.py +0 -11
- truss/templates/trtllm/packages/schema.py +0 -216
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
- truss/templates/trtllm/packages/triton_client.py +0 -150
- truss/templates/trtllm/packages/utils.py +0 -43
- truss/test_data/context_builder_image_test/test.py +0 -4
- truss/test_data/happy.ipynb +0 -54
- truss/test_data/model_load_failure_test/config.yaml +0 -2
- truss/test_data/test_concurrency_truss/config.yaml +0 -2
- truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
- truss/test_data/test_streaming_truss/config.yaml +0 -3
- truss/test_data/test_truss/config.yaml +0 -2
- truss/tests/server/common/test_termination_handler_middleware.py +0 -93
- truss/tests/server/control/test_model_container_patch_applier.py +0 -203
- truss/tests/server/core/server/common/test_util.py +0 -19
- truss/tests/server/test_model_wrapper.py +0 -87
- truss/util/data_structures.py +0 -16
- truss-0.10.0rc1.dist-info/RECORD +0 -216
- truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
- truss/{server/shared → base}/__init__.py +0 -0
- truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
- truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
- truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
- truss/{server → templates/server}/common/__init__.py +0 -0
- truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
- truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
- truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
- truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
- truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
- truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
- truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
- truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
- truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
- truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
- truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
- truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
- truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
- truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
- /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
- /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
- /truss/{patch → truss_handle/patch}/constants.py +0 -0
- /truss/{notebook.py → util/notebook.py} +0 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
|
@@ -1,46 +1,78 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
import concurrent
|
|
3
|
+
import contextlib
|
|
4
|
+
import dataclasses
|
|
2
5
|
import inspect
|
|
3
6
|
import json
|
|
4
7
|
import logging
|
|
8
|
+
import pathlib
|
|
9
|
+
import sys
|
|
5
10
|
import tempfile
|
|
6
11
|
import textwrap
|
|
7
12
|
import time
|
|
8
13
|
from concurrent.futures import ThreadPoolExecutor
|
|
9
14
|
from pathlib import Path
|
|
10
15
|
from threading import Thread
|
|
16
|
+
from typing import Iterator, Mapping, Optional
|
|
11
17
|
|
|
18
|
+
import httpx
|
|
19
|
+
import opentelemetry.trace.propagation.tracecontext as tracecontext
|
|
12
20
|
import pytest
|
|
13
21
|
import requests
|
|
22
|
+
from opentelemetry import context, trace
|
|
23
|
+
from python_on_whales import Container
|
|
14
24
|
from requests.exceptions import RequestException
|
|
25
|
+
|
|
26
|
+
from truss.base.truss_config import map_to_supported_python_version
|
|
15
27
|
from truss.local.local_config_handler import LocalConfigHandler
|
|
16
|
-
from truss.model_inference import map_to_supported_python_version
|
|
17
28
|
from truss.tests.helpers import create_truss
|
|
18
29
|
from truss.tests.test_testing_utilities_for_other_tests import ensure_kill_all
|
|
19
|
-
from truss.truss_handle import TrussHandle
|
|
30
|
+
from truss.truss_handle.truss_handle import TrussHandle, wait_for_truss
|
|
20
31
|
|
|
21
32
|
logger = logging.getLogger(__name__)
|
|
22
33
|
|
|
23
34
|
DEFAULT_LOG_ERROR = "Internal Server Error"
|
|
35
|
+
PREDICT_URL = "http://localhost:8090/v1/models/model:predict"
|
|
36
|
+
COMPLETIONS_URL = "http://localhost:8090/v1/completions"
|
|
37
|
+
CHAT_COMPLETIONS_URL = "http://localhost:8090/v1/chat/completions"
|
|
38
|
+
|
|
24
39
|
|
|
40
|
+
@pytest.fixture
|
|
41
|
+
def anyio_backend():
|
|
42
|
+
return "asyncio"
|
|
25
43
|
|
|
26
|
-
|
|
44
|
+
|
|
45
|
+
def _log_contains_line(
|
|
46
|
+
line: dict, message: str, level: str, error: Optional[str] = None
|
|
47
|
+
):
|
|
27
48
|
return (
|
|
28
|
-
line["levelname"] ==
|
|
29
|
-
and line["message"]
|
|
30
|
-
and error in line["exc_info"]
|
|
49
|
+
line["levelname"] == level
|
|
50
|
+
and message in line["message"]
|
|
51
|
+
and (error is None or error in line["exc_info"])
|
|
31
52
|
)
|
|
32
53
|
|
|
33
54
|
|
|
34
|
-
def
|
|
35
|
-
loglines = logs.splitlines()
|
|
55
|
+
def _assert_logs_contain_error(logs: str, error: str, message=DEFAULT_LOG_ERROR):
|
|
56
|
+
loglines = [json.loads(line) for line in logs.splitlines()]
|
|
36
57
|
assert any(
|
|
37
|
-
|
|
58
|
+
_log_contains_line(line, message, "ERROR", error) for line in loglines
|
|
59
|
+
), (
|
|
60
|
+
f"Did not find expected error in logs.\nExpected error: {error}\n"
|
|
61
|
+
f"Expected message: {message}\nActual logs:\n{loglines}"
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _assert_logs_contain(logs: str, message: str, level: str = "INFO"):
|
|
66
|
+
loglines = [json.loads(line) for line in logs.splitlines()]
|
|
67
|
+
assert any(_log_contains_line(line, message, level) for line in loglines), (
|
|
68
|
+
f"Did not find expected logs.\n"
|
|
69
|
+
f"Expected message: {message}\nActual logs:\n{loglines}"
|
|
38
70
|
)
|
|
39
71
|
|
|
40
72
|
|
|
41
|
-
class
|
|
73
|
+
class _PropagatingThread(Thread):
|
|
42
74
|
"""
|
|
43
|
-
|
|
75
|
+
_PropagatingThread allows us to run threads and keep track of exceptions
|
|
44
76
|
thrown.
|
|
45
77
|
"""
|
|
46
78
|
|
|
@@ -52,22 +84,31 @@ class PropagatingThread(Thread):
|
|
|
52
84
|
self.exc = e
|
|
53
85
|
|
|
54
86
|
def join(self, timeout=None):
|
|
55
|
-
super(
|
|
87
|
+
super(_PropagatingThread, self).join(timeout)
|
|
56
88
|
if self.exc:
|
|
57
89
|
raise self.exc
|
|
58
90
|
return self.ret
|
|
59
91
|
|
|
60
92
|
|
|
93
|
+
@contextlib.contextmanager
|
|
94
|
+
def _temp_truss(model_src: str, config_src: str = "") -> Iterator[TrussHandle]:
|
|
95
|
+
with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
|
|
96
|
+
truss_dir = Path(tmp_work_dir, "truss")
|
|
97
|
+
create_truss(truss_dir, config_src, textwrap.dedent(model_src))
|
|
98
|
+
yield TrussHandle(truss_dir)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
# Test Cases ###########################################################################
|
|
102
|
+
|
|
103
|
+
|
|
61
104
|
@pytest.mark.parametrize(
|
|
62
105
|
"python_version, expected_python_version",
|
|
63
106
|
[
|
|
64
|
-
("py37", "py38"),
|
|
65
107
|
("py38", "py38"),
|
|
66
108
|
("py39", "py39"),
|
|
67
109
|
("py310", "py310"),
|
|
68
110
|
("py311", "py311"),
|
|
69
111
|
("py312", "py311"),
|
|
70
|
-
("py36", "py38"),
|
|
71
112
|
],
|
|
72
113
|
)
|
|
73
114
|
def test_map_to_supported_python_version(python_version, expected_python_version):
|
|
@@ -75,11 +116,54 @@ def test_map_to_supported_python_version(python_version, expected_python_version
|
|
|
75
116
|
assert out_python_version == expected_python_version
|
|
76
117
|
|
|
77
118
|
|
|
119
|
+
def test_not_supported_python_minor_versions():
|
|
120
|
+
with pytest.raises(
|
|
121
|
+
ValueError,
|
|
122
|
+
match="Mapping python version 3.6 to 3.8, "
|
|
123
|
+
"the lowest version that Truss currently supports.",
|
|
124
|
+
):
|
|
125
|
+
map_to_supported_python_version("py36")
|
|
126
|
+
with pytest.raises(
|
|
127
|
+
ValueError,
|
|
128
|
+
match="Mapping python version 3.7 to 3.8, "
|
|
129
|
+
"the lowest version that Truss currently supports.",
|
|
130
|
+
):
|
|
131
|
+
map_to_supported_python_version("py37")
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def test_not_supported_python_major_versions():
|
|
135
|
+
with pytest.raises(NotImplementedError, match="Only python version 3 is supported"):
|
|
136
|
+
map_to_supported_python_version("py211")
|
|
137
|
+
|
|
138
|
+
|
|
78
139
|
@pytest.mark.integration
|
|
79
|
-
def
|
|
140
|
+
def test_model_load_logs(test_data_path):
|
|
141
|
+
model = """
|
|
142
|
+
from typing import Optional
|
|
143
|
+
import logging
|
|
144
|
+
class Model:
|
|
145
|
+
def load(self):
|
|
146
|
+
logging.info(f"User Load Message")
|
|
147
|
+
|
|
148
|
+
def predict(self, model_input):
|
|
149
|
+
return self.environment_name
|
|
150
|
+
"""
|
|
151
|
+
config = "model_name: init-environment-truss"
|
|
152
|
+
with ensure_kill_all(), _temp_truss(model, config) as tr:
|
|
153
|
+
container = tr.docker_run(
|
|
154
|
+
local_port=8090, detach=True, wait_for_server_ready=True
|
|
155
|
+
)
|
|
156
|
+
logs = container.logs()
|
|
157
|
+
_assert_logs_contain(logs, message="Executing model.load()")
|
|
158
|
+
_assert_logs_contain(logs, message="Loading truss model from file")
|
|
159
|
+
_assert_logs_contain(logs, message="Completed model.load()")
|
|
160
|
+
_assert_logs_contain(logs, message="User Load Message")
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@pytest.mark.integration
|
|
164
|
+
def test_model_load_failure_truss(test_data_path):
|
|
80
165
|
with ensure_kill_all():
|
|
81
|
-
|
|
82
|
-
truss_dir = truss_root / "test_data" / "model_load_failure_test"
|
|
166
|
+
truss_dir = test_data_path / "model_load_failure_test"
|
|
83
167
|
tr = TrussHandle(truss_dir)
|
|
84
168
|
|
|
85
169
|
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=False)
|
|
@@ -110,6 +194,12 @@ def test_model_load_failure_truss():
|
|
|
110
194
|
assert ready.status_code == expected_code
|
|
111
195
|
return True
|
|
112
196
|
|
|
197
|
+
@handle_request_exception
|
|
198
|
+
def _test_is_loaded(expected_code):
|
|
199
|
+
ready = requests.get(f"{truss_server_addr}/v1/models/model/loaded")
|
|
200
|
+
assert ready.status_code == expected_code
|
|
201
|
+
return True
|
|
202
|
+
|
|
113
203
|
@handle_request_exception
|
|
114
204
|
def _test_ping(expected_code):
|
|
115
205
|
ping = requests.get(f"{truss_server_addr}/ping")
|
|
@@ -118,43 +208,39 @@ def test_model_load_failure_truss():
|
|
|
118
208
|
|
|
119
209
|
@handle_request_exception
|
|
120
210
|
def _test_invocations(expected_code):
|
|
121
|
-
invocations = requests.post(
|
|
211
|
+
invocations = requests.post(
|
|
212
|
+
f"{truss_server_addr}/v1/models/model:predict", json={}
|
|
213
|
+
)
|
|
122
214
|
assert invocations.status_code == expected_code
|
|
123
215
|
return True
|
|
124
216
|
|
|
125
217
|
# The server should be completely down so all requests should result in a RequestException.
|
|
126
218
|
# The decorator handle_request_exception catches the RequestException and returns False.
|
|
127
|
-
assert not _test_readiness_probe(expected_code=200)
|
|
128
219
|
assert not _test_liveness_probe(expected_code=200)
|
|
220
|
+
assert not _test_readiness_probe(expected_code=200)
|
|
221
|
+
assert not _test_is_loaded(expected_code=200)
|
|
129
222
|
assert not _test_ping(expected_code=200)
|
|
130
223
|
assert not _test_invocations(expected_code=200)
|
|
131
224
|
|
|
132
225
|
|
|
133
226
|
@pytest.mark.integration
|
|
134
|
-
def test_concurrency_truss():
|
|
227
|
+
def test_concurrency_truss(test_data_path):
|
|
135
228
|
# Tests that concurrency limits work correctly
|
|
136
229
|
with ensure_kill_all():
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
truss_dir = truss_root / "test_data" / "test_concurrency_truss"
|
|
140
|
-
|
|
230
|
+
truss_dir = test_data_path / "test_concurrency_truss"
|
|
141
231
|
tr = TrussHandle(truss_dir)
|
|
142
|
-
|
|
143
232
|
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
|
|
144
233
|
|
|
145
|
-
truss_server_addr = "http://localhost:8090"
|
|
146
|
-
full_url = f"{truss_server_addr}/v1/models/model:predict"
|
|
147
|
-
|
|
148
234
|
# Each request takes 2 seconds, for this thread, we allow
|
|
149
235
|
# a concurrency of 2. This means the first two requests will
|
|
150
236
|
# succeed within the 2 seconds, and the third will fail, since
|
|
151
237
|
# it cannot start until the first two have completed.
|
|
152
238
|
def make_request():
|
|
153
|
-
requests.post(
|
|
239
|
+
requests.post(PREDICT_URL, json={}, timeout=3)
|
|
154
240
|
|
|
155
|
-
successful_thread_1 =
|
|
156
|
-
successful_thread_2 =
|
|
157
|
-
failed_thread =
|
|
241
|
+
successful_thread_1 = _PropagatingThread(target=make_request)
|
|
242
|
+
successful_thread_2 = _PropagatingThread(target=make_request)
|
|
243
|
+
failed_thread = _PropagatingThread(target=make_request)
|
|
158
244
|
|
|
159
245
|
successful_thread_1.start()
|
|
160
246
|
successful_thread_2.start()
|
|
@@ -169,38 +255,40 @@ def test_concurrency_truss():
|
|
|
169
255
|
|
|
170
256
|
|
|
171
257
|
@pytest.mark.integration
|
|
172
|
-
def test_requirements_file_truss():
|
|
258
|
+
def test_requirements_file_truss(test_data_path):
|
|
173
259
|
with ensure_kill_all():
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
truss_dir = truss_root / "test_data" / "test_requirements_file_truss"
|
|
177
|
-
|
|
260
|
+
truss_dir = test_data_path / "test_requirements_file_truss"
|
|
178
261
|
tr = TrussHandle(truss_dir)
|
|
179
|
-
|
|
180
262
|
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
|
|
181
|
-
|
|
182
|
-
full_url = f"{truss_server_addr}/v1/models/model:predict"
|
|
263
|
+
time.sleep(3) # Sleeping to allow the load to finish
|
|
183
264
|
|
|
184
265
|
# The prediction imports torch which is specified in a requirements.txt and returns if GPU is available.
|
|
185
|
-
response = requests.post(
|
|
266
|
+
response = requests.post(PREDICT_URL, json={})
|
|
186
267
|
assert response.status_code == 200
|
|
187
268
|
assert response.json() is False
|
|
188
269
|
|
|
189
270
|
|
|
190
271
|
@pytest.mark.integration
|
|
191
|
-
|
|
272
|
+
@pytest.mark.parametrize("pydantic_major_version", ["1", "2"])
|
|
273
|
+
def test_requirements_pydantic(test_data_path, pydantic_major_version):
|
|
192
274
|
with ensure_kill_all():
|
|
193
|
-
|
|
275
|
+
truss_dir = test_data_path / f"test_pyantic_v{pydantic_major_version}"
|
|
276
|
+
tr = TrussHandle(truss_dir)
|
|
277
|
+
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
|
|
194
278
|
|
|
195
|
-
|
|
279
|
+
response = requests.post(PREDICT_URL, json={})
|
|
280
|
+
assert response.status_code == 200
|
|
281
|
+
assert response.json() == '{\n "foo": "bla",\n "bar": 123\n}'
|
|
196
282
|
|
|
197
|
-
tr = TrussHandle(truss_dir)
|
|
198
283
|
|
|
284
|
+
@pytest.mark.integration
|
|
285
|
+
def test_async_truss(test_data_path):
|
|
286
|
+
with ensure_kill_all():
|
|
287
|
+
truss_dir = test_data_path / "test_async_truss"
|
|
288
|
+
tr = TrussHandle(truss_dir)
|
|
199
289
|
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
|
|
200
|
-
truss_server_addr = "http://localhost:8090"
|
|
201
|
-
full_url = f"{truss_server_addr}/v1/models/model:predict"
|
|
202
290
|
|
|
203
|
-
response = requests.post(
|
|
291
|
+
response = requests.post(PREDICT_URL, json={})
|
|
204
292
|
assert response.json() == {
|
|
205
293
|
"preprocess_value": "value",
|
|
206
294
|
"postprocess_value": "value",
|
|
@@ -208,58 +296,44 @@ def test_async_truss():
|
|
|
208
296
|
|
|
209
297
|
|
|
210
298
|
@pytest.mark.integration
|
|
211
|
-
def test_async_streaming():
|
|
299
|
+
def test_async_streaming(test_data_path):
|
|
212
300
|
with ensure_kill_all():
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
truss_dir = truss_root / "test_data" / "test_streaming_async_generator_truss"
|
|
216
|
-
|
|
301
|
+
truss_dir = test_data_path / "test_streaming_async_generator_truss"
|
|
217
302
|
tr = TrussHandle(truss_dir)
|
|
218
|
-
|
|
219
303
|
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
|
|
220
|
-
truss_server_addr = "http://localhost:8090"
|
|
221
|
-
full_url = f"{truss_server_addr}/v1/models/model:predict"
|
|
222
304
|
|
|
223
|
-
response = requests.post(
|
|
305
|
+
response = requests.post(PREDICT_URL, json={}, stream=True)
|
|
224
306
|
assert response.headers.get("transfer-encoding") == "chunked"
|
|
225
307
|
assert [
|
|
226
308
|
byte_string.decode() for byte_string in list(response.iter_content())
|
|
227
309
|
] == ["0", "1", "2", "3", "4"]
|
|
228
310
|
|
|
229
311
|
predict_non_stream_response = requests.post(
|
|
230
|
-
|
|
231
|
-
json={},
|
|
232
|
-
stream=True,
|
|
233
|
-
headers={"accept": "application/json"},
|
|
312
|
+
PREDICT_URL, json={}, stream=True, headers={"accept": "application/json"}
|
|
234
313
|
)
|
|
235
314
|
assert "transfer-encoding" not in predict_non_stream_response.headers
|
|
236
315
|
assert predict_non_stream_response.json() == "01234"
|
|
237
316
|
|
|
238
317
|
|
|
239
318
|
@pytest.mark.integration
|
|
240
|
-
def test_async_streaming_timeout():
|
|
319
|
+
def test_async_streaming_timeout(test_data_path):
|
|
241
320
|
with ensure_kill_all():
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
truss_dir = truss_root / "test_data" / "test_streaming_read_timeout"
|
|
245
|
-
|
|
321
|
+
truss_dir = test_data_path / "test_streaming_read_timeout"
|
|
246
322
|
tr = TrussHandle(truss_dir)
|
|
247
|
-
|
|
248
323
|
container = tr.docker_run(
|
|
249
324
|
local_port=8090, detach=True, wait_for_server_ready=True
|
|
250
325
|
)
|
|
251
|
-
truss_server_addr = "http://localhost:8090"
|
|
252
|
-
predict_url = f"{truss_server_addr}/v1/models/model:predict"
|
|
253
326
|
|
|
254
327
|
# ChunkedEncodingError is raised when the chunk does not get processed due to streaming read timeout
|
|
255
328
|
with pytest.raises(requests.exceptions.ChunkedEncodingError):
|
|
256
|
-
response = requests.post(
|
|
329
|
+
response = requests.post(PREDICT_URL, json={}, stream=True)
|
|
257
330
|
|
|
258
331
|
for chunk in response.iter_content():
|
|
259
332
|
pass
|
|
260
333
|
|
|
261
334
|
# Check to ensure the Timeout error is in the container logs
|
|
262
|
-
|
|
335
|
+
# TODO: maybe intercept this error better?
|
|
336
|
+
_assert_logs_contain_error(
|
|
263
337
|
container.logs(),
|
|
264
338
|
error="raise exceptions.TimeoutError()",
|
|
265
339
|
message="Exception in ASGI application\n",
|
|
@@ -267,20 +341,16 @@ def test_async_streaming_timeout():
|
|
|
267
341
|
|
|
268
342
|
|
|
269
343
|
@pytest.mark.integration
|
|
270
|
-
def
|
|
344
|
+
def test_streaming_with_error_and_stacktrace(test_data_path):
|
|
271
345
|
with ensure_kill_all():
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
truss_dir = truss_root / "test_data" / "test_streaming_truss_with_error"
|
|
275
|
-
|
|
346
|
+
truss_dir = test_data_path / "test_streaming_truss_with_error"
|
|
276
347
|
tr = TrussHandle(truss_dir)
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
predict_url = f"{truss_server_addr}/v1/models/model:predict"
|
|
348
|
+
container = tr.docker_run(
|
|
349
|
+
local_port=8090, detach=True, wait_for_server_ready=True
|
|
350
|
+
)
|
|
281
351
|
|
|
282
352
|
predict_error_response = requests.post(
|
|
283
|
-
|
|
353
|
+
PREDICT_URL, json={"throw_error": True}, stream=True, timeout=2
|
|
284
354
|
)
|
|
285
355
|
|
|
286
356
|
# In error cases, the response will return whatever the stream returned,
|
|
@@ -293,73 +363,28 @@ def test_streaming_with_error():
|
|
|
293
363
|
|
|
294
364
|
# Test that we are able to continue to make requests successfully
|
|
295
365
|
predict_non_error_response = requests.post(
|
|
296
|
-
|
|
366
|
+
PREDICT_URL, json={"throw_error": False}, stream=True, timeout=2
|
|
297
367
|
)
|
|
298
368
|
|
|
299
369
|
assert [
|
|
300
370
|
byte_string.decode()
|
|
301
371
|
for byte_string in predict_non_error_response.iter_content()
|
|
302
372
|
] == ["0", "1", "2", "3", "4"]
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
# A request for which response is not completely read
|
|
318
|
-
predict_response = requests.post(predict_url, json={}, stream=True)
|
|
319
|
-
# We just read the first part and leave it hanging here
|
|
320
|
-
next(predict_response.iter_content())
|
|
321
|
-
|
|
322
|
-
predict_response = requests.post(predict_url, json={}, stream=True)
|
|
323
|
-
|
|
324
|
-
assert predict_response.headers.get("transfer-encoding") == "chunked"
|
|
325
|
-
assert [
|
|
326
|
-
byte_string.decode()
|
|
327
|
-
for byte_string in list(predict_response.iter_content())
|
|
328
|
-
] == ["0", "1", "2", "3", "4"]
|
|
329
|
-
|
|
330
|
-
# When accept is set to application/json, the response is not streamed.
|
|
331
|
-
predict_non_stream_response = requests.post(
|
|
332
|
-
predict_url,
|
|
333
|
-
json={},
|
|
334
|
-
stream=True,
|
|
335
|
-
headers={"accept": "application/json"},
|
|
373
|
+
expected_stack_trace = (
|
|
374
|
+
"Traceback (most recent call last):\n"
|
|
375
|
+
' File "/app/model/model.py", line 12, in inner\n'
|
|
376
|
+
" helpers_1.foo(123)\n"
|
|
377
|
+
' File "/packages/helpers_1.py", line 5, in foo\n'
|
|
378
|
+
" return helpers_2.bar(x)\n"
|
|
379
|
+
' File "/packages/helpers_2.py", line 2, in bar\n'
|
|
380
|
+
' raise Exception("Crashed in `bar`.")\n'
|
|
381
|
+
"Exception: Crashed in `bar`."
|
|
382
|
+
)
|
|
383
|
+
_assert_logs_contain_error(
|
|
384
|
+
container.logs(),
|
|
385
|
+
error=expected_stack_trace,
|
|
386
|
+
message="Exception while generating streamed response: Crashed in `bar`.",
|
|
336
387
|
)
|
|
337
|
-
assert "transfer-encoding" not in predict_non_stream_response.headers
|
|
338
|
-
assert predict_non_stream_response.json() == "01234"
|
|
339
|
-
|
|
340
|
-
# Test that concurrency work correctly. The streaming Truss has a configured
|
|
341
|
-
# concurrency of 1, so only one request can be in flight at a time. Each request
|
|
342
|
-
# takes 2 seconds, so with a timeout of 3 seconds, we expect the first request to
|
|
343
|
-
# succeed and for the second to timeout.
|
|
344
|
-
#
|
|
345
|
-
# Note that with streamed requests, requests.post raises a ReadTimeout exception if
|
|
346
|
-
# `timeout` seconds has passed since receiving any data from the server.
|
|
347
|
-
def make_request(delay: int):
|
|
348
|
-
# For streamed responses, requests does not start receiving content from server until
|
|
349
|
-
# `iter_content` is called, so we must call this in order to get an actual timeout.
|
|
350
|
-
time.sleep(delay)
|
|
351
|
-
list(requests.post(predict_url, json={}, stream=True).iter_content())
|
|
352
|
-
|
|
353
|
-
with ThreadPoolExecutor() as e:
|
|
354
|
-
# We use concurrent.futures.wait instead of the timeout property
|
|
355
|
-
# on requests, since requests timeout property has a complex interaction
|
|
356
|
-
# with streaming.
|
|
357
|
-
first_request = e.submit(make_request, 0)
|
|
358
|
-
second_request = e.submit(make_request, 0.2)
|
|
359
|
-
futures = [first_request, second_request]
|
|
360
|
-
done, not_done = concurrent.futures.wait(futures, timeout=3)
|
|
361
|
-
assert first_request in done
|
|
362
|
-
assert second_request in not_done
|
|
363
388
|
|
|
364
389
|
|
|
365
390
|
@pytest.mark.integration
|
|
@@ -378,106 +403,50 @@ secrets:
|
|
|
378
403
|
|
|
379
404
|
config_with_no_secret = "model_name: secrets-truss"
|
|
380
405
|
missing_secret_error_message = """Secret 'secret' not found. Please ensure that:
|
|
381
|
-
* Secret 'secret' is defined in the 'secrets' section of the Truss config file
|
|
382
|
-
* The model was pushed with the --trusted flag"""
|
|
406
|
+
* Secret 'secret' is defined in the 'secrets' section of the Truss config file"""
|
|
383
407
|
|
|
384
|
-
with ensure_kill_all(),
|
|
385
|
-
truss_dir = Path(tmp_work_dir, "truss")
|
|
386
|
-
|
|
387
|
-
create_truss(truss_dir, config, textwrap.dedent(inspect.getsource(Model)))
|
|
388
|
-
|
|
389
|
-
tr = TrussHandle(truss_dir)
|
|
408
|
+
with ensure_kill_all(), _temp_truss(inspect.getsource(Model), config) as tr:
|
|
390
409
|
LocalConfigHandler.set_secret("secret", "secret_value")
|
|
391
410
|
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
|
|
392
|
-
truss_server_addr = "http://localhost:8090"
|
|
393
|
-
full_url = f"{truss_server_addr}/v1/models/model:predict"
|
|
394
411
|
|
|
395
|
-
response = requests.post(
|
|
412
|
+
response = requests.post(PREDICT_URL, json={})
|
|
396
413
|
|
|
397
414
|
assert response.json() == "secret_value"
|
|
398
415
|
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
create_truss(
|
|
404
|
-
truss_dir, config_with_no_secret, textwrap.dedent(inspect.getsource(Model))
|
|
405
|
-
)
|
|
406
|
-
tr = TrussHandle(truss_dir)
|
|
416
|
+
# Case where the secret is not specified in the config
|
|
417
|
+
with ensure_kill_all(), _temp_truss(
|
|
418
|
+
inspect.getsource(Model), config_with_no_secret
|
|
419
|
+
) as tr:
|
|
407
420
|
LocalConfigHandler.set_secret("secret", "secret_value")
|
|
408
421
|
container = tr.docker_run(
|
|
409
422
|
local_port=8090, detach=True, wait_for_server_ready=True
|
|
410
423
|
)
|
|
411
|
-
truss_server_addr = "http://localhost:8090"
|
|
412
|
-
full_url = f"{truss_server_addr}/v1/models/model:predict"
|
|
413
|
-
|
|
414
|
-
response = requests.post(full_url, json={})
|
|
415
424
|
|
|
425
|
+
response = requests.post(PREDICT_URL, json={})
|
|
416
426
|
assert "error" in response.json()
|
|
417
|
-
|
|
418
|
-
assert_logs_contain_error(container.logs(), missing_secret_error_message)
|
|
427
|
+
_assert_logs_contain_error(container.logs(), missing_secret_error_message)
|
|
419
428
|
assert "Internal Server Error" in response.json()["error"]
|
|
429
|
+
assert response.headers["x-baseten-error-source"] == "04"
|
|
430
|
+
assert response.headers["x-baseten-error-code"] == "600"
|
|
420
431
|
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
truss_dir = Path(tmp_work_dir, "truss")
|
|
424
|
-
|
|
425
|
-
create_truss(truss_dir, config, textwrap.dedent(inspect.getsource(Model)))
|
|
426
|
-
tr = TrussHandle(truss_dir)
|
|
432
|
+
# Case where the secret is not mounted
|
|
433
|
+
with ensure_kill_all(), _temp_truss(inspect.getsource(Model), config) as tr:
|
|
427
434
|
LocalConfigHandler.remove_secret("secret")
|
|
428
435
|
container = tr.docker_run(
|
|
429
436
|
local_port=8090, detach=True, wait_for_server_ready=True
|
|
430
437
|
)
|
|
431
|
-
truss_server_addr = "http://localhost:8090"
|
|
432
|
-
full_url = f"{truss_server_addr}/v1/models/model:predict"
|
|
433
438
|
|
|
434
|
-
response = requests.post(
|
|
439
|
+
response = requests.post(PREDICT_URL, json={})
|
|
435
440
|
assert response.status_code == 500
|
|
436
|
-
|
|
437
|
-
assert_logs_contain_error(container.logs(), missing_secret_error_message)
|
|
441
|
+
_assert_logs_contain_error(container.logs(), missing_secret_error_message)
|
|
438
442
|
assert "Internal Server Error" in response.json()["error"]
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
@pytest.mark.integration
|
|
442
|
-
def test_prints_captured_in_log():
|
|
443
|
-
class Model:
|
|
444
|
-
def predict(self, request):
|
|
445
|
-
print("This is a message from the Truss: Hello World!")
|
|
446
|
-
return {}
|
|
447
|
-
|
|
448
|
-
config = """model_name: printing-truss"""
|
|
449
|
-
|
|
450
|
-
with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
|
|
451
|
-
# Case where the secret is not specified in the config
|
|
452
|
-
truss_dir = Path(tmp_work_dir, "truss")
|
|
453
|
-
|
|
454
|
-
create_truss(truss_dir, config, textwrap.dedent(inspect.getsource(Model)))
|
|
455
|
-
tr = TrussHandle(truss_dir)
|
|
456
|
-
container = tr.docker_run(
|
|
457
|
-
local_port=8090, detach=True, wait_for_server_ready=True
|
|
458
|
-
)
|
|
459
|
-
truss_server_addr = "http://localhost:8090"
|
|
460
|
-
full_url = f"{truss_server_addr}/v1/models/model:predict"
|
|
461
|
-
|
|
462
|
-
_ = requests.post(full_url, json={})
|
|
463
|
-
|
|
464
|
-
loglines = container.logs().splitlines()
|
|
465
|
-
|
|
466
|
-
relevant_line = None
|
|
467
|
-
for line in loglines:
|
|
468
|
-
logline = json.loads(line)
|
|
469
|
-
if logline["message"] == "This is a message from the Truss: Hello World!":
|
|
470
|
-
relevant_line = logline
|
|
471
|
-
break
|
|
472
|
-
|
|
473
|
-
# check that log line has other attributes and could be found
|
|
474
|
-
assert relevant_line is not None, "Relevant log line not found."
|
|
475
|
-
assert "asctime" in relevant_line
|
|
476
|
-
assert "levelname" in relevant_line
|
|
443
|
+
assert response.headers["x-baseten-error-source"] == "04"
|
|
444
|
+
assert response.headers["x-baseten-error-code"] == "600"
|
|
477
445
|
|
|
478
446
|
|
|
479
447
|
@pytest.mark.integration
|
|
480
448
|
def test_postprocess_with_streaming_predict():
|
|
449
|
+
# TODO: revisit the decision to forbid this. If so remove below comment.
|
|
481
450
|
"""
|
|
482
451
|
Test a Truss that has streaming response from both predict and postprocess.
|
|
483
452
|
In this case, the postprocess step continues to happen within the predict lock,
|
|
@@ -498,26 +467,26 @@ def test_postprocess_with_streaming_predict():
|
|
|
498
467
|
yield str(i)
|
|
499
468
|
"""
|
|
500
469
|
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
create_truss(truss_dir, config, textwrap.dedent(model))
|
|
470
|
+
with ensure_kill_all(), _temp_truss(model) as tr:
|
|
471
|
+
container = tr.docker_run(
|
|
472
|
+
local_port=8090, detach=True, wait_for_server_ready=True
|
|
473
|
+
)
|
|
506
474
|
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
assert response.
|
|
475
|
+
response = requests.post(PREDICT_URL, json={}, stream=True)
|
|
476
|
+
logging.info(response.content)
|
|
477
|
+
_assert_logs_contain_error(
|
|
478
|
+
container.logs(),
|
|
479
|
+
"ModelDefinitionError: If the predict function returns a generator (streaming), you cannot use postprocessing.",
|
|
480
|
+
)
|
|
481
|
+
assert "Internal Server Error" in response.json()["error"]
|
|
482
|
+
assert response.headers["x-baseten-error-source"] == "04"
|
|
483
|
+
assert response.headers["x-baseten-error-code"] == "600"
|
|
515
484
|
|
|
516
485
|
|
|
517
486
|
@pytest.mark.integration
|
|
518
487
|
def test_streaming_postprocess():
|
|
519
488
|
"""
|
|
520
|
-
Tests a Truss where predict returns non-streaming, but postprocess is
|
|
489
|
+
Tests a Truss where predict returns non-streaming, but postprocess is streamed, and
|
|
521
490
|
ensures that the postprocess step does not happen within the predict lock. To do this,
|
|
522
491
|
we sleep for two seconds during the postprocess streaming process, and fire off two
|
|
523
492
|
requests with a total timeout of 3 seconds, ensuring that if they were serialized
|
|
@@ -536,22 +505,14 @@ def test_streaming_postprocess():
|
|
|
536
505
|
return ["0", "1"]
|
|
537
506
|
"""
|
|
538
507
|
|
|
539
|
-
|
|
540
|
-
with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
|
|
541
|
-
truss_dir = Path(tmp_work_dir, "truss")
|
|
542
|
-
|
|
543
|
-
create_truss(truss_dir, config, textwrap.dedent(model))
|
|
544
|
-
|
|
545
|
-
tr = TrussHandle(truss_dir)
|
|
508
|
+
with ensure_kill_all(), _temp_truss(model) as tr:
|
|
546
509
|
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
|
|
547
|
-
truss_server_addr = "http://localhost:8090"
|
|
548
|
-
full_url = f"{truss_server_addr}/v1/models/model:predict"
|
|
549
510
|
|
|
550
511
|
def make_request(delay: int):
|
|
551
512
|
# For streamed responses, requests does not start receiving content from server until
|
|
552
513
|
# `iter_content` is called, so we must call this in order to get an actual timeout.
|
|
553
514
|
time.sleep(delay)
|
|
554
|
-
response = requests.post(
|
|
515
|
+
response = requests.post(PREDICT_URL, json={}, stream=True)
|
|
555
516
|
|
|
556
517
|
assert response.status_code == 200
|
|
557
518
|
assert response.content == b"0 modified1 modified"
|
|
@@ -599,20 +560,12 @@ def test_postprocess():
|
|
|
599
560
|
|
|
600
561
|
"""
|
|
601
562
|
|
|
602
|
-
|
|
603
|
-
with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
|
|
604
|
-
truss_dir = Path(tmp_work_dir, "truss")
|
|
605
|
-
|
|
606
|
-
create_truss(truss_dir, config, textwrap.dedent(model))
|
|
607
|
-
|
|
608
|
-
tr = TrussHandle(truss_dir)
|
|
563
|
+
with ensure_kill_all(), _temp_truss(model) as tr:
|
|
609
564
|
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
|
|
610
|
-
truss_server_addr = "http://localhost:8090"
|
|
611
|
-
full_url = f"{truss_server_addr}/v1/models/model:predict"
|
|
612
565
|
|
|
613
566
|
def make_request(delay: int):
|
|
614
567
|
time.sleep(delay)
|
|
615
|
-
response = requests.post(
|
|
568
|
+
response = requests.post(PREDICT_URL, json={})
|
|
616
569
|
assert response.status_code == 200
|
|
617
570
|
assert response.json() == ["0 modified", "1 modified"]
|
|
618
571
|
|
|
@@ -642,27 +595,20 @@ def test_truss_with_errors():
|
|
|
642
595
|
raise ValueError("error")
|
|
643
596
|
"""
|
|
644
597
|
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
|
|
648
|
-
truss_dir = Path(tmp_work_dir, "truss")
|
|
649
|
-
|
|
650
|
-
create_truss(truss_dir, config, textwrap.dedent(model))
|
|
651
|
-
|
|
652
|
-
tr = TrussHandle(truss_dir)
|
|
598
|
+
with ensure_kill_all(), _temp_truss(model) as tr:
|
|
653
599
|
container = tr.docker_run(
|
|
654
600
|
local_port=8090, detach=True, wait_for_server_ready=True
|
|
655
601
|
)
|
|
656
|
-
truss_server_addr = "http://localhost:8090"
|
|
657
|
-
full_url = f"{truss_server_addr}/v1/models/model:predict"
|
|
658
602
|
|
|
659
|
-
response = requests.post(
|
|
603
|
+
response = requests.post(PREDICT_URL, json={})
|
|
660
604
|
assert response.status_code == 500
|
|
661
605
|
assert "error" in response.json()
|
|
662
606
|
|
|
663
|
-
|
|
607
|
+
_assert_logs_contain_error(container.logs(), "ValueError: error")
|
|
664
608
|
|
|
665
609
|
assert "Internal Server Error" in response.json()["error"]
|
|
610
|
+
assert response.headers["x-baseten-error-source"] == "04"
|
|
611
|
+
assert response.headers["x-baseten-error-code"] == "600"
|
|
666
612
|
|
|
667
613
|
model_preprocess_error = """
|
|
668
614
|
class Model:
|
|
@@ -673,24 +619,19 @@ def test_truss_with_errors():
|
|
|
673
619
|
return {"a": "b"}
|
|
674
620
|
"""
|
|
675
621
|
|
|
676
|
-
with ensure_kill_all(),
|
|
677
|
-
truss_dir = Path(tmp_work_dir, "truss")
|
|
678
|
-
|
|
679
|
-
create_truss(truss_dir, config, textwrap.dedent(model_preprocess_error))
|
|
680
|
-
|
|
681
|
-
tr = TrussHandle(truss_dir)
|
|
622
|
+
with ensure_kill_all(), _temp_truss(model_preprocess_error) as tr:
|
|
682
623
|
container = tr.docker_run(
|
|
683
624
|
local_port=8090, detach=True, wait_for_server_ready=True
|
|
684
625
|
)
|
|
685
|
-
truss_server_addr = "http://localhost:8090"
|
|
686
|
-
full_url = f"{truss_server_addr}/v1/models/model:predict"
|
|
687
626
|
|
|
688
|
-
response = requests.post(
|
|
627
|
+
response = requests.post(PREDICT_URL, json={})
|
|
689
628
|
assert response.status_code == 500
|
|
690
629
|
assert "error" in response.json()
|
|
691
630
|
|
|
692
|
-
|
|
631
|
+
_assert_logs_contain_error(container.logs(), "ValueError: error")
|
|
693
632
|
assert "Internal Server Error" in response.json()["error"]
|
|
633
|
+
assert response.headers["x-baseten-error-source"] == "04"
|
|
634
|
+
assert response.headers["x-baseten-error-code"] == "600"
|
|
694
635
|
|
|
695
636
|
model_postprocess_error = """
|
|
696
637
|
class Model:
|
|
@@ -701,23 +642,18 @@ def test_truss_with_errors():
|
|
|
701
642
|
raise ValueError("error")
|
|
702
643
|
"""
|
|
703
644
|
|
|
704
|
-
with ensure_kill_all(),
|
|
705
|
-
truss_dir = Path(tmp_work_dir, "truss")
|
|
706
|
-
|
|
707
|
-
create_truss(truss_dir, config, textwrap.dedent(model_postprocess_error))
|
|
708
|
-
|
|
709
|
-
tr = TrussHandle(truss_dir)
|
|
645
|
+
with ensure_kill_all(), _temp_truss(model_postprocess_error) as tr:
|
|
710
646
|
container = tr.docker_run(
|
|
711
647
|
local_port=8090, detach=True, wait_for_server_ready=True
|
|
712
648
|
)
|
|
713
|
-
truss_server_addr = "http://localhost:8090"
|
|
714
|
-
full_url = f"{truss_server_addr}/v1/models/model:predict"
|
|
715
649
|
|
|
716
|
-
response = requests.post(
|
|
650
|
+
response = requests.post(PREDICT_URL, json={})
|
|
717
651
|
assert response.status_code == 500
|
|
718
652
|
assert "error" in response.json()
|
|
719
|
-
|
|
653
|
+
_assert_logs_contain_error(container.logs(), "ValueError: error")
|
|
720
654
|
assert "Internal Server Error" in response.json()["error"]
|
|
655
|
+
assert response.headers["x-baseten-error-source"] == "04"
|
|
656
|
+
assert response.headers["x-baseten-error-code"] == "600"
|
|
721
657
|
|
|
722
658
|
model_async = """
|
|
723
659
|
class Model:
|
|
@@ -725,32 +661,93 @@ def test_truss_with_errors():
|
|
|
725
661
|
raise ValueError("error")
|
|
726
662
|
"""
|
|
727
663
|
|
|
728
|
-
with ensure_kill_all(),
|
|
729
|
-
|
|
664
|
+
with ensure_kill_all(), _temp_truss(model_async) as tr:
|
|
665
|
+
container = tr.docker_run(
|
|
666
|
+
local_port=8090, detach=True, wait_for_server_ready=True
|
|
667
|
+
)
|
|
730
668
|
|
|
731
|
-
|
|
669
|
+
response = requests.post(PREDICT_URL, json={})
|
|
670
|
+
assert response.status_code == 500
|
|
671
|
+
assert "error" in response.json()
|
|
732
672
|
|
|
733
|
-
|
|
673
|
+
_assert_logs_contain_error(container.logs(), "ValueError: error")
|
|
674
|
+
|
|
675
|
+
assert "Internal Server Error" in response.json()["error"]
|
|
676
|
+
assert response.headers["x-baseten-error-source"] == "04"
|
|
677
|
+
assert response.headers["x-baseten-error-code"] == "600"
|
|
678
|
+
|
|
679
|
+
|
|
680
|
+
@pytest.mark.integration
|
|
681
|
+
def test_truss_with_user_errors():
|
|
682
|
+
"""Test that user-code raised `fastapi.HTTPExceptions` are passed through as is."""
|
|
683
|
+
model = """
|
|
684
|
+
import fastapi
|
|
685
|
+
|
|
686
|
+
class Model:
|
|
687
|
+
def predict(self, request):
|
|
688
|
+
raise fastapi.HTTPException(status_code=500, detail="My custom message.")
|
|
689
|
+
"""
|
|
690
|
+
|
|
691
|
+
with ensure_kill_all(), _temp_truss(model) as tr:
|
|
734
692
|
container = tr.docker_run(
|
|
735
693
|
local_port=8090, detach=True, wait_for_server_ready=True
|
|
736
694
|
)
|
|
737
|
-
truss_server_addr = "http://localhost:8090"
|
|
738
|
-
full_url = f"{truss_server_addr}/v1/models/model:predict"
|
|
739
695
|
|
|
740
|
-
response = requests.post(
|
|
696
|
+
response = requests.post(PREDICT_URL, json={})
|
|
741
697
|
assert response.status_code == 500
|
|
742
698
|
assert "error" in response.json()
|
|
699
|
+
assert response.headers["x-baseten-error-source"] == "04"
|
|
700
|
+
assert response.headers["x-baseten-error-code"] == "600"
|
|
701
|
+
|
|
702
|
+
_assert_logs_contain_error(
|
|
703
|
+
container.logs(),
|
|
704
|
+
"HTTPException: 500: My custom message.",
|
|
705
|
+
"Model raised HTTPException",
|
|
706
|
+
)
|
|
707
|
+
|
|
708
|
+
assert "My custom message." in response.json()["error"]
|
|
709
|
+
assert response.headers["x-baseten-error-source"] == "04"
|
|
710
|
+
assert response.headers["x-baseten-error-code"] == "600"
|
|
711
|
+
|
|
712
|
+
|
|
713
|
+
@pytest.mark.integration
|
|
714
|
+
def test_truss_with_error_stacktrace(test_data_path):
|
|
715
|
+
with ensure_kill_all():
|
|
716
|
+
truss_dir = test_data_path / "test_truss_with_error"
|
|
717
|
+
tr = TrussHandle(truss_dir)
|
|
718
|
+
container = tr.docker_run(
|
|
719
|
+
local_port=8090, detach=True, wait_for_server_ready=True
|
|
720
|
+
)
|
|
743
721
|
|
|
744
|
-
|
|
722
|
+
response = requests.post(PREDICT_URL, json={})
|
|
723
|
+
assert response.status_code == 500
|
|
724
|
+
assert "error" in response.json()
|
|
745
725
|
|
|
746
726
|
assert "Internal Server Error" in response.json()["error"]
|
|
727
|
+
assert response.headers["x-baseten-error-source"] == "04"
|
|
728
|
+
assert response.headers["x-baseten-error-code"] == "600"
|
|
729
|
+
|
|
730
|
+
expected_stack_trace = (
|
|
731
|
+
"Traceback (most recent call last):\n"
|
|
732
|
+
' File "/app/model/model.py", line 8, in predict\n'
|
|
733
|
+
" return helpers_1.foo(123)\n"
|
|
734
|
+
' File "/packages/helpers_1.py", line 5, in foo\n'
|
|
735
|
+
" return helpers_2.bar(x)\n"
|
|
736
|
+
' File "/packages/helpers_2.py", line 2, in bar\n'
|
|
737
|
+
' raise Exception("Crashed in `bar`.")\n'
|
|
738
|
+
"Exception: Crashed in `bar`."
|
|
739
|
+
)
|
|
740
|
+
_assert_logs_contain_error(
|
|
741
|
+
container.logs(),
|
|
742
|
+
error=expected_stack_trace,
|
|
743
|
+
message="Internal Server Error",
|
|
744
|
+
)
|
|
747
745
|
|
|
748
746
|
|
|
749
747
|
@pytest.mark.integration
|
|
750
|
-
def test_slow_truss():
|
|
748
|
+
def test_slow_truss(test_data_path):
|
|
751
749
|
with ensure_kill_all():
|
|
752
|
-
|
|
753
|
-
truss_dir = truss_root / "test_data" / "server_conformance_test_truss"
|
|
750
|
+
truss_dir = test_data_path / "server_conformance_test_truss"
|
|
754
751
|
tr = TrussHandle(truss_dir)
|
|
755
752
|
|
|
756
753
|
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=False)
|
|
@@ -765,6 +762,10 @@ def test_slow_truss():
|
|
|
765
762
|
ready = requests.get(f"{truss_server_addr}/v1/models/model")
|
|
766
763
|
assert ready.status_code == expected_code
|
|
767
764
|
|
|
765
|
+
def _test_is_loaded(expected_code):
|
|
766
|
+
ready = requests.get(f"{truss_server_addr}/v1/models/model/loaded")
|
|
767
|
+
assert ready.status_code == expected_code
|
|
768
|
+
|
|
768
769
|
def _test_ping(expected_code):
|
|
769
770
|
ping = requests.get(f"{truss_server_addr}/ping")
|
|
770
771
|
assert ping.status_code == expected_code
|
|
@@ -786,6 +787,7 @@ def test_slow_truss():
|
|
|
786
787
|
for _ in range(LOAD_TEST_TIME):
|
|
787
788
|
_test_liveness_probe(200)
|
|
788
789
|
_test_readiness_probe(503)
|
|
790
|
+
_test_is_loaded(503)
|
|
789
791
|
_test_ping(503)
|
|
790
792
|
_test_invocations(503)
|
|
791
793
|
time.sleep(1)
|
|
@@ -793,6 +795,7 @@ def test_slow_truss():
|
|
|
793
795
|
time.sleep(LOAD_BUFFER_TIME)
|
|
794
796
|
_test_liveness_probe(200)
|
|
795
797
|
_test_readiness_probe(200)
|
|
798
|
+
_test_is_loaded(200)
|
|
796
799
|
_test_ping(200)
|
|
797
800
|
|
|
798
801
|
predict_call = Thread(
|
|
@@ -805,9 +808,1054 @@ def test_slow_truss():
|
|
|
805
808
|
for _ in range(PREDICT_TEST_TIME):
|
|
806
809
|
_test_liveness_probe(200)
|
|
807
810
|
_test_readiness_probe(200)
|
|
811
|
+
_test_is_loaded(200)
|
|
808
812
|
_test_ping(200)
|
|
809
813
|
time.sleep(1)
|
|
810
814
|
|
|
811
815
|
predict_call.join()
|
|
812
816
|
|
|
813
817
|
_test_invocations(200)
|
|
818
|
+
|
|
819
|
+
|
|
820
|
+
@pytest.mark.integration
|
|
821
|
+
def test_init_environment_parameter():
|
|
822
|
+
# Test a truss deployment that is associated with an environment
|
|
823
|
+
model = """
|
|
824
|
+
from typing import Optional
|
|
825
|
+
class Model:
|
|
826
|
+
def __init__(self, **kwargs):
|
|
827
|
+
self._config = kwargs["config"]
|
|
828
|
+
self._environment = kwargs["environment"]
|
|
829
|
+
self.environment_name = self._environment.get("name") if self._environment else None
|
|
830
|
+
|
|
831
|
+
def load(self):
|
|
832
|
+
print(f"Executing model.load with environment: {self.environment_name}")
|
|
833
|
+
|
|
834
|
+
def predict(self, model_input):
|
|
835
|
+
return self.environment_name
|
|
836
|
+
"""
|
|
837
|
+
config = "model_name: init-environment-truss"
|
|
838
|
+
with ensure_kill_all(), _temp_truss(model, config) as tr:
|
|
839
|
+
# Mimic environment changing to staging
|
|
840
|
+
staging_env = {"name": "staging"}
|
|
841
|
+
staging_env_str = json.dumps(staging_env)
|
|
842
|
+
LocalConfigHandler.set_dynamic_config("environment", staging_env_str)
|
|
843
|
+
container = tr.docker_run(
|
|
844
|
+
local_port=8090, detach=True, wait_for_server_ready=True
|
|
845
|
+
)
|
|
846
|
+
assert "Executing model.load with environment: staging" in container.logs()
|
|
847
|
+
response = requests.post(PREDICT_URL, json={})
|
|
848
|
+
assert response.json() == "staging"
|
|
849
|
+
assert response.status_code == 200
|
|
850
|
+
container.execute(["bash", "-c", "rm -f /etc/b10_dynamic_config/environment"])
|
|
851
|
+
|
|
852
|
+
# Test a truss deployment with no associated environment
|
|
853
|
+
config = "model_name: init-no-environment-truss"
|
|
854
|
+
with ensure_kill_all(), _temp_truss(model, config) as tr:
|
|
855
|
+
container = tr.docker_run(
|
|
856
|
+
local_port=8090, detach=True, wait_for_server_ready=True
|
|
857
|
+
)
|
|
858
|
+
assert "Executing model.load with environment: None" in container.logs()
|
|
859
|
+
response = requests.post(PREDICT_URL, json={})
|
|
860
|
+
assert response.json() is None
|
|
861
|
+
assert response.status_code == 200
|
|
862
|
+
|
|
863
|
+
|
|
864
|
+
@pytest.mark.integration
|
|
865
|
+
def test_setup_environment():
|
|
866
|
+
# Test truss that uses setup_environment() without load()
|
|
867
|
+
model = """
|
|
868
|
+
from typing import Optional
|
|
869
|
+
class Model:
|
|
870
|
+
def setup_environment(self, environment: Optional[dict]):
|
|
871
|
+
print("setup_environment called with", environment)
|
|
872
|
+
self.environment_name = environment.get("name") if environment else None
|
|
873
|
+
print(f"in {self.environment_name} environment")
|
|
874
|
+
|
|
875
|
+
def predict(self, model_input):
|
|
876
|
+
return model_input
|
|
877
|
+
"""
|
|
878
|
+
with ensure_kill_all(), _temp_truss(model, "") as tr:
|
|
879
|
+
container = tr.docker_run(
|
|
880
|
+
local_port=8090, detach=True, wait_for_server_ready=True
|
|
881
|
+
)
|
|
882
|
+
# Mimic environment changing to beta
|
|
883
|
+
beta_env = {"name": "beta"}
|
|
884
|
+
beta_env_str = json.dumps(beta_env)
|
|
885
|
+
container.execute(
|
|
886
|
+
[
|
|
887
|
+
"bash",
|
|
888
|
+
"-c",
|
|
889
|
+
f"echo '{beta_env_str}' > /etc/b10_dynamic_config/environment",
|
|
890
|
+
]
|
|
891
|
+
)
|
|
892
|
+
time.sleep(30)
|
|
893
|
+
assert (
|
|
894
|
+
f"Executing model.setup_environment with environment: {beta_env}"
|
|
895
|
+
in container.logs()
|
|
896
|
+
)
|
|
897
|
+
single_quote_beta_env_str = beta_env_str.replace('"', "'")
|
|
898
|
+
assert (
|
|
899
|
+
f"setup_environment called with {single_quote_beta_env_str}"
|
|
900
|
+
in container.logs()
|
|
901
|
+
)
|
|
902
|
+
assert "in beta environment" in container.logs()
|
|
903
|
+
container.execute(["bash", "-c", "rm -f /etc/b10_dynamic_config/environment"])
|
|
904
|
+
|
|
905
|
+
# Test a truss that uses the environment in load()
|
|
906
|
+
model = """
|
|
907
|
+
from typing import Optional
|
|
908
|
+
class Model:
|
|
909
|
+
def setup_environment(self, environment: Optional[dict]):
|
|
910
|
+
print("setup_environment called with", environment)
|
|
911
|
+
self.environment_name = environment.get("name") if environment else None
|
|
912
|
+
print(f"in {self.environment_name} environment")
|
|
913
|
+
|
|
914
|
+
def load(self):
|
|
915
|
+
print("loading in environment", self.environment_name)
|
|
916
|
+
|
|
917
|
+
def predict(self, model_input):
|
|
918
|
+
return model_input
|
|
919
|
+
"""
|
|
920
|
+
with ensure_kill_all(), _temp_truss(model, "") as tr:
|
|
921
|
+
# Mimic environment changing to staging
|
|
922
|
+
staging_env = {"name": "staging"}
|
|
923
|
+
staging_env_str = json.dumps(staging_env)
|
|
924
|
+
LocalConfigHandler.set_dynamic_config("environment", staging_env_str)
|
|
925
|
+
container = tr.docker_run(
|
|
926
|
+
local_port=8090, detach=True, wait_for_server_ready=True
|
|
927
|
+
)
|
|
928
|
+
# Don't need to wait here because we explicitly grab the environment from dynamic_config_resolver before calling user's load()
|
|
929
|
+
assert (
|
|
930
|
+
f"Executing model.setup_environment with environment: {staging_env}"
|
|
931
|
+
in container.logs()
|
|
932
|
+
)
|
|
933
|
+
single_quote_staging_env_str = staging_env_str.replace('"', "'")
|
|
934
|
+
assert (
|
|
935
|
+
f"setup_environment called with {single_quote_staging_env_str}"
|
|
936
|
+
in container.logs()
|
|
937
|
+
)
|
|
938
|
+
assert "in staging environment" in container.logs()
|
|
939
|
+
assert "loading in environment staging" in container.logs()
|
|
940
|
+
# Set environment to None
|
|
941
|
+
no_env = None
|
|
942
|
+
no_env_str = json.dumps(no_env)
|
|
943
|
+
container.execute(
|
|
944
|
+
["bash", "-c", f"echo '{no_env_str}' > /etc/b10_dynamic_config/environment"]
|
|
945
|
+
)
|
|
946
|
+
time.sleep(30)
|
|
947
|
+
assert (
|
|
948
|
+
f"Executing model.setup_environment with environment: {no_env}"
|
|
949
|
+
in container.logs()
|
|
950
|
+
)
|
|
951
|
+
assert "setup_environment called with None" in container.logs()
|
|
952
|
+
container.execute(["bash", "-c", "rm -f /etc/b10_dynamic_config/environment"])
|
|
953
|
+
|
|
954
|
+
|
|
955
|
+
@pytest.mark.integration
|
|
956
|
+
def test_health_check_configuration():
|
|
957
|
+
model = """
|
|
958
|
+
class Model:
|
|
959
|
+
def predict(self, model_input):
|
|
960
|
+
return model_input
|
|
961
|
+
"""
|
|
962
|
+
|
|
963
|
+
config = """runtime:
|
|
964
|
+
health_checks:
|
|
965
|
+
restart_check_delay_seconds: 100
|
|
966
|
+
restart_threshold_seconds: 1700
|
|
967
|
+
"""
|
|
968
|
+
|
|
969
|
+
with ensure_kill_all(), _temp_truss(model, config) as tr:
|
|
970
|
+
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
|
|
971
|
+
|
|
972
|
+
assert tr.spec.config.runtime.health_checks.restart_check_delay_seconds == 100
|
|
973
|
+
assert tr.spec.config.runtime.health_checks.restart_threshold_seconds == 1700
|
|
974
|
+
assert (
|
|
975
|
+
tr.spec.config.runtime.health_checks.stop_traffic_threshold_seconds is None
|
|
976
|
+
)
|
|
977
|
+
|
|
978
|
+
config = """runtime:
|
|
979
|
+
health_checks:
|
|
980
|
+
restart_check_delay_seconds: 1200
|
|
981
|
+
restart_threshold_seconds: 90
|
|
982
|
+
stop_traffic_threshold_seconds: 50
|
|
983
|
+
"""
|
|
984
|
+
|
|
985
|
+
with ensure_kill_all(), _temp_truss(model, config) as tr:
|
|
986
|
+
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
|
|
987
|
+
|
|
988
|
+
assert tr.spec.config.runtime.health_checks.restart_check_delay_seconds == 1200
|
|
989
|
+
assert tr.spec.config.runtime.health_checks.restart_threshold_seconds == 90
|
|
990
|
+
assert tr.spec.config.runtime.health_checks.stop_traffic_threshold_seconds == 50
|
|
991
|
+
|
|
992
|
+
with ensure_kill_all(), _temp_truss(model, "") as tr:
|
|
993
|
+
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
|
|
994
|
+
|
|
995
|
+
assert tr.spec.config.runtime.health_checks.restart_check_delay_seconds is None
|
|
996
|
+
assert tr.spec.config.runtime.health_checks.restart_threshold_seconds is None
|
|
997
|
+
assert (
|
|
998
|
+
tr.spec.config.runtime.health_checks.stop_traffic_threshold_seconds is None
|
|
999
|
+
)
|
|
1000
|
+
|
|
1001
|
+
|
|
1002
|
+
@pytest.mark.integration
|
|
1003
|
+
def test_is_healthy():
|
|
1004
|
+
model = """
|
|
1005
|
+
class Model:
|
|
1006
|
+
def load(self):
|
|
1007
|
+
raise Exception("not loaded")
|
|
1008
|
+
|
|
1009
|
+
def is_healthy(self) -> bool:
|
|
1010
|
+
return True
|
|
1011
|
+
|
|
1012
|
+
def predict(self, model_input):
|
|
1013
|
+
return model_input
|
|
1014
|
+
"""
|
|
1015
|
+
with ensure_kill_all(), _temp_truss(model, "") as tr:
|
|
1016
|
+
container = tr.docker_run(
|
|
1017
|
+
local_port=8090, detach=True, wait_for_server_ready=False
|
|
1018
|
+
)
|
|
1019
|
+
|
|
1020
|
+
truss_server_addr = "http://localhost:8090"
|
|
1021
|
+
for _ in range(5):
|
|
1022
|
+
time.sleep(1)
|
|
1023
|
+
healthy = requests.get(f"{truss_server_addr}/v1/models/model")
|
|
1024
|
+
if healthy.status_code == 503:
|
|
1025
|
+
break
|
|
1026
|
+
assert healthy.status_code == 200
|
|
1027
|
+
assert healthy.status_code == 503
|
|
1028
|
+
diff = container.diff()
|
|
1029
|
+
assert "/root/inference_server_crashed.txt" in diff
|
|
1030
|
+
assert diff["/root/inference_server_crashed.txt"] == "A"
|
|
1031
|
+
|
|
1032
|
+
model = """
|
|
1033
|
+
class Model:
|
|
1034
|
+
def is_healthy(self, argument) -> bool:
|
|
1035
|
+
pass
|
|
1036
|
+
|
|
1037
|
+
def predict(self, model_input):
|
|
1038
|
+
return model_input
|
|
1039
|
+
"""
|
|
1040
|
+
with ensure_kill_all(), _temp_truss(model, "") as tr:
|
|
1041
|
+
container = tr.docker_run(
|
|
1042
|
+
local_port=8090, detach=True, wait_for_server_ready=False
|
|
1043
|
+
)
|
|
1044
|
+
time.sleep(1)
|
|
1045
|
+
_assert_logs_contain_error(
|
|
1046
|
+
container.logs(),
|
|
1047
|
+
message="Exception while loading model",
|
|
1048
|
+
error="`is_healthy` must have only one argument: `self`",
|
|
1049
|
+
)
|
|
1050
|
+
|
|
1051
|
+
model = """
|
|
1052
|
+
class Model:
|
|
1053
|
+
def is_healthy(self) -> bool:
|
|
1054
|
+
raise Exception("not healthy")
|
|
1055
|
+
|
|
1056
|
+
def predict(self, model_input):
|
|
1057
|
+
return model_input
|
|
1058
|
+
"""
|
|
1059
|
+
with ensure_kill_all(), _temp_truss(model, "") as tr:
|
|
1060
|
+
container = tr.docker_run(
|
|
1061
|
+
local_port=8090, detach=True, wait_for_server_ready=False
|
|
1062
|
+
)
|
|
1063
|
+
|
|
1064
|
+
# Sleep a few seconds to get the server some time to wake up
|
|
1065
|
+
time.sleep(10)
|
|
1066
|
+
|
|
1067
|
+
truss_server_addr = "http://localhost:8090"
|
|
1068
|
+
|
|
1069
|
+
healthy = requests.get(f"{truss_server_addr}/v1/models/model")
|
|
1070
|
+
assert healthy.status_code == 503
|
|
1071
|
+
assert (
|
|
1072
|
+
"Exception while checking if model is healthy: not healthy"
|
|
1073
|
+
in container.logs()
|
|
1074
|
+
)
|
|
1075
|
+
assert "Health check failed." in container.logs()
|
|
1076
|
+
|
|
1077
|
+
model = """
|
|
1078
|
+
import time
|
|
1079
|
+
|
|
1080
|
+
class Model:
|
|
1081
|
+
def load(self):
|
|
1082
|
+
time.sleep(10)
|
|
1083
|
+
|
|
1084
|
+
def is_healthy(self) -> bool:
|
|
1085
|
+
return False
|
|
1086
|
+
|
|
1087
|
+
def predict(self, model_input):
|
|
1088
|
+
return model_input
|
|
1089
|
+
"""
|
|
1090
|
+
with ensure_kill_all(), _temp_truss(model, "") as tr:
|
|
1091
|
+
container = tr.docker_run(
|
|
1092
|
+
local_port=8090, detach=True, wait_for_server_ready=False
|
|
1093
|
+
)
|
|
1094
|
+
truss_server_addr = "http://localhost:8090"
|
|
1095
|
+
|
|
1096
|
+
time.sleep(5)
|
|
1097
|
+
healthy = requests.get(f"{truss_server_addr}/v1/models/model")
|
|
1098
|
+
assert healthy.status_code == 503
|
|
1099
|
+
# Ensure we only log after model.load is complete
|
|
1100
|
+
assert "Health check failed." not in container.logs()
|
|
1101
|
+
|
|
1102
|
+
# Sleep a few seconds to get the server some time to wake up
|
|
1103
|
+
time.sleep(10)
|
|
1104
|
+
|
|
1105
|
+
healthy = requests.get(f"{truss_server_addr}/v1/models/model")
|
|
1106
|
+
assert healthy.status_code == 503
|
|
1107
|
+
assert container.logs().count("Health check failed.") == 1
|
|
1108
|
+
healthy = requests.get(f"{truss_server_addr}/v1/models/model")
|
|
1109
|
+
assert healthy.status_code == 503
|
|
1110
|
+
assert container.logs().count("Health check failed.") == 2
|
|
1111
|
+
|
|
1112
|
+
model = """
|
|
1113
|
+
import time
|
|
1114
|
+
|
|
1115
|
+
class Model:
|
|
1116
|
+
def __init__(self, **kwargs):
|
|
1117
|
+
self._healthy = False
|
|
1118
|
+
|
|
1119
|
+
def load(self):
|
|
1120
|
+
time.sleep(10)
|
|
1121
|
+
self._healthy = True
|
|
1122
|
+
|
|
1123
|
+
def is_healthy(self):
|
|
1124
|
+
return self._healthy
|
|
1125
|
+
|
|
1126
|
+
def predict(self, model_input):
|
|
1127
|
+
self._healthy = model_input["healthy"]
|
|
1128
|
+
return model_input
|
|
1129
|
+
"""
|
|
1130
|
+
with ensure_kill_all(), _temp_truss(model, "") as tr:
|
|
1131
|
+
tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=False)
|
|
1132
|
+
time.sleep(5)
|
|
1133
|
+
truss_server_addr = "http://localhost:8090"
|
|
1134
|
+
healthy = requests.get(f"{truss_server_addr}/v1/models/model")
|
|
1135
|
+
assert healthy.status_code == 503
|
|
1136
|
+
time.sleep(10)
|
|
1137
|
+
healthy = requests.get(f"{truss_server_addr}/v1/models/model")
|
|
1138
|
+
assert healthy.status_code == 200
|
|
1139
|
+
|
|
1140
|
+
healthy_responses = [True, "yessss", 34, {"woo": "hoo"}]
|
|
1141
|
+
for response in healthy_responses:
|
|
1142
|
+
predict_response = requests.post(PREDICT_URL, json={"healthy": response})
|
|
1143
|
+
assert predict_response.status_code == 200
|
|
1144
|
+
healthy = requests.get(f"{truss_server_addr}/v1/models/model")
|
|
1145
|
+
assert healthy.status_code == 200
|
|
1146
|
+
|
|
1147
|
+
not_healthy_responses = [False, "", 0, {}]
|
|
1148
|
+
for response in not_healthy_responses:
|
|
1149
|
+
predict_response = requests.post(PREDICT_URL, json={"healthy": response})
|
|
1150
|
+
assert predict_response.status_code == 200
|
|
1151
|
+
healthy = requests.get(f"{truss_server_addr}/v1/models/model")
|
|
1152
|
+
assert healthy.status_code == 503
|
|
1153
|
+
|
|
1154
|
+
model = """
|
|
1155
|
+
class Model:
|
|
1156
|
+
def is_healthy(self) -> bool:
|
|
1157
|
+
return True
|
|
1158
|
+
|
|
1159
|
+
def predict(self, model_input):
|
|
1160
|
+
return model_input
|
|
1161
|
+
"""
|
|
1162
|
+
with ensure_kill_all(), _temp_truss(model, "") as tr:
|
|
1163
|
+
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
|
|
1164
|
+
|
|
1165
|
+
truss_server_addr = "http://localhost:8090"
|
|
1166
|
+
|
|
1167
|
+
healthy = requests.get(f"{truss_server_addr}/v1/models/model")
|
|
1168
|
+
assert healthy.status_code == 200
|
|
1169
|
+
|
|
1170
|
+
|
|
1171
|
+
def _patch_termination_timeout(container: Container, seconds: int, truss_container_fs):
|
|
1172
|
+
app_path = truss_container_fs / "app"
|
|
1173
|
+
sys.path.append(str(app_path))
|
|
1174
|
+
import truss_server
|
|
1175
|
+
|
|
1176
|
+
local_server_source = pathlib.Path(truss_server.__file__)
|
|
1177
|
+
container_server_source = "/app/truss_server.py"
|
|
1178
|
+
modified_content = local_server_source.read_text().replace(
|
|
1179
|
+
"TIMEOUT_GRACEFUL_SHUTDOWN = 120", f"TIMEOUT_GRACEFUL_SHUTDOWN = {seconds}"
|
|
1180
|
+
)
|
|
1181
|
+
with tempfile.NamedTemporaryFile() as patched_file:
|
|
1182
|
+
patched_file.write(modified_content.encode("utf-8"))
|
|
1183
|
+
patched_file.flush()
|
|
1184
|
+
container.copy_to(patched_file.name, container_server_source)
|
|
1185
|
+
|
|
1186
|
+
|
|
1187
|
+
@pytest.mark.anyio
|
|
1188
|
+
@pytest.mark.integration
|
|
1189
|
+
async def test_graceful_shutdown(truss_container_fs):
|
|
1190
|
+
model = """
|
|
1191
|
+
import time
|
|
1192
|
+
class Model:
|
|
1193
|
+
def predict(self, request):
|
|
1194
|
+
print(f"Received {request}")
|
|
1195
|
+
time.sleep(request["seconds"])
|
|
1196
|
+
print(f"Done {request}")
|
|
1197
|
+
return request
|
|
1198
|
+
"""
|
|
1199
|
+
|
|
1200
|
+
async def predict_request(data: dict):
|
|
1201
|
+
async with httpx.AsyncClient() as client:
|
|
1202
|
+
response = await client.post(PREDICT_URL, json=data)
|
|
1203
|
+
response.raise_for_status()
|
|
1204
|
+
return response.json()
|
|
1205
|
+
|
|
1206
|
+
with ensure_kill_all(), _temp_truss(model) as tr:
|
|
1207
|
+
container = tr.docker_run(
|
|
1208
|
+
local_port=8090, detach=True, wait_for_server_ready=True
|
|
1209
|
+
)
|
|
1210
|
+
await predict_request({"seconds": 0, "task": 0}) # Warm up server.
|
|
1211
|
+
|
|
1212
|
+
# Test starting two requests, each taking 2 seconds, then terminating server.
|
|
1213
|
+
# They should both finish successfully since the server grace period is 120 s.
|
|
1214
|
+
task_0 = asyncio.create_task(predict_request({"seconds": 2, "task": 0}))
|
|
1215
|
+
await asyncio.sleep(0.1) # Yield to event loop to make above task run.
|
|
1216
|
+
task_1 = asyncio.create_task(predict_request({"seconds": 2, "task": 1}))
|
|
1217
|
+
await asyncio.sleep(0.1) # Yield to event loop to make above task run.
|
|
1218
|
+
|
|
1219
|
+
t0 = time.perf_counter()
|
|
1220
|
+
# Even though the server has 120s grace period, we expect to finish much
|
|
1221
|
+
# faster in the test here, so use 10s.
|
|
1222
|
+
container.stop(10)
|
|
1223
|
+
stop_time = time.perf_counter() - t0
|
|
1224
|
+
print(f"Stopped in {stop_time} seconds,")
|
|
1225
|
+
|
|
1226
|
+
assert 3 < stop_time < 5
|
|
1227
|
+
assert (await task_0) == {"seconds": 2, "task": 0}
|
|
1228
|
+
assert (await task_1) == {"seconds": 2, "task": 1}
|
|
1229
|
+
|
|
1230
|
+
# Now mess around in the docker container to reduce the grace period to 3 s.
|
|
1231
|
+
# (There's not nice way to patch this...)
|
|
1232
|
+
_patch_termination_timeout(container, 3, truss_container_fs)
|
|
1233
|
+
# Now only one request should complete.
|
|
1234
|
+
container.restart()
|
|
1235
|
+
wait_for_truss("http://localhost:8090", container, True)
|
|
1236
|
+
await predict_request({"seconds": 0, "task": 0}) # Warm up server.
|
|
1237
|
+
|
|
1238
|
+
task_2 = asyncio.create_task(predict_request({"seconds": 2, "task": 2}))
|
|
1239
|
+
await asyncio.sleep(0.1) # Yield to event loop to make above task run.
|
|
1240
|
+
task_3 = asyncio.create_task(predict_request({"seconds": 2, "task": 3}))
|
|
1241
|
+
await asyncio.sleep(0.1) # Yield to event loop to make above task run.
|
|
1242
|
+
t0 = time.perf_counter()
|
|
1243
|
+
container.stop(10)
|
|
1244
|
+
stop_time = time.perf_counter() - t0
|
|
1245
|
+
print(f"Stopped in {stop_time} seconds,")
|
|
1246
|
+
assert 3 < stop_time < 5
|
|
1247
|
+
assert (await task_2) == {"seconds": 2, "task": 2}
|
|
1248
|
+
with pytest.raises(httpx.HTTPStatusError):
|
|
1249
|
+
await task_3
|
|
1250
|
+
|
|
1251
|
+
|
|
1252
|
+
# Tracing ##############################################################################
|
|
1253
|
+
|
|
1254
|
+
|
|
1255
|
+
def _make_otel_headers() -> Mapping[str, str]:
|
|
1256
|
+
"""
|
|
1257
|
+
Create and return a mapping with OpenTelemetry trace context headers.
|
|
1258
|
+
|
|
1259
|
+
This function starts a new span and injects the trace context into the headers,
|
|
1260
|
+
which can be used to propagate tracing information in outgoing HTTP requests.
|
|
1261
|
+
|
|
1262
|
+
Returns:
|
|
1263
|
+
Mapping[str, str]: A mapping containing the trace context headers.
|
|
1264
|
+
"""
|
|
1265
|
+
# Initialize a tracer
|
|
1266
|
+
tracer = trace.get_tracer(__name__)
|
|
1267
|
+
|
|
1268
|
+
# Create a dictionary to hold the headers
|
|
1269
|
+
headers: dict[str, str] = {}
|
|
1270
|
+
|
|
1271
|
+
# Start a new span
|
|
1272
|
+
with tracer.start_as_current_span("outgoing-request-span"):
|
|
1273
|
+
# Use the TraceContextTextMapPropagator to inject the trace context into the headers
|
|
1274
|
+
propagator = tracecontext.TraceContextTextMapPropagator()
|
|
1275
|
+
propagator.inject(headers, context=context.get_current())
|
|
1276
|
+
|
|
1277
|
+
return headers
|
|
1278
|
+
|
|
1279
|
+
|
|
1280
|
+
@pytest.mark.integration
|
|
1281
|
+
@pytest.mark.parametrize("enable_tracing_data", [True, False])
|
|
1282
|
+
def test_streaming_truss_with_user_tracing(test_data_path, enable_tracing_data):
|
|
1283
|
+
with ensure_kill_all():
|
|
1284
|
+
truss_dir = test_data_path / "test_streaming_truss_with_tracing"
|
|
1285
|
+
tr = TrussHandle(truss_dir)
|
|
1286
|
+
|
|
1287
|
+
def enable_gpu_fn(conf):
|
|
1288
|
+
new_runtime = dataclasses.replace(
|
|
1289
|
+
conf.runtime, enable_tracing_data=enable_tracing_data
|
|
1290
|
+
)
|
|
1291
|
+
return dataclasses.replace(conf, runtime=new_runtime)
|
|
1292
|
+
|
|
1293
|
+
tr._update_config(enable_gpu_fn)
|
|
1294
|
+
|
|
1295
|
+
container = tr.docker_run(
|
|
1296
|
+
local_port=8090, detach=True, wait_for_server_ready=True
|
|
1297
|
+
)
|
|
1298
|
+
|
|
1299
|
+
# A request for which response is not completely read
|
|
1300
|
+
headers_0 = _make_otel_headers()
|
|
1301
|
+
predict_response = requests.post(
|
|
1302
|
+
PREDICT_URL, json={}, stream=True, headers=headers_0
|
|
1303
|
+
)
|
|
1304
|
+
# We just read the first part and leave it hanging here
|
|
1305
|
+
next(predict_response.iter_content())
|
|
1306
|
+
|
|
1307
|
+
headers_1 = _make_otel_headers()
|
|
1308
|
+
predict_response = requests.post(
|
|
1309
|
+
PREDICT_URL, json={}, stream=True, headers=headers_1
|
|
1310
|
+
)
|
|
1311
|
+
assert predict_response.headers.get("transfer-encoding") == "chunked"
|
|
1312
|
+
|
|
1313
|
+
# When accept is set to application/json, the response is not streamed.
|
|
1314
|
+
headers_2 = _make_otel_headers()
|
|
1315
|
+
predict_non_stream_response = requests.post(
|
|
1316
|
+
PREDICT_URL,
|
|
1317
|
+
json={},
|
|
1318
|
+
stream=True,
|
|
1319
|
+
headers={**headers_2, "accept": "application/json"},
|
|
1320
|
+
)
|
|
1321
|
+
assert "transfer-encoding" not in predict_non_stream_response.headers
|
|
1322
|
+
assert predict_non_stream_response.json() == "01234"
|
|
1323
|
+
|
|
1324
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
1325
|
+
truss_traces_file = pathlib.Path(tmp_dir) / "otel_traces.ndjson"
|
|
1326
|
+
container.copy_from("/tmp/otel_traces.ndjson", truss_traces_file)
|
|
1327
|
+
truss_traces = [
|
|
1328
|
+
json.loads(s) for s in truss_traces_file.read_text().splitlines()
|
|
1329
|
+
]
|
|
1330
|
+
|
|
1331
|
+
user_traces_file = pathlib.Path(tmp_dir) / "otel_user_traces.ndjson"
|
|
1332
|
+
container.copy_from("/tmp/otel_user_traces.ndjson", user_traces_file)
|
|
1333
|
+
user_traces = [
|
|
1334
|
+
json.loads(s) for s in user_traces_file.read_text().splitlines()
|
|
1335
|
+
]
|
|
1336
|
+
|
|
1337
|
+
if not enable_tracing_data:
|
|
1338
|
+
assert len(truss_traces) == 0
|
|
1339
|
+
assert len(user_traces) > 0
|
|
1340
|
+
return
|
|
1341
|
+
|
|
1342
|
+
assert sum(1 for x in truss_traces if x["name"] == "predict-endpoint") == 3
|
|
1343
|
+
assert sum(1 for x in user_traces if x["name"] == "load_model") == 1
|
|
1344
|
+
assert sum(1 for x in user_traces if x["name"] == "predict") == 3
|
|
1345
|
+
|
|
1346
|
+
user_parents = set(x["parent_id"] for x in user_traces)
|
|
1347
|
+
truss_spans = set(x["context"]["span_id"] for x in truss_traces)
|
|
1348
|
+
truss_parents = set(x["parent_id"] for x in truss_traces)
|
|
1349
|
+
# Make sure there is no context creep into user traces. No user trace should
|
|
1350
|
+
# have a truss trace as parent.
|
|
1351
|
+
assert user_parents & truss_spans == set()
|
|
1352
|
+
# But make sure traces have parents at all.
|
|
1353
|
+
assert len(user_parents) > 3
|
|
1354
|
+
assert len(truss_parents) > 3
|
|
1355
|
+
|
|
1356
|
+
|
|
1357
|
+
# Returning Response Objects ###########################################################
|
|
1358
|
+
|
|
1359
|
+
|
|
1360
|
+
@pytest.mark.integration
|
|
1361
|
+
def test_truss_with_response():
|
|
1362
|
+
"""Test that user-code can set a custom status code."""
|
|
1363
|
+
model = """
|
|
1364
|
+
from fastapi.responses import Response
|
|
1365
|
+
|
|
1366
|
+
class Model:
|
|
1367
|
+
def predict(self, inputs):
|
|
1368
|
+
return Response(status_code=inputs["code"])
|
|
1369
|
+
"""
|
|
1370
|
+
from fastapi import status
|
|
1371
|
+
|
|
1372
|
+
with ensure_kill_all(), _temp_truss(model) as tr:
|
|
1373
|
+
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
|
|
1374
|
+
|
|
1375
|
+
response = requests.post(PREDICT_URL, json={"code": status.HTTP_204_NO_CONTENT})
|
|
1376
|
+
assert response.status_code == 204
|
|
1377
|
+
assert "x-baseten-error-source" not in response.headers
|
|
1378
|
+
assert "x-baseten-error-code" not in response.headers
|
|
1379
|
+
|
|
1380
|
+
response = requests.post(
|
|
1381
|
+
PREDICT_URL, json={"code": status.HTTP_500_INTERNAL_SERVER_ERROR}
|
|
1382
|
+
)
|
|
1383
|
+
assert response.status_code == 500
|
|
1384
|
+
assert response.headers["x-baseten-error-source"] == "04"
|
|
1385
|
+
assert response.headers["x-baseten-error-code"] == "700"
|
|
1386
|
+
|
|
1387
|
+
|
|
1388
|
+
@pytest.mark.integration
|
|
1389
|
+
def test_truss_with_streaming_response():
|
|
1390
|
+
# TODO: one issue with this is that (unlike our "builtin" streaming), this keeps
|
|
1391
|
+
# the semaphore claimed potentially longer if the client drops.
|
|
1392
|
+
|
|
1393
|
+
model = """from starlette.responses import StreamingResponse
|
|
1394
|
+
class Model:
|
|
1395
|
+
def predict(self, model_input):
|
|
1396
|
+
def text_generator():
|
|
1397
|
+
for i in range(3):
|
|
1398
|
+
yield f"data: {i}\\n\\n"
|
|
1399
|
+
return StreamingResponse(text_generator(), media_type="text/event-stream")
|
|
1400
|
+
"""
|
|
1401
|
+
|
|
1402
|
+
with ensure_kill_all(), _temp_truss(model) as tr:
|
|
1403
|
+
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
|
|
1404
|
+
|
|
1405
|
+
# A request for which response is not completely read.
|
|
1406
|
+
predict_response = requests.post(PREDICT_URL, json={}, stream=True)
|
|
1407
|
+
assert (
|
|
1408
|
+
predict_response.headers["Content-Type"]
|
|
1409
|
+
== "text/event-stream; charset=utf-8"
|
|
1410
|
+
)
|
|
1411
|
+
|
|
1412
|
+
lines = predict_response.text.strip().split("\n")
|
|
1413
|
+
assert lines == ["data: 0", "", "data: 1", "", "data: 2"]
|
|
1414
|
+
|
|
1415
|
+
|
|
1416
|
+
# Using Request in Model ###############################################################
|
|
1417
|
+
|
|
1418
|
+
|
|
1419
|
+
@pytest.mark.integration
|
|
1420
|
+
def test_truss_with_request():
|
|
1421
|
+
model = """
|
|
1422
|
+
import fastapi
|
|
1423
|
+
class Model:
|
|
1424
|
+
async def preprocess(self, request: fastapi.Request):
|
|
1425
|
+
return await request.json()
|
|
1426
|
+
|
|
1427
|
+
async def predict(self, inputs, request: fastapi.Request):
|
|
1428
|
+
inputs["request_size"] = len(await request.body())
|
|
1429
|
+
return inputs
|
|
1430
|
+
|
|
1431
|
+
def postprocess(self, inputs):
|
|
1432
|
+
return {**inputs, "postprocess": "was here"}
|
|
1433
|
+
"""
|
|
1434
|
+
with ensure_kill_all(), _temp_truss(model) as tr:
|
|
1435
|
+
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
|
|
1436
|
+
|
|
1437
|
+
response = requests.post(PREDICT_URL, json={"test": 123})
|
|
1438
|
+
assert response.status_code == 200
|
|
1439
|
+
assert response.json() == {
|
|
1440
|
+
"test": 123,
|
|
1441
|
+
"request_size": 13,
|
|
1442
|
+
"postprocess": "was here",
|
|
1443
|
+
}
|
|
1444
|
+
|
|
1445
|
+
|
|
1446
|
+
@pytest.mark.integration
|
|
1447
|
+
def test_truss_with_requests_and_invalid_signatures():
|
|
1448
|
+
model = """
|
|
1449
|
+
class Model:
|
|
1450
|
+
def predict(self, inputs, invalid_arg): ...
|
|
1451
|
+
"""
|
|
1452
|
+
with ensure_kill_all(), _temp_truss(model) as tr:
|
|
1453
|
+
container = tr.docker_run(
|
|
1454
|
+
local_port=8090, detach=True, wait_for_server_ready=False
|
|
1455
|
+
)
|
|
1456
|
+
time.sleep(1.0) # Wait for logs.
|
|
1457
|
+
_assert_logs_contain_error(
|
|
1458
|
+
container.logs(),
|
|
1459
|
+
"`predict` method with two arguments must have request as second argument",
|
|
1460
|
+
"Exception while loading model",
|
|
1461
|
+
)
|
|
1462
|
+
|
|
1463
|
+
model = """
|
|
1464
|
+
import fastapi
|
|
1465
|
+
|
|
1466
|
+
class Model:
|
|
1467
|
+
def predict(self, request: fastapi.Request, invalid_arg): ...
|
|
1468
|
+
"""
|
|
1469
|
+
with ensure_kill_all(), _temp_truss(model) as tr:
|
|
1470
|
+
container = tr.docker_run(
|
|
1471
|
+
local_port=8090, detach=True, wait_for_server_ready=False
|
|
1472
|
+
)
|
|
1473
|
+
time.sleep(1.0) # Wait for logs.
|
|
1474
|
+
_assert_logs_contain_error(
|
|
1475
|
+
container.logs(),
|
|
1476
|
+
"`predict` method with two arguments is not allowed to have request as "
|
|
1477
|
+
"first argument",
|
|
1478
|
+
"Exception while loading model",
|
|
1479
|
+
)
|
|
1480
|
+
|
|
1481
|
+
model = """
|
|
1482
|
+
import fastapi
|
|
1483
|
+
|
|
1484
|
+
class Model:
|
|
1485
|
+
def predict(self, inputs, request: fastapi.Request, something): ...
|
|
1486
|
+
"""
|
|
1487
|
+
with ensure_kill_all(), _temp_truss(model) as tr:
|
|
1488
|
+
container = tr.docker_run(
|
|
1489
|
+
local_port=8090, detach=True, wait_for_server_ready=False
|
|
1490
|
+
)
|
|
1491
|
+
time.sleep(1.0) # Wait for logs.
|
|
1492
|
+
_assert_logs_contain_error(
|
|
1493
|
+
container.logs(),
|
|
1494
|
+
"`predict` method cannot have more than two arguments",
|
|
1495
|
+
"Exception while loading model",
|
|
1496
|
+
)
|
|
1497
|
+
|
|
1498
|
+
|
|
1499
|
+
@pytest.mark.integration
|
|
1500
|
+
def test_truss_with_requests_and_invalid_argument_combinations():
|
|
1501
|
+
model = """
|
|
1502
|
+
import fastapi
|
|
1503
|
+
class Model:
|
|
1504
|
+
async def preprocess(self, inputs): ...
|
|
1505
|
+
|
|
1506
|
+
def predict(self, request: fastapi.Request): ...
|
|
1507
|
+
"""
|
|
1508
|
+
with ensure_kill_all(), _temp_truss(model) as tr:
|
|
1509
|
+
container = tr.docker_run(
|
|
1510
|
+
local_port=8090, detach=True, wait_for_server_ready=False
|
|
1511
|
+
)
|
|
1512
|
+
time.sleep(1.0) # Wait for logs.
|
|
1513
|
+
_assert_logs_contain_error(
|
|
1514
|
+
container.logs(),
|
|
1515
|
+
"When using `preprocess`, the predict method cannot only have the request argument",
|
|
1516
|
+
"Exception while loading model",
|
|
1517
|
+
)
|
|
1518
|
+
|
|
1519
|
+
model = """
|
|
1520
|
+
import fastapi
|
|
1521
|
+
class Model:
|
|
1522
|
+
def preprocess(self, inputs): ...
|
|
1523
|
+
|
|
1524
|
+
async def predict(self, inputs, request: fastapi.Request): ...
|
|
1525
|
+
|
|
1526
|
+
def postprocess(self, request: fastapi.Request): ...
|
|
1527
|
+
"""
|
|
1528
|
+
with ensure_kill_all(), _temp_truss(model) as tr:
|
|
1529
|
+
container = tr.docker_run(
|
|
1530
|
+
local_port=8090, detach=True, wait_for_server_ready=False
|
|
1531
|
+
)
|
|
1532
|
+
time.sleep(1.0) # Wait for logs.
|
|
1533
|
+
_assert_logs_contain_error(
|
|
1534
|
+
container.logs(),
|
|
1535
|
+
"The `postprocess` method cannot only have the request argument",
|
|
1536
|
+
"Exception while loading model",
|
|
1537
|
+
)
|
|
1538
|
+
|
|
1539
|
+
model = """
|
|
1540
|
+
import fastapi
|
|
1541
|
+
class Model:
|
|
1542
|
+
def preprocess(self, inputs): ...
|
|
1543
|
+
"""
|
|
1544
|
+
with ensure_kill_all(), _temp_truss(model) as tr:
|
|
1545
|
+
container = tr.docker_run(
|
|
1546
|
+
local_port=8090, detach=True, wait_for_server_ready=False
|
|
1547
|
+
)
|
|
1548
|
+
time.sleep(1.0) # Wait for logs.
|
|
1549
|
+
_assert_logs_contain_error(
|
|
1550
|
+
container.logs(),
|
|
1551
|
+
"Truss model must have a `predict` method.",
|
|
1552
|
+
"Exception while loading model",
|
|
1553
|
+
)
|
|
1554
|
+
|
|
1555
|
+
|
|
1556
|
+
@pytest.mark.integration
|
|
1557
|
+
def test_truss_forbid_postprocessing_with_response():
|
|
1558
|
+
model = """
|
|
1559
|
+
import fastapi, json
|
|
1560
|
+
class Model:
|
|
1561
|
+
def predict(self, inputs):
|
|
1562
|
+
return fastapi.Response(content=json.dumps(inputs), status_code=200)
|
|
1563
|
+
|
|
1564
|
+
def postprocess(self, inputs):
|
|
1565
|
+
return inputs
|
|
1566
|
+
"""
|
|
1567
|
+
with ensure_kill_all(), _temp_truss(model) as tr:
|
|
1568
|
+
container = tr.docker_run(
|
|
1569
|
+
local_port=8090, detach=True, wait_for_server_ready=True
|
|
1570
|
+
)
|
|
1571
|
+
|
|
1572
|
+
response = requests.post(PREDICT_URL, json={})
|
|
1573
|
+
assert response.status_code == 500
|
|
1574
|
+
assert response.headers["x-baseten-error-source"] == "04"
|
|
1575
|
+
assert response.headers["x-baseten-error-code"] == "600"
|
|
1576
|
+
_assert_logs_contain_error(
|
|
1577
|
+
container.logs(),
|
|
1578
|
+
"If the predict function returns a response object, you cannot "
|
|
1579
|
+
"use postprocessing.",
|
|
1580
|
+
)
|
|
1581
|
+
|
|
1582
|
+
|
|
1583
|
+
@pytest.mark.integration
|
|
1584
|
+
def test_async_streaming_with_cancellation():
|
|
1585
|
+
model = """
|
|
1586
|
+
import fastapi, asyncio, logging
|
|
1587
|
+
|
|
1588
|
+
class Model:
|
|
1589
|
+
async def predict(self, inputs, request: fastapi.Request):
|
|
1590
|
+
await asyncio.sleep(1)
|
|
1591
|
+
if await request.is_disconnected():
|
|
1592
|
+
logging.warning("Cancelled (before gen).")
|
|
1593
|
+
return
|
|
1594
|
+
|
|
1595
|
+
for i in range(5):
|
|
1596
|
+
await asyncio.sleep(1.0)
|
|
1597
|
+
logging.warning(i)
|
|
1598
|
+
yield str(i)
|
|
1599
|
+
if await request.is_disconnected():
|
|
1600
|
+
logging.warning("Cancelled (during gen).")
|
|
1601
|
+
return
|
|
1602
|
+
"""
|
|
1603
|
+
with ensure_kill_all(), _temp_truss(model) as tr:
|
|
1604
|
+
container = tr.docker_run(
|
|
1605
|
+
local_port=8090, detach=True, wait_for_server_ready=True
|
|
1606
|
+
)
|
|
1607
|
+
# For hard cancellation we need to use httpx, requests' timeouts don't work.
|
|
1608
|
+
with pytest.raises(httpx.ReadTimeout):
|
|
1609
|
+
with httpx.Client(
|
|
1610
|
+
timeout=httpx.Timeout(1.0, connect=1.0, read=1.0)
|
|
1611
|
+
) as client:
|
|
1612
|
+
response = client.post(PREDICT_URL, json={}, timeout=1.0)
|
|
1613
|
+
response.raise_for_status()
|
|
1614
|
+
|
|
1615
|
+
time.sleep(2) # Wait a bit to get all logs.
|
|
1616
|
+
assert "Cancelled (during gen)." in container.logs()
|
|
1617
|
+
|
|
1618
|
+
|
|
1619
|
+
@pytest.mark.integration
|
|
1620
|
+
def test_async_non_streaming_with_cancellation():
|
|
1621
|
+
model = """
|
|
1622
|
+
import fastapi, asyncio, logging
|
|
1623
|
+
|
|
1624
|
+
class Model:
|
|
1625
|
+
async def predict(self, inputs, request: fastapi.Request):
|
|
1626
|
+
logging.info("Start sleep")
|
|
1627
|
+
await asyncio.sleep(2)
|
|
1628
|
+
logging.info("done sleep, check request.")
|
|
1629
|
+
if await request.is_disconnected():
|
|
1630
|
+
logging.warning("Cancelled (before gen).")
|
|
1631
|
+
return
|
|
1632
|
+
logging.info("Not cancelled.")
|
|
1633
|
+
return "Done"
|
|
1634
|
+
"""
|
|
1635
|
+
with ensure_kill_all(), _temp_truss(model) as tr:
|
|
1636
|
+
container = tr.docker_run(
|
|
1637
|
+
local_port=8090, detach=True, wait_for_server_ready=True
|
|
1638
|
+
)
|
|
1639
|
+
# For hard cancellation we need to use httpx, requests' timeouts don't work.
|
|
1640
|
+
with pytest.raises(httpx.ReadTimeout):
|
|
1641
|
+
with httpx.Client(
|
|
1642
|
+
timeout=httpx.Timeout(1.0, connect=1.0, read=1.0)
|
|
1643
|
+
) as client:
|
|
1644
|
+
response = client.post(PREDICT_URL, json={}, timeout=1.0)
|
|
1645
|
+
response.raise_for_status()
|
|
1646
|
+
|
|
1647
|
+
time.sleep(2) # Wait a bit to get all logs.
|
|
1648
|
+
assert "Cancelled (before gen)." in container.logs()
|
|
1649
|
+
|
|
1650
|
+
|
|
1651
|
+
@pytest.mark.integration
|
|
1652
|
+
def test_limit_concurrency_with_sse():
|
|
1653
|
+
# It seems that the "builtin" functionality of the FastAPI server already buffers
|
|
1654
|
+
# the generator, so that it doesn't keep hanging around if the client doesn't
|
|
1655
|
+
# consume data. `_buffered_response_generator` might be redundant.
|
|
1656
|
+
# This can be observed by waiting for a long time in `make_request`: the server will
|
|
1657
|
+
# print `Done` for the tasks, while we still wait and hold the unconsumed response.
|
|
1658
|
+
# For testing we need to have actually slow generation to keep the server busy.
|
|
1659
|
+
model = """
|
|
1660
|
+
import asyncio
|
|
1661
|
+
|
|
1662
|
+
class Model:
|
|
1663
|
+
async def predict(self, request):
|
|
1664
|
+
print(f"Starting {request}")
|
|
1665
|
+
for i in range(5):
|
|
1666
|
+
await asyncio.sleep(0.1)
|
|
1667
|
+
yield str(i)
|
|
1668
|
+
print(f"Done {request}")
|
|
1669
|
+
|
|
1670
|
+
"""
|
|
1671
|
+
|
|
1672
|
+
config = """runtime:
|
|
1673
|
+
predict_concurrency: 2"""
|
|
1674
|
+
|
|
1675
|
+
def make_request(consume_chunks, timeout, task_id):
|
|
1676
|
+
t0 = time.time()
|
|
1677
|
+
with httpx.Client() as client:
|
|
1678
|
+
with client.stream(
|
|
1679
|
+
"POST", PREDICT_URL, json={"task_id": task_id}
|
|
1680
|
+
) as response:
|
|
1681
|
+
assert response.status_code == 200
|
|
1682
|
+
if consume_chunks:
|
|
1683
|
+
chunks = [chunk for chunk in response.iter_text()]
|
|
1684
|
+
print(f"consumed chunks ({task_id}): {chunks}")
|
|
1685
|
+
assert len(chunks) > 0
|
|
1686
|
+
t1 = time.time()
|
|
1687
|
+
if t1 - t0 > timeout:
|
|
1688
|
+
raise httpx.ReadTimeout("Timeout")
|
|
1689
|
+
return chunks
|
|
1690
|
+
else:
|
|
1691
|
+
print(f"waiting ({task_id})")
|
|
1692
|
+
time.sleep(0.5) # Hold the connection.
|
|
1693
|
+
print(f"waiting done ({task_id})")
|
|
1694
|
+
|
|
1695
|
+
with ensure_kill_all(), _temp_truss(model, config) as tr:
|
|
1696
|
+
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
|
|
1697
|
+
# Processing full request takes 0.5s.
|
|
1698
|
+
print("Make warmup request")
|
|
1699
|
+
make_request(consume_chunks=True, timeout=0.55, task_id=0)
|
|
1700
|
+
|
|
1701
|
+
with ThreadPoolExecutor() as executor:
|
|
1702
|
+
# Start two requests and hold them without consuming all chunks
|
|
1703
|
+
# Each takes for 0.5 s. Semaphore should be claimed, with 0 remaining.
|
|
1704
|
+
print("Start two tasks.")
|
|
1705
|
+
task1 = executor.submit(make_request, False, 0.55, 1)
|
|
1706
|
+
task2 = executor.submit(make_request, False, 0.55, 2)
|
|
1707
|
+
print("Wait for tasks to start.")
|
|
1708
|
+
time.sleep(0.05)
|
|
1709
|
+
print("Make a request while server is busy.")
|
|
1710
|
+
with pytest.raises(httpx.ReadTimeout):
|
|
1711
|
+
make_request(True, timeout=0.55, task_id=3)
|
|
1712
|
+
|
|
1713
|
+
task1.result()
|
|
1714
|
+
task2.result()
|
|
1715
|
+
print("Task 1 and 2 completed. Server should be free again.")
|
|
1716
|
+
|
|
1717
|
+
result = make_request(True, timeout=0.55, task_id=4)
|
|
1718
|
+
print(f"Final chunks: {result}")
|
|
1719
|
+
|
|
1720
|
+
|
|
1721
|
+
@pytest.mark.integration
|
|
1722
|
+
def test_custom_openai_endpoints():
|
|
1723
|
+
"""
|
|
1724
|
+
Test a Truss that exposes an OpenAI compatible endpoint.
|
|
1725
|
+
"""
|
|
1726
|
+
model = """
|
|
1727
|
+
from typing import Dict
|
|
1728
|
+
|
|
1729
|
+
class Model:
|
|
1730
|
+
def __init__(self):
|
|
1731
|
+
pass
|
|
1732
|
+
|
|
1733
|
+
def load(self):
|
|
1734
|
+
self._predict_count = 0
|
|
1735
|
+
self._completions_count = 0
|
|
1736
|
+
|
|
1737
|
+
async def predict(self, inputs: Dict) -> int:
|
|
1738
|
+
self._predict_count += inputs["increment"]
|
|
1739
|
+
return self._predict_count
|
|
1740
|
+
|
|
1741
|
+
async def completions(self, inputs: Dict) -> int:
|
|
1742
|
+
self._completions_count += inputs["increment"]
|
|
1743
|
+
return self._completions_count
|
|
1744
|
+
"""
|
|
1745
|
+
with ensure_kill_all(), _temp_truss(model) as tr:
|
|
1746
|
+
tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
|
|
1747
|
+
|
|
1748
|
+
response = requests.post(PREDICT_URL, json={"increment": 1})
|
|
1749
|
+
assert response.status_code == 200
|
|
1750
|
+
assert response.json() == 1
|
|
1751
|
+
|
|
1752
|
+
response = requests.post(COMPLETIONS_URL, json={"increment": 2})
|
|
1753
|
+
assert response.status_code == 200
|
|
1754
|
+
assert response.json() == 2
|
|
1755
|
+
|
|
1756
|
+
response = requests.post(CHAT_COMPLETIONS_URL, json={"increment": 3})
|
|
1757
|
+
assert response.status_code == 404
|
|
1758
|
+
|
|
1759
|
+
|
|
1760
|
+
@pytest.mark.integration
|
|
1761
|
+
def test_postprocess_async_generator_streaming():
|
|
1762
|
+
"""
|
|
1763
|
+
Test a Truss that exposes an OpenAI compatible endpoint.
|
|
1764
|
+
"""
|
|
1765
|
+
model = """
|
|
1766
|
+
from typing import Dict, List, Generator
|
|
1767
|
+
|
|
1768
|
+
class Model:
|
|
1769
|
+
def __init__(self):
|
|
1770
|
+
pass
|
|
1771
|
+
|
|
1772
|
+
def load(self):
|
|
1773
|
+
pass
|
|
1774
|
+
|
|
1775
|
+
async def predict(self, inputs: Dict) -> List[str]:
|
|
1776
|
+
nums: List[int] = inputs["nums"]
|
|
1777
|
+
return nums
|
|
1778
|
+
|
|
1779
|
+
async def postprocess(self, nums: List[str]) -> Generator[str, None, None]:
|
|
1780
|
+
for num in nums:
|
|
1781
|
+
yield num
|
|
1782
|
+
"""
|
|
1783
|
+
with ensure_kill_all(), _temp_truss(model) as tr:
|
|
1784
|
+
tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
|
|
1785
|
+
|
|
1786
|
+
response = requests.post(PREDICT_URL, json={"nums": ["1", "2"]}, stream=True)
|
|
1787
|
+
assert response.headers.get("transfer-encoding") == "chunked"
|
|
1788
|
+
assert [
|
|
1789
|
+
byte_string.decode() for byte_string in list(response.iter_content())
|
|
1790
|
+
] == ["1", "2"]
|
|
1791
|
+
|
|
1792
|
+
|
|
1793
|
+
@pytest.mark.integration
|
|
1794
|
+
def test_preprocess_async_generator():
|
|
1795
|
+
"""
|
|
1796
|
+
Test a Truss that exposes an OpenAI compatible endpoint.
|
|
1797
|
+
"""
|
|
1798
|
+
model = """
|
|
1799
|
+
from typing import Dict, List, AsyncGenerator
|
|
1800
|
+
|
|
1801
|
+
class Model:
|
|
1802
|
+
def __init__(self):
|
|
1803
|
+
pass
|
|
1804
|
+
|
|
1805
|
+
def load(self):
|
|
1806
|
+
pass
|
|
1807
|
+
|
|
1808
|
+
async def preprocess(self, inputs: Dict) -> AsyncGenerator[str, None]:
|
|
1809
|
+
for num in inputs["nums"]:
|
|
1810
|
+
yield num
|
|
1811
|
+
|
|
1812
|
+
async def predict(self, nums: AsyncGenerator[str, None]) -> List[str]:
|
|
1813
|
+
return [num async for num in nums]
|
|
1814
|
+
"""
|
|
1815
|
+
with ensure_kill_all(), _temp_truss(model) as tr:
|
|
1816
|
+
tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
|
|
1817
|
+
|
|
1818
|
+
response = requests.post(PREDICT_URL, json={"nums": ["1", "2"]})
|
|
1819
|
+
assert response.status_code == 200
|
|
1820
|
+
assert response.json() == ["1", "2"]
|
|
1821
|
+
|
|
1822
|
+
|
|
1823
|
+
@pytest.mark.integration
|
|
1824
|
+
def test_openai_client_streaming():
|
|
1825
|
+
"""
|
|
1826
|
+
Test a Truss that exposes an OpenAI compatible endpoint.
|
|
1827
|
+
"""
|
|
1828
|
+
model = """
|
|
1829
|
+
from typing import Dict, AsyncGenerator
|
|
1830
|
+
|
|
1831
|
+
class Model:
|
|
1832
|
+
def __init__(self):
|
|
1833
|
+
pass
|
|
1834
|
+
|
|
1835
|
+
def load(self):
|
|
1836
|
+
pass
|
|
1837
|
+
|
|
1838
|
+
async def chat_completions(self, inputs: Dict) -> AsyncGenerator[str, None]:
|
|
1839
|
+
for num in inputs["nums"]:
|
|
1840
|
+
yield num
|
|
1841
|
+
|
|
1842
|
+
async def predict(self, inputs: Dict):
|
|
1843
|
+
pass
|
|
1844
|
+
"""
|
|
1845
|
+
with ensure_kill_all(), _temp_truss(model) as tr:
|
|
1846
|
+
tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
|
|
1847
|
+
|
|
1848
|
+
response = requests.post(
|
|
1849
|
+
CHAT_COMPLETIONS_URL,
|
|
1850
|
+
json={"nums": ["1", "2"]},
|
|
1851
|
+
stream=True,
|
|
1852
|
+
# Despite requesting json, we should still stream results back.
|
|
1853
|
+
headers={
|
|
1854
|
+
"accept": "application/json",
|
|
1855
|
+
"user-agent": "OpenAI/Python 1.61.0",
|
|
1856
|
+
},
|
|
1857
|
+
)
|
|
1858
|
+
assert response.headers.get("transfer-encoding") == "chunked"
|
|
1859
|
+
assert [
|
|
1860
|
+
byte_string.decode() for byte_string in list(response.iter_content())
|
|
1861
|
+
] == ["1", "2"]
|