truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/__init__.py +10 -3
- truss/api/__init__.py +123 -0
- truss/api/definitions.py +51 -0
- truss/base/constants.py +116 -0
- truss/base/custom_types.py +29 -0
- truss/{errors.py → base/errors.py} +4 -0
- truss/base/trt_llm_config.py +310 -0
- truss/{truss_config.py → base/truss_config.py} +344 -31
- truss/{truss_spec.py → base/truss_spec.py} +20 -6
- truss/{validation.py → base/validation.py} +60 -11
- truss/cli/cli.py +841 -88
- truss/{remote → cli}/remote_cli.py +2 -7
- truss/contexts/docker_build_setup.py +67 -0
- truss/contexts/image_builder/cache_warmer.py +2 -8
- truss/contexts/image_builder/image_builder.py +1 -1
- truss/contexts/image_builder/serving_image_builder.py +292 -46
- truss/contexts/image_builder/util.py +1 -3
- truss/contexts/local_loader/docker_build_emulator.py +58 -0
- truss/contexts/local_loader/load_model_local.py +2 -2
- truss/contexts/local_loader/truss_module_loader.py +1 -1
- truss/contexts/local_loader/utils.py +1 -1
- truss/local/local_config.py +2 -6
- truss/local/local_config_handler.py +20 -5
- truss/patch/__init__.py +1 -0
- truss/patch/hash.py +4 -70
- truss/patch/signature.py +4 -16
- truss/patch/truss_dir_patch_applier.py +3 -78
- truss/remote/baseten/api.py +308 -23
- truss/remote/baseten/auth.py +3 -3
- truss/remote/baseten/core.py +257 -50
- truss/remote/baseten/custom_types.py +44 -0
- truss/remote/baseten/error.py +4 -0
- truss/remote/baseten/remote.py +369 -118
- truss/remote/baseten/service.py +118 -11
- truss/remote/baseten/utils/status.py +29 -0
- truss/remote/baseten/utils/tar.py +34 -22
- truss/remote/baseten/utils/transfer.py +36 -23
- truss/remote/remote_factory.py +14 -5
- truss/remote/truss_remote.py +72 -45
- truss/templates/base.Dockerfile.jinja +18 -16
- truss/templates/cache.Dockerfile.jinja +3 -3
- truss/{server → templates/control}/control/application.py +14 -35
- truss/{server → templates/control}/control/endpoints.py +39 -9
- truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
- truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
- truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
- truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
- truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
- truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
- truss/{server → templates/control}/control/server.py +11 -6
- truss/templates/control/requirements.txt +9 -0
- truss/templates/custom_python_dx/my_model.py +28 -0
- truss/templates/docker_server/proxy.conf.jinja +42 -0
- truss/templates/docker_server/supervisord.conf.jinja +27 -0
- truss/templates/docker_server_requirements.txt +1 -0
- truss/templates/server/common/errors.py +231 -0
- truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
- truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
- truss/{server → templates/server}/common/retry.py +1 -0
- truss/{server → templates/server}/common/schema.py +11 -9
- truss/templates/server/common/tracing.py +157 -0
- truss/templates/server/main.py +9 -0
- truss/templates/server/model_wrapper.py +961 -0
- truss/templates/server/requirements.txt +21 -0
- truss/templates/server/truss_server.py +447 -0
- truss/templates/server.Dockerfile.jinja +62 -14
- truss/templates/shared/dynamic_config_resolver.py +28 -0
- truss/templates/shared/lazy_data_resolver.py +164 -0
- truss/templates/shared/log_config.py +125 -0
- truss/{server → templates}/shared/secrets_resolver.py +1 -2
- truss/{server → templates}/shared/serialization.py +31 -9
- truss/{server → templates}/shared/util.py +3 -13
- truss/templates/trtllm-audio/model/model.py +49 -0
- truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
- truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
- truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
- truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
- truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
- truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
- truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
- truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
- truss/templates/trtllm-briton/src/extension.py +64 -0
- truss/tests/conftest.py +302 -94
- truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
- truss/tests/contexts/local_loader/test_load_local.py +2 -2
- truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
- truss/tests/patch/test_calc_patch.py +439 -127
- truss/tests/patch/test_dir_signature.py +3 -12
- truss/tests/patch/test_hash.py +1 -1
- truss/tests/patch/test_signature.py +1 -1
- truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
- truss/tests/patch/test_types.py +2 -2
- truss/tests/remote/baseten/test_api.py +153 -58
- truss/tests/remote/baseten/test_auth.py +2 -1
- truss/tests/remote/baseten/test_core.py +160 -12
- truss/tests/remote/baseten/test_remote.py +489 -77
- truss/tests/remote/baseten/test_service.py +55 -0
- truss/tests/remote/test_remote_factory.py +16 -18
- truss/tests/remote/test_truss_remote.py +26 -17
- truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
- truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
- truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
- truss/tests/{server → templates/control}/control/test_server.py +79 -24
- truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
- truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
- truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
- truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
- truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
- truss/tests/{server → templates/server}/common/test_retry.py +3 -3
- truss/tests/templates/server/test_model_wrapper.py +248 -0
- truss/tests/{server → templates/server}/test_schema.py +3 -5
- truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
- truss/tests/test_build.py +9 -52
- truss/tests/test_config.py +336 -77
- truss/tests/test_context_builder_image.py +3 -11
- truss/tests/test_control_truss_patching.py +7 -12
- truss/tests/test_custom_server.py +38 -0
- truss/tests/test_data/context_builder_image_test/test.py +3 -0
- truss/tests/test_data/happy.ipynb +56 -0
- truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
- truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
- truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
- truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
- truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
- truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
- truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
- truss/tests/test_data/test_async_truss/__init__.py +0 -0
- truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/config.yaml +16 -0
- truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/config.yaml +13 -0
- truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
- truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
- truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
- truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
- truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
- truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
- truss/tests/test_data/test_openai/__init__.py +0 -0
- truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
- truss/tests/test_data/test_openai/model/__init__.py +0 -0
- truss/tests/test_data/test_openai/model/model.py +15 -0
- truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
- truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
- truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
- truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
- truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
- truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
- truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
- truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
- truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
- truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
- truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
- truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
- truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
- truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
- truss/tests/test_data/test_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss/config.yaml +4 -0
- truss/tests/test_data/test_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss/model/dummy +0 -0
- truss/tests/test_data/test_truss/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
- truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
- truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_docker.py +2 -1
- truss/tests/test_model_inference.py +1340 -292
- truss/tests/test_model_schema.py +33 -26
- truss/tests/test_testing_utilities_for_other_tests.py +50 -5
- truss/tests/test_truss_gatherer.py +3 -5
- truss/tests/test_truss_handle.py +62 -59
- truss/tests/test_util.py +2 -1
- truss/tests/test_validation.py +15 -13
- truss/tests/trt_llm/test_trt_llm_config.py +41 -0
- truss/tests/trt_llm/test_validation.py +91 -0
- truss/tests/util/test_config_checks.py +40 -0
- truss/tests/util/test_env_vars.py +14 -0
- truss/tests/util/test_path.py +10 -23
- truss/trt_llm/config_checks.py +43 -0
- truss/trt_llm/validation.py +42 -0
- truss/truss_handle/__init__.py +0 -0
- truss/truss_handle/build.py +122 -0
- truss/{decorators.py → truss_handle/decorators.py} +1 -1
- truss/truss_handle/patch/__init__.py +0 -0
- truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
- truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
- truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
- truss/truss_handle/patch/hash.py +71 -0
- truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
- truss/truss_handle/patch/signature.py +22 -0
- truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
- truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
- truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
- truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
- truss/util/.truss_ignore +3 -0
- truss/{docker.py → util/docker.py} +6 -2
- truss/util/download.py +6 -15
- truss/util/env_vars.py +41 -0
- truss/util/log_utils.py +52 -0
- truss/util/path.py +20 -20
- truss/util/requirements.py +11 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
- truss-0.60.0.dist-info/RECORD +324 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
- truss-0.60.0.dist-info/entry_points.txt +4 -0
- truss_chains/__init__.py +71 -0
- truss_chains/definitions.py +756 -0
- truss_chains/deployment/__init__.py +0 -0
- truss_chains/deployment/code_gen.py +816 -0
- truss_chains/deployment/deployment_client.py +871 -0
- truss_chains/framework.py +1480 -0
- truss_chains/public_api.py +231 -0
- truss_chains/py.typed +0 -0
- truss_chains/pydantic_numpy.py +131 -0
- truss_chains/reference_code/reference_chainlet.py +34 -0
- truss_chains/reference_code/reference_model.py +10 -0
- truss_chains/remote_chainlet/__init__.py +0 -0
- truss_chains/remote_chainlet/model_skeleton.py +60 -0
- truss_chains/remote_chainlet/stub.py +380 -0
- truss_chains/remote_chainlet/utils.py +332 -0
- truss_chains/streaming.py +378 -0
- truss_chains/utils.py +178 -0
- CODE_OF_CONDUCT.md +0 -131
- CONTRIBUTING.md +0 -48
- README.md +0 -137
- context_builder.Dockerfile +0 -24
- truss/blob/blob_backend.py +0 -10
- truss/blob/blob_backend_registry.py +0 -23
- truss/blob/http_public_blob_backend.py +0 -23
- truss/build/__init__.py +0 -2
- truss/build/build.py +0 -143
- truss/build/configure.py +0 -63
- truss/cli/__init__.py +0 -2
- truss/cli/console.py +0 -5
- truss/cli/create.py +0 -5
- truss/config/trt_llm.py +0 -81
- truss/constants.py +0 -61
- truss/model_inference.py +0 -123
- truss/patch/types.py +0 -30
- truss/pytest.ini +0 -7
- truss/server/common/errors.py +0 -100
- truss/server/common/termination_handler_middleware.py +0 -64
- truss/server/common/truss_server.py +0 -389
- truss/server/control/patch/model_code_patch_applier.py +0 -46
- truss/server/control/patch/requirement_name_identifier.py +0 -17
- truss/server/inference_server.py +0 -29
- truss/server/model_wrapper.py +0 -434
- truss/server/shared/logging.py +0 -81
- truss/templates/trtllm/model/model.py +0 -97
- truss/templates/trtllm/packages/build_engine_utils.py +0 -34
- truss/templates/trtllm/packages/constants.py +0 -11
- truss/templates/trtllm/packages/schema.py +0 -216
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
- truss/templates/trtllm/packages/triton_client.py +0 -150
- truss/templates/trtllm/packages/utils.py +0 -43
- truss/test_data/context_builder_image_test/test.py +0 -4
- truss/test_data/happy.ipynb +0 -54
- truss/test_data/model_load_failure_test/config.yaml +0 -2
- truss/test_data/test_concurrency_truss/config.yaml +0 -2
- truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
- truss/test_data/test_streaming_truss/config.yaml +0 -3
- truss/test_data/test_truss/config.yaml +0 -2
- truss/tests/server/common/test_termination_handler_middleware.py +0 -93
- truss/tests/server/control/test_model_container_patch_applier.py +0 -203
- truss/tests/server/core/server/common/test_util.py +0 -19
- truss/tests/server/test_model_wrapper.py +0 -87
- truss/util/data_structures.py +0 -16
- truss-0.10.0rc1.dist-info/RECORD +0 -216
- truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
- truss/{server/shared → base}/__init__.py +0 -0
- truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
- truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
- truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
- truss/{server → templates/server}/common/__init__.py +0 -0
- truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
- truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
- truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
- truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
- truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
- truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
- truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
- truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
- truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
- truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
- truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
- truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
- truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
- truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
- /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
- /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
- /truss/{patch → truss_handle/patch}/constants.py +0 -0
- /truss/{notebook.py → util/notebook.py} +0 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import shutil
|
|
4
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Dict, List, Tuple
|
|
8
|
+
|
|
9
|
+
import pydantic
|
|
10
|
+
import requests
|
|
11
|
+
import yaml
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from shared.util import BLOB_DOWNLOAD_TIMEOUT_SECS
|
|
15
|
+
except ModuleNotFoundError:
|
|
16
|
+
from truss.templates.shared.util import BLOB_DOWNLOAD_TIMEOUT_SECS
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
import truss_transfer
|
|
20
|
+
|
|
21
|
+
TRUSS_TRANSFER_AVAILABLE = True
|
|
22
|
+
except ImportError:
|
|
23
|
+
TRUSS_TRANSFER_AVAILABLE = False
|
|
24
|
+
|
|
25
|
+
LAZY_DATA_RESOLVER_PATH = Path("/bptr/bptr-manifest")
|
|
26
|
+
NUM_WORKERS = 4
|
|
27
|
+
CACHE_DIR = Path("/cache/org/artifacts")
|
|
28
|
+
BASETEN_FS_ENABLED_ENV_VAR = "BASETEN_FS_ENABLED"
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Resolution(pydantic.BaseModel):
|
|
34
|
+
url: str
|
|
35
|
+
expiration_timestamp: int
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class BasetenPointer(pydantic.BaseModel):
|
|
39
|
+
"""Specification for lazy data resolution for download of large files, similar to Git LFS pointers"""
|
|
40
|
+
|
|
41
|
+
resolution: Resolution
|
|
42
|
+
uid: str
|
|
43
|
+
file_name: str
|
|
44
|
+
hashtype: str
|
|
45
|
+
hash: str
|
|
46
|
+
size: int
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class BasetenPointerManifest(pydantic.BaseModel):
|
|
50
|
+
pointers: List[BasetenPointer]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class LazyDataResolver:
|
|
54
|
+
"""Deprecation warning: This class is deprecated and will be removed in a future release.
|
|
55
|
+
|
|
56
|
+
Please use LazyDataResolverV2 instead (using the `truss_transfer` package).
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, data_dir: Path):
|
|
60
|
+
self._data_dir: Path = data_dir
|
|
61
|
+
self._bptr_resolution: Dict[str, Tuple[str, str, int]] = _read_bptr_resolution()
|
|
62
|
+
self._resolution_done = False
|
|
63
|
+
self._uses_b10_cache = (
|
|
64
|
+
os.environ.get(BASETEN_FS_ENABLED_ENV_VAR, "False") == "True"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def cached_download_from_url_using_requests(
|
|
68
|
+
self, URL: str, hash: str, file_name: str, size: int
|
|
69
|
+
):
|
|
70
|
+
"""Download object from URL, attempt to write to cache and symlink to data directory if applicable, data directory otherwise.
|
|
71
|
+
In case of failure, write to data directory
|
|
72
|
+
"""
|
|
73
|
+
if self._uses_b10_cache:
|
|
74
|
+
file_path = CACHE_DIR / hash
|
|
75
|
+
if file_path.exists():
|
|
76
|
+
try:
|
|
77
|
+
os.symlink(file_path, self._data_dir / file_name)
|
|
78
|
+
return
|
|
79
|
+
except FileExistsError:
|
|
80
|
+
# symlink may already exist if the inference server was restarted
|
|
81
|
+
return
|
|
82
|
+
|
|
83
|
+
# Streaming download to keep memory usage low
|
|
84
|
+
resp = requests.get(
|
|
85
|
+
URL, allow_redirects=True, stream=True, timeout=BLOB_DOWNLOAD_TIMEOUT_SECS
|
|
86
|
+
)
|
|
87
|
+
resp.raise_for_status()
|
|
88
|
+
|
|
89
|
+
if self._uses_b10_cache:
|
|
90
|
+
try:
|
|
91
|
+
# Check whether the cache has sufficient space to store the file
|
|
92
|
+
cache_free_space = shutil.disk_usage(CACHE_DIR).free
|
|
93
|
+
if cache_free_space < size:
|
|
94
|
+
raise OSError(
|
|
95
|
+
f"Cache directory does not have sufficient space to save file {file_name}. Free space in cache: {cache_free_space}, file size: {size}"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
99
|
+
with file_path.open("wb") as file:
|
|
100
|
+
shutil.copyfileobj(resp.raw, file)
|
|
101
|
+
# symlink to data directory
|
|
102
|
+
os.symlink(file_path, self._data_dir / file_name)
|
|
103
|
+
return
|
|
104
|
+
except FileExistsError:
|
|
105
|
+
# symlink may already exist if the inference server was restarted
|
|
106
|
+
return
|
|
107
|
+
except OSError as e:
|
|
108
|
+
logger.debug(
|
|
109
|
+
"Failed to save artifact to cache dir, saving to data dir instead. Error: %s",
|
|
110
|
+
e,
|
|
111
|
+
)
|
|
112
|
+
# Cache likely has no space left on device, break to download to data dir as fallback
|
|
113
|
+
pass
|
|
114
|
+
|
|
115
|
+
file_path = self._data_dir / file_name
|
|
116
|
+
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
117
|
+
with file_path.open("wb") as file:
|
|
118
|
+
shutil.copyfileobj(resp.raw, file)
|
|
119
|
+
|
|
120
|
+
def fetch(self):
|
|
121
|
+
if self._resolution_done:
|
|
122
|
+
return
|
|
123
|
+
|
|
124
|
+
with ThreadPoolExecutor(NUM_WORKERS) as executor:
|
|
125
|
+
futures = {}
|
|
126
|
+
for file_name, (resolved_url, hash, size) in self._bptr_resolution.items():
|
|
127
|
+
futures[
|
|
128
|
+
executor.submit(
|
|
129
|
+
self.cached_download_from_url_using_requests,
|
|
130
|
+
resolved_url,
|
|
131
|
+
hash,
|
|
132
|
+
file_name,
|
|
133
|
+
size,
|
|
134
|
+
)
|
|
135
|
+
] = file_name
|
|
136
|
+
for future in as_completed(futures):
|
|
137
|
+
if future.exception():
|
|
138
|
+
file_name = futures[future]
|
|
139
|
+
raise RuntimeError(f"Download failure for file {file_name}")
|
|
140
|
+
self._resolution_done = True
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class LazyDataResolverV2:
|
|
144
|
+
def __init__(self, data_dir: Path):
|
|
145
|
+
self._data_dir: Path = data_dir
|
|
146
|
+
|
|
147
|
+
def fetch(self):
|
|
148
|
+
truss_transfer.lazy_data_resolve(str(self._data_dir))
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _read_bptr_resolution() -> Dict[str, Tuple[str, str, int]]:
|
|
152
|
+
if not LAZY_DATA_RESOLVER_PATH.is_file():
|
|
153
|
+
return {}
|
|
154
|
+
bptr_manifest = BasetenPointerManifest(
|
|
155
|
+
**yaml.safe_load(LAZY_DATA_RESOLVER_PATH.read_text())
|
|
156
|
+
)
|
|
157
|
+
resolution_map = {}
|
|
158
|
+
for bptr in bptr_manifest.pointers:
|
|
159
|
+
if bptr.resolution.expiration_timestamp < int(
|
|
160
|
+
datetime.now(timezone.utc).timestamp()
|
|
161
|
+
):
|
|
162
|
+
raise RuntimeError("Baseten pointer lazy data resolution has expired")
|
|
163
|
+
resolution_map[bptr.file_name] = bptr.resolution.url, bptr.hash, bptr.size
|
|
164
|
+
return resolution_map
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import urllib.parse
|
|
4
|
+
from typing import Any, Mapping
|
|
5
|
+
|
|
6
|
+
from pythonjsonlogger import jsonlogger
|
|
7
|
+
|
|
8
|
+
LOCAL_DATE_FORMAT = "%H:%M:%S"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _disable_json_logging() -> bool:
|
|
12
|
+
return bool(os.environ.get("DISABLE_JSON_LOGGING"))
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class _HealthCheckFilter(logging.Filter):
|
|
16
|
+
def filter(self, record: logging.LogRecord) -> bool:
|
|
17
|
+
excluded_paths = {
|
|
18
|
+
"GET / ",
|
|
19
|
+
"GET /v1/models/model ",
|
|
20
|
+
"GET /v1/models/model/loaded ",
|
|
21
|
+
}
|
|
22
|
+
msg = record.getMessage()
|
|
23
|
+
return not any(path in msg for path in excluded_paths)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class _AccessJsonFormatter(jsonlogger.JsonFormatter):
|
|
27
|
+
def format(self, record: logging.LogRecord) -> str:
|
|
28
|
+
# Uvicorn sets record.msg = '%s - "%s %s HTTP/%s" %d' and
|
|
29
|
+
# record.args = (addr, method, path, version, status).
|
|
30
|
+
# Python's logging system resolves final
|
|
31
|
+
# record.message = record.msg % record.args unless we override record.msg.
|
|
32
|
+
if record.name == "uvicorn.access" and record.args and len(record.args) == 5:
|
|
33
|
+
client_addr, method, raw_path, version, status = record.args
|
|
34
|
+
path_decoded = urllib.parse.unquote(str(raw_path))
|
|
35
|
+
new_message = (
|
|
36
|
+
f"Handled request from {client_addr} - {method} "
|
|
37
|
+
f"{path_decoded} HTTP/{version} {status}"
|
|
38
|
+
)
|
|
39
|
+
record.msg = new_message
|
|
40
|
+
record.args = () # Ensure Python doesn't reapply the old format string
|
|
41
|
+
return super().format(record)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class _AccessFormatter(logging.Formatter):
|
|
45
|
+
def format(self, record: logging.LogRecord) -> str:
|
|
46
|
+
if record.name == "uvicorn.access" and record.args and len(record.args) == 5:
|
|
47
|
+
client_addr, method, raw_path, version, status = record.args
|
|
48
|
+
path_decoded = urllib.parse.unquote(str(raw_path))
|
|
49
|
+
new_message = (
|
|
50
|
+
f"Handled request from {client_addr} - {method} "
|
|
51
|
+
f"{path_decoded} HTTP/{version} {status}"
|
|
52
|
+
)
|
|
53
|
+
record.msg = new_message
|
|
54
|
+
record.args = ()
|
|
55
|
+
return super().format(record)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def make_log_config(log_level: str) -> Mapping[str, Any]:
|
|
59
|
+
# Warning: `ModelWrapper` depends on correctly setup `uvicorn` logger,
|
|
60
|
+
# if you change/remove that logger, make sure `ModelWrapper` has a suitable
|
|
61
|
+
# alternative logger that is also correctly setup in the load thread.
|
|
62
|
+
formatters = (
|
|
63
|
+
{
|
|
64
|
+
"default_formatter": {
|
|
65
|
+
"format": "%(asctime)s.%(msecs)04d %(levelname)s %(message)s",
|
|
66
|
+
"datefmt": LOCAL_DATE_FORMAT,
|
|
67
|
+
},
|
|
68
|
+
"access_formatter": {
|
|
69
|
+
"()": _AccessFormatter,
|
|
70
|
+
"format": "%(asctime)s.%(msecs)04d %(levelname)s %(message)s",
|
|
71
|
+
"datefmt": LOCAL_DATE_FORMAT,
|
|
72
|
+
},
|
|
73
|
+
}
|
|
74
|
+
if _disable_json_logging()
|
|
75
|
+
else {
|
|
76
|
+
"default_formatter": {
|
|
77
|
+
"()": jsonlogger.JsonFormatter,
|
|
78
|
+
"format": "%(asctime)s %(levelname)s %(message)s",
|
|
79
|
+
},
|
|
80
|
+
"access_formatter": {
|
|
81
|
+
"()": _AccessJsonFormatter,
|
|
82
|
+
"format": "%(asctime)s %(levelname)s %(message)s",
|
|
83
|
+
},
|
|
84
|
+
}
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
log_config = {
|
|
88
|
+
"version": 1,
|
|
89
|
+
"disable_existing_loggers": False,
|
|
90
|
+
"filters": {"health_check_filter": {"()": _HealthCheckFilter}},
|
|
91
|
+
"formatters": formatters,
|
|
92
|
+
"handlers": {
|
|
93
|
+
"default_handler": {
|
|
94
|
+
"formatter": "default_formatter",
|
|
95
|
+
"class": "logging.StreamHandler",
|
|
96
|
+
"stream": "ext://sys.stderr",
|
|
97
|
+
},
|
|
98
|
+
"access_handler": {
|
|
99
|
+
"formatter": "access_formatter",
|
|
100
|
+
"class": "logging.StreamHandler",
|
|
101
|
+
"stream": "ext://sys.stdout",
|
|
102
|
+
},
|
|
103
|
+
},
|
|
104
|
+
"loggers": {
|
|
105
|
+
"uvicorn": {
|
|
106
|
+
"handlers": ["default_handler"],
|
|
107
|
+
"level": log_level,
|
|
108
|
+
"propagate": False,
|
|
109
|
+
},
|
|
110
|
+
"uvicorn.error": {
|
|
111
|
+
"handlers": ["default_handler"],
|
|
112
|
+
"level": "INFO",
|
|
113
|
+
"propagate": False,
|
|
114
|
+
},
|
|
115
|
+
"uvicorn.access": {
|
|
116
|
+
"handlers": ["access_handler"],
|
|
117
|
+
"level": "INFO",
|
|
118
|
+
"propagate": False,
|
|
119
|
+
"filters": ["health_check_filter"],
|
|
120
|
+
},
|
|
121
|
+
},
|
|
122
|
+
# Catch-all for module loggers
|
|
123
|
+
"root": {"handlers": ["default_handler"], "level": log_level},
|
|
124
|
+
}
|
|
125
|
+
return log_config
|
|
@@ -3,7 +3,7 @@ from collections.abc import Mapping
|
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
from typing import Dict, Optional
|
|
5
5
|
|
|
6
|
-
SECRETS_DOC_LINK = "https://truss.baseten.co/
|
|
6
|
+
SECRETS_DOC_LINK = "https://truss.baseten.co/guides/secrets"
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class SecretNotFound(Exception):
|
|
@@ -62,7 +62,6 @@ def _secret_missing_error_message(key: str) -> str:
|
|
|
62
62
|
return f"""
|
|
63
63
|
Secret '{key}' not found. Please ensure that:
|
|
64
64
|
* Secret '{key}' is defined in the 'secrets' section of the Truss config file
|
|
65
|
-
* The model was pushed with the --trusted flag
|
|
66
65
|
* Secret '{key}' is defined in the secret manager
|
|
67
66
|
Read more about secrets here: {SECRETS_DOC_LINK}.
|
|
68
67
|
"""
|
|
@@ -2,11 +2,33 @@ import json
|
|
|
2
2
|
import uuid
|
|
3
3
|
from datetime import date, datetime, time, timedelta
|
|
4
4
|
from decimal import Decimal
|
|
5
|
-
from typing import Any, Callable, Dict, Optional, Union
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
JSONType = Union[str, int, float, bool, None, List["JSONType"], Dict[str, "JSONType"]]
|
|
12
|
+
MsgPackType = Union[
|
|
13
|
+
str,
|
|
14
|
+
int,
|
|
15
|
+
float,
|
|
16
|
+
bool,
|
|
17
|
+
None,
|
|
18
|
+
date,
|
|
19
|
+
Decimal,
|
|
20
|
+
datetime,
|
|
21
|
+
time,
|
|
22
|
+
timedelta,
|
|
23
|
+
uuid.UUID,
|
|
24
|
+
"NDArray",
|
|
25
|
+
List["MsgPackType"],
|
|
26
|
+
Dict[str, "MsgPackType"],
|
|
27
|
+
]
|
|
6
28
|
|
|
7
29
|
|
|
8
30
|
# mostly cribbed from django.core.serializer.DjangoJSONEncoder
|
|
9
|
-
def
|
|
31
|
+
def _truss_msgpack_encoder(
|
|
10
32
|
obj: Union[Decimal, date, time, timedelta, uuid.UUID, Dict],
|
|
11
33
|
chain: Optional[Callable] = None,
|
|
12
34
|
) -> Dict:
|
|
@@ -36,7 +58,7 @@ def truss_msgpack_encoder(
|
|
|
36
58
|
return obj if chain is None else chain(obj)
|
|
37
59
|
|
|
38
60
|
|
|
39
|
-
def
|
|
61
|
+
def _truss_msgpack_decoder(obj: Any, chain=None):
|
|
40
62
|
try:
|
|
41
63
|
if b"__dt_datetime_iso__" in obj:
|
|
42
64
|
return datetime.fromisoformat(obj[b"data"])
|
|
@@ -58,7 +80,7 @@ def truss_msgpack_decoder(obj: Any, chain=None):
|
|
|
58
80
|
|
|
59
81
|
|
|
60
82
|
# this json object is JSONType + np.array + datetime
|
|
61
|
-
def is_truss_serializable(obj) -> bool:
|
|
83
|
+
def is_truss_serializable(obj: Any) -> bool:
|
|
62
84
|
import numpy as np
|
|
63
85
|
|
|
64
86
|
# basic JSON types
|
|
@@ -72,21 +94,21 @@ def is_truss_serializable(obj) -> bool:
|
|
|
72
94
|
return False
|
|
73
95
|
|
|
74
96
|
|
|
75
|
-
def truss_msgpack_serialize(obj):
|
|
97
|
+
def truss_msgpack_serialize(obj: MsgPackType) -> bytes:
|
|
76
98
|
import msgpack
|
|
77
99
|
import msgpack_numpy as mp_np
|
|
78
100
|
|
|
79
101
|
return msgpack.packb(
|
|
80
|
-
obj, default=lambda x:
|
|
102
|
+
obj, default=lambda x: _truss_msgpack_encoder(x, chain=mp_np.encode)
|
|
81
103
|
)
|
|
82
104
|
|
|
83
105
|
|
|
84
|
-
def truss_msgpack_deserialize(
|
|
106
|
+
def truss_msgpack_deserialize(data: bytes) -> MsgPackType:
|
|
85
107
|
import msgpack
|
|
86
108
|
import msgpack_numpy as mp_np
|
|
87
109
|
|
|
88
110
|
return msgpack.unpackb(
|
|
89
|
-
|
|
111
|
+
data, object_hook=lambda x: _truss_msgpack_decoder(x, chain=mp_np.decode)
|
|
90
112
|
)
|
|
91
113
|
|
|
92
114
|
|
|
@@ -101,4 +123,4 @@ class DeepNumpyEncoder(json.JSONEncoder):
|
|
|
101
123
|
elif isinstance(obj, np.ndarray):
|
|
102
124
|
return obj.tolist()
|
|
103
125
|
else:
|
|
104
|
-
return super(
|
|
126
|
+
return super().default(obj)
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
import multiprocessing
|
|
2
2
|
import os
|
|
3
3
|
import sys
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import List
|
|
5
5
|
|
|
6
6
|
import psutil
|
|
7
7
|
|
|
8
|
+
BLOB_DOWNLOAD_TIMEOUT_SECS = 600 # 10 minutes
|
|
8
9
|
# number of seconds to wait for truss server child processes before sending kill signal
|
|
9
10
|
CHILD_PROCESS_WAIT_TIMEOUT_SECONDS = 120
|
|
10
11
|
|
|
@@ -12,9 +13,7 @@ CHILD_PROCESS_WAIT_TIMEOUT_SECONDS = 120
|
|
|
12
13
|
def model_supports_predict_proba(model: object) -> bool:
|
|
13
14
|
if not hasattr(model, "predict_proba"):
|
|
14
15
|
return False
|
|
15
|
-
if hasattr(
|
|
16
|
-
model, "_check_proba"
|
|
17
|
-
): # noqa eg Support Vector Machines *can* predict proba if they made certain choices while training
|
|
16
|
+
if hasattr(model, "_check_proba"): # noqa eg Support Vector Machines *can* predict proba if they made certain choices while training
|
|
18
17
|
try:
|
|
19
18
|
model._check_proba()
|
|
20
19
|
return True
|
|
@@ -76,12 +75,3 @@ def kill_child_processes(parent_pid: int):
|
|
|
76
75
|
)
|
|
77
76
|
for process in alive:
|
|
78
77
|
process.kill()
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
X = TypeVar("X")
|
|
82
|
-
Y = TypeVar("Y")
|
|
83
|
-
Z = TypeVar("Z")
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
def transform_keys(d: Mapping[X, Z], fn: Callable[[X], Y]) -> Dict[Y, Z]:
|
|
87
|
-
return {fn(key): value for key, value in d.items()}
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
|
|
3
|
+
import requests
|
|
4
|
+
import sigint_patch
|
|
5
|
+
|
|
6
|
+
# We patch TensorRT-LLM v0.11 to fix an issue where they set a SIGINT handler
|
|
7
|
+
# which breaks because we do not use the main thread.
|
|
8
|
+
sigint_patch.patch()
|
|
9
|
+
from whisper_trt import WhisperModel
|
|
10
|
+
from whisper_trt.custom_types import WhisperResult
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Model:
|
|
14
|
+
def __init__(self, **kwargs):
|
|
15
|
+
self._data_dir = kwargs["data_dir"]
|
|
16
|
+
self._secrets = kwargs["secrets"]
|
|
17
|
+
self._model = None
|
|
18
|
+
|
|
19
|
+
def load(self):
|
|
20
|
+
self._model = WhisperModel(str(self._data_dir), max_queue_time=0.050)
|
|
21
|
+
|
|
22
|
+
def preprocess(self, request: dict):
|
|
23
|
+
audio_base64 = request.get("audio")
|
|
24
|
+
audio_url = request.get("url")
|
|
25
|
+
|
|
26
|
+
if audio_base64 and audio_url:
|
|
27
|
+
return {
|
|
28
|
+
"error": "Only a base64 audio file OR a URL can be passed to the API, not both of them."
|
|
29
|
+
}
|
|
30
|
+
if not audio_base64 and not audio_url:
|
|
31
|
+
return {
|
|
32
|
+
"error": "Please provide either an audio file in base64 string format or a URL to an audio file."
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
binary_data = None
|
|
36
|
+
|
|
37
|
+
if audio_base64:
|
|
38
|
+
binary_data = base64.b64decode(audio_base64.encode("utf-8"))
|
|
39
|
+
elif audio_url:
|
|
40
|
+
resp = requests.get(audio_url)
|
|
41
|
+
binary_data = resp.content
|
|
42
|
+
return binary_data, request
|
|
43
|
+
|
|
44
|
+
async def predict(self, preprocessed_request) -> WhisperResult:
|
|
45
|
+
binary_data, request = preprocessed_request
|
|
46
|
+
waveform = self._model.preprocess_audio(binary_data)
|
|
47
|
+
return await self._model.transcribe(
|
|
48
|
+
waveform, language="english", timestamps=True
|
|
49
|
+
)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
import fileinput
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
MODULE_FILE_PATH = "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/hlapi/utils.py"
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def patch():
|
|
8
|
+
search_text = "signal.signal(signal.SIGINT, sigint_handler)"
|
|
9
|
+
|
|
10
|
+
with fileinput.FileInput(MODULE_FILE_PATH, inplace=True) as file:
|
|
11
|
+
for line in file:
|
|
12
|
+
if search_text in line:
|
|
13
|
+
line = " # " + line.lstrip()
|
|
14
|
+
sys.stdout.write(line)
|