truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

Files changed (362) hide show
  1. truss/__init__.py +10 -3
  2. truss/api/__init__.py +123 -0
  3. truss/api/definitions.py +51 -0
  4. truss/base/constants.py +116 -0
  5. truss/base/custom_types.py +29 -0
  6. truss/{errors.py → base/errors.py} +4 -0
  7. truss/base/trt_llm_config.py +310 -0
  8. truss/{truss_config.py → base/truss_config.py} +344 -31
  9. truss/{truss_spec.py → base/truss_spec.py} +20 -6
  10. truss/{validation.py → base/validation.py} +60 -11
  11. truss/cli/cli.py +841 -88
  12. truss/{remote → cli}/remote_cli.py +2 -7
  13. truss/contexts/docker_build_setup.py +67 -0
  14. truss/contexts/image_builder/cache_warmer.py +2 -8
  15. truss/contexts/image_builder/image_builder.py +1 -1
  16. truss/contexts/image_builder/serving_image_builder.py +292 -46
  17. truss/contexts/image_builder/util.py +1 -3
  18. truss/contexts/local_loader/docker_build_emulator.py +58 -0
  19. truss/contexts/local_loader/load_model_local.py +2 -2
  20. truss/contexts/local_loader/truss_module_loader.py +1 -1
  21. truss/contexts/local_loader/utils.py +1 -1
  22. truss/local/local_config.py +2 -6
  23. truss/local/local_config_handler.py +20 -5
  24. truss/patch/__init__.py +1 -0
  25. truss/patch/hash.py +4 -70
  26. truss/patch/signature.py +4 -16
  27. truss/patch/truss_dir_patch_applier.py +3 -78
  28. truss/remote/baseten/api.py +308 -23
  29. truss/remote/baseten/auth.py +3 -3
  30. truss/remote/baseten/core.py +257 -50
  31. truss/remote/baseten/custom_types.py +44 -0
  32. truss/remote/baseten/error.py +4 -0
  33. truss/remote/baseten/remote.py +369 -118
  34. truss/remote/baseten/service.py +118 -11
  35. truss/remote/baseten/utils/status.py +29 -0
  36. truss/remote/baseten/utils/tar.py +34 -22
  37. truss/remote/baseten/utils/transfer.py +36 -23
  38. truss/remote/remote_factory.py +14 -5
  39. truss/remote/truss_remote.py +72 -45
  40. truss/templates/base.Dockerfile.jinja +18 -16
  41. truss/templates/cache.Dockerfile.jinja +3 -3
  42. truss/{server → templates/control}/control/application.py +14 -35
  43. truss/{server → templates/control}/control/endpoints.py +39 -9
  44. truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
  45. truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
  46. truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
  47. truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
  48. truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
  49. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
  50. truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
  51. truss/{server → templates/control}/control/server.py +11 -6
  52. truss/templates/control/requirements.txt +9 -0
  53. truss/templates/custom_python_dx/my_model.py +28 -0
  54. truss/templates/docker_server/proxy.conf.jinja +42 -0
  55. truss/templates/docker_server/supervisord.conf.jinja +27 -0
  56. truss/templates/docker_server_requirements.txt +1 -0
  57. truss/templates/server/common/errors.py +231 -0
  58. truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
  59. truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
  60. truss/{server → templates/server}/common/retry.py +1 -0
  61. truss/{server → templates/server}/common/schema.py +11 -9
  62. truss/templates/server/common/tracing.py +157 -0
  63. truss/templates/server/main.py +9 -0
  64. truss/templates/server/model_wrapper.py +961 -0
  65. truss/templates/server/requirements.txt +21 -0
  66. truss/templates/server/truss_server.py +447 -0
  67. truss/templates/server.Dockerfile.jinja +62 -14
  68. truss/templates/shared/dynamic_config_resolver.py +28 -0
  69. truss/templates/shared/lazy_data_resolver.py +164 -0
  70. truss/templates/shared/log_config.py +125 -0
  71. truss/{server → templates}/shared/secrets_resolver.py +1 -2
  72. truss/{server → templates}/shared/serialization.py +31 -9
  73. truss/{server → templates}/shared/util.py +3 -13
  74. truss/templates/trtllm-audio/model/model.py +49 -0
  75. truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
  76. truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
  77. truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
  78. truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
  79. truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
  80. truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
  81. truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
  82. truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
  83. truss/templates/trtllm-briton/src/extension.py +64 -0
  84. truss/tests/conftest.py +302 -94
  85. truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
  86. truss/tests/contexts/local_loader/test_load_local.py +2 -2
  87. truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
  88. truss/tests/patch/test_calc_patch.py +439 -127
  89. truss/tests/patch/test_dir_signature.py +3 -12
  90. truss/tests/patch/test_hash.py +1 -1
  91. truss/tests/patch/test_signature.py +1 -1
  92. truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
  93. truss/tests/patch/test_types.py +2 -2
  94. truss/tests/remote/baseten/test_api.py +153 -58
  95. truss/tests/remote/baseten/test_auth.py +2 -1
  96. truss/tests/remote/baseten/test_core.py +160 -12
  97. truss/tests/remote/baseten/test_remote.py +489 -77
  98. truss/tests/remote/baseten/test_service.py +55 -0
  99. truss/tests/remote/test_remote_factory.py +16 -18
  100. truss/tests/remote/test_truss_remote.py +26 -17
  101. truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
  102. truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
  103. truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
  104. truss/tests/{server → templates/control}/control/test_server.py +79 -24
  105. truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
  106. truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
  107. truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
  108. truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
  109. truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
  110. truss/tests/{server → templates/server}/common/test_retry.py +3 -3
  111. truss/tests/templates/server/test_model_wrapper.py +248 -0
  112. truss/tests/{server → templates/server}/test_schema.py +3 -5
  113. truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
  114. truss/tests/test_build.py +9 -52
  115. truss/tests/test_config.py +336 -77
  116. truss/tests/test_context_builder_image.py +3 -11
  117. truss/tests/test_control_truss_patching.py +7 -12
  118. truss/tests/test_custom_server.py +38 -0
  119. truss/tests/test_data/context_builder_image_test/test.py +3 -0
  120. truss/tests/test_data/happy.ipynb +56 -0
  121. truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
  122. truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
  123. truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
  124. truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
  125. truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
  126. truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
  127. truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
  128. truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
  129. truss/tests/test_data/test_async_truss/__init__.py +0 -0
  130. truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
  131. truss/tests/test_data/test_basic_truss/__init__.py +0 -0
  132. truss/tests/test_data/test_basic_truss/config.yaml +16 -0
  133. truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
  134. truss/tests/test_data/test_build_commands/__init__.py +0 -0
  135. truss/tests/test_data/test_build_commands/config.yaml +13 -0
  136. truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
  137. truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
  138. truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
  139. truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
  140. truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
  141. truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
  142. truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
  143. truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
  144. truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
  145. truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
  146. truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
  147. truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
  148. truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
  149. truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
  150. truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
  151. truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
  152. truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
  153. truss/tests/test_data/test_openai/__init__.py +0 -0
  154. truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
  155. truss/tests/test_data/test_openai/model/__init__.py +0 -0
  156. truss/tests/test_data/test_openai/model/model.py +15 -0
  157. truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
  158. truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
  159. truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
  160. truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
  161. truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
  162. truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
  163. truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
  164. truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
  165. truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
  166. truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
  167. truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
  168. truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
  169. truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
  170. truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
  171. truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
  172. truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
  173. truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
  174. truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
  175. truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
  176. truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
  177. truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
  178. truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
  179. truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
  180. truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
  181. truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
  182. truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
  183. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
  184. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
  185. truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
  186. truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
  187. truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
  188. truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
  189. truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
  190. truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
  191. truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
  192. truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
  193. truss/tests/test_data/test_truss/__init__.py +0 -0
  194. truss/tests/test_data/test_truss/config.yaml +4 -0
  195. truss/tests/test_data/test_truss/model/__init__.py +0 -0
  196. truss/tests/test_data/test_truss/model/dummy +0 -0
  197. truss/tests/test_data/test_truss/packages/__init__.py +0 -0
  198. truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
  199. truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
  200. truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
  201. truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
  202. truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
  203. truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
  204. truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
  205. truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
  206. truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
  207. truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
  208. truss/tests/test_docker.py +2 -1
  209. truss/tests/test_model_inference.py +1340 -292
  210. truss/tests/test_model_schema.py +33 -26
  211. truss/tests/test_testing_utilities_for_other_tests.py +50 -5
  212. truss/tests/test_truss_gatherer.py +3 -5
  213. truss/tests/test_truss_handle.py +62 -59
  214. truss/tests/test_util.py +2 -1
  215. truss/tests/test_validation.py +15 -13
  216. truss/tests/trt_llm/test_trt_llm_config.py +41 -0
  217. truss/tests/trt_llm/test_validation.py +91 -0
  218. truss/tests/util/test_config_checks.py +40 -0
  219. truss/tests/util/test_env_vars.py +14 -0
  220. truss/tests/util/test_path.py +10 -23
  221. truss/trt_llm/config_checks.py +43 -0
  222. truss/trt_llm/validation.py +42 -0
  223. truss/truss_handle/__init__.py +0 -0
  224. truss/truss_handle/build.py +122 -0
  225. truss/{decorators.py → truss_handle/decorators.py} +1 -1
  226. truss/truss_handle/patch/__init__.py +0 -0
  227. truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
  228. truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
  229. truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
  230. truss/truss_handle/patch/hash.py +71 -0
  231. truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
  232. truss/truss_handle/patch/signature.py +22 -0
  233. truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
  234. truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
  235. truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
  236. truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
  237. truss/util/.truss_ignore +3 -0
  238. truss/{docker.py → util/docker.py} +6 -2
  239. truss/util/download.py +6 -15
  240. truss/util/env_vars.py +41 -0
  241. truss/util/log_utils.py +52 -0
  242. truss/util/path.py +20 -20
  243. truss/util/requirements.py +11 -0
  244. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
  245. truss-0.60.0.dist-info/RECORD +324 -0
  246. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
  247. truss-0.60.0.dist-info/entry_points.txt +4 -0
  248. truss_chains/__init__.py +71 -0
  249. truss_chains/definitions.py +756 -0
  250. truss_chains/deployment/__init__.py +0 -0
  251. truss_chains/deployment/code_gen.py +816 -0
  252. truss_chains/deployment/deployment_client.py +871 -0
  253. truss_chains/framework.py +1480 -0
  254. truss_chains/public_api.py +231 -0
  255. truss_chains/py.typed +0 -0
  256. truss_chains/pydantic_numpy.py +131 -0
  257. truss_chains/reference_code/reference_chainlet.py +34 -0
  258. truss_chains/reference_code/reference_model.py +10 -0
  259. truss_chains/remote_chainlet/__init__.py +0 -0
  260. truss_chains/remote_chainlet/model_skeleton.py +60 -0
  261. truss_chains/remote_chainlet/stub.py +380 -0
  262. truss_chains/remote_chainlet/utils.py +332 -0
  263. truss_chains/streaming.py +378 -0
  264. truss_chains/utils.py +178 -0
  265. CODE_OF_CONDUCT.md +0 -131
  266. CONTRIBUTING.md +0 -48
  267. README.md +0 -137
  268. context_builder.Dockerfile +0 -24
  269. truss/blob/blob_backend.py +0 -10
  270. truss/blob/blob_backend_registry.py +0 -23
  271. truss/blob/http_public_blob_backend.py +0 -23
  272. truss/build/__init__.py +0 -2
  273. truss/build/build.py +0 -143
  274. truss/build/configure.py +0 -63
  275. truss/cli/__init__.py +0 -2
  276. truss/cli/console.py +0 -5
  277. truss/cli/create.py +0 -5
  278. truss/config/trt_llm.py +0 -81
  279. truss/constants.py +0 -61
  280. truss/model_inference.py +0 -123
  281. truss/patch/types.py +0 -30
  282. truss/pytest.ini +0 -7
  283. truss/server/common/errors.py +0 -100
  284. truss/server/common/termination_handler_middleware.py +0 -64
  285. truss/server/common/truss_server.py +0 -389
  286. truss/server/control/patch/model_code_patch_applier.py +0 -46
  287. truss/server/control/patch/requirement_name_identifier.py +0 -17
  288. truss/server/inference_server.py +0 -29
  289. truss/server/model_wrapper.py +0 -434
  290. truss/server/shared/logging.py +0 -81
  291. truss/templates/trtllm/model/model.py +0 -97
  292. truss/templates/trtllm/packages/build_engine_utils.py +0 -34
  293. truss/templates/trtllm/packages/constants.py +0 -11
  294. truss/templates/trtllm/packages/schema.py +0 -216
  295. truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
  296. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
  297. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
  298. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
  299. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
  300. truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
  301. truss/templates/trtllm/packages/triton_client.py +0 -150
  302. truss/templates/trtllm/packages/utils.py +0 -43
  303. truss/test_data/context_builder_image_test/test.py +0 -4
  304. truss/test_data/happy.ipynb +0 -54
  305. truss/test_data/model_load_failure_test/config.yaml +0 -2
  306. truss/test_data/test_concurrency_truss/config.yaml +0 -2
  307. truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
  308. truss/test_data/test_streaming_truss/config.yaml +0 -3
  309. truss/test_data/test_truss/config.yaml +0 -2
  310. truss/tests/server/common/test_termination_handler_middleware.py +0 -93
  311. truss/tests/server/control/test_model_container_patch_applier.py +0 -203
  312. truss/tests/server/core/server/common/test_util.py +0 -19
  313. truss/tests/server/test_model_wrapper.py +0 -87
  314. truss/util/data_structures.py +0 -16
  315. truss-0.10.0rc1.dist-info/RECORD +0 -216
  316. truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
  317. truss/{server/shared → base}/__init__.py +0 -0
  318. truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
  319. truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
  320. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
  321. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
  322. truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
  323. truss/{server → templates/server}/common/__init__.py +0 -0
  324. truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
  325. truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
  326. truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
  327. truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
  328. truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
  329. truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
  330. truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
  331. truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
  332. truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
  333. truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
  334. truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
  335. truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
  336. truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
  337. truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
  338. truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
  339. truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
  340. truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
  341. truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
  342. truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
  343. truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
  344. truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
  345. truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
  346. truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
  347. /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
  348. /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
  349. /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
  350. /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
  351. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
  352. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
  353. /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
  354. /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
  355. /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
  356. /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
  357. /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
  358. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
  359. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
  360. /truss/{patch → truss_handle/patch}/constants.py +0 -0
  361. /truss/{notebook.py → util/notebook.py} +0 -0
  362. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
@@ -1,46 +1,78 @@
1
+ import asyncio
1
2
  import concurrent
3
+ import contextlib
4
+ import dataclasses
2
5
  import inspect
3
6
  import json
4
7
  import logging
8
+ import pathlib
9
+ import sys
5
10
  import tempfile
6
11
  import textwrap
7
12
  import time
8
13
  from concurrent.futures import ThreadPoolExecutor
9
14
  from pathlib import Path
10
15
  from threading import Thread
16
+ from typing import Iterator, Mapping, Optional
11
17
 
18
+ import httpx
19
+ import opentelemetry.trace.propagation.tracecontext as tracecontext
12
20
  import pytest
13
21
  import requests
22
+ from opentelemetry import context, trace
23
+ from python_on_whales import Container
14
24
  from requests.exceptions import RequestException
25
+
26
+ from truss.base.truss_config import map_to_supported_python_version
15
27
  from truss.local.local_config_handler import LocalConfigHandler
16
- from truss.model_inference import map_to_supported_python_version
17
28
  from truss.tests.helpers import create_truss
18
29
  from truss.tests.test_testing_utilities_for_other_tests import ensure_kill_all
19
- from truss.truss_handle import TrussHandle
30
+ from truss.truss_handle.truss_handle import TrussHandle, wait_for_truss
20
31
 
21
32
  logger = logging.getLogger(__name__)
22
33
 
23
34
  DEFAULT_LOG_ERROR = "Internal Server Error"
35
+ PREDICT_URL = "http://localhost:8090/v1/models/model:predict"
36
+ COMPLETIONS_URL = "http://localhost:8090/v1/completions"
37
+ CHAT_COMPLETIONS_URL = "http://localhost:8090/v1/chat/completions"
38
+
24
39
 
40
+ @pytest.fixture
41
+ def anyio_backend():
42
+ return "asyncio"
25
43
 
26
- def _log_contains_error(line: dict, error: str, message: str):
44
+
45
+ def _log_contains_line(
46
+ line: dict, message: str, level: str, error: Optional[str] = None
47
+ ):
27
48
  return (
28
- line["levelname"] == "ERROR"
29
- and line["message"] == message
30
- and error in line["exc_info"]
49
+ line["levelname"] == level
50
+ and message in line["message"]
51
+ and (error is None or error in line["exc_info"])
31
52
  )
32
53
 
33
54
 
34
- def assert_logs_contain_error(logs: str, error: str, message=DEFAULT_LOG_ERROR):
35
- loglines = logs.splitlines()
55
+ def _assert_logs_contain_error(logs: str, error: str, message=DEFAULT_LOG_ERROR):
56
+ loglines = [json.loads(line) for line in logs.splitlines()]
36
57
  assert any(
37
- _log_contains_error(json.loads(line), error, message) for line in loglines
58
+ _log_contains_line(line, message, "ERROR", error) for line in loglines
59
+ ), (
60
+ f"Did not find expected error in logs.\nExpected error: {error}\n"
61
+ f"Expected message: {message}\nActual logs:\n{loglines}"
62
+ )
63
+
64
+
65
+ def _assert_logs_contain(logs: str, message: str, level: str = "INFO"):
66
+ loglines = [json.loads(line) for line in logs.splitlines()]
67
+ assert any(_log_contains_line(line, message, level) for line in loglines), (
68
+ f"Did not find expected logs.\n"
69
+ f"Expected message: {message}\nActual logs:\n{loglines}"
38
70
  )
39
71
 
40
72
 
41
- class PropagatingThread(Thread):
73
+ class _PropagatingThread(Thread):
42
74
  """
43
- PropagatingThread allows us to run threads and keep track of exceptions
75
+ _PropagatingThread allows us to run threads and keep track of exceptions
44
76
  thrown.
45
77
  """
46
78
 
@@ -52,22 +84,31 @@ class PropagatingThread(Thread):
52
84
  self.exc = e
53
85
 
54
86
  def join(self, timeout=None):
55
- super(PropagatingThread, self).join(timeout)
87
+ super(_PropagatingThread, self).join(timeout)
56
88
  if self.exc:
57
89
  raise self.exc
58
90
  return self.ret
59
91
 
60
92
 
93
+ @contextlib.contextmanager
94
+ def _temp_truss(model_src: str, config_src: str = "") -> Iterator[TrussHandle]:
95
+ with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
96
+ truss_dir = Path(tmp_work_dir, "truss")
97
+ create_truss(truss_dir, config_src, textwrap.dedent(model_src))
98
+ yield TrussHandle(truss_dir)
99
+
100
+
101
+ # Test Cases ###########################################################################
102
+
103
+
61
104
  @pytest.mark.parametrize(
62
105
  "python_version, expected_python_version",
63
106
  [
64
- ("py37", "py38"),
65
107
  ("py38", "py38"),
66
108
  ("py39", "py39"),
67
109
  ("py310", "py310"),
68
110
  ("py311", "py311"),
69
111
  ("py312", "py311"),
70
- ("py36", "py38"),
71
112
  ],
72
113
  )
73
114
  def test_map_to_supported_python_version(python_version, expected_python_version):
@@ -75,11 +116,54 @@ def test_map_to_supported_python_version(python_version, expected_python_version
75
116
  assert out_python_version == expected_python_version
76
117
 
77
118
 
119
+ def test_not_supported_python_minor_versions():
120
+ with pytest.raises(
121
+ ValueError,
122
+ match="Mapping python version 3.6 to 3.8, "
123
+ "the lowest version that Truss currently supports.",
124
+ ):
125
+ map_to_supported_python_version("py36")
126
+ with pytest.raises(
127
+ ValueError,
128
+ match="Mapping python version 3.7 to 3.8, "
129
+ "the lowest version that Truss currently supports.",
130
+ ):
131
+ map_to_supported_python_version("py37")
132
+
133
+
134
+ def test_not_supported_python_major_versions():
135
+ with pytest.raises(NotImplementedError, match="Only python version 3 is supported"):
136
+ map_to_supported_python_version("py211")
137
+
138
+
78
139
  @pytest.mark.integration
79
- def test_model_load_failure_truss():
140
+ def test_model_load_logs(test_data_path):
141
+ model = """
142
+ from typing import Optional
143
+ import logging
144
+ class Model:
145
+ def load(self):
146
+ logging.info(f"User Load Message")
147
+
148
+ def predict(self, model_input):
149
+ return self.environment_name
150
+ """
151
+ config = "model_name: init-environment-truss"
152
+ with ensure_kill_all(), _temp_truss(model, config) as tr:
153
+ container = tr.docker_run(
154
+ local_port=8090, detach=True, wait_for_server_ready=True
155
+ )
156
+ logs = container.logs()
157
+ _assert_logs_contain(logs, message="Executing model.load()")
158
+ _assert_logs_contain(logs, message="Loading truss model from file")
159
+ _assert_logs_contain(logs, message="Completed model.load()")
160
+ _assert_logs_contain(logs, message="User Load Message")
161
+
162
+
163
+ @pytest.mark.integration
164
+ def test_model_load_failure_truss(test_data_path):
80
165
  with ensure_kill_all():
81
- truss_root = Path(__file__).parent.parent.parent.resolve() / "truss"
82
- truss_dir = truss_root / "test_data" / "model_load_failure_test"
166
+ truss_dir = test_data_path / "model_load_failure_test"
83
167
  tr = TrussHandle(truss_dir)
84
168
 
85
169
  _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=False)
@@ -110,6 +194,12 @@ def test_model_load_failure_truss():
110
194
  assert ready.status_code == expected_code
111
195
  return True
112
196
 
197
+ @handle_request_exception
198
+ def _test_is_loaded(expected_code):
199
+ ready = requests.get(f"{truss_server_addr}/v1/models/model/loaded")
200
+ assert ready.status_code == expected_code
201
+ return True
202
+
113
203
  @handle_request_exception
114
204
  def _test_ping(expected_code):
115
205
  ping = requests.get(f"{truss_server_addr}/ping")
@@ -118,43 +208,39 @@ def test_model_load_failure_truss():
118
208
 
119
209
  @handle_request_exception
120
210
  def _test_invocations(expected_code):
121
- invocations = requests.post(f"{truss_server_addr}/invocations", json={})
211
+ invocations = requests.post(
212
+ f"{truss_server_addr}/v1/models/model:predict", json={}
213
+ )
122
214
  assert invocations.status_code == expected_code
123
215
  return True
124
216
 
125
217
  # The server should be completely down so all requests should result in a RequestException.
126
218
  # The decorator handle_request_exception catches the RequestException and returns False.
127
- assert not _test_readiness_probe(expected_code=200)
128
219
  assert not _test_liveness_probe(expected_code=200)
220
+ assert not _test_readiness_probe(expected_code=200)
221
+ assert not _test_is_loaded(expected_code=200)
129
222
  assert not _test_ping(expected_code=200)
130
223
  assert not _test_invocations(expected_code=200)
131
224
 
132
225
 
133
226
  @pytest.mark.integration
134
- def test_concurrency_truss():
227
+ def test_concurrency_truss(test_data_path):
135
228
  # Tests that concurrency limits work correctly
136
229
  with ensure_kill_all():
137
- truss_root = Path(__file__).parent.parent.parent.resolve() / "truss"
138
-
139
- truss_dir = truss_root / "test_data" / "test_concurrency_truss"
140
-
230
+ truss_dir = test_data_path / "test_concurrency_truss"
141
231
  tr = TrussHandle(truss_dir)
142
-
143
232
  _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
144
233
 
145
- truss_server_addr = "http://localhost:8090"
146
- full_url = f"{truss_server_addr}/v1/models/model:predict"
147
-
148
234
  # Each request takes 2 seconds, for this thread, we allow
149
235
  # a concurrency of 2. This means the first two requests will
150
236
  # succeed within the 2 seconds, and the third will fail, since
151
237
  # it cannot start until the first two have completed.
152
238
  def make_request():
153
- requests.post(full_url, json={}, timeout=3)
239
+ requests.post(PREDICT_URL, json={}, timeout=3)
154
240
 
155
- successful_thread_1 = PropagatingThread(target=make_request)
156
- successful_thread_2 = PropagatingThread(target=make_request)
157
- failed_thread = PropagatingThread(target=make_request)
241
+ successful_thread_1 = _PropagatingThread(target=make_request)
242
+ successful_thread_2 = _PropagatingThread(target=make_request)
243
+ failed_thread = _PropagatingThread(target=make_request)
158
244
 
159
245
  successful_thread_1.start()
160
246
  successful_thread_2.start()
@@ -169,38 +255,40 @@ def test_concurrency_truss():
169
255
 
170
256
 
171
257
  @pytest.mark.integration
172
- def test_requirements_file_truss():
258
+ def test_requirements_file_truss(test_data_path):
173
259
  with ensure_kill_all():
174
- truss_root = Path(__file__).parent.parent.parent.resolve() / "truss"
175
-
176
- truss_dir = truss_root / "test_data" / "test_requirements_file_truss"
177
-
260
+ truss_dir = test_data_path / "test_requirements_file_truss"
178
261
  tr = TrussHandle(truss_dir)
179
-
180
262
  _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
181
- truss_server_addr = "http://localhost:8090"
182
- full_url = f"{truss_server_addr}/v1/models/model:predict"
263
+ time.sleep(3) # Sleeping to allow the load to finish
183
264
 
184
265
  # The prediction imports torch which is specified in a requirements.txt and returns if GPU is available.
185
- response = requests.post(full_url, json={})
266
+ response = requests.post(PREDICT_URL, json={})
186
267
  assert response.status_code == 200
187
268
  assert response.json() is False
188
269
 
189
270
 
190
271
  @pytest.mark.integration
191
- def test_async_truss():
272
+ @pytest.mark.parametrize("pydantic_major_version", ["1", "2"])
273
+ def test_requirements_pydantic(test_data_path, pydantic_major_version):
192
274
  with ensure_kill_all():
193
- truss_root = Path(__file__).parent.parent.parent.resolve() / "truss"
275
+ truss_dir = test_data_path / f"test_pyantic_v{pydantic_major_version}"
276
+ tr = TrussHandle(truss_dir)
277
+ _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
194
278
 
195
- truss_dir = truss_root / "test_data" / "test_async_truss"
279
+ response = requests.post(PREDICT_URL, json={})
280
+ assert response.status_code == 200
281
+ assert response.json() == '{\n "foo": "bla",\n "bar": 123\n}'
196
282
 
197
- tr = TrussHandle(truss_dir)
198
283
 
284
+ @pytest.mark.integration
285
+ def test_async_truss(test_data_path):
286
+ with ensure_kill_all():
287
+ truss_dir = test_data_path / "test_async_truss"
288
+ tr = TrussHandle(truss_dir)
199
289
  _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
200
- truss_server_addr = "http://localhost:8090"
201
- full_url = f"{truss_server_addr}/v1/models/model:predict"
202
290
 
203
- response = requests.post(full_url, json={})
291
+ response = requests.post(PREDICT_URL, json={})
204
292
  assert response.json() == {
205
293
  "preprocess_value": "value",
206
294
  "postprocess_value": "value",
@@ -208,58 +296,44 @@ def test_async_truss():
208
296
 
209
297
 
210
298
  @pytest.mark.integration
211
- def test_async_streaming():
299
+ def test_async_streaming(test_data_path):
212
300
  with ensure_kill_all():
213
- truss_root = Path(__file__).parent.parent.parent.resolve() / "truss"
214
-
215
- truss_dir = truss_root / "test_data" / "test_streaming_async_generator_truss"
216
-
301
+ truss_dir = test_data_path / "test_streaming_async_generator_truss"
217
302
  tr = TrussHandle(truss_dir)
218
-
219
303
  _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
220
- truss_server_addr = "http://localhost:8090"
221
- full_url = f"{truss_server_addr}/v1/models/model:predict"
222
304
 
223
- response = requests.post(full_url, json={}, stream=True)
305
+ response = requests.post(PREDICT_URL, json={}, stream=True)
224
306
  assert response.headers.get("transfer-encoding") == "chunked"
225
307
  assert [
226
308
  byte_string.decode() for byte_string in list(response.iter_content())
227
309
  ] == ["0", "1", "2", "3", "4"]
228
310
 
229
311
  predict_non_stream_response = requests.post(
230
- full_url,
231
- json={},
232
- stream=True,
233
- headers={"accept": "application/json"},
312
+ PREDICT_URL, json={}, stream=True, headers={"accept": "application/json"}
234
313
  )
235
314
  assert "transfer-encoding" not in predict_non_stream_response.headers
236
315
  assert predict_non_stream_response.json() == "01234"
237
316
 
238
317
 
239
318
  @pytest.mark.integration
240
- def test_async_streaming_timeout():
319
+ def test_async_streaming_timeout(test_data_path):
241
320
  with ensure_kill_all():
242
- truss_root = Path(__file__).parent.parent.parent.resolve() / "truss"
243
-
244
- truss_dir = truss_root / "test_data" / "test_streaming_read_timeout"
245
-
321
+ truss_dir = test_data_path / "test_streaming_read_timeout"
246
322
  tr = TrussHandle(truss_dir)
247
-
248
323
  container = tr.docker_run(
249
324
  local_port=8090, detach=True, wait_for_server_ready=True
250
325
  )
251
- truss_server_addr = "http://localhost:8090"
252
- predict_url = f"{truss_server_addr}/v1/models/model:predict"
253
326
 
254
327
  # ChunkedEncodingError is raised when the chunk does not get processed due to streaming read timeout
255
328
  with pytest.raises(requests.exceptions.ChunkedEncodingError):
256
- response = requests.post(predict_url, json={}, stream=True)
329
+ response = requests.post(PREDICT_URL, json={}, stream=True)
257
330
 
258
331
  for chunk in response.iter_content():
259
332
  pass
260
333
 
261
334
  # Check to ensure the Timeout error is in the container logs
262
- assert_logs_contain_error(
335
+ # TODO: maybe intercept this error better?
336
+ _assert_logs_contain_error(
263
337
  container.logs(),
264
338
  error="raise exceptions.TimeoutError()",
265
339
  message="Exception in ASGI application\n",
@@ -267,20 +341,16 @@ def test_async_streaming_timeout():
267
341
 
268
342
 
269
343
  @pytest.mark.integration
270
- def test_streaming_with_error():
344
+ def test_streaming_with_error_and_stacktrace(test_data_path):
271
345
  with ensure_kill_all():
272
- truss_root = Path(__file__).parent.parent.parent.resolve() / "truss"
273
-
274
- truss_dir = truss_root / "test_data" / "test_streaming_truss_with_error"
275
-
346
+ truss_dir = test_data_path / "test_streaming_truss_with_error"
276
347
  tr = TrussHandle(truss_dir)
277
-
278
- _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
279
- truss_server_addr = "http://localhost:8090"
280
- predict_url = f"{truss_server_addr}/v1/models/model:predict"
348
+ container = tr.docker_run(
349
+ local_port=8090, detach=True, wait_for_server_ready=True
350
+ )
281
351
 
282
352
  predict_error_response = requests.post(
283
- predict_url, json={"throw_error": True}, stream=True, timeout=2
353
+ PREDICT_URL, json={"throw_error": True}, stream=True, timeout=2
284
354
  )
285
355
 
286
356
  # In error cases, the response will return whatever the stream returned,
@@ -293,73 +363,28 @@ def test_streaming_with_error():
293
363
 
294
364
  # Test that we are able to continue to make requests successfully
295
365
  predict_non_error_response = requests.post(
296
- predict_url, json={"throw_error": False}, stream=True, timeout=2
366
+ PREDICT_URL, json={"throw_error": False}, stream=True, timeout=2
297
367
  )
298
368
 
299
369
  assert [
300
370
  byte_string.decode()
301
371
  for byte_string in predict_non_error_response.iter_content()
302
372
  ] == ["0", "1", "2", "3", "4"]
303
-
304
-
305
- @pytest.mark.integration
306
- def test_streaming_truss():
307
- with ensure_kill_all():
308
- truss_root = Path(__file__).parent.parent.parent.resolve() / "truss"
309
- truss_dir = truss_root / "test_data" / "test_streaming_truss"
310
- tr = TrussHandle(truss_dir)
311
-
312
- _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
313
-
314
- truss_server_addr = "http://localhost:8090"
315
- predict_url = f"{truss_server_addr}/v1/models/model:predict"
316
-
317
- # A request for which response is not completely read
318
- predict_response = requests.post(predict_url, json={}, stream=True)
319
- # We just read the first part and leave it hanging here
320
- next(predict_response.iter_content())
321
-
322
- predict_response = requests.post(predict_url, json={}, stream=True)
323
-
324
- assert predict_response.headers.get("transfer-encoding") == "chunked"
325
- assert [
326
- byte_string.decode()
327
- for byte_string in list(predict_response.iter_content())
328
- ] == ["0", "1", "2", "3", "4"]
329
-
330
- # When accept is set to application/json, the response is not streamed.
331
- predict_non_stream_response = requests.post(
332
- predict_url,
333
- json={},
334
- stream=True,
335
- headers={"accept": "application/json"},
373
+ expected_stack_trace = (
374
+ "Traceback (most recent call last):\n"
375
+ ' File "/app/model/model.py", line 12, in inner\n'
376
+ " helpers_1.foo(123)\n"
377
+ ' File "/packages/helpers_1.py", line 5, in foo\n'
378
+ " return helpers_2.bar(x)\n"
379
+ ' File "/packages/helpers_2.py", line 2, in bar\n'
380
+ ' raise Exception("Crashed in `bar`.")\n'
381
+ "Exception: Crashed in `bar`."
382
+ )
383
+ _assert_logs_contain_error(
384
+ container.logs(),
385
+ error=expected_stack_trace,
386
+ message="Exception while generating streamed response: Crashed in `bar`.",
336
387
  )
337
- assert "transfer-encoding" not in predict_non_stream_response.headers
338
- assert predict_non_stream_response.json() == "01234"
339
-
340
- # Test that concurrency work correctly. The streaming Truss has a configured
341
- # concurrency of 1, so only one request can be in flight at a time. Each request
342
- # takes 2 seconds, so with a timeout of 3 seconds, we expect the first request to
343
- # succeed and for the second to timeout.
344
- #
345
- # Note that with streamed requests, requests.post raises a ReadTimeout exception if
346
- # `timeout` seconds has passed since receiving any data from the server.
347
- def make_request(delay: int):
348
- # For streamed responses, requests does not start receiving content from server until
349
- # `iter_content` is called, so we must call this in order to get an actual timeout.
350
- time.sleep(delay)
351
- list(requests.post(predict_url, json={}, stream=True).iter_content())
352
-
353
- with ThreadPoolExecutor() as e:
354
- # We use concurrent.futures.wait instead of the timeout property
355
- # on requests, since requests timeout property has a complex interaction
356
- # with streaming.
357
- first_request = e.submit(make_request, 0)
358
- second_request = e.submit(make_request, 0.2)
359
- futures = [first_request, second_request]
360
- done, not_done = concurrent.futures.wait(futures, timeout=3)
361
- assert first_request in done
362
- assert second_request in not_done
363
388
 
364
389
 
365
390
  @pytest.mark.integration
@@ -378,106 +403,50 @@ secrets:
378
403
 
379
404
  config_with_no_secret = "model_name: secrets-truss"
380
405
  missing_secret_error_message = """Secret 'secret' not found. Please ensure that:
381
- * Secret 'secret' is defined in the 'secrets' section of the Truss config file
382
- * The model was pushed with the --trusted flag"""
406
+ * Secret 'secret' is defined in the 'secrets' section of the Truss config file"""
383
407
 
384
- with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
385
- truss_dir = Path(tmp_work_dir, "truss")
386
-
387
- create_truss(truss_dir, config, textwrap.dedent(inspect.getsource(Model)))
388
-
389
- tr = TrussHandle(truss_dir)
408
+ with ensure_kill_all(), _temp_truss(inspect.getsource(Model), config) as tr:
390
409
  LocalConfigHandler.set_secret("secret", "secret_value")
391
410
  _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
392
- truss_server_addr = "http://localhost:8090"
393
- full_url = f"{truss_server_addr}/v1/models/model:predict"
394
411
 
395
- response = requests.post(full_url, json={})
412
+ response = requests.post(PREDICT_URL, json={})
396
413
 
397
414
  assert response.json() == "secret_value"
398
415
 
399
- with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
400
- # Case where the secret is not specified in the config
401
- truss_dir = Path(tmp_work_dir, "truss")
402
-
403
- create_truss(
404
- truss_dir, config_with_no_secret, textwrap.dedent(inspect.getsource(Model))
405
- )
406
- tr = TrussHandle(truss_dir)
416
+ # Case where the secret is not specified in the config
417
+ with ensure_kill_all(), _temp_truss(
418
+ inspect.getsource(Model), config_with_no_secret
419
+ ) as tr:
407
420
  LocalConfigHandler.set_secret("secret", "secret_value")
408
421
  container = tr.docker_run(
409
422
  local_port=8090, detach=True, wait_for_server_ready=True
410
423
  )
411
- truss_server_addr = "http://localhost:8090"
412
- full_url = f"{truss_server_addr}/v1/models/model:predict"
413
-
414
- response = requests.post(full_url, json={})
415
424
 
425
+ response = requests.post(PREDICT_URL, json={})
416
426
  assert "error" in response.json()
417
-
418
- assert_logs_contain_error(container.logs(), missing_secret_error_message)
427
+ _assert_logs_contain_error(container.logs(), missing_secret_error_message)
419
428
  assert "Internal Server Error" in response.json()["error"]
429
+ assert response.headers["x-baseten-error-source"] == "04"
430
+ assert response.headers["x-baseten-error-code"] == "600"
420
431
 
421
- with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
422
- # Case where the secret is not mounted
423
- truss_dir = Path(tmp_work_dir, "truss")
424
-
425
- create_truss(truss_dir, config, textwrap.dedent(inspect.getsource(Model)))
426
- tr = TrussHandle(truss_dir)
432
+ # Case where the secret is not mounted
433
+ with ensure_kill_all(), _temp_truss(inspect.getsource(Model), config) as tr:
427
434
  LocalConfigHandler.remove_secret("secret")
428
435
  container = tr.docker_run(
429
436
  local_port=8090, detach=True, wait_for_server_ready=True
430
437
  )
431
- truss_server_addr = "http://localhost:8090"
432
- full_url = f"{truss_server_addr}/v1/models/model:predict"
433
438
 
434
- response = requests.post(full_url, json={})
439
+ response = requests.post(PREDICT_URL, json={})
435
440
  assert response.status_code == 500
436
-
437
- assert_logs_contain_error(container.logs(), missing_secret_error_message)
441
+ _assert_logs_contain_error(container.logs(), missing_secret_error_message)
438
442
  assert "Internal Server Error" in response.json()["error"]
439
-
440
-
441
- @pytest.mark.integration
442
- def test_prints_captured_in_log():
443
- class Model:
444
- def predict(self, request):
445
- print("This is a message from the Truss: Hello World!")
446
- return {}
447
-
448
- config = """model_name: printing-truss"""
449
-
450
- with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
451
- # Case where the secret is not specified in the config
452
- truss_dir = Path(tmp_work_dir, "truss")
453
-
454
- create_truss(truss_dir, config, textwrap.dedent(inspect.getsource(Model)))
455
- tr = TrussHandle(truss_dir)
456
- container = tr.docker_run(
457
- local_port=8090, detach=True, wait_for_server_ready=True
458
- )
459
- truss_server_addr = "http://localhost:8090"
460
- full_url = f"{truss_server_addr}/v1/models/model:predict"
461
-
462
- _ = requests.post(full_url, json={})
463
-
464
- loglines = container.logs().splitlines()
465
-
466
- relevant_line = None
467
- for line in loglines:
468
- logline = json.loads(line)
469
- if logline["message"] == "This is a message from the Truss: Hello World!":
470
- relevant_line = logline
471
- break
472
-
473
- # check that log line has other attributes and could be found
474
- assert relevant_line is not None, "Relevant log line not found."
475
- assert "asctime" in relevant_line
476
- assert "levelname" in relevant_line
443
+ assert response.headers["x-baseten-error-source"] == "04"
444
+ assert response.headers["x-baseten-error-code"] == "600"
477
445
 
478
446
 
479
447
  @pytest.mark.integration
480
448
  def test_postprocess_with_streaming_predict():
449
+ # TODO: revisit the decision to forbid this. If so remove below comment.
481
450
  """
482
451
  Test a Truss that has streaming response from both predict and postprocess.
483
452
  In this case, the postprocess step continues to happen within the predict lock,
@@ -498,26 +467,26 @@ def test_postprocess_with_streaming_predict():
498
467
  yield str(i)
499
468
  """
500
469
 
501
- config = "model_name: error-truss"
502
- with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
503
- truss_dir = Path(tmp_work_dir, "truss")
504
-
505
- create_truss(truss_dir, config, textwrap.dedent(model))
470
+ with ensure_kill_all(), _temp_truss(model) as tr:
471
+ container = tr.docker_run(
472
+ local_port=8090, detach=True, wait_for_server_ready=True
473
+ )
506
474
 
507
- tr = TrussHandle(truss_dir)
508
- _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
509
- truss_server_addr = "http://localhost:8090"
510
- full_url = f"{truss_server_addr}/v1/models/model:predict"
511
- response = requests.post(full_url, json={}, stream=True)
512
- # Note that the postprocess function is applied to the
513
- # streamed response.
514
- assert response.content == b"0 modified1 modified"
475
+ response = requests.post(PREDICT_URL, json={}, stream=True)
476
+ logging.info(response.content)
477
+ _assert_logs_contain_error(
478
+ container.logs(),
479
+ "ModelDefinitionError: If the predict function returns a generator (streaming), you cannot use postprocessing.",
480
+ )
481
+ assert "Internal Server Error" in response.json()["error"]
482
+ assert response.headers["x-baseten-error-source"] == "04"
483
+ assert response.headers["x-baseten-error-code"] == "600"
515
484
 
516
485
 
517
486
  @pytest.mark.integration
518
487
  def test_streaming_postprocess():
519
488
  """
520
- Tests a Truss where predict returns non-streaming, but postprocess is streamd, and
489
+ Tests a Truss where predict returns non-streaming, but postprocess is streamed, and
521
490
  ensures that the postprocess step does not happen within the predict lock. To do this,
522
491
  we sleep for two seconds during the postprocess streaming process, and fire off two
523
492
  requests with a total timeout of 3 seconds, ensuring that if they were serialized
@@ -536,22 +505,14 @@ def test_streaming_postprocess():
536
505
  return ["0", "1"]
537
506
  """
538
507
 
539
- config = "model_name: error-truss"
540
- with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
541
- truss_dir = Path(tmp_work_dir, "truss")
542
-
543
- create_truss(truss_dir, config, textwrap.dedent(model))
544
-
545
- tr = TrussHandle(truss_dir)
508
+ with ensure_kill_all(), _temp_truss(model) as tr:
546
509
  _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
547
- truss_server_addr = "http://localhost:8090"
548
- full_url = f"{truss_server_addr}/v1/models/model:predict"
549
510
 
550
511
  def make_request(delay: int):
551
512
  # For streamed responses, requests does not start receiving content from server until
552
513
  # `iter_content` is called, so we must call this in order to get an actual timeout.
553
514
  time.sleep(delay)
554
- response = requests.post(full_url, json={}, stream=True)
515
+ response = requests.post(PREDICT_URL, json={}, stream=True)
555
516
 
556
517
  assert response.status_code == 200
557
518
  assert response.content == b"0 modified1 modified"
@@ -599,20 +560,12 @@ def test_postprocess():
599
560
 
600
561
  """
601
562
 
602
- config = "model_name: error-truss"
603
- with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
604
- truss_dir = Path(tmp_work_dir, "truss")
605
-
606
- create_truss(truss_dir, config, textwrap.dedent(model))
607
-
608
- tr = TrussHandle(truss_dir)
563
+ with ensure_kill_all(), _temp_truss(model) as tr:
609
564
  _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
610
- truss_server_addr = "http://localhost:8090"
611
- full_url = f"{truss_server_addr}/v1/models/model:predict"
612
565
 
613
566
  def make_request(delay: int):
614
567
  time.sleep(delay)
615
- response = requests.post(full_url, json={})
568
+ response = requests.post(PREDICT_URL, json={})
616
569
  assert response.status_code == 200
617
570
  assert response.json() == ["0 modified", "1 modified"]
618
571
 
@@ -642,27 +595,20 @@ def test_truss_with_errors():
642
595
  raise ValueError("error")
643
596
  """
644
597
 
645
- config = "model_name: error-truss"
646
-
647
- with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
648
- truss_dir = Path(tmp_work_dir, "truss")
649
-
650
- create_truss(truss_dir, config, textwrap.dedent(model))
651
-
652
- tr = TrussHandle(truss_dir)
598
+ with ensure_kill_all(), _temp_truss(model) as tr:
653
599
  container = tr.docker_run(
654
600
  local_port=8090, detach=True, wait_for_server_ready=True
655
601
  )
656
- truss_server_addr = "http://localhost:8090"
657
- full_url = f"{truss_server_addr}/v1/models/model:predict"
658
602
 
659
- response = requests.post(full_url, json={})
603
+ response = requests.post(PREDICT_URL, json={})
660
604
  assert response.status_code == 500
661
605
  assert "error" in response.json()
662
606
 
663
- assert_logs_contain_error(container.logs(), "ValueError: error")
607
+ _assert_logs_contain_error(container.logs(), "ValueError: error")
664
608
 
665
609
  assert "Internal Server Error" in response.json()["error"]
610
+ assert response.headers["x-baseten-error-source"] == "04"
611
+ assert response.headers["x-baseten-error-code"] == "600"
666
612
 
667
613
  model_preprocess_error = """
668
614
  class Model:
@@ -673,24 +619,19 @@ def test_truss_with_errors():
673
619
  return {"a": "b"}
674
620
  """
675
621
 
676
- with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
677
- truss_dir = Path(tmp_work_dir, "truss")
678
-
679
- create_truss(truss_dir, config, textwrap.dedent(model_preprocess_error))
680
-
681
- tr = TrussHandle(truss_dir)
622
+ with ensure_kill_all(), _temp_truss(model_preprocess_error) as tr:
682
623
  container = tr.docker_run(
683
624
  local_port=8090, detach=True, wait_for_server_ready=True
684
625
  )
685
- truss_server_addr = "http://localhost:8090"
686
- full_url = f"{truss_server_addr}/v1/models/model:predict"
687
626
 
688
- response = requests.post(full_url, json={})
627
+ response = requests.post(PREDICT_URL, json={})
689
628
  assert response.status_code == 500
690
629
  assert "error" in response.json()
691
630
 
692
- assert_logs_contain_error(container.logs(), "ValueError: error")
631
+ _assert_logs_contain_error(container.logs(), "ValueError: error")
693
632
  assert "Internal Server Error" in response.json()["error"]
633
+ assert response.headers["x-baseten-error-source"] == "04"
634
+ assert response.headers["x-baseten-error-code"] == "600"
694
635
 
695
636
  model_postprocess_error = """
696
637
  class Model:
@@ -701,23 +642,18 @@ def test_truss_with_errors():
701
642
  raise ValueError("error")
702
643
  """
703
644
 
704
- with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
705
- truss_dir = Path(tmp_work_dir, "truss")
706
-
707
- create_truss(truss_dir, config, textwrap.dedent(model_postprocess_error))
708
-
709
- tr = TrussHandle(truss_dir)
645
+ with ensure_kill_all(), _temp_truss(model_postprocess_error) as tr:
710
646
  container = tr.docker_run(
711
647
  local_port=8090, detach=True, wait_for_server_ready=True
712
648
  )
713
- truss_server_addr = "http://localhost:8090"
714
- full_url = f"{truss_server_addr}/v1/models/model:predict"
715
649
 
716
- response = requests.post(full_url, json={})
650
+ response = requests.post(PREDICT_URL, json={})
717
651
  assert response.status_code == 500
718
652
  assert "error" in response.json()
719
- assert_logs_contain_error(container.logs(), "ValueError: error")
653
+ _assert_logs_contain_error(container.logs(), "ValueError: error")
720
654
  assert "Internal Server Error" in response.json()["error"]
655
+ assert response.headers["x-baseten-error-source"] == "04"
656
+ assert response.headers["x-baseten-error-code"] == "600"
721
657
 
722
658
  model_async = """
723
659
  class Model:
@@ -725,32 +661,93 @@ def test_truss_with_errors():
725
661
  raise ValueError("error")
726
662
  """
727
663
 
728
- with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
729
- truss_dir = Path(tmp_work_dir, "truss")
664
+ with ensure_kill_all(), _temp_truss(model_async) as tr:
665
+ container = tr.docker_run(
666
+ local_port=8090, detach=True, wait_for_server_ready=True
667
+ )
730
668
 
731
- create_truss(truss_dir, config, textwrap.dedent(model_async))
669
+ response = requests.post(PREDICT_URL, json={})
670
+ assert response.status_code == 500
671
+ assert "error" in response.json()
732
672
 
733
- tr = TrussHandle(truss_dir)
673
+ _assert_logs_contain_error(container.logs(), "ValueError: error")
674
+
675
+ assert "Internal Server Error" in response.json()["error"]
676
+ assert response.headers["x-baseten-error-source"] == "04"
677
+ assert response.headers["x-baseten-error-code"] == "600"
678
+
679
+
680
+ @pytest.mark.integration
681
+ def test_truss_with_user_errors():
682
+ """Test that user-code raised `fastapi.HTTPExceptions` are passed through as is."""
683
+ model = """
684
+ import fastapi
685
+
686
+ class Model:
687
+ def predict(self, request):
688
+ raise fastapi.HTTPException(status_code=500, detail="My custom message.")
689
+ """
690
+
691
+ with ensure_kill_all(), _temp_truss(model) as tr:
734
692
  container = tr.docker_run(
735
693
  local_port=8090, detach=True, wait_for_server_ready=True
736
694
  )
737
- truss_server_addr = "http://localhost:8090"
738
- full_url = f"{truss_server_addr}/v1/models/model:predict"
739
695
 
740
- response = requests.post(full_url, json={})
696
+ response = requests.post(PREDICT_URL, json={})
741
697
  assert response.status_code == 500
742
698
  assert "error" in response.json()
699
+ assert response.headers["x-baseten-error-source"] == "04"
700
+ assert response.headers["x-baseten-error-code"] == "600"
701
+
702
+ _assert_logs_contain_error(
703
+ container.logs(),
704
+ "HTTPException: 500: My custom message.",
705
+ "Model raised HTTPException",
706
+ )
707
+
708
+ assert "My custom message." in response.json()["error"]
709
+ assert response.headers["x-baseten-error-source"] == "04"
710
+ assert response.headers["x-baseten-error-code"] == "600"
711
+
712
+
713
+ @pytest.mark.integration
714
+ def test_truss_with_error_stacktrace(test_data_path):
715
+ with ensure_kill_all():
716
+ truss_dir = test_data_path / "test_truss_with_error"
717
+ tr = TrussHandle(truss_dir)
718
+ container = tr.docker_run(
719
+ local_port=8090, detach=True, wait_for_server_ready=True
720
+ )
743
721
 
744
- assert_logs_contain_error(container.logs(), "ValueError: error")
722
+ response = requests.post(PREDICT_URL, json={})
723
+ assert response.status_code == 500
724
+ assert "error" in response.json()
745
725
 
746
726
  assert "Internal Server Error" in response.json()["error"]
727
+ assert response.headers["x-baseten-error-source"] == "04"
728
+ assert response.headers["x-baseten-error-code"] == "600"
729
+
730
+ expected_stack_trace = (
731
+ "Traceback (most recent call last):\n"
732
+ ' File "/app/model/model.py", line 8, in predict\n'
733
+ " return helpers_1.foo(123)\n"
734
+ ' File "/packages/helpers_1.py", line 5, in foo\n'
735
+ " return helpers_2.bar(x)\n"
736
+ ' File "/packages/helpers_2.py", line 2, in bar\n'
737
+ ' raise Exception("Crashed in `bar`.")\n'
738
+ "Exception: Crashed in `bar`."
739
+ )
740
+ _assert_logs_contain_error(
741
+ container.logs(),
742
+ error=expected_stack_trace,
743
+ message="Internal Server Error",
744
+ )
747
745
 
748
746
 
749
747
  @pytest.mark.integration
750
- def test_slow_truss():
748
+ def test_slow_truss(test_data_path):
751
749
  with ensure_kill_all():
752
- truss_root = Path(__file__).parent.parent.parent.resolve() / "truss"
753
- truss_dir = truss_root / "test_data" / "server_conformance_test_truss"
750
+ truss_dir = test_data_path / "server_conformance_test_truss"
754
751
  tr = TrussHandle(truss_dir)
755
752
 
756
753
  _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=False)
@@ -765,6 +762,10 @@ def test_slow_truss():
765
762
  ready = requests.get(f"{truss_server_addr}/v1/models/model")
766
763
  assert ready.status_code == expected_code
767
764
 
765
+ def _test_is_loaded(expected_code):
766
+ ready = requests.get(f"{truss_server_addr}/v1/models/model/loaded")
767
+ assert ready.status_code == expected_code
768
+
768
769
  def _test_ping(expected_code):
769
770
  ping = requests.get(f"{truss_server_addr}/ping")
770
771
  assert ping.status_code == expected_code
@@ -786,6 +787,7 @@ def test_slow_truss():
786
787
  for _ in range(LOAD_TEST_TIME):
787
788
  _test_liveness_probe(200)
788
789
  _test_readiness_probe(503)
790
+ _test_is_loaded(503)
789
791
  _test_ping(503)
790
792
  _test_invocations(503)
791
793
  time.sleep(1)
@@ -793,6 +795,7 @@ def test_slow_truss():
793
795
  time.sleep(LOAD_BUFFER_TIME)
794
796
  _test_liveness_probe(200)
795
797
  _test_readiness_probe(200)
798
+ _test_is_loaded(200)
796
799
  _test_ping(200)
797
800
 
798
801
  predict_call = Thread(
@@ -805,9 +808,1054 @@ def test_slow_truss():
805
808
  for _ in range(PREDICT_TEST_TIME):
806
809
  _test_liveness_probe(200)
807
810
  _test_readiness_probe(200)
811
+ _test_is_loaded(200)
808
812
  _test_ping(200)
809
813
  time.sleep(1)
810
814
 
811
815
  predict_call.join()
812
816
 
813
817
  _test_invocations(200)
818
+
819
+
820
+ @pytest.mark.integration
821
+ def test_init_environment_parameter():
822
+ # Test a truss deployment that is associated with an environment
823
+ model = """
824
+ from typing import Optional
825
+ class Model:
826
+ def __init__(self, **kwargs):
827
+ self._config = kwargs["config"]
828
+ self._environment = kwargs["environment"]
829
+ self.environment_name = self._environment.get("name") if self._environment else None
830
+
831
+ def load(self):
832
+ print(f"Executing model.load with environment: {self.environment_name}")
833
+
834
+ def predict(self, model_input):
835
+ return self.environment_name
836
+ """
837
+ config = "model_name: init-environment-truss"
838
+ with ensure_kill_all(), _temp_truss(model, config) as tr:
839
+ # Mimic environment changing to staging
840
+ staging_env = {"name": "staging"}
841
+ staging_env_str = json.dumps(staging_env)
842
+ LocalConfigHandler.set_dynamic_config("environment", staging_env_str)
843
+ container = tr.docker_run(
844
+ local_port=8090, detach=True, wait_for_server_ready=True
845
+ )
846
+ assert "Executing model.load with environment: staging" in container.logs()
847
+ response = requests.post(PREDICT_URL, json={})
848
+ assert response.json() == "staging"
849
+ assert response.status_code == 200
850
+ container.execute(["bash", "-c", "rm -f /etc/b10_dynamic_config/environment"])
851
+
852
+ # Test a truss deployment with no associated environment
853
+ config = "model_name: init-no-environment-truss"
854
+ with ensure_kill_all(), _temp_truss(model, config) as tr:
855
+ container = tr.docker_run(
856
+ local_port=8090, detach=True, wait_for_server_ready=True
857
+ )
858
+ assert "Executing model.load with environment: None" in container.logs()
859
+ response = requests.post(PREDICT_URL, json={})
860
+ assert response.json() is None
861
+ assert response.status_code == 200
862
+
863
+
864
+ @pytest.mark.integration
865
+ def test_setup_environment():
866
+ # Test truss that uses setup_environment() without load()
867
+ model = """
868
+ from typing import Optional
869
+ class Model:
870
+ def setup_environment(self, environment: Optional[dict]):
871
+ print("setup_environment called with", environment)
872
+ self.environment_name = environment.get("name") if environment else None
873
+ print(f"in {self.environment_name} environment")
874
+
875
+ def predict(self, model_input):
876
+ return model_input
877
+ """
878
+ with ensure_kill_all(), _temp_truss(model, "") as tr:
879
+ container = tr.docker_run(
880
+ local_port=8090, detach=True, wait_for_server_ready=True
881
+ )
882
+ # Mimic environment changing to beta
883
+ beta_env = {"name": "beta"}
884
+ beta_env_str = json.dumps(beta_env)
885
+ container.execute(
886
+ [
887
+ "bash",
888
+ "-c",
889
+ f"echo '{beta_env_str}' > /etc/b10_dynamic_config/environment",
890
+ ]
891
+ )
892
+ time.sleep(30)
893
+ assert (
894
+ f"Executing model.setup_environment with environment: {beta_env}"
895
+ in container.logs()
896
+ )
897
+ single_quote_beta_env_str = beta_env_str.replace('"', "'")
898
+ assert (
899
+ f"setup_environment called with {single_quote_beta_env_str}"
900
+ in container.logs()
901
+ )
902
+ assert "in beta environment" in container.logs()
903
+ container.execute(["bash", "-c", "rm -f /etc/b10_dynamic_config/environment"])
904
+
905
+ # Test a truss that uses the environment in load()
906
+ model = """
907
+ from typing import Optional
908
+ class Model:
909
+ def setup_environment(self, environment: Optional[dict]):
910
+ print("setup_environment called with", environment)
911
+ self.environment_name = environment.get("name") if environment else None
912
+ print(f"in {self.environment_name} environment")
913
+
914
+ def load(self):
915
+ print("loading in environment", self.environment_name)
916
+
917
+ def predict(self, model_input):
918
+ return model_input
919
+ """
920
+ with ensure_kill_all(), _temp_truss(model, "") as tr:
921
+ # Mimic environment changing to staging
922
+ staging_env = {"name": "staging"}
923
+ staging_env_str = json.dumps(staging_env)
924
+ LocalConfigHandler.set_dynamic_config("environment", staging_env_str)
925
+ container = tr.docker_run(
926
+ local_port=8090, detach=True, wait_for_server_ready=True
927
+ )
928
+ # Don't need to wait here because we explicitly grab the environment from dynamic_config_resolver before calling user's load()
929
+ assert (
930
+ f"Executing model.setup_environment with environment: {staging_env}"
931
+ in container.logs()
932
+ )
933
+ single_quote_staging_env_str = staging_env_str.replace('"', "'")
934
+ assert (
935
+ f"setup_environment called with {single_quote_staging_env_str}"
936
+ in container.logs()
937
+ )
938
+ assert "in staging environment" in container.logs()
939
+ assert "loading in environment staging" in container.logs()
940
+ # Set environment to None
941
+ no_env = None
942
+ no_env_str = json.dumps(no_env)
943
+ container.execute(
944
+ ["bash", "-c", f"echo '{no_env_str}' > /etc/b10_dynamic_config/environment"]
945
+ )
946
+ time.sleep(30)
947
+ assert (
948
+ f"Executing model.setup_environment with environment: {no_env}"
949
+ in container.logs()
950
+ )
951
+ assert "setup_environment called with None" in container.logs()
952
+ container.execute(["bash", "-c", "rm -f /etc/b10_dynamic_config/environment"])
953
+
954
+
955
+ @pytest.mark.integration
956
+ def test_health_check_configuration():
957
+ model = """
958
+ class Model:
959
+ def predict(self, model_input):
960
+ return model_input
961
+ """
962
+
963
+ config = """runtime:
964
+ health_checks:
965
+ restart_check_delay_seconds: 100
966
+ restart_threshold_seconds: 1700
967
+ """
968
+
969
+ with ensure_kill_all(), _temp_truss(model, config) as tr:
970
+ _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
971
+
972
+ assert tr.spec.config.runtime.health_checks.restart_check_delay_seconds == 100
973
+ assert tr.spec.config.runtime.health_checks.restart_threshold_seconds == 1700
974
+ assert (
975
+ tr.spec.config.runtime.health_checks.stop_traffic_threshold_seconds is None
976
+ )
977
+
978
+ config = """runtime:
979
+ health_checks:
980
+ restart_check_delay_seconds: 1200
981
+ restart_threshold_seconds: 90
982
+ stop_traffic_threshold_seconds: 50
983
+ """
984
+
985
+ with ensure_kill_all(), _temp_truss(model, config) as tr:
986
+ _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
987
+
988
+ assert tr.spec.config.runtime.health_checks.restart_check_delay_seconds == 1200
989
+ assert tr.spec.config.runtime.health_checks.restart_threshold_seconds == 90
990
+ assert tr.spec.config.runtime.health_checks.stop_traffic_threshold_seconds == 50
991
+
992
+ with ensure_kill_all(), _temp_truss(model, "") as tr:
993
+ _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
994
+
995
+ assert tr.spec.config.runtime.health_checks.restart_check_delay_seconds is None
996
+ assert tr.spec.config.runtime.health_checks.restart_threshold_seconds is None
997
+ assert (
998
+ tr.spec.config.runtime.health_checks.stop_traffic_threshold_seconds is None
999
+ )
1000
+
1001
+
1002
+ @pytest.mark.integration
1003
+ def test_is_healthy():
1004
+ model = """
1005
+ class Model:
1006
+ def load(self):
1007
+ raise Exception("not loaded")
1008
+
1009
+ def is_healthy(self) -> bool:
1010
+ return True
1011
+
1012
+ def predict(self, model_input):
1013
+ return model_input
1014
+ """
1015
+ with ensure_kill_all(), _temp_truss(model, "") as tr:
1016
+ container = tr.docker_run(
1017
+ local_port=8090, detach=True, wait_for_server_ready=False
1018
+ )
1019
+
1020
+ truss_server_addr = "http://localhost:8090"
1021
+ for _ in range(5):
1022
+ time.sleep(1)
1023
+ healthy = requests.get(f"{truss_server_addr}/v1/models/model")
1024
+ if healthy.status_code == 503:
1025
+ break
1026
+ assert healthy.status_code == 200
1027
+ assert healthy.status_code == 503
1028
+ diff = container.diff()
1029
+ assert "/root/inference_server_crashed.txt" in diff
1030
+ assert diff["/root/inference_server_crashed.txt"] == "A"
1031
+
1032
+ model = """
1033
+ class Model:
1034
+ def is_healthy(self, argument) -> bool:
1035
+ pass
1036
+
1037
+ def predict(self, model_input):
1038
+ return model_input
1039
+ """
1040
+ with ensure_kill_all(), _temp_truss(model, "") as tr:
1041
+ container = tr.docker_run(
1042
+ local_port=8090, detach=True, wait_for_server_ready=False
1043
+ )
1044
+ time.sleep(1)
1045
+ _assert_logs_contain_error(
1046
+ container.logs(),
1047
+ message="Exception while loading model",
1048
+ error="`is_healthy` must have only one argument: `self`",
1049
+ )
1050
+
1051
+ model = """
1052
+ class Model:
1053
+ def is_healthy(self) -> bool:
1054
+ raise Exception("not healthy")
1055
+
1056
+ def predict(self, model_input):
1057
+ return model_input
1058
+ """
1059
+ with ensure_kill_all(), _temp_truss(model, "") as tr:
1060
+ container = tr.docker_run(
1061
+ local_port=8090, detach=True, wait_for_server_ready=False
1062
+ )
1063
+
1064
+ # Sleep a few seconds to get the server some time to wake up
1065
+ time.sleep(10)
1066
+
1067
+ truss_server_addr = "http://localhost:8090"
1068
+
1069
+ healthy = requests.get(f"{truss_server_addr}/v1/models/model")
1070
+ assert healthy.status_code == 503
1071
+ assert (
1072
+ "Exception while checking if model is healthy: not healthy"
1073
+ in container.logs()
1074
+ )
1075
+ assert "Health check failed." in container.logs()
1076
+
1077
+ model = """
1078
+ import time
1079
+
1080
+ class Model:
1081
+ def load(self):
1082
+ time.sleep(10)
1083
+
1084
+ def is_healthy(self) -> bool:
1085
+ return False
1086
+
1087
+ def predict(self, model_input):
1088
+ return model_input
1089
+ """
1090
+ with ensure_kill_all(), _temp_truss(model, "") as tr:
1091
+ container = tr.docker_run(
1092
+ local_port=8090, detach=True, wait_for_server_ready=False
1093
+ )
1094
+ truss_server_addr = "http://localhost:8090"
1095
+
1096
+ time.sleep(5)
1097
+ healthy = requests.get(f"{truss_server_addr}/v1/models/model")
1098
+ assert healthy.status_code == 503
1099
+ # Ensure we only log after model.load is complete
1100
+ assert "Health check failed." not in container.logs()
1101
+
1102
+ # Sleep a few seconds to get the server some time to wake up
1103
+ time.sleep(10)
1104
+
1105
+ healthy = requests.get(f"{truss_server_addr}/v1/models/model")
1106
+ assert healthy.status_code == 503
1107
+ assert container.logs().count("Health check failed.") == 1
1108
+ healthy = requests.get(f"{truss_server_addr}/v1/models/model")
1109
+ assert healthy.status_code == 503
1110
+ assert container.logs().count("Health check failed.") == 2
1111
+
1112
+ model = """
1113
+ import time
1114
+
1115
+ class Model:
1116
+ def __init__(self, **kwargs):
1117
+ self._healthy = False
1118
+
1119
+ def load(self):
1120
+ time.sleep(10)
1121
+ self._healthy = True
1122
+
1123
+ def is_healthy(self):
1124
+ return self._healthy
1125
+
1126
+ def predict(self, model_input):
1127
+ self._healthy = model_input["healthy"]
1128
+ return model_input
1129
+ """
1130
+ with ensure_kill_all(), _temp_truss(model, "") as tr:
1131
+ tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=False)
1132
+ time.sleep(5)
1133
+ truss_server_addr = "http://localhost:8090"
1134
+ healthy = requests.get(f"{truss_server_addr}/v1/models/model")
1135
+ assert healthy.status_code == 503
1136
+ time.sleep(10)
1137
+ healthy = requests.get(f"{truss_server_addr}/v1/models/model")
1138
+ assert healthy.status_code == 200
1139
+
1140
+ healthy_responses = [True, "yessss", 34, {"woo": "hoo"}]
1141
+ for response in healthy_responses:
1142
+ predict_response = requests.post(PREDICT_URL, json={"healthy": response})
1143
+ assert predict_response.status_code == 200
1144
+ healthy = requests.get(f"{truss_server_addr}/v1/models/model")
1145
+ assert healthy.status_code == 200
1146
+
1147
+ not_healthy_responses = [False, "", 0, {}]
1148
+ for response in not_healthy_responses:
1149
+ predict_response = requests.post(PREDICT_URL, json={"healthy": response})
1150
+ assert predict_response.status_code == 200
1151
+ healthy = requests.get(f"{truss_server_addr}/v1/models/model")
1152
+ assert healthy.status_code == 503
1153
+
1154
+ model = """
1155
+ class Model:
1156
+ def is_healthy(self) -> bool:
1157
+ return True
1158
+
1159
+ def predict(self, model_input):
1160
+ return model_input
1161
+ """
1162
+ with ensure_kill_all(), _temp_truss(model, "") as tr:
1163
+ _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
1164
+
1165
+ truss_server_addr = "http://localhost:8090"
1166
+
1167
+ healthy = requests.get(f"{truss_server_addr}/v1/models/model")
1168
+ assert healthy.status_code == 200
1169
+
1170
+
1171
+ def _patch_termination_timeout(container: Container, seconds: int, truss_container_fs):
1172
+ app_path = truss_container_fs / "app"
1173
+ sys.path.append(str(app_path))
1174
+ import truss_server
1175
+
1176
+ local_server_source = pathlib.Path(truss_server.__file__)
1177
+ container_server_source = "/app/truss_server.py"
1178
+ modified_content = local_server_source.read_text().replace(
1179
+ "TIMEOUT_GRACEFUL_SHUTDOWN = 120", f"TIMEOUT_GRACEFUL_SHUTDOWN = {seconds}"
1180
+ )
1181
+ with tempfile.NamedTemporaryFile() as patched_file:
1182
+ patched_file.write(modified_content.encode("utf-8"))
1183
+ patched_file.flush()
1184
+ container.copy_to(patched_file.name, container_server_source)
1185
+
1186
+
1187
+ @pytest.mark.anyio
1188
+ @pytest.mark.integration
1189
+ async def test_graceful_shutdown(truss_container_fs):
1190
+ model = """
1191
+ import time
1192
+ class Model:
1193
+ def predict(self, request):
1194
+ print(f"Received {request}")
1195
+ time.sleep(request["seconds"])
1196
+ print(f"Done {request}")
1197
+ return request
1198
+ """
1199
+
1200
+ async def predict_request(data: dict):
1201
+ async with httpx.AsyncClient() as client:
1202
+ response = await client.post(PREDICT_URL, json=data)
1203
+ response.raise_for_status()
1204
+ return response.json()
1205
+
1206
+ with ensure_kill_all(), _temp_truss(model) as tr:
1207
+ container = tr.docker_run(
1208
+ local_port=8090, detach=True, wait_for_server_ready=True
1209
+ )
1210
+ await predict_request({"seconds": 0, "task": 0}) # Warm up server.
1211
+
1212
+ # Test starting two requests, each taking 2 seconds, then terminating server.
1213
+ # They should both finish successfully since the server grace period is 120 s.
1214
+ task_0 = asyncio.create_task(predict_request({"seconds": 2, "task": 0}))
1215
+ await asyncio.sleep(0.1) # Yield to event loop to make above task run.
1216
+ task_1 = asyncio.create_task(predict_request({"seconds": 2, "task": 1}))
1217
+ await asyncio.sleep(0.1) # Yield to event loop to make above task run.
1218
+
1219
+ t0 = time.perf_counter()
1220
+ # Even though the server has 120s grace period, we expect to finish much
1221
+ # faster in the test here, so use 10s.
1222
+ container.stop(10)
1223
+ stop_time = time.perf_counter() - t0
1224
+ print(f"Stopped in {stop_time} seconds,")
1225
+
1226
+ assert 3 < stop_time < 5
1227
+ assert (await task_0) == {"seconds": 2, "task": 0}
1228
+ assert (await task_1) == {"seconds": 2, "task": 1}
1229
+
1230
+ # Now mess around in the docker container to reduce the grace period to 3 s.
1231
+ # (There's not nice way to patch this...)
1232
+ _patch_termination_timeout(container, 3, truss_container_fs)
1233
+ # Now only one request should complete.
1234
+ container.restart()
1235
+ wait_for_truss("http://localhost:8090", container, True)
1236
+ await predict_request({"seconds": 0, "task": 0}) # Warm up server.
1237
+
1238
+ task_2 = asyncio.create_task(predict_request({"seconds": 2, "task": 2}))
1239
+ await asyncio.sleep(0.1) # Yield to event loop to make above task run.
1240
+ task_3 = asyncio.create_task(predict_request({"seconds": 2, "task": 3}))
1241
+ await asyncio.sleep(0.1) # Yield to event loop to make above task run.
1242
+ t0 = time.perf_counter()
1243
+ container.stop(10)
1244
+ stop_time = time.perf_counter() - t0
1245
+ print(f"Stopped in {stop_time} seconds,")
1246
+ assert 3 < stop_time < 5
1247
+ assert (await task_2) == {"seconds": 2, "task": 2}
1248
+ with pytest.raises(httpx.HTTPStatusError):
1249
+ await task_3
1250
+
1251
+
1252
+ # Tracing ##############################################################################
1253
+
1254
+
1255
+ def _make_otel_headers() -> Mapping[str, str]:
1256
+ """
1257
+ Create and return a mapping with OpenTelemetry trace context headers.
1258
+
1259
+ This function starts a new span and injects the trace context into the headers,
1260
+ which can be used to propagate tracing information in outgoing HTTP requests.
1261
+
1262
+ Returns:
1263
+ Mapping[str, str]: A mapping containing the trace context headers.
1264
+ """
1265
+ # Initialize a tracer
1266
+ tracer = trace.get_tracer(__name__)
1267
+
1268
+ # Create a dictionary to hold the headers
1269
+ headers: dict[str, str] = {}
1270
+
1271
+ # Start a new span
1272
+ with tracer.start_as_current_span("outgoing-request-span"):
1273
+ # Use the TraceContextTextMapPropagator to inject the trace context into the headers
1274
+ propagator = tracecontext.TraceContextTextMapPropagator()
1275
+ propagator.inject(headers, context=context.get_current())
1276
+
1277
+ return headers
1278
+
1279
+
1280
+ @pytest.mark.integration
1281
+ @pytest.mark.parametrize("enable_tracing_data", [True, False])
1282
+ def test_streaming_truss_with_user_tracing(test_data_path, enable_tracing_data):
1283
+ with ensure_kill_all():
1284
+ truss_dir = test_data_path / "test_streaming_truss_with_tracing"
1285
+ tr = TrussHandle(truss_dir)
1286
+
1287
+ def enable_gpu_fn(conf):
1288
+ new_runtime = dataclasses.replace(
1289
+ conf.runtime, enable_tracing_data=enable_tracing_data
1290
+ )
1291
+ return dataclasses.replace(conf, runtime=new_runtime)
1292
+
1293
+ tr._update_config(enable_gpu_fn)
1294
+
1295
+ container = tr.docker_run(
1296
+ local_port=8090, detach=True, wait_for_server_ready=True
1297
+ )
1298
+
1299
+ # A request for which response is not completely read
1300
+ headers_0 = _make_otel_headers()
1301
+ predict_response = requests.post(
1302
+ PREDICT_URL, json={}, stream=True, headers=headers_0
1303
+ )
1304
+ # We just read the first part and leave it hanging here
1305
+ next(predict_response.iter_content())
1306
+
1307
+ headers_1 = _make_otel_headers()
1308
+ predict_response = requests.post(
1309
+ PREDICT_URL, json={}, stream=True, headers=headers_1
1310
+ )
1311
+ assert predict_response.headers.get("transfer-encoding") == "chunked"
1312
+
1313
+ # When accept is set to application/json, the response is not streamed.
1314
+ headers_2 = _make_otel_headers()
1315
+ predict_non_stream_response = requests.post(
1316
+ PREDICT_URL,
1317
+ json={},
1318
+ stream=True,
1319
+ headers={**headers_2, "accept": "application/json"},
1320
+ )
1321
+ assert "transfer-encoding" not in predict_non_stream_response.headers
1322
+ assert predict_non_stream_response.json() == "01234"
1323
+
1324
+ with tempfile.TemporaryDirectory() as tmp_dir:
1325
+ truss_traces_file = pathlib.Path(tmp_dir) / "otel_traces.ndjson"
1326
+ container.copy_from("/tmp/otel_traces.ndjson", truss_traces_file)
1327
+ truss_traces = [
1328
+ json.loads(s) for s in truss_traces_file.read_text().splitlines()
1329
+ ]
1330
+
1331
+ user_traces_file = pathlib.Path(tmp_dir) / "otel_user_traces.ndjson"
1332
+ container.copy_from("/tmp/otel_user_traces.ndjson", user_traces_file)
1333
+ user_traces = [
1334
+ json.loads(s) for s in user_traces_file.read_text().splitlines()
1335
+ ]
1336
+
1337
+ if not enable_tracing_data:
1338
+ assert len(truss_traces) == 0
1339
+ assert len(user_traces) > 0
1340
+ return
1341
+
1342
+ assert sum(1 for x in truss_traces if x["name"] == "predict-endpoint") == 3
1343
+ assert sum(1 for x in user_traces if x["name"] == "load_model") == 1
1344
+ assert sum(1 for x in user_traces if x["name"] == "predict") == 3
1345
+
1346
+ user_parents = set(x["parent_id"] for x in user_traces)
1347
+ truss_spans = set(x["context"]["span_id"] for x in truss_traces)
1348
+ truss_parents = set(x["parent_id"] for x in truss_traces)
1349
+ # Make sure there is no context creep into user traces. No user trace should
1350
+ # have a truss trace as parent.
1351
+ assert user_parents & truss_spans == set()
1352
+ # But make sure traces have parents at all.
1353
+ assert len(user_parents) > 3
1354
+ assert len(truss_parents) > 3
1355
+
1356
+
1357
+ # Returning Response Objects ###########################################################
1358
+
1359
+
1360
+ @pytest.mark.integration
1361
+ def test_truss_with_response():
1362
+ """Test that user-code can set a custom status code."""
1363
+ model = """
1364
+ from fastapi.responses import Response
1365
+
1366
+ class Model:
1367
+ def predict(self, inputs):
1368
+ return Response(status_code=inputs["code"])
1369
+ """
1370
+ from fastapi import status
1371
+
1372
+ with ensure_kill_all(), _temp_truss(model) as tr:
1373
+ _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
1374
+
1375
+ response = requests.post(PREDICT_URL, json={"code": status.HTTP_204_NO_CONTENT})
1376
+ assert response.status_code == 204
1377
+ assert "x-baseten-error-source" not in response.headers
1378
+ assert "x-baseten-error-code" not in response.headers
1379
+
1380
+ response = requests.post(
1381
+ PREDICT_URL, json={"code": status.HTTP_500_INTERNAL_SERVER_ERROR}
1382
+ )
1383
+ assert response.status_code == 500
1384
+ assert response.headers["x-baseten-error-source"] == "04"
1385
+ assert response.headers["x-baseten-error-code"] == "700"
1386
+
1387
+
1388
+ @pytest.mark.integration
1389
+ def test_truss_with_streaming_response():
1390
+ # TODO: one issue with this is that (unlike our "builtin" streaming), this keeps
1391
+ # the semaphore claimed potentially longer if the client drops.
1392
+
1393
+ model = """from starlette.responses import StreamingResponse
1394
+ class Model:
1395
+ def predict(self, model_input):
1396
+ def text_generator():
1397
+ for i in range(3):
1398
+ yield f"data: {i}\\n\\n"
1399
+ return StreamingResponse(text_generator(), media_type="text/event-stream")
1400
+ """
1401
+
1402
+ with ensure_kill_all(), _temp_truss(model) as tr:
1403
+ _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
1404
+
1405
+ # A request for which response is not completely read.
1406
+ predict_response = requests.post(PREDICT_URL, json={}, stream=True)
1407
+ assert (
1408
+ predict_response.headers["Content-Type"]
1409
+ == "text/event-stream; charset=utf-8"
1410
+ )
1411
+
1412
+ lines = predict_response.text.strip().split("\n")
1413
+ assert lines == ["data: 0", "", "data: 1", "", "data: 2"]
1414
+
1415
+
1416
+ # Using Request in Model ###############################################################
1417
+
1418
+
1419
+ @pytest.mark.integration
1420
+ def test_truss_with_request():
1421
+ model = """
1422
+ import fastapi
1423
+ class Model:
1424
+ async def preprocess(self, request: fastapi.Request):
1425
+ return await request.json()
1426
+
1427
+ async def predict(self, inputs, request: fastapi.Request):
1428
+ inputs["request_size"] = len(await request.body())
1429
+ return inputs
1430
+
1431
+ def postprocess(self, inputs):
1432
+ return {**inputs, "postprocess": "was here"}
1433
+ """
1434
+ with ensure_kill_all(), _temp_truss(model) as tr:
1435
+ _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
1436
+
1437
+ response = requests.post(PREDICT_URL, json={"test": 123})
1438
+ assert response.status_code == 200
1439
+ assert response.json() == {
1440
+ "test": 123,
1441
+ "request_size": 13,
1442
+ "postprocess": "was here",
1443
+ }
1444
+
1445
+
1446
+ @pytest.mark.integration
1447
+ def test_truss_with_requests_and_invalid_signatures():
1448
+ model = """
1449
+ class Model:
1450
+ def predict(self, inputs, invalid_arg): ...
1451
+ """
1452
+ with ensure_kill_all(), _temp_truss(model) as tr:
1453
+ container = tr.docker_run(
1454
+ local_port=8090, detach=True, wait_for_server_ready=False
1455
+ )
1456
+ time.sleep(1.0) # Wait for logs.
1457
+ _assert_logs_contain_error(
1458
+ container.logs(),
1459
+ "`predict` method with two arguments must have request as second argument",
1460
+ "Exception while loading model",
1461
+ )
1462
+
1463
+ model = """
1464
+ import fastapi
1465
+
1466
+ class Model:
1467
+ def predict(self, request: fastapi.Request, invalid_arg): ...
1468
+ """
1469
+ with ensure_kill_all(), _temp_truss(model) as tr:
1470
+ container = tr.docker_run(
1471
+ local_port=8090, detach=True, wait_for_server_ready=False
1472
+ )
1473
+ time.sleep(1.0) # Wait for logs.
1474
+ _assert_logs_contain_error(
1475
+ container.logs(),
1476
+ "`predict` method with two arguments is not allowed to have request as "
1477
+ "first argument",
1478
+ "Exception while loading model",
1479
+ )
1480
+
1481
+ model = """
1482
+ import fastapi
1483
+
1484
+ class Model:
1485
+ def predict(self, inputs, request: fastapi.Request, something): ...
1486
+ """
1487
+ with ensure_kill_all(), _temp_truss(model) as tr:
1488
+ container = tr.docker_run(
1489
+ local_port=8090, detach=True, wait_for_server_ready=False
1490
+ )
1491
+ time.sleep(1.0) # Wait for logs.
1492
+ _assert_logs_contain_error(
1493
+ container.logs(),
1494
+ "`predict` method cannot have more than two arguments",
1495
+ "Exception while loading model",
1496
+ )
1497
+
1498
+
1499
+ @pytest.mark.integration
1500
+ def test_truss_with_requests_and_invalid_argument_combinations():
1501
+ model = """
1502
+ import fastapi
1503
+ class Model:
1504
+ async def preprocess(self, inputs): ...
1505
+
1506
+ def predict(self, request: fastapi.Request): ...
1507
+ """
1508
+ with ensure_kill_all(), _temp_truss(model) as tr:
1509
+ container = tr.docker_run(
1510
+ local_port=8090, detach=True, wait_for_server_ready=False
1511
+ )
1512
+ time.sleep(1.0) # Wait for logs.
1513
+ _assert_logs_contain_error(
1514
+ container.logs(),
1515
+ "When using `preprocess`, the predict method cannot only have the request argument",
1516
+ "Exception while loading model",
1517
+ )
1518
+
1519
+ model = """
1520
+ import fastapi
1521
+ class Model:
1522
+ def preprocess(self, inputs): ...
1523
+
1524
+ async def predict(self, inputs, request: fastapi.Request): ...
1525
+
1526
+ def postprocess(self, request: fastapi.Request): ...
1527
+ """
1528
+ with ensure_kill_all(), _temp_truss(model) as tr:
1529
+ container = tr.docker_run(
1530
+ local_port=8090, detach=True, wait_for_server_ready=False
1531
+ )
1532
+ time.sleep(1.0) # Wait for logs.
1533
+ _assert_logs_contain_error(
1534
+ container.logs(),
1535
+ "The `postprocess` method cannot only have the request argument",
1536
+ "Exception while loading model",
1537
+ )
1538
+
1539
+ model = """
1540
+ import fastapi
1541
+ class Model:
1542
+ def preprocess(self, inputs): ...
1543
+ """
1544
+ with ensure_kill_all(), _temp_truss(model) as tr:
1545
+ container = tr.docker_run(
1546
+ local_port=8090, detach=True, wait_for_server_ready=False
1547
+ )
1548
+ time.sleep(1.0) # Wait for logs.
1549
+ _assert_logs_contain_error(
1550
+ container.logs(),
1551
+ "Truss model must have a `predict` method.",
1552
+ "Exception while loading model",
1553
+ )
1554
+
1555
+
1556
+ @pytest.mark.integration
1557
+ def test_truss_forbid_postprocessing_with_response():
1558
+ model = """
1559
+ import fastapi, json
1560
+ class Model:
1561
+ def predict(self, inputs):
1562
+ return fastapi.Response(content=json.dumps(inputs), status_code=200)
1563
+
1564
+ def postprocess(self, inputs):
1565
+ return inputs
1566
+ """
1567
+ with ensure_kill_all(), _temp_truss(model) as tr:
1568
+ container = tr.docker_run(
1569
+ local_port=8090, detach=True, wait_for_server_ready=True
1570
+ )
1571
+
1572
+ response = requests.post(PREDICT_URL, json={})
1573
+ assert response.status_code == 500
1574
+ assert response.headers["x-baseten-error-source"] == "04"
1575
+ assert response.headers["x-baseten-error-code"] == "600"
1576
+ _assert_logs_contain_error(
1577
+ container.logs(),
1578
+ "If the predict function returns a response object, you cannot "
1579
+ "use postprocessing.",
1580
+ )
1581
+
1582
+
1583
+ @pytest.mark.integration
1584
+ def test_async_streaming_with_cancellation():
1585
+ model = """
1586
+ import fastapi, asyncio, logging
1587
+
1588
+ class Model:
1589
+ async def predict(self, inputs, request: fastapi.Request):
1590
+ await asyncio.sleep(1)
1591
+ if await request.is_disconnected():
1592
+ logging.warning("Cancelled (before gen).")
1593
+ return
1594
+
1595
+ for i in range(5):
1596
+ await asyncio.sleep(1.0)
1597
+ logging.warning(i)
1598
+ yield str(i)
1599
+ if await request.is_disconnected():
1600
+ logging.warning("Cancelled (during gen).")
1601
+ return
1602
+ """
1603
+ with ensure_kill_all(), _temp_truss(model) as tr:
1604
+ container = tr.docker_run(
1605
+ local_port=8090, detach=True, wait_for_server_ready=True
1606
+ )
1607
+ # For hard cancellation we need to use httpx, requests' timeouts don't work.
1608
+ with pytest.raises(httpx.ReadTimeout):
1609
+ with httpx.Client(
1610
+ timeout=httpx.Timeout(1.0, connect=1.0, read=1.0)
1611
+ ) as client:
1612
+ response = client.post(PREDICT_URL, json={}, timeout=1.0)
1613
+ response.raise_for_status()
1614
+
1615
+ time.sleep(2) # Wait a bit to get all logs.
1616
+ assert "Cancelled (during gen)." in container.logs()
1617
+
1618
+
1619
+ @pytest.mark.integration
1620
+ def test_async_non_streaming_with_cancellation():
1621
+ model = """
1622
+ import fastapi, asyncio, logging
1623
+
1624
+ class Model:
1625
+ async def predict(self, inputs, request: fastapi.Request):
1626
+ logging.info("Start sleep")
1627
+ await asyncio.sleep(2)
1628
+ logging.info("done sleep, check request.")
1629
+ if await request.is_disconnected():
1630
+ logging.warning("Cancelled (before gen).")
1631
+ return
1632
+ logging.info("Not cancelled.")
1633
+ return "Done"
1634
+ """
1635
+ with ensure_kill_all(), _temp_truss(model) as tr:
1636
+ container = tr.docker_run(
1637
+ local_port=8090, detach=True, wait_for_server_ready=True
1638
+ )
1639
+ # For hard cancellation we need to use httpx, requests' timeouts don't work.
1640
+ with pytest.raises(httpx.ReadTimeout):
1641
+ with httpx.Client(
1642
+ timeout=httpx.Timeout(1.0, connect=1.0, read=1.0)
1643
+ ) as client:
1644
+ response = client.post(PREDICT_URL, json={}, timeout=1.0)
1645
+ response.raise_for_status()
1646
+
1647
+ time.sleep(2) # Wait a bit to get all logs.
1648
+ assert "Cancelled (before gen)." in container.logs()
1649
+
1650
+
1651
+ @pytest.mark.integration
1652
+ def test_limit_concurrency_with_sse():
1653
+ # It seems that the "builtin" functionality of the FastAPI server already buffers
1654
+ # the generator, so that it doesn't keep hanging around if the client doesn't
1655
+ # consume data. `_buffered_response_generator` might be redundant.
1656
+ # This can be observed by waiting for a long time in `make_request`: the server will
1657
+ # print `Done` for the tasks, while we still wait and hold the unconsumed response.
1658
+ # For testing we need to have actually slow generation to keep the server busy.
1659
+ model = """
1660
+ import asyncio
1661
+
1662
+ class Model:
1663
+ async def predict(self, request):
1664
+ print(f"Starting {request}")
1665
+ for i in range(5):
1666
+ await asyncio.sleep(0.1)
1667
+ yield str(i)
1668
+ print(f"Done {request}")
1669
+
1670
+ """
1671
+
1672
+ config = """runtime:
1673
+ predict_concurrency: 2"""
1674
+
1675
+ def make_request(consume_chunks, timeout, task_id):
1676
+ t0 = time.time()
1677
+ with httpx.Client() as client:
1678
+ with client.stream(
1679
+ "POST", PREDICT_URL, json={"task_id": task_id}
1680
+ ) as response:
1681
+ assert response.status_code == 200
1682
+ if consume_chunks:
1683
+ chunks = [chunk for chunk in response.iter_text()]
1684
+ print(f"consumed chunks ({task_id}): {chunks}")
1685
+ assert len(chunks) > 0
1686
+ t1 = time.time()
1687
+ if t1 - t0 > timeout:
1688
+ raise httpx.ReadTimeout("Timeout")
1689
+ return chunks
1690
+ else:
1691
+ print(f"waiting ({task_id})")
1692
+ time.sleep(0.5) # Hold the connection.
1693
+ print(f"waiting done ({task_id})")
1694
+
1695
+ with ensure_kill_all(), _temp_truss(model, config) as tr:
1696
+ _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
1697
+ # Processing full request takes 0.5s.
1698
+ print("Make warmup request")
1699
+ make_request(consume_chunks=True, timeout=0.55, task_id=0)
1700
+
1701
+ with ThreadPoolExecutor() as executor:
1702
+ # Start two requests and hold them without consuming all chunks
1703
+ # Each takes for 0.5 s. Semaphore should be claimed, with 0 remaining.
1704
+ print("Start two tasks.")
1705
+ task1 = executor.submit(make_request, False, 0.55, 1)
1706
+ task2 = executor.submit(make_request, False, 0.55, 2)
1707
+ print("Wait for tasks to start.")
1708
+ time.sleep(0.05)
1709
+ print("Make a request while server is busy.")
1710
+ with pytest.raises(httpx.ReadTimeout):
1711
+ make_request(True, timeout=0.55, task_id=3)
1712
+
1713
+ task1.result()
1714
+ task2.result()
1715
+ print("Task 1 and 2 completed. Server should be free again.")
1716
+
1717
+ result = make_request(True, timeout=0.55, task_id=4)
1718
+ print(f"Final chunks: {result}")
1719
+
1720
+
1721
+ @pytest.mark.integration
1722
+ def test_custom_openai_endpoints():
1723
+ """
1724
+ Test a Truss that exposes an OpenAI compatible endpoint.
1725
+ """
1726
+ model = """
1727
+ from typing import Dict
1728
+
1729
+ class Model:
1730
+ def __init__(self):
1731
+ pass
1732
+
1733
+ def load(self):
1734
+ self._predict_count = 0
1735
+ self._completions_count = 0
1736
+
1737
+ async def predict(self, inputs: Dict) -> int:
1738
+ self._predict_count += inputs["increment"]
1739
+ return self._predict_count
1740
+
1741
+ async def completions(self, inputs: Dict) -> int:
1742
+ self._completions_count += inputs["increment"]
1743
+ return self._completions_count
1744
+ """
1745
+ with ensure_kill_all(), _temp_truss(model) as tr:
1746
+ tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
1747
+
1748
+ response = requests.post(PREDICT_URL, json={"increment": 1})
1749
+ assert response.status_code == 200
1750
+ assert response.json() == 1
1751
+
1752
+ response = requests.post(COMPLETIONS_URL, json={"increment": 2})
1753
+ assert response.status_code == 200
1754
+ assert response.json() == 2
1755
+
1756
+ response = requests.post(CHAT_COMPLETIONS_URL, json={"increment": 3})
1757
+ assert response.status_code == 404
1758
+
1759
+
1760
+ @pytest.mark.integration
1761
+ def test_postprocess_async_generator_streaming():
1762
+ """
1763
+ Test a Truss that exposes an OpenAI compatible endpoint.
1764
+ """
1765
+ model = """
1766
+ from typing import Dict, List, Generator
1767
+
1768
+ class Model:
1769
+ def __init__(self):
1770
+ pass
1771
+
1772
+ def load(self):
1773
+ pass
1774
+
1775
+ async def predict(self, inputs: Dict) -> List[str]:
1776
+ nums: List[int] = inputs["nums"]
1777
+ return nums
1778
+
1779
+ async def postprocess(self, nums: List[str]) -> Generator[str, None, None]:
1780
+ for num in nums:
1781
+ yield num
1782
+ """
1783
+ with ensure_kill_all(), _temp_truss(model) as tr:
1784
+ tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
1785
+
1786
+ response = requests.post(PREDICT_URL, json={"nums": ["1", "2"]}, stream=True)
1787
+ assert response.headers.get("transfer-encoding") == "chunked"
1788
+ assert [
1789
+ byte_string.decode() for byte_string in list(response.iter_content())
1790
+ ] == ["1", "2"]
1791
+
1792
+
1793
+ @pytest.mark.integration
1794
+ def test_preprocess_async_generator():
1795
+ """
1796
+ Test a Truss that exposes an OpenAI compatible endpoint.
1797
+ """
1798
+ model = """
1799
+ from typing import Dict, List, AsyncGenerator
1800
+
1801
+ class Model:
1802
+ def __init__(self):
1803
+ pass
1804
+
1805
+ def load(self):
1806
+ pass
1807
+
1808
+ async def preprocess(self, inputs: Dict) -> AsyncGenerator[str, None]:
1809
+ for num in inputs["nums"]:
1810
+ yield num
1811
+
1812
+ async def predict(self, nums: AsyncGenerator[str, None]) -> List[str]:
1813
+ return [num async for num in nums]
1814
+ """
1815
+ with ensure_kill_all(), _temp_truss(model) as tr:
1816
+ tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
1817
+
1818
+ response = requests.post(PREDICT_URL, json={"nums": ["1", "2"]})
1819
+ assert response.status_code == 200
1820
+ assert response.json() == ["1", "2"]
1821
+
1822
+
1823
+ @pytest.mark.integration
1824
+ def test_openai_client_streaming():
1825
+ """
1826
+ Test a Truss that exposes an OpenAI compatible endpoint.
1827
+ """
1828
+ model = """
1829
+ from typing import Dict, AsyncGenerator
1830
+
1831
+ class Model:
1832
+ def __init__(self):
1833
+ pass
1834
+
1835
+ def load(self):
1836
+ pass
1837
+
1838
+ async def chat_completions(self, inputs: Dict) -> AsyncGenerator[str, None]:
1839
+ for num in inputs["nums"]:
1840
+ yield num
1841
+
1842
+ async def predict(self, inputs: Dict):
1843
+ pass
1844
+ """
1845
+ with ensure_kill_all(), _temp_truss(model) as tr:
1846
+ tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
1847
+
1848
+ response = requests.post(
1849
+ CHAT_COMPLETIONS_URL,
1850
+ json={"nums": ["1", "2"]},
1851
+ stream=True,
1852
+ # Despite requesting json, we should still stream results back.
1853
+ headers={
1854
+ "accept": "application/json",
1855
+ "user-agent": "OpenAI/Python 1.61.0",
1856
+ },
1857
+ )
1858
+ assert response.headers.get("transfer-encoding") == "chunked"
1859
+ assert [
1860
+ byte_string.decode() for byte_string in list(response.iter_content())
1861
+ ] == ["1", "2"]