onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1638 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ from __future__ import annotations
8
+
9
+ import argparse
10
+ import copy
11
+ import logging
12
+ import os
13
+
14
+ import ml_dtypes
15
+ import numpy as np
16
+ import numpy.typing as npt
17
+ import onnx
18
+ import onnx_ir as ir
19
+ from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto
20
+
21
+ from onnxruntime.capi._pybind_state import (
22
+ quantize_matmul_2bits,
23
+ quantize_matmul_4bits,
24
+ quantize_matmul_8bits,
25
+ quantize_qdq_matmul_4bits,
26
+ )
27
+
28
+ from .calibrate import CalibrationDataReader
29
+ from .neural_compressor import gptq_quantize, rtn_quantize
30
+ from .onnx_model import ONNXModel
31
+ from .quant_utils import QuantFormat, attribute_to_kwarg
32
+
33
+ logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.INFO)
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class WeightOnlyQuantConfig:
38
+ def __init__(
39
+ self,
40
+ algorithm: str,
41
+ quant_format: QuantFormat,
42
+ op_types_to_quantize: tuple[str, ...] | None = None,
43
+ quant_axes: tuple[tuple[str, int], ...] | None = None,
44
+ customized_weight_config: dict | None = None,
45
+ ):
46
+ """This is the Base class for Weight Only blockwise quantization Configuration.
47
+
48
+ Args:
49
+ algorithm:
50
+ weight only quantize algorithm name.
51
+ quant_format: QuantFormat{QOperator, QDQ}.
52
+ QOperator format quantizes the model with quantized operators directly.
53
+ QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
54
+ op_types_to_quantize (optional):
55
+ set of operator types to quantize. Default {MatMul}
56
+ quant_axes (dict[str, int], optional):
57
+ op:axis, which axis to quantize for an op. Default {MatMul: 0, Gather: 1}
58
+ customized_weight_config:
59
+ customized weight config for nodes if needed. It is dictionary with node name as key,
60
+ and the value is a dict of customized config.
61
+ """
62
+ self.algorithm = algorithm
63
+ self.quant_format = quant_format
64
+ self.op_types_to_quantize = set(op_types_to_quantize) if op_types_to_quantize else {"MatMul"}
65
+ self.quant_axes = dict(quant_axes) if quant_axes else {"MatMul": 0, "Gather": 1}
66
+ self.customized_weight_config = customized_weight_config
67
+
68
+
69
+ class RTNWeightOnlyQuantConfig(WeightOnlyQuantConfig):
70
+ def __init__(
71
+ self,
72
+ ratios=None,
73
+ quant_format=QuantFormat.QOperator,
74
+ op_types_to_quantize: tuple[str, ...] | None = None,
75
+ customized_weight_config: dict | None = None,
76
+ ):
77
+ """
78
+ This is a class for round-to-nearest (RTN) algorithm Weight Only Quant Configuration.
79
+ RTN is the most straightforward way to quantize weight using scale maps.
80
+
81
+ Args:
82
+ ratios:
83
+ percentile of clip. Defaults to {}.
84
+ quant_format (QuantFormat{QOperator, QDQ}, optional):
85
+ QOperator format quantizes the model with quantized operators directly.
86
+ QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
87
+ Defaults to QuantFormat.QOperator.
88
+ op_types_to_quantize (optional):
89
+ set of operator types to quantize.
90
+ customized_weight_config:
91
+ customized weight config for nodes if needed. It is dictionary with node name as key,
92
+ and the value is a dict of customized config.
93
+ """
94
+ assert quant_format == QuantFormat.QOperator, "RTN only supports QOperator format"
95
+
96
+ if ratios is None:
97
+ ratios = {}
98
+ super().__init__(
99
+ algorithm="RTN",
100
+ quant_format=quant_format,
101
+ op_types_to_quantize=op_types_to_quantize,
102
+ customized_weight_config=customized_weight_config,
103
+ )
104
+ self.ratios = ratios
105
+
106
+
107
+ class KQuantWeightOnlyQuantConfig(WeightOnlyQuantConfig):
108
+ def __init__(
109
+ self,
110
+ ratios=None,
111
+ quant_format=QuantFormat.QOperator,
112
+ op_types_to_quantize: tuple[str, ...] | None = None,
113
+ customized_weight_config: dict | None = None,
114
+ ):
115
+ """
116
+ This is a class for k-quant algorithm Weight Only Quant Configuration.
117
+
118
+ Args:
119
+ ratios:
120
+ percentile of clip. Defaults to {}.
121
+ quant_format (QuantFormat{QOperator, QDQ}, optional):
122
+ QOperator format quantizes the model with quantized operators directly.
123
+ QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
124
+ Defaults to QuantFormat.QOperator.
125
+ op_types_to_quantize (optional):
126
+ set of operator types to quantize.
127
+ """
128
+ assert quant_format == QuantFormat.QOperator, "k-quant only supports QOperator format"
129
+
130
+ if ratios is None:
131
+ ratios = {}
132
+ super().__init__(
133
+ algorithm="k_quant",
134
+ quant_format=quant_format,
135
+ op_types_to_quantize=op_types_to_quantize,
136
+ customized_weight_config=customized_weight_config,
137
+ )
138
+ self.ratios = ratios
139
+
140
+
141
+ class GPTQWeightOnlyQuantConfig(WeightOnlyQuantConfig):
142
+ def __init__(
143
+ self,
144
+ calibration_data_reader: CalibrationDataReader | None = None,
145
+ percdamp=0.01,
146
+ block_size=128,
147
+ actorder=False,
148
+ mse=False,
149
+ perchannel=True,
150
+ quant_format=QuantFormat.QOperator,
151
+ op_types_to_quantize: tuple[str, ...] | None = None,
152
+ ):
153
+ """
154
+ This is a class for GPTQ algorithm Weight Only Quant Configuration.
155
+ GPTQ algorithm provides more accurate quantization but requires more computational resources.
156
+
157
+ Args:
158
+ calibration_data_reader:
159
+ a calibration data reader. It enumerates calibration data and generates inputs for the original model.
160
+ percdamp:
161
+ percent of the average Hessian diagonal to use for dampening.
162
+ block_size (int, optional):
163
+ channel number in one block to execute a GPTQ quantization iteration.
164
+ actorder (bool, optional):
165
+ whether rearrange Hessian matrix considering the diag's value.
166
+ mse (bool, optional):
167
+ whether get scale and zero point with mse error.
168
+ perchannel (bool, optional):
169
+ whether quantize weight per-channel.
170
+ quant_format (QuantFormat{QOperator, QDQ}, optional):
171
+ QOperator format quantizes the model with quantized operators directly.
172
+ QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
173
+ Defaults to QuantFormat.QOperator.
174
+ op_types_to_quantize (optional):
175
+ set of operator types to quantize.
176
+ """
177
+ assert quant_format == QuantFormat.QOperator, "GPTQ only supports QOperator format"
178
+
179
+ super().__init__(
180
+ algorithm="GPTQ",
181
+ quant_format=quant_format,
182
+ op_types_to_quantize=op_types_to_quantize,
183
+ )
184
+ self.calibration_data_reader = calibration_data_reader
185
+ self.percdamp = percdamp
186
+ self.block_size = block_size
187
+ self.actorder = actorder
188
+ self.mse = mse
189
+ self.perchannel = perchannel
190
+
191
+
192
+ class HQQWeightOnlyQuantConfig(WeightOnlyQuantConfig):
193
+ def __init__(
194
+ self,
195
+ block_size=128,
196
+ bits=4,
197
+ axis=1,
198
+ quant_format=QuantFormat.QOperator,
199
+ op_types_to_quantize: tuple[str, ...] | None = None,
200
+ quant_axes: tuple[tuple[str, int], ...] | None = None,
201
+ ):
202
+ """
203
+ This is a class for HQQ algorithm Weight Only Quant Configuration.
204
+ HQQ algorithm quant weight without needing calibrate data.
205
+
206
+ Args:
207
+ block_size (int, optional):
208
+ channel number in one block to execute a HQQ quantization iteration.
209
+ bits (int, optional):
210
+ how many bits to represent weight.
211
+ axis (int, optional):
212
+ 0 or 1. which axis to quantize. https://arxiv.org/pdf/2309.15531.pdf
213
+ quant_format (QuantFormat{QOperator, QDQ}, optional):
214
+ QOperator format quantizes the model with quantized operators directly.
215
+ QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
216
+ Defaults to QuantFormat.QOperator.
217
+ op_types_to_quantize (optional):
218
+ set of operator types to quantize.
219
+ quant_axes (dict[str, int], optional):
220
+ op:axis, which axis to quantize for an op. Default {MatMul: 0, Gather: 1}
221
+ """
222
+ assert quant_format == QuantFormat.QOperator, "HQQ only supports QOperator format"
223
+
224
+ super().__init__(
225
+ algorithm="HQQ",
226
+ quant_format=quant_format,
227
+ op_types_to_quantize=op_types_to_quantize,
228
+ quant_axes=quant_axes,
229
+ )
230
+ self.block_size = block_size
231
+ self.bits = bits
232
+ self.axis = axis
233
+
234
+
235
+ class DefaultWeightOnlyQuantConfig(WeightOnlyQuantConfig):
236
+ def __init__(
237
+ self,
238
+ block_size: int = 128,
239
+ is_symmetric: bool = False,
240
+ accuracy_level: int | None = None,
241
+ quant_format=QuantFormat.QOperator,
242
+ op_types_to_quantize: tuple[str, ...] | None = None,
243
+ quant_axes: tuple[tuple[str, int], ...] | None = None,
244
+ bits: int = 4,
245
+ channel_wised_quantize: bool = False,
246
+ ):
247
+ """
248
+ This is a class for weight only affine quantization configuration.
249
+
250
+ Args:
251
+ block_size (int, optional):
252
+ channel number in one block to execute an affine quantization iteration.
253
+ is_symmetric (bool, optional):
254
+ whether quantize weight symmetrically.
255
+ accuracy_level (int, optional):
256
+ Accuracy level of the 4-bit quantized MatMul computation.
257
+ Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details.
258
+ (https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits)
259
+ quant_format (QuantFormat{QOperator, QDQ}, optional):
260
+ QOperator format quantizes the model with quantized operators directly.
261
+ QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
262
+ Defaults to QuantFormat.QOperator.
263
+ op_types_to_quantize (optional):
264
+ set of operator types to quantize.
265
+ quant_axes (dict[str, int], optional):
266
+ op:axis, which axis to quantize for an op. Default {MatMul: 0, Gather: 1}
267
+ bits (int, optional):
268
+ number of bits per element after quantization. Default 4.
269
+ """
270
+ super().__init__(
271
+ algorithm="DEFAULT",
272
+ quant_format=quant_format,
273
+ op_types_to_quantize=op_types_to_quantize,
274
+ quant_axes=quant_axes,
275
+ )
276
+ self.block_size = block_size
277
+ self.is_symmetric = is_symmetric
278
+ self.bits = bits
279
+ self.accuracy_level = accuracy_level
280
+ self.channel_wised_quantize = channel_wised_quantize
281
+ if channel_wised_quantize and quant_format == QuantFormat.QOperator:
282
+ raise NotImplementedError("QuantFormat.QOperator is not supported channel_wised_quantize yet")
283
+
284
+
285
+ class NVAWQWeightOnlyQuantConfig(WeightOnlyQuantConfig):
286
+ def __init__(
287
+ self,
288
+ tokenizer_dir,
289
+ dataset_name="cnn",
290
+ cache_dir="./cache",
291
+ calibration_method="awq_lite",
292
+ ):
293
+ """
294
+ Configuration for the nvidia_awq quantization method.
295
+
296
+ Args:
297
+ tokenizer_dir (str): pathof the tokenizer dir.
298
+ dataset_name (str): Name of the dataset.
299
+ cache_dir (str): Directory for caching.
300
+ calibration_method (str): calib method for nvidia_awq.
301
+ """
302
+ # Import torch and DataLoader
303
+ try:
304
+ import torch # noqa: PLC0415
305
+ from torch.utils.data import DataLoader # noqa: PLC0415
306
+
307
+ self.torch = torch
308
+ self.DataLoader = DataLoader
309
+ except ImportError:
310
+ print(
311
+ "Error: The 'torch' library is required but not installed. Please install it using 'pip install torch'."
312
+ )
313
+ raise ImportError("torch is not installed. Exiting.") from None
314
+
315
+ # Import datasets
316
+ try:
317
+ from datasets import load_dataset # noqa: PLC0415
318
+
319
+ self.load_dataset = load_dataset
320
+ except ImportError:
321
+ print(
322
+ "Error: The 'datasets' library is required but not installed. Please install it using 'pip install datasets'."
323
+ )
324
+ raise ImportError("datasets is not installed. Exiting.") from None
325
+
326
+ # Import transformers
327
+ try:
328
+ from transformers import AutoConfig, AutoTokenizer # noqa: PLC0415
329
+
330
+ self.AutoConfig = AutoConfig
331
+ self.AutoTokenizer = AutoTokenizer
332
+ except ImportError:
333
+ print(
334
+ "Error: The 'transformers' library is required but not installed. Please install it using 'pip install transformers'."
335
+ )
336
+ raise ImportError("transformers is not installed. Exiting.") from None
337
+
338
+ super().__init__(
339
+ algorithm="nvidia_awq",
340
+ quant_format=QuantFormat.QDQ,
341
+ op_types_to_quantize=None, # Assuming op_types_to_quantize is handled elsewhere
342
+ quant_axes=None, # Assuming quant_axes is handled elsewhere
343
+ )
344
+
345
+ # Determine the device
346
+ device = self.torch.device("cuda" if self.torch.cuda.is_available() else "cpu")
347
+
348
+ calib_inputs = self.get_calib_inputs(
349
+ dataset_name=dataset_name,
350
+ model_name=tokenizer_dir,
351
+ cache_dir=cache_dir,
352
+ calib_size=32,
353
+ batch_size=1,
354
+ block_size=512,
355
+ device=device,
356
+ use_fp16=True,
357
+ use_buffer_share=False,
358
+ add_past_kv_inputs=True,
359
+ max_calib_rows_to_load=128,
360
+ add_position_ids=True,
361
+ )
362
+
363
+ self.calibration_data_reader = calib_inputs
364
+ self.calibration_method = calibration_method
365
+
366
+ def make_model_input(
367
+ self,
368
+ config,
369
+ input_ids_arg,
370
+ attention_mask_arg,
371
+ add_past_kv_inputs,
372
+ device,
373
+ use_fp16,
374
+ use_buffer_share,
375
+ add_position_ids,
376
+ ):
377
+ # Access torch from the instance variable
378
+ torch = self.torch
379
+
380
+ input_ids = input_ids_arg
381
+ attention_mask = attention_mask_arg
382
+
383
+ if isinstance(input_ids_arg, list):
384
+ input_ids = torch.tensor(input_ids_arg, device=device, dtype=torch.int64)
385
+ attention_mask = torch.tensor(attention_mask_arg, device=device, dtype=torch.int64)
386
+
387
+ inputs = {
388
+ "input_ids": input_ids.contiguous(),
389
+ "attention_mask": attention_mask.contiguous(),
390
+ }
391
+
392
+ if add_position_ids:
393
+ position_ids = attention_mask.long().cumsum(-1) - 1
394
+ position_ids.masked_fill_(attention_mask == 0, 1)
395
+ inputs["position_ids"] = position_ids.contiguous()
396
+
397
+ if add_past_kv_inputs:
398
+ torch_dtype = torch.float16 if use_fp16 else torch.float32
399
+ batch_size, sequence_length = input_ids.shape
400
+ max_sequence_length = config.max_position_embeddings
401
+ num_heads, head_size = (
402
+ config.num_key_value_heads,
403
+ config.hidden_size // config.num_attention_heads,
404
+ )
405
+ for i in range(config.num_hidden_layers):
406
+ past_key = torch.zeros(
407
+ batch_size,
408
+ num_heads,
409
+ max_sequence_length if use_buffer_share else 0,
410
+ head_size,
411
+ device=device,
412
+ dtype=torch_dtype,
413
+ )
414
+ past_value = torch.zeros(
415
+ batch_size,
416
+ num_heads,
417
+ max_sequence_length if use_buffer_share else 0,
418
+ head_size,
419
+ device=device,
420
+ dtype=torch_dtype,
421
+ )
422
+ inputs.update(
423
+ {
424
+ f"past_key_values.{i}.key": past_key.contiguous(),
425
+ f"past_key_values.{i}.value": past_value.contiguous(),
426
+ }
427
+ )
428
+
429
+ return inputs
430
+
431
+ def get_calib_inputs(
432
+ self,
433
+ dataset_name,
434
+ model_name,
435
+ cache_dir,
436
+ calib_size,
437
+ batch_size,
438
+ block_size,
439
+ device,
440
+ use_fp16,
441
+ use_buffer_share,
442
+ add_past_kv_inputs,
443
+ max_calib_rows_to_load,
444
+ add_position_ids,
445
+ ):
446
+ # Access transformers and datasets from the instance variables
447
+ auto_config = self.AutoConfig
448
+ auto_tokenizer = self.AutoTokenizer
449
+ load_dataset = self.load_dataset
450
+
451
+ config = auto_config.from_pretrained(
452
+ model_name, use_auth_token=True, cache_dir=cache_dir, trust_remote_code=True
453
+ )
454
+ tokenizer = auto_tokenizer.from_pretrained(
455
+ model_name, use_auth_token=True, cache_dir=cache_dir, trust_remote_code=True
456
+ )
457
+ tokenizer.add_special_tokens({"pad_token": "[PAD]"})
458
+ tokenizer.pad_token = tokenizer.eos_token
459
+
460
+ assert calib_size <= max_calib_rows_to_load, "calib size should be no more than max_calib_rows_to_load"
461
+
462
+ if "cnn" in dataset_name:
463
+ dataset2 = load_dataset("cnn_dailymail", name="3.0.0", split="train").select(range(max_calib_rows_to_load))
464
+ column = "article"
465
+ elif "pile" in dataset_name:
466
+ dataset2 = load_dataset("mit-han-lab/pile-val-backup", split="validation")
467
+ column = "text"
468
+ else:
469
+ raise ValueError(f'dataset "{dataset_name}" not supported')
470
+
471
+ dataset2 = dataset2[column][:calib_size]
472
+ batch_encoded = tokenizer.batch_encode_plus(
473
+ dataset2, return_tensors="pt", padding=True, truncation=True, max_length=block_size
474
+ )
475
+ batch_encoded = batch_encoded.to(device)
476
+ batch_encoded_input_ids = batch_encoded["input_ids"]
477
+ batch_encoded_attention_mask = batch_encoded["attention_mask"]
478
+
479
+ # Access DataLoader from the instance variable
480
+ data_loader = self.DataLoader
481
+
482
+ calib_dataloader_input_ids = data_loader(batch_encoded_input_ids, batch_size=batch_size, shuffle=False)
483
+ calib_dataloader_attention_mask = data_loader(
484
+ batch_encoded_attention_mask, batch_size=batch_size, shuffle=False
485
+ )
486
+
487
+ assert len(calib_dataloader_input_ids.dataset) == len(calib_dataloader_attention_mask.dataset)
488
+ assert len(calib_dataloader_input_ids) == len(calib_dataloader_attention_mask)
489
+
490
+ number_of_batched_samples = calib_size // batch_size
491
+
492
+ batched_input_ids = []
493
+ for idx, data in enumerate(calib_dataloader_input_ids):
494
+ batched_input_ids.append(data)
495
+ if idx == (number_of_batched_samples - 1):
496
+ break
497
+
498
+ batched_attention_mask = []
499
+ for idx, data in enumerate(calib_dataloader_attention_mask):
500
+ batched_attention_mask.append(data)
501
+ if idx == (number_of_batched_samples - 1):
502
+ break
503
+
504
+ print(
505
+ f"\n--Quantize-Script-- number_of_batched_samples={number_of_batched_samples}, "
506
+ f"batch-input-ids-list-len={len(batched_input_ids)}, batched_attention_mask={len(batched_attention_mask)}\n"
507
+ )
508
+
509
+ batched_inputs_list = []
510
+ for i in range(number_of_batched_samples):
511
+ input_ids = batched_input_ids[i]
512
+ attention_mask = batched_attention_mask[i]
513
+
514
+ inputs = self.make_model_input(
515
+ config,
516
+ input_ids,
517
+ attention_mask,
518
+ add_past_kv_inputs,
519
+ device,
520
+ use_fp16,
521
+ use_buffer_share,
522
+ add_position_ids,
523
+ )
524
+ inputs = {input_name: torch_tensor.cpu().numpy() for input_name, torch_tensor in inputs.items()}
525
+ batched_inputs_list.append(inputs)
526
+
527
+ print(f"\n--Quantize-Script-- number of batched inputs = {len(batched_inputs_list)}\n")
528
+ return batched_inputs_list
529
+
530
+
531
+ def is_divisible(val1, val2):
532
+ return int(val2 * np.ceil(val1 / val2)) == val1
533
+
534
+
535
+ class HQQWeightOnlyQuantizer:
536
+ def __init__(
537
+ self,
538
+ config: HQQWeightOnlyQuantConfig,
539
+ ):
540
+ self.config = config
541
+
542
+ # Proximal solver || weight - dequantize(quantize(weight))||_p^p
543
+ @staticmethod
544
+ def optimize_weights(
545
+ tensor,
546
+ scale,
547
+ zero,
548
+ min_max: list[int],
549
+ axis: int = 0,
550
+ opt_params: dict | None = None,
551
+ verbose=False,
552
+ ):
553
+ import torch # noqa: PLC0415
554
+
555
+ opt_params = {"lp_norm": 0.7, "beta": 1e1, "kappa": 1.01, "iters": 20} if opt_params is None else opt_params
556
+ lp_norm, beta, kappa, iters = (
557
+ opt_params["lp_norm"],
558
+ opt_params["beta"],
559
+ opt_params["kappa"],
560
+ opt_params["iters"],
561
+ )
562
+
563
+ dtype = torch.float16 if tensor.is_cuda else torch.float32
564
+ w_f = tensor.to(dtype)
565
+ scale = scale.to(dtype)
566
+ zero = zero.to(dtype)
567
+
568
+ def shrink_op(x, beta, p=lp_norm):
569
+ if p == 1:
570
+ return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta)
571
+ else:
572
+ return torch.sign(x) * torch.nn.functional.relu(
573
+ torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x) + 1e-8, p - 1)
574
+ )
575
+
576
+ best_error = 1e4
577
+ for i in range(iters):
578
+ w_q = torch.round(w_f * scale + zero).clamp(min_max[0], min_max[1])
579
+ w_r = (w_q - zero) / scale
580
+ w_e = shrink_op(w_f - w_r, beta)
581
+ zero = torch.mean(w_q - (w_f - w_e) * scale, axis=axis, keepdim=True)
582
+ beta *= kappa
583
+
584
+ current_error = float(torch.abs(w_f - w_r).mean())
585
+ if verbose:
586
+ print(i, np.round(current_error, 6))
587
+ if current_error < best_error:
588
+ best_error = current_error
589
+ else:
590
+ break
591
+
592
+ del w_f, w_q, w_r, w_e
593
+
594
+ return scale, zero
595
+
596
+ @staticmethod
597
+ def pack_on_row_fast_248bit(pack_tensor, ori_int_tensor, bits):
598
+ if pack_tensor.shape[0] == ori_int_tensor.shape[0]:
599
+ ori_int_tensor = ori_int_tensor.T
600
+ pack_tensor = pack_tensor.T
601
+ if bits in [2, 4, 8]:
602
+ compress_ratio = pack_tensor.element_size() * 8 // bits
603
+ for j in range(compress_ratio):
604
+ pack_tensor[0:] |= ori_int_tensor[j::compress_ratio] << (bits * (j))
605
+ else:
606
+ raise NotImplementedError("Only 2,4,8 bits are supported.")
607
+
608
+ # from Official implementation of Half-Quadratic Quantization (HQQ)
609
+ def quantize_internal(
610
+ self, tensor, bits=4, channel_wise=True, group_size=64, optimize=True, round_zero=True, axis=1
611
+ ):
612
+ import torch # noqa: PLC0415
613
+
614
+ weight = tensor.float()
615
+ ori_shape = weight.shape
616
+
617
+ pad_len = (group_size - ori_shape[axis] % group_size) % group_size
618
+ if axis == 1:
619
+ weight = torch.nn.functional.pad(weight, (0, pad_len), "constant", 0)
620
+ else:
621
+ weight = torch.nn.functional.pad(weight, (0, 0, 0, pad_len), "constant", 0)
622
+ shape = weight.shape
623
+
624
+ # Reshape for grouping
625
+ if (group_size is not None) and channel_wise:
626
+ weight = weight.reshape([-1, group_size]) if (axis == 1) else weight.reshape([group_size, -1])
627
+
628
+ # Get min/max values
629
+ if channel_wise is False:
630
+ _min, _max = weight.min(), weight.max()
631
+ optimize = False
632
+ else:
633
+ _min = weight.min(axis=axis, keepdim=True)[0]
634
+ _max = weight.max(axis=axis, keepdim=True)[0]
635
+
636
+ max_v = 2**bits - 1
637
+ min_v = 0
638
+ min_max = [min_v, max_v]
639
+
640
+ # Note: here we work with the inverse of the scale to avoid division and quantize instead via weight*scale + zero, the scale is inverted later on.
641
+ # clamp to avoid half-precision problems
642
+ scale = (max_v / (_max - _min)).clamp(max=2e4)
643
+ #!!!!!!!!!!!!!!!
644
+ min_max_axis = _max - _min
645
+ if (min_max_axis == 0).sum().item() > 0:
646
+ min_max_axis[min_max_axis == 0] = max_v
647
+ scale = (max_v / min_max_axis).clamp(max=2e4)
648
+ zero = -_min * scale
649
+
650
+ if round_zero:
651
+ zero = torch.round(zero)
652
+
653
+ # Fine-tune weights
654
+ if optimize:
655
+ scale, zero = self.optimize_weights(tensor=weight, scale=scale, zero=zero, min_max=min_max, axis=axis)
656
+
657
+ # Quantize
658
+ # Necessary for fake quantization backprop
659
+ w_q = torch.round(weight * scale + zero).clamp(min_max[0], min_max[1])
660
+ w_q = w_q.reshape(shape).int()
661
+
662
+ scale = 1.0 / scale
663
+ if axis == 1:
664
+ scale = scale.reshape(shape[0], -1)
665
+ zero = zero.reshape(shape[0], -1)
666
+ else:
667
+ scale = scale.reshape(-1, shape[-1])
668
+ zero = zero.reshape(-1, shape[-1])
669
+ # cleanup
670
+ del weight, _min, _max
671
+
672
+ return w_q, scale.to(tensor.dtype), zero.to(tensor.dtype)
673
+
674
+ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]:
675
+ """
676
+ Target node: QOperator node: QDQ nodes:
677
+ MatMul MatMulNBits DeQuantizeLinear -> MatMul
678
+ Gather GatherBlockQuantized Gather, Gather, Gather (optional) -> DequantizeLinear
679
+ If the node is target node with fp32 or fp16 const weight, quantize the weight to int4 and
680
+ return the new nodes.
681
+ If QOperator format, return the corresponding QOperator nodes.
682
+ If QDQ format, return the corresdponging QDQ nodes.
683
+ Gather (quantized data) + Gather (scales) + Gather (optional, zero points) -> DequantizeLinear is
684
+ not supported yet because Gather does not support int4 data.
685
+ """
686
+ # With HQQ, zero points are in float. Current GatherBlockQuantized does not support float zero points.
687
+ if node.op_type == "Gather":
688
+ raise NotImplementedError("Gather quantization is not supported yet in HQQ")
689
+
690
+ import torch # noqa: PLC0415
691
+
692
+ logger.info(f"start to quantize {node.name} ...")
693
+ input_b = node.input[1]
694
+ b_pb, bs_graph = get_initializer(input_b, graph_stack)
695
+ if b_pb is None:
696
+ logger.info("MatMul doesn't have const weight. Skip to quantize")
697
+ return [node] # only care about constant weight
698
+
699
+ b_array = onnx.numpy_helper.to_array(b_pb)
700
+ if len(b_array.shape) != 2:
701
+ logger.info("MatMul weight is not 2D. Skip to quantize")
702
+ return [node] # can only process 2-D matrix
703
+ b_array_torch = torch.from_numpy(b_array)
704
+ if torch.cuda.is_available():
705
+ b_array_torch = b_array_torch.cuda()
706
+
707
+ bits = self.config.bits
708
+ quant_weight_torch, scales_torch, zero_points_torch = self.quantize_internal(
709
+ b_array_torch.T, bits=bits, group_size=self.config.block_size
710
+ )
711
+ quant_weight_torch = quant_weight_torch.contiguous()
712
+ scales_torch = scales_torch.contiguous()
713
+ zero_points_torch = zero_points_torch.contiguous()
714
+
715
+ packed_size = 8 // bits # number of elements packed into one byte
716
+
717
+ packed_torch = torch.zeros(
718
+ (quant_weight_torch.shape[0], quant_weight_torch.shape[1] // packed_size),
719
+ dtype=torch.uint8,
720
+ device=quant_weight_torch.device,
721
+ )
722
+ self.pack_on_row_fast_248bit(packed_torch, quant_weight_torch, bits)
723
+ scales = scales_torch.cpu().numpy()
724
+ zero_points = zero_points_torch.cpu().numpy()
725
+ # reshape to the predefined shape in MatmulNbits
726
+ scales = scales.reshape(-1)
727
+ zero_points = zero_points.reshape(-1)
728
+ rows, cols = b_array_torch.shape
729
+ block_size = self.config.block_size
730
+ blob_size = block_size // packed_size
731
+ k_blocks = (rows + block_size - 1) // block_size
732
+ packed_torch = packed_torch.reshape(cols, k_blocks, blob_size)
733
+
734
+ b_quant = onnx.numpy_helper.from_array(packed_torch.cpu().numpy())
735
+ b_quant.name = b_pb.name + "_Q" + str(bits)
736
+ for input in bs_graph.input:
737
+ if input.name == input_b:
738
+ bs_graph.input.remove(input)
739
+ break
740
+
741
+ scales_tensor = onnx.numpy_helper.from_array(scales)
742
+ scales_tensor.name = b_pb.name + "_scales"
743
+ bs_graph.initializer.extend([b_quant, scales_tensor])
744
+
745
+ input_names = [node.input[0], b_quant.name, scales_tensor.name]
746
+ zp_tensor = onnx.numpy_helper.from_array(zero_points)
747
+ zp_tensor.name = b_pb.name + "_zero_points"
748
+ bs_graph.initializer.extend([zp_tensor])
749
+ input_names.append(zp_tensor.name)
750
+
751
+ kwargs = {}
752
+ rows, cols = b_array.shape
753
+ kwargs["K"] = rows
754
+ kwargs["N"] = cols
755
+ kwargs["bits"] = bits
756
+ kwargs["block_size"] = self.config.block_size
757
+
758
+ matmul_q_node = onnx.helper.make_node(
759
+ "MatMulNBits",
760
+ inputs=input_names,
761
+ outputs=[node.output[0]],
762
+ name=node.name + "_Q" + str(bits) if node.name else "",
763
+ domain="com.microsoft",
764
+ **kwargs,
765
+ )
766
+
767
+ logger.info(f"complete quantization of {node.name} ...")
768
+
769
+ return [matmul_q_node]
770
+
771
+
772
+ def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]:
773
+ for gid in range(len(graph_path) - 1, -1, -1):
774
+ graph = graph_path[gid]
775
+ for tensor in graph.initializer:
776
+ if tensor.name == name:
777
+ return tensor, graph
778
+ return None, None
779
+
780
+
781
+ # transpose int4 matrix (packed as uint8)
782
+ def transpose_packed_int4_matrix(packed, rows, cols):
783
+ # unpack to int4 matrix
784
+ total = rows * cols
785
+ high = (packed >> 4) & 0x0F
786
+ low = packed & 0x0F
787
+ int4_vals = np.empty(total, dtype=np.uint8)
788
+ int4_vals[0::2] = low
789
+ int4_vals[1::2] = high
790
+ int4_matrix = int4_vals.reshape((rows, cols))
791
+
792
+ # transpose int4 matrix
793
+ int4_matrix_transposed = int4_matrix.T
794
+
795
+ # pack to uint8
796
+ flat = int4_matrix_transposed.reshape(-1)
797
+ packed = ((flat[1::2] << 4) & 0xF0) | (flat[0::2] & 0x0F)
798
+ return packed.astype(np.uint8)
799
+
800
+
801
+ class DefaultWeightOnlyQuantizer:
802
+ def __init__(self, config: DefaultWeightOnlyQuantConfig):
803
+ self.config = config
804
+
805
+ def qbits_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
806
+ """4b/8b quantize fp32 weight to int4 using C++ kernels."""
807
+
808
+ qbits = self.config.bits
809
+ kpack = 8 // qbits
810
+ if len(fp32weight.shape) != 2:
811
+ raise ValueError("Current int4 block quantization only supports 2D tensors!")
812
+ rows, cols = fp32weight.shape
813
+
814
+ block_size = self.config.block_size
815
+ k_blocks = (rows + block_size - 1) // block_size
816
+
817
+ if self.config.quant_format == QuantFormat.QOperator:
818
+ blob_size = (block_size + kpack - 1) // kpack
819
+ padded_rows = k_blocks * block_size
820
+ pad_len = padded_rows - rows
821
+ if pad_len > 0:
822
+ fp32weight = np.pad(fp32weight, ((0, pad_len), (0, 0)), "constant")
823
+
824
+ # block wise quantization, each block comes from a single column
825
+ packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8")
826
+ zero_point = np.zeros((cols, ((k_blocks + kpack - 1) // kpack)), dtype="uint8")
827
+ scales = np.zeros((cols, k_blocks), dtype=fp32weight.dtype)
828
+ if qbits == 2:
829
+ quantize_matmul_2bits(
830
+ packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric
831
+ )
832
+ elif qbits == 8:
833
+ quantize_matmul_8bits(
834
+ packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric
835
+ )
836
+ else:
837
+ quantize_matmul_4bits(
838
+ packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric
839
+ )
840
+ else:
841
+ # block size equal to rows (K) if channel wised quantize enabled
842
+ block_size = rows if self.config.channel_wised_quantize else self.config.block_size
843
+ k_blocks = (rows + block_size - 1) // block_size
844
+
845
+ assert qbits == 4, "QDQ format only support 4 bits quantization"
846
+ packed = np.zeros((rows * cols + 1) // 2, dtype="uint8")
847
+ zero_point = np.zeros((cols * k_blocks + 1) // 2, dtype="uint8")
848
+ scales = np.zeros((k_blocks, cols), dtype=fp32weight.dtype)
849
+ quantize_qdq_matmul_4bits(
850
+ packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric
851
+ )
852
+
853
+ return (packed, scales, zero_point)
854
+
855
+ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]:
856
+ """
857
+ Quantize weight B of MatMul node to int4 or int8.
858
+ Currently only support 2D constant matrix and axis 0 blockwise quantization.
859
+ """
860
+ bits = self.config.bits
861
+ if bits == 8:
862
+ qtype = TensorProto.INT8 if self.config.is_symmetric else TensorProto.UINT8
863
+ else:
864
+ qtype = TensorProto.INT4 if self.config.is_symmetric else TensorProto.UINT4
865
+ input_b = node.input[1]
866
+ b_tensor, b_graph = get_initializer(input_b, graph_stack)
867
+ if b_tensor is None:
868
+ logger.info("MatMul doesn't have const weight. Skip to quantize")
869
+ return [node] # only care about constant weight
870
+
871
+ b_ndarray = ir.from_proto(b_tensor).numpy()
872
+ if len(b_ndarray.shape) != 2:
873
+ logger.info("MatMul weight is not 2D. Skip to quantize")
874
+ return [node] # can only process 2-D matrix
875
+
876
+ bfloat16 = b_ndarray.dtype == "bfloat16"
877
+ if bfloat16:
878
+ b_ndarray = b_ndarray.astype(np.float32)
879
+
880
+ packed, scales, zero_points = self.qbits_block_quant(b_ndarray)
881
+ if bfloat16:
882
+ scales = scales.astype(ml_dtypes.bfloat16)
883
+
884
+ if self.config.quant_format == QuantFormat.QOperator:
885
+ b_quant = ir.serde.serialize_tensor(ir.Tensor(packed, name=b_tensor.name + f"_Q{bits}"))
886
+ scales_tensor = ir.serde.serialize_tensor(ir.Tensor(scales, name=b_tensor.name + "_scales"))
887
+ else:
888
+ b_quant = onnx.helper.make_tensor(
889
+ b_tensor.name + f"_DQ_Q{bits}", qtype, b_ndarray.shape, packed.tobytes(), True
890
+ )
891
+ scales_tensor = ir.serde.serialize_tensor(ir.Tensor(scales, name=b_tensor.name + "_DQ_scales"))
892
+
893
+ # if QDQ, CW and SYM enabled, optimize for Intel NPU, tranpose the weight to NHWC format will increase performance
894
+ qdq_opt_for_intel_npu_enabled = (
895
+ self.config.quant_format == QuantFormat.QDQ
896
+ and self.config.channel_wised_quantize
897
+ and self.config.is_symmetric
898
+ )
899
+ if qdq_opt_for_intel_npu_enabled:
900
+ rows, cols = b_ndarray.shape
901
+ packed = transpose_packed_int4_matrix(packed, rows, cols)
902
+ scales = scales.reshape((cols, 1)) # (cols, 1)
903
+ b_quant = onnx.helper.make_tensor(
904
+ b_tensor.name + f"_DQ_Q{bits}", qtype, [cols, rows], packed.tobytes(), True
905
+ )
906
+ scales_tensor = ir.serde.serialize_tensor(ir.Tensor(scales, name=b_tensor.name + "_DQ_scales"))
907
+
908
+ for input in b_graph.input:
909
+ if input.name == input_b:
910
+ b_graph.input.remove(input)
911
+ break
912
+
913
+ b_graph.initializer.extend([b_quant, scales_tensor])
914
+
915
+ output_nodes = []
916
+
917
+ if self.config.quant_format == QuantFormat.QOperator:
918
+ input_names = [node.input[0], b_quant.name, scales_tensor.name]
919
+ if not self.config.is_symmetric:
920
+ zp_tensor = onnx.numpy_helper.from_array(zero_points, b_tensor.name + "_zero_points")
921
+ input_names.append(zp_tensor.name)
922
+ b_graph.initializer.extend([zp_tensor])
923
+ kwargs = {}
924
+ rows, cols = b_ndarray.shape
925
+ kwargs["K"] = rows
926
+ kwargs["N"] = cols
927
+ kwargs["bits"] = bits
928
+ kwargs["block_size"] = self.config.block_size
929
+
930
+ # Do not output accuracy_level if it is 0 since the attribute is optional and is not supported by most EPs.
931
+ if self.config.accuracy_level:
932
+ kwargs["accuracy_level"] = self.config.accuracy_level
933
+
934
+ matmul_qbit_node = onnx.helper.make_node(
935
+ "MatMulNBits",
936
+ inputs=input_names,
937
+ outputs=[node.output[0]],
938
+ name=node.name + f"_Q{bits}" if node.name else "",
939
+ domain="com.microsoft",
940
+ **kwargs,
941
+ )
942
+
943
+ output_nodes.append(matmul_qbit_node)
944
+ else:
945
+ dq_input_names = [b_quant.name, scales_tensor.name]
946
+ dq_output_names = [b_quant.name + "_output"]
947
+ tp_input_names = [dq_output_names[0]]
948
+ tp_output_names = [dq_output_names[0] + "_transposed"]
949
+ matmul_input_names = [
950
+ node.input[0],
951
+ tp_output_names[0] if qdq_opt_for_intel_npu_enabled else dq_output_names[0],
952
+ ]
953
+ matmul_output_names = [node.output[0]]
954
+ if not self.config.is_symmetric:
955
+ zp_tensor = onnx.helper.make_tensor(
956
+ b_tensor.name + "_DQ_zero_points", qtype, scales.shape, zero_points.tobytes(), True
957
+ )
958
+ dq_input_names.append(zp_tensor.name)
959
+ b_graph.initializer.extend([zp_tensor])
960
+ rows, cols = b_ndarray.shape
961
+ dq_kwargs = {
962
+ "axis": 1 if qdq_opt_for_intel_npu_enabled else 0,
963
+ "block_size": rows if self.config.channel_wised_quantize else self.config.block_size,
964
+ }
965
+ dq_node = onnx.helper.make_node(
966
+ "DequantizeLinear",
967
+ inputs=dq_input_names,
968
+ outputs=dq_output_names,
969
+ name=node.name + f"_DQ_Q{bits}" if node.name else "",
970
+ **dq_kwargs,
971
+ )
972
+ matmul_node = onnx.helper.make_node(
973
+ "MatMul",
974
+ inputs=matmul_input_names,
975
+ outputs=matmul_output_names,
976
+ name=node.name + f"_matmul_Q{bits}" if node.name else "",
977
+ )
978
+ if qdq_opt_for_intel_npu_enabled:
979
+ tp_node = onnx.helper.make_node(
980
+ "Transpose",
981
+ inputs=tp_input_names,
982
+ outputs=tp_output_names,
983
+ perm=[1, 0],
984
+ )
985
+ output_nodes.extend([dq_node, tp_node, matmul_node])
986
+ else:
987
+ output_nodes.extend([dq_node, matmul_node])
988
+
989
+ return output_nodes
990
+
991
+ @staticmethod
992
+ def quant_slice_symmetric(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
993
+ max_val = np.max(data, axis=1, keepdims=True)
994
+ min_val = np.min(data, axis=1, keepdims=True)
995
+ abs_max = np.where(np.abs(max_val) > np.abs(min_val), max_val, min_val)
996
+
997
+ scale = abs_max / -8.0 # if max == min, max may be clipped
998
+ quantized_slice = np.where(scale == 0, 0, data / scale).round().clip(-8, 7).astype(np.int8)
999
+
1000
+ return quantized_slice, scale
1001
+
1002
+ @staticmethod
1003
+ def quant_slice_asymmetric(data: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
1004
+ min_val = np.minimum(data.min(axis=1, keepdims=True), 0)
1005
+ max_val = np.maximum(data.max(axis=1, keepdims=True), 0)
1006
+
1007
+ scale = (max_val - min_val) / 15.0
1008
+ zero_point = np.where(scale == 0, 8, -min_val / scale).round().clip(0, 15).astype(np.uint8)
1009
+ quantized_slice = np.where(scale == 0, 8, data / scale + zero_point).round().clip(0, 15).astype(np.uint8)
1010
+
1011
+ return quantized_slice, scale, zero_point
1012
+
1013
+ @staticmethod
1014
+ def pack_int8_to_int4(data: np.ndarray) -> np.ndarray:
1015
+ """Pack int8 data to int4 and store in uint8 ndarray."""
1016
+ data_flat = data.reshape(-1)
1017
+ if len(data_flat) % 2 != 0:
1018
+ data_flat = np.append(data_flat, 0)
1019
+ quant_data_int4 = (data_flat[::2] & 0xF) | ((data_flat[1::2] & 0xF) << 4)
1020
+
1021
+ return quant_data_int4.astype("uint8")
1022
+
1023
+ @staticmethod
1024
+ def quantize_ndarray(
1025
+ data: np.ndarray,
1026
+ quantize_axis: int,
1027
+ block_size: int,
1028
+ is_symmetric: bool,
1029
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray | None]:
1030
+ """Quantize ndarray data to int4 using numpy, return (quantized data, scales, zero points)."""
1031
+ # Get the shape of the matrix
1032
+ m = 1 # dimension of the matrix before the quantize axis
1033
+ k = data.shape[quantize_axis] # dimension of the matrix along the quantize axis
1034
+ n = 1 # dimension of the matrix after the quantize axis
1035
+ for i, dim in enumerate(data.shape):
1036
+ if i < quantize_axis:
1037
+ m *= dim
1038
+ elif i > quantize_axis:
1039
+ n *= dim
1040
+
1041
+ k_blocks = (k + block_size - 1) // block_size
1042
+ scales_shape = list(data.shape)
1043
+ scales_shape[quantize_axis] = k_blocks
1044
+
1045
+ data_reshape = data.reshape((m, k, n))
1046
+ scales = np.zeros((m, k_blocks, n), dtype=data.dtype)
1047
+ if is_symmetric:
1048
+ quant_data_int8 = np.zeros((m, k, n), dtype="int8")
1049
+ else:
1050
+ quant_data_int8 = np.zeros((m, k, n), dtype="uint8")
1051
+ zero_point_int8 = np.zeros((m, k_blocks, n), dtype="uint8")
1052
+
1053
+ # slice and quantize
1054
+ for i in range(0, k, block_size):
1055
+ end_idx = min(i + block_size, k)
1056
+ slice = data_reshape[:, i:end_idx, :]
1057
+
1058
+ if is_symmetric:
1059
+ quantized_slice_int8, scale_slice = DefaultWeightOnlyQuantizer.quant_slice_symmetric(slice)
1060
+ else:
1061
+ quantized_slice_int8, scale_slice, zero_point_slice_int8 = (
1062
+ DefaultWeightOnlyQuantizer.quant_slice_asymmetric(slice)
1063
+ )
1064
+
1065
+ quant_data_int8[:, i:end_idx, :] = quantized_slice_int8
1066
+ j = i // block_size
1067
+ scales[:, j : (j + 1), :] = scale_slice
1068
+ if not is_symmetric:
1069
+ zero_point_int8[:, j : (j + 1), :] = zero_point_slice_int8
1070
+
1071
+ # pack int8 to int4
1072
+ quant_data_int4 = DefaultWeightOnlyQuantizer.pack_int8_to_int4(quant_data_int8)
1073
+ zero_point_int4 = None
1074
+ if not is_symmetric:
1075
+ zero_point_int4 = DefaultWeightOnlyQuantizer.pack_int8_to_int4(zero_point_int8)
1076
+ scales = scales.reshape(scales_shape)
1077
+ return quant_data_int4, scales, zero_point_int4
1078
+
1079
+ def quantize_gather(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]:
1080
+ """Quantize weight data of Gather node to int4."""
1081
+ assert self.config.quant_format == QuantFormat.QOperator, "Gather only supports QOperator format currently."
1082
+
1083
+ qtype = TensorProto.INT4 if self.config.is_symmetric else TensorProto.UINT4
1084
+ data_arg = node.input[0]
1085
+ data_tensorproto, data_graphproto = get_initializer(data_arg, graph_stack)
1086
+ if data_tensorproto is None:
1087
+ logger.info("Gather doesn't have const weight. Skip quantization.")
1088
+ return [node] # only care about constant weight
1089
+
1090
+ data_ndarray = onnx.numpy_helper.to_array(data_tensorproto)
1091
+ data_rank = len(data_ndarray.shape)
1092
+ quantize_axis = self.config.quant_axes.get("Gather", 1)
1093
+ block_size = self.config.block_size
1094
+
1095
+ assert quantize_axis < data_rank and quantize_axis >= -data_rank, "Invalid quantize axis for Gather node."
1096
+ assert block_size >= 16 and ((block_size - 1) & block_size == 0), "Invalid block size for Gather node."
1097
+
1098
+ quantize_axis = (quantize_axis + data_rank) % data_rank
1099
+ quantized_data, scales, zero_points = self.quantize_ndarray(
1100
+ data_ndarray, quantize_axis, block_size, self.config.is_symmetric
1101
+ )
1102
+
1103
+ for input in data_graphproto.input:
1104
+ if input.name == data_arg:
1105
+ data_graphproto.input.remove(input)
1106
+ break
1107
+
1108
+ quantized_data_tensorproto = onnx.helper.make_tensor(
1109
+ data_tensorproto.name + "_Q4", qtype, data_ndarray.shape, quantized_data.tobytes(), True
1110
+ )
1111
+ scales_tensorproto = onnx.numpy_helper.from_array(scales, data_tensorproto.name + "_scales")
1112
+ input_names = [quantized_data_tensorproto.name, node.input[1], scales_tensorproto.name]
1113
+ data_graphproto.initializer.extend([quantized_data_tensorproto, scales_tensorproto])
1114
+ if not self.config.is_symmetric:
1115
+ zp_tensorproto = onnx.helper.make_tensor(
1116
+ data_tensorproto.name + "_zero_points", qtype, scales.shape, zero_points.tobytes(), True
1117
+ )
1118
+ input_names.append(zp_tensorproto.name)
1119
+ data_graphproto.initializer.extend([zp_tensorproto])
1120
+
1121
+ try:
1122
+ gather_axis = onnx.helper.get_node_attr_value(node, "axis")
1123
+ except ValueError:
1124
+ gather_axis = 0
1125
+
1126
+ kwargs = {
1127
+ "gather_axis": gather_axis,
1128
+ "quantize_axis": quantize_axis,
1129
+ "block_size": block_size,
1130
+ }
1131
+
1132
+ gather_q4_node = onnx.helper.make_node(
1133
+ "GatherBlockQuantized",
1134
+ inputs=input_names,
1135
+ outputs=[node.output[0]],
1136
+ name=node.name + "_Q4" if node.name else "",
1137
+ domain="com.microsoft",
1138
+ **kwargs,
1139
+ )
1140
+
1141
+ return [gather_q4_node]
1142
+
1143
+ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]:
1144
+ """
1145
+ Target node: QOperator node: QDQ nodes:
1146
+ MatMul MatMulNBits DeQuantizeLinear -> MatMul
1147
+ Gather GatherBlockQuantized Gather, Gather, Gather (optional) -> DequantizeLinear
1148
+ If the node is target node with fp32 or fp16 const weight, quantize the weight to int4 and
1149
+ return the new nodes.
1150
+ If QOperator format, return the corresponding QOperator nodes.
1151
+ If QDQ format, return the corresdponging QDQ nodes.
1152
+ Gather (quantized data) + Gather (scales) + Gather (optional, zero points) -> DequantizeLinear is
1153
+ not supported yet because Gather does not support int4 data.
1154
+ """
1155
+ logger.info(f"start to quantize {node.name} ...")
1156
+
1157
+ bits = self.config.bits
1158
+ if node.op_type == "MatMul":
1159
+ if bits == 8 and self.config.quant_format == QuantFormat.QDQ:
1160
+ logger.error("MatMul only supports QOperator format for 8 bits quantization.")
1161
+ return [node]
1162
+ results = self.quantize_matmul(node, graph_stack)
1163
+ elif node.op_type == "Gather":
1164
+ if self.config.bits != 4:
1165
+ logger.error("Gather only supports 4 bits quantization.")
1166
+ return [node]
1167
+
1168
+ results = self.quantize_gather(node, graph_stack)
1169
+ else:
1170
+ logger.error(f"Unsupported operator {node.op_type} for weight only quantization. Skip quantization.")
1171
+ return [node]
1172
+
1173
+ logger.info(f"complete quantization of {node.name} with {self.config.bits} bits ...")
1174
+ return results
1175
+
1176
+
1177
+ class NVAWQWeightOnlyQuantizer:
1178
+ def __init__(
1179
+ self,
1180
+ config: NVAWQWeightOnlyQuantConfig,
1181
+ ):
1182
+ self.config = config
1183
+
1184
+ def quantize_awq(self, model: ModelProto | str) -> ModelProto:
1185
+ """
1186
+ Perform nvidia_awq quantization using ModelOpt's int4 quantize function.
1187
+
1188
+ Args:
1189
+ model (ModelProto): The ONNX model to quantize.
1190
+
1191
+ Returns:
1192
+ ModelProto: The quantized ONNX model.
1193
+ """
1194
+ try:
1195
+ from modelopt.onnx.quantization.int4 import quantize as quantize_int4 # noqa: PLC0415
1196
+ except ImportError:
1197
+ print(
1198
+ "Please ensure that the 'modelopt' package is installed. Please install it using pip install nvidia_modelopt."
1199
+ )
1200
+ raise ImportError(
1201
+ "modelopt is not installed. Please install it using pip install nvidia_modelopt. Exiting."
1202
+ ) from None
1203
+
1204
+ logger.info("Starting nvidia_awq quantization...")
1205
+
1206
+ # Prepare calibration inputs
1207
+ calib_inputs = self.config.calibration_data_reader
1208
+
1209
+ # Perform quantization using ModelOpt's int4 quantize function
1210
+ quantized_model = quantize_int4(
1211
+ model,
1212
+ calibration_method=self.config.calibration_method,
1213
+ calibration_data_reader=calib_inputs,
1214
+ )
1215
+
1216
+ logger.info("Completed nvidia_awq quantization.")
1217
+ return quantized_model
1218
+
1219
+
1220
+ class MatMulNBitsQuantizer:
1221
+ """
1222
+ Target node: QOperator node: QDQ nodes:
1223
+ MatMul MatMulNBits DeQuantizeLinear -> MatMul
1224
+ Gather GatherBlockQuantized Gather, Gather, Gather (optional) -> DequantizeLinear
1225
+
1226
+ Perform 2/4/8 bits quantization of constant weights for target nodes.
1227
+ If algo_config.quant_format is QOperator:
1228
+ - nodes are replaced by the corresponding QOperator nodes.
1229
+ - quantized weights are stored in the contrib ops.
1230
+ If algo_config.quant_format is QDQ:
1231
+ - the quantized weight is stored in a standard onnx node. For MatMul, it is DequantizeLinear. For Gather,
1232
+ it is the three Gathers, one for quantized data, one for scales and one for optional zero points.
1233
+ - The nodes are replaced by the corresponding QDQ nodes.
1234
+ - currently Gather is not supported in QDQ because Gather does not support int4 yet.
1235
+ Note:
1236
+ - for quantized gather, the memory usage of "DequantizeLinear + Gather" is the same as the original Gather
1237
+ during runtime. Therefor it is not recommended.
1238
+ - when a node is in nodes_to_exclude, and the node configuration in algo_config.customized_weight_config will be ignored.
1239
+ """
1240
+
1241
+ def __init__(
1242
+ self,
1243
+ model: ModelProto | str,
1244
+ bits: int = 4, # default to 4bit
1245
+ block_size: int = 128,
1246
+ is_symmetric: bool = False,
1247
+ accuracy_level: int | None = None,
1248
+ nodes_to_exclude=None,
1249
+ nodes_to_include: list[str] | None = None,
1250
+ quant_format=QuantFormat.QOperator,
1251
+ op_types_to_quantize: tuple[str, ...] | None = None,
1252
+ quant_axes: tuple[tuple[str, int], ...] | None = None,
1253
+ channel_wised_quantize: bool = False,
1254
+ algo_config: WeightOnlyQuantConfig | None = None,
1255
+ ):
1256
+ if nodes_to_exclude is None:
1257
+ nodes_to_exclude = []
1258
+ self.model = ONNXModel(onnx.load(model)) if isinstance(model, str) else ONNXModel(model)
1259
+ self.model_path = model if isinstance(model, str) else None
1260
+ self.bits = bits
1261
+ self.block_size = block_size
1262
+ self.is_symmetric = is_symmetric
1263
+ self.accuracy_level = accuracy_level
1264
+ self.nodes_to_exclude = set(nodes_to_exclude)
1265
+ self.nodes_to_include = set(nodes_to_include) if nodes_to_include else None
1266
+ self.node_quantizer = None
1267
+
1268
+ if algo_config is None:
1269
+ algo_config = DefaultWeightOnlyQuantConfig(
1270
+ block_size=block_size,
1271
+ is_symmetric=is_symmetric,
1272
+ accuracy_level=accuracy_level,
1273
+ quant_format=quant_format,
1274
+ op_types_to_quantize=op_types_to_quantize,
1275
+ quant_axes=quant_axes,
1276
+ bits=bits,
1277
+ channel_wised_quantize=channel_wised_quantize,
1278
+ )
1279
+
1280
+ self.algo_config = algo_config
1281
+ if hasattr(self.algo_config, "bits"):
1282
+ assert self.algo_config.bits in [2, 4, 8], "Only support 2, 4 or 8 bits quantization"
1283
+
1284
+ if algo_config.algorithm == "HQQ":
1285
+ self.node_quantizer = HQQWeightOnlyQuantizer(self.algo_config)
1286
+ elif algo_config.algorithm == "DEFAULT":
1287
+ self.node_quantizer = DefaultWeightOnlyQuantizer(self.algo_config)
1288
+ elif algo_config.algorithm == "nvidia_awq":
1289
+ self.node_quantizer = NVAWQWeightOnlyQuantizer(self.algo_config)
1290
+
1291
+ def _process_subgraph(self, graph_stack: list[GraphProto]):
1292
+ new_nodes = []
1293
+ graph = graph_stack[-1]
1294
+
1295
+ for node in graph.node:
1296
+ graph_attrs = [
1297
+ attr
1298
+ for attr in node.attribute
1299
+ if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS
1300
+ ]
1301
+ if graph_attrs:
1302
+ kwargs = {}
1303
+ for attr in node.attribute:
1304
+ if attr.type == onnx.AttributeProto.GRAPH:
1305
+ # recursive call to take care of sub-graph
1306
+ graph_stack.append(attr.g)
1307
+ kv = {attr.name: self._process_subgraph(graph_stack)}
1308
+ elif attr.type == onnx.AttributeProto.GRAPHS:
1309
+ value = []
1310
+ for subgraph in attr.graphs:
1311
+ # recursive call to take care of sub-graph
1312
+ graph_stack.append(subgraph)
1313
+ value.extend([self._process_subgraph(graph_stack)])
1314
+ kv = {attr.name: value}
1315
+ else:
1316
+ kv = attribute_to_kwarg(attr)
1317
+ kwargs.update(kv)
1318
+ node = onnx.helper.make_node( # noqa: PLW2901
1319
+ node.op_type, node.input, node.output, name=node.name, **kwargs
1320
+ )
1321
+ out_nodes = []
1322
+ if node.name in self.nodes_to_exclude:
1323
+ logger.info(f"exclude to quantize {node.name} as specified by nodes_to_exclude...")
1324
+ out_nodes = [node]
1325
+ elif (self.nodes_to_include and node.name in self.nodes_to_include) or (
1326
+ node.op_type in self.algo_config.op_types_to_quantize
1327
+ ):
1328
+ out_nodes = self.node_quantizer.quantize(node, graph_stack)
1329
+ else:
1330
+ logger.info(f"skip to quantize {node.name} ...")
1331
+ out_nodes = [node]
1332
+ new_nodes.extend(out_nodes)
1333
+
1334
+ graph.ClearField("node")
1335
+ graph.node.extend(new_nodes)
1336
+ graph_stack.pop()
1337
+ return graph
1338
+
1339
+ def _generate_q4_node_config(self):
1340
+ """Generate weight only quant configuration for nodes."""
1341
+ q4_node_config = {}
1342
+ for node in self.model.model.graph.node:
1343
+ if node.op_type in ["MatMul"]:
1344
+ if not all(self.model.get_initializer(i) is None for i in node.input):
1345
+ template_config_q4 = {
1346
+ "bits": 4,
1347
+ "group_size": self.block_size,
1348
+ "scheme": "sym" if self.is_symmetric else "asym",
1349
+ }
1350
+ if (
1351
+ self.algo_config.customized_weight_config
1352
+ and node.name in self.algo_config.customized_weight_config
1353
+ ):
1354
+ for key, value in self.algo_config.customized_weight_config[node.name].items():
1355
+ if key in template_config_q4:
1356
+ template_config_q4[key] = value
1357
+ q4_node_config[node.name] = template_config_q4
1358
+ return q4_node_config
1359
+
1360
+ def int4_quant_algo(self):
1361
+ """4b quantize a model with RTN or GPTQ algorithm. Please refer to
1362
+ https://github.com/intel/neural-compressor/blob/master/docs/source/quantization_weight_only.md
1363
+ for more details on weight only quantization using Intel® Neural Compressor.
1364
+ """
1365
+
1366
+ def inc_dataloader():
1367
+ data_reader = copy.deepcopy(self.algo_config.calibration_data_reader)
1368
+ for data in data_reader:
1369
+ yield data, None
1370
+
1371
+ kwargs = {}
1372
+ if self.accuracy_level is not None:
1373
+ kwargs["accuracy_level"] = self.accuracy_level
1374
+ weight_only_node_config = self._generate_q4_node_config()
1375
+
1376
+ algorithm = self.algo_config.algorithm
1377
+ logger.info(f"start to quantize model with {algorithm} algorithm...")
1378
+ if algorithm in ["RTN", "k_quant"]:
1379
+ kwargs["ratios"] = self.algo_config.ratios
1380
+ kwargs["algorithm"] = algorithm
1381
+
1382
+ """
1383
+ We uses fp32 to represent the node that skip quantization, it does not mean this node is fp32 type though.
1384
+ """
1385
+ for n in self.nodes_to_exclude:
1386
+ weight_only_node_config[n] = "fp32"
1387
+
1388
+ self.model = rtn_quantize(
1389
+ model=self.model_path if self.model_path is not None else self.model.model,
1390
+ weight_config=weight_only_node_config,
1391
+ **kwargs,
1392
+ )
1393
+ elif algorithm == "GPTQ":
1394
+ kwargs["percdamp"] = self.algo_config.percdamp
1395
+ kwargs["blocksize"] = self.algo_config.block_size
1396
+ kwargs["actorder"] = self.algo_config.actorder
1397
+ kwargs["mse"] = self.algo_config.mse
1398
+ kwargs["perchannel"] = self.algo_config.perchannel
1399
+ kwargs["n_samples"] = -1
1400
+ dataloader = inc_dataloader()
1401
+
1402
+ self.model = gptq_quantize(
1403
+ model=self.model_path if self.model_path is not None else self.model.model,
1404
+ weight_config=weight_only_node_config,
1405
+ dataloader=dataloader,
1406
+ **kwargs,
1407
+ )
1408
+ logger.info(f"complete quantization of model with {algorithm} algorithm.")
1409
+
1410
+ def process(self):
1411
+ if self.algo_config.algorithm in ["HQQ", "DEFAULT"]:
1412
+ # use a stack to keep track of sub-graphs
1413
+ graph_stack = [self.model.graph()]
1414
+
1415
+ # Update domain opset
1416
+ if self.algo_config.quant_format == QuantFormat.QOperator:
1417
+ self.model.set_opset_import("com.microsoft", 1)
1418
+
1419
+ if self.algo_config.quant_format == QuantFormat.QDQ or "Gather" in self.algo_config.op_types_to_quantize:
1420
+ opset_import = self.model.opset_import()
1421
+ for opset in opset_import:
1422
+ if opset.domain in [None, "ai.onnx", ""] and opset.version < 21:
1423
+ logger.warning(
1424
+ "The opset of the input model is under 21 and doesn't support int4 data type. "
1425
+ "Force to update it to opset 21, but the generated model may not be a valid model."
1426
+ )
1427
+ self.model.set_opset_import(opset.domain, 21)
1428
+
1429
+ self._process_subgraph(graph_stack)
1430
+ self.model.clean_initializers()
1431
+ elif self.algo_config.algorithm == "nvidia_awq":
1432
+ # Handle nvidia_awq quantization
1433
+ logger.info("Processing nvidia_awq quantization...")
1434
+ self.model = self.node_quantizer.quantize_awq(
1435
+ self.model.model if self.model_path is None else self.model_path
1436
+ )
1437
+ logger.info("Completed nvidia_awq quantization.")
1438
+ self.model = ONNXModel(self.model) # Ensure the model is wrapped back into ONNXModel
1439
+ self.model.clean_initializers()
1440
+ else:
1441
+ # RTN or GPTQ weight-only quantize algorithm
1442
+ self.int4_quant_algo()
1443
+
1444
+
1445
+ def ort_convert_str_to_bool(value):
1446
+ return value.lower() in ("true", "1")
1447
+
1448
+
1449
+ # Custom function to parse str:int pairs
1450
+ def parse_key_value_pair(s):
1451
+ key, value = s.split(":")
1452
+ return key, int(value)
1453
+
1454
+
1455
+ def parse_args():
1456
+ parser = argparse.ArgumentParser(
1457
+ description="""Blockwise int4 quantization for MatMul 2D weight matrices.
1458
+
1459
+ A weight matrix is partitioned into into blocks, where each block is a
1460
+ continguous subset inside each column. Each block is quantized into a
1461
+ set of 4b integers with a scaling factor and an optional offset.
1462
+ """
1463
+ )
1464
+
1465
+ parser.add_argument("--input_model", required=True, help="Path to the input model file")
1466
+ parser.add_argument("--output_model", required=True, help="Path to the output model file")
1467
+ parser.add_argument("--block_size", required=False, default=32, type=int, help="Block size for quantization")
1468
+ parser.add_argument(
1469
+ "--quant_method",
1470
+ default="default",
1471
+ type=str,
1472
+ choices=["default", "hqq", "rtn", "k_quant", "gptq", "nvidia_awq"],
1473
+ help="the algorithm used to quantize weight, \nrtn and gptq leverage Intel® Neural Compressor",
1474
+ )
1475
+ parser.add_argument("--bits", default=4, type=int, help="the target bits to represent weight")
1476
+ parser.add_argument(
1477
+ "--symmetric",
1478
+ required=False,
1479
+ default=True,
1480
+ const=True,
1481
+ nargs="?",
1482
+ type=ort_convert_str_to_bool,
1483
+ choices=[True, False],
1484
+ help="Indicate whether to quantize the model symmetrically, symmetric is not supported by hqq",
1485
+ )
1486
+ parser.add_argument(
1487
+ "--accuracy_level",
1488
+ required=False,
1489
+ type=int,
1490
+ help="Accuracy level of the 4-bit quantized MatMul computation. "
1491
+ "Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details "
1492
+ "(https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits).",
1493
+ )
1494
+ parser.add_argument("-v", "--verbose", required=False, action="store_true")
1495
+ parser.set_defaults(verbose=False)
1496
+ parser.add_argument(
1497
+ "--nodes_to_exclude",
1498
+ nargs="+",
1499
+ type=str,
1500
+ required=False,
1501
+ default=[],
1502
+ help="Specify the nodes to be excluded from quantization with node names",
1503
+ )
1504
+ parser.add_argument(
1505
+ "--nodes_to_include",
1506
+ nargs="+",
1507
+ type=str,
1508
+ required=False,
1509
+ help="Specify the specific nodes to be included from quantization with node names",
1510
+ )
1511
+ parser.add_argument(
1512
+ "--quant_format",
1513
+ default="QOperator",
1514
+ type=str,
1515
+ choices=["QOperator", "QDQ"],
1516
+ help="QuantFormat {QOperator, QDQ}"
1517
+ "QOperator format quantizes the model with quantized operators directly."
1518
+ "QDQ format quantize the model by inserting DeQuantizeLinear before the MatMul.",
1519
+ )
1520
+ parser.add_argument(
1521
+ "--op_types_to_quantize",
1522
+ type=str,
1523
+ nargs="+",
1524
+ choices=["MatMul", "Gather"],
1525
+ help="op_types_to_quantize {MatMul, Gather}. Operators to quantize. Default is MatMul.",
1526
+ )
1527
+ parser.add_argument(
1528
+ "--quant_axes",
1529
+ type=parse_key_value_pair,
1530
+ nargs="+",
1531
+ required=False,
1532
+ help="Key-value pairs in op_type:axis_to_quantize separated by space."
1533
+ "Specify the axis to quantize for an op. Default {MatMul:0, Gather:1}"
1534
+ "Example: --quant_axes MatMul:0 Gather:1",
1535
+ )
1536
+ # Group arguments specific to nvidia_awq
1537
+ nv_awq_config = parser.add_argument_group("nvidia_awq", "Arguments specific to nvidia_awq quantization")
1538
+ nv_awq_config.add_argument(
1539
+ "--calib_dataset_name",
1540
+ type=str,
1541
+ default="cnn",
1542
+ help="Name of the calibration dataset for nvidia_awq.",
1543
+ )
1544
+ nv_awq_config.add_argument(
1545
+ "--tokenizer_dir",
1546
+ type=str,
1547
+ required=False,
1548
+ help="Path of the tokenizer dir.",
1549
+ )
1550
+ nv_awq_config.add_argument(
1551
+ "--calibration_method",
1552
+ type=str,
1553
+ required=False,
1554
+ choices=["awq", "awq_clip"],
1555
+ help="Support two options, awq implementation and weight clipping.",
1556
+ )
1557
+ nv_awq_config.add_argument(
1558
+ "--cache_dir",
1559
+ type=str,
1560
+ default="./cache",
1561
+ help="Cache directory for calibration data.",
1562
+ )
1563
+ return parser.parse_args()
1564
+
1565
+
1566
+ if __name__ == "__main__":
1567
+ args = parse_args()
1568
+ if args.verbose:
1569
+ logger.setLevel(logging.DEBUG)
1570
+
1571
+ input_model_path = args.input_model
1572
+ output_model_path = args.output_model
1573
+ quant_format = QuantFormat[args.quant_format]
1574
+ op_types_to_quantize = tuple(args.op_types_to_quantize) if args.op_types_to_quantize else ("MatMul",)
1575
+ quant_axes = tuple(args.quant_axes) if args.quant_axes else None
1576
+
1577
+ if os.path.exists(output_model_path):
1578
+ logger.error(f"file {output_model_path} already exists")
1579
+ raise Exception(f"file {output_model_path} already exists")
1580
+
1581
+ if args.symmetric and args.quant_method == "hqq":
1582
+ logger.warning("symmetric is not supportted by hqq, will force to symmetric=False")
1583
+ args.symmetric = False
1584
+
1585
+ model = onnx.load(input_model_path)
1586
+ if args.quant_method == "hqq":
1587
+ quant_config = HQQWeightOnlyQuantConfig(
1588
+ block_size=args.block_size, bits=args.bits, op_types_to_quantize=op_types_to_quantize, quant_axes=quant_axes
1589
+ )
1590
+ elif args.quant_method == "default":
1591
+ quant_config = DefaultWeightOnlyQuantConfig(
1592
+ block_size=args.block_size,
1593
+ is_symmetric=args.symmetric,
1594
+ accuracy_level=args.accuracy_level,
1595
+ quant_format=quant_format,
1596
+ op_types_to_quantize=op_types_to_quantize,
1597
+ quant_axes=quant_axes,
1598
+ bits=args.bits,
1599
+ )
1600
+ elif args.quant_method == "rtn":
1601
+ quant_config = RTNWeightOnlyQuantConfig(op_types_to_quantize=op_types_to_quantize)
1602
+ elif args.quant_method == "k_quant":
1603
+ quant_config = KQuantWeightOnlyQuantConfig(op_types_to_quantize=op_types_to_quantize)
1604
+ elif args.quant_method == "gptq":
1605
+ quant_config = GPTQWeightOnlyQuantConfig(block_size=args.block_size, op_types_to_quantize=op_types_to_quantize)
1606
+ elif args.quant_method == "nvidia_awq":
1607
+ if quant_format == QuantFormat.QOperator:
1608
+ logger.warning("QOperator is not applicable to nvidia_awq. overriding the value to QDQ")
1609
+ quant_format = QuantFormat.QDQ
1610
+
1611
+ model = input_model_path
1612
+ if args.calibration_method is not None:
1613
+ if args.calibration_method == "awq":
1614
+ calibration_method = "awq_lite"
1615
+ else:
1616
+ calibration_method = "awq_clip"
1617
+ else:
1618
+ calibration_method = "awq_lite"
1619
+
1620
+ quant_config = NVAWQWeightOnlyQuantConfig(
1621
+ dataset_name=args.calib_dataset_name,
1622
+ tokenizer_dir=args.tokenizer_dir,
1623
+ cache_dir=args.cache_dir,
1624
+ calibration_method=calibration_method,
1625
+ )
1626
+ else:
1627
+ raise ValueError(f"Unsupported quantization method: {args.quant_method}")
1628
+
1629
+ quant = MatMulNBitsQuantizer(
1630
+ model=model,
1631
+ bits=args.bits,
1632
+ accuracy_level=args.accuracy_level,
1633
+ nodes_to_exclude=args.nodes_to_exclude,
1634
+ nodes_to_include=args.nodes_to_include,
1635
+ algo_config=quant_config,
1636
+ )
1637
+ quant.process()
1638
+ quant.model.save_model_to_file(output_model_path, True)