truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

Files changed (362) hide show
  1. truss/__init__.py +10 -3
  2. truss/api/__init__.py +123 -0
  3. truss/api/definitions.py +51 -0
  4. truss/base/constants.py +116 -0
  5. truss/base/custom_types.py +29 -0
  6. truss/{errors.py → base/errors.py} +4 -0
  7. truss/base/trt_llm_config.py +310 -0
  8. truss/{truss_config.py → base/truss_config.py} +344 -31
  9. truss/{truss_spec.py → base/truss_spec.py} +20 -6
  10. truss/{validation.py → base/validation.py} +60 -11
  11. truss/cli/cli.py +841 -88
  12. truss/{remote → cli}/remote_cli.py +2 -7
  13. truss/contexts/docker_build_setup.py +67 -0
  14. truss/contexts/image_builder/cache_warmer.py +2 -8
  15. truss/contexts/image_builder/image_builder.py +1 -1
  16. truss/contexts/image_builder/serving_image_builder.py +292 -46
  17. truss/contexts/image_builder/util.py +1 -3
  18. truss/contexts/local_loader/docker_build_emulator.py +58 -0
  19. truss/contexts/local_loader/load_model_local.py +2 -2
  20. truss/contexts/local_loader/truss_module_loader.py +1 -1
  21. truss/contexts/local_loader/utils.py +1 -1
  22. truss/local/local_config.py +2 -6
  23. truss/local/local_config_handler.py +20 -5
  24. truss/patch/__init__.py +1 -0
  25. truss/patch/hash.py +4 -70
  26. truss/patch/signature.py +4 -16
  27. truss/patch/truss_dir_patch_applier.py +3 -78
  28. truss/remote/baseten/api.py +308 -23
  29. truss/remote/baseten/auth.py +3 -3
  30. truss/remote/baseten/core.py +257 -50
  31. truss/remote/baseten/custom_types.py +44 -0
  32. truss/remote/baseten/error.py +4 -0
  33. truss/remote/baseten/remote.py +369 -118
  34. truss/remote/baseten/service.py +118 -11
  35. truss/remote/baseten/utils/status.py +29 -0
  36. truss/remote/baseten/utils/tar.py +34 -22
  37. truss/remote/baseten/utils/transfer.py +36 -23
  38. truss/remote/remote_factory.py +14 -5
  39. truss/remote/truss_remote.py +72 -45
  40. truss/templates/base.Dockerfile.jinja +18 -16
  41. truss/templates/cache.Dockerfile.jinja +3 -3
  42. truss/{server → templates/control}/control/application.py +14 -35
  43. truss/{server → templates/control}/control/endpoints.py +39 -9
  44. truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
  45. truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
  46. truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
  47. truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
  48. truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
  49. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
  50. truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
  51. truss/{server → templates/control}/control/server.py +11 -6
  52. truss/templates/control/requirements.txt +9 -0
  53. truss/templates/custom_python_dx/my_model.py +28 -0
  54. truss/templates/docker_server/proxy.conf.jinja +42 -0
  55. truss/templates/docker_server/supervisord.conf.jinja +27 -0
  56. truss/templates/docker_server_requirements.txt +1 -0
  57. truss/templates/server/common/errors.py +231 -0
  58. truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
  59. truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
  60. truss/{server → templates/server}/common/retry.py +1 -0
  61. truss/{server → templates/server}/common/schema.py +11 -9
  62. truss/templates/server/common/tracing.py +157 -0
  63. truss/templates/server/main.py +9 -0
  64. truss/templates/server/model_wrapper.py +961 -0
  65. truss/templates/server/requirements.txt +21 -0
  66. truss/templates/server/truss_server.py +447 -0
  67. truss/templates/server.Dockerfile.jinja +62 -14
  68. truss/templates/shared/dynamic_config_resolver.py +28 -0
  69. truss/templates/shared/lazy_data_resolver.py +164 -0
  70. truss/templates/shared/log_config.py +125 -0
  71. truss/{server → templates}/shared/secrets_resolver.py +1 -2
  72. truss/{server → templates}/shared/serialization.py +31 -9
  73. truss/{server → templates}/shared/util.py +3 -13
  74. truss/templates/trtllm-audio/model/model.py +49 -0
  75. truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
  76. truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
  77. truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
  78. truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
  79. truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
  80. truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
  81. truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
  82. truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
  83. truss/templates/trtllm-briton/src/extension.py +64 -0
  84. truss/tests/conftest.py +302 -94
  85. truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
  86. truss/tests/contexts/local_loader/test_load_local.py +2 -2
  87. truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
  88. truss/tests/patch/test_calc_patch.py +439 -127
  89. truss/tests/patch/test_dir_signature.py +3 -12
  90. truss/tests/patch/test_hash.py +1 -1
  91. truss/tests/patch/test_signature.py +1 -1
  92. truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
  93. truss/tests/patch/test_types.py +2 -2
  94. truss/tests/remote/baseten/test_api.py +153 -58
  95. truss/tests/remote/baseten/test_auth.py +2 -1
  96. truss/tests/remote/baseten/test_core.py +160 -12
  97. truss/tests/remote/baseten/test_remote.py +489 -77
  98. truss/tests/remote/baseten/test_service.py +55 -0
  99. truss/tests/remote/test_remote_factory.py +16 -18
  100. truss/tests/remote/test_truss_remote.py +26 -17
  101. truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
  102. truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
  103. truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
  104. truss/tests/{server → templates/control}/control/test_server.py +79 -24
  105. truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
  106. truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
  107. truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
  108. truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
  109. truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
  110. truss/tests/{server → templates/server}/common/test_retry.py +3 -3
  111. truss/tests/templates/server/test_model_wrapper.py +248 -0
  112. truss/tests/{server → templates/server}/test_schema.py +3 -5
  113. truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
  114. truss/tests/test_build.py +9 -52
  115. truss/tests/test_config.py +336 -77
  116. truss/tests/test_context_builder_image.py +3 -11
  117. truss/tests/test_control_truss_patching.py +7 -12
  118. truss/tests/test_custom_server.py +38 -0
  119. truss/tests/test_data/context_builder_image_test/test.py +3 -0
  120. truss/tests/test_data/happy.ipynb +56 -0
  121. truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
  122. truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
  123. truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
  124. truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
  125. truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
  126. truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
  127. truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
  128. truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
  129. truss/tests/test_data/test_async_truss/__init__.py +0 -0
  130. truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
  131. truss/tests/test_data/test_basic_truss/__init__.py +0 -0
  132. truss/tests/test_data/test_basic_truss/config.yaml +16 -0
  133. truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
  134. truss/tests/test_data/test_build_commands/__init__.py +0 -0
  135. truss/tests/test_data/test_build_commands/config.yaml +13 -0
  136. truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
  137. truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
  138. truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
  139. truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
  140. truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
  141. truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
  142. truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
  143. truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
  144. truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
  145. truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
  146. truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
  147. truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
  148. truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
  149. truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
  150. truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
  151. truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
  152. truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
  153. truss/tests/test_data/test_openai/__init__.py +0 -0
  154. truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
  155. truss/tests/test_data/test_openai/model/__init__.py +0 -0
  156. truss/tests/test_data/test_openai/model/model.py +15 -0
  157. truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
  158. truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
  159. truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
  160. truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
  161. truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
  162. truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
  163. truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
  164. truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
  165. truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
  166. truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
  167. truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
  168. truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
  169. truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
  170. truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
  171. truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
  172. truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
  173. truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
  174. truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
  175. truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
  176. truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
  177. truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
  178. truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
  179. truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
  180. truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
  181. truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
  182. truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
  183. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
  184. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
  185. truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
  186. truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
  187. truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
  188. truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
  189. truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
  190. truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
  191. truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
  192. truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
  193. truss/tests/test_data/test_truss/__init__.py +0 -0
  194. truss/tests/test_data/test_truss/config.yaml +4 -0
  195. truss/tests/test_data/test_truss/model/__init__.py +0 -0
  196. truss/tests/test_data/test_truss/model/dummy +0 -0
  197. truss/tests/test_data/test_truss/packages/__init__.py +0 -0
  198. truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
  199. truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
  200. truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
  201. truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
  202. truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
  203. truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
  204. truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
  205. truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
  206. truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
  207. truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
  208. truss/tests/test_docker.py +2 -1
  209. truss/tests/test_model_inference.py +1340 -292
  210. truss/tests/test_model_schema.py +33 -26
  211. truss/tests/test_testing_utilities_for_other_tests.py +50 -5
  212. truss/tests/test_truss_gatherer.py +3 -5
  213. truss/tests/test_truss_handle.py +62 -59
  214. truss/tests/test_util.py +2 -1
  215. truss/tests/test_validation.py +15 -13
  216. truss/tests/trt_llm/test_trt_llm_config.py +41 -0
  217. truss/tests/trt_llm/test_validation.py +91 -0
  218. truss/tests/util/test_config_checks.py +40 -0
  219. truss/tests/util/test_env_vars.py +14 -0
  220. truss/tests/util/test_path.py +10 -23
  221. truss/trt_llm/config_checks.py +43 -0
  222. truss/trt_llm/validation.py +42 -0
  223. truss/truss_handle/__init__.py +0 -0
  224. truss/truss_handle/build.py +122 -0
  225. truss/{decorators.py → truss_handle/decorators.py} +1 -1
  226. truss/truss_handle/patch/__init__.py +0 -0
  227. truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
  228. truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
  229. truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
  230. truss/truss_handle/patch/hash.py +71 -0
  231. truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
  232. truss/truss_handle/patch/signature.py +22 -0
  233. truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
  234. truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
  235. truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
  236. truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
  237. truss/util/.truss_ignore +3 -0
  238. truss/{docker.py → util/docker.py} +6 -2
  239. truss/util/download.py +6 -15
  240. truss/util/env_vars.py +41 -0
  241. truss/util/log_utils.py +52 -0
  242. truss/util/path.py +20 -20
  243. truss/util/requirements.py +11 -0
  244. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
  245. truss-0.60.0.dist-info/RECORD +324 -0
  246. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
  247. truss-0.60.0.dist-info/entry_points.txt +4 -0
  248. truss_chains/__init__.py +71 -0
  249. truss_chains/definitions.py +756 -0
  250. truss_chains/deployment/__init__.py +0 -0
  251. truss_chains/deployment/code_gen.py +816 -0
  252. truss_chains/deployment/deployment_client.py +871 -0
  253. truss_chains/framework.py +1480 -0
  254. truss_chains/public_api.py +231 -0
  255. truss_chains/py.typed +0 -0
  256. truss_chains/pydantic_numpy.py +131 -0
  257. truss_chains/reference_code/reference_chainlet.py +34 -0
  258. truss_chains/reference_code/reference_model.py +10 -0
  259. truss_chains/remote_chainlet/__init__.py +0 -0
  260. truss_chains/remote_chainlet/model_skeleton.py +60 -0
  261. truss_chains/remote_chainlet/stub.py +380 -0
  262. truss_chains/remote_chainlet/utils.py +332 -0
  263. truss_chains/streaming.py +378 -0
  264. truss_chains/utils.py +178 -0
  265. CODE_OF_CONDUCT.md +0 -131
  266. CONTRIBUTING.md +0 -48
  267. README.md +0 -137
  268. context_builder.Dockerfile +0 -24
  269. truss/blob/blob_backend.py +0 -10
  270. truss/blob/blob_backend_registry.py +0 -23
  271. truss/blob/http_public_blob_backend.py +0 -23
  272. truss/build/__init__.py +0 -2
  273. truss/build/build.py +0 -143
  274. truss/build/configure.py +0 -63
  275. truss/cli/__init__.py +0 -2
  276. truss/cli/console.py +0 -5
  277. truss/cli/create.py +0 -5
  278. truss/config/trt_llm.py +0 -81
  279. truss/constants.py +0 -61
  280. truss/model_inference.py +0 -123
  281. truss/patch/types.py +0 -30
  282. truss/pytest.ini +0 -7
  283. truss/server/common/errors.py +0 -100
  284. truss/server/common/termination_handler_middleware.py +0 -64
  285. truss/server/common/truss_server.py +0 -389
  286. truss/server/control/patch/model_code_patch_applier.py +0 -46
  287. truss/server/control/patch/requirement_name_identifier.py +0 -17
  288. truss/server/inference_server.py +0 -29
  289. truss/server/model_wrapper.py +0 -434
  290. truss/server/shared/logging.py +0 -81
  291. truss/templates/trtllm/model/model.py +0 -97
  292. truss/templates/trtllm/packages/build_engine_utils.py +0 -34
  293. truss/templates/trtllm/packages/constants.py +0 -11
  294. truss/templates/trtllm/packages/schema.py +0 -216
  295. truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
  296. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
  297. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
  298. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
  299. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
  300. truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
  301. truss/templates/trtllm/packages/triton_client.py +0 -150
  302. truss/templates/trtllm/packages/utils.py +0 -43
  303. truss/test_data/context_builder_image_test/test.py +0 -4
  304. truss/test_data/happy.ipynb +0 -54
  305. truss/test_data/model_load_failure_test/config.yaml +0 -2
  306. truss/test_data/test_concurrency_truss/config.yaml +0 -2
  307. truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
  308. truss/test_data/test_streaming_truss/config.yaml +0 -3
  309. truss/test_data/test_truss/config.yaml +0 -2
  310. truss/tests/server/common/test_termination_handler_middleware.py +0 -93
  311. truss/tests/server/control/test_model_container_patch_applier.py +0 -203
  312. truss/tests/server/core/server/common/test_util.py +0 -19
  313. truss/tests/server/test_model_wrapper.py +0 -87
  314. truss/util/data_structures.py +0 -16
  315. truss-0.10.0rc1.dist-info/RECORD +0 -216
  316. truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
  317. truss/{server/shared → base}/__init__.py +0 -0
  318. truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
  319. truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
  320. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
  321. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
  322. truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
  323. truss/{server → templates/server}/common/__init__.py +0 -0
  324. truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
  325. truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
  326. truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
  327. truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
  328. truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
  329. truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
  330. truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
  331. truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
  332. truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
  333. truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
  334. truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
  335. truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
  336. truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
  337. truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
  338. truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
  339. truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
  340. truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
  341. truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
  342. truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
  343. truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
  344. truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
  345. truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
  346. truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
  347. /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
  348. /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
  349. /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
  350. /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
  351. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
  352. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
  353. /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
  354. /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
  355. /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
  356. /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
  357. /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
  358. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
  359. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
  360. /truss/{patch → truss_handle/patch}/constants.py +0 -0
  361. /truss/{notebook.py → util/notebook.py} +0 -0
  362. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
@@ -1,150 +0,0 @@
1
- import json
2
- import os
3
- import subprocess
4
- import time
5
- from pathlib import Path
6
- from typing import AsyncGenerator, Optional
7
-
8
- import tritonclient.grpc.aio as grpcclient
9
- import tritonclient.http as httpclient
10
- from constants import (
11
- ENTRYPOINT_MODEL_NAME,
12
- GRPC_SERVICE_PORT,
13
- TENSORRT_LLM_MODEL_REPOSITORY_PATH,
14
- )
15
- from schema import ModelInput
16
- from utils import download_engine, prepare_model_repository
17
-
18
-
19
- class TritonServer:
20
- def __init__(self, grpc_port: int = 8001, http_port: int = 8003):
21
- self.grpc_port = grpc_port
22
- self.http_port = http_port
23
- self._server_process = None
24
-
25
- def create_model_repository(
26
- self,
27
- truss_data_dir: Path,
28
- engine_repository_path: Optional[str] = None,
29
- huggingface_auth_token: Optional[str] = None,
30
- ) -> None:
31
- if engine_repository_path:
32
- download_engine(
33
- engine_repository=engine_repository_path,
34
- fp=truss_data_dir,
35
- auth_token=huggingface_auth_token,
36
- )
37
- prepare_model_repository(truss_data_dir)
38
- return
39
-
40
- def start(self, tensor_parallelism: int = 1, env: dict = {}) -> None:
41
- def build_server_start_command() -> list:
42
- """
43
- Triton Inference Server has different startup commands depending on the tensor
44
- parallelism (TP) configuration. This function starts the server with the appropriate
45
- command.
46
- """
47
- base_command = [
48
- "tritonserver",
49
- f"--model-repository={TENSORRT_LLM_MODEL_REPOSITORY_PATH}",
50
- f"--grpc-port={str(self.grpc_port)}",
51
- f"--http-port={str(self.http_port)}",
52
- # "--log-verbose=1",
53
- ]
54
-
55
- if tensor_parallelism == 1:
56
- return base_command
57
-
58
- mpirun_command = ["mpirun", "--allow-run-as-root"]
59
- mpi_commands = []
60
- for i in range(tensor_parallelism):
61
- mpi_command = [
62
- "-n=1",
63
- *base_command,
64
- "--disable-auto-complete-config",
65
- f"--backend-config=python,shm-region-prefix-name=prefix{str(i)}_",
66
- ]
67
- mpi_commands.append(" ".join(mpi_command))
68
-
69
- combined_mpi_commands = " : ".join(mpi_commands)
70
- return mpirun_command + [combined_mpi_commands]
71
-
72
- command = build_server_start_command()
73
- self._server_process = subprocess.Popen( # type: ignore
74
- command,
75
- env={**os.environ, **env},
76
- )
77
- while not self.is_alive and not self.is_ready:
78
- time.sleep(2)
79
- return
80
-
81
- def stop(self):
82
- if self._server_process:
83
- if self.is_server_ready:
84
- self._server_process.kill()
85
- self._server_process = None
86
- return
87
-
88
- @property
89
- def is_alive(self) -> bool:
90
- try:
91
- http_client = httpclient.InferenceServerClient(
92
- url=f"localhost:{self.http_port}", verbose=False
93
- )
94
- return http_client.is_server_live()
95
- except ConnectionRefusedError:
96
- return False
97
-
98
- @property
99
- def is_ready(self) -> bool:
100
- try:
101
- http_client = httpclient.InferenceServerClient(
102
- url=f"localhost:{self.http_port}", verbose=False
103
- )
104
- return http_client.is_model_ready(model_name=ENTRYPOINT_MODEL_NAME)
105
- except ConnectionRefusedError:
106
- return False
107
-
108
-
109
- class TritonClient:
110
- def __init__(self, grpc_service_port: int = GRPC_SERVICE_PORT):
111
- self.grpc_service_port = grpc_service_port
112
- self._grpc_client = None
113
-
114
- def start_grpc_stream(self) -> grpcclient.InferenceServerClient:
115
- if self._grpc_client:
116
- return self._grpc_client
117
-
118
- self._grpc_client = grpcclient.InferenceServerClient(
119
- url=f"localhost:{self.grpc_service_port}", verbose=False
120
- )
121
- return self._grpc_client
122
-
123
- async def infer(
124
- self, model_input: ModelInput, model_name="ensemble"
125
- ) -> AsyncGenerator[str, None]:
126
- grpc_client_instance = self.start_grpc_stream()
127
- inputs = model_input.to_tensors()
128
-
129
- async def input_generator():
130
- yield {
131
- "model_name": model_name,
132
- "inputs": inputs,
133
- "request_id": model_input.request_id,
134
- }
135
-
136
- response_iterator = grpc_client_instance.stream_infer(
137
- inputs_iterator=input_generator(),
138
- )
139
-
140
- try:
141
- async for response in response_iterator:
142
- result, error = response
143
- if result:
144
- result = result.as_numpy("text_output")
145
- yield result[0].decode("utf-8")
146
- else:
147
- yield json.dumps({"status": "error", "message": error.message()})
148
-
149
- except grpcclient.InferenceServerException as e:
150
- print(f"InferenceServerException: {e}")
@@ -1,43 +0,0 @@
1
- from pathlib import Path
2
-
3
- from constants import TENSORRT_LLM_MODEL_REPOSITORY_PATH
4
- from huggingface_hub import snapshot_download
5
-
6
-
7
- def move_all_files(src: Path, dest: Path) -> None:
8
- """
9
- Moves all files from `src` to `dest` recursively.
10
- """
11
- for item in src.iterdir():
12
- dest_item = dest / item.name
13
- if item.is_dir():
14
- dest_item.mkdir(parents=True, exist_ok=True)
15
- move_all_files(item, dest_item)
16
- else:
17
- item.rename(dest_item)
18
-
19
-
20
- def prepare_model_repository(data_dir: Path) -> None:
21
- # Ensure the destination directory exists
22
- dest_dir = TENSORRT_LLM_MODEL_REPOSITORY_PATH / "tensorrt_llm" / "1"
23
- dest_dir.mkdir(parents=True, exist_ok=True)
24
-
25
- # Ensure empty version directory for `ensemble` model exists
26
- ensemble_dir = TENSORRT_LLM_MODEL_REPOSITORY_PATH / "ensemble" / "1"
27
- ensemble_dir.mkdir(parents=True, exist_ok=True)
28
-
29
- # Move all files and directories from data_dir to dest_dir
30
- move_all_files(data_dir, dest_dir)
31
-
32
-
33
- def download_engine(engine_repository: str, fp: Path, auth_token=None):
34
- """
35
- Downloads the specified engine from Hugging Face Hub.
36
- """
37
- snapshot_download(
38
- engine_repository,
39
- local_dir=fp,
40
- local_dir_use_symlinks=False,
41
- max_workers=4,
42
- **({"use_auth_token": auth_token} if auth_token is not None else {}),
43
- )
@@ -1,4 +0,0 @@
1
- from truss import init
2
-
3
- th = init("test_truss")
4
- th.docker_build_setup()
@@ -1,54 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 13,
6
- "id": "7a4aad7e",
7
- "metadata": {},
8
- "outputs": [],
9
- "source": [
10
- " import truss\n",
11
- " from sklearn.linear_model import LogisticRegression\n",
12
- " import numpy as np\n",
13
- " import pandas as pd\n",
14
- " import random\n",
15
- " import tempfile\n",
16
- " \n",
17
- " # 10 rows of 10 features\n",
18
- " features = random.choices(range(10), k=10)\n",
19
- " values = np.random.randint(0, 100, size=(10, 10))\n",
20
- " \n",
21
- " df = pd.DataFrame(values, columns=features)\n",
22
- "\n",
23
- " lr = LogisticRegression(random_state=0, max_iter=1000)\n",
24
- " lr.fit(df.values, df.iloc[:, 0])\n",
25
- " \n",
26
- " with tempfile.TemporaryDirectory() as tmp:\n",
27
- " tr = truss.create(lr, target_directory=tmp)\n",
28
- " \n",
29
- " assert truss.load(tmp).spec.yaml_string == tr.spec.yaml_string"
30
- ]
31
- }
32
- ],
33
- "metadata": {
34
- "kernelspec": {
35
- "display_name": "Python 3 (ipykernel)",
36
- "language": "python",
37
- "name": "python3"
38
- },
39
- "language_info": {
40
- "codemirror_mode": {
41
- "name": "ipython",
42
- "version": 3
43
- },
44
- "file_extension": ".py",
45
- "mimetype": "text/x-python",
46
- "name": "python",
47
- "nbconvert_exporter": "python",
48
- "pygments_lexer": "ipython3",
49
- "version": "3.9.9"
50
- }
51
- },
52
- "nbformat": 4,
53
- "nbformat_minor": 5
54
- }
@@ -1,2 +0,0 @@
1
- model_name: Test Loaf Failure
2
- python_version: py39
@@ -1,2 +0,0 @@
1
- runtime:
2
- predict_concurrency: 2
@@ -1,2 +0,0 @@
1
- model_name: Test Streaming Async Generator
2
- python_version: py39
@@ -1,3 +0,0 @@
1
-
2
- model_name: Test Streaming
3
- python_version: py39
@@ -1,2 +0,0 @@
1
- model_name: Test
2
- python_version: py39
@@ -1,93 +0,0 @@
1
- import multiprocessing
2
- import tempfile
3
- import time
4
- from pathlib import Path
5
- from typing import Awaitable, Callable, List
6
-
7
- import pytest
8
- from truss.server.common.termination_handler_middleware import (
9
- TerminationHandlerMiddleware,
10
- )
11
-
12
-
13
- async def noop(*args, **kwargs):
14
- return
15
-
16
-
17
- @pytest.mark.integration
18
- def test_termination_sequence_no_pending_requests(tmp_path):
19
- # Create middleware in separate process, on sending term signal to process,
20
- # it should print the right messages.
21
- def main_coro_gen(middleware: TerminationHandlerMiddleware):
22
- import asyncio
23
-
24
- async def main(*args, **kwargs):
25
- await middleware(1, call_next=noop)
26
- await asyncio.sleep(1)
27
- print("should not print due to termination")
28
-
29
- return main()
30
-
31
- _verify_term(main_coro_gen, ["stopped", "terminated"])
32
-
33
-
34
- @pytest.mark.integration
35
- def test_termination_sequence_with_pending_requests(tmp_path):
36
- def main_coro_gen(middleware: TerminationHandlerMiddleware):
37
- import asyncio
38
-
39
- async def main(*args, **kwargs):
40
- async def call_next(req):
41
- await asyncio.sleep(1.0)
42
- return "call_next_called"
43
-
44
- resp = await middleware(1, call_next=call_next)
45
- print(f"call_next response: {resp}")
46
- await asyncio.sleep(1)
47
- print("should not print due to termination")
48
-
49
- return main()
50
-
51
- _verify_term(
52
- main_coro_gen,
53
- [
54
- "stopped",
55
- "call_next response: call_next_called",
56
- "terminated",
57
- ],
58
- )
59
-
60
-
61
- def _verify_term(
62
- main_coro_gen: Callable[[TerminationHandlerMiddleware], Awaitable],
63
- expected_lines: List[str],
64
- ):
65
- def run(stdout_capture_file_path):
66
- import asyncio
67
- import os
68
- import signal
69
- import sys
70
-
71
- sys.stdout = open(stdout_capture_file_path, "w")
72
-
73
- def term():
74
- print("terminated", flush=True)
75
- os.kill(os.getpid(), signal.SIGKILL)
76
-
77
- middleware = TerminationHandlerMiddleware(
78
- on_stop=lambda: print("stopped", flush=True),
79
- on_term=term,
80
- termination_delay_secs=0.1,
81
- )
82
- asyncio.run(main_coro_gen(middleware))
83
-
84
- stdout_capture_file = tempfile.NamedTemporaryFile()
85
- proc = multiprocessing.Process(target=run, args=(stdout_capture_file.name,))
86
- proc.start()
87
- time.sleep(1)
88
- proc.terminate()
89
- proc.join(timeout=6.0)
90
- with Path(stdout_capture_file.name).open() as file:
91
- lines = [line.strip() for line in file]
92
-
93
- assert lines == expected_lines
@@ -1,203 +0,0 @@
1
- import os
2
- from unittest import mock
3
-
4
- import pytest
5
- from truss.server.control.patch.model_container_patch_applier import (
6
- ModelContainerPatchApplier,
7
- )
8
- from truss.server.control.patch.types import (
9
- Action,
10
- ConfigPatch,
11
- EnvVarPatch,
12
- ExternalDataPatch,
13
- ModelCodePatch,
14
- PackagePatch,
15
- Patch,
16
- PatchType,
17
- )
18
- from truss.truss_config import TrussConfig
19
-
20
-
21
- @pytest.fixture
22
- def patch_applier(tmp_truss_dir):
23
- return ModelContainerPatchApplier(tmp_truss_dir, mock.Mock())
24
-
25
-
26
- def test_patch_applier_model_code_patch_add(
27
- patch_applier: ModelContainerPatchApplier, tmp_truss_dir
28
- ):
29
- patch = Patch(
30
- type=PatchType.MODEL_CODE,
31
- body=ModelCodePatch(
32
- action=Action.ADD,
33
- path="dummy",
34
- content="",
35
- ),
36
- )
37
- patch_applier(patch, os.environ.copy())
38
- assert (tmp_truss_dir / "model" / "dummy").exists()
39
-
40
-
41
- def test_patch_applier_model_code_patch_remove(
42
- patch_applier: ModelContainerPatchApplier, tmp_truss_dir
43
- ):
44
- patch = Patch(
45
- type=PatchType.MODEL_CODE,
46
- body=ModelCodePatch(
47
- action=Action.REMOVE,
48
- path="model.py",
49
- ),
50
- )
51
- assert (tmp_truss_dir / "model" / "model.py").exists()
52
- patch_applier(patch, os.environ.copy())
53
- assert not (tmp_truss_dir / "model" / "model.py").exists()
54
-
55
-
56
- def test_patch_applier_model_code_patch_update(
57
- patch_applier: ModelContainerPatchApplier, tmp_truss_dir
58
- ):
59
- new_model_file_content = """
60
- class Model:
61
- pass
62
- """
63
- patch = Patch(
64
- type=PatchType.MODEL_CODE,
65
- body=ModelCodePatch(
66
- action=Action.UPDATE,
67
- path="model.py",
68
- content=new_model_file_content,
69
- ),
70
- )
71
- patch_applier(patch, os.environ.copy())
72
- assert (tmp_truss_dir / "model" / "model.py").read_text() == new_model_file_content
73
-
74
-
75
- def test_patch_applier_package_patch_add(
76
- patch_applier: ModelContainerPatchApplier, tmp_truss_dir
77
- ):
78
- patch = Patch(
79
- type=PatchType.PACKAGE,
80
- body=PackagePatch(
81
- action=Action.ADD,
82
- path="test_package/test.py",
83
- content="foobar",
84
- ),
85
- )
86
- patch_applier(patch, os.environ.copy())
87
- assert (tmp_truss_dir / "packages" / "test_package" / "test.py").exists()
88
-
89
-
90
- def test_patch_applier_package_patch_remove(
91
- patch_applier: ModelContainerPatchApplier,
92
- tmp_truss_dir,
93
- ):
94
- patch = Patch(
95
- type=PatchType.PACKAGE,
96
- body=PackagePatch(
97
- action=Action.REMOVE,
98
- path="test_package/test.py",
99
- ),
100
- )
101
- assert (tmp_truss_dir / "packages" / "test_package" / "test.py").exists()
102
- patch_applier(patch, os.environ.copy())
103
- assert not (tmp_truss_dir / "packages" / "test_package" / "test.py").exists()
104
-
105
-
106
- def test_patch_applier_package_patch_update(
107
- patch_applier: ModelContainerPatchApplier,
108
- tmp_truss_dir,
109
- ):
110
- new_package_content = """X = 2"""
111
- patch = Patch(
112
- type=PatchType.PACKAGE,
113
- body=PackagePatch(
114
- action=Action.UPDATE,
115
- path="test_package/test.py",
116
- content=new_package_content,
117
- ),
118
- )
119
- patch_applier(patch, os.environ.copy())
120
- assert (
121
- tmp_truss_dir / "packages" / "test_package" / "test.py"
122
- ).read_text() == new_package_content
123
-
124
-
125
- def test_patch_applier_config_patch_update(
126
- patch_applier: ModelContainerPatchApplier, tmp_truss_dir
127
- ):
128
- new_config_dict = {"model_name": "foobar"}
129
- patch = Patch(
130
- type=PatchType.CONFIG,
131
- body=ConfigPatch(
132
- action=Action.UPDATE,
133
- config=new_config_dict,
134
- ),
135
- )
136
- patch_applier(patch, os.environ.copy())
137
- new_config = TrussConfig.from_yaml(tmp_truss_dir / "config.yaml")
138
- assert new_config.model_name == "foobar"
139
-
140
-
141
- def test_patch_applier_env_var_patch_update(
142
- patch_applier: ModelContainerPatchApplier,
143
- ):
144
- env_var_dict = {"FOO": "BAR"}
145
- patch = Patch(
146
- type=PatchType.ENVIRONMENT_VARIABLE,
147
- body=EnvVarPatch(
148
- action=Action.UPDATE,
149
- item={"FOO": "BAR-PATCHED"},
150
- ),
151
- )
152
- patch_applier(patch, env_var_dict)
153
- assert env_var_dict["FOO"] == "BAR-PATCHED"
154
-
155
-
156
- def test_patch_applier_env_var_patch_add(
157
- patch_applier: ModelContainerPatchApplier,
158
- ):
159
- env_var_dict = {"FOO": "BAR"}
160
- patch = Patch(
161
- type=PatchType.ENVIRONMENT_VARIABLE,
162
- body=EnvVarPatch(
163
- action=Action.ADD,
164
- item={"BAR": "FOO"},
165
- ),
166
- )
167
- patch_applier(patch, env_var_dict)
168
- assert env_var_dict["FOO"] == "BAR"
169
- assert env_var_dict["BAR"] == "FOO"
170
-
171
-
172
- def test_patch_applier_env_var_patch_remove(
173
- patch_applier: ModelContainerPatchApplier,
174
- ):
175
- env_var_dict = {"FOO": "BAR"}
176
- patch = Patch(
177
- type=PatchType.ENVIRONMENT_VARIABLE,
178
- body=EnvVarPatch(
179
- action=Action.REMOVE,
180
- item={"FOO": "BAR"},
181
- ),
182
- )
183
- patch_applier(patch, env_var_dict)
184
- with pytest.raises(KeyError):
185
- _ = env_var_dict["FOO"]
186
-
187
-
188
- def test_patch_applier_external_data_patch_add(
189
- patch_applier: ModelContainerPatchApplier,
190
- tmp_truss_dir,
191
- ):
192
- patch = Patch(
193
- type=PatchType.EXTERNAL_DATA,
194
- body=ExternalDataPatch(
195
- action=Action.ADD,
196
- item={
197
- "url": "https://raw.githubusercontent.com/basetenlabs/truss/main/docs/favicon.svg",
198
- "local_data_path": "truss_icon",
199
- },
200
- ),
201
- )
202
- patch_applier(patch, os.environ.copy())
203
- assert (tmp_truss_dir / "data" / "truss_icon").exists()
@@ -1,19 +0,0 @@
1
- # This file doesn't test anything, but provides utilities for testing.
2
- from unittest import mock
3
-
4
-
5
- def model_supports_predict_proba():
6
- mock_not_predict_proba = mock.Mock(name="mock_not_predict_proba")
7
- mock_not_predict_proba.predict_proba.return_value = False
8
-
9
- mock_check_proba = mock.Mock(name="mock_check_proba")
10
- mock_check_proba.predict_proba.return_value = True
11
- mock_check_proba._check_proba.return_value = True
12
-
13
- mock_not_check_proba = mock.Mock(name="mock_not_check_proba")
14
- mock_not_check_proba.predict_proba.return_value = True
15
- mock_not_check_proba._check_proba.side_effect = AttributeError
16
-
17
- assert not model_supports_predict_proba(mock_not_predict_proba)
18
- assert model_supports_predict_proba(mock_check_proba)
19
- assert not model_supports_predict_proba(mock_not_check_proba)
@@ -1,87 +0,0 @@
1
- import importlib
2
- import sys
3
- import time
4
- from pathlib import Path
5
- from typing import Any
6
-
7
- import pytest
8
- import yaml
9
-
10
-
11
- @pytest.fixture
12
- def app_path(tmp_truss_dir: Path, helpers: Any):
13
- truss_container_app_path = tmp_truss_dir
14
- model_file_content = """
15
- class Model:
16
- def __init__(self):
17
- self.load_count = 0
18
- def load(self):
19
- self.load_count += 1
20
- if self.load_count <= 2:
21
- raise RuntimeError('Simulated error')
22
- def predict(self, request):
23
- return request
24
- """
25
- with helpers.file_content(
26
- truss_container_app_path / "model" / "model.py",
27
- model_file_content,
28
- ), helpers.sys_path(truss_container_app_path):
29
- yield truss_container_app_path
30
-
31
-
32
- # TODO: Make this test work
33
- @pytest.mark.skip(
34
- reason="Succeeds when tests in this file are run alone, but fails with the whole suit"
35
- )
36
- def test_model_wrapper_load_error_once(app_path):
37
- if "model_wrapper" in sys.modules:
38
- model_wrapper_module = sys.modules["model_wrapper"]
39
- importlib.reload(model_wrapper_module)
40
- else:
41
- model_wrapper_module = importlib.import_module("model_wrapper")
42
- model_wraper_class = getattr(model_wrapper_module, "ModelWrapper")
43
- config = yaml.safe_load((app_path / "config.yaml").read_text())
44
- model_wrapper = model_wraper_class(config)
45
- model_wrapper.load()
46
- # Allow load thread to execute
47
- time.sleep(1)
48
- output = model_wrapper.predict({})
49
- assert output == {}
50
- assert model_wrapper._model.load_count == 3
51
-
52
-
53
- # TODO: Make this test work
54
- @pytest.mark.skip(
55
- reason="Succeeds when tests in this file are run alone, but fails with the whole suit"
56
- )
57
- def test_model_wrapper_load_error_more_than_allowed(app_path, helpers):
58
- with helpers.env_var("NUM_LOAD_RETRIES_TRUSS", "0"):
59
- if "model_wrapper" in sys.modules:
60
- model_wrapper_module = sys.modules["model_wrapper"]
61
- importlib.reload(model_wrapper_module)
62
- else:
63
- model_wrapper_module = importlib.import_module("model_wrapper")
64
- model_wraper_class = getattr(model_wrapper_module, "ModelWrapper")
65
- config = yaml.safe_load((app_path / "config.yaml").read_text())
66
- model_wrapper = model_wraper_class(config)
67
- model_wrapper.load()
68
- # Allow load thread to execute
69
- time.sleep(1)
70
- assert model_wrapper.load_failed()
71
-
72
-
73
- @pytest.mark.integration
74
- async def test_model_wrapper_streaming_timeout(app_path):
75
- if "model_wrapper" in sys.modules:
76
- model_wrapper_module = sys.modules["model_wrapper"]
77
- importlib.reload(model_wrapper_module)
78
- else:
79
- model_wrapper_module = importlib.import_module("model_wrapper")
80
- model_wraper_class = getattr(model_wrapper_module, "ModelWrapper")
81
-
82
- # Create an instance of ModelWrapper with streaming_read_timeout set to 5 seconds
83
- config = yaml.safe_load((app_path / "config.yaml").read_text())
84
- config["runtime"]["streaming_read_timeout"] = 5
85
- model_wrapper = model_wraper_class(config)
86
- model_wrapper.load()
87
- assert model_wrapper._config.get("runtime").get("streaming_read_timeout") == 5
@@ -1,16 +0,0 @@
1
- from typing import Callable, Dict, Optional, TypeVar
2
-
3
- X = TypeVar("X")
4
- Y = TypeVar("Y")
5
- Z = TypeVar("Z")
6
-
7
-
8
- def transform_optional(x: Optional[X], fn: Callable[[X], Optional[Y]]) -> Optional[Y]:
9
- if x is None:
10
- return None
11
-
12
- return fn(x)
13
-
14
-
15
- def transform_keys(d: Dict[X, Z], fn: Callable[[X], Y]) -> Dict[Y, Z]:
16
- return {fn(key): value for key, value in d.items()}