truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/__init__.py +10 -3
- truss/api/__init__.py +123 -0
- truss/api/definitions.py +51 -0
- truss/base/constants.py +116 -0
- truss/base/custom_types.py +29 -0
- truss/{errors.py → base/errors.py} +4 -0
- truss/base/trt_llm_config.py +310 -0
- truss/{truss_config.py → base/truss_config.py} +344 -31
- truss/{truss_spec.py → base/truss_spec.py} +20 -6
- truss/{validation.py → base/validation.py} +60 -11
- truss/cli/cli.py +841 -88
- truss/{remote → cli}/remote_cli.py +2 -7
- truss/contexts/docker_build_setup.py +67 -0
- truss/contexts/image_builder/cache_warmer.py +2 -8
- truss/contexts/image_builder/image_builder.py +1 -1
- truss/contexts/image_builder/serving_image_builder.py +292 -46
- truss/contexts/image_builder/util.py +1 -3
- truss/contexts/local_loader/docker_build_emulator.py +58 -0
- truss/contexts/local_loader/load_model_local.py +2 -2
- truss/contexts/local_loader/truss_module_loader.py +1 -1
- truss/contexts/local_loader/utils.py +1 -1
- truss/local/local_config.py +2 -6
- truss/local/local_config_handler.py +20 -5
- truss/patch/__init__.py +1 -0
- truss/patch/hash.py +4 -70
- truss/patch/signature.py +4 -16
- truss/patch/truss_dir_patch_applier.py +3 -78
- truss/remote/baseten/api.py +308 -23
- truss/remote/baseten/auth.py +3 -3
- truss/remote/baseten/core.py +257 -50
- truss/remote/baseten/custom_types.py +44 -0
- truss/remote/baseten/error.py +4 -0
- truss/remote/baseten/remote.py +369 -118
- truss/remote/baseten/service.py +118 -11
- truss/remote/baseten/utils/status.py +29 -0
- truss/remote/baseten/utils/tar.py +34 -22
- truss/remote/baseten/utils/transfer.py +36 -23
- truss/remote/remote_factory.py +14 -5
- truss/remote/truss_remote.py +72 -45
- truss/templates/base.Dockerfile.jinja +18 -16
- truss/templates/cache.Dockerfile.jinja +3 -3
- truss/{server → templates/control}/control/application.py +14 -35
- truss/{server → templates/control}/control/endpoints.py +39 -9
- truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
- truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
- truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
- truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
- truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
- truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
- truss/{server → templates/control}/control/server.py +11 -6
- truss/templates/control/requirements.txt +9 -0
- truss/templates/custom_python_dx/my_model.py +28 -0
- truss/templates/docker_server/proxy.conf.jinja +42 -0
- truss/templates/docker_server/supervisord.conf.jinja +27 -0
- truss/templates/docker_server_requirements.txt +1 -0
- truss/templates/server/common/errors.py +231 -0
- truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
- truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
- truss/{server → templates/server}/common/retry.py +1 -0
- truss/{server → templates/server}/common/schema.py +11 -9
- truss/templates/server/common/tracing.py +157 -0
- truss/templates/server/main.py +9 -0
- truss/templates/server/model_wrapper.py +961 -0
- truss/templates/server/requirements.txt +21 -0
- truss/templates/server/truss_server.py +447 -0
- truss/templates/server.Dockerfile.jinja +62 -14
- truss/templates/shared/dynamic_config_resolver.py +28 -0
- truss/templates/shared/lazy_data_resolver.py +164 -0
- truss/templates/shared/log_config.py +125 -0
- truss/{server → templates}/shared/secrets_resolver.py +1 -2
- truss/{server → templates}/shared/serialization.py +31 -9
- truss/{server → templates}/shared/util.py +3 -13
- truss/templates/trtllm-audio/model/model.py +49 -0
- truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
- truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
- truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
- truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
- truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
- truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
- truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
- truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
- truss/templates/trtllm-briton/src/extension.py +64 -0
- truss/tests/conftest.py +302 -94
- truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
- truss/tests/contexts/local_loader/test_load_local.py +2 -2
- truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
- truss/tests/patch/test_calc_patch.py +439 -127
- truss/tests/patch/test_dir_signature.py +3 -12
- truss/tests/patch/test_hash.py +1 -1
- truss/tests/patch/test_signature.py +1 -1
- truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
- truss/tests/patch/test_types.py +2 -2
- truss/tests/remote/baseten/test_api.py +153 -58
- truss/tests/remote/baseten/test_auth.py +2 -1
- truss/tests/remote/baseten/test_core.py +160 -12
- truss/tests/remote/baseten/test_remote.py +489 -77
- truss/tests/remote/baseten/test_service.py +55 -0
- truss/tests/remote/test_remote_factory.py +16 -18
- truss/tests/remote/test_truss_remote.py +26 -17
- truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
- truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
- truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
- truss/tests/{server → templates/control}/control/test_server.py +79 -24
- truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
- truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
- truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
- truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
- truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
- truss/tests/{server → templates/server}/common/test_retry.py +3 -3
- truss/tests/templates/server/test_model_wrapper.py +248 -0
- truss/tests/{server → templates/server}/test_schema.py +3 -5
- truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
- truss/tests/test_build.py +9 -52
- truss/tests/test_config.py +336 -77
- truss/tests/test_context_builder_image.py +3 -11
- truss/tests/test_control_truss_patching.py +7 -12
- truss/tests/test_custom_server.py +38 -0
- truss/tests/test_data/context_builder_image_test/test.py +3 -0
- truss/tests/test_data/happy.ipynb +56 -0
- truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
- truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
- truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
- truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
- truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
- truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
- truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
- truss/tests/test_data/test_async_truss/__init__.py +0 -0
- truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/config.yaml +16 -0
- truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/config.yaml +13 -0
- truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
- truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
- truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
- truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
- truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
- truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
- truss/tests/test_data/test_openai/__init__.py +0 -0
- truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
- truss/tests/test_data/test_openai/model/__init__.py +0 -0
- truss/tests/test_data/test_openai/model/model.py +15 -0
- truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
- truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
- truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
- truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
- truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
- truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
- truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
- truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
- truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
- truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
- truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
- truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
- truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
- truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
- truss/tests/test_data/test_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss/config.yaml +4 -0
- truss/tests/test_data/test_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss/model/dummy +0 -0
- truss/tests/test_data/test_truss/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
- truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
- truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_docker.py +2 -1
- truss/tests/test_model_inference.py +1340 -292
- truss/tests/test_model_schema.py +33 -26
- truss/tests/test_testing_utilities_for_other_tests.py +50 -5
- truss/tests/test_truss_gatherer.py +3 -5
- truss/tests/test_truss_handle.py +62 -59
- truss/tests/test_util.py +2 -1
- truss/tests/test_validation.py +15 -13
- truss/tests/trt_llm/test_trt_llm_config.py +41 -0
- truss/tests/trt_llm/test_validation.py +91 -0
- truss/tests/util/test_config_checks.py +40 -0
- truss/tests/util/test_env_vars.py +14 -0
- truss/tests/util/test_path.py +10 -23
- truss/trt_llm/config_checks.py +43 -0
- truss/trt_llm/validation.py +42 -0
- truss/truss_handle/__init__.py +0 -0
- truss/truss_handle/build.py +122 -0
- truss/{decorators.py → truss_handle/decorators.py} +1 -1
- truss/truss_handle/patch/__init__.py +0 -0
- truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
- truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
- truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
- truss/truss_handle/patch/hash.py +71 -0
- truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
- truss/truss_handle/patch/signature.py +22 -0
- truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
- truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
- truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
- truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
- truss/util/.truss_ignore +3 -0
- truss/{docker.py → util/docker.py} +6 -2
- truss/util/download.py +6 -15
- truss/util/env_vars.py +41 -0
- truss/util/log_utils.py +52 -0
- truss/util/path.py +20 -20
- truss/util/requirements.py +11 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
- truss-0.60.0.dist-info/RECORD +324 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
- truss-0.60.0.dist-info/entry_points.txt +4 -0
- truss_chains/__init__.py +71 -0
- truss_chains/definitions.py +756 -0
- truss_chains/deployment/__init__.py +0 -0
- truss_chains/deployment/code_gen.py +816 -0
- truss_chains/deployment/deployment_client.py +871 -0
- truss_chains/framework.py +1480 -0
- truss_chains/public_api.py +231 -0
- truss_chains/py.typed +0 -0
- truss_chains/pydantic_numpy.py +131 -0
- truss_chains/reference_code/reference_chainlet.py +34 -0
- truss_chains/reference_code/reference_model.py +10 -0
- truss_chains/remote_chainlet/__init__.py +0 -0
- truss_chains/remote_chainlet/model_skeleton.py +60 -0
- truss_chains/remote_chainlet/stub.py +380 -0
- truss_chains/remote_chainlet/utils.py +332 -0
- truss_chains/streaming.py +378 -0
- truss_chains/utils.py +178 -0
- CODE_OF_CONDUCT.md +0 -131
- CONTRIBUTING.md +0 -48
- README.md +0 -137
- context_builder.Dockerfile +0 -24
- truss/blob/blob_backend.py +0 -10
- truss/blob/blob_backend_registry.py +0 -23
- truss/blob/http_public_blob_backend.py +0 -23
- truss/build/__init__.py +0 -2
- truss/build/build.py +0 -143
- truss/build/configure.py +0 -63
- truss/cli/__init__.py +0 -2
- truss/cli/console.py +0 -5
- truss/cli/create.py +0 -5
- truss/config/trt_llm.py +0 -81
- truss/constants.py +0 -61
- truss/model_inference.py +0 -123
- truss/patch/types.py +0 -30
- truss/pytest.ini +0 -7
- truss/server/common/errors.py +0 -100
- truss/server/common/termination_handler_middleware.py +0 -64
- truss/server/common/truss_server.py +0 -389
- truss/server/control/patch/model_code_patch_applier.py +0 -46
- truss/server/control/patch/requirement_name_identifier.py +0 -17
- truss/server/inference_server.py +0 -29
- truss/server/model_wrapper.py +0 -434
- truss/server/shared/logging.py +0 -81
- truss/templates/trtllm/model/model.py +0 -97
- truss/templates/trtllm/packages/build_engine_utils.py +0 -34
- truss/templates/trtllm/packages/constants.py +0 -11
- truss/templates/trtllm/packages/schema.py +0 -216
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
- truss/templates/trtllm/packages/triton_client.py +0 -150
- truss/templates/trtllm/packages/utils.py +0 -43
- truss/test_data/context_builder_image_test/test.py +0 -4
- truss/test_data/happy.ipynb +0 -54
- truss/test_data/model_load_failure_test/config.yaml +0 -2
- truss/test_data/test_concurrency_truss/config.yaml +0 -2
- truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
- truss/test_data/test_streaming_truss/config.yaml +0 -3
- truss/test_data/test_truss/config.yaml +0 -2
- truss/tests/server/common/test_termination_handler_middleware.py +0 -93
- truss/tests/server/control/test_model_container_patch_applier.py +0 -203
- truss/tests/server/core/server/common/test_util.py +0 -19
- truss/tests/server/test_model_wrapper.py +0 -87
- truss/util/data_structures.py +0 -16
- truss-0.10.0rc1.dist-info/RECORD +0 -216
- truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
- truss/{server/shared → base}/__init__.py +0 -0
- truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
- truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
- truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
- truss/{server → templates/server}/common/__init__.py +0 -0
- truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
- truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
- truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
- truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
- truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
- truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
- truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
- truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
- truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
- truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
- truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
- truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
- truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
- truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
- /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
- /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
- /truss/{patch → truss_handle/patch}/constants.py +0 -0
- /truss/{notebook.py → util/notebook.py} +0 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from truss.remote.baseten import service
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def test_model_invocation_url_prod():
|
|
5
|
+
url = service.URLConfig.invocation_url(
|
|
6
|
+
"https://api.baseten.co", service.URLConfig.MODEL, "123", "789", is_draft=False
|
|
7
|
+
)
|
|
8
|
+
assert url == "https://model-123.api.baseten.co/deployment/789/predict"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def test_model_invocation_url_draft():
|
|
12
|
+
url = service.URLConfig.invocation_url(
|
|
13
|
+
"https://api.baseten.co", service.URLConfig.MODEL, "123", "789", is_draft=True
|
|
14
|
+
)
|
|
15
|
+
assert url == "https://model-123.api.baseten.co/development/predict"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def test_chain_invocation_url_prod():
|
|
19
|
+
url = service.URLConfig.invocation_url(
|
|
20
|
+
"https://api.baseten.co", service.URLConfig.CHAIN, "abc", "666", is_draft=False
|
|
21
|
+
)
|
|
22
|
+
assert url == "https://chain-abc.api.baseten.co/deployment/666/run_remote"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def test_chain_invocation_url_draft():
|
|
26
|
+
url = service.URLConfig.invocation_url(
|
|
27
|
+
"https://api.baseten.co", service.URLConfig.CHAIN, "abc", "666", is_draft=True
|
|
28
|
+
)
|
|
29
|
+
assert url == "https://chain-abc.api.baseten.co/development/run_remote"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def test_model_status_page_url():
|
|
33
|
+
url = service.URLConfig.status_page_url(
|
|
34
|
+
"https://app.baseten.co", service.URLConfig.MODEL, "123"
|
|
35
|
+
)
|
|
36
|
+
assert url == "https://app.baseten.co/models/123/overview"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def test_chain_status_page_url():
|
|
40
|
+
url = service.URLConfig.status_page_url(
|
|
41
|
+
"https://app.baseten.co", service.URLConfig.CHAIN, "abc"
|
|
42
|
+
)
|
|
43
|
+
assert url == "https://app.baseten.co/chains/abc/overview"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def test_model_logs_url():
|
|
47
|
+
url = service.URLConfig.model_logs_url("https://app.baseten.co", "123", "789")
|
|
48
|
+
assert url == "https://app.baseten.co/models/123/logs/789"
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def test_chain_logs_url():
|
|
52
|
+
url = service.URLConfig.chainlet_logs_url(
|
|
53
|
+
"https://app.baseten.co", "abc", "666", "543"
|
|
54
|
+
)
|
|
55
|
+
assert url == "https://app.baseten.co/chains/abc/logs/666/543"
|
|
@@ -2,7 +2,7 @@ from unittest import mock
|
|
|
2
2
|
|
|
3
3
|
import pytest
|
|
4
4
|
from truss.remote.remote_factory import RemoteFactory
|
|
5
|
-
from truss.remote.truss_remote import RemoteConfig,
|
|
5
|
+
from truss.remote.truss_remote import RemoteConfig, RemoteUser, TrussRemote
|
|
6
6
|
|
|
7
7
|
SAMPLE_CONFIG = {"api_key": "test_key", "remote_url": "http://test.com"}
|
|
8
8
|
|
|
@@ -25,10 +25,9 @@ remote_provider=test_remote
|
|
|
25
25
|
"""
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
class
|
|
28
|
+
class TrussTestRemote(TrussRemote):
|
|
29
29
|
def __init__(self, api_key, remote_url):
|
|
30
30
|
self.api_key = api_key
|
|
31
|
-
self.remote_url = remote_url
|
|
32
31
|
|
|
33
32
|
def authenticate(self):
|
|
34
33
|
return {"Authorization": self.api_key}
|
|
@@ -36,20 +35,19 @@ class TestRemote(TrussRemote):
|
|
|
36
35
|
def push(self):
|
|
37
36
|
return {"status": "success"}
|
|
38
37
|
|
|
39
|
-
def get_remote_logs_url(self, service: TrussService) -> str:
|
|
40
|
-
raise NotImplementedError
|
|
41
|
-
|
|
42
38
|
def get_service(self, **kwargs):
|
|
43
39
|
raise NotImplementedError
|
|
44
40
|
|
|
45
41
|
def sync_truss_to_dev_version_by_name(self, model_name: str, target_directory: str):
|
|
46
42
|
raise NotImplementedError
|
|
47
43
|
|
|
44
|
+
def whoami(self) -> RemoteUser:
|
|
45
|
+
return RemoteUser("test_user", "test_email")
|
|
46
|
+
|
|
48
47
|
|
|
49
48
|
def mock_service_config():
|
|
50
49
|
return RemoteConfig(
|
|
51
|
-
name="mock-service",
|
|
52
|
-
configs={"remote_provider": "test_remote", **SAMPLE_CONFIG},
|
|
50
|
+
name="mock-service", configs={"remote_provider": "test_remote", **SAMPLE_CONFIG}
|
|
53
51
|
)
|
|
54
52
|
|
|
55
53
|
|
|
@@ -60,7 +58,7 @@ def mock_incorrect_service_config():
|
|
|
60
58
|
)
|
|
61
59
|
|
|
62
60
|
|
|
63
|
-
@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote":
|
|
61
|
+
@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TrussTestRemote}, clear=True)
|
|
64
62
|
@mock.patch(
|
|
65
63
|
"truss.remote.remote_factory.RemoteFactory.load_remote_config",
|
|
66
64
|
return_value=mock_service_config(),
|
|
@@ -69,10 +67,10 @@ def test_create(mock_load_remote_config):
|
|
|
69
67
|
service_name = "test_service"
|
|
70
68
|
remote = RemoteFactory.create(service_name)
|
|
71
69
|
mock_load_remote_config.assert_called_once_with(service_name)
|
|
72
|
-
assert isinstance(remote,
|
|
70
|
+
assert isinstance(remote, TrussTestRemote)
|
|
73
71
|
|
|
74
72
|
|
|
75
|
-
@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote":
|
|
73
|
+
@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TrussTestRemote}, clear=True)
|
|
76
74
|
@mock.patch(
|
|
77
75
|
"truss.remote.remote_factory.RemoteFactory.load_remote_config",
|
|
78
76
|
return_value=mock_incorrect_service_config(),
|
|
@@ -83,7 +81,7 @@ def test_create_no_service(mock_load_remote_config):
|
|
|
83
81
|
RemoteFactory.create(service_name)
|
|
84
82
|
|
|
85
83
|
|
|
86
|
-
@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote":
|
|
84
|
+
@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TrussTestRemote}, clear=True)
|
|
87
85
|
@mock.patch("builtins.open", new_callable=mock.mock_open, read_data=SAMPLE_TRUSSRC)
|
|
88
86
|
@mock.patch("pathlib.Path.exists", return_value=True)
|
|
89
87
|
def test_load_remote_config(mock_exists, mock_open):
|
|
@@ -92,7 +90,7 @@ def test_load_remote_config(mock_exists, mock_open):
|
|
|
92
90
|
assert service.configs == {"remote_provider": "test_remote", **SAMPLE_CONFIG}
|
|
93
91
|
|
|
94
92
|
|
|
95
|
-
@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote":
|
|
93
|
+
@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TrussTestRemote}, clear=True)
|
|
96
94
|
@mock.patch("builtins.open", new_callable=mock.mock_open, read_data=SAMPLE_TRUSSRC)
|
|
97
95
|
@mock.patch("pathlib.Path.exists", return_value=False)
|
|
98
96
|
def test_load_remote_config_no_file(mock_exists, mock_open):
|
|
@@ -100,7 +98,7 @@ def test_load_remote_config_no_file(mock_exists, mock_open):
|
|
|
100
98
|
RemoteFactory.load_remote_config("test")
|
|
101
99
|
|
|
102
100
|
|
|
103
|
-
@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote":
|
|
101
|
+
@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TrussTestRemote}, clear=True)
|
|
104
102
|
@mock.patch("builtins.open", new_callable=mock.mock_open, read_data=SAMPLE_TRUSSRC)
|
|
105
103
|
@mock.patch("pathlib.Path.exists", return_value=True)
|
|
106
104
|
def test_load_remote_config_no_service(mock_exists, mock_open):
|
|
@@ -108,13 +106,13 @@ def test_load_remote_config_no_service(mock_exists, mock_open):
|
|
|
108
106
|
RemoteFactory.load_remote_config("nonexistent_service")
|
|
109
107
|
|
|
110
108
|
|
|
111
|
-
@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote":
|
|
109
|
+
@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TrussTestRemote}, clear=True)
|
|
112
110
|
def test_required_params():
|
|
113
|
-
required_params = RemoteFactory.required_params(
|
|
111
|
+
required_params = RemoteFactory.required_params(TrussTestRemote)
|
|
114
112
|
assert required_params == {"api_key", "remote_url"}
|
|
115
113
|
|
|
116
114
|
|
|
117
|
-
@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote":
|
|
115
|
+
@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TrussTestRemote}, clear=True)
|
|
118
116
|
@mock.patch(
|
|
119
117
|
"builtins.open", new_callable=mock.mock_open, read_data=SAMPLE_TRUSSRC_NO_REMOTE
|
|
120
118
|
)
|
|
@@ -125,7 +123,7 @@ def test_validate_remote_config_no_remote(mock_exists, mock_open):
|
|
|
125
123
|
RemoteFactory.validate_remote_config(service.configs, "test")
|
|
126
124
|
|
|
127
125
|
|
|
128
|
-
@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote":
|
|
126
|
+
@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TrussTestRemote}, clear=True)
|
|
129
127
|
@mock.patch(
|
|
130
128
|
"builtins.open", new_callable=mock.mock_open, read_data=SAMPLE_TRUSSRC_NO_PARAMS
|
|
131
129
|
)
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from typing import Iterator
|
|
1
2
|
from unittest import mock
|
|
2
3
|
|
|
3
4
|
import pytest
|
|
@@ -7,28 +8,36 @@ from truss.remote.truss_remote import TrussService
|
|
|
7
8
|
TEST_SERVICE_URL = "http://test.com"
|
|
8
9
|
|
|
9
10
|
|
|
10
|
-
class
|
|
11
|
-
def __init__(self,
|
|
12
|
-
super().__init__(
|
|
13
|
-
self._remote_url = remote_url
|
|
11
|
+
class TrussTestService(TrussService):
|
|
12
|
+
def __init__(self, _service_url: str, is_draft: bool, **kwargs):
|
|
13
|
+
super().__init__(_service_url, is_draft, **kwargs)
|
|
14
14
|
|
|
15
15
|
def authenticate(self):
|
|
16
|
-
return {
|
|
16
|
+
return {}
|
|
17
17
|
|
|
18
|
-
def is_live(self):
|
|
19
|
-
response = self._send_request(self.
|
|
18
|
+
def is_live(self) -> bool:
|
|
19
|
+
response = self._send_request(self._service_url, "GET")
|
|
20
20
|
if response.status_code == 200:
|
|
21
21
|
return True
|
|
22
22
|
return False
|
|
23
23
|
|
|
24
|
-
def is_ready(self):
|
|
25
|
-
response = self._send_request(self.
|
|
24
|
+
def is_ready(self) -> bool:
|
|
25
|
+
response = self._send_request(self._service_url, "GET")
|
|
26
26
|
if response.status_code == 200:
|
|
27
27
|
return True
|
|
28
28
|
return False
|
|
29
29
|
|
|
30
|
-
|
|
31
|
-
|
|
30
|
+
@property
|
|
31
|
+
def logs_url(self) -> str:
|
|
32
|
+
raise NotImplementedError()
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def predict_url(self) -> str:
|
|
36
|
+
return f"{self._service_url}/v1/models/model:predict"
|
|
37
|
+
|
|
38
|
+
def poll_deployment_status(self, sleep_secs: int = 1) -> Iterator[str]:
|
|
39
|
+
for status in ["DEPLOYING", "ACTIVE"]:
|
|
40
|
+
yield status
|
|
32
41
|
|
|
33
42
|
|
|
34
43
|
def mock_successful_response():
|
|
@@ -46,37 +55,37 @@ def mock_unsuccessful_response():
|
|
|
46
55
|
|
|
47
56
|
@mock.patch("requests.request", return_value=mock_successful_response())
|
|
48
57
|
def test_is_live(mock_request):
|
|
49
|
-
service =
|
|
58
|
+
service = TrussTestService(TEST_SERVICE_URL, True)
|
|
50
59
|
assert service.is_live()
|
|
51
60
|
|
|
52
61
|
|
|
53
62
|
@mock.patch("requests.request", return_value=mock_unsuccessful_response())
|
|
54
63
|
def test_is_not_live(mock_request):
|
|
55
|
-
service =
|
|
64
|
+
service = TrussTestService(TEST_SERVICE_URL, True)
|
|
56
65
|
assert service.is_live() is False
|
|
57
66
|
|
|
58
67
|
|
|
59
68
|
@mock.patch("requests.request", return_value=mock_successful_response())
|
|
60
69
|
def test_is_ready(mock_request):
|
|
61
|
-
service =
|
|
70
|
+
service = TrussTestService(TEST_SERVICE_URL, True)
|
|
62
71
|
assert service.is_ready()
|
|
63
72
|
|
|
64
73
|
|
|
65
74
|
@mock.patch("requests.request", return_value=mock_unsuccessful_response())
|
|
66
75
|
def test_is_not_ready(mock_request):
|
|
67
|
-
service =
|
|
76
|
+
service = TrussTestService(TEST_SERVICE_URL, True)
|
|
68
77
|
assert service.is_ready() is False
|
|
69
78
|
|
|
70
79
|
|
|
71
80
|
@mock.patch("requests.request", return_value=mock_successful_response())
|
|
72
81
|
def test_predict(mock_request):
|
|
73
|
-
service =
|
|
82
|
+
service = TrussTestService(TEST_SERVICE_URL, True)
|
|
74
83
|
response = service.predict({"model_input": "test"})
|
|
75
84
|
assert response.status_code == 200
|
|
76
85
|
|
|
77
86
|
|
|
78
87
|
@mock.patch("requests.request", return_value=mock_successful_response())
|
|
79
88
|
def test_predict_no_data(mock_request):
|
|
80
|
-
service =
|
|
89
|
+
service = TrussTestService(TEST_SERVICE_URL, True)
|
|
81
90
|
with pytest.raises(ValueError):
|
|
82
91
|
service._send_request(TEST_SERVICE_URL, "POST")
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from truss.templates.control.control.helpers.context_managers import current_directory
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def test_current_directory(tmp_path):
|
|
7
|
+
orig_cwd = os.getcwd()
|
|
8
|
+
with current_directory(tmp_path):
|
|
9
|
+
assert os.getcwd() == str(tmp_path)
|
|
10
|
+
|
|
11
|
+
assert os.getcwd() == orig_cwd
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from unittest import mock
|
|
5
|
+
|
|
6
|
+
import pytest
|
|
7
|
+
from truss.base.truss_config import TrussConfig
|
|
8
|
+
|
|
9
|
+
# Needed to simulate the set up on the model docker container
|
|
10
|
+
sys.path.append(
|
|
11
|
+
str(
|
|
12
|
+
Path(__file__).parent.parent.parent.parent.parent.parent
|
|
13
|
+
/ "templates"
|
|
14
|
+
/ "control"
|
|
15
|
+
/ "control"
|
|
16
|
+
)
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
# Have to use imports in this form, otherwise isinstance checks fail on helper classes
|
|
20
|
+
from helpers.truss_patch.model_container_patch_applier import ( # noqa
|
|
21
|
+
ModelContainerPatchApplier,
|
|
22
|
+
)
|
|
23
|
+
from helpers.custom_types import ( # noqa
|
|
24
|
+
Action,
|
|
25
|
+
ConfigPatch,
|
|
26
|
+
EnvVarPatch,
|
|
27
|
+
ExternalDataPatch,
|
|
28
|
+
ModelCodePatch,
|
|
29
|
+
PackagePatch,
|
|
30
|
+
Patch,
|
|
31
|
+
PatchType,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@pytest.fixture
|
|
36
|
+
def patch_applier(truss_container_fs):
|
|
37
|
+
return ModelContainerPatchApplier(truss_container_fs / "app", mock.Mock())
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def test_patch_applier_model_code_patch_add(
|
|
41
|
+
patch_applier: ModelContainerPatchApplier, truss_container_fs
|
|
42
|
+
):
|
|
43
|
+
patch = Patch(
|
|
44
|
+
type=PatchType.MODEL_CODE,
|
|
45
|
+
body=ModelCodePatch(action=Action.ADD, path="dummy", content=""),
|
|
46
|
+
)
|
|
47
|
+
patch_applier(patch, os.environ.copy())
|
|
48
|
+
assert (truss_container_fs / "app" / "model" / "dummy").exists()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def test_patch_applier_model_code_patch_remove(
|
|
52
|
+
patch_applier: ModelContainerPatchApplier, truss_container_fs
|
|
53
|
+
):
|
|
54
|
+
patch = Patch(
|
|
55
|
+
type=PatchType.MODEL_CODE,
|
|
56
|
+
body=ModelCodePatch(action=Action.REMOVE, path="model.py"),
|
|
57
|
+
)
|
|
58
|
+
assert (truss_container_fs / "app" / "model" / "model.py").exists()
|
|
59
|
+
patch_applier(patch, os.environ.copy())
|
|
60
|
+
assert not (truss_container_fs / "app" / "model" / "model.py").exists()
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def test_patch_applier_model_code_patch_update(
|
|
64
|
+
patch_applier: ModelContainerPatchApplier, truss_container_fs
|
|
65
|
+
):
|
|
66
|
+
new_model_file_content = """
|
|
67
|
+
class Model:
|
|
68
|
+
pass
|
|
69
|
+
"""
|
|
70
|
+
patch = Patch(
|
|
71
|
+
type=PatchType.MODEL_CODE,
|
|
72
|
+
body=ModelCodePatch(
|
|
73
|
+
action=Action.UPDATE, path="model.py", content=new_model_file_content
|
|
74
|
+
),
|
|
75
|
+
)
|
|
76
|
+
patch_applier(patch, os.environ.copy())
|
|
77
|
+
assert (
|
|
78
|
+
truss_container_fs / "app" / "model" / "model.py"
|
|
79
|
+
).read_text() == new_model_file_content
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def test_patch_applier_package_patch_add(
|
|
83
|
+
patch_applier: ModelContainerPatchApplier, truss_container_fs
|
|
84
|
+
):
|
|
85
|
+
patch = Patch(
|
|
86
|
+
type=PatchType.PACKAGE,
|
|
87
|
+
body=PackagePatch(
|
|
88
|
+
action=Action.ADD, path="test_package/test.py", content="foobar"
|
|
89
|
+
),
|
|
90
|
+
)
|
|
91
|
+
patch_applier(patch, os.environ.copy())
|
|
92
|
+
assert (truss_container_fs / "packages" / "test_package" / "test.py").exists()
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def test_patch_applier_package_patch_remove(
|
|
96
|
+
patch_applier: ModelContainerPatchApplier, truss_container_fs
|
|
97
|
+
):
|
|
98
|
+
patch = Patch(
|
|
99
|
+
type=PatchType.PACKAGE,
|
|
100
|
+
body=PackagePatch(action=Action.REMOVE, path="test_package/test.py"),
|
|
101
|
+
)
|
|
102
|
+
assert (truss_container_fs / "packages" / "test_package" / "test.py").exists()
|
|
103
|
+
patch_applier(patch, os.environ.copy())
|
|
104
|
+
assert not (truss_container_fs / "packages" / "test_package" / "test.py").exists()
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def test_patch_applier_package_patch_update(
|
|
108
|
+
patch_applier: ModelContainerPatchApplier, truss_container_fs
|
|
109
|
+
):
|
|
110
|
+
new_package_content = """X = 2"""
|
|
111
|
+
patch = Patch(
|
|
112
|
+
type=PatchType.PACKAGE,
|
|
113
|
+
body=PackagePatch(
|
|
114
|
+
action=Action.UPDATE,
|
|
115
|
+
path="test_package/test.py",
|
|
116
|
+
content=new_package_content,
|
|
117
|
+
),
|
|
118
|
+
)
|
|
119
|
+
patch_applier(patch, os.environ.copy())
|
|
120
|
+
assert (
|
|
121
|
+
truss_container_fs / "packages" / "test_package" / "test.py"
|
|
122
|
+
).read_text() == new_package_content
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def test_patch_applier_config_patch_update(
|
|
126
|
+
patch_applier: ModelContainerPatchApplier, truss_container_fs
|
|
127
|
+
):
|
|
128
|
+
new_config_dict = {"model_name": "foobar"}
|
|
129
|
+
patch = Patch(
|
|
130
|
+
type=PatchType.CONFIG,
|
|
131
|
+
body=ConfigPatch(action=Action.UPDATE, config=new_config_dict),
|
|
132
|
+
)
|
|
133
|
+
patch_applier(patch, os.environ.copy())
|
|
134
|
+
new_config = TrussConfig.from_yaml(truss_container_fs / "app" / "config.yaml")
|
|
135
|
+
assert new_config.model_name == "foobar"
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def test_patch_applier_env_var_patch_update(patch_applier: ModelContainerPatchApplier):
|
|
139
|
+
env_var_dict = {"FOO": "BAR"}
|
|
140
|
+
patch = Patch(
|
|
141
|
+
type=PatchType.ENVIRONMENT_VARIABLE,
|
|
142
|
+
body=EnvVarPatch(action=Action.UPDATE, item={"FOO": "BAR-PATCHED"}),
|
|
143
|
+
)
|
|
144
|
+
patch_applier(patch, env_var_dict)
|
|
145
|
+
assert env_var_dict["FOO"] == "BAR-PATCHED"
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def test_patch_applier_env_var_patch_add(patch_applier: ModelContainerPatchApplier):
|
|
149
|
+
env_var_dict = {"FOO": "BAR"}
|
|
150
|
+
patch = Patch(
|
|
151
|
+
type=PatchType.ENVIRONMENT_VARIABLE,
|
|
152
|
+
body=EnvVarPatch(action=Action.ADD, item={"BAR": "FOO"}),
|
|
153
|
+
)
|
|
154
|
+
patch_applier(patch, env_var_dict)
|
|
155
|
+
assert env_var_dict["FOO"] == "BAR"
|
|
156
|
+
assert env_var_dict["BAR"] == "FOO"
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def test_patch_applier_env_var_patch_remove(patch_applier: ModelContainerPatchApplier):
|
|
160
|
+
env_var_dict = {"FOO": "BAR"}
|
|
161
|
+
patch = Patch(
|
|
162
|
+
type=PatchType.ENVIRONMENT_VARIABLE,
|
|
163
|
+
body=EnvVarPatch(action=Action.REMOVE, item={"FOO": "BAR"}),
|
|
164
|
+
)
|
|
165
|
+
patch_applier(patch, env_var_dict)
|
|
166
|
+
with pytest.raises(KeyError):
|
|
167
|
+
_ = env_var_dict["FOO"]
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def test_patch_applier_external_data_patch_add(
|
|
171
|
+
patch_applier: ModelContainerPatchApplier, truss_container_fs
|
|
172
|
+
):
|
|
173
|
+
patch = Patch(
|
|
174
|
+
type=PatchType.EXTERNAL_DATA,
|
|
175
|
+
body=ExternalDataPatch(
|
|
176
|
+
action=Action.ADD,
|
|
177
|
+
item={
|
|
178
|
+
"url": "https://raw.githubusercontent.com/basetenlabs/truss/main/docs/favicon.svg",
|
|
179
|
+
"local_data_path": "truss_icon",
|
|
180
|
+
},
|
|
181
|
+
),
|
|
182
|
+
)
|
|
183
|
+
patch_applier(patch, os.environ.copy())
|
|
184
|
+
assert (truss_container_fs / "app" / "data" / "truss_icon").exists()
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from truss.templates.control.control.helpers.truss_patch.requirement_name_identifier import (
|
|
3
|
+
RequirementMeta,
|
|
4
|
+
identify_requirement_name,
|
|
5
|
+
reqs_by_name,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@pytest.mark.parametrize(
|
|
10
|
+
"req, expected_name",
|
|
11
|
+
[
|
|
12
|
+
("pytorch", "pytorch"),
|
|
13
|
+
(
|
|
14
|
+
"git+https://github.com/huggingface/transformers.git#egg=transformers",
|
|
15
|
+
"git+github.com/huggingface/transformers.git",
|
|
16
|
+
),
|
|
17
|
+
(
|
|
18
|
+
"git+https://github.com/huggingface/transformers.git",
|
|
19
|
+
"git+github.com/huggingface/transformers.git",
|
|
20
|
+
),
|
|
21
|
+
(
|
|
22
|
+
"git+https://github.com/huggingface/transformers.git@main#egg=transformers",
|
|
23
|
+
"git+github.com/huggingface/transformers.git",
|
|
24
|
+
),
|
|
25
|
+
(
|
|
26
|
+
" git+https://github.com/huggingface/transformers.git ",
|
|
27
|
+
"git+github.com/huggingface/transformers.git",
|
|
28
|
+
),
|
|
29
|
+
("pytorch==1.0", "pytorch"),
|
|
30
|
+
("pytorch>=1.0", "pytorch"),
|
|
31
|
+
("pytorch<=1.0", "pytorch"),
|
|
32
|
+
],
|
|
33
|
+
)
|
|
34
|
+
def test_identify_requirement_name(req, expected_name):
|
|
35
|
+
assert expected_name == identify_requirement_name(req)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def test_reqs_by_name():
|
|
39
|
+
reqs = ["pytorch", " ", "jinja==1.0"]
|
|
40
|
+
assert reqs_by_name(reqs) == {"pytorch": "pytorch", "jinja": "jinja==1.0"}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@pytest.mark.parametrize(
|
|
44
|
+
"desc, req, expected_meta",
|
|
45
|
+
[
|
|
46
|
+
(
|
|
47
|
+
"handles simple requirement",
|
|
48
|
+
"pytorch",
|
|
49
|
+
RequirementMeta(
|
|
50
|
+
requirement="pytorch",
|
|
51
|
+
name="pytorch",
|
|
52
|
+
is_url_based_requirement=False,
|
|
53
|
+
egg_tag=None,
|
|
54
|
+
),
|
|
55
|
+
),
|
|
56
|
+
(
|
|
57
|
+
"handles python package with version",
|
|
58
|
+
"pytorch==1.0",
|
|
59
|
+
RequirementMeta(
|
|
60
|
+
requirement="pytorch==1.0",
|
|
61
|
+
name="pytorch",
|
|
62
|
+
is_url_based_requirement=False,
|
|
63
|
+
egg_tag=None,
|
|
64
|
+
),
|
|
65
|
+
),
|
|
66
|
+
(
|
|
67
|
+
"handles url-based requirement with egg tag",
|
|
68
|
+
"git+https://github.com/huggingface/transformers.git@main#egg=transformers",
|
|
69
|
+
RequirementMeta(
|
|
70
|
+
requirement="git+https://github.com/huggingface/transformers.git@main#egg=transformers",
|
|
71
|
+
name="git+github.com/huggingface/transformers.git",
|
|
72
|
+
is_url_based_requirement=True,
|
|
73
|
+
egg_tag=["transformers"],
|
|
74
|
+
),
|
|
75
|
+
),
|
|
76
|
+
(
|
|
77
|
+
"handles url-based requirement without egg tag",
|
|
78
|
+
"git+https://github.com/huggingface/transformers.git",
|
|
79
|
+
RequirementMeta(
|
|
80
|
+
requirement="git+https://github.com/huggingface/transformers.git",
|
|
81
|
+
name="git+github.com/huggingface/transformers.git",
|
|
82
|
+
is_url_based_requirement=True,
|
|
83
|
+
egg_tag=None,
|
|
84
|
+
),
|
|
85
|
+
),
|
|
86
|
+
],
|
|
87
|
+
)
|
|
88
|
+
def test_requirement_meta_from_req(desc, req: str, expected_meta: RequirementMeta):
|
|
89
|
+
assert expected_meta == RequirementMeta.from_req(req), desc
|