truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

Files changed (362) hide show
  1. truss/__init__.py +10 -3
  2. truss/api/__init__.py +123 -0
  3. truss/api/definitions.py +51 -0
  4. truss/base/constants.py +116 -0
  5. truss/base/custom_types.py +29 -0
  6. truss/{errors.py → base/errors.py} +4 -0
  7. truss/base/trt_llm_config.py +310 -0
  8. truss/{truss_config.py → base/truss_config.py} +344 -31
  9. truss/{truss_spec.py → base/truss_spec.py} +20 -6
  10. truss/{validation.py → base/validation.py} +60 -11
  11. truss/cli/cli.py +841 -88
  12. truss/{remote → cli}/remote_cli.py +2 -7
  13. truss/contexts/docker_build_setup.py +67 -0
  14. truss/contexts/image_builder/cache_warmer.py +2 -8
  15. truss/contexts/image_builder/image_builder.py +1 -1
  16. truss/contexts/image_builder/serving_image_builder.py +292 -46
  17. truss/contexts/image_builder/util.py +1 -3
  18. truss/contexts/local_loader/docker_build_emulator.py +58 -0
  19. truss/contexts/local_loader/load_model_local.py +2 -2
  20. truss/contexts/local_loader/truss_module_loader.py +1 -1
  21. truss/contexts/local_loader/utils.py +1 -1
  22. truss/local/local_config.py +2 -6
  23. truss/local/local_config_handler.py +20 -5
  24. truss/patch/__init__.py +1 -0
  25. truss/patch/hash.py +4 -70
  26. truss/patch/signature.py +4 -16
  27. truss/patch/truss_dir_patch_applier.py +3 -78
  28. truss/remote/baseten/api.py +308 -23
  29. truss/remote/baseten/auth.py +3 -3
  30. truss/remote/baseten/core.py +257 -50
  31. truss/remote/baseten/custom_types.py +44 -0
  32. truss/remote/baseten/error.py +4 -0
  33. truss/remote/baseten/remote.py +369 -118
  34. truss/remote/baseten/service.py +118 -11
  35. truss/remote/baseten/utils/status.py +29 -0
  36. truss/remote/baseten/utils/tar.py +34 -22
  37. truss/remote/baseten/utils/transfer.py +36 -23
  38. truss/remote/remote_factory.py +14 -5
  39. truss/remote/truss_remote.py +72 -45
  40. truss/templates/base.Dockerfile.jinja +18 -16
  41. truss/templates/cache.Dockerfile.jinja +3 -3
  42. truss/{server → templates/control}/control/application.py +14 -35
  43. truss/{server → templates/control}/control/endpoints.py +39 -9
  44. truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
  45. truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
  46. truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
  47. truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
  48. truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
  49. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
  50. truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
  51. truss/{server → templates/control}/control/server.py +11 -6
  52. truss/templates/control/requirements.txt +9 -0
  53. truss/templates/custom_python_dx/my_model.py +28 -0
  54. truss/templates/docker_server/proxy.conf.jinja +42 -0
  55. truss/templates/docker_server/supervisord.conf.jinja +27 -0
  56. truss/templates/docker_server_requirements.txt +1 -0
  57. truss/templates/server/common/errors.py +231 -0
  58. truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
  59. truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
  60. truss/{server → templates/server}/common/retry.py +1 -0
  61. truss/{server → templates/server}/common/schema.py +11 -9
  62. truss/templates/server/common/tracing.py +157 -0
  63. truss/templates/server/main.py +9 -0
  64. truss/templates/server/model_wrapper.py +961 -0
  65. truss/templates/server/requirements.txt +21 -0
  66. truss/templates/server/truss_server.py +447 -0
  67. truss/templates/server.Dockerfile.jinja +62 -14
  68. truss/templates/shared/dynamic_config_resolver.py +28 -0
  69. truss/templates/shared/lazy_data_resolver.py +164 -0
  70. truss/templates/shared/log_config.py +125 -0
  71. truss/{server → templates}/shared/secrets_resolver.py +1 -2
  72. truss/{server → templates}/shared/serialization.py +31 -9
  73. truss/{server → templates}/shared/util.py +3 -13
  74. truss/templates/trtllm-audio/model/model.py +49 -0
  75. truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
  76. truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
  77. truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
  78. truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
  79. truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
  80. truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
  81. truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
  82. truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
  83. truss/templates/trtllm-briton/src/extension.py +64 -0
  84. truss/tests/conftest.py +302 -94
  85. truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
  86. truss/tests/contexts/local_loader/test_load_local.py +2 -2
  87. truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
  88. truss/tests/patch/test_calc_patch.py +439 -127
  89. truss/tests/patch/test_dir_signature.py +3 -12
  90. truss/tests/patch/test_hash.py +1 -1
  91. truss/tests/patch/test_signature.py +1 -1
  92. truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
  93. truss/tests/patch/test_types.py +2 -2
  94. truss/tests/remote/baseten/test_api.py +153 -58
  95. truss/tests/remote/baseten/test_auth.py +2 -1
  96. truss/tests/remote/baseten/test_core.py +160 -12
  97. truss/tests/remote/baseten/test_remote.py +489 -77
  98. truss/tests/remote/baseten/test_service.py +55 -0
  99. truss/tests/remote/test_remote_factory.py +16 -18
  100. truss/tests/remote/test_truss_remote.py +26 -17
  101. truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
  102. truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
  103. truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
  104. truss/tests/{server → templates/control}/control/test_server.py +79 -24
  105. truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
  106. truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
  107. truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
  108. truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
  109. truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
  110. truss/tests/{server → templates/server}/common/test_retry.py +3 -3
  111. truss/tests/templates/server/test_model_wrapper.py +248 -0
  112. truss/tests/{server → templates/server}/test_schema.py +3 -5
  113. truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
  114. truss/tests/test_build.py +9 -52
  115. truss/tests/test_config.py +336 -77
  116. truss/tests/test_context_builder_image.py +3 -11
  117. truss/tests/test_control_truss_patching.py +7 -12
  118. truss/tests/test_custom_server.py +38 -0
  119. truss/tests/test_data/context_builder_image_test/test.py +3 -0
  120. truss/tests/test_data/happy.ipynb +56 -0
  121. truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
  122. truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
  123. truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
  124. truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
  125. truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
  126. truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
  127. truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
  128. truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
  129. truss/tests/test_data/test_async_truss/__init__.py +0 -0
  130. truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
  131. truss/tests/test_data/test_basic_truss/__init__.py +0 -0
  132. truss/tests/test_data/test_basic_truss/config.yaml +16 -0
  133. truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
  134. truss/tests/test_data/test_build_commands/__init__.py +0 -0
  135. truss/tests/test_data/test_build_commands/config.yaml +13 -0
  136. truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
  137. truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
  138. truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
  139. truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
  140. truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
  141. truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
  142. truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
  143. truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
  144. truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
  145. truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
  146. truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
  147. truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
  148. truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
  149. truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
  150. truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
  151. truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
  152. truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
  153. truss/tests/test_data/test_openai/__init__.py +0 -0
  154. truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
  155. truss/tests/test_data/test_openai/model/__init__.py +0 -0
  156. truss/tests/test_data/test_openai/model/model.py +15 -0
  157. truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
  158. truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
  159. truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
  160. truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
  161. truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
  162. truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
  163. truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
  164. truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
  165. truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
  166. truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
  167. truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
  168. truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
  169. truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
  170. truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
  171. truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
  172. truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
  173. truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
  174. truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
  175. truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
  176. truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
  177. truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
  178. truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
  179. truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
  180. truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
  181. truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
  182. truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
  183. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
  184. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
  185. truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
  186. truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
  187. truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
  188. truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
  189. truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
  190. truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
  191. truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
  192. truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
  193. truss/tests/test_data/test_truss/__init__.py +0 -0
  194. truss/tests/test_data/test_truss/config.yaml +4 -0
  195. truss/tests/test_data/test_truss/model/__init__.py +0 -0
  196. truss/tests/test_data/test_truss/model/dummy +0 -0
  197. truss/tests/test_data/test_truss/packages/__init__.py +0 -0
  198. truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
  199. truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
  200. truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
  201. truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
  202. truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
  203. truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
  204. truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
  205. truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
  206. truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
  207. truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
  208. truss/tests/test_docker.py +2 -1
  209. truss/tests/test_model_inference.py +1340 -292
  210. truss/tests/test_model_schema.py +33 -26
  211. truss/tests/test_testing_utilities_for_other_tests.py +50 -5
  212. truss/tests/test_truss_gatherer.py +3 -5
  213. truss/tests/test_truss_handle.py +62 -59
  214. truss/tests/test_util.py +2 -1
  215. truss/tests/test_validation.py +15 -13
  216. truss/tests/trt_llm/test_trt_llm_config.py +41 -0
  217. truss/tests/trt_llm/test_validation.py +91 -0
  218. truss/tests/util/test_config_checks.py +40 -0
  219. truss/tests/util/test_env_vars.py +14 -0
  220. truss/tests/util/test_path.py +10 -23
  221. truss/trt_llm/config_checks.py +43 -0
  222. truss/trt_llm/validation.py +42 -0
  223. truss/truss_handle/__init__.py +0 -0
  224. truss/truss_handle/build.py +122 -0
  225. truss/{decorators.py → truss_handle/decorators.py} +1 -1
  226. truss/truss_handle/patch/__init__.py +0 -0
  227. truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
  228. truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
  229. truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
  230. truss/truss_handle/patch/hash.py +71 -0
  231. truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
  232. truss/truss_handle/patch/signature.py +22 -0
  233. truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
  234. truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
  235. truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
  236. truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
  237. truss/util/.truss_ignore +3 -0
  238. truss/{docker.py → util/docker.py} +6 -2
  239. truss/util/download.py +6 -15
  240. truss/util/env_vars.py +41 -0
  241. truss/util/log_utils.py +52 -0
  242. truss/util/path.py +20 -20
  243. truss/util/requirements.py +11 -0
  244. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
  245. truss-0.60.0.dist-info/RECORD +324 -0
  246. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
  247. truss-0.60.0.dist-info/entry_points.txt +4 -0
  248. truss_chains/__init__.py +71 -0
  249. truss_chains/definitions.py +756 -0
  250. truss_chains/deployment/__init__.py +0 -0
  251. truss_chains/deployment/code_gen.py +816 -0
  252. truss_chains/deployment/deployment_client.py +871 -0
  253. truss_chains/framework.py +1480 -0
  254. truss_chains/public_api.py +231 -0
  255. truss_chains/py.typed +0 -0
  256. truss_chains/pydantic_numpy.py +131 -0
  257. truss_chains/reference_code/reference_chainlet.py +34 -0
  258. truss_chains/reference_code/reference_model.py +10 -0
  259. truss_chains/remote_chainlet/__init__.py +0 -0
  260. truss_chains/remote_chainlet/model_skeleton.py +60 -0
  261. truss_chains/remote_chainlet/stub.py +380 -0
  262. truss_chains/remote_chainlet/utils.py +332 -0
  263. truss_chains/streaming.py +378 -0
  264. truss_chains/utils.py +178 -0
  265. CODE_OF_CONDUCT.md +0 -131
  266. CONTRIBUTING.md +0 -48
  267. README.md +0 -137
  268. context_builder.Dockerfile +0 -24
  269. truss/blob/blob_backend.py +0 -10
  270. truss/blob/blob_backend_registry.py +0 -23
  271. truss/blob/http_public_blob_backend.py +0 -23
  272. truss/build/__init__.py +0 -2
  273. truss/build/build.py +0 -143
  274. truss/build/configure.py +0 -63
  275. truss/cli/__init__.py +0 -2
  276. truss/cli/console.py +0 -5
  277. truss/cli/create.py +0 -5
  278. truss/config/trt_llm.py +0 -81
  279. truss/constants.py +0 -61
  280. truss/model_inference.py +0 -123
  281. truss/patch/types.py +0 -30
  282. truss/pytest.ini +0 -7
  283. truss/server/common/errors.py +0 -100
  284. truss/server/common/termination_handler_middleware.py +0 -64
  285. truss/server/common/truss_server.py +0 -389
  286. truss/server/control/patch/model_code_patch_applier.py +0 -46
  287. truss/server/control/patch/requirement_name_identifier.py +0 -17
  288. truss/server/inference_server.py +0 -29
  289. truss/server/model_wrapper.py +0 -434
  290. truss/server/shared/logging.py +0 -81
  291. truss/templates/trtllm/model/model.py +0 -97
  292. truss/templates/trtllm/packages/build_engine_utils.py +0 -34
  293. truss/templates/trtllm/packages/constants.py +0 -11
  294. truss/templates/trtllm/packages/schema.py +0 -216
  295. truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
  296. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
  297. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
  298. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
  299. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
  300. truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
  301. truss/templates/trtllm/packages/triton_client.py +0 -150
  302. truss/templates/trtllm/packages/utils.py +0 -43
  303. truss/test_data/context_builder_image_test/test.py +0 -4
  304. truss/test_data/happy.ipynb +0 -54
  305. truss/test_data/model_load_failure_test/config.yaml +0 -2
  306. truss/test_data/test_concurrency_truss/config.yaml +0 -2
  307. truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
  308. truss/test_data/test_streaming_truss/config.yaml +0 -3
  309. truss/test_data/test_truss/config.yaml +0 -2
  310. truss/tests/server/common/test_termination_handler_middleware.py +0 -93
  311. truss/tests/server/control/test_model_container_patch_applier.py +0 -203
  312. truss/tests/server/core/server/common/test_util.py +0 -19
  313. truss/tests/server/test_model_wrapper.py +0 -87
  314. truss/util/data_structures.py +0 -16
  315. truss-0.10.0rc1.dist-info/RECORD +0 -216
  316. truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
  317. truss/{server/shared → base}/__init__.py +0 -0
  318. truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
  319. truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
  320. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
  321. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
  322. truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
  323. truss/{server → templates/server}/common/__init__.py +0 -0
  324. truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
  325. truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
  326. truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
  327. truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
  328. truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
  329. truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
  330. truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
  331. truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
  332. truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
  333. truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
  334. truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
  335. truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
  336. truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
  337. truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
  338. truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
  339. truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
  340. truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
  341. truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
  342. truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
  343. truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
  344. truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
  345. truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
  346. truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
  347. /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
  348. /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
  349. /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
  350. /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
  351. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
  352. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
  353. /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
  354. /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
  355. /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
  356. /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
  357. /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
  358. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
  359. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
  360. /truss/{patch → truss_handle/patch}/constants.py +0 -0
  361. /truss/{notebook.py → util/notebook.py} +0 -0
  362. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
truss/cli/cli.py CHANGED
@@ -3,29 +3,62 @@ import json
3
3
  import logging
4
4
  import os
5
5
  import sys
6
+ import time
7
+ import warnings
6
8
  from functools import wraps
7
9
  from pathlib import Path
8
- from typing import Callable, Optional
10
+ from typing import Callable, List, Optional, Tuple, Union
9
11
 
10
12
  import rich
13
+ import rich.live
14
+ import rich.logging
15
+ import rich.spinner
16
+ import rich.table
17
+ import rich.traceback
11
18
  import rich_click as click
12
19
  import truss
13
- from truss.cli.console import console
14
- from truss.cli.create import ask_name
20
+ from InquirerPy import inquirer
21
+ from rich import progress
22
+ from rich.console import Console
23
+ from truss.base.constants import (
24
+ PRODUCTION_ENVIRONMENT_NAME,
25
+ TRTLLM_MIN_MEMORY_REQUEST_GI,
26
+ )
27
+ from truss.base.errors import RemoteNetworkError
28
+ from truss.base.trt_llm_config import TrussTRTLLMQuantizationType
29
+ from truss.base.truss_config import Build, ModelServer
30
+ from truss.cli.remote_cli import (
31
+ inquire_model_name,
32
+ inquire_remote_config,
33
+ inquire_remote_name,
34
+ )
15
35
  from truss.remote.baseten.core import (
36
+ ACTIVE_STATUS,
37
+ DEPLOYING_STATUSES,
16
38
  ModelId,
17
39
  ModelIdentifier,
18
40
  ModelName,
19
41
  ModelVersionId,
20
42
  )
21
43
  from truss.remote.baseten.service import BasetenService
22
- from truss.remote.remote_cli import inquire_model_name, inquire_remote_name
44
+ from truss.remote.baseten.utils.status import get_displayable_status
23
45
  from truss.remote.remote_factory import USER_TRUSSRC_PATH, RemoteFactory
24
- from truss.truss_config import Build, ModelServer
25
- from truss.truss_handle import TrussHandle
26
-
27
- logging.basicConfig(level=logging.INFO)
28
-
46
+ from truss.trt_llm.config_checks import (
47
+ is_missing_secrets_for_trt_llm_builder,
48
+ memory_updated_for_trt_llm_builder,
49
+ uses_trt_llm_builder,
50
+ )
51
+ from truss.truss_handle.build import cleanup as _cleanup
52
+ from truss.truss_handle.build import init_directory as _init
53
+ from truss.truss_handle.build import load
54
+ from truss.util import docker
55
+ from truss.util.log_utils import LogInterceptor
56
+
57
+ rich.spinner.SPINNERS["deploying"] = {"interval": 500, "frames": ["👾 ", " 👾"]}
58
+ rich.spinner.SPINNERS["building"] = {"interval": 500, "frames": ["🛠️ ", " 🛠️"]}
59
+ rich.spinner.SPINNERS["loading"] = {"interval": 500, "frames": ["⏱️ ", " ⏱️"]}
60
+ rich.spinner.SPINNERS["active"] = {"interval": 500, "frames": ["💚 ", " 💚"]}
61
+ rich.spinner.SPINNERS["failed"] = {"interval": 500, "frames": ["😤 ", " 😤"]}
29
62
 
30
63
  click.rich_click.COMMAND_GROUPS = {
31
64
  "truss": [
@@ -33,26 +66,31 @@ click.rich_click.COMMAND_GROUPS = {
33
66
  "name": "Main usage",
34
67
  "commands": ["init", "push", "watch", "predict"],
35
68
  "table_styles": { # type: ignore
36
- "row_styles": ["green"],
69
+ "row_styles": ["green"]
37
70
  },
38
71
  },
39
72
  {
40
73
  "name": "Advanced Usage",
41
74
  "commands": ["image", "container", "cleanup"],
42
75
  "table_styles": { # type: ignore
43
- "row_styles": ["yellow"],
76
+ "row_styles": ["yellow"]
77
+ },
78
+ },
79
+ {
80
+ "name": "Chains",
81
+ "commands": ["chains"],
82
+ "table_styles": { # type: ignore
83
+ "row_styles": ["red"]
44
84
  },
45
85
  },
46
86
  ]
47
87
  }
48
88
 
89
+ console = Console()
49
90
 
50
- def echo_output(f: Callable[..., object]):
51
- @wraps(f)
52
- def wrapper(*args, **kwargs):
53
- click.echo(f(*args, **kwargs))
91
+ error_console = Console(stderr=True, style="bold red")
54
92
 
55
- return wrapper
93
+ is_humanfriendly_log_level = True
56
94
 
57
95
 
58
96
  def error_handling(f: Callable[..., object]):
@@ -63,11 +101,74 @@ def error_handling(f: Callable[..., object]):
63
101
  except click.UsageError as e:
64
102
  raise e # You can re-raise the exception or handle it different
65
103
  except Exception as e:
66
- click.secho(f"ERROR: {e}", fg="red")
104
+ if is_humanfriendly_log_level:
105
+ console.print(
106
+ f"[bold red]ERROR {type(e).__name__}[/bold red]: {e}",
107
+ highlight=True,
108
+ )
109
+ else:
110
+ console.print_exception(show_locals=True)
67
111
 
68
112
  return wrapper
69
113
 
70
114
 
115
+ _HUMANFRIENDLY_LOG_LEVEL = "humanfriendly"
116
+ _log_level_str_to_level = {
117
+ _HUMANFRIENDLY_LOG_LEVEL: logging.INFO,
118
+ "W": logging.WARNING,
119
+ "WARNING": logging.WARNING,
120
+ "I": logging.INFO,
121
+ "INFO": logging.INFO,
122
+ "D": logging.DEBUG,
123
+ "DEBUG": logging.DEBUG,
124
+ }
125
+
126
+
127
+ def _set_logging_level(log_level: Union[str, int]) -> None:
128
+ if isinstance(log_level, str):
129
+ level = _log_level_str_to_level[log_level]
130
+ else:
131
+ level = log_level
132
+ root_logger = logging.getLogger()
133
+ root_logger.setLevel(level)
134
+
135
+ if log_level == _HUMANFRIENDLY_LOG_LEVEL:
136
+ rich_handler = rich.logging.RichHandler(
137
+ show_time=False, show_level=False, show_path=False
138
+ )
139
+ else:
140
+ # Rich handler adds time, levels, file location etc.
141
+ rich_handler = rich.logging.RichHandler()
142
+ global is_humanfriendly_log_level
143
+ is_humanfriendly_log_level = False
144
+
145
+ root_logger.handlers = [] # Clear existing handlers
146
+ root_logger.addHandler(rich_handler)
147
+ # Enable deprecation warnings raised in this module.
148
+ warnings.filterwarnings(
149
+ "default", category=DeprecationWarning, module="^truss\.cli\\b"
150
+ )
151
+
152
+
153
+ def log_level_option(f):
154
+ def callback(ctx, param, value):
155
+ _set_logging_level(value)
156
+ return value
157
+
158
+ return click.option(
159
+ "--log",
160
+ default=_HUMANFRIENDLY_LOG_LEVEL,
161
+ expose_value=False,
162
+ help="Customizes logging.",
163
+ type=click.Choice(list(_log_level_str_to_level.keys()), case_sensitive=False),
164
+ callback=callback,
165
+ )(f)
166
+
167
+
168
+ def _format_link(text: str) -> str:
169
+ return f"[link={text}]{text}[/link]"
170
+
171
+
71
172
  def print_help() -> None:
72
173
  ctx = click.get_current_context()
73
174
  click.echo(ctx.get_help())
@@ -76,6 +177,7 @@ def print_help() -> None:
76
177
  @click.group(name="truss", invoke_without_command=True) # type: ignore
77
178
  @click.pass_context
78
179
  @click.version_option(truss.version())
180
+ @log_level_option
79
181
  def truss_cli(ctx) -> None:
80
182
  """truss: The simplest way to serve models in production"""
81
183
  if not ctx.invoked_subcommand:
@@ -85,13 +187,11 @@ def truss_cli(ctx) -> None:
85
187
  @click.group()
86
188
  def container():
87
189
  """Subcommands for truss container"""
88
- pass
89
190
 
90
191
 
91
192
  @click.group()
92
193
  def image():
93
194
  """Subcommands for truss image"""
94
- pass
95
195
 
96
196
 
97
197
  @truss_cli.command()
@@ -103,23 +203,36 @@ def image():
103
203
  default=ModelServer.TrussServer.value,
104
204
  type=click.Choice([server.value for server in ModelServer]),
105
205
  )
206
+ @click.option("-n", "--name", type=click.STRING)
207
+ @click.option(
208
+ "--python-config/--no-python-config",
209
+ type=bool,
210
+ default=False,
211
+ help="Uses the code first tooling to build models.",
212
+ )
213
+ @log_level_option
106
214
  @error_handling
107
- def init(target_directory, backend) -> None:
215
+ def init(target_directory, backend, name, python_config) -> None:
108
216
  """Create a new truss.
109
217
 
110
218
  TARGET_DIRECTORY: A Truss is created in this directory
111
219
  """
112
220
  if os.path.isdir(target_directory):
113
221
  raise click.ClickException(
114
- f'Error: Directory "{target_directory}" already exists and cannot be overwritten.'
222
+ f"Error: Directory `{target_directory}` already exists "
223
+ "and cannot be overwritten."
115
224
  )
116
225
  tr_path = Path(target_directory)
117
226
  build_config = Build(model_server=ModelServer[backend])
118
- model_name = ask_name()
119
- truss.init(
227
+ if name:
228
+ model_name = name
229
+ else:
230
+ model_name = inquire_model_name()
231
+ _init(
120
232
  target_directory=target_directory,
121
233
  build_config=build_config,
122
234
  model_name=model_name,
235
+ python_config=python_config,
123
236
  )
124
237
  click.echo(f"Truss {model_name} was created in {tr_path.absolute()}")
125
238
 
@@ -127,6 +240,7 @@ def init(target_directory, backend) -> None:
127
240
  @image.command() # type: ignore
128
241
  @click.argument("build_dir")
129
242
  @click.argument("target_directory", required=False)
243
+ @log_level_option
130
244
  @error_handling
131
245
  def build_context(build_dir, target_directory: str) -> None:
132
246
  """
@@ -143,9 +257,16 @@ def build_context(build_dir, target_directory: str) -> None:
143
257
  @image.command() # type: ignore
144
258
  @click.argument("target_directory", required=False)
145
259
  @click.argument("build_dir", required=False)
146
- @error_handling
147
260
  @click.option("--tag", help="Docker image tag")
148
- def build(target_directory: str, build_dir: Path, tag) -> None:
261
+ @click.option(
262
+ "--use_host_network",
263
+ is_flag=True,
264
+ default=False,
265
+ help="Use host network for docker build",
266
+ )
267
+ @log_level_option
268
+ @error_handling
269
+ def build(target_directory: str, build_dir: Path, tag, use_host_network) -> None:
149
270
  """
150
271
  Builds the docker image for a Truss.
151
272
 
@@ -156,6 +277,9 @@ def build(target_directory: str, build_dir: Path, tag) -> None:
156
277
  tr = _get_truss_from_directory(target_directory=target_directory)
157
278
  if build_dir:
158
279
  build_dir = Path(build_dir)
280
+ if use_host_network:
281
+ tr.build_serving_docker_image(build_dir=build_dir, tag=tag, network="host")
282
+ return
159
283
  tr.build_serving_docker_image(build_dir=build_dir, tag=tag)
160
284
 
161
285
 
@@ -168,13 +292,16 @@ def build(target_directory: str, build_dir: Path, tag) -> None:
168
292
  "--attach", is_flag=True, default=False, help="Flag for attaching the process"
169
293
  )
170
294
  @click.option(
171
- "--cache/--no-cache",
295
+ "--use_host_network",
172
296
  is_flag=True,
173
- default=True,
174
- help="Flag for caching build or not",
297
+ default=False,
298
+ help="Use host network for docker build",
175
299
  )
300
+ @log_level_option
176
301
  @error_handling
177
- def run(target_directory: str, build_dir: Path, tag, port, attach, cache) -> None:
302
+ def run(
303
+ target_directory: str, build_dir: Path, tag, port, attach, use_host_network
304
+ ) -> None:
178
305
  """
179
306
  Runs the docker image for a Truss.
180
307
 
@@ -190,38 +317,74 @@ def run(target_directory: str, build_dir: Path, tag, port, attach, cache) -> Non
190
317
  click.confirm(
191
318
  f"Container already exists at {urls}. Are you sure you want to continue?"
192
319
  )
193
- tr.docker_run(
194
- build_dir=build_dir, tag=tag, local_port=port, detach=not attach, cache=cache
195
- )
320
+ if use_host_network:
321
+ tr.docker_run(
322
+ build_dir=build_dir,
323
+ tag=tag,
324
+ local_port=port,
325
+ detach=not attach,
326
+ network="host",
327
+ )
328
+ return
329
+ tr.docker_run(build_dir=build_dir, tag=tag, local_port=port, detach=not attach)
196
330
 
197
331
 
198
332
  @truss_cli.command()
199
- @click.argument("target_directory", required=False, default=os.getcwd())
200
333
  @click.option(
201
- "--remote",
334
+ "--api-key",
202
335
  type=str,
203
336
  required=False,
204
337
  help="Name of the remote in .trussrc to patch changes to",
205
338
  )
339
+ @error_handling
340
+ def login(api_key: Optional[str]):
341
+ from truss.api import login
342
+
343
+ if not api_key:
344
+ remote_config = inquire_remote_config()
345
+ RemoteFactory.update_remote_config(remote_config)
346
+ else:
347
+ login(api_key)
348
+
349
+
350
+ @truss_cli.command()
206
351
  @click.option(
207
- "--logs",
208
- is_flag=True,
209
- show_default=True,
210
- default=False,
211
- help="Automatically open remote logs tab",
352
+ "--remote",
353
+ type=str,
354
+ required=False,
355
+ help="Name of the remote in .trussrc to check whoami.",
212
356
  )
213
357
  @error_handling
214
- def watch(
215
- target_directory: str,
216
- remote: str,
217
- logs: bool,
218
- ) -> None:
358
+ def whoami(remote: Optional[str]):
359
+ """
360
+ Shows user information and exit.
361
+ """
362
+ from truss.api import whoami
363
+
364
+ if not remote:
365
+ remote = inquire_remote_name(RemoteFactory.get_available_config_names())
366
+
367
+ user = whoami(remote)
368
+
369
+ console.print(f"{user.workspace_name}\{user.user_email}")
370
+
371
+
372
+ @truss_cli.command()
373
+ @click.argument("target_directory", required=False, default=os.getcwd())
374
+ @click.option(
375
+ "--remote",
376
+ type=str,
377
+ required=False,
378
+ help="Name of the remote in .trussrc to patch changes to",
379
+ )
380
+ @log_level_option
381
+ @error_handling
382
+ def watch(target_directory: str, remote: str) -> None:
219
383
  """
220
384
  Seamless remote development with truss
221
385
 
222
386
  TARGET_DIRECTORY: A Truss directory. If none, use current directory.
223
387
  """
224
-
225
388
  # TODO: ensure that provider support draft
226
389
  if not remote:
227
390
  remote = inquire_remote_name(RemoteFactory.get_available_config_names())
@@ -231,16 +394,456 @@ def watch(
231
394
  tr = _get_truss_from_directory(target_directory=target_directory)
232
395
  model_name = tr.spec.config.model_name
233
396
  if not model_name:
234
- rich.print(
397
+ console.print(
235
398
  "🧐 NoneType model_name provided in config.yaml. "
236
399
  "Please check that you have the correct model name in your config file."
237
400
  )
238
401
  sys.exit(1)
239
402
 
240
403
  service = remote_provider.get_service(model_identifier=ModelName(model_name))
241
- logs_url = remote_provider.get_remote_logs_url(service)
242
- rich.print(f"🪵 View logs for your deployment at {logs_url}")
243
- remote_provider.sync_truss_to_dev_version_by_name(model_name, target_directory)
404
+ console.print(
405
+ f"🪵 View logs for your deployment at {_format_link(service.logs_url)}"
406
+ )
407
+
408
+ if not os.path.isfile(target_directory):
409
+ remote_provider.sync_truss_to_dev_version_by_name(
410
+ model_name, target_directory, console, error_console
411
+ )
412
+ else:
413
+ # These imports are delayed, to handle pydantic v1 envs gracefully.
414
+ from truss_chains.deployment import deployment_client
415
+
416
+ deployment_client.watch_model(
417
+ source=Path(target_directory),
418
+ model_name=model_name,
419
+ remote_provider=remote_provider,
420
+ console=console,
421
+ error_console=error_console,
422
+ )
423
+
424
+
425
+ # Chains Stuff #########################################################################
426
+
427
+
428
+ @click.group()
429
+ def chains():
430
+ """Subcommands for truss chains"""
431
+
432
+
433
+ def _make_chains_curl_snippet(run_remote_url: str, environment: Optional[str]) -> str:
434
+ if environment:
435
+ idx = run_remote_url.find("deployment")
436
+ if idx != -1:
437
+ run_remote_url = (
438
+ run_remote_url[:idx] + f"environments/{environment}/run_remote"
439
+ )
440
+ return (
441
+ f"curl -X POST '{run_remote_url}' \\\n"
442
+ ' -H "Authorization: Api-Key $BASETEN_API_KEY" \\\n'
443
+ " -d '<JSON_INPUT>'"
444
+ )
445
+
446
+
447
+ def _create_chains_table(service) -> Tuple[rich.table.Table, List[str]]:
448
+ """Creates a status table similar to:
449
+
450
+ ⛓️ ItestChain - Chain ⛓️
451
+
452
+ 🌐 Status page: https://app.baseten.co/chains/p7qrm93v/overview
453
+ ╭──────────────────────┬──────────────────────────────┬─────────────────────────────────────────────╮
454
+ │ Status │ Chainlet │ Logs URL │
455
+ ├──────────────────────┼──────────────────────────────┼────────────────────────────────────────────┤
456
+ │ 🛠️ BUILDING │ ItestChain (entrypoint) │ https://app.baseten.co/chains/.../logs/... │
457
+ ├──────────────────────┼──────────────────────────────┼────────────────────────────────────────────┤
458
+ │ 👾 DEPLOYING │ GENERATE_DATA (internal) │ https://app.baseten.co/chains/.../logs/... │
459
+ │ 👾 DEPLOYING │ SplitTextFailOnce (internal) │ https://app.baseten.co/chains/.../logs/... │
460
+ │ 👾 DEPLOYING │ TextReplicator (internal) │ https://app.baseten.co/chains/.../logs/... │
461
+ │ 🛠️ BUILDING │ TextToNum (internal) │ https://app.baseten.co/chains/.../logs/... │
462
+ ╰──────────────────────┴──────────────────────────────┴────────────────────────────────────────────╯
463
+
464
+ """
465
+ title = (
466
+ f"⛓️ {service.name} - Chain ⛓️\n\n "
467
+ f"🌐 Status page: {_format_link(service.status_page_url)}"
468
+ )
469
+ table = rich.table.Table(
470
+ show_header=True,
471
+ header_style="bold yellow",
472
+ title=title,
473
+ box=rich.table.box.ROUNDED,
474
+ border_style="blue",
475
+ )
476
+ table.add_column("Status", style="dim", min_width=20)
477
+ table.add_column("Chainlet", min_width=20)
478
+ table.add_column("Logs URL")
479
+ statuses = []
480
+ status_iterable = service.get_info()
481
+ # Organize status_iterable s.t. entrypoint is first.
482
+ entrypoint = next(x for x in status_iterable if x.is_entrypoint)
483
+ sorted_chainlets = sorted(
484
+ (x for x in status_iterable if not x.is_entrypoint), key=lambda x: x.name
485
+ )
486
+ for i, chainlet in enumerate([entrypoint] + sorted_chainlets):
487
+ displayable_status = get_displayable_status(chainlet.status)
488
+ if displayable_status == ACTIVE_STATUS:
489
+ spinner_name = "active"
490
+ elif displayable_status in DEPLOYING_STATUSES:
491
+ if displayable_status == "BUILDING":
492
+ spinner_name = "building"
493
+ elif displayable_status == "LOADING":
494
+ spinner_name = "loading"
495
+ else:
496
+ spinner_name = "deploying"
497
+ else:
498
+ spinner_name = "failed"
499
+ spinner = rich.spinner.Spinner(spinner_name, text=displayable_status)
500
+ if chainlet.is_entrypoint:
501
+ display_name = f"{chainlet.name} (entrypoint)"
502
+ else:
503
+ display_name = f"{chainlet.name} (internal)"
504
+
505
+ table.add_row(spinner, display_name, _format_link(chainlet.logs_url))
506
+ # Add section divider after entrypoint, entrypoint must be first.
507
+ if chainlet.is_entrypoint:
508
+ table.add_section()
509
+ statuses.append(displayable_status)
510
+ return table, statuses
511
+
512
+
513
+ @chains.command(name="push") # type: ignore
514
+ @click.argument("source", type=Path, required=True)
515
+ @click.argument("entrypoint", type=str, required=False)
516
+ @click.option(
517
+ "--name",
518
+ type=str,
519
+ required=False,
520
+ help="Name of the chain to be deployed, if not given, the entrypoint name is used.",
521
+ )
522
+ @click.option(
523
+ "--publish/--no-publish",
524
+ type=bool,
525
+ default=False,
526
+ help="Create chainlets as published deployments.",
527
+ )
528
+ @click.option(
529
+ "--promote/--no-promote",
530
+ type=bool,
531
+ default=False,
532
+ help="Replace production chainlets with newly deployed chainlets.",
533
+ )
534
+ @click.option(
535
+ "--environment",
536
+ type=str,
537
+ required=False,
538
+ help=(
539
+ "Deploy the chain as a published deployment to the specified environment."
540
+ "If specified, --publish is implied and the supplied value of --promote will be ignored."
541
+ ),
542
+ )
543
+ @click.option(
544
+ "--wait/--no-wait",
545
+ type=bool,
546
+ default=True,
547
+ help="Wait until all chainlets are ready (or deployment failed).",
548
+ )
549
+ @click.option(
550
+ "--watch/--no-watch",
551
+ type=bool,
552
+ default=False,
553
+ help=(
554
+ "Watches the chains source code and applies live patches. Using this option "
555
+ "will wait for the chain to be deployed (i.e. `--wait` flag is applied), "
556
+ "before starting to watch for changes. This option required the deployment "
557
+ "to be a development deployment (i.e. `--no-promote` and `--no-publish`."
558
+ ),
559
+ )
560
+ @click.option(
561
+ "--dryrun",
562
+ type=bool,
563
+ default=False,
564
+ is_flag=True,
565
+ help="Produces only generated files, but doesn't deploy anything.",
566
+ )
567
+ @click.option(
568
+ "--remote",
569
+ type=str,
570
+ required=False,
571
+ help="Name of the remote in .trussrc to push to.",
572
+ )
573
+ @click.option(
574
+ "--experimental-watch-chainlet-names",
575
+ type=str,
576
+ required=False,
577
+ help=(
578
+ "Runs `watch`, but only applies patches to specified chainlets. The option is "
579
+ "a comma-separated list of chainlet (display) names. This option can give "
580
+ "faster dev loops, but also lead to inconsistent deployments. Use with caution "
581
+ "and refer to docs."
582
+ ),
583
+ )
584
+ @log_level_option
585
+ @error_handling
586
+ def push_chain(
587
+ source: Path,
588
+ entrypoint: Optional[str],
589
+ name: Optional[str],
590
+ publish: bool,
591
+ promote: bool,
592
+ wait: bool,
593
+ watch: bool,
594
+ dryrun: bool,
595
+ remote: Optional[str],
596
+ environment: Optional[str],
597
+ experimental_watch_chainlet_names: Optional[str],
598
+ ) -> None:
599
+ """
600
+ Deploys a chain remotely.
601
+
602
+ SOURCE: Path to a python file that contains the entrypoint chainlet.
603
+
604
+ ENTRYPOINT: Class name of the entrypoint chainlet in source file. May be omitted
605
+ if a chainlet definition in SOURCE is tagged with `@chains.mark_entrypoint`.
606
+ """
607
+ # These imports are delayed, to handle pydantic v1 envs gracefully.
608
+ from truss_chains import definitions as chains_def
609
+ from truss_chains import framework
610
+ from truss_chains.deployment import deployment_client
611
+
612
+ if experimental_watch_chainlet_names:
613
+ watch = True
614
+
615
+ if watch:
616
+ if publish or promote:
617
+ raise ValueError(
618
+ "When using `--watch`, the deployment cannot be published or promoted."
619
+ )
620
+ if not wait:
621
+ console.print(
622
+ "`--watch` is used. Will wait for deployment before watching files."
623
+ )
624
+ wait = True
625
+
626
+ if promote and environment:
627
+ promote_warning = (
628
+ "`promote` flag and `environment` flag were both specified. "
629
+ "Ignoring the value of `promote`."
630
+ )
631
+ console.print(promote_warning, style="yellow")
632
+
633
+ if not remote:
634
+ remote = inquire_remote_name(RemoteFactory.get_available_config_names())
635
+
636
+ with framework.ChainletImporter.import_target(source, entrypoint) as entrypoint_cls:
637
+ chain_name = (
638
+ name or entrypoint_cls.meta_data.chain_name or entrypoint_cls.display_name
639
+ )
640
+ options = chains_def.PushOptionsBaseten.create(
641
+ chain_name=chain_name,
642
+ promote=promote,
643
+ publish=publish,
644
+ only_generate_trusses=dryrun,
645
+ remote=remote,
646
+ environment=environment,
647
+ )
648
+ service = deployment_client.push(
649
+ entrypoint_cls, options, progress_bar=progress.Progress
650
+ )
651
+
652
+ if dryrun:
653
+ return
654
+
655
+ assert isinstance(service, deployment_client.BasetenChainService)
656
+ curl_snippet = _make_chains_curl_snippet(
657
+ service.run_remote_url, options.environment
658
+ )
659
+
660
+ table, statuses = _create_chains_table(service)
661
+ status_check_wait_sec = 2
662
+ if wait:
663
+ num_services = len(statuses)
664
+ success = False
665
+ num_failed = 0
666
+ # Logging inferences with live display (even when using richHandler)
667
+ # -> capture logs and print later.
668
+ with LogInterceptor() as log_interceptor, rich.live.Live(
669
+ table, console=console, refresh_per_second=4
670
+ ) as live:
671
+ while True:
672
+ table, statuses = _create_chains_table(service)
673
+ live.update(table)
674
+ num_active = sum(s == ACTIVE_STATUS for s in statuses)
675
+ num_deploying = sum(s in DEPLOYING_STATUSES for s in statuses)
676
+ if num_active == num_services:
677
+ success = True
678
+ break
679
+ elif num_failed := num_services - num_active - num_deploying:
680
+ break
681
+ time.sleep(status_check_wait_sec)
682
+
683
+ intercepted_logs = log_interceptor.get_logs()
684
+
685
+ # Prints must be outside `Live` context.
686
+ if intercepted_logs:
687
+ console.print("Logs intercepted during waiting:", style="blue")
688
+ for log in intercepted_logs:
689
+ console.print(f"\t{log}")
690
+ if success:
691
+ deploy_success_text = "Deployment succeeded."
692
+ if environment:
693
+ deploy_success_text = (
694
+ "Your chain has been deployed into "
695
+ f"the {options.environment} environment."
696
+ )
697
+ console.print(deploy_success_text, style="bold green")
698
+ console.print(f"You can run the chain with:\n{curl_snippet}")
699
+
700
+ if watch: # Note that this command will print a startup message.
701
+ if experimental_watch_chainlet_names:
702
+ included_chainlets = [
703
+ x.strip() for x in experimental_watch_chainlet_names.split(",")
704
+ ]
705
+ else:
706
+ included_chainlets = None
707
+ deployment_client.watch(
708
+ source,
709
+ entrypoint,
710
+ name,
711
+ remote,
712
+ console,
713
+ error_console,
714
+ show_stack_trace=not is_humanfriendly_log_level,
715
+ included_chainlets=included_chainlets,
716
+ )
717
+ else:
718
+ console.print(f"Deployment failed ({num_failed} failures).", style="red")
719
+ else:
720
+ console.print(table)
721
+ console.print(
722
+ "Once all chainlets are deployed, "
723
+ f"you can run the chain with:\n\n{curl_snippet}"
724
+ )
725
+
726
+
727
+ @chains.command(name="watch") # type: ignore
728
+ @click.argument("source", type=Path, required=True)
729
+ @click.argument("entrypoint", type=str, required=False)
730
+ @click.option(
731
+ "--name",
732
+ type=str,
733
+ required=False,
734
+ help="Name of the chain to be deployed, if not given, the entrypoint name is used.",
735
+ )
736
+ @click.option(
737
+ "--remote",
738
+ type=str,
739
+ required=False,
740
+ help="Name of the remote in .trussrc to push to.",
741
+ )
742
+ @click.option(
743
+ "--experimental-chainlet-names",
744
+ type=str,
745
+ required=False,
746
+ help=(
747
+ "Runs `watch`, but only applies patches to specified chainlets. The option is "
748
+ "a comma-separated list of chainlet (display) names. This option can give "
749
+ "faster dev loops, but also lead to inconsistent deployments. Use with caution "
750
+ "and refer to docs."
751
+ ),
752
+ )
753
+ @log_level_option
754
+ @error_handling
755
+ def watch_chains(
756
+ source: Path,
757
+ entrypoint: Optional[str],
758
+ name: Optional[str],
759
+ remote: Optional[str],
760
+ experimental_chainlet_names: Optional[str],
761
+ ) -> None:
762
+ """
763
+ Watches the chains source code and applies live patches to a development deployment.
764
+
765
+ The development deployment must have been deployed before running this command.
766
+
767
+ SOURCE: Path to a python file that contains the entrypoint chainlet.
768
+
769
+ ENTRYPOINT: Class name of the entrypoint chainlet in source file. May be omitted
770
+ if a chainlet definition in SOURCE is tagged with `@chains.mark_entrypoint`.
771
+ """
772
+ # These imports are delayed, to handle pydantic v1 envs gracefully.
773
+ from truss_chains.deployment import deployment_client
774
+
775
+ if not remote:
776
+ remote = inquire_remote_name(RemoteFactory.get_available_config_names())
777
+
778
+ if experimental_chainlet_names:
779
+ included_chainlets = [x.strip() for x in experimental_chainlet_names.split(",")]
780
+ else:
781
+ included_chainlets = None
782
+
783
+ deployment_client.watch(
784
+ source,
785
+ entrypoint,
786
+ name,
787
+ remote,
788
+ console,
789
+ error_console,
790
+ show_stack_trace=not is_humanfriendly_log_level,
791
+ included_chainlets=included_chainlets,
792
+ )
793
+
794
+
795
+ @chains.command(name="init") # type: ignore
796
+ @click.argument("directory", type=Path, required=False)
797
+ @log_level_option
798
+ @error_handling
799
+ def init_chain(directory: Optional[Path]) -> None:
800
+ """
801
+ Initializes a chains project directory.
802
+
803
+ DIRECTORY: A name of new or existing directory to create the chain in,
804
+ it must be empty. If not specified, the current directory is used.
805
+
806
+ """
807
+ if not directory:
808
+ directory = Path.cwd()
809
+ if directory.exists():
810
+ if not directory.is_dir():
811
+ raise ValueError(f"The path {directory} must be a directory.")
812
+ if any(directory.iterdir()):
813
+ raise ValueError(f"Directory {directory} must be empty.")
814
+ else:
815
+ directory.mkdir()
816
+
817
+ filename = inquirer.text(
818
+ qmark="",
819
+ message="Enter the python file name for the chain.",
820
+ default="my_chain.py",
821
+ ).execute()
822
+ filepath = directory / str(filename).strip()
823
+ console.print(f"Creating and populating {filepath}...\n")
824
+ source_code = _load_example_chainlet_code()
825
+ filepath.write_text(source_code)
826
+ console.print(
827
+ "Next steps:\n",
828
+ f"💻 Run [bold green]`python {filepath}`[/bold green] for local debug "
829
+ "execution.\n"
830
+ f"🚢 Run [bold green]`truss chains deploy {filepath}`[/bold green] "
831
+ "to deploy the chain to Baseten.\n",
832
+ )
833
+
834
+
835
+ def _load_example_chainlet_code() -> str:
836
+ try:
837
+ from truss_chains.reference_code import reference_chainlet
838
+ # if the example is faulty, a validation error would be raised
839
+ except Exception as e:
840
+ raise Exception("Failed to load starter code. Please notify support.") from e
841
+
842
+ source = Path(reference_chainlet.__file__).read_text()
843
+ return source
844
+
845
+
846
+ # End Chains Stuff #####################################################################
244
847
 
245
848
 
246
849
  def _extract_and_validate_model_identifier(
@@ -315,7 +918,10 @@ def _extract_request_data(data: Optional[str], file: Optional[Path]):
315
918
  "--model-version",
316
919
  type=str,
317
920
  required=False,
318
- help="[DEPRECATED] Use --model-deployment instead, this will be removed in future release. ID of model deployment",
921
+ help=(
922
+ "[DEPRECATED] Use --model-deployment instead, this will be "
923
+ "removed in future release. ID of model deployment"
924
+ ),
319
925
  )
320
926
  @click.option(
321
927
  "--model-deployment",
@@ -323,13 +929,8 @@ def _extract_request_data(data: Optional[str], file: Optional[Path]):
323
929
  required=False,
324
930
  help="ID of model deployment to call",
325
931
  )
326
- @click.option(
327
- "--model",
328
- type=str,
329
- required=False,
330
- help="ID of model to call",
331
- )
332
- @echo_output
932
+ @click.option("--model", type=str, required=False, help="ID of model to call")
933
+ @log_level_option
333
934
  def predict(
334
935
  target_directory: str,
335
936
  remote: str,
@@ -356,7 +957,8 @@ def predict(
356
957
 
357
958
  if model_version:
358
959
  console.print(
359
- "[DEPRECATED] --model-version is deprecated, use --model-deployment instead.",
960
+ "[DEPRECATED] --model-version is deprecated, "
961
+ "use --model-deployment instead.",
360
962
  style="yellow",
361
963
  )
362
964
  model_deployment = model_version
@@ -376,7 +978,7 @@ def predict(
376
978
 
377
979
  # Log deployment ID for Baseten models.
378
980
  if isinstance(service, BasetenService):
379
- rich.print(
981
+ console.print(
380
982
  f"Calling predict on {'[cyan]development[/cyan] ' if service.is_draft else ''}"
381
983
  f"deployment ID {service.model_version_id}..."
382
984
  )
@@ -386,7 +988,40 @@ def predict(
386
988
  for chunk in result:
387
989
  click.echo(chunk, nl=False)
388
990
  return
389
- rich.print_json(data=result)
991
+ console.print_json(data=result)
992
+
993
+
994
+ @truss_cli.command()
995
+ @click.argument("script", required=True)
996
+ @click.argument("target_directory", required=False, default=os.getcwd())
997
+ def run_python(script, target_directory):
998
+ if not Path(script).exists():
999
+ raise click.BadParameter(
1000
+ f"File {script} does not exist. Please provide a valid file."
1001
+ )
1002
+
1003
+ if not Path(target_directory).exists():
1004
+ raise click.BadParameter(f"Directory {target_directory} does not exist.")
1005
+
1006
+ if not (Path(target_directory) / "config.yaml").exists():
1007
+ raise click.BadParameter(
1008
+ f"Directory {target_directory} does not contain a valid Truss."
1009
+ )
1010
+
1011
+ tr = _get_truss_from_directory(target_directory=target_directory)
1012
+ container = tr.run_python_script(Path(script))
1013
+ for output in container.logs():
1014
+ output_type = output[0]
1015
+ output_content = output[1]
1016
+
1017
+ options = {}
1018
+
1019
+ if output_type == "stderr":
1020
+ options["fg"] = "red"
1021
+
1022
+ click.secho(output_content.decode("utf-8", "replace"), nl=False, **options)
1023
+ exit_code = container.wait()
1024
+ sys.exit(exit_code)
390
1025
 
391
1026
 
392
1027
  @truss_cli.command()
@@ -422,6 +1057,15 @@ def predict(
422
1057
  "after deploy completes."
423
1058
  ),
424
1059
  )
1060
+ @click.option(
1061
+ "--environment",
1062
+ type=str,
1063
+ required=False,
1064
+ help=(
1065
+ "Push the truss as a published deployment to the specified environment."
1066
+ "If specified, --publish is implied and the supplied value of --promote will be ignored."
1067
+ ),
1068
+ )
425
1069
  @click.option(
426
1070
  "--preserve-previous-production-deployment",
427
1071
  type=bool,
@@ -429,9 +1073,9 @@ def predict(
429
1073
  required=False,
430
1074
  default=False,
431
1075
  help=(
432
- "Preserve the previous production deployment's autoscaling setting. When not specified, "
433
- "the previous production deployment will be updated to allow it to scale to zero. "
434
- "Can only be use in combination with --promote option"
1076
+ "Preserve the previous production deployment's autoscaling setting. When "
1077
+ "not specified, the previous production deployment will be updated to allow "
1078
+ "it to scale to zero. Can only be use in combination with --promote option."
435
1079
  ),
436
1080
  )
437
1081
  @click.option(
@@ -440,7 +1084,15 @@ def predict(
440
1084
  is_flag=True,
441
1085
  required=False,
442
1086
  default=False,
443
- help="Trust truss with hosted secrets.",
1087
+ help="[DEPRECATED]Trust truss with hosted secrets.",
1088
+ )
1089
+ @click.option(
1090
+ "--disable-truss-download",
1091
+ type=bool,
1092
+ is_flag=True,
1093
+ required=False,
1094
+ default=False,
1095
+ help="Disable downloading the truss directory from the UI.",
444
1096
  )
445
1097
  @click.option(
446
1098
  "--deployment-name",
@@ -448,19 +1100,41 @@ def predict(
448
1100
  required=False,
449
1101
  help=(
450
1102
  "Name of the deployment created by the push. Can only be "
451
- "used in combination with --publish or --promote."
1103
+ "used in combination with `--publish` or `--promote`."
452
1104
  ),
453
1105
  )
1106
+ @click.option(
1107
+ "--wait/--no-wait",
1108
+ type=bool,
1109
+ is_flag=True,
1110
+ required=False,
1111
+ default=False,
1112
+ help="Wait for the deployment to complete before returning.",
1113
+ )
1114
+ @click.option(
1115
+ "--timeout-seconds",
1116
+ type=int,
1117
+ required=False,
1118
+ help=(
1119
+ "Maximum time to wait for deployment to complete in seconds. Without "
1120
+ "specifying, the command will not complete until the deployment is complete."
1121
+ ),
1122
+ )
1123
+ @log_level_option
454
1124
  @error_handling
455
1125
  def push(
456
1126
  target_directory: str,
457
1127
  remote: str,
458
- model_name: Optional[str],
1128
+ model_name: str,
459
1129
  publish: bool = False,
460
1130
  trusted: bool = False,
1131
+ disable_truss_download: bool = False,
461
1132
  promote: bool = False,
462
1133
  preserve_previous_production_deployment: bool = False,
463
1134
  deployment_name: Optional[str] = None,
1135
+ wait: bool = False,
1136
+ timeout_seconds: Optional[int] = None,
1137
+ environment: Optional[str] = None,
464
1138
  ) -> None:
465
1139
  """
466
1140
  Pushes a truss to a TrussRemote.
@@ -472,27 +1146,72 @@ def push(
472
1146
  remote = inquire_remote_name(RemoteFactory.get_available_config_names())
473
1147
 
474
1148
  remote_provider = RemoteFactory.create(remote=remote)
475
-
476
1149
  tr = _get_truss_from_directory(target_directory=target_directory)
477
1150
 
478
1151
  model_name = model_name or tr.spec.config.model_name
479
1152
  if not model_name:
480
1153
  model_name = inquire_model_name()
481
1154
 
1155
+ if promote and environment:
1156
+ promote_warning = "`promote` flag and `environment` flag were both specified. Ignoring the value of `promote`"
1157
+ console.print(promote_warning, style="yellow")
1158
+ if promote and not environment:
1159
+ environment = PRODUCTION_ENVIRONMENT_NAME
1160
+
482
1161
  # Write model name to config if it's not already there
483
1162
  if model_name != tr.spec.config.model_name:
484
1163
  tr.spec.config.model_name = model_name
485
1164
  tr.spec.config.write_to_yaml_file(tr.spec.config_path, verbose=False)
486
1165
 
1166
+ # Log a warning if using --trusted.
1167
+ if trusted:
1168
+ trusted_deprecation_notice = (
1169
+ "[DEPRECATED] `--trusted` option is deprecated and no longer needed"
1170
+ )
1171
+ console.print(trusted_deprecation_notice, style="yellow")
1172
+
1173
+ # trt-llm engine builder checks
1174
+ if uses_trt_llm_builder(tr):
1175
+ if not publish:
1176
+ live_reload_disabled_text = "Development mode is currently not supported for trusses using TRT-LLM build flow, push as a published model using --publish"
1177
+ console.print(live_reload_disabled_text, style="red")
1178
+ sys.exit(1)
1179
+ if is_missing_secrets_for_trt_llm_builder(tr):
1180
+ missing_token_text = (
1181
+ "`hf_access_token` must be provided in secrets to build a gated model. "
1182
+ "Please see https://docs.baseten.co/deploy/guides/private-model for configuration instructions."
1183
+ )
1184
+ console.print(missing_token_text, style="red")
1185
+ sys.exit(1)
1186
+ if memory_updated_for_trt_llm_builder(tr):
1187
+ console.print(
1188
+ f"Automatically increasing memory for trt-llm builder to {TRTLLM_MIN_MEMORY_REQUEST_GI}Gi."
1189
+ )
1190
+ for trt_llm_build_config in tr.spec.config.parsed_trt_llm_build_configs:
1191
+ if (
1192
+ trt_llm_build_config.quantization_type
1193
+ in [TrussTRTLLMQuantizationType.FP8, TrussTRTLLMQuantizationType.FP8_KV]
1194
+ and not trt_llm_build_config.num_builder_gpus
1195
+ ):
1196
+ fp8_and_num_builder_gpus_text = (
1197
+ "Warning: build specifies FP8 quantization but does not explicitly specify number of build GPUs. "
1198
+ "GPU memory required at build time may be significantly more than that required at inference time due to FP8 quantization, which can result in OOM failures during the engine build phase."
1199
+ "`num_builder_gpus` can be used to specify the number of GPUs to use at build time."
1200
+ )
1201
+ console.print(fp8_and_num_builder_gpus_text, style="yellow")
1202
+
487
1203
  # TODO(Abu): This needs to be refactored to be more generic
488
1204
  service = remote_provider.push(
489
1205
  tr,
490
- model_name,
1206
+ model_name=model_name,
491
1207
  publish=publish,
492
- trusted=trusted,
1208
+ trusted=True,
493
1209
  promote=promote,
494
1210
  preserve_previous_prod_deployment=preserve_previous_production_deployment,
495
1211
  deployment_name=deployment_name,
1212
+ environment=environment,
1213
+ disable_truss_download=disable_truss_download,
1214
+ progress_bar=progress.Progress,
496
1215
  ) # type: ignore
497
1216
 
498
1217
  click.echo(f"✨ Model {model_name} was successfully pushed ✨")
@@ -500,7 +1219,7 @@ def push(
500
1219
  if service.is_draft:
501
1220
  draft_model_text = """
502
1221
  |---------------------------------------------------------------------------------------|
503
- | Your model has been deployed as a development model. Development models allow you to |
1222
+ | Your model is deploying as a development model. Development models allow you to |
504
1223
  | iterate quickly during the deployment process. |
505
1224
  | |
506
1225
  | When you are ready to publish your deployed model as a new deployment, |
@@ -512,21 +1231,48 @@ def push(
512
1231
 
513
1232
  click.echo(draft_model_text)
514
1233
 
515
- # Log a warning if using secrets without --trusted.
516
- # TODO(helen): this could be moved to a separate function that includes more config checks.
517
- if tr.spec.config.secrets and not trusted:
518
- not_trusted_text = """Warning: your Truss has secrets but was not pushed with --trusted.
519
- Please push with --trusted to grant access to secrets.
520
- """
521
- console.print(not_trusted_text, style="red")
522
-
523
- if promote:
524
- promotion_text = """Your Truss has been deployed as a production model. After it successfully deploys,
525
- it will become the next production deployment of your model."""
1234
+ if environment:
1235
+ promotion_text = (
1236
+ f"Your Truss has been deployed into the {environment} environment. After it successfully "
1237
+ f"deploys, it will become the next {environment} deployment of your model."
1238
+ )
526
1239
  console.print(promotion_text, style="green")
527
1240
 
528
- logs_url = remote_provider.get_remote_logs_url(service) # type: ignore[attr-defined]
529
- rich.print(f"🪵 View logs for your deployment at {logs_url}")
1241
+ console.print(
1242
+ f"🪵 View logs for your deployment at {_format_link(service.logs_url)}"
1243
+ )
1244
+ if wait:
1245
+ start_time = time.time()
1246
+ with console.status("[bold green]Deploying...") as status:
1247
+ try:
1248
+ # Poll for the deployment status until we have reached. Either ACTIVE,
1249
+ # or a non-deploying status (in which case the deployment has failed).
1250
+ for deployment_status in service.poll_deployment_status():
1251
+ if (
1252
+ timeout_seconds is not None
1253
+ and time.time() - start_time > timeout_seconds
1254
+ ):
1255
+ console.print("Deployment timed out.", style="red")
1256
+ sys.exit(1)
1257
+
1258
+ status.update(
1259
+ f"[bold green]Deploying...Current Status: {deployment_status}"
1260
+ )
1261
+
1262
+ if deployment_status == ACTIVE_STATUS:
1263
+ console.print("Deployment succeeded.", style="bold green")
1264
+ return
1265
+
1266
+ if deployment_status not in DEPLOYING_STATUSES:
1267
+ console.print(
1268
+ f"Deployment failed with status {deployment_status}.",
1269
+ style="red",
1270
+ )
1271
+ sys.exit(1)
1272
+
1273
+ except RemoteNetworkError:
1274
+ console.print("Deployment failed: Could not reach remote.", style="red")
1275
+ sys.exit(1)
530
1276
 
531
1277
 
532
1278
  @truss_cli.command()
@@ -576,8 +1322,8 @@ def kill(target_directory: str) -> None:
576
1322
 
577
1323
  @container.command() # type: ignore
578
1324
  def kill_all() -> None:
579
- "Kills all truss containers that are not manually persisted"
580
- truss.kill_all()
1325
+ """Kills all truss containers that are not manually persisted."""
1326
+ docker.kill_all()
581
1327
 
582
1328
 
583
1329
  @truss_cli.command()
@@ -590,18 +1336,25 @@ def cleanup() -> None:
590
1336
  such as for building docker images. This command clears
591
1337
  that data to free up disk space.
592
1338
  """
593
- truss.build.cleanup()
1339
+ _cleanup()
594
1340
 
595
1341
 
596
- def _get_truss_from_directory(target_directory: Optional[str] = None) -> TrussHandle:
1342
+ def _get_truss_from_directory(target_directory: Optional[str] = None):
597
1343
  """Gets Truss from directory. If none, use the current directory"""
598
1344
  if target_directory is None:
599
1345
  target_directory = os.getcwd()
600
- return truss.load(target_directory)
1346
+ if not os.path.isfile(target_directory):
1347
+ return load(target_directory)
1348
+ # These imports are delayed, to handle pydantic v1 envs gracefully.
1349
+ from truss_chains.deployment import code_gen
1350
+
1351
+ truss_dir = code_gen.gen_truss_model_from_source(Path(target_directory))
1352
+ return load(truss_dir)
601
1353
 
602
1354
 
603
1355
  truss_cli.add_command(container)
604
1356
  truss_cli.add_command(image)
1357
+ truss_cli.add_command(chains)
605
1358
 
606
1359
  if __name__ == "__main__":
607
1360
  truss_cli()