truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

Files changed (362) hide show
  1. truss/__init__.py +10 -3
  2. truss/api/__init__.py +123 -0
  3. truss/api/definitions.py +51 -0
  4. truss/base/constants.py +116 -0
  5. truss/base/custom_types.py +29 -0
  6. truss/{errors.py → base/errors.py} +4 -0
  7. truss/base/trt_llm_config.py +310 -0
  8. truss/{truss_config.py → base/truss_config.py} +344 -31
  9. truss/{truss_spec.py → base/truss_spec.py} +20 -6
  10. truss/{validation.py → base/validation.py} +60 -11
  11. truss/cli/cli.py +841 -88
  12. truss/{remote → cli}/remote_cli.py +2 -7
  13. truss/contexts/docker_build_setup.py +67 -0
  14. truss/contexts/image_builder/cache_warmer.py +2 -8
  15. truss/contexts/image_builder/image_builder.py +1 -1
  16. truss/contexts/image_builder/serving_image_builder.py +292 -46
  17. truss/contexts/image_builder/util.py +1 -3
  18. truss/contexts/local_loader/docker_build_emulator.py +58 -0
  19. truss/contexts/local_loader/load_model_local.py +2 -2
  20. truss/contexts/local_loader/truss_module_loader.py +1 -1
  21. truss/contexts/local_loader/utils.py +1 -1
  22. truss/local/local_config.py +2 -6
  23. truss/local/local_config_handler.py +20 -5
  24. truss/patch/__init__.py +1 -0
  25. truss/patch/hash.py +4 -70
  26. truss/patch/signature.py +4 -16
  27. truss/patch/truss_dir_patch_applier.py +3 -78
  28. truss/remote/baseten/api.py +308 -23
  29. truss/remote/baseten/auth.py +3 -3
  30. truss/remote/baseten/core.py +257 -50
  31. truss/remote/baseten/custom_types.py +44 -0
  32. truss/remote/baseten/error.py +4 -0
  33. truss/remote/baseten/remote.py +369 -118
  34. truss/remote/baseten/service.py +118 -11
  35. truss/remote/baseten/utils/status.py +29 -0
  36. truss/remote/baseten/utils/tar.py +34 -22
  37. truss/remote/baseten/utils/transfer.py +36 -23
  38. truss/remote/remote_factory.py +14 -5
  39. truss/remote/truss_remote.py +72 -45
  40. truss/templates/base.Dockerfile.jinja +18 -16
  41. truss/templates/cache.Dockerfile.jinja +3 -3
  42. truss/{server → templates/control}/control/application.py +14 -35
  43. truss/{server → templates/control}/control/endpoints.py +39 -9
  44. truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
  45. truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
  46. truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
  47. truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
  48. truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
  49. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
  50. truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
  51. truss/{server → templates/control}/control/server.py +11 -6
  52. truss/templates/control/requirements.txt +9 -0
  53. truss/templates/custom_python_dx/my_model.py +28 -0
  54. truss/templates/docker_server/proxy.conf.jinja +42 -0
  55. truss/templates/docker_server/supervisord.conf.jinja +27 -0
  56. truss/templates/docker_server_requirements.txt +1 -0
  57. truss/templates/server/common/errors.py +231 -0
  58. truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
  59. truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
  60. truss/{server → templates/server}/common/retry.py +1 -0
  61. truss/{server → templates/server}/common/schema.py +11 -9
  62. truss/templates/server/common/tracing.py +157 -0
  63. truss/templates/server/main.py +9 -0
  64. truss/templates/server/model_wrapper.py +961 -0
  65. truss/templates/server/requirements.txt +21 -0
  66. truss/templates/server/truss_server.py +447 -0
  67. truss/templates/server.Dockerfile.jinja +62 -14
  68. truss/templates/shared/dynamic_config_resolver.py +28 -0
  69. truss/templates/shared/lazy_data_resolver.py +164 -0
  70. truss/templates/shared/log_config.py +125 -0
  71. truss/{server → templates}/shared/secrets_resolver.py +1 -2
  72. truss/{server → templates}/shared/serialization.py +31 -9
  73. truss/{server → templates}/shared/util.py +3 -13
  74. truss/templates/trtllm-audio/model/model.py +49 -0
  75. truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
  76. truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
  77. truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
  78. truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
  79. truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
  80. truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
  81. truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
  82. truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
  83. truss/templates/trtllm-briton/src/extension.py +64 -0
  84. truss/tests/conftest.py +302 -94
  85. truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
  86. truss/tests/contexts/local_loader/test_load_local.py +2 -2
  87. truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
  88. truss/tests/patch/test_calc_patch.py +439 -127
  89. truss/tests/patch/test_dir_signature.py +3 -12
  90. truss/tests/patch/test_hash.py +1 -1
  91. truss/tests/patch/test_signature.py +1 -1
  92. truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
  93. truss/tests/patch/test_types.py +2 -2
  94. truss/tests/remote/baseten/test_api.py +153 -58
  95. truss/tests/remote/baseten/test_auth.py +2 -1
  96. truss/tests/remote/baseten/test_core.py +160 -12
  97. truss/tests/remote/baseten/test_remote.py +489 -77
  98. truss/tests/remote/baseten/test_service.py +55 -0
  99. truss/tests/remote/test_remote_factory.py +16 -18
  100. truss/tests/remote/test_truss_remote.py +26 -17
  101. truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
  102. truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
  103. truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
  104. truss/tests/{server → templates/control}/control/test_server.py +79 -24
  105. truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
  106. truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
  107. truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
  108. truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
  109. truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
  110. truss/tests/{server → templates/server}/common/test_retry.py +3 -3
  111. truss/tests/templates/server/test_model_wrapper.py +248 -0
  112. truss/tests/{server → templates/server}/test_schema.py +3 -5
  113. truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
  114. truss/tests/test_build.py +9 -52
  115. truss/tests/test_config.py +336 -77
  116. truss/tests/test_context_builder_image.py +3 -11
  117. truss/tests/test_control_truss_patching.py +7 -12
  118. truss/tests/test_custom_server.py +38 -0
  119. truss/tests/test_data/context_builder_image_test/test.py +3 -0
  120. truss/tests/test_data/happy.ipynb +56 -0
  121. truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
  122. truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
  123. truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
  124. truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
  125. truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
  126. truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
  127. truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
  128. truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
  129. truss/tests/test_data/test_async_truss/__init__.py +0 -0
  130. truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
  131. truss/tests/test_data/test_basic_truss/__init__.py +0 -0
  132. truss/tests/test_data/test_basic_truss/config.yaml +16 -0
  133. truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
  134. truss/tests/test_data/test_build_commands/__init__.py +0 -0
  135. truss/tests/test_data/test_build_commands/config.yaml +13 -0
  136. truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
  137. truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
  138. truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
  139. truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
  140. truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
  141. truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
  142. truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
  143. truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
  144. truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
  145. truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
  146. truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
  147. truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
  148. truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
  149. truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
  150. truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
  151. truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
  152. truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
  153. truss/tests/test_data/test_openai/__init__.py +0 -0
  154. truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
  155. truss/tests/test_data/test_openai/model/__init__.py +0 -0
  156. truss/tests/test_data/test_openai/model/model.py +15 -0
  157. truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
  158. truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
  159. truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
  160. truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
  161. truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
  162. truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
  163. truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
  164. truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
  165. truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
  166. truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
  167. truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
  168. truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
  169. truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
  170. truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
  171. truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
  172. truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
  173. truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
  174. truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
  175. truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
  176. truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
  177. truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
  178. truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
  179. truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
  180. truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
  181. truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
  182. truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
  183. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
  184. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
  185. truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
  186. truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
  187. truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
  188. truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
  189. truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
  190. truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
  191. truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
  192. truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
  193. truss/tests/test_data/test_truss/__init__.py +0 -0
  194. truss/tests/test_data/test_truss/config.yaml +4 -0
  195. truss/tests/test_data/test_truss/model/__init__.py +0 -0
  196. truss/tests/test_data/test_truss/model/dummy +0 -0
  197. truss/tests/test_data/test_truss/packages/__init__.py +0 -0
  198. truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
  199. truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
  200. truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
  201. truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
  202. truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
  203. truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
  204. truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
  205. truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
  206. truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
  207. truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
  208. truss/tests/test_docker.py +2 -1
  209. truss/tests/test_model_inference.py +1340 -292
  210. truss/tests/test_model_schema.py +33 -26
  211. truss/tests/test_testing_utilities_for_other_tests.py +50 -5
  212. truss/tests/test_truss_gatherer.py +3 -5
  213. truss/tests/test_truss_handle.py +62 -59
  214. truss/tests/test_util.py +2 -1
  215. truss/tests/test_validation.py +15 -13
  216. truss/tests/trt_llm/test_trt_llm_config.py +41 -0
  217. truss/tests/trt_llm/test_validation.py +91 -0
  218. truss/tests/util/test_config_checks.py +40 -0
  219. truss/tests/util/test_env_vars.py +14 -0
  220. truss/tests/util/test_path.py +10 -23
  221. truss/trt_llm/config_checks.py +43 -0
  222. truss/trt_llm/validation.py +42 -0
  223. truss/truss_handle/__init__.py +0 -0
  224. truss/truss_handle/build.py +122 -0
  225. truss/{decorators.py → truss_handle/decorators.py} +1 -1
  226. truss/truss_handle/patch/__init__.py +0 -0
  227. truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
  228. truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
  229. truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
  230. truss/truss_handle/patch/hash.py +71 -0
  231. truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
  232. truss/truss_handle/patch/signature.py +22 -0
  233. truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
  234. truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
  235. truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
  236. truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
  237. truss/util/.truss_ignore +3 -0
  238. truss/{docker.py → util/docker.py} +6 -2
  239. truss/util/download.py +6 -15
  240. truss/util/env_vars.py +41 -0
  241. truss/util/log_utils.py +52 -0
  242. truss/util/path.py +20 -20
  243. truss/util/requirements.py +11 -0
  244. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
  245. truss-0.60.0.dist-info/RECORD +324 -0
  246. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
  247. truss-0.60.0.dist-info/entry_points.txt +4 -0
  248. truss_chains/__init__.py +71 -0
  249. truss_chains/definitions.py +756 -0
  250. truss_chains/deployment/__init__.py +0 -0
  251. truss_chains/deployment/code_gen.py +816 -0
  252. truss_chains/deployment/deployment_client.py +871 -0
  253. truss_chains/framework.py +1480 -0
  254. truss_chains/public_api.py +231 -0
  255. truss_chains/py.typed +0 -0
  256. truss_chains/pydantic_numpy.py +131 -0
  257. truss_chains/reference_code/reference_chainlet.py +34 -0
  258. truss_chains/reference_code/reference_model.py +10 -0
  259. truss_chains/remote_chainlet/__init__.py +0 -0
  260. truss_chains/remote_chainlet/model_skeleton.py +60 -0
  261. truss_chains/remote_chainlet/stub.py +380 -0
  262. truss_chains/remote_chainlet/utils.py +332 -0
  263. truss_chains/streaming.py +378 -0
  264. truss_chains/utils.py +178 -0
  265. CODE_OF_CONDUCT.md +0 -131
  266. CONTRIBUTING.md +0 -48
  267. README.md +0 -137
  268. context_builder.Dockerfile +0 -24
  269. truss/blob/blob_backend.py +0 -10
  270. truss/blob/blob_backend_registry.py +0 -23
  271. truss/blob/http_public_blob_backend.py +0 -23
  272. truss/build/__init__.py +0 -2
  273. truss/build/build.py +0 -143
  274. truss/build/configure.py +0 -63
  275. truss/cli/__init__.py +0 -2
  276. truss/cli/console.py +0 -5
  277. truss/cli/create.py +0 -5
  278. truss/config/trt_llm.py +0 -81
  279. truss/constants.py +0 -61
  280. truss/model_inference.py +0 -123
  281. truss/patch/types.py +0 -30
  282. truss/pytest.ini +0 -7
  283. truss/server/common/errors.py +0 -100
  284. truss/server/common/termination_handler_middleware.py +0 -64
  285. truss/server/common/truss_server.py +0 -389
  286. truss/server/control/patch/model_code_patch_applier.py +0 -46
  287. truss/server/control/patch/requirement_name_identifier.py +0 -17
  288. truss/server/inference_server.py +0 -29
  289. truss/server/model_wrapper.py +0 -434
  290. truss/server/shared/logging.py +0 -81
  291. truss/templates/trtllm/model/model.py +0 -97
  292. truss/templates/trtllm/packages/build_engine_utils.py +0 -34
  293. truss/templates/trtllm/packages/constants.py +0 -11
  294. truss/templates/trtllm/packages/schema.py +0 -216
  295. truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
  296. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
  297. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
  298. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
  299. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
  300. truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
  301. truss/templates/trtllm/packages/triton_client.py +0 -150
  302. truss/templates/trtllm/packages/utils.py +0 -43
  303. truss/test_data/context_builder_image_test/test.py +0 -4
  304. truss/test_data/happy.ipynb +0 -54
  305. truss/test_data/model_load_failure_test/config.yaml +0 -2
  306. truss/test_data/test_concurrency_truss/config.yaml +0 -2
  307. truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
  308. truss/test_data/test_streaming_truss/config.yaml +0 -3
  309. truss/test_data/test_truss/config.yaml +0 -2
  310. truss/tests/server/common/test_termination_handler_middleware.py +0 -93
  311. truss/tests/server/control/test_model_container_patch_applier.py +0 -203
  312. truss/tests/server/core/server/common/test_util.py +0 -19
  313. truss/tests/server/test_model_wrapper.py +0 -87
  314. truss/util/data_structures.py +0 -16
  315. truss-0.10.0rc1.dist-info/RECORD +0 -216
  316. truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
  317. truss/{server/shared → base}/__init__.py +0 -0
  318. truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
  319. truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
  320. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
  321. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
  322. truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
  323. truss/{server → templates/server}/common/__init__.py +0 -0
  324. truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
  325. truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
  326. truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
  327. truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
  328. truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
  329. truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
  330. truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
  331. truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
  332. truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
  333. truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
  334. truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
  335. truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
  336. truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
  337. truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
  338. truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
  339. truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
  340. truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
  341. truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
  342. truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
  343. truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
  344. truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
  345. truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
  346. truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
  347. /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
  348. /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
  349. /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
  350. /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
  351. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
  352. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
  353. /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
  354. /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
  355. /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
  356. /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
  357. /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
  358. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
  359. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
  360. /truss/{patch → truss_handle/patch}/constants.py +0 -0
  361. /truss/{notebook.py → util/notebook.py} +0 -0
  362. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
@@ -1,4 +1,4 @@
1
- from truss.patch.dir_signature import directory_content_signature
1
+ from truss.truss_handle.patch.dir_signature import directory_content_signature
2
2
 
3
3
 
4
4
  def test_directory_content_signature(tmp_path):
@@ -12,12 +12,7 @@ def test_directory_content_signature(tmp_path):
12
12
 
13
13
  content_sign = directory_content_signature(root)
14
14
 
15
- assert content_sign.keys() == {
16
- "dir",
17
- "dir/file3",
18
- "file1",
19
- "file2",
20
- }
15
+ assert content_sign.keys() == {"dir", "dir/file3", "file1", "file2"}
21
16
 
22
17
 
23
18
  def test_directory_content_signature_ignore_patterns(tmp_path):
@@ -40,8 +35,4 @@ def test_directory_content_signature_ignore_patterns(tmp_path):
40
35
  root=root, ignore_patterns=["data/*", ".git"]
41
36
  )
42
37
 
43
- assert content_sign.keys() == {
44
- "data",
45
- "file1",
46
- "file2",
47
- }
38
+ assert content_sign.keys() == {"data", "file1", "file2"}
@@ -4,7 +4,7 @@ from pathlib import Path
4
4
  from typing import Callable, List
5
5
 
6
6
  import pytest
7
- from truss.patch.hash import (
7
+ from truss.truss_handle.patch.hash import (
8
8
  directory_content_hash,
9
9
  file_content_hash,
10
10
  file_content_hash_str,
@@ -1,4 +1,4 @@
1
- from truss.patch.signature import calc_truss_signature
1
+ from truss.truss_handle.patch.signature import calc_truss_signature
2
2
 
3
3
 
4
4
  def test_calc_truss_signature(custom_model_truss_dir):
@@ -2,17 +2,18 @@ import logging
2
2
  from pathlib import Path
3
3
 
4
4
  import yaml
5
- from truss.patch.truss_dir_patch_applier import TrussDirPatchApplier
6
- from truss.server.control.patch.types import (
5
+ from truss.base.truss_config import TrussConfig
6
+ from truss.templates.control.control.helpers.custom_types import (
7
7
  Action,
8
8
  ConfigPatch,
9
9
  ModelCodePatch,
10
+ PackagePatch,
10
11
  Patch,
11
12
  PatchType,
12
13
  PythonRequirementPatch,
13
14
  SystemPackagePatch,
14
15
  )
15
- from truss.truss_config import TrussConfig
16
+ from truss.truss_handle.patch.truss_dir_patch_applier import TrussDirPatchApplier
16
17
 
17
18
  TEST_LOGGER = logging.getLogger("test_logger")
18
19
 
@@ -32,6 +33,23 @@ def test_model_code_patch(custom_model_truss_dir: Path):
32
33
  assert (custom_model_truss_dir / "model" / "model.py").read_text() == "test_content"
33
34
 
34
35
 
36
+ def test_packages_patch(custom_model_truss_dir: Path):
37
+ applier = TrussDirPatchApplier(custom_model_truss_dir, TEST_LOGGER)
38
+ applier(
39
+ [
40
+ Patch(
41
+ type=PatchType.PACKAGE,
42
+ body=PackagePatch(
43
+ action=Action.UPDATE, path="user_package.py", content="import sys"
44
+ ),
45
+ )
46
+ ]
47
+ )
48
+ assert (
49
+ custom_model_truss_dir / "packages" / "user_package.py"
50
+ ).read_text() == "import sys"
51
+
52
+
35
53
  def test_python_requirement_patch(custom_model_truss_dir: Path):
36
54
  req = "git+https://github.com/huggingface/transformers.git"
37
55
  applier = TrussDirPatchApplier(custom_model_truss_dir, TEST_LOGGER)
@@ -41,10 +59,7 @@ def test_python_requirement_patch(custom_model_truss_dir: Path):
41
59
  [
42
60
  Patch(
43
61
  type=PatchType.PYTHON_REQUIREMENT,
44
- body=PythonRequirementPatch(
45
- action=Action.ADD,
46
- requirement=req,
47
- ),
62
+ body=PythonRequirementPatch(action=Action.ADD, requirement=req),
48
63
  ),
49
64
  Patch(
50
65
  type=PatchType.CONFIG,
@@ -73,10 +88,7 @@ def test_system_requirement_patch(custom_model_truss_dir: Path):
73
88
  ),
74
89
  Patch(
75
90
  type=PatchType.SYSTEM_PACKAGE,
76
- body=SystemPackagePatch(
77
- action=Action.ADD,
78
- package="curl",
79
- ),
91
+ body=SystemPackagePatch(action=Action.ADD, package="curl"),
80
92
  ),
81
93
  ]
82
94
  )
@@ -1,5 +1,5 @@
1
- from truss.patch.signature import calc_truss_signature
2
- from truss.patch.types import TrussSignature
1
+ from truss.truss_handle.patch.custom_types import TrussSignature
2
+ from truss.truss_handle.patch.signature import calc_truss_signature
3
3
 
4
4
 
5
5
  def test_truss_signature_type(custom_model_truss_dir):
@@ -4,9 +4,11 @@ import pytest
4
4
  import requests
5
5
  from requests import Response
6
6
  from truss.remote.baseten.api import BasetenApi
7
+ from truss.remote.baseten.custom_types import ChainletDataAtomic, OracleData
7
8
  from truss.remote.baseten.error import ApiError
8
9
 
9
10
 
11
+ @pytest.fixture
10
12
  def mock_auth_service():
11
13
  auth_service = mock.Mock()
12
14
  auth_token = mock.Mock(headers=lambda: {"Authorization": "Api-Key token"})
@@ -52,45 +54,62 @@ def mock_create_model_response():
52
54
  return response
53
55
 
54
56
 
55
- @mock.patch("truss.remote.baseten.auth.AuthService")
56
- @mock.patch("requests.post", return_value=mock_successful_response())
57
- def test_post_graphql_query_success(mock_post, mock_auth_service):
58
- api_url = "https://test.com/api"
59
- api = BasetenApi(api_url, mock_auth_service)
57
+ def mock_create_development_model_response():
58
+ response = Response()
59
+ response.status_code = 200
60
+ response.json = mock.Mock(
61
+ return_value={"data": {"deploy_draft_truss": {"id": "12345"}}}
62
+ )
63
+ return response
64
+
65
+
66
+ def mock_deploy_chain_deployment_response():
67
+ response = Response()
68
+ response.status_code = 200
69
+ response.json = mock.Mock(
70
+ return_value={
71
+ "data": {
72
+ "deploy_chain_atomic": {
73
+ "chain_id": "12345",
74
+ "chain_deployment_id": "54321",
75
+ "entrypoint_model_id": "67890",
76
+ "entrypoint_model_version_id": "09876",
77
+ }
78
+ }
79
+ }
80
+ )
81
+ return response
82
+
60
83
 
84
+ @pytest.fixture
85
+ def baseten_api(mock_auth_service):
86
+ return BasetenApi("https://app.test.com", mock_auth_service)
87
+
88
+
89
+ @mock.patch("requests.post", return_value=mock_successful_response())
90
+ def test_post_graphql_query_success(mock_post, baseten_api):
61
91
  response_data = {"data": {"status": "success"}}
62
92
 
63
- result = api._post_graphql_query("sample_query_string")
93
+ result = baseten_api._post_graphql_query("sample_query_string")
64
94
 
65
95
  assert result == response_data
66
96
 
67
97
 
68
- @mock.patch("truss.remote.baseten.auth.AuthService")
69
98
  @mock.patch("requests.post", return_value=mock_graphql_error_response())
70
- def test_post_graphql_query_error(mock_post, mock_auth_service):
71
- api_url = "https://test.com/api"
72
- api = BasetenApi(api_url, mock_auth_service)
73
-
99
+ def test_post_graphql_query_error(mock_post, baseten_api):
74
100
  with pytest.raises(ApiError):
75
- api._post_graphql_query("sample_query_string")
101
+ baseten_api._post_graphql_query("sample_query_string")
76
102
 
77
103
 
78
- @mock.patch("truss.remote.baseten.auth.AuthService")
79
104
  @mock.patch("requests.post", return_value=mock_unsuccessful_response())
80
- def test_post_requests_error(mock_post, mock_auth_service):
81
- api_url = "https://test.com/api"
82
- api = BasetenApi(api_url, mock_auth_service)
105
+ def test_post_requests_error(mock_post, baseten_api):
83
106
  with pytest.raises(requests.exceptions.HTTPError):
84
- api._post_graphql_query("sample_query_string")
107
+ baseten_api._post_graphql_query("sample_query_string")
85
108
 
86
109
 
87
- @mock.patch("truss.remote.baseten.auth.AuthService")
88
110
  @mock.patch("requests.post", return_value=mock_create_model_version_response())
89
- def test_create_model_version_from_truss(mock_post, mock_auth_service):
90
- api_url = "https://test.com/api"
91
- api = BasetenApi(api_url, mock_auth_service)
92
-
93
- api.create_model_version_from_truss(
111
+ def test_create_model_version_from_truss(mock_post, baseten_api):
112
+ baseten_api.create_model_version_from_truss(
94
113
  "model_id",
95
114
  "s3key",
96
115
  "config_str",
@@ -98,8 +117,8 @@ def test_create_model_version_from_truss(mock_post, mock_auth_service):
98
117
  "client_version",
99
118
  False,
100
119
  False,
101
- False,
102
120
  "deployment_name",
121
+ "production",
103
122
  )
104
123
 
105
124
  gql_mutation = mock_post.call_args[1]["data"]["query"]
@@ -109,27 +128,22 @@ def test_create_model_version_from_truss(mock_post, mock_auth_service):
109
128
  assert 'semver_bump: "semver_bump"' in gql_mutation
110
129
  assert 'client_version: "client_version"' in gql_mutation
111
130
  assert "is_trusted: false" in gql_mutation
112
- assert "promote_after_deploy: false" in gql_mutation
113
131
  assert "scale_down_old_production: true" in gql_mutation
114
132
  assert 'name: "deployment_name"' in gql_mutation
133
+ assert 'environment_name: "production"' in gql_mutation
115
134
 
116
135
 
117
- @mock.patch("truss.remote.baseten.auth.AuthService")
118
136
  @mock.patch("requests.post", return_value=mock_create_model_version_response())
119
137
  def test_create_model_version_from_truss_does_not_send_deployment_name_if_not_specified(
120
- mock_post, mock_auth_service
138
+ mock_post, baseten_api
121
139
  ):
122
- api_url = "https://test.com/api"
123
- api = BasetenApi(api_url, mock_auth_service)
124
-
125
- api.create_model_version_from_truss(
140
+ baseten_api.create_model_version_from_truss(
126
141
  "model_id",
127
142
  "s3key",
128
143
  "config_str",
129
144
  "semver_bump",
130
145
  "client_version",
131
146
  True,
132
- True,
133
147
  False,
134
148
  deployment_name=None,
135
149
  )
@@ -141,20 +155,16 @@ def test_create_model_version_from_truss_does_not_send_deployment_name_if_not_sp
141
155
  assert 'semver_bump: "semver_bump"' in gql_mutation
142
156
  assert 'client_version: "client_version"' in gql_mutation
143
157
  assert "is_trusted: true" in gql_mutation
144
- assert "promote_after_deploy: true" in gql_mutation
145
158
  assert "scale_down_old_production: true" in gql_mutation
146
- assert "name: " not in gql_mutation
159
+ assert " name: " not in gql_mutation
160
+ assert "environment_name: " not in gql_mutation
147
161
 
148
162
 
149
- @mock.patch("truss.remote.baseten.auth.AuthService")
150
163
  @mock.patch("requests.post", return_value=mock_create_model_version_response())
151
164
  def test_create_model_version_from_truss_does_not_scale_old_prod_to_zero_if_keep_previous_prod_settings(
152
- mock_post, mock_auth_service
165
+ mock_post, baseten_api
153
166
  ):
154
- api_url = "https://test.com/api"
155
- api = BasetenApi(api_url, mock_auth_service)
156
-
157
- api.create_model_version_from_truss(
167
+ baseten_api.create_model_version_from_truss(
158
168
  "model_id",
159
169
  "s3key",
160
170
  "config_str",
@@ -162,8 +172,8 @@ def test_create_model_version_from_truss_does_not_scale_old_prod_to_zero_if_keep
162
172
  "client_version",
163
173
  True,
164
174
  True,
165
- True,
166
175
  deployment_name=None,
176
+ environment="staging",
167
177
  )
168
178
 
169
179
  gql_mutation = mock_post.call_args[1]["data"]["query"]
@@ -173,25 +183,21 @@ def test_create_model_version_from_truss_does_not_scale_old_prod_to_zero_if_keep
173
183
  assert 'semver_bump: "semver_bump"' in gql_mutation
174
184
  assert 'client_version: "client_version"' in gql_mutation
175
185
  assert "is_trusted: true" in gql_mutation
176
- assert "promote_after_deploy: true" in gql_mutation
177
186
  assert "scale_down_old_production: false" in gql_mutation
178
- assert "name: " not in gql_mutation
187
+ assert " name: " not in gql_mutation
188
+ assert 'environment_name: "staging"' in gql_mutation
179
189
 
180
190
 
181
- @mock.patch("truss.remote.baseten.auth.AuthService")
182
191
  @mock.patch("requests.post", return_value=mock_create_model_response())
183
- def test_create_model_from_truss(mock_post, mock_auth_service):
184
- api_url = "https://test.com/api"
185
- api = BasetenApi(api_url, mock_auth_service)
186
-
187
- api.create_model_from_truss(
192
+ def test_create_model_from_truss(mock_post, baseten_api):
193
+ baseten_api.create_model_from_truss(
188
194
  "model_name",
189
195
  "s3key",
190
196
  "config_str",
191
197
  "semver_bump",
192
198
  "client_version",
193
- False,
194
- "deployment_name",
199
+ is_trusted=False,
200
+ deployment_name="deployment_name",
195
201
  )
196
202
 
197
203
  gql_mutation = mock_post.call_args[1]["data"]["query"]
@@ -204,21 +210,17 @@ def test_create_model_from_truss(mock_post, mock_auth_service):
204
210
  assert 'version_name: "deployment_name"' in gql_mutation
205
211
 
206
212
 
207
- @mock.patch("truss.remote.baseten.auth.AuthService")
208
213
  @mock.patch("requests.post", return_value=mock_create_model_response())
209
214
  def test_create_model_from_truss_does_not_send_deployment_name_if_not_specified(
210
- mock_post, mock_auth_service
215
+ mock_post, baseten_api
211
216
  ):
212
- api_url = "https://test.com/api"
213
- api = BasetenApi(api_url, mock_auth_service)
214
-
215
- api.create_model_from_truss(
217
+ baseten_api.create_model_from_truss(
216
218
  "model_name",
217
219
  "s3key",
218
220
  "config_str",
219
221
  "semver_bump",
220
222
  "client_version",
221
- True,
223
+ is_trusted=True,
222
224
  deployment_name=None,
223
225
  )
224
226
 
@@ -230,3 +232,96 @@ def test_create_model_from_truss_does_not_send_deployment_name_if_not_specified(
230
232
  assert 'client_version: "client_version"' in gql_mutation
231
233
  assert "is_trusted: true" in gql_mutation
232
234
  assert "version_name: " not in gql_mutation
235
+
236
+
237
+ @mock.patch("requests.post", return_value=mock_create_model_response())
238
+ def test_create_model_from_truss_with_allow_truss_download(mock_post, baseten_api):
239
+ baseten_api.create_model_from_truss(
240
+ "model_name",
241
+ "s3key",
242
+ "config_str",
243
+ "semver_bump",
244
+ "client_version",
245
+ is_trusted=True,
246
+ allow_truss_download=False,
247
+ )
248
+
249
+ gql_mutation = mock_post.call_args[1]["data"]["query"]
250
+ assert 'name: "model_name"' in gql_mutation
251
+ assert 's3_key: "s3key"' in gql_mutation
252
+ assert 'config: "config_str"' in gql_mutation
253
+ assert 'semver_bump: "semver_bump"' in gql_mutation
254
+ assert 'client_version: "client_version"' in gql_mutation
255
+ assert "is_trusted: true" in gql_mutation
256
+ assert "allow_truss_download: false" in gql_mutation
257
+
258
+
259
+ @mock.patch("requests.post", return_value=mock_create_development_model_response())
260
+ def test_create_development_model_from_truss_with_allow_truss_download(
261
+ mock_post, baseten_api
262
+ ):
263
+ baseten_api.create_development_model_from_truss(
264
+ "model_name",
265
+ "s3key",
266
+ "config_str",
267
+ "client_version",
268
+ is_trusted=True,
269
+ allow_truss_download=False,
270
+ )
271
+
272
+ gql_mutation = mock_post.call_args[1]["data"]["query"]
273
+ assert 'name: "model_name"' in gql_mutation
274
+ assert 's3_key: "s3key"' in gql_mutation
275
+ assert 'config: "config_str"' in gql_mutation
276
+ assert 'client_version: "client_version"' in gql_mutation
277
+ assert "is_trusted: true" in gql_mutation
278
+ assert "allow_truss_download: false" in gql_mutation
279
+
280
+
281
+ @mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response())
282
+ def test_deploy_chain_deployment(mock_post, baseten_api):
283
+ baseten_api.deploy_chain_atomic(
284
+ environment="production",
285
+ chain_id="chain_id",
286
+ dependencies=[],
287
+ entrypoint=ChainletDataAtomic(
288
+ name="chainlet-1",
289
+ oracle=OracleData(
290
+ model_name="model-1",
291
+ s3_key="s3-key-1",
292
+ encoded_config_str="encoded-config-str-1",
293
+ is_trusted=True,
294
+ ),
295
+ ),
296
+ )
297
+
298
+ gql_mutation = mock_post.call_args[1]["data"]["query"]
299
+
300
+ assert 'environment: "production"' in gql_mutation
301
+ assert 'chain_id: "chain_id"' in gql_mutation
302
+ assert "dependencies:" in gql_mutation
303
+ assert "entrypoint:" in gql_mutation
304
+
305
+
306
+ @mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response())
307
+ def test_deploy_chain_deployment_no_environment(mock_post, baseten_api):
308
+ baseten_api.deploy_chain_atomic(
309
+ chain_id="chain_id",
310
+ dependencies=[],
311
+ entrypoint=ChainletDataAtomic(
312
+ name="chainlet-1",
313
+ oracle=OracleData(
314
+ model_name="model-1",
315
+ s3_key="s3-key-1",
316
+ encoded_config_str="encoded-config-str-1",
317
+ is_trusted=True,
318
+ ),
319
+ ),
320
+ )
321
+
322
+ gql_mutation = mock_post.call_args[1]["data"]["query"]
323
+
324
+ assert 'chain_id: "chain_id"' in gql_mutation
325
+ assert "environment" not in gql_mutation
326
+ assert "dependencies:" in gql_mutation
327
+ assert "entrypoint:" in gql_mutation
@@ -9,7 +9,8 @@ def test_api_key():
9
9
  assert key.header() == {"Authorization": "Api-Key test_key"}
10
10
 
11
11
 
12
- def test_auth_service_no_key():
12
+ def test_auth_service_no_key(monkeypatch: pytest.MonkeyPatch):
13
+ monkeypatch.delenv("BASETEN_API_KEY", raising=False)
13
14
  auth_service = AuthService()
14
15
  with pytest.raises(AuthorizationError):
15
16
  auth_service.authenticate()
@@ -1,8 +1,13 @@
1
+ import json
1
2
  from tempfile import NamedTemporaryFile
2
3
  from unittest.mock import MagicMock
3
4
 
5
+ import pytest
6
+ from truss.base.constants import PRODUCTION_ENVIRONMENT_NAME
7
+ from truss.base.errors import ValidationError
4
8
  from truss.remote.baseten import core
5
9
  from truss.remote.baseten.api import BasetenApi
10
+ from truss.remote.baseten.core import create_truss_service
6
11
  from truss.remote.baseten.error import ApiError
7
12
 
8
13
 
@@ -35,31 +40,23 @@ def test_upload_truss():
35
40
  core.multipart_upload_boto3 = MagicMock()
36
41
  core.multipart_upload_boto3.return_value = None
37
42
  test_file = NamedTemporaryFile()
38
- assert core.upload_truss(api, test_file) == "key"
43
+ assert core.upload_truss(api, test_file, None) == "key"
39
44
 
40
45
 
41
46
  def test_get_dev_version_from_versions():
42
- versions = [
43
- {"id": "1", "is_draft": False},
44
- {"id": "2", "is_draft": True},
45
- ]
47
+ versions = [{"id": "1", "is_draft": False}, {"id": "2", "is_draft": True}]
46
48
  dev_version = core.get_dev_version_from_versions(versions)
47
49
  assert dev_version["id"] == "2"
48
50
 
49
51
 
50
52
  def test_get_dev_version_from_versions_error():
51
- versions = [
52
- {"id": "1", "is_draft": False},
53
- ]
53
+ versions = [{"id": "1", "is_draft": False}]
54
54
  dev_version = core.get_dev_version_from_versions(versions)
55
55
  assert dev_version is None
56
56
 
57
57
 
58
58
  def test_get_dev_version():
59
- versions = [
60
- {"id": "1", "is_draft": False},
61
- {"id": "2", "is_draft": True},
62
- ]
59
+ versions = [{"id": "1", "is_draft": False}, {"id": "2", "is_draft": True}]
63
60
  api = MagicMock()
64
61
  api.get_model.return_value = {"model": {"versions": versions}}
65
62
 
@@ -84,3 +81,154 @@ def test_get_prod_version_from_versions_error():
84
81
  ]
85
82
  prod_version = core.get_prod_version_from_versions(versions)
86
83
  assert prod_version is None
84
+
85
+
86
+ @pytest.mark.parametrize("environment", [None, PRODUCTION_ENVIRONMENT_NAME])
87
+ def test_create_truss_service_handles_eligible_environment_values(environment):
88
+ api = MagicMock()
89
+ return_value = {"id": "id", "version_id": "model_version_id"}
90
+ api.create_model_from_truss.return_value = return_value
91
+ model_id, model_version_id = create_truss_service(
92
+ api,
93
+ "model_name",
94
+ "s3_key",
95
+ "config",
96
+ is_trusted=False,
97
+ preserve_previous_prod_deployment=False,
98
+ is_draft=False,
99
+ model_id=None,
100
+ deployment_name="deployment_name",
101
+ environment=environment,
102
+ )
103
+ assert model_id == return_value["id"]
104
+ assert model_version_id == return_value["version_id"]
105
+ api.create_model_from_truss.assert_called_once()
106
+
107
+
108
+ @pytest.mark.parametrize("model_id", ["some_model_id", None])
109
+ def test_create_truss_services_handles_is_draft(model_id):
110
+ api = MagicMock()
111
+ return_value = {"id": "id", "version_id": "model_version_id"}
112
+ api.create_development_model_from_truss.return_value = return_value
113
+ model_id, model_version_id = create_truss_service(
114
+ api,
115
+ "model_name",
116
+ "s3_key",
117
+ "config",
118
+ is_trusted=False,
119
+ preserve_previous_prod_deployment=False,
120
+ is_draft=True,
121
+ model_id=model_id,
122
+ deployment_name="deployment_name",
123
+ )
124
+ assert model_id == return_value["id"]
125
+ assert model_version_id == return_value["version_id"]
126
+ api.create_development_model_from_truss.assert_called_once()
127
+
128
+
129
+ @pytest.mark.parametrize(
130
+ "inputs",
131
+ [
132
+ {
133
+ "environment": None,
134
+ "deployment_name": "some deployment",
135
+ "is_trusted": True,
136
+ "preserve_previous_prod_deployment": False,
137
+ },
138
+ {
139
+ "environment": PRODUCTION_ENVIRONMENT_NAME,
140
+ "deployment_name": None,
141
+ "is_trusted": True,
142
+ "preserve_previous_prod_deployment": False,
143
+ },
144
+ {
145
+ "environment": "staging",
146
+ "deployment_name": "some_deployment_name",
147
+ "is_trusted": False,
148
+ "preserve_previous_prod_deployment": True,
149
+ },
150
+ ],
151
+ )
152
+ def test_create_truss_service_handles_existing_model(inputs):
153
+ api = MagicMock()
154
+ return_value = {"id": "model_version_id"}
155
+ api.create_model_version_from_truss.return_value = return_value
156
+ model_id, model_version_id = create_truss_service(
157
+ api,
158
+ "model_name",
159
+ "s3_key",
160
+ "config",
161
+ is_draft=False,
162
+ model_id="model_id",
163
+ **inputs,
164
+ )
165
+
166
+ assert model_id == "model_id"
167
+ assert model_version_id == return_value["id"]
168
+ api.create_model_version_from_truss.assert_called_once()
169
+ _, kwargs = api.create_model_version_from_truss.call_args
170
+ for k, v in inputs.items():
171
+ assert kwargs[k] == v
172
+
173
+
174
+ @pytest.mark.parametrize("allow_truss_download", [True, False])
175
+ @pytest.mark.parametrize("is_draft", [True, False])
176
+ def test_create_truss_service_handles_allow_truss_download_for_new_models(
177
+ is_draft, allow_truss_download
178
+ ):
179
+ api = MagicMock()
180
+ return_value = {"id": "id", "version_id": "model_version_id"}
181
+ api.create_model_from_truss.return_value = return_value
182
+ api.create_development_model_from_truss.return_value = return_value
183
+
184
+ model_id = None
185
+ model_id, model_version_id = create_truss_service(
186
+ api,
187
+ "model_name",
188
+ "s3_key",
189
+ "config",
190
+ is_trusted=False,
191
+ preserve_previous_prod_deployment=False,
192
+ is_draft=is_draft,
193
+ model_id=model_id,
194
+ deployment_name="deployment_name",
195
+ allow_truss_download=allow_truss_download,
196
+ )
197
+ assert model_id == return_value["id"]
198
+ assert model_version_id == return_value["version_id"]
199
+
200
+ create_model_mock = (
201
+ api.create_development_model_from_truss
202
+ if is_draft
203
+ else api.create_model_from_truss
204
+ )
205
+ create_model_mock.assert_called_once()
206
+ _, kwargs = create_model_mock.call_args
207
+ assert kwargs["allow_truss_download"] is allow_truss_download
208
+
209
+
210
+ def test_validate_truss_config():
211
+ def mock_validate_truss(client_version, config):
212
+ if config == {}:
213
+ return {"success": True, "details": json.dumps({})}
214
+ elif "hi" in config:
215
+ return {"success": False, "details": json.dumps({"errors": ["error"]})}
216
+ else:
217
+ return {
218
+ "success": False,
219
+ "details": json.dumps({"errors": ["error", "and another one"]}),
220
+ }
221
+
222
+ api = MagicMock()
223
+ api.validate_truss.side_effect = mock_validate_truss
224
+
225
+ assert core.validate_truss_config(api, {}) is None
226
+ with pytest.raises(
227
+ ValidationError, match="Validation failed with the following errors:\n error"
228
+ ):
229
+ core.validate_truss_config(api, {"hi": "hi"})
230
+ with pytest.raises(
231
+ ValidationError,
232
+ match="Validation failed with the following errors:\n error\n and another one",
233
+ ):
234
+ core.validate_truss_config(api, {"should_error": "hi"})