onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,57 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ import os
7
+
8
+ import torch.distributed as dist
9
+
10
+
11
+ def init_dist():
12
+ if "LOCAL_RANK" in os.environ:
13
+ int(os.environ["LOCAL_RANK"])
14
+ rank = int(os.environ["RANK"])
15
+ world_size = int(os.environ["WORLD_SIZE"])
16
+
17
+ dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7645", world_size=world_size, rank=rank)
18
+ elif "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
19
+ int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", "0"))
20
+ rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", "0"))
21
+ world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", "1"))
22
+
23
+ dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7647", world_size=world_size, rank=rank)
24
+ else:
25
+ # don't need to do init for single process
26
+ pass
27
+
28
+
29
+ def _get_comm():
30
+ try:
31
+ from mpi4py import MPI # noqa: PLC0415
32
+
33
+ comm = MPI.COMM_WORLD
34
+ return comm
35
+ except ImportError:
36
+ return None
37
+
38
+
39
+ def get_rank():
40
+ comm = _get_comm()
41
+ return comm.Get_rank() if comm is not None else 0
42
+
43
+
44
+ def get_size():
45
+ comm = _get_comm()
46
+ return comm.Get_size() if comm is not None else 1
47
+
48
+
49
+ def barrier():
50
+ comm = _get_comm()
51
+ if comm is not None:
52
+ comm.Barrier()
53
+
54
+
55
+ def print_out(*args):
56
+ if get_rank() == 0:
57
+ print(*args)
@@ -0,0 +1,504 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ from __future__ import annotations
7
+
8
+ import numpy as np
9
+ import torch
10
+ from transformers import AutoConfig, AutoTokenizer
11
+ from transformers.cache_utils import DynamicCache
12
+
13
+ from onnxruntime import InferenceSession, OrtValue
14
+
15
+
16
+ # Get position_ids from attention_mask
17
+ def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool):
18
+ position_ids = attention_mask.long().cumsum(-1) - 1
19
+ position_ids.masked_fill_(attention_mask == 0, 1)
20
+ if use_past_kv:
21
+ # Shape: (batch_size, 1)
22
+ position_ids = position_ids[:, -1].unsqueeze(-1)
23
+
24
+ # Shape: (batch_size, sequence_length)
25
+ return position_ids
26
+
27
+
28
+ # Inputs for first pass to get initial past_key_values
29
+ # input_ids: (batch_size, sequence_length)
30
+ # attention_mask: (batch_size, sequence_length)
31
+ # position_ids: (batch_size, sequence_length)
32
+ def get_sample_inputs(
33
+ config: AutoConfig,
34
+ device: torch.device,
35
+ batch_size: int,
36
+ seq_len: int,
37
+ engine: str = "pt",
38
+ return_dict: bool = False,
39
+ ):
40
+ input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
41
+ attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64)
42
+ position_ids = get_position_ids(attention_mask, use_past_kv=False)
43
+
44
+ # Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
45
+ input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
46
+ attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
47
+ position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
48
+
49
+ if not return_dict:
50
+ # For export
51
+ return (input_ids, attention_mask, position_ids)
52
+
53
+ inputs = {
54
+ "input_ids": input_ids,
55
+ "attention_mask": attention_mask,
56
+ "position_ids": position_ids,
57
+ }
58
+ return inputs
59
+
60
+
61
+ # Inputs for subsequent passes with past_key_values
62
+ # input_ids: (batch_size, 1)
63
+ # attention_mask: (batch_size, past_sequence_length + 1)
64
+ # position_ids: (batch_size, 1)
65
+ # past_key: (batch_size, num_heads, past_sequence_length, head_size)
66
+ # past_value: (batch_size, num_heads, past_sequence_length, head_size)
67
+ def get_sample_with_past_kv_inputs(
68
+ config: AutoConfig,
69
+ device: torch.device,
70
+ batch_size: int,
71
+ past_seq_len: int,
72
+ use_fp16: bool = False,
73
+ engine: str = "pt",
74
+ return_dict: bool = False,
75
+ world_size: int = 1,
76
+ ):
77
+ input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), dtype=torch.int64)
78
+ attention_mask = torch.ones(batch_size, past_seq_len + 1, dtype=torch.int64)
79
+ # position_ids is of shape (batch_size, 1)
80
+ position_ids = get_position_ids(attention_mask, use_past_kv=True)
81
+ past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)
82
+
83
+ # Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
84
+ input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
85
+ attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
86
+ position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
87
+ past_kv = (
88
+ flatten_past_kv_inputs(past_kv) if engine == "ort" else [(kv[0].to(device), kv[1].to(device)) for kv in past_kv]
89
+ )
90
+
91
+ if not return_dict:
92
+ # For export
93
+ assert isinstance(past_kv, list)
94
+ return (input_ids, attention_mask, position_ids, past_kv)
95
+
96
+ inputs = {
97
+ "input_ids": input_ids,
98
+ "attention_mask": attention_mask,
99
+ "position_ids": position_ids,
100
+ }
101
+ if engine == "ort":
102
+ assert isinstance(past_kv, dict)
103
+ inputs.update(past_kv)
104
+ else:
105
+ assert isinstance(past_kv, list)
106
+ inputs["past_key_values"] = past_kv
107
+
108
+ return inputs
109
+
110
+
111
+ # Inputs for all passes with past_key_values
112
+ # input_ids: (batch_size, sequence_length)
113
+ # attention_mask: (batch_size, past_sequence_length + sequence_length)
114
+ # position_ids: (batch_size, sequence_length)
115
+ # past_key: (batch_size, num_heads, kv_sequence_length, head_size)
116
+ # For models with GQA, kv_sequence_length = max_sequence_length
117
+ # For models without GQA, kv_sequence_length = past_sequence_length
118
+ # past_value: (batch_size, num_heads, kv_sequence_length, head_size)
119
+ # For models with GQA, kv_sequence_length = max_sequence_length
120
+ # For models without GQA, kv_sequence_length = past_sequence_length
121
+ def get_merged_sample_with_past_kv_inputs(
122
+ config: AutoConfig,
123
+ device: torch.device,
124
+ batch_size: int,
125
+ seq_len: int,
126
+ past_seq_len: int,
127
+ max_seq_len: int,
128
+ use_fp16: bool = False,
129
+ use_buffer_share: bool = False,
130
+ engine: str = "pt",
131
+ return_dict: bool = False,
132
+ world_size: int = 1,
133
+ ):
134
+ input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
135
+ attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64)
136
+ # position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation
137
+ position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0))
138
+ past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)
139
+
140
+ # Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
141
+ input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
142
+ attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
143
+ position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
144
+ past_kv = (
145
+ flatten_past_kv_inputs(past_kv) if engine == "ort" else [(kv[0].to(device), kv[1].to(device)) for kv in past_kv]
146
+ )
147
+
148
+ if not return_dict:
149
+ # For export
150
+ assert isinstance(past_kv, list)
151
+ return (input_ids, attention_mask, position_ids, past_kv)
152
+
153
+ inputs = {
154
+ "input_ids": input_ids,
155
+ "attention_mask": attention_mask,
156
+ "position_ids": position_ids,
157
+ }
158
+ if engine == "ort":
159
+ assert isinstance(past_kv, dict)
160
+ inputs.update(past_kv)
161
+
162
+ if use_buffer_share:
163
+ inputs = enable_past_present_share_buffer(inputs, past_seq_len, max_seq_len)
164
+
165
+ else:
166
+ assert isinstance(past_kv, list)
167
+ inputs["past_key_values"] = past_kv
168
+
169
+ return inputs
170
+
171
+
172
+ # Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx
173
+ def get_msft_sample_inputs(
174
+ config: AutoConfig,
175
+ batch_size: int,
176
+ past_seq_len: int,
177
+ seq_len: int,
178
+ max_seq_len: int,
179
+ use_fp16: bool,
180
+ use_buffer_share: bool,
181
+ split_kv: bool,
182
+ ):
183
+ np_dtype = np.float16 if use_fp16 else np.float32
184
+ head_size = config.hidden_size // config.num_attention_heads
185
+
186
+ if not split_kv:
187
+ ort_inputs = {
188
+ "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype),
189
+ "attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype),
190
+ "k_cache": np.random.rand(
191
+ batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
192
+ ).astype(np_dtype),
193
+ "v_cache": np.random.rand(
194
+ batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
195
+ ).astype(np_dtype),
196
+ "pos": np.array(past_seq_len, dtype=np.int64),
197
+ }
198
+ else:
199
+ ort_inputs = {
200
+ "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype),
201
+ "attn_mask": (np.triu(np.ones((batch_size, max_seq_len, max_seq_len), dtype=np.int32), k=1) - 1).astype(
202
+ np.int32
203
+ ),
204
+ "pos": np.array(past_seq_len, dtype=np.int64),
205
+ }
206
+ for i in range(config.num_hidden_layers):
207
+ ort_inputs.update(
208
+ {
209
+ f"k_{i}_cache": np.random.rand(
210
+ batch_size, config.num_attention_heads, past_seq_len, head_size
211
+ ).astype(np_dtype),
212
+ f"v_{i}_cache": np.random.rand(
213
+ batch_size, config.num_attention_heads, past_seq_len, head_size
214
+ ).astype(np_dtype),
215
+ }
216
+ )
217
+
218
+ if use_buffer_share:
219
+ ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len)
220
+
221
+ return ort_inputs
222
+
223
+
224
+ # Create past_key_values
225
+ # Each is of shape (batch_size, num_heads, past_sequence_length, head_size)
226
+ def get_past_kv_inputs(config: AutoConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1):
227
+ num_heads = config.num_key_value_heads // world_size
228
+ head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
229
+ torch_dtype = torch.float16 if use_fp16 else torch.float32
230
+ past_kv = [
231
+ (
232
+ torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype),
233
+ torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype),
234
+ )
235
+ for _ in range(config.num_hidden_layers)
236
+ ]
237
+ return past_kv
238
+
239
+
240
+ # Convert list of past_key_values to dict of past_key and past_value
241
+ def flatten_past_kv_inputs(past_key_values: list[tuple[torch.Tensor, torch.Tensor]]):
242
+ past_kv = {}
243
+ for i, (past_k, past_v) in enumerate(past_key_values):
244
+ if isinstance(past_key_values, DynamicCache):
245
+ past_kv[f"past_key_values_key_cache_{i}"] = past_k.detach().cpu().numpy()
246
+ past_kv[f"past_key_values_value_cache_{i}"] = past_v.detach().cpu().numpy()
247
+ else:
248
+ past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy()
249
+ past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy()
250
+ return past_kv
251
+
252
+
253
+ # Format PyTorch inputs to ONNX Runtime inputs
254
+ def convert_inputs_for_ort(
255
+ pt_inputs: dict,
256
+ use_buffer_share: bool = False,
257
+ past_seq_len: int = 0,
258
+ max_seq_len: int = 2048,
259
+ ):
260
+ ort_inputs = {}
261
+ for k, v in pt_inputs.items():
262
+ if isinstance(v, np.ndarray):
263
+ ort_inputs[k] = v
264
+ elif k == "past_key_values":
265
+ ort_inputs.update(flatten_past_kv_inputs(v))
266
+ else:
267
+ ort_inputs[k] = v.detach().cpu().numpy()
268
+
269
+ # Reshape KV caches if using past-present-share-buffer
270
+ if use_buffer_share:
271
+ ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len)
272
+
273
+ return ort_inputs
274
+
275
+
276
+ # Re-allocate KV caches from (batch_size, num_heads, past_sequence_length, head_size) to
277
+ # (batch_size, num_heads, max_sequence_length, head_size) for past-present buffer sharing
278
+ def enable_past_present_share_buffer(ort_inputs: dict, past_seq_len: int, max_seq_len: int):
279
+ for k, v in ort_inputs.items():
280
+ # Allocate new buffers with max_sequence_length for GQA
281
+ if "cache" in k or "past_key_values" in k:
282
+ # Copy v (BxSxPxH) into new_v (BxSxMxH)
283
+ batch_size, num_heads, _, head_size = v.shape
284
+ new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype)
285
+ new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v
286
+ ort_inputs[k] = new_v
287
+ return ort_inputs
288
+
289
+
290
+ # Verify ONNX Runtime inputs with model
291
+ def verify_ort_inputs(model: InferenceSession, ort_inputs: dict):
292
+ # Check that all model inputs will be provided
293
+ model_inputs = {model_input.name for model_input in model.get_inputs()}
294
+ user_inputs = set(ort_inputs.keys())
295
+ missing_inputs = model_inputs - user_inputs
296
+ if len(missing_inputs):
297
+ print(f"The following model inputs are missing: {missing_inputs}")
298
+ raise Exception("There are missing inputs to the model. Please add them and try again.")
299
+
300
+ # Remove unnecessary inputs from model inputs
301
+ unnecessary_inputs = user_inputs - model_inputs
302
+ if len(unnecessary_inputs):
303
+ for unnecessary_input in unnecessary_inputs:
304
+ del ort_inputs[unnecessary_input]
305
+
306
+ return ort_inputs
307
+
308
+
309
+ # Add IO bindings for execution providers using OrtValue
310
+ # Use when you need to run inference once or twice to save memory
311
+ def add_io_bindings_as_ortvalues(
312
+ model: InferenceSession,
313
+ ort_inputs: dict,
314
+ device: str,
315
+ device_id: int,
316
+ use_buffer_share: bool,
317
+ kv_cache_ortvalues: dict,
318
+ ):
319
+ io_binding = model.io_binding()
320
+
321
+ model_inputs = {i.name for i in model.get_inputs()}
322
+ for k, v in ort_inputs.items():
323
+ # Use this check to handle scenarios such as INT4 CUDA and FP16 CUDA models with
324
+ # GQA + RotaryEmbedding fusion where `position_ids` is removed as an ONNX model input
325
+ # but `position_ids` is used as a PyTorch model input
326
+ if k not in model_inputs:
327
+ continue
328
+
329
+ # Bind OrtValue inputs to device
330
+ if use_buffer_share and ("cache" in k or "past_key_values" in k):
331
+ if k not in kv_cache_ortvalues:
332
+ v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id)
333
+ io_binding.bind_ortvalue_input(k, v_device)
334
+ kv_cache_ortvalues[k] = v_device
335
+ else:
336
+ kv_cache_ortvalues[k].update_inplace(v)
337
+ io_binding.bind_ortvalue_input(k, kv_cache_ortvalues[k])
338
+ else:
339
+ v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id)
340
+ io_binding.bind_ortvalue_input(k, v_device)
341
+
342
+ for output in model.get_outputs():
343
+ name = output.name
344
+ if use_buffer_share and ("out" in name or "present" in name):
345
+ # Bind present KV cache outputs to past KV cache inputs in order to buffer share
346
+ input_name = name.replace("out", "cache").replace("present", "past_key_values")
347
+ io_binding.bind_ortvalue_output(name, kv_cache_ortvalues[input_name])
348
+ else:
349
+ io_binding.bind_output(name, device_type=device, device_id=device_id)
350
+
351
+ return io_binding, kv_cache_ortvalues
352
+
353
+
354
+ # Add IO bindings for execution providers using PyTorch tensors
355
+ # Use when you need to run inference many times
356
+ def add_io_bindings_as_tensors(
357
+ model: InferenceSession, inputs: dict, outputs: dict, use_fp16: bool, use_buffer_share: bool
358
+ ):
359
+ # Verify model inputs
360
+ inputs = verify_ort_inputs(model, inputs)
361
+
362
+ device = None
363
+ pt_to_np = {
364
+ "torch.int32": np.int32,
365
+ "torch.int64": np.int64,
366
+ "torch.float16": np.float16,
367
+ "torch.float32": np.float32,
368
+ }
369
+
370
+ # Bind inputs/outputs to IO binding
371
+ io_binding = model.io_binding()
372
+ for k, v in inputs.items():
373
+ io_binding.bind_input(
374
+ name=k,
375
+ device_type=v.device.type,
376
+ device_id=0 if v.device.type == "cpu" else v.device.index,
377
+ element_type=pt_to_np[repr(v.dtype)],
378
+ shape=tuple(v.shape),
379
+ buffer_ptr=v.data_ptr(),
380
+ )
381
+ device = v.device
382
+
383
+ for output in model.get_outputs():
384
+ name = output.name
385
+ # Bind KV cache outputs to KV cache inputs
386
+ v = (
387
+ inputs[name.replace("present", "past_key_values")]
388
+ if use_buffer_share and "present" in name
389
+ else outputs[name]
390
+ )
391
+ io_binding.bind_output(
392
+ name=name,
393
+ device_type=device.type,
394
+ device_id=0 if device.type == "cpu" else device.index,
395
+ element_type=(np.float16 if use_fp16 else np.float32),
396
+ shape=tuple(v.shape),
397
+ buffer_ptr=v.data_ptr(),
398
+ )
399
+
400
+ return io_binding
401
+
402
+
403
+ # Get actual inputs when using real data (instead of sample data) and initialize outputs
404
+ def get_initial_inputs_and_outputs(
405
+ config: AutoConfig,
406
+ tokenizer: AutoTokenizer,
407
+ requested_length: int,
408
+ prompt: list[str],
409
+ device: torch.device,
410
+ use_fp16: bool,
411
+ use_buffer_share: bool,
412
+ engine: str,
413
+ ):
414
+ tokenizer.pad_token = tokenizer.eos_token
415
+ encodings_dict = tokenizer.batch_encode_plus(prompt, padding=True)
416
+ torch_dtype = torch.float16 if use_fp16 else torch.float32
417
+
418
+ # input_ids: pad token id is 0
419
+ # attention_mask: pad token id is 0
420
+ # position_ids: pad token id is 1
421
+ input_ids = torch.tensor(encodings_dict["input_ids"], device=device, dtype=torch.int64)
422
+ attention_mask = torch.tensor(encodings_dict["attention_mask"], device=device, dtype=torch.int64)
423
+ position_ids = get_position_ids(attention_mask, use_past_kv=False)
424
+
425
+ # Check if tokenized prompt length matches the requested prompt length
426
+ tokenized_length = input_ids.shape[-1]
427
+ if tokenized_length > requested_length:
428
+ # Shorten the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
429
+ input_ids = input_ids[:, :requested_length]
430
+ attention_mask = attention_mask[:, :requested_length]
431
+ position_ids = get_position_ids(attention_mask, use_past_kv=False)
432
+ elif tokenized_length < requested_length:
433
+ # Lengthen the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
434
+ input_ids_first_col = input_ids[:, 0].unsqueeze(0).T
435
+ attention_mask_first_col = attention_mask[:, 0].unsqueeze(0).T
436
+ for _ in range(requested_length - tokenized_length):
437
+ input_ids = torch.hstack((input_ids_first_col, input_ids))
438
+ attention_mask = torch.hstack((attention_mask_first_col, attention_mask))
439
+ position_ids = get_position_ids(attention_mask, use_past_kv=False)
440
+
441
+ tokenized_length = input_ids.shape[-1]
442
+ assert tokenized_length == requested_length
443
+
444
+ # Create inputs
445
+ inputs = {
446
+ "input_ids": input_ids.contiguous() if engine == "ort" else input_ids,
447
+ "attention_mask": attention_mask.contiguous() if engine == "ort" else attention_mask,
448
+ "position_ids": position_ids.contiguous() if engine == "ort" else position_ids,
449
+ }
450
+ if engine != "ort":
451
+ inputs["past_key_values"] = []
452
+
453
+ # Get shape of KV cache inputs
454
+ batch_size, sequence_length = input_ids.shape
455
+ max_sequence_length = config.max_position_embeddings
456
+ num_heads = config.num_key_value_heads
457
+ head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
458
+
459
+ # Create KV cache inputs
460
+ for i in range(config.num_hidden_layers):
461
+ past_key = torch.zeros(
462
+ batch_size,
463
+ num_heads,
464
+ max_sequence_length if use_buffer_share else 0,
465
+ head_size,
466
+ device=device,
467
+ dtype=torch_dtype,
468
+ )
469
+ past_value = torch.zeros(
470
+ batch_size,
471
+ num_heads,
472
+ max_sequence_length if use_buffer_share else 0,
473
+ head_size,
474
+ device=device,
475
+ dtype=torch_dtype,
476
+ )
477
+ if engine == "ort":
478
+ inputs.update(
479
+ {
480
+ f"past_key_values.{i}.key": past_key.contiguous(),
481
+ f"past_key_values.{i}.value": past_value.contiguous(),
482
+ }
483
+ )
484
+ else:
485
+ inputs["past_key_values"].append((past_key, past_value))
486
+
487
+ outputs = None
488
+ if engine == "ort":
489
+ # Create outputs
490
+ logits = torch.zeros(batch_size, sequence_length, config.vocab_size, device=device, dtype=torch_dtype)
491
+ outputs = {"logits": logits.contiguous()}
492
+ if not use_buffer_share:
493
+ for i in range(config.num_hidden_layers):
494
+ present_key = torch.zeros(
495
+ batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype
496
+ )
497
+ present_value = torch.zeros(
498
+ batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype
499
+ )
500
+ outputs.update(
501
+ {f"present.{i}.key": present_key.contiguous(), f"present.{i}.value": present_value.contiguous()}
502
+ )
503
+
504
+ return inputs, outputs