truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/__init__.py +10 -3
- truss/api/__init__.py +123 -0
- truss/api/definitions.py +51 -0
- truss/base/constants.py +116 -0
- truss/base/custom_types.py +29 -0
- truss/{errors.py → base/errors.py} +4 -0
- truss/base/trt_llm_config.py +310 -0
- truss/{truss_config.py → base/truss_config.py} +344 -31
- truss/{truss_spec.py → base/truss_spec.py} +20 -6
- truss/{validation.py → base/validation.py} +60 -11
- truss/cli/cli.py +841 -88
- truss/{remote → cli}/remote_cli.py +2 -7
- truss/contexts/docker_build_setup.py +67 -0
- truss/contexts/image_builder/cache_warmer.py +2 -8
- truss/contexts/image_builder/image_builder.py +1 -1
- truss/contexts/image_builder/serving_image_builder.py +292 -46
- truss/contexts/image_builder/util.py +1 -3
- truss/contexts/local_loader/docker_build_emulator.py +58 -0
- truss/contexts/local_loader/load_model_local.py +2 -2
- truss/contexts/local_loader/truss_module_loader.py +1 -1
- truss/contexts/local_loader/utils.py +1 -1
- truss/local/local_config.py +2 -6
- truss/local/local_config_handler.py +20 -5
- truss/patch/__init__.py +1 -0
- truss/patch/hash.py +4 -70
- truss/patch/signature.py +4 -16
- truss/patch/truss_dir_patch_applier.py +3 -78
- truss/remote/baseten/api.py +308 -23
- truss/remote/baseten/auth.py +3 -3
- truss/remote/baseten/core.py +257 -50
- truss/remote/baseten/custom_types.py +44 -0
- truss/remote/baseten/error.py +4 -0
- truss/remote/baseten/remote.py +369 -118
- truss/remote/baseten/service.py +118 -11
- truss/remote/baseten/utils/status.py +29 -0
- truss/remote/baseten/utils/tar.py +34 -22
- truss/remote/baseten/utils/transfer.py +36 -23
- truss/remote/remote_factory.py +14 -5
- truss/remote/truss_remote.py +72 -45
- truss/templates/base.Dockerfile.jinja +18 -16
- truss/templates/cache.Dockerfile.jinja +3 -3
- truss/{server → templates/control}/control/application.py +14 -35
- truss/{server → templates/control}/control/endpoints.py +39 -9
- truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
- truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
- truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
- truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
- truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
- truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
- truss/{server → templates/control}/control/server.py +11 -6
- truss/templates/control/requirements.txt +9 -0
- truss/templates/custom_python_dx/my_model.py +28 -0
- truss/templates/docker_server/proxy.conf.jinja +42 -0
- truss/templates/docker_server/supervisord.conf.jinja +27 -0
- truss/templates/docker_server_requirements.txt +1 -0
- truss/templates/server/common/errors.py +231 -0
- truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
- truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
- truss/{server → templates/server}/common/retry.py +1 -0
- truss/{server → templates/server}/common/schema.py +11 -9
- truss/templates/server/common/tracing.py +157 -0
- truss/templates/server/main.py +9 -0
- truss/templates/server/model_wrapper.py +961 -0
- truss/templates/server/requirements.txt +21 -0
- truss/templates/server/truss_server.py +447 -0
- truss/templates/server.Dockerfile.jinja +62 -14
- truss/templates/shared/dynamic_config_resolver.py +28 -0
- truss/templates/shared/lazy_data_resolver.py +164 -0
- truss/templates/shared/log_config.py +125 -0
- truss/{server → templates}/shared/secrets_resolver.py +1 -2
- truss/{server → templates}/shared/serialization.py +31 -9
- truss/{server → templates}/shared/util.py +3 -13
- truss/templates/trtllm-audio/model/model.py +49 -0
- truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
- truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
- truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
- truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
- truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
- truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
- truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
- truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
- truss/templates/trtllm-briton/src/extension.py +64 -0
- truss/tests/conftest.py +302 -94
- truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
- truss/tests/contexts/local_loader/test_load_local.py +2 -2
- truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
- truss/tests/patch/test_calc_patch.py +439 -127
- truss/tests/patch/test_dir_signature.py +3 -12
- truss/tests/patch/test_hash.py +1 -1
- truss/tests/patch/test_signature.py +1 -1
- truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
- truss/tests/patch/test_types.py +2 -2
- truss/tests/remote/baseten/test_api.py +153 -58
- truss/tests/remote/baseten/test_auth.py +2 -1
- truss/tests/remote/baseten/test_core.py +160 -12
- truss/tests/remote/baseten/test_remote.py +489 -77
- truss/tests/remote/baseten/test_service.py +55 -0
- truss/tests/remote/test_remote_factory.py +16 -18
- truss/tests/remote/test_truss_remote.py +26 -17
- truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
- truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
- truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
- truss/tests/{server → templates/control}/control/test_server.py +79 -24
- truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
- truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
- truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
- truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
- truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
- truss/tests/{server → templates/server}/common/test_retry.py +3 -3
- truss/tests/templates/server/test_model_wrapper.py +248 -0
- truss/tests/{server → templates/server}/test_schema.py +3 -5
- truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
- truss/tests/test_build.py +9 -52
- truss/tests/test_config.py +336 -77
- truss/tests/test_context_builder_image.py +3 -11
- truss/tests/test_control_truss_patching.py +7 -12
- truss/tests/test_custom_server.py +38 -0
- truss/tests/test_data/context_builder_image_test/test.py +3 -0
- truss/tests/test_data/happy.ipynb +56 -0
- truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
- truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
- truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
- truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
- truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
- truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
- truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
- truss/tests/test_data/test_async_truss/__init__.py +0 -0
- truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/config.yaml +16 -0
- truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/config.yaml +13 -0
- truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
- truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
- truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
- truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
- truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
- truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
- truss/tests/test_data/test_openai/__init__.py +0 -0
- truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
- truss/tests/test_data/test_openai/model/__init__.py +0 -0
- truss/tests/test_data/test_openai/model/model.py +15 -0
- truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
- truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
- truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
- truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
- truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
- truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
- truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
- truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
- truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
- truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
- truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
- truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
- truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
- truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
- truss/tests/test_data/test_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss/config.yaml +4 -0
- truss/tests/test_data/test_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss/model/dummy +0 -0
- truss/tests/test_data/test_truss/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
- truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
- truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_docker.py +2 -1
- truss/tests/test_model_inference.py +1340 -292
- truss/tests/test_model_schema.py +33 -26
- truss/tests/test_testing_utilities_for_other_tests.py +50 -5
- truss/tests/test_truss_gatherer.py +3 -5
- truss/tests/test_truss_handle.py +62 -59
- truss/tests/test_util.py +2 -1
- truss/tests/test_validation.py +15 -13
- truss/tests/trt_llm/test_trt_llm_config.py +41 -0
- truss/tests/trt_llm/test_validation.py +91 -0
- truss/tests/util/test_config_checks.py +40 -0
- truss/tests/util/test_env_vars.py +14 -0
- truss/tests/util/test_path.py +10 -23
- truss/trt_llm/config_checks.py +43 -0
- truss/trt_llm/validation.py +42 -0
- truss/truss_handle/__init__.py +0 -0
- truss/truss_handle/build.py +122 -0
- truss/{decorators.py → truss_handle/decorators.py} +1 -1
- truss/truss_handle/patch/__init__.py +0 -0
- truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
- truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
- truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
- truss/truss_handle/patch/hash.py +71 -0
- truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
- truss/truss_handle/patch/signature.py +22 -0
- truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
- truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
- truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
- truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
- truss/util/.truss_ignore +3 -0
- truss/{docker.py → util/docker.py} +6 -2
- truss/util/download.py +6 -15
- truss/util/env_vars.py +41 -0
- truss/util/log_utils.py +52 -0
- truss/util/path.py +20 -20
- truss/util/requirements.py +11 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
- truss-0.60.0.dist-info/RECORD +324 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
- truss-0.60.0.dist-info/entry_points.txt +4 -0
- truss_chains/__init__.py +71 -0
- truss_chains/definitions.py +756 -0
- truss_chains/deployment/__init__.py +0 -0
- truss_chains/deployment/code_gen.py +816 -0
- truss_chains/deployment/deployment_client.py +871 -0
- truss_chains/framework.py +1480 -0
- truss_chains/public_api.py +231 -0
- truss_chains/py.typed +0 -0
- truss_chains/pydantic_numpy.py +131 -0
- truss_chains/reference_code/reference_chainlet.py +34 -0
- truss_chains/reference_code/reference_model.py +10 -0
- truss_chains/remote_chainlet/__init__.py +0 -0
- truss_chains/remote_chainlet/model_skeleton.py +60 -0
- truss_chains/remote_chainlet/stub.py +380 -0
- truss_chains/remote_chainlet/utils.py +332 -0
- truss_chains/streaming.py +378 -0
- truss_chains/utils.py +178 -0
- CODE_OF_CONDUCT.md +0 -131
- CONTRIBUTING.md +0 -48
- README.md +0 -137
- context_builder.Dockerfile +0 -24
- truss/blob/blob_backend.py +0 -10
- truss/blob/blob_backend_registry.py +0 -23
- truss/blob/http_public_blob_backend.py +0 -23
- truss/build/__init__.py +0 -2
- truss/build/build.py +0 -143
- truss/build/configure.py +0 -63
- truss/cli/__init__.py +0 -2
- truss/cli/console.py +0 -5
- truss/cli/create.py +0 -5
- truss/config/trt_llm.py +0 -81
- truss/constants.py +0 -61
- truss/model_inference.py +0 -123
- truss/patch/types.py +0 -30
- truss/pytest.ini +0 -7
- truss/server/common/errors.py +0 -100
- truss/server/common/termination_handler_middleware.py +0 -64
- truss/server/common/truss_server.py +0 -389
- truss/server/control/patch/model_code_patch_applier.py +0 -46
- truss/server/control/patch/requirement_name_identifier.py +0 -17
- truss/server/inference_server.py +0 -29
- truss/server/model_wrapper.py +0 -434
- truss/server/shared/logging.py +0 -81
- truss/templates/trtllm/model/model.py +0 -97
- truss/templates/trtllm/packages/build_engine_utils.py +0 -34
- truss/templates/trtllm/packages/constants.py +0 -11
- truss/templates/trtllm/packages/schema.py +0 -216
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
- truss/templates/trtllm/packages/triton_client.py +0 -150
- truss/templates/trtllm/packages/utils.py +0 -43
- truss/test_data/context_builder_image_test/test.py +0 -4
- truss/test_data/happy.ipynb +0 -54
- truss/test_data/model_load_failure_test/config.yaml +0 -2
- truss/test_data/test_concurrency_truss/config.yaml +0 -2
- truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
- truss/test_data/test_streaming_truss/config.yaml +0 -3
- truss/test_data/test_truss/config.yaml +0 -2
- truss/tests/server/common/test_termination_handler_middleware.py +0 -93
- truss/tests/server/control/test_model_container_patch_applier.py +0 -203
- truss/tests/server/core/server/common/test_util.py +0 -19
- truss/tests/server/test_model_wrapper.py +0 -87
- truss/util/data_structures.py +0 -16
- truss-0.10.0rc1.dist-info/RECORD +0 -216
- truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
- truss/{server/shared → base}/__init__.py +0 -0
- truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
- truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
- truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
- truss/{server → templates/server}/common/__init__.py +0 -0
- truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
- truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
- truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
- truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
- truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
- truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
- truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
- truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
- truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
- truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
- truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
- truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
- truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
- truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
- /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
- /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
- /truss/{patch → truss_handle/patch}/constants.py +0 -0
- /truss/{notebook.py → util/notebook.py} +0 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
|
@@ -0,0 +1,871 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import concurrent.futures
|
|
3
|
+
import inspect
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import pathlib
|
|
7
|
+
import textwrap
|
|
8
|
+
import traceback
|
|
9
|
+
import uuid
|
|
10
|
+
from typing import (
|
|
11
|
+
TYPE_CHECKING,
|
|
12
|
+
Any,
|
|
13
|
+
Callable,
|
|
14
|
+
Dict,
|
|
15
|
+
Iterable,
|
|
16
|
+
Iterator,
|
|
17
|
+
Mapping,
|
|
18
|
+
Optional,
|
|
19
|
+
Type,
|
|
20
|
+
cast,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
import requests
|
|
24
|
+
import tenacity
|
|
25
|
+
import watchfiles
|
|
26
|
+
from truss.local import local_config_handler
|
|
27
|
+
from truss.remote import remote_factory
|
|
28
|
+
from truss.remote.baseten import core as b10_core
|
|
29
|
+
from truss.remote.baseten import custom_types as b10_types
|
|
30
|
+
from truss.remote.baseten import error as b10_errors
|
|
31
|
+
from truss.remote.baseten import remote as b10_remote
|
|
32
|
+
from truss.remote.baseten import service as b10_service
|
|
33
|
+
from truss.truss_handle import truss_handle
|
|
34
|
+
from truss.util import log_utils
|
|
35
|
+
from truss.util import path as truss_path
|
|
36
|
+
|
|
37
|
+
from truss_chains import definitions, framework, utils
|
|
38
|
+
from truss_chains.deployment import code_gen
|
|
39
|
+
|
|
40
|
+
if TYPE_CHECKING:
|
|
41
|
+
from rich import console as rich_console
|
|
42
|
+
from rich import progress
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _get_ordered_dependencies(
|
|
46
|
+
chainlets: Iterable[Type[definitions.ABCChainlet]],
|
|
47
|
+
) -> Iterable[definitions.ChainletAPIDescriptor]:
|
|
48
|
+
"""Gather all Chainlets needed and returns a topologically ordered list."""
|
|
49
|
+
needed_chainlets: set[definitions.ChainletAPIDescriptor] = set()
|
|
50
|
+
|
|
51
|
+
def add_needed_chainlets(chainlet: definitions.ChainletAPIDescriptor):
|
|
52
|
+
needed_chainlets.add(chainlet)
|
|
53
|
+
for chainlet_descriptor in framework.get_dependencies(chainlet):
|
|
54
|
+
needed_chainlets.add(chainlet_descriptor)
|
|
55
|
+
add_needed_chainlets(chainlet_descriptor)
|
|
56
|
+
|
|
57
|
+
for chainlet_cls in chainlets:
|
|
58
|
+
add_needed_chainlets(framework.get_descriptor(chainlet_cls))
|
|
59
|
+
# Get dependencies in topological order.
|
|
60
|
+
return [
|
|
61
|
+
descr
|
|
62
|
+
for descr in framework.get_ordered_descriptors()
|
|
63
|
+
if descr in needed_chainlets
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _get_chain_root(entrypoint: Type[definitions.ABCChainlet]) -> pathlib.Path:
|
|
68
|
+
# TODO: revisit how chain root is inferred/specified, current might be brittle.
|
|
69
|
+
chain_root = pathlib.Path(inspect.getfile(entrypoint)).absolute().parent
|
|
70
|
+
logging.info(
|
|
71
|
+
f"Using chain workspace dir: `{chain_root}` (files under this dir will "
|
|
72
|
+
"be included as dependencies in the remote deployments and are importable)."
|
|
73
|
+
)
|
|
74
|
+
return chain_root
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class ChainService(abc.ABC):
|
|
78
|
+
"""Handle for a deployed chain.
|
|
79
|
+
|
|
80
|
+
A ``ChainService`` is created and returned when using ``push``. It
|
|
81
|
+
bundles the individual services for each chainlet in the chain, and provides
|
|
82
|
+
utilities to query their status, invoke the entrypoint etc.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
_name: str
|
|
86
|
+
_entrypoint_fake_json_data: Any
|
|
87
|
+
|
|
88
|
+
def __init__(self, name: str):
|
|
89
|
+
self._name = name
|
|
90
|
+
self._entrypoint_fake_json_data = None
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def name(self) -> str:
|
|
94
|
+
return self._name
|
|
95
|
+
|
|
96
|
+
@property
|
|
97
|
+
@abc.abstractmethod
|
|
98
|
+
def status_page_url(self) -> str:
|
|
99
|
+
"""Link to status page on Baseten."""
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
@abc.abstractmethod
|
|
103
|
+
def run_remote_url(self) -> str:
|
|
104
|
+
"""URL to invoke the entrypoint."""
|
|
105
|
+
|
|
106
|
+
@abc.abstractmethod
|
|
107
|
+
def run_remote(self, json: Dict) -> Any:
|
|
108
|
+
"""Invokes the entrypoint with JSON data.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
The JSON response."""
|
|
112
|
+
|
|
113
|
+
@abc.abstractmethod
|
|
114
|
+
def get_info(self) -> list[b10_types.DeployedChainlet]:
|
|
115
|
+
"""Queries the statuses of all chainlets in the chain.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
List of ``DeployedChainlet`` for each chainlet."""
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def entrypoint_fake_json_data(self) -> Any:
|
|
122
|
+
"""Fake JSON example data that matches the entrypoint's input schema.
|
|
123
|
+
This property must be externally populated.
|
|
124
|
+
|
|
125
|
+
Raises:
|
|
126
|
+
ValueError: If fake data was not set.
|
|
127
|
+
"""
|
|
128
|
+
if self._entrypoint_fake_json_data is None:
|
|
129
|
+
raise ValueError("Fake data was not set.")
|
|
130
|
+
return self._entrypoint_fake_json_data
|
|
131
|
+
|
|
132
|
+
@entrypoint_fake_json_data.setter
|
|
133
|
+
def entrypoint_fake_json_data(self, fake_data: Any) -> None:
|
|
134
|
+
self._entrypoint_fake_json_data = fake_data
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class _ChainSourceGenerator:
|
|
138
|
+
def __init__(self, options: definitions.PushOptions) -> None:
|
|
139
|
+
self._options = options
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def _use_local_chains_src(self) -> bool:
|
|
143
|
+
if isinstance(self._options, definitions.PushOptionsLocalDocker):
|
|
144
|
+
return self._options.use_local_chains_src
|
|
145
|
+
return False
|
|
146
|
+
|
|
147
|
+
def generate_chainlet_artifacts(
|
|
148
|
+
self, entrypoint: Type[definitions.ABCChainlet]
|
|
149
|
+
) -> tuple[b10_types.ChainletArtifact, list[b10_types.ChainletArtifact]]:
|
|
150
|
+
chain_root = _get_chain_root(entrypoint)
|
|
151
|
+
entrypoint_artifact: Optional[b10_types.ChainletArtifact] = None
|
|
152
|
+
dependency_artifacts: list[b10_types.ChainletArtifact] = []
|
|
153
|
+
chainlet_display_names: set[str] = set()
|
|
154
|
+
|
|
155
|
+
for chainlet_descriptor in _get_ordered_dependencies([entrypoint]):
|
|
156
|
+
chainlet_display_name = chainlet_descriptor.display_name
|
|
157
|
+
|
|
158
|
+
if chainlet_display_name in chainlet_display_names:
|
|
159
|
+
raise definitions.ChainsUsageError(
|
|
160
|
+
f"Chainlet names must be unique. Found multiple Chainlets with the name: '{chainlet_display_name}'."
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
chainlet_display_names.add(chainlet_display_name)
|
|
164
|
+
|
|
165
|
+
# Since we are creating a distinct model for each deployment of the chain,
|
|
166
|
+
# we add a random suffix.
|
|
167
|
+
model_suffix = str(uuid.uuid4()).split("-")[0]
|
|
168
|
+
model_name = f"{chainlet_display_name}-{model_suffix}"
|
|
169
|
+
|
|
170
|
+
chainlet_dir = code_gen.gen_truss_chainlet(
|
|
171
|
+
chain_root,
|
|
172
|
+
self._options.chain_name,
|
|
173
|
+
chainlet_descriptor,
|
|
174
|
+
model_name,
|
|
175
|
+
self._use_local_chains_src,
|
|
176
|
+
)
|
|
177
|
+
artifact = b10_types.ChainletArtifact(
|
|
178
|
+
truss_dir=chainlet_dir,
|
|
179
|
+
name=chainlet_descriptor.name,
|
|
180
|
+
display_name=chainlet_display_name,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
is_entrypoint = chainlet_descriptor.chainlet_cls == entrypoint
|
|
184
|
+
|
|
185
|
+
if is_entrypoint:
|
|
186
|
+
assert entrypoint_artifact is None
|
|
187
|
+
|
|
188
|
+
entrypoint_artifact = artifact
|
|
189
|
+
else:
|
|
190
|
+
dependency_artifacts.append(artifact)
|
|
191
|
+
|
|
192
|
+
assert entrypoint_artifact is not None
|
|
193
|
+
|
|
194
|
+
return entrypoint_artifact, dependency_artifacts
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
@framework.raise_validation_errors_before
|
|
198
|
+
def push(
|
|
199
|
+
entrypoint: Type[definitions.ABCChainlet],
|
|
200
|
+
options: definitions.PushOptions,
|
|
201
|
+
progress_bar: Optional[Type["progress.Progress"]] = None,
|
|
202
|
+
) -> Optional[ChainService]:
|
|
203
|
+
entrypoint_artifact, dependency_artifacts = _ChainSourceGenerator(
|
|
204
|
+
options
|
|
205
|
+
).generate_chainlet_artifacts(entrypoint)
|
|
206
|
+
if options.only_generate_trusses:
|
|
207
|
+
return None
|
|
208
|
+
if isinstance(options, definitions.PushOptionsBaseten):
|
|
209
|
+
return _create_baseten_chain(
|
|
210
|
+
options, entrypoint_artifact, dependency_artifacts, progress_bar
|
|
211
|
+
)
|
|
212
|
+
elif isinstance(options, definitions.PushOptionsLocalDocker):
|
|
213
|
+
return _create_docker_chain(options, entrypoint_artifact, dependency_artifacts)
|
|
214
|
+
else:
|
|
215
|
+
raise NotImplementedError(options)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
# Docker ###############################################################################
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class DockerChainletService(b10_service.TrussService):
|
|
222
|
+
"""This service is for Chainlets (not for Chains)."""
|
|
223
|
+
|
|
224
|
+
def __init__(self, port: int, **kwargs):
|
|
225
|
+
remote_url = f"http://localhost:{port}"
|
|
226
|
+
|
|
227
|
+
super().__init__(remote_url, is_draft=False, **kwargs)
|
|
228
|
+
|
|
229
|
+
def authenticate(self) -> Dict[str, str]:
|
|
230
|
+
return {}
|
|
231
|
+
|
|
232
|
+
def is_live(self) -> bool:
|
|
233
|
+
response = self._send_request(self._service_url, "GET")
|
|
234
|
+
if response.status_code == 200:
|
|
235
|
+
return True
|
|
236
|
+
return False
|
|
237
|
+
|
|
238
|
+
def is_ready(self) -> bool:
|
|
239
|
+
response = self._send_request(self._service_url, "GET")
|
|
240
|
+
if response.status_code == 200:
|
|
241
|
+
return True
|
|
242
|
+
return False
|
|
243
|
+
|
|
244
|
+
@property
|
|
245
|
+
def logs_url(self) -> str:
|
|
246
|
+
raise NotImplementedError()
|
|
247
|
+
|
|
248
|
+
@property
|
|
249
|
+
def predict_url(self) -> str:
|
|
250
|
+
return f"{self._service_url}/v1/models/model:predict"
|
|
251
|
+
|
|
252
|
+
def poll_deployment_status(self, sleep_secs: int = 1) -> Iterator[str]:
|
|
253
|
+
raise NotImplementedError()
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def _push_service_docker(
|
|
257
|
+
truss_dir: pathlib.Path,
|
|
258
|
+
chainlet_display_name: str,
|
|
259
|
+
options: definitions.PushOptionsLocalDocker,
|
|
260
|
+
port: int,
|
|
261
|
+
) -> None:
|
|
262
|
+
th = truss_handle.TrussHandle(truss_dir)
|
|
263
|
+
th.add_secret(definitions.BASETEN_API_SECRET_NAME, options.baseten_chain_api_key)
|
|
264
|
+
th.docker_run(
|
|
265
|
+
local_port=port,
|
|
266
|
+
detach=True,
|
|
267
|
+
wait_for_server_ready=True,
|
|
268
|
+
network="host",
|
|
269
|
+
container_name_prefix=chainlet_display_name,
|
|
270
|
+
disable_json_logging=True,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class DockerChainService(ChainService):
|
|
275
|
+
_entrypoint_service: DockerChainletService
|
|
276
|
+
|
|
277
|
+
def __init__(self, name: str, entrypoint_service: DockerChainletService) -> None:
|
|
278
|
+
super().__init__(name)
|
|
279
|
+
self._entrypoint_service = entrypoint_service
|
|
280
|
+
|
|
281
|
+
@property
|
|
282
|
+
def run_remote_url(self) -> str:
|
|
283
|
+
"""URL to invoke the entrypoint."""
|
|
284
|
+
return self._entrypoint_service.predict_url
|
|
285
|
+
|
|
286
|
+
def run_remote(self, json: Dict) -> Any:
|
|
287
|
+
"""Invokes the entrypoint with JSON data.
|
|
288
|
+
|
|
289
|
+
Returns:
|
|
290
|
+
The JSON response."""
|
|
291
|
+
return self._entrypoint_service.predict(json)
|
|
292
|
+
|
|
293
|
+
@property
|
|
294
|
+
def status_page_url(self) -> str:
|
|
295
|
+
"""Not Implemented."""
|
|
296
|
+
raise NotImplementedError()
|
|
297
|
+
|
|
298
|
+
def get_info(self) -> list[b10_types.DeployedChainlet]:
|
|
299
|
+
"""Not Implemented."""
|
|
300
|
+
raise NotImplementedError()
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def _create_docker_chain(
|
|
304
|
+
docker_options: definitions.PushOptionsLocalDocker,
|
|
305
|
+
entrypoint_artifact: b10_types.ChainletArtifact,
|
|
306
|
+
dependency_artifacts: list[b10_types.ChainletArtifact],
|
|
307
|
+
) -> DockerChainService:
|
|
308
|
+
chainlet_artifacts = [*dependency_artifacts, entrypoint_artifact]
|
|
309
|
+
chainlet_to_predict_url: Dict[str, Dict[str, str]] = {}
|
|
310
|
+
chainlet_to_service: Dict[str, DockerChainletService] = {}
|
|
311
|
+
for chainlet_artifact in chainlet_artifacts:
|
|
312
|
+
port = utils.get_free_port()
|
|
313
|
+
service = DockerChainletService(port)
|
|
314
|
+
|
|
315
|
+
docker_internal_url = service.predict_url.replace(
|
|
316
|
+
"localhost", "host.docker.internal"
|
|
317
|
+
)
|
|
318
|
+
chainlet_to_predict_url[chainlet_artifact.display_name] = {
|
|
319
|
+
"predict_url": docker_internal_url
|
|
320
|
+
}
|
|
321
|
+
chainlet_to_service[chainlet_artifact.name] = service
|
|
322
|
+
|
|
323
|
+
local_config_handler.LocalConfigHandler.set_dynamic_config(
|
|
324
|
+
definitions.DYNAMIC_CHAINLET_CONFIG_KEY, json.dumps(chainlet_to_predict_url)
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
truss_dir = chainlet_artifact.truss_dir
|
|
328
|
+
logging.info(
|
|
329
|
+
f"Building Chainlet `{chainlet_artifact.display_name}` docker image."
|
|
330
|
+
)
|
|
331
|
+
_push_service_docker(
|
|
332
|
+
truss_dir, chainlet_artifact.display_name, docker_options, port
|
|
333
|
+
)
|
|
334
|
+
logging.info(
|
|
335
|
+
f"Pushed Chainlet `{chainlet_artifact.display_name}` as docker container."
|
|
336
|
+
)
|
|
337
|
+
logging.debug(
|
|
338
|
+
"Internal model endpoint: "
|
|
339
|
+
f"`{chainlet_to_predict_url[chainlet_artifact.display_name]}`"
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
return DockerChainService(
|
|
343
|
+
docker_options.chain_name, chainlet_to_service[entrypoint_artifact.name]
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
# Baseten ##############################################################################
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
class BasetenChainService(ChainService):
|
|
351
|
+
_chain_deployment_handle: b10_core.ChainDeploymentHandleAtomic
|
|
352
|
+
_remote: b10_remote.BasetenRemote
|
|
353
|
+
|
|
354
|
+
def __init__(
|
|
355
|
+
self,
|
|
356
|
+
name: str,
|
|
357
|
+
chain_deployment_handle: b10_core.ChainDeploymentHandleAtomic,
|
|
358
|
+
remote: b10_remote.BasetenRemote,
|
|
359
|
+
) -> None:
|
|
360
|
+
super().__init__(name)
|
|
361
|
+
self._chain_deployment_handle = chain_deployment_handle
|
|
362
|
+
self._remote = remote
|
|
363
|
+
|
|
364
|
+
@property
|
|
365
|
+
def run_remote_url(self) -> str:
|
|
366
|
+
"""URL to invoke the entrypoint."""
|
|
367
|
+
return b10_service.URLConfig.invocation_url(
|
|
368
|
+
self._remote.api.rest_api_url,
|
|
369
|
+
b10_service.URLConfig.CHAIN,
|
|
370
|
+
self._chain_deployment_handle.chain_id,
|
|
371
|
+
self._chain_deployment_handle.chain_deployment_id,
|
|
372
|
+
self._chain_deployment_handle.is_draft,
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
def run_remote(self, json_data: Dict) -> Any:
|
|
376
|
+
"""Invokes the entrypoint with JSON data.
|
|
377
|
+
|
|
378
|
+
Returns:
|
|
379
|
+
The JSON response."""
|
|
380
|
+
headers = self._remote._auth_service.authenticate().header()
|
|
381
|
+
response = requests.post(
|
|
382
|
+
self.run_remote_url, json=json_data, headers=headers, stream=True
|
|
383
|
+
)
|
|
384
|
+
if response.status_code == 401:
|
|
385
|
+
raise ValueError(
|
|
386
|
+
f"Authentication failed with status code {response.status_code}"
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
if response.headers.get("transfer-encoding") == "chunked":
|
|
390
|
+
# Case of streaming response, the backend does not set an encoding, so
|
|
391
|
+
# manually decode to the contents to utf-8 here.
|
|
392
|
+
def decode_content():
|
|
393
|
+
for chunk in response.iter_content(
|
|
394
|
+
chunk_size=8192, decode_unicode=True
|
|
395
|
+
):
|
|
396
|
+
# Depending on the content-type of the response,
|
|
397
|
+
# iter_content will either emit a byte stream, or a stream
|
|
398
|
+
# of strings. Only decode in the bytes case.
|
|
399
|
+
if isinstance(chunk, bytes):
|
|
400
|
+
yield chunk.decode(
|
|
401
|
+
response.encoding or b10_service.DEFAULT_STREAM_ENCODING
|
|
402
|
+
)
|
|
403
|
+
else:
|
|
404
|
+
yield chunk
|
|
405
|
+
|
|
406
|
+
return decode_content()
|
|
407
|
+
|
|
408
|
+
parsed_response = response.json()
|
|
409
|
+
|
|
410
|
+
if "error" in parsed_response:
|
|
411
|
+
# In the case that the model is in a non-ready state, the response
|
|
412
|
+
# will be a json with an `error` key.
|
|
413
|
+
return parsed_response
|
|
414
|
+
|
|
415
|
+
return response.json()
|
|
416
|
+
|
|
417
|
+
@property
|
|
418
|
+
def status_page_url(self) -> str:
|
|
419
|
+
"""Link to status page on Baseten."""
|
|
420
|
+
return b10_service.URLConfig.status_page_url(
|
|
421
|
+
self._remote.remote_url,
|
|
422
|
+
b10_service.URLConfig.CHAIN,
|
|
423
|
+
self._chain_deployment_handle.chain_id,
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
@tenacity.retry(
|
|
427
|
+
stop=tenacity.stop_after_delay(300), wait=tenacity.wait_fixed(1), reraise=True
|
|
428
|
+
)
|
|
429
|
+
def get_info(self) -> list[b10_types.DeployedChainlet]:
|
|
430
|
+
"""Queries the statuses of all chainlets in the chain.
|
|
431
|
+
|
|
432
|
+
Returns:
|
|
433
|
+
List of ``DeployedChainlet`` for each chainlet."""
|
|
434
|
+
return self._remote.get_chainlets(
|
|
435
|
+
self._chain_deployment_handle.chain_deployment_id
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def _create_baseten_chain(
|
|
440
|
+
baseten_options: definitions.PushOptionsBaseten,
|
|
441
|
+
entrypoint_artifact: b10_types.ChainletArtifact,
|
|
442
|
+
dependency_artifacts: list[b10_types.ChainletArtifact],
|
|
443
|
+
progress_bar: Optional[Type["progress.Progress"]],
|
|
444
|
+
):
|
|
445
|
+
logging.info(
|
|
446
|
+
f"Pushing Chain '{baseten_options.chain_name}' to Baseten "
|
|
447
|
+
f"(publish={baseten_options.publish}, environment={baseten_options.environment})."
|
|
448
|
+
)
|
|
449
|
+
remote_provider = cast(
|
|
450
|
+
b10_remote.BasetenRemote,
|
|
451
|
+
remote_factory.RemoteFactory.create(remote=baseten_options.remote),
|
|
452
|
+
)
|
|
453
|
+
_create_chains_secret_if_missing(remote_provider)
|
|
454
|
+
|
|
455
|
+
chain_deployment_handle = remote_provider.push_chain_atomic(
|
|
456
|
+
chain_name=baseten_options.chain_name,
|
|
457
|
+
entrypoint_artifact=entrypoint_artifact,
|
|
458
|
+
dependency_artifacts=dependency_artifacts,
|
|
459
|
+
publish=baseten_options.publish,
|
|
460
|
+
environment=baseten_options.environment,
|
|
461
|
+
progress_bar=progress_bar,
|
|
462
|
+
)
|
|
463
|
+
return BasetenChainService(
|
|
464
|
+
baseten_options.chain_name, chain_deployment_handle, remote_provider
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
def _create_chains_secret_if_missing(remote_provider: b10_remote.BasetenRemote) -> None:
|
|
469
|
+
secrets_info = remote_provider.api.get_all_secrets()
|
|
470
|
+
secret_names = {sec["name"] for sec in secrets_info["secrets"]}
|
|
471
|
+
if definitions.BASETEN_API_SECRET_NAME not in secret_names:
|
|
472
|
+
logging.info(
|
|
473
|
+
"It seems you are using chains for the first time, since there "
|
|
474
|
+
f"is no `{definitions.BASETEN_API_SECRET_NAME}` secret on baseten. "
|
|
475
|
+
"Creating secret automatically."
|
|
476
|
+
)
|
|
477
|
+
remote_provider.api.upsert_secret(
|
|
478
|
+
definitions.BASETEN_API_SECRET_NAME, remote_provider.api.auth_token.value
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
# Watch / Live Patching ################################################################
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
def _create_watch_filter(root_dir: pathlib.Path):
|
|
486
|
+
ignore_patterns = truss_path.load_trussignore_patterns_from_truss_dir(root_dir)
|
|
487
|
+
|
|
488
|
+
def watch_filter(_: watchfiles.Change, path: str) -> bool:
|
|
489
|
+
return not truss_path.is_ignored(pathlib.Path(path), ignore_patterns)
|
|
490
|
+
|
|
491
|
+
logging.getLogger("watchfiles.main").disabled = True
|
|
492
|
+
return ignore_patterns, watch_filter
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
def _handle_intercepted_logs(logs: list[str], console: "rich_console.Console"):
|
|
496
|
+
if logs:
|
|
497
|
+
formatted_logs = textwrap.indent("\n".join(logs), " " * 4)
|
|
498
|
+
console.print(f"Intercepted logs from importing source code:\n{formatted_logs}")
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
def _handle_import_error(
|
|
502
|
+
exception: Exception,
|
|
503
|
+
console: "rich_console.Console",
|
|
504
|
+
error_console: "rich_console.Console",
|
|
505
|
+
stack_trace: Optional[str] = None,
|
|
506
|
+
):
|
|
507
|
+
error_console.print(
|
|
508
|
+
"Source files were changed, but pre-conditions for "
|
|
509
|
+
"live patching are not given. Most likely there is a "
|
|
510
|
+
"syntax error in the source files or names changed. "
|
|
511
|
+
"Try to fix the issue and save the file. Error:\n"
|
|
512
|
+
f"{textwrap.indent(str(exception), ' ' * 4)}"
|
|
513
|
+
)
|
|
514
|
+
if stack_trace:
|
|
515
|
+
error_console.print(stack_trace)
|
|
516
|
+
|
|
517
|
+
console.print(
|
|
518
|
+
"The watcher will continue and if you can resolve the "
|
|
519
|
+
"issue, subsequent patches might succeed.",
|
|
520
|
+
style="blue",
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
class _ModelWatcher:
|
|
525
|
+
_source: pathlib.Path
|
|
526
|
+
_model_name: str
|
|
527
|
+
_remote_provider: b10_remote.BasetenRemote
|
|
528
|
+
_ignore_patterns: list[str]
|
|
529
|
+
_watch_filter: Callable[[watchfiles.Change, str], bool]
|
|
530
|
+
_console: "rich_console.Console"
|
|
531
|
+
_error_console: "rich_console.Console"
|
|
532
|
+
|
|
533
|
+
def __init__(
|
|
534
|
+
self,
|
|
535
|
+
source: pathlib.Path,
|
|
536
|
+
model_name: str,
|
|
537
|
+
remote_provider: b10_remote.BasetenRemote,
|
|
538
|
+
console: "rich_console.Console",
|
|
539
|
+
error_console: "rich_console.Console",
|
|
540
|
+
) -> None:
|
|
541
|
+
self._source = source
|
|
542
|
+
self._model_name = model_name
|
|
543
|
+
self._remote_provider = remote_provider
|
|
544
|
+
self._console = console
|
|
545
|
+
self._error_console = error_console
|
|
546
|
+
self._ignore_patterns, self._watch_filter = _create_watch_filter(
|
|
547
|
+
source.absolute().parent
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
dev_version = b10_core.get_dev_version(self._remote_provider.api, model_name)
|
|
551
|
+
if not dev_version:
|
|
552
|
+
raise b10_errors.RemoteError(
|
|
553
|
+
"No development model found. Run `truss push` then try again."
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
def _patch(self) -> None:
|
|
557
|
+
exception_raised = None
|
|
558
|
+
with log_utils.LogInterceptor() as log_interceptor, self._console.status(
|
|
559
|
+
" Live Patching Model.\n", spinner="arrow3"
|
|
560
|
+
):
|
|
561
|
+
try:
|
|
562
|
+
gen_truss_path = code_gen.gen_truss_model_from_source(self._source)
|
|
563
|
+
return self._remote_provider.patch(
|
|
564
|
+
gen_truss_path,
|
|
565
|
+
self._ignore_patterns,
|
|
566
|
+
self._console,
|
|
567
|
+
self._error_console,
|
|
568
|
+
)
|
|
569
|
+
except Exception as e:
|
|
570
|
+
exception_raised = e
|
|
571
|
+
finally:
|
|
572
|
+
logs = log_interceptor.get_logs()
|
|
573
|
+
|
|
574
|
+
_handle_intercepted_logs(logs, self._console)
|
|
575
|
+
if exception_raised:
|
|
576
|
+
_handle_import_error(exception_raised, self._console, self._error_console)
|
|
577
|
+
|
|
578
|
+
def watch(self) -> None:
|
|
579
|
+
# Perform one initial patch at startup.
|
|
580
|
+
self._patch()
|
|
581
|
+
self._console.print("👀 Watching for new changes.", style="blue")
|
|
582
|
+
|
|
583
|
+
# TODO(nikhil): Improve detection of directory structure, since right now
|
|
584
|
+
# we assume a flat structure
|
|
585
|
+
root_dir = self._source.absolute().parent
|
|
586
|
+
for _ in watchfiles.watch(
|
|
587
|
+
root_dir, watch_filter=self._watch_filter, raise_interrupt=False
|
|
588
|
+
):
|
|
589
|
+
self._patch()
|
|
590
|
+
self._console.print("👀 Watching for new changes.", style="blue")
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
class _Watcher:
|
|
594
|
+
_source: pathlib.Path
|
|
595
|
+
_entrypoint: Optional[str]
|
|
596
|
+
_deployed_chain_name: str
|
|
597
|
+
_remote_provider: b10_remote.BasetenRemote
|
|
598
|
+
_chainlet_data: Mapping[str, b10_types.DeployedChainlet]
|
|
599
|
+
_watch_filter: Callable[[watchfiles.Change, str], bool]
|
|
600
|
+
_console: "rich_console.Console"
|
|
601
|
+
_error_console: "rich_console.Console"
|
|
602
|
+
_show_stack_trace: bool
|
|
603
|
+
_included_chainlets: set[str]
|
|
604
|
+
|
|
605
|
+
def __init__(
|
|
606
|
+
self,
|
|
607
|
+
source: pathlib.Path,
|
|
608
|
+
entrypoint: Optional[str],
|
|
609
|
+
name: Optional[str],
|
|
610
|
+
remote: str,
|
|
611
|
+
console: "rich_console.Console",
|
|
612
|
+
error_console: "rich_console.Console",
|
|
613
|
+
show_stack_trace: bool,
|
|
614
|
+
included_chainlets: Optional[list[str]],
|
|
615
|
+
) -> None:
|
|
616
|
+
self._source = source
|
|
617
|
+
self._entrypoint = entrypoint
|
|
618
|
+
self._console = console
|
|
619
|
+
self._error_console = error_console
|
|
620
|
+
self._show_stack_trace = show_stack_trace
|
|
621
|
+
self._remote_provider = cast(
|
|
622
|
+
b10_remote.BasetenRemote, remote_factory.RemoteFactory.create(remote=remote)
|
|
623
|
+
)
|
|
624
|
+
with framework.ChainletImporter.import_target(
|
|
625
|
+
source, entrypoint
|
|
626
|
+
) as entrypoint_cls:
|
|
627
|
+
self._deployed_chain_name = name or entrypoint_cls.__name__
|
|
628
|
+
self._chain_root = _get_chain_root(entrypoint_cls)
|
|
629
|
+
chainlet_names = set(
|
|
630
|
+
desc.display_name
|
|
631
|
+
for desc in _get_ordered_dependencies([entrypoint_cls])
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
if included_chainlets:
|
|
635
|
+
if not_matched := (set(included_chainlets) - chainlet_names):
|
|
636
|
+
raise definitions.ChainsDeploymentError(
|
|
637
|
+
"Requested to watch specific chainlets, but did not find "
|
|
638
|
+
f"{not_matched} among available chainlets {chainlet_names}."
|
|
639
|
+
)
|
|
640
|
+
self._included_chainlets = set(included_chainlets)
|
|
641
|
+
else:
|
|
642
|
+
self._included_chainlets = chainlet_names
|
|
643
|
+
|
|
644
|
+
chain_id = b10_core.get_chain_id_by_name(
|
|
645
|
+
self._remote_provider.api, self._deployed_chain_name
|
|
646
|
+
)
|
|
647
|
+
if not chain_id:
|
|
648
|
+
raise definitions.ChainsDeploymentError(
|
|
649
|
+
f"Chain `{chain_id}` was not found."
|
|
650
|
+
)
|
|
651
|
+
self._status_page_url = b10_service.URLConfig.status_page_url(
|
|
652
|
+
self._remote_provider.remote_url, b10_service.URLConfig.CHAIN, chain_id
|
|
653
|
+
)
|
|
654
|
+
chain_deployment = b10_core.get_dev_chain_deployment(
|
|
655
|
+
self._remote_provider.api, chain_id
|
|
656
|
+
)
|
|
657
|
+
if chain_deployment is None:
|
|
658
|
+
raise definitions.ChainsDeploymentError(
|
|
659
|
+
f"No development deployment was found for Chain `{chain_id}`. "
|
|
660
|
+
"You cannot live-patch production deployments. Check the Chain's "
|
|
661
|
+
f"status page for available deployments: {self._status_page_url}."
|
|
662
|
+
)
|
|
663
|
+
deployed_chainlets = self._remote_provider.get_chainlets(chain_deployment["id"])
|
|
664
|
+
non_draft_chainlets = [
|
|
665
|
+
chainlet.name for chainlet in deployed_chainlets if not chainlet.is_draft
|
|
666
|
+
]
|
|
667
|
+
assert not (non_draft_chainlets), (
|
|
668
|
+
"If the chain is draft, the oracles must be draft."
|
|
669
|
+
)
|
|
670
|
+
|
|
671
|
+
self._chainlet_data = {c.name: c for c in deployed_chainlets}
|
|
672
|
+
self._assert_chainlet_names_same(chainlet_names)
|
|
673
|
+
self._ignore_patterns, self._watch_filter = _create_watch_filter(
|
|
674
|
+
self._chain_root
|
|
675
|
+
)
|
|
676
|
+
|
|
677
|
+
@property
|
|
678
|
+
def _original_chainlet_names(self) -> set[str]:
|
|
679
|
+
return set(self._chainlet_data.keys())
|
|
680
|
+
|
|
681
|
+
def _assert_chainlet_names_same(self, new_names: set[str]) -> None:
|
|
682
|
+
missing = self._original_chainlet_names - new_names
|
|
683
|
+
added = new_names - self._original_chainlet_names
|
|
684
|
+
if not (missing or added):
|
|
685
|
+
return
|
|
686
|
+
msg_parts = [
|
|
687
|
+
"The deployed Chainlets and the Chainlets in the current workspace differ. "
|
|
688
|
+
"Live patching is not possible if the set of Chainlet names differ."
|
|
689
|
+
]
|
|
690
|
+
if missing:
|
|
691
|
+
msg_parts.append(f"Chainlets missing in current workspace: {list(missing)}")
|
|
692
|
+
if added:
|
|
693
|
+
msg_parts.append(f"Chainlets added in current workspace: {list(added)}")
|
|
694
|
+
|
|
695
|
+
raise definitions.ChainsDeploymentError("\n".join(msg_parts))
|
|
696
|
+
|
|
697
|
+
def _code_gen_and_patch_thread(
|
|
698
|
+
self, descr: definitions.ChainletAPIDescriptor
|
|
699
|
+
) -> tuple[b10_remote.PatchResult, list[str]]:
|
|
700
|
+
with log_utils.LogInterceptor() as log_interceptor:
|
|
701
|
+
# TODO: Maybe try-except code_gen errors explicitly.
|
|
702
|
+
chainlet_dir = code_gen.gen_truss_chainlet(
|
|
703
|
+
self._chain_root,
|
|
704
|
+
self._deployed_chain_name,
|
|
705
|
+
descr,
|
|
706
|
+
self._chainlet_data[descr.display_name].oracle_name,
|
|
707
|
+
use_local_chains_src=False,
|
|
708
|
+
)
|
|
709
|
+
patch_result = self._remote_provider.patch_for_chainlet(
|
|
710
|
+
chainlet_dir, self._ignore_patterns
|
|
711
|
+
)
|
|
712
|
+
logs = log_interceptor.get_logs()
|
|
713
|
+
return patch_result, logs
|
|
714
|
+
|
|
715
|
+
def _patch(self, executor: concurrent.futures.Executor) -> None:
|
|
716
|
+
exception_raised = None
|
|
717
|
+
stack_trace = ""
|
|
718
|
+
with log_utils.LogInterceptor() as log_interceptor, self._console.status(
|
|
719
|
+
" Live Patching Chain.\n", spinner="arrow3"
|
|
720
|
+
):
|
|
721
|
+
# Handle import errors gracefully (e.g. if user saved file, but there
|
|
722
|
+
# are syntax errors, undefined symbols etc.).
|
|
723
|
+
try:
|
|
724
|
+
with framework.ChainletImporter.import_target(
|
|
725
|
+
self._source, self._entrypoint
|
|
726
|
+
) as entrypoint_cls:
|
|
727
|
+
chainlet_descriptors = _get_ordered_dependencies([entrypoint_cls])
|
|
728
|
+
chain_root_new = _get_chain_root(entrypoint_cls)
|
|
729
|
+
assert chain_root_new == self._chain_root
|
|
730
|
+
self._assert_chainlet_names_same(
|
|
731
|
+
set(desc.display_name for desc in chainlet_descriptors)
|
|
732
|
+
)
|
|
733
|
+
future_to_display_name = {}
|
|
734
|
+
for chainlet_descr in chainlet_descriptors:
|
|
735
|
+
if chainlet_descr.display_name not in self._included_chainlets:
|
|
736
|
+
self._console.print(
|
|
737
|
+
f"⏩ Skipping patching `{chainlet_descr.display_name}`.",
|
|
738
|
+
style="grey50",
|
|
739
|
+
)
|
|
740
|
+
continue
|
|
741
|
+
|
|
742
|
+
future = executor.submit(
|
|
743
|
+
self._code_gen_and_patch_thread, chainlet_descr
|
|
744
|
+
)
|
|
745
|
+
future_to_display_name[future] = chainlet_descr.display_name
|
|
746
|
+
# Threads need to finish while inside the `import_target`-context.
|
|
747
|
+
done_futures = {
|
|
748
|
+
future_to_display_name[future]: future
|
|
749
|
+
for future in concurrent.futures.as_completed(
|
|
750
|
+
future_to_display_name
|
|
751
|
+
)
|
|
752
|
+
}
|
|
753
|
+
except Exception as e:
|
|
754
|
+
exception_raised = e
|
|
755
|
+
stack_trace = traceback.format_exc()
|
|
756
|
+
finally:
|
|
757
|
+
logs = log_interceptor.get_logs()
|
|
758
|
+
|
|
759
|
+
_handle_intercepted_logs(logs, self._console)
|
|
760
|
+
if exception_raised:
|
|
761
|
+
_handle_import_error(
|
|
762
|
+
exception_raised,
|
|
763
|
+
self._console,
|
|
764
|
+
self._error_console,
|
|
765
|
+
stack_trace=stack_trace if self._show_stack_trace else None,
|
|
766
|
+
)
|
|
767
|
+
return
|
|
768
|
+
|
|
769
|
+
self._check_patch_results(done_futures)
|
|
770
|
+
|
|
771
|
+
def _check_patch_results(
|
|
772
|
+
self,
|
|
773
|
+
display_name_to_done_future: Mapping[
|
|
774
|
+
str, concurrent.futures.Future[tuple[b10_remote.PatchResult, list[str]]]
|
|
775
|
+
],
|
|
776
|
+
) -> None:
|
|
777
|
+
has_errors = False
|
|
778
|
+
for display_name, future in display_name_to_done_future.items():
|
|
779
|
+
# It is not expected that code_gen_and_patch raises an exception, errors
|
|
780
|
+
# should be handled by setting `b10_remote.PatchStatus`.
|
|
781
|
+
# If an exception is raised anyway, it should bubble up the default way.
|
|
782
|
+
patch_result, logs = future.result()
|
|
783
|
+
if logs:
|
|
784
|
+
formatted_logs = textwrap.indent("\n".join(logs), " " * 4)
|
|
785
|
+
logs_output = f" [grey70]Intercepted logs:\n{formatted_logs}[grey70]"
|
|
786
|
+
else:
|
|
787
|
+
logs_output = ""
|
|
788
|
+
|
|
789
|
+
if patch_result.status == b10_remote.PatchStatus.SUCCESS:
|
|
790
|
+
self._console.print(
|
|
791
|
+
f"✅ Patched Chainlet `{display_name}`.{logs_output}", style="green"
|
|
792
|
+
)
|
|
793
|
+
elif patch_result.status == b10_remote.PatchStatus.SKIPPED:
|
|
794
|
+
self._console.print(
|
|
795
|
+
f"💤 Nothing to do for Chainlet `{display_name}`.{logs_output}",
|
|
796
|
+
style="grey50",
|
|
797
|
+
)
|
|
798
|
+
else:
|
|
799
|
+
has_errors = True
|
|
800
|
+
self._error_console.print(
|
|
801
|
+
f"❌ Failed to patch Chainlet `{display_name}`. "
|
|
802
|
+
f"{patch_result.message}{logs_output}"
|
|
803
|
+
)
|
|
804
|
+
|
|
805
|
+
if has_errors:
|
|
806
|
+
msg = (
|
|
807
|
+
"Some Chainlets could not be live patched. See above error messages. "
|
|
808
|
+
"The watcher will continue, and try patching new changes. However, the "
|
|
809
|
+
"safest way to proceed and ensure a consistent state is to re-deploy "
|
|
810
|
+
"the the entire development Chain."
|
|
811
|
+
)
|
|
812
|
+
self._error_console.print(msg)
|
|
813
|
+
|
|
814
|
+
def watch(self) -> None:
|
|
815
|
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
816
|
+
# Perform one initial patch at startup.
|
|
817
|
+
self._patch(executor)
|
|
818
|
+
self._console.print("👀 Watching for new changes.", style="blue")
|
|
819
|
+
for _ in watchfiles.watch(
|
|
820
|
+
self._chain_root, watch_filter=self._watch_filter, raise_interrupt=False
|
|
821
|
+
):
|
|
822
|
+
self._patch(executor)
|
|
823
|
+
self._console.print("👀 Watching for new changes.", style="blue")
|
|
824
|
+
|
|
825
|
+
|
|
826
|
+
@framework.raise_validation_errors_before
|
|
827
|
+
def watch(
|
|
828
|
+
source: pathlib.Path,
|
|
829
|
+
entrypoint: Optional[str],
|
|
830
|
+
name: Optional[str],
|
|
831
|
+
remote: str,
|
|
832
|
+
console: "rich_console.Console",
|
|
833
|
+
error_console: "rich_console.Console",
|
|
834
|
+
show_stack_trace: bool,
|
|
835
|
+
included_chainlets: Optional[list[str]],
|
|
836
|
+
) -> None:
|
|
837
|
+
console.print(
|
|
838
|
+
(
|
|
839
|
+
"👀 Starting to watch for Chain source code and applying live patches "
|
|
840
|
+
"when changes are detected."
|
|
841
|
+
),
|
|
842
|
+
style="blue",
|
|
843
|
+
)
|
|
844
|
+
patcher = _Watcher(
|
|
845
|
+
source,
|
|
846
|
+
entrypoint,
|
|
847
|
+
name,
|
|
848
|
+
remote,
|
|
849
|
+
console,
|
|
850
|
+
error_console,
|
|
851
|
+
show_stack_trace,
|
|
852
|
+
included_chainlets,
|
|
853
|
+
)
|
|
854
|
+
patcher.watch()
|
|
855
|
+
|
|
856
|
+
|
|
857
|
+
def watch_model(
|
|
858
|
+
source: pathlib.Path,
|
|
859
|
+
model_name: str,
|
|
860
|
+
remote_provider: b10_remote.TrussRemote,
|
|
861
|
+
console: "rich_console.Console",
|
|
862
|
+
error_console: "rich_console.Console",
|
|
863
|
+
):
|
|
864
|
+
patcher = _ModelWatcher(
|
|
865
|
+
source=source,
|
|
866
|
+
model_name=model_name,
|
|
867
|
+
remote_provider=cast(b10_remote.BasetenRemote, remote_provider),
|
|
868
|
+
console=console,
|
|
869
|
+
error_console=error_console,
|
|
870
|
+
)
|
|
871
|
+
patcher.watch()
|