onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,226 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import logging
6
+
7
+ from fusion_attention import AttentionMask, FusionAttention
8
+ from fusion_utils import NumpyHelper
9
+ from onnx import NodeProto, helper
10
+ from onnx_model import OnnxModel
11
+ from onnx_model_bert import BertOnnxModel
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class FusionTnlrAttention(FusionAttention):
17
+ """
18
+ Fuse TNLR Attention subgraph into one Attention node.
19
+ TNLR Attention has extra addition after qk nodes and adopts [S, B, NH] as I/O shape.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ model: OnnxModel,
25
+ hidden_size: int,
26
+ num_heads: int,
27
+ attention_mask: AttentionMask,
28
+ ):
29
+ super().__init__(model, hidden_size, num_heads, attention_mask)
30
+
31
+ def create_attention_node(
32
+ self,
33
+ mask_index: str,
34
+ matmul: NodeProto,
35
+ add: NodeProto,
36
+ num_heads: int,
37
+ hidden_size: int,
38
+ input: str,
39
+ output: str,
40
+ add_qk_str: str,
41
+ ) -> NodeProto | None:
42
+ assert num_heads > 0
43
+ if hidden_size > 0 and (hidden_size % num_heads) != 0:
44
+ logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
45
+ return None
46
+
47
+ weight = self.model.get_initializer(matmul.input[1])
48
+ bias = self.model.get_initializer(add.input[1]) or self.model.get_initializer(add.input[0])
49
+
50
+ if weight is None or bias is None:
51
+ return None
52
+
53
+ qkv_weight = NumpyHelper.to_array(weight)
54
+ qkv_bias = NumpyHelper.to_array(bias)
55
+
56
+ attention_node_name = self.model.create_node_name("Attention")
57
+
58
+ tensor_dtype = weight.data_type
59
+ np_type = helper.tensor_dtype_to_np_dtype(tensor_dtype)
60
+ weight = helper.make_tensor(
61
+ name=attention_node_name + "_qkv_weight",
62
+ data_type=tensor_dtype,
63
+ dims=[hidden_size, 3 * hidden_size],
64
+ vals=qkv_weight.astype(np_type).tobytes(),
65
+ raw=True,
66
+ )
67
+ self.model.add_initializer(weight, self.this_graph_name)
68
+
69
+ bias = helper.make_tensor(
70
+ name=attention_node_name + "_qkv_bias",
71
+ data_type=tensor_dtype,
72
+ dims=[3 * hidden_size],
73
+ vals=qkv_bias.astype(np_type).tobytes(),
74
+ raw=True,
75
+ )
76
+ self.model.add_initializer(bias, self.this_graph_name)
77
+
78
+ attention_inputs = [
79
+ input,
80
+ attention_node_name + "_qkv_weight",
81
+ attention_node_name + "_qkv_bias",
82
+ ]
83
+ if mask_index is not None:
84
+ attention_inputs.append(mask_index)
85
+ else:
86
+ attention_inputs.append("")
87
+
88
+ if add_qk_str is not None:
89
+ attention_inputs.append("")
90
+ attention_inputs.append(add_qk_str)
91
+
92
+ attention_node = helper.make_node(
93
+ "Attention",
94
+ inputs=attention_inputs,
95
+ outputs=[output],
96
+ name=attention_node_name,
97
+ )
98
+ attention_node.domain = "com.microsoft"
99
+ attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
100
+
101
+ return attention_node
102
+
103
+ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
104
+ # Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm
105
+ # Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern
106
+ start_node = normalize_node
107
+ if normalize_node.op_type != "SkipLayerNormalization":
108
+ return
109
+
110
+ # SkipLayerNormalization has two inputs, and one of them is the root input for attention.
111
+ qkv_nodes = self.model.match_parent_path(
112
+ start_node,
113
+ ["Where", "Add", "MatMul", "Reshape", "Transpose", "MatMul"],
114
+ [1, 1, 1, 0, 0, 0],
115
+ )
116
+ if qkv_nodes is not None:
117
+ (_, _, matmul_below, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
118
+ else:
119
+ return
120
+
121
+ other_inputs = []
122
+ for _i, input in enumerate(start_node.input):
123
+ if input not in output_name_to_node:
124
+ continue
125
+
126
+ if input == qkv_nodes[0].output[0]:
127
+ continue
128
+ other_inputs.append(input)
129
+ if len(other_inputs) != 1:
130
+ return
131
+
132
+ root_input = other_inputs[0]
133
+
134
+ v_nodes = self.model.match_parent_path(
135
+ matmul_qkv,
136
+ ["Transpose", "Reshape", "Slice", "Add", "MatMul"],
137
+ [1, 0, 0, 0, 1],
138
+ )
139
+ if v_nodes is None:
140
+ return
141
+ (_, _, _, add, matmul) = v_nodes
142
+
143
+ upper_nodes = self.model.match_parent_path(matmul, ["Transpose"], [0])
144
+ transpose = upper_nodes[0]
145
+
146
+ qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])
147
+ if qk_nodes is None:
148
+ return
149
+ (_, add_qk, matmul_qk) = qk_nodes
150
+
151
+ q_nodes = self.model.match_parent_path(
152
+ matmul_qk,
153
+ ["Mul", "Transpose", "Reshape", "Slice", "Add", "MatMul"],
154
+ [0, 0, 0, 0, 0, 1],
155
+ )
156
+ if q_nodes is None:
157
+ return
158
+ add = q_nodes[-2]
159
+ matmul = q_nodes[-1]
160
+
161
+ k_nodes = self.model.match_parent_path(
162
+ matmul_qk,
163
+ ["Transpose", "Reshape", "Slice", "Add", "MatMul"],
164
+ [1, 0, 0, 0, 1],
165
+ )
166
+ if k_nodes is None:
167
+ return
168
+ add = k_nodes[-2]
169
+ matmul = k_nodes[-1]
170
+
171
+ relative_position_bias_nodes = self.model.match_parent_path(add_qk, ["Reshape", "Where"], [1, 0])
172
+ if relative_position_bias_nodes is None:
173
+ return
174
+
175
+ if matmul.input[0] == root_input:
176
+ mask_index = None
177
+ attention_last_node = reshape_qkv
178
+ # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
179
+ # the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately
180
+ new_node = self.create_attention_node(
181
+ mask_index,
182
+ matmul,
183
+ add,
184
+ self.num_heads,
185
+ self.hidden_size,
186
+ root_input,
187
+ attention_last_node.output[0],
188
+ relative_position_bias_nodes[0].input[0],
189
+ )
190
+ if new_node is None:
191
+ return
192
+
193
+ self.nodes_to_add.append(new_node)
194
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
195
+
196
+ # Add a transpose node after the attention node
197
+ back_transpose = helper.make_node(
198
+ "Transpose",
199
+ ["back_transpose_in_" + new_node.name],
200
+ [new_node.output[0]],
201
+ "back_transpose_" + new_node.name,
202
+ perm=[1, 0, 2],
203
+ )
204
+ self.model.add_node(back_transpose, self.this_graph_name)
205
+ new_node.input[0] = transpose.input[0]
206
+ new_node.output[0] = "back_transpose_in_" + new_node.name
207
+
208
+ self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
209
+ self.nodes_to_remove.extend(qk_nodes)
210
+ self.nodes_to_remove.extend(q_nodes)
211
+ self.nodes_to_remove.extend(k_nodes)
212
+ self.nodes_to_remove.extend(v_nodes)
213
+
214
+ # Use prune graph to remove mask nodes since they are shared by all attention nodes.
215
+ # self.nodes_to_remove.extend(mask_nodes)
216
+ self.prune_graph = True
217
+
218
+
219
+ class TnlrOnnxModel(BertOnnxModel):
220
+ def __init__(self, model, num_heads, hidden_size):
221
+ super().__init__(model, num_heads, hidden_size)
222
+ self.attention_mask = AttentionMask(self)
223
+ self.attention_fusion = FusionTnlrAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
224
+
225
+ def fuse_attention(self):
226
+ self.attention_fusion.apply()
@@ -0,0 +1,258 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ import logging
7
+
8
+ from fusion_attention_unet import FusionAttentionUnet
9
+ from fusion_bias_add import FusionBiasAdd
10
+ from fusion_biassplitgelu import FusionBiasSplitGelu
11
+ from fusion_group_norm import FusionGroupNorm
12
+ from fusion_nhwc_conv import FusionNhwcConv
13
+ from fusion_options import FusionOptions
14
+ from fusion_skip_group_norm import FusionSkipGroupNorm
15
+ from fusion_transpose import FusionInsertTranspose, FusionTranspose
16
+ from import_utils import is_installed
17
+ from onnx import ModelProto
18
+ from onnx_model import OnnxModel
19
+ from onnx_model_bert import BertOnnxModel
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class UnetOnnxModel(BertOnnxModel):
25
+ def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
26
+ """Initialize UNet ONNX Model.
27
+
28
+ Args:
29
+ model (ModelProto): the ONNX model
30
+ num_heads (int, optional): number of attention heads. Defaults to 0 (detect the parameter automatically).
31
+ hidden_size (int, optional): hidden dimension. Defaults to 0 (detect the parameter automatically).
32
+ """
33
+ assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0)
34
+
35
+ super().__init__(model, num_heads=num_heads, hidden_size=hidden_size)
36
+
37
+ def preprocess(self):
38
+ self.remove_useless_div()
39
+
40
+ def postprocess(self):
41
+ self.prune_graph()
42
+ self.remove_unused_constant()
43
+
44
+ def remove_useless_div(self):
45
+ """Remove Div by 1"""
46
+ div_nodes = [node for node in self.nodes() if node.op_type == "Div"]
47
+
48
+ nodes_to_remove = []
49
+ for div in div_nodes:
50
+ if self.find_constant_input(div, 1.0) == 1:
51
+ nodes_to_remove.append(div)
52
+
53
+ for node in nodes_to_remove:
54
+ self.replace_input_of_all_nodes(node.output[0], node.input[0])
55
+
56
+ if nodes_to_remove:
57
+ self.remove_nodes(nodes_to_remove)
58
+ logger.info("Removed %d Div nodes", len(nodes_to_remove))
59
+
60
+ def convert_conv_to_nhwc(self):
61
+ # Transpose weights in offline might help since ORT does not apply constant-folding on Transpose nodes.
62
+ conv_to_nhwc_conv = FusionNhwcConv(self, update_weight=True)
63
+ conv_to_nhwc_conv.apply()
64
+
65
+ def merge_adjacent_transpose(self):
66
+ fusion_transpose = FusionTranspose(self)
67
+ fusion_transpose.apply()
68
+
69
+ remove_count = 0
70
+ nodes = self.get_nodes_by_op_type("Transpose")
71
+ for node in nodes:
72
+ permutation = OnnxModel.get_node_attribute(node, "perm")
73
+ assert isinstance(permutation, list)
74
+ if permutation != list(range(len(permutation))):
75
+ continue
76
+ assert not (
77
+ self.find_graph_output(node.output[0])
78
+ or self.find_graph_input(node.input[0])
79
+ or self.find_graph_output(node.input[0])
80
+ )
81
+
82
+ # Let all children nodes skip current Transpose node and link to its parent
83
+ # Note that we cannot update parent node output since parent node might have more than one children.
84
+ self.replace_input_of_all_nodes(node.output[0], node.input[0])
85
+
86
+ self.remove_node(node)
87
+ remove_count += 1
88
+
89
+ total = len(fusion_transpose.nodes_to_remove) + remove_count
90
+ if total:
91
+ logger.info("Removed %d Transpose nodes", total)
92
+
93
+ def fuse_multi_head_attention(self, options: FusionOptions | None = None):
94
+ # Self Attention
95
+ enable_packed_qkv = (options is None) or options.enable_packed_qkv
96
+ self_attention_fusion = FusionAttentionUnet(
97
+ self,
98
+ self.hidden_size,
99
+ self.num_heads,
100
+ is_cross_attention=False,
101
+ enable_packed_qkv=enable_packed_qkv,
102
+ enable_packed_kv=False,
103
+ )
104
+ self_attention_fusion.apply()
105
+
106
+ # Cross Attention
107
+ enable_packed_kv = (options is None) or options.enable_packed_kv
108
+ cross_attention_fusion = FusionAttentionUnet(
109
+ self,
110
+ self.hidden_size,
111
+ self.num_heads,
112
+ is_cross_attention=True,
113
+ enable_packed_qkv=False,
114
+ enable_packed_kv=enable_packed_kv,
115
+ )
116
+ cross_attention_fusion.apply()
117
+
118
+ def fuse_bias_add(self):
119
+ fusion = FusionBiasAdd(self)
120
+ fusion.apply()
121
+
122
+ def optimize(self, options: FusionOptions | None = None):
123
+ if is_installed("tqdm"):
124
+ import tqdm # noqa: PLC0415
125
+ from tqdm.contrib.logging import logging_redirect_tqdm # noqa: PLC0415
126
+
127
+ with logging_redirect_tqdm():
128
+ steps = 18
129
+ progress_bar = tqdm.tqdm(range(steps), initial=0, desc="fusion")
130
+ self._optimize(options, progress_bar)
131
+ else:
132
+ logger.info("tqdm is not installed. Run optimization without progress bar")
133
+ self._optimize(options, None)
134
+
135
+ def _optimize(self, options: FusionOptions | None = None, progress_bar=None):
136
+ if (options is not None) and not options.enable_shape_inference:
137
+ self.disable_shape_inference()
138
+
139
+ self.utils.remove_identity_nodes()
140
+ if progress_bar:
141
+ progress_bar.update(1)
142
+
143
+ # Remove cast nodes that having same data type of input and output based on symbolic shape inference.
144
+ self.utils.remove_useless_cast_nodes()
145
+ if progress_bar:
146
+ progress_bar.update(1)
147
+
148
+ if (options is None) or options.enable_layer_norm:
149
+ self.fuse_layer_norm()
150
+ if progress_bar:
151
+ progress_bar.update(1)
152
+
153
+ if (options is None) or options.enable_gelu:
154
+ self.fuse_gelu()
155
+ if progress_bar:
156
+ progress_bar.update(1)
157
+
158
+ self.preprocess()
159
+ if progress_bar:
160
+ progress_bar.update(1)
161
+
162
+ self.fuse_reshape()
163
+ if progress_bar:
164
+ progress_bar.update(1)
165
+
166
+ if (options is None) or options.enable_group_norm:
167
+ channels_last = (options is None) or options.group_norm_channels_last
168
+ group_norm_fusion = FusionGroupNorm(self, channels_last)
169
+ group_norm_fusion.apply()
170
+
171
+ insert_transpose_fusion = FusionInsertTranspose(self)
172
+ insert_transpose_fusion.apply()
173
+ if progress_bar:
174
+ progress_bar.update(1)
175
+
176
+ if (options is None) or options.enable_bias_splitgelu:
177
+ bias_split_gelu_fusion = FusionBiasSplitGelu(self)
178
+ bias_split_gelu_fusion.apply()
179
+ if progress_bar:
180
+ progress_bar.update(1)
181
+
182
+ if (options is None) or options.enable_attention:
183
+ # self.save_model_to_file("before_mha.onnx")
184
+ self.fuse_multi_head_attention(options)
185
+ if progress_bar:
186
+ progress_bar.update(1)
187
+
188
+ if (options is None) or options.enable_skip_layer_norm:
189
+ self.fuse_skip_layer_norm()
190
+ if progress_bar:
191
+ progress_bar.update(1)
192
+
193
+ self.fuse_shape()
194
+ if progress_bar:
195
+ progress_bar.update(1)
196
+
197
+ # Remove reshape nodes that having same shape of input and output based on symbolic shape inference.
198
+ self.utils.remove_useless_reshape_nodes()
199
+ if progress_bar:
200
+ progress_bar.update(1)
201
+
202
+ if (options is None) or options.enable_skip_group_norm:
203
+ skip_group_norm_fusion = FusionSkipGroupNorm(self)
204
+ skip_group_norm_fusion.apply()
205
+ if progress_bar:
206
+ progress_bar.update(1)
207
+
208
+ if (options is None) or options.enable_bias_skip_layer_norm:
209
+ # Fuse SkipLayerNormalization and Add Bias before it.
210
+ self.fuse_add_bias_skip_layer_norm()
211
+ if progress_bar:
212
+ progress_bar.update(1)
213
+
214
+ if options is not None and options.enable_gelu_approximation:
215
+ self.gelu_approximation()
216
+ if progress_bar:
217
+ progress_bar.update(1)
218
+
219
+ if options is None or options.enable_nhwc_conv:
220
+ self.convert_conv_to_nhwc()
221
+ self.merge_adjacent_transpose()
222
+ if progress_bar:
223
+ progress_bar.update(1)
224
+
225
+ if options is not None and options.enable_bias_add:
226
+ self.fuse_bias_add()
227
+ if progress_bar:
228
+ progress_bar.update(1)
229
+
230
+ self.postprocess()
231
+ if progress_bar:
232
+ progress_bar.update(1)
233
+
234
+ logger.info(f"opset version: {self.get_opset_version()}")
235
+
236
+ def get_fused_operator_statistics(self):
237
+ """
238
+ Returns node count of fused operators.
239
+ """
240
+ op_count = {}
241
+ ops = [
242
+ "Attention",
243
+ "MultiHeadAttention",
244
+ "LayerNormalization",
245
+ "SkipLayerNormalization",
246
+ "BiasSplitGelu",
247
+ "GroupNorm",
248
+ "SkipGroupNorm",
249
+ "NhwcConv",
250
+ "BiasAdd",
251
+ ]
252
+
253
+ for op in ops:
254
+ nodes = self.get_nodes_by_op_type(op)
255
+ op_count[op] = len(nodes)
256
+
257
+ logger.info(f"Optimized operators:{op_count}")
258
+ return op_count
@@ -0,0 +1,42 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+
8
+ from fusion_attention_vae import FusionAttentionVae
9
+ from fusion_options import FusionOptions
10
+ from onnx import ModelProto
11
+ from onnx_model_unet import UnetOnnxModel
12
+
13
+ logger = getLogger(__name__)
14
+
15
+
16
+ class VaeOnnxModel(UnetOnnxModel):
17
+ def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
18
+ assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0)
19
+ super().__init__(model, num_heads=num_heads, hidden_size=hidden_size)
20
+
21
+ def fuse_multi_head_attention(self, options: FusionOptions | None = None):
22
+ # Self Attention
23
+ self_attention_fusion = FusionAttentionVae(self, self.hidden_size, self.num_heads)
24
+ self_attention_fusion.apply()
25
+
26
+ def get_fused_operator_statistics(self):
27
+ """
28
+ Returns node count of fused operators.
29
+ """
30
+ op_count = {}
31
+ ops = [
32
+ "Attention",
33
+ "GroupNorm",
34
+ "SkipGroupNorm",
35
+ "NhwcConv",
36
+ ]
37
+ for op in ops:
38
+ nodes = self.get_nodes_by_op_type(op)
39
+ op_count[op] = len(nodes)
40
+
41
+ logger.info(f"Optimized operators:{op_count}")
42
+ return op_count
@@ -0,0 +1,55 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from fusion_utils import NumpyHelper
6
+ from onnx import ModelProto, TensorProto
7
+ from onnx.external_data_helper import set_external_data
8
+ from onnx_model import OnnxModel
9
+
10
+ from onnxruntime import OrtValue
11
+
12
+
13
+ def extract_raw_data_from_model(model: ModelProto):
14
+ """
15
+ Extract external data from model and return the external data as a list of tuples (name, value).
16
+ Note this function does not handle external data that is not loaded into the model as raw data.
17
+
18
+ Args:
19
+ model (ModelProto): the model proto to extract external data from.
20
+ Returns:
21
+ (external_names, external_values): a tuple of two lists of external data names and values.
22
+ """
23
+ external_data = []
24
+ onnx_model = OnnxModel(model)
25
+ for graph in onnx_model.graphs():
26
+ for initializer in graph.initializer:
27
+ name = initializer.name
28
+
29
+ if initializer.HasField("raw_data"):
30
+ numpy_tensor = NumpyHelper.to_array(initializer)
31
+ ort_value = OrtValue.ortvalue_from_numpy(numpy_tensor)
32
+ external_data.append((name, ort_value))
33
+ # mimic set_external_data
34
+ set_external_data(initializer, location="foo.bin")
35
+ initializer.name = name
36
+ initializer.ClearField("raw_data")
37
+
38
+ return zip(*external_data, strict=False)
39
+
40
+
41
+ def has_external_data(model: ModelProto):
42
+ """
43
+ Check if the model has external data.
44
+
45
+ Args:
46
+ model (ModelProto): the model proto to check for external data.
47
+ Returns:
48
+ bool: True if the model has external data, False otherwise.
49
+ """
50
+ onnx_model = OnnxModel(model)
51
+ for graph in onnx_model.graphs():
52
+ for initializer in graph.initializer:
53
+ if initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL:
54
+ return True
55
+ return False