truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

Files changed (362) hide show
  1. truss/__init__.py +10 -3
  2. truss/api/__init__.py +123 -0
  3. truss/api/definitions.py +51 -0
  4. truss/base/constants.py +116 -0
  5. truss/base/custom_types.py +29 -0
  6. truss/{errors.py → base/errors.py} +4 -0
  7. truss/base/trt_llm_config.py +310 -0
  8. truss/{truss_config.py → base/truss_config.py} +344 -31
  9. truss/{truss_spec.py → base/truss_spec.py} +20 -6
  10. truss/{validation.py → base/validation.py} +60 -11
  11. truss/cli/cli.py +841 -88
  12. truss/{remote → cli}/remote_cli.py +2 -7
  13. truss/contexts/docker_build_setup.py +67 -0
  14. truss/contexts/image_builder/cache_warmer.py +2 -8
  15. truss/contexts/image_builder/image_builder.py +1 -1
  16. truss/contexts/image_builder/serving_image_builder.py +292 -46
  17. truss/contexts/image_builder/util.py +1 -3
  18. truss/contexts/local_loader/docker_build_emulator.py +58 -0
  19. truss/contexts/local_loader/load_model_local.py +2 -2
  20. truss/contexts/local_loader/truss_module_loader.py +1 -1
  21. truss/contexts/local_loader/utils.py +1 -1
  22. truss/local/local_config.py +2 -6
  23. truss/local/local_config_handler.py +20 -5
  24. truss/patch/__init__.py +1 -0
  25. truss/patch/hash.py +4 -70
  26. truss/patch/signature.py +4 -16
  27. truss/patch/truss_dir_patch_applier.py +3 -78
  28. truss/remote/baseten/api.py +308 -23
  29. truss/remote/baseten/auth.py +3 -3
  30. truss/remote/baseten/core.py +257 -50
  31. truss/remote/baseten/custom_types.py +44 -0
  32. truss/remote/baseten/error.py +4 -0
  33. truss/remote/baseten/remote.py +369 -118
  34. truss/remote/baseten/service.py +118 -11
  35. truss/remote/baseten/utils/status.py +29 -0
  36. truss/remote/baseten/utils/tar.py +34 -22
  37. truss/remote/baseten/utils/transfer.py +36 -23
  38. truss/remote/remote_factory.py +14 -5
  39. truss/remote/truss_remote.py +72 -45
  40. truss/templates/base.Dockerfile.jinja +18 -16
  41. truss/templates/cache.Dockerfile.jinja +3 -3
  42. truss/{server → templates/control}/control/application.py +14 -35
  43. truss/{server → templates/control}/control/endpoints.py +39 -9
  44. truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
  45. truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
  46. truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
  47. truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
  48. truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
  49. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
  50. truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
  51. truss/{server → templates/control}/control/server.py +11 -6
  52. truss/templates/control/requirements.txt +9 -0
  53. truss/templates/custom_python_dx/my_model.py +28 -0
  54. truss/templates/docker_server/proxy.conf.jinja +42 -0
  55. truss/templates/docker_server/supervisord.conf.jinja +27 -0
  56. truss/templates/docker_server_requirements.txt +1 -0
  57. truss/templates/server/common/errors.py +231 -0
  58. truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
  59. truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
  60. truss/{server → templates/server}/common/retry.py +1 -0
  61. truss/{server → templates/server}/common/schema.py +11 -9
  62. truss/templates/server/common/tracing.py +157 -0
  63. truss/templates/server/main.py +9 -0
  64. truss/templates/server/model_wrapper.py +961 -0
  65. truss/templates/server/requirements.txt +21 -0
  66. truss/templates/server/truss_server.py +447 -0
  67. truss/templates/server.Dockerfile.jinja +62 -14
  68. truss/templates/shared/dynamic_config_resolver.py +28 -0
  69. truss/templates/shared/lazy_data_resolver.py +164 -0
  70. truss/templates/shared/log_config.py +125 -0
  71. truss/{server → templates}/shared/secrets_resolver.py +1 -2
  72. truss/{server → templates}/shared/serialization.py +31 -9
  73. truss/{server → templates}/shared/util.py +3 -13
  74. truss/templates/trtllm-audio/model/model.py +49 -0
  75. truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
  76. truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
  77. truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
  78. truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
  79. truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
  80. truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
  81. truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
  82. truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
  83. truss/templates/trtllm-briton/src/extension.py +64 -0
  84. truss/tests/conftest.py +302 -94
  85. truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
  86. truss/tests/contexts/local_loader/test_load_local.py +2 -2
  87. truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
  88. truss/tests/patch/test_calc_patch.py +439 -127
  89. truss/tests/patch/test_dir_signature.py +3 -12
  90. truss/tests/patch/test_hash.py +1 -1
  91. truss/tests/patch/test_signature.py +1 -1
  92. truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
  93. truss/tests/patch/test_types.py +2 -2
  94. truss/tests/remote/baseten/test_api.py +153 -58
  95. truss/tests/remote/baseten/test_auth.py +2 -1
  96. truss/tests/remote/baseten/test_core.py +160 -12
  97. truss/tests/remote/baseten/test_remote.py +489 -77
  98. truss/tests/remote/baseten/test_service.py +55 -0
  99. truss/tests/remote/test_remote_factory.py +16 -18
  100. truss/tests/remote/test_truss_remote.py +26 -17
  101. truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
  102. truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
  103. truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
  104. truss/tests/{server → templates/control}/control/test_server.py +79 -24
  105. truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
  106. truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
  107. truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
  108. truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
  109. truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
  110. truss/tests/{server → templates/server}/common/test_retry.py +3 -3
  111. truss/tests/templates/server/test_model_wrapper.py +248 -0
  112. truss/tests/{server → templates/server}/test_schema.py +3 -5
  113. truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
  114. truss/tests/test_build.py +9 -52
  115. truss/tests/test_config.py +336 -77
  116. truss/tests/test_context_builder_image.py +3 -11
  117. truss/tests/test_control_truss_patching.py +7 -12
  118. truss/tests/test_custom_server.py +38 -0
  119. truss/tests/test_data/context_builder_image_test/test.py +3 -0
  120. truss/tests/test_data/happy.ipynb +56 -0
  121. truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
  122. truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
  123. truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
  124. truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
  125. truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
  126. truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
  127. truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
  128. truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
  129. truss/tests/test_data/test_async_truss/__init__.py +0 -0
  130. truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
  131. truss/tests/test_data/test_basic_truss/__init__.py +0 -0
  132. truss/tests/test_data/test_basic_truss/config.yaml +16 -0
  133. truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
  134. truss/tests/test_data/test_build_commands/__init__.py +0 -0
  135. truss/tests/test_data/test_build_commands/config.yaml +13 -0
  136. truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
  137. truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
  138. truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
  139. truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
  140. truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
  141. truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
  142. truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
  143. truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
  144. truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
  145. truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
  146. truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
  147. truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
  148. truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
  149. truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
  150. truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
  151. truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
  152. truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
  153. truss/tests/test_data/test_openai/__init__.py +0 -0
  154. truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
  155. truss/tests/test_data/test_openai/model/__init__.py +0 -0
  156. truss/tests/test_data/test_openai/model/model.py +15 -0
  157. truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
  158. truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
  159. truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
  160. truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
  161. truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
  162. truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
  163. truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
  164. truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
  165. truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
  166. truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
  167. truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
  168. truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
  169. truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
  170. truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
  171. truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
  172. truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
  173. truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
  174. truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
  175. truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
  176. truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
  177. truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
  178. truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
  179. truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
  180. truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
  181. truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
  182. truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
  183. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
  184. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
  185. truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
  186. truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
  187. truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
  188. truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
  189. truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
  190. truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
  191. truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
  192. truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
  193. truss/tests/test_data/test_truss/__init__.py +0 -0
  194. truss/tests/test_data/test_truss/config.yaml +4 -0
  195. truss/tests/test_data/test_truss/model/__init__.py +0 -0
  196. truss/tests/test_data/test_truss/model/dummy +0 -0
  197. truss/tests/test_data/test_truss/packages/__init__.py +0 -0
  198. truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
  199. truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
  200. truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
  201. truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
  202. truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
  203. truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
  204. truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
  205. truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
  206. truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
  207. truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
  208. truss/tests/test_docker.py +2 -1
  209. truss/tests/test_model_inference.py +1340 -292
  210. truss/tests/test_model_schema.py +33 -26
  211. truss/tests/test_testing_utilities_for_other_tests.py +50 -5
  212. truss/tests/test_truss_gatherer.py +3 -5
  213. truss/tests/test_truss_handle.py +62 -59
  214. truss/tests/test_util.py +2 -1
  215. truss/tests/test_validation.py +15 -13
  216. truss/tests/trt_llm/test_trt_llm_config.py +41 -0
  217. truss/tests/trt_llm/test_validation.py +91 -0
  218. truss/tests/util/test_config_checks.py +40 -0
  219. truss/tests/util/test_env_vars.py +14 -0
  220. truss/tests/util/test_path.py +10 -23
  221. truss/trt_llm/config_checks.py +43 -0
  222. truss/trt_llm/validation.py +42 -0
  223. truss/truss_handle/__init__.py +0 -0
  224. truss/truss_handle/build.py +122 -0
  225. truss/{decorators.py → truss_handle/decorators.py} +1 -1
  226. truss/truss_handle/patch/__init__.py +0 -0
  227. truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
  228. truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
  229. truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
  230. truss/truss_handle/patch/hash.py +71 -0
  231. truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
  232. truss/truss_handle/patch/signature.py +22 -0
  233. truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
  234. truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
  235. truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
  236. truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
  237. truss/util/.truss_ignore +3 -0
  238. truss/{docker.py → util/docker.py} +6 -2
  239. truss/util/download.py +6 -15
  240. truss/util/env_vars.py +41 -0
  241. truss/util/log_utils.py +52 -0
  242. truss/util/path.py +20 -20
  243. truss/util/requirements.py +11 -0
  244. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
  245. truss-0.60.0.dist-info/RECORD +324 -0
  246. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
  247. truss-0.60.0.dist-info/entry_points.txt +4 -0
  248. truss_chains/__init__.py +71 -0
  249. truss_chains/definitions.py +756 -0
  250. truss_chains/deployment/__init__.py +0 -0
  251. truss_chains/deployment/code_gen.py +816 -0
  252. truss_chains/deployment/deployment_client.py +871 -0
  253. truss_chains/framework.py +1480 -0
  254. truss_chains/public_api.py +231 -0
  255. truss_chains/py.typed +0 -0
  256. truss_chains/pydantic_numpy.py +131 -0
  257. truss_chains/reference_code/reference_chainlet.py +34 -0
  258. truss_chains/reference_code/reference_model.py +10 -0
  259. truss_chains/remote_chainlet/__init__.py +0 -0
  260. truss_chains/remote_chainlet/model_skeleton.py +60 -0
  261. truss_chains/remote_chainlet/stub.py +380 -0
  262. truss_chains/remote_chainlet/utils.py +332 -0
  263. truss_chains/streaming.py +378 -0
  264. truss_chains/utils.py +178 -0
  265. CODE_OF_CONDUCT.md +0 -131
  266. CONTRIBUTING.md +0 -48
  267. README.md +0 -137
  268. context_builder.Dockerfile +0 -24
  269. truss/blob/blob_backend.py +0 -10
  270. truss/blob/blob_backend_registry.py +0 -23
  271. truss/blob/http_public_blob_backend.py +0 -23
  272. truss/build/__init__.py +0 -2
  273. truss/build/build.py +0 -143
  274. truss/build/configure.py +0 -63
  275. truss/cli/__init__.py +0 -2
  276. truss/cli/console.py +0 -5
  277. truss/cli/create.py +0 -5
  278. truss/config/trt_llm.py +0 -81
  279. truss/constants.py +0 -61
  280. truss/model_inference.py +0 -123
  281. truss/patch/types.py +0 -30
  282. truss/pytest.ini +0 -7
  283. truss/server/common/errors.py +0 -100
  284. truss/server/common/termination_handler_middleware.py +0 -64
  285. truss/server/common/truss_server.py +0 -389
  286. truss/server/control/patch/model_code_patch_applier.py +0 -46
  287. truss/server/control/patch/requirement_name_identifier.py +0 -17
  288. truss/server/inference_server.py +0 -29
  289. truss/server/model_wrapper.py +0 -434
  290. truss/server/shared/logging.py +0 -81
  291. truss/templates/trtllm/model/model.py +0 -97
  292. truss/templates/trtllm/packages/build_engine_utils.py +0 -34
  293. truss/templates/trtllm/packages/constants.py +0 -11
  294. truss/templates/trtllm/packages/schema.py +0 -216
  295. truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
  296. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
  297. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
  298. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
  299. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
  300. truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
  301. truss/templates/trtllm/packages/triton_client.py +0 -150
  302. truss/templates/trtllm/packages/utils.py +0 -43
  303. truss/test_data/context_builder_image_test/test.py +0 -4
  304. truss/test_data/happy.ipynb +0 -54
  305. truss/test_data/model_load_failure_test/config.yaml +0 -2
  306. truss/test_data/test_concurrency_truss/config.yaml +0 -2
  307. truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
  308. truss/test_data/test_streaming_truss/config.yaml +0 -3
  309. truss/test_data/test_truss/config.yaml +0 -2
  310. truss/tests/server/common/test_termination_handler_middleware.py +0 -93
  311. truss/tests/server/control/test_model_container_patch_applier.py +0 -203
  312. truss/tests/server/core/server/common/test_util.py +0 -19
  313. truss/tests/server/test_model_wrapper.py +0 -87
  314. truss/util/data_structures.py +0 -16
  315. truss-0.10.0rc1.dist-info/RECORD +0 -216
  316. truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
  317. truss/{server/shared → base}/__init__.py +0 -0
  318. truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
  319. truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
  320. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
  321. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
  322. truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
  323. truss/{server → templates/server}/common/__init__.py +0 -0
  324. truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
  325. truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
  326. truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
  327. truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
  328. truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
  329. truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
  330. truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
  331. truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
  332. truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
  333. truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
  334. truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
  335. truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
  336. truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
  337. truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
  338. truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
  339. truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
  340. truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
  341. truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
  342. truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
  343. truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
  344. truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
  345. truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
  346. truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
  347. /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
  348. /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
  349. /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
  350. /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
  351. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
  352. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
  353. /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
  354. /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
  355. /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
  356. /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
  357. /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
  358. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
  359. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
  360. /truss/{patch → truss_handle/patch}/constants.py +0 -0
  361. /truss/{notebook.py → util/notebook.py} +0 -0
  362. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
@@ -0,0 +1,332 @@
1
+ import asyncio
2
+ import builtins
3
+ import contextlib
4
+ import contextvars
5
+ import json
6
+ import logging
7
+ import sys
8
+ import textwrap
9
+ import threading
10
+ import traceback
11
+ from typing import Dict, Iterator, Mapping, NoReturn, Optional, Type, TypeVar
12
+
13
+ import aiohttp
14
+ import fastapi
15
+ import httpx
16
+ import pydantic
17
+ import starlette.requests
18
+ from truss.templates.shared import dynamic_config_resolver
19
+
20
+ from truss_chains import definitions
21
+
22
+ T = TypeVar("T")
23
+
24
+
25
+ def populate_chainlet_service_predict_urls(
26
+ chainlet_to_service: Mapping[str, definitions.ServiceDescriptor],
27
+ ) -> Mapping[str, definitions.DeployedServiceDescriptor]:
28
+ chainlet_to_deployed_service: Dict[str, definitions.DeployedServiceDescriptor] = {}
29
+ # If there are no dependencies of this chainlet, no need to derive dynamic URLs
30
+ if len(chainlet_to_service) == 0:
31
+ return chainlet_to_deployed_service
32
+
33
+ dynamic_chainlet_config_str = dynamic_config_resolver.get_dynamic_config_value_sync(
34
+ definitions.DYNAMIC_CHAINLET_CONFIG_KEY
35
+ )
36
+
37
+ if not dynamic_chainlet_config_str:
38
+ raise definitions.MissingDependencyError(
39
+ f"No '{definitions.DYNAMIC_CHAINLET_CONFIG_KEY}' "
40
+ "found. Cannot override Chainlet configs."
41
+ )
42
+
43
+ dynamic_chainlet_config = json.loads(dynamic_chainlet_config_str)
44
+
45
+ for chainlet_name, service_descriptor in chainlet_to_service.items():
46
+ display_name = service_descriptor.display_name
47
+
48
+ # NOTE: The Chainlet `display_name` in the Truss CLI
49
+ # corresponds to Chainlet `name` in the backend. As
50
+ # the dynamic Chainlet config is keyed on the backend
51
+ # Chainlet name, we have to look up config values by
52
+ # using the `display_name` in the service descriptor.
53
+ if display_name not in dynamic_chainlet_config:
54
+ raise definitions.MissingDependencyError(
55
+ f"Chainlet '{display_name}' not found in "
56
+ f"'{definitions.DYNAMIC_CHAINLET_CONFIG_KEY}'. "
57
+ f"Dynamic Chainlet config keys: {list(dynamic_chainlet_config)}."
58
+ )
59
+
60
+ chainlet_to_deployed_service[chainlet_name] = (
61
+ definitions.DeployedServiceDescriptor(
62
+ display_name=display_name,
63
+ name=service_descriptor.name,
64
+ options=service_descriptor.options,
65
+ predict_url=dynamic_chainlet_config[display_name]["predict_url"],
66
+ )
67
+ )
68
+
69
+ return chainlet_to_deployed_service
70
+
71
+
72
+ class AsyncSafeCounter:
73
+ def __init__(self, initial: int = 0) -> None:
74
+ self._counter = initial
75
+ self._lock = asyncio.Lock()
76
+
77
+ async def increment(self) -> int:
78
+ async with self._lock:
79
+ self._counter += 1
80
+ return self._counter
81
+
82
+ async def decrement(self) -> int:
83
+ async with self._lock:
84
+ self._counter -= 1
85
+ return self._counter
86
+
87
+ async def __aenter__(self) -> int:
88
+ return await self.increment()
89
+
90
+ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
91
+ await self.decrement()
92
+
93
+
94
+ class ThreadSafeCounter:
95
+ def __init__(self, initial: int = 0) -> None:
96
+ self._counter = initial
97
+ self._lock = threading.Lock()
98
+
99
+ def increment(self) -> int:
100
+ with self._lock:
101
+ self._counter += 1
102
+ return self._counter
103
+
104
+ def decrement(self) -> int:
105
+ with self._lock:
106
+ self._counter -= 1
107
+ return self._counter
108
+
109
+ def __enter__(self) -> int:
110
+ return self.increment()
111
+
112
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
113
+ self.decrement()
114
+
115
+
116
+ _trace_parent_context: contextvars.ContextVar[str] = contextvars.ContextVar(
117
+ "trace_parent"
118
+ )
119
+
120
+
121
+ @contextlib.contextmanager
122
+ def _trace_parent(request: starlette.requests.Request) -> Iterator[None]:
123
+ token = _trace_parent_context.set(
124
+ request.headers.get(definitions.OTEL_TRACE_PARENT_HEADER_KEY, "")
125
+ )
126
+ try:
127
+ yield
128
+ finally:
129
+ _trace_parent_context.reset(token)
130
+
131
+
132
+ @contextlib.contextmanager
133
+ def trace_parent_raw(trace_parent: str) -> Iterator[None]:
134
+ token = _trace_parent_context.set(trace_parent)
135
+ try:
136
+ yield
137
+ finally:
138
+ _trace_parent_context.reset(token)
139
+
140
+
141
+ def get_trace_parent() -> Optional[str]:
142
+ return _trace_parent_context.get()
143
+
144
+
145
+ def pydantic_set_field_dict(obj: pydantic.BaseModel) -> dict[str, pydantic.BaseModel]:
146
+ """Like `BaseModel.model_dump(exclude_unset=True), but only top-level.
147
+
148
+ This is used to get kwargs for invoking a function, while dropping fields for which
149
+ there is no value explicitly set in the pydantic model. A field is considered unset
150
+ if the key was not present in the incoming JSON request (from which the model was
151
+ parsed/initialized) and the pydantic model has a default value, such as `None`.
152
+
153
+ By dropping these unset fields, the default values from the function definition
154
+ will be used instead. This behavior ensures correct handling of arguments where
155
+ the function has a default, such as in the case of `run_remote`. If the model has
156
+ an optional field defaulting to `None`, this approach differentiates between
157
+ the user explicitly passing a value of `None` and the field being unset in the
158
+ request.
159
+
160
+ """
161
+ return {name: getattr(obj, name) for name in obj.model_fields_set}
162
+
163
+
164
+ # Error Propagation Utils. #############################################################
165
+
166
+
167
+ def _handle_exception(exception: Exception) -> NoReturn:
168
+ """Raises `HTTPException` with `RemoteErrorDetail`."""
169
+ if hasattr(exception, "__module__"):
170
+ exception_module_name = exception.__module__
171
+ else:
172
+ exception_module_name = None
173
+
174
+ error_stack = traceback.extract_tb(exception.__traceback__)
175
+ # Filter everything before (model.py) and after (stubs, error handling) so that only
176
+ # user-defined code remains. See test_e2e.py::test_chain for expected results.
177
+ model_predict_index = 0
178
+ first_stub_index = len(error_stack)
179
+ for i, frame in enumerate(error_stack):
180
+ if frame.filename.endswith("model/model.py") and frame.name == "predict":
181
+ model_predict_index = i + 1
182
+ if frame.filename.endswith("remote_chainlet/stub.py") and frame.name.startswith(
183
+ "predict" # predict sycnc|async|stream.
184
+ ):
185
+ first_stub_index = i - 1
186
+ break
187
+
188
+ final_tb = error_stack[model_predict_index:first_stub_index]
189
+ stack = [definitions.StackFrame.from_frame_summary(frame) for frame in final_tb]
190
+ error = definitions.RemoteErrorDetail(
191
+ exception_cls_name=exception.__class__.__name__,
192
+ exception_module_name=exception_module_name,
193
+ exception_message=str(exception),
194
+ user_stack_trace=list(stack),
195
+ )
196
+ raise fastapi.HTTPException(
197
+ status_code=500, detail=error.model_dump()
198
+ ) from exception
199
+
200
+
201
+ @contextlib.contextmanager
202
+ def _exception_to_http_error() -> Iterator[None]:
203
+ try:
204
+ yield
205
+ except Exception as e:
206
+ _handle_exception(e)
207
+
208
+
209
+ def _resolve_exception_class(error: definitions.RemoteErrorDetail) -> Type[Exception]:
210
+ """Tries to find the exception class in builtins or imported libs,
211
+ falls back to `definitions.GenericRemoteError` if not found."""
212
+ exception_cls = None
213
+ if error.exception_module_name is None:
214
+ exception_cls = getattr(builtins, error.exception_cls_name, None)
215
+ else:
216
+ if mod := sys.modules.get(error.exception_module_name):
217
+ exception_cls = getattr(mod, error.exception_cls_name, None)
218
+
219
+ if exception_cls is None:
220
+ logging.warning(
221
+ f"Could not resolve exception with name `{error.exception_cls_name}` "
222
+ f"and module `{error.exception_module_name}` - fall back to "
223
+ f"`{definitions.GenericRemoteException.__name__}`."
224
+ )
225
+ exception_cls = definitions.GenericRemoteException
226
+
227
+ if issubclass(exception_cls, pydantic.ValidationError):
228
+ # Cannot re-raise naively.
229
+ # https://github.com/pydantic/pydantic/issues/6734.
230
+ exception_cls = definitions.GenericRemoteException
231
+
232
+ return exception_cls
233
+
234
+
235
+ def _handle_response_error(response_json: dict, base_msg: str):
236
+ try:
237
+ error_json = response_json["error"]
238
+ except KeyError as e:
239
+ logging.error(f"response_json: {response_json}")
240
+ raise ValueError(
241
+ f"{base_msg}. Could not get `error` field from JSON response."
242
+ ) from e
243
+
244
+ try:
245
+ error = definitions.RemoteErrorDetail.model_validate(error_json)
246
+ except pydantic.ValidationError as e:
247
+ if isinstance(error_json, str):
248
+ msg = f"{base_msg}: '{error_json}'"
249
+ raise definitions.GenericRemoteException(msg) from None
250
+ raise ValueError(
251
+ f"{base_msg}: Could not parse chainlet error. Error details are expected "
252
+ "to be either a plain string (old truss models) or a serialized "
253
+ f"`{definitions.RemoteErrorDetail.__name__}`, got:\n{repr(error_json)}"
254
+ ) from e
255
+
256
+ exception_cls = _resolve_exception_class(error)
257
+ error_format = textwrap.indent(error.format(), "│ ")
258
+ *lines, last_line = error_format.splitlines()
259
+ last_line = f"╰{last_line[1:]}" if last_line.startswith("│") else last_line
260
+ error_format = "\n".join(lines + [last_line])
261
+ msg = (
262
+ f"(showing chained remote errors, root error at the bottom)\n"
263
+ f"├─ {base_msg}\n"
264
+ f"{error_format}"
265
+ )
266
+ raise exception_cls(msg)
267
+
268
+
269
+ def _make_base_error_message(remote_name: str, http_status: int) -> str:
270
+ return (
271
+ f"Error calling dependency Chainlet `{remote_name}`, "
272
+ f"HTTP status={http_status}, trace ID=`{get_trace_parent()}`."
273
+ )
274
+
275
+
276
+ def response_raise_errors(response: httpx.Response, remote_name: str) -> None:
277
+ """In case of error, raise it.
278
+
279
+ If the response error contains `RemoteErrorDetail`, it tries to re-raise
280
+ the same exception that was raised remotely and falls back to
281
+ `GenericRemoteException` if the exception class could not be resolved.
282
+
283
+ Exception messages are chained to trace back to the root cause, i.e. the first
284
+ Chainlet that raised an exception. E.g. the message might look like this:
285
+
286
+ ```
287
+ Chainlet-Traceback (most recent call last):
288
+ File "/packages/itest_chain.py", line 132, in run_remote
289
+ value = self._accumulate_parts(text_parts.parts)
290
+ File "/packages/itest_chain.py", line 144, in _accumulate_parts
291
+ value += self._text_to_num.run_remote(part)
292
+ ValueError: (showing chained remote errors, root error at the bottom)
293
+ ├─ Error in dependency Chainlet `TextToNum` (HTTP status 500):
294
+ │ Chainlet-Traceback (most recent call last):
295
+ │ File "/packages/itest_chain.py", line 87, in run_remote
296
+ │ generated_text = self._replicator.run_remote(data)
297
+ │ ValueError: (showing chained remote errors, root error at the bottom)
298
+ │ ├─ Error in dependency Chainlet `TextReplicator` (HTTP status 500):
299
+ │ │ Chainlet-Traceback (most recent call last):
300
+ │ │ File "/packages/itest_chain.py", line 52, in run_remote
301
+ │ │ validate_data(data)
302
+ │ │ File "/packages/itest_chain.py", line 36, in validate_data
303
+ │ │ raise ValueError(f"This input is too long: {len(data)}.")
304
+ ╰ ╰ ValueError: This input is too long: 100.
305
+ ```
306
+ """
307
+ if response.is_error:
308
+ base_msg = _make_base_error_message(remote_name, response.status_code)
309
+ try:
310
+ response_json = response.json()
311
+ except Exception as e:
312
+ raise ValueError(base_msg) from e
313
+ _handle_response_error(response_json, base_msg)
314
+
315
+
316
+ async def async_response_raise_errors(
317
+ response: aiohttp.ClientResponse, remote_name: str
318
+ ) -> None:
319
+ """Async version of `async_response_raise_errors`."""
320
+ if response.status >= 400:
321
+ base_msg = _make_base_error_message(remote_name, response.status)
322
+ try:
323
+ response_json = await response.json()
324
+ except Exception as e:
325
+ raise ValueError(base_msg) from e
326
+ _handle_response_error(response_json, base_msg)
327
+
328
+
329
+ @contextlib.contextmanager
330
+ def predict_context(request: starlette.requests.Request) -> Iterator[None]:
331
+ with _trace_parent(request), _exception_to_http_error():
332
+ yield
@@ -0,0 +1,378 @@
1
+ import asyncio
2
+ import dataclasses
3
+ import enum
4
+ import struct
5
+ import sys
6
+ from collections.abc import AsyncIterator
7
+ from typing import Generic, Optional, Protocol, Type, TypeVar, overload
8
+
9
+ import pydantic
10
+
11
+ _TAG_SIZE = 5 # uint8 + uint32.
12
+
13
+ _T = TypeVar("_T")
14
+
15
+ if sys.version_info < (3, 10):
16
+
17
+ async def anext(iterable: AsyncIterator[_T]) -> _T:
18
+ return await iterable.__anext__()
19
+
20
+
21
+ # Note on the (verbose) typing in this module: we want exact typing of the reader and
22
+ # writer helpers, while also allowing flexibility to users to leave out header/footer
23
+ # if not needed.
24
+ # Putting both a constraint on the header/footer types to be pydantic
25
+ # models, but also letting them be optional is not well-supported by typing tools,
26
+ # (missing feature is using type variables a constraints on other type variables).
27
+ #
28
+ # A functional, yet verbose workaround that gives correct variadic type inference,
29
+ # is using intermediate type variables `HeaderT` <-> `HeaderTT` and in conjunction with
30
+ # mapping out all usage combinations with overloads (the overloads essentially allow
31
+ # "conditional" binding of type vars). These overloads also allow to use granular
32
+ # reader/writer sub-classes conditionally, that have the read/write methods only for the
33
+ # data types configured, and implemented DRY with mixin classes.
34
+ ItemT = TypeVar("ItemT", bound=pydantic.BaseModel)
35
+ HeaderT = TypeVar("HeaderT", bound=pydantic.BaseModel)
36
+ FooterT = TypeVar("FooterT", bound=pydantic.BaseModel)
37
+
38
+ # Since header/footer could also be `None`, we need an extra type variable that
39
+ # can assume either `Type[HeaderT]` or `None` - `Type[None]` causes issues.
40
+ HeaderTT = TypeVar("HeaderTT")
41
+ FooterTT = TypeVar("FooterTT")
42
+
43
+
44
+ @dataclasses.dataclass
45
+ class StreamTypes(Generic[ItemT, HeaderTT, FooterTT]):
46
+ item_type: Type[ItemT]
47
+ header_type: HeaderTT # Is either `Type[HeaderT]` or `None`.
48
+ footer_type: FooterTT # Is either `Type[FooterT]` or `None`.
49
+
50
+
51
+ @overload
52
+ def stream_types(
53
+ item_type: Type[ItemT], *, header_type: Type[HeaderT], footer_type: Type[FooterT]
54
+ ) -> StreamTypes[ItemT, HeaderT, FooterT]: ...
55
+
56
+
57
+ @overload
58
+ def stream_types(
59
+ item_type: Type[ItemT], *, header_type: Type[HeaderT]
60
+ ) -> StreamTypes[ItemT, HeaderT, None]: ...
61
+
62
+
63
+ @overload
64
+ def stream_types(
65
+ item_type: Type[ItemT], *, footer_type: Type[FooterT]
66
+ ) -> StreamTypes[ItemT, None, FooterT]: ...
67
+
68
+
69
+ @overload
70
+ def stream_types(item_type: Type[ItemT]) -> StreamTypes[ItemT, None, None]: ...
71
+
72
+
73
+ def stream_types(
74
+ item_type: Type[ItemT],
75
+ *,
76
+ header_type: Optional[Type[HeaderT]] = None,
77
+ footer_type: Optional[Type[FooterT]] = None,
78
+ ) -> StreamTypes:
79
+ """Creates a bundle of item type and potentially header/footer types,
80
+ each as pydantic model."""
81
+ # This indirection for creating `StreamTypes` is needed to get generic typing.
82
+ return StreamTypes(item_type, header_type, footer_type)
83
+
84
+
85
+ # Reading ##############################################################################
86
+
87
+
88
+ class _Delimiter(enum.IntEnum):
89
+ NOT_SET = enum.auto()
90
+ HEADER = enum.auto()
91
+ ITEM = enum.auto()
92
+ FOOTER = enum.auto()
93
+ END = enum.auto()
94
+
95
+
96
+ class _Streamer(Generic[ItemT, HeaderTT, FooterTT]):
97
+ _stream_types: StreamTypes[ItemT, HeaderTT, FooterTT]
98
+
99
+ def __init__(self, types: StreamTypes[ItemT, HeaderTT, FooterTT]) -> None:
100
+ self._stream_types = types
101
+
102
+
103
+ # Reading ##############################################################################
104
+
105
+
106
+ class _ByteReader:
107
+ """Helper to provide `readexactly` API for an async bytes iterator."""
108
+
109
+ def __init__(self, source: AsyncIterator[bytes]) -> None:
110
+ self._source = source
111
+ self._buffer = bytearray()
112
+
113
+ async def readexactly(self, num_bytes: int) -> bytes:
114
+ while len(self._buffer) < num_bytes:
115
+ try:
116
+ chunk = await anext(self._source)
117
+ except StopAsyncIteration:
118
+ break
119
+ self._buffer.extend(chunk)
120
+
121
+ if len(self._buffer) < num_bytes:
122
+ if len(self._buffer) == 0:
123
+ raise EOFError()
124
+ raise asyncio.IncompleteReadError(self._buffer, num_bytes)
125
+
126
+ result = bytes(self._buffer[:num_bytes])
127
+ del self._buffer[:num_bytes]
128
+ return result
129
+
130
+
131
+ class _StreamReaderProtocol(Protocol[ItemT, HeaderTT, FooterTT]):
132
+ _stream_types: StreamTypes[ItemT, HeaderTT, FooterTT]
133
+ _footer_data: Optional[bytes]
134
+
135
+ async def _read(self) -> tuple[_Delimiter, bytes]: ...
136
+
137
+
138
+ class _StreamReader(_Streamer[ItemT, HeaderTT, FooterTT]):
139
+ _stream: _ByteReader
140
+ _footer_data: Optional[bytes]
141
+
142
+ def __init__(
143
+ self,
144
+ types: StreamTypes[ItemT, HeaderTT, FooterTT],
145
+ stream: AsyncIterator[bytes],
146
+ ) -> None:
147
+ super().__init__(types)
148
+ self._stream = _ByteReader(stream)
149
+ self._footer_data = None
150
+
151
+ @staticmethod
152
+ def _unpack_tag(tag: bytes) -> tuple[_Delimiter, int]:
153
+ enum_value, length = struct.unpack(">BI", tag)
154
+ return _Delimiter(enum_value), length
155
+
156
+ async def _read(self) -> tuple[_Delimiter, bytes]:
157
+ try:
158
+ tag = await self._stream.readexactly(_TAG_SIZE)
159
+ # It's ok to read nothing (end of stream), but unexpected to read partial.
160
+ except asyncio.IncompleteReadError:
161
+ raise
162
+ except EOFError:
163
+ return _Delimiter.END, b""
164
+
165
+ delimiter, length = self._unpack_tag(tag)
166
+ if not length:
167
+ return delimiter, b""
168
+ data_bytes = await self._stream.readexactly(length)
169
+ return delimiter, data_bytes
170
+
171
+ async def read_items(self) -> AsyncIterator[ItemT]:
172
+ delimiter, data_bytes = await self._read()
173
+ if delimiter == _Delimiter.HEADER:
174
+ raise ValueError(
175
+ "Called `read_items`, but there the stream contains header data, which "
176
+ "is not consumed. Call `read_header` first or remove sending a header."
177
+ )
178
+ if delimiter in (_Delimiter.FOOTER, _Delimiter.END): # In case of 0 items.
179
+ self._footer_data = data_bytes
180
+ return
181
+
182
+ assert delimiter == _Delimiter.ITEM
183
+ while True:
184
+ yield self._stream_types.item_type.model_validate_json(data_bytes)
185
+ # We don't know if the next data is another item, footer or the end.
186
+ delimiter, data_bytes = await self._read()
187
+ if delimiter == _Delimiter.END:
188
+ return
189
+ if delimiter == _Delimiter.FOOTER:
190
+ self._footer_data = data_bytes
191
+ return
192
+
193
+
194
+ class _HeaderReadMixin(_Streamer[ItemT, HeaderT, FooterTT]):
195
+ async def read_header(
196
+ self: _StreamReaderProtocol[ItemT, HeaderT, FooterTT],
197
+ ) -> HeaderT:
198
+ delimiter, data_bytes = await self._read()
199
+ if delimiter != _Delimiter.HEADER:
200
+ raise ValueError("Stream does not contain header.")
201
+ return self._stream_types.header_type.model_validate_json(data_bytes)
202
+
203
+
204
+ class _FooterReadMixin(_Streamer[ItemT, HeaderTT, FooterT]):
205
+ _footer_data: Optional[bytes]
206
+
207
+ async def read_footer(
208
+ self: _StreamReaderProtocol[ItemT, HeaderTT, FooterT],
209
+ ) -> FooterT:
210
+ if self._footer_data is None:
211
+ delimiter, data_bytes = await self._read()
212
+ if delimiter != _Delimiter.FOOTER:
213
+ raise ValueError("Stream does not contain footer.")
214
+ self._footer_data = data_bytes
215
+
216
+ footer = self._stream_types.footer_type.model_validate_json(self._footer_data)
217
+ self._footer_data = None
218
+ return footer
219
+
220
+
221
+ class StreamReaderWithHeader(
222
+ _StreamReader[ItemT, HeaderT, FooterTT], _HeaderReadMixin[ItemT, HeaderT, FooterTT]
223
+ ): ...
224
+
225
+
226
+ class StreamReaderWithFooter(
227
+ _StreamReader[ItemT, HeaderTT, FooterT], _FooterReadMixin[ItemT, HeaderTT, FooterT]
228
+ ): ...
229
+
230
+
231
+ class StreamReaderFull(
232
+ _StreamReader[ItemT, HeaderT, FooterT],
233
+ _HeaderReadMixin[ItemT, HeaderT, FooterT],
234
+ _FooterReadMixin[ItemT, HeaderT, FooterT],
235
+ ): ...
236
+
237
+
238
+ @overload
239
+ def stream_reader(
240
+ types: StreamTypes[ItemT, None, None], stream: AsyncIterator[bytes]
241
+ ) -> _StreamReader[ItemT, None, None]: ...
242
+
243
+
244
+ @overload
245
+ def stream_reader(
246
+ types: StreamTypes[ItemT, HeaderT, None], stream: AsyncIterator[bytes]
247
+ ) -> StreamReaderWithHeader[ItemT, HeaderT, None]: ...
248
+
249
+
250
+ @overload
251
+ def stream_reader(
252
+ types: StreamTypes[ItemT, None, FooterT], stream: AsyncIterator[bytes]
253
+ ) -> StreamReaderWithFooter[ItemT, None, FooterT]: ...
254
+
255
+
256
+ @overload
257
+ def stream_reader(
258
+ types: StreamTypes[ItemT, HeaderT, FooterT], stream: AsyncIterator[bytes]
259
+ ) -> StreamReaderFull[ItemT, HeaderT, FooterT]: ...
260
+
261
+
262
+ def stream_reader(
263
+ types: StreamTypes[ItemT, HeaderTT, FooterTT], stream: AsyncIterator[bytes]
264
+ ) -> _StreamReader:
265
+ if types.header_type is None and types.footer_type is None:
266
+ return _StreamReader(types, stream)
267
+ if types.header_type is None:
268
+ return StreamReaderWithFooter(types, stream)
269
+ if types.footer_type is None:
270
+ return StreamReaderWithHeader(types, stream)
271
+
272
+ return StreamReaderFull(types, stream)
273
+
274
+
275
+ # Writing ##############################################################################
276
+
277
+
278
+ class _StreamWriterProtocol(Protocol[ItemT, HeaderTT, FooterTT]):
279
+ _stream_types: StreamTypes[ItemT, HeaderTT, FooterTT]
280
+ _last_sent: _Delimiter
281
+
282
+ def _serialize(self, obj: pydantic.BaseModel, delimiter: _Delimiter) -> bytes: ...
283
+
284
+
285
+ class _StreamWriter(_Streamer[ItemT, HeaderTT, FooterTT]):
286
+ def __init__(self, types: StreamTypes[ItemT, HeaderTT, FooterTT]) -> None:
287
+ super().__init__(types)
288
+ self._last_sent = _Delimiter.NOT_SET
289
+ self._stream_types = types
290
+
291
+ @staticmethod
292
+ def _pack_tag(delimiter: _Delimiter, length: int) -> bytes:
293
+ return struct.pack(">BI", delimiter.value, length)
294
+
295
+ def _serialize(self, obj: pydantic.BaseModel, delimiter: _Delimiter) -> bytes:
296
+ data_bytes = obj.model_dump_json().encode()
297
+ data = bytearray(self._pack_tag(delimiter, len(data_bytes)))
298
+ data.extend(data_bytes)
299
+ # Starlette cannot handle byte array, but view works..
300
+ return memoryview(data)
301
+
302
+ def yield_item(self, item: ItemT) -> bytes:
303
+ if self._last_sent in (_Delimiter.FOOTER, _Delimiter.END):
304
+ raise ValueError("Cannot yield item after sending footer / closing stream.")
305
+ self._last_sent = _Delimiter.ITEM
306
+ return self._serialize(item, _Delimiter.ITEM)
307
+
308
+
309
+ class _HeaderWriteMixin(_Streamer[ItemT, HeaderT, FooterTT]):
310
+ def yield_header(
311
+ self: _StreamWriterProtocol[ItemT, HeaderT, FooterTT], header: HeaderT
312
+ ) -> bytes:
313
+ if self._last_sent != _Delimiter.NOT_SET:
314
+ raise ValueError("Cannot yield header after other data has been sent.")
315
+ self._last_sent = _Delimiter.HEADER
316
+ return self._serialize(header, _Delimiter.HEADER)
317
+
318
+
319
+ class _FooterWriteMixin(_Streamer[ItemT, HeaderTT, FooterT]):
320
+ def yield_footer(
321
+ self: _StreamWriterProtocol[ItemT, HeaderTT, FooterT], footer: FooterT
322
+ ) -> bytes:
323
+ if self._last_sent == _Delimiter.END:
324
+ raise ValueError("Cannot yield footer after closing stream.")
325
+ self._last_sent = _Delimiter.FOOTER
326
+ return self._serialize(footer, _Delimiter.FOOTER)
327
+
328
+
329
+ class StreamWriterWithHeader(
330
+ _StreamWriter[ItemT, HeaderT, FooterTT], _HeaderWriteMixin[ItemT, HeaderT, FooterTT]
331
+ ): ...
332
+
333
+
334
+ class StreamWriterWithFooter(
335
+ _StreamWriter[ItemT, HeaderTT, FooterT], _FooterWriteMixin[ItemT, HeaderTT, FooterT]
336
+ ): ...
337
+
338
+
339
+ class StreamWriterFull(
340
+ _StreamWriter[ItemT, HeaderT, FooterT],
341
+ _HeaderWriteMixin[ItemT, HeaderT, FooterT],
342
+ _FooterWriteMixin[ItemT, HeaderT, FooterT],
343
+ ): ...
344
+
345
+
346
+ @overload
347
+ def stream_writer(
348
+ types: StreamTypes[ItemT, None, None],
349
+ ) -> _StreamWriter[ItemT, None, None]: ...
350
+
351
+
352
+ @overload
353
+ def stream_writer(
354
+ types: StreamTypes[ItemT, HeaderT, None],
355
+ ) -> StreamWriterWithHeader[ItemT, HeaderT, None]: ...
356
+
357
+
358
+ @overload
359
+ def stream_writer(
360
+ types: StreamTypes[ItemT, None, FooterT],
361
+ ) -> StreamWriterWithFooter[ItemT, None, FooterT]: ...
362
+
363
+
364
+ @overload
365
+ def stream_writer(
366
+ types: StreamTypes[ItemT, HeaderT, FooterT],
367
+ ) -> StreamWriterFull[ItemT, HeaderT, FooterT]: ...
368
+
369
+
370
+ def stream_writer(types: StreamTypes[ItemT, HeaderTT, FooterTT]) -> _StreamWriter:
371
+ if types.header_type is None and types.footer_type is None:
372
+ return _StreamWriter(types)
373
+ if types.header_type is None:
374
+ return StreamWriterWithFooter(types)
375
+ if types.footer_type is None:
376
+ return StreamWriterWithHeader(types)
377
+
378
+ return StreamWriterFull(types)