onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,932 @@
1
+ #
2
+ # The implementation of this file is based on:
3
+ # https://github.com/intel/neural-compressor/tree/master/neural_compressor
4
+ #
5
+ # Copyright (c) 2023 Intel Corporation
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ #
19
+ # Modifications:
20
+ # Add k-quant quantization method.
21
+ # Copyright (c) Microsoft Corporation. All rights reserved.
22
+ # Licensed under the MIT License.
23
+
24
+ """WeightOnly for onnxrt adaptor."""
25
+
26
+ import copy
27
+ import logging
28
+ import os
29
+ import sys
30
+
31
+ import numpy as np
32
+ import onnx
33
+ from onnx import numpy_helper
34
+ from onnx.helper import np_dtype_to_tensor_dtype
35
+
36
+ import onnxruntime as ort
37
+
38
+ from .onnx_model import ONNXModel
39
+ from .util import simple_progress_bar
40
+
41
+ logger = logging.getLogger("neural_compressor")
42
+
43
+
44
+ def make_matmul_weight_only_node(
45
+ node,
46
+ weight_shape,
47
+ num_bits,
48
+ group_size,
49
+ k_blocks,
50
+ q_weight,
51
+ scale,
52
+ zero_point,
53
+ accuracy_level=0,
54
+ ): # pragma: no cover
55
+ """Build MatMulNBits node.
56
+
57
+ Args:
58
+ node: original matmul node
59
+ weight_shape: original weight shape
60
+ num_bits (int): num_bits
61
+ group_size (int): how many elements share one scale/zp
62
+ k_blocks (int): block number
63
+ q_weight (array): quantized weight
64
+ scale (array): scale
65
+ zero_point (array): zero point
66
+ accuracy_level (int): accuracy level. Support 0 (unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8).
67
+
68
+ Returns:
69
+ matmul_weight_only_node: MatMulNBits node
70
+ new_inits: initializers of the new node
71
+ """
72
+ blob_size = group_size * num_bits // 8
73
+ packed = np.zeros((q_weight.shape[0], blob_size), dtype="uint8")
74
+ q_weight_name = node.input[1] + f"_Q{num_bits!s}G{group_size!s}"
75
+ input_names = [node.input[0], q_weight_name]
76
+ new_inits = []
77
+ kwargs = {}
78
+
79
+ op_type = "MatMulNBits"
80
+
81
+ # pack quantized weight
82
+ if num_bits == 4:
83
+ q_weight_pairs = q_weight[:, ::2] | q_weight[:, 1::2] << 4
84
+ packed[:, :] = q_weight_pairs[:, :blob_size]
85
+ elif num_bits == 8:
86
+ packed = q_weight
87
+ else:
88
+ logger.error(f"MatMulNBits does not have kernel support for num_bits = {num_bits}.")
89
+
90
+ packed = np.reshape(packed, (-1, k_blocks, blob_size))
91
+
92
+ # build scale tensor
93
+ scale = np.reshape(scale, (-1, k_blocks))
94
+ assert scale.dtype == np.float32 or scale.dtype == np.float16
95
+ scale_tensor = onnx.helper.make_tensor(
96
+ name=node.input[1] + "_scale",
97
+ data_type=np_dtype_to_tensor_dtype(scale.dtype),
98
+ dims=scale.shape,
99
+ vals=scale.tobytes(),
100
+ raw=True,
101
+ )
102
+ input_names.append(scale_tensor.name)
103
+ new_inits.append(scale_tensor)
104
+
105
+ # build zero_point tensor
106
+ if zero_point is not None:
107
+ if num_bits == 8:
108
+ packed_zp = zero_point.astype("uint8")
109
+ elif num_bits == 4:
110
+ # For 4-bit case, the default zeros is 0x8. So it is 0x88 = 136 if we fill lower/higher 4 bits with 0x8.
111
+ packed_zp = np.full((zero_point.shape[0] + 1) // 2, 136, dtype="uint8")
112
+ # create an index array
113
+ idx = np.arange(zero_point.shape[0] // k_blocks * k_blocks).reshape(-1)
114
+ # separate odd and even indices
115
+ even_idx = idx[::2]
116
+ odd_idx = idx[1::2]
117
+ # vectorized operation for even and odd indices
118
+ packed_zp[even_idx // 2] = (packed_zp[even_idx // 2] & 0xF0) | zero_point[even_idx].ravel()
119
+ packed_zp[odd_idx // 2] = (packed_zp[odd_idx // 2] & 0x0F) | (zero_point[odd_idx].ravel() << 4)
120
+ else:
121
+ raise ValueError(f"MatMulNBits does not have kernel support for num_bits = {num_bits}.")
122
+
123
+ packed_zp = np.reshape(packed_zp, (weight_shape[1], -1))
124
+ zp_tensor = onnx.helper.make_tensor(
125
+ name=node.input[1] + "_zp", data_type=2, dims=packed_zp.shape, vals=packed_zp.tobytes(), raw=True
126
+ )
127
+ input_names.append(zp_tensor.name)
128
+ new_inits.append(zp_tensor)
129
+
130
+ # set kwargs
131
+ kwargs["K"] = weight_shape[0]
132
+ kwargs["N"] = weight_shape[1]
133
+ kwargs["bits"] = num_bits
134
+ kwargs["block_size"] = group_size
135
+ if accuracy_level > 0:
136
+ # require onnxruntime > 1.16.3
137
+ kwargs["accuracy_level"] = accuracy_level
138
+
139
+ q_weight_tensor = onnx.helper.make_tensor(
140
+ name=q_weight_name,
141
+ data_type=2,
142
+ dims=packed.shape,
143
+ vals=packed.tobytes(),
144
+ raw=True,
145
+ )
146
+ new_inits.append(q_weight_tensor)
147
+
148
+ matmul_weight_only_node = onnx.helper.make_node(
149
+ op_type,
150
+ inputs=input_names,
151
+ outputs=node.output,
152
+ name=node.name + "_Q" + str(num_bits) if node.name else "_Q" + str(num_bits),
153
+ domain="com.microsoft",
154
+ **kwargs,
155
+ )
156
+ return matmul_weight_only_node, new_inits
157
+
158
+
159
+ def quant_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ratio=1.0):
160
+ """Quantize tensor per group.
161
+
162
+ Args:
163
+ data : input weight
164
+ num_bits (int, optional): num_bits. Defaults to 4.
165
+ group_size (int, optional): how many elements share one scale/zp. Defaults to 4.
166
+ scheme (str, optional): quantization scheme. Defaults to "asym".
167
+ dtype (str, optional): data type. Defaults to "int".
168
+ ratio (float, optional): percentile of clip. Defaults to 1.0.
169
+
170
+ Returns:
171
+ output: quantized weight
172
+ scale: scale
173
+ zero_point: zero point
174
+ """
175
+ data = np.reshape(data, (-1, group_size))
176
+ if scheme == "asym" or dtype == "uint":
177
+ maxq = 2**num_bits - 1
178
+ minq = 0
179
+ elif scheme == "sym":
180
+ maxq = 2 ** (num_bits - 1) - 1 if num_bits != 1 else 0
181
+ minq = -(2 ** (num_bits - 1)) if num_bits != 1 else -1
182
+
183
+ rmin = np.min(data, axis=1, keepdims=True) * ratio
184
+ rmax = np.max(data, axis=1, keepdims=True) * ratio
185
+ if scheme == "sym":
186
+ max_range = np.maximum(np.abs(rmin), np.abs(rmax))
187
+ scale = np.ones(rmax.shape)
188
+ mask = max_range > 0
189
+ scale[mask] = (max_range[mask] * 2.0).astype(np.float64) / (maxq - minq)
190
+ zero_point = (
191
+ np.zeros(scale.shape) if dtype == "int" else np.ones(rmax.shape, dtype="uint8") * (1 << (num_bits - 1))
192
+ )
193
+ else:
194
+ scale = np.ones(rmax.shape)
195
+ scale[rmin != rmax] = np.array(
196
+ [float(i) / (maxq - minq) for i in (rmax - rmin)[rmin != rmax].flatten().tolist()]
197
+ )
198
+ zero_point = (
199
+ ((np.zeros(scale.shape) - rmin) / scale).round()
200
+ if dtype == "int"
201
+ else np.maximum(0, np.minimum(maxq, ((np.zeros(scale.shape) - rmin) / scale).round())).astype("uint8")
202
+ )
203
+
204
+ q_weight = np.empty_like(data, dtype=scale.dtype)
205
+ np.divide(data, scale, out=q_weight)
206
+ np.add(q_weight, zero_point, out=q_weight)
207
+ np.round(q_weight, out=q_weight)
208
+ np.clip(q_weight, minq, maxq, out=q_weight)
209
+
210
+ return q_weight, scale, zero_point
211
+
212
+
213
+ def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32):
214
+ """Quantize tensor per group based on k quant.
215
+
216
+ Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c
217
+
218
+ Args:
219
+ data : input weight
220
+ num_bits (int, optional): num_bits. Defaults to 4.
221
+ group_size (int, optional): how many elements share one scale/zp. Defaults to 32.
222
+
223
+ Returns:
224
+ output: quantized weight
225
+ scale: scale
226
+ zero_point: zero point
227
+ """
228
+ data = np.reshape(data, (-1, group_size)).astype(np.float32) # nb = data.shape[0], (nb, group_size)
229
+ maxq = 2**num_bits - 1
230
+ minq = 0
231
+ sum_x2 = np.sum(data**2, axis=1, keepdims=True) # (nb, 1)
232
+ av_x = np.sqrt(sum_x2 / group_size) # (nb, 1)
233
+ weights = np.add(av_x, np.abs(data)) # (nb, group_size)
234
+ rmin = np.min(data, axis=1, keepdims=True) # (nb, 1)
235
+ rmax = np.max(data, axis=1, keepdims=True) # (nb, 1)
236
+ sum_w = np.sum(weights, axis=1, keepdims=True) # (nb, 1)
237
+ sum_x = np.sum(weights * data, axis=1, keepdims=True) # (nb, group_size)
238
+ iscale = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
239
+ mask = rmin != rmax
240
+ iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask])
241
+ scale = 1 / iscale
242
+ quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size)
243
+ diff = scale * quant_data + rmin - data # (nb, group_size)
244
+ best_mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)
245
+ nstep = 20
246
+ rdelta = 0.1
247
+ # nstep * rdelta = -2 * rrmin, maxq - minq = 2**num_bits - 1
248
+ rrmin = -1
249
+ for is_ in range(nstep):
250
+ iscale_new = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
251
+ factor = np.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0]
252
+ mask = rmin != rmax
253
+ iscale_new[mask] = factor / (rmax[mask] - rmin[mask])
254
+ quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size)
255
+ mul_weights_quant_data_new = weights * quant_data_new
256
+ sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1)
257
+ sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1)
258
+ sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1)
259
+ D = np.subtract(sum_w * sum_l2, sum_l**2) # noqa: N806
260
+
261
+ this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1)
262
+ this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1)
263
+
264
+ diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
265
+ mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)
266
+
267
+ mad_1 = np.array(mad)
268
+ best_mad_1 = np.array(best_mad)
269
+ idx_to_replace = np.where(mad_1 < best_mad_1)[0]
270
+ quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :]
271
+ best_mad[idx_to_replace] = mad[idx_to_replace]
272
+ scale[idx_to_replace] = this_scale[idx_to_replace]
273
+ rmin[idx_to_replace] = this_min[idx_to_replace]
274
+
275
+ zero_point = np.clip(((-rmin) / scale).round(), 0, maxq).astype("uint8")
276
+ scale = scale.astype(np.float64)
277
+ q_weight = np.empty_like(data, dtype=scale.dtype)
278
+ np.divide(data, scale, out=q_weight)
279
+ np.add(q_weight, zero_point, out=q_weight)
280
+ np.round(q_weight, out=q_weight)
281
+ np.clip(q_weight, minq, maxq, out=q_weight)
282
+
283
+ return q_weight, scale, zero_point
284
+
285
+
286
+ def quant_tensor_k_quant_cuda(data, num_bits=4, group_size=32):
287
+ """Quantize tensor per group based on k quant.
288
+
289
+ Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c
290
+
291
+ Args:
292
+ data : input weight
293
+ num_bits (int, optional): num_bits. Defaults to 4.
294
+ group_size (int, optional): how many elements share one scale/zp. Defaults to 4.
295
+
296
+ Returns:
297
+ output: quantized weight
298
+ scale: scale
299
+ zero_point: zero point
300
+ """
301
+ try:
302
+ import cupy as cp # noqa: PLC0415
303
+ import torch # noqa: PLC0415
304
+
305
+ if torch.cuda.is_available():
306
+ data = cp.asarray(data)
307
+ data = data.reshape((-1, group_size)).astype(cp.float32) # nb = data.shape[0], (nb, group_size)
308
+ maxq = 2**num_bits - 1
309
+ minq = 0
310
+ sum_x2 = cp.sum(data**2, axis=1, keepdims=True) # (nb, 1)
311
+ av_x = cp.sqrt(sum_x2 / group_size) # (nb, 1)
312
+ weights = cp.add(av_x, cp.abs(data)) # (nb, group_size)
313
+ rmin = cp.min(data, axis=1, keepdims=True) # (nb, 1)
314
+ rmax = cp.max(data, axis=1, keepdims=True) # (nb, 1)
315
+ sum_w = cp.sum(weights, axis=1, keepdims=True) # (nb, 1)
316
+ sum_x = cp.sum(weights * data, axis=1, keepdims=True) # (nb, group_size)
317
+ iscale = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
318
+ mask = rmin != rmax
319
+ iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask])
320
+ scale = 1 / iscale
321
+ quant_data = cp.clip(cp.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size)
322
+ diff = scale * quant_data + rmin - data # (nb, group_size)
323
+ best_mad = cp.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)
324
+ nstep = 20
325
+ rdelta = 0.1
326
+ rrmin = -1
327
+ for is_ in range(nstep):
328
+ iscale_new = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
329
+ factor = cp.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0]
330
+ mask = rmin != rmax
331
+ iscale_new[mask] = factor / (rmax[mask] - rmin[mask])
332
+ quant_data_new = cp.clip(cp.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size)
333
+ mul_weights_quant_data_new = weights * quant_data_new
334
+ sum_l = cp.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1)
335
+ sum_l2 = cp.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1)
336
+ sum_xl = cp.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1)
337
+ D = cp.subtract(sum_w * sum_l2, sum_l**2) # noqa: N806
338
+
339
+ this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1)
340
+ this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1)
341
+
342
+ diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
343
+ mad = cp.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)
344
+
345
+ mad_1 = cp.array(mad)
346
+ best_mad_1 = cp.array(best_mad)
347
+ idx_to_replace = cp.where(mad_1 < best_mad_1)[0]
348
+ quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :]
349
+ best_mad[idx_to_replace] = mad[idx_to_replace]
350
+ scale[idx_to_replace] = this_scale[idx_to_replace]
351
+ rmin[idx_to_replace] = this_min[idx_to_replace]
352
+
353
+ zero_point = cp.clip(((-rmin) / scale).round(), 0, maxq).astype("uint8")
354
+ scale = scale.astype(cp.float64)
355
+ q_weight = cp.empty_like(data, dtype=scale.dtype)
356
+ cp.divide(data, scale, out=q_weight)
357
+ cp.add(q_weight, zero_point, out=q_weight)
358
+ cp.round(q_weight, out=q_weight)
359
+ cp.clip(q_weight, minq, maxq, out=q_weight)
360
+
361
+ return q_weight.get(), scale.get(), zero_point.get()
362
+ else:
363
+ logger.warning(
364
+ "Try to use k-quant quantization on CUDA. However, CUDA is not available."
365
+ "Fall back to k-quant quantization on CPU."
366
+ )
367
+ return quant_tensor_k_quant_cpu(data, num_bits, group_size)
368
+ except ImportError:
369
+ logger.info(
370
+ "Now we are using k-quant quantization on cpu, which is time consuming."
371
+ "Please consider install cupy to speed up on CUDA. See https://cupy.dev/"
372
+ "Please also install torch to check CUDA availability."
373
+ )
374
+ return quant_tensor_k_quant_cpu(data, num_bits, group_size)
375
+
376
+
377
+ def qdq_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ratio=1.0):
378
+ """Quant dequant tensor per group.
379
+
380
+ Args:
381
+ data : input weight
382
+ num_bits (int, optional): num_bits. Defaults to 4.
383
+ group_size (int, optional): how many elements share one scale/zp. Defaults to 4.
384
+ scheme (str, optional): quantization scheme. Defaults to "asym".
385
+ dtype (str, optional): data type. Defaults to "int".
386
+ ratio (float, optional): percentile of clip. Defaults to 1.0.
387
+
388
+ Returns:
389
+ output: quant-dequant weight
390
+ """
391
+ org_shape = data.shape
392
+ weight, scale, zp = quant_tensor(data, num_bits, group_size, scheme, dtype, ratio)
393
+ return np.reshape(scale * (weight - zp), org_shape)
394
+
395
+
396
+ def pad_tensor(weight, group_size, k_blocks):
397
+ """Pad tensor rowi so that it can be is divisible by group_size.
398
+
399
+ Args:
400
+ weight (array): weight
401
+ group_size (int): how many elements share one scale/zp
402
+ k_blocks (int): the number of block
403
+
404
+ Returns:
405
+ weight: paded weight
406
+ """
407
+ if group_size == -1:
408
+ return weight
409
+
410
+ org_w_shape = weight.shape
411
+ padded_rows = k_blocks * group_size
412
+ pad_len = padded_rows - org_w_shape[0]
413
+
414
+ if pad_len > 0:
415
+ weight = np.pad(weight, ((0, pad_len), (0, 0)), "constant")
416
+
417
+ return weight
418
+
419
+
420
+ def rtn_quantize(
421
+ model,
422
+ weight_config={}, # noqa: B006
423
+ num_bits=4,
424
+ group_size=32,
425
+ scheme="asym",
426
+ ratios={}, # noqa: B006
427
+ accuracy_level=0,
428
+ providers=["CPUExecutionProvider"], # noqa: B006
429
+ algorithm="k_quant",
430
+ ):
431
+ """Quant the model with round to nearst method.
432
+
433
+ Args:
434
+ model (ModelProto or ONNXModel): onnx model
435
+ weight_config (dict): quantization config
436
+ For example,
437
+ weight_config = {
438
+ 'fc2':
439
+ {
440
+ 'bits': 4,
441
+ 'group_size': 32,
442
+ 'scheme': 'sym',
443
+ 'algorithm': 'RTN'
444
+ }
445
+ }
446
+ num_bits (int, optional): num_bits. Default is 4.
447
+ group_size (int, optional): how many elements share one scale/zp. Default is 32.
448
+ scheme (str, optional): sym or asym. Defaults to "asym".
449
+ ratios (dict, optional): percentile of clip. Defaults to {}.
450
+ accuracy_level (int): accuracy level. Support 0 (unset),1(fp32), 2(fp16), 3(bf16), or 4(int8).
451
+ providers (list): providers to use
452
+
453
+ Returns:
454
+ model: fake quantized ONNXModel
455
+ """
456
+ model = ONNXModel(model)
457
+ base_dir = os.path.dirname(model.model_path) if model.model_path is not None else ""
458
+ new_nodes = []
459
+ remove_nodes = []
460
+ total_num = len([i for i in model.nodes() if i.op_type in ["MatMul"]])
461
+ curr_id = 0
462
+ for node in model.nodes():
463
+ if node.op_type in ["MatMul"]:
464
+ curr_id += 1
465
+ simple_progress_bar(total_num, curr_id)
466
+ if (
467
+ node.op_type in ["MatMul"]
468
+ and model.get_initializer(node.input[1]) is not None
469
+ and weight_config.get(node.name, {}) != "fp32"
470
+ ):
471
+ weight_tensor = model.get_initializer(node.input[1])
472
+ weight = numpy_helper.to_array(weight_tensor, base_dir=base_dir).copy()
473
+ if len(weight.shape) != 2:
474
+ continue
475
+
476
+ dtype = weight.dtype
477
+
478
+ if node.name in weight_config:
479
+ num_bits = weight_config[node.name]["bits"]
480
+ group_size = weight_config[node.name]["group_size"]
481
+ scheme = weight_config[node.name]["scheme"]
482
+
483
+ org_w_shape = weight.shape # ic, oc
484
+ group_size = group_size if group_size != -1 else org_w_shape[0]
485
+
486
+ k_blocks = (org_w_shape[0] - 1) // group_size + 1
487
+ init_share_num = model.get_initializer_share_num(node.input[1])
488
+
489
+ weight = pad_tensor(weight, group_size, k_blocks)
490
+
491
+ satisfy_MatMulNBits_condition = num_bits == 4 or num_bits == 8 # noqa: N806
492
+
493
+ if satisfy_MatMulNBits_condition: # pragma: no cover
494
+ if algorithm == "k_quant":
495
+ q_weight, scale, zp = quant_tensor_k_quant_cuda(weight.T, num_bits, group_size)
496
+ else:
497
+ q_weight, scale, zp = quant_tensor(
498
+ weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1)
499
+ )
500
+
501
+ q_matmul_node, new_inits = make_matmul_weight_only_node(
502
+ node=node,
503
+ weight_shape=org_w_shape,
504
+ num_bits=num_bits,
505
+ group_size=group_size,
506
+ k_blocks=k_blocks,
507
+ q_weight=q_weight.astype("uint8"),
508
+ scale=scale.astype(dtype),
509
+ zero_point=zp if scheme == "asym" or algorithm == "k_quant" else None,
510
+ accuracy_level=accuracy_level,
511
+ )
512
+
513
+ model.add_initializers(new_inits)
514
+ remove_nodes.append(node)
515
+ new_nodes.append(q_matmul_node)
516
+ else:
517
+ q_weight = qdq_tensor(weight.T, num_bits, group_size, scheme, "int", ratios.get(node.input[1], 1))
518
+ q_weight = np.reshape(q_weight, (org_w_shape[1], -1))
519
+ q_weight = np.transpose(q_weight)
520
+ q_weight = q_weight[: org_w_shape[0], :].astype(dtype)
521
+ q_weight_tensor = onnx.helper.make_tensor(
522
+ name=node.input[1] + f"_Q{num_bits!s}G{group_size!s}",
523
+ data_type=np_dtype_to_tensor_dtype(dtype),
524
+ dims=weight.shape,
525
+ vals=q_weight.tobytes(),
526
+ raw=True,
527
+ )
528
+ model.add_initializer(q_weight_tensor)
529
+ node.input[1] = q_weight_tensor.name
530
+ if init_share_num == 1:
531
+ model.remove_initializer(weight_tensor)
532
+
533
+ model.add_nodes(new_nodes)
534
+ model.remove_nodes(remove_nodes)
535
+ model.topological_sort()
536
+ return model
537
+
538
+
539
+ def get_weight_scale(weight, group_size):
540
+ """Get the scale of weight."""
541
+ org_shape = weight.shape
542
+ weight = np.reshape(weight, (-1, group_size)) if group_size != -1 else weight
543
+ scale = np.mean(np.reshape(np.abs(weight) / np.max(np.abs(weight), axis=1, keepdims=True), org_shape), axis=0)
544
+ return scale
545
+
546
+
547
+ def prepare_inputs(model, n_samples, dataloader, providers):
548
+ """Prepare inputs for weight only quantization.
549
+
550
+ Args:
551
+ model (ModelProto or ONNXModel): onnx model
552
+ n_samples (int, optional): calibration sample number. -1 means all samples.
553
+ dataloader (object): dataloader for calibration.
554
+ providers (list): providers to use
555
+
556
+ Returns:
557
+ inputs: prepared inputs.
558
+ so: session options
559
+ """
560
+ from importlib.util import find_spec # noqa: PLC0415
561
+
562
+ from .util import to_numpy # noqa: PLC0415
563
+
564
+ so = ort.SessionOptions()
565
+ if sys.version_info < (3, 11) and find_spec("onnxruntime_extensions"): # pragma: no cover
566
+ from onnxruntime_extensions import get_library_path # noqa: PLC0415
567
+
568
+ so.register_custom_ops_library(get_library_path())
569
+ if model.is_large_model:
570
+ onnx.save_model(
571
+ model.model,
572
+ model.model_path + "_augment.onnx",
573
+ save_as_external_data=True,
574
+ all_tensors_to_one_file=True,
575
+ convert_attribute=False,
576
+ )
577
+
578
+ session = (
579
+ ort.InferenceSession(model.model.SerializeToString(), so, providers=providers)
580
+ if not model.is_large_model
581
+ else ort.InferenceSession(model.model_path + "_augment.onnx", so, providers=providers)
582
+ )
583
+ inputs_names = [i.name for i in session.get_inputs()]
584
+ del session
585
+
586
+ inputs = []
587
+ for i, data in enumerate(dataloader):
588
+ if n_samples != -1 and ((i + 1) * dataloader.batch_size) > n_samples:
589
+ break
590
+ if len(inputs_names) != 1 or isinstance(data[0], dict):
591
+ assert len(data[0]) == len(inputs_names), (
592
+ f"Input number mismatch, require {len(inputs_names)} but get {len(data[0])}"
593
+ )
594
+
595
+ if isinstance(data[0], dict):
596
+ inputs.append(dict([(name, to_numpy(inp_data)) for name, inp_data in data[0].items()])) # noqa: C404
597
+ elif isinstance(data[0], np.ndarray): # pragma: no cover
598
+ inputs.append(dict([(name, inp) for name, inp in zip(inputs_names, [data[0]], strict=False)])) # noqa: C404
599
+ else: # pragma: no cover
600
+ inputs.append(dict([(name, to_numpy(inp)) for name, inp in zip(inputs_names, data[0], strict=False)])) # noqa: C404
601
+ return inputs, so
602
+
603
+
604
+ def gptq(
605
+ W,
606
+ H,
607
+ num_bits=4,
608
+ group_size=32,
609
+ scheme="asym",
610
+ blocksize=128,
611
+ percdamp=0.01,
612
+ actorder=False,
613
+ mse=False,
614
+ perchannel=True,
615
+ ):
616
+ """Quant the weight with GPTQ method.
617
+
618
+ Args:
619
+ W (array): weight.
620
+ H (array): Hessian matrix.
621
+ num_bits (int, optional): num_bits. Default is 4.
622
+ group_size (int, optional): how many elements share one scale/zp. Default is 32.
623
+ scheme (str, optional): sym or asym. Defaults to "asym".
624
+ blocksize (int, optional): blocksize to quantize weight.
625
+ percdamp (float, optional): percent of the average Hessian diagonal to use for dampening.
626
+ actorder (bool, optional): whether rearrange Hessian matrix considering the diag's value.
627
+ mse (bool, optional): whether get scale and zero point with mse error.
628
+ perchannel (bool, optional): whether quantize weight per-channel.
629
+
630
+ Returns:
631
+ Q: fake quantized weight
632
+ """
633
+ maxq = 2**num_bits - 1
634
+ grid = 100
635
+ maxshrink = 0.8
636
+ norm = 2.4
637
+
638
+ def find_params(weight):
639
+ org_shape = weight.shape
640
+ # find zp, scale
641
+ if not perchannel:
642
+ weight = np.expand_dims(weight.flatten(), axis=1)
643
+ tmp = np.zeros(weight.shape[1])
644
+ xmin = np.minimum(np.min(weight, axis=0), tmp)
645
+ xmax = np.maximum(np.max(weight, axis=0), tmp)
646
+ if scheme == "sym":
647
+ xmax = np.maximum(np.abs(xmin), xmax)
648
+ tmp = xmin < 0
649
+ if np.any(tmp):
650
+ xmin[tmp] = -xmax[tmp]
651
+ tmp = (xmin == 0) & (xmax == 0)
652
+ xmin[tmp] = -1
653
+ xmax[tmp] = +1
654
+
655
+ scale = (xmax - xmin) / maxq
656
+ if scheme == "sym":
657
+ zero = np.ones(scale.shape) * (maxq + 1) / 2
658
+ else:
659
+ zero = np.round(-xmin / scale)
660
+ if mse:
661
+ best = np.ones([weight.shape[1]]) * float("inf")
662
+ for i in range(int(maxshrink * grid)):
663
+ p = 1 - i / grid
664
+ xmin1 = p * xmin
665
+ xmax1 = p * xmax
666
+ scale1 = (xmax1 - xmin1) / maxq
667
+ zero1 = np.round(-xmin1 / scale1) if scheme != "sym" else zero
668
+ q = np.clip(np.round(weight / scale1) + zero1, 0, maxq)
669
+ q -= weight
670
+ q = np.power(np.abs(q), norm)
671
+ err = np.sum(q, 0)
672
+ tmp = err < best
673
+ if np.any(tmp):
674
+ best[tmp] = err[tmp]
675
+ scale[tmp] = scale1[tmp]
676
+ zero[tmp] = zero1[tmp]
677
+ if not perchannel:
678
+ tmp = org_shape[1]
679
+ scale = np.repeat(scale, tmp)
680
+ zero = np.repeat(zero, tmp)
681
+ shape = [-1] + [1] * (len(org_shape) - 1)
682
+ scale = np.reshape(scale, shape)
683
+ zero = np.reshape(zero, shape)
684
+ return scale, zero
685
+
686
+ shape = W.shape
687
+ scale, zp = find_params(W)
688
+ dead = np.diag(H) == 0
689
+ H[dead, dead] = 1
690
+ W[dead, :] = 0 # such channel makes no contribution to quantization computation
691
+
692
+ # rearrange considering the diag's value
693
+ if actorder:
694
+ perm = np.argsort(np.diag(H))[::-1]
695
+ W = W[perm, :] # noqa: N806
696
+ H = H[perm, :][:, perm] # noqa: N806
697
+ Losses = np.zeros_like(W) # noqa: N806
698
+ Q = np.zeros_like(W) # noqa: N806
699
+ damp = percdamp * np.mean(np.diag(H))
700
+ diag = np.arange(shape[0])
701
+ H[diag, diag] += damp # add a average value of
702
+ H = np.linalg.cholesky(np.linalg.inv(H)).T # noqa: N806
703
+ Hinv = H # noqa: N806
704
+ for i1 in range(0, shape[0], blocksize):
705
+ i2 = min(i1 + blocksize, shape[0])
706
+ count = i2 - i1
707
+
708
+ W1 = copy.deepcopy(W[i1:i2, :]) # noqa: N806
709
+ Q1 = np.zeros_like(W1) # noqa: N806
710
+ Err1 = np.zeros_like(W1) # noqa: N806
711
+ Losses1 = np.zeros_like(W1) # noqa: N806
712
+ Hinv1 = Hinv[i1:i2, i1:i2] # noqa: N806
713
+
714
+ for i in range(count): # within a block, channel wise
715
+ w = W1[i, :]
716
+ d = Hinv1[i, i]
717
+
718
+ if group_size != -1:
719
+ if (i1 + i) % group_size == 0:
720
+ scale, zp = find_params(W[(i1 + i) : (i1 + i + group_size), :])
721
+
722
+ q = (scale * (np.clip(np.round(w[:, np.newaxis] / scale) + zp, 0, maxq) - zp)).flatten()
723
+ Q1[i, :] = q
724
+ Losses1[i, :] = (w - q) ** 2 / d**2
725
+
726
+ err1 = (w - q) / d
727
+ W1[i:, :] -= np.matmul(np.expand_dims(Hinv1[i:, i], axis=1), np.expand_dims(err1, axis=0))
728
+ Err1[i, :] = err1
729
+
730
+ Q[i1:i2, :] = Q1
731
+ Losses[i1:i2, :] = Losses1 / 2
732
+
733
+ W[i2:, :] -= np.matmul(Hinv[i2:, i1:i2], Err1)
734
+
735
+ if actorder:
736
+ invperm = np.argsort(perm)
737
+ Q = Q[invperm, :] # noqa: N806
738
+
739
+ Q = np.reshape(Q, W.shape) # noqa: N806
740
+ del W
741
+ return Q
742
+
743
+
744
+ def gptq_quantize(
745
+ model,
746
+ dataloader,
747
+ weight_config={}, # noqa: B006
748
+ num_bits=4,
749
+ group_size=32,
750
+ scheme="asym",
751
+ n_samples=128,
752
+ percdamp=0.01,
753
+ blocksize=128,
754
+ actorder=False,
755
+ mse=False,
756
+ perchannel=True,
757
+ accuracy_level=0,
758
+ providers=["CPUExecutionProvider"], # noqa: B006
759
+ ):
760
+ """Quant the model with GPTQ method.
761
+
762
+ Args:
763
+ model (ModelProto or ONNXModel): onnx model
764
+ dataloader (object): dataloader for calibration.
765
+ weight_config (dict): quantization config
766
+ For example,
767
+ weight_config = {
768
+ 'fc2':
769
+ {
770
+ 'bits': 4,
771
+ 'group_size': 32,
772
+ 'scheme': 'sym',
773
+ 'algorithm': 'GPTQ'
774
+ }
775
+ }
776
+ num_bits (int, optional): num_bits. Default is 4.
777
+ group_size (int, optional): how many elements share one scale/zp. Default is 32.
778
+ scheme (str, optional): sym or asym. Defaults to "asym".
779
+ n_samples (int, optional): calibration sample number.
780
+ percdamp (float, optional): percent of the average Hessian diagonal to use for dampening.
781
+ blocksize (int, optional): blocksize to quantize weight.
782
+ actorder (bool, optional): whether rearrange Hessian matrix considering the diag's value.
783
+ mse (bool, optional): whether get scale and zero point with mse error.
784
+ perchannel (bool, optional): whether quantize weight per-channel.
785
+ accuracy_level (int): accuracy level. Support 0 (unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8).
786
+ providers (list): providers to use
787
+
788
+ Returns:
789
+ model: fake quantized ONNXModel
790
+ """
791
+ model = ONNXModel(model)
792
+ base_dir = os.path.dirname(model.model_path) if model.model_path is not None else ""
793
+
794
+ inputs, so = prepare_inputs(model, n_samples, dataloader, providers)
795
+ del dataloader
796
+ org_output = copy.deepcopy(model.model.graph.output)
797
+ model.remove_tensors_from_outputs([i.name for i in org_output])
798
+ output_names = []
799
+ for node in model.nodes():
800
+ if (
801
+ node.op_type in ["MatMul"]
802
+ and weight_config.get(node.name, {}) != "fp32"
803
+ and weight_config.get(node.name, {}).get("algorithm", "GPTQ") == "GPTQ"
804
+ ):
805
+ output_names.append(node.input[0])
806
+ output_names = list(set(output_names))
807
+ model.add_tensors_to_outputs(output_names)
808
+ if model.is_large_model:
809
+ onnx.save_model(
810
+ model.model,
811
+ model.model_path + "_augment.onnx",
812
+ save_as_external_data=True,
813
+ all_tensors_to_one_file=True,
814
+ convert_attribute=False,
815
+ )
816
+
817
+ session = (
818
+ ort.InferenceSession(model.model.SerializeToString(), so, providers=providers)
819
+ if not model.is_large_model
820
+ else ort.InferenceSession(model.model_path + "_augment.onnx", so, providers=providers)
821
+ )
822
+
823
+ for idx, input_name in enumerate(output_names):
824
+ simple_progress_bar(len(output_names), idx + 1)
825
+ node_list = []
826
+ weights = []
827
+
828
+ for node in model.input_name_to_nodes[input_name]:
829
+ if (
830
+ node.op_type in ["MatMul"]
831
+ and weight_config.get(node.name, {}) != "fp32"
832
+ and weight_config.get(node.name, {}).get("algorithm", "GPTQ") == "GPTQ"
833
+ and model.get_initializer(node.input[1]) is not None
834
+ ):
835
+ weight = numpy_helper.to_array(
836
+ model.get_initializer(model.get_node(node.name).input[1]), base_dir
837
+ ).copy()
838
+ if len(weight.shape) != 2:
839
+ continue
840
+
841
+ weights.append(weight)
842
+ node_list.append(model.get_node(node.name))
843
+
844
+ if len(weights) == 0:
845
+ continue
846
+
847
+ Hs = [np.zeros((i.shape[0], i.shape[0])) for i in weights] # noqa: N806
848
+ nsamples = 0
849
+ for data in inputs:
850
+ inp = session.run([input_name], data)[0]
851
+ tmp = inp.shape[0]
852
+ inp = np.reshape(inp, (-1, inp.shape[-1]))
853
+ Hs = [i * (nsamples / (nsamples + tmp)) for i in Hs] # noqa: N806
854
+ nsamples += tmp
855
+ inp = np.sqrt(2 / nsamples) * inp
856
+ Hs = [i + np.matmul(inp.T, inp) for i in Hs] # noqa: N806
857
+
858
+ for (
859
+ node,
860
+ weight,
861
+ H, # noqa: N806
862
+ ) in zip(node_list, weights, Hs, strict=False):
863
+ if node.name in weight_config:
864
+ num_bits = weight_config[node.name]["bits"]
865
+ group_size = weight_config[node.name]["group_size"]
866
+ scheme = weight_config[node.name]["scheme"]
867
+ group_size = group_size if group_size != -1 else weight.shape[0]
868
+ dtype = weight.dtype
869
+
870
+ q_weight = gptq(
871
+ weight,
872
+ H,
873
+ num_bits=num_bits,
874
+ group_size=group_size,
875
+ scheme=scheme,
876
+ blocksize=blocksize,
877
+ percdamp=percdamp,
878
+ actorder=actorder,
879
+ mse=mse,
880
+ perchannel=perchannel,
881
+ )
882
+
883
+ weight_tensor = model.get_initializer(node.input[1])
884
+ init_share_num = model.get_initializer_share_num(node.input[1])
885
+
886
+ satisfy_MatMulNBits_condition = num_bits == 4 # noqa: N806
887
+
888
+ if satisfy_MatMulNBits_condition: # pragma: no cover
889
+ org_shape = weight.shape
890
+ k_blocks = (org_shape[0] + group_size - 1) // group_size
891
+ q_weight = pad_tensor(q_weight, group_size, k_blocks)
892
+ q_weight, scale, zp = quant_tensor(q_weight.T, num_bits, group_size, scheme, "uint")
893
+ q_matmul_node, new_inits = make_matmul_weight_only_node(
894
+ node=node,
895
+ weight_shape=org_shape,
896
+ num_bits=num_bits,
897
+ group_size=group_size,
898
+ k_blocks=k_blocks,
899
+ q_weight=q_weight.astype("uint8"),
900
+ scale=scale.astype(dtype),
901
+ zero_point=zp if scheme == "asym" else None,
902
+ accuracy_level=accuracy_level,
903
+ )
904
+
905
+ model.add_initializers(new_inits)
906
+ model.remove_node(node)
907
+ model.add_node(q_matmul_node)
908
+ else:
909
+ q_weight_tensor = onnx.helper.make_tensor(
910
+ name=node.input[1] + f"_Q{num_bits!s}G{group_size!s}",
911
+ data_type=np_dtype_to_tensor_dtype(dtype),
912
+ dims=q_weight.shape,
913
+ vals=q_weight.astype(dtype).tobytes(),
914
+ raw=True,
915
+ )
916
+ model.add_initializer(q_weight_tensor)
917
+ node.input[1] = q_weight_tensor.name
918
+ if init_share_num == 1:
919
+ model.remove_initializer(weight_tensor)
920
+
921
+ model.remove_tensors_from_outputs(output_names)
922
+ model.model.graph.output.MergeFrom(org_output)
923
+
924
+ model.topological_sort()
925
+
926
+ # reload external data to prevent external data file path errors
927
+ if model.is_large_model:
928
+ from onnx.external_data_helper import load_external_data_for_model # noqa: PLC0415
929
+
930
+ load_external_data_for_model(model.model, os.path.split(model.model_path)[0])
931
+
932
+ return model