onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,270 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (R) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import argparse
6
+ import os
7
+ import pathlib
8
+ import sys
9
+
10
+ import torch
11
+ from image_decoder import export_decoder_onnx, test_decoder_onnx
12
+ from image_encoder import export_image_encoder_onnx, test_image_encoder_onnx
13
+ from mask_decoder import export_mask_decoder_onnx, test_mask_decoder_onnx
14
+ from prompt_encoder import export_prompt_encoder_onnx, test_prompt_encoder_onnx
15
+ from sam2_demo import run_demo, show_all_images
16
+ from sam2_utils import load_sam2_model, sam2_onnx_path, setup_logger
17
+
18
+
19
+ def parse_arguments():
20
+ parser = argparse.ArgumentParser(description="Export SAM2 models to ONNX")
21
+
22
+ parser.add_argument(
23
+ "--model_type",
24
+ required=False,
25
+ type=str,
26
+ choices=["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"],
27
+ default="sam2_hiera_large",
28
+ help="The model type to export",
29
+ )
30
+
31
+ parser.add_argument(
32
+ "--components",
33
+ required=False,
34
+ nargs="+",
35
+ choices=["image_encoder", "mask_decoder", "prompt_encoder", "image_decoder"],
36
+ default=["image_encoder", "image_decoder"],
37
+ help="Type of ONNX models to export. "
38
+ "Note that image_decoder is a combination of prompt_encoder and mask_decoder",
39
+ )
40
+
41
+ parser.add_argument(
42
+ "--output_dir",
43
+ type=str,
44
+ help="The output directory for the ONNX models",
45
+ default="sam2_onnx_models",
46
+ )
47
+
48
+ parser.add_argument(
49
+ "--dynamic_batch_axes",
50
+ required=False,
51
+ default=False,
52
+ action="store_true",
53
+ help="Export image_encoder with dynamic batch axes",
54
+ )
55
+
56
+ parser.add_argument(
57
+ "--multimask_output",
58
+ required=False,
59
+ default=False,
60
+ action="store_true",
61
+ help="Export mask_decoder or image_decoder with multimask_output",
62
+ )
63
+
64
+ parser.add_argument(
65
+ "--disable_dynamic_multimask_via_stability",
66
+ required=False,
67
+ action="store_true",
68
+ help="Disable mask_decoder dynamic_multimask_via_stability, and output first mask only."
69
+ "This option will be ignored when multimask_output is True",
70
+ )
71
+
72
+ parser.add_argument(
73
+ "--sam2_dir",
74
+ required=False,
75
+ type=str,
76
+ default="./segment-anything-2",
77
+ help="The directory of segment-anything-2 git repository",
78
+ )
79
+
80
+ parser.add_argument(
81
+ "--overwrite",
82
+ required=False,
83
+ default=False,
84
+ action="store_true",
85
+ help="Overwrite onnx model file if exists.",
86
+ )
87
+
88
+ parser.add_argument(
89
+ "--demo",
90
+ required=False,
91
+ default=False,
92
+ action="store_true",
93
+ help="Run demo with the exported ONNX models.",
94
+ )
95
+
96
+ parser.add_argument(
97
+ "--optimize",
98
+ required=False,
99
+ default=False,
100
+ action="store_true",
101
+ help="Optimize onnx models",
102
+ )
103
+
104
+ parser.add_argument(
105
+ "--dtype", required=False, choices=["fp32", "fp16"], default="fp32", help="Data type for inference."
106
+ )
107
+
108
+ parser.add_argument(
109
+ "--use_gpu",
110
+ required=False,
111
+ default=False,
112
+ action="store_true",
113
+ help="Optimize onnx models for GPU",
114
+ )
115
+
116
+ parser.add_argument(
117
+ "--dynamo",
118
+ required=False,
119
+ default=False,
120
+ action="store_true",
121
+ help="Use dynamo for exporting onnx model. Only image_encoder supports dynamo right now.",
122
+ )
123
+
124
+ parser.add_argument(
125
+ "--verbose",
126
+ required=False,
127
+ default=False,
128
+ action="store_true",
129
+ help="Print verbose information",
130
+ )
131
+
132
+ args = parser.parse_args()
133
+ return args
134
+
135
+
136
+ def optimize_sam2_model(onnx_model_path, optimized_model_path, float16: bool, use_gpu: bool):
137
+ print(f"Optimizing {onnx_model_path} to {optimized_model_path} with float16={float16} and use_gpu={use_gpu}...")
138
+
139
+ # Import from source directory.
140
+ transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
141
+ if transformers_dir not in sys.path:
142
+ sys.path.insert(0, transformers_dir)
143
+ from optimizer import optimize_model # noqa: PLC0415
144
+
145
+ optimized_model = optimize_model(onnx_model_path, model_type="sam2", opt_level=1, use_gpu=use_gpu)
146
+ if float16:
147
+ optimized_model.convert_float_to_float16(keep_io_types=False)
148
+ optimized_model.save_model_to_file(optimized_model_path)
149
+
150
+
151
+ def main():
152
+ args = parse_arguments()
153
+
154
+ sam2_model = load_sam2_model(args.sam2_dir, args.model_type, device="cpu")
155
+
156
+ pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True)
157
+
158
+ for component in args.components:
159
+ onnx_model_path = sam2_onnx_path(args.output_dir, args.model_type, component, args.multimask_output)
160
+ if component == "image_encoder":
161
+ if args.overwrite or not os.path.exists(onnx_model_path):
162
+ export_image_encoder_onnx(
163
+ sam2_model, onnx_model_path, args.dynamic_batch_axes, args.verbose, args.dynamo
164
+ )
165
+ test_image_encoder_onnx(sam2_model, onnx_model_path, dynamic_batch_axes=args.dynamic_batch_axes)
166
+
167
+ elif component == "mask_decoder":
168
+ if args.overwrite or not os.path.exists(onnx_model_path):
169
+ export_mask_decoder_onnx(
170
+ sam2_model,
171
+ onnx_model_path,
172
+ args.multimask_output,
173
+ not args.disable_dynamic_multimask_via_stability,
174
+ args.verbose,
175
+ )
176
+ test_mask_decoder_onnx(
177
+ sam2_model,
178
+ onnx_model_path,
179
+ args.multimask_output,
180
+ not args.disable_dynamic_multimask_via_stability,
181
+ )
182
+ elif component == "prompt_encoder":
183
+ if args.overwrite or not os.path.exists(onnx_model_path):
184
+ export_prompt_encoder_onnx(sam2_model, onnx_model_path)
185
+ test_prompt_encoder_onnx(sam2_model, onnx_model_path)
186
+ else:
187
+ assert component == "image_decoder"
188
+ if args.overwrite or not os.path.exists(onnx_model_path):
189
+ export_decoder_onnx(sam2_model, onnx_model_path, args.multimask_output)
190
+ test_decoder_onnx(sam2_model, onnx_model_path, args.multimask_output)
191
+
192
+ suffix = ""
193
+ convert_to_fp16 = args.dtype == "fp16"
194
+ if args.optimize:
195
+ suffix = f"_{args.dtype}_" + ("gpu" if args.use_gpu else "cpu")
196
+ for component in args.components:
197
+ onnx_model_path = sam2_onnx_path(args.output_dir, args.model_type, component, args.multimask_output)
198
+ optimized_model_path = sam2_onnx_path(
199
+ args.output_dir, args.model_type, component, args.multimask_output, suffix
200
+ )
201
+ optimize_sam2_model(onnx_model_path, optimized_model_path, convert_to_fp16, args.use_gpu)
202
+
203
+ if args.demo:
204
+ # Export required ONNX models for demo if not already exported.
205
+ image_encoder_onnx_path = sam2_onnx_path(
206
+ args.output_dir, args.model_type, "image_encoder", args.multimask_output
207
+ )
208
+ if not os.path.exists(image_encoder_onnx_path):
209
+ export_image_encoder_onnx(sam2_model, image_encoder_onnx_path, args.dynamic_batch_axes, args.verbose)
210
+
211
+ image_decoder_onnx_path = sam2_onnx_path(args.output_dir, args.model_type, "image_decoder", False)
212
+ if not os.path.exists(image_decoder_onnx_path):
213
+ export_decoder_onnx(sam2_model, image_decoder_onnx_path, False)
214
+
215
+ image_decoder_multi_onnx_path = sam2_onnx_path(args.output_dir, args.model_type, "image_decoder", True)
216
+ if not os.path.exists(image_decoder_multi_onnx_path):
217
+ export_decoder_onnx(sam2_model, image_decoder_multi_onnx_path, True)
218
+
219
+ dtype = torch.float32 if args.dtype == "fp32" else torch.float16
220
+ if suffix:
221
+ optimized_image_encoder_onnx_path = image_encoder_onnx_path.replace(".onnx", f"{suffix}.onnx")
222
+ if not os.path.exists(optimized_image_encoder_onnx_path):
223
+ optimize_sam2_model(
224
+ image_encoder_onnx_path, optimized_image_encoder_onnx_path, convert_to_fp16, args.use_gpu
225
+ )
226
+
227
+ optimized_image_decoder_onnx_path = image_decoder_onnx_path.replace(".onnx", f"{suffix}.onnx")
228
+ if not os.path.exists(optimized_image_decoder_onnx_path):
229
+ optimize_sam2_model(
230
+ image_decoder_onnx_path, optimized_image_decoder_onnx_path, convert_to_fp16, args.use_gpu
231
+ )
232
+
233
+ optimized_image_decoder_multi_onnx_path = image_decoder_multi_onnx_path.replace(".onnx", f"{suffix}.onnx")
234
+ if not os.path.exists(optimized_image_decoder_multi_onnx_path):
235
+ optimize_sam2_model(
236
+ image_decoder_multi_onnx_path,
237
+ optimized_image_decoder_multi_onnx_path,
238
+ convert_to_fp16,
239
+ args.use_gpu,
240
+ )
241
+
242
+ # Use optimized models to run demo.
243
+ image_encoder_onnx_path = optimized_image_encoder_onnx_path
244
+ image_decoder_onnx_path = optimized_image_decoder_onnx_path
245
+ image_decoder_multi_onnx_path = optimized_image_decoder_multi_onnx_path
246
+
247
+ ort_image_files = run_demo(
248
+ args.sam2_dir,
249
+ args.model_type,
250
+ engine="ort",
251
+ dtype=dtype,
252
+ image_encoder_onnx_path=image_encoder_onnx_path,
253
+ image_decoder_onnx_path=image_decoder_onnx_path,
254
+ image_decoder_multi_onnx_path=image_decoder_multi_onnx_path,
255
+ use_gpu=args.use_gpu,
256
+ )
257
+ print("demo output files for ONNX Runtime:", ort_image_files)
258
+
259
+ # Get results from torch engine to compare.
260
+ torch_image_files = run_demo(args.sam2_dir, args.model_type, engine="torch", dtype=dtype, use_gpu=args.use_gpu)
261
+ print("demo output files for PyTorch:", torch_image_files)
262
+
263
+ show_all_images(ort_image_files, torch_image_files, suffix)
264
+ print(f"Combined demo output: sam2_demo{suffix}.png")
265
+
266
+
267
+ if __name__ == "__main__":
268
+ setup_logger(verbose=False)
269
+ with torch.no_grad():
270
+ main()
@@ -0,0 +1,272 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (R) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import logging
6
+ import warnings
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from image_encoder import SAM2ImageEncoder, random_sam2_input_image
11
+ from mask_decoder import SAM2MaskDecoder
12
+ from prompt_encoder import SAM2PromptEncoder
13
+ from sam2.modeling.sam2_base import SAM2Base
14
+ from sam2_utils import compare_tensors_with_tolerance
15
+ from torch import nn
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class SAM2ImageDecoder(nn.Module):
21
+ def __init__(
22
+ self,
23
+ sam_model: SAM2Base,
24
+ multimask_output: bool,
25
+ dynamic_multimask_via_stability: bool = True,
26
+ return_logits: bool = False,
27
+ mask_threshold: float = 0.0,
28
+ ) -> None:
29
+ super().__init__()
30
+ self.prompt_encoder = SAM2PromptEncoder(sam_model)
31
+ self.mask_decoder = SAM2MaskDecoder(sam_model, multimask_output, dynamic_multimask_via_stability)
32
+ self.return_logits = return_logits
33
+ self.mask_threshold = mask_threshold
34
+
35
+ @torch.no_grad()
36
+ def forward(
37
+ self,
38
+ image_features_0: torch.Tensor,
39
+ image_features_1: torch.Tensor,
40
+ image_embeddings: torch.Tensor,
41
+ point_coords: torch.Tensor,
42
+ point_labels: torch.Tensor,
43
+ input_masks: torch.Tensor,
44
+ has_input_masks: torch.Tensor,
45
+ original_image_size: torch.Tensor,
46
+ enable_nvtx_profile: bool = False,
47
+ ):
48
+ """
49
+ Decode masks from image features and prompts. Batched images are not supported. H=W=1024.
50
+
51
+ Args:
52
+ image_features_0 (torch.Tensor): [1, 32, H/4, W/4]. high resolution features of level 0 from image encoder.
53
+ image_features_1 (torch.Tensor): [1, 64, H/8, W/8]. high resolution features of level 1 from image encoder.
54
+ image_embeddings (torch.Tensor): [1, 256, H/16, W/16]. image embedding from image encoder.
55
+ point_coords (torch.Tensor): [L, P, 2] shape and float32 dtype and contains the absolute pixel
56
+ coordinate in (x, y) format of the P input points in image of size 1024x1024.
57
+ point_labels (torch.Tensor): shape [L, P] and int32 dtype, where 1 means
58
+ positive (foreground), 0 means negative (background), -1 means padding,
59
+ 2 (box left upper corner), 3 (box right bottom corner).
60
+ input_masks (torch.Tensor): [L, 1, H/4, W/4]. Low resolution mask input to the model.
61
+ Typically coming from a previous iteration.
62
+ has_input_masks (torch.Tensor): [L]. 1.0 if input_masks is used, 0.0 otherwise.
63
+ original_image_size(torch.Tensor): [2]. original image size H_o, W_o.
64
+ enable_nvtx_profile (bool): enable NVTX profiling.
65
+
66
+ Returns:
67
+ masks (torch.Tensor): [1, M, H_o, W_o] where M=3 or 1. Masks of original image size.
68
+ iou_predictions (torch.Tensor): [1, M]. scores for M masks.
69
+ low_res_masks (torch.Tensor, optional): [1, M, H/4, W/4]. low resolution masks.
70
+ """
71
+ nvtx_helper = None
72
+ if enable_nvtx_profile:
73
+ from nvtx_helper import NvtxHelper # noqa: PLC0415
74
+
75
+ nvtx_helper = NvtxHelper(["prompt_encoder", "mask_decoder", "post_process"])
76
+
77
+ if nvtx_helper is not None:
78
+ nvtx_helper.start_profile("prompt_encoder", color="blue")
79
+
80
+ sparse_embeddings, dense_embeddings, image_pe = self.prompt_encoder(
81
+ point_coords, point_labels, input_masks, has_input_masks
82
+ )
83
+
84
+ if nvtx_helper is not None:
85
+ nvtx_helper.stop_profile("prompt_encoder")
86
+ nvtx_helper.start_profile("mask_decoder", color="red")
87
+
88
+ low_res_masks, iou_predictions = self.mask_decoder(
89
+ image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings
90
+ )
91
+
92
+ if nvtx_helper is not None:
93
+ nvtx_helper.stop_profile("mask_decoder")
94
+ nvtx_helper.start_profile("post_process", color="green")
95
+
96
+ # Interpolate the low resolution masks back to the original image size.
97
+ masks = F.interpolate(
98
+ low_res_masks,
99
+ (original_image_size[0], original_image_size[1]),
100
+ mode="bilinear",
101
+ align_corners=False, # Note that align_corners=True has less mismatches during comparing ORT and PyTorch.
102
+ )
103
+
104
+ low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
105
+ if not self.return_logits:
106
+ masks = masks > self.mask_threshold
107
+
108
+ if nvtx_helper is not None:
109
+ nvtx_helper.stop_profile("post_process")
110
+ nvtx_helper.print_latency()
111
+
112
+ return masks, iou_predictions, low_res_masks
113
+
114
+
115
+ def export_decoder_onnx(
116
+ sam2_model: SAM2Base,
117
+ onnx_model_path: str,
118
+ multimask_output: bool = False,
119
+ verbose: bool = False,
120
+ ):
121
+ batch_size = 1
122
+ image = random_sam2_input_image(batch_size)
123
+ sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
124
+ image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
125
+
126
+ logger.info("image_features_0.shape: %s", image_features_0.shape)
127
+ logger.info("image_features_1.shape: %s", image_features_1.shape)
128
+ logger.info("image_embeddings.shape: %s", image_embeddings.shape)
129
+
130
+ sam2_decoder = SAM2ImageDecoder(
131
+ sam2_model,
132
+ multimask_output=multimask_output,
133
+ dynamic_multimask_via_stability=True,
134
+ ).cpu()
135
+
136
+ num_labels = 2
137
+ num_points = 3
138
+ point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
139
+ point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
140
+ input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
141
+ has_input_masks = torch.ones(1, dtype=torch.float)
142
+ original_image_size = torch.tensor([1200, 1800], dtype=torch.int32)
143
+
144
+ example_inputs = (
145
+ image_features_0,
146
+ image_features_1,
147
+ image_embeddings,
148
+ point_coords,
149
+ point_labels,
150
+ input_masks,
151
+ has_input_masks,
152
+ original_image_size,
153
+ )
154
+
155
+ logger.info("point_coords.shape: %s", point_coords.shape)
156
+ logger.info("point_labels.shape: %s", point_labels.shape)
157
+ logger.info("input_masks.shape: %s", input_masks.shape)
158
+ logger.info("has_input_masks.shape: %s", has_input_masks.shape)
159
+ logger.info("original_image_size.shape: %s", original_image_size.shape)
160
+
161
+ if verbose:
162
+ masks, iou_predictions, low_res_masks = sam2_decoder(*example_inputs)
163
+ logger.info("masks.shape: %s", masks.shape)
164
+ logger.info("iou_predictions.shape: %s", iou_predictions.shape)
165
+ logger.info("low_res_masks.shape: %s", low_res_masks.shape)
166
+
167
+ input_names = [
168
+ "image_features_0",
169
+ "image_features_1",
170
+ "image_embeddings",
171
+ "point_coords",
172
+ "point_labels",
173
+ "input_masks",
174
+ "has_input_masks",
175
+ "original_image_size",
176
+ ]
177
+
178
+ output_names = ["masks", "iou_predictions", "low_res_masks"]
179
+
180
+ dynamic_axes = {
181
+ "point_coords": {0: "num_labels", 1: "num_points"},
182
+ "point_labels": {0: "num_labels", 1: "num_points"},
183
+ "input_masks": {0: "num_labels"},
184
+ "has_input_masks": {0: "num_labels"},
185
+ "masks": {0: "num_labels", 2: "original_image_height", 3: "original_image_width"},
186
+ "low_res_masks": {0: "num_labels"},
187
+ "iou_predictions": {0: "num_labels"},
188
+ }
189
+
190
+ with warnings.catch_warnings():
191
+ if not verbose:
192
+ warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
193
+ warnings.filterwarnings("ignore", category=UserWarning)
194
+
195
+ torch.onnx.export(
196
+ sam2_decoder,
197
+ example_inputs,
198
+ onnx_model_path,
199
+ export_params=True,
200
+ opset_version=16,
201
+ do_constant_folding=True,
202
+ input_names=input_names,
203
+ output_names=output_names,
204
+ dynamic_axes=dynamic_axes,
205
+ )
206
+
207
+ logger.info("decoder onnx model saved to %s", onnx_model_path)
208
+
209
+
210
+ def test_decoder_onnx(
211
+ sam2_model: SAM2Base,
212
+ onnx_model_path: str,
213
+ multimask_output=False,
214
+ ):
215
+ batch_size = 1
216
+ image = random_sam2_input_image(batch_size)
217
+ sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
218
+ image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
219
+
220
+ sam2_image_decoder = SAM2ImageDecoder(
221
+ sam2_model,
222
+ multimask_output=multimask_output,
223
+ dynamic_multimask_via_stability=True,
224
+ ).cpu()
225
+
226
+ num_labels = 1
227
+ num_points = 5
228
+ point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
229
+ point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
230
+ input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
231
+ has_input_masks = torch.zeros(1, dtype=torch.float)
232
+ original_image_size = torch.tensor([1500, 1500], dtype=torch.int32)
233
+
234
+ example_inputs = (
235
+ image_features_0,
236
+ image_features_1,
237
+ image_embeddings,
238
+ point_coords,
239
+ point_labels,
240
+ input_masks,
241
+ has_input_masks,
242
+ original_image_size,
243
+ )
244
+
245
+ masks, iou_predictions, low_res_masks = sam2_image_decoder(*example_inputs)
246
+
247
+ import onnxruntime # noqa: PLC0415
248
+
249
+ ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
250
+
251
+ model_inputs = ort_session.get_inputs()
252
+ input_names = [model_inputs[i].name for i in range(len(model_inputs))]
253
+ logger.info("input_names: %s", input_names)
254
+
255
+ model_outputs = ort_session.get_outputs()
256
+ output_names = [model_outputs[i].name for i in range(len(model_outputs))]
257
+ logger.info("output_names: %s", output_names)
258
+ inputs = {model_inputs[i].name: example_inputs[i].numpy() for i in range(len(model_inputs))}
259
+ outputs = ort_session.run(output_names, inputs)
260
+
261
+ for i, output_name in enumerate(output_names):
262
+ logger.info(f"{output_name}.shape: %s", outputs[i].shape)
263
+
264
+ ort_masks, ort_iou_predictions, ort_low_res_masks = outputs
265
+ if (
266
+ compare_tensors_with_tolerance("masks", masks.float(), torch.tensor(ort_masks).float())
267
+ and compare_tensors_with_tolerance("iou_predictions", iou_predictions, torch.tensor(ort_iou_predictions))
268
+ and compare_tensors_with_tolerance("low_res_masks", low_res_masks, torch.tensor(ort_low_res_masks))
269
+ ):
270
+ print("onnx model has been verified:", onnx_model_path)
271
+ else:
272
+ print("onnx model verification failed:", onnx_model_path)