onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,322 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (R) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import os
6
+ from typing import Union
7
+
8
+ import matplotlib.image as mpimg
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ import torch
12
+ from matplotlib.patches import Rectangle
13
+ from PIL import Image
14
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
15
+ from sam2_image_onnx_predictor import SAM2ImageOnnxPredictor
16
+ from sam2_utils import load_sam2_model
17
+
18
+ import onnxruntime
19
+
20
+
21
+ def show_mask(mask, ax, random_color=False, borders=True):
22
+ if random_color:
23
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
24
+ else:
25
+ color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
26
+ h, w = mask.shape[-2:]
27
+ mask = mask.astype(np.uint8)
28
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
29
+ if borders:
30
+ import cv2
31
+
32
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
33
+ # Try to smooth contours
34
+ contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
35
+ mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
36
+ ax.imshow(mask_image)
37
+
38
+
39
+ def show_points(coords, labels, ax, marker_size=375):
40
+ pos_points = coords[labels == 1]
41
+ neg_points = coords[labels == 0]
42
+ ax.scatter(
43
+ pos_points[:, 0], pos_points[:, 1], color="green", marker="*", s=marker_size, edgecolor="white", linewidth=1.25
44
+ )
45
+ ax.scatter(
46
+ neg_points[:, 0], neg_points[:, 1], color="red", marker="*", s=marker_size, edgecolor="white", linewidth=1.25
47
+ )
48
+
49
+
50
+ def show_box(box, ax):
51
+ x0, y0 = box[0], box[1]
52
+ w, h = box[2] - box[0], box[3] - box[1]
53
+ ax.add_patch(Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2))
54
+
55
+
56
+ def show_masks(
57
+ image,
58
+ masks,
59
+ scores,
60
+ point_coords=None,
61
+ box_coords=None,
62
+ input_labels=None,
63
+ borders=True,
64
+ output_image_file_prefix=None,
65
+ image_files=None,
66
+ ):
67
+ for i, (mask, score) in enumerate(zip(masks, scores)):
68
+ plt.figure(figsize=(10, 10))
69
+ plt.imshow(image)
70
+ show_mask(mask, plt.gca(), borders=borders)
71
+ if point_coords is not None:
72
+ assert input_labels is not None
73
+ show_points(point_coords, input_labels, plt.gca())
74
+
75
+ if box_coords is not None:
76
+ show_box(box_coords, plt.gca())
77
+
78
+ if len(scores) > 1:
79
+ plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
80
+
81
+ plt.axis("off")
82
+ if output_image_file_prefix:
83
+ filename = f"{output_image_file_prefix}_{i}.png"
84
+ if os.path.exists(filename):
85
+ os.remove(filename)
86
+ plt.savefig(filename, format="png", bbox_inches="tight", pad_inches=0)
87
+ if isinstance(image_files, list):
88
+ image_files.append(filename)
89
+ plt.show(block=False)
90
+ plt.close()
91
+
92
+
93
+ def get_predictor(
94
+ sam2_dir: str,
95
+ device: Union[str, torch.device],
96
+ dtype: torch.dtype,
97
+ model_type="sam2_hiera_large",
98
+ engine="torch",
99
+ image_encoder_onnx_path: str = "",
100
+ image_decoder_onnx_path: str = "",
101
+ image_decoder_multi_onnx_path: str = "",
102
+ provider: str = "CUDAExecutionProvider",
103
+ ):
104
+ sam2_model = load_sam2_model(sam2_dir, model_type, device=device)
105
+ if engine == "torch":
106
+ predictor = SAM2ImagePredictor(sam2_model)
107
+ else:
108
+ predictor = SAM2ImageOnnxPredictor(
109
+ sam2_model,
110
+ image_encoder_onnx_path=image_encoder_onnx_path,
111
+ image_decoder_onnx_path=image_decoder_onnx_path,
112
+ image_decoder_multi_onnx_path=image_decoder_multi_onnx_path,
113
+ provider=provider,
114
+ device=device,
115
+ onnx_dtype=dtype,
116
+ )
117
+ return predictor
118
+
119
+
120
+ def run_demo(
121
+ sam2_dir: str,
122
+ model_type: str = "sam2_hiera_large",
123
+ engine: str = "torch",
124
+ dtype: torch.dtype = torch.float32,
125
+ image_encoder_onnx_path: str = "",
126
+ image_decoder_onnx_path: str = "",
127
+ image_decoder_multi_onnx_path: str = "",
128
+ use_gpu: bool = True,
129
+ enable_batch: bool = False,
130
+ ):
131
+ if use_gpu:
132
+ assert torch.cuda.is_available()
133
+ assert "CUDAExecutionProvider" in onnxruntime.get_available_providers()
134
+ provider = "CUDAExecutionProvider"
135
+ else:
136
+ provider = "CPUExecutionProvider"
137
+
138
+ device = torch.device("cuda" if use_gpu else "cpu")
139
+
140
+ if use_gpu and engine == "torch" and torch.cuda.get_device_properties(0).major >= 8:
141
+ # Turn on tfloat32 for Ampere GPUs.
142
+ torch.backends.cuda.matmul.allow_tf32 = True
143
+ torch.backends.cudnn.allow_tf32 = True
144
+
145
+ np.random.seed(3)
146
+ image = Image.open("truck.jpg")
147
+ image = np.array(image.convert("RGB"))
148
+
149
+ predictor = get_predictor(
150
+ sam2_dir,
151
+ device,
152
+ dtype,
153
+ model_type,
154
+ engine,
155
+ image_encoder_onnx_path,
156
+ image_decoder_onnx_path,
157
+ image_decoder_multi_onnx_path,
158
+ provider=provider,
159
+ )
160
+
161
+ predictor.set_image(image)
162
+ prefix = f"sam2_demo_{engine}_"
163
+
164
+ # The model returns masks, quality predictions for those masks,
165
+ # and low resolution mask logits that can be passed to the next iteration of prediction.
166
+ # With multimask_output=True (the default setting), SAM 2 outputs 3 masks, where
167
+ # scores gives the model's own estimation of the quality of these masks.
168
+ # For ambiguous prompts such as a single point, it is recommended to use multimask_output=True
169
+ # even if only a single mask is desired;
170
+ input_point = np.array([[500, 375]])
171
+ input_label = np.array([1])
172
+ masks, scores, logits = predictor.predict(
173
+ point_coords=input_point,
174
+ point_labels=input_label,
175
+ multimask_output=True,
176
+ )
177
+
178
+ sorted_ind = np.argsort(scores)[::-1]
179
+ masks = masks[sorted_ind]
180
+ scores = scores[sorted_ind]
181
+ logits = logits[sorted_ind]
182
+
183
+ image_files = []
184
+ show_masks(
185
+ image,
186
+ masks,
187
+ scores,
188
+ point_coords=input_point,
189
+ input_labels=input_label,
190
+ borders=True,
191
+ output_image_file_prefix=prefix + "multimask",
192
+ image_files=image_files,
193
+ )
194
+
195
+ # Multiple points.
196
+ input_point = np.array([[500, 375], [1125, 625]])
197
+ input_label = np.array([1, 1])
198
+ mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
199
+ masks, scores, _ = predictor.predict(
200
+ point_coords=input_point,
201
+ point_labels=input_label,
202
+ mask_input=mask_input[None, :, :],
203
+ multimask_output=False,
204
+ )
205
+ show_masks(
206
+ image,
207
+ masks,
208
+ scores,
209
+ point_coords=input_point,
210
+ input_labels=input_label,
211
+ output_image_file_prefix=prefix + "multi_points",
212
+ image_files=image_files,
213
+ )
214
+
215
+ # Specify a window and a background point.
216
+ input_point = np.array([[500, 375], [1125, 625]])
217
+ input_label = np.array([1, 0])
218
+ mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
219
+ masks, scores, _ = predictor.predict(
220
+ point_coords=input_point,
221
+ point_labels=input_label,
222
+ mask_input=mask_input[None, :, :],
223
+ multimask_output=False,
224
+ )
225
+ show_masks(
226
+ image,
227
+ masks,
228
+ scores,
229
+ point_coords=input_point,
230
+ input_labels=input_label,
231
+ output_image_file_prefix=prefix + "background_point",
232
+ image_files=image_files,
233
+ )
234
+
235
+ # Take a box as input
236
+ input_box = np.array([425, 600, 700, 875])
237
+ masks, scores, _ = predictor.predict(
238
+ point_coords=None,
239
+ point_labels=None,
240
+ box=input_box[None, :],
241
+ multimask_output=False,
242
+ )
243
+ show_masks(
244
+ image,
245
+ masks,
246
+ scores,
247
+ box_coords=input_box,
248
+ output_image_file_prefix=prefix + "box",
249
+ image_files=image_files,
250
+ )
251
+
252
+ # Combining points and boxes
253
+ input_box = np.array([425, 600, 700, 875])
254
+ input_point = np.array([[575, 750]])
255
+ input_label = np.array([0])
256
+
257
+ masks, scores, logits = predictor.predict(
258
+ point_coords=input_point,
259
+ point_labels=input_label,
260
+ box=input_box,
261
+ multimask_output=False,
262
+ )
263
+ show_masks(
264
+ image,
265
+ masks,
266
+ scores,
267
+ box_coords=input_box,
268
+ point_coords=input_point,
269
+ input_labels=input_label,
270
+ output_image_file_prefix=prefix + "box_and_point",
271
+ image_files=image_files,
272
+ )
273
+
274
+ # TODO: support batched prompt inputs
275
+ if enable_batch:
276
+ input_boxes = np.array(
277
+ [
278
+ [75, 275, 1725, 850],
279
+ [425, 600, 700, 875],
280
+ [1375, 550, 1650, 800],
281
+ [1240, 675, 1400, 750],
282
+ ]
283
+ )
284
+ masks, scores, _ = predictor.predict(
285
+ point_coords=None,
286
+ point_labels=None,
287
+ box=input_boxes,
288
+ multimask_output=False,
289
+ )
290
+ plt.figure(figsize=(10, 10))
291
+ plt.imshow(image)
292
+ for mask in masks:
293
+ show_mask(mask.squeeze(0), plt.gca(), random_color=True)
294
+ for box in input_boxes:
295
+ show_box(box, plt.gca())
296
+ plt.axis("off")
297
+ plt.show()
298
+ plt.savefig(prefix + "batch_prompt.png")
299
+ image_files.append(prefix + "batch_prompt.png")
300
+ return image_files
301
+
302
+
303
+ def show_all_images(left_images, right_images, suffix=""):
304
+ # Show images in two rows since display screen is horizontal in most cases.
305
+ fig, axes = plt.subplots(nrows=2, ncols=len(left_images), figsize=(19.20, 10.80))
306
+ for i, (left_img_path, right_img_path) in enumerate(zip(left_images, right_images)):
307
+ left_img = mpimg.imread(left_img_path)
308
+ right_img = mpimg.imread(right_img_path)
309
+
310
+ axes[0, i].imshow(left_img)
311
+ axes[0, i].set_title(left_img_path.replace("sam2_demo_", "").replace(".png", ""), fontsize=10)
312
+ axes[0, i].axis("off")
313
+ axes[0, i].set_aspect(left_img.shape[1] / left_img.shape[0])
314
+
315
+ axes[1, i].imshow(right_img)
316
+ axes[1, i].set_title(right_img_path.replace("sam2_demo_", "").replace(".png", ""), fontsize=10)
317
+ axes[1, i].axis("off")
318
+ axes[1, i].set_aspect(right_img.shape[1] / right_img.shape[0])
319
+
320
+ plt.tight_layout()
321
+ plt.savefig(f"sam2_demo{suffix}.png", format="png", bbox_inches="tight", dpi=1000)
322
+ plt.show()
@@ -0,0 +1,280 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (R) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ import logging
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+ from PIL.Image import Image
12
+ from sam2.modeling.sam2_base import SAM2Base
13
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
14
+ from sam2_utils import decoder_shape_dict, encoder_shape_dict
15
+
16
+ from onnxruntime import InferenceSession
17
+ from onnxruntime.transformers.io_binding_helper import CudaSession
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def create_ort_session(
23
+ onnx_path: str,
24
+ session_options=None,
25
+ provider="CUDAExecutionProvider",
26
+ enable_cuda_graph=False,
27
+ use_tf32=True,
28
+ ) -> InferenceSession:
29
+ if provider == "CUDAExecutionProvider":
30
+ device_id = torch.cuda.current_device()
31
+ provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph)
32
+ provider_options["use_tf32"] = int(use_tf32)
33
+ providers = [(provider, provider_options), "CPUExecutionProvider"]
34
+ else:
35
+ providers = ["CPUExecutionProvider"]
36
+ logger.info("Using providers: %s", providers)
37
+ return InferenceSession(onnx_path, session_options, providers=providers)
38
+
39
+
40
+ def create_session(
41
+ onnx_path: str,
42
+ session_options=None,
43
+ provider="CUDAExecutionProvider",
44
+ device: Union[str, torch.device] = "cuda",
45
+ enable_cuda_graph=False,
46
+ ) -> CudaSession:
47
+ ort_session = create_ort_session(
48
+ onnx_path, session_options, provider, enable_cuda_graph=enable_cuda_graph, use_tf32=True
49
+ )
50
+ cuda_session = CudaSession(ort_session, device=torch.device(device), enable_cuda_graph=enable_cuda_graph)
51
+ return cuda_session
52
+
53
+
54
+ class SAM2ImageOnnxPredictor(SAM2ImagePredictor):
55
+ def __init__(
56
+ self,
57
+ sam_model: SAM2Base,
58
+ image_encoder_onnx_path: str = "",
59
+ image_decoder_onnx_path: str = "",
60
+ image_decoder_multi_onnx_path: str = "",
61
+ provider: str = "CUDAExecutionProvider",
62
+ device: Union[str, torch.device] = "cuda",
63
+ onnx_dtype: torch.dtype = torch.float32,
64
+ mask_threshold=0.0,
65
+ max_hole_area=0.0,
66
+ max_sprinkle_area=0.0,
67
+ **kwargs,
68
+ ) -> None:
69
+ """
70
+ Uses SAM-2 to compute the image embedding for an image, and then allow mask prediction given prompts.
71
+
72
+ Arguments:
73
+ sam_model (SAM2Base): The model to use for mask prediction.
74
+ onnx_directory (str): The path of the directory that contains encoder and decoder onnx models.
75
+ onnx_dtype (torch.dtype): The data type to use for ONNX inputs.
76
+ mask_threshold (float): The threshold to convert mask logits to binary masks. Default is 0.0.
77
+ max_hole_area (float): If max_hole_area > 0, we fill small holes in up to
78
+ the maximum area of max_hole_area in low_res_masks.
79
+ max_sprinkle_area (float): If max_sprinkle_area > 0, we remove small sprinkles up to
80
+ the maximum area of max_sprinkle_area in low_res_masks.
81
+ """
82
+ super().__init__(
83
+ sam_model, mask_threshold=mask_threshold, max_hole_area=max_hole_area, max_sprinkle_area=max_sprinkle_area
84
+ )
85
+
86
+ logger.debug("self.device=%s, device=%s", self.device, device)
87
+
88
+ # This model is exported by image_encoder.py.
89
+ self.encoder_session = create_session(
90
+ image_encoder_onnx_path,
91
+ session_options=None,
92
+ provider=provider,
93
+ device=device,
94
+ enable_cuda_graph=False,
95
+ )
96
+ self.onnx_dtype = onnx_dtype
97
+
98
+ # This model is exported by image_decoder.py. It outputs only one mask.
99
+ self.decoder_session = create_session(
100
+ image_decoder_onnx_path,
101
+ session_options=None,
102
+ provider=provider,
103
+ device=device,
104
+ enable_cuda_graph=False,
105
+ )
106
+
107
+ # This model is exported by image_decoder.py. It outputs multiple (3) masks.
108
+ self.decoder_session_multi_out = create_session(
109
+ image_decoder_multi_onnx_path,
110
+ session_options=None,
111
+ provider=provider,
112
+ device=device,
113
+ enable_cuda_graph=False,
114
+ )
115
+
116
+ @torch.no_grad()
117
+ def set_image(self, image: Union[np.ndarray, Image]):
118
+ """
119
+ Calculates the image embeddings for the provided image.
120
+
121
+ Arguments:
122
+ image (np.ndarray or PIL Image): The input image to embed in RGB format.
123
+ The image should be in HWC format if np.ndarray, or WHC format if PIL Image with pixel values in [0, 255].
124
+ """
125
+ self.reset_predictor()
126
+ # Transform the image to the form expected by the model
127
+ if isinstance(image, np.ndarray):
128
+ # For numpy array image, we assume (HxWxC) format.
129
+ self._orig_hw = [image.shape[:2]]
130
+ elif isinstance(image, Image):
131
+ w, h = image.size
132
+ self._orig_hw = [(h, w)]
133
+ else:
134
+ raise NotImplementedError("Image format not supported")
135
+
136
+ input_image = self._transforms(image)
137
+ input_image = input_image[None, ...].to(self.device)
138
+
139
+ assert (
140
+ len(input_image.shape) == 4 and input_image.shape[1] == 3
141
+ ), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
142
+
143
+ # Computing image embeddings for the provided image
144
+ io_shapes = encoder_shape_dict(batch_size=1, height=input_image.shape[2], width=input_image.shape[3])
145
+ self.encoder_session.allocate_buffers(io_shapes)
146
+
147
+ feed_dict = {"image": input_image.to(self.onnx_dtype).to(self.device)}
148
+
149
+ for key, value in feed_dict.items():
150
+ logger.debug(f"{key}: {value.shape}, {value.dtype}")
151
+ logger.debug(f"encoder onnx: {self.encoder_session.ort_session._model_path}")
152
+
153
+ ort_outputs = self.encoder_session.infer(feed_dict)
154
+
155
+ self._features = {
156
+ "image_embed": ort_outputs["image_embeddings"],
157
+ "high_res_feats": [ort_outputs[f"image_features_{i}"] for i in range(2)],
158
+ }
159
+ self._is_image_set = True
160
+ logging.info("Image embeddings computed.")
161
+
162
+ @torch.no_grad()
163
+ def _predict(
164
+ self,
165
+ point_coords: Optional[torch.Tensor],
166
+ point_labels: Optional[torch.Tensor],
167
+ boxes: Optional[torch.Tensor] = None,
168
+ mask_input: Optional[torch.Tensor] = None,
169
+ multimask_output: bool = True,
170
+ return_logits: bool = False,
171
+ img_idx: int = -1,
172
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
173
+ """
174
+ Predict masks for the given input prompts, using the currently set image.
175
+ Input prompts are batched torch tensors and are expected to already be
176
+ transformed to the input frame using SAM2Transforms.
177
+
178
+ Arguments:
179
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
180
+ model. Each point is in (X,Y) in pixels.
181
+ point_labels (torch.Tensor or None): A BxN array of labels for the
182
+ point prompts. 1 indicates a foreground point and 0 indicates a
183
+ background point.
184
+ boxes (np.ndarray or None): A Bx4 array given a box prompt to the
185
+ model, in XYXY format.
186
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
187
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
188
+ for SAM, H=W=256. Masks returned by a previous iteration of the
189
+ predict method do not need further transformation.
190
+ multimask_output (bool): If true, the model will return three masks.
191
+ For ambiguous input prompts (such as a single click), this will often
192
+ produce better masks than a single prediction. If only a single
193
+ mask is needed, the model's predicted quality score can be used
194
+ to select the best mask. For non-ambiguous prompts, such as multiple
195
+ input prompts, multimask_output=False can give better results.
196
+ return_logits (bool): If true, returns un-thresholded masks logits
197
+ instead of a binary mask.
198
+
199
+ Returns:
200
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
201
+ number of masks, and (H, W) is the original image size.
202
+ (torch.Tensor): An array of shape BxC containing the model's
203
+ predictions for the quality of each mask.
204
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
205
+ of masks and H=W=256. These low res logits can be passed to
206
+ a subsequent iteration as mask input.
207
+ """
208
+ assert not return_logits # onnx model is exported for returning bool masks.
209
+
210
+ if not self._is_image_set:
211
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
212
+
213
+ if point_coords is not None:
214
+ concat_points = (point_coords, point_labels)
215
+ else:
216
+ concat_points = None
217
+
218
+ # Embed prompts
219
+ if boxes is not None:
220
+ box_coords = boxes.reshape(-1, 2, 2)
221
+ box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
222
+ box_labels = box_labels.repeat(boxes.size(0), 1)
223
+ # we merge "boxes" and "points" into a single "concat_points" input (where
224
+ # boxes are added at the beginning) to sam_prompt_encoder
225
+ if concat_points is not None:
226
+ concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
227
+ concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
228
+ concat_points = (concat_coords, concat_labels)
229
+ else:
230
+ concat_points = (box_coords, box_labels)
231
+
232
+ assert concat_points is not None
233
+ num_labels = concat_points[0].shape[0]
234
+ shape_dict = decoder_shape_dict(
235
+ original_image_height=self._orig_hw[img_idx][0],
236
+ original_image_width=self._orig_hw[img_idx][1],
237
+ num_labels=num_labels,
238
+ max_points=concat_points[0].shape[1],
239
+ num_masks=3 if multimask_output else 1,
240
+ )
241
+ if multimask_output:
242
+ decoder_session = self.decoder_session_multi_out
243
+ else:
244
+ decoder_session = self.decoder_session
245
+
246
+ decoder_session.allocate_buffers(shape_dict)
247
+
248
+ image_features_0 = self._features["high_res_feats"][0][img_idx].unsqueeze(0)
249
+ image_features_1 = self._features["high_res_feats"][1][img_idx].unsqueeze(0)
250
+ image_embeddings = self._features["image_embed"][img_idx].unsqueeze(0)
251
+
252
+ if mask_input is None:
253
+ input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=self.onnx_dtype, device=self.device)
254
+ has_input_masks = torch.zeros(num_labels, dtype=self.onnx_dtype, device=self.device)
255
+ else:
256
+ input_masks = mask_input[img_idx].unsqueeze(0).repeat(num_labels, 1, 1, 1)
257
+ has_input_masks = torch.ones(num_labels, dtype=self.onnx_dtype, device=self.device)
258
+
259
+ feed_dict = {
260
+ "image_embeddings": image_embeddings.contiguous().to(dtype=self.onnx_dtype).to(self.device),
261
+ "image_features_0": image_features_0.contiguous().to(dtype=self.onnx_dtype).to(self.device),
262
+ "image_features_1": image_features_1.contiguous().to(dtype=self.onnx_dtype).to(self.device),
263
+ "point_coords": concat_points[0].to(dtype=self.onnx_dtype).to(self.device),
264
+ "point_labels": concat_points[1].to(dtype=torch.int32).to(self.device),
265
+ "input_masks": input_masks.to(dtype=self.onnx_dtype).to(self.device),
266
+ "has_input_masks": has_input_masks.to(dtype=self.onnx_dtype).to(self.device),
267
+ "original_image_size": torch.tensor(self._orig_hw[img_idx], dtype=torch.int32, device=self.device),
268
+ }
269
+
270
+ for key, value in feed_dict.items():
271
+ logger.debug(f"{key}: {value.shape}, {value.dtype}")
272
+ logger.debug(f"decoder onnx: {self.decoder_session.ort_session._model_path}")
273
+
274
+ ort_outputs = decoder_session.infer(feed_dict)
275
+
276
+ masks = ort_outputs["masks"]
277
+ iou_predictions = ort_outputs["iou_predictions"]
278
+ low_res_masks = ort_outputs["low_res_masks"]
279
+
280
+ return torch.Tensor(masks), torch.Tensor(iou_predictions), torch.Tensor(low_res_masks)