truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/__init__.py +10 -3
- truss/api/__init__.py +123 -0
- truss/api/definitions.py +51 -0
- truss/base/constants.py +116 -0
- truss/base/custom_types.py +29 -0
- truss/{errors.py → base/errors.py} +4 -0
- truss/base/trt_llm_config.py +310 -0
- truss/{truss_config.py → base/truss_config.py} +344 -31
- truss/{truss_spec.py → base/truss_spec.py} +20 -6
- truss/{validation.py → base/validation.py} +60 -11
- truss/cli/cli.py +841 -88
- truss/{remote → cli}/remote_cli.py +2 -7
- truss/contexts/docker_build_setup.py +67 -0
- truss/contexts/image_builder/cache_warmer.py +2 -8
- truss/contexts/image_builder/image_builder.py +1 -1
- truss/contexts/image_builder/serving_image_builder.py +292 -46
- truss/contexts/image_builder/util.py +1 -3
- truss/contexts/local_loader/docker_build_emulator.py +58 -0
- truss/contexts/local_loader/load_model_local.py +2 -2
- truss/contexts/local_loader/truss_module_loader.py +1 -1
- truss/contexts/local_loader/utils.py +1 -1
- truss/local/local_config.py +2 -6
- truss/local/local_config_handler.py +20 -5
- truss/patch/__init__.py +1 -0
- truss/patch/hash.py +4 -70
- truss/patch/signature.py +4 -16
- truss/patch/truss_dir_patch_applier.py +3 -78
- truss/remote/baseten/api.py +308 -23
- truss/remote/baseten/auth.py +3 -3
- truss/remote/baseten/core.py +257 -50
- truss/remote/baseten/custom_types.py +44 -0
- truss/remote/baseten/error.py +4 -0
- truss/remote/baseten/remote.py +369 -118
- truss/remote/baseten/service.py +118 -11
- truss/remote/baseten/utils/status.py +29 -0
- truss/remote/baseten/utils/tar.py +34 -22
- truss/remote/baseten/utils/transfer.py +36 -23
- truss/remote/remote_factory.py +14 -5
- truss/remote/truss_remote.py +72 -45
- truss/templates/base.Dockerfile.jinja +18 -16
- truss/templates/cache.Dockerfile.jinja +3 -3
- truss/{server → templates/control}/control/application.py +14 -35
- truss/{server → templates/control}/control/endpoints.py +39 -9
- truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
- truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
- truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
- truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
- truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
- truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
- truss/{server → templates/control}/control/server.py +11 -6
- truss/templates/control/requirements.txt +9 -0
- truss/templates/custom_python_dx/my_model.py +28 -0
- truss/templates/docker_server/proxy.conf.jinja +42 -0
- truss/templates/docker_server/supervisord.conf.jinja +27 -0
- truss/templates/docker_server_requirements.txt +1 -0
- truss/templates/server/common/errors.py +231 -0
- truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
- truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
- truss/{server → templates/server}/common/retry.py +1 -0
- truss/{server → templates/server}/common/schema.py +11 -9
- truss/templates/server/common/tracing.py +157 -0
- truss/templates/server/main.py +9 -0
- truss/templates/server/model_wrapper.py +961 -0
- truss/templates/server/requirements.txt +21 -0
- truss/templates/server/truss_server.py +447 -0
- truss/templates/server.Dockerfile.jinja +62 -14
- truss/templates/shared/dynamic_config_resolver.py +28 -0
- truss/templates/shared/lazy_data_resolver.py +164 -0
- truss/templates/shared/log_config.py +125 -0
- truss/{server → templates}/shared/secrets_resolver.py +1 -2
- truss/{server → templates}/shared/serialization.py +31 -9
- truss/{server → templates}/shared/util.py +3 -13
- truss/templates/trtllm-audio/model/model.py +49 -0
- truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
- truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
- truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
- truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
- truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
- truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
- truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
- truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
- truss/templates/trtllm-briton/src/extension.py +64 -0
- truss/tests/conftest.py +302 -94
- truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
- truss/tests/contexts/local_loader/test_load_local.py +2 -2
- truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
- truss/tests/patch/test_calc_patch.py +439 -127
- truss/tests/patch/test_dir_signature.py +3 -12
- truss/tests/patch/test_hash.py +1 -1
- truss/tests/patch/test_signature.py +1 -1
- truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
- truss/tests/patch/test_types.py +2 -2
- truss/tests/remote/baseten/test_api.py +153 -58
- truss/tests/remote/baseten/test_auth.py +2 -1
- truss/tests/remote/baseten/test_core.py +160 -12
- truss/tests/remote/baseten/test_remote.py +489 -77
- truss/tests/remote/baseten/test_service.py +55 -0
- truss/tests/remote/test_remote_factory.py +16 -18
- truss/tests/remote/test_truss_remote.py +26 -17
- truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
- truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
- truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
- truss/tests/{server → templates/control}/control/test_server.py +79 -24
- truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
- truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
- truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
- truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
- truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
- truss/tests/{server → templates/server}/common/test_retry.py +3 -3
- truss/tests/templates/server/test_model_wrapper.py +248 -0
- truss/tests/{server → templates/server}/test_schema.py +3 -5
- truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
- truss/tests/test_build.py +9 -52
- truss/tests/test_config.py +336 -77
- truss/tests/test_context_builder_image.py +3 -11
- truss/tests/test_control_truss_patching.py +7 -12
- truss/tests/test_custom_server.py +38 -0
- truss/tests/test_data/context_builder_image_test/test.py +3 -0
- truss/tests/test_data/happy.ipynb +56 -0
- truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
- truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
- truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
- truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
- truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
- truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
- truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
- truss/tests/test_data/test_async_truss/__init__.py +0 -0
- truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/config.yaml +16 -0
- truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/config.yaml +13 -0
- truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
- truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
- truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
- truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
- truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
- truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
- truss/tests/test_data/test_openai/__init__.py +0 -0
- truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
- truss/tests/test_data/test_openai/model/__init__.py +0 -0
- truss/tests/test_data/test_openai/model/model.py +15 -0
- truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
- truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
- truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
- truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
- truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
- truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
- truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
- truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
- truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
- truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
- truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
- truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
- truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
- truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
- truss/tests/test_data/test_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss/config.yaml +4 -0
- truss/tests/test_data/test_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss/model/dummy +0 -0
- truss/tests/test_data/test_truss/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
- truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
- truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_docker.py +2 -1
- truss/tests/test_model_inference.py +1340 -292
- truss/tests/test_model_schema.py +33 -26
- truss/tests/test_testing_utilities_for_other_tests.py +50 -5
- truss/tests/test_truss_gatherer.py +3 -5
- truss/tests/test_truss_handle.py +62 -59
- truss/tests/test_util.py +2 -1
- truss/tests/test_validation.py +15 -13
- truss/tests/trt_llm/test_trt_llm_config.py +41 -0
- truss/tests/trt_llm/test_validation.py +91 -0
- truss/tests/util/test_config_checks.py +40 -0
- truss/tests/util/test_env_vars.py +14 -0
- truss/tests/util/test_path.py +10 -23
- truss/trt_llm/config_checks.py +43 -0
- truss/trt_llm/validation.py +42 -0
- truss/truss_handle/__init__.py +0 -0
- truss/truss_handle/build.py +122 -0
- truss/{decorators.py → truss_handle/decorators.py} +1 -1
- truss/truss_handle/patch/__init__.py +0 -0
- truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
- truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
- truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
- truss/truss_handle/patch/hash.py +71 -0
- truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
- truss/truss_handle/patch/signature.py +22 -0
- truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
- truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
- truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
- truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
- truss/util/.truss_ignore +3 -0
- truss/{docker.py → util/docker.py} +6 -2
- truss/util/download.py +6 -15
- truss/util/env_vars.py +41 -0
- truss/util/log_utils.py +52 -0
- truss/util/path.py +20 -20
- truss/util/requirements.py +11 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
- truss-0.60.0.dist-info/RECORD +324 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
- truss-0.60.0.dist-info/entry_points.txt +4 -0
- truss_chains/__init__.py +71 -0
- truss_chains/definitions.py +756 -0
- truss_chains/deployment/__init__.py +0 -0
- truss_chains/deployment/code_gen.py +816 -0
- truss_chains/deployment/deployment_client.py +871 -0
- truss_chains/framework.py +1480 -0
- truss_chains/public_api.py +231 -0
- truss_chains/py.typed +0 -0
- truss_chains/pydantic_numpy.py +131 -0
- truss_chains/reference_code/reference_chainlet.py +34 -0
- truss_chains/reference_code/reference_model.py +10 -0
- truss_chains/remote_chainlet/__init__.py +0 -0
- truss_chains/remote_chainlet/model_skeleton.py +60 -0
- truss_chains/remote_chainlet/stub.py +380 -0
- truss_chains/remote_chainlet/utils.py +332 -0
- truss_chains/streaming.py +378 -0
- truss_chains/utils.py +178 -0
- CODE_OF_CONDUCT.md +0 -131
- CONTRIBUTING.md +0 -48
- README.md +0 -137
- context_builder.Dockerfile +0 -24
- truss/blob/blob_backend.py +0 -10
- truss/blob/blob_backend_registry.py +0 -23
- truss/blob/http_public_blob_backend.py +0 -23
- truss/build/__init__.py +0 -2
- truss/build/build.py +0 -143
- truss/build/configure.py +0 -63
- truss/cli/__init__.py +0 -2
- truss/cli/console.py +0 -5
- truss/cli/create.py +0 -5
- truss/config/trt_llm.py +0 -81
- truss/constants.py +0 -61
- truss/model_inference.py +0 -123
- truss/patch/types.py +0 -30
- truss/pytest.ini +0 -7
- truss/server/common/errors.py +0 -100
- truss/server/common/termination_handler_middleware.py +0 -64
- truss/server/common/truss_server.py +0 -389
- truss/server/control/patch/model_code_patch_applier.py +0 -46
- truss/server/control/patch/requirement_name_identifier.py +0 -17
- truss/server/inference_server.py +0 -29
- truss/server/model_wrapper.py +0 -434
- truss/server/shared/logging.py +0 -81
- truss/templates/trtllm/model/model.py +0 -97
- truss/templates/trtllm/packages/build_engine_utils.py +0 -34
- truss/templates/trtllm/packages/constants.py +0 -11
- truss/templates/trtllm/packages/schema.py +0 -216
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
- truss/templates/trtllm/packages/triton_client.py +0 -150
- truss/templates/trtllm/packages/utils.py +0 -43
- truss/test_data/context_builder_image_test/test.py +0 -4
- truss/test_data/happy.ipynb +0 -54
- truss/test_data/model_load_failure_test/config.yaml +0 -2
- truss/test_data/test_concurrency_truss/config.yaml +0 -2
- truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
- truss/test_data/test_streaming_truss/config.yaml +0 -3
- truss/test_data/test_truss/config.yaml +0 -2
- truss/tests/server/common/test_termination_handler_middleware.py +0 -93
- truss/tests/server/control/test_model_container_patch_applier.py +0 -203
- truss/tests/server/core/server/common/test_util.py +0 -19
- truss/tests/server/test_model_wrapper.py +0 -87
- truss/util/data_structures.py +0 -16
- truss-0.10.0rc1.dist-info/RECORD +0 -216
- truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
- truss/{server/shared → base}/__init__.py +0 -0
- truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
- truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
- truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
- truss/{server → templates/server}/common/__init__.py +0 -0
- truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
- truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
- truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
- truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
- truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
- truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
- truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
- truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
- truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
- truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
- truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
- truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
- truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
- truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
- /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
- /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
- /truss/{patch → truss_handle/patch}/constants.py +0 -0
- /truss/{notebook.py → util/notebook.py} +0 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import os
|
|
18
|
+
from collections import OrderedDict
|
|
19
|
+
from functools import lru_cache
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from subprocess import CalledProcessError, run
|
|
22
|
+
from typing import Optional, Union
|
|
23
|
+
|
|
24
|
+
import numpy as np
|
|
25
|
+
import soundfile
|
|
26
|
+
import torch
|
|
27
|
+
import torch.nn.functional as F
|
|
28
|
+
from anyio import to_thread
|
|
29
|
+
|
|
30
|
+
Pathlike = Union[str, Path]
|
|
31
|
+
|
|
32
|
+
SAMPLE_RATE = 16000
|
|
33
|
+
N_FFT = 400
|
|
34
|
+
HOP_LENGTH = 160
|
|
35
|
+
CHUNK_LENGTH = 30
|
|
36
|
+
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
|
40
|
+
"""
|
|
41
|
+
Open an audio file and read as mono waveform, resampling as necessary
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
file: str
|
|
46
|
+
The audio file to open
|
|
47
|
+
|
|
48
|
+
sr: int
|
|
49
|
+
The sample rate to resample the audio if necessary
|
|
50
|
+
|
|
51
|
+
Returns
|
|
52
|
+
-------
|
|
53
|
+
A NumPy array containing the audio waveform, in float32 dtype.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
# This launches a subprocess to decode audio while down-mixing
|
|
57
|
+
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
|
58
|
+
# fmt: off
|
|
59
|
+
cmd = [
|
|
60
|
+
"ffmpeg", "-nostdin", "-threads", "0", "-i", file, "-f", "s16le", "-ac",
|
|
61
|
+
"1", "-acodec", "pcm_s16le", "-ar",
|
|
62
|
+
str(sr), "-"
|
|
63
|
+
]
|
|
64
|
+
# fmt: on
|
|
65
|
+
try:
|
|
66
|
+
out = run(cmd, capture_output=True, check=True).stdout
|
|
67
|
+
except CalledProcessError as e:
|
|
68
|
+
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
|
69
|
+
|
|
70
|
+
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def load_audio_wav_format(wav_path):
|
|
74
|
+
# make sure audio in .wav format
|
|
75
|
+
assert wav_path.endswith(".wav"), f"Only support .wav format, but got {wav_path}"
|
|
76
|
+
waveform, sample_rate = soundfile.read(wav_path)
|
|
77
|
+
assert sample_rate == 16000, f"Only support 16k sample rate, but got {sample_rate}"
|
|
78
|
+
return waveform, sample_rate
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
|
82
|
+
"""
|
|
83
|
+
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
|
84
|
+
"""
|
|
85
|
+
if torch.is_tensor(array):
|
|
86
|
+
if array.shape[axis] > length:
|
|
87
|
+
array = array.index_select(
|
|
88
|
+
dim=axis, index=torch.arange(length, device=array.device)
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
if array.shape[axis] < length:
|
|
92
|
+
pad_widths = [(0, 0)] * array.ndim
|
|
93
|
+
pad_widths[axis] = (0, length - array.shape[axis])
|
|
94
|
+
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
|
|
95
|
+
else:
|
|
96
|
+
if array.shape[axis] > length:
|
|
97
|
+
array = array.take(indices=range(length), axis=axis)
|
|
98
|
+
|
|
99
|
+
if array.shape[axis] < length:
|
|
100
|
+
pad_widths = [(0, 0)] * array.ndim
|
|
101
|
+
pad_widths[axis] = (0, length - array.shape[axis])
|
|
102
|
+
array = np.pad(array, pad_widths)
|
|
103
|
+
|
|
104
|
+
return array
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@lru_cache(maxsize=None)
|
|
108
|
+
def mel_filters(
|
|
109
|
+
device, n_mels: int, mel_filters_dir: Optional[str] = None
|
|
110
|
+
) -> torch.Tensor:
|
|
111
|
+
"""
|
|
112
|
+
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
|
113
|
+
Allows decoupling librosa dependency; saved using:
|
|
114
|
+
|
|
115
|
+
np.savez_compressed(
|
|
116
|
+
"mel_filters.npz",
|
|
117
|
+
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
|
118
|
+
)
|
|
119
|
+
"""
|
|
120
|
+
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
|
121
|
+
if mel_filters_dir is None:
|
|
122
|
+
mel_filters_path = os.path.join(
|
|
123
|
+
os.path.dirname(__file__), "assets", "mel_filters.npz"
|
|
124
|
+
)
|
|
125
|
+
else:
|
|
126
|
+
mel_filters_path = os.path.join(mel_filters_dir, "mel_filters.npz")
|
|
127
|
+
with np.load(mel_filters_path) as f:
|
|
128
|
+
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _log_mel_spectrogram(
|
|
132
|
+
audio: Union[str, np.ndarray, torch.Tensor],
|
|
133
|
+
n_mels: int,
|
|
134
|
+
padding: int = 0,
|
|
135
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
136
|
+
mel_filters_dir: Optional[str] = None,
|
|
137
|
+
):
|
|
138
|
+
"""
|
|
139
|
+
Compute the log-Mel spectrogram of
|
|
140
|
+
|
|
141
|
+
Parameters
|
|
142
|
+
----------
|
|
143
|
+
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
|
144
|
+
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
|
145
|
+
|
|
146
|
+
n_mels: int
|
|
147
|
+
The number of Mel-frequency filters, only 80 and 128 are supported
|
|
148
|
+
|
|
149
|
+
padding: int
|
|
150
|
+
Number of zero samples to pad to the right
|
|
151
|
+
|
|
152
|
+
device: Optional[Union[str, torch.device]]
|
|
153
|
+
If given, the audio tensor is moved to this device before STFT
|
|
154
|
+
|
|
155
|
+
Returns
|
|
156
|
+
-------
|
|
157
|
+
torch.Tensor, shape = (80 or 128, n_frames)
|
|
158
|
+
A Tensor that contains the Mel spectrogram
|
|
159
|
+
"""
|
|
160
|
+
if not torch.is_tensor(audio):
|
|
161
|
+
if isinstance(audio, str):
|
|
162
|
+
if audio.endswith(".wav"):
|
|
163
|
+
audio, _ = load_audio_wav_format(audio)
|
|
164
|
+
else:
|
|
165
|
+
audio = load_audio(audio)
|
|
166
|
+
assert isinstance(audio, np.ndarray), f"Unsupported audio type: {type(audio)}"
|
|
167
|
+
audio = pad_or_trim(audio, N_SAMPLES)
|
|
168
|
+
audio = audio.astype(np.float32)
|
|
169
|
+
audio = torch.from_numpy(audio)
|
|
170
|
+
|
|
171
|
+
if device is not None:
|
|
172
|
+
audio = audio.to(device) # type: ignore
|
|
173
|
+
if padding > 0:
|
|
174
|
+
audio = F.pad(audio, (0, padding))
|
|
175
|
+
window = torch.hann_window(N_FFT).to(audio.device) # type: ignore
|
|
176
|
+
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
|
177
|
+
magnitudes = stft[..., :-1].abs() ** 2
|
|
178
|
+
|
|
179
|
+
filters = mel_filters(audio.device, n_mels, mel_filters_dir) # type: ignore
|
|
180
|
+
mel_spec = filters @ magnitudes
|
|
181
|
+
|
|
182
|
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
|
183
|
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
|
184
|
+
log_spec = (log_spec + 4.0) / 4.0
|
|
185
|
+
return log_spec
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
async def log_mel_spectrogram(
|
|
189
|
+
audio: Union[str, np.ndarray, torch.Tensor],
|
|
190
|
+
n_mels: int,
|
|
191
|
+
padding: int = 0,
|
|
192
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
193
|
+
mel_filters_dir: Optional[str] = None,
|
|
194
|
+
):
|
|
195
|
+
log_spec = await to_thread.run_sync(
|
|
196
|
+
lambda: _log_mel_spectrogram(audio, n_mels, padding, device, mel_filters_dir)
|
|
197
|
+
)
|
|
198
|
+
return log_spec
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def remove_tensor_padding(input_tensor, input_tensor_lengths=None, pad_value=0):
|
|
202
|
+
if input_tensor.dim() == 2:
|
|
203
|
+
# Text tensor case: batch, seq_len
|
|
204
|
+
assert torch.all(input_tensor[:, 0] != pad_value), (
|
|
205
|
+
"First token in each sequence should not be pad_value"
|
|
206
|
+
)
|
|
207
|
+
assert input_tensor_lengths is None
|
|
208
|
+
|
|
209
|
+
# Create a mask for all non-pad tokens
|
|
210
|
+
mask = input_tensor != pad_value
|
|
211
|
+
|
|
212
|
+
# Apply the mask to input_tensor to remove pad tokens
|
|
213
|
+
output_tensor = input_tensor[mask].view(1, -1)
|
|
214
|
+
|
|
215
|
+
elif input_tensor.dim() == 3:
|
|
216
|
+
# Audio tensor case: batch, seq_len, feature_len
|
|
217
|
+
assert input_tensor_lengths is not None, (
|
|
218
|
+
"input_tensor_lengths must be provided for 3D input_tensor"
|
|
219
|
+
)
|
|
220
|
+
batch_size, seq_len, feature_len = input_tensor.shape
|
|
221
|
+
|
|
222
|
+
# Initialize a list to collect valid sequences
|
|
223
|
+
valid_sequences = []
|
|
224
|
+
|
|
225
|
+
for i in range(batch_size):
|
|
226
|
+
valid_length = input_tensor_lengths[i]
|
|
227
|
+
valid_sequences.append(input_tensor[i, :valid_length, :])
|
|
228
|
+
|
|
229
|
+
# Concatenate all valid sequences along the batch dimension
|
|
230
|
+
output_tensor = torch.cat(valid_sequences, dim=0)
|
|
231
|
+
|
|
232
|
+
else:
|
|
233
|
+
raise ValueError("Input tensor must have 2 or 3 dimensions")
|
|
234
|
+
|
|
235
|
+
return output_tensor
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def read_config(component, engine_dir):
|
|
239
|
+
config_path = engine_dir / component / "config.json"
|
|
240
|
+
with open(config_path, "r") as f:
|
|
241
|
+
config = json.load(f)
|
|
242
|
+
model_config = OrderedDict()
|
|
243
|
+
model_config.update(config["pretrained_config"])
|
|
244
|
+
model_config.update(config["build_config"])
|
|
245
|
+
return model_config
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from briton.spec_dec_truss_model import Model as SpecDecModel
|
|
2
|
+
from briton.trtllm_config import TRTLLMConfiguration
|
|
3
|
+
from briton.truss_model import Model
|
|
4
|
+
|
|
5
|
+
# TODO(pankaj) Define an ABC base class for this. That baseclass should live in
|
|
6
|
+
# a new, smaller truss sub-library, perhaps called `truss-runtime`` for inclusion
|
|
7
|
+
# in Truss runtime. Once we have that sub-library, we should define the Extension
|
|
8
|
+
# base class there and derive Extension class below from it.
|
|
9
|
+
#
|
|
10
|
+
# That base class would look like:
|
|
11
|
+
# class TrussExtension(ABC):
|
|
12
|
+
# @abstracemethod
|
|
13
|
+
# def model_override(self):
|
|
14
|
+
# pass
|
|
15
|
+
|
|
16
|
+
# @abstractmethod
|
|
17
|
+
# def model_args(self) -> dict:
|
|
18
|
+
# pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# @abstractmethod
|
|
22
|
+
# def load(self) -> dict:
|
|
23
|
+
# pass
|
|
24
|
+
class Extension:
|
|
25
|
+
"""
|
|
26
|
+
trt_llm truss extension.
|
|
27
|
+
|
|
28
|
+
Provides model_args to supply to model class, which contain the trtllm
|
|
29
|
+
engine that corresponds to provided config.
|
|
30
|
+
|
|
31
|
+
This extension also provides a full replacement of the model class, which is
|
|
32
|
+
to be used if user doesn't supply it. This may be desired behavior in many
|
|
33
|
+
cases where users want to just go by config and don't want to do any pre or
|
|
34
|
+
post-processing.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, *args, **kwargs):
|
|
38
|
+
self._config = kwargs["config"]
|
|
39
|
+
trt_llm_config = self._config.get("trt_llm")
|
|
40
|
+
config = TRTLLMConfiguration(**trt_llm_config)
|
|
41
|
+
if config.build.speculator:
|
|
42
|
+
self._model = SpecDecModel(*args, **kwargs)
|
|
43
|
+
else:
|
|
44
|
+
self._model = Model(*args, **kwargs)
|
|
45
|
+
|
|
46
|
+
def model_override(self):
|
|
47
|
+
"""Return a model object.
|
|
48
|
+
|
|
49
|
+
This is used if model.py is omitted, which is allowed when using trt_llm.
|
|
50
|
+
"""
|
|
51
|
+
return self._model
|
|
52
|
+
|
|
53
|
+
def model_args(self) -> dict:
|
|
54
|
+
"""Return args to supply as input to Model class' __init__ method.
|
|
55
|
+
|
|
56
|
+
Model class can use this to invoke the trt_llm engine.
|
|
57
|
+
|
|
58
|
+
Returned engine is a typical Truss model class that provides a predict
|
|
59
|
+
function. The predict function is async and returns an async generator.
|
|
60
|
+
"""
|
|
61
|
+
return {"engine": self._model}
|
|
62
|
+
|
|
63
|
+
def load(self):
|
|
64
|
+
self._model.load()
|