truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

Files changed (362) hide show
  1. truss/__init__.py +10 -3
  2. truss/api/__init__.py +123 -0
  3. truss/api/definitions.py +51 -0
  4. truss/base/constants.py +116 -0
  5. truss/base/custom_types.py +29 -0
  6. truss/{errors.py → base/errors.py} +4 -0
  7. truss/base/trt_llm_config.py +310 -0
  8. truss/{truss_config.py → base/truss_config.py} +344 -31
  9. truss/{truss_spec.py → base/truss_spec.py} +20 -6
  10. truss/{validation.py → base/validation.py} +60 -11
  11. truss/cli/cli.py +841 -88
  12. truss/{remote → cli}/remote_cli.py +2 -7
  13. truss/contexts/docker_build_setup.py +67 -0
  14. truss/contexts/image_builder/cache_warmer.py +2 -8
  15. truss/contexts/image_builder/image_builder.py +1 -1
  16. truss/contexts/image_builder/serving_image_builder.py +292 -46
  17. truss/contexts/image_builder/util.py +1 -3
  18. truss/contexts/local_loader/docker_build_emulator.py +58 -0
  19. truss/contexts/local_loader/load_model_local.py +2 -2
  20. truss/contexts/local_loader/truss_module_loader.py +1 -1
  21. truss/contexts/local_loader/utils.py +1 -1
  22. truss/local/local_config.py +2 -6
  23. truss/local/local_config_handler.py +20 -5
  24. truss/patch/__init__.py +1 -0
  25. truss/patch/hash.py +4 -70
  26. truss/patch/signature.py +4 -16
  27. truss/patch/truss_dir_patch_applier.py +3 -78
  28. truss/remote/baseten/api.py +308 -23
  29. truss/remote/baseten/auth.py +3 -3
  30. truss/remote/baseten/core.py +257 -50
  31. truss/remote/baseten/custom_types.py +44 -0
  32. truss/remote/baseten/error.py +4 -0
  33. truss/remote/baseten/remote.py +369 -118
  34. truss/remote/baseten/service.py +118 -11
  35. truss/remote/baseten/utils/status.py +29 -0
  36. truss/remote/baseten/utils/tar.py +34 -22
  37. truss/remote/baseten/utils/transfer.py +36 -23
  38. truss/remote/remote_factory.py +14 -5
  39. truss/remote/truss_remote.py +72 -45
  40. truss/templates/base.Dockerfile.jinja +18 -16
  41. truss/templates/cache.Dockerfile.jinja +3 -3
  42. truss/{server → templates/control}/control/application.py +14 -35
  43. truss/{server → templates/control}/control/endpoints.py +39 -9
  44. truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
  45. truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
  46. truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
  47. truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
  48. truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
  49. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
  50. truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
  51. truss/{server → templates/control}/control/server.py +11 -6
  52. truss/templates/control/requirements.txt +9 -0
  53. truss/templates/custom_python_dx/my_model.py +28 -0
  54. truss/templates/docker_server/proxy.conf.jinja +42 -0
  55. truss/templates/docker_server/supervisord.conf.jinja +27 -0
  56. truss/templates/docker_server_requirements.txt +1 -0
  57. truss/templates/server/common/errors.py +231 -0
  58. truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
  59. truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
  60. truss/{server → templates/server}/common/retry.py +1 -0
  61. truss/{server → templates/server}/common/schema.py +11 -9
  62. truss/templates/server/common/tracing.py +157 -0
  63. truss/templates/server/main.py +9 -0
  64. truss/templates/server/model_wrapper.py +961 -0
  65. truss/templates/server/requirements.txt +21 -0
  66. truss/templates/server/truss_server.py +447 -0
  67. truss/templates/server.Dockerfile.jinja +62 -14
  68. truss/templates/shared/dynamic_config_resolver.py +28 -0
  69. truss/templates/shared/lazy_data_resolver.py +164 -0
  70. truss/templates/shared/log_config.py +125 -0
  71. truss/{server → templates}/shared/secrets_resolver.py +1 -2
  72. truss/{server → templates}/shared/serialization.py +31 -9
  73. truss/{server → templates}/shared/util.py +3 -13
  74. truss/templates/trtllm-audio/model/model.py +49 -0
  75. truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
  76. truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
  77. truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
  78. truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
  79. truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
  80. truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
  81. truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
  82. truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
  83. truss/templates/trtllm-briton/src/extension.py +64 -0
  84. truss/tests/conftest.py +302 -94
  85. truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
  86. truss/tests/contexts/local_loader/test_load_local.py +2 -2
  87. truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
  88. truss/tests/patch/test_calc_patch.py +439 -127
  89. truss/tests/patch/test_dir_signature.py +3 -12
  90. truss/tests/patch/test_hash.py +1 -1
  91. truss/tests/patch/test_signature.py +1 -1
  92. truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
  93. truss/tests/patch/test_types.py +2 -2
  94. truss/tests/remote/baseten/test_api.py +153 -58
  95. truss/tests/remote/baseten/test_auth.py +2 -1
  96. truss/tests/remote/baseten/test_core.py +160 -12
  97. truss/tests/remote/baseten/test_remote.py +489 -77
  98. truss/tests/remote/baseten/test_service.py +55 -0
  99. truss/tests/remote/test_remote_factory.py +16 -18
  100. truss/tests/remote/test_truss_remote.py +26 -17
  101. truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
  102. truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
  103. truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
  104. truss/tests/{server → templates/control}/control/test_server.py +79 -24
  105. truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
  106. truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
  107. truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
  108. truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
  109. truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
  110. truss/tests/{server → templates/server}/common/test_retry.py +3 -3
  111. truss/tests/templates/server/test_model_wrapper.py +248 -0
  112. truss/tests/{server → templates/server}/test_schema.py +3 -5
  113. truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
  114. truss/tests/test_build.py +9 -52
  115. truss/tests/test_config.py +336 -77
  116. truss/tests/test_context_builder_image.py +3 -11
  117. truss/tests/test_control_truss_patching.py +7 -12
  118. truss/tests/test_custom_server.py +38 -0
  119. truss/tests/test_data/context_builder_image_test/test.py +3 -0
  120. truss/tests/test_data/happy.ipynb +56 -0
  121. truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
  122. truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
  123. truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
  124. truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
  125. truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
  126. truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
  127. truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
  128. truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
  129. truss/tests/test_data/test_async_truss/__init__.py +0 -0
  130. truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
  131. truss/tests/test_data/test_basic_truss/__init__.py +0 -0
  132. truss/tests/test_data/test_basic_truss/config.yaml +16 -0
  133. truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
  134. truss/tests/test_data/test_build_commands/__init__.py +0 -0
  135. truss/tests/test_data/test_build_commands/config.yaml +13 -0
  136. truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
  137. truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
  138. truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
  139. truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
  140. truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
  141. truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
  142. truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
  143. truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
  144. truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
  145. truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
  146. truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
  147. truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
  148. truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
  149. truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
  150. truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
  151. truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
  152. truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
  153. truss/tests/test_data/test_openai/__init__.py +0 -0
  154. truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
  155. truss/tests/test_data/test_openai/model/__init__.py +0 -0
  156. truss/tests/test_data/test_openai/model/model.py +15 -0
  157. truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
  158. truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
  159. truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
  160. truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
  161. truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
  162. truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
  163. truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
  164. truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
  165. truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
  166. truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
  167. truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
  168. truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
  169. truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
  170. truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
  171. truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
  172. truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
  173. truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
  174. truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
  175. truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
  176. truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
  177. truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
  178. truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
  179. truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
  180. truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
  181. truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
  182. truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
  183. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
  184. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
  185. truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
  186. truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
  187. truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
  188. truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
  189. truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
  190. truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
  191. truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
  192. truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
  193. truss/tests/test_data/test_truss/__init__.py +0 -0
  194. truss/tests/test_data/test_truss/config.yaml +4 -0
  195. truss/tests/test_data/test_truss/model/__init__.py +0 -0
  196. truss/tests/test_data/test_truss/model/dummy +0 -0
  197. truss/tests/test_data/test_truss/packages/__init__.py +0 -0
  198. truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
  199. truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
  200. truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
  201. truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
  202. truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
  203. truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
  204. truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
  205. truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
  206. truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
  207. truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
  208. truss/tests/test_docker.py +2 -1
  209. truss/tests/test_model_inference.py +1340 -292
  210. truss/tests/test_model_schema.py +33 -26
  211. truss/tests/test_testing_utilities_for_other_tests.py +50 -5
  212. truss/tests/test_truss_gatherer.py +3 -5
  213. truss/tests/test_truss_handle.py +62 -59
  214. truss/tests/test_util.py +2 -1
  215. truss/tests/test_validation.py +15 -13
  216. truss/tests/trt_llm/test_trt_llm_config.py +41 -0
  217. truss/tests/trt_llm/test_validation.py +91 -0
  218. truss/tests/util/test_config_checks.py +40 -0
  219. truss/tests/util/test_env_vars.py +14 -0
  220. truss/tests/util/test_path.py +10 -23
  221. truss/trt_llm/config_checks.py +43 -0
  222. truss/trt_llm/validation.py +42 -0
  223. truss/truss_handle/__init__.py +0 -0
  224. truss/truss_handle/build.py +122 -0
  225. truss/{decorators.py → truss_handle/decorators.py} +1 -1
  226. truss/truss_handle/patch/__init__.py +0 -0
  227. truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
  228. truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
  229. truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
  230. truss/truss_handle/patch/hash.py +71 -0
  231. truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
  232. truss/truss_handle/patch/signature.py +22 -0
  233. truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
  234. truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
  235. truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
  236. truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
  237. truss/util/.truss_ignore +3 -0
  238. truss/{docker.py → util/docker.py} +6 -2
  239. truss/util/download.py +6 -15
  240. truss/util/env_vars.py +41 -0
  241. truss/util/log_utils.py +52 -0
  242. truss/util/path.py +20 -20
  243. truss/util/requirements.py +11 -0
  244. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
  245. truss-0.60.0.dist-info/RECORD +324 -0
  246. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
  247. truss-0.60.0.dist-info/entry_points.txt +4 -0
  248. truss_chains/__init__.py +71 -0
  249. truss_chains/definitions.py +756 -0
  250. truss_chains/deployment/__init__.py +0 -0
  251. truss_chains/deployment/code_gen.py +816 -0
  252. truss_chains/deployment/deployment_client.py +871 -0
  253. truss_chains/framework.py +1480 -0
  254. truss_chains/public_api.py +231 -0
  255. truss_chains/py.typed +0 -0
  256. truss_chains/pydantic_numpy.py +131 -0
  257. truss_chains/reference_code/reference_chainlet.py +34 -0
  258. truss_chains/reference_code/reference_model.py +10 -0
  259. truss_chains/remote_chainlet/__init__.py +0 -0
  260. truss_chains/remote_chainlet/model_skeleton.py +60 -0
  261. truss_chains/remote_chainlet/stub.py +380 -0
  262. truss_chains/remote_chainlet/utils.py +332 -0
  263. truss_chains/streaming.py +378 -0
  264. truss_chains/utils.py +178 -0
  265. CODE_OF_CONDUCT.md +0 -131
  266. CONTRIBUTING.md +0 -48
  267. README.md +0 -137
  268. context_builder.Dockerfile +0 -24
  269. truss/blob/blob_backend.py +0 -10
  270. truss/blob/blob_backend_registry.py +0 -23
  271. truss/blob/http_public_blob_backend.py +0 -23
  272. truss/build/__init__.py +0 -2
  273. truss/build/build.py +0 -143
  274. truss/build/configure.py +0 -63
  275. truss/cli/__init__.py +0 -2
  276. truss/cli/console.py +0 -5
  277. truss/cli/create.py +0 -5
  278. truss/config/trt_llm.py +0 -81
  279. truss/constants.py +0 -61
  280. truss/model_inference.py +0 -123
  281. truss/patch/types.py +0 -30
  282. truss/pytest.ini +0 -7
  283. truss/server/common/errors.py +0 -100
  284. truss/server/common/termination_handler_middleware.py +0 -64
  285. truss/server/common/truss_server.py +0 -389
  286. truss/server/control/patch/model_code_patch_applier.py +0 -46
  287. truss/server/control/patch/requirement_name_identifier.py +0 -17
  288. truss/server/inference_server.py +0 -29
  289. truss/server/model_wrapper.py +0 -434
  290. truss/server/shared/logging.py +0 -81
  291. truss/templates/trtllm/model/model.py +0 -97
  292. truss/templates/trtllm/packages/build_engine_utils.py +0 -34
  293. truss/templates/trtllm/packages/constants.py +0 -11
  294. truss/templates/trtllm/packages/schema.py +0 -216
  295. truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
  296. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
  297. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
  298. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
  299. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
  300. truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
  301. truss/templates/trtllm/packages/triton_client.py +0 -150
  302. truss/templates/trtllm/packages/utils.py +0 -43
  303. truss/test_data/context_builder_image_test/test.py +0 -4
  304. truss/test_data/happy.ipynb +0 -54
  305. truss/test_data/model_load_failure_test/config.yaml +0 -2
  306. truss/test_data/test_concurrency_truss/config.yaml +0 -2
  307. truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
  308. truss/test_data/test_streaming_truss/config.yaml +0 -3
  309. truss/test_data/test_truss/config.yaml +0 -2
  310. truss/tests/server/common/test_termination_handler_middleware.py +0 -93
  311. truss/tests/server/control/test_model_container_patch_applier.py +0 -203
  312. truss/tests/server/core/server/common/test_util.py +0 -19
  313. truss/tests/server/test_model_wrapper.py +0 -87
  314. truss/util/data_structures.py +0 -16
  315. truss-0.10.0rc1.dist-info/RECORD +0 -216
  316. truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
  317. truss/{server/shared → base}/__init__.py +0 -0
  318. truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
  319. truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
  320. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
  321. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
  322. truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
  323. truss/{server → templates/server}/common/__init__.py +0 -0
  324. truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
  325. truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
  326. truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
  327. truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
  328. truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
  329. truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
  330. truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
  331. truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
  332. truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
  333. truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
  334. truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
  335. truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
  336. truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
  337. truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
  338. truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
  339. truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
  340. truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
  341. truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
  342. truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
  343. truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
  344. truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
  345. truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
  346. truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
  347. /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
  348. /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
  349. /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
  350. /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
  351. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
  352. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
  353. /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
  354. /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
  355. /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
  356. /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
  357. /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
  358. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
  359. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
  360. /truss/{patch → truss_handle/patch}/constants.py +0 -0
  361. /truss/{notebook.py → util/notebook.py} +0 -0
  362. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
@@ -0,0 +1,215 @@
1
+ import io
2
+ import re
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ import tensorrt_llm
7
+ import torch
8
+ import torchaudio
9
+ from torch import Tensor
10
+
11
+ from whisper_trt.assets import download_assets
12
+ from whisper_trt.batching import WhisperBatchProcessor
13
+ from whisper_trt.custom_types import (
14
+ DEFAULT_MAX_NEW_TOKENS,
15
+ DEFAULT_NUM_BEAMS,
16
+ SUPPORTED_SAMPLE_RATE,
17
+ BatchWhisperItem,
18
+ Segment,
19
+ WhisperResult,
20
+ )
21
+ from whisper_trt.modeling import WhisperDecoding, WhisperEncoding
22
+ from whisper_trt.tokenizer import REVERSED_LANGUAGES, get_tokenizer
23
+ from whisper_trt.utils import log_mel_spectrogram
24
+
25
+ SEGMENTS_PATTERN = re.compile(r"<\|([\d.]+)\|>([^<]+)<\|([\d.]+)\|>")
26
+ LANG_CODE_PATTERN = re.compile(r"<\|([a-z]{2})\|>")
27
+
28
+
29
+ class WhisperModel(object):
30
+ def __init__(
31
+ self,
32
+ engine_dir,
33
+ tokenizer_name="multilingual",
34
+ debug_mode=False,
35
+ assets_dir=None,
36
+ max_queue_time=0.01, # 10 ms by default
37
+ ):
38
+ world_size = 1
39
+ runtime_rank = tensorrt_llm.mpi_rank()
40
+ runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank)
41
+ torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
42
+
43
+ engine_dir = Path(engine_dir)
44
+
45
+ self.assets_dir = assets_dir
46
+ if self.assets_dir is None:
47
+ self.assets_dir = download_assets()
48
+
49
+ self.encoder = WhisperEncoding(engine_dir)
50
+ self.decoder = WhisperDecoding(
51
+ engine_dir, runtime_mapping, debug_mode=debug_mode
52
+ )
53
+ self.batch_size = self.decoder.decoder_config["max_batch_size"]
54
+ self.n_mels = self.encoder.n_mels
55
+ self.tokenizer = get_tokenizer(
56
+ name=tokenizer_name,
57
+ num_languages=self.encoder.num_languages,
58
+ tokenizer_dir=self.assets_dir,
59
+ )
60
+ self.eot_id = self.tokenizer.encode(
61
+ "<|endoftext|>", allowed_special=self.tokenizer.special_tokens_set
62
+ )[0]
63
+
64
+ self.batch_processor = WhisperBatchProcessor(
65
+ self, max_batch_size=self.batch_size, max_queue_time=max_queue_time
66
+ )
67
+
68
+ def preprocess_audio(self, binary_data) -> dict:
69
+ audio_stream = io.BytesIO(binary_data)
70
+ waveform, sample_rate = torchaudio.load(audio_stream)
71
+
72
+ # Resample audio to rate compatible with what the model was trained at
73
+ if sample_rate != SUPPORTED_SAMPLE_RATE:
74
+ waveform = torchaudio.transforms.Resample(
75
+ orig_freq=sample_rate, new_freq=SUPPORTED_SAMPLE_RATE
76
+ )(waveform)
77
+ sample_rate = SUPPORTED_SAMPLE_RATE
78
+
79
+ return waveform
80
+
81
+ def _get_text_prefix(
82
+ self,
83
+ language: str = "english",
84
+ prompt: Optional[str] = None,
85
+ timestamps: bool = False,
86
+ task: str = "transcribe",
87
+ prefix: Optional[str] = None,
88
+ ):
89
+ try:
90
+ language_code = REVERSED_LANGUAGES[language]
91
+ except KeyError:
92
+ language_code = language
93
+ text_prefix = f"<|startoftranscript|><|{language_code}|><|{task}|>"
94
+ if prompt is not None:
95
+ text_prefix = f"<|startofprev|> {prompt}" + text_prefix
96
+ if timestamps:
97
+ text_prefix += "<|0.00|>"
98
+ else:
99
+ text_prefix += "<|notimestamps|>"
100
+ if prefix is not None:
101
+ text_prefix += prefix
102
+ return text_prefix
103
+
104
+ def process_batch(
105
+ self,
106
+ mel_batch,
107
+ decoder_input_ids,
108
+ num_beams=DEFAULT_NUM_BEAMS,
109
+ max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
110
+ ) -> Tensor:
111
+ encoder_output = self.encoder.get_audio_features(mel_batch)
112
+ output_ids = self.decoder.generate(
113
+ decoder_input_ids,
114
+ encoder_output,
115
+ self.eot_id,
116
+ max_new_tokens=max_new_tokens,
117
+ num_beams=num_beams,
118
+ )
119
+ return output_ids
120
+
121
+ def decode_output_ids(self, output_ids, text_prefix):
122
+ text = self.tokenizer.decode(output_ids[0]).strip()
123
+ text.replace(text_prefix, "")
124
+ return text
125
+
126
+ async def detect_audio_and_language(self, mel) -> Optional[str]:
127
+ """
128
+ Detects the audio and language from the given mel spectrogram.
129
+
130
+ Args:
131
+ mel: The mel spectrogram of the audio.
132
+
133
+ Returns:
134
+ The detected language code, or None if no speech is detected.
135
+ """
136
+ text_prefix = "<|startoftranscript|>"
137
+
138
+ prompt_ids = self.tokenizer.encode(
139
+ text_prefix, allowed_special=self.tokenizer.special_tokens_set
140
+ )
141
+
142
+ output_ids = await self.batch_processor.process(
143
+ item=BatchWhisperItem(mel=mel, prompt_ids=prompt_ids, max_new_tokens=1)
144
+ )
145
+ text = self.decode_output_ids(output_ids, text_prefix)
146
+ if text == "<|nospeech|>":
147
+ return None
148
+ return text.replace(text_prefix, "").replace("<|", "").replace("|>", "")
149
+
150
+ async def transcribe(
151
+ self,
152
+ waveform,
153
+ prompt: Optional[str] = None,
154
+ language: Optional[str] = None,
155
+ timestamps: bool = False,
156
+ num_beams: int = DEFAULT_NUM_BEAMS,
157
+ prefix: Optional[str] = None,
158
+ task: str = "transcribe",
159
+ max_new_tokens=128,
160
+ ):
161
+ mel = await log_mel_spectrogram(
162
+ waveform.numpy(),
163
+ self.n_mels,
164
+ device="cuda",
165
+ mel_filters_dir=self.assets_dir,
166
+ )
167
+ mel = mel.type(torch.float16)
168
+ if language is None:
169
+ language = await self.detect_audio_and_language(mel)
170
+ if language is None:
171
+ # No speech was detected. Can result empty segments
172
+ return WhisperResult(segments=[], language_code=None)
173
+ text_prefix = self._get_text_prefix(
174
+ language=language,
175
+ prompt=prompt,
176
+ timestamps=timestamps,
177
+ prefix=prefix,
178
+ task=task,
179
+ )
180
+
181
+ prompt_ids = self.tokenizer.encode(
182
+ text_prefix, allowed_special=self.tokenizer.special_tokens_set
183
+ )
184
+
185
+ output_ids: Tensor = await self.batch_processor.process(
186
+ item=BatchWhisperItem(
187
+ mel=mel, prompt_ids=prompt_ids, max_new_tokens=max_new_tokens
188
+ )
189
+ )
190
+
191
+ return self._postprocess_transcript(
192
+ self.decode_output_ids(output_ids, text_prefix)
193
+ )
194
+
195
+ def _postprocess_transcript(self, transcribed_text: str) -> WhisperResult:
196
+ """
197
+ Post-process the output of the transcription model.
198
+ """
199
+ language_code = LANG_CODE_PATTERN.findall(transcribed_text)[0]
200
+
201
+ # Find all matches in the input string
202
+ matches = SEGMENTS_PATTERN.findall(transcribed_text)
203
+
204
+ # Process matches to create the desired output format
205
+ segments = []
206
+ for match in matches:
207
+ start, text, end = match
208
+
209
+ segments.append(
210
+ Segment(
211
+ **{"start": float(start), "end": float(end), "text": text.strip()}
212
+ )
213
+ )
214
+
215
+ return WhisperResult(segments=segments, language_code=language_code)
@@ -0,0 +1,25 @@
1
+ import os
2
+ import urllib.request
3
+
4
+
5
+ def _download_files(urls):
6
+ ASSETS_DIR = os.path.join(
7
+ os.path.expanduser("~"), ".cache", "whisper-trt", "assets"
8
+ )
9
+ os.makedirs(ASSETS_DIR, exist_ok=True)
10
+
11
+ for url in urls:
12
+ file_name = os.path.basename(url)
13
+ file_path = os.path.join(ASSETS_DIR, file_name)
14
+ if not os.path.exists(file_path):
15
+ urllib.request.urlretrieve(url, file_path)
16
+ return ASSETS_DIR
17
+
18
+
19
+ def download_assets():
20
+ return _download_files(
21
+ [
22
+ "https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/multilingual.tiktoken",
23
+ "https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/mel_filters.npz",
24
+ ]
25
+ )
@@ -0,0 +1,52 @@
1
+ import logging
2
+ from typing import TYPE_CHECKING, List
3
+
4
+ if TYPE_CHECKING:
5
+ from whisper_trt import WhisperModel
6
+
7
+ import torch
8
+ from async_batcher.batcher import AsyncBatcher
9
+ from torch import Tensor
10
+
11
+ from whisper_trt.custom_types import DEFAULT_NUM_BEAMS, BatchWhisperItem
12
+
13
+ FIXED_TEXT_PRFIX = "<|startoftranscript|><|en|><|transcribe|><|0.00|>"
14
+
15
+
16
+ class WhisperBatchProcessor(AsyncBatcher[List[BatchWhisperItem], List[str]]):
17
+ def __init__(self, model, *args, **kwargs):
18
+ super().__init__(*args, **kwargs)
19
+ self.model: "WhisperModel" = model
20
+
21
+ def concat_and_pad_mels(self, tensors: List[Tensor]):
22
+ """Concatenates mel spectrograms to the maximum batch size using the last mel spectrogram as padding."""
23
+ while len(tensors) < self.max_batch_size:
24
+ tensors.append(tensors[-1])
25
+ res = torch.cat(tensors, dim=0).type(torch.float16)
26
+ return res
27
+
28
+ def concat_and_pad_prompts(self, prompts: List[List]) -> Tensor:
29
+ """Concatenates prompts to the maximum batch size using the last prompt as padding."""
30
+ while len(prompts) < self.max_batch_size:
31
+ prompts.append(prompts[-1])
32
+ return Tensor(prompts)
33
+
34
+ def process_batch(self, batch: List[BatchWhisperItem]) -> List[float]:
35
+ logging.warn(f"Processing batch of size {len(batch)}")
36
+
37
+ # Need to pad the batch up to the maximum batch size
38
+ decoder_input_ids = self.concat_and_pad_prompts(
39
+ [item.prompt_ids for item in batch]
40
+ )
41
+ mel_batch = self.concat_and_pad_mels([item.mel for item in batch])
42
+
43
+ max_new_tokens = max(item.max_new_tokens for item in batch)
44
+ batch_result = self.model.process_batch(
45
+ mel_batch,
46
+ decoder_input_ids,
47
+ max_new_tokens=max_new_tokens,
48
+ num_beams=DEFAULT_NUM_BEAMS,
49
+ )
50
+ # Splicing to len(batch) is needed to remove the padding we add
51
+ # during `concat_and_pad_mels` and `concat_and_pad_prompts`
52
+ return batch_result[: len(batch)]
@@ -0,0 +1,26 @@
1
+ from typing import List, NamedTuple
2
+
3
+ from pydantic import BaseModel
4
+ from torch import Tensor
5
+
6
+ SUPPORTED_SAMPLE_RATE = 16_000
7
+ DEFAULT_NUM_BEAMS = 1
8
+ DEFAULT_MAX_NEW_TOKENS = 128
9
+
10
+
11
+ class BatchWhisperItem(NamedTuple):
12
+ mel: Tensor
13
+ prompt_ids: Tensor
14
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS
15
+ num_beams: int = DEFAULT_NUM_BEAMS
16
+
17
+
18
+ class Segment(BaseModel):
19
+ start: float
20
+ end: float
21
+ text: str
22
+
23
+
24
+ class WhisperResult(BaseModel):
25
+ segments: List[Segment]
26
+ language_code: str
@@ -0,0 +1,184 @@
1
+ from collections import OrderedDict
2
+
3
+ import tensorrt_llm
4
+ import tensorrt_llm.logger as logger
5
+ import torch
6
+ from tensorrt_llm._utils import str_dtype_to_trt, trt_dtype_to_torch
7
+ from tensorrt_llm.runtime import ModelConfig, SamplingConfig
8
+ from tensorrt_llm.runtime.session import Session, TensorInfo
9
+
10
+ from whisper_trt.custom_types import DEFAULT_NUM_BEAMS
11
+ from whisper_trt.utils import read_config, remove_tensor_padding
12
+
13
+
14
+ class WhisperEncoding:
15
+ def __init__(self, engine_dir):
16
+ self.session = self.get_session(engine_dir)
17
+
18
+ def get_session(self, engine_dir):
19
+ config = read_config("encoder", engine_dir)
20
+ self.encoder_config = config
21
+
22
+ self.dtype = config["dtype"]
23
+ self.n_mels = config["n_mels"]
24
+ self.num_languages = config["num_languages"]
25
+
26
+ serialize_path = engine_dir / "encoder" / "rank0.engine"
27
+
28
+ with open(serialize_path, "rb") as f:
29
+ session = Session.from_serialized_engine(f.read())
30
+
31
+ return session
32
+
33
+ def get_audio_features(self, mel):
34
+ input_lengths = torch.tensor(
35
+ [mel.shape[2] // 2 for _ in range(mel.shape[0])],
36
+ dtype=torch.int32,
37
+ device=mel.device,
38
+ )
39
+ if self.encoder_config["plugin_config"]["remove_input_padding"]:
40
+ mel_input_lengths = torch.full(
41
+ (mel.shape[0],), mel.shape[2], dtype=torch.int32, device="cuda"
42
+ )
43
+ # mel B,D,T -> B,T,D -> BxT, D
44
+ mel = mel.transpose(1, 2)
45
+ mel = remove_tensor_padding(mel, mel_input_lengths)
46
+
47
+ inputs = OrderedDict()
48
+ inputs["input_features"] = mel
49
+ inputs["input_lengths"] = input_lengths
50
+
51
+ output_list = [
52
+ TensorInfo("input_features", str_dtype_to_trt(self.dtype), mel.shape),
53
+ TensorInfo("input_lengths", str_dtype_to_trt("int32"), input_lengths.shape),
54
+ ]
55
+
56
+ output_info = (self.session).infer_shapes(output_list)
57
+
58
+ logger.debug(f"output info {output_info}")
59
+ outputs = {
60
+ t.name: torch.empty(
61
+ tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device="cuda"
62
+ )
63
+ for t in output_info
64
+ }
65
+ stream = torch.cuda.current_stream()
66
+ ok = self.session.run(inputs=inputs, outputs=outputs, stream=stream.cuda_stream)
67
+ assert ok, "Engine execution failed"
68
+ stream.synchronize()
69
+ audio_features = outputs["encoder_output"]
70
+ return audio_features
71
+
72
+
73
+ class WhisperDecoding:
74
+ def __init__(self, engine_dir, runtime_mapping, debug_mode=False):
75
+ self.decoder_config = read_config("decoder", engine_dir)
76
+ self.decoder_generation_session = self.get_session(
77
+ engine_dir, runtime_mapping, debug_mode
78
+ )
79
+
80
+ def get_session(self, engine_dir, runtime_mapping, debug_mode=False):
81
+ serialize_path = engine_dir / "decoder" / "rank0.engine"
82
+ with open(serialize_path, "rb") as f:
83
+ decoder_engine_buffer = f.read()
84
+
85
+ decoder_model_config = ModelConfig(
86
+ max_batch_size=self.decoder_config["max_batch_size"],
87
+ max_beam_width=self.decoder_config["max_beam_width"],
88
+ num_heads=self.decoder_config["num_attention_heads"],
89
+ num_kv_heads=self.decoder_config["num_attention_heads"],
90
+ hidden_size=self.decoder_config["hidden_size"],
91
+ vocab_size=self.decoder_config["vocab_size"],
92
+ cross_attention=True,
93
+ num_layers=self.decoder_config["num_hidden_layers"],
94
+ gpt_attention_plugin=self.decoder_config["plugin_config"][
95
+ "gpt_attention_plugin"
96
+ ],
97
+ remove_input_padding=self.decoder_config["plugin_config"][
98
+ "remove_input_padding"
99
+ ],
100
+ has_position_embedding=self.decoder_config["has_position_embedding"],
101
+ dtype=self.decoder_config["dtype"],
102
+ has_token_type_embedding=False,
103
+ )
104
+ decoder_generation_session = tensorrt_llm.runtime.GenerationSession(
105
+ decoder_model_config,
106
+ decoder_engine_buffer,
107
+ runtime_mapping,
108
+ debug_mode=debug_mode,
109
+ )
110
+
111
+ return decoder_generation_session
112
+
113
+ def generate(
114
+ self,
115
+ decoder_input_ids,
116
+ encoder_outputs,
117
+ eot_id,
118
+ max_new_tokens=40,
119
+ num_beams=DEFAULT_NUM_BEAMS,
120
+ ):
121
+ encoder_input_lengths = torch.tensor(
122
+ [encoder_outputs.shape[1] for x in range(encoder_outputs.shape[0])],
123
+ dtype=torch.int32,
124
+ device="cuda",
125
+ )
126
+
127
+ decoder_input_lengths = torch.tensor(
128
+ [decoder_input_ids.shape[-1] for _ in range(decoder_input_ids.shape[0])],
129
+ dtype=torch.int32,
130
+ device="cuda",
131
+ )
132
+ decoder_max_input_length = torch.max(decoder_input_lengths).item()
133
+
134
+ cross_attention_mask = (
135
+ torch.ones([encoder_outputs.shape[0], 1, encoder_outputs.shape[1]])
136
+ .int()
137
+ .cuda()
138
+ )
139
+
140
+ # generation config
141
+ sampling_config = SamplingConfig(
142
+ end_id=eot_id, pad_id=eot_id, num_beams=num_beams
143
+ )
144
+ self.decoder_generation_session.setup(
145
+ decoder_input_lengths.size(0),
146
+ decoder_max_input_length,
147
+ max_new_tokens,
148
+ beam_width=num_beams,
149
+ encoder_max_input_length=encoder_outputs.shape[1],
150
+ )
151
+
152
+ torch.cuda.synchronize()
153
+
154
+ decoder_input_ids = decoder_input_ids.type(torch.int32).cuda()
155
+ if self.decoder_config["plugin_config"]["remove_input_padding"]:
156
+ # 50256 is the index of <pad> for all whisper models' decoder
157
+ WHISPER_PAD_TOKEN_ID = 50256
158
+ decoder_input_ids = remove_tensor_padding(
159
+ decoder_input_ids, pad_value=WHISPER_PAD_TOKEN_ID
160
+ )
161
+ if encoder_outputs.dim() == 3:
162
+ encoder_output_lens = torch.full(
163
+ (encoder_outputs.shape[0],),
164
+ encoder_outputs.shape[1],
165
+ dtype=torch.int32,
166
+ device="cuda",
167
+ )
168
+
169
+ encoder_outputs = remove_tensor_padding(
170
+ encoder_outputs, encoder_output_lens
171
+ )
172
+ output_ids = self.decoder_generation_session.decode(
173
+ decoder_input_ids,
174
+ decoder_input_lengths,
175
+ sampling_config,
176
+ encoder_output=encoder_outputs,
177
+ encoder_input_lengths=encoder_input_lengths,
178
+ cross_attention_mask=cross_attention_mask,
179
+ )
180
+ torch.cuda.synchronize()
181
+
182
+ # get the list of int from output_ids tensor
183
+ output_ids = output_ids.cpu().numpy().tolist()
184
+ return output_ids
@@ -0,0 +1,185 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from https://github.com/openai/whisper/blob/main/whisper/tokenizer.py
16
+ import base64
17
+ import os
18
+ from typing import Optional
19
+
20
+ import tiktoken
21
+
22
+ LANGUAGES = {
23
+ "en": "english",
24
+ "zh": "chinese",
25
+ "de": "german",
26
+ "es": "spanish",
27
+ "ru": "russian",
28
+ "ko": "korean",
29
+ "fr": "french",
30
+ "ja": "japanese",
31
+ "pt": "portuguese",
32
+ "tr": "turkish",
33
+ "pl": "polish",
34
+ "ca": "catalan",
35
+ "nl": "dutch",
36
+ "ar": "arabic",
37
+ "sv": "swedish",
38
+ "it": "italian",
39
+ "id": "indonesian",
40
+ "hi": "hindi",
41
+ "fi": "finnish",
42
+ "vi": "vietnamese",
43
+ "he": "hebrew",
44
+ "uk": "ukrainian",
45
+ "el": "greek",
46
+ "ms": "malay",
47
+ "cs": "czech",
48
+ "ro": "romanian",
49
+ "da": "danish",
50
+ "hu": "hungarian",
51
+ "ta": "tamil",
52
+ "no": "norwegian",
53
+ "th": "thai",
54
+ "ur": "urdu",
55
+ "hr": "croatian",
56
+ "bg": "bulgarian",
57
+ "lt": "lithuanian",
58
+ "la": "latin",
59
+ "mi": "maori",
60
+ "ml": "malayalam",
61
+ "cy": "welsh",
62
+ "sk": "slovak",
63
+ "te": "telugu",
64
+ "fa": "persian",
65
+ "lv": "latvian",
66
+ "bn": "bengali",
67
+ "sr": "serbian",
68
+ "az": "azerbaijani",
69
+ "sl": "slovenian",
70
+ "kn": "kannada",
71
+ "et": "estonian",
72
+ "mk": "macedonian",
73
+ "br": "breton",
74
+ "eu": "basque",
75
+ "is": "icelandic",
76
+ "hy": "armenian",
77
+ "ne": "nepali",
78
+ "mn": "mongolian",
79
+ "bs": "bosnian",
80
+ "kk": "kazakh",
81
+ "sq": "albanian",
82
+ "sw": "swahili",
83
+ "gl": "galician",
84
+ "mr": "marathi",
85
+ "pa": "punjabi",
86
+ "si": "sinhala",
87
+ "km": "khmer",
88
+ "sn": "shona",
89
+ "yo": "yoruba",
90
+ "so": "somali",
91
+ "af": "afrikaans",
92
+ "oc": "occitan",
93
+ "ka": "georgian",
94
+ "be": "belarusian",
95
+ "tg": "tajik",
96
+ "sd": "sindhi",
97
+ "gu": "gujarati",
98
+ "am": "amharic",
99
+ "yi": "yiddish",
100
+ "lo": "lao",
101
+ "uz": "uzbek",
102
+ "fo": "faroese",
103
+ "ht": "haitian creole",
104
+ "ps": "pashto",
105
+ "tk": "turkmen",
106
+ "nn": "nynorsk",
107
+ "mt": "maltese",
108
+ "sa": "sanskrit",
109
+ "lb": "luxembourgish",
110
+ "my": "myanmar",
111
+ "bo": "tibetan",
112
+ "tl": "tagalog",
113
+ "mg": "malagasy",
114
+ "as": "assamese",
115
+ "tt": "tatar",
116
+ "haw": "hawaiian",
117
+ "ln": "lingala",
118
+ "ha": "hausa",
119
+ "ba": "bashkir",
120
+ "jw": "javanese",
121
+ "su": "sundanese",
122
+ "yue": "cantonese",
123
+ }
124
+
125
+ REVERSED_LANGUAGES = {v: k for k, v in LANGUAGES.items()}
126
+
127
+
128
+ def get_tokenizer(
129
+ name: str = "multilingual",
130
+ num_languages: int = 99,
131
+ tokenizer_dir: Optional[str] = None,
132
+ ):
133
+ if tokenizer_dir is None:
134
+ vocab_path = os.path.join(os.path.dirname(__file__), f"assets/{name}.tiktoken")
135
+ else:
136
+ vocab_path = os.path.join(tokenizer_dir, f"{name}.tiktoken")
137
+ ranks = {
138
+ base64.b64decode(token): int(rank)
139
+ for token, rank in (line.split() for line in open(vocab_path) if line)
140
+ }
141
+ n_vocab = len(ranks)
142
+ special_tokens = {}
143
+
144
+ specials = [
145
+ "<|endoftext|>",
146
+ "<|startoftranscript|>",
147
+ *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
148
+ "<|translate|>",
149
+ "<|transcribe|>",
150
+ "<|startoflm|>",
151
+ "<|startofprev|>",
152
+ "<|nospeech|>",
153
+ "<|notimestamps|>",
154
+ *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
155
+ ]
156
+
157
+ for token in specials:
158
+ special_tokens[token] = n_vocab
159
+ n_vocab += 1
160
+
161
+ return tiktoken.Encoding(
162
+ name=os.path.basename(vocab_path),
163
+ explicit_n_vocab=n_vocab,
164
+ pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
165
+ mergeable_ranks=ranks,
166
+ special_tokens=special_tokens,
167
+ )
168
+
169
+
170
+ if __name__ == "__main__":
171
+ enc = get_tokenizer()
172
+ mytest_str = "<|startofprev|> Nvidia<|startoftranscript|><|en|><|transcribe|>"
173
+ encoding = enc.encode(mytest_str, allowed_special=enc.special_tokens_set)
174
+ mystr = enc.decode([50361, 45, 43021, 50258, 50259, 50359])
175
+ mystr2 = enc.decode([50361, 46284, 50258, 50259, 50359])
176
+ print(encoding, mystr, mystr2)
177
+ print(
178
+ enc.encode("<|startoftranscript|>", allowed_special=enc.special_tokens_set)[0]
179
+ )
180
+
181
+ my_zh_str = "好好学习"
182
+ encoding = enc.encode(my_zh_str, allowed_special=enc.special_tokens_set)
183
+ decoding = enc.decode(encoding)
184
+ print(type(decoding))
185
+ print(encoding, decoding)