onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,131 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ import inspect
5
+ from collections import abc
6
+
7
+ import torch
8
+
9
+
10
+ def _parse_inputs_for_onnx_export(all_input_parameters, inputs, kwargs):
11
+ # extracted from https://github.com/microsoft/onnxruntime/blob/239c6ad3f021ff7cc2e6247eb074bd4208dc11e2/orttraining/orttraining/python/training/ortmodule/_io.py#L433
12
+
13
+ def _add_input(name, input):
14
+ """Returns number of expanded inputs that _add_input processed"""
15
+
16
+ if input is None:
17
+ # Drop all None inputs and return 0.
18
+ return 0
19
+
20
+ num_expanded_non_none_inputs = 0
21
+ if isinstance(input, abc.Sequence):
22
+ # If the input is a sequence (like a list), expand the list so that
23
+ # each element of the list is an input by itself.
24
+ for i, val in enumerate(input):
25
+ # Name each input with the index appended to the original name of the
26
+ # argument.
27
+ num_expanded_non_none_inputs += _add_input(f"{name}_{i}", val)
28
+
29
+ # Return here since the list by itself is not a valid input.
30
+ # All the elements of the list have already been added as inputs individually.
31
+ return num_expanded_non_none_inputs
32
+ elif isinstance(input, abc.Mapping):
33
+ # If the input is a mapping (like a dict), expand the dict so that
34
+ # each element of the dict is an input by itself.
35
+ for key, val in input.items():
36
+ num_expanded_non_none_inputs += _add_input(f"{name}_{key}", val)
37
+
38
+ # Return here since the dict by itself is not a valid input.
39
+ # All the elements of the dict have already been added as inputs individually.
40
+ return num_expanded_non_none_inputs
41
+
42
+ # InputInfo should contain all the names irrespective of whether they are
43
+ # a part of the onnx graph or not.
44
+ input_names.append(name)
45
+
46
+ # A single input non none input was processed, return 1
47
+ return 1
48
+
49
+ input_names = []
50
+ var_positional_idx = 0
51
+ num_expanded_non_none_positional_inputs = 0
52
+
53
+ for input_idx, input_parameter in enumerate(all_input_parameters):
54
+ if input_parameter.kind == inspect.Parameter.VAR_POSITIONAL:
55
+ # VAR_POSITIONAL parameter carries all *args parameters from original forward method
56
+ for args_i in range(input_idx, len(inputs)):
57
+ name = f"{input_parameter.name}_{var_positional_idx}"
58
+ var_positional_idx += 1
59
+ inp = inputs[args_i]
60
+ num_expanded_non_none_positional_inputs += _add_input(name, inp)
61
+ elif (
62
+ input_parameter.kind == inspect.Parameter.POSITIONAL_ONLY
63
+ or input_parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
64
+ or input_parameter.kind == inspect.Parameter.KEYWORD_ONLY
65
+ ):
66
+ # All positional non-*args and non-**kwargs are processed here
67
+ name = input_parameter.name
68
+ inp = None
69
+ input_idx += var_positional_idx # noqa: PLW2901
70
+ is_positional = True
71
+ if input_idx < len(inputs) and inputs[input_idx] is not None:
72
+ inp = inputs[input_idx]
73
+ elif name in kwargs and kwargs[name] is not None:
74
+ inp = kwargs[name]
75
+ is_positional = False
76
+ num_expanded_non_none_inputs_local = _add_input(name, inp)
77
+ if is_positional:
78
+ num_expanded_non_none_positional_inputs += num_expanded_non_none_inputs_local
79
+ elif input_parameter.kind == inspect.Parameter.VAR_KEYWORD:
80
+ # **kwargs is always the last argument of forward()
81
+ for name, inp in kwargs.items():
82
+ if name not in input_names:
83
+ _add_input(name, inp)
84
+
85
+ return input_names
86
+
87
+
88
+ def _flatten_module_input(names, args, kwargs):
89
+ """Flatten args and kwargs in a single tuple of tensors."""
90
+ # extracted from https://github.com/microsoft/onnxruntime/blob/239c6ad3f021ff7cc2e6247eb074bd4208dc11e2/orttraining/orttraining/python/training/ortmodule/_io.py#L110
91
+
92
+ def is_primitive_type(value):
93
+ return type(value) in {int, bool, float}
94
+
95
+ def to_tensor(value):
96
+ return torch.tensor(value)
97
+
98
+ ret = [to_tensor(arg) if is_primitive_type(arg) else arg for arg in args]
99
+ ret += [
100
+ to_tensor(kwargs[name]) if is_primitive_type(kwargs[name]) else kwargs[name] for name in names if name in kwargs
101
+ ]
102
+
103
+ # if kwargs is empty, append an empty dictionary at the end of the sample inputs to make exporter
104
+ # happy. This is because the exporter is confused with kwargs and dictionary inputs otherwise.
105
+ if not kwargs:
106
+ ret.append({})
107
+
108
+ return tuple(ret)
109
+
110
+
111
+ def infer_input_info(module: torch.nn.Module, *inputs, **kwargs):
112
+ """
113
+ Infer the input names and order from the arguments used to execute a PyTorch module for usage exporting
114
+ the model via torch.onnx.export.
115
+ Assumes model is on CPU. Use `module.to(torch.device('cpu'))` if it isn't.
116
+
117
+ Example usage:
118
+ input_names, inputs_as_tuple = infer_input_info(module, ...)
119
+ torch.onnx.export(module, inputs_as_type, 'model.onnx', input_names=input_names, output_names=[...], ...)
120
+
121
+ :param module: Module
122
+ :param inputs: Positional inputs
123
+ :param kwargs: Keyword argument inputs
124
+ :return: Tuple of ordered input names and input values. These can be used directly with torch.onnx.export as the
125
+ `input_names` and `inputs` arguments.
126
+ """
127
+ module_parameters = inspect.signature(module.forward).parameters.values()
128
+ input_names = _parse_inputs_for_onnx_export(module_parameters, inputs, kwargs)
129
+ inputs_as_tuple = _flatten_module_input(input_names, inputs, kwargs)
130
+
131
+ return input_names, inputs_as_tuple
File without changes
@@ -0,0 +1,37 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+
5
+ import argparse
6
+ import os
7
+ import pathlib
8
+
9
+ import onnx
10
+
11
+
12
+ def optimize_qdq_model():
13
+ parser = argparse.ArgumentParser(
14
+ os.path.basename(__file__),
15
+ description="Update a QDQ format ONNX model to ensure optimal performance when executed using ONNX Runtime.",
16
+ )
17
+
18
+ parser.add_argument("input_model", type=pathlib.Path, help="Provide path to ONNX model to update.")
19
+ parser.add_argument("output_model", type=pathlib.Path, help="Provide path to write updated ONNX model to.")
20
+
21
+ args = parser.parse_args()
22
+
23
+ model = onnx.load(str(args.input_model.resolve(strict=True)))
24
+
25
+ # run QDQ model optimizations here
26
+
27
+ # Originally, the fixing up of DQ nodes with multiple consumers was implemented as one such optimization.
28
+ # That was moved to an ORT graph transformer.
29
+ print("As of ORT 1.15, the fixing up of DQ nodes with multiple consumers is done by an ORT graph transformer.")
30
+
31
+ # There are no optimizations being run currently but we expect that there may be in the future.
32
+
33
+ onnx.save(model, str(args.output_model.resolve()))
34
+
35
+
36
+ if __name__ == "__main__":
37
+ optimize_qdq_model()
@@ -0,0 +1,292 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ import json
7
+ from argparse import ArgumentParser
8
+
9
+ import onnx
10
+ from onnx import TensorProto, helper
11
+
12
+
13
+ def graph_topological_sort(graph):
14
+ deps_count = [0] * len(graph.node) # dependency count of each node
15
+ deps_to_nodes = {} # input to node indice
16
+ sorted_nodes = [] # initialize sorted_nodes
17
+ for node_idx, node in enumerate(graph.node):
18
+ # CANNOT use len(node.input) directly because input can be optional
19
+ deps_count[node_idx] = sum(1 for _ in node.input if _)
20
+ if deps_count[node_idx] == 0: # Constant doesn't depend on any inputs
21
+ sorted_nodes.append(graph.node[node_idx])
22
+ continue
23
+
24
+ for input_name in node.input:
25
+ if input_name not in deps_to_nodes:
26
+ deps_to_nodes[input_name] = [node_idx]
27
+ else:
28
+ deps_to_nodes[input_name].append(node_idx)
29
+
30
+ # Note: this logic only applies to top level graph since a sub graph could use intializer from parent graph
31
+ initializer_names = [init.name for init in graph.initializer]
32
+ graph_input_names = [input.name for input in graph.input]
33
+ input_names = initializer_names + graph_input_names
34
+ input_names.sort()
35
+ prev_input_name = None
36
+ for input_name in input_names:
37
+ if prev_input_name == input_name:
38
+ continue
39
+
40
+ prev_input_name = input_name
41
+ if input_name in deps_to_nodes:
42
+ for node_idx in deps_to_nodes[input_name]:
43
+ deps_count[node_idx] = deps_count[node_idx] - 1
44
+ if deps_count[node_idx] == 0:
45
+ sorted_nodes.append(graph.node[node_idx])
46
+
47
+ start = 0
48
+ end = len(sorted_nodes)
49
+
50
+ while start < end:
51
+ for output in sorted_nodes[start].output:
52
+ if output in deps_to_nodes:
53
+ for node_idx in deps_to_nodes[output]:
54
+ deps_count[node_idx] = deps_count[node_idx] - 1
55
+ if deps_count[node_idx] == 0:
56
+ sorted_nodes.append(graph.node[node_idx])
57
+ end = end + 1
58
+ start = start + 1
59
+
60
+ assert end == len(graph.node), "Graph is not a DAG"
61
+ graph.ClearField("node")
62
+ graph.node.extend(sorted_nodes)
63
+
64
+
65
+ class QnnTensorStruct:
66
+ def __init__(self):
67
+ self.name = ""
68
+ self.onnx_data_type = TensorProto.FLOAT
69
+ self.dim = []
70
+
71
+
72
+ def qnn_data_type_to_onnx_data_type(qnn_data_type):
73
+ # QNN_DATATYPE_UFIXED_POINT_8 QNN_DATATYPE_UINT_8
74
+ if qnn_data_type == 0x0408 or qnn_data_type == 0x0108:
75
+ return TensorProto.UINT8
76
+ # QNN_DATATYPE_UFIXED_POINT_16 QNN_DATATYPE_UINT_16
77
+ elif qnn_data_type == 0x0416 or qnn_data_type == 0x0116:
78
+ return TensorProto.UINT16
79
+ # QNN_DATATYPE_UFIXED_POINT_32 QNN_DATATYPE_UINT_32
80
+ elif qnn_data_type == 0x0432 or qnn_data_type == 0x0132:
81
+ return TensorProto.UINT32
82
+ # QNN_DATATYPE_UINT_64
83
+ elif qnn_data_type == 0x0164:
84
+ return TensorProto.UINT64
85
+ # QNN_DATATYPE_FIXED_POINT_8 QNN_DATATYPE_INT_8
86
+ elif qnn_data_type == 0x0308 or qnn_data_type == 0x0008:
87
+ return TensorProto.INT8
88
+ # QNN_DATATYPE_FIXED_POINT_16 QNN_DATATYPE_INT_16
89
+ elif qnn_data_type == 0x0316 or qnn_data_type == 0x0016:
90
+ return TensorProto.INT16
91
+ # QNN_DATATYPE_FIXED_POINT_32 QNN_DATATYPE_INT_32
92
+ elif qnn_data_type == 0x0332 or qnn_data_type == 0x0032:
93
+ return TensorProto.INT32
94
+ # QNN_DATATYPE_INT_64
95
+ elif qnn_data_type == 0x0064:
96
+ return TensorProto.INT64
97
+ # QNN_DATATYPE_FLOAT_16
98
+ elif qnn_data_type == 0x0216:
99
+ return TensorProto.FLOAT16
100
+ # QNN_DATATYPE_FLOAT_32
101
+ elif qnn_data_type == 0x0232:
102
+ return TensorProto.FLOAT
103
+ # QNN_DATATYPE_BOOL_8
104
+ elif qnn_data_type == 0x0508:
105
+ return TensorProto.BOOL
106
+ else:
107
+ return TensorProto.UNDEFINED
108
+
109
+
110
+ def parse_qnn_json_file(qnn_json_file_path, qnn_input_output_tensor_dic):
111
+ with open(qnn_json_file_path) as qnn_json_file:
112
+ qnn_json = json.load(qnn_json_file)
113
+ assert "graph" in qnn_json, "QNN converted json file not valid. Can't find graph."
114
+ assert "tensors" in qnn_json["graph"], "QNN converted json file not valid. Can't find tensors."
115
+ for qnn_tensor_name, qnn_tensor_attribute in qnn_json["graph"]["tensors"].items():
116
+ # type:0 - QNN input tensor, type:1 - QNN output tensor
117
+ assert (
118
+ "type" in qnn_tensor_attribute
119
+ and "data_type" in qnn_tensor_attribute
120
+ and "dims" in qnn_tensor_attribute
121
+ ), "QNN converted json file not valid. Can't find some keys from tensors"
122
+ if qnn_tensor_attribute["type"] == 0 or qnn_tensor_attribute["type"] == 1:
123
+ qnn_tensor = QnnTensorStruct()
124
+ qnn_tensor.name = qnn_tensor_name
125
+ qnn_tensor.onnx_data_type = qnn_data_type_to_onnx_data_type(qnn_tensor_attribute["data_type"])
126
+ qnn_tensor.dim = qnn_tensor_attribute["dims"]
127
+ qnn_input_output_tensor_dic[qnn_tensor_name] = qnn_tensor
128
+
129
+ assert len(qnn_input_output_tensor_dic) > 1, (
130
+ "Converted QNN model not valid. It should have at least 1 input & 1 output."
131
+ )
132
+
133
+
134
+ def compare_onnx_shape_with_qnn_shape(onnx_dims, qnn_dims):
135
+ assert len(onnx_dims) == len(qnn_dims), "Onnx shape and Qnn shape has different rank."
136
+ return all(onnx_dims[i].dim_value == qnn_dims[i] for i in range(len(onnx_dims)))
137
+
138
+
139
+ def gen_to_channel_first_perm(rank):
140
+ assert rank > 2, "Shape rank should >2 for the Transpose node."
141
+ perm = []
142
+ perm.append(0)
143
+ perm.append(rank - 1)
144
+ for i in range(1, rank - 1):
145
+ perm.append(i) # noqa: PERF402
146
+
147
+ return perm
148
+
149
+
150
+ def gen_to_channel_last_perm(rank):
151
+ assert rank > 2, "Shape rank should >2 for the Transpose node."
152
+ perm = []
153
+ perm.append(0)
154
+ for i in range(2, rank):
155
+ perm.append(i) # noqa: PERF402
156
+ perm.append(1)
157
+
158
+ return perm
159
+
160
+
161
+ # Onnxruntime QNN EP can support context binary file generated by QNN tool chain. However QNN generated context binary file
162
+ # uses channel last data layout and 8 bits or 16 bits for input and output.
163
+ # This script gets the QNN model input & output information from QNN converted model_net.json file, compare them with Onnx model
164
+ # and inserts Cast, Transpose nodes to Onnx model if required
165
+ def main():
166
+ parser = ArgumentParser(
167
+ "Insert Cast, Transpose nodes into Onnx model to make it aligned with QNN generated context binary."
168
+ )
169
+ parser.add_argument("-m", "--onnx_model", help="Required. Path to Onnx model file.", required=True, type=str)
170
+ parser.add_argument(
171
+ "-q", "--qnn_json", help="Required. Path to Qnn converted model_net.json file.", required=True, type=str
172
+ )
173
+ args = parser.parse_args()
174
+
175
+ # Parse Qnn model_net.json file to get the graph input output information
176
+ qnn_input_output_tensor_dic = {}
177
+ parse_qnn_json_file(args.qnn_json, qnn_input_output_tensor_dic)
178
+
179
+ model = onnx.load(args.onnx_model)
180
+
181
+ nodes_to_add = []
182
+ # Tranch the tensor name change to update the consumer nodes
183
+ graph_input_output_name_dic = {}
184
+ for graph_input in model.graph.input:
185
+ if graph_input.name in qnn_input_output_tensor_dic:
186
+ input_name_fater_node_insert = graph_input.name
187
+ qnn_input_tensor = qnn_input_output_tensor_dic[graph_input.name]
188
+ # Insert Cast node if Onnx input and Qnn input has different data type
189
+ if graph_input.type.tensor_type.elem_type != qnn_input_tensor.onnx_data_type:
190
+ # Insert Cast node
191
+ cast_input_name = input_name_fater_node_insert
192
+ cast_output_name = cast_input_name + "_qnn_cast"
193
+ input_cast_node = helper.make_node(
194
+ "Cast",
195
+ name=cast_output_name,
196
+ inputs=[cast_input_name],
197
+ outputs=[cast_output_name],
198
+ to=graph_input.type.tensor_type.elem_type,
199
+ )
200
+ # Change input data type to Qnn input data type
201
+ graph_input.type.tensor_type.elem_type = qnn_input_tensor.onnx_data_type
202
+ nodes_to_add.extend([input_cast_node])
203
+ input_name_fater_node_insert = cast_output_name
204
+ graph_input_output_name_dic[graph_input.name] = cast_output_name
205
+
206
+ if not compare_onnx_shape_with_qnn_shape(graph_input.type.tensor_type.shape.dim, qnn_input_tensor.dim):
207
+ # Add Transpose node (channel last to channel first)
208
+ transpose_perm = gen_to_channel_first_perm(len(graph_input.type.tensor_type.shape.dim))
209
+ transpose_input_name = input_name_fater_node_insert
210
+ transpose_output_name = transpose_input_name + "_qnn_trans"
211
+ input_transpose_node = helper.make_node(
212
+ "Transpose",
213
+ name=transpose_output_name,
214
+ inputs=[transpose_input_name],
215
+ outputs=[transpose_output_name],
216
+ perm=transpose_perm,
217
+ )
218
+ nodes_to_add.extend([input_transpose_node])
219
+ graph_input_output_name_dic[graph_input.name] = transpose_output_name
220
+
221
+ # Change input shape to Qnn input shape
222
+ for i in range(len(graph_input.type.tensor_type.shape.dim)):
223
+ graph_input.type.tensor_type.shape.dim[i].dim_value = qnn_input_tensor.dim[i]
224
+ else:
225
+ raise AssertionError("Error: Onnx model input: " + graph_input.name + " not exist from QNN model input.")
226
+
227
+ for graph_output in model.graph.output:
228
+ if graph_output.name in qnn_input_output_tensor_dic:
229
+ output_name_after_node_insert = graph_output.name
230
+ # Insert Cast node if Onnx input and Qnn input has idfferent data type
231
+ qnn_output_tensor = qnn_input_output_tensor_dic[graph_output.name]
232
+ if graph_output.type.tensor_type.elem_type != qnn_output_tensor.onnx_data_type:
233
+ # Insert Cast node
234
+ cast_output_name = output_name_after_node_insert
235
+ cast_input_name = cast_output_name + "_qnn_cast"
236
+ output_cast_node = helper.make_node(
237
+ "Cast",
238
+ name=cast_input_name,
239
+ inputs=[cast_input_name],
240
+ outputs=[cast_output_name],
241
+ to=qnn_output_tensor.onnx_data_type,
242
+ )
243
+ # Change output data type to Onn output data type
244
+ graph_output.type.tensor_type.elem_type = qnn_output_tensor.onnx_data_type
245
+ nodes_to_add.extend([output_cast_node])
246
+ output_name_after_node_insert = cast_input_name
247
+ graph_input_output_name_dic[graph_output.name] = cast_input_name
248
+
249
+ if not compare_onnx_shape_with_qnn_shape(graph_output.type.tensor_type.shape.dim, qnn_output_tensor.dim):
250
+ # Add Transpose node (channel first to channel last)
251
+ transpose_perm = gen_to_channel_last_perm(len(graph_output.type.tensor_type.shape.dim))
252
+ transpose_output_name = output_name_after_node_insert
253
+ transpose_input_name = transpose_output_name + "_qnn_trans"
254
+ output_transpose_node = helper.make_node(
255
+ "Transpose",
256
+ name=transpose_input_name,
257
+ inputs=[transpose_input_name],
258
+ outputs=[transpose_output_name],
259
+ perm=transpose_perm,
260
+ )
261
+ nodes_to_add.extend([output_transpose_node])
262
+ graph_input_output_name_dic[graph_output.name] = transpose_input_name
263
+
264
+ # Change output shape to Qnn output shape
265
+ for i in range(len(graph_output.type.tensor_type.shape.dim)):
266
+ graph_output.type.tensor_type.shape.dim[i].dim_value = qnn_input_output_tensor_dic[
267
+ graph_output.name
268
+ ].dim[i]
269
+ else:
270
+ raise AssertionError("Error: Onnx model output: " + graph_output.name + " not exist from QNN model output.")
271
+
272
+ for node in model.graph.node:
273
+ for node_input_index, node_input in enumerate(node.input):
274
+ # update consumer node for graph inputs to connect to inserted node
275
+ if node_input in graph_input_output_name_dic:
276
+ node.input[node_input_index] = graph_input_output_name_dic[node_input]
277
+
278
+ for node_output_index, node_output in enumerate(node.output):
279
+ # update producer node for graph outputs to connect to inserted node
280
+ if node_output in graph_input_output_name_dic:
281
+ node.output[node_output_index] = graph_input_output_name_dic[node_output]
282
+
283
+ model.graph.node.extend(nodes_to_add)
284
+ graph_topological_sort(model.graph)
285
+
286
+ # Add extra parameter all_tensors_to_one_file=False, size_threshold=5000 if the model exceeds protobuf 2GB limit e.g below
287
+ # onnx.save(model, args.onnx_model.replace(".onnx", "_add_trans.onnx"), all_tensors_to_one_file=False, size_threshold=5000)
288
+ onnx.save(model, args.onnx_model.replace(".onnx", "_add_trans.onnx"))
289
+
290
+
291
+ if __name__ == "__main__":
292
+ main()