onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,532 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ import logging
7
+ from typing import Any, Dict
8
+
9
+ import numpy as np
10
+ import onnx
11
+ import onnx.numpy_helper
12
+
13
+ try:
14
+ from onnx.reference.op_run import to_array_extended
15
+ except ImportError:
16
+ # old version of onnx.
17
+ to_array_extended = None
18
+
19
+ from .calibrate import TensorData
20
+ from .onnx_model import ONNXModel
21
+ from .quant_utils import (
22
+ ONNX_TYPE_TO_NP_TYPE,
23
+ TENSOR_NAME_QUANT_SUFFIX,
24
+ QuantType,
25
+ find_by_name,
26
+ model_has_infer_metadata,
27
+ normalize_axis,
28
+ pack_bytes_to_4bit,
29
+ quantize_data,
30
+ quantize_nparray,
31
+ save_and_reload_model_with_shape_infer,
32
+ tensor_proto_to_array,
33
+ )
34
+ from .tensor_quant_overrides import TensorQuantOverridesHelper
35
+
36
+
37
+ class QuantizationParams:
38
+ def __init__(self, **data: Dict[str, Any]):
39
+ self.data = {}
40
+ for k, v in data.items():
41
+ if not isinstance(k, str):
42
+ raise TypeError(f"Keys must be strings not {type(k)} for k={k!r}.")
43
+ if not isinstance(v, (int, str, np.ndarray)):
44
+ raise TypeError(f"Values must be numpy arrays, int, float, str not {type(v)} for k={k!r}.")
45
+ if k == "scale" and v.dtype not in (np.float32, np.float16):
46
+ raise ValueError(f"scale must a float32 or float16 numpy element but is {v.dtype} for k={k!r}")
47
+ self.data[k] = v
48
+
49
+ def __iter__(self):
50
+ yield from self.data
51
+
52
+ def __getitem__(self, key):
53
+ return self.data[key]
54
+
55
+ def __len__(self):
56
+ return len(self.data)
57
+
58
+
59
+ class BaseQuantizer:
60
+ def __init__(
61
+ self,
62
+ model,
63
+ per_channel,
64
+ reduce_range,
65
+ weight_qType,
66
+ activation_qType,
67
+ tensors_range,
68
+ nodes_to_quantize,
69
+ nodes_to_exclude,
70
+ op_types_to_quantize,
71
+ extra_options=None,
72
+ ):
73
+ if not model_has_infer_metadata(model):
74
+ model = save_and_reload_model_with_shape_infer(model)
75
+ self.value_infos = {vi.name: vi for vi in model.graph.value_info}
76
+ self.value_infos.update({ot.name: ot for ot in model.graph.output})
77
+ self.value_infos.update({it.name: it for it in model.graph.input})
78
+
79
+ self.model = ONNXModel(model)
80
+ self.per_channel = per_channel # weight-pack per channel
81
+ self.reduce_range = reduce_range
82
+
83
+ self.extra_options = extra_options if extra_options else {}
84
+ self.enable_subgraph_quantization = (
85
+ "EnableSubgraph" in self.extra_options and self.extra_options["EnableSubgraph"]
86
+ )
87
+ self.parent = None
88
+ self.force_quantize_no_input_check = (
89
+ "ForceQuantizeNoInputCheck" in self.extra_options and self.extra_options["ForceQuantizeNoInputCheck"]
90
+ )
91
+ self.is_weight_symmetric = self.extra_options.get(
92
+ "WeightSymmetric", weight_qType in (QuantType.QInt8, QuantType.QInt16, QuantType.QFLOAT8E4M3FN)
93
+ )
94
+ self.is_activation_symmetric = self.extra_options.get("ActivationSymmetric", False)
95
+ self.min_real_range = self.extra_options.get("MinimumRealRange")
96
+
97
+ self.activation_qType = getattr(activation_qType, "tensor_type", activation_qType)
98
+ self.weight_qType = getattr(weight_qType, "tensor_type", weight_qType)
99
+
100
+ """
101
+ Dictionary specifying the min and max values for tensors. It has following format:
102
+ {
103
+ "param_name": [min, max]
104
+ }
105
+ example:
106
+ {
107
+ 'Conv_3:0': [np.float32(0), np.float32(0.5)],
108
+ 'Conv_4:0': [np.float32(1), np.float32(3.5)]
109
+ }
110
+ """
111
+ if tensors_range is not None and any(map(lambda t: not isinstance(t, TensorData), tensors_range.values())):
112
+ raise TypeError(
113
+ f"tensors_range contains unexpected types {set(type(v) for v in tensors_range.values())}, not TensorData."
114
+ )
115
+ self.tensors_range = tensors_range
116
+ self.nodes_to_quantize = nodes_to_quantize # specific nodes to quantize
117
+ self.nodes_to_exclude = nodes_to_exclude # specific nodes to exclude
118
+ self.op_types_to_quantize = op_types_to_quantize
119
+
120
+ self.opset_version = self.check_opset_version()
121
+
122
+ # Get tensor-level quantization overrides and ensure they are valid.
123
+ self.tensor_quant_overrides = TensorQuantOverridesHelper(self.extra_options.get("TensorQuantOverrides", {}))
124
+
125
+ self.initializers = {initzer.name: initzer for initzer in self.model.initializer()}
126
+ overrides_valid, overrides_err = self.tensor_quant_overrides.is_valid(
127
+ self.initializers, self.value_infos.keys(), activation_qType
128
+ )
129
+ if not overrides_valid:
130
+ raise ValueError(overrides_err)
131
+
132
+ self.tensor_quant_override_qtypes = self.tensor_quant_overrides.get_quant_types()
133
+
134
+ def quantize_model(self):
135
+ raise NotImplementedError
136
+
137
+ def is_input_a_initializer(self, input_name):
138
+ initializer = find_by_name(input_name, self.model.initializer())
139
+ return initializer is not None
140
+
141
+ def is_per_channel(self):
142
+ return self.per_channel
143
+
144
+ def is_valid_quantize_weight(self, weight_name):
145
+ weight = find_by_name(weight_name, self.model.initializer())
146
+ if weight is not None:
147
+ return weight.data_type in (onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT16)
148
+ if (not self.enable_subgraph_quantization) or (self.parent is None):
149
+ return False
150
+ return self.parent.is_valid_quantize_weight(weight_name)
151
+
152
+ def should_quantize_node(self, node):
153
+ if (
154
+ self.nodes_to_quantize is not None
155
+ and len(self.nodes_to_quantize) != 0
156
+ and node.name not in self.nodes_to_quantize
157
+ ):
158
+ return False
159
+
160
+ if node.op_type not in self.op_types_to_quantize:
161
+ return False
162
+
163
+ if self.nodes_to_exclude is not None and node.name in self.nodes_to_exclude:
164
+ return False
165
+
166
+ return True
167
+
168
+ def check_opset_version(self):
169
+ ai_onnx_domain = [
170
+ opset for opset in self.model.model.opset_import if not opset.domain or opset.domain == "ai.onnx"
171
+ ]
172
+ if len(ai_onnx_domain) != 1:
173
+ raise ValueError("Failed to find proper ai.onnx domain")
174
+ opset_version = ai_onnx_domain[0].version
175
+
176
+ if opset_version == 10:
177
+ logging.warning(
178
+ f"The original model opset version is {opset_version}, which does not support node fusions. Please update the model to opset >= 11 for better performance."
179
+ )
180
+ return 10
181
+
182
+ if opset_version < 10:
183
+ logging.warning(
184
+ f"The original model opset version is {opset_version}, which does not support quantization. Please update the model to opset >= 11. Updating the model automatically to opset 11. Please verify the quantized model."
185
+ )
186
+ self.model.model.opset_import.remove(ai_onnx_domain[0])
187
+ self.model.model.opset_import.extend([onnx.helper.make_opsetid("", 11)])
188
+ opset_version = 11
189
+
190
+ if opset_version < 19 and self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN:
191
+ logging.warning(
192
+ f"The original model opset version is {opset_version}, which does not support quantization to float 8. "
193
+ "Please update the model to opset >= 19. Updating the model automatically to opset 19. "
194
+ "Please verify the quantized model."
195
+ )
196
+ self.model.model.opset_import.remove(ai_onnx_domain[0])
197
+ self.model.model.opset_import.extend([onnx.helper.make_opsetid("", 19)])
198
+ self.model.model.ir_version = 9
199
+ opset_version = 19
200
+
201
+ return opset_version
202
+
203
+ def quantize_bias_static_impl(self, bias_name, input_scale, weight_scale, beta=1.0):
204
+ """
205
+ Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale
206
+ """
207
+
208
+ # get bias
209
+ bias_initializer = find_by_name(bias_name, self.model.initializer())
210
+ bias_data = tensor_proto_to_array(bias_initializer)
211
+ quantized_bias_name = bias_name + TENSOR_NAME_QUANT_SUFFIX
212
+
213
+ # quantize bias
214
+ if self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN:
215
+ data = np.asarray(bias_data)
216
+ if data.dtype == np.float16:
217
+ node_qtype = onnx.TensorProto.FLOAT16
218
+ elif data.dtype == np.float32:
219
+ node_qtype = onnx.TensorProto.FLOAT
220
+ else:
221
+ raise TypeError(f"Only float16 or float32 are supported with float 8 but bias dtype is {data.dtype}.")
222
+ quantized_data = data.astype(np.float32)
223
+ bias_scale = np.array([1], dtype=quantized_data.dtype)
224
+ bias_scale_data = bias_scale.reshape(-1)
225
+ packed_bias_initializer = onnx.numpy_helper.from_array(quantized_data, quantized_bias_name)
226
+ self.model.initializer_extend([packed_bias_initializer])
227
+ node_type = "Cast"
228
+ else:
229
+ # calculate scale for bias
230
+ # TODO: This formula should be explained including why the scale is not estimated for the bias as well.
231
+ bias_scale = input_scale * weight_scale * beta
232
+
233
+ quantized_data = (np.asarray(bias_data) / bias_scale).round()
234
+ quantized_data = np.clip(quantized_data, np.iinfo(np.int32).min, np.iinfo(np.int32).max)
235
+ quantized_data = quantized_data.astype(np.int32)
236
+
237
+ # update bias initializer
238
+ bias_np_data = np.asarray(quantized_data, dtype=np.int32).reshape(bias_initializer.dims)
239
+ packed_bias_initializer = onnx.numpy_helper.from_array(bias_np_data, quantized_bias_name)
240
+ self.model.initializer_extend([packed_bias_initializer])
241
+
242
+ # Bias's scale dtype should match the original bias data's unquantized type (float32 or float16).
243
+ bias_scale_data = np.asarray(bias_scale, dtype=bias_data.dtype).reshape(-1)
244
+ node_type = "DequantizeLinear"
245
+ node_qtype = self.weight_qType
246
+
247
+ # update scale initializer
248
+ quantized_bias_scale_name = quantized_bias_name + "_scale"
249
+ packed_bias_scale_initializer = onnx.numpy_helper.from_array(bias_scale_data, quantized_bias_scale_name)
250
+ self.model.initializer_extend([packed_bias_scale_initializer])
251
+
252
+ # update zero initializer
253
+ if self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN:
254
+ tensor_type = self.weight_qType
255
+ else:
256
+ tensor_type = onnx.TensorProto.INT32
257
+
258
+ quantized_bias_zp_name = quantized_bias_name + "_zero_point"
259
+ if self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN:
260
+ packed_bias_zp_initializer = onnx.helper.make_tensor(quantized_bias_zp_name, self.weight_qType, [1], [0.0])
261
+ elif bias_scale.size > 1:
262
+ bias_zp_data = np.zeros(bias_scale.shape, dtype=np.int32).reshape(-1)
263
+ packed_bias_zp_initializer = onnx.numpy_helper.from_array(bias_zp_data, quantized_bias_zp_name)
264
+ else:
265
+ packed_bias_zp_initializer = onnx.helper.make_tensor(quantized_bias_zp_name, tensor_type, [], [0])
266
+ self.model.initializer_extend([packed_bias_zp_initializer])
267
+
268
+ return (
269
+ quantized_bias_name,
270
+ quantized_bias_scale_name,
271
+ quantized_bias_zp_name,
272
+ bias_scale_data,
273
+ node_type,
274
+ node_qtype,
275
+ )
276
+
277
+ def quantize_initializer_impl(self, weight, qType, reduce_range=False, keep_float_weight=False):
278
+ """
279
+ :param weight: TensorProto initializer
280
+ :param qType: type to quantize to
281
+ :param keep_float_weight: Whether to quantize the weight. In some cases, we only want to qunatize scale and zero point.
282
+ If keep_float_weight is False, quantize the weight, or don't quantize the weight.
283
+ :return: quantized weight name, zero point name, scale name
284
+ """
285
+ q_weight_name = weight.name + TENSOR_NAME_QUANT_SUFFIX
286
+ zp_name = weight.name + "_zero_point"
287
+ scale_name = weight.name + "_scale"
288
+
289
+ # Quantize weight data. Use quantization overrides if provided by the user.
290
+ weight_data = tensor_proto_to_array(weight)
291
+ quant_overrides = self.tensor_quant_overrides.get_per_tensor_overrides(weight.name, default_val={})
292
+ if "quant_type" in quant_overrides:
293
+ qType = quant_overrides["quant_type"].tensor_type # noqa: N806
294
+
295
+ if "scale" in quant_overrides and "zero_point" in quant_overrides:
296
+ zero_point = np.array(quant_overrides["zero_point"], dtype=ONNX_TYPE_TO_NP_TYPE[qType])
297
+ scale = np.array(quant_overrides["scale"])
298
+ q_weight_data = quantize_nparray(qType, weight_data.flatten(), scale, zero_point)
299
+ assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
300
+ assert (
301
+ zero_point.dtype != np.float32 and zero_point.dtype != np.float16
302
+ ), f"Unexpected dtype {zero_point.dtype}"
303
+ assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
304
+
305
+ else:
306
+ _, _, zero_point, scale, q_weight_data = quantize_data(
307
+ weight_data.flatten(),
308
+ qType,
309
+ quant_overrides.get("symmetric", self.is_weight_symmetric),
310
+ reduce_range=quant_overrides.get("reduce_range", self.reduce_range and reduce_range),
311
+ min_real_range=self.min_real_range,
312
+ rmin_override=quant_overrides.get("rmin"),
313
+ rmax_override=quant_overrides.get("rmax"),
314
+ )
315
+
316
+ assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
317
+ assert (
318
+ zero_point.dtype != np.float32 and zero_point.dtype != np.float16
319
+ ), f"Unexpected dtype {zero_point.dtype}"
320
+ assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
321
+
322
+ scale_dtype = weight.data_type
323
+ scale_initializer = onnx.helper.make_tensor(scale_name, scale_dtype, [], scale.reshape((-1,)).tolist())
324
+ zero_initializer = onnx.helper.make_tensor(zp_name, qType, [], zero_point.reshape((-1,)).tolist())
325
+ self.model.initializer_extend([scale_initializer, zero_initializer])
326
+
327
+ if not keep_float_weight:
328
+ if self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN:
329
+ q_weight_initializer = onnx.TensorProto()
330
+ q_weight_initializer.data_type = self.weight_qType
331
+ q_weight_initializer.dims.extend(weight.dims)
332
+ q_weight_initializer.name = q_weight_name
333
+ # Do not remove .flatten().copy() numpy is not clear about data persistence.
334
+ q_weight_initializer.raw_data = q_weight_data.flatten().copy().tobytes()
335
+ if to_array_extended is not None:
336
+ # This test should not be needed but it helped catch some issues
337
+ # with data persistence and tobytes.
338
+ check = to_array_extended(q_weight_initializer)
339
+ if check.shape != weight_data.shape or check.tobytes() != q_weight_data.tobytes():
340
+ raise RuntimeError(
341
+ f"The initializer of shape {weight_data.shape} could not be created, expecting "
342
+ f"{q_weight_data.tobytes()[:10]}, got {check.tobytes()[:10]} and shape={weight.shape}"
343
+ f"\nraw={str(q_weight_initializer)[:200]}."
344
+ )
345
+ elif qType in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4):
346
+ if q_weight_data.dtype not in (np.int8, np.uint8):
347
+ raise RuntimeError(
348
+ f"Quantized weights for {q_weight_name} must be 8-bit before packing as 4-bit values."
349
+ )
350
+
351
+ # We do not use onnx.helper.pack_float32_to_4bit() due to performance.
352
+ # This can be the difference between a large model taking 30 minutes to quantize vs 5 minutes.
353
+ packed_data = bytes(pack_bytes_to_4bit(q_weight_data.tobytes()))
354
+
355
+ # We only use onnx.helper.make_tensor with raw data due to bug: https://github.com/onnx/onnx/pull/6161
356
+ q_weight_initializer = onnx.helper.make_tensor(q_weight_name, qType, weight.dims, packed_data, raw=True)
357
+ else:
358
+ q_weight_data = np.asarray(q_weight_data, dtype=onnx.helper.tensor_dtype_to_np_dtype(qType)).reshape(
359
+ weight.dims
360
+ )
361
+ q_weight_initializer = onnx.numpy_helper.from_array(q_weight_data, q_weight_name)
362
+ self.model.initializer_extend([q_weight_initializer])
363
+
364
+ return q_weight_name, zp_name, scale_name
365
+
366
+ def quantize_weight_per_channel_impl(
367
+ self,
368
+ weight_name,
369
+ weight_qType,
370
+ channel_axis,
371
+ reduce_range=True,
372
+ keep_float_weight=False,
373
+ ):
374
+ initializer = find_by_name(weight_name, self.model.initializer())
375
+ if initializer is None:
376
+ raise ValueError("{} is not an initializer", weight_name)
377
+
378
+ weights = tensor_proto_to_array(initializer)
379
+ weights_rank = len(weights.shape)
380
+ is_axis_valid, axis_norm = normalize_axis(channel_axis, weights_rank)
381
+ if not is_axis_valid:
382
+ raise ValueError(
383
+ f"Weight {weight_name} has a per-channel axis with value {channel_axis} that is "
384
+ f"out-of-bounds for rank {weights_rank}"
385
+ )
386
+
387
+ channel_axis = axis_norm
388
+ channel_count = weights.shape[channel_axis]
389
+ quant_overrides_for_channels = self.tensor_quant_overrides.get_per_channel_overrides(
390
+ weight_name, default_val=[{"axis": channel_axis}]
391
+ )
392
+
393
+ num_channel_overrides = len(quant_overrides_for_channels)
394
+ if num_channel_overrides != 1 and num_channel_overrides != channel_count:
395
+ raise ValueError(
396
+ f"Per-channel tensor quantization overrides for {weight_name} must have "
397
+ f"either 1 or {channel_count} elements in the list of dictionaries."
398
+ )
399
+
400
+ is_axis_override_valid, axis_override = normalize_axis(quant_overrides_for_channels[0]["axis"], weights_rank)
401
+ if not is_axis_override_valid or axis_override != channel_axis:
402
+ raise ValueError(
403
+ f"Tensor quantization overrides for {weight_name} specify an unexpected axis. "
404
+ f"Expected {channel_axis}, but got {quant_overrides_for_channels[0]['axis']}."
405
+ )
406
+
407
+ # If user provides per-channel quantization overrides, all channels must use the same quant_type,
408
+ # axis, symmetric, and reduce_range values. So, just use the first channel's values.
409
+ if "quant_type" in quant_overrides_for_channels[0]:
410
+ weight_qType = quant_overrides_for_channels[0]["quant_type"].tensor_type # noqa: N806
411
+
412
+ symmetric = quant_overrides_for_channels[0].get(
413
+ "symmetric",
414
+ (
415
+ self.is_weight_symmetric
416
+ or weight_qType in (onnx.TensorProto.INT8, onnx.TensorProto.FLOAT8E4M3FN, onnx.TensorProto.INT4)
417
+ ),
418
+ )
419
+ reduce_range = quant_overrides_for_channels[0].get("reduce_range", self.reduce_range and reduce_range)
420
+ zero_point_list = []
421
+ scale_list = []
422
+ quantized_per_channel_data_list = []
423
+ weights_shape = list(weights.shape)
424
+ reshape_dims = list(weights_shape) # deep copy
425
+ reshape_dims[channel_axis] = 1 # only one per channel for reshape
426
+ for i in range(channel_count):
427
+ per_channel_data = weights.take(i, channel_axis)
428
+ channel_override_index = i if i < num_channel_overrides else 0
429
+ channel_quant_overrides = quant_overrides_for_channels[channel_override_index]
430
+
431
+ if "scale" in channel_quant_overrides and "zero_point" in channel_quant_overrides:
432
+ zero_point = np.array(channel_quant_overrides["zero_point"], dtype=ONNX_TYPE_TO_NP_TYPE[weight_qType])
433
+ scale = np.array(channel_quant_overrides["scale"])
434
+ quantized_per_channel_data = quantize_nparray(
435
+ weight_qType, per_channel_data.flatten(), scale, zero_point
436
+ )
437
+ assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
438
+ assert (
439
+ zero_point.dtype != np.float32 and zero_point.dtype != np.float16
440
+ ), f"Unexpected dtype {zero_point.dtype}"
441
+ assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
442
+ assert isinstance(
443
+ quantized_per_channel_data, np.ndarray
444
+ ), f"Unexpected type {type(quantized_per_channel_data)}"
445
+
446
+ else:
447
+ _, _, zero_point, scale, quantized_per_channel_data = quantize_data(
448
+ per_channel_data.flatten(),
449
+ weight_qType,
450
+ symmetric,
451
+ reduce_range=reduce_range,
452
+ min_real_range=self.min_real_range,
453
+ rmin_override=channel_quant_overrides.get("rmin"),
454
+ rmax_override=channel_quant_overrides.get("rmax"),
455
+ )
456
+
457
+ assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
458
+ assert (
459
+ zero_point.dtype != np.float32 and zero_point.dtype != np.float16
460
+ ), f"Unexpected dtype {zero_point.dtype}"
461
+ assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
462
+ assert isinstance(
463
+ quantized_per_channel_data, np.ndarray
464
+ ), f"Unexpected type {type(quantized_per_channel_data)}"
465
+
466
+ zero_point_list.append(zero_point)
467
+ scale_list.append(scale)
468
+ quantized_per_channel_data_list.append(np.asarray(quantized_per_channel_data).reshape(reshape_dims))
469
+
470
+ # combine per_channel_data into one
471
+ quantized_weights = np.concatenate(quantized_per_channel_data_list, channel_axis)
472
+ q_weight_name = weight_name + TENSOR_NAME_QUANT_SUFFIX
473
+ zp_name = weight_name + "_zero_point"
474
+ scale_name = weight_name + "_scale"
475
+
476
+ # Update packed weight, zero point, and scale initializers
477
+ zero_scale_shape = [initializer.dims[channel_axis]]
478
+ scale_initializer = onnx.helper.make_tensor(
479
+ scale_name, initializer.data_type, zero_scale_shape, np.hstack(scale_list).tolist()
480
+ )
481
+ zero_initializer = onnx.helper.make_tensor(
482
+ zp_name, weight_qType, zero_scale_shape, np.hstack(zero_point_list).tolist()
483
+ )
484
+
485
+ self.model.initializer_extend([scale_initializer, zero_initializer])
486
+
487
+ if not keep_float_weight:
488
+ if weight_qType in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4):
489
+ if quantized_weights.dtype not in (np.int8, np.uint8):
490
+ raise RuntimeError(
491
+ f"Quantized weights for {q_weight_name} must be 8-bit before packing as 4-bit values."
492
+ )
493
+
494
+ # We do not use onnx.helper.pack_float32_to_4bit() due to performance.
495
+ # This can be the difference between a large model taking 30 minutes to quantize vs 5 minutes.
496
+ packed_data = bytes(pack_bytes_to_4bit(quantized_weights.tobytes()))
497
+
498
+ # We only use onnx.helper.make_tensor with raw data due to bug: https://github.com/onnx/onnx/pull/6161
499
+ q_weight_initializer = onnx.helper.make_tensor(
500
+ q_weight_name, weight_qType, weights_shape, packed_data, raw=True
501
+ )
502
+ self.model.initializer_extend([q_weight_initializer])
503
+ else:
504
+ quantized_weights = np.asarray(
505
+ quantized_weights,
506
+ dtype=onnx.helper.tensor_dtype_to_np_dtype(weight_qType),
507
+ ).reshape(initializer.dims)
508
+ q_weight_initializer = onnx.numpy_helper.from_array(quantized_weights, q_weight_name)
509
+ self.model.initializer_extend([q_weight_initializer])
510
+
511
+ return q_weight_name, zp_name, scale_name
512
+
513
+ def adjust_tensor_ranges(self):
514
+ if self.tensors_range is None:
515
+ return
516
+
517
+ for node in self.model.nodes():
518
+ # adjust tensor_ranges for input of Clip and Relu node
519
+ if node.op_type in ["Clip", "Relu"]:
520
+ if not self.should_quantize_node(node):
521
+ continue
522
+ if len(self.model.input_name_to_nodes()[node.input[0]]) != 1:
523
+ continue
524
+ if node.input[0] not in self.tensors_range or node.output[0] not in self.tensors_range:
525
+ continue
526
+ td = self.tensors_range[node.output[0]]
527
+ if not isinstance(td, TensorData):
528
+ raise TypeError(f"Unexpected type {type(td)} for {node.output[0]!r}.")
529
+ self.tensors_range[node.input[0]] = td
530
+ # Adjust Softmax to range from 0.0 to 1.0
531
+ elif node.op_type == "Softmax":
532
+ self.tensors_range[node.output[0]] = TensorData(lowest=np.float32(0.0), highest=np.float32(1.0))