truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/__init__.py +10 -3
- truss/api/__init__.py +123 -0
- truss/api/definitions.py +51 -0
- truss/base/constants.py +116 -0
- truss/base/custom_types.py +29 -0
- truss/{errors.py → base/errors.py} +4 -0
- truss/base/trt_llm_config.py +310 -0
- truss/{truss_config.py → base/truss_config.py} +344 -31
- truss/{truss_spec.py → base/truss_spec.py} +20 -6
- truss/{validation.py → base/validation.py} +60 -11
- truss/cli/cli.py +841 -88
- truss/{remote → cli}/remote_cli.py +2 -7
- truss/contexts/docker_build_setup.py +67 -0
- truss/contexts/image_builder/cache_warmer.py +2 -8
- truss/contexts/image_builder/image_builder.py +1 -1
- truss/contexts/image_builder/serving_image_builder.py +292 -46
- truss/contexts/image_builder/util.py +1 -3
- truss/contexts/local_loader/docker_build_emulator.py +58 -0
- truss/contexts/local_loader/load_model_local.py +2 -2
- truss/contexts/local_loader/truss_module_loader.py +1 -1
- truss/contexts/local_loader/utils.py +1 -1
- truss/local/local_config.py +2 -6
- truss/local/local_config_handler.py +20 -5
- truss/patch/__init__.py +1 -0
- truss/patch/hash.py +4 -70
- truss/patch/signature.py +4 -16
- truss/patch/truss_dir_patch_applier.py +3 -78
- truss/remote/baseten/api.py +308 -23
- truss/remote/baseten/auth.py +3 -3
- truss/remote/baseten/core.py +257 -50
- truss/remote/baseten/custom_types.py +44 -0
- truss/remote/baseten/error.py +4 -0
- truss/remote/baseten/remote.py +369 -118
- truss/remote/baseten/service.py +118 -11
- truss/remote/baseten/utils/status.py +29 -0
- truss/remote/baseten/utils/tar.py +34 -22
- truss/remote/baseten/utils/transfer.py +36 -23
- truss/remote/remote_factory.py +14 -5
- truss/remote/truss_remote.py +72 -45
- truss/templates/base.Dockerfile.jinja +18 -16
- truss/templates/cache.Dockerfile.jinja +3 -3
- truss/{server → templates/control}/control/application.py +14 -35
- truss/{server → templates/control}/control/endpoints.py +39 -9
- truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
- truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
- truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
- truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
- truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
- truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
- truss/{server → templates/control}/control/server.py +11 -6
- truss/templates/control/requirements.txt +9 -0
- truss/templates/custom_python_dx/my_model.py +28 -0
- truss/templates/docker_server/proxy.conf.jinja +42 -0
- truss/templates/docker_server/supervisord.conf.jinja +27 -0
- truss/templates/docker_server_requirements.txt +1 -0
- truss/templates/server/common/errors.py +231 -0
- truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
- truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
- truss/{server → templates/server}/common/retry.py +1 -0
- truss/{server → templates/server}/common/schema.py +11 -9
- truss/templates/server/common/tracing.py +157 -0
- truss/templates/server/main.py +9 -0
- truss/templates/server/model_wrapper.py +961 -0
- truss/templates/server/requirements.txt +21 -0
- truss/templates/server/truss_server.py +447 -0
- truss/templates/server.Dockerfile.jinja +62 -14
- truss/templates/shared/dynamic_config_resolver.py +28 -0
- truss/templates/shared/lazy_data_resolver.py +164 -0
- truss/templates/shared/log_config.py +125 -0
- truss/{server → templates}/shared/secrets_resolver.py +1 -2
- truss/{server → templates}/shared/serialization.py +31 -9
- truss/{server → templates}/shared/util.py +3 -13
- truss/templates/trtllm-audio/model/model.py +49 -0
- truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
- truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
- truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
- truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
- truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
- truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
- truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
- truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
- truss/templates/trtllm-briton/src/extension.py +64 -0
- truss/tests/conftest.py +302 -94
- truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
- truss/tests/contexts/local_loader/test_load_local.py +2 -2
- truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
- truss/tests/patch/test_calc_patch.py +439 -127
- truss/tests/patch/test_dir_signature.py +3 -12
- truss/tests/patch/test_hash.py +1 -1
- truss/tests/patch/test_signature.py +1 -1
- truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
- truss/tests/patch/test_types.py +2 -2
- truss/tests/remote/baseten/test_api.py +153 -58
- truss/tests/remote/baseten/test_auth.py +2 -1
- truss/tests/remote/baseten/test_core.py +160 -12
- truss/tests/remote/baseten/test_remote.py +489 -77
- truss/tests/remote/baseten/test_service.py +55 -0
- truss/tests/remote/test_remote_factory.py +16 -18
- truss/tests/remote/test_truss_remote.py +26 -17
- truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
- truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
- truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
- truss/tests/{server → templates/control}/control/test_server.py +79 -24
- truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
- truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
- truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
- truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
- truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
- truss/tests/{server → templates/server}/common/test_retry.py +3 -3
- truss/tests/templates/server/test_model_wrapper.py +248 -0
- truss/tests/{server → templates/server}/test_schema.py +3 -5
- truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
- truss/tests/test_build.py +9 -52
- truss/tests/test_config.py +336 -77
- truss/tests/test_context_builder_image.py +3 -11
- truss/tests/test_control_truss_patching.py +7 -12
- truss/tests/test_custom_server.py +38 -0
- truss/tests/test_data/context_builder_image_test/test.py +3 -0
- truss/tests/test_data/happy.ipynb +56 -0
- truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
- truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
- truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
- truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
- truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
- truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
- truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
- truss/tests/test_data/test_async_truss/__init__.py +0 -0
- truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/config.yaml +16 -0
- truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/config.yaml +13 -0
- truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
- truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
- truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
- truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
- truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
- truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
- truss/tests/test_data/test_openai/__init__.py +0 -0
- truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
- truss/tests/test_data/test_openai/model/__init__.py +0 -0
- truss/tests/test_data/test_openai/model/model.py +15 -0
- truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
- truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
- truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
- truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
- truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
- truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
- truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
- truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
- truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
- truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
- truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
- truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
- truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
- truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
- truss/tests/test_data/test_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss/config.yaml +4 -0
- truss/tests/test_data/test_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss/model/dummy +0 -0
- truss/tests/test_data/test_truss/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
- truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
- truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_docker.py +2 -1
- truss/tests/test_model_inference.py +1340 -292
- truss/tests/test_model_schema.py +33 -26
- truss/tests/test_testing_utilities_for_other_tests.py +50 -5
- truss/tests/test_truss_gatherer.py +3 -5
- truss/tests/test_truss_handle.py +62 -59
- truss/tests/test_util.py +2 -1
- truss/tests/test_validation.py +15 -13
- truss/tests/trt_llm/test_trt_llm_config.py +41 -0
- truss/tests/trt_llm/test_validation.py +91 -0
- truss/tests/util/test_config_checks.py +40 -0
- truss/tests/util/test_env_vars.py +14 -0
- truss/tests/util/test_path.py +10 -23
- truss/trt_llm/config_checks.py +43 -0
- truss/trt_llm/validation.py +42 -0
- truss/truss_handle/__init__.py +0 -0
- truss/truss_handle/build.py +122 -0
- truss/{decorators.py → truss_handle/decorators.py} +1 -1
- truss/truss_handle/patch/__init__.py +0 -0
- truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
- truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
- truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
- truss/truss_handle/patch/hash.py +71 -0
- truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
- truss/truss_handle/patch/signature.py +22 -0
- truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
- truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
- truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
- truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
- truss/util/.truss_ignore +3 -0
- truss/{docker.py → util/docker.py} +6 -2
- truss/util/download.py +6 -15
- truss/util/env_vars.py +41 -0
- truss/util/log_utils.py +52 -0
- truss/util/path.py +20 -20
- truss/util/requirements.py +11 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
- truss-0.60.0.dist-info/RECORD +324 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
- truss-0.60.0.dist-info/entry_points.txt +4 -0
- truss_chains/__init__.py +71 -0
- truss_chains/definitions.py +756 -0
- truss_chains/deployment/__init__.py +0 -0
- truss_chains/deployment/code_gen.py +816 -0
- truss_chains/deployment/deployment_client.py +871 -0
- truss_chains/framework.py +1480 -0
- truss_chains/public_api.py +231 -0
- truss_chains/py.typed +0 -0
- truss_chains/pydantic_numpy.py +131 -0
- truss_chains/reference_code/reference_chainlet.py +34 -0
- truss_chains/reference_code/reference_model.py +10 -0
- truss_chains/remote_chainlet/__init__.py +0 -0
- truss_chains/remote_chainlet/model_skeleton.py +60 -0
- truss_chains/remote_chainlet/stub.py +380 -0
- truss_chains/remote_chainlet/utils.py +332 -0
- truss_chains/streaming.py +378 -0
- truss_chains/utils.py +178 -0
- CODE_OF_CONDUCT.md +0 -131
- CONTRIBUTING.md +0 -48
- README.md +0 -137
- context_builder.Dockerfile +0 -24
- truss/blob/blob_backend.py +0 -10
- truss/blob/blob_backend_registry.py +0 -23
- truss/blob/http_public_blob_backend.py +0 -23
- truss/build/__init__.py +0 -2
- truss/build/build.py +0 -143
- truss/build/configure.py +0 -63
- truss/cli/__init__.py +0 -2
- truss/cli/console.py +0 -5
- truss/cli/create.py +0 -5
- truss/config/trt_llm.py +0 -81
- truss/constants.py +0 -61
- truss/model_inference.py +0 -123
- truss/patch/types.py +0 -30
- truss/pytest.ini +0 -7
- truss/server/common/errors.py +0 -100
- truss/server/common/termination_handler_middleware.py +0 -64
- truss/server/common/truss_server.py +0 -389
- truss/server/control/patch/model_code_patch_applier.py +0 -46
- truss/server/control/patch/requirement_name_identifier.py +0 -17
- truss/server/inference_server.py +0 -29
- truss/server/model_wrapper.py +0 -434
- truss/server/shared/logging.py +0 -81
- truss/templates/trtllm/model/model.py +0 -97
- truss/templates/trtllm/packages/build_engine_utils.py +0 -34
- truss/templates/trtllm/packages/constants.py +0 -11
- truss/templates/trtllm/packages/schema.py +0 -216
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
- truss/templates/trtllm/packages/triton_client.py +0 -150
- truss/templates/trtllm/packages/utils.py +0 -43
- truss/test_data/context_builder_image_test/test.py +0 -4
- truss/test_data/happy.ipynb +0 -54
- truss/test_data/model_load_failure_test/config.yaml +0 -2
- truss/test_data/test_concurrency_truss/config.yaml +0 -2
- truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
- truss/test_data/test_streaming_truss/config.yaml +0 -3
- truss/test_data/test_truss/config.yaml +0 -2
- truss/tests/server/common/test_termination_handler_middleware.py +0 -93
- truss/tests/server/control/test_model_container_patch_applier.py +0 -203
- truss/tests/server/core/server/common/test_util.py +0 -19
- truss/tests/server/test_model_wrapper.py +0 -87
- truss/util/data_structures.py +0 -16
- truss-0.10.0rc1.dist-info/RECORD +0 -216
- truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
- truss/{server/shared → base}/__init__.py +0 -0
- truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
- truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
- truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
- truss/{server → templates/server}/common/__init__.py +0 -0
- truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
- truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
- truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
- truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
- truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
- truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
- truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
- truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
- truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
- truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
- truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
- truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
- truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
- truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
- /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
- /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
- /truss/{patch → truss_handle/patch}/constants.py +0 -0
- /truss/{notebook.py → util/notebook.py} +0 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import builtins
|
|
3
|
+
import contextlib
|
|
4
|
+
import contextvars
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import sys
|
|
8
|
+
import textwrap
|
|
9
|
+
import threading
|
|
10
|
+
import traceback
|
|
11
|
+
from typing import Dict, Iterator, Mapping, NoReturn, Optional, Type, TypeVar
|
|
12
|
+
|
|
13
|
+
import aiohttp
|
|
14
|
+
import fastapi
|
|
15
|
+
import httpx
|
|
16
|
+
import pydantic
|
|
17
|
+
import starlette.requests
|
|
18
|
+
from truss.templates.shared import dynamic_config_resolver
|
|
19
|
+
|
|
20
|
+
from truss_chains import definitions
|
|
21
|
+
|
|
22
|
+
T = TypeVar("T")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def populate_chainlet_service_predict_urls(
|
|
26
|
+
chainlet_to_service: Mapping[str, definitions.ServiceDescriptor],
|
|
27
|
+
) -> Mapping[str, definitions.DeployedServiceDescriptor]:
|
|
28
|
+
chainlet_to_deployed_service: Dict[str, definitions.DeployedServiceDescriptor] = {}
|
|
29
|
+
# If there are no dependencies of this chainlet, no need to derive dynamic URLs
|
|
30
|
+
if len(chainlet_to_service) == 0:
|
|
31
|
+
return chainlet_to_deployed_service
|
|
32
|
+
|
|
33
|
+
dynamic_chainlet_config_str = dynamic_config_resolver.get_dynamic_config_value_sync(
|
|
34
|
+
definitions.DYNAMIC_CHAINLET_CONFIG_KEY
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
if not dynamic_chainlet_config_str:
|
|
38
|
+
raise definitions.MissingDependencyError(
|
|
39
|
+
f"No '{definitions.DYNAMIC_CHAINLET_CONFIG_KEY}' "
|
|
40
|
+
"found. Cannot override Chainlet configs."
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
dynamic_chainlet_config = json.loads(dynamic_chainlet_config_str)
|
|
44
|
+
|
|
45
|
+
for chainlet_name, service_descriptor in chainlet_to_service.items():
|
|
46
|
+
display_name = service_descriptor.display_name
|
|
47
|
+
|
|
48
|
+
# NOTE: The Chainlet `display_name` in the Truss CLI
|
|
49
|
+
# corresponds to Chainlet `name` in the backend. As
|
|
50
|
+
# the dynamic Chainlet config is keyed on the backend
|
|
51
|
+
# Chainlet name, we have to look up config values by
|
|
52
|
+
# using the `display_name` in the service descriptor.
|
|
53
|
+
if display_name not in dynamic_chainlet_config:
|
|
54
|
+
raise definitions.MissingDependencyError(
|
|
55
|
+
f"Chainlet '{display_name}' not found in "
|
|
56
|
+
f"'{definitions.DYNAMIC_CHAINLET_CONFIG_KEY}'. "
|
|
57
|
+
f"Dynamic Chainlet config keys: {list(dynamic_chainlet_config)}."
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
chainlet_to_deployed_service[chainlet_name] = (
|
|
61
|
+
definitions.DeployedServiceDescriptor(
|
|
62
|
+
display_name=display_name,
|
|
63
|
+
name=service_descriptor.name,
|
|
64
|
+
options=service_descriptor.options,
|
|
65
|
+
predict_url=dynamic_chainlet_config[display_name]["predict_url"],
|
|
66
|
+
)
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
return chainlet_to_deployed_service
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class AsyncSafeCounter:
|
|
73
|
+
def __init__(self, initial: int = 0) -> None:
|
|
74
|
+
self._counter = initial
|
|
75
|
+
self._lock = asyncio.Lock()
|
|
76
|
+
|
|
77
|
+
async def increment(self) -> int:
|
|
78
|
+
async with self._lock:
|
|
79
|
+
self._counter += 1
|
|
80
|
+
return self._counter
|
|
81
|
+
|
|
82
|
+
async def decrement(self) -> int:
|
|
83
|
+
async with self._lock:
|
|
84
|
+
self._counter -= 1
|
|
85
|
+
return self._counter
|
|
86
|
+
|
|
87
|
+
async def __aenter__(self) -> int:
|
|
88
|
+
return await self.increment()
|
|
89
|
+
|
|
90
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
91
|
+
await self.decrement()
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class ThreadSafeCounter:
|
|
95
|
+
def __init__(self, initial: int = 0) -> None:
|
|
96
|
+
self._counter = initial
|
|
97
|
+
self._lock = threading.Lock()
|
|
98
|
+
|
|
99
|
+
def increment(self) -> int:
|
|
100
|
+
with self._lock:
|
|
101
|
+
self._counter += 1
|
|
102
|
+
return self._counter
|
|
103
|
+
|
|
104
|
+
def decrement(self) -> int:
|
|
105
|
+
with self._lock:
|
|
106
|
+
self._counter -= 1
|
|
107
|
+
return self._counter
|
|
108
|
+
|
|
109
|
+
def __enter__(self) -> int:
|
|
110
|
+
return self.increment()
|
|
111
|
+
|
|
112
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
113
|
+
self.decrement()
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
_trace_parent_context: contextvars.ContextVar[str] = contextvars.ContextVar(
|
|
117
|
+
"trace_parent"
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@contextlib.contextmanager
|
|
122
|
+
def _trace_parent(request: starlette.requests.Request) -> Iterator[None]:
|
|
123
|
+
token = _trace_parent_context.set(
|
|
124
|
+
request.headers.get(definitions.OTEL_TRACE_PARENT_HEADER_KEY, "")
|
|
125
|
+
)
|
|
126
|
+
try:
|
|
127
|
+
yield
|
|
128
|
+
finally:
|
|
129
|
+
_trace_parent_context.reset(token)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@contextlib.contextmanager
|
|
133
|
+
def trace_parent_raw(trace_parent: str) -> Iterator[None]:
|
|
134
|
+
token = _trace_parent_context.set(trace_parent)
|
|
135
|
+
try:
|
|
136
|
+
yield
|
|
137
|
+
finally:
|
|
138
|
+
_trace_parent_context.reset(token)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def get_trace_parent() -> Optional[str]:
|
|
142
|
+
return _trace_parent_context.get()
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def pydantic_set_field_dict(obj: pydantic.BaseModel) -> dict[str, pydantic.BaseModel]:
|
|
146
|
+
"""Like `BaseModel.model_dump(exclude_unset=True), but only top-level.
|
|
147
|
+
|
|
148
|
+
This is used to get kwargs for invoking a function, while dropping fields for which
|
|
149
|
+
there is no value explicitly set in the pydantic model. A field is considered unset
|
|
150
|
+
if the key was not present in the incoming JSON request (from which the model was
|
|
151
|
+
parsed/initialized) and the pydantic model has a default value, such as `None`.
|
|
152
|
+
|
|
153
|
+
By dropping these unset fields, the default values from the function definition
|
|
154
|
+
will be used instead. This behavior ensures correct handling of arguments where
|
|
155
|
+
the function has a default, such as in the case of `run_remote`. If the model has
|
|
156
|
+
an optional field defaulting to `None`, this approach differentiates between
|
|
157
|
+
the user explicitly passing a value of `None` and the field being unset in the
|
|
158
|
+
request.
|
|
159
|
+
|
|
160
|
+
"""
|
|
161
|
+
return {name: getattr(obj, name) for name in obj.model_fields_set}
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
# Error Propagation Utils. #############################################################
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _handle_exception(exception: Exception) -> NoReturn:
|
|
168
|
+
"""Raises `HTTPException` with `RemoteErrorDetail`."""
|
|
169
|
+
if hasattr(exception, "__module__"):
|
|
170
|
+
exception_module_name = exception.__module__
|
|
171
|
+
else:
|
|
172
|
+
exception_module_name = None
|
|
173
|
+
|
|
174
|
+
error_stack = traceback.extract_tb(exception.__traceback__)
|
|
175
|
+
# Filter everything before (model.py) and after (stubs, error handling) so that only
|
|
176
|
+
# user-defined code remains. See test_e2e.py::test_chain for expected results.
|
|
177
|
+
model_predict_index = 0
|
|
178
|
+
first_stub_index = len(error_stack)
|
|
179
|
+
for i, frame in enumerate(error_stack):
|
|
180
|
+
if frame.filename.endswith("model/model.py") and frame.name == "predict":
|
|
181
|
+
model_predict_index = i + 1
|
|
182
|
+
if frame.filename.endswith("remote_chainlet/stub.py") and frame.name.startswith(
|
|
183
|
+
"predict" # predict sycnc|async|stream.
|
|
184
|
+
):
|
|
185
|
+
first_stub_index = i - 1
|
|
186
|
+
break
|
|
187
|
+
|
|
188
|
+
final_tb = error_stack[model_predict_index:first_stub_index]
|
|
189
|
+
stack = [definitions.StackFrame.from_frame_summary(frame) for frame in final_tb]
|
|
190
|
+
error = definitions.RemoteErrorDetail(
|
|
191
|
+
exception_cls_name=exception.__class__.__name__,
|
|
192
|
+
exception_module_name=exception_module_name,
|
|
193
|
+
exception_message=str(exception),
|
|
194
|
+
user_stack_trace=list(stack),
|
|
195
|
+
)
|
|
196
|
+
raise fastapi.HTTPException(
|
|
197
|
+
status_code=500, detail=error.model_dump()
|
|
198
|
+
) from exception
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
@contextlib.contextmanager
|
|
202
|
+
def _exception_to_http_error() -> Iterator[None]:
|
|
203
|
+
try:
|
|
204
|
+
yield
|
|
205
|
+
except Exception as e:
|
|
206
|
+
_handle_exception(e)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _resolve_exception_class(error: definitions.RemoteErrorDetail) -> Type[Exception]:
|
|
210
|
+
"""Tries to find the exception class in builtins or imported libs,
|
|
211
|
+
falls back to `definitions.GenericRemoteError` if not found."""
|
|
212
|
+
exception_cls = None
|
|
213
|
+
if error.exception_module_name is None:
|
|
214
|
+
exception_cls = getattr(builtins, error.exception_cls_name, None)
|
|
215
|
+
else:
|
|
216
|
+
if mod := sys.modules.get(error.exception_module_name):
|
|
217
|
+
exception_cls = getattr(mod, error.exception_cls_name, None)
|
|
218
|
+
|
|
219
|
+
if exception_cls is None:
|
|
220
|
+
logging.warning(
|
|
221
|
+
f"Could not resolve exception with name `{error.exception_cls_name}` "
|
|
222
|
+
f"and module `{error.exception_module_name}` - fall back to "
|
|
223
|
+
f"`{definitions.GenericRemoteException.__name__}`."
|
|
224
|
+
)
|
|
225
|
+
exception_cls = definitions.GenericRemoteException
|
|
226
|
+
|
|
227
|
+
if issubclass(exception_cls, pydantic.ValidationError):
|
|
228
|
+
# Cannot re-raise naively.
|
|
229
|
+
# https://github.com/pydantic/pydantic/issues/6734.
|
|
230
|
+
exception_cls = definitions.GenericRemoteException
|
|
231
|
+
|
|
232
|
+
return exception_cls
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def _handle_response_error(response_json: dict, base_msg: str):
|
|
236
|
+
try:
|
|
237
|
+
error_json = response_json["error"]
|
|
238
|
+
except KeyError as e:
|
|
239
|
+
logging.error(f"response_json: {response_json}")
|
|
240
|
+
raise ValueError(
|
|
241
|
+
f"{base_msg}. Could not get `error` field from JSON response."
|
|
242
|
+
) from e
|
|
243
|
+
|
|
244
|
+
try:
|
|
245
|
+
error = definitions.RemoteErrorDetail.model_validate(error_json)
|
|
246
|
+
except pydantic.ValidationError as e:
|
|
247
|
+
if isinstance(error_json, str):
|
|
248
|
+
msg = f"{base_msg}: '{error_json}'"
|
|
249
|
+
raise definitions.GenericRemoteException(msg) from None
|
|
250
|
+
raise ValueError(
|
|
251
|
+
f"{base_msg}: Could not parse chainlet error. Error details are expected "
|
|
252
|
+
"to be either a plain string (old truss models) or a serialized "
|
|
253
|
+
f"`{definitions.RemoteErrorDetail.__name__}`, got:\n{repr(error_json)}"
|
|
254
|
+
) from e
|
|
255
|
+
|
|
256
|
+
exception_cls = _resolve_exception_class(error)
|
|
257
|
+
error_format = textwrap.indent(error.format(), "│ ")
|
|
258
|
+
*lines, last_line = error_format.splitlines()
|
|
259
|
+
last_line = f"╰{last_line[1:]}" if last_line.startswith("│") else last_line
|
|
260
|
+
error_format = "\n".join(lines + [last_line])
|
|
261
|
+
msg = (
|
|
262
|
+
f"(showing chained remote errors, root error at the bottom)\n"
|
|
263
|
+
f"├─ {base_msg}\n"
|
|
264
|
+
f"{error_format}"
|
|
265
|
+
)
|
|
266
|
+
raise exception_cls(msg)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def _make_base_error_message(remote_name: str, http_status: int) -> str:
|
|
270
|
+
return (
|
|
271
|
+
f"Error calling dependency Chainlet `{remote_name}`, "
|
|
272
|
+
f"HTTP status={http_status}, trace ID=`{get_trace_parent()}`."
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def response_raise_errors(response: httpx.Response, remote_name: str) -> None:
|
|
277
|
+
"""In case of error, raise it.
|
|
278
|
+
|
|
279
|
+
If the response error contains `RemoteErrorDetail`, it tries to re-raise
|
|
280
|
+
the same exception that was raised remotely and falls back to
|
|
281
|
+
`GenericRemoteException` if the exception class could not be resolved.
|
|
282
|
+
|
|
283
|
+
Exception messages are chained to trace back to the root cause, i.e. the first
|
|
284
|
+
Chainlet that raised an exception. E.g. the message might look like this:
|
|
285
|
+
|
|
286
|
+
```
|
|
287
|
+
Chainlet-Traceback (most recent call last):
|
|
288
|
+
File "/packages/itest_chain.py", line 132, in run_remote
|
|
289
|
+
value = self._accumulate_parts(text_parts.parts)
|
|
290
|
+
File "/packages/itest_chain.py", line 144, in _accumulate_parts
|
|
291
|
+
value += self._text_to_num.run_remote(part)
|
|
292
|
+
ValueError: (showing chained remote errors, root error at the bottom)
|
|
293
|
+
├─ Error in dependency Chainlet `TextToNum` (HTTP status 500):
|
|
294
|
+
│ Chainlet-Traceback (most recent call last):
|
|
295
|
+
│ File "/packages/itest_chain.py", line 87, in run_remote
|
|
296
|
+
│ generated_text = self._replicator.run_remote(data)
|
|
297
|
+
│ ValueError: (showing chained remote errors, root error at the bottom)
|
|
298
|
+
│ ├─ Error in dependency Chainlet `TextReplicator` (HTTP status 500):
|
|
299
|
+
│ │ Chainlet-Traceback (most recent call last):
|
|
300
|
+
│ │ File "/packages/itest_chain.py", line 52, in run_remote
|
|
301
|
+
│ │ validate_data(data)
|
|
302
|
+
│ │ File "/packages/itest_chain.py", line 36, in validate_data
|
|
303
|
+
│ │ raise ValueError(f"This input is too long: {len(data)}.")
|
|
304
|
+
╰ ╰ ValueError: This input is too long: 100.
|
|
305
|
+
```
|
|
306
|
+
"""
|
|
307
|
+
if response.is_error:
|
|
308
|
+
base_msg = _make_base_error_message(remote_name, response.status_code)
|
|
309
|
+
try:
|
|
310
|
+
response_json = response.json()
|
|
311
|
+
except Exception as e:
|
|
312
|
+
raise ValueError(base_msg) from e
|
|
313
|
+
_handle_response_error(response_json, base_msg)
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
async def async_response_raise_errors(
|
|
317
|
+
response: aiohttp.ClientResponse, remote_name: str
|
|
318
|
+
) -> None:
|
|
319
|
+
"""Async version of `async_response_raise_errors`."""
|
|
320
|
+
if response.status >= 400:
|
|
321
|
+
base_msg = _make_base_error_message(remote_name, response.status)
|
|
322
|
+
try:
|
|
323
|
+
response_json = await response.json()
|
|
324
|
+
except Exception as e:
|
|
325
|
+
raise ValueError(base_msg) from e
|
|
326
|
+
_handle_response_error(response_json, base_msg)
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
@contextlib.contextmanager
|
|
330
|
+
def predict_context(request: starlette.requests.Request) -> Iterator[None]:
|
|
331
|
+
with _trace_parent(request), _exception_to_http_error():
|
|
332
|
+
yield
|
|
@@ -0,0 +1,378 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import dataclasses
|
|
3
|
+
import enum
|
|
4
|
+
import struct
|
|
5
|
+
import sys
|
|
6
|
+
from collections.abc import AsyncIterator
|
|
7
|
+
from typing import Generic, Optional, Protocol, Type, TypeVar, overload
|
|
8
|
+
|
|
9
|
+
import pydantic
|
|
10
|
+
|
|
11
|
+
_TAG_SIZE = 5 # uint8 + uint32.
|
|
12
|
+
|
|
13
|
+
_T = TypeVar("_T")
|
|
14
|
+
|
|
15
|
+
if sys.version_info < (3, 10):
|
|
16
|
+
|
|
17
|
+
async def anext(iterable: AsyncIterator[_T]) -> _T:
|
|
18
|
+
return await iterable.__anext__()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# Note on the (verbose) typing in this module: we want exact typing of the reader and
|
|
22
|
+
# writer helpers, while also allowing flexibility to users to leave out header/footer
|
|
23
|
+
# if not needed.
|
|
24
|
+
# Putting both a constraint on the header/footer types to be pydantic
|
|
25
|
+
# models, but also letting them be optional is not well-supported by typing tools,
|
|
26
|
+
# (missing feature is using type variables a constraints on other type variables).
|
|
27
|
+
#
|
|
28
|
+
# A functional, yet verbose workaround that gives correct variadic type inference,
|
|
29
|
+
# is using intermediate type variables `HeaderT` <-> `HeaderTT` and in conjunction with
|
|
30
|
+
# mapping out all usage combinations with overloads (the overloads essentially allow
|
|
31
|
+
# "conditional" binding of type vars). These overloads also allow to use granular
|
|
32
|
+
# reader/writer sub-classes conditionally, that have the read/write methods only for the
|
|
33
|
+
# data types configured, and implemented DRY with mixin classes.
|
|
34
|
+
ItemT = TypeVar("ItemT", bound=pydantic.BaseModel)
|
|
35
|
+
HeaderT = TypeVar("HeaderT", bound=pydantic.BaseModel)
|
|
36
|
+
FooterT = TypeVar("FooterT", bound=pydantic.BaseModel)
|
|
37
|
+
|
|
38
|
+
# Since header/footer could also be `None`, we need an extra type variable that
|
|
39
|
+
# can assume either `Type[HeaderT]` or `None` - `Type[None]` causes issues.
|
|
40
|
+
HeaderTT = TypeVar("HeaderTT")
|
|
41
|
+
FooterTT = TypeVar("FooterTT")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclasses.dataclass
|
|
45
|
+
class StreamTypes(Generic[ItemT, HeaderTT, FooterTT]):
|
|
46
|
+
item_type: Type[ItemT]
|
|
47
|
+
header_type: HeaderTT # Is either `Type[HeaderT]` or `None`.
|
|
48
|
+
footer_type: FooterTT # Is either `Type[FooterT]` or `None`.
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@overload
|
|
52
|
+
def stream_types(
|
|
53
|
+
item_type: Type[ItemT], *, header_type: Type[HeaderT], footer_type: Type[FooterT]
|
|
54
|
+
) -> StreamTypes[ItemT, HeaderT, FooterT]: ...
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@overload
|
|
58
|
+
def stream_types(
|
|
59
|
+
item_type: Type[ItemT], *, header_type: Type[HeaderT]
|
|
60
|
+
) -> StreamTypes[ItemT, HeaderT, None]: ...
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@overload
|
|
64
|
+
def stream_types(
|
|
65
|
+
item_type: Type[ItemT], *, footer_type: Type[FooterT]
|
|
66
|
+
) -> StreamTypes[ItemT, None, FooterT]: ...
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@overload
|
|
70
|
+
def stream_types(item_type: Type[ItemT]) -> StreamTypes[ItemT, None, None]: ...
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def stream_types(
|
|
74
|
+
item_type: Type[ItemT],
|
|
75
|
+
*,
|
|
76
|
+
header_type: Optional[Type[HeaderT]] = None,
|
|
77
|
+
footer_type: Optional[Type[FooterT]] = None,
|
|
78
|
+
) -> StreamTypes:
|
|
79
|
+
"""Creates a bundle of item type and potentially header/footer types,
|
|
80
|
+
each as pydantic model."""
|
|
81
|
+
# This indirection for creating `StreamTypes` is needed to get generic typing.
|
|
82
|
+
return StreamTypes(item_type, header_type, footer_type)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
# Reading ##############################################################################
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class _Delimiter(enum.IntEnum):
|
|
89
|
+
NOT_SET = enum.auto()
|
|
90
|
+
HEADER = enum.auto()
|
|
91
|
+
ITEM = enum.auto()
|
|
92
|
+
FOOTER = enum.auto()
|
|
93
|
+
END = enum.auto()
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class _Streamer(Generic[ItemT, HeaderTT, FooterTT]):
|
|
97
|
+
_stream_types: StreamTypes[ItemT, HeaderTT, FooterTT]
|
|
98
|
+
|
|
99
|
+
def __init__(self, types: StreamTypes[ItemT, HeaderTT, FooterTT]) -> None:
|
|
100
|
+
self._stream_types = types
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
# Reading ##############################################################################
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class _ByteReader:
|
|
107
|
+
"""Helper to provide `readexactly` API for an async bytes iterator."""
|
|
108
|
+
|
|
109
|
+
def __init__(self, source: AsyncIterator[bytes]) -> None:
|
|
110
|
+
self._source = source
|
|
111
|
+
self._buffer = bytearray()
|
|
112
|
+
|
|
113
|
+
async def readexactly(self, num_bytes: int) -> bytes:
|
|
114
|
+
while len(self._buffer) < num_bytes:
|
|
115
|
+
try:
|
|
116
|
+
chunk = await anext(self._source)
|
|
117
|
+
except StopAsyncIteration:
|
|
118
|
+
break
|
|
119
|
+
self._buffer.extend(chunk)
|
|
120
|
+
|
|
121
|
+
if len(self._buffer) < num_bytes:
|
|
122
|
+
if len(self._buffer) == 0:
|
|
123
|
+
raise EOFError()
|
|
124
|
+
raise asyncio.IncompleteReadError(self._buffer, num_bytes)
|
|
125
|
+
|
|
126
|
+
result = bytes(self._buffer[:num_bytes])
|
|
127
|
+
del self._buffer[:num_bytes]
|
|
128
|
+
return result
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class _StreamReaderProtocol(Protocol[ItemT, HeaderTT, FooterTT]):
|
|
132
|
+
_stream_types: StreamTypes[ItemT, HeaderTT, FooterTT]
|
|
133
|
+
_footer_data: Optional[bytes]
|
|
134
|
+
|
|
135
|
+
async def _read(self) -> tuple[_Delimiter, bytes]: ...
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class _StreamReader(_Streamer[ItemT, HeaderTT, FooterTT]):
|
|
139
|
+
_stream: _ByteReader
|
|
140
|
+
_footer_data: Optional[bytes]
|
|
141
|
+
|
|
142
|
+
def __init__(
|
|
143
|
+
self,
|
|
144
|
+
types: StreamTypes[ItemT, HeaderTT, FooterTT],
|
|
145
|
+
stream: AsyncIterator[bytes],
|
|
146
|
+
) -> None:
|
|
147
|
+
super().__init__(types)
|
|
148
|
+
self._stream = _ByteReader(stream)
|
|
149
|
+
self._footer_data = None
|
|
150
|
+
|
|
151
|
+
@staticmethod
|
|
152
|
+
def _unpack_tag(tag: bytes) -> tuple[_Delimiter, int]:
|
|
153
|
+
enum_value, length = struct.unpack(">BI", tag)
|
|
154
|
+
return _Delimiter(enum_value), length
|
|
155
|
+
|
|
156
|
+
async def _read(self) -> tuple[_Delimiter, bytes]:
|
|
157
|
+
try:
|
|
158
|
+
tag = await self._stream.readexactly(_TAG_SIZE)
|
|
159
|
+
# It's ok to read nothing (end of stream), but unexpected to read partial.
|
|
160
|
+
except asyncio.IncompleteReadError:
|
|
161
|
+
raise
|
|
162
|
+
except EOFError:
|
|
163
|
+
return _Delimiter.END, b""
|
|
164
|
+
|
|
165
|
+
delimiter, length = self._unpack_tag(tag)
|
|
166
|
+
if not length:
|
|
167
|
+
return delimiter, b""
|
|
168
|
+
data_bytes = await self._stream.readexactly(length)
|
|
169
|
+
return delimiter, data_bytes
|
|
170
|
+
|
|
171
|
+
async def read_items(self) -> AsyncIterator[ItemT]:
|
|
172
|
+
delimiter, data_bytes = await self._read()
|
|
173
|
+
if delimiter == _Delimiter.HEADER:
|
|
174
|
+
raise ValueError(
|
|
175
|
+
"Called `read_items`, but there the stream contains header data, which "
|
|
176
|
+
"is not consumed. Call `read_header` first or remove sending a header."
|
|
177
|
+
)
|
|
178
|
+
if delimiter in (_Delimiter.FOOTER, _Delimiter.END): # In case of 0 items.
|
|
179
|
+
self._footer_data = data_bytes
|
|
180
|
+
return
|
|
181
|
+
|
|
182
|
+
assert delimiter == _Delimiter.ITEM
|
|
183
|
+
while True:
|
|
184
|
+
yield self._stream_types.item_type.model_validate_json(data_bytes)
|
|
185
|
+
# We don't know if the next data is another item, footer or the end.
|
|
186
|
+
delimiter, data_bytes = await self._read()
|
|
187
|
+
if delimiter == _Delimiter.END:
|
|
188
|
+
return
|
|
189
|
+
if delimiter == _Delimiter.FOOTER:
|
|
190
|
+
self._footer_data = data_bytes
|
|
191
|
+
return
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class _HeaderReadMixin(_Streamer[ItemT, HeaderT, FooterTT]):
|
|
195
|
+
async def read_header(
|
|
196
|
+
self: _StreamReaderProtocol[ItemT, HeaderT, FooterTT],
|
|
197
|
+
) -> HeaderT:
|
|
198
|
+
delimiter, data_bytes = await self._read()
|
|
199
|
+
if delimiter != _Delimiter.HEADER:
|
|
200
|
+
raise ValueError("Stream does not contain header.")
|
|
201
|
+
return self._stream_types.header_type.model_validate_json(data_bytes)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class _FooterReadMixin(_Streamer[ItemT, HeaderTT, FooterT]):
|
|
205
|
+
_footer_data: Optional[bytes]
|
|
206
|
+
|
|
207
|
+
async def read_footer(
|
|
208
|
+
self: _StreamReaderProtocol[ItemT, HeaderTT, FooterT],
|
|
209
|
+
) -> FooterT:
|
|
210
|
+
if self._footer_data is None:
|
|
211
|
+
delimiter, data_bytes = await self._read()
|
|
212
|
+
if delimiter != _Delimiter.FOOTER:
|
|
213
|
+
raise ValueError("Stream does not contain footer.")
|
|
214
|
+
self._footer_data = data_bytes
|
|
215
|
+
|
|
216
|
+
footer = self._stream_types.footer_type.model_validate_json(self._footer_data)
|
|
217
|
+
self._footer_data = None
|
|
218
|
+
return footer
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class StreamReaderWithHeader(
|
|
222
|
+
_StreamReader[ItemT, HeaderT, FooterTT], _HeaderReadMixin[ItemT, HeaderT, FooterTT]
|
|
223
|
+
): ...
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class StreamReaderWithFooter(
|
|
227
|
+
_StreamReader[ItemT, HeaderTT, FooterT], _FooterReadMixin[ItemT, HeaderTT, FooterT]
|
|
228
|
+
): ...
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class StreamReaderFull(
|
|
232
|
+
_StreamReader[ItemT, HeaderT, FooterT],
|
|
233
|
+
_HeaderReadMixin[ItemT, HeaderT, FooterT],
|
|
234
|
+
_FooterReadMixin[ItemT, HeaderT, FooterT],
|
|
235
|
+
): ...
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@overload
|
|
239
|
+
def stream_reader(
|
|
240
|
+
types: StreamTypes[ItemT, None, None], stream: AsyncIterator[bytes]
|
|
241
|
+
) -> _StreamReader[ItemT, None, None]: ...
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
@overload
|
|
245
|
+
def stream_reader(
|
|
246
|
+
types: StreamTypes[ItemT, HeaderT, None], stream: AsyncIterator[bytes]
|
|
247
|
+
) -> StreamReaderWithHeader[ItemT, HeaderT, None]: ...
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
@overload
|
|
251
|
+
def stream_reader(
|
|
252
|
+
types: StreamTypes[ItemT, None, FooterT], stream: AsyncIterator[bytes]
|
|
253
|
+
) -> StreamReaderWithFooter[ItemT, None, FooterT]: ...
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
@overload
|
|
257
|
+
def stream_reader(
|
|
258
|
+
types: StreamTypes[ItemT, HeaderT, FooterT], stream: AsyncIterator[bytes]
|
|
259
|
+
) -> StreamReaderFull[ItemT, HeaderT, FooterT]: ...
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def stream_reader(
|
|
263
|
+
types: StreamTypes[ItemT, HeaderTT, FooterTT], stream: AsyncIterator[bytes]
|
|
264
|
+
) -> _StreamReader:
|
|
265
|
+
if types.header_type is None and types.footer_type is None:
|
|
266
|
+
return _StreamReader(types, stream)
|
|
267
|
+
if types.header_type is None:
|
|
268
|
+
return StreamReaderWithFooter(types, stream)
|
|
269
|
+
if types.footer_type is None:
|
|
270
|
+
return StreamReaderWithHeader(types, stream)
|
|
271
|
+
|
|
272
|
+
return StreamReaderFull(types, stream)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
# Writing ##############################################################################
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
class _StreamWriterProtocol(Protocol[ItemT, HeaderTT, FooterTT]):
|
|
279
|
+
_stream_types: StreamTypes[ItemT, HeaderTT, FooterTT]
|
|
280
|
+
_last_sent: _Delimiter
|
|
281
|
+
|
|
282
|
+
def _serialize(self, obj: pydantic.BaseModel, delimiter: _Delimiter) -> bytes: ...
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
class _StreamWriter(_Streamer[ItemT, HeaderTT, FooterTT]):
|
|
286
|
+
def __init__(self, types: StreamTypes[ItemT, HeaderTT, FooterTT]) -> None:
|
|
287
|
+
super().__init__(types)
|
|
288
|
+
self._last_sent = _Delimiter.NOT_SET
|
|
289
|
+
self._stream_types = types
|
|
290
|
+
|
|
291
|
+
@staticmethod
|
|
292
|
+
def _pack_tag(delimiter: _Delimiter, length: int) -> bytes:
|
|
293
|
+
return struct.pack(">BI", delimiter.value, length)
|
|
294
|
+
|
|
295
|
+
def _serialize(self, obj: pydantic.BaseModel, delimiter: _Delimiter) -> bytes:
|
|
296
|
+
data_bytes = obj.model_dump_json().encode()
|
|
297
|
+
data = bytearray(self._pack_tag(delimiter, len(data_bytes)))
|
|
298
|
+
data.extend(data_bytes)
|
|
299
|
+
# Starlette cannot handle byte array, but view works..
|
|
300
|
+
return memoryview(data)
|
|
301
|
+
|
|
302
|
+
def yield_item(self, item: ItemT) -> bytes:
|
|
303
|
+
if self._last_sent in (_Delimiter.FOOTER, _Delimiter.END):
|
|
304
|
+
raise ValueError("Cannot yield item after sending footer / closing stream.")
|
|
305
|
+
self._last_sent = _Delimiter.ITEM
|
|
306
|
+
return self._serialize(item, _Delimiter.ITEM)
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
class _HeaderWriteMixin(_Streamer[ItemT, HeaderT, FooterTT]):
|
|
310
|
+
def yield_header(
|
|
311
|
+
self: _StreamWriterProtocol[ItemT, HeaderT, FooterTT], header: HeaderT
|
|
312
|
+
) -> bytes:
|
|
313
|
+
if self._last_sent != _Delimiter.NOT_SET:
|
|
314
|
+
raise ValueError("Cannot yield header after other data has been sent.")
|
|
315
|
+
self._last_sent = _Delimiter.HEADER
|
|
316
|
+
return self._serialize(header, _Delimiter.HEADER)
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
class _FooterWriteMixin(_Streamer[ItemT, HeaderTT, FooterT]):
|
|
320
|
+
def yield_footer(
|
|
321
|
+
self: _StreamWriterProtocol[ItemT, HeaderTT, FooterT], footer: FooterT
|
|
322
|
+
) -> bytes:
|
|
323
|
+
if self._last_sent == _Delimiter.END:
|
|
324
|
+
raise ValueError("Cannot yield footer after closing stream.")
|
|
325
|
+
self._last_sent = _Delimiter.FOOTER
|
|
326
|
+
return self._serialize(footer, _Delimiter.FOOTER)
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
class StreamWriterWithHeader(
|
|
330
|
+
_StreamWriter[ItemT, HeaderT, FooterTT], _HeaderWriteMixin[ItemT, HeaderT, FooterTT]
|
|
331
|
+
): ...
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
class StreamWriterWithFooter(
|
|
335
|
+
_StreamWriter[ItemT, HeaderTT, FooterT], _FooterWriteMixin[ItemT, HeaderTT, FooterT]
|
|
336
|
+
): ...
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
class StreamWriterFull(
|
|
340
|
+
_StreamWriter[ItemT, HeaderT, FooterT],
|
|
341
|
+
_HeaderWriteMixin[ItemT, HeaderT, FooterT],
|
|
342
|
+
_FooterWriteMixin[ItemT, HeaderT, FooterT],
|
|
343
|
+
): ...
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
@overload
|
|
347
|
+
def stream_writer(
|
|
348
|
+
types: StreamTypes[ItemT, None, None],
|
|
349
|
+
) -> _StreamWriter[ItemT, None, None]: ...
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
@overload
|
|
353
|
+
def stream_writer(
|
|
354
|
+
types: StreamTypes[ItemT, HeaderT, None],
|
|
355
|
+
) -> StreamWriterWithHeader[ItemT, HeaderT, None]: ...
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
@overload
|
|
359
|
+
def stream_writer(
|
|
360
|
+
types: StreamTypes[ItemT, None, FooterT],
|
|
361
|
+
) -> StreamWriterWithFooter[ItemT, None, FooterT]: ...
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
@overload
|
|
365
|
+
def stream_writer(
|
|
366
|
+
types: StreamTypes[ItemT, HeaderT, FooterT],
|
|
367
|
+
) -> StreamWriterFull[ItemT, HeaderT, FooterT]: ...
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def stream_writer(types: StreamTypes[ItemT, HeaderTT, FooterTT]) -> _StreamWriter:
|
|
371
|
+
if types.header_type is None and types.footer_type is None:
|
|
372
|
+
return _StreamWriter(types)
|
|
373
|
+
if types.header_type is None:
|
|
374
|
+
return StreamWriterWithFooter(types)
|
|
375
|
+
if types.footer_type is None:
|
|
376
|
+
return StreamWriterWithHeader(types)
|
|
377
|
+
|
|
378
|
+
return StreamWriterFull(types)
|