onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1051 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ from __future__ import annotations
7
+
8
+ import copy
9
+ import logging
10
+ import os
11
+ import tempfile
12
+ from enum import Enum
13
+ from pathlib import Path
14
+
15
+ import numpy
16
+ import onnx
17
+ from ml_dtypes import float8_e4m3fn, int4, uint4
18
+ from onnx import ModelProto, TensorProto, external_data_helper
19
+ from onnx import onnx_pb as onnx_proto
20
+ from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info
21
+ from onnx.reference import ReferenceEvaluator
22
+
23
+ from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions
24
+
25
+ try:
26
+ from onnx.reference.op_run import to_array_extended
27
+ except ImportError:
28
+ # old version of onnx.
29
+ to_array_extended = None
30
+
31
+
32
+ __producer__ = "onnx.quantize"
33
+ __version__ = "0.1.0"
34
+ onnx_domain = "ai.onnx"
35
+ ms_domain = "com.microsoft"
36
+ QUANT_OP_NAME = "QuantizeLinear"
37
+ QUANT_INPUT_SUFFIX = "_QuantizeLinear_Input"
38
+ DEQUANT_OP_NAME = "DequantizeLinear"
39
+ DEQUANT_OUTPUT_SUFFIX = "_DequantizeLinear_Output"
40
+ TENSOR_NAME_QUANT_SUFFIX = "_quantized"
41
+ MODEL_SIZE_THRESHOLD = 2147483648 # Quant model should use external data if >= 2GB
42
+
43
+ FLOAT8_DISTRIBUTIONS = {}
44
+
45
+ type_to_name = {getattr(TensorProto, k): k for k in dir(TensorProto) if isinstance(getattr(TensorProto, k), int)}
46
+
47
+ # Quantization mode
48
+ # IntegerOps: Use IntegerOps in quantized model. Only ConvInteger and MatMulInteger ops are supported now.
49
+ # QLinearOps: Use QLinearOps in quantized model. Only QLinearConv and QLinearMatMul ops are supported now.
50
+
51
+
52
+ class QuantizationMode(Enum):
53
+ IntegerOps = 0
54
+ QLinearOps = 1
55
+
56
+ def __str__(self):
57
+ return self.name
58
+
59
+ @staticmethod
60
+ def from_string(mode):
61
+ try:
62
+ return QuantizationMode[mode]
63
+ except KeyError:
64
+ raise ValueError() # noqa: B904
65
+
66
+
67
+ class QuantizedValueType(Enum):
68
+ Input = 0
69
+ Initializer = 1
70
+
71
+ def __str__(self):
72
+ return self.name
73
+
74
+ @staticmethod
75
+ def from_string(v):
76
+ try:
77
+ return QuantizedValueType[v]
78
+ except KeyError:
79
+ raise ValueError() # noqa: B904
80
+
81
+
82
+ class QuantType(Enum):
83
+ QInt8 = 0
84
+ QUInt8 = 1
85
+ QFLOAT8E4M3FN = 2
86
+ QInt16 = 3
87
+ QUInt16 = 4
88
+ QInt4 = 5
89
+ QUInt4 = 6
90
+
91
+ def __str__(self):
92
+ return self.name
93
+
94
+ @staticmethod
95
+ def from_string(t):
96
+ try:
97
+ return QuantType[t]
98
+ except KeyError:
99
+ raise ValueError() # noqa: B904
100
+
101
+ @property
102
+ def tensor_type(self):
103
+ if self == QuantType.QInt8:
104
+ return TensorProto.INT8
105
+ if self == QuantType.QUInt8:
106
+ return TensorProto.UINT8
107
+ if self == QuantType.QUInt16:
108
+ return TensorProto.UINT16
109
+ if self == QuantType.QInt16:
110
+ return TensorProto.INT16
111
+ if self == QuantType.QFLOAT8E4M3FN:
112
+ return TensorProto.FLOAT8E4M3FN
113
+ if self == QuantType.QUInt4:
114
+ return TensorProto.UINT4
115
+ if self == QuantType.QInt4:
116
+ return TensorProto.INT4
117
+ raise ValueError(f"Unexpected value qtype={self!r}.")
118
+
119
+
120
+ class QuantFormat(Enum):
121
+ QOperator = 0
122
+ QDQ = 1
123
+
124
+ def __str__(self):
125
+ return self.name
126
+
127
+ @staticmethod
128
+ def from_string(format):
129
+ try:
130
+ return QuantFormat[format]
131
+ except KeyError:
132
+ raise ValueError() # noqa: B904
133
+
134
+
135
+ ONNX_TYPE_TO_NP_TYPE = {
136
+ onnx_proto.TensorProto.INT8: numpy.dtype("int8"),
137
+ onnx_proto.TensorProto.UINT8: numpy.dtype("uint8"),
138
+ onnx_proto.TensorProto.INT16: numpy.dtype("int16"),
139
+ onnx_proto.TensorProto.UINT16: numpy.dtype("uint16"),
140
+ onnx_proto.TensorProto.FLOAT8E4M3FN: float8_e4m3fn,
141
+ onnx_proto.TensorProto.INT4: int4,
142
+ onnx_proto.TensorProto.UINT4: uint4,
143
+ }
144
+
145
+ ONNX_INT_TYPE_RANGE = {
146
+ onnx_proto.TensorProto.UINT8: (numpy.array(0, dtype=numpy.uint8), numpy.array(255, dtype=numpy.uint8)),
147
+ onnx_proto.TensorProto.INT8: (numpy.array(-128, dtype=numpy.int8), numpy.array(127, dtype=numpy.int8)),
148
+ onnx_proto.TensorProto.UINT16: (numpy.array(0, dtype=numpy.uint16), numpy.array(65535, dtype=numpy.uint16)),
149
+ onnx_proto.TensorProto.INT16: (numpy.array(-32768, dtype=numpy.int16), numpy.array(32767, dtype=numpy.int16)),
150
+ onnx_proto.TensorProto.UINT4: (numpy.array(0, dtype=uint4), numpy.array(15, dtype=uint4)),
151
+ onnx_proto.TensorProto.INT4: (numpy.array(-8, dtype=int4), numpy.array(7, dtype=int4)),
152
+ }
153
+
154
+ ONNX_INT_TYPE_SYMMETRIC_RANGE = {
155
+ onnx_proto.TensorProto.INT8: (numpy.array(-127, dtype=numpy.int8), numpy.array(127, dtype=numpy.int8)),
156
+ onnx_proto.TensorProto.INT16: (numpy.array(-32767, dtype=numpy.int16), numpy.array(32767, dtype=numpy.int16)),
157
+ }
158
+
159
+ ONNX_INT_TYPE_REDUCED_RANGE = {
160
+ onnx_proto.TensorProto.UINT8: (numpy.array(0, dtype=numpy.uint8), numpy.array(127, dtype=numpy.uint8)),
161
+ onnx_proto.TensorProto.INT8: (numpy.array(-64, dtype=numpy.int8), numpy.array(64, dtype=numpy.int8)),
162
+ onnx_proto.TensorProto.UINT16: (numpy.array(0, dtype=numpy.uint16), numpy.array(32767, dtype=numpy.uint16)),
163
+ onnx_proto.TensorProto.INT16: (numpy.array(-16384, dtype=numpy.int16), numpy.array(16384, dtype=numpy.int16)),
164
+ onnx_proto.TensorProto.UINT4: (numpy.array(0, dtype=uint4), numpy.array(7, dtype=uint4)),
165
+ onnx_proto.TensorProto.INT4: (numpy.array(-4, dtype=int4), numpy.array(3, dtype=int4)),
166
+ }
167
+
168
+
169
+ def _check_type(*args, zero_point_index=-1):
170
+ new_args = []
171
+ for i, a in enumerate(args):
172
+ if numpy.issubdtype(type(a), numpy.number):
173
+ new_args.append(numpy.array(a))
174
+ elif isinstance(a, numpy.ndarray):
175
+ new_args.append(a)
176
+ else:
177
+ raise TypeError(f"arg {i} is not an array: {a}")
178
+ if i == zero_point_index:
179
+ v = new_args[-1]
180
+ if v.dtype == numpy.float32 or v.dtype == numpy.float16:
181
+ raise TypeError(f"zero_point cannot be {v.dtype}")
182
+ return tuple(new_args) if len(new_args) > 1 else new_args[0]
183
+
184
+
185
+ def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None):
186
+ assert qType in ONNX_TYPE_TO_NP_TYPE, (
187
+ f"Unexpected data type {qType} requested. Only INT8, UINT8, INT16, and UINT16 are supported."
188
+ )
189
+ if qType in (
190
+ onnx_proto.TensorProto.FLOAT8E4M3FN,
191
+ onnx_proto.TensorProto.FLOAT8E4M3FNUZ,
192
+ onnx_proto.TensorProto.FLOAT8E5M2,
193
+ onnx_proto.TensorProto.FLOAT8E5M2FNUZ,
194
+ ):
195
+ if zero_point != 0:
196
+ raise NotImplementedError(f"zero_point is expected to be null for float 8 not {zero_point!r}.")
197
+ if arr.dtype == numpy.float32:
198
+ onnx_type = TensorProto.FLOAT
199
+ elif arr.dtype == numpy.float16:
200
+ onnx_type = TensorProto.FLOAT16
201
+ else:
202
+ raise ValueError(f"Unexpected dtype {arr.dtype}.")
203
+ onnx_model = make_model(
204
+ make_graph(
205
+ [
206
+ make_node(
207
+ "Constant", [], ["zero_point"], value=onnx.helper.make_tensor("zero_point", qType, [], [0])
208
+ ),
209
+ make_node("QuantizeLinear", ["X", "scale", "zero_point"], ["Y"]),
210
+ ],
211
+ "qu",
212
+ [
213
+ make_tensor_value_info("X", onnx_type, None),
214
+ make_tensor_value_info("scale", onnx_type, None),
215
+ ],
216
+ [make_tensor_value_info("Y", qType, None)],
217
+ )
218
+ )
219
+ ref = ReferenceEvaluator(onnx_model)
220
+ return _check_type(ref.run(None, {"X": arr, "scale": scale})[0])
221
+ else:
222
+ # Quantizes data for all integer types.
223
+ #
224
+ # For int4 types, the quantized data is returned as either np.int8 or np.uint8,
225
+ # which matches the python reference ONNX implementation of QuantizeLinear.
226
+ # This data can be packed into 4-bit elements by using pack_bytes_to_4bit().
227
+ dtype = ONNX_TYPE_TO_NP_TYPE[qType]
228
+ qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=False)
229
+
230
+ cliplow = max(qmin, low) if low is not None else qmin
231
+ cliphigh = min(qmax, high) if high is not None else qmax
232
+ arr_fp32 = numpy.asarray((arr.astype(numpy.float32) / scale).round() + zero_point)
233
+ numpy.clip(arr_fp32, cliplow, cliphigh, out=arr_fp32)
234
+ return _check_type(arr_fp32.astype(dtype))
235
+
236
+
237
+ def compute_scale_zp(rmin, rmax, qmin, qmax, symmetric=False, min_real_range=None):
238
+ """Calculate the scale s and zero point z for the quantization relation
239
+ r = s(q-z), where r are the original values and q are the corresponding
240
+ quantized values.
241
+
242
+ r and z are calculated such that every value within [rmin,rmax] has an
243
+ approximate representation within [qmin,qmax]. In addition, qmin <= z <=
244
+ qmax is enforced. If the symmetric flag is set to True, the interval
245
+ [rmin,rmax] is symmetrized to [-absmax, +absmax], where
246
+ absmax = max(abs(rmin), abs(rmax)).
247
+
248
+ :parameter rmin: minimum value of r
249
+ :parameter rmax: maximum value of r
250
+ :parameter qmin: minimum value representable by the target quantization data type
251
+ :parameter qmax: maximum value representable by the target quantization data type
252
+ :parameter symmetric: True if the floating-point range should be made symmetric. Defaults to False.
253
+ :parameter min_real_range: Minimum floating-point range (i.e., rmax - rmin) to enforce. Defaults to None.
254
+ :return: zero and scale [z, s]
255
+
256
+ """
257
+ if qmin > 0 or qmax < 0:
258
+ raise ValueError(f"qmin and qmax must meet requirement: qmin <= 0 <= qmax while qmin:{qmin}, qmmax:{qmax}")
259
+
260
+ # Adjust rmin and rmax such that 0 is included in the range. This is
261
+ # required to make sure zero can be represented by the quantization data
262
+ # type (i.e. to make sure qmin <= zero_point <= qmax)
263
+ rmin = numpy.minimum(rmin, numpy.array(0, dtype=rmin.dtype))
264
+ rmax = numpy.maximum(rmax, numpy.array(0, dtype=rmax.dtype))
265
+
266
+ # Ensure a minimum float-point range if specified.
267
+ if min_real_range is not None:
268
+ rmax = max(rmax, rmin + numpy.asarray(min_real_range, dtype=rmin.dtype))
269
+
270
+ if symmetric:
271
+ absmax = numpy.maximum(numpy.abs(rmin), numpy.abs(rmax))
272
+ rmin = -absmax
273
+ rmax = +absmax
274
+
275
+ assert qmin <= qmax, f"qmin={rmin} > qmax={rmax}"
276
+ dr = numpy.array(rmax - rmin, dtype=numpy.float64)
277
+ dq = numpy.array(qmax, dtype=numpy.float64) - numpy.array(qmin, dtype=numpy.float64)
278
+ scale = numpy.array(dr / dq)
279
+ assert scale >= 0, "scale issue"
280
+ if scale < numpy.finfo(rmax.dtype).tiny:
281
+ scale = numpy.array(1.0, dtype=rmax.dtype)
282
+ zero_point = numpy.array(0, dtype=qmin.dtype)
283
+ else:
284
+ if symmetric:
285
+ # When symmetric (i.e., rmax == -rmin), the zero_point formula reduces to round((qmax + qmin) / 2.0).
286
+ # This simpler formula doesn't depend on scale and guarantees that the zero point values
287
+ # for int8, uint8, int16, and uint16 are always 0, 128, 0, and 32768, respectively.
288
+ # This is important for per-channel/symmetric QLinearConv on CPU EP, which requires all channels to have
289
+ # the exact same zero_point values.
290
+ zero_point = numpy.array(
291
+ numpy.round((qmin + qmax) / numpy.array(2.0, dtype=numpy.float64)), dtype=qmin.dtype
292
+ )
293
+ else:
294
+ zero_point = numpy.array(numpy.round(qmin - rmin / scale), dtype=qmin.dtype)
295
+ scale = scale.astype(rmax.dtype)
296
+
297
+ return [zero_point, scale]
298
+
299
+
300
+ def compute_scale_zp_float8(element_type, std):
301
+ """Calculate the scale s for a float8 type (E4M3FN).
302
+ The function assumes the coefficient distribution and the float 8
303
+ distribution are similar to two gaussian laws.
304
+
305
+ :return: zero and scale [z, s]
306
+
307
+ More details in notebook `quantization_fp8.ipynb
308
+ <https://github.com/microsoft/onnxruntime/blob/main/docs/python/notebooks/quantization_fp8.ipynb>`_.
309
+ """
310
+ zp_dtype = None
311
+ if element_type not in FLOAT8_DISTRIBUTIONS:
312
+ if element_type == TensorProto.FLOAT8E4M3FN:
313
+ from ml_dtypes import float8_e4m3fn # noqa: PLC0415
314
+
315
+ zp_dtype = float8_e4m3fn
316
+ all_values = [float(i) for i in range(256)]
317
+ values = numpy.array(
318
+ [f for f in all_values if not numpy.isnan(f) and not numpy.isinf(f)], dtype=numpy.float32
319
+ )
320
+ else:
321
+ raise ValueError(f"Quantization to element_type={element_type} not implemented.")
322
+ FLOAT8_DISTRIBUTIONS[element_type] = values
323
+ elif element_type == TensorProto.FLOAT8E4M3FN:
324
+ from ml_dtypes import float8_e4m3fn # noqa: PLC0415
325
+
326
+ zp_dtype = float8_e4m3fn
327
+
328
+ if zp_dtype is None:
329
+ raise TypeError(f"Unexpected element_type {element_type}.")
330
+ std_f8 = numpy.std(FLOAT8_DISTRIBUTIONS[element_type])
331
+ zero = numpy.array(0, dtype=zp_dtype)
332
+ scale = numpy.array(std / std_f8, dtype=std.dtype)
333
+ return [zero, scale]
334
+
335
+
336
+ def compute_data_quant_params(
337
+ data: numpy.ndarray,
338
+ quant_type: onnx.TensorProto.DataType,
339
+ symmetric: bool,
340
+ reduce_range: bool = False,
341
+ min_real_range: float | None = None,
342
+ rmin_override: float | None = None,
343
+ rmax_override: float | None = None,
344
+ ) -> tuple[numpy.ndarray, numpy.ndarray]:
345
+ """
346
+ Returns the zero_point and scale for the given data.
347
+
348
+ :param data: The data for which to compute quantization parameters.
349
+ :param quant_type: The quantization data type.
350
+ :param symmetric: whether symmetric quantization is used or not.
351
+ :parameter reduce_range: True if the quantization range should be reduced. Defaults to False.
352
+ :parameter min_real_range: Minimum floating-point range (i.e., rmax - rmin) to enforce. Defaults to None.
353
+ :parameter rmin_override: The value of rmin to use if not None. Otherwise, uses min(data).
354
+ :parameter rmax_override: The value of rmax to use if not None. Otherwise, uses max(data).
355
+ :return: zero point and scale
356
+ """
357
+ if not isinstance(data, numpy.ndarray):
358
+ raise TypeError(f"Weight must be given as an array not {type(data)}.")
359
+ if rmin_override is not None:
360
+ rmin = rmin_override
361
+ else:
362
+ rmin = data.min() if len(data) else 0.0
363
+
364
+ if rmax_override is not None:
365
+ rmax = rmax_override
366
+ else:
367
+ rmax = data.max() if len(data) else 0.0
368
+
369
+ rmin = numpy.array(rmin, dtype=data.dtype)
370
+ rmax = numpy.array(rmax, dtype=data.dtype)
371
+ scale = numpy.array(1.0, dtype=data.dtype)
372
+
373
+ if quant_type == TensorProto.FLOAT8E4M3FN:
374
+ if reduce_range:
375
+ raise RuntimeError("Unsupported option reduce_range=True for float 8.")
376
+ std = numpy.std(data)
377
+ zero_point, scale = compute_scale_zp_float8(quant_type, std)
378
+ return _check_type(zero_point, scale, zero_point_index=0)
379
+
380
+ if quant_type in (
381
+ TensorProto.INT8,
382
+ TensorProto.UINT8,
383
+ TensorProto.INT16,
384
+ TensorProto.UINT16,
385
+ TensorProto.INT4,
386
+ TensorProto.UINT4,
387
+ ):
388
+ qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range, symmetric=symmetric)
389
+ if len(data):
390
+ zero_point, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, min_real_range)
391
+ else:
392
+ zero_point = numpy.array(0, dtype=qmin.dtype)
393
+ return _check_type(zero_point, scale, zero_point_index=0)
394
+
395
+ raise ValueError(f"Unexpected value for quant_type={quant_type}.")
396
+
397
+
398
+ def quantize_data(
399
+ data, qType, symmetric, reduce_range=False, min_real_range=None, rmin_override=None, rmax_override=None
400
+ ) -> tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]:
401
+ """
402
+ :param data: data to quantize
403
+ :param qType: data type to quantize to.
404
+ :param symmetric: whether symmetric quantization is used or not.
405
+ :parameter reduce_range: True if the quantization range should be reduced. Defaults to False.
406
+ :parameter min_real_range: Minimum floating-point range (i.e., rmax - rmin) to enforce. Defaults to None.
407
+ :parameter rmin_override: The value of rmin to use if not None. Otherwise, uses min(data).
408
+ :parameter rmax_override: The value of rmax to use if not None. Otherwise, uses max(data).
409
+ :return: minimum, maximum, zero point, scale, and quantized weights
410
+
411
+ To pack weights, we compute a linear transformation
412
+
413
+ - when data `type == uint8` mode, from `[rmin, rmax]` -> :math:`[0, 2^{b-1}]` and
414
+ - when data `type == int8`, from `[-m , m]` -> :math:`[-(2^{b-1}-1), 2^{b-1}-1]` where
415
+ `m = max(abs(rmin), abs(rmax))`
416
+
417
+ and add necessary intermediate nodes to transform quantized weight to full weight using the equation
418
+
419
+ :math:`r = S(q-z)`, where
420
+
421
+ - *r*: real original value
422
+ - *q*: quantized value
423
+ - *S*: scale
424
+ - *z*: zero point
425
+ """
426
+ zero_point, scale = compute_data_quant_params(
427
+ data,
428
+ qType,
429
+ symmetric,
430
+ reduce_range,
431
+ min_real_range,
432
+ rmin_override,
433
+ rmax_override,
434
+ )
435
+ if qType == TensorProto.FLOAT8E4M3FN:
436
+ quantized_data = quantize_nparray(qType, data, scale, zero_point)
437
+ if any((quantized_data.view(numpy.uint8).ravel() & 127) == 127):
438
+ np_data = numpy.asarray(data)
439
+ raise RuntimeError(
440
+ f"One of the quantized value is NaN data in [{np_data.min()}, {np_data.max()}], "
441
+ f"quantized_data in [{quantized_data.min()}, {quantized_data.max()}]."
442
+ )
443
+ return zero_point, scale, quantized_data
444
+
445
+ if qType in (
446
+ TensorProto.INT8,
447
+ TensorProto.UINT8,
448
+ TensorProto.INT16,
449
+ TensorProto.UINT16,
450
+ TensorProto.INT4,
451
+ TensorProto.UINT4,
452
+ ):
453
+ quantized_data = quantize_nparray(qType, data, scale, zero_point)
454
+ return zero_point, scale, quantized_data
455
+
456
+ raise ValueError(f"Unexpected value for qType={qType}.")
457
+
458
+
459
+ def quantize_onnx_initializer(
460
+ weight: onnx.TensorProto,
461
+ quant_type: onnx.TensorProto.DataType,
462
+ zero_point: numpy.ndarray,
463
+ scale: numpy.ndarray,
464
+ axis: int | None = None,
465
+ quant_weight_name: str | None = None,
466
+ ) -> onnx.TensorProto:
467
+ """
468
+ Returns a quantized version of the given ONNX initializer.
469
+
470
+ :param weight: The ONNX initializer to quantize.
471
+ :param quant_type: The final quantized data type.
472
+ :param zero_point: The zero-point value to use for quantization.
473
+ :param scale: The scale value to use for quantization.
474
+ :param axis: The quantization axis if quantizing per-channel. Defaults to None.
475
+ :param quant_weight_name: The name of the quantized initializer.
476
+ If not specified, the quantized name is generated.
477
+ :return: The quantized ONNX initializer.
478
+ """
479
+ weight_data = tensor_proto_to_array(weight)
480
+ q_weight_data: numpy.ndarray | None = None
481
+
482
+ if axis is None: # Per-tensor quantization
483
+ q_weight_data = quantize_nparray(quant_type, weight_data.ravel(), scale, zero_point)
484
+ else: # Per-channel quantization
485
+ channel_count = weight_data.shape[axis]
486
+ channel_dims = list(weight_data.shape) # deep copy
487
+ channel_dims[axis] = 1 # only one per channel for reshape
488
+ quantized_channel_data_list = []
489
+
490
+ for i in range(channel_count):
491
+ channel_data = weight_data.take(i, axis)
492
+ channel_scale = scale[i]
493
+ channel_zero_point = zero_point[i]
494
+ quantized_channel_data = quantize_nparray(
495
+ quant_type, channel_data.ravel(), channel_scale, channel_zero_point
496
+ )
497
+ quantized_channel_data_list.append(numpy.asarray(quantized_channel_data).reshape(channel_dims))
498
+
499
+ q_weight_data = numpy.concatenate(quantized_channel_data_list, axis)
500
+
501
+ q_weight_name = quant_weight_name if quant_weight_name else f"{weight.name}{TENSOR_NAME_QUANT_SUFFIX}"
502
+
503
+ if quant_type == onnx.TensorProto.FLOAT8E4M3FN:
504
+ q_weight_initializer = onnx.TensorProto()
505
+ q_weight_initializer.data_type = quant_type
506
+ q_weight_initializer.dims.extend(weight.dims)
507
+ q_weight_initializer.name = q_weight_name
508
+ # Do not remove .flatten().copy() numpy is not clear about data persistence.
509
+ q_weight_initializer.raw_data = q_weight_data.flatten().copy().tobytes()
510
+ if to_array_extended is not None:
511
+ # This test should not be needed but it helped catch some issues
512
+ # with data persistence and tobytes.
513
+ check = to_array_extended(q_weight_initializer)
514
+ if check.shape != weight_data.shape or check.tobytes() != q_weight_data.tobytes():
515
+ raise RuntimeError(
516
+ f"The initializer of shape {weight_data.shape} could not be created, expecting "
517
+ f"{q_weight_data.tobytes()[:10]}, got {check.tobytes()[:10]} and shape={weight.shape}"
518
+ f"\nraw={str(q_weight_initializer)[:200]}."
519
+ )
520
+ elif quant_type in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4):
521
+ if q_weight_data.dtype not in (int4, uint4):
522
+ raise RuntimeError(f"Quantized weights for {q_weight_name} must be 8-bit before packing as 4-bit values.")
523
+
524
+ # We do not use onnx.helper.pack_float32_to_4bit() due to performance.
525
+ # This can be the difference between a large model taking 30 minutes to quantize vs 5 minutes.
526
+ packed_data = bytes(pack_bytes_to_4bit(q_weight_data.tobytes()))
527
+
528
+ # We only use onnx.helper.make_tensor with raw data due to bug: https://github.com/onnx/onnx/pull/6161
529
+ q_weight_initializer = onnx.helper.make_tensor(q_weight_name, quant_type, weight.dims, packed_data, raw=True)
530
+ else:
531
+ quant_np_dtype = onnx.helper.tensor_dtype_to_np_dtype(quant_type)
532
+ q_weight_data = numpy.asarray(q_weight_data, dtype=quant_np_dtype).reshape(weight.dims)
533
+ q_weight_initializer = onnx.numpy_helper.from_array(q_weight_data, q_weight_name)
534
+
535
+ return q_weight_initializer
536
+
537
+
538
+ def get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=False): # noqa: N802
539
+ """
540
+ Return qmin and qmax, the minimum and maximum value representable by the given qType
541
+ :parameter qType: onnx.onnx_pb.TensorProto.UINT8 or onnx.onnx_pb.TensorProto.UINT8
542
+ :return: qmin, qmax
543
+ """
544
+ if qType == onnx_proto.TensorProto.FLOAT8E4M3FN:
545
+ raise NotImplementedError("This function is not implemented for float 8 as not needed.")
546
+
547
+ qrange = None
548
+
549
+ if reduce_range:
550
+ qrange = ONNX_INT_TYPE_REDUCED_RANGE.get(qType)
551
+ elif symmetric and qType in ONNX_INT_TYPE_SYMMETRIC_RANGE:
552
+ qrange = ONNX_INT_TYPE_SYMMETRIC_RANGE[qType]
553
+ else:
554
+ qrange = ONNX_INT_TYPE_RANGE.get(qType)
555
+
556
+ if not qrange:
557
+ raise ValueError(f"Unexpected data type {qType} requested. Only INT8, UINT8, INT16, and UINT16 are supported.")
558
+
559
+ qmin, qmax = qrange
560
+ if qmin > 0 or qmax < 0:
561
+ raise ValueError(
562
+ f"qmin and qmax must meet requirement: qmin <= 0 <= qmax while "
563
+ f"qmin:{qmin}, qmmax:{qmax}, dtype={qmin.dtype}, reduce_range={reduce_range}, "
564
+ f"symmetric={symmetric}, qType={qType}"
565
+ )
566
+
567
+ return qrange
568
+
569
+
570
+ def get_qrange_for_qType(qType, reduce_range=False, symmetric=False): # noqa: N802
571
+ """
572
+ Helper function to get the quantization range for a type.
573
+ parameter qType: quantization type.
574
+ return: quantization range.
575
+ """
576
+ qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range, symmetric=symmetric)
577
+ return qmax - qmin
578
+
579
+
580
+ def normalize_axis(axis: int, rank: int) -> tuple[bool, int]:
581
+ """
582
+ Helper function that tries to return a normalized axis in the range [0, rank - 1].
583
+ :parameter axis: The axis to normalize.
584
+ :parameter rank: The tensor rank (number of dimensions).
585
+ :return (is_valid, axis_norm)
586
+ """
587
+ axis_norm = axis + rank if axis < 0 else axis
588
+ is_valid = axis_norm >= 0 and axis_norm < rank
589
+ return is_valid, axis_norm
590
+
591
+
592
+ def pack_bytes_to_4bit(src_8bit: bytes) -> bytearray:
593
+ """
594
+ Copies a source array of 8-bit values into a destination bytearray of packed 4-bit values.
595
+ Assumes that the source values are already in the appropriate int4 range.
596
+ :parameter src_8bit: The 8-bit element values to pack.
597
+ :return A bytearray with every two 8-bit src elements packed into a single byte.
598
+ """
599
+ num_elems = len(src_8bit)
600
+ if num_elems == 0:
601
+ return bytearray()
602
+
603
+ dst_size = (num_elems + 1) // 2 # Ex: 5 8-bit elems packed into 3 bytes
604
+ dst = bytearray(dst_size)
605
+
606
+ src_i: int = 0
607
+ dst_i: int = 0
608
+
609
+ # Pack two 8-bit elements into a single byte in each iteration.
610
+ while src_i < num_elems - 1:
611
+ dst[dst_i] = ((src_8bit[src_i + 1] & 0xF) << 4) | (src_8bit[src_i] & 0xF)
612
+ dst_i += 1
613
+ src_i += 2
614
+
615
+ if src_i < num_elems:
616
+ # Odd number of elements.
617
+ dst[dst_i] = src_8bit[src_i] & 0xF
618
+
619
+ return dst
620
+
621
+
622
+ class QuantizedInitializer:
623
+ """
624
+ Represents a linearly quantized weight input from ONNX operators
625
+ """
626
+
627
+ def __init__(
628
+ self,
629
+ name,
630
+ initializer,
631
+ rmins,
632
+ rmaxs,
633
+ zero_points,
634
+ scales,
635
+ data=[], # noqa: B006
636
+ quantized_data=[], # noqa: B006
637
+ axis=None,
638
+ ):
639
+ self.name = name
640
+ self.initializer = initializer # TensorProto initializer in ONNX graph
641
+ self.rmins = rmins # List of minimum range for each axis
642
+ self.rmaxs = rmaxs # List of maximum range for each axis
643
+ # 1D tensor of zero points computed for each axis. scalar if axis is empty
644
+ self.zero_points = zero_points
645
+ self.scales = scales # 1D tensor of scales computed for each axis. scalar if axis is empty
646
+ self.data = data # original data from initializer TensorProto
647
+ self.quantized_data = quantized_data # weight-packed data from data
648
+ # Scalar to specify which dimension in the initializer to weight pack.
649
+ self.axis = axis
650
+ # If empty, single zero point and scales computed from a single rmin and rmax
651
+
652
+
653
+ class QuantizedValue:
654
+ """
655
+ Represents a linearly quantized value (input\\output\\intializer)
656
+ """
657
+
658
+ def __init__(
659
+ self,
660
+ name,
661
+ new_quantized_name,
662
+ scale_name,
663
+ zero_point_name,
664
+ quantized_value_type,
665
+ axis=None,
666
+ node_type=None,
667
+ node_qtype=None,
668
+ scale_type=None,
669
+ ):
670
+ self.original_name = name
671
+ self.q_name = new_quantized_name
672
+ self.scale_name = scale_name
673
+ self.zp_name = zero_point_name
674
+ self.value_type = quantized_value_type
675
+ self.axis = axis
676
+ self.node_type = node_type
677
+ self.node_qtype = node_qtype
678
+ self.scale_type = scale_type
679
+
680
+
681
+ class BiasToQuantize:
682
+ """
683
+ Represents a bias to be quantized
684
+ """
685
+
686
+ def __init__(self, bias_name, input_name, weight_name):
687
+ self.bias_name = bias_name
688
+ self.input_name = input_name
689
+ self.weight_name = weight_name
690
+
691
+
692
+ def attribute_to_kwarg(attribute):
693
+ """
694
+ Convert attribute to kwarg format for use with onnx.helper.make_node.
695
+ :parameter attribute: attribute in AttributeProto format.
696
+ :return: attribute in {key: value} format.
697
+ """
698
+ if attribute.type == 0:
699
+ raise ValueError(f"attribute {attribute.name} does not have type specified.")
700
+
701
+ # Based on attribute type definitions from AttributeProto
702
+ # definition in https://github.com/onnx/onnx/blob/main/onnx/onnx.proto
703
+ if attribute.type == 1:
704
+ value = attribute.f
705
+ elif attribute.type == 2:
706
+ value = attribute.i
707
+ elif attribute.type == 3:
708
+ value = attribute.s
709
+ elif attribute.type == 4:
710
+ value = attribute.t
711
+ elif attribute.type == 5:
712
+ value = attribute.g
713
+ elif attribute.type == 6:
714
+ value = attribute.floats
715
+ elif attribute.type == 7:
716
+ value = attribute.ints
717
+ elif attribute.type == 8:
718
+ value = attribute.strings
719
+ elif attribute.type == 9:
720
+ value = attribute.tensors
721
+ elif attribute.type == 10:
722
+ value = attribute.graphs
723
+ else:
724
+ raise ValueError(f"attribute {attribute.name} has unsupported type {attribute.type}.")
725
+
726
+ return {attribute.name: value}
727
+
728
+
729
+ def find_by_name(item_name, item_list):
730
+ """
731
+ Helper function to find item by name in a list.
732
+ parameter item_name: name of the item.
733
+ parameter item_list: list of items.
734
+ return: item if found. None otherwise.
735
+ """
736
+ items = [item for item in item_list if item.name == item_name]
737
+ return items[0] if len(items) > 0 else None
738
+
739
+
740
+ def get_elem_index(elem_name, elem_list):
741
+ """
742
+ Helper function to return index of an item in a node list
743
+ """
744
+ elem_idx = -1
745
+ for i in range(len(elem_list)):
746
+ if elem_list[i] == elem_name:
747
+ elem_idx = i
748
+ return elem_idx
749
+
750
+
751
+ def get_mul_node(inputs, output, name):
752
+ """
753
+ Helper function to create a Mul node.
754
+ parameter inputs: list of input names.
755
+ parameter output: output name.
756
+ parameter name: name of the node.
757
+ return: Mul node in NodeProto format.
758
+ """
759
+ return onnx.helper.make_node("Mul", inputs, [output], name)
760
+
761
+
762
+ def generate_identified_filename(filename: Path, identifier: str) -> Path:
763
+ """
764
+ Helper function to generate a identifiable filepath by concatenating the given identifier as a suffix.
765
+ """
766
+ return filename.parent.joinpath(filename.stem + identifier + filename.suffix)
767
+
768
+
769
+ def apply_plot(hist, hist_edges):
770
+ import sys # noqa: PLC0415
771
+
772
+ import matplotlib.pyplot as plt # noqa: PLC0415
773
+ import numpy # noqa: PLC0415
774
+
775
+ numpy.set_printoptions(threshold=sys.maxsize)
776
+ print("Histogram:")
777
+ print(hist)
778
+ print("Histogram Edges:")
779
+ print(hist_edges)
780
+ plt.stairs(hist, hist_edges, fill=True)
781
+ plt.xlabel("Tensor value")
782
+ plt.ylabel("Counts")
783
+ plt.title("Tensor value V.S. Counts")
784
+ plt.show()
785
+
786
+
787
+ def write_calibration_table(calibration_cache, dir="."):
788
+ """
789
+ Helper function to write calibration table to files.
790
+ """
791
+
792
+ import json # noqa: PLC0415
793
+
794
+ import flatbuffers # noqa: PLC0415
795
+ import numpy as np # noqa: PLC0415
796
+
797
+ import onnxruntime.quantization.CalTableFlatBuffers.KeyValue as KeyValue # noqa: PLC0415
798
+ import onnxruntime.quantization.CalTableFlatBuffers.TrtTable as TrtTable # noqa: PLC0415
799
+ from onnxruntime.quantization.calibrate import CalibrationMethod, TensorData, TensorsData # noqa: PLC0415
800
+
801
+ logging.info(f"calibration cache: {calibration_cache}")
802
+
803
+ class MyEncoder(json.JSONEncoder):
804
+ def default(self, obj):
805
+ if isinstance(obj, (TensorData, TensorsData)):
806
+ return obj.to_dict()
807
+ if isinstance(obj, np.ndarray):
808
+ return {"data": obj.tolist(), "dtype": str(obj.dtype), "CLS": "numpy.array"}
809
+ if isinstance(obj, CalibrationMethod):
810
+ return {"CLS": obj.__class__.__name__, "value": str(obj)}
811
+ return json.JSONEncoder.default(self, obj)
812
+
813
+ json_data = json.dumps(calibration_cache, cls=MyEncoder)
814
+
815
+ with open(os.path.join(dir, "calibration.json"), "w") as file:
816
+ file.write(json_data) # use `json.loads` to do the reverse
817
+
818
+ # Serialize data using FlatBuffers
819
+ zero = np.array(0)
820
+ builder = flatbuffers.Builder(1024)
821
+ key_value_list = []
822
+ for key in sorted(calibration_cache.keys()):
823
+ values = calibration_cache[key]
824
+ d_values = values.to_dict()
825
+ floats = [
826
+ float(d_values.get("highest", zero).item()),
827
+ float(d_values.get("lowest", zero).item()),
828
+ ]
829
+ value = str(max(floats))
830
+
831
+ flat_key = builder.CreateString(key)
832
+ flat_value = builder.CreateString(value)
833
+
834
+ KeyValue.KeyValueStart(builder)
835
+ KeyValue.KeyValueAddKey(builder, flat_key)
836
+ KeyValue.KeyValueAddValue(builder, flat_value)
837
+ key_value = KeyValue.KeyValueEnd(builder)
838
+
839
+ key_value_list.append(key_value)
840
+
841
+ TrtTable.TrtTableStartDictVector(builder, len(key_value_list))
842
+ for key_value in key_value_list:
843
+ builder.PrependUOffsetTRelative(key_value)
844
+ main_dict = builder.EndVector()
845
+
846
+ TrtTable.TrtTableStart(builder)
847
+ TrtTable.TrtTableAddDict(builder, main_dict)
848
+ cal_table = TrtTable.TrtTableEnd(builder)
849
+
850
+ builder.Finish(cal_table)
851
+ buf = builder.Output()
852
+
853
+ with open(os.path.join(dir, "calibration.flatbuffers"), "wb") as file:
854
+ file.write(buf)
855
+
856
+ # Deserialize data (for validation)
857
+ if os.environ.get("QUANTIZATION_DEBUG", "0") in (1, "1"):
858
+ cal_table = TrtTable.TrtTable.GetRootAsTrtTable(buf, 0)
859
+ dict_len = cal_table.DictLength()
860
+ for i in range(dict_len):
861
+ key_value = cal_table.Dict(i)
862
+ logging.info(key_value.Key())
863
+ logging.info(key_value.Value())
864
+
865
+ # write plain text
866
+ with open(os.path.join(dir, "calibration.cache"), "w") as file:
867
+ for key in sorted(calibration_cache.keys()):
868
+ values = calibration_cache[key]
869
+ d_values = values.to_dict()
870
+ floats = [
871
+ float(d_values.get("highest", zero).item()),
872
+ float(d_values.get("lowest", zero).item()),
873
+ ]
874
+ value = key + " " + str(max(floats))
875
+ file.write(value)
876
+ file.write("\n")
877
+
878
+
879
+ def smooth_distribution(p, eps=0.0001):
880
+ """Given a discrete distribution (may have not been normalized to 1),
881
+ smooth it by replacing zeros with eps multiplied by a scaling factor
882
+ and taking the corresponding amount off the non-zero values.
883
+ Ref: http://web.engr.illinois.edu/~hanj/cs412/bk3/KL-divergence.pdf
884
+ https://github.com//apache/incubator-mxnet/blob/master/python/mxnet/contrib/quantization.py
885
+ """
886
+ is_zeros = (p == 0).astype(numpy.float32)
887
+ is_nonzeros = (p != 0).astype(numpy.float32)
888
+ n_zeros = is_zeros.sum()
889
+ n_nonzeros = p.size - n_zeros
890
+
891
+ if not n_nonzeros:
892
+ # raise ValueError('The discrete probability distribution is malformed. All entries are 0.')
893
+ return None
894
+ eps1 = eps * float(n_zeros) / float(n_nonzeros)
895
+ assert eps1 < 1.0, f"n_zeros={n_zeros}, n_nonzeros={n_nonzeros}, eps1={eps1}"
896
+
897
+ hist = p.astype(numpy.float32)
898
+ hist += eps * is_zeros + (-eps1) * is_nonzeros
899
+ assert (hist <= 0).sum() == 0
900
+
901
+ return hist
902
+
903
+
904
+ def model_has_external_data(model_path: Path):
905
+ model = onnx.load(model_path.as_posix(), load_external_data=False)
906
+ return any(external_data_helper.uses_external_data(intializer) for intializer in model.graph.initializer)
907
+
908
+
909
+ def optimize_model(model_path: Path, opt_model_path: Path):
910
+ """
911
+ Generate model that applies graph optimization (constant folding, etc.)
912
+ parameter model_path: path to the original onnx model
913
+ parameter opt_model_path: path to the optimized onnx model
914
+ :return: optimized onnx model
915
+ """
916
+ sess_option = SessionOptions()
917
+ sess_option.optimized_model_filepath = opt_model_path.as_posix()
918
+ sess_option.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
919
+ kwargs = {}
920
+ # This will rename constant initializer names, disable it to make test pass.
921
+ kwargs["disabled_optimizers"] = ["ConstantSharing"]
922
+ _ = InferenceSession(model_path.as_posix(), sess_option, providers=["CPUExecutionProvider"], **kwargs)
923
+
924
+
925
+ def add_pre_process_metadata(model: ModelProto):
926
+ """Tag the model that it went through quantization pre-processing"""
927
+ metadata_props = {"onnx.quant.pre_process": "onnxruntime.quant"}
928
+ if model.metadata_props:
929
+ for prop in model.metadata_props:
930
+ metadata_props.update({prop.key: prop.value})
931
+ onnx.helper.set_model_props(model, metadata_props)
932
+
933
+
934
+ def model_has_pre_process_metadata(model: ModelProto) -> bool:
935
+ """Check the model whether it went through quantization pre-processing"""
936
+ if model.metadata_props:
937
+ for prop in model.metadata_props:
938
+ if prop.key == "onnx.quant.pre_process" and prop.value == "onnxruntime.quant":
939
+ return True
940
+ return False
941
+
942
+
943
+ def add_infer_metadata(model: ModelProto):
944
+ metadata_props = {"onnx.infer": "onnxruntime.quant"}
945
+ if model.metadata_props:
946
+ for p in model.metadata_props:
947
+ metadata_props.update({p.key: p.value})
948
+ onnx.helper.set_model_props(model, metadata_props)
949
+
950
+
951
+ def model_has_infer_metadata(model: ModelProto) -> bool:
952
+ if model.metadata_props:
953
+ for p in model.metadata_props:
954
+ if p.key == "onnx.infer" and p.value == "onnxruntime.quant":
955
+ return True
956
+ return False
957
+
958
+
959
+ def get_opset_version(model: ModelProto) -> int:
960
+ ai_onnx_domain = [opset for opset in model.opset_import if not opset.domain or opset.domain == "ai.onnx"]
961
+ if len(ai_onnx_domain) != 1:
962
+ raise ValueError("Failed to find proper ai.onnx domain")
963
+ opset_version = ai_onnx_domain[0].version
964
+
965
+ return opset_version
966
+
967
+
968
+ def update_opset_version(model: ModelProto, weight_type: QuantType) -> ModelProto:
969
+ opset_version = get_opset_version(model)
970
+ target_opset_version = opset_version
971
+ weight_quant_type = getattr(weight_type, "tensor_type", weight_type)
972
+
973
+ if opset_version < 19 and weight_quant_type == onnx.TensorProto.FLOAT8E4M3FN:
974
+ logging.warning(
975
+ f"The original model opset version is {opset_version}, which does not support quantization to float 8. "
976
+ "Please update the model to opset >= 19. Automatically update the model to opset 19. "
977
+ "Please verify the quantized model."
978
+ )
979
+ target_opset_version = 19
980
+
981
+ elif opset_version == 10:
982
+ logging.warning(
983
+ f"The original model opset version is {opset_version}, which does not support node fusions. "
984
+ "Please update the model to opset >= 11 for better performance."
985
+ )
986
+
987
+ elif opset_version < 10:
988
+ logging.warning(
989
+ f"The original model opset version is {opset_version}, which does not support quantization. "
990
+ "Please update the model to opset >= 11. Automatically update the model to opset 11. "
991
+ "Please verify the quantized model."
992
+ )
993
+ target_opset_version = 11
994
+
995
+ if target_opset_version != opset_version:
996
+ model = onnx.version_converter.convert_version(model, target_opset_version)
997
+ # Additional nodes may be added to the model during the opset version conversion. Run shape inference
998
+ # to ensure all nodes are included in model.graph.value_info.
999
+ model = save_and_reload_model_with_shape_infer(model)
1000
+
1001
+ return model
1002
+
1003
+
1004
+ def load_model_with_shape_infer(model_path: Path) -> ModelProto:
1005
+ inferred_model_path = generate_identified_filename(model_path, "-inferred")
1006
+ onnx.shape_inference.infer_shapes_path(str(model_path), str(inferred_model_path))
1007
+ model = onnx.load(inferred_model_path.as_posix())
1008
+ add_infer_metadata(model)
1009
+ inferred_model_path.unlink()
1010
+ return model
1011
+
1012
+
1013
+ def save_and_reload_model_with_shape_infer(model: ModelProto) -> ModelProto:
1014
+ with tempfile.TemporaryDirectory(prefix="ort.quant.") as quant_tmp_dir:
1015
+ model_copy = copy.deepcopy(model)
1016
+ model_path = Path(quant_tmp_dir).joinpath("model.onnx")
1017
+ onnx.save_model(model_copy, model_path.as_posix(), save_as_external_data=True)
1018
+ return load_model_with_shape_infer(model_path)
1019
+
1020
+
1021
+ def tensor_proto_to_array(initializer: TensorProto) -> numpy.ndarray:
1022
+ if initializer.data_type in (onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16):
1023
+ return onnx.numpy_helper.to_array(initializer)
1024
+
1025
+ raise ValueError(
1026
+ f"Only float type is supported. Weights {initializer.name} is {type_to_name[initializer.data_type]}"
1027
+ )
1028
+
1029
+
1030
+ def add_quant_suffix(tensor_name: str) -> str:
1031
+ return tensor_name + "_QuantizeLinear"
1032
+
1033
+
1034
+ def add_quant_input_suffix(tensor_name: str) -> str:
1035
+ return tensor_name + QUANT_INPUT_SUFFIX
1036
+
1037
+
1038
+ def add_quant_output_suffix(tensor_name) -> str:
1039
+ return tensor_name + "_QuantizeLinear_Output"
1040
+
1041
+
1042
+ def add_dequant_suffix(tensor_name) -> str:
1043
+ return tensor_name + "_DequantizeLinear"
1044
+
1045
+
1046
+ def add_dequant_input_suffix(tensor_name) -> str:
1047
+ return tensor_name + "_DequantizeLinear_Input"
1048
+
1049
+
1050
+ def add_dequant_output_suffix(tensor_name) -> str:
1051
+ return tensor_name + DEQUANT_OUTPUT_SUFFIX