truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/__init__.py +10 -3
- truss/api/__init__.py +123 -0
- truss/api/definitions.py +51 -0
- truss/base/constants.py +116 -0
- truss/base/custom_types.py +29 -0
- truss/{errors.py → base/errors.py} +4 -0
- truss/base/trt_llm_config.py +310 -0
- truss/{truss_config.py → base/truss_config.py} +344 -31
- truss/{truss_spec.py → base/truss_spec.py} +20 -6
- truss/{validation.py → base/validation.py} +60 -11
- truss/cli/cli.py +841 -88
- truss/{remote → cli}/remote_cli.py +2 -7
- truss/contexts/docker_build_setup.py +67 -0
- truss/contexts/image_builder/cache_warmer.py +2 -8
- truss/contexts/image_builder/image_builder.py +1 -1
- truss/contexts/image_builder/serving_image_builder.py +292 -46
- truss/contexts/image_builder/util.py +1 -3
- truss/contexts/local_loader/docker_build_emulator.py +58 -0
- truss/contexts/local_loader/load_model_local.py +2 -2
- truss/contexts/local_loader/truss_module_loader.py +1 -1
- truss/contexts/local_loader/utils.py +1 -1
- truss/local/local_config.py +2 -6
- truss/local/local_config_handler.py +20 -5
- truss/patch/__init__.py +1 -0
- truss/patch/hash.py +4 -70
- truss/patch/signature.py +4 -16
- truss/patch/truss_dir_patch_applier.py +3 -78
- truss/remote/baseten/api.py +308 -23
- truss/remote/baseten/auth.py +3 -3
- truss/remote/baseten/core.py +257 -50
- truss/remote/baseten/custom_types.py +44 -0
- truss/remote/baseten/error.py +4 -0
- truss/remote/baseten/remote.py +369 -118
- truss/remote/baseten/service.py +118 -11
- truss/remote/baseten/utils/status.py +29 -0
- truss/remote/baseten/utils/tar.py +34 -22
- truss/remote/baseten/utils/transfer.py +36 -23
- truss/remote/remote_factory.py +14 -5
- truss/remote/truss_remote.py +72 -45
- truss/templates/base.Dockerfile.jinja +18 -16
- truss/templates/cache.Dockerfile.jinja +3 -3
- truss/{server → templates/control}/control/application.py +14 -35
- truss/{server → templates/control}/control/endpoints.py +39 -9
- truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
- truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
- truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
- truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
- truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
- truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
- truss/{server → templates/control}/control/server.py +11 -6
- truss/templates/control/requirements.txt +9 -0
- truss/templates/custom_python_dx/my_model.py +28 -0
- truss/templates/docker_server/proxy.conf.jinja +42 -0
- truss/templates/docker_server/supervisord.conf.jinja +27 -0
- truss/templates/docker_server_requirements.txt +1 -0
- truss/templates/server/common/errors.py +231 -0
- truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
- truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
- truss/{server → templates/server}/common/retry.py +1 -0
- truss/{server → templates/server}/common/schema.py +11 -9
- truss/templates/server/common/tracing.py +157 -0
- truss/templates/server/main.py +9 -0
- truss/templates/server/model_wrapper.py +961 -0
- truss/templates/server/requirements.txt +21 -0
- truss/templates/server/truss_server.py +447 -0
- truss/templates/server.Dockerfile.jinja +62 -14
- truss/templates/shared/dynamic_config_resolver.py +28 -0
- truss/templates/shared/lazy_data_resolver.py +164 -0
- truss/templates/shared/log_config.py +125 -0
- truss/{server → templates}/shared/secrets_resolver.py +1 -2
- truss/{server → templates}/shared/serialization.py +31 -9
- truss/{server → templates}/shared/util.py +3 -13
- truss/templates/trtllm-audio/model/model.py +49 -0
- truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
- truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
- truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
- truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
- truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
- truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
- truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
- truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
- truss/templates/trtllm-briton/src/extension.py +64 -0
- truss/tests/conftest.py +302 -94
- truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
- truss/tests/contexts/local_loader/test_load_local.py +2 -2
- truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
- truss/tests/patch/test_calc_patch.py +439 -127
- truss/tests/patch/test_dir_signature.py +3 -12
- truss/tests/patch/test_hash.py +1 -1
- truss/tests/patch/test_signature.py +1 -1
- truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
- truss/tests/patch/test_types.py +2 -2
- truss/tests/remote/baseten/test_api.py +153 -58
- truss/tests/remote/baseten/test_auth.py +2 -1
- truss/tests/remote/baseten/test_core.py +160 -12
- truss/tests/remote/baseten/test_remote.py +489 -77
- truss/tests/remote/baseten/test_service.py +55 -0
- truss/tests/remote/test_remote_factory.py +16 -18
- truss/tests/remote/test_truss_remote.py +26 -17
- truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
- truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
- truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
- truss/tests/{server → templates/control}/control/test_server.py +79 -24
- truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
- truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
- truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
- truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
- truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
- truss/tests/{server → templates/server}/common/test_retry.py +3 -3
- truss/tests/templates/server/test_model_wrapper.py +248 -0
- truss/tests/{server → templates/server}/test_schema.py +3 -5
- truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
- truss/tests/test_build.py +9 -52
- truss/tests/test_config.py +336 -77
- truss/tests/test_context_builder_image.py +3 -11
- truss/tests/test_control_truss_patching.py +7 -12
- truss/tests/test_custom_server.py +38 -0
- truss/tests/test_data/context_builder_image_test/test.py +3 -0
- truss/tests/test_data/happy.ipynb +56 -0
- truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
- truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
- truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
- truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
- truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
- truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
- truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
- truss/tests/test_data/test_async_truss/__init__.py +0 -0
- truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/config.yaml +16 -0
- truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/config.yaml +13 -0
- truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
- truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
- truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
- truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
- truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
- truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
- truss/tests/test_data/test_openai/__init__.py +0 -0
- truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
- truss/tests/test_data/test_openai/model/__init__.py +0 -0
- truss/tests/test_data/test_openai/model/model.py +15 -0
- truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
- truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
- truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
- truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
- truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
- truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
- truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
- truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
- truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
- truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
- truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
- truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
- truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
- truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
- truss/tests/test_data/test_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss/config.yaml +4 -0
- truss/tests/test_data/test_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss/model/dummy +0 -0
- truss/tests/test_data/test_truss/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
- truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
- truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_docker.py +2 -1
- truss/tests/test_model_inference.py +1340 -292
- truss/tests/test_model_schema.py +33 -26
- truss/tests/test_testing_utilities_for_other_tests.py +50 -5
- truss/tests/test_truss_gatherer.py +3 -5
- truss/tests/test_truss_handle.py +62 -59
- truss/tests/test_util.py +2 -1
- truss/tests/test_validation.py +15 -13
- truss/tests/trt_llm/test_trt_llm_config.py +41 -0
- truss/tests/trt_llm/test_validation.py +91 -0
- truss/tests/util/test_config_checks.py +40 -0
- truss/tests/util/test_env_vars.py +14 -0
- truss/tests/util/test_path.py +10 -23
- truss/trt_llm/config_checks.py +43 -0
- truss/trt_llm/validation.py +42 -0
- truss/truss_handle/__init__.py +0 -0
- truss/truss_handle/build.py +122 -0
- truss/{decorators.py → truss_handle/decorators.py} +1 -1
- truss/truss_handle/patch/__init__.py +0 -0
- truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
- truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
- truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
- truss/truss_handle/patch/hash.py +71 -0
- truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
- truss/truss_handle/patch/signature.py +22 -0
- truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
- truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
- truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
- truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
- truss/util/.truss_ignore +3 -0
- truss/{docker.py → util/docker.py} +6 -2
- truss/util/download.py +6 -15
- truss/util/env_vars.py +41 -0
- truss/util/log_utils.py +52 -0
- truss/util/path.py +20 -20
- truss/util/requirements.py +11 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
- truss-0.60.0.dist-info/RECORD +324 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
- truss-0.60.0.dist-info/entry_points.txt +4 -0
- truss_chains/__init__.py +71 -0
- truss_chains/definitions.py +756 -0
- truss_chains/deployment/__init__.py +0 -0
- truss_chains/deployment/code_gen.py +816 -0
- truss_chains/deployment/deployment_client.py +871 -0
- truss_chains/framework.py +1480 -0
- truss_chains/public_api.py +231 -0
- truss_chains/py.typed +0 -0
- truss_chains/pydantic_numpy.py +131 -0
- truss_chains/reference_code/reference_chainlet.py +34 -0
- truss_chains/reference_code/reference_model.py +10 -0
- truss_chains/remote_chainlet/__init__.py +0 -0
- truss_chains/remote_chainlet/model_skeleton.py +60 -0
- truss_chains/remote_chainlet/stub.py +380 -0
- truss_chains/remote_chainlet/utils.py +332 -0
- truss_chains/streaming.py +378 -0
- truss_chains/utils.py +178 -0
- CODE_OF_CONDUCT.md +0 -131
- CONTRIBUTING.md +0 -48
- README.md +0 -137
- context_builder.Dockerfile +0 -24
- truss/blob/blob_backend.py +0 -10
- truss/blob/blob_backend_registry.py +0 -23
- truss/blob/http_public_blob_backend.py +0 -23
- truss/build/__init__.py +0 -2
- truss/build/build.py +0 -143
- truss/build/configure.py +0 -63
- truss/cli/__init__.py +0 -2
- truss/cli/console.py +0 -5
- truss/cli/create.py +0 -5
- truss/config/trt_llm.py +0 -81
- truss/constants.py +0 -61
- truss/model_inference.py +0 -123
- truss/patch/types.py +0 -30
- truss/pytest.ini +0 -7
- truss/server/common/errors.py +0 -100
- truss/server/common/termination_handler_middleware.py +0 -64
- truss/server/common/truss_server.py +0 -389
- truss/server/control/patch/model_code_patch_applier.py +0 -46
- truss/server/control/patch/requirement_name_identifier.py +0 -17
- truss/server/inference_server.py +0 -29
- truss/server/model_wrapper.py +0 -434
- truss/server/shared/logging.py +0 -81
- truss/templates/trtllm/model/model.py +0 -97
- truss/templates/trtllm/packages/build_engine_utils.py +0 -34
- truss/templates/trtllm/packages/constants.py +0 -11
- truss/templates/trtllm/packages/schema.py +0 -216
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
- truss/templates/trtllm/packages/triton_client.py +0 -150
- truss/templates/trtllm/packages/utils.py +0 -43
- truss/test_data/context_builder_image_test/test.py +0 -4
- truss/test_data/happy.ipynb +0 -54
- truss/test_data/model_load_failure_test/config.yaml +0 -2
- truss/test_data/test_concurrency_truss/config.yaml +0 -2
- truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
- truss/test_data/test_streaming_truss/config.yaml +0 -3
- truss/test_data/test_truss/config.yaml +0 -2
- truss/tests/server/common/test_termination_handler_middleware.py +0 -93
- truss/tests/server/control/test_model_container_patch_applier.py +0 -203
- truss/tests/server/core/server/common/test_util.py +0 -19
- truss/tests/server/test_model_wrapper.py +0 -87
- truss/util/data_structures.py +0 -16
- truss-0.10.0rc1.dist-info/RECORD +0 -216
- truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
- truss/{server/shared → base}/__init__.py +0 -0
- truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
- truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
- truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
- truss/{server → templates/server}/common/__init__.py +0 -0
- truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
- truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
- truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
- truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
- truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
- truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
- truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
- truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
- truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
- truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
- truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
- truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
- truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
- truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
- /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
- /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
- /truss/{patch → truss_handle/patch}/constants.py +0 -0
- /truss/{notebook.py → util/notebook.py} +0 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import re
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import tensorrt_llm
|
|
7
|
+
import torch
|
|
8
|
+
import torchaudio
|
|
9
|
+
from torch import Tensor
|
|
10
|
+
|
|
11
|
+
from whisper_trt.assets import download_assets
|
|
12
|
+
from whisper_trt.batching import WhisperBatchProcessor
|
|
13
|
+
from whisper_trt.custom_types import (
|
|
14
|
+
DEFAULT_MAX_NEW_TOKENS,
|
|
15
|
+
DEFAULT_NUM_BEAMS,
|
|
16
|
+
SUPPORTED_SAMPLE_RATE,
|
|
17
|
+
BatchWhisperItem,
|
|
18
|
+
Segment,
|
|
19
|
+
WhisperResult,
|
|
20
|
+
)
|
|
21
|
+
from whisper_trt.modeling import WhisperDecoding, WhisperEncoding
|
|
22
|
+
from whisper_trt.tokenizer import REVERSED_LANGUAGES, get_tokenizer
|
|
23
|
+
from whisper_trt.utils import log_mel_spectrogram
|
|
24
|
+
|
|
25
|
+
SEGMENTS_PATTERN = re.compile(r"<\|([\d.]+)\|>([^<]+)<\|([\d.]+)\|>")
|
|
26
|
+
LANG_CODE_PATTERN = re.compile(r"<\|([a-z]{2})\|>")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class WhisperModel(object):
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
engine_dir,
|
|
33
|
+
tokenizer_name="multilingual",
|
|
34
|
+
debug_mode=False,
|
|
35
|
+
assets_dir=None,
|
|
36
|
+
max_queue_time=0.01, # 10 ms by default
|
|
37
|
+
):
|
|
38
|
+
world_size = 1
|
|
39
|
+
runtime_rank = tensorrt_llm.mpi_rank()
|
|
40
|
+
runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank)
|
|
41
|
+
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
|
|
42
|
+
|
|
43
|
+
engine_dir = Path(engine_dir)
|
|
44
|
+
|
|
45
|
+
self.assets_dir = assets_dir
|
|
46
|
+
if self.assets_dir is None:
|
|
47
|
+
self.assets_dir = download_assets()
|
|
48
|
+
|
|
49
|
+
self.encoder = WhisperEncoding(engine_dir)
|
|
50
|
+
self.decoder = WhisperDecoding(
|
|
51
|
+
engine_dir, runtime_mapping, debug_mode=debug_mode
|
|
52
|
+
)
|
|
53
|
+
self.batch_size = self.decoder.decoder_config["max_batch_size"]
|
|
54
|
+
self.n_mels = self.encoder.n_mels
|
|
55
|
+
self.tokenizer = get_tokenizer(
|
|
56
|
+
name=tokenizer_name,
|
|
57
|
+
num_languages=self.encoder.num_languages,
|
|
58
|
+
tokenizer_dir=self.assets_dir,
|
|
59
|
+
)
|
|
60
|
+
self.eot_id = self.tokenizer.encode(
|
|
61
|
+
"<|endoftext|>", allowed_special=self.tokenizer.special_tokens_set
|
|
62
|
+
)[0]
|
|
63
|
+
|
|
64
|
+
self.batch_processor = WhisperBatchProcessor(
|
|
65
|
+
self, max_batch_size=self.batch_size, max_queue_time=max_queue_time
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
def preprocess_audio(self, binary_data) -> dict:
|
|
69
|
+
audio_stream = io.BytesIO(binary_data)
|
|
70
|
+
waveform, sample_rate = torchaudio.load(audio_stream)
|
|
71
|
+
|
|
72
|
+
# Resample audio to rate compatible with what the model was trained at
|
|
73
|
+
if sample_rate != SUPPORTED_SAMPLE_RATE:
|
|
74
|
+
waveform = torchaudio.transforms.Resample(
|
|
75
|
+
orig_freq=sample_rate, new_freq=SUPPORTED_SAMPLE_RATE
|
|
76
|
+
)(waveform)
|
|
77
|
+
sample_rate = SUPPORTED_SAMPLE_RATE
|
|
78
|
+
|
|
79
|
+
return waveform
|
|
80
|
+
|
|
81
|
+
def _get_text_prefix(
|
|
82
|
+
self,
|
|
83
|
+
language: str = "english",
|
|
84
|
+
prompt: Optional[str] = None,
|
|
85
|
+
timestamps: bool = False,
|
|
86
|
+
task: str = "transcribe",
|
|
87
|
+
prefix: Optional[str] = None,
|
|
88
|
+
):
|
|
89
|
+
try:
|
|
90
|
+
language_code = REVERSED_LANGUAGES[language]
|
|
91
|
+
except KeyError:
|
|
92
|
+
language_code = language
|
|
93
|
+
text_prefix = f"<|startoftranscript|><|{language_code}|><|{task}|>"
|
|
94
|
+
if prompt is not None:
|
|
95
|
+
text_prefix = f"<|startofprev|> {prompt}" + text_prefix
|
|
96
|
+
if timestamps:
|
|
97
|
+
text_prefix += "<|0.00|>"
|
|
98
|
+
else:
|
|
99
|
+
text_prefix += "<|notimestamps|>"
|
|
100
|
+
if prefix is not None:
|
|
101
|
+
text_prefix += prefix
|
|
102
|
+
return text_prefix
|
|
103
|
+
|
|
104
|
+
def process_batch(
|
|
105
|
+
self,
|
|
106
|
+
mel_batch,
|
|
107
|
+
decoder_input_ids,
|
|
108
|
+
num_beams=DEFAULT_NUM_BEAMS,
|
|
109
|
+
max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
|
|
110
|
+
) -> Tensor:
|
|
111
|
+
encoder_output = self.encoder.get_audio_features(mel_batch)
|
|
112
|
+
output_ids = self.decoder.generate(
|
|
113
|
+
decoder_input_ids,
|
|
114
|
+
encoder_output,
|
|
115
|
+
self.eot_id,
|
|
116
|
+
max_new_tokens=max_new_tokens,
|
|
117
|
+
num_beams=num_beams,
|
|
118
|
+
)
|
|
119
|
+
return output_ids
|
|
120
|
+
|
|
121
|
+
def decode_output_ids(self, output_ids, text_prefix):
|
|
122
|
+
text = self.tokenizer.decode(output_ids[0]).strip()
|
|
123
|
+
text.replace(text_prefix, "")
|
|
124
|
+
return text
|
|
125
|
+
|
|
126
|
+
async def detect_audio_and_language(self, mel) -> Optional[str]:
|
|
127
|
+
"""
|
|
128
|
+
Detects the audio and language from the given mel spectrogram.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
mel: The mel spectrogram of the audio.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
The detected language code, or None if no speech is detected.
|
|
135
|
+
"""
|
|
136
|
+
text_prefix = "<|startoftranscript|>"
|
|
137
|
+
|
|
138
|
+
prompt_ids = self.tokenizer.encode(
|
|
139
|
+
text_prefix, allowed_special=self.tokenizer.special_tokens_set
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
output_ids = await self.batch_processor.process(
|
|
143
|
+
item=BatchWhisperItem(mel=mel, prompt_ids=prompt_ids, max_new_tokens=1)
|
|
144
|
+
)
|
|
145
|
+
text = self.decode_output_ids(output_ids, text_prefix)
|
|
146
|
+
if text == "<|nospeech|>":
|
|
147
|
+
return None
|
|
148
|
+
return text.replace(text_prefix, "").replace("<|", "").replace("|>", "")
|
|
149
|
+
|
|
150
|
+
async def transcribe(
|
|
151
|
+
self,
|
|
152
|
+
waveform,
|
|
153
|
+
prompt: Optional[str] = None,
|
|
154
|
+
language: Optional[str] = None,
|
|
155
|
+
timestamps: bool = False,
|
|
156
|
+
num_beams: int = DEFAULT_NUM_BEAMS,
|
|
157
|
+
prefix: Optional[str] = None,
|
|
158
|
+
task: str = "transcribe",
|
|
159
|
+
max_new_tokens=128,
|
|
160
|
+
):
|
|
161
|
+
mel = await log_mel_spectrogram(
|
|
162
|
+
waveform.numpy(),
|
|
163
|
+
self.n_mels,
|
|
164
|
+
device="cuda",
|
|
165
|
+
mel_filters_dir=self.assets_dir,
|
|
166
|
+
)
|
|
167
|
+
mel = mel.type(torch.float16)
|
|
168
|
+
if language is None:
|
|
169
|
+
language = await self.detect_audio_and_language(mel)
|
|
170
|
+
if language is None:
|
|
171
|
+
# No speech was detected. Can result empty segments
|
|
172
|
+
return WhisperResult(segments=[], language_code=None)
|
|
173
|
+
text_prefix = self._get_text_prefix(
|
|
174
|
+
language=language,
|
|
175
|
+
prompt=prompt,
|
|
176
|
+
timestamps=timestamps,
|
|
177
|
+
prefix=prefix,
|
|
178
|
+
task=task,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
prompt_ids = self.tokenizer.encode(
|
|
182
|
+
text_prefix, allowed_special=self.tokenizer.special_tokens_set
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
output_ids: Tensor = await self.batch_processor.process(
|
|
186
|
+
item=BatchWhisperItem(
|
|
187
|
+
mel=mel, prompt_ids=prompt_ids, max_new_tokens=max_new_tokens
|
|
188
|
+
)
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
return self._postprocess_transcript(
|
|
192
|
+
self.decode_output_ids(output_ids, text_prefix)
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
def _postprocess_transcript(self, transcribed_text: str) -> WhisperResult:
|
|
196
|
+
"""
|
|
197
|
+
Post-process the output of the transcription model.
|
|
198
|
+
"""
|
|
199
|
+
language_code = LANG_CODE_PATTERN.findall(transcribed_text)[0]
|
|
200
|
+
|
|
201
|
+
# Find all matches in the input string
|
|
202
|
+
matches = SEGMENTS_PATTERN.findall(transcribed_text)
|
|
203
|
+
|
|
204
|
+
# Process matches to create the desired output format
|
|
205
|
+
segments = []
|
|
206
|
+
for match in matches:
|
|
207
|
+
start, text, end = match
|
|
208
|
+
|
|
209
|
+
segments.append(
|
|
210
|
+
Segment(
|
|
211
|
+
**{"start": float(start), "end": float(end), "text": text.strip()}
|
|
212
|
+
)
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
return WhisperResult(segments=segments, language_code=language_code)
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import urllib.request
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def _download_files(urls):
|
|
6
|
+
ASSETS_DIR = os.path.join(
|
|
7
|
+
os.path.expanduser("~"), ".cache", "whisper-trt", "assets"
|
|
8
|
+
)
|
|
9
|
+
os.makedirs(ASSETS_DIR, exist_ok=True)
|
|
10
|
+
|
|
11
|
+
for url in urls:
|
|
12
|
+
file_name = os.path.basename(url)
|
|
13
|
+
file_path = os.path.join(ASSETS_DIR, file_name)
|
|
14
|
+
if not os.path.exists(file_path):
|
|
15
|
+
urllib.request.urlretrieve(url, file_path)
|
|
16
|
+
return ASSETS_DIR
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def download_assets():
|
|
20
|
+
return _download_files(
|
|
21
|
+
[
|
|
22
|
+
"https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/multilingual.tiktoken",
|
|
23
|
+
"https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/mel_filters.npz",
|
|
24
|
+
]
|
|
25
|
+
)
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import TYPE_CHECKING, List
|
|
3
|
+
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
from whisper_trt import WhisperModel
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from async_batcher.batcher import AsyncBatcher
|
|
9
|
+
from torch import Tensor
|
|
10
|
+
|
|
11
|
+
from whisper_trt.custom_types import DEFAULT_NUM_BEAMS, BatchWhisperItem
|
|
12
|
+
|
|
13
|
+
FIXED_TEXT_PRFIX = "<|startoftranscript|><|en|><|transcribe|><|0.00|>"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class WhisperBatchProcessor(AsyncBatcher[List[BatchWhisperItem], List[str]]):
|
|
17
|
+
def __init__(self, model, *args, **kwargs):
|
|
18
|
+
super().__init__(*args, **kwargs)
|
|
19
|
+
self.model: "WhisperModel" = model
|
|
20
|
+
|
|
21
|
+
def concat_and_pad_mels(self, tensors: List[Tensor]):
|
|
22
|
+
"""Concatenates mel spectrograms to the maximum batch size using the last mel spectrogram as padding."""
|
|
23
|
+
while len(tensors) < self.max_batch_size:
|
|
24
|
+
tensors.append(tensors[-1])
|
|
25
|
+
res = torch.cat(tensors, dim=0).type(torch.float16)
|
|
26
|
+
return res
|
|
27
|
+
|
|
28
|
+
def concat_and_pad_prompts(self, prompts: List[List]) -> Tensor:
|
|
29
|
+
"""Concatenates prompts to the maximum batch size using the last prompt as padding."""
|
|
30
|
+
while len(prompts) < self.max_batch_size:
|
|
31
|
+
prompts.append(prompts[-1])
|
|
32
|
+
return Tensor(prompts)
|
|
33
|
+
|
|
34
|
+
def process_batch(self, batch: List[BatchWhisperItem]) -> List[float]:
|
|
35
|
+
logging.warn(f"Processing batch of size {len(batch)}")
|
|
36
|
+
|
|
37
|
+
# Need to pad the batch up to the maximum batch size
|
|
38
|
+
decoder_input_ids = self.concat_and_pad_prompts(
|
|
39
|
+
[item.prompt_ids for item in batch]
|
|
40
|
+
)
|
|
41
|
+
mel_batch = self.concat_and_pad_mels([item.mel for item in batch])
|
|
42
|
+
|
|
43
|
+
max_new_tokens = max(item.max_new_tokens for item in batch)
|
|
44
|
+
batch_result = self.model.process_batch(
|
|
45
|
+
mel_batch,
|
|
46
|
+
decoder_input_ids,
|
|
47
|
+
max_new_tokens=max_new_tokens,
|
|
48
|
+
num_beams=DEFAULT_NUM_BEAMS,
|
|
49
|
+
)
|
|
50
|
+
# Splicing to len(batch) is needed to remove the padding we add
|
|
51
|
+
# during `concat_and_pad_mels` and `concat_and_pad_prompts`
|
|
52
|
+
return batch_result[: len(batch)]
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from typing import List, NamedTuple
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
SUPPORTED_SAMPLE_RATE = 16_000
|
|
7
|
+
DEFAULT_NUM_BEAMS = 1
|
|
8
|
+
DEFAULT_MAX_NEW_TOKENS = 128
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BatchWhisperItem(NamedTuple):
|
|
12
|
+
mel: Tensor
|
|
13
|
+
prompt_ids: Tensor
|
|
14
|
+
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS
|
|
15
|
+
num_beams: int = DEFAULT_NUM_BEAMS
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Segment(BaseModel):
|
|
19
|
+
start: float
|
|
20
|
+
end: float
|
|
21
|
+
text: str
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class WhisperResult(BaseModel):
|
|
25
|
+
segments: List[Segment]
|
|
26
|
+
language_code: str
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
from collections import OrderedDict
|
|
2
|
+
|
|
3
|
+
import tensorrt_llm
|
|
4
|
+
import tensorrt_llm.logger as logger
|
|
5
|
+
import torch
|
|
6
|
+
from tensorrt_llm._utils import str_dtype_to_trt, trt_dtype_to_torch
|
|
7
|
+
from tensorrt_llm.runtime import ModelConfig, SamplingConfig
|
|
8
|
+
from tensorrt_llm.runtime.session import Session, TensorInfo
|
|
9
|
+
|
|
10
|
+
from whisper_trt.custom_types import DEFAULT_NUM_BEAMS
|
|
11
|
+
from whisper_trt.utils import read_config, remove_tensor_padding
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class WhisperEncoding:
|
|
15
|
+
def __init__(self, engine_dir):
|
|
16
|
+
self.session = self.get_session(engine_dir)
|
|
17
|
+
|
|
18
|
+
def get_session(self, engine_dir):
|
|
19
|
+
config = read_config("encoder", engine_dir)
|
|
20
|
+
self.encoder_config = config
|
|
21
|
+
|
|
22
|
+
self.dtype = config["dtype"]
|
|
23
|
+
self.n_mels = config["n_mels"]
|
|
24
|
+
self.num_languages = config["num_languages"]
|
|
25
|
+
|
|
26
|
+
serialize_path = engine_dir / "encoder" / "rank0.engine"
|
|
27
|
+
|
|
28
|
+
with open(serialize_path, "rb") as f:
|
|
29
|
+
session = Session.from_serialized_engine(f.read())
|
|
30
|
+
|
|
31
|
+
return session
|
|
32
|
+
|
|
33
|
+
def get_audio_features(self, mel):
|
|
34
|
+
input_lengths = torch.tensor(
|
|
35
|
+
[mel.shape[2] // 2 for _ in range(mel.shape[0])],
|
|
36
|
+
dtype=torch.int32,
|
|
37
|
+
device=mel.device,
|
|
38
|
+
)
|
|
39
|
+
if self.encoder_config["plugin_config"]["remove_input_padding"]:
|
|
40
|
+
mel_input_lengths = torch.full(
|
|
41
|
+
(mel.shape[0],), mel.shape[2], dtype=torch.int32, device="cuda"
|
|
42
|
+
)
|
|
43
|
+
# mel B,D,T -> B,T,D -> BxT, D
|
|
44
|
+
mel = mel.transpose(1, 2)
|
|
45
|
+
mel = remove_tensor_padding(mel, mel_input_lengths)
|
|
46
|
+
|
|
47
|
+
inputs = OrderedDict()
|
|
48
|
+
inputs["input_features"] = mel
|
|
49
|
+
inputs["input_lengths"] = input_lengths
|
|
50
|
+
|
|
51
|
+
output_list = [
|
|
52
|
+
TensorInfo("input_features", str_dtype_to_trt(self.dtype), mel.shape),
|
|
53
|
+
TensorInfo("input_lengths", str_dtype_to_trt("int32"), input_lengths.shape),
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
output_info = (self.session).infer_shapes(output_list)
|
|
57
|
+
|
|
58
|
+
logger.debug(f"output info {output_info}")
|
|
59
|
+
outputs = {
|
|
60
|
+
t.name: torch.empty(
|
|
61
|
+
tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device="cuda"
|
|
62
|
+
)
|
|
63
|
+
for t in output_info
|
|
64
|
+
}
|
|
65
|
+
stream = torch.cuda.current_stream()
|
|
66
|
+
ok = self.session.run(inputs=inputs, outputs=outputs, stream=stream.cuda_stream)
|
|
67
|
+
assert ok, "Engine execution failed"
|
|
68
|
+
stream.synchronize()
|
|
69
|
+
audio_features = outputs["encoder_output"]
|
|
70
|
+
return audio_features
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class WhisperDecoding:
|
|
74
|
+
def __init__(self, engine_dir, runtime_mapping, debug_mode=False):
|
|
75
|
+
self.decoder_config = read_config("decoder", engine_dir)
|
|
76
|
+
self.decoder_generation_session = self.get_session(
|
|
77
|
+
engine_dir, runtime_mapping, debug_mode
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
def get_session(self, engine_dir, runtime_mapping, debug_mode=False):
|
|
81
|
+
serialize_path = engine_dir / "decoder" / "rank0.engine"
|
|
82
|
+
with open(serialize_path, "rb") as f:
|
|
83
|
+
decoder_engine_buffer = f.read()
|
|
84
|
+
|
|
85
|
+
decoder_model_config = ModelConfig(
|
|
86
|
+
max_batch_size=self.decoder_config["max_batch_size"],
|
|
87
|
+
max_beam_width=self.decoder_config["max_beam_width"],
|
|
88
|
+
num_heads=self.decoder_config["num_attention_heads"],
|
|
89
|
+
num_kv_heads=self.decoder_config["num_attention_heads"],
|
|
90
|
+
hidden_size=self.decoder_config["hidden_size"],
|
|
91
|
+
vocab_size=self.decoder_config["vocab_size"],
|
|
92
|
+
cross_attention=True,
|
|
93
|
+
num_layers=self.decoder_config["num_hidden_layers"],
|
|
94
|
+
gpt_attention_plugin=self.decoder_config["plugin_config"][
|
|
95
|
+
"gpt_attention_plugin"
|
|
96
|
+
],
|
|
97
|
+
remove_input_padding=self.decoder_config["plugin_config"][
|
|
98
|
+
"remove_input_padding"
|
|
99
|
+
],
|
|
100
|
+
has_position_embedding=self.decoder_config["has_position_embedding"],
|
|
101
|
+
dtype=self.decoder_config["dtype"],
|
|
102
|
+
has_token_type_embedding=False,
|
|
103
|
+
)
|
|
104
|
+
decoder_generation_session = tensorrt_llm.runtime.GenerationSession(
|
|
105
|
+
decoder_model_config,
|
|
106
|
+
decoder_engine_buffer,
|
|
107
|
+
runtime_mapping,
|
|
108
|
+
debug_mode=debug_mode,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
return decoder_generation_session
|
|
112
|
+
|
|
113
|
+
def generate(
|
|
114
|
+
self,
|
|
115
|
+
decoder_input_ids,
|
|
116
|
+
encoder_outputs,
|
|
117
|
+
eot_id,
|
|
118
|
+
max_new_tokens=40,
|
|
119
|
+
num_beams=DEFAULT_NUM_BEAMS,
|
|
120
|
+
):
|
|
121
|
+
encoder_input_lengths = torch.tensor(
|
|
122
|
+
[encoder_outputs.shape[1] for x in range(encoder_outputs.shape[0])],
|
|
123
|
+
dtype=torch.int32,
|
|
124
|
+
device="cuda",
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
decoder_input_lengths = torch.tensor(
|
|
128
|
+
[decoder_input_ids.shape[-1] for _ in range(decoder_input_ids.shape[0])],
|
|
129
|
+
dtype=torch.int32,
|
|
130
|
+
device="cuda",
|
|
131
|
+
)
|
|
132
|
+
decoder_max_input_length = torch.max(decoder_input_lengths).item()
|
|
133
|
+
|
|
134
|
+
cross_attention_mask = (
|
|
135
|
+
torch.ones([encoder_outputs.shape[0], 1, encoder_outputs.shape[1]])
|
|
136
|
+
.int()
|
|
137
|
+
.cuda()
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# generation config
|
|
141
|
+
sampling_config = SamplingConfig(
|
|
142
|
+
end_id=eot_id, pad_id=eot_id, num_beams=num_beams
|
|
143
|
+
)
|
|
144
|
+
self.decoder_generation_session.setup(
|
|
145
|
+
decoder_input_lengths.size(0),
|
|
146
|
+
decoder_max_input_length,
|
|
147
|
+
max_new_tokens,
|
|
148
|
+
beam_width=num_beams,
|
|
149
|
+
encoder_max_input_length=encoder_outputs.shape[1],
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
torch.cuda.synchronize()
|
|
153
|
+
|
|
154
|
+
decoder_input_ids = decoder_input_ids.type(torch.int32).cuda()
|
|
155
|
+
if self.decoder_config["plugin_config"]["remove_input_padding"]:
|
|
156
|
+
# 50256 is the index of <pad> for all whisper models' decoder
|
|
157
|
+
WHISPER_PAD_TOKEN_ID = 50256
|
|
158
|
+
decoder_input_ids = remove_tensor_padding(
|
|
159
|
+
decoder_input_ids, pad_value=WHISPER_PAD_TOKEN_ID
|
|
160
|
+
)
|
|
161
|
+
if encoder_outputs.dim() == 3:
|
|
162
|
+
encoder_output_lens = torch.full(
|
|
163
|
+
(encoder_outputs.shape[0],),
|
|
164
|
+
encoder_outputs.shape[1],
|
|
165
|
+
dtype=torch.int32,
|
|
166
|
+
device="cuda",
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
encoder_outputs = remove_tensor_padding(
|
|
170
|
+
encoder_outputs, encoder_output_lens
|
|
171
|
+
)
|
|
172
|
+
output_ids = self.decoder_generation_session.decode(
|
|
173
|
+
decoder_input_ids,
|
|
174
|
+
decoder_input_lengths,
|
|
175
|
+
sampling_config,
|
|
176
|
+
encoder_output=encoder_outputs,
|
|
177
|
+
encoder_input_lengths=encoder_input_lengths,
|
|
178
|
+
cross_attention_mask=cross_attention_mask,
|
|
179
|
+
)
|
|
180
|
+
torch.cuda.synchronize()
|
|
181
|
+
|
|
182
|
+
# get the list of int from output_ids tensor
|
|
183
|
+
output_ids = output_ids.cpu().numpy().tolist()
|
|
184
|
+
return output_ids
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# Modified from https://github.com/openai/whisper/blob/main/whisper/tokenizer.py
|
|
16
|
+
import base64
|
|
17
|
+
import os
|
|
18
|
+
from typing import Optional
|
|
19
|
+
|
|
20
|
+
import tiktoken
|
|
21
|
+
|
|
22
|
+
LANGUAGES = {
|
|
23
|
+
"en": "english",
|
|
24
|
+
"zh": "chinese",
|
|
25
|
+
"de": "german",
|
|
26
|
+
"es": "spanish",
|
|
27
|
+
"ru": "russian",
|
|
28
|
+
"ko": "korean",
|
|
29
|
+
"fr": "french",
|
|
30
|
+
"ja": "japanese",
|
|
31
|
+
"pt": "portuguese",
|
|
32
|
+
"tr": "turkish",
|
|
33
|
+
"pl": "polish",
|
|
34
|
+
"ca": "catalan",
|
|
35
|
+
"nl": "dutch",
|
|
36
|
+
"ar": "arabic",
|
|
37
|
+
"sv": "swedish",
|
|
38
|
+
"it": "italian",
|
|
39
|
+
"id": "indonesian",
|
|
40
|
+
"hi": "hindi",
|
|
41
|
+
"fi": "finnish",
|
|
42
|
+
"vi": "vietnamese",
|
|
43
|
+
"he": "hebrew",
|
|
44
|
+
"uk": "ukrainian",
|
|
45
|
+
"el": "greek",
|
|
46
|
+
"ms": "malay",
|
|
47
|
+
"cs": "czech",
|
|
48
|
+
"ro": "romanian",
|
|
49
|
+
"da": "danish",
|
|
50
|
+
"hu": "hungarian",
|
|
51
|
+
"ta": "tamil",
|
|
52
|
+
"no": "norwegian",
|
|
53
|
+
"th": "thai",
|
|
54
|
+
"ur": "urdu",
|
|
55
|
+
"hr": "croatian",
|
|
56
|
+
"bg": "bulgarian",
|
|
57
|
+
"lt": "lithuanian",
|
|
58
|
+
"la": "latin",
|
|
59
|
+
"mi": "maori",
|
|
60
|
+
"ml": "malayalam",
|
|
61
|
+
"cy": "welsh",
|
|
62
|
+
"sk": "slovak",
|
|
63
|
+
"te": "telugu",
|
|
64
|
+
"fa": "persian",
|
|
65
|
+
"lv": "latvian",
|
|
66
|
+
"bn": "bengali",
|
|
67
|
+
"sr": "serbian",
|
|
68
|
+
"az": "azerbaijani",
|
|
69
|
+
"sl": "slovenian",
|
|
70
|
+
"kn": "kannada",
|
|
71
|
+
"et": "estonian",
|
|
72
|
+
"mk": "macedonian",
|
|
73
|
+
"br": "breton",
|
|
74
|
+
"eu": "basque",
|
|
75
|
+
"is": "icelandic",
|
|
76
|
+
"hy": "armenian",
|
|
77
|
+
"ne": "nepali",
|
|
78
|
+
"mn": "mongolian",
|
|
79
|
+
"bs": "bosnian",
|
|
80
|
+
"kk": "kazakh",
|
|
81
|
+
"sq": "albanian",
|
|
82
|
+
"sw": "swahili",
|
|
83
|
+
"gl": "galician",
|
|
84
|
+
"mr": "marathi",
|
|
85
|
+
"pa": "punjabi",
|
|
86
|
+
"si": "sinhala",
|
|
87
|
+
"km": "khmer",
|
|
88
|
+
"sn": "shona",
|
|
89
|
+
"yo": "yoruba",
|
|
90
|
+
"so": "somali",
|
|
91
|
+
"af": "afrikaans",
|
|
92
|
+
"oc": "occitan",
|
|
93
|
+
"ka": "georgian",
|
|
94
|
+
"be": "belarusian",
|
|
95
|
+
"tg": "tajik",
|
|
96
|
+
"sd": "sindhi",
|
|
97
|
+
"gu": "gujarati",
|
|
98
|
+
"am": "amharic",
|
|
99
|
+
"yi": "yiddish",
|
|
100
|
+
"lo": "lao",
|
|
101
|
+
"uz": "uzbek",
|
|
102
|
+
"fo": "faroese",
|
|
103
|
+
"ht": "haitian creole",
|
|
104
|
+
"ps": "pashto",
|
|
105
|
+
"tk": "turkmen",
|
|
106
|
+
"nn": "nynorsk",
|
|
107
|
+
"mt": "maltese",
|
|
108
|
+
"sa": "sanskrit",
|
|
109
|
+
"lb": "luxembourgish",
|
|
110
|
+
"my": "myanmar",
|
|
111
|
+
"bo": "tibetan",
|
|
112
|
+
"tl": "tagalog",
|
|
113
|
+
"mg": "malagasy",
|
|
114
|
+
"as": "assamese",
|
|
115
|
+
"tt": "tatar",
|
|
116
|
+
"haw": "hawaiian",
|
|
117
|
+
"ln": "lingala",
|
|
118
|
+
"ha": "hausa",
|
|
119
|
+
"ba": "bashkir",
|
|
120
|
+
"jw": "javanese",
|
|
121
|
+
"su": "sundanese",
|
|
122
|
+
"yue": "cantonese",
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
REVERSED_LANGUAGES = {v: k for k, v in LANGUAGES.items()}
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def get_tokenizer(
|
|
129
|
+
name: str = "multilingual",
|
|
130
|
+
num_languages: int = 99,
|
|
131
|
+
tokenizer_dir: Optional[str] = None,
|
|
132
|
+
):
|
|
133
|
+
if tokenizer_dir is None:
|
|
134
|
+
vocab_path = os.path.join(os.path.dirname(__file__), f"assets/{name}.tiktoken")
|
|
135
|
+
else:
|
|
136
|
+
vocab_path = os.path.join(tokenizer_dir, f"{name}.tiktoken")
|
|
137
|
+
ranks = {
|
|
138
|
+
base64.b64decode(token): int(rank)
|
|
139
|
+
for token, rank in (line.split() for line in open(vocab_path) if line)
|
|
140
|
+
}
|
|
141
|
+
n_vocab = len(ranks)
|
|
142
|
+
special_tokens = {}
|
|
143
|
+
|
|
144
|
+
specials = [
|
|
145
|
+
"<|endoftext|>",
|
|
146
|
+
"<|startoftranscript|>",
|
|
147
|
+
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
|
148
|
+
"<|translate|>",
|
|
149
|
+
"<|transcribe|>",
|
|
150
|
+
"<|startoflm|>",
|
|
151
|
+
"<|startofprev|>",
|
|
152
|
+
"<|nospeech|>",
|
|
153
|
+
"<|notimestamps|>",
|
|
154
|
+
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
|
155
|
+
]
|
|
156
|
+
|
|
157
|
+
for token in specials:
|
|
158
|
+
special_tokens[token] = n_vocab
|
|
159
|
+
n_vocab += 1
|
|
160
|
+
|
|
161
|
+
return tiktoken.Encoding(
|
|
162
|
+
name=os.path.basename(vocab_path),
|
|
163
|
+
explicit_n_vocab=n_vocab,
|
|
164
|
+
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
|
165
|
+
mergeable_ranks=ranks,
|
|
166
|
+
special_tokens=special_tokens,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
if __name__ == "__main__":
|
|
171
|
+
enc = get_tokenizer()
|
|
172
|
+
mytest_str = "<|startofprev|> Nvidia<|startoftranscript|><|en|><|transcribe|>"
|
|
173
|
+
encoding = enc.encode(mytest_str, allowed_special=enc.special_tokens_set)
|
|
174
|
+
mystr = enc.decode([50361, 45, 43021, 50258, 50259, 50359])
|
|
175
|
+
mystr2 = enc.decode([50361, 46284, 50258, 50259, 50359])
|
|
176
|
+
print(encoding, mystr, mystr2)
|
|
177
|
+
print(
|
|
178
|
+
enc.encode("<|startoftranscript|>", allowed_special=enc.special_tokens_set)[0]
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
my_zh_str = "好好学习"
|
|
182
|
+
encoding = enc.encode(my_zh_str, allowed_special=enc.special_tokens_set)
|
|
183
|
+
decoding = enc.decode(encoding)
|
|
184
|
+
print(type(decoding))
|
|
185
|
+
print(encoding, decoding)
|