truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/__init__.py +10 -3
- truss/api/__init__.py +123 -0
- truss/api/definitions.py +51 -0
- truss/base/constants.py +116 -0
- truss/base/custom_types.py +29 -0
- truss/{errors.py → base/errors.py} +4 -0
- truss/base/trt_llm_config.py +310 -0
- truss/{truss_config.py → base/truss_config.py} +344 -31
- truss/{truss_spec.py → base/truss_spec.py} +20 -6
- truss/{validation.py → base/validation.py} +60 -11
- truss/cli/cli.py +841 -88
- truss/{remote → cli}/remote_cli.py +2 -7
- truss/contexts/docker_build_setup.py +67 -0
- truss/contexts/image_builder/cache_warmer.py +2 -8
- truss/contexts/image_builder/image_builder.py +1 -1
- truss/contexts/image_builder/serving_image_builder.py +292 -46
- truss/contexts/image_builder/util.py +1 -3
- truss/contexts/local_loader/docker_build_emulator.py +58 -0
- truss/contexts/local_loader/load_model_local.py +2 -2
- truss/contexts/local_loader/truss_module_loader.py +1 -1
- truss/contexts/local_loader/utils.py +1 -1
- truss/local/local_config.py +2 -6
- truss/local/local_config_handler.py +20 -5
- truss/patch/__init__.py +1 -0
- truss/patch/hash.py +4 -70
- truss/patch/signature.py +4 -16
- truss/patch/truss_dir_patch_applier.py +3 -78
- truss/remote/baseten/api.py +308 -23
- truss/remote/baseten/auth.py +3 -3
- truss/remote/baseten/core.py +257 -50
- truss/remote/baseten/custom_types.py +44 -0
- truss/remote/baseten/error.py +4 -0
- truss/remote/baseten/remote.py +369 -118
- truss/remote/baseten/service.py +118 -11
- truss/remote/baseten/utils/status.py +29 -0
- truss/remote/baseten/utils/tar.py +34 -22
- truss/remote/baseten/utils/transfer.py +36 -23
- truss/remote/remote_factory.py +14 -5
- truss/remote/truss_remote.py +72 -45
- truss/templates/base.Dockerfile.jinja +18 -16
- truss/templates/cache.Dockerfile.jinja +3 -3
- truss/{server → templates/control}/control/application.py +14 -35
- truss/{server → templates/control}/control/endpoints.py +39 -9
- truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
- truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
- truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
- truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
- truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
- truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
- truss/{server → templates/control}/control/server.py +11 -6
- truss/templates/control/requirements.txt +9 -0
- truss/templates/custom_python_dx/my_model.py +28 -0
- truss/templates/docker_server/proxy.conf.jinja +42 -0
- truss/templates/docker_server/supervisord.conf.jinja +27 -0
- truss/templates/docker_server_requirements.txt +1 -0
- truss/templates/server/common/errors.py +231 -0
- truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
- truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
- truss/{server → templates/server}/common/retry.py +1 -0
- truss/{server → templates/server}/common/schema.py +11 -9
- truss/templates/server/common/tracing.py +157 -0
- truss/templates/server/main.py +9 -0
- truss/templates/server/model_wrapper.py +961 -0
- truss/templates/server/requirements.txt +21 -0
- truss/templates/server/truss_server.py +447 -0
- truss/templates/server.Dockerfile.jinja +62 -14
- truss/templates/shared/dynamic_config_resolver.py +28 -0
- truss/templates/shared/lazy_data_resolver.py +164 -0
- truss/templates/shared/log_config.py +125 -0
- truss/{server → templates}/shared/secrets_resolver.py +1 -2
- truss/{server → templates}/shared/serialization.py +31 -9
- truss/{server → templates}/shared/util.py +3 -13
- truss/templates/trtllm-audio/model/model.py +49 -0
- truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
- truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
- truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
- truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
- truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
- truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
- truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
- truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
- truss/templates/trtllm-briton/src/extension.py +64 -0
- truss/tests/conftest.py +302 -94
- truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
- truss/tests/contexts/local_loader/test_load_local.py +2 -2
- truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
- truss/tests/patch/test_calc_patch.py +439 -127
- truss/tests/patch/test_dir_signature.py +3 -12
- truss/tests/patch/test_hash.py +1 -1
- truss/tests/patch/test_signature.py +1 -1
- truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
- truss/tests/patch/test_types.py +2 -2
- truss/tests/remote/baseten/test_api.py +153 -58
- truss/tests/remote/baseten/test_auth.py +2 -1
- truss/tests/remote/baseten/test_core.py +160 -12
- truss/tests/remote/baseten/test_remote.py +489 -77
- truss/tests/remote/baseten/test_service.py +55 -0
- truss/tests/remote/test_remote_factory.py +16 -18
- truss/tests/remote/test_truss_remote.py +26 -17
- truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
- truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
- truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
- truss/tests/{server → templates/control}/control/test_server.py +79 -24
- truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
- truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
- truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
- truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
- truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
- truss/tests/{server → templates/server}/common/test_retry.py +3 -3
- truss/tests/templates/server/test_model_wrapper.py +248 -0
- truss/tests/{server → templates/server}/test_schema.py +3 -5
- truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
- truss/tests/test_build.py +9 -52
- truss/tests/test_config.py +336 -77
- truss/tests/test_context_builder_image.py +3 -11
- truss/tests/test_control_truss_patching.py +7 -12
- truss/tests/test_custom_server.py +38 -0
- truss/tests/test_data/context_builder_image_test/test.py +3 -0
- truss/tests/test_data/happy.ipynb +56 -0
- truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
- truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
- truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
- truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
- truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
- truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
- truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
- truss/tests/test_data/test_async_truss/__init__.py +0 -0
- truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/config.yaml +16 -0
- truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/config.yaml +13 -0
- truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
- truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
- truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
- truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
- truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
- truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
- truss/tests/test_data/test_openai/__init__.py +0 -0
- truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
- truss/tests/test_data/test_openai/model/__init__.py +0 -0
- truss/tests/test_data/test_openai/model/model.py +15 -0
- truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
- truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
- truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
- truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
- truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
- truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
- truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
- truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
- truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
- truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
- truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
- truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
- truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
- truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
- truss/tests/test_data/test_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss/config.yaml +4 -0
- truss/tests/test_data/test_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss/model/dummy +0 -0
- truss/tests/test_data/test_truss/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
- truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
- truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_docker.py +2 -1
- truss/tests/test_model_inference.py +1340 -292
- truss/tests/test_model_schema.py +33 -26
- truss/tests/test_testing_utilities_for_other_tests.py +50 -5
- truss/tests/test_truss_gatherer.py +3 -5
- truss/tests/test_truss_handle.py +62 -59
- truss/tests/test_util.py +2 -1
- truss/tests/test_validation.py +15 -13
- truss/tests/trt_llm/test_trt_llm_config.py +41 -0
- truss/tests/trt_llm/test_validation.py +91 -0
- truss/tests/util/test_config_checks.py +40 -0
- truss/tests/util/test_env_vars.py +14 -0
- truss/tests/util/test_path.py +10 -23
- truss/trt_llm/config_checks.py +43 -0
- truss/trt_llm/validation.py +42 -0
- truss/truss_handle/__init__.py +0 -0
- truss/truss_handle/build.py +122 -0
- truss/{decorators.py → truss_handle/decorators.py} +1 -1
- truss/truss_handle/patch/__init__.py +0 -0
- truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
- truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
- truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
- truss/truss_handle/patch/hash.py +71 -0
- truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
- truss/truss_handle/patch/signature.py +22 -0
- truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
- truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
- truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
- truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
- truss/util/.truss_ignore +3 -0
- truss/{docker.py → util/docker.py} +6 -2
- truss/util/download.py +6 -15
- truss/util/env_vars.py +41 -0
- truss/util/log_utils.py +52 -0
- truss/util/path.py +20 -20
- truss/util/requirements.py +11 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
- truss-0.60.0.dist-info/RECORD +324 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
- truss-0.60.0.dist-info/entry_points.txt +4 -0
- truss_chains/__init__.py +71 -0
- truss_chains/definitions.py +756 -0
- truss_chains/deployment/__init__.py +0 -0
- truss_chains/deployment/code_gen.py +816 -0
- truss_chains/deployment/deployment_client.py +871 -0
- truss_chains/framework.py +1480 -0
- truss_chains/public_api.py +231 -0
- truss_chains/py.typed +0 -0
- truss_chains/pydantic_numpy.py +131 -0
- truss_chains/reference_code/reference_chainlet.py +34 -0
- truss_chains/reference_code/reference_model.py +10 -0
- truss_chains/remote_chainlet/__init__.py +0 -0
- truss_chains/remote_chainlet/model_skeleton.py +60 -0
- truss_chains/remote_chainlet/stub.py +380 -0
- truss_chains/remote_chainlet/utils.py +332 -0
- truss_chains/streaming.py +378 -0
- truss_chains/utils.py +178 -0
- CODE_OF_CONDUCT.md +0 -131
- CONTRIBUTING.md +0 -48
- README.md +0 -137
- context_builder.Dockerfile +0 -24
- truss/blob/blob_backend.py +0 -10
- truss/blob/blob_backend_registry.py +0 -23
- truss/blob/http_public_blob_backend.py +0 -23
- truss/build/__init__.py +0 -2
- truss/build/build.py +0 -143
- truss/build/configure.py +0 -63
- truss/cli/__init__.py +0 -2
- truss/cli/console.py +0 -5
- truss/cli/create.py +0 -5
- truss/config/trt_llm.py +0 -81
- truss/constants.py +0 -61
- truss/model_inference.py +0 -123
- truss/patch/types.py +0 -30
- truss/pytest.ini +0 -7
- truss/server/common/errors.py +0 -100
- truss/server/common/termination_handler_middleware.py +0 -64
- truss/server/common/truss_server.py +0 -389
- truss/server/control/patch/model_code_patch_applier.py +0 -46
- truss/server/control/patch/requirement_name_identifier.py +0 -17
- truss/server/inference_server.py +0 -29
- truss/server/model_wrapper.py +0 -434
- truss/server/shared/logging.py +0 -81
- truss/templates/trtllm/model/model.py +0 -97
- truss/templates/trtllm/packages/build_engine_utils.py +0 -34
- truss/templates/trtllm/packages/constants.py +0 -11
- truss/templates/trtllm/packages/schema.py +0 -216
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
- truss/templates/trtllm/packages/triton_client.py +0 -150
- truss/templates/trtllm/packages/utils.py +0 -43
- truss/test_data/context_builder_image_test/test.py +0 -4
- truss/test_data/happy.ipynb +0 -54
- truss/test_data/model_load_failure_test/config.yaml +0 -2
- truss/test_data/test_concurrency_truss/config.yaml +0 -2
- truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
- truss/test_data/test_streaming_truss/config.yaml +0 -3
- truss/test_data/test_truss/config.yaml +0 -2
- truss/tests/server/common/test_termination_handler_middleware.py +0 -93
- truss/tests/server/control/test_model_container_patch_applier.py +0 -203
- truss/tests/server/core/server/common/test_util.py +0 -19
- truss/tests/server/test_model_wrapper.py +0 -87
- truss/util/data_structures.py +0 -16
- truss-0.10.0rc1.dist-info/RECORD +0 -216
- truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
- truss/{server/shared → base}/__init__.py +0 -0
- truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
- truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
- truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
- truss/{server → templates/server}/common/__init__.py +0 -0
- truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
- truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
- truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
- truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
- truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
- truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
- truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
- truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
- truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
- truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
- truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
- truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
- truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
- truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
- /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
- /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
- /truss/{patch → truss_handle/patch}/constants.py +0 -0
- /truss/{notebook.py → util/notebook.py} +0 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
|
@@ -0,0 +1,816 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Chains currently assumes that everything from the directory in which the entrypoint
|
|
3
|
+
is defined (i.e. sibling files and nested dirs) could be imported/used. e.g.:
|
|
4
|
+
|
|
5
|
+
workspace/
|
|
6
|
+
entrypoint.py
|
|
7
|
+
helper.py
|
|
8
|
+
some_package/
|
|
9
|
+
utils.py
|
|
10
|
+
sub_package/
|
|
11
|
+
...
|
|
12
|
+
|
|
13
|
+
These sources are copied into truss's `/packages` and can be imported on the remote.
|
|
14
|
+
Using code *outside* of the workspace is not supported:
|
|
15
|
+
|
|
16
|
+
shared_lib/
|
|
17
|
+
common.py
|
|
18
|
+
workspace/
|
|
19
|
+
entrypoint.py
|
|
20
|
+
...
|
|
21
|
+
|
|
22
|
+
`shared_lib` can only be imported on the remote if its installed as a pip
|
|
23
|
+
requirement (site-package), it will not be copied from the local host.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
import logging
|
|
27
|
+
import os
|
|
28
|
+
import pathlib
|
|
29
|
+
import re
|
|
30
|
+
import shlex
|
|
31
|
+
import shutil
|
|
32
|
+
import subprocess
|
|
33
|
+
import sys
|
|
34
|
+
import tempfile
|
|
35
|
+
import textwrap
|
|
36
|
+
from typing import Any, Iterable, Mapping, Optional, get_args, get_origin
|
|
37
|
+
|
|
38
|
+
import libcst
|
|
39
|
+
import truss
|
|
40
|
+
from truss.base import truss_config
|
|
41
|
+
from truss.contexts.image_builder import serving_image_builder
|
|
42
|
+
from truss.util import path as truss_path
|
|
43
|
+
|
|
44
|
+
from truss_chains import definitions, framework, utils
|
|
45
|
+
|
|
46
|
+
_INDENT = " " * 4
|
|
47
|
+
_REQUIREMENTS_FILENAME = "pip_requirements.txt"
|
|
48
|
+
_MODEL_FILENAME = "model.py"
|
|
49
|
+
_MODEL_CLS_NAME = "TrussChainletModel"
|
|
50
|
+
_TRUSS_GIT = "git+https://github.com/basetenlabs/truss.git"
|
|
51
|
+
_TRUSS_PIP_PATTERN = re.compile(
|
|
52
|
+
r"""
|
|
53
|
+
^truss
|
|
54
|
+
(?:
|
|
55
|
+
\s*(==|>=|<=|!=|>|<)\s* # Version comparison operators
|
|
56
|
+
\d+(\.\d+)* # Version numbers (e.g., 1, 1.0, 1.0.0)
|
|
57
|
+
(?: # Optional pre-release or build metadata
|
|
58
|
+
(?:a|b|rc|dev)\d*
|
|
59
|
+
(?:\.post\d+)?
|
|
60
|
+
(?:\+[\w\.]+)?
|
|
61
|
+
)?
|
|
62
|
+
)?$
|
|
63
|
+
""",
|
|
64
|
+
re.VERBOSE,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
_MODEL_SKELETON_FILE = (
|
|
68
|
+
pathlib.Path(__file__).parent.parent.resolve()
|
|
69
|
+
/ "remote_chainlet"
|
|
70
|
+
/ "model_skeleton.py"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _indent(text: str, num: int = 1) -> str:
|
|
75
|
+
return textwrap.indent(text, _INDENT * num)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _run_simple_subprocess(cmd: str) -> None:
|
|
79
|
+
process = subprocess.Popen(
|
|
80
|
+
shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
|
81
|
+
)
|
|
82
|
+
_, stderr = process.communicate()
|
|
83
|
+
if process.returncode != 0:
|
|
84
|
+
raise ChildProcessError(f"Error: {stderr.decode()}")
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _format_python_file(file_path: pathlib.Path) -> None:
|
|
88
|
+
# Resolve importing sorting and unused import issues.
|
|
89
|
+
_run_simple_subprocess(f"ruff check {file_path} --fix --select F401,I")
|
|
90
|
+
_run_simple_subprocess(f"ruff format {file_path}")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class _Source(definitions.SafeModelNonSerializable):
|
|
94
|
+
src: str
|
|
95
|
+
imports: set[str] = set()
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _update_src(new_source: _Source, src_parts: list[str], imports: set[str]) -> None:
|
|
99
|
+
src_parts.append(new_source.src)
|
|
100
|
+
imports.update(new_source.imports)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _gen_pydantic_import_and_ref(raw_type: Any) -> _Source:
|
|
104
|
+
"""Returns e.g. ("from sub_package import module", "module.OutputType")."""
|
|
105
|
+
if raw_type.__module__ == "__main__":
|
|
106
|
+
# Assuming that main is copied into package dir and can be imported.
|
|
107
|
+
module_obj = sys.modules[raw_type.__module__]
|
|
108
|
+
if not module_obj.__file__:
|
|
109
|
+
raise definitions.ChainsUsageError(
|
|
110
|
+
f"File-based python code required. `{raw_type}` does not have a file."
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
file = os.path.basename(module_obj.__file__)
|
|
114
|
+
assert file.endswith(".py")
|
|
115
|
+
module_name = file.replace(".py", "")
|
|
116
|
+
import_src = f"import {module_name}"
|
|
117
|
+
ref_src = f"{module_name}.{raw_type.__name__}"
|
|
118
|
+
else:
|
|
119
|
+
parts = raw_type.__module__.split(".")
|
|
120
|
+
ref_src = f"{parts[-1]}.{raw_type.__name__}"
|
|
121
|
+
if len(parts) > 1:
|
|
122
|
+
import_src = f"from {'.'.join(parts[:-1])} import {parts[-1]}"
|
|
123
|
+
else:
|
|
124
|
+
import_src = f"import {parts[0]}"
|
|
125
|
+
|
|
126
|
+
return _Source(src=ref_src, imports={import_src})
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _gen_nested_pydantic(raw_type: Any) -> _Source:
|
|
130
|
+
"""Handles `list[PydanticModel]` and similar, correctly resolving imports
|
|
131
|
+
of model args that might be defined in other files."""
|
|
132
|
+
origin = get_origin(raw_type)
|
|
133
|
+
assert origin in framework._SIMPLE_CONTAINERS
|
|
134
|
+
container = _gen_type_import_and_ref(definitions.TypeDescriptor(raw=origin))
|
|
135
|
+
args = get_args(raw_type)
|
|
136
|
+
arg_parts = []
|
|
137
|
+
for arg in args:
|
|
138
|
+
arg_src = _gen_type_import_and_ref(definitions.TypeDescriptor(raw=arg))
|
|
139
|
+
arg_parts.append(arg_src.src)
|
|
140
|
+
container.imports.update(arg_src.imports)
|
|
141
|
+
|
|
142
|
+
container.src = f"{container.src}[{','.join(arg_parts)}]"
|
|
143
|
+
return container
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _gen_type_import_and_ref(type_descr: definitions.TypeDescriptor) -> _Source:
|
|
147
|
+
"""Returns e.g. ("from sub_package import module", "module.OutputType")."""
|
|
148
|
+
if type_descr.is_pydantic:
|
|
149
|
+
return _gen_pydantic_import_and_ref(type_descr.raw)
|
|
150
|
+
if type_descr.has_pydantic_args:
|
|
151
|
+
return _gen_nested_pydantic(type_descr.raw)
|
|
152
|
+
if isinstance(type_descr.raw, type):
|
|
153
|
+
if not type_descr.raw.__module__ == "builtins":
|
|
154
|
+
raise TypeError(
|
|
155
|
+
f"{type_descr.raw} is not a builtin - cannot be rendered as source."
|
|
156
|
+
)
|
|
157
|
+
return _Source(src=type_descr.raw.__name__)
|
|
158
|
+
|
|
159
|
+
return _Source(src=str(type_descr.raw))
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _gen_streaming_type_import_and_ref(
|
|
163
|
+
stream_type: definitions.StreamingTypeDescriptor,
|
|
164
|
+
) -> _Source:
|
|
165
|
+
"""Unlike other `_gen`-helpers, this does not define a type, it creates a symbol."""
|
|
166
|
+
mod = stream_type.origin_type.__module__
|
|
167
|
+
arg = stream_type.arg_type.__name__
|
|
168
|
+
type_src = f"{mod}.{stream_type.origin_type.__name__}[{arg}]"
|
|
169
|
+
return _Source(src=type_src, imports={f"import {mod}"})
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _gen_chainlet_import_and_ref(
|
|
173
|
+
chainlet_descriptor: definitions.ChainletAPIDescriptor,
|
|
174
|
+
) -> _Source:
|
|
175
|
+
"""Returns e.g. ("from sub_package import module", "module.OutputType")."""
|
|
176
|
+
return _gen_pydantic_import_and_ref(chainlet_descriptor.chainlet_cls)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
# I/O used by Stubs and Truss models ###################################################
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _get_input_model_name(chainlet_name: str) -> str:
|
|
183
|
+
return f"{chainlet_name}Input"
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _get_output_model_name(chainlet_name: str) -> str:
|
|
187
|
+
return f"{chainlet_name}Output"
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _gen_truss_input_pydantic(
|
|
191
|
+
chainlet_descriptor: definitions.ChainletAPIDescriptor,
|
|
192
|
+
) -> _Source:
|
|
193
|
+
imports = {"import pydantic", "from typing import Optional"}
|
|
194
|
+
fields = []
|
|
195
|
+
for arg in chainlet_descriptor.endpoint.input_args:
|
|
196
|
+
type_ref = _gen_type_import_and_ref(arg.type)
|
|
197
|
+
imports.update(type_ref.imports)
|
|
198
|
+
if arg.is_optional:
|
|
199
|
+
fields.append(f"{arg.name}: Optional[{type_ref.src}] = None")
|
|
200
|
+
else:
|
|
201
|
+
fields.append(f"{arg.name}: {type_ref.src}")
|
|
202
|
+
|
|
203
|
+
if fields:
|
|
204
|
+
field_block = _indent("\n".join(fields))
|
|
205
|
+
else:
|
|
206
|
+
field_block = _indent("pass")
|
|
207
|
+
|
|
208
|
+
model_name = _get_input_model_name(chainlet_descriptor.name)
|
|
209
|
+
src = f"class {model_name}(pydantic.BaseModel):\n{field_block}"
|
|
210
|
+
return _Source(src=src, imports=imports)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _gen_truss_output_pydantic(
|
|
214
|
+
chainlet_descriptor: definitions.ChainletAPIDescriptor,
|
|
215
|
+
) -> _Source:
|
|
216
|
+
imports = {"import pydantic"}
|
|
217
|
+
fields: list[str] = []
|
|
218
|
+
for i, output_type in enumerate(chainlet_descriptor.endpoint.output_types):
|
|
219
|
+
_update_src(_gen_type_import_and_ref(output_type), fields, imports)
|
|
220
|
+
|
|
221
|
+
model_name = _get_output_model_name(chainlet_descriptor.name)
|
|
222
|
+
if len(fields) > 1:
|
|
223
|
+
root_type = f"tuple[{','.join(fields)}]"
|
|
224
|
+
else:
|
|
225
|
+
root_type = fields[0]
|
|
226
|
+
src = f"{model_name} = pydantic.RootModel[{root_type}]"
|
|
227
|
+
return _Source(src=src, imports=imports)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
# Stub Gen #############################################################################
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _stub_endpoint_signature_src(
|
|
234
|
+
endpoint: definitions.EndpointAPIDescriptor,
|
|
235
|
+
) -> _Source:
|
|
236
|
+
"""
|
|
237
|
+
E.g.:
|
|
238
|
+
```
|
|
239
|
+
async def run_remote(
|
|
240
|
+
self, inputs: shared_chainlet.SplitTextInput, extra_arg: int
|
|
241
|
+
) -> tuple[shared_chainlet.SplitTextOutput, int]:
|
|
242
|
+
```
|
|
243
|
+
"""
|
|
244
|
+
imports = set()
|
|
245
|
+
args = ["self"]
|
|
246
|
+
for arg in endpoint.input_args:
|
|
247
|
+
arg_ref = _gen_type_import_and_ref(arg.type)
|
|
248
|
+
imports.update(arg_ref.imports)
|
|
249
|
+
args.append(f"{arg.name}: {arg_ref.src}")
|
|
250
|
+
|
|
251
|
+
if endpoint.is_streaming:
|
|
252
|
+
streaming_src = _gen_streaming_type_import_and_ref(endpoint.streaming_type)
|
|
253
|
+
imports.update(streaming_src.imports)
|
|
254
|
+
output = streaming_src.src
|
|
255
|
+
else:
|
|
256
|
+
outputs: list[str] = []
|
|
257
|
+
for output_type in endpoint.output_types:
|
|
258
|
+
_update_src(_gen_type_import_and_ref(output_type), outputs, imports)
|
|
259
|
+
|
|
260
|
+
if len(outputs) == 1:
|
|
261
|
+
output = outputs[0]
|
|
262
|
+
else:
|
|
263
|
+
output = f"tuple[{', '.join(outputs)}]"
|
|
264
|
+
|
|
265
|
+
def_str = "async def" if endpoint.is_async else "def"
|
|
266
|
+
return _Source(
|
|
267
|
+
src=f"{def_str} {endpoint.name}({','.join(args)}) -> {output}:", imports=imports
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def _stub_endpoint_body_src(
|
|
272
|
+
endpoint: definitions.EndpointAPIDescriptor, chainlet_name: str
|
|
273
|
+
) -> _Source:
|
|
274
|
+
"""Generates source code for calling the stub and wrapping the I/O types.
|
|
275
|
+
|
|
276
|
+
E.g.:
|
|
277
|
+
```
|
|
278
|
+
return await self.predict_async(
|
|
279
|
+
SplitTextInput(inputs=inputs, extra_arg=extra_arg), SplitTextOutput).root
|
|
280
|
+
```
|
|
281
|
+
"""
|
|
282
|
+
imports: set[str] = set()
|
|
283
|
+
args = [f"{arg.name}={arg.name}" for arg in endpoint.input_args]
|
|
284
|
+
if args:
|
|
285
|
+
inputs = f"{_get_input_model_name(chainlet_name)}({', '.join(args)})"
|
|
286
|
+
else:
|
|
287
|
+
inputs = "{}"
|
|
288
|
+
|
|
289
|
+
parts = []
|
|
290
|
+
# Invoke remote.
|
|
291
|
+
if not endpoint.is_streaming:
|
|
292
|
+
output_model_name = _get_output_model_name(chainlet_name)
|
|
293
|
+
if endpoint.is_async:
|
|
294
|
+
parts = [
|
|
295
|
+
f"return (await self.predict_async({inputs}, {output_model_name})).root"
|
|
296
|
+
]
|
|
297
|
+
else:
|
|
298
|
+
parts = [f"return self.predict_sync({inputs}, {output_model_name}).root"]
|
|
299
|
+
|
|
300
|
+
else:
|
|
301
|
+
if endpoint.is_async:
|
|
302
|
+
parts.append(
|
|
303
|
+
f"async for data in await self.predict_async_stream({inputs}):"
|
|
304
|
+
)
|
|
305
|
+
if endpoint.streaming_type.is_string:
|
|
306
|
+
parts.append(_indent("yield data.decode()"))
|
|
307
|
+
else:
|
|
308
|
+
parts.append(_indent("yield data"))
|
|
309
|
+
else:
|
|
310
|
+
raise NotImplementedError(
|
|
311
|
+
"`Streaming endpoints (containing `yield` statements) are only "
|
|
312
|
+
"supported for async endpoints."
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
return _Source(src="\n".join(parts), imports=imports)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def _gen_stub_src(chainlet: definitions.ChainletAPIDescriptor) -> _Source:
|
|
319
|
+
"""Generates stub class source, e.g:
|
|
320
|
+
|
|
321
|
+
```
|
|
322
|
+
<IMPORTS>
|
|
323
|
+
|
|
324
|
+
class SplitTextInput(pydantic.BaseModel):
|
|
325
|
+
inputs: shared_chainlet.SplitTextInput
|
|
326
|
+
extra_arg: int
|
|
327
|
+
|
|
328
|
+
class SplitTextOutput(pydantic.BaseModel):
|
|
329
|
+
output: tuple[shared_chainlet.SplitTextOutput, int]
|
|
330
|
+
|
|
331
|
+
class SplitText(stub.StubBase):
|
|
332
|
+
async def run_remote(
|
|
333
|
+
self, inputs: shared_chainlet.SplitTextInput, extra_arg: int
|
|
334
|
+
) -> tuple[shared_chainlet.SplitTextOutput, int]:
|
|
335
|
+
return await self.predict_async(
|
|
336
|
+
SplitTextInput(inputs=inputs, extra_arg=extra_arg), SplitTextOutput).root
|
|
337
|
+
```
|
|
338
|
+
"""
|
|
339
|
+
imports = {"from truss_chains.remote_chainlet import stub"}
|
|
340
|
+
src_parts: list[str] = []
|
|
341
|
+
input_src = _gen_truss_input_pydantic(chainlet)
|
|
342
|
+
_update_src(input_src, src_parts, imports)
|
|
343
|
+
if not chainlet.endpoint.is_streaming:
|
|
344
|
+
output_src = _gen_truss_output_pydantic(chainlet)
|
|
345
|
+
_update_src(output_src, src_parts, imports)
|
|
346
|
+
signature = _stub_endpoint_signature_src(chainlet.endpoint)
|
|
347
|
+
imports.update(signature.imports)
|
|
348
|
+
body = _stub_endpoint_body_src(chainlet.endpoint, chainlet.name)
|
|
349
|
+
imports.update(body.imports)
|
|
350
|
+
|
|
351
|
+
src_parts.extend(
|
|
352
|
+
[
|
|
353
|
+
f"class {chainlet.name}(stub.StubBase):",
|
|
354
|
+
_indent(signature.src),
|
|
355
|
+
_indent(body.src, 2),
|
|
356
|
+
"\n",
|
|
357
|
+
]
|
|
358
|
+
)
|
|
359
|
+
return _Source(src="\n".join(src_parts), imports=imports)
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def _gen_stub_src_for_deps(
|
|
363
|
+
dependencies: Iterable[definitions.ChainletAPIDescriptor],
|
|
364
|
+
) -> Optional[_Source]:
|
|
365
|
+
"""Generates a source code and imports for stub classes."""
|
|
366
|
+
imports: set[str] = set()
|
|
367
|
+
src_parts: list[str] = []
|
|
368
|
+
for dep in dependencies:
|
|
369
|
+
_update_src(_gen_stub_src(dep), src_parts, imports)
|
|
370
|
+
|
|
371
|
+
if not (imports or src_parts):
|
|
372
|
+
return None
|
|
373
|
+
return _Source(src="\n".join(src_parts), imports=imports)
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
# Truss Chainlet Gen ###################################################################
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def _name_to_dirname(name: str) -> str:
|
|
380
|
+
"""Make a string safe to use as a directory name."""
|
|
381
|
+
name = name.strip() # Remove leading and trailing spaces
|
|
382
|
+
name = re.sub(
|
|
383
|
+
r"[^\w.-]", "_", name
|
|
384
|
+
) # Replace non-alphanumeric characters with underscores
|
|
385
|
+
name = re.sub(r"_+", "_", name) # Collapse multiple underscores into a single one
|
|
386
|
+
return name
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
def _make_chainlet_dir(
|
|
390
|
+
chain_name: str,
|
|
391
|
+
chainlet_descriptor: definitions.ChainletAPIDescriptor,
|
|
392
|
+
root: pathlib.Path,
|
|
393
|
+
) -> pathlib.Path:
|
|
394
|
+
dir_name = f"chainlet_{chainlet_descriptor.name}"
|
|
395
|
+
chainlet_dir = (
|
|
396
|
+
root / definitions.GENERATED_CODE_DIR / _name_to_dirname(chain_name) / dir_name
|
|
397
|
+
)
|
|
398
|
+
if chainlet_dir.exists():
|
|
399
|
+
shutil.rmtree(chainlet_dir)
|
|
400
|
+
chainlet_dir.mkdir(exist_ok=False, parents=True)
|
|
401
|
+
return chainlet_dir
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
class _SpecifyChainletTypeAnnotation(libcst.CSTTransformer):
|
|
405
|
+
"""Inserts the concrete chainlet class into `_chainlet: definitions.ABCChainlet`."""
|
|
406
|
+
|
|
407
|
+
def __init__(self, new_annotation: str) -> None:
|
|
408
|
+
super().__init__()
|
|
409
|
+
self._new_annotation = libcst.parse_expression(new_annotation)
|
|
410
|
+
|
|
411
|
+
def leave_SimpleStatementLine(
|
|
412
|
+
self,
|
|
413
|
+
original_node: libcst.SimpleStatementLine,
|
|
414
|
+
updated_node: libcst.SimpleStatementLine,
|
|
415
|
+
) -> libcst.SimpleStatementLine:
|
|
416
|
+
new_body: list[Any] = []
|
|
417
|
+
for statement in updated_node.body:
|
|
418
|
+
if (
|
|
419
|
+
isinstance(statement, libcst.AnnAssign)
|
|
420
|
+
and isinstance(statement.target, libcst.Name)
|
|
421
|
+
and statement.target.value == "_chainlet"
|
|
422
|
+
):
|
|
423
|
+
new_annotation = libcst.Annotation(annotation=self._new_annotation)
|
|
424
|
+
new_statement = statement.with_changes(annotation=new_annotation)
|
|
425
|
+
new_body.append(new_statement)
|
|
426
|
+
else:
|
|
427
|
+
new_body.append(statement)
|
|
428
|
+
|
|
429
|
+
return updated_node.with_changes(body=tuple(new_body))
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
def _gen_load_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> _Source:
|
|
433
|
+
"""Generates AST for the `load` method of the truss model."""
|
|
434
|
+
imports = {"from truss_chains.remote_chainlet import stub", "import logging"}
|
|
435
|
+
stub_args = []
|
|
436
|
+
for name, dep in chainlet_descriptor.dependencies.items():
|
|
437
|
+
# `dep.name` is the class name, while `name` is the argument name.
|
|
438
|
+
stub_args.append(f"{name}=stub.factory({dep.name}, self._context)")
|
|
439
|
+
|
|
440
|
+
if chainlet_descriptor.has_context:
|
|
441
|
+
if stub_args:
|
|
442
|
+
init_args = f"{', '.join(stub_args)}, context=self._context"
|
|
443
|
+
else:
|
|
444
|
+
init_args = "context=self._context"
|
|
445
|
+
else:
|
|
446
|
+
init_args = ", ".join(stub_args)
|
|
447
|
+
|
|
448
|
+
user_chainlet_ref = _gen_chainlet_import_and_ref(chainlet_descriptor)
|
|
449
|
+
imports.update(user_chainlet_ref.imports)
|
|
450
|
+
body = _indent(
|
|
451
|
+
"\n".join(
|
|
452
|
+
[f"logging.info(f'Loading Chainlet `{chainlet_descriptor.name}`.')"]
|
|
453
|
+
+ [f"self._chainlet = {user_chainlet_ref.src}({init_args})"]
|
|
454
|
+
)
|
|
455
|
+
)
|
|
456
|
+
src = "\n".join(["def load(self) -> None:", body])
|
|
457
|
+
return _Source(src=src, imports=imports)
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
def _gen_health_check_src(
|
|
461
|
+
health_check: definitions.HealthCheckAPIDescriptor,
|
|
462
|
+
) -> _Source:
|
|
463
|
+
"""Generates AST for the `is_healthy` method of the truss model."""
|
|
464
|
+
def_str = "async def" if health_check.is_async else "def"
|
|
465
|
+
maybe_await = "await " if health_check.is_async else ""
|
|
466
|
+
src = (
|
|
467
|
+
f"{def_str} is_healthy(self) -> Optional[bool]:\n"
|
|
468
|
+
f"""{_indent('if hasattr(self, "_chainlet"):')}"""
|
|
469
|
+
f"""{_indent(f"return {maybe_await}self._chainlet.is_healthy()")}"""
|
|
470
|
+
)
|
|
471
|
+
return _Source(src=src)
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> _Source:
|
|
475
|
+
"""Generates AST for the `predict` method of the truss model."""
|
|
476
|
+
imports: set[str] = {
|
|
477
|
+
"from truss_chains.remote_chainlet import stub",
|
|
478
|
+
"from truss_chains.remote_chainlet import utils",
|
|
479
|
+
}
|
|
480
|
+
parts: list[str] = []
|
|
481
|
+
def_str = "async def" if chainlet_descriptor.endpoint.is_async else "def"
|
|
482
|
+
input_model_name = _get_input_model_name(chainlet_descriptor.name)
|
|
483
|
+
if chainlet_descriptor.endpoint.is_streaming:
|
|
484
|
+
streaming_src = _gen_streaming_type_import_and_ref(
|
|
485
|
+
chainlet_descriptor.endpoint.streaming_type
|
|
486
|
+
)
|
|
487
|
+
imports.update(streaming_src.imports)
|
|
488
|
+
output_type_name = streaming_src.src
|
|
489
|
+
else:
|
|
490
|
+
output_type_name = _get_output_model_name(chainlet_descriptor.name)
|
|
491
|
+
|
|
492
|
+
imports.add("import starlette.requests")
|
|
493
|
+
parts.append(
|
|
494
|
+
f"{def_str} predict(self, inputs: {input_model_name}, "
|
|
495
|
+
f"request: starlette.requests.Request) -> {output_type_name}:"
|
|
496
|
+
)
|
|
497
|
+
# Add error handling context manager:
|
|
498
|
+
parts.append(_indent("with utils.predict_context(request):"))
|
|
499
|
+
# Invoke Chainlet.
|
|
500
|
+
if (
|
|
501
|
+
chainlet_descriptor.endpoint.is_async
|
|
502
|
+
and not chainlet_descriptor.endpoint.is_streaming
|
|
503
|
+
):
|
|
504
|
+
maybe_await = "await "
|
|
505
|
+
else:
|
|
506
|
+
maybe_await = ""
|
|
507
|
+
run_remote = chainlet_descriptor.endpoint.name
|
|
508
|
+
# See docs of `pydantic_set_field_dict` for why this is needed.
|
|
509
|
+
args = "**utils.pydantic_set_field_dict(inputs)"
|
|
510
|
+
parts.append(
|
|
511
|
+
_indent(f"result = {maybe_await}self._chainlet.{run_remote}({args})", 2)
|
|
512
|
+
)
|
|
513
|
+
if chainlet_descriptor.endpoint.is_streaming:
|
|
514
|
+
# Streaming returns raw iterator, no pydantic model.
|
|
515
|
+
# This needs to be nested inside the `trace_parent` context!
|
|
516
|
+
parts.append(_indent("async for chunk in result:", 2))
|
|
517
|
+
parts.append(_indent("yield chunk", 3))
|
|
518
|
+
else:
|
|
519
|
+
result_pydantic = f"{output_type_name}(result)"
|
|
520
|
+
parts.append(_indent(f"return {result_pydantic}"))
|
|
521
|
+
return _Source(src="\n".join(parts), imports=imports)
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
def _gen_truss_chainlet_model(
|
|
525
|
+
chainlet_descriptor: definitions.ChainletAPIDescriptor,
|
|
526
|
+
) -> _Source:
|
|
527
|
+
skeleton_tree = libcst.parse_module(_MODEL_SKELETON_FILE.read_text())
|
|
528
|
+
imports: set[str] = set(
|
|
529
|
+
libcst.Module(body=[node]).code
|
|
530
|
+
for node in skeleton_tree.body
|
|
531
|
+
if isinstance(node, libcst.SimpleStatementLine)
|
|
532
|
+
and any(
|
|
533
|
+
isinstance(stmt, libcst.Import) or isinstance(stmt, libcst.ImportFrom)
|
|
534
|
+
for stmt in node.body
|
|
535
|
+
)
|
|
536
|
+
)
|
|
537
|
+
class_definition: libcst.ClassDef = utils.expect_one(
|
|
538
|
+
node
|
|
539
|
+
for node in skeleton_tree.body
|
|
540
|
+
if isinstance(node, libcst.ClassDef) and node.name.value == _MODEL_CLS_NAME
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
load_src = _gen_load_src(chainlet_descriptor)
|
|
544
|
+
imports.update(load_src.imports)
|
|
545
|
+
predict_src = _gen_predict_src(chainlet_descriptor)
|
|
546
|
+
imports.update(predict_src.imports)
|
|
547
|
+
|
|
548
|
+
new_body: list[Any] = list(class_definition.body.body) + [
|
|
549
|
+
libcst.parse_statement(load_src.src),
|
|
550
|
+
libcst.parse_statement(predict_src.src),
|
|
551
|
+
]
|
|
552
|
+
|
|
553
|
+
if chainlet_descriptor.health_check is not None:
|
|
554
|
+
health_check_src = _gen_health_check_src(chainlet_descriptor.health_check)
|
|
555
|
+
new_body.extend([libcst.parse_statement(health_check_src.src)])
|
|
556
|
+
|
|
557
|
+
user_chainlet_ref = _gen_chainlet_import_and_ref(chainlet_descriptor)
|
|
558
|
+
imports.update(user_chainlet_ref.imports)
|
|
559
|
+
|
|
560
|
+
new_block = libcst.IndentedBlock(body=new_body)
|
|
561
|
+
class_definition = class_definition.with_changes(body=new_block)
|
|
562
|
+
class_definition = class_definition.visit( # type: ignore[assignment]
|
|
563
|
+
_SpecifyChainletTypeAnnotation(user_chainlet_ref.src)
|
|
564
|
+
)
|
|
565
|
+
model_class_src = libcst.Module(body=[class_definition]).code
|
|
566
|
+
return _Source(src=model_class_src, imports=imports)
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
def _gen_truss_chainlet_file(
|
|
570
|
+
chainlet_dir: pathlib.Path,
|
|
571
|
+
chainlet_descriptor: definitions.ChainletAPIDescriptor,
|
|
572
|
+
dependencies: Iterable[definitions.ChainletAPIDescriptor],
|
|
573
|
+
) -> pathlib.Path:
|
|
574
|
+
"""Generates code that wraps a Chainlet as a truss-compatible model."""
|
|
575
|
+
file_path = chainlet_dir / truss_config.DEFAULT_MODEL_MODULE_DIR / _MODEL_FILENAME
|
|
576
|
+
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
577
|
+
(chainlet_dir / truss_config.DEFAULT_MODEL_MODULE_DIR / "__init__.py").touch()
|
|
578
|
+
imports: set[str] = set()
|
|
579
|
+
src_parts: list[str] = []
|
|
580
|
+
|
|
581
|
+
if maybe_stub_src := _gen_stub_src_for_deps(dependencies):
|
|
582
|
+
_update_src(maybe_stub_src, src_parts, imports)
|
|
583
|
+
|
|
584
|
+
input_src = _gen_truss_input_pydantic(chainlet_descriptor)
|
|
585
|
+
_update_src(input_src, src_parts, imports)
|
|
586
|
+
if not chainlet_descriptor.endpoint.is_streaming:
|
|
587
|
+
output_src = _gen_truss_output_pydantic(chainlet_descriptor)
|
|
588
|
+
_update_src(output_src, src_parts, imports)
|
|
589
|
+
model_src = _gen_truss_chainlet_model(chainlet_descriptor)
|
|
590
|
+
_update_src(model_src, src_parts, imports)
|
|
591
|
+
|
|
592
|
+
imports_str = "\n".join(imports)
|
|
593
|
+
src_str = "\n".join(src_parts)
|
|
594
|
+
file_path.write_text(f"{imports_str}\n{src_str}")
|
|
595
|
+
_format_python_file(file_path)
|
|
596
|
+
return file_path
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
# Truss Gen ############################################################################
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
def _make_requirements(image: definitions.DockerImage) -> list[str]:
|
|
603
|
+
"""Merges file- and list-based requirements and adds truss git if not present."""
|
|
604
|
+
pip_requirements: set[str] = set()
|
|
605
|
+
if image.pip_requirements_file:
|
|
606
|
+
pip_requirements.update(
|
|
607
|
+
req
|
|
608
|
+
for req in pathlib.Path(image.pip_requirements_file.abs_path)
|
|
609
|
+
.read_text()
|
|
610
|
+
.splitlines()
|
|
611
|
+
if not req.strip().startswith("#")
|
|
612
|
+
)
|
|
613
|
+
pip_requirements.update(image.pip_requirements)
|
|
614
|
+
|
|
615
|
+
truss_pypy = next(
|
|
616
|
+
(req for req in pip_requirements if _TRUSS_PIP_PATTERN.match(req)), None
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
truss_git = next((req for req in pip_requirements if _TRUSS_GIT in req), None)
|
|
620
|
+
|
|
621
|
+
if truss_git:
|
|
622
|
+
logging.warning(
|
|
623
|
+
"The chainlet contains a truss version from github as a pip_requirement:\n"
|
|
624
|
+
f"\t{truss_git}\n"
|
|
625
|
+
"This could result in inconsistencies between the deploying client and the "
|
|
626
|
+
"deployed chainlet. This is not recommended for production chains."
|
|
627
|
+
)
|
|
628
|
+
if truss_pypy:
|
|
629
|
+
logging.warning(
|
|
630
|
+
"The chainlet contains a pinned truss version as a pip_requirement:\n"
|
|
631
|
+
f"\t{truss_pypy}\n"
|
|
632
|
+
"This could result in inconsistencies between the deploying client and the "
|
|
633
|
+
"deployed chainlet. This is not recommended for production chains. If "
|
|
634
|
+
"`truss` is not manually added as a requirement, the same version as "
|
|
635
|
+
"locally installed will be automatically added and ensure compatibility."
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
if not (truss_git or truss_pypy):
|
|
639
|
+
truss_pip = f"truss=={truss.version()}"
|
|
640
|
+
logging.debug(
|
|
641
|
+
f"Truss not found in pip requirements, auto-adding: `{truss_pip}`."
|
|
642
|
+
)
|
|
643
|
+
pip_requirements.add(truss_pip)
|
|
644
|
+
|
|
645
|
+
return sorted(pip_requirements)
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
def _inplace_fill_base_image(
|
|
649
|
+
image: definitions.DockerImage, mutable_truss_config: truss_config.TrussConfig
|
|
650
|
+
) -> None:
|
|
651
|
+
if isinstance(image.base_image, definitions.BasetenImage):
|
|
652
|
+
mutable_truss_config.python_version = image.base_image.value
|
|
653
|
+
elif isinstance(image.base_image, definitions.CustomImage):
|
|
654
|
+
mutable_truss_config.base_image = truss_config.BaseImage(
|
|
655
|
+
image=image.base_image.image, docker_auth=image.base_image.docker_auth
|
|
656
|
+
)
|
|
657
|
+
if image.base_image.python_executable_path:
|
|
658
|
+
mutable_truss_config.base_image.python_executable_path = (
|
|
659
|
+
image.base_image.python_executable_path
|
|
660
|
+
)
|
|
661
|
+
elif isinstance(image.base_image, str): # This options is deprecated.
|
|
662
|
+
raise NotImplementedError(
|
|
663
|
+
"Specifying docker base image as string is deprecated"
|
|
664
|
+
)
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
def _write_truss_config_yaml(
|
|
668
|
+
chainlet_dir: pathlib.Path,
|
|
669
|
+
chains_config: definitions.RemoteConfig,
|
|
670
|
+
chainlet_to_service: Mapping[str, definitions.ServiceDescriptor],
|
|
671
|
+
model_name: str,
|
|
672
|
+
use_local_chains_src: bool,
|
|
673
|
+
):
|
|
674
|
+
"""Generate a truss config for a Chainlet."""
|
|
675
|
+
config = truss_config.TrussConfig()
|
|
676
|
+
config.model_name = model_name
|
|
677
|
+
config.model_class_filename = _MODEL_FILENAME
|
|
678
|
+
config.model_class_name = _MODEL_CLS_NAME
|
|
679
|
+
config.runtime.enable_tracing_data = chains_config.options.enable_b10_tracing
|
|
680
|
+
config.environment_variables = dict(chains_config.options.env_variables)
|
|
681
|
+
config.runtime.health_checks = chains_config.options.health_checks
|
|
682
|
+
# Compute.
|
|
683
|
+
compute = chains_config.get_compute_spec()
|
|
684
|
+
config.resources.cpu = str(compute.cpu_count)
|
|
685
|
+
config.resources.memory = str(compute.memory)
|
|
686
|
+
config.resources.accelerator = compute.accelerator
|
|
687
|
+
config.resources.use_gpu = bool(compute.accelerator.count)
|
|
688
|
+
config.runtime.predict_concurrency = compute.predict_concurrency
|
|
689
|
+
# Image.
|
|
690
|
+
_inplace_fill_base_image(chains_config.docker_image, config)
|
|
691
|
+
pip_requirements = _make_requirements(chains_config.docker_image)
|
|
692
|
+
# TODO: `pip_requirements` will add server requirements which give version
|
|
693
|
+
# conflicts. Check if that's still the case after relaxing versions.
|
|
694
|
+
# config.requirements = pip_requirements
|
|
695
|
+
pip_requirements_file_path = chainlet_dir / _REQUIREMENTS_FILENAME
|
|
696
|
+
pip_requirements_file_path.write_text("\n".join(pip_requirements))
|
|
697
|
+
# Absolute paths don't work with remote build.
|
|
698
|
+
config.requirements_file = _REQUIREMENTS_FILENAME
|
|
699
|
+
config.system_packages = chains_config.docker_image.apt_requirements
|
|
700
|
+
if chains_config.docker_image.external_package_dirs:
|
|
701
|
+
for ext_dir in chains_config.docker_image.external_package_dirs:
|
|
702
|
+
config.external_package_dirs.append(ext_dir.abs_path)
|
|
703
|
+
config.use_local_chains_src = use_local_chains_src
|
|
704
|
+
# Assets.
|
|
705
|
+
assets = chains_config.get_asset_spec()
|
|
706
|
+
config.secrets = assets.secrets
|
|
707
|
+
if definitions.BASETEN_API_SECRET_NAME not in config.secrets:
|
|
708
|
+
config.secrets[definitions.BASETEN_API_SECRET_NAME] = definitions.SECRET_DUMMY
|
|
709
|
+
else:
|
|
710
|
+
logging.info(
|
|
711
|
+
f"Chains automatically add {definitions.BASETEN_API_SECRET_NAME} "
|
|
712
|
+
"to secrets - no need to manually add it."
|
|
713
|
+
)
|
|
714
|
+
config.model_cache.models = assets.cached
|
|
715
|
+
config.external_data = truss_config.ExternalData(items=assets.external_data)
|
|
716
|
+
# Metadata.
|
|
717
|
+
chains_metadata: definitions.TrussMetadata = definitions.TrussMetadata(
|
|
718
|
+
chainlet_to_service=chainlet_to_service
|
|
719
|
+
)
|
|
720
|
+
config.model_metadata[definitions.TRUSS_CONFIG_CHAINS_KEY] = (
|
|
721
|
+
chains_metadata.model_dump()
|
|
722
|
+
)
|
|
723
|
+
config.write_to_yaml_file(
|
|
724
|
+
chainlet_dir / serving_image_builder.CONFIG_FILE, verbose=True
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
|
|
728
|
+
def gen_truss_model_from_source(
|
|
729
|
+
model_src: pathlib.Path, use_local_chains_src: bool = False
|
|
730
|
+
) -> pathlib.Path:
|
|
731
|
+
# TODO(nikhil): Improve detection of directory structure, since right now
|
|
732
|
+
# we assume a flat structure
|
|
733
|
+
root_dir = model_src.absolute().parent
|
|
734
|
+
with framework.ModelImporter.import_target(model_src) as entrypoint_cls:
|
|
735
|
+
descriptor = framework.get_descriptor(entrypoint_cls)
|
|
736
|
+
return gen_truss_model(
|
|
737
|
+
model_root=root_dir,
|
|
738
|
+
model_name=entrypoint_cls.display_name,
|
|
739
|
+
model_descriptor=descriptor,
|
|
740
|
+
use_local_chains_src=use_local_chains_src,
|
|
741
|
+
)
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
def gen_truss_model(
|
|
745
|
+
model_root: pathlib.Path,
|
|
746
|
+
model_name: str,
|
|
747
|
+
model_descriptor: definitions.ChainletAPIDescriptor,
|
|
748
|
+
use_local_chains_src: bool = False,
|
|
749
|
+
) -> pathlib.Path:
|
|
750
|
+
return gen_truss_chainlet(
|
|
751
|
+
chain_root=model_root,
|
|
752
|
+
chain_name=model_name,
|
|
753
|
+
chainlet_descriptor=model_descriptor,
|
|
754
|
+
use_local_chains_src=use_local_chains_src,
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
|
|
758
|
+
def gen_truss_chainlet(
|
|
759
|
+
chain_root: pathlib.Path,
|
|
760
|
+
chain_name: str,
|
|
761
|
+
chainlet_descriptor: definitions.ChainletAPIDescriptor,
|
|
762
|
+
model_name: Optional[str] = None,
|
|
763
|
+
use_local_chains_src: bool = False,
|
|
764
|
+
) -> pathlib.Path:
|
|
765
|
+
# Filter needed services and customize options.
|
|
766
|
+
dep_services = {}
|
|
767
|
+
for dep in chainlet_descriptor.dependencies.values():
|
|
768
|
+
dep_services[dep.name] = definitions.ServiceDescriptor(
|
|
769
|
+
name=dep.name, display_name=dep.display_name, options=dep.options
|
|
770
|
+
)
|
|
771
|
+
gen_root = pathlib.Path(tempfile.gettempdir())
|
|
772
|
+
chainlet_dir = _make_chainlet_dir(chain_name, chainlet_descriptor, gen_root)
|
|
773
|
+
logging.info(
|
|
774
|
+
f"Code generation for {chainlet_descriptor.chainlet_cls.entity_type} `{chainlet_descriptor.name}` "
|
|
775
|
+
f"in `{chainlet_dir}`."
|
|
776
|
+
)
|
|
777
|
+
_write_truss_config_yaml(
|
|
778
|
+
chainlet_dir=chainlet_dir,
|
|
779
|
+
chains_config=chainlet_descriptor.chainlet_cls.remote_config,
|
|
780
|
+
model_name=model_name or chain_name,
|
|
781
|
+
chainlet_to_service=dep_services,
|
|
782
|
+
use_local_chains_src=use_local_chains_src,
|
|
783
|
+
)
|
|
784
|
+
# This assumes all imports are absolute w.r.t chain root (or site-packages).
|
|
785
|
+
truss_path.copy_tree_path(
|
|
786
|
+
chain_root, chainlet_dir / truss_config.DEFAULT_BUNDLED_PACKAGES_DIR
|
|
787
|
+
)
|
|
788
|
+
for file in chain_root.glob("*.py"):
|
|
789
|
+
if "-" in file.name:
|
|
790
|
+
raise definitions.ChainsUsageError(
|
|
791
|
+
f"Python file `{file}` contains `-`, use `_` instead."
|
|
792
|
+
)
|
|
793
|
+
if file.name == _MODEL_FILENAME:
|
|
794
|
+
raise definitions.ChainsUsageError(
|
|
795
|
+
f"Python file name `{_MODEL_FILENAME}` is reserved and cannot be used."
|
|
796
|
+
)
|
|
797
|
+
chainlet_file = _gen_truss_chainlet_file(
|
|
798
|
+
chainlet_dir,
|
|
799
|
+
chainlet_descriptor,
|
|
800
|
+
framework.get_dependencies(chainlet_descriptor),
|
|
801
|
+
)
|
|
802
|
+
remote_config = chainlet_descriptor.chainlet_cls.remote_config
|
|
803
|
+
if remote_config.docker_image.data_dir:
|
|
804
|
+
data_dir = chainlet_dir / truss_config.DEFAULT_DATA_DIRECTORY
|
|
805
|
+
data_dir.mkdir(parents=True, exist_ok=True)
|
|
806
|
+
user_data_dir = remote_config.docker_image.data_dir.abs_path
|
|
807
|
+
shutil.copytree(user_data_dir, data_dir, dirs_exist_ok=True)
|
|
808
|
+
|
|
809
|
+
# Copy model file s.t. during debugging imports can are properly resolved.
|
|
810
|
+
shutil.copy(
|
|
811
|
+
chainlet_file,
|
|
812
|
+
chainlet_file.parent.parent
|
|
813
|
+
/ truss_config.DEFAULT_BUNDLED_PACKAGES_DIR
|
|
814
|
+
/ "_model_dbg.py",
|
|
815
|
+
)
|
|
816
|
+
return chainlet_dir
|