truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/__init__.py +10 -3
- truss/api/__init__.py +123 -0
- truss/api/definitions.py +51 -0
- truss/base/constants.py +116 -0
- truss/base/custom_types.py +29 -0
- truss/{errors.py → base/errors.py} +4 -0
- truss/base/trt_llm_config.py +310 -0
- truss/{truss_config.py → base/truss_config.py} +344 -31
- truss/{truss_spec.py → base/truss_spec.py} +20 -6
- truss/{validation.py → base/validation.py} +60 -11
- truss/cli/cli.py +841 -88
- truss/{remote → cli}/remote_cli.py +2 -7
- truss/contexts/docker_build_setup.py +67 -0
- truss/contexts/image_builder/cache_warmer.py +2 -8
- truss/contexts/image_builder/image_builder.py +1 -1
- truss/contexts/image_builder/serving_image_builder.py +292 -46
- truss/contexts/image_builder/util.py +1 -3
- truss/contexts/local_loader/docker_build_emulator.py +58 -0
- truss/contexts/local_loader/load_model_local.py +2 -2
- truss/contexts/local_loader/truss_module_loader.py +1 -1
- truss/contexts/local_loader/utils.py +1 -1
- truss/local/local_config.py +2 -6
- truss/local/local_config_handler.py +20 -5
- truss/patch/__init__.py +1 -0
- truss/patch/hash.py +4 -70
- truss/patch/signature.py +4 -16
- truss/patch/truss_dir_patch_applier.py +3 -78
- truss/remote/baseten/api.py +308 -23
- truss/remote/baseten/auth.py +3 -3
- truss/remote/baseten/core.py +257 -50
- truss/remote/baseten/custom_types.py +44 -0
- truss/remote/baseten/error.py +4 -0
- truss/remote/baseten/remote.py +369 -118
- truss/remote/baseten/service.py +118 -11
- truss/remote/baseten/utils/status.py +29 -0
- truss/remote/baseten/utils/tar.py +34 -22
- truss/remote/baseten/utils/transfer.py +36 -23
- truss/remote/remote_factory.py +14 -5
- truss/remote/truss_remote.py +72 -45
- truss/templates/base.Dockerfile.jinja +18 -16
- truss/templates/cache.Dockerfile.jinja +3 -3
- truss/{server → templates/control}/control/application.py +14 -35
- truss/{server → templates/control}/control/endpoints.py +39 -9
- truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
- truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
- truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
- truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
- truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
- truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
- truss/{server → templates/control}/control/server.py +11 -6
- truss/templates/control/requirements.txt +9 -0
- truss/templates/custom_python_dx/my_model.py +28 -0
- truss/templates/docker_server/proxy.conf.jinja +42 -0
- truss/templates/docker_server/supervisord.conf.jinja +27 -0
- truss/templates/docker_server_requirements.txt +1 -0
- truss/templates/server/common/errors.py +231 -0
- truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
- truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
- truss/{server → templates/server}/common/retry.py +1 -0
- truss/{server → templates/server}/common/schema.py +11 -9
- truss/templates/server/common/tracing.py +157 -0
- truss/templates/server/main.py +9 -0
- truss/templates/server/model_wrapper.py +961 -0
- truss/templates/server/requirements.txt +21 -0
- truss/templates/server/truss_server.py +447 -0
- truss/templates/server.Dockerfile.jinja +62 -14
- truss/templates/shared/dynamic_config_resolver.py +28 -0
- truss/templates/shared/lazy_data_resolver.py +164 -0
- truss/templates/shared/log_config.py +125 -0
- truss/{server → templates}/shared/secrets_resolver.py +1 -2
- truss/{server → templates}/shared/serialization.py +31 -9
- truss/{server → templates}/shared/util.py +3 -13
- truss/templates/trtllm-audio/model/model.py +49 -0
- truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
- truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
- truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
- truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
- truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
- truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
- truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
- truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
- truss/templates/trtllm-briton/src/extension.py +64 -0
- truss/tests/conftest.py +302 -94
- truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
- truss/tests/contexts/local_loader/test_load_local.py +2 -2
- truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
- truss/tests/patch/test_calc_patch.py +439 -127
- truss/tests/patch/test_dir_signature.py +3 -12
- truss/tests/patch/test_hash.py +1 -1
- truss/tests/patch/test_signature.py +1 -1
- truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
- truss/tests/patch/test_types.py +2 -2
- truss/tests/remote/baseten/test_api.py +153 -58
- truss/tests/remote/baseten/test_auth.py +2 -1
- truss/tests/remote/baseten/test_core.py +160 -12
- truss/tests/remote/baseten/test_remote.py +489 -77
- truss/tests/remote/baseten/test_service.py +55 -0
- truss/tests/remote/test_remote_factory.py +16 -18
- truss/tests/remote/test_truss_remote.py +26 -17
- truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
- truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
- truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
- truss/tests/{server → templates/control}/control/test_server.py +79 -24
- truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
- truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
- truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
- truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
- truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
- truss/tests/{server → templates/server}/common/test_retry.py +3 -3
- truss/tests/templates/server/test_model_wrapper.py +248 -0
- truss/tests/{server → templates/server}/test_schema.py +3 -5
- truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
- truss/tests/test_build.py +9 -52
- truss/tests/test_config.py +336 -77
- truss/tests/test_context_builder_image.py +3 -11
- truss/tests/test_control_truss_patching.py +7 -12
- truss/tests/test_custom_server.py +38 -0
- truss/tests/test_data/context_builder_image_test/test.py +3 -0
- truss/tests/test_data/happy.ipynb +56 -0
- truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
- truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
- truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
- truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
- truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
- truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
- truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
- truss/tests/test_data/test_async_truss/__init__.py +0 -0
- truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/config.yaml +16 -0
- truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/config.yaml +13 -0
- truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
- truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
- truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
- truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
- truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
- truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
- truss/tests/test_data/test_openai/__init__.py +0 -0
- truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
- truss/tests/test_data/test_openai/model/__init__.py +0 -0
- truss/tests/test_data/test_openai/model/model.py +15 -0
- truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
- truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
- truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
- truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
- truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
- truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
- truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
- truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
- truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
- truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
- truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
- truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
- truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
- truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
- truss/tests/test_data/test_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss/config.yaml +4 -0
- truss/tests/test_data/test_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss/model/dummy +0 -0
- truss/tests/test_data/test_truss/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
- truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
- truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_docker.py +2 -1
- truss/tests/test_model_inference.py +1340 -292
- truss/tests/test_model_schema.py +33 -26
- truss/tests/test_testing_utilities_for_other_tests.py +50 -5
- truss/tests/test_truss_gatherer.py +3 -5
- truss/tests/test_truss_handle.py +62 -59
- truss/tests/test_util.py +2 -1
- truss/tests/test_validation.py +15 -13
- truss/tests/trt_llm/test_trt_llm_config.py +41 -0
- truss/tests/trt_llm/test_validation.py +91 -0
- truss/tests/util/test_config_checks.py +40 -0
- truss/tests/util/test_env_vars.py +14 -0
- truss/tests/util/test_path.py +10 -23
- truss/trt_llm/config_checks.py +43 -0
- truss/trt_llm/validation.py +42 -0
- truss/truss_handle/__init__.py +0 -0
- truss/truss_handle/build.py +122 -0
- truss/{decorators.py → truss_handle/decorators.py} +1 -1
- truss/truss_handle/patch/__init__.py +0 -0
- truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
- truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
- truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
- truss/truss_handle/patch/hash.py +71 -0
- truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
- truss/truss_handle/patch/signature.py +22 -0
- truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
- truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
- truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
- truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
- truss/util/.truss_ignore +3 -0
- truss/{docker.py → util/docker.py} +6 -2
- truss/util/download.py +6 -15
- truss/util/env_vars.py +41 -0
- truss/util/log_utils.py +52 -0
- truss/util/path.py +20 -20
- truss/util/requirements.py +11 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
- truss-0.60.0.dist-info/RECORD +324 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
- truss-0.60.0.dist-info/entry_points.txt +4 -0
- truss_chains/__init__.py +71 -0
- truss_chains/definitions.py +756 -0
- truss_chains/deployment/__init__.py +0 -0
- truss_chains/deployment/code_gen.py +816 -0
- truss_chains/deployment/deployment_client.py +871 -0
- truss_chains/framework.py +1480 -0
- truss_chains/public_api.py +231 -0
- truss_chains/py.typed +0 -0
- truss_chains/pydantic_numpy.py +131 -0
- truss_chains/reference_code/reference_chainlet.py +34 -0
- truss_chains/reference_code/reference_model.py +10 -0
- truss_chains/remote_chainlet/__init__.py +0 -0
- truss_chains/remote_chainlet/model_skeleton.py +60 -0
- truss_chains/remote_chainlet/stub.py +380 -0
- truss_chains/remote_chainlet/utils.py +332 -0
- truss_chains/streaming.py +378 -0
- truss_chains/utils.py +178 -0
- CODE_OF_CONDUCT.md +0 -131
- CONTRIBUTING.md +0 -48
- README.md +0 -137
- context_builder.Dockerfile +0 -24
- truss/blob/blob_backend.py +0 -10
- truss/blob/blob_backend_registry.py +0 -23
- truss/blob/http_public_blob_backend.py +0 -23
- truss/build/__init__.py +0 -2
- truss/build/build.py +0 -143
- truss/build/configure.py +0 -63
- truss/cli/__init__.py +0 -2
- truss/cli/console.py +0 -5
- truss/cli/create.py +0 -5
- truss/config/trt_llm.py +0 -81
- truss/constants.py +0 -61
- truss/model_inference.py +0 -123
- truss/patch/types.py +0 -30
- truss/pytest.ini +0 -7
- truss/server/common/errors.py +0 -100
- truss/server/common/termination_handler_middleware.py +0 -64
- truss/server/common/truss_server.py +0 -389
- truss/server/control/patch/model_code_patch_applier.py +0 -46
- truss/server/control/patch/requirement_name_identifier.py +0 -17
- truss/server/inference_server.py +0 -29
- truss/server/model_wrapper.py +0 -434
- truss/server/shared/logging.py +0 -81
- truss/templates/trtllm/model/model.py +0 -97
- truss/templates/trtllm/packages/build_engine_utils.py +0 -34
- truss/templates/trtllm/packages/constants.py +0 -11
- truss/templates/trtllm/packages/schema.py +0 -216
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
- truss/templates/trtllm/packages/triton_client.py +0 -150
- truss/templates/trtllm/packages/utils.py +0 -43
- truss/test_data/context_builder_image_test/test.py +0 -4
- truss/test_data/happy.ipynb +0 -54
- truss/test_data/model_load_failure_test/config.yaml +0 -2
- truss/test_data/test_concurrency_truss/config.yaml +0 -2
- truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
- truss/test_data/test_streaming_truss/config.yaml +0 -3
- truss/test_data/test_truss/config.yaml +0 -2
- truss/tests/server/common/test_termination_handler_middleware.py +0 -93
- truss/tests/server/control/test_model_container_patch_applier.py +0 -203
- truss/tests/server/core/server/common/test_util.py +0 -19
- truss/tests/server/test_model_wrapper.py +0 -87
- truss/util/data_structures.py +0 -16
- truss-0.10.0rc1.dist-info/RECORD +0 -216
- truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
- truss/{server/shared → base}/__init__.py +0 -0
- truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
- truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
- truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
- truss/{server → templates/server}/common/__init__.py +0 -0
- truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
- truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
- truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
- truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
- truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
- truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
- truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
- truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
- truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
- truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
- truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
- truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
- truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
- truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
- /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
- /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
- /truss/{patch → truss_handle/patch}/constants.py +0 -0
- /truss/{notebook.py → util/notebook.py} +0 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
|
@@ -0,0 +1,756 @@
|
|
|
1
|
+
# TODO: this file contains too much implementation -> restructure.
|
|
2
|
+
import abc
|
|
3
|
+
import enum
|
|
4
|
+
import logging
|
|
5
|
+
import pathlib
|
|
6
|
+
import traceback
|
|
7
|
+
from typing import ( # type: ignore[attr-defined] # Chains uses Python >=3.9.
|
|
8
|
+
Any,
|
|
9
|
+
Callable,
|
|
10
|
+
ClassVar,
|
|
11
|
+
Generic,
|
|
12
|
+
GenericAlias, # This causes above type error.
|
|
13
|
+
Iterable,
|
|
14
|
+
Literal,
|
|
15
|
+
Mapping,
|
|
16
|
+
Optional,
|
|
17
|
+
Protocol,
|
|
18
|
+
Type,
|
|
19
|
+
TypeVar,
|
|
20
|
+
Union,
|
|
21
|
+
cast,
|
|
22
|
+
get_args,
|
|
23
|
+
get_origin,
|
|
24
|
+
runtime_checkable,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
import pydantic
|
|
28
|
+
from truss.base import truss_config
|
|
29
|
+
from truss.base.constants import PRODUCTION_ENVIRONMENT_NAME
|
|
30
|
+
|
|
31
|
+
BASETEN_API_SECRET_NAME = "baseten_chain_api_key"
|
|
32
|
+
SECRET_DUMMY = "***"
|
|
33
|
+
TRUSS_CONFIG_CHAINS_KEY = "chains_metadata"
|
|
34
|
+
GENERATED_CODE_DIR = ".chains_generated"
|
|
35
|
+
DYNAMIC_CHAINLET_CONFIG_KEY = "dynamic_chainlet_config"
|
|
36
|
+
OTEL_TRACE_PARENT_HEADER_KEY = "traceparent"
|
|
37
|
+
RUN_REMOTE_METHOD_NAME = "run_remote" # Chainlet method name exposed as endpoint.
|
|
38
|
+
MODEL_ENDPOINT_METHOD_NAME = "predict" # Model method name exposed as endpoint.
|
|
39
|
+
HEALTH_CHECK_METHOD_NAME = "is_healthy"
|
|
40
|
+
# Below arg names must correspond to `definitions.ABCChainlet`.
|
|
41
|
+
CONTEXT_ARG_NAME = "context" # Referring to Chainlets `__init__` signature.
|
|
42
|
+
SELF_ARG_NAME = "self"
|
|
43
|
+
REMOTE_CONFIG_NAME = "remote_config"
|
|
44
|
+
|
|
45
|
+
K = TypeVar("K", contravariant=True)
|
|
46
|
+
V = TypeVar("V", covariant=True)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
C = TypeVar("C")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class _classproperty(Generic[C, V]):
|
|
53
|
+
def __init__(self, fget: Callable[[Type[C]], V]) -> None:
|
|
54
|
+
self._fget = fget
|
|
55
|
+
|
|
56
|
+
def __get__(self, instance: object, owner: Type[C]) -> V:
|
|
57
|
+
return self._fget.__get__(None, owner)()
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def classproperty(fget: Callable[[Type[C]], V]) -> _classproperty[C, V]:
|
|
61
|
+
return _classproperty(fget)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@runtime_checkable
|
|
65
|
+
class MappingNoIter(Protocol[K, V]):
|
|
66
|
+
def __getitem__(self, key: K) -> V: ...
|
|
67
|
+
|
|
68
|
+
def __len__(self) -> int: ...
|
|
69
|
+
|
|
70
|
+
def __contains__(self, key: K) -> bool: ...
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class SafeModel(pydantic.BaseModel):
|
|
74
|
+
"""Pydantic base model with reasonable config."""
|
|
75
|
+
|
|
76
|
+
model_config = pydantic.ConfigDict(
|
|
77
|
+
arbitrary_types_allowed=False, strict=True, validate_assignment=True
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class SafeModelNonSerializable(pydantic.BaseModel):
|
|
82
|
+
"""Pydantic base model with reasonable config - allowing arbitrary types."""
|
|
83
|
+
|
|
84
|
+
model_config = pydantic.ConfigDict(
|
|
85
|
+
arbitrary_types_allowed=True,
|
|
86
|
+
strict=True,
|
|
87
|
+
validate_assignment=True,
|
|
88
|
+
extra="forbid",
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class ChainsUsageError(TypeError):
|
|
93
|
+
"""Raised when user-defined Chainlets do not adhere to API constraints."""
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class MissingDependencyError(TypeError):
|
|
97
|
+
"""Raised when a needed resource could not be found or is not defined."""
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class ChainsRuntimeError(Exception):
|
|
101
|
+
"""Raised when components are not used the expected way at runtime."""
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class ChainsDeploymentError(Exception):
|
|
105
|
+
"""Raised when interaction with a Chain deployment are not possible."""
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class AbsPath:
|
|
109
|
+
_abs_file_path: str
|
|
110
|
+
_creating_module: str
|
|
111
|
+
_original_path: str
|
|
112
|
+
|
|
113
|
+
def __init__(
|
|
114
|
+
self, abs_file_path: str, creating_module: str, original_path: str
|
|
115
|
+
) -> None:
|
|
116
|
+
self._abs_file_path = abs_file_path
|
|
117
|
+
self._creating_module = creating_module
|
|
118
|
+
self._original_path = original_path
|
|
119
|
+
|
|
120
|
+
def _raise_if_not_exists(self, abs_path: str) -> None:
|
|
121
|
+
path = pathlib.Path(abs_path)
|
|
122
|
+
if not (path.is_file() or (path.is_dir() and any(path.iterdir()))):
|
|
123
|
+
raise MissingDependencyError(
|
|
124
|
+
f"With the file path `{self._original_path}` an absolute path relative "
|
|
125
|
+
f"to the calling module `{self._creating_module}` was created, "
|
|
126
|
+
f"resulting `{self._abs_file_path}` - but no file was found."
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def abs_path(self) -> str:
|
|
131
|
+
if self._abs_file_path != self._original_path:
|
|
132
|
+
logging.debug(
|
|
133
|
+
f"Using abs path `{self._abs_file_path}` for path specified as "
|
|
134
|
+
f"`{self._original_path}` (in `{self._creating_module}`)."
|
|
135
|
+
)
|
|
136
|
+
abs_path = self._abs_file_path
|
|
137
|
+
self._raise_if_not_exists(abs_path)
|
|
138
|
+
return abs_path
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class BasetenImage(enum.Enum):
|
|
142
|
+
"""Default images, curated by baseten, for different python versions. If a Chainlet
|
|
143
|
+
uses GPUs, drivers will be included in the image."""
|
|
144
|
+
|
|
145
|
+
# Enum values correspond to truss canonical python versions.
|
|
146
|
+
PY39 = "py39"
|
|
147
|
+
PY310 = "py310"
|
|
148
|
+
PY311 = "py311"
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class CustomImage(SafeModel):
|
|
152
|
+
"""Configures the usage of a custom image hosted on dockerhub."""
|
|
153
|
+
|
|
154
|
+
image: str
|
|
155
|
+
python_executable_path: Optional[str] = None
|
|
156
|
+
docker_auth: Optional[truss_config.DockerAuthSettings] = None
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class DockerImage(SafeModelNonSerializable):
|
|
160
|
+
"""Configures the docker image in which a remoted chainlet is deployed.
|
|
161
|
+
|
|
162
|
+
Note:
|
|
163
|
+
Any paths are relative to the source file where ``DockerImage`` is
|
|
164
|
+
defined and must be created with the helper function ``make_abs_path_here``.
|
|
165
|
+
This allows you for example organize chainlets in different (potentially nested)
|
|
166
|
+
modules and keep their requirement files right next their python source files.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
base_image: The base image used by the chainlet. Other dependencies and
|
|
170
|
+
assets are included as additional layers on top of that image. You can choose
|
|
171
|
+
a Baseten default image for a supported python version (e.g.
|
|
172
|
+
``BasetenImage.PY311``), this will also include GPU drivers if needed, or
|
|
173
|
+
provide a custom image (e.g. ``CustomImage(image="python:3.11-slim")``)..
|
|
174
|
+
pip_requirements_file: Path to a file containing pip requirements. The file
|
|
175
|
+
content is naively concatenated with ``pip_requirements``.
|
|
176
|
+
pip_requirements: A list of pip requirements to install. The items are
|
|
177
|
+
naively concatenated with the content of the ``pip_requirements_file``.
|
|
178
|
+
apt_requirements: A list of apt requirements to install.
|
|
179
|
+
data_dir: Data from this directory is copied into the docker image and
|
|
180
|
+
accessible to the remote chainlet at runtime.
|
|
181
|
+
external_package_dirs: A list of directories containing additional python
|
|
182
|
+
packages outside the chain's workspace dir, e.g. a shared library. This code
|
|
183
|
+
is copied into the docker image and importable at runtime.
|
|
184
|
+
"""
|
|
185
|
+
|
|
186
|
+
base_image: Union[BasetenImage, CustomImage] = BasetenImage.PY311
|
|
187
|
+
pip_requirements_file: Optional[AbsPath] = None
|
|
188
|
+
pip_requirements: list[str] = []
|
|
189
|
+
apt_requirements: list[str] = []
|
|
190
|
+
data_dir: Optional[AbsPath] = None
|
|
191
|
+
external_package_dirs: Optional[list[AbsPath]] = None
|
|
192
|
+
|
|
193
|
+
@pydantic.root_validator(pre=True)
|
|
194
|
+
def migrate_fields(cls, values):
|
|
195
|
+
if "base_image" in values:
|
|
196
|
+
base_image = values["base_image"]
|
|
197
|
+
if isinstance(base_image, str):
|
|
198
|
+
doc_link = "https://docs.baseten.co/chains-reference/sdk#class-truss-chains-dockerimage"
|
|
199
|
+
raise ChainsUsageError(
|
|
200
|
+
"`DockerImage.base_image` as string is deprecated. Specify as "
|
|
201
|
+
f"`BasetenImage` or `CustomImage` (see docs: {doc_link})."
|
|
202
|
+
)
|
|
203
|
+
return values
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class ComputeSpec(pydantic.BaseModel):
|
|
207
|
+
"""Parsed and validated compute. See ``Compute`` for more information."""
|
|
208
|
+
|
|
209
|
+
# TODO[rcano] add node count
|
|
210
|
+
cpu_count: int = 1
|
|
211
|
+
predict_concurrency: int = 1
|
|
212
|
+
memory: str = "2Gi"
|
|
213
|
+
accelerator: truss_config.AcceleratorSpec = truss_config.AcceleratorSpec()
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
CpuCountT = Literal["cpu_count"]
|
|
217
|
+
CPU_COUNT: CpuCountT = "cpu_count"
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
class Compute:
|
|
221
|
+
"""Specifies which compute resources a chainlet has in the *remote* deployment.
|
|
222
|
+
|
|
223
|
+
Note:
|
|
224
|
+
Not all combinations can be exactly satisfied by available hardware, in some
|
|
225
|
+
cases more powerful machine types are chosen to make sure requirements are met
|
|
226
|
+
or over-provisioned. Refer to the
|
|
227
|
+
`baseten instance reference <https://docs.baseten.co/performance/instances>`_.
|
|
228
|
+
"""
|
|
229
|
+
|
|
230
|
+
# Builder to create ComputeSpec.
|
|
231
|
+
# This extra layer around `ComputeSpec` is needed to parse the accelerator options.
|
|
232
|
+
|
|
233
|
+
_spec: ComputeSpec
|
|
234
|
+
|
|
235
|
+
def __init__(
|
|
236
|
+
self,
|
|
237
|
+
cpu_count: int = 1,
|
|
238
|
+
memory: str = "2Gi",
|
|
239
|
+
gpu: Union[str, truss_config.Accelerator, None] = None,
|
|
240
|
+
gpu_count: int = 1,
|
|
241
|
+
predict_concurrency: Union[int, CpuCountT] = 1,
|
|
242
|
+
) -> None:
|
|
243
|
+
"""
|
|
244
|
+
Args:
|
|
245
|
+
cpu_count: Minimum number of CPUs to allocate.
|
|
246
|
+
memory: Minimum memory to allocate, e.g. "2Gi" (2 gibibytes).
|
|
247
|
+
gpu: GPU accelerator type, e.g. "A10G", "A100", refer to the
|
|
248
|
+
`truss config <https://docs.baseten.co/reference/config#resources-accelerator>`_
|
|
249
|
+
for more choices.
|
|
250
|
+
gpu_count: Number of GPUs to allocate.
|
|
251
|
+
predict_concurrency: Number of concurrent requests a single replica of a
|
|
252
|
+
deployed chainlet handles.
|
|
253
|
+
|
|
254
|
+
Concurrency concepts are explained in `this guide <https://docs.baseten.co/deploy/guides/concurrency#predict-concurrency>`_. # noqa: E501
|
|
255
|
+
It is important to understand the difference between `predict_concurrency` and
|
|
256
|
+
the concurrency target (used for autoscaling, i.e. adding or removing replicas).
|
|
257
|
+
Furthermore, the ``predict_concurrency`` of a single instance is implemented in
|
|
258
|
+
two ways:
|
|
259
|
+
|
|
260
|
+
- Via python's ``asyncio``, if ``run_remote`` is an async def. This
|
|
261
|
+
requires that ``run_remote`` yields to the event loop.
|
|
262
|
+
|
|
263
|
+
- With a threadpool if it's a synchronous function. This requires
|
|
264
|
+
that the threads don't have significant CPU load (due to the GIL).
|
|
265
|
+
"""
|
|
266
|
+
accelerator = truss_config.AcceleratorSpec()
|
|
267
|
+
if gpu:
|
|
268
|
+
accelerator.accelerator = truss_config.Accelerator(gpu)
|
|
269
|
+
accelerator.count = gpu_count
|
|
270
|
+
accelerator = truss_config.AcceleratorSpec(
|
|
271
|
+
accelerator=truss_config.Accelerator(gpu), count=gpu_count
|
|
272
|
+
)
|
|
273
|
+
if predict_concurrency == CPU_COUNT:
|
|
274
|
+
predict_concurrency_int = cpu_count
|
|
275
|
+
else:
|
|
276
|
+
assert isinstance(predict_concurrency, int)
|
|
277
|
+
predict_concurrency_int = predict_concurrency
|
|
278
|
+
|
|
279
|
+
self._spec = ComputeSpec(
|
|
280
|
+
cpu_count=cpu_count,
|
|
281
|
+
memory=memory,
|
|
282
|
+
accelerator=accelerator,
|
|
283
|
+
predict_concurrency=predict_concurrency_int,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
def get_spec(self) -> ComputeSpec:
|
|
287
|
+
return self._spec.model_copy(deep=True)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
class AssetSpec(SafeModel):
|
|
291
|
+
"""Parsed and validated assets. See ``Assets`` for more information."""
|
|
292
|
+
|
|
293
|
+
secrets: dict[str, str] = pydantic.Field({})
|
|
294
|
+
cached: list[truss_config.ModelRepo] = []
|
|
295
|
+
external_data: list[truss_config.ExternalDataItem] = []
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class Assets:
|
|
299
|
+
"""Specifies which assets a chainlet can access in the remote deployment.
|
|
300
|
+
|
|
301
|
+
For example, model weight caching can be used like this::
|
|
302
|
+
|
|
303
|
+
import truss_chains as chains
|
|
304
|
+
from truss.base import truss_config
|
|
305
|
+
|
|
306
|
+
mistral_cache = truss_config.ModelRepo(
|
|
307
|
+
repo_id="mistralai/Mistral-7B-Instruct-v0.2",
|
|
308
|
+
allow_patterns=["*.json", "*.safetensors", ".model"]
|
|
309
|
+
)
|
|
310
|
+
chains.Assets(cached=[mistral_cache], ...)
|
|
311
|
+
|
|
312
|
+
See `truss caching guide <https://docs.baseten.co/deploy/guides/model-cache#enabling-caching-for-a-model>`_
|
|
313
|
+
for more details on caching.
|
|
314
|
+
"""
|
|
315
|
+
|
|
316
|
+
# Builder to create asset spec.
|
|
317
|
+
# This extra layer around `AssetSpec` is needed to add secret_keys.
|
|
318
|
+
_spec: AssetSpec
|
|
319
|
+
|
|
320
|
+
def __init__(
|
|
321
|
+
self,
|
|
322
|
+
cached: Iterable[truss_config.ModelRepo] = (),
|
|
323
|
+
secret_keys: Iterable[str] = (),
|
|
324
|
+
external_data: Iterable[truss_config.ExternalDataItem] = (),
|
|
325
|
+
) -> None:
|
|
326
|
+
"""
|
|
327
|
+
Args:
|
|
328
|
+
cached: One or more ``truss_config.ModelRepo`` objects.
|
|
329
|
+
secret_keys: Names of secrets stored on baseten, that the
|
|
330
|
+
chainlet should have access to. You can manage secrets on baseten
|
|
331
|
+
`here <https://app.baseten.co/settings/secrets>`_.
|
|
332
|
+
external_data: Data to be downloaded from public URLs and made available
|
|
333
|
+
in the deployment (via ``context.data_dir``). See
|
|
334
|
+
`here <https://docs.baseten.co/reference/config#external-data>`_ for
|
|
335
|
+
more details.
|
|
336
|
+
"""
|
|
337
|
+
self._spec = AssetSpec(
|
|
338
|
+
cached=list(cached),
|
|
339
|
+
secrets={k: SECRET_DUMMY for k in secret_keys},
|
|
340
|
+
external_data=list(external_data),
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
def get_spec(self) -> AssetSpec:
|
|
344
|
+
"""Returns parsed and validated assets."""
|
|
345
|
+
return self._spec.model_copy(deep=True)
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
class ChainletOptions(SafeModelNonSerializable):
|
|
349
|
+
"""
|
|
350
|
+
Args:
|
|
351
|
+
enable_b10_tracing: enables baseten-internal trace data collection. This
|
|
352
|
+
helps baseten engineers better analyze chain performance in case of issues.
|
|
353
|
+
It is independent of a potentially user-configured tracing instrumentation.
|
|
354
|
+
Turning this on, could add performance overhead.
|
|
355
|
+
env_variables: static environment variables available to the deployed chainlet.
|
|
356
|
+
health_checks: Configures health checks for the chainlet.
|
|
357
|
+
"""
|
|
358
|
+
|
|
359
|
+
enable_b10_tracing: bool = False
|
|
360
|
+
env_variables: Mapping[str, str] = {}
|
|
361
|
+
health_checks: truss_config.HealthChecks = truss_config.HealthChecks()
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
class ChainletMetadata(SafeModelNonSerializable):
|
|
365
|
+
is_entrypoint: bool = False
|
|
366
|
+
chain_name: Optional[str] = None
|
|
367
|
+
init_is_patched: bool = False
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
class FrameworkConfig(SafeModelNonSerializable):
|
|
371
|
+
entity_type: Literal["Chainlet", "Model"]
|
|
372
|
+
supports_dependencies: bool
|
|
373
|
+
endpoint_method_name: str
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
class RemoteConfig(SafeModelNonSerializable):
|
|
377
|
+
"""Bundles config values needed to deploy a chainlet remotely.
|
|
378
|
+
|
|
379
|
+
This is specified as a class variable for each chainlet class, e.g.::
|
|
380
|
+
|
|
381
|
+
import truss_chains as chains
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
class MyChainlet(chains.ChainletBase):
|
|
385
|
+
remote_config = chains.RemoteConfig(
|
|
386
|
+
docker_image=chains.DockerImage(
|
|
387
|
+
pip_requirements=["torch==2.0.1", ...]
|
|
388
|
+
),
|
|
389
|
+
compute=chains.Compute(cpu_count=2, gpu="A10G", ...),
|
|
390
|
+
assets=chains.Assets(secret_keys=["hf_access_token"], ...),
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
"""
|
|
394
|
+
|
|
395
|
+
docker_image: DockerImage = DockerImage()
|
|
396
|
+
compute: Compute = Compute()
|
|
397
|
+
assets: Assets = Assets()
|
|
398
|
+
name: Optional[str] = None
|
|
399
|
+
options: ChainletOptions = ChainletOptions()
|
|
400
|
+
|
|
401
|
+
def get_compute_spec(self) -> ComputeSpec:
|
|
402
|
+
return self.compute.get_spec()
|
|
403
|
+
|
|
404
|
+
def get_asset_spec(self) -> AssetSpec:
|
|
405
|
+
return self.assets.get_spec()
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
DEFAULT_TIMEOUT_SEC = 600.0
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
class RPCOptions(SafeModel):
|
|
412
|
+
"""Options to customize RPCs to dependency chainlets.
|
|
413
|
+
|
|
414
|
+
Args:
|
|
415
|
+
retries: The number of times to retry the remote chainlet in case of failures
|
|
416
|
+
(e.g. due to transient network issues). For streaming, retries are only made
|
|
417
|
+
if the request fails before streaming any results back. Failures mid-stream
|
|
418
|
+
not retried.
|
|
419
|
+
timeout_sec: Timeout for the HTTP request to this chainlet.
|
|
420
|
+
use_binary: Whether to send data in binary format. This can give a parsing
|
|
421
|
+
speedup and message size reduction (~25%) for numpy arrays. Use
|
|
422
|
+
``NumpyArrayField`` as a field type on pydantic models for integration and set
|
|
423
|
+
this option to ``True``. For simple text data, there is no significant benefit.
|
|
424
|
+
"""
|
|
425
|
+
|
|
426
|
+
retries: int = 1
|
|
427
|
+
timeout_sec: float = DEFAULT_TIMEOUT_SEC
|
|
428
|
+
use_binary: bool = False
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
class ServiceDescriptor(SafeModel):
|
|
432
|
+
"""Bundles values to establish an RPC session to a dependency chainlet,
|
|
433
|
+
specifically with ``StubBase``."""
|
|
434
|
+
|
|
435
|
+
name: str
|
|
436
|
+
display_name: str
|
|
437
|
+
options: RPCOptions
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
class DeployedServiceDescriptor(ServiceDescriptor):
|
|
441
|
+
predict_url: str
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
class Environment(SafeModel):
|
|
445
|
+
"""The environment the chainlet is deployed in.
|
|
446
|
+
|
|
447
|
+
Args:
|
|
448
|
+
name: The name of the environment.
|
|
449
|
+
"""
|
|
450
|
+
|
|
451
|
+
name: str
|
|
452
|
+
# can add more fields here as we add them to dynamic_config configmap
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
class DeploymentContext(SafeModelNonSerializable):
|
|
456
|
+
"""Bundles config values and resources needed to instantiate Chainlets.
|
|
457
|
+
|
|
458
|
+
The context can optionally added as a trailing argument in a Chainlet's
|
|
459
|
+
``__init__`` method and then used to set up the chainlet (e.g. using a secret as
|
|
460
|
+
an access token for downloading model weights).
|
|
461
|
+
|
|
462
|
+
Args:
|
|
463
|
+
data_dir: The directory where the chainlet can store and access data,
|
|
464
|
+
e.g. for downloading model weights.
|
|
465
|
+
chainlet_to_service: A mapping from chainlet names to service descriptors.
|
|
466
|
+
This is used to create RPC sessions to dependency chainlets. It contains only
|
|
467
|
+
the chainlet services that are dependencies of the current chainlet.
|
|
468
|
+
secrets: A mapping from secret names to secret values. It contains only the
|
|
469
|
+
secrets that are listed in ``remote_config.assets.secret_keys`` of the
|
|
470
|
+
current chainlet.
|
|
471
|
+
environment: The environment that the chainlet is deployed in.
|
|
472
|
+
None if the chainlet is not associated with an environment.
|
|
473
|
+
"""
|
|
474
|
+
|
|
475
|
+
data_dir: Optional[pathlib.Path] = None
|
|
476
|
+
chainlet_to_service: Mapping[str, DeployedServiceDescriptor]
|
|
477
|
+
secrets: MappingNoIter[str, str]
|
|
478
|
+
environment: Optional[Environment] = None
|
|
479
|
+
|
|
480
|
+
def get_service_descriptor(self, chainlet_name: str) -> DeployedServiceDescriptor:
|
|
481
|
+
if chainlet_name not in self.chainlet_to_service:
|
|
482
|
+
raise MissingDependencyError(f"{chainlet_name}")
|
|
483
|
+
return self.chainlet_to_service[chainlet_name]
|
|
484
|
+
|
|
485
|
+
def get_baseten_api_key(self) -> str:
|
|
486
|
+
if self.secrets is None:
|
|
487
|
+
raise ChainsRuntimeError(
|
|
488
|
+
f"Secrets not set in `{self.__class__.__name__}` object."
|
|
489
|
+
)
|
|
490
|
+
error_msg = (
|
|
491
|
+
"For using chains, it is required to setup a an API key with name "
|
|
492
|
+
f"`{BASETEN_API_SECRET_NAME}` on Baseten to allow chain Chainlet to "
|
|
493
|
+
"call other Chainlets. For local execution, secrets can be provided "
|
|
494
|
+
"to `run_local`."
|
|
495
|
+
)
|
|
496
|
+
if BASETEN_API_SECRET_NAME not in self.secrets:
|
|
497
|
+
raise MissingDependencyError(error_msg)
|
|
498
|
+
|
|
499
|
+
api_key = self.secrets[BASETEN_API_SECRET_NAME]
|
|
500
|
+
if api_key == SECRET_DUMMY:
|
|
501
|
+
raise MissingDependencyError(
|
|
502
|
+
f"{error_msg}. Retrieved dummy value of `{api_key}`."
|
|
503
|
+
)
|
|
504
|
+
return api_key
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
class TrussMetadata(SafeModel):
|
|
508
|
+
"""Plugin for the truss config (in config["model_metadata"]["chains_metadata"])."""
|
|
509
|
+
|
|
510
|
+
chainlet_to_service: Mapping[str, ServiceDescriptor]
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
class ABCChainlet(abc.ABC):
|
|
514
|
+
remote_config: ClassVar[RemoteConfig] = RemoteConfig()
|
|
515
|
+
# `meta_data` is not shared between subclasses, each has an isolated copy.
|
|
516
|
+
meta_data: ClassVar[ChainletMetadata] = ChainletMetadata()
|
|
517
|
+
_framework_config: ClassVar[FrameworkConfig]
|
|
518
|
+
|
|
519
|
+
@classmethod
|
|
520
|
+
def has_custom_init(cls) -> bool:
|
|
521
|
+
return cls.__init__ is not object.__init__
|
|
522
|
+
|
|
523
|
+
@classproperty
|
|
524
|
+
@classmethod
|
|
525
|
+
def name(cls) -> str:
|
|
526
|
+
return cls.__name__
|
|
527
|
+
|
|
528
|
+
@classproperty
|
|
529
|
+
@classmethod
|
|
530
|
+
def display_name(cls) -> str:
|
|
531
|
+
return cls.remote_config.name or cls.name
|
|
532
|
+
|
|
533
|
+
@classproperty
|
|
534
|
+
@classmethod
|
|
535
|
+
def supports_dependencies(cls) -> bool:
|
|
536
|
+
return cls._framework_config.supports_dependencies
|
|
537
|
+
|
|
538
|
+
@classproperty
|
|
539
|
+
@classmethod
|
|
540
|
+
def entity_type(cls) -> Literal["Chainlet", "Model"]:
|
|
541
|
+
return cls._framework_config.entity_type
|
|
542
|
+
|
|
543
|
+
@classproperty
|
|
544
|
+
@classmethod
|
|
545
|
+
def endpoint_method_name(cls) -> str:
|
|
546
|
+
return cls._framework_config.endpoint_method_name
|
|
547
|
+
|
|
548
|
+
# Cannot add this abstract method to API, because we want to allow arbitrary
|
|
549
|
+
# arg/kwarg names and specifying any function signature here would give type errors
|
|
550
|
+
# @abc.abstractmethod
|
|
551
|
+
# def run_remote(self, *args, **kwargs) -> Any:
|
|
552
|
+
# ...
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
class TypeDescriptor(SafeModelNonSerializable):
|
|
556
|
+
"""For describing I/O types of Chainlets."""
|
|
557
|
+
|
|
558
|
+
raw: Any # The raw type annotation object (could be a type or GenericAlias).
|
|
559
|
+
|
|
560
|
+
@property
|
|
561
|
+
def is_pydantic(self) -> bool:
|
|
562
|
+
return (
|
|
563
|
+
isinstance(self.raw, type)
|
|
564
|
+
and not isinstance(self.raw, GenericAlias)
|
|
565
|
+
and issubclass(self.raw, pydantic.BaseModel)
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
@property
|
|
569
|
+
def has_pydantic_args(self):
|
|
570
|
+
origin = get_origin(self.raw)
|
|
571
|
+
if not origin:
|
|
572
|
+
return False
|
|
573
|
+
args = get_args(self.raw)
|
|
574
|
+
return any(
|
|
575
|
+
isinstance(arg, type) and issubclass(arg, pydantic.BaseModel)
|
|
576
|
+
for arg in args
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
class StreamingTypeDescriptor(TypeDescriptor):
|
|
581
|
+
origin_type: type
|
|
582
|
+
arg_type: type
|
|
583
|
+
|
|
584
|
+
@property
|
|
585
|
+
def is_string(self) -> bool:
|
|
586
|
+
return self.arg_type is str
|
|
587
|
+
|
|
588
|
+
@property
|
|
589
|
+
def is_pydantic(self) -> bool:
|
|
590
|
+
return False
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
class InputArg(SafeModelNonSerializable):
|
|
594
|
+
name: str
|
|
595
|
+
type: TypeDescriptor
|
|
596
|
+
is_optional: bool
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
class EndpointAPIDescriptor(SafeModelNonSerializable):
|
|
600
|
+
name: str = RUN_REMOTE_METHOD_NAME
|
|
601
|
+
input_args: list[InputArg]
|
|
602
|
+
output_types: list[TypeDescriptor]
|
|
603
|
+
is_async: bool
|
|
604
|
+
is_streaming: bool
|
|
605
|
+
|
|
606
|
+
@property
|
|
607
|
+
def streaming_type(self) -> StreamingTypeDescriptor:
|
|
608
|
+
if (
|
|
609
|
+
not self.is_streaming
|
|
610
|
+
or len(self.output_types) != 1
|
|
611
|
+
or not isinstance(self.output_types[0], StreamingTypeDescriptor)
|
|
612
|
+
):
|
|
613
|
+
raise ValueError(f"{self} is not a streaming endpoint.")
|
|
614
|
+
return cast(StreamingTypeDescriptor, self.output_types[0])
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
class DependencyDescriptor(SafeModelNonSerializable):
|
|
618
|
+
chainlet_cls: Type[ABCChainlet]
|
|
619
|
+
options: RPCOptions
|
|
620
|
+
|
|
621
|
+
@property
|
|
622
|
+
def name(self) -> str:
|
|
623
|
+
return self.chainlet_cls.name
|
|
624
|
+
|
|
625
|
+
@property
|
|
626
|
+
def display_name(self) -> str:
|
|
627
|
+
return self.chainlet_cls.display_name
|
|
628
|
+
|
|
629
|
+
|
|
630
|
+
class HealthCheckAPIDescriptor(SafeModelNonSerializable):
|
|
631
|
+
name: str = HEALTH_CHECK_METHOD_NAME
|
|
632
|
+
is_async: bool
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
class ChainletAPIDescriptor(SafeModelNonSerializable):
|
|
636
|
+
chainlet_cls: Type[ABCChainlet]
|
|
637
|
+
src_path: str
|
|
638
|
+
has_context: bool
|
|
639
|
+
dependencies: Mapping[str, DependencyDescriptor]
|
|
640
|
+
endpoint: EndpointAPIDescriptor
|
|
641
|
+
health_check: Optional[HealthCheckAPIDescriptor]
|
|
642
|
+
|
|
643
|
+
def __hash__(self) -> int:
|
|
644
|
+
return hash(self.chainlet_cls)
|
|
645
|
+
|
|
646
|
+
@property
|
|
647
|
+
def name(self) -> str:
|
|
648
|
+
return self.chainlet_cls.name
|
|
649
|
+
|
|
650
|
+
@property
|
|
651
|
+
def display_name(self) -> str:
|
|
652
|
+
return self.chainlet_cls.display_name
|
|
653
|
+
|
|
654
|
+
|
|
655
|
+
class StackFrame(SafeModel):
|
|
656
|
+
filename: str
|
|
657
|
+
lineno: Optional[int]
|
|
658
|
+
name: str
|
|
659
|
+
line: Optional[str]
|
|
660
|
+
|
|
661
|
+
@classmethod
|
|
662
|
+
def from_frame_summary(cls, frame: traceback.FrameSummary):
|
|
663
|
+
return cls(
|
|
664
|
+
filename=frame.filename,
|
|
665
|
+
lineno=frame.lineno,
|
|
666
|
+
name=frame.name,
|
|
667
|
+
line=frame.line,
|
|
668
|
+
)
|
|
669
|
+
|
|
670
|
+
def to_frame_summary(self) -> traceback.FrameSummary:
|
|
671
|
+
return traceback.FrameSummary(
|
|
672
|
+
filename=self.filename, lineno=self.lineno, name=self.name, line=self.line
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
|
|
676
|
+
class RemoteErrorDetail(SafeModel):
|
|
677
|
+
"""When a remote chainlet raises an exception, this pydantic model contains
|
|
678
|
+
information about the error and stack trace and is included in JSON form in the
|
|
679
|
+
error response.
|
|
680
|
+
"""
|
|
681
|
+
|
|
682
|
+
exception_cls_name: str
|
|
683
|
+
exception_module_name: Optional[str]
|
|
684
|
+
exception_message: str
|
|
685
|
+
user_stack_trace: list[StackFrame]
|
|
686
|
+
|
|
687
|
+
def _to_stack_summary(self) -> traceback.StackSummary:
|
|
688
|
+
return traceback.StackSummary.from_list(
|
|
689
|
+
frame.to_frame_summary() for frame in self.user_stack_trace
|
|
690
|
+
)
|
|
691
|
+
|
|
692
|
+
def format(self) -> str:
|
|
693
|
+
"""Format the error for printing, similar to how Python formats exceptions
|
|
694
|
+
with stack traces."""
|
|
695
|
+
stack = "".join(traceback.format_list(self._to_stack_summary()))
|
|
696
|
+
exc_info = (
|
|
697
|
+
f"\n(Exception class defined in `{self.exception_module_name}`.)"
|
|
698
|
+
if self.exception_module_name
|
|
699
|
+
else ""
|
|
700
|
+
)
|
|
701
|
+
error = (
|
|
702
|
+
f"Chainlet-Traceback (most recent call last):\n"
|
|
703
|
+
f"{stack}{self.exception_cls_name}: {self.exception_message}{exc_info}"
|
|
704
|
+
)
|
|
705
|
+
return error
|
|
706
|
+
|
|
707
|
+
|
|
708
|
+
class GenericRemoteException(Exception): ...
|
|
709
|
+
|
|
710
|
+
|
|
711
|
+
########################################################################################
|
|
712
|
+
|
|
713
|
+
|
|
714
|
+
class PushOptions(SafeModelNonSerializable):
|
|
715
|
+
chain_name: str
|
|
716
|
+
only_generate_trusses: bool = False
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
class PushOptionsBaseten(PushOptions):
|
|
720
|
+
remote: str
|
|
721
|
+
publish: bool
|
|
722
|
+
environment: Optional[str]
|
|
723
|
+
|
|
724
|
+
@classmethod
|
|
725
|
+
def create(
|
|
726
|
+
cls,
|
|
727
|
+
chain_name: str,
|
|
728
|
+
publish: bool,
|
|
729
|
+
promote: Optional[bool],
|
|
730
|
+
only_generate_trusses: bool,
|
|
731
|
+
remote: str,
|
|
732
|
+
environment: Optional[str] = None,
|
|
733
|
+
) -> "PushOptionsBaseten":
|
|
734
|
+
if promote and not environment:
|
|
735
|
+
environment = PRODUCTION_ENVIRONMENT_NAME
|
|
736
|
+
if environment:
|
|
737
|
+
publish = True
|
|
738
|
+
return PushOptionsBaseten(
|
|
739
|
+
remote=remote,
|
|
740
|
+
chain_name=chain_name,
|
|
741
|
+
publish=publish,
|
|
742
|
+
only_generate_trusses=only_generate_trusses,
|
|
743
|
+
environment=environment,
|
|
744
|
+
)
|
|
745
|
+
|
|
746
|
+
|
|
747
|
+
class PushOptionsLocalDocker(PushOptions):
|
|
748
|
+
# Local docker-to-docker requests don't need auth, but we need to set a
|
|
749
|
+
# value different from `SECRET_DUMMY` to not trigger the check that the secret
|
|
750
|
+
# is unset. Additionally, if local docker containers make calls to models deployed
|
|
751
|
+
# on baseten, a real API key must be provided (i.e. the default must be overridden).
|
|
752
|
+
baseten_chain_api_key: str = "docker_dummy_key"
|
|
753
|
+
# If enabled, chains code is copied from the local package into `/app/truss_chains`
|
|
754
|
+
# in the docker image (which takes precedence over potential pip/site-packages).
|
|
755
|
+
# This should be used for integration tests or quick local dev loops.
|
|
756
|
+
use_local_chains_src: bool = False
|