onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1251 @@
1
+ #
2
+ # The implementation of this file is based on:
3
+ # https://github.com/intel/neural-compressor/tree/master/neural_compressor
4
+ #
5
+ # Copyright (c) 2023 Intel Corporation
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ """Class for ONNX model."""
20
+
21
+ import copy
22
+ import logging
23
+ import os
24
+ import sys
25
+ from collections import deque
26
+ from pathlib import Path
27
+
28
+ import onnx
29
+ import onnx.external_data_helper
30
+
31
+ from .util import MAXIMUM_PROTOBUF, find_by_name
32
+
33
+ logger = logging.getLogger("neural_compressor")
34
+
35
+ # TODO: Check https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/onnx_model.py to see if we can integrate with it.
36
+
37
+
38
+ class ONNXModel:
39
+ """Build ONNX model."""
40
+
41
+ def __init__(self, model, **kwargs):
42
+ """Initialize an ONNX model.
43
+
44
+ Args:
45
+ model (str or ModelProto): path to onnx model or loaded ModelProto model object.
46
+ ignore_warning (bool): ignore large model warning. Default is False.
47
+ load_external_data (bool): load external data for large model. Default is True.
48
+ """
49
+ self._model = model if not isinstance(model, str) else onnx.load(model, load_external_data=False)
50
+ self._model_path = None if not isinstance(model, str) else model
51
+
52
+ self.check_is_large_model()
53
+ if self._is_large_model and self._model_path is None and not kwargs.get("ignore_warning", False):
54
+ logger.warning("Model size > 2GB. Please use model path instead of onnx model object to quantize")
55
+
56
+ if self._is_large_model and isinstance(model, str) and kwargs.get("load_external_data", True):
57
+ onnx.external_data_helper.load_external_data_for_model(self._model, os.path.dirname(self._model_path))
58
+
59
+ self._config = None
60
+ if isinstance(model, str) and os.path.exists(Path(model).parent.joinpath("config.json").as_posix()):
61
+ from transformers import AutoConfig # noqa: PLC0415
62
+
63
+ self._config = AutoConfig.from_pretrained(Path(model).parent.as_posix())
64
+
65
+ self.node_name_counter = {}
66
+ self._output_name_to_node = {}
67
+ self._input_name_to_nodes = {}
68
+ self._get_input_name_to_nodes(self._model.graph.node)
69
+ self._get_output_name_to_node(self._model.graph.node)
70
+ self._graph_info = {}
71
+ self._get_graph_info()
72
+ self._q_config = None
73
+
74
+ def check_is_large_model(self):
75
+ """Check model > 2GB."""
76
+ init_size = 0
77
+ for init in self._model.graph.initializer:
78
+ # if initializer has external data location, return True
79
+ if init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL:
80
+ self._is_large_model = True
81
+ return
82
+ # if raise error of initializer size > 2GB, return True
83
+ try:
84
+ init_bytes = init.SerializeToString()
85
+ init_size += sys.getsizeof(init_bytes)
86
+ except Exception as e:
87
+ if "exceeds maximum protobuf size of 2GB" in str(e):
88
+ self._is_large_model = True
89
+ return
90
+ else: # pragma: no cover
91
+ raise e
92
+ if init_size > MAXIMUM_PROTOBUF:
93
+ self._is_large_model = True
94
+ return
95
+ self._is_large_model = False
96
+
97
+ @property
98
+ def is_large_model(self):
99
+ """Check the onnx model is over 2GB."""
100
+ return self._is_large_model
101
+
102
+ @property
103
+ def model_path(self):
104
+ """Return model path."""
105
+ return self._model_path
106
+
107
+ @model_path.setter
108
+ def model_path(self, path):
109
+ """Set model path."""
110
+ self._model_path = path
111
+
112
+ def framework(self):
113
+ """Return framework."""
114
+ return "onnxruntime"
115
+
116
+ @property
117
+ def q_config(self):
118
+ """Return q_config."""
119
+ return self._q_config
120
+
121
+ @q_config.setter
122
+ def q_config(self, q_config):
123
+ """Set q_config."""
124
+ self._q_config = q_config
125
+
126
+ @property
127
+ def hf_config(self):
128
+ """Return huggingface config if model is Transformer-based."""
129
+ return self._config
130
+
131
+ @property
132
+ def model(self):
133
+ """Return model itself."""
134
+ return self._model
135
+
136
+ @model.setter
137
+ def model(self, model):
138
+ """Set model itself."""
139
+ self._model = model
140
+ self._graph_info = {}
141
+ self._get_graph_info()
142
+ self._output_name_to_node = {}
143
+ self._input_name_to_nodes = {}
144
+ self._get_input_name_to_nodes(self._model.graph.node)
145
+ self._get_output_name_to_node(self._model.graph.node)
146
+
147
+ def input(self):
148
+ """Return input of model."""
149
+ return [i.name for i in self._model.graph.input]
150
+
151
+ def output(self):
152
+ """Return output of model."""
153
+ return [i.name for i in self._model.graph.output]
154
+
155
+ def update(self):
156
+ """Update model info."""
157
+ self._graph_info = {}
158
+ self._get_graph_info()
159
+ self._output_name_to_node = {}
160
+ self._input_name_to_nodes = {}
161
+ self._get_input_name_to_nodes(self._model.graph.node)
162
+ self._get_output_name_to_node(self._model.graph.node)
163
+
164
+ @property
165
+ def graph_info(self):
166
+ """Return ORT Graph Info object holding information about backend graph."""
167
+ return self._graph_info
168
+
169
+ def _get_graph_info(self):
170
+ """Update graph info."""
171
+ for node in self._model.graph.node:
172
+ self.graph_info.update({node.name: node.op_type})
173
+
174
+ def save(self, root):
175
+ """Save ONNX model."""
176
+ if os.path.split(root)[0] != "" and not os.path.exists(os.path.split(root)[0]):
177
+ raise ValueError('"root" directory does not exists.')
178
+ if self.is_large_model:
179
+ onnx.external_data_helper.load_external_data_for_model(self._model, os.path.split(self._model_path)[0])
180
+ onnx.save_model(
181
+ self._model,
182
+ root,
183
+ save_as_external_data=True,
184
+ all_tensors_to_one_file=True,
185
+ location=root.split("/")[-1] + "_data",
186
+ size_threshold=1024,
187
+ convert_attribute=False,
188
+ )
189
+ else:
190
+ onnx.save(self._model, root)
191
+
192
+ if self._config is not None:
193
+ model_type = "" if not hasattr(self._config, "model_type") else self._config.model_type
194
+ self._config.__class__.model_type = model_type
195
+ output_config_file = Path(root).parent.joinpath("config.json").as_posix()
196
+ self._config.to_json_file(output_config_file, use_diff=False)
197
+
198
+ def nodes(self):
199
+ """Return model nodes."""
200
+ return self._model.graph.node
201
+
202
+ def initializer(self):
203
+ """Return model initializer."""
204
+ return self._model.graph.initializer
205
+
206
+ def graph(self):
207
+ """Return model graph."""
208
+ return self._model.graph
209
+
210
+ def ir_version(self):
211
+ """Return model ir_version."""
212
+ return self._model.ir_version
213
+
214
+ def opset_import(self):
215
+ """Return model opset_import."""
216
+ return self._model.opset_import
217
+
218
+ def remove_node(self, node):
219
+ """Remove a node from model."""
220
+ if node in self._model.graph.node:
221
+ self._model.graph.node.remove(node)
222
+
223
+ def remove_nodes(self, nodes_to_remove):
224
+ """Remove nodes from model."""
225
+ for node in nodes_to_remove:
226
+ self.remove_node(node)
227
+
228
+ def add_node(self, node):
229
+ """Add a node to model."""
230
+ self._model.graph.node.extend([node])
231
+
232
+ def add_nodes(self, nodes_to_add):
233
+ """Add nodes to model."""
234
+ self._model.graph.node.extend(nodes_to_add)
235
+
236
+ def add_initializer(self, tensor):
237
+ """Add a initializer to model."""
238
+ if find_by_name(tensor.name, self._model.graph.initializer) is None:
239
+ self._model.graph.initializer.extend([tensor])
240
+
241
+ def add_initializers(self, tensors):
242
+ """Add initializers to model."""
243
+ for tensor in tensors:
244
+ self.add_initializer(tensor)
245
+
246
+ def get_initializer(self, name):
247
+ """Get an initializer by name."""
248
+ for tensor in self._model.graph.initializer:
249
+ if tensor.name == name:
250
+ return tensor
251
+ return None
252
+
253
+ def get_initializer_share_num(self, name):
254
+ """Get the number of shares of initializer."""
255
+ num = 0
256
+ if self.get_initializer(name) is None:
257
+ return num
258
+
259
+ for node in self.nodes():
260
+ if name in node.input:
261
+ num += 1
262
+ return num
263
+
264
+ def get_node(self, name):
265
+ """Get a node by name."""
266
+ for node in self._model.graph.node:
267
+ if node.name == name:
268
+ return node
269
+ return None
270
+
271
+ def remove_initializer(self, tensor):
272
+ """Remove an initializer from model."""
273
+ if tensor in self._model.graph.initializer:
274
+ self._model.graph.initializer.remove(tensor)
275
+
276
+ def remove_initializers(self, init_to_remove):
277
+ """Remove initializers from model."""
278
+ for initializer in init_to_remove:
279
+ self.remove_initializer(initializer)
280
+
281
+ def set_initializer(self, tensor, array, raw=False):
282
+ """Update initializer."""
283
+ old_tensor = self.get_initializer(tensor)
284
+ self.remove_initializer(old_tensor)
285
+ dims = old_tensor.dims
286
+ data_type = old_tensor.data_type
287
+ new_tensor = (
288
+ onnx.helper.make_tensor(tensor, data_type, dims, array.flatten().tolist())
289
+ if not raw
290
+ else onnx.helper.make_tensor(tensor, data_type, dims, array.tostring(), raw=raw)
291
+ )
292
+ self.add_initializer(new_tensor)
293
+
294
+ @property
295
+ def input_name_to_nodes(self):
296
+ """Return input names of nodes."""
297
+ return self._input_name_to_nodes
298
+
299
+ def _get_input_name_to_nodes(self, nodes):
300
+ """Get input names of nodes."""
301
+ for node in nodes:
302
+ attrs = [
303
+ attr
304
+ for attr in node.attribute
305
+ if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS
306
+ ]
307
+ if len(attrs) > 0:
308
+ for attr in attrs:
309
+ self._get_input_name_to_nodes(attr.g.node)
310
+ for input_name in node.input:
311
+ if len(input_name.strip()) != 0:
312
+ if input_name not in self._input_name_to_nodes:
313
+ self._input_name_to_nodes[input_name] = [node]
314
+ else:
315
+ self._input_name_to_nodes[input_name].append(node)
316
+
317
+ @property
318
+ def output_name_to_node(self):
319
+ """Return output names of nodes."""
320
+ return self._output_name_to_node
321
+
322
+ def _get_output_name_to_node(self, nodes):
323
+ """Get output names of nodes."""
324
+ for node in nodes:
325
+ attrs = [
326
+ attr
327
+ for attr in node.attribute
328
+ if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS
329
+ ]
330
+ if len(attrs) > 0:
331
+ for attr in attrs:
332
+ self._get_output_name_to_node(attr.g.node)
333
+ for output_name in node.output:
334
+ if len(output_name.strip()) != 0:
335
+ self._output_name_to_node[output_name] = node
336
+
337
+ def get_siblings(self, node):
338
+ """Get siblings nodes."""
339
+ siblings = []
340
+ for parent in self.get_parents(node):
341
+ for child in self.get_children(parent):
342
+ if child.name != node.name:
343
+ siblings.append(child)
344
+ return siblings
345
+
346
+ def get_children(self, node, input_name_to_nodes=None):
347
+ """Get children nodes."""
348
+ if input_name_to_nodes is None:
349
+ input_name_to_nodes = self._input_name_to_nodes
350
+
351
+ children = []
352
+ for output in node.output:
353
+ if output in input_name_to_nodes:
354
+ for child in input_name_to_nodes[output]:
355
+ children.append(child) # noqa: PERF402
356
+ return children
357
+
358
+ def get_parents(self, node, output_name_to_node=None):
359
+ """Get parents nodes."""
360
+ if output_name_to_node is None:
361
+ output_name_to_node = self._output_name_to_node
362
+
363
+ parents = []
364
+ for input in node.input:
365
+ if input in output_name_to_node:
366
+ parents.append(output_name_to_node[input])
367
+ return parents
368
+
369
+ def get_parent(self, node, idx, output_name_to_node=None):
370
+ """Get parent node by idx."""
371
+ if output_name_to_node is None:
372
+ output_name_to_node = self._output_name_to_node
373
+
374
+ if len(node.input) <= idx:
375
+ return None
376
+
377
+ input = node.input[idx]
378
+ if input not in output_name_to_node:
379
+ return None
380
+
381
+ return output_name_to_node[input]
382
+
383
+ def find_node_by_name(self, node_name, new_nodes_list, graph):
384
+ """Find out node by name."""
385
+ graph_nodes_list = list(graph.node) # deep copy
386
+ graph_nodes_list.extend(new_nodes_list)
387
+ node = find_by_name(node_name, graph_nodes_list)
388
+ return node
389
+
390
+ def find_nodes_by_initializer(self, graph, initializer):
391
+ """Find all nodes with given initializer as an input."""
392
+ nodes = []
393
+ for node in graph.node:
394
+ for node_input in node.input:
395
+ if node_input == initializer.name:
396
+ nodes.append(node)
397
+ return nodes
398
+
399
+ def get_scale_zero(self, tensor):
400
+ """Help function to get scale and zero_point."""
401
+ if not tensor.endswith("_quantized"):
402
+ logger.debug(f"Find {tensor} in the quantized graph is not quantized.")
403
+ return None, None
404
+
405
+ def _searcher(tensor_name):
406
+ """Search scale and zero point tensor recursively."""
407
+ node = self._input_name_to_nodes[tensor_name][0]
408
+ parent = self._output_name_to_node.get(tensor_name, None)
409
+ direct_int8 = ["Reshape", "Transpose", "Squeeze", "Unsqueeze", "MaxPool", "Pad", "Split"]
410
+ if parent is not None and parent.op_type in direct_int8:
411
+ fp32_tensor_name = (
412
+ parent.input[0]
413
+ .replace("_quantized", "")
414
+ .replace("_QuantizeLinear", "")
415
+ .replace("_QuantizeInput", "")
416
+ )
417
+ elif node.op_type in ["Gather"]: # pragma: no cover
418
+ fp32_tensor_name = (
419
+ node.output[0]
420
+ .replace("_quantized", "")
421
+ .replace("_QuantizeLinear", "")
422
+ .replace("_QuantizeInput", "")
423
+ )
424
+ else:
425
+ fp32_tensor_name = (
426
+ tensor_name.replace("_quantized", "").replace("_QuantizeLinear", "").replace("_QuantizeInput", "")
427
+ )
428
+ scale = fp32_tensor_name + "_scale"
429
+ scale_tensor = self.get_initializer(scale)
430
+ zo = fp32_tensor_name + "_zero_point"
431
+ zo_tensor = self.get_initializer(zo)
432
+
433
+ if scale_tensor is None or zo_tensor is None:
434
+ if parent is not None:
435
+ scale_tensor, zo_tensor = _searcher(parent.input[0])
436
+ return scale_tensor, zo_tensor
437
+
438
+ node = self._input_name_to_nodes[tensor][0]
439
+ # TODO check if scale_tensor and zero_point is needed
440
+ # for bias of qlinearconv, scale and zero_point is not needed
441
+ if (node.op_type == "QLinearConv" and tensor == node.input[-1]) or (
442
+ node.op_type == "QGemm" and tensor == node.input[-3]
443
+ ):
444
+ return None, None
445
+ else:
446
+ scale_tensor, zo_tensor = _searcher(tensor)
447
+ assert scale_tensor, f"missing scale for tensor {tensor}"
448
+ assert zo_tensor, f"missing zero point for tensor {tensor}"
449
+ return scale_tensor, zo_tensor
450
+
451
+ def save_model_to_file(self, output_path, use_external_data_format=False):
452
+ """Save model to external data, which is needed for model size > 2GB."""
453
+ if use_external_data_format:
454
+ onnx.external_data_helper.convert_model_to_external_data(
455
+ self._model, all_tensors_to_one_file=True, location=Path(output_path).name + ".data"
456
+ )
457
+ onnx.save_model(self._model, output_path)
458
+
459
+ @staticmethod
460
+ def replace_node_input(node, old_input_name, new_input_name):
461
+ """Replace input of a node."""
462
+ assert isinstance(old_input_name, str) and isinstance(new_input_name, str)
463
+ for j in range(len(node.input)):
464
+ if node.input[j] == old_input_name:
465
+ node.input[j] = new_input_name
466
+
467
+ def replace_input_of_all_nodes(self, old_input_name, new_input_name, white_optype=None, black_optype=None):
468
+ """Replace inputs of all nodes."""
469
+ if white_optype is None:
470
+ white_optype = []
471
+ if black_optype is None:
472
+ black_optype = []
473
+ if len(white_optype) > 0:
474
+ for node in self.model.graph.node:
475
+ if node.op_type in white_optype:
476
+ ONNXModel.replace_node_input(node, old_input_name, new_input_name)
477
+ else:
478
+ for node in self.model.graph.node:
479
+ if node.op_type not in black_optype:
480
+ ONNXModel.replace_node_input(node, old_input_name, new_input_name)
481
+
482
+ @staticmethod
483
+ def replace_node_output(node, old_output_name, new_output_name):
484
+ """Replace output of a node."""
485
+ assert isinstance(old_output_name, str) and isinstance(new_output_name, str)
486
+ for j in range(len(node.output)):
487
+ if node.output[j] == old_output_name:
488
+ node.output[j] = new_output_name
489
+
490
+ def replace_output_of_all_nodes(self, old_output_name, new_output_name, white_optype=None, black_optype=None):
491
+ """Replace outputs of all nodes."""
492
+ if white_optype is None:
493
+ white_optype = []
494
+ if black_optype is None:
495
+ black_optype = []
496
+ if len(white_optype) > 0:
497
+ for node in self.model.graph.node:
498
+ if node.op_type in white_optype:
499
+ ONNXModel.replace_node_output(node, old_output_name, new_output_name)
500
+ else:
501
+ for node in self.model.graph.node:
502
+ if node.op_type not in black_optype:
503
+ ONNXModel.replace_node_output(node, old_output_name, new_output_name)
504
+
505
+ def remove_unused_nodes(self):
506
+ """Remove unused nodes."""
507
+ unused_nodes = []
508
+ nodes = self.nodes()
509
+ for node in nodes:
510
+ if (
511
+ node.op_type == "Constant"
512
+ and node.output[0] not in self._model.graph.output
513
+ and node.output[0] not in self._input_name_to_nodes
514
+ ):
515
+ unused_nodes.append(node)
516
+ elif (
517
+ node.op_type == "QuantizeLinear"
518
+ and len(self.get_children(node)) == 1
519
+ and self.get_children(node)[0].op_type == "DequantizeLinear"
520
+ and node.input[0] not in self._output_name_to_node
521
+ and self.get_children(node)[0].output[0] not in self._input_name_to_nodes
522
+ ):
523
+ unused_nodes.append(node)
524
+ unused_nodes.extend(self.get_children(node))
525
+ else:
526
+ # remove the node if it does not serve as the input or output of any other nodes
527
+ unused = True
528
+ for output in node.output:
529
+ if output in self._input_name_to_nodes or output in self.output():
530
+ unused = False
531
+ break
532
+ for input in node.input:
533
+ if self.get_initializer(input) is not None:
534
+ continue
535
+ elif input in self._output_name_to_node or input in self.input():
536
+ unused = False
537
+ break
538
+ if unused:
539
+ unused_nodes.append(node)
540
+ self.remove_nodes(unused_nodes)
541
+
542
+ ununsed_weights = []
543
+ for w in self._model.graph.initializer:
544
+ if w.name not in self._input_name_to_nodes and w.name not in self._model.graph.output:
545
+ ununsed_weights.append(w)
546
+ # Remove from graph.input
547
+ for graph_input in self.graph().input:
548
+ if graph_input.name == w.name:
549
+ self.graph().input.remove(graph_input)
550
+
551
+ self.remove_initializers(ununsed_weights)
552
+ self.update()
553
+
554
+ def topological_sort(self, enable_subgraph=False):
555
+ """Topological sort the model."""
556
+
557
+ if not enable_subgraph:
558
+ input_name_to_nodes = {}
559
+ output_name_to_node = {}
560
+ for node in self.model.graph.node:
561
+ for input_name in node.input:
562
+ if len(input_name.strip()) != 0:
563
+ if input_name not in input_name_to_nodes:
564
+ input_name_to_nodes[input_name] = [node]
565
+ else:
566
+ input_name_to_nodes[input_name].append(node)
567
+ for output_name in node.output:
568
+ if len(output_name.strip()) != 0:
569
+ output_name_to_node[output_name] = node
570
+ else: # pragma: no cover
571
+ input_name_to_nodes = self._input_name_to_nodes
572
+ output_name_to_node = self._output_name_to_node
573
+
574
+ all_nodes = {}
575
+ q = deque()
576
+ wait = deque()
577
+ for inp in self.model.graph.input:
578
+ q.extend(input_name_to_nodes[inp.name])
579
+ for n in self.model.graph.node:
580
+ if all(i not in output_name_to_node and i not in self.input() for i in n.input):
581
+ q.append(n)
582
+
583
+ while q:
584
+ n = q.popleft()
585
+ if not all(output_name_to_node[i].name in all_nodes for i in n.input if i in output_name_to_node):
586
+ if n not in wait:
587
+ wait.append(n)
588
+ continue
589
+
590
+ all_nodes[n.name] = n
591
+ for out in n.output:
592
+ if out in input_name_to_nodes:
593
+ q.extend([i for i in input_name_to_nodes[out] if i.name not in all_nodes and i not in q])
594
+ if len(q) == 0 and len(wait) != 0:
595
+ q = copy.deepcopy(wait)
596
+ wait.clear()
597
+ nodes = [i[1] for i in all_nodes.items()]
598
+ assert len(list({n.name for n in nodes})) == len(list({n.name for n in self.model.graph.node}))
599
+ self.model.graph.ClearField("node")
600
+ self.model.graph.node.extend(nodes)
601
+
602
+ def get_nodes_chain(self, start, stop, result_chain=None):
603
+ """Get nodes chain with given start node and stop node."""
604
+ if result_chain is None:
605
+ result_chain = []
606
+ # process start node list
607
+ start_node = deque()
608
+ for node in start:
609
+ if isinstance(node, str):
610
+ start_node.append(node)
611
+ elif isinstance(node, onnx.NodeProto):
612
+ start_node.append(node.name)
613
+ else:
614
+ assert False, "'get_nodes_chain' function only support list[string]or list[NodeProto] params" # noqa: B011
615
+
616
+ # process stop node list
617
+ stop_node = []
618
+ for node in stop:
619
+ if isinstance(node, str):
620
+ stop_node.append(node)
621
+ elif isinstance(node, onnx.NodeProto):
622
+ stop_node.append(node.name)
623
+ else:
624
+ assert False, "'get_nodes_chain' function only support list[string]or list[NodeProto] params" # noqa: B011
625
+
626
+ while start_node:
627
+ node_name = start_node.popleft()
628
+ if node_name in stop_node:
629
+ continue
630
+ if node_name not in result_chain:
631
+ result_chain.append(node_name)
632
+ else:
633
+ continue
634
+
635
+ node = find_by_name(node_name, list(self.model.graph.node))
636
+ for parent in self.get_parents(node):
637
+ start_node.append(parent.name)
638
+
639
+ return result_chain
640
+
641
+ def find_split_node_for_layer_wise_quantization(self):
642
+ """Find split node for layer wise quantization."""
643
+ # find split nodes of decoder blocks
644
+ # embed -> decoder.0 -(split_node)-> ... -(split_node)-> decoder.n -(split_node)-> norm -> head
645
+ # after split: embed -> decoder.0,
646
+ # decoder.1,
647
+ # decoder.2,
648
+ # ...,
649
+ # decoder.n,
650
+ # norm -> head
651
+ start_nodes = []
652
+ for node in self._model.graph.node:
653
+ start_node, qkv_nodes_list = None, None
654
+ if node.op_type == "SkipLayerNormalization":
655
+ start_node = node
656
+ qkv_nodes_list = [
657
+ self.match_parent_path(
658
+ start_node,
659
+ ["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
660
+ [None, 0, 0, 0, 0],
661
+ ),
662
+ self.match_parent_path(
663
+ start_node,
664
+ ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
665
+ [1, 1, 0, 0, 0],
666
+ ),
667
+ ]
668
+ if node.op_type == "Add":
669
+ start_node = node
670
+ qkv_nodes_list = [
671
+ # match base attention structure
672
+ self.match_parent_path(
673
+ start_node,
674
+ ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
675
+ [0, None, 0, 0, 0],
676
+ ),
677
+ self.match_parent_path(
678
+ start_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], [1, None, 0, 0, 0]
679
+ ),
680
+ # match gpt attention no past structure
681
+ self.match_parent_path(
682
+ start_node,
683
+ ["Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"],
684
+ [None, 0, 0, 0, 0, 0],
685
+ output_name_to_node=self.output_name_to_node,
686
+ return_indice=[],
687
+ ),
688
+ # match bart attention structure
689
+ self.match_parent_path(
690
+ start_node,
691
+ ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
692
+ [0, None, 0, 0, 0, 0],
693
+ ),
694
+ self.match_parent_path(
695
+ start_node,
696
+ ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
697
+ [1, None, 0, 0, 0, 0],
698
+ ),
699
+ self.match_parent_path(
700
+ start_node,
701
+ ["MatMul", "Mul", "MatMul", "Mul", "Div", "Add"],
702
+ [None, 0, None, 0, None, 0],
703
+ ),
704
+ self.match_parent_path(
705
+ start_node,
706
+ ["MatMul", "Mul", "MatMul", "SimplifiedLayerNormalization", "Add"],
707
+ [None, 0, None, 0, 0],
708
+ ),
709
+ ]
710
+ if not start_node:
711
+ continue
712
+ if not any(qkv_nodes_list):
713
+ continue
714
+ start_nodes.append(start_node)
715
+ return start_nodes
716
+
717
+ def find_qkv_in_attention(self, find_all=False):
718
+ """Find qkv MatMul in Attention.
719
+
720
+ Args:
721
+ find_all (bool, optional): find all qkv MatMul. Defaults to False
722
+
723
+ Returns:
724
+ qkv (list): qkv MatMul list
725
+ """
726
+ qkv = []
727
+ for node in self._model.graph.node:
728
+ if node.op_type == "Attention":
729
+ qkv.append([node.name])
730
+ continue
731
+ start_node, qkv_nodes_list = None, None
732
+ if node.op_type == "SkipLayerNormalization":
733
+ start_node = node
734
+ qkv_nodes_list = [
735
+ self.match_parent_path(
736
+ start_node,
737
+ ["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
738
+ [None, 0, 0, 0, 0],
739
+ ),
740
+ self.match_parent_path(
741
+ start_node,
742
+ ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
743
+ [1, 1, 0, 0, 0],
744
+ ),
745
+ ]
746
+ if node.op_type == "Add":
747
+ start_node = node
748
+ qkv_nodes_list = [
749
+ # match base attention structure
750
+ self.match_parent_path(
751
+ start_node,
752
+ ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
753
+ [0, None, 0, 0, 0],
754
+ ),
755
+ self.match_parent_path(
756
+ start_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], [1, None, 0, 0, 0]
757
+ ),
758
+ # match gpt attention no past structure
759
+ self.match_parent_path(
760
+ start_node,
761
+ ["Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"],
762
+ [None, 0, 0, 0, 0, 0],
763
+ output_name_to_node=self.output_name_to_node,
764
+ return_indice=[],
765
+ ),
766
+ # match bart attention structure
767
+ self.match_parent_path(
768
+ start_node,
769
+ ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
770
+ [0, None, 0, 0, 0, 0],
771
+ ),
772
+ self.match_parent_path(
773
+ start_node,
774
+ ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
775
+ [1, None, 0, 0, 0, 0],
776
+ ),
777
+ ]
778
+ if not start_node:
779
+ continue
780
+ if not any(qkv_nodes_list):
781
+ continue
782
+ qkv_nodes = [qkv for qkv in qkv_nodes_list if qkv is not None][-1]
783
+ other_inputs = []
784
+ for input in start_node.input:
785
+ if input not in self.output_name_to_node:
786
+ continue
787
+ if input == qkv_nodes[0].output[0]:
788
+ continue
789
+ other_inputs.append(input)
790
+ if len(other_inputs) != 1:
791
+ continue
792
+ root_input = other_inputs[0]
793
+ input_name_to_nodes = self.input_name_to_nodes
794
+ children = input_name_to_nodes[root_input]
795
+ children_types = [child.op_type for child in children]
796
+ if children_types.count("MatMul") == 3:
797
+ qkv.append([child.name for child in children if child.op_type == "MatMul"])
798
+ if not find_all:
799
+ break
800
+ return qkv
801
+
802
+ def find_ffn_matmul(self, attention_index, attention_matmul_list, block_len):
803
+ """Find MatMul in FFN.
804
+
805
+ Args:
806
+ attention_index (list): index of Attention
807
+ attention_matmul_list (list): list of Attention and MatMul nodes
808
+ block_len (int): block length
809
+
810
+ Returns:
811
+ list: list of MatMul in FFN
812
+ """
813
+ ffn_matmul = []
814
+ for idx in range(len(attention_index)):
815
+ if idx != len(attention_index) - 1:
816
+ index = attention_index[idx + 1]
817
+ if index - 2 >= 0:
818
+ ffn_matmul.append([attention_matmul_list[index - 2], attention_matmul_list[index - 1]])
819
+ else:
820
+ index = attention_index[idx]
821
+ if index + block_len - 1 < len(attention_matmul_list):
822
+ ffn_matmul.append(
823
+ [attention_matmul_list[index + block_len - 2], attention_matmul_list[index + block_len - 1]]
824
+ )
825
+ return ffn_matmul
826
+
827
+ def export(self, save_path, conf):
828
+ """Export Qlinear to QDQ model."""
829
+ from neural_compressor.config import ONNXQlinear2QDQConfig # noqa: PLC0415
830
+ from neural_compressor.utils.export import onnx_qlinear_to_qdq # noqa: PLC0415
831
+
832
+ if isinstance(conf, ONNXQlinear2QDQConfig):
833
+ add_nodes, remove_nodes, inits = onnx_qlinear_to_qdq(self._model, self._input_name_to_nodes)
834
+ self.add_nodes(add_nodes)
835
+ self.remove_nodes(remove_nodes)
836
+ self.add_initializers(inits)
837
+ self.update()
838
+ self.remove_unused_nodes()
839
+ self.topological_sort()
840
+ self.save(save_path)
841
+ else:
842
+ logger.warning("Unsupported config for export, only ONNXQlinear2QDQConfig is supported!")
843
+ exit(0)
844
+
845
+ def add_tensors_to_outputs(self, tensor_names):
846
+ """Add the tensors to the model outputs to gets their values.
847
+
848
+ Args:
849
+ tensor_names: The names of tensors to be dumped.
850
+ """
851
+ added_outputs = []
852
+ for tensor in tensor_names:
853
+ if tensor not in self.output():
854
+ added_tensor = onnx.helper.ValueInfoProto()
855
+ added_tensor.name = tensor
856
+ added_outputs.append(added_tensor)
857
+ self._model.graph.output.extend(added_outputs) # pylint: disable=no-member
858
+
859
+ def remove_tensors_from_outputs(self, tensor_names):
860
+ """Remove the tensors from the model outputs.
861
+
862
+ Args:
863
+ tensor_names: The names of tensors to be removed.
864
+ """
865
+ removed_outputs = []
866
+ for tensor in tensor_names:
867
+ if tensor in self.output():
868
+ removed_outputs.append(self._model.graph.output[self.output().index(tensor)])
869
+ for output in removed_outputs:
870
+ self._model.graph.output.remove(output)
871
+
872
+ def match_first_parent(self, node, parent_op_type, output_name_to_node, exclude=None):
873
+ """Find parent node based on constraints on op_type.
874
+
875
+ Args:
876
+ node (str): current node name.
877
+ parent_op_type (str): constraint of parent node op_type.
878
+ output_name_to_node (dict): dictionary with output name as key, and node as value.
879
+ exclude (list): list of nodes that are excluded (not allowed to match as parent).
880
+
881
+ Returns:
882
+ parent: The matched parent node. None if not found.
883
+ index: The input index of matched parent node. None if not found.
884
+ """
885
+ if exclude is None:
886
+ exclude = []
887
+ for i, input in enumerate(node.input):
888
+ if input in output_name_to_node:
889
+ parent = output_name_to_node[input]
890
+ if parent.op_type == parent_op_type and parent not in exclude:
891
+ return parent, i
892
+ return None, None
893
+
894
+ def match_parent(
895
+ self,
896
+ node,
897
+ parent_op_type,
898
+ input_index=None,
899
+ output_name_to_node=None,
900
+ exclude=None,
901
+ return_indice=None,
902
+ ):
903
+ """Find parent node based on constraints on op_type and index.
904
+
905
+ Args:
906
+ node (str): current node name.
907
+ parent_op_type (str): constraint of parent node op_type.
908
+ input_index (int or None): only check the parent given input index of current node.
909
+ output_name_to_node (dict): dictionary with output name as key, and node as value.
910
+ exclude (list): list of nodes that are excluded (not allowed to match as parent).
911
+ return_indice (list): a list to append the input index when input_index is None.
912
+
913
+ Returns:
914
+ parent: The matched parent node.
915
+ """
916
+ assert node is not None
917
+ assert input_index is None or input_index >= 0
918
+ if exclude is None:
919
+ exclude = []
920
+ if output_name_to_node is None:
921
+ output_name_to_node = self._output_name_to_node
922
+
923
+ if input_index is None:
924
+ parent, index = self.match_first_parent(node, parent_op_type, output_name_to_node, exclude)
925
+ if return_indice is not None:
926
+ return_indice.append(index)
927
+ return parent
928
+
929
+ if input_index >= len(node.input):
930
+ return None
931
+
932
+ parent = self.get_parent(node, input_index, output_name_to_node)
933
+ if parent is not None and parent.op_type == parent_op_type and parent not in exclude:
934
+ return parent
935
+
936
+ return None
937
+
938
+ def match_parent_path(
939
+ self,
940
+ node,
941
+ parent_op_types,
942
+ parent_input_index,
943
+ output_name_to_node=None,
944
+ return_indice=None,
945
+ ):
946
+ """Find a sequence of input edges based on constraints on parent op_type and index.
947
+
948
+ Args:
949
+ node (str): current node name.
950
+ parent_op_types (str): constraint of parent node op_type of each input edge.
951
+ parent_input_index (list): constraint of input index of each input edge.
952
+ None means no constraint.
953
+ output_name_to_node (dict): dictionary with output name as key, and node as value.
954
+ return_indice (list): a list to append the input index when there is
955
+ no constraint on input index of an edge.
956
+
957
+ Returns:
958
+ parents: a list of matched parent node.
959
+ """
960
+ assert len(parent_input_index) == len(parent_op_types)
961
+
962
+ if output_name_to_node is None:
963
+ output_name_to_node = self._output_name_to_node
964
+
965
+ current_node = node
966
+ matched_parents = []
967
+ for i, op_type in enumerate(parent_op_types):
968
+ matched_parent = self.match_parent(
969
+ current_node,
970
+ op_type,
971
+ parent_input_index[i],
972
+ output_name_to_node,
973
+ exclude=[],
974
+ return_indice=return_indice,
975
+ )
976
+ if matched_parent is None:
977
+ return None
978
+
979
+ matched_parents.append(matched_parent)
980
+ current_node = matched_parent
981
+
982
+ return matched_parents
983
+
984
+ def is_smoothquant_model(self):
985
+ """Check the model is smooth quantized or not.
986
+
987
+ Returns:
988
+ bool: the model is smooth quantized or not.
989
+ """
990
+ for init in self.model.graph.initializer: # noqa: SIM110
991
+ if "_smooth_scale" in init.name:
992
+ return True
993
+ return False
994
+
995
+ def find_split_nodes(self):
996
+ """Find split nodes for layer-wise quantization."""
997
+ split_nodes = self.find_split_node_for_layer_wise_quantization()
998
+ return split_nodes
999
+
1000
+ def split_model_with_node(
1001
+ self, split_node_name, path_of_model_to_split, shape_infer=True, save_both_split_models=True
1002
+ ):
1003
+ """Split model into two parts at a given node.
1004
+
1005
+ Args:
1006
+ split_node_name (str): name of the node where the model is split at>
1007
+ path_of_model_to_split (str): path of model to be split.
1008
+ shape_infer (bool): do shape inference. Default is True.
1009
+ save_both_split_models (bool): whether to save the two split models.
1010
+ False means only save the first split model.
1011
+ True means save both the two split models.
1012
+ Default id True.
1013
+
1014
+ Returns:
1015
+ tuple: the first split model, the second split model
1016
+ """
1017
+ # origin model : ... -> node_1 -> split_node -> node_2 -> ...
1018
+ # split model 1: ... -> node_1 -> split_node
1019
+ # split model 2: node_2 -> ...
1020
+
1021
+ split_model_part_1 = onnx.ModelProto()
1022
+ split_model_part_1.CopyFrom(self._model)
1023
+ split_model_part_1.graph.ClearField("node")
1024
+
1025
+ split_model_part_2 = onnx.ModelProto()
1026
+ split_model_part_2.CopyFrom(self._model)
1027
+ split_model_part_2.graph.ClearField("node")
1028
+
1029
+ split_node_output = None
1030
+ part_idx = 1
1031
+ for node in self._model.graph.node:
1032
+ if part_idx == 1:
1033
+ split_model_part_1.graph.node.append(node)
1034
+ elif part_idx == 2:
1035
+ split_model_part_2.graph.node.append(node)
1036
+
1037
+ if node.name == split_node_name:
1038
+ split_node_output = node.output
1039
+ part_idx = 2
1040
+
1041
+ assert len(split_node_output) == 1, (
1042
+ f"Only support split at node with 1 output tensor, while current split node {split_node_name} has {len(split_node_output)} output tensors"
1043
+ )
1044
+ split_tensor_name = split_node_output[0]
1045
+
1046
+ # infer shape of the model to be split
1047
+ if shape_infer:
1048
+ try:
1049
+ from neural_compressor.adaptor.ox_utils.util import infer_shapes # noqa: PLC0415
1050
+
1051
+ self._model = infer_shapes(self._model, auto_merge=True, base_dir=os.path.dirname(self._model_path))
1052
+ except Exception as e: # pragma: no cover
1053
+ logger.error(
1054
+ "Shape infer fails for layer-wise quantization. "
1055
+ "We would recommend checking the graph optimization level of your model "
1056
+ "and setting it to 'DISABLE_ALL' or 'ENABLE_BASIC', "
1057
+ "as this may help avoid this error."
1058
+ )
1059
+ raise e
1060
+
1061
+ split_tensor_type, split_tensor_shape = self._get_output_type_shape_by_tensor_name(split_tensor_name)
1062
+ split_tensor = onnx.helper.make_tensor_value_info(split_tensor_name, split_tensor_type, split_tensor_shape)
1063
+
1064
+ split_model_part_1 = ONNXModel(split_model_part_1, ignore_warning=True)
1065
+ split_model_part_2 = ONNXModel(split_model_part_2, ignore_warning=True)
1066
+
1067
+ # remove unused input & output
1068
+ split_model_part_1._remove_unused_input_output()
1069
+ split_model_part_2._remove_unused_input_output()
1070
+
1071
+ split_model_part_1.model.graph.output.append(split_tensor)
1072
+ split_model_part_2.model.graph.input.append(split_tensor)
1073
+
1074
+ insert_output_for_model_1 = []
1075
+ insert_input_for_model_2 = []
1076
+ for output in split_model_part_1.output_name_to_node:
1077
+ if output in split_model_part_2.input_name_to_nodes:
1078
+ output_type, output_shape = self._get_output_type_shape_by_tensor_name(output)
1079
+ output_tensor = onnx.helper.make_tensor_value_info(output, output_type, output_shape)
1080
+ if output_tensor not in split_model_part_1.model.graph.output:
1081
+ insert_output_for_model_1.append(output_tensor)
1082
+ if output_tensor not in split_model_part_2.model.graph.input:
1083
+ insert_input_for_model_2.append(output_tensor)
1084
+
1085
+ # insert model 1 output
1086
+ for output in insert_output_for_model_1:
1087
+ split_model_part_1.model.graph.output.append(output)
1088
+
1089
+ # insert model 2 input
1090
+ for input in insert_input_for_model_2:
1091
+ split_model_part_2.model.graph.input.append(input)
1092
+
1093
+ # remove unused init
1094
+ split_model_part_1.remove_unused_init()
1095
+ split_model_part_2.remove_unused_init()
1096
+
1097
+ split_model_part_1.update()
1098
+ split_model_part_2.update()
1099
+
1100
+ dir_of_model_to_split = os.path.dirname(path_of_model_to_split)
1101
+
1102
+ split_model_part_1.load_model_initializer_by_tensor(dir_of_model_to_split)
1103
+ split_model_part_1_path = os.path.join(dir_of_model_to_split, "split_model_part_1.onnx")
1104
+ split_model_part_1.model_path = split_model_part_1_path
1105
+ split_model_part_1._save_split_model(split_model_part_1_path)
1106
+ split_model_part_1.check_is_large_model()
1107
+ logger.debug(f"save split model part 1 to {split_model_part_1_path} for layer wise quantization")
1108
+
1109
+ if save_both_split_models:
1110
+ split_model_part_2.load_model_initializer_by_tensor(dir_of_model_to_split)
1111
+ split_model_part_2_path = os.path.join(dir_of_model_to_split, "split_model_part_2.onnx")
1112
+ split_model_part_2.model_path = split_model_part_2_path
1113
+ split_model_part_2._save_split_model(split_model_part_2_path)
1114
+ split_model_part_2.check_is_large_model()
1115
+ logger.debug(f"save split model part 2 to {split_model_part_2_path} for layer wise quantization")
1116
+ return split_model_part_1, split_model_part_2
1117
+ else:
1118
+ return split_model_part_1, split_model_part_2
1119
+
1120
+ def _save_split_model(self, save_path):
1121
+ """Save split model as external data for layer wise quantization.
1122
+
1123
+ Args:
1124
+ save_path (str): the path to save the split model
1125
+ """
1126
+ if os.path.exists(save_path + "_data"):
1127
+ os.remove(save_path + "_data")
1128
+ onnx.save_model(
1129
+ self._model,
1130
+ save_path,
1131
+ save_as_external_data=True,
1132
+ all_tensors_to_one_file=True,
1133
+ location=save_path.split("/")[-1] + "_data",
1134
+ size_threshold=1024,
1135
+ convert_attribute=False,
1136
+ )
1137
+
1138
+ def _get_output_type_shape_by_tensor_name(self, tensor_name):
1139
+ """Get output type and shape with a tensor name.
1140
+
1141
+ Args:
1142
+ tensor_name (str): name of a tensor
1143
+
1144
+ Returns:
1145
+ tuple: output type and shape
1146
+ """
1147
+ elem_type = onnx.TensorProto.FLOAT
1148
+ shape = None
1149
+ for output in self._model.graph.value_info:
1150
+ if output.name == tensor_name:
1151
+ elem_type = output.type.tensor_type.elem_type
1152
+ shape = [
1153
+ dim.dim_value if dim.HasField("dim_value") else -1 for dim in output.type.tensor_type.shape.dim
1154
+ ]
1155
+ break
1156
+ return elem_type, shape
1157
+
1158
+ def _remove_unused_input_output(self):
1159
+ """Remove unused input & output for split model."""
1160
+ remove_outputs = []
1161
+ remove_inputs = []
1162
+ for output in self._model.graph.output:
1163
+ if output.name not in self.output_name_to_node:
1164
+ remove_outputs.append(output)
1165
+
1166
+ for input in self._model.graph.input:
1167
+ if input.name not in self.input_name_to_nodes:
1168
+ remove_inputs.append(input)
1169
+
1170
+ for output in remove_outputs:
1171
+ self._model.graph.output.remove(output)
1172
+ for input in remove_inputs:
1173
+ self._model.graph.input.remove(input)
1174
+
1175
+ def remove_unused_init(self):
1176
+ """Remove unused init."""
1177
+ remov_inits = []
1178
+ for init in self._model.graph.initializer:
1179
+ if init.name not in self.input_name_to_nodes:
1180
+ remov_inits.append(init)
1181
+ self.remove_initializers(remov_inits)
1182
+
1183
+ def load_model_initializer_by_tensor(self, data_path=None):
1184
+ """Load model initializer by tensor.
1185
+
1186
+ Args:
1187
+ data_path (str, optional): the directory of saved initializer. Defaults to None.
1188
+ """
1189
+ if data_path is None:
1190
+ data_path = os.path.dirname(self._model_path)
1191
+ for init in self._model.graph.initializer:
1192
+ if init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL:
1193
+ onnx.external_data_helper.load_external_data_for_tensor(init, data_path)
1194
+
1195
+ def write_external_data_to_new_location(self, external_data_location="external.data", overwrite=False):
1196
+ """Write external data of merged quantized model to new location to save memory.
1197
+
1198
+ Args:
1199
+ external_data_location (str, optional): external data location of merged quantized model.
1200
+ Defaults to "external.data".
1201
+ overwrite (bool, optional): if True, remove existed externa data. Defaults to False.
1202
+ """
1203
+ if overwrite and os.path.exists(os.path.join(os.path.dirname(self._model_path), external_data_location)):
1204
+ os.remove(os.path.join(os.path.dirname(self._model_path), external_data_location))
1205
+ self.load_model_initializer_by_tensor()
1206
+ onnx.external_data_helper.convert_model_to_external_data(self._model, location=external_data_location)
1207
+ # TODO : if init is already saved, skip write it
1208
+ onnx.external_data_helper.write_external_data_tensors(self._model, filepath=os.path.dirname(self._model_path))
1209
+
1210
+ def merge_split_models(self, to_merge_model):
1211
+ """Merge two split model into final model."""
1212
+ to_merge_model.write_external_data_to_new_location()
1213
+ self.add_nodes(list(to_merge_model.nodes()))
1214
+ self.add_initializers(list(to_merge_model.initializer()))
1215
+ self.update()
1216
+
1217
+ # add new output
1218
+ for output in to_merge_model.graph().output:
1219
+ if output.name not in self.output():
1220
+ self._model.graph.output.append(output)
1221
+
1222
+ # remove unused output
1223
+ remove_output = []
1224
+ for output in self._model.graph.output:
1225
+ if output.name in to_merge_model.input():
1226
+ remove_output.append(output)
1227
+ for output in remove_output:
1228
+ self._model.graph.output.remove(output)
1229
+
1230
+ # add new input
1231
+ for input in to_merge_model.graph().input:
1232
+ if (
1233
+ input.name not in self.input()
1234
+ and input.name not in self.output()
1235
+ and input.name not in self.output_name_to_node
1236
+ ):
1237
+ self._model.graph.input.append(input)
1238
+
1239
+ def re_org_output(self, origin_output):
1240
+ """Re-org output of merged model for layer-wise quantization."""
1241
+ outputs = {}
1242
+ tmp_remove = []
1243
+ for output in self._model.graph.output:
1244
+ outputs[output.name] = output
1245
+ tmp_remove.append(output)
1246
+
1247
+ for output in tmp_remove:
1248
+ self._model.graph.output.remove(output)
1249
+
1250
+ for out_name in origin_output:
1251
+ self._model.graph.output.append(outputs[out_name])