onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,186 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (R) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import logging
6
+ import warnings
7
+
8
+ import torch
9
+ from sam2.modeling.sam2_base import SAM2Base
10
+ from sam2_utils import compare_tensors_with_tolerance, random_sam2_input_image
11
+ from torch import nn
12
+
13
+ import onnxruntime
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class SAM2ImageEncoder(nn.Module):
19
+ def __init__(self, sam_model: SAM2Base) -> None:
20
+ super().__init__()
21
+ self.model = sam_model
22
+ self.image_encoder = sam_model.image_encoder
23
+ self.no_mem_embed = sam_model.no_mem_embed
24
+
25
+ def forward(
26
+ self,
27
+ image: torch.Tensor,
28
+ enable_nvtx_profile: bool = False,
29
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
30
+ """
31
+ Encodes images into features.
32
+
33
+ Only supports H=W=1024. If you want to use different image sizes like 512x512,
34
+ see https://github.com/facebookresearch/segment-anything-2/issues/138.
35
+
36
+ Args:
37
+ image (torch.Tensor): images of shape [B, 3, H, W], B is batch size, H and W are height and width.
38
+ enable_nvtx_profile (bool): enable NVTX profiling.
39
+
40
+ Returns:
41
+ image_features_0: image features of shape [B, 32, H/4, W/4] - high resolution features of level 0
42
+ image_features_1: image features of shape [B, 64, H/8, W/8] - high resolution features of level 1
43
+ image_embeddings: image features of shape [B, 256, H/16, W/16] - 16 is the backbone_stride
44
+ """
45
+ nvtx_helper = None
46
+ if enable_nvtx_profile:
47
+ from nvtx_helper import NvtxHelper
48
+
49
+ nvtx_helper = NvtxHelper(["image_encoder", "post_process"])
50
+
51
+ if nvtx_helper is not None:
52
+ nvtx_helper.start_profile("image_encoder")
53
+
54
+ backbone_out = self.image_encoder(image)
55
+
56
+ if nvtx_helper is not None:
57
+ nvtx_helper.stop_profile("image_encoder")
58
+ nvtx_helper.start_profile("post_process")
59
+
60
+ # precompute projected level 0 and level 1 features in SAM decoder
61
+ # to avoid running it again on every SAM click
62
+ backbone_out["backbone_fpn"][0] = self.model.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
63
+ backbone_out["backbone_fpn"][1] = self.model.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
64
+
65
+ # Prepare and flatten visual features.
66
+ feature_maps = backbone_out["backbone_fpn"][-self.model.num_feature_levels :]
67
+ vision_pos_embeds = backbone_out["vision_pos_enc"][-self.model.num_feature_levels :]
68
+ feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
69
+
70
+ # flatten NxCxHxW to HWxNxC
71
+ # TODO: we should avoid this transpose since it will be transposed back to NCHW later.
72
+ vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
73
+
74
+ vision_feats[-1] = vision_feats[-1] + self.no_mem_embed
75
+
76
+ feats = [
77
+ feat.permute(1, 2, 0).reshape(1, -1, *feat_size)
78
+ for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1])
79
+ ][::-1]
80
+
81
+ if nvtx_helper is not None:
82
+ nvtx_helper.stop_profile("post_process")
83
+ nvtx_helper.print_latency()
84
+
85
+ return feats[0], feats[1], feats[2]
86
+
87
+
88
+ def export_image_encoder_onnx(
89
+ sam2_model: SAM2Base,
90
+ onnx_model_path: str,
91
+ dynamic_batch_axes: bool = False,
92
+ verbose: bool = False,
93
+ ):
94
+ image = random_sam2_input_image()
95
+
96
+ sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
97
+ image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
98
+ logger.info("image.shape: %s", image.shape)
99
+ logger.info("image_features_0.shape: %s", image_features_0.shape)
100
+ logger.info("image_features_1.shape: %s", image_features_1.shape)
101
+ logger.info("image_embeddings.shape: %s", image_embeddings.shape)
102
+
103
+ dynamic_axes = None
104
+ if dynamic_batch_axes:
105
+ dynamic_axes = {
106
+ "image": {0: "batch_size"},
107
+ "image_features_0": {0: "batch_size"},
108
+ "image_features_1": {0: "batch_size"},
109
+ "image_embeddings": {0: "batch_size"},
110
+ }
111
+
112
+ with warnings.catch_warnings():
113
+ if not verbose:
114
+ warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
115
+ warnings.filterwarnings("ignore", category=UserWarning)
116
+ torch.onnx.export(
117
+ sam2_encoder,
118
+ image,
119
+ onnx_model_path,
120
+ export_params=True,
121
+ opset_version=17,
122
+ do_constant_folding=True,
123
+ input_names=["image"],
124
+ output_names=["image_features_0", "image_features_1", "image_embeddings"],
125
+ dynamic_axes=dynamic_axes,
126
+ )
127
+
128
+ print("encoder onnx model saved to", onnx_model_path)
129
+
130
+
131
+ def test_image_encoder_onnx(
132
+ sam2_model: SAM2Base,
133
+ onnx_model_path: str,
134
+ dynamic_batch_axes=False,
135
+ ):
136
+ ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers())
137
+
138
+ model_inputs = ort_session.get_inputs()
139
+ input_names = [model_inputs[i].name for i in range(len(model_inputs))]
140
+ logger.info("input_names: %s", input_names)
141
+
142
+ model_outputs = ort_session.get_outputs()
143
+ output_names = [model_outputs[i].name for i in range(len(model_outputs))]
144
+ logger.info("output_names: %s", output_names)
145
+
146
+ batch_sizes = [1, 2] if dynamic_batch_axes else [1]
147
+ for batch_size in batch_sizes:
148
+ image = random_sam2_input_image(batch_size)
149
+
150
+ sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
151
+ image_features_0, image_features_1, image_embeddings = sam2_encoder(image.clone())
152
+
153
+ logger.info("image.shape: %s", image.shape)
154
+ logger.info("image_features_0.shape: %s", image_features_0.shape)
155
+ logger.info("image_features_1.shape: %s", image_features_1.shape)
156
+ logger.info("image_embeddings.shape: %s", image_embeddings.shape)
157
+
158
+ outputs = ort_session.run(output_names, {"image": image.numpy()})
159
+ for i, output_name in enumerate(output_names):
160
+ logger.info("output %s shape %s", output_name, outputs[i].shape)
161
+ ort_image_features_0, ort_image_features_1, ort_image_embeddings = outputs
162
+
163
+ # ONNXRuntime and PyTorch has about 0.75% mismatched elements, but seems not impacting segmentation results.
164
+ if (
165
+ compare_tensors_with_tolerance(
166
+ "image_features_0",
167
+ image_features_0,
168
+ torch.tensor(ort_image_features_0),
169
+ mismatch_percentage_tolerance=1,
170
+ )
171
+ and compare_tensors_with_tolerance(
172
+ "image_features_1",
173
+ image_features_1,
174
+ torch.tensor(ort_image_features_1),
175
+ mismatch_percentage_tolerance=1,
176
+ )
177
+ and compare_tensors_with_tolerance(
178
+ "image_embeddings",
179
+ image_embeddings,
180
+ torch.tensor(ort_image_embeddings),
181
+ mismatch_percentage_tolerance=1,
182
+ )
183
+ ):
184
+ print(f"onnx model has been verified for batch_size={batch_size}: {onnx_model_path}")
185
+ else:
186
+ print(f"onnx model verification failed for batch_size={batch_size}: {onnx_model_path}")
@@ -0,0 +1,208 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (R) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import logging
6
+ import warnings
7
+
8
+ import torch
9
+ from image_encoder import SAM2ImageEncoder, random_sam2_input_image
10
+ from prompt_encoder import SAM2PromptEncoder
11
+ from sam2.modeling.sam2_base import SAM2Base
12
+ from torch import nn
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class SAM2MaskDecoder(nn.Module):
18
+ def __init__(
19
+ self,
20
+ sam_model: SAM2Base,
21
+ multimask_output: bool,
22
+ dynamic_multimask_via_stability: bool = True,
23
+ ) -> None:
24
+ super().__init__()
25
+ self.mask_decoder = sam_model.sam_mask_decoder
26
+ self.prompt_encoder = sam_model.sam_prompt_encoder
27
+ self.model = sam_model
28
+ self.multimask_output = multimask_output
29
+ self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
30
+
31
+ @torch.no_grad()
32
+ def forward(
33
+ self,
34
+ image_features_0: torch.Tensor,
35
+ image_features_1: torch.Tensor,
36
+ image_embeddings: torch.Tensor,
37
+ image_pe: torch.Tensor,
38
+ sparse_embeddings: torch.Tensor,
39
+ dense_embeddings: torch.Tensor,
40
+ ):
41
+ """
42
+ Decode masks from image and prompt embeddings. Only support H=W=1024.
43
+
44
+ Args:
45
+ image_features_0 (torch.Tensor): [1, 32, H/4, W/4]. high resolution features of level 0 from image encoder.
46
+ image_features_1 (torch.Tensor): [1, 64, H/8, W/8]. high resolution features of level 1 from image encoder.
47
+ image_embeddings (torch.Tensor): [1, 256, H/16, W/16]. image embedding from image encoder.
48
+ image_pe (torch.Tensor): [1, 256, H/16, W/16]. image positional encoding.
49
+ sparse_embeddings (torch.Tensor): [L, P+1, 256], embedding for points and boxes.
50
+ dense_embeddings (torch.Tensor): [L, 256, H/16, W/16]. embedding for input masks.
51
+
52
+ Returns:
53
+ low_res_masks (torch.Tensor, optional): [1, M, H/4, W/4]. low resolution masks.
54
+ iou_predictions (torch.Tensor): [1, M]. scores for M masks.
55
+ """
56
+ low_res_masks, iou_predictions, _, _ = self.mask_decoder.predict_masks(
57
+ image_embeddings=image_embeddings,
58
+ image_pe=image_pe,
59
+ sparse_prompt_embeddings=sparse_embeddings,
60
+ dense_prompt_embeddings=dense_embeddings,
61
+ repeat_image=sparse_embeddings.shape[0] > 1, # batch mode
62
+ high_res_features=[image_features_0, image_features_1],
63
+ )
64
+
65
+ if self.multimask_output:
66
+ low_res_masks = low_res_masks[:, 1:, :, :]
67
+ iou_predictions = iou_predictions[:, 1:]
68
+ elif self.dynamic_multimask_via_stability:
69
+ # When outputting a single mask, if the stability score from the current single-mask
70
+ # output (based on output token 0) falls below a threshold, we instead select from
71
+ # multi-mask outputs (based on output token 1~3) the mask with the highest predicted IoU score.
72
+ low_res_masks, iou_predictions = self.mask_decoder._dynamic_multimask_via_stability(
73
+ low_res_masks, iou_predictions
74
+ )
75
+ else:
76
+ low_res_masks = low_res_masks[:, 0:1, :, :]
77
+ iou_predictions = iou_predictions[:, 0:1]
78
+
79
+ return low_res_masks, iou_predictions
80
+
81
+
82
+ def export_mask_decoder_onnx(
83
+ sam2_model: SAM2Base,
84
+ onnx_model_path: str,
85
+ multimask_output: bool,
86
+ dynamic_multimask_via_stability: bool = True,
87
+ verbose=False,
88
+ ):
89
+ sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
90
+
91
+ image = random_sam2_input_image()
92
+ sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
93
+ image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
94
+ logger.info("image_features_0.shape: %s", image_features_0.shape)
95
+ logger.info("image_features_1.shape: %s", image_features_1.shape)
96
+ logger.info("image_embeddings.shape: %s", image_embeddings.shape)
97
+
98
+ # encode an random prompt
99
+ num_labels = 2
100
+ num_points = 3
101
+ point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
102
+ point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.float)
103
+ input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
104
+ has_input_masks = torch.ones(1, dtype=torch.float)
105
+
106
+ sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
107
+ point_coords, point_labels, input_masks, has_input_masks
108
+ )
109
+
110
+ logger.info("sparse_embeddings.shape: %s", sparse_embeddings.shape)
111
+ logger.info("dense_embeddings.shape: %s", dense_embeddings.shape)
112
+ logger.info("image_pe.shape: %s", image_pe.shape)
113
+
114
+ sam2_mask_decoder = SAM2MaskDecoder(sam2_model, multimask_output, dynamic_multimask_via_stability)
115
+ inputs = (image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings)
116
+ low_res_masks, iou_predictions = sam2_mask_decoder(*inputs)
117
+ logger.info("low_res_masks.shape: %s", low_res_masks.shape)
118
+ logger.info("iou_predictions.shape: %s", iou_predictions.shape)
119
+
120
+ with warnings.catch_warnings():
121
+ if not verbose:
122
+ warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
123
+ warnings.filterwarnings("ignore", category=UserWarning)
124
+ torch.onnx.export(
125
+ sam2_mask_decoder,
126
+ inputs,
127
+ onnx_model_path,
128
+ export_params=True,
129
+ opset_version=18,
130
+ do_constant_folding=True,
131
+ input_names=[
132
+ "image_features_0",
133
+ "image_features_1",
134
+ "image_embeddings",
135
+ "image_pe",
136
+ "sparse_embeddings",
137
+ "dense_embeddings",
138
+ ],
139
+ output_names=["low_res_masks", "iou_predictions"],
140
+ dynamic_axes={
141
+ "sparse_embeddings": {0: "num_labels", 1: "num_points+1"},
142
+ "dense_embeddings": {0: "num_labels"},
143
+ "low_res_masks": {0: "num_labels"},
144
+ "iou_predictions": {0: "num_labels"},
145
+ },
146
+ )
147
+
148
+ print("mask decoder onnx model saved to", onnx_model_path)
149
+
150
+
151
+ def test_mask_decoder_onnx(
152
+ sam2_model: SAM2Base,
153
+ onnx_model_path: str,
154
+ multimask_output: bool,
155
+ dynamic_multimask_via_stability: bool,
156
+ ):
157
+ sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
158
+
159
+ image = random_sam2_input_image()
160
+ sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
161
+ image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
162
+
163
+ num_labels = 1
164
+ num_points = 5
165
+ point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
166
+ point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.float)
167
+ input_masks = torch.rand(num_labels, 1, 256, 256, dtype=torch.float)
168
+ has_input_masks = torch.ones(1, dtype=torch.float)
169
+
170
+ sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
171
+ point_coords, point_labels, input_masks, has_input_masks
172
+ )
173
+
174
+ sam2_mask_decoder = SAM2MaskDecoder(sam2_model, multimask_output, dynamic_multimask_via_stability)
175
+ inputs = (image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings)
176
+ low_res_masks, iou_predictions = sam2_mask_decoder(*inputs)
177
+
178
+ import onnxruntime
179
+
180
+ ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers())
181
+
182
+ model_inputs = ort_session.get_inputs()
183
+ input_names = [model_inputs[i].name for i in range(len(model_inputs))]
184
+ logger.info("input_names: %s", input_names)
185
+
186
+ model_outputs = ort_session.get_outputs()
187
+ output_names = [model_outputs[i].name for i in range(len(model_outputs))]
188
+ logger.info("output_names: %s", output_names)
189
+
190
+ outputs = ort_session.run(
191
+ output_names,
192
+ {
193
+ "image_features_0": image_features_0.numpy(),
194
+ "image_features_1": image_features_1.numpy(),
195
+ "image_embeddings": image_embeddings.numpy(),
196
+ "image_pe": image_pe.numpy(),
197
+ "sparse_embeddings": sparse_embeddings.numpy(),
198
+ "dense_embeddings": dense_embeddings.numpy(),
199
+ },
200
+ )
201
+
202
+ for i, output_name in enumerate(output_names):
203
+ logger.info("output %s shape: %s", output_name, outputs[i].shape)
204
+
205
+ ort_low_res_masks, ort_iou_predictions = outputs
206
+ torch.testing.assert_close(low_res_masks, torch.tensor(ort_low_res_masks), atol=5e-3, rtol=1e-4)
207
+ torch.testing.assert_close(iou_predictions, torch.tensor(ort_iou_predictions), atol=5e-3, rtol=1e-4)
208
+ print(f"onnx model has been verified: {onnx_model_path}")
@@ -0,0 +1,33 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (R) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import nvtx
6
+ from cuda import cudart
7
+
8
+
9
+ class NvtxHelper:
10
+ def __init__(self, stages):
11
+ self.stages = stages
12
+ self.events = {}
13
+ for stage in stages:
14
+ for marker in ["start", "stop"]:
15
+ self.events[stage + "-" + marker] = cudart.cudaEventCreate()[1]
16
+ self.markers = {}
17
+
18
+ def start_profile(self, stage, color="blue"):
19
+ self.markers[stage] = nvtx.start_range(message=stage, color=color)
20
+ event_name = stage + "-start"
21
+ if event_name in self.events:
22
+ cudart.cudaEventRecord(self.events[event_name], 0)
23
+
24
+ def stop_profile(self, stage):
25
+ event_name = stage + "-stop"
26
+ if event_name in self.events:
27
+ cudart.cudaEventRecord(self.events[event_name], 0)
28
+ nvtx.end_range(self.markers[stage])
29
+
30
+ def print_latency(self):
31
+ for stage in self.stages:
32
+ latency = cudart.cudaEventElapsedTime(self.events[f"{stage}-start"], self.events[f"{stage}-stop"])[1]
33
+ print(f"{stage}: {latency:.2f} ms")
@@ -0,0 +1,189 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (R) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import logging
6
+
7
+ import torch
8
+ from sam2.modeling.sam2_base import SAM2Base
9
+ from sam2_utils import compare_tensors_with_tolerance
10
+ from torch import nn
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class SAM2PromptEncoder(nn.Module):
16
+ def __init__(self, sam_model: SAM2Base):
17
+ super().__init__()
18
+ self.prompt_encoder = sam_model.sam_prompt_encoder
19
+ self.model = sam_model
20
+
21
+ @torch.no_grad()
22
+ def forward(
23
+ self,
24
+ point_coords: torch.Tensor,
25
+ point_labels: torch.Tensor,
26
+ input_masks: torch.Tensor,
27
+ has_input_masks: torch.Tensor,
28
+ ):
29
+ """Encode prompts.
30
+
31
+ Args:
32
+ point_coords (torch.Tensor): [L, P, 2] shape and float32 dtype and contains the absolute pixel
33
+ coordinate in (x, y) format of the P input points in image of size 1024x1024.
34
+ point_labels (torch.Tensor): shape [L, P] and int32 dtype, where 1 means
35
+ positive (foreground), 0 means negative (background), -1 means padding,
36
+ 2 (box left upper corner), 3 (box right bottom corner).
37
+ input_masks (torch.Tensor): [L, 1, H/4, W/4]. Low resolution mask input to the model.
38
+ Typically coming from a previous iteration.
39
+ has_input_masks (torch.Tensor): [L]. 1.0 if input_masks is used, 0.0 otherwise.
40
+ Returns:
41
+ sparse_embeddings (torch.Tensor): [L, P+1, 256], embedding for points and boxes.
42
+ dense_embeddings (torch.Tensor): [L, 256, 64, 64]. embedding for input masks.
43
+ image_pe (torch.Tensor, optional): [1, 256, 64, 64]. image positional encoding.
44
+ """
45
+ sparse_embeddings = self._embed_points(point_coords, point_labels)
46
+ dense_embeddings = self._embed_masks(input_masks, has_input_masks)
47
+ image_pe = self.prompt_encoder.get_dense_pe()
48
+
49
+ return sparse_embeddings, dense_embeddings, image_pe
50
+
51
+ def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
52
+ point_coords = point_coords + 0.5
53
+
54
+ padding_point = torch.zeros((point_coords.shape[0], 1, 2), device=point_coords.device)
55
+ padding_label = -torch.ones((point_labels.shape[0], 1), device=point_labels.device)
56
+ point_coords = torch.cat([point_coords, padding_point], dim=1)
57
+ point_labels = torch.cat([point_labels, padding_label], dim=1)
58
+
59
+ # Note that the input coordinates are based on image size 1024x1024. Here we normalize it to [0.0, 1.0).
60
+ point_coords[:, :, 0] = point_coords[:, :, 0] / self.model.image_size
61
+ point_coords[:, :, 1] = point_coords[:, :, 1] / self.model.image_size
62
+
63
+ point_embedding = self.prompt_encoder.pe_layer._pe_encoding(point_coords)
64
+ point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
65
+
66
+ point_embedding = point_embedding * (point_labels != -1)
67
+ point_embedding = point_embedding + self.prompt_encoder.not_a_point_embed.weight * (point_labels == -1)
68
+
69
+ for i in range(self.prompt_encoder.num_point_embeddings):
70
+ point_embedding = point_embedding + self.prompt_encoder.point_embeddings[i].weight * (point_labels == i)
71
+
72
+ return point_embedding
73
+
74
+ def _embed_masks(self, input_masks: torch.Tensor, has_input_masks: torch.Tensor) -> torch.Tensor:
75
+ mask_embedding = self.prompt_encoder.mask_downscaling(input_masks)
76
+ no_mask_embedding = self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
77
+ logger.info("no_mask_embedding.shape: %s", no_mask_embedding.shape)
78
+ mask_embedding = has_input_masks * mask_embedding + (1.0 - has_input_masks) * no_mask_embedding
79
+ logger.info("mask_embedding.shape: %s", mask_embedding.shape)
80
+ return mask_embedding
81
+
82
+
83
+ def export_prompt_encoder_onnx(
84
+ sam2_model: SAM2Base,
85
+ onnx_model_path: str,
86
+ ):
87
+ sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
88
+
89
+ num_labels = 2
90
+ num_points = 3
91
+ point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
92
+ point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
93
+ input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
94
+ has_input_masks = torch.ones(1, dtype=torch.float)
95
+
96
+ sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
97
+ point_coords, point_labels, input_masks, has_input_masks
98
+ )
99
+
100
+ logger.info("point_coords.shape: %s", point_coords.shape)
101
+ logger.info("point_labels.shape: %s", point_labels.shape)
102
+ logger.info("input_masks.shape: %s", input_masks.shape)
103
+ logger.info("has_input_masks.shape: %s", has_input_masks.shape)
104
+
105
+ logger.info("sparse_embeddings.shape: %s", sparse_embeddings.shape)
106
+ logger.info("dense_embeddings.shape: %s", dense_embeddings.shape)
107
+ logger.info("image_pe.shape: %s", image_pe.shape)
108
+
109
+ torch.onnx.export(
110
+ sam2_prompt_encoder,
111
+ (point_coords, point_labels, input_masks, has_input_masks),
112
+ onnx_model_path,
113
+ export_params=True,
114
+ opset_version=18,
115
+ do_constant_folding=True,
116
+ input_names=["point_coords", "point_labels", "input_masks", "has_input_masks"],
117
+ output_names=["sparse_embeddings", "dense_embeddings", "image_pe"],
118
+ dynamic_axes={
119
+ "point_coords": {0: "num_labels", 1: "num_points"},
120
+ "point_labels": {0: "num_labels", 1: "num_points"},
121
+ "input_masks": {0: "num_labels"},
122
+ "sparse_embeddings": {0: "num_labels", 1: "num_points+1"},
123
+ "dense_embeddings": {0: "num_labels"},
124
+ },
125
+ )
126
+
127
+ print("prompt encoder onnx model saved to ", onnx_model_path)
128
+
129
+
130
+ def test_prompt_encoder_onnx(
131
+ sam2_model: SAM2Base,
132
+ onnx_model_path: str,
133
+ ):
134
+ sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
135
+
136
+ num_labels = 1
137
+ num_points = 5
138
+ point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
139
+ point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
140
+ input_masks = torch.rand(num_labels, 1, 256, 256, dtype=torch.float)
141
+ has_input_masks = torch.ones(1, dtype=torch.float)
142
+
143
+ sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
144
+ point_coords, point_labels, input_masks, has_input_masks
145
+ )
146
+
147
+ import onnxruntime
148
+
149
+ ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers())
150
+
151
+ model_inputs = ort_session.get_inputs()
152
+ input_names = [model_inputs[i].name for i in range(len(model_inputs))]
153
+ logger.info("input_names: %s", input_names)
154
+
155
+ model_outputs = ort_session.get_outputs()
156
+ output_names = [model_outputs[i].name for i in range(len(model_outputs))]
157
+ logger.info("output_names: %s", output_names)
158
+
159
+ outputs = ort_session.run(
160
+ output_names,
161
+ {
162
+ "point_coords": point_coords.numpy(),
163
+ "point_labels": point_labels.numpy(),
164
+ "input_masks": input_masks.numpy(),
165
+ "has_input_masks": has_input_masks.numpy(),
166
+ },
167
+ )
168
+
169
+ for i, output_name in enumerate(output_names):
170
+ logger.info("output %s shape: %s", output_name, outputs[i].shape)
171
+
172
+ ort_sparse_embeddings, ort_dense_embeddings, ort_image_pe = outputs
173
+ if (
174
+ compare_tensors_with_tolerance(
175
+ "sparse_embeddings",
176
+ sparse_embeddings,
177
+ torch.tensor(ort_sparse_embeddings),
178
+ mismatch_percentage_tolerance=0.2,
179
+ )
180
+ and compare_tensors_with_tolerance(
181
+ "dense_embeddings", dense_embeddings, torch.tensor(ort_dense_embeddings), mismatch_percentage_tolerance=0.2
182
+ )
183
+ and compare_tensors_with_tolerance(
184
+ "image_pe", image_pe, torch.tensor(ort_image_pe), mismatch_percentage_tolerance=0.2
185
+ )
186
+ ):
187
+ print(f"onnx model has been verified: {onnx_model_path}")
188
+ else:
189
+ print(f"onnx model verification failed: {onnx_model_path}")