onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,110 @@
1
+ from .operators.activation import QDQRemovableActivation, QLinearActivation
2
+ from .operators.argmax import QArgMax
3
+ from .operators.attention import AttentionQuant
4
+ from .operators.base_operator import QuantOperatorBase
5
+ from .operators.binary_op import QLinearBinaryOp
6
+ from .operators.concat import QLinearConcat
7
+ from .operators.conv import ConvInteger, QDQConv, QLinearConv
8
+ from .operators.direct_q8 import Direct8BitOp, QDQDirect8BitOp
9
+ from .operators.embed_layernorm import EmbedLayerNormalizationQuant
10
+ from .operators.gather import GatherQuant, QDQGather
11
+ from .operators.gavgpool import QGlobalAveragePool
12
+ from .operators.gemm import QDQGemm, QLinearGemm
13
+ from .operators.lstm import LSTMQuant
14
+ from .operators.matmul import MatMulInteger, QDQMatMul, QLinearMatMul
15
+ from .operators.maxpool import QDQMaxPool, QMaxPool
16
+ from .operators.norm import QDQNormalization
17
+ from .operators.pad import QDQPad, QPad
18
+ from .operators.pooling import QLinearPool
19
+ from .operators.qdq_base_operator import QDQOperatorBase
20
+ from .operators.resize import QDQResize, QResize
21
+ from .operators.softmax import QLinearSoftmax
22
+ from .operators.split import QDQSplit, QSplit
23
+ from .operators.where import QDQWhere, QLinearWhere
24
+ from .quant_utils import QuantizationMode
25
+
26
+ CommonOpsRegistry = {
27
+ "Gather": GatherQuant,
28
+ "Transpose": Direct8BitOp,
29
+ "EmbedLayerNormalization": EmbedLayerNormalizationQuant,
30
+ }
31
+
32
+ IntegerOpsRegistry = {
33
+ "Conv": ConvInteger,
34
+ "MatMul": MatMulInteger,
35
+ "Attention": AttentionQuant,
36
+ "LSTM": LSTMQuant,
37
+ }
38
+ IntegerOpsRegistry.update(CommonOpsRegistry)
39
+
40
+ QLinearOpsRegistry = {
41
+ "ArgMax": QArgMax,
42
+ "Conv": QLinearConv,
43
+ "Gemm": QLinearGemm,
44
+ "MatMul": QLinearMatMul,
45
+ "Add": QLinearBinaryOp,
46
+ "Mul": QLinearBinaryOp,
47
+ "Relu": QLinearActivation,
48
+ "Clip": QLinearActivation,
49
+ "LeakyRelu": QLinearActivation,
50
+ "Sigmoid": QLinearActivation,
51
+ "MaxPool": QMaxPool,
52
+ "GlobalAveragePool": QGlobalAveragePool,
53
+ "Split": QSplit,
54
+ "Pad": QPad,
55
+ "Reshape": Direct8BitOp,
56
+ "Squeeze": Direct8BitOp,
57
+ "Unsqueeze": Direct8BitOp,
58
+ "Resize": QResize,
59
+ "AveragePool": QLinearPool,
60
+ "Concat": QLinearConcat,
61
+ "Softmax": QLinearSoftmax,
62
+ "Where": QLinearWhere,
63
+ }
64
+ QLinearOpsRegistry.update(CommonOpsRegistry)
65
+
66
+ QDQRegistry = {
67
+ "Conv": QDQConv,
68
+ "ConvTranspose": QDQConv,
69
+ "Gemm": QDQGemm,
70
+ "Clip": QDQRemovableActivation,
71
+ "Relu": QDQRemovableActivation,
72
+ "Reshape": QDQDirect8BitOp,
73
+ "Transpose": QDQDirect8BitOp,
74
+ "Squeeze": QDQDirect8BitOp,
75
+ "Unsqueeze": QDQDirect8BitOp,
76
+ "Resize": QDQResize,
77
+ "MaxPool": QDQMaxPool,
78
+ "AveragePool": QDQDirect8BitOp,
79
+ "Slice": QDQDirect8BitOp,
80
+ "Pad": QDQPad,
81
+ "MatMul": QDQMatMul,
82
+ "Split": QDQSplit,
83
+ "Gather": QDQGather,
84
+ "GatherElements": QDQGather,
85
+ "Where": QDQWhere,
86
+ "InstanceNormalization": QDQNormalization,
87
+ "LayerNormalization": QDQNormalization,
88
+ "BatchNormalization": QDQNormalization,
89
+ "TopK": QDQDirect8BitOp,
90
+ "CumSum": QDQOperatorBase,
91
+ }
92
+
93
+
94
+ def CreateDefaultOpQuantizer(onnx_quantizer, node): # noqa: N802
95
+ return QuantOperatorBase(onnx_quantizer, node)
96
+
97
+
98
+ def CreateOpQuantizer(onnx_quantizer, node): # noqa: N802
99
+ registry = IntegerOpsRegistry if onnx_quantizer.mode == QuantizationMode.IntegerOps else QLinearOpsRegistry
100
+ if node.op_type in registry:
101
+ op_quantizer = registry[node.op_type](onnx_quantizer, node)
102
+ if op_quantizer.should_quantize():
103
+ return op_quantizer
104
+ return QuantOperatorBase(onnx_quantizer, node)
105
+
106
+
107
+ def CreateQDQQuantizer(onnx_quantizer, node): # noqa: N802
108
+ if node.op_type in QDQRegistry:
109
+ return QDQRegistry[node.op_type](onnx_quantizer, node)
110
+ return QDQOperatorBase(onnx_quantizer, node)
@@ -0,0 +1,204 @@
1
+ # --------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft, Intel Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+
8
+ import logging
9
+ import tempfile
10
+ import traceback
11
+ from pathlib import Path
12
+
13
+ import onnx
14
+
15
+ import onnxruntime
16
+ from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
17
+ from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data
18
+
19
+ from .fusions import ReplaceUpsampleWithResize
20
+ from .onnx_model import ONNXModel
21
+ from .quant_utils import add_pre_process_metadata, save_and_reload_model_with_shape_infer
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ def quant_pre_process(
27
+ input_model: str | Path | onnx.ModelProto | None = None,
28
+ output_model_path: str | Path | None = None,
29
+ skip_optimization: bool = False,
30
+ skip_onnx_shape: bool = False,
31
+ skip_symbolic_shape: bool = False,
32
+ auto_merge: bool = False,
33
+ int_max: int = 2**31 - 1,
34
+ guess_output_rank: bool = False,
35
+ verbose: int = 0,
36
+ save_as_external_data: bool = False,
37
+ all_tensors_to_one_file: bool = False,
38
+ external_data_location: str | None = None,
39
+ external_data_size_threshold: int = 1024,
40
+ **deprecated_kwargs,
41
+ ) -> None:
42
+ """Shape inference and model optimization, in preparation for quantization.
43
+
44
+ Args:
45
+ input_model: Path to the input model file or ModelProto
46
+ output_model_path: Path to the output model file
47
+ skip_optimization: Skip model optimization step if true. This may result in ONNX shape
48
+ inference failure for some models.
49
+ skip_onnx_shape: Skip ONNX shape inference. Symbolic shape inference is most effective
50
+ with transformer based models. Skipping all shape inferences may
51
+ reduce the effectiveness of quantization, as a tensor with unknown
52
+ shape can not be quantized.
53
+ skip_symbolic_shape: Skip symbolic shape inference. Symbolic shape inference is most
54
+ effective with transformer based models. Skipping all shape
55
+ inferences may reduce the effectiveness of quantization, as a tensor
56
+ with unknown shape can not be quantized.
57
+ auto_merge: For symbolic shape inference, automatically merge symbolic dims when
58
+ conflict happens.
59
+ int_max: For symbolic shape inference, specify the maximum value for integer to be
60
+ treated as boundless for ops like slice
61
+ guess_output_rank: Guess output rank to be the same as input 0 for unknown ops
62
+ verbose: Logs detailed info of inference, 0: turn off, 1: warnings, 3: detailed
63
+ save_as_external_data: Saving an ONNX model to external data
64
+ all_tensors_to_one_file: Saving all the external data to one file
65
+ external_data_location: The file location to save the external file
66
+ external_data_size_threshold: The size threshold for external data
67
+ """
68
+
69
+ if input_model is None:
70
+ input_model = deprecated_kwargs.pop("input_model_path", None)
71
+ assert input_model is not None
72
+
73
+ assert output_model_path is not None, "output_model_path is required."
74
+
75
+ with tempfile.TemporaryDirectory(prefix="pre.quant.") as quant_tmp_dir:
76
+ temp_path = Path(quant_tmp_dir)
77
+ model = input_model if isinstance(input_model, onnx.ModelProto) else onnx.load(input_model)
78
+
79
+ # Since Upsample is deprecated after opset v10, and the model's opset will
80
+ # be upgraded to at least v11 during quantization, we need to replace Upsample
81
+ # with Resize first to avoid generating an invalid model.
82
+ ai_onnx_domain = [opset for opset in model.opset_import if not opset.domain or opset.domain == "ai.onnx"]
83
+ if len(ai_onnx_domain) == 1:
84
+ opset_version = ai_onnx_domain[0].version
85
+ if opset_version <= 10:
86
+ ReplaceUpsampleWithResize(ONNXModel(model), opset_version).apply()
87
+ model = onnx.version_converter.convert_version(model, 11)
88
+ model = save_and_reload_model_with_shape_infer(model)
89
+
90
+ if not skip_symbolic_shape:
91
+ logger.info("Performing symbolic shape inference...")
92
+ model = SymbolicShapeInference.infer_shapes(
93
+ model,
94
+ int_max,
95
+ auto_merge,
96
+ guess_output_rank,
97
+ verbose,
98
+ )
99
+
100
+ if not skip_optimization:
101
+ # Use ORT optimizers (native code) to optimize model
102
+ if not skip_symbolic_shape:
103
+ # Need to save the inferenced model to file so as to run the optimizer
104
+ input_model = str(temp_path / "symbolic_shape_inferred.onnx")
105
+ if save_as_external_data:
106
+ onnx.save_model(
107
+ model,
108
+ input_model,
109
+ save_as_external_data=True,
110
+ all_tensors_to_one_file=all_tensors_to_one_file,
111
+ size_threshold=external_data_size_threshold,
112
+ convert_attribute=False,
113
+ )
114
+ else:
115
+ onnx.save(model, input_model)
116
+ model = None
117
+
118
+ opt_model_path = str(temp_path / "optimized.onnx")
119
+ try:
120
+ sess_option = onnxruntime.SessionOptions()
121
+ sess_option.optimized_model_filepath = opt_model_path
122
+ sess_option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
123
+ # For large model, extract external data from model and add to session options
124
+ if isinstance(input_model, onnx.ModelProto):
125
+ if has_external_data(input_model):
126
+ raise ValueError(
127
+ "ModelProto has external data not loaded into memory, ORT cannot create session. "
128
+ "Please load external data before calling this function. "
129
+ "See https://onnx.ai/onnx/repo-docs/ExternalData.html for more information."
130
+ )
131
+ external_names, external_values = extract_raw_data_from_model(input_model)
132
+ sess_option.add_external_initializers(list(external_names), list(external_values))
133
+ input_model = input_model.SerializeToString()
134
+ # the saved optimized model otherwise points to the original external data file name
135
+ # which is not available relative to the optimized model file
136
+ elif skip_symbolic_shape and save_as_external_data:
137
+ sess_option.add_session_config_entry(
138
+ "session.optimized_model_external_initializers_file_name", "optimized.onnx.data"
139
+ )
140
+
141
+ sess = onnxruntime.InferenceSession(input_model, sess_option, providers=["CPUExecutionProvider"])
142
+ # Close the session to avoid the cleanup error on Windows for temp folders
143
+ # https://github.com/microsoft/onnxruntime/issues/17627
144
+ del sess
145
+ except Exception:
146
+ logger.error(
147
+ "ONNX Runtime Model Optimization Failed! Consider rerun with option `--skip_optimization'."
148
+ )
149
+ logger.error(traceback.format_exc())
150
+
151
+ input_model = opt_model_path
152
+
153
+ if not skip_onnx_shape:
154
+ # ONNX shape inference.
155
+ # According to docs, infer_shapes_path should be used for 2G+ models.
156
+ # If the skip optimization is specified, we could be dealing with a
157
+ # large model. So be on the safe side, save the model
158
+ if model is not None:
159
+ input_model = str(temp_path / "symbolic_shape_inferred.onnx")
160
+ if save_as_external_data:
161
+ onnx.save_model(
162
+ model,
163
+ input_model,
164
+ save_as_external_data=True,
165
+ all_tensors_to_one_file=all_tensors_to_one_file,
166
+ size_threshold=external_data_size_threshold,
167
+ convert_attribute=False,
168
+ )
169
+ else:
170
+ onnx.save(model, input_model)
171
+ model = None
172
+
173
+ if isinstance(input_model, onnx.ModelProto):
174
+ input_model = str(Path(quant_tmp_dir) / "model_input.onnx")
175
+ onnx.save_model(
176
+ model,
177
+ input_model,
178
+ save_as_external_data=True,
179
+ all_tensors_to_one_file=all_tensors_to_one_file,
180
+ size_threshold=external_data_size_threshold,
181
+ convert_attribute=False,
182
+ )
183
+
184
+ inferred_model_path = str(temp_path / "onnx_shape_inferred.onnx")
185
+ onnx.shape_inference.infer_shapes_path(input_model, inferred_model_path)
186
+ model = onnx.load(inferred_model_path)
187
+
188
+ if model is None:
189
+ model = input_model if isinstance(input_model, onnx.ModelProto) else onnx.load(input_model)
190
+
191
+ add_pre_process_metadata(model)
192
+
193
+ if save_as_external_data:
194
+ onnx.save_model(
195
+ model,
196
+ output_model_path,
197
+ save_as_external_data=True,
198
+ all_tensors_to_one_file=all_tensors_to_one_file,
199
+ location=external_data_location,
200
+ size_threshold=external_data_size_threshold,
201
+ convert_attribute=False,
202
+ )
203
+ else:
204
+ onnx.save(model, output_model_path)
@@ -0,0 +1,256 @@
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ import numpy as np
6
+ import onnx
7
+
8
+ import onnxruntime
9
+ from onnxruntime.quantization import QuantFormat, QuantType, StaticQuantConfig, quantize
10
+ from onnxruntime.quantization.calibrate import CalibrationDataReader, CalibrationMethod
11
+
12
+
13
+ class OnnxModelCalibrationDataReader(CalibrationDataReader):
14
+ def __init__(self, model_path):
15
+ self.model_dir = os.path.dirname(model_path)
16
+ data_dirs = [
17
+ os.path.join(self.model_dir, a) for a in os.listdir(self.model_dir) if a.startswith("test_data_set_")
18
+ ]
19
+ model_inputs = onnxruntime.InferenceSession(model_path).get_inputs()
20
+ name2tensors = []
21
+ for data_dir in data_dirs:
22
+ name2tensor = {}
23
+ data_paths = [os.path.join(data_dir, f"input_{input_idx}.pb") for input_idx in range(len(model_inputs))]
24
+ data_ndarrays = [self.read_onnx_pb_data(data_path) for data_path in data_paths]
25
+ for model_input, data_ndarray in zip(model_inputs, data_ndarrays, strict=False):
26
+ name2tensor[model_input.name] = data_ndarray
27
+ name2tensors.append(name2tensor)
28
+ assert len(name2tensors) == len(data_dirs)
29
+ assert len(name2tensors[0]) == len(model_inputs)
30
+
31
+ self.calibration_data = iter(name2tensors)
32
+
33
+ def get_next(self) -> dict:
34
+ """generate the input data dict for ONNXinferenceSession run"""
35
+ return next(self.calibration_data, None)
36
+
37
+ def read_onnx_pb_data(self, file_pb):
38
+ tensor = onnx.TensorProto()
39
+ with open(file_pb, "rb") as f:
40
+ tensor.ParseFromString(f.read())
41
+ ret = onnx.numpy_helper.to_array(tensor)
42
+ return ret
43
+
44
+
45
+ def parse_arguments():
46
+ parser = argparse.ArgumentParser(description="The arguments for static quantization")
47
+ parser.add_argument("-i", "--input_model_path", required=True, help="Path to the input onnx model")
48
+ parser.add_argument(
49
+ "-o", "--output_quantized_model_path", required=True, help="Path to the output quantized onnx model"
50
+ )
51
+ parser.add_argument(
52
+ "--activation_type",
53
+ choices=["qint8", "quint8", "qint16", "quint16", "qint4", "quint4", "qfloat8e4m3fn"],
54
+ default="quint8",
55
+ help="Activation quantization type used",
56
+ )
57
+ parser.add_argument(
58
+ "--weight_type",
59
+ choices=["qint8", "quint8", "qint16", "quint16", "qint4", "quint4", "qfloat8e4m3fn"],
60
+ default="qint8",
61
+ help="Weight quantization type used",
62
+ )
63
+ parser.add_argument("--enable_subgraph", action="store_true", help="If set, subgraph will be quantized.")
64
+ parser.add_argument(
65
+ "--force_quantize_no_input_check",
66
+ action="store_true",
67
+ help="By default, some latent operators like maxpool, transpose, do not quantize if their input is not"
68
+ " quantized already. Setting to True to force such operator always quantize input and so generate"
69
+ " quantized output. Also the True behavior could be disabled per node using the nodes_to_exclude.",
70
+ )
71
+ parser.add_argument(
72
+ "--matmul_const_b_only",
73
+ action="store_true",
74
+ help="If set, only MatMul with const B will be quantized.",
75
+ )
76
+ parser.add_argument(
77
+ "--add_qdq_pair_to_weight",
78
+ action="store_true",
79
+ help="If set, it remains floating-point weight and inserts both QuantizeLinear/DeQuantizeLinear"
80
+ " nodes to weight.",
81
+ )
82
+ parser.add_argument(
83
+ "--dedicated_qdq_pair",
84
+ action="store_true",
85
+ help="If set, it will create identical and dedicated QDQ pair for each node.",
86
+ )
87
+ parser.add_argument(
88
+ "--op_types_to_exclude_output_quantization",
89
+ nargs="+",
90
+ default=[],
91
+ help="If any op type is specified, it won't quantize the output of ops with this specific op types.",
92
+ )
93
+ parser.add_argument(
94
+ "--calibration_method",
95
+ default="minmax",
96
+ choices=["minmax", "entropy", "percentile", "distribution"],
97
+ help="Calibration method used",
98
+ )
99
+ parser.add_argument("--quant_format", default="qdq", choices=["qdq", "qoperator"], help="Quantization format used")
100
+ parser.add_argument(
101
+ "--calib_tensor_range_symmetric",
102
+ action="store_true",
103
+ help="If enabled, the final range of tensor during calibration will be explicitly"
104
+ " set to symmetric to central point 0",
105
+ )
106
+ # TODO: --calib_strided_minmax"
107
+ # TODO: --calib_moving_average_constant"
108
+ # TODO: --calib_max_intermediate_outputs"
109
+ parser.add_argument(
110
+ "--calib_moving_average",
111
+ action="store_true",
112
+ help="If enabled, the moving average of"
113
+ " the minimum and maximum values will be computed when the calibration method selected is MinMax.",
114
+ )
115
+ parser.add_argument(
116
+ "--disable_quantize_bias",
117
+ action="store_true",
118
+ help="Whether to quantize floating-point biases by solely inserting a DeQuantizeLinear node"
119
+ " If not set, it remains floating-point bias and does not insert any quantization nodes"
120
+ " associated with biases.",
121
+ )
122
+
123
+ # TODO: Add arguments related to Smooth Quant
124
+
125
+ parser.add_argument(
126
+ "--use_qdq_contrib_ops",
127
+ action="store_true",
128
+ help="If set, the inserted QuantizeLinear and DequantizeLinear ops will have the com.microsoft domain,"
129
+ " which forces use of ONNX Runtime's QuantizeLinear and DequantizeLinear contrib op implementations.",
130
+ )
131
+ parser.add_argument(
132
+ "--minimum_real_range",
133
+ type=float,
134
+ default=0.0001,
135
+ help="If set to a floating-point value, the calculation of the quantization parameters"
136
+ " (i.e., scale and zero point) will enforce a minimum range between rmin and rmax. If (rmax-rmin)"
137
+ " is less than the specified minimum range, rmax will be set to rmin + MinimumRealRange. This is"
138
+ " necessary for EPs like QNN that require a minimum floating-point range when determining "
139
+ " quantization parameters.",
140
+ )
141
+ parser.add_argument(
142
+ "--qdq_keep_removable_activations",
143
+ action="store_true",
144
+ help="If set, removable activations (e.g., Clip or Relu) will not be removed,"
145
+ " and will be explicitly represented in the QDQ model.",
146
+ )
147
+ parser.add_argument(
148
+ "--qdq_disable_weight_adjust_for_int32_bias",
149
+ action="store_true",
150
+ help="If set, QDQ quantizer will not adjust the weight's scale when the bias"
151
+ " has a scale (input_scale * weight_scale) that is too small.",
152
+ )
153
+ parser.add_argument("--per_channel", action="store_true", help="Whether using per-channel quantization")
154
+ parser.add_argument(
155
+ "--nodes_to_quantize",
156
+ nargs="+",
157
+ default=None,
158
+ help="List of nodes names to quantize. When this list is not None only the nodes in this list are quantized.",
159
+ )
160
+ parser.add_argument(
161
+ "--nodes_to_exclude",
162
+ nargs="+",
163
+ default=None,
164
+ help="List of nodes names to exclude. The nodes in this list will be excluded from quantization when it is not None.",
165
+ )
166
+ parser.add_argument(
167
+ "--op_per_channel_axis",
168
+ nargs=2,
169
+ action="append",
170
+ metavar=("OP_TYPE", "PER_CHANNEL_AXIS"),
171
+ default=[],
172
+ help="Set channel axis for specific op type, for example: --op_per_channel_axis MatMul 1, and it's"
173
+ " effective only when per channel quantization is supported and per_channel is True. If specific"
174
+ " op type supports per channel quantization but not explicitly specified with channel axis,"
175
+ " default channel axis will be used.",
176
+ )
177
+ parser.add_argument("--tensor_quant_overrides", help="Set the json file for tensor quantization overrides.")
178
+ return parser.parse_args()
179
+
180
+
181
+ def get_tensor_quant_overrides(file):
182
+ # TODO: Enhance the function to handle more real cases of json file
183
+ if not file:
184
+ return {}
185
+ with open(file) as f:
186
+ quant_override_dict = json.load(f)
187
+ for tensor in quant_override_dict:
188
+ for enc_dict in quant_override_dict[tensor]:
189
+ enc_dict["scale"] = np.array(enc_dict["scale"], dtype=np.float32)
190
+ enc_dict["zero_point"] = np.array(enc_dict["zero_point"])
191
+ return quant_override_dict
192
+
193
+
194
+ def main():
195
+ args = parse_arguments()
196
+ data_reader = OnnxModelCalibrationDataReader(model_path=args.input_model_path)
197
+ arg2quant_type = {
198
+ "qint8": QuantType.QInt8,
199
+ "quint8": QuantType.QUInt8,
200
+ "qint16": QuantType.QInt16,
201
+ "quint16": QuantType.QUInt16,
202
+ "qint4": QuantType.QInt4,
203
+ "quint4": QuantType.QUInt4,
204
+ "qfloat8e4m3fn": QuantType.QFLOAT8E4M3FN,
205
+ }
206
+ activation_type = arg2quant_type[args.activation_type]
207
+ weight_type = arg2quant_type[args.weight_type]
208
+ qdq_op_type_per_channel_support_to_axis = dict(args.op_per_channel_axis)
209
+ extra_options = {
210
+ "EnableSubgraph": args.enable_subgraph,
211
+ "ForceQuantizeNoInputCheck": args.force_quantize_no_input_check,
212
+ "MatMulConstBOnly": args.matmul_const_b_only,
213
+ "AddQDQPairToWeight": args.add_qdq_pair_to_weight,
214
+ "OpTypesToExcludeOutputQuantization": args.op_types_to_exclude_output_quantization,
215
+ "DedicatedQDQPair": args.dedicated_qdq_pair,
216
+ "QDQOpTypePerChannelSupportToAxis": qdq_op_type_per_channel_support_to_axis,
217
+ "CalibTensorRangeSymmetric": args.calib_tensor_range_symmetric,
218
+ "CalibMovingAverage": args.calib_moving_average,
219
+ "QuantizeBias": not args.disable_quantize_bias,
220
+ "UseQDQContribOps": args.use_qdq_contrib_ops,
221
+ "MinimumRealRange": args.minimum_real_range,
222
+ "QDQKeepRemovableActivations": args.qdq_keep_removable_activations,
223
+ "QDQDisableWeightAdjustForInt32Bias": args.qdq_disable_weight_adjust_for_int32_bias,
224
+ # Load json file for encoding override
225
+ "TensorQuantOverrides": get_tensor_quant_overrides(args.tensor_quant_overrides),
226
+ }
227
+ arg2calib_method = {
228
+ "minmax": CalibrationMethod.MinMax,
229
+ "entropy": CalibrationMethod.Entropy,
230
+ "percentile": CalibrationMethod.Percentile,
231
+ "distribution": CalibrationMethod.Distribution,
232
+ }
233
+ arg2quant_format = {
234
+ "qdq": QuantFormat.QDQ,
235
+ "qoperator": QuantFormat.QOperator,
236
+ }
237
+ sqc = StaticQuantConfig(
238
+ calibration_data_reader=data_reader,
239
+ calibrate_method=arg2calib_method[args.calibration_method],
240
+ quant_format=arg2quant_format[args.quant_format],
241
+ activation_type=activation_type,
242
+ weight_type=weight_type,
243
+ op_types_to_quantize=None,
244
+ nodes_to_quantize=args.nodes_to_quantize,
245
+ nodes_to_exclude=args.nodes_to_exclude,
246
+ per_channel=args.per_channel,
247
+ reduce_range=False,
248
+ use_external_data_format=False,
249
+ calibration_providers=None, # Use CPUExecutionProvider
250
+ extra_options=extra_options,
251
+ )
252
+ quantize(model_input=args.input_model_path, model_output=args.output_quantized_model_path, quant_config=sqc)
253
+
254
+
255
+ if __name__ == "__main__":
256
+ main()