truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/__init__.py +10 -3
- truss/api/__init__.py +123 -0
- truss/api/definitions.py +51 -0
- truss/base/constants.py +116 -0
- truss/base/custom_types.py +29 -0
- truss/{errors.py → base/errors.py} +4 -0
- truss/base/trt_llm_config.py +310 -0
- truss/{truss_config.py → base/truss_config.py} +344 -31
- truss/{truss_spec.py → base/truss_spec.py} +20 -6
- truss/{validation.py → base/validation.py} +60 -11
- truss/cli/cli.py +841 -88
- truss/{remote → cli}/remote_cli.py +2 -7
- truss/contexts/docker_build_setup.py +67 -0
- truss/contexts/image_builder/cache_warmer.py +2 -8
- truss/contexts/image_builder/image_builder.py +1 -1
- truss/contexts/image_builder/serving_image_builder.py +292 -46
- truss/contexts/image_builder/util.py +1 -3
- truss/contexts/local_loader/docker_build_emulator.py +58 -0
- truss/contexts/local_loader/load_model_local.py +2 -2
- truss/contexts/local_loader/truss_module_loader.py +1 -1
- truss/contexts/local_loader/utils.py +1 -1
- truss/local/local_config.py +2 -6
- truss/local/local_config_handler.py +20 -5
- truss/patch/__init__.py +1 -0
- truss/patch/hash.py +4 -70
- truss/patch/signature.py +4 -16
- truss/patch/truss_dir_patch_applier.py +3 -78
- truss/remote/baseten/api.py +308 -23
- truss/remote/baseten/auth.py +3 -3
- truss/remote/baseten/core.py +257 -50
- truss/remote/baseten/custom_types.py +44 -0
- truss/remote/baseten/error.py +4 -0
- truss/remote/baseten/remote.py +369 -118
- truss/remote/baseten/service.py +118 -11
- truss/remote/baseten/utils/status.py +29 -0
- truss/remote/baseten/utils/tar.py +34 -22
- truss/remote/baseten/utils/transfer.py +36 -23
- truss/remote/remote_factory.py +14 -5
- truss/remote/truss_remote.py +72 -45
- truss/templates/base.Dockerfile.jinja +18 -16
- truss/templates/cache.Dockerfile.jinja +3 -3
- truss/{server → templates/control}/control/application.py +14 -35
- truss/{server → templates/control}/control/endpoints.py +39 -9
- truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
- truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
- truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
- truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
- truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
- truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
- truss/{server → templates/control}/control/server.py +11 -6
- truss/templates/control/requirements.txt +9 -0
- truss/templates/custom_python_dx/my_model.py +28 -0
- truss/templates/docker_server/proxy.conf.jinja +42 -0
- truss/templates/docker_server/supervisord.conf.jinja +27 -0
- truss/templates/docker_server_requirements.txt +1 -0
- truss/templates/server/common/errors.py +231 -0
- truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
- truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
- truss/{server → templates/server}/common/retry.py +1 -0
- truss/{server → templates/server}/common/schema.py +11 -9
- truss/templates/server/common/tracing.py +157 -0
- truss/templates/server/main.py +9 -0
- truss/templates/server/model_wrapper.py +961 -0
- truss/templates/server/requirements.txt +21 -0
- truss/templates/server/truss_server.py +447 -0
- truss/templates/server.Dockerfile.jinja +62 -14
- truss/templates/shared/dynamic_config_resolver.py +28 -0
- truss/templates/shared/lazy_data_resolver.py +164 -0
- truss/templates/shared/log_config.py +125 -0
- truss/{server → templates}/shared/secrets_resolver.py +1 -2
- truss/{server → templates}/shared/serialization.py +31 -9
- truss/{server → templates}/shared/util.py +3 -13
- truss/templates/trtllm-audio/model/model.py +49 -0
- truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
- truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
- truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
- truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
- truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
- truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
- truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
- truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
- truss/templates/trtllm-briton/src/extension.py +64 -0
- truss/tests/conftest.py +302 -94
- truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
- truss/tests/contexts/local_loader/test_load_local.py +2 -2
- truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
- truss/tests/patch/test_calc_patch.py +439 -127
- truss/tests/patch/test_dir_signature.py +3 -12
- truss/tests/patch/test_hash.py +1 -1
- truss/tests/patch/test_signature.py +1 -1
- truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
- truss/tests/patch/test_types.py +2 -2
- truss/tests/remote/baseten/test_api.py +153 -58
- truss/tests/remote/baseten/test_auth.py +2 -1
- truss/tests/remote/baseten/test_core.py +160 -12
- truss/tests/remote/baseten/test_remote.py +489 -77
- truss/tests/remote/baseten/test_service.py +55 -0
- truss/tests/remote/test_remote_factory.py +16 -18
- truss/tests/remote/test_truss_remote.py +26 -17
- truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
- truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
- truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
- truss/tests/{server → templates/control}/control/test_server.py +79 -24
- truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
- truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
- truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
- truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
- truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
- truss/tests/{server → templates/server}/common/test_retry.py +3 -3
- truss/tests/templates/server/test_model_wrapper.py +248 -0
- truss/tests/{server → templates/server}/test_schema.py +3 -5
- truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
- truss/tests/test_build.py +9 -52
- truss/tests/test_config.py +336 -77
- truss/tests/test_context_builder_image.py +3 -11
- truss/tests/test_control_truss_patching.py +7 -12
- truss/tests/test_custom_server.py +38 -0
- truss/tests/test_data/context_builder_image_test/test.py +3 -0
- truss/tests/test_data/happy.ipynb +56 -0
- truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
- truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
- truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
- truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
- truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
- truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
- truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
- truss/tests/test_data/test_async_truss/__init__.py +0 -0
- truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/config.yaml +16 -0
- truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/config.yaml +13 -0
- truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
- truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
- truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
- truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
- truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
- truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
- truss/tests/test_data/test_openai/__init__.py +0 -0
- truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
- truss/tests/test_data/test_openai/model/__init__.py +0 -0
- truss/tests/test_data/test_openai/model/model.py +15 -0
- truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
- truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
- truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
- truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
- truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
- truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
- truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
- truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
- truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
- truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
- truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
- truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
- truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
- truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
- truss/tests/test_data/test_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss/config.yaml +4 -0
- truss/tests/test_data/test_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss/model/dummy +0 -0
- truss/tests/test_data/test_truss/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
- truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
- truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_docker.py +2 -1
- truss/tests/test_model_inference.py +1340 -292
- truss/tests/test_model_schema.py +33 -26
- truss/tests/test_testing_utilities_for_other_tests.py +50 -5
- truss/tests/test_truss_gatherer.py +3 -5
- truss/tests/test_truss_handle.py +62 -59
- truss/tests/test_util.py +2 -1
- truss/tests/test_validation.py +15 -13
- truss/tests/trt_llm/test_trt_llm_config.py +41 -0
- truss/tests/trt_llm/test_validation.py +91 -0
- truss/tests/util/test_config_checks.py +40 -0
- truss/tests/util/test_env_vars.py +14 -0
- truss/tests/util/test_path.py +10 -23
- truss/trt_llm/config_checks.py +43 -0
- truss/trt_llm/validation.py +42 -0
- truss/truss_handle/__init__.py +0 -0
- truss/truss_handle/build.py +122 -0
- truss/{decorators.py → truss_handle/decorators.py} +1 -1
- truss/truss_handle/patch/__init__.py +0 -0
- truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
- truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
- truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
- truss/truss_handle/patch/hash.py +71 -0
- truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
- truss/truss_handle/patch/signature.py +22 -0
- truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
- truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
- truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
- truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
- truss/util/.truss_ignore +3 -0
- truss/{docker.py → util/docker.py} +6 -2
- truss/util/download.py +6 -15
- truss/util/env_vars.py +41 -0
- truss/util/log_utils.py +52 -0
- truss/util/path.py +20 -20
- truss/util/requirements.py +11 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
- truss-0.60.0.dist-info/RECORD +324 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
- truss-0.60.0.dist-info/entry_points.txt +4 -0
- truss_chains/__init__.py +71 -0
- truss_chains/definitions.py +756 -0
- truss_chains/deployment/__init__.py +0 -0
- truss_chains/deployment/code_gen.py +816 -0
- truss_chains/deployment/deployment_client.py +871 -0
- truss_chains/framework.py +1480 -0
- truss_chains/public_api.py +231 -0
- truss_chains/py.typed +0 -0
- truss_chains/pydantic_numpy.py +131 -0
- truss_chains/reference_code/reference_chainlet.py +34 -0
- truss_chains/reference_code/reference_model.py +10 -0
- truss_chains/remote_chainlet/__init__.py +0 -0
- truss_chains/remote_chainlet/model_skeleton.py +60 -0
- truss_chains/remote_chainlet/stub.py +380 -0
- truss_chains/remote_chainlet/utils.py +332 -0
- truss_chains/streaming.py +378 -0
- truss_chains/utils.py +178 -0
- CODE_OF_CONDUCT.md +0 -131
- CONTRIBUTING.md +0 -48
- README.md +0 -137
- context_builder.Dockerfile +0 -24
- truss/blob/blob_backend.py +0 -10
- truss/blob/blob_backend_registry.py +0 -23
- truss/blob/http_public_blob_backend.py +0 -23
- truss/build/__init__.py +0 -2
- truss/build/build.py +0 -143
- truss/build/configure.py +0 -63
- truss/cli/__init__.py +0 -2
- truss/cli/console.py +0 -5
- truss/cli/create.py +0 -5
- truss/config/trt_llm.py +0 -81
- truss/constants.py +0 -61
- truss/model_inference.py +0 -123
- truss/patch/types.py +0 -30
- truss/pytest.ini +0 -7
- truss/server/common/errors.py +0 -100
- truss/server/common/termination_handler_middleware.py +0 -64
- truss/server/common/truss_server.py +0 -389
- truss/server/control/patch/model_code_patch_applier.py +0 -46
- truss/server/control/patch/requirement_name_identifier.py +0 -17
- truss/server/inference_server.py +0 -29
- truss/server/model_wrapper.py +0 -434
- truss/server/shared/logging.py +0 -81
- truss/templates/trtllm/model/model.py +0 -97
- truss/templates/trtllm/packages/build_engine_utils.py +0 -34
- truss/templates/trtllm/packages/constants.py +0 -11
- truss/templates/trtllm/packages/schema.py +0 -216
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
- truss/templates/trtllm/packages/triton_client.py +0 -150
- truss/templates/trtllm/packages/utils.py +0 -43
- truss/test_data/context_builder_image_test/test.py +0 -4
- truss/test_data/happy.ipynb +0 -54
- truss/test_data/model_load_failure_test/config.yaml +0 -2
- truss/test_data/test_concurrency_truss/config.yaml +0 -2
- truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
- truss/test_data/test_streaming_truss/config.yaml +0 -3
- truss/test_data/test_truss/config.yaml +0 -2
- truss/tests/server/common/test_termination_handler_middleware.py +0 -93
- truss/tests/server/control/test_model_container_patch_applier.py +0 -203
- truss/tests/server/core/server/common/test_util.py +0 -19
- truss/tests/server/test_model_wrapper.py +0 -87
- truss/util/data_structures.py +0 -16
- truss-0.10.0rc1.dist-info/RECORD +0 -216
- truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
- truss/{server/shared → base}/__init__.py +0 -0
- truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
- truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
- truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
- truss/{server → templates/server}/common/__init__.py +0 -0
- truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
- truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
- truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
- truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
- truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
- truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
- truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
- truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
- truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
- truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
- truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
- truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
- truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
- truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
- /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
- /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
- /truss/{patch → truss_handle/patch}/constants.py +0 -0
- /truss/{notebook.py → util/notebook.py} +0 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
|
@@ -1,216 +0,0 @@
|
|
|
1
|
-
from enum import Enum
|
|
2
|
-
from pathlib import Path
|
|
3
|
-
from typing import Optional
|
|
4
|
-
|
|
5
|
-
import numpy as np
|
|
6
|
-
import tritonclient
|
|
7
|
-
import tritonclient.grpc.aio as grpcclient
|
|
8
|
-
from pydantic import BaseModel, ConfigDict, PrivateAttr
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class ModelInput:
|
|
12
|
-
def __init__(
|
|
13
|
-
self,
|
|
14
|
-
prompt: str,
|
|
15
|
-
request_id: int,
|
|
16
|
-
max_tokens: int = 50,
|
|
17
|
-
beam_width: int = 1,
|
|
18
|
-
bad_words_list: Optional[list] = None,
|
|
19
|
-
stop_words_list: Optional[list] = None,
|
|
20
|
-
repetition_penalty: float = 1.0,
|
|
21
|
-
ignore_eos: bool = False,
|
|
22
|
-
stream: bool = True,
|
|
23
|
-
eos_token_id: int = None, # type: ignore
|
|
24
|
-
) -> None:
|
|
25
|
-
self.stream = stream
|
|
26
|
-
self.request_id = request_id
|
|
27
|
-
self._prompt = prompt
|
|
28
|
-
self._max_tokens = max_tokens
|
|
29
|
-
self._beam_width = beam_width
|
|
30
|
-
self._bad_words_list = [""] if bad_words_list is None else bad_words_list
|
|
31
|
-
self._stop_words_list = [""] if stop_words_list is None else stop_words_list
|
|
32
|
-
self._repetition_penalty = repetition_penalty
|
|
33
|
-
self._eos_token_id = eos_token_id
|
|
34
|
-
self._ignore_eos = ignore_eos
|
|
35
|
-
|
|
36
|
-
def _prepare_grpc_tensor(
|
|
37
|
-
self, name: str, input_data: np.ndarray
|
|
38
|
-
) -> grpcclient.InferInput:
|
|
39
|
-
tensor = grpcclient.InferInput(
|
|
40
|
-
name,
|
|
41
|
-
input_data.shape,
|
|
42
|
-
tritonclient.utils.np_to_triton_dtype(input_data.dtype),
|
|
43
|
-
)
|
|
44
|
-
tensor.set_data_from_numpy(input_data)
|
|
45
|
-
return tensor
|
|
46
|
-
|
|
47
|
-
def to_tensors(self):
|
|
48
|
-
if self._eos_token_id is None and self._ignore_eos:
|
|
49
|
-
raise ValueError("eos_token_id is required when ignore_eos is True")
|
|
50
|
-
|
|
51
|
-
prompt_data = np.array([[self._prompt]], dtype=object)
|
|
52
|
-
output_len_data = np.ones_like(prompt_data, dtype=np.uint32) * self._max_tokens
|
|
53
|
-
bad_words_data = np.array([self._bad_words_list], dtype=object)
|
|
54
|
-
stop_words_data = np.array([self._stop_words_list], dtype=object)
|
|
55
|
-
stream_data = np.array([[self.stream]], dtype=bool)
|
|
56
|
-
beam_width_data = np.array([[self._beam_width]], dtype=np.uint32)
|
|
57
|
-
repetition_penalty_data = np.array(
|
|
58
|
-
[[self._repetition_penalty]], dtype=np.float32
|
|
59
|
-
)
|
|
60
|
-
|
|
61
|
-
inputs = [
|
|
62
|
-
self._prepare_grpc_tensor("text_input", prompt_data),
|
|
63
|
-
self._prepare_grpc_tensor("max_tokens", output_len_data),
|
|
64
|
-
self._prepare_grpc_tensor("bad_words", bad_words_data),
|
|
65
|
-
self._prepare_grpc_tensor("stop_words", stop_words_data),
|
|
66
|
-
self._prepare_grpc_tensor("stream", stream_data),
|
|
67
|
-
self._prepare_grpc_tensor("beam_width", beam_width_data),
|
|
68
|
-
self._prepare_grpc_tensor("repetition_penalty", repetition_penalty_data),
|
|
69
|
-
]
|
|
70
|
-
|
|
71
|
-
if not self._ignore_eos:
|
|
72
|
-
end_id_data = np.array([[self._eos_token_id]], dtype=np.uint32)
|
|
73
|
-
inputs.append(self._prepare_grpc_tensor("end_id", end_id_data))
|
|
74
|
-
|
|
75
|
-
return inputs
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
class Quant(Enum):
|
|
79
|
-
NO_QUANT = "no_quant"
|
|
80
|
-
WEIGHTS_ONLY = "weights_only"
|
|
81
|
-
WEIGHTS_KV_INT8 = "weights_kv_int8"
|
|
82
|
-
SMOOTH_QUANT = "smooth_quant"
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
class EngineType(Enum):
|
|
86
|
-
LLAMA = "llama"
|
|
87
|
-
MISTRAL = "mistral"
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
class ArgsConfig(BaseModel):
|
|
91
|
-
max_input_len: Optional[int] = None
|
|
92
|
-
max_output_len: Optional[int] = None
|
|
93
|
-
max_batch_size: Optional[int] = None
|
|
94
|
-
tp_size: Optional[int] = None
|
|
95
|
-
pp_size: Optional[int] = None
|
|
96
|
-
world_size: Optional[int] = None
|
|
97
|
-
gather_all_token_logits: Optional[bool] = None
|
|
98
|
-
multi_block_mode: Optional[bool] = None
|
|
99
|
-
remove_input_padding: Optional[bool] = None
|
|
100
|
-
use_gpt_attention_plugin: Optional[str] = None
|
|
101
|
-
paged_kv_cache: Optional[bool] = None
|
|
102
|
-
use_inflight_batching: Optional[bool] = None
|
|
103
|
-
enable_context_fmha: Optional[bool] = None
|
|
104
|
-
use_gemm_plugin: Optional[str] = None
|
|
105
|
-
use_weight_only: Optional[bool] = None
|
|
106
|
-
output_dir: Optional[str] = None
|
|
107
|
-
model_dir: Optional[str] = None
|
|
108
|
-
ft_model_dir: Optional[str] = None
|
|
109
|
-
dtype: Optional[str] = None
|
|
110
|
-
int8_kv_cache: Optional[bool] = None
|
|
111
|
-
use_smooth_quant: Optional[bool] = None
|
|
112
|
-
per_token: Optional[bool] = None
|
|
113
|
-
per_channel: Optional[bool] = None
|
|
114
|
-
parallel_build: Optional[bool] = None
|
|
115
|
-
|
|
116
|
-
# to disable warning because `model_dir` starts with `model_` prefix
|
|
117
|
-
model_config = ConfigDict(protected_namespaces=()) # type: ignore
|
|
118
|
-
|
|
119
|
-
def as_command_arguments(self) -> list:
|
|
120
|
-
non_bool_args = [
|
|
121
|
-
element
|
|
122
|
-
for arg, value in self.dict().items()
|
|
123
|
-
for element in [f"--{arg}", str(value)]
|
|
124
|
-
if value is not None and not isinstance(value, bool)
|
|
125
|
-
]
|
|
126
|
-
bool_args = [
|
|
127
|
-
f"--{arg}"
|
|
128
|
-
for arg, value in self.dict().items()
|
|
129
|
-
if isinstance(value, bool) and value
|
|
130
|
-
]
|
|
131
|
-
return non_bool_args + bool_args
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
class CalibrationConfig(BaseModel):
|
|
135
|
-
kv_cache: Optional[bool] = None # either to calibrate kv cache
|
|
136
|
-
sq_alpha: Optional[float] = None
|
|
137
|
-
|
|
138
|
-
def cache_path(self) -> Path:
|
|
139
|
-
if self.kv_cache is not None:
|
|
140
|
-
return Path("kv_cache")
|
|
141
|
-
else:
|
|
142
|
-
return Path(f"sq_{self.sq_alpha}")
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
class EngineBuildArgs(BaseModel, use_enum_values=True):
|
|
146
|
-
repo: Optional[str] = None
|
|
147
|
-
args: Optional[ArgsConfig] = None
|
|
148
|
-
quant: Optional[Quant] = None
|
|
149
|
-
calibration: Optional[CalibrationConfig] = None
|
|
150
|
-
engine_type: Optional[EngineType] = None
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
class TrussBuildConfig(BaseModel):
|
|
154
|
-
"""
|
|
155
|
-
This is a spec for what the config.yaml looks like to take advantage of TRT-LLM + TRT-LLM builds. We structure the
|
|
156
|
-
configuration with the below top-level keys.
|
|
157
|
-
|
|
158
|
-
Example (for building an engine)
|
|
159
|
-
```
|
|
160
|
-
build:
|
|
161
|
-
model_server: TRT_LLM
|
|
162
|
-
arguments:
|
|
163
|
-
tokenizer_repository: "mistralai/mistral-v2-instruct"
|
|
164
|
-
arguments:
|
|
165
|
-
max_input_len: 1024
|
|
166
|
-
max_output_len: 1024
|
|
167
|
-
max_batch_size: 64
|
|
168
|
-
quant: "weights_kv_int8"
|
|
169
|
-
tensor_parallel_count: 2
|
|
170
|
-
pipeline_parallel_count: 1
|
|
171
|
-
```
|
|
172
|
-
|
|
173
|
-
Example (for using an existing engine)
|
|
174
|
-
```
|
|
175
|
-
build:
|
|
176
|
-
model_server: TRT_LLM
|
|
177
|
-
arguments:
|
|
178
|
-
engine_repository: "baseten/mistral-v2-32k"
|
|
179
|
-
tensor_parallel_count: 2
|
|
180
|
-
pipeline_parallel_count: 1
|
|
181
|
-
```
|
|
182
|
-
|
|
183
|
-
"""
|
|
184
|
-
|
|
185
|
-
tokenizer_repository: str
|
|
186
|
-
quant: Quant = Quant.NO_QUANT
|
|
187
|
-
pipeline_parallel_count: int = 1
|
|
188
|
-
tensor_parallel_count: int = 1
|
|
189
|
-
arguments: Optional[ArgsConfig] = None
|
|
190
|
-
engine_repository: Optional[str] = None
|
|
191
|
-
calibration: Optional[CalibrationConfig] = None
|
|
192
|
-
engine_type: Optional[EngineType] = None
|
|
193
|
-
_engine_build_args: Optional[EngineBuildArgs] = PrivateAttr(default=None)
|
|
194
|
-
|
|
195
|
-
@property
|
|
196
|
-
def engine_build_args(self) -> EngineBuildArgs:
|
|
197
|
-
if self._engine_build_args is None:
|
|
198
|
-
repo = self.tokenizer_repository
|
|
199
|
-
quant = self.quant
|
|
200
|
-
calibration = self.calibration
|
|
201
|
-
engine_type = self.engine_type
|
|
202
|
-
args = self.arguments or ArgsConfig()
|
|
203
|
-
args.tp_size = self.tensor_parallel_count
|
|
204
|
-
args.pp_size = self.pipeline_parallel_count
|
|
205
|
-
self._engine_build_args = EngineBuildArgs(
|
|
206
|
-
repo=repo,
|
|
207
|
-
quant=quant,
|
|
208
|
-
calibration=calibration,
|
|
209
|
-
engine_type=engine_type,
|
|
210
|
-
args=args,
|
|
211
|
-
)
|
|
212
|
-
return self._engine_build_args
|
|
213
|
-
|
|
214
|
-
@property
|
|
215
|
-
def requires_build(self):
|
|
216
|
-
return self.engine_repository is None
|
|
@@ -1,246 +0,0 @@
|
|
|
1
|
-
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Redistribution and use in source and binary forms, with or without
|
|
4
|
-
# modification, are permitted provided that the following conditions
|
|
5
|
-
# are met:
|
|
6
|
-
# * Redistributions of source code must retain the above copyright
|
|
7
|
-
# notice, this list of conditions and the following disclaimer.
|
|
8
|
-
# * Redistributions in binary form must reproduce the above copyright
|
|
9
|
-
# notice, this list of conditions and the following disclaimer in the
|
|
10
|
-
# documentation and/or other materials provided with the distribution.
|
|
11
|
-
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
|
12
|
-
# contributors may be used to endorse or promote products derived
|
|
13
|
-
# from this software without specific prior written permission.
|
|
14
|
-
#
|
|
15
|
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
|
16
|
-
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
17
|
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
|
18
|
-
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
|
19
|
-
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
|
20
|
-
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
|
21
|
-
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
|
22
|
-
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
|
23
|
-
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
24
|
-
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
25
|
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
26
|
-
|
|
27
|
-
name: "ensemble"
|
|
28
|
-
platform: "ensemble"
|
|
29
|
-
max_batch_size: 2048
|
|
30
|
-
input [
|
|
31
|
-
{
|
|
32
|
-
name: "text_input"
|
|
33
|
-
data_type: TYPE_STRING
|
|
34
|
-
dims: [ -1 ]
|
|
35
|
-
},
|
|
36
|
-
{
|
|
37
|
-
name: "max_tokens"
|
|
38
|
-
data_type: TYPE_UINT32
|
|
39
|
-
dims: [ -1 ]
|
|
40
|
-
},
|
|
41
|
-
{
|
|
42
|
-
name: "bad_words"
|
|
43
|
-
data_type: TYPE_STRING
|
|
44
|
-
dims: [ -1 ]
|
|
45
|
-
},
|
|
46
|
-
{
|
|
47
|
-
name: "stop_words"
|
|
48
|
-
data_type: TYPE_STRING
|
|
49
|
-
dims: [ -1 ]
|
|
50
|
-
},
|
|
51
|
-
{
|
|
52
|
-
name: "end_id"
|
|
53
|
-
data_type: TYPE_UINT32
|
|
54
|
-
dims: [ 1 ]
|
|
55
|
-
optional: true
|
|
56
|
-
},
|
|
57
|
-
{
|
|
58
|
-
name: "pad_id"
|
|
59
|
-
data_type: TYPE_UINT32
|
|
60
|
-
dims: [ 1 ]
|
|
61
|
-
optional: true
|
|
62
|
-
},
|
|
63
|
-
{
|
|
64
|
-
name: "top_k"
|
|
65
|
-
data_type: TYPE_UINT32
|
|
66
|
-
dims: [ 1 ]
|
|
67
|
-
optional: true
|
|
68
|
-
},
|
|
69
|
-
{
|
|
70
|
-
name: "top_p"
|
|
71
|
-
data_type: TYPE_FP32
|
|
72
|
-
dims: [ 1 ]
|
|
73
|
-
optional: true
|
|
74
|
-
},
|
|
75
|
-
{
|
|
76
|
-
name: "temperature"
|
|
77
|
-
data_type: TYPE_FP32
|
|
78
|
-
dims: [ 1 ]
|
|
79
|
-
optional: true
|
|
80
|
-
},
|
|
81
|
-
{
|
|
82
|
-
name: "length_penalty"
|
|
83
|
-
data_type: TYPE_FP32
|
|
84
|
-
dims: [ 1 ]
|
|
85
|
-
optional: true
|
|
86
|
-
},
|
|
87
|
-
{
|
|
88
|
-
name: "repetition_penalty"
|
|
89
|
-
data_type: TYPE_FP32
|
|
90
|
-
dims: [ 1 ]
|
|
91
|
-
optional: true
|
|
92
|
-
},
|
|
93
|
-
{
|
|
94
|
-
name: "min_length"
|
|
95
|
-
data_type: TYPE_UINT32
|
|
96
|
-
dims: [ 1 ]
|
|
97
|
-
optional: true
|
|
98
|
-
},
|
|
99
|
-
{
|
|
100
|
-
name: "presence_penalty"
|
|
101
|
-
data_type: TYPE_FP32
|
|
102
|
-
dims: [ 1 ]
|
|
103
|
-
optional: true
|
|
104
|
-
},
|
|
105
|
-
{
|
|
106
|
-
name: "random_seed"
|
|
107
|
-
data_type: TYPE_UINT64
|
|
108
|
-
dims: [ 1 ]
|
|
109
|
-
optional: true
|
|
110
|
-
},
|
|
111
|
-
{
|
|
112
|
-
name: "beam_width"
|
|
113
|
-
data_type: TYPE_UINT32
|
|
114
|
-
dims: [ 1 ]
|
|
115
|
-
optional: true
|
|
116
|
-
},
|
|
117
|
-
{
|
|
118
|
-
name: "stream"
|
|
119
|
-
data_type: TYPE_BOOL
|
|
120
|
-
dims: [ 1 ]
|
|
121
|
-
optional: true
|
|
122
|
-
}
|
|
123
|
-
]
|
|
124
|
-
output [
|
|
125
|
-
{
|
|
126
|
-
name: "text_output"
|
|
127
|
-
data_type: TYPE_STRING
|
|
128
|
-
dims: [ -1, -1 ]
|
|
129
|
-
}
|
|
130
|
-
]
|
|
131
|
-
ensemble_scheduling {
|
|
132
|
-
step [
|
|
133
|
-
{
|
|
134
|
-
model_name: "preprocessing"
|
|
135
|
-
model_version: -1
|
|
136
|
-
input_map {
|
|
137
|
-
key: "QUERY"
|
|
138
|
-
value: "text_input"
|
|
139
|
-
}
|
|
140
|
-
input_map {
|
|
141
|
-
key: "REQUEST_OUTPUT_LEN"
|
|
142
|
-
value: "max_tokens"
|
|
143
|
-
}
|
|
144
|
-
input_map {
|
|
145
|
-
key: "BAD_WORDS_DICT"
|
|
146
|
-
value: "bad_words"
|
|
147
|
-
}
|
|
148
|
-
input_map {
|
|
149
|
-
key: "STOP_WORDS_DICT"
|
|
150
|
-
value: "stop_words"
|
|
151
|
-
}
|
|
152
|
-
output_map {
|
|
153
|
-
key: "REQUEST_INPUT_LEN"
|
|
154
|
-
value: "_REQUEST_INPUT_LEN"
|
|
155
|
-
}
|
|
156
|
-
output_map {
|
|
157
|
-
key: "INPUT_ID"
|
|
158
|
-
value: "_INPUT_ID"
|
|
159
|
-
}
|
|
160
|
-
output_map {
|
|
161
|
-
key: "REQUEST_OUTPUT_LEN"
|
|
162
|
-
value: "_REQUEST_OUTPUT_LEN"
|
|
163
|
-
}
|
|
164
|
-
},
|
|
165
|
-
{
|
|
166
|
-
model_name: "tensorrt_llm"
|
|
167
|
-
model_version: -1
|
|
168
|
-
input_map {
|
|
169
|
-
key: "input_ids"
|
|
170
|
-
value: "_INPUT_ID"
|
|
171
|
-
}
|
|
172
|
-
input_map {
|
|
173
|
-
key: "input_lengths"
|
|
174
|
-
value: "_REQUEST_INPUT_LEN"
|
|
175
|
-
}
|
|
176
|
-
input_map {
|
|
177
|
-
key: "request_output_len"
|
|
178
|
-
value: "_REQUEST_OUTPUT_LEN"
|
|
179
|
-
}
|
|
180
|
-
input_map {
|
|
181
|
-
key: "end_id"
|
|
182
|
-
value: "end_id"
|
|
183
|
-
}
|
|
184
|
-
input_map {
|
|
185
|
-
key: "pad_id"
|
|
186
|
-
value: "pad_id"
|
|
187
|
-
}
|
|
188
|
-
input_map {
|
|
189
|
-
key: "runtime_top_k"
|
|
190
|
-
value: "top_k"
|
|
191
|
-
}
|
|
192
|
-
input_map {
|
|
193
|
-
key: "runtime_top_p"
|
|
194
|
-
value: "top_p"
|
|
195
|
-
}
|
|
196
|
-
input_map {
|
|
197
|
-
key: "temperature"
|
|
198
|
-
value: "temperature"
|
|
199
|
-
}
|
|
200
|
-
input_map {
|
|
201
|
-
key: "len_penalty"
|
|
202
|
-
value: "length_penalty"
|
|
203
|
-
}
|
|
204
|
-
input_map {
|
|
205
|
-
key: "repetition_penalty"
|
|
206
|
-
value: "repetition_penalty"
|
|
207
|
-
}
|
|
208
|
-
input_map {
|
|
209
|
-
key: "min_length"
|
|
210
|
-
value: "min_length"
|
|
211
|
-
}
|
|
212
|
-
input_map {
|
|
213
|
-
key: "presence_penalty"
|
|
214
|
-
value: "presence_penalty"
|
|
215
|
-
}
|
|
216
|
-
input_map {
|
|
217
|
-
key: "random_seed"
|
|
218
|
-
value: "random_seed"
|
|
219
|
-
}
|
|
220
|
-
input_map {
|
|
221
|
-
key: "beam_width"
|
|
222
|
-
value: "beam_width"
|
|
223
|
-
}
|
|
224
|
-
input_map {
|
|
225
|
-
key: "streaming"
|
|
226
|
-
value: "stream"
|
|
227
|
-
}
|
|
228
|
-
output_map {
|
|
229
|
-
key: "output_ids"
|
|
230
|
-
value: "_TOKENS_BATCH"
|
|
231
|
-
}
|
|
232
|
-
},
|
|
233
|
-
{
|
|
234
|
-
model_name: "postprocessing"
|
|
235
|
-
model_version: -1
|
|
236
|
-
input_map {
|
|
237
|
-
key: "TOKENS_BATCH"
|
|
238
|
-
value: "_TOKENS_BATCH"
|
|
239
|
-
}
|
|
240
|
-
output_map {
|
|
241
|
-
key: "OUTPUT"
|
|
242
|
-
value: "text_output"
|
|
243
|
-
}
|
|
244
|
-
}
|
|
245
|
-
]
|
|
246
|
-
}
|
|
@@ -1,181 +0,0 @@
|
|
|
1
|
-
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Redistribution and use in source and binary forms, with or without
|
|
4
|
-
# modification, are permitted provided that the following conditions
|
|
5
|
-
# are met:
|
|
6
|
-
# * Redistributions of source code must retain the above copyright
|
|
7
|
-
# notice, this list of conditions and the following disclaimer.
|
|
8
|
-
# * Redistributions in binary form must reproduce the above copyright
|
|
9
|
-
# notice, this list of conditions and the following disclaimer in the
|
|
10
|
-
# documentation and/or other materials provided with the distribution.
|
|
11
|
-
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
|
12
|
-
# contributors may be used to endorse or promote products derived
|
|
13
|
-
# from this software without specific prior written permission.
|
|
14
|
-
#
|
|
15
|
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
|
16
|
-
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
17
|
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
|
18
|
-
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
|
19
|
-
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
|
20
|
-
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
|
21
|
-
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
|
22
|
-
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
|
23
|
-
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
24
|
-
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
25
|
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
26
|
-
|
|
27
|
-
import json
|
|
28
|
-
import os
|
|
29
|
-
from collections import OrderedDict
|
|
30
|
-
|
|
31
|
-
import numpy as np
|
|
32
|
-
import triton_python_backend_utils as pb_utils
|
|
33
|
-
from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
class TritonPythonModel:
|
|
37
|
-
"""Your Python model must use the same class name. Every Python model
|
|
38
|
-
that is created must have "TritonPythonModel" as the class name.
|
|
39
|
-
"""
|
|
40
|
-
|
|
41
|
-
def initialize(self, args):
|
|
42
|
-
"""`initialize` is called only once when the model is being loaded.
|
|
43
|
-
Implementing `initialize` function is optional. This function allows
|
|
44
|
-
the model to initialize any state associated with this model.
|
|
45
|
-
Parameters
|
|
46
|
-
----------
|
|
47
|
-
args : dict
|
|
48
|
-
Both keys and values are strings. The dictionary keys and values are:
|
|
49
|
-
* model_config: A JSON string containing the model configuration
|
|
50
|
-
* model_instance_kind: A string containing model instance kind
|
|
51
|
-
* model_instance_device_id: A string containing model instance device ID
|
|
52
|
-
* model_repository: Model repository path
|
|
53
|
-
* model_version: Model version
|
|
54
|
-
* model_name: Model name
|
|
55
|
-
"""
|
|
56
|
-
# Parse model configs
|
|
57
|
-
model_config = json.loads(args["model_config"])
|
|
58
|
-
# NOTE: Keep this in sync with the truss model.py variable
|
|
59
|
-
tokenizer_dir = os.environ["TRITON_TOKENIZER_REPOSITORY"]
|
|
60
|
-
tokenizer_type = model_config["parameters"]["tokenizer_type"]["string_value"]
|
|
61
|
-
|
|
62
|
-
if tokenizer_type == "t5":
|
|
63
|
-
self.tokenizer = T5Tokenizer(vocab_file=tokenizer_dir, padding_side="left")
|
|
64
|
-
elif tokenizer_type == "auto":
|
|
65
|
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
66
|
-
tokenizer_dir, padding_side="left"
|
|
67
|
-
)
|
|
68
|
-
elif tokenizer_type == "llama":
|
|
69
|
-
self.tokenizer = LlamaTokenizer.from_pretrained(
|
|
70
|
-
tokenizer_dir, legacy=False, padding_side="left"
|
|
71
|
-
)
|
|
72
|
-
else:
|
|
73
|
-
raise AttributeError(f"Unexpected tokenizer type: {tokenizer_type}")
|
|
74
|
-
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
75
|
-
|
|
76
|
-
# Parse model output configs
|
|
77
|
-
output_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT")
|
|
78
|
-
# Convert Triton types to numpy types
|
|
79
|
-
self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
|
|
80
|
-
|
|
81
|
-
self.state_dict = OrderedDict()
|
|
82
|
-
# TODO(pankaj) This should come from the batch size
|
|
83
|
-
self.cache_size = 2048
|
|
84
|
-
|
|
85
|
-
def execute(self, requests):
|
|
86
|
-
"""`execute` must be implemented in every Python model. `execute`
|
|
87
|
-
function receives a list of pb_utils.InferenceRequest as the only
|
|
88
|
-
argument. This function is called when an inference is requested
|
|
89
|
-
for this model. Depending on the batching configuration (e.g. Dynamic
|
|
90
|
-
Batching) used, `requests` may contain multiple requests. Every
|
|
91
|
-
Python model, must create one pb_utils.InferenceResponse for every
|
|
92
|
-
pb_utils.InferenceRequest in `requests`. If there is an error, you can
|
|
93
|
-
set the error argument when creating a pb_utils.InferenceResponse.
|
|
94
|
-
Parameters
|
|
95
|
-
----------
|
|
96
|
-
requests : list
|
|
97
|
-
A list of pb_utils.InferenceRequest
|
|
98
|
-
Returns
|
|
99
|
-
-------
|
|
100
|
-
list
|
|
101
|
-
A list of pb_utils.InferenceResponse. The length of this list must
|
|
102
|
-
be the same as `requests`
|
|
103
|
-
"""
|
|
104
|
-
|
|
105
|
-
responses = []
|
|
106
|
-
|
|
107
|
-
# Every Python backend must iterate over everyone of the requests
|
|
108
|
-
# and create a pb_utils.InferenceResponse for each of them.
|
|
109
|
-
for idx, request in enumerate(requests):
|
|
110
|
-
# Get request ID
|
|
111
|
-
request_id = request.request_id()
|
|
112
|
-
|
|
113
|
-
# Get input tensors
|
|
114
|
-
tokens_batch = (
|
|
115
|
-
pb_utils.get_input_tensor_by_name(request, "TOKENS_BATCH")
|
|
116
|
-
.as_numpy()
|
|
117
|
-
.flatten()
|
|
118
|
-
)
|
|
119
|
-
if len(tokens_batch) == 0:
|
|
120
|
-
continue
|
|
121
|
-
|
|
122
|
-
# Postprocess output data
|
|
123
|
-
prev_token = self._get_prev_token(request_id)
|
|
124
|
-
self._store_prev_token(request_id, tokens_batch[-1])
|
|
125
|
-
if prev_token is None:
|
|
126
|
-
delta = self.tokenizer.decode(tokens_batch)
|
|
127
|
-
else:
|
|
128
|
-
# TODO(pankaj) Figure out how to make tokenizer.decode not
|
|
129
|
-
# ignore initial whitespace so we can avoid this hack.
|
|
130
|
-
# Get string with and without previous token and diff. This hack
|
|
131
|
-
# is needed because tokenizer.decode strips initial whitespace.
|
|
132
|
-
old_string = self.tokenizer.decode([prev_token])
|
|
133
|
-
with_prev_token = np.concatenate(([prev_token], tokens_batch))
|
|
134
|
-
new_string = self.tokenizer.decode(with_prev_token)
|
|
135
|
-
delta = self._compute_delta(old_string, new_string)
|
|
136
|
-
|
|
137
|
-
# Create output tensor
|
|
138
|
-
output_tensor = pb_utils.Tensor(
|
|
139
|
-
"OUTPUT", np.array([delta]).astype(self.output_dtype)
|
|
140
|
-
)
|
|
141
|
-
inference_response = pb_utils.InferenceResponse(
|
|
142
|
-
output_tensors=[output_tensor]
|
|
143
|
-
)
|
|
144
|
-
responses.append(inference_response)
|
|
145
|
-
|
|
146
|
-
return responses
|
|
147
|
-
|
|
148
|
-
def finalize(self):
|
|
149
|
-
print("Cleaning up...")
|
|
150
|
-
|
|
151
|
-
def _store_prev_token(self, request_id, token):
|
|
152
|
-
if request_id in self.state_dict:
|
|
153
|
-
self.state_dict[request_id]["prev_token"] = token
|
|
154
|
-
|
|
155
|
-
# Move request ID to end of queue to prevent it from being evicted
|
|
156
|
-
self.state_dict.move_to_end(request_id)
|
|
157
|
-
else:
|
|
158
|
-
# Evict least recently used item if cache is full
|
|
159
|
-
if len(self.state_dict) > self.cache_size:
|
|
160
|
-
self.state_dict.popitem(last=False)
|
|
161
|
-
|
|
162
|
-
self.state_dict[request_id] = {"prev_token": token}
|
|
163
|
-
|
|
164
|
-
def _get_prev_token(self, request_id):
|
|
165
|
-
if request_id in self.state_dict:
|
|
166
|
-
return self.state_dict[request_id]["prev_token"]
|
|
167
|
-
return None
|
|
168
|
-
|
|
169
|
-
def _compute_delta(self, prev_str, new_str):
|
|
170
|
-
delta = "".join(
|
|
171
|
-
[
|
|
172
|
-
char
|
|
173
|
-
for index, char in enumerate(new_str)
|
|
174
|
-
if index >= len(prev_str) or char != prev_str[index]
|
|
175
|
-
]
|
|
176
|
-
)
|
|
177
|
-
return delta
|
|
178
|
-
|
|
179
|
-
def _postprocessing(self, tokens):
|
|
180
|
-
decoded_tokens = self.tokenizer.decode(tokens)
|
|
181
|
-
return decoded_tokens
|
|
@@ -1,64 +0,0 @@
|
|
|
1
|
-
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Redistribution and use in source and binary forms, with or without
|
|
4
|
-
# modification, are permitted provided that the following conditions
|
|
5
|
-
# are met:
|
|
6
|
-
# * Redistributions of source code must retain the above copyright
|
|
7
|
-
# notice, this list of conditions and the following disclaimer.
|
|
8
|
-
# * Redistributions in binary form must reproduce the above copyright
|
|
9
|
-
# notice, this list of conditions and the following disclaimer in the
|
|
10
|
-
# documentation and/or other materials provided with the distribution.
|
|
11
|
-
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
|
12
|
-
# contributors may be used to endorse or promote products derived
|
|
13
|
-
# from this software without specific prior written permission.
|
|
14
|
-
#
|
|
15
|
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
|
16
|
-
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
17
|
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
|
18
|
-
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
|
19
|
-
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
|
20
|
-
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
|
21
|
-
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
|
22
|
-
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
|
23
|
-
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
24
|
-
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
25
|
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
26
|
-
|
|
27
|
-
name: "postprocessing"
|
|
28
|
-
backend: "python"
|
|
29
|
-
max_batch_size: 2048
|
|
30
|
-
input [
|
|
31
|
-
{
|
|
32
|
-
name: "TOKENS_BATCH"
|
|
33
|
-
data_type: TYPE_INT32
|
|
34
|
-
dims: [ -1, -1 ]
|
|
35
|
-
}
|
|
36
|
-
]
|
|
37
|
-
output [
|
|
38
|
-
{
|
|
39
|
-
name: "OUTPUT"
|
|
40
|
-
data_type: TYPE_STRING
|
|
41
|
-
dims: [ -1, -1 ]
|
|
42
|
-
}
|
|
43
|
-
]
|
|
44
|
-
|
|
45
|
-
parameters {
|
|
46
|
-
key: "tokenizer_dir"
|
|
47
|
-
value: {
|
|
48
|
-
string_value: "NousResearch/Llama-2-7b-hf"
|
|
49
|
-
}
|
|
50
|
-
}
|
|
51
|
-
|
|
52
|
-
parameters {
|
|
53
|
-
key: "tokenizer_type"
|
|
54
|
-
value: {
|
|
55
|
-
string_value: "auto"
|
|
56
|
-
}
|
|
57
|
-
}
|
|
58
|
-
|
|
59
|
-
instance_group [
|
|
60
|
-
{
|
|
61
|
-
count: 1
|
|
62
|
-
kind: KIND_CPU
|
|
63
|
-
}
|
|
64
|
-
]
|