truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

Files changed (362) hide show
  1. truss/__init__.py +10 -3
  2. truss/api/__init__.py +123 -0
  3. truss/api/definitions.py +51 -0
  4. truss/base/constants.py +116 -0
  5. truss/base/custom_types.py +29 -0
  6. truss/{errors.py → base/errors.py} +4 -0
  7. truss/base/trt_llm_config.py +310 -0
  8. truss/{truss_config.py → base/truss_config.py} +344 -31
  9. truss/{truss_spec.py → base/truss_spec.py} +20 -6
  10. truss/{validation.py → base/validation.py} +60 -11
  11. truss/cli/cli.py +841 -88
  12. truss/{remote → cli}/remote_cli.py +2 -7
  13. truss/contexts/docker_build_setup.py +67 -0
  14. truss/contexts/image_builder/cache_warmer.py +2 -8
  15. truss/contexts/image_builder/image_builder.py +1 -1
  16. truss/contexts/image_builder/serving_image_builder.py +292 -46
  17. truss/contexts/image_builder/util.py +1 -3
  18. truss/contexts/local_loader/docker_build_emulator.py +58 -0
  19. truss/contexts/local_loader/load_model_local.py +2 -2
  20. truss/contexts/local_loader/truss_module_loader.py +1 -1
  21. truss/contexts/local_loader/utils.py +1 -1
  22. truss/local/local_config.py +2 -6
  23. truss/local/local_config_handler.py +20 -5
  24. truss/patch/__init__.py +1 -0
  25. truss/patch/hash.py +4 -70
  26. truss/patch/signature.py +4 -16
  27. truss/patch/truss_dir_patch_applier.py +3 -78
  28. truss/remote/baseten/api.py +308 -23
  29. truss/remote/baseten/auth.py +3 -3
  30. truss/remote/baseten/core.py +257 -50
  31. truss/remote/baseten/custom_types.py +44 -0
  32. truss/remote/baseten/error.py +4 -0
  33. truss/remote/baseten/remote.py +369 -118
  34. truss/remote/baseten/service.py +118 -11
  35. truss/remote/baseten/utils/status.py +29 -0
  36. truss/remote/baseten/utils/tar.py +34 -22
  37. truss/remote/baseten/utils/transfer.py +36 -23
  38. truss/remote/remote_factory.py +14 -5
  39. truss/remote/truss_remote.py +72 -45
  40. truss/templates/base.Dockerfile.jinja +18 -16
  41. truss/templates/cache.Dockerfile.jinja +3 -3
  42. truss/{server → templates/control}/control/application.py +14 -35
  43. truss/{server → templates/control}/control/endpoints.py +39 -9
  44. truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
  45. truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
  46. truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
  47. truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
  48. truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
  49. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
  50. truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
  51. truss/{server → templates/control}/control/server.py +11 -6
  52. truss/templates/control/requirements.txt +9 -0
  53. truss/templates/custom_python_dx/my_model.py +28 -0
  54. truss/templates/docker_server/proxy.conf.jinja +42 -0
  55. truss/templates/docker_server/supervisord.conf.jinja +27 -0
  56. truss/templates/docker_server_requirements.txt +1 -0
  57. truss/templates/server/common/errors.py +231 -0
  58. truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
  59. truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
  60. truss/{server → templates/server}/common/retry.py +1 -0
  61. truss/{server → templates/server}/common/schema.py +11 -9
  62. truss/templates/server/common/tracing.py +157 -0
  63. truss/templates/server/main.py +9 -0
  64. truss/templates/server/model_wrapper.py +961 -0
  65. truss/templates/server/requirements.txt +21 -0
  66. truss/templates/server/truss_server.py +447 -0
  67. truss/templates/server.Dockerfile.jinja +62 -14
  68. truss/templates/shared/dynamic_config_resolver.py +28 -0
  69. truss/templates/shared/lazy_data_resolver.py +164 -0
  70. truss/templates/shared/log_config.py +125 -0
  71. truss/{server → templates}/shared/secrets_resolver.py +1 -2
  72. truss/{server → templates}/shared/serialization.py +31 -9
  73. truss/{server → templates}/shared/util.py +3 -13
  74. truss/templates/trtllm-audio/model/model.py +49 -0
  75. truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
  76. truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
  77. truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
  78. truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
  79. truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
  80. truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
  81. truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
  82. truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
  83. truss/templates/trtllm-briton/src/extension.py +64 -0
  84. truss/tests/conftest.py +302 -94
  85. truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
  86. truss/tests/contexts/local_loader/test_load_local.py +2 -2
  87. truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
  88. truss/tests/patch/test_calc_patch.py +439 -127
  89. truss/tests/patch/test_dir_signature.py +3 -12
  90. truss/tests/patch/test_hash.py +1 -1
  91. truss/tests/patch/test_signature.py +1 -1
  92. truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
  93. truss/tests/patch/test_types.py +2 -2
  94. truss/tests/remote/baseten/test_api.py +153 -58
  95. truss/tests/remote/baseten/test_auth.py +2 -1
  96. truss/tests/remote/baseten/test_core.py +160 -12
  97. truss/tests/remote/baseten/test_remote.py +489 -77
  98. truss/tests/remote/baseten/test_service.py +55 -0
  99. truss/tests/remote/test_remote_factory.py +16 -18
  100. truss/tests/remote/test_truss_remote.py +26 -17
  101. truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
  102. truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
  103. truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
  104. truss/tests/{server → templates/control}/control/test_server.py +79 -24
  105. truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
  106. truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
  107. truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
  108. truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
  109. truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
  110. truss/tests/{server → templates/server}/common/test_retry.py +3 -3
  111. truss/tests/templates/server/test_model_wrapper.py +248 -0
  112. truss/tests/{server → templates/server}/test_schema.py +3 -5
  113. truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
  114. truss/tests/test_build.py +9 -52
  115. truss/tests/test_config.py +336 -77
  116. truss/tests/test_context_builder_image.py +3 -11
  117. truss/tests/test_control_truss_patching.py +7 -12
  118. truss/tests/test_custom_server.py +38 -0
  119. truss/tests/test_data/context_builder_image_test/test.py +3 -0
  120. truss/tests/test_data/happy.ipynb +56 -0
  121. truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
  122. truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
  123. truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
  124. truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
  125. truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
  126. truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
  127. truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
  128. truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
  129. truss/tests/test_data/test_async_truss/__init__.py +0 -0
  130. truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
  131. truss/tests/test_data/test_basic_truss/__init__.py +0 -0
  132. truss/tests/test_data/test_basic_truss/config.yaml +16 -0
  133. truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
  134. truss/tests/test_data/test_build_commands/__init__.py +0 -0
  135. truss/tests/test_data/test_build_commands/config.yaml +13 -0
  136. truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
  137. truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
  138. truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
  139. truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
  140. truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
  141. truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
  142. truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
  143. truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
  144. truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
  145. truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
  146. truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
  147. truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
  148. truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
  149. truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
  150. truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
  151. truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
  152. truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
  153. truss/tests/test_data/test_openai/__init__.py +0 -0
  154. truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
  155. truss/tests/test_data/test_openai/model/__init__.py +0 -0
  156. truss/tests/test_data/test_openai/model/model.py +15 -0
  157. truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
  158. truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
  159. truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
  160. truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
  161. truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
  162. truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
  163. truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
  164. truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
  165. truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
  166. truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
  167. truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
  168. truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
  169. truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
  170. truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
  171. truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
  172. truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
  173. truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
  174. truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
  175. truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
  176. truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
  177. truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
  178. truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
  179. truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
  180. truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
  181. truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
  182. truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
  183. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
  184. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
  185. truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
  186. truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
  187. truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
  188. truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
  189. truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
  190. truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
  191. truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
  192. truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
  193. truss/tests/test_data/test_truss/__init__.py +0 -0
  194. truss/tests/test_data/test_truss/config.yaml +4 -0
  195. truss/tests/test_data/test_truss/model/__init__.py +0 -0
  196. truss/tests/test_data/test_truss/model/dummy +0 -0
  197. truss/tests/test_data/test_truss/packages/__init__.py +0 -0
  198. truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
  199. truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
  200. truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
  201. truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
  202. truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
  203. truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
  204. truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
  205. truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
  206. truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
  207. truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
  208. truss/tests/test_docker.py +2 -1
  209. truss/tests/test_model_inference.py +1340 -292
  210. truss/tests/test_model_schema.py +33 -26
  211. truss/tests/test_testing_utilities_for_other_tests.py +50 -5
  212. truss/tests/test_truss_gatherer.py +3 -5
  213. truss/tests/test_truss_handle.py +62 -59
  214. truss/tests/test_util.py +2 -1
  215. truss/tests/test_validation.py +15 -13
  216. truss/tests/trt_llm/test_trt_llm_config.py +41 -0
  217. truss/tests/trt_llm/test_validation.py +91 -0
  218. truss/tests/util/test_config_checks.py +40 -0
  219. truss/tests/util/test_env_vars.py +14 -0
  220. truss/tests/util/test_path.py +10 -23
  221. truss/trt_llm/config_checks.py +43 -0
  222. truss/trt_llm/validation.py +42 -0
  223. truss/truss_handle/__init__.py +0 -0
  224. truss/truss_handle/build.py +122 -0
  225. truss/{decorators.py → truss_handle/decorators.py} +1 -1
  226. truss/truss_handle/patch/__init__.py +0 -0
  227. truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
  228. truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
  229. truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
  230. truss/truss_handle/patch/hash.py +71 -0
  231. truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
  232. truss/truss_handle/patch/signature.py +22 -0
  233. truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
  234. truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
  235. truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
  236. truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
  237. truss/util/.truss_ignore +3 -0
  238. truss/{docker.py → util/docker.py} +6 -2
  239. truss/util/download.py +6 -15
  240. truss/util/env_vars.py +41 -0
  241. truss/util/log_utils.py +52 -0
  242. truss/util/path.py +20 -20
  243. truss/util/requirements.py +11 -0
  244. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
  245. truss-0.60.0.dist-info/RECORD +324 -0
  246. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
  247. truss-0.60.0.dist-info/entry_points.txt +4 -0
  248. truss_chains/__init__.py +71 -0
  249. truss_chains/definitions.py +756 -0
  250. truss_chains/deployment/__init__.py +0 -0
  251. truss_chains/deployment/code_gen.py +816 -0
  252. truss_chains/deployment/deployment_client.py +871 -0
  253. truss_chains/framework.py +1480 -0
  254. truss_chains/public_api.py +231 -0
  255. truss_chains/py.typed +0 -0
  256. truss_chains/pydantic_numpy.py +131 -0
  257. truss_chains/reference_code/reference_chainlet.py +34 -0
  258. truss_chains/reference_code/reference_model.py +10 -0
  259. truss_chains/remote_chainlet/__init__.py +0 -0
  260. truss_chains/remote_chainlet/model_skeleton.py +60 -0
  261. truss_chains/remote_chainlet/stub.py +380 -0
  262. truss_chains/remote_chainlet/utils.py +332 -0
  263. truss_chains/streaming.py +378 -0
  264. truss_chains/utils.py +178 -0
  265. CODE_OF_CONDUCT.md +0 -131
  266. CONTRIBUTING.md +0 -48
  267. README.md +0 -137
  268. context_builder.Dockerfile +0 -24
  269. truss/blob/blob_backend.py +0 -10
  270. truss/blob/blob_backend_registry.py +0 -23
  271. truss/blob/http_public_blob_backend.py +0 -23
  272. truss/build/__init__.py +0 -2
  273. truss/build/build.py +0 -143
  274. truss/build/configure.py +0 -63
  275. truss/cli/__init__.py +0 -2
  276. truss/cli/console.py +0 -5
  277. truss/cli/create.py +0 -5
  278. truss/config/trt_llm.py +0 -81
  279. truss/constants.py +0 -61
  280. truss/model_inference.py +0 -123
  281. truss/patch/types.py +0 -30
  282. truss/pytest.ini +0 -7
  283. truss/server/common/errors.py +0 -100
  284. truss/server/common/termination_handler_middleware.py +0 -64
  285. truss/server/common/truss_server.py +0 -389
  286. truss/server/control/patch/model_code_patch_applier.py +0 -46
  287. truss/server/control/patch/requirement_name_identifier.py +0 -17
  288. truss/server/inference_server.py +0 -29
  289. truss/server/model_wrapper.py +0 -434
  290. truss/server/shared/logging.py +0 -81
  291. truss/templates/trtllm/model/model.py +0 -97
  292. truss/templates/trtllm/packages/build_engine_utils.py +0 -34
  293. truss/templates/trtllm/packages/constants.py +0 -11
  294. truss/templates/trtllm/packages/schema.py +0 -216
  295. truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
  296. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
  297. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
  298. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
  299. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
  300. truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
  301. truss/templates/trtllm/packages/triton_client.py +0 -150
  302. truss/templates/trtllm/packages/utils.py +0 -43
  303. truss/test_data/context_builder_image_test/test.py +0 -4
  304. truss/test_data/happy.ipynb +0 -54
  305. truss/test_data/model_load_failure_test/config.yaml +0 -2
  306. truss/test_data/test_concurrency_truss/config.yaml +0 -2
  307. truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
  308. truss/test_data/test_streaming_truss/config.yaml +0 -3
  309. truss/test_data/test_truss/config.yaml +0 -2
  310. truss/tests/server/common/test_termination_handler_middleware.py +0 -93
  311. truss/tests/server/control/test_model_container_patch_applier.py +0 -203
  312. truss/tests/server/core/server/common/test_util.py +0 -19
  313. truss/tests/server/test_model_wrapper.py +0 -87
  314. truss/util/data_structures.py +0 -16
  315. truss-0.10.0rc1.dist-info/RECORD +0 -216
  316. truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
  317. truss/{server/shared → base}/__init__.py +0 -0
  318. truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
  319. truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
  320. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
  321. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
  322. truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
  323. truss/{server → templates/server}/common/__init__.py +0 -0
  324. truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
  325. truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
  326. truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
  327. truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
  328. truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
  329. truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
  330. truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
  331. truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
  332. truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
  333. truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
  334. truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
  335. truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
  336. truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
  337. truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
  338. truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
  339. truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
  340. truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
  341. truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
  342. truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
  343. truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
  344. truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
  345. truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
  346. truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
  347. /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
  348. /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
  349. /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
  350. /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
  351. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
  352. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
  353. /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
  354. /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
  355. /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
  356. /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
  357. /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
  358. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
  359. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
  360. /truss/{patch → truss_handle/patch}/constants.py +0 -0
  361. /truss/{notebook.py → util/notebook.py} +0 -0
  362. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
truss/__init__.py CHANGED
@@ -1,9 +1,12 @@
1
- # flake8: noqa F401
2
-
1
+ import warnings
3
2
  from pathlib import Path
4
3
 
4
+ from pydantic import PydanticDeprecatedSince20
5
5
  from single_source import get_version
6
6
 
7
+ # Suppress Pydantic V1 warnings, because we have to use it for backwards compat.
8
+ warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20)
9
+
7
10
  __version__ = get_version(__name__, Path(__file__).parent.parent)
8
11
 
9
12
 
@@ -11,4 +14,8 @@ def version():
11
14
  return __version__
12
15
 
13
16
 
14
- from truss.build import from_directory, init, kill_all, load
17
+ from truss.api import login, push, whoami
18
+ from truss.base import truss_config
19
+ from truss.truss_handle.build import load # TODO: Refactor all usages and remove.
20
+
21
+ __all__ = ["push", "login", "load", "whoami", "truss_config"]
truss/api/__init__.py ADDED
@@ -0,0 +1,123 @@
1
+ from typing import TYPE_CHECKING, Optional, Type, cast
2
+
3
+ if TYPE_CHECKING:
4
+ from rich import progress
5
+
6
+ from truss.api import definitions
7
+ from truss.remote.baseten.service import BasetenService
8
+ from truss.remote.remote_factory import RemoteFactory
9
+ from truss.remote.truss_remote import RemoteConfig
10
+ from truss.truss_handle.build import load
11
+
12
+
13
+ def login(api_key: str):
14
+ """
15
+ Logs user into Baseten account. Persists information to ~/.trussrc file,
16
+ so only needs to be invoked once.
17
+ Args:
18
+ api_key: Baseten API Key
19
+ """
20
+ remote_url = "https://app.baseten.co"
21
+ remote_config = RemoteConfig(
22
+ name="baseten",
23
+ configs={
24
+ "remote_provider": "baseten",
25
+ "api_key": api_key,
26
+ "remote_url": remote_url,
27
+ },
28
+ )
29
+ RemoteFactory.update_remote_config(remote_config)
30
+
31
+
32
+ def whoami(remote: Optional[str] = None):
33
+ """
34
+ Returns account information for the current user.
35
+ """
36
+ if not remote:
37
+ available_remotes = RemoteFactory.get_available_config_names()
38
+ if len(available_remotes) == 1:
39
+ remote = available_remotes[0]
40
+ elif len(available_remotes) == 0:
41
+ raise ValueError(
42
+ "Please authenticate via truss.login and pass it as an argument."
43
+ )
44
+ else:
45
+ raise ValueError(
46
+ "Multiple remotes found. Please pass the remote as an argument."
47
+ )
48
+
49
+ remote_provider = RemoteFactory.create(remote=remote)
50
+ return remote_provider.whoami()
51
+
52
+
53
+ def push(
54
+ target_directory: str,
55
+ remote: Optional[str] = None,
56
+ model_name: Optional[str] = None,
57
+ publish: bool = False,
58
+ promote: bool = False,
59
+ preserve_previous_production_deployment: bool = False,
60
+ trusted: bool = False,
61
+ deployment_name: Optional[str] = None,
62
+ environment: Optional[str] = None,
63
+ progress_bar: Optional[Type["progress.Progress"]] = None,
64
+ ) -> definitions.ModelDeployment:
65
+ """
66
+ Pushes a Truss to Baseten.
67
+
68
+ Args:
69
+ target_directory: Directory of Truss to push.
70
+ remote: Name of the remote in .trussrc to patch changes to.
71
+ model_name: The name of the model, if different from the one in the config.yaml.
72
+ publish: Push the truss as a published deployment. If no production deployment exists,
73
+ promote the truss to production after deploy completes.
74
+ promote: Push the truss as a published deployment. Even if a production deployment exists,
75
+ promote the truss to production after deploy completes.
76
+ preserve_previous_production_deployment: Preserve the previous production deployment’s autoscaling
77
+ setting. When not specified, the previous production deployment will be updated to allow it to
78
+ scale to zero. Can only be use in combination with `promote` option.
79
+ trusted: Give Truss access to secrets on remote host.
80
+ deployment_name: Name of the deployment created by the push. Can only be
81
+ used in combination with `publish` or `promote`. Deployment name must
82
+ only contain alphanumeric, ’.’, ’-’ or ’_’ characters.
83
+ environment: Name of stable environment on baseten.
84
+ progress_bar: Optional `rich.progress.Progress` if output is desired.
85
+
86
+ Returns:
87
+ The newly created ModelDeployment.
88
+ """
89
+
90
+ if not remote:
91
+ available_remotes = RemoteFactory.get_available_config_names()
92
+ if len(available_remotes) == 1:
93
+ remote = available_remotes[0]
94
+ elif len(available_remotes) == 0:
95
+ raise ValueError(
96
+ "Please authenticate via truss.login and pass it as an argument."
97
+ )
98
+ else:
99
+ raise ValueError(
100
+ "Multiple remotes found. Please pass the remote as an argument."
101
+ )
102
+
103
+ remote_provider = RemoteFactory.create(remote=remote)
104
+ tr = load(target_directory)
105
+ model_name = model_name or tr.spec.config.model_name
106
+ if not model_name:
107
+ raise ValueError(
108
+ "No model name provided. Please specify a model name in config.yaml."
109
+ )
110
+
111
+ service = remote_provider.push(
112
+ tr,
113
+ model_name=model_name,
114
+ publish=publish,
115
+ trusted=trusted,
116
+ promote=promote,
117
+ preserve_previous_prod_deployment=preserve_previous_production_deployment,
118
+ deployment_name=deployment_name,
119
+ environment=environment,
120
+ progress_bar=progress_bar,
121
+ ) # type: ignore
122
+
123
+ return definitions.ModelDeployment(cast(BasetenService, service))
@@ -0,0 +1,51 @@
1
+ import time
2
+
3
+ import pydantic
4
+
5
+ from truss.remote.baseten import service
6
+ from truss.remote.baseten.core import ACTIVE_STATUS, DEPLOYING_STATUSES
7
+
8
+
9
+ class ModelDeployment:
10
+ model_config = pydantic.ConfigDict(protected_namespaces=())
11
+
12
+ model_id: str
13
+ model_deployment_id: str
14
+ _baseten_service: service.BasetenService
15
+
16
+ def __init__(self, service: service.BasetenService):
17
+ self.model_id = service._model_id
18
+ self.model_deployment_id = service._model_version_id
19
+ self._baseten_service = service
20
+
21
+ def wait_for_active(self, timeout_seconds: int = 600) -> bool:
22
+ """
23
+ Waits for the deployment to be active.
24
+
25
+ Args:
26
+ timeout_seconds: The maximum time to wait for the deployment to be active.
27
+
28
+ Returns:
29
+ The status of the deployment.
30
+ """
31
+ start_time = time.time()
32
+ for deployment_status in self._baseten_service.poll_deployment_status():
33
+ if (
34
+ timeout_seconds is not None
35
+ and time.time() - start_time > timeout_seconds
36
+ ):
37
+ raise TimeoutError("Deployment timed out.")
38
+
39
+ if deployment_status == ACTIVE_STATUS:
40
+ return True
41
+
42
+ if deployment_status not in DEPLOYING_STATUSES:
43
+ raise Exception(f"Deployment failed with status: {deployment_status}")
44
+
45
+ raise RuntimeError("Error polling deployment status.")
46
+
47
+ def __repr__(self):
48
+ return (
49
+ f"ModelDeployment(model_id={self.model_id}, "
50
+ f"model_deployment_id={self.model_deployment_id})"
51
+ )
@@ -0,0 +1,116 @@
1
+ import pathlib
2
+ from typing import Set
3
+
4
+ SKLEARN = "sklearn"
5
+ TENSORFLOW = "tensorflow"
6
+ KERAS = "keras"
7
+ XGBOOST = "xgboost"
8
+ PYTORCH = "pytorch"
9
+ CUSTOM = "custom"
10
+ HUGGINGFACE_TRANSFORMER = "huggingface_transformer"
11
+ LIGHTGBM = "lightgbm"
12
+
13
+ _TRUSS_ROOT = pathlib.Path(__file__).parent.parent.resolve()
14
+
15
+ TEMPLATES_DIR = _TRUSS_ROOT / "templates"
16
+ TRADITIONAL_CUSTOM_TEMPLATE_DIR = TEMPLATES_DIR / "custom"
17
+ PYTHON_DX_CUSTOM_TEMPLATE_DIR = TEMPLATES_DIR / "custom_python_dx"
18
+ DOCKER_SERVER_TEMPLATES_DIR = TEMPLATES_DIR / "docker_server"
19
+ SERVER_CODE_DIR: pathlib.Path = TEMPLATES_DIR / "server"
20
+ TRITON_SERVER_CODE_DIR: pathlib.Path = TEMPLATES_DIR / "triton"
21
+ TRTLLM_TRUSS_DIR: pathlib.Path = TEMPLATES_DIR / "trtllm-briton"
22
+ SHARED_SERVING_AND_TRAINING_CODE_DIR_NAME = "shared"
23
+ SHARED_SERVING_AND_TRAINING_CODE_DIR: pathlib.Path = (
24
+ TEMPLATES_DIR / SHARED_SERVING_AND_TRAINING_CODE_DIR_NAME
25
+ )
26
+ CONTROL_SERVER_CODE_DIR: pathlib.Path = TEMPLATES_DIR / "control"
27
+ CHAINS_CODE_DIR: pathlib.Path = _TRUSS_ROOT.parent / "truss-chains" / "truss_chains"
28
+
29
+ SUPPORTED_PYTHON_VERSIONS = {"3.8", "3.9", "3.10", "3.11"}
30
+ MAX_SUPPORTED_PYTHON_VERSION_IN_CUSTOM_BASE_IMAGE = "3.12"
31
+ MIN_SUPPORTED_PYTHON_VERSION_IN_CUSTOM_BASE_IMAGE = "3.8"
32
+
33
+ TRTLLM_PREDICT_CONCURRENCY = 512
34
+ BEI_TRTLLM_CLIENT_BATCH_SIZE = 128
35
+ BEI_MAX_CONCURRENCY_TARGET_REQUESTS = 2048
36
+ BEI_REQUIRED_MAX_NUM_TOKENS = 16384
37
+
38
+ TRTLLM_MIN_MEMORY_REQUEST_GI = 16
39
+ HF_MODELS_API_URL = "https://huggingface.co/api/models"
40
+ HF_ACCESS_TOKEN_KEY = "hf_access_token"
41
+ TRUSSLESS_MAX_PAYLOAD_SIZE = "64M"
42
+ # Alias for TEMPLATES_DIR
43
+ SERVING_DIR: pathlib.Path = TEMPLATES_DIR
44
+
45
+ REQUIREMENTS_TXT_FILENAME = "requirements.txt"
46
+ USER_SUPPLIED_REQUIREMENTS_TXT_FILENAME = "user_requirements.txt"
47
+ BASE_SERVER_REQUIREMENTS_TXT_FILENAME = "base_server_requirements.txt"
48
+ SERVER_REQUIREMENTS_TXT_FILENAME = "server_requirements.txt"
49
+ SYSTEM_PACKAGES_TXT_FILENAME = "system_packages.txt"
50
+
51
+ FILENAME_CONSTANTS_MAP = {
52
+ "config_requirements_filename": REQUIREMENTS_TXT_FILENAME,
53
+ "user_supplied_requirements_filename": USER_SUPPLIED_REQUIREMENTS_TXT_FILENAME,
54
+ "base_server_requirements_filename": BASE_SERVER_REQUIREMENTS_TXT_FILENAME,
55
+ "server_requirements_filename": SERVER_REQUIREMENTS_TXT_FILENAME,
56
+ "system_packages_filename": SYSTEM_PACKAGES_TXT_FILENAME,
57
+ }
58
+
59
+ SERVER_DOCKERFILE_TEMPLATE_NAME = "server.Dockerfile.jinja"
60
+ MODEL_DOCKERFILE_NAME = "Dockerfile"
61
+
62
+ README_TEMPLATE_NAME = "README.md.jinja"
63
+ MODEL_README_NAME = "README.md"
64
+
65
+ CONFIG_FILE = "config.yaml"
66
+ DOCKERFILE = "Dockerfile"
67
+ # Used to indicate whether to associate a container with Truss
68
+ TRUSS = "truss"
69
+ # Used to create unique identifier based on last time truss was updated
70
+ TRUSS_MODIFIED_TIME = "truss_modified_time"
71
+ # Path of the Truss used to identify which Truss is being referred
72
+ TRUSS_DIR = "truss_dir"
73
+ TRUSS_HASH = "truss_hash"
74
+
75
+ HUGGINGFACE_TRANSFORMER_MODULE_NAME: Set[str] = set({})
76
+
77
+ # list from https://scikit-learn.org/stable/developers/advanced_installation.html
78
+ SKLEARN_REQ_MODULE_NAMES: Set[str] = {
79
+ "numpy",
80
+ "scipy",
81
+ "joblib",
82
+ "scikit-learn",
83
+ "threadpoolctl",
84
+ }
85
+
86
+ XGBOOST_REQ_MODULE_NAMES: Set[str] = {"xgboost"}
87
+
88
+ # list from https://www.tensorflow.org/install/pip
89
+ # if problematic, lets look to https://www.tensorflow.org/install/source
90
+ TENSORFLOW_REQ_MODULE_NAMES: Set[str] = {"tensorflow"}
91
+
92
+ LIGHTGBM_REQ_MODULE_NAMES: Set[str] = {"lightgbm"}
93
+
94
+ # list from https://pytorch.org/get-started/locally/
95
+ PYTORCH_REQ_MODULE_NAMES: Set[str] = {"torch", "torchvision", "torchaudio"}
96
+
97
+ MLFLOW_REQ_MODULE_NAMES: Set[str] = {"mlflow"}
98
+
99
+ INFERENCE_SERVER_PORT = 8080
100
+
101
+ HTTP_PUBLIC_BLOB_BACKEND = "http_public"
102
+
103
+ REGISTRY_BUILD_SECRET_PREFIX = "DOCKER_REGISTRY_"
104
+
105
+ TRTLLM_SPEC_DEC_TARGET_MODEL_NAME = "target"
106
+ TRTLLM_SPEC_DEC_DRAFT_MODEL_NAME = "draft"
107
+ TRTLLM_BASE_IMAGE = "baseten/briton-server:v0.16.0-5be7b58"
108
+ TRTLLM_PYTHON_EXECUTABLE = "/usr/local/briton/venv/bin/python"
109
+ BASE_TRTLLM_REQUIREMENTS = ["briton==0.4.2"]
110
+ BEI_TRTLLM_BASE_IMAGE = "baseten/bei:0.0.17@sha256:9c3577f6ec672d6da5aca18e9c0ebdddd65ed80c8858e757fbde7e9cf48de01d"
111
+
112
+ BEI_TRTLLM_PYTHON_EXECUTABLE = "/usr/bin/python3"
113
+
114
+ OPENAI_COMPATIBLE_TAG = "openai-compatible"
115
+
116
+ PRODUCTION_ENVIRONMENT_NAME = "production"
@@ -0,0 +1,29 @@
1
+ from dataclasses import dataclass
2
+ from enum import Enum
3
+ from typing import Any
4
+
5
+
6
+ # TODO(marius/TaT): kill this.
7
+ class ModelFrameworkType(Enum):
8
+ SKLEARN = "sklearn"
9
+ TENSORFLOW = "tensorflow"
10
+ KERAS = "keras"
11
+ PYTORCH = "pytorch"
12
+ HUGGINGFACE_TRANSFORMER = "huggingface_transformer"
13
+ XGBOOST = "xgboost"
14
+ LIGHTGBM = "lightgbm"
15
+ MLFLOW = "mlflow"
16
+ CUSTOM = "custom"
17
+
18
+
19
+ @dataclass
20
+ class Example:
21
+ name: str
22
+ input: Any
23
+
24
+ @staticmethod
25
+ def from_dict(example_dict):
26
+ return Example(name=example_dict["name"], input=example_dict["input"])
27
+
28
+ def to_dict(self) -> dict:
29
+ return {"name": self.name, "input": self.input}
@@ -38,3 +38,7 @@ class ContainerNotFoundError(Error):
38
38
 
39
39
  class ContainerAPINoResponseError(Error):
40
40
  pass
41
+
42
+
43
+ class RemoteNetworkError(Exception):
44
+ pass
@@ -0,0 +1,310 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+ import os
6
+ import warnings
7
+ from enum import Enum
8
+ from typing import Any, Optional
9
+
10
+ from huggingface_hub.errors import HFValidationError
11
+ from huggingface_hub.utils import validate_repo_id
12
+ from pydantic import BaseModel, PydanticDeprecatedSince20, model_validator, validator
13
+
14
+ from truss.base.constants import BEI_REQUIRED_MAX_NUM_TOKENS
15
+
16
+ logger = logging.getLogger(__name__)
17
+ # Suppress Pydantic V1 warnings, because we have to use it for backwards compat.
18
+ warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20)
19
+
20
+ ENGINE_BUILDER_TRUSS_RUNTIME_MIGRATION = (
21
+ os.environ.get("ENGINE_BUILDER_TRUSS_RUNTIME_MIGRATION", "False") == "True"
22
+ )
23
+
24
+
25
+ class TrussTRTLLMModel(str, Enum):
26
+ LLAMA = "llama"
27
+ MISTRAL = "mistral"
28
+ DEEPSEEK = "deepseek"
29
+ WHISPER = "whisper"
30
+ QWEN = "qwen"
31
+ ENCODER = "encoder"
32
+ PALMYRA = "palmyra"
33
+
34
+
35
+ class TrussTRTLLMQuantizationType(str, Enum):
36
+ NO_QUANT = "no_quant"
37
+ WEIGHTS_ONLY_INT8 = "weights_int8"
38
+ WEIGHTS_KV_INT8 = "weights_kv_int8"
39
+ WEIGHTS_ONLY_INT4 = "weights_int4"
40
+ WEIGHTS_INT4_KV_INT8 = "weights_int4_kv_int8"
41
+ SMOOTH_QUANT = "smooth_quant"
42
+ FP8 = "fp8"
43
+ FP8_KV = "fp8_kv"
44
+
45
+
46
+ class TrussTRTLLMPluginConfiguration(BaseModel):
47
+ paged_kv_cache: bool = True
48
+ gemm_plugin: str = "auto"
49
+ use_paged_context_fmha: bool = True
50
+ use_fp8_context_fmha: bool = False
51
+
52
+
53
+ class CheckpointSource(str, Enum):
54
+ HF = "HF"
55
+ GCS = "GCS"
56
+ LOCAL = "LOCAL"
57
+ # REMOTE_URL is useful when the checkpoint lives on remote storage accessible via HTTP (e.g a presigned URL)
58
+ REMOTE_URL = "REMOTE_URL"
59
+
60
+
61
+ class CheckpointRepository(BaseModel):
62
+ source: CheckpointSource
63
+ repo: str
64
+ revision: Optional[str] = None
65
+
66
+ def __init__(self, **data):
67
+ super().__init__(**data)
68
+ if self.source == CheckpointSource.HF:
69
+ self._validate_hf_repo_id()
70
+
71
+ def _validate_hf_repo_id(self):
72
+ try:
73
+ validate_repo_id(self.repo)
74
+ except HFValidationError as e:
75
+ raise ValueError(
76
+ f"HuggingFace repository validation failed: {str(e)}"
77
+ ) from e
78
+
79
+
80
+ class TrussTRTLLMBatchSchedulerPolicy(str, Enum):
81
+ MAX_UTILIZATION = "max_utilization"
82
+ GUARANTEED_NO_EVICT = "guaranteed_no_evict"
83
+
84
+
85
+ class TrussSpecDecMode(str, Enum):
86
+ DRAFT_EXTERNAL = "DRAFT_TOKENS_EXTERNAL"
87
+
88
+
89
+ class TrussTRTLLMRuntimeConfiguration(BaseModel):
90
+ kv_cache_free_gpu_mem_fraction: float = 0.9
91
+ kv_cache_host_memory_bytes: Optional[int] = None
92
+ enable_chunked_context: bool = True
93
+ batch_scheduler_policy: TrussTRTLLMBatchSchedulerPolicy = (
94
+ TrussTRTLLMBatchSchedulerPolicy.GUARANTEED_NO_EVICT
95
+ )
96
+ request_default_max_tokens: Optional[int] = None
97
+ total_token_limit: int = 500000
98
+
99
+
100
+ class TrussTRTLLMBuildConfiguration(BaseModel):
101
+ base_model: TrussTRTLLMModel
102
+ max_seq_len: int
103
+ max_batch_size: int = 256
104
+ max_num_tokens: int = 8192
105
+ max_beam_width: int = 1
106
+ max_prompt_embedding_table_size: int = 0
107
+ checkpoint_repository: CheckpointRepository
108
+ gather_all_token_logits: bool = False
109
+ strongly_typed: bool = False
110
+ quantization_type: TrussTRTLLMQuantizationType = (
111
+ TrussTRTLLMQuantizationType.NO_QUANT
112
+ )
113
+ tensor_parallel_count: int = 1
114
+ pipeline_parallel_count: int = 1
115
+ plugin_configuration: TrussTRTLLMPluginConfiguration = (
116
+ TrussTRTLLMPluginConfiguration()
117
+ )
118
+ num_builder_gpus: Optional[int] = None
119
+ speculator: Optional[TrussSpeculatorConfiguration] = None
120
+
121
+ class Config:
122
+ extra = "forbid"
123
+
124
+ def __init__(self, **data):
125
+ super().__init__(**data)
126
+ self._validate_kv_cache_flags()
127
+ self._validate_speculator_config()
128
+ self._bei_specfic_migration()
129
+
130
+ @validator("max_beam_width")
131
+ def check_max_beam_width(cls, v: int):
132
+ if isinstance(v, int):
133
+ if v != 1:
134
+ raise ValueError(
135
+ "max_beam_width greater than 1 is not currently supported"
136
+ )
137
+ return v
138
+
139
+ def _bei_specfic_migration(self):
140
+ """performs embedding specfic optimizations (no kv-cache, high batch size)"""
141
+ if self.base_model == TrussTRTLLMModel.ENCODER:
142
+ # Encoder specific settings
143
+ logger.info(
144
+ f"Your setting of `build.max_seq_len={self.max_seq_len}` is not used and "
145
+ "automatically inferred from the model repo config.json -> `max_position_embeddings`"
146
+ )
147
+
148
+ if self.max_num_tokens < BEI_REQUIRED_MAX_NUM_TOKENS:
149
+ logger.warning(
150
+ f"build.max_num_tokens={self.max_num_tokens}, upgrading to {BEI_REQUIRED_MAX_NUM_TOKENS}"
151
+ )
152
+ self.max_num_tokens = BEI_REQUIRED_MAX_NUM_TOKENS
153
+ self.plugin_configuration.paged_kv_cache = False
154
+ self.plugin_configuration.use_paged_context_fmha = False
155
+
156
+ if "_kv" in self.quantization_type.value:
157
+ raise ValueError(
158
+ "encoder does not have a kv-cache, therefore a kv specfic datatype is not valid"
159
+ f"you selected build.quantization_type {self.quantization_type}"
160
+ )
161
+
162
+ def _validate_kv_cache_flags(self):
163
+ if not self.plugin_configuration.paged_kv_cache and (
164
+ self.plugin_configuration.use_paged_context_fmha
165
+ or self.plugin_configuration.use_fp8_context_fmha
166
+ ):
167
+ raise ValueError(
168
+ "Using paged context fmha or fp8 context fmha requires requires paged kv cache"
169
+ )
170
+ if (
171
+ self.plugin_configuration.use_fp8_context_fmha
172
+ and not self.plugin_configuration.use_paged_context_fmha
173
+ ):
174
+ raise ValueError("Using fp8 context fmha requires paged context fmha")
175
+ if (
176
+ self.plugin_configuration.use_fp8_context_fmha
177
+ and not self.quantization_type == TrussTRTLLMQuantizationType.FP8_KV
178
+ ):
179
+ raise ValueError("Using fp8 context fmha requires fp8 kv cache dtype")
180
+ return self
181
+
182
+ def _validate_speculator_config(self):
183
+ if self.speculator:
184
+ if self.base_model is TrussTRTLLMModel.WHISPER:
185
+ raise ValueError("Speculative decoding for Whisper is not supported.")
186
+ if not all(
187
+ [
188
+ self.plugin_configuration.use_paged_context_fmha,
189
+ self.plugin_configuration.paged_kv_cache,
190
+ ]
191
+ ):
192
+ raise ValueError(
193
+ "KV cache block reuse must be enabled for speculative decoding target model."
194
+ )
195
+ if self.speculator.build:
196
+ if (
197
+ self.tensor_parallel_count
198
+ != self.speculator.build.tensor_parallel_count
199
+ ):
200
+ raise ValueError(
201
+ "Speculative decoding requires the same tensor parallelism for target and draft models."
202
+ )
203
+
204
+ @property
205
+ def max_draft_len(self) -> Optional[int]:
206
+ if self.speculator:
207
+ return self.speculator.num_draft_tokens
208
+ return None
209
+
210
+
211
+ class TrussSpeculatorConfiguration(BaseModel):
212
+ speculative_decoding_mode: TrussSpecDecMode = TrussSpecDecMode.DRAFT_EXTERNAL
213
+ num_draft_tokens: int
214
+ checkpoint_repository: Optional[CheckpointRepository] = None
215
+ runtime: TrussTRTLLMRuntimeConfiguration = TrussTRTLLMRuntimeConfiguration()
216
+ build: Optional[TrussTRTLLMBuildConfiguration] = None
217
+
218
+ def __init__(self, **data):
219
+ super().__init__(**data)
220
+ self._validate_checkpoint()
221
+
222
+ def _validate_checkpoint(self):
223
+ if not (bool(self.checkpoint_repository) ^ bool(self.build)):
224
+ raise ValueError(
225
+ "Speculative decoding requires exactly one of checkpoint_repository or build to be configured."
226
+ )
227
+
228
+ @property
229
+ def resolved_checkpoint_repository(self) -> CheckpointRepository:
230
+ if self.build:
231
+ return self.build.checkpoint_repository
232
+ elif self.checkpoint_repository:
233
+ return self.checkpoint_repository
234
+ else:
235
+ raise ValueError(
236
+ "Speculative decoding requires exactly one of checkpoint_repository or build to be configured."
237
+ )
238
+
239
+
240
+ class TRTLLMConfiguration(BaseModel):
241
+ runtime: TrussTRTLLMRuntimeConfiguration = TrussTRTLLMRuntimeConfiguration()
242
+ build: TrussTRTLLMBuildConfiguration
243
+
244
+ @model_validator(mode="before")
245
+ @classmethod
246
+ def migrate_runtime_fields(cls, data: Any) -> Any:
247
+ extra_runtime_fields = {}
248
+ valid_build_fields = {}
249
+ if isinstance(data.get("build"), dict):
250
+ for key, value in data.get("build").items():
251
+ if key in TrussTRTLLMBuildConfiguration.__annotations__:
252
+ valid_build_fields[key] = value
253
+ else:
254
+ if key in TrussTRTLLMRuntimeConfiguration.__annotations__:
255
+ logger.warning(f"Found runtime.{key}: {value} in build config")
256
+ extra_runtime_fields[key] = value
257
+ if extra_runtime_fields:
258
+ logger.warning(
259
+ f"Found extra fields {list(extra_runtime_fields.keys())} in build configuration, unspecified runtime fields will be configured using these values."
260
+ " This configuration of deprecated fields is scheduled for removal, please upgrade to the latest truss version and update configs according to https://docs.baseten.co/performance/engine-builder-config."
261
+ )
262
+ if data.get("runtime"):
263
+ data.get("runtime").update(
264
+ {
265
+ k: v
266
+ for k, v in extra_runtime_fields.items()
267
+ if k not in data.get("runtime")
268
+ }
269
+ )
270
+ else:
271
+ data.update(
272
+ {"runtime": {k: v for k, v in extra_runtime_fields.items()}}
273
+ )
274
+ data.update({"build": valid_build_fields})
275
+ return data
276
+ return data
277
+
278
+ @model_validator(mode="after")
279
+ def after(self: "TRTLLMConfiguration") -> "TRTLLMConfiguration":
280
+ # check if there is an error wrt. runtime.enable_chunked_context
281
+ if (
282
+ self.runtime.enable_chunked_context
283
+ and (self.build.base_model != TrussTRTLLMModel.ENCODER)
284
+ and not (
285
+ self.build.plugin_configuration.use_paged_context_fmha
286
+ and self.build.plugin_configuration.paged_kv_cache
287
+ )
288
+ ):
289
+ if ENGINE_BUILDER_TRUSS_RUNTIME_MIGRATION:
290
+ logger.warning(
291
+ "If trt_llm.runtime.enable_chunked_context is True, then trt_llm.build.plugin_configuration.use_paged_context_fmha and trt_llm.build.plugin_configuration.paged_kv_cache should be True. "
292
+ "Setting trt_llm.build.plugin_configuration.use_paged_context_fmha and trt_llm.build.plugin_configuration.paged_kv_cache to True."
293
+ )
294
+ self.build.plugin_configuration.use_paged_context_fmha = True
295
+ self.build.plugin_configuration.paged_kv_cache = True
296
+ else:
297
+ raise ValueError(
298
+ "If runtime.enable_chunked_context is True, then build.plugin_configuration.use_paged_context_fmha and build.plugin_configuration.paged_kv_cache should be True"
299
+ )
300
+
301
+ return self
302
+
303
+ @property
304
+ def requires_build(self):
305
+ return self.build is not None
306
+
307
+ # TODO(Abu): Replace this with model_dump(json=True)
308
+ # when pydantic v2 is used here
309
+ def to_json_dict(self, verbose=True):
310
+ return json.loads(self.json(exclude_unset=not verbose))