onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,231 @@
1
+ import itertools
2
+ import logging
3
+
4
+ import onnx
5
+ from onnx import onnx_pb as onnx_proto
6
+
7
+ from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, find_by_name, get_mul_node
8
+ from .base_operator import QuantOperatorBase
9
+ from .qdq_base_operator import QDQOperatorBase
10
+
11
+
12
+ class QOpMatMul(QuantOperatorBase):
13
+ def __init__(self, onnx_quantizer, onnx_node):
14
+ super().__init__(onnx_quantizer, onnx_node)
15
+
16
+ def should_quantize(self):
17
+ if not self.quantizer.should_quantize_node(self.node):
18
+ logging.debug(f"Ignore MatMul {self.node.name}]")
19
+ return False
20
+
21
+ if (not self.quantizer.is_float_tensor(self.node.input[1])) and (
22
+ not self.quantizer.is_float_tensor(self.node.input[0])
23
+ ):
24
+ logging.info(f"Ignore MatMul due to non float inputs {self.node.name}]")
25
+ return False
26
+
27
+ # do not quantize non-constant B matrices for matmul
28
+ if self.quantizer.q_matmul_const_b_only:
29
+ if not self.quantizer.find_initializer_in_path(self.node.input[1]):
30
+ logging.info(f"Ignore MatMul due to non constant B: {self.quantizer.graph_scope}[{self.node.name}]")
31
+ return False
32
+ return True
33
+
34
+
35
+ """
36
+ Used when quantize mode is QuantizationMode.IntegerOps.
37
+ """
38
+
39
+
40
+ class MatMulInteger(QOpMatMul):
41
+ def __init__(self, onnx_quantizer, onnx_node):
42
+ super().__init__(onnx_quantizer, onnx_node)
43
+
44
+ def quantize(self):
45
+ node = self.node
46
+ assert node.op_type == "MatMul"
47
+ # Get Quantized from both activation(input[0]) and weight(input[1])
48
+ (
49
+ quantized_input_names,
50
+ zero_point_names,
51
+ scale_names,
52
+ nodes,
53
+ ) = self.quantizer.quantize_activation(node, [0])
54
+
55
+ (
56
+ quantized_input_names_weight,
57
+ zero_point_names_weight,
58
+ scale_names_weight,
59
+ nodes_weight,
60
+ ) = self.quantizer.quantize_weight(node, [1], reduce_range=True, op_level_per_channel=True)
61
+ quantized_input_names.extend(quantized_input_names_weight)
62
+ zero_point_names.extend(zero_point_names_weight)
63
+ scale_names.extend(scale_names_weight)
64
+ nodes.extend(nodes_weight)
65
+
66
+ matmul_integer_output = node.output[0] + "_output_quantized"
67
+ matmul_integer_name = node.name + "_quant" if node.name else ""
68
+ matmul_integer_node = onnx.helper.make_node(
69
+ "MatMulInteger",
70
+ quantized_input_names + zero_point_names,
71
+ [matmul_integer_output],
72
+ matmul_integer_name,
73
+ )
74
+ nodes.append(matmul_integer_node)
75
+
76
+ # Add cast operation to cast matmulInteger output to float.
77
+ cast_op_output = matmul_integer_output + "_cast_output"
78
+ otype = self.quantizer.get_tensor_type(node.output[0], mandatory=True)
79
+ cast_node = onnx.helper.make_node(
80
+ "Cast",
81
+ [matmul_integer_output],
82
+ [cast_op_output],
83
+ matmul_integer_output + "_cast",
84
+ to=otype,
85
+ )
86
+ nodes.append(cast_node)
87
+
88
+ # Add mul operation to multiply scales of two inputs.
89
+ assert len(scale_names) == 2
90
+ scales_mul_op = (
91
+ matmul_integer_name + "_scales_mul"
92
+ if matmul_integer_name
93
+ else scale_names[0] + "_" + scale_names[1] + "_mul"
94
+ )
95
+
96
+ scales_mul_node = find_by_name(scales_mul_op, self.quantizer.new_nodes)
97
+ if scales_mul_node is None:
98
+ scales_mul_node = get_mul_node(scale_names, scales_mul_op + ":0", scales_mul_op)
99
+ nodes.append(scales_mul_node)
100
+
101
+ scales_mul_op_output = scales_mul_node.output[0]
102
+
103
+ # Add mul operation to multiply mul_scales_op result with output of MatMulInteger
104
+ # and make the output of this node the same as output of original matmul node.
105
+ output_scale_mul_op = ""
106
+ if matmul_integer_name:
107
+ output_scale_mul_op = matmul_integer_name + "_output_scale_mul"
108
+ nodes.append(
109
+ get_mul_node(
110
+ [cast_op_output, scales_mul_op_output],
111
+ node.output[0],
112
+ output_scale_mul_op,
113
+ )
114
+ )
115
+ self.quantizer.new_nodes += nodes
116
+
117
+
118
+ """
119
+ Used when quantize mode is QuantizationMode.QLinearOps
120
+ """
121
+
122
+
123
+ class QLinearMatMul(QOpMatMul):
124
+ def __init__(self, onnx_quantizer, onnx_node):
125
+ super().__init__(onnx_quantizer, onnx_node)
126
+
127
+ def quantize(self):
128
+ node = self.node
129
+ assert node.op_type == "MatMul"
130
+ # Get Quantized from both activation(input[0]) and weight(input[1])
131
+ (
132
+ quantized_input_names,
133
+ zero_point_names,
134
+ scale_names,
135
+ nodes,
136
+ ) = self.quantizer.quantize_activation(node, [0])
137
+
138
+ (
139
+ quantized_input_names_weight,
140
+ zero_point_names_weight,
141
+ scale_names_weight,
142
+ nodes_weight,
143
+ ) = self.quantizer.quantize_weight(node, [1], reduce_range=True, op_level_per_channel=True)
144
+ quantized_input_names.extend(quantized_input_names_weight)
145
+ zero_point_names.extend(zero_point_names_weight)
146
+ scale_names.extend(scale_names_weight)
147
+
148
+ nodes.extend(nodes_weight)
149
+ (
150
+ data_found,
151
+ output_scale_name,
152
+ output_zp_name,
153
+ _,
154
+ _,
155
+ ) = self.quantizer._get_quantization_params(node.output[0])
156
+ if not data_found or quantized_input_names is None:
157
+ return super().quantize()
158
+
159
+ qlinear_matmul_output = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
160
+ qlinear_matmul_name = node.name + "_quant" if node.name else ""
161
+
162
+ qlinear_matmul_inputs = []
163
+ # Input 0
164
+ qlinear_matmul_inputs.append(quantized_input_names[0])
165
+ qlinear_matmul_inputs.append(scale_names[0])
166
+ qlinear_matmul_inputs.append(zero_point_names[0])
167
+ # Input 1
168
+ qlinear_matmul_inputs.append(quantized_input_names[1])
169
+ qlinear_matmul_inputs.append(scale_names[1])
170
+ qlinear_matmul_inputs.append(zero_point_names[1])
171
+ # Output quantization parameter
172
+ qlinear_matmul_inputs.append(output_scale_name)
173
+ qlinear_matmul_inputs.append(output_zp_name)
174
+
175
+ domain = (
176
+ "com.microsoft"
177
+ if self.quantizer.weight_qType
178
+ in {
179
+ onnx_proto.TensorProto.FLOAT8E4M3FN,
180
+ onnx_proto.TensorProto.FLOAT8E4M3FNUZ,
181
+ onnx_proto.TensorProto.FLOAT8E5M2,
182
+ onnx_proto.TensorProto.FLOAT8E5M2FNUZ,
183
+ }
184
+ else ""
185
+ )
186
+ qlinear_matmul_node = onnx.helper.make_node(
187
+ "QLinearMatMul",
188
+ qlinear_matmul_inputs,
189
+ [qlinear_matmul_output],
190
+ qlinear_matmul_name,
191
+ domain=domain,
192
+ )
193
+ nodes.append(qlinear_matmul_node)
194
+
195
+ # Create an entry for this quantized value
196
+ q_output = QuantizedValue(
197
+ node.output[0],
198
+ qlinear_matmul_output,
199
+ output_scale_name,
200
+ output_zp_name,
201
+ QuantizedValueType.Input,
202
+ )
203
+ self.quantizer.quantized_value_map[node.output[0]] = q_output
204
+
205
+ self.quantizer.new_nodes += nodes
206
+
207
+
208
+ class QDQMatMul(QDQOperatorBase):
209
+ def __init__(self, onnx_quantizer, onnx_node):
210
+ super().__init__(onnx_quantizer, onnx_node)
211
+
212
+ def quantize(self):
213
+ node = self.node
214
+ assert node.op_type == "MatMul"
215
+
216
+ if self.disable_qdq_for_node_output:
217
+ nodes_to_iterate = node.input
218
+ else:
219
+ nodes_to_iterate = itertools.chain(node.input, node.output)
220
+
221
+ for tensor_name in nodes_to_iterate:
222
+ if find_by_name(tensor_name, self.quantizer.model.initializer()):
223
+ is_per_channel, channel_axis = self.quantizer.is_tensor_per_channel(
224
+ tensor_name, default_axis=1, op_type=node.op_type
225
+ )
226
+ if is_per_channel:
227
+ self.quantizer.quantize_weight_tensor_per_channel(tensor_name, channel_axis)
228
+ else:
229
+ self.quantizer.quantize_weight_tensor(tensor_name)
230
+ else:
231
+ self.quantizer.quantize_activation_tensor(tensor_name)
@@ -0,0 +1,34 @@
1
+ from .direct_q8 import Direct8BitOp, QDQDirect8BitOp
2
+
3
+
4
+ class QMaxPool(Direct8BitOp):
5
+ def __init__(self, onnx_quantizer, onnx_node):
6
+ super().__init__(onnx_quantizer, onnx_node)
7
+
8
+ def quantize(self):
9
+ node = self.node
10
+ assert node.op_type == "MaxPool"
11
+
12
+ # if version is less than 12, go to normal quantize.
13
+ if self.quantizer.opset_version < 12:
14
+ super(Direct8BitOp, self).quantize()
15
+ return
16
+
17
+ # Direct 8bits op
18
+ return super().quantize()
19
+
20
+
21
+ class QDQMaxPool(QDQDirect8BitOp):
22
+ def __init__(self, onnx_quantizer, onnx_node):
23
+ super().__init__(onnx_quantizer, onnx_node)
24
+
25
+ def quantize(self):
26
+ node = self.node
27
+ assert node.op_type == "MaxPool"
28
+
29
+ # if version is less than 12, just no change
30
+ if self.quantizer.opset_version < 12:
31
+ return
32
+
33
+ # Direct 8bits op
34
+ return super().quantize()
@@ -0,0 +1,40 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from .qdq_base_operator import QDQOperatorBase
7
+
8
+
9
+ class QDQNormalization(QDQOperatorBase):
10
+ def __init__(self, onnx_quantizer, onnx_node):
11
+ super().__init__(onnx_quantizer, onnx_node)
12
+
13
+ def quantize(self):
14
+ node = self.node
15
+ assert node.op_type in {"InstanceNormalization", "LayerNormalization", "BatchNormalization"}
16
+
17
+ # Input
18
+ self.quantizer.quantize_activation_tensor(node.input[0])
19
+
20
+ # Scale
21
+ scale_is_initializer = self.quantizer.is_input_a_initializer(node.input[1])
22
+ scale_is_per_channel, scale_channel_axis = self.quantizer.is_tensor_per_channel(
23
+ node.input[1], default_axis=1, op_type=node.op_type
24
+ )
25
+
26
+ if scale_is_per_channel:
27
+ self.quantizer.quantize_weight_tensor_per_channel(node.input[1], axis=scale_channel_axis)
28
+ elif scale_is_initializer:
29
+ self.quantizer.quantize_weight_tensor(node.input[1])
30
+ else:
31
+ self.quantizer.quantize_activation_tensor(node.input[1])
32
+
33
+ # Bias
34
+ if len(node.input) > 2 and node.input[2]:
35
+ self.quantizer.quantize_bias_tensor(node.name, node.input[2], node.input[0], node.input[1])
36
+
37
+ # Output
38
+ if not self.disable_qdq_for_node_output:
39
+ for output_name in node.output:
40
+ self.quantizer.quantize_activation_tensor(output_name)
@@ -0,0 +1,100 @@
1
+ import onnx
2
+
3
+ from ..quant_utils import (
4
+ TENSOR_NAME_QUANT_SUFFIX,
5
+ QuantizedValue,
6
+ QuantizedValueType,
7
+ attribute_to_kwarg,
8
+ quantize_nparray,
9
+ )
10
+ from .base_operator import QuantOperatorBase
11
+
12
+
13
+ class QPad(QuantOperatorBase):
14
+ def __init__(self, onnx_quantizer, onnx_node):
15
+ super().__init__(onnx_quantizer, onnx_node)
16
+
17
+ def quantize(self):
18
+ node = self.node
19
+ assert node.op_type == "Pad"
20
+
21
+ # Only after version 11, it has the optional constant_value
22
+ # If input[0] is not quantized, do not quanitize this node
23
+ if (self.quantizer.opset_version < 11) or (node.input[0] not in self.quantizer.quantized_value_map):
24
+ super().quantize()
25
+ return
26
+ quantized_input_value = self.quantizer.quantized_value_map[node.input[0]]
27
+
28
+ kwargs = {}
29
+ for attribute in node.attribute:
30
+ kv = attribute_to_kwarg(attribute)
31
+ kwargs.update(kv)
32
+
33
+ if "mode" not in kwargs or kwargs["mode"] == b"constant":
34
+ if len(node.input) > 2 and node.input[2] != "": # There is 3rd input 'constant_value'
35
+ zp_tensor = self.quantizer.model.get_initializer(quantized_input_value.zp_name)
36
+ scale_tensor = self.quantizer.model.get_initializer(quantized_input_value.scale_name)
37
+ if zp_tensor is None or scale_tensor is None:
38
+ super().quantize()
39
+ return
40
+
41
+ padding_constant_initializer = self.quantizer.model.get_initializer(node.input[2])
42
+ if padding_constant_initializer is not None:
43
+ zp_array = onnx.numpy_helper.to_array(zp_tensor)
44
+ zp_value = zp_array.item() if zp_array.ndim == 0 else zp_array[0]
45
+ scale_array = onnx.numpy_helper.to_array(scale_tensor)
46
+ scale_value = scale_array.item() if scale_array.ndim == 0 else scale_array[0]
47
+ padding_constant_array = onnx.numpy_helper.to_array(padding_constant_initializer)
48
+ quantized_padding_constant_array = quantize_nparray(
49
+ self.quantizer.activation_qType,
50
+ padding_constant_array,
51
+ scale_value,
52
+ zp_value,
53
+ )
54
+ quantized_padding_constant_name = node.input[2] + TENSOR_NAME_QUANT_SUFFIX
55
+ quantized_padding_constant_initializer = onnx.numpy_helper.from_array(
56
+ quantized_padding_constant_array,
57
+ quantized_padding_constant_name,
58
+ )
59
+ # Suppose this padding constant initializer only used by the node
60
+ self.quantizer.model.remove_initializer(padding_constant_initializer)
61
+ self.quantizer.model.add_initializer(quantized_padding_constant_initializer)
62
+ node.input[2] = quantized_padding_constant_name
63
+ else:
64
+ # TODO: check quantize_inputs after sub graph is supported
65
+ pad_value_qnodes = self.quantizer._get_quantize_input_nodes(
66
+ node,
67
+ 2,
68
+ self.quantizer.activation_qType,
69
+ quantized_input_value.scale_name,
70
+ quantized_input_value.zp_name,
71
+ initial_type=scale_tensor.data_type,
72
+ )
73
+ self.quantizer.new_nodes.extend(pad_value_qnodes)
74
+ node.input[2] = pad_value_qnodes[0].output[0]
75
+ else:
76
+ # In quantized format, the `zero` before quantization is mapped
77
+ # to quantized_input_value.zp_name. Thus, padding 0 to
78
+ # original tensor should become padding zero point to quantized
79
+ # tensor.
80
+ if len(node.input) == 2:
81
+ # Feed quantization's zero point to padding node.
82
+ node.input.append(quantized_input_value.zp_name)
83
+ else:
84
+ # Assign quantization's zero point to padding node.
85
+ assert node.input[2] == ""
86
+ node.input[2] = quantized_input_value.zp_name
87
+
88
+ # Create an entry for output quantized value
89
+ quantized_output_value = QuantizedValue(
90
+ node.output[0],
91
+ node.output[0] + TENSOR_NAME_QUANT_SUFFIX,
92
+ quantized_input_value.scale_name,
93
+ quantized_input_value.zp_name,
94
+ QuantizedValueType.Input,
95
+ )
96
+ self.quantizer.quantized_value_map[node.output[0]] = quantized_output_value
97
+
98
+ node.input[0] = quantized_input_value.q_name
99
+ node.output[0] = quantized_output_value.q_name
100
+ self.quantizer.new_nodes += [node]
@@ -0,0 +1,67 @@
1
+ import onnx
2
+
3
+ from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
4
+ from .base_operator import QuantOperatorBase
5
+
6
+
7
+ class QLinearPool(QuantOperatorBase):
8
+ def __init__(self, onnx_quantizer, onnx_node):
9
+ super().__init__(onnx_quantizer, onnx_node)
10
+
11
+ def quantize(self):
12
+ node = self.node
13
+
14
+ # only try to quantize when given quantization parameters for it
15
+ (
16
+ data_found,
17
+ output_scale_name,
18
+ output_zp_name,
19
+ _,
20
+ _,
21
+ ) = self.quantizer._get_quantization_params(node.output[0])
22
+
23
+ # get quantized input tensor names, quantize input if needed
24
+ (
25
+ quantized_input_names,
26
+ input_zero_point_names,
27
+ input_scale_names,
28
+ nodes,
29
+ ) = self.quantizer.quantize_activation(node, [0])
30
+
31
+ if not data_found or quantized_input_names is None:
32
+ return super().quantize()
33
+
34
+ # Create an entry for output quantized value.
35
+ qlinear_output_name = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
36
+ quantized_output_value = QuantizedValue(
37
+ node.output[0],
38
+ qlinear_output_name,
39
+ output_scale_name,
40
+ output_zp_name,
41
+ QuantizedValueType.Input,
42
+ )
43
+ self.quantizer.quantized_value_map[node.output[0]] = quantized_output_value
44
+
45
+ # Create qlinear pool node for given type (AveragePool, etc)
46
+ kwargs = {}
47
+ for attribute in node.attribute:
48
+ kwargs.update(attribute_to_kwarg(attribute))
49
+ kwargs["domain"] = ms_domain
50
+ qlinear_node_name = node.name + "_quant" if node.name else ""
51
+ qnode = onnx.helper.make_node(
52
+ "QLinear" + node.op_type,
53
+ [
54
+ quantized_input_names[0],
55
+ input_scale_names[0],
56
+ input_zero_point_names[0],
57
+ output_scale_name,
58
+ output_zp_name,
59
+ ],
60
+ [qlinear_output_name],
61
+ qlinear_node_name,
62
+ **kwargs,
63
+ )
64
+
65
+ # add all newly created nodes
66
+ nodes.append(qnode)
67
+ self.quantizer.new_nodes += nodes
@@ -0,0 +1,22 @@
1
+ import itertools
2
+
3
+ from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg, quantize_nparray # noqa: F401
4
+ from .base_operator import QuantOperatorBase # noqa: F401
5
+
6
+
7
+ class QDQOperatorBase:
8
+ def __init__(self, onnx_quantizer, onnx_node):
9
+ self.quantizer = onnx_quantizer
10
+ self.node = onnx_node
11
+ self.disable_qdq_for_node_output = onnx_node.op_type in onnx_quantizer.op_types_to_exclude_output_quantization
12
+
13
+ def quantize(self):
14
+ node = self.node
15
+
16
+ if self.disable_qdq_for_node_output:
17
+ tensors_to_quantize = node.input
18
+ else:
19
+ tensors_to_quantize = itertools.chain(node.input, node.output)
20
+
21
+ for tensor_name in tensors_to_quantize:
22
+ self.quantizer.quantize_activation_tensor(tensor_name)
@@ -0,0 +1,34 @@
1
+ from .direct_q8 import Direct8BitOp, QDQDirect8BitOp
2
+
3
+
4
+ class QResize(Direct8BitOp):
5
+ def __init__(self, onnx_quantizer, onnx_node):
6
+ super().__init__(onnx_quantizer, onnx_node)
7
+
8
+ def quantize(self):
9
+ node = self.node
10
+ assert node.op_type == "Resize"
11
+
12
+ # if version is less than 11, go to normal quantize.
13
+ if self.quantizer.opset_version < 11:
14
+ super(Direct8BitOp, self).quantize()
15
+ return
16
+
17
+ # Direct 8bits op
18
+ return super().quantize()
19
+
20
+
21
+ class QDQResize(QDQDirect8BitOp):
22
+ def __init__(self, onnx_quantizer, onnx_node):
23
+ super().__init__(onnx_quantizer, onnx_node)
24
+
25
+ def quantize(self):
26
+ node = self.node
27
+ assert node.op_type == "Resize"
28
+
29
+ # if version is less than 11, just keep this node
30
+ if self.quantizer.opset_version < 11:
31
+ return
32
+
33
+ # Direct 8bits op
34
+ return super().quantize()
@@ -0,0 +1,74 @@
1
+ import onnx
2
+ import onnx.helper
3
+
4
+ from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
5
+ from .base_operator import QuantOperatorBase
6
+
7
+
8
+ class QLinearSoftmax(QuantOperatorBase):
9
+ def quantize(self):
10
+ node = self.node
11
+ # set limitations for softmax output scale and zp, because the output of softmax is always 0-1
12
+ if self.quantizer.activation_qType == onnx.onnx_pb.TensorProto.UINT8:
13
+ out_scale = 1 / 256.0
14
+ out_zero_point = 0
15
+ else:
16
+ out_scale = 1 / 256.0
17
+ out_zero_point = -128
18
+ # only try to quantize when given quantization parameters for it
19
+ (
20
+ data_found,
21
+ output_scale_name,
22
+ output_zp_name,
23
+ _,
24
+ _,
25
+ ) = self.quantizer._get_quantization_params(node.output[0], out_scale, out_zero_point)
26
+
27
+ # get quantized input tensor names, quantize input if needed
28
+ (
29
+ quantized_input_names,
30
+ input_zero_point_names,
31
+ input_scale_names,
32
+ nodes,
33
+ ) = self.quantizer.quantize_activation(node, [0])
34
+
35
+ if not data_found or quantized_input_names is None:
36
+ return super().quantize()
37
+
38
+ # Create an entry for output quantized value.
39
+ qlinear_output_name = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
40
+ quantized_output_value = QuantizedValue(
41
+ node.output[0],
42
+ qlinear_output_name,
43
+ output_scale_name,
44
+ output_zp_name,
45
+ QuantizedValueType.Input,
46
+ )
47
+ self.quantizer.quantized_value_map[node.output[0]] = quantized_output_value
48
+
49
+ # Create qlinear softmax node for given type
50
+ kwargs = {}
51
+ for attribute in node.attribute:
52
+ kwargs.update(attribute_to_kwarg(attribute))
53
+ kwargs["domain"] = ms_domain
54
+ # make qlinearsoft has the real opset_version, its default SinceVersion would be 1
55
+ kwargs["opset"] = self.quantizer.opset_version
56
+ qlinear_node_name = node.name + "_quant" if node.name else ""
57
+ qnode = onnx.helper.make_node(
58
+ "QLinear" + node.op_type,
59
+ [
60
+ quantized_input_names[0],
61
+ input_scale_names[0],
62
+ input_zero_point_names[0],
63
+ output_scale_name,
64
+ output_zp_name,
65
+ ],
66
+ [qlinear_output_name],
67
+ qlinear_node_name,
68
+ **kwargs,
69
+ )
70
+
71
+ # add all newly created nodes
72
+ nodes.append(qnode)
73
+ self.quantizer.new_nodes += nodes
74
+ return None
@@ -0,0 +1,63 @@
1
+ import onnx
2
+
3
+ from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg
4
+ from .base_operator import QuantOperatorBase
5
+ from .qdq_base_operator import QDQOperatorBase
6
+
7
+
8
+ class QSplit(QuantOperatorBase):
9
+ def __init__(self, onnx_quantizer, onnx_node):
10
+ super().__init__(onnx_quantizer, onnx_node)
11
+
12
+ def quantize(self):
13
+ node = self.node
14
+ (
15
+ quantized_input_names,
16
+ zero_point_names,
17
+ scale_names,
18
+ nodes,
19
+ ) = self.quantizer.quantize_activation(node, [0])
20
+ if quantized_input_names is None:
21
+ return super().quantize()
22
+
23
+ quantized_node_name = ""
24
+ if node.name:
25
+ quantized_node_name = node.name + "_quant"
26
+ kwargs = {}
27
+ for attribute in node.attribute:
28
+ kwargs.update(attribute_to_kwarg(attribute))
29
+
30
+ # Output just derive the scale/zero from input
31
+ quantized_output_names = []
32
+ for output_name in node.output:
33
+ quantized_output_name = output_name + "quantized"
34
+ quantized_output_names.append(quantized_output_name)
35
+ q_output = QuantizedValue(
36
+ output_name,
37
+ quantized_output_name,
38
+ scale_names[0],
39
+ zero_point_names[0],
40
+ QuantizedValueType.Input,
41
+ )
42
+ self.quantizer.quantized_value_map[output_name] = q_output
43
+
44
+ if len(node.input) > 1:
45
+ quantized_input_names.extend(node.input[1:])
46
+ quantized_node = onnx.helper.make_node(
47
+ node.op_type, quantized_input_names, quantized_output_names, quantized_node_name, **kwargs
48
+ )
49
+
50
+ nodes.append(quantized_node)
51
+ self.quantizer.new_nodes += nodes
52
+
53
+
54
+ class QDQSplit(QDQOperatorBase):
55
+ def quantize(self):
56
+ node = self.node
57
+ assert node.op_type == "Split"
58
+
59
+ if not self.quantizer.is_tensor_quantized(node.input[0]):
60
+ self.quantizer.quantize_activation_tensor(node.input[0])
61
+ if not self.disable_qdq_for_node_output:
62
+ for output in node.output:
63
+ self.quantizer.quantize_output_same_as_input(output, node.input[0], node.name)