truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/__init__.py +10 -3
- truss/api/__init__.py +123 -0
- truss/api/definitions.py +51 -0
- truss/base/constants.py +116 -0
- truss/base/custom_types.py +29 -0
- truss/{errors.py → base/errors.py} +4 -0
- truss/base/trt_llm_config.py +310 -0
- truss/{truss_config.py → base/truss_config.py} +344 -31
- truss/{truss_spec.py → base/truss_spec.py} +20 -6
- truss/{validation.py → base/validation.py} +60 -11
- truss/cli/cli.py +841 -88
- truss/{remote → cli}/remote_cli.py +2 -7
- truss/contexts/docker_build_setup.py +67 -0
- truss/contexts/image_builder/cache_warmer.py +2 -8
- truss/contexts/image_builder/image_builder.py +1 -1
- truss/contexts/image_builder/serving_image_builder.py +292 -46
- truss/contexts/image_builder/util.py +1 -3
- truss/contexts/local_loader/docker_build_emulator.py +58 -0
- truss/contexts/local_loader/load_model_local.py +2 -2
- truss/contexts/local_loader/truss_module_loader.py +1 -1
- truss/contexts/local_loader/utils.py +1 -1
- truss/local/local_config.py +2 -6
- truss/local/local_config_handler.py +20 -5
- truss/patch/__init__.py +1 -0
- truss/patch/hash.py +4 -70
- truss/patch/signature.py +4 -16
- truss/patch/truss_dir_patch_applier.py +3 -78
- truss/remote/baseten/api.py +308 -23
- truss/remote/baseten/auth.py +3 -3
- truss/remote/baseten/core.py +257 -50
- truss/remote/baseten/custom_types.py +44 -0
- truss/remote/baseten/error.py +4 -0
- truss/remote/baseten/remote.py +369 -118
- truss/remote/baseten/service.py +118 -11
- truss/remote/baseten/utils/status.py +29 -0
- truss/remote/baseten/utils/tar.py +34 -22
- truss/remote/baseten/utils/transfer.py +36 -23
- truss/remote/remote_factory.py +14 -5
- truss/remote/truss_remote.py +72 -45
- truss/templates/base.Dockerfile.jinja +18 -16
- truss/templates/cache.Dockerfile.jinja +3 -3
- truss/{server → templates/control}/control/application.py +14 -35
- truss/{server → templates/control}/control/endpoints.py +39 -9
- truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
- truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
- truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
- truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
- truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
- truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
- truss/{server → templates/control}/control/server.py +11 -6
- truss/templates/control/requirements.txt +9 -0
- truss/templates/custom_python_dx/my_model.py +28 -0
- truss/templates/docker_server/proxy.conf.jinja +42 -0
- truss/templates/docker_server/supervisord.conf.jinja +27 -0
- truss/templates/docker_server_requirements.txt +1 -0
- truss/templates/server/common/errors.py +231 -0
- truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
- truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
- truss/{server → templates/server}/common/retry.py +1 -0
- truss/{server → templates/server}/common/schema.py +11 -9
- truss/templates/server/common/tracing.py +157 -0
- truss/templates/server/main.py +9 -0
- truss/templates/server/model_wrapper.py +961 -0
- truss/templates/server/requirements.txt +21 -0
- truss/templates/server/truss_server.py +447 -0
- truss/templates/server.Dockerfile.jinja +62 -14
- truss/templates/shared/dynamic_config_resolver.py +28 -0
- truss/templates/shared/lazy_data_resolver.py +164 -0
- truss/templates/shared/log_config.py +125 -0
- truss/{server → templates}/shared/secrets_resolver.py +1 -2
- truss/{server → templates}/shared/serialization.py +31 -9
- truss/{server → templates}/shared/util.py +3 -13
- truss/templates/trtllm-audio/model/model.py +49 -0
- truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
- truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
- truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
- truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
- truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
- truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
- truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
- truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
- truss/templates/trtllm-briton/src/extension.py +64 -0
- truss/tests/conftest.py +302 -94
- truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
- truss/tests/contexts/local_loader/test_load_local.py +2 -2
- truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
- truss/tests/patch/test_calc_patch.py +439 -127
- truss/tests/patch/test_dir_signature.py +3 -12
- truss/tests/patch/test_hash.py +1 -1
- truss/tests/patch/test_signature.py +1 -1
- truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
- truss/tests/patch/test_types.py +2 -2
- truss/tests/remote/baseten/test_api.py +153 -58
- truss/tests/remote/baseten/test_auth.py +2 -1
- truss/tests/remote/baseten/test_core.py +160 -12
- truss/tests/remote/baseten/test_remote.py +489 -77
- truss/tests/remote/baseten/test_service.py +55 -0
- truss/tests/remote/test_remote_factory.py +16 -18
- truss/tests/remote/test_truss_remote.py +26 -17
- truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
- truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
- truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
- truss/tests/{server → templates/control}/control/test_server.py +79 -24
- truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
- truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
- truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
- truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
- truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
- truss/tests/{server → templates/server}/common/test_retry.py +3 -3
- truss/tests/templates/server/test_model_wrapper.py +248 -0
- truss/tests/{server → templates/server}/test_schema.py +3 -5
- truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
- truss/tests/test_build.py +9 -52
- truss/tests/test_config.py +336 -77
- truss/tests/test_context_builder_image.py +3 -11
- truss/tests/test_control_truss_patching.py +7 -12
- truss/tests/test_custom_server.py +38 -0
- truss/tests/test_data/context_builder_image_test/test.py +3 -0
- truss/tests/test_data/happy.ipynb +56 -0
- truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
- truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
- truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
- truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
- truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
- truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
- truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
- truss/tests/test_data/test_async_truss/__init__.py +0 -0
- truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/__init__.py +0 -0
- truss/tests/test_data/test_basic_truss/config.yaml +16 -0
- truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/__init__.py +0 -0
- truss/tests/test_data/test_build_commands/config.yaml +13 -0
- truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
- truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
- truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
- truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
- truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
- truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
- truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
- truss/tests/test_data/test_openai/__init__.py +0 -0
- truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
- truss/tests/test_data/test_openai/model/__init__.py +0 -0
- truss/tests/test_data/test_openai/model/model.py +15 -0
- truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
- truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
- truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
- truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
- truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
- truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
- truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
- truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
- truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
- truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
- truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
- truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
- truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
- truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
- truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
- truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
- truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
- truss/tests/test_data/test_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss/config.yaml +4 -0
- truss/tests/test_data/test_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss/model/dummy +0 -0
- truss/tests/test_data/test_truss/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
- truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
- truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
- truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
- truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
- truss/tests/test_docker.py +2 -1
- truss/tests/test_model_inference.py +1340 -292
- truss/tests/test_model_schema.py +33 -26
- truss/tests/test_testing_utilities_for_other_tests.py +50 -5
- truss/tests/test_truss_gatherer.py +3 -5
- truss/tests/test_truss_handle.py +62 -59
- truss/tests/test_util.py +2 -1
- truss/tests/test_validation.py +15 -13
- truss/tests/trt_llm/test_trt_llm_config.py +41 -0
- truss/tests/trt_llm/test_validation.py +91 -0
- truss/tests/util/test_config_checks.py +40 -0
- truss/tests/util/test_env_vars.py +14 -0
- truss/tests/util/test_path.py +10 -23
- truss/trt_llm/config_checks.py +43 -0
- truss/trt_llm/validation.py +42 -0
- truss/truss_handle/__init__.py +0 -0
- truss/truss_handle/build.py +122 -0
- truss/{decorators.py → truss_handle/decorators.py} +1 -1
- truss/truss_handle/patch/__init__.py +0 -0
- truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
- truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
- truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
- truss/truss_handle/patch/hash.py +71 -0
- truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
- truss/truss_handle/patch/signature.py +22 -0
- truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
- truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
- truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
- truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
- truss/util/.truss_ignore +3 -0
- truss/{docker.py → util/docker.py} +6 -2
- truss/util/download.py +6 -15
- truss/util/env_vars.py +41 -0
- truss/util/log_utils.py +52 -0
- truss/util/path.py +20 -20
- truss/util/requirements.py +11 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
- truss-0.60.0.dist-info/RECORD +324 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
- truss-0.60.0.dist-info/entry_points.txt +4 -0
- truss_chains/__init__.py +71 -0
- truss_chains/definitions.py +756 -0
- truss_chains/deployment/__init__.py +0 -0
- truss_chains/deployment/code_gen.py +816 -0
- truss_chains/deployment/deployment_client.py +871 -0
- truss_chains/framework.py +1480 -0
- truss_chains/public_api.py +231 -0
- truss_chains/py.typed +0 -0
- truss_chains/pydantic_numpy.py +131 -0
- truss_chains/reference_code/reference_chainlet.py +34 -0
- truss_chains/reference_code/reference_model.py +10 -0
- truss_chains/remote_chainlet/__init__.py +0 -0
- truss_chains/remote_chainlet/model_skeleton.py +60 -0
- truss_chains/remote_chainlet/stub.py +380 -0
- truss_chains/remote_chainlet/utils.py +332 -0
- truss_chains/streaming.py +378 -0
- truss_chains/utils.py +178 -0
- CODE_OF_CONDUCT.md +0 -131
- CONTRIBUTING.md +0 -48
- README.md +0 -137
- context_builder.Dockerfile +0 -24
- truss/blob/blob_backend.py +0 -10
- truss/blob/blob_backend_registry.py +0 -23
- truss/blob/http_public_blob_backend.py +0 -23
- truss/build/__init__.py +0 -2
- truss/build/build.py +0 -143
- truss/build/configure.py +0 -63
- truss/cli/__init__.py +0 -2
- truss/cli/console.py +0 -5
- truss/cli/create.py +0 -5
- truss/config/trt_llm.py +0 -81
- truss/constants.py +0 -61
- truss/model_inference.py +0 -123
- truss/patch/types.py +0 -30
- truss/pytest.ini +0 -7
- truss/server/common/errors.py +0 -100
- truss/server/common/termination_handler_middleware.py +0 -64
- truss/server/common/truss_server.py +0 -389
- truss/server/control/patch/model_code_patch_applier.py +0 -46
- truss/server/control/patch/requirement_name_identifier.py +0 -17
- truss/server/inference_server.py +0 -29
- truss/server/model_wrapper.py +0 -434
- truss/server/shared/logging.py +0 -81
- truss/templates/trtllm/model/model.py +0 -97
- truss/templates/trtllm/packages/build_engine_utils.py +0 -34
- truss/templates/trtllm/packages/constants.py +0 -11
- truss/templates/trtllm/packages/schema.py +0 -216
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
- truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
- truss/templates/trtllm/packages/triton_client.py +0 -150
- truss/templates/trtllm/packages/utils.py +0 -43
- truss/test_data/context_builder_image_test/test.py +0 -4
- truss/test_data/happy.ipynb +0 -54
- truss/test_data/model_load_failure_test/config.yaml +0 -2
- truss/test_data/test_concurrency_truss/config.yaml +0 -2
- truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
- truss/test_data/test_streaming_truss/config.yaml +0 -3
- truss/test_data/test_truss/config.yaml +0 -2
- truss/tests/server/common/test_termination_handler_middleware.py +0 -93
- truss/tests/server/control/test_model_container_patch_applier.py +0 -203
- truss/tests/server/core/server/common/test_util.py +0 -19
- truss/tests/server/test_model_wrapper.py +0 -87
- truss/util/data_structures.py +0 -16
- truss-0.10.0rc1.dist-info/RECORD +0 -216
- truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
- truss/{server/shared → base}/__init__.py +0 -0
- truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
- truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
- truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
- truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
- truss/{server → templates/server}/common/__init__.py +0 -0
- truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
- truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
- truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
- truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
- truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
- truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
- truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
- truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
- truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
- truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
- truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
- truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
- truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
- truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
- truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
- truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
- truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
- truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
- truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
- /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
- /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
- /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
- /truss/{patch → truss_handle/patch}/constants.py +0 -0
- /truss/{notebook.py → util/notebook.py} +0 -0
- {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import tempfile
|
|
3
|
+
import time
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import pytest
|
|
7
|
+
from truss.templates.shared.lazy_data_resolver import (
|
|
8
|
+
TRUSS_TRANSFER_AVAILABLE,
|
|
9
|
+
LazyDataResolverV2,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
LAZY_DATA_RESOLVER_PATH = Path("/bptr/bptr-manifest")
|
|
13
|
+
TARGET_FILE = Path("nested/config.json")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def write_bptr_manifest_to_file(expiration_timestamp: int = 2683764059):
|
|
17
|
+
bptr_manifest = {
|
|
18
|
+
"pointers": [
|
|
19
|
+
{
|
|
20
|
+
"resolution": {
|
|
21
|
+
"url": "https://raw.githubusercontent.com/basetenlabs/truss/00e01b679afbe353b0b2fe4de6b138d912bb7167/.circleci/config.yml",
|
|
22
|
+
"expiration_timestamp": expiration_timestamp,
|
|
23
|
+
},
|
|
24
|
+
"uid": "8c6b2f215f0333437cdc3fe7c79be0c802847d2f2a0ccdc0bb251814e63cf375",
|
|
25
|
+
"file_name": TARGET_FILE.as_posix(),
|
|
26
|
+
"hashtype": "blake3",
|
|
27
|
+
"hash": "8c6b2f215f0333437cdc3fe7c79be0c802847d2f2a0ccdc0bb251814e63cf375",
|
|
28
|
+
"size": 1482,
|
|
29
|
+
}
|
|
30
|
+
]
|
|
31
|
+
}
|
|
32
|
+
# write to LAZY_DATA_RESOLVER_PATH
|
|
33
|
+
with open(LAZY_DATA_RESOLVER_PATH, "w") as f:
|
|
34
|
+
json.dump(bptr_manifest, f)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@pytest.mark.skipif(not TRUSS_TRANSFER_AVAILABLE, reason="Truss Transfer not available")
|
|
38
|
+
def test_lazy_data_resolver_v2():
|
|
39
|
+
# truss_transfer reads from LAZY_DATA_RESOLVER_PATH
|
|
40
|
+
if LAZY_DATA_RESOLVER_PATH.exists():
|
|
41
|
+
LAZY_DATA_RESOLVER_PATH.unlink()
|
|
42
|
+
with pytest.raises(Exception):
|
|
43
|
+
# LAZY_DATA_RESOLVER_PATH does not exist
|
|
44
|
+
# should raise an exception
|
|
45
|
+
LazyDataResolverV2(Path("/tmp")).fetch()
|
|
46
|
+
|
|
47
|
+
try:
|
|
48
|
+
LAZY_DATA_RESOLVER_PATH.mkdir(parents=True, exist_ok=True)
|
|
49
|
+
except Exception as e:
|
|
50
|
+
pytest.skip(
|
|
51
|
+
f"Unable to create {LAZY_DATA_RESOLVER_PATH} due to missing os permissions: {e}"
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# without LAZY_DATA_RESOLVER_PATH -> does not create folder / file
|
|
55
|
+
with tempfile.TemporaryDirectory() as tempdir:
|
|
56
|
+
data_dir = Path(tempdir)
|
|
57
|
+
resolver = LazyDataResolverV2(data_dir).fetch()
|
|
58
|
+
assert not (data_dir / TARGET_FILE).exists()
|
|
59
|
+
|
|
60
|
+
# with LAZY_DATA_RESOLVER_PATH -> fetches data
|
|
61
|
+
with tempfile.TemporaryDirectory() as tempdir:
|
|
62
|
+
data_dir = Path(tempdir)
|
|
63
|
+
write_bptr_manifest_to_file()
|
|
64
|
+
resolver = LazyDataResolverV2(data_dir).fetch()
|
|
65
|
+
resolver.fetch()
|
|
66
|
+
assert (data_dir / TARGET_FILE).exists()
|
|
67
|
+
assert (data_dir / TARGET_FILE).stat().st_size == 1482
|
|
68
|
+
|
|
69
|
+
# with expired LAZY_DATA_RESOLVER_PATH -> raises exception
|
|
70
|
+
with tempfile.TemporaryDirectory() as tempdir:
|
|
71
|
+
data_dir = Path(tempdir)
|
|
72
|
+
write_bptr_manifest_to_file(expiration_timestamp=int(time.time()) - 1)
|
|
73
|
+
resolver = LazyDataResolverV2(data_dir).fetch()
|
|
74
|
+
with pytest.raises(Exception):
|
|
75
|
+
resolver.fetch()
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
if __name__ == "__main__":
|
|
79
|
+
test_lazy_data_resolver_v2()
|
|
@@ -2,7 +2,7 @@ import os
|
|
|
2
2
|
from contextlib import contextmanager
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
|
|
5
|
-
from truss.
|
|
5
|
+
from truss.templates.shared.secrets_resolver import SecretsResolver
|
|
6
6
|
|
|
7
7
|
CONFIG = {"secrets": {"secret_key": "default_secret_value"}}
|
|
8
8
|
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
from typing import Any
|
|
1
|
+
from typing import Any
|
|
2
2
|
from unittest.mock import Mock
|
|
3
3
|
|
|
4
4
|
import pytest
|
|
5
|
-
from truss.server.common.retry import retry
|
|
5
|
+
from truss.templates.server.common.retry import retry
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class FailForCallCount:
|
|
@@ -20,7 +20,7 @@ class FailForCallCount:
|
|
|
20
20
|
return self._call_count
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
def fail_for_call_count(count: int) ->
|
|
23
|
+
def fail_for_call_count(count: int) -> callable:
|
|
24
24
|
call_count = 0
|
|
25
25
|
|
|
26
26
|
def inner():
|
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import os
|
|
3
|
+
import sys
|
|
4
|
+
import time
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
8
|
+
from unittest.mock import MagicMock, Mock, patch
|
|
9
|
+
|
|
10
|
+
import opentelemetry.sdk.trace as sdk_trace
|
|
11
|
+
import pytest
|
|
12
|
+
import yaml
|
|
13
|
+
from starlette.requests import Request
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@pytest.fixture
|
|
17
|
+
def anyio_backend():
|
|
18
|
+
return "asyncio"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@pytest.fixture
|
|
22
|
+
def app_path(truss_container_fs: Path, helpers: Any):
|
|
23
|
+
truss_container_app_path = truss_container_fs / "app"
|
|
24
|
+
model_file_content = """
|
|
25
|
+
class Model:
|
|
26
|
+
def __init__(self):
|
|
27
|
+
self.load_count = 0
|
|
28
|
+
|
|
29
|
+
def load(self):
|
|
30
|
+
self.load_count += 1
|
|
31
|
+
if self.load_count <= 2:
|
|
32
|
+
raise RuntimeError('Simulated error')
|
|
33
|
+
|
|
34
|
+
def predict(self, request):
|
|
35
|
+
return request
|
|
36
|
+
"""
|
|
37
|
+
with helpers.file_content(
|
|
38
|
+
truss_container_app_path / "model" / "model.py", model_file_content
|
|
39
|
+
), helpers.sys_path(truss_container_app_path):
|
|
40
|
+
yield truss_container_app_path
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@pytest.mark.anyio
|
|
44
|
+
async def test_model_wrapper_load_error_once(app_path):
|
|
45
|
+
if "model_wrapper" in sys.modules:
|
|
46
|
+
model_wrapper_module = sys.modules["model_wrapper"]
|
|
47
|
+
importlib.reload(model_wrapper_module)
|
|
48
|
+
else:
|
|
49
|
+
model_wrapper_module = importlib.import_module("model_wrapper")
|
|
50
|
+
model_wrapper_class = getattr(model_wrapper_module, "ModelWrapper")
|
|
51
|
+
config = yaml.safe_load((app_path / "config.yaml").read_text())
|
|
52
|
+
os.chdir(app_path)
|
|
53
|
+
model_wrapper = model_wrapper_class(config, sdk_trace.NoOpTracer())
|
|
54
|
+
model_wrapper.load()
|
|
55
|
+
# Allow load thread to execute
|
|
56
|
+
time.sleep(1)
|
|
57
|
+
output = await model_wrapper.predict({}, MagicMock(spec=Request))
|
|
58
|
+
assert output == {}
|
|
59
|
+
assert model_wrapper._model.load_count == 2
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def test_model_wrapper_load_error_more_than_allowed(app_path, helpers):
|
|
63
|
+
with helpers.env_var("NUM_LOAD_RETRIES_TRUSS", "0"):
|
|
64
|
+
if "model_wrapper" in sys.modules:
|
|
65
|
+
model_wrapper_module = sys.modules["model_wrapper"]
|
|
66
|
+
importlib.reload(model_wrapper_module)
|
|
67
|
+
else:
|
|
68
|
+
model_wrapper_module = importlib.import_module("model_wrapper")
|
|
69
|
+
model_wrapper_class = getattr(model_wrapper_module, "ModelWrapper")
|
|
70
|
+
config = yaml.safe_load((app_path / "config.yaml").read_text())
|
|
71
|
+
os.chdir(app_path)
|
|
72
|
+
model_wrapper = model_wrapper_class(config, sdk_trace.NoOpTracer())
|
|
73
|
+
model_wrapper.load()
|
|
74
|
+
# Allow load thread to execute
|
|
75
|
+
time.sleep(1)
|
|
76
|
+
assert model_wrapper.load_failed
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@pytest.mark.anyio
|
|
80
|
+
@pytest.mark.integration
|
|
81
|
+
async def test_model_wrapper_streaming_timeout(app_path):
|
|
82
|
+
if "model_wrapper" in sys.modules:
|
|
83
|
+
model_wrapper_module = sys.modules["model_wrapper"]
|
|
84
|
+
importlib.reload(model_wrapper_module)
|
|
85
|
+
else:
|
|
86
|
+
model_wrapper_module = importlib.import_module("model_wrapper")
|
|
87
|
+
model_wrapper_class = getattr(model_wrapper_module, "ModelWrapper")
|
|
88
|
+
|
|
89
|
+
# Create an instance of ModelWrapper with streaming_read_timeout set to 5 seconds
|
|
90
|
+
config = yaml.safe_load((app_path / "config.yaml").read_text())
|
|
91
|
+
config["runtime"]["streaming_read_timeout"] = 5
|
|
92
|
+
model_wrapper = model_wrapper_class(config, sdk_trace.NoOpTracer())
|
|
93
|
+
model_wrapper.load()
|
|
94
|
+
assert model_wrapper._config.get("runtime").get("streaming_read_timeout") == 5
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@pytest.mark.anyio
|
|
98
|
+
async def test_trt_llm_truss_init_extension(trt_llm_truss_container_fs, helpers):
|
|
99
|
+
app_path = trt_llm_truss_container_fs / "app"
|
|
100
|
+
packages_path = trt_llm_truss_container_fs / "packages"
|
|
101
|
+
with _clear_model_load_modules(), helpers.sys_paths(app_path, packages_path):
|
|
102
|
+
model_wrapper_module = importlib.import_module("model_wrapper")
|
|
103
|
+
model_wrapper_class = getattr(model_wrapper_module, "ModelWrapper")
|
|
104
|
+
config = yaml.safe_load((app_path / "config.yaml").read_text())
|
|
105
|
+
mock_extension = Mock()
|
|
106
|
+
mock_extension.load = Mock()
|
|
107
|
+
with patch.object(
|
|
108
|
+
model_wrapper_module, "_init_extension", return_value=mock_extension
|
|
109
|
+
) as mock_init_extension:
|
|
110
|
+
model_wrapper = model_wrapper_class(config, sdk_trace.NoOpTracer())
|
|
111
|
+
model_wrapper.load()
|
|
112
|
+
called_with_specific_extension = any(
|
|
113
|
+
call_args[0][0] == "trt_llm"
|
|
114
|
+
for call_args in mock_init_extension.call_args_list
|
|
115
|
+
)
|
|
116
|
+
assert called_with_specific_extension, (
|
|
117
|
+
"Expected extension_name was not called"
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@pytest.mark.anyio
|
|
122
|
+
async def test_trt_llm_truss_predict(trt_llm_truss_container_fs, helpers):
|
|
123
|
+
app_path = trt_llm_truss_container_fs / "app"
|
|
124
|
+
packages_path = trt_llm_truss_container_fs / "packages"
|
|
125
|
+
with _clear_model_load_modules(), helpers.sys_paths(
|
|
126
|
+
app_path, packages_path
|
|
127
|
+
), _change_directory(app_path):
|
|
128
|
+
model_wrapper_module = importlib.import_module("model_wrapper")
|
|
129
|
+
model_wrapper_class = getattr(model_wrapper_module, "ModelWrapper")
|
|
130
|
+
config = yaml.safe_load((app_path / "config.yaml").read_text())
|
|
131
|
+
|
|
132
|
+
expected_predict_response = "test"
|
|
133
|
+
mock_predict_called = False
|
|
134
|
+
|
|
135
|
+
async def mock_predict(return_value, request):
|
|
136
|
+
nonlocal mock_predict_called
|
|
137
|
+
mock_predict_called = True
|
|
138
|
+
return expected_predict_response
|
|
139
|
+
|
|
140
|
+
mock_engine = Mock(predict=mock_predict)
|
|
141
|
+
mock_extension = Mock()
|
|
142
|
+
mock_extension.load = Mock()
|
|
143
|
+
mock_extension.model_args = Mock(return_value={"engine": mock_engine})
|
|
144
|
+
with patch.object(
|
|
145
|
+
model_wrapper_module, "_init_extension", return_value=mock_extension
|
|
146
|
+
):
|
|
147
|
+
model_wrapper = model_wrapper_class(config, sdk_trace.NoOpTracer())
|
|
148
|
+
model_wrapper.load()
|
|
149
|
+
resp = await model_wrapper.predict({}, MagicMock(spec=Request))
|
|
150
|
+
mock_extension.load.assert_called()
|
|
151
|
+
mock_extension.model_args.assert_called()
|
|
152
|
+
assert mock_predict_called
|
|
153
|
+
assert resp == expected_predict_response
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@pytest.mark.anyio
|
|
157
|
+
async def test_trt_llm_truss_missing_model_py(trt_llm_truss_container_fs, helpers):
|
|
158
|
+
app_path = trt_llm_truss_container_fs / "app"
|
|
159
|
+
(app_path / "model" / "model.py").unlink()
|
|
160
|
+
|
|
161
|
+
packages_path = trt_llm_truss_container_fs / "packages"
|
|
162
|
+
with _clear_model_load_modules(), helpers.sys_paths(
|
|
163
|
+
app_path, packages_path
|
|
164
|
+
), _change_directory(app_path):
|
|
165
|
+
model_wrapper_module = importlib.import_module("model_wrapper")
|
|
166
|
+
model_wrapper_class = getattr(model_wrapper_module, "ModelWrapper")
|
|
167
|
+
config = yaml.safe_load((app_path / "config.yaml").read_text())
|
|
168
|
+
|
|
169
|
+
expected_predict_response = "test"
|
|
170
|
+
mock_predict_called = False
|
|
171
|
+
|
|
172
|
+
async def mock_predict(return_value, request: Request):
|
|
173
|
+
nonlocal mock_predict_called
|
|
174
|
+
mock_predict_called = True
|
|
175
|
+
return expected_predict_response
|
|
176
|
+
|
|
177
|
+
mock_engine = Mock(predict=mock_predict, spec=["predict"])
|
|
178
|
+
mock_extension = Mock()
|
|
179
|
+
mock_extension.load = Mock()
|
|
180
|
+
mock_extension.model_override = Mock(return_value=mock_engine)
|
|
181
|
+
with patch.object(
|
|
182
|
+
model_wrapper_module, "_init_extension", return_value=mock_extension
|
|
183
|
+
):
|
|
184
|
+
model_wrapper = model_wrapper_class(config, sdk_trace.NoOpTracer())
|
|
185
|
+
model_wrapper.load()
|
|
186
|
+
resp = await model_wrapper.predict({}, MagicMock(spec=Request))
|
|
187
|
+
mock_extension.load.assert_called()
|
|
188
|
+
mock_extension.model_override.assert_called()
|
|
189
|
+
assert mock_predict_called
|
|
190
|
+
assert resp == expected_predict_response
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
@pytest.mark.anyio
|
|
194
|
+
async def test_open_ai_completion_endpoints(open_ai_container_fs, helpers):
|
|
195
|
+
app_path = open_ai_container_fs / "app"
|
|
196
|
+
with _clear_model_load_modules(), helpers.sys_paths(app_path), _change_directory(
|
|
197
|
+
app_path
|
|
198
|
+
):
|
|
199
|
+
model_wrapper_module = importlib.import_module("model_wrapper")
|
|
200
|
+
model_wrapper_class = getattr(model_wrapper_module, "ModelWrapper")
|
|
201
|
+
config = yaml.safe_load((app_path / "config.yaml").read_text())
|
|
202
|
+
|
|
203
|
+
model_wrapper = model_wrapper_class(config, sdk_trace.NoOpTracer())
|
|
204
|
+
model_wrapper.load()
|
|
205
|
+
|
|
206
|
+
mock_req = MagicMock(spec=Request)
|
|
207
|
+
predict_resp = await model_wrapper.predict({}, mock_req)
|
|
208
|
+
assert predict_resp == "predict"
|
|
209
|
+
|
|
210
|
+
completions_resp = await model_wrapper.completions({}, mock_req)
|
|
211
|
+
assert completions_resp == "completions"
|
|
212
|
+
|
|
213
|
+
chat_completions_resp = await model_wrapper.chat_completions({}, mock_req)
|
|
214
|
+
assert chat_completions_resp == "chat_completions"
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
@contextmanager
|
|
218
|
+
def _change_directory(new_directory: Path):
|
|
219
|
+
original_directory = os.getcwd()
|
|
220
|
+
os.chdir(str(new_directory))
|
|
221
|
+
try:
|
|
222
|
+
yield
|
|
223
|
+
finally:
|
|
224
|
+
os.chdir(original_directory)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
@contextmanager
|
|
228
|
+
def _clear_model_load_modules():
|
|
229
|
+
"""Clear dangling references to model and model_wrapper modules
|
|
230
|
+
|
|
231
|
+
We do this before to clear any debris from before, and after to clean up
|
|
232
|
+
after self. This is meant for cases where we simulate running a truss model
|
|
233
|
+
in process, where these modules are loaded dyamically.
|
|
234
|
+
"""
|
|
235
|
+
# These are left over by TrussModuleLoader used by local prediction tests.
|
|
236
|
+
# TODO(pankaj) Find a way for TrussModuleLoader to clean up after itself.
|
|
237
|
+
_remove_model_load_sys_modules()
|
|
238
|
+
yield
|
|
239
|
+
_remove_model_load_sys_modules()
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def _remove_model_load_sys_modules():
|
|
243
|
+
if "model" in sys.modules:
|
|
244
|
+
del sys.modules["model"]
|
|
245
|
+
if "model.model" in sys.modules:
|
|
246
|
+
del sys.modules["model.model"]
|
|
247
|
+
if "model_wrapper" in sys.modules:
|
|
248
|
+
del sys.modules["model_wrapper"]
|
|
@@ -2,7 +2,7 @@ import inspect
|
|
|
2
2
|
from typing import AsyncGenerator, Awaitable, Generator, Union
|
|
3
3
|
|
|
4
4
|
from pydantic import BaseModel
|
|
5
|
-
from truss.server.common.schema import TrussSchema
|
|
5
|
+
from truss.templates.server.common.schema import TrussSchema
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class ModelInput(BaseModel):
|
|
@@ -181,8 +181,7 @@ def test_truss_schema_union_sync():
|
|
|
181
181
|
def test_truss_schema_union_async():
|
|
182
182
|
class Model:
|
|
183
183
|
async def predict(
|
|
184
|
-
self,
|
|
185
|
-
request: ModelInput,
|
|
184
|
+
self, request: ModelInput
|
|
186
185
|
) -> Union[Awaitable[ModelOutput], AsyncGenerator[str, None]]:
|
|
187
186
|
if request.stream:
|
|
188
187
|
|
|
@@ -208,8 +207,7 @@ def test_truss_schema_union_async():
|
|
|
208
207
|
def test_truss_schema_union_async_non_pydantic():
|
|
209
208
|
class Model:
|
|
210
209
|
async def predict(
|
|
211
|
-
self,
|
|
212
|
-
request: ModelInput,
|
|
210
|
+
self, request: ModelInput
|
|
213
211
|
) -> Union[Awaitable[str], AsyncGenerator[str, None]]:
|
|
214
212
|
return "hello"
|
|
215
213
|
|
|
@@ -8,18 +8,21 @@ from multiprocessing import Process
|
|
|
8
8
|
from pathlib import Path
|
|
9
9
|
|
|
10
10
|
import pytest
|
|
11
|
-
import yaml
|
|
12
|
-
from truss.server.common.truss_server import TrussServer
|
|
13
11
|
|
|
14
12
|
|
|
15
13
|
@pytest.mark.integration
|
|
16
|
-
def test_truss_server_termination(
|
|
14
|
+
def test_truss_server_termination(truss_container_fs):
|
|
17
15
|
port = 10123
|
|
18
16
|
|
|
19
17
|
def start_truss_server(stdout_capture_file_path):
|
|
20
18
|
sys.stdout = open(stdout_capture_file_path, "w")
|
|
21
|
-
|
|
22
|
-
|
|
19
|
+
app_path = truss_container_fs / "app"
|
|
20
|
+
sys.path.append(str(app_path))
|
|
21
|
+
os.chdir(app_path)
|
|
22
|
+
|
|
23
|
+
from truss_server import TrussServer
|
|
24
|
+
|
|
25
|
+
server = TrussServer(http_port=port, config_or_path=app_path / "config.yaml")
|
|
23
26
|
server.start()
|
|
24
27
|
|
|
25
28
|
stdout_capture_file = tempfile.NamedTemporaryFile()
|
truss/tests/test_build.py
CHANGED
|
@@ -1,63 +1,20 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
from
|
|
4
|
-
from truss.truss_spec import TrussSpec
|
|
1
|
+
from truss.base.truss_spec import TrussSpec
|
|
2
|
+
from truss.truss_handle.build import init_directory, load
|
|
3
|
+
from truss_chains.deployment import code_gen
|
|
5
4
|
|
|
6
5
|
|
|
7
6
|
def test_truss_init(tmp_path):
|
|
8
|
-
|
|
9
|
-
init(dir_name)
|
|
10
|
-
spec = TrussSpec(Path(dir_name))
|
|
7
|
+
spec = TrussSpec(init_directory(tmp_path))
|
|
11
8
|
assert spec.model_module_dir.exists()
|
|
12
9
|
assert spec.data_dir.exists()
|
|
13
10
|
assert spec.truss_dir == tmp_path
|
|
14
11
|
assert spec.config_path.exists()
|
|
15
12
|
|
|
16
13
|
|
|
17
|
-
def
|
|
18
|
-
tmp_path,
|
|
19
|
-
):
|
|
20
|
-
dir_path = tmp_path / "truss"
|
|
21
|
-
dir_name = str(dir_path)
|
|
22
|
-
|
|
23
|
-
# Init data files
|
|
24
|
-
data_path = tmp_path / "data.txt"
|
|
25
|
-
with data_path.open("w") as data_file:
|
|
26
|
-
data_file.write("test")
|
|
27
|
-
|
|
28
|
-
# Init requirements file
|
|
29
|
-
req_file_path = tmp_path / "requirements.txt"
|
|
30
|
-
requirements = [
|
|
31
|
-
"tensorflow==2.3.1",
|
|
32
|
-
"uvicorn==0.12.2",
|
|
33
|
-
]
|
|
34
|
-
with req_file_path.open("w") as req_file:
|
|
35
|
-
for req in requirements:
|
|
36
|
-
req_file.write(f"{req}\n")
|
|
14
|
+
def test_truss_init_with_python_dx(tmp_path):
|
|
15
|
+
init_directory(tmp_path, model_name="Test Model Name", python_config=True)
|
|
37
16
|
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
packages_path.mkdir()
|
|
41
|
-
packages_path_file_py = packages_path / "file.py"
|
|
42
|
-
packages_path_init_py = packages_path / "__init__.py"
|
|
43
|
-
pkg_files = [packages_path_init_py, packages_path_file_py]
|
|
44
|
-
for pkg_file in pkg_files:
|
|
45
|
-
with pkg_file.open("w") as fh:
|
|
46
|
-
fh.write("test")
|
|
17
|
+
generated_truss_dir = code_gen.gen_truss_model_from_source(tmp_path / "my_model.py")
|
|
18
|
+
truss_handle = load(generated_truss_dir)
|
|
47
19
|
|
|
48
|
-
|
|
49
|
-
dir_name,
|
|
50
|
-
data_files=[str(data_path)],
|
|
51
|
-
requirements_file=str(req_file_path),
|
|
52
|
-
bundled_packages=[str(packages_path)],
|
|
53
|
-
)
|
|
54
|
-
spec = TrussSpec(Path(dir_name))
|
|
55
|
-
assert spec.model_module_dir.exists()
|
|
56
|
-
assert spec.truss_dir == dir_path
|
|
57
|
-
assert spec.config_path.exists()
|
|
58
|
-
assert spec.data_dir.exists()
|
|
59
|
-
assert spec.bundled_packages_dir.exists()
|
|
60
|
-
assert (spec.data_dir / "data.txt").exists()
|
|
61
|
-
assert spec.requirements == requirements
|
|
62
|
-
assert (spec.bundled_packages_dir / "dep_pkg" / "__init__.py").exists()
|
|
63
|
-
assert (spec.bundled_packages_dir / "dep_pkg" / "file.py").exists()
|
|
20
|
+
assert truss_handle.spec.config.model_name == "Test Model Name"
|