truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

Files changed (362) hide show
  1. truss/__init__.py +10 -3
  2. truss/api/__init__.py +123 -0
  3. truss/api/definitions.py +51 -0
  4. truss/base/constants.py +116 -0
  5. truss/base/custom_types.py +29 -0
  6. truss/{errors.py → base/errors.py} +4 -0
  7. truss/base/trt_llm_config.py +310 -0
  8. truss/{truss_config.py → base/truss_config.py} +344 -31
  9. truss/{truss_spec.py → base/truss_spec.py} +20 -6
  10. truss/{validation.py → base/validation.py} +60 -11
  11. truss/cli/cli.py +841 -88
  12. truss/{remote → cli}/remote_cli.py +2 -7
  13. truss/contexts/docker_build_setup.py +67 -0
  14. truss/contexts/image_builder/cache_warmer.py +2 -8
  15. truss/contexts/image_builder/image_builder.py +1 -1
  16. truss/contexts/image_builder/serving_image_builder.py +292 -46
  17. truss/contexts/image_builder/util.py +1 -3
  18. truss/contexts/local_loader/docker_build_emulator.py +58 -0
  19. truss/contexts/local_loader/load_model_local.py +2 -2
  20. truss/contexts/local_loader/truss_module_loader.py +1 -1
  21. truss/contexts/local_loader/utils.py +1 -1
  22. truss/local/local_config.py +2 -6
  23. truss/local/local_config_handler.py +20 -5
  24. truss/patch/__init__.py +1 -0
  25. truss/patch/hash.py +4 -70
  26. truss/patch/signature.py +4 -16
  27. truss/patch/truss_dir_patch_applier.py +3 -78
  28. truss/remote/baseten/api.py +308 -23
  29. truss/remote/baseten/auth.py +3 -3
  30. truss/remote/baseten/core.py +257 -50
  31. truss/remote/baseten/custom_types.py +44 -0
  32. truss/remote/baseten/error.py +4 -0
  33. truss/remote/baseten/remote.py +369 -118
  34. truss/remote/baseten/service.py +118 -11
  35. truss/remote/baseten/utils/status.py +29 -0
  36. truss/remote/baseten/utils/tar.py +34 -22
  37. truss/remote/baseten/utils/transfer.py +36 -23
  38. truss/remote/remote_factory.py +14 -5
  39. truss/remote/truss_remote.py +72 -45
  40. truss/templates/base.Dockerfile.jinja +18 -16
  41. truss/templates/cache.Dockerfile.jinja +3 -3
  42. truss/{server → templates/control}/control/application.py +14 -35
  43. truss/{server → templates/control}/control/endpoints.py +39 -9
  44. truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
  45. truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
  46. truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
  47. truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
  48. truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
  49. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
  50. truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
  51. truss/{server → templates/control}/control/server.py +11 -6
  52. truss/templates/control/requirements.txt +9 -0
  53. truss/templates/custom_python_dx/my_model.py +28 -0
  54. truss/templates/docker_server/proxy.conf.jinja +42 -0
  55. truss/templates/docker_server/supervisord.conf.jinja +27 -0
  56. truss/templates/docker_server_requirements.txt +1 -0
  57. truss/templates/server/common/errors.py +231 -0
  58. truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
  59. truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
  60. truss/{server → templates/server}/common/retry.py +1 -0
  61. truss/{server → templates/server}/common/schema.py +11 -9
  62. truss/templates/server/common/tracing.py +157 -0
  63. truss/templates/server/main.py +9 -0
  64. truss/templates/server/model_wrapper.py +961 -0
  65. truss/templates/server/requirements.txt +21 -0
  66. truss/templates/server/truss_server.py +447 -0
  67. truss/templates/server.Dockerfile.jinja +62 -14
  68. truss/templates/shared/dynamic_config_resolver.py +28 -0
  69. truss/templates/shared/lazy_data_resolver.py +164 -0
  70. truss/templates/shared/log_config.py +125 -0
  71. truss/{server → templates}/shared/secrets_resolver.py +1 -2
  72. truss/{server → templates}/shared/serialization.py +31 -9
  73. truss/{server → templates}/shared/util.py +3 -13
  74. truss/templates/trtllm-audio/model/model.py +49 -0
  75. truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
  76. truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
  77. truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
  78. truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
  79. truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
  80. truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
  81. truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
  82. truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
  83. truss/templates/trtllm-briton/src/extension.py +64 -0
  84. truss/tests/conftest.py +302 -94
  85. truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
  86. truss/tests/contexts/local_loader/test_load_local.py +2 -2
  87. truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
  88. truss/tests/patch/test_calc_patch.py +439 -127
  89. truss/tests/patch/test_dir_signature.py +3 -12
  90. truss/tests/patch/test_hash.py +1 -1
  91. truss/tests/patch/test_signature.py +1 -1
  92. truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
  93. truss/tests/patch/test_types.py +2 -2
  94. truss/tests/remote/baseten/test_api.py +153 -58
  95. truss/tests/remote/baseten/test_auth.py +2 -1
  96. truss/tests/remote/baseten/test_core.py +160 -12
  97. truss/tests/remote/baseten/test_remote.py +489 -77
  98. truss/tests/remote/baseten/test_service.py +55 -0
  99. truss/tests/remote/test_remote_factory.py +16 -18
  100. truss/tests/remote/test_truss_remote.py +26 -17
  101. truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
  102. truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
  103. truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
  104. truss/tests/{server → templates/control}/control/test_server.py +79 -24
  105. truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
  106. truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
  107. truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
  108. truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
  109. truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
  110. truss/tests/{server → templates/server}/common/test_retry.py +3 -3
  111. truss/tests/templates/server/test_model_wrapper.py +248 -0
  112. truss/tests/{server → templates/server}/test_schema.py +3 -5
  113. truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
  114. truss/tests/test_build.py +9 -52
  115. truss/tests/test_config.py +336 -77
  116. truss/tests/test_context_builder_image.py +3 -11
  117. truss/tests/test_control_truss_patching.py +7 -12
  118. truss/tests/test_custom_server.py +38 -0
  119. truss/tests/test_data/context_builder_image_test/test.py +3 -0
  120. truss/tests/test_data/happy.ipynb +56 -0
  121. truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
  122. truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
  123. truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
  124. truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
  125. truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
  126. truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
  127. truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
  128. truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
  129. truss/tests/test_data/test_async_truss/__init__.py +0 -0
  130. truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
  131. truss/tests/test_data/test_basic_truss/__init__.py +0 -0
  132. truss/tests/test_data/test_basic_truss/config.yaml +16 -0
  133. truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
  134. truss/tests/test_data/test_build_commands/__init__.py +0 -0
  135. truss/tests/test_data/test_build_commands/config.yaml +13 -0
  136. truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
  137. truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
  138. truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
  139. truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
  140. truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
  141. truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
  142. truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
  143. truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
  144. truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
  145. truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
  146. truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
  147. truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
  148. truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
  149. truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
  150. truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
  151. truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
  152. truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
  153. truss/tests/test_data/test_openai/__init__.py +0 -0
  154. truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
  155. truss/tests/test_data/test_openai/model/__init__.py +0 -0
  156. truss/tests/test_data/test_openai/model/model.py +15 -0
  157. truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
  158. truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
  159. truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
  160. truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
  161. truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
  162. truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
  163. truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
  164. truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
  165. truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
  166. truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
  167. truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
  168. truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
  169. truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
  170. truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
  171. truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
  172. truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
  173. truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
  174. truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
  175. truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
  176. truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
  177. truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
  178. truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
  179. truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
  180. truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
  181. truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
  182. truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
  183. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
  184. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
  185. truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
  186. truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
  187. truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
  188. truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
  189. truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
  190. truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
  191. truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
  192. truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
  193. truss/tests/test_data/test_truss/__init__.py +0 -0
  194. truss/tests/test_data/test_truss/config.yaml +4 -0
  195. truss/tests/test_data/test_truss/model/__init__.py +0 -0
  196. truss/tests/test_data/test_truss/model/dummy +0 -0
  197. truss/tests/test_data/test_truss/packages/__init__.py +0 -0
  198. truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
  199. truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
  200. truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
  201. truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
  202. truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
  203. truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
  204. truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
  205. truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
  206. truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
  207. truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
  208. truss/tests/test_docker.py +2 -1
  209. truss/tests/test_model_inference.py +1340 -292
  210. truss/tests/test_model_schema.py +33 -26
  211. truss/tests/test_testing_utilities_for_other_tests.py +50 -5
  212. truss/tests/test_truss_gatherer.py +3 -5
  213. truss/tests/test_truss_handle.py +62 -59
  214. truss/tests/test_util.py +2 -1
  215. truss/tests/test_validation.py +15 -13
  216. truss/tests/trt_llm/test_trt_llm_config.py +41 -0
  217. truss/tests/trt_llm/test_validation.py +91 -0
  218. truss/tests/util/test_config_checks.py +40 -0
  219. truss/tests/util/test_env_vars.py +14 -0
  220. truss/tests/util/test_path.py +10 -23
  221. truss/trt_llm/config_checks.py +43 -0
  222. truss/trt_llm/validation.py +42 -0
  223. truss/truss_handle/__init__.py +0 -0
  224. truss/truss_handle/build.py +122 -0
  225. truss/{decorators.py → truss_handle/decorators.py} +1 -1
  226. truss/truss_handle/patch/__init__.py +0 -0
  227. truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
  228. truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
  229. truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
  230. truss/truss_handle/patch/hash.py +71 -0
  231. truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
  232. truss/truss_handle/patch/signature.py +22 -0
  233. truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
  234. truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
  235. truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
  236. truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
  237. truss/util/.truss_ignore +3 -0
  238. truss/{docker.py → util/docker.py} +6 -2
  239. truss/util/download.py +6 -15
  240. truss/util/env_vars.py +41 -0
  241. truss/util/log_utils.py +52 -0
  242. truss/util/path.py +20 -20
  243. truss/util/requirements.py +11 -0
  244. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
  245. truss-0.60.0.dist-info/RECORD +324 -0
  246. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
  247. truss-0.60.0.dist-info/entry_points.txt +4 -0
  248. truss_chains/__init__.py +71 -0
  249. truss_chains/definitions.py +756 -0
  250. truss_chains/deployment/__init__.py +0 -0
  251. truss_chains/deployment/code_gen.py +816 -0
  252. truss_chains/deployment/deployment_client.py +871 -0
  253. truss_chains/framework.py +1480 -0
  254. truss_chains/public_api.py +231 -0
  255. truss_chains/py.typed +0 -0
  256. truss_chains/pydantic_numpy.py +131 -0
  257. truss_chains/reference_code/reference_chainlet.py +34 -0
  258. truss_chains/reference_code/reference_model.py +10 -0
  259. truss_chains/remote_chainlet/__init__.py +0 -0
  260. truss_chains/remote_chainlet/model_skeleton.py +60 -0
  261. truss_chains/remote_chainlet/stub.py +380 -0
  262. truss_chains/remote_chainlet/utils.py +332 -0
  263. truss_chains/streaming.py +378 -0
  264. truss_chains/utils.py +178 -0
  265. CODE_OF_CONDUCT.md +0 -131
  266. CONTRIBUTING.md +0 -48
  267. README.md +0 -137
  268. context_builder.Dockerfile +0 -24
  269. truss/blob/blob_backend.py +0 -10
  270. truss/blob/blob_backend_registry.py +0 -23
  271. truss/blob/http_public_blob_backend.py +0 -23
  272. truss/build/__init__.py +0 -2
  273. truss/build/build.py +0 -143
  274. truss/build/configure.py +0 -63
  275. truss/cli/__init__.py +0 -2
  276. truss/cli/console.py +0 -5
  277. truss/cli/create.py +0 -5
  278. truss/config/trt_llm.py +0 -81
  279. truss/constants.py +0 -61
  280. truss/model_inference.py +0 -123
  281. truss/patch/types.py +0 -30
  282. truss/pytest.ini +0 -7
  283. truss/server/common/errors.py +0 -100
  284. truss/server/common/termination_handler_middleware.py +0 -64
  285. truss/server/common/truss_server.py +0 -389
  286. truss/server/control/patch/model_code_patch_applier.py +0 -46
  287. truss/server/control/patch/requirement_name_identifier.py +0 -17
  288. truss/server/inference_server.py +0 -29
  289. truss/server/model_wrapper.py +0 -434
  290. truss/server/shared/logging.py +0 -81
  291. truss/templates/trtllm/model/model.py +0 -97
  292. truss/templates/trtllm/packages/build_engine_utils.py +0 -34
  293. truss/templates/trtllm/packages/constants.py +0 -11
  294. truss/templates/trtllm/packages/schema.py +0 -216
  295. truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
  296. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
  297. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
  298. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
  299. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
  300. truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
  301. truss/templates/trtllm/packages/triton_client.py +0 -150
  302. truss/templates/trtllm/packages/utils.py +0 -43
  303. truss/test_data/context_builder_image_test/test.py +0 -4
  304. truss/test_data/happy.ipynb +0 -54
  305. truss/test_data/model_load_failure_test/config.yaml +0 -2
  306. truss/test_data/test_concurrency_truss/config.yaml +0 -2
  307. truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
  308. truss/test_data/test_streaming_truss/config.yaml +0 -3
  309. truss/test_data/test_truss/config.yaml +0 -2
  310. truss/tests/server/common/test_termination_handler_middleware.py +0 -93
  311. truss/tests/server/control/test_model_container_patch_applier.py +0 -203
  312. truss/tests/server/core/server/common/test_util.py +0 -19
  313. truss/tests/server/test_model_wrapper.py +0 -87
  314. truss/util/data_structures.py +0 -16
  315. truss-0.10.0rc1.dist-info/RECORD +0 -216
  316. truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
  317. truss/{server/shared → base}/__init__.py +0 -0
  318. truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
  319. truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
  320. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
  321. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
  322. truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
  323. truss/{server → templates/server}/common/__init__.py +0 -0
  324. truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
  325. truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
  326. truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
  327. truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
  328. truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
  329. truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
  330. truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
  331. truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
  332. truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
  333. truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
  334. truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
  335. truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
  336. truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
  337. truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
  338. truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
  339. truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
  340. truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
  341. truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
  342. truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
  343. truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
  344. truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
  345. truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
  346. truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
  347. /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
  348. /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
  349. /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
  350. /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
  351. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
  352. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
  353. /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
  354. /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
  355. /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
  356. /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
  357. /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
  358. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
  359. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
  360. /truss/{patch → truss_handle/patch}/constants.py +0 -0
  361. /truss/{notebook.py → util/notebook.py} +0 -0
  362. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
@@ -1,29 +1,44 @@
1
- import click
1
+ from urllib import parse
2
+
2
3
  import pytest
3
4
  import requests_mock
4
- from truss.remote.baseten.core import ModelId, ModelName, ModelVersionId
5
+ import truss
6
+ from truss.remote.baseten.core import (
7
+ ModelId,
8
+ ModelName,
9
+ ModelVersionId,
10
+ create_chain_atomic,
11
+ )
12
+ from truss.remote.baseten.custom_types import ChainletDataAtomic, OracleData
13
+ from truss.remote.baseten.error import RemoteError
5
14
  from truss.remote.baseten.remote import BasetenRemote
6
- from truss.truss_handle import TrussHandle
15
+ from truss.truss_handle.truss_handle import TrussHandle
7
16
 
8
17
  _TEST_REMOTE_URL = "http://test_remote.com"
18
+ _TEST_REMOTE_GRAPHQL_PATH = "http://test_remote.com/graphql/"
19
+
20
+
21
+ def assert_request_matches_expected_query(request, expected_query) -> None:
22
+ unescaped_content = parse.unquote_plus(request.text)
23
+ actual_lines = tuple(
24
+ line.strip()
25
+ for line in unescaped_content.replace("query=", "").strip().split("\n")
26
+ if line.strip()
27
+ )
28
+ expected_lines = tuple(
29
+ line.strip() for line in expected_query.split("\n") if line.strip()
30
+ )
31
+ assert actual_lines == expected_lines
9
32
 
10
33
 
11
34
  def test_get_service_by_version_id():
12
35
  remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
13
36
 
14
- version = {
15
- "id": "version_id",
16
- "oracle": {
17
- "id": "model_id",
18
- },
19
- }
37
+ version = {"id": "version_id", "oracle": {"id": "model_id"}}
20
38
  model_version_response = {"data": {"model_version": version}}
21
39
 
22
40
  with requests_mock.Mocker() as m:
23
- m.post(
24
- remote._api._api_url,
25
- json=model_version_response,
26
- )
41
+ m.post(_TEST_REMOTE_GRAPHQL_PATH, json=model_version_response)
27
42
  service = remote.get_service(model_identifier=ModelVersionId("version_id"))
28
43
 
29
44
  assert service.model_id == "model_id"
@@ -34,11 +49,8 @@ def test_get_service_by_version_id_no_version():
34
49
  remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
35
50
  model_version_response = {"errors": [{"message": "error"}]}
36
51
  with requests_mock.Mocker() as m:
37
- m.post(
38
- remote._api._api_url,
39
- json=model_version_response,
40
- )
41
- with pytest.raises(click.UsageError):
52
+ m.post(_TEST_REMOTE_GRAPHQL_PATH, json=model_version_response)
53
+ with pytest.raises(RemoteError):
42
54
  remote.get_service(model_identifier=ModelVersionId("version_id"))
43
55
 
44
56
 
@@ -52,19 +64,12 @@ def test_get_service_by_model_name():
52
64
  ]
53
65
  model_response = {
54
66
  "data": {
55
- "model": {
56
- "name": "model_name",
57
- "id": "model_id",
58
- "versions": versions,
59
- }
67
+ "model": {"name": "model_name", "id": "model_id", "versions": versions}
60
68
  }
61
69
  }
62
70
 
63
71
  with requests_mock.Mocker() as m:
64
- m.post(
65
- remote._api._api_url,
66
- json=model_response,
67
- )
72
+ m.post(_TEST_REMOTE_GRAPHQL_PATH, json=model_response)
68
73
 
69
74
  # Check that the production version is returned when published is True.
70
75
  service = remote.get_service(
@@ -84,24 +89,15 @@ def test_get_service_by_model_name():
84
89
  def test_get_service_by_model_name_no_dev_version():
85
90
  remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
86
91
 
87
- versions = [
88
- {"id": "1", "is_draft": False, "is_primary": True},
89
- ]
92
+ versions = [{"id": "1", "is_draft": False, "is_primary": True}]
90
93
  model_response = {
91
94
  "data": {
92
- "model": {
93
- "name": "model_name",
94
- "id": "model_id",
95
- "versions": versions,
96
- }
95
+ "model": {"name": "model_name", "id": "model_id", "versions": versions}
97
96
  }
98
97
  }
99
98
 
100
99
  with requests_mock.Mocker() as m:
101
- m.post(
102
- remote._api._api_url,
103
- json=model_response,
104
- )
100
+ m.post(_TEST_REMOTE_GRAPHQL_PATH, json=model_response)
105
101
 
106
102
  # Check that the production version is returned when published is True.
107
103
  service = remote.get_service(
@@ -112,7 +108,7 @@ def test_get_service_by_model_name_no_dev_version():
112
108
 
113
109
  # Since no development version exists, calling get_service with
114
110
  # published=False should raise an error.
115
- with pytest.raises(click.UsageError):
111
+ with pytest.raises(RemoteError):
116
112
  remote.get_service(
117
113
  model_identifier=ModelName("model_name"), published=False
118
114
  )
@@ -121,28 +117,19 @@ def test_get_service_by_model_name_no_dev_version():
121
117
  def test_get_service_by_model_name_no_prod_version():
122
118
  remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
123
119
 
124
- versions = [
125
- {"id": "1", "is_draft": True, "is_primary": False},
126
- ]
120
+ versions = [{"id": "1", "is_draft": True, "is_primary": False}]
127
121
  model_response = {
128
122
  "data": {
129
- "model": {
130
- "name": "model_name",
131
- "id": "model_id",
132
- "versions": versions,
133
- }
123
+ "model": {"name": "model_name", "id": "model_id", "versions": versions}
134
124
  }
135
125
  }
136
126
 
137
127
  with requests_mock.Mocker() as m:
138
- m.post(
139
- remote._api._api_url,
140
- json=model_response,
141
- )
128
+ m.post(_TEST_REMOTE_GRAPHQL_PATH, json=model_response)
142
129
 
143
130
  # Since no production version exists, calling get_service with
144
131
  # published=True should raise an error.
145
- with pytest.raises(click.UsageError):
132
+ with pytest.raises(RemoteError):
146
133
  remote.get_service(model_identifier=ModelName("model_name"), published=True)
147
134
 
148
135
  # Check that the development version is returned when published is False.
@@ -167,10 +154,7 @@ def test_get_service_by_model_id():
167
154
  }
168
155
 
169
156
  with requests_mock.Mocker() as m:
170
- m.post(
171
- remote._api._api_url,
172
- json=model_response,
173
- )
157
+ m.post(_TEST_REMOTE_GRAPHQL_PATH, json=model_response)
174
158
 
175
159
  service = remote.get_service(model_identifier=ModelId("model_id"))
176
160
  assert service.model_id == "model_id"
@@ -181,11 +165,8 @@ def test_get_service_by_model_id_no_model():
181
165
  remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
182
166
  model_response = {"errors": [{"message": "error"}]}
183
167
  with requests_mock.Mocker() as m:
184
- m.post(
185
- remote._api._api_url,
186
- json=model_response,
187
- )
188
- with pytest.raises(click.UsageError):
168
+ m.post(_TEST_REMOTE_GRAPHQL_PATH, json=model_response)
169
+ with pytest.raises(RemoteError):
189
170
  remote.get_service(model_identifier=ModelId("model_id"))
190
171
 
191
172
 
@@ -203,17 +184,22 @@ def test_push_raised_value_error_when_deployment_name_and_not_publish(
203
184
  }
204
185
  }
205
186
  with requests_mock.Mocker() as m:
206
- m.post(
207
- remote._api._api_url,
208
- json=model_response,
209
- )
187
+ m.post(_TEST_REMOTE_GRAPHQL_PATH, json=model_response)
210
188
  th = TrussHandle(custom_model_truss_dir_with_pre_and_post)
211
189
 
212
190
  with pytest.raises(
213
191
  ValueError,
214
192
  match="Deployment name cannot be used for development deployment",
215
193
  ):
216
- remote.push(th, "model_name", False, False, False, False, "dep_name")
194
+ remote.push(
195
+ th,
196
+ "model_name",
197
+ publish=False,
198
+ trusted=False,
199
+ promote=False,
200
+ preserve_previous_prod_deployment=False,
201
+ deployment_name="dep_name",
202
+ )
217
203
 
218
204
 
219
205
  def test_push_raised_value_error_when_deployment_name_is_not_valid(
@@ -230,17 +216,22 @@ def test_push_raised_value_error_when_deployment_name_is_not_valid(
230
216
  }
231
217
  }
232
218
  with requests_mock.Mocker() as m:
233
- m.post(
234
- remote._api._api_url,
235
- json=model_response,
236
- )
219
+ m.post(_TEST_REMOTE_GRAPHQL_PATH, json=model_response)
237
220
  th = TrussHandle(custom_model_truss_dir_with_pre_and_post)
238
221
 
239
222
  with pytest.raises(
240
223
  ValueError,
241
224
  match="Deployment name must only contain alphanumeric, -, _ and . characters",
242
225
  ):
243
- remote.push(th, "model_name", True, False, False, False, "dep//name")
226
+ remote.push(
227
+ th,
228
+ "model_name",
229
+ publish=True,
230
+ trusted=False,
231
+ promote=False,
232
+ preserve_previous_prod_deployment=False,
233
+ deployment_name="dep//name",
234
+ )
244
235
 
245
236
 
246
237
  def test_push_raised_value_error_when_keep_previous_prod_settings_and_not_promote(
@@ -257,14 +248,435 @@ def test_push_raised_value_error_when_keep_previous_prod_settings_and_not_promot
257
248
  }
258
249
  }
259
250
  with requests_mock.Mocker() as m:
260
- m.post(
261
- remote._api._api_url,
262
- json=model_response,
263
- )
251
+ m.post(_TEST_REMOTE_GRAPHQL_PATH, json=model_response)
264
252
  th = TrussHandle(custom_model_truss_dir_with_pre_and_post)
265
253
 
266
254
  with pytest.raises(
267
255
  ValueError,
268
256
  match="preserve-previous-production-deployment can only be used with the '--promote' option",
269
257
  ):
270
- remote.push(th, "model_name", False, False, False, True)
258
+ remote.push(
259
+ th,
260
+ "model_name",
261
+ publish=False,
262
+ trusted=False,
263
+ promote=False,
264
+ preserve_previous_prod_deployment=True,
265
+ )
266
+
267
+
268
+ def test_create_chain_with_no_publish():
269
+ remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
270
+
271
+ with requests_mock.Mocker() as m:
272
+ m.post(
273
+ _TEST_REMOTE_GRAPHQL_PATH,
274
+ [
275
+ {"json": {"data": {"chains": []}}},
276
+ {
277
+ "json": {
278
+ "data": {
279
+ "deploy_chain_atomic": {
280
+ "chain_id": "new-chain-id",
281
+ "chain_deployment_id": "new-chain-deployment-id",
282
+ "entrypoint_model_id": "new-entrypoint-model-id",
283
+ "entrypoint_model_version_id": "new-entrypoint-model-version-id",
284
+ }
285
+ }
286
+ }
287
+ },
288
+ ],
289
+ )
290
+
291
+ deployment_handle = create_chain_atomic(
292
+ api=remote.api,
293
+ chain_name="draft_chain",
294
+ entrypoint=ChainletDataAtomic(
295
+ name="chainlet-1",
296
+ oracle=OracleData(
297
+ model_name="model-1",
298
+ s3_key="s3-key-1",
299
+ encoded_config_str="encoded-config-str-1",
300
+ is_trusted=True,
301
+ ),
302
+ ),
303
+ dependencies=[],
304
+ is_draft=True,
305
+ environment=None,
306
+ )
307
+
308
+ get_chains_graphql_request = m.request_history[0]
309
+ create_chain_graphql_request = m.request_history[1]
310
+
311
+ expected_get_chains_query = """
312
+ {
313
+ chains {
314
+ id
315
+ name
316
+ }
317
+ }
318
+ """.strip()
319
+
320
+ assert_request_matches_expected_query(
321
+ get_chains_graphql_request, expected_get_chains_query
322
+ )
323
+
324
+ chainlets_string = """
325
+ {
326
+ name: "chainlet-1",
327
+ oracle: {
328
+ model_name: "model-1",
329
+ s3_key: "s3-key-1",
330
+ encoded_config_str: "encoded-config-str-1",
331
+ is_trusted: true,
332
+ semver_bump: "MINOR"
333
+ }
334
+ }
335
+ """.strip()
336
+
337
+ # Note that if publish=False and promote=True, we set publish to True and create
338
+ # a non-draft deployment
339
+ expected_create_chain_mutation = f"""
340
+ mutation {{
341
+ deploy_chain_atomic(
342
+ chain_name: "draft_chain"
343
+ is_draft: true
344
+ entrypoint: {chainlets_string}
345
+ dependencies: []
346
+ client_version: "{truss.version()}"
347
+ ) {{
348
+ chain_id
349
+ chain_deployment_id
350
+ entrypoint_model_id
351
+ entrypoint_model_version_id
352
+ }}
353
+ }}
354
+ """.strip()
355
+
356
+ assert_request_matches_expected_query(
357
+ create_chain_graphql_request, expected_create_chain_mutation
358
+ )
359
+
360
+ assert deployment_handle.chain_id == "new-chain-id"
361
+ assert deployment_handle.chain_deployment_id == "new-chain-deployment-id"
362
+
363
+
364
+ def test_create_chain_no_existing_chain():
365
+ remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
366
+
367
+ with requests_mock.Mocker() as m:
368
+ m.post(
369
+ _TEST_REMOTE_GRAPHQL_PATH,
370
+ [
371
+ {"json": {"data": {"chains": []}}},
372
+ {
373
+ "json": {
374
+ "data": {
375
+ "deploy_chain_atomic": {
376
+ "chain_id": "new-chain-id",
377
+ "chain_deployment_id": "new-chain-deployment-id",
378
+ "entrypoint_model_id": "new-entrypoint-model-id",
379
+ "entrypoint_model_version_id": "new-entrypoint-model-version-id",
380
+ }
381
+ }
382
+ }
383
+ },
384
+ ],
385
+ )
386
+
387
+ deployment_handle = create_chain_atomic(
388
+ api=remote.api,
389
+ chain_name="new_chain",
390
+ entrypoint=ChainletDataAtomic(
391
+ name="chainlet-1",
392
+ oracle=OracleData(
393
+ model_name="model-1",
394
+ s3_key="s3-key-1",
395
+ encoded_config_str="encoded-config-str-1",
396
+ is_trusted=True,
397
+ ),
398
+ ),
399
+ dependencies=[],
400
+ is_draft=False,
401
+ environment=None,
402
+ )
403
+
404
+ get_chains_graphql_request = m.request_history[0]
405
+ create_chain_graphql_request = m.request_history[1]
406
+
407
+ expected_get_chains_query = """
408
+ {
409
+ chains {
410
+ id
411
+ name
412
+ }
413
+ }
414
+ """.strip()
415
+
416
+ assert_request_matches_expected_query(
417
+ get_chains_graphql_request, expected_get_chains_query
418
+ )
419
+
420
+ chainlets_string = """
421
+ {
422
+ name: "chainlet-1",
423
+ oracle: {
424
+ model_name: "model-1",
425
+ s3_key: "s3-key-1",
426
+ encoded_config_str: "encoded-config-str-1",
427
+ is_trusted: true,
428
+ semver_bump: "MINOR"
429
+ }
430
+ }
431
+ """.strip()
432
+
433
+ expected_create_chain_mutation = f"""
434
+ mutation {{
435
+ deploy_chain_atomic(
436
+ chain_name: "new_chain"
437
+ is_draft: false
438
+ entrypoint: {chainlets_string}
439
+ dependencies: []
440
+ client_version: "{truss.version()}"
441
+ ) {{
442
+ chain_id
443
+ chain_deployment_id
444
+ entrypoint_model_id
445
+ entrypoint_model_version_id
446
+ }}
447
+ }}
448
+ """.strip()
449
+
450
+ assert_request_matches_expected_query(
451
+ create_chain_graphql_request, expected_create_chain_mutation
452
+ )
453
+
454
+ assert deployment_handle.chain_id == "new-chain-id"
455
+ assert deployment_handle.chain_deployment_id == "new-chain-deployment-id"
456
+
457
+
458
+ def test_create_chain_with_existing_chain_promote_to_environment_publish_false():
459
+ remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
460
+
461
+ with requests_mock.Mocker() as m:
462
+ m.post(
463
+ _TEST_REMOTE_GRAPHQL_PATH,
464
+ [
465
+ {
466
+ "json": {
467
+ "data": {
468
+ "chains": [{"id": "old-chain-id", "name": "old_chain"}]
469
+ }
470
+ }
471
+ },
472
+ {
473
+ "json": {
474
+ "data": {
475
+ "deploy_chain_atomic": {
476
+ "chain_id": "new-chain-id",
477
+ "chain_deployment_id": "new-chain-deployment-id",
478
+ "entrypoint_model_id": "new-entrypoint-model-id",
479
+ "entrypoint_model_version_id": "new-entrypoint-model-version-id",
480
+ }
481
+ }
482
+ }
483
+ },
484
+ ],
485
+ )
486
+
487
+ deployment_handle = create_chain_atomic(
488
+ api=remote.api,
489
+ chain_name="old_chain",
490
+ entrypoint=ChainletDataAtomic(
491
+ name="chainlet-1",
492
+ oracle=OracleData(
493
+ model_name="model-1",
494
+ s3_key="s3-key-1",
495
+ encoded_config_str="encoded-config-str-1",
496
+ is_trusted=True,
497
+ ),
498
+ ),
499
+ dependencies=[],
500
+ is_draft=True,
501
+ environment="production",
502
+ )
503
+
504
+ get_chains_graphql_request = m.request_history[0]
505
+ create_chain_graphql_request = m.request_history[1]
506
+
507
+ expected_get_chains_query = """
508
+ {
509
+ chains {
510
+ id
511
+ name
512
+ }
513
+ }
514
+ """.strip()
515
+
516
+ assert_request_matches_expected_query(
517
+ get_chains_graphql_request, expected_get_chains_query
518
+ )
519
+
520
+ # Note that if publish=False and environment!=None, we set publish to True and create
521
+ # a non-draft deployment
522
+ chainlets_string = """
523
+ {
524
+ name: "chainlet-1",
525
+ oracle: {
526
+ model_name: "model-1",
527
+ s3_key: "s3-key-1",
528
+ encoded_config_str: "encoded-config-str-1",
529
+ is_trusted: true,
530
+ semver_bump: "MINOR"
531
+ }
532
+ }
533
+ """.strip()
534
+
535
+ expected_create_chain_mutation = f"""
536
+ mutation {{
537
+ deploy_chain_atomic(
538
+ chain_id: "old-chain-id"
539
+ environment: "production"
540
+ is_draft: false
541
+ entrypoint: {chainlets_string}
542
+ dependencies: []
543
+ client_version: "{truss.version()}"
544
+ ) {{
545
+ chain_id
546
+ chain_deployment_id
547
+ entrypoint_model_id
548
+ entrypoint_model_version_id
549
+ }}
550
+ }}
551
+ """.strip()
552
+
553
+ assert_request_matches_expected_query(
554
+ create_chain_graphql_request, expected_create_chain_mutation
555
+ )
556
+
557
+ assert deployment_handle.chain_id == "new-chain-id"
558
+ assert deployment_handle.chain_deployment_id == "new-chain-deployment-id"
559
+
560
+
561
+ def test_create_chain_existing_chain_publish_true_no_promotion():
562
+ remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
563
+
564
+ with requests_mock.Mocker() as m:
565
+ m.post(
566
+ _TEST_REMOTE_GRAPHQL_PATH,
567
+ [
568
+ {
569
+ "json": {
570
+ "data": {
571
+ "chains": [{"id": "old-chain-id", "name": "old_chain"}]
572
+ }
573
+ }
574
+ },
575
+ {
576
+ "json": {
577
+ "data": {
578
+ "deploy_chain_atomic": {
579
+ "chain_id": "new-chain-id",
580
+ "chain_deployment_id": "new-chain-deployment-id",
581
+ "entrypoint_model_id": "new-entrypoint-model-id",
582
+ "entrypoint_model_version_id": "new-entrypoint-model-version-id",
583
+ }
584
+ }
585
+ }
586
+ },
587
+ ],
588
+ )
589
+
590
+ deployment_handle = create_chain_atomic(
591
+ api=remote.api,
592
+ chain_name="old_chain",
593
+ entrypoint=ChainletDataAtomic(
594
+ name="chainlet-1",
595
+ oracle=OracleData(
596
+ model_name="model-1",
597
+ s3_key="s3-key-1",
598
+ encoded_config_str="encoded-config-str-1",
599
+ is_trusted=True,
600
+ ),
601
+ ),
602
+ dependencies=[],
603
+ is_draft=False,
604
+ environment=None,
605
+ )
606
+
607
+ get_chains_graphql_request = m.request_history[0]
608
+ create_chain_graphql_request = m.request_history[1]
609
+
610
+ expected_get_chains_query = """
611
+ {
612
+ chains {
613
+ id
614
+ name
615
+ }
616
+ }
617
+ """.strip()
618
+
619
+ assert_request_matches_expected_query(
620
+ get_chains_graphql_request, expected_get_chains_query
621
+ )
622
+
623
+ chainlets_string = """
624
+ {
625
+ name: "chainlet-1",
626
+ oracle: {
627
+ model_name: "model-1",
628
+ s3_key: "s3-key-1",
629
+ encoded_config_str: "encoded-config-str-1",
630
+ is_trusted: true,
631
+ semver_bump: "MINOR"
632
+ }
633
+ }
634
+ """.strip()
635
+
636
+ expected_create_chain_mutation = f"""
637
+ mutation {{
638
+ deploy_chain_atomic(
639
+ chain_id: "old-chain-id"
640
+ is_draft: false
641
+ entrypoint: {chainlets_string}
642
+ dependencies: []
643
+ client_version: "{truss.version()}"
644
+ ) {{
645
+ chain_id
646
+ chain_deployment_id
647
+ entrypoint_model_id
648
+ entrypoint_model_version_id
649
+ }}
650
+ }}
651
+ """.strip()
652
+
653
+ assert_request_matches_expected_query(
654
+ create_chain_graphql_request, expected_create_chain_mutation
655
+ )
656
+
657
+ assert deployment_handle.chain_id == "new-chain-id"
658
+ assert deployment_handle.chain_deployment_id == "new-chain-deployment-id"
659
+
660
+
661
+ @pytest.mark.parametrize("publish", [True, False])
662
+ def test_push_raised_value_error_when_disable_truss_download_for_existing_model(
663
+ publish, custom_model_truss_dir_with_pre_and_post
664
+ ):
665
+ remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
666
+ model_response = {
667
+ "data": {
668
+ "model": {
669
+ "name": "model_name",
670
+ "id": "model_id",
671
+ "primary_version": {"id": "version_id"},
672
+ }
673
+ }
674
+ }
675
+ with requests_mock.Mocker() as m:
676
+ m.post(_TEST_REMOTE_GRAPHQL_PATH, json=model_response)
677
+ th = TrussHandle(custom_model_truss_dir_with_pre_and_post)
678
+
679
+ with pytest.raises(
680
+ ValueError, match="disable-truss-download can only be used for new models"
681
+ ):
682
+ remote.push(th, "model_name", publish=publish, disable_truss_download=True)