truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/__init__.py +10 -3
- truss/api/__init__.py +123 -0
- truss/api/definitions.py +51 -0
- truss/base/constants.py +116 -0
- truss/base/custom_types.py +29 -0
- truss/{errors.py → base/errors.py} +4 -0
- truss/base/trt_llm_config.py +310 -0
- truss/{truss_config.py → base/truss_config.py} +344 -31
- truss/{truss_spec.py → base/truss_spec.py} +20 -6
- truss/{validation.py → base/validation.py} +60 -11
- truss/cli/cli.py +841 -88
- truss/{remote → cli}/remote_cli.py +2 -7
- truss/contexts/docker_build_setup.py +67 -0
- truss/contexts/image_builder/cache_warmer.py +2 -8
- truss/contexts/image_builder/image_builder.py +1 -1
- truss/contexts/image_builder/serving_image_builder.py +292 -46
- truss/contexts/image_builder/util.py +1 -3
- truss/contexts/local_loader/docker_build_emulator.py +58 -0
- truss/contexts/local_loader/load_model_local.py +2 -2
- truss/contexts/local_loader/truss_module_loader.py +1 -1
- truss/contexts/local_loader/utils.py +1 -1
- truss/local/local_config.py +2 -6
- truss/local/local_config_handler.py +20 -5
- truss/patch/__init__.py +1 -0
- truss/patch/hash.py +4 -70
- truss/patch/signature.py +4 -16
- truss/patch/truss_dir_patch_applier.py +3 -78
- truss/remote/baseten/api.py +308 -23
- truss/remote/baseten/auth.py +3 -3
- truss/remote/baseten/core.py +257 -50
- truss/remote/baseten/custom_types.py +44 -0
- truss/remote/baseten/error.py +4 -0
- truss/remote/baseten/remote.py +369 -118
- truss/remote/baseten/service.py +118 -11
- truss/remote/baseten/utils/status.py +29 -0
- truss/remote/baseten/utils/tar.py +34 -22
- truss/remote/baseten/utils/transfer.py +36 -23
- truss/remote/remote_factory.py +14 -5
- truss/remote/truss_remote.py +72 -45
- truss/templates/base.Dockerfile.jinja +18 -16
- truss/templates/cache.Dockerfile.jinja +3 -3
- truss/{server → templates/control}/control/application.py +14 -35
- truss/{server → templates/control}/control/endpoints.py +39 -9
- truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
- truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
- truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
- truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
- truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
- truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
- truss/{server → templates/control}/control/server.py +11 -6
- truss/templates/control/requirements.txt +9 -0
- truss/templates/custom_python_dx/my_model.py +28 -0
- truss/templates/docker_server/proxy.conf.jinja +42 -0
- truss/templates/docker_server/supervisord.conf.jinja +27 -0
- truss/templates/docker_server_requirements.txt +1 -0
- truss/templates/server/common/errors.py +231 -0
- truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
- truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
- truss/{server → templates/server}/common/retry.py +1 -0
- truss/{server → templates/server}/common/schema.py +11 -9
- truss/templates/server/common/tracing.py +157 -0
- truss/templates/server/main.py +9 -0
- truss/templates/server/model_wrapper.py +961 -0
- truss/templates/server/requirements.txt +21 -0
- truss/templates/server/truss_server.py +447 -0
- truss/templates/server.Dockerfile.jinja +62 -14
- truss/templates/shared/dynamic_config_resolver.py +28 -0
- truss/templates/shared/lazy_data_resolver.py +164 -0
- truss/templates/shared/log_config.py +125 -0
- truss/{server → templates}/shared/secrets_resolver.py +1 -2
- truss/{server → templates}/shared/serialization.py +31 -9
- truss/{server → templates}/shared/util.py +3 -13
- truss/templates/trtllm-audio/model/model.py +49 -0
- truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
- truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
- truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
- truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
- truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
- truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
- truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
- truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
- truss/templates/trtllm-briton/src/extension.py +64 -0
- truss/tests/conftest.py +302 -94
- truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
- truss/tests/contexts/local_loader/test_load_local.py +2 -2
- truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
- truss/tests/patch/test_calc_patch.py +439 -127
- truss/tests/patch/test_dir_signature.py +3 -12
- truss/tests/patch/test_hash.py +1 -1
- truss/tests/patch/test_signature.py +1 -1
- truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
- truss/tests/patch/test_types.py +2 -2
- truss/tests/remote/baseten/test_api.py +153 -58
- truss/tests/remote/baseten/test_auth.py +2 -1
- truss/tests/remote/baseten/test_core.py +160 -12
- truss/tests/remote/baseten/test_remote.py +489 -77
- truss/tests/remote/baseten/test_service.py +55 -0
- truss/tests/remote/test_remote_factory.py +16 -18
- truss/tests/remote/test_truss_remote.py +26 -17
- truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
- truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
- truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
- truss/tests/{server → templates/control}/control/test_server.py +79 -24
- truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
- truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
- truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
- truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
- truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
- truss/tests/{server → templates/server}/common/test_retry.py +3 -3
- truss/tests/templates/server/test_model_wrapper.py +248 -0
- truss/tests/{server → templates/server}/test_schema.py +3 -5
- truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
- truss/tests/test_build.py +9 -52
- truss/tests/test_config.py +336 -77
- truss/tests/test_context_builder_image.py +3 -11
- truss/tests/test_control_truss_patching.py +7 -12
- truss/tests/test_custom_server.py +38 -0
- truss/tests/test_data/context_builder_image_test/test.py +3 -0
- truss/tests/test_data/happy.ipynb +56 -0
- truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
- truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
- truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
- truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
- truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
- truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
- truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
- truss/tests/test_data/test_async_truss/__init__.py +0 -0
- truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/config.yaml +16 -0
- truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/config.yaml +13 -0
- truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
- truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
- truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
- truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
- truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
- truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
- truss/tests/test_data/test_openai/__init__.py +0 -0
- truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
- truss/tests/test_data/test_openai/model/__init__.py +0 -0
- truss/tests/test_data/test_openai/model/model.py +15 -0
- truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
- truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
- truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
- truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
- truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
- truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
- truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
- truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
- truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
- truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
- truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
- truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
- truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
- truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
- truss/tests/test_data/test_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss/config.yaml +4 -0
- truss/tests/test_data/test_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss/model/dummy +0 -0
- truss/tests/test_data/test_truss/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
- truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
- truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_docker.py +2 -1
- truss/tests/test_model_inference.py +1340 -292
- truss/tests/test_model_schema.py +33 -26
- truss/tests/test_testing_utilities_for_other_tests.py +50 -5
- truss/tests/test_truss_gatherer.py +3 -5
- truss/tests/test_truss_handle.py +62 -59
- truss/tests/test_util.py +2 -1
- truss/tests/test_validation.py +15 -13
- truss/tests/trt_llm/test_trt_llm_config.py +41 -0
- truss/tests/trt_llm/test_validation.py +91 -0
- truss/tests/util/test_config_checks.py +40 -0
- truss/tests/util/test_env_vars.py +14 -0
- truss/tests/util/test_path.py +10 -23
- truss/trt_llm/config_checks.py +43 -0
- truss/trt_llm/validation.py +42 -0
- truss/truss_handle/__init__.py +0 -0
- truss/truss_handle/build.py +122 -0
- truss/{decorators.py → truss_handle/decorators.py} +1 -1
- truss/truss_handle/patch/__init__.py +0 -0
- truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
- truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
- truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
- truss/truss_handle/patch/hash.py +71 -0
- truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
- truss/truss_handle/patch/signature.py +22 -0
- truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
- truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
- truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
- truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
- truss/util/.truss_ignore +3 -0
- truss/{docker.py → util/docker.py} +6 -2
- truss/util/download.py +6 -15
- truss/util/env_vars.py +41 -0
- truss/util/log_utils.py +52 -0
- truss/util/path.py +20 -20
- truss/util/requirements.py +11 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
- truss-0.60.0.dist-info/RECORD +324 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
- truss-0.60.0.dist-info/entry_points.txt +4 -0
- truss_chains/__init__.py +71 -0
- truss_chains/definitions.py +756 -0
- truss_chains/deployment/__init__.py +0 -0
- truss_chains/deployment/code_gen.py +816 -0
- truss_chains/deployment/deployment_client.py +871 -0
- truss_chains/framework.py +1480 -0
- truss_chains/public_api.py +231 -0
- truss_chains/py.typed +0 -0
- truss_chains/pydantic_numpy.py +131 -0
- truss_chains/reference_code/reference_chainlet.py +34 -0
- truss_chains/reference_code/reference_model.py +10 -0
- truss_chains/remote_chainlet/__init__.py +0 -0
- truss_chains/remote_chainlet/model_skeleton.py +60 -0
- truss_chains/remote_chainlet/stub.py +380 -0
- truss_chains/remote_chainlet/utils.py +332 -0
- truss_chains/streaming.py +378 -0
- truss_chains/utils.py +178 -0
- CODE_OF_CONDUCT.md +0 -131
- CONTRIBUTING.md +0 -48
- README.md +0 -137
- context_builder.Dockerfile +0 -24
- truss/blob/blob_backend.py +0 -10
- truss/blob/blob_backend_registry.py +0 -23
- truss/blob/http_public_blob_backend.py +0 -23
- truss/build/__init__.py +0 -2
- truss/build/build.py +0 -143
- truss/build/configure.py +0 -63
- truss/cli/__init__.py +0 -2
- truss/cli/console.py +0 -5
- truss/cli/create.py +0 -5
- truss/config/trt_llm.py +0 -81
- truss/constants.py +0 -61
- truss/model_inference.py +0 -123
- truss/patch/types.py +0 -30
- truss/pytest.ini +0 -7
- truss/server/common/errors.py +0 -100
- truss/server/common/termination_handler_middleware.py +0 -64
- truss/server/common/truss_server.py +0 -389
- truss/server/control/patch/model_code_patch_applier.py +0 -46
- truss/server/control/patch/requirement_name_identifier.py +0 -17
- truss/server/inference_server.py +0 -29
- truss/server/model_wrapper.py +0 -434
- truss/server/shared/logging.py +0 -81
- truss/templates/trtllm/model/model.py +0 -97
- truss/templates/trtllm/packages/build_engine_utils.py +0 -34
- truss/templates/trtllm/packages/constants.py +0 -11
- truss/templates/trtllm/packages/schema.py +0 -216
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
- truss/templates/trtllm/packages/triton_client.py +0 -150
- truss/templates/trtllm/packages/utils.py +0 -43
- truss/test_data/context_builder_image_test/test.py +0 -4
- truss/test_data/happy.ipynb +0 -54
- truss/test_data/model_load_failure_test/config.yaml +0 -2
- truss/test_data/test_concurrency_truss/config.yaml +0 -2
- truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
- truss/test_data/test_streaming_truss/config.yaml +0 -3
- truss/test_data/test_truss/config.yaml +0 -2
- truss/tests/server/common/test_termination_handler_middleware.py +0 -93
- truss/tests/server/control/test_model_container_patch_applier.py +0 -203
- truss/tests/server/core/server/common/test_util.py +0 -19
- truss/tests/server/test_model_wrapper.py +0 -87
- truss/util/data_structures.py +0 -16
- truss-0.10.0rc1.dist-info/RECORD +0 -216
- truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
- truss/{server/shared → base}/__init__.py +0 -0
- truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
- truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
- truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
- truss/{server → templates/server}/common/__init__.py +0 -0
- truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
- truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
- truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
- truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
- truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
- truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
- truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
- truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
- truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
- truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
- truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
- truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
- truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
- truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
- /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
- /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
- /truss/{patch → truss_handle/patch}/constants.py +0 -0
- /truss/{notebook.py → util/notebook.py} +0 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
truss/__init__.py
CHANGED
|
@@ -1,9 +1,12 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
import warnings
|
|
3
2
|
from pathlib import Path
|
|
4
3
|
|
|
4
|
+
from pydantic import PydanticDeprecatedSince20
|
|
5
5
|
from single_source import get_version
|
|
6
6
|
|
|
7
|
+
# Suppress Pydantic V1 warnings, because we have to use it for backwards compat.
|
|
8
|
+
warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20)
|
|
9
|
+
|
|
7
10
|
__version__ = get_version(__name__, Path(__file__).parent.parent)
|
|
8
11
|
|
|
9
12
|
|
|
@@ -11,4 +14,8 @@ def version():
|
|
|
11
14
|
return __version__
|
|
12
15
|
|
|
13
16
|
|
|
14
|
-
from truss.
|
|
17
|
+
from truss.api import login, push, whoami
|
|
18
|
+
from truss.base import truss_config
|
|
19
|
+
from truss.truss_handle.build import load # TODO: Refactor all usages and remove.
|
|
20
|
+
|
|
21
|
+
__all__ = ["push", "login", "load", "whoami", "truss_config"]
|
truss/api/__init__.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Optional, Type, cast
|
|
2
|
+
|
|
3
|
+
if TYPE_CHECKING:
|
|
4
|
+
from rich import progress
|
|
5
|
+
|
|
6
|
+
from truss.api import definitions
|
|
7
|
+
from truss.remote.baseten.service import BasetenService
|
|
8
|
+
from truss.remote.remote_factory import RemoteFactory
|
|
9
|
+
from truss.remote.truss_remote import RemoteConfig
|
|
10
|
+
from truss.truss_handle.build import load
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def login(api_key: str):
|
|
14
|
+
"""
|
|
15
|
+
Logs user into Baseten account. Persists information to ~/.trussrc file,
|
|
16
|
+
so only needs to be invoked once.
|
|
17
|
+
Args:
|
|
18
|
+
api_key: Baseten API Key
|
|
19
|
+
"""
|
|
20
|
+
remote_url = "https://app.baseten.co"
|
|
21
|
+
remote_config = RemoteConfig(
|
|
22
|
+
name="baseten",
|
|
23
|
+
configs={
|
|
24
|
+
"remote_provider": "baseten",
|
|
25
|
+
"api_key": api_key,
|
|
26
|
+
"remote_url": remote_url,
|
|
27
|
+
},
|
|
28
|
+
)
|
|
29
|
+
RemoteFactory.update_remote_config(remote_config)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def whoami(remote: Optional[str] = None):
|
|
33
|
+
"""
|
|
34
|
+
Returns account information for the current user.
|
|
35
|
+
"""
|
|
36
|
+
if not remote:
|
|
37
|
+
available_remotes = RemoteFactory.get_available_config_names()
|
|
38
|
+
if len(available_remotes) == 1:
|
|
39
|
+
remote = available_remotes[0]
|
|
40
|
+
elif len(available_remotes) == 0:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
"Please authenticate via truss.login and pass it as an argument."
|
|
43
|
+
)
|
|
44
|
+
else:
|
|
45
|
+
raise ValueError(
|
|
46
|
+
"Multiple remotes found. Please pass the remote as an argument."
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
remote_provider = RemoteFactory.create(remote=remote)
|
|
50
|
+
return remote_provider.whoami()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def push(
|
|
54
|
+
target_directory: str,
|
|
55
|
+
remote: Optional[str] = None,
|
|
56
|
+
model_name: Optional[str] = None,
|
|
57
|
+
publish: bool = False,
|
|
58
|
+
promote: bool = False,
|
|
59
|
+
preserve_previous_production_deployment: bool = False,
|
|
60
|
+
trusted: bool = False,
|
|
61
|
+
deployment_name: Optional[str] = None,
|
|
62
|
+
environment: Optional[str] = None,
|
|
63
|
+
progress_bar: Optional[Type["progress.Progress"]] = None,
|
|
64
|
+
) -> definitions.ModelDeployment:
|
|
65
|
+
"""
|
|
66
|
+
Pushes a Truss to Baseten.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
target_directory: Directory of Truss to push.
|
|
70
|
+
remote: Name of the remote in .trussrc to patch changes to.
|
|
71
|
+
model_name: The name of the model, if different from the one in the config.yaml.
|
|
72
|
+
publish: Push the truss as a published deployment. If no production deployment exists,
|
|
73
|
+
promote the truss to production after deploy completes.
|
|
74
|
+
promote: Push the truss as a published deployment. Even if a production deployment exists,
|
|
75
|
+
promote the truss to production after deploy completes.
|
|
76
|
+
preserve_previous_production_deployment: Preserve the previous production deployment’s autoscaling
|
|
77
|
+
setting. When not specified, the previous production deployment will be updated to allow it to
|
|
78
|
+
scale to zero. Can only be use in combination with `promote` option.
|
|
79
|
+
trusted: Give Truss access to secrets on remote host.
|
|
80
|
+
deployment_name: Name of the deployment created by the push. Can only be
|
|
81
|
+
used in combination with `publish` or `promote`. Deployment name must
|
|
82
|
+
only contain alphanumeric, ’.’, ’-’ or ’_’ characters.
|
|
83
|
+
environment: Name of stable environment on baseten.
|
|
84
|
+
progress_bar: Optional `rich.progress.Progress` if output is desired.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
The newly created ModelDeployment.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
if not remote:
|
|
91
|
+
available_remotes = RemoteFactory.get_available_config_names()
|
|
92
|
+
if len(available_remotes) == 1:
|
|
93
|
+
remote = available_remotes[0]
|
|
94
|
+
elif len(available_remotes) == 0:
|
|
95
|
+
raise ValueError(
|
|
96
|
+
"Please authenticate via truss.login and pass it as an argument."
|
|
97
|
+
)
|
|
98
|
+
else:
|
|
99
|
+
raise ValueError(
|
|
100
|
+
"Multiple remotes found. Please pass the remote as an argument."
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
remote_provider = RemoteFactory.create(remote=remote)
|
|
104
|
+
tr = load(target_directory)
|
|
105
|
+
model_name = model_name or tr.spec.config.model_name
|
|
106
|
+
if not model_name:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
"No model name provided. Please specify a model name in config.yaml."
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
service = remote_provider.push(
|
|
112
|
+
tr,
|
|
113
|
+
model_name=model_name,
|
|
114
|
+
publish=publish,
|
|
115
|
+
trusted=trusted,
|
|
116
|
+
promote=promote,
|
|
117
|
+
preserve_previous_prod_deployment=preserve_previous_production_deployment,
|
|
118
|
+
deployment_name=deployment_name,
|
|
119
|
+
environment=environment,
|
|
120
|
+
progress_bar=progress_bar,
|
|
121
|
+
) # type: ignore
|
|
122
|
+
|
|
123
|
+
return definitions.ModelDeployment(cast(BasetenService, service))
|
truss/api/definitions.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import time
|
|
2
|
+
|
|
3
|
+
import pydantic
|
|
4
|
+
|
|
5
|
+
from truss.remote.baseten import service
|
|
6
|
+
from truss.remote.baseten.core import ACTIVE_STATUS, DEPLOYING_STATUSES
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ModelDeployment:
|
|
10
|
+
model_config = pydantic.ConfigDict(protected_namespaces=())
|
|
11
|
+
|
|
12
|
+
model_id: str
|
|
13
|
+
model_deployment_id: str
|
|
14
|
+
_baseten_service: service.BasetenService
|
|
15
|
+
|
|
16
|
+
def __init__(self, service: service.BasetenService):
|
|
17
|
+
self.model_id = service._model_id
|
|
18
|
+
self.model_deployment_id = service._model_version_id
|
|
19
|
+
self._baseten_service = service
|
|
20
|
+
|
|
21
|
+
def wait_for_active(self, timeout_seconds: int = 600) -> bool:
|
|
22
|
+
"""
|
|
23
|
+
Waits for the deployment to be active.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
timeout_seconds: The maximum time to wait for the deployment to be active.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
The status of the deployment.
|
|
30
|
+
"""
|
|
31
|
+
start_time = time.time()
|
|
32
|
+
for deployment_status in self._baseten_service.poll_deployment_status():
|
|
33
|
+
if (
|
|
34
|
+
timeout_seconds is not None
|
|
35
|
+
and time.time() - start_time > timeout_seconds
|
|
36
|
+
):
|
|
37
|
+
raise TimeoutError("Deployment timed out.")
|
|
38
|
+
|
|
39
|
+
if deployment_status == ACTIVE_STATUS:
|
|
40
|
+
return True
|
|
41
|
+
|
|
42
|
+
if deployment_status not in DEPLOYING_STATUSES:
|
|
43
|
+
raise Exception(f"Deployment failed with status: {deployment_status}")
|
|
44
|
+
|
|
45
|
+
raise RuntimeError("Error polling deployment status.")
|
|
46
|
+
|
|
47
|
+
def __repr__(self):
|
|
48
|
+
return (
|
|
49
|
+
f"ModelDeployment(model_id={self.model_id}, "
|
|
50
|
+
f"model_deployment_id={self.model_deployment_id})"
|
|
51
|
+
)
|
truss/base/constants.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
import pathlib
|
|
2
|
+
from typing import Set
|
|
3
|
+
|
|
4
|
+
SKLEARN = "sklearn"
|
|
5
|
+
TENSORFLOW = "tensorflow"
|
|
6
|
+
KERAS = "keras"
|
|
7
|
+
XGBOOST = "xgboost"
|
|
8
|
+
PYTORCH = "pytorch"
|
|
9
|
+
CUSTOM = "custom"
|
|
10
|
+
HUGGINGFACE_TRANSFORMER = "huggingface_transformer"
|
|
11
|
+
LIGHTGBM = "lightgbm"
|
|
12
|
+
|
|
13
|
+
_TRUSS_ROOT = pathlib.Path(__file__).parent.parent.resolve()
|
|
14
|
+
|
|
15
|
+
TEMPLATES_DIR = _TRUSS_ROOT / "templates"
|
|
16
|
+
TRADITIONAL_CUSTOM_TEMPLATE_DIR = TEMPLATES_DIR / "custom"
|
|
17
|
+
PYTHON_DX_CUSTOM_TEMPLATE_DIR = TEMPLATES_DIR / "custom_python_dx"
|
|
18
|
+
DOCKER_SERVER_TEMPLATES_DIR = TEMPLATES_DIR / "docker_server"
|
|
19
|
+
SERVER_CODE_DIR: pathlib.Path = TEMPLATES_DIR / "server"
|
|
20
|
+
TRITON_SERVER_CODE_DIR: pathlib.Path = TEMPLATES_DIR / "triton"
|
|
21
|
+
TRTLLM_TRUSS_DIR: pathlib.Path = TEMPLATES_DIR / "trtllm-briton"
|
|
22
|
+
SHARED_SERVING_AND_TRAINING_CODE_DIR_NAME = "shared"
|
|
23
|
+
SHARED_SERVING_AND_TRAINING_CODE_DIR: pathlib.Path = (
|
|
24
|
+
TEMPLATES_DIR / SHARED_SERVING_AND_TRAINING_CODE_DIR_NAME
|
|
25
|
+
)
|
|
26
|
+
CONTROL_SERVER_CODE_DIR: pathlib.Path = TEMPLATES_DIR / "control"
|
|
27
|
+
CHAINS_CODE_DIR: pathlib.Path = _TRUSS_ROOT.parent / "truss-chains" / "truss_chains"
|
|
28
|
+
|
|
29
|
+
SUPPORTED_PYTHON_VERSIONS = {"3.8", "3.9", "3.10", "3.11"}
|
|
30
|
+
MAX_SUPPORTED_PYTHON_VERSION_IN_CUSTOM_BASE_IMAGE = "3.12"
|
|
31
|
+
MIN_SUPPORTED_PYTHON_VERSION_IN_CUSTOM_BASE_IMAGE = "3.8"
|
|
32
|
+
|
|
33
|
+
TRTLLM_PREDICT_CONCURRENCY = 512
|
|
34
|
+
BEI_TRTLLM_CLIENT_BATCH_SIZE = 128
|
|
35
|
+
BEI_MAX_CONCURRENCY_TARGET_REQUESTS = 2048
|
|
36
|
+
BEI_REQUIRED_MAX_NUM_TOKENS = 16384
|
|
37
|
+
|
|
38
|
+
TRTLLM_MIN_MEMORY_REQUEST_GI = 16
|
|
39
|
+
HF_MODELS_API_URL = "https://huggingface.co/api/models"
|
|
40
|
+
HF_ACCESS_TOKEN_KEY = "hf_access_token"
|
|
41
|
+
TRUSSLESS_MAX_PAYLOAD_SIZE = "64M"
|
|
42
|
+
# Alias for TEMPLATES_DIR
|
|
43
|
+
SERVING_DIR: pathlib.Path = TEMPLATES_DIR
|
|
44
|
+
|
|
45
|
+
REQUIREMENTS_TXT_FILENAME = "requirements.txt"
|
|
46
|
+
USER_SUPPLIED_REQUIREMENTS_TXT_FILENAME = "user_requirements.txt"
|
|
47
|
+
BASE_SERVER_REQUIREMENTS_TXT_FILENAME = "base_server_requirements.txt"
|
|
48
|
+
SERVER_REQUIREMENTS_TXT_FILENAME = "server_requirements.txt"
|
|
49
|
+
SYSTEM_PACKAGES_TXT_FILENAME = "system_packages.txt"
|
|
50
|
+
|
|
51
|
+
FILENAME_CONSTANTS_MAP = {
|
|
52
|
+
"config_requirements_filename": REQUIREMENTS_TXT_FILENAME,
|
|
53
|
+
"user_supplied_requirements_filename": USER_SUPPLIED_REQUIREMENTS_TXT_FILENAME,
|
|
54
|
+
"base_server_requirements_filename": BASE_SERVER_REQUIREMENTS_TXT_FILENAME,
|
|
55
|
+
"server_requirements_filename": SERVER_REQUIREMENTS_TXT_FILENAME,
|
|
56
|
+
"system_packages_filename": SYSTEM_PACKAGES_TXT_FILENAME,
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
SERVER_DOCKERFILE_TEMPLATE_NAME = "server.Dockerfile.jinja"
|
|
60
|
+
MODEL_DOCKERFILE_NAME = "Dockerfile"
|
|
61
|
+
|
|
62
|
+
README_TEMPLATE_NAME = "README.md.jinja"
|
|
63
|
+
MODEL_README_NAME = "README.md"
|
|
64
|
+
|
|
65
|
+
CONFIG_FILE = "config.yaml"
|
|
66
|
+
DOCKERFILE = "Dockerfile"
|
|
67
|
+
# Used to indicate whether to associate a container with Truss
|
|
68
|
+
TRUSS = "truss"
|
|
69
|
+
# Used to create unique identifier based on last time truss was updated
|
|
70
|
+
TRUSS_MODIFIED_TIME = "truss_modified_time"
|
|
71
|
+
# Path of the Truss used to identify which Truss is being referred
|
|
72
|
+
TRUSS_DIR = "truss_dir"
|
|
73
|
+
TRUSS_HASH = "truss_hash"
|
|
74
|
+
|
|
75
|
+
HUGGINGFACE_TRANSFORMER_MODULE_NAME: Set[str] = set({})
|
|
76
|
+
|
|
77
|
+
# list from https://scikit-learn.org/stable/developers/advanced_installation.html
|
|
78
|
+
SKLEARN_REQ_MODULE_NAMES: Set[str] = {
|
|
79
|
+
"numpy",
|
|
80
|
+
"scipy",
|
|
81
|
+
"joblib",
|
|
82
|
+
"scikit-learn",
|
|
83
|
+
"threadpoolctl",
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
XGBOOST_REQ_MODULE_NAMES: Set[str] = {"xgboost"}
|
|
87
|
+
|
|
88
|
+
# list from https://www.tensorflow.org/install/pip
|
|
89
|
+
# if problematic, lets look to https://www.tensorflow.org/install/source
|
|
90
|
+
TENSORFLOW_REQ_MODULE_NAMES: Set[str] = {"tensorflow"}
|
|
91
|
+
|
|
92
|
+
LIGHTGBM_REQ_MODULE_NAMES: Set[str] = {"lightgbm"}
|
|
93
|
+
|
|
94
|
+
# list from https://pytorch.org/get-started/locally/
|
|
95
|
+
PYTORCH_REQ_MODULE_NAMES: Set[str] = {"torch", "torchvision", "torchaudio"}
|
|
96
|
+
|
|
97
|
+
MLFLOW_REQ_MODULE_NAMES: Set[str] = {"mlflow"}
|
|
98
|
+
|
|
99
|
+
INFERENCE_SERVER_PORT = 8080
|
|
100
|
+
|
|
101
|
+
HTTP_PUBLIC_BLOB_BACKEND = "http_public"
|
|
102
|
+
|
|
103
|
+
REGISTRY_BUILD_SECRET_PREFIX = "DOCKER_REGISTRY_"
|
|
104
|
+
|
|
105
|
+
TRTLLM_SPEC_DEC_TARGET_MODEL_NAME = "target"
|
|
106
|
+
TRTLLM_SPEC_DEC_DRAFT_MODEL_NAME = "draft"
|
|
107
|
+
TRTLLM_BASE_IMAGE = "baseten/briton-server:v0.16.0-5be7b58"
|
|
108
|
+
TRTLLM_PYTHON_EXECUTABLE = "/usr/local/briton/venv/bin/python"
|
|
109
|
+
BASE_TRTLLM_REQUIREMENTS = ["briton==0.4.2"]
|
|
110
|
+
BEI_TRTLLM_BASE_IMAGE = "baseten/bei:0.0.17@sha256:9c3577f6ec672d6da5aca18e9c0ebdddd65ed80c8858e757fbde7e9cf48de01d"
|
|
111
|
+
|
|
112
|
+
BEI_TRTLLM_PYTHON_EXECUTABLE = "/usr/bin/python3"
|
|
113
|
+
|
|
114
|
+
OPENAI_COMPATIBLE_TAG = "openai-compatible"
|
|
115
|
+
|
|
116
|
+
PRODUCTION_ENVIRONMENT_NAME = "production"
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
# TODO(marius/TaT): kill this.
|
|
7
|
+
class ModelFrameworkType(Enum):
|
|
8
|
+
SKLEARN = "sklearn"
|
|
9
|
+
TENSORFLOW = "tensorflow"
|
|
10
|
+
KERAS = "keras"
|
|
11
|
+
PYTORCH = "pytorch"
|
|
12
|
+
HUGGINGFACE_TRANSFORMER = "huggingface_transformer"
|
|
13
|
+
XGBOOST = "xgboost"
|
|
14
|
+
LIGHTGBM = "lightgbm"
|
|
15
|
+
MLFLOW = "mlflow"
|
|
16
|
+
CUSTOM = "custom"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class Example:
|
|
21
|
+
name: str
|
|
22
|
+
input: Any
|
|
23
|
+
|
|
24
|
+
@staticmethod
|
|
25
|
+
def from_dict(example_dict):
|
|
26
|
+
return Example(name=example_dict["name"], input=example_dict["input"])
|
|
27
|
+
|
|
28
|
+
def to_dict(self) -> dict:
|
|
29
|
+
return {"name": self.name, "input": self.input}
|
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import warnings
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from typing import Any, Optional
|
|
9
|
+
|
|
10
|
+
from huggingface_hub.errors import HFValidationError
|
|
11
|
+
from huggingface_hub.utils import validate_repo_id
|
|
12
|
+
from pydantic import BaseModel, PydanticDeprecatedSince20, model_validator, validator
|
|
13
|
+
|
|
14
|
+
from truss.base.constants import BEI_REQUIRED_MAX_NUM_TOKENS
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
# Suppress Pydantic V1 warnings, because we have to use it for backwards compat.
|
|
18
|
+
warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20)
|
|
19
|
+
|
|
20
|
+
ENGINE_BUILDER_TRUSS_RUNTIME_MIGRATION = (
|
|
21
|
+
os.environ.get("ENGINE_BUILDER_TRUSS_RUNTIME_MIGRATION", "False") == "True"
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class TrussTRTLLMModel(str, Enum):
|
|
26
|
+
LLAMA = "llama"
|
|
27
|
+
MISTRAL = "mistral"
|
|
28
|
+
DEEPSEEK = "deepseek"
|
|
29
|
+
WHISPER = "whisper"
|
|
30
|
+
QWEN = "qwen"
|
|
31
|
+
ENCODER = "encoder"
|
|
32
|
+
PALMYRA = "palmyra"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class TrussTRTLLMQuantizationType(str, Enum):
|
|
36
|
+
NO_QUANT = "no_quant"
|
|
37
|
+
WEIGHTS_ONLY_INT8 = "weights_int8"
|
|
38
|
+
WEIGHTS_KV_INT8 = "weights_kv_int8"
|
|
39
|
+
WEIGHTS_ONLY_INT4 = "weights_int4"
|
|
40
|
+
WEIGHTS_INT4_KV_INT8 = "weights_int4_kv_int8"
|
|
41
|
+
SMOOTH_QUANT = "smooth_quant"
|
|
42
|
+
FP8 = "fp8"
|
|
43
|
+
FP8_KV = "fp8_kv"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class TrussTRTLLMPluginConfiguration(BaseModel):
|
|
47
|
+
paged_kv_cache: bool = True
|
|
48
|
+
gemm_plugin: str = "auto"
|
|
49
|
+
use_paged_context_fmha: bool = True
|
|
50
|
+
use_fp8_context_fmha: bool = False
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class CheckpointSource(str, Enum):
|
|
54
|
+
HF = "HF"
|
|
55
|
+
GCS = "GCS"
|
|
56
|
+
LOCAL = "LOCAL"
|
|
57
|
+
# REMOTE_URL is useful when the checkpoint lives on remote storage accessible via HTTP (e.g a presigned URL)
|
|
58
|
+
REMOTE_URL = "REMOTE_URL"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class CheckpointRepository(BaseModel):
|
|
62
|
+
source: CheckpointSource
|
|
63
|
+
repo: str
|
|
64
|
+
revision: Optional[str] = None
|
|
65
|
+
|
|
66
|
+
def __init__(self, **data):
|
|
67
|
+
super().__init__(**data)
|
|
68
|
+
if self.source == CheckpointSource.HF:
|
|
69
|
+
self._validate_hf_repo_id()
|
|
70
|
+
|
|
71
|
+
def _validate_hf_repo_id(self):
|
|
72
|
+
try:
|
|
73
|
+
validate_repo_id(self.repo)
|
|
74
|
+
except HFValidationError as e:
|
|
75
|
+
raise ValueError(
|
|
76
|
+
f"HuggingFace repository validation failed: {str(e)}"
|
|
77
|
+
) from e
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class TrussTRTLLMBatchSchedulerPolicy(str, Enum):
|
|
81
|
+
MAX_UTILIZATION = "max_utilization"
|
|
82
|
+
GUARANTEED_NO_EVICT = "guaranteed_no_evict"
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class TrussSpecDecMode(str, Enum):
|
|
86
|
+
DRAFT_EXTERNAL = "DRAFT_TOKENS_EXTERNAL"
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class TrussTRTLLMRuntimeConfiguration(BaseModel):
|
|
90
|
+
kv_cache_free_gpu_mem_fraction: float = 0.9
|
|
91
|
+
kv_cache_host_memory_bytes: Optional[int] = None
|
|
92
|
+
enable_chunked_context: bool = True
|
|
93
|
+
batch_scheduler_policy: TrussTRTLLMBatchSchedulerPolicy = (
|
|
94
|
+
TrussTRTLLMBatchSchedulerPolicy.GUARANTEED_NO_EVICT
|
|
95
|
+
)
|
|
96
|
+
request_default_max_tokens: Optional[int] = None
|
|
97
|
+
total_token_limit: int = 500000
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class TrussTRTLLMBuildConfiguration(BaseModel):
|
|
101
|
+
base_model: TrussTRTLLMModel
|
|
102
|
+
max_seq_len: int
|
|
103
|
+
max_batch_size: int = 256
|
|
104
|
+
max_num_tokens: int = 8192
|
|
105
|
+
max_beam_width: int = 1
|
|
106
|
+
max_prompt_embedding_table_size: int = 0
|
|
107
|
+
checkpoint_repository: CheckpointRepository
|
|
108
|
+
gather_all_token_logits: bool = False
|
|
109
|
+
strongly_typed: bool = False
|
|
110
|
+
quantization_type: TrussTRTLLMQuantizationType = (
|
|
111
|
+
TrussTRTLLMQuantizationType.NO_QUANT
|
|
112
|
+
)
|
|
113
|
+
tensor_parallel_count: int = 1
|
|
114
|
+
pipeline_parallel_count: int = 1
|
|
115
|
+
plugin_configuration: TrussTRTLLMPluginConfiguration = (
|
|
116
|
+
TrussTRTLLMPluginConfiguration()
|
|
117
|
+
)
|
|
118
|
+
num_builder_gpus: Optional[int] = None
|
|
119
|
+
speculator: Optional[TrussSpeculatorConfiguration] = None
|
|
120
|
+
|
|
121
|
+
class Config:
|
|
122
|
+
extra = "forbid"
|
|
123
|
+
|
|
124
|
+
def __init__(self, **data):
|
|
125
|
+
super().__init__(**data)
|
|
126
|
+
self._validate_kv_cache_flags()
|
|
127
|
+
self._validate_speculator_config()
|
|
128
|
+
self._bei_specfic_migration()
|
|
129
|
+
|
|
130
|
+
@validator("max_beam_width")
|
|
131
|
+
def check_max_beam_width(cls, v: int):
|
|
132
|
+
if isinstance(v, int):
|
|
133
|
+
if v != 1:
|
|
134
|
+
raise ValueError(
|
|
135
|
+
"max_beam_width greater than 1 is not currently supported"
|
|
136
|
+
)
|
|
137
|
+
return v
|
|
138
|
+
|
|
139
|
+
def _bei_specfic_migration(self):
|
|
140
|
+
"""performs embedding specfic optimizations (no kv-cache, high batch size)"""
|
|
141
|
+
if self.base_model == TrussTRTLLMModel.ENCODER:
|
|
142
|
+
# Encoder specific settings
|
|
143
|
+
logger.info(
|
|
144
|
+
f"Your setting of `build.max_seq_len={self.max_seq_len}` is not used and "
|
|
145
|
+
"automatically inferred from the model repo config.json -> `max_position_embeddings`"
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
if self.max_num_tokens < BEI_REQUIRED_MAX_NUM_TOKENS:
|
|
149
|
+
logger.warning(
|
|
150
|
+
f"build.max_num_tokens={self.max_num_tokens}, upgrading to {BEI_REQUIRED_MAX_NUM_TOKENS}"
|
|
151
|
+
)
|
|
152
|
+
self.max_num_tokens = BEI_REQUIRED_MAX_NUM_TOKENS
|
|
153
|
+
self.plugin_configuration.paged_kv_cache = False
|
|
154
|
+
self.plugin_configuration.use_paged_context_fmha = False
|
|
155
|
+
|
|
156
|
+
if "_kv" in self.quantization_type.value:
|
|
157
|
+
raise ValueError(
|
|
158
|
+
"encoder does not have a kv-cache, therefore a kv specfic datatype is not valid"
|
|
159
|
+
f"you selected build.quantization_type {self.quantization_type}"
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
def _validate_kv_cache_flags(self):
|
|
163
|
+
if not self.plugin_configuration.paged_kv_cache and (
|
|
164
|
+
self.plugin_configuration.use_paged_context_fmha
|
|
165
|
+
or self.plugin_configuration.use_fp8_context_fmha
|
|
166
|
+
):
|
|
167
|
+
raise ValueError(
|
|
168
|
+
"Using paged context fmha or fp8 context fmha requires requires paged kv cache"
|
|
169
|
+
)
|
|
170
|
+
if (
|
|
171
|
+
self.plugin_configuration.use_fp8_context_fmha
|
|
172
|
+
and not self.plugin_configuration.use_paged_context_fmha
|
|
173
|
+
):
|
|
174
|
+
raise ValueError("Using fp8 context fmha requires paged context fmha")
|
|
175
|
+
if (
|
|
176
|
+
self.plugin_configuration.use_fp8_context_fmha
|
|
177
|
+
and not self.quantization_type == TrussTRTLLMQuantizationType.FP8_KV
|
|
178
|
+
):
|
|
179
|
+
raise ValueError("Using fp8 context fmha requires fp8 kv cache dtype")
|
|
180
|
+
return self
|
|
181
|
+
|
|
182
|
+
def _validate_speculator_config(self):
|
|
183
|
+
if self.speculator:
|
|
184
|
+
if self.base_model is TrussTRTLLMModel.WHISPER:
|
|
185
|
+
raise ValueError("Speculative decoding for Whisper is not supported.")
|
|
186
|
+
if not all(
|
|
187
|
+
[
|
|
188
|
+
self.plugin_configuration.use_paged_context_fmha,
|
|
189
|
+
self.plugin_configuration.paged_kv_cache,
|
|
190
|
+
]
|
|
191
|
+
):
|
|
192
|
+
raise ValueError(
|
|
193
|
+
"KV cache block reuse must be enabled for speculative decoding target model."
|
|
194
|
+
)
|
|
195
|
+
if self.speculator.build:
|
|
196
|
+
if (
|
|
197
|
+
self.tensor_parallel_count
|
|
198
|
+
!= self.speculator.build.tensor_parallel_count
|
|
199
|
+
):
|
|
200
|
+
raise ValueError(
|
|
201
|
+
"Speculative decoding requires the same tensor parallelism for target and draft models."
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
@property
|
|
205
|
+
def max_draft_len(self) -> Optional[int]:
|
|
206
|
+
if self.speculator:
|
|
207
|
+
return self.speculator.num_draft_tokens
|
|
208
|
+
return None
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
class TrussSpeculatorConfiguration(BaseModel):
|
|
212
|
+
speculative_decoding_mode: TrussSpecDecMode = TrussSpecDecMode.DRAFT_EXTERNAL
|
|
213
|
+
num_draft_tokens: int
|
|
214
|
+
checkpoint_repository: Optional[CheckpointRepository] = None
|
|
215
|
+
runtime: TrussTRTLLMRuntimeConfiguration = TrussTRTLLMRuntimeConfiguration()
|
|
216
|
+
build: Optional[TrussTRTLLMBuildConfiguration] = None
|
|
217
|
+
|
|
218
|
+
def __init__(self, **data):
|
|
219
|
+
super().__init__(**data)
|
|
220
|
+
self._validate_checkpoint()
|
|
221
|
+
|
|
222
|
+
def _validate_checkpoint(self):
|
|
223
|
+
if not (bool(self.checkpoint_repository) ^ bool(self.build)):
|
|
224
|
+
raise ValueError(
|
|
225
|
+
"Speculative decoding requires exactly one of checkpoint_repository or build to be configured."
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
@property
|
|
229
|
+
def resolved_checkpoint_repository(self) -> CheckpointRepository:
|
|
230
|
+
if self.build:
|
|
231
|
+
return self.build.checkpoint_repository
|
|
232
|
+
elif self.checkpoint_repository:
|
|
233
|
+
return self.checkpoint_repository
|
|
234
|
+
else:
|
|
235
|
+
raise ValueError(
|
|
236
|
+
"Speculative decoding requires exactly one of checkpoint_repository or build to be configured."
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
class TRTLLMConfiguration(BaseModel):
|
|
241
|
+
runtime: TrussTRTLLMRuntimeConfiguration = TrussTRTLLMRuntimeConfiguration()
|
|
242
|
+
build: TrussTRTLLMBuildConfiguration
|
|
243
|
+
|
|
244
|
+
@model_validator(mode="before")
|
|
245
|
+
@classmethod
|
|
246
|
+
def migrate_runtime_fields(cls, data: Any) -> Any:
|
|
247
|
+
extra_runtime_fields = {}
|
|
248
|
+
valid_build_fields = {}
|
|
249
|
+
if isinstance(data.get("build"), dict):
|
|
250
|
+
for key, value in data.get("build").items():
|
|
251
|
+
if key in TrussTRTLLMBuildConfiguration.__annotations__:
|
|
252
|
+
valid_build_fields[key] = value
|
|
253
|
+
else:
|
|
254
|
+
if key in TrussTRTLLMRuntimeConfiguration.__annotations__:
|
|
255
|
+
logger.warning(f"Found runtime.{key}: {value} in build config")
|
|
256
|
+
extra_runtime_fields[key] = value
|
|
257
|
+
if extra_runtime_fields:
|
|
258
|
+
logger.warning(
|
|
259
|
+
f"Found extra fields {list(extra_runtime_fields.keys())} in build configuration, unspecified runtime fields will be configured using these values."
|
|
260
|
+
" This configuration of deprecated fields is scheduled for removal, please upgrade to the latest truss version and update configs according to https://docs.baseten.co/performance/engine-builder-config."
|
|
261
|
+
)
|
|
262
|
+
if data.get("runtime"):
|
|
263
|
+
data.get("runtime").update(
|
|
264
|
+
{
|
|
265
|
+
k: v
|
|
266
|
+
for k, v in extra_runtime_fields.items()
|
|
267
|
+
if k not in data.get("runtime")
|
|
268
|
+
}
|
|
269
|
+
)
|
|
270
|
+
else:
|
|
271
|
+
data.update(
|
|
272
|
+
{"runtime": {k: v for k, v in extra_runtime_fields.items()}}
|
|
273
|
+
)
|
|
274
|
+
data.update({"build": valid_build_fields})
|
|
275
|
+
return data
|
|
276
|
+
return data
|
|
277
|
+
|
|
278
|
+
@model_validator(mode="after")
|
|
279
|
+
def after(self: "TRTLLMConfiguration") -> "TRTLLMConfiguration":
|
|
280
|
+
# check if there is an error wrt. runtime.enable_chunked_context
|
|
281
|
+
if (
|
|
282
|
+
self.runtime.enable_chunked_context
|
|
283
|
+
and (self.build.base_model != TrussTRTLLMModel.ENCODER)
|
|
284
|
+
and not (
|
|
285
|
+
self.build.plugin_configuration.use_paged_context_fmha
|
|
286
|
+
and self.build.plugin_configuration.paged_kv_cache
|
|
287
|
+
)
|
|
288
|
+
):
|
|
289
|
+
if ENGINE_BUILDER_TRUSS_RUNTIME_MIGRATION:
|
|
290
|
+
logger.warning(
|
|
291
|
+
"If trt_llm.runtime.enable_chunked_context is True, then trt_llm.build.plugin_configuration.use_paged_context_fmha and trt_llm.build.plugin_configuration.paged_kv_cache should be True. "
|
|
292
|
+
"Setting trt_llm.build.plugin_configuration.use_paged_context_fmha and trt_llm.build.plugin_configuration.paged_kv_cache to True."
|
|
293
|
+
)
|
|
294
|
+
self.build.plugin_configuration.use_paged_context_fmha = True
|
|
295
|
+
self.build.plugin_configuration.paged_kv_cache = True
|
|
296
|
+
else:
|
|
297
|
+
raise ValueError(
|
|
298
|
+
"If runtime.enable_chunked_context is True, then build.plugin_configuration.use_paged_context_fmha and build.plugin_configuration.paged_kv_cache should be True"
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
return self
|
|
302
|
+
|
|
303
|
+
@property
|
|
304
|
+
def requires_build(self):
|
|
305
|
+
return self.build is not None
|
|
306
|
+
|
|
307
|
+
# TODO(Abu): Replace this with model_dump(json=True)
|
|
308
|
+
# when pydantic v2 is used here
|
|
309
|
+
def to_json_dict(self, verbose=True):
|
|
310
|
+
return json.loads(self.json(exclude_unset=not verbose))
|