onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,321 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (R) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import os
6
+
7
+ import matplotlib.image as mpimg
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import torch
11
+ from matplotlib.patches import Rectangle
12
+ from PIL import Image
13
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
14
+ from sam2_image_onnx_predictor import SAM2ImageOnnxPredictor
15
+ from sam2_utils import load_sam2_model
16
+
17
+ import onnxruntime
18
+
19
+
20
+ def show_mask(mask, ax, random_color=False, borders=True):
21
+ if random_color:
22
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
23
+ else:
24
+ color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
25
+ h, w = mask.shape[-2:]
26
+ mask = mask.astype(np.uint8)
27
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
28
+ if borders:
29
+ import cv2 # noqa: PLC0415
30
+
31
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
32
+ # Try to smooth contours
33
+ contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
34
+ mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
35
+ ax.imshow(mask_image)
36
+
37
+
38
+ def show_points(coords, labels, ax, marker_size=375):
39
+ pos_points = coords[labels == 1]
40
+ neg_points = coords[labels == 0]
41
+ ax.scatter(
42
+ pos_points[:, 0], pos_points[:, 1], color="green", marker="*", s=marker_size, edgecolor="white", linewidth=1.25
43
+ )
44
+ ax.scatter(
45
+ neg_points[:, 0], neg_points[:, 1], color="red", marker="*", s=marker_size, edgecolor="white", linewidth=1.25
46
+ )
47
+
48
+
49
+ def show_box(box, ax):
50
+ x0, y0 = box[0], box[1]
51
+ w, h = box[2] - box[0], box[3] - box[1]
52
+ ax.add_patch(Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2))
53
+
54
+
55
+ def show_masks(
56
+ image,
57
+ masks,
58
+ scores,
59
+ point_coords=None,
60
+ box_coords=None,
61
+ input_labels=None,
62
+ borders=True,
63
+ output_image_file_prefix=None,
64
+ image_files=None,
65
+ ):
66
+ for i, (mask, score) in enumerate(zip(masks, scores, strict=False)):
67
+ plt.figure(figsize=(10, 10))
68
+ plt.imshow(image)
69
+ show_mask(mask, plt.gca(), borders=borders)
70
+ if point_coords is not None:
71
+ assert input_labels is not None
72
+ show_points(point_coords, input_labels, plt.gca())
73
+
74
+ if box_coords is not None:
75
+ show_box(box_coords, plt.gca())
76
+
77
+ if len(scores) > 1:
78
+ plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)
79
+
80
+ plt.axis("off")
81
+ if output_image_file_prefix:
82
+ filename = f"{output_image_file_prefix}_{i}.png"
83
+ if os.path.exists(filename):
84
+ os.remove(filename)
85
+ plt.savefig(filename, format="png", bbox_inches="tight", pad_inches=0)
86
+ if isinstance(image_files, list):
87
+ image_files.append(filename)
88
+ plt.show(block=False)
89
+ plt.close()
90
+
91
+
92
+ def get_predictor(
93
+ sam2_dir: str,
94
+ device: str | torch.device,
95
+ dtype: torch.dtype,
96
+ model_type="sam2_hiera_large",
97
+ engine="torch",
98
+ image_encoder_onnx_path: str = "",
99
+ image_decoder_onnx_path: str = "",
100
+ image_decoder_multi_onnx_path: str = "",
101
+ provider: str = "CUDAExecutionProvider",
102
+ ):
103
+ sam2_model = load_sam2_model(sam2_dir, model_type, device=device)
104
+ if engine == "torch":
105
+ predictor = SAM2ImagePredictor(sam2_model)
106
+ else:
107
+ predictor = SAM2ImageOnnxPredictor(
108
+ sam2_model,
109
+ image_encoder_onnx_path=image_encoder_onnx_path,
110
+ image_decoder_onnx_path=image_decoder_onnx_path,
111
+ image_decoder_multi_onnx_path=image_decoder_multi_onnx_path,
112
+ provider=provider,
113
+ device=device,
114
+ onnx_dtype=dtype,
115
+ )
116
+ return predictor
117
+
118
+
119
+ def run_demo(
120
+ sam2_dir: str,
121
+ model_type: str = "sam2_hiera_large",
122
+ engine: str = "torch",
123
+ dtype: torch.dtype = torch.float32,
124
+ image_encoder_onnx_path: str = "",
125
+ image_decoder_onnx_path: str = "",
126
+ image_decoder_multi_onnx_path: str = "",
127
+ use_gpu: bool = True,
128
+ enable_batch: bool = False,
129
+ ):
130
+ if use_gpu:
131
+ assert torch.cuda.is_available()
132
+ assert "CUDAExecutionProvider" in onnxruntime.get_available_providers()
133
+ provider = "CUDAExecutionProvider"
134
+ else:
135
+ provider = "CPUExecutionProvider"
136
+
137
+ device = torch.device("cuda" if use_gpu else "cpu")
138
+
139
+ if use_gpu and engine == "torch" and torch.cuda.get_device_properties(0).major >= 8:
140
+ # Turn on tfloat32 for Ampere GPUs.
141
+ torch.backends.cuda.matmul.allow_tf32 = True
142
+ torch.backends.cudnn.allow_tf32 = True
143
+
144
+ np.random.seed(3)
145
+ image = Image.open("truck.jpg")
146
+ image = np.array(image.convert("RGB"))
147
+
148
+ predictor = get_predictor(
149
+ sam2_dir,
150
+ device,
151
+ dtype,
152
+ model_type,
153
+ engine,
154
+ image_encoder_onnx_path,
155
+ image_decoder_onnx_path,
156
+ image_decoder_multi_onnx_path,
157
+ provider=provider,
158
+ )
159
+
160
+ predictor.set_image(image)
161
+ prefix = f"sam2_demo_{engine}_"
162
+
163
+ # The model returns masks, quality predictions for those masks,
164
+ # and low resolution mask logits that can be passed to the next iteration of prediction.
165
+ # With multimask_output=True (the default setting), SAM 2 outputs 3 masks, where
166
+ # scores gives the model's own estimation of the quality of these masks.
167
+ # For ambiguous prompts such as a single point, it is recommended to use multimask_output=True
168
+ # even if only a single mask is desired;
169
+ input_point = np.array([[500, 375]])
170
+ input_label = np.array([1])
171
+ masks, scores, logits = predictor.predict(
172
+ point_coords=input_point,
173
+ point_labels=input_label,
174
+ multimask_output=True,
175
+ )
176
+
177
+ sorted_ind = np.argsort(scores)[::-1]
178
+ masks = masks[sorted_ind]
179
+ scores = scores[sorted_ind]
180
+ logits = logits[sorted_ind]
181
+
182
+ image_files = []
183
+ show_masks(
184
+ image,
185
+ masks,
186
+ scores,
187
+ point_coords=input_point,
188
+ input_labels=input_label,
189
+ borders=True,
190
+ output_image_file_prefix=prefix + "multimask",
191
+ image_files=image_files,
192
+ )
193
+
194
+ # Multiple points.
195
+ input_point = np.array([[500, 375], [1125, 625]])
196
+ input_label = np.array([1, 1])
197
+ mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
198
+ masks, scores, _ = predictor.predict(
199
+ point_coords=input_point,
200
+ point_labels=input_label,
201
+ mask_input=mask_input[None, :, :],
202
+ multimask_output=False,
203
+ )
204
+ show_masks(
205
+ image,
206
+ masks,
207
+ scores,
208
+ point_coords=input_point,
209
+ input_labels=input_label,
210
+ output_image_file_prefix=prefix + "multi_points",
211
+ image_files=image_files,
212
+ )
213
+
214
+ # Specify a window and a background point.
215
+ input_point = np.array([[500, 375], [1125, 625]])
216
+ input_label = np.array([1, 0])
217
+ mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
218
+ masks, scores, _ = predictor.predict(
219
+ point_coords=input_point,
220
+ point_labels=input_label,
221
+ mask_input=mask_input[None, :, :],
222
+ multimask_output=False,
223
+ )
224
+ show_masks(
225
+ image,
226
+ masks,
227
+ scores,
228
+ point_coords=input_point,
229
+ input_labels=input_label,
230
+ output_image_file_prefix=prefix + "background_point",
231
+ image_files=image_files,
232
+ )
233
+
234
+ # Take a box as input
235
+ input_box = np.array([425, 600, 700, 875])
236
+ masks, scores, _ = predictor.predict(
237
+ point_coords=None,
238
+ point_labels=None,
239
+ box=input_box[None, :],
240
+ multimask_output=False,
241
+ )
242
+ show_masks(
243
+ image,
244
+ masks,
245
+ scores,
246
+ box_coords=input_box,
247
+ output_image_file_prefix=prefix + "box",
248
+ image_files=image_files,
249
+ )
250
+
251
+ # Combining points and boxes
252
+ input_box = np.array([425, 600, 700, 875])
253
+ input_point = np.array([[575, 750]])
254
+ input_label = np.array([0])
255
+
256
+ masks, scores, logits = predictor.predict(
257
+ point_coords=input_point,
258
+ point_labels=input_label,
259
+ box=input_box,
260
+ multimask_output=False,
261
+ )
262
+ show_masks(
263
+ image,
264
+ masks,
265
+ scores,
266
+ box_coords=input_box,
267
+ point_coords=input_point,
268
+ input_labels=input_label,
269
+ output_image_file_prefix=prefix + "box_and_point",
270
+ image_files=image_files,
271
+ )
272
+
273
+ # TODO: support batched prompt inputs
274
+ if enable_batch:
275
+ input_boxes = np.array(
276
+ [
277
+ [75, 275, 1725, 850],
278
+ [425, 600, 700, 875],
279
+ [1375, 550, 1650, 800],
280
+ [1240, 675, 1400, 750],
281
+ ]
282
+ )
283
+ masks, scores, _ = predictor.predict(
284
+ point_coords=None,
285
+ point_labels=None,
286
+ box=input_boxes,
287
+ multimask_output=False,
288
+ )
289
+ plt.figure(figsize=(10, 10))
290
+ plt.imshow(image)
291
+ for mask in masks:
292
+ show_mask(mask.squeeze(0), plt.gca(), random_color=True)
293
+ for box in input_boxes:
294
+ show_box(box, plt.gca())
295
+ plt.axis("off")
296
+ plt.show()
297
+ plt.savefig(prefix + "batch_prompt.png")
298
+ image_files.append(prefix + "batch_prompt.png")
299
+ return image_files
300
+
301
+
302
+ def show_all_images(left_images, right_images, suffix=""):
303
+ # Show images in two rows since display screen is horizontal in most cases.
304
+ fig, axes = plt.subplots(nrows=2, ncols=len(left_images), figsize=(19.20, 10.80))
305
+ for i, (left_img_path, right_img_path) in enumerate(zip(left_images, right_images, strict=False)):
306
+ left_img = mpimg.imread(left_img_path)
307
+ right_img = mpimg.imread(right_img_path)
308
+
309
+ axes[0, i].imshow(left_img)
310
+ axes[0, i].set_title(left_img_path.replace("sam2_demo_", "").replace(".png", ""), fontsize=10)
311
+ axes[0, i].axis("off")
312
+ axes[0, i].set_aspect(left_img.shape[1] / left_img.shape[0])
313
+
314
+ axes[1, i].imshow(right_img)
315
+ axes[1, i].set_title(right_img_path.replace("sam2_demo_", "").replace(".png", ""), fontsize=10)
316
+ axes[1, i].axis("off")
317
+ axes[1, i].set_aspect(right_img.shape[1] / right_img.shape[0])
318
+
319
+ plt.tight_layout()
320
+ plt.savefig(f"sam2_demo{suffix}.png", format="png", bbox_inches="tight", dpi=1000)
321
+ plt.show()
@@ -0,0 +1,279 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (R) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ import logging
7
+
8
+ import numpy as np
9
+ import torch
10
+ from PIL.Image import Image
11
+ from sam2.modeling.sam2_base import SAM2Base
12
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
13
+ from sam2_utils import decoder_shape_dict, encoder_shape_dict
14
+
15
+ from onnxruntime import InferenceSession
16
+ from onnxruntime.transformers.io_binding_helper import CudaSession
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def create_ort_session(
22
+ onnx_path: str,
23
+ session_options=None,
24
+ provider="CUDAExecutionProvider",
25
+ enable_cuda_graph=False,
26
+ use_tf32=True,
27
+ ) -> InferenceSession:
28
+ if provider == "CUDAExecutionProvider":
29
+ device_id = torch.cuda.current_device()
30
+ provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph)
31
+ provider_options["use_tf32"] = int(use_tf32)
32
+ providers = [(provider, provider_options), "CPUExecutionProvider"]
33
+ else:
34
+ providers = ["CPUExecutionProvider"]
35
+ logger.info("Using providers: %s", providers)
36
+ return InferenceSession(onnx_path, session_options, providers=providers)
37
+
38
+
39
+ def create_session(
40
+ onnx_path: str,
41
+ session_options=None,
42
+ provider="CUDAExecutionProvider",
43
+ device: str | torch.device = "cuda",
44
+ enable_cuda_graph=False,
45
+ ) -> CudaSession:
46
+ ort_session = create_ort_session(
47
+ onnx_path, session_options, provider, enable_cuda_graph=enable_cuda_graph, use_tf32=True
48
+ )
49
+ cuda_session = CudaSession(ort_session, device=torch.device(device), enable_cuda_graph=enable_cuda_graph)
50
+ return cuda_session
51
+
52
+
53
+ class SAM2ImageOnnxPredictor(SAM2ImagePredictor):
54
+ def __init__(
55
+ self,
56
+ sam_model: SAM2Base,
57
+ image_encoder_onnx_path: str = "",
58
+ image_decoder_onnx_path: str = "",
59
+ image_decoder_multi_onnx_path: str = "",
60
+ provider: str = "CUDAExecutionProvider",
61
+ device: str | torch.device = "cuda",
62
+ onnx_dtype: torch.dtype = torch.float32,
63
+ mask_threshold=0.0,
64
+ max_hole_area=0.0,
65
+ max_sprinkle_area=0.0,
66
+ **kwargs,
67
+ ) -> None:
68
+ """
69
+ Uses SAM-2 to compute the image embedding for an image, and then allow mask prediction given prompts.
70
+
71
+ Arguments:
72
+ sam_model (SAM2Base): The model to use for mask prediction.
73
+ onnx_directory (str): The path of the directory that contains encoder and decoder onnx models.
74
+ onnx_dtype (torch.dtype): The data type to use for ONNX inputs.
75
+ mask_threshold (float): The threshold to convert mask logits to binary masks. Default is 0.0.
76
+ max_hole_area (float): If max_hole_area > 0, we fill small holes in up to
77
+ the maximum area of max_hole_area in low_res_masks.
78
+ max_sprinkle_area (float): If max_sprinkle_area > 0, we remove small sprinkles up to
79
+ the maximum area of max_sprinkle_area in low_res_masks.
80
+ """
81
+ super().__init__(
82
+ sam_model, mask_threshold=mask_threshold, max_hole_area=max_hole_area, max_sprinkle_area=max_sprinkle_area
83
+ )
84
+
85
+ logger.debug("self.device=%s, device=%s", self.device, device)
86
+
87
+ # This model is exported by image_encoder.py.
88
+ self.encoder_session = create_session(
89
+ image_encoder_onnx_path,
90
+ session_options=None,
91
+ provider=provider,
92
+ device=device,
93
+ enable_cuda_graph=False,
94
+ )
95
+ self.onnx_dtype = onnx_dtype
96
+
97
+ # This model is exported by image_decoder.py. It outputs only one mask.
98
+ self.decoder_session = create_session(
99
+ image_decoder_onnx_path,
100
+ session_options=None,
101
+ provider=provider,
102
+ device=device,
103
+ enable_cuda_graph=False,
104
+ )
105
+
106
+ # This model is exported by image_decoder.py. It outputs multiple (3) masks.
107
+ self.decoder_session_multi_out = create_session(
108
+ image_decoder_multi_onnx_path,
109
+ session_options=None,
110
+ provider=provider,
111
+ device=device,
112
+ enable_cuda_graph=False,
113
+ )
114
+
115
+ @torch.no_grad()
116
+ def set_image(self, image: np.ndarray | Image):
117
+ """
118
+ Calculates the image embeddings for the provided image.
119
+
120
+ Arguments:
121
+ image (np.ndarray or PIL Image): The input image to embed in RGB format.
122
+ The image should be in HWC format if np.ndarray, or WHC format if PIL Image with pixel values in [0, 255].
123
+ """
124
+ self.reset_predictor()
125
+ # Transform the image to the form expected by the model
126
+ if isinstance(image, np.ndarray):
127
+ # For numpy array image, we assume (HxWxC) format.
128
+ self._orig_hw = [image.shape[:2]]
129
+ elif isinstance(image, Image):
130
+ w, h = image.size
131
+ self._orig_hw = [(h, w)]
132
+ else:
133
+ raise NotImplementedError("Image format not supported")
134
+
135
+ input_image = self._transforms(image)
136
+ input_image = input_image[None, ...].to(self.device)
137
+
138
+ assert len(input_image.shape) == 4 and input_image.shape[1] == 3, (
139
+ f"input_image must be of size 1x3xHxW, got {input_image.shape}"
140
+ )
141
+
142
+ # Computing image embeddings for the provided image
143
+ io_shapes = encoder_shape_dict(batch_size=1, height=input_image.shape[2], width=input_image.shape[3])
144
+ self.encoder_session.allocate_buffers(io_shapes)
145
+
146
+ feed_dict = {"image": input_image.to(self.onnx_dtype).to(self.device)}
147
+
148
+ for key, value in feed_dict.items():
149
+ logger.debug(f"{key}: {value.shape}, {value.dtype}")
150
+ logger.debug(f"encoder onnx: {self.encoder_session.ort_session._model_path}")
151
+
152
+ ort_outputs = self.encoder_session.infer(feed_dict)
153
+
154
+ self._features = {
155
+ "image_embed": ort_outputs["image_embeddings"],
156
+ "high_res_feats": [ort_outputs[f"image_features_{i}"] for i in range(2)],
157
+ }
158
+ self._is_image_set = True
159
+ logging.info("Image embeddings computed.")
160
+
161
+ @torch.no_grad()
162
+ def _predict(
163
+ self,
164
+ point_coords: torch.Tensor | None,
165
+ point_labels: torch.Tensor | None,
166
+ boxes: torch.Tensor | None = None,
167
+ mask_input: torch.Tensor | None = None,
168
+ multimask_output: bool = True,
169
+ return_logits: bool = False,
170
+ img_idx: int = -1,
171
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
172
+ """
173
+ Predict masks for the given input prompts, using the currently set image.
174
+ Input prompts are batched torch tensors and are expected to already be
175
+ transformed to the input frame using SAM2Transforms.
176
+
177
+ Arguments:
178
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
179
+ model. Each point is in (X,Y) in pixels.
180
+ point_labels (torch.Tensor or None): A BxN array of labels for the
181
+ point prompts. 1 indicates a foreground point and 0 indicates a
182
+ background point.
183
+ boxes (np.ndarray or None): A Bx4 array given a box prompt to the
184
+ model, in XYXY format.
185
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
186
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
187
+ for SAM, H=W=256. Masks returned by a previous iteration of the
188
+ predict method do not need further transformation.
189
+ multimask_output (bool): If true, the model will return three masks.
190
+ For ambiguous input prompts (such as a single click), this will often
191
+ produce better masks than a single prediction. If only a single
192
+ mask is needed, the model's predicted quality score can be used
193
+ to select the best mask. For non-ambiguous prompts, such as multiple
194
+ input prompts, multimask_output=False can give better results.
195
+ return_logits (bool): If true, returns un-thresholded masks logits
196
+ instead of a binary mask.
197
+
198
+ Returns:
199
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
200
+ number of masks, and (H, W) is the original image size.
201
+ (torch.Tensor): An array of shape BxC containing the model's
202
+ predictions for the quality of each mask.
203
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
204
+ of masks and H=W=256. These low res logits can be passed to
205
+ a subsequent iteration as mask input.
206
+ """
207
+ assert not return_logits # onnx model is exported for returning bool masks.
208
+
209
+ if not self._is_image_set:
210
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
211
+
212
+ if point_coords is not None:
213
+ concat_points = (point_coords, point_labels)
214
+ else:
215
+ concat_points = None
216
+
217
+ # Embed prompts
218
+ if boxes is not None:
219
+ box_coords = boxes.reshape(-1, 2, 2)
220
+ box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
221
+ box_labels = box_labels.repeat(boxes.size(0), 1)
222
+ # we merge "boxes" and "points" into a single "concat_points" input (where
223
+ # boxes are added at the beginning) to sam_prompt_encoder
224
+ if concat_points is not None:
225
+ concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
226
+ concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
227
+ concat_points = (concat_coords, concat_labels)
228
+ else:
229
+ concat_points = (box_coords, box_labels)
230
+
231
+ assert concat_points is not None
232
+ num_labels = concat_points[0].shape[0]
233
+ shape_dict = decoder_shape_dict(
234
+ original_image_height=self._orig_hw[img_idx][0],
235
+ original_image_width=self._orig_hw[img_idx][1],
236
+ num_labels=num_labels,
237
+ max_points=concat_points[0].shape[1],
238
+ num_masks=3 if multimask_output else 1,
239
+ )
240
+ if multimask_output:
241
+ decoder_session = self.decoder_session_multi_out
242
+ else:
243
+ decoder_session = self.decoder_session
244
+
245
+ decoder_session.allocate_buffers(shape_dict)
246
+
247
+ image_features_0 = self._features["high_res_feats"][0][img_idx].unsqueeze(0)
248
+ image_features_1 = self._features["high_res_feats"][1][img_idx].unsqueeze(0)
249
+ image_embeddings = self._features["image_embed"][img_idx].unsqueeze(0)
250
+
251
+ if mask_input is None:
252
+ input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=self.onnx_dtype, device=self.device)
253
+ has_input_masks = torch.zeros(num_labels, dtype=self.onnx_dtype, device=self.device)
254
+ else:
255
+ input_masks = mask_input[img_idx].unsqueeze(0).repeat(num_labels, 1, 1, 1)
256
+ has_input_masks = torch.ones(num_labels, dtype=self.onnx_dtype, device=self.device)
257
+
258
+ feed_dict = {
259
+ "image_embeddings": image_embeddings.contiguous().to(dtype=self.onnx_dtype).to(self.device),
260
+ "image_features_0": image_features_0.contiguous().to(dtype=self.onnx_dtype).to(self.device),
261
+ "image_features_1": image_features_1.contiguous().to(dtype=self.onnx_dtype).to(self.device),
262
+ "point_coords": concat_points[0].to(dtype=self.onnx_dtype).to(self.device),
263
+ "point_labels": concat_points[1].to(dtype=torch.int32).to(self.device),
264
+ "input_masks": input_masks.to(dtype=self.onnx_dtype).to(self.device),
265
+ "has_input_masks": has_input_masks.to(dtype=self.onnx_dtype).to(self.device),
266
+ "original_image_size": torch.tensor(self._orig_hw[img_idx], dtype=torch.int32, device=self.device),
267
+ }
268
+
269
+ for key, value in feed_dict.items():
270
+ logger.debug(f"{key}: {value.shape}, {value.dtype}")
271
+ logger.debug(f"decoder onnx: {self.decoder_session.ort_session._model_path}")
272
+
273
+ ort_outputs = decoder_session.infer(feed_dict)
274
+
275
+ masks = ort_outputs["masks"]
276
+ iou_predictions = ort_outputs["iou_predictions"]
277
+ low_res_masks = ort_outputs["low_res_masks"]
278
+
279
+ return torch.Tensor(masks), torch.Tensor(iou_predictions), torch.Tensor(low_res_masks)