truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

Files changed (362) hide show
  1. truss/__init__.py +10 -3
  2. truss/api/__init__.py +123 -0
  3. truss/api/definitions.py +51 -0
  4. truss/base/constants.py +116 -0
  5. truss/base/custom_types.py +29 -0
  6. truss/{errors.py → base/errors.py} +4 -0
  7. truss/base/trt_llm_config.py +310 -0
  8. truss/{truss_config.py → base/truss_config.py} +344 -31
  9. truss/{truss_spec.py → base/truss_spec.py} +20 -6
  10. truss/{validation.py → base/validation.py} +60 -11
  11. truss/cli/cli.py +841 -88
  12. truss/{remote → cli}/remote_cli.py +2 -7
  13. truss/contexts/docker_build_setup.py +67 -0
  14. truss/contexts/image_builder/cache_warmer.py +2 -8
  15. truss/contexts/image_builder/image_builder.py +1 -1
  16. truss/contexts/image_builder/serving_image_builder.py +292 -46
  17. truss/contexts/image_builder/util.py +1 -3
  18. truss/contexts/local_loader/docker_build_emulator.py +58 -0
  19. truss/contexts/local_loader/load_model_local.py +2 -2
  20. truss/contexts/local_loader/truss_module_loader.py +1 -1
  21. truss/contexts/local_loader/utils.py +1 -1
  22. truss/local/local_config.py +2 -6
  23. truss/local/local_config_handler.py +20 -5
  24. truss/patch/__init__.py +1 -0
  25. truss/patch/hash.py +4 -70
  26. truss/patch/signature.py +4 -16
  27. truss/patch/truss_dir_patch_applier.py +3 -78
  28. truss/remote/baseten/api.py +308 -23
  29. truss/remote/baseten/auth.py +3 -3
  30. truss/remote/baseten/core.py +257 -50
  31. truss/remote/baseten/custom_types.py +44 -0
  32. truss/remote/baseten/error.py +4 -0
  33. truss/remote/baseten/remote.py +369 -118
  34. truss/remote/baseten/service.py +118 -11
  35. truss/remote/baseten/utils/status.py +29 -0
  36. truss/remote/baseten/utils/tar.py +34 -22
  37. truss/remote/baseten/utils/transfer.py +36 -23
  38. truss/remote/remote_factory.py +14 -5
  39. truss/remote/truss_remote.py +72 -45
  40. truss/templates/base.Dockerfile.jinja +18 -16
  41. truss/templates/cache.Dockerfile.jinja +3 -3
  42. truss/{server → templates/control}/control/application.py +14 -35
  43. truss/{server → templates/control}/control/endpoints.py +39 -9
  44. truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
  45. truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
  46. truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
  47. truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
  48. truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
  49. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
  50. truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
  51. truss/{server → templates/control}/control/server.py +11 -6
  52. truss/templates/control/requirements.txt +9 -0
  53. truss/templates/custom_python_dx/my_model.py +28 -0
  54. truss/templates/docker_server/proxy.conf.jinja +42 -0
  55. truss/templates/docker_server/supervisord.conf.jinja +27 -0
  56. truss/templates/docker_server_requirements.txt +1 -0
  57. truss/templates/server/common/errors.py +231 -0
  58. truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
  59. truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
  60. truss/{server → templates/server}/common/retry.py +1 -0
  61. truss/{server → templates/server}/common/schema.py +11 -9
  62. truss/templates/server/common/tracing.py +157 -0
  63. truss/templates/server/main.py +9 -0
  64. truss/templates/server/model_wrapper.py +961 -0
  65. truss/templates/server/requirements.txt +21 -0
  66. truss/templates/server/truss_server.py +447 -0
  67. truss/templates/server.Dockerfile.jinja +62 -14
  68. truss/templates/shared/dynamic_config_resolver.py +28 -0
  69. truss/templates/shared/lazy_data_resolver.py +164 -0
  70. truss/templates/shared/log_config.py +125 -0
  71. truss/{server → templates}/shared/secrets_resolver.py +1 -2
  72. truss/{server → templates}/shared/serialization.py +31 -9
  73. truss/{server → templates}/shared/util.py +3 -13
  74. truss/templates/trtllm-audio/model/model.py +49 -0
  75. truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
  76. truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
  77. truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
  78. truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
  79. truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
  80. truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
  81. truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
  82. truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
  83. truss/templates/trtllm-briton/src/extension.py +64 -0
  84. truss/tests/conftest.py +302 -94
  85. truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
  86. truss/tests/contexts/local_loader/test_load_local.py +2 -2
  87. truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
  88. truss/tests/patch/test_calc_patch.py +439 -127
  89. truss/tests/patch/test_dir_signature.py +3 -12
  90. truss/tests/patch/test_hash.py +1 -1
  91. truss/tests/patch/test_signature.py +1 -1
  92. truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
  93. truss/tests/patch/test_types.py +2 -2
  94. truss/tests/remote/baseten/test_api.py +153 -58
  95. truss/tests/remote/baseten/test_auth.py +2 -1
  96. truss/tests/remote/baseten/test_core.py +160 -12
  97. truss/tests/remote/baseten/test_remote.py +489 -77
  98. truss/tests/remote/baseten/test_service.py +55 -0
  99. truss/tests/remote/test_remote_factory.py +16 -18
  100. truss/tests/remote/test_truss_remote.py +26 -17
  101. truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
  102. truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
  103. truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
  104. truss/tests/{server → templates/control}/control/test_server.py +79 -24
  105. truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
  106. truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
  107. truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
  108. truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
  109. truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
  110. truss/tests/{server → templates/server}/common/test_retry.py +3 -3
  111. truss/tests/templates/server/test_model_wrapper.py +248 -0
  112. truss/tests/{server → templates/server}/test_schema.py +3 -5
  113. truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
  114. truss/tests/test_build.py +9 -52
  115. truss/tests/test_config.py +336 -77
  116. truss/tests/test_context_builder_image.py +3 -11
  117. truss/tests/test_control_truss_patching.py +7 -12
  118. truss/tests/test_custom_server.py +38 -0
  119. truss/tests/test_data/context_builder_image_test/test.py +3 -0
  120. truss/tests/test_data/happy.ipynb +56 -0
  121. truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
  122. truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
  123. truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
  124. truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
  125. truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
  126. truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
  127. truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
  128. truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
  129. truss/tests/test_data/test_async_truss/__init__.py +0 -0
  130. truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
  131. truss/tests/test_data/test_basic_truss/__init__.py +0 -0
  132. truss/tests/test_data/test_basic_truss/config.yaml +16 -0
  133. truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
  134. truss/tests/test_data/test_build_commands/__init__.py +0 -0
  135. truss/tests/test_data/test_build_commands/config.yaml +13 -0
  136. truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
  137. truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
  138. truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
  139. truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
  140. truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
  141. truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
  142. truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
  143. truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
  144. truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
  145. truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
  146. truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
  147. truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
  148. truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
  149. truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
  150. truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
  151. truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
  152. truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
  153. truss/tests/test_data/test_openai/__init__.py +0 -0
  154. truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
  155. truss/tests/test_data/test_openai/model/__init__.py +0 -0
  156. truss/tests/test_data/test_openai/model/model.py +15 -0
  157. truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
  158. truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
  159. truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
  160. truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
  161. truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
  162. truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
  163. truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
  164. truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
  165. truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
  166. truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
  167. truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
  168. truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
  169. truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
  170. truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
  171. truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
  172. truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
  173. truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
  174. truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
  175. truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
  176. truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
  177. truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
  178. truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
  179. truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
  180. truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
  181. truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
  182. truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
  183. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
  184. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
  185. truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
  186. truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
  187. truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
  188. truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
  189. truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
  190. truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
  191. truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
  192. truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
  193. truss/tests/test_data/test_truss/__init__.py +0 -0
  194. truss/tests/test_data/test_truss/config.yaml +4 -0
  195. truss/tests/test_data/test_truss/model/__init__.py +0 -0
  196. truss/tests/test_data/test_truss/model/dummy +0 -0
  197. truss/tests/test_data/test_truss/packages/__init__.py +0 -0
  198. truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
  199. truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
  200. truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
  201. truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
  202. truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
  203. truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
  204. truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
  205. truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
  206. truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
  207. truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
  208. truss/tests/test_docker.py +2 -1
  209. truss/tests/test_model_inference.py +1340 -292
  210. truss/tests/test_model_schema.py +33 -26
  211. truss/tests/test_testing_utilities_for_other_tests.py +50 -5
  212. truss/tests/test_truss_gatherer.py +3 -5
  213. truss/tests/test_truss_handle.py +62 -59
  214. truss/tests/test_util.py +2 -1
  215. truss/tests/test_validation.py +15 -13
  216. truss/tests/trt_llm/test_trt_llm_config.py +41 -0
  217. truss/tests/trt_llm/test_validation.py +91 -0
  218. truss/tests/util/test_config_checks.py +40 -0
  219. truss/tests/util/test_env_vars.py +14 -0
  220. truss/tests/util/test_path.py +10 -23
  221. truss/trt_llm/config_checks.py +43 -0
  222. truss/trt_llm/validation.py +42 -0
  223. truss/truss_handle/__init__.py +0 -0
  224. truss/truss_handle/build.py +122 -0
  225. truss/{decorators.py → truss_handle/decorators.py} +1 -1
  226. truss/truss_handle/patch/__init__.py +0 -0
  227. truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
  228. truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
  229. truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
  230. truss/truss_handle/patch/hash.py +71 -0
  231. truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
  232. truss/truss_handle/patch/signature.py +22 -0
  233. truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
  234. truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
  235. truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
  236. truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
  237. truss/util/.truss_ignore +3 -0
  238. truss/{docker.py → util/docker.py} +6 -2
  239. truss/util/download.py +6 -15
  240. truss/util/env_vars.py +41 -0
  241. truss/util/log_utils.py +52 -0
  242. truss/util/path.py +20 -20
  243. truss/util/requirements.py +11 -0
  244. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
  245. truss-0.60.0.dist-info/RECORD +324 -0
  246. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
  247. truss-0.60.0.dist-info/entry_points.txt +4 -0
  248. truss_chains/__init__.py +71 -0
  249. truss_chains/definitions.py +756 -0
  250. truss_chains/deployment/__init__.py +0 -0
  251. truss_chains/deployment/code_gen.py +816 -0
  252. truss_chains/deployment/deployment_client.py +871 -0
  253. truss_chains/framework.py +1480 -0
  254. truss_chains/public_api.py +231 -0
  255. truss_chains/py.typed +0 -0
  256. truss_chains/pydantic_numpy.py +131 -0
  257. truss_chains/reference_code/reference_chainlet.py +34 -0
  258. truss_chains/reference_code/reference_model.py +10 -0
  259. truss_chains/remote_chainlet/__init__.py +0 -0
  260. truss_chains/remote_chainlet/model_skeleton.py +60 -0
  261. truss_chains/remote_chainlet/stub.py +380 -0
  262. truss_chains/remote_chainlet/utils.py +332 -0
  263. truss_chains/streaming.py +378 -0
  264. truss_chains/utils.py +178 -0
  265. CODE_OF_CONDUCT.md +0 -131
  266. CONTRIBUTING.md +0 -48
  267. README.md +0 -137
  268. context_builder.Dockerfile +0 -24
  269. truss/blob/blob_backend.py +0 -10
  270. truss/blob/blob_backend_registry.py +0 -23
  271. truss/blob/http_public_blob_backend.py +0 -23
  272. truss/build/__init__.py +0 -2
  273. truss/build/build.py +0 -143
  274. truss/build/configure.py +0 -63
  275. truss/cli/__init__.py +0 -2
  276. truss/cli/console.py +0 -5
  277. truss/cli/create.py +0 -5
  278. truss/config/trt_llm.py +0 -81
  279. truss/constants.py +0 -61
  280. truss/model_inference.py +0 -123
  281. truss/patch/types.py +0 -30
  282. truss/pytest.ini +0 -7
  283. truss/server/common/errors.py +0 -100
  284. truss/server/common/termination_handler_middleware.py +0 -64
  285. truss/server/common/truss_server.py +0 -389
  286. truss/server/control/patch/model_code_patch_applier.py +0 -46
  287. truss/server/control/patch/requirement_name_identifier.py +0 -17
  288. truss/server/inference_server.py +0 -29
  289. truss/server/model_wrapper.py +0 -434
  290. truss/server/shared/logging.py +0 -81
  291. truss/templates/trtllm/model/model.py +0 -97
  292. truss/templates/trtllm/packages/build_engine_utils.py +0 -34
  293. truss/templates/trtllm/packages/constants.py +0 -11
  294. truss/templates/trtllm/packages/schema.py +0 -216
  295. truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
  296. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
  297. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
  298. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
  299. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
  300. truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
  301. truss/templates/trtllm/packages/triton_client.py +0 -150
  302. truss/templates/trtllm/packages/utils.py +0 -43
  303. truss/test_data/context_builder_image_test/test.py +0 -4
  304. truss/test_data/happy.ipynb +0 -54
  305. truss/test_data/model_load_failure_test/config.yaml +0 -2
  306. truss/test_data/test_concurrency_truss/config.yaml +0 -2
  307. truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
  308. truss/test_data/test_streaming_truss/config.yaml +0 -3
  309. truss/test_data/test_truss/config.yaml +0 -2
  310. truss/tests/server/common/test_termination_handler_middleware.py +0 -93
  311. truss/tests/server/control/test_model_container_patch_applier.py +0 -203
  312. truss/tests/server/core/server/common/test_util.py +0 -19
  313. truss/tests/server/test_model_wrapper.py +0 -87
  314. truss/util/data_structures.py +0 -16
  315. truss-0.10.0rc1.dist-info/RECORD +0 -216
  316. truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
  317. truss/{server/shared → base}/__init__.py +0 -0
  318. truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
  319. truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
  320. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
  321. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
  322. truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
  323. truss/{server → templates/server}/common/__init__.py +0 -0
  324. truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
  325. truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
  326. truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
  327. truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
  328. truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
  329. truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
  330. truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
  331. truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
  332. truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
  333. truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
  334. truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
  335. truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
  336. truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
  337. truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
  338. truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
  339. truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
  340. truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
  341. truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
  342. truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
  343. truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
  344. truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
  345. truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
  346. truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
  347. /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
  348. /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
  349. /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
  350. /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
  351. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
  352. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
  353. /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
  354. /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
  355. /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
  356. /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
  357. /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
  358. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
  359. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
  360. /truss/{patch → truss_handle/patch}/constants.py +0 -0
  361. /truss/{notebook.py → util/notebook.py} +0 -0
  362. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
@@ -0,0 +1,1480 @@
1
+ import abc
2
+ import ast
3
+ import atexit
4
+ import collections
5
+ import contextlib
6
+ import contextvars
7
+ import enum
8
+ import functools
9
+ import importlib.util
10
+ import inspect
11
+ import logging
12
+ import os
13
+ import pathlib
14
+ import pprint
15
+ import sys
16
+ import types
17
+ import warnings
18
+ from importlib.abc import Loader
19
+ from typing import (
20
+ Any,
21
+ Callable,
22
+ Iterable,
23
+ Iterator,
24
+ Mapping,
25
+ MutableMapping,
26
+ Optional,
27
+ Protocol,
28
+ Type,
29
+ TypeVar,
30
+ Union,
31
+ cast,
32
+ get_args,
33
+ get_origin,
34
+ )
35
+
36
+ import pydantic
37
+ from typing_extensions import ParamSpec
38
+
39
+ from truss_chains import definitions, utils
40
+
41
+ _SIMPLE_TYPES = {int, float, complex, bool, str, bytes, None, pydantic.BaseModel}
42
+ _SIMPLE_CONTAINERS = {list, dict}
43
+ _STREAM_TYPES = {str, bytes}
44
+
45
+ _DOCS_URL_CHAINING = (
46
+ "https://docs.baseten.co/chains/concepts#depends-call-other-chainlets"
47
+ )
48
+ _DOCS_URL_LOCAL = "https://docs.baseten.co/chains/guide#local-development"
49
+ _DOCS_URL_STREAMING = "https://docs.baseten.co/chains/guide#streaming"
50
+
51
+ # A "neutral dummy" endpoint descriptor if validation fails, this allows to safely
52
+ # continue checking for more errors.
53
+ _DUMMY_ENDPOINT_DESCRIPTOR = definitions.EndpointAPIDescriptor(
54
+ input_args=[], output_types=[], is_async=False, is_streaming=False
55
+ )
56
+
57
+
58
+ ChainletT = TypeVar("ChainletT", bound=definitions.ABCChainlet)
59
+ _P = ParamSpec("_P")
60
+ _R = TypeVar("_R")
61
+
62
+
63
+ # Error Collector ######################################################################
64
+
65
+
66
+ class _ErrorKind(str, enum.Enum):
67
+ TYPE_ERROR = enum.auto()
68
+ IO_TYPE_ERROR = enum.auto()
69
+ MISSING_API_ERROR = enum.auto()
70
+
71
+
72
+ class _ErrorLocation(definitions.SafeModel):
73
+ src_path: str
74
+ line: Optional[int] = None
75
+ chainlet_name: Optional[str] = None
76
+ method_name: Optional[str] = None
77
+
78
+ def __str__(self) -> str:
79
+ value = f"{self.src_path}:{self.line}"
80
+ if self.chainlet_name and self.method_name:
81
+ value = f"{value} ({self.chainlet_name}.{self.method_name})"
82
+ elif self.chainlet_name:
83
+ value = f"{value} ({self.chainlet_name})"
84
+ else:
85
+ assert not self.chainlet_name
86
+ return value
87
+
88
+
89
+ class _ValidationError(definitions.SafeModel):
90
+ msg: str
91
+ kind: _ErrorKind
92
+ location: _ErrorLocation
93
+
94
+ def __str__(self) -> str:
95
+ return f"{self.location} [kind: {self.kind.name}]: {self.msg}"
96
+
97
+
98
+ class _ErrorCollector:
99
+ _errors: list[_ValidationError]
100
+
101
+ def __init__(self) -> None:
102
+ self._errors = []
103
+ # This hook is for the case of just running the Chainlet file, without
104
+ # making a push - we want to surface the errors at exit.
105
+ atexit.register(self.maybe_display_errors)
106
+
107
+ def clear(self) -> None:
108
+ self._errors.clear()
109
+
110
+ def collect(self, error):
111
+ self._errors.append(error)
112
+
113
+ @property
114
+ def has_errors(self) -> bool:
115
+ return bool(self._errors)
116
+
117
+ @property
118
+ def num_errors(self) -> int:
119
+ return len(self._errors)
120
+
121
+ def format_errors(self) -> str:
122
+ parts = []
123
+ for error in self._errors:
124
+ parts.append(str(error))
125
+
126
+ return "\n".join(parts)
127
+
128
+ def maybe_display_errors(self) -> None:
129
+ if self.has_errors:
130
+ sys.stderr.write(self.format_errors())
131
+ sys.stderr.write("\n")
132
+
133
+
134
+ _global_error_collector = _ErrorCollector()
135
+
136
+
137
+ def _collect_error(msg: str, kind: _ErrorKind, location: _ErrorLocation):
138
+ _global_error_collector.collect(
139
+ _ValidationError(msg=msg, kind=kind, location=location)
140
+ )
141
+
142
+
143
+ def raise_validation_errors() -> None:
144
+ """Raises validation errors as combined ``ChainsUsageError``"""
145
+ if _global_error_collector.has_errors:
146
+ error_msg = _global_error_collector.format_errors()
147
+ _global_error_collector.clear() # Clear errors so `atexit` won't display them
148
+ raise definitions.ChainsUsageError(
149
+ "The user defined code does not comply with the required spec, "
150
+ f"please fix below:\n{error_msg}"
151
+ )
152
+
153
+
154
+ def raise_validation_errors_before(f: Callable[_P, _R]) -> Callable[_P, _R]:
155
+ """Raises validation errors as combined ``ChainsUsageError`` before invoking `f`."""
156
+
157
+ @functools.wraps(f)
158
+ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
159
+ raise_validation_errors()
160
+ return f(*args, **kwargs)
161
+
162
+ return wrapper
163
+
164
+
165
+ class _BaseProvisionMarker:
166
+ """A marker for object to be dependency injected by the framework."""
167
+
168
+
169
+ class ContextDependencyMarker(_BaseProvisionMarker):
170
+ def __str__(self) -> str:
171
+ return f"{self.__class__.__name__}"
172
+
173
+ def __getattr__(self, item: str) -> Any:
174
+ logging.error(f"Attempting to access attribute `{item}` on `{self}`.")
175
+ raise definitions.ChainsRuntimeError(
176
+ "It seems `chains.depends_context()` was used, but not as an argument "
177
+ "to the `__init__` method of a Chainlet - This is not supported."
178
+ f"See {_DOCS_URL_CHAINING}.\n"
179
+ "Example of correct `__init__` with context:\n"
180
+ f"{_example_chainlet_code()}"
181
+ )
182
+
183
+
184
+ class ChainletDependencyMarker(_BaseProvisionMarker):
185
+ chainlet_cls: Type[definitions.ABCChainlet]
186
+ retries: int
187
+
188
+ def __init__(
189
+ self,
190
+ chainlet_cls: Type[definitions.ABCChainlet],
191
+ options: definitions.RPCOptions,
192
+ ) -> None:
193
+ self.chainlet_cls = chainlet_cls
194
+ self.options = options
195
+
196
+ def __str__(self) -> str:
197
+ return f"{self.__class__.__name__}({self.chainlet_cls.name})"
198
+
199
+ def __getattr__(self, item: str) -> Any:
200
+ logging.error(f"Attempting to access attribute `{item}` on `{self}`.")
201
+ raise definitions.ChainsRuntimeError(
202
+ f"It seems `chains.depends({self.chainlet_cls.name})` was used, but "
203
+ "not as an argument to the `__init__` method of a Chainlet - This is not "
204
+ "supported. Dependency Chainlets must be passed as init arguments.\n"
205
+ f"See {_DOCS_URL_CHAINING}.\n"
206
+ "Example of correct `__init__` with dependencies:\n"
207
+ f"{_example_chainlet_code()}"
208
+ )
209
+
210
+
211
+ # Validation of Chainlet class definition ##############################################
212
+
213
+
214
+ @functools.cache
215
+ def _example_chainlet_code() -> str:
216
+ # Note: this function requires all chains modules to be initialized, because
217
+ # in `example_chainlet` the full chainlet validation process is triggered.
218
+ # To avoid circular import dependencies, `_example_chainlet_code` should only be
219
+ # called on erroneous code branches (which will not be triggered if
220
+ # `example_chainlet` is free of errors).
221
+ try:
222
+ from truss_chains.reference_code import reference_chainlet
223
+ # If `example_chainlet` fails validation and `_example_chainlet_code` is
224
+ # called as a result of that, we have a circular import ("partially initialized
225
+ # module 'truss_chains.example_chainlet' ...").
226
+ except AttributeError:
227
+ logging.error("`reference_chainlet` is broken.", exc_info=True, stack_info=True)
228
+ return "<EXAMPLE CODE MISSING/BROKEN>"
229
+
230
+ example_name = reference_chainlet.HelloWorld.name
231
+ return _get_cls_source(reference_chainlet.__file__, example_name)
232
+
233
+
234
+ @functools.cache
235
+ def _example_model_code() -> str:
236
+ try:
237
+ from truss_chains.reference_code import reference_model
238
+ except AttributeError:
239
+ logging.error("`reference_model` is broken.", exc_info=True, stack_info=True)
240
+ return "<EXAMPLE CODE MISSING/BROKEN>"
241
+
242
+ example_name = reference_model.HelloWorld.name
243
+ return _get_cls_source(reference_model.__file__, example_name)
244
+
245
+
246
+ def _get_cls_source(src_path: str, target_class_name: str) -> str:
247
+ source = pathlib.Path(src_path).read_text()
248
+ tree = ast.parse(source)
249
+ class_code = ""
250
+ for node in ast.walk(tree):
251
+ if isinstance(node, ast.ClassDef) and node.name == target_class_name:
252
+ # Extract the source code of the class definition
253
+ lines = source.splitlines()
254
+ class_code = "\n".join(lines[node.lineno - 1 : node.end_lineno])
255
+ break
256
+ return class_code
257
+
258
+
259
+ def _instantiation_error_msg(cls_name: str, location: Optional[str] = None) -> str:
260
+ location_format = f"{location}\n" if location else ""
261
+ return (
262
+ f"Error when instantiating Chainlet `{cls_name}`.\n"
263
+ f"{location_format}"
264
+ "Chainlets cannot be naively instantiated. Possible fixes:\n"
265
+ "1. To use Chainlets as dependencies in other Chainlets ('chaining'), "
266
+ f"add them as init argument. See {_DOCS_URL_CHAINING}.\n"
267
+ f"2. For local / debug execution, use the `{run_local.__name__}`-"
268
+ f"context. See {_DOCS_URL_LOCAL}. You cannot use helper functions to "
269
+ "instantiate the Chain in this case.\n"
270
+ "3. Push the chain and call the remote endpoint.\n"
271
+ "Example of correct `__init__` with dependencies:\n"
272
+ f"{_example_chainlet_code()}"
273
+ )
274
+
275
+
276
+ def _validate_io_type(
277
+ annotation: Any, param_name: str, location: _ErrorLocation
278
+ ) -> None:
279
+ """
280
+ For Chainlet I/O (both data or parameters), we allow simple types
281
+ (int, str, float...) and `list` or `dict` containers of these.
282
+ Any deeper nested and structured data must be typed as a pydantic model.
283
+ """
284
+ containers_str = [c.__name__ for c in _SIMPLE_CONTAINERS]
285
+ types_str = [c.__name__ if c is not None else "None" for c in _SIMPLE_TYPES]
286
+ if isinstance(annotation, str):
287
+ _collect_error(
288
+ f"A string-valued type annotation was found for `{param_name}` of type "
289
+ f"`{annotation}`. Use only actual types objects and avoid "
290
+ "`from __future__ import annotations` (if needed upgrade python).",
291
+ _ErrorKind.IO_TYPE_ERROR,
292
+ location,
293
+ )
294
+ return
295
+ if annotation in _SIMPLE_TYPES:
296
+ return
297
+
298
+ error_msg = (
299
+ f"Unsupported I/O type for `{param_name}` of type `{annotation}`. "
300
+ "Supported are:\n"
301
+ f"\t* simple types: {types_str}\n"
302
+ "\t* containers of these simple types, with annotated item types: "
303
+ f"{containers_str}, e.g. `dict[str, int]` (use built-in types, not "
304
+ "`typing.Dict`).\n"
305
+ "\t* For complicated / nested data structures: `pydantic` models."
306
+ )
307
+ if isinstance(annotation, types.GenericAlias):
308
+ if get_origin(annotation) not in _SIMPLE_CONTAINERS:
309
+ _collect_error(error_msg, _ErrorKind.IO_TYPE_ERROR, location)
310
+ return
311
+ args = get_args(annotation)
312
+ for arg in args:
313
+ if not (
314
+ arg in _SIMPLE_TYPES or utils.issubclass_safe(arg, pydantic.BaseModel)
315
+ ):
316
+ _collect_error(error_msg, _ErrorKind.IO_TYPE_ERROR, location)
317
+ return
318
+ pass
319
+ return
320
+ if utils.issubclass_safe(annotation, pydantic.BaseModel):
321
+ return
322
+
323
+ _collect_error(error_msg, _ErrorKind.IO_TYPE_ERROR, location)
324
+
325
+
326
+ def _validate_streaming_output_type(
327
+ annotation: Any, location: _ErrorLocation
328
+ ) -> definitions.StreamingTypeDescriptor:
329
+ origin = get_origin(annotation)
330
+ assert origin in (collections.abc.AsyncIterator, collections.abc.Iterator)
331
+ args = get_args(annotation)
332
+ if len(args) < 1:
333
+ stream_types = sorted(list(x.__name__ for x in _STREAM_TYPES))
334
+ _collect_error(
335
+ f"Iterators must be annotated with type (one of {stream_types}).",
336
+ _ErrorKind.IO_TYPE_ERROR,
337
+ location,
338
+ )
339
+ return definitions.StreamingTypeDescriptor(
340
+ raw=annotation, origin_type=origin, arg_type=bytes
341
+ )
342
+
343
+ assert len(args) == 1, "Iterator type annotations cannot have more than 1 arg."
344
+ arg = args[0]
345
+ if arg not in _STREAM_TYPES:
346
+ msg = (
347
+ "Streaming endpoints (containing `yield` statements) can only yield string "
348
+ "or byte items. For streaming structured pydantic data, use `stream_writer`"
349
+ "and `stream_reader` helpers.\n"
350
+ f"See streaming docs: {_DOCS_URL_STREAMING}"
351
+ )
352
+ _collect_error(msg, _ErrorKind.IO_TYPE_ERROR, location)
353
+
354
+ return definitions.StreamingTypeDescriptor(
355
+ raw=annotation, origin_type=origin, arg_type=arg
356
+ )
357
+
358
+
359
+ def _validate_method_signature(
360
+ method_name: str, location: _ErrorLocation, params: list[inspect.Parameter]
361
+ ) -> None:
362
+ if len(params) == 0:
363
+ _collect_error(
364
+ f"`{method_name}` must be a method, i.e. with `{definitions.SELF_ARG_NAME}` as "
365
+ "first argument. Got function with no arguments.",
366
+ _ErrorKind.TYPE_ERROR,
367
+ location,
368
+ )
369
+ elif params[0].name != definitions.SELF_ARG_NAME:
370
+ _collect_error(
371
+ f"`{method_name}` must be a method, i.e. with `{definitions.SELF_ARG_NAME}` as "
372
+ f"first argument. Got `{params[0].name}` as first argument.",
373
+ _ErrorKind.TYPE_ERROR,
374
+ location,
375
+ )
376
+
377
+
378
+ def _validate_endpoint_params(
379
+ params: list[inspect.Parameter], location: _ErrorLocation
380
+ ) -> list[definitions.InputArg]:
381
+ _validate_method_signature(definitions.RUN_REMOTE_METHOD_NAME, location, params)
382
+ input_args = []
383
+ for param in params[1:]: # Skip self argument.
384
+ if param.annotation == inspect.Parameter.empty:
385
+ _collect_error(
386
+ "Arguments of endpoints must have type annotations. "
387
+ f"Parameter `{param.name}` has no type annotation.",
388
+ _ErrorKind.IO_TYPE_ERROR,
389
+ location,
390
+ )
391
+ else:
392
+ _validate_io_type(param.annotation, param.name, location)
393
+ type_descriptor = definitions.TypeDescriptor(raw=param.annotation)
394
+ is_optional = param.default != inspect.Parameter.empty
395
+ input_args.append(
396
+ definitions.InputArg(
397
+ name=param.name, type=type_descriptor, is_optional=is_optional
398
+ )
399
+ )
400
+ return input_args
401
+
402
+
403
+ def _validate_endpoint_output_types(
404
+ annotation: Any, signature, location: _ErrorLocation, is_streaming: bool
405
+ ) -> list[definitions.TypeDescriptor]:
406
+ has_streaming_type = False
407
+ if annotation == inspect.Parameter.empty:
408
+ _collect_error(
409
+ "Return values of endpoints must be type annotated. Got:\n"
410
+ f"\t{location.method_name}{signature} -> !MISSING!",
411
+ _ErrorKind.IO_TYPE_ERROR,
412
+ location,
413
+ )
414
+ return []
415
+ origin = get_origin(annotation)
416
+ if origin is tuple:
417
+ output_types = []
418
+ for i, arg in enumerate(get_args(annotation)):
419
+ _validate_io_type(arg, f"return_type[{i}]", location)
420
+ output_types.append(definitions.TypeDescriptor(raw=arg))
421
+
422
+ elif origin in (collections.abc.AsyncIterator, collections.abc.Iterator):
423
+ output_types = [_validate_streaming_output_type(annotation, location)]
424
+ has_streaming_type = True
425
+ if not is_streaming:
426
+ _collect_error(
427
+ "If the endpoint returns an iterator (streaming), it must have `yield` "
428
+ "statements.",
429
+ _ErrorKind.IO_TYPE_ERROR,
430
+ location,
431
+ )
432
+ else:
433
+ _validate_io_type(annotation, "return_type", location)
434
+ output_types = [definitions.TypeDescriptor(raw=annotation)]
435
+
436
+ if is_streaming and not has_streaming_type:
437
+ _collect_error(
438
+ "If the endpoint is streaming (has `yield` statements), the return type "
439
+ "must be an iterator (e.g. `AsyncIterator[bytes]`). Got:\n"
440
+ f"\t{location.method_name}{signature} -> {annotation}",
441
+ _ErrorKind.IO_TYPE_ERROR,
442
+ location,
443
+ )
444
+
445
+ return output_types
446
+
447
+
448
+ def _validate_and_describe_endpoint(
449
+ cls: Type[definitions.ABCChainlet], location: _ErrorLocation
450
+ ) -> definitions.EndpointAPIDescriptor:
451
+ """The "endpoint method" of a Chainlet must have the following signature:
452
+
453
+ ```
454
+ [async] def run_remote(
455
+ self, [param_0: anno_0, param_1: anno_1 = default_1, ...]) -> ret_anno:
456
+ ```
457
+
458
+ * The name must be `run_remote` for Chainlets, or `predict` for Models.
459
+ * It can be sync or async def.
460
+ * The number and names of parameters are arbitrary, both positional and named
461
+ parameters are ok.
462
+ * All parameters and the return value must have type annotations. See
463
+ `_validate_io_type` for valid types.
464
+ * Generators are allowed, too (but not yet supported).
465
+ """
466
+ if not hasattr(cls, cls.endpoint_method_name):
467
+ _collect_error(
468
+ f"{cls.entity_type}s must have a `{cls.endpoint_method_name}` method.",
469
+ _ErrorKind.MISSING_API_ERROR,
470
+ location,
471
+ )
472
+ return _DUMMY_ENDPOINT_DESCRIPTOR
473
+
474
+ # This is the unbound method.
475
+ endpoint_method = getattr(cls, cls.endpoint_method_name)
476
+ line = inspect.getsourcelines(endpoint_method)[1]
477
+ location = location.model_copy(
478
+ update={"line": line, "method_name": cls.endpoint_method_name}
479
+ )
480
+
481
+ if not inspect.isfunction(endpoint_method):
482
+ _collect_error("`Endpoints must be a method.", _ErrorKind.TYPE_ERROR, location)
483
+ # If it's not a function, it might be a class var and subsequent inspections
484
+ # fail.
485
+ return _DUMMY_ENDPOINT_DESCRIPTOR
486
+ signature = inspect.signature(endpoint_method)
487
+ input_args = _validate_endpoint_params(
488
+ list(signature.parameters.values()), location
489
+ )
490
+
491
+ if inspect.isasyncgenfunction(endpoint_method):
492
+ is_async = True
493
+ is_streaming = True
494
+ elif inspect.iscoroutinefunction(endpoint_method):
495
+ is_async = True
496
+ is_streaming = False
497
+ else:
498
+ is_async = False
499
+ is_streaming = inspect.isgeneratorfunction(endpoint_method)
500
+
501
+ output_types = _validate_endpoint_output_types(
502
+ signature.return_annotation, signature, location, is_streaming
503
+ )
504
+
505
+ if is_streaming:
506
+ if not is_async:
507
+ _collect_error(
508
+ "`Streaming endpoints (containing `yield` statements) are only "
509
+ "supported for async endpoints.",
510
+ _ErrorKind.IO_TYPE_ERROR,
511
+ location,
512
+ )
513
+
514
+ if not is_async:
515
+ warnings.warn(
516
+ f"`{cls.endpoint_method_name}` must be an async (coroutine) function in future releases. "
517
+ "Replace `def run_remote(...)` with `async def run_remote(...)`. "
518
+ "Local testing and execution can be done with "
519
+ "`asyncio.run(my_chainlet.run_remote(...))`.\n"
520
+ "Note on concurrency: previously sync functions were run in threads by the "
521
+ "Truss server.\n"
522
+ "For some frameworks this was **unsafe** (e.g. in torch the CUDA context "
523
+ "is not thread-safe).\n"
524
+ "Additionally, python threads hold the GIL and therefore might not give "
525
+ "actual throughput gains.\n"
526
+ "To achieve safe and performant concurrency, use framework-specific async "
527
+ "APIs (e.g. AsyncLLMEngine for vLLM) or generic async batching like such "
528
+ "as https://github.com/hussein-awala/async-batcher.",
529
+ DeprecationWarning,
530
+ stacklevel=1,
531
+ )
532
+
533
+ return definitions.EndpointAPIDescriptor(
534
+ name=cls.endpoint_method_name,
535
+ input_args=input_args,
536
+ output_types=output_types,
537
+ is_async=is_async,
538
+ is_streaming=is_streaming,
539
+ )
540
+
541
+
542
+ def _get_generic_class_type(var):
543
+ """Extracts `SomeGeneric` from `SomeGeneric` or `SomeGeneric[T]` uniformly."""
544
+ origin = get_origin(var)
545
+ return origin if origin is not None else var
546
+
547
+
548
+ class _ChainletInitValidator:
549
+ """The `__init__`-method of a Chainlet must have the following signature:
550
+
551
+ ```
552
+ def __init__(
553
+ self,
554
+ [dep_0: dep_0_type = truss_chains.depends(dep_0_class),]
555
+ [dep_1: dep_1_type = truss_chains.depends(dep_1_class),]
556
+ ...
557
+ [dep_N: dep_N_type = truss_chains.provides(dep_N_class),]
558
+ [context: truss_chains.Context = truss_chains.depends_context()]
559
+ ) -> None:
560
+ ```
561
+ * The context argument is optionally trailing and must have a default constructed
562
+ with the `provide_context` directive.
563
+ * The names and number of Chainlet "dependency" arguments are arbitrary.
564
+ * Default values for dependencies must be constructed with the `depends` directive
565
+ to make the dependency injection work. The argument to `depends` must be a
566
+ Chainlet class.
567
+ * The type annotation for dependencies can be a Chainlet class, but it can also be
568
+ a `Protocol` with an equivalent `run` method (e.g. for getting correct type
569
+ checks when providing fake Chainlets for local testing.). It may be omitted if
570
+ the type is clear from the RHS.
571
+ """
572
+
573
+ _location: _ErrorLocation
574
+ _cls: Type[definitions.ABCChainlet]
575
+ has_context: bool = False
576
+ validated_dependencies: Mapping[str, definitions.DependencyDescriptor] = {}
577
+
578
+ def __init__(
579
+ self, cls: Type[definitions.ABCChainlet], location: _ErrorLocation
580
+ ) -> None:
581
+ self._cls = cls
582
+ if not self._cls.has_custom_init():
583
+ self.has_context = False
584
+ self.validated_dependencies = {}
585
+ return
586
+ line = inspect.getsourcelines(cls.__init__)[1]
587
+ self._location = location.model_copy(
588
+ update={"line": line, "method_name": "__init__"}
589
+ )
590
+
591
+ params = list(inspect.signature(cls.__init__).parameters.values())
592
+ self._validate_args(params)
593
+
594
+ def _validate_args(self, params: list[inspect.Parameter]):
595
+ # Each validation pops of "processed" arguments from the list.
596
+ self._validate_self_arg(params)
597
+ self._validate_context_arg(params)
598
+ self._validate_dependencies(params)
599
+
600
+ def _validate_self_arg(self, params: list[inspect.Parameter]):
601
+ if len(params) == 0:
602
+ _collect_error(
603
+ "Methods must have first argument `self`, got no arguments.",
604
+ _ErrorKind.TYPE_ERROR,
605
+ self._location,
606
+ )
607
+ return params
608
+ param = params.pop(0)
609
+ if param.name != definitions.SELF_ARG_NAME:
610
+ _collect_error(
611
+ f"Methods must have first argument `self`, got `{param.name}`.",
612
+ _ErrorKind.TYPE_ERROR,
613
+ self._location,
614
+ )
615
+
616
+ def _validate_context_arg(self, params: list[inspect.Parameter]):
617
+ def make_context_error_msg():
618
+ return (
619
+ f"If `{self._cls.entity_type}` uses context for initialization, it "
620
+ f"must have `{definitions.CONTEXT_ARG_NAME}` argument of type "
621
+ f"`{definitions.DeploymentContext}` as the last argument.\n"
622
+ f"Got arguments: `{params}`.\n"
623
+ "Example of correct `__init__` with context:\n"
624
+ f"{self._example_code()}"
625
+ )
626
+
627
+ if not params:
628
+ return
629
+
630
+ has_context = params[-1].name == definitions.CONTEXT_ARG_NAME
631
+ has_context_marker = isinstance(params[-1].default, ContextDependencyMarker)
632
+ if has_context ^ has_context_marker:
633
+ _collect_error(
634
+ make_context_error_msg(), _ErrorKind.TYPE_ERROR, self._location
635
+ )
636
+
637
+ if not has_context:
638
+ return
639
+
640
+ self.has_context = True
641
+ param = params.pop(-1)
642
+ param_type = _get_generic_class_type(param.annotation)
643
+ # We are lenient and allow omitting the type annotation for context.
644
+ if (
645
+ (param_type is not None)
646
+ and (param_type != inspect.Parameter.empty)
647
+ and (not utils.issubclass_safe(param_type, definitions.DeploymentContext))
648
+ ):
649
+ _collect_error(
650
+ make_context_error_msg(), _ErrorKind.TYPE_ERROR, self._location
651
+ )
652
+ if not isinstance(param.default, ContextDependencyMarker):
653
+ _collect_error(
654
+ f"Incorrect default value `{param.default}` for `context` argument. "
655
+ "Example of correct `__init__` with dependencies:\n"
656
+ f"{self._example_code()}",
657
+ _ErrorKind.TYPE_ERROR,
658
+ self._location,
659
+ )
660
+
661
+ def _validate_dependencies(self, params: list[inspect.Parameter]):
662
+ used = set()
663
+ dependencies = {}
664
+
665
+ if params and not self._cls.supports_dependencies:
666
+ _collect_error(
667
+ f"The only supported argument to `__init__` for {self._cls.entity_type}s "
668
+ f"is the optional context argument.",
669
+ _ErrorKind.TYPE_ERROR,
670
+ self._location,
671
+ )
672
+ return
673
+ for param in params:
674
+ marker = self._validate_dependency_param(param)
675
+ if marker is None:
676
+ continue
677
+ if marker.chainlet_cls in used:
678
+ _collect_error(
679
+ f"The same Chainlet class cannot be used multiple times for "
680
+ f"different arguments. Got previously used "
681
+ f"`{marker.chainlet_cls}` for `{param.name}`.",
682
+ _ErrorKind.TYPE_ERROR,
683
+ self._location,
684
+ )
685
+
686
+ dependencies[param.name] = definitions.DependencyDescriptor(
687
+ chainlet_cls=marker.chainlet_cls, options=marker.options
688
+ )
689
+ used.add(marker.chainlet_cls)
690
+
691
+ self.validated_dependencies = dependencies
692
+
693
+ def _validate_dependency_param(
694
+ self, param: inspect.Parameter
695
+ ) -> Optional[ChainletDependencyMarker]:
696
+ """
697
+ Returns a valid ChainletDependencyMarker if found, None otherwise.
698
+ """
699
+ # TODO: handle subclasses, unions, optionals, check default value etc.
700
+ if param.name == definitions.CONTEXT_ARG_NAME:
701
+ _collect_error(
702
+ f"The init argument name `{definitions.CONTEXT_ARG_NAME}` is reserved for "
703
+ "the optional context argument, which must be trailing if used. Example "
704
+ "of correct `__init__` with context:\n"
705
+ f"{self._example_code()}",
706
+ _ErrorKind.TYPE_ERROR,
707
+ self._location,
708
+ )
709
+
710
+ if not isinstance(param.default, ChainletDependencyMarker):
711
+ _collect_error(
712
+ f"Any arguments of a Chainlet's __init__ (besides `context`) must have "
713
+ "dependency Chainlets with default values from `chains.depends`-directive. "
714
+ f"Got `{param}`.\n"
715
+ f"Example of correct `__init__` with dependencies:\n"
716
+ f"{self._example_code()}",
717
+ _ErrorKind.TYPE_ERROR,
718
+ self._location,
719
+ )
720
+ return None
721
+
722
+ chainlet_cls = param.default.chainlet_cls
723
+ if not utils.issubclass_safe(chainlet_cls, definitions.ABCChainlet):
724
+ _collect_error(
725
+ f"`chains.depends` must be used with a Chainlet class as argument, got "
726
+ f"{chainlet_cls} instead.",
727
+ _ErrorKind.TYPE_ERROR,
728
+ self._location,
729
+ )
730
+ return None
731
+ # Check type annotation.
732
+ # Also lenient with type annotation: since the RHS / default is asserted to be a
733
+ # chainlet class, proper type inference is possible even without annotation.
734
+ # TODO: `Protocol` is not a proper class and this might be version dependent.
735
+ # Find a better way to inspect this.
736
+ if not (
737
+ param.annotation == inspect.Parameter.empty
738
+ or utils.issubclass_safe(param.annotation, Protocol) # type: ignore[arg-type]
739
+ or utils.issubclass_safe(chainlet_cls, param.annotation)
740
+ ):
741
+ _collect_error(
742
+ f"The type annotation for `{param.name}` must be a class/subclass of the "
743
+ "Chainlet type specified by `chains.provides` or a compatible "
744
+ f"typing.Protocol`. Got `{param.annotation}`.",
745
+ _ErrorKind.TYPE_ERROR,
746
+ self._location,
747
+ )
748
+ return param.default # The Marker.
749
+
750
+ @functools.cache
751
+ def _example_code(self) -> str:
752
+ if self._cls.entity_type == "Model":
753
+ return _example_model_code()
754
+ return _example_chainlet_code()
755
+
756
+
757
+ def _validate_remote_config(
758
+ cls: Type[definitions.ABCChainlet], location: _ErrorLocation
759
+ ):
760
+ if not isinstance(
761
+ remote_config := getattr(cls, definitions.REMOTE_CONFIG_NAME),
762
+ definitions.RemoteConfig,
763
+ ):
764
+ _collect_error(
765
+ f"{cls.entity_type}s must have a `{definitions.REMOTE_CONFIG_NAME}` class variable "
766
+ f"of type `{definitions.RemoteConfig}`. Got `{type(remote_config)}` "
767
+ f"for `{cls}`.",
768
+ _ErrorKind.TYPE_ERROR,
769
+ location,
770
+ )
771
+
772
+
773
+ def _validate_health_check(
774
+ cls: Type[definitions.ABCChainlet], location: _ErrorLocation
775
+ ) -> Optional[definitions.HealthCheckAPIDescriptor]:
776
+ """The `is_healthy` method of a Chainlet must have the following signature:
777
+ ```
778
+ [async] def is_healthy(self) -> bool:
779
+ ```
780
+ * The name must be `is_healthy`.
781
+ * It can be sync or async def.
782
+ * Must not define any parameters other than `self`.
783
+ * Must return a boolean.
784
+ """
785
+ if not hasattr(cls, definitions.HEALTH_CHECK_METHOD_NAME):
786
+ return None
787
+
788
+ health_check_method = getattr(cls, definitions.HEALTH_CHECK_METHOD_NAME)
789
+ if not inspect.isfunction(health_check_method):
790
+ _collect_error(
791
+ f"`{definitions.HEALTH_CHECK_METHOD_NAME}` must be a method.",
792
+ _ErrorKind.TYPE_ERROR,
793
+ location,
794
+ )
795
+ return None
796
+
797
+ line = inspect.getsourcelines(health_check_method)[1]
798
+ location = location.model_copy(
799
+ update={"line": line, "method_name": definitions.HEALTH_CHECK_METHOD_NAME}
800
+ )
801
+ is_async = inspect.iscoroutinefunction(health_check_method)
802
+ signature = inspect.signature(health_check_method)
803
+ params = list(signature.parameters.values())
804
+ _validate_method_signature(definitions.HEALTH_CHECK_METHOD_NAME, location, params)
805
+ if len(params) > 1:
806
+ _collect_error(
807
+ f"`{definitions.HEALTH_CHECK_METHOD_NAME}` must have only one argument: `{definitions.SELF_ARG_NAME}`.",
808
+ _ErrorKind.TYPE_ERROR,
809
+ location,
810
+ )
811
+ if signature.return_annotation == inspect.Parameter.empty:
812
+ _collect_error(
813
+ "Return value of health check must be type annotated. Got:\n"
814
+ f"\t{location.method_name}{signature} -> !MISSING!",
815
+ _ErrorKind.IO_TYPE_ERROR,
816
+ location,
817
+ )
818
+ return None
819
+ if signature.return_annotation is not bool:
820
+ _collect_error(
821
+ "Return value of health check must be a boolean. Got:\n"
822
+ f"\t{location.method_name}{signature} -> {signature.return_annotation}",
823
+ _ErrorKind.IO_TYPE_ERROR,
824
+ location,
825
+ )
826
+
827
+ return definitions.HealthCheckAPIDescriptor(is_async=is_async)
828
+
829
+
830
+ def validate_and_register_cls(cls: Type[definitions.ABCChainlet]) -> None:
831
+ """Note that validation errors will only be collected, not raised, and Chainlets.
832
+ with issues, are still added to the registry. Use `raise_validation_errors` to
833
+ assert all Chainlets are valid and before performing operations that depend on
834
+ these constraints."""
835
+ src_path = os.path.abspath(inspect.getfile(cls))
836
+ line = inspect.getsourcelines(cls)[1]
837
+ location = _ErrorLocation(src_path=src_path, line=line, chainlet_name=cls.__name__)
838
+
839
+ _validate_remote_config(cls, location)
840
+ init_validator = _ChainletInitValidator(cls, location)
841
+ chainlet_descriptor = definitions.ChainletAPIDescriptor(
842
+ chainlet_cls=cls,
843
+ dependencies=init_validator.validated_dependencies,
844
+ has_context=init_validator.has_context,
845
+ endpoint=_validate_and_describe_endpoint(cls, location),
846
+ src_path=src_path,
847
+ health_check=_validate_health_check(cls, location),
848
+ )
849
+ logging.debug(
850
+ f"Descriptor for {cls}:\n{pprint.pformat(chainlet_descriptor, indent=4)}\n"
851
+ )
852
+ _global_chainlet_registry.register_chainlet(chainlet_descriptor)
853
+
854
+
855
+ # Dependency-Injection / Registry ######################################################
856
+
857
+
858
+ class _ChainletRegistry:
859
+ # Because dependencies are required to be present when registering a Chainlet,
860
+ # this dict contains natively a topological sorting of the dependency graph.
861
+ _chainlets: collections.OrderedDict[
862
+ Type[definitions.ABCChainlet], definitions.ChainletAPIDescriptor
863
+ ]
864
+ _name_to_cls: MutableMapping[str, Type[definitions.ABCChainlet]]
865
+
866
+ def __init__(self) -> None:
867
+ self._chainlets = collections.OrderedDict()
868
+ self._name_to_cls = {}
869
+
870
+ def clear(self):
871
+ self._chainlets = collections.OrderedDict()
872
+ self._name_to_cls = {}
873
+
874
+ def register_chainlet(self, chainlet_descriptor: definitions.ChainletAPIDescriptor):
875
+ for dep in chainlet_descriptor.dependencies.values():
876
+ # To depend on a Chainlet, the class must be defined (module initialized)
877
+ # which entails that is has already been added to the registry.
878
+ # This is an assertion, because unless users meddle with the internal
879
+ # registry, it's not possible to depend on another chainlet before it's
880
+ # also added to the registry.
881
+ assert dep.chainlet_cls in self._chainlets, (
882
+ "Cannot depend on Chainlet. Available Chainlets: "
883
+ f"{list(self._chainlets.keys())}"
884
+ )
885
+
886
+ # Because class are globally unique, to prevent re-use / overwriting of names,
887
+ # We must check this in addition.
888
+ if chainlet_descriptor.name in self._name_to_cls:
889
+ conflict = self._name_to_cls[chainlet_descriptor.name]
890
+ existing_source_path = self._chainlets[conflict].src_path
891
+ raise definitions.ChainsUsageError(
892
+ f"A Chainlet with name `{chainlet_descriptor.name}` was already "
893
+ f"defined, Chainlet names must be globally unique.\n"
894
+ f"Pre-existing in: `{existing_source_path}`\n"
895
+ f"New conflict in: `{chainlet_descriptor.src_path}`."
896
+ )
897
+
898
+ self._chainlets[chainlet_descriptor.chainlet_cls] = chainlet_descriptor
899
+ self._name_to_cls[chainlet_descriptor.name] = chainlet_descriptor.chainlet_cls
900
+
901
+ def unregister_chainlet(self, chainlet_name: str) -> None:
902
+ chainlet_cls = self._name_to_cls.pop(chainlet_name)
903
+ self._chainlets.pop(chainlet_cls)
904
+
905
+ @property
906
+ def chainlet_descriptors(self) -> list[definitions.ChainletAPIDescriptor]:
907
+ return list(self._chainlets.values())
908
+
909
+ def get_descriptor(
910
+ self, chainlet_cls: Type[definitions.ABCChainlet]
911
+ ) -> definitions.ChainletAPIDescriptor:
912
+ return self._chainlets[chainlet_cls]
913
+
914
+ def get_dependencies(
915
+ self, chainlet: definitions.ChainletAPIDescriptor
916
+ ) -> Iterable[definitions.ChainletAPIDescriptor]:
917
+ return [
918
+ self._chainlets[dep.chainlet_cls]
919
+ for dep in self._chainlets[chainlet.chainlet_cls].dependencies.values()
920
+ ]
921
+
922
+ def get_chainlet_names(self) -> set[str]:
923
+ return set(self._name_to_cls.keys())
924
+
925
+
926
+ _global_chainlet_registry = _ChainletRegistry()
927
+
928
+
929
+ def get_dependencies(
930
+ chainlet: definitions.ChainletAPIDescriptor,
931
+ ) -> Iterable[definitions.ChainletAPIDescriptor]:
932
+ return _global_chainlet_registry.get_dependencies(chainlet)
933
+
934
+
935
+ def get_descriptor(
936
+ chainlet_cls: Type[definitions.ABCChainlet],
937
+ ) -> definitions.ChainletAPIDescriptor:
938
+ return _global_chainlet_registry.get_descriptor(chainlet_cls)
939
+
940
+
941
+ def get_ordered_descriptors() -> list[definitions.ChainletAPIDescriptor]:
942
+ return _global_chainlet_registry.chainlet_descriptors
943
+
944
+
945
+ # Chainlet class runtime utils #########################################################
946
+
947
+
948
+ def _determine_arguments(func: Callable, **kwargs):
949
+ """Merges provided and default arguments to effective invocation arguments."""
950
+ sig = inspect.signature(func)
951
+ bound_args = sig.bind_partial(**kwargs)
952
+ bound_args.apply_defaults()
953
+ return bound_args.arguments
954
+
955
+
956
+ def ensure_args_are_injected(cls, original_init: Callable, kwargs) -> None:
957
+ """Asserts all marker markers are replaced by actual objects."""
958
+ final_args = _determine_arguments(original_init, **kwargs)
959
+ for name, value in final_args.items():
960
+ if name == definitions.CONTEXT_ARG_NAME:
961
+ if not isinstance(value, definitions.DeploymentContext):
962
+ logging.error(
963
+ f"When initializing {cls.entity_type} `{cls.name}`, for context "
964
+ f"argument an incompatible value was passed, value: `{value}`."
965
+ )
966
+ raise definitions.ChainsRuntimeError(_instantiation_error_msg(cls.name))
967
+ # The argument is a dependency chainlet.
968
+ elif isinstance(value, _BaseProvisionMarker):
969
+ logging.error(
970
+ f"When initializing {cls.entity_type} `{cls.name}`, for dependency Chainlet"
971
+ f"argument `{name}` an incompatible value was passed, value: `{value}`."
972
+ )
973
+ raise definitions.ChainsRuntimeError(_instantiation_error_msg(cls.name))
974
+
975
+
976
+ # Local Execution ######################################################################
977
+
978
+ # A variable to track the stack depth relative to `run_local` context manager.
979
+ run_local_stack_depth: contextvars.ContextVar[int] = contextvars.ContextVar(
980
+ "run_local_stack_depth"
981
+ )
982
+
983
+ _INIT_LOCAL_NAME = "__init_local__"
984
+ _INIT_NAME = "__init__"
985
+
986
+
987
+ def _create_modified_init_for_local(
988
+ chainlet_descriptor: definitions.ChainletAPIDescriptor,
989
+ cls_to_instance: MutableMapping[
990
+ Type[definitions.ABCChainlet], definitions.ABCChainlet
991
+ ],
992
+ secrets: Mapping[str, str],
993
+ data_dir: Optional[pathlib.Path],
994
+ chainlet_to_service: Mapping[str, definitions.DeployedServiceDescriptor],
995
+ ):
996
+ """Replaces the default argument values with local Chainlet instantiations.
997
+
998
+ If this patch is used, Chainlets can be functionally instantiated without
999
+ any init args (because the patched defaults are sufficient).
1000
+ """
1001
+
1002
+ def _detect_naive_instantiations(
1003
+ stack: list[inspect.FrameInfo], levels_below_run_local: int
1004
+ ) -> None:
1005
+ # The goal is to find cases where a chainlet is directly instantiated
1006
+ # in a place that is not immediately inside the `run_local`-contextmanager.
1007
+ # In particular chainlets being instantiated in the `__init__` or `run_remote`
1008
+ # methods of other chainlets (instead of being passed as dependencies with
1009
+ # `chains.depends()`).
1010
+ #
1011
+ # We look into the calls stack of any (wrapped) invocation of an
1012
+ # ABCChainlet-subclass's `__init__`.
1013
+ # We also cut off the "above" call stack, such that `run_local` (and anything
1014
+ # above that) is ignored, so it is possible to use `run_local` in nested code.
1015
+ #
1016
+ # A valid stack looks like this:
1017
+ # * `__init_local__` as deepest frame (which would then call
1018
+ # `__init_with_arg_check__` -> `__init__` if validation passes).
1019
+ # * If a chainlet has no base classes, this can *only* be called from
1020
+ # `__init_local__` - the part when the chainlet needs to be instantiated and
1021
+ # added to `cls_to_instance`.
1022
+ # * If a chainlet has other chainlets as base classes, they may call a chain
1023
+ # of `super().__init()`. Each will add a triple of
1024
+ # (__init__, __init_with_arg_check__, __init_local__) to the stack. While
1025
+ # these 3 init layers belong to the different base classes, the type of the
1026
+ # `self` arg is fixed.
1027
+ #
1028
+ # To detect invalid stacks we can rephrase this: `__init_local__` can only be
1029
+ # called under either of these conditions:
1030
+ # * From `__init_local__` when needing to populate `cls_to_instance`.
1031
+ # * From a subclass's `__init__` using `super().__init__()`. This means the
1032
+ # type (and instance) of the `self` arg in the calling `__init_local__` and
1033
+ # the invoked `__init__` must are identical. In the forbidden situation that
1034
+ # for example Chainlet `A` tries to create an instance of `B` inside its
1035
+ # `__init__` the `self` args are two different instances.
1036
+ substack = stack[:levels_below_run_local]
1037
+ parts = ["-------- Chainlet Instantiation Stack --------"]
1038
+ # Track the owner classes encountered in the stack to detect invalid scenarios
1039
+ transformed_stack = []
1040
+ for frame in substack:
1041
+ func_name = frame.function
1042
+ line_number = frame.lineno
1043
+ local_vars = frame.frame.f_locals
1044
+ init_owner_class = None
1045
+ self_value = None
1046
+ # Determine if "self" exists and extract the owner class
1047
+ if "self" in local_vars:
1048
+ self_value = local_vars["self"]
1049
+ if func_name == _INIT_NAME:
1050
+ try:
1051
+ name_parts = frame.frame.f_code.co_qualname.split(".") # type: ignore[attr-defined]
1052
+ except AttributeError: # `co_qualname` only in Python 3.11+.
1053
+ name_parts = []
1054
+ if len(name_parts) > 1:
1055
+ init_owner_class = name_parts[-2]
1056
+ elif func_name == _INIT_LOCAL_NAME:
1057
+ assert "init_owner_class" in local_vars, (
1058
+ f"`{_INIT_LOCAL_NAME}` must capture `init_owner_class`"
1059
+ )
1060
+ init_owner_class = local_vars["init_owner_class"].__name__
1061
+
1062
+ if init_owner_class:
1063
+ parts.append(
1064
+ f"{func_name}:{line_number} | type(self)=<"
1065
+ f"{self_value.__class__.__name__}> method of <"
1066
+ f"{init_owner_class}>"
1067
+ )
1068
+ else:
1069
+ parts.append(
1070
+ f"{func_name}:l{line_number} | type(self)=<"
1071
+ f"{self_value.__class__.__name__}>"
1072
+ )
1073
+ else:
1074
+ parts.append(f"{func_name}:l{line_number}")
1075
+
1076
+ transformed_stack.append((func_name, self_value, frame))
1077
+
1078
+ if len(parts) > 1:
1079
+ logging.debug("\n".join(parts))
1080
+
1081
+ # Analyze the stack after preparing relevant information.
1082
+ for i in range(len(transformed_stack) - 1):
1083
+ func_name, self_value, _ = transformed_stack[i]
1084
+ up_func_name, up_self_value, up_frame = transformed_stack[i + 1]
1085
+ if func_name != _INIT_LOCAL_NAME:
1086
+ continue # OK, we only validate `__init_local__` invocations.
1087
+ # We are in `__init_local__`. Now check who and how called it.
1088
+ if up_func_name == _INIT_LOCAL_NAME:
1089
+ # Note: in this case `self` in the current frame is different then
1090
+ # self in the parent frame, since a new instance is created.
1091
+ continue # Ok, populating `cls_to_instance`.
1092
+ if up_func_name == _INIT_NAME and self_value == up_self_value:
1093
+ continue # OK, call to `super().__init__()`.
1094
+
1095
+ # Everything else is invalid.
1096
+ code_context = up_frame.code_context
1097
+ assert code_context is not None
1098
+ location = (
1099
+ f"{up_frame.filename}:{up_frame.lineno} ({up_frame.function})\n"
1100
+ f" {code_context[0].strip()}"
1101
+ )
1102
+ raise definitions.ChainsRuntimeError(
1103
+ _instantiation_error_msg(chainlet_descriptor.name, location)
1104
+ )
1105
+
1106
+ __original_init__ = chainlet_descriptor.chainlet_cls.__init__
1107
+
1108
+ @functools.wraps(__original_init__)
1109
+ def __init_local__(self: definitions.ABCChainlet, **kwargs) -> None:
1110
+ logging.debug(f"Patched `__init__` of `{chainlet_descriptor.name}`.")
1111
+ stack_depth = run_local_stack_depth.get(None)
1112
+ assert stack_depth is not None, "__init_local__ is only called in context."
1113
+ stack = inspect.stack()
1114
+ current_stack_depth = len(stack)
1115
+ levels_below_run_local = current_stack_depth - stack_depth
1116
+ # Capture `init_owner_class` in locals, because we check it in
1117
+ # `_detect_naive_instantiations`.
1118
+ init_owner_class = chainlet_descriptor.chainlet_cls # noqa: F841
1119
+ _detect_naive_instantiations(stack, levels_below_run_local)
1120
+
1121
+ kwargs_mod = dict(kwargs)
1122
+ if (
1123
+ chainlet_descriptor.has_context
1124
+ and definitions.CONTEXT_ARG_NAME not in kwargs_mod
1125
+ ):
1126
+ kwargs_mod[definitions.CONTEXT_ARG_NAME] = definitions.DeploymentContext(
1127
+ secrets=secrets,
1128
+ data_dir=data_dir,
1129
+ chainlet_to_service=chainlet_to_service,
1130
+ )
1131
+ for arg_name, dep in chainlet_descriptor.dependencies.items():
1132
+ chainlet_cls = dep.chainlet_cls
1133
+ if arg_name in kwargs_mod:
1134
+ logging.debug(
1135
+ f"Use given instance for `{arg_name}` of type `{dep.name}`."
1136
+ )
1137
+ continue
1138
+ if chainlet_cls in cls_to_instance:
1139
+ logging.debug(
1140
+ f"Use previously created `{arg_name}` of type `{dep.name}`."
1141
+ )
1142
+ kwargs_mod[arg_name] = cls_to_instance[chainlet_cls]
1143
+ else:
1144
+ logging.debug(
1145
+ f"Create new instance for `{arg_name}` of type `{dep.name}`. "
1146
+ f"Calling patched __init__."
1147
+ )
1148
+ assert chainlet_cls.meta_data.init_is_patched
1149
+ # Dependency chainlets are instantiated here, using their __init__
1150
+ # that is patched for local.
1151
+ logging.info(f"Making first {dep.name}.")
1152
+ instance = chainlet_cls() # type: ignore # Here init args are patched.
1153
+ cls_to_instance[chainlet_cls] = instance
1154
+ kwargs_mod[arg_name] = instance
1155
+
1156
+ logging.debug(f"Calling original __init__ of {chainlet_descriptor.name}.")
1157
+ __original_init__(self, **kwargs_mod)
1158
+
1159
+ return __init_local__
1160
+
1161
+
1162
+ @contextlib.contextmanager
1163
+ @raise_validation_errors_before
1164
+ def run_local(
1165
+ secrets: Mapping[str, str],
1166
+ data_dir: Optional[pathlib.Path],
1167
+ chainlet_to_service: Mapping[str, definitions.DeployedServiceDescriptor],
1168
+ ) -> Any:
1169
+ """Context to run Chainlets with dependency injection from local instances."""
1170
+ type_to_instance: MutableMapping[
1171
+ Type[definitions.ABCChainlet], definitions.ABCChainlet
1172
+ ] = {}
1173
+ original_inits: MutableMapping[Type[definitions.ABCChainlet], Callable] = {}
1174
+
1175
+ # Capture the stack depth when entering the context manager. The stack is used
1176
+ # to check that chainlets' `__init__` methods are only called within this context
1177
+ # manager, to flag naive instantiations.
1178
+ stack_depth = len(inspect.stack())
1179
+ for chainlet_descriptor in _global_chainlet_registry.chainlet_descriptors:
1180
+ original_inits[chainlet_descriptor.chainlet_cls] = (
1181
+ chainlet_descriptor.chainlet_cls.__init__
1182
+ )
1183
+ init_for_local = _create_modified_init_for_local(
1184
+ chainlet_descriptor,
1185
+ type_to_instance,
1186
+ secrets,
1187
+ data_dir,
1188
+ chainlet_to_service,
1189
+ )
1190
+ chainlet_descriptor.chainlet_cls.__init__ = init_for_local # type: ignore[method-assign]
1191
+ chainlet_descriptor.chainlet_cls.meta_data.init_is_patched = True
1192
+ # Subtract 2 levels: `run_local` (this) and `__enter__` (from @contextmanager).
1193
+ token = run_local_stack_depth.set(stack_depth - 2)
1194
+ try:
1195
+ yield
1196
+ finally:
1197
+ # Restore original classes to unpatched state.
1198
+ for chainlet_cls, original_init in original_inits.items():
1199
+ chainlet_cls.__init__ = original_init # type: ignore[method-assign]
1200
+ chainlet_cls.meta_data.init_is_patched = False
1201
+
1202
+ run_local_stack_depth.reset(token)
1203
+
1204
+
1205
+ ########################################################################################
1206
+
1207
+
1208
+ def entrypoint(
1209
+ cls_or_chain_name: Optional[Union[Type[ChainletT], str]] = None,
1210
+ ) -> Union[Callable[[Type[ChainletT]], Type[ChainletT]], Type[ChainletT]]:
1211
+ """Decorator to tag a Chainlet as an entrypoint.
1212
+ Can be used with or without chain name argument.
1213
+ """
1214
+
1215
+ def decorator(cls: Type[ChainletT]) -> Type[ChainletT]:
1216
+ if not (utils.issubclass_safe(cls, definitions.ABCChainlet)):
1217
+ src_path = os.path.abspath(inspect.getfile(cls))
1218
+ line = inspect.getsourcelines(cls)[1]
1219
+ location = _ErrorLocation(src_path=src_path, line=line)
1220
+ _collect_error(
1221
+ "Only Chainlet classes can be marked as entrypoint.",
1222
+ _ErrorKind.TYPE_ERROR,
1223
+ location,
1224
+ )
1225
+ cls.meta_data.is_entrypoint = True
1226
+ if isinstance(cls_or_chain_name, str):
1227
+ cls.meta_data.chain_name = cls_or_chain_name
1228
+ return cls
1229
+
1230
+ if isinstance(cls_or_chain_name, str):
1231
+ return decorator
1232
+
1233
+ assert cls_or_chain_name is not None
1234
+ return decorator(cls_or_chain_name) # Decorator used without arguments
1235
+
1236
+
1237
+ class _ABCImporter(abc.ABC):
1238
+ @classmethod
1239
+ @abc.abstractmethod
1240
+ def _no_entrypoint_error(cls, module_path: pathlib.Path) -> ValueError:
1241
+ pass
1242
+
1243
+ @classmethod
1244
+ @abc.abstractmethod
1245
+ def _multiple_entrypoints_error(
1246
+ cls, module_path: pathlib.Path, entrypoints: set[type[definitions.ABCChainlet]]
1247
+ ) -> ValueError:
1248
+ pass
1249
+
1250
+ @classmethod
1251
+ @abc.abstractmethod
1252
+ def _target_cls_type(cls) -> Type[definitions.ABCChainlet]:
1253
+ pass
1254
+
1255
+ @classmethod
1256
+ def _get_entrypoint_chainlets(cls, symbols) -> set[Type[definitions.ABCChainlet]]:
1257
+ return {
1258
+ sym
1259
+ for sym in symbols
1260
+ if utils.issubclass_safe(sym, cls._target_cls_type())
1261
+ and cast(definitions.ABCChainlet, sym).meta_data.is_entrypoint
1262
+ }
1263
+
1264
+ @classmethod
1265
+ def _load_module(cls, module_path: pathlib.Path) -> tuple[types.ModuleType, Loader]:
1266
+ """The context manager ensures that modules imported by the Model/Chain
1267
+ are removed upon exit.
1268
+
1269
+ I.e. aiming at making the import idempotent for common usages, although there could
1270
+ be additional side effects not accounted for by this implementation."""
1271
+ module_name = module_path.stem # Use the file's name as the module name
1272
+ if not os.path.isfile(module_path):
1273
+ raise ImportError(
1274
+ f"`{module_path}` is not a file. You must point to a python file where "
1275
+ f"the entrypoint is defined."
1276
+ )
1277
+
1278
+ import_error_msg = f"Could not import `{module_path}`. Check path."
1279
+ spec = importlib.util.spec_from_file_location(module_name, module_path)
1280
+ if not spec or not spec.loader:
1281
+ raise ImportError(import_error_msg)
1282
+
1283
+ module = importlib.util.module_from_spec(spec)
1284
+ module.__file__ = str(module_path)
1285
+ # Since the framework depends on tracking the source files via `inspect` and this
1286
+ # depends on the modules bein properly registered in `sys.modules`, we have to
1287
+ # manually do this here (because importlib does not do it automatically). This
1288
+ # registration has to stay at least until the push command has finished.
1289
+ if module_name in sys.modules:
1290
+ raise ImportError(
1291
+ f"{import_error_msg} There is already a module in `sys.modules` "
1292
+ f"with name `{module_name}`. Overwriting that value is unsafe. "
1293
+ "Try renaming your source file."
1294
+ )
1295
+
1296
+ sys.modules[module_name] = module
1297
+ # Add path for making absolute imports relative to the source_module's dir.
1298
+ sys.path.insert(0, str(module_path.parent))
1299
+
1300
+ return module, spec.loader
1301
+
1302
+ @classmethod
1303
+ def _cleanup_module_imports(
1304
+ cls,
1305
+ modules_before: set[str],
1306
+ modules_after: set[str],
1307
+ module_path: pathlib.Path,
1308
+ ):
1309
+ modules_diff = modules_after - modules_before
1310
+ # Apparently torch import leaves some side effects that cannot be reverted
1311
+ # by deleting the modules and would lead to a crash when another import
1312
+ # is attempted. Since torch is a common lib, we make this explicit special
1313
+ # case and just leave those modules.
1314
+ # TODO: this seems still brittle and other modules might cause similar problems.
1315
+ # it would be good to find a more principled solution.
1316
+ modules_to_delete = {
1317
+ s for s in modules_diff if not (s.startswith("torch.") or s == "torch")
1318
+ }
1319
+ if torch_modules := modules_diff - modules_to_delete:
1320
+ logging.debug(
1321
+ f"Keeping torch modules after import context: {torch_modules}"
1322
+ )
1323
+
1324
+ logging.debug(
1325
+ f"Deleting modules when exiting import context: {modules_to_delete}"
1326
+ )
1327
+ for mod in modules_to_delete:
1328
+ del sys.modules[mod]
1329
+ try:
1330
+ sys.path.remove(str(module_path.parent))
1331
+ except ValueError: # In case the value was already removed for whatever reason.
1332
+ pass
1333
+
1334
+ @classmethod
1335
+ @contextlib.contextmanager
1336
+ def import_target(
1337
+ cls, module_path: pathlib.Path, target_name: Optional[str] = None
1338
+ ) -> Iterator[Type[definitions.ABCChainlet]]:
1339
+ resolved_module_path = pathlib.Path(module_path).resolve()
1340
+ modules_before = set(sys.modules.keys())
1341
+ module, loader = cls._load_module(module_path)
1342
+ modules_after = set()
1343
+
1344
+ chainlets_before = _global_chainlet_registry.get_chainlet_names()
1345
+ chainlets_after = set()
1346
+ try:
1347
+ try:
1348
+ loader.exec_module(module)
1349
+ raise_validation_errors()
1350
+ finally:
1351
+ modules_after = set(sys.modules.keys())
1352
+ chainlets_after = _global_chainlet_registry.get_chainlet_names()
1353
+
1354
+ if target_name:
1355
+ target_cls = getattr(module, target_name, None)
1356
+ if not target_cls:
1357
+ raise AttributeError(
1358
+ f"Target class `{target_name}` not found "
1359
+ f"in `{resolved_module_path}`."
1360
+ )
1361
+ if not utils.issubclass_safe(target_cls, cls._target_cls_type()):
1362
+ raise TypeError(
1363
+ f"Target `{target_cls}` is not a {cls._target_cls_type()}."
1364
+ )
1365
+ else:
1366
+ module_vars = (getattr(module, name) for name in dir(module))
1367
+ entrypoints = cls._get_entrypoint_chainlets(module_vars)
1368
+ if len(entrypoints) == 0:
1369
+ raise cls._no_entrypoint_error(module_path)
1370
+ elif len(entrypoints) > 1:
1371
+ raise cls._multiple_entrypoints_error(module_path, entrypoints)
1372
+ target_cls = utils.expect_one(entrypoints)
1373
+ yield target_cls
1374
+ finally:
1375
+ cls._cleanup_module_imports(
1376
+ modules_before, modules_after, resolved_module_path
1377
+ )
1378
+ for chainlet_name in chainlets_after - chainlets_before:
1379
+ _global_chainlet_registry.unregister_chainlet(chainlet_name)
1380
+
1381
+
1382
+ class ChainletImporter(_ABCImporter):
1383
+ @classmethod
1384
+ def _no_entrypoint_error(cls, module_path: pathlib.Path) -> ValueError:
1385
+ return ValueError(
1386
+ "No `target_name` was specified and no Chainlet in "
1387
+ f"`{module_path}` was tagged with `@chains.mark_entrypoint`. Tag "
1388
+ "one Chainlet or provide the Chainlet class name."
1389
+ )
1390
+
1391
+ @classmethod
1392
+ def _multiple_entrypoints_error(
1393
+ cls, module_path: pathlib.Path, entrypoints: set[type[definitions.ABCChainlet]]
1394
+ ) -> ValueError:
1395
+ return ValueError(
1396
+ "`target_name` was not specified and multiple Chainlets in "
1397
+ f"`{module_path}` were tagged with `@chains.mark_entrypoint`. Tag "
1398
+ "one Chainlet or provide the Chainlet class name. Found Chainlets: "
1399
+ f"\n{list(cls.name for cls in entrypoints)}"
1400
+ )
1401
+
1402
+ @classmethod
1403
+ def _target_cls_type(cls) -> Type[definitions.ABCChainlet]:
1404
+ return ChainletBase
1405
+
1406
+
1407
+ class ModelImporter(_ABCImporter):
1408
+ @classmethod
1409
+ def _no_entrypoint_error(cls, module_path: pathlib.Path) -> ValueError:
1410
+ return ValueError(
1411
+ f"No Model class in `{module_path}` inherits from {cls._target_cls_type()}."
1412
+ )
1413
+
1414
+ @classmethod
1415
+ def _multiple_entrypoints_error(
1416
+ cls, module_path: pathlib.Path, entrypoints: set[type[definitions.ABCChainlet]]
1417
+ ) -> ValueError:
1418
+ return ValueError(
1419
+ f"Multiple Model classes in `{module_path}` inherit from {cls._target_cls_type()}, "
1420
+ "but only one allowed. Found classes: "
1421
+ f"\n{list(cls.name for cls in entrypoints)}"
1422
+ )
1423
+
1424
+ @classmethod
1425
+ def _target_cls_type(cls) -> Type[definitions.ABCChainlet]:
1426
+ return ModelBase
1427
+
1428
+
1429
+ class ChainletBase(definitions.ABCChainlet):
1430
+ """Base class for all chainlets.
1431
+
1432
+ Inheriting from this class adds validations to make sure subclasses adhere to the
1433
+ chainlet pattern and facilitates remote chainlet deployment.
1434
+
1435
+ Refer to `the docs <https://docs.baseten.co/chains/getting-started>`_ and this
1436
+ `example chainlet <https://github.com/basetenlabs/truss/blob/main/truss-chains/truss_chains/example_chainlet.py>`_
1437
+ for more guidance on how to create subclasses.
1438
+ """
1439
+
1440
+ def __init_subclass__(cls, **kwargs) -> None:
1441
+ super().__init_subclass__(**kwargs)
1442
+ cls._framework_config = definitions.FrameworkConfig(
1443
+ entity_type="Chainlet",
1444
+ supports_dependencies=True,
1445
+ endpoint_method_name=definitions.RUN_REMOTE_METHOD_NAME,
1446
+ )
1447
+ # Each sub-class has own, isolated metadata, e.g. we don't want
1448
+ # `mark_entrypoint` to propagate to subclasses.
1449
+ cls.meta_data = definitions.ChainletMetadata()
1450
+ validate_and_register_cls(cls) # Errors are collected, not raised!
1451
+ # For default init (from `object`) we don't need to check anything.
1452
+ if cls.has_custom_init():
1453
+ original_init = cls.__init__
1454
+
1455
+ @functools.wraps(original_init)
1456
+ def __init_with_arg_check__(self, *args, **kwargs):
1457
+ if args:
1458
+ raise definitions.ChainsRuntimeError("Only kwargs are allowed.")
1459
+ ensure_args_are_injected(cls, original_init, kwargs)
1460
+ original_init(self, *args, **kwargs)
1461
+
1462
+ cls.__init__ = __init_with_arg_check__ # type: ignore[method-assign]
1463
+
1464
+
1465
+ class ModelBase(definitions.ABCChainlet):
1466
+ """Base class for all standalone models.
1467
+
1468
+ Inheriting from this class adds validations to make sure subclasses adhere to the
1469
+ truss model pattern.
1470
+ """
1471
+
1472
+ def __init_subclass__(cls, **kwargs) -> None:
1473
+ super().__init_subclass__(**kwargs)
1474
+ cls._framework_config = definitions.FrameworkConfig(
1475
+ entity_type="Model",
1476
+ supports_dependencies=False,
1477
+ endpoint_method_name=definitions.MODEL_ENDPOINT_METHOD_NAME,
1478
+ )
1479
+ cls.meta_data = definitions.ChainletMetadata(is_entrypoint=True)
1480
+ validate_and_register_cls(cls)