onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (R) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import os
|
|
6
|
+
|
|
7
|
+
import matplotlib.image as mpimg
|
|
8
|
+
import matplotlib.pyplot as plt
|
|
9
|
+
import numpy as np
|
|
10
|
+
import torch
|
|
11
|
+
from matplotlib.patches import Rectangle
|
|
12
|
+
from PIL import Image
|
|
13
|
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
14
|
+
from sam2_image_onnx_predictor import SAM2ImageOnnxPredictor
|
|
15
|
+
from sam2_utils import load_sam2_model
|
|
16
|
+
|
|
17
|
+
import onnxruntime
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def show_mask(mask, ax, random_color=False, borders=True):
|
|
21
|
+
if random_color:
|
|
22
|
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
|
23
|
+
else:
|
|
24
|
+
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
|
|
25
|
+
h, w = mask.shape[-2:]
|
|
26
|
+
mask = mask.astype(np.uint8)
|
|
27
|
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
|
28
|
+
if borders:
|
|
29
|
+
import cv2 # noqa: PLC0415
|
|
30
|
+
|
|
31
|
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
|
32
|
+
# Try to smooth contours
|
|
33
|
+
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
|
|
34
|
+
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
|
|
35
|
+
ax.imshow(mask_image)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def show_points(coords, labels, ax, marker_size=375):
|
|
39
|
+
pos_points = coords[labels == 1]
|
|
40
|
+
neg_points = coords[labels == 0]
|
|
41
|
+
ax.scatter(
|
|
42
|
+
pos_points[:, 0], pos_points[:, 1], color="green", marker="*", s=marker_size, edgecolor="white", linewidth=1.25
|
|
43
|
+
)
|
|
44
|
+
ax.scatter(
|
|
45
|
+
neg_points[:, 0], neg_points[:, 1], color="red", marker="*", s=marker_size, edgecolor="white", linewidth=1.25
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def show_box(box, ax):
|
|
50
|
+
x0, y0 = box[0], box[1]
|
|
51
|
+
w, h = box[2] - box[0], box[3] - box[1]
|
|
52
|
+
ax.add_patch(Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2))
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def show_masks(
|
|
56
|
+
image,
|
|
57
|
+
masks,
|
|
58
|
+
scores,
|
|
59
|
+
point_coords=None,
|
|
60
|
+
box_coords=None,
|
|
61
|
+
input_labels=None,
|
|
62
|
+
borders=True,
|
|
63
|
+
output_image_file_prefix=None,
|
|
64
|
+
image_files=None,
|
|
65
|
+
):
|
|
66
|
+
for i, (mask, score) in enumerate(zip(masks, scores, strict=False)):
|
|
67
|
+
plt.figure(figsize=(10, 10))
|
|
68
|
+
plt.imshow(image)
|
|
69
|
+
show_mask(mask, plt.gca(), borders=borders)
|
|
70
|
+
if point_coords is not None:
|
|
71
|
+
assert input_labels is not None
|
|
72
|
+
show_points(point_coords, input_labels, plt.gca())
|
|
73
|
+
|
|
74
|
+
if box_coords is not None:
|
|
75
|
+
show_box(box_coords, plt.gca())
|
|
76
|
+
|
|
77
|
+
if len(scores) > 1:
|
|
78
|
+
plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)
|
|
79
|
+
|
|
80
|
+
plt.axis("off")
|
|
81
|
+
if output_image_file_prefix:
|
|
82
|
+
filename = f"{output_image_file_prefix}_{i}.png"
|
|
83
|
+
if os.path.exists(filename):
|
|
84
|
+
os.remove(filename)
|
|
85
|
+
plt.savefig(filename, format="png", bbox_inches="tight", pad_inches=0)
|
|
86
|
+
if isinstance(image_files, list):
|
|
87
|
+
image_files.append(filename)
|
|
88
|
+
plt.show(block=False)
|
|
89
|
+
plt.close()
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def get_predictor(
|
|
93
|
+
sam2_dir: str,
|
|
94
|
+
device: str | torch.device,
|
|
95
|
+
dtype: torch.dtype,
|
|
96
|
+
model_type="sam2_hiera_large",
|
|
97
|
+
engine="torch",
|
|
98
|
+
image_encoder_onnx_path: str = "",
|
|
99
|
+
image_decoder_onnx_path: str = "",
|
|
100
|
+
image_decoder_multi_onnx_path: str = "",
|
|
101
|
+
provider: str = "CUDAExecutionProvider",
|
|
102
|
+
):
|
|
103
|
+
sam2_model = load_sam2_model(sam2_dir, model_type, device=device)
|
|
104
|
+
if engine == "torch":
|
|
105
|
+
predictor = SAM2ImagePredictor(sam2_model)
|
|
106
|
+
else:
|
|
107
|
+
predictor = SAM2ImageOnnxPredictor(
|
|
108
|
+
sam2_model,
|
|
109
|
+
image_encoder_onnx_path=image_encoder_onnx_path,
|
|
110
|
+
image_decoder_onnx_path=image_decoder_onnx_path,
|
|
111
|
+
image_decoder_multi_onnx_path=image_decoder_multi_onnx_path,
|
|
112
|
+
provider=provider,
|
|
113
|
+
device=device,
|
|
114
|
+
onnx_dtype=dtype,
|
|
115
|
+
)
|
|
116
|
+
return predictor
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def run_demo(
|
|
120
|
+
sam2_dir: str,
|
|
121
|
+
model_type: str = "sam2_hiera_large",
|
|
122
|
+
engine: str = "torch",
|
|
123
|
+
dtype: torch.dtype = torch.float32,
|
|
124
|
+
image_encoder_onnx_path: str = "",
|
|
125
|
+
image_decoder_onnx_path: str = "",
|
|
126
|
+
image_decoder_multi_onnx_path: str = "",
|
|
127
|
+
use_gpu: bool = True,
|
|
128
|
+
enable_batch: bool = False,
|
|
129
|
+
):
|
|
130
|
+
if use_gpu:
|
|
131
|
+
assert torch.cuda.is_available()
|
|
132
|
+
assert "CUDAExecutionProvider" in onnxruntime.get_available_providers()
|
|
133
|
+
provider = "CUDAExecutionProvider"
|
|
134
|
+
else:
|
|
135
|
+
provider = "CPUExecutionProvider"
|
|
136
|
+
|
|
137
|
+
device = torch.device("cuda" if use_gpu else "cpu")
|
|
138
|
+
|
|
139
|
+
if use_gpu and engine == "torch" and torch.cuda.get_device_properties(0).major >= 8:
|
|
140
|
+
# Turn on tfloat32 for Ampere GPUs.
|
|
141
|
+
torch.backends.cuda.matmul.allow_tf32 = True
|
|
142
|
+
torch.backends.cudnn.allow_tf32 = True
|
|
143
|
+
|
|
144
|
+
np.random.seed(3)
|
|
145
|
+
image = Image.open("truck.jpg")
|
|
146
|
+
image = np.array(image.convert("RGB"))
|
|
147
|
+
|
|
148
|
+
predictor = get_predictor(
|
|
149
|
+
sam2_dir,
|
|
150
|
+
device,
|
|
151
|
+
dtype,
|
|
152
|
+
model_type,
|
|
153
|
+
engine,
|
|
154
|
+
image_encoder_onnx_path,
|
|
155
|
+
image_decoder_onnx_path,
|
|
156
|
+
image_decoder_multi_onnx_path,
|
|
157
|
+
provider=provider,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
predictor.set_image(image)
|
|
161
|
+
prefix = f"sam2_demo_{engine}_"
|
|
162
|
+
|
|
163
|
+
# The model returns masks, quality predictions for those masks,
|
|
164
|
+
# and low resolution mask logits that can be passed to the next iteration of prediction.
|
|
165
|
+
# With multimask_output=True (the default setting), SAM 2 outputs 3 masks, where
|
|
166
|
+
# scores gives the model's own estimation of the quality of these masks.
|
|
167
|
+
# For ambiguous prompts such as a single point, it is recommended to use multimask_output=True
|
|
168
|
+
# even if only a single mask is desired;
|
|
169
|
+
input_point = np.array([[500, 375]])
|
|
170
|
+
input_label = np.array([1])
|
|
171
|
+
masks, scores, logits = predictor.predict(
|
|
172
|
+
point_coords=input_point,
|
|
173
|
+
point_labels=input_label,
|
|
174
|
+
multimask_output=True,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
sorted_ind = np.argsort(scores)[::-1]
|
|
178
|
+
masks = masks[sorted_ind]
|
|
179
|
+
scores = scores[sorted_ind]
|
|
180
|
+
logits = logits[sorted_ind]
|
|
181
|
+
|
|
182
|
+
image_files = []
|
|
183
|
+
show_masks(
|
|
184
|
+
image,
|
|
185
|
+
masks,
|
|
186
|
+
scores,
|
|
187
|
+
point_coords=input_point,
|
|
188
|
+
input_labels=input_label,
|
|
189
|
+
borders=True,
|
|
190
|
+
output_image_file_prefix=prefix + "multimask",
|
|
191
|
+
image_files=image_files,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# Multiple points.
|
|
195
|
+
input_point = np.array([[500, 375], [1125, 625]])
|
|
196
|
+
input_label = np.array([1, 1])
|
|
197
|
+
mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
|
|
198
|
+
masks, scores, _ = predictor.predict(
|
|
199
|
+
point_coords=input_point,
|
|
200
|
+
point_labels=input_label,
|
|
201
|
+
mask_input=mask_input[None, :, :],
|
|
202
|
+
multimask_output=False,
|
|
203
|
+
)
|
|
204
|
+
show_masks(
|
|
205
|
+
image,
|
|
206
|
+
masks,
|
|
207
|
+
scores,
|
|
208
|
+
point_coords=input_point,
|
|
209
|
+
input_labels=input_label,
|
|
210
|
+
output_image_file_prefix=prefix + "multi_points",
|
|
211
|
+
image_files=image_files,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# Specify a window and a background point.
|
|
215
|
+
input_point = np.array([[500, 375], [1125, 625]])
|
|
216
|
+
input_label = np.array([1, 0])
|
|
217
|
+
mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
|
|
218
|
+
masks, scores, _ = predictor.predict(
|
|
219
|
+
point_coords=input_point,
|
|
220
|
+
point_labels=input_label,
|
|
221
|
+
mask_input=mask_input[None, :, :],
|
|
222
|
+
multimask_output=False,
|
|
223
|
+
)
|
|
224
|
+
show_masks(
|
|
225
|
+
image,
|
|
226
|
+
masks,
|
|
227
|
+
scores,
|
|
228
|
+
point_coords=input_point,
|
|
229
|
+
input_labels=input_label,
|
|
230
|
+
output_image_file_prefix=prefix + "background_point",
|
|
231
|
+
image_files=image_files,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# Take a box as input
|
|
235
|
+
input_box = np.array([425, 600, 700, 875])
|
|
236
|
+
masks, scores, _ = predictor.predict(
|
|
237
|
+
point_coords=None,
|
|
238
|
+
point_labels=None,
|
|
239
|
+
box=input_box[None, :],
|
|
240
|
+
multimask_output=False,
|
|
241
|
+
)
|
|
242
|
+
show_masks(
|
|
243
|
+
image,
|
|
244
|
+
masks,
|
|
245
|
+
scores,
|
|
246
|
+
box_coords=input_box,
|
|
247
|
+
output_image_file_prefix=prefix + "box",
|
|
248
|
+
image_files=image_files,
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
# Combining points and boxes
|
|
252
|
+
input_box = np.array([425, 600, 700, 875])
|
|
253
|
+
input_point = np.array([[575, 750]])
|
|
254
|
+
input_label = np.array([0])
|
|
255
|
+
|
|
256
|
+
masks, scores, logits = predictor.predict(
|
|
257
|
+
point_coords=input_point,
|
|
258
|
+
point_labels=input_label,
|
|
259
|
+
box=input_box,
|
|
260
|
+
multimask_output=False,
|
|
261
|
+
)
|
|
262
|
+
show_masks(
|
|
263
|
+
image,
|
|
264
|
+
masks,
|
|
265
|
+
scores,
|
|
266
|
+
box_coords=input_box,
|
|
267
|
+
point_coords=input_point,
|
|
268
|
+
input_labels=input_label,
|
|
269
|
+
output_image_file_prefix=prefix + "box_and_point",
|
|
270
|
+
image_files=image_files,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
# TODO: support batched prompt inputs
|
|
274
|
+
if enable_batch:
|
|
275
|
+
input_boxes = np.array(
|
|
276
|
+
[
|
|
277
|
+
[75, 275, 1725, 850],
|
|
278
|
+
[425, 600, 700, 875],
|
|
279
|
+
[1375, 550, 1650, 800],
|
|
280
|
+
[1240, 675, 1400, 750],
|
|
281
|
+
]
|
|
282
|
+
)
|
|
283
|
+
masks, scores, _ = predictor.predict(
|
|
284
|
+
point_coords=None,
|
|
285
|
+
point_labels=None,
|
|
286
|
+
box=input_boxes,
|
|
287
|
+
multimask_output=False,
|
|
288
|
+
)
|
|
289
|
+
plt.figure(figsize=(10, 10))
|
|
290
|
+
plt.imshow(image)
|
|
291
|
+
for mask in masks:
|
|
292
|
+
show_mask(mask.squeeze(0), plt.gca(), random_color=True)
|
|
293
|
+
for box in input_boxes:
|
|
294
|
+
show_box(box, plt.gca())
|
|
295
|
+
plt.axis("off")
|
|
296
|
+
plt.show()
|
|
297
|
+
plt.savefig(prefix + "batch_prompt.png")
|
|
298
|
+
image_files.append(prefix + "batch_prompt.png")
|
|
299
|
+
return image_files
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def show_all_images(left_images, right_images, suffix=""):
|
|
303
|
+
# Show images in two rows since display screen is horizontal in most cases.
|
|
304
|
+
fig, axes = plt.subplots(nrows=2, ncols=len(left_images), figsize=(19.20, 10.80))
|
|
305
|
+
for i, (left_img_path, right_img_path) in enumerate(zip(left_images, right_images, strict=False)):
|
|
306
|
+
left_img = mpimg.imread(left_img_path)
|
|
307
|
+
right_img = mpimg.imread(right_img_path)
|
|
308
|
+
|
|
309
|
+
axes[0, i].imshow(left_img)
|
|
310
|
+
axes[0, i].set_title(left_img_path.replace("sam2_demo_", "").replace(".png", ""), fontsize=10)
|
|
311
|
+
axes[0, i].axis("off")
|
|
312
|
+
axes[0, i].set_aspect(left_img.shape[1] / left_img.shape[0])
|
|
313
|
+
|
|
314
|
+
axes[1, i].imshow(right_img)
|
|
315
|
+
axes[1, i].set_title(right_img_path.replace("sam2_demo_", "").replace(".png", ""), fontsize=10)
|
|
316
|
+
axes[1, i].axis("off")
|
|
317
|
+
axes[1, i].set_aspect(right_img.shape[1] / right_img.shape[0])
|
|
318
|
+
|
|
319
|
+
plt.tight_layout()
|
|
320
|
+
plt.savefig(f"sam2_demo{suffix}.png", format="png", bbox_inches="tight", dpi=1000)
|
|
321
|
+
plt.show()
|
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (R) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
from PIL.Image import Image
|
|
11
|
+
from sam2.modeling.sam2_base import SAM2Base
|
|
12
|
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
13
|
+
from sam2_utils import decoder_shape_dict, encoder_shape_dict
|
|
14
|
+
|
|
15
|
+
from onnxruntime import InferenceSession
|
|
16
|
+
from onnxruntime.transformers.io_binding_helper import CudaSession
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def create_ort_session(
|
|
22
|
+
onnx_path: str,
|
|
23
|
+
session_options=None,
|
|
24
|
+
provider="CUDAExecutionProvider",
|
|
25
|
+
enable_cuda_graph=False,
|
|
26
|
+
use_tf32=True,
|
|
27
|
+
) -> InferenceSession:
|
|
28
|
+
if provider == "CUDAExecutionProvider":
|
|
29
|
+
device_id = torch.cuda.current_device()
|
|
30
|
+
provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph)
|
|
31
|
+
provider_options["use_tf32"] = int(use_tf32)
|
|
32
|
+
providers = [(provider, provider_options), "CPUExecutionProvider"]
|
|
33
|
+
else:
|
|
34
|
+
providers = ["CPUExecutionProvider"]
|
|
35
|
+
logger.info("Using providers: %s", providers)
|
|
36
|
+
return InferenceSession(onnx_path, session_options, providers=providers)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def create_session(
|
|
40
|
+
onnx_path: str,
|
|
41
|
+
session_options=None,
|
|
42
|
+
provider="CUDAExecutionProvider",
|
|
43
|
+
device: str | torch.device = "cuda",
|
|
44
|
+
enable_cuda_graph=False,
|
|
45
|
+
) -> CudaSession:
|
|
46
|
+
ort_session = create_ort_session(
|
|
47
|
+
onnx_path, session_options, provider, enable_cuda_graph=enable_cuda_graph, use_tf32=True
|
|
48
|
+
)
|
|
49
|
+
cuda_session = CudaSession(ort_session, device=torch.device(device), enable_cuda_graph=enable_cuda_graph)
|
|
50
|
+
return cuda_session
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class SAM2ImageOnnxPredictor(SAM2ImagePredictor):
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
sam_model: SAM2Base,
|
|
57
|
+
image_encoder_onnx_path: str = "",
|
|
58
|
+
image_decoder_onnx_path: str = "",
|
|
59
|
+
image_decoder_multi_onnx_path: str = "",
|
|
60
|
+
provider: str = "CUDAExecutionProvider",
|
|
61
|
+
device: str | torch.device = "cuda",
|
|
62
|
+
onnx_dtype: torch.dtype = torch.float32,
|
|
63
|
+
mask_threshold=0.0,
|
|
64
|
+
max_hole_area=0.0,
|
|
65
|
+
max_sprinkle_area=0.0,
|
|
66
|
+
**kwargs,
|
|
67
|
+
) -> None:
|
|
68
|
+
"""
|
|
69
|
+
Uses SAM-2 to compute the image embedding for an image, and then allow mask prediction given prompts.
|
|
70
|
+
|
|
71
|
+
Arguments:
|
|
72
|
+
sam_model (SAM2Base): The model to use for mask prediction.
|
|
73
|
+
onnx_directory (str): The path of the directory that contains encoder and decoder onnx models.
|
|
74
|
+
onnx_dtype (torch.dtype): The data type to use for ONNX inputs.
|
|
75
|
+
mask_threshold (float): The threshold to convert mask logits to binary masks. Default is 0.0.
|
|
76
|
+
max_hole_area (float): If max_hole_area > 0, we fill small holes in up to
|
|
77
|
+
the maximum area of max_hole_area in low_res_masks.
|
|
78
|
+
max_sprinkle_area (float): If max_sprinkle_area > 0, we remove small sprinkles up to
|
|
79
|
+
the maximum area of max_sprinkle_area in low_res_masks.
|
|
80
|
+
"""
|
|
81
|
+
super().__init__(
|
|
82
|
+
sam_model, mask_threshold=mask_threshold, max_hole_area=max_hole_area, max_sprinkle_area=max_sprinkle_area
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
logger.debug("self.device=%s, device=%s", self.device, device)
|
|
86
|
+
|
|
87
|
+
# This model is exported by image_encoder.py.
|
|
88
|
+
self.encoder_session = create_session(
|
|
89
|
+
image_encoder_onnx_path,
|
|
90
|
+
session_options=None,
|
|
91
|
+
provider=provider,
|
|
92
|
+
device=device,
|
|
93
|
+
enable_cuda_graph=False,
|
|
94
|
+
)
|
|
95
|
+
self.onnx_dtype = onnx_dtype
|
|
96
|
+
|
|
97
|
+
# This model is exported by image_decoder.py. It outputs only one mask.
|
|
98
|
+
self.decoder_session = create_session(
|
|
99
|
+
image_decoder_onnx_path,
|
|
100
|
+
session_options=None,
|
|
101
|
+
provider=provider,
|
|
102
|
+
device=device,
|
|
103
|
+
enable_cuda_graph=False,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# This model is exported by image_decoder.py. It outputs multiple (3) masks.
|
|
107
|
+
self.decoder_session_multi_out = create_session(
|
|
108
|
+
image_decoder_multi_onnx_path,
|
|
109
|
+
session_options=None,
|
|
110
|
+
provider=provider,
|
|
111
|
+
device=device,
|
|
112
|
+
enable_cuda_graph=False,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
@torch.no_grad()
|
|
116
|
+
def set_image(self, image: np.ndarray | Image):
|
|
117
|
+
"""
|
|
118
|
+
Calculates the image embeddings for the provided image.
|
|
119
|
+
|
|
120
|
+
Arguments:
|
|
121
|
+
image (np.ndarray or PIL Image): The input image to embed in RGB format.
|
|
122
|
+
The image should be in HWC format if np.ndarray, or WHC format if PIL Image with pixel values in [0, 255].
|
|
123
|
+
"""
|
|
124
|
+
self.reset_predictor()
|
|
125
|
+
# Transform the image to the form expected by the model
|
|
126
|
+
if isinstance(image, np.ndarray):
|
|
127
|
+
# For numpy array image, we assume (HxWxC) format.
|
|
128
|
+
self._orig_hw = [image.shape[:2]]
|
|
129
|
+
elif isinstance(image, Image):
|
|
130
|
+
w, h = image.size
|
|
131
|
+
self._orig_hw = [(h, w)]
|
|
132
|
+
else:
|
|
133
|
+
raise NotImplementedError("Image format not supported")
|
|
134
|
+
|
|
135
|
+
input_image = self._transforms(image)
|
|
136
|
+
input_image = input_image[None, ...].to(self.device)
|
|
137
|
+
|
|
138
|
+
assert len(input_image.shape) == 4 and input_image.shape[1] == 3, (
|
|
139
|
+
f"input_image must be of size 1x3xHxW, got {input_image.shape}"
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# Computing image embeddings for the provided image
|
|
143
|
+
io_shapes = encoder_shape_dict(batch_size=1, height=input_image.shape[2], width=input_image.shape[3])
|
|
144
|
+
self.encoder_session.allocate_buffers(io_shapes)
|
|
145
|
+
|
|
146
|
+
feed_dict = {"image": input_image.to(self.onnx_dtype).to(self.device)}
|
|
147
|
+
|
|
148
|
+
for key, value in feed_dict.items():
|
|
149
|
+
logger.debug(f"{key}: {value.shape}, {value.dtype}")
|
|
150
|
+
logger.debug(f"encoder onnx: {self.encoder_session.ort_session._model_path}")
|
|
151
|
+
|
|
152
|
+
ort_outputs = self.encoder_session.infer(feed_dict)
|
|
153
|
+
|
|
154
|
+
self._features = {
|
|
155
|
+
"image_embed": ort_outputs["image_embeddings"],
|
|
156
|
+
"high_res_feats": [ort_outputs[f"image_features_{i}"] for i in range(2)],
|
|
157
|
+
}
|
|
158
|
+
self._is_image_set = True
|
|
159
|
+
logging.info("Image embeddings computed.")
|
|
160
|
+
|
|
161
|
+
@torch.no_grad()
|
|
162
|
+
def _predict(
|
|
163
|
+
self,
|
|
164
|
+
point_coords: torch.Tensor | None,
|
|
165
|
+
point_labels: torch.Tensor | None,
|
|
166
|
+
boxes: torch.Tensor | None = None,
|
|
167
|
+
mask_input: torch.Tensor | None = None,
|
|
168
|
+
multimask_output: bool = True,
|
|
169
|
+
return_logits: bool = False,
|
|
170
|
+
img_idx: int = -1,
|
|
171
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
172
|
+
"""
|
|
173
|
+
Predict masks for the given input prompts, using the currently set image.
|
|
174
|
+
Input prompts are batched torch tensors and are expected to already be
|
|
175
|
+
transformed to the input frame using SAM2Transforms.
|
|
176
|
+
|
|
177
|
+
Arguments:
|
|
178
|
+
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
|
|
179
|
+
model. Each point is in (X,Y) in pixels.
|
|
180
|
+
point_labels (torch.Tensor or None): A BxN array of labels for the
|
|
181
|
+
point prompts. 1 indicates a foreground point and 0 indicates a
|
|
182
|
+
background point.
|
|
183
|
+
boxes (np.ndarray or None): A Bx4 array given a box prompt to the
|
|
184
|
+
model, in XYXY format.
|
|
185
|
+
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
|
186
|
+
coming from a previous prediction iteration. Has form Bx1xHxW, where
|
|
187
|
+
for SAM, H=W=256. Masks returned by a previous iteration of the
|
|
188
|
+
predict method do not need further transformation.
|
|
189
|
+
multimask_output (bool): If true, the model will return three masks.
|
|
190
|
+
For ambiguous input prompts (such as a single click), this will often
|
|
191
|
+
produce better masks than a single prediction. If only a single
|
|
192
|
+
mask is needed, the model's predicted quality score can be used
|
|
193
|
+
to select the best mask. For non-ambiguous prompts, such as multiple
|
|
194
|
+
input prompts, multimask_output=False can give better results.
|
|
195
|
+
return_logits (bool): If true, returns un-thresholded masks logits
|
|
196
|
+
instead of a binary mask.
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
(torch.Tensor): The output masks in BxCxHxW format, where C is the
|
|
200
|
+
number of masks, and (H, W) is the original image size.
|
|
201
|
+
(torch.Tensor): An array of shape BxC containing the model's
|
|
202
|
+
predictions for the quality of each mask.
|
|
203
|
+
(torch.Tensor): An array of shape BxCxHxW, where C is the number
|
|
204
|
+
of masks and H=W=256. These low res logits can be passed to
|
|
205
|
+
a subsequent iteration as mask input.
|
|
206
|
+
"""
|
|
207
|
+
assert not return_logits # onnx model is exported for returning bool masks.
|
|
208
|
+
|
|
209
|
+
if not self._is_image_set:
|
|
210
|
+
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
|
|
211
|
+
|
|
212
|
+
if point_coords is not None:
|
|
213
|
+
concat_points = (point_coords, point_labels)
|
|
214
|
+
else:
|
|
215
|
+
concat_points = None
|
|
216
|
+
|
|
217
|
+
# Embed prompts
|
|
218
|
+
if boxes is not None:
|
|
219
|
+
box_coords = boxes.reshape(-1, 2, 2)
|
|
220
|
+
box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
|
|
221
|
+
box_labels = box_labels.repeat(boxes.size(0), 1)
|
|
222
|
+
# we merge "boxes" and "points" into a single "concat_points" input (where
|
|
223
|
+
# boxes are added at the beginning) to sam_prompt_encoder
|
|
224
|
+
if concat_points is not None:
|
|
225
|
+
concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
|
|
226
|
+
concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
|
|
227
|
+
concat_points = (concat_coords, concat_labels)
|
|
228
|
+
else:
|
|
229
|
+
concat_points = (box_coords, box_labels)
|
|
230
|
+
|
|
231
|
+
assert concat_points is not None
|
|
232
|
+
num_labels = concat_points[0].shape[0]
|
|
233
|
+
shape_dict = decoder_shape_dict(
|
|
234
|
+
original_image_height=self._orig_hw[img_idx][0],
|
|
235
|
+
original_image_width=self._orig_hw[img_idx][1],
|
|
236
|
+
num_labels=num_labels,
|
|
237
|
+
max_points=concat_points[0].shape[1],
|
|
238
|
+
num_masks=3 if multimask_output else 1,
|
|
239
|
+
)
|
|
240
|
+
if multimask_output:
|
|
241
|
+
decoder_session = self.decoder_session_multi_out
|
|
242
|
+
else:
|
|
243
|
+
decoder_session = self.decoder_session
|
|
244
|
+
|
|
245
|
+
decoder_session.allocate_buffers(shape_dict)
|
|
246
|
+
|
|
247
|
+
image_features_0 = self._features["high_res_feats"][0][img_idx].unsqueeze(0)
|
|
248
|
+
image_features_1 = self._features["high_res_feats"][1][img_idx].unsqueeze(0)
|
|
249
|
+
image_embeddings = self._features["image_embed"][img_idx].unsqueeze(0)
|
|
250
|
+
|
|
251
|
+
if mask_input is None:
|
|
252
|
+
input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=self.onnx_dtype, device=self.device)
|
|
253
|
+
has_input_masks = torch.zeros(num_labels, dtype=self.onnx_dtype, device=self.device)
|
|
254
|
+
else:
|
|
255
|
+
input_masks = mask_input[img_idx].unsqueeze(0).repeat(num_labels, 1, 1, 1)
|
|
256
|
+
has_input_masks = torch.ones(num_labels, dtype=self.onnx_dtype, device=self.device)
|
|
257
|
+
|
|
258
|
+
feed_dict = {
|
|
259
|
+
"image_embeddings": image_embeddings.contiguous().to(dtype=self.onnx_dtype).to(self.device),
|
|
260
|
+
"image_features_0": image_features_0.contiguous().to(dtype=self.onnx_dtype).to(self.device),
|
|
261
|
+
"image_features_1": image_features_1.contiguous().to(dtype=self.onnx_dtype).to(self.device),
|
|
262
|
+
"point_coords": concat_points[0].to(dtype=self.onnx_dtype).to(self.device),
|
|
263
|
+
"point_labels": concat_points[1].to(dtype=torch.int32).to(self.device),
|
|
264
|
+
"input_masks": input_masks.to(dtype=self.onnx_dtype).to(self.device),
|
|
265
|
+
"has_input_masks": has_input_masks.to(dtype=self.onnx_dtype).to(self.device),
|
|
266
|
+
"original_image_size": torch.tensor(self._orig_hw[img_idx], dtype=torch.int32, device=self.device),
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
for key, value in feed_dict.items():
|
|
270
|
+
logger.debug(f"{key}: {value.shape}, {value.dtype}")
|
|
271
|
+
logger.debug(f"decoder onnx: {self.decoder_session.ort_session._model_path}")
|
|
272
|
+
|
|
273
|
+
ort_outputs = decoder_session.infer(feed_dict)
|
|
274
|
+
|
|
275
|
+
masks = ort_outputs["masks"]
|
|
276
|
+
iou_predictions = ort_outputs["iou_predictions"]
|
|
277
|
+
low_res_masks = ort_outputs["low_res_masks"]
|
|
278
|
+
|
|
279
|
+
return torch.Tensor(masks), torch.Tensor(iou_predictions), torch.Tensor(low_res_masks)
|