onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,78 @@
1
+ from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType
2
+ from .base_operator import QuantOperatorBase
3
+ from .qdq_base_operator import QDQOperatorBase
4
+
5
+
6
+ # For operators that support 8bits operations directly, and output could
7
+ # reuse input[0]'s type, zeropoint, scale; For example,Transpose, Reshape, etc.
8
+ class Direct8BitOp(QuantOperatorBase):
9
+ def __init__(self, onnx_quantizer, onnx_node):
10
+ super().__init__(onnx_quantizer, onnx_node)
11
+
12
+ def quantize(self):
13
+ node = self.node
14
+
15
+ if not self.quantizer.force_quantize_no_input_check:
16
+ # Keep backward compatibility
17
+ # Quantize when input[0] is quantized already. Otherwise keep it.
18
+ quantized_input_value = self.quantizer.find_quantized_value(node.input[0])
19
+ if quantized_input_value is None:
20
+ self.quantizer.new_nodes += [node]
21
+ return
22
+
23
+ quantized_output_value = QuantizedValue(
24
+ node.output[0],
25
+ node.output[0] + TENSOR_NAME_QUANT_SUFFIX,
26
+ quantized_input_value.scale_name,
27
+ quantized_input_value.zp_name,
28
+ quantized_input_value.value_type,
29
+ )
30
+ self.quantizer.quantized_value_map[node.output[0]] = quantized_output_value
31
+
32
+ node.input[0] = quantized_input_value.q_name
33
+ node.output[0] = quantized_output_value.q_name
34
+ self.quantizer.new_nodes += [node]
35
+
36
+ else:
37
+ # Force quantize those ops if possible, use exclude node list if this is not you want
38
+ if not self.quantizer.is_valid_quantize_weight(node.input[0]):
39
+ super().quantize()
40
+ return
41
+
42
+ (
43
+ quantized_input_names,
44
+ zero_point_names,
45
+ scale_names,
46
+ nodes,
47
+ ) = self.quantizer.quantize_activation(node, [0])
48
+ if quantized_input_names is None:
49
+ return super().quantize()
50
+
51
+ # Create an entry for output quantized value
52
+ quantized_output_value = QuantizedValue(
53
+ node.output[0],
54
+ node.output[0] + TENSOR_NAME_QUANT_SUFFIX,
55
+ scale_names[0],
56
+ zero_point_names[0],
57
+ QuantizedValueType.Input,
58
+ )
59
+ self.quantizer.quantized_value_map[node.output[0]] = quantized_output_value
60
+
61
+ node.input[0] = quantized_input_names[0]
62
+ node.output[0] = quantized_output_value.q_name
63
+ nodes.append(node)
64
+
65
+ self.quantizer.new_nodes += nodes
66
+
67
+
68
+ class QDQDirect8BitOp(QDQOperatorBase):
69
+ def __init__(self, onnx_quantizer, onnx_node):
70
+ super().__init__(onnx_quantizer, onnx_node)
71
+
72
+ def quantize(self):
73
+ if self.quantizer.force_quantize_no_input_check:
74
+ self.quantizer.quantize_activation_tensor(self.node.input[0])
75
+ if not self.disable_qdq_for_node_output:
76
+ self.quantizer.quantize_output_same_as_input(self.node.output[0], self.node.input[0], self.node.name)
77
+ elif self.quantizer.is_tensor_quantized(self.node.input[0]) and not self.disable_qdq_for_node_output:
78
+ self.quantizer.quantize_output_same_as_input(self.node.output[0], self.node.input[0], self.node.name)
@@ -0,0 +1,121 @@
1
+ import logging
2
+
3
+ import onnx
4
+ from onnx import onnx_pb as onnx_proto # noqa: F401
5
+
6
+ from ..quant_utils import attribute_to_kwarg, ms_domain
7
+ from .base_operator import QuantOperatorBase
8
+
9
+ """
10
+ Quantizes the EmbedLayerNorm fused ONNXRuntime Op.
11
+
12
+ This Quant operator keeps the input and segment IDs at int32 but will quantize all initializer and
13
+ weight inputs associated with the node to uint8.
14
+ """
15
+
16
+
17
+ class EmbedLayerNormalizationQuant(QuantOperatorBase):
18
+ def __init__(self, onnx_quantizer, onnx_node):
19
+ super().__init__(onnx_quantizer, onnx_node)
20
+
21
+ def should_quantize(self):
22
+ return self.quantizer.should_quantize_node(self.node)
23
+
24
+ def quantize(self):
25
+ node = self.node
26
+ assert node.op_type == "EmbedLayerNormalization"
27
+
28
+ if len(node.output) > 2:
29
+ logging.info(f"Quantization is not applied to {node.name} since it has 3 outputs")
30
+ return super().quantize()
31
+
32
+ """
33
+ Pre-quantization EmbedLayerNorm inputs:
34
+ [0] input_ids (int32)
35
+ [1] segment_ids (int32)
36
+ [2] word_embedding (float32)
37
+ [3] position_embedding (float32)
38
+ [4] segment_embedding (float32)
39
+ [5] gamma (float32)
40
+ [6] beta (float32)
41
+ [7] mask (int32) (optional)
42
+ """
43
+ (
44
+ quantized_input_names,
45
+ zero_point_names,
46
+ scale_names,
47
+ nodes,
48
+ ) = self.quantizer.quantize_activation(node, [2, 3, 4, 5, 6])
49
+ if quantized_input_names is None:
50
+ return super().quantize()
51
+
52
+ qembed_layer_norm_name = "" if not node.name else node.name + "_quant"
53
+
54
+ """
55
+ Quantized Input Tensor List
56
+ [0] input_ids (int32)
57
+ [1] segment_ids (int32)
58
+ [2] word_embedding (uint8)
59
+ [3] position_embedding (uint8)
60
+ [4] segment_embedding (uint8)
61
+ [5] gamma (uint8)
62
+ [6] beta (uint8)
63
+ [7] mask (int32) (optional)
64
+ [8] word_embedding_scale (float)
65
+ [9] position_embedding_scale (float)
66
+ [10] segment_embedding_scale (float)
67
+ [11] gamma_scale (float)
68
+ [12] beta_scale (float)
69
+ [13] word_embedding_zero_point (uint8)
70
+ [14] position_embedding_zero_point (uint8)
71
+ [15] segment_embedding_zero_point (uint8)
72
+ [16] gamma_zero_point (uint8)
73
+ [17] beta_zero_point (uint8)
74
+ """
75
+ inputs = []
76
+ # 'input_ids'
77
+ inputs.extend([node.input[0]])
78
+ # 'segment_ids'
79
+ inputs.extend([node.input[1]])
80
+ # 'word_embedding_quant'
81
+ inputs.extend([quantized_input_names[0]])
82
+ # 'position_embedding_quant'
83
+ inputs.extend([quantized_input_names[1]])
84
+ # 'segment_embedding_quant'
85
+ inputs.extend([quantized_input_names[2]])
86
+ # 'gamma_quant'
87
+ inputs.extend([quantized_input_names[3]])
88
+ # 'beta_quant'
89
+ inputs.extend([quantized_input_names[4]])
90
+ # 'mask' (optional)
91
+ inputs.extend([node.input[7] if len(node.input) > 7 else ""])
92
+
93
+ # Add all scales:
94
+ inputs.extend([scale_names[0]])
95
+ inputs.extend([scale_names[1]])
96
+ inputs.extend([scale_names[2]])
97
+ inputs.extend([scale_names[3]])
98
+ inputs.extend([scale_names[4]])
99
+
100
+ # Add all zero points:
101
+ inputs.extend([zero_point_names[0]])
102
+ inputs.extend([zero_point_names[1]])
103
+ inputs.extend([zero_point_names[2]])
104
+ inputs.extend([zero_point_names[3]])
105
+ inputs.extend([zero_point_names[4]])
106
+
107
+ kwargs = {}
108
+ for attribute in node.attribute:
109
+ kwargs.update(attribute_to_kwarg(attribute))
110
+ kwargs["domain"] = ms_domain
111
+
112
+ qembed_layer_norm_node = onnx.helper.make_node(
113
+ "QEmbedLayerNormalization",
114
+ inputs,
115
+ node.output,
116
+ qembed_layer_norm_name,
117
+ **kwargs,
118
+ )
119
+ nodes.append(qembed_layer_norm_node)
120
+
121
+ self.quantizer.new_nodes += nodes
@@ -0,0 +1,64 @@
1
+ from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType
2
+ from .base_operator import QuantOperatorBase
3
+ from .qdq_base_operator import QDQOperatorBase
4
+
5
+ """
6
+ Quantize Gather
7
+ """
8
+
9
+
10
+ class GatherQuant(QuantOperatorBase):
11
+ def __init__(self, onnx_quantizer, onnx_node):
12
+ super().__init__(onnx_quantizer, onnx_node)
13
+
14
+ def should_quantize(self):
15
+ if not self.quantizer.should_quantize_node(self.node):
16
+ return False
17
+
18
+ return self.quantizer.is_valid_quantize_weight(self.node.input[0])
19
+
20
+ def quantize(self):
21
+ node = self.node
22
+ assert node.op_type == "Gather"
23
+
24
+ (
25
+ quantized_input_names,
26
+ zero_point_names,
27
+ scale_names,
28
+ nodes,
29
+ ) = self.quantizer.quantize_activation(node, [0])
30
+ if quantized_input_names is None:
31
+ return super().quantize()
32
+
33
+ gather_new_output = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
34
+
35
+ # Create an entry for this quantized value
36
+ q_output = QuantizedValue(
37
+ node.output[0],
38
+ gather_new_output,
39
+ scale_names[0],
40
+ zero_point_names[0],
41
+ QuantizedValueType.Input,
42
+ )
43
+ self.quantizer.quantized_value_map[node.output[0]] = q_output
44
+
45
+ node.output[0] = gather_new_output
46
+ node.input[0] = quantized_input_names[0]
47
+ nodes.append(node)
48
+
49
+ self.quantizer.new_nodes += nodes
50
+
51
+
52
+ class QDQGather(QDQOperatorBase):
53
+ def __init__(self, onnx_quantizer, onnx_node):
54
+ super().__init__(onnx_quantizer, onnx_node)
55
+
56
+ def quantize(self):
57
+ node = self.node
58
+ assert node.op_type == "Gather" or node.op_type == "GatherElements"
59
+
60
+ if self.quantizer.is_valid_quantize_weight(node.input[0]) or self.quantizer.force_quantize_no_input_check:
61
+ self.quantizer.quantize_activation_tensor(node.input[0])
62
+ self.quantizer.quantize_output_same_as_input(node.output[0], node.input[0], node.name)
63
+ elif self.quantizer.is_tensor_quantized(node.input[0]):
64
+ self.quantizer.quantize_output_same_as_input(node.output[0], node.input[0], node.name)
@@ -0,0 +1,62 @@
1
+ import onnx
2
+
3
+ from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
4
+ from .base_operator import QuantOperatorBase
5
+
6
+
7
+ class QGlobalAveragePool(QuantOperatorBase):
8
+ def __init__(self, onnx_quantizer, onnx_node):
9
+ super().__init__(onnx_quantizer, onnx_node)
10
+
11
+ def quantize(self):
12
+ node = self.node
13
+ assert node.op_type == "GlobalAveragePool"
14
+
15
+ # If input to this node is not quantized then keep this node.
16
+ if node.input[0] not in self.quantizer.quantized_value_map:
17
+ return super().quantize()
18
+
19
+ quantized_input_value = self.quantizer.quantized_value_map[node.input[0]]
20
+
21
+ # Create an entry for output quantized value.
22
+ quantized_input_value = self.quantizer.quantized_value_map[node.input[0]]
23
+ (
24
+ data_found,
25
+ output_scale_name_from_parameter,
26
+ output_zp_name_from_parameter,
27
+ _,
28
+ _,
29
+ ) = self.quantizer._get_quantization_params(node.output[0])
30
+ # Just use input scale and zp if parameters for output is not specified.
31
+ output_scale_name = output_scale_name_from_parameter if data_found else quantized_input_value.scale_name
32
+ output_zp_name = output_zp_name_from_parameter if data_found else quantized_input_value.zp_name
33
+ quantized_output_value = QuantizedValue(
34
+ node.output[0],
35
+ node.output[0] + TENSOR_NAME_QUANT_SUFFIX,
36
+ output_scale_name,
37
+ output_zp_name,
38
+ QuantizedValueType.Input,
39
+ )
40
+ self.quantizer.quantized_value_map[node.output[0]] = quantized_output_value
41
+
42
+ kwargs = {}
43
+ for attribute in node.attribute:
44
+ kwargs.update(attribute_to_kwarg(attribute))
45
+ kwargs["domain"] = ms_domain
46
+ kwargs["channels_last"] = 0
47
+ qnode_name = node.name + "_quant" if node.name else ""
48
+
49
+ qnode = onnx.helper.make_node(
50
+ "QLinear" + node.op_type,
51
+ [
52
+ quantized_input_value.q_name,
53
+ quantized_input_value.scale_name,
54
+ quantized_input_value.zp_name,
55
+ output_scale_name,
56
+ output_zp_name,
57
+ ],
58
+ [quantized_output_value.q_name],
59
+ qnode_name,
60
+ **kwargs,
61
+ )
62
+ self.quantizer.new_nodes += [qnode]
@@ -0,0 +1,172 @@
1
+ import logging
2
+
3
+ import numpy as np # noqa: F401
4
+ import onnx
5
+
6
+ from ..quant_utils import (
7
+ TENSOR_NAME_QUANT_SUFFIX,
8
+ QuantizedValue,
9
+ QuantizedValueType,
10
+ attribute_to_kwarg,
11
+ find_by_name, # noqa: F401
12
+ get_mul_node, # noqa: F401
13
+ ms_domain,
14
+ )
15
+ from .base_operator import QuantOperatorBase # noqa: F401
16
+ from .matmul import QOpMatMul
17
+ from .qdq_base_operator import QDQOperatorBase
18
+
19
+
20
+ def is_B_transposed(gemm_node): # noqa: N802
21
+ transB_attribute = [attr for attr in gemm_node.attribute if attr.name == "transB"] # noqa: N806
22
+ if transB_attribute:
23
+ return onnx.helper.get_attribute_value(transB_attribute[0]) > 0
24
+
25
+ return False
26
+
27
+
28
+ def get_beta(gemm_node):
29
+ beta_attribute = [attr for attr in gemm_node.attribute if attr.name == "beta"]
30
+ if beta_attribute:
31
+ return onnx.helper.get_attribute_value(beta_attribute[0])
32
+
33
+ return 1.0
34
+
35
+
36
+ def set_default_beta(gemm_node):
37
+ beta_attribute = [attr for attr in gemm_node.attribute if attr.name == "beta"]
38
+ if beta_attribute:
39
+ beta_attribute[0].f = 1.0
40
+
41
+ return 1.0
42
+
43
+
44
+ class QLinearGemm(QOpMatMul):
45
+ def __init__(self, onnx_quantizer, onnx_node):
46
+ super().__init__(onnx_quantizer, onnx_node)
47
+
48
+ def quantize(self):
49
+ node = self.node
50
+ assert node.op_type == "Gemm"
51
+
52
+ (
53
+ data_found,
54
+ output_scale_name,
55
+ output_zp_name,
56
+ _,
57
+ _,
58
+ ) = self.quantizer._get_quantization_params(node.output[0])
59
+
60
+ if self.quantizer.is_input_a_initializer(node.input[1]) and self.quantizer.is_per_channel():
61
+ (
62
+ quantized_input_names,
63
+ zero_point_names,
64
+ scale_names,
65
+ nodes,
66
+ ) = self.quantizer.quantize_activation(node, [0])
67
+ quant_weight_tuple = self.quantizer.quantize_weight_per_channel(
68
+ node.input[1],
69
+ self.quantizer.weight_qType,
70
+ 0 if is_B_transposed(node) else 1,
71
+ )
72
+ quantized_input_names.append(quant_weight_tuple[0])
73
+ zero_point_names.append(quant_weight_tuple[1])
74
+ scale_names.append(quant_weight_tuple[2])
75
+ else:
76
+ # Get Quantized from both activation(input[0]) and weight(input[1])
77
+ (
78
+ quantized_input_names,
79
+ zero_point_names,
80
+ scale_names,
81
+ nodes,
82
+ ) = self.quantizer.quantize_activation(node, [0])
83
+
84
+ (
85
+ quantized_input_names_weight,
86
+ zero_point_names_weight,
87
+ scale_names_weight,
88
+ nodes_weight,
89
+ ) = self.quantizer.quantize_weight(node, [1], reduce_range=self.quantizer.reduce_range)
90
+ quantized_input_names.extend(quantized_input_names_weight)
91
+ zero_point_names.extend(zero_point_names_weight)
92
+ scale_names.extend(scale_names_weight)
93
+ nodes.extend(nodes_weight)
94
+
95
+ if not data_found or quantized_input_names is None:
96
+ return super().quantize()
97
+
98
+ quantized_bias_name = ""
99
+ if len(node.input) == 3:
100
+ if not self.quantizer.is_input_a_initializer(node.input[2]):
101
+ return super().quantize()
102
+
103
+ # Note: if the quantized type is float 8, the bias is converted into float 16.
104
+ # cublasLtMatMul only supports (b)float16 or float32 bias.
105
+ quantized_bias_name = self.quantizer.quantize_bias_static(
106
+ node.input[2], node.input[0], node.input[1], get_beta(self.node)
107
+ )
108
+
109
+ qgemm_output = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
110
+ qgemm_name = node.name + "_quant" if node.name else ""
111
+
112
+ kwargs = {}
113
+ for attribute in node.attribute:
114
+ if attribute.name != "beta":
115
+ kwargs.update(attribute_to_kwarg(attribute))
116
+ kwargs["domain"] = ms_domain
117
+
118
+ # generate input
119
+ qgemm_inputs = []
120
+ for i in range(2):
121
+ qgemm_inputs.extend([quantized_input_names[i], scale_names[i], zero_point_names[i]])
122
+
123
+ qgemm_inputs.extend([quantized_bias_name, output_scale_name, output_zp_name])
124
+
125
+ qgemm_node = onnx.helper.make_node("QGemm", qgemm_inputs, [qgemm_output], qgemm_name, **kwargs)
126
+ nodes.append(qgemm_node)
127
+
128
+ # Create an entry for this quantized value
129
+ q_output = QuantizedValue(
130
+ node.output[0],
131
+ qgemm_output,
132
+ output_scale_name,
133
+ output_zp_name,
134
+ QuantizedValueType.Input,
135
+ node_type=node.op_type,
136
+ node_qtype=self.quantizer.weight_qType,
137
+ )
138
+ self.quantizer.quantized_value_map[node.output[0]] = q_output
139
+
140
+ self.quantizer.new_nodes += nodes
141
+
142
+
143
+ class QDQGemm(QDQOperatorBase):
144
+ def __init__(self, onnx_quantizer, onnx_node):
145
+ super().__init__(onnx_quantizer, onnx_node)
146
+
147
+ def quantize(self):
148
+ node = self.node
149
+ assert node.op_type == "Gemm"
150
+
151
+ self.quantizer.quantize_activation_tensor(node.input[0])
152
+ if not self.disable_qdq_for_node_output:
153
+ self.quantizer.quantize_activation_tensor(node.output[0])
154
+
155
+ is_weight_per_channel, weight_axis = self.quantizer.is_tensor_per_channel(
156
+ node.input[1], default_axis=0 if is_B_transposed(node) else 1
157
+ )
158
+ if is_weight_per_channel:
159
+ self.quantizer.quantize_weight_tensor_per_channel(node.input[1], weight_axis)
160
+ else:
161
+ self.quantizer.quantize_weight_tensor(node.input[1])
162
+
163
+ if len(node.input) == 3:
164
+ if self.quantizer.is_input_a_initializer(node.input[2]):
165
+ self.quantizer.quantize_bias_tensor(
166
+ node.name, node.input[2], node.input[0], node.input[1], get_beta(self.node)
167
+ )
168
+ set_default_beta(self.node)
169
+ else:
170
+ logging.warning(
171
+ f"Bias of Gemm node '{self.node.name}' is not constant. Please exclude this node for better performance."
172
+ )
@@ -0,0 +1,121 @@
1
+ import numpy
2
+ import onnx
3
+ from onnx import onnx_pb as onnx_proto
4
+
5
+ from ..quant_utils import QuantType, attribute_to_kwarg, ms_domain # noqa: F401
6
+ from .base_operator import QuantOperatorBase
7
+
8
+ """
9
+ Quantize LSTM
10
+ """
11
+
12
+
13
+ class LSTMQuant(QuantOperatorBase):
14
+ def __init__(self, onnx_quantizer, onnx_node):
15
+ super().__init__(onnx_quantizer, onnx_node)
16
+
17
+ def quantize(self):
18
+ """
19
+ parameter node: LSTM node.
20
+ parameter new_nodes_list: List of new nodes created before processing this node.
21
+ return: a list of nodes in topological order that represents quantized Attention node.
22
+ """
23
+ node = self.node
24
+ assert node.op_type == "LSTM"
25
+
26
+ if not self.quantizer.is_valid_quantize_weight(node.input[1]) or not self.quantizer.is_valid_quantize_weight(
27
+ node.input[2]
28
+ ):
29
+ super().quantize()
30
+ return
31
+
32
+ model = self.quantizer.model
33
+ W = model.get_initializer(node.input[1]) # noqa: N806
34
+ R = model.get_initializer(node.input[2]) # noqa: N806
35
+
36
+ if len(W.dims) != 3 or len(R.dims) != 3:
37
+ super().quantize()
38
+ return
39
+
40
+ [W_num_dir, W_4_hidden_size, W_input_size] = W.dims # noqa: N806
41
+ [R_num_dir, R_4_hidden_size, R_hidden_size] = R.dims # noqa: N806
42
+
43
+ if self.quantizer.is_per_channel():
44
+ del W.dims[0]
45
+ del R.dims[0]
46
+ W.dims[0] = W_num_dir * W_4_hidden_size
47
+ R.dims[0] = R_num_dir * R_4_hidden_size
48
+
49
+ quant_input_weight_tuple = self.quantizer.quantize_weight_per_channel(
50
+ node.input[1],
51
+ onnx_proto.TensorProto.INT8,
52
+ 0, # self.quantizer.weight_qType?
53
+ )
54
+ quant_recurrent_weight_tuple = self.quantizer.quantize_weight_per_channel(
55
+ node.input[2],
56
+ onnx_proto.TensorProto.INT8,
57
+ 0, # self.quantizer.weight_qType?
58
+ )
59
+
60
+ W_quant_weight = model.get_initializer(quant_input_weight_tuple[0]) # noqa: N806
61
+ R_quant_weight = model.get_initializer(quant_recurrent_weight_tuple[0]) # noqa: N806
62
+
63
+ W_quant_array = onnx.numpy_helper.to_array(W_quant_weight) # noqa: N806
64
+ R_quant_array = onnx.numpy_helper.to_array(R_quant_weight) # noqa: N806
65
+
66
+ W_quant_array = numpy.reshape(W_quant_array, (W_num_dir, W_4_hidden_size, W_input_size)) # noqa: N806
67
+ R_quant_array = numpy.reshape(R_quant_array, (R_num_dir, R_4_hidden_size, R_hidden_size)) # noqa: N806
68
+
69
+ W_quant_array = numpy.transpose(W_quant_array, (0, 2, 1)) # noqa: N806
70
+ R_quant_array = numpy.transpose(R_quant_array, (0, 2, 1)) # noqa: N806
71
+
72
+ W_quant_tranposed = onnx.numpy_helper.from_array(W_quant_array, quant_input_weight_tuple[0]) # noqa: N806
73
+ R_quant_tranposed = onnx.numpy_helper.from_array(R_quant_array, quant_recurrent_weight_tuple[0]) # noqa: N806
74
+
75
+ model.remove_initializers([W_quant_weight, R_quant_weight])
76
+ model.add_initializer(W_quant_tranposed)
77
+ model.add_initializer(R_quant_tranposed)
78
+
79
+ W_quant_zp = model.get_initializer(quant_input_weight_tuple[1]) # noqa: N806
80
+ R_quant_zp = model.get_initializer(quant_recurrent_weight_tuple[1]) # noqa: N806
81
+ W_quant_scale = model.get_initializer(quant_input_weight_tuple[2]) # noqa: N806
82
+ R_quant_scale = model.get_initializer(quant_recurrent_weight_tuple[2]) # noqa: N806
83
+
84
+ if self.quantizer.is_per_channel():
85
+ W_quant_zp.dims[:] = [W_num_dir, W_4_hidden_size]
86
+ R_quant_zp.dims[:] = [R_num_dir, R_4_hidden_size]
87
+ W_quant_scale.dims[:] = [W_num_dir, W_4_hidden_size]
88
+ R_quant_scale.dims[:] = [R_num_dir, R_4_hidden_size]
89
+
90
+ inputs = []
91
+ input_len = len(node.input)
92
+ inputs.extend([node.input[0]])
93
+ inputs.extend([quant_input_weight_tuple[0], quant_recurrent_weight_tuple[0]])
94
+ inputs.extend([node.input[3] if input_len > 3 else ""])
95
+ inputs.extend([node.input[4] if input_len > 4 else ""])
96
+ inputs.extend([node.input[5] if input_len > 5 else ""])
97
+ inputs.extend([node.input[6] if input_len > 6 else ""])
98
+ inputs.extend([node.input[7] if input_len > 7 else ""])
99
+ inputs.extend(
100
+ [
101
+ quant_input_weight_tuple[2],
102
+ quant_input_weight_tuple[1],
103
+ quant_recurrent_weight_tuple[2],
104
+ quant_recurrent_weight_tuple[1],
105
+ ]
106
+ )
107
+
108
+ kwargs = {}
109
+ for attribute in node.attribute:
110
+ if attribute.name == "layout":
111
+ continue
112
+ kwargs.update(attribute_to_kwarg(attribute))
113
+ kwargs["domain"] = ms_domain
114
+
115
+ quant_lstm_name = "" if not node.name else node.name + "_quant"
116
+ quant_lstm_node = onnx.helper.make_node("DynamicQuantizeLSTM", inputs, node.output, quant_lstm_name, **kwargs)
117
+ self.quantizer.new_nodes.append(quant_lstm_node)
118
+
119
+ dequantize_node = self.quantizer._dequantize_value(node.input[0])
120
+ if dequantize_node is not None:
121
+ self.quantizer.new_nodes.append(dequantize_node)