onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1163 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ import logging
7
+
8
+ import numpy as np
9
+ import onnx
10
+ import onnx.numpy_helper
11
+ from onnx import onnx_pb as onnx_proto
12
+
13
+ from .base_quantizer import BaseQuantizer, QuantizationParams
14
+ from .calibrate import TensorData
15
+ from .onnx_model import ONNXModel
16
+ from .quant_utils import (
17
+ TENSOR_NAME_QUANT_SUFFIX,
18
+ QuantizationMode,
19
+ QuantizedValue,
20
+ QuantizedValueType,
21
+ __producer__,
22
+ __version__,
23
+ add_infer_metadata,
24
+ attribute_to_kwarg,
25
+ compute_scale_zp,
26
+ compute_scale_zp_float8,
27
+ find_by_name,
28
+ get_qmin_qmax_for_qType,
29
+ get_qrange_for_qType,
30
+ ms_domain,
31
+ quantize_onnx_initializer,
32
+ save_and_reload_model_with_shape_infer,
33
+ tensor_proto_to_array,
34
+ )
35
+ from .registry import CreateOpQuantizer
36
+
37
+
38
+ class ONNXQuantizer(BaseQuantizer):
39
+ def __init__(
40
+ self,
41
+ model,
42
+ per_channel,
43
+ reduce_range,
44
+ mode,
45
+ static,
46
+ weight_qType,
47
+ activation_qType,
48
+ tensors_range,
49
+ nodes_to_quantize,
50
+ nodes_to_exclude,
51
+ op_types_to_quantize,
52
+ extra_options=None,
53
+ ):
54
+ BaseQuantizer.__init__(
55
+ self,
56
+ model,
57
+ per_channel,
58
+ reduce_range,
59
+ weight_qType,
60
+ activation_qType,
61
+ tensors_range,
62
+ nodes_to_quantize,
63
+ nodes_to_exclude,
64
+ op_types_to_quantize,
65
+ extra_options,
66
+ )
67
+
68
+ if not static:
69
+ self.model.replace_gemm_with_matmul()
70
+ # We need to update value_infos.
71
+ model = save_and_reload_model_with_shape_infer(self.model.model)
72
+ self.value_infos = {vi.name: vi for vi in model.graph.value_info}
73
+ self.value_infos.update({ot.name: ot for ot in model.graph.output})
74
+ self.value_infos.update({it.name: it for it in model.graph.input})
75
+ self.model = ONNXModel(model)
76
+
77
+ self.mode = mode # QuantizationMode.Value
78
+ self.static = static # use static quantization for inputs.
79
+ self.fuse_dynamic_quant = self.opset_version > 10
80
+
81
+ self.q_matmul_const_b_only = "MatMulConstBOnly" in self.extra_options and self.extra_options["MatMulConstBOnly"]
82
+
83
+ self.new_nodes = []
84
+ self.graph_scope = "/" # for human readable debug information
85
+ self.tensor_names = {} # in case the shape inference not totally working
86
+ self.tensor_names.update({ot.name: 1 for ot in model.graph.output})
87
+ self.tensor_names.update({it.name: 1 for it in model.graph.input})
88
+ for node in self.model.model.graph.node:
89
+ self.tensor_names.update(dict.fromkeys(node.output, 1))
90
+
91
+ if self.mode not in QuantizationMode:
92
+ raise ValueError(f"unsupported quantization mode {self.mode}")
93
+
94
+ self.quantization_params = self.calculate_quantization_params()
95
+
96
+ # QuantizeRange tensor name and zero tensor name for scale and zero point calculation.
97
+ # Used when static is False
98
+ self.fixed_qrange_uint8_name = "fixed_quantization_range_uint8"
99
+ self.fixed_qrange_int8_name = "fixed_quantization_range_int8"
100
+ # For uint8 data-type, to compute zero point, we subtract rmin from 0 (represented by fixed_zero_name tensor)
101
+ self.fixed_zero_name = "fixed_zero"
102
+ # For int8 data-type, zero point is always zero (respresented by fixed_zero_point_name tensor)
103
+ self.fixed_zero_zp_name = "fixed_zero_zp"
104
+
105
+ # Map of all original value names to quantized value names
106
+ self.quantized_value_map = {}
107
+ # some output from nodes will be quantized, yet itself should be treat as existing so
108
+ # no dequantized will be applied when needed later
109
+ self.generated_value_names = self.model.get_non_initializer_inputs()
110
+
111
+ # routines for subgraph support
112
+ def quantize_subgraph(self, subgraph, graph_key):
113
+ """
114
+ generate submodel for the subgraph, so that we re-utilize current quantization implementation.
115
+ quantize the submodel
116
+ update subgraph and set it back to node
117
+ """
118
+ warped_model = onnx.helper.make_model(
119
+ subgraph,
120
+ producer_name="onnx-quantizer",
121
+ opset_imports=self.model.model.opset_import,
122
+ )
123
+ add_infer_metadata(warped_model)
124
+ sub_quantizer = ONNXQuantizer(
125
+ warped_model,
126
+ self.per_channel,
127
+ self.reduce_range,
128
+ self.mode,
129
+ self.static,
130
+ self.weight_qType,
131
+ self.activation_qType,
132
+ self.tensors_range,
133
+ self.nodes_to_quantize,
134
+ self.nodes_to_exclude,
135
+ self.op_types_to_quantize,
136
+ self.extra_options,
137
+ )
138
+ sub_quantizer.parent = self
139
+ sub_quantizer.graph_scope = f"{self.graph_scope}{graph_key}/"
140
+ sub_quantizer.quantize_model()
141
+ return sub_quantizer.model.model.graph
142
+
143
+ def quantize_node_with_sub_graph(self, node):
144
+ """
145
+ Check subgraph, if any, quantize it and replace it.
146
+ return new_nodes added for quantizing subgraph
147
+ """
148
+ graph_attrs = [
149
+ attr
150
+ for attr in node.attribute
151
+ if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS
152
+ ]
153
+ if len(graph_attrs) == 0:
154
+ return node
155
+ node_name = node.name if node.name else f"{node.op_type}_node_count_{len(self.new_nodes)}"
156
+ kwargs = {}
157
+ for attr in node.attribute:
158
+ if attr.type == onnx.AttributeProto.GRAPH:
159
+ kv = {attr.name: self.quantize_subgraph(attr.g, f"{node_name}:{attr.name}")}
160
+ elif attr.type == onnx.AttributeProto.GRAPHS:
161
+ value = []
162
+ for subgraph in attr.graphs:
163
+ value.extend(
164
+ [
165
+ self.quantize_subgraph(
166
+ subgraph,
167
+ f"{node_name}:{attr.name}:{len(value)}",
168
+ )
169
+ ]
170
+ )
171
+ kv = {attr.name: value}
172
+ else:
173
+ kv = attribute_to_kwarg(attr)
174
+ kwargs.update(kv)
175
+ return onnx.helper.make_node(node.op_type, node.input, node.output, name=node.name, **kwargs)
176
+
177
+ def has_QDQ_nodes(self): # noqa: N802
178
+ """
179
+ Detect if model already has QuantizeLinear or DequantizeLinear.
180
+ """
181
+ return any(
182
+ node.op_type == "QuantizeLinear" or node.op_type == "DequantizeLinear" for node in self.model.nodes()
183
+ )
184
+
185
+ def find_initializer_in_path(self, initializer_name):
186
+ if find_by_name(initializer_name, self.model.initializer()) is not None:
187
+ return True
188
+ if self.parent is not None:
189
+ return self.parent.find_initializer_in_path(initializer_name)
190
+ return False
191
+
192
+ def add_new_nodes(self, nodes):
193
+ self.new_nodes.extend(nodes)
194
+ for node in nodes:
195
+ for output_name in node.output:
196
+ self.generated_value_names.add(output_name)
197
+
198
+ def quantize_model(self):
199
+ if self.has_QDQ_nodes():
200
+ logging.warning(
201
+ "Please check if the model is already quantized. "
202
+ "Note you don't need to quantize a QAT model. OnnxRuntime support to run QAT model directly."
203
+ )
204
+
205
+ for node in self.model.nodes():
206
+ # quantize subgraphes if have
207
+ if self.enable_subgraph_quantization:
208
+ node = self.quantize_node_with_sub_graph(node) # noqa: PLW2901
209
+
210
+ number_of_existing_new_nodes = len(self.new_nodes)
211
+ op_quantizer = CreateOpQuantizer(self, node)
212
+ op_quantizer.quantize()
213
+ for i in range(number_of_existing_new_nodes, len(self.new_nodes)):
214
+ for output_name in self.new_nodes[i].output:
215
+ self.generated_value_names.add(output_name)
216
+
217
+ self._dequantize_outputs()
218
+
219
+ # extend is used to append to the list for a protobuf fields
220
+ # https://developers.google.com/protocol-buffers/docs/reference/python-generated?csw=1#fields
221
+ self.model.graph().ClearField("node")
222
+ self.model.graph().node.extend(self.new_nodes)
223
+
224
+ # Remove ununsed initializers from graph, starting from the top level graph.
225
+ if self.parent is None:
226
+ _, initializers_not_found = self.model.clean_initializers()
227
+ if len(initializers_not_found) > 0:
228
+ raise RuntimeError("Invalid model with unknown initializers/tensors." + str(initializers_not_found))
229
+
230
+ self.model.model.producer_name = __producer__
231
+ self.model.model.producer_version = __version__
232
+ # Add ms domain if needed
233
+ ms_opset = [opset for opset in self.model.model.opset_import if opset.domain == ms_domain]
234
+ if not ms_opset:
235
+ ms_nodes = [node for node in self.new_nodes if node.domain == "com.microsoft"]
236
+ if ms_nodes:
237
+ opset = self.model.model.opset_import.add()
238
+ opset.version = 1
239
+ opset.domain = ms_domain
240
+
241
+ return self.model.model
242
+
243
+ def _get_default_tensor_type(self, tensor_name):
244
+ if "DefaultTensorType" in self.extra_options:
245
+ logging.info(
246
+ "get_tensor_type returns DefaultTensorType for tensor name %r, use %d",
247
+ tensor_name,
248
+ self.extra_options["DefaultTensorType"],
249
+ )
250
+ return self.extra_options["DefaultTensorType"]
251
+ raise RuntimeError(
252
+ f"Unable to find data type for weight_name={tensor_name!r}. "
253
+ f"shape_inference failed to return a type probably this node is "
254
+ f"from a different domain or using an input produced by such an operator. "
255
+ f"This may happen if you quantize a model already quantized. "
256
+ f"You may use extra_options `DefaultTensorType` to indicate "
257
+ f"the default weight type, usually `onnx.TensorProto.FLOAT`."
258
+ )
259
+
260
+ def get_tensor_type(self, tensor_name, mandatory=False):
261
+ weight = find_by_name(tensor_name, self.model.initializer())
262
+ if weight is not None:
263
+ return weight.data_type
264
+ if tensor_name in self.value_infos:
265
+ vi = self.value_infos[tensor_name]
266
+ if vi.type.HasField("tensor_type"):
267
+ if mandatory and vi.type.tensor_type.elem_type == 0:
268
+ return self._get_default_tensor_type(tensor_name)
269
+ return vi.type.tensor_type.elem_type
270
+ if (not self.enable_subgraph_quantization) or (self.parent is None):
271
+ if mandatory:
272
+ return self._get_default_tensor_type(tensor_name)
273
+ return None
274
+ otype = self.parent.is_valid_quantize_weight(tensor_name)
275
+ if otype is not None:
276
+ return otype
277
+ if self.enable_subgraph_quantization and self.parent:
278
+ res = self.parent.get_tensor_type(tensor_name)
279
+ if res is not None:
280
+ return res
281
+ if mandatory:
282
+ return self._get_default_tensor_type(tensor_name)
283
+ return None
284
+
285
+ def is_float_tensor(self, tensor_name):
286
+ if self.is_input_a_initializer(tensor_name):
287
+ return self.is_valid_quantize_weight(tensor_name)
288
+
289
+ if tensor_name in self.value_infos:
290
+ vi = self.value_infos[tensor_name]
291
+ if vi.type.HasField("tensor_type") and vi.type.tensor_type.elem_type in (
292
+ onnx_proto.TensorProto.FLOAT,
293
+ onnx_proto.TensorProto.FLOAT16,
294
+ ):
295
+ return True
296
+ logging.warning(
297
+ f"Inference failed or unsupported type to quantize for tensor {tensor_name!r}, type is {vi.type}."
298
+ )
299
+ return False
300
+
301
+ if self.enable_subgraph_quantization and self.parent:
302
+ return self.parent.is_float_tensor(tensor_name)
303
+
304
+ logging.warning(
305
+ f"Failed to infer data type of tensor: {tensor_name!r}. Please add data type info for this tensor "
306
+ f"if your model has customized operators."
307
+ )
308
+ return False
309
+
310
+ def _get_dynamic_input_quantization_params(self, input_name, nodes_list, qType, initial_type):
311
+ """
312
+ Create nodes for dynamic quantization of input and add them to nodes_list.
313
+ parameter input_name: Name of the input.
314
+ parameter nodes_list: new nodes are appended to this list.
315
+ parameter qType: type to quantize to.
316
+ parameter initial_type: type to quantize from
317
+ return: scale_name, zero_point_name, scale_shape, zero_point_shape.
318
+ """
319
+ if qType == onnx_proto.TensorProto.INT8:
320
+ return self._get_dynamic_input_quantization_params_int8(input_name, nodes_list, initial_type)
321
+ if qType == onnx_proto.TensorProto.UINT8:
322
+ return self._get_dynamic_input_quantization_params_uint8(input_name, nodes_list, initial_type)
323
+ raise ValueError(f"Unexpected value for qType={qType}.")
324
+
325
+ def _get_dynamic_input_quantization_params_int8(self, input_name, nodes_list, initial_type):
326
+ """
327
+ Create nodes for dynamic quantization of input to int8 and add them to nodes_list
328
+ parameter input_name: Name of the input.
329
+ parameter nodes_list: new nodes are appended to this list.
330
+ parameter initial_type: initial weight type (FLOAT or FLOAT16)
331
+ return: scale_name, zero_point_name, scale_shape, zero_point_shape.
332
+ """
333
+ qType = onnx_proto.TensorProto.INT8 # noqa: N806
334
+
335
+ # Reduce min and Reduce max
336
+ input_scale_name = input_name + "_scale"
337
+
338
+ reduce_min_name = input_name + "_ReduceMin"
339
+ reduce_min_node = onnx.helper.make_node(
340
+ "ReduceMin",
341
+ [input_name],
342
+ [reduce_min_name + ":0"],
343
+ reduce_min_name,
344
+ keepdims=0,
345
+ )
346
+ nodes_list.append(reduce_min_node)
347
+
348
+ reduce_max_name = input_name + "_ReduceMax"
349
+ reduce_max_node = onnx.helper.make_node(
350
+ "ReduceMax",
351
+ [input_name],
352
+ [reduce_max_name + ":0"],
353
+ reduce_max_name,
354
+ keepdims=0,
355
+ )
356
+ nodes_list.append(reduce_max_node)
357
+
358
+ # Compute scale
359
+ # Find abs(rmin)
360
+ reduce_min_abs_name = reduce_min_name + "_Abs"
361
+ reduce_min_abs_node = onnx.helper.make_node(
362
+ "Abs",
363
+ [reduce_min_node.output[0]],
364
+ [reduce_min_abs_name + ":0"],
365
+ reduce_min_abs_name,
366
+ )
367
+ nodes_list.append(reduce_min_abs_node)
368
+ # Find abs(rmax)
369
+ reduce_max_abs_name = reduce_max_name + "_Abs"
370
+ reduce_max_abs_node = onnx.helper.make_node(
371
+ "Abs",
372
+ [reduce_max_node.output[0]],
373
+ [reduce_max_abs_name + ":0"],
374
+ reduce_max_abs_name,
375
+ )
376
+ nodes_list.append(reduce_max_abs_node)
377
+ # Compute max of abs(rmin) and abs(rmax)
378
+ abs_max_name = input_name + "_Abs_Max"
379
+ abs_max_node = onnx.helper.make_node(
380
+ "Max",
381
+ [reduce_min_abs_node.output[0], reduce_max_abs_node.output[0]],
382
+ [abs_max_name + ":0"],
383
+ abs_max_name,
384
+ )
385
+ nodes_list.append(abs_max_node)
386
+ # and divide by (quantize_range/2.0) which will be equal to max(...)*2.0/quantize_range
387
+ initializer_div = onnx.helper.make_tensor(
388
+ self.fixed_qrange_int8_name,
389
+ initial_type,
390
+ [],
391
+ [get_qrange_for_qType(qType) / 2.0],
392
+ )
393
+ self.model.add_initializer(initializer_div)
394
+ scale_div_name = input_name + "scale_Div"
395
+ scale_div_node = onnx.helper.make_node(
396
+ "Div",
397
+ [abs_max_node.output[0], self.fixed_qrange_int8_name],
398
+ [input_scale_name],
399
+ scale_div_name,
400
+ )
401
+ nodes_list.append(scale_div_node)
402
+
403
+ # Zero point
404
+ initializer_zp = onnx.helper.make_tensor(self.fixed_zero_zp_name, qType, [], [0])
405
+ self.model.add_initializer(initializer_zp)
406
+
407
+ return input_scale_name, self.fixed_zero_zp_name, [], []
408
+
409
+ def _get_dynamic_input_quantization_params_uint8(self, input_name, nodes_list, initial_type):
410
+ """
411
+ Create nodes for dynamic quantization of input to uint8 and add them to nodes_list
412
+ parameter input_name: Name of the input.
413
+ parameter nodes_list: new nodes are appended to this list.
414
+ parameter initial_type: initial weight type (FLAOT or FLOAT16)
415
+ return: scale_name, zero_point_name, scale_shape, zero_point_shape.
416
+ """
417
+ qType = onnx_proto.TensorProto.UINT8 # noqa: N806
418
+ # Reduce min and Reduce max
419
+ input_scale_name = input_name + "_scale"
420
+ input_zp_name = input_name + "_zero_point"
421
+
422
+ reduce_min_name = input_name + "_ReduceMin"
423
+ reduce_min_node = onnx.helper.make_node(
424
+ "ReduceMin",
425
+ [input_name],
426
+ [reduce_min_name + ":0"],
427
+ reduce_min_name,
428
+ keepdims=0,
429
+ )
430
+ nodes_list.append(reduce_min_node)
431
+
432
+ reduce_max_name = input_name + "_ReduceMax"
433
+ reduce_max_node = onnx.helper.make_node(
434
+ "ReduceMax",
435
+ [input_name],
436
+ [reduce_max_name + ":0"],
437
+ reduce_max_name,
438
+ keepdims=0,
439
+ )
440
+ nodes_list.append(reduce_max_node)
441
+
442
+ # Add tensors for quantize range and zero value.
443
+ initializer_qrange = onnx.helper.make_tensor(
444
+ self.fixed_qrange_uint8_name,
445
+ initial_type,
446
+ [],
447
+ [get_qrange_for_qType(qType)],
448
+ )
449
+ self.model.add_initializer(initializer_qrange)
450
+ initializer_qvalue = onnx.helper.make_tensor(self.fixed_zero_name, initial_type, [], [0.0])
451
+ self.model.add_initializer(initializer_qvalue)
452
+
453
+ # Compute Scale
454
+ # Subtract rmax and rmin
455
+ scale_sub_name = input_name + "_scale_Sub"
456
+ scale_sub_node = onnx.helper.make_node(
457
+ "Sub",
458
+ [reduce_max_node.output[0], reduce_min_node.output[0]],
459
+ [scale_sub_name + ":0"],
460
+ scale_sub_name,
461
+ )
462
+ nodes_list.append(scale_sub_node)
463
+ # and divide by quantize range
464
+ scale_div_name = input_name + "_scale_Div"
465
+ scale_div_node = onnx.helper.make_node(
466
+ "Div",
467
+ [scale_sub_node.output[0], self.fixed_qrange_uint8_name],
468
+ [input_scale_name],
469
+ scale_div_name,
470
+ )
471
+ nodes_list.append(scale_div_node)
472
+
473
+ # Compute zero point
474
+ # Subtract zero and rmin
475
+ zp_sub_name = input_name + "_zero_point_Sub"
476
+ zp_sub_node = onnx.helper.make_node(
477
+ "Sub",
478
+ [self.fixed_zero_name, reduce_min_node.output[0]],
479
+ [zp_sub_name + ":0"],
480
+ zp_sub_name,
481
+ )
482
+ nodes_list.append(zp_sub_node)
483
+ # Divide by scale
484
+ zp_div_name = input_name + "_zero_point_Div"
485
+ zp_div_node = onnx.helper.make_node(
486
+ "Div",
487
+ [zp_sub_node.output[0], input_scale_name],
488
+ [zp_div_name + ":0"],
489
+ zp_div_name,
490
+ )
491
+ nodes_list.append(zp_div_node)
492
+ # Compute floor
493
+ zp_floor_name = input_name + "_zero_point_Floor"
494
+ zp_floor_node = onnx.helper.make_node("Floor", zp_div_node.output, [zp_floor_name + ":0"], zp_floor_name)
495
+ nodes_list.append(zp_floor_node)
496
+ # Cast to integer
497
+ zp_cast_name = input_name + "_zero_point_Cast"
498
+ zp_cast_node = onnx.helper.make_node("Cast", zp_floor_node.output, [input_zp_name], zp_cast_name, to=qType)
499
+ nodes_list.append(zp_cast_node)
500
+
501
+ return input_scale_name, input_zp_name, [], []
502
+
503
+ def _get_quantization_params(self, param_name, use_scale=None, use_zeropoint=None):
504
+ """
505
+ Create initializers and inputs in the graph for zero point and scale of output.
506
+ Zero point and scale values are obtained from self.quantization_params if specified.
507
+ parameter param_name: Name of the quantization parameter.
508
+ return: result, scale_name, zero_point_name, scale_shape, zero_point_shape.
509
+ """
510
+ zero_point_type = self.activation_qType
511
+
512
+ if use_scale is None or use_zeropoint is None:
513
+ if self.quantization_params is None or param_name not in self.quantization_params:
514
+ logging.info(f'Quantization parameters for tensor:"{param_name}" not specified')
515
+ return False, "", "", "", ""
516
+
517
+ params = self.quantization_params[param_name]
518
+ if not isinstance(params, QuantizationParams):
519
+ raise TypeError(f"Unexpected type {type(params)} for {param_name!r}.")
520
+ if params is None or len(params) != 3:
521
+ raise ValueError(
522
+ "Quantization parameters should contain zero point, scale, quant type. "
523
+ f"Specified values for output {param_name}: {params}"
524
+ )
525
+
526
+ zero_point_values = np.array([params["zero_point"]])
527
+ if not hasattr(params["scale"], "dtype") or params["scale"].dtype not in (np.float32, np.float16):
528
+ raise ValueError(f"Unexpected type {type(params['scale'])} and param_name={param_name!r}")
529
+ scale_values = np.array([params["scale"]])
530
+ assert scale_values.dtype != np.float64
531
+ zero_point_type = params["quant_type"]
532
+ else:
533
+ zero_point_values = np.array([use_zeropoint])
534
+ scale_values = np.array([use_scale])
535
+ params = self.quantization_params[param_name]
536
+ if "scale" in params:
537
+ dtype = params["scale"].dtype
538
+ scale_values = scale_values.astype(dtype)
539
+ assert scale_values.dtype != np.float64
540
+
541
+ zero_point_shape = []
542
+ zero_point_name = param_name + "_zero_point"
543
+ scale_shape = []
544
+ scale_name = param_name + "_scale"
545
+
546
+ # Add initializers
547
+ init_zp = onnx.helper.make_tensor(
548
+ zero_point_name, zero_point_type, zero_point_shape, zero_point_values.ravel().tolist()
549
+ )
550
+ self.model.add_initializer(init_zp)
551
+ if scale_values.dtype == np.float32:
552
+ scale_type = onnx_proto.TensorProto.FLOAT
553
+ elif scale_values.dtype == np.float16:
554
+ scale_type = onnx_proto.TensorProto.FLOAT16
555
+ else:
556
+ raise ValueError(f"Unexpected dtype={scale_values.dtype} for param_name={param_name!r}")
557
+ init_scale = onnx.helper.make_tensor(scale_name, scale_type, scale_shape, scale_values.reshape((-1,)).tolist())
558
+ self.model.add_initializer(init_scale)
559
+
560
+ return True, scale_name, zero_point_name, scale_shape, zero_point_shape
561
+
562
+ def _get_quantize_input_nodes(
563
+ self, node, input_index, qType, given_scale_name=None, given_zp_name=None, initial_type=None
564
+ ):
565
+ """
566
+ Given an input for a node (which is not a initializer), this function
567
+
568
+ - add nodes to compute zero point and scale for this input if they don't exist.
569
+ - add new QuantizeLinear node to quantize the input.
570
+
571
+ :param node: node being quantized in NodeProto format.
572
+ :param input_index: index of input in node.input.
573
+ :param qType: type to quantize to.
574
+ :param given_scale_name: if those inputs need to be quanitzed using this scale tensor.
575
+ :param given_zp_name: if those inputs to be quantized using this zeropoint tensor.
576
+ :param initial_type: type of the weight to quantize
577
+ :return: List of newly created nodes in NodeProto format.
578
+ """
579
+ input_name = node.input[input_index]
580
+ assert input_name != "", "Cannot access undefined variable in graph."
581
+ output_name = input_name + TENSOR_NAME_QUANT_SUFFIX
582
+ ql_node_name = input_name + "_QuantizeLinear"
583
+
584
+ if (given_scale_name is not None) and (given_zp_name is not None):
585
+ data_found, scale_name, zp_name = (True, given_scale_name, given_zp_name)
586
+ else:
587
+ data_found, scale_name, zp_name, _, _ = self._get_quantization_params(input_name)
588
+
589
+ nodes = []
590
+ if data_found:
591
+ qlinear_node = onnx.helper.make_node(
592
+ "QuantizeLinear",
593
+ [input_name, scale_name, zp_name],
594
+ [output_name],
595
+ ql_node_name,
596
+ )
597
+ else:
598
+ if self.static:
599
+ return None
600
+ # dynamic mode
601
+ # Scale and Zero Points not available for this input. Add nodes to dynamically compute it
602
+ if self.fuse_dynamic_quant and qType == onnx_proto.TensorProto.UINT8:
603
+ scale_name = input_name + "_scale"
604
+ zp_name = input_name + "_zero_point"
605
+ qlinear_node = onnx.helper.make_node(
606
+ "DynamicQuantizeLinear",
607
+ [input_name],
608
+ [output_name, scale_name, zp_name],
609
+ ql_node_name,
610
+ )
611
+ else:
612
+ assert initial_type is not None, (
613
+ f"Cannot quantize input without knowing the initial type, "
614
+ f"input_name={input_name!r}, input_index={input_index}, qType={qType}, node={node}"
615
+ )
616
+ (
617
+ scale_name,
618
+ zp_name,
619
+ scale_shape,
620
+ zp_shape,
621
+ ) = self._get_dynamic_input_quantization_params(input_name, nodes, qType, initial_type=initial_type)
622
+ qlinear_node = onnx.helper.make_node(
623
+ "QuantizeLinear",
624
+ [input_name, scale_name, zp_name],
625
+ [output_name],
626
+ ql_node_name,
627
+ )
628
+
629
+ self.quantized_value_map[input_name] = QuantizedValue(input_name, output_name, scale_name, zp_name, qType)
630
+ return [*nodes, qlinear_node]
631
+
632
+ def find_quantized_value(self, input_name):
633
+ if input_name in self.quantized_value_map:
634
+ return self.quantized_value_map[input_name]
635
+ if self.parent is not None:
636
+ return self.parent.find_quantized_value(input_name)
637
+ return None
638
+
639
+ def adjust_single_weight_scale_if_needed(
640
+ self,
641
+ bias_val,
642
+ input_scale,
643
+ weight_scale,
644
+ weight_scale_dtype,
645
+ weight_name,
646
+ bias_name,
647
+ qrange,
648
+ multiplicative_epsilon,
649
+ idx=None,
650
+ ):
651
+ """Adjust a single weight scale to ensure the int32 bias does not overflow."""
652
+ absmax = np.abs(bias_val)
653
+ bias_smallest_valid_scale = multiplicative_epsilon * (2.0 * absmax) / qrange
654
+
655
+ input_scale_fp64 = np.array(input_scale.item(), dtype=np.float64)
656
+ weight_scale_fp64 = np.array(weight_scale.item(), dtype=np.float64)
657
+ bias_candidate_scale = input_scale_fp64 * weight_scale_fp64
658
+
659
+ if (bias_candidate_scale < bias_smallest_valid_scale) and (bias_candidate_scale > 0.0):
660
+ ratio = bias_smallest_valid_scale / bias_candidate_scale
661
+ new_scale = weight_scale_fp64 * ratio
662
+ if idx is None:
663
+ logging.info(
664
+ f"Increasing scale for weight `{weight_name}` by the ratio {ratio} to "
665
+ f"ensure bias `{bias_name}` has a valid scale."
666
+ )
667
+ return True, np.array(new_scale, dtype=weight_scale_dtype)
668
+ else:
669
+ logging.info(
670
+ f"Increased scale[{idx}] for weight `{weight_name}` by ratio {ratio} "
671
+ f"to ensure bias `{bias_name}` has a valid scale."
672
+ )
673
+ return True, new_scale.astype(weight_scale_dtype)
674
+ return False, weight_scale
675
+
676
+ def _adjust_weight_scale_for_int32_bias(
677
+ self,
678
+ input_scale: np.ndarray,
679
+ weight_scale: np.ndarray,
680
+ weight_name: str,
681
+ bias_tp: onnx.TensorProto,
682
+ is_per_channel: bool,
683
+ ) -> tuple[bool, np.ndarray | None]:
684
+ """Checks if the bias scale is too small and increases the weight scale if needed."""
685
+
686
+ if not weight_scale.size:
687
+ return False, None
688
+
689
+ bias_float_data = tensor_proto_to_array(bias_tp)
690
+ int32_info = np.iinfo(np.int32)
691
+ multiplicative_epsilon = 1.0001
692
+ qrange = np.array(int32_info.max, dtype=np.float64) - np.array(int32_info.min + 1, dtype=np.float64)
693
+ weight_scale_dtype = weight_scale.dtype
694
+ updated = False
695
+
696
+ if not is_per_channel:
697
+ rmin = np.minimum(bias_float_data.min(), np.array(0, dtype=np.float64))
698
+ rmax = np.maximum(bias_float_data.max(), np.array(0, dtype=np.float64))
699
+ absmax = np.maximum(np.abs(rmin), np.abs(rmax))
700
+ changed, new_scale = self.adjust_single_weight_scale_if_needed(
701
+ absmax,
702
+ input_scale,
703
+ weight_scale,
704
+ weight_scale_dtype,
705
+ weight_name,
706
+ bias_tp.name,
707
+ qrange,
708
+ multiplicative_epsilon,
709
+ )
710
+ if changed:
711
+ weight_scale = new_scale
712
+ updated = True
713
+ elif weight_scale.shape and len(weight_scale.shape) == 1:
714
+ for i in range(weight_scale.shape[0]):
715
+ changed, new_scale = self.adjust_single_weight_scale_if_needed(
716
+ bias_float_data[i],
717
+ input_scale,
718
+ weight_scale[i],
719
+ weight_scale_dtype,
720
+ weight_name,
721
+ bias_tp.name,
722
+ qrange,
723
+ multiplicative_epsilon,
724
+ idx=i,
725
+ )
726
+ if changed:
727
+ weight_scale[i] = new_scale
728
+ updated = True
729
+
730
+ return updated, weight_scale
731
+
732
+ def _requantize_weight(self, weight_name: str, new_scale: np.ndarray) -> None:
733
+ """Re-quantizes the given weight initializer using the provided scale."""
734
+
735
+ if weight_name not in self.quantized_value_map:
736
+ return
737
+
738
+ qv = self.quantized_value_map[weight_name]
739
+
740
+ weight_tp = find_by_name(weight_name, self.model.initializer())
741
+ scale_init = find_by_name(qv.scale_name, self.model.initializer())
742
+ zp_init = find_by_name(qv.zp_name, self.model.initializer())
743
+ q_weight_init = find_by_name(qv.q_name, self.model.initializer())
744
+
745
+ if weight_tp is None or scale_init is None or zp_init is None or q_weight_init is None:
746
+ return
747
+
748
+ self.model.remove_initializer(scale_init)
749
+ self.model.remove_initializer(q_weight_init)
750
+
751
+ weight_zero_point = onnx.numpy_helper.to_array(zp_init)
752
+ axis = qv.axis
753
+
754
+ # Add new scale initializer
755
+ scale_np = np.asarray(new_scale, dtype=onnx.helper.tensor_dtype_to_np_dtype(weight_tp.data_type))
756
+ new_scale_init = onnx.numpy_helper.from_array(scale_np.reshape(scale_init.dims), qv.scale_name)
757
+ self.model.add_initializer(new_scale_init)
758
+
759
+ # Add new quantized weight initializer
760
+ new_q_weight = quantize_onnx_initializer(
761
+ weight_tp,
762
+ self.weight_qType,
763
+ weight_zero_point,
764
+ scale_np,
765
+ axis,
766
+ quant_weight_name=qv.q_name,
767
+ )
768
+ self.model.add_initializer(new_q_weight)
769
+
770
+ def quantize_bias_static(self, bias_name, input_name, weight_name, beta=1.0):
771
+ """
772
+ Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale
773
+ """
774
+
775
+ # Handle case where bias already in quantization map
776
+ if bias_name in self.quantized_value_map:
777
+ return self.quantized_value_map[bias_name].q_name
778
+
779
+ # get scale for weight
780
+ weight_scale_name = self.quantized_value_map[weight_name].scale_name
781
+ weight_initializer = find_by_name(weight_scale_name, self.model.initializer())
782
+ weight_scale = tensor_proto_to_array(weight_initializer)
783
+
784
+ # get scale for input
785
+ if input_name in self.quantized_value_map:
786
+ input_scale_name = self.quantized_value_map[input_name].scale_name
787
+ elif input_name in self.quantization_params:
788
+ _, input_scale_name, _, _, _ = self._get_quantization_params(input_name)
789
+ else:
790
+ raise ValueError(f"Expected {input_name} to be in quantized value map for static quantization")
791
+
792
+ inputscale_initializer = find_by_name(input_scale_name, self.model.initializer())
793
+ input_scale = tensor_proto_to_array(inputscale_initializer)
794
+
795
+ # Adjust weight scale if quantizing to int32 may overflow due to a small scale
796
+ weight_zp_name = self.quantized_value_map[weight_name].zp_name
797
+ weight_zp_init = find_by_name(weight_zp_name, self.model.initializer())
798
+ weight_zero_point = onnx.numpy_helper.to_array(weight_zp_init) if weight_zp_init is not None else None
799
+ is_per_channel = self.per_channel
800
+ if (
801
+ weight_zero_point is not None
802
+ and weight_zero_point.size
803
+ and not weight_zero_point.any()
804
+ and self.weight_qType in (onnx_proto.TensorProto.INT8,)
805
+ ):
806
+ bias_initializer = find_by_name(bias_name, self.model.initializer())
807
+ did_update, new_weight_scale = self._adjust_weight_scale_for_int32_bias(
808
+ input_scale,
809
+ weight_scale,
810
+ weight_name,
811
+ bias_initializer,
812
+ is_per_channel,
813
+ )
814
+ if did_update:
815
+ self._requantize_weight(weight_name, new_weight_scale)
816
+ weight_scale = new_weight_scale
817
+
818
+ (
819
+ quantized_bias_name,
820
+ quantized_bias_scale_name,
821
+ quantized_bias_zp_name,
822
+ bias_scale_data,
823
+ node_type,
824
+ node_qtype,
825
+ ) = self.quantize_bias_static_impl(bias_name, input_scale, weight_scale, beta)
826
+
827
+ assert bias_name not in self.quantized_value_map
828
+ quantized_value = QuantizedValue(
829
+ bias_name,
830
+ quantized_bias_name,
831
+ quantized_bias_scale_name,
832
+ quantized_bias_zp_name,
833
+ QuantizedValueType.Initializer,
834
+ 0 if bias_scale_data.size > 1 else None,
835
+ node_type=node_type,
836
+ node_qtype=node_qtype,
837
+ )
838
+ self.quantized_value_map[bias_name] = quantized_value
839
+
840
+ return quantized_bias_name
841
+
842
+ def contains_tensor(self, tensor_name):
843
+ """
844
+ only check for value info and newly generated tensor names, initializers are checked separately
845
+ """
846
+ return (
847
+ (tensor_name in self.value_infos)
848
+ or (tensor_name in self.tensor_names)
849
+ or (tensor_name in self.generated_value_names)
850
+ )
851
+
852
+ def quantize_activation(self, node, indices, from_subgraph=False):
853
+ return self.__quantize_inputs(
854
+ node=node,
855
+ indices=indices,
856
+ initializer_use_weight_qType=False,
857
+ reduce_range=False,
858
+ op_level_per_channel=False,
859
+ axis=-1,
860
+ from_subgraph=from_subgraph,
861
+ )
862
+
863
+ # In some circumstances a weight is not an initializer, for example of MatMul, if both A and B are not
864
+ # initializer, B can still be considered as Weight
865
+ def quantize_weight(
866
+ self,
867
+ node,
868
+ indices,
869
+ reduce_range=False,
870
+ op_level_per_channel=False,
871
+ axis=-1,
872
+ from_subgraph=False,
873
+ ):
874
+ return self.__quantize_inputs(
875
+ node=node,
876
+ indices=indices,
877
+ initializer_use_weight_qType=True,
878
+ reduce_range=reduce_range,
879
+ op_level_per_channel=op_level_per_channel,
880
+ axis=axis,
881
+ from_subgraph=from_subgraph,
882
+ )
883
+
884
+ def __quantize_inputs(
885
+ self,
886
+ node,
887
+ indices,
888
+ initializer_use_weight_qType=True,
889
+ reduce_range=False,
890
+ op_level_per_channel=False,
891
+ axis=-1,
892
+ from_subgraph=False,
893
+ ):
894
+ """
895
+ Given a node, this function quantizes the inputs as follows:
896
+ - If input is an initializer, quantize the initializer data, replace old initializer
897
+ with new initializer
898
+ - Else, add QuantizeLinear nodes to perform quantization
899
+ parameter node: node being quantized in NodeProto format.
900
+ parameter indices: input indices to quantize.
901
+ return: (List of quantized input names,
902
+ List of zero point names used for input quantization,
903
+ List of scale names used for input quantization,
904
+ List of new QuantizeLinear nodes created)
905
+ """
906
+
907
+ scale_names = []
908
+ zero_point_names = []
909
+ quantized_input_names = []
910
+ nodes = []
911
+
912
+ for input_index in indices:
913
+ node_input = node.input[input_index]
914
+
915
+ # Find if this input is already quantized
916
+ if node_input in self.quantized_value_map:
917
+ quantized_value = self.quantized_value_map[node_input]
918
+ scale_names.append(quantized_value.scale_name)
919
+ zero_point_names.append(quantized_value.zp_name)
920
+ quantized_input_names.append(quantized_value.q_name)
921
+ continue
922
+ # adding this for case embed_layernorm.py has optional segment_embedding
923
+ if not node_input:
924
+ quantized_input_names.append("")
925
+ scale_names.append("")
926
+ zero_point_names.append("")
927
+ continue
928
+ # Quantize the input
929
+ initializer = find_by_name(node_input, self.model.initializer())
930
+ if initializer is not None:
931
+ if self.per_channel and op_level_per_channel:
932
+ (
933
+ q_weight_name,
934
+ zp_name,
935
+ scale_name,
936
+ ) = self.quantize_weight_per_channel(
937
+ initializer.name,
938
+ self.weight_qType if initializer_use_weight_qType else self.activation_qType,
939
+ axis,
940
+ reduce_range,
941
+ )
942
+ else:
943
+ q_weight_name, zp_name, scale_name = self.quantize_initializer(
944
+ initializer,
945
+ self.weight_qType if initializer_use_weight_qType else self.activation_qType,
946
+ reduce_range,
947
+ )
948
+
949
+ quantized_input_names.append(q_weight_name)
950
+ zero_point_names.append(zp_name)
951
+ scale_names.append(scale_name)
952
+ elif self.contains_tensor(node_input):
953
+ # Add QuantizeLinear node.
954
+ qlinear_node = self.model.find_node_by_name(
955
+ node_input + "_QuantizeLinear", self.new_nodes, self.model.graph()
956
+ )
957
+ if qlinear_node is None:
958
+ input_name = node.input[input_index]
959
+ if input_name in self.value_infos:
960
+ value_info = self.value_infos[input_name]
961
+ assert value_info.HasField("type"), f"value_info={value_info} has no type."
962
+ assert value_info.type.HasField("tensor_type"), f"value_info={value_info} is not a tensor."
963
+ initial_type = value_info.type.tensor_type.elem_type
964
+ else:
965
+ # Shape inference failed. Fallback to self.tensor_names.
966
+ assert input_name in self.tensor_names, (
967
+ f"shape inference failed for {input_name!r} and "
968
+ f"attribute 'tensor_names' does not have any value for "
969
+ f"this tensor."
970
+ )
971
+ initial_type = self.tensor_names[input_name]
972
+ quantize_input_nodes = self._get_quantize_input_nodes(
973
+ node, input_index, self.activation_qType, initial_type=initial_type
974
+ )
975
+ if quantize_input_nodes is None:
976
+ return (None, None, None, None)
977
+ if from_subgraph:
978
+ self.add_new_nodes(quantize_input_nodes)
979
+ else:
980
+ nodes.extend(quantize_input_nodes)
981
+ qlinear_node = quantize_input_nodes[-1]
982
+
983
+ if qlinear_node.op_type == "QuantizeLinear":
984
+ quantized_input_names.extend(qlinear_node.output)
985
+ scale_names.append(qlinear_node.input[1])
986
+ zero_point_names.append(qlinear_node.input[2])
987
+ else:
988
+ quantized_input_names.append(qlinear_node.output[0])
989
+ scale_names.append(qlinear_node.output[1])
990
+ zero_point_names.append(qlinear_node.output[2])
991
+ elif self.parent is not None:
992
+ (
993
+ parent_quantized_input_names,
994
+ parent_zero_point_names,
995
+ parent_scale_names,
996
+ _,
997
+ ) = self.parent.__quantize_inputs(
998
+ node,
999
+ [input_index],
1000
+ initializer_use_weight_qType=initializer_use_weight_qType,
1001
+ reduce_range=reduce_range,
1002
+ op_level_per_channel=op_level_per_channel,
1003
+ axis=axis,
1004
+ from_subgraph=True,
1005
+ )
1006
+ quantized_input_names.append(parent_quantized_input_names[0])
1007
+ scale_names.append(parent_scale_names[0])
1008
+ zero_point_names.append(parent_zero_point_names[0])
1009
+ # node should not be add this child level here
1010
+ else:
1011
+ raise ValueError(f"Invalid tensor name to quantize: {node_input} @graph scope{self.graph_scope}")
1012
+
1013
+ return quantized_input_names, zero_point_names, scale_names, nodes
1014
+
1015
+ def quantize_initializer(self, weight, qType, reduce_range=False, keep_float_weight=False):
1016
+ """
1017
+ :param weight: TensorProto initializer
1018
+ :param qType: type to quantize to
1019
+ :param keep_float_weight: Whether to quantize the weight. In some cases, we only want to qunatize scale and zero point.
1020
+ If keep_float_weight is False, quantize the weight, or don't quantize the weight.
1021
+ :return: quantized weight name, zero point name, scale name
1022
+ """
1023
+ # Find if this input is already quantized
1024
+ if weight.name in self.quantized_value_map:
1025
+ quantized_value = self.quantized_value_map[weight.name]
1026
+ return (
1027
+ quantized_value.q_name,
1028
+ quantized_value.zp_name,
1029
+ quantized_value.scale_name,
1030
+ )
1031
+
1032
+ q_weight_name, zp_name, scale_name = self.quantize_initializer_impl(
1033
+ weight, qType, reduce_range, keep_float_weight
1034
+ )
1035
+
1036
+ # Log entry for this quantized weight
1037
+ quantized_value = QuantizedValue(
1038
+ weight.name,
1039
+ q_weight_name,
1040
+ scale_name,
1041
+ zp_name,
1042
+ QuantizedValueType.Initializer,
1043
+ None,
1044
+ )
1045
+ self.quantized_value_map[weight.name] = quantized_value
1046
+ return q_weight_name, zp_name, scale_name
1047
+
1048
+ def quantize_weight_per_channel(
1049
+ self,
1050
+ weight_name,
1051
+ weight_qType,
1052
+ channel_axis,
1053
+ reduce_range=True,
1054
+ keep_float_weight=False,
1055
+ ):
1056
+ # Find if this input is already quantized
1057
+ if weight_name in self.quantized_value_map:
1058
+ quantized_value = self.quantized_value_map[weight_name]
1059
+ return (
1060
+ quantized_value.q_name,
1061
+ quantized_value.zp_name,
1062
+ quantized_value.scale_name,
1063
+ )
1064
+
1065
+ q_weight_name, zp_name, scale_name = self.quantize_weight_per_channel_impl(
1066
+ weight_name, weight_qType, channel_axis, reduce_range, keep_float_weight
1067
+ )
1068
+ quantized_value = QuantizedValue(
1069
+ weight_name,
1070
+ q_weight_name,
1071
+ scale_name,
1072
+ zp_name,
1073
+ QuantizedValueType.Initializer,
1074
+ None,
1075
+ )
1076
+ self.quantized_value_map[weight_name] = quantized_value
1077
+
1078
+ return q_weight_name, zp_name, scale_name
1079
+
1080
+ def _dequantize_value(self, value_name):
1081
+ """
1082
+ Given a value (input/output) which is quantized, add a DequantizeLinear node to dequantize
1083
+ it back to float32 or float16
1084
+ parameter value_name: value to dequantize
1085
+ parameter new_nodes_list: List of new nodes created before processing current node
1086
+ return: None if there is already a DequantizeLinear node that dequantizes it
1087
+ A DequantizeLinear node otherwise
1088
+ """
1089
+ if (value_name in self.quantized_value_map) and (value_name not in self.generated_value_names):
1090
+ quantized_value = self.quantized_value_map[value_name]
1091
+ # Add DequantizeLinear Node for this input
1092
+
1093
+ scale_init = find_by_name(quantized_value.scale_name, self.model.initializer())
1094
+
1095
+ # In case we are working with subgraphs, the graph `producer_name` is set to `"onnx-quantizer"` in the `quantize_subgraph` method. In this case, the scale initializer may be on the top level graph, so the check below can not be done.
1096
+ if self.model.model.producer_name != "onnx-quantizer" or (
1097
+ self.model.model.producer_name == "onnx-quantizer" and scale_init is not None
1098
+ ):
1099
+ # axis is not specified so scale_init must be a scalar.
1100
+ assert scale_init is None or onnx.numpy_helper.to_array(scale_init).size == 1
1101
+
1102
+ dqlinear_name = value_name + "_DequantizeLinear"
1103
+ dqlinear_node = self.model.find_node_by_name(dqlinear_name, self.new_nodes, self.model.graph())
1104
+ if dqlinear_node is None:
1105
+ dqlinear_inputs = [
1106
+ quantized_value.q_name,
1107
+ quantized_value.scale_name,
1108
+ quantized_value.zp_name,
1109
+ ]
1110
+ dequantize_node = onnx.helper.make_node(
1111
+ "DequantizeLinear", dqlinear_inputs, [value_name], dqlinear_name
1112
+ )
1113
+ return dequantize_node
1114
+ else:
1115
+ # DQ op is already present, assert it's output matches the input of current node
1116
+ assert value_name == dqlinear_node.output[0]
1117
+ return None
1118
+
1119
+ def _dequantize_outputs(self):
1120
+ """
1121
+ Dequantize output if it is quantized
1122
+ parameter new_nodes_list: List of new nodes created before processing current node
1123
+ return: List of new nodes created
1124
+ """
1125
+
1126
+ for output in self.model.graph().output:
1127
+ dequantize_node = self._dequantize_value(output.name)
1128
+ if dequantize_node is not None:
1129
+ self.new_nodes.append(dequantize_node)
1130
+
1131
+ def calculate_quantization_params(self):
1132
+ if self.tensors_range is None:
1133
+ return None
1134
+
1135
+ self.adjust_tensor_ranges()
1136
+
1137
+ quantization_params = {}
1138
+ for tensor_name in self.tensors_range:
1139
+ td = self.tensors_range[tensor_name]
1140
+ if not isinstance(td, TensorData):
1141
+ raise TypeError(f"Unexpected type {type(td)} for {tensor_name!r}.")
1142
+
1143
+ quant_overrides = self.tensor_quant_overrides.get_per_tensor_overrides(tensor_name, default_val={})
1144
+
1145
+ quant_type = self.activation_qType
1146
+ if "quant_type" in quant_overrides:
1147
+ quant_type = quant_overrides["quant_type"].tensor_type
1148
+
1149
+ if "scale" in quant_overrides and "zero_point" in quant_overrides:
1150
+ zero, scale = quant_overrides["zero_point"], quant_overrides["scale"]
1151
+ elif quant_type == onnx.TensorProto.FLOAT8E4M3FN:
1152
+ zero, scale = compute_scale_zp_float8(quant_type, td.avg_std[1])
1153
+ else:
1154
+ rmin = quant_overrides.get("rmin", td.range_value[0])
1155
+ rmax = quant_overrides.get("rmax", td.range_value[1])
1156
+ symmetric = quant_overrides.get("symmetric", self.is_activation_symmetric)
1157
+ reduce_range = quant_overrides.get("reduce_range", False)
1158
+ qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range=reduce_range, symmetric=symmetric)
1159
+ zero, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, self.min_real_range)
1160
+
1161
+ quantization_params[tensor_name] = QuantizationParams(zero_point=zero, scale=scale, quant_type=quant_type)
1162
+
1163
+ return quantization_params