onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2 @@
1
+ # from .base_operator import QuantOperatorBase
2
+ # from .matmul import MatMulInteger
@@ -0,0 +1,119 @@
1
+ import onnx
2
+
3
+ from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
4
+ from .base_operator import QuantOperatorBase
5
+ from .qdq_base_operator import QDQOperatorBase
6
+
7
+
8
+ class QLinearActivation(QuantOperatorBase):
9
+ def __init__(self, onnx_quantizer, onnx_node):
10
+ super().__init__(onnx_quantizer, onnx_node)
11
+
12
+ def QuantizeClipRelu(self): # noqa: N802
13
+ node = self.node
14
+ assert node.op_type == "Relu" or node.op_type == "Clip"
15
+
16
+ # When mode is QLinearOps, the output quantization params are calculated based on outputs from
17
+ # activation nodes, therefore these nodes can be removed from the graph if they follow a quantized op.
18
+ # If input to this node is not quantized then keep this node
19
+ # If activation is symmetric, not quantize the op and simply return
20
+ if node.input[0] not in self.quantizer.quantized_value_map or self.quantizer.is_activation_symmetric:
21
+ return super().quantize()
22
+
23
+ quantized_value = self.quantizer.quantized_value_map[node.input[0]]
24
+ self.quantizer.quantized_value_map[node.output[0]] = quantized_value
25
+
26
+ def quantize(self):
27
+ node = self.node
28
+ if node.op_type == "Relu" or node.op_type == "Clip":
29
+ self.QuantizeClipRelu()
30
+ return
31
+
32
+ nnapi_sigmoid_option = "extra.Sigmoid.nnapi"
33
+ sigmoid_nnapi_mode = (
34
+ node.op_type == "Sigmoid"
35
+ and nnapi_sigmoid_option in self.quantizer.extra_options
36
+ and self.quantizer.extra_options[nnapi_sigmoid_option]
37
+ )
38
+ use_scale = 1 / 256.0 if sigmoid_nnapi_mode else None
39
+ use_zeropoint = 0 if sigmoid_nnapi_mode else None
40
+
41
+ # No assert on op_type as it is controlled by registry
42
+ # only try to quantize when given quantization parameters for it
43
+ (
44
+ data_found,
45
+ output_scale_name,
46
+ output_zp_name,
47
+ _,
48
+ _,
49
+ ) = self.quantizer._get_quantization_params(node.output[0], use_scale, use_zeropoint)
50
+ (
51
+ quantized_input_names,
52
+ zero_point_names,
53
+ scale_names,
54
+ nodes,
55
+ ) = self.quantizer.quantize_activation(node, [0])
56
+ if not data_found or quantized_input_names is None:
57
+ return super().quantize()
58
+
59
+ qlinear_activation_output = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
60
+ qlinear_activation_name = ""
61
+ if node.name:
62
+ qlinear_activation_name = node.name + "_quant"
63
+ kwargs = {}
64
+ for attribute in node.attribute:
65
+ kwargs.update(attribute_to_kwarg(attribute))
66
+ kwargs["domain"] = ms_domain
67
+
68
+ qlinear_activation_inputs = [
69
+ quantized_input_names[0],
70
+ scale_names[0],
71
+ zero_point_names[0],
72
+ output_scale_name,
73
+ output_zp_name,
74
+ ]
75
+
76
+ qlinear_activation_node = onnx.helper.make_node(
77
+ "QLinear" + node.op_type,
78
+ qlinear_activation_inputs,
79
+ [qlinear_activation_output],
80
+ qlinear_activation_name,
81
+ **kwargs,
82
+ )
83
+
84
+ # Create an entry for this quantized value
85
+ q_output = QuantizedValue(
86
+ node.output[0],
87
+ qlinear_activation_output,
88
+ output_scale_name,
89
+ output_zp_name,
90
+ QuantizedValueType.Input,
91
+ )
92
+ self.quantizer.quantized_value_map[node.output[0]] = q_output
93
+
94
+ nodes.append(qlinear_activation_node)
95
+ self.quantizer.new_nodes += nodes
96
+
97
+
98
+ class QDQRemovableActivation(QDQOperatorBase):
99
+ def __init__(self, onnx_quantizer, onnx_node):
100
+ super().__init__(onnx_quantizer, onnx_node)
101
+
102
+ def quantize(self):
103
+ node = self.node
104
+
105
+ # If input to this node is not quantized then keep this node
106
+ if not self.quantizer.is_tensor_quantized(node.input[0]):
107
+ return
108
+
109
+ if (
110
+ not self.quantizer.is_activation_symmetric
111
+ and not self.quantizer.qdq_keep_removable_activations
112
+ and self.quantizer.try_replacing_upstream_output(node.input[0], node.output[0])
113
+ ):
114
+ self.quantizer.remove_node(self.node)
115
+ else:
116
+ self.quantizer.quantize_activation_tensor(node.input[0])
117
+
118
+ if not self.disable_qdq_for_node_output:
119
+ self.quantizer.quantize_activation_tensor(node.output[0])
@@ -0,0 +1,18 @@
1
+ from .base_operator import QuantOperatorBase
2
+
3
+
4
+ # Use the quantized tensor as input without DQ.
5
+ class QArgMax(QuantOperatorBase):
6
+ def __init__(self, onnx_quantizer, onnx_node):
7
+ super().__init__(onnx_quantizer, onnx_node)
8
+
9
+ def quantize(self):
10
+ node = self.node
11
+
12
+ quantized_input_value = self.quantizer.find_quantized_value(node.input[0])
13
+ if quantized_input_value is None:
14
+ self.quantizer.new_nodes += [node]
15
+ return
16
+
17
+ node.input[0] = quantized_input_value.q_name
18
+ self.quantizer.new_nodes += [node]
@@ -0,0 +1,73 @@
1
+ import onnx
2
+ from onnx import onnx_pb as onnx_proto # noqa: F401
3
+
4
+ from ..quant_utils import attribute_to_kwarg, ms_domain
5
+ from .base_operator import QuantOperatorBase
6
+
7
+ """
8
+ Quantize Attention
9
+ """
10
+
11
+
12
+ class AttentionQuant(QuantOperatorBase):
13
+ def __init__(self, onnx_quantizer, onnx_node):
14
+ super().__init__(onnx_quantizer, onnx_node)
15
+
16
+ def should_quantize(self):
17
+ return self.quantizer.should_quantize_node(self.node)
18
+
19
+ def quantize(self):
20
+ """
21
+ parameter node: Attention node.
22
+ parameter new_nodes_list: List of new nodes created before processing this node.
23
+ return: a list of nodes in topological order that represents quantized Attention node.
24
+ """
25
+ node = self.node
26
+ assert node.op_type == "Attention"
27
+
28
+ # TODO This is a temporary fix to stop exporting QAttention with qkv_hidden_sizes
29
+ # attribute. This needs to be removed once the QAttention for varied q,k,v sizes
30
+ # is implemented
31
+ for attr in node.attribute:
32
+ if attr.name == "qkv_hidden_sizes":
33
+ return super().quantize()
34
+
35
+ (
36
+ quantized_input_names,
37
+ zero_point_names,
38
+ scale_names,
39
+ nodes,
40
+ ) = self.quantizer.quantize_activation(node, [0])
41
+
42
+ (
43
+ quantized_input_names_weight,
44
+ zero_point_names_weight,
45
+ scale_names_weight,
46
+ nodes_weight,
47
+ ) = self.quantizer.quantize_weight(node, [1], reduce_range=True, op_level_per_channel=True)
48
+ quantized_input_names.extend(quantized_input_names_weight)
49
+ zero_point_names.extend(zero_point_names_weight)
50
+ scale_names.extend(scale_names_weight)
51
+ nodes.extend(nodes_weight)
52
+
53
+ if quantized_input_names is None:
54
+ return super().quantize()
55
+
56
+ qattention_name = "" if not node.name else node.name + "_quant"
57
+
58
+ inputs = []
59
+ inputs.extend(quantized_input_names)
60
+ inputs.extend([node.input[2]])
61
+ inputs.extend(scale_names)
62
+ inputs.extend([node.input[3] if len(node.input) > 3 else ""])
63
+ inputs.extend(zero_point_names)
64
+ inputs.extend([node.input[4] if len(node.input) > 4 else ""])
65
+
66
+ kwargs = {}
67
+ for attribute in node.attribute:
68
+ kwargs.update(attribute_to_kwarg(attribute))
69
+ kwargs["domain"] = ms_domain
70
+ qattention_node = onnx.helper.make_node("QAttention", inputs, node.output, qattention_name, **kwargs)
71
+ nodes.append(qattention_node)
72
+
73
+ self.quantizer.new_nodes += nodes
@@ -0,0 +1,26 @@
1
+ class QuantOperatorBase:
2
+ def __init__(self, onnx_quantizer, onnx_node):
3
+ self.quantizer = onnx_quantizer
4
+ self.node = onnx_node
5
+
6
+ def should_quantize(self):
7
+ if not self.quantizer.should_quantize_node(self.node):
8
+ return False
9
+
10
+ return self.quantizer.is_float_tensor(self.node.input[0])
11
+
12
+ def quantize(self):
13
+ """
14
+ Given a node which does not support quantization, this method checks whether the input to
15
+ this node is quantized and adds a DequantizeLinear node to dequantize this input back to FP32
16
+ parameter node: Current node
17
+ parameter new_nodes_list: List of new nodes created before processing current node
18
+ return: List of new nodes created
19
+ """
20
+ for _, node_input in enumerate(self.node.input):
21
+ dequantize_node = self.quantizer._dequantize_value(node_input)
22
+ if dequantize_node is not None:
23
+ self.quantizer.new_nodes.append(dequantize_node)
24
+
25
+ # Append the original node
26
+ self.quantizer.new_nodes.append(self.node)
@@ -0,0 +1,72 @@
1
+ import onnx
2
+ from onnx import onnx_pb as onnx_proto # noqa: F401
3
+
4
+ from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
5
+ from .base_operator import QuantOperatorBase
6
+
7
+
8
+ class QLinearBinaryOp(QuantOperatorBase):
9
+ def __init__(self, onnx_quantizer, onnx_node):
10
+ super().__init__(onnx_quantizer, onnx_node)
11
+
12
+ def quantize(self):
13
+ node = self.node
14
+
15
+ (
16
+ data_found,
17
+ output_scale_name,
18
+ output_zp_name,
19
+ _,
20
+ _,
21
+ ) = self.quantizer._get_quantization_params(node.output[0])
22
+ (
23
+ quantized_input_names,
24
+ zero_point_names,
25
+ scale_names,
26
+ nodes,
27
+ ) = self.quantizer.quantize_activation(node, [0, 1])
28
+ if not data_found or quantized_input_names is None:
29
+ return super().quantize()
30
+
31
+ qlinear_binary_math_output = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
32
+ qlinear_binary_math_name = node.name + "_quant" if node.name else ""
33
+
34
+ kwargs = {}
35
+ for attribute in node.attribute:
36
+ kwargs.update(attribute_to_kwarg(attribute))
37
+ kwargs["domain"] = ms_domain
38
+
39
+ qlinear_binary_math_inputs = []
40
+ # Input 0
41
+ qlinear_binary_math_inputs.append(quantized_input_names[0])
42
+ qlinear_binary_math_inputs.append(scale_names[0])
43
+ qlinear_binary_math_inputs.append(zero_point_names[0])
44
+ # Input 1
45
+ qlinear_binary_math_inputs.append(quantized_input_names[1])
46
+ qlinear_binary_math_inputs.append(scale_names[1])
47
+ qlinear_binary_math_inputs.append(zero_point_names[1])
48
+
49
+ # Output
50
+ qlinear_binary_math_inputs.append(output_scale_name)
51
+ qlinear_binary_math_inputs.append(output_zp_name)
52
+
53
+ qlinear_binary_math_node = onnx.helper.make_node(
54
+ "QLinear" + node.op_type,
55
+ qlinear_binary_math_inputs,
56
+ [qlinear_binary_math_output],
57
+ qlinear_binary_math_name,
58
+ **kwargs,
59
+ )
60
+ nodes.append(qlinear_binary_math_node)
61
+
62
+ # Create an entry for this quantized value
63
+ q_output = QuantizedValue(
64
+ node.output[0],
65
+ qlinear_binary_math_output,
66
+ output_scale_name,
67
+ output_zp_name,
68
+ QuantizedValueType.Input,
69
+ )
70
+ self.quantizer.quantized_value_map[node.output[0]] = q_output
71
+
72
+ self.quantizer.new_nodes += nodes
@@ -0,0 +1,62 @@
1
+ import onnx
2
+
3
+ from ..quant_utils import ( # noqa: F401
4
+ TENSOR_NAME_QUANT_SUFFIX,
5
+ QuantizedValue,
6
+ QuantizedValueType,
7
+ attribute_to_kwarg,
8
+ ms_domain,
9
+ )
10
+ from .base_operator import QuantOperatorBase
11
+ from .qdq_base_operator import QDQOperatorBase # noqa: F401
12
+
13
+
14
+ class QLinearConcat(QuantOperatorBase):
15
+ def __init__(self, onnx_quantizer, onnx_node):
16
+ super().__init__(onnx_quantizer, onnx_node)
17
+
18
+ def quantize(self):
19
+ node = self.node
20
+
21
+ (
22
+ data_found,
23
+ output_scale_name,
24
+ output_zp_name,
25
+ _,
26
+ _,
27
+ ) = self.quantizer._get_quantization_params(node.output[0])
28
+ (
29
+ q_input_names,
30
+ zero_point_names,
31
+ scale_names,
32
+ nodes,
33
+ ) = self.quantizer.quantize_activation(node, [*range(len(node.input))])
34
+ if not data_found or q_input_names is None:
35
+ return super().quantize()
36
+
37
+ # Create an entry for output quantized value
38
+ quantized_input_value = self.quantizer.quantized_value_map[node.input[0]]
39
+ quantized_output_value = QuantizedValue(
40
+ node.output[0],
41
+ node.output[0] + TENSOR_NAME_QUANT_SUFFIX,
42
+ output_scale_name,
43
+ output_zp_name,
44
+ quantized_input_value.value_type,
45
+ )
46
+ self.quantizer.quantized_value_map[node.output[0]] = quantized_output_value
47
+
48
+ kwargs = {}
49
+ for attribute in node.attribute:
50
+ kwargs.update(attribute_to_kwarg(attribute))
51
+ kwargs["domain"] = ms_domain
52
+ qnode_name = node.name + "_quant" if node.name else ""
53
+
54
+ qlconcat_inputs = [output_scale_name, output_zp_name]
55
+ for i in range(len(q_input_names)):
56
+ qlconcat_inputs.extend([q_input_names[i], scale_names[i], zero_point_names[i]])
57
+ qlconcat_node = onnx.helper.make_node(
58
+ "QLinearConcat", qlconcat_inputs, [quantized_output_value.q_name], qnode_name, **kwargs
59
+ )
60
+
61
+ self.quantizer.new_nodes += nodes
62
+ self.quantizer.new_nodes += [qlconcat_node]
@@ -0,0 +1,260 @@
1
+ import numpy as np
2
+ import onnx
3
+ from onnx import onnx_pb as onnx_proto
4
+
5
+ from ..quant_utils import (
6
+ TENSOR_NAME_QUANT_SUFFIX,
7
+ QuantizedValue,
8
+ QuantizedValueType,
9
+ attribute_to_kwarg,
10
+ find_by_name,
11
+ get_mul_node,
12
+ )
13
+ from .base_operator import QuantOperatorBase
14
+ from .qdq_base_operator import QDQOperatorBase
15
+
16
+
17
+ class ConvInteger(QuantOperatorBase):
18
+ def __init__(self, onnx_quantizer, onnx_node):
19
+ super().__init__(onnx_quantizer, onnx_node)
20
+
21
+ def add_bias(self, nodes, scaled_output):
22
+ """
23
+ Given a node, this function handles bias add by adding a "reshape" node on bias and an "add" node
24
+ parameter nodes: new nodes would be appended into nodes
25
+ parameter node: current node (Conv)
26
+ parameter scaled_output: output of quant conv without bias
27
+ parameter output: output of Conv
28
+ parameter bias_name: bias of Conv
29
+ return: the name of output
30
+ """
31
+ node = self.node
32
+ model = self.quantizer.model
33
+ # Add tensors for the shape to be reshaped to
34
+ weight = find_by_name(node.input[1], model.initializer())
35
+ if weight is None:
36
+ raise ValueError(f"Expected {node.input[1]} to be an initializer")
37
+
38
+ # Add reshape for correct broadcase
39
+ output = node.output[0]
40
+ reshape_input_data = node.input[2] # bias of Conv
41
+ reshape_input_shape = output + "_bias_reshape_shape"
42
+ reshape_output = output + "_bias_reshape_output"
43
+
44
+ shape = np.ones((len(weight.dims)), dtype=np.int64)
45
+ shape[1] = -1
46
+ init_shape = onnx.helper.make_tensor(
47
+ reshape_input_shape, onnx_proto.TensorProto.INT64, [len(weight.dims)], shape
48
+ )
49
+ model.add_initializer(init_shape)
50
+
51
+ reshape_node = onnx.helper.make_node("Reshape", [reshape_input_data, reshape_input_shape], [reshape_output])
52
+ nodes.append(reshape_node)
53
+
54
+ # Add an Add operation for bias
55
+ add_node = onnx.helper.make_node("Add", [scaled_output, reshape_output], [output], output + "_bias_add")
56
+ nodes.append(add_node)
57
+
58
+ def quantize(self):
59
+ node = self.node
60
+ assert node.op_type == "Conv"
61
+ # Get Quantized from both activation(input[0]) and weight(input[1])
62
+ (
63
+ quantized_input_names,
64
+ zero_point_names,
65
+ scale_names,
66
+ nodes,
67
+ ) = self.quantizer.quantize_activation(node, [0])
68
+
69
+ (
70
+ quantized_input_names_weight,
71
+ zero_point_names_weight,
72
+ scale_names_weight,
73
+ nodes_weight,
74
+ ) = self.quantizer.quantize_weight(node, [1], reduce_range=self.quantizer.reduce_range)
75
+ quantized_input_names.extend(quantized_input_names_weight)
76
+ zero_point_names.extend(zero_point_names_weight)
77
+ scale_names.extend(scale_names_weight)
78
+ nodes.extend(nodes_weight)
79
+
80
+ conv_integer_output = node.output[0] + "_output_quantized"
81
+ conv_integer_name = node.name + "_quant" if node.name else ""
82
+
83
+ kwargs = {}
84
+ for attribute in node.attribute:
85
+ kwargs.update(attribute_to_kwarg(attribute))
86
+ conv_integer_node = onnx.helper.make_node(
87
+ "ConvInteger", quantized_input_names + zero_point_names, [conv_integer_output], conv_integer_name, **kwargs
88
+ )
89
+ nodes.append(conv_integer_node)
90
+
91
+ # Add cast operation to cast convInteger output to float.
92
+ onnx_type = self.quantizer.get_tensor_type(node.output[0], mandatory=True)
93
+ cast_op_output = conv_integer_output + "_cast_output"
94
+ cast_node = onnx.helper.make_node(
95
+ "Cast",
96
+ [conv_integer_output],
97
+ [cast_op_output],
98
+ conv_integer_output + "_cast",
99
+ to=onnx_type, # TODO: FLOAT ot FLOAT16
100
+ )
101
+ nodes.append(cast_node)
102
+
103
+ # Add mul operation to multiply scales of two inputs.
104
+ assert len(scale_names) == 2
105
+ if conv_integer_name:
106
+ scales_mul_op = conv_integer_name + "_scales_mul"
107
+ else:
108
+ scales_mul_op = scale_names[0] + "_" + scale_names[1] + "_mul"
109
+
110
+ scales_mul_node = find_by_name(scales_mul_op, self.quantizer.new_nodes)
111
+ if scales_mul_node is None:
112
+ scales_mul_node = get_mul_node(scale_names, scales_mul_op + ":0", scales_mul_op)
113
+ nodes.append(scales_mul_node)
114
+
115
+ scales_mul_op_output = scales_mul_node.output[0]
116
+
117
+ has_bias = len(node.input) == 3
118
+ scaled_output_name = node.output[0] if not has_bias else node.output[0] + "quant_scaled_output"
119
+
120
+ # Add mul operation to multiply mul_scales_op result with output of ConvInteger
121
+ # and make the output of this node the same as output of original conv node.
122
+ output_scale_mul_op = conv_integer_name + "_output_scale_mul" if conv_integer_name else ""
123
+ nodes.append(
124
+ get_mul_node(
125
+ [cast_op_output, scales_mul_op_output],
126
+ scaled_output_name,
127
+ output_scale_mul_op,
128
+ )
129
+ )
130
+
131
+ if has_bias:
132
+ self.add_bias(nodes, scaled_output_name)
133
+
134
+ self.quantizer.new_nodes += nodes
135
+
136
+
137
+ class QLinearConv(QuantOperatorBase):
138
+ def __init__(self, onnx_quantizer, onnx_node):
139
+ super().__init__(onnx_quantizer, onnx_node)
140
+
141
+ def quantize(self):
142
+ node = self.node
143
+ assert node.op_type == "Conv"
144
+
145
+ (
146
+ data_found,
147
+ output_scale_name,
148
+ output_zp_name,
149
+ _,
150
+ _,
151
+ ) = self.quantizer._get_quantization_params(node.output[0])
152
+
153
+ if self.quantizer.is_input_a_initializer(node.input[1]) and self.quantizer.is_per_channel():
154
+ (
155
+ quantized_input_names,
156
+ zero_point_names,
157
+ scale_names,
158
+ nodes,
159
+ ) = self.quantizer.quantize_activation(node, [0])
160
+ quant_weight_tuple = self.quantizer.quantize_weight_per_channel(
161
+ node.input[1],
162
+ onnx_proto.TensorProto.INT8,
163
+ 0, # self.quantizer.weight_qType?
164
+ )
165
+ quantized_input_names.append(quant_weight_tuple[0])
166
+ zero_point_names.append(quant_weight_tuple[1])
167
+ scale_names.append(quant_weight_tuple[2])
168
+ else:
169
+ (
170
+ quantized_input_names,
171
+ zero_point_names,
172
+ scale_names,
173
+ nodes,
174
+ ) = self.quantizer.quantize_activation(node, [0])
175
+
176
+ (
177
+ quantized_input_names_weight,
178
+ zero_point_names_weight,
179
+ scale_names_weight,
180
+ nodes_weight,
181
+ ) = self.quantizer.quantize_weight(node, [1], reduce_range=self.quantizer.reduce_range)
182
+ quantized_input_names.extend(quantized_input_names_weight)
183
+ zero_point_names.extend(zero_point_names_weight)
184
+ scale_names.extend(scale_names_weight)
185
+ nodes.extend(nodes_weight)
186
+
187
+ if not data_found or quantized_input_names is None:
188
+ return super().quantize()
189
+
190
+ quantized_bias_name = ""
191
+ bias_present = False
192
+ if len(node.input) == 3:
193
+ if self.quantizer.weight_qType == onnx_proto.TensorProto.FLOAT8E4M3FN:
194
+ raise RuntimeError("Quantization to FLOAT8E4M3FN for operator Conv is not supported.")
195
+ quantized_bias_name = self.quantizer.quantize_bias_static(node.input[2], node.input[0], node.input[1])
196
+ bias_present = True
197
+
198
+ qlinear_conv_output = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
199
+ qlinear_conv_name = node.name + "_quant" if node.name else ""
200
+
201
+ kwargs = {}
202
+ for attribute in node.attribute:
203
+ kwargs.update(attribute_to_kwarg(attribute))
204
+ qlinear_conv_inputs = []
205
+ # Input 0
206
+ qlinear_conv_inputs.append(quantized_input_names[0])
207
+ qlinear_conv_inputs.append(scale_names[0])
208
+ qlinear_conv_inputs.append(zero_point_names[0])
209
+ # Input 1
210
+ qlinear_conv_inputs.append(quantized_input_names[1])
211
+ qlinear_conv_inputs.append(scale_names[1])
212
+ qlinear_conv_inputs.append(zero_point_names[1])
213
+
214
+ # Output
215
+ qlinear_conv_inputs.append(output_scale_name)
216
+ qlinear_conv_inputs.append(output_zp_name)
217
+
218
+ if bias_present:
219
+ qlinear_conv_inputs.append(quantized_bias_name)
220
+
221
+ qlinear_conv_node = onnx.helper.make_node(
222
+ "QLinearConv", qlinear_conv_inputs, [qlinear_conv_output], qlinear_conv_name, **kwargs
223
+ )
224
+ nodes.append(qlinear_conv_node)
225
+
226
+ # Create an entry for this quantized value
227
+ q_output = QuantizedValue(
228
+ node.output[0],
229
+ qlinear_conv_output,
230
+ output_scale_name,
231
+ output_zp_name,
232
+ QuantizedValueType.Input,
233
+ )
234
+ self.quantizer.quantized_value_map[node.output[0]] = q_output
235
+
236
+ self.quantizer.new_nodes += nodes
237
+
238
+
239
+ class QDQConv(QDQOperatorBase):
240
+ def __init__(self, onnx_quantizer, onnx_node):
241
+ super().__init__(onnx_quantizer, onnx_node)
242
+
243
+ def quantize(self):
244
+ node = self.node
245
+ assert node.op_type == "Conv" or node.op_type == "ConvTranspose"
246
+
247
+ self.quantizer.quantize_activation_tensor(node.input[0])
248
+ if not self.disable_qdq_for_node_output:
249
+ self.quantizer.quantize_activation_tensor(node.output[0])
250
+
251
+ is_weight_per_channel, weight_axis = self.quantizer.is_tensor_per_channel(
252
+ node.input[1], default_axis=0 if node.op_type == "Conv" else 1
253
+ )
254
+ if is_weight_per_channel:
255
+ self.quantizer.quantize_weight_tensor_per_channel(node.input[1], weight_axis)
256
+ else:
257
+ self.quantizer.quantize_weight_tensor(node.input[1])
258
+
259
+ if len(node.input) == 3:
260
+ self.quantizer.quantize_bias_tensor(node.name, node.input[2], node.input[0], node.input[1])