onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1477 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ from __future__ import annotations
7
+
8
+ import logging
9
+ from dataclasses import dataclass
10
+ from enum import Enum
11
+ from typing import Any
12
+
13
+ import numpy as np
14
+ import onnx
15
+ from onnx import TensorProto
16
+ from onnx import onnx_pb as onnx_proto
17
+
18
+ from .base_quantizer import BaseQuantizer, QuantizationParams
19
+ from .calibrate import TensorData
20
+ from .quant_utils import (
21
+ DEQUANT_OP_NAME,
22
+ ONNX_TYPE_TO_NP_TYPE,
23
+ QUANT_OP_NAME,
24
+ QuantizedValue,
25
+ QuantizedValueType,
26
+ __producer__,
27
+ __version__,
28
+ add_dequant_output_suffix,
29
+ add_dequant_suffix,
30
+ add_quant_input_suffix,
31
+ add_quant_output_suffix,
32
+ add_quant_suffix,
33
+ compute_data_quant_params,
34
+ compute_scale_zp,
35
+ compute_scale_zp_float8,
36
+ find_by_name,
37
+ get_qmin_qmax_for_qType,
38
+ ms_domain,
39
+ normalize_axis,
40
+ quantize_onnx_initializer,
41
+ tensor_proto_to_array,
42
+ )
43
+ from .registry import CreateQDQQuantizer
44
+
45
+
46
+ class QDQQuantTensorType(Enum):
47
+ ACTIVATION = 0
48
+ WEIGHT = 1
49
+ BIAS = 2
50
+
51
+
52
+ # Holds the name of the node input from which a node output will share the
53
+ # same quantization param initializers (zero-point and scale initializers).
54
+ # Ex: A Transpose node's output will use the same quant param initializers used at the input.
55
+ @dataclass
56
+ class QDQQuantParamProvider:
57
+ input_name: str
58
+ node_name: str
59
+
60
+
61
+ # Holds information for tensors that have been marked for quantization by operator quantizers.
62
+ # Does not hold information for bias tensors.
63
+ class QDQTensorQuantInfo:
64
+ def __init__(self, tensor_type=QDQQuantTensorType.ACTIVATION, quant_para_provider=None, axis=None, data_type=None):
65
+ self.tensor_type = tensor_type
66
+ self.quant_para_provider = quant_para_provider
67
+ self.axis = axis
68
+ self.is_shared = quant_para_provider is not None
69
+ assert data_type is not None
70
+ self.data_type = data_type
71
+
72
+
73
+ # Holds information for bias tensors that have been marked for quantization by operator quantizers.
74
+ @dataclass
75
+ class QDQBiasQuantInfo:
76
+ node_name: str
77
+ input_name: str
78
+ weight_name: str
79
+ beta: float
80
+
81
+
82
+ # Holds quantization parameter values (scale, zp) for a tensor.
83
+ # A tensor typically has a one set of quantization parameters, unless the tensor is
84
+ # at a "mixed-precision" boundary where the activation quantization type changes (e.g., from uint8 to uint16).
85
+ @dataclass
86
+ class QDQTensorQuantParams:
87
+ original: QuantizationParams # Generated by producer node.
88
+ converted: QuantizationParams | None # Converted type consumed by some (or all/none) consumer nodes.
89
+ converted_recv_nodes: set[str] | None # The name of nodes that consume the converted type.
90
+
91
+ def get_for_consumer(self, consumer_node_name) -> QuantizationParams:
92
+ if self.converted is None: # Quantized value is not converted, return original
93
+ return self.original
94
+
95
+ if self.converted_recv_nodes is None: # All consumers receive the converted value
96
+ return self.converted
97
+
98
+ # Check if consumer node name is in the list of nodes that
99
+ # receive the converted quantization value. If not, return the original value generated
100
+ # by the tensor's producer.
101
+ return self.converted if (consumer_node_name in self.converted_recv_nodes) else self.original
102
+
103
+
104
+ # Holds scale and zero_point initializer TensorProtos.
105
+ @dataclass
106
+ class QDQScaleZpInitializers:
107
+ scale: TensorProto
108
+ zero_point: TensorProto
109
+
110
+
111
+ # Holds all scale and zero-point initializers for a tensor.
112
+ # A tensor typically has a one set of quantization parameters, unless the tensor is
113
+ # at a "mixed-precision" boundary where the activation quantization type changes (e.g., from uint8 to uint16).
114
+ @dataclass
115
+ class QDQTensorScaleZpInitializers:
116
+ original: QDQScaleZpInitializers
117
+ converted: QDQScaleZpInitializers | None
118
+ converted_recv_nodes: set[str] | None
119
+
120
+
121
+ # Holds cached information of a tensor's quantized values (types, zp/scale initializer names, etc.).
122
+ # A tensor typically has a one set of quantization parameters, unless the tensor is
123
+ # at a "mixed-precision" boundary where the activation quantization type changes (e.g., from uint8 to uint16).
124
+ @dataclass
125
+ class QDQTensorQuantizedValue:
126
+ original: QuantizedValue
127
+ converted: QuantizedValue | None
128
+ converted_recv_nodes: set[str] | None
129
+
130
+ def get_for_consumer(self, consumer_node_name) -> QuantizedValue:
131
+ if self.converted is None: # Quantized value is not converted, return original
132
+ return self.original
133
+
134
+ if self.converted_recv_nodes is None: # All consumers receive the converted value
135
+ return self.converted
136
+
137
+ # Check if consumer node name is in the list of nodes that
138
+ # receive the converted quantization value. If not, return the original value generated
139
+ # by the tensor's producer.
140
+ return self.converted if (consumer_node_name in self.converted_recv_nodes) else self.original
141
+
142
+
143
+ class QDQQuantizer(BaseQuantizer):
144
+ def __init__(
145
+ self,
146
+ model,
147
+ per_channel,
148
+ reduce_range,
149
+ weight_qType,
150
+ activation_qType,
151
+ tensors_range,
152
+ nodes_to_quantize,
153
+ nodes_to_exclude,
154
+ op_types_to_quantize,
155
+ extra_options=None,
156
+ ):
157
+ BaseQuantizer.__init__(
158
+ self,
159
+ model,
160
+ per_channel,
161
+ reduce_range,
162
+ weight_qType,
163
+ activation_qType,
164
+ tensors_range,
165
+ nodes_to_quantize,
166
+ nodes_to_exclude,
167
+ op_types_to_quantize,
168
+ extra_options,
169
+ )
170
+ self.tensors_to_quantize: dict[str, QDQTensorQuantInfo] = {}
171
+ self.bias_to_quantize: dict[str, QDQBiasQuantInfo] = {}
172
+
173
+ self.nodes_to_remove = []
174
+
175
+ # Specific op types to exclude qdq quantization for their outputs.
176
+ # In TRT, it's not recommended to quantize outputs for weighted ops such as Conv, Matmul, Gemm
177
+ # because those ops may be followed by nodes that require high resolution inputs.
178
+ # Adding QDQ for those ops' output may end up with worse accuracy.
179
+ # So, we don't recommend to add QDQ to node's output under such condition.
180
+ self.op_types_to_exclude_output_quantization = extra_options.get("OpTypesToExcludeOutputQuantization", [])
181
+
182
+ # We do quantization on Dequantizelinear's input to remove Quantizelinear for weight as an optimization.
183
+ # In some cases, for example QDQ BERT model for TensorRT, QDQ should always appear as a pair.
184
+ # Therefore, we need to disable this optimization and add qdq pair to weight.
185
+ self.add_qdq_pair_to_weight = extra_options.get("AddQDQPairToWeight", False)
186
+
187
+ # Some scenarios do not need the bias quantized. For example, in the case of Quantization Aware Training,
188
+ # quantizing the bias is not needed. This is because in QAT, all model parameters are expected to be in
189
+ # floating point format. To that end, we can use the FakeQuant operator for weights and activations that
190
+ # can always have QDQ pairs (by using AddQDQPairToWeight). But for biases in a quantized model, we can't use
191
+ # FakeQuant because it only ever appears before a DQ (since it is quantized as int32).
192
+ self.quantize_bias = extra_options.get("QuantizeBias", True)
193
+
194
+ # The default behavior is that multiple nodes can share a QDQ pair as their inputs.
195
+ # In TRT, QDQ pair can`t be shared between nodes, so it will create dedicated QDQ pairs for each node.
196
+ self.dedicated_qdq_pair = extra_options.get("DedicatedQDQPair", False)
197
+ self.tensor_to_its_receiving_nodes: dict[str, list[onnx.NodeProto]] = {}
198
+
199
+ # Maps a tensor to the DequantizeLinear node (in the original input model) that outputs the tensor.
200
+ # Populated for input models with some pre-quantized weights (typically via a different tool).
201
+ self.tensor_to_producing_dq: dict[str, onnx.NodeProto] = {}
202
+
203
+ # Let user set channel axis for specific op type and it's effective only when per channel quantization is supported and per_channel is True.
204
+ self.qdq_op_type_per_channel_support_to_axis = extra_options.get("QDQOpTypePerChannelSupportToAxis", {})
205
+
206
+ self.qdq_op_domain = ms_domain if extra_options.get("UseQDQContribOps", False) else None
207
+
208
+ # User can specify if removable activations, like Clip/Relu, should be kept in the graph.
209
+ # Used in the QDQRemovableActivation class.
210
+ self.qdq_keep_removable_activations = extra_options.get("QDQKeepRemovableActivations", False)
211
+
212
+ # Let user disable adjustment of weight scales for bias inputs that are quantized to int32.
213
+ self.qdq_disable_weight_adjust_for_int32_bias = extra_options.get("QDQDisableWeightAdjustForInt32Bias", False)
214
+
215
+ # The ONNX spec did not support 16-bit Q/DQ ops before opset 21.
216
+ # So, may have to override the Q/DQ op domain to 'com.microsoft' if the activation or weight types
217
+ # are 16-bit or 4-bit integers.
218
+ if self.opset_version < 21:
219
+ opset21_types = (TensorProto.UINT16, TensorProto.INT16, TensorProto.UINT4, TensorProto.INT4)
220
+ overrides_have_opset21_types = any(
221
+ t.tensor_type in opset21_types for t in self.tensor_quant_override_qtypes
222
+ )
223
+ if not self.qdq_op_domain and (
224
+ self.activation_qType in opset21_types
225
+ or self.weight_qType in opset21_types
226
+ or overrides_have_opset21_types
227
+ ):
228
+ logging.warning(
229
+ "ONNX QuantizeLinear and DequantizeLinear operators do not support "
230
+ "16-bit/4-bit integer quantization types prior to opset 21. "
231
+ f"The domain of QuantizeLinear and DequantizeLinear operators will be set to '{ms_domain}' to "
232
+ "enable support."
233
+ )
234
+ self.qdq_op_domain = ms_domain
235
+
236
+ self.quantization_params = self.calc_graph_quant_params()
237
+ self.initializer_quant_params: dict[str, QuantizationParams] = {}
238
+
239
+ # Map of all original value names to quantized value names
240
+ self.quantized_value_map = {}
241
+
242
+ def _get_tensor_type(self, tensor_name):
243
+ """
244
+ Check if tensor can be quantized
245
+ """
246
+ weight = find_by_name(tensor_name, self.model.initializer())
247
+ if weight is not None:
248
+ return weight.data_type
249
+ elif tensor_name in self.value_infos:
250
+ vi = self.value_infos[tensor_name]
251
+ if vi.type.HasField("tensor_type"):
252
+ return vi.type.tensor_type.elem_type
253
+ return None
254
+
255
+ def _is_tensor_quantizable(self, tensor_name):
256
+ """
257
+ Check if tensor can be quantized
258
+ """
259
+ weight = find_by_name(tensor_name, self.model.initializer())
260
+ if weight is not None:
261
+ if weight.data_type in (onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16):
262
+ return True
263
+ elif tensor_name in self.value_infos:
264
+ vi = self.value_infos[tensor_name]
265
+ if vi.type.HasField("tensor_type") and vi.type.tensor_type.elem_type in (
266
+ TensorProto.FLOAT,
267
+ TensorProto.FLOAT16,
268
+ ):
269
+ return True
270
+ else:
271
+ logging.warning(
272
+ f"failed to infer the type of tensor: {tensor_name}. Skip to quantize it. Please check if it is expected."
273
+ )
274
+
275
+ return False
276
+
277
+ def __quantize_tensor(self, tensor_name, quant_sharing_provider=None, tensor_type=QDQQuantTensorType.ACTIVATION):
278
+ """
279
+ Adds a tensor to the list (actually a dict) of tensors to quantize. Called indirectly by op quantizers that
280
+ want to quantize a tensor (i.e., "mark" a tensor for quantization).
281
+
282
+ If quant_sharing_provider is not None, tensor with name tensor_name will be quantized with the same
283
+ quantization parameters as the node input specified in quant_sharing_provider. Ex: A Tranpose node's output
284
+ will typically use the same quantization parameter initializers used at the Transpose node's input.
285
+
286
+ Args:
287
+ tensor_name: name of the tensor to quantize
288
+ quant_sharing_provider: name of the tensor and node that provides quantization parameter
289
+ tensor_type: QDQQuantTensorType default ACTIVATION
290
+ """
291
+ if self._is_tensor_quantizable(tensor_name):
292
+ if quant_sharing_provider:
293
+ if not isinstance(quant_sharing_provider, QDQQuantParamProvider):
294
+ raise TypeError(
295
+ f"quant_sharing_provider must be of type QDQQuantParamProvider, not {type(quant_sharing_provider)}."
296
+ )
297
+
298
+ data_type = self._get_tensor_type(tensor_name)
299
+ self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo(
300
+ tensor_type=tensor_type, quant_para_provider=quant_sharing_provider, data_type=data_type
301
+ )
302
+ elif tensor_name not in self.tensors_to_quantize:
303
+ data_type = self._get_tensor_type(tensor_name)
304
+ self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo(tensor_type=tensor_type, data_type=data_type)
305
+
306
+ def quantize_activation_tensor(self, tensor_name: str):
307
+ """
308
+ Adds a tensor to the list of tensors to quantize. Called by op quantizers that
309
+ want to quantize a tensor (i.e., "mark" a tensor for quantization).
310
+
311
+ Args:
312
+ tensor_name: name of the tensor to quantize
313
+ """
314
+ return self.__quantize_tensor(tensor_name, None, QDQQuantTensorType.ACTIVATION)
315
+
316
+ def quantize_output_same_as_input(self, output_name: str, input_name: str, node_name: str):
317
+ """
318
+ Adds a tensor to the list of tensors to quantize. Called by op quantizers that
319
+ want to quantize an output tensor using the same quantization parameters as one of the node's inputs.
320
+
321
+ Ex: A Tranpose node's output will typically use the same quantization parameter initializers used at
322
+ the Transpose node's input.
323
+
324
+ Args:
325
+ output_name: name of the node output to quantize so that it uses the same quantization params as an input.
326
+ input_name: name of the node input from which the output tensor will get its quantization params.
327
+ node_name: name of the node that consumes `input_name`.
328
+ """
329
+ return self.__quantize_tensor(
330
+ output_name, QDQQuantParamProvider(input_name, node_name), QDQQuantTensorType.ACTIVATION
331
+ )
332
+
333
+ def quantize_weight_tensor(self, tensor_name: str):
334
+ """
335
+ Adds a tensor to the list of weight tensors to quantize. Called by op quantizers that
336
+ want to quantize a weight (i.e., "mark" a weight for quantization).
337
+
338
+ Args:
339
+ tensor_name: name of the weight to quantize
340
+ """
341
+ return self.__quantize_tensor(tensor_name, None, QDQQuantTensorType.WEIGHT)
342
+
343
+ def quantize_weight_tensor_per_channel(self, tensor_name, axis):
344
+ weight = find_by_name(tensor_name, self.model.initializer())
345
+ if weight:
346
+ if weight.data_type in (onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16):
347
+ self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo(
348
+ tensor_type=QDQQuantTensorType.WEIGHT, axis=axis, data_type=weight.data_type
349
+ )
350
+ else:
351
+ logging.warning(f"only support per-channel quantization on weight. Tensor: {tensor_name} is not quantized.")
352
+
353
+ def _dup_initializer(self, initializer: onnx.TensorProto) -> onnx.TensorProto:
354
+ """
355
+ Duplicates an existing initializer and adds it to the model. Returns the new initializer.
356
+ """
357
+ name_suffix: int = self.model.get_largest_initializer_name_suffix(initializer.name) + 1
358
+ new_initializer_name = f"{initializer.name}{name_suffix}"
359
+ new_initializer = onnx.TensorProto()
360
+ new_initializer.CopyFrom(initializer)
361
+ new_initializer.name = new_initializer_name
362
+ self.model.add_initializer(new_initializer)
363
+ return new_initializer
364
+
365
+ def quantize_bias_tensor(self, node_name, bias_name, input_name, weight_name, beta=1.0):
366
+ """
367
+ Adds a bias tensor to the list of bias tensors to quantize. Called by op quantizers that
368
+ want to quantize a bias with bias_zero_point = 0 and bias_scale = input_scale * weight_scale * beta.
369
+ TODO: Explain the reasoning for using this formula.
370
+
371
+ Args:
372
+ node_name: name of the node that consumes the bias, input, and weight tensors.
373
+ bias_name: name of the bias tensor to quantize.
374
+ input_name: name of the input tensor whose scale is used to compute the bias's scale.
375
+ weight_name: name of the weight tensor whose scale is used to compute the bias's scale.
376
+ beta: Multiplier used to compute the bias's scale.
377
+ """
378
+ # If the user provided quantization overrides for this tensor, treat it as a regular weight.
379
+ if self.tensor_quant_overrides.get(bias_name):
380
+ logging.info(
381
+ f"Quantizing bias tensor '{bias_name}' as a weight due to the presence of user-specified overrides"
382
+ )
383
+ is_per_channel, axis = self.is_tensor_per_channel(bias_name, default_axis=0)
384
+ if is_per_channel:
385
+ self.quantize_weight_tensor_per_channel(bias_name, axis)
386
+ else:
387
+ self.quantize_weight_tensor(bias_name)
388
+ return
389
+
390
+ bias_initializer = find_by_name(bias_name, self.model.initializer())
391
+ if bias_initializer is None:
392
+ logging.warning(f"Expected bias '{bias_name}' to be an initializer")
393
+ return
394
+
395
+ if bias_initializer.data_type not in (onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16):
396
+ logging.info(f"Expected bias '{bias_name}' to be an floating-point initializer")
397
+ return
398
+
399
+ actual_bias_name = bias_name
400
+ if bias_name in self.bias_to_quantize:
401
+ # This bias input is consumed by two different nodes. We need to duplicate the bias so that
402
+ # each node has its own bias input. This is necessary because the bias's scale is computed
403
+ # from the node's other input scales.
404
+ new_bias_initializer = self._dup_initializer(bias_initializer)
405
+ actual_bias_name = new_bias_initializer.name
406
+
407
+ # Replace this node's bias input
408
+ self.model.replace_input_of_nodes(bias_name, actual_bias_name, {node_name})
409
+ logging.info(f"Created a copy of bias input '{bias_name}' called '{actual_bias_name}'")
410
+
411
+ # Add this to our list of biases to quantize.
412
+ self.bias_to_quantize[actual_bias_name] = QDQBiasQuantInfo(node_name, input_name, weight_name, beta)
413
+
414
+ def _adjust_weight_scale_for_int32_bias(
415
+ self,
416
+ input_scale: np.ndarray,
417
+ weight_scale: np.ndarray,
418
+ weight_name: str,
419
+ bias_tp: onnx.TensorProto,
420
+ is_per_channel: bool,
421
+ ) -> tuple[bool, np.ndarray | None]:
422
+ """
423
+ Checks if the bias scale (input_scale * weight_scale) that we intend to use is too small.
424
+ A bias scale that is too small leads to quantized bias values that fall outside the range of a int32 and have to
425
+ be clipped, which decreases accuracy. If this function detects such a scenario, the weight_scale value will be
426
+ increased to prevent this from happening.
427
+
428
+ Although the adjustment method and amount differs, the idea to adjust the weight's scale came from the following
429
+ reference:
430
+ https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/optimize/quantization_utils.cc#L252
431
+
432
+ :param input_scale: The input's scale.
433
+ :param weight_scale: The weight scale to potentially adjust.
434
+ :param weight_name: The weight initializer's name. Used for logging.
435
+ :param bias_tp: The bias ONNX initializer.
436
+ :param is_per_channel: True if the bias and weight are quantized per-channel.
437
+ :return: A tuple with a bool indicating if the weight's scale was adjusted and the new weight scale.
438
+ """
439
+ if not weight_scale.size:
440
+ return False, None
441
+
442
+ bias_float_data = tensor_proto_to_array(bias_tp)
443
+
444
+ int32_info = np.iinfo(np.int32)
445
+ multiplicative_epsilon = 1.0001
446
+ qrange = np.array(int32_info.max, dtype=np.float64) - np.array(int32_info.min + 1, dtype=np.float64)
447
+ weight_scale_dtype = weight_scale.dtype
448
+ updated_an_elem = False
449
+
450
+ if not is_per_channel:
451
+ rmin = np.minimum(bias_float_data.min(), np.array(0, dtype=np.float64))
452
+ rmax = np.maximum(bias_float_data.max(), np.array(0, dtype=np.float64))
453
+ absmax = np.maximum(np.abs(rmin), np.abs(rmax))
454
+ bias_smallest_valid_scale = multiplicative_epsilon * (2.0 * absmax) / qrange
455
+
456
+ input_scale_fp64 = np.array(input_scale.item(), dtype=np.float64)
457
+ weight_scale_fp64 = np.array(weight_scale.item(), dtype=np.float64)
458
+ bias_candidate_scale = input_scale_fp64 * weight_scale_fp64
459
+
460
+ if (bias_candidate_scale < bias_smallest_valid_scale) and (bias_candidate_scale > 0.0):
461
+ # The candidate bias scale would be too small, so increase the weight_scale by the necessary ratio.
462
+ ratio = bias_smallest_valid_scale / bias_candidate_scale
463
+ logging.info(
464
+ f"Increasing scale for weight `{weight_name}` by the ratio {ratio} to "
465
+ f"ensure bias input `{bias_tp.name}` has a valid scale."
466
+ )
467
+ new_scale = weight_scale_fp64 * ratio
468
+ weight_scale = new_scale.astype(weight_scale_dtype)
469
+ updated_an_elem = True
470
+ elif weight_scale.shape and len(weight_scale.shape) == 1:
471
+ # per-channel case
472
+ num_elems = weight_scale.shape[0]
473
+
474
+ for i in range(num_elems):
475
+ bias_rmax = np.abs(bias_float_data[i])
476
+ bias_smallest_valid_scale = multiplicative_epsilon * (2.0 * bias_rmax) / qrange
477
+
478
+ input_scale_fp64 = np.array(input_scale.item(), dtype=np.float64)
479
+ weight_scale_fp64 = np.array(weight_scale[i].item(), dtype=np.float64)
480
+ bias_candidate_scale = input_scale_fp64 * weight_scale_fp64
481
+ if (bias_candidate_scale < bias_smallest_valid_scale) and (bias_candidate_scale > 0.0):
482
+ # The candidate bias scale would be too small, so increase the weight_scale by the necessary ratio.
483
+ ratio = bias_smallest_valid_scale / bias_candidate_scale
484
+ logging.info(
485
+ f"Increased scale[{i}] for weight `{weight_name}` by ratio {ratio} "
486
+ f"to ensure bias input `{bias_tp.name}` has a valid scale."
487
+ )
488
+ new_scale = weight_scale_fp64 * ratio
489
+ weight_scale[i] = new_scale.astype(weight_scale_dtype)
490
+ updated_an_elem = True
491
+
492
+ return updated_an_elem, weight_scale
493
+
494
+ def _adjust_weight_quant_params_for_bias_tensors(self):
495
+ """
496
+ Iterates through all bias inputs that should be quantized to int32. If the intended
497
+ bias scale (equal to input_scale * weight_scale) is too small, this function will increase
498
+ the associated weight's scale to ensure the bias does not overflow the int32 range when quantized.
499
+ """
500
+
501
+ if self.qdq_disable_weight_adjust_for_int32_bias:
502
+ # User passed an extra_option to disable this adjustment.
503
+ return
504
+
505
+ for bias_name, bias_info in self.bias_to_quantize.items():
506
+ if (
507
+ bias_info.input_name not in self.quantization_params
508
+ or bias_info.input_name not in self.tensors_to_quantize
509
+ or bias_info.weight_name not in self.initializer_quant_params
510
+ ):
511
+ continue
512
+
513
+ # Get the associated input's scale.
514
+ input_qparams = self.quantization_params[bias_info.input_name].get_for_consumer(bias_info.node_name)
515
+ input_info = self.tensors_to_quantize[bias_info.input_name]
516
+ input_scale = np.asarray(
517
+ input_qparams["scale"], dtype=onnx.helper.tensor_dtype_to_np_dtype(input_info.data_type)
518
+ )
519
+
520
+ weight_quant_params = self.initializer_quant_params[bias_info.weight_name]
521
+ weight_quant_type = weight_quant_params["quant_type"]
522
+ if weight_quant_type not in (onnx.TensorProto.INT8, onnx.TensorProto.INT16):
523
+ continue
524
+
525
+ weight_zero_point: np.ndarray = weight_quant_params["zero_point"]
526
+ if weight_zero_point.any():
527
+ # Skip if zero_point(s) are not all zero (i.e., symmetric quant)
528
+ continue
529
+
530
+ weight_scale: np.ndarray = weight_quant_params["scale"]
531
+ is_per_channel = weight_quant_params.get("axis", None) is not None
532
+
533
+ # Get adjusted weight scales.
534
+ did_update_weight_scale, new_weight_scale = self._adjust_weight_scale_for_int32_bias(
535
+ input_scale,
536
+ weight_scale,
537
+ bias_info.weight_name,
538
+ find_by_name(bias_name, self.model.initializer()),
539
+ is_per_channel,
540
+ )
541
+
542
+ if did_update_weight_scale:
543
+ weight_quant_params["scale"] = new_weight_scale
544
+
545
+ def remove_node(self, node):
546
+ self.nodes_to_remove.append(node)
547
+
548
+ def remove_nodes(self):
549
+ self.model.remove_nodes(self.nodes_to_remove)
550
+
551
+ def quantize_model(self):
552
+ for node in self.model.nodes():
553
+ if self.should_quantize_node(node):
554
+ op_quantizer = CreateQDQQuantizer(self, node)
555
+ op_quantizer.quantize()
556
+
557
+ for tensor_name in node.input:
558
+ if tensor_name not in self.tensor_to_its_receiving_nodes:
559
+ self.tensor_to_its_receiving_nodes[tensor_name] = []
560
+ self.tensor_to_its_receiving_nodes[tensor_name].append(node)
561
+ if node.op_type == DEQUANT_OP_NAME:
562
+ for tensor_name in node.output:
563
+ self.tensor_to_producing_dq[tensor_name] = node
564
+
565
+ self.initializer_quant_params = self._calc_initializer_quant_params()
566
+ self._adjust_weight_quant_params_for_bias_tensors()
567
+ self._quantize_normal_tensors()
568
+ self._quantize_sharing_param_tensors()
569
+ if self.quantize_bias:
570
+ self._quantize_bias_tensors()
571
+ self.remove_nodes()
572
+ if not self.add_qdq_pair_to_weight:
573
+ self.model.clean_initializers()
574
+
575
+ self.model.model.producer_name = __producer__
576
+ self.model.model.producer_version = __version__
577
+ if self.qdq_op_domain == ms_domain:
578
+ self.model.set_opset_import(ms_domain, 1)
579
+
580
+ return self.model.model
581
+
582
+ def try_replacing_upstream_output(self, upstream_output_name, output_name):
583
+ if (
584
+ output_name in self.quantization_params
585
+ and self.quantization_params[output_name].converted is None
586
+ and self.quantization_params[upstream_output_name].converted is None
587
+ and len(self.model.input_name_to_nodes()[upstream_output_name]) == 1
588
+ and not self.model.is_graph_output(upstream_output_name)
589
+ and not self.model.is_graph_input(upstream_output_name)
590
+ ):
591
+ self.model.replace_output_of_all_nodes(upstream_output_name, output_name)
592
+ if upstream_output_name in self.tensors_to_quantize:
593
+ del self.tensors_to_quantize[upstream_output_name]
594
+ return True
595
+ return False
596
+
597
+ def _create_q_node(
598
+ self,
599
+ q_input: str,
600
+ q_output: str,
601
+ quant_node_name: str,
602
+ scale_name: str,
603
+ zp_name: str,
604
+ axis: int | None = None,
605
+ ):
606
+ """
607
+ Creates a QuantizeLinear node and adds it to the model.
608
+ """
609
+ qlinear_node = onnx.helper.make_node(
610
+ QUANT_OP_NAME,
611
+ [q_input, scale_name, zp_name],
612
+ [q_output],
613
+ quant_node_name,
614
+ axis=axis,
615
+ domain=self.qdq_op_domain,
616
+ )
617
+ self.model.add_nodes([qlinear_node])
618
+
619
+ def _create_dq_node(
620
+ self,
621
+ dq_input: str,
622
+ dq_output: str,
623
+ dequant_node_name: str,
624
+ scale_name: str,
625
+ zp_name: str,
626
+ axis: int | None = None,
627
+ ):
628
+ """
629
+ Creates a DequantizeLinear node and adds it to the model.
630
+ """
631
+ dequant_node = onnx.helper.make_node(
632
+ DEQUANT_OP_NAME,
633
+ [dq_input, scale_name, zp_name],
634
+ [dq_output],
635
+ dequant_node_name,
636
+ axis=axis,
637
+ domain=self.qdq_op_domain,
638
+ )
639
+ self.model.add_nodes([dequant_node])
640
+
641
+ def _create_qdq_nodes(
642
+ self, q_input, q_output, quant_node_name, dq_input, dq_output, dequant_node_name, scale_name, zp_name, axis=None
643
+ ):
644
+ qlinear_node = onnx.helper.make_node(
645
+ QUANT_OP_NAME,
646
+ [q_input, scale_name, zp_name],
647
+ [q_output],
648
+ quant_node_name,
649
+ axis=axis,
650
+ domain=self.qdq_op_domain,
651
+ )
652
+ dequant_node = onnx.helper.make_node(
653
+ DEQUANT_OP_NAME,
654
+ [dq_input, scale_name, zp_name],
655
+ [dq_output],
656
+ dequant_node_name,
657
+ axis=axis,
658
+ domain=self.qdq_op_domain,
659
+ )
660
+ self.model.add_nodes([qlinear_node, dequant_node])
661
+
662
+ def _add_qdq_nodes_for_initializer(self, weight_proto: onnx.TensorProto):
663
+ """
664
+ Adds Q/DQ nodes for an initializer. If `self.add_qdq_pair_to_weight` is true, creates
665
+ the sequence (weight_f32 -> Q -> DQ -> ). Otherwise, this function quantizes the initializer
666
+ and adds the sequence (weight_quant -> DQ ->).
667
+ """
668
+ weight_name = weight_proto.name
669
+ if weight_name in self.quantized_value_map:
670
+ return
671
+
672
+ quant_params: QuantizationParams = self.initializer_quant_params[weight_name]
673
+ axis: int = quant_params.get("axis")
674
+ scale_zp_initializers = self._make_scale_zp_initializers(weight_name, quant_params)
675
+ q_weight_name: str | None = None
676
+ weight_dequant_output = add_dequant_output_suffix(weight_name)
677
+ self.model.replace_input_of_all_nodes(weight_name, weight_dequant_output)
678
+
679
+ if self.add_qdq_pair_to_weight:
680
+ # Don't actually quantize the weight. Instead, keep floating-point weight and create the node
681
+ # sequence (weight_f32 -> Q -> DQ -> weight_dequant)
682
+ weight_quant_output = add_quant_output_suffix(weight_name)
683
+
684
+ self._create_qdq_nodes(
685
+ weight_name,
686
+ weight_quant_output,
687
+ add_quant_suffix(weight_name),
688
+ weight_quant_output,
689
+ weight_dequant_output,
690
+ add_dequant_suffix(weight_name),
691
+ scale_zp_initializers.scale.name,
692
+ scale_zp_initializers.zero_point.name,
693
+ axis,
694
+ )
695
+ else:
696
+ # Quantize the weight and create the node sequence:
697
+ # (weight_quantized -> DQ -> weight_dequant)
698
+ quant_weight = quantize_onnx_initializer(
699
+ weight_proto,
700
+ quant_params["quant_type"],
701
+ quant_params["zero_point"],
702
+ quant_params["scale"],
703
+ axis,
704
+ )
705
+ self.model.add_initializer(quant_weight)
706
+
707
+ q_weight_name = quant_weight.name
708
+ dequant_node = onnx.helper.make_node(
709
+ DEQUANT_OP_NAME,
710
+ [quant_weight.name, scale_zp_initializers.scale.name, scale_zp_initializers.zero_point.name],
711
+ [weight_dequant_output],
712
+ add_dequant_suffix(weight_name),
713
+ axis=axis,
714
+ domain=self.qdq_op_domain,
715
+ )
716
+ self.model.add_node(dequant_node)
717
+
718
+ # Log entry for this quantized weight
719
+ quantized_value = QuantizedValue(
720
+ weight_name,
721
+ q_weight_name,
722
+ scale_zp_initializers.scale.name,
723
+ scale_zp_initializers.zero_point.name,
724
+ QuantizedValueType.Initializer,
725
+ axis=axis,
726
+ )
727
+ self.quantized_value_map[weight_name] = QDQTensorQuantizedValue(quantized_value, None, None)
728
+
729
+ def _add_qdq_pair_for_activation(self, tensor_name, scale_name, zp_name, data_type=None):
730
+ if (
731
+ self.dedicated_qdq_pair
732
+ and tensor_name in self.tensor_to_its_receiving_nodes
733
+ and len(self.tensor_to_its_receiving_nodes[tensor_name]) > 1
734
+ ):
735
+ num_dedicated_qdq_pair = len(self.tensor_to_its_receiving_nodes[tensor_name])
736
+ for i in range(num_dedicated_qdq_pair):
737
+ postfix = f"_{i + 1}"
738
+ tensor_name_quant_output_postfix = add_quant_output_suffix(tensor_name) + postfix
739
+ tensor_name_dequant_output_postfix = add_dequant_output_suffix(tensor_name) + postfix
740
+ quant_node_name_postfix = add_quant_suffix(tensor_name) + postfix
741
+ dequant_node_name_postfix = add_dequant_suffix(tensor_name) + postfix
742
+ self._create_qdq_nodes(
743
+ tensor_name,
744
+ tensor_name_quant_output_postfix,
745
+ quant_node_name_postfix,
746
+ tensor_name_quant_output_postfix,
747
+ tensor_name_dequant_output_postfix,
748
+ dequant_node_name_postfix,
749
+ scale_name,
750
+ zp_name,
751
+ )
752
+
753
+ node = self.tensor_to_its_receiving_nodes[tensor_name][i]
754
+ self.model.replace_node_input(node, tensor_name, tensor_name_dequant_output_postfix)
755
+ if i == 0:
756
+ quantized_value = QuantizedValue(
757
+ tensor_name,
758
+ tensor_name_dequant_output_postfix,
759
+ scale_name,
760
+ zp_name,
761
+ QuantizedValueType.Input,
762
+ scale_type=data_type,
763
+ )
764
+ self.quantized_value_map[tensor_name] = QDQTensorQuantizedValue(quantized_value, None, None)
765
+ else:
766
+ q_input = tensor_name
767
+ dq_output = add_dequant_output_suffix(tensor_name)
768
+ if self.model.is_graph_output(tensor_name):
769
+ q_input = add_quant_input_suffix(tensor_name)
770
+ dq_output = tensor_name
771
+ self.model.replace_output_of_all_nodes(tensor_name, q_input)
772
+ else:
773
+ self.model.replace_input_of_all_nodes(tensor_name, dq_output)
774
+
775
+ self._create_qdq_nodes(
776
+ q_input,
777
+ add_quant_output_suffix(tensor_name),
778
+ add_quant_suffix(tensor_name),
779
+ add_quant_output_suffix(tensor_name),
780
+ dq_output,
781
+ add_dequant_suffix(tensor_name),
782
+ scale_name,
783
+ zp_name,
784
+ )
785
+
786
+ quantized_value = QuantizedValue(
787
+ tensor_name,
788
+ dq_output,
789
+ scale_name,
790
+ zp_name,
791
+ QuantizedValueType.Input,
792
+ scale_type=data_type,
793
+ )
794
+ self.quantized_value_map[tensor_name] = QDQTensorQuantizedValue(quantized_value, None, None)
795
+
796
+ def _add_qdq_ops_for_converted_activation(
797
+ self,
798
+ tensor_name,
799
+ first_scale_name,
800
+ first_zp_name,
801
+ scale_data_type,
802
+ convert_scale_name,
803
+ convert_zp_name,
804
+ convert_recv_nodes,
805
+ ):
806
+ """
807
+ Adds Q and DQ ops to a tensor whose quantized data type is converted. That is, some consumers may use the
808
+ original data type from the producer, while other consumers use the converted data type.
809
+ This is generally done by adding a sequence of ops that convert from one data type (e.g., uint8) to another (e.g., uint16).
810
+
811
+ T_float ---> Quant(to u8) ---> Convert(to u16) ---> Dequant(to float) ---> T_float'
812
+ where Convert(to u16) is equivalent to: ---> Dequant(to float) ---> Quant(to u16) --->
813
+
814
+ This function handles the following scenarios:
815
+
816
+ 1) Tensor T is not a graph output; all consumers use the converted type
817
+
818
+ <Producer> ---> Q1 ---> DQ1 ---> Q2 ---> DQ2 ---> <Consumers>
819
+
820
+ 2) Tensor T is not a graph output; some consumers use the original type, others use the converted type
821
+
822
+ <Producer> ---> Q1 -+-> DQ1 ---> <Consumers of original type>
823
+ |
824
+ +-> DQ1' ---> Q2 ---> DQ2 ---> <Consumers of converted type>
825
+
826
+ 3) Tensor T is a graph output; all consumers use the converted type
827
+
828
+ <Producer> ---> Q1 ---> DQ1 ---> Q2 ---> DQ2 -+-> <Consumers>
829
+ |
830
+ +-> <Graph output>
831
+
832
+ 4) Tensor T is a graph output; some consumers use the original type, others use the converted type
833
+
834
+ <Producer> ---> Q1 -+-> DQ1 -+-> <Consumers of original type>
835
+ | |
836
+ | +-> <Graph output>
837
+ |
838
+ +-> DQ1' ---> Q2 ---> DQ2 ---> <Consumers of converted type>
839
+
840
+ 5) Tensor T is a graph output that is not consumed by any other nodes.
841
+
842
+ <Producer> ---> Q1 ---> DQ1 ---> Q2 ---> DQ2 ---> <Graph output>
843
+ """
844
+ tensor_recv_nodes = {node.name for node in self.tensor_to_its_receiving_nodes.get(tensor_name, [])}
845
+
846
+ if (
847
+ self.dedicated_qdq_pair
848
+ and tensor_name in self.tensor_to_its_receiving_nodes
849
+ and len(self.tensor_to_its_receiving_nodes[tensor_name]) > 1
850
+ ):
851
+ # TODO: Add support for dedicated_qdq_pair if/when needed.
852
+ raise ValueError(
853
+ "Do not currently support converted quant_types in TensorQuantOverrides when the `dedicated_qdq_pair` extra_option is enabled"
854
+ )
855
+
856
+ # Determine which nodes consume the original quantized type and which nodes
857
+ # consume the converted quantized type.
858
+ original_recv_nodes = tensor_recv_nodes
859
+ if convert_recv_nodes is None: # In this case, all consumers receive the converted type.
860
+ convert_recv_nodes = tensor_recv_nodes
861
+ original_recv_nodes = set()
862
+ else:
863
+ original_recv_nodes = original_recv_nodes - convert_recv_nodes
864
+
865
+ all_use_converted = len(convert_recv_nodes) == len(tensor_recv_nodes)
866
+ is_graph_output = self.model.is_graph_output(tensor_name)
867
+
868
+ # Create first Q op.
869
+ first_q_input = tensor_name
870
+ if is_graph_output:
871
+ first_q_input = add_quant_input_suffix(tensor_name)
872
+ self.model.replace_output_of_all_nodes(tensor_name, first_q_input)
873
+
874
+ first_q_output = add_quant_output_suffix(tensor_name)
875
+ self._create_q_node(
876
+ first_q_input, first_q_output, add_quant_suffix(tensor_name), first_scale_name, first_zp_name
877
+ )
878
+
879
+ # Create first DQ op.
880
+ first_dq_output = add_dequant_output_suffix(tensor_name)
881
+ if is_graph_output and not all_use_converted:
882
+ first_dq_output = tensor_name
883
+ if original_recv_nodes and first_dq_output != tensor_name:
884
+ self.model.replace_input_of_nodes(tensor_name, first_dq_output, original_recv_nodes)
885
+
886
+ self._create_dq_node(
887
+ first_q_output, first_dq_output, add_dequant_suffix(tensor_name), first_scale_name, first_zp_name
888
+ )
889
+
890
+ # Create parallel clone of first DQ op if _not all_ consumers use the converted type.
891
+ # --> DQ1' --> Q2 --> DQ2 --> <Consumers of converted type>
892
+ #
893
+ # This DQ clone would only have one consumer Q node (Q2) and could be potentially fused with
894
+ # it by some EPs (e.g., QNN) without breaking other "node units".
895
+ # Ex QNN fusion:
896
+ # --> Convert (fused) --> DQ2 --> <Consumers of converted type>
897
+ second_q_input = first_dq_output
898
+ if not all_use_converted:
899
+ second_q_input = add_quant_input_suffix(f"{tensor_name}_convert")
900
+ self._create_dq_node(
901
+ first_q_output,
902
+ second_q_input,
903
+ add_dequant_suffix(f"{tensor_name}_convert_clone"),
904
+ first_scale_name,
905
+ first_zp_name,
906
+ )
907
+
908
+ # Create second Q op.
909
+ second_q_output = add_quant_output_suffix(f"{tensor_name}_convert")
910
+ self._create_q_node(
911
+ second_q_input,
912
+ second_q_output,
913
+ add_quant_suffix(f"{tensor_name}_convert"),
914
+ convert_scale_name,
915
+ convert_zp_name,
916
+ )
917
+
918
+ # Create second DQ op.
919
+ second_dq_output = add_dequant_output_suffix(f"{tensor_name}_convert")
920
+ if is_graph_output and all_use_converted:
921
+ second_dq_output = tensor_name
922
+ if convert_recv_nodes and second_dq_output != tensor_name:
923
+ self.model.replace_input_of_nodes(tensor_name, second_dq_output, convert_recv_nodes)
924
+ self._create_dq_node(
925
+ second_q_output,
926
+ second_dq_output,
927
+ add_dequant_suffix(f"{tensor_name}_convert"),
928
+ convert_scale_name,
929
+ convert_zp_name,
930
+ )
931
+
932
+ # Store in quantized_value_map
933
+ original_quantized_value = QuantizedValue(
934
+ tensor_name,
935
+ first_dq_output,
936
+ first_scale_name,
937
+ first_zp_name,
938
+ QuantizedValueType.Input,
939
+ scale_type=scale_data_type,
940
+ )
941
+ converted_quantized_value = QuantizedValue(
942
+ tensor_name,
943
+ second_dq_output,
944
+ convert_scale_name,
945
+ convert_zp_name,
946
+ QuantizedValueType.Input,
947
+ scale_type=scale_data_type,
948
+ )
949
+ self.quantized_value_map[tensor_name] = QDQTensorQuantizedValue(
950
+ original_quantized_value, converted_quantized_value, convert_recv_nodes
951
+ )
952
+
953
+ def _quantize_normal_tensors(self):
954
+ """
955
+ Adds Q/DQ ops to tensors (activations and weights) that have been marked for quantization by op quantizers.
956
+ """
957
+ for tensor_name, tensor_info in self.tensors_to_quantize.copy().items():
958
+ if tensor_name in self.quantized_value_map:
959
+ continue
960
+
961
+ if not tensor_info.is_shared:
962
+ # Quantize the input
963
+ initializer = find_by_name(tensor_name, self.model.initializer())
964
+ if initializer:
965
+ self._add_qdq_nodes_for_initializer(initializer)
966
+ else:
967
+ # Check if this tensor is already a dequantized value. If so, skip it.
968
+ # This happens if the original input model already has some pre-quantized weights
969
+ # generated by a different tool.
970
+ # Ex: (quantized_weight -> DequantizeLinear -> this_tensor)
971
+ if tensor_name in self.tensor_to_producing_dq:
972
+ del self.tensors_to_quantize[tensor_name]
973
+ continue
974
+
975
+ tensor_qparam_initializers = self._make_tensor_scale_zp_initializers(tensor_name)
976
+ if not tensor_qparam_initializers:
977
+ raise ValueError(
978
+ f"Quantization parameters are not specified for param {tensor_name}. "
979
+ "In static mode quantization params for inputs and outputs of nodes to be quantized are required."
980
+ )
981
+
982
+ if tensor_qparam_initializers.converted is None:
983
+ # Normal case: <producer> --> Q --> DQ --> <consumers>
984
+ self._add_qdq_pair_for_activation(
985
+ tensor_name,
986
+ tensor_qparam_initializers.original.scale.name,
987
+ tensor_qparam_initializers.original.zero_point.name,
988
+ data_type=tensor_info.data_type,
989
+ )
990
+ else:
991
+ # Conversion case: <producer> ---> Q1 -+-> DQ1 --> <consumers of original type>
992
+ # |
993
+ # +-> DQ1' --> Q2 --> DQ2 --> <consumers of converted type>
994
+ assert tensor_info.data_type == tensor_qparam_initializers.original.scale.data_type
995
+ self._add_qdq_ops_for_converted_activation(
996
+ tensor_name,
997
+ tensor_qparam_initializers.original.scale.name,
998
+ tensor_qparam_initializers.original.zero_point.name,
999
+ tensor_info.data_type,
1000
+ tensor_qparam_initializers.converted.scale.name,
1001
+ tensor_qparam_initializers.converted.zero_point.name,
1002
+ tensor_qparam_initializers.converted_recv_nodes,
1003
+ )
1004
+
1005
+ del self.tensors_to_quantize[tensor_name]
1006
+
1007
+ def _quantize_sharing_param_tensors(self):
1008
+ """
1009
+ Adds Q/DQ ops to tensors that have been marked for quantization by op quantizers.
1010
+ Only operates on tensors that want to use the quantization parameter initializers from an upstream tensor.
1011
+ For example, a Transpose node's output tensor will typically want to use the same quantization parameter
1012
+ initializers as the Transpose node's input.
1013
+ """
1014
+ while self.tensors_to_quantize:
1015
+ for tensor_name, tensor_info in self.tensors_to_quantize.copy().items():
1016
+ quant_provider = tensor_info.quant_para_provider
1017
+ if quant_provider and quant_provider.input_name in self.quantized_value_map:
1018
+ del self.tensors_to_quantize[tensor_name]
1019
+
1020
+ quantized_value = self.quantized_value_map[quant_provider.input_name].get_for_consumer(
1021
+ quant_provider.node_name
1022
+ )
1023
+ if self.is_input_a_initializer(tensor_name):
1024
+ raise ValueError("Quantization parameter shared mode is not supported for weight yet")
1025
+
1026
+ if tensor_name in self.tensor_to_producing_dq:
1027
+ raise ValueError(
1028
+ f"Quantization parameter sharing is invalid for tensor {tensor_name} "
1029
+ "because it has already been quantized"
1030
+ )
1031
+
1032
+ # Need to check if this tensor's quant_type is converted for some consumers.
1033
+ # If so, create new scale/zp initializers for these consumers.
1034
+ converted_qparam_inits = None
1035
+ converted_recv_nodes = None
1036
+ if tensor_name in self.quantization_params:
1037
+ tensor_params = self.quantization_params[tensor_name]
1038
+ if tensor_params.converted:
1039
+ converted_qparam_inits = self._make_scale_zp_initializers(
1040
+ tensor_name, tensor_params.converted, "_convert"
1041
+ )
1042
+ converted_recv_nodes = tensor_params.converted_recv_nodes
1043
+
1044
+ if converted_qparam_inits is None:
1045
+ # Normal case: <producer> --> Q_shared --> DQ_shared --> <consumers>
1046
+ self._add_qdq_pair_for_activation(
1047
+ tensor_name, quantized_value.scale_name, quantized_value.zp_name
1048
+ )
1049
+ else:
1050
+ # Conversion case: <producer> ---> Q_shared -+-> DQ_shared --> <consumers of original type>
1051
+ # |
1052
+ # +-> DQ_shared' --> Q2 --> DQ2 --> <consumers of converted type>
1053
+ self._add_qdq_ops_for_converted_activation(
1054
+ tensor_name,
1055
+ quantized_value.scale_name,
1056
+ quantized_value.zp_name,
1057
+ converted_qparam_inits.scale.data_type,
1058
+ converted_qparam_inits.scale.name,
1059
+ converted_qparam_inits.zero_point.name,
1060
+ converted_recv_nodes,
1061
+ )
1062
+
1063
+ def _quantize_bias_tensors(self):
1064
+ """
1065
+ Adds DQ ops (or Cast) for bias tensors that have been marked for quantization by op quantizers.
1066
+ """
1067
+ for bias_name, bias_info in self.bias_to_quantize.items():
1068
+ if bias_name in self.quantized_value_map:
1069
+ continue
1070
+ # Quantize the input
1071
+ self.quantize_bias_static(bias_name, bias_info)
1072
+ init = find_by_name(bias_name, self.model.initializer())
1073
+ self.model.remove_initializer(init)
1074
+ quant_value = self.quantized_value_map[bias_name].original
1075
+ if quant_value.node_type == "Cast":
1076
+ # simple cast to float 16 and not DequantizeLinear
1077
+ # cublasLtMatmul only supports (b)float16, float bias.
1078
+ if not isinstance(init.data_type, int):
1079
+ raise TypeError(f"Unexpected type {type(init.data_type)} for input={bias_info.input_name!r}")
1080
+ node_name = add_dequant_suffix(bias_name)
1081
+ dequant_node = onnx.helper.make_node(
1082
+ "Cast",
1083
+ [quant_value.q_name],
1084
+ [bias_name],
1085
+ name=node_name,
1086
+ to=init.data_type,
1087
+ )
1088
+ elif quant_value.node_type in (None, "DequantizeLinear"):
1089
+ if quant_value.node_qtype in {
1090
+ onnx.TensorProto.FLOAT16,
1091
+ onnx.TensorProto.BFLOAT16,
1092
+ onnx.TensorProto.FLOAT,
1093
+ }:
1094
+ raise RuntimeError(f"Unexpected quantize type {quant_value.node_qtype} for DequantizeLinear.")
1095
+ inputs = [quant_value.q_name, quant_value.scale_name, quant_value.zp_name]
1096
+ node_name = add_dequant_suffix(bias_name)
1097
+ if quant_value.axis is not None:
1098
+ dequant_node = onnx.helper.make_node(
1099
+ "DequantizeLinear",
1100
+ inputs,
1101
+ [bias_name],
1102
+ node_name,
1103
+ axis=quant_value.axis,
1104
+ domain=self.qdq_op_domain,
1105
+ )
1106
+ else:
1107
+ dequant_node = onnx.helper.make_node(
1108
+ "DequantizeLinear",
1109
+ inputs,
1110
+ [bias_name],
1111
+ node_name,
1112
+ domain=self.qdq_op_domain,
1113
+ )
1114
+ else:
1115
+ raise RuntimeError(f"Unexpected operator type {quant_value.node_type!r}.")
1116
+ self.model.add_node(dequant_node)
1117
+
1118
+ def is_tensor_quantized(self, tensor_name: str):
1119
+ return tensor_name in self.tensors_to_quantize or tensor_name in self.bias_to_quantize
1120
+
1121
+ def is_tensor_per_channel(
1122
+ self,
1123
+ tensor_name: str,
1124
+ default_axis: int,
1125
+ op_type: str | None = None,
1126
+ ) -> tuple[bool, int | None]:
1127
+ """
1128
+ Checks if a given tensor is configured to be quantized per-channel. If so, also returns the channel axis.
1129
+
1130
+ ORT only supports per-channel quantization on static weights (i.e., ONNX initializers). If the user did not provide
1131
+ tensor quantization overrides for this tensor, then the value of self.per_channel determines if the weight
1132
+ is to be quantized per-channel.
1133
+
1134
+ Params:
1135
+ tensor_name: The name of the tensor to check.
1136
+ default_axis: The default channel axis. This method checks if the normalized axis is within bounds.
1137
+ Can be overridden via the extra_options 'QDQOpTypePerChannelSupportToAxis'
1138
+ and 'TensorQuantOverrides'.
1139
+ op_type: Optional, defaults to None. The operator type that is the only consumer of this weight.
1140
+ Used to access the extra option 'QDQOpTypePerChannelSupportToAxis'.
1141
+ Returns:
1142
+ A tuple (is_per_channel, axis) in which the first element indicates whether the tensor is
1143
+ quantized per-channel and the second element is the channel axis.
1144
+ The returned axis is only None if the tensor is not per-channel or the axis is out of bounds.
1145
+ """
1146
+ weight_initializer = self.initializers.get(tensor_name)
1147
+ if weight_initializer is None:
1148
+ return False, None # Only support per-channel weights
1149
+
1150
+ if self.tensor_quant_overrides.has_per_tensor_overrides(tensor_name):
1151
+ return False, None # User provided per-tensor overrides for this initializer
1152
+
1153
+ has_per_chan_overrides = self.tensor_quant_overrides.has_per_channel_overrides(tensor_name)
1154
+ if not self.per_channel and not has_per_chan_overrides:
1155
+ return False, None # global self.per_channel is off and user did not provide per-channel overrides.
1156
+
1157
+ axis = self.qdq_op_type_per_channel_support_to_axis.get(op_type, default_axis) if op_type else default_axis
1158
+ if has_per_chan_overrides:
1159
+ per_chan_overrides = self.tensor_quant_overrides.get_per_channel_overrides(tensor_name)
1160
+ axis = per_chan_overrides[0]["axis"] # Prefer axis from user-specified tensor-level overrides if available
1161
+
1162
+ weight_rank = len(weight_initializer.dims)
1163
+ axis_valid, axis = normalize_axis(axis, weight_rank)
1164
+ if not axis_valid:
1165
+ logging.warning(f"Axis {axis} is out-of-range for weight '{tensor_name}' with rank {weight_rank}")
1166
+ return False, None
1167
+
1168
+ return True, axis
1169
+
1170
+ def _get_tensor_quantization_scale(self, tensor_name: str, consumer_node_name: str) -> np.ndarray | None:
1171
+ """
1172
+ Returns the quantization scale of a tensor that is consumed by the given node.
1173
+ :parameter tensor_name: The name of the tensor.
1174
+ :parameter consumer_node_name: The name of the node that consumes the tensor as input. Necessary in case
1175
+ the quantization type of the tensor was converted.
1176
+ Refer: QDQQuantizer::_add_qdq_ops_for_converted_activation.
1177
+ :returns: The quantization scale or None.
1178
+ """
1179
+ initializers = self.model.initializer()
1180
+ scale_initializer: onnx.TensorProto | None = None
1181
+
1182
+ if tensor_name in self.quantized_value_map:
1183
+ # Tensor was quantized by this tool, so get scale from initializer created by this tool run.
1184
+ scale_name = self.quantized_value_map[tensor_name].get_for_consumer(consumer_node_name).scale_name
1185
+ scale_initializer = find_by_name(scale_name, initializers)
1186
+ else:
1187
+ # Tensor was already quantized in original model, so get scale from DQ node that outputs the tensor.
1188
+ dq_node = self.tensor_to_producing_dq.get(tensor_name, None)
1189
+ if dq_node:
1190
+ scale_initializer = find_by_name(dq_node.input[1], initializers)
1191
+
1192
+ return tensor_proto_to_array(scale_initializer) if scale_initializer is not None else None
1193
+
1194
+ def quantize_bias_static(self, bias_name: str, bias_info: QDQBiasQuantInfo) -> str:
1195
+ """
1196
+ Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale
1197
+ """
1198
+
1199
+ # Handle case where bias already in quantization map
1200
+ if bias_name in self.quantized_value_map:
1201
+ return self.quantized_value_map[bias_name].original.q_name
1202
+
1203
+ # get scale for weight.
1204
+ weight_scale = self._get_tensor_quantization_scale(bias_info.weight_name, bias_info.node_name)
1205
+ if weight_scale is None:
1206
+ raise ValueError(
1207
+ f"Unable to get valid quantization scale for weight input '{bias_info.weight_name}' "
1208
+ f"when quantizing bias '{bias_name}' to int32."
1209
+ )
1210
+
1211
+ # get scale for input.
1212
+ input_scale = self._get_tensor_quantization_scale(bias_info.input_name, bias_info.node_name)
1213
+ if input_scale is None:
1214
+ raise ValueError(
1215
+ f"Unable to get valid quantization scale for input '{bias_info.input_name}' "
1216
+ f"when quantizing bias '{bias_name}' to int32."
1217
+ )
1218
+
1219
+ (
1220
+ quantized_bias_name,
1221
+ quantized_bias_scale_name,
1222
+ quantized_bias_zp_name,
1223
+ bias_scale_data,
1224
+ node_type,
1225
+ node_qtype,
1226
+ ) = self.quantize_bias_static_impl(bias_name, input_scale, weight_scale, bias_info.beta)
1227
+
1228
+ quantized_value = QuantizedValue(
1229
+ bias_name,
1230
+ quantized_bias_name,
1231
+ quantized_bias_scale_name,
1232
+ quantized_bias_zp_name,
1233
+ QuantizedValueType.Initializer,
1234
+ 0 if bias_scale_data.size > 1 else None,
1235
+ node_type=node_type,
1236
+ node_qtype=node_qtype,
1237
+ )
1238
+ self.quantized_value_map[bias_name] = QDQTensorQuantizedValue(quantized_value, None, None)
1239
+
1240
+ return quantized_bias_name
1241
+
1242
+ def _make_scale_zp_initializers(
1243
+ self, param_name: str, quant_params: QuantizationParams, init_name_suffix: str = ""
1244
+ ) -> QDQScaleZpInitializers:
1245
+ """
1246
+ Creates and returns scale and zero-point initializers for the given quantization params. The initializers are
1247
+ named:
1248
+ - {param_name}_zero_point{init_name_suffix}
1249
+ - {param_name}_scale{init_name_suffix}
1250
+ """
1251
+ zero_point = quant_params["zero_point"]
1252
+ scale = quant_params["scale"]
1253
+ zero_point_type = quant_params["quant_type"]
1254
+ axis: int | None = quant_params.get("axis")
1255
+ assert (axis is not None and len(scale.shape) == 1) or (axis is None and len(scale.shape) == 0), (
1256
+ "Wrong scale/zp shapes"
1257
+ )
1258
+ assert len(scale.shape) == len(zero_point.shape), "Scale and zero-point must have the same rank"
1259
+
1260
+ zero_point_name = param_name + "_zero_point" + init_name_suffix
1261
+ scale_name = param_name + "_scale" + init_name_suffix
1262
+
1263
+ # Add initializers to model
1264
+ init_zp = onnx.helper.make_tensor(
1265
+ zero_point_name, zero_point_type, zero_point.shape, zero_point.ravel().tolist()
1266
+ )
1267
+ self.model.add_initializer(init_zp)
1268
+
1269
+ if scale.dtype == np.float32:
1270
+ scale_type = onnx_proto.TensorProto.FLOAT
1271
+ elif scale.dtype == np.float16:
1272
+ scale_type = onnx_proto.TensorProto.FLOAT16
1273
+ else:
1274
+ raise ValueError(f"Unexpected dtype={scale.dtype} for param_name={param_name!r}")
1275
+ init_scale = onnx.helper.make_tensor(scale_name, scale_type, scale.shape, scale.ravel().tolist())
1276
+ self.model.add_initializer(init_scale)
1277
+
1278
+ return QDQScaleZpInitializers(init_scale, init_zp)
1279
+
1280
+ def _make_tensor_scale_zp_initializers(self, tensor_name: str) -> QDQTensorScaleZpInitializers | None:
1281
+ """
1282
+ Create and returns all scale/zero_point initializers for a given tensor. If the tensor is converted
1283
+ to a different quantization type, this function creates two pairs of zp/scale initializers. Otherwise,
1284
+ only one pair of zp/scale initializers is created.
1285
+ """
1286
+ if self.quantization_params is None or tensor_name not in self.quantization_params:
1287
+ logging.info(f'Quantization parameters for tensor:"{tensor_name}" not specified')
1288
+ return None
1289
+
1290
+ tensor_params = self.quantization_params[tensor_name]
1291
+ if not isinstance(tensor_params, QDQTensorQuantParams):
1292
+ raise TypeError(f"Unexpected type {type(tensor_params)} for {tensor_name!r}.")
1293
+
1294
+ original_inits = self._make_scale_zp_initializers(tensor_name, tensor_params.original)
1295
+ converted_inits = (
1296
+ self._make_scale_zp_initializers(tensor_name, tensor_params.converted, "_convert")
1297
+ if tensor_params.converted
1298
+ else None
1299
+ )
1300
+
1301
+ return QDQTensorScaleZpInitializers(original_inits, converted_inits, tensor_params.converted_recv_nodes)
1302
+
1303
+ def calc_quant_params(self, tensor_data: TensorData, quant_overrides: dict[str, Any]) -> QuantizationParams:
1304
+ """
1305
+ Calculates quantization parameters (scale/zero-point) given a tensor's min/max range and optional
1306
+ user-provided overrides.
1307
+ """
1308
+ quant_type = self.activation_qType
1309
+ if "quant_type" in quant_overrides:
1310
+ quant_type = quant_overrides["quant_type"].tensor_type
1311
+
1312
+ if "scale" in quant_overrides and "zero_point" in quant_overrides:
1313
+ zero, scale = quant_overrides["zero_point"], quant_overrides["scale"]
1314
+ elif quant_type == onnx.TensorProto.FLOAT8E4M3FN:
1315
+ zero, scale = compute_scale_zp_float8(quant_type, tensor_data.avg_std[1])
1316
+ else:
1317
+ rmin = quant_overrides.get("rmin", tensor_data.range_value[0])
1318
+ rmax = quant_overrides.get("rmax", tensor_data.range_value[1])
1319
+ symmetric = quant_overrides.get("symmetric", self.is_activation_symmetric)
1320
+ reduce_range = quant_overrides.get("reduce_range", False)
1321
+ qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range=reduce_range, symmetric=symmetric)
1322
+ zero, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, self.min_real_range)
1323
+
1324
+ return QuantizationParams(zero_point=zero.squeeze(), scale=scale.squeeze(), quant_type=quant_type)
1325
+
1326
+ def calc_graph_quant_params(self) -> dict[str, QDQTensorQuantParams]:
1327
+ """
1328
+ Calculates quantization parameters (scale/zero-point) for all tensors in the graph using each tensor's min/max range
1329
+ and optional user-provided overrides.
1330
+ """
1331
+ if self.tensors_range is None:
1332
+ return {}
1333
+
1334
+ self.adjust_tensor_ranges()
1335
+
1336
+ quantization_params = {}
1337
+ for tensor_name in self.tensors_range:
1338
+ td = self.tensors_range[tensor_name]
1339
+ if not isinstance(td, TensorData):
1340
+ raise TypeError(f"Unexpected type {type(td)} for {tensor_name!r}.")
1341
+
1342
+ quant_overrides = self.tensor_quant_overrides.get_per_tensor_overrides(tensor_name, default_val={})
1343
+ original = self.calc_quant_params(td, quant_overrides)
1344
+ converted = None
1345
+ converted_recv_nodes = None
1346
+
1347
+ if "convert" in quant_overrides:
1348
+ converted = self.calc_quant_params(td, quant_overrides["convert"])
1349
+ converted_recv_nodes = quant_overrides["convert"].get("recv_nodes")
1350
+
1351
+ quantization_params[tensor_name] = QDQTensorQuantParams(original, converted, converted_recv_nodes)
1352
+
1353
+ return quantization_params
1354
+
1355
+ def _calc_initializer_quant_params(self) -> dict[str, QuantizationParams]:
1356
+ """
1357
+ Returns quantization parameters (scale/zero_point/quant_type) for all initializers.
1358
+ """
1359
+
1360
+ quantization_params: dict[str, QuantizationParams] = {}
1361
+ for tensor_name, tensor_info in self.tensors_to_quantize.items():
1362
+ initializer = find_by_name(tensor_name, self.model.initializer())
1363
+ if not initializer:
1364
+ continue
1365
+
1366
+ initializer_data = tensor_proto_to_array(initializer)
1367
+ initializer_rank = len(initializer_data.shape)
1368
+
1369
+ # initializers for elementwise ops use the quant_type for activations.
1370
+ is_weight = tensor_info.tensor_type is QDQQuantTensorType.WEIGHT
1371
+ quant_type = self.weight_qType if is_weight else self.activation_qType
1372
+
1373
+ # Try to get scale/zp directly from user's overrides and avoid computation.
1374
+ if self.tensor_quant_overrides.overrides_scale_zp(tensor_name):
1375
+ overrides = self.tensor_quant_overrides[tensor_name]
1376
+ if "quant_type" in overrides[0]:
1377
+ quant_type = overrides[0]["quant_type"].tensor_type
1378
+
1379
+ zp_dtype = ONNX_TYPE_TO_NP_TYPE[quant_type]
1380
+ is_per_channel = "axis" in overrides[0]
1381
+ if not is_per_channel:
1382
+ quantization_params[tensor_name] = QuantizationParams(
1383
+ zero_point=np.array(overrides[0]["zero_point"], dtype=zp_dtype),
1384
+ scale=np.array(overrides[0]["scale"], initializer_data.dtype),
1385
+ quant_type=quant_type,
1386
+ )
1387
+ else:
1388
+ zero_points_list = []
1389
+ scales_list = []
1390
+ for chan_overrides in overrides:
1391
+ zero_points_list.append(np.array(chan_overrides["zero_point"], zp_dtype))
1392
+ scales_list.append(np.array(chan_overrides["scale"], dtype=initializer_data.dtype))
1393
+
1394
+ channel_axis = overrides[0]["axis"]
1395
+ is_axis_valid, norm_channel_axis = normalize_axis(channel_axis, initializer_rank)
1396
+ if not is_axis_valid:
1397
+ raise ValueError(
1398
+ f"Weight {initializer.name} has a per-channel axis with value {channel_axis} that is "
1399
+ f"out-of-bounds for rank {initializer_rank}"
1400
+ )
1401
+
1402
+ quantization_params[tensor_name] = QuantizationParams(
1403
+ zero_point=np.array(zero_points_list),
1404
+ scale=np.array(scales_list),
1405
+ quant_type=quant_type,
1406
+ axis=norm_channel_axis,
1407
+ )
1408
+
1409
+ continue
1410
+
1411
+ # Compute scale/zp normally. User's overrides may still override parameters
1412
+ # used to compute the scale/zp (e.g., rmin, rmax, symmetric, etc.)
1413
+ overrides = self.tensor_quant_overrides.get(tensor_name, [{}])
1414
+ if "quant_type" in overrides[0]:
1415
+ quant_type = overrides[0]["quant_type"].tensor_type
1416
+
1417
+ channel_axis = overrides[0].get("axis", tensor_info.axis)
1418
+ is_per_channel = channel_axis is not None
1419
+
1420
+ # Note: always quantize per-channel initializers as symmetric because QLinear* ops require the
1421
+ # same zero-point in every channel, which is necessarily the case for symmetric quantization.
1422
+ is_symmetric_default = is_per_channel or (
1423
+ self.is_weight_symmetric(quant_type) if is_weight else self.is_activation_symmetric
1424
+ )
1425
+ is_symmetric = overrides[0].get("symmetric", is_symmetric_default)
1426
+ reduce_range = overrides[0].get("reduce_range", self.reduce_range)
1427
+ zero_point: np.ndarray | None = None
1428
+ scale: np.ndarray | None = None
1429
+
1430
+ if not is_per_channel:
1431
+ zero_point, scale = compute_data_quant_params(
1432
+ initializer_data.flatten(),
1433
+ quant_type,
1434
+ is_symmetric,
1435
+ reduce_range=reduce_range,
1436
+ min_real_range=self.min_real_range,
1437
+ rmin_override=overrides[0].get("rmin"),
1438
+ rmax_override=overrides[0].get("rmax"),
1439
+ )
1440
+ else:
1441
+ is_axis_valid, norm_channel_axis = normalize_axis(channel_axis, initializer_rank)
1442
+ if not is_axis_valid:
1443
+ raise ValueError(
1444
+ f"Weight {initializer.name} has a per-channel axis with value {channel_axis} that is "
1445
+ f"out-of-bounds for rank {initializer_rank}"
1446
+ )
1447
+
1448
+ channel_axis = norm_channel_axis
1449
+ channel_count = initializer_data.shape[channel_axis]
1450
+ zero_points_list = []
1451
+ scales_list = []
1452
+ for i in range(channel_count):
1453
+ per_channel_data = initializer_data.take(i, channel_axis)
1454
+ channel_overrides = overrides[i] if overrides and i < len(overrides) else {}
1455
+ channel_zero_point, channel_scale = compute_data_quant_params(
1456
+ per_channel_data.ravel(),
1457
+ quant_type,
1458
+ is_symmetric,
1459
+ reduce_range=reduce_range,
1460
+ min_real_range=self.min_real_range,
1461
+ rmin_override=channel_overrides.get("rmin"),
1462
+ rmax_override=channel_overrides.get("rmax"),
1463
+ )
1464
+ zero_points_list.append(channel_zero_point)
1465
+ scales_list.append(channel_scale)
1466
+
1467
+ zero_point = np.asarray(zero_points_list)
1468
+ scale = np.asarray(scales_list)
1469
+
1470
+ quantization_params[tensor_name] = QuantizationParams(
1471
+ zero_point=zero_point,
1472
+ scale=scale,
1473
+ quant_type=quant_type,
1474
+ axis=channel_axis,
1475
+ )
1476
+
1477
+ return quantization_params