truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

Files changed (362) hide show
  1. truss/__init__.py +10 -3
  2. truss/api/__init__.py +123 -0
  3. truss/api/definitions.py +51 -0
  4. truss/base/constants.py +116 -0
  5. truss/base/custom_types.py +29 -0
  6. truss/{errors.py → base/errors.py} +4 -0
  7. truss/base/trt_llm_config.py +310 -0
  8. truss/{truss_config.py → base/truss_config.py} +344 -31
  9. truss/{truss_spec.py → base/truss_spec.py} +20 -6
  10. truss/{validation.py → base/validation.py} +60 -11
  11. truss/cli/cli.py +841 -88
  12. truss/{remote → cli}/remote_cli.py +2 -7
  13. truss/contexts/docker_build_setup.py +67 -0
  14. truss/contexts/image_builder/cache_warmer.py +2 -8
  15. truss/contexts/image_builder/image_builder.py +1 -1
  16. truss/contexts/image_builder/serving_image_builder.py +292 -46
  17. truss/contexts/image_builder/util.py +1 -3
  18. truss/contexts/local_loader/docker_build_emulator.py +58 -0
  19. truss/contexts/local_loader/load_model_local.py +2 -2
  20. truss/contexts/local_loader/truss_module_loader.py +1 -1
  21. truss/contexts/local_loader/utils.py +1 -1
  22. truss/local/local_config.py +2 -6
  23. truss/local/local_config_handler.py +20 -5
  24. truss/patch/__init__.py +1 -0
  25. truss/patch/hash.py +4 -70
  26. truss/patch/signature.py +4 -16
  27. truss/patch/truss_dir_patch_applier.py +3 -78
  28. truss/remote/baseten/api.py +308 -23
  29. truss/remote/baseten/auth.py +3 -3
  30. truss/remote/baseten/core.py +257 -50
  31. truss/remote/baseten/custom_types.py +44 -0
  32. truss/remote/baseten/error.py +4 -0
  33. truss/remote/baseten/remote.py +369 -118
  34. truss/remote/baseten/service.py +118 -11
  35. truss/remote/baseten/utils/status.py +29 -0
  36. truss/remote/baseten/utils/tar.py +34 -22
  37. truss/remote/baseten/utils/transfer.py +36 -23
  38. truss/remote/remote_factory.py +14 -5
  39. truss/remote/truss_remote.py +72 -45
  40. truss/templates/base.Dockerfile.jinja +18 -16
  41. truss/templates/cache.Dockerfile.jinja +3 -3
  42. truss/{server → templates/control}/control/application.py +14 -35
  43. truss/{server → templates/control}/control/endpoints.py +39 -9
  44. truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
  45. truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
  46. truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
  47. truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
  48. truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
  49. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
  50. truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
  51. truss/{server → templates/control}/control/server.py +11 -6
  52. truss/templates/control/requirements.txt +9 -0
  53. truss/templates/custom_python_dx/my_model.py +28 -0
  54. truss/templates/docker_server/proxy.conf.jinja +42 -0
  55. truss/templates/docker_server/supervisord.conf.jinja +27 -0
  56. truss/templates/docker_server_requirements.txt +1 -0
  57. truss/templates/server/common/errors.py +231 -0
  58. truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
  59. truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
  60. truss/{server → templates/server}/common/retry.py +1 -0
  61. truss/{server → templates/server}/common/schema.py +11 -9
  62. truss/templates/server/common/tracing.py +157 -0
  63. truss/templates/server/main.py +9 -0
  64. truss/templates/server/model_wrapper.py +961 -0
  65. truss/templates/server/requirements.txt +21 -0
  66. truss/templates/server/truss_server.py +447 -0
  67. truss/templates/server.Dockerfile.jinja +62 -14
  68. truss/templates/shared/dynamic_config_resolver.py +28 -0
  69. truss/templates/shared/lazy_data_resolver.py +164 -0
  70. truss/templates/shared/log_config.py +125 -0
  71. truss/{server → templates}/shared/secrets_resolver.py +1 -2
  72. truss/{server → templates}/shared/serialization.py +31 -9
  73. truss/{server → templates}/shared/util.py +3 -13
  74. truss/templates/trtllm-audio/model/model.py +49 -0
  75. truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
  76. truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
  77. truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
  78. truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
  79. truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
  80. truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
  81. truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
  82. truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
  83. truss/templates/trtllm-briton/src/extension.py +64 -0
  84. truss/tests/conftest.py +302 -94
  85. truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
  86. truss/tests/contexts/local_loader/test_load_local.py +2 -2
  87. truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
  88. truss/tests/patch/test_calc_patch.py +439 -127
  89. truss/tests/patch/test_dir_signature.py +3 -12
  90. truss/tests/patch/test_hash.py +1 -1
  91. truss/tests/patch/test_signature.py +1 -1
  92. truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
  93. truss/tests/patch/test_types.py +2 -2
  94. truss/tests/remote/baseten/test_api.py +153 -58
  95. truss/tests/remote/baseten/test_auth.py +2 -1
  96. truss/tests/remote/baseten/test_core.py +160 -12
  97. truss/tests/remote/baseten/test_remote.py +489 -77
  98. truss/tests/remote/baseten/test_service.py +55 -0
  99. truss/tests/remote/test_remote_factory.py +16 -18
  100. truss/tests/remote/test_truss_remote.py +26 -17
  101. truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
  102. truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
  103. truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
  104. truss/tests/{server → templates/control}/control/test_server.py +79 -24
  105. truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
  106. truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
  107. truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
  108. truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
  109. truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
  110. truss/tests/{server → templates/server}/common/test_retry.py +3 -3
  111. truss/tests/templates/server/test_model_wrapper.py +248 -0
  112. truss/tests/{server → templates/server}/test_schema.py +3 -5
  113. truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
  114. truss/tests/test_build.py +9 -52
  115. truss/tests/test_config.py +336 -77
  116. truss/tests/test_context_builder_image.py +3 -11
  117. truss/tests/test_control_truss_patching.py +7 -12
  118. truss/tests/test_custom_server.py +38 -0
  119. truss/tests/test_data/context_builder_image_test/test.py +3 -0
  120. truss/tests/test_data/happy.ipynb +56 -0
  121. truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
  122. truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
  123. truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
  124. truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
  125. truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
  126. truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
  127. truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
  128. truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
  129. truss/tests/test_data/test_async_truss/__init__.py +0 -0
  130. truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
  131. truss/tests/test_data/test_basic_truss/__init__.py +0 -0
  132. truss/tests/test_data/test_basic_truss/config.yaml +16 -0
  133. truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
  134. truss/tests/test_data/test_build_commands/__init__.py +0 -0
  135. truss/tests/test_data/test_build_commands/config.yaml +13 -0
  136. truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
  137. truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
  138. truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
  139. truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
  140. truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
  141. truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
  142. truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
  143. truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
  144. truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
  145. truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
  146. truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
  147. truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
  148. truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
  149. truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
  150. truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
  151. truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
  152. truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
  153. truss/tests/test_data/test_openai/__init__.py +0 -0
  154. truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
  155. truss/tests/test_data/test_openai/model/__init__.py +0 -0
  156. truss/tests/test_data/test_openai/model/model.py +15 -0
  157. truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
  158. truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
  159. truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
  160. truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
  161. truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
  162. truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
  163. truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
  164. truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
  165. truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
  166. truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
  167. truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
  168. truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
  169. truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
  170. truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
  171. truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
  172. truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
  173. truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
  174. truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
  175. truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
  176. truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
  177. truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
  178. truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
  179. truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
  180. truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
  181. truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
  182. truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
  183. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
  184. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
  185. truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
  186. truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
  187. truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
  188. truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
  189. truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
  190. truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
  191. truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
  192. truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
  193. truss/tests/test_data/test_truss/__init__.py +0 -0
  194. truss/tests/test_data/test_truss/config.yaml +4 -0
  195. truss/tests/test_data/test_truss/model/__init__.py +0 -0
  196. truss/tests/test_data/test_truss/model/dummy +0 -0
  197. truss/tests/test_data/test_truss/packages/__init__.py +0 -0
  198. truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
  199. truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
  200. truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
  201. truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
  202. truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
  203. truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
  204. truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
  205. truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
  206. truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
  207. truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
  208. truss/tests/test_docker.py +2 -1
  209. truss/tests/test_model_inference.py +1340 -292
  210. truss/tests/test_model_schema.py +33 -26
  211. truss/tests/test_testing_utilities_for_other_tests.py +50 -5
  212. truss/tests/test_truss_gatherer.py +3 -5
  213. truss/tests/test_truss_handle.py +62 -59
  214. truss/tests/test_util.py +2 -1
  215. truss/tests/test_validation.py +15 -13
  216. truss/tests/trt_llm/test_trt_llm_config.py +41 -0
  217. truss/tests/trt_llm/test_validation.py +91 -0
  218. truss/tests/util/test_config_checks.py +40 -0
  219. truss/tests/util/test_env_vars.py +14 -0
  220. truss/tests/util/test_path.py +10 -23
  221. truss/trt_llm/config_checks.py +43 -0
  222. truss/trt_llm/validation.py +42 -0
  223. truss/truss_handle/__init__.py +0 -0
  224. truss/truss_handle/build.py +122 -0
  225. truss/{decorators.py → truss_handle/decorators.py} +1 -1
  226. truss/truss_handle/patch/__init__.py +0 -0
  227. truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
  228. truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
  229. truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
  230. truss/truss_handle/patch/hash.py +71 -0
  231. truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
  232. truss/truss_handle/patch/signature.py +22 -0
  233. truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
  234. truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
  235. truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
  236. truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
  237. truss/util/.truss_ignore +3 -0
  238. truss/{docker.py → util/docker.py} +6 -2
  239. truss/util/download.py +6 -15
  240. truss/util/env_vars.py +41 -0
  241. truss/util/log_utils.py +52 -0
  242. truss/util/path.py +20 -20
  243. truss/util/requirements.py +11 -0
  244. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
  245. truss-0.60.0.dist-info/RECORD +324 -0
  246. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
  247. truss-0.60.0.dist-info/entry_points.txt +4 -0
  248. truss_chains/__init__.py +71 -0
  249. truss_chains/definitions.py +756 -0
  250. truss_chains/deployment/__init__.py +0 -0
  251. truss_chains/deployment/code_gen.py +816 -0
  252. truss_chains/deployment/deployment_client.py +871 -0
  253. truss_chains/framework.py +1480 -0
  254. truss_chains/public_api.py +231 -0
  255. truss_chains/py.typed +0 -0
  256. truss_chains/pydantic_numpy.py +131 -0
  257. truss_chains/reference_code/reference_chainlet.py +34 -0
  258. truss_chains/reference_code/reference_model.py +10 -0
  259. truss_chains/remote_chainlet/__init__.py +0 -0
  260. truss_chains/remote_chainlet/model_skeleton.py +60 -0
  261. truss_chains/remote_chainlet/stub.py +380 -0
  262. truss_chains/remote_chainlet/utils.py +332 -0
  263. truss_chains/streaming.py +378 -0
  264. truss_chains/utils.py +178 -0
  265. CODE_OF_CONDUCT.md +0 -131
  266. CONTRIBUTING.md +0 -48
  267. README.md +0 -137
  268. context_builder.Dockerfile +0 -24
  269. truss/blob/blob_backend.py +0 -10
  270. truss/blob/blob_backend_registry.py +0 -23
  271. truss/blob/http_public_blob_backend.py +0 -23
  272. truss/build/__init__.py +0 -2
  273. truss/build/build.py +0 -143
  274. truss/build/configure.py +0 -63
  275. truss/cli/__init__.py +0 -2
  276. truss/cli/console.py +0 -5
  277. truss/cli/create.py +0 -5
  278. truss/config/trt_llm.py +0 -81
  279. truss/constants.py +0 -61
  280. truss/model_inference.py +0 -123
  281. truss/patch/types.py +0 -30
  282. truss/pytest.ini +0 -7
  283. truss/server/common/errors.py +0 -100
  284. truss/server/common/termination_handler_middleware.py +0 -64
  285. truss/server/common/truss_server.py +0 -389
  286. truss/server/control/patch/model_code_patch_applier.py +0 -46
  287. truss/server/control/patch/requirement_name_identifier.py +0 -17
  288. truss/server/inference_server.py +0 -29
  289. truss/server/model_wrapper.py +0 -434
  290. truss/server/shared/logging.py +0 -81
  291. truss/templates/trtllm/model/model.py +0 -97
  292. truss/templates/trtllm/packages/build_engine_utils.py +0 -34
  293. truss/templates/trtllm/packages/constants.py +0 -11
  294. truss/templates/trtllm/packages/schema.py +0 -216
  295. truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
  296. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
  297. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
  298. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
  299. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
  300. truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
  301. truss/templates/trtllm/packages/triton_client.py +0 -150
  302. truss/templates/trtllm/packages/utils.py +0 -43
  303. truss/test_data/context_builder_image_test/test.py +0 -4
  304. truss/test_data/happy.ipynb +0 -54
  305. truss/test_data/model_load_failure_test/config.yaml +0 -2
  306. truss/test_data/test_concurrency_truss/config.yaml +0 -2
  307. truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
  308. truss/test_data/test_streaming_truss/config.yaml +0 -3
  309. truss/test_data/test_truss/config.yaml +0 -2
  310. truss/tests/server/common/test_termination_handler_middleware.py +0 -93
  311. truss/tests/server/control/test_model_container_patch_applier.py +0 -203
  312. truss/tests/server/core/server/common/test_util.py +0 -19
  313. truss/tests/server/test_model_wrapper.py +0 -87
  314. truss/util/data_structures.py +0 -16
  315. truss-0.10.0rc1.dist-info/RECORD +0 -216
  316. truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
  317. truss/{server/shared → base}/__init__.py +0 -0
  318. truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
  319. truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
  320. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
  321. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
  322. truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
  323. truss/{server → templates/server}/common/__init__.py +0 -0
  324. truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
  325. truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
  326. truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
  327. truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
  328. truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
  329. truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
  330. truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
  331. truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
  332. truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
  333. truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
  334. truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
  335. truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
  336. truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
  337. truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
  338. truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
  339. truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
  340. truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
  341. truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
  342. truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
  343. truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
  344. truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
  345. truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
  346. truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
  347. /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
  348. /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
  349. /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
  350. /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
  351. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
  352. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
  353. /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
  354. /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
  355. /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
  356. /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
  357. /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
  358. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
  359. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
  360. /truss/{patch → truss_handle/patch}/constants.py +0 -0
  361. /truss/{notebook.py → util/notebook.py} +0 -0
  362. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
@@ -1,434 +0,0 @@
1
- import asyncio
2
- import importlib
3
- import inspect
4
- import logging
5
- import os
6
- import sys
7
- import time
8
- from collections.abc import Generator
9
- from contextlib import asynccontextmanager
10
- from enum import Enum
11
- from multiprocessing import Lock
12
- from pathlib import Path
13
- from threading import Thread
14
- from typing import Any, AsyncGenerator, Dict, Optional, Set, Union
15
-
16
- import pydantic
17
- from anyio import Semaphore, to_thread
18
- from fastapi import HTTPException
19
- from pydantic import BaseModel
20
- from truss.server.common.patches import apply_patches
21
- from truss.server.common.retry import retry
22
- from truss.server.common.schema import TrussSchema
23
- from truss.server.shared.secrets_resolver import SecretsResolver
24
-
25
- MODEL_BASENAME = "model"
26
-
27
- NUM_LOAD_RETRIES = int(os.environ.get("NUM_LOAD_RETRIES_TRUSS", "1"))
28
- STREAMING_RESPONSE_QUEUE_READ_TIMEOUT_SECS = 60
29
- DEFAULT_PREDICT_CONCURRENCY = 1
30
-
31
-
32
- class DeferredSemaphoreManager:
33
- """
34
- Helper class for supported deferred semaphore release.
35
- """
36
-
37
- def __init__(self, semaphore: Semaphore):
38
- self.semaphore = semaphore
39
- self.deferred = False
40
-
41
- def defer(self):
42
- """
43
- Track that this semaphore is to be deferred, and return
44
- a release method that the context block can use to release
45
- the semaphore.
46
- """
47
- self.deferred = True
48
-
49
- return self.semaphore.release
50
-
51
-
52
- @asynccontextmanager
53
- async def deferred_semaphore(semaphore: Semaphore):
54
- """
55
- Context manager that allows deferring the release of a semaphore.
56
- It yields a DeferredSemaphoreManager -- in your use of this context manager,
57
- if you call DeferredSemaphoreManager.defer(), you will get back a function that releases
58
- the semaphore that you must call.
59
- """
60
- semaphore_manager = DeferredSemaphoreManager(semaphore)
61
- await semaphore.acquire()
62
-
63
- try:
64
- yield semaphore_manager
65
- finally:
66
- if not semaphore_manager.deferred:
67
- semaphore.release()
68
-
69
-
70
- class ModelWrapper:
71
- class Status(Enum):
72
- NOT_READY = 0
73
- LOADING = 1
74
- READY = 2
75
- FAILED = 3
76
-
77
- def __init__(self, config: Dict):
78
- self._config = config
79
- self._logger = logging.getLogger()
80
- self.name = MODEL_BASENAME
81
- self.ready = False
82
- self._load_lock = Lock()
83
- self._status = ModelWrapper.Status.NOT_READY
84
- self._predict_semaphore = Semaphore(
85
- self._config.get("runtime", {}).get(
86
- "predict_concurrency", DEFAULT_PREDICT_CONCURRENCY
87
- )
88
- )
89
- self._background_tasks: Set[asyncio.Task] = set()
90
- self.truss_schema: Optional[TrussSchema] = None
91
- self.app_home: Path = Path(os.environ.get("APP_HOME", default=str(Path.cwd())))
92
-
93
- def load(self) -> bool:
94
- if self.ready:
95
- return self.ready
96
-
97
- # if we are already loading, block on aquiring the lock;
98
- # this worker will return 503 while the worker with the lock is loading
99
- with self._load_lock:
100
- self._status = ModelWrapper.Status.LOADING
101
-
102
- self._logger.info("Executing model.load()...")
103
-
104
- try:
105
- start_time = time.perf_counter()
106
- self.try_load()
107
- self.ready = True
108
- self._status = ModelWrapper.Status.READY
109
- self._logger.info(
110
- f"Completed model.load() execution in {_elapsed_ms(start_time)} ms"
111
- )
112
-
113
- return self.ready
114
- except Exception:
115
- self._logger.exception("Exception while loading model")
116
- self._status = ModelWrapper.Status.FAILED
117
-
118
- return self.ready
119
-
120
- def start_load(self):
121
- if self.should_load():
122
- thread = Thread(target=self.load)
123
- thread.start()
124
-
125
- def load_failed(self) -> bool:
126
- return self._status == ModelWrapper.Status.FAILED
127
-
128
- def should_load(self) -> bool:
129
- # don't retry failed loads
130
- return not self._status == ModelWrapper.Status.FAILED and not self.ready
131
-
132
- def try_load(self):
133
- data_dir = Path("data")
134
- data_dir.mkdir(exist_ok=True)
135
-
136
- sys.path.append(str(self.app_home))
137
- if "bundled_packages_dir" in self._config:
138
- bundled_packages_path = self.app_home / "packages"
139
- if bundled_packages_path.exists():
140
- sys.path.append(str(bundled_packages_path))
141
-
142
- model_module_name = str(
143
- Path(self._config["model_class_filename"]).with_suffix("")
144
- )
145
-
146
- module = importlib.import_module(
147
- f"{self._config['model_module_dir']}.{model_module_name}"
148
- )
149
- model_class = getattr(module, self._config["model_class_name"])
150
- model_class_signature = inspect.signature(model_class)
151
- model_init_params = {}
152
- if _signature_accepts_keyword_arg(model_class_signature, "config"):
153
- model_init_params["config"] = self._config
154
- if _signature_accepts_keyword_arg(model_class_signature, "data_dir"):
155
- model_init_params["data_dir"] = data_dir
156
- if _signature_accepts_keyword_arg(model_class_signature, "secrets"):
157
- model_init_params["secrets"] = SecretsResolver.get_secrets(self._config)
158
- apply_patches(
159
- self._config.get("apply_library_patches", True),
160
- self._config["requirements"],
161
- )
162
- self._model = model_class(**model_init_params)
163
-
164
- self.set_truss_schema()
165
-
166
- if hasattr(self._model, "load"):
167
- retry(
168
- self._model.load,
169
- NUM_LOAD_RETRIES,
170
- self._logger.warn,
171
- "Failed to load model.",
172
- gap_seconds=1.0,
173
- )
174
-
175
- def set_truss_schema(self):
176
- parameters = (
177
- inspect.signature(self._model.preprocess).parameters
178
- if hasattr(self._model, "preprocess")
179
- else inspect.signature(self._model.predict).parameters
180
- )
181
-
182
- outputs_annotation = (
183
- inspect.signature(self._model.postprocess).return_annotation
184
- if hasattr(self._model, "postprocess")
185
- else inspect.signature(self._model.predict).return_annotation
186
- )
187
-
188
- self.truss_schema = TrussSchema.from_signature(parameters, outputs_annotation)
189
-
190
- async def preprocess(
191
- self,
192
- payload: Any,
193
- headers: Optional[Dict[str, str]] = None,
194
- ) -> Any:
195
- if not hasattr(self._model, "preprocess"):
196
- return payload
197
-
198
- if inspect.iscoroutinefunction(self._model.preprocess):
199
- return await _intercept_exceptions_async(self._model.preprocess)(payload)
200
- else:
201
- return await to_thread.run_sync(
202
- _intercept_exceptions_sync(self._model.preprocess), payload
203
- )
204
-
205
- async def predict(
206
- self,
207
- payload: Any,
208
- headers: Optional[Dict[str, str]] = None,
209
- ) -> Any:
210
- # It's possible for the user's predict function to be a:
211
- # 1. Generator function (function that returns a generator)
212
- # 2. Async generator (function that returns async generator)
213
- # In these cases, just return the generator or async generator,
214
- # as we will be propagating these up. No need for await at this point.
215
- # 3. Coroutine -- in this case, await the predict function as it is async
216
- # 4. Normal function -- in this case, offload to a separate thread to prevent
217
- # blocking the main event loop
218
- if inspect.isasyncgenfunction(
219
- self._model.predict
220
- ) or inspect.isgeneratorfunction(self._model.predict):
221
- return self._model.predict(payload)
222
-
223
- if inspect.iscoroutinefunction(self._model.predict):
224
- return await _intercept_exceptions_async(self._model.predict)(payload)
225
-
226
- return await to_thread.run_sync(
227
- _intercept_exceptions_sync(self._model.predict), payload
228
- )
229
-
230
- async def postprocess(
231
- self,
232
- response: Any,
233
- headers: Optional[Dict[str, str]] = None,
234
- ) -> Any:
235
- # Similar to the predict function, it is possible for postprocess
236
- # to return either a generator or async generator, in which case
237
- # just return the generator.
238
- #
239
- # It can also return a coroutine or just be a function, in which
240
- # case either await, or offload to a thread respectively.
241
- if not hasattr(self._model, "postprocess"):
242
- return response
243
-
244
- if inspect.isasyncgenfunction(
245
- self._model.postprocess
246
- ) or inspect.isgeneratorfunction(self._model.postprocess):
247
- return self._model.postprocess(response)
248
-
249
- if inspect.iscoroutinefunction(self._model.postprocess):
250
- return await _intercept_exceptions_async(self._model.postprocess)(response)
251
-
252
- return await to_thread.run_sync(
253
- _intercept_exceptions_sync(self._model.postprocess), response
254
- )
255
-
256
- async def write_response_to_queue(
257
- self, queue: asyncio.Queue, generator: AsyncGenerator
258
- ):
259
- try:
260
- async for chunk in generator:
261
- await queue.put(ResponseChunk(chunk))
262
- except Exception as e:
263
- self._logger.exception("Exception while reading stream response: " + str(e))
264
- finally:
265
- await queue.put(None)
266
-
267
- async def __call__(
268
- self, body: Any, headers: Optional[Dict[str, str]] = None
269
- ) -> Union[Dict, Generator]:
270
- """Method to call predictor or explainer with the given input.
271
-
272
- Args:
273
- body (Any): Request payload body.
274
- headers (Dict): Request headers.
275
-
276
- Returns:
277
- Dict: Response output from preprocess -> predictor -> postprocess
278
- Generator: In case of streaming response
279
- """
280
-
281
- # The streaming read timeout is the amount of time in between streamed chunks before a timeout is triggered
282
- streaming_read_timeout = self._config.get("runtime", {}).get(
283
- "streaming_read_timeout", STREAMING_RESPONSE_QUEUE_READ_TIMEOUT_SECS
284
- )
285
-
286
- if self.truss_schema is not None:
287
- try:
288
- body = self.truss_schema.input_type(**body)
289
- except pydantic.ValidationError as e:
290
- self._logger.info("Request Validation Error: %s", {str(e)})
291
- raise HTTPException(
292
- status_code=400, detail="Request Validation Error"
293
- ) from e
294
-
295
- payload = await self.preprocess(body, headers)
296
-
297
- async with deferred_semaphore(self._predict_semaphore) as semaphore_manager:
298
- response = await self.predict(payload, headers)
299
-
300
- # Streaming cases
301
- if inspect.isgenerator(response) or inspect.isasyncgen(response):
302
- if hasattr(self._model, "postprocess"):
303
- logging.warning(
304
- "Predict returned a streaming response, while a postprocess is defined."
305
- "Note that in this case, the postprocess will run within the predict lock."
306
- )
307
-
308
- response = await self.postprocess(response)
309
-
310
- async_generator = _force_async_generator(response)
311
-
312
- if headers and headers.get("accept") == "application/json":
313
- # In the case of a streaming response, consume stream
314
- # if the http accept header is set, and json is requested.
315
- return await _convert_streamed_response_to_string(async_generator)
316
-
317
- # To ensure that a partial read from a client does not cause the semaphore
318
- # to stay claimed, we immediately write all of the data from the stream to a
319
- # queue. We then return a new generator that reads from the queue, and then
320
- # exit the semaphore block.
321
- response_queue: asyncio.Queue = asyncio.Queue()
322
-
323
- # This task will be triggered and run in the background.
324
- task = asyncio.create_task(
325
- self.write_response_to_queue(response_queue, async_generator)
326
- )
327
-
328
- # We add the task to the ModelWrapper instance to ensure it does
329
- # not get garbage collected after the predict method completes,
330
- # and continues running.
331
- self._background_tasks.add(task)
332
-
333
- # Defer the release of the semaphore until the write_response_to_queue
334
- # task.
335
- semaphore_release_function = semaphore_manager.defer()
336
- task.add_done_callback(lambda _: semaphore_release_function())
337
- task.add_done_callback(self._background_tasks.discard)
338
-
339
- # The gap between responses in a stream must be < streaming_read_timeout
340
- async def _response_generator():
341
- while True:
342
- chunk = await asyncio.wait_for(
343
- response_queue.get(),
344
- timeout=streaming_read_timeout,
345
- )
346
- if chunk is None:
347
- return
348
- yield chunk.value
349
-
350
- return _response_generator()
351
-
352
- processed_response = await self.postprocess(response)
353
-
354
- if isinstance(processed_response, BaseModel):
355
- # If we return a pydantic object, convert it back to a dict
356
- processed_response = processed_response.dict()
357
- return processed_response
358
-
359
-
360
- class ResponseChunk:
361
- def __init__(self, value):
362
- self.value = value
363
-
364
-
365
- async def _convert_streamed_response_to_string(response: AsyncGenerator):
366
- return "".join([str(chunk) async for chunk in response])
367
-
368
-
369
- def _force_async_generator(gen: Union[Generator, AsyncGenerator]) -> AsyncGenerator:
370
- """
371
- Takes a generator, and converts it into an async generator if it is not already.
372
- """
373
- if inspect.isasyncgen(gen):
374
- return gen
375
-
376
- async def _convert_generator_to_async():
377
- """
378
- Runs each iteration of the generator in an offloaded thread, to ensure
379
- the main loop is not blocked, and yield to create an async generator.
380
- """
381
- FINAL_GENERATOR_VALUE = object()
382
- while True:
383
- # Note that this is the equivalent of running:
384
- # next(gen, FINAL_GENERATOR_VALUE) on a separate thread,
385
- # ensuring that if there is anything blocking in the generator,
386
- # it does not block the main loop.
387
- chunk = await to_thread.run_sync(next, gen, FINAL_GENERATOR_VALUE)
388
- if chunk == FINAL_GENERATOR_VALUE:
389
- break
390
- yield chunk
391
-
392
- return _convert_generator_to_async()
393
-
394
-
395
- def _signature_accepts_keyword_arg(signature: inspect.Signature, kwarg: str) -> bool:
396
- return kwarg in signature.parameters or _signature_accepts_kwargs(signature)
397
-
398
-
399
- def _signature_accepts_kwargs(signature: inspect.Signature) -> bool:
400
- for param in signature.parameters.values():
401
- if param.kind == inspect.Parameter.VAR_KEYWORD:
402
- return True
403
- return False
404
-
405
-
406
- def _elapsed_ms(since_micro_seconds: float) -> int:
407
- return int((time.perf_counter() - since_micro_seconds) * 1000)
408
-
409
-
410
- def _handle_exception():
411
- # Note that logger.exception logs the stacktrace, such that the user can
412
- # debug this error from the logs.
413
- logging.exception("Internal Server Error")
414
- raise HTTPException(status_code=500, detail="Internal Server Error")
415
-
416
-
417
- def _intercept_exceptions_sync(func):
418
- def inner(*args, **kwargs):
419
- try:
420
- return func(*args, **kwargs)
421
- except Exception:
422
- _handle_exception()
423
-
424
- return inner
425
-
426
-
427
- def _intercept_exceptions_async(func):
428
- async def inner(*args, **kwargs):
429
- try:
430
- return await func(*args, **kwargs)
431
- except Exception:
432
- _handle_exception()
433
-
434
- return inner
@@ -1,81 +0,0 @@
1
- import logging
2
- import os
3
- import sys
4
-
5
- from pythonjsonlogger import jsonlogger
6
-
7
- LEVEL: int = logging.INFO
8
-
9
- use_json_logs = os.environ.get("JSON_LOG", default=False)
10
-
11
- JSON_LOG_HANDLER = logging.StreamHandler(stream=sys.stdout)
12
- JSON_LOG_HANDLER.set_name("json_logger_handler")
13
- JSON_LOG_HANDLER.setLevel(LEVEL)
14
- JSON_LOG_HANDLER.setFormatter(
15
- jsonlogger.JsonFormatter("%(asctime)s %(levelname)s %(message)s")
16
- )
17
-
18
-
19
- class HealthCheckFilter(logging.Filter):
20
- def filter(self, record: logging.LogRecord) -> bool:
21
- # for any health check endpoints, lets skip logging
22
- return (
23
- record.getMessage().find("GET / ") == -1
24
- and record.getMessage().find("GET /v1/models/model ") == -1
25
- )
26
-
27
-
28
- class StreamToLogger:
29
- """
30
- StreamToLogger redirects stdout and stderr to logger
31
- """
32
-
33
- def __init__(self, logger, log_level, stream):
34
- self.logger = logger
35
- self.log_level = log_level
36
- self.stream = stream
37
-
38
- def __getattr__(self, name):
39
- # we need to pass `isatty` from the stream for uvicorn
40
- # this is a more general, less hacky fix
41
- return getattr(self.stream, name)
42
-
43
- def write(self, buf):
44
- self.logger.log(self.log_level, buf)
45
-
46
- def flush(self):
47
- """
48
- This is a no-op function. It only exists to prevent
49
- AttributeError in case some part of the code attempts to call flush()
50
- on instances of StreamToLogger. Thus, we define this method as a safety
51
- measure.
52
- """
53
- pass
54
-
55
-
56
- def setup_logging() -> None:
57
- loggers = [logging.getLogger()] + [
58
- logging.getLogger(name) for name in logging.root.manager.loggerDict
59
- ]
60
-
61
- sys.stdout = StreamToLogger(logging.getLogger(), logging.INFO, sys.__stdout__) # type: ignore
62
- sys.stderr = StreamToLogger(logging.getLogger(), logging.INFO, sys.__stderr__) # type: ignore
63
-
64
- for logger in loggers:
65
- logger.setLevel(LEVEL)
66
- logger.propagate = False
67
-
68
- setup = False
69
- if use_json_logs:
70
- # let's not thrash the handlers unnecessarily
71
- for handler in logger.handlers:
72
- if handler.name == JSON_LOG_HANDLER.name:
73
- setup = True
74
-
75
- if not setup:
76
- logger.handlers.clear()
77
- logger.addHandler(JSON_LOG_HANDLER)
78
-
79
- # some special handling for request logging
80
- if logger.name == "uvicorn.access":
81
- logger.addFilter(HealthCheckFilter())
@@ -1,97 +0,0 @@
1
- import os
2
- from itertools import count
3
-
4
- import build_engine_utils
5
- from constants import (
6
- GRPC_SERVICE_PORT,
7
- HF_AUTH_KEY_CONSTANT,
8
- HTTP_SERVICE_PORT,
9
- TOKENIZER_KEY_CONSTANT,
10
- )
11
- from schema import ModelInput, TrussBuildConfig
12
- from transformers import AutoTokenizer
13
- from triton_client import TritonClient, TritonServer
14
-
15
-
16
- class Model:
17
- def __init__(self, data_dir, config, secrets):
18
- self._data_dir = data_dir
19
- self._config = config
20
- self._secrets = secrets
21
- self._request_id_counter = count(start=1)
22
- self.triton_client = None
23
- self.triton_server = None
24
- self.tokenizer = None
25
- self.uses_openai_api = None
26
-
27
- def load(self):
28
- build_config = TrussBuildConfig(**self._config["build"]["arguments"])
29
- self.uses_openai_api = "openai-compatible" in self._config.get(
30
- "model_metadata", {}
31
- ).get("tags", [])
32
- hf_access_token = None
33
- if "hf_access_token" in self._secrets._base_secrets.keys():
34
- hf_access_token = self._secrets["hf_access_token"]
35
-
36
- # TODO(Abu): Move to pre-runtime
37
- if build_config.requires_build:
38
- build_engine_utils.build_engine_from_config_args(
39
- engine_build_args=build_config.engine_build_args,
40
- dst=self._data_dir,
41
- )
42
-
43
- self.triton_server = TritonServer(
44
- grpc_port=GRPC_SERVICE_PORT,
45
- http_port=HTTP_SERVICE_PORT,
46
- )
47
-
48
- self.triton_server.create_model_repository(
49
- truss_data_dir=self._data_dir,
50
- engine_repository_path=build_config.engine_repository
51
- if not build_config.requires_build
52
- else None,
53
- huggingface_auth_token=hf_access_token,
54
- )
55
-
56
- env = {}
57
- if hf_access_token:
58
- env[HF_AUTH_KEY_CONSTANT] = hf_access_token
59
- env[TOKENIZER_KEY_CONSTANT] = build_config.tokenizer_repository
60
-
61
- self.triton_server.start(
62
- tensor_parallelism=build_config.tensor_parallel_count,
63
- env=env,
64
- )
65
-
66
- self.triton_client = TritonClient(
67
- grpc_service_port=GRPC_SERVICE_PORT,
68
- )
69
-
70
- self.tokenizer = AutoTokenizer.from_pretrained(
71
- build_config.tokenizer_repository, token=hf_access_token
72
- )
73
- self.eos_token_id = self.tokenizer.eos_token_id
74
-
75
- async def predict(self, model_input):
76
- model_input["request_id"] = str(os.getpid()) + str(
77
- next(self._request_id_counter)
78
- )
79
- model_input["eos_token_id"] = self.eos_token_id
80
-
81
- self.triton_client.start_grpc_stream()
82
-
83
- model_input = ModelInput(**model_input)
84
-
85
- result_iterator = self.triton_client.infer(model_input)
86
-
87
- async def generate():
88
- async for result in result_iterator:
89
- yield result
90
-
91
- if model_input.stream:
92
- return generate()
93
- else:
94
- if self.uses_openai_api:
95
- return "".join(generate())
96
- else:
97
- return {"text": "".join(generate())}
@@ -1,34 +0,0 @@
1
- from pathlib import Path
2
-
3
- from schema import EngineBuildArgs
4
-
5
-
6
- def build_engine_from_config_args(
7
- engine_build_args: EngineBuildArgs,
8
- dst: Path,
9
- ):
10
- import os
11
- import shutil
12
- import sys
13
-
14
- # NOTE: These are provided by the underlying base image
15
- # TODO(Abu): Remove this when we have a better way of handling this
16
- sys.path.append("/app/baseten")
17
- from build_engine import Engine, build_engine
18
- from trtllm_utils import docker_tag_aware_file_cache
19
-
20
- engine = Engine(**engine_build_args.model_dump())
21
-
22
- with docker_tag_aware_file_cache("/root/.cache/trtllm"):
23
- built_engine = build_engine(engine, download_remote=True)
24
-
25
- if not os.path.exists(dst):
26
- os.makedirs(dst)
27
-
28
- for filename in os.listdir(str(built_engine)):
29
- source_file = os.path.join(str(built_engine), filename)
30
- destination_file = os.path.join(dst, filename)
31
- if not os.path.exists(destination_file):
32
- shutil.copy(source_file, destination_file)
33
-
34
- return dst
@@ -1,11 +0,0 @@
1
- from pathlib import Path
2
-
3
- # If changing model repo path, please updated inside tensorrt_llm config.pbtxt as well
4
- TENSORRT_LLM_MODEL_REPOSITORY_PATH = Path(
5
- "/app/packages/tensorrt_llm_model_repository/"
6
- )
7
- GRPC_SERVICE_PORT = 8001
8
- HTTP_SERVICE_PORT = 8003
9
- HF_AUTH_KEY_CONSTANT = "HUGGING_FACE_HUB_TOKEN"
10
- TOKENIZER_KEY_CONSTANT = "TRITON_TOKENIZER_REPOSITORY"
11
- ENTRYPOINT_MODEL_NAME = "ensemble"