truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

Files changed (362) hide show
  1. truss/__init__.py +10 -3
  2. truss/api/__init__.py +123 -0
  3. truss/api/definitions.py +51 -0
  4. truss/base/constants.py +116 -0
  5. truss/base/custom_types.py +29 -0
  6. truss/{errors.py → base/errors.py} +4 -0
  7. truss/base/trt_llm_config.py +310 -0
  8. truss/{truss_config.py → base/truss_config.py} +344 -31
  9. truss/{truss_spec.py → base/truss_spec.py} +20 -6
  10. truss/{validation.py → base/validation.py} +60 -11
  11. truss/cli/cli.py +841 -88
  12. truss/{remote → cli}/remote_cli.py +2 -7
  13. truss/contexts/docker_build_setup.py +67 -0
  14. truss/contexts/image_builder/cache_warmer.py +2 -8
  15. truss/contexts/image_builder/image_builder.py +1 -1
  16. truss/contexts/image_builder/serving_image_builder.py +292 -46
  17. truss/contexts/image_builder/util.py +1 -3
  18. truss/contexts/local_loader/docker_build_emulator.py +58 -0
  19. truss/contexts/local_loader/load_model_local.py +2 -2
  20. truss/contexts/local_loader/truss_module_loader.py +1 -1
  21. truss/contexts/local_loader/utils.py +1 -1
  22. truss/local/local_config.py +2 -6
  23. truss/local/local_config_handler.py +20 -5
  24. truss/patch/__init__.py +1 -0
  25. truss/patch/hash.py +4 -70
  26. truss/patch/signature.py +4 -16
  27. truss/patch/truss_dir_patch_applier.py +3 -78
  28. truss/remote/baseten/api.py +308 -23
  29. truss/remote/baseten/auth.py +3 -3
  30. truss/remote/baseten/core.py +257 -50
  31. truss/remote/baseten/custom_types.py +44 -0
  32. truss/remote/baseten/error.py +4 -0
  33. truss/remote/baseten/remote.py +369 -118
  34. truss/remote/baseten/service.py +118 -11
  35. truss/remote/baseten/utils/status.py +29 -0
  36. truss/remote/baseten/utils/tar.py +34 -22
  37. truss/remote/baseten/utils/transfer.py +36 -23
  38. truss/remote/remote_factory.py +14 -5
  39. truss/remote/truss_remote.py +72 -45
  40. truss/templates/base.Dockerfile.jinja +18 -16
  41. truss/templates/cache.Dockerfile.jinja +3 -3
  42. truss/{server → templates/control}/control/application.py +14 -35
  43. truss/{server → templates/control}/control/endpoints.py +39 -9
  44. truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
  45. truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
  46. truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
  47. truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
  48. truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
  49. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
  50. truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
  51. truss/{server → templates/control}/control/server.py +11 -6
  52. truss/templates/control/requirements.txt +9 -0
  53. truss/templates/custom_python_dx/my_model.py +28 -0
  54. truss/templates/docker_server/proxy.conf.jinja +42 -0
  55. truss/templates/docker_server/supervisord.conf.jinja +27 -0
  56. truss/templates/docker_server_requirements.txt +1 -0
  57. truss/templates/server/common/errors.py +231 -0
  58. truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
  59. truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
  60. truss/{server → templates/server}/common/retry.py +1 -0
  61. truss/{server → templates/server}/common/schema.py +11 -9
  62. truss/templates/server/common/tracing.py +157 -0
  63. truss/templates/server/main.py +9 -0
  64. truss/templates/server/model_wrapper.py +961 -0
  65. truss/templates/server/requirements.txt +21 -0
  66. truss/templates/server/truss_server.py +447 -0
  67. truss/templates/server.Dockerfile.jinja +62 -14
  68. truss/templates/shared/dynamic_config_resolver.py +28 -0
  69. truss/templates/shared/lazy_data_resolver.py +164 -0
  70. truss/templates/shared/log_config.py +125 -0
  71. truss/{server → templates}/shared/secrets_resolver.py +1 -2
  72. truss/{server → templates}/shared/serialization.py +31 -9
  73. truss/{server → templates}/shared/util.py +3 -13
  74. truss/templates/trtllm-audio/model/model.py +49 -0
  75. truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
  76. truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
  77. truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
  78. truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
  79. truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
  80. truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
  81. truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
  82. truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
  83. truss/templates/trtllm-briton/src/extension.py +64 -0
  84. truss/tests/conftest.py +302 -94
  85. truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
  86. truss/tests/contexts/local_loader/test_load_local.py +2 -2
  87. truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
  88. truss/tests/patch/test_calc_patch.py +439 -127
  89. truss/tests/patch/test_dir_signature.py +3 -12
  90. truss/tests/patch/test_hash.py +1 -1
  91. truss/tests/patch/test_signature.py +1 -1
  92. truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
  93. truss/tests/patch/test_types.py +2 -2
  94. truss/tests/remote/baseten/test_api.py +153 -58
  95. truss/tests/remote/baseten/test_auth.py +2 -1
  96. truss/tests/remote/baseten/test_core.py +160 -12
  97. truss/tests/remote/baseten/test_remote.py +489 -77
  98. truss/tests/remote/baseten/test_service.py +55 -0
  99. truss/tests/remote/test_remote_factory.py +16 -18
  100. truss/tests/remote/test_truss_remote.py +26 -17
  101. truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
  102. truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
  103. truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
  104. truss/tests/{server → templates/control}/control/test_server.py +79 -24
  105. truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
  106. truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
  107. truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
  108. truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
  109. truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
  110. truss/tests/{server → templates/server}/common/test_retry.py +3 -3
  111. truss/tests/templates/server/test_model_wrapper.py +248 -0
  112. truss/tests/{server → templates/server}/test_schema.py +3 -5
  113. truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
  114. truss/tests/test_build.py +9 -52
  115. truss/tests/test_config.py +336 -77
  116. truss/tests/test_context_builder_image.py +3 -11
  117. truss/tests/test_control_truss_patching.py +7 -12
  118. truss/tests/test_custom_server.py +38 -0
  119. truss/tests/test_data/context_builder_image_test/test.py +3 -0
  120. truss/tests/test_data/happy.ipynb +56 -0
  121. truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
  122. truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
  123. truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
  124. truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
  125. truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
  126. truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
  127. truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
  128. truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
  129. truss/tests/test_data/test_async_truss/__init__.py +0 -0
  130. truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
  131. truss/tests/test_data/test_basic_truss/__init__.py +0 -0
  132. truss/tests/test_data/test_basic_truss/config.yaml +16 -0
  133. truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
  134. truss/tests/test_data/test_build_commands/__init__.py +0 -0
  135. truss/tests/test_data/test_build_commands/config.yaml +13 -0
  136. truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
  137. truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
  138. truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
  139. truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
  140. truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
  141. truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
  142. truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
  143. truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
  144. truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
  145. truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
  146. truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
  147. truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
  148. truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
  149. truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
  150. truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
  151. truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
  152. truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
  153. truss/tests/test_data/test_openai/__init__.py +0 -0
  154. truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
  155. truss/tests/test_data/test_openai/model/__init__.py +0 -0
  156. truss/tests/test_data/test_openai/model/model.py +15 -0
  157. truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
  158. truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
  159. truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
  160. truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
  161. truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
  162. truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
  163. truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
  164. truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
  165. truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
  166. truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
  167. truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
  168. truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
  169. truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
  170. truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
  171. truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
  172. truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
  173. truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
  174. truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
  175. truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
  176. truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
  177. truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
  178. truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
  179. truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
  180. truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
  181. truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
  182. truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
  183. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
  184. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
  185. truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
  186. truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
  187. truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
  188. truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
  189. truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
  190. truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
  191. truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
  192. truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
  193. truss/tests/test_data/test_truss/__init__.py +0 -0
  194. truss/tests/test_data/test_truss/config.yaml +4 -0
  195. truss/tests/test_data/test_truss/model/__init__.py +0 -0
  196. truss/tests/test_data/test_truss/model/dummy +0 -0
  197. truss/tests/test_data/test_truss/packages/__init__.py +0 -0
  198. truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
  199. truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
  200. truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
  201. truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
  202. truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
  203. truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
  204. truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
  205. truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
  206. truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
  207. truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
  208. truss/tests/test_docker.py +2 -1
  209. truss/tests/test_model_inference.py +1340 -292
  210. truss/tests/test_model_schema.py +33 -26
  211. truss/tests/test_testing_utilities_for_other_tests.py +50 -5
  212. truss/tests/test_truss_gatherer.py +3 -5
  213. truss/tests/test_truss_handle.py +62 -59
  214. truss/tests/test_util.py +2 -1
  215. truss/tests/test_validation.py +15 -13
  216. truss/tests/trt_llm/test_trt_llm_config.py +41 -0
  217. truss/tests/trt_llm/test_validation.py +91 -0
  218. truss/tests/util/test_config_checks.py +40 -0
  219. truss/tests/util/test_env_vars.py +14 -0
  220. truss/tests/util/test_path.py +10 -23
  221. truss/trt_llm/config_checks.py +43 -0
  222. truss/trt_llm/validation.py +42 -0
  223. truss/truss_handle/__init__.py +0 -0
  224. truss/truss_handle/build.py +122 -0
  225. truss/{decorators.py → truss_handle/decorators.py} +1 -1
  226. truss/truss_handle/patch/__init__.py +0 -0
  227. truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
  228. truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
  229. truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
  230. truss/truss_handle/patch/hash.py +71 -0
  231. truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
  232. truss/truss_handle/patch/signature.py +22 -0
  233. truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
  234. truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
  235. truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
  236. truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
  237. truss/util/.truss_ignore +3 -0
  238. truss/{docker.py → util/docker.py} +6 -2
  239. truss/util/download.py +6 -15
  240. truss/util/env_vars.py +41 -0
  241. truss/util/log_utils.py +52 -0
  242. truss/util/path.py +20 -20
  243. truss/util/requirements.py +11 -0
  244. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
  245. truss-0.60.0.dist-info/RECORD +324 -0
  246. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
  247. truss-0.60.0.dist-info/entry_points.txt +4 -0
  248. truss_chains/__init__.py +71 -0
  249. truss_chains/definitions.py +756 -0
  250. truss_chains/deployment/__init__.py +0 -0
  251. truss_chains/deployment/code_gen.py +816 -0
  252. truss_chains/deployment/deployment_client.py +871 -0
  253. truss_chains/framework.py +1480 -0
  254. truss_chains/public_api.py +231 -0
  255. truss_chains/py.typed +0 -0
  256. truss_chains/pydantic_numpy.py +131 -0
  257. truss_chains/reference_code/reference_chainlet.py +34 -0
  258. truss_chains/reference_code/reference_model.py +10 -0
  259. truss_chains/remote_chainlet/__init__.py +0 -0
  260. truss_chains/remote_chainlet/model_skeleton.py +60 -0
  261. truss_chains/remote_chainlet/stub.py +380 -0
  262. truss_chains/remote_chainlet/utils.py +332 -0
  263. truss_chains/streaming.py +378 -0
  264. truss_chains/utils.py +178 -0
  265. CODE_OF_CONDUCT.md +0 -131
  266. CONTRIBUTING.md +0 -48
  267. README.md +0 -137
  268. context_builder.Dockerfile +0 -24
  269. truss/blob/blob_backend.py +0 -10
  270. truss/blob/blob_backend_registry.py +0 -23
  271. truss/blob/http_public_blob_backend.py +0 -23
  272. truss/build/__init__.py +0 -2
  273. truss/build/build.py +0 -143
  274. truss/build/configure.py +0 -63
  275. truss/cli/__init__.py +0 -2
  276. truss/cli/console.py +0 -5
  277. truss/cli/create.py +0 -5
  278. truss/config/trt_llm.py +0 -81
  279. truss/constants.py +0 -61
  280. truss/model_inference.py +0 -123
  281. truss/patch/types.py +0 -30
  282. truss/pytest.ini +0 -7
  283. truss/server/common/errors.py +0 -100
  284. truss/server/common/termination_handler_middleware.py +0 -64
  285. truss/server/common/truss_server.py +0 -389
  286. truss/server/control/patch/model_code_patch_applier.py +0 -46
  287. truss/server/control/patch/requirement_name_identifier.py +0 -17
  288. truss/server/inference_server.py +0 -29
  289. truss/server/model_wrapper.py +0 -434
  290. truss/server/shared/logging.py +0 -81
  291. truss/templates/trtllm/model/model.py +0 -97
  292. truss/templates/trtllm/packages/build_engine_utils.py +0 -34
  293. truss/templates/trtllm/packages/constants.py +0 -11
  294. truss/templates/trtllm/packages/schema.py +0 -216
  295. truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
  296. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
  297. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
  298. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
  299. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
  300. truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
  301. truss/templates/trtllm/packages/triton_client.py +0 -150
  302. truss/templates/trtllm/packages/utils.py +0 -43
  303. truss/test_data/context_builder_image_test/test.py +0 -4
  304. truss/test_data/happy.ipynb +0 -54
  305. truss/test_data/model_load_failure_test/config.yaml +0 -2
  306. truss/test_data/test_concurrency_truss/config.yaml +0 -2
  307. truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
  308. truss/test_data/test_streaming_truss/config.yaml +0 -3
  309. truss/test_data/test_truss/config.yaml +0 -2
  310. truss/tests/server/common/test_termination_handler_middleware.py +0 -93
  311. truss/tests/server/control/test_model_container_patch_applier.py +0 -203
  312. truss/tests/server/core/server/common/test_util.py +0 -19
  313. truss/tests/server/test_model_wrapper.py +0 -87
  314. truss/util/data_structures.py +0 -16
  315. truss-0.10.0rc1.dist-info/RECORD +0 -216
  316. truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
  317. truss/{server/shared → base}/__init__.py +0 -0
  318. truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
  319. truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
  320. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
  321. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
  322. truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
  323. truss/{server → templates/server}/common/__init__.py +0 -0
  324. truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
  325. truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
  326. truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
  327. truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
  328. truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
  329. truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
  330. truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
  331. truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
  332. truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
  333. truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
  334. truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
  335. truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
  336. truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
  337. truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
  338. truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
  339. truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
  340. truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
  341. truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
  342. truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
  343. truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
  344. truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
  345. truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
  346. truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
  347. /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
  348. /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
  349. /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
  350. /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
  351. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
  352. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
  353. /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
  354. /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
  355. /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
  356. /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
  357. /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
  358. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
  359. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
  360. /truss/{patch → truss_handle/patch}/constants.py +0 -0
  361. /truss/{notebook.py → util/notebook.py} +0 -0
  362. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
@@ -0,0 +1,871 @@
1
+ import abc
2
+ import concurrent.futures
3
+ import inspect
4
+ import json
5
+ import logging
6
+ import pathlib
7
+ import textwrap
8
+ import traceback
9
+ import uuid
10
+ from typing import (
11
+ TYPE_CHECKING,
12
+ Any,
13
+ Callable,
14
+ Dict,
15
+ Iterable,
16
+ Iterator,
17
+ Mapping,
18
+ Optional,
19
+ Type,
20
+ cast,
21
+ )
22
+
23
+ import requests
24
+ import tenacity
25
+ import watchfiles
26
+ from truss.local import local_config_handler
27
+ from truss.remote import remote_factory
28
+ from truss.remote.baseten import core as b10_core
29
+ from truss.remote.baseten import custom_types as b10_types
30
+ from truss.remote.baseten import error as b10_errors
31
+ from truss.remote.baseten import remote as b10_remote
32
+ from truss.remote.baseten import service as b10_service
33
+ from truss.truss_handle import truss_handle
34
+ from truss.util import log_utils
35
+ from truss.util import path as truss_path
36
+
37
+ from truss_chains import definitions, framework, utils
38
+ from truss_chains.deployment import code_gen
39
+
40
+ if TYPE_CHECKING:
41
+ from rich import console as rich_console
42
+ from rich import progress
43
+
44
+
45
+ def _get_ordered_dependencies(
46
+ chainlets: Iterable[Type[definitions.ABCChainlet]],
47
+ ) -> Iterable[definitions.ChainletAPIDescriptor]:
48
+ """Gather all Chainlets needed and returns a topologically ordered list."""
49
+ needed_chainlets: set[definitions.ChainletAPIDescriptor] = set()
50
+
51
+ def add_needed_chainlets(chainlet: definitions.ChainletAPIDescriptor):
52
+ needed_chainlets.add(chainlet)
53
+ for chainlet_descriptor in framework.get_dependencies(chainlet):
54
+ needed_chainlets.add(chainlet_descriptor)
55
+ add_needed_chainlets(chainlet_descriptor)
56
+
57
+ for chainlet_cls in chainlets:
58
+ add_needed_chainlets(framework.get_descriptor(chainlet_cls))
59
+ # Get dependencies in topological order.
60
+ return [
61
+ descr
62
+ for descr in framework.get_ordered_descriptors()
63
+ if descr in needed_chainlets
64
+ ]
65
+
66
+
67
+ def _get_chain_root(entrypoint: Type[definitions.ABCChainlet]) -> pathlib.Path:
68
+ # TODO: revisit how chain root is inferred/specified, current might be brittle.
69
+ chain_root = pathlib.Path(inspect.getfile(entrypoint)).absolute().parent
70
+ logging.info(
71
+ f"Using chain workspace dir: `{chain_root}` (files under this dir will "
72
+ "be included as dependencies in the remote deployments and are importable)."
73
+ )
74
+ return chain_root
75
+
76
+
77
+ class ChainService(abc.ABC):
78
+ """Handle for a deployed chain.
79
+
80
+ A ``ChainService`` is created and returned when using ``push``. It
81
+ bundles the individual services for each chainlet in the chain, and provides
82
+ utilities to query their status, invoke the entrypoint etc.
83
+ """
84
+
85
+ _name: str
86
+ _entrypoint_fake_json_data: Any
87
+
88
+ def __init__(self, name: str):
89
+ self._name = name
90
+ self._entrypoint_fake_json_data = None
91
+
92
+ @property
93
+ def name(self) -> str:
94
+ return self._name
95
+
96
+ @property
97
+ @abc.abstractmethod
98
+ def status_page_url(self) -> str:
99
+ """Link to status page on Baseten."""
100
+
101
+ @property
102
+ @abc.abstractmethod
103
+ def run_remote_url(self) -> str:
104
+ """URL to invoke the entrypoint."""
105
+
106
+ @abc.abstractmethod
107
+ def run_remote(self, json: Dict) -> Any:
108
+ """Invokes the entrypoint with JSON data.
109
+
110
+ Returns:
111
+ The JSON response."""
112
+
113
+ @abc.abstractmethod
114
+ def get_info(self) -> list[b10_types.DeployedChainlet]:
115
+ """Queries the statuses of all chainlets in the chain.
116
+
117
+ Returns:
118
+ List of ``DeployedChainlet`` for each chainlet."""
119
+
120
+ @property
121
+ def entrypoint_fake_json_data(self) -> Any:
122
+ """Fake JSON example data that matches the entrypoint's input schema.
123
+ This property must be externally populated.
124
+
125
+ Raises:
126
+ ValueError: If fake data was not set.
127
+ """
128
+ if self._entrypoint_fake_json_data is None:
129
+ raise ValueError("Fake data was not set.")
130
+ return self._entrypoint_fake_json_data
131
+
132
+ @entrypoint_fake_json_data.setter
133
+ def entrypoint_fake_json_data(self, fake_data: Any) -> None:
134
+ self._entrypoint_fake_json_data = fake_data
135
+
136
+
137
+ class _ChainSourceGenerator:
138
+ def __init__(self, options: definitions.PushOptions) -> None:
139
+ self._options = options
140
+
141
+ @property
142
+ def _use_local_chains_src(self) -> bool:
143
+ if isinstance(self._options, definitions.PushOptionsLocalDocker):
144
+ return self._options.use_local_chains_src
145
+ return False
146
+
147
+ def generate_chainlet_artifacts(
148
+ self, entrypoint: Type[definitions.ABCChainlet]
149
+ ) -> tuple[b10_types.ChainletArtifact, list[b10_types.ChainletArtifact]]:
150
+ chain_root = _get_chain_root(entrypoint)
151
+ entrypoint_artifact: Optional[b10_types.ChainletArtifact] = None
152
+ dependency_artifacts: list[b10_types.ChainletArtifact] = []
153
+ chainlet_display_names: set[str] = set()
154
+
155
+ for chainlet_descriptor in _get_ordered_dependencies([entrypoint]):
156
+ chainlet_display_name = chainlet_descriptor.display_name
157
+
158
+ if chainlet_display_name in chainlet_display_names:
159
+ raise definitions.ChainsUsageError(
160
+ f"Chainlet names must be unique. Found multiple Chainlets with the name: '{chainlet_display_name}'."
161
+ )
162
+
163
+ chainlet_display_names.add(chainlet_display_name)
164
+
165
+ # Since we are creating a distinct model for each deployment of the chain,
166
+ # we add a random suffix.
167
+ model_suffix = str(uuid.uuid4()).split("-")[0]
168
+ model_name = f"{chainlet_display_name}-{model_suffix}"
169
+
170
+ chainlet_dir = code_gen.gen_truss_chainlet(
171
+ chain_root,
172
+ self._options.chain_name,
173
+ chainlet_descriptor,
174
+ model_name,
175
+ self._use_local_chains_src,
176
+ )
177
+ artifact = b10_types.ChainletArtifact(
178
+ truss_dir=chainlet_dir,
179
+ name=chainlet_descriptor.name,
180
+ display_name=chainlet_display_name,
181
+ )
182
+
183
+ is_entrypoint = chainlet_descriptor.chainlet_cls == entrypoint
184
+
185
+ if is_entrypoint:
186
+ assert entrypoint_artifact is None
187
+
188
+ entrypoint_artifact = artifact
189
+ else:
190
+ dependency_artifacts.append(artifact)
191
+
192
+ assert entrypoint_artifact is not None
193
+
194
+ return entrypoint_artifact, dependency_artifacts
195
+
196
+
197
+ @framework.raise_validation_errors_before
198
+ def push(
199
+ entrypoint: Type[definitions.ABCChainlet],
200
+ options: definitions.PushOptions,
201
+ progress_bar: Optional[Type["progress.Progress"]] = None,
202
+ ) -> Optional[ChainService]:
203
+ entrypoint_artifact, dependency_artifacts = _ChainSourceGenerator(
204
+ options
205
+ ).generate_chainlet_artifacts(entrypoint)
206
+ if options.only_generate_trusses:
207
+ return None
208
+ if isinstance(options, definitions.PushOptionsBaseten):
209
+ return _create_baseten_chain(
210
+ options, entrypoint_artifact, dependency_artifacts, progress_bar
211
+ )
212
+ elif isinstance(options, definitions.PushOptionsLocalDocker):
213
+ return _create_docker_chain(options, entrypoint_artifact, dependency_artifacts)
214
+ else:
215
+ raise NotImplementedError(options)
216
+
217
+
218
+ # Docker ###############################################################################
219
+
220
+
221
+ class DockerChainletService(b10_service.TrussService):
222
+ """This service is for Chainlets (not for Chains)."""
223
+
224
+ def __init__(self, port: int, **kwargs):
225
+ remote_url = f"http://localhost:{port}"
226
+
227
+ super().__init__(remote_url, is_draft=False, **kwargs)
228
+
229
+ def authenticate(self) -> Dict[str, str]:
230
+ return {}
231
+
232
+ def is_live(self) -> bool:
233
+ response = self._send_request(self._service_url, "GET")
234
+ if response.status_code == 200:
235
+ return True
236
+ return False
237
+
238
+ def is_ready(self) -> bool:
239
+ response = self._send_request(self._service_url, "GET")
240
+ if response.status_code == 200:
241
+ return True
242
+ return False
243
+
244
+ @property
245
+ def logs_url(self) -> str:
246
+ raise NotImplementedError()
247
+
248
+ @property
249
+ def predict_url(self) -> str:
250
+ return f"{self._service_url}/v1/models/model:predict"
251
+
252
+ def poll_deployment_status(self, sleep_secs: int = 1) -> Iterator[str]:
253
+ raise NotImplementedError()
254
+
255
+
256
+ def _push_service_docker(
257
+ truss_dir: pathlib.Path,
258
+ chainlet_display_name: str,
259
+ options: definitions.PushOptionsLocalDocker,
260
+ port: int,
261
+ ) -> None:
262
+ th = truss_handle.TrussHandle(truss_dir)
263
+ th.add_secret(definitions.BASETEN_API_SECRET_NAME, options.baseten_chain_api_key)
264
+ th.docker_run(
265
+ local_port=port,
266
+ detach=True,
267
+ wait_for_server_ready=True,
268
+ network="host",
269
+ container_name_prefix=chainlet_display_name,
270
+ disable_json_logging=True,
271
+ )
272
+
273
+
274
+ class DockerChainService(ChainService):
275
+ _entrypoint_service: DockerChainletService
276
+
277
+ def __init__(self, name: str, entrypoint_service: DockerChainletService) -> None:
278
+ super().__init__(name)
279
+ self._entrypoint_service = entrypoint_service
280
+
281
+ @property
282
+ def run_remote_url(self) -> str:
283
+ """URL to invoke the entrypoint."""
284
+ return self._entrypoint_service.predict_url
285
+
286
+ def run_remote(self, json: Dict) -> Any:
287
+ """Invokes the entrypoint with JSON data.
288
+
289
+ Returns:
290
+ The JSON response."""
291
+ return self._entrypoint_service.predict(json)
292
+
293
+ @property
294
+ def status_page_url(self) -> str:
295
+ """Not Implemented."""
296
+ raise NotImplementedError()
297
+
298
+ def get_info(self) -> list[b10_types.DeployedChainlet]:
299
+ """Not Implemented."""
300
+ raise NotImplementedError()
301
+
302
+
303
+ def _create_docker_chain(
304
+ docker_options: definitions.PushOptionsLocalDocker,
305
+ entrypoint_artifact: b10_types.ChainletArtifact,
306
+ dependency_artifacts: list[b10_types.ChainletArtifact],
307
+ ) -> DockerChainService:
308
+ chainlet_artifacts = [*dependency_artifacts, entrypoint_artifact]
309
+ chainlet_to_predict_url: Dict[str, Dict[str, str]] = {}
310
+ chainlet_to_service: Dict[str, DockerChainletService] = {}
311
+ for chainlet_artifact in chainlet_artifacts:
312
+ port = utils.get_free_port()
313
+ service = DockerChainletService(port)
314
+
315
+ docker_internal_url = service.predict_url.replace(
316
+ "localhost", "host.docker.internal"
317
+ )
318
+ chainlet_to_predict_url[chainlet_artifact.display_name] = {
319
+ "predict_url": docker_internal_url
320
+ }
321
+ chainlet_to_service[chainlet_artifact.name] = service
322
+
323
+ local_config_handler.LocalConfigHandler.set_dynamic_config(
324
+ definitions.DYNAMIC_CHAINLET_CONFIG_KEY, json.dumps(chainlet_to_predict_url)
325
+ )
326
+
327
+ truss_dir = chainlet_artifact.truss_dir
328
+ logging.info(
329
+ f"Building Chainlet `{chainlet_artifact.display_name}` docker image."
330
+ )
331
+ _push_service_docker(
332
+ truss_dir, chainlet_artifact.display_name, docker_options, port
333
+ )
334
+ logging.info(
335
+ f"Pushed Chainlet `{chainlet_artifact.display_name}` as docker container."
336
+ )
337
+ logging.debug(
338
+ "Internal model endpoint: "
339
+ f"`{chainlet_to_predict_url[chainlet_artifact.display_name]}`"
340
+ )
341
+
342
+ return DockerChainService(
343
+ docker_options.chain_name, chainlet_to_service[entrypoint_artifact.name]
344
+ )
345
+
346
+
347
+ # Baseten ##############################################################################
348
+
349
+
350
+ class BasetenChainService(ChainService):
351
+ _chain_deployment_handle: b10_core.ChainDeploymentHandleAtomic
352
+ _remote: b10_remote.BasetenRemote
353
+
354
+ def __init__(
355
+ self,
356
+ name: str,
357
+ chain_deployment_handle: b10_core.ChainDeploymentHandleAtomic,
358
+ remote: b10_remote.BasetenRemote,
359
+ ) -> None:
360
+ super().__init__(name)
361
+ self._chain_deployment_handle = chain_deployment_handle
362
+ self._remote = remote
363
+
364
+ @property
365
+ def run_remote_url(self) -> str:
366
+ """URL to invoke the entrypoint."""
367
+ return b10_service.URLConfig.invocation_url(
368
+ self._remote.api.rest_api_url,
369
+ b10_service.URLConfig.CHAIN,
370
+ self._chain_deployment_handle.chain_id,
371
+ self._chain_deployment_handle.chain_deployment_id,
372
+ self._chain_deployment_handle.is_draft,
373
+ )
374
+
375
+ def run_remote(self, json_data: Dict) -> Any:
376
+ """Invokes the entrypoint with JSON data.
377
+
378
+ Returns:
379
+ The JSON response."""
380
+ headers = self._remote._auth_service.authenticate().header()
381
+ response = requests.post(
382
+ self.run_remote_url, json=json_data, headers=headers, stream=True
383
+ )
384
+ if response.status_code == 401:
385
+ raise ValueError(
386
+ f"Authentication failed with status code {response.status_code}"
387
+ )
388
+
389
+ if response.headers.get("transfer-encoding") == "chunked":
390
+ # Case of streaming response, the backend does not set an encoding, so
391
+ # manually decode to the contents to utf-8 here.
392
+ def decode_content():
393
+ for chunk in response.iter_content(
394
+ chunk_size=8192, decode_unicode=True
395
+ ):
396
+ # Depending on the content-type of the response,
397
+ # iter_content will either emit a byte stream, or a stream
398
+ # of strings. Only decode in the bytes case.
399
+ if isinstance(chunk, bytes):
400
+ yield chunk.decode(
401
+ response.encoding or b10_service.DEFAULT_STREAM_ENCODING
402
+ )
403
+ else:
404
+ yield chunk
405
+
406
+ return decode_content()
407
+
408
+ parsed_response = response.json()
409
+
410
+ if "error" in parsed_response:
411
+ # In the case that the model is in a non-ready state, the response
412
+ # will be a json with an `error` key.
413
+ return parsed_response
414
+
415
+ return response.json()
416
+
417
+ @property
418
+ def status_page_url(self) -> str:
419
+ """Link to status page on Baseten."""
420
+ return b10_service.URLConfig.status_page_url(
421
+ self._remote.remote_url,
422
+ b10_service.URLConfig.CHAIN,
423
+ self._chain_deployment_handle.chain_id,
424
+ )
425
+
426
+ @tenacity.retry(
427
+ stop=tenacity.stop_after_delay(300), wait=tenacity.wait_fixed(1), reraise=True
428
+ )
429
+ def get_info(self) -> list[b10_types.DeployedChainlet]:
430
+ """Queries the statuses of all chainlets in the chain.
431
+
432
+ Returns:
433
+ List of ``DeployedChainlet`` for each chainlet."""
434
+ return self._remote.get_chainlets(
435
+ self._chain_deployment_handle.chain_deployment_id
436
+ )
437
+
438
+
439
+ def _create_baseten_chain(
440
+ baseten_options: definitions.PushOptionsBaseten,
441
+ entrypoint_artifact: b10_types.ChainletArtifact,
442
+ dependency_artifacts: list[b10_types.ChainletArtifact],
443
+ progress_bar: Optional[Type["progress.Progress"]],
444
+ ):
445
+ logging.info(
446
+ f"Pushing Chain '{baseten_options.chain_name}' to Baseten "
447
+ f"(publish={baseten_options.publish}, environment={baseten_options.environment})."
448
+ )
449
+ remote_provider = cast(
450
+ b10_remote.BasetenRemote,
451
+ remote_factory.RemoteFactory.create(remote=baseten_options.remote),
452
+ )
453
+ _create_chains_secret_if_missing(remote_provider)
454
+
455
+ chain_deployment_handle = remote_provider.push_chain_atomic(
456
+ chain_name=baseten_options.chain_name,
457
+ entrypoint_artifact=entrypoint_artifact,
458
+ dependency_artifacts=dependency_artifacts,
459
+ publish=baseten_options.publish,
460
+ environment=baseten_options.environment,
461
+ progress_bar=progress_bar,
462
+ )
463
+ return BasetenChainService(
464
+ baseten_options.chain_name, chain_deployment_handle, remote_provider
465
+ )
466
+
467
+
468
+ def _create_chains_secret_if_missing(remote_provider: b10_remote.BasetenRemote) -> None:
469
+ secrets_info = remote_provider.api.get_all_secrets()
470
+ secret_names = {sec["name"] for sec in secrets_info["secrets"]}
471
+ if definitions.BASETEN_API_SECRET_NAME not in secret_names:
472
+ logging.info(
473
+ "It seems you are using chains for the first time, since there "
474
+ f"is no `{definitions.BASETEN_API_SECRET_NAME}` secret on baseten. "
475
+ "Creating secret automatically."
476
+ )
477
+ remote_provider.api.upsert_secret(
478
+ definitions.BASETEN_API_SECRET_NAME, remote_provider.api.auth_token.value
479
+ )
480
+
481
+
482
+ # Watch / Live Patching ################################################################
483
+
484
+
485
+ def _create_watch_filter(root_dir: pathlib.Path):
486
+ ignore_patterns = truss_path.load_trussignore_patterns_from_truss_dir(root_dir)
487
+
488
+ def watch_filter(_: watchfiles.Change, path: str) -> bool:
489
+ return not truss_path.is_ignored(pathlib.Path(path), ignore_patterns)
490
+
491
+ logging.getLogger("watchfiles.main").disabled = True
492
+ return ignore_patterns, watch_filter
493
+
494
+
495
+ def _handle_intercepted_logs(logs: list[str], console: "rich_console.Console"):
496
+ if logs:
497
+ formatted_logs = textwrap.indent("\n".join(logs), " " * 4)
498
+ console.print(f"Intercepted logs from importing source code:\n{formatted_logs}")
499
+
500
+
501
+ def _handle_import_error(
502
+ exception: Exception,
503
+ console: "rich_console.Console",
504
+ error_console: "rich_console.Console",
505
+ stack_trace: Optional[str] = None,
506
+ ):
507
+ error_console.print(
508
+ "Source files were changed, but pre-conditions for "
509
+ "live patching are not given. Most likely there is a "
510
+ "syntax error in the source files or names changed. "
511
+ "Try to fix the issue and save the file. Error:\n"
512
+ f"{textwrap.indent(str(exception), ' ' * 4)}"
513
+ )
514
+ if stack_trace:
515
+ error_console.print(stack_trace)
516
+
517
+ console.print(
518
+ "The watcher will continue and if you can resolve the "
519
+ "issue, subsequent patches might succeed.",
520
+ style="blue",
521
+ )
522
+
523
+
524
+ class _ModelWatcher:
525
+ _source: pathlib.Path
526
+ _model_name: str
527
+ _remote_provider: b10_remote.BasetenRemote
528
+ _ignore_patterns: list[str]
529
+ _watch_filter: Callable[[watchfiles.Change, str], bool]
530
+ _console: "rich_console.Console"
531
+ _error_console: "rich_console.Console"
532
+
533
+ def __init__(
534
+ self,
535
+ source: pathlib.Path,
536
+ model_name: str,
537
+ remote_provider: b10_remote.BasetenRemote,
538
+ console: "rich_console.Console",
539
+ error_console: "rich_console.Console",
540
+ ) -> None:
541
+ self._source = source
542
+ self._model_name = model_name
543
+ self._remote_provider = remote_provider
544
+ self._console = console
545
+ self._error_console = error_console
546
+ self._ignore_patterns, self._watch_filter = _create_watch_filter(
547
+ source.absolute().parent
548
+ )
549
+
550
+ dev_version = b10_core.get_dev_version(self._remote_provider.api, model_name)
551
+ if not dev_version:
552
+ raise b10_errors.RemoteError(
553
+ "No development model found. Run `truss push` then try again."
554
+ )
555
+
556
+ def _patch(self) -> None:
557
+ exception_raised = None
558
+ with log_utils.LogInterceptor() as log_interceptor, self._console.status(
559
+ " Live Patching Model.\n", spinner="arrow3"
560
+ ):
561
+ try:
562
+ gen_truss_path = code_gen.gen_truss_model_from_source(self._source)
563
+ return self._remote_provider.patch(
564
+ gen_truss_path,
565
+ self._ignore_patterns,
566
+ self._console,
567
+ self._error_console,
568
+ )
569
+ except Exception as e:
570
+ exception_raised = e
571
+ finally:
572
+ logs = log_interceptor.get_logs()
573
+
574
+ _handle_intercepted_logs(logs, self._console)
575
+ if exception_raised:
576
+ _handle_import_error(exception_raised, self._console, self._error_console)
577
+
578
+ def watch(self) -> None:
579
+ # Perform one initial patch at startup.
580
+ self._patch()
581
+ self._console.print("👀 Watching for new changes.", style="blue")
582
+
583
+ # TODO(nikhil): Improve detection of directory structure, since right now
584
+ # we assume a flat structure
585
+ root_dir = self._source.absolute().parent
586
+ for _ in watchfiles.watch(
587
+ root_dir, watch_filter=self._watch_filter, raise_interrupt=False
588
+ ):
589
+ self._patch()
590
+ self._console.print("👀 Watching for new changes.", style="blue")
591
+
592
+
593
+ class _Watcher:
594
+ _source: pathlib.Path
595
+ _entrypoint: Optional[str]
596
+ _deployed_chain_name: str
597
+ _remote_provider: b10_remote.BasetenRemote
598
+ _chainlet_data: Mapping[str, b10_types.DeployedChainlet]
599
+ _watch_filter: Callable[[watchfiles.Change, str], bool]
600
+ _console: "rich_console.Console"
601
+ _error_console: "rich_console.Console"
602
+ _show_stack_trace: bool
603
+ _included_chainlets: set[str]
604
+
605
+ def __init__(
606
+ self,
607
+ source: pathlib.Path,
608
+ entrypoint: Optional[str],
609
+ name: Optional[str],
610
+ remote: str,
611
+ console: "rich_console.Console",
612
+ error_console: "rich_console.Console",
613
+ show_stack_trace: bool,
614
+ included_chainlets: Optional[list[str]],
615
+ ) -> None:
616
+ self._source = source
617
+ self._entrypoint = entrypoint
618
+ self._console = console
619
+ self._error_console = error_console
620
+ self._show_stack_trace = show_stack_trace
621
+ self._remote_provider = cast(
622
+ b10_remote.BasetenRemote, remote_factory.RemoteFactory.create(remote=remote)
623
+ )
624
+ with framework.ChainletImporter.import_target(
625
+ source, entrypoint
626
+ ) as entrypoint_cls:
627
+ self._deployed_chain_name = name or entrypoint_cls.__name__
628
+ self._chain_root = _get_chain_root(entrypoint_cls)
629
+ chainlet_names = set(
630
+ desc.display_name
631
+ for desc in _get_ordered_dependencies([entrypoint_cls])
632
+ )
633
+
634
+ if included_chainlets:
635
+ if not_matched := (set(included_chainlets) - chainlet_names):
636
+ raise definitions.ChainsDeploymentError(
637
+ "Requested to watch specific chainlets, but did not find "
638
+ f"{not_matched} among available chainlets {chainlet_names}."
639
+ )
640
+ self._included_chainlets = set(included_chainlets)
641
+ else:
642
+ self._included_chainlets = chainlet_names
643
+
644
+ chain_id = b10_core.get_chain_id_by_name(
645
+ self._remote_provider.api, self._deployed_chain_name
646
+ )
647
+ if not chain_id:
648
+ raise definitions.ChainsDeploymentError(
649
+ f"Chain `{chain_id}` was not found."
650
+ )
651
+ self._status_page_url = b10_service.URLConfig.status_page_url(
652
+ self._remote_provider.remote_url, b10_service.URLConfig.CHAIN, chain_id
653
+ )
654
+ chain_deployment = b10_core.get_dev_chain_deployment(
655
+ self._remote_provider.api, chain_id
656
+ )
657
+ if chain_deployment is None:
658
+ raise definitions.ChainsDeploymentError(
659
+ f"No development deployment was found for Chain `{chain_id}`. "
660
+ "You cannot live-patch production deployments. Check the Chain's "
661
+ f"status page for available deployments: {self._status_page_url}."
662
+ )
663
+ deployed_chainlets = self._remote_provider.get_chainlets(chain_deployment["id"])
664
+ non_draft_chainlets = [
665
+ chainlet.name for chainlet in deployed_chainlets if not chainlet.is_draft
666
+ ]
667
+ assert not (non_draft_chainlets), (
668
+ "If the chain is draft, the oracles must be draft."
669
+ )
670
+
671
+ self._chainlet_data = {c.name: c for c in deployed_chainlets}
672
+ self._assert_chainlet_names_same(chainlet_names)
673
+ self._ignore_patterns, self._watch_filter = _create_watch_filter(
674
+ self._chain_root
675
+ )
676
+
677
+ @property
678
+ def _original_chainlet_names(self) -> set[str]:
679
+ return set(self._chainlet_data.keys())
680
+
681
+ def _assert_chainlet_names_same(self, new_names: set[str]) -> None:
682
+ missing = self._original_chainlet_names - new_names
683
+ added = new_names - self._original_chainlet_names
684
+ if not (missing or added):
685
+ return
686
+ msg_parts = [
687
+ "The deployed Chainlets and the Chainlets in the current workspace differ. "
688
+ "Live patching is not possible if the set of Chainlet names differ."
689
+ ]
690
+ if missing:
691
+ msg_parts.append(f"Chainlets missing in current workspace: {list(missing)}")
692
+ if added:
693
+ msg_parts.append(f"Chainlets added in current workspace: {list(added)}")
694
+
695
+ raise definitions.ChainsDeploymentError("\n".join(msg_parts))
696
+
697
+ def _code_gen_and_patch_thread(
698
+ self, descr: definitions.ChainletAPIDescriptor
699
+ ) -> tuple[b10_remote.PatchResult, list[str]]:
700
+ with log_utils.LogInterceptor() as log_interceptor:
701
+ # TODO: Maybe try-except code_gen errors explicitly.
702
+ chainlet_dir = code_gen.gen_truss_chainlet(
703
+ self._chain_root,
704
+ self._deployed_chain_name,
705
+ descr,
706
+ self._chainlet_data[descr.display_name].oracle_name,
707
+ use_local_chains_src=False,
708
+ )
709
+ patch_result = self._remote_provider.patch_for_chainlet(
710
+ chainlet_dir, self._ignore_patterns
711
+ )
712
+ logs = log_interceptor.get_logs()
713
+ return patch_result, logs
714
+
715
+ def _patch(self, executor: concurrent.futures.Executor) -> None:
716
+ exception_raised = None
717
+ stack_trace = ""
718
+ with log_utils.LogInterceptor() as log_interceptor, self._console.status(
719
+ " Live Patching Chain.\n", spinner="arrow3"
720
+ ):
721
+ # Handle import errors gracefully (e.g. if user saved file, but there
722
+ # are syntax errors, undefined symbols etc.).
723
+ try:
724
+ with framework.ChainletImporter.import_target(
725
+ self._source, self._entrypoint
726
+ ) as entrypoint_cls:
727
+ chainlet_descriptors = _get_ordered_dependencies([entrypoint_cls])
728
+ chain_root_new = _get_chain_root(entrypoint_cls)
729
+ assert chain_root_new == self._chain_root
730
+ self._assert_chainlet_names_same(
731
+ set(desc.display_name for desc in chainlet_descriptors)
732
+ )
733
+ future_to_display_name = {}
734
+ for chainlet_descr in chainlet_descriptors:
735
+ if chainlet_descr.display_name not in self._included_chainlets:
736
+ self._console.print(
737
+ f"⏩ Skipping patching `{chainlet_descr.display_name}`.",
738
+ style="grey50",
739
+ )
740
+ continue
741
+
742
+ future = executor.submit(
743
+ self._code_gen_and_patch_thread, chainlet_descr
744
+ )
745
+ future_to_display_name[future] = chainlet_descr.display_name
746
+ # Threads need to finish while inside the `import_target`-context.
747
+ done_futures = {
748
+ future_to_display_name[future]: future
749
+ for future in concurrent.futures.as_completed(
750
+ future_to_display_name
751
+ )
752
+ }
753
+ except Exception as e:
754
+ exception_raised = e
755
+ stack_trace = traceback.format_exc()
756
+ finally:
757
+ logs = log_interceptor.get_logs()
758
+
759
+ _handle_intercepted_logs(logs, self._console)
760
+ if exception_raised:
761
+ _handle_import_error(
762
+ exception_raised,
763
+ self._console,
764
+ self._error_console,
765
+ stack_trace=stack_trace if self._show_stack_trace else None,
766
+ )
767
+ return
768
+
769
+ self._check_patch_results(done_futures)
770
+
771
+ def _check_patch_results(
772
+ self,
773
+ display_name_to_done_future: Mapping[
774
+ str, concurrent.futures.Future[tuple[b10_remote.PatchResult, list[str]]]
775
+ ],
776
+ ) -> None:
777
+ has_errors = False
778
+ for display_name, future in display_name_to_done_future.items():
779
+ # It is not expected that code_gen_and_patch raises an exception, errors
780
+ # should be handled by setting `b10_remote.PatchStatus`.
781
+ # If an exception is raised anyway, it should bubble up the default way.
782
+ patch_result, logs = future.result()
783
+ if logs:
784
+ formatted_logs = textwrap.indent("\n".join(logs), " " * 4)
785
+ logs_output = f" [grey70]Intercepted logs:\n{formatted_logs}[grey70]"
786
+ else:
787
+ logs_output = ""
788
+
789
+ if patch_result.status == b10_remote.PatchStatus.SUCCESS:
790
+ self._console.print(
791
+ f"✅ Patched Chainlet `{display_name}`.{logs_output}", style="green"
792
+ )
793
+ elif patch_result.status == b10_remote.PatchStatus.SKIPPED:
794
+ self._console.print(
795
+ f"💤 Nothing to do for Chainlet `{display_name}`.{logs_output}",
796
+ style="grey50",
797
+ )
798
+ else:
799
+ has_errors = True
800
+ self._error_console.print(
801
+ f"❌ Failed to patch Chainlet `{display_name}`. "
802
+ f"{patch_result.message}{logs_output}"
803
+ )
804
+
805
+ if has_errors:
806
+ msg = (
807
+ "Some Chainlets could not be live patched. See above error messages. "
808
+ "The watcher will continue, and try patching new changes. However, the "
809
+ "safest way to proceed and ensure a consistent state is to re-deploy "
810
+ "the the entire development Chain."
811
+ )
812
+ self._error_console.print(msg)
813
+
814
+ def watch(self) -> None:
815
+ with concurrent.futures.ThreadPoolExecutor() as executor:
816
+ # Perform one initial patch at startup.
817
+ self._patch(executor)
818
+ self._console.print("👀 Watching for new changes.", style="blue")
819
+ for _ in watchfiles.watch(
820
+ self._chain_root, watch_filter=self._watch_filter, raise_interrupt=False
821
+ ):
822
+ self._patch(executor)
823
+ self._console.print("👀 Watching for new changes.", style="blue")
824
+
825
+
826
+ @framework.raise_validation_errors_before
827
+ def watch(
828
+ source: pathlib.Path,
829
+ entrypoint: Optional[str],
830
+ name: Optional[str],
831
+ remote: str,
832
+ console: "rich_console.Console",
833
+ error_console: "rich_console.Console",
834
+ show_stack_trace: bool,
835
+ included_chainlets: Optional[list[str]],
836
+ ) -> None:
837
+ console.print(
838
+ (
839
+ "👀 Starting to watch for Chain source code and applying live patches "
840
+ "when changes are detected."
841
+ ),
842
+ style="blue",
843
+ )
844
+ patcher = _Watcher(
845
+ source,
846
+ entrypoint,
847
+ name,
848
+ remote,
849
+ console,
850
+ error_console,
851
+ show_stack_trace,
852
+ included_chainlets,
853
+ )
854
+ patcher.watch()
855
+
856
+
857
+ def watch_model(
858
+ source: pathlib.Path,
859
+ model_name: str,
860
+ remote_provider: b10_remote.TrussRemote,
861
+ console: "rich_console.Console",
862
+ error_console: "rich_console.Console",
863
+ ):
864
+ patcher = _ModelWatcher(
865
+ source=source,
866
+ model_name=model_name,
867
+ remote_provider=cast(b10_remote.BasetenRemote, remote_provider),
868
+ console=console,
869
+ error_console=error_console,
870
+ )
871
+ patcher.watch()