truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/__init__.py +10 -3
- truss/api/__init__.py +123 -0
- truss/api/definitions.py +51 -0
- truss/base/constants.py +116 -0
- truss/base/custom_types.py +29 -0
- truss/{errors.py → base/errors.py} +4 -0
- truss/base/trt_llm_config.py +310 -0
- truss/{truss_config.py → base/truss_config.py} +344 -31
- truss/{truss_spec.py → base/truss_spec.py} +20 -6
- truss/{validation.py → base/validation.py} +60 -11
- truss/cli/cli.py +841 -88
- truss/{remote → cli}/remote_cli.py +2 -7
- truss/contexts/docker_build_setup.py +67 -0
- truss/contexts/image_builder/cache_warmer.py +2 -8
- truss/contexts/image_builder/image_builder.py +1 -1
- truss/contexts/image_builder/serving_image_builder.py +292 -46
- truss/contexts/image_builder/util.py +1 -3
- truss/contexts/local_loader/docker_build_emulator.py +58 -0
- truss/contexts/local_loader/load_model_local.py +2 -2
- truss/contexts/local_loader/truss_module_loader.py +1 -1
- truss/contexts/local_loader/utils.py +1 -1
- truss/local/local_config.py +2 -6
- truss/local/local_config_handler.py +20 -5
- truss/patch/__init__.py +1 -0
- truss/patch/hash.py +4 -70
- truss/patch/signature.py +4 -16
- truss/patch/truss_dir_patch_applier.py +3 -78
- truss/remote/baseten/api.py +308 -23
- truss/remote/baseten/auth.py +3 -3
- truss/remote/baseten/core.py +257 -50
- truss/remote/baseten/custom_types.py +44 -0
- truss/remote/baseten/error.py +4 -0
- truss/remote/baseten/remote.py +369 -118
- truss/remote/baseten/service.py +118 -11
- truss/remote/baseten/utils/status.py +29 -0
- truss/remote/baseten/utils/tar.py +34 -22
- truss/remote/baseten/utils/transfer.py +36 -23
- truss/remote/remote_factory.py +14 -5
- truss/remote/truss_remote.py +72 -45
- truss/templates/base.Dockerfile.jinja +18 -16
- truss/templates/cache.Dockerfile.jinja +3 -3
- truss/{server → templates/control}/control/application.py +14 -35
- truss/{server → templates/control}/control/endpoints.py +39 -9
- truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
- truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
- truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
- truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
- truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
- truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
- truss/{server → templates/control}/control/server.py +11 -6
- truss/templates/control/requirements.txt +9 -0
- truss/templates/custom_python_dx/my_model.py +28 -0
- truss/templates/docker_server/proxy.conf.jinja +42 -0
- truss/templates/docker_server/supervisord.conf.jinja +27 -0
- truss/templates/docker_server_requirements.txt +1 -0
- truss/templates/server/common/errors.py +231 -0
- truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
- truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
- truss/{server → templates/server}/common/retry.py +1 -0
- truss/{server → templates/server}/common/schema.py +11 -9
- truss/templates/server/common/tracing.py +157 -0
- truss/templates/server/main.py +9 -0
- truss/templates/server/model_wrapper.py +961 -0
- truss/templates/server/requirements.txt +21 -0
- truss/templates/server/truss_server.py +447 -0
- truss/templates/server.Dockerfile.jinja +62 -14
- truss/templates/shared/dynamic_config_resolver.py +28 -0
- truss/templates/shared/lazy_data_resolver.py +164 -0
- truss/templates/shared/log_config.py +125 -0
- truss/{server → templates}/shared/secrets_resolver.py +1 -2
- truss/{server → templates}/shared/serialization.py +31 -9
- truss/{server → templates}/shared/util.py +3 -13
- truss/templates/trtllm-audio/model/model.py +49 -0
- truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
- truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
- truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
- truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
- truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
- truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
- truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
- truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
- truss/templates/trtllm-briton/src/extension.py +64 -0
- truss/tests/conftest.py +302 -94
- truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
- truss/tests/contexts/local_loader/test_load_local.py +2 -2
- truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
- truss/tests/patch/test_calc_patch.py +439 -127
- truss/tests/patch/test_dir_signature.py +3 -12
- truss/tests/patch/test_hash.py +1 -1
- truss/tests/patch/test_signature.py +1 -1
- truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
- truss/tests/patch/test_types.py +2 -2
- truss/tests/remote/baseten/test_api.py +153 -58
- truss/tests/remote/baseten/test_auth.py +2 -1
- truss/tests/remote/baseten/test_core.py +160 -12
- truss/tests/remote/baseten/test_remote.py +489 -77
- truss/tests/remote/baseten/test_service.py +55 -0
- truss/tests/remote/test_remote_factory.py +16 -18
- truss/tests/remote/test_truss_remote.py +26 -17
- truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
- truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
- truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
- truss/tests/{server → templates/control}/control/test_server.py +79 -24
- truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
- truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
- truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
- truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
- truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
- truss/tests/{server → templates/server}/common/test_retry.py +3 -3
- truss/tests/templates/server/test_model_wrapper.py +248 -0
- truss/tests/{server → templates/server}/test_schema.py +3 -5
- truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
- truss/tests/test_build.py +9 -52
- truss/tests/test_config.py +336 -77
- truss/tests/test_context_builder_image.py +3 -11
- truss/tests/test_control_truss_patching.py +7 -12
- truss/tests/test_custom_server.py +38 -0
- truss/tests/test_data/context_builder_image_test/test.py +3 -0
- truss/tests/test_data/happy.ipynb +56 -0
- truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
- truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
- truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
- truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
- truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
- truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
- truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
- truss/tests/test_data/test_async_truss/__init__.py +0 -0
- truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/config.yaml +16 -0
- truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/config.yaml +13 -0
- truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
- truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
- truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
- truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
- truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
- truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
- truss/tests/test_data/test_openai/__init__.py +0 -0
- truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
- truss/tests/test_data/test_openai/model/__init__.py +0 -0
- truss/tests/test_data/test_openai/model/model.py +15 -0
- truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
- truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
- truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
- truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
- truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
- truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
- truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
- truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
- truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
- truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
- truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
- truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
- truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
- truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
- truss/tests/test_data/test_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss/config.yaml +4 -0
- truss/tests/test_data/test_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss/model/dummy +0 -0
- truss/tests/test_data/test_truss/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
- truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
- truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_docker.py +2 -1
- truss/tests/test_model_inference.py +1340 -292
- truss/tests/test_model_schema.py +33 -26
- truss/tests/test_testing_utilities_for_other_tests.py +50 -5
- truss/tests/test_truss_gatherer.py +3 -5
- truss/tests/test_truss_handle.py +62 -59
- truss/tests/test_util.py +2 -1
- truss/tests/test_validation.py +15 -13
- truss/tests/trt_llm/test_trt_llm_config.py +41 -0
- truss/tests/trt_llm/test_validation.py +91 -0
- truss/tests/util/test_config_checks.py +40 -0
- truss/tests/util/test_env_vars.py +14 -0
- truss/tests/util/test_path.py +10 -23
- truss/trt_llm/config_checks.py +43 -0
- truss/trt_llm/validation.py +42 -0
- truss/truss_handle/__init__.py +0 -0
- truss/truss_handle/build.py +122 -0
- truss/{decorators.py → truss_handle/decorators.py} +1 -1
- truss/truss_handle/patch/__init__.py +0 -0
- truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
- truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
- truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
- truss/truss_handle/patch/hash.py +71 -0
- truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
- truss/truss_handle/patch/signature.py +22 -0
- truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
- truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
- truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
- truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
- truss/util/.truss_ignore +3 -0
- truss/{docker.py → util/docker.py} +6 -2
- truss/util/download.py +6 -15
- truss/util/env_vars.py +41 -0
- truss/util/log_utils.py +52 -0
- truss/util/path.py +20 -20
- truss/util/requirements.py +11 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
- truss-0.60.0.dist-info/RECORD +324 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
- truss-0.60.0.dist-info/entry_points.txt +4 -0
- truss_chains/__init__.py +71 -0
- truss_chains/definitions.py +756 -0
- truss_chains/deployment/__init__.py +0 -0
- truss_chains/deployment/code_gen.py +816 -0
- truss_chains/deployment/deployment_client.py +871 -0
- truss_chains/framework.py +1480 -0
- truss_chains/public_api.py +231 -0
- truss_chains/py.typed +0 -0
- truss_chains/pydantic_numpy.py +131 -0
- truss_chains/reference_code/reference_chainlet.py +34 -0
- truss_chains/reference_code/reference_model.py +10 -0
- truss_chains/remote_chainlet/__init__.py +0 -0
- truss_chains/remote_chainlet/model_skeleton.py +60 -0
- truss_chains/remote_chainlet/stub.py +380 -0
- truss_chains/remote_chainlet/utils.py +332 -0
- truss_chains/streaming.py +378 -0
- truss_chains/utils.py +178 -0
- CODE_OF_CONDUCT.md +0 -131
- CONTRIBUTING.md +0 -48
- README.md +0 -137
- context_builder.Dockerfile +0 -24
- truss/blob/blob_backend.py +0 -10
- truss/blob/blob_backend_registry.py +0 -23
- truss/blob/http_public_blob_backend.py +0 -23
- truss/build/__init__.py +0 -2
- truss/build/build.py +0 -143
- truss/build/configure.py +0 -63
- truss/cli/__init__.py +0 -2
- truss/cli/console.py +0 -5
- truss/cli/create.py +0 -5
- truss/config/trt_llm.py +0 -81
- truss/constants.py +0 -61
- truss/model_inference.py +0 -123
- truss/patch/types.py +0 -30
- truss/pytest.ini +0 -7
- truss/server/common/errors.py +0 -100
- truss/server/common/termination_handler_middleware.py +0 -64
- truss/server/common/truss_server.py +0 -389
- truss/server/control/patch/model_code_patch_applier.py +0 -46
- truss/server/control/patch/requirement_name_identifier.py +0 -17
- truss/server/inference_server.py +0 -29
- truss/server/model_wrapper.py +0 -434
- truss/server/shared/logging.py +0 -81
- truss/templates/trtllm/model/model.py +0 -97
- truss/templates/trtllm/packages/build_engine_utils.py +0 -34
- truss/templates/trtllm/packages/constants.py +0 -11
- truss/templates/trtllm/packages/schema.py +0 -216
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
- truss/templates/trtllm/packages/triton_client.py +0 -150
- truss/templates/trtllm/packages/utils.py +0 -43
- truss/test_data/context_builder_image_test/test.py +0 -4
- truss/test_data/happy.ipynb +0 -54
- truss/test_data/model_load_failure_test/config.yaml +0 -2
- truss/test_data/test_concurrency_truss/config.yaml +0 -2
- truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
- truss/test_data/test_streaming_truss/config.yaml +0 -3
- truss/test_data/test_truss/config.yaml +0 -2
- truss/tests/server/common/test_termination_handler_middleware.py +0 -93
- truss/tests/server/control/test_model_container_patch_applier.py +0 -203
- truss/tests/server/core/server/common/test_util.py +0 -19
- truss/tests/server/test_model_wrapper.py +0 -87
- truss/util/data_structures.py +0 -16
- truss-0.10.0rc1.dist-info/RECORD +0 -216
- truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
- truss/{server/shared → base}/__init__.py +0 -0
- truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
- truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
- truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
- truss/{server → templates/server}/common/__init__.py +0 -0
- truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
- truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
- truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
- truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
- truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
- truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
- truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
- truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
- truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
- truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
- truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
- truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
- truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
- truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
- /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
- /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
- /truss/{patch → truss_handle/patch}/constants.py +0 -0
- /truss/{notebook.py → util/notebook.py} +0 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
|
@@ -0,0 +1,1480 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import ast
|
|
3
|
+
import atexit
|
|
4
|
+
import collections
|
|
5
|
+
import contextlib
|
|
6
|
+
import contextvars
|
|
7
|
+
import enum
|
|
8
|
+
import functools
|
|
9
|
+
import importlib.util
|
|
10
|
+
import inspect
|
|
11
|
+
import logging
|
|
12
|
+
import os
|
|
13
|
+
import pathlib
|
|
14
|
+
import pprint
|
|
15
|
+
import sys
|
|
16
|
+
import types
|
|
17
|
+
import warnings
|
|
18
|
+
from importlib.abc import Loader
|
|
19
|
+
from typing import (
|
|
20
|
+
Any,
|
|
21
|
+
Callable,
|
|
22
|
+
Iterable,
|
|
23
|
+
Iterator,
|
|
24
|
+
Mapping,
|
|
25
|
+
MutableMapping,
|
|
26
|
+
Optional,
|
|
27
|
+
Protocol,
|
|
28
|
+
Type,
|
|
29
|
+
TypeVar,
|
|
30
|
+
Union,
|
|
31
|
+
cast,
|
|
32
|
+
get_args,
|
|
33
|
+
get_origin,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
import pydantic
|
|
37
|
+
from typing_extensions import ParamSpec
|
|
38
|
+
|
|
39
|
+
from truss_chains import definitions, utils
|
|
40
|
+
|
|
41
|
+
_SIMPLE_TYPES = {int, float, complex, bool, str, bytes, None, pydantic.BaseModel}
|
|
42
|
+
_SIMPLE_CONTAINERS = {list, dict}
|
|
43
|
+
_STREAM_TYPES = {str, bytes}
|
|
44
|
+
|
|
45
|
+
_DOCS_URL_CHAINING = (
|
|
46
|
+
"https://docs.baseten.co/chains/concepts#depends-call-other-chainlets"
|
|
47
|
+
)
|
|
48
|
+
_DOCS_URL_LOCAL = "https://docs.baseten.co/chains/guide#local-development"
|
|
49
|
+
_DOCS_URL_STREAMING = "https://docs.baseten.co/chains/guide#streaming"
|
|
50
|
+
|
|
51
|
+
# A "neutral dummy" endpoint descriptor if validation fails, this allows to safely
|
|
52
|
+
# continue checking for more errors.
|
|
53
|
+
_DUMMY_ENDPOINT_DESCRIPTOR = definitions.EndpointAPIDescriptor(
|
|
54
|
+
input_args=[], output_types=[], is_async=False, is_streaming=False
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
ChainletT = TypeVar("ChainletT", bound=definitions.ABCChainlet)
|
|
59
|
+
_P = ParamSpec("_P")
|
|
60
|
+
_R = TypeVar("_R")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# Error Collector ######################################################################
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class _ErrorKind(str, enum.Enum):
|
|
67
|
+
TYPE_ERROR = enum.auto()
|
|
68
|
+
IO_TYPE_ERROR = enum.auto()
|
|
69
|
+
MISSING_API_ERROR = enum.auto()
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class _ErrorLocation(definitions.SafeModel):
|
|
73
|
+
src_path: str
|
|
74
|
+
line: Optional[int] = None
|
|
75
|
+
chainlet_name: Optional[str] = None
|
|
76
|
+
method_name: Optional[str] = None
|
|
77
|
+
|
|
78
|
+
def __str__(self) -> str:
|
|
79
|
+
value = f"{self.src_path}:{self.line}"
|
|
80
|
+
if self.chainlet_name and self.method_name:
|
|
81
|
+
value = f"{value} ({self.chainlet_name}.{self.method_name})"
|
|
82
|
+
elif self.chainlet_name:
|
|
83
|
+
value = f"{value} ({self.chainlet_name})"
|
|
84
|
+
else:
|
|
85
|
+
assert not self.chainlet_name
|
|
86
|
+
return value
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class _ValidationError(definitions.SafeModel):
|
|
90
|
+
msg: str
|
|
91
|
+
kind: _ErrorKind
|
|
92
|
+
location: _ErrorLocation
|
|
93
|
+
|
|
94
|
+
def __str__(self) -> str:
|
|
95
|
+
return f"{self.location} [kind: {self.kind.name}]: {self.msg}"
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class _ErrorCollector:
|
|
99
|
+
_errors: list[_ValidationError]
|
|
100
|
+
|
|
101
|
+
def __init__(self) -> None:
|
|
102
|
+
self._errors = []
|
|
103
|
+
# This hook is for the case of just running the Chainlet file, without
|
|
104
|
+
# making a push - we want to surface the errors at exit.
|
|
105
|
+
atexit.register(self.maybe_display_errors)
|
|
106
|
+
|
|
107
|
+
def clear(self) -> None:
|
|
108
|
+
self._errors.clear()
|
|
109
|
+
|
|
110
|
+
def collect(self, error):
|
|
111
|
+
self._errors.append(error)
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def has_errors(self) -> bool:
|
|
115
|
+
return bool(self._errors)
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def num_errors(self) -> int:
|
|
119
|
+
return len(self._errors)
|
|
120
|
+
|
|
121
|
+
def format_errors(self) -> str:
|
|
122
|
+
parts = []
|
|
123
|
+
for error in self._errors:
|
|
124
|
+
parts.append(str(error))
|
|
125
|
+
|
|
126
|
+
return "\n".join(parts)
|
|
127
|
+
|
|
128
|
+
def maybe_display_errors(self) -> None:
|
|
129
|
+
if self.has_errors:
|
|
130
|
+
sys.stderr.write(self.format_errors())
|
|
131
|
+
sys.stderr.write("\n")
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
_global_error_collector = _ErrorCollector()
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _collect_error(msg: str, kind: _ErrorKind, location: _ErrorLocation):
|
|
138
|
+
_global_error_collector.collect(
|
|
139
|
+
_ValidationError(msg=msg, kind=kind, location=location)
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def raise_validation_errors() -> None:
|
|
144
|
+
"""Raises validation errors as combined ``ChainsUsageError``"""
|
|
145
|
+
if _global_error_collector.has_errors:
|
|
146
|
+
error_msg = _global_error_collector.format_errors()
|
|
147
|
+
_global_error_collector.clear() # Clear errors so `atexit` won't display them
|
|
148
|
+
raise definitions.ChainsUsageError(
|
|
149
|
+
"The user defined code does not comply with the required spec, "
|
|
150
|
+
f"please fix below:\n{error_msg}"
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def raise_validation_errors_before(f: Callable[_P, _R]) -> Callable[_P, _R]:
|
|
155
|
+
"""Raises validation errors as combined ``ChainsUsageError`` before invoking `f`."""
|
|
156
|
+
|
|
157
|
+
@functools.wraps(f)
|
|
158
|
+
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
|
159
|
+
raise_validation_errors()
|
|
160
|
+
return f(*args, **kwargs)
|
|
161
|
+
|
|
162
|
+
return wrapper
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class _BaseProvisionMarker:
|
|
166
|
+
"""A marker for object to be dependency injected by the framework."""
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class ContextDependencyMarker(_BaseProvisionMarker):
|
|
170
|
+
def __str__(self) -> str:
|
|
171
|
+
return f"{self.__class__.__name__}"
|
|
172
|
+
|
|
173
|
+
def __getattr__(self, item: str) -> Any:
|
|
174
|
+
logging.error(f"Attempting to access attribute `{item}` on `{self}`.")
|
|
175
|
+
raise definitions.ChainsRuntimeError(
|
|
176
|
+
"It seems `chains.depends_context()` was used, but not as an argument "
|
|
177
|
+
"to the `__init__` method of a Chainlet - This is not supported."
|
|
178
|
+
f"See {_DOCS_URL_CHAINING}.\n"
|
|
179
|
+
"Example of correct `__init__` with context:\n"
|
|
180
|
+
f"{_example_chainlet_code()}"
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class ChainletDependencyMarker(_BaseProvisionMarker):
|
|
185
|
+
chainlet_cls: Type[definitions.ABCChainlet]
|
|
186
|
+
retries: int
|
|
187
|
+
|
|
188
|
+
def __init__(
|
|
189
|
+
self,
|
|
190
|
+
chainlet_cls: Type[definitions.ABCChainlet],
|
|
191
|
+
options: definitions.RPCOptions,
|
|
192
|
+
) -> None:
|
|
193
|
+
self.chainlet_cls = chainlet_cls
|
|
194
|
+
self.options = options
|
|
195
|
+
|
|
196
|
+
def __str__(self) -> str:
|
|
197
|
+
return f"{self.__class__.__name__}({self.chainlet_cls.name})"
|
|
198
|
+
|
|
199
|
+
def __getattr__(self, item: str) -> Any:
|
|
200
|
+
logging.error(f"Attempting to access attribute `{item}` on `{self}`.")
|
|
201
|
+
raise definitions.ChainsRuntimeError(
|
|
202
|
+
f"It seems `chains.depends({self.chainlet_cls.name})` was used, but "
|
|
203
|
+
"not as an argument to the `__init__` method of a Chainlet - This is not "
|
|
204
|
+
"supported. Dependency Chainlets must be passed as init arguments.\n"
|
|
205
|
+
f"See {_DOCS_URL_CHAINING}.\n"
|
|
206
|
+
"Example of correct `__init__` with dependencies:\n"
|
|
207
|
+
f"{_example_chainlet_code()}"
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
# Validation of Chainlet class definition ##############################################
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
@functools.cache
|
|
215
|
+
def _example_chainlet_code() -> str:
|
|
216
|
+
# Note: this function requires all chains modules to be initialized, because
|
|
217
|
+
# in `example_chainlet` the full chainlet validation process is triggered.
|
|
218
|
+
# To avoid circular import dependencies, `_example_chainlet_code` should only be
|
|
219
|
+
# called on erroneous code branches (which will not be triggered if
|
|
220
|
+
# `example_chainlet` is free of errors).
|
|
221
|
+
try:
|
|
222
|
+
from truss_chains.reference_code import reference_chainlet
|
|
223
|
+
# If `example_chainlet` fails validation and `_example_chainlet_code` is
|
|
224
|
+
# called as a result of that, we have a circular import ("partially initialized
|
|
225
|
+
# module 'truss_chains.example_chainlet' ...").
|
|
226
|
+
except AttributeError:
|
|
227
|
+
logging.error("`reference_chainlet` is broken.", exc_info=True, stack_info=True)
|
|
228
|
+
return "<EXAMPLE CODE MISSING/BROKEN>"
|
|
229
|
+
|
|
230
|
+
example_name = reference_chainlet.HelloWorld.name
|
|
231
|
+
return _get_cls_source(reference_chainlet.__file__, example_name)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
@functools.cache
|
|
235
|
+
def _example_model_code() -> str:
|
|
236
|
+
try:
|
|
237
|
+
from truss_chains.reference_code import reference_model
|
|
238
|
+
except AttributeError:
|
|
239
|
+
logging.error("`reference_model` is broken.", exc_info=True, stack_info=True)
|
|
240
|
+
return "<EXAMPLE CODE MISSING/BROKEN>"
|
|
241
|
+
|
|
242
|
+
example_name = reference_model.HelloWorld.name
|
|
243
|
+
return _get_cls_source(reference_model.__file__, example_name)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def _get_cls_source(src_path: str, target_class_name: str) -> str:
|
|
247
|
+
source = pathlib.Path(src_path).read_text()
|
|
248
|
+
tree = ast.parse(source)
|
|
249
|
+
class_code = ""
|
|
250
|
+
for node in ast.walk(tree):
|
|
251
|
+
if isinstance(node, ast.ClassDef) and node.name == target_class_name:
|
|
252
|
+
# Extract the source code of the class definition
|
|
253
|
+
lines = source.splitlines()
|
|
254
|
+
class_code = "\n".join(lines[node.lineno - 1 : node.end_lineno])
|
|
255
|
+
break
|
|
256
|
+
return class_code
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def _instantiation_error_msg(cls_name: str, location: Optional[str] = None) -> str:
|
|
260
|
+
location_format = f"{location}\n" if location else ""
|
|
261
|
+
return (
|
|
262
|
+
f"Error when instantiating Chainlet `{cls_name}`.\n"
|
|
263
|
+
f"{location_format}"
|
|
264
|
+
"Chainlets cannot be naively instantiated. Possible fixes:\n"
|
|
265
|
+
"1. To use Chainlets as dependencies in other Chainlets ('chaining'), "
|
|
266
|
+
f"add them as init argument. See {_DOCS_URL_CHAINING}.\n"
|
|
267
|
+
f"2. For local / debug execution, use the `{run_local.__name__}`-"
|
|
268
|
+
f"context. See {_DOCS_URL_LOCAL}. You cannot use helper functions to "
|
|
269
|
+
"instantiate the Chain in this case.\n"
|
|
270
|
+
"3. Push the chain and call the remote endpoint.\n"
|
|
271
|
+
"Example of correct `__init__` with dependencies:\n"
|
|
272
|
+
f"{_example_chainlet_code()}"
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def _validate_io_type(
|
|
277
|
+
annotation: Any, param_name: str, location: _ErrorLocation
|
|
278
|
+
) -> None:
|
|
279
|
+
"""
|
|
280
|
+
For Chainlet I/O (both data or parameters), we allow simple types
|
|
281
|
+
(int, str, float...) and `list` or `dict` containers of these.
|
|
282
|
+
Any deeper nested and structured data must be typed as a pydantic model.
|
|
283
|
+
"""
|
|
284
|
+
containers_str = [c.__name__ for c in _SIMPLE_CONTAINERS]
|
|
285
|
+
types_str = [c.__name__ if c is not None else "None" for c in _SIMPLE_TYPES]
|
|
286
|
+
if isinstance(annotation, str):
|
|
287
|
+
_collect_error(
|
|
288
|
+
f"A string-valued type annotation was found for `{param_name}` of type "
|
|
289
|
+
f"`{annotation}`. Use only actual types objects and avoid "
|
|
290
|
+
"`from __future__ import annotations` (if needed upgrade python).",
|
|
291
|
+
_ErrorKind.IO_TYPE_ERROR,
|
|
292
|
+
location,
|
|
293
|
+
)
|
|
294
|
+
return
|
|
295
|
+
if annotation in _SIMPLE_TYPES:
|
|
296
|
+
return
|
|
297
|
+
|
|
298
|
+
error_msg = (
|
|
299
|
+
f"Unsupported I/O type for `{param_name}` of type `{annotation}`. "
|
|
300
|
+
"Supported are:\n"
|
|
301
|
+
f"\t* simple types: {types_str}\n"
|
|
302
|
+
"\t* containers of these simple types, with annotated item types: "
|
|
303
|
+
f"{containers_str}, e.g. `dict[str, int]` (use built-in types, not "
|
|
304
|
+
"`typing.Dict`).\n"
|
|
305
|
+
"\t* For complicated / nested data structures: `pydantic` models."
|
|
306
|
+
)
|
|
307
|
+
if isinstance(annotation, types.GenericAlias):
|
|
308
|
+
if get_origin(annotation) not in _SIMPLE_CONTAINERS:
|
|
309
|
+
_collect_error(error_msg, _ErrorKind.IO_TYPE_ERROR, location)
|
|
310
|
+
return
|
|
311
|
+
args = get_args(annotation)
|
|
312
|
+
for arg in args:
|
|
313
|
+
if not (
|
|
314
|
+
arg in _SIMPLE_TYPES or utils.issubclass_safe(arg, pydantic.BaseModel)
|
|
315
|
+
):
|
|
316
|
+
_collect_error(error_msg, _ErrorKind.IO_TYPE_ERROR, location)
|
|
317
|
+
return
|
|
318
|
+
pass
|
|
319
|
+
return
|
|
320
|
+
if utils.issubclass_safe(annotation, pydantic.BaseModel):
|
|
321
|
+
return
|
|
322
|
+
|
|
323
|
+
_collect_error(error_msg, _ErrorKind.IO_TYPE_ERROR, location)
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def _validate_streaming_output_type(
|
|
327
|
+
annotation: Any, location: _ErrorLocation
|
|
328
|
+
) -> definitions.StreamingTypeDescriptor:
|
|
329
|
+
origin = get_origin(annotation)
|
|
330
|
+
assert origin in (collections.abc.AsyncIterator, collections.abc.Iterator)
|
|
331
|
+
args = get_args(annotation)
|
|
332
|
+
if len(args) < 1:
|
|
333
|
+
stream_types = sorted(list(x.__name__ for x in _STREAM_TYPES))
|
|
334
|
+
_collect_error(
|
|
335
|
+
f"Iterators must be annotated with type (one of {stream_types}).",
|
|
336
|
+
_ErrorKind.IO_TYPE_ERROR,
|
|
337
|
+
location,
|
|
338
|
+
)
|
|
339
|
+
return definitions.StreamingTypeDescriptor(
|
|
340
|
+
raw=annotation, origin_type=origin, arg_type=bytes
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
assert len(args) == 1, "Iterator type annotations cannot have more than 1 arg."
|
|
344
|
+
arg = args[0]
|
|
345
|
+
if arg not in _STREAM_TYPES:
|
|
346
|
+
msg = (
|
|
347
|
+
"Streaming endpoints (containing `yield` statements) can only yield string "
|
|
348
|
+
"or byte items. For streaming structured pydantic data, use `stream_writer`"
|
|
349
|
+
"and `stream_reader` helpers.\n"
|
|
350
|
+
f"See streaming docs: {_DOCS_URL_STREAMING}"
|
|
351
|
+
)
|
|
352
|
+
_collect_error(msg, _ErrorKind.IO_TYPE_ERROR, location)
|
|
353
|
+
|
|
354
|
+
return definitions.StreamingTypeDescriptor(
|
|
355
|
+
raw=annotation, origin_type=origin, arg_type=arg
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def _validate_method_signature(
|
|
360
|
+
method_name: str, location: _ErrorLocation, params: list[inspect.Parameter]
|
|
361
|
+
) -> None:
|
|
362
|
+
if len(params) == 0:
|
|
363
|
+
_collect_error(
|
|
364
|
+
f"`{method_name}` must be a method, i.e. with `{definitions.SELF_ARG_NAME}` as "
|
|
365
|
+
"first argument. Got function with no arguments.",
|
|
366
|
+
_ErrorKind.TYPE_ERROR,
|
|
367
|
+
location,
|
|
368
|
+
)
|
|
369
|
+
elif params[0].name != definitions.SELF_ARG_NAME:
|
|
370
|
+
_collect_error(
|
|
371
|
+
f"`{method_name}` must be a method, i.e. with `{definitions.SELF_ARG_NAME}` as "
|
|
372
|
+
f"first argument. Got `{params[0].name}` as first argument.",
|
|
373
|
+
_ErrorKind.TYPE_ERROR,
|
|
374
|
+
location,
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def _validate_endpoint_params(
|
|
379
|
+
params: list[inspect.Parameter], location: _ErrorLocation
|
|
380
|
+
) -> list[definitions.InputArg]:
|
|
381
|
+
_validate_method_signature(definitions.RUN_REMOTE_METHOD_NAME, location, params)
|
|
382
|
+
input_args = []
|
|
383
|
+
for param in params[1:]: # Skip self argument.
|
|
384
|
+
if param.annotation == inspect.Parameter.empty:
|
|
385
|
+
_collect_error(
|
|
386
|
+
"Arguments of endpoints must have type annotations. "
|
|
387
|
+
f"Parameter `{param.name}` has no type annotation.",
|
|
388
|
+
_ErrorKind.IO_TYPE_ERROR,
|
|
389
|
+
location,
|
|
390
|
+
)
|
|
391
|
+
else:
|
|
392
|
+
_validate_io_type(param.annotation, param.name, location)
|
|
393
|
+
type_descriptor = definitions.TypeDescriptor(raw=param.annotation)
|
|
394
|
+
is_optional = param.default != inspect.Parameter.empty
|
|
395
|
+
input_args.append(
|
|
396
|
+
definitions.InputArg(
|
|
397
|
+
name=param.name, type=type_descriptor, is_optional=is_optional
|
|
398
|
+
)
|
|
399
|
+
)
|
|
400
|
+
return input_args
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def _validate_endpoint_output_types(
|
|
404
|
+
annotation: Any, signature, location: _ErrorLocation, is_streaming: bool
|
|
405
|
+
) -> list[definitions.TypeDescriptor]:
|
|
406
|
+
has_streaming_type = False
|
|
407
|
+
if annotation == inspect.Parameter.empty:
|
|
408
|
+
_collect_error(
|
|
409
|
+
"Return values of endpoints must be type annotated. Got:\n"
|
|
410
|
+
f"\t{location.method_name}{signature} -> !MISSING!",
|
|
411
|
+
_ErrorKind.IO_TYPE_ERROR,
|
|
412
|
+
location,
|
|
413
|
+
)
|
|
414
|
+
return []
|
|
415
|
+
origin = get_origin(annotation)
|
|
416
|
+
if origin is tuple:
|
|
417
|
+
output_types = []
|
|
418
|
+
for i, arg in enumerate(get_args(annotation)):
|
|
419
|
+
_validate_io_type(arg, f"return_type[{i}]", location)
|
|
420
|
+
output_types.append(definitions.TypeDescriptor(raw=arg))
|
|
421
|
+
|
|
422
|
+
elif origin in (collections.abc.AsyncIterator, collections.abc.Iterator):
|
|
423
|
+
output_types = [_validate_streaming_output_type(annotation, location)]
|
|
424
|
+
has_streaming_type = True
|
|
425
|
+
if not is_streaming:
|
|
426
|
+
_collect_error(
|
|
427
|
+
"If the endpoint returns an iterator (streaming), it must have `yield` "
|
|
428
|
+
"statements.",
|
|
429
|
+
_ErrorKind.IO_TYPE_ERROR,
|
|
430
|
+
location,
|
|
431
|
+
)
|
|
432
|
+
else:
|
|
433
|
+
_validate_io_type(annotation, "return_type", location)
|
|
434
|
+
output_types = [definitions.TypeDescriptor(raw=annotation)]
|
|
435
|
+
|
|
436
|
+
if is_streaming and not has_streaming_type:
|
|
437
|
+
_collect_error(
|
|
438
|
+
"If the endpoint is streaming (has `yield` statements), the return type "
|
|
439
|
+
"must be an iterator (e.g. `AsyncIterator[bytes]`). Got:\n"
|
|
440
|
+
f"\t{location.method_name}{signature} -> {annotation}",
|
|
441
|
+
_ErrorKind.IO_TYPE_ERROR,
|
|
442
|
+
location,
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
return output_types
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
def _validate_and_describe_endpoint(
|
|
449
|
+
cls: Type[definitions.ABCChainlet], location: _ErrorLocation
|
|
450
|
+
) -> definitions.EndpointAPIDescriptor:
|
|
451
|
+
"""The "endpoint method" of a Chainlet must have the following signature:
|
|
452
|
+
|
|
453
|
+
```
|
|
454
|
+
[async] def run_remote(
|
|
455
|
+
self, [param_0: anno_0, param_1: anno_1 = default_1, ...]) -> ret_anno:
|
|
456
|
+
```
|
|
457
|
+
|
|
458
|
+
* The name must be `run_remote` for Chainlets, or `predict` for Models.
|
|
459
|
+
* It can be sync or async def.
|
|
460
|
+
* The number and names of parameters are arbitrary, both positional and named
|
|
461
|
+
parameters are ok.
|
|
462
|
+
* All parameters and the return value must have type annotations. See
|
|
463
|
+
`_validate_io_type` for valid types.
|
|
464
|
+
* Generators are allowed, too (but not yet supported).
|
|
465
|
+
"""
|
|
466
|
+
if not hasattr(cls, cls.endpoint_method_name):
|
|
467
|
+
_collect_error(
|
|
468
|
+
f"{cls.entity_type}s must have a `{cls.endpoint_method_name}` method.",
|
|
469
|
+
_ErrorKind.MISSING_API_ERROR,
|
|
470
|
+
location,
|
|
471
|
+
)
|
|
472
|
+
return _DUMMY_ENDPOINT_DESCRIPTOR
|
|
473
|
+
|
|
474
|
+
# This is the unbound method.
|
|
475
|
+
endpoint_method = getattr(cls, cls.endpoint_method_name)
|
|
476
|
+
line = inspect.getsourcelines(endpoint_method)[1]
|
|
477
|
+
location = location.model_copy(
|
|
478
|
+
update={"line": line, "method_name": cls.endpoint_method_name}
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
if not inspect.isfunction(endpoint_method):
|
|
482
|
+
_collect_error("`Endpoints must be a method.", _ErrorKind.TYPE_ERROR, location)
|
|
483
|
+
# If it's not a function, it might be a class var and subsequent inspections
|
|
484
|
+
# fail.
|
|
485
|
+
return _DUMMY_ENDPOINT_DESCRIPTOR
|
|
486
|
+
signature = inspect.signature(endpoint_method)
|
|
487
|
+
input_args = _validate_endpoint_params(
|
|
488
|
+
list(signature.parameters.values()), location
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
if inspect.isasyncgenfunction(endpoint_method):
|
|
492
|
+
is_async = True
|
|
493
|
+
is_streaming = True
|
|
494
|
+
elif inspect.iscoroutinefunction(endpoint_method):
|
|
495
|
+
is_async = True
|
|
496
|
+
is_streaming = False
|
|
497
|
+
else:
|
|
498
|
+
is_async = False
|
|
499
|
+
is_streaming = inspect.isgeneratorfunction(endpoint_method)
|
|
500
|
+
|
|
501
|
+
output_types = _validate_endpoint_output_types(
|
|
502
|
+
signature.return_annotation, signature, location, is_streaming
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
if is_streaming:
|
|
506
|
+
if not is_async:
|
|
507
|
+
_collect_error(
|
|
508
|
+
"`Streaming endpoints (containing `yield` statements) are only "
|
|
509
|
+
"supported for async endpoints.",
|
|
510
|
+
_ErrorKind.IO_TYPE_ERROR,
|
|
511
|
+
location,
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
if not is_async:
|
|
515
|
+
warnings.warn(
|
|
516
|
+
f"`{cls.endpoint_method_name}` must be an async (coroutine) function in future releases. "
|
|
517
|
+
"Replace `def run_remote(...)` with `async def run_remote(...)`. "
|
|
518
|
+
"Local testing and execution can be done with "
|
|
519
|
+
"`asyncio.run(my_chainlet.run_remote(...))`.\n"
|
|
520
|
+
"Note on concurrency: previously sync functions were run in threads by the "
|
|
521
|
+
"Truss server.\n"
|
|
522
|
+
"For some frameworks this was **unsafe** (e.g. in torch the CUDA context "
|
|
523
|
+
"is not thread-safe).\n"
|
|
524
|
+
"Additionally, python threads hold the GIL and therefore might not give "
|
|
525
|
+
"actual throughput gains.\n"
|
|
526
|
+
"To achieve safe and performant concurrency, use framework-specific async "
|
|
527
|
+
"APIs (e.g. AsyncLLMEngine for vLLM) or generic async batching like such "
|
|
528
|
+
"as https://github.com/hussein-awala/async-batcher.",
|
|
529
|
+
DeprecationWarning,
|
|
530
|
+
stacklevel=1,
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
return definitions.EndpointAPIDescriptor(
|
|
534
|
+
name=cls.endpoint_method_name,
|
|
535
|
+
input_args=input_args,
|
|
536
|
+
output_types=output_types,
|
|
537
|
+
is_async=is_async,
|
|
538
|
+
is_streaming=is_streaming,
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
def _get_generic_class_type(var):
|
|
543
|
+
"""Extracts `SomeGeneric` from `SomeGeneric` or `SomeGeneric[T]` uniformly."""
|
|
544
|
+
origin = get_origin(var)
|
|
545
|
+
return origin if origin is not None else var
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
class _ChainletInitValidator:
|
|
549
|
+
"""The `__init__`-method of a Chainlet must have the following signature:
|
|
550
|
+
|
|
551
|
+
```
|
|
552
|
+
def __init__(
|
|
553
|
+
self,
|
|
554
|
+
[dep_0: dep_0_type = truss_chains.depends(dep_0_class),]
|
|
555
|
+
[dep_1: dep_1_type = truss_chains.depends(dep_1_class),]
|
|
556
|
+
...
|
|
557
|
+
[dep_N: dep_N_type = truss_chains.provides(dep_N_class),]
|
|
558
|
+
[context: truss_chains.Context = truss_chains.depends_context()]
|
|
559
|
+
) -> None:
|
|
560
|
+
```
|
|
561
|
+
* The context argument is optionally trailing and must have a default constructed
|
|
562
|
+
with the `provide_context` directive.
|
|
563
|
+
* The names and number of Chainlet "dependency" arguments are arbitrary.
|
|
564
|
+
* Default values for dependencies must be constructed with the `depends` directive
|
|
565
|
+
to make the dependency injection work. The argument to `depends` must be a
|
|
566
|
+
Chainlet class.
|
|
567
|
+
* The type annotation for dependencies can be a Chainlet class, but it can also be
|
|
568
|
+
a `Protocol` with an equivalent `run` method (e.g. for getting correct type
|
|
569
|
+
checks when providing fake Chainlets for local testing.). It may be omitted if
|
|
570
|
+
the type is clear from the RHS.
|
|
571
|
+
"""
|
|
572
|
+
|
|
573
|
+
_location: _ErrorLocation
|
|
574
|
+
_cls: Type[definitions.ABCChainlet]
|
|
575
|
+
has_context: bool = False
|
|
576
|
+
validated_dependencies: Mapping[str, definitions.DependencyDescriptor] = {}
|
|
577
|
+
|
|
578
|
+
def __init__(
|
|
579
|
+
self, cls: Type[definitions.ABCChainlet], location: _ErrorLocation
|
|
580
|
+
) -> None:
|
|
581
|
+
self._cls = cls
|
|
582
|
+
if not self._cls.has_custom_init():
|
|
583
|
+
self.has_context = False
|
|
584
|
+
self.validated_dependencies = {}
|
|
585
|
+
return
|
|
586
|
+
line = inspect.getsourcelines(cls.__init__)[1]
|
|
587
|
+
self._location = location.model_copy(
|
|
588
|
+
update={"line": line, "method_name": "__init__"}
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
params = list(inspect.signature(cls.__init__).parameters.values())
|
|
592
|
+
self._validate_args(params)
|
|
593
|
+
|
|
594
|
+
def _validate_args(self, params: list[inspect.Parameter]):
|
|
595
|
+
# Each validation pops of "processed" arguments from the list.
|
|
596
|
+
self._validate_self_arg(params)
|
|
597
|
+
self._validate_context_arg(params)
|
|
598
|
+
self._validate_dependencies(params)
|
|
599
|
+
|
|
600
|
+
def _validate_self_arg(self, params: list[inspect.Parameter]):
|
|
601
|
+
if len(params) == 0:
|
|
602
|
+
_collect_error(
|
|
603
|
+
"Methods must have first argument `self`, got no arguments.",
|
|
604
|
+
_ErrorKind.TYPE_ERROR,
|
|
605
|
+
self._location,
|
|
606
|
+
)
|
|
607
|
+
return params
|
|
608
|
+
param = params.pop(0)
|
|
609
|
+
if param.name != definitions.SELF_ARG_NAME:
|
|
610
|
+
_collect_error(
|
|
611
|
+
f"Methods must have first argument `self`, got `{param.name}`.",
|
|
612
|
+
_ErrorKind.TYPE_ERROR,
|
|
613
|
+
self._location,
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
def _validate_context_arg(self, params: list[inspect.Parameter]):
|
|
617
|
+
def make_context_error_msg():
|
|
618
|
+
return (
|
|
619
|
+
f"If `{self._cls.entity_type}` uses context for initialization, it "
|
|
620
|
+
f"must have `{definitions.CONTEXT_ARG_NAME}` argument of type "
|
|
621
|
+
f"`{definitions.DeploymentContext}` as the last argument.\n"
|
|
622
|
+
f"Got arguments: `{params}`.\n"
|
|
623
|
+
"Example of correct `__init__` with context:\n"
|
|
624
|
+
f"{self._example_code()}"
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
if not params:
|
|
628
|
+
return
|
|
629
|
+
|
|
630
|
+
has_context = params[-1].name == definitions.CONTEXT_ARG_NAME
|
|
631
|
+
has_context_marker = isinstance(params[-1].default, ContextDependencyMarker)
|
|
632
|
+
if has_context ^ has_context_marker:
|
|
633
|
+
_collect_error(
|
|
634
|
+
make_context_error_msg(), _ErrorKind.TYPE_ERROR, self._location
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
if not has_context:
|
|
638
|
+
return
|
|
639
|
+
|
|
640
|
+
self.has_context = True
|
|
641
|
+
param = params.pop(-1)
|
|
642
|
+
param_type = _get_generic_class_type(param.annotation)
|
|
643
|
+
# We are lenient and allow omitting the type annotation for context.
|
|
644
|
+
if (
|
|
645
|
+
(param_type is not None)
|
|
646
|
+
and (param_type != inspect.Parameter.empty)
|
|
647
|
+
and (not utils.issubclass_safe(param_type, definitions.DeploymentContext))
|
|
648
|
+
):
|
|
649
|
+
_collect_error(
|
|
650
|
+
make_context_error_msg(), _ErrorKind.TYPE_ERROR, self._location
|
|
651
|
+
)
|
|
652
|
+
if not isinstance(param.default, ContextDependencyMarker):
|
|
653
|
+
_collect_error(
|
|
654
|
+
f"Incorrect default value `{param.default}` for `context` argument. "
|
|
655
|
+
"Example of correct `__init__` with dependencies:\n"
|
|
656
|
+
f"{self._example_code()}",
|
|
657
|
+
_ErrorKind.TYPE_ERROR,
|
|
658
|
+
self._location,
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
def _validate_dependencies(self, params: list[inspect.Parameter]):
|
|
662
|
+
used = set()
|
|
663
|
+
dependencies = {}
|
|
664
|
+
|
|
665
|
+
if params and not self._cls.supports_dependencies:
|
|
666
|
+
_collect_error(
|
|
667
|
+
f"The only supported argument to `__init__` for {self._cls.entity_type}s "
|
|
668
|
+
f"is the optional context argument.",
|
|
669
|
+
_ErrorKind.TYPE_ERROR,
|
|
670
|
+
self._location,
|
|
671
|
+
)
|
|
672
|
+
return
|
|
673
|
+
for param in params:
|
|
674
|
+
marker = self._validate_dependency_param(param)
|
|
675
|
+
if marker is None:
|
|
676
|
+
continue
|
|
677
|
+
if marker.chainlet_cls in used:
|
|
678
|
+
_collect_error(
|
|
679
|
+
f"The same Chainlet class cannot be used multiple times for "
|
|
680
|
+
f"different arguments. Got previously used "
|
|
681
|
+
f"`{marker.chainlet_cls}` for `{param.name}`.",
|
|
682
|
+
_ErrorKind.TYPE_ERROR,
|
|
683
|
+
self._location,
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
dependencies[param.name] = definitions.DependencyDescriptor(
|
|
687
|
+
chainlet_cls=marker.chainlet_cls, options=marker.options
|
|
688
|
+
)
|
|
689
|
+
used.add(marker.chainlet_cls)
|
|
690
|
+
|
|
691
|
+
self.validated_dependencies = dependencies
|
|
692
|
+
|
|
693
|
+
def _validate_dependency_param(
|
|
694
|
+
self, param: inspect.Parameter
|
|
695
|
+
) -> Optional[ChainletDependencyMarker]:
|
|
696
|
+
"""
|
|
697
|
+
Returns a valid ChainletDependencyMarker if found, None otherwise.
|
|
698
|
+
"""
|
|
699
|
+
# TODO: handle subclasses, unions, optionals, check default value etc.
|
|
700
|
+
if param.name == definitions.CONTEXT_ARG_NAME:
|
|
701
|
+
_collect_error(
|
|
702
|
+
f"The init argument name `{definitions.CONTEXT_ARG_NAME}` is reserved for "
|
|
703
|
+
"the optional context argument, which must be trailing if used. Example "
|
|
704
|
+
"of correct `__init__` with context:\n"
|
|
705
|
+
f"{self._example_code()}",
|
|
706
|
+
_ErrorKind.TYPE_ERROR,
|
|
707
|
+
self._location,
|
|
708
|
+
)
|
|
709
|
+
|
|
710
|
+
if not isinstance(param.default, ChainletDependencyMarker):
|
|
711
|
+
_collect_error(
|
|
712
|
+
f"Any arguments of a Chainlet's __init__ (besides `context`) must have "
|
|
713
|
+
"dependency Chainlets with default values from `chains.depends`-directive. "
|
|
714
|
+
f"Got `{param}`.\n"
|
|
715
|
+
f"Example of correct `__init__` with dependencies:\n"
|
|
716
|
+
f"{self._example_code()}",
|
|
717
|
+
_ErrorKind.TYPE_ERROR,
|
|
718
|
+
self._location,
|
|
719
|
+
)
|
|
720
|
+
return None
|
|
721
|
+
|
|
722
|
+
chainlet_cls = param.default.chainlet_cls
|
|
723
|
+
if not utils.issubclass_safe(chainlet_cls, definitions.ABCChainlet):
|
|
724
|
+
_collect_error(
|
|
725
|
+
f"`chains.depends` must be used with a Chainlet class as argument, got "
|
|
726
|
+
f"{chainlet_cls} instead.",
|
|
727
|
+
_ErrorKind.TYPE_ERROR,
|
|
728
|
+
self._location,
|
|
729
|
+
)
|
|
730
|
+
return None
|
|
731
|
+
# Check type annotation.
|
|
732
|
+
# Also lenient with type annotation: since the RHS / default is asserted to be a
|
|
733
|
+
# chainlet class, proper type inference is possible even without annotation.
|
|
734
|
+
# TODO: `Protocol` is not a proper class and this might be version dependent.
|
|
735
|
+
# Find a better way to inspect this.
|
|
736
|
+
if not (
|
|
737
|
+
param.annotation == inspect.Parameter.empty
|
|
738
|
+
or utils.issubclass_safe(param.annotation, Protocol) # type: ignore[arg-type]
|
|
739
|
+
or utils.issubclass_safe(chainlet_cls, param.annotation)
|
|
740
|
+
):
|
|
741
|
+
_collect_error(
|
|
742
|
+
f"The type annotation for `{param.name}` must be a class/subclass of the "
|
|
743
|
+
"Chainlet type specified by `chains.provides` or a compatible "
|
|
744
|
+
f"typing.Protocol`. Got `{param.annotation}`.",
|
|
745
|
+
_ErrorKind.TYPE_ERROR,
|
|
746
|
+
self._location,
|
|
747
|
+
)
|
|
748
|
+
return param.default # The Marker.
|
|
749
|
+
|
|
750
|
+
@functools.cache
|
|
751
|
+
def _example_code(self) -> str:
|
|
752
|
+
if self._cls.entity_type == "Model":
|
|
753
|
+
return _example_model_code()
|
|
754
|
+
return _example_chainlet_code()
|
|
755
|
+
|
|
756
|
+
|
|
757
|
+
def _validate_remote_config(
|
|
758
|
+
cls: Type[definitions.ABCChainlet], location: _ErrorLocation
|
|
759
|
+
):
|
|
760
|
+
if not isinstance(
|
|
761
|
+
remote_config := getattr(cls, definitions.REMOTE_CONFIG_NAME),
|
|
762
|
+
definitions.RemoteConfig,
|
|
763
|
+
):
|
|
764
|
+
_collect_error(
|
|
765
|
+
f"{cls.entity_type}s must have a `{definitions.REMOTE_CONFIG_NAME}` class variable "
|
|
766
|
+
f"of type `{definitions.RemoteConfig}`. Got `{type(remote_config)}` "
|
|
767
|
+
f"for `{cls}`.",
|
|
768
|
+
_ErrorKind.TYPE_ERROR,
|
|
769
|
+
location,
|
|
770
|
+
)
|
|
771
|
+
|
|
772
|
+
|
|
773
|
+
def _validate_health_check(
|
|
774
|
+
cls: Type[definitions.ABCChainlet], location: _ErrorLocation
|
|
775
|
+
) -> Optional[definitions.HealthCheckAPIDescriptor]:
|
|
776
|
+
"""The `is_healthy` method of a Chainlet must have the following signature:
|
|
777
|
+
```
|
|
778
|
+
[async] def is_healthy(self) -> bool:
|
|
779
|
+
```
|
|
780
|
+
* The name must be `is_healthy`.
|
|
781
|
+
* It can be sync or async def.
|
|
782
|
+
* Must not define any parameters other than `self`.
|
|
783
|
+
* Must return a boolean.
|
|
784
|
+
"""
|
|
785
|
+
if not hasattr(cls, definitions.HEALTH_CHECK_METHOD_NAME):
|
|
786
|
+
return None
|
|
787
|
+
|
|
788
|
+
health_check_method = getattr(cls, definitions.HEALTH_CHECK_METHOD_NAME)
|
|
789
|
+
if not inspect.isfunction(health_check_method):
|
|
790
|
+
_collect_error(
|
|
791
|
+
f"`{definitions.HEALTH_CHECK_METHOD_NAME}` must be a method.",
|
|
792
|
+
_ErrorKind.TYPE_ERROR,
|
|
793
|
+
location,
|
|
794
|
+
)
|
|
795
|
+
return None
|
|
796
|
+
|
|
797
|
+
line = inspect.getsourcelines(health_check_method)[1]
|
|
798
|
+
location = location.model_copy(
|
|
799
|
+
update={"line": line, "method_name": definitions.HEALTH_CHECK_METHOD_NAME}
|
|
800
|
+
)
|
|
801
|
+
is_async = inspect.iscoroutinefunction(health_check_method)
|
|
802
|
+
signature = inspect.signature(health_check_method)
|
|
803
|
+
params = list(signature.parameters.values())
|
|
804
|
+
_validate_method_signature(definitions.HEALTH_CHECK_METHOD_NAME, location, params)
|
|
805
|
+
if len(params) > 1:
|
|
806
|
+
_collect_error(
|
|
807
|
+
f"`{definitions.HEALTH_CHECK_METHOD_NAME}` must have only one argument: `{definitions.SELF_ARG_NAME}`.",
|
|
808
|
+
_ErrorKind.TYPE_ERROR,
|
|
809
|
+
location,
|
|
810
|
+
)
|
|
811
|
+
if signature.return_annotation == inspect.Parameter.empty:
|
|
812
|
+
_collect_error(
|
|
813
|
+
"Return value of health check must be type annotated. Got:\n"
|
|
814
|
+
f"\t{location.method_name}{signature} -> !MISSING!",
|
|
815
|
+
_ErrorKind.IO_TYPE_ERROR,
|
|
816
|
+
location,
|
|
817
|
+
)
|
|
818
|
+
return None
|
|
819
|
+
if signature.return_annotation is not bool:
|
|
820
|
+
_collect_error(
|
|
821
|
+
"Return value of health check must be a boolean. Got:\n"
|
|
822
|
+
f"\t{location.method_name}{signature} -> {signature.return_annotation}",
|
|
823
|
+
_ErrorKind.IO_TYPE_ERROR,
|
|
824
|
+
location,
|
|
825
|
+
)
|
|
826
|
+
|
|
827
|
+
return definitions.HealthCheckAPIDescriptor(is_async=is_async)
|
|
828
|
+
|
|
829
|
+
|
|
830
|
+
def validate_and_register_cls(cls: Type[definitions.ABCChainlet]) -> None:
|
|
831
|
+
"""Note that validation errors will only be collected, not raised, and Chainlets.
|
|
832
|
+
with issues, are still added to the registry. Use `raise_validation_errors` to
|
|
833
|
+
assert all Chainlets are valid and before performing operations that depend on
|
|
834
|
+
these constraints."""
|
|
835
|
+
src_path = os.path.abspath(inspect.getfile(cls))
|
|
836
|
+
line = inspect.getsourcelines(cls)[1]
|
|
837
|
+
location = _ErrorLocation(src_path=src_path, line=line, chainlet_name=cls.__name__)
|
|
838
|
+
|
|
839
|
+
_validate_remote_config(cls, location)
|
|
840
|
+
init_validator = _ChainletInitValidator(cls, location)
|
|
841
|
+
chainlet_descriptor = definitions.ChainletAPIDescriptor(
|
|
842
|
+
chainlet_cls=cls,
|
|
843
|
+
dependencies=init_validator.validated_dependencies,
|
|
844
|
+
has_context=init_validator.has_context,
|
|
845
|
+
endpoint=_validate_and_describe_endpoint(cls, location),
|
|
846
|
+
src_path=src_path,
|
|
847
|
+
health_check=_validate_health_check(cls, location),
|
|
848
|
+
)
|
|
849
|
+
logging.debug(
|
|
850
|
+
f"Descriptor for {cls}:\n{pprint.pformat(chainlet_descriptor, indent=4)}\n"
|
|
851
|
+
)
|
|
852
|
+
_global_chainlet_registry.register_chainlet(chainlet_descriptor)
|
|
853
|
+
|
|
854
|
+
|
|
855
|
+
# Dependency-Injection / Registry ######################################################
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
class _ChainletRegistry:
|
|
859
|
+
# Because dependencies are required to be present when registering a Chainlet,
|
|
860
|
+
# this dict contains natively a topological sorting of the dependency graph.
|
|
861
|
+
_chainlets: collections.OrderedDict[
|
|
862
|
+
Type[definitions.ABCChainlet], definitions.ChainletAPIDescriptor
|
|
863
|
+
]
|
|
864
|
+
_name_to_cls: MutableMapping[str, Type[definitions.ABCChainlet]]
|
|
865
|
+
|
|
866
|
+
def __init__(self) -> None:
|
|
867
|
+
self._chainlets = collections.OrderedDict()
|
|
868
|
+
self._name_to_cls = {}
|
|
869
|
+
|
|
870
|
+
def clear(self):
|
|
871
|
+
self._chainlets = collections.OrderedDict()
|
|
872
|
+
self._name_to_cls = {}
|
|
873
|
+
|
|
874
|
+
def register_chainlet(self, chainlet_descriptor: definitions.ChainletAPIDescriptor):
|
|
875
|
+
for dep in chainlet_descriptor.dependencies.values():
|
|
876
|
+
# To depend on a Chainlet, the class must be defined (module initialized)
|
|
877
|
+
# which entails that is has already been added to the registry.
|
|
878
|
+
# This is an assertion, because unless users meddle with the internal
|
|
879
|
+
# registry, it's not possible to depend on another chainlet before it's
|
|
880
|
+
# also added to the registry.
|
|
881
|
+
assert dep.chainlet_cls in self._chainlets, (
|
|
882
|
+
"Cannot depend on Chainlet. Available Chainlets: "
|
|
883
|
+
f"{list(self._chainlets.keys())}"
|
|
884
|
+
)
|
|
885
|
+
|
|
886
|
+
# Because class are globally unique, to prevent re-use / overwriting of names,
|
|
887
|
+
# We must check this in addition.
|
|
888
|
+
if chainlet_descriptor.name in self._name_to_cls:
|
|
889
|
+
conflict = self._name_to_cls[chainlet_descriptor.name]
|
|
890
|
+
existing_source_path = self._chainlets[conflict].src_path
|
|
891
|
+
raise definitions.ChainsUsageError(
|
|
892
|
+
f"A Chainlet with name `{chainlet_descriptor.name}` was already "
|
|
893
|
+
f"defined, Chainlet names must be globally unique.\n"
|
|
894
|
+
f"Pre-existing in: `{existing_source_path}`\n"
|
|
895
|
+
f"New conflict in: `{chainlet_descriptor.src_path}`."
|
|
896
|
+
)
|
|
897
|
+
|
|
898
|
+
self._chainlets[chainlet_descriptor.chainlet_cls] = chainlet_descriptor
|
|
899
|
+
self._name_to_cls[chainlet_descriptor.name] = chainlet_descriptor.chainlet_cls
|
|
900
|
+
|
|
901
|
+
def unregister_chainlet(self, chainlet_name: str) -> None:
|
|
902
|
+
chainlet_cls = self._name_to_cls.pop(chainlet_name)
|
|
903
|
+
self._chainlets.pop(chainlet_cls)
|
|
904
|
+
|
|
905
|
+
@property
|
|
906
|
+
def chainlet_descriptors(self) -> list[definitions.ChainletAPIDescriptor]:
|
|
907
|
+
return list(self._chainlets.values())
|
|
908
|
+
|
|
909
|
+
def get_descriptor(
|
|
910
|
+
self, chainlet_cls: Type[definitions.ABCChainlet]
|
|
911
|
+
) -> definitions.ChainletAPIDescriptor:
|
|
912
|
+
return self._chainlets[chainlet_cls]
|
|
913
|
+
|
|
914
|
+
def get_dependencies(
|
|
915
|
+
self, chainlet: definitions.ChainletAPIDescriptor
|
|
916
|
+
) -> Iterable[definitions.ChainletAPIDescriptor]:
|
|
917
|
+
return [
|
|
918
|
+
self._chainlets[dep.chainlet_cls]
|
|
919
|
+
for dep in self._chainlets[chainlet.chainlet_cls].dependencies.values()
|
|
920
|
+
]
|
|
921
|
+
|
|
922
|
+
def get_chainlet_names(self) -> set[str]:
|
|
923
|
+
return set(self._name_to_cls.keys())
|
|
924
|
+
|
|
925
|
+
|
|
926
|
+
_global_chainlet_registry = _ChainletRegistry()
|
|
927
|
+
|
|
928
|
+
|
|
929
|
+
def get_dependencies(
|
|
930
|
+
chainlet: definitions.ChainletAPIDescriptor,
|
|
931
|
+
) -> Iterable[definitions.ChainletAPIDescriptor]:
|
|
932
|
+
return _global_chainlet_registry.get_dependencies(chainlet)
|
|
933
|
+
|
|
934
|
+
|
|
935
|
+
def get_descriptor(
|
|
936
|
+
chainlet_cls: Type[definitions.ABCChainlet],
|
|
937
|
+
) -> definitions.ChainletAPIDescriptor:
|
|
938
|
+
return _global_chainlet_registry.get_descriptor(chainlet_cls)
|
|
939
|
+
|
|
940
|
+
|
|
941
|
+
def get_ordered_descriptors() -> list[definitions.ChainletAPIDescriptor]:
|
|
942
|
+
return _global_chainlet_registry.chainlet_descriptors
|
|
943
|
+
|
|
944
|
+
|
|
945
|
+
# Chainlet class runtime utils #########################################################
|
|
946
|
+
|
|
947
|
+
|
|
948
|
+
def _determine_arguments(func: Callable, **kwargs):
|
|
949
|
+
"""Merges provided and default arguments to effective invocation arguments."""
|
|
950
|
+
sig = inspect.signature(func)
|
|
951
|
+
bound_args = sig.bind_partial(**kwargs)
|
|
952
|
+
bound_args.apply_defaults()
|
|
953
|
+
return bound_args.arguments
|
|
954
|
+
|
|
955
|
+
|
|
956
|
+
def ensure_args_are_injected(cls, original_init: Callable, kwargs) -> None:
|
|
957
|
+
"""Asserts all marker markers are replaced by actual objects."""
|
|
958
|
+
final_args = _determine_arguments(original_init, **kwargs)
|
|
959
|
+
for name, value in final_args.items():
|
|
960
|
+
if name == definitions.CONTEXT_ARG_NAME:
|
|
961
|
+
if not isinstance(value, definitions.DeploymentContext):
|
|
962
|
+
logging.error(
|
|
963
|
+
f"When initializing {cls.entity_type} `{cls.name}`, for context "
|
|
964
|
+
f"argument an incompatible value was passed, value: `{value}`."
|
|
965
|
+
)
|
|
966
|
+
raise definitions.ChainsRuntimeError(_instantiation_error_msg(cls.name))
|
|
967
|
+
# The argument is a dependency chainlet.
|
|
968
|
+
elif isinstance(value, _BaseProvisionMarker):
|
|
969
|
+
logging.error(
|
|
970
|
+
f"When initializing {cls.entity_type} `{cls.name}`, for dependency Chainlet"
|
|
971
|
+
f"argument `{name}` an incompatible value was passed, value: `{value}`."
|
|
972
|
+
)
|
|
973
|
+
raise definitions.ChainsRuntimeError(_instantiation_error_msg(cls.name))
|
|
974
|
+
|
|
975
|
+
|
|
976
|
+
# Local Execution ######################################################################
|
|
977
|
+
|
|
978
|
+
# A variable to track the stack depth relative to `run_local` context manager.
|
|
979
|
+
run_local_stack_depth: contextvars.ContextVar[int] = contextvars.ContextVar(
|
|
980
|
+
"run_local_stack_depth"
|
|
981
|
+
)
|
|
982
|
+
|
|
983
|
+
_INIT_LOCAL_NAME = "__init_local__"
|
|
984
|
+
_INIT_NAME = "__init__"
|
|
985
|
+
|
|
986
|
+
|
|
987
|
+
def _create_modified_init_for_local(
|
|
988
|
+
chainlet_descriptor: definitions.ChainletAPIDescriptor,
|
|
989
|
+
cls_to_instance: MutableMapping[
|
|
990
|
+
Type[definitions.ABCChainlet], definitions.ABCChainlet
|
|
991
|
+
],
|
|
992
|
+
secrets: Mapping[str, str],
|
|
993
|
+
data_dir: Optional[pathlib.Path],
|
|
994
|
+
chainlet_to_service: Mapping[str, definitions.DeployedServiceDescriptor],
|
|
995
|
+
):
|
|
996
|
+
"""Replaces the default argument values with local Chainlet instantiations.
|
|
997
|
+
|
|
998
|
+
If this patch is used, Chainlets can be functionally instantiated without
|
|
999
|
+
any init args (because the patched defaults are sufficient).
|
|
1000
|
+
"""
|
|
1001
|
+
|
|
1002
|
+
def _detect_naive_instantiations(
|
|
1003
|
+
stack: list[inspect.FrameInfo], levels_below_run_local: int
|
|
1004
|
+
) -> None:
|
|
1005
|
+
# The goal is to find cases where a chainlet is directly instantiated
|
|
1006
|
+
# in a place that is not immediately inside the `run_local`-contextmanager.
|
|
1007
|
+
# In particular chainlets being instantiated in the `__init__` or `run_remote`
|
|
1008
|
+
# methods of other chainlets (instead of being passed as dependencies with
|
|
1009
|
+
# `chains.depends()`).
|
|
1010
|
+
#
|
|
1011
|
+
# We look into the calls stack of any (wrapped) invocation of an
|
|
1012
|
+
# ABCChainlet-subclass's `__init__`.
|
|
1013
|
+
# We also cut off the "above" call stack, such that `run_local` (and anything
|
|
1014
|
+
# above that) is ignored, so it is possible to use `run_local` in nested code.
|
|
1015
|
+
#
|
|
1016
|
+
# A valid stack looks like this:
|
|
1017
|
+
# * `__init_local__` as deepest frame (which would then call
|
|
1018
|
+
# `__init_with_arg_check__` -> `__init__` if validation passes).
|
|
1019
|
+
# * If a chainlet has no base classes, this can *only* be called from
|
|
1020
|
+
# `__init_local__` - the part when the chainlet needs to be instantiated and
|
|
1021
|
+
# added to `cls_to_instance`.
|
|
1022
|
+
# * If a chainlet has other chainlets as base classes, they may call a chain
|
|
1023
|
+
# of `super().__init()`. Each will add a triple of
|
|
1024
|
+
# (__init__, __init_with_arg_check__, __init_local__) to the stack. While
|
|
1025
|
+
# these 3 init layers belong to the different base classes, the type of the
|
|
1026
|
+
# `self` arg is fixed.
|
|
1027
|
+
#
|
|
1028
|
+
# To detect invalid stacks we can rephrase this: `__init_local__` can only be
|
|
1029
|
+
# called under either of these conditions:
|
|
1030
|
+
# * From `__init_local__` when needing to populate `cls_to_instance`.
|
|
1031
|
+
# * From a subclass's `__init__` using `super().__init__()`. This means the
|
|
1032
|
+
# type (and instance) of the `self` arg in the calling `__init_local__` and
|
|
1033
|
+
# the invoked `__init__` must are identical. In the forbidden situation that
|
|
1034
|
+
# for example Chainlet `A` tries to create an instance of `B` inside its
|
|
1035
|
+
# `__init__` the `self` args are two different instances.
|
|
1036
|
+
substack = stack[:levels_below_run_local]
|
|
1037
|
+
parts = ["-------- Chainlet Instantiation Stack --------"]
|
|
1038
|
+
# Track the owner classes encountered in the stack to detect invalid scenarios
|
|
1039
|
+
transformed_stack = []
|
|
1040
|
+
for frame in substack:
|
|
1041
|
+
func_name = frame.function
|
|
1042
|
+
line_number = frame.lineno
|
|
1043
|
+
local_vars = frame.frame.f_locals
|
|
1044
|
+
init_owner_class = None
|
|
1045
|
+
self_value = None
|
|
1046
|
+
# Determine if "self" exists and extract the owner class
|
|
1047
|
+
if "self" in local_vars:
|
|
1048
|
+
self_value = local_vars["self"]
|
|
1049
|
+
if func_name == _INIT_NAME:
|
|
1050
|
+
try:
|
|
1051
|
+
name_parts = frame.frame.f_code.co_qualname.split(".") # type: ignore[attr-defined]
|
|
1052
|
+
except AttributeError: # `co_qualname` only in Python 3.11+.
|
|
1053
|
+
name_parts = []
|
|
1054
|
+
if len(name_parts) > 1:
|
|
1055
|
+
init_owner_class = name_parts[-2]
|
|
1056
|
+
elif func_name == _INIT_LOCAL_NAME:
|
|
1057
|
+
assert "init_owner_class" in local_vars, (
|
|
1058
|
+
f"`{_INIT_LOCAL_NAME}` must capture `init_owner_class`"
|
|
1059
|
+
)
|
|
1060
|
+
init_owner_class = local_vars["init_owner_class"].__name__
|
|
1061
|
+
|
|
1062
|
+
if init_owner_class:
|
|
1063
|
+
parts.append(
|
|
1064
|
+
f"{func_name}:{line_number} | type(self)=<"
|
|
1065
|
+
f"{self_value.__class__.__name__}> method of <"
|
|
1066
|
+
f"{init_owner_class}>"
|
|
1067
|
+
)
|
|
1068
|
+
else:
|
|
1069
|
+
parts.append(
|
|
1070
|
+
f"{func_name}:l{line_number} | type(self)=<"
|
|
1071
|
+
f"{self_value.__class__.__name__}>"
|
|
1072
|
+
)
|
|
1073
|
+
else:
|
|
1074
|
+
parts.append(f"{func_name}:l{line_number}")
|
|
1075
|
+
|
|
1076
|
+
transformed_stack.append((func_name, self_value, frame))
|
|
1077
|
+
|
|
1078
|
+
if len(parts) > 1:
|
|
1079
|
+
logging.debug("\n".join(parts))
|
|
1080
|
+
|
|
1081
|
+
# Analyze the stack after preparing relevant information.
|
|
1082
|
+
for i in range(len(transformed_stack) - 1):
|
|
1083
|
+
func_name, self_value, _ = transformed_stack[i]
|
|
1084
|
+
up_func_name, up_self_value, up_frame = transformed_stack[i + 1]
|
|
1085
|
+
if func_name != _INIT_LOCAL_NAME:
|
|
1086
|
+
continue # OK, we only validate `__init_local__` invocations.
|
|
1087
|
+
# We are in `__init_local__`. Now check who and how called it.
|
|
1088
|
+
if up_func_name == _INIT_LOCAL_NAME:
|
|
1089
|
+
# Note: in this case `self` in the current frame is different then
|
|
1090
|
+
# self in the parent frame, since a new instance is created.
|
|
1091
|
+
continue # Ok, populating `cls_to_instance`.
|
|
1092
|
+
if up_func_name == _INIT_NAME and self_value == up_self_value:
|
|
1093
|
+
continue # OK, call to `super().__init__()`.
|
|
1094
|
+
|
|
1095
|
+
# Everything else is invalid.
|
|
1096
|
+
code_context = up_frame.code_context
|
|
1097
|
+
assert code_context is not None
|
|
1098
|
+
location = (
|
|
1099
|
+
f"{up_frame.filename}:{up_frame.lineno} ({up_frame.function})\n"
|
|
1100
|
+
f" {code_context[0].strip()}"
|
|
1101
|
+
)
|
|
1102
|
+
raise definitions.ChainsRuntimeError(
|
|
1103
|
+
_instantiation_error_msg(chainlet_descriptor.name, location)
|
|
1104
|
+
)
|
|
1105
|
+
|
|
1106
|
+
__original_init__ = chainlet_descriptor.chainlet_cls.__init__
|
|
1107
|
+
|
|
1108
|
+
@functools.wraps(__original_init__)
|
|
1109
|
+
def __init_local__(self: definitions.ABCChainlet, **kwargs) -> None:
|
|
1110
|
+
logging.debug(f"Patched `__init__` of `{chainlet_descriptor.name}`.")
|
|
1111
|
+
stack_depth = run_local_stack_depth.get(None)
|
|
1112
|
+
assert stack_depth is not None, "__init_local__ is only called in context."
|
|
1113
|
+
stack = inspect.stack()
|
|
1114
|
+
current_stack_depth = len(stack)
|
|
1115
|
+
levels_below_run_local = current_stack_depth - stack_depth
|
|
1116
|
+
# Capture `init_owner_class` in locals, because we check it in
|
|
1117
|
+
# `_detect_naive_instantiations`.
|
|
1118
|
+
init_owner_class = chainlet_descriptor.chainlet_cls # noqa: F841
|
|
1119
|
+
_detect_naive_instantiations(stack, levels_below_run_local)
|
|
1120
|
+
|
|
1121
|
+
kwargs_mod = dict(kwargs)
|
|
1122
|
+
if (
|
|
1123
|
+
chainlet_descriptor.has_context
|
|
1124
|
+
and definitions.CONTEXT_ARG_NAME not in kwargs_mod
|
|
1125
|
+
):
|
|
1126
|
+
kwargs_mod[definitions.CONTEXT_ARG_NAME] = definitions.DeploymentContext(
|
|
1127
|
+
secrets=secrets,
|
|
1128
|
+
data_dir=data_dir,
|
|
1129
|
+
chainlet_to_service=chainlet_to_service,
|
|
1130
|
+
)
|
|
1131
|
+
for arg_name, dep in chainlet_descriptor.dependencies.items():
|
|
1132
|
+
chainlet_cls = dep.chainlet_cls
|
|
1133
|
+
if arg_name in kwargs_mod:
|
|
1134
|
+
logging.debug(
|
|
1135
|
+
f"Use given instance for `{arg_name}` of type `{dep.name}`."
|
|
1136
|
+
)
|
|
1137
|
+
continue
|
|
1138
|
+
if chainlet_cls in cls_to_instance:
|
|
1139
|
+
logging.debug(
|
|
1140
|
+
f"Use previously created `{arg_name}` of type `{dep.name}`."
|
|
1141
|
+
)
|
|
1142
|
+
kwargs_mod[arg_name] = cls_to_instance[chainlet_cls]
|
|
1143
|
+
else:
|
|
1144
|
+
logging.debug(
|
|
1145
|
+
f"Create new instance for `{arg_name}` of type `{dep.name}`. "
|
|
1146
|
+
f"Calling patched __init__."
|
|
1147
|
+
)
|
|
1148
|
+
assert chainlet_cls.meta_data.init_is_patched
|
|
1149
|
+
# Dependency chainlets are instantiated here, using their __init__
|
|
1150
|
+
# that is patched for local.
|
|
1151
|
+
logging.info(f"Making first {dep.name}.")
|
|
1152
|
+
instance = chainlet_cls() # type: ignore # Here init args are patched.
|
|
1153
|
+
cls_to_instance[chainlet_cls] = instance
|
|
1154
|
+
kwargs_mod[arg_name] = instance
|
|
1155
|
+
|
|
1156
|
+
logging.debug(f"Calling original __init__ of {chainlet_descriptor.name}.")
|
|
1157
|
+
__original_init__(self, **kwargs_mod)
|
|
1158
|
+
|
|
1159
|
+
return __init_local__
|
|
1160
|
+
|
|
1161
|
+
|
|
1162
|
+
@contextlib.contextmanager
|
|
1163
|
+
@raise_validation_errors_before
|
|
1164
|
+
def run_local(
|
|
1165
|
+
secrets: Mapping[str, str],
|
|
1166
|
+
data_dir: Optional[pathlib.Path],
|
|
1167
|
+
chainlet_to_service: Mapping[str, definitions.DeployedServiceDescriptor],
|
|
1168
|
+
) -> Any:
|
|
1169
|
+
"""Context to run Chainlets with dependency injection from local instances."""
|
|
1170
|
+
type_to_instance: MutableMapping[
|
|
1171
|
+
Type[definitions.ABCChainlet], definitions.ABCChainlet
|
|
1172
|
+
] = {}
|
|
1173
|
+
original_inits: MutableMapping[Type[definitions.ABCChainlet], Callable] = {}
|
|
1174
|
+
|
|
1175
|
+
# Capture the stack depth when entering the context manager. The stack is used
|
|
1176
|
+
# to check that chainlets' `__init__` methods are only called within this context
|
|
1177
|
+
# manager, to flag naive instantiations.
|
|
1178
|
+
stack_depth = len(inspect.stack())
|
|
1179
|
+
for chainlet_descriptor in _global_chainlet_registry.chainlet_descriptors:
|
|
1180
|
+
original_inits[chainlet_descriptor.chainlet_cls] = (
|
|
1181
|
+
chainlet_descriptor.chainlet_cls.__init__
|
|
1182
|
+
)
|
|
1183
|
+
init_for_local = _create_modified_init_for_local(
|
|
1184
|
+
chainlet_descriptor,
|
|
1185
|
+
type_to_instance,
|
|
1186
|
+
secrets,
|
|
1187
|
+
data_dir,
|
|
1188
|
+
chainlet_to_service,
|
|
1189
|
+
)
|
|
1190
|
+
chainlet_descriptor.chainlet_cls.__init__ = init_for_local # type: ignore[method-assign]
|
|
1191
|
+
chainlet_descriptor.chainlet_cls.meta_data.init_is_patched = True
|
|
1192
|
+
# Subtract 2 levels: `run_local` (this) and `__enter__` (from @contextmanager).
|
|
1193
|
+
token = run_local_stack_depth.set(stack_depth - 2)
|
|
1194
|
+
try:
|
|
1195
|
+
yield
|
|
1196
|
+
finally:
|
|
1197
|
+
# Restore original classes to unpatched state.
|
|
1198
|
+
for chainlet_cls, original_init in original_inits.items():
|
|
1199
|
+
chainlet_cls.__init__ = original_init # type: ignore[method-assign]
|
|
1200
|
+
chainlet_cls.meta_data.init_is_patched = False
|
|
1201
|
+
|
|
1202
|
+
run_local_stack_depth.reset(token)
|
|
1203
|
+
|
|
1204
|
+
|
|
1205
|
+
########################################################################################
|
|
1206
|
+
|
|
1207
|
+
|
|
1208
|
+
def entrypoint(
|
|
1209
|
+
cls_or_chain_name: Optional[Union[Type[ChainletT], str]] = None,
|
|
1210
|
+
) -> Union[Callable[[Type[ChainletT]], Type[ChainletT]], Type[ChainletT]]:
|
|
1211
|
+
"""Decorator to tag a Chainlet as an entrypoint.
|
|
1212
|
+
Can be used with or without chain name argument.
|
|
1213
|
+
"""
|
|
1214
|
+
|
|
1215
|
+
def decorator(cls: Type[ChainletT]) -> Type[ChainletT]:
|
|
1216
|
+
if not (utils.issubclass_safe(cls, definitions.ABCChainlet)):
|
|
1217
|
+
src_path = os.path.abspath(inspect.getfile(cls))
|
|
1218
|
+
line = inspect.getsourcelines(cls)[1]
|
|
1219
|
+
location = _ErrorLocation(src_path=src_path, line=line)
|
|
1220
|
+
_collect_error(
|
|
1221
|
+
"Only Chainlet classes can be marked as entrypoint.",
|
|
1222
|
+
_ErrorKind.TYPE_ERROR,
|
|
1223
|
+
location,
|
|
1224
|
+
)
|
|
1225
|
+
cls.meta_data.is_entrypoint = True
|
|
1226
|
+
if isinstance(cls_or_chain_name, str):
|
|
1227
|
+
cls.meta_data.chain_name = cls_or_chain_name
|
|
1228
|
+
return cls
|
|
1229
|
+
|
|
1230
|
+
if isinstance(cls_or_chain_name, str):
|
|
1231
|
+
return decorator
|
|
1232
|
+
|
|
1233
|
+
assert cls_or_chain_name is not None
|
|
1234
|
+
return decorator(cls_or_chain_name) # Decorator used without arguments
|
|
1235
|
+
|
|
1236
|
+
|
|
1237
|
+
class _ABCImporter(abc.ABC):
|
|
1238
|
+
@classmethod
|
|
1239
|
+
@abc.abstractmethod
|
|
1240
|
+
def _no_entrypoint_error(cls, module_path: pathlib.Path) -> ValueError:
|
|
1241
|
+
pass
|
|
1242
|
+
|
|
1243
|
+
@classmethod
|
|
1244
|
+
@abc.abstractmethod
|
|
1245
|
+
def _multiple_entrypoints_error(
|
|
1246
|
+
cls, module_path: pathlib.Path, entrypoints: set[type[definitions.ABCChainlet]]
|
|
1247
|
+
) -> ValueError:
|
|
1248
|
+
pass
|
|
1249
|
+
|
|
1250
|
+
@classmethod
|
|
1251
|
+
@abc.abstractmethod
|
|
1252
|
+
def _target_cls_type(cls) -> Type[definitions.ABCChainlet]:
|
|
1253
|
+
pass
|
|
1254
|
+
|
|
1255
|
+
@classmethod
|
|
1256
|
+
def _get_entrypoint_chainlets(cls, symbols) -> set[Type[definitions.ABCChainlet]]:
|
|
1257
|
+
return {
|
|
1258
|
+
sym
|
|
1259
|
+
for sym in symbols
|
|
1260
|
+
if utils.issubclass_safe(sym, cls._target_cls_type())
|
|
1261
|
+
and cast(definitions.ABCChainlet, sym).meta_data.is_entrypoint
|
|
1262
|
+
}
|
|
1263
|
+
|
|
1264
|
+
@classmethod
|
|
1265
|
+
def _load_module(cls, module_path: pathlib.Path) -> tuple[types.ModuleType, Loader]:
|
|
1266
|
+
"""The context manager ensures that modules imported by the Model/Chain
|
|
1267
|
+
are removed upon exit.
|
|
1268
|
+
|
|
1269
|
+
I.e. aiming at making the import idempotent for common usages, although there could
|
|
1270
|
+
be additional side effects not accounted for by this implementation."""
|
|
1271
|
+
module_name = module_path.stem # Use the file's name as the module name
|
|
1272
|
+
if not os.path.isfile(module_path):
|
|
1273
|
+
raise ImportError(
|
|
1274
|
+
f"`{module_path}` is not a file. You must point to a python file where "
|
|
1275
|
+
f"the entrypoint is defined."
|
|
1276
|
+
)
|
|
1277
|
+
|
|
1278
|
+
import_error_msg = f"Could not import `{module_path}`. Check path."
|
|
1279
|
+
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
|
1280
|
+
if not spec or not spec.loader:
|
|
1281
|
+
raise ImportError(import_error_msg)
|
|
1282
|
+
|
|
1283
|
+
module = importlib.util.module_from_spec(spec)
|
|
1284
|
+
module.__file__ = str(module_path)
|
|
1285
|
+
# Since the framework depends on tracking the source files via `inspect` and this
|
|
1286
|
+
# depends on the modules bein properly registered in `sys.modules`, we have to
|
|
1287
|
+
# manually do this here (because importlib does not do it automatically). This
|
|
1288
|
+
# registration has to stay at least until the push command has finished.
|
|
1289
|
+
if module_name in sys.modules:
|
|
1290
|
+
raise ImportError(
|
|
1291
|
+
f"{import_error_msg} There is already a module in `sys.modules` "
|
|
1292
|
+
f"with name `{module_name}`. Overwriting that value is unsafe. "
|
|
1293
|
+
"Try renaming your source file."
|
|
1294
|
+
)
|
|
1295
|
+
|
|
1296
|
+
sys.modules[module_name] = module
|
|
1297
|
+
# Add path for making absolute imports relative to the source_module's dir.
|
|
1298
|
+
sys.path.insert(0, str(module_path.parent))
|
|
1299
|
+
|
|
1300
|
+
return module, spec.loader
|
|
1301
|
+
|
|
1302
|
+
@classmethod
|
|
1303
|
+
def _cleanup_module_imports(
|
|
1304
|
+
cls,
|
|
1305
|
+
modules_before: set[str],
|
|
1306
|
+
modules_after: set[str],
|
|
1307
|
+
module_path: pathlib.Path,
|
|
1308
|
+
):
|
|
1309
|
+
modules_diff = modules_after - modules_before
|
|
1310
|
+
# Apparently torch import leaves some side effects that cannot be reverted
|
|
1311
|
+
# by deleting the modules and would lead to a crash when another import
|
|
1312
|
+
# is attempted. Since torch is a common lib, we make this explicit special
|
|
1313
|
+
# case and just leave those modules.
|
|
1314
|
+
# TODO: this seems still brittle and other modules might cause similar problems.
|
|
1315
|
+
# it would be good to find a more principled solution.
|
|
1316
|
+
modules_to_delete = {
|
|
1317
|
+
s for s in modules_diff if not (s.startswith("torch.") or s == "torch")
|
|
1318
|
+
}
|
|
1319
|
+
if torch_modules := modules_diff - modules_to_delete:
|
|
1320
|
+
logging.debug(
|
|
1321
|
+
f"Keeping torch modules after import context: {torch_modules}"
|
|
1322
|
+
)
|
|
1323
|
+
|
|
1324
|
+
logging.debug(
|
|
1325
|
+
f"Deleting modules when exiting import context: {modules_to_delete}"
|
|
1326
|
+
)
|
|
1327
|
+
for mod in modules_to_delete:
|
|
1328
|
+
del sys.modules[mod]
|
|
1329
|
+
try:
|
|
1330
|
+
sys.path.remove(str(module_path.parent))
|
|
1331
|
+
except ValueError: # In case the value was already removed for whatever reason.
|
|
1332
|
+
pass
|
|
1333
|
+
|
|
1334
|
+
@classmethod
|
|
1335
|
+
@contextlib.contextmanager
|
|
1336
|
+
def import_target(
|
|
1337
|
+
cls, module_path: pathlib.Path, target_name: Optional[str] = None
|
|
1338
|
+
) -> Iterator[Type[definitions.ABCChainlet]]:
|
|
1339
|
+
resolved_module_path = pathlib.Path(module_path).resolve()
|
|
1340
|
+
modules_before = set(sys.modules.keys())
|
|
1341
|
+
module, loader = cls._load_module(module_path)
|
|
1342
|
+
modules_after = set()
|
|
1343
|
+
|
|
1344
|
+
chainlets_before = _global_chainlet_registry.get_chainlet_names()
|
|
1345
|
+
chainlets_after = set()
|
|
1346
|
+
try:
|
|
1347
|
+
try:
|
|
1348
|
+
loader.exec_module(module)
|
|
1349
|
+
raise_validation_errors()
|
|
1350
|
+
finally:
|
|
1351
|
+
modules_after = set(sys.modules.keys())
|
|
1352
|
+
chainlets_after = _global_chainlet_registry.get_chainlet_names()
|
|
1353
|
+
|
|
1354
|
+
if target_name:
|
|
1355
|
+
target_cls = getattr(module, target_name, None)
|
|
1356
|
+
if not target_cls:
|
|
1357
|
+
raise AttributeError(
|
|
1358
|
+
f"Target class `{target_name}` not found "
|
|
1359
|
+
f"in `{resolved_module_path}`."
|
|
1360
|
+
)
|
|
1361
|
+
if not utils.issubclass_safe(target_cls, cls._target_cls_type()):
|
|
1362
|
+
raise TypeError(
|
|
1363
|
+
f"Target `{target_cls}` is not a {cls._target_cls_type()}."
|
|
1364
|
+
)
|
|
1365
|
+
else:
|
|
1366
|
+
module_vars = (getattr(module, name) for name in dir(module))
|
|
1367
|
+
entrypoints = cls._get_entrypoint_chainlets(module_vars)
|
|
1368
|
+
if len(entrypoints) == 0:
|
|
1369
|
+
raise cls._no_entrypoint_error(module_path)
|
|
1370
|
+
elif len(entrypoints) > 1:
|
|
1371
|
+
raise cls._multiple_entrypoints_error(module_path, entrypoints)
|
|
1372
|
+
target_cls = utils.expect_one(entrypoints)
|
|
1373
|
+
yield target_cls
|
|
1374
|
+
finally:
|
|
1375
|
+
cls._cleanup_module_imports(
|
|
1376
|
+
modules_before, modules_after, resolved_module_path
|
|
1377
|
+
)
|
|
1378
|
+
for chainlet_name in chainlets_after - chainlets_before:
|
|
1379
|
+
_global_chainlet_registry.unregister_chainlet(chainlet_name)
|
|
1380
|
+
|
|
1381
|
+
|
|
1382
|
+
class ChainletImporter(_ABCImporter):
|
|
1383
|
+
@classmethod
|
|
1384
|
+
def _no_entrypoint_error(cls, module_path: pathlib.Path) -> ValueError:
|
|
1385
|
+
return ValueError(
|
|
1386
|
+
"No `target_name` was specified and no Chainlet in "
|
|
1387
|
+
f"`{module_path}` was tagged with `@chains.mark_entrypoint`. Tag "
|
|
1388
|
+
"one Chainlet or provide the Chainlet class name."
|
|
1389
|
+
)
|
|
1390
|
+
|
|
1391
|
+
@classmethod
|
|
1392
|
+
def _multiple_entrypoints_error(
|
|
1393
|
+
cls, module_path: pathlib.Path, entrypoints: set[type[definitions.ABCChainlet]]
|
|
1394
|
+
) -> ValueError:
|
|
1395
|
+
return ValueError(
|
|
1396
|
+
"`target_name` was not specified and multiple Chainlets in "
|
|
1397
|
+
f"`{module_path}` were tagged with `@chains.mark_entrypoint`. Tag "
|
|
1398
|
+
"one Chainlet or provide the Chainlet class name. Found Chainlets: "
|
|
1399
|
+
f"\n{list(cls.name for cls in entrypoints)}"
|
|
1400
|
+
)
|
|
1401
|
+
|
|
1402
|
+
@classmethod
|
|
1403
|
+
def _target_cls_type(cls) -> Type[definitions.ABCChainlet]:
|
|
1404
|
+
return ChainletBase
|
|
1405
|
+
|
|
1406
|
+
|
|
1407
|
+
class ModelImporter(_ABCImporter):
|
|
1408
|
+
@classmethod
|
|
1409
|
+
def _no_entrypoint_error(cls, module_path: pathlib.Path) -> ValueError:
|
|
1410
|
+
return ValueError(
|
|
1411
|
+
f"No Model class in `{module_path}` inherits from {cls._target_cls_type()}."
|
|
1412
|
+
)
|
|
1413
|
+
|
|
1414
|
+
@classmethod
|
|
1415
|
+
def _multiple_entrypoints_error(
|
|
1416
|
+
cls, module_path: pathlib.Path, entrypoints: set[type[definitions.ABCChainlet]]
|
|
1417
|
+
) -> ValueError:
|
|
1418
|
+
return ValueError(
|
|
1419
|
+
f"Multiple Model classes in `{module_path}` inherit from {cls._target_cls_type()}, "
|
|
1420
|
+
"but only one allowed. Found classes: "
|
|
1421
|
+
f"\n{list(cls.name for cls in entrypoints)}"
|
|
1422
|
+
)
|
|
1423
|
+
|
|
1424
|
+
@classmethod
|
|
1425
|
+
def _target_cls_type(cls) -> Type[definitions.ABCChainlet]:
|
|
1426
|
+
return ModelBase
|
|
1427
|
+
|
|
1428
|
+
|
|
1429
|
+
class ChainletBase(definitions.ABCChainlet):
|
|
1430
|
+
"""Base class for all chainlets.
|
|
1431
|
+
|
|
1432
|
+
Inheriting from this class adds validations to make sure subclasses adhere to the
|
|
1433
|
+
chainlet pattern and facilitates remote chainlet deployment.
|
|
1434
|
+
|
|
1435
|
+
Refer to `the docs <https://docs.baseten.co/chains/getting-started>`_ and this
|
|
1436
|
+
`example chainlet <https://github.com/basetenlabs/truss/blob/main/truss-chains/truss_chains/example_chainlet.py>`_
|
|
1437
|
+
for more guidance on how to create subclasses.
|
|
1438
|
+
"""
|
|
1439
|
+
|
|
1440
|
+
def __init_subclass__(cls, **kwargs) -> None:
|
|
1441
|
+
super().__init_subclass__(**kwargs)
|
|
1442
|
+
cls._framework_config = definitions.FrameworkConfig(
|
|
1443
|
+
entity_type="Chainlet",
|
|
1444
|
+
supports_dependencies=True,
|
|
1445
|
+
endpoint_method_name=definitions.RUN_REMOTE_METHOD_NAME,
|
|
1446
|
+
)
|
|
1447
|
+
# Each sub-class has own, isolated metadata, e.g. we don't want
|
|
1448
|
+
# `mark_entrypoint` to propagate to subclasses.
|
|
1449
|
+
cls.meta_data = definitions.ChainletMetadata()
|
|
1450
|
+
validate_and_register_cls(cls) # Errors are collected, not raised!
|
|
1451
|
+
# For default init (from `object`) we don't need to check anything.
|
|
1452
|
+
if cls.has_custom_init():
|
|
1453
|
+
original_init = cls.__init__
|
|
1454
|
+
|
|
1455
|
+
@functools.wraps(original_init)
|
|
1456
|
+
def __init_with_arg_check__(self, *args, **kwargs):
|
|
1457
|
+
if args:
|
|
1458
|
+
raise definitions.ChainsRuntimeError("Only kwargs are allowed.")
|
|
1459
|
+
ensure_args_are_injected(cls, original_init, kwargs)
|
|
1460
|
+
original_init(self, *args, **kwargs)
|
|
1461
|
+
|
|
1462
|
+
cls.__init__ = __init_with_arg_check__ # type: ignore[method-assign]
|
|
1463
|
+
|
|
1464
|
+
|
|
1465
|
+
class ModelBase(definitions.ABCChainlet):
|
|
1466
|
+
"""Base class for all standalone models.
|
|
1467
|
+
|
|
1468
|
+
Inheriting from this class adds validations to make sure subclasses adhere to the
|
|
1469
|
+
truss model pattern.
|
|
1470
|
+
"""
|
|
1471
|
+
|
|
1472
|
+
def __init_subclass__(cls, **kwargs) -> None:
|
|
1473
|
+
super().__init_subclass__(**kwargs)
|
|
1474
|
+
cls._framework_config = definitions.FrameworkConfig(
|
|
1475
|
+
entity_type="Model",
|
|
1476
|
+
supports_dependencies=False,
|
|
1477
|
+
endpoint_method_name=definitions.MODEL_ENDPOINT_METHOD_NAME,
|
|
1478
|
+
)
|
|
1479
|
+
cls.meta_data = definitions.ChainletMetadata(is_entrypoint=True)
|
|
1480
|
+
validate_and_register_cls(cls)
|