truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

Files changed (362) hide show
  1. truss/__init__.py +10 -3
  2. truss/api/__init__.py +123 -0
  3. truss/api/definitions.py +51 -0
  4. truss/base/constants.py +116 -0
  5. truss/base/custom_types.py +29 -0
  6. truss/{errors.py → base/errors.py} +4 -0
  7. truss/base/trt_llm_config.py +310 -0
  8. truss/{truss_config.py → base/truss_config.py} +344 -31
  9. truss/{truss_spec.py → base/truss_spec.py} +20 -6
  10. truss/{validation.py → base/validation.py} +60 -11
  11. truss/cli/cli.py +841 -88
  12. truss/{remote → cli}/remote_cli.py +2 -7
  13. truss/contexts/docker_build_setup.py +67 -0
  14. truss/contexts/image_builder/cache_warmer.py +2 -8
  15. truss/contexts/image_builder/image_builder.py +1 -1
  16. truss/contexts/image_builder/serving_image_builder.py +292 -46
  17. truss/contexts/image_builder/util.py +1 -3
  18. truss/contexts/local_loader/docker_build_emulator.py +58 -0
  19. truss/contexts/local_loader/load_model_local.py +2 -2
  20. truss/contexts/local_loader/truss_module_loader.py +1 -1
  21. truss/contexts/local_loader/utils.py +1 -1
  22. truss/local/local_config.py +2 -6
  23. truss/local/local_config_handler.py +20 -5
  24. truss/patch/__init__.py +1 -0
  25. truss/patch/hash.py +4 -70
  26. truss/patch/signature.py +4 -16
  27. truss/patch/truss_dir_patch_applier.py +3 -78
  28. truss/remote/baseten/api.py +308 -23
  29. truss/remote/baseten/auth.py +3 -3
  30. truss/remote/baseten/core.py +257 -50
  31. truss/remote/baseten/custom_types.py +44 -0
  32. truss/remote/baseten/error.py +4 -0
  33. truss/remote/baseten/remote.py +369 -118
  34. truss/remote/baseten/service.py +118 -11
  35. truss/remote/baseten/utils/status.py +29 -0
  36. truss/remote/baseten/utils/tar.py +34 -22
  37. truss/remote/baseten/utils/transfer.py +36 -23
  38. truss/remote/remote_factory.py +14 -5
  39. truss/remote/truss_remote.py +72 -45
  40. truss/templates/base.Dockerfile.jinja +18 -16
  41. truss/templates/cache.Dockerfile.jinja +3 -3
  42. truss/{server → templates/control}/control/application.py +14 -35
  43. truss/{server → templates/control}/control/endpoints.py +39 -9
  44. truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
  45. truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
  46. truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
  47. truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
  48. truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
  49. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
  50. truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
  51. truss/{server → templates/control}/control/server.py +11 -6
  52. truss/templates/control/requirements.txt +9 -0
  53. truss/templates/custom_python_dx/my_model.py +28 -0
  54. truss/templates/docker_server/proxy.conf.jinja +42 -0
  55. truss/templates/docker_server/supervisord.conf.jinja +27 -0
  56. truss/templates/docker_server_requirements.txt +1 -0
  57. truss/templates/server/common/errors.py +231 -0
  58. truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
  59. truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
  60. truss/{server → templates/server}/common/retry.py +1 -0
  61. truss/{server → templates/server}/common/schema.py +11 -9
  62. truss/templates/server/common/tracing.py +157 -0
  63. truss/templates/server/main.py +9 -0
  64. truss/templates/server/model_wrapper.py +961 -0
  65. truss/templates/server/requirements.txt +21 -0
  66. truss/templates/server/truss_server.py +447 -0
  67. truss/templates/server.Dockerfile.jinja +62 -14
  68. truss/templates/shared/dynamic_config_resolver.py +28 -0
  69. truss/templates/shared/lazy_data_resolver.py +164 -0
  70. truss/templates/shared/log_config.py +125 -0
  71. truss/{server → templates}/shared/secrets_resolver.py +1 -2
  72. truss/{server → templates}/shared/serialization.py +31 -9
  73. truss/{server → templates}/shared/util.py +3 -13
  74. truss/templates/trtllm-audio/model/model.py +49 -0
  75. truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
  76. truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
  77. truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
  78. truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
  79. truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
  80. truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
  81. truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
  82. truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
  83. truss/templates/trtllm-briton/src/extension.py +64 -0
  84. truss/tests/conftest.py +302 -94
  85. truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
  86. truss/tests/contexts/local_loader/test_load_local.py +2 -2
  87. truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
  88. truss/tests/patch/test_calc_patch.py +439 -127
  89. truss/tests/patch/test_dir_signature.py +3 -12
  90. truss/tests/patch/test_hash.py +1 -1
  91. truss/tests/patch/test_signature.py +1 -1
  92. truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
  93. truss/tests/patch/test_types.py +2 -2
  94. truss/tests/remote/baseten/test_api.py +153 -58
  95. truss/tests/remote/baseten/test_auth.py +2 -1
  96. truss/tests/remote/baseten/test_core.py +160 -12
  97. truss/tests/remote/baseten/test_remote.py +489 -77
  98. truss/tests/remote/baseten/test_service.py +55 -0
  99. truss/tests/remote/test_remote_factory.py +16 -18
  100. truss/tests/remote/test_truss_remote.py +26 -17
  101. truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
  102. truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
  103. truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
  104. truss/tests/{server → templates/control}/control/test_server.py +79 -24
  105. truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
  106. truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
  107. truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
  108. truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
  109. truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
  110. truss/tests/{server → templates/server}/common/test_retry.py +3 -3
  111. truss/tests/templates/server/test_model_wrapper.py +248 -0
  112. truss/tests/{server → templates/server}/test_schema.py +3 -5
  113. truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
  114. truss/tests/test_build.py +9 -52
  115. truss/tests/test_config.py +336 -77
  116. truss/tests/test_context_builder_image.py +3 -11
  117. truss/tests/test_control_truss_patching.py +7 -12
  118. truss/tests/test_custom_server.py +38 -0
  119. truss/tests/test_data/context_builder_image_test/test.py +3 -0
  120. truss/tests/test_data/happy.ipynb +56 -0
  121. truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
  122. truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
  123. truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
  124. truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
  125. truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
  126. truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
  127. truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
  128. truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
  129. truss/tests/test_data/test_async_truss/__init__.py +0 -0
  130. truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
  131. truss/tests/test_data/test_basic_truss/__init__.py +0 -0
  132. truss/tests/test_data/test_basic_truss/config.yaml +16 -0
  133. truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
  134. truss/tests/test_data/test_build_commands/__init__.py +0 -0
  135. truss/tests/test_data/test_build_commands/config.yaml +13 -0
  136. truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
  137. truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
  138. truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
  139. truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
  140. truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
  141. truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
  142. truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
  143. truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
  144. truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
  145. truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
  146. truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
  147. truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
  148. truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
  149. truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
  150. truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
  151. truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
  152. truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
  153. truss/tests/test_data/test_openai/__init__.py +0 -0
  154. truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
  155. truss/tests/test_data/test_openai/model/__init__.py +0 -0
  156. truss/tests/test_data/test_openai/model/model.py +15 -0
  157. truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
  158. truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
  159. truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
  160. truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
  161. truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
  162. truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
  163. truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
  164. truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
  165. truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
  166. truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
  167. truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
  168. truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
  169. truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
  170. truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
  171. truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
  172. truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
  173. truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
  174. truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
  175. truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
  176. truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
  177. truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
  178. truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
  179. truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
  180. truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
  181. truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
  182. truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
  183. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
  184. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
  185. truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
  186. truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
  187. truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
  188. truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
  189. truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
  190. truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
  191. truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
  192. truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
  193. truss/tests/test_data/test_truss/__init__.py +0 -0
  194. truss/tests/test_data/test_truss/config.yaml +4 -0
  195. truss/tests/test_data/test_truss/model/__init__.py +0 -0
  196. truss/tests/test_data/test_truss/model/dummy +0 -0
  197. truss/tests/test_data/test_truss/packages/__init__.py +0 -0
  198. truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
  199. truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
  200. truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
  201. truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
  202. truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
  203. truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
  204. truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
  205. truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
  206. truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
  207. truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
  208. truss/tests/test_docker.py +2 -1
  209. truss/tests/test_model_inference.py +1340 -292
  210. truss/tests/test_model_schema.py +33 -26
  211. truss/tests/test_testing_utilities_for_other_tests.py +50 -5
  212. truss/tests/test_truss_gatherer.py +3 -5
  213. truss/tests/test_truss_handle.py +62 -59
  214. truss/tests/test_util.py +2 -1
  215. truss/tests/test_validation.py +15 -13
  216. truss/tests/trt_llm/test_trt_llm_config.py +41 -0
  217. truss/tests/trt_llm/test_validation.py +91 -0
  218. truss/tests/util/test_config_checks.py +40 -0
  219. truss/tests/util/test_env_vars.py +14 -0
  220. truss/tests/util/test_path.py +10 -23
  221. truss/trt_llm/config_checks.py +43 -0
  222. truss/trt_llm/validation.py +42 -0
  223. truss/truss_handle/__init__.py +0 -0
  224. truss/truss_handle/build.py +122 -0
  225. truss/{decorators.py → truss_handle/decorators.py} +1 -1
  226. truss/truss_handle/patch/__init__.py +0 -0
  227. truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
  228. truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
  229. truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
  230. truss/truss_handle/patch/hash.py +71 -0
  231. truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
  232. truss/truss_handle/patch/signature.py +22 -0
  233. truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
  234. truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
  235. truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
  236. truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
  237. truss/util/.truss_ignore +3 -0
  238. truss/{docker.py → util/docker.py} +6 -2
  239. truss/util/download.py +6 -15
  240. truss/util/env_vars.py +41 -0
  241. truss/util/log_utils.py +52 -0
  242. truss/util/path.py +20 -20
  243. truss/util/requirements.py +11 -0
  244. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
  245. truss-0.60.0.dist-info/RECORD +324 -0
  246. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
  247. truss-0.60.0.dist-info/entry_points.txt +4 -0
  248. truss_chains/__init__.py +71 -0
  249. truss_chains/definitions.py +756 -0
  250. truss_chains/deployment/__init__.py +0 -0
  251. truss_chains/deployment/code_gen.py +816 -0
  252. truss_chains/deployment/deployment_client.py +871 -0
  253. truss_chains/framework.py +1480 -0
  254. truss_chains/public_api.py +231 -0
  255. truss_chains/py.typed +0 -0
  256. truss_chains/pydantic_numpy.py +131 -0
  257. truss_chains/reference_code/reference_chainlet.py +34 -0
  258. truss_chains/reference_code/reference_model.py +10 -0
  259. truss_chains/remote_chainlet/__init__.py +0 -0
  260. truss_chains/remote_chainlet/model_skeleton.py +60 -0
  261. truss_chains/remote_chainlet/stub.py +380 -0
  262. truss_chains/remote_chainlet/utils.py +332 -0
  263. truss_chains/streaming.py +378 -0
  264. truss_chains/utils.py +178 -0
  265. CODE_OF_CONDUCT.md +0 -131
  266. CONTRIBUTING.md +0 -48
  267. README.md +0 -137
  268. context_builder.Dockerfile +0 -24
  269. truss/blob/blob_backend.py +0 -10
  270. truss/blob/blob_backend_registry.py +0 -23
  271. truss/blob/http_public_blob_backend.py +0 -23
  272. truss/build/__init__.py +0 -2
  273. truss/build/build.py +0 -143
  274. truss/build/configure.py +0 -63
  275. truss/cli/__init__.py +0 -2
  276. truss/cli/console.py +0 -5
  277. truss/cli/create.py +0 -5
  278. truss/config/trt_llm.py +0 -81
  279. truss/constants.py +0 -61
  280. truss/model_inference.py +0 -123
  281. truss/patch/types.py +0 -30
  282. truss/pytest.ini +0 -7
  283. truss/server/common/errors.py +0 -100
  284. truss/server/common/termination_handler_middleware.py +0 -64
  285. truss/server/common/truss_server.py +0 -389
  286. truss/server/control/patch/model_code_patch_applier.py +0 -46
  287. truss/server/control/patch/requirement_name_identifier.py +0 -17
  288. truss/server/inference_server.py +0 -29
  289. truss/server/model_wrapper.py +0 -434
  290. truss/server/shared/logging.py +0 -81
  291. truss/templates/trtllm/model/model.py +0 -97
  292. truss/templates/trtllm/packages/build_engine_utils.py +0 -34
  293. truss/templates/trtllm/packages/constants.py +0 -11
  294. truss/templates/trtllm/packages/schema.py +0 -216
  295. truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
  296. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
  297. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
  298. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
  299. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
  300. truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
  301. truss/templates/trtllm/packages/triton_client.py +0 -150
  302. truss/templates/trtllm/packages/utils.py +0 -43
  303. truss/test_data/context_builder_image_test/test.py +0 -4
  304. truss/test_data/happy.ipynb +0 -54
  305. truss/test_data/model_load_failure_test/config.yaml +0 -2
  306. truss/test_data/test_concurrency_truss/config.yaml +0 -2
  307. truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
  308. truss/test_data/test_streaming_truss/config.yaml +0 -3
  309. truss/test_data/test_truss/config.yaml +0 -2
  310. truss/tests/server/common/test_termination_handler_middleware.py +0 -93
  311. truss/tests/server/control/test_model_container_patch_applier.py +0 -203
  312. truss/tests/server/core/server/common/test_util.py +0 -19
  313. truss/tests/server/test_model_wrapper.py +0 -87
  314. truss/util/data_structures.py +0 -16
  315. truss-0.10.0rc1.dist-info/RECORD +0 -216
  316. truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
  317. truss/{server/shared → base}/__init__.py +0 -0
  318. truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
  319. truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
  320. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
  321. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
  322. truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
  323. truss/{server → templates/server}/common/__init__.py +0 -0
  324. truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
  325. truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
  326. truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
  327. truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
  328. truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
  329. truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
  330. truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
  331. truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
  332. truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
  333. truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
  334. truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
  335. truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
  336. truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
  337. truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
  338. truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
  339. truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
  340. truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
  341. truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
  342. truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
  343. truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
  344. truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
  345. truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
  346. truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
  347. /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
  348. /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
  349. /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
  350. /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
  351. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
  352. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
  353. /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
  354. /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
  355. /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
  356. /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
  357. /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
  358. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
  359. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
  360. /truss/{patch → truss_handle/patch}/constants.py +0 -0
  361. /truss/{notebook.py → util/notebook.py} +0 -0
  362. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
@@ -0,0 +1,961 @@
1
+ import asyncio
2
+ import dataclasses
3
+ import enum
4
+ import importlib
5
+ import importlib.util
6
+ import inspect
7
+ import json
8
+ import logging
9
+ import os
10
+ import pathlib
11
+ import sys
12
+ import time
13
+ import weakref
14
+ from contextlib import asynccontextmanager
15
+ from functools import cached_property
16
+ from multiprocessing import Lock
17
+ from pathlib import Path
18
+ from threading import Thread
19
+ from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Union, cast
20
+
21
+ import opentelemetry.sdk.trace as sdk_trace
22
+ import pydantic
23
+ import starlette.requests
24
+ import starlette.responses
25
+ from anyio import Semaphore, to_thread
26
+ from common import errors, tracing
27
+ from common.patches import apply_patches
28
+ from common.retry import retry
29
+ from common.schema import TrussSchema
30
+ from opentelemetry import trace
31
+ from shared import dynamic_config_resolver, serialization
32
+ from shared.lazy_data_resolver import LazyDataResolver
33
+ from shared.secrets_resolver import SecretsResolver
34
+
35
+ if sys.version_info >= (3, 9):
36
+ from typing import AsyncGenerator, Generator
37
+ else:
38
+ from typing_extensions import AsyncGenerator, Generator
39
+
40
+ MODEL_BASENAME = "model"
41
+
42
+ NUM_LOAD_RETRIES = int(os.environ.get("NUM_LOAD_RETRIES_TRUSS", "1"))
43
+ STREAMING_RESPONSE_QUEUE_READ_TIMEOUT_SECS = 60
44
+ DEFAULT_PREDICT_CONCURRENCY = 1
45
+ EXTENSIONS_DIR_NAME = "extensions"
46
+ EXTENSION_CLASS_NAME = "Extension"
47
+ EXTENSION_FILE_NAME = "extension"
48
+ TRT_LLM_EXTENSION_NAME = "trt_llm"
49
+ POLL_FOR_ENVIRONMENT_UPDATES_TIMEOUT_SECS = 30
50
+
51
+
52
+ class MethodName(str, enum.Enum):
53
+ CHAT_COMPLETIONS = "chat_completions"
54
+ COMPLETIONS = "completions"
55
+ IS_HEALTHY = "is_healthy"
56
+ POSTPROCESS = "postprocess"
57
+ PREDICT = "predict"
58
+ PREPROCESS = "preprocess"
59
+ SETUP_ENVIRONMENT = "setup_environment"
60
+
61
+
62
+ InputType = Union[serialization.JSONType, serialization.MsgPackType, pydantic.BaseModel]
63
+ OutputType = Union[
64
+ serialization.JSONType,
65
+ serialization.MsgPackType,
66
+ Generator[bytes, None, None],
67
+ AsyncGenerator[bytes, None],
68
+ "starlette.responses.Response",
69
+ pydantic.BaseModel,
70
+ ]
71
+ ModelFn = Callable[..., Union[OutputType, Awaitable[OutputType]]]
72
+
73
+
74
+ @asynccontextmanager
75
+ async def deferred_semaphore_and_span(
76
+ semaphore: Semaphore, span: trace.Span
77
+ ) -> AsyncGenerator[Callable[[], Callable[[], None]], None]:
78
+ """
79
+ Context manager that allows deferring the release of a semaphore and the ending of a
80
+ trace span.
81
+
82
+ Yields a function that, when called, releases the semaphore and ends the span.
83
+ If that function is not called, the resources are cleand up when exiting.
84
+ """
85
+ await semaphore.acquire()
86
+ trace.use_span(span, end_on_exit=False)
87
+ deferred = False
88
+
89
+ def release_and_end() -> None:
90
+ semaphore.release()
91
+ span.end()
92
+
93
+ def defer() -> Callable[[], None]:
94
+ nonlocal deferred
95
+ deferred = True
96
+ return release_and_end
97
+
98
+ try:
99
+ yield defer
100
+ finally:
101
+ if not deferred:
102
+ release_and_end()
103
+
104
+
105
+ _ArgsType = Union[
106
+ Tuple[Any],
107
+ Tuple[Any, starlette.requests.Request],
108
+ Tuple[starlette.requests.Request],
109
+ ]
110
+
111
+
112
+ class _Sentinel:
113
+ def __repr__(self) -> str:
114
+ return "<Sentinel End of Queue>"
115
+
116
+
117
+ SENTINEL = _Sentinel()
118
+
119
+
120
+ def _is_request_type(obj: Any) -> bool:
121
+ # issubclass raises an error (instead of returning False) if `obj` is not a type.
122
+ try:
123
+ return issubclass(obj, starlette.requests.Request)
124
+ except Exception:
125
+ return False
126
+
127
+
128
+ class ArgConfig(enum.Enum):
129
+ NONE = enum.auto()
130
+ INPUTS_ONLY = enum.auto()
131
+ REQUEST_ONLY = enum.auto()
132
+ INPUTS_AND_REQUEST = enum.auto()
133
+
134
+ @classmethod
135
+ def from_signature(
136
+ cls, signature: inspect.Signature, method_name: str
137
+ ) -> "ArgConfig":
138
+ parameters = list(signature.parameters.values())
139
+
140
+ if len(parameters) == 0:
141
+ return cls.NONE
142
+ elif len(parameters) == 1:
143
+ if _is_request_type(parameters[0].annotation):
144
+ return cls.REQUEST_ONLY
145
+ return cls.INPUTS_ONLY
146
+ elif len(parameters) == 2:
147
+ # First arg can be whatever, except request. Second arg must be request.
148
+ param1, param2 = parameters
149
+ if param1.annotation:
150
+ if _is_request_type(param1.annotation):
151
+ raise errors.ModelDefinitionError(
152
+ f"`{method_name}` method with two arguments is not allowed to "
153
+ "have request as first argument, request must be second. "
154
+ f"Got: {signature}"
155
+ )
156
+ if not (param2.annotation and _is_request_type(param2.annotation)):
157
+ raise errors.ModelDefinitionError(
158
+ f"`{method_name}` method with two arguments must have request as "
159
+ f"second argument (type annotated). Got: {signature} "
160
+ )
161
+ return cls.INPUTS_AND_REQUEST
162
+ else:
163
+ raise errors.ModelDefinitionError(
164
+ f"`{method_name}` method cannot have more than two arguments. "
165
+ f"Got: {signature}"
166
+ )
167
+
168
+ @classmethod
169
+ def prepare_args(
170
+ cls,
171
+ inputs: Any,
172
+ request: starlette.requests.Request,
173
+ descriptor: "MethodDescriptor",
174
+ ) -> _ArgsType:
175
+ args: _ArgsType
176
+ if descriptor.arg_config == ArgConfig.INPUTS_ONLY:
177
+ args = (inputs,)
178
+ elif descriptor.arg_config == ArgConfig.REQUEST_ONLY:
179
+ args = (request,)
180
+ elif descriptor.arg_config == ArgConfig.INPUTS_AND_REQUEST:
181
+ args = (inputs, request)
182
+ else:
183
+ raise NotImplementedError(f"Arg config {descriptor.arg_config}.")
184
+ return args
185
+
186
+
187
+ @dataclasses.dataclass
188
+ class MethodDescriptor:
189
+ is_async: bool
190
+ is_generator: bool
191
+ arg_config: ArgConfig
192
+ method_name: MethodName
193
+ method: ModelFn
194
+
195
+ @classmethod
196
+ def from_method(cls, method: Any, method_name: MethodName) -> "MethodDescriptor":
197
+ return cls(
198
+ is_async=cls._is_async(method),
199
+ is_generator=cls._is_generator(method),
200
+ arg_config=ArgConfig.from_signature(inspect.signature(method), method_name),
201
+ method_name=method_name,
202
+ # ArgConfig ensures that the Callable has an appropriate signature.
203
+ method=cast(ModelFn, method),
204
+ )
205
+
206
+ @classmethod
207
+ def _is_async(cls, method: Any):
208
+ # We intentionally do not check inspect.isasyncgenfunction(method) because you cannot
209
+ # `await` an async generator, you must use `async for` syntax.
210
+ return inspect.iscoroutinefunction(method)
211
+
212
+ @classmethod
213
+ def _is_generator(cls, method: Any):
214
+ return inspect.isgeneratorfunction(method) or inspect.isasyncgenfunction(method)
215
+
216
+
217
+ @dataclasses.dataclass
218
+ class ModelDescriptor:
219
+ preprocess: Optional[MethodDescriptor]
220
+ predict: MethodDescriptor
221
+ postprocess: Optional[MethodDescriptor]
222
+ truss_schema: Optional[TrussSchema]
223
+ setup_environment: Optional[MethodDescriptor]
224
+ is_healthy: Optional[MethodDescriptor]
225
+ completions: Optional[MethodDescriptor]
226
+ chat_completions: Optional[MethodDescriptor]
227
+
228
+ @cached_property
229
+ def skip_input_parsing(self) -> bool:
230
+ return self.predict.arg_config == ArgConfig.REQUEST_ONLY and (
231
+ not self.preprocess or self.preprocess.arg_config == ArgConfig.REQUEST_ONLY
232
+ )
233
+
234
+ @classmethod
235
+ def _gen_truss_schema(
236
+ cls,
237
+ model_cls: Any,
238
+ predict: MethodDescriptor,
239
+ preprocess: Optional[MethodDescriptor],
240
+ postprocess: Optional[MethodDescriptor],
241
+ ) -> TrussSchema:
242
+ if preprocess:
243
+ parameters = inspect.signature(model_cls.preprocess).parameters
244
+ else:
245
+ parameters = inspect.signature(model_cls.predict).parameters
246
+
247
+ if postprocess:
248
+ return_annotation = inspect.signature(
249
+ model_cls.postprocess
250
+ ).return_annotation
251
+ else:
252
+ return_annotation = inspect.signature(model_cls.predict).return_annotation
253
+
254
+ return TrussSchema.from_signature(parameters, return_annotation)
255
+
256
+ @classmethod
257
+ def _safe_extract_descriptor(
258
+ cls, model_cls: Any, method_name: MethodName
259
+ ) -> Union[MethodDescriptor, None]:
260
+ if hasattr(model_cls, method_name):
261
+ return MethodDescriptor.from_method(
262
+ method=getattr(model_cls, method_name), method_name=method_name
263
+ )
264
+ return None
265
+
266
+ @classmethod
267
+ def from_model(cls, model_cls) -> "ModelDescriptor":
268
+ preprocess = cls._safe_extract_descriptor(model_cls, MethodName.PREPROCESS)
269
+ predict = cls._safe_extract_descriptor(model_cls, MethodName.PREDICT)
270
+ if predict is None:
271
+ raise errors.ModelDefinitionError(
272
+ f"Truss model must have a `{MethodName.PREDICT}` method."
273
+ )
274
+ elif preprocess and predict.arg_config == ArgConfig.REQUEST_ONLY:
275
+ raise errors.ModelDefinitionError(
276
+ f"When using `{MethodName.PREPROCESS}`, the {MethodName.PREDICT} method "
277
+ f"cannot only have the request argument (because the result of `{MethodName.PREPROCESS}` "
278
+ "would be discarded)."
279
+ )
280
+
281
+ postprocess = cls._safe_extract_descriptor(model_cls, MethodName.POSTPROCESS)
282
+ if postprocess and postprocess.arg_config == ArgConfig.REQUEST_ONLY:
283
+ raise errors.ModelDefinitionError(
284
+ f"The `{MethodName.POSTPROCESS}` method cannot only have the request "
285
+ f"argument (because the result of `{MethodName.PREDICT}` would be discarded)."
286
+ )
287
+ setup = cls._safe_extract_descriptor(model_cls, MethodName.SETUP_ENVIRONMENT)
288
+ completions = cls._safe_extract_descriptor(model_cls, MethodName.COMPLETIONS)
289
+ chats = cls._safe_extract_descriptor(model_cls, MethodName.CHAT_COMPLETIONS)
290
+ is_healthy = cls._safe_extract_descriptor(model_cls, MethodName.IS_HEALTHY)
291
+ if is_healthy and is_healthy.arg_config != ArgConfig.NONE:
292
+ raise errors.ModelDefinitionError(
293
+ f"`{MethodName.IS_HEALTHY}` must have only one argument: `self`."
294
+ )
295
+
296
+ truss_schema = cls._gen_truss_schema(
297
+ model_cls=model_cls,
298
+ predict=predict,
299
+ preprocess=preprocess,
300
+ postprocess=postprocess,
301
+ )
302
+ return cls(
303
+ preprocess=preprocess,
304
+ predict=predict,
305
+ postprocess=postprocess,
306
+ truss_schema=truss_schema,
307
+ setup_environment=setup,
308
+ is_healthy=is_healthy,
309
+ completions=completions,
310
+ chat_completions=chats,
311
+ )
312
+
313
+
314
+ class ModelWrapper:
315
+ _config: Dict
316
+ _tracer: sdk_trace.Tracer
317
+ _maybe_model: Optional[Any]
318
+ _maybe_model_descriptor: Optional[ModelDescriptor]
319
+ _logger: logging.Logger
320
+ _status: "ModelWrapper.Status"
321
+ _predict_semaphore: Semaphore
322
+ _poll_for_environment_updates_task: Optional[asyncio.Task]
323
+ _environment: Optional[dict]
324
+
325
+ class Status(enum.Enum):
326
+ NOT_READY = 0
327
+ LOADING = 1
328
+ READY = 2
329
+ FAILED = 3
330
+
331
+ def __init__(self, config: Dict, tracer: sdk_trace.Tracer):
332
+ self._config = config
333
+ self._tracer = tracer
334
+ self._maybe_model = None
335
+ self._maybe_model_descriptor = None
336
+ # We need a logger that has all our server JSON logging setup applied in its
337
+ # handlers and where this also hold in the loading thread. Creating a new
338
+ # instance does not carry over the setup into the thread and using unspecified
339
+ # `getLogger` may return non-compliant loggers if dependencies override the root
340
+ # logger (c.g. https://github.com/numpy/numpy/issues/24213). We chose to get
341
+ # the uvicorn logger that is set up in `truss_server`.
342
+ self._logger = logging.getLogger("uvicorn")
343
+ self.name = MODEL_BASENAME
344
+ self._load_lock = Lock()
345
+ self._status = ModelWrapper.Status.NOT_READY
346
+ self._predict_semaphore = Semaphore(
347
+ self._config.get("runtime", {}).get(
348
+ "predict_concurrency", DEFAULT_PREDICT_CONCURRENCY
349
+ )
350
+ )
351
+ self._poll_for_environment_updates_task = None
352
+ self._environment = None
353
+
354
+ @property
355
+ def _model(self) -> Any:
356
+ if self._maybe_model:
357
+ return self._maybe_model
358
+ else:
359
+ raise errors.ModelNotReady(self.name)
360
+
361
+ @property
362
+ def model_descriptor(self) -> ModelDescriptor:
363
+ if self._maybe_model_descriptor:
364
+ return self._maybe_model_descriptor
365
+ else:
366
+ raise errors.ModelNotReady(self.name)
367
+
368
+ @property
369
+ def load_failed(self) -> bool:
370
+ return self._status == ModelWrapper.Status.FAILED
371
+
372
+ @property
373
+ def ready(self) -> bool:
374
+ return self._status == ModelWrapper.Status.READY
375
+
376
+ @property
377
+ def _model_file_name(self) -> str:
378
+ return self._config["model_class_filename"]
379
+
380
+ def start_load_thread(self):
381
+ # Don't retry failed loads.
382
+ if self._status == ModelWrapper.Status.NOT_READY:
383
+ thread = Thread(target=self.load)
384
+ thread.start()
385
+
386
+ def load(self):
387
+ if self.ready:
388
+ return
389
+ # if we are already loading, block on acquiring the lock;
390
+ # this worker will return 503 while the worker with the lock is loading
391
+ with self._load_lock:
392
+ self._status = ModelWrapper.Status.LOADING
393
+ self._logger.info("Executing model.load()...")
394
+ try:
395
+ start_time = time.perf_counter()
396
+ self._load_impl()
397
+ self._status = ModelWrapper.Status.READY
398
+ self._logger.info(
399
+ f"Completed model.load() execution in {_elapsed_ms(start_time)} ms"
400
+ )
401
+ except Exception:
402
+ self._logger.exception("Exception while loading model")
403
+ self._status = ModelWrapper.Status.FAILED
404
+
405
+ def _load_impl(self):
406
+ data_dir = Path("data")
407
+ data_dir.mkdir(exist_ok=True)
408
+
409
+ if "bundled_packages_dir" in self._config:
410
+ bundled_packages_path = Path("/packages")
411
+ if bundled_packages_path.exists():
412
+ sys.path.append(str(bundled_packages_path))
413
+
414
+ secrets = SecretsResolver.get_secrets(self._config)
415
+ lazy_data_resolver = LazyDataResolver(data_dir)
416
+
417
+ apply_patches(
418
+ self._config.get("apply_library_patches", True),
419
+ self._config["requirements"],
420
+ )
421
+
422
+ extensions = _init_extensions(
423
+ self._config, data_dir, secrets, lazy_data_resolver
424
+ )
425
+ for extension in extensions.values():
426
+ extension.load()
427
+
428
+ model_class_file_path = (
429
+ Path(self._config["model_module_dir"])
430
+ / self._config["model_class_filename"]
431
+ )
432
+ if model_class_file_path.exists():
433
+ self._logger.info("Loading truss model from file.")
434
+ module_path = pathlib.Path(model_class_file_path).resolve()
435
+ module_name = module_path.stem # Use the file's name as the module name
436
+ if not os.path.isfile(module_path):
437
+ raise ImportError(
438
+ f"`{module_path}` is not a file. You must point to a python file where "
439
+ "the entrypoint chainlet is defined."
440
+ )
441
+ import_error_msg = f"Could not import `{module_path}`. Check path."
442
+ spec = importlib.util.spec_from_file_location(module_name, module_path)
443
+ if not spec:
444
+ raise ImportError(import_error_msg)
445
+ if not spec.loader:
446
+ raise ImportError(import_error_msg)
447
+ module = importlib.util.module_from_spec(spec)
448
+ try:
449
+ spec.loader.exec_module(module)
450
+ except ImportError as e:
451
+ if "attempted relative import" in str(e):
452
+ raise ImportError(
453
+ f"During import of `{model_class_file_path}`. "
454
+ f"Since Truss v0.9.36 relative imports (starting with '.') in "
455
+ "the top-level model file are no longer supported. Please "
456
+ "replace them with absolute imports. For guidance on importing "
457
+ "custom packages refer to our documentation "
458
+ "https://docs.baseten.co/truss-reference/config#packages"
459
+ ) from e
460
+
461
+ raise
462
+
463
+ model_class = getattr(module, self._config["model_class_name"])
464
+ model_init_params = _prepare_init_args(
465
+ model_class, self._config, data_dir, secrets, lazy_data_resolver
466
+ )
467
+ signature = inspect.signature(model_class)
468
+ for ext_name, ext in extensions.items():
469
+ if _signature_accepts_keyword_arg(signature, ext_name):
470
+ model_init_params[ext_name] = ext.model_args()
471
+ self._maybe_model = model_class(**model_init_params)
472
+
473
+ elif TRT_LLM_EXTENSION_NAME in extensions:
474
+ self._logger.info("Loading TRT LLM extension as model.")
475
+ # trt_llm extension allows model.py to be absent. It supplies its
476
+ # own model class in that case.
477
+ trt_llm_extension = extensions["trt_llm"]
478
+ self._maybe_model = trt_llm_extension.model_override()
479
+ else:
480
+ raise RuntimeError("No module class file found")
481
+
482
+ self._maybe_model_descriptor = ModelDescriptor.from_model(self._model)
483
+
484
+ if self._maybe_model_descriptor.setup_environment:
485
+ self._initialize_environment_before_load()
486
+
487
+ if hasattr(self._model, "load"):
488
+ retry(
489
+ self._model.load,
490
+ NUM_LOAD_RETRIES,
491
+ self._logger.warning,
492
+ "Failed to load model.",
493
+ gap_seconds=1.0,
494
+ )
495
+
496
+ def setup_polling_for_environment_updates(self):
497
+ self._poll_for_environment_updates_task = asyncio.create_task(
498
+ self.poll_for_environment_updates()
499
+ )
500
+
501
+ def _initialize_environment_before_load(self):
502
+ environment_str = dynamic_config_resolver.get_dynamic_config_value_sync(
503
+ dynamic_config_resolver.ENVIRONMENT_DYNAMIC_CONFIG_KEY
504
+ )
505
+ if environment_str:
506
+ environment_json = json.loads(environment_str)
507
+ self._logger.info(
508
+ f"Executing model.setup_environment with environment: {environment_json}"
509
+ )
510
+ # TODO: Support calling an async setup_environment() here once we support async load()
511
+ self._model.setup_environment(environment_json)
512
+ self._environment = environment_json
513
+
514
+ async def setup_environment(self, environment: Optional[dict]):
515
+ descriptor = self.model_descriptor.setup_environment
516
+ if not descriptor:
517
+ return
518
+ self._logger.info(
519
+ f"Executing model.setup_environment with environment: {environment}"
520
+ )
521
+ if descriptor.is_async:
522
+ return await self._model.setup_environment(environment)
523
+ else:
524
+ return await to_thread.run_sync(self._model.setup_environment, environment)
525
+
526
+ async def poll_for_environment_updates(self) -> None:
527
+ last_modified_time = None
528
+ environment_config_filename = (
529
+ dynamic_config_resolver.get_dynamic_config_file_path(
530
+ dynamic_config_resolver.ENVIRONMENT_DYNAMIC_CONFIG_KEY
531
+ )
532
+ )
533
+
534
+ while True:
535
+ # Give control back to the event loop while waiting for environment updates
536
+ await asyncio.sleep(POLL_FOR_ENVIRONMENT_UPDATES_TIMEOUT_SECS)
537
+
538
+ # Wait for load to finish before checking for environment updates
539
+ if not self.ready:
540
+ continue
541
+
542
+ # Skip polling if no setup_environment implementation provided
543
+ if not self.model_descriptor.setup_environment:
544
+ break
545
+
546
+ if environment_config_filename.exists():
547
+ try:
548
+ current_mtime = os.path.getmtime(environment_config_filename)
549
+ if not last_modified_time or last_modified_time != current_mtime:
550
+ environment_str = await dynamic_config_resolver.get_dynamic_config_value_async(
551
+ dynamic_config_resolver.ENVIRONMENT_DYNAMIC_CONFIG_KEY
552
+ )
553
+ if environment_str:
554
+ last_modified_time = current_mtime
555
+ environment_json = json.loads(environment_str)
556
+ # Avoid rerunning `setup_environment` with the same environment
557
+ if self._environment != environment_json:
558
+ await self.setup_environment(environment_json)
559
+ self._environment = environment_json
560
+ except Exception as e:
561
+ self._logger.exception(
562
+ f"Exception while setting up environment: {str(e)}",
563
+ exc_info=errors.filter_traceback(self._model_file_name),
564
+ )
565
+
566
+ async def is_healthy(self) -> Optional[bool]:
567
+ descriptor = self.model_descriptor.is_healthy
568
+ is_healthy: Optional[bool] = None
569
+ if not descriptor or self.load_failed:
570
+ # return early with None if model does not have is_healthy method or load failed
571
+ return is_healthy
572
+ try:
573
+ if descriptor.is_async:
574
+ is_healthy = await self._model.is_healthy()
575
+ else:
576
+ # Offload sync functions to thread, to not block event loop.
577
+ is_healthy = await to_thread.run_sync(self._model.is_healthy)
578
+ except Exception as e:
579
+ is_healthy = False
580
+ self._logger.exception(
581
+ f"Exception while checking if model is healthy: {str(e)}",
582
+ exc_info=errors.filter_traceback(self._model_file_name),
583
+ )
584
+ if not is_healthy and self.ready:
585
+ # self.ready evaluates to True when the model's load function has completed,
586
+ # we will only log health check failures to model logs when the model's load has completed
587
+ self._logger.warning("Health check failed.")
588
+ return is_healthy
589
+
590
+ async def preprocess(
591
+ self, inputs: InputType, request: starlette.requests.Request
592
+ ) -> Any:
593
+ descriptor = self.model_descriptor.preprocess
594
+ assert descriptor, (
595
+ f"`{MethodName.PREPROCESS}` must only be called if model has it."
596
+ )
597
+ return await self._execute_async_model_fn(inputs, request, descriptor)
598
+
599
+ async def predict(
600
+ self, inputs: Any, request: starlette.requests.Request
601
+ ) -> Union[OutputType, Any]:
602
+ # The result can be a serializable data structure, byte-generator, a request,
603
+ # or, if `postprocessing` is used, anything. In the last case postprocessing
604
+ # must convert the result to something serializable.
605
+ descriptor = self.model_descriptor.predict
606
+ return await self._execute_async_model_fn(inputs, request, descriptor)
607
+
608
+ async def postprocess(
609
+ self, result: Union[InputType, Any], request: starlette.requests.Request
610
+ ) -> OutputType:
611
+ # The postprocess function can handle outputs of `predict`, but not
612
+ # generators and responses - in that case predict must return directly
613
+ # and postprocess is skipped.
614
+ # The result type can be the same as for predict.
615
+ descriptor = self.model_descriptor.postprocess
616
+ assert descriptor, (
617
+ f"`{MethodName.POSTPROCESS}` must only be called if model has it."
618
+ )
619
+ return await self._execute_async_model_fn(result, request, descriptor)
620
+
621
+ async def _write_response_to_queue(
622
+ self,
623
+ queue: asyncio.Queue,
624
+ generator: AsyncGenerator[bytes, None],
625
+ span: trace.Span,
626
+ ) -> None:
627
+ with tracing.section_as_event(span, "write_response_to_queue"):
628
+ try:
629
+ async for chunk in generator:
630
+ await queue.put(chunk)
631
+ except Exception as e:
632
+ self._logger.exception(
633
+ f"Exception while generating streamed response: {str(e)}",
634
+ exc_info=errors.filter_traceback(self._model_file_name),
635
+ )
636
+ finally:
637
+ await queue.put(SENTINEL)
638
+
639
+ async def _stream_with_background_task(
640
+ self,
641
+ generator: Union[Generator[bytes, None, None], AsyncGenerator[bytes, None]],
642
+ span: trace.Span,
643
+ trace_ctx: trace.Context,
644
+ cleanup_fn: Callable[[], None],
645
+ ) -> AsyncGenerator[bytes, None]:
646
+ # The streaming read timeout is the amount of time in between streamed chunk
647
+ # before a timeout is triggered.
648
+ streaming_read_timeout = self._config.get("runtime", {}).get(
649
+ "streaming_read_timeout", STREAMING_RESPONSE_QUEUE_READ_TIMEOUT_SECS
650
+ )
651
+ async_generator = _force_async_generator(generator)
652
+ # To ensure that a partial read from a client does not keep the semaphore
653
+ # claimed, we write all the data from the stream to the queue as it is produced,
654
+ # irrespective of how fast it is consumed.
655
+ # We then return a new generator that reads from the queue, and then
656
+ # exits the semaphore block.
657
+ response_queue: asyncio.Queue = asyncio.Queue()
658
+
659
+ # `write_response_to_queue` keeps running the background until completion.
660
+ gen_task = asyncio.create_task(
661
+ self._write_response_to_queue(response_queue, async_generator, span)
662
+ )
663
+ # Defer the release of the semaphore until the write_response_to_queue task.
664
+ gen_task.add_done_callback(lambda _: cleanup_fn())
665
+
666
+ # The gap between responses in a stream must be < streaming_read_timeout
667
+ # TODO: this whole buffering might be superfluous and sufficiently done by
668
+ # by the FastAPI server already. See `test_limit_concurrency_with_sse`.
669
+ async def _buffered_response_generator() -> AsyncGenerator[bytes, None]:
670
+ # `span` is tied to the "producer" `gen_task` which might complete before
671
+ # "consume" part here finishes, therefore a dedicated span is required.
672
+ # Because all of this code is inside a `detach_context` block, we
673
+ # explicitly propagate the tracing context for this span.
674
+ with self._tracer.start_as_current_span(
675
+ "buffered-response-generator", context=trace_ctx
676
+ ):
677
+ while True:
678
+ chunk = await asyncio.wait_for(
679
+ response_queue.get(), timeout=streaming_read_timeout
680
+ )
681
+ if chunk == SENTINEL:
682
+ return
683
+ yield chunk
684
+
685
+ return _buffered_response_generator()
686
+
687
+ async def _execute_async_model_fn(
688
+ self,
689
+ inputs: Union[InputType, Any],
690
+ request: starlette.requests.Request,
691
+ descriptor: MethodDescriptor,
692
+ ) -> OutputType:
693
+ args = ArgConfig.prepare_args(inputs, request, descriptor)
694
+ with errors.intercept_exceptions(self._logger, self._model_file_name):
695
+ if descriptor.is_generator:
696
+ # Even for async generators, don't await here.
697
+ return descriptor.method(*args)
698
+ if descriptor.is_async:
699
+ return await cast(Awaitable[OutputType], descriptor.method(*args))
700
+ return await to_thread.run_sync(descriptor.method, *args)
701
+
702
+ async def _process_model_fn(
703
+ self,
704
+ inputs: InputType,
705
+ request: starlette.requests.Request,
706
+ descriptor: MethodDescriptor,
707
+ ) -> OutputType:
708
+ """
709
+ Wraps the execution of any model code other than `predict`.
710
+ """
711
+ fn_span = self._tracer.start_span(f"call-{descriptor.method_name}")
712
+ # TODO(nikhil): Make it easier to start a section with detached context.
713
+ with tracing.section_as_event(
714
+ fn_span, descriptor.method_name
715
+ ), tracing.detach_context() as detached_ctx:
716
+ result = await self._execute_async_model_fn(inputs, request, descriptor)
717
+
718
+ if inspect.isgenerator(result) or inspect.isasyncgen(result):
719
+ return await self._handle_generator_response(
720
+ request, result, fn_span, detached_ctx
721
+ )
722
+
723
+ return result
724
+
725
+ def _should_gather_generator(self, request: starlette.requests.Request) -> bool:
726
+ # The OpenAI SDK sends an accept header for JSON even in a streaming context,
727
+ # but we need to stream results back for client compatibility. Luckily,
728
+ # we can differentiate by looking at the user agent (e.g. OpenAI/Python 1.61.0)
729
+ user_agent = request.headers.get("user-agent", "")
730
+ if "openai" in user_agent.lower():
731
+ return False
732
+ # TODO(nikhil): determine if we can safely deprecate this behavior.
733
+ return request.headers.get("accept") == "application/json"
734
+
735
+ async def _handle_generator_response(
736
+ self,
737
+ request: starlette.requests.Request,
738
+ generator: Union[Generator[bytes, None, None], AsyncGenerator[bytes, None]],
739
+ span: trace.Span,
740
+ trace_ctx: trace.Context,
741
+ get_cleanup_fn: Callable[[], Callable[[], None]] = lambda: lambda: None,
742
+ ):
743
+ if self._should_gather_generator(request):
744
+ return await _gather_generator(generator)
745
+ else:
746
+ return await self._stream_with_background_task(
747
+ generator, span, trace_ctx, cleanup_fn=get_cleanup_fn()
748
+ )
749
+
750
+ async def completions(
751
+ self, inputs: InputType, request: starlette.requests.Request
752
+ ) -> OutputType:
753
+ descriptor = self.model_descriptor.completions
754
+ assert descriptor, (
755
+ f"`{MethodName.COMPLETIONS}` must only be called if model has it."
756
+ )
757
+
758
+ return await self._process_model_fn(inputs, request, descriptor)
759
+
760
+ async def chat_completions(
761
+ self, inputs: InputType, request: starlette.requests.Request
762
+ ) -> OutputType:
763
+ descriptor = self.model_descriptor.chat_completions
764
+ assert descriptor, (
765
+ f"`{MethodName.CHAT_COMPLETIONS}` must only be called if model has it."
766
+ )
767
+
768
+ return await self._process_model_fn(inputs, request, descriptor)
769
+
770
+ async def __call__(
771
+ self, inputs: Optional[InputType], request: starlette.requests.Request
772
+ ) -> OutputType:
773
+ """
774
+ Returns result from: preprocess -> predictor -> postprocess.
775
+ """
776
+ if self.model_descriptor.preprocess:
777
+ with self._tracer.start_as_current_span("call-pre") as span_pre:
778
+ # TODO(nikhil): Make it easier to start a section with detached context.
779
+ with tracing.section_as_event(
780
+ span_pre, "preprocess"
781
+ ), tracing.detach_context():
782
+ preprocess_result = await self.preprocess(inputs, request)
783
+ else:
784
+ preprocess_result = inputs
785
+
786
+ span_predict = self._tracer.start_span("call-predict")
787
+ async with deferred_semaphore_and_span(
788
+ self._predict_semaphore, span_predict
789
+ ) as get_defer_fn:
790
+ # TODO(nikhil): Make it easier to start a section with detached context.
791
+ with tracing.section_as_event(
792
+ span_predict, "predict"
793
+ ), tracing.detach_context() as detached_ctx:
794
+ # To prevent span pollution, we need to make sure spans created by user
795
+ # code don't inherit context from our spans (which happens even if
796
+ # different tracer instances are used).
797
+ # Therefor, predict is run in `detach_context`.
798
+ # There is one caveat with streaming predictions though:
799
+ # The context manager only detaches spans that are created outside
800
+ # the generator loop that yields the stream (because the parts of the
801
+ # loop body will be executed in a "deferred" way (same reasoning as for
802
+ # using `deferred_semaphore_and_span`). We assume that here that
803
+ # creating spans inside the loop body is very unlikely. In order to
804
+ # exactly handle that case we would need to apply `detach_context`
805
+ # around each `next`-invocation that consumes the generator, which is
806
+ # prohibitive.
807
+ predict_result = await self.predict(preprocess_result, request)
808
+
809
+ if inspect.isgenerator(predict_result) or inspect.isasyncgen(
810
+ predict_result
811
+ ):
812
+ if self.model_descriptor.postprocess:
813
+ with errors.intercept_exceptions(
814
+ self._logger, self._model_file_name
815
+ ):
816
+ raise errors.ModelDefinitionError(
817
+ "If the predict function returns a generator (streaming), "
818
+ "you cannot use postprocessing. Include all processing in "
819
+ "the predict method."
820
+ )
821
+
822
+ return await self._handle_generator_response(
823
+ request,
824
+ predict_result,
825
+ span_predict,
826
+ detached_ctx,
827
+ get_cleanup_fn=get_defer_fn,
828
+ )
829
+
830
+ if isinstance(predict_result, starlette.responses.Response):
831
+ if self.model_descriptor.postprocess:
832
+ with errors.intercept_exceptions(
833
+ self._logger, self._model_file_name
834
+ ):
835
+ raise errors.ModelDefinitionError(
836
+ "If the predict function returns a response object, "
837
+ "you cannot use postprocessing."
838
+ )
839
+ if isinstance(predict_result, starlette.responses.StreamingResponse):
840
+ # Defer the semaphore release, using a weakref on the response.
841
+ # This might keep the semaphore longer than using "native" truss
842
+ # streaming, because here the criterion is not the production of
843
+ # data by the generator, but the span of handling the request by
844
+ # the fastAPI server.
845
+ weakref.finalize(predict_result, get_defer_fn())
846
+
847
+ return predict_result
848
+
849
+ if self.model_descriptor.postprocess:
850
+ with self._tracer.start_as_current_span("call-post") as span_post:
851
+ # TODO(nikhil): Make it easier to start a section with detached context.
852
+ with tracing.section_as_event(
853
+ span_post, "postprocess"
854
+ ), tracing.detach_context():
855
+ postprocess_result = await self.postprocess(predict_result, request)
856
+ return postprocess_result
857
+ else:
858
+ return predict_result
859
+
860
+
861
+ async def _gather_generator(
862
+ predict_result: Union[AsyncGenerator[bytes, None], Generator[bytes, None, None]],
863
+ ) -> str:
864
+ return "".join(
865
+ [str(chunk) async for chunk in _force_async_generator(predict_result)]
866
+ )
867
+
868
+
869
+ def _force_async_generator(gen: Union[Generator, AsyncGenerator]) -> AsyncGenerator:
870
+ """
871
+ Takes a generator, and converts it into an async generator if it is not already.
872
+ """
873
+ if inspect.isasyncgen(gen):
874
+ return gen
875
+
876
+ async def _convert_generator_to_async():
877
+ """
878
+ Runs each iteration of the generator in an offloaded thread, to ensure
879
+ the main loop is not blocked, and yield to create an async generator.
880
+ """
881
+ while True:
882
+ # Note that this is the equivalent of running:
883
+ # next(gen, FINAL_GENERATOR_VALUE) on a separate thread,
884
+ # ensuring that if there is anything blocking in the generator,
885
+ # it does not block the main loop.
886
+ chunk = await to_thread.run_sync(next, gen, SENTINEL)
887
+ if chunk == SENTINEL:
888
+ return
889
+ yield chunk
890
+
891
+ return _convert_generator_to_async()
892
+
893
+
894
+ def _signature_accepts_keyword_arg(signature: inspect.Signature, kwarg: str) -> bool:
895
+ return kwarg in signature.parameters or _signature_accepts_kwargs(signature)
896
+
897
+
898
+ def _signature_accepts_kwargs(signature: inspect.Signature) -> bool:
899
+ for param in signature.parameters.values():
900
+ if param.kind == inspect.Parameter.VAR_KEYWORD:
901
+ return True
902
+ return False
903
+
904
+
905
+ def _elapsed_ms(since_micro_seconds: float) -> int:
906
+ return int((time.perf_counter() - since_micro_seconds) * 1000)
907
+
908
+
909
+ def _init_extensions(config, data_dir, secrets, lazy_data_resolver):
910
+ extensions = {}
911
+ extensions_path = Path(__file__).parent / EXTENSIONS_DIR_NAME
912
+ if extensions_path.exists():
913
+ for extension_path in extensions_path.iterdir():
914
+ if extension_path.is_dir():
915
+ extension_name = extension_path.name
916
+ extension = _init_extension(
917
+ extension_name, config, data_dir, secrets, lazy_data_resolver
918
+ )
919
+ extensions[extension_name] = extension
920
+ return extensions
921
+
922
+
923
+ def _init_extension(extension_name: str, config, data_dir, secrets, lazy_data_resolver):
924
+ extension_module = importlib.import_module(
925
+ f"{EXTENSIONS_DIR_NAME}.{extension_name}.{EXTENSION_FILE_NAME}"
926
+ )
927
+ extension_class = getattr(extension_module, EXTENSION_CLASS_NAME)
928
+ init_args = _prepare_init_args(
929
+ extension_class,
930
+ config=config,
931
+ data_dir=data_dir,
932
+ secrets=secrets,
933
+ lazy_data_resolver=lazy_data_resolver,
934
+ )
935
+ return extension_class(**init_args)
936
+
937
+
938
+ def _prepare_init_args(klass, config, data_dir, secrets, lazy_data_resolver):
939
+ """Prepares init params based on signature.
940
+
941
+ Used to pass params to extension and model class' __init__ function.
942
+ """
943
+ signature = inspect.signature(klass)
944
+ model_init_params = {}
945
+ if _signature_accepts_keyword_arg(signature, "config"):
946
+ model_init_params["config"] = config
947
+ if _signature_accepts_keyword_arg(signature, "data_dir"):
948
+ model_init_params["data_dir"] = data_dir
949
+ if _signature_accepts_keyword_arg(signature, "secrets"):
950
+ model_init_params["secrets"] = secrets
951
+ if _signature_accepts_keyword_arg(signature, "lazy_data_resolver"):
952
+ model_init_params["lazy_data_resolver"] = lazy_data_resolver.fetch()
953
+ if _signature_accepts_keyword_arg(signature, "environment"):
954
+ environment = None
955
+ environment_str = dynamic_config_resolver.get_dynamic_config_value_sync(
956
+ dynamic_config_resolver.ENVIRONMENT_DYNAMIC_CONFIG_KEY
957
+ )
958
+ if environment_str:
959
+ environment = json.loads(environment_str)
960
+ model_init_params["environment"] = environment
961
+ return model_init_params