truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

Files changed (362) hide show
  1. truss/__init__.py +10 -3
  2. truss/api/__init__.py +123 -0
  3. truss/api/definitions.py +51 -0
  4. truss/base/constants.py +116 -0
  5. truss/base/custom_types.py +29 -0
  6. truss/{errors.py → base/errors.py} +4 -0
  7. truss/base/trt_llm_config.py +310 -0
  8. truss/{truss_config.py → base/truss_config.py} +344 -31
  9. truss/{truss_spec.py → base/truss_spec.py} +20 -6
  10. truss/{validation.py → base/validation.py} +60 -11
  11. truss/cli/cli.py +841 -88
  12. truss/{remote → cli}/remote_cli.py +2 -7
  13. truss/contexts/docker_build_setup.py +67 -0
  14. truss/contexts/image_builder/cache_warmer.py +2 -8
  15. truss/contexts/image_builder/image_builder.py +1 -1
  16. truss/contexts/image_builder/serving_image_builder.py +292 -46
  17. truss/contexts/image_builder/util.py +1 -3
  18. truss/contexts/local_loader/docker_build_emulator.py +58 -0
  19. truss/contexts/local_loader/load_model_local.py +2 -2
  20. truss/contexts/local_loader/truss_module_loader.py +1 -1
  21. truss/contexts/local_loader/utils.py +1 -1
  22. truss/local/local_config.py +2 -6
  23. truss/local/local_config_handler.py +20 -5
  24. truss/patch/__init__.py +1 -0
  25. truss/patch/hash.py +4 -70
  26. truss/patch/signature.py +4 -16
  27. truss/patch/truss_dir_patch_applier.py +3 -78
  28. truss/remote/baseten/api.py +308 -23
  29. truss/remote/baseten/auth.py +3 -3
  30. truss/remote/baseten/core.py +257 -50
  31. truss/remote/baseten/custom_types.py +44 -0
  32. truss/remote/baseten/error.py +4 -0
  33. truss/remote/baseten/remote.py +369 -118
  34. truss/remote/baseten/service.py +118 -11
  35. truss/remote/baseten/utils/status.py +29 -0
  36. truss/remote/baseten/utils/tar.py +34 -22
  37. truss/remote/baseten/utils/transfer.py +36 -23
  38. truss/remote/remote_factory.py +14 -5
  39. truss/remote/truss_remote.py +72 -45
  40. truss/templates/base.Dockerfile.jinja +18 -16
  41. truss/templates/cache.Dockerfile.jinja +3 -3
  42. truss/{server → templates/control}/control/application.py +14 -35
  43. truss/{server → templates/control}/control/endpoints.py +39 -9
  44. truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
  45. truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
  46. truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
  47. truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
  48. truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
  49. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
  50. truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
  51. truss/{server → templates/control}/control/server.py +11 -6
  52. truss/templates/control/requirements.txt +9 -0
  53. truss/templates/custom_python_dx/my_model.py +28 -0
  54. truss/templates/docker_server/proxy.conf.jinja +42 -0
  55. truss/templates/docker_server/supervisord.conf.jinja +27 -0
  56. truss/templates/docker_server_requirements.txt +1 -0
  57. truss/templates/server/common/errors.py +231 -0
  58. truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
  59. truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
  60. truss/{server → templates/server}/common/retry.py +1 -0
  61. truss/{server → templates/server}/common/schema.py +11 -9
  62. truss/templates/server/common/tracing.py +157 -0
  63. truss/templates/server/main.py +9 -0
  64. truss/templates/server/model_wrapper.py +961 -0
  65. truss/templates/server/requirements.txt +21 -0
  66. truss/templates/server/truss_server.py +447 -0
  67. truss/templates/server.Dockerfile.jinja +62 -14
  68. truss/templates/shared/dynamic_config_resolver.py +28 -0
  69. truss/templates/shared/lazy_data_resolver.py +164 -0
  70. truss/templates/shared/log_config.py +125 -0
  71. truss/{server → templates}/shared/secrets_resolver.py +1 -2
  72. truss/{server → templates}/shared/serialization.py +31 -9
  73. truss/{server → templates}/shared/util.py +3 -13
  74. truss/templates/trtllm-audio/model/model.py +49 -0
  75. truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
  76. truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
  77. truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
  78. truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
  79. truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
  80. truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
  81. truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
  82. truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
  83. truss/templates/trtllm-briton/src/extension.py +64 -0
  84. truss/tests/conftest.py +302 -94
  85. truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
  86. truss/tests/contexts/local_loader/test_load_local.py +2 -2
  87. truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
  88. truss/tests/patch/test_calc_patch.py +439 -127
  89. truss/tests/patch/test_dir_signature.py +3 -12
  90. truss/tests/patch/test_hash.py +1 -1
  91. truss/tests/patch/test_signature.py +1 -1
  92. truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
  93. truss/tests/patch/test_types.py +2 -2
  94. truss/tests/remote/baseten/test_api.py +153 -58
  95. truss/tests/remote/baseten/test_auth.py +2 -1
  96. truss/tests/remote/baseten/test_core.py +160 -12
  97. truss/tests/remote/baseten/test_remote.py +489 -77
  98. truss/tests/remote/baseten/test_service.py +55 -0
  99. truss/tests/remote/test_remote_factory.py +16 -18
  100. truss/tests/remote/test_truss_remote.py +26 -17
  101. truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
  102. truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
  103. truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
  104. truss/tests/{server → templates/control}/control/test_server.py +79 -24
  105. truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
  106. truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
  107. truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
  108. truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
  109. truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
  110. truss/tests/{server → templates/server}/common/test_retry.py +3 -3
  111. truss/tests/templates/server/test_model_wrapper.py +248 -0
  112. truss/tests/{server → templates/server}/test_schema.py +3 -5
  113. truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
  114. truss/tests/test_build.py +9 -52
  115. truss/tests/test_config.py +336 -77
  116. truss/tests/test_context_builder_image.py +3 -11
  117. truss/tests/test_control_truss_patching.py +7 -12
  118. truss/tests/test_custom_server.py +38 -0
  119. truss/tests/test_data/context_builder_image_test/test.py +3 -0
  120. truss/tests/test_data/happy.ipynb +56 -0
  121. truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
  122. truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
  123. truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
  124. truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
  125. truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
  126. truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
  127. truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
  128. truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
  129. truss/tests/test_data/test_async_truss/__init__.py +0 -0
  130. truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
  131. truss/tests/test_data/test_basic_truss/__init__.py +0 -0
  132. truss/tests/test_data/test_basic_truss/config.yaml +16 -0
  133. truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
  134. truss/tests/test_data/test_build_commands/__init__.py +0 -0
  135. truss/tests/test_data/test_build_commands/config.yaml +13 -0
  136. truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
  137. truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
  138. truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
  139. truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
  140. truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
  141. truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
  142. truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
  143. truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
  144. truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
  145. truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
  146. truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
  147. truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
  148. truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
  149. truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
  150. truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
  151. truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
  152. truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
  153. truss/tests/test_data/test_openai/__init__.py +0 -0
  154. truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
  155. truss/tests/test_data/test_openai/model/__init__.py +0 -0
  156. truss/tests/test_data/test_openai/model/model.py +15 -0
  157. truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
  158. truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
  159. truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
  160. truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
  161. truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
  162. truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
  163. truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
  164. truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
  165. truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
  166. truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
  167. truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
  168. truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
  169. truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
  170. truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
  171. truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
  172. truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
  173. truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
  174. truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
  175. truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
  176. truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
  177. truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
  178. truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
  179. truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
  180. truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
  181. truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
  182. truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
  183. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
  184. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
  185. truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
  186. truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
  187. truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
  188. truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
  189. truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
  190. truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
  191. truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
  192. truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
  193. truss/tests/test_data/test_truss/__init__.py +0 -0
  194. truss/tests/test_data/test_truss/config.yaml +4 -0
  195. truss/tests/test_data/test_truss/model/__init__.py +0 -0
  196. truss/tests/test_data/test_truss/model/dummy +0 -0
  197. truss/tests/test_data/test_truss/packages/__init__.py +0 -0
  198. truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
  199. truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
  200. truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
  201. truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
  202. truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
  203. truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
  204. truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
  205. truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
  206. truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
  207. truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
  208. truss/tests/test_docker.py +2 -1
  209. truss/tests/test_model_inference.py +1340 -292
  210. truss/tests/test_model_schema.py +33 -26
  211. truss/tests/test_testing_utilities_for_other_tests.py +50 -5
  212. truss/tests/test_truss_gatherer.py +3 -5
  213. truss/tests/test_truss_handle.py +62 -59
  214. truss/tests/test_util.py +2 -1
  215. truss/tests/test_validation.py +15 -13
  216. truss/tests/trt_llm/test_trt_llm_config.py +41 -0
  217. truss/tests/trt_llm/test_validation.py +91 -0
  218. truss/tests/util/test_config_checks.py +40 -0
  219. truss/tests/util/test_env_vars.py +14 -0
  220. truss/tests/util/test_path.py +10 -23
  221. truss/trt_llm/config_checks.py +43 -0
  222. truss/trt_llm/validation.py +42 -0
  223. truss/truss_handle/__init__.py +0 -0
  224. truss/truss_handle/build.py +122 -0
  225. truss/{decorators.py → truss_handle/decorators.py} +1 -1
  226. truss/truss_handle/patch/__init__.py +0 -0
  227. truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
  228. truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
  229. truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
  230. truss/truss_handle/patch/hash.py +71 -0
  231. truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
  232. truss/truss_handle/patch/signature.py +22 -0
  233. truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
  234. truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
  235. truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
  236. truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
  237. truss/util/.truss_ignore +3 -0
  238. truss/{docker.py → util/docker.py} +6 -2
  239. truss/util/download.py +6 -15
  240. truss/util/env_vars.py +41 -0
  241. truss/util/log_utils.py +52 -0
  242. truss/util/path.py +20 -20
  243. truss/util/requirements.py +11 -0
  244. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
  245. truss-0.60.0.dist-info/RECORD +324 -0
  246. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
  247. truss-0.60.0.dist-info/entry_points.txt +4 -0
  248. truss_chains/__init__.py +71 -0
  249. truss_chains/definitions.py +756 -0
  250. truss_chains/deployment/__init__.py +0 -0
  251. truss_chains/deployment/code_gen.py +816 -0
  252. truss_chains/deployment/deployment_client.py +871 -0
  253. truss_chains/framework.py +1480 -0
  254. truss_chains/public_api.py +231 -0
  255. truss_chains/py.typed +0 -0
  256. truss_chains/pydantic_numpy.py +131 -0
  257. truss_chains/reference_code/reference_chainlet.py +34 -0
  258. truss_chains/reference_code/reference_model.py +10 -0
  259. truss_chains/remote_chainlet/__init__.py +0 -0
  260. truss_chains/remote_chainlet/model_skeleton.py +60 -0
  261. truss_chains/remote_chainlet/stub.py +380 -0
  262. truss_chains/remote_chainlet/utils.py +332 -0
  263. truss_chains/streaming.py +378 -0
  264. truss_chains/utils.py +178 -0
  265. CODE_OF_CONDUCT.md +0 -131
  266. CONTRIBUTING.md +0 -48
  267. README.md +0 -137
  268. context_builder.Dockerfile +0 -24
  269. truss/blob/blob_backend.py +0 -10
  270. truss/blob/blob_backend_registry.py +0 -23
  271. truss/blob/http_public_blob_backend.py +0 -23
  272. truss/build/__init__.py +0 -2
  273. truss/build/build.py +0 -143
  274. truss/build/configure.py +0 -63
  275. truss/cli/__init__.py +0 -2
  276. truss/cli/console.py +0 -5
  277. truss/cli/create.py +0 -5
  278. truss/config/trt_llm.py +0 -81
  279. truss/constants.py +0 -61
  280. truss/model_inference.py +0 -123
  281. truss/patch/types.py +0 -30
  282. truss/pytest.ini +0 -7
  283. truss/server/common/errors.py +0 -100
  284. truss/server/common/termination_handler_middleware.py +0 -64
  285. truss/server/common/truss_server.py +0 -389
  286. truss/server/control/patch/model_code_patch_applier.py +0 -46
  287. truss/server/control/patch/requirement_name_identifier.py +0 -17
  288. truss/server/inference_server.py +0 -29
  289. truss/server/model_wrapper.py +0 -434
  290. truss/server/shared/logging.py +0 -81
  291. truss/templates/trtllm/model/model.py +0 -97
  292. truss/templates/trtllm/packages/build_engine_utils.py +0 -34
  293. truss/templates/trtllm/packages/constants.py +0 -11
  294. truss/templates/trtllm/packages/schema.py +0 -216
  295. truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
  296. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
  297. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
  298. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
  299. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
  300. truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
  301. truss/templates/trtllm/packages/triton_client.py +0 -150
  302. truss/templates/trtllm/packages/utils.py +0 -43
  303. truss/test_data/context_builder_image_test/test.py +0 -4
  304. truss/test_data/happy.ipynb +0 -54
  305. truss/test_data/model_load_failure_test/config.yaml +0 -2
  306. truss/test_data/test_concurrency_truss/config.yaml +0 -2
  307. truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
  308. truss/test_data/test_streaming_truss/config.yaml +0 -3
  309. truss/test_data/test_truss/config.yaml +0 -2
  310. truss/tests/server/common/test_termination_handler_middleware.py +0 -93
  311. truss/tests/server/control/test_model_container_patch_applier.py +0 -203
  312. truss/tests/server/core/server/common/test_util.py +0 -19
  313. truss/tests/server/test_model_wrapper.py +0 -87
  314. truss/util/data_structures.py +0 -16
  315. truss-0.10.0rc1.dist-info/RECORD +0 -216
  316. truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
  317. truss/{server/shared → base}/__init__.py +0 -0
  318. truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
  319. truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
  320. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
  321. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
  322. truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
  323. truss/{server → templates/server}/common/__init__.py +0 -0
  324. truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
  325. truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
  326. truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
  327. truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
  328. truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
  329. truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
  330. truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
  331. truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
  332. truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
  333. truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
  334. truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
  335. truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
  336. truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
  337. truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
  338. truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
  339. truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
  340. truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
  341. truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
  342. truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
  343. truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
  344. truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
  345. truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
  346. truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
  347. /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
  348. /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
  349. /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
  350. /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
  351. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
  352. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
  353. /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
  354. /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
  355. /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
  356. /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
  357. /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
  358. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
  359. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
  360. /truss/{patch → truss_handle/patch}/constants.py +0 -0
  361. /truss/{notebook.py → util/notebook.py} +0 -0
  362. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
@@ -0,0 +1,816 @@
1
+ """
2
+ Chains currently assumes that everything from the directory in which the entrypoint
3
+ is defined (i.e. sibling files and nested dirs) could be imported/used. e.g.:
4
+
5
+ workspace/
6
+ entrypoint.py
7
+ helper.py
8
+ some_package/
9
+ utils.py
10
+ sub_package/
11
+ ...
12
+
13
+ These sources are copied into truss's `/packages` and can be imported on the remote.
14
+ Using code *outside* of the workspace is not supported:
15
+
16
+ shared_lib/
17
+ common.py
18
+ workspace/
19
+ entrypoint.py
20
+ ...
21
+
22
+ `shared_lib` can only be imported on the remote if its installed as a pip
23
+ requirement (site-package), it will not be copied from the local host.
24
+ """
25
+
26
+ import logging
27
+ import os
28
+ import pathlib
29
+ import re
30
+ import shlex
31
+ import shutil
32
+ import subprocess
33
+ import sys
34
+ import tempfile
35
+ import textwrap
36
+ from typing import Any, Iterable, Mapping, Optional, get_args, get_origin
37
+
38
+ import libcst
39
+ import truss
40
+ from truss.base import truss_config
41
+ from truss.contexts.image_builder import serving_image_builder
42
+ from truss.util import path as truss_path
43
+
44
+ from truss_chains import definitions, framework, utils
45
+
46
+ _INDENT = " " * 4
47
+ _REQUIREMENTS_FILENAME = "pip_requirements.txt"
48
+ _MODEL_FILENAME = "model.py"
49
+ _MODEL_CLS_NAME = "TrussChainletModel"
50
+ _TRUSS_GIT = "git+https://github.com/basetenlabs/truss.git"
51
+ _TRUSS_PIP_PATTERN = re.compile(
52
+ r"""
53
+ ^truss
54
+ (?:
55
+ \s*(==|>=|<=|!=|>|<)\s* # Version comparison operators
56
+ \d+(\.\d+)* # Version numbers (e.g., 1, 1.0, 1.0.0)
57
+ (?: # Optional pre-release or build metadata
58
+ (?:a|b|rc|dev)\d*
59
+ (?:\.post\d+)?
60
+ (?:\+[\w\.]+)?
61
+ )?
62
+ )?$
63
+ """,
64
+ re.VERBOSE,
65
+ )
66
+
67
+ _MODEL_SKELETON_FILE = (
68
+ pathlib.Path(__file__).parent.parent.resolve()
69
+ / "remote_chainlet"
70
+ / "model_skeleton.py"
71
+ )
72
+
73
+
74
+ def _indent(text: str, num: int = 1) -> str:
75
+ return textwrap.indent(text, _INDENT * num)
76
+
77
+
78
+ def _run_simple_subprocess(cmd: str) -> None:
79
+ process = subprocess.Popen(
80
+ shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE
81
+ )
82
+ _, stderr = process.communicate()
83
+ if process.returncode != 0:
84
+ raise ChildProcessError(f"Error: {stderr.decode()}")
85
+
86
+
87
+ def _format_python_file(file_path: pathlib.Path) -> None:
88
+ # Resolve importing sorting and unused import issues.
89
+ _run_simple_subprocess(f"ruff check {file_path} --fix --select F401,I")
90
+ _run_simple_subprocess(f"ruff format {file_path}")
91
+
92
+
93
+ class _Source(definitions.SafeModelNonSerializable):
94
+ src: str
95
+ imports: set[str] = set()
96
+
97
+
98
+ def _update_src(new_source: _Source, src_parts: list[str], imports: set[str]) -> None:
99
+ src_parts.append(new_source.src)
100
+ imports.update(new_source.imports)
101
+
102
+
103
+ def _gen_pydantic_import_and_ref(raw_type: Any) -> _Source:
104
+ """Returns e.g. ("from sub_package import module", "module.OutputType")."""
105
+ if raw_type.__module__ == "__main__":
106
+ # Assuming that main is copied into package dir and can be imported.
107
+ module_obj = sys.modules[raw_type.__module__]
108
+ if not module_obj.__file__:
109
+ raise definitions.ChainsUsageError(
110
+ f"File-based python code required. `{raw_type}` does not have a file."
111
+ )
112
+
113
+ file = os.path.basename(module_obj.__file__)
114
+ assert file.endswith(".py")
115
+ module_name = file.replace(".py", "")
116
+ import_src = f"import {module_name}"
117
+ ref_src = f"{module_name}.{raw_type.__name__}"
118
+ else:
119
+ parts = raw_type.__module__.split(".")
120
+ ref_src = f"{parts[-1]}.{raw_type.__name__}"
121
+ if len(parts) > 1:
122
+ import_src = f"from {'.'.join(parts[:-1])} import {parts[-1]}"
123
+ else:
124
+ import_src = f"import {parts[0]}"
125
+
126
+ return _Source(src=ref_src, imports={import_src})
127
+
128
+
129
+ def _gen_nested_pydantic(raw_type: Any) -> _Source:
130
+ """Handles `list[PydanticModel]` and similar, correctly resolving imports
131
+ of model args that might be defined in other files."""
132
+ origin = get_origin(raw_type)
133
+ assert origin in framework._SIMPLE_CONTAINERS
134
+ container = _gen_type_import_and_ref(definitions.TypeDescriptor(raw=origin))
135
+ args = get_args(raw_type)
136
+ arg_parts = []
137
+ for arg in args:
138
+ arg_src = _gen_type_import_and_ref(definitions.TypeDescriptor(raw=arg))
139
+ arg_parts.append(arg_src.src)
140
+ container.imports.update(arg_src.imports)
141
+
142
+ container.src = f"{container.src}[{','.join(arg_parts)}]"
143
+ return container
144
+
145
+
146
+ def _gen_type_import_and_ref(type_descr: definitions.TypeDescriptor) -> _Source:
147
+ """Returns e.g. ("from sub_package import module", "module.OutputType")."""
148
+ if type_descr.is_pydantic:
149
+ return _gen_pydantic_import_and_ref(type_descr.raw)
150
+ if type_descr.has_pydantic_args:
151
+ return _gen_nested_pydantic(type_descr.raw)
152
+ if isinstance(type_descr.raw, type):
153
+ if not type_descr.raw.__module__ == "builtins":
154
+ raise TypeError(
155
+ f"{type_descr.raw} is not a builtin - cannot be rendered as source."
156
+ )
157
+ return _Source(src=type_descr.raw.__name__)
158
+
159
+ return _Source(src=str(type_descr.raw))
160
+
161
+
162
+ def _gen_streaming_type_import_and_ref(
163
+ stream_type: definitions.StreamingTypeDescriptor,
164
+ ) -> _Source:
165
+ """Unlike other `_gen`-helpers, this does not define a type, it creates a symbol."""
166
+ mod = stream_type.origin_type.__module__
167
+ arg = stream_type.arg_type.__name__
168
+ type_src = f"{mod}.{stream_type.origin_type.__name__}[{arg}]"
169
+ return _Source(src=type_src, imports={f"import {mod}"})
170
+
171
+
172
+ def _gen_chainlet_import_and_ref(
173
+ chainlet_descriptor: definitions.ChainletAPIDescriptor,
174
+ ) -> _Source:
175
+ """Returns e.g. ("from sub_package import module", "module.OutputType")."""
176
+ return _gen_pydantic_import_and_ref(chainlet_descriptor.chainlet_cls)
177
+
178
+
179
+ # I/O used by Stubs and Truss models ###################################################
180
+
181
+
182
+ def _get_input_model_name(chainlet_name: str) -> str:
183
+ return f"{chainlet_name}Input"
184
+
185
+
186
+ def _get_output_model_name(chainlet_name: str) -> str:
187
+ return f"{chainlet_name}Output"
188
+
189
+
190
+ def _gen_truss_input_pydantic(
191
+ chainlet_descriptor: definitions.ChainletAPIDescriptor,
192
+ ) -> _Source:
193
+ imports = {"import pydantic", "from typing import Optional"}
194
+ fields = []
195
+ for arg in chainlet_descriptor.endpoint.input_args:
196
+ type_ref = _gen_type_import_and_ref(arg.type)
197
+ imports.update(type_ref.imports)
198
+ if arg.is_optional:
199
+ fields.append(f"{arg.name}: Optional[{type_ref.src}] = None")
200
+ else:
201
+ fields.append(f"{arg.name}: {type_ref.src}")
202
+
203
+ if fields:
204
+ field_block = _indent("\n".join(fields))
205
+ else:
206
+ field_block = _indent("pass")
207
+
208
+ model_name = _get_input_model_name(chainlet_descriptor.name)
209
+ src = f"class {model_name}(pydantic.BaseModel):\n{field_block}"
210
+ return _Source(src=src, imports=imports)
211
+
212
+
213
+ def _gen_truss_output_pydantic(
214
+ chainlet_descriptor: definitions.ChainletAPIDescriptor,
215
+ ) -> _Source:
216
+ imports = {"import pydantic"}
217
+ fields: list[str] = []
218
+ for i, output_type in enumerate(chainlet_descriptor.endpoint.output_types):
219
+ _update_src(_gen_type_import_and_ref(output_type), fields, imports)
220
+
221
+ model_name = _get_output_model_name(chainlet_descriptor.name)
222
+ if len(fields) > 1:
223
+ root_type = f"tuple[{','.join(fields)}]"
224
+ else:
225
+ root_type = fields[0]
226
+ src = f"{model_name} = pydantic.RootModel[{root_type}]"
227
+ return _Source(src=src, imports=imports)
228
+
229
+
230
+ # Stub Gen #############################################################################
231
+
232
+
233
+ def _stub_endpoint_signature_src(
234
+ endpoint: definitions.EndpointAPIDescriptor,
235
+ ) -> _Source:
236
+ """
237
+ E.g.:
238
+ ```
239
+ async def run_remote(
240
+ self, inputs: shared_chainlet.SplitTextInput, extra_arg: int
241
+ ) -> tuple[shared_chainlet.SplitTextOutput, int]:
242
+ ```
243
+ """
244
+ imports = set()
245
+ args = ["self"]
246
+ for arg in endpoint.input_args:
247
+ arg_ref = _gen_type_import_and_ref(arg.type)
248
+ imports.update(arg_ref.imports)
249
+ args.append(f"{arg.name}: {arg_ref.src}")
250
+
251
+ if endpoint.is_streaming:
252
+ streaming_src = _gen_streaming_type_import_and_ref(endpoint.streaming_type)
253
+ imports.update(streaming_src.imports)
254
+ output = streaming_src.src
255
+ else:
256
+ outputs: list[str] = []
257
+ for output_type in endpoint.output_types:
258
+ _update_src(_gen_type_import_and_ref(output_type), outputs, imports)
259
+
260
+ if len(outputs) == 1:
261
+ output = outputs[0]
262
+ else:
263
+ output = f"tuple[{', '.join(outputs)}]"
264
+
265
+ def_str = "async def" if endpoint.is_async else "def"
266
+ return _Source(
267
+ src=f"{def_str} {endpoint.name}({','.join(args)}) -> {output}:", imports=imports
268
+ )
269
+
270
+
271
+ def _stub_endpoint_body_src(
272
+ endpoint: definitions.EndpointAPIDescriptor, chainlet_name: str
273
+ ) -> _Source:
274
+ """Generates source code for calling the stub and wrapping the I/O types.
275
+
276
+ E.g.:
277
+ ```
278
+ return await self.predict_async(
279
+ SplitTextInput(inputs=inputs, extra_arg=extra_arg), SplitTextOutput).root
280
+ ```
281
+ """
282
+ imports: set[str] = set()
283
+ args = [f"{arg.name}={arg.name}" for arg in endpoint.input_args]
284
+ if args:
285
+ inputs = f"{_get_input_model_name(chainlet_name)}({', '.join(args)})"
286
+ else:
287
+ inputs = "{}"
288
+
289
+ parts = []
290
+ # Invoke remote.
291
+ if not endpoint.is_streaming:
292
+ output_model_name = _get_output_model_name(chainlet_name)
293
+ if endpoint.is_async:
294
+ parts = [
295
+ f"return (await self.predict_async({inputs}, {output_model_name})).root"
296
+ ]
297
+ else:
298
+ parts = [f"return self.predict_sync({inputs}, {output_model_name}).root"]
299
+
300
+ else:
301
+ if endpoint.is_async:
302
+ parts.append(
303
+ f"async for data in await self.predict_async_stream({inputs}):"
304
+ )
305
+ if endpoint.streaming_type.is_string:
306
+ parts.append(_indent("yield data.decode()"))
307
+ else:
308
+ parts.append(_indent("yield data"))
309
+ else:
310
+ raise NotImplementedError(
311
+ "`Streaming endpoints (containing `yield` statements) are only "
312
+ "supported for async endpoints."
313
+ )
314
+
315
+ return _Source(src="\n".join(parts), imports=imports)
316
+
317
+
318
+ def _gen_stub_src(chainlet: definitions.ChainletAPIDescriptor) -> _Source:
319
+ """Generates stub class source, e.g:
320
+
321
+ ```
322
+ <IMPORTS>
323
+
324
+ class SplitTextInput(pydantic.BaseModel):
325
+ inputs: shared_chainlet.SplitTextInput
326
+ extra_arg: int
327
+
328
+ class SplitTextOutput(pydantic.BaseModel):
329
+ output: tuple[shared_chainlet.SplitTextOutput, int]
330
+
331
+ class SplitText(stub.StubBase):
332
+ async def run_remote(
333
+ self, inputs: shared_chainlet.SplitTextInput, extra_arg: int
334
+ ) -> tuple[shared_chainlet.SplitTextOutput, int]:
335
+ return await self.predict_async(
336
+ SplitTextInput(inputs=inputs, extra_arg=extra_arg), SplitTextOutput).root
337
+ ```
338
+ """
339
+ imports = {"from truss_chains.remote_chainlet import stub"}
340
+ src_parts: list[str] = []
341
+ input_src = _gen_truss_input_pydantic(chainlet)
342
+ _update_src(input_src, src_parts, imports)
343
+ if not chainlet.endpoint.is_streaming:
344
+ output_src = _gen_truss_output_pydantic(chainlet)
345
+ _update_src(output_src, src_parts, imports)
346
+ signature = _stub_endpoint_signature_src(chainlet.endpoint)
347
+ imports.update(signature.imports)
348
+ body = _stub_endpoint_body_src(chainlet.endpoint, chainlet.name)
349
+ imports.update(body.imports)
350
+
351
+ src_parts.extend(
352
+ [
353
+ f"class {chainlet.name}(stub.StubBase):",
354
+ _indent(signature.src),
355
+ _indent(body.src, 2),
356
+ "\n",
357
+ ]
358
+ )
359
+ return _Source(src="\n".join(src_parts), imports=imports)
360
+
361
+
362
+ def _gen_stub_src_for_deps(
363
+ dependencies: Iterable[definitions.ChainletAPIDescriptor],
364
+ ) -> Optional[_Source]:
365
+ """Generates a source code and imports for stub classes."""
366
+ imports: set[str] = set()
367
+ src_parts: list[str] = []
368
+ for dep in dependencies:
369
+ _update_src(_gen_stub_src(dep), src_parts, imports)
370
+
371
+ if not (imports or src_parts):
372
+ return None
373
+ return _Source(src="\n".join(src_parts), imports=imports)
374
+
375
+
376
+ # Truss Chainlet Gen ###################################################################
377
+
378
+
379
+ def _name_to_dirname(name: str) -> str:
380
+ """Make a string safe to use as a directory name."""
381
+ name = name.strip() # Remove leading and trailing spaces
382
+ name = re.sub(
383
+ r"[^\w.-]", "_", name
384
+ ) # Replace non-alphanumeric characters with underscores
385
+ name = re.sub(r"_+", "_", name) # Collapse multiple underscores into a single one
386
+ return name
387
+
388
+
389
+ def _make_chainlet_dir(
390
+ chain_name: str,
391
+ chainlet_descriptor: definitions.ChainletAPIDescriptor,
392
+ root: pathlib.Path,
393
+ ) -> pathlib.Path:
394
+ dir_name = f"chainlet_{chainlet_descriptor.name}"
395
+ chainlet_dir = (
396
+ root / definitions.GENERATED_CODE_DIR / _name_to_dirname(chain_name) / dir_name
397
+ )
398
+ if chainlet_dir.exists():
399
+ shutil.rmtree(chainlet_dir)
400
+ chainlet_dir.mkdir(exist_ok=False, parents=True)
401
+ return chainlet_dir
402
+
403
+
404
+ class _SpecifyChainletTypeAnnotation(libcst.CSTTransformer):
405
+ """Inserts the concrete chainlet class into `_chainlet: definitions.ABCChainlet`."""
406
+
407
+ def __init__(self, new_annotation: str) -> None:
408
+ super().__init__()
409
+ self._new_annotation = libcst.parse_expression(new_annotation)
410
+
411
+ def leave_SimpleStatementLine(
412
+ self,
413
+ original_node: libcst.SimpleStatementLine,
414
+ updated_node: libcst.SimpleStatementLine,
415
+ ) -> libcst.SimpleStatementLine:
416
+ new_body: list[Any] = []
417
+ for statement in updated_node.body:
418
+ if (
419
+ isinstance(statement, libcst.AnnAssign)
420
+ and isinstance(statement.target, libcst.Name)
421
+ and statement.target.value == "_chainlet"
422
+ ):
423
+ new_annotation = libcst.Annotation(annotation=self._new_annotation)
424
+ new_statement = statement.with_changes(annotation=new_annotation)
425
+ new_body.append(new_statement)
426
+ else:
427
+ new_body.append(statement)
428
+
429
+ return updated_node.with_changes(body=tuple(new_body))
430
+
431
+
432
+ def _gen_load_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> _Source:
433
+ """Generates AST for the `load` method of the truss model."""
434
+ imports = {"from truss_chains.remote_chainlet import stub", "import logging"}
435
+ stub_args = []
436
+ for name, dep in chainlet_descriptor.dependencies.items():
437
+ # `dep.name` is the class name, while `name` is the argument name.
438
+ stub_args.append(f"{name}=stub.factory({dep.name}, self._context)")
439
+
440
+ if chainlet_descriptor.has_context:
441
+ if stub_args:
442
+ init_args = f"{', '.join(stub_args)}, context=self._context"
443
+ else:
444
+ init_args = "context=self._context"
445
+ else:
446
+ init_args = ", ".join(stub_args)
447
+
448
+ user_chainlet_ref = _gen_chainlet_import_and_ref(chainlet_descriptor)
449
+ imports.update(user_chainlet_ref.imports)
450
+ body = _indent(
451
+ "\n".join(
452
+ [f"logging.info(f'Loading Chainlet `{chainlet_descriptor.name}`.')"]
453
+ + [f"self._chainlet = {user_chainlet_ref.src}({init_args})"]
454
+ )
455
+ )
456
+ src = "\n".join(["def load(self) -> None:", body])
457
+ return _Source(src=src, imports=imports)
458
+
459
+
460
+ def _gen_health_check_src(
461
+ health_check: definitions.HealthCheckAPIDescriptor,
462
+ ) -> _Source:
463
+ """Generates AST for the `is_healthy` method of the truss model."""
464
+ def_str = "async def" if health_check.is_async else "def"
465
+ maybe_await = "await " if health_check.is_async else ""
466
+ src = (
467
+ f"{def_str} is_healthy(self) -> Optional[bool]:\n"
468
+ f"""{_indent('if hasattr(self, "_chainlet"):')}"""
469
+ f"""{_indent(f"return {maybe_await}self._chainlet.is_healthy()")}"""
470
+ )
471
+ return _Source(src=src)
472
+
473
+
474
+ def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> _Source:
475
+ """Generates AST for the `predict` method of the truss model."""
476
+ imports: set[str] = {
477
+ "from truss_chains.remote_chainlet import stub",
478
+ "from truss_chains.remote_chainlet import utils",
479
+ }
480
+ parts: list[str] = []
481
+ def_str = "async def" if chainlet_descriptor.endpoint.is_async else "def"
482
+ input_model_name = _get_input_model_name(chainlet_descriptor.name)
483
+ if chainlet_descriptor.endpoint.is_streaming:
484
+ streaming_src = _gen_streaming_type_import_and_ref(
485
+ chainlet_descriptor.endpoint.streaming_type
486
+ )
487
+ imports.update(streaming_src.imports)
488
+ output_type_name = streaming_src.src
489
+ else:
490
+ output_type_name = _get_output_model_name(chainlet_descriptor.name)
491
+
492
+ imports.add("import starlette.requests")
493
+ parts.append(
494
+ f"{def_str} predict(self, inputs: {input_model_name}, "
495
+ f"request: starlette.requests.Request) -> {output_type_name}:"
496
+ )
497
+ # Add error handling context manager:
498
+ parts.append(_indent("with utils.predict_context(request):"))
499
+ # Invoke Chainlet.
500
+ if (
501
+ chainlet_descriptor.endpoint.is_async
502
+ and not chainlet_descriptor.endpoint.is_streaming
503
+ ):
504
+ maybe_await = "await "
505
+ else:
506
+ maybe_await = ""
507
+ run_remote = chainlet_descriptor.endpoint.name
508
+ # See docs of `pydantic_set_field_dict` for why this is needed.
509
+ args = "**utils.pydantic_set_field_dict(inputs)"
510
+ parts.append(
511
+ _indent(f"result = {maybe_await}self._chainlet.{run_remote}({args})", 2)
512
+ )
513
+ if chainlet_descriptor.endpoint.is_streaming:
514
+ # Streaming returns raw iterator, no pydantic model.
515
+ # This needs to be nested inside the `trace_parent` context!
516
+ parts.append(_indent("async for chunk in result:", 2))
517
+ parts.append(_indent("yield chunk", 3))
518
+ else:
519
+ result_pydantic = f"{output_type_name}(result)"
520
+ parts.append(_indent(f"return {result_pydantic}"))
521
+ return _Source(src="\n".join(parts), imports=imports)
522
+
523
+
524
+ def _gen_truss_chainlet_model(
525
+ chainlet_descriptor: definitions.ChainletAPIDescriptor,
526
+ ) -> _Source:
527
+ skeleton_tree = libcst.parse_module(_MODEL_SKELETON_FILE.read_text())
528
+ imports: set[str] = set(
529
+ libcst.Module(body=[node]).code
530
+ for node in skeleton_tree.body
531
+ if isinstance(node, libcst.SimpleStatementLine)
532
+ and any(
533
+ isinstance(stmt, libcst.Import) or isinstance(stmt, libcst.ImportFrom)
534
+ for stmt in node.body
535
+ )
536
+ )
537
+ class_definition: libcst.ClassDef = utils.expect_one(
538
+ node
539
+ for node in skeleton_tree.body
540
+ if isinstance(node, libcst.ClassDef) and node.name.value == _MODEL_CLS_NAME
541
+ )
542
+
543
+ load_src = _gen_load_src(chainlet_descriptor)
544
+ imports.update(load_src.imports)
545
+ predict_src = _gen_predict_src(chainlet_descriptor)
546
+ imports.update(predict_src.imports)
547
+
548
+ new_body: list[Any] = list(class_definition.body.body) + [
549
+ libcst.parse_statement(load_src.src),
550
+ libcst.parse_statement(predict_src.src),
551
+ ]
552
+
553
+ if chainlet_descriptor.health_check is not None:
554
+ health_check_src = _gen_health_check_src(chainlet_descriptor.health_check)
555
+ new_body.extend([libcst.parse_statement(health_check_src.src)])
556
+
557
+ user_chainlet_ref = _gen_chainlet_import_and_ref(chainlet_descriptor)
558
+ imports.update(user_chainlet_ref.imports)
559
+
560
+ new_block = libcst.IndentedBlock(body=new_body)
561
+ class_definition = class_definition.with_changes(body=new_block)
562
+ class_definition = class_definition.visit( # type: ignore[assignment]
563
+ _SpecifyChainletTypeAnnotation(user_chainlet_ref.src)
564
+ )
565
+ model_class_src = libcst.Module(body=[class_definition]).code
566
+ return _Source(src=model_class_src, imports=imports)
567
+
568
+
569
+ def _gen_truss_chainlet_file(
570
+ chainlet_dir: pathlib.Path,
571
+ chainlet_descriptor: definitions.ChainletAPIDescriptor,
572
+ dependencies: Iterable[definitions.ChainletAPIDescriptor],
573
+ ) -> pathlib.Path:
574
+ """Generates code that wraps a Chainlet as a truss-compatible model."""
575
+ file_path = chainlet_dir / truss_config.DEFAULT_MODEL_MODULE_DIR / _MODEL_FILENAME
576
+ file_path.parent.mkdir(parents=True, exist_ok=True)
577
+ (chainlet_dir / truss_config.DEFAULT_MODEL_MODULE_DIR / "__init__.py").touch()
578
+ imports: set[str] = set()
579
+ src_parts: list[str] = []
580
+
581
+ if maybe_stub_src := _gen_stub_src_for_deps(dependencies):
582
+ _update_src(maybe_stub_src, src_parts, imports)
583
+
584
+ input_src = _gen_truss_input_pydantic(chainlet_descriptor)
585
+ _update_src(input_src, src_parts, imports)
586
+ if not chainlet_descriptor.endpoint.is_streaming:
587
+ output_src = _gen_truss_output_pydantic(chainlet_descriptor)
588
+ _update_src(output_src, src_parts, imports)
589
+ model_src = _gen_truss_chainlet_model(chainlet_descriptor)
590
+ _update_src(model_src, src_parts, imports)
591
+
592
+ imports_str = "\n".join(imports)
593
+ src_str = "\n".join(src_parts)
594
+ file_path.write_text(f"{imports_str}\n{src_str}")
595
+ _format_python_file(file_path)
596
+ return file_path
597
+
598
+
599
+ # Truss Gen ############################################################################
600
+
601
+
602
+ def _make_requirements(image: definitions.DockerImage) -> list[str]:
603
+ """Merges file- and list-based requirements and adds truss git if not present."""
604
+ pip_requirements: set[str] = set()
605
+ if image.pip_requirements_file:
606
+ pip_requirements.update(
607
+ req
608
+ for req in pathlib.Path(image.pip_requirements_file.abs_path)
609
+ .read_text()
610
+ .splitlines()
611
+ if not req.strip().startswith("#")
612
+ )
613
+ pip_requirements.update(image.pip_requirements)
614
+
615
+ truss_pypy = next(
616
+ (req for req in pip_requirements if _TRUSS_PIP_PATTERN.match(req)), None
617
+ )
618
+
619
+ truss_git = next((req for req in pip_requirements if _TRUSS_GIT in req), None)
620
+
621
+ if truss_git:
622
+ logging.warning(
623
+ "The chainlet contains a truss version from github as a pip_requirement:\n"
624
+ f"\t{truss_git}\n"
625
+ "This could result in inconsistencies between the deploying client and the "
626
+ "deployed chainlet. This is not recommended for production chains."
627
+ )
628
+ if truss_pypy:
629
+ logging.warning(
630
+ "The chainlet contains a pinned truss version as a pip_requirement:\n"
631
+ f"\t{truss_pypy}\n"
632
+ "This could result in inconsistencies between the deploying client and the "
633
+ "deployed chainlet. This is not recommended for production chains. If "
634
+ "`truss` is not manually added as a requirement, the same version as "
635
+ "locally installed will be automatically added and ensure compatibility."
636
+ )
637
+
638
+ if not (truss_git or truss_pypy):
639
+ truss_pip = f"truss=={truss.version()}"
640
+ logging.debug(
641
+ f"Truss not found in pip requirements, auto-adding: `{truss_pip}`."
642
+ )
643
+ pip_requirements.add(truss_pip)
644
+
645
+ return sorted(pip_requirements)
646
+
647
+
648
+ def _inplace_fill_base_image(
649
+ image: definitions.DockerImage, mutable_truss_config: truss_config.TrussConfig
650
+ ) -> None:
651
+ if isinstance(image.base_image, definitions.BasetenImage):
652
+ mutable_truss_config.python_version = image.base_image.value
653
+ elif isinstance(image.base_image, definitions.CustomImage):
654
+ mutable_truss_config.base_image = truss_config.BaseImage(
655
+ image=image.base_image.image, docker_auth=image.base_image.docker_auth
656
+ )
657
+ if image.base_image.python_executable_path:
658
+ mutable_truss_config.base_image.python_executable_path = (
659
+ image.base_image.python_executable_path
660
+ )
661
+ elif isinstance(image.base_image, str): # This options is deprecated.
662
+ raise NotImplementedError(
663
+ "Specifying docker base image as string is deprecated"
664
+ )
665
+
666
+
667
+ def _write_truss_config_yaml(
668
+ chainlet_dir: pathlib.Path,
669
+ chains_config: definitions.RemoteConfig,
670
+ chainlet_to_service: Mapping[str, definitions.ServiceDescriptor],
671
+ model_name: str,
672
+ use_local_chains_src: bool,
673
+ ):
674
+ """Generate a truss config for a Chainlet."""
675
+ config = truss_config.TrussConfig()
676
+ config.model_name = model_name
677
+ config.model_class_filename = _MODEL_FILENAME
678
+ config.model_class_name = _MODEL_CLS_NAME
679
+ config.runtime.enable_tracing_data = chains_config.options.enable_b10_tracing
680
+ config.environment_variables = dict(chains_config.options.env_variables)
681
+ config.runtime.health_checks = chains_config.options.health_checks
682
+ # Compute.
683
+ compute = chains_config.get_compute_spec()
684
+ config.resources.cpu = str(compute.cpu_count)
685
+ config.resources.memory = str(compute.memory)
686
+ config.resources.accelerator = compute.accelerator
687
+ config.resources.use_gpu = bool(compute.accelerator.count)
688
+ config.runtime.predict_concurrency = compute.predict_concurrency
689
+ # Image.
690
+ _inplace_fill_base_image(chains_config.docker_image, config)
691
+ pip_requirements = _make_requirements(chains_config.docker_image)
692
+ # TODO: `pip_requirements` will add server requirements which give version
693
+ # conflicts. Check if that's still the case after relaxing versions.
694
+ # config.requirements = pip_requirements
695
+ pip_requirements_file_path = chainlet_dir / _REQUIREMENTS_FILENAME
696
+ pip_requirements_file_path.write_text("\n".join(pip_requirements))
697
+ # Absolute paths don't work with remote build.
698
+ config.requirements_file = _REQUIREMENTS_FILENAME
699
+ config.system_packages = chains_config.docker_image.apt_requirements
700
+ if chains_config.docker_image.external_package_dirs:
701
+ for ext_dir in chains_config.docker_image.external_package_dirs:
702
+ config.external_package_dirs.append(ext_dir.abs_path)
703
+ config.use_local_chains_src = use_local_chains_src
704
+ # Assets.
705
+ assets = chains_config.get_asset_spec()
706
+ config.secrets = assets.secrets
707
+ if definitions.BASETEN_API_SECRET_NAME not in config.secrets:
708
+ config.secrets[definitions.BASETEN_API_SECRET_NAME] = definitions.SECRET_DUMMY
709
+ else:
710
+ logging.info(
711
+ f"Chains automatically add {definitions.BASETEN_API_SECRET_NAME} "
712
+ "to secrets - no need to manually add it."
713
+ )
714
+ config.model_cache.models = assets.cached
715
+ config.external_data = truss_config.ExternalData(items=assets.external_data)
716
+ # Metadata.
717
+ chains_metadata: definitions.TrussMetadata = definitions.TrussMetadata(
718
+ chainlet_to_service=chainlet_to_service
719
+ )
720
+ config.model_metadata[definitions.TRUSS_CONFIG_CHAINS_KEY] = (
721
+ chains_metadata.model_dump()
722
+ )
723
+ config.write_to_yaml_file(
724
+ chainlet_dir / serving_image_builder.CONFIG_FILE, verbose=True
725
+ )
726
+
727
+
728
+ def gen_truss_model_from_source(
729
+ model_src: pathlib.Path, use_local_chains_src: bool = False
730
+ ) -> pathlib.Path:
731
+ # TODO(nikhil): Improve detection of directory structure, since right now
732
+ # we assume a flat structure
733
+ root_dir = model_src.absolute().parent
734
+ with framework.ModelImporter.import_target(model_src) as entrypoint_cls:
735
+ descriptor = framework.get_descriptor(entrypoint_cls)
736
+ return gen_truss_model(
737
+ model_root=root_dir,
738
+ model_name=entrypoint_cls.display_name,
739
+ model_descriptor=descriptor,
740
+ use_local_chains_src=use_local_chains_src,
741
+ )
742
+
743
+
744
+ def gen_truss_model(
745
+ model_root: pathlib.Path,
746
+ model_name: str,
747
+ model_descriptor: definitions.ChainletAPIDescriptor,
748
+ use_local_chains_src: bool = False,
749
+ ) -> pathlib.Path:
750
+ return gen_truss_chainlet(
751
+ chain_root=model_root,
752
+ chain_name=model_name,
753
+ chainlet_descriptor=model_descriptor,
754
+ use_local_chains_src=use_local_chains_src,
755
+ )
756
+
757
+
758
+ def gen_truss_chainlet(
759
+ chain_root: pathlib.Path,
760
+ chain_name: str,
761
+ chainlet_descriptor: definitions.ChainletAPIDescriptor,
762
+ model_name: Optional[str] = None,
763
+ use_local_chains_src: bool = False,
764
+ ) -> pathlib.Path:
765
+ # Filter needed services and customize options.
766
+ dep_services = {}
767
+ for dep in chainlet_descriptor.dependencies.values():
768
+ dep_services[dep.name] = definitions.ServiceDescriptor(
769
+ name=dep.name, display_name=dep.display_name, options=dep.options
770
+ )
771
+ gen_root = pathlib.Path(tempfile.gettempdir())
772
+ chainlet_dir = _make_chainlet_dir(chain_name, chainlet_descriptor, gen_root)
773
+ logging.info(
774
+ f"Code generation for {chainlet_descriptor.chainlet_cls.entity_type} `{chainlet_descriptor.name}` "
775
+ f"in `{chainlet_dir}`."
776
+ )
777
+ _write_truss_config_yaml(
778
+ chainlet_dir=chainlet_dir,
779
+ chains_config=chainlet_descriptor.chainlet_cls.remote_config,
780
+ model_name=model_name or chain_name,
781
+ chainlet_to_service=dep_services,
782
+ use_local_chains_src=use_local_chains_src,
783
+ )
784
+ # This assumes all imports are absolute w.r.t chain root (or site-packages).
785
+ truss_path.copy_tree_path(
786
+ chain_root, chainlet_dir / truss_config.DEFAULT_BUNDLED_PACKAGES_DIR
787
+ )
788
+ for file in chain_root.glob("*.py"):
789
+ if "-" in file.name:
790
+ raise definitions.ChainsUsageError(
791
+ f"Python file `{file}` contains `-`, use `_` instead."
792
+ )
793
+ if file.name == _MODEL_FILENAME:
794
+ raise definitions.ChainsUsageError(
795
+ f"Python file name `{_MODEL_FILENAME}` is reserved and cannot be used."
796
+ )
797
+ chainlet_file = _gen_truss_chainlet_file(
798
+ chainlet_dir,
799
+ chainlet_descriptor,
800
+ framework.get_dependencies(chainlet_descriptor),
801
+ )
802
+ remote_config = chainlet_descriptor.chainlet_cls.remote_config
803
+ if remote_config.docker_image.data_dir:
804
+ data_dir = chainlet_dir / truss_config.DEFAULT_DATA_DIRECTORY
805
+ data_dir.mkdir(parents=True, exist_ok=True)
806
+ user_data_dir = remote_config.docker_image.data_dir.abs_path
807
+ shutil.copytree(user_data_dir, data_dir, dirs_exist_ok=True)
808
+
809
+ # Copy model file s.t. during debugging imports can are properly resolved.
810
+ shutil.copy(
811
+ chainlet_file,
812
+ chainlet_file.parent.parent
813
+ / truss_config.DEFAULT_BUNDLED_PACKAGES_DIR
814
+ / "_model_dbg.py",
815
+ )
816
+ return chainlet_dir