onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,932 @@
|
|
|
1
|
+
#
|
|
2
|
+
# The implementation of this file is based on:
|
|
3
|
+
# https://github.com/intel/neural-compressor/tree/master/neural_compressor
|
|
4
|
+
#
|
|
5
|
+
# Copyright (c) 2023 Intel Corporation
|
|
6
|
+
#
|
|
7
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
8
|
+
# you may not use this file except in compliance with the License.
|
|
9
|
+
# You may obtain a copy of the License at
|
|
10
|
+
#
|
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
12
|
+
#
|
|
13
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
14
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
15
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
16
|
+
# See the License for the specific language governing permissions and
|
|
17
|
+
# limitations under the License.
|
|
18
|
+
#
|
|
19
|
+
# Modifications:
|
|
20
|
+
# Add k-quant quantization method.
|
|
21
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
22
|
+
# Licensed under the MIT License.
|
|
23
|
+
|
|
24
|
+
"""WeightOnly for onnxrt adaptor."""
|
|
25
|
+
|
|
26
|
+
import copy
|
|
27
|
+
import logging
|
|
28
|
+
import os
|
|
29
|
+
import sys
|
|
30
|
+
|
|
31
|
+
import numpy as np
|
|
32
|
+
import onnx
|
|
33
|
+
from onnx import numpy_helper
|
|
34
|
+
from onnx.helper import np_dtype_to_tensor_dtype
|
|
35
|
+
|
|
36
|
+
import onnxruntime as ort
|
|
37
|
+
|
|
38
|
+
from .onnx_model import ONNXModel
|
|
39
|
+
from .util import simple_progress_bar
|
|
40
|
+
|
|
41
|
+
logger = logging.getLogger("neural_compressor")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def make_matmul_weight_only_node(
|
|
45
|
+
node,
|
|
46
|
+
weight_shape,
|
|
47
|
+
num_bits,
|
|
48
|
+
group_size,
|
|
49
|
+
k_blocks,
|
|
50
|
+
q_weight,
|
|
51
|
+
scale,
|
|
52
|
+
zero_point,
|
|
53
|
+
accuracy_level=0,
|
|
54
|
+
): # pragma: no cover
|
|
55
|
+
"""Build MatMulNBits node.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
node: original matmul node
|
|
59
|
+
weight_shape: original weight shape
|
|
60
|
+
num_bits (int): num_bits
|
|
61
|
+
group_size (int): how many elements share one scale/zp
|
|
62
|
+
k_blocks (int): block number
|
|
63
|
+
q_weight (array): quantized weight
|
|
64
|
+
scale (array): scale
|
|
65
|
+
zero_point (array): zero point
|
|
66
|
+
accuracy_level (int): accuracy level. Support 0 (unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8).
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
matmul_weight_only_node: MatMulNBits node
|
|
70
|
+
new_inits: initializers of the new node
|
|
71
|
+
"""
|
|
72
|
+
blob_size = group_size * num_bits // 8
|
|
73
|
+
packed = np.zeros((q_weight.shape[0], blob_size), dtype="uint8")
|
|
74
|
+
q_weight_name = node.input[1] + f"_Q{num_bits!s}G{group_size!s}"
|
|
75
|
+
input_names = [node.input[0], q_weight_name]
|
|
76
|
+
new_inits = []
|
|
77
|
+
kwargs = {}
|
|
78
|
+
|
|
79
|
+
op_type = "MatMulNBits"
|
|
80
|
+
|
|
81
|
+
# pack quantized weight
|
|
82
|
+
if num_bits == 4:
|
|
83
|
+
q_weight_pairs = q_weight[:, ::2] | q_weight[:, 1::2] << 4
|
|
84
|
+
packed[:, :] = q_weight_pairs[:, :blob_size]
|
|
85
|
+
elif num_bits == 8:
|
|
86
|
+
packed = q_weight
|
|
87
|
+
else:
|
|
88
|
+
logger.error(f"MatMulNBits does not have kernel support for num_bits = {num_bits}.")
|
|
89
|
+
|
|
90
|
+
packed = np.reshape(packed, (-1, k_blocks, blob_size))
|
|
91
|
+
|
|
92
|
+
# build scale tensor
|
|
93
|
+
scale = np.reshape(scale, (-1, k_blocks))
|
|
94
|
+
assert scale.dtype == np.float32 or scale.dtype == np.float16
|
|
95
|
+
scale_tensor = onnx.helper.make_tensor(
|
|
96
|
+
name=node.input[1] + "_scale",
|
|
97
|
+
data_type=np_dtype_to_tensor_dtype(scale.dtype),
|
|
98
|
+
dims=scale.shape,
|
|
99
|
+
vals=scale.tobytes(),
|
|
100
|
+
raw=True,
|
|
101
|
+
)
|
|
102
|
+
input_names.append(scale_tensor.name)
|
|
103
|
+
new_inits.append(scale_tensor)
|
|
104
|
+
|
|
105
|
+
# build zero_point tensor
|
|
106
|
+
if zero_point is not None:
|
|
107
|
+
if num_bits == 8:
|
|
108
|
+
packed_zp = zero_point.astype("uint8")
|
|
109
|
+
elif num_bits == 4:
|
|
110
|
+
# For 4-bit case, the default zeros is 0x8. So it is 0x88 = 136 if we fill lower/higher 4 bits with 0x8.
|
|
111
|
+
packed_zp = np.full((zero_point.shape[0] + 1) // 2, 136, dtype="uint8")
|
|
112
|
+
# create an index array
|
|
113
|
+
idx = np.arange(zero_point.shape[0] // k_blocks * k_blocks).reshape(-1)
|
|
114
|
+
# separate odd and even indices
|
|
115
|
+
even_idx = idx[::2]
|
|
116
|
+
odd_idx = idx[1::2]
|
|
117
|
+
# vectorized operation for even and odd indices
|
|
118
|
+
packed_zp[even_idx // 2] = (packed_zp[even_idx // 2] & 0xF0) | zero_point[even_idx].ravel()
|
|
119
|
+
packed_zp[odd_idx // 2] = (packed_zp[odd_idx // 2] & 0x0F) | (zero_point[odd_idx].ravel() << 4)
|
|
120
|
+
else:
|
|
121
|
+
raise ValueError(f"MatMulNBits does not have kernel support for num_bits = {num_bits}.")
|
|
122
|
+
|
|
123
|
+
packed_zp = np.reshape(packed_zp, (weight_shape[1], -1))
|
|
124
|
+
zp_tensor = onnx.helper.make_tensor(
|
|
125
|
+
name=node.input[1] + "_zp", data_type=2, dims=packed_zp.shape, vals=packed_zp.tobytes(), raw=True
|
|
126
|
+
)
|
|
127
|
+
input_names.append(zp_tensor.name)
|
|
128
|
+
new_inits.append(zp_tensor)
|
|
129
|
+
|
|
130
|
+
# set kwargs
|
|
131
|
+
kwargs["K"] = weight_shape[0]
|
|
132
|
+
kwargs["N"] = weight_shape[1]
|
|
133
|
+
kwargs["bits"] = num_bits
|
|
134
|
+
kwargs["block_size"] = group_size
|
|
135
|
+
if accuracy_level > 0:
|
|
136
|
+
# require onnxruntime > 1.16.3
|
|
137
|
+
kwargs["accuracy_level"] = accuracy_level
|
|
138
|
+
|
|
139
|
+
q_weight_tensor = onnx.helper.make_tensor(
|
|
140
|
+
name=q_weight_name,
|
|
141
|
+
data_type=2,
|
|
142
|
+
dims=packed.shape,
|
|
143
|
+
vals=packed.tobytes(),
|
|
144
|
+
raw=True,
|
|
145
|
+
)
|
|
146
|
+
new_inits.append(q_weight_tensor)
|
|
147
|
+
|
|
148
|
+
matmul_weight_only_node = onnx.helper.make_node(
|
|
149
|
+
op_type,
|
|
150
|
+
inputs=input_names,
|
|
151
|
+
outputs=node.output,
|
|
152
|
+
name=node.name + "_Q" + str(num_bits) if node.name else "_Q" + str(num_bits),
|
|
153
|
+
domain="com.microsoft",
|
|
154
|
+
**kwargs,
|
|
155
|
+
)
|
|
156
|
+
return matmul_weight_only_node, new_inits
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def quant_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ratio=1.0):
|
|
160
|
+
"""Quantize tensor per group.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
data : input weight
|
|
164
|
+
num_bits (int, optional): num_bits. Defaults to 4.
|
|
165
|
+
group_size (int, optional): how many elements share one scale/zp. Defaults to 4.
|
|
166
|
+
scheme (str, optional): quantization scheme. Defaults to "asym".
|
|
167
|
+
dtype (str, optional): data type. Defaults to "int".
|
|
168
|
+
ratio (float, optional): percentile of clip. Defaults to 1.0.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
output: quantized weight
|
|
172
|
+
scale: scale
|
|
173
|
+
zero_point: zero point
|
|
174
|
+
"""
|
|
175
|
+
data = np.reshape(data, (-1, group_size))
|
|
176
|
+
if scheme == "asym" or dtype == "uint":
|
|
177
|
+
maxq = 2**num_bits - 1
|
|
178
|
+
minq = 0
|
|
179
|
+
elif scheme == "sym":
|
|
180
|
+
maxq = 2 ** (num_bits - 1) - 1 if num_bits != 1 else 0
|
|
181
|
+
minq = -(2 ** (num_bits - 1)) if num_bits != 1 else -1
|
|
182
|
+
|
|
183
|
+
rmin = np.min(data, axis=1, keepdims=True) * ratio
|
|
184
|
+
rmax = np.max(data, axis=1, keepdims=True) * ratio
|
|
185
|
+
if scheme == "sym":
|
|
186
|
+
max_range = np.maximum(np.abs(rmin), np.abs(rmax))
|
|
187
|
+
scale = np.ones(rmax.shape)
|
|
188
|
+
mask = max_range > 0
|
|
189
|
+
scale[mask] = (max_range[mask] * 2.0).astype(np.float64) / (maxq - minq)
|
|
190
|
+
zero_point = (
|
|
191
|
+
np.zeros(scale.shape) if dtype == "int" else np.ones(rmax.shape, dtype="uint8") * (1 << (num_bits - 1))
|
|
192
|
+
)
|
|
193
|
+
else:
|
|
194
|
+
scale = np.ones(rmax.shape)
|
|
195
|
+
scale[rmin != rmax] = np.array(
|
|
196
|
+
[float(i) / (maxq - minq) for i in (rmax - rmin)[rmin != rmax].flatten().tolist()]
|
|
197
|
+
)
|
|
198
|
+
zero_point = (
|
|
199
|
+
((np.zeros(scale.shape) - rmin) / scale).round()
|
|
200
|
+
if dtype == "int"
|
|
201
|
+
else np.maximum(0, np.minimum(maxq, ((np.zeros(scale.shape) - rmin) / scale).round())).astype("uint8")
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
q_weight = np.empty_like(data, dtype=scale.dtype)
|
|
205
|
+
np.divide(data, scale, out=q_weight)
|
|
206
|
+
np.add(q_weight, zero_point, out=q_weight)
|
|
207
|
+
np.round(q_weight, out=q_weight)
|
|
208
|
+
np.clip(q_weight, minq, maxq, out=q_weight)
|
|
209
|
+
|
|
210
|
+
return q_weight, scale, zero_point
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32):
|
|
214
|
+
"""Quantize tensor per group based on k quant.
|
|
215
|
+
|
|
216
|
+
Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
data : input weight
|
|
220
|
+
num_bits (int, optional): num_bits. Defaults to 4.
|
|
221
|
+
group_size (int, optional): how many elements share one scale/zp. Defaults to 32.
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
output: quantized weight
|
|
225
|
+
scale: scale
|
|
226
|
+
zero_point: zero point
|
|
227
|
+
"""
|
|
228
|
+
data = np.reshape(data, (-1, group_size)).astype(np.float32) # nb = data.shape[0], (nb, group_size)
|
|
229
|
+
maxq = 2**num_bits - 1
|
|
230
|
+
minq = 0
|
|
231
|
+
sum_x2 = np.sum(data**2, axis=1, keepdims=True) # (nb, 1)
|
|
232
|
+
av_x = np.sqrt(sum_x2 / group_size) # (nb, 1)
|
|
233
|
+
weights = np.add(av_x, np.abs(data)) # (nb, group_size)
|
|
234
|
+
rmin = np.min(data, axis=1, keepdims=True) # (nb, 1)
|
|
235
|
+
rmax = np.max(data, axis=1, keepdims=True) # (nb, 1)
|
|
236
|
+
sum_w = np.sum(weights, axis=1, keepdims=True) # (nb, 1)
|
|
237
|
+
sum_x = np.sum(weights * data, axis=1, keepdims=True) # (nb, group_size)
|
|
238
|
+
iscale = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
|
|
239
|
+
mask = rmin != rmax
|
|
240
|
+
iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask])
|
|
241
|
+
scale = 1 / iscale
|
|
242
|
+
quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size)
|
|
243
|
+
diff = scale * quant_data + rmin - data # (nb, group_size)
|
|
244
|
+
best_mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)
|
|
245
|
+
nstep = 20
|
|
246
|
+
rdelta = 0.1
|
|
247
|
+
# nstep * rdelta = -2 * rrmin, maxq - minq = 2**num_bits - 1
|
|
248
|
+
rrmin = -1
|
|
249
|
+
for is_ in range(nstep):
|
|
250
|
+
iscale_new = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
|
|
251
|
+
factor = np.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0]
|
|
252
|
+
mask = rmin != rmax
|
|
253
|
+
iscale_new[mask] = factor / (rmax[mask] - rmin[mask])
|
|
254
|
+
quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size)
|
|
255
|
+
mul_weights_quant_data_new = weights * quant_data_new
|
|
256
|
+
sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1)
|
|
257
|
+
sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1)
|
|
258
|
+
sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1)
|
|
259
|
+
D = np.subtract(sum_w * sum_l2, sum_l**2) # noqa: N806
|
|
260
|
+
|
|
261
|
+
this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1)
|
|
262
|
+
this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1)
|
|
263
|
+
|
|
264
|
+
diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
|
|
265
|
+
mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)
|
|
266
|
+
|
|
267
|
+
mad_1 = np.array(mad)
|
|
268
|
+
best_mad_1 = np.array(best_mad)
|
|
269
|
+
idx_to_replace = np.where(mad_1 < best_mad_1)[0]
|
|
270
|
+
quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :]
|
|
271
|
+
best_mad[idx_to_replace] = mad[idx_to_replace]
|
|
272
|
+
scale[idx_to_replace] = this_scale[idx_to_replace]
|
|
273
|
+
rmin[idx_to_replace] = this_min[idx_to_replace]
|
|
274
|
+
|
|
275
|
+
zero_point = np.clip(((-rmin) / scale).round(), 0, maxq).astype("uint8")
|
|
276
|
+
scale = scale.astype(np.float64)
|
|
277
|
+
q_weight = np.empty_like(data, dtype=scale.dtype)
|
|
278
|
+
np.divide(data, scale, out=q_weight)
|
|
279
|
+
np.add(q_weight, zero_point, out=q_weight)
|
|
280
|
+
np.round(q_weight, out=q_weight)
|
|
281
|
+
np.clip(q_weight, minq, maxq, out=q_weight)
|
|
282
|
+
|
|
283
|
+
return q_weight, scale, zero_point
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def quant_tensor_k_quant_cuda(data, num_bits=4, group_size=32):
|
|
287
|
+
"""Quantize tensor per group based on k quant.
|
|
288
|
+
|
|
289
|
+
Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
data : input weight
|
|
293
|
+
num_bits (int, optional): num_bits. Defaults to 4.
|
|
294
|
+
group_size (int, optional): how many elements share one scale/zp. Defaults to 4.
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
output: quantized weight
|
|
298
|
+
scale: scale
|
|
299
|
+
zero_point: zero point
|
|
300
|
+
"""
|
|
301
|
+
try:
|
|
302
|
+
import cupy as cp # noqa: PLC0415
|
|
303
|
+
import torch # noqa: PLC0415
|
|
304
|
+
|
|
305
|
+
if torch.cuda.is_available():
|
|
306
|
+
data = cp.asarray(data)
|
|
307
|
+
data = data.reshape((-1, group_size)).astype(cp.float32) # nb = data.shape[0], (nb, group_size)
|
|
308
|
+
maxq = 2**num_bits - 1
|
|
309
|
+
minq = 0
|
|
310
|
+
sum_x2 = cp.sum(data**2, axis=1, keepdims=True) # (nb, 1)
|
|
311
|
+
av_x = cp.sqrt(sum_x2 / group_size) # (nb, 1)
|
|
312
|
+
weights = cp.add(av_x, cp.abs(data)) # (nb, group_size)
|
|
313
|
+
rmin = cp.min(data, axis=1, keepdims=True) # (nb, 1)
|
|
314
|
+
rmax = cp.max(data, axis=1, keepdims=True) # (nb, 1)
|
|
315
|
+
sum_w = cp.sum(weights, axis=1, keepdims=True) # (nb, 1)
|
|
316
|
+
sum_x = cp.sum(weights * data, axis=1, keepdims=True) # (nb, group_size)
|
|
317
|
+
iscale = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
|
|
318
|
+
mask = rmin != rmax
|
|
319
|
+
iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask])
|
|
320
|
+
scale = 1 / iscale
|
|
321
|
+
quant_data = cp.clip(cp.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size)
|
|
322
|
+
diff = scale * quant_data + rmin - data # (nb, group_size)
|
|
323
|
+
best_mad = cp.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)
|
|
324
|
+
nstep = 20
|
|
325
|
+
rdelta = 0.1
|
|
326
|
+
rrmin = -1
|
|
327
|
+
for is_ in range(nstep):
|
|
328
|
+
iscale_new = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
|
|
329
|
+
factor = cp.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0]
|
|
330
|
+
mask = rmin != rmax
|
|
331
|
+
iscale_new[mask] = factor / (rmax[mask] - rmin[mask])
|
|
332
|
+
quant_data_new = cp.clip(cp.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size)
|
|
333
|
+
mul_weights_quant_data_new = weights * quant_data_new
|
|
334
|
+
sum_l = cp.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1)
|
|
335
|
+
sum_l2 = cp.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1)
|
|
336
|
+
sum_xl = cp.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1)
|
|
337
|
+
D = cp.subtract(sum_w * sum_l2, sum_l**2) # noqa: N806
|
|
338
|
+
|
|
339
|
+
this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1)
|
|
340
|
+
this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1)
|
|
341
|
+
|
|
342
|
+
diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
|
|
343
|
+
mad = cp.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)
|
|
344
|
+
|
|
345
|
+
mad_1 = cp.array(mad)
|
|
346
|
+
best_mad_1 = cp.array(best_mad)
|
|
347
|
+
idx_to_replace = cp.where(mad_1 < best_mad_1)[0]
|
|
348
|
+
quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :]
|
|
349
|
+
best_mad[idx_to_replace] = mad[idx_to_replace]
|
|
350
|
+
scale[idx_to_replace] = this_scale[idx_to_replace]
|
|
351
|
+
rmin[idx_to_replace] = this_min[idx_to_replace]
|
|
352
|
+
|
|
353
|
+
zero_point = cp.clip(((-rmin) / scale).round(), 0, maxq).astype("uint8")
|
|
354
|
+
scale = scale.astype(cp.float64)
|
|
355
|
+
q_weight = cp.empty_like(data, dtype=scale.dtype)
|
|
356
|
+
cp.divide(data, scale, out=q_weight)
|
|
357
|
+
cp.add(q_weight, zero_point, out=q_weight)
|
|
358
|
+
cp.round(q_weight, out=q_weight)
|
|
359
|
+
cp.clip(q_weight, minq, maxq, out=q_weight)
|
|
360
|
+
|
|
361
|
+
return q_weight.get(), scale.get(), zero_point.get()
|
|
362
|
+
else:
|
|
363
|
+
logger.warning(
|
|
364
|
+
"Try to use k-quant quantization on CUDA. However, CUDA is not available."
|
|
365
|
+
"Fall back to k-quant quantization on CPU."
|
|
366
|
+
)
|
|
367
|
+
return quant_tensor_k_quant_cpu(data, num_bits, group_size)
|
|
368
|
+
except ImportError:
|
|
369
|
+
logger.info(
|
|
370
|
+
"Now we are using k-quant quantization on cpu, which is time consuming."
|
|
371
|
+
"Please consider install cupy to speed up on CUDA. See https://cupy.dev/"
|
|
372
|
+
"Please also install torch to check CUDA availability."
|
|
373
|
+
)
|
|
374
|
+
return quant_tensor_k_quant_cpu(data, num_bits, group_size)
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def qdq_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ratio=1.0):
|
|
378
|
+
"""Quant dequant tensor per group.
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
data : input weight
|
|
382
|
+
num_bits (int, optional): num_bits. Defaults to 4.
|
|
383
|
+
group_size (int, optional): how many elements share one scale/zp. Defaults to 4.
|
|
384
|
+
scheme (str, optional): quantization scheme. Defaults to "asym".
|
|
385
|
+
dtype (str, optional): data type. Defaults to "int".
|
|
386
|
+
ratio (float, optional): percentile of clip. Defaults to 1.0.
|
|
387
|
+
|
|
388
|
+
Returns:
|
|
389
|
+
output: quant-dequant weight
|
|
390
|
+
"""
|
|
391
|
+
org_shape = data.shape
|
|
392
|
+
weight, scale, zp = quant_tensor(data, num_bits, group_size, scheme, dtype, ratio)
|
|
393
|
+
return np.reshape(scale * (weight - zp), org_shape)
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def pad_tensor(weight, group_size, k_blocks):
|
|
397
|
+
"""Pad tensor rowi so that it can be is divisible by group_size.
|
|
398
|
+
|
|
399
|
+
Args:
|
|
400
|
+
weight (array): weight
|
|
401
|
+
group_size (int): how many elements share one scale/zp
|
|
402
|
+
k_blocks (int): the number of block
|
|
403
|
+
|
|
404
|
+
Returns:
|
|
405
|
+
weight: paded weight
|
|
406
|
+
"""
|
|
407
|
+
if group_size == -1:
|
|
408
|
+
return weight
|
|
409
|
+
|
|
410
|
+
org_w_shape = weight.shape
|
|
411
|
+
padded_rows = k_blocks * group_size
|
|
412
|
+
pad_len = padded_rows - org_w_shape[0]
|
|
413
|
+
|
|
414
|
+
if pad_len > 0:
|
|
415
|
+
weight = np.pad(weight, ((0, pad_len), (0, 0)), "constant")
|
|
416
|
+
|
|
417
|
+
return weight
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def rtn_quantize(
|
|
421
|
+
model,
|
|
422
|
+
weight_config={}, # noqa: B006
|
|
423
|
+
num_bits=4,
|
|
424
|
+
group_size=32,
|
|
425
|
+
scheme="asym",
|
|
426
|
+
ratios={}, # noqa: B006
|
|
427
|
+
accuracy_level=0,
|
|
428
|
+
providers=["CPUExecutionProvider"], # noqa: B006
|
|
429
|
+
algorithm="k_quant",
|
|
430
|
+
):
|
|
431
|
+
"""Quant the model with round to nearst method.
|
|
432
|
+
|
|
433
|
+
Args:
|
|
434
|
+
model (ModelProto or ONNXModel): onnx model
|
|
435
|
+
weight_config (dict): quantization config
|
|
436
|
+
For example,
|
|
437
|
+
weight_config = {
|
|
438
|
+
'fc2':
|
|
439
|
+
{
|
|
440
|
+
'bits': 4,
|
|
441
|
+
'group_size': 32,
|
|
442
|
+
'scheme': 'sym',
|
|
443
|
+
'algorithm': 'RTN'
|
|
444
|
+
}
|
|
445
|
+
}
|
|
446
|
+
num_bits (int, optional): num_bits. Default is 4.
|
|
447
|
+
group_size (int, optional): how many elements share one scale/zp. Default is 32.
|
|
448
|
+
scheme (str, optional): sym or asym. Defaults to "asym".
|
|
449
|
+
ratios (dict, optional): percentile of clip. Defaults to {}.
|
|
450
|
+
accuracy_level (int): accuracy level. Support 0 (unset),1(fp32), 2(fp16), 3(bf16), or 4(int8).
|
|
451
|
+
providers (list): providers to use
|
|
452
|
+
|
|
453
|
+
Returns:
|
|
454
|
+
model: fake quantized ONNXModel
|
|
455
|
+
"""
|
|
456
|
+
model = ONNXModel(model)
|
|
457
|
+
base_dir = os.path.dirname(model.model_path) if model.model_path is not None else ""
|
|
458
|
+
new_nodes = []
|
|
459
|
+
remove_nodes = []
|
|
460
|
+
total_num = len([i for i in model.nodes() if i.op_type in ["MatMul"]])
|
|
461
|
+
curr_id = 0
|
|
462
|
+
for node in model.nodes():
|
|
463
|
+
if node.op_type in ["MatMul"]:
|
|
464
|
+
curr_id += 1
|
|
465
|
+
simple_progress_bar(total_num, curr_id)
|
|
466
|
+
if (
|
|
467
|
+
node.op_type in ["MatMul"]
|
|
468
|
+
and model.get_initializer(node.input[1]) is not None
|
|
469
|
+
and weight_config.get(node.name, {}) != "fp32"
|
|
470
|
+
):
|
|
471
|
+
weight_tensor = model.get_initializer(node.input[1])
|
|
472
|
+
weight = numpy_helper.to_array(weight_tensor, base_dir=base_dir).copy()
|
|
473
|
+
if len(weight.shape) != 2:
|
|
474
|
+
continue
|
|
475
|
+
|
|
476
|
+
dtype = weight.dtype
|
|
477
|
+
|
|
478
|
+
if node.name in weight_config:
|
|
479
|
+
num_bits = weight_config[node.name]["bits"]
|
|
480
|
+
group_size = weight_config[node.name]["group_size"]
|
|
481
|
+
scheme = weight_config[node.name]["scheme"]
|
|
482
|
+
|
|
483
|
+
org_w_shape = weight.shape # ic, oc
|
|
484
|
+
group_size = group_size if group_size != -1 else org_w_shape[0]
|
|
485
|
+
|
|
486
|
+
k_blocks = (org_w_shape[0] - 1) // group_size + 1
|
|
487
|
+
init_share_num = model.get_initializer_share_num(node.input[1])
|
|
488
|
+
|
|
489
|
+
weight = pad_tensor(weight, group_size, k_blocks)
|
|
490
|
+
|
|
491
|
+
satisfy_MatMulNBits_condition = num_bits == 4 or num_bits == 8 # noqa: N806
|
|
492
|
+
|
|
493
|
+
if satisfy_MatMulNBits_condition: # pragma: no cover
|
|
494
|
+
if algorithm == "k_quant":
|
|
495
|
+
q_weight, scale, zp = quant_tensor_k_quant_cuda(weight.T, num_bits, group_size)
|
|
496
|
+
else:
|
|
497
|
+
q_weight, scale, zp = quant_tensor(
|
|
498
|
+
weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1)
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
q_matmul_node, new_inits = make_matmul_weight_only_node(
|
|
502
|
+
node=node,
|
|
503
|
+
weight_shape=org_w_shape,
|
|
504
|
+
num_bits=num_bits,
|
|
505
|
+
group_size=group_size,
|
|
506
|
+
k_blocks=k_blocks,
|
|
507
|
+
q_weight=q_weight.astype("uint8"),
|
|
508
|
+
scale=scale.astype(dtype),
|
|
509
|
+
zero_point=zp if scheme == "asym" or algorithm == "k_quant" else None,
|
|
510
|
+
accuracy_level=accuracy_level,
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
model.add_initializers(new_inits)
|
|
514
|
+
remove_nodes.append(node)
|
|
515
|
+
new_nodes.append(q_matmul_node)
|
|
516
|
+
else:
|
|
517
|
+
q_weight = qdq_tensor(weight.T, num_bits, group_size, scheme, "int", ratios.get(node.input[1], 1))
|
|
518
|
+
q_weight = np.reshape(q_weight, (org_w_shape[1], -1))
|
|
519
|
+
q_weight = np.transpose(q_weight)
|
|
520
|
+
q_weight = q_weight[: org_w_shape[0], :].astype(dtype)
|
|
521
|
+
q_weight_tensor = onnx.helper.make_tensor(
|
|
522
|
+
name=node.input[1] + f"_Q{num_bits!s}G{group_size!s}",
|
|
523
|
+
data_type=np_dtype_to_tensor_dtype(dtype),
|
|
524
|
+
dims=weight.shape,
|
|
525
|
+
vals=q_weight.tobytes(),
|
|
526
|
+
raw=True,
|
|
527
|
+
)
|
|
528
|
+
model.add_initializer(q_weight_tensor)
|
|
529
|
+
node.input[1] = q_weight_tensor.name
|
|
530
|
+
if init_share_num == 1:
|
|
531
|
+
model.remove_initializer(weight_tensor)
|
|
532
|
+
|
|
533
|
+
model.add_nodes(new_nodes)
|
|
534
|
+
model.remove_nodes(remove_nodes)
|
|
535
|
+
model.topological_sort()
|
|
536
|
+
return model
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
def get_weight_scale(weight, group_size):
|
|
540
|
+
"""Get the scale of weight."""
|
|
541
|
+
org_shape = weight.shape
|
|
542
|
+
weight = np.reshape(weight, (-1, group_size)) if group_size != -1 else weight
|
|
543
|
+
scale = np.mean(np.reshape(np.abs(weight) / np.max(np.abs(weight), axis=1, keepdims=True), org_shape), axis=0)
|
|
544
|
+
return scale
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
def prepare_inputs(model, n_samples, dataloader, providers):
|
|
548
|
+
"""Prepare inputs for weight only quantization.
|
|
549
|
+
|
|
550
|
+
Args:
|
|
551
|
+
model (ModelProto or ONNXModel): onnx model
|
|
552
|
+
n_samples (int, optional): calibration sample number. -1 means all samples.
|
|
553
|
+
dataloader (object): dataloader for calibration.
|
|
554
|
+
providers (list): providers to use
|
|
555
|
+
|
|
556
|
+
Returns:
|
|
557
|
+
inputs: prepared inputs.
|
|
558
|
+
so: session options
|
|
559
|
+
"""
|
|
560
|
+
from importlib.util import find_spec # noqa: PLC0415
|
|
561
|
+
|
|
562
|
+
from .util import to_numpy # noqa: PLC0415
|
|
563
|
+
|
|
564
|
+
so = ort.SessionOptions()
|
|
565
|
+
if sys.version_info < (3, 11) and find_spec("onnxruntime_extensions"): # pragma: no cover
|
|
566
|
+
from onnxruntime_extensions import get_library_path # noqa: PLC0415
|
|
567
|
+
|
|
568
|
+
so.register_custom_ops_library(get_library_path())
|
|
569
|
+
if model.is_large_model:
|
|
570
|
+
onnx.save_model(
|
|
571
|
+
model.model,
|
|
572
|
+
model.model_path + "_augment.onnx",
|
|
573
|
+
save_as_external_data=True,
|
|
574
|
+
all_tensors_to_one_file=True,
|
|
575
|
+
convert_attribute=False,
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
session = (
|
|
579
|
+
ort.InferenceSession(model.model.SerializeToString(), so, providers=providers)
|
|
580
|
+
if not model.is_large_model
|
|
581
|
+
else ort.InferenceSession(model.model_path + "_augment.onnx", so, providers=providers)
|
|
582
|
+
)
|
|
583
|
+
inputs_names = [i.name for i in session.get_inputs()]
|
|
584
|
+
del session
|
|
585
|
+
|
|
586
|
+
inputs = []
|
|
587
|
+
for i, data in enumerate(dataloader):
|
|
588
|
+
if n_samples != -1 and ((i + 1) * dataloader.batch_size) > n_samples:
|
|
589
|
+
break
|
|
590
|
+
if len(inputs_names) != 1 or isinstance(data[0], dict):
|
|
591
|
+
assert len(data[0]) == len(inputs_names), (
|
|
592
|
+
f"Input number mismatch, require {len(inputs_names)} but get {len(data[0])}"
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
if isinstance(data[0], dict):
|
|
596
|
+
inputs.append(dict([(name, to_numpy(inp_data)) for name, inp_data in data[0].items()])) # noqa: C404
|
|
597
|
+
elif isinstance(data[0], np.ndarray): # pragma: no cover
|
|
598
|
+
inputs.append(dict([(name, inp) for name, inp in zip(inputs_names, [data[0]], strict=False)])) # noqa: C404
|
|
599
|
+
else: # pragma: no cover
|
|
600
|
+
inputs.append(dict([(name, to_numpy(inp)) for name, inp in zip(inputs_names, data[0], strict=False)])) # noqa: C404
|
|
601
|
+
return inputs, so
|
|
602
|
+
|
|
603
|
+
|
|
604
|
+
def gptq(
|
|
605
|
+
W,
|
|
606
|
+
H,
|
|
607
|
+
num_bits=4,
|
|
608
|
+
group_size=32,
|
|
609
|
+
scheme="asym",
|
|
610
|
+
blocksize=128,
|
|
611
|
+
percdamp=0.01,
|
|
612
|
+
actorder=False,
|
|
613
|
+
mse=False,
|
|
614
|
+
perchannel=True,
|
|
615
|
+
):
|
|
616
|
+
"""Quant the weight with GPTQ method.
|
|
617
|
+
|
|
618
|
+
Args:
|
|
619
|
+
W (array): weight.
|
|
620
|
+
H (array): Hessian matrix.
|
|
621
|
+
num_bits (int, optional): num_bits. Default is 4.
|
|
622
|
+
group_size (int, optional): how many elements share one scale/zp. Default is 32.
|
|
623
|
+
scheme (str, optional): sym or asym. Defaults to "asym".
|
|
624
|
+
blocksize (int, optional): blocksize to quantize weight.
|
|
625
|
+
percdamp (float, optional): percent of the average Hessian diagonal to use for dampening.
|
|
626
|
+
actorder (bool, optional): whether rearrange Hessian matrix considering the diag's value.
|
|
627
|
+
mse (bool, optional): whether get scale and zero point with mse error.
|
|
628
|
+
perchannel (bool, optional): whether quantize weight per-channel.
|
|
629
|
+
|
|
630
|
+
Returns:
|
|
631
|
+
Q: fake quantized weight
|
|
632
|
+
"""
|
|
633
|
+
maxq = 2**num_bits - 1
|
|
634
|
+
grid = 100
|
|
635
|
+
maxshrink = 0.8
|
|
636
|
+
norm = 2.4
|
|
637
|
+
|
|
638
|
+
def find_params(weight):
|
|
639
|
+
org_shape = weight.shape
|
|
640
|
+
# find zp, scale
|
|
641
|
+
if not perchannel:
|
|
642
|
+
weight = np.expand_dims(weight.flatten(), axis=1)
|
|
643
|
+
tmp = np.zeros(weight.shape[1])
|
|
644
|
+
xmin = np.minimum(np.min(weight, axis=0), tmp)
|
|
645
|
+
xmax = np.maximum(np.max(weight, axis=0), tmp)
|
|
646
|
+
if scheme == "sym":
|
|
647
|
+
xmax = np.maximum(np.abs(xmin), xmax)
|
|
648
|
+
tmp = xmin < 0
|
|
649
|
+
if np.any(tmp):
|
|
650
|
+
xmin[tmp] = -xmax[tmp]
|
|
651
|
+
tmp = (xmin == 0) & (xmax == 0)
|
|
652
|
+
xmin[tmp] = -1
|
|
653
|
+
xmax[tmp] = +1
|
|
654
|
+
|
|
655
|
+
scale = (xmax - xmin) / maxq
|
|
656
|
+
if scheme == "sym":
|
|
657
|
+
zero = np.ones(scale.shape) * (maxq + 1) / 2
|
|
658
|
+
else:
|
|
659
|
+
zero = np.round(-xmin / scale)
|
|
660
|
+
if mse:
|
|
661
|
+
best = np.ones([weight.shape[1]]) * float("inf")
|
|
662
|
+
for i in range(int(maxshrink * grid)):
|
|
663
|
+
p = 1 - i / grid
|
|
664
|
+
xmin1 = p * xmin
|
|
665
|
+
xmax1 = p * xmax
|
|
666
|
+
scale1 = (xmax1 - xmin1) / maxq
|
|
667
|
+
zero1 = np.round(-xmin1 / scale1) if scheme != "sym" else zero
|
|
668
|
+
q = np.clip(np.round(weight / scale1) + zero1, 0, maxq)
|
|
669
|
+
q -= weight
|
|
670
|
+
q = np.power(np.abs(q), norm)
|
|
671
|
+
err = np.sum(q, 0)
|
|
672
|
+
tmp = err < best
|
|
673
|
+
if np.any(tmp):
|
|
674
|
+
best[tmp] = err[tmp]
|
|
675
|
+
scale[tmp] = scale1[tmp]
|
|
676
|
+
zero[tmp] = zero1[tmp]
|
|
677
|
+
if not perchannel:
|
|
678
|
+
tmp = org_shape[1]
|
|
679
|
+
scale = np.repeat(scale, tmp)
|
|
680
|
+
zero = np.repeat(zero, tmp)
|
|
681
|
+
shape = [-1] + [1] * (len(org_shape) - 1)
|
|
682
|
+
scale = np.reshape(scale, shape)
|
|
683
|
+
zero = np.reshape(zero, shape)
|
|
684
|
+
return scale, zero
|
|
685
|
+
|
|
686
|
+
shape = W.shape
|
|
687
|
+
scale, zp = find_params(W)
|
|
688
|
+
dead = np.diag(H) == 0
|
|
689
|
+
H[dead, dead] = 1
|
|
690
|
+
W[dead, :] = 0 # such channel makes no contribution to quantization computation
|
|
691
|
+
|
|
692
|
+
# rearrange considering the diag's value
|
|
693
|
+
if actorder:
|
|
694
|
+
perm = np.argsort(np.diag(H))[::-1]
|
|
695
|
+
W = W[perm, :] # noqa: N806
|
|
696
|
+
H = H[perm, :][:, perm] # noqa: N806
|
|
697
|
+
Losses = np.zeros_like(W) # noqa: N806
|
|
698
|
+
Q = np.zeros_like(W) # noqa: N806
|
|
699
|
+
damp = percdamp * np.mean(np.diag(H))
|
|
700
|
+
diag = np.arange(shape[0])
|
|
701
|
+
H[diag, diag] += damp # add a average value of
|
|
702
|
+
H = np.linalg.cholesky(np.linalg.inv(H)).T # noqa: N806
|
|
703
|
+
Hinv = H # noqa: N806
|
|
704
|
+
for i1 in range(0, shape[0], blocksize):
|
|
705
|
+
i2 = min(i1 + blocksize, shape[0])
|
|
706
|
+
count = i2 - i1
|
|
707
|
+
|
|
708
|
+
W1 = copy.deepcopy(W[i1:i2, :]) # noqa: N806
|
|
709
|
+
Q1 = np.zeros_like(W1) # noqa: N806
|
|
710
|
+
Err1 = np.zeros_like(W1) # noqa: N806
|
|
711
|
+
Losses1 = np.zeros_like(W1) # noqa: N806
|
|
712
|
+
Hinv1 = Hinv[i1:i2, i1:i2] # noqa: N806
|
|
713
|
+
|
|
714
|
+
for i in range(count): # within a block, channel wise
|
|
715
|
+
w = W1[i, :]
|
|
716
|
+
d = Hinv1[i, i]
|
|
717
|
+
|
|
718
|
+
if group_size != -1:
|
|
719
|
+
if (i1 + i) % group_size == 0:
|
|
720
|
+
scale, zp = find_params(W[(i1 + i) : (i1 + i + group_size), :])
|
|
721
|
+
|
|
722
|
+
q = (scale * (np.clip(np.round(w[:, np.newaxis] / scale) + zp, 0, maxq) - zp)).flatten()
|
|
723
|
+
Q1[i, :] = q
|
|
724
|
+
Losses1[i, :] = (w - q) ** 2 / d**2
|
|
725
|
+
|
|
726
|
+
err1 = (w - q) / d
|
|
727
|
+
W1[i:, :] -= np.matmul(np.expand_dims(Hinv1[i:, i], axis=1), np.expand_dims(err1, axis=0))
|
|
728
|
+
Err1[i, :] = err1
|
|
729
|
+
|
|
730
|
+
Q[i1:i2, :] = Q1
|
|
731
|
+
Losses[i1:i2, :] = Losses1 / 2
|
|
732
|
+
|
|
733
|
+
W[i2:, :] -= np.matmul(Hinv[i2:, i1:i2], Err1)
|
|
734
|
+
|
|
735
|
+
if actorder:
|
|
736
|
+
invperm = np.argsort(perm)
|
|
737
|
+
Q = Q[invperm, :] # noqa: N806
|
|
738
|
+
|
|
739
|
+
Q = np.reshape(Q, W.shape) # noqa: N806
|
|
740
|
+
del W
|
|
741
|
+
return Q
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
def gptq_quantize(
|
|
745
|
+
model,
|
|
746
|
+
dataloader,
|
|
747
|
+
weight_config={}, # noqa: B006
|
|
748
|
+
num_bits=4,
|
|
749
|
+
group_size=32,
|
|
750
|
+
scheme="asym",
|
|
751
|
+
n_samples=128,
|
|
752
|
+
percdamp=0.01,
|
|
753
|
+
blocksize=128,
|
|
754
|
+
actorder=False,
|
|
755
|
+
mse=False,
|
|
756
|
+
perchannel=True,
|
|
757
|
+
accuracy_level=0,
|
|
758
|
+
providers=["CPUExecutionProvider"], # noqa: B006
|
|
759
|
+
):
|
|
760
|
+
"""Quant the model with GPTQ method.
|
|
761
|
+
|
|
762
|
+
Args:
|
|
763
|
+
model (ModelProto or ONNXModel): onnx model
|
|
764
|
+
dataloader (object): dataloader for calibration.
|
|
765
|
+
weight_config (dict): quantization config
|
|
766
|
+
For example,
|
|
767
|
+
weight_config = {
|
|
768
|
+
'fc2':
|
|
769
|
+
{
|
|
770
|
+
'bits': 4,
|
|
771
|
+
'group_size': 32,
|
|
772
|
+
'scheme': 'sym',
|
|
773
|
+
'algorithm': 'GPTQ'
|
|
774
|
+
}
|
|
775
|
+
}
|
|
776
|
+
num_bits (int, optional): num_bits. Default is 4.
|
|
777
|
+
group_size (int, optional): how many elements share one scale/zp. Default is 32.
|
|
778
|
+
scheme (str, optional): sym or asym. Defaults to "asym".
|
|
779
|
+
n_samples (int, optional): calibration sample number.
|
|
780
|
+
percdamp (float, optional): percent of the average Hessian diagonal to use for dampening.
|
|
781
|
+
blocksize (int, optional): blocksize to quantize weight.
|
|
782
|
+
actorder (bool, optional): whether rearrange Hessian matrix considering the diag's value.
|
|
783
|
+
mse (bool, optional): whether get scale and zero point with mse error.
|
|
784
|
+
perchannel (bool, optional): whether quantize weight per-channel.
|
|
785
|
+
accuracy_level (int): accuracy level. Support 0 (unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8).
|
|
786
|
+
providers (list): providers to use
|
|
787
|
+
|
|
788
|
+
Returns:
|
|
789
|
+
model: fake quantized ONNXModel
|
|
790
|
+
"""
|
|
791
|
+
model = ONNXModel(model)
|
|
792
|
+
base_dir = os.path.dirname(model.model_path) if model.model_path is not None else ""
|
|
793
|
+
|
|
794
|
+
inputs, so = prepare_inputs(model, n_samples, dataloader, providers)
|
|
795
|
+
del dataloader
|
|
796
|
+
org_output = copy.deepcopy(model.model.graph.output)
|
|
797
|
+
model.remove_tensors_from_outputs([i.name for i in org_output])
|
|
798
|
+
output_names = []
|
|
799
|
+
for node in model.nodes():
|
|
800
|
+
if (
|
|
801
|
+
node.op_type in ["MatMul"]
|
|
802
|
+
and weight_config.get(node.name, {}) != "fp32"
|
|
803
|
+
and weight_config.get(node.name, {}).get("algorithm", "GPTQ") == "GPTQ"
|
|
804
|
+
):
|
|
805
|
+
output_names.append(node.input[0])
|
|
806
|
+
output_names = list(set(output_names))
|
|
807
|
+
model.add_tensors_to_outputs(output_names)
|
|
808
|
+
if model.is_large_model:
|
|
809
|
+
onnx.save_model(
|
|
810
|
+
model.model,
|
|
811
|
+
model.model_path + "_augment.onnx",
|
|
812
|
+
save_as_external_data=True,
|
|
813
|
+
all_tensors_to_one_file=True,
|
|
814
|
+
convert_attribute=False,
|
|
815
|
+
)
|
|
816
|
+
|
|
817
|
+
session = (
|
|
818
|
+
ort.InferenceSession(model.model.SerializeToString(), so, providers=providers)
|
|
819
|
+
if not model.is_large_model
|
|
820
|
+
else ort.InferenceSession(model.model_path + "_augment.onnx", so, providers=providers)
|
|
821
|
+
)
|
|
822
|
+
|
|
823
|
+
for idx, input_name in enumerate(output_names):
|
|
824
|
+
simple_progress_bar(len(output_names), idx + 1)
|
|
825
|
+
node_list = []
|
|
826
|
+
weights = []
|
|
827
|
+
|
|
828
|
+
for node in model.input_name_to_nodes[input_name]:
|
|
829
|
+
if (
|
|
830
|
+
node.op_type in ["MatMul"]
|
|
831
|
+
and weight_config.get(node.name, {}) != "fp32"
|
|
832
|
+
and weight_config.get(node.name, {}).get("algorithm", "GPTQ") == "GPTQ"
|
|
833
|
+
and model.get_initializer(node.input[1]) is not None
|
|
834
|
+
):
|
|
835
|
+
weight = numpy_helper.to_array(
|
|
836
|
+
model.get_initializer(model.get_node(node.name).input[1]), base_dir
|
|
837
|
+
).copy()
|
|
838
|
+
if len(weight.shape) != 2:
|
|
839
|
+
continue
|
|
840
|
+
|
|
841
|
+
weights.append(weight)
|
|
842
|
+
node_list.append(model.get_node(node.name))
|
|
843
|
+
|
|
844
|
+
if len(weights) == 0:
|
|
845
|
+
continue
|
|
846
|
+
|
|
847
|
+
Hs = [np.zeros((i.shape[0], i.shape[0])) for i in weights] # noqa: N806
|
|
848
|
+
nsamples = 0
|
|
849
|
+
for data in inputs:
|
|
850
|
+
inp = session.run([input_name], data)[0]
|
|
851
|
+
tmp = inp.shape[0]
|
|
852
|
+
inp = np.reshape(inp, (-1, inp.shape[-1]))
|
|
853
|
+
Hs = [i * (nsamples / (nsamples + tmp)) for i in Hs] # noqa: N806
|
|
854
|
+
nsamples += tmp
|
|
855
|
+
inp = np.sqrt(2 / nsamples) * inp
|
|
856
|
+
Hs = [i + np.matmul(inp.T, inp) for i in Hs] # noqa: N806
|
|
857
|
+
|
|
858
|
+
for (
|
|
859
|
+
node,
|
|
860
|
+
weight,
|
|
861
|
+
H, # noqa: N806
|
|
862
|
+
) in zip(node_list, weights, Hs, strict=False):
|
|
863
|
+
if node.name in weight_config:
|
|
864
|
+
num_bits = weight_config[node.name]["bits"]
|
|
865
|
+
group_size = weight_config[node.name]["group_size"]
|
|
866
|
+
scheme = weight_config[node.name]["scheme"]
|
|
867
|
+
group_size = group_size if group_size != -1 else weight.shape[0]
|
|
868
|
+
dtype = weight.dtype
|
|
869
|
+
|
|
870
|
+
q_weight = gptq(
|
|
871
|
+
weight,
|
|
872
|
+
H,
|
|
873
|
+
num_bits=num_bits,
|
|
874
|
+
group_size=group_size,
|
|
875
|
+
scheme=scheme,
|
|
876
|
+
blocksize=blocksize,
|
|
877
|
+
percdamp=percdamp,
|
|
878
|
+
actorder=actorder,
|
|
879
|
+
mse=mse,
|
|
880
|
+
perchannel=perchannel,
|
|
881
|
+
)
|
|
882
|
+
|
|
883
|
+
weight_tensor = model.get_initializer(node.input[1])
|
|
884
|
+
init_share_num = model.get_initializer_share_num(node.input[1])
|
|
885
|
+
|
|
886
|
+
satisfy_MatMulNBits_condition = num_bits == 4 # noqa: N806
|
|
887
|
+
|
|
888
|
+
if satisfy_MatMulNBits_condition: # pragma: no cover
|
|
889
|
+
org_shape = weight.shape
|
|
890
|
+
k_blocks = (org_shape[0] + group_size - 1) // group_size
|
|
891
|
+
q_weight = pad_tensor(q_weight, group_size, k_blocks)
|
|
892
|
+
q_weight, scale, zp = quant_tensor(q_weight.T, num_bits, group_size, scheme, "uint")
|
|
893
|
+
q_matmul_node, new_inits = make_matmul_weight_only_node(
|
|
894
|
+
node=node,
|
|
895
|
+
weight_shape=org_shape,
|
|
896
|
+
num_bits=num_bits,
|
|
897
|
+
group_size=group_size,
|
|
898
|
+
k_blocks=k_blocks,
|
|
899
|
+
q_weight=q_weight.astype("uint8"),
|
|
900
|
+
scale=scale.astype(dtype),
|
|
901
|
+
zero_point=zp if scheme == "asym" else None,
|
|
902
|
+
accuracy_level=accuracy_level,
|
|
903
|
+
)
|
|
904
|
+
|
|
905
|
+
model.add_initializers(new_inits)
|
|
906
|
+
model.remove_node(node)
|
|
907
|
+
model.add_node(q_matmul_node)
|
|
908
|
+
else:
|
|
909
|
+
q_weight_tensor = onnx.helper.make_tensor(
|
|
910
|
+
name=node.input[1] + f"_Q{num_bits!s}G{group_size!s}",
|
|
911
|
+
data_type=np_dtype_to_tensor_dtype(dtype),
|
|
912
|
+
dims=q_weight.shape,
|
|
913
|
+
vals=q_weight.astype(dtype).tobytes(),
|
|
914
|
+
raw=True,
|
|
915
|
+
)
|
|
916
|
+
model.add_initializer(q_weight_tensor)
|
|
917
|
+
node.input[1] = q_weight_tensor.name
|
|
918
|
+
if init_share_num == 1:
|
|
919
|
+
model.remove_initializer(weight_tensor)
|
|
920
|
+
|
|
921
|
+
model.remove_tensors_from_outputs(output_names)
|
|
922
|
+
model.model.graph.output.MergeFrom(org_output)
|
|
923
|
+
|
|
924
|
+
model.topological_sort()
|
|
925
|
+
|
|
926
|
+
# reload external data to prevent external data file path errors
|
|
927
|
+
if model.is_large_model:
|
|
928
|
+
from onnx.external_data_helper import load_external_data_for_model # noqa: PLC0415
|
|
929
|
+
|
|
930
|
+
load_external_data_for_model(model.model, os.path.split(model.model_path)[0])
|
|
931
|
+
|
|
932
|
+
return model
|