onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,57 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ import os
7
+
8
+ import torch.distributed as dist
9
+
10
+
11
+ def init_dist():
12
+ if "LOCAL_RANK" in os.environ:
13
+ int(os.environ["LOCAL_RANK"])
14
+ rank = int(os.environ["RANK"])
15
+ world_size = int(os.environ["WORLD_SIZE"])
16
+
17
+ dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7645", world_size=world_size, rank=rank)
18
+ elif "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
19
+ int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", 0))
20
+ rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0))
21
+ world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1))
22
+
23
+ dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7647", world_size=world_size, rank=rank)
24
+ else:
25
+ # don't need to do init for single process
26
+ pass
27
+
28
+
29
+ def _get_comm():
30
+ try:
31
+ from mpi4py import MPI
32
+
33
+ comm = MPI.COMM_WORLD
34
+ return comm
35
+ except ImportError:
36
+ return None
37
+
38
+
39
+ def get_rank():
40
+ comm = _get_comm()
41
+ return comm.Get_rank() if comm is not None else 0
42
+
43
+
44
+ def get_size():
45
+ comm = _get_comm()
46
+ return comm.Get_size() if comm is not None else 1
47
+
48
+
49
+ def barrier():
50
+ comm = _get_comm()
51
+ if comm is not None:
52
+ comm.Barrier()
53
+
54
+
55
+ def print_out(*args):
56
+ if get_rank() == 0:
57
+ print(*args)
@@ -0,0 +1,503 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ from __future__ import annotations
7
+
8
+ import numpy as np
9
+ import torch
10
+ from transformers import AutoConfig, AutoTokenizer
11
+
12
+ from onnxruntime import InferenceSession, OrtValue
13
+
14
+
15
+ # Get position_ids from attention_mask
16
+ def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool):
17
+ position_ids = attention_mask.long().cumsum(-1) - 1
18
+ position_ids.masked_fill_(attention_mask == 0, 1)
19
+ if use_past_kv:
20
+ # Shape: (batch_size, 1)
21
+ position_ids = position_ids[:, -1].unsqueeze(-1)
22
+
23
+ # Shape: (batch_size, sequence_length)
24
+ return position_ids
25
+
26
+
27
+ # Inputs for first pass to get initial past_key_values
28
+ # input_ids: (batch_size, sequence_length)
29
+ # attention_mask: (batch_size, sequence_length)
30
+ # position_ids: (batch_size, sequence_length)
31
+ def get_sample_inputs(
32
+ config: AutoConfig,
33
+ device: torch.device,
34
+ batch_size: int,
35
+ seq_len: int,
36
+ engine: str = "pt",
37
+ return_dict: bool = False,
38
+ ):
39
+ input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
40
+ attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64)
41
+ position_ids = get_position_ids(attention_mask, use_past_kv=False)
42
+
43
+ # Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
44
+ input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
45
+ attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
46
+ position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
47
+
48
+ if not return_dict:
49
+ # For export
50
+ return (input_ids, attention_mask, position_ids)
51
+
52
+ inputs = {
53
+ "input_ids": input_ids,
54
+ "attention_mask": attention_mask,
55
+ "position_ids": position_ids,
56
+ }
57
+ return inputs
58
+
59
+
60
+ # Inputs for subsequent passes with past_key_values
61
+ # input_ids: (batch_size, 1)
62
+ # attention_mask: (batch_size, past_sequence_length + 1)
63
+ # position_ids: (batch_size, 1)
64
+ # past_key: (batch_size, num_heads, past_sequence_length, head_size)
65
+ # past_value: (batch_size, num_heads, past_sequence_length, head_size)
66
+ def get_sample_with_past_kv_inputs(
67
+ config: AutoConfig,
68
+ device: torch.device,
69
+ batch_size: int,
70
+ past_seq_len: int,
71
+ use_fp16: bool = False,
72
+ engine: str = "pt",
73
+ return_dict: bool = False,
74
+ world_size: int = 1,
75
+ ):
76
+ input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), dtype=torch.int64)
77
+ attention_mask = torch.ones(batch_size, past_seq_len + 1, dtype=torch.int64)
78
+ # position_ids is of shape (batch_size, 1)
79
+ position_ids = get_position_ids(attention_mask, use_past_kv=True)
80
+ past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)
81
+
82
+ # Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
83
+ input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
84
+ attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
85
+ position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
86
+ past_kv = (
87
+ flatten_past_kv_inputs(past_kv)
88
+ if engine == "ort"
89
+ else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv))
90
+ )
91
+
92
+ if not return_dict:
93
+ # For export
94
+ assert isinstance(past_kv, list)
95
+ return (input_ids, attention_mask, position_ids, past_kv)
96
+
97
+ inputs = {
98
+ "input_ids": input_ids,
99
+ "attention_mask": attention_mask,
100
+ "position_ids": position_ids,
101
+ }
102
+ if engine == "ort":
103
+ assert isinstance(past_kv, dict)
104
+ inputs.update(past_kv)
105
+ else:
106
+ assert isinstance(past_kv, list)
107
+ inputs["past_key_values"] = past_kv
108
+
109
+ return inputs
110
+
111
+
112
+ # Inputs for all passes with past_key_values
113
+ # input_ids: (batch_size, sequence_length)
114
+ # attention_mask: (batch_size, past_sequence_length + sequence_length)
115
+ # position_ids: (batch_size, sequence_length)
116
+ # past_key: (batch_size, num_heads, kv_sequence_length, head_size)
117
+ # For models with GQA, kv_sequence_length = max_sequence_length
118
+ # For models without GQA, kv_sequence_length = past_sequence_length
119
+ # past_value: (batch_size, num_heads, kv_sequence_length, head_size)
120
+ # For models with GQA, kv_sequence_length = max_sequence_length
121
+ # For models without GQA, kv_sequence_length = past_sequence_length
122
+ def get_merged_sample_with_past_kv_inputs(
123
+ config: AutoConfig,
124
+ device: torch.device,
125
+ batch_size: int,
126
+ seq_len: int,
127
+ past_seq_len: int,
128
+ max_seq_len: int,
129
+ use_fp16: bool = False,
130
+ use_buffer_share: bool = False,
131
+ engine: str = "pt",
132
+ return_dict: bool = False,
133
+ world_size: int = 1,
134
+ ):
135
+ input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
136
+ attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64)
137
+ # position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation
138
+ position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0))
139
+ past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)
140
+
141
+ # Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
142
+ input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
143
+ attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
144
+ position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
145
+ past_kv = (
146
+ flatten_past_kv_inputs(past_kv)
147
+ if engine == "ort"
148
+ else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv))
149
+ )
150
+
151
+ if not return_dict:
152
+ # For export
153
+ assert isinstance(past_kv, list)
154
+ return (input_ids, attention_mask, position_ids, past_kv)
155
+
156
+ inputs = {
157
+ "input_ids": input_ids,
158
+ "attention_mask": attention_mask,
159
+ "position_ids": position_ids,
160
+ }
161
+ if engine == "ort":
162
+ assert isinstance(past_kv, dict)
163
+ inputs.update(past_kv)
164
+
165
+ if use_buffer_share:
166
+ inputs = enable_past_present_share_buffer(inputs, past_seq_len, max_seq_len)
167
+
168
+ else:
169
+ assert isinstance(past_kv, list)
170
+ inputs["past_key_values"] = past_kv
171
+
172
+ return inputs
173
+
174
+
175
+ # Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx
176
+ def get_msft_sample_inputs(
177
+ config: AutoConfig,
178
+ batch_size: int,
179
+ past_seq_len: int,
180
+ seq_len: int,
181
+ max_seq_len: int,
182
+ use_fp16: bool,
183
+ use_buffer_share: bool,
184
+ split_kv: bool,
185
+ ):
186
+ np_dtype = np.float16 if use_fp16 else np.float32
187
+ head_size = config.hidden_size // config.num_attention_heads
188
+
189
+ if not split_kv:
190
+ ort_inputs = {
191
+ "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype),
192
+ "attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype),
193
+ "k_cache": np.random.rand(
194
+ batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
195
+ ).astype(np_dtype),
196
+ "v_cache": np.random.rand(
197
+ batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
198
+ ).astype(np_dtype),
199
+ "pos": np.array(past_seq_len, dtype=np.int64),
200
+ }
201
+ else:
202
+ ort_inputs = {
203
+ "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype),
204
+ "attn_mask": (np.triu(np.ones((batch_size, max_seq_len, max_seq_len), dtype=np.int32), k=1) - 1).astype(
205
+ np.int32
206
+ ),
207
+ "pos": np.array(past_seq_len, dtype=np.int64),
208
+ }
209
+ for i in range(config.num_hidden_layers):
210
+ ort_inputs.update(
211
+ {
212
+ f"k_{i}_cache": np.random.rand(
213
+ batch_size, config.num_attention_heads, past_seq_len, head_size
214
+ ).astype(np_dtype),
215
+ f"v_{i}_cache": np.random.rand(
216
+ batch_size, config.num_attention_heads, past_seq_len, head_size
217
+ ).astype(np_dtype),
218
+ }
219
+ )
220
+
221
+ if use_buffer_share:
222
+ ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len)
223
+
224
+ return ort_inputs
225
+
226
+
227
+ # Create past_key_values
228
+ # Each is of shape (batch_size, num_heads, past_sequence_length, head_size)
229
+ def get_past_kv_inputs(config: AutoConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1):
230
+ num_heads = config.num_key_value_heads // world_size
231
+ head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
232
+ torch_dtype = torch.float16 if use_fp16 else torch.float32
233
+ past_kv = [
234
+ (
235
+ torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype),
236
+ torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype),
237
+ )
238
+ for _ in range(config.num_hidden_layers)
239
+ ]
240
+ return past_kv
241
+
242
+
243
+ # Convert list of past_key_values to dict of past_key and past_value
244
+ def flatten_past_kv_inputs(past_key_values: list[tuple[torch.Tensor, torch.Tensor]]):
245
+ past_kv = {}
246
+ for i, (past_k, past_v) in enumerate(past_key_values):
247
+ past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy()
248
+ past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy()
249
+ return past_kv
250
+
251
+
252
+ # Format PyTorch inputs to ONNX Runtime inputs
253
+ def convert_inputs_for_ort(
254
+ pt_inputs: dict,
255
+ use_buffer_share: bool = False,
256
+ past_seq_len: int = 0,
257
+ max_seq_len: int = 2048,
258
+ ):
259
+ ort_inputs = {}
260
+ for k, v in pt_inputs.items():
261
+ if isinstance(v, np.ndarray):
262
+ ort_inputs[k] = v
263
+ elif k == "past_key_values":
264
+ ort_inputs.update(flatten_past_kv_inputs(v))
265
+ else:
266
+ ort_inputs[k] = v.detach().cpu().numpy()
267
+
268
+ # Reshape KV caches if using past-present-share-buffer
269
+ if use_buffer_share:
270
+ ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len)
271
+
272
+ return ort_inputs
273
+
274
+
275
+ # Re-allocate KV caches from (batch_size, num_heads, past_sequence_length, head_size) to
276
+ # (batch_size, num_heads, max_sequence_length, head_size) for past-present buffer sharing
277
+ def enable_past_present_share_buffer(ort_inputs: dict, past_seq_len: int, max_seq_len: int):
278
+ for k, v in ort_inputs.items():
279
+ # Allocate new buffers with max_sequence_length for GQA
280
+ if "cache" in k or "past_key_values" in k:
281
+ # Copy v (BxSxPxH) into new_v (BxSxMxH)
282
+ batch_size, num_heads, _, head_size = v.shape
283
+ new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype)
284
+ new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v
285
+ ort_inputs[k] = new_v
286
+ return ort_inputs
287
+
288
+
289
+ # Verify ONNX Runtime inputs with model
290
+ def verify_ort_inputs(model: InferenceSession, ort_inputs: dict):
291
+ # Check that all model inputs will be provided
292
+ model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
293
+ user_inputs = set(ort_inputs.keys())
294
+ missing_inputs = model_inputs - user_inputs
295
+ if len(missing_inputs):
296
+ print(f"The following model inputs are missing: {missing_inputs}")
297
+ raise Exception("There are missing inputs to the model. Please add them and try again.")
298
+
299
+ # Remove unnecessary inputs from model inputs
300
+ unnecessary_inputs = user_inputs - model_inputs
301
+ if len(unnecessary_inputs):
302
+ for unnecessary_input in unnecessary_inputs:
303
+ del ort_inputs[unnecessary_input]
304
+
305
+ return ort_inputs
306
+
307
+
308
+ # Add IO bindings for execution providers using OrtValue
309
+ # Use when you need to run inference once or twice to save memory
310
+ def add_io_bindings_as_ortvalues(
311
+ model: InferenceSession,
312
+ ort_inputs: dict,
313
+ device: str,
314
+ device_id: int,
315
+ use_buffer_share: bool,
316
+ kv_cache_ortvalues: dict,
317
+ ):
318
+ io_binding = model.io_binding()
319
+
320
+ model_inputs = set(map(lambda i: i.name, model.get_inputs()))
321
+ for k, v in ort_inputs.items():
322
+ # Use this check to handle scenarios such as INT4 CUDA and FP16 CUDA models with
323
+ # GQA + RotaryEmbedding fusion where `position_ids` is removed as an ONNX model input
324
+ # but `position_ids` is used as a PyTorch model input
325
+ if k not in model_inputs:
326
+ continue
327
+
328
+ # Bind OrtValue inputs to device
329
+ if use_buffer_share and ("cache" in k or "past_key_values" in k):
330
+ if k not in kv_cache_ortvalues:
331
+ v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id)
332
+ io_binding.bind_ortvalue_input(k, v_device)
333
+ kv_cache_ortvalues[k] = v_device
334
+ else:
335
+ kv_cache_ortvalues[k].update_inplace(v)
336
+ io_binding.bind_ortvalue_input(k, kv_cache_ortvalues[k])
337
+ else:
338
+ v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id)
339
+ io_binding.bind_ortvalue_input(k, v_device)
340
+
341
+ for output in model.get_outputs():
342
+ name = output.name
343
+ if use_buffer_share and ("out" in name or "present" in name):
344
+ # Bind present KV cache outputs to past KV cache inputs in order to buffer share
345
+ input_name = name.replace("out", "cache").replace("present", "past_key_values")
346
+ io_binding.bind_ortvalue_output(name, kv_cache_ortvalues[input_name])
347
+ else:
348
+ io_binding.bind_output(name, device_type=device, device_id=device_id)
349
+
350
+ return io_binding, kv_cache_ortvalues
351
+
352
+
353
+ # Add IO bindings for execution providers using PyTorch tensors
354
+ # Use when you need to run inference many times
355
+ def add_io_bindings_as_tensors(
356
+ model: InferenceSession, inputs: dict, outputs: dict, use_fp16: bool, use_buffer_share: bool
357
+ ):
358
+ # Verify model inputs
359
+ inputs = verify_ort_inputs(model, inputs)
360
+
361
+ device = None
362
+ pt_to_np = {
363
+ "torch.int32": np.int32,
364
+ "torch.int64": np.int64,
365
+ "torch.float16": np.float16,
366
+ "torch.float32": np.float32,
367
+ }
368
+
369
+ # Bind inputs/outputs to IO binding
370
+ io_binding = model.io_binding()
371
+ for k, v in inputs.items():
372
+ io_binding.bind_input(
373
+ name=k,
374
+ device_type=v.device.type,
375
+ device_id=0 if v.device.type == "cpu" else v.device.index,
376
+ element_type=pt_to_np[repr(v.dtype)],
377
+ shape=tuple(v.shape),
378
+ buffer_ptr=v.data_ptr(),
379
+ )
380
+ device = v.device
381
+
382
+ for output in model.get_outputs():
383
+ name = output.name
384
+ # Bind KV cache outputs to KV cache inputs
385
+ v = (
386
+ inputs[name.replace("present", "past_key_values")]
387
+ if use_buffer_share and "present" in name
388
+ else outputs[name]
389
+ )
390
+ io_binding.bind_output(
391
+ name=name,
392
+ device_type=device.type,
393
+ device_id=0 if device.type == "cpu" else device.index,
394
+ element_type=(np.float16 if use_fp16 else np.float32),
395
+ shape=tuple(v.shape),
396
+ buffer_ptr=v.data_ptr(),
397
+ )
398
+
399
+ return io_binding
400
+
401
+
402
+ # Get actual inputs when using real data (instead of sample data) and initialize outputs
403
+ def get_initial_inputs_and_outputs(
404
+ config: AutoConfig,
405
+ tokenizer: AutoTokenizer,
406
+ requested_length: int,
407
+ prompt: list[str],
408
+ device: torch.device,
409
+ use_fp16: bool,
410
+ use_buffer_share: bool,
411
+ engine: str,
412
+ ):
413
+ tokenizer.pad_token = tokenizer.eos_token
414
+ encodings_dict = tokenizer.batch_encode_plus(prompt, padding=True)
415
+ torch_dtype = torch.float16 if use_fp16 else torch.float32
416
+
417
+ # input_ids: pad token id is 0
418
+ # attention_mask: pad token id is 0
419
+ # position_ids: pad token id is 1
420
+ input_ids = torch.tensor(encodings_dict["input_ids"], device=device, dtype=torch.int64)
421
+ attention_mask = torch.tensor(encodings_dict["attention_mask"], device=device, dtype=torch.int64)
422
+ position_ids = get_position_ids(attention_mask, use_past_kv=False)
423
+
424
+ # Check if tokenized prompt length matches the requested prompt length
425
+ tokenized_length = input_ids.shape[-1]
426
+ if tokenized_length > requested_length:
427
+ # Shorten the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
428
+ input_ids = input_ids[:, :requested_length]
429
+ attention_mask = attention_mask[:, :requested_length]
430
+ position_ids = get_position_ids(attention_mask, use_past_kv=False)
431
+ elif tokenized_length < requested_length:
432
+ # Lengthen the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
433
+ input_ids_first_col = input_ids[:, 0].unsqueeze(0).T
434
+ attention_mask_first_col = attention_mask[:, 0].unsqueeze(0).T
435
+ for _ in range(requested_length - tokenized_length):
436
+ input_ids = torch.hstack((input_ids_first_col, input_ids))
437
+ attention_mask = torch.hstack((attention_mask_first_col, attention_mask))
438
+ position_ids = get_position_ids(attention_mask, use_past_kv=False)
439
+
440
+ tokenized_length = input_ids.shape[-1]
441
+ assert tokenized_length == requested_length
442
+
443
+ # Create inputs
444
+ inputs = {
445
+ "input_ids": input_ids.contiguous() if engine == "ort" else input_ids,
446
+ "attention_mask": attention_mask.contiguous() if engine == "ort" else attention_mask,
447
+ "position_ids": position_ids.contiguous() if engine == "ort" else position_ids,
448
+ }
449
+ if engine != "ort":
450
+ inputs["past_key_values"] = []
451
+
452
+ # Get shape of KV cache inputs
453
+ batch_size, sequence_length = input_ids.shape
454
+ max_sequence_length = config.max_position_embeddings
455
+ num_heads = config.num_key_value_heads
456
+ head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
457
+
458
+ # Create KV cache inputs
459
+ for i in range(config.num_hidden_layers):
460
+ past_key = torch.zeros(
461
+ batch_size,
462
+ num_heads,
463
+ max_sequence_length if use_buffer_share else 0,
464
+ head_size,
465
+ device=device,
466
+ dtype=torch_dtype,
467
+ )
468
+ past_value = torch.zeros(
469
+ batch_size,
470
+ num_heads,
471
+ max_sequence_length if use_buffer_share else 0,
472
+ head_size,
473
+ device=device,
474
+ dtype=torch_dtype,
475
+ )
476
+ if engine == "ort":
477
+ inputs.update(
478
+ {
479
+ f"past_key_values.{i}.key": past_key.contiguous(),
480
+ f"past_key_values.{i}.value": past_value.contiguous(),
481
+ }
482
+ )
483
+ else:
484
+ inputs["past_key_values"].append((past_key, past_value))
485
+
486
+ outputs = None
487
+ if engine == "ort":
488
+ # Create outputs
489
+ logits = torch.zeros(batch_size, sequence_length, config.vocab_size, device=device, dtype=torch_dtype)
490
+ outputs = {"logits": logits.contiguous()}
491
+ if not use_buffer_share:
492
+ for i in range(config.num_hidden_layers):
493
+ present_key = torch.zeros(
494
+ batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype
495
+ )
496
+ present_value = torch.zeros(
497
+ batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype
498
+ )
499
+ outputs.update(
500
+ {f"present.{i}.key": present_key.contiguous(), f"present.{i}.value": present_value.contiguous()}
501
+ )
502
+
503
+ return inputs, outputs