onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1480 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ from __future__ import annotations
8
+
9
+ import argparse
10
+ import copy
11
+ import importlib
12
+ import logging
13
+ import os
14
+
15
+ import numpy as np
16
+ import numpy.typing as npt
17
+ import onnx
18
+ from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto
19
+ from packaging import version
20
+
21
+ from onnxruntime.capi._pybind_state import quantize_matmul_4bits, quantize_qdq_matmul_4bits
22
+
23
+ from .calibrate import CalibrationDataReader
24
+ from .onnx_model import ONNXModel
25
+ from .quant_utils import QuantFormat, attribute_to_kwarg
26
+
27
+ logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.INFO)
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class WeightOnlyQuantConfig:
32
+ def __init__(
33
+ self,
34
+ algorithm: str,
35
+ quant_format: QuantFormat,
36
+ op_types_to_quantize: tuple[str, ...] | None = None,
37
+ quant_axes: tuple[tuple[str, int], ...] | None = None,
38
+ ):
39
+ """This is the Base class for Weight Only blockwise quantization Configuration.
40
+
41
+ Args:
42
+ algorithm:
43
+ weight only quantize algorithm name.
44
+ quant_format: QuantFormat{QOperator, QDQ}.
45
+ QOperator format quantizes the model with quantized operators directly.
46
+ QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
47
+ op_types_to_quantize (optional):
48
+ set of operator types to quantize. Default {MatMul}
49
+ quant_axes (dict[str, int], optional):
50
+ op:axis, which axis to quantize for an op. Default {MatMul: 0, Gather: 1}
51
+ """
52
+ self.algorithm = algorithm
53
+ self.quant_format = quant_format
54
+ self.op_types_to_quantize = set(op_types_to_quantize) if op_types_to_quantize else {"MatMul"}
55
+ self.quant_axes = dict(quant_axes) if quant_axes else {"MatMul": 0, "Gather": 1}
56
+
57
+
58
+ class RTNWeightOnlyQuantConfig(WeightOnlyQuantConfig):
59
+ def __init__(
60
+ self,
61
+ ratios=None,
62
+ quant_format=QuantFormat.QOperator,
63
+ op_types_to_quantize: tuple[str, ...] | None = None,
64
+ ):
65
+ """
66
+ This is a class for round-to-nearest (RTN) algorithm Weight Only Quant Configuration.
67
+ RTN is the most straightforward way to quantize weight using scale maps.
68
+
69
+ Args:
70
+ ratios:
71
+ percentile of clip. Defaults to {}.
72
+ quant_format (QuantFormat{QOperator, QDQ}, optional):
73
+ QOperator format quantizes the model with quantized operators directly.
74
+ QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
75
+ Defaults to QuantFormat.QOperator.
76
+ op_types_to_quantize (optional):
77
+ set of operator types to quantize.
78
+ """
79
+ assert quant_format == QuantFormat.QOperator, "RTN only supports QOperator format"
80
+
81
+ if ratios is None:
82
+ ratios = {}
83
+ super().__init__(
84
+ algorithm="RTN",
85
+ quant_format=quant_format,
86
+ op_types_to_quantize=op_types_to_quantize,
87
+ )
88
+ self.ratios = ratios
89
+
90
+
91
+ class GPTQWeightOnlyQuantConfig(WeightOnlyQuantConfig):
92
+ def __init__(
93
+ self,
94
+ calibration_data_reader: CalibrationDataReader | None = None,
95
+ percdamp=0.01,
96
+ block_size=128,
97
+ actorder=False,
98
+ mse=False,
99
+ perchannel=True,
100
+ quant_format=QuantFormat.QOperator,
101
+ op_types_to_quantize: tuple[str, ...] | None = None,
102
+ ):
103
+ """
104
+ This is a class for GPTQ algorithm Weight Only Quant Configuration.
105
+ GPTQ algorithm provides more accurate quantization but requires more computational resources.
106
+
107
+ Args:
108
+ calibration_data_reader:
109
+ a calibration data reader. It enumerates calibration data and generates inputs for the original model.
110
+ percdamp:
111
+ percent of the average Hessian diagonal to use for dampening.
112
+ block_size (int, optional):
113
+ channel number in one block to execute a GPTQ quantization iteration.
114
+ actorder (bool, optional):
115
+ whether rearrange Hessian matrix considering the diag's value.
116
+ mse (bool, optional):
117
+ whether get scale and zero point with mse error.
118
+ perchannel (bool, optional):
119
+ whether quantize weight per-channel.
120
+ quant_format (QuantFormat{QOperator, QDQ}, optional):
121
+ QOperator format quantizes the model with quantized operators directly.
122
+ QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
123
+ Defaults to QuantFormat.QOperator.
124
+ op_types_to_quantize (optional):
125
+ set of operator types to quantize.
126
+ """
127
+ assert quant_format == QuantFormat.QOperator, "GPTQ only supports QOperator format"
128
+
129
+ super().__init__(
130
+ algorithm="GPTQ",
131
+ quant_format=quant_format,
132
+ op_types_to_quantize=op_types_to_quantize,
133
+ )
134
+ self.calibration_data_reader = calibration_data_reader
135
+ self.percdamp = percdamp
136
+ self.block_size = block_size
137
+ self.actorder = actorder
138
+ self.mse = mse
139
+ self.perchannel = perchannel
140
+
141
+
142
+ class HQQWeightOnlyQuantConfig(WeightOnlyQuantConfig):
143
+ def __init__(
144
+ self,
145
+ block_size=128,
146
+ bits=4,
147
+ axis=1,
148
+ quant_format=QuantFormat.QOperator,
149
+ op_types_to_quantize: tuple[str, ...] | None = None,
150
+ quant_axes: tuple[tuple[str, int], ...] | None = None,
151
+ ):
152
+ """
153
+ This is a class for HQQ algorithm Weight Only Quant Configuration.
154
+ HQQ algorithm quant weight without needing calibrate data.
155
+
156
+ Args:
157
+ block_size (int, optional):
158
+ channel number in one block to execute a HQQ quantization iteration.
159
+ bits (int, optional):
160
+ how many bits to represent weight.
161
+ axis (int, optional):
162
+ 0 or 1. which axis to quantize. https://arxiv.org/pdf/2309.15531.pdf
163
+ quant_format (QuantFormat{QOperator, QDQ}, optional):
164
+ QOperator format quantizes the model with quantized operators directly.
165
+ QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
166
+ Defaults to QuantFormat.QOperator.
167
+ op_types_to_quantize (optional):
168
+ set of operator types to quantize.
169
+ quant_axes (dict[str, int], optional):
170
+ op:axis, which axis to quantize for an op. Default {MatMul: 0, Gather: 1}
171
+ """
172
+ assert quant_format == QuantFormat.QOperator, "HQQ only supports QOperator format"
173
+
174
+ super().__init__(
175
+ algorithm="HQQ",
176
+ quant_format=quant_format,
177
+ op_types_to_quantize=op_types_to_quantize,
178
+ quant_axes=quant_axes,
179
+ )
180
+ self.block_size = block_size
181
+ self.bits = bits
182
+ self.axis = axis
183
+
184
+
185
+ class DefaultWeightOnlyQuantConfig(WeightOnlyQuantConfig):
186
+ def __init__(
187
+ self,
188
+ block_size: int = 128,
189
+ is_symmetric: bool = False,
190
+ accuracy_level: int | None = None,
191
+ quant_format=QuantFormat.QOperator,
192
+ op_types_to_quantize: tuple[str, ...] | None = None,
193
+ quant_axes: tuple[tuple[str, int], ...] | None = None,
194
+ ):
195
+ """
196
+ This is a class for weight only affine quantization configuration.
197
+
198
+ Args:
199
+ block_size (int, optional):
200
+ channel number in one block to execute an affine quantization iteration.
201
+ is_symmetric (bool, optional):
202
+ whether quantize weight symmetrically.
203
+ accuracy_level (int, optional):
204
+ Accuracy level of the 4-bit quantized MatMul computation.
205
+ Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details.
206
+ (https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits)
207
+ quant_format (QuantFormat{QOperator, QDQ}, optional):
208
+ QOperator format quantizes the model with quantized operators directly.
209
+ QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
210
+ Defaults to QuantFormat.QOperator.
211
+ op_types_to_quantize (optional):
212
+ set of operator types to quantize.
213
+ quant_axes (dict[str, int], optional):
214
+ op:axis, which axis to quantize for an op. Default {MatMul: 0, Gather: 1}
215
+ """
216
+ super().__init__(
217
+ algorithm="DEFAULT",
218
+ quant_format=quant_format,
219
+ op_types_to_quantize=op_types_to_quantize,
220
+ quant_axes=quant_axes,
221
+ )
222
+ self.block_size = block_size
223
+ self.is_symmetric = is_symmetric
224
+ self.bits = 4
225
+ self.accuracy_level = accuracy_level
226
+
227
+
228
+ class NVAWQWeightOnlyQuantConfig(WeightOnlyQuantConfig):
229
+ def __init__(
230
+ self,
231
+ tokenizer_dir,
232
+ dataset_name="cnn",
233
+ cache_dir="./cache",
234
+ calibration_method="awq_lite",
235
+ ):
236
+ """
237
+ Configuration for the nvidia_awq quantization method.
238
+
239
+ Args:
240
+ tokenizer_dir (str): pathof the tokenizer dir.
241
+ dataset_name (str): Name of the dataset.
242
+ cache_dir (str): Directory for caching.
243
+ calibration_method (str): calib method for nvidia_awq.
244
+ """
245
+ # Import torch and DataLoader
246
+ try:
247
+ import torch
248
+ from torch.utils.data import DataLoader
249
+
250
+ self.torch = torch
251
+ self.DataLoader = DataLoader
252
+ except ImportError:
253
+ print(
254
+ "Error: The 'torch' library is required but not installed. Please install it using 'pip install torch'."
255
+ )
256
+ raise ImportError("torch is not installed. Exiting.") from None
257
+
258
+ # Import datasets
259
+ try:
260
+ from datasets import load_dataset
261
+
262
+ self.load_dataset = load_dataset
263
+ except ImportError:
264
+ print(
265
+ "Error: The 'datasets' library is required but not installed. Please install it using 'pip install datasets'."
266
+ )
267
+ raise ImportError("datasets is not installed. Exiting.") from None
268
+
269
+ # Import transformers
270
+ try:
271
+ from transformers import AutoConfig, AutoTokenizer
272
+
273
+ self.AutoConfig = AutoConfig
274
+ self.AutoTokenizer = AutoTokenizer
275
+ except ImportError:
276
+ print(
277
+ "Error: The 'transformers' library is required but not installed. Please install it using 'pip install transformers'."
278
+ )
279
+ raise ImportError("transformers is not installed. Exiting.") from None
280
+
281
+ super().__init__(
282
+ algorithm="nvidia_awq",
283
+ quant_format=QuantFormat.QDQ,
284
+ op_types_to_quantize=None, # Assuming op_types_to_quantize is handled elsewhere
285
+ quant_axes=None, # Assuming quant_axes is handled elsewhere
286
+ )
287
+
288
+ # Determine the device
289
+ device = self.torch.device("cuda" if self.torch.cuda.is_available() else "cpu")
290
+
291
+ calib_inputs = self.get_calib_inputs(
292
+ dataset_name=dataset_name,
293
+ model_name=tokenizer_dir,
294
+ cache_dir=cache_dir,
295
+ calib_size=32,
296
+ batch_size=1,
297
+ block_size=512,
298
+ device=device,
299
+ use_fp16=True,
300
+ use_buffer_share=False,
301
+ add_past_kv_inputs=True,
302
+ max_calib_rows_to_load=128,
303
+ add_position_ids=True,
304
+ )
305
+
306
+ self.calibration_data_reader = calib_inputs
307
+ self.calibration_method = calibration_method
308
+
309
+ def make_model_input(
310
+ self,
311
+ config,
312
+ input_ids_arg,
313
+ attention_mask_arg,
314
+ add_past_kv_inputs,
315
+ device,
316
+ use_fp16,
317
+ use_buffer_share,
318
+ add_position_ids,
319
+ ):
320
+ # Access torch from the instance variable
321
+ torch = self.torch
322
+
323
+ input_ids = input_ids_arg
324
+ attention_mask = attention_mask_arg
325
+
326
+ if isinstance(input_ids_arg, list):
327
+ input_ids = torch.tensor(input_ids_arg, device=device, dtype=torch.int64)
328
+ attention_mask = torch.tensor(attention_mask_arg, device=device, dtype=torch.int64)
329
+
330
+ inputs = {
331
+ "input_ids": input_ids.contiguous(),
332
+ "attention_mask": attention_mask.contiguous(),
333
+ }
334
+
335
+ if add_position_ids:
336
+ position_ids = attention_mask.long().cumsum(-1) - 1
337
+ position_ids.masked_fill_(attention_mask == 0, 1)
338
+ inputs["position_ids"] = position_ids.contiguous()
339
+
340
+ if add_past_kv_inputs:
341
+ torch_dtype = torch.float16 if use_fp16 else torch.float32
342
+ batch_size, sequence_length = input_ids.shape
343
+ max_sequence_length = config.max_position_embeddings
344
+ num_heads, head_size = (
345
+ config.num_key_value_heads,
346
+ config.hidden_size // config.num_attention_heads,
347
+ )
348
+ for i in range(config.num_hidden_layers):
349
+ past_key = torch.zeros(
350
+ batch_size,
351
+ num_heads,
352
+ max_sequence_length if use_buffer_share else 0,
353
+ head_size,
354
+ device=device,
355
+ dtype=torch_dtype,
356
+ )
357
+ past_value = torch.zeros(
358
+ batch_size,
359
+ num_heads,
360
+ max_sequence_length if use_buffer_share else 0,
361
+ head_size,
362
+ device=device,
363
+ dtype=torch_dtype,
364
+ )
365
+ inputs.update(
366
+ {
367
+ f"past_key_values.{i}.key": past_key.contiguous(),
368
+ f"past_key_values.{i}.value": past_value.contiguous(),
369
+ }
370
+ )
371
+
372
+ return inputs
373
+
374
+ def get_calib_inputs(
375
+ self,
376
+ dataset_name,
377
+ model_name,
378
+ cache_dir,
379
+ calib_size,
380
+ batch_size,
381
+ block_size,
382
+ device,
383
+ use_fp16,
384
+ use_buffer_share,
385
+ add_past_kv_inputs,
386
+ max_calib_rows_to_load,
387
+ add_position_ids,
388
+ ):
389
+ # Access transformers and datasets from the instance variables
390
+ auto_config = self.AutoConfig
391
+ auto_tokenizer = self.AutoTokenizer
392
+ load_dataset = self.load_dataset
393
+
394
+ config = auto_config.from_pretrained(
395
+ model_name, use_auth_token=True, cache_dir=cache_dir, trust_remote_code=True
396
+ )
397
+ tokenizer = auto_tokenizer.from_pretrained(
398
+ model_name, use_auth_token=True, cache_dir=cache_dir, trust_remote_code=True
399
+ )
400
+ tokenizer.add_special_tokens({"pad_token": "[PAD]"})
401
+ tokenizer.pad_token = tokenizer.eos_token
402
+
403
+ assert calib_size <= max_calib_rows_to_load, "calib size should be no more than max_calib_rows_to_load"
404
+
405
+ if "cnn" in dataset_name:
406
+ dataset2 = load_dataset("cnn_dailymail", name="3.0.0", split="train").select(range(max_calib_rows_to_load))
407
+ column = "article"
408
+ elif "pile" in dataset_name:
409
+ dataset2 = load_dataset("mit-han-lab/pile-val-backup", split="validation")
410
+ column = "text"
411
+ else:
412
+ raise ValueError(f'dataset "{dataset_name}" not supported')
413
+
414
+ dataset2 = dataset2[column][:calib_size]
415
+ batch_encoded = tokenizer.batch_encode_plus(
416
+ dataset2, return_tensors="pt", padding=True, truncation=True, max_length=block_size
417
+ )
418
+ batch_encoded = batch_encoded.to(device)
419
+ batch_encoded_input_ids = batch_encoded["input_ids"]
420
+ batch_encoded_attention_mask = batch_encoded["attention_mask"]
421
+
422
+ # Access DataLoader from the instance variable
423
+ data_loader = self.DataLoader
424
+
425
+ calib_dataloader_input_ids = data_loader(batch_encoded_input_ids, batch_size=batch_size, shuffle=False)
426
+ calib_dataloader_attention_mask = data_loader(
427
+ batch_encoded_attention_mask, batch_size=batch_size, shuffle=False
428
+ )
429
+
430
+ assert len(calib_dataloader_input_ids.dataset) == len(calib_dataloader_attention_mask.dataset)
431
+ assert len(calib_dataloader_input_ids) == len(calib_dataloader_attention_mask)
432
+
433
+ number_of_batched_samples = calib_size // batch_size
434
+
435
+ batched_input_ids = []
436
+ for idx, data in enumerate(calib_dataloader_input_ids):
437
+ batched_input_ids.append(data)
438
+ if idx == (number_of_batched_samples - 1):
439
+ break
440
+
441
+ batched_attention_mask = []
442
+ for idx, data in enumerate(calib_dataloader_attention_mask):
443
+ batched_attention_mask.append(data)
444
+ if idx == (number_of_batched_samples - 1):
445
+ break
446
+
447
+ print(
448
+ f"\n--Quantize-Script-- number_of_batched_samples={number_of_batched_samples}, "
449
+ f"batch-input-ids-list-len={len(batched_input_ids)}, batched_attention_mask={len(batched_attention_mask)}\n"
450
+ )
451
+
452
+ batched_inputs_list = []
453
+ for i in range(number_of_batched_samples):
454
+ input_ids = batched_input_ids[i]
455
+ attention_mask = batched_attention_mask[i]
456
+
457
+ inputs = self.make_model_input(
458
+ config,
459
+ input_ids,
460
+ attention_mask,
461
+ add_past_kv_inputs,
462
+ device,
463
+ use_fp16,
464
+ use_buffer_share,
465
+ add_position_ids,
466
+ )
467
+ inputs = {input_name: torch_tensor.cpu().numpy() for input_name, torch_tensor in inputs.items()}
468
+ batched_inputs_list.append(inputs)
469
+
470
+ print(f"\n--Quantize-Script-- number of batched inputs = {len(batched_inputs_list)}\n")
471
+ return batched_inputs_list
472
+
473
+
474
+ def is_divisible(val1, val2):
475
+ return int(val2 * np.ceil(val1 / val2)) == val1
476
+
477
+
478
+ class HQQWeightOnlyQuantizer:
479
+ def __init__(
480
+ self,
481
+ config: HQQWeightOnlyQuantConfig,
482
+ ):
483
+ self.config = config
484
+
485
+ # Proximal solver || weight - dequantize(quantize(weight))||_p^p
486
+ @staticmethod
487
+ def optimize_weights(
488
+ tensor,
489
+ scale,
490
+ zero,
491
+ min_max: list[int],
492
+ axis: int = 0,
493
+ opt_params: dict | None = None,
494
+ verbose=False,
495
+ ):
496
+ import torch
497
+
498
+ opt_params = {"lp_norm": 0.7, "beta": 1e1, "kappa": 1.01, "iters": 20} if opt_params is None else opt_params
499
+ lp_norm, beta, kappa, iters = (
500
+ opt_params["lp_norm"],
501
+ opt_params["beta"],
502
+ opt_params["kappa"],
503
+ opt_params["iters"],
504
+ )
505
+
506
+ dtype = torch.float16 if tensor.is_cuda else torch.float32
507
+ w_f = tensor.to(dtype)
508
+ scale = scale.to(dtype)
509
+ zero = zero.to(dtype)
510
+
511
+ def shrink_op(x, beta, p=lp_norm):
512
+ if p == 1:
513
+ return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta)
514
+ else:
515
+ return torch.sign(x) * torch.nn.functional.relu(
516
+ torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x) + 1e-8, p - 1)
517
+ )
518
+
519
+ best_error = 1e4
520
+ for i in range(iters):
521
+ w_q = torch.round(w_f * scale + zero).clamp(min_max[0], min_max[1])
522
+ w_r = (w_q - zero) / scale
523
+ w_e = shrink_op(w_f - w_r, beta)
524
+ zero = torch.mean(w_q - (w_f - w_e) * scale, axis=axis, keepdim=True)
525
+ beta *= kappa
526
+
527
+ current_error = float(torch.abs(w_f - w_r).mean())
528
+ if verbose:
529
+ print(i, np.round(current_error, 6))
530
+ if current_error < best_error:
531
+ best_error = current_error
532
+ else:
533
+ break
534
+
535
+ del w_f, w_q, w_r, w_e
536
+
537
+ return scale, zero
538
+
539
+ @staticmethod
540
+ def pack_on_row_fast_248bit(pack_tensor, ori_int_tensor, bits):
541
+ if pack_tensor.shape[0] == ori_int_tensor.shape[0]:
542
+ ori_int_tensor = ori_int_tensor.T
543
+ pack_tensor = pack_tensor.T
544
+ if bits in [2, 4, 8]:
545
+ compress_ratio = pack_tensor.element_size() * 8 // bits
546
+ for j in range(compress_ratio):
547
+ pack_tensor[0:] |= ori_int_tensor[j::compress_ratio] << (bits * (j))
548
+ else:
549
+ raise NotImplementedError("Only 2,4,8 bits are supported.")
550
+
551
+ # from Official implementation of Half-Quadratic Quantization (HQQ)
552
+ def quantize_internal(
553
+ self, tensor, bits=4, channel_wise=True, group_size=64, optimize=True, round_zero=True, axis=1
554
+ ):
555
+ import torch
556
+
557
+ weight = tensor.float()
558
+ ori_shape = weight.shape
559
+
560
+ pad_len = (group_size - ori_shape[axis] % group_size) % group_size
561
+ if axis == 1:
562
+ weight = torch.nn.functional.pad(weight, (0, pad_len), "constant", 0)
563
+ else:
564
+ weight = torch.nn.functional.pad(weight, (0, 0, 0, pad_len), "constant", 0)
565
+ shape = weight.shape
566
+
567
+ # Reshape for grouping
568
+ if (group_size is not None) and channel_wise:
569
+ weight = weight.reshape([-1, group_size]) if (axis == 1) else weight.reshape([group_size, -1])
570
+
571
+ # Get min/max values
572
+ if channel_wise is False:
573
+ _min, _max = weight.min(), weight.max()
574
+ optimize = False
575
+ else:
576
+ _min = weight.min(axis=axis, keepdim=True)[0]
577
+ _max = weight.max(axis=axis, keepdim=True)[0]
578
+
579
+ max_v = 2**bits - 1
580
+ min_v = 0
581
+ min_max = [min_v, max_v]
582
+
583
+ # Note: here we work with the inverse of the scale to avoid division and quantize instead via weight*scale + zero, the scale is inverted later on.
584
+ # clamp to avoid half-precision problems
585
+ scale = (max_v / (_max - _min)).clamp(max=2e4)
586
+ #!!!!!!!!!!!!!!!
587
+ min_max_axis = _max - _min
588
+ if (min_max_axis == 0).sum().item() > 0:
589
+ min_max_axis[min_max_axis == 0] = max_v
590
+ scale = (max_v / min_max_axis).clamp(max=2e4)
591
+ zero = -_min * scale
592
+
593
+ if round_zero:
594
+ zero = torch.round(zero)
595
+
596
+ # Fine-tune weights
597
+ if optimize:
598
+ scale, zero = self.optimize_weights(tensor=weight, scale=scale, zero=zero, min_max=min_max, axis=axis)
599
+
600
+ # Quantize
601
+ # Necessary for fake quantization backprop
602
+ w_q = torch.round(weight * scale + zero).clamp(min_max[0], min_max[1])
603
+ w_q = w_q.reshape(shape).int()
604
+
605
+ scale = 1.0 / scale
606
+ if axis == 1:
607
+ scale = scale.reshape(shape[0], -1)
608
+ zero = zero.reshape(shape[0], -1)
609
+ else:
610
+ scale = scale.reshape(-1, shape[-1])
611
+ zero = zero.reshape(-1, shape[-1])
612
+ # cleanup
613
+ del weight, _min, _max
614
+
615
+ return w_q, scale.to(tensor.dtype), zero.to(tensor.dtype)
616
+
617
+ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]:
618
+ """
619
+ Target node: QOperator node: QDQ nodes:
620
+ MatMul MatMulNBits DeQuantizeLinear -> MatMul
621
+ Gather GatherBlockQuantized Gather, Gather, Gather (optional) -> DequantizeLinear
622
+ If the node is target node with fp32 or fp16 const weight, quantize the weight to int4 and
623
+ return the new nodes.
624
+ If QOperator format, return the corresponding QOperator nodes.
625
+ If QDQ format, return the corresdponging QDQ nodes.
626
+ Gather (quantized data) + Gather (scales) + Gather (optional, zero points) -> DequantizeLinear is
627
+ not supported yet because Gather does not support int4 data.
628
+ """
629
+ # With HQQ, zero points are in float. Current GatherBlockQuantized does not support float zero points.
630
+ if node.op_type == "Gather":
631
+ raise NotImplementedError("Gather quantization is not supported yet in HQQ")
632
+
633
+ import torch
634
+
635
+ logger.info(f"start to quantize {node.name} ...")
636
+ input_b = node.input[1]
637
+ b_pb, bs_graph = get_initializer(input_b, graph_stack)
638
+ if b_pb is None:
639
+ logger.info("MatMul doesn't have const weight. Skip to quantize")
640
+ return [node] # only care about constant weight
641
+
642
+ b_array = onnx.numpy_helper.to_array(b_pb)
643
+ if len(b_array.shape) != 2:
644
+ logger.info("MatMul weight is not 2D. Skip to quantize")
645
+ return [node] # can only process 2-D matrix
646
+ b_array_torch = torch.from_numpy(b_array)
647
+ if torch.cuda.is_available():
648
+ b_array_torch = b_array_torch.cuda()
649
+ quant_weight_torch, scales_torch, zero_points_torch = self.quantize_internal(
650
+ b_array_torch.T, bits=self.config.bits, group_size=self.config.block_size
651
+ )
652
+ quant_weight_torch = quant_weight_torch.contiguous()
653
+ scales_torch = scales_torch.contiguous()
654
+ zero_points_torch = zero_points_torch.contiguous()
655
+
656
+ packed_torch = torch.zeros(
657
+ (quant_weight_torch.shape[0], quant_weight_torch.shape[1] // 2),
658
+ dtype=torch.uint8,
659
+ device=quant_weight_torch.device,
660
+ )
661
+ self.pack_on_row_fast_248bit(packed_torch, quant_weight_torch, self.config.bits)
662
+ scales = scales_torch.cpu().numpy()
663
+ zero_points = zero_points_torch.cpu().numpy()
664
+ # reshape to the predefined shape in MatmulNbits
665
+ scales = scales.reshape(-1)
666
+ zero_points = zero_points.reshape(-1)
667
+ rows, cols = b_array_torch.shape
668
+ block_size = self.config.block_size
669
+ blob_size = block_size // 2
670
+ k_blocks = (rows + block_size - 1) // block_size
671
+ packed_torch = packed_torch.reshape(cols, k_blocks, blob_size)
672
+
673
+ b_quant = onnx.numpy_helper.from_array(packed_torch.cpu().numpy())
674
+ b_quant.name = b_pb.name + "_Q4"
675
+ for input in bs_graph.input:
676
+ if input.name == input_b:
677
+ bs_graph.input.remove(input)
678
+ break
679
+
680
+ scales_tensor = onnx.numpy_helper.from_array(scales)
681
+ scales_tensor.name = b_pb.name + "_scales"
682
+ bs_graph.initializer.extend([b_quant, scales_tensor])
683
+
684
+ input_names = [node.input[0], b_quant.name, scales_tensor.name]
685
+ zp_tensor = onnx.numpy_helper.from_array(zero_points)
686
+ zp_tensor.name = b_pb.name + "_zero_points"
687
+ bs_graph.initializer.extend([zp_tensor])
688
+ input_names.append(zp_tensor.name)
689
+
690
+ kwargs = {}
691
+ rows, cols = b_array.shape
692
+ kwargs["K"] = rows
693
+ kwargs["N"] = cols
694
+ kwargs["bits"] = self.config.bits
695
+ kwargs["block_size"] = self.config.block_size
696
+
697
+ matmul_q4_node = onnx.helper.make_node(
698
+ "MatMulNBits",
699
+ inputs=input_names,
700
+ outputs=[node.output[0]],
701
+ name=node.name + "_Q4" if node.name else "",
702
+ domain="com.microsoft",
703
+ **kwargs,
704
+ )
705
+
706
+ logger.info(f"complete quantization of {node.name} ...")
707
+
708
+ return [matmul_q4_node]
709
+
710
+
711
+ def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]:
712
+ for gid in range(len(graph_path) - 1, -1, -1):
713
+ graph = graph_path[gid]
714
+ for tensor in graph.initializer:
715
+ if tensor.name == name:
716
+ return tensor, graph
717
+ return None, None
718
+
719
+
720
+ class DefaultWeightOnlyQuantizer:
721
+ def __init__(self, config: DefaultWeightOnlyQuantConfig):
722
+ self.config = config
723
+
724
+ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
725
+ """4b quantize fp32 weight to int4 using C++ kernels."""
726
+
727
+ if len(fp32weight.shape) != 2:
728
+ raise ValueError("Current int4 block quantization only supports 2D tensors!")
729
+ rows, cols = fp32weight.shape
730
+
731
+ block_size = self.config.block_size
732
+ k_blocks = (rows + block_size - 1) // block_size
733
+
734
+ if self.config.quant_format == QuantFormat.QOperator:
735
+ blob_size = block_size // 2
736
+ padded_rows = k_blocks * block_size
737
+ pad_len = padded_rows - rows
738
+ if pad_len > 0:
739
+ fp32weight = np.pad(fp32weight, ((0, pad_len), (0, 0)), "constant")
740
+
741
+ # block wise quantization, each block comes from a single column
742
+ packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8")
743
+ zero_point = np.zeros(cols * ((k_blocks + 1) // 2), dtype="uint8")
744
+ scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype)
745
+ quantize_matmul_4bits(
746
+ packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric
747
+ )
748
+ else:
749
+ packed = np.zeros((rows * cols + 1) // 2, dtype="uint8")
750
+ zero_point = np.zeros((cols * k_blocks + 1) // 2, dtype="uint8")
751
+ scales = np.zeros((k_blocks, cols), dtype=fp32weight.dtype)
752
+ quantize_qdq_matmul_4bits(
753
+ packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric
754
+ )
755
+
756
+ return (packed, scales, zero_point)
757
+
758
+ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]:
759
+ """
760
+ Quantize weight B of MatMul node to int4.
761
+ Currently only support 2D constant matrix and axis 0 blockwise quantization.
762
+ """
763
+ qtype = TensorProto.INT4 if self.config.is_symmetric else TensorProto.UINT4
764
+ input_b = node.input[1]
765
+ b_tensor, b_graph = get_initializer(input_b, graph_stack)
766
+ if b_tensor is None:
767
+ logger.info("MatMul doesn't have const weight. Skip to quantize")
768
+ return [node] # only care about constant weight
769
+
770
+ b_ndarray = onnx.numpy_helper.to_array(b_tensor)
771
+ if len(b_ndarray.shape) != 2:
772
+ logger.info("MatMul weight is not 2D. Skip to quantize")
773
+ return [node] # can only process 2-D matrix
774
+
775
+ packed, scales, zero_points = self.int4_block_quant(b_ndarray)
776
+
777
+ if self.config.quant_format == QuantFormat.QOperator:
778
+ b_quant = onnx.numpy_helper.from_array(packed, b_tensor.name + "_Q4")
779
+ scales_tensor = onnx.numpy_helper.from_array(scales, b_tensor.name + "_scales")
780
+ else:
781
+ b_quant = onnx.helper.make_tensor(b_tensor.name + "_DQ_Q4", qtype, b_ndarray.shape, packed.tobytes(), True)
782
+ scales_tensor = onnx.numpy_helper.from_array(scales, b_tensor.name + "_DQ_scales")
783
+
784
+ for input in b_graph.input:
785
+ if input.name == input_b:
786
+ b_graph.input.remove(input)
787
+ break
788
+
789
+ b_graph.initializer.extend([b_quant, scales_tensor])
790
+
791
+ output_nodes = []
792
+
793
+ if self.config.quant_format == QuantFormat.QOperator:
794
+ input_names = [node.input[0], b_quant.name, scales_tensor.name]
795
+ if not self.config.is_symmetric:
796
+ zp_tensor = onnx.numpy_helper.from_array(zero_points, b_tensor.name + "_zero_points")
797
+ input_names.append(zp_tensor.name)
798
+ b_graph.initializer.extend([zp_tensor])
799
+ kwargs = {}
800
+ rows, cols = b_ndarray.shape
801
+ kwargs["K"] = rows
802
+ kwargs["N"] = cols
803
+ kwargs["bits"] = 4
804
+ kwargs["block_size"] = self.config.block_size
805
+ if self.config.accuracy_level is not None:
806
+ kwargs["accuracy_level"] = self.config.accuracy_level
807
+
808
+ matmul_q4_node = onnx.helper.make_node(
809
+ "MatMulNBits",
810
+ inputs=input_names,
811
+ outputs=[node.output[0]],
812
+ name=node.name + "_Q4" if node.name else "",
813
+ domain="com.microsoft",
814
+ **kwargs,
815
+ )
816
+
817
+ output_nodes.append(matmul_q4_node)
818
+ else:
819
+ dq_input_names = [b_quant.name, scales_tensor.name]
820
+ dq_output_names = [b_quant.name + "_output"]
821
+ matmul_input_names = [node.input[0], dq_output_names[0]]
822
+ matmul_output_names = [node.output[0]]
823
+ if not self.config.is_symmetric:
824
+ zp_tensor = onnx.helper.make_tensor(
825
+ b_tensor.name + "_DQ_zero_points", qtype, scales.shape, zero_points.tobytes(), True
826
+ )
827
+ dq_input_names.append(zp_tensor.name)
828
+ b_graph.initializer.extend([zp_tensor])
829
+ dq_kwargs = {"axis": 0, "block_size": self.config.block_size}
830
+ dq_node = onnx.helper.make_node(
831
+ "DequantizeLinear",
832
+ inputs=dq_input_names,
833
+ outputs=dq_output_names,
834
+ name=node.name + "_DQ_Q4" if node.name else "",
835
+ **dq_kwargs,
836
+ )
837
+ matmul_node = onnx.helper.make_node(
838
+ "MatMul",
839
+ inputs=matmul_input_names,
840
+ outputs=matmul_output_names,
841
+ name=node.name + "_matmul_Q4" if node.name else "",
842
+ )
843
+ output_nodes.extend([dq_node, matmul_node])
844
+
845
+ return output_nodes
846
+
847
+ @staticmethod
848
+ def quant_slice_symmetric(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
849
+ max_val = np.max(data, axis=1, keepdims=True)
850
+ min_val = np.min(data, axis=1, keepdims=True)
851
+ abs_max = np.where(np.abs(max_val) > np.abs(min_val), max_val, min_val)
852
+
853
+ scale = abs_max / -8.0 # if max == min, max may be clipped
854
+ quantized_slice = np.where(scale == 0, 0, data / scale).round().clip(-8, 7).astype(np.int8)
855
+
856
+ return quantized_slice, scale
857
+
858
+ @staticmethod
859
+ def quant_slice_asymmetric(data: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
860
+ min_val = np.minimum(data.min(axis=1, keepdims=True), 0)
861
+ max_val = np.maximum(data.max(axis=1, keepdims=True), 0)
862
+
863
+ scale = (max_val - min_val) / 15.0
864
+ zero_point = np.where(scale == 0, 8, -min_val / scale).round().clip(0, 15).astype(np.uint8)
865
+ quantized_slice = np.where(scale == 0, 8, data / scale + zero_point).round().clip(0, 15).astype(np.uint8)
866
+
867
+ return quantized_slice, scale, zero_point
868
+
869
+ @staticmethod
870
+ def pack_int8_to_int4(data: np.ndarray) -> np.ndarray:
871
+ """Pack int8 data to int4 and store in uint8 ndarray."""
872
+ data_flat = data.reshape(-1)
873
+ if len(data_flat) % 2 != 0:
874
+ data_flat = np.append(data_flat, 0)
875
+ quant_data_int4 = (data_flat[::2] & 0xF) | ((data_flat[1::2] & 0xF) << 4)
876
+
877
+ return quant_data_int4.astype("uint8")
878
+
879
+ @staticmethod
880
+ def quantize_ndarray(
881
+ data: np.ndarray,
882
+ quantize_axis: int,
883
+ block_size: int,
884
+ is_symmetric: bool,
885
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray | None]:
886
+ """Quantize ndarray data to int4 using numpy, return (quantized data, scales, zero points)."""
887
+ # Get the shape of the matrix
888
+ m = 1 # dimension of the matrix before the quantize axis
889
+ k = data.shape[quantize_axis] # dimension of the matrix along the quantize axis
890
+ n = 1 # dimension of the matrix after the quantize axis
891
+ for i, dim in enumerate(data.shape):
892
+ if i < quantize_axis:
893
+ m *= dim
894
+ elif i > quantize_axis:
895
+ n *= dim
896
+
897
+ k_blocks = (k + block_size - 1) // block_size
898
+ scales_shape = list(data.shape)
899
+ scales_shape[quantize_axis] = k_blocks
900
+
901
+ data_reshape = data.reshape((m, k, n))
902
+ scales = np.zeros((m, k_blocks, n), dtype=data.dtype)
903
+ if is_symmetric:
904
+ quant_data_int8 = np.zeros((m, k, n), dtype="int8")
905
+ else:
906
+ quant_data_int8 = np.zeros((m, k, n), dtype="uint8")
907
+ zero_point_int8 = np.zeros((m, k_blocks, n), dtype="uint8")
908
+
909
+ # slice and quantize
910
+ for i in range(0, k, block_size):
911
+ end_idx = min(i + block_size, k)
912
+ slice = data_reshape[:, i:end_idx, :]
913
+
914
+ if is_symmetric:
915
+ quantized_slice_int8, scale_slice = DefaultWeightOnlyQuantizer.quant_slice_symmetric(slice)
916
+ else:
917
+ quantized_slice_int8, scale_slice, zero_point_slice_int8 = (
918
+ DefaultWeightOnlyQuantizer.quant_slice_asymmetric(slice)
919
+ )
920
+
921
+ quant_data_int8[:, i:end_idx, :] = quantized_slice_int8
922
+ j = i // block_size
923
+ scales[:, j : (j + 1), :] = scale_slice
924
+ if not is_symmetric:
925
+ zero_point_int8[:, j : (j + 1), :] = zero_point_slice_int8
926
+
927
+ # pack int8 to int4
928
+ quant_data_int4 = DefaultWeightOnlyQuantizer.pack_int8_to_int4(quant_data_int8)
929
+ zero_point_int4 = None
930
+ if not is_symmetric:
931
+ zero_point_int4 = DefaultWeightOnlyQuantizer.pack_int8_to_int4(zero_point_int8)
932
+ scales = scales.reshape(scales_shape)
933
+ return quant_data_int4, scales, zero_point_int4
934
+
935
+ def quantize_gather(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]:
936
+ """Quantize weight data of Gather node to int4."""
937
+ assert self.config.quant_format == QuantFormat.QOperator, "Gather only supports QOperator format currently."
938
+
939
+ qtype = TensorProto.INT4 if self.config.is_symmetric else TensorProto.UINT4
940
+ data_arg = node.input[0]
941
+ data_tensorproto, data_graphproto = get_initializer(data_arg, graph_stack)
942
+ if data_tensorproto is None:
943
+ logger.info("Gather doesn't have const weight. Skip quantization.")
944
+ return [node] # only care about constant weight
945
+
946
+ data_ndarray = onnx.numpy_helper.to_array(data_tensorproto)
947
+ data_rank = len(data_ndarray.shape)
948
+ quantize_axis = self.config.quant_axes.get("Gather", 1)
949
+ block_size = self.config.block_size
950
+
951
+ assert quantize_axis < data_rank and quantize_axis >= -data_rank, "Invalid quantize axis for Gather node."
952
+ assert block_size >= 16 and ((block_size - 1) & block_size == 0), "Invalid block size for Gather node."
953
+
954
+ quantize_axis = (quantize_axis + data_rank) % data_rank
955
+ quantized_data, scales, zero_points = self.quantize_ndarray(
956
+ data_ndarray, quantize_axis, block_size, self.config.is_symmetric
957
+ )
958
+
959
+ for input in data_graphproto.input:
960
+ if input.name == data_arg:
961
+ data_graphproto.input.remove(input)
962
+ break
963
+
964
+ quantized_data_tensorproto = onnx.helper.make_tensor(
965
+ data_tensorproto.name + "_Q4", qtype, data_ndarray.shape, quantized_data.tobytes(), True
966
+ )
967
+ scales_tensorproto = onnx.numpy_helper.from_array(scales, data_tensorproto.name + "_scales")
968
+ input_names = [quantized_data_tensorproto.name, node.input[1], scales_tensorproto.name]
969
+ data_graphproto.initializer.extend([quantized_data_tensorproto, scales_tensorproto])
970
+ if not self.config.is_symmetric:
971
+ zp_tensorproto = onnx.helper.make_tensor(
972
+ data_tensorproto.name + "_zero_points", qtype, scales.shape, zero_points.tobytes(), True
973
+ )
974
+ input_names.append(zp_tensorproto.name)
975
+ data_graphproto.initializer.extend([zp_tensorproto])
976
+
977
+ try:
978
+ gather_axis = onnx.helper.get_node_attr_value(node, "axis")
979
+ except ValueError:
980
+ gather_axis = 0
981
+
982
+ kwargs = {
983
+ "gather_axis": gather_axis,
984
+ "quantize_axis": quantize_axis,
985
+ "block_size": block_size,
986
+ }
987
+
988
+ gather_q4_node = onnx.helper.make_node(
989
+ "GatherBlockQuantized",
990
+ inputs=input_names,
991
+ outputs=[node.output[0]],
992
+ name=node.name + "_Q4" if node.name else "",
993
+ domain="com.microsoft",
994
+ **kwargs,
995
+ )
996
+
997
+ return [gather_q4_node]
998
+
999
+ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]:
1000
+ """
1001
+ Target node: QOperator node: QDQ nodes:
1002
+ MatMul MatMulNBits DeQuantizeLinear -> MatMul
1003
+ Gather GatherBlockQuantized Gather, Gather, Gather (optional) -> DequantizeLinear
1004
+ If the node is target node with fp32 or fp16 const weight, quantize the weight to int4 and
1005
+ return the new nodes.
1006
+ If QOperator format, return the corresponding QOperator nodes.
1007
+ If QDQ format, return the corresdponging QDQ nodes.
1008
+ Gather (quantized data) + Gather (scales) + Gather (optional, zero points) -> DequantizeLinear is
1009
+ not supported yet because Gather does not support int4 data.
1010
+ """
1011
+ logger.info(f"start to quantize {node.name} ...")
1012
+
1013
+ if node.op_type == "MatMul":
1014
+ results = self.quantize_matmul(node, graph_stack)
1015
+ elif node.op_type == "Gather":
1016
+ results = self.quantize_gather(node, graph_stack)
1017
+ else:
1018
+ logger.error(f"Unsupported operator {node.op_type} for weight only quantization. Skip quantization.")
1019
+ results = [node]
1020
+
1021
+ logger.info(f"complete quantization of {node.name} ...")
1022
+
1023
+ return results
1024
+
1025
+
1026
+ class NVAWQWeightOnlyQuantizer:
1027
+ def __init__(
1028
+ self,
1029
+ config: NVAWQWeightOnlyQuantConfig,
1030
+ ):
1031
+ self.config = config
1032
+
1033
+ def quantize_awq(self, model: ModelProto | str) -> ModelProto:
1034
+ """
1035
+ Perform nvidia_awq quantization using ModelOpt's int4 quantize function.
1036
+
1037
+ Args:
1038
+ model (ModelProto): The ONNX model to quantize.
1039
+
1040
+ Returns:
1041
+ ModelProto: The quantized ONNX model.
1042
+ """
1043
+ try:
1044
+ from modelopt.onnx.quantization.int4 import quantize as quantize_int4
1045
+ except ImportError:
1046
+ print(
1047
+ "Please ensure that the 'modelopt' package is installed. Please install it using pip install nvidia_modelopt."
1048
+ )
1049
+ raise ImportError(
1050
+ "modelopt is not installed. Please install it using pip install nvidia_modelopt. Exiting."
1051
+ ) from None
1052
+
1053
+ logger.info("Starting nvidia_awq quantization...")
1054
+
1055
+ # Prepare calibration inputs
1056
+ calib_inputs = self.config.calibration_data_reader
1057
+
1058
+ # Perform quantization using ModelOpt's int4 quantize function
1059
+ quantized_model = quantize_int4(
1060
+ model,
1061
+ calibration_method=self.config.calibration_method,
1062
+ calibration_data_reader=calib_inputs,
1063
+ )
1064
+
1065
+ logger.info("Completed nvidia_awq quantization.")
1066
+ return quantized_model
1067
+
1068
+
1069
+ # TODO(fajin): change class name
1070
+ class MatMul4BitsQuantizer:
1071
+ """
1072
+ Target node: QOperator node: QDQ nodes:
1073
+ MatMul MatMulNBits DeQuantizeLinear -> MatMul
1074
+ Gather GatherBlockQuantized Gather, Gather, Gather (optional) -> DequantizeLinear
1075
+
1076
+ Perform 4b quantization of constant weights for target nodes.
1077
+ If algo_config.quant_format is QOperator:
1078
+ - nodes are replaced by the corresponding QOperator nodes.
1079
+ - quantized weights are stored in the contrib ops.
1080
+ If algo_config.quant_format is QDQ:
1081
+ - the quantized weight is stored in a standard onnx node. For MatMul, it is DequantizeLinear. For Gather,
1082
+ it is the three Gathers, one for quantized data, one for scales and one for optional zero points.
1083
+ - The nodes are replaced by the corresponding QDQ nodes.
1084
+ - currently Gather is not supported in QDQ because Gather does not support int4 yet.
1085
+ Note:
1086
+ - for quantized gather, the memory usage of "DequantizeLinear + Gather" is the same as the original Gather
1087
+ during runtime. Therefor it is not recommended.
1088
+ """
1089
+
1090
+ def __init__(
1091
+ self,
1092
+ model: ModelProto | str,
1093
+ block_size: int = 128,
1094
+ is_symmetric: bool = False,
1095
+ accuracy_level: int | None = None,
1096
+ nodes_to_exclude=None,
1097
+ nodes_to_include: list[str] | None = None,
1098
+ quant_format=QuantFormat.QOperator,
1099
+ op_types_to_quantize: tuple[str, ...] | None = None,
1100
+ quant_axes: tuple[tuple[str, int], ...] | None = None,
1101
+ algo_config: WeightOnlyQuantConfig | None = None,
1102
+ ):
1103
+ if nodes_to_exclude is None:
1104
+ nodes_to_exclude = []
1105
+ self.model = ONNXModel(onnx.load(model)) if isinstance(model, str) else ONNXModel(model)
1106
+ self.model_path = model if isinstance(model, str) else None
1107
+ self.block_size = block_size
1108
+ self.is_symmetric = is_symmetric
1109
+ self.accuracy_level = accuracy_level
1110
+ self.nodes_to_exclude = set(nodes_to_exclude)
1111
+ self.nodes_to_include = set(nodes_to_include) if nodes_to_include else None
1112
+ self.node_quantizer = None
1113
+
1114
+ if algo_config is None:
1115
+ algo_config = DefaultWeightOnlyQuantConfig(
1116
+ block_size=block_size,
1117
+ is_symmetric=is_symmetric,
1118
+ accuracy_level=accuracy_level,
1119
+ quant_format=quant_format,
1120
+ op_types_to_quantize=op_types_to_quantize,
1121
+ quant_axes=quant_axes,
1122
+ )
1123
+ self.algo_config = algo_config
1124
+ if algo_config.algorithm == "HQQ":
1125
+ self.node_quantizer = HQQWeightOnlyQuantizer(self.algo_config)
1126
+ elif algo_config.algorithm == "DEFAULT":
1127
+ self.node_quantizer = DefaultWeightOnlyQuantizer(self.algo_config)
1128
+ elif algo_config.algorithm == "nvidia_awq":
1129
+ self.node_quantizer = NVAWQWeightOnlyQuantizer(self.algo_config)
1130
+
1131
+ def _process_subgraph(self, graph_stack: list[GraphProto]):
1132
+ new_nodes = []
1133
+ graph = graph_stack[-1]
1134
+
1135
+ for node in graph.node:
1136
+ graph_attrs = [
1137
+ attr
1138
+ for attr in node.attribute
1139
+ if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS
1140
+ ]
1141
+ if len(graph_attrs):
1142
+ kwargs = {}
1143
+ for attr in node.attribute:
1144
+ if attr.type == onnx.AttributeProto.GRAPH:
1145
+ # recursive call to take care of sub-graph
1146
+ graph_stack.append(attr.g)
1147
+ kv = {attr.name: self._process_subgraph(graph_stack)}
1148
+ elif attr.type == onnx.AttributeProto.GRAPHS:
1149
+ value = []
1150
+ for subgraph in attr.graphs:
1151
+ # recursive call to take care of sub-graph
1152
+ graph_stack.append(subgraph)
1153
+ value.extend([self._process_subgraph(graph_stack)])
1154
+ kv = {attr.name: value}
1155
+ else:
1156
+ kv = attribute_to_kwarg(attr)
1157
+ kwargs.update(kv)
1158
+ node = onnx.helper.make_node( # noqa: PLW2901
1159
+ node.op_type, node.input, node.output, name=node.name, **kwargs
1160
+ )
1161
+ out_nodes = []
1162
+ if node.name in self.nodes_to_exclude:
1163
+ logger.info(f"exclude to quantize {node.name} as specified by nodes_to_exclude...")
1164
+ out_nodes = [node]
1165
+ elif (self.nodes_to_include and node.name in self.nodes_to_include) or (
1166
+ node.op_type in self.algo_config.op_types_to_quantize
1167
+ ):
1168
+ out_nodes = self.node_quantizer.quantize(node, graph_stack)
1169
+ else:
1170
+ logger.info(f"skip to quantize {node.name} ...")
1171
+ out_nodes = [node]
1172
+ new_nodes.extend(out_nodes)
1173
+
1174
+ graph.ClearField("node")
1175
+ graph.node.extend(new_nodes)
1176
+ graph_stack.pop()
1177
+ return graph
1178
+
1179
+ def _generate_q4_node_config(self):
1180
+ """Generate weight only quant configuration for nodes."""
1181
+ q4_node_config = {}
1182
+ template_config_q4 = {
1183
+ "bits": 4,
1184
+ "group_size": self.block_size,
1185
+ "scheme": "sym" if self.is_symmetric else "asym",
1186
+ }
1187
+ for node in self.model.model.graph.node:
1188
+ if node.op_type in ["MatMul"]:
1189
+ if not all([self.model.get_initializer(i) is None for i in node.input]):
1190
+ q4_node_config[node.name] = template_config_q4
1191
+ return q4_node_config
1192
+
1193
+ def int4_quant_algo(self):
1194
+ """4b quantize a model with RTN or GPTQ algorithm. Please refer to
1195
+ https://github.com/intel/neural-compressor/blob/master/docs/source/quantization_weight_only.md
1196
+ for more details on weight only quantization using Intel® Neural Compressor.
1197
+ """
1198
+
1199
+ def inc_dataloader():
1200
+ data_reader = copy.deepcopy(self.algo_config.calibration_data_reader)
1201
+ for data in data_reader:
1202
+ yield data, None
1203
+
1204
+ kwargs = {}
1205
+ if self.accuracy_level is not None:
1206
+ kwargs["accuracy_level"] = self.accuracy_level
1207
+ weight_only_node_config = self._generate_q4_node_config()
1208
+
1209
+ algorithm = self.algo_config.algorithm
1210
+ logger.info(f"start to quantize model with {algorithm} algorithm...")
1211
+ if algorithm == "RTN":
1212
+ from neural_compressor.adaptor.ox_utils.weight_only import rtn_quantize
1213
+
1214
+ kwargs["ratios"] = self.algo_config.ratios
1215
+
1216
+ self.model = rtn_quantize(
1217
+ model=self.model_path if self.model_path is not None else self.model.model,
1218
+ weight_config=weight_only_node_config,
1219
+ **kwargs,
1220
+ )
1221
+ elif algorithm == "GPTQ":
1222
+ from neural_compressor.adaptor.ox_utils.weight_only import gptq_quantize
1223
+
1224
+ kwargs["percdamp"] = self.algo_config.percdamp
1225
+ kwargs["blocksize"] = self.algo_config.block_size
1226
+ kwargs["actorder"] = self.algo_config.actorder
1227
+ kwargs["mse"] = self.algo_config.mse
1228
+ kwargs["perchannel"] = self.algo_config.perchannel
1229
+ kwargs["n_samples"] = -1
1230
+ dataloader = inc_dataloader()
1231
+
1232
+ self.model = gptq_quantize(
1233
+ model=self.model_path if self.model_path is not None else self.model.model,
1234
+ weight_config=weight_only_node_config,
1235
+ dataloader=dataloader,
1236
+ **kwargs,
1237
+ )
1238
+ logger.info(f"complete quantization of model with {algorithm} algorithm.")
1239
+
1240
+ def process(self):
1241
+ if self.algo_config.algorithm in ["HQQ", "DEFAULT"]:
1242
+ # use a stack to keep track of sub-graphs
1243
+ graph_stack = [self.model.graph()]
1244
+
1245
+ # Update domain opset
1246
+ if self.algo_config.quant_format == QuantFormat.QOperator:
1247
+ self.model.set_opset_import("com.microsoft", 1)
1248
+
1249
+ if self.algo_config.quant_format == QuantFormat.QDQ or "Gather" in self.algo_config.op_types_to_quantize:
1250
+ opset_import = self.model.opset_import()
1251
+ for opset in opset_import:
1252
+ if opset.domain in [None, "ai.onnx", ""] and opset.version < 21:
1253
+ logger.warning(
1254
+ "The opset of the input model is under 21 and doesn't support int4 data type. "
1255
+ "Force to update it to opset 21, but the generated model may not be a valid model."
1256
+ )
1257
+ self.model.set_opset_import(opset.domain, 21)
1258
+
1259
+ self._process_subgraph(graph_stack)
1260
+ self.model.clean_initializers()
1261
+ elif self.algo_config.algorithm == "nvidia_awq":
1262
+
1263
+ # Handle nvidia_awq quantization
1264
+ logger.info("Processing nvidia_awq quantization...")
1265
+ self.model = self.node_quantizer.quantize_awq(
1266
+ self.model.model if self.model_path is None else self.model_path
1267
+ )
1268
+ logger.info("Completed nvidia_awq quantization.")
1269
+ self.model = ONNXModel(self.model) # Ensure the model is wrapped back into ONNXModel
1270
+ self.model.clean_initializers()
1271
+ else:
1272
+ # use Intel® Neural Compressor for RTN or GPTQ weight-only quantize algorithm
1273
+ try:
1274
+ importlib.import_module("neural_compressor")
1275
+ except Exception as e:
1276
+ logging.error(f"{e}.")
1277
+ raise RuntimeError(
1278
+ "neural-compressor is not correctly installed. Please check your environment."
1279
+ ) from e
1280
+
1281
+ import neural_compressor
1282
+
1283
+ assert version.parse(neural_compressor.__version__) >= version.parse(
1284
+ "2.3.2"
1285
+ ), "Require neural-compressor >= 2.3.2 to support weight only quantization!"
1286
+
1287
+ self.int4_quant_algo()
1288
+
1289
+
1290
+ def ort_convert_str_to_bool(value):
1291
+ return value.lower() in ("true", "1")
1292
+
1293
+
1294
+ # Custom function to parse str:int pairs
1295
+ def parse_key_value_pair(s):
1296
+ key, value = s.split(":")
1297
+ return key, int(value)
1298
+
1299
+
1300
+ def parse_args():
1301
+ parser = argparse.ArgumentParser(
1302
+ description="""Blockwise int4 quantization for MatMul 2D weight matrices.
1303
+
1304
+ A weight matrix is partitioned into into blocks, where each block is a
1305
+ continguous subset inside each column. Each block is quantized into a
1306
+ set of 4b integers with a scaling factor and an optional offset.
1307
+ """
1308
+ )
1309
+
1310
+ parser.add_argument("--input_model", required=True, help="Path to the input model file")
1311
+ parser.add_argument("--output_model", required=True, help="Path to the output model file")
1312
+ parser.add_argument("--block_size", required=False, default=32, type=int, help="Block size for quantization")
1313
+ parser.add_argument(
1314
+ "--quant_method",
1315
+ default="default",
1316
+ type=str,
1317
+ choices=["default", "hqq", "rtn", "gptq", "nvidia_awq"],
1318
+ help="the algorithm used to quantize weight, \nrtn and gptq leverage Intel® Neural Compressor",
1319
+ )
1320
+ parser.add_argument("--bits", default=4, type=int, help="the target bits to represent weight")
1321
+ parser.add_argument(
1322
+ "--symmetric",
1323
+ required=False,
1324
+ default=True,
1325
+ const=True,
1326
+ nargs="?",
1327
+ type=ort_convert_str_to_bool,
1328
+ choices=[True, False],
1329
+ help="Indicate whether to quantize the model symmetrically, symmetric is not supported by hqq",
1330
+ )
1331
+ parser.add_argument(
1332
+ "--accuracy_level",
1333
+ required=False,
1334
+ type=int,
1335
+ help="Accuracy level of the 4-bit quantized MatMul computation. "
1336
+ "Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details "
1337
+ "(https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits).",
1338
+ )
1339
+ parser.add_argument("-v", "--verbose", required=False, action="store_true")
1340
+ parser.set_defaults(verbose=False)
1341
+ parser.add_argument(
1342
+ "--nodes_to_exclude",
1343
+ nargs="+",
1344
+ type=str,
1345
+ required=False,
1346
+ default=[],
1347
+ help="Specify the nodes to be excluded from quantization with node names",
1348
+ )
1349
+ parser.add_argument(
1350
+ "--nodes_to_include",
1351
+ nargs="+",
1352
+ type=str,
1353
+ required=False,
1354
+ help="Specify the specific nodes to be included from quantization with node names",
1355
+ )
1356
+ parser.add_argument(
1357
+ "--quant_format",
1358
+ default="QOperator",
1359
+ type=str,
1360
+ choices=["QOperator", "QDQ"],
1361
+ help="QuantFormat {QOperator, QDQ}"
1362
+ "QOperator format quantizes the model with quantized operators directly."
1363
+ "QDQ format quantize the model by inserting DeQuantizeLinear before the MatMul.",
1364
+ )
1365
+ parser.add_argument(
1366
+ "--op_types_to_quantize",
1367
+ type=str,
1368
+ nargs="+",
1369
+ choices=["MatMul", "Gather"],
1370
+ help="op_types_to_quantize {MatMul, Gather}. Operators to quantize. Default is MatMul.",
1371
+ )
1372
+ parser.add_argument(
1373
+ "--quant_axes",
1374
+ type=parse_key_value_pair,
1375
+ nargs="+",
1376
+ required=False,
1377
+ help="Key-value pairs in op_type:axis_to_quantize separated by space."
1378
+ "Specify the axis to quantize for an op. Default {MatMul:0, Gather:1}"
1379
+ "Example: --quant_axes MatMul:0 Gather:1",
1380
+ )
1381
+ # Group arguments specific to nvidia_awq
1382
+ nv_awq_config = parser.add_argument_group("nvidia_awq", "Arguments specific to nvidia_awq quantization")
1383
+ nv_awq_config.add_argument(
1384
+ "--calib_dataset_name",
1385
+ type=str,
1386
+ default="cnn",
1387
+ help="Name of the calibration dataset for nvidia_awq.",
1388
+ )
1389
+ nv_awq_config.add_argument(
1390
+ "--tokenizer_dir",
1391
+ type=str,
1392
+ required=False,
1393
+ help="Path of the tokenizer dir.",
1394
+ )
1395
+ nv_awq_config.add_argument(
1396
+ "--calibration_method",
1397
+ type=str,
1398
+ required=False,
1399
+ choices=["awq", "awq_clip"],
1400
+ help="Support two options, awq implementation and weight clipping.",
1401
+ )
1402
+ nv_awq_config.add_argument(
1403
+ "--cache_dir",
1404
+ type=str,
1405
+ default="./cache",
1406
+ help="Cache directory for calibration data.",
1407
+ )
1408
+ return parser.parse_args()
1409
+
1410
+
1411
+ if __name__ == "__main__":
1412
+ args = parse_args()
1413
+ if args.verbose:
1414
+ logger.setLevel(logging.DEBUG)
1415
+
1416
+ input_model_path = args.input_model
1417
+ output_model_path = args.output_model
1418
+ quant_format = QuantFormat[args.quant_format]
1419
+ op_types_to_quantize = tuple(args.op_types_to_quantize) if args.op_types_to_quantize else ("MatMul",)
1420
+ quant_axes = tuple(args.quant_axes) if args.quant_axes else None
1421
+
1422
+ if os.path.exists(output_model_path):
1423
+ logger.error(f"file {output_model_path} already exists")
1424
+ raise Exception(f"file {output_model_path} already exists")
1425
+
1426
+ if args.symmetric and args.quant_method == "hqq":
1427
+ logger.warning("symmetric is not supportted by hqq, will force to symmetric=False")
1428
+ args.symmetric = False
1429
+
1430
+ model = onnx.load(input_model_path)
1431
+ if args.quant_method == "hqq":
1432
+ quant_config = HQQWeightOnlyQuantConfig(
1433
+ block_size=args.block_size, bits=args.bits, op_types_to_quantize=op_types_to_quantize, quant_axes=quant_axes
1434
+ )
1435
+ elif args.quant_method == "default":
1436
+ quant_config = DefaultWeightOnlyQuantConfig(
1437
+ block_size=args.block_size,
1438
+ is_symmetric=args.symmetric,
1439
+ accuracy_level=args.accuracy_level,
1440
+ quant_format=quant_format,
1441
+ op_types_to_quantize=op_types_to_quantize,
1442
+ quant_axes=quant_axes,
1443
+ )
1444
+ elif args.quant_method == "rtn":
1445
+ quant_config = RTNWeightOnlyQuantConfig(op_types_to_quantize=op_types_to_quantize)
1446
+ elif args.quant_method == "gptq":
1447
+ quant_config = GPTQWeightOnlyQuantConfig(block_size=args.block_size, op_types_to_quantize=op_types_to_quantize)
1448
+ elif args.quant_method == "nvidia_awq":
1449
+
1450
+ if quant_format == QuantFormat.QOperator:
1451
+ logger.warning("QOperator is not applicable to nvidia_awq. overriding the value to QDQ")
1452
+ quant_format = QuantFormat.QDQ
1453
+
1454
+ model = input_model_path
1455
+ if args.calibration_method is not None:
1456
+ if args.calibration_method == "awq":
1457
+ calibration_method = "awq_lite"
1458
+ else:
1459
+ calibration_method = "awq_clip"
1460
+ else:
1461
+ calibration_method = "awq_lite"
1462
+
1463
+ quant_config = NVAWQWeightOnlyQuantConfig(
1464
+ dataset_name=args.calib_dataset_name,
1465
+ tokenizer_dir=args.tokenizer_dir,
1466
+ cache_dir=args.cache_dir,
1467
+ calibration_method=calibration_method,
1468
+ )
1469
+ else:
1470
+ raise ValueError(f"Unsupported quantization method: {args.quant_method}")
1471
+
1472
+ quant = MatMul4BitsQuantizer(
1473
+ model=model,
1474
+ accuracy_level=args.accuracy_level,
1475
+ nodes_to_exclude=args.nodes_to_exclude,
1476
+ nodes_to_include=args.nodes_to_include,
1477
+ algo_config=quant_config,
1478
+ )
1479
+ quant.process()
1480
+ quant.model.save_model_to_file(output_model_path, True)