onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,529 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ import logging
7
+ from typing import Any
8
+
9
+ import numpy as np
10
+ import onnx
11
+ import onnx.numpy_helper
12
+
13
+ try:
14
+ from onnx.reference.op_run import to_array_extended
15
+ except ImportError:
16
+ # old version of onnx.
17
+ to_array_extended = None
18
+
19
+ from .calibrate import TensorData
20
+ from .onnx_model import ONNXModel
21
+ from .quant_utils import (
22
+ DEQUANT_OP_NAME,
23
+ ONNX_TYPE_TO_NP_TYPE,
24
+ QUANT_OP_NAME,
25
+ TENSOR_NAME_QUANT_SUFFIX,
26
+ find_by_name,
27
+ get_opset_version,
28
+ model_has_infer_metadata,
29
+ normalize_axis,
30
+ pack_bytes_to_4bit,
31
+ quantize_data,
32
+ quantize_nparray,
33
+ save_and_reload_model_with_shape_infer,
34
+ tensor_proto_to_array,
35
+ )
36
+ from .tensor_quant_overrides import TensorQuantOverridesHelper
37
+
38
+
39
+ class QuantizationParams:
40
+ def __init__(self, **data: dict[str, Any]):
41
+ self.data = {}
42
+ for k, v in data.items():
43
+ if not isinstance(k, str):
44
+ raise TypeError(f"Keys must be strings not {type(k)} for k={k!r}.")
45
+ if k != "axis" and not isinstance(v, (int, str, np.ndarray, float)):
46
+ raise TypeError(f"Values must be numpy arrays, int, float, str not {type(v)} for k={k!r}.")
47
+ if k == "axis" and not isinstance(v, int) and v is not None:
48
+ raise TypeError(f"Axis value must be an int or None, not {type(v)}.")
49
+ if k == "scale" and v.dtype not in (np.float32, np.float16):
50
+ raise ValueError(f"scale must a float32 or float16 numpy element but is {v.dtype} for k={k!r}")
51
+ self.data[k] = v
52
+
53
+ def get(self, key, default_value=None):
54
+ return self.data.get(key, default_value)
55
+
56
+ def __iter__(self):
57
+ yield from self.data
58
+
59
+ def __getitem__(self, key):
60
+ return self.data[key]
61
+
62
+ def __setitem__(self, key, value):
63
+ self.data[key] = value
64
+
65
+ def __len__(self):
66
+ return len(self.data)
67
+
68
+
69
+ class BaseQuantizer:
70
+ def __init__(
71
+ self,
72
+ model,
73
+ per_channel,
74
+ reduce_range,
75
+ weight_qType,
76
+ activation_qType,
77
+ tensors_range,
78
+ nodes_to_quantize,
79
+ nodes_to_exclude,
80
+ op_types_to_quantize,
81
+ extra_options=None,
82
+ ):
83
+ if not model_has_infer_metadata(model):
84
+ model = save_and_reload_model_with_shape_infer(model)
85
+ self.value_infos = {vi.name: vi for vi in model.graph.value_info}
86
+ self.value_infos.update({ot.name: ot for ot in model.graph.output})
87
+ self.value_infos.update({it.name: it for it in model.graph.input})
88
+
89
+ self.model = ONNXModel(model)
90
+ self.opset_version = get_opset_version(model)
91
+ self.per_channel = per_channel # weight-pack per channel
92
+ self.reduce_range = reduce_range
93
+
94
+ self.extra_options = extra_options if extra_options else {}
95
+ self.enable_subgraph_quantization = (
96
+ "EnableSubgraph" in self.extra_options and self.extra_options["EnableSubgraph"]
97
+ )
98
+ self.parent = None
99
+ self.force_quantize_no_input_check = (
100
+ "ForceQuantizeNoInputCheck" in self.extra_options and self.extra_options["ForceQuantizeNoInputCheck"]
101
+ )
102
+
103
+ # If user does not explicitly set "WeightSymmetric", then the weight's quantization type determines
104
+ # the symmetry (i.e., signed integer types will use symmetric quantization). See `def is_weight_symmetric()`
105
+ self._is_weight_symmetric: bool | None = self.extra_options.get("WeightSymmetric", None)
106
+ self.is_activation_symmetric = self.extra_options.get("ActivationSymmetric", False)
107
+ self.min_real_range = self.extra_options.get("MinimumRealRange")
108
+
109
+ self.activation_qType = getattr(activation_qType, "tensor_type", activation_qType)
110
+ self.weight_qType = getattr(weight_qType, "tensor_type", weight_qType)
111
+
112
+ """
113
+ Dictionary specifying the min and max values for tensors. It has following format:
114
+ {
115
+ "param_name": [min, max]
116
+ }
117
+ example:
118
+ {
119
+ 'Conv_3:0': [np.float32(0), np.float32(0.5)],
120
+ 'Conv_4:0': [np.float32(1), np.float32(3.5)]
121
+ }
122
+ """
123
+ if tensors_range is not None and any(not isinstance(t, TensorData) for t in tensors_range.values()):
124
+ raise TypeError(
125
+ f"tensors_range contains unexpected types { {type(v) for v in tensors_range.values()} }, not TensorData."
126
+ )
127
+ self.tensors_range = tensors_range
128
+ self.nodes_to_quantize = nodes_to_quantize # specific nodes to quantize
129
+ self.nodes_to_exclude = nodes_to_exclude # specific nodes to exclude
130
+ self.op_types_to_quantize = op_types_to_quantize
131
+
132
+ # Get tensor-level quantization overrides and ensure they are valid.
133
+ self.tensor_quant_overrides = TensorQuantOverridesHelper(self.extra_options.get("TensorQuantOverrides", {}))
134
+
135
+ self.initializers = {initzer.name: initzer for initzer in self.model.initializer()}
136
+ overrides_valid, overrides_err = self.tensor_quant_overrides.is_valid(
137
+ self.initializers, self.value_infos.keys(), activation_qType
138
+ )
139
+ if not overrides_valid:
140
+ raise ValueError(overrides_err)
141
+
142
+ self.tensor_quant_override_qtypes = self.tensor_quant_overrides.get_quant_types()
143
+
144
+ def is_weight_symmetric(self, weight_quant_type: onnx.TensorProto.DataType) -> bool:
145
+ if self._is_weight_symmetric is not None:
146
+ return self._is_weight_symmetric # Return value explicitly set by user.
147
+ return weight_quant_type in (
148
+ onnx.TensorProto.INT4,
149
+ onnx.TensorProto.INT8,
150
+ onnx.TensorProto.INT16,
151
+ onnx.TensorProto.FLOAT8E4M3FN,
152
+ )
153
+
154
+ def quantize_model(self):
155
+ raise NotImplementedError
156
+
157
+ def is_input_a_initializer(self, input_name):
158
+ initializer = find_by_name(input_name, self.model.initializer())
159
+ return initializer is not None
160
+
161
+ def is_per_channel(self):
162
+ return self.per_channel
163
+
164
+ def is_valid_quantize_weight(self, weight_name):
165
+ weight = find_by_name(weight_name, self.model.initializer())
166
+ if weight is not None:
167
+ return weight.data_type in (onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT16)
168
+ if (not self.enable_subgraph_quantization) or (self.parent is None):
169
+ return False
170
+ return self.parent.is_valid_quantize_weight(weight_name)
171
+
172
+ def should_quantize_node(self, node):
173
+ if (
174
+ self.nodes_to_quantize is not None
175
+ and len(self.nodes_to_quantize) != 0
176
+ and node.name not in self.nodes_to_quantize
177
+ ):
178
+ return False
179
+
180
+ if node.op_type not in self.op_types_to_quantize:
181
+ return False
182
+
183
+ if node.op_type in (DEQUANT_OP_NAME, QUANT_OP_NAME):
184
+ return False
185
+
186
+ if self.nodes_to_exclude is not None and node.name in self.nodes_to_exclude:
187
+ return False
188
+
189
+ return True
190
+
191
+ def quantize_bias_static_impl(self, bias_name, input_scale, weight_scale, beta=1.0):
192
+ """
193
+ Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale
194
+ """
195
+
196
+ # get bias
197
+ bias_initializer = find_by_name(bias_name, self.model.initializer())
198
+ bias_data = tensor_proto_to_array(bias_initializer)
199
+ quantized_bias_name = bias_name + TENSOR_NAME_QUANT_SUFFIX
200
+
201
+ # quantize bias
202
+ if self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN:
203
+ data = np.asarray(bias_data)
204
+ if data.dtype == np.float16:
205
+ node_qtype = onnx.TensorProto.FLOAT16
206
+ elif data.dtype == np.float32:
207
+ node_qtype = onnx.TensorProto.FLOAT
208
+ else:
209
+ raise TypeError(f"Only float16 or float32 are supported with float 8 but bias dtype is {data.dtype}.")
210
+ quantized_data = data.astype(np.float32)
211
+ bias_scale = np.array([1], dtype=quantized_data.dtype)
212
+ bias_scale_data = bias_scale.reshape(-1)
213
+ packed_bias_initializer = onnx.numpy_helper.from_array(quantized_data, quantized_bias_name)
214
+ self.model.initializer_extend([packed_bias_initializer])
215
+ node_type = "Cast"
216
+ else:
217
+ # calculate scale for bias
218
+ # TODO: This formula should be explained including why the scale is not estimated for the bias as well.
219
+ bias_scale = input_scale * weight_scale * beta
220
+
221
+ # Quantize by dividing by bias_scale
222
+ quantized_data = np.asarray(bias_data, dtype=np.float64) / np.asarray(bias_scale, dtype=np.float64)
223
+ quantized_data = quantized_data.round()
224
+
225
+ # Clip quantized data to the range of a int32
226
+ int32_min = np.float64(np.iinfo(np.int32).min)
227
+ int32_max = np.float64(np.iinfo(np.int32).max)
228
+ if np.any(quantized_data < int32_min) or np.any(quantized_data > int32_max):
229
+ logging.warning(
230
+ f"Quantized bias `{bias_name}` exceeds the range of a int32. The bias scale is too small."
231
+ )
232
+
233
+ quantized_data = np.clip(quantized_data, int32_min, int32_max).astype(np.int32)
234
+
235
+ # update bias initializer
236
+ bias_np_data = np.asarray(quantized_data, dtype=np.int32).reshape(bias_initializer.dims)
237
+ packed_bias_initializer = onnx.numpy_helper.from_array(bias_np_data, quantized_bias_name)
238
+ self.model.initializer_extend([packed_bias_initializer])
239
+
240
+ # Bias's scale dtype should match the original bias data's unquantized type (float32 or float16).
241
+ bias_scale_data = np.asarray(bias_scale, dtype=bias_data.dtype).reshape(-1)
242
+ node_type = "DequantizeLinear"
243
+ node_qtype = self.weight_qType
244
+
245
+ # update scale initializer
246
+ quantized_bias_scale_name = quantized_bias_name + "_scale"
247
+ packed_bias_scale_initializer = onnx.numpy_helper.from_array(bias_scale_data, quantized_bias_scale_name)
248
+ self.model.initializer_extend([packed_bias_scale_initializer])
249
+
250
+ # update zero initializer
251
+ if self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN:
252
+ tensor_type = self.weight_qType
253
+ else:
254
+ tensor_type = onnx.TensorProto.INT32
255
+
256
+ quantized_bias_zp_name = quantized_bias_name + "_zero_point"
257
+ if self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN:
258
+ packed_bias_zp_initializer = onnx.helper.make_tensor(quantized_bias_zp_name, self.weight_qType, [1], [0.0])
259
+ elif bias_scale.size > 1:
260
+ bias_zp_data = np.zeros(bias_scale.shape, dtype=np.int32).reshape(-1)
261
+ packed_bias_zp_initializer = onnx.numpy_helper.from_array(bias_zp_data, quantized_bias_zp_name)
262
+ else:
263
+ packed_bias_zp_initializer = onnx.helper.make_tensor(quantized_bias_zp_name, tensor_type, [], [0])
264
+ self.model.initializer_extend([packed_bias_zp_initializer])
265
+
266
+ return (
267
+ quantized_bias_name,
268
+ quantized_bias_scale_name,
269
+ quantized_bias_zp_name,
270
+ bias_scale_data,
271
+ node_type,
272
+ node_qtype,
273
+ )
274
+
275
+ def quantize_initializer_impl(self, weight, qType, reduce_range=False, keep_float_weight=False):
276
+ """
277
+ :param weight: TensorProto initializer
278
+ :param qType: type to quantize to
279
+ :param keep_float_weight: Whether to quantize the weight. In some cases, we only want to qunatize scale and zero point.
280
+ If keep_float_weight is False, quantize the weight, or don't quantize the weight.
281
+ :return: quantized weight name, zero point name, scale name
282
+ """
283
+ # TODO(adrianlizarraga): This function is now only used by onnx_quantizer.py, so move it there.
284
+ q_weight_name = weight.name + TENSOR_NAME_QUANT_SUFFIX
285
+ zp_name = weight.name + "_zero_point"
286
+ scale_name = weight.name + "_scale"
287
+
288
+ # Quantize weight data. Use quantization overrides if provided by the user.
289
+ weight_data = tensor_proto_to_array(weight)
290
+ quant_overrides = self.tensor_quant_overrides.get_per_tensor_overrides(weight.name, default_val={})
291
+ if "quant_type" in quant_overrides:
292
+ qType = quant_overrides["quant_type"].tensor_type # noqa: N806
293
+
294
+ if "scale" in quant_overrides and "zero_point" in quant_overrides:
295
+ zero_point = np.array(quant_overrides["zero_point"], dtype=ONNX_TYPE_TO_NP_TYPE[qType])
296
+ scale = np.array(quant_overrides["scale"])
297
+ q_weight_data = quantize_nparray(qType, weight_data.flatten(), scale, zero_point)
298
+ assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
299
+ assert zero_point.dtype != np.float32 and zero_point.dtype != np.float16, (
300
+ f"Unexpected dtype {zero_point.dtype}"
301
+ )
302
+ assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
303
+
304
+ else:
305
+ symmetric = self.is_weight_symmetric(qType) if qType == self.weight_qType else self.is_activation_symmetric
306
+ zero_point, scale, q_weight_data = quantize_data(
307
+ weight_data.flatten(),
308
+ qType,
309
+ quant_overrides.get("symmetric", symmetric),
310
+ reduce_range=quant_overrides.get("reduce_range", self.reduce_range and reduce_range),
311
+ min_real_range=self.min_real_range,
312
+ rmin_override=quant_overrides.get("rmin"),
313
+ rmax_override=quant_overrides.get("rmax"),
314
+ )
315
+
316
+ assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
317
+ assert zero_point.dtype != np.float32 and zero_point.dtype != np.float16, (
318
+ f"Unexpected dtype {zero_point.dtype}"
319
+ )
320
+ assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
321
+
322
+ scale_dtype = weight.data_type
323
+ scale_initializer = onnx.helper.make_tensor(scale_name, scale_dtype, [], scale.reshape((-1,)).tolist())
324
+ zero_initializer = onnx.helper.make_tensor(zp_name, qType, [], zero_point.reshape((-1,)).tolist())
325
+ self.model.initializer_extend([scale_initializer, zero_initializer])
326
+
327
+ if not keep_float_weight:
328
+ if self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN:
329
+ q_weight_initializer = onnx.TensorProto()
330
+ q_weight_initializer.data_type = self.weight_qType
331
+ q_weight_initializer.dims.extend(weight.dims)
332
+ q_weight_initializer.name = q_weight_name
333
+ # Do not remove .flatten().copy() numpy is not clear about data persistence.
334
+ q_weight_initializer.raw_data = q_weight_data.flatten().copy().tobytes()
335
+ if to_array_extended is not None:
336
+ # This test should not be needed but it helped catch some issues
337
+ # with data persistence and tobytes.
338
+ check = to_array_extended(q_weight_initializer)
339
+ if check.shape != weight_data.shape or check.tobytes() != q_weight_data.tobytes():
340
+ raise RuntimeError(
341
+ f"The initializer of shape {weight_data.shape} could not be created, expecting "
342
+ f"{q_weight_data.tobytes()[:10]}, got {check.tobytes()[:10]} and shape={weight.shape}"
343
+ f"\nraw={str(q_weight_initializer)[:200]}."
344
+ )
345
+ elif qType in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4):
346
+ if q_weight_data.dtype not in (np.int8, np.uint8):
347
+ raise RuntimeError(
348
+ f"Quantized weights for {q_weight_name} must be 8-bit before packing as 4-bit values."
349
+ )
350
+
351
+ # We do not use onnx.helper.pack_float32_to_4bit() due to performance.
352
+ # This can be the difference between a large model taking 30 minutes to quantize vs 5 minutes.
353
+ packed_data = bytes(pack_bytes_to_4bit(q_weight_data.tobytes()))
354
+
355
+ # We only use onnx.helper.make_tensor with raw data due to bug: https://github.com/onnx/onnx/pull/6161
356
+ q_weight_initializer = onnx.helper.make_tensor(q_weight_name, qType, weight.dims, packed_data, raw=True)
357
+ else:
358
+ q_weight_data = np.asarray(q_weight_data, dtype=onnx.helper.tensor_dtype_to_np_dtype(qType)).reshape(
359
+ weight.dims
360
+ )
361
+ q_weight_initializer = onnx.numpy_helper.from_array(q_weight_data, q_weight_name)
362
+ self.model.initializer_extend([q_weight_initializer])
363
+
364
+ return q_weight_name, zp_name, scale_name
365
+
366
+ def quantize_weight_per_channel_impl(
367
+ self,
368
+ weight_name,
369
+ weight_qType,
370
+ channel_axis,
371
+ reduce_range=True,
372
+ keep_float_weight=False,
373
+ ):
374
+ # TODO(adrianlizarraga): This function is now only used by onnx_quantizer.py, so move it there.
375
+ initializer = find_by_name(weight_name, self.model.initializer())
376
+ if initializer is None:
377
+ raise ValueError("{} is not an initializer", weight_name)
378
+
379
+ weights = tensor_proto_to_array(initializer)
380
+ weights_rank = len(weights.shape)
381
+ is_axis_valid, axis_norm = normalize_axis(channel_axis, weights_rank)
382
+ if not is_axis_valid:
383
+ raise ValueError(
384
+ f"Weight {weight_name} has a per-channel axis with value {channel_axis} that is "
385
+ f"out-of-bounds for rank {weights_rank}"
386
+ )
387
+
388
+ channel_axis = axis_norm
389
+ channel_count = weights.shape[channel_axis]
390
+ quant_overrides_for_channels = self.tensor_quant_overrides.get_per_channel_overrides(
391
+ weight_name, default_val=[{"axis": channel_axis}]
392
+ )
393
+
394
+ num_channel_overrides = len(quant_overrides_for_channels)
395
+ if num_channel_overrides != 1 and num_channel_overrides != channel_count:
396
+ raise ValueError(
397
+ f"Per-channel tensor quantization overrides for {weight_name} must have "
398
+ f"either 1 or {channel_count} elements in the list of dictionaries."
399
+ )
400
+
401
+ is_axis_override_valid, axis_override = normalize_axis(quant_overrides_for_channels[0]["axis"], weights_rank)
402
+ if not is_axis_override_valid or axis_override != channel_axis:
403
+ raise ValueError(
404
+ f"Tensor quantization overrides for {weight_name} specify an unexpected axis. "
405
+ f"Expected {channel_axis}, but got {quant_overrides_for_channels[0]['axis']}."
406
+ )
407
+
408
+ # If user provides per-channel quantization overrides, all channels must use the same quant_type,
409
+ # axis, symmetric, and reduce_range values. So, just use the first channel's values.
410
+ if "quant_type" in quant_overrides_for_channels[0]:
411
+ weight_qType = quant_overrides_for_channels[0]["quant_type"].tensor_type # noqa: N806
412
+
413
+ symmetric = quant_overrides_for_channels[0].get("symmetric", self.is_weight_symmetric(weight_qType))
414
+ reduce_range = quant_overrides_for_channels[0].get("reduce_range", self.reduce_range and reduce_range)
415
+ zero_point_list = []
416
+ scale_list = []
417
+ quantized_per_channel_data_list = []
418
+ weights_shape = list(weights.shape)
419
+ reshape_dims = list(weights_shape) # deep copy
420
+ reshape_dims[channel_axis] = 1 # only one per channel for reshape
421
+ for i in range(channel_count):
422
+ per_channel_data = weights.take(i, channel_axis)
423
+ channel_override_index = i if i < num_channel_overrides else 0
424
+ channel_quant_overrides = quant_overrides_for_channels[channel_override_index]
425
+
426
+ if "scale" in channel_quant_overrides and "zero_point" in channel_quant_overrides:
427
+ zero_point = np.array(channel_quant_overrides["zero_point"], dtype=ONNX_TYPE_TO_NP_TYPE[weight_qType])
428
+ scale = np.array(channel_quant_overrides["scale"])
429
+ quantized_per_channel_data = quantize_nparray(
430
+ weight_qType, per_channel_data.flatten(), scale, zero_point
431
+ )
432
+ assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
433
+ assert zero_point.dtype != np.float32 and zero_point.dtype != np.float16, (
434
+ f"Unexpected dtype {zero_point.dtype}"
435
+ )
436
+ assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
437
+ assert isinstance(quantized_per_channel_data, np.ndarray), (
438
+ f"Unexpected type {type(quantized_per_channel_data)}"
439
+ )
440
+
441
+ else:
442
+ zero_point, scale, quantized_per_channel_data = quantize_data(
443
+ per_channel_data.flatten(),
444
+ weight_qType,
445
+ symmetric,
446
+ reduce_range=reduce_range,
447
+ min_real_range=self.min_real_range,
448
+ rmin_override=channel_quant_overrides.get("rmin"),
449
+ rmax_override=channel_quant_overrides.get("rmax"),
450
+ )
451
+
452
+ assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
453
+ assert zero_point.dtype != np.float32 and zero_point.dtype != np.float16, (
454
+ f"Unexpected dtype {zero_point.dtype}"
455
+ )
456
+ assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
457
+ assert isinstance(quantized_per_channel_data, np.ndarray), (
458
+ f"Unexpected type {type(quantized_per_channel_data)}"
459
+ )
460
+
461
+ zero_point_list.append(zero_point)
462
+ scale_list.append(scale)
463
+ quantized_per_channel_data_list.append(np.asarray(quantized_per_channel_data).reshape(reshape_dims))
464
+
465
+ # combine per_channel_data into one
466
+ quantized_weights = np.concatenate(quantized_per_channel_data_list, channel_axis)
467
+ q_weight_name = weight_name + TENSOR_NAME_QUANT_SUFFIX
468
+ zp_name = weight_name + "_zero_point"
469
+ scale_name = weight_name + "_scale"
470
+
471
+ # Update packed weight, zero point, and scale initializers
472
+ zero_scale_shape = [initializer.dims[channel_axis]]
473
+ scale_initializer = onnx.helper.make_tensor(
474
+ scale_name, initializer.data_type, zero_scale_shape, np.hstack(scale_list).tolist()
475
+ )
476
+ zero_initializer = onnx.helper.make_tensor(
477
+ zp_name, weight_qType, zero_scale_shape, np.hstack(zero_point_list).tolist()
478
+ )
479
+
480
+ self.model.initializer_extend([scale_initializer, zero_initializer])
481
+
482
+ if not keep_float_weight:
483
+ if weight_qType in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4):
484
+ if quantized_weights.dtype not in (np.int8, np.uint8):
485
+ raise RuntimeError(
486
+ f"Quantized weights for {q_weight_name} must be 8-bit before packing as 4-bit values."
487
+ )
488
+
489
+ # We do not use onnx.helper.pack_float32_to_4bit() due to performance.
490
+ # This can be the difference between a large model taking 30 minutes to quantize vs 5 minutes.
491
+ packed_data = bytes(pack_bytes_to_4bit(quantized_weights.tobytes()))
492
+
493
+ # We only use onnx.helper.make_tensor with raw data due to bug: https://github.com/onnx/onnx/pull/6161
494
+ q_weight_initializer = onnx.helper.make_tensor(
495
+ q_weight_name, weight_qType, weights_shape, packed_data, raw=True
496
+ )
497
+ self.model.initializer_extend([q_weight_initializer])
498
+ else:
499
+ quantized_weights = np.asarray(
500
+ quantized_weights,
501
+ dtype=onnx.helper.tensor_dtype_to_np_dtype(weight_qType),
502
+ ).reshape(initializer.dims)
503
+ q_weight_initializer = onnx.numpy_helper.from_array(quantized_weights, q_weight_name)
504
+ self.model.initializer_extend([q_weight_initializer])
505
+
506
+ return q_weight_name, zp_name, scale_name
507
+
508
+ def adjust_tensor_ranges(self):
509
+ if self.tensors_range is None:
510
+ return
511
+
512
+ for node in self.model.nodes():
513
+ # adjust tensor_ranges for input of Clip and Relu node
514
+ if node.op_type in ["Clip", "Relu"]:
515
+ if not self.should_quantize_node(node):
516
+ continue
517
+ if len(self.model.input_name_to_nodes()[node.input[0]]) != 1:
518
+ continue
519
+ if node.input[0] not in self.tensors_range or node.output[0] not in self.tensors_range:
520
+ continue
521
+ td = self.tensors_range[node.output[0]]
522
+ if not isinstance(td, TensorData):
523
+ raise TypeError(f"Unexpected type {type(td)} for {node.output[0]!r}.")
524
+ self.tensors_range[node.input[0]] = td
525
+ # Adjust Softmax to range from 0.0 to 1.0
526
+ elif node.op_type == "Softmax":
527
+ if not self.should_quantize_node(node):
528
+ continue
529
+ self.tensors_range[node.output[0]] = TensorData(lowest=np.float32(0.0), highest=np.float32(1.0))