truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

Files changed (362) hide show
  1. truss/__init__.py +10 -3
  2. truss/api/__init__.py +123 -0
  3. truss/api/definitions.py +51 -0
  4. truss/base/constants.py +116 -0
  5. truss/base/custom_types.py +29 -0
  6. truss/{errors.py → base/errors.py} +4 -0
  7. truss/base/trt_llm_config.py +310 -0
  8. truss/{truss_config.py → base/truss_config.py} +344 -31
  9. truss/{truss_spec.py → base/truss_spec.py} +20 -6
  10. truss/{validation.py → base/validation.py} +60 -11
  11. truss/cli/cli.py +841 -88
  12. truss/{remote → cli}/remote_cli.py +2 -7
  13. truss/contexts/docker_build_setup.py +67 -0
  14. truss/contexts/image_builder/cache_warmer.py +2 -8
  15. truss/contexts/image_builder/image_builder.py +1 -1
  16. truss/contexts/image_builder/serving_image_builder.py +292 -46
  17. truss/contexts/image_builder/util.py +1 -3
  18. truss/contexts/local_loader/docker_build_emulator.py +58 -0
  19. truss/contexts/local_loader/load_model_local.py +2 -2
  20. truss/contexts/local_loader/truss_module_loader.py +1 -1
  21. truss/contexts/local_loader/utils.py +1 -1
  22. truss/local/local_config.py +2 -6
  23. truss/local/local_config_handler.py +20 -5
  24. truss/patch/__init__.py +1 -0
  25. truss/patch/hash.py +4 -70
  26. truss/patch/signature.py +4 -16
  27. truss/patch/truss_dir_patch_applier.py +3 -78
  28. truss/remote/baseten/api.py +308 -23
  29. truss/remote/baseten/auth.py +3 -3
  30. truss/remote/baseten/core.py +257 -50
  31. truss/remote/baseten/custom_types.py +44 -0
  32. truss/remote/baseten/error.py +4 -0
  33. truss/remote/baseten/remote.py +369 -118
  34. truss/remote/baseten/service.py +118 -11
  35. truss/remote/baseten/utils/status.py +29 -0
  36. truss/remote/baseten/utils/tar.py +34 -22
  37. truss/remote/baseten/utils/transfer.py +36 -23
  38. truss/remote/remote_factory.py +14 -5
  39. truss/remote/truss_remote.py +72 -45
  40. truss/templates/base.Dockerfile.jinja +18 -16
  41. truss/templates/cache.Dockerfile.jinja +3 -3
  42. truss/{server → templates/control}/control/application.py +14 -35
  43. truss/{server → templates/control}/control/endpoints.py +39 -9
  44. truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
  45. truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
  46. truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
  47. truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
  48. truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
  49. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
  50. truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
  51. truss/{server → templates/control}/control/server.py +11 -6
  52. truss/templates/control/requirements.txt +9 -0
  53. truss/templates/custom_python_dx/my_model.py +28 -0
  54. truss/templates/docker_server/proxy.conf.jinja +42 -0
  55. truss/templates/docker_server/supervisord.conf.jinja +27 -0
  56. truss/templates/docker_server_requirements.txt +1 -0
  57. truss/templates/server/common/errors.py +231 -0
  58. truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
  59. truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
  60. truss/{server → templates/server}/common/retry.py +1 -0
  61. truss/{server → templates/server}/common/schema.py +11 -9
  62. truss/templates/server/common/tracing.py +157 -0
  63. truss/templates/server/main.py +9 -0
  64. truss/templates/server/model_wrapper.py +961 -0
  65. truss/templates/server/requirements.txt +21 -0
  66. truss/templates/server/truss_server.py +447 -0
  67. truss/templates/server.Dockerfile.jinja +62 -14
  68. truss/templates/shared/dynamic_config_resolver.py +28 -0
  69. truss/templates/shared/lazy_data_resolver.py +164 -0
  70. truss/templates/shared/log_config.py +125 -0
  71. truss/{server → templates}/shared/secrets_resolver.py +1 -2
  72. truss/{server → templates}/shared/serialization.py +31 -9
  73. truss/{server → templates}/shared/util.py +3 -13
  74. truss/templates/trtllm-audio/model/model.py +49 -0
  75. truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
  76. truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
  77. truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
  78. truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
  79. truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
  80. truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
  81. truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
  82. truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
  83. truss/templates/trtllm-briton/src/extension.py +64 -0
  84. truss/tests/conftest.py +302 -94
  85. truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
  86. truss/tests/contexts/local_loader/test_load_local.py +2 -2
  87. truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
  88. truss/tests/patch/test_calc_patch.py +439 -127
  89. truss/tests/patch/test_dir_signature.py +3 -12
  90. truss/tests/patch/test_hash.py +1 -1
  91. truss/tests/patch/test_signature.py +1 -1
  92. truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
  93. truss/tests/patch/test_types.py +2 -2
  94. truss/tests/remote/baseten/test_api.py +153 -58
  95. truss/tests/remote/baseten/test_auth.py +2 -1
  96. truss/tests/remote/baseten/test_core.py +160 -12
  97. truss/tests/remote/baseten/test_remote.py +489 -77
  98. truss/tests/remote/baseten/test_service.py +55 -0
  99. truss/tests/remote/test_remote_factory.py +16 -18
  100. truss/tests/remote/test_truss_remote.py +26 -17
  101. truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
  102. truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
  103. truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
  104. truss/tests/{server → templates/control}/control/test_server.py +79 -24
  105. truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
  106. truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
  107. truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
  108. truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
  109. truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
  110. truss/tests/{server → templates/server}/common/test_retry.py +3 -3
  111. truss/tests/templates/server/test_model_wrapper.py +248 -0
  112. truss/tests/{server → templates/server}/test_schema.py +3 -5
  113. truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
  114. truss/tests/test_build.py +9 -52
  115. truss/tests/test_config.py +336 -77
  116. truss/tests/test_context_builder_image.py +3 -11
  117. truss/tests/test_control_truss_patching.py +7 -12
  118. truss/tests/test_custom_server.py +38 -0
  119. truss/tests/test_data/context_builder_image_test/test.py +3 -0
  120. truss/tests/test_data/happy.ipynb +56 -0
  121. truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
  122. truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
  123. truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
  124. truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
  125. truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
  126. truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
  127. truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
  128. truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
  129. truss/tests/test_data/test_async_truss/__init__.py +0 -0
  130. truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
  131. truss/tests/test_data/test_basic_truss/__init__.py +0 -0
  132. truss/tests/test_data/test_basic_truss/config.yaml +16 -0
  133. truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
  134. truss/tests/test_data/test_build_commands/__init__.py +0 -0
  135. truss/tests/test_data/test_build_commands/config.yaml +13 -0
  136. truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
  137. truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
  138. truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
  139. truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
  140. truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
  141. truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
  142. truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
  143. truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
  144. truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
  145. truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
  146. truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
  147. truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
  148. truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
  149. truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
  150. truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
  151. truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
  152. truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
  153. truss/tests/test_data/test_openai/__init__.py +0 -0
  154. truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
  155. truss/tests/test_data/test_openai/model/__init__.py +0 -0
  156. truss/tests/test_data/test_openai/model/model.py +15 -0
  157. truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
  158. truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
  159. truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
  160. truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
  161. truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
  162. truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
  163. truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
  164. truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
  165. truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
  166. truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
  167. truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
  168. truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
  169. truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
  170. truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
  171. truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
  172. truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
  173. truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
  174. truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
  175. truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
  176. truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
  177. truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
  178. truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
  179. truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
  180. truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
  181. truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
  182. truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
  183. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
  184. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
  185. truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
  186. truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
  187. truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
  188. truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
  189. truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
  190. truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
  191. truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
  192. truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
  193. truss/tests/test_data/test_truss/__init__.py +0 -0
  194. truss/tests/test_data/test_truss/config.yaml +4 -0
  195. truss/tests/test_data/test_truss/model/__init__.py +0 -0
  196. truss/tests/test_data/test_truss/model/dummy +0 -0
  197. truss/tests/test_data/test_truss/packages/__init__.py +0 -0
  198. truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
  199. truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
  200. truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
  201. truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
  202. truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
  203. truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
  204. truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
  205. truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
  206. truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
  207. truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
  208. truss/tests/test_docker.py +2 -1
  209. truss/tests/test_model_inference.py +1340 -292
  210. truss/tests/test_model_schema.py +33 -26
  211. truss/tests/test_testing_utilities_for_other_tests.py +50 -5
  212. truss/tests/test_truss_gatherer.py +3 -5
  213. truss/tests/test_truss_handle.py +62 -59
  214. truss/tests/test_util.py +2 -1
  215. truss/tests/test_validation.py +15 -13
  216. truss/tests/trt_llm/test_trt_llm_config.py +41 -0
  217. truss/tests/trt_llm/test_validation.py +91 -0
  218. truss/tests/util/test_config_checks.py +40 -0
  219. truss/tests/util/test_env_vars.py +14 -0
  220. truss/tests/util/test_path.py +10 -23
  221. truss/trt_llm/config_checks.py +43 -0
  222. truss/trt_llm/validation.py +42 -0
  223. truss/truss_handle/__init__.py +0 -0
  224. truss/truss_handle/build.py +122 -0
  225. truss/{decorators.py → truss_handle/decorators.py} +1 -1
  226. truss/truss_handle/patch/__init__.py +0 -0
  227. truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
  228. truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
  229. truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
  230. truss/truss_handle/patch/hash.py +71 -0
  231. truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
  232. truss/truss_handle/patch/signature.py +22 -0
  233. truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
  234. truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
  235. truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
  236. truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
  237. truss/util/.truss_ignore +3 -0
  238. truss/{docker.py → util/docker.py} +6 -2
  239. truss/util/download.py +6 -15
  240. truss/util/env_vars.py +41 -0
  241. truss/util/log_utils.py +52 -0
  242. truss/util/path.py +20 -20
  243. truss/util/requirements.py +11 -0
  244. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
  245. truss-0.60.0.dist-info/RECORD +324 -0
  246. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
  247. truss-0.60.0.dist-info/entry_points.txt +4 -0
  248. truss_chains/__init__.py +71 -0
  249. truss_chains/definitions.py +756 -0
  250. truss_chains/deployment/__init__.py +0 -0
  251. truss_chains/deployment/code_gen.py +816 -0
  252. truss_chains/deployment/deployment_client.py +871 -0
  253. truss_chains/framework.py +1480 -0
  254. truss_chains/public_api.py +231 -0
  255. truss_chains/py.typed +0 -0
  256. truss_chains/pydantic_numpy.py +131 -0
  257. truss_chains/reference_code/reference_chainlet.py +34 -0
  258. truss_chains/reference_code/reference_model.py +10 -0
  259. truss_chains/remote_chainlet/__init__.py +0 -0
  260. truss_chains/remote_chainlet/model_skeleton.py +60 -0
  261. truss_chains/remote_chainlet/stub.py +380 -0
  262. truss_chains/remote_chainlet/utils.py +332 -0
  263. truss_chains/streaming.py +378 -0
  264. truss_chains/utils.py +178 -0
  265. CODE_OF_CONDUCT.md +0 -131
  266. CONTRIBUTING.md +0 -48
  267. README.md +0 -137
  268. context_builder.Dockerfile +0 -24
  269. truss/blob/blob_backend.py +0 -10
  270. truss/blob/blob_backend_registry.py +0 -23
  271. truss/blob/http_public_blob_backend.py +0 -23
  272. truss/build/__init__.py +0 -2
  273. truss/build/build.py +0 -143
  274. truss/build/configure.py +0 -63
  275. truss/cli/__init__.py +0 -2
  276. truss/cli/console.py +0 -5
  277. truss/cli/create.py +0 -5
  278. truss/config/trt_llm.py +0 -81
  279. truss/constants.py +0 -61
  280. truss/model_inference.py +0 -123
  281. truss/patch/types.py +0 -30
  282. truss/pytest.ini +0 -7
  283. truss/server/common/errors.py +0 -100
  284. truss/server/common/termination_handler_middleware.py +0 -64
  285. truss/server/common/truss_server.py +0 -389
  286. truss/server/control/patch/model_code_patch_applier.py +0 -46
  287. truss/server/control/patch/requirement_name_identifier.py +0 -17
  288. truss/server/inference_server.py +0 -29
  289. truss/server/model_wrapper.py +0 -434
  290. truss/server/shared/logging.py +0 -81
  291. truss/templates/trtllm/model/model.py +0 -97
  292. truss/templates/trtllm/packages/build_engine_utils.py +0 -34
  293. truss/templates/trtllm/packages/constants.py +0 -11
  294. truss/templates/trtllm/packages/schema.py +0 -216
  295. truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
  296. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
  297. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
  298. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
  299. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
  300. truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
  301. truss/templates/trtllm/packages/triton_client.py +0 -150
  302. truss/templates/trtllm/packages/utils.py +0 -43
  303. truss/test_data/context_builder_image_test/test.py +0 -4
  304. truss/test_data/happy.ipynb +0 -54
  305. truss/test_data/model_load_failure_test/config.yaml +0 -2
  306. truss/test_data/test_concurrency_truss/config.yaml +0 -2
  307. truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
  308. truss/test_data/test_streaming_truss/config.yaml +0 -3
  309. truss/test_data/test_truss/config.yaml +0 -2
  310. truss/tests/server/common/test_termination_handler_middleware.py +0 -93
  311. truss/tests/server/control/test_model_container_patch_applier.py +0 -203
  312. truss/tests/server/core/server/common/test_util.py +0 -19
  313. truss/tests/server/test_model_wrapper.py +0 -87
  314. truss/util/data_structures.py +0 -16
  315. truss-0.10.0rc1.dist-info/RECORD +0 -216
  316. truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
  317. truss/{server/shared → base}/__init__.py +0 -0
  318. truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
  319. truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
  320. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
  321. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
  322. truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
  323. truss/{server → templates/server}/common/__init__.py +0 -0
  324. truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
  325. truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
  326. truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
  327. truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
  328. truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
  329. truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
  330. truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
  331. truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
  332. truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
  333. truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
  334. truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
  335. truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
  336. truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
  337. truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
  338. truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
  339. truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
  340. truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
  341. truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
  342. truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
  343. truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
  344. truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
  345. truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
  346. truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
  347. /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
  348. /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
  349. /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
  350. /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
  351. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
  352. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
  353. /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
  354. /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
  355. /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
  356. /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
  357. /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
  358. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
  359. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
  360. /truss/{patch → truss_handle/patch}/constants.py +0 -0
  361. /truss/{notebook.py → util/notebook.py} +0 -0
  362. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
@@ -1,216 +0,0 @@
1
- from enum import Enum
2
- from pathlib import Path
3
- from typing import Optional
4
-
5
- import numpy as np
6
- import tritonclient
7
- import tritonclient.grpc.aio as grpcclient
8
- from pydantic import BaseModel, ConfigDict, PrivateAttr
9
-
10
-
11
- class ModelInput:
12
- def __init__(
13
- self,
14
- prompt: str,
15
- request_id: int,
16
- max_tokens: int = 50,
17
- beam_width: int = 1,
18
- bad_words_list: Optional[list] = None,
19
- stop_words_list: Optional[list] = None,
20
- repetition_penalty: float = 1.0,
21
- ignore_eos: bool = False,
22
- stream: bool = True,
23
- eos_token_id: int = None, # type: ignore
24
- ) -> None:
25
- self.stream = stream
26
- self.request_id = request_id
27
- self._prompt = prompt
28
- self._max_tokens = max_tokens
29
- self._beam_width = beam_width
30
- self._bad_words_list = [""] if bad_words_list is None else bad_words_list
31
- self._stop_words_list = [""] if stop_words_list is None else stop_words_list
32
- self._repetition_penalty = repetition_penalty
33
- self._eos_token_id = eos_token_id
34
- self._ignore_eos = ignore_eos
35
-
36
- def _prepare_grpc_tensor(
37
- self, name: str, input_data: np.ndarray
38
- ) -> grpcclient.InferInput:
39
- tensor = grpcclient.InferInput(
40
- name,
41
- input_data.shape,
42
- tritonclient.utils.np_to_triton_dtype(input_data.dtype),
43
- )
44
- tensor.set_data_from_numpy(input_data)
45
- return tensor
46
-
47
- def to_tensors(self):
48
- if self._eos_token_id is None and self._ignore_eos:
49
- raise ValueError("eos_token_id is required when ignore_eos is True")
50
-
51
- prompt_data = np.array([[self._prompt]], dtype=object)
52
- output_len_data = np.ones_like(prompt_data, dtype=np.uint32) * self._max_tokens
53
- bad_words_data = np.array([self._bad_words_list], dtype=object)
54
- stop_words_data = np.array([self._stop_words_list], dtype=object)
55
- stream_data = np.array([[self.stream]], dtype=bool)
56
- beam_width_data = np.array([[self._beam_width]], dtype=np.uint32)
57
- repetition_penalty_data = np.array(
58
- [[self._repetition_penalty]], dtype=np.float32
59
- )
60
-
61
- inputs = [
62
- self._prepare_grpc_tensor("text_input", prompt_data),
63
- self._prepare_grpc_tensor("max_tokens", output_len_data),
64
- self._prepare_grpc_tensor("bad_words", bad_words_data),
65
- self._prepare_grpc_tensor("stop_words", stop_words_data),
66
- self._prepare_grpc_tensor("stream", stream_data),
67
- self._prepare_grpc_tensor("beam_width", beam_width_data),
68
- self._prepare_grpc_tensor("repetition_penalty", repetition_penalty_data),
69
- ]
70
-
71
- if not self._ignore_eos:
72
- end_id_data = np.array([[self._eos_token_id]], dtype=np.uint32)
73
- inputs.append(self._prepare_grpc_tensor("end_id", end_id_data))
74
-
75
- return inputs
76
-
77
-
78
- class Quant(Enum):
79
- NO_QUANT = "no_quant"
80
- WEIGHTS_ONLY = "weights_only"
81
- WEIGHTS_KV_INT8 = "weights_kv_int8"
82
- SMOOTH_QUANT = "smooth_quant"
83
-
84
-
85
- class EngineType(Enum):
86
- LLAMA = "llama"
87
- MISTRAL = "mistral"
88
-
89
-
90
- class ArgsConfig(BaseModel):
91
- max_input_len: Optional[int] = None
92
- max_output_len: Optional[int] = None
93
- max_batch_size: Optional[int] = None
94
- tp_size: Optional[int] = None
95
- pp_size: Optional[int] = None
96
- world_size: Optional[int] = None
97
- gather_all_token_logits: Optional[bool] = None
98
- multi_block_mode: Optional[bool] = None
99
- remove_input_padding: Optional[bool] = None
100
- use_gpt_attention_plugin: Optional[str] = None
101
- paged_kv_cache: Optional[bool] = None
102
- use_inflight_batching: Optional[bool] = None
103
- enable_context_fmha: Optional[bool] = None
104
- use_gemm_plugin: Optional[str] = None
105
- use_weight_only: Optional[bool] = None
106
- output_dir: Optional[str] = None
107
- model_dir: Optional[str] = None
108
- ft_model_dir: Optional[str] = None
109
- dtype: Optional[str] = None
110
- int8_kv_cache: Optional[bool] = None
111
- use_smooth_quant: Optional[bool] = None
112
- per_token: Optional[bool] = None
113
- per_channel: Optional[bool] = None
114
- parallel_build: Optional[bool] = None
115
-
116
- # to disable warning because `model_dir` starts with `model_` prefix
117
- model_config = ConfigDict(protected_namespaces=()) # type: ignore
118
-
119
- def as_command_arguments(self) -> list:
120
- non_bool_args = [
121
- element
122
- for arg, value in self.dict().items()
123
- for element in [f"--{arg}", str(value)]
124
- if value is not None and not isinstance(value, bool)
125
- ]
126
- bool_args = [
127
- f"--{arg}"
128
- for arg, value in self.dict().items()
129
- if isinstance(value, bool) and value
130
- ]
131
- return non_bool_args + bool_args
132
-
133
-
134
- class CalibrationConfig(BaseModel):
135
- kv_cache: Optional[bool] = None # either to calibrate kv cache
136
- sq_alpha: Optional[float] = None
137
-
138
- def cache_path(self) -> Path:
139
- if self.kv_cache is not None:
140
- return Path("kv_cache")
141
- else:
142
- return Path(f"sq_{self.sq_alpha}")
143
-
144
-
145
- class EngineBuildArgs(BaseModel, use_enum_values=True):
146
- repo: Optional[str] = None
147
- args: Optional[ArgsConfig] = None
148
- quant: Optional[Quant] = None
149
- calibration: Optional[CalibrationConfig] = None
150
- engine_type: Optional[EngineType] = None
151
-
152
-
153
- class TrussBuildConfig(BaseModel):
154
- """
155
- This is a spec for what the config.yaml looks like to take advantage of TRT-LLM + TRT-LLM builds. We structure the
156
- configuration with the below top-level keys.
157
-
158
- Example (for building an engine)
159
- ```
160
- build:
161
- model_server: TRT_LLM
162
- arguments:
163
- tokenizer_repository: "mistralai/mistral-v2-instruct"
164
- arguments:
165
- max_input_len: 1024
166
- max_output_len: 1024
167
- max_batch_size: 64
168
- quant: "weights_kv_int8"
169
- tensor_parallel_count: 2
170
- pipeline_parallel_count: 1
171
- ```
172
-
173
- Example (for using an existing engine)
174
- ```
175
- build:
176
- model_server: TRT_LLM
177
- arguments:
178
- engine_repository: "baseten/mistral-v2-32k"
179
- tensor_parallel_count: 2
180
- pipeline_parallel_count: 1
181
- ```
182
-
183
- """
184
-
185
- tokenizer_repository: str
186
- quant: Quant = Quant.NO_QUANT
187
- pipeline_parallel_count: int = 1
188
- tensor_parallel_count: int = 1
189
- arguments: Optional[ArgsConfig] = None
190
- engine_repository: Optional[str] = None
191
- calibration: Optional[CalibrationConfig] = None
192
- engine_type: Optional[EngineType] = None
193
- _engine_build_args: Optional[EngineBuildArgs] = PrivateAttr(default=None)
194
-
195
- @property
196
- def engine_build_args(self) -> EngineBuildArgs:
197
- if self._engine_build_args is None:
198
- repo = self.tokenizer_repository
199
- quant = self.quant
200
- calibration = self.calibration
201
- engine_type = self.engine_type
202
- args = self.arguments or ArgsConfig()
203
- args.tp_size = self.tensor_parallel_count
204
- args.pp_size = self.pipeline_parallel_count
205
- self._engine_build_args = EngineBuildArgs(
206
- repo=repo,
207
- quant=quant,
208
- calibration=calibration,
209
- engine_type=engine_type,
210
- args=args,
211
- )
212
- return self._engine_build_args
213
-
214
- @property
215
- def requires_build(self):
216
- return self.engine_repository is None
@@ -1,246 +0,0 @@
1
- # Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- #
3
- # Redistribution and use in source and binary forms, with or without
4
- # modification, are permitted provided that the following conditions
5
- # are met:
6
- # * Redistributions of source code must retain the above copyright
7
- # notice, this list of conditions and the following disclaimer.
8
- # * Redistributions in binary form must reproduce the above copyright
9
- # notice, this list of conditions and the following disclaimer in the
10
- # documentation and/or other materials provided with the distribution.
11
- # * Neither the name of NVIDIA CORPORATION nor the names of its
12
- # contributors may be used to endorse or promote products derived
13
- # from this software without specific prior written permission.
14
- #
15
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16
- # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18
- # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19
- # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20
- # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21
- # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22
- # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23
- # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
- # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
-
27
- name: "ensemble"
28
- platform: "ensemble"
29
- max_batch_size: 2048
30
- input [
31
- {
32
- name: "text_input"
33
- data_type: TYPE_STRING
34
- dims: [ -1 ]
35
- },
36
- {
37
- name: "max_tokens"
38
- data_type: TYPE_UINT32
39
- dims: [ -1 ]
40
- },
41
- {
42
- name: "bad_words"
43
- data_type: TYPE_STRING
44
- dims: [ -1 ]
45
- },
46
- {
47
- name: "stop_words"
48
- data_type: TYPE_STRING
49
- dims: [ -1 ]
50
- },
51
- {
52
- name: "end_id"
53
- data_type: TYPE_UINT32
54
- dims: [ 1 ]
55
- optional: true
56
- },
57
- {
58
- name: "pad_id"
59
- data_type: TYPE_UINT32
60
- dims: [ 1 ]
61
- optional: true
62
- },
63
- {
64
- name: "top_k"
65
- data_type: TYPE_UINT32
66
- dims: [ 1 ]
67
- optional: true
68
- },
69
- {
70
- name: "top_p"
71
- data_type: TYPE_FP32
72
- dims: [ 1 ]
73
- optional: true
74
- },
75
- {
76
- name: "temperature"
77
- data_type: TYPE_FP32
78
- dims: [ 1 ]
79
- optional: true
80
- },
81
- {
82
- name: "length_penalty"
83
- data_type: TYPE_FP32
84
- dims: [ 1 ]
85
- optional: true
86
- },
87
- {
88
- name: "repetition_penalty"
89
- data_type: TYPE_FP32
90
- dims: [ 1 ]
91
- optional: true
92
- },
93
- {
94
- name: "min_length"
95
- data_type: TYPE_UINT32
96
- dims: [ 1 ]
97
- optional: true
98
- },
99
- {
100
- name: "presence_penalty"
101
- data_type: TYPE_FP32
102
- dims: [ 1 ]
103
- optional: true
104
- },
105
- {
106
- name: "random_seed"
107
- data_type: TYPE_UINT64
108
- dims: [ 1 ]
109
- optional: true
110
- },
111
- {
112
- name: "beam_width"
113
- data_type: TYPE_UINT32
114
- dims: [ 1 ]
115
- optional: true
116
- },
117
- {
118
- name: "stream"
119
- data_type: TYPE_BOOL
120
- dims: [ 1 ]
121
- optional: true
122
- }
123
- ]
124
- output [
125
- {
126
- name: "text_output"
127
- data_type: TYPE_STRING
128
- dims: [ -1, -1 ]
129
- }
130
- ]
131
- ensemble_scheduling {
132
- step [
133
- {
134
- model_name: "preprocessing"
135
- model_version: -1
136
- input_map {
137
- key: "QUERY"
138
- value: "text_input"
139
- }
140
- input_map {
141
- key: "REQUEST_OUTPUT_LEN"
142
- value: "max_tokens"
143
- }
144
- input_map {
145
- key: "BAD_WORDS_DICT"
146
- value: "bad_words"
147
- }
148
- input_map {
149
- key: "STOP_WORDS_DICT"
150
- value: "stop_words"
151
- }
152
- output_map {
153
- key: "REQUEST_INPUT_LEN"
154
- value: "_REQUEST_INPUT_LEN"
155
- }
156
- output_map {
157
- key: "INPUT_ID"
158
- value: "_INPUT_ID"
159
- }
160
- output_map {
161
- key: "REQUEST_OUTPUT_LEN"
162
- value: "_REQUEST_OUTPUT_LEN"
163
- }
164
- },
165
- {
166
- model_name: "tensorrt_llm"
167
- model_version: -1
168
- input_map {
169
- key: "input_ids"
170
- value: "_INPUT_ID"
171
- }
172
- input_map {
173
- key: "input_lengths"
174
- value: "_REQUEST_INPUT_LEN"
175
- }
176
- input_map {
177
- key: "request_output_len"
178
- value: "_REQUEST_OUTPUT_LEN"
179
- }
180
- input_map {
181
- key: "end_id"
182
- value: "end_id"
183
- }
184
- input_map {
185
- key: "pad_id"
186
- value: "pad_id"
187
- }
188
- input_map {
189
- key: "runtime_top_k"
190
- value: "top_k"
191
- }
192
- input_map {
193
- key: "runtime_top_p"
194
- value: "top_p"
195
- }
196
- input_map {
197
- key: "temperature"
198
- value: "temperature"
199
- }
200
- input_map {
201
- key: "len_penalty"
202
- value: "length_penalty"
203
- }
204
- input_map {
205
- key: "repetition_penalty"
206
- value: "repetition_penalty"
207
- }
208
- input_map {
209
- key: "min_length"
210
- value: "min_length"
211
- }
212
- input_map {
213
- key: "presence_penalty"
214
- value: "presence_penalty"
215
- }
216
- input_map {
217
- key: "random_seed"
218
- value: "random_seed"
219
- }
220
- input_map {
221
- key: "beam_width"
222
- value: "beam_width"
223
- }
224
- input_map {
225
- key: "streaming"
226
- value: "stream"
227
- }
228
- output_map {
229
- key: "output_ids"
230
- value: "_TOKENS_BATCH"
231
- }
232
- },
233
- {
234
- model_name: "postprocessing"
235
- model_version: -1
236
- input_map {
237
- key: "TOKENS_BATCH"
238
- value: "_TOKENS_BATCH"
239
- }
240
- output_map {
241
- key: "OUTPUT"
242
- value: "text_output"
243
- }
244
- }
245
- ]
246
- }
@@ -1,181 +0,0 @@
1
- # Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- #
3
- # Redistribution and use in source and binary forms, with or without
4
- # modification, are permitted provided that the following conditions
5
- # are met:
6
- # * Redistributions of source code must retain the above copyright
7
- # notice, this list of conditions and the following disclaimer.
8
- # * Redistributions in binary form must reproduce the above copyright
9
- # notice, this list of conditions and the following disclaimer in the
10
- # documentation and/or other materials provided with the distribution.
11
- # * Neither the name of NVIDIA CORPORATION nor the names of its
12
- # contributors may be used to endorse or promote products derived
13
- # from this software without specific prior written permission.
14
- #
15
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16
- # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18
- # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19
- # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20
- # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21
- # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22
- # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23
- # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
- # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
-
27
- import json
28
- import os
29
- from collections import OrderedDict
30
-
31
- import numpy as np
32
- import triton_python_backend_utils as pb_utils
33
- from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer
34
-
35
-
36
- class TritonPythonModel:
37
- """Your Python model must use the same class name. Every Python model
38
- that is created must have "TritonPythonModel" as the class name.
39
- """
40
-
41
- def initialize(self, args):
42
- """`initialize` is called only once when the model is being loaded.
43
- Implementing `initialize` function is optional. This function allows
44
- the model to initialize any state associated with this model.
45
- Parameters
46
- ----------
47
- args : dict
48
- Both keys and values are strings. The dictionary keys and values are:
49
- * model_config: A JSON string containing the model configuration
50
- * model_instance_kind: A string containing model instance kind
51
- * model_instance_device_id: A string containing model instance device ID
52
- * model_repository: Model repository path
53
- * model_version: Model version
54
- * model_name: Model name
55
- """
56
- # Parse model configs
57
- model_config = json.loads(args["model_config"])
58
- # NOTE: Keep this in sync with the truss model.py variable
59
- tokenizer_dir = os.environ["TRITON_TOKENIZER_REPOSITORY"]
60
- tokenizer_type = model_config["parameters"]["tokenizer_type"]["string_value"]
61
-
62
- if tokenizer_type == "t5":
63
- self.tokenizer = T5Tokenizer(vocab_file=tokenizer_dir, padding_side="left")
64
- elif tokenizer_type == "auto":
65
- self.tokenizer = AutoTokenizer.from_pretrained(
66
- tokenizer_dir, padding_side="left"
67
- )
68
- elif tokenizer_type == "llama":
69
- self.tokenizer = LlamaTokenizer.from_pretrained(
70
- tokenizer_dir, legacy=False, padding_side="left"
71
- )
72
- else:
73
- raise AttributeError(f"Unexpected tokenizer type: {tokenizer_type}")
74
- self.tokenizer.pad_token = self.tokenizer.eos_token
75
-
76
- # Parse model output configs
77
- output_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT")
78
- # Convert Triton types to numpy types
79
- self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
80
-
81
- self.state_dict = OrderedDict()
82
- # TODO(pankaj) This should come from the batch size
83
- self.cache_size = 2048
84
-
85
- def execute(self, requests):
86
- """`execute` must be implemented in every Python model. `execute`
87
- function receives a list of pb_utils.InferenceRequest as the only
88
- argument. This function is called when an inference is requested
89
- for this model. Depending on the batching configuration (e.g. Dynamic
90
- Batching) used, `requests` may contain multiple requests. Every
91
- Python model, must create one pb_utils.InferenceResponse for every
92
- pb_utils.InferenceRequest in `requests`. If there is an error, you can
93
- set the error argument when creating a pb_utils.InferenceResponse.
94
- Parameters
95
- ----------
96
- requests : list
97
- A list of pb_utils.InferenceRequest
98
- Returns
99
- -------
100
- list
101
- A list of pb_utils.InferenceResponse. The length of this list must
102
- be the same as `requests`
103
- """
104
-
105
- responses = []
106
-
107
- # Every Python backend must iterate over everyone of the requests
108
- # and create a pb_utils.InferenceResponse for each of them.
109
- for idx, request in enumerate(requests):
110
- # Get request ID
111
- request_id = request.request_id()
112
-
113
- # Get input tensors
114
- tokens_batch = (
115
- pb_utils.get_input_tensor_by_name(request, "TOKENS_BATCH")
116
- .as_numpy()
117
- .flatten()
118
- )
119
- if len(tokens_batch) == 0:
120
- continue
121
-
122
- # Postprocess output data
123
- prev_token = self._get_prev_token(request_id)
124
- self._store_prev_token(request_id, tokens_batch[-1])
125
- if prev_token is None:
126
- delta = self.tokenizer.decode(tokens_batch)
127
- else:
128
- # TODO(pankaj) Figure out how to make tokenizer.decode not
129
- # ignore initial whitespace so we can avoid this hack.
130
- # Get string with and without previous token and diff. This hack
131
- # is needed because tokenizer.decode strips initial whitespace.
132
- old_string = self.tokenizer.decode([prev_token])
133
- with_prev_token = np.concatenate(([prev_token], tokens_batch))
134
- new_string = self.tokenizer.decode(with_prev_token)
135
- delta = self._compute_delta(old_string, new_string)
136
-
137
- # Create output tensor
138
- output_tensor = pb_utils.Tensor(
139
- "OUTPUT", np.array([delta]).astype(self.output_dtype)
140
- )
141
- inference_response = pb_utils.InferenceResponse(
142
- output_tensors=[output_tensor]
143
- )
144
- responses.append(inference_response)
145
-
146
- return responses
147
-
148
- def finalize(self):
149
- print("Cleaning up...")
150
-
151
- def _store_prev_token(self, request_id, token):
152
- if request_id in self.state_dict:
153
- self.state_dict[request_id]["prev_token"] = token
154
-
155
- # Move request ID to end of queue to prevent it from being evicted
156
- self.state_dict.move_to_end(request_id)
157
- else:
158
- # Evict least recently used item if cache is full
159
- if len(self.state_dict) > self.cache_size:
160
- self.state_dict.popitem(last=False)
161
-
162
- self.state_dict[request_id] = {"prev_token": token}
163
-
164
- def _get_prev_token(self, request_id):
165
- if request_id in self.state_dict:
166
- return self.state_dict[request_id]["prev_token"]
167
- return None
168
-
169
- def _compute_delta(self, prev_str, new_str):
170
- delta = "".join(
171
- [
172
- char
173
- for index, char in enumerate(new_str)
174
- if index >= len(prev_str) or char != prev_str[index]
175
- ]
176
- )
177
- return delta
178
-
179
- def _postprocessing(self, tokens):
180
- decoded_tokens = self.tokenizer.decode(tokens)
181
- return decoded_tokens
@@ -1,64 +0,0 @@
1
- # Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- #
3
- # Redistribution and use in source and binary forms, with or without
4
- # modification, are permitted provided that the following conditions
5
- # are met:
6
- # * Redistributions of source code must retain the above copyright
7
- # notice, this list of conditions and the following disclaimer.
8
- # * Redistributions in binary form must reproduce the above copyright
9
- # notice, this list of conditions and the following disclaimer in the
10
- # documentation and/or other materials provided with the distribution.
11
- # * Neither the name of NVIDIA CORPORATION nor the names of its
12
- # contributors may be used to endorse or promote products derived
13
- # from this software without specific prior written permission.
14
- #
15
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16
- # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18
- # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19
- # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20
- # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21
- # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22
- # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23
- # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
- # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
-
27
- name: "postprocessing"
28
- backend: "python"
29
- max_batch_size: 2048
30
- input [
31
- {
32
- name: "TOKENS_BATCH"
33
- data_type: TYPE_INT32
34
- dims: [ -1, -1 ]
35
- }
36
- ]
37
- output [
38
- {
39
- name: "OUTPUT"
40
- data_type: TYPE_STRING
41
- dims: [ -1, -1 ]
42
- }
43
- ]
44
-
45
- parameters {
46
- key: "tokenizer_dir"
47
- value: {
48
- string_value: "NousResearch/Llama-2-7b-hf"
49
- }
50
- }
51
-
52
- parameters {
53
- key: "tokenizer_type"
54
- value: {
55
- string_value: "auto"
56
- }
57
- }
58
-
59
- instance_group [
60
- {
61
- count: 1
62
- kind: KIND_CPU
63
- }
64
- ]