onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,524 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import logging
8
+ import os
9
+ from pathlib import Path
10
+ from typing import Dict, Tuple, Union
11
+
12
+ import numpy as np
13
+ import torch
14
+ from float16 import float_to_float16_max_diff
15
+ from onnx_model import OnnxModel
16
+ from optimizer import optimize_model
17
+ from packaging import version
18
+ from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor
19
+ from transformers import __version__ as transformers_version
20
+ from whisper_decoder import WhisperDecoder, WhisperDecoderHelper, WhisperDecoderInit
21
+ from whisper_encoder import WhisperEncoder, WhisperEncoderHelper
22
+ from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper
23
+
24
+ from onnxruntime import InferenceSession
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ PRETRAINED_WHISPER_MODELS = [
29
+ "whisper-tiny",
30
+ "whisper-tiny.en",
31
+ "whisper-base",
32
+ "whisper-base.en",
33
+ "whisper-small",
34
+ "whisper-small.en",
35
+ "whisper-medium",
36
+ "whisper-medium.en",
37
+ "whisper-large",
38
+ "whisper-large-v2",
39
+ "whisper-large-v3",
40
+ ]
41
+
42
+
43
+ class WhisperHelper:
44
+ @staticmethod
45
+ def get_onnx_path(
46
+ output_dir: str,
47
+ model_name_or_path: str,
48
+ suffix: str = "",
49
+ new_folder: bool = False,
50
+ ) -> str:
51
+ """Build onnx path
52
+
53
+ Args:
54
+ output_dir (str): output directory
55
+ model_name_or_path (str): pretrained model name, or path to the model checkpoint
56
+ suffix (str, optional): suffix like "_encoder" or "_decoder_fp16" will be appended to file name. Defaults to None.
57
+ new_folder (bool, optional): create a new directory for the model. Defaults to False.
58
+
59
+ Returns:
60
+ str: path of onnx model
61
+ """
62
+ model_name = model_name_or_path
63
+ if os.path.isdir(model_name_or_path):
64
+ model_name = Path(model_name_or_path).parts[-1]
65
+ else:
66
+ model_name = model_name.split("/")[-1]
67
+
68
+ model_name += suffix
69
+
70
+ directory = os.path.join(output_dir, model_name) if new_folder else output_dir
71
+ return os.path.join(directory, model_name + ".onnx")
72
+
73
+ @staticmethod
74
+ def load_model_openai(
75
+ model_name_or_path: str,
76
+ cache_dir: str,
77
+ device: torch.device,
78
+ ) -> torch.nn.Module:
79
+ """Load model given a pretrained name or path, then build models for ONNX conversion.
80
+
81
+ Args:
82
+ model_name_or_path (str): pretrained model name or path
83
+ cache_dir (str): cache directory
84
+ device (torch.device): device to run the model
85
+ merge_encoder_and_decoder_init (bool, optional): Whether merge encoder and decoder initialization into one ONNX model. Defaults to True.
86
+ Returns:
87
+ Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion.
88
+ """
89
+ from whisper import _ALIGNMENT_HEADS, _MODELS, _download
90
+ from whisper.model import ModelDimensions, Whisper
91
+
92
+ in_memory = False
93
+
94
+ model_name = model_name_or_path.split("/")[-1][8:]
95
+ checkpoint_file, alignment_heads = None, None
96
+ if model_name in _MODELS:
97
+ checkpoint_file = _download(_MODELS[model_name], cache_dir, in_memory)
98
+ alignment_heads = _ALIGNMENT_HEADS[model_name]
99
+
100
+ with open(checkpoint_file, "rb") as fp:
101
+ checkpoint = torch.load(fp, map_location=device)
102
+ del checkpoint_file
103
+
104
+ dims = ModelDimensions(**checkpoint["dims"])
105
+ model = Whisper(dims)
106
+ model.load_state_dict(checkpoint["model_state_dict"])
107
+
108
+ if alignment_heads is not None:
109
+ model.set_alignment_heads(alignment_heads)
110
+ return model.to(device)
111
+
112
+ @staticmethod
113
+ def load_model(
114
+ model_name_or_path: str,
115
+ model_impl: str,
116
+ cache_dir: str,
117
+ device: torch.device,
118
+ merge_encoder_and_decoder_init: bool = True,
119
+ state_dict_path: str = "",
120
+ ) -> Dict[str, torch.nn.Module]:
121
+ """Load model given a pretrained name or path, then build models for ONNX conversion.
122
+
123
+ Args:
124
+ model_name_or_path (str): pretrained model name or path
125
+ cache_dir (str): cache directory
126
+ device (torch.device): device to run the model
127
+ merge_encoder_and_decoder_init (bool, optional): Whether merge encoder and decoder initialization into one ONNX model. Defaults to True.
128
+ Returns:
129
+ Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion.
130
+ """
131
+ extra_kwargs = {}
132
+ if version.parse(transformers_version) >= version.parse("4.36.0"):
133
+ extra_kwargs["attn_implementation"] = "eager"
134
+ model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir, **extra_kwargs)
135
+
136
+ if model_impl == "openai":
137
+ openai_model = WhisperHelper.load_model_openai(model_name_or_path, cache_dir, device)
138
+ model_encoder, model_decoder = openai_model.encoder, openai_model.decoder
139
+ passed_model = openai_model
140
+ else:
141
+ model_encoder, model_decoder = model, model
142
+ passed_model = None
143
+
144
+ if state_dict_path:
145
+ model.load_state_dict(torch.load(state_dict_path), strict=False)
146
+
147
+ decoder = WhisperDecoder(model_decoder, model.config, model_impl=model_impl, model=passed_model)
148
+ decoder.eval().to(device)
149
+
150
+ if merge_encoder_and_decoder_init:
151
+ encoder_decoder_init = WhisperEncoderDecoderInit(
152
+ model_encoder,
153
+ model_decoder,
154
+ model.config,
155
+ decoder_start_token_id=None,
156
+ model_impl=model_impl,
157
+ model=passed_model,
158
+ )
159
+ return {"encoder_decoder_init": encoder_decoder_init, "decoder": decoder}
160
+ else:
161
+ encoder = WhisperEncoder(model.model.encoder, model.config)
162
+ encoder.eval().to(device)
163
+ decoder_init = WhisperDecoderInit(model.decoder, model.config)
164
+ decoder_init.eval().to(device)
165
+ return {
166
+ "encoder": encoder,
167
+ "decoder": decoder,
168
+ "decoder_init": decoder_init,
169
+ }
170
+
171
+ @staticmethod
172
+ def export_onnx(
173
+ model: Union[WhisperEncoder, WhisperDecoder, WhisperDecoderInit, WhisperEncoderDecoderInit],
174
+ device: torch.device,
175
+ onnx_model_path: str,
176
+ verbose: bool = True,
177
+ use_external_data_format: bool = False,
178
+ use_decoder_input_ids: bool = True,
179
+ use_int32_inputs: bool = False,
180
+ ):
181
+ if isinstance(model, WhisperEncoder):
182
+ WhisperEncoderHelper.export_onnx(
183
+ model,
184
+ device,
185
+ onnx_model_path,
186
+ verbose,
187
+ use_external_data_format,
188
+ )
189
+ elif isinstance(model, WhisperEncoderDecoderInit):
190
+ WhisperEncoderDecoderInitHelper.export_onnx(
191
+ model,
192
+ device,
193
+ onnx_model_path,
194
+ use_decoder_input_ids,
195
+ verbose,
196
+ use_external_data_format,
197
+ use_int32_inputs,
198
+ )
199
+ else:
200
+ WhisperDecoderHelper.export_onnx(
201
+ model,
202
+ device,
203
+ onnx_model_path,
204
+ verbose,
205
+ use_external_data_format,
206
+ use_int32_inputs,
207
+ )
208
+
209
+ @staticmethod
210
+ def auto_mixed_precision(
211
+ onnx_model: OnnxModel,
212
+ op_block_list: Tuple[str] = (
213
+ "SimplifiedLayerNormalization",
214
+ "SkipSimplifiedLayerNormalization",
215
+ "Relu",
216
+ "Add",
217
+ ),
218
+ ):
219
+ """Convert model to mixed precision.
220
+ It detects whether original model has fp16 precision weights, and set parameters for float16 conversion automatically.
221
+ Args:
222
+ onnx_model (OnnxModel): optimized ONNX model
223
+ op_block_list (List[str], optional): . Defaults to ["SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization", "Relu", "Add"]
224
+ Returns:
225
+ parameters(dict): a dictionary of parameters used in float16 conversion
226
+ """
227
+ op_full_set = set([node.op_type for node in onnx_model.nodes()])
228
+ fp32_op_set = set(op_block_list)
229
+ fp16_op_set = op_full_set.difference(fp32_op_set)
230
+ logger.info(f"fp32 op: {fp32_op_set} fp16 op: {fp16_op_set}")
231
+
232
+ # logits is the first output
233
+ logits_output_name = onnx_model.graph().output[0].name
234
+
235
+ # We use the weight in last MatMul node to detect whether the model is stored with float16 weights from training.
236
+ is_weight_fp16_precision = False
237
+ output_name_to_node = onnx_model.output_name_to_node()
238
+ assert logits_output_name in output_name_to_node
239
+ node = output_name_to_node[logits_output_name]
240
+ last_matmul_node = None
241
+ if node.op_type == "MatMul":
242
+ last_matmul_node = node
243
+ logger.info(f"Found last MatMul node for logits: {node.name}")
244
+ initializer = None
245
+ for input in node.input:
246
+ initializer = onnx_model.get_initializer(input)
247
+ if initializer is not None:
248
+ break
249
+
250
+ # when the max difference of value after converting float to float16 is lower than a threshold (1e-6),
251
+ # we can deduce that the weights are stored in float16 precision.
252
+ max_diff = float_to_float16_max_diff(initializer)
253
+ logger.debug(f"max diff of converting weights in last MatMul node {node.name}: {max_diff}")
254
+ is_weight_fp16_precision = max_diff < 1e-6
255
+ else:
256
+ logger.warning(f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}")
257
+
258
+ keep_io_types = []
259
+ node_block_list = []
260
+ if (not is_weight_fp16_precision) and (last_matmul_node is not None):
261
+ # When original weight is float32 precision, keep logits and last MatMul in float32 could get better precision.
262
+ keep_io_types = [logits_output_name]
263
+ node_block_list = [last_matmul_node.name]
264
+
265
+ parameters = {
266
+ "keep_io_types": keep_io_types,
267
+ "op_block_list": list(op_block_list),
268
+ "node_block_list": node_block_list,
269
+ "force_fp16_initializers": is_weight_fp16_precision,
270
+ }
271
+
272
+ logger.info(f"auto_mixed_precision parameters: {parameters}")
273
+ onnx_model.convert_float_to_float16(use_symbolic_shape_infer=True, **parameters)
274
+
275
+ return parameters
276
+
277
+ @staticmethod
278
+ def optimize_onnx(
279
+ onnx_model_path: str,
280
+ optimized_model_path: str,
281
+ is_float16: bool,
282
+ num_attention_heads: int,
283
+ hidden_size: int,
284
+ use_external_data_format: bool = False,
285
+ auto_mixed_precision: bool = True,
286
+ use_gpu: bool = False,
287
+ provider: str = "cpu",
288
+ ):
289
+ """Optimize ONNX model with an option to convert it to use mixed precision."""
290
+
291
+ from fusion_options import FusionOptions
292
+
293
+ optimization_options = FusionOptions("bart")
294
+ optimization_options.use_multi_head_attention = True
295
+ optimization_options.disable_multi_head_attention_bias = provider == "rocm"
296
+
297
+ m = optimize_model(
298
+ onnx_model_path,
299
+ model_type="bart",
300
+ num_heads=num_attention_heads,
301
+ hidden_size=hidden_size,
302
+ opt_level=2 if not use_external_data_format else None,
303
+ optimization_options=optimization_options,
304
+ use_gpu=use_gpu,
305
+ only_onnxruntime=False,
306
+ )
307
+
308
+ if is_float16:
309
+ if auto_mixed_precision:
310
+ WhisperHelper.auto_mixed_precision(m)
311
+ else:
312
+ m.convert_model_float32_to_float16(cast_input_output=False)
313
+
314
+ m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True)
315
+
316
+ @staticmethod
317
+ def pt_transcription_for_verify_onnx(
318
+ processor: WhisperProcessor,
319
+ pt_model: torch.nn.Module,
320
+ device: torch.device,
321
+ batch_size: int = 1,
322
+ prompt_mode: bool = False,
323
+ ):
324
+ # Try to import `datasets` pip package
325
+ try:
326
+ from datasets import load_dataset
327
+ except Exception as e:
328
+ logger.error(f"An error occurred while importing `datasets`: {e}", exc_info=True) # noqa: G201
329
+ install_cmd = "pip install datasets"
330
+ logger.warning(f"Could not import `datasets`. Attempting to install `datasets` via `{install_cmd}`.")
331
+ os.system(install_cmd)
332
+
333
+ from datasets import load_dataset
334
+
335
+ ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
336
+ input_features_ = []
337
+ if batch_size == 1:
338
+ input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features
339
+ else:
340
+ input_features_ = [
341
+ processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features,
342
+ processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features,
343
+ ]
344
+ assert len(input_features_) == batch_size
345
+ input_features = torch.cat((input_features_[0], input_features_[1]))
346
+
347
+ max_length, min_length, num_beams, num_return_sequences = 30, 0, 1, 1
348
+ length_penalty, repetition_penalty = 1.0, 1.0
349
+ inputs = {
350
+ "input_features": input_features.to(device),
351
+ "max_length": max_length,
352
+ "min_length": min_length,
353
+ "num_beams": num_beams,
354
+ "num_return_sequences": num_return_sequences,
355
+ "length_penalty": length_penalty,
356
+ "repetition_penalty": repetition_penalty,
357
+ "early_stopping": True,
358
+ "use_cache": True,
359
+ }
360
+
361
+ if prompt_mode:
362
+ prompts = ["John has doubts", "Maria has grave doubts"]
363
+ prompt_ids = [processor.get_prompt_ids(p) for p in prompts]
364
+ pt_transcription = []
365
+ pt_outputs = []
366
+ # The looping for model.generate is necessary here due to the limitation as per
367
+ # https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate.prompt_ids
368
+ # prompt_ids input requires a tensor of rank 1
369
+ for i in range(batch_size):
370
+ inputs["prompt_ids"] = torch.from_numpy(prompt_ids[i])
371
+ inputs["input_features"] = input_features_[i].to(device)
372
+ pt_output = pt_model.generate(**inputs).detach().cpu().numpy()
373
+ pt_outputs.append(pt_output)
374
+ pt_transcription.append(processor.batch_decode(pt_output, skip_special_tokens=True)[0])
375
+ inputs["input_features"] = input_features
376
+ del inputs["prompt_ids"]
377
+ else:
378
+ prompt_ids = []
379
+ pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy()
380
+ pt_transcription = [processor.batch_decode(pt_outputs, skip_special_tokens=True)[0]]
381
+ pt_outputs = list(pt_outputs)
382
+ del inputs["early_stopping"]
383
+ del inputs["use_cache"]
384
+ return inputs, pt_transcription, pt_outputs, prompt_ids
385
+
386
+ @staticmethod
387
+ def select_transcription_options(
388
+ batch_size: int,
389
+ prompt_mode: bool,
390
+ ):
391
+ if batch_size > 1 and prompt_mode:
392
+ expected_transcription_no_comma_prompt1 = " John has doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky I"
393
+ expected_transcription_misspelled_prompt1 = " John has doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I"
394
+ expected_transcription_no_comma_prompt2 = " Maria has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky"
395
+ expected_transcription_misspelled_prompt2 = " Maria has grave doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I"
396
+ expected_transcription_options = {
397
+ expected_transcription_no_comma_prompt1,
398
+ expected_transcription_no_comma_prompt2,
399
+ expected_transcription_misspelled_prompt1,
400
+ expected_transcription_misspelled_prompt2,
401
+ }
402
+ else:
403
+ expected_transcription_no_comma = (
404
+ " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
405
+ )
406
+ expected_transcription_with_comma = (
407
+ " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
408
+ )
409
+ expected_transcription_with_quote_and_comma = (
410
+ ' "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
411
+ )
412
+ expected_transcription_options = {
413
+ expected_transcription_no_comma,
414
+ expected_transcription_with_comma,
415
+ expected_transcription_with_quote_and_comma,
416
+ }
417
+ return expected_transcription_options
418
+
419
+ @staticmethod
420
+ def verify_onnx(
421
+ model_name_or_path: str,
422
+ cache_dir: str,
423
+ ort_session: InferenceSession,
424
+ device: torch.device,
425
+ batch_size: int = 1,
426
+ prompt_mode: bool = False,
427
+ ):
428
+ """Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good."""
429
+ extra_kwargs = {}
430
+ if version.parse(transformers_version) >= version.parse("4.36.0"):
431
+ extra_kwargs["attn_implementation"] = "eager"
432
+ pt_model = WhisperForConditionalGeneration.from_pretrained(
433
+ model_name_or_path, cache_dir=cache_dir, **extra_kwargs
434
+ ).to(device)
435
+ processor = WhisperProcessor.from_pretrained(model_name_or_path, cache_dir=cache_dir)
436
+ config = WhisperConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir)
437
+
438
+ inputs, pt_transcription, pt_outputs, decoder_prompt_ids = WhisperHelper.pt_transcription_for_verify_onnx(
439
+ processor,
440
+ pt_model,
441
+ device,
442
+ batch_size=batch_size,
443
+ prompt_mode=prompt_mode,
444
+ )
445
+
446
+ start_id = [config.decoder_start_token_id] # ex: [50258]
447
+ prompt_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")
448
+ prompt_ids = list(map(lambda token: token[1], prompt_ids)) # ex: [50259, 50358, 50363]
449
+ forced_decoder_ids = start_id + prompt_ids # ex: [50258, 50259, 50358, 50363]
450
+
451
+ ort_names = list(map(lambda entry: entry.name, ort_session.get_inputs()))
452
+ ort_dtypes = list(map(lambda entry: entry.type, ort_session.get_inputs()))
453
+ ort_to_np = {
454
+ "tensor(float)": np.float32,
455
+ "tensor(float16)": np.float16,
456
+ "tensor(int64)": np.int64,
457
+ "tensor(int32)": np.int32,
458
+ "tensor(int8)": np.int8,
459
+ "tensor(uint8)": np.uint8,
460
+ }
461
+
462
+ use_extra_decoding_ids = "extra_decoding_ids" in ort_names
463
+ for name, dtype in zip(ort_names, ort_dtypes):
464
+ if name == "input_features":
465
+ inputs[name] = inputs[name].detach().cpu().numpy()
466
+ elif name == "vocab_mask":
467
+ inputs[name] = np.ones(config.vocab_size, dtype=ort_to_np[dtype])
468
+ elif name == "prefix_vocab_mask":
469
+ inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype])
470
+ elif name == "decoder_input_ids":
471
+ if not prompt_mode:
472
+ raw_input_ids = [start_id] if use_extra_decoding_ids else [forced_decoder_ids]
473
+ inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype])
474
+ else:
475
+ # This logic handles the scenario for when prompts are not of the same size
476
+ # For example if our prompt ids are [p1_id_1, p1_id_2] and [p2_id_1]
477
+ # The final decoder_input_ids will look as such after padding
478
+ # [prev_token, p1_id_1, p1_id_2, start_token, lang_token, transcribe_token]
479
+ # [prev_token, p2_id_1, PAD_TOKEN, start_token, lang_token, transcribe_token]
480
+ ort_prompts = []
481
+ for i in range(batch_size):
482
+ ort_prompts.append(decoder_prompt_ids[i].tolist())
483
+ max_len = max(len(p) for p in ort_prompts)
484
+ padded_prompts = []
485
+ for p in ort_prompts:
486
+ padded_prompt = [*p, *([config.pad_token_id] * (max_len - len(p)))]
487
+ padded_prompts.append(padded_prompt + forced_decoder_ids)
488
+ inputs[name] = np.array(padded_prompts, dtype=ort_to_np[dtype])
489
+ elif name == "logits_processor":
490
+ inputs[name] = np.array([1], dtype=ort_to_np[dtype])
491
+ elif name == "cross_qk_layer_head":
492
+ inputs[name] = np.array([[0, 0]], dtype=ort_to_np[dtype])
493
+ elif name == "extra_decoding_ids":
494
+ inputs[name] = np.repeat(np.array([prompt_ids], dtype=ort_to_np[dtype]), batch_size, 0)
495
+ elif name == "temperature":
496
+ inputs[name] = np.array([1.0], dtype=ort_to_np[dtype])
497
+ else:
498
+ inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype])
499
+ ort_outputs = ort_session.run(None, inputs)[0][:, 0, :]
500
+ ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True)
501
+ expected_transcription_options = WhisperHelper.select_transcription_options(batch_size, prompt_mode)
502
+
503
+ parity = 1
504
+ for i in range(batch_size):
505
+ parity *= (
506
+ pt_transcription[i] in expected_transcription_options
507
+ and ort_transcription[i] in expected_transcription_options
508
+ )
509
+ max_diff = 0
510
+
511
+ if not parity:
512
+ for i in range(batch_size):
513
+ if pt_outputs[i].shape != ort_outputs[i].shape:
514
+ diff = pt_outputs[i] - ort_outputs[i][:, : len(pt_outputs[i])]
515
+ else:
516
+ diff = pt_outputs[i] - ort_outputs[i]
517
+ max_diff_i = max(diff.min(), diff.max(), key=abs)
518
+ max_diff = max(max_diff, max_diff_i)
519
+
520
+ if max_diff != 0:
521
+ logger.warning(f"PyTorch outputs: {pt_transcription}")
522
+ logger.warning(f"ONNX Runtime outputs: {ort_transcription}")
523
+
524
+ return max_diff
@@ -0,0 +1,84 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import logging
8
+
9
+ import torch
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class WhisperDecoderInitOpenai(torch.nn.Module):
15
+ """WhisperDecoderInit for Openai."""
16
+
17
+ def __init__(
18
+ self,
19
+ model: torch.nn.Module,
20
+ decoder: torch.nn.Module,
21
+ ):
22
+ super().__init__()
23
+ self.whisper_model = model
24
+ self.whisper_decoder = decoder
25
+ self.kv_cache = {}
26
+
27
+ @torch.no_grad()
28
+ def forward(
29
+ self,
30
+ tokens,
31
+ audio_features,
32
+ past=None,
33
+ remove_hooks=False,
34
+ ):
35
+ # Create a kv_cache for past_values
36
+ past_kv_cache = dict()
37
+ if past is not None:
38
+ # Convert past values from 4D to 3D
39
+ past = [torch.transpose(val, 1, 2) for val in past]
40
+ past = [val.reshape(val.shape[:2] + (-1,)) for val in past]
41
+ half_idx = len(past) // 2
42
+ for idx, block in enumerate(self.whisper_decoder.blocks):
43
+ past_kv_cache[block.attn.key] = past[2 * idx]
44
+ past_kv_cache[block.attn.value] = past[2 * idx + 1]
45
+ past_kv_cache[block.cross_attn.key] = past[2 * idx + half_idx]
46
+ past_kv_cache[block.cross_attn.value] = past[2 * idx + half_idx + 1]
47
+
48
+ hooks = None
49
+ if not self.kv_cache:
50
+ self.kv_cache, hooks = self.whisper_model.install_kv_cache_hooks()
51
+
52
+ logits = self.whisper_decoder(tokens, audio_features, kv_cache=past_kv_cache)
53
+
54
+ # Add concat node for past values
55
+ if past is not None:
56
+ for block in self.whisper_decoder.blocks:
57
+ self.kv_cache[block.attn.key] = torch.cat(
58
+ [past_kv_cache[block.attn.key], self.kv_cache[block.attn.key]], dim=1
59
+ ).detach()
60
+ self.kv_cache[block.attn.value] = torch.cat(
61
+ [past_kv_cache[block.attn.value], self.kv_cache[block.attn.value]], dim=1
62
+ ).detach()
63
+
64
+ present_self, present_cross = [], []
65
+ # Group self and cross values
66
+ for block in self.whisper_decoder.blocks:
67
+ present_self.append(self.kv_cache[block.attn.key])
68
+ present_self.append(self.kv_cache[block.attn.value])
69
+ if past is None:
70
+ present_cross.append(self.kv_cache[block.cross_attn.key])
71
+ present_cross.append(self.kv_cache[block.cross_attn.value])
72
+
73
+ present_self = present_self + present_cross
74
+ # Add reshape and transpose ops to convert from 3D to 4D
75
+ present_self = [
76
+ present_val.reshape(present_val.shape[:2] + (-1, 64)).transpose(1, 2) for present_val in present_self
77
+ ]
78
+
79
+ # Remove forward hooks to avoid model cloning step
80
+ if hooks is not None and remove_hooks:
81
+ self.kv_cache = {}
82
+ for hook in hooks:
83
+ hook.remove()
84
+ return logits, present_self