onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,361 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # -------------------------------------------------------------------------
5
+
6
+ import logging
7
+ import os
8
+ import tempfile
9
+ from pathlib import Path
10
+
11
+ import numpy
12
+ import onnx
13
+ import torch
14
+ from onnx_model import OnnxModel
15
+ from past_helper import PastKeyValuesHelper
16
+ from t5_decoder import T5DecoderInit
17
+ from t5_encoder import T5Encoder, T5EncoderInputs
18
+ from torch_onnx_export_helper import torch_onnx_export
19
+ from transformers import MT5Config, T5Config
20
+
21
+ from onnxruntime import InferenceSession
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class T5EncoderDecoderInit(torch.nn.Module):
27
+ """A combination of T5Encoder and T5DecoderInit."""
28
+
29
+ def __init__(
30
+ self,
31
+ encoder: torch.nn.Module,
32
+ decoder: torch.nn.Module,
33
+ lm_head: torch.nn.Linear,
34
+ config: T5Config | MT5Config,
35
+ decoder_start_token_id: int | None = None,
36
+ output_cross_only: bool = False,
37
+ ):
38
+ super().__init__()
39
+ self.config: T5Config | MT5Config = config
40
+ self.t5_encoder = T5Encoder(encoder, config)
41
+ self.t5_decoder_init = T5DecoderInit(decoder, lm_head, config, decoder_start_token_id)
42
+ self.output_cross_only = output_cross_only
43
+
44
+ def forward(
45
+ self,
46
+ encoder_input_ids: torch.Tensor,
47
+ encoder_attention_mask: torch.Tensor,
48
+ decoder_input_ids: torch.Tensor | None = None,
49
+ ):
50
+ encoder_hidden_states: torch.FloatTensor = self.t5_encoder(encoder_input_ids, encoder_attention_mask)
51
+
52
+ lm_logits, past_self, past_cross = self.t5_decoder_init(
53
+ decoder_input_ids, encoder_attention_mask, encoder_hidden_states
54
+ )
55
+
56
+ if self.output_cross_only:
57
+ return past_cross
58
+ else:
59
+ return lm_logits, encoder_hidden_states, past_self, past_cross
60
+
61
+
62
+ class T5EncoderDecoderInitInputs:
63
+ def __init__(self, encoder_input_ids, encoder_attention_mask, decoder_input_ids=None):
64
+ self.encoder_input_ids: torch.LongTensor = encoder_input_ids
65
+ self.encoder_attention_mask: torch.LongTensor = encoder_attention_mask
66
+ self.decoder_input_ids: torch.LongTensor | None = decoder_input_ids
67
+
68
+ @staticmethod
69
+ def create_dummy(
70
+ config: T5Config | MT5Config,
71
+ batch_size: int,
72
+ encode_sequence_length: int,
73
+ use_decoder_input_ids: int,
74
+ device: torch.device,
75
+ use_int32_inputs: bool = False,
76
+ ): # -> T5EncoderDecoderInitInputs:
77
+ encoder_inputs: T5EncoderInputs = T5EncoderInputs.create_dummy(
78
+ batch_size,
79
+ encode_sequence_length,
80
+ config.vocab_size,
81
+ device,
82
+ use_int32_inputs=use_int32_inputs,
83
+ )
84
+ decoder_input_ids = None
85
+ if use_decoder_input_ids:
86
+ dtype = torch.int32 if use_int32_inputs else torch.int64
87
+ decoder_input_ids = torch.ones((batch_size, 1), dtype=dtype, device=device) * config.decoder_start_token_id
88
+
89
+ return T5EncoderDecoderInitInputs(encoder_inputs.input_ids, encoder_inputs.attention_mask, decoder_input_ids)
90
+
91
+ def to_list(self) -> list:
92
+ input_list = [self.encoder_input_ids, self.encoder_attention_mask]
93
+ if self.decoder_input_ids is not None:
94
+ input_list.append(self.decoder_input_ids)
95
+ return input_list
96
+
97
+
98
+ class T5EncoderDecoderInitHelper:
99
+ @staticmethod
100
+ def export_onnx(
101
+ model: T5EncoderDecoderInit,
102
+ device: torch.device,
103
+ onnx_model_path: str,
104
+ use_decoder_input_ids: bool = True,
105
+ verbose: bool = True,
106
+ use_external_data_format: bool = False,
107
+ use_int32_inputs: bool = False,
108
+ ):
109
+ """Export decoder to ONNX
110
+
111
+ Args:
112
+ model (T5EncoderDecoderInit): the model to export
113
+ device (torch.device): device of decoder object
114
+ onnx_model_path (str): onnx path
115
+ verbose (bool, optional): print verbose information. Defaults to True.
116
+ use_external_data_format (bool, optional): use external data format or not. Defaults to False.
117
+ use_int32_inputs (bool, optional): use int32 instead of int64 for integer inputs. Defaults to False.
118
+ """
119
+ assert isinstance(model, T5EncoderDecoderInit)
120
+
121
+ # Do not exclude decoder in torch onnx export so that cross can show up.
122
+ output_cross_only = model.output_cross_only
123
+ model.output_cross_only = False
124
+
125
+ inputs = T5EncoderDecoderInitInputs.create_dummy(
126
+ model.config,
127
+ batch_size=2,
128
+ encode_sequence_length=3,
129
+ use_decoder_input_ids=use_decoder_input_ids,
130
+ device=device,
131
+ use_int32_inputs=use_int32_inputs,
132
+ )
133
+ input_list = inputs.to_list()
134
+
135
+ present_names = PastKeyValuesHelper.get_past_names(model.config.num_decoder_layers, present=True)
136
+
137
+ output_names = ["logits", "encoder_hidden_states", *present_names]
138
+
139
+ # Shape of input tensors (sequence_length==1):
140
+ # input_ids: (batch_size, sequence_length)
141
+ # encoder_attention_mask: (batch_size, encode_sequence_length)
142
+ # encoder_hidden_states: (batch_size, encode_sequence_length, hidden_size)
143
+ # past_self_*: (batch_size, num_heads, past_decode_sequence_length, head_size)
144
+ # past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
145
+
146
+ # Shape of output tensors:
147
+ # logits: (batch_size, sequence_length, vocab_size)
148
+ # past_self_*: (batch_size, num_heads, past_decode_sequence_length + sequence_length, head_size)
149
+ # past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
150
+
151
+ input_names = ["encoder_input_ids", "encoder_attention_mask"]
152
+
153
+ # ONNX exporter might mark dimension like 'present_value_self_1_dim_2' in shape inference.
154
+ # We use a workaround here: first use dim_param "1" for sequence_length, and later change to dim_value.
155
+ sequence_length = "1"
156
+ num_heads = str(model.config.num_heads)
157
+ hidden_size = str(model.config.d_model)
158
+ head_size = str(model.config.d_kv)
159
+
160
+ dynamic_axes = {
161
+ "encoder_input_ids": {0: "batch_size", 1: "encode_sequence_length"},
162
+ "encoder_attention_mask": {0: "batch_size", 1: "encode_sequence_length"},
163
+ "encoder_hidden_states": {
164
+ 0: "batch_size",
165
+ 1: "encode_sequence_length",
166
+ 2: hidden_size,
167
+ },
168
+ "logits": {
169
+ 0: "batch_size",
170
+ 1: sequence_length,
171
+ },
172
+ }
173
+
174
+ if use_decoder_input_ids:
175
+ input_names.append("decoder_input_ids")
176
+ dynamic_axes["decoder_input_ids"] = {
177
+ 0: "batch_size",
178
+ 1: sequence_length,
179
+ }
180
+
181
+ for name in present_names:
182
+ if "cross" in name:
183
+ dynamic_axes[name] = {
184
+ 0: "batch_size",
185
+ 1: num_heads,
186
+ 2: "encode_sequence_length",
187
+ 3: head_size,
188
+ }
189
+
190
+ else: # self attention past state
191
+ dynamic_axes[name] = {
192
+ 0: "batch_size",
193
+ 1: num_heads,
194
+ 2: sequence_length,
195
+ 3: head_size,
196
+ }
197
+
198
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
199
+ temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder_decoder_init.onnx")
200
+ Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
201
+ torch_onnx_export(
202
+ model,
203
+ args=tuple(input_list),
204
+ f=temp_onnx_model_path,
205
+ export_params=True,
206
+ input_names=input_names,
207
+ output_names=output_names,
208
+ dynamic_axes=dynamic_axes,
209
+ opset_version=12,
210
+ do_constant_folding=True,
211
+ use_external_data_format=use_external_data_format,
212
+ verbose=verbose,
213
+ )
214
+
215
+ # Restore output_cross_only setting.
216
+ model.output_cross_only = output_cross_only
217
+
218
+ # Workaround as mentioned earlier: change numeric dim_param to dim_value
219
+ exported_model: onnx.ModelProto = onnx.load(temp_onnx_model_path)
220
+ for tensor in exported_model.graph.output:
221
+ for dim_proto in tensor.type.tensor_type.shape.dim:
222
+ if dim_proto.HasField("dim_param") and dim_proto.dim_param in [
223
+ sequence_length,
224
+ num_heads,
225
+ hidden_size,
226
+ head_size,
227
+ ]:
228
+ dim_value = int(dim_proto.dim_param)
229
+ dim_proto.Clear()
230
+ dim_proto.dim_value = dim_value
231
+
232
+ if output_cross_only:
233
+ # Rewrite onnx graph to only keep present_[key|value]_cross_* outputs.
234
+ onnx_model = OnnxModel(exported_model)
235
+ output_name_to_node = onnx_model.output_name_to_node()
236
+
237
+ for output in exported_model.graph.output:
238
+ if "cross" in output.name:
239
+ assert output.name in output_name_to_node
240
+
241
+ transpose_node = output_name_to_node[output.name]
242
+ assert transpose_node and transpose_node.op_type == "Transpose"
243
+
244
+ permutation = OnnxModel.get_node_attribute(transpose_node, "perm")
245
+ assert isinstance(permutation, list)
246
+ assert permutation == [0, 2, 1, 3]
247
+
248
+ matched_nodes = onnx_model.match_parent_path(
249
+ transpose_node,
250
+ ["Reshape", "MatMul"],
251
+ [0, 0],
252
+ output_name_to_node,
253
+ )
254
+ assert matched_nodes is not None
255
+
256
+ reshape_node, matmul_node = matched_nodes
257
+ assert "encoder_hidden_states" in matmul_node.input
258
+
259
+ if not onnx_model.get_initializer("cross_reshape_shape"):
260
+ shape_tensor = onnx.helper.make_tensor(
261
+ name="cross_reshape_shape",
262
+ data_type=onnx.TensorProto.INT64,
263
+ dims=[4],
264
+ vals=[0, 0, int(num_heads), int(head_size)],
265
+ raw=False,
266
+ )
267
+ onnx_model.add_initializer(shape_tensor)
268
+
269
+ reshape_node.input[1] = "cross_reshape_shape"
270
+
271
+ cross_outputs = [output.name for output in exported_model.graph.output if "cross" in output.name]
272
+ onnx_model.prune_graph(cross_outputs, allow_remove_graph_inputs=True)
273
+
274
+ OnnxModel.save(
275
+ exported_model,
276
+ onnx_model_path,
277
+ save_as_external_data=use_external_data_format,
278
+ all_tensors_to_one_file=True,
279
+ )
280
+
281
+ @staticmethod
282
+ def onnxruntime_inference(ort_session, inputs: T5EncoderDecoderInitInputs):
283
+ """Run inference of ONNX model."""
284
+ logger.debug("start onnxruntime_inference")
285
+
286
+ ort_inputs = {
287
+ "encoder_input_ids": numpy.ascontiguousarray(inputs.encoder_input_ids.cpu().numpy()),
288
+ "encoder_attention_mask": numpy.ascontiguousarray(inputs.encoder_attention_mask.cpu().numpy()),
289
+ }
290
+ if inputs.decoder_input_ids is not None:
291
+ ort_inputs["decoder_input_ids"] = numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy())
292
+
293
+ ort_outputs = ort_session.run(None, ort_inputs)
294
+ return ort_outputs
295
+
296
+ @staticmethod
297
+ def verify_onnx(
298
+ model: T5EncoderDecoderInit,
299
+ ort_session: InferenceSession,
300
+ device: torch.device,
301
+ use_int32_inputs: bool,
302
+ max_cases: int = 4,
303
+ ):
304
+ """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
305
+ ort_inputs = ort_session.get_inputs()
306
+ use_decoder_input_ids = len(ort_inputs) == 3
307
+
308
+ test_cases = [(4, 11), (1, 2), (3, 1), (8, 5)]
309
+ test_cases_max_diff = []
310
+ for batch_size, encode_sequence_length in test_cases[:max_cases]:
311
+ inputs = T5EncoderDecoderInitInputs.create_dummy(
312
+ model.config,
313
+ batch_size,
314
+ encode_sequence_length,
315
+ use_decoder_input_ids=use_decoder_input_ids,
316
+ device=device,
317
+ use_int32_inputs=use_int32_inputs,
318
+ )
319
+
320
+ ort_outputs = T5EncoderDecoderInitHelper.onnxruntime_inference(ort_session, inputs)
321
+
322
+ # Run inference of PyTorch model
323
+ input_list = inputs.to_list()
324
+ torch_outputs = model(*input_list)
325
+
326
+ num_decoder_layers = model.config.num_decoder_layers
327
+
328
+ if not model.output_cross_only:
329
+ assert torch_outputs[0].cpu().numpy().shape == ort_outputs[0].shape
330
+ max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0]))
331
+ logger.debug(f"logits max_diff={max_diff}")
332
+ max_diff_all = max_diff
333
+
334
+ assert torch_outputs[1].cpu().numpy().shape == ort_outputs[1].shape
335
+ max_diff = numpy.amax(numpy.abs(torch_outputs[1].cpu().numpy() - ort_outputs[1]))
336
+ logger.debug(f"encoder_hidden_states max_diff={max_diff}")
337
+ max_diff_all = max(max_diff_all, max_diff)
338
+
339
+ for i in range(2 * num_decoder_layers):
340
+ max_diff = numpy.amax(numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[2 + i]))
341
+ logger.debug(f"self attention past state {i} max_diff={max_diff}")
342
+
343
+ for i in range(2 * num_decoder_layers):
344
+ max_diff = numpy.amax(
345
+ numpy.abs(torch_outputs[3][i].cpu().numpy() - ort_outputs[2 + 2 * num_decoder_layers + i])
346
+ )
347
+ logger.debug(f"cross attention past state {i} max_diff={max_diff}")
348
+ max_diff_all = max(max_diff_all, max_diff)
349
+ else:
350
+ max_diff_all = -float("inf")
351
+ for i in range(2 * num_decoder_layers):
352
+ max_diff = numpy.amax(numpy.abs(torch_outputs[i].cpu().numpy() - ort_outputs[i]))
353
+ logger.debug(f"cross attention past state {i} max_diff={max_diff}")
354
+ max_diff_all = max(max_diff_all, max_diff)
355
+
356
+ test_cases_max_diff.append(max_diff_all)
357
+ logger.info(
358
+ f"batch_size={batch_size} encode_sequence_length={encode_sequence_length}, max_diff={max_diff_all}"
359
+ )
360
+
361
+ return max(test_cases_max_diff)
@@ -0,0 +1,302 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # -------------------------------------------------------------------------
5
+
6
+ import logging
7
+ import os
8
+ from pathlib import Path
9
+
10
+ import torch
11
+ from float16 import float_to_float16_max_diff
12
+ from onnx_model import OnnxModel
13
+ from optimizer import optimize_model
14
+ from t5_decoder import T5Decoder, T5DecoderHelper
15
+ from t5_encoder_decoder_init import T5EncoderDecoderInit, T5EncoderDecoderInitHelper
16
+ from transformers import MT5ForConditionalGeneration, T5ForConditionalGeneration
17
+
18
+ from onnxruntime import InferenceSession
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ PRETRAINED_T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]
23
+ PRETRAINED_MT5_MODELS = [
24
+ "google/mt5-small",
25
+ "google/mt5-base",
26
+ "google/mt5-large",
27
+ "google/mt5-xl",
28
+ "google/mt5-xxl",
29
+ ]
30
+
31
+
32
+ class T5Helper:
33
+ @staticmethod
34
+ def get_onnx_path(
35
+ output_dir: str,
36
+ model_name_or_path: str,
37
+ suffix: str = "",
38
+ new_folder: bool = False,
39
+ ) -> str:
40
+ """Build onnx path
41
+
42
+ Args:
43
+ output_dir (str): output directory
44
+ model_name_or_path (str): pretrained model name, or path to the model checkpoint
45
+ suffix (str, optional): suffix like "_encoder" or "_decoder_fp16" will be appended to file name. Defaults to None.
46
+ new_folder (bool, optional): create a new directory for the model. Defaults to False.
47
+
48
+ Returns:
49
+ str: path of onnx model
50
+ """
51
+ model_name = model_name_or_path
52
+ if os.path.isdir(model_name_or_path):
53
+ model_name = Path(model_name_or_path).parts[-1]
54
+ else:
55
+ model_name.split("/")[-1]
56
+
57
+ model_name += suffix
58
+
59
+ directory = os.path.join(output_dir, model_name) if new_folder else output_dir
60
+ return os.path.join(directory, model_name + ".onnx")
61
+
62
+ @staticmethod
63
+ def load_model(
64
+ model_name_or_path: str,
65
+ cache_dir: str,
66
+ device: torch.device,
67
+ model_type: str = "t5",
68
+ state_dict_path: str = "",
69
+ encoder_decoder_init: bool = False,
70
+ ) -> dict[str, T5EncoderDecoderInit | T5Decoder]:
71
+ """Load model given a pretrained name or path, then build models for ONNX conversion.
72
+
73
+ Args:
74
+ model_name_or_path (str): pretrained model name or path
75
+ cache_dir (str): cache directory
76
+ device (torch.device): device to run the model
77
+ model_type (str, optional): model type "t5" or "mt5"
78
+ state_dict_path(str, optional): state dictionary path
79
+ encoder_decoder_init (bool, optional): combine encoder and decoder kv cache initialization into one model.
80
+ Returns:
81
+ Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion.
82
+ """
83
+ if model_type == "t5":
84
+ model = T5ForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir)
85
+ elif model_type == "mt5":
86
+ model = MT5ForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir)
87
+ else:
88
+ raise ValueError("only support mode_type=t5 or mt5")
89
+
90
+ if state_dict_path:
91
+ model.load_state_dict(torch.load(state_dict_path))
92
+
93
+ decoder = T5Decoder(model.decoder, model.lm_head, model.config)
94
+ decoder.eval().to(device)
95
+
96
+ encoder = T5EncoderDecoderInit(
97
+ model.encoder,
98
+ model.decoder,
99
+ model.lm_head,
100
+ model.config,
101
+ decoder_start_token_id=None,
102
+ output_cross_only=not encoder_decoder_init,
103
+ )
104
+
105
+ encoder_name = "encoder_decoder_init" if encoder_decoder_init else "encoder"
106
+ return {encoder_name: encoder, "decoder": decoder}
107
+
108
+ @staticmethod
109
+ def export_onnx(
110
+ model: T5Decoder | T5EncoderDecoderInit,
111
+ device: torch.device,
112
+ onnx_model_path: str,
113
+ verbose: bool = True,
114
+ use_external_data_format: bool = False,
115
+ use_decoder_input_ids: bool = True,
116
+ use_int32_inputs: bool = False,
117
+ ):
118
+ if isinstance(model, T5EncoderDecoderInit):
119
+ T5EncoderDecoderInitHelper.export_onnx(
120
+ model,
121
+ device,
122
+ onnx_model_path,
123
+ use_decoder_input_ids,
124
+ verbose,
125
+ use_external_data_format,
126
+ use_int32_inputs,
127
+ )
128
+ else:
129
+ T5DecoderHelper.export_onnx(
130
+ model,
131
+ device,
132
+ onnx_model_path,
133
+ verbose,
134
+ use_external_data_format,
135
+ use_int32_inputs,
136
+ )
137
+
138
+ @staticmethod
139
+ def auto_mixed_precision(
140
+ onnx_model: OnnxModel,
141
+ op_block_list: list[str] | None = None,
142
+ force_fp16_logits: bool = False,
143
+ use_symbolic_shape_infer: bool = True,
144
+ ):
145
+ """Convert model to mixed precision.
146
+ It detects whether original model has fp16 precision weights, and set parameters for float16 conversion automatically.
147
+ Args:
148
+ onnx_model (OnnxModel): optimized ONNX model
149
+ op_block_list (List[str], optional): operators need to run in fp32.
150
+ force_fp16_logits (bool, optional): force logits and last MatMul node to be in float16. Defaults to False.
151
+ use_symbolic_shape_infer (bool, optional): use symbolic shape inference to convert float to float16. Defaults to True.
152
+ Returns:
153
+ parameters(dict): a dictionary of parameters used in float16 conversion
154
+ """
155
+ if op_block_list is None:
156
+ op_block_list = [
157
+ "SimplifiedLayerNormalization",
158
+ "SkipSimplifiedLayerNormalization",
159
+ "Relu",
160
+ "Add",
161
+ ]
162
+
163
+ op_full_set = {node.op_type for node in onnx_model.nodes()}
164
+ fp32_op_set = set(op_block_list)
165
+ fp16_op_set = op_full_set.difference(fp32_op_set)
166
+ logger.info(f"fp32 op: {fp32_op_set} fp16 op: {fp16_op_set}")
167
+
168
+ # logits is the first output
169
+ logits_output_name = onnx_model.graph().output[0].name
170
+
171
+ # We use the weight in last MatMul node to detect whether the model is stored with float16 weights from training.
172
+ is_weight_fp16_precision = False
173
+ output_name_to_node = onnx_model.output_name_to_node()
174
+ assert logits_output_name in output_name_to_node
175
+ node = output_name_to_node[logits_output_name]
176
+ last_matmul_node = None
177
+ if node.op_type == "MatMul":
178
+ last_matmul_node = node
179
+ logger.info(f"Found last MatMul node for logits: {node.name}")
180
+ initializer = None
181
+ for input in node.input:
182
+ initializer = onnx_model.get_initializer(input)
183
+ if initializer is not None:
184
+ break
185
+
186
+ # when the max difference of value after converting float to float16 is lower than a threshold (1e-6),
187
+ # we can deduce that the weights are stored in float16 precision.
188
+ max_diff = float_to_float16_max_diff(initializer)
189
+ logger.debug(f"max diff of converting weights in last MatMul node {node.name}: {max_diff}")
190
+ is_weight_fp16_precision = max_diff < 1e-6
191
+ else:
192
+ logger.warning(f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}")
193
+
194
+ keep_io_types = []
195
+ node_block_list = []
196
+ if (not is_weight_fp16_precision) and (last_matmul_node is not None) and not force_fp16_logits:
197
+ # When original weight is float32 precision, keep logits and last MatMul in float32 could get better precision.
198
+ keep_io_types = [logits_output_name]
199
+ node_block_list = [last_matmul_node.name]
200
+
201
+ if "Add" not in op_block_list:
202
+ input_name_to_nodes = onnx_model.input_name_to_nodes()
203
+ fp32_add = 0
204
+ changed = True
205
+ add_nodes = onnx_model.get_nodes_by_op_type("Add")
206
+ while changed:
207
+ changed = False
208
+ for node in add_nodes:
209
+ if node.name not in node_block_list:
210
+ parents = onnx_model.get_parents(node, output_name_to_node)
211
+ children = onnx_model.get_children(node, input_name_to_nodes)
212
+ blocked_children = [
213
+ child for child in children if child.op_type in op_block_list or child in node_block_list
214
+ ]
215
+ blocked_parents = [
216
+ parent for parent in parents if parent.op_type in op_block_list or parent in node_block_list
217
+ ]
218
+ # If any child or parent is in fp32, we place the Add node to fp32.
219
+ if (len(blocked_children) + len(blocked_parents)) > 0:
220
+ node_block_list.append(node.name)
221
+ fp32_add += 1
222
+ changed = True
223
+ fp16_add = len(add_nodes) - fp32_add
224
+ logger.info(f"node counter of Add operator: fp32={fp32_add} fp16={fp16_add}")
225
+
226
+ logger.info(f"node_block_list: {node_block_list}")
227
+
228
+ parameters = {
229
+ "keep_io_types": keep_io_types,
230
+ "op_block_list": op_block_list,
231
+ "node_block_list": node_block_list,
232
+ "force_fp16_initializers": is_weight_fp16_precision,
233
+ }
234
+
235
+ logger.info(f"auto_mixed_precision parameters: {parameters}")
236
+ if use_symbolic_shape_infer:
237
+ onnx_model.convert_float_to_float16(use_symbolic_shape_infer=True, **parameters)
238
+ else:
239
+ # Workaround when symbolic shape inference fails.
240
+ # Need enable shape_infer_before_optimization in convert_to_onnx.py as well.
241
+ from float16 import convert_float_to_float16 # noqa: PLC0415
242
+
243
+ convert_float_to_float16(
244
+ onnx_model.model,
245
+ disable_shape_infer=True,
246
+ **parameters,
247
+ )
248
+
249
+ return parameters
250
+
251
+ @staticmethod
252
+ def optimize_onnx(
253
+ onnx_model_path: str,
254
+ optimized_model_path: str,
255
+ is_float16: bool,
256
+ num_attention_heads: int,
257
+ hidden_size: int,
258
+ use_external_data_format: bool = False,
259
+ auto_mixed_precision: bool = True,
260
+ use_gpu: bool = False,
261
+ force_fp16_io: bool = False,
262
+ ):
263
+ """Optimize ONNX model with an option to convert it to use mixed precision."""
264
+
265
+ from fusion_options import FusionOptions # noqa: PLC0415
266
+
267
+ optimization_options = None
268
+ if is_float16:
269
+ optimization_options = FusionOptions("t5")
270
+ # SkipLayerNormalization is faster but might bring accuracy drop since it uses fp16 accumulation.
271
+ optimization_options.enable_skip_layer_norm = not auto_mixed_precision
272
+
273
+ m = optimize_model(
274
+ onnx_model_path,
275
+ model_type="t5",
276
+ num_heads=num_attention_heads,
277
+ hidden_size=hidden_size,
278
+ opt_level=0,
279
+ optimization_options=optimization_options,
280
+ use_gpu=use_gpu,
281
+ )
282
+
283
+ if is_float16:
284
+ if auto_mixed_precision:
285
+ T5Helper.auto_mixed_precision(m, force_fp16_logits=force_fp16_io)
286
+ else:
287
+ m.convert_model_float32_to_float16(cast_input_output=force_fp16_io)
288
+
289
+ m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True)
290
+
291
+ @staticmethod
292
+ def verify_onnx(
293
+ model: T5Decoder | T5EncoderDecoderInit,
294
+ ort_session: InferenceSession,
295
+ device: torch.device,
296
+ use_int32_inputs: bool,
297
+ ):
298
+ """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
299
+ if isinstance(model, T5EncoderDecoderInit):
300
+ return T5EncoderDecoderInitHelper.verify_onnx(model, ort_session, device, use_int32_inputs)
301
+
302
+ return T5DecoderHelper.verify_onnx(model, ort_session, device, use_int32_inputs)
@@ -0,0 +1,12 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import os
6
+ import sys
7
+
8
+ sys.path.append(os.path.dirname(__file__))
9
+
10
+ transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
11
+ if transformers_dir not in sys.path:
12
+ sys.path.append(transformers_dir)