truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

Files changed (362) hide show
  1. truss/__init__.py +10 -3
  2. truss/api/__init__.py +123 -0
  3. truss/api/definitions.py +51 -0
  4. truss/base/constants.py +116 -0
  5. truss/base/custom_types.py +29 -0
  6. truss/{errors.py → base/errors.py} +4 -0
  7. truss/base/trt_llm_config.py +310 -0
  8. truss/{truss_config.py → base/truss_config.py} +344 -31
  9. truss/{truss_spec.py → base/truss_spec.py} +20 -6
  10. truss/{validation.py → base/validation.py} +60 -11
  11. truss/cli/cli.py +841 -88
  12. truss/{remote → cli}/remote_cli.py +2 -7
  13. truss/contexts/docker_build_setup.py +67 -0
  14. truss/contexts/image_builder/cache_warmer.py +2 -8
  15. truss/contexts/image_builder/image_builder.py +1 -1
  16. truss/contexts/image_builder/serving_image_builder.py +292 -46
  17. truss/contexts/image_builder/util.py +1 -3
  18. truss/contexts/local_loader/docker_build_emulator.py +58 -0
  19. truss/contexts/local_loader/load_model_local.py +2 -2
  20. truss/contexts/local_loader/truss_module_loader.py +1 -1
  21. truss/contexts/local_loader/utils.py +1 -1
  22. truss/local/local_config.py +2 -6
  23. truss/local/local_config_handler.py +20 -5
  24. truss/patch/__init__.py +1 -0
  25. truss/patch/hash.py +4 -70
  26. truss/patch/signature.py +4 -16
  27. truss/patch/truss_dir_patch_applier.py +3 -78
  28. truss/remote/baseten/api.py +308 -23
  29. truss/remote/baseten/auth.py +3 -3
  30. truss/remote/baseten/core.py +257 -50
  31. truss/remote/baseten/custom_types.py +44 -0
  32. truss/remote/baseten/error.py +4 -0
  33. truss/remote/baseten/remote.py +369 -118
  34. truss/remote/baseten/service.py +118 -11
  35. truss/remote/baseten/utils/status.py +29 -0
  36. truss/remote/baseten/utils/tar.py +34 -22
  37. truss/remote/baseten/utils/transfer.py +36 -23
  38. truss/remote/remote_factory.py +14 -5
  39. truss/remote/truss_remote.py +72 -45
  40. truss/templates/base.Dockerfile.jinja +18 -16
  41. truss/templates/cache.Dockerfile.jinja +3 -3
  42. truss/{server → templates/control}/control/application.py +14 -35
  43. truss/{server → templates/control}/control/endpoints.py +39 -9
  44. truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
  45. truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
  46. truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
  47. truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
  48. truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
  49. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
  50. truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
  51. truss/{server → templates/control}/control/server.py +11 -6
  52. truss/templates/control/requirements.txt +9 -0
  53. truss/templates/custom_python_dx/my_model.py +28 -0
  54. truss/templates/docker_server/proxy.conf.jinja +42 -0
  55. truss/templates/docker_server/supervisord.conf.jinja +27 -0
  56. truss/templates/docker_server_requirements.txt +1 -0
  57. truss/templates/server/common/errors.py +231 -0
  58. truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
  59. truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
  60. truss/{server → templates/server}/common/retry.py +1 -0
  61. truss/{server → templates/server}/common/schema.py +11 -9
  62. truss/templates/server/common/tracing.py +157 -0
  63. truss/templates/server/main.py +9 -0
  64. truss/templates/server/model_wrapper.py +961 -0
  65. truss/templates/server/requirements.txt +21 -0
  66. truss/templates/server/truss_server.py +447 -0
  67. truss/templates/server.Dockerfile.jinja +62 -14
  68. truss/templates/shared/dynamic_config_resolver.py +28 -0
  69. truss/templates/shared/lazy_data_resolver.py +164 -0
  70. truss/templates/shared/log_config.py +125 -0
  71. truss/{server → templates}/shared/secrets_resolver.py +1 -2
  72. truss/{server → templates}/shared/serialization.py +31 -9
  73. truss/{server → templates}/shared/util.py +3 -13
  74. truss/templates/trtllm-audio/model/model.py +49 -0
  75. truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
  76. truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
  77. truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
  78. truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
  79. truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
  80. truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
  81. truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
  82. truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
  83. truss/templates/trtllm-briton/src/extension.py +64 -0
  84. truss/tests/conftest.py +302 -94
  85. truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
  86. truss/tests/contexts/local_loader/test_load_local.py +2 -2
  87. truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
  88. truss/tests/patch/test_calc_patch.py +439 -127
  89. truss/tests/patch/test_dir_signature.py +3 -12
  90. truss/tests/patch/test_hash.py +1 -1
  91. truss/tests/patch/test_signature.py +1 -1
  92. truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
  93. truss/tests/patch/test_types.py +2 -2
  94. truss/tests/remote/baseten/test_api.py +153 -58
  95. truss/tests/remote/baseten/test_auth.py +2 -1
  96. truss/tests/remote/baseten/test_core.py +160 -12
  97. truss/tests/remote/baseten/test_remote.py +489 -77
  98. truss/tests/remote/baseten/test_service.py +55 -0
  99. truss/tests/remote/test_remote_factory.py +16 -18
  100. truss/tests/remote/test_truss_remote.py +26 -17
  101. truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
  102. truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
  103. truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
  104. truss/tests/{server → templates/control}/control/test_server.py +79 -24
  105. truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
  106. truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
  107. truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
  108. truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
  109. truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
  110. truss/tests/{server → templates/server}/common/test_retry.py +3 -3
  111. truss/tests/templates/server/test_model_wrapper.py +248 -0
  112. truss/tests/{server → templates/server}/test_schema.py +3 -5
  113. truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
  114. truss/tests/test_build.py +9 -52
  115. truss/tests/test_config.py +336 -77
  116. truss/tests/test_context_builder_image.py +3 -11
  117. truss/tests/test_control_truss_patching.py +7 -12
  118. truss/tests/test_custom_server.py +38 -0
  119. truss/tests/test_data/context_builder_image_test/test.py +3 -0
  120. truss/tests/test_data/happy.ipynb +56 -0
  121. truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
  122. truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
  123. truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
  124. truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
  125. truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
  126. truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
  127. truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
  128. truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
  129. truss/tests/test_data/test_async_truss/__init__.py +0 -0
  130. truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
  131. truss/tests/test_data/test_basic_truss/__init__.py +0 -0
  132. truss/tests/test_data/test_basic_truss/config.yaml +16 -0
  133. truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
  134. truss/tests/test_data/test_build_commands/__init__.py +0 -0
  135. truss/tests/test_data/test_build_commands/config.yaml +13 -0
  136. truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
  137. truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
  138. truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
  139. truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
  140. truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
  141. truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
  142. truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
  143. truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
  144. truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
  145. truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
  146. truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
  147. truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
  148. truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
  149. truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
  150. truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
  151. truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
  152. truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
  153. truss/tests/test_data/test_openai/__init__.py +0 -0
  154. truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
  155. truss/tests/test_data/test_openai/model/__init__.py +0 -0
  156. truss/tests/test_data/test_openai/model/model.py +15 -0
  157. truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
  158. truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
  159. truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
  160. truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
  161. truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
  162. truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
  163. truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
  164. truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
  165. truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
  166. truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
  167. truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
  168. truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
  169. truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
  170. truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
  171. truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
  172. truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
  173. truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
  174. truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
  175. truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
  176. truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
  177. truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
  178. truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
  179. truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
  180. truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
  181. truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
  182. truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
  183. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
  184. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
  185. truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
  186. truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
  187. truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
  188. truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
  189. truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
  190. truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
  191. truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
  192. truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
  193. truss/tests/test_data/test_truss/__init__.py +0 -0
  194. truss/tests/test_data/test_truss/config.yaml +4 -0
  195. truss/tests/test_data/test_truss/model/__init__.py +0 -0
  196. truss/tests/test_data/test_truss/model/dummy +0 -0
  197. truss/tests/test_data/test_truss/packages/__init__.py +0 -0
  198. truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
  199. truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
  200. truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
  201. truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
  202. truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
  203. truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
  204. truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
  205. truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
  206. truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
  207. truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
  208. truss/tests/test_docker.py +2 -1
  209. truss/tests/test_model_inference.py +1340 -292
  210. truss/tests/test_model_schema.py +33 -26
  211. truss/tests/test_testing_utilities_for_other_tests.py +50 -5
  212. truss/tests/test_truss_gatherer.py +3 -5
  213. truss/tests/test_truss_handle.py +62 -59
  214. truss/tests/test_util.py +2 -1
  215. truss/tests/test_validation.py +15 -13
  216. truss/tests/trt_llm/test_trt_llm_config.py +41 -0
  217. truss/tests/trt_llm/test_validation.py +91 -0
  218. truss/tests/util/test_config_checks.py +40 -0
  219. truss/tests/util/test_env_vars.py +14 -0
  220. truss/tests/util/test_path.py +10 -23
  221. truss/trt_llm/config_checks.py +43 -0
  222. truss/trt_llm/validation.py +42 -0
  223. truss/truss_handle/__init__.py +0 -0
  224. truss/truss_handle/build.py +122 -0
  225. truss/{decorators.py → truss_handle/decorators.py} +1 -1
  226. truss/truss_handle/patch/__init__.py +0 -0
  227. truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
  228. truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
  229. truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
  230. truss/truss_handle/patch/hash.py +71 -0
  231. truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
  232. truss/truss_handle/patch/signature.py +22 -0
  233. truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
  234. truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
  235. truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
  236. truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
  237. truss/util/.truss_ignore +3 -0
  238. truss/{docker.py → util/docker.py} +6 -2
  239. truss/util/download.py +6 -15
  240. truss/util/env_vars.py +41 -0
  241. truss/util/log_utils.py +52 -0
  242. truss/util/path.py +20 -20
  243. truss/util/requirements.py +11 -0
  244. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
  245. truss-0.60.0.dist-info/RECORD +324 -0
  246. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
  247. truss-0.60.0.dist-info/entry_points.txt +4 -0
  248. truss_chains/__init__.py +71 -0
  249. truss_chains/definitions.py +756 -0
  250. truss_chains/deployment/__init__.py +0 -0
  251. truss_chains/deployment/code_gen.py +816 -0
  252. truss_chains/deployment/deployment_client.py +871 -0
  253. truss_chains/framework.py +1480 -0
  254. truss_chains/public_api.py +231 -0
  255. truss_chains/py.typed +0 -0
  256. truss_chains/pydantic_numpy.py +131 -0
  257. truss_chains/reference_code/reference_chainlet.py +34 -0
  258. truss_chains/reference_code/reference_model.py +10 -0
  259. truss_chains/remote_chainlet/__init__.py +0 -0
  260. truss_chains/remote_chainlet/model_skeleton.py +60 -0
  261. truss_chains/remote_chainlet/stub.py +380 -0
  262. truss_chains/remote_chainlet/utils.py +332 -0
  263. truss_chains/streaming.py +378 -0
  264. truss_chains/utils.py +178 -0
  265. CODE_OF_CONDUCT.md +0 -131
  266. CONTRIBUTING.md +0 -48
  267. README.md +0 -137
  268. context_builder.Dockerfile +0 -24
  269. truss/blob/blob_backend.py +0 -10
  270. truss/blob/blob_backend_registry.py +0 -23
  271. truss/blob/http_public_blob_backend.py +0 -23
  272. truss/build/__init__.py +0 -2
  273. truss/build/build.py +0 -143
  274. truss/build/configure.py +0 -63
  275. truss/cli/__init__.py +0 -2
  276. truss/cli/console.py +0 -5
  277. truss/cli/create.py +0 -5
  278. truss/config/trt_llm.py +0 -81
  279. truss/constants.py +0 -61
  280. truss/model_inference.py +0 -123
  281. truss/patch/types.py +0 -30
  282. truss/pytest.ini +0 -7
  283. truss/server/common/errors.py +0 -100
  284. truss/server/common/termination_handler_middleware.py +0 -64
  285. truss/server/common/truss_server.py +0 -389
  286. truss/server/control/patch/model_code_patch_applier.py +0 -46
  287. truss/server/control/patch/requirement_name_identifier.py +0 -17
  288. truss/server/inference_server.py +0 -29
  289. truss/server/model_wrapper.py +0 -434
  290. truss/server/shared/logging.py +0 -81
  291. truss/templates/trtllm/model/model.py +0 -97
  292. truss/templates/trtllm/packages/build_engine_utils.py +0 -34
  293. truss/templates/trtllm/packages/constants.py +0 -11
  294. truss/templates/trtllm/packages/schema.py +0 -216
  295. truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
  296. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
  297. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
  298. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
  299. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
  300. truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
  301. truss/templates/trtllm/packages/triton_client.py +0 -150
  302. truss/templates/trtllm/packages/utils.py +0 -43
  303. truss/test_data/context_builder_image_test/test.py +0 -4
  304. truss/test_data/happy.ipynb +0 -54
  305. truss/test_data/model_load_failure_test/config.yaml +0 -2
  306. truss/test_data/test_concurrency_truss/config.yaml +0 -2
  307. truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
  308. truss/test_data/test_streaming_truss/config.yaml +0 -3
  309. truss/test_data/test_truss/config.yaml +0 -2
  310. truss/tests/server/common/test_termination_handler_middleware.py +0 -93
  311. truss/tests/server/control/test_model_container_patch_applier.py +0 -203
  312. truss/tests/server/core/server/common/test_util.py +0 -19
  313. truss/tests/server/test_model_wrapper.py +0 -87
  314. truss/util/data_structures.py +0 -16
  315. truss-0.10.0rc1.dist-info/RECORD +0 -216
  316. truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
  317. truss/{server/shared → base}/__init__.py +0 -0
  318. truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
  319. truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
  320. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
  321. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
  322. truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
  323. truss/{server → templates/server}/common/__init__.py +0 -0
  324. truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
  325. truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
  326. truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
  327. truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
  328. truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
  329. truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
  330. truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
  331. truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
  332. truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
  333. truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
  334. truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
  335. truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
  336. truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
  337. truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
  338. truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
  339. truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
  340. truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
  341. truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
  342. truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
  343. truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
  344. truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
  345. truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
  346. truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
  347. /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
  348. /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
  349. /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
  350. /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
  351. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
  352. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
  353. /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
  354. /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
  355. /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
  356. /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
  357. /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
  358. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
  359. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
  360. /truss/{patch → truss_handle/patch}/constants.py +0 -0
  361. /truss/{notebook.py → util/notebook.py} +0 -0
  362. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
@@ -0,0 +1,66 @@
1
+ import re
2
+ from dataclasses import dataclass
3
+ from typing import Dict, List, Optional
4
+ from urllib.parse import parse_qs, urlparse
5
+
6
+ from packaging import requirements # type: ignore
7
+
8
+ URL_PATTERN = re.compile(r"^(https?|git|svn|hg|bzr)(\+|:\/\/)")
9
+
10
+
11
+ def is_url_based_requirement(req: str) -> bool:
12
+ return bool(URL_PATTERN.match(req.strip()))
13
+
14
+
15
+ def get_egg_tag(req: str) -> Optional[List[str]]:
16
+ if is_url_based_requirement(req):
17
+ parsed_url = urlparse(req)
18
+ fragments = parse_qs(parsed_url.fragment)
19
+ return fragments.get("egg", None)
20
+ return None
21
+
22
+
23
+ def reqs_by_name(reqs: List[str]) -> Dict[str, str]:
24
+ return {identify_requirement_name(req): req for req in reqs if req.strip()}
25
+
26
+
27
+ def identify_requirement_name(req: str) -> str:
28
+ try:
29
+ # parse as a url requirement
30
+ req = req.strip()
31
+ if is_url_based_requirement(req):
32
+ parsed_url = urlparse(req)
33
+ # Identify the package by it's unique components. We can't reliably id a package name
34
+ # via URL, so we fallback on this instead.
35
+ vcs_options = ["git", "svn", "hg", "bzr"]
36
+ for vcs in vcs_options:
37
+ if req.startswith(vcs + "+"):
38
+ return f"{vcs}+{parsed_url.netloc}{parsed_url.path.split('@')[0]}"
39
+
40
+ return req
41
+
42
+ parsed_req = requirements.Requirement(req)
43
+ return parsed_req.name # type: ignore
44
+ except (requirements.InvalidRequirement, ValueError):
45
+ # default to the whole line if we can't parse it.
46
+ return req.strip()
47
+
48
+
49
+ @dataclass
50
+ class RequirementMeta:
51
+ requirement: str
52
+ name: str
53
+ is_url_based_requirement: bool
54
+ egg_tag: Optional[List[str]]
55
+
56
+ def __hash__(self):
57
+ return hash(self.name)
58
+
59
+ @staticmethod
60
+ def from_req(req: str) -> "RequirementMeta":
61
+ return RequirementMeta(
62
+ req,
63
+ identify_requirement_name(req),
64
+ is_url_based_requirement(req),
65
+ get_egg_tag(req),
66
+ )
@@ -1,14 +1,20 @@
1
1
  import asyncio
2
2
  import os
3
- from pathlib import Path
4
3
 
5
4
  import uvicorn
6
- from truss.server.control.application import create_app
5
+ from application import create_app
7
6
 
8
7
  CONTROL_SERVER_PORT = int(os.environ.get("CONTROL_SERVER_PORT", "8080"))
9
8
  INFERENCE_SERVER_PORT = int(os.environ.get("INFERENCE_SERVER_PORT", "8090"))
10
9
 
11
10
 
11
+ def _identify_python_executable_path() -> str:
12
+ if "PYTHON_EXECUTABLE" in os.environ:
13
+ return os.environ["PYTHON_EXECUTABLE"]
14
+
15
+ raise RuntimeError("Unable to find python, make sure it's installed.")
16
+
17
+
12
18
  class ControlServer:
13
19
  def __init__(
14
20
  self,
@@ -29,8 +35,7 @@ class ControlServer:
29
35
  "inference_server_home": self._inf_serv_home,
30
36
  "inference_server_process_args": [
31
37
  self._python_executable_path,
32
- "-m",
33
- "truss.server.inference_server",
38
+ f"{self._inf_serv_home}/main.py",
34
39
  ],
35
40
  "control_server_host": "0.0.0.0",
36
41
  "control_server_port": self._control_server_port,
@@ -59,8 +64,8 @@ class ControlServer:
59
64
 
60
65
  if __name__ == "__main__":
61
66
  control_server = ControlServer(
62
- python_executable_path=os.environ.get("PYTHON_EXECUTABLE", default="python3"),
63
- inf_serv_home=os.environ.get("APP_HOME", default=str(Path.cwd())),
67
+ python_executable_path=_identify_python_executable_path(),
68
+ inf_serv_home=os.environ["APP_HOME"],
64
69
  control_server_port=CONTROL_SERVER_PORT,
65
70
  inference_server_port=INFERENCE_SERVER_PORT,
66
71
  )
@@ -0,0 +1,9 @@
1
+ dataclasses-json==0.5.7
2
+ truss==0.9.14 # TODO(marius/TaT): remove after TaT.
3
+ fastapi==0.114.1
4
+ uvicorn==0.24.0
5
+ uvloop==0.19.0
6
+ tenacity==8.1.0
7
+ httpx==0.27.0
8
+ python-json-logger==2.0.2
9
+ loguru==0.7.2
@@ -0,0 +1,28 @@
1
+ """
2
+ The `Model` class is an interface between the ML model that you're packaging and the model
3
+ server that you're running it on.
4
+
5
+ The main method to implement here is:
6
+ * `predict`: runs every time the model server is called. Include any logic for model
7
+ inference and return the model output.
8
+
9
+ See https://truss.baseten.co/quickstart for more.
10
+ """
11
+
12
+ import truss_chains as baseten
13
+
14
+
15
+ class Model(baseten.ModelBase):
16
+ # Configure resources for your model here.
17
+ remote_config: baseten.RemoteConfig = baseten.RemoteConfig(name="{{ MODEL_NAME }}")
18
+
19
+ def __init__(
20
+ self, context: baseten.DeploymentContext = baseten.depends_context()
21
+ ) -> None:
22
+ # Access secrets via optional `context` variable.
23
+ # Load model here and assign to self._model.
24
+ self._model = None
25
+
26
+ def predict(self, input_field: int) -> int:
27
+ # Run model inference here
28
+ return input_field
@@ -0,0 +1,42 @@
1
+ server {
2
+ # We use the proxy_read_timeout directive here (instead of proxy_send_timeout) as it sets the timeout for reading a response from the proxied server vs. setting a timeout for sending a request to the proxied server.
3
+ listen 8080;
4
+ client_max_body_size {{client_max_body_size}};
5
+
6
+ # Liveness endpoint override
7
+ location = / {
8
+ proxy_redirect off;
9
+ proxy_read_timeout 300s;
10
+
11
+ rewrite ^/$ {{liveness_endpoint}} break;
12
+
13
+ proxy_pass http://127.0.0.1:{{server_port}};
14
+ }
15
+ # Readiness endpoint override
16
+ location ~ ^/v1/models/model$ {
17
+ proxy_redirect off;
18
+ proxy_read_timeout 300s;
19
+
20
+ rewrite ^/v1/models/model$ {{readiness_endpoint}} break;
21
+
22
+ proxy_pass http://127.0.0.1:{{server_port}};
23
+ }
24
+ # Predict
25
+ location ~ ^/v1/models/model:predict$ {
26
+ proxy_redirect off;
27
+ proxy_read_timeout 300s;
28
+
29
+ rewrite ^/v1/models/model:predict$ {{server_endpoint}} break;
30
+
31
+ proxy_pass http://127.0.0.1:{{server_port}};
32
+ }
33
+
34
+ # Forward all other paths
35
+ location / {
36
+ proxy_redirect off;
37
+ proxy_read_timeout 300s;
38
+
39
+ proxy_pass http://127.0.0.1:{{server_port}};
40
+ }
41
+
42
+ }
@@ -0,0 +1,27 @@
1
+ [supervisord]
2
+ nodaemon=true ; Run supervisord in the foreground (useful for containers)
3
+ logfile=/dev/null ; Disable logging to file (send logs to /dev/null)
4
+ logfile_maxbytes=0 ; No size limit on logfile (since logging is disabled)
5
+
6
+ [program:model-server]
7
+ command={{start_command}} ; Command to start the model server (provided by Jinja variable)
8
+ startsecs=30 ; Wait 30 seconds before assuming the server is running
9
+ autostart=true ; Automatically start the program when supervisord starts
10
+ autorestart=true ; Always restart the program if it exits, no matter what the exit code
11
+ stdout_logfile=/dev/fd/1 ; Send stdout to the first file descriptor (stdout)
12
+ stdout_logfile_maxbytes=0 ; No size limit on stdout log
13
+ redirect_stderr=true ; Redirect stderr to stdout
14
+
15
+ [program:nginx]
16
+ command=nginx -g "daemon off;" ; Command to start nginx without daemonizing (keeps it in the foreground)
17
+ startsecs=0 ; Assume nginx starts immediately
18
+ autostart=true ; Automatically start nginx when supervisord starts
19
+ autorestart=true ; Always restart the program if it exits, no matter what the exit code
20
+ stdout_logfile=/dev/fd/1 ; Send nginx stdout to the first file descriptor (stdout)
21
+ stdout_logfile_maxbytes=0 ; No size limit on stdout log
22
+ redirect_stderr=true ; Redirect nginx stderr to stdout
23
+
24
+ [eventlistener:quit_on_failure]
25
+ events=PROCESS_STATE_FATAL ; Listen for fatal process state events
26
+ command=sh -c 'echo "READY"; read line; kill -15 1; echo "RESULT 2";'
27
+ ; Stop supervisord if a fatal event occurs (by sending SIGTERM to PID 1)
@@ -0,0 +1 @@
1
+ supervisor==4.2.5
@@ -0,0 +1,231 @@
1
+ import contextlib
2
+ import logging
3
+ import sys
4
+ import textwrap
5
+ from http import HTTPStatus
6
+ from types import TracebackType
7
+ from typing import Generator, Mapping, Optional, Tuple, Type, Union
8
+
9
+ import fastapi
10
+ import pydantic
11
+ import starlette.responses
12
+ from fastapi.responses import JSONResponse
13
+
14
+ # See https://github.com/basetenlabs/baseten/blob/master/docs/Error-Propagation.md
15
+ _TRUSS_SERVER_SERVICE_ID = 4
16
+ _BASETEN_UNEXPECTED_ERROR = 500
17
+ _BASETEN_DOWNSTREAM_ERROR_CODE = 600
18
+ _BASETEN_CLIENT_ERROR_CODE = 700
19
+
20
+
21
+ class ModelMissingError(Exception):
22
+ def __init__(self, path):
23
+ self.path = path
24
+
25
+ def __str__(self):
26
+ return self.path
27
+
28
+
29
+ class ModelNotReady(RuntimeError):
30
+ def __init__(self, model_name: str, detail: Optional[str] = None):
31
+ self.model_name = model_name
32
+ self.error_msg = f"Model with name {self.model_name} is not ready."
33
+ if detail:
34
+ self.error_msg = self.error_msg + " " + detail
35
+
36
+ def __str__(self):
37
+ return self.error_msg
38
+
39
+
40
+ class InputParsingError(ValueError):
41
+ pass
42
+
43
+
44
+ class UserCodeError(Exception):
45
+ pass
46
+
47
+
48
+ class ModelDefinitionError(TypeError):
49
+ """When the user-defined truss model does not meet the contract."""
50
+
51
+
52
+ def _make_baseten_error_headers(error_code: int) -> Mapping[str, str]:
53
+ return {
54
+ "X-BASETEN-ERROR-SOURCE": f"{_TRUSS_SERVER_SERVICE_ID:02}",
55
+ "X-BASETEN-ERROR-CODE": f"{error_code:03}",
56
+ }
57
+
58
+
59
+ def add_error_headers_to_user_response(response: starlette.responses.Response) -> None:
60
+ response.headers.update(_make_baseten_error_headers(_BASETEN_CLIENT_ERROR_CODE))
61
+
62
+
63
+ def _make_baseten_response(
64
+ http_status: int, info: Union[str, Exception], baseten_error_code: int
65
+ ) -> fastapi.Response:
66
+ msg = str(info) if isinstance(info, Exception) else info
67
+ return JSONResponse(
68
+ status_code=http_status,
69
+ content={"error": msg},
70
+ headers=_make_baseten_error_headers(baseten_error_code),
71
+ )
72
+
73
+
74
+ async def exception_handler(_: fastapi.Request, exc: Exception) -> fastapi.Response:
75
+ if isinstance(exc, ModelMissingError):
76
+ return _make_baseten_response(
77
+ HTTPStatus.NOT_FOUND.value, exc, _BASETEN_DOWNSTREAM_ERROR_CODE
78
+ )
79
+ if isinstance(exc, ModelNotReady):
80
+ return _make_baseten_response(
81
+ HTTPStatus.SERVICE_UNAVAILABLE.value, exc, _BASETEN_DOWNSTREAM_ERROR_CODE
82
+ )
83
+ if isinstance(exc, InputParsingError):
84
+ return _make_baseten_response(
85
+ HTTPStatus.BAD_REQUEST.value, exc, _BASETEN_CLIENT_ERROR_CODE
86
+ )
87
+ if isinstance(exc, ModelDefinitionError):
88
+ return _make_baseten_response(
89
+ HTTPStatus.PRECONDITION_FAILED.value,
90
+ f"{type(exc).__name__}: {str(exc)}",
91
+ _BASETEN_DOWNSTREAM_ERROR_CODE,
92
+ )
93
+ if isinstance(exc, UserCodeError):
94
+ return _make_baseten_response(
95
+ HTTPStatus.INTERNAL_SERVER_ERROR.value,
96
+ "Internal Server Error",
97
+ _BASETEN_DOWNSTREAM_ERROR_CODE,
98
+ )
99
+ if isinstance(exc, fastapi.HTTPException):
100
+ # This is a pass through, but additionally adds our custom error headers.
101
+ return _make_baseten_response(
102
+ exc.status_code, exc.detail, _BASETEN_DOWNSTREAM_ERROR_CODE
103
+ )
104
+ # Fallback case.
105
+ return _make_baseten_response(
106
+ HTTPStatus.INTERNAL_SERVER_ERROR.value,
107
+ f"Unhandled exception: {type(exc).__name__}: {str(exc)}",
108
+ _BASETEN_UNEXPECTED_ERROR,
109
+ )
110
+
111
+
112
+ HANDLED_EXCEPTIONS = {
113
+ ModelMissingError,
114
+ ModelNotReady,
115
+ NotImplementedError,
116
+ InputParsingError,
117
+ UserCodeError,
118
+ ModelDefinitionError,
119
+ fastapi.HTTPException,
120
+ }
121
+
122
+
123
+ def filter_traceback(
124
+ model_file_name: str,
125
+ ) -> Union[
126
+ Tuple[Type[BaseException], BaseException, TracebackType], Tuple[None, None, None]
127
+ ]:
128
+ exc_type, exc_value, tb = sys.exc_info()
129
+ if tb is None:
130
+ return exc_type, exc_value, tb # type: ignore[return-value]
131
+
132
+ # Store the last occurrence of the traceback matching the condition
133
+ last_matching_tb: Optional[TracebackType] = None
134
+ current_tb: Optional[TracebackType] = tb
135
+
136
+ while current_tb is not None:
137
+ filename = current_tb.tb_frame.f_code.co_filename
138
+ if filename.endswith(model_file_name):
139
+ last_matching_tb = current_tb
140
+ current_tb = current_tb.tb_next
141
+
142
+ # If a match was found, truncate the traceback at the last occurrence
143
+ if last_matching_tb is not None:
144
+ return exc_type, exc_value, last_matching_tb # type: ignore[return-value]
145
+
146
+ # If `model_file_name` not found, return the original exception info
147
+ return exc_type, exc_value, tb # type: ignore[return-value]
148
+
149
+
150
+ @contextlib.contextmanager
151
+ def intercept_exceptions(
152
+ logger: logging.Logger, model_file_name: str
153
+ ) -> Generator[None, None, None]:
154
+ try:
155
+ yield
156
+ # Note that logger.error logs the stacktrace, such that the user can
157
+ # debug this error from the logs.
158
+ except fastapi.HTTPException as e:
159
+ # TODO: we try to avoid any dependency of the truss server on chains, but for
160
+ # the purpose of getting readable chained-stack traces in the server logs,
161
+ # we have to add a special-case here.
162
+ if "user_stack_trace" in e.detail:
163
+ try:
164
+ from truss_chains import definitions
165
+
166
+ chains_error = definitions.RemoteErrorDetail.model_validate(e.detail)
167
+ # The formatted error contains a (potentially chained) stack trace
168
+ # with all framework code removed, see
169
+ # truss_chains/remote_chainlet/utils.py::response_raise_errors.
170
+ logger.error(f"Chainlet raised Exception:\n{chains_error.format()}")
171
+ except: # If we cannot import chains or parse the error.
172
+ logger.error(
173
+ "Model raised HTTPException",
174
+ exc_info=filter_traceback(model_file_name),
175
+ )
176
+ raise
177
+ # If error was extracted successfully, the customized stack trace is
178
+ # already printed above, so we raise with a clear traceback.
179
+ e.__traceback__ = None
180
+ raise e from None
181
+
182
+ logger.error(
183
+ "Model raised HTTPException", exc_info=filter_traceback(model_file_name)
184
+ )
185
+ raise
186
+ except Exception as exc:
187
+ logger.error(
188
+ "Internal Server Error", exc_info=filter_traceback(model_file_name)
189
+ )
190
+ raise UserCodeError(str(exc))
191
+
192
+
193
+ def _loc_to_dot_sep(loc: Tuple[Union[str, int], ...]) -> str:
194
+ # From https://docs.pydantic.dev/latest/errors/errors/#customize-error-messages.
195
+ # Chained field access is stylized with `.`-notation (corresponding to str parts)
196
+ # and array indexing is stylized with `[i]`-notation (corresponding to int parts).
197
+ # E.g. ('items', 1, 'value') -> items[1].value.
198
+ parts = []
199
+ for i, x in enumerate(loc):
200
+ if isinstance(x, str):
201
+ if i > 0:
202
+ parts.append(".")
203
+ parts.append(x)
204
+ elif isinstance(x, int):
205
+ parts.append(f"[{x}]")
206
+ else:
207
+ raise TypeError(f"Unexpected type: {x}.")
208
+ return "".join(parts)
209
+
210
+
211
+ def format_pydantic_validation_error(exc: pydantic.ValidationError) -> str:
212
+ if exc.error_count() == 1:
213
+ parts = ["Input Parsing Error:"]
214
+ else:
215
+ parts = ["Input Parsing Errors:"]
216
+
217
+ try:
218
+ for error in exc.errors(include_url=False, include_input=False):
219
+ loc = _loc_to_dot_sep(error["loc"])
220
+ if error["msg"]:
221
+ msg = error["msg"]
222
+ else:
223
+ error_no_loc = dict(error)
224
+ error_no_loc.pop("loc")
225
+ msg = str(error_no_loc)
226
+ parts.append(textwrap.indent(f"`{loc}`: {msg}.", " "))
227
+ except: # noqa: E722
228
+ # Fallback in case any of the fields cannot be processed as expected.
229
+ return f"Input Parsing Error, {str(exc)}"
230
+
231
+ return "\n".join(parts)
@@ -7,6 +7,7 @@ to fail. This patch makes the download more defensive, addressing this issue. tq
7
7
  works when the Content-Length is None.
8
8
  Open PR (https://github.com/openai/whisper/pull/1366) - once merged we can remove the patch.
9
9
  """
10
+
10
11
  import os
11
12
  from typing import Union
12
13
 
@@ -2,8 +2,6 @@ import importlib.util
2
2
  import logging
3
3
  from pathlib import Path
4
4
 
5
- # Set up logging
6
- logging.basicConfig(level=logging.INFO)
7
5
  logger = logging.getLogger(__name__)
8
6
 
9
7
 
@@ -12,7 +10,7 @@ def apply_patches(enabled: bool, requirements: list):
12
10
  Apply patches to certain functions. The patches are contained in each patch module under 'patches' directory.
13
11
  If a patch cannot be applied, it logs the name of the function and the exception details.
14
12
  """
15
- PATCHES_DIR = Path(__file__).parent
13
+ PATCHES_DIR = Path(__file__).parent / "patches"
16
14
  if not enabled:
17
15
  return
18
16
  for requirement in requirements:
@@ -2,6 +2,7 @@ import time
2
2
  from typing import Callable
3
3
 
4
4
 
5
+ # TODO: replace with tenacity.
5
6
  def retry(
6
7
  fn: Callable,
7
8
  count: int,
@@ -2,12 +2,15 @@ from types import MappingProxyType
2
2
  from typing import (
3
3
  Any,
4
4
  AsyncGenerator,
5
+ AsyncIterator,
5
6
  Awaitable,
6
7
  Generator,
8
+ Iterator,
7
9
  List,
8
10
  Optional,
9
11
  Type,
10
12
  Union,
13
+ cast,
11
14
  get_args,
12
15
  get_origin,
13
16
  )
@@ -61,10 +64,7 @@ class TrussSchema(BaseModel):
61
64
 
62
65
  def _parse_input_type(input_parameters: MappingProxyType) -> Optional[type]:
63
66
  parameter_types = list(input_parameters.values())
64
-
65
- if len(parameter_types) > 1:
66
- return None
67
-
67
+ # In `ArgConfig.from_signature` the arguments are validated.
68
68
  input_type = parameter_types[0].annotation
69
69
 
70
70
  if _annotation_is_pydantic_model(input_type):
@@ -85,7 +85,7 @@ def _annotation_is_pydantic_model(annotation: Any) -> bool:
85
85
 
86
86
  def _parse_output_type(output_annotation: Any) -> Optional[OutputType]:
87
87
  """
88
- Therea are 4 possible cases for output_annotation:
88
+ There are 4 possible cases for output_annotation:
89
89
  1. Data object -- represented by a Pydantic BaseModel
90
90
  2. Streaming -- represented by a Generator or AsyncGenerator
91
91
  3. Async -- represented by an Awaitable
@@ -119,7 +119,7 @@ def _parse_output_type(output_annotation: Any) -> Optional[OutputType]:
119
119
  def _is_generator_type(annotation: Any) -> bool:
120
120
  base_type = get_origin(annotation)
121
121
  return isinstance(base_type, type) and issubclass(
122
- base_type, (Generator, AsyncGenerator)
122
+ base_type, (Generator, AsyncGenerator, Iterator, AsyncIterator)
123
123
  )
124
124
 
125
125
 
@@ -158,11 +158,13 @@ def _extract_pydantic_base_models(union_args: tuple) -> List[Type[BaseModel]]:
158
158
  2. Union[Awaitable[PydanticBaseModel], AsyncGenerator]
159
159
  So for Awaitables, we need to extract the base class from the Awaitable type
160
160
  """
161
+
161
162
  return [
162
- retrieve_base_class_from_awaitable(arg) if _is_awaitable_type(arg) else arg
163
- for arg in union_args
163
+ cast(Type[BaseModel], retrieve_base_class_from_awaitable(arg))
164
164
  if _is_awaitable_type(arg)
165
- or (isinstance(arg, type) and issubclass(arg, BaseModel))
165
+ else arg
166
+ for arg in union_args
167
+ if _is_awaitable_type(arg) or _annotation_is_pydantic_model(arg)
166
168
  ]
167
169
 
168
170