truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/__init__.py +10 -3
- truss/api/__init__.py +123 -0
- truss/api/definitions.py +51 -0
- truss/base/constants.py +116 -0
- truss/base/custom_types.py +29 -0
- truss/{errors.py → base/errors.py} +4 -0
- truss/base/trt_llm_config.py +310 -0
- truss/{truss_config.py → base/truss_config.py} +344 -31
- truss/{truss_spec.py → base/truss_spec.py} +20 -6
- truss/{validation.py → base/validation.py} +60 -11
- truss/cli/cli.py +841 -88
- truss/{remote → cli}/remote_cli.py +2 -7
- truss/contexts/docker_build_setup.py +67 -0
- truss/contexts/image_builder/cache_warmer.py +2 -8
- truss/contexts/image_builder/image_builder.py +1 -1
- truss/contexts/image_builder/serving_image_builder.py +292 -46
- truss/contexts/image_builder/util.py +1 -3
- truss/contexts/local_loader/docker_build_emulator.py +58 -0
- truss/contexts/local_loader/load_model_local.py +2 -2
- truss/contexts/local_loader/truss_module_loader.py +1 -1
- truss/contexts/local_loader/utils.py +1 -1
- truss/local/local_config.py +2 -6
- truss/local/local_config_handler.py +20 -5
- truss/patch/__init__.py +1 -0
- truss/patch/hash.py +4 -70
- truss/patch/signature.py +4 -16
- truss/patch/truss_dir_patch_applier.py +3 -78
- truss/remote/baseten/api.py +308 -23
- truss/remote/baseten/auth.py +3 -3
- truss/remote/baseten/core.py +257 -50
- truss/remote/baseten/custom_types.py +44 -0
- truss/remote/baseten/error.py +4 -0
- truss/remote/baseten/remote.py +369 -118
- truss/remote/baseten/service.py +118 -11
- truss/remote/baseten/utils/status.py +29 -0
- truss/remote/baseten/utils/tar.py +34 -22
- truss/remote/baseten/utils/transfer.py +36 -23
- truss/remote/remote_factory.py +14 -5
- truss/remote/truss_remote.py +72 -45
- truss/templates/base.Dockerfile.jinja +18 -16
- truss/templates/cache.Dockerfile.jinja +3 -3
- truss/{server → templates/control}/control/application.py +14 -35
- truss/{server → templates/control}/control/endpoints.py +39 -9
- truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
- truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
- truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
- truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
- truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
- truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
- truss/{server → templates/control}/control/server.py +11 -6
- truss/templates/control/requirements.txt +9 -0
- truss/templates/custom_python_dx/my_model.py +28 -0
- truss/templates/docker_server/proxy.conf.jinja +42 -0
- truss/templates/docker_server/supervisord.conf.jinja +27 -0
- truss/templates/docker_server_requirements.txt +1 -0
- truss/templates/server/common/errors.py +231 -0
- truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
- truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
- truss/{server → templates/server}/common/retry.py +1 -0
- truss/{server → templates/server}/common/schema.py +11 -9
- truss/templates/server/common/tracing.py +157 -0
- truss/templates/server/main.py +9 -0
- truss/templates/server/model_wrapper.py +961 -0
- truss/templates/server/requirements.txt +21 -0
- truss/templates/server/truss_server.py +447 -0
- truss/templates/server.Dockerfile.jinja +62 -14
- truss/templates/shared/dynamic_config_resolver.py +28 -0
- truss/templates/shared/lazy_data_resolver.py +164 -0
- truss/templates/shared/log_config.py +125 -0
- truss/{server → templates}/shared/secrets_resolver.py +1 -2
- truss/{server → templates}/shared/serialization.py +31 -9
- truss/{server → templates}/shared/util.py +3 -13
- truss/templates/trtllm-audio/model/model.py +49 -0
- truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
- truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
- truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
- truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
- truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
- truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
- truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
- truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
- truss/templates/trtllm-briton/src/extension.py +64 -0
- truss/tests/conftest.py +302 -94
- truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
- truss/tests/contexts/local_loader/test_load_local.py +2 -2
- truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
- truss/tests/patch/test_calc_patch.py +439 -127
- truss/tests/patch/test_dir_signature.py +3 -12
- truss/tests/patch/test_hash.py +1 -1
- truss/tests/patch/test_signature.py +1 -1
- truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
- truss/tests/patch/test_types.py +2 -2
- truss/tests/remote/baseten/test_api.py +153 -58
- truss/tests/remote/baseten/test_auth.py +2 -1
- truss/tests/remote/baseten/test_core.py +160 -12
- truss/tests/remote/baseten/test_remote.py +489 -77
- truss/tests/remote/baseten/test_service.py +55 -0
- truss/tests/remote/test_remote_factory.py +16 -18
- truss/tests/remote/test_truss_remote.py +26 -17
- truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
- truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
- truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
- truss/tests/{server → templates/control}/control/test_server.py +79 -24
- truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
- truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
- truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
- truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
- truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
- truss/tests/{server → templates/server}/common/test_retry.py +3 -3
- truss/tests/templates/server/test_model_wrapper.py +248 -0
- truss/tests/{server → templates/server}/test_schema.py +3 -5
- truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
- truss/tests/test_build.py +9 -52
- truss/tests/test_config.py +336 -77
- truss/tests/test_context_builder_image.py +3 -11
- truss/tests/test_control_truss_patching.py +7 -12
- truss/tests/test_custom_server.py +38 -0
- truss/tests/test_data/context_builder_image_test/test.py +3 -0
- truss/tests/test_data/happy.ipynb +56 -0
- truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
- truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
- truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
- truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
- truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
- truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
- truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
- truss/tests/test_data/test_async_truss/__init__.py +0 -0
- truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/config.yaml +16 -0
- truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/config.yaml +13 -0
- truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
- truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
- truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
- truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
- truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
- truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
- truss/tests/test_data/test_openai/__init__.py +0 -0
- truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
- truss/tests/test_data/test_openai/model/__init__.py +0 -0
- truss/tests/test_data/test_openai/model/model.py +15 -0
- truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
- truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
- truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
- truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
- truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
- truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
- truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
- truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
- truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
- truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
- truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
- truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
- truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
- truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
- truss/tests/test_data/test_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss/config.yaml +4 -0
- truss/tests/test_data/test_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss/model/dummy +0 -0
- truss/tests/test_data/test_truss/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
- truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
- truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_docker.py +2 -1
- truss/tests/test_model_inference.py +1340 -292
- truss/tests/test_model_schema.py +33 -26
- truss/tests/test_testing_utilities_for_other_tests.py +50 -5
- truss/tests/test_truss_gatherer.py +3 -5
- truss/tests/test_truss_handle.py +62 -59
- truss/tests/test_util.py +2 -1
- truss/tests/test_validation.py +15 -13
- truss/tests/trt_llm/test_trt_llm_config.py +41 -0
- truss/tests/trt_llm/test_validation.py +91 -0
- truss/tests/util/test_config_checks.py +40 -0
- truss/tests/util/test_env_vars.py +14 -0
- truss/tests/util/test_path.py +10 -23
- truss/trt_llm/config_checks.py +43 -0
- truss/trt_llm/validation.py +42 -0
- truss/truss_handle/__init__.py +0 -0
- truss/truss_handle/build.py +122 -0
- truss/{decorators.py → truss_handle/decorators.py} +1 -1
- truss/truss_handle/patch/__init__.py +0 -0
- truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
- truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
- truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
- truss/truss_handle/patch/hash.py +71 -0
- truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
- truss/truss_handle/patch/signature.py +22 -0
- truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
- truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
- truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
- truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
- truss/util/.truss_ignore +3 -0
- truss/{docker.py → util/docker.py} +6 -2
- truss/util/download.py +6 -15
- truss/util/env_vars.py +41 -0
- truss/util/log_utils.py +52 -0
- truss/util/path.py +20 -20
- truss/util/requirements.py +11 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
- truss-0.60.0.dist-info/RECORD +324 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
- truss-0.60.0.dist-info/entry_points.txt +4 -0
- truss_chains/__init__.py +71 -0
- truss_chains/definitions.py +756 -0
- truss_chains/deployment/__init__.py +0 -0
- truss_chains/deployment/code_gen.py +816 -0
- truss_chains/deployment/deployment_client.py +871 -0
- truss_chains/framework.py +1480 -0
- truss_chains/public_api.py +231 -0
- truss_chains/py.typed +0 -0
- truss_chains/pydantic_numpy.py +131 -0
- truss_chains/reference_code/reference_chainlet.py +34 -0
- truss_chains/reference_code/reference_model.py +10 -0
- truss_chains/remote_chainlet/__init__.py +0 -0
- truss_chains/remote_chainlet/model_skeleton.py +60 -0
- truss_chains/remote_chainlet/stub.py +380 -0
- truss_chains/remote_chainlet/utils.py +332 -0
- truss_chains/streaming.py +378 -0
- truss_chains/utils.py +178 -0
- CODE_OF_CONDUCT.md +0 -131
- CONTRIBUTING.md +0 -48
- README.md +0 -137
- context_builder.Dockerfile +0 -24
- truss/blob/blob_backend.py +0 -10
- truss/blob/blob_backend_registry.py +0 -23
- truss/blob/http_public_blob_backend.py +0 -23
- truss/build/__init__.py +0 -2
- truss/build/build.py +0 -143
- truss/build/configure.py +0 -63
- truss/cli/__init__.py +0 -2
- truss/cli/console.py +0 -5
- truss/cli/create.py +0 -5
- truss/config/trt_llm.py +0 -81
- truss/constants.py +0 -61
- truss/model_inference.py +0 -123
- truss/patch/types.py +0 -30
- truss/pytest.ini +0 -7
- truss/server/common/errors.py +0 -100
- truss/server/common/termination_handler_middleware.py +0 -64
- truss/server/common/truss_server.py +0 -389
- truss/server/control/patch/model_code_patch_applier.py +0 -46
- truss/server/control/patch/requirement_name_identifier.py +0 -17
- truss/server/inference_server.py +0 -29
- truss/server/model_wrapper.py +0 -434
- truss/server/shared/logging.py +0 -81
- truss/templates/trtllm/model/model.py +0 -97
- truss/templates/trtllm/packages/build_engine_utils.py +0 -34
- truss/templates/trtllm/packages/constants.py +0 -11
- truss/templates/trtllm/packages/schema.py +0 -216
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
- truss/templates/trtllm/packages/triton_client.py +0 -150
- truss/templates/trtllm/packages/utils.py +0 -43
- truss/test_data/context_builder_image_test/test.py +0 -4
- truss/test_data/happy.ipynb +0 -54
- truss/test_data/model_load_failure_test/config.yaml +0 -2
- truss/test_data/test_concurrency_truss/config.yaml +0 -2
- truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
- truss/test_data/test_streaming_truss/config.yaml +0 -3
- truss/test_data/test_truss/config.yaml +0 -2
- truss/tests/server/common/test_termination_handler_middleware.py +0 -93
- truss/tests/server/control/test_model_container_patch_applier.py +0 -203
- truss/tests/server/core/server/common/test_util.py +0 -19
- truss/tests/server/test_model_wrapper.py +0 -87
- truss/util/data_structures.py +0 -16
- truss-0.10.0rc1.dist-info/RECORD +0 -216
- truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
- truss/{server/shared → base}/__init__.py +0 -0
- truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
- truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
- truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
- truss/{server → templates/server}/common/__init__.py +0 -0
- truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
- truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
- truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
- truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
- truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
- truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
- truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
- truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
- truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
- truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
- truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
- truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
- truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
- truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
- /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
- /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
- /truss/{patch → truss_handle/patch}/constants.py +0 -0
- /truss/{notebook.py → util/notebook.py} +0 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
|
@@ -0,0 +1,961 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import dataclasses
|
|
3
|
+
import enum
|
|
4
|
+
import importlib
|
|
5
|
+
import importlib.util
|
|
6
|
+
import inspect
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
import pathlib
|
|
11
|
+
import sys
|
|
12
|
+
import time
|
|
13
|
+
import weakref
|
|
14
|
+
from contextlib import asynccontextmanager
|
|
15
|
+
from functools import cached_property
|
|
16
|
+
from multiprocessing import Lock
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from threading import Thread
|
|
19
|
+
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Union, cast
|
|
20
|
+
|
|
21
|
+
import opentelemetry.sdk.trace as sdk_trace
|
|
22
|
+
import pydantic
|
|
23
|
+
import starlette.requests
|
|
24
|
+
import starlette.responses
|
|
25
|
+
from anyio import Semaphore, to_thread
|
|
26
|
+
from common import errors, tracing
|
|
27
|
+
from common.patches import apply_patches
|
|
28
|
+
from common.retry import retry
|
|
29
|
+
from common.schema import TrussSchema
|
|
30
|
+
from opentelemetry import trace
|
|
31
|
+
from shared import dynamic_config_resolver, serialization
|
|
32
|
+
from shared.lazy_data_resolver import LazyDataResolver
|
|
33
|
+
from shared.secrets_resolver import SecretsResolver
|
|
34
|
+
|
|
35
|
+
if sys.version_info >= (3, 9):
|
|
36
|
+
from typing import AsyncGenerator, Generator
|
|
37
|
+
else:
|
|
38
|
+
from typing_extensions import AsyncGenerator, Generator
|
|
39
|
+
|
|
40
|
+
MODEL_BASENAME = "model"
|
|
41
|
+
|
|
42
|
+
NUM_LOAD_RETRIES = int(os.environ.get("NUM_LOAD_RETRIES_TRUSS", "1"))
|
|
43
|
+
STREAMING_RESPONSE_QUEUE_READ_TIMEOUT_SECS = 60
|
|
44
|
+
DEFAULT_PREDICT_CONCURRENCY = 1
|
|
45
|
+
EXTENSIONS_DIR_NAME = "extensions"
|
|
46
|
+
EXTENSION_CLASS_NAME = "Extension"
|
|
47
|
+
EXTENSION_FILE_NAME = "extension"
|
|
48
|
+
TRT_LLM_EXTENSION_NAME = "trt_llm"
|
|
49
|
+
POLL_FOR_ENVIRONMENT_UPDATES_TIMEOUT_SECS = 30
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class MethodName(str, enum.Enum):
|
|
53
|
+
CHAT_COMPLETIONS = "chat_completions"
|
|
54
|
+
COMPLETIONS = "completions"
|
|
55
|
+
IS_HEALTHY = "is_healthy"
|
|
56
|
+
POSTPROCESS = "postprocess"
|
|
57
|
+
PREDICT = "predict"
|
|
58
|
+
PREPROCESS = "preprocess"
|
|
59
|
+
SETUP_ENVIRONMENT = "setup_environment"
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
InputType = Union[serialization.JSONType, serialization.MsgPackType, pydantic.BaseModel]
|
|
63
|
+
OutputType = Union[
|
|
64
|
+
serialization.JSONType,
|
|
65
|
+
serialization.MsgPackType,
|
|
66
|
+
Generator[bytes, None, None],
|
|
67
|
+
AsyncGenerator[bytes, None],
|
|
68
|
+
"starlette.responses.Response",
|
|
69
|
+
pydantic.BaseModel,
|
|
70
|
+
]
|
|
71
|
+
ModelFn = Callable[..., Union[OutputType, Awaitable[OutputType]]]
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@asynccontextmanager
|
|
75
|
+
async def deferred_semaphore_and_span(
|
|
76
|
+
semaphore: Semaphore, span: trace.Span
|
|
77
|
+
) -> AsyncGenerator[Callable[[], Callable[[], None]], None]:
|
|
78
|
+
"""
|
|
79
|
+
Context manager that allows deferring the release of a semaphore and the ending of a
|
|
80
|
+
trace span.
|
|
81
|
+
|
|
82
|
+
Yields a function that, when called, releases the semaphore and ends the span.
|
|
83
|
+
If that function is not called, the resources are cleand up when exiting.
|
|
84
|
+
"""
|
|
85
|
+
await semaphore.acquire()
|
|
86
|
+
trace.use_span(span, end_on_exit=False)
|
|
87
|
+
deferred = False
|
|
88
|
+
|
|
89
|
+
def release_and_end() -> None:
|
|
90
|
+
semaphore.release()
|
|
91
|
+
span.end()
|
|
92
|
+
|
|
93
|
+
def defer() -> Callable[[], None]:
|
|
94
|
+
nonlocal deferred
|
|
95
|
+
deferred = True
|
|
96
|
+
return release_and_end
|
|
97
|
+
|
|
98
|
+
try:
|
|
99
|
+
yield defer
|
|
100
|
+
finally:
|
|
101
|
+
if not deferred:
|
|
102
|
+
release_and_end()
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
_ArgsType = Union[
|
|
106
|
+
Tuple[Any],
|
|
107
|
+
Tuple[Any, starlette.requests.Request],
|
|
108
|
+
Tuple[starlette.requests.Request],
|
|
109
|
+
]
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class _Sentinel:
|
|
113
|
+
def __repr__(self) -> str:
|
|
114
|
+
return "<Sentinel End of Queue>"
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
SENTINEL = _Sentinel()
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _is_request_type(obj: Any) -> bool:
|
|
121
|
+
# issubclass raises an error (instead of returning False) if `obj` is not a type.
|
|
122
|
+
try:
|
|
123
|
+
return issubclass(obj, starlette.requests.Request)
|
|
124
|
+
except Exception:
|
|
125
|
+
return False
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class ArgConfig(enum.Enum):
|
|
129
|
+
NONE = enum.auto()
|
|
130
|
+
INPUTS_ONLY = enum.auto()
|
|
131
|
+
REQUEST_ONLY = enum.auto()
|
|
132
|
+
INPUTS_AND_REQUEST = enum.auto()
|
|
133
|
+
|
|
134
|
+
@classmethod
|
|
135
|
+
def from_signature(
|
|
136
|
+
cls, signature: inspect.Signature, method_name: str
|
|
137
|
+
) -> "ArgConfig":
|
|
138
|
+
parameters = list(signature.parameters.values())
|
|
139
|
+
|
|
140
|
+
if len(parameters) == 0:
|
|
141
|
+
return cls.NONE
|
|
142
|
+
elif len(parameters) == 1:
|
|
143
|
+
if _is_request_type(parameters[0].annotation):
|
|
144
|
+
return cls.REQUEST_ONLY
|
|
145
|
+
return cls.INPUTS_ONLY
|
|
146
|
+
elif len(parameters) == 2:
|
|
147
|
+
# First arg can be whatever, except request. Second arg must be request.
|
|
148
|
+
param1, param2 = parameters
|
|
149
|
+
if param1.annotation:
|
|
150
|
+
if _is_request_type(param1.annotation):
|
|
151
|
+
raise errors.ModelDefinitionError(
|
|
152
|
+
f"`{method_name}` method with two arguments is not allowed to "
|
|
153
|
+
"have request as first argument, request must be second. "
|
|
154
|
+
f"Got: {signature}"
|
|
155
|
+
)
|
|
156
|
+
if not (param2.annotation and _is_request_type(param2.annotation)):
|
|
157
|
+
raise errors.ModelDefinitionError(
|
|
158
|
+
f"`{method_name}` method with two arguments must have request as "
|
|
159
|
+
f"second argument (type annotated). Got: {signature} "
|
|
160
|
+
)
|
|
161
|
+
return cls.INPUTS_AND_REQUEST
|
|
162
|
+
else:
|
|
163
|
+
raise errors.ModelDefinitionError(
|
|
164
|
+
f"`{method_name}` method cannot have more than two arguments. "
|
|
165
|
+
f"Got: {signature}"
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
@classmethod
|
|
169
|
+
def prepare_args(
|
|
170
|
+
cls,
|
|
171
|
+
inputs: Any,
|
|
172
|
+
request: starlette.requests.Request,
|
|
173
|
+
descriptor: "MethodDescriptor",
|
|
174
|
+
) -> _ArgsType:
|
|
175
|
+
args: _ArgsType
|
|
176
|
+
if descriptor.arg_config == ArgConfig.INPUTS_ONLY:
|
|
177
|
+
args = (inputs,)
|
|
178
|
+
elif descriptor.arg_config == ArgConfig.REQUEST_ONLY:
|
|
179
|
+
args = (request,)
|
|
180
|
+
elif descriptor.arg_config == ArgConfig.INPUTS_AND_REQUEST:
|
|
181
|
+
args = (inputs, request)
|
|
182
|
+
else:
|
|
183
|
+
raise NotImplementedError(f"Arg config {descriptor.arg_config}.")
|
|
184
|
+
return args
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@dataclasses.dataclass
|
|
188
|
+
class MethodDescriptor:
|
|
189
|
+
is_async: bool
|
|
190
|
+
is_generator: bool
|
|
191
|
+
arg_config: ArgConfig
|
|
192
|
+
method_name: MethodName
|
|
193
|
+
method: ModelFn
|
|
194
|
+
|
|
195
|
+
@classmethod
|
|
196
|
+
def from_method(cls, method: Any, method_name: MethodName) -> "MethodDescriptor":
|
|
197
|
+
return cls(
|
|
198
|
+
is_async=cls._is_async(method),
|
|
199
|
+
is_generator=cls._is_generator(method),
|
|
200
|
+
arg_config=ArgConfig.from_signature(inspect.signature(method), method_name),
|
|
201
|
+
method_name=method_name,
|
|
202
|
+
# ArgConfig ensures that the Callable has an appropriate signature.
|
|
203
|
+
method=cast(ModelFn, method),
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
@classmethod
|
|
207
|
+
def _is_async(cls, method: Any):
|
|
208
|
+
# We intentionally do not check inspect.isasyncgenfunction(method) because you cannot
|
|
209
|
+
# `await` an async generator, you must use `async for` syntax.
|
|
210
|
+
return inspect.iscoroutinefunction(method)
|
|
211
|
+
|
|
212
|
+
@classmethod
|
|
213
|
+
def _is_generator(cls, method: Any):
|
|
214
|
+
return inspect.isgeneratorfunction(method) or inspect.isasyncgenfunction(method)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
@dataclasses.dataclass
|
|
218
|
+
class ModelDescriptor:
|
|
219
|
+
preprocess: Optional[MethodDescriptor]
|
|
220
|
+
predict: MethodDescriptor
|
|
221
|
+
postprocess: Optional[MethodDescriptor]
|
|
222
|
+
truss_schema: Optional[TrussSchema]
|
|
223
|
+
setup_environment: Optional[MethodDescriptor]
|
|
224
|
+
is_healthy: Optional[MethodDescriptor]
|
|
225
|
+
completions: Optional[MethodDescriptor]
|
|
226
|
+
chat_completions: Optional[MethodDescriptor]
|
|
227
|
+
|
|
228
|
+
@cached_property
|
|
229
|
+
def skip_input_parsing(self) -> bool:
|
|
230
|
+
return self.predict.arg_config == ArgConfig.REQUEST_ONLY and (
|
|
231
|
+
not self.preprocess or self.preprocess.arg_config == ArgConfig.REQUEST_ONLY
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
@classmethod
|
|
235
|
+
def _gen_truss_schema(
|
|
236
|
+
cls,
|
|
237
|
+
model_cls: Any,
|
|
238
|
+
predict: MethodDescriptor,
|
|
239
|
+
preprocess: Optional[MethodDescriptor],
|
|
240
|
+
postprocess: Optional[MethodDescriptor],
|
|
241
|
+
) -> TrussSchema:
|
|
242
|
+
if preprocess:
|
|
243
|
+
parameters = inspect.signature(model_cls.preprocess).parameters
|
|
244
|
+
else:
|
|
245
|
+
parameters = inspect.signature(model_cls.predict).parameters
|
|
246
|
+
|
|
247
|
+
if postprocess:
|
|
248
|
+
return_annotation = inspect.signature(
|
|
249
|
+
model_cls.postprocess
|
|
250
|
+
).return_annotation
|
|
251
|
+
else:
|
|
252
|
+
return_annotation = inspect.signature(model_cls.predict).return_annotation
|
|
253
|
+
|
|
254
|
+
return TrussSchema.from_signature(parameters, return_annotation)
|
|
255
|
+
|
|
256
|
+
@classmethod
|
|
257
|
+
def _safe_extract_descriptor(
|
|
258
|
+
cls, model_cls: Any, method_name: MethodName
|
|
259
|
+
) -> Union[MethodDescriptor, None]:
|
|
260
|
+
if hasattr(model_cls, method_name):
|
|
261
|
+
return MethodDescriptor.from_method(
|
|
262
|
+
method=getattr(model_cls, method_name), method_name=method_name
|
|
263
|
+
)
|
|
264
|
+
return None
|
|
265
|
+
|
|
266
|
+
@classmethod
|
|
267
|
+
def from_model(cls, model_cls) -> "ModelDescriptor":
|
|
268
|
+
preprocess = cls._safe_extract_descriptor(model_cls, MethodName.PREPROCESS)
|
|
269
|
+
predict = cls._safe_extract_descriptor(model_cls, MethodName.PREDICT)
|
|
270
|
+
if predict is None:
|
|
271
|
+
raise errors.ModelDefinitionError(
|
|
272
|
+
f"Truss model must have a `{MethodName.PREDICT}` method."
|
|
273
|
+
)
|
|
274
|
+
elif preprocess and predict.arg_config == ArgConfig.REQUEST_ONLY:
|
|
275
|
+
raise errors.ModelDefinitionError(
|
|
276
|
+
f"When using `{MethodName.PREPROCESS}`, the {MethodName.PREDICT} method "
|
|
277
|
+
f"cannot only have the request argument (because the result of `{MethodName.PREPROCESS}` "
|
|
278
|
+
"would be discarded)."
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
postprocess = cls._safe_extract_descriptor(model_cls, MethodName.POSTPROCESS)
|
|
282
|
+
if postprocess and postprocess.arg_config == ArgConfig.REQUEST_ONLY:
|
|
283
|
+
raise errors.ModelDefinitionError(
|
|
284
|
+
f"The `{MethodName.POSTPROCESS}` method cannot only have the request "
|
|
285
|
+
f"argument (because the result of `{MethodName.PREDICT}` would be discarded)."
|
|
286
|
+
)
|
|
287
|
+
setup = cls._safe_extract_descriptor(model_cls, MethodName.SETUP_ENVIRONMENT)
|
|
288
|
+
completions = cls._safe_extract_descriptor(model_cls, MethodName.COMPLETIONS)
|
|
289
|
+
chats = cls._safe_extract_descriptor(model_cls, MethodName.CHAT_COMPLETIONS)
|
|
290
|
+
is_healthy = cls._safe_extract_descriptor(model_cls, MethodName.IS_HEALTHY)
|
|
291
|
+
if is_healthy and is_healthy.arg_config != ArgConfig.NONE:
|
|
292
|
+
raise errors.ModelDefinitionError(
|
|
293
|
+
f"`{MethodName.IS_HEALTHY}` must have only one argument: `self`."
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
truss_schema = cls._gen_truss_schema(
|
|
297
|
+
model_cls=model_cls,
|
|
298
|
+
predict=predict,
|
|
299
|
+
preprocess=preprocess,
|
|
300
|
+
postprocess=postprocess,
|
|
301
|
+
)
|
|
302
|
+
return cls(
|
|
303
|
+
preprocess=preprocess,
|
|
304
|
+
predict=predict,
|
|
305
|
+
postprocess=postprocess,
|
|
306
|
+
truss_schema=truss_schema,
|
|
307
|
+
setup_environment=setup,
|
|
308
|
+
is_healthy=is_healthy,
|
|
309
|
+
completions=completions,
|
|
310
|
+
chat_completions=chats,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
class ModelWrapper:
|
|
315
|
+
_config: Dict
|
|
316
|
+
_tracer: sdk_trace.Tracer
|
|
317
|
+
_maybe_model: Optional[Any]
|
|
318
|
+
_maybe_model_descriptor: Optional[ModelDescriptor]
|
|
319
|
+
_logger: logging.Logger
|
|
320
|
+
_status: "ModelWrapper.Status"
|
|
321
|
+
_predict_semaphore: Semaphore
|
|
322
|
+
_poll_for_environment_updates_task: Optional[asyncio.Task]
|
|
323
|
+
_environment: Optional[dict]
|
|
324
|
+
|
|
325
|
+
class Status(enum.Enum):
|
|
326
|
+
NOT_READY = 0
|
|
327
|
+
LOADING = 1
|
|
328
|
+
READY = 2
|
|
329
|
+
FAILED = 3
|
|
330
|
+
|
|
331
|
+
def __init__(self, config: Dict, tracer: sdk_trace.Tracer):
|
|
332
|
+
self._config = config
|
|
333
|
+
self._tracer = tracer
|
|
334
|
+
self._maybe_model = None
|
|
335
|
+
self._maybe_model_descriptor = None
|
|
336
|
+
# We need a logger that has all our server JSON logging setup applied in its
|
|
337
|
+
# handlers and where this also hold in the loading thread. Creating a new
|
|
338
|
+
# instance does not carry over the setup into the thread and using unspecified
|
|
339
|
+
# `getLogger` may return non-compliant loggers if dependencies override the root
|
|
340
|
+
# logger (c.g. https://github.com/numpy/numpy/issues/24213). We chose to get
|
|
341
|
+
# the uvicorn logger that is set up in `truss_server`.
|
|
342
|
+
self._logger = logging.getLogger("uvicorn")
|
|
343
|
+
self.name = MODEL_BASENAME
|
|
344
|
+
self._load_lock = Lock()
|
|
345
|
+
self._status = ModelWrapper.Status.NOT_READY
|
|
346
|
+
self._predict_semaphore = Semaphore(
|
|
347
|
+
self._config.get("runtime", {}).get(
|
|
348
|
+
"predict_concurrency", DEFAULT_PREDICT_CONCURRENCY
|
|
349
|
+
)
|
|
350
|
+
)
|
|
351
|
+
self._poll_for_environment_updates_task = None
|
|
352
|
+
self._environment = None
|
|
353
|
+
|
|
354
|
+
@property
|
|
355
|
+
def _model(self) -> Any:
|
|
356
|
+
if self._maybe_model:
|
|
357
|
+
return self._maybe_model
|
|
358
|
+
else:
|
|
359
|
+
raise errors.ModelNotReady(self.name)
|
|
360
|
+
|
|
361
|
+
@property
|
|
362
|
+
def model_descriptor(self) -> ModelDescriptor:
|
|
363
|
+
if self._maybe_model_descriptor:
|
|
364
|
+
return self._maybe_model_descriptor
|
|
365
|
+
else:
|
|
366
|
+
raise errors.ModelNotReady(self.name)
|
|
367
|
+
|
|
368
|
+
@property
|
|
369
|
+
def load_failed(self) -> bool:
|
|
370
|
+
return self._status == ModelWrapper.Status.FAILED
|
|
371
|
+
|
|
372
|
+
@property
|
|
373
|
+
def ready(self) -> bool:
|
|
374
|
+
return self._status == ModelWrapper.Status.READY
|
|
375
|
+
|
|
376
|
+
@property
|
|
377
|
+
def _model_file_name(self) -> str:
|
|
378
|
+
return self._config["model_class_filename"]
|
|
379
|
+
|
|
380
|
+
def start_load_thread(self):
|
|
381
|
+
# Don't retry failed loads.
|
|
382
|
+
if self._status == ModelWrapper.Status.NOT_READY:
|
|
383
|
+
thread = Thread(target=self.load)
|
|
384
|
+
thread.start()
|
|
385
|
+
|
|
386
|
+
def load(self):
|
|
387
|
+
if self.ready:
|
|
388
|
+
return
|
|
389
|
+
# if we are already loading, block on acquiring the lock;
|
|
390
|
+
# this worker will return 503 while the worker with the lock is loading
|
|
391
|
+
with self._load_lock:
|
|
392
|
+
self._status = ModelWrapper.Status.LOADING
|
|
393
|
+
self._logger.info("Executing model.load()...")
|
|
394
|
+
try:
|
|
395
|
+
start_time = time.perf_counter()
|
|
396
|
+
self._load_impl()
|
|
397
|
+
self._status = ModelWrapper.Status.READY
|
|
398
|
+
self._logger.info(
|
|
399
|
+
f"Completed model.load() execution in {_elapsed_ms(start_time)} ms"
|
|
400
|
+
)
|
|
401
|
+
except Exception:
|
|
402
|
+
self._logger.exception("Exception while loading model")
|
|
403
|
+
self._status = ModelWrapper.Status.FAILED
|
|
404
|
+
|
|
405
|
+
def _load_impl(self):
|
|
406
|
+
data_dir = Path("data")
|
|
407
|
+
data_dir.mkdir(exist_ok=True)
|
|
408
|
+
|
|
409
|
+
if "bundled_packages_dir" in self._config:
|
|
410
|
+
bundled_packages_path = Path("/packages")
|
|
411
|
+
if bundled_packages_path.exists():
|
|
412
|
+
sys.path.append(str(bundled_packages_path))
|
|
413
|
+
|
|
414
|
+
secrets = SecretsResolver.get_secrets(self._config)
|
|
415
|
+
lazy_data_resolver = LazyDataResolver(data_dir)
|
|
416
|
+
|
|
417
|
+
apply_patches(
|
|
418
|
+
self._config.get("apply_library_patches", True),
|
|
419
|
+
self._config["requirements"],
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
extensions = _init_extensions(
|
|
423
|
+
self._config, data_dir, secrets, lazy_data_resolver
|
|
424
|
+
)
|
|
425
|
+
for extension in extensions.values():
|
|
426
|
+
extension.load()
|
|
427
|
+
|
|
428
|
+
model_class_file_path = (
|
|
429
|
+
Path(self._config["model_module_dir"])
|
|
430
|
+
/ self._config["model_class_filename"]
|
|
431
|
+
)
|
|
432
|
+
if model_class_file_path.exists():
|
|
433
|
+
self._logger.info("Loading truss model from file.")
|
|
434
|
+
module_path = pathlib.Path(model_class_file_path).resolve()
|
|
435
|
+
module_name = module_path.stem # Use the file's name as the module name
|
|
436
|
+
if not os.path.isfile(module_path):
|
|
437
|
+
raise ImportError(
|
|
438
|
+
f"`{module_path}` is not a file. You must point to a python file where "
|
|
439
|
+
"the entrypoint chainlet is defined."
|
|
440
|
+
)
|
|
441
|
+
import_error_msg = f"Could not import `{module_path}`. Check path."
|
|
442
|
+
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
|
443
|
+
if not spec:
|
|
444
|
+
raise ImportError(import_error_msg)
|
|
445
|
+
if not spec.loader:
|
|
446
|
+
raise ImportError(import_error_msg)
|
|
447
|
+
module = importlib.util.module_from_spec(spec)
|
|
448
|
+
try:
|
|
449
|
+
spec.loader.exec_module(module)
|
|
450
|
+
except ImportError as e:
|
|
451
|
+
if "attempted relative import" in str(e):
|
|
452
|
+
raise ImportError(
|
|
453
|
+
f"During import of `{model_class_file_path}`. "
|
|
454
|
+
f"Since Truss v0.9.36 relative imports (starting with '.') in "
|
|
455
|
+
"the top-level model file are no longer supported. Please "
|
|
456
|
+
"replace them with absolute imports. For guidance on importing "
|
|
457
|
+
"custom packages refer to our documentation "
|
|
458
|
+
"https://docs.baseten.co/truss-reference/config#packages"
|
|
459
|
+
) from e
|
|
460
|
+
|
|
461
|
+
raise
|
|
462
|
+
|
|
463
|
+
model_class = getattr(module, self._config["model_class_name"])
|
|
464
|
+
model_init_params = _prepare_init_args(
|
|
465
|
+
model_class, self._config, data_dir, secrets, lazy_data_resolver
|
|
466
|
+
)
|
|
467
|
+
signature = inspect.signature(model_class)
|
|
468
|
+
for ext_name, ext in extensions.items():
|
|
469
|
+
if _signature_accepts_keyword_arg(signature, ext_name):
|
|
470
|
+
model_init_params[ext_name] = ext.model_args()
|
|
471
|
+
self._maybe_model = model_class(**model_init_params)
|
|
472
|
+
|
|
473
|
+
elif TRT_LLM_EXTENSION_NAME in extensions:
|
|
474
|
+
self._logger.info("Loading TRT LLM extension as model.")
|
|
475
|
+
# trt_llm extension allows model.py to be absent. It supplies its
|
|
476
|
+
# own model class in that case.
|
|
477
|
+
trt_llm_extension = extensions["trt_llm"]
|
|
478
|
+
self._maybe_model = trt_llm_extension.model_override()
|
|
479
|
+
else:
|
|
480
|
+
raise RuntimeError("No module class file found")
|
|
481
|
+
|
|
482
|
+
self._maybe_model_descriptor = ModelDescriptor.from_model(self._model)
|
|
483
|
+
|
|
484
|
+
if self._maybe_model_descriptor.setup_environment:
|
|
485
|
+
self._initialize_environment_before_load()
|
|
486
|
+
|
|
487
|
+
if hasattr(self._model, "load"):
|
|
488
|
+
retry(
|
|
489
|
+
self._model.load,
|
|
490
|
+
NUM_LOAD_RETRIES,
|
|
491
|
+
self._logger.warning,
|
|
492
|
+
"Failed to load model.",
|
|
493
|
+
gap_seconds=1.0,
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
def setup_polling_for_environment_updates(self):
|
|
497
|
+
self._poll_for_environment_updates_task = asyncio.create_task(
|
|
498
|
+
self.poll_for_environment_updates()
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
def _initialize_environment_before_load(self):
|
|
502
|
+
environment_str = dynamic_config_resolver.get_dynamic_config_value_sync(
|
|
503
|
+
dynamic_config_resolver.ENVIRONMENT_DYNAMIC_CONFIG_KEY
|
|
504
|
+
)
|
|
505
|
+
if environment_str:
|
|
506
|
+
environment_json = json.loads(environment_str)
|
|
507
|
+
self._logger.info(
|
|
508
|
+
f"Executing model.setup_environment with environment: {environment_json}"
|
|
509
|
+
)
|
|
510
|
+
# TODO: Support calling an async setup_environment() here once we support async load()
|
|
511
|
+
self._model.setup_environment(environment_json)
|
|
512
|
+
self._environment = environment_json
|
|
513
|
+
|
|
514
|
+
async def setup_environment(self, environment: Optional[dict]):
|
|
515
|
+
descriptor = self.model_descriptor.setup_environment
|
|
516
|
+
if not descriptor:
|
|
517
|
+
return
|
|
518
|
+
self._logger.info(
|
|
519
|
+
f"Executing model.setup_environment with environment: {environment}"
|
|
520
|
+
)
|
|
521
|
+
if descriptor.is_async:
|
|
522
|
+
return await self._model.setup_environment(environment)
|
|
523
|
+
else:
|
|
524
|
+
return await to_thread.run_sync(self._model.setup_environment, environment)
|
|
525
|
+
|
|
526
|
+
async def poll_for_environment_updates(self) -> None:
|
|
527
|
+
last_modified_time = None
|
|
528
|
+
environment_config_filename = (
|
|
529
|
+
dynamic_config_resolver.get_dynamic_config_file_path(
|
|
530
|
+
dynamic_config_resolver.ENVIRONMENT_DYNAMIC_CONFIG_KEY
|
|
531
|
+
)
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
while True:
|
|
535
|
+
# Give control back to the event loop while waiting for environment updates
|
|
536
|
+
await asyncio.sleep(POLL_FOR_ENVIRONMENT_UPDATES_TIMEOUT_SECS)
|
|
537
|
+
|
|
538
|
+
# Wait for load to finish before checking for environment updates
|
|
539
|
+
if not self.ready:
|
|
540
|
+
continue
|
|
541
|
+
|
|
542
|
+
# Skip polling if no setup_environment implementation provided
|
|
543
|
+
if not self.model_descriptor.setup_environment:
|
|
544
|
+
break
|
|
545
|
+
|
|
546
|
+
if environment_config_filename.exists():
|
|
547
|
+
try:
|
|
548
|
+
current_mtime = os.path.getmtime(environment_config_filename)
|
|
549
|
+
if not last_modified_time or last_modified_time != current_mtime:
|
|
550
|
+
environment_str = await dynamic_config_resolver.get_dynamic_config_value_async(
|
|
551
|
+
dynamic_config_resolver.ENVIRONMENT_DYNAMIC_CONFIG_KEY
|
|
552
|
+
)
|
|
553
|
+
if environment_str:
|
|
554
|
+
last_modified_time = current_mtime
|
|
555
|
+
environment_json = json.loads(environment_str)
|
|
556
|
+
# Avoid rerunning `setup_environment` with the same environment
|
|
557
|
+
if self._environment != environment_json:
|
|
558
|
+
await self.setup_environment(environment_json)
|
|
559
|
+
self._environment = environment_json
|
|
560
|
+
except Exception as e:
|
|
561
|
+
self._logger.exception(
|
|
562
|
+
f"Exception while setting up environment: {str(e)}",
|
|
563
|
+
exc_info=errors.filter_traceback(self._model_file_name),
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
async def is_healthy(self) -> Optional[bool]:
|
|
567
|
+
descriptor = self.model_descriptor.is_healthy
|
|
568
|
+
is_healthy: Optional[bool] = None
|
|
569
|
+
if not descriptor or self.load_failed:
|
|
570
|
+
# return early with None if model does not have is_healthy method or load failed
|
|
571
|
+
return is_healthy
|
|
572
|
+
try:
|
|
573
|
+
if descriptor.is_async:
|
|
574
|
+
is_healthy = await self._model.is_healthy()
|
|
575
|
+
else:
|
|
576
|
+
# Offload sync functions to thread, to not block event loop.
|
|
577
|
+
is_healthy = await to_thread.run_sync(self._model.is_healthy)
|
|
578
|
+
except Exception as e:
|
|
579
|
+
is_healthy = False
|
|
580
|
+
self._logger.exception(
|
|
581
|
+
f"Exception while checking if model is healthy: {str(e)}",
|
|
582
|
+
exc_info=errors.filter_traceback(self._model_file_name),
|
|
583
|
+
)
|
|
584
|
+
if not is_healthy and self.ready:
|
|
585
|
+
# self.ready evaluates to True when the model's load function has completed,
|
|
586
|
+
# we will only log health check failures to model logs when the model's load has completed
|
|
587
|
+
self._logger.warning("Health check failed.")
|
|
588
|
+
return is_healthy
|
|
589
|
+
|
|
590
|
+
async def preprocess(
|
|
591
|
+
self, inputs: InputType, request: starlette.requests.Request
|
|
592
|
+
) -> Any:
|
|
593
|
+
descriptor = self.model_descriptor.preprocess
|
|
594
|
+
assert descriptor, (
|
|
595
|
+
f"`{MethodName.PREPROCESS}` must only be called if model has it."
|
|
596
|
+
)
|
|
597
|
+
return await self._execute_async_model_fn(inputs, request, descriptor)
|
|
598
|
+
|
|
599
|
+
async def predict(
|
|
600
|
+
self, inputs: Any, request: starlette.requests.Request
|
|
601
|
+
) -> Union[OutputType, Any]:
|
|
602
|
+
# The result can be a serializable data structure, byte-generator, a request,
|
|
603
|
+
# or, if `postprocessing` is used, anything. In the last case postprocessing
|
|
604
|
+
# must convert the result to something serializable.
|
|
605
|
+
descriptor = self.model_descriptor.predict
|
|
606
|
+
return await self._execute_async_model_fn(inputs, request, descriptor)
|
|
607
|
+
|
|
608
|
+
async def postprocess(
|
|
609
|
+
self, result: Union[InputType, Any], request: starlette.requests.Request
|
|
610
|
+
) -> OutputType:
|
|
611
|
+
# The postprocess function can handle outputs of `predict`, but not
|
|
612
|
+
# generators and responses - in that case predict must return directly
|
|
613
|
+
# and postprocess is skipped.
|
|
614
|
+
# The result type can be the same as for predict.
|
|
615
|
+
descriptor = self.model_descriptor.postprocess
|
|
616
|
+
assert descriptor, (
|
|
617
|
+
f"`{MethodName.POSTPROCESS}` must only be called if model has it."
|
|
618
|
+
)
|
|
619
|
+
return await self._execute_async_model_fn(result, request, descriptor)
|
|
620
|
+
|
|
621
|
+
async def _write_response_to_queue(
|
|
622
|
+
self,
|
|
623
|
+
queue: asyncio.Queue,
|
|
624
|
+
generator: AsyncGenerator[bytes, None],
|
|
625
|
+
span: trace.Span,
|
|
626
|
+
) -> None:
|
|
627
|
+
with tracing.section_as_event(span, "write_response_to_queue"):
|
|
628
|
+
try:
|
|
629
|
+
async for chunk in generator:
|
|
630
|
+
await queue.put(chunk)
|
|
631
|
+
except Exception as e:
|
|
632
|
+
self._logger.exception(
|
|
633
|
+
f"Exception while generating streamed response: {str(e)}",
|
|
634
|
+
exc_info=errors.filter_traceback(self._model_file_name),
|
|
635
|
+
)
|
|
636
|
+
finally:
|
|
637
|
+
await queue.put(SENTINEL)
|
|
638
|
+
|
|
639
|
+
async def _stream_with_background_task(
|
|
640
|
+
self,
|
|
641
|
+
generator: Union[Generator[bytes, None, None], AsyncGenerator[bytes, None]],
|
|
642
|
+
span: trace.Span,
|
|
643
|
+
trace_ctx: trace.Context,
|
|
644
|
+
cleanup_fn: Callable[[], None],
|
|
645
|
+
) -> AsyncGenerator[bytes, None]:
|
|
646
|
+
# The streaming read timeout is the amount of time in between streamed chunk
|
|
647
|
+
# before a timeout is triggered.
|
|
648
|
+
streaming_read_timeout = self._config.get("runtime", {}).get(
|
|
649
|
+
"streaming_read_timeout", STREAMING_RESPONSE_QUEUE_READ_TIMEOUT_SECS
|
|
650
|
+
)
|
|
651
|
+
async_generator = _force_async_generator(generator)
|
|
652
|
+
# To ensure that a partial read from a client does not keep the semaphore
|
|
653
|
+
# claimed, we write all the data from the stream to the queue as it is produced,
|
|
654
|
+
# irrespective of how fast it is consumed.
|
|
655
|
+
# We then return a new generator that reads from the queue, and then
|
|
656
|
+
# exits the semaphore block.
|
|
657
|
+
response_queue: asyncio.Queue = asyncio.Queue()
|
|
658
|
+
|
|
659
|
+
# `write_response_to_queue` keeps running the background until completion.
|
|
660
|
+
gen_task = asyncio.create_task(
|
|
661
|
+
self._write_response_to_queue(response_queue, async_generator, span)
|
|
662
|
+
)
|
|
663
|
+
# Defer the release of the semaphore until the write_response_to_queue task.
|
|
664
|
+
gen_task.add_done_callback(lambda _: cleanup_fn())
|
|
665
|
+
|
|
666
|
+
# The gap between responses in a stream must be < streaming_read_timeout
|
|
667
|
+
# TODO: this whole buffering might be superfluous and sufficiently done by
|
|
668
|
+
# by the FastAPI server already. See `test_limit_concurrency_with_sse`.
|
|
669
|
+
async def _buffered_response_generator() -> AsyncGenerator[bytes, None]:
|
|
670
|
+
# `span` is tied to the "producer" `gen_task` which might complete before
|
|
671
|
+
# "consume" part here finishes, therefore a dedicated span is required.
|
|
672
|
+
# Because all of this code is inside a `detach_context` block, we
|
|
673
|
+
# explicitly propagate the tracing context for this span.
|
|
674
|
+
with self._tracer.start_as_current_span(
|
|
675
|
+
"buffered-response-generator", context=trace_ctx
|
|
676
|
+
):
|
|
677
|
+
while True:
|
|
678
|
+
chunk = await asyncio.wait_for(
|
|
679
|
+
response_queue.get(), timeout=streaming_read_timeout
|
|
680
|
+
)
|
|
681
|
+
if chunk == SENTINEL:
|
|
682
|
+
return
|
|
683
|
+
yield chunk
|
|
684
|
+
|
|
685
|
+
return _buffered_response_generator()
|
|
686
|
+
|
|
687
|
+
async def _execute_async_model_fn(
|
|
688
|
+
self,
|
|
689
|
+
inputs: Union[InputType, Any],
|
|
690
|
+
request: starlette.requests.Request,
|
|
691
|
+
descriptor: MethodDescriptor,
|
|
692
|
+
) -> OutputType:
|
|
693
|
+
args = ArgConfig.prepare_args(inputs, request, descriptor)
|
|
694
|
+
with errors.intercept_exceptions(self._logger, self._model_file_name):
|
|
695
|
+
if descriptor.is_generator:
|
|
696
|
+
# Even for async generators, don't await here.
|
|
697
|
+
return descriptor.method(*args)
|
|
698
|
+
if descriptor.is_async:
|
|
699
|
+
return await cast(Awaitable[OutputType], descriptor.method(*args))
|
|
700
|
+
return await to_thread.run_sync(descriptor.method, *args)
|
|
701
|
+
|
|
702
|
+
async def _process_model_fn(
|
|
703
|
+
self,
|
|
704
|
+
inputs: InputType,
|
|
705
|
+
request: starlette.requests.Request,
|
|
706
|
+
descriptor: MethodDescriptor,
|
|
707
|
+
) -> OutputType:
|
|
708
|
+
"""
|
|
709
|
+
Wraps the execution of any model code other than `predict`.
|
|
710
|
+
"""
|
|
711
|
+
fn_span = self._tracer.start_span(f"call-{descriptor.method_name}")
|
|
712
|
+
# TODO(nikhil): Make it easier to start a section with detached context.
|
|
713
|
+
with tracing.section_as_event(
|
|
714
|
+
fn_span, descriptor.method_name
|
|
715
|
+
), tracing.detach_context() as detached_ctx:
|
|
716
|
+
result = await self._execute_async_model_fn(inputs, request, descriptor)
|
|
717
|
+
|
|
718
|
+
if inspect.isgenerator(result) or inspect.isasyncgen(result):
|
|
719
|
+
return await self._handle_generator_response(
|
|
720
|
+
request, result, fn_span, detached_ctx
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
return result
|
|
724
|
+
|
|
725
|
+
def _should_gather_generator(self, request: starlette.requests.Request) -> bool:
|
|
726
|
+
# The OpenAI SDK sends an accept header for JSON even in a streaming context,
|
|
727
|
+
# but we need to stream results back for client compatibility. Luckily,
|
|
728
|
+
# we can differentiate by looking at the user agent (e.g. OpenAI/Python 1.61.0)
|
|
729
|
+
user_agent = request.headers.get("user-agent", "")
|
|
730
|
+
if "openai" in user_agent.lower():
|
|
731
|
+
return False
|
|
732
|
+
# TODO(nikhil): determine if we can safely deprecate this behavior.
|
|
733
|
+
return request.headers.get("accept") == "application/json"
|
|
734
|
+
|
|
735
|
+
async def _handle_generator_response(
|
|
736
|
+
self,
|
|
737
|
+
request: starlette.requests.Request,
|
|
738
|
+
generator: Union[Generator[bytes, None, None], AsyncGenerator[bytes, None]],
|
|
739
|
+
span: trace.Span,
|
|
740
|
+
trace_ctx: trace.Context,
|
|
741
|
+
get_cleanup_fn: Callable[[], Callable[[], None]] = lambda: lambda: None,
|
|
742
|
+
):
|
|
743
|
+
if self._should_gather_generator(request):
|
|
744
|
+
return await _gather_generator(generator)
|
|
745
|
+
else:
|
|
746
|
+
return await self._stream_with_background_task(
|
|
747
|
+
generator, span, trace_ctx, cleanup_fn=get_cleanup_fn()
|
|
748
|
+
)
|
|
749
|
+
|
|
750
|
+
async def completions(
|
|
751
|
+
self, inputs: InputType, request: starlette.requests.Request
|
|
752
|
+
) -> OutputType:
|
|
753
|
+
descriptor = self.model_descriptor.completions
|
|
754
|
+
assert descriptor, (
|
|
755
|
+
f"`{MethodName.COMPLETIONS}` must only be called if model has it."
|
|
756
|
+
)
|
|
757
|
+
|
|
758
|
+
return await self._process_model_fn(inputs, request, descriptor)
|
|
759
|
+
|
|
760
|
+
async def chat_completions(
|
|
761
|
+
self, inputs: InputType, request: starlette.requests.Request
|
|
762
|
+
) -> OutputType:
|
|
763
|
+
descriptor = self.model_descriptor.chat_completions
|
|
764
|
+
assert descriptor, (
|
|
765
|
+
f"`{MethodName.CHAT_COMPLETIONS}` must only be called if model has it."
|
|
766
|
+
)
|
|
767
|
+
|
|
768
|
+
return await self._process_model_fn(inputs, request, descriptor)
|
|
769
|
+
|
|
770
|
+
async def __call__(
|
|
771
|
+
self, inputs: Optional[InputType], request: starlette.requests.Request
|
|
772
|
+
) -> OutputType:
|
|
773
|
+
"""
|
|
774
|
+
Returns result from: preprocess -> predictor -> postprocess.
|
|
775
|
+
"""
|
|
776
|
+
if self.model_descriptor.preprocess:
|
|
777
|
+
with self._tracer.start_as_current_span("call-pre") as span_pre:
|
|
778
|
+
# TODO(nikhil): Make it easier to start a section with detached context.
|
|
779
|
+
with tracing.section_as_event(
|
|
780
|
+
span_pre, "preprocess"
|
|
781
|
+
), tracing.detach_context():
|
|
782
|
+
preprocess_result = await self.preprocess(inputs, request)
|
|
783
|
+
else:
|
|
784
|
+
preprocess_result = inputs
|
|
785
|
+
|
|
786
|
+
span_predict = self._tracer.start_span("call-predict")
|
|
787
|
+
async with deferred_semaphore_and_span(
|
|
788
|
+
self._predict_semaphore, span_predict
|
|
789
|
+
) as get_defer_fn:
|
|
790
|
+
# TODO(nikhil): Make it easier to start a section with detached context.
|
|
791
|
+
with tracing.section_as_event(
|
|
792
|
+
span_predict, "predict"
|
|
793
|
+
), tracing.detach_context() as detached_ctx:
|
|
794
|
+
# To prevent span pollution, we need to make sure spans created by user
|
|
795
|
+
# code don't inherit context from our spans (which happens even if
|
|
796
|
+
# different tracer instances are used).
|
|
797
|
+
# Therefor, predict is run in `detach_context`.
|
|
798
|
+
# There is one caveat with streaming predictions though:
|
|
799
|
+
# The context manager only detaches spans that are created outside
|
|
800
|
+
# the generator loop that yields the stream (because the parts of the
|
|
801
|
+
# loop body will be executed in a "deferred" way (same reasoning as for
|
|
802
|
+
# using `deferred_semaphore_and_span`). We assume that here that
|
|
803
|
+
# creating spans inside the loop body is very unlikely. In order to
|
|
804
|
+
# exactly handle that case we would need to apply `detach_context`
|
|
805
|
+
# around each `next`-invocation that consumes the generator, which is
|
|
806
|
+
# prohibitive.
|
|
807
|
+
predict_result = await self.predict(preprocess_result, request)
|
|
808
|
+
|
|
809
|
+
if inspect.isgenerator(predict_result) or inspect.isasyncgen(
|
|
810
|
+
predict_result
|
|
811
|
+
):
|
|
812
|
+
if self.model_descriptor.postprocess:
|
|
813
|
+
with errors.intercept_exceptions(
|
|
814
|
+
self._logger, self._model_file_name
|
|
815
|
+
):
|
|
816
|
+
raise errors.ModelDefinitionError(
|
|
817
|
+
"If the predict function returns a generator (streaming), "
|
|
818
|
+
"you cannot use postprocessing. Include all processing in "
|
|
819
|
+
"the predict method."
|
|
820
|
+
)
|
|
821
|
+
|
|
822
|
+
return await self._handle_generator_response(
|
|
823
|
+
request,
|
|
824
|
+
predict_result,
|
|
825
|
+
span_predict,
|
|
826
|
+
detached_ctx,
|
|
827
|
+
get_cleanup_fn=get_defer_fn,
|
|
828
|
+
)
|
|
829
|
+
|
|
830
|
+
if isinstance(predict_result, starlette.responses.Response):
|
|
831
|
+
if self.model_descriptor.postprocess:
|
|
832
|
+
with errors.intercept_exceptions(
|
|
833
|
+
self._logger, self._model_file_name
|
|
834
|
+
):
|
|
835
|
+
raise errors.ModelDefinitionError(
|
|
836
|
+
"If the predict function returns a response object, "
|
|
837
|
+
"you cannot use postprocessing."
|
|
838
|
+
)
|
|
839
|
+
if isinstance(predict_result, starlette.responses.StreamingResponse):
|
|
840
|
+
# Defer the semaphore release, using a weakref on the response.
|
|
841
|
+
# This might keep the semaphore longer than using "native" truss
|
|
842
|
+
# streaming, because here the criterion is not the production of
|
|
843
|
+
# data by the generator, but the span of handling the request by
|
|
844
|
+
# the fastAPI server.
|
|
845
|
+
weakref.finalize(predict_result, get_defer_fn())
|
|
846
|
+
|
|
847
|
+
return predict_result
|
|
848
|
+
|
|
849
|
+
if self.model_descriptor.postprocess:
|
|
850
|
+
with self._tracer.start_as_current_span("call-post") as span_post:
|
|
851
|
+
# TODO(nikhil): Make it easier to start a section with detached context.
|
|
852
|
+
with tracing.section_as_event(
|
|
853
|
+
span_post, "postprocess"
|
|
854
|
+
), tracing.detach_context():
|
|
855
|
+
postprocess_result = await self.postprocess(predict_result, request)
|
|
856
|
+
return postprocess_result
|
|
857
|
+
else:
|
|
858
|
+
return predict_result
|
|
859
|
+
|
|
860
|
+
|
|
861
|
+
async def _gather_generator(
|
|
862
|
+
predict_result: Union[AsyncGenerator[bytes, None], Generator[bytes, None, None]],
|
|
863
|
+
) -> str:
|
|
864
|
+
return "".join(
|
|
865
|
+
[str(chunk) async for chunk in _force_async_generator(predict_result)]
|
|
866
|
+
)
|
|
867
|
+
|
|
868
|
+
|
|
869
|
+
def _force_async_generator(gen: Union[Generator, AsyncGenerator]) -> AsyncGenerator:
|
|
870
|
+
"""
|
|
871
|
+
Takes a generator, and converts it into an async generator if it is not already.
|
|
872
|
+
"""
|
|
873
|
+
if inspect.isasyncgen(gen):
|
|
874
|
+
return gen
|
|
875
|
+
|
|
876
|
+
async def _convert_generator_to_async():
|
|
877
|
+
"""
|
|
878
|
+
Runs each iteration of the generator in an offloaded thread, to ensure
|
|
879
|
+
the main loop is not blocked, and yield to create an async generator.
|
|
880
|
+
"""
|
|
881
|
+
while True:
|
|
882
|
+
# Note that this is the equivalent of running:
|
|
883
|
+
# next(gen, FINAL_GENERATOR_VALUE) on a separate thread,
|
|
884
|
+
# ensuring that if there is anything blocking in the generator,
|
|
885
|
+
# it does not block the main loop.
|
|
886
|
+
chunk = await to_thread.run_sync(next, gen, SENTINEL)
|
|
887
|
+
if chunk == SENTINEL:
|
|
888
|
+
return
|
|
889
|
+
yield chunk
|
|
890
|
+
|
|
891
|
+
return _convert_generator_to_async()
|
|
892
|
+
|
|
893
|
+
|
|
894
|
+
def _signature_accepts_keyword_arg(signature: inspect.Signature, kwarg: str) -> bool:
|
|
895
|
+
return kwarg in signature.parameters or _signature_accepts_kwargs(signature)
|
|
896
|
+
|
|
897
|
+
|
|
898
|
+
def _signature_accepts_kwargs(signature: inspect.Signature) -> bool:
|
|
899
|
+
for param in signature.parameters.values():
|
|
900
|
+
if param.kind == inspect.Parameter.VAR_KEYWORD:
|
|
901
|
+
return True
|
|
902
|
+
return False
|
|
903
|
+
|
|
904
|
+
|
|
905
|
+
def _elapsed_ms(since_micro_seconds: float) -> int:
|
|
906
|
+
return int((time.perf_counter() - since_micro_seconds) * 1000)
|
|
907
|
+
|
|
908
|
+
|
|
909
|
+
def _init_extensions(config, data_dir, secrets, lazy_data_resolver):
|
|
910
|
+
extensions = {}
|
|
911
|
+
extensions_path = Path(__file__).parent / EXTENSIONS_DIR_NAME
|
|
912
|
+
if extensions_path.exists():
|
|
913
|
+
for extension_path in extensions_path.iterdir():
|
|
914
|
+
if extension_path.is_dir():
|
|
915
|
+
extension_name = extension_path.name
|
|
916
|
+
extension = _init_extension(
|
|
917
|
+
extension_name, config, data_dir, secrets, lazy_data_resolver
|
|
918
|
+
)
|
|
919
|
+
extensions[extension_name] = extension
|
|
920
|
+
return extensions
|
|
921
|
+
|
|
922
|
+
|
|
923
|
+
def _init_extension(extension_name: str, config, data_dir, secrets, lazy_data_resolver):
|
|
924
|
+
extension_module = importlib.import_module(
|
|
925
|
+
f"{EXTENSIONS_DIR_NAME}.{extension_name}.{EXTENSION_FILE_NAME}"
|
|
926
|
+
)
|
|
927
|
+
extension_class = getattr(extension_module, EXTENSION_CLASS_NAME)
|
|
928
|
+
init_args = _prepare_init_args(
|
|
929
|
+
extension_class,
|
|
930
|
+
config=config,
|
|
931
|
+
data_dir=data_dir,
|
|
932
|
+
secrets=secrets,
|
|
933
|
+
lazy_data_resolver=lazy_data_resolver,
|
|
934
|
+
)
|
|
935
|
+
return extension_class(**init_args)
|
|
936
|
+
|
|
937
|
+
|
|
938
|
+
def _prepare_init_args(klass, config, data_dir, secrets, lazy_data_resolver):
|
|
939
|
+
"""Prepares init params based on signature.
|
|
940
|
+
|
|
941
|
+
Used to pass params to extension and model class' __init__ function.
|
|
942
|
+
"""
|
|
943
|
+
signature = inspect.signature(klass)
|
|
944
|
+
model_init_params = {}
|
|
945
|
+
if _signature_accepts_keyword_arg(signature, "config"):
|
|
946
|
+
model_init_params["config"] = config
|
|
947
|
+
if _signature_accepts_keyword_arg(signature, "data_dir"):
|
|
948
|
+
model_init_params["data_dir"] = data_dir
|
|
949
|
+
if _signature_accepts_keyword_arg(signature, "secrets"):
|
|
950
|
+
model_init_params["secrets"] = secrets
|
|
951
|
+
if _signature_accepts_keyword_arg(signature, "lazy_data_resolver"):
|
|
952
|
+
model_init_params["lazy_data_resolver"] = lazy_data_resolver.fetch()
|
|
953
|
+
if _signature_accepts_keyword_arg(signature, "environment"):
|
|
954
|
+
environment = None
|
|
955
|
+
environment_str = dynamic_config_resolver.get_dynamic_config_value_sync(
|
|
956
|
+
dynamic_config_resolver.ENVIRONMENT_DYNAMIC_CONFIG_KEY
|
|
957
|
+
)
|
|
958
|
+
if environment_str:
|
|
959
|
+
environment = json.loads(environment_str)
|
|
960
|
+
model_init_params["environment"] = environment
|
|
961
|
+
return model_init_params
|