truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

Files changed (362) hide show
  1. truss/__init__.py +10 -3
  2. truss/api/__init__.py +123 -0
  3. truss/api/definitions.py +51 -0
  4. truss/base/constants.py +116 -0
  5. truss/base/custom_types.py +29 -0
  6. truss/{errors.py → base/errors.py} +4 -0
  7. truss/base/trt_llm_config.py +310 -0
  8. truss/{truss_config.py → base/truss_config.py} +344 -31
  9. truss/{truss_spec.py → base/truss_spec.py} +20 -6
  10. truss/{validation.py → base/validation.py} +60 -11
  11. truss/cli/cli.py +841 -88
  12. truss/{remote → cli}/remote_cli.py +2 -7
  13. truss/contexts/docker_build_setup.py +67 -0
  14. truss/contexts/image_builder/cache_warmer.py +2 -8
  15. truss/contexts/image_builder/image_builder.py +1 -1
  16. truss/contexts/image_builder/serving_image_builder.py +292 -46
  17. truss/contexts/image_builder/util.py +1 -3
  18. truss/contexts/local_loader/docker_build_emulator.py +58 -0
  19. truss/contexts/local_loader/load_model_local.py +2 -2
  20. truss/contexts/local_loader/truss_module_loader.py +1 -1
  21. truss/contexts/local_loader/utils.py +1 -1
  22. truss/local/local_config.py +2 -6
  23. truss/local/local_config_handler.py +20 -5
  24. truss/patch/__init__.py +1 -0
  25. truss/patch/hash.py +4 -70
  26. truss/patch/signature.py +4 -16
  27. truss/patch/truss_dir_patch_applier.py +3 -78
  28. truss/remote/baseten/api.py +308 -23
  29. truss/remote/baseten/auth.py +3 -3
  30. truss/remote/baseten/core.py +257 -50
  31. truss/remote/baseten/custom_types.py +44 -0
  32. truss/remote/baseten/error.py +4 -0
  33. truss/remote/baseten/remote.py +369 -118
  34. truss/remote/baseten/service.py +118 -11
  35. truss/remote/baseten/utils/status.py +29 -0
  36. truss/remote/baseten/utils/tar.py +34 -22
  37. truss/remote/baseten/utils/transfer.py +36 -23
  38. truss/remote/remote_factory.py +14 -5
  39. truss/remote/truss_remote.py +72 -45
  40. truss/templates/base.Dockerfile.jinja +18 -16
  41. truss/templates/cache.Dockerfile.jinja +3 -3
  42. truss/{server → templates/control}/control/application.py +14 -35
  43. truss/{server → templates/control}/control/endpoints.py +39 -9
  44. truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
  45. truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
  46. truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
  47. truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
  48. truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
  49. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
  50. truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
  51. truss/{server → templates/control}/control/server.py +11 -6
  52. truss/templates/control/requirements.txt +9 -0
  53. truss/templates/custom_python_dx/my_model.py +28 -0
  54. truss/templates/docker_server/proxy.conf.jinja +42 -0
  55. truss/templates/docker_server/supervisord.conf.jinja +27 -0
  56. truss/templates/docker_server_requirements.txt +1 -0
  57. truss/templates/server/common/errors.py +231 -0
  58. truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
  59. truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
  60. truss/{server → templates/server}/common/retry.py +1 -0
  61. truss/{server → templates/server}/common/schema.py +11 -9
  62. truss/templates/server/common/tracing.py +157 -0
  63. truss/templates/server/main.py +9 -0
  64. truss/templates/server/model_wrapper.py +961 -0
  65. truss/templates/server/requirements.txt +21 -0
  66. truss/templates/server/truss_server.py +447 -0
  67. truss/templates/server.Dockerfile.jinja +62 -14
  68. truss/templates/shared/dynamic_config_resolver.py +28 -0
  69. truss/templates/shared/lazy_data_resolver.py +164 -0
  70. truss/templates/shared/log_config.py +125 -0
  71. truss/{server → templates}/shared/secrets_resolver.py +1 -2
  72. truss/{server → templates}/shared/serialization.py +31 -9
  73. truss/{server → templates}/shared/util.py +3 -13
  74. truss/templates/trtllm-audio/model/model.py +49 -0
  75. truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
  76. truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
  77. truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
  78. truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
  79. truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
  80. truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
  81. truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
  82. truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
  83. truss/templates/trtllm-briton/src/extension.py +64 -0
  84. truss/tests/conftest.py +302 -94
  85. truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
  86. truss/tests/contexts/local_loader/test_load_local.py +2 -2
  87. truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
  88. truss/tests/patch/test_calc_patch.py +439 -127
  89. truss/tests/patch/test_dir_signature.py +3 -12
  90. truss/tests/patch/test_hash.py +1 -1
  91. truss/tests/patch/test_signature.py +1 -1
  92. truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
  93. truss/tests/patch/test_types.py +2 -2
  94. truss/tests/remote/baseten/test_api.py +153 -58
  95. truss/tests/remote/baseten/test_auth.py +2 -1
  96. truss/tests/remote/baseten/test_core.py +160 -12
  97. truss/tests/remote/baseten/test_remote.py +489 -77
  98. truss/tests/remote/baseten/test_service.py +55 -0
  99. truss/tests/remote/test_remote_factory.py +16 -18
  100. truss/tests/remote/test_truss_remote.py +26 -17
  101. truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
  102. truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
  103. truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
  104. truss/tests/{server → templates/control}/control/test_server.py +79 -24
  105. truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
  106. truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
  107. truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
  108. truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
  109. truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
  110. truss/tests/{server → templates/server}/common/test_retry.py +3 -3
  111. truss/tests/templates/server/test_model_wrapper.py +248 -0
  112. truss/tests/{server → templates/server}/test_schema.py +3 -5
  113. truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
  114. truss/tests/test_build.py +9 -52
  115. truss/tests/test_config.py +336 -77
  116. truss/tests/test_context_builder_image.py +3 -11
  117. truss/tests/test_control_truss_patching.py +7 -12
  118. truss/tests/test_custom_server.py +38 -0
  119. truss/tests/test_data/context_builder_image_test/test.py +3 -0
  120. truss/tests/test_data/happy.ipynb +56 -0
  121. truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
  122. truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
  123. truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
  124. truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
  125. truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
  126. truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
  127. truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
  128. truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
  129. truss/tests/test_data/test_async_truss/__init__.py +0 -0
  130. truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
  131. truss/tests/test_data/test_basic_truss/__init__.py +0 -0
  132. truss/tests/test_data/test_basic_truss/config.yaml +16 -0
  133. truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
  134. truss/tests/test_data/test_build_commands/__init__.py +0 -0
  135. truss/tests/test_data/test_build_commands/config.yaml +13 -0
  136. truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
  137. truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
  138. truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
  139. truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
  140. truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
  141. truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
  142. truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
  143. truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
  144. truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
  145. truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
  146. truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
  147. truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
  148. truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
  149. truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
  150. truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
  151. truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
  152. truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
  153. truss/tests/test_data/test_openai/__init__.py +0 -0
  154. truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
  155. truss/tests/test_data/test_openai/model/__init__.py +0 -0
  156. truss/tests/test_data/test_openai/model/model.py +15 -0
  157. truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
  158. truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
  159. truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
  160. truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
  161. truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
  162. truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
  163. truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
  164. truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
  165. truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
  166. truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
  167. truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
  168. truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
  169. truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
  170. truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
  171. truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
  172. truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
  173. truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
  174. truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
  175. truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
  176. truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
  177. truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
  178. truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
  179. truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
  180. truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
  181. truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
  182. truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
  183. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
  184. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
  185. truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
  186. truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
  187. truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
  188. truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
  189. truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
  190. truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
  191. truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
  192. truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
  193. truss/tests/test_data/test_truss/__init__.py +0 -0
  194. truss/tests/test_data/test_truss/config.yaml +4 -0
  195. truss/tests/test_data/test_truss/model/__init__.py +0 -0
  196. truss/tests/test_data/test_truss/model/dummy +0 -0
  197. truss/tests/test_data/test_truss/packages/__init__.py +0 -0
  198. truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
  199. truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
  200. truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
  201. truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
  202. truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
  203. truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
  204. truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
  205. truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
  206. truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
  207. truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
  208. truss/tests/test_docker.py +2 -1
  209. truss/tests/test_model_inference.py +1340 -292
  210. truss/tests/test_model_schema.py +33 -26
  211. truss/tests/test_testing_utilities_for_other_tests.py +50 -5
  212. truss/tests/test_truss_gatherer.py +3 -5
  213. truss/tests/test_truss_handle.py +62 -59
  214. truss/tests/test_util.py +2 -1
  215. truss/tests/test_validation.py +15 -13
  216. truss/tests/trt_llm/test_trt_llm_config.py +41 -0
  217. truss/tests/trt_llm/test_validation.py +91 -0
  218. truss/tests/util/test_config_checks.py +40 -0
  219. truss/tests/util/test_env_vars.py +14 -0
  220. truss/tests/util/test_path.py +10 -23
  221. truss/trt_llm/config_checks.py +43 -0
  222. truss/trt_llm/validation.py +42 -0
  223. truss/truss_handle/__init__.py +0 -0
  224. truss/truss_handle/build.py +122 -0
  225. truss/{decorators.py → truss_handle/decorators.py} +1 -1
  226. truss/truss_handle/patch/__init__.py +0 -0
  227. truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
  228. truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
  229. truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
  230. truss/truss_handle/patch/hash.py +71 -0
  231. truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
  232. truss/truss_handle/patch/signature.py +22 -0
  233. truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
  234. truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
  235. truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
  236. truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
  237. truss/util/.truss_ignore +3 -0
  238. truss/{docker.py → util/docker.py} +6 -2
  239. truss/util/download.py +6 -15
  240. truss/util/env_vars.py +41 -0
  241. truss/util/log_utils.py +52 -0
  242. truss/util/path.py +20 -20
  243. truss/util/requirements.py +11 -0
  244. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
  245. truss-0.60.0.dist-info/RECORD +324 -0
  246. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
  247. truss-0.60.0.dist-info/entry_points.txt +4 -0
  248. truss_chains/__init__.py +71 -0
  249. truss_chains/definitions.py +756 -0
  250. truss_chains/deployment/__init__.py +0 -0
  251. truss_chains/deployment/code_gen.py +816 -0
  252. truss_chains/deployment/deployment_client.py +871 -0
  253. truss_chains/framework.py +1480 -0
  254. truss_chains/public_api.py +231 -0
  255. truss_chains/py.typed +0 -0
  256. truss_chains/pydantic_numpy.py +131 -0
  257. truss_chains/reference_code/reference_chainlet.py +34 -0
  258. truss_chains/reference_code/reference_model.py +10 -0
  259. truss_chains/remote_chainlet/__init__.py +0 -0
  260. truss_chains/remote_chainlet/model_skeleton.py +60 -0
  261. truss_chains/remote_chainlet/stub.py +380 -0
  262. truss_chains/remote_chainlet/utils.py +332 -0
  263. truss_chains/streaming.py +378 -0
  264. truss_chains/utils.py +178 -0
  265. CODE_OF_CONDUCT.md +0 -131
  266. CONTRIBUTING.md +0 -48
  267. README.md +0 -137
  268. context_builder.Dockerfile +0 -24
  269. truss/blob/blob_backend.py +0 -10
  270. truss/blob/blob_backend_registry.py +0 -23
  271. truss/blob/http_public_blob_backend.py +0 -23
  272. truss/build/__init__.py +0 -2
  273. truss/build/build.py +0 -143
  274. truss/build/configure.py +0 -63
  275. truss/cli/__init__.py +0 -2
  276. truss/cli/console.py +0 -5
  277. truss/cli/create.py +0 -5
  278. truss/config/trt_llm.py +0 -81
  279. truss/constants.py +0 -61
  280. truss/model_inference.py +0 -123
  281. truss/patch/types.py +0 -30
  282. truss/pytest.ini +0 -7
  283. truss/server/common/errors.py +0 -100
  284. truss/server/common/termination_handler_middleware.py +0 -64
  285. truss/server/common/truss_server.py +0 -389
  286. truss/server/control/patch/model_code_patch_applier.py +0 -46
  287. truss/server/control/patch/requirement_name_identifier.py +0 -17
  288. truss/server/inference_server.py +0 -29
  289. truss/server/model_wrapper.py +0 -434
  290. truss/server/shared/logging.py +0 -81
  291. truss/templates/trtllm/model/model.py +0 -97
  292. truss/templates/trtllm/packages/build_engine_utils.py +0 -34
  293. truss/templates/trtllm/packages/constants.py +0 -11
  294. truss/templates/trtllm/packages/schema.py +0 -216
  295. truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
  296. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
  297. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
  298. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
  299. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
  300. truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
  301. truss/templates/trtllm/packages/triton_client.py +0 -150
  302. truss/templates/trtllm/packages/utils.py +0 -43
  303. truss/test_data/context_builder_image_test/test.py +0 -4
  304. truss/test_data/happy.ipynb +0 -54
  305. truss/test_data/model_load_failure_test/config.yaml +0 -2
  306. truss/test_data/test_concurrency_truss/config.yaml +0 -2
  307. truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
  308. truss/test_data/test_streaming_truss/config.yaml +0 -3
  309. truss/test_data/test_truss/config.yaml +0 -2
  310. truss/tests/server/common/test_termination_handler_middleware.py +0 -93
  311. truss/tests/server/control/test_model_container_patch_applier.py +0 -203
  312. truss/tests/server/core/server/common/test_util.py +0 -19
  313. truss/tests/server/test_model_wrapper.py +0 -87
  314. truss/util/data_structures.py +0 -16
  315. truss-0.10.0rc1.dist-info/RECORD +0 -216
  316. truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
  317. truss/{server/shared → base}/__init__.py +0 -0
  318. truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
  319. truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
  320. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
  321. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
  322. truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
  323. truss/{server → templates/server}/common/__init__.py +0 -0
  324. truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
  325. truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
  326. truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
  327. truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
  328. truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
  329. truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
  330. truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
  331. truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
  332. truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
  333. truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
  334. truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
  335. truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
  336. truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
  337. truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
  338. truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
  339. truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
  340. truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
  341. truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
  342. truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
  343. truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
  344. truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
  345. truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
  346. truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
  347. /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
  348. /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
  349. /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
  350. /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
  351. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
  352. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
  353. /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
  354. /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
  355. /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
  356. /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
  357. /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
  358. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
  359. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
  360. /truss/{patch → truss_handle/patch}/constants.py +0 -0
  361. /truss/{notebook.py → util/notebook.py} +0 -0
  362. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
@@ -1,49 +1,125 @@
1
+ import enum
1
2
  import logging
2
3
  import re
3
4
  from pathlib import Path
4
- from typing import List, Optional, Tuple
5
+ from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Type
5
6
 
6
- import click
7
- import rich
8
7
  import yaml
9
8
  from requests import ReadTimeout
9
+ from truss.base.constants import PRODUCTION_ENVIRONMENT_NAME
10
+
11
+ if TYPE_CHECKING:
12
+ from rich import console as rich_console
13
+ from rich import progress
14
+ from truss.base import validation
15
+ from truss.base.truss_config import ModelServer
10
16
  from truss.local.local_config_handler import LocalConfigHandler
17
+ from truss.remote.baseten import custom_types
11
18
  from truss.remote.baseten.api import BasetenApi
12
19
  from truss.remote.baseten.auth import AuthService
13
20
  from truss.remote.baseten.core import (
21
+ ChainDeploymentHandleAtomic,
14
22
  ModelId,
15
23
  ModelIdentifier,
16
24
  ModelName,
17
25
  ModelVersionId,
18
26
  archive_truss,
27
+ create_chain_atomic,
19
28
  create_truss_service,
20
29
  exists_model,
21
30
  get_dev_version,
22
31
  get_dev_version_from_versions,
23
32
  get_model_versions,
24
33
  get_prod_version_from_versions,
34
+ get_truss_watch_state,
25
35
  upload_truss,
36
+ validate_truss_config,
26
37
  )
27
- from truss.remote.baseten.error import ApiError
28
- from truss.remote.baseten.service import BasetenService
38
+ from truss.remote.baseten.error import ApiError, RemoteError
39
+ from truss.remote.baseten.service import BasetenService, URLConfig
29
40
  from truss.remote.baseten.utils.transfer import base64_encoded_json_str
30
- from truss.remote.truss_remote import TrussRemote, TrussService
31
- from truss.truss_config import ModelServer
32
- from truss.truss_handle import TrussHandle
33
- from truss.util.path import is_ignored, load_trussignore_patterns
41
+ from truss.remote.truss_remote import RemoteUser, TrussRemote
42
+ from truss.truss_handle import build as truss_build
43
+ from truss.truss_handle.truss_handle import TrussHandle
44
+ from truss.util.path import is_ignored, load_trussignore_patterns_from_truss_dir
34
45
  from watchfiles import watch
35
46
 
36
47
 
48
+ class PatchStatus(enum.Enum):
49
+ SUCCESS = enum.auto()
50
+ FAILED = enum.auto()
51
+ SKIPPED = enum.auto()
52
+
53
+
54
+ class PatchResult(NamedTuple):
55
+ status: PatchStatus
56
+ message: str
57
+
58
+
59
+ class FinalPushData(custom_types.OracleData):
60
+ class Config:
61
+ protected_namespaces = ()
62
+
63
+ is_draft: bool
64
+ model_id: Optional[str]
65
+ preserve_previous_prod_deployment: bool
66
+ origin: Optional[custom_types.ModelOrigin] = None
67
+ environment: Optional[str] = None
68
+ allow_truss_download: bool
69
+
70
+
37
71
  class BasetenRemote(TrussRemote):
38
- def __init__(self, remote_url: str, api_key: str, **kwargs):
39
- super().__init__(remote_url, **kwargs)
72
+ def __init__(self, remote_url: str, api_key: str):
73
+ super().__init__(remote_url)
40
74
  self._auth_service = AuthService(api_key=api_key)
41
- self._api = BasetenApi(f"{self._remote_url}/graphql/", self._auth_service)
75
+ self._api = BasetenApi(remote_url, self._auth_service)
76
+
77
+ @property
78
+ def api(self) -> BasetenApi:
79
+ return self._api
80
+
81
+ def get_chainlets(
82
+ self, chain_deployment_id: str
83
+ ) -> List[custom_types.DeployedChainlet]:
84
+ return [
85
+ custom_types.DeployedChainlet(
86
+ name=chainlet["name"],
87
+ is_entrypoint=chainlet["is_entrypoint"],
88
+ is_draft=chainlet["oracle_version"]["is_draft"],
89
+ status=chainlet["oracle_version"]["current_model_deployment_status"][
90
+ "status"
91
+ ],
92
+ logs_url=URLConfig.chainlet_logs_url(
93
+ self.remote_url,
94
+ chainlet["chain"]["id"],
95
+ chain_deployment_id,
96
+ chainlet["id"],
97
+ ),
98
+ oracle_predict_url=URLConfig.invocation_url(
99
+ self._api.rest_api_url,
100
+ URLConfig.MODEL,
101
+ chainlet["oracle"]["id"],
102
+ chainlet["oracle_version"]["id"],
103
+ chainlet["oracle_version"]["is_draft"],
104
+ ),
105
+ oracle_name=chainlet["oracle"]["name"],
106
+ )
107
+ for chainlet in self._api.get_chainlets_by_deployment_id(
108
+ chain_deployment_id
109
+ )
110
+ ]
42
111
 
43
- def authenticate(self):
44
- return self._auth_service.validate()
112
+ def whoami(self) -> RemoteUser:
113
+ resp = self._api._post_graphql_query(
114
+ "query{organization{workspace_name}user{email}}"
115
+ )
116
+ workspace_name = resp["data"]["organization"]["workspace_name"]
117
+ user_email = resp["data"]["user"]["email"]
118
+ return RemoteUser(workspace_name, user_email)
45
119
 
46
- def push( # type: ignore
120
+ # Validate and finalize options.
121
+ # Upload Truss files to S3 and return S3 key.
122
+ def _prepare_push(
47
123
  self,
48
124
  truss_handle: TrussHandle,
49
125
  model_name: str,
@@ -51,20 +127,30 @@ class BasetenRemote(TrussRemote):
51
127
  trusted: bool = False,
52
128
  promote: bool = False,
53
129
  preserve_previous_prod_deployment: bool = False,
130
+ disable_truss_download: bool = False,
54
131
  deployment_name: Optional[str] = None,
55
- ):
132
+ origin: Optional[custom_types.ModelOrigin] = None,
133
+ environment: Optional[str] = None,
134
+ progress_bar: Optional[Type["progress.Progress"]] = None,
135
+ ) -> FinalPushData:
56
136
  if model_name.isspace():
57
137
  raise ValueError("Model name cannot be empty")
58
138
 
59
- model_id = exists_model(self._api, model_name)
139
+ if truss_handle.is_scattered():
140
+ truss_handle = TrussHandle(truss_handle.gather())
60
141
 
61
- gathered_truss = TrussHandle(truss_handle.gather())
62
- if gathered_truss.spec.model_server != ModelServer.TrussServer:
142
+ if truss_handle.spec.model_server != ModelServer.TrussServer:
63
143
  publish = True
64
144
 
65
145
  if promote:
66
- # If we are promoting a model after deploy, it must be published.
67
- # Draft models cannot be promoted.
146
+ environment = PRODUCTION_ENVIRONMENT_NAME
147
+
148
+ # If there is a target environment, it must be published.
149
+ # Draft models cannot be promoted.
150
+ if environment and not publish:
151
+ logging.info(
152
+ f"Automatically publishing model '{model_name}' based on environment setting."
153
+ )
68
154
  publish = True
69
155
 
70
156
  if not publish and deployment_name:
@@ -74,7 +160,8 @@ class BasetenRemote(TrussRemote):
74
160
 
75
161
  if not promote and preserve_previous_prod_deployment:
76
162
  raise ValueError(
77
- "preserve-previous-production-deployment can only be used with the '--promote' option"
163
+ "preserve-previous-production-deployment can only be used "
164
+ "with the '--promote' option"
78
165
  )
79
166
 
80
167
  if deployment_name and not re.match(r"^[0-9a-zA-Z_\-\.]*$", deployment_name):
@@ -82,34 +169,146 @@ class BasetenRemote(TrussRemote):
82
169
  "Deployment name must only contain alphanumeric, -, _ and . characters"
83
170
  )
84
171
 
172
+ model_id = exists_model(self._api, model_name)
173
+
174
+ if model_id is not None and disable_truss_download:
175
+ raise ValueError("disable-truss-download can only be used for new models")
176
+
177
+ temp_file = archive_truss(truss_handle, progress_bar)
178
+ s3_key = upload_truss(self._api, temp_file, progress_bar)
85
179
  encoded_config_str = base64_encoded_json_str(
86
- gathered_truss._spec._config.to_dict()
180
+ truss_handle._spec._config.to_dict()
87
181
  )
88
182
 
89
- temp_file = archive_truss(gathered_truss)
90
- s3_key = upload_truss(self._api, temp_file)
183
+ validate_truss_config(self._api, encoded_config_str)
91
184
 
92
- model_id, model_version_id = create_truss_service(
93
- api=self._api,
185
+ return FinalPushData(
94
186
  model_name=model_name,
95
187
  s3_key=s3_key,
96
- config=encoded_config_str,
188
+ encoded_config_str=encoded_config_str,
97
189
  is_draft=not publish,
98
190
  model_id=model_id,
99
191
  is_trusted=trusted,
192
+ preserve_previous_prod_deployment=preserve_previous_prod_deployment,
193
+ version_name=deployment_name,
194
+ origin=origin,
195
+ environment=environment,
196
+ allow_truss_download=not disable_truss_download,
197
+ )
198
+
199
+ def push( # type: ignore
200
+ self,
201
+ truss_handle: TrussHandle,
202
+ model_name: str,
203
+ publish: bool = True,
204
+ trusted: bool = False,
205
+ promote: bool = False,
206
+ preserve_previous_prod_deployment: bool = False,
207
+ disable_truss_download: bool = False,
208
+ deployment_name: Optional[str] = None,
209
+ origin: Optional[custom_types.ModelOrigin] = None,
210
+ environment: Optional[str] = None,
211
+ progress_bar: Optional[Type["progress.Progress"]] = None,
212
+ ) -> BasetenService:
213
+ push_data = self._prepare_push(
214
+ truss_handle=truss_handle,
215
+ model_name=model_name,
216
+ publish=publish,
217
+ trusted=trusted,
100
218
  promote=promote,
101
219
  preserve_previous_prod_deployment=preserve_previous_prod_deployment,
220
+ disable_truss_download=disable_truss_download,
102
221
  deployment_name=deployment_name,
222
+ origin=origin,
223
+ environment=environment,
224
+ progress_bar=progress_bar,
225
+ )
226
+
227
+ # TODO(Tyron): This set of args is duplicated across
228
+ # many functions. We should consolidate them into a
229
+ # data class with standardized default values so
230
+ # we're not drilling these arguments everywhere.
231
+ model_id, model_version_id = create_truss_service(
232
+ api=self._api,
233
+ model_name=push_data.model_name,
234
+ s3_key=push_data.s3_key,
235
+ config=push_data.encoded_config_str,
236
+ is_draft=push_data.is_draft,
237
+ model_id=push_data.model_id,
238
+ is_trusted=push_data.is_trusted,
239
+ preserve_previous_prod_deployment=push_data.preserve_previous_prod_deployment,
240
+ allow_truss_download=push_data.allow_truss_download,
241
+ deployment_name=push_data.version_name,
242
+ origin=push_data.origin,
243
+ environment=push_data.environment,
103
244
  )
104
245
 
105
246
  return BasetenService(
106
247
  model_id=model_id,
107
248
  model_version_id=model_version_id,
108
- is_draft=not publish,
249
+ is_draft=push_data.is_draft,
109
250
  api_key=self._auth_service.authenticate().value,
110
251
  service_url=f"{self._remote_url}/model_versions/{model_version_id}",
111
252
  truss_handle=truss_handle,
253
+ api=self._api,
254
+ )
255
+
256
+ def push_chain_atomic(
257
+ self,
258
+ chain_name: str,
259
+ entrypoint_artifact: custom_types.ChainletArtifact,
260
+ dependency_artifacts: List[custom_types.ChainletArtifact],
261
+ publish: bool = False,
262
+ environment: Optional[str] = None,
263
+ progress_bar: Optional[Type["progress.Progress"]] = None,
264
+ ) -> ChainDeploymentHandleAtomic:
265
+ # If we are promoting a model to an environment after deploy, it must be published.
266
+ # Draft models cannot be promoted.
267
+ if environment and not publish:
268
+ publish = True
269
+
270
+ chainlet_data: List[custom_types.ChainletDataAtomic] = []
271
+
272
+ for artifact in [entrypoint_artifact, *dependency_artifacts]:
273
+ truss_handle = truss_build.load(str(artifact.truss_dir))
274
+ model_name = truss_handle.spec.config.model_name
275
+
276
+ assert model_name and validation.is_valid_model_name(model_name)
277
+
278
+ push_data = self._prepare_push(
279
+ truss_handle=truss_handle,
280
+ model_name=model_name,
281
+ # Models must be trusted to use the API KEY secret.
282
+ trusted=True,
283
+ publish=publish,
284
+ origin=custom_types.ModelOrigin.CHAINS,
285
+ progress_bar=progress_bar,
286
+ )
287
+ oracle_data = custom_types.OracleData(
288
+ model_name=push_data.model_name,
289
+ s3_key=push_data.s3_key,
290
+ encoded_config_str=push_data.encoded_config_str,
291
+ is_draft=push_data.is_draft,
292
+ model_id=push_data.model_id,
293
+ is_trusted=push_data.is_trusted,
294
+ version_name=push_data.version_name,
295
+ )
296
+ chainlet_data.append(
297
+ custom_types.ChainletDataAtomic(
298
+ name=artifact.display_name, oracle=oracle_data
299
+ )
300
+ )
301
+
302
+ chain_deployment_handle = create_chain_atomic(
303
+ api=self._api,
304
+ chain_name=chain_name,
305
+ entrypoint=chainlet_data[0],
306
+ dependencies=chainlet_data[1:],
307
+ is_draft=not publish,
308
+ environment=environment,
112
309
  )
310
+ logging.info("Successfully pushed to baseten. Chain is building and deploying.")
311
+ return chain_deployment_handle
113
312
 
114
313
  @staticmethod
115
314
  def _get_matching_version(model_versions: List[dict], published: bool) -> dict:
@@ -117,7 +316,7 @@ class BasetenRemote(TrussRemote):
117
316
  # Return the development model version.
118
317
  dev_version = get_dev_version_from_versions(model_versions)
119
318
  if not dev_version:
120
- raise click.UsageError(
319
+ raise RemoteError(
121
320
  "No development model found. Run `truss push` then try again."
122
321
  )
123
322
  return dev_version
@@ -125,7 +324,7 @@ class BasetenRemote(TrussRemote):
125
324
  # Return the production deployment version.
126
325
  prod_version = get_prod_version_from_versions(model_versions)
127
326
  if not prod_version:
128
- raise click.UsageError(
327
+ raise RemoteError(
129
328
  "No production model found. Run `truss push --publish` then try again."
130
329
  )
131
330
  return prod_version
@@ -138,9 +337,7 @@ class BasetenRemote(TrussRemote):
138
337
  try:
139
338
  model_version = api.get_model_version_by_id(model_identifier.value)
140
339
  except ApiError:
141
- raise click.UsageError(
142
- f"Model version {model_identifier.value} not found."
143
- )
340
+ raise RemoteError(f"Model version {model_identifier.value} not found.")
144
341
  model_version_id = model_version["model_version"]["id"]
145
342
  model_id = model_version["model_version"]["oracle"]["id"]
146
343
  service_url_path = f"/model_versions/{model_version_id}"
@@ -159,14 +356,15 @@ class BasetenRemote(TrussRemote):
159
356
  try:
160
357
  model = api.get_model_by_id(model_identifier.value)
161
358
  except ApiError:
162
- raise click.UsageError(f"Model {model_identifier.value} not found.")
359
+ raise RemoteError(f"Model {model_identifier.value} not found.")
163
360
  model_id = model["model"]["id"]
164
361
  model_version_id = model["model"]["primary_version"]["id"]
165
362
  service_url_path = f"/models/{model_id}"
166
363
  else:
167
364
  # Model identifier is of invalid type.
168
- raise click.UsageError(
169
- "You must either be inside of a Truss directory, or provide --model-deployment or --model options."
365
+ raise RemoteError(
366
+ "You must either be inside of a Truss directory, or provide "
367
+ "--model-deployment or --model options."
170
368
  )
171
369
 
172
370
  return service_url_path, model_id, model_version_id
@@ -178,12 +376,10 @@ class BasetenRemote(TrussRemote):
178
376
  raise ValueError("Baseten Service requires a model_identifier")
179
377
 
180
378
  published = kwargs.get("published", False)
181
- (
182
- service_url_path,
183
- model_id,
184
- model_version_id,
185
- ) = self._get_service_url_path_and_model_ids(
186
- self._api, model_identifier, published
379
+ (service_url_path, model_id, model_version_id) = (
380
+ self._get_service_url_path_and_model_ids(
381
+ self._api, model_identifier, published
382
+ )
187
383
  )
188
384
 
189
385
  return BasetenService(
@@ -192,119 +388,174 @@ class BasetenRemote(TrussRemote):
192
388
  is_draft=not published,
193
389
  api_key=self._auth_service.authenticate().value,
194
390
  service_url=f"{self._remote_url}{service_url_path}",
391
+ api=self._api,
195
392
  )
196
393
 
197
- def get_remote_logs_url(
198
- self,
199
- service: TrussService,
200
- ) -> str:
201
- return service.logs_url(self._remote_url)
202
-
203
394
  def sync_truss_to_dev_version_by_name(
204
395
  self,
205
396
  model_name: str,
206
397
  target_directory: str,
207
- ):
398
+ console: "rich_console.Console",
399
+ error_console: "rich_console.Console",
400
+ ) -> None:
208
401
  # verify that development deployment exists for given model name
209
- dev_version = get_dev_version(
210
- self._api, model_name
211
- ) # pylint: disable=protected-access
402
+ dev_version = get_dev_version(self._api, model_name) # pylint: disable=protected-access
212
403
  if not dev_version:
213
- raise click.UsageError(
404
+ raise RemoteError(
214
405
  "No development model found. Run `truss push` then try again."
215
406
  )
216
407
 
217
408
  watch_path = Path(target_directory)
218
- truss_ignore_patterns = load_trussignore_patterns()
409
+ truss_ignore_patterns = load_trussignore_patterns_from_truss_dir(watch_path)
219
410
 
220
411
  def watch_filter(_, path):
221
- return not is_ignored(
222
- Path(path),
223
- truss_ignore_patterns,
224
- )
412
+ return not is_ignored(Path(path), truss_ignore_patterns)
225
413
 
226
414
  # disable watchfiles logger
227
415
  logging.getLogger("watchfiles.main").disabled = True
228
416
 
229
- rich.print(f"🚰 Attempting to sync truss at '{watch_path}' with remote")
230
- self.patch(watch_path, truss_ignore_patterns)
417
+ console.print(f"🚰 Attempting to sync truss at '{watch_path}' with remote")
418
+ self.patch(watch_path, truss_ignore_patterns, console, error_console)
231
419
 
232
- rich.print(f"👀 Watching for changes to truss at '{watch_path}' ...")
420
+ console.print(f"👀 Watching for changes to truss at '{watch_path}' ...")
233
421
  for _ in watch(watch_path, watch_filter=watch_filter, raise_interrupt=False):
234
- self.patch(watch_path, truss_ignore_patterns)
422
+ self.patch(watch_path, truss_ignore_patterns, console, error_console)
235
423
 
236
- def patch(
424
+ def _patch(
237
425
  self,
238
426
  watch_path: Path,
239
427
  truss_ignore_patterns: List[str],
240
- ):
241
- from truss.cli.console import console, error_console
242
-
428
+ console: Optional["rich_console.Console"] = None,
429
+ ) -> PatchResult:
243
430
  try:
244
431
  truss_handle = TrussHandle(watch_path)
245
- except yaml.parser.ParserError:
246
- error_console.print("Unable to parse config file")
247
- return
248
- except ValueError:
249
- error_console.print(f"Error when reading truss from directory {watch_path}")
250
- return
432
+ except yaml.parser.ParserError as e:
433
+ return PatchResult(PatchStatus.FAILED, f"Unable to parse config file. {e}")
434
+ except ValueError as e:
435
+ return PatchResult(
436
+ PatchStatus.FAILED,
437
+ f"Error when reading truss from directory {watch_path}. {e}",
438
+ )
439
+
251
440
  model_name = truss_handle.spec.config.model_name
252
441
  dev_version = get_dev_version(self._api, model_name) # type: ignore
253
442
  if not dev_version:
254
- error_console.print(
255
- f"No development deployment found with model name: {model_name}"
443
+ return PatchResult(
444
+ PatchStatus.FAILED,
445
+ f"No development deployment found for model: {model_name}.",
256
446
  )
257
- return
447
+
258
448
  truss_hash = dev_version.get("truss_hash", None)
259
449
  truss_signature = dev_version.get("truss_signature", None)
260
450
  if not (truss_hash and truss_signature):
261
- error_console.print(
262
- """Failed to inspect a running remote deployment to watch for changes.
263
- Ensure that there exists a running remote deployment before attempting to watch for changes
264
- """
451
+ return PatchResult(
452
+ PatchStatus.FAILED,
453
+ (
454
+ "Failed to inspect a running remote deployment to watch for "
455
+ "changes. Ensure that there exists a running remote deployment"
456
+ " before attempting to watch for changes."
457
+ ),
458
+ )
459
+
460
+ truss_watch_state = get_truss_watch_state(self._api, model_name) # type: ignore
461
+ # Make sure the patches are calculated against the current django patch state, if it exists.
462
+ # This is important to ensure that the sequence of patches for a given sesion forms a
463
+ # valid patch sequence (via a linked list)
464
+ if truss_watch_state.patches:
465
+ truss_hash = truss_watch_state.patches.django_patch_state.current_hash
466
+ truss_signature = (
467
+ truss_watch_state.patches.django_patch_state.current_signature
468
+ )
469
+ logging.debug(f"db patch hash: {truss_hash}")
470
+ logging.debug(
471
+ f"container_patch_hash: {truss_watch_state.patches.container_patch_state.current_hash}"
265
472
  )
266
- return
267
473
  LocalConfigHandler.add_signature(truss_hash, truss_signature)
268
474
  try:
269
475
  patch_request = truss_handle.calc_patch(truss_hash, truss_ignore_patterns)
270
- except Exception:
271
- error_console.print("Failed to calculate patch, bailing on patching")
272
- return
273
- if patch_request:
274
- if (
275
- patch_request.prev_hash == patch_request.next_hash
276
- or len(patch_request.patch_ops) == 0
277
- ):
278
- console.print("No changes observed, skipping patching")
279
- return
280
- try:
476
+ except Exception as e:
477
+ return PatchResult(PatchStatus.FAILED, f"Failed to calculate patch. {e}")
478
+ if not patch_request:
479
+ return PatchResult(
480
+ PatchStatus.FAILED,
481
+ "Failed to calculate patch. Change type might not be supported.",
482
+ )
483
+
484
+ django_has_unapplied_patches = (
485
+ not truss_watch_state.is_container_built_from_push
486
+ and truss_watch_state.patches
487
+ and (
488
+ truss_watch_state.patches.django_patch_state.current_hash
489
+ != truss_watch_state.patches.container_patch_state.current_hash
490
+ )
491
+ )
492
+ should_create_patch = (
493
+ patch_request.prev_hash != patch_request.next_hash
494
+ and len(patch_request.patch_ops) > 0
495
+ )
496
+ is_synced = not django_has_unapplied_patches and not should_create_patch
497
+ if is_synced:
498
+ return PatchResult(
499
+ PatchStatus.SKIPPED, "No changes observed, skipping patching."
500
+ )
501
+
502
+ def do_patch():
503
+ if should_create_patch:
504
+ resp = self._api.patch_draft_truss_two_step(model_name, patch_request)
505
+ else:
506
+ resp = self._api.sync_draft_truss(model_name)
507
+ return resp
508
+
509
+ try:
510
+ if console:
281
511
  with console.status("Applying patch..."):
282
- resp = self._api.patch_draft_truss(model_name, patch_request)
283
- except ReadTimeout:
284
- error_console.print(
285
- "Read Timeout when attempting to connect to remote. Bailing on patching"
286
- )
287
- return
288
- except Exception:
289
- error_console.print(
290
- "Failed to patch draft deployment, bailing on patching"
512
+ resp = do_patch()
513
+ else:
514
+ resp = do_patch()
515
+
516
+ except ReadTimeout:
517
+ return PatchResult(
518
+ PatchStatus.FAILED, "Read Timeout when attempting to patch remote."
519
+ )
520
+ except Exception as e:
521
+ return PatchResult(
522
+ PatchStatus.FAILED, f"Failed to patch draft deployment. {e}"
523
+ )
524
+ if not resp["succeeded"]:
525
+ needs_full_deploy = resp.get("needs_full_deploy", None)
526
+ if needs_full_deploy:
527
+ message = (
528
+ f"Model {model_name} is not able to be patched: `{resp['error']}`. "
529
+ f"Use `truss push` to deploy."
291
530
  )
292
- return
293
- if not resp["succeeded"]:
294
- needs_full_deploy = resp.get("needs_full_deploy", None)
295
- if needs_full_deploy:
296
- error_console.print(
297
- f"Model {model_name} is not able to be patched, use `truss push` to deploy"
298
- )
299
- else:
300
- error_console.print(
301
- f"Failed to patch: `{resp['error']}`. Model left in original state"
302
- )
303
531
  else:
304
- console.print(
305
- resp.get(
306
- "success_message",
307
- f"Model {model_name} patched successfully",
308
- ),
309
- style="green",
532
+ message = (
533
+ f"Failed to patch. Server error: `{resp['error']}`. "
534
+ "Model left in original state."
310
535
  )
536
+ return PatchResult(PatchStatus.FAILED, message)
537
+ else:
538
+ return PatchResult(
539
+ PatchStatus.SUCCESS,
540
+ resp.get(
541
+ "success_message", f"Model {model_name} patched successfully."
542
+ ),
543
+ )
544
+
545
+ def patch(
546
+ self,
547
+ watch_path: Path,
548
+ truss_ignore_patterns: List[str],
549
+ console: "rich_console.Console",
550
+ error_console: "rich_console.Console",
551
+ ):
552
+ result = self._patch(watch_path, truss_ignore_patterns)
553
+ if result.status in (PatchStatus.SUCCESS, PatchStatus.SKIPPED):
554
+ console.print(result.message, style="green")
555
+ else:
556
+ error_console.print(result.message)
557
+
558
+ def patch_for_chainlet(
559
+ self, watch_path: Path, truss_ignore_patterns: List[str]
560
+ ) -> PatchResult:
561
+ return self._patch(watch_path, truss_ignore_patterns, console=None)