onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,953 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ from __future__ import annotations
7
+
8
+ import copy
9
+ import logging
10
+ import tempfile
11
+ from collections.abc import Callable
12
+ from pathlib import Path
13
+ from typing import Any
14
+
15
+ import onnx
16
+
17
+ from .calibrate import CalibrationDataReader, CalibrationMethod, TensorsData, create_calibrator
18
+ from .onnx_quantizer import ONNXQuantizer
19
+ from .qdq_quantizer import QDQQuantizer
20
+ from .quant_utils import (
21
+ MODEL_SIZE_THRESHOLD,
22
+ QuantFormat,
23
+ QuantizationMode,
24
+ QuantType,
25
+ load_model_with_shape_infer,
26
+ model_has_pre_process_metadata,
27
+ save_and_reload_model_with_shape_infer,
28
+ update_opset_version,
29
+ )
30
+ from .registry import IntegerOpsRegistry, QDQRegistry, QLinearOpsRegistry
31
+ from .tensor_quant_overrides import TensorQuantOverridesHelper
32
+
33
+
34
+ class QuantConfig:
35
+ def __init__(
36
+ self,
37
+ activation_type=QuantType.QUInt8,
38
+ weight_type=QuantType.QInt8,
39
+ op_types_to_quantize=None,
40
+ nodes_to_quantize=None,
41
+ nodes_to_exclude=None,
42
+ per_channel=False,
43
+ reduce_range=False,
44
+ use_external_data_format=False,
45
+ ):
46
+ """
47
+ This is the Base class for both Static and Dynamic Quantize Configuration
48
+ Args:
49
+ activation_type:
50
+ quantization data type of activation. Please refer to
51
+ https://onnxruntime.ai/docs/performance/quantization.html for more details on data type selection
52
+ weight_type:
53
+ quantization data type of weight. Please refer to
54
+ https://onnxruntime.ai/docs/performance/quantization.html for more details on data type selection
55
+ op_types_to_quantize:
56
+ specify the types of operators to quantize, like ['Conv'] to quantize Conv only.
57
+ It quantizes all supported operators by default.
58
+ nodes_to_quantize:
59
+ List of nodes names to quantize. When this list is not None only the nodes in this list
60
+ are quantized.
61
+ example:
62
+ [
63
+ 'Conv__224',
64
+ 'Conv__252'
65
+ ]
66
+ nodes_to_exclude:
67
+ List of nodes names to exclude. The nodes in this list will be excluded from quantization
68
+ when it is not None.
69
+ per_channel: quantize weights per channel
70
+ reduce_range:
71
+ quantize weights with 7-bits. It may improve the accuracy for some models running on non-VNNI machine,
72
+ especially for per-channel mode
73
+ use_external_data_format: option used for large size (>2GB) model. Set to False by default.
74
+ """
75
+
76
+ nodes_to_exclude = nodes_to_exclude or []
77
+ nodes_to_quantize = nodes_to_quantize or []
78
+ op_types_to_quantize = op_types_to_quantize or []
79
+ self.op_types_to_quantize = op_types_to_quantize
80
+ self.per_channel = per_channel
81
+ self.reduce_range = reduce_range
82
+ self.weight_type = weight_type
83
+ self.activation_type = activation_type
84
+ self.nodes_to_quantize = nodes_to_quantize
85
+ self.nodes_to_exclude = nodes_to_exclude
86
+ self.use_external_data_format = use_external_data_format
87
+
88
+
89
+ class StaticQuantConfig(QuantConfig):
90
+ def __init__(
91
+ self,
92
+ calibration_data_reader: CalibrationDataReader,
93
+ calibrate_method=CalibrationMethod.MinMax,
94
+ quant_format=QuantFormat.QDQ,
95
+ activation_type=QuantType.QInt8,
96
+ weight_type=QuantType.QInt8,
97
+ op_types_to_quantize=None,
98
+ nodes_to_quantize=None,
99
+ nodes_to_exclude=None,
100
+ per_channel=False,
101
+ reduce_range=False,
102
+ use_external_data_format=False,
103
+ calibration_providers=None,
104
+ extra_options=None,
105
+ ):
106
+ """
107
+ This is the derived class for static Quantize Configuration
108
+
109
+ Args:
110
+ calibration_data_reader:
111
+ a calibration data reader. It enumerates calibration data and generates inputs for the original model.
112
+ calibrate_method:
113
+ Current calibration methods supported are MinMax, Entropy and Percentile.
114
+ quant_format: QuantFormat{QOperator, QDQ}.
115
+ QOperator format quantizes the model with quantized operators directly.
116
+ QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
117
+ calibration_providers: Execution providers to run the session during calibration. Default is None which uses
118
+ [ "CPUExecutionProvider" ].
119
+ extra_options:
120
+ key value pair dictionary for various options in different case. Current used:
121
+ extra.Sigmoid.nnapi = True/False (Default is False)
122
+ ActivationSymmetric = True/False: symmetrize calibration data for activations (default is False).
123
+ WeightSymmetric = True/False: symmetrize calibration data for weights (default is True).
124
+ EnableSubgraph = True/False : Default is False. If enabled, subgraph will be quantized.
125
+ Dyanmic mode currently is supported. Will support more in future.
126
+ ForceQuantizeNoInputCheck = True/False :
127
+ By default, some latent operators like maxpool, transpose, do not quantize if their input is not
128
+ quantized already. Setting to True to force such operator always quantize input and so generate
129
+ quantized output. Also the True behavior could be disabled per node using the nodes_to_exclude.
130
+ MatMulConstBOnly = True/False:
131
+ Default is False for static mode. If enabled, only MatMul with const B will be quantized.
132
+ AddQDQPairToWeight = True/False :
133
+ Default is False which quantizes floating-point weight and feeds it to solely inserted
134
+ DeQuantizeLinear node. If True, it remains floating-point weight and inserts both
135
+ QuantizeLinear/DeQuantizeLinear nodes to weight.
136
+ OpTypesToExcludeOutputQuantization = list of op type :
137
+ Default is []. If any op type is specified, it won't quantize the output of ops with this
138
+ specific op types.
139
+ DedicatedQDQPair = True/False :
140
+ Default is False. When inserting QDQ pair, multiple nodes can share a single QDQ pair as their
141
+ inputs. If True, it will create identical and dedicated QDQ pair for each node.
142
+ QDQOpTypePerChannelSupportToAxis = dictionary :
143
+ Default is {}. Set channel axis for specific op type, for example: {'MatMul': 1}, and it's
144
+ effective only when per channel quantization is supported and per_channel is True. If specific
145
+ op type supports per channel quantization but not explicitly specified with channel axis,
146
+ default channel axis will be used.
147
+ CalibTensorRangeSymmetric = True/False :
148
+ Default is False. If enabled, the final range of tensor during calibration will be explicitly
149
+ set to symmetric to central point "0".
150
+ CalibMovingAverage = True/False :
151
+ Default is False. If enabled, the moving average of the minimum and maximum values will be
152
+ computed when the calibration method selected is MinMax.
153
+ CalibMovingAverageConstant = float :
154
+ Default is 0.01. Constant smoothing factor to use when computing the moving average of the
155
+ minimum and maximum values. Effective only when the calibration method selected is MinMax and
156
+ when CalibMovingAverage is set to True.
157
+ QuantizeBias = True/False :
158
+ Default is True which quantizes floating-point biases and it solely inserts
159
+ a DeQuantizeLinear node. If False, it remains floating-point bias and does not insert
160
+ any quantization nodes associated with biases.
161
+ This extra option is only effective when quant_format is QuantFormat.QDQ.
162
+ SmoothQuant = True/False :
163
+ Default is False. If enabled, SmoothQuant algorithm will be applied before quantization to do
164
+ fake input channel quantization.
165
+ SmoothQuantAlpha = float :
166
+ Default is 0.5. It only works if SmoothQuant is True. It controls the difficulty of weight
167
+ and activation quantization. A larger alpha value could be used on models with more significant
168
+ activation outliers to migrate more quantization difficulty to weights.
169
+ SmoothQuantFolding = True/False :
170
+ Default is True. It only works if SmoothQuant is True. If enabled, inserted Mul ops during
171
+ SmoothQuant will be folded into the previous op if the previous op is foldable.
172
+ UseQDQContribOps = True/False :
173
+ Default is False. If enabled, the inserted QuantizeLinear and DequantizeLinear ops will have the
174
+ `com.microsoft` domain, which forces use of ONNX Runtime's QuantizeLinear and DequantizeLinear
175
+ contrib op implementations. The contrib op implementations may support features not standardized
176
+ into the ONNX specification (e.g., 16-bit quantization types).
177
+ MinimumRealRange = float|None :
178
+ Default is None. If set to a floating-point value, the calculation of the quantization parameters
179
+ (i.e., scale and zero point) will enforce a minimum range between rmin and rmax. If (rmax-rmin)
180
+ is less than the specified minimum range, rmax will be set to rmin + MinimumRealRange. This is
181
+ necessary for EPs like QNN that require a minimum floating-point range when determining
182
+ quantization parameters.
183
+ TensorQuantOverrides = dictionary :
184
+ Default is {}. Set tensor quantization overrides. The key is a tensor name and the value is a
185
+ list of dictionaries. For per-tensor quantization, the list contains a single dictionary. For
186
+ per-channel quantization, the list contains a dictionary for each channel in the tensor.
187
+ Each dictionary contains optional overrides with the following keys and values.
188
+ 'quant_type' = QuantType : The tensor's quantization data type.
189
+ 'scale' = Float : The scale value to use. Must also specify `zero_point` if set.
190
+ 'zero_point' = Int : The zero-point value to use. Must also specify `scale` is set.
191
+ 'symmetric' = Bool : If the tensor should use symmetric quantization. Invalid if also
192
+ set `scale` or `zero_point`.
193
+ 'reduce_range' = Bool : If the quantization range should be reduced. Invalid if also
194
+ set `scale` or `zero_point`.
195
+ 'rmax' = Float : Override the maximum real tensor value in calibration data.
196
+ Invalid if also set `scale` or `zero_point`.
197
+ 'rmin' = Float : Override the minimum real tensor value in calibration data.
198
+ Invalid if also set `scale` or `zero_point`.
199
+ QDQKeepRemovableActivations = True/False:
200
+ Default is False. If true, "removable" activations (e.g., Clip or Relu) will not be removed, and
201
+ will be explicitly represented in the QDQ model. If false, these activations are automatically
202
+ removed if activations are asymmetrically quantized. Keeping these activations is necessary if
203
+ optimizations or EP transformations will later remove QuantizeLinear/DequantizeLinear
204
+ operators from the model.
205
+ QDQDisableWeightAdjustForInt32Bias = True/False:
206
+ Default is False. If true, QDQ quantizer will not adjust the weight's scale when the bias
207
+ has a scale (input_scale * weight_scale) that is too small.
208
+ execution_provider : A enum indicates the Execution Provider such as: CPU, TRT, NNAPI, SNE, etc.
209
+ Raises:
210
+ ValueError: Raise ValueError if execution provider is unknown
211
+ """
212
+
213
+ super().__init__(
214
+ activation_type=activation_type,
215
+ weight_type=weight_type,
216
+ op_types_to_quantize=op_types_to_quantize,
217
+ nodes_to_quantize=nodes_to_quantize,
218
+ nodes_to_exclude=nodes_to_exclude,
219
+ per_channel=per_channel,
220
+ reduce_range=reduce_range,
221
+ use_external_data_format=use_external_data_format,
222
+ )
223
+ self.calibration_data_reader = calibration_data_reader
224
+ self.calibrate_method = calibrate_method
225
+ self.quant_format = quant_format
226
+ self.calibration_providers = calibration_providers
227
+ self.extra_options = extra_options or {}
228
+
229
+
230
+ def get_qdq_config(
231
+ model_input: str | Path | onnx.ModelProto,
232
+ calibration_data_reader: CalibrationDataReader,
233
+ calibrate_method=CalibrationMethod.MinMax,
234
+ calibrate_args: dict[str, Any] | None = None,
235
+ activation_type=QuantType.QUInt8,
236
+ weight_type=QuantType.QInt8,
237
+ activation_symmetric: bool = False,
238
+ weight_symmetric: bool | None = None,
239
+ per_channel: bool = False,
240
+ reduce_range: bool = False,
241
+ keep_removable_activations: bool = False,
242
+ min_real_range: float | None = None,
243
+ tensor_quant_overrides: dict[str, list[dict[str, Any]]] | None = None,
244
+ calibration_providers: list[str] | None = None,
245
+ op_types_to_quantize: list[str] | None = None,
246
+ nodes_to_exclude: list[str] | Callable[[onnx.ModelProto, onnx.NodeProto], bool] | None = None,
247
+ extra_options: dict | None = None,
248
+ ) -> StaticQuantConfig:
249
+ """
250
+ Returns a configuration suitable that quantizes the entire model to integer precision.
251
+
252
+ Params:
253
+ model_input: Path to the input model file or ModelProto.
254
+ calibration_data_reader: Calibration data reader.
255
+ calibrate_methode: The calibration method. Defaults to MinMax.
256
+ activation_type: The default activation quantization type. Defaults to QUInt8.
257
+ weight_type: The default weight quantization type. Defaults to QInt8.
258
+ activation_symmetric: True if activations should be quantized symmetrically (i.e, rmax == -rmin) by default.
259
+ Defaults to false. For int8 and int16, this results in zero-point values of 0. For uint8 and uint16,
260
+ the zero-point values are 127 and 32,767, respectively.
261
+ weight_symmetric: True if weights should be quantized symmetrically (i.e., rmax == -rmin) by default.
262
+ Defaults to None. If set to None, weight_symmetric is assumed true if a weight's quant type is a signed int.
263
+ per_channel: Global option that determines if a fixed set of operator types should be quantized per-channel.
264
+ Defaults to false. Alternatively, use the tensor-level `tensor_quant_overrides` to select individual operators
265
+ and their quantization axes.
266
+ reduce_range: quantize weights with 1 less bit of precision (e.g., 7 bits for QInt8). Defaults to false.
267
+ May improve the accuracy for some models running on non-VNNI machine, especially for per-channel mode.
268
+ keep_removable_activations: Defaults to false. If true, "removable" activations (e.g., Clip or Relu) will not
269
+ be removed, and will be explicitly represented in the QDQ model. If false, these activations
270
+ are automatically removed if activations are asymmetrically quantized. Keeping these activations
271
+ is necessary if optimizations or EP transformations will later remove
272
+ QuantizeLinear/DequantizeLinear operators from the model.
273
+ min_real_range: Default is None. If set to a floating-point value, the calculation of the quantization parameters
274
+ (i.e., scale and zero point) will enforce a minimum range between rmin and rmax. If (rmax - rmin)
275
+ is less than the specified minimum range, rmax will be set to rmin + min_real_range.
276
+ tensor_quant_overrides: tensor-level quantization overrides. Defaults to None.
277
+ The key is a tensor name and the value is a list of dictionaries. For per-tensor quantization, the list
278
+ contains a single dictionary. For per-channel quantization, the list contains either a dictionary for
279
+ each channel in the tensor or a single dictionary that is assumed to apply to all channels. An 'axis'
280
+ key must be present in the first dictionary for per-channel quantization.
281
+
282
+ Each dictionary contains optional overrides with the following keys and values.
283
+ 'quant_type' = QuantType : The tensor's quantization data type.
284
+ 'axis' = Int : The per-channel axis. Must be present for per-channel weights.
285
+ 'scale' = Float : The scale value to use. Must also specify `zero_point` if set.
286
+ 'zero_point' = Int : The zero-point value to use. Must also specify `scale` is set.
287
+ 'symmetric' = Bool : If the tensor should use symmetric quantization. Invalid if also
288
+ set `scale` or `zero_point`.
289
+ 'reduce_range' = Bool : If the quantization range should be reduced. Invalid if also
290
+ set `scale` or `zero_point`. Only valid for initializers.
291
+ 'rmax' = Float : Override the maximum real tensor value in calibration data.
292
+ Invalid if also set `scale` or `zero_point`.
293
+ 'rmin' = Float : Override the minimum real tensor value in calibration data.
294
+ Invalid if also set `scale` or `zero_point`.
295
+ 'convert' = Dict : A nested dictionary with the same keys for an activation
296
+ tensor that should be converted to another quantization type.
297
+ 'convert["recv_nodes"] = Set : Set of node names that consume the converted activation,
298
+ other nodes get the original type. If not specified,
299
+ assume all consumer nodes get the converted type.
300
+ calibration_providers: Execution providers to run the session during calibration. Default is None which uses
301
+ [ "CPUExecutionProvider" ].
302
+ op_types_to_quantize: List of operator types to quantize. If None, all operators other than Cast, DequantizeLinear,
303
+ and QuantizeLinear are quantized.
304
+ nodes_to_exclude: List of nodes names to exclude from quantization. Alternatively, can provide a function that
305
+ accepts an onnx.ModelProto and onnx.NodeProto as arguments and returns true if the give onnx.NodeProto
306
+ should be excluded from quantization.
307
+ extra_options: Additional options specified as string key/value pairs. Refer to the documentation for
308
+ `quantize_static` for valid keys and values.
309
+
310
+ Returns:
311
+ A StaticQuantConfig object
312
+ """
313
+ q16_types = {QuantType.QInt16, QuantType.QUInt16}
314
+ q4_types = {QuantType.QInt4, QuantType.QUInt4}
315
+ op_types_to_exclude = {"Cast", "DequantizeLinear", "QuantizeLinear"}
316
+
317
+ model = (
318
+ model_input
319
+ if isinstance(model_input, onnx.ModelProto)
320
+ else onnx.load_model(model_input, load_external_data=False)
321
+ )
322
+
323
+ op_types = set()
324
+ model_has_external_data = False
325
+ overrides_helper = TensorQuantOverridesHelper(
326
+ copy.deepcopy(tensor_quant_overrides) if tensor_quant_overrides else {}
327
+ )
328
+
329
+ # check if the model has external data.
330
+ for initializer in model.graph.initializer:
331
+ if onnx.external_data_helper.uses_external_data(initializer):
332
+ model_has_external_data = True
333
+
334
+ op_types_to_quantize_set = set(op_types_to_quantize) if op_types_to_quantize else None
335
+ nodes_to_exclude_set = set(nodes_to_exclude) if isinstance(nodes_to_exclude, list) else set()
336
+
337
+ # Iterate through nodes to get all operator types in the model and
338
+ # call user's function to filter out nodes from quantization.
339
+ for node in model.graph.node:
340
+ if op_types_to_quantize_set and node.op_type not in op_types_to_quantize_set:
341
+ continue
342
+ if node.name in nodes_to_exclude_set:
343
+ continue
344
+ if callable(nodes_to_exclude) and nodes_to_exclude(model, node):
345
+ nodes_to_exclude_set.add(node.name)
346
+ else:
347
+ op_types.add(node.op_type)
348
+
349
+ final_extra_options = {
350
+ "MinimumRealRange": min_real_range,
351
+ "QDQKeepRemovableActivations": keep_removable_activations,
352
+ "ActivationSymmetric": activation_symmetric,
353
+ "WeightSymmetric": weight_symmetric,
354
+ "ForceQuantizeNoInputCheck": True,
355
+ "TensorQuantOverrides": overrides_helper.get_dict(),
356
+ }
357
+
358
+ # Pass along known calibration options
359
+ if calibrate_args:
360
+ calib_extra_options_keys = [
361
+ ("symmetric", "CalibTensorRangeSymmetric"),
362
+ ("moving_average", "CalibMovingAverage"),
363
+ ("averaging_constant", "CalibMovingAverageConstant"),
364
+ ("max_intermediate_outputs", "CalibMaxIntermediateOutputs"),
365
+ ("percentile", "CalibPercentile"),
366
+ ]
367
+ calib_extra_options = {
368
+ key: calibrate_args.get(name) for (name, key) in calib_extra_options_keys if name in calibrate_args
369
+ }
370
+ final_extra_options.update(calib_extra_options)
371
+
372
+ # ONNX opset < 21 does not support 16-bit quantization, so must use 'com.microsoft' domain
373
+ # on Q/DQ operators if using 16-bit or 4-bit quantization.
374
+ onnx_opset = next(x for x in model.opset_import if x.domain == "" or x.domain == "ai.onnx")
375
+ if onnx_opset.version < 21:
376
+ opset21_types = q16_types.union(q4_types)
377
+ overrides_have_opset21_types = any(t in opset21_types for t in overrides_helper.get_quant_types())
378
+ if activation_type in opset21_types or weight_type in opset21_types or overrides_have_opset21_types:
379
+ final_extra_options["UseQDQContribOps"] = True
380
+
381
+ # Allow user's extra_options to override our final_extra_options.
382
+ if extra_options:
383
+ final_extra_options.update(extra_options)
384
+
385
+ return StaticQuantConfig(
386
+ calibration_data_reader,
387
+ calibrate_method=calibrate_method,
388
+ quant_format=QuantFormat.QDQ,
389
+ activation_type=activation_type,
390
+ weight_type=weight_type,
391
+ op_types_to_quantize=(
392
+ op_types_to_quantize if op_types_to_quantize else list(op_types.difference(op_types_to_exclude))
393
+ ),
394
+ nodes_to_exclude=list(nodes_to_exclude_set),
395
+ per_channel=per_channel,
396
+ reduce_range=reduce_range,
397
+ use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD),
398
+ calibration_providers=calibration_providers,
399
+ extra_options=final_extra_options,
400
+ )
401
+
402
+
403
+ class DynamicQuantConfig(QuantConfig):
404
+ def __init__(
405
+ self,
406
+ weight_type=QuantType.QInt8,
407
+ op_types_to_quantize=None,
408
+ nodes_to_quantize=None,
409
+ nodes_to_exclude=None,
410
+ per_channel=False,
411
+ reduce_range=False,
412
+ use_external_data_format=False,
413
+ extra_options=None,
414
+ ):
415
+ """
416
+ This is a class for dynamic Quant Configuration
417
+
418
+ Args:
419
+ extra_options: key value pair dictionary for various options in different case. Current used:
420
+ extra.Sigmoid.nnapi = True/False (Default is False)
421
+ ActivationSymmetric = True/False: symmetrize calibration data for activations (default is False).
422
+ WeightSymmetric = True/False: symmetrize calibration data for weights (default is True).
423
+ EnableSubgraph = True/False :
424
+ Default is False. If enabled, subgraph will be quantized. Dynamic mode currently is supported. Will
425
+ support more in the future.
426
+ ForceQuantizeNoInputCheck = True/False :
427
+ By default, some latent operators like maxpool, transpose, do not quantize if their input is not
428
+ quantized already. Setting to True to force such operator always quantize input and so generate
429
+ quantized output. Also the True behavior could be disabled per node using the nodes_to_exclude.
430
+ MatMulConstBOnly = True/False:
431
+ Default is True for dynamic mode. If enabled, only MatMul with const B will be quantized.
432
+ execution_provider : A enum indicates the Execution Provider such as: CPU, TRT, NNAPI, SNE, etc.
433
+
434
+ Raises:
435
+ ValueError: Raise ValueError if execution provider is unknown
436
+ """
437
+ super().__init__(
438
+ op_types_to_quantize=op_types_to_quantize,
439
+ per_channel=per_channel,
440
+ reduce_range=reduce_range,
441
+ weight_type=weight_type,
442
+ nodes_to_quantize=nodes_to_quantize,
443
+ nodes_to_exclude=nodes_to_exclude,
444
+ use_external_data_format=use_external_data_format,
445
+ )
446
+ self.extra_options = extra_options or {}
447
+
448
+
449
+ def check_static_quant_arguments(quant_format: QuantFormat, activation_type: QuantType, weight_type: QuantType):
450
+ if activation_type == QuantType.QInt8 and weight_type == QuantType.QUInt8:
451
+ raise ValueError(
452
+ "ONNXRuntime quantization doesn't support data format:"
453
+ "activation_type=QuantType.QInt8, weight_type=QuantType.QUInt8"
454
+ )
455
+ if activation_type != QuantType.QFLOAT8E4M3FN and weight_type == QuantType.QFLOAT8E4M3FN:
456
+ raise ValueError(
457
+ f"ONNXRuntime quantization doesn't support data format: activation_type={activation_type} "
458
+ "!=QuantType.QFLOAT8E4M3FN, weight_type=QuantType.QFLOAT8E4M3FN."
459
+ )
460
+
461
+ if activation_type == QuantType.QFLOAT8E4M3FN and weight_type != QuantType.QFLOAT8E4M3FN:
462
+ raise ValueError(
463
+ "ONNXRuntime quantization doesn't support data format: activation_type=QuantType.QFLOAT8E4M3FN, "
464
+ f"weight_type={weight_type}!=QuantType.QFLOAT8E4M3FN"
465
+ )
466
+
467
+ q16_types = [QuantType.QInt16, QuantType.QUInt16]
468
+
469
+ if (activation_type in q16_types or weight_type in q16_types) and quant_format != QuantFormat.QDQ:
470
+ raise ValueError("Only QuantFormat.QDQ supports 16-bit quantization types.")
471
+
472
+ if activation_type == QuantType.QInt8 and weight_type == QuantType.QInt8 and quant_format != QuantFormat.QDQ:
473
+ logging.warning(
474
+ "Please use QuantFormat.QDQ for activation type QInt8 and weight type QInt8. "
475
+ "Or it will lead to bad performance on x64."
476
+ )
477
+
478
+
479
+ def quantize_static(
480
+ model_input: str | Path | onnx.ModelProto,
481
+ model_output: str | Path,
482
+ calibration_data_reader: CalibrationDataReader,
483
+ quant_format=QuantFormat.QDQ,
484
+ op_types_to_quantize=None,
485
+ per_channel=False,
486
+ reduce_range=False,
487
+ activation_type=QuantType.QInt8,
488
+ weight_type=QuantType.QInt8,
489
+ nodes_to_quantize=None,
490
+ nodes_to_exclude=None,
491
+ use_external_data_format=False,
492
+ calibrate_method=CalibrationMethod.MinMax,
493
+ calibration_providers=None,
494
+ extra_options=None,
495
+ ):
496
+ """
497
+ Given an onnx model and calibration data reader, create a quantized onnx model and save it into a file
498
+ It is recommended to use QuantFormat.QDQ format from 1.11 with activation_type = QuantType.QInt8 and weight_type
499
+ = QuantType.QInt8. If model is targeted to GPU/TRT, symmetric activation and weight are required. If model is
500
+ targeted to CPU, asymmetric activation and symmetric weight are recommended for balance of performance and
501
+ accuracy.
502
+
503
+ Args:
504
+
505
+ model_input: file path of model or ModelProto to quantize
506
+ model_output: file path of quantized model
507
+ calibration_data_reader: a calibration data reader. It
508
+ enumerates calibration data and generates inputs for the
509
+ original model.
510
+ quant_format: QuantFormat{QOperator, QDQ}.
511
+ QOperator format quantizes the model with quantized operators directly.
512
+ QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
513
+ activation_type:
514
+ quantization data type of activation. Please refer to
515
+ https://onnxruntime.ai/docs/performance/quantization.html for more details on data type selection
516
+ calibrate_method:
517
+ Current calibration methods supported are MinMax and Entropy.
518
+ Please use CalibrationMethod.MinMax or CalibrationMethod.Entropy as options.
519
+ op_types_to_quantize:
520
+ specify the types of operators to quantize, like ['Conv'] to quantize Conv only.
521
+ It quantizes all supported operators by default.
522
+ per_channel: quantize weights per channel
523
+ reduce_range:
524
+ quantize weights with 7-bits. It may improve the accuracy for some models running on non-VNNI machine,
525
+ especially for per-channel mode
526
+ weight_type:
527
+ quantization data type of weight. Please refer to
528
+ https://onnxruntime.ai/docs/performance/quantization.html for more details on data type selection
529
+ nodes_to_quantize:
530
+ List of nodes names to quantize. When this list is not None only the nodes in this list
531
+ are quantized.
532
+ example:
533
+ [
534
+ 'Conv__224',
535
+ 'Conv__252'
536
+ ]
537
+ nodes_to_exclude:
538
+ List of nodes names to exclude. The nodes in this list will be excluded from quantization
539
+ when it is not None.
540
+ use_external_data_format: option used for large size (>2GB) model. Set to False by default.
541
+ calibration_providers: Execution providers to run the session during calibration. Default is None which uses
542
+ [ "CPUExecutionProvider" ]
543
+ extra_options:
544
+ key value pair dictionary for various options in different case. Current used:
545
+ extra.Sigmoid.nnapi = True/False (Default is False)
546
+ ActivationSymmetric = True/False: symmetrize calibration data for activations (default is False).
547
+ WeightSymmetric = True/False: symmetrize calibration data for weights (default is True).
548
+ EnableSubgraph = True/False : Default is False. If enabled, subgraph will be quantized.
549
+ Dyanmic mode currently is supported. Will support more in the future.
550
+ ForceQuantizeNoInputCheck = True/False :
551
+ By default, some latent operators like maxpool, transpose, do not quantize if their input is not
552
+ quantized already. Setting to True to force such operator always quantize input and so generate
553
+ quantized output. Also, the True behavior could be disabled per node using the nodes_to_exclude.
554
+ MatMulConstBOnly = True/False:
555
+ Default is False for static mode. If enabled, only MatMul with const B will be quantized.
556
+ AddQDQPairToWeight = True/False :
557
+ Default is False which quantizes floating-point weight and feeds it to solely inserted
558
+ DeQuantizeLinear node. If True, it remains floating-point weight and inserts both
559
+ QuantizeLinear/DeQuantizeLinear nodes to weight.
560
+ OpTypesToExcludeOutputQuantization = list of op type :
561
+ Default is []. If any op type is specified, it won't quantize the output of ops with this
562
+ specific op types.
563
+ DedicatedQDQPair = True/False :
564
+ Default is False. When inserting QDQ pair, multiple nodes can share a single QDQ pair as their
565
+ inputs. If True, it will create identical and dedicated QDQ pair for each node.
566
+ QDQOpTypePerChannelSupportToAxis = dictionary :
567
+ Default is {}. Set channel axis for specific op type, for example: {'MatMul': 1}, and it's
568
+ effective only when per channel quantization is supported and per_channel is True. If specific
569
+ op type supports per channel quantization but not explicitly specified with channel axis,
570
+ default channel axis will be used.
571
+ CalibTensorRangeSymmetric = True/False :
572
+ Default is False. If enabled, the final range of tensor during calibration will be explicitly
573
+ set to symmetric to central point "0".
574
+ CalibStridedMinMax = Optional[int] :
575
+ Default is None. If set to an integer, during calculation of the min-max, only stride amount of
576
+ data will be used and then all results will be merged in the end.
577
+ CalibMovingAverage = True/False :
578
+ Default is False. If enabled, the moving average of the minimum and maximum values will be
579
+ computed when the calibration method selected is MinMax.
580
+ CalibMovingAverageConstant = float :
581
+ Default is 0.01. Constant smoothing factor to use when computing the moving average of the
582
+ minimum and maximum values. Effective only when the calibration method selected is MinMax and
583
+ when CalibMovingAverage is set to True.
584
+ CalibMaxIntermediateOutputs = Optional[int] :
585
+ Default is None. If set to an integer, during calculation of the min-max range of the tensors
586
+ it will load at max value number of outputs before computing and merging the range. This will
587
+ produce the same result as all computing with None, but is more memory efficient.
588
+ SmoothQuant = True/False :
589
+ Default is False. If enabled, SmoothQuant algorithm will be applied before quantization to do
590
+ fake input channel quantization.
591
+ SmoothQuantAlpha = float :
592
+ Default is 0.5. It only works if SmoothQuant is True. It controls the difficulty of weight
593
+ and activation quantization. A larger alpha value could be used on models with more significant
594
+ activation outliers to migrate more quantization difficulty to weights.
595
+ SmoothQuantFolding = True/False :
596
+ Default is True. It only works if SmoothQuant is True. If enabled, inserted Mul ops during
597
+ SmoothQuant will be folded into the previous op if the previous op is foldable.
598
+ UseQDQContribOps = True/False :
599
+ Default is False. If enabled, the inserted QuantizeLinear and DequantizeLinear ops will have the
600
+ `com.microsoft` domain, which forces use of ONNX Runtime's QuantizeLinear and DequantizeLinear
601
+ contrib op implementations. The contrib op implementations may support features not standardized
602
+ into the ONNX specification (e.g., 16-bit quantization types).
603
+ MinimumRealRange = float|None :
604
+ Default is None. If set to a floating-point value, the calculation of the quantization parameters
605
+ (i.e., scale and zero point) will enforce a minimum range between rmin and rmax. If (rmax - rmin)
606
+ is less than the specified minimum range, rmax will be set to rmin + MinimumRealRange. This is
607
+ necessary for EPs like QNN that require a minimum floating-point range when determining
608
+ quantization parameters.
609
+ TensorQuantOverrides = dictionary :
610
+ Default is {}. Set tensor quantization overrides. The key is a tensor name and the value is a
611
+ list of dictionaries. For per-tensor quantization, the list contains a single dictionary. For
612
+ per-channel quantization, the list contains a dictionary for each channel in the tensor.
613
+ Each dictionary contains optional overrides with the following keys and values.
614
+ 'quant_type' = QuantType : The tensor's quantization data type.
615
+ 'scale' = Float : The scale value to use. Must also specify `zero_point` if set.
616
+ 'zero_point' = Int : The zero-point value to use. Must also specify `scale` is set.
617
+ 'symmetric' = Bool : If the tensor should use symmetric quantization. Invalid if also
618
+ set `scale` or `zero_point`.
619
+ 'reduce_range' = Bool : If the quantization range should be reduced. Invalid if also
620
+ set `scale` or `zero_point`.
621
+ 'rmax' = Float : Override the maximum real tensor value in calibration data.
622
+ Invalid if also set `scale` or `zero_point`.
623
+ 'rmin' = Float : Override the minimum real tensor value in calibration data.
624
+ Invalid if also set `scale` or `zero_point`.
625
+ QDQKeepRemovableActivations = True/False:
626
+ Default is False. If true, "removable" activations (e.g., Clip or Relu) will not be removed, and
627
+ will be explicitly represented in the QDQ model. If false, these activations are automatically
628
+ removed if activations are asymmetrically quantized. Keeping these activations is necessary if
629
+ optimizations or EP transformations will later remove QuantizeLinear/DequantizeLinear
630
+ operators from the model.
631
+ QDQDisableWeightAdjustForInt32Bias = True/False:
632
+ Default is False. If true, QDQ quantizer will not adjust the weight's scale when the bias
633
+ has a scale (input_scale * weight_scale) that is too small.
634
+ """
635
+ if activation_type == QuantType.QFLOAT8E4M3FN or weight_type == QuantType.QFLOAT8E4M3FN:
636
+ if calibrate_method != CalibrationMethod.Distribution:
637
+ raise ValueError("Only Distribution calibration method is supported for float quantization.")
638
+
639
+ extra_options = extra_options or {}
640
+ nodes_to_exclude = nodes_to_exclude or []
641
+ nodes_to_quantize = nodes_to_quantize or []
642
+ op_types_to_quantize = op_types_to_quantize or []
643
+ mode = QuantizationMode.QLinearOps
644
+
645
+ if not op_types_to_quantize or len(op_types_to_quantize) == 0:
646
+ q_linear_ops = list(QLinearOpsRegistry.keys())
647
+ qdq_ops = list(QDQRegistry.keys())
648
+ op_types_to_quantize = list(set(q_linear_ops + qdq_ops))
649
+
650
+ model = (
651
+ save_and_reload_model_with_shape_infer(model_input)
652
+ if isinstance(model_input, onnx.ModelProto)
653
+ else load_model_with_shape_infer(Path(model_input))
654
+ )
655
+
656
+ pre_processed: bool = model_has_pre_process_metadata(model)
657
+ if not pre_processed:
658
+ logging.warning(
659
+ "Please consider to run pre-processing before quantization. Refer to example: "
660
+ "https://github.com/microsoft/onnxruntime-inference-examples/blob/main/quantization/image_classification"
661
+ "/cpu/ReadMe.md "
662
+ )
663
+
664
+ calib_extra_options_keys = [
665
+ ("CalibTensorRangeSymmetric", "symmetric"),
666
+ ("CalibMovingAverage", "moving_average"),
667
+ ("CalibMovingAverageConstant", "averaging_constant"),
668
+ ("CalibMaxIntermediateOutputs", "max_intermediate_outputs"),
669
+ ("CalibPercentile", "percentile"),
670
+ ]
671
+ calib_extra_options = {
672
+ key: extra_options.get(name) for (name, key) in calib_extra_options_keys if name in extra_options
673
+ }
674
+
675
+ if extra_options.get("SmoothQuant", False):
676
+ import importlib # noqa: PLC0415
677
+
678
+ try:
679
+ importlib.import_module("neural_compressor.adaptor.ox_utils.smooth_quant")
680
+ except Exception as e:
681
+ logging.error(f"{e}.")
682
+ raise RuntimeError("neural-compressor is not correctly installed. Please check your environment.") from e
683
+
684
+ from neural_compressor.adaptor.ox_utils.smooth_quant import ORTSmoothQuant # noqa: PLC0415
685
+
686
+ def inc_dataloader():
687
+ data_reader = copy.deepcopy(calibration_data_reader)
688
+ for data in data_reader:
689
+ yield data, None
690
+
691
+ orig_nodes = [i.name for i in model.graph.node]
692
+ dataloader = inc_dataloader()
693
+ sq = ORTSmoothQuant(model_input, dataloader, reduce_range)
694
+ del dataloader
695
+ model = sq.transform(extra_options.get("SmoothQuantAlpha", 0.5), extra_options.get("SmoothQuantFolding", True))
696
+ sq_path = tempfile.TemporaryDirectory(prefix="ort.quant.")
697
+ model_input = Path(sq_path.name).joinpath("sq_model.onnx").as_posix()
698
+ model.save(model_input)
699
+ nodes_to_exclude.extend([i.name for i in model.model.graph.node if i.name not in orig_nodes])
700
+ model = load_model_with_shape_infer(Path(model_input)) # use smooth quant model for calibration
701
+
702
+ updated_model = update_opset_version(model, weight_type)
703
+ is_model_updated = updated_model is not model
704
+ if is_model_updated:
705
+ model = updated_model
706
+
707
+ with tempfile.TemporaryDirectory(prefix="ort.quant.") as quant_tmp_dir:
708
+ if is_model_updated:
709
+ # Update model_input and avoid to use the original one
710
+ model_input = copy.deepcopy(model)
711
+
712
+ if isinstance(model_input, onnx.ModelProto):
713
+ output_path = Path(quant_tmp_dir).joinpath("model_input.onnx").as_posix()
714
+ onnx.save_model(
715
+ model_input,
716
+ output_path,
717
+ save_as_external_data=True,
718
+ )
719
+ model_input = output_path
720
+
721
+ calibrator = create_calibrator(
722
+ Path(model_input),
723
+ op_types_to_quantize,
724
+ augmented_model_path=Path(quant_tmp_dir).joinpath("augmented_model.onnx").as_posix(),
725
+ calibrate_method=calibrate_method,
726
+ use_external_data_format=use_external_data_format,
727
+ providers=calibration_providers,
728
+ extra_options=calib_extra_options,
729
+ )
730
+
731
+ stride = extra_options.get("CalibStridedMinMax", None)
732
+ if stride:
733
+ total_data_size = len(calibration_data_reader)
734
+ if total_data_size % stride != 0:
735
+ raise ValueError(f"Total data size ({total_data_size}) is not divisible by stride size ({stride}).")
736
+
737
+ for start in range(0, total_data_size, stride):
738
+ end_index = start + stride
739
+ calibration_data_reader.set_range(start_index=start, end_index=end_index)
740
+ calibrator.collect_data(calibration_data_reader)
741
+ else:
742
+ calibrator.collect_data(calibration_data_reader)
743
+ tensors_range = calibrator.compute_data()
744
+ if not isinstance(tensors_range, TensorsData):
745
+ raise TypeError(
746
+ f"Unexpected type {type(tensors_range)} for tensors_range and calibrator={type(calibrator)}."
747
+ )
748
+ del calibrator
749
+
750
+ check_static_quant_arguments(quant_format, activation_type, weight_type)
751
+
752
+ if quant_format is QuantFormat.QOperator:
753
+ quantizer = ONNXQuantizer(
754
+ model,
755
+ per_channel,
756
+ reduce_range,
757
+ mode,
758
+ True, # static
759
+ weight_type,
760
+ activation_type,
761
+ tensors_range,
762
+ nodes_to_quantize,
763
+ nodes_to_exclude,
764
+ op_types_to_quantize,
765
+ extra_options,
766
+ )
767
+ else:
768
+ quantizer = QDQQuantizer(
769
+ model,
770
+ per_channel,
771
+ reduce_range,
772
+ weight_type,
773
+ activation_type,
774
+ tensors_range,
775
+ nodes_to_quantize,
776
+ nodes_to_exclude,
777
+ op_types_to_quantize,
778
+ extra_options,
779
+ )
780
+
781
+ quantizer.quantize_model()
782
+ quantizer.model.save_model_to_file(model_output, use_external_data_format)
783
+ if not pre_processed:
784
+ logging.warning(
785
+ "Please consider pre-processing before quantization. See "
786
+ "https://github.com/microsoft/onnxruntime-inference-examples/blob/main/quantization/image_classification"
787
+ "/cpu/ReadMe.md "
788
+ )
789
+
790
+ if extra_options.get("SmoothQuant", False):
791
+ sq_path.cleanup()
792
+
793
+
794
+ def quantize_dynamic(
795
+ model_input: str | Path | onnx.ModelProto,
796
+ model_output: str | Path,
797
+ op_types_to_quantize=None,
798
+ per_channel=False,
799
+ reduce_range=False,
800
+ weight_type=QuantType.QInt8,
801
+ nodes_to_quantize=None,
802
+ nodes_to_exclude=None,
803
+ use_external_data_format=False,
804
+ extra_options=None,
805
+ ):
806
+ """Given an onnx model, create a quantized onnx model and save it into a file
807
+
808
+ Args:
809
+ model_input: file path of model or ModelProto to quantize
810
+ model_output: file path of quantized model
811
+ op_types_to_quantize:
812
+ specify the types of operators to quantize, like ['Conv'] to quantize Conv only.
813
+ It quantizes all supported operators by default.
814
+ per_channel: quantize weights per channel
815
+ reduce_range:
816
+ quantize weights with 7-bits. It may improve the accuracy for some models running on non-VNNI machine,
817
+ especially for per-channel mode
818
+ weight_type:
819
+ quantization data type of weight. Please refer to
820
+ https://onnxruntime.ai/docs/performance/quantization.html for more details on data type selection
821
+ nodes_to_quantize:
822
+ List of nodes names to quantize. When this list is not None only the nodes in this list
823
+ are quantized.
824
+ example:
825
+ [
826
+ 'Conv__224',
827
+ 'Conv__252'
828
+ ]
829
+ nodes_to_exclude:
830
+ List of nodes names to exclude. The nodes in this list will be excluded from quantization
831
+ when it is not None.
832
+ use_external_data_format: option used for large size (>2GB) model. Set to False by default.
833
+ extra_options:
834
+ key value pair dictionary for various options in different case. Current used:
835
+ extra.Sigmoid.nnapi = True/False (Default is False)
836
+ ActivationSymmetric = True/False: symmetrize calibration data for activations (default is False).
837
+ WeightSymmetric = True/False: symmetrize calibration data for weights (default is True).
838
+ EnableSubgraph = True/False :
839
+ Default is False. If enabled, subgraph will be quantized. Dynamic mode currently is supported. Will
840
+ support more in the future.
841
+ ForceQuantizeNoInputCheck = True/False :
842
+ By default, some latent operators like maxpool, transpose, do not quantize if their input is not
843
+ quantized already. Setting to True to force such operator always quantize input and so generate
844
+ quantized output. Also the True behavior could be disabled per node using the nodes_to_exclude.
845
+ MatMulConstBOnly = True/False:
846
+ Default is True for dynamic mode. If enabled, only MatMul with const B will be quantized.
847
+ """
848
+ extra_options = extra_options or {}
849
+ nodes_to_exclude = nodes_to_exclude or []
850
+ nodes_to_quantize = nodes_to_quantize or []
851
+ op_types_to_quantize = op_types_to_quantize or []
852
+
853
+ mode = QuantizationMode.IntegerOps
854
+
855
+ if not op_types_to_quantize or len(op_types_to_quantize) == 0:
856
+ op_types_to_quantize = list(IntegerOpsRegistry.keys())
857
+
858
+ model = (
859
+ save_and_reload_model_with_shape_infer(model_input)
860
+ if isinstance(model_input, onnx.ModelProto)
861
+ else load_model_with_shape_infer(Path(model_input))
862
+ )
863
+
864
+ pre_processed: bool = model_has_pre_process_metadata(model)
865
+ if not pre_processed:
866
+ logging.warning(
867
+ "Please consider to run pre-processing before quantization. Refer to example: "
868
+ "https://github.com/microsoft/onnxruntime-inference-examples/blob/main/quantization/image_classification"
869
+ "/cpu/ReadMe.md "
870
+ )
871
+
872
+ if "MatMulConstBOnly" not in extra_options:
873
+ extra_options["MatMulConstBOnly"] = True
874
+
875
+ model = update_opset_version(model, weight_type)
876
+
877
+ quantizer = ONNXQuantizer(
878
+ model,
879
+ per_channel,
880
+ reduce_range,
881
+ mode,
882
+ False, # static
883
+ weight_type,
884
+ QuantType.QUInt8, # dynamic activation only supports uint8
885
+ None,
886
+ nodes_to_quantize,
887
+ nodes_to_exclude,
888
+ op_types_to_quantize,
889
+ extra_options,
890
+ )
891
+
892
+ quantizer.quantize_model()
893
+ quantizer.model.save_model_to_file(model_output, use_external_data_format)
894
+
895
+
896
+ def quantize(
897
+ model_input: str | Path | onnx.ModelProto,
898
+ model_output: str | Path,
899
+ quant_config: QuantConfig,
900
+ ):
901
+ """Quantize a model with QuantConfig.
902
+
903
+ Args:
904
+ model_input (str | Path | ModelProto): Path to the model or ModelProto to quantize.
905
+ model_output (str | Path): Path to save the quantized model.
906
+ quant_config (QuantConfig | WeightOnlyQuantConfig): Quantization Configuration.
907
+ """
908
+ if isinstance(quant_config, StaticQuantConfig):
909
+ quantize_static(
910
+ model_input,
911
+ model_output,
912
+ quant_config.calibration_data_reader,
913
+ calibrate_method=quant_config.calibrate_method,
914
+ quant_format=quant_config.quant_format,
915
+ activation_type=quant_config.activation_type,
916
+ weight_type=quant_config.weight_type,
917
+ op_types_to_quantize=quant_config.op_types_to_quantize,
918
+ nodes_to_quantize=quant_config.nodes_to_quantize,
919
+ nodes_to_exclude=quant_config.nodes_to_exclude,
920
+ per_channel=quant_config.per_channel,
921
+ reduce_range=quant_config.reduce_range,
922
+ use_external_data_format=quant_config.use_external_data_format,
923
+ calibration_providers=quant_config.calibration_providers,
924
+ extra_options=quant_config.extra_options,
925
+ )
926
+
927
+ elif isinstance(quant_config, DynamicQuantConfig):
928
+ quantize_dynamic(
929
+ model_input,
930
+ model_output,
931
+ weight_type=quant_config.weight_type,
932
+ op_types_to_quantize=quant_config.op_types_to_quantize,
933
+ nodes_to_quantize=quant_config.nodes_to_quantize,
934
+ nodes_to_exclude=quant_config.nodes_to_exclude,
935
+ per_channel=quant_config.per_channel,
936
+ reduce_range=quant_config.reduce_range,
937
+ use_external_data_format=quant_config.use_external_data_format,
938
+ extra_options=quant_config.extra_options,
939
+ )
940
+ else:
941
+ # training package doesn't has quantize_matmul_4bits, avoid global import
942
+ from .matmul_nbits_quantizer import MatMulNBitsQuantizer, WeightOnlyQuantConfig # noqa: PLC0415
943
+
944
+ if isinstance(quant_config, WeightOnlyQuantConfig):
945
+ model = model_input if isinstance(model_input, onnx.ModelProto) else onnx.load(model_input)
946
+ quant = MatMulNBitsQuantizer(model, algo_config=quant_config)
947
+ quant.process()
948
+ quant.model.save_model_to_file(model_output, True)
949
+ else:
950
+ raise TypeError(
951
+ "Invalid quantization config type, it must be either StaticQuantConfig, "
952
+ "DynamicQuantConfig, or WeightOnlyQuantConfig."
953
+ )