onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1008 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ import logging
7
+
8
+ import numpy as np
9
+ import onnx
10
+ import onnx.numpy_helper
11
+ from onnx import onnx_pb as onnx_proto
12
+
13
+ from .base_quantizer import BaseQuantizer, QuantizationParams
14
+ from .calibrate import TensorData
15
+ from .onnx_model import ONNXModel
16
+ from .quant_utils import (
17
+ TENSOR_NAME_QUANT_SUFFIX,
18
+ QuantizationMode,
19
+ QuantizedValue,
20
+ QuantizedValueType,
21
+ __producer__,
22
+ __version__,
23
+ add_infer_metadata,
24
+ attribute_to_kwarg,
25
+ compute_scale_zp,
26
+ compute_scale_zp_float8,
27
+ find_by_name,
28
+ get_qmin_qmax_for_qType,
29
+ get_qrange_for_qType,
30
+ ms_domain,
31
+ save_and_reload_model_with_shape_infer,
32
+ tensor_proto_to_array,
33
+ )
34
+ from .registry import CreateOpQuantizer
35
+
36
+
37
+ class ONNXQuantizer(BaseQuantizer):
38
+ def __init__(
39
+ self,
40
+ model,
41
+ per_channel,
42
+ reduce_range,
43
+ mode,
44
+ static,
45
+ weight_qType,
46
+ activation_qType,
47
+ tensors_range,
48
+ nodes_to_quantize,
49
+ nodes_to_exclude,
50
+ op_types_to_quantize,
51
+ extra_options=None,
52
+ ):
53
+ BaseQuantizer.__init__(
54
+ self,
55
+ model,
56
+ per_channel,
57
+ reduce_range,
58
+ weight_qType,
59
+ activation_qType,
60
+ tensors_range,
61
+ nodes_to_quantize,
62
+ nodes_to_exclude,
63
+ op_types_to_quantize,
64
+ extra_options,
65
+ )
66
+
67
+ if not static:
68
+ self.model.replace_gemm_with_matmul()
69
+ # We need to update value_infos.
70
+ model = save_and_reload_model_with_shape_infer(self.model.model)
71
+ self.value_infos = {vi.name: vi for vi in model.graph.value_info}
72
+ self.value_infos.update({ot.name: ot for ot in model.graph.output})
73
+ self.value_infos.update({it.name: it for it in model.graph.input})
74
+ self.model = ONNXModel(model)
75
+
76
+ self.mode = mode # QuantizationMode.Value
77
+ self.static = static # use static quantization for inputs.
78
+ self.fuse_dynamic_quant = self.opset_version > 10
79
+
80
+ self.q_matmul_const_b_only = "MatMulConstBOnly" in self.extra_options and self.extra_options["MatMulConstBOnly"]
81
+
82
+ self.new_nodes = []
83
+ self.graph_scope = "/" # for human readable debug information
84
+ self.tensor_names = {} # in case the shape inference not totally working
85
+ self.tensor_names.update({ot.name: 1 for ot in model.graph.output})
86
+ self.tensor_names.update({it.name: 1 for it in model.graph.input})
87
+ for node in self.model.model.graph.node:
88
+ self.tensor_names.update({output_name: 1 for output_name in node.output})
89
+
90
+ if self.mode not in QuantizationMode:
91
+ raise ValueError(f"unsupported quantization mode {self.mode}")
92
+
93
+ self.quantization_params = self.calculate_quantization_params()
94
+
95
+ # QuantizeRange tensor name and zero tensor name for scale and zero point calculation.
96
+ # Used when static is False
97
+ self.fixed_qrange_uint8_name = "fixed_quantization_range_uint8"
98
+ self.fixed_qrange_int8_name = "fixed_quantization_range_int8"
99
+ # For uint8 data-type, to compute zero point, we subtract rmin from 0 (represented by fixed_zero_name tensor)
100
+ self.fixed_zero_name = "fixed_zero"
101
+ # For int8 data-type, zero point is always zero (respresented by fixed_zero_point_name tensor)
102
+ self.fixed_zero_zp_name = "fixed_zero_zp"
103
+
104
+ # Map of all original value names to quantized value names
105
+ self.quantized_value_map = {}
106
+ # some output from nodes will be quantized, yet itself should be treat as existing so
107
+ # no dequantized will be applied when needed later
108
+ self.generated_value_names = self.model.get_non_initializer_inputs()
109
+
110
+ # routines for subgraph support
111
+ def quantize_subgraph(self, subgraph, graph_key):
112
+ """
113
+ generate submodel for the subgraph, so that we re-utilize current quantization implementation.
114
+ quantize the submodel
115
+ update subgraph and set it back to node
116
+ """
117
+ warped_model = onnx.helper.make_model(
118
+ subgraph,
119
+ producer_name="onnx-quantizer",
120
+ opset_imports=self.model.model.opset_import,
121
+ )
122
+ add_infer_metadata(warped_model)
123
+ sub_quantizer = ONNXQuantizer(
124
+ warped_model,
125
+ self.per_channel,
126
+ self.reduce_range,
127
+ self.mode,
128
+ self.static,
129
+ self.weight_qType,
130
+ self.activation_qType,
131
+ self.tensors_range,
132
+ self.nodes_to_quantize,
133
+ self.nodes_to_exclude,
134
+ self.op_types_to_quantize,
135
+ self.extra_options,
136
+ )
137
+ sub_quantizer.parent = self
138
+ sub_quantizer.graph_scope = f"{self.graph_scope}{graph_key}/"
139
+ sub_quantizer.quantize_model()
140
+ return sub_quantizer.model.model.graph
141
+
142
+ def quantize_node_with_sub_graph(self, node):
143
+ """
144
+ Check subgraph, if any, quantize it and replace it.
145
+ return new_nodes added for quantizing subgraph
146
+ """
147
+ graph_attrs = [
148
+ attr
149
+ for attr in node.attribute
150
+ if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS
151
+ ]
152
+ if len(graph_attrs) == 0:
153
+ return node
154
+ node_name = node.name if node.name else f"{node.op_type}_node_count_{len(self.new_nodes)}"
155
+ kwargs = {}
156
+ for attr in node.attribute:
157
+ if attr.type == onnx.AttributeProto.GRAPH:
158
+ kv = {attr.name: self.quantize_subgraph(attr.g, f"{node_name}:{attr.name}")}
159
+ elif attr.type == onnx.AttributeProto.GRAPHS:
160
+ value = []
161
+ for subgraph in attr.graphs:
162
+ value.extend(
163
+ [
164
+ self.quantize_subgraph(
165
+ subgraph,
166
+ f"{node_name}:{attr.name}:{len(value)}",
167
+ )
168
+ ]
169
+ )
170
+ kv = {attr.name: value}
171
+ else:
172
+ kv = attribute_to_kwarg(attr)
173
+ kwargs.update(kv)
174
+ return onnx.helper.make_node(node.op_type, node.input, node.output, name=node.name, **kwargs)
175
+
176
+ def has_QDQ_nodes(self): # noqa: N802
177
+ """
178
+ Detect if model already has QuantizeLinear or DequantizeLinear.
179
+ """
180
+ return any(
181
+ node.op_type == "QuantizeLinear" or node.op_type == "DequantizeLinear" for node in self.model.nodes()
182
+ )
183
+
184
+ def find_initializer_in_path(self, initializer_name):
185
+ if find_by_name(initializer_name, self.model.initializer()) is not None:
186
+ return True
187
+ if self.parent is not None:
188
+ return self.parent.find_initializer_in_path(initializer_name)
189
+ return False
190
+
191
+ def add_new_nodes(self, nodes):
192
+ self.new_nodes.extend(nodes)
193
+ for node in nodes:
194
+ for output_name in node.output:
195
+ self.generated_value_names.add(output_name)
196
+
197
+ def quantize_model(self):
198
+ if self.has_QDQ_nodes():
199
+ logging.warning(
200
+ "Please check if the model is already quantized. "
201
+ "Note you don't need to quantize a QAT model. OnnxRuntime support to run QAT model directly."
202
+ )
203
+
204
+ for node in self.model.nodes():
205
+ # quantize subgraphes if have
206
+ if self.enable_subgraph_quantization:
207
+ node = self.quantize_node_with_sub_graph(node) # noqa: PLW2901
208
+
209
+ number_of_existing_new_nodes = len(self.new_nodes)
210
+ op_quantizer = CreateOpQuantizer(self, node)
211
+ op_quantizer.quantize()
212
+ for i in range(number_of_existing_new_nodes, len(self.new_nodes)):
213
+ for output_name in self.new_nodes[i].output:
214
+ self.generated_value_names.add(output_name)
215
+
216
+ self._dequantize_outputs()
217
+
218
+ # extend is used to append to the list for a protobuf fields
219
+ # https://developers.google.com/protocol-buffers/docs/reference/python-generated?csw=1#fields
220
+ self.model.graph().ClearField("node")
221
+ self.model.graph().node.extend(self.new_nodes)
222
+
223
+ # Remove ununsed initializers from graph, starting from the top level graph.
224
+ if self.parent is None:
225
+ _, initializers_not_found = self.model.clean_initializers()
226
+ if len(initializers_not_found) > 0:
227
+ raise RuntimeError("Invalid model with unknown initializers/tensors." + str(initializers_not_found))
228
+
229
+ self.model.model.producer_name = __producer__
230
+ self.model.model.producer_version = __version__
231
+ # Add ms domain if needed
232
+ ms_opset = [opset for opset in self.model.model.opset_import if opset.domain == ms_domain]
233
+ if not ms_opset:
234
+ ms_nodes = [node for node in self.new_nodes if node.domain == "com.microsoft"]
235
+ if ms_nodes:
236
+ opset = self.model.model.opset_import.add()
237
+ opset.version = 1
238
+ opset.domain = ms_domain
239
+
240
+ return self.model.model
241
+
242
+ def _get_default_tensor_type(self, tensor_name):
243
+ if "DefaultTensorType" in self.extra_options:
244
+ logging.info(
245
+ "get_tensor_type returns DefaultTensorType for tensor name %r, use %d",
246
+ tensor_name,
247
+ self.extra_options["DefaultTensorType"],
248
+ )
249
+ return self.extra_options["DefaultTensorType"]
250
+ raise RuntimeError(
251
+ f"Unable to find data type for weight_name={tensor_name!r}. "
252
+ f"shape_inference failed to return a type probably this node is "
253
+ f"from a different domain or using an input produced by such an operator. "
254
+ f"This may happen if you quantize a model already quantized. "
255
+ f"You may use extra_options `DefaultTensorType` to indicate "
256
+ f"the default weight type, usually `onnx.TensorProto.FLOAT`."
257
+ )
258
+
259
+ def get_tensor_type(self, tensor_name, mandatory=False):
260
+ weight = find_by_name(tensor_name, self.model.initializer())
261
+ if weight is not None:
262
+ return weight.data_type
263
+ if tensor_name in self.value_infos:
264
+ vi = self.value_infos[tensor_name]
265
+ if vi.type.HasField("tensor_type"):
266
+ if mandatory and vi.type.tensor_type.elem_type == 0:
267
+ return self._get_default_tensor_type(tensor_name)
268
+ return vi.type.tensor_type.elem_type
269
+ if (not self.enable_subgraph_quantization) or (self.parent is None):
270
+ if mandatory:
271
+ return self._get_default_tensor_type(tensor_name)
272
+ return None
273
+ otype = self.parent.is_valid_quantize_weight(tensor_name)
274
+ if otype is not None:
275
+ return otype
276
+ if self.enable_subgraph_quantization and self.parent:
277
+ res = self.parent.get_tensor_type(tensor_name)
278
+ if res is not None:
279
+ return res
280
+ if mandatory:
281
+ return self._get_default_tensor_type(tensor_name)
282
+ return None
283
+
284
+ def is_float_tensor(self, tensor_name):
285
+ if self.is_input_a_initializer(tensor_name):
286
+ return self.is_valid_quantize_weight(tensor_name)
287
+
288
+ if tensor_name in self.value_infos:
289
+ vi = self.value_infos[tensor_name]
290
+ if vi.type.HasField("tensor_type") and vi.type.tensor_type.elem_type in (
291
+ onnx_proto.TensorProto.FLOAT,
292
+ onnx_proto.TensorProto.FLOAT16,
293
+ ):
294
+ return True
295
+ logging.warning(
296
+ f"Inference failed or unsupported type to quantize for tensor {tensor_name!r}, type is {vi.type}."
297
+ )
298
+ return False
299
+
300
+ if self.enable_subgraph_quantization and self.parent:
301
+ return self.parent.is_float_tensor(tensor_name)
302
+
303
+ logging.warning(
304
+ f"Failed to infer data type of tensor: {tensor_name!r}. Please add data type info for this tensor "
305
+ f"if your model has customized operators."
306
+ )
307
+ return False
308
+
309
+ def _get_dynamic_input_quantization_params(self, input_name, nodes_list, qType, initial_type):
310
+ """
311
+ Create nodes for dynamic quantization of input and add them to nodes_list.
312
+ parameter input_name: Name of the input.
313
+ parameter nodes_list: new nodes are appended to this list.
314
+ parameter qType: type to quantize to.
315
+ parameter initial_type: type to quantize from
316
+ return: scale_name, zero_point_name, scale_shape, zero_point_shape.
317
+ """
318
+ if qType == onnx_proto.TensorProto.INT8:
319
+ return self._get_dynamic_input_quantization_params_int8(input_name, nodes_list, initial_type)
320
+ if qType == onnx_proto.TensorProto.UINT8:
321
+ return self._get_dynamic_input_quantization_params_uint8(input_name, nodes_list, initial_type)
322
+ raise ValueError(f"Unexpected value for qType={qType}.")
323
+
324
+ def _get_dynamic_input_quantization_params_int8(self, input_name, nodes_list, initial_type):
325
+ """
326
+ Create nodes for dynamic quantization of input to int8 and add them to nodes_list
327
+ parameter input_name: Name of the input.
328
+ parameter nodes_list: new nodes are appended to this list.
329
+ parameter initial_type: initial weight type (FLOAT or FLOAT16)
330
+ return: scale_name, zero_point_name, scale_shape, zero_point_shape.
331
+ """
332
+ qType = onnx_proto.TensorProto.INT8 # noqa: N806
333
+
334
+ # Reduce min and Reduce max
335
+ input_scale_name = input_name + "_scale"
336
+
337
+ reduce_min_name = input_name + "_ReduceMin"
338
+ reduce_min_node = onnx.helper.make_node(
339
+ "ReduceMin",
340
+ [input_name],
341
+ [reduce_min_name + ":0"],
342
+ reduce_min_name,
343
+ keepdims=0,
344
+ )
345
+ nodes_list.append(reduce_min_node)
346
+
347
+ reduce_max_name = input_name + "_ReduceMax"
348
+ reduce_max_node = onnx.helper.make_node(
349
+ "ReduceMax",
350
+ [input_name],
351
+ [reduce_max_name + ":0"],
352
+ reduce_max_name,
353
+ keepdims=0,
354
+ )
355
+ nodes_list.append(reduce_max_node)
356
+
357
+ # Compute scale
358
+ # Find abs(rmin)
359
+ reduce_min_abs_name = reduce_min_name + "_Abs"
360
+ reduce_min_abs_node = onnx.helper.make_node(
361
+ "Abs",
362
+ [reduce_min_node.output[0]],
363
+ [reduce_min_abs_name + ":0"],
364
+ reduce_min_abs_name,
365
+ )
366
+ nodes_list.append(reduce_min_abs_node)
367
+ # Find abs(rmax)
368
+ reduce_max_abs_name = reduce_max_name + "_Abs"
369
+ reduce_max_abs_node = onnx.helper.make_node(
370
+ "Abs",
371
+ [reduce_max_node.output[0]],
372
+ [reduce_max_abs_name + ":0"],
373
+ reduce_max_abs_name,
374
+ )
375
+ nodes_list.append(reduce_max_abs_node)
376
+ # Compute max of abs(rmin) and abs(rmax)
377
+ abs_max_name = input_name + "_Abs_Max"
378
+ abs_max_node = onnx.helper.make_node(
379
+ "Max",
380
+ [reduce_min_abs_node.output[0], reduce_max_abs_node.output[0]],
381
+ [abs_max_name + ":0"],
382
+ abs_max_name,
383
+ )
384
+ nodes_list.append(abs_max_node)
385
+ # and divide by (quantize_range/2.0) which will be equal to max(...)*2.0/quantize_range
386
+ initializer_div = onnx.helper.make_tensor(
387
+ self.fixed_qrange_int8_name,
388
+ initial_type,
389
+ [],
390
+ [get_qrange_for_qType(qType) / 2.0],
391
+ )
392
+ self.model.add_initializer(initializer_div)
393
+ scale_div_name = input_name + "scale_Div"
394
+ scale_div_node = onnx.helper.make_node(
395
+ "Div",
396
+ [abs_max_node.output[0], self.fixed_qrange_int8_name],
397
+ [input_scale_name],
398
+ scale_div_name,
399
+ )
400
+ nodes_list.append(scale_div_node)
401
+
402
+ # Zero point
403
+ initializer_zp = onnx.helper.make_tensor(self.fixed_zero_zp_name, qType, [], [0])
404
+ self.model.add_initializer(initializer_zp)
405
+
406
+ return input_scale_name, self.fixed_zero_zp_name, [], []
407
+
408
+ def _get_dynamic_input_quantization_params_uint8(self, input_name, nodes_list, initial_type):
409
+ """
410
+ Create nodes for dynamic quantization of input to uint8 and add them to nodes_list
411
+ parameter input_name: Name of the input.
412
+ parameter nodes_list: new nodes are appended to this list.
413
+ parameter initial_type: initial weight type (FLAOT or FLOAT16)
414
+ return: scale_name, zero_point_name, scale_shape, zero_point_shape.
415
+ """
416
+ qType = onnx_proto.TensorProto.UINT8 # noqa: N806
417
+ # Reduce min and Reduce max
418
+ input_scale_name = input_name + "_scale"
419
+ input_zp_name = input_name + "_zero_point"
420
+
421
+ reduce_min_name = input_name + "_ReduceMin"
422
+ reduce_min_node = onnx.helper.make_node(
423
+ "ReduceMin",
424
+ [input_name],
425
+ [reduce_min_name + ":0"],
426
+ reduce_min_name,
427
+ keepdims=0,
428
+ )
429
+ nodes_list.append(reduce_min_node)
430
+
431
+ reduce_max_name = input_name + "_ReduceMax"
432
+ reduce_max_node = onnx.helper.make_node(
433
+ "ReduceMax",
434
+ [input_name],
435
+ [reduce_max_name + ":0"],
436
+ reduce_max_name,
437
+ keepdims=0,
438
+ )
439
+ nodes_list.append(reduce_max_node)
440
+
441
+ # Add tensors for quantize range and zero value.
442
+ initializer_qrange = onnx.helper.make_tensor(
443
+ self.fixed_qrange_uint8_name,
444
+ initial_type,
445
+ [],
446
+ [get_qrange_for_qType(qType)],
447
+ )
448
+ self.model.add_initializer(initializer_qrange)
449
+ initializer_qvalue = onnx.helper.make_tensor(self.fixed_zero_name, initial_type, [], [0.0])
450
+ self.model.add_initializer(initializer_qvalue)
451
+
452
+ # Compute Scale
453
+ # Subtract rmax and rmin
454
+ scale_sub_name = input_name + "_scale_Sub"
455
+ scale_sub_node = onnx.helper.make_node(
456
+ "Sub",
457
+ [reduce_max_node.output[0], reduce_min_node.output[0]],
458
+ [scale_sub_name + ":0"],
459
+ scale_sub_name,
460
+ )
461
+ nodes_list.append(scale_sub_node)
462
+ # and divide by quantize range
463
+ scale_div_name = input_name + "_scale_Div"
464
+ scale_div_node = onnx.helper.make_node(
465
+ "Div",
466
+ [scale_sub_node.output[0], self.fixed_qrange_uint8_name],
467
+ [input_scale_name],
468
+ scale_div_name,
469
+ )
470
+ nodes_list.append(scale_div_node)
471
+
472
+ # Compute zero point
473
+ # Subtract zero and rmin
474
+ zp_sub_name = input_name + "_zero_point_Sub"
475
+ zp_sub_node = onnx.helper.make_node(
476
+ "Sub",
477
+ [self.fixed_zero_name, reduce_min_node.output[0]],
478
+ [zp_sub_name + ":0"],
479
+ zp_sub_name,
480
+ )
481
+ nodes_list.append(zp_sub_node)
482
+ # Divide by scale
483
+ zp_div_name = input_name + "_zero_point_Div"
484
+ zp_div_node = onnx.helper.make_node(
485
+ "Div",
486
+ [zp_sub_node.output[0], input_scale_name],
487
+ [zp_div_name + ":0"],
488
+ zp_div_name,
489
+ )
490
+ nodes_list.append(zp_div_node)
491
+ # Compute floor
492
+ zp_floor_name = input_name + "_zero_point_Floor"
493
+ zp_floor_node = onnx.helper.make_node("Floor", zp_div_node.output, [zp_floor_name + ":0"], zp_floor_name)
494
+ nodes_list.append(zp_floor_node)
495
+ # Cast to integer
496
+ zp_cast_name = input_name + "_zero_point_Cast"
497
+ zp_cast_node = onnx.helper.make_node("Cast", zp_floor_node.output, [input_zp_name], zp_cast_name, to=qType)
498
+ nodes_list.append(zp_cast_node)
499
+
500
+ return input_scale_name, input_zp_name, [], []
501
+
502
+ def _get_quantization_params(self, param_name, use_scale=None, use_zeropoint=None):
503
+ """
504
+ Create initializers and inputs in the graph for zero point and scale of output.
505
+ Zero point and scale values are obtained from self.quantization_params if specified.
506
+ parameter param_name: Name of the quantization parameter.
507
+ return: result, scale_name, zero_point_name, scale_shape, zero_point_shape.
508
+ """
509
+ zero_point_type = self.activation_qType
510
+
511
+ if use_scale is None or use_zeropoint is None:
512
+ if self.quantization_params is None or param_name not in self.quantization_params:
513
+ logging.info(f'Quantization parameters for tensor:"{param_name}" not specified')
514
+ return False, "", "", "", ""
515
+
516
+ params = self.quantization_params[param_name]
517
+ if not isinstance(params, QuantizationParams):
518
+ raise TypeError(f"Unexpected type {type(params)} for {param_name!r}.")
519
+ if params is None or len(params) != 3:
520
+ raise ValueError(
521
+ "Quantization parameters should contain zero point, scale, quant type. "
522
+ f"Specified values for output {param_name}: {params}"
523
+ )
524
+
525
+ zero_point_values = np.array([params["zero_point"]])
526
+ if not hasattr(params["scale"], "dtype") or params["scale"].dtype not in (np.float32, np.float16):
527
+ raise ValueError(f"Unexpected type {type(params['scale'])} and param_name={param_name!r}")
528
+ scale_values = np.array([params["scale"]])
529
+ assert scale_values.dtype != np.float64
530
+ zero_point_type = params["quant_type"]
531
+ else:
532
+ zero_point_values = np.array([use_zeropoint])
533
+ scale_values = np.array([use_scale])
534
+ params = self.quantization_params[param_name]
535
+ if "scale" in params:
536
+ dtype = params["scale"].dtype
537
+ scale_values = scale_values.astype(dtype)
538
+ assert scale_values.dtype != np.float64
539
+
540
+ zero_point_shape = []
541
+ zero_point_name = param_name + "_zero_point"
542
+ scale_shape = []
543
+ scale_name = param_name + "_scale"
544
+
545
+ # Add initializers
546
+ init_zp = onnx.helper.make_tensor(
547
+ zero_point_name, zero_point_type, zero_point_shape, zero_point_values.ravel().tolist()
548
+ )
549
+ self.model.add_initializer(init_zp)
550
+ if scale_values.dtype == np.float32:
551
+ scale_type = onnx_proto.TensorProto.FLOAT
552
+ elif scale_values.dtype == np.float16:
553
+ scale_type = onnx_proto.TensorProto.FLOAT16
554
+ else:
555
+ raise ValueError(f"Unexpected dtype={scale_values.dtype} for param_name={param_name!r}")
556
+ init_scale = onnx.helper.make_tensor(scale_name, scale_type, scale_shape, scale_values.reshape((-1,)).tolist())
557
+ self.model.add_initializer(init_scale)
558
+
559
+ return True, scale_name, zero_point_name, scale_shape, zero_point_shape
560
+
561
+ def _get_quantize_input_nodes(
562
+ self, node, input_index, qType, given_scale_name=None, given_zp_name=None, initial_type=None
563
+ ):
564
+ """
565
+ Given an input for a node (which is not a initializer), this function
566
+
567
+ - add nodes to compute zero point and scale for this input if they don't exist.
568
+ - add new QuantizeLinear node to quantize the input.
569
+
570
+ :param node: node being quantized in NodeProto format.
571
+ :param input_index: index of input in node.input.
572
+ :param qType: type to quantize to.
573
+ :param given_scale_name: if those inputs need to be quanitzed using this scale tensor.
574
+ :param given_zp_name: if those inputs to be quantized using this zeropoint tensor.
575
+ :param initial_type: type of the weight to quantize
576
+ :return: List of newly created nodes in NodeProto format.
577
+ """
578
+ input_name = node.input[input_index]
579
+ assert input_name != "", "Cannot access undefined variable in graph."
580
+ output_name = input_name + TENSOR_NAME_QUANT_SUFFIX
581
+ ql_node_name = input_name + "_QuantizeLinear"
582
+
583
+ if (given_scale_name is not None) and (given_zp_name is not None):
584
+ data_found, scale_name, zp_name = (True, given_scale_name, given_zp_name)
585
+ else:
586
+ data_found, scale_name, zp_name, _, _ = self._get_quantization_params(input_name)
587
+
588
+ nodes = []
589
+ if data_found:
590
+ qlinear_node = onnx.helper.make_node(
591
+ "QuantizeLinear",
592
+ [input_name, scale_name, zp_name],
593
+ [output_name],
594
+ ql_node_name,
595
+ )
596
+ else:
597
+ if self.static:
598
+ return None
599
+ # dynamic mode
600
+ # Scale and Zero Points not available for this input. Add nodes to dynamically compute it
601
+ if self.fuse_dynamic_quant and qType == onnx_proto.TensorProto.UINT8:
602
+ scale_name = input_name + "_scale"
603
+ zp_name = input_name + "_zero_point"
604
+ qlinear_node = onnx.helper.make_node(
605
+ "DynamicQuantizeLinear",
606
+ [input_name],
607
+ [output_name, scale_name, zp_name],
608
+ ql_node_name,
609
+ )
610
+ else:
611
+ assert initial_type is not None, (
612
+ f"Cannot quantize input without knowing the initial type, "
613
+ f"input_name={input_name!r}, input_index={input_index}, qType={qType}, node={node}"
614
+ )
615
+ (
616
+ scale_name,
617
+ zp_name,
618
+ scale_shape,
619
+ zp_shape,
620
+ ) = self._get_dynamic_input_quantization_params(input_name, nodes, qType, initial_type=initial_type)
621
+ qlinear_node = onnx.helper.make_node(
622
+ "QuantizeLinear",
623
+ [input_name, scale_name, zp_name],
624
+ [output_name],
625
+ ql_node_name,
626
+ )
627
+
628
+ self.quantized_value_map[input_name] = QuantizedValue(input_name, output_name, scale_name, zp_name, qType)
629
+ return [*nodes, qlinear_node]
630
+
631
+ def find_quantized_value(self, input_name):
632
+ if input_name in self.quantized_value_map:
633
+ return self.quantized_value_map[input_name]
634
+ if self.parent is not None:
635
+ return self.parent.find_quantized_value(input_name)
636
+ return None
637
+
638
+ def quantize_bias_static(self, bias_name, input_name, weight_name, beta=1.0):
639
+ """
640
+ Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale
641
+ """
642
+
643
+ # Handle case where bias already in quantization map
644
+ if bias_name in self.quantized_value_map:
645
+ return self.quantized_value_map[bias_name].q_name
646
+
647
+ # get scale for weight
648
+ weight_scale_name = self.quantized_value_map[weight_name].scale_name
649
+ weight_initializer = find_by_name(weight_scale_name, self.model.initializer())
650
+ weight_scale = tensor_proto_to_array(weight_initializer)
651
+
652
+ # get scale for input
653
+ if input_name in self.quantized_value_map:
654
+ input_scale_name = self.quantized_value_map[input_name].scale_name
655
+ elif input_name in self.quantization_params:
656
+ _, input_scale_name, _, _, _ = self._get_quantization_params(input_name)
657
+ else:
658
+ raise ValueError(f"Expected {input_name} to be in quantized value map for static quantization")
659
+
660
+ inputscale_initializer = find_by_name(input_scale_name, self.model.initializer())
661
+ input_scale = tensor_proto_to_array(inputscale_initializer)
662
+
663
+ (
664
+ quantized_bias_name,
665
+ quantized_bias_scale_name,
666
+ quantized_bias_zp_name,
667
+ bias_scale_data,
668
+ node_type,
669
+ node_qtype,
670
+ ) = self.quantize_bias_static_impl(bias_name, input_scale, weight_scale, beta)
671
+
672
+ assert bias_name not in self.quantized_value_map
673
+ quantized_value = QuantizedValue(
674
+ bias_name,
675
+ quantized_bias_name,
676
+ quantized_bias_scale_name,
677
+ quantized_bias_zp_name,
678
+ QuantizedValueType.Initializer,
679
+ 0 if bias_scale_data.size > 1 else None,
680
+ node_type=node_type,
681
+ node_qtype=node_qtype,
682
+ )
683
+ self.quantized_value_map[bias_name] = quantized_value
684
+
685
+ return quantized_bias_name
686
+
687
+ def contains_tensor(self, tensor_name):
688
+ """
689
+ only check for value info and newly generated tensor names, initializers are checked separately
690
+ """
691
+ return (
692
+ (tensor_name in self.value_infos)
693
+ or (tensor_name in self.tensor_names)
694
+ or (tensor_name in self.generated_value_names)
695
+ )
696
+
697
+ def quantize_activation(self, node, indices, from_subgraph=False):
698
+ return self.__quantize_inputs(
699
+ node=node,
700
+ indices=indices,
701
+ initializer_use_weight_qType=False,
702
+ reduce_range=False,
703
+ op_level_per_channel=False,
704
+ axis=-1,
705
+ from_subgraph=from_subgraph,
706
+ )
707
+
708
+ # In some circumstances a weight is not an initializer, for example of MatMul, if both A and B are not
709
+ # initializer, B can still be considered as Weight
710
+ def quantize_weight(
711
+ self,
712
+ node,
713
+ indices,
714
+ reduce_range=False,
715
+ op_level_per_channel=False,
716
+ axis=-1,
717
+ from_subgraph=False,
718
+ ):
719
+ return self.__quantize_inputs(
720
+ node=node,
721
+ indices=indices,
722
+ initializer_use_weight_qType=True,
723
+ reduce_range=reduce_range,
724
+ op_level_per_channel=op_level_per_channel,
725
+ axis=axis,
726
+ from_subgraph=from_subgraph,
727
+ )
728
+
729
+ def __quantize_inputs(
730
+ self,
731
+ node,
732
+ indices,
733
+ initializer_use_weight_qType=True,
734
+ reduce_range=False,
735
+ op_level_per_channel=False,
736
+ axis=-1,
737
+ from_subgraph=False,
738
+ ):
739
+ """
740
+ Given a node, this function quantizes the inputs as follows:
741
+ - If input is an initializer, quantize the initializer data, replace old initializer
742
+ with new initializer
743
+ - Else, add QuantizeLinear nodes to perform quantization
744
+ parameter node: node being quantized in NodeProto format.
745
+ parameter indices: input indices to quantize.
746
+ return: (List of quantized input names,
747
+ List of zero point names used for input quantization,
748
+ List of scale names used for input quantization,
749
+ List of new QuantizeLinear nodes created)
750
+ """
751
+
752
+ scale_names = []
753
+ zero_point_names = []
754
+ quantized_input_names = []
755
+ nodes = []
756
+
757
+ for input_index in indices:
758
+ node_input = node.input[input_index]
759
+
760
+ # Find if this input is already quantized
761
+ if node_input in self.quantized_value_map:
762
+ quantized_value = self.quantized_value_map[node_input]
763
+ scale_names.append(quantized_value.scale_name)
764
+ zero_point_names.append(quantized_value.zp_name)
765
+ quantized_input_names.append(quantized_value.q_name)
766
+ continue
767
+ # adding this for case embed_layernorm.py has optional segment_embedding
768
+ if not node_input:
769
+ quantized_input_names.append("")
770
+ scale_names.append("")
771
+ zero_point_names.append("")
772
+ continue
773
+ # Quantize the input
774
+ initializer = find_by_name(node_input, self.model.initializer())
775
+ if initializer is not None:
776
+ if self.per_channel and op_level_per_channel:
777
+ (
778
+ q_weight_name,
779
+ zp_name,
780
+ scale_name,
781
+ ) = self.quantize_weight_per_channel(
782
+ initializer.name,
783
+ self.weight_qType if initializer_use_weight_qType else self.activation_qType,
784
+ axis,
785
+ reduce_range,
786
+ )
787
+ else:
788
+ q_weight_name, zp_name, scale_name = self.quantize_initializer(
789
+ initializer,
790
+ self.weight_qType if initializer_use_weight_qType else self.activation_qType,
791
+ reduce_range,
792
+ )
793
+
794
+ quantized_input_names.append(q_weight_name)
795
+ zero_point_names.append(zp_name)
796
+ scale_names.append(scale_name)
797
+ elif self.contains_tensor(node_input):
798
+ # Add QuantizeLinear node.
799
+ qlinear_node = self.model.find_node_by_name(
800
+ node_input + "_QuantizeLinear", self.new_nodes, self.model.graph()
801
+ )
802
+ if qlinear_node is None:
803
+ input_name = node.input[input_index]
804
+ if input_name in self.value_infos:
805
+ value_info = self.value_infos[input_name]
806
+ assert value_info.HasField("type"), f"value_info={value_info} has no type."
807
+ assert value_info.type.HasField("tensor_type"), f"value_info={value_info} is not a tensor."
808
+ initial_type = value_info.type.tensor_type.elem_type
809
+ else:
810
+ # Shape inference failed. Fallback to self.tensor_names.
811
+ assert input_name in self.tensor_names, (
812
+ f"shape inference failed for {input_name!r} and "
813
+ f"attribute 'tensor_names' does not have any value for "
814
+ f"this tensor."
815
+ )
816
+ initial_type = self.tensor_names[input_name]
817
+ quantize_input_nodes = self._get_quantize_input_nodes(
818
+ node, input_index, self.activation_qType, initial_type=initial_type
819
+ )
820
+ if quantize_input_nodes is None:
821
+ return (None, None, None, None)
822
+ if from_subgraph:
823
+ self.add_new_nodes(quantize_input_nodes)
824
+ else:
825
+ nodes.extend(quantize_input_nodes)
826
+ qlinear_node = quantize_input_nodes[-1]
827
+
828
+ if qlinear_node.op_type == "QuantizeLinear":
829
+ quantized_input_names.extend(qlinear_node.output)
830
+ scale_names.append(qlinear_node.input[1])
831
+ zero_point_names.append(qlinear_node.input[2])
832
+ else:
833
+ quantized_input_names.append(qlinear_node.output[0])
834
+ scale_names.append(qlinear_node.output[1])
835
+ zero_point_names.append(qlinear_node.output[2])
836
+ elif self.parent is not None:
837
+ (
838
+ parent_quantized_input_names,
839
+ parent_zero_point_names,
840
+ parent_scale_names,
841
+ _,
842
+ ) = self.parent.__quantize_inputs(
843
+ node,
844
+ [input_index],
845
+ initializer_use_weight_qType=initializer_use_weight_qType,
846
+ reduce_range=reduce_range,
847
+ op_level_per_channel=op_level_per_channel,
848
+ axis=axis,
849
+ from_subgraph=True,
850
+ )
851
+ quantized_input_names.append(parent_quantized_input_names[0])
852
+ scale_names.append(parent_scale_names[0])
853
+ zero_point_names.append(parent_zero_point_names[0])
854
+ # node should not be add this child level here
855
+ else:
856
+ raise ValueError(f"Invalid tensor name to quantize: {node_input} @graph scope{self.graph_scope}")
857
+
858
+ return quantized_input_names, zero_point_names, scale_names, nodes
859
+
860
+ def quantize_initializer(self, weight, qType, reduce_range=False, keep_float_weight=False):
861
+ """
862
+ :param weight: TensorProto initializer
863
+ :param qType: type to quantize to
864
+ :param keep_float_weight: Whether to quantize the weight. In some cases, we only want to qunatize scale and zero point.
865
+ If keep_float_weight is False, quantize the weight, or don't quantize the weight.
866
+ :return: quantized weight name, zero point name, scale name
867
+ """
868
+ # Find if this input is already quantized
869
+ if weight.name in self.quantized_value_map:
870
+ quantized_value = self.quantized_value_map[weight.name]
871
+ return (
872
+ quantized_value.q_name,
873
+ quantized_value.zp_name,
874
+ quantized_value.scale_name,
875
+ )
876
+
877
+ q_weight_name, zp_name, scale_name = self.quantize_initializer_impl(
878
+ weight, qType, reduce_range, keep_float_weight
879
+ )
880
+
881
+ # Log entry for this quantized weight
882
+ quantized_value = QuantizedValue(
883
+ weight.name,
884
+ q_weight_name,
885
+ scale_name,
886
+ zp_name,
887
+ QuantizedValueType.Initializer,
888
+ None,
889
+ )
890
+ self.quantized_value_map[weight.name] = quantized_value
891
+ return q_weight_name, zp_name, scale_name
892
+
893
+ def quantize_weight_per_channel(
894
+ self,
895
+ weight_name,
896
+ weight_qType,
897
+ channel_axis,
898
+ reduce_range=True,
899
+ keep_float_weight=False,
900
+ ):
901
+ # Find if this input is already quantized
902
+ if weight_name in self.quantized_value_map:
903
+ quantized_value = self.quantized_value_map[weight_name]
904
+ return (
905
+ quantized_value.q_name,
906
+ quantized_value.zp_name,
907
+ quantized_value.scale_name,
908
+ )
909
+
910
+ q_weight_name, zp_name, scale_name = self.quantize_weight_per_channel_impl(
911
+ weight_name, weight_qType, channel_axis, reduce_range, keep_float_weight
912
+ )
913
+ quantized_value = QuantizedValue(
914
+ weight_name,
915
+ q_weight_name,
916
+ scale_name,
917
+ zp_name,
918
+ QuantizedValueType.Initializer,
919
+ None,
920
+ )
921
+ self.quantized_value_map[weight_name] = quantized_value
922
+
923
+ return q_weight_name, zp_name, scale_name
924
+
925
+ def _dequantize_value(self, value_name):
926
+ """
927
+ Given a value (input/output) which is quantized, add a DequantizeLinear node to dequantize
928
+ it back to float32 or float16
929
+ parameter value_name: value to dequantize
930
+ parameter new_nodes_list: List of new nodes created before processing current node
931
+ return: None if there is already a DequantizeLinear node that dequantizes it
932
+ A DequantizeLinear node otherwise
933
+ """
934
+ if (value_name in self.quantized_value_map) and (value_name not in self.generated_value_names):
935
+ quantized_value = self.quantized_value_map[value_name]
936
+ # Add DequantizeLinear Node for this input
937
+
938
+ scale_init = find_by_name(quantized_value.scale_name, self.model.initializer())
939
+
940
+ # In case we are working with subgraphs, the graph `producer_name` is set to `"onnx-quantizer"` in the `quantize_subgraph` method. In this case, the scale initializer may be on the top level graph, so the check below can not be done.
941
+ if self.model.model.producer_name != "onnx-quantizer" or (
942
+ self.model.model.producer_name == "onnx-quantizer" and scale_init is not None
943
+ ):
944
+ # axis is not specified so scale_init must be a scalar.
945
+ assert onnx.numpy_helper.to_array(scale_init).size == 1
946
+
947
+ dqlinear_name = value_name + "_DequantizeLinear"
948
+ dqlinear_node = self.model.find_node_by_name(dqlinear_name, self.new_nodes, self.model.graph())
949
+ if dqlinear_node is None:
950
+ dqlinear_inputs = [
951
+ quantized_value.q_name,
952
+ quantized_value.scale_name,
953
+ quantized_value.zp_name,
954
+ ]
955
+ dequantize_node = onnx.helper.make_node(
956
+ "DequantizeLinear", dqlinear_inputs, [value_name], dqlinear_name
957
+ )
958
+ return dequantize_node
959
+ else:
960
+ # DQ op is already present, assert it's output matches the input of current node
961
+ assert value_name == dqlinear_node.output[0]
962
+ return None
963
+
964
+ def _dequantize_outputs(self):
965
+ """
966
+ Dequantize output if it is quantized
967
+ parameter new_nodes_list: List of new nodes created before processing current node
968
+ return: List of new nodes created
969
+ """
970
+
971
+ for output in self.model.graph().output:
972
+ dequantize_node = self._dequantize_value(output.name)
973
+ if dequantize_node is not None:
974
+ self.new_nodes.append(dequantize_node)
975
+
976
+ def calculate_quantization_params(self):
977
+ if self.tensors_range is None:
978
+ return None
979
+
980
+ self.adjust_tensor_ranges()
981
+
982
+ quantization_params = {}
983
+ for tensor_name in self.tensors_range:
984
+ td = self.tensors_range[tensor_name]
985
+ if not isinstance(td, TensorData):
986
+ raise TypeError(f"Unexpected type {type(td)} for {tensor_name!r}.")
987
+
988
+ quant_overrides = self.tensor_quant_overrides.get_per_tensor_overrides(tensor_name, default_val={})
989
+
990
+ quant_type = self.activation_qType
991
+ if "quant_type" in quant_overrides:
992
+ quant_type = quant_overrides["quant_type"].tensor_type
993
+
994
+ if "scale" in quant_overrides and "zero_point" in quant_overrides:
995
+ zero, scale = quant_overrides["zero_point"], quant_overrides["scale"]
996
+ elif quant_type == onnx.TensorProto.FLOAT8E4M3FN:
997
+ zero, scale = compute_scale_zp_float8(quant_type, td.avg_std[1])
998
+ else:
999
+ rmin = quant_overrides.get("rmin", td.range_value[0])
1000
+ rmax = quant_overrides.get("rmax", td.range_value[1])
1001
+ symmetric = quant_overrides.get("symmetric", self.is_activation_symmetric)
1002
+ reduce_range = quant_overrides.get("reduce_range", False)
1003
+ qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range=reduce_range, symmetric=symmetric)
1004
+ zero, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, self.min_real_range)
1005
+
1006
+ quantization_params[tensor_name] = QuantizationParams(zero_point=zero, scale=scale, quant_type=quant_type)
1007
+
1008
+ return quantization_params