truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/__init__.py +10 -3
- truss/api/__init__.py +123 -0
- truss/api/definitions.py +51 -0
- truss/base/constants.py +116 -0
- truss/base/custom_types.py +29 -0
- truss/{errors.py → base/errors.py} +4 -0
- truss/base/trt_llm_config.py +310 -0
- truss/{truss_config.py → base/truss_config.py} +344 -31
- truss/{truss_spec.py → base/truss_spec.py} +20 -6
- truss/{validation.py → base/validation.py} +60 -11
- truss/cli/cli.py +841 -88
- truss/{remote → cli}/remote_cli.py +2 -7
- truss/contexts/docker_build_setup.py +67 -0
- truss/contexts/image_builder/cache_warmer.py +2 -8
- truss/contexts/image_builder/image_builder.py +1 -1
- truss/contexts/image_builder/serving_image_builder.py +292 -46
- truss/contexts/image_builder/util.py +1 -3
- truss/contexts/local_loader/docker_build_emulator.py +58 -0
- truss/contexts/local_loader/load_model_local.py +2 -2
- truss/contexts/local_loader/truss_module_loader.py +1 -1
- truss/contexts/local_loader/utils.py +1 -1
- truss/local/local_config.py +2 -6
- truss/local/local_config_handler.py +20 -5
- truss/patch/__init__.py +1 -0
- truss/patch/hash.py +4 -70
- truss/patch/signature.py +4 -16
- truss/patch/truss_dir_patch_applier.py +3 -78
- truss/remote/baseten/api.py +308 -23
- truss/remote/baseten/auth.py +3 -3
- truss/remote/baseten/core.py +257 -50
- truss/remote/baseten/custom_types.py +44 -0
- truss/remote/baseten/error.py +4 -0
- truss/remote/baseten/remote.py +369 -118
- truss/remote/baseten/service.py +118 -11
- truss/remote/baseten/utils/status.py +29 -0
- truss/remote/baseten/utils/tar.py +34 -22
- truss/remote/baseten/utils/transfer.py +36 -23
- truss/remote/remote_factory.py +14 -5
- truss/remote/truss_remote.py +72 -45
- truss/templates/base.Dockerfile.jinja +18 -16
- truss/templates/cache.Dockerfile.jinja +3 -3
- truss/{server → templates/control}/control/application.py +14 -35
- truss/{server → templates/control}/control/endpoints.py +39 -9
- truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
- truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
- truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
- truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
- truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
- truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
- truss/{server → templates/control}/control/server.py +11 -6
- truss/templates/control/requirements.txt +9 -0
- truss/templates/custom_python_dx/my_model.py +28 -0
- truss/templates/docker_server/proxy.conf.jinja +42 -0
- truss/templates/docker_server/supervisord.conf.jinja +27 -0
- truss/templates/docker_server_requirements.txt +1 -0
- truss/templates/server/common/errors.py +231 -0
- truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
- truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
- truss/{server → templates/server}/common/retry.py +1 -0
- truss/{server → templates/server}/common/schema.py +11 -9
- truss/templates/server/common/tracing.py +157 -0
- truss/templates/server/main.py +9 -0
- truss/templates/server/model_wrapper.py +961 -0
- truss/templates/server/requirements.txt +21 -0
- truss/templates/server/truss_server.py +447 -0
- truss/templates/server.Dockerfile.jinja +62 -14
- truss/templates/shared/dynamic_config_resolver.py +28 -0
- truss/templates/shared/lazy_data_resolver.py +164 -0
- truss/templates/shared/log_config.py +125 -0
- truss/{server → templates}/shared/secrets_resolver.py +1 -2
- truss/{server → templates}/shared/serialization.py +31 -9
- truss/{server → templates}/shared/util.py +3 -13
- truss/templates/trtllm-audio/model/model.py +49 -0
- truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
- truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
- truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
- truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
- truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
- truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
- truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
- truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
- truss/templates/trtllm-briton/src/extension.py +64 -0
- truss/tests/conftest.py +302 -94
- truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
- truss/tests/contexts/local_loader/test_load_local.py +2 -2
- truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
- truss/tests/patch/test_calc_patch.py +439 -127
- truss/tests/patch/test_dir_signature.py +3 -12
- truss/tests/patch/test_hash.py +1 -1
- truss/tests/patch/test_signature.py +1 -1
- truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
- truss/tests/patch/test_types.py +2 -2
- truss/tests/remote/baseten/test_api.py +153 -58
- truss/tests/remote/baseten/test_auth.py +2 -1
- truss/tests/remote/baseten/test_core.py +160 -12
- truss/tests/remote/baseten/test_remote.py +489 -77
- truss/tests/remote/baseten/test_service.py +55 -0
- truss/tests/remote/test_remote_factory.py +16 -18
- truss/tests/remote/test_truss_remote.py +26 -17
- truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
- truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
- truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
- truss/tests/{server → templates/control}/control/test_server.py +79 -24
- truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
- truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
- truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
- truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
- truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
- truss/tests/{server → templates/server}/common/test_retry.py +3 -3
- truss/tests/templates/server/test_model_wrapper.py +248 -0
- truss/tests/{server → templates/server}/test_schema.py +3 -5
- truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
- truss/tests/test_build.py +9 -52
- truss/tests/test_config.py +336 -77
- truss/tests/test_context_builder_image.py +3 -11
- truss/tests/test_control_truss_patching.py +7 -12
- truss/tests/test_custom_server.py +38 -0
- truss/tests/test_data/context_builder_image_test/test.py +3 -0
- truss/tests/test_data/happy.ipynb +56 -0
- truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
- truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
- truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
- truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
- truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
- truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
- truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
- truss/tests/test_data/test_async_truss/__init__.py +0 -0
- truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/config.yaml +16 -0
- truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/config.yaml +13 -0
- truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
- truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
- truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
- truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
- truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
- truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
- truss/tests/test_data/test_openai/__init__.py +0 -0
- truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
- truss/tests/test_data/test_openai/model/__init__.py +0 -0
- truss/tests/test_data/test_openai/model/model.py +15 -0
- truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
- truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
- truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
- truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
- truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
- truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
- truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
- truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
- truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
- truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
- truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
- truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
- truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
- truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
- truss/tests/test_data/test_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss/config.yaml +4 -0
- truss/tests/test_data/test_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss/model/dummy +0 -0
- truss/tests/test_data/test_truss/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
- truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
- truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_docker.py +2 -1
- truss/tests/test_model_inference.py +1340 -292
- truss/tests/test_model_schema.py +33 -26
- truss/tests/test_testing_utilities_for_other_tests.py +50 -5
- truss/tests/test_truss_gatherer.py +3 -5
- truss/tests/test_truss_handle.py +62 -59
- truss/tests/test_util.py +2 -1
- truss/tests/test_validation.py +15 -13
- truss/tests/trt_llm/test_trt_llm_config.py +41 -0
- truss/tests/trt_llm/test_validation.py +91 -0
- truss/tests/util/test_config_checks.py +40 -0
- truss/tests/util/test_env_vars.py +14 -0
- truss/tests/util/test_path.py +10 -23
- truss/trt_llm/config_checks.py +43 -0
- truss/trt_llm/validation.py +42 -0
- truss/truss_handle/__init__.py +0 -0
- truss/truss_handle/build.py +122 -0
- truss/{decorators.py → truss_handle/decorators.py} +1 -1
- truss/truss_handle/patch/__init__.py +0 -0
- truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
- truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
- truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
- truss/truss_handle/patch/hash.py +71 -0
- truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
- truss/truss_handle/patch/signature.py +22 -0
- truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
- truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
- truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
- truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
- truss/util/.truss_ignore +3 -0
- truss/{docker.py → util/docker.py} +6 -2
- truss/util/download.py +6 -15
- truss/util/env_vars.py +41 -0
- truss/util/log_utils.py +52 -0
- truss/util/path.py +20 -20
- truss/util/requirements.py +11 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
- truss-0.60.0.dist-info/RECORD +324 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
- truss-0.60.0.dist-info/entry_points.txt +4 -0
- truss_chains/__init__.py +71 -0
- truss_chains/definitions.py +756 -0
- truss_chains/deployment/__init__.py +0 -0
- truss_chains/deployment/code_gen.py +816 -0
- truss_chains/deployment/deployment_client.py +871 -0
- truss_chains/framework.py +1480 -0
- truss_chains/public_api.py +231 -0
- truss_chains/py.typed +0 -0
- truss_chains/pydantic_numpy.py +131 -0
- truss_chains/reference_code/reference_chainlet.py +34 -0
- truss_chains/reference_code/reference_model.py +10 -0
- truss_chains/remote_chainlet/__init__.py +0 -0
- truss_chains/remote_chainlet/model_skeleton.py +60 -0
- truss_chains/remote_chainlet/stub.py +380 -0
- truss_chains/remote_chainlet/utils.py +332 -0
- truss_chains/streaming.py +378 -0
- truss_chains/utils.py +178 -0
- CODE_OF_CONDUCT.md +0 -131
- CONTRIBUTING.md +0 -48
- README.md +0 -137
- context_builder.Dockerfile +0 -24
- truss/blob/blob_backend.py +0 -10
- truss/blob/blob_backend_registry.py +0 -23
- truss/blob/http_public_blob_backend.py +0 -23
- truss/build/__init__.py +0 -2
- truss/build/build.py +0 -143
- truss/build/configure.py +0 -63
- truss/cli/__init__.py +0 -2
- truss/cli/console.py +0 -5
- truss/cli/create.py +0 -5
- truss/config/trt_llm.py +0 -81
- truss/constants.py +0 -61
- truss/model_inference.py +0 -123
- truss/patch/types.py +0 -30
- truss/pytest.ini +0 -7
- truss/server/common/errors.py +0 -100
- truss/server/common/termination_handler_middleware.py +0 -64
- truss/server/common/truss_server.py +0 -389
- truss/server/control/patch/model_code_patch_applier.py +0 -46
- truss/server/control/patch/requirement_name_identifier.py +0 -17
- truss/server/inference_server.py +0 -29
- truss/server/model_wrapper.py +0 -434
- truss/server/shared/logging.py +0 -81
- truss/templates/trtllm/model/model.py +0 -97
- truss/templates/trtllm/packages/build_engine_utils.py +0 -34
- truss/templates/trtllm/packages/constants.py +0 -11
- truss/templates/trtllm/packages/schema.py +0 -216
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
- truss/templates/trtllm/packages/triton_client.py +0 -150
- truss/templates/trtllm/packages/utils.py +0 -43
- truss/test_data/context_builder_image_test/test.py +0 -4
- truss/test_data/happy.ipynb +0 -54
- truss/test_data/model_load_failure_test/config.yaml +0 -2
- truss/test_data/test_concurrency_truss/config.yaml +0 -2
- truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
- truss/test_data/test_streaming_truss/config.yaml +0 -3
- truss/test_data/test_truss/config.yaml +0 -2
- truss/tests/server/common/test_termination_handler_middleware.py +0 -93
- truss/tests/server/control/test_model_container_patch_applier.py +0 -203
- truss/tests/server/core/server/common/test_util.py +0 -19
- truss/tests/server/test_model_wrapper.py +0 -87
- truss/util/data_structures.py +0 -16
- truss-0.10.0rc1.dist-info/RECORD +0 -216
- truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
- truss/{server/shared → base}/__init__.py +0 -0
- truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
- truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
- truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
- truss/{server → templates/server}/common/__init__.py +0 -0
- truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
- truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
- truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
- truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
- truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
- truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
- truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
- truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
- truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
- truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
- truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
- truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
- truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
- truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
- /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
- /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
- /truss/{patch → truss_handle/patch}/constants.py +0 -0
- /truss/{notebook.py → util/notebook.py} +0 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
truss/server/model_wrapper.py
DELETED
|
@@ -1,434 +0,0 @@
|
|
|
1
|
-
import asyncio
|
|
2
|
-
import importlib
|
|
3
|
-
import inspect
|
|
4
|
-
import logging
|
|
5
|
-
import os
|
|
6
|
-
import sys
|
|
7
|
-
import time
|
|
8
|
-
from collections.abc import Generator
|
|
9
|
-
from contextlib import asynccontextmanager
|
|
10
|
-
from enum import Enum
|
|
11
|
-
from multiprocessing import Lock
|
|
12
|
-
from pathlib import Path
|
|
13
|
-
from threading import Thread
|
|
14
|
-
from typing import Any, AsyncGenerator, Dict, Optional, Set, Union
|
|
15
|
-
|
|
16
|
-
import pydantic
|
|
17
|
-
from anyio import Semaphore, to_thread
|
|
18
|
-
from fastapi import HTTPException
|
|
19
|
-
from pydantic import BaseModel
|
|
20
|
-
from truss.server.common.patches import apply_patches
|
|
21
|
-
from truss.server.common.retry import retry
|
|
22
|
-
from truss.server.common.schema import TrussSchema
|
|
23
|
-
from truss.server.shared.secrets_resolver import SecretsResolver
|
|
24
|
-
|
|
25
|
-
MODEL_BASENAME = "model"
|
|
26
|
-
|
|
27
|
-
NUM_LOAD_RETRIES = int(os.environ.get("NUM_LOAD_RETRIES_TRUSS", "1"))
|
|
28
|
-
STREAMING_RESPONSE_QUEUE_READ_TIMEOUT_SECS = 60
|
|
29
|
-
DEFAULT_PREDICT_CONCURRENCY = 1
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
class DeferredSemaphoreManager:
|
|
33
|
-
"""
|
|
34
|
-
Helper class for supported deferred semaphore release.
|
|
35
|
-
"""
|
|
36
|
-
|
|
37
|
-
def __init__(self, semaphore: Semaphore):
|
|
38
|
-
self.semaphore = semaphore
|
|
39
|
-
self.deferred = False
|
|
40
|
-
|
|
41
|
-
def defer(self):
|
|
42
|
-
"""
|
|
43
|
-
Track that this semaphore is to be deferred, and return
|
|
44
|
-
a release method that the context block can use to release
|
|
45
|
-
the semaphore.
|
|
46
|
-
"""
|
|
47
|
-
self.deferred = True
|
|
48
|
-
|
|
49
|
-
return self.semaphore.release
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
@asynccontextmanager
|
|
53
|
-
async def deferred_semaphore(semaphore: Semaphore):
|
|
54
|
-
"""
|
|
55
|
-
Context manager that allows deferring the release of a semaphore.
|
|
56
|
-
It yields a DeferredSemaphoreManager -- in your use of this context manager,
|
|
57
|
-
if you call DeferredSemaphoreManager.defer(), you will get back a function that releases
|
|
58
|
-
the semaphore that you must call.
|
|
59
|
-
"""
|
|
60
|
-
semaphore_manager = DeferredSemaphoreManager(semaphore)
|
|
61
|
-
await semaphore.acquire()
|
|
62
|
-
|
|
63
|
-
try:
|
|
64
|
-
yield semaphore_manager
|
|
65
|
-
finally:
|
|
66
|
-
if not semaphore_manager.deferred:
|
|
67
|
-
semaphore.release()
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
class ModelWrapper:
|
|
71
|
-
class Status(Enum):
|
|
72
|
-
NOT_READY = 0
|
|
73
|
-
LOADING = 1
|
|
74
|
-
READY = 2
|
|
75
|
-
FAILED = 3
|
|
76
|
-
|
|
77
|
-
def __init__(self, config: Dict):
|
|
78
|
-
self._config = config
|
|
79
|
-
self._logger = logging.getLogger()
|
|
80
|
-
self.name = MODEL_BASENAME
|
|
81
|
-
self.ready = False
|
|
82
|
-
self._load_lock = Lock()
|
|
83
|
-
self._status = ModelWrapper.Status.NOT_READY
|
|
84
|
-
self._predict_semaphore = Semaphore(
|
|
85
|
-
self._config.get("runtime", {}).get(
|
|
86
|
-
"predict_concurrency", DEFAULT_PREDICT_CONCURRENCY
|
|
87
|
-
)
|
|
88
|
-
)
|
|
89
|
-
self._background_tasks: Set[asyncio.Task] = set()
|
|
90
|
-
self.truss_schema: Optional[TrussSchema] = None
|
|
91
|
-
self.app_home: Path = Path(os.environ.get("APP_HOME", default=str(Path.cwd())))
|
|
92
|
-
|
|
93
|
-
def load(self) -> bool:
|
|
94
|
-
if self.ready:
|
|
95
|
-
return self.ready
|
|
96
|
-
|
|
97
|
-
# if we are already loading, block on aquiring the lock;
|
|
98
|
-
# this worker will return 503 while the worker with the lock is loading
|
|
99
|
-
with self._load_lock:
|
|
100
|
-
self._status = ModelWrapper.Status.LOADING
|
|
101
|
-
|
|
102
|
-
self._logger.info("Executing model.load()...")
|
|
103
|
-
|
|
104
|
-
try:
|
|
105
|
-
start_time = time.perf_counter()
|
|
106
|
-
self.try_load()
|
|
107
|
-
self.ready = True
|
|
108
|
-
self._status = ModelWrapper.Status.READY
|
|
109
|
-
self._logger.info(
|
|
110
|
-
f"Completed model.load() execution in {_elapsed_ms(start_time)} ms"
|
|
111
|
-
)
|
|
112
|
-
|
|
113
|
-
return self.ready
|
|
114
|
-
except Exception:
|
|
115
|
-
self._logger.exception("Exception while loading model")
|
|
116
|
-
self._status = ModelWrapper.Status.FAILED
|
|
117
|
-
|
|
118
|
-
return self.ready
|
|
119
|
-
|
|
120
|
-
def start_load(self):
|
|
121
|
-
if self.should_load():
|
|
122
|
-
thread = Thread(target=self.load)
|
|
123
|
-
thread.start()
|
|
124
|
-
|
|
125
|
-
def load_failed(self) -> bool:
|
|
126
|
-
return self._status == ModelWrapper.Status.FAILED
|
|
127
|
-
|
|
128
|
-
def should_load(self) -> bool:
|
|
129
|
-
# don't retry failed loads
|
|
130
|
-
return not self._status == ModelWrapper.Status.FAILED and not self.ready
|
|
131
|
-
|
|
132
|
-
def try_load(self):
|
|
133
|
-
data_dir = Path("data")
|
|
134
|
-
data_dir.mkdir(exist_ok=True)
|
|
135
|
-
|
|
136
|
-
sys.path.append(str(self.app_home))
|
|
137
|
-
if "bundled_packages_dir" in self._config:
|
|
138
|
-
bundled_packages_path = self.app_home / "packages"
|
|
139
|
-
if bundled_packages_path.exists():
|
|
140
|
-
sys.path.append(str(bundled_packages_path))
|
|
141
|
-
|
|
142
|
-
model_module_name = str(
|
|
143
|
-
Path(self._config["model_class_filename"]).with_suffix("")
|
|
144
|
-
)
|
|
145
|
-
|
|
146
|
-
module = importlib.import_module(
|
|
147
|
-
f"{self._config['model_module_dir']}.{model_module_name}"
|
|
148
|
-
)
|
|
149
|
-
model_class = getattr(module, self._config["model_class_name"])
|
|
150
|
-
model_class_signature = inspect.signature(model_class)
|
|
151
|
-
model_init_params = {}
|
|
152
|
-
if _signature_accepts_keyword_arg(model_class_signature, "config"):
|
|
153
|
-
model_init_params["config"] = self._config
|
|
154
|
-
if _signature_accepts_keyword_arg(model_class_signature, "data_dir"):
|
|
155
|
-
model_init_params["data_dir"] = data_dir
|
|
156
|
-
if _signature_accepts_keyword_arg(model_class_signature, "secrets"):
|
|
157
|
-
model_init_params["secrets"] = SecretsResolver.get_secrets(self._config)
|
|
158
|
-
apply_patches(
|
|
159
|
-
self._config.get("apply_library_patches", True),
|
|
160
|
-
self._config["requirements"],
|
|
161
|
-
)
|
|
162
|
-
self._model = model_class(**model_init_params)
|
|
163
|
-
|
|
164
|
-
self.set_truss_schema()
|
|
165
|
-
|
|
166
|
-
if hasattr(self._model, "load"):
|
|
167
|
-
retry(
|
|
168
|
-
self._model.load,
|
|
169
|
-
NUM_LOAD_RETRIES,
|
|
170
|
-
self._logger.warn,
|
|
171
|
-
"Failed to load model.",
|
|
172
|
-
gap_seconds=1.0,
|
|
173
|
-
)
|
|
174
|
-
|
|
175
|
-
def set_truss_schema(self):
|
|
176
|
-
parameters = (
|
|
177
|
-
inspect.signature(self._model.preprocess).parameters
|
|
178
|
-
if hasattr(self._model, "preprocess")
|
|
179
|
-
else inspect.signature(self._model.predict).parameters
|
|
180
|
-
)
|
|
181
|
-
|
|
182
|
-
outputs_annotation = (
|
|
183
|
-
inspect.signature(self._model.postprocess).return_annotation
|
|
184
|
-
if hasattr(self._model, "postprocess")
|
|
185
|
-
else inspect.signature(self._model.predict).return_annotation
|
|
186
|
-
)
|
|
187
|
-
|
|
188
|
-
self.truss_schema = TrussSchema.from_signature(parameters, outputs_annotation)
|
|
189
|
-
|
|
190
|
-
async def preprocess(
|
|
191
|
-
self,
|
|
192
|
-
payload: Any,
|
|
193
|
-
headers: Optional[Dict[str, str]] = None,
|
|
194
|
-
) -> Any:
|
|
195
|
-
if not hasattr(self._model, "preprocess"):
|
|
196
|
-
return payload
|
|
197
|
-
|
|
198
|
-
if inspect.iscoroutinefunction(self._model.preprocess):
|
|
199
|
-
return await _intercept_exceptions_async(self._model.preprocess)(payload)
|
|
200
|
-
else:
|
|
201
|
-
return await to_thread.run_sync(
|
|
202
|
-
_intercept_exceptions_sync(self._model.preprocess), payload
|
|
203
|
-
)
|
|
204
|
-
|
|
205
|
-
async def predict(
|
|
206
|
-
self,
|
|
207
|
-
payload: Any,
|
|
208
|
-
headers: Optional[Dict[str, str]] = None,
|
|
209
|
-
) -> Any:
|
|
210
|
-
# It's possible for the user's predict function to be a:
|
|
211
|
-
# 1. Generator function (function that returns a generator)
|
|
212
|
-
# 2. Async generator (function that returns async generator)
|
|
213
|
-
# In these cases, just return the generator or async generator,
|
|
214
|
-
# as we will be propagating these up. No need for await at this point.
|
|
215
|
-
# 3. Coroutine -- in this case, await the predict function as it is async
|
|
216
|
-
# 4. Normal function -- in this case, offload to a separate thread to prevent
|
|
217
|
-
# blocking the main event loop
|
|
218
|
-
if inspect.isasyncgenfunction(
|
|
219
|
-
self._model.predict
|
|
220
|
-
) or inspect.isgeneratorfunction(self._model.predict):
|
|
221
|
-
return self._model.predict(payload)
|
|
222
|
-
|
|
223
|
-
if inspect.iscoroutinefunction(self._model.predict):
|
|
224
|
-
return await _intercept_exceptions_async(self._model.predict)(payload)
|
|
225
|
-
|
|
226
|
-
return await to_thread.run_sync(
|
|
227
|
-
_intercept_exceptions_sync(self._model.predict), payload
|
|
228
|
-
)
|
|
229
|
-
|
|
230
|
-
async def postprocess(
|
|
231
|
-
self,
|
|
232
|
-
response: Any,
|
|
233
|
-
headers: Optional[Dict[str, str]] = None,
|
|
234
|
-
) -> Any:
|
|
235
|
-
# Similar to the predict function, it is possible for postprocess
|
|
236
|
-
# to return either a generator or async generator, in which case
|
|
237
|
-
# just return the generator.
|
|
238
|
-
#
|
|
239
|
-
# It can also return a coroutine or just be a function, in which
|
|
240
|
-
# case either await, or offload to a thread respectively.
|
|
241
|
-
if not hasattr(self._model, "postprocess"):
|
|
242
|
-
return response
|
|
243
|
-
|
|
244
|
-
if inspect.isasyncgenfunction(
|
|
245
|
-
self._model.postprocess
|
|
246
|
-
) or inspect.isgeneratorfunction(self._model.postprocess):
|
|
247
|
-
return self._model.postprocess(response)
|
|
248
|
-
|
|
249
|
-
if inspect.iscoroutinefunction(self._model.postprocess):
|
|
250
|
-
return await _intercept_exceptions_async(self._model.postprocess)(response)
|
|
251
|
-
|
|
252
|
-
return await to_thread.run_sync(
|
|
253
|
-
_intercept_exceptions_sync(self._model.postprocess), response
|
|
254
|
-
)
|
|
255
|
-
|
|
256
|
-
async def write_response_to_queue(
|
|
257
|
-
self, queue: asyncio.Queue, generator: AsyncGenerator
|
|
258
|
-
):
|
|
259
|
-
try:
|
|
260
|
-
async for chunk in generator:
|
|
261
|
-
await queue.put(ResponseChunk(chunk))
|
|
262
|
-
except Exception as e:
|
|
263
|
-
self._logger.exception("Exception while reading stream response: " + str(e))
|
|
264
|
-
finally:
|
|
265
|
-
await queue.put(None)
|
|
266
|
-
|
|
267
|
-
async def __call__(
|
|
268
|
-
self, body: Any, headers: Optional[Dict[str, str]] = None
|
|
269
|
-
) -> Union[Dict, Generator]:
|
|
270
|
-
"""Method to call predictor or explainer with the given input.
|
|
271
|
-
|
|
272
|
-
Args:
|
|
273
|
-
body (Any): Request payload body.
|
|
274
|
-
headers (Dict): Request headers.
|
|
275
|
-
|
|
276
|
-
Returns:
|
|
277
|
-
Dict: Response output from preprocess -> predictor -> postprocess
|
|
278
|
-
Generator: In case of streaming response
|
|
279
|
-
"""
|
|
280
|
-
|
|
281
|
-
# The streaming read timeout is the amount of time in between streamed chunks before a timeout is triggered
|
|
282
|
-
streaming_read_timeout = self._config.get("runtime", {}).get(
|
|
283
|
-
"streaming_read_timeout", STREAMING_RESPONSE_QUEUE_READ_TIMEOUT_SECS
|
|
284
|
-
)
|
|
285
|
-
|
|
286
|
-
if self.truss_schema is not None:
|
|
287
|
-
try:
|
|
288
|
-
body = self.truss_schema.input_type(**body)
|
|
289
|
-
except pydantic.ValidationError as e:
|
|
290
|
-
self._logger.info("Request Validation Error: %s", {str(e)})
|
|
291
|
-
raise HTTPException(
|
|
292
|
-
status_code=400, detail="Request Validation Error"
|
|
293
|
-
) from e
|
|
294
|
-
|
|
295
|
-
payload = await self.preprocess(body, headers)
|
|
296
|
-
|
|
297
|
-
async with deferred_semaphore(self._predict_semaphore) as semaphore_manager:
|
|
298
|
-
response = await self.predict(payload, headers)
|
|
299
|
-
|
|
300
|
-
# Streaming cases
|
|
301
|
-
if inspect.isgenerator(response) or inspect.isasyncgen(response):
|
|
302
|
-
if hasattr(self._model, "postprocess"):
|
|
303
|
-
logging.warning(
|
|
304
|
-
"Predict returned a streaming response, while a postprocess is defined."
|
|
305
|
-
"Note that in this case, the postprocess will run within the predict lock."
|
|
306
|
-
)
|
|
307
|
-
|
|
308
|
-
response = await self.postprocess(response)
|
|
309
|
-
|
|
310
|
-
async_generator = _force_async_generator(response)
|
|
311
|
-
|
|
312
|
-
if headers and headers.get("accept") == "application/json":
|
|
313
|
-
# In the case of a streaming response, consume stream
|
|
314
|
-
# if the http accept header is set, and json is requested.
|
|
315
|
-
return await _convert_streamed_response_to_string(async_generator)
|
|
316
|
-
|
|
317
|
-
# To ensure that a partial read from a client does not cause the semaphore
|
|
318
|
-
# to stay claimed, we immediately write all of the data from the stream to a
|
|
319
|
-
# queue. We then return a new generator that reads from the queue, and then
|
|
320
|
-
# exit the semaphore block.
|
|
321
|
-
response_queue: asyncio.Queue = asyncio.Queue()
|
|
322
|
-
|
|
323
|
-
# This task will be triggered and run in the background.
|
|
324
|
-
task = asyncio.create_task(
|
|
325
|
-
self.write_response_to_queue(response_queue, async_generator)
|
|
326
|
-
)
|
|
327
|
-
|
|
328
|
-
# We add the task to the ModelWrapper instance to ensure it does
|
|
329
|
-
# not get garbage collected after the predict method completes,
|
|
330
|
-
# and continues running.
|
|
331
|
-
self._background_tasks.add(task)
|
|
332
|
-
|
|
333
|
-
# Defer the release of the semaphore until the write_response_to_queue
|
|
334
|
-
# task.
|
|
335
|
-
semaphore_release_function = semaphore_manager.defer()
|
|
336
|
-
task.add_done_callback(lambda _: semaphore_release_function())
|
|
337
|
-
task.add_done_callback(self._background_tasks.discard)
|
|
338
|
-
|
|
339
|
-
# The gap between responses in a stream must be < streaming_read_timeout
|
|
340
|
-
async def _response_generator():
|
|
341
|
-
while True:
|
|
342
|
-
chunk = await asyncio.wait_for(
|
|
343
|
-
response_queue.get(),
|
|
344
|
-
timeout=streaming_read_timeout,
|
|
345
|
-
)
|
|
346
|
-
if chunk is None:
|
|
347
|
-
return
|
|
348
|
-
yield chunk.value
|
|
349
|
-
|
|
350
|
-
return _response_generator()
|
|
351
|
-
|
|
352
|
-
processed_response = await self.postprocess(response)
|
|
353
|
-
|
|
354
|
-
if isinstance(processed_response, BaseModel):
|
|
355
|
-
# If we return a pydantic object, convert it back to a dict
|
|
356
|
-
processed_response = processed_response.dict()
|
|
357
|
-
return processed_response
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
class ResponseChunk:
|
|
361
|
-
def __init__(self, value):
|
|
362
|
-
self.value = value
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
async def _convert_streamed_response_to_string(response: AsyncGenerator):
|
|
366
|
-
return "".join([str(chunk) async for chunk in response])
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
def _force_async_generator(gen: Union[Generator, AsyncGenerator]) -> AsyncGenerator:
|
|
370
|
-
"""
|
|
371
|
-
Takes a generator, and converts it into an async generator if it is not already.
|
|
372
|
-
"""
|
|
373
|
-
if inspect.isasyncgen(gen):
|
|
374
|
-
return gen
|
|
375
|
-
|
|
376
|
-
async def _convert_generator_to_async():
|
|
377
|
-
"""
|
|
378
|
-
Runs each iteration of the generator in an offloaded thread, to ensure
|
|
379
|
-
the main loop is not blocked, and yield to create an async generator.
|
|
380
|
-
"""
|
|
381
|
-
FINAL_GENERATOR_VALUE = object()
|
|
382
|
-
while True:
|
|
383
|
-
# Note that this is the equivalent of running:
|
|
384
|
-
# next(gen, FINAL_GENERATOR_VALUE) on a separate thread,
|
|
385
|
-
# ensuring that if there is anything blocking in the generator,
|
|
386
|
-
# it does not block the main loop.
|
|
387
|
-
chunk = await to_thread.run_sync(next, gen, FINAL_GENERATOR_VALUE)
|
|
388
|
-
if chunk == FINAL_GENERATOR_VALUE:
|
|
389
|
-
break
|
|
390
|
-
yield chunk
|
|
391
|
-
|
|
392
|
-
return _convert_generator_to_async()
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
def _signature_accepts_keyword_arg(signature: inspect.Signature, kwarg: str) -> bool:
|
|
396
|
-
return kwarg in signature.parameters or _signature_accepts_kwargs(signature)
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
def _signature_accepts_kwargs(signature: inspect.Signature) -> bool:
|
|
400
|
-
for param in signature.parameters.values():
|
|
401
|
-
if param.kind == inspect.Parameter.VAR_KEYWORD:
|
|
402
|
-
return True
|
|
403
|
-
return False
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
def _elapsed_ms(since_micro_seconds: float) -> int:
|
|
407
|
-
return int((time.perf_counter() - since_micro_seconds) * 1000)
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
def _handle_exception():
|
|
411
|
-
# Note that logger.exception logs the stacktrace, such that the user can
|
|
412
|
-
# debug this error from the logs.
|
|
413
|
-
logging.exception("Internal Server Error")
|
|
414
|
-
raise HTTPException(status_code=500, detail="Internal Server Error")
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
def _intercept_exceptions_sync(func):
|
|
418
|
-
def inner(*args, **kwargs):
|
|
419
|
-
try:
|
|
420
|
-
return func(*args, **kwargs)
|
|
421
|
-
except Exception:
|
|
422
|
-
_handle_exception()
|
|
423
|
-
|
|
424
|
-
return inner
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
def _intercept_exceptions_async(func):
|
|
428
|
-
async def inner(*args, **kwargs):
|
|
429
|
-
try:
|
|
430
|
-
return await func(*args, **kwargs)
|
|
431
|
-
except Exception:
|
|
432
|
-
_handle_exception()
|
|
433
|
-
|
|
434
|
-
return inner
|
truss/server/shared/logging.py
DELETED
|
@@ -1,81 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
import os
|
|
3
|
-
import sys
|
|
4
|
-
|
|
5
|
-
from pythonjsonlogger import jsonlogger
|
|
6
|
-
|
|
7
|
-
LEVEL: int = logging.INFO
|
|
8
|
-
|
|
9
|
-
use_json_logs = os.environ.get("JSON_LOG", default=False)
|
|
10
|
-
|
|
11
|
-
JSON_LOG_HANDLER = logging.StreamHandler(stream=sys.stdout)
|
|
12
|
-
JSON_LOG_HANDLER.set_name("json_logger_handler")
|
|
13
|
-
JSON_LOG_HANDLER.setLevel(LEVEL)
|
|
14
|
-
JSON_LOG_HANDLER.setFormatter(
|
|
15
|
-
jsonlogger.JsonFormatter("%(asctime)s %(levelname)s %(message)s")
|
|
16
|
-
)
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class HealthCheckFilter(logging.Filter):
|
|
20
|
-
def filter(self, record: logging.LogRecord) -> bool:
|
|
21
|
-
# for any health check endpoints, lets skip logging
|
|
22
|
-
return (
|
|
23
|
-
record.getMessage().find("GET / ") == -1
|
|
24
|
-
and record.getMessage().find("GET /v1/models/model ") == -1
|
|
25
|
-
)
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
class StreamToLogger:
|
|
29
|
-
"""
|
|
30
|
-
StreamToLogger redirects stdout and stderr to logger
|
|
31
|
-
"""
|
|
32
|
-
|
|
33
|
-
def __init__(self, logger, log_level, stream):
|
|
34
|
-
self.logger = logger
|
|
35
|
-
self.log_level = log_level
|
|
36
|
-
self.stream = stream
|
|
37
|
-
|
|
38
|
-
def __getattr__(self, name):
|
|
39
|
-
# we need to pass `isatty` from the stream for uvicorn
|
|
40
|
-
# this is a more general, less hacky fix
|
|
41
|
-
return getattr(self.stream, name)
|
|
42
|
-
|
|
43
|
-
def write(self, buf):
|
|
44
|
-
self.logger.log(self.log_level, buf)
|
|
45
|
-
|
|
46
|
-
def flush(self):
|
|
47
|
-
"""
|
|
48
|
-
This is a no-op function. It only exists to prevent
|
|
49
|
-
AttributeError in case some part of the code attempts to call flush()
|
|
50
|
-
on instances of StreamToLogger. Thus, we define this method as a safety
|
|
51
|
-
measure.
|
|
52
|
-
"""
|
|
53
|
-
pass
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
def setup_logging() -> None:
|
|
57
|
-
loggers = [logging.getLogger()] + [
|
|
58
|
-
logging.getLogger(name) for name in logging.root.manager.loggerDict
|
|
59
|
-
]
|
|
60
|
-
|
|
61
|
-
sys.stdout = StreamToLogger(logging.getLogger(), logging.INFO, sys.__stdout__) # type: ignore
|
|
62
|
-
sys.stderr = StreamToLogger(logging.getLogger(), logging.INFO, sys.__stderr__) # type: ignore
|
|
63
|
-
|
|
64
|
-
for logger in loggers:
|
|
65
|
-
logger.setLevel(LEVEL)
|
|
66
|
-
logger.propagate = False
|
|
67
|
-
|
|
68
|
-
setup = False
|
|
69
|
-
if use_json_logs:
|
|
70
|
-
# let's not thrash the handlers unnecessarily
|
|
71
|
-
for handler in logger.handlers:
|
|
72
|
-
if handler.name == JSON_LOG_HANDLER.name:
|
|
73
|
-
setup = True
|
|
74
|
-
|
|
75
|
-
if not setup:
|
|
76
|
-
logger.handlers.clear()
|
|
77
|
-
logger.addHandler(JSON_LOG_HANDLER)
|
|
78
|
-
|
|
79
|
-
# some special handling for request logging
|
|
80
|
-
if logger.name == "uvicorn.access":
|
|
81
|
-
logger.addFilter(HealthCheckFilter())
|
|
@@ -1,97 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from itertools import count
|
|
3
|
-
|
|
4
|
-
import build_engine_utils
|
|
5
|
-
from constants import (
|
|
6
|
-
GRPC_SERVICE_PORT,
|
|
7
|
-
HF_AUTH_KEY_CONSTANT,
|
|
8
|
-
HTTP_SERVICE_PORT,
|
|
9
|
-
TOKENIZER_KEY_CONSTANT,
|
|
10
|
-
)
|
|
11
|
-
from schema import ModelInput, TrussBuildConfig
|
|
12
|
-
from transformers import AutoTokenizer
|
|
13
|
-
from triton_client import TritonClient, TritonServer
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class Model:
|
|
17
|
-
def __init__(self, data_dir, config, secrets):
|
|
18
|
-
self._data_dir = data_dir
|
|
19
|
-
self._config = config
|
|
20
|
-
self._secrets = secrets
|
|
21
|
-
self._request_id_counter = count(start=1)
|
|
22
|
-
self.triton_client = None
|
|
23
|
-
self.triton_server = None
|
|
24
|
-
self.tokenizer = None
|
|
25
|
-
self.uses_openai_api = None
|
|
26
|
-
|
|
27
|
-
def load(self):
|
|
28
|
-
build_config = TrussBuildConfig(**self._config["build"]["arguments"])
|
|
29
|
-
self.uses_openai_api = "openai-compatible" in self._config.get(
|
|
30
|
-
"model_metadata", {}
|
|
31
|
-
).get("tags", [])
|
|
32
|
-
hf_access_token = None
|
|
33
|
-
if "hf_access_token" in self._secrets._base_secrets.keys():
|
|
34
|
-
hf_access_token = self._secrets["hf_access_token"]
|
|
35
|
-
|
|
36
|
-
# TODO(Abu): Move to pre-runtime
|
|
37
|
-
if build_config.requires_build:
|
|
38
|
-
build_engine_utils.build_engine_from_config_args(
|
|
39
|
-
engine_build_args=build_config.engine_build_args,
|
|
40
|
-
dst=self._data_dir,
|
|
41
|
-
)
|
|
42
|
-
|
|
43
|
-
self.triton_server = TritonServer(
|
|
44
|
-
grpc_port=GRPC_SERVICE_PORT,
|
|
45
|
-
http_port=HTTP_SERVICE_PORT,
|
|
46
|
-
)
|
|
47
|
-
|
|
48
|
-
self.triton_server.create_model_repository(
|
|
49
|
-
truss_data_dir=self._data_dir,
|
|
50
|
-
engine_repository_path=build_config.engine_repository
|
|
51
|
-
if not build_config.requires_build
|
|
52
|
-
else None,
|
|
53
|
-
huggingface_auth_token=hf_access_token,
|
|
54
|
-
)
|
|
55
|
-
|
|
56
|
-
env = {}
|
|
57
|
-
if hf_access_token:
|
|
58
|
-
env[HF_AUTH_KEY_CONSTANT] = hf_access_token
|
|
59
|
-
env[TOKENIZER_KEY_CONSTANT] = build_config.tokenizer_repository
|
|
60
|
-
|
|
61
|
-
self.triton_server.start(
|
|
62
|
-
tensor_parallelism=build_config.tensor_parallel_count,
|
|
63
|
-
env=env,
|
|
64
|
-
)
|
|
65
|
-
|
|
66
|
-
self.triton_client = TritonClient(
|
|
67
|
-
grpc_service_port=GRPC_SERVICE_PORT,
|
|
68
|
-
)
|
|
69
|
-
|
|
70
|
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
71
|
-
build_config.tokenizer_repository, token=hf_access_token
|
|
72
|
-
)
|
|
73
|
-
self.eos_token_id = self.tokenizer.eos_token_id
|
|
74
|
-
|
|
75
|
-
async def predict(self, model_input):
|
|
76
|
-
model_input["request_id"] = str(os.getpid()) + str(
|
|
77
|
-
next(self._request_id_counter)
|
|
78
|
-
)
|
|
79
|
-
model_input["eos_token_id"] = self.eos_token_id
|
|
80
|
-
|
|
81
|
-
self.triton_client.start_grpc_stream()
|
|
82
|
-
|
|
83
|
-
model_input = ModelInput(**model_input)
|
|
84
|
-
|
|
85
|
-
result_iterator = self.triton_client.infer(model_input)
|
|
86
|
-
|
|
87
|
-
async def generate():
|
|
88
|
-
async for result in result_iterator:
|
|
89
|
-
yield result
|
|
90
|
-
|
|
91
|
-
if model_input.stream:
|
|
92
|
-
return generate()
|
|
93
|
-
else:
|
|
94
|
-
if self.uses_openai_api:
|
|
95
|
-
return "".join(generate())
|
|
96
|
-
else:
|
|
97
|
-
return {"text": "".join(generate())}
|
|
@@ -1,34 +0,0 @@
|
|
|
1
|
-
from pathlib import Path
|
|
2
|
-
|
|
3
|
-
from schema import EngineBuildArgs
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
def build_engine_from_config_args(
|
|
7
|
-
engine_build_args: EngineBuildArgs,
|
|
8
|
-
dst: Path,
|
|
9
|
-
):
|
|
10
|
-
import os
|
|
11
|
-
import shutil
|
|
12
|
-
import sys
|
|
13
|
-
|
|
14
|
-
# NOTE: These are provided by the underlying base image
|
|
15
|
-
# TODO(Abu): Remove this when we have a better way of handling this
|
|
16
|
-
sys.path.append("/app/baseten")
|
|
17
|
-
from build_engine import Engine, build_engine
|
|
18
|
-
from trtllm_utils import docker_tag_aware_file_cache
|
|
19
|
-
|
|
20
|
-
engine = Engine(**engine_build_args.model_dump())
|
|
21
|
-
|
|
22
|
-
with docker_tag_aware_file_cache("/root/.cache/trtllm"):
|
|
23
|
-
built_engine = build_engine(engine, download_remote=True)
|
|
24
|
-
|
|
25
|
-
if not os.path.exists(dst):
|
|
26
|
-
os.makedirs(dst)
|
|
27
|
-
|
|
28
|
-
for filename in os.listdir(str(built_engine)):
|
|
29
|
-
source_file = os.path.join(str(built_engine), filename)
|
|
30
|
-
destination_file = os.path.join(dst, filename)
|
|
31
|
-
if not os.path.exists(destination_file):
|
|
32
|
-
shutil.copy(source_file, destination_file)
|
|
33
|
-
|
|
34
|
-
return dst
|
|
@@ -1,11 +0,0 @@
|
|
|
1
|
-
from pathlib import Path
|
|
2
|
-
|
|
3
|
-
# If changing model repo path, please updated inside tensorrt_llm config.pbtxt as well
|
|
4
|
-
TENSORRT_LLM_MODEL_REPOSITORY_PATH = Path(
|
|
5
|
-
"/app/packages/tensorrt_llm_model_repository/"
|
|
6
|
-
)
|
|
7
|
-
GRPC_SERVICE_PORT = 8001
|
|
8
|
-
HTTP_SERVICE_PORT = 8003
|
|
9
|
-
HF_AUTH_KEY_CONSTANT = "HUGGING_FACE_HUB_TOKEN"
|
|
10
|
-
TOKENIZER_KEY_CONSTANT = "TRITON_TOKENIZER_REPOSITORY"
|
|
11
|
-
ENTRYPOINT_MODEL_NAME = "ensemble"
|