onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (R) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import logging
|
|
6
|
+
import warnings
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from sam2.modeling.sam2_base import SAM2Base
|
|
10
|
+
from sam2_utils import compare_tensors_with_tolerance, random_sam2_input_image
|
|
11
|
+
from torch import nn
|
|
12
|
+
|
|
13
|
+
import onnxruntime
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SAM2ImageEncoder(nn.Module):
|
|
19
|
+
def __init__(self, sam_model: SAM2Base) -> None:
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.model = sam_model
|
|
22
|
+
self.image_encoder = sam_model.image_encoder
|
|
23
|
+
self.no_mem_embed = sam_model.no_mem_embed
|
|
24
|
+
|
|
25
|
+
def forward(
|
|
26
|
+
self,
|
|
27
|
+
image: torch.Tensor,
|
|
28
|
+
enable_nvtx_profile: bool = False,
|
|
29
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
30
|
+
"""
|
|
31
|
+
Encodes images into features.
|
|
32
|
+
|
|
33
|
+
Only supports H=W=1024. If you want to use different image sizes like 512x512,
|
|
34
|
+
see https://github.com/facebookresearch/segment-anything-2/issues/138.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
image (torch.Tensor): images of shape [B, 3, H, W], B is batch size, H and W are height and width.
|
|
38
|
+
enable_nvtx_profile (bool): enable NVTX profiling.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
image_features_0: image features of shape [B, 32, H/4, W/4] - high resolution features of level 0
|
|
42
|
+
image_features_1: image features of shape [B, 64, H/8, W/8] - high resolution features of level 1
|
|
43
|
+
image_embeddings: image features of shape [B, 256, H/16, W/16] - 16 is the backbone_stride
|
|
44
|
+
"""
|
|
45
|
+
nvtx_helper = None
|
|
46
|
+
if enable_nvtx_profile:
|
|
47
|
+
from nvtx_helper import NvtxHelper
|
|
48
|
+
|
|
49
|
+
nvtx_helper = NvtxHelper(["image_encoder", "post_process"])
|
|
50
|
+
|
|
51
|
+
if nvtx_helper is not None:
|
|
52
|
+
nvtx_helper.start_profile("image_encoder")
|
|
53
|
+
|
|
54
|
+
backbone_out = self.image_encoder(image)
|
|
55
|
+
|
|
56
|
+
if nvtx_helper is not None:
|
|
57
|
+
nvtx_helper.stop_profile("image_encoder")
|
|
58
|
+
nvtx_helper.start_profile("post_process")
|
|
59
|
+
|
|
60
|
+
# precompute projected level 0 and level 1 features in SAM decoder
|
|
61
|
+
# to avoid running it again on every SAM click
|
|
62
|
+
backbone_out["backbone_fpn"][0] = self.model.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
|
|
63
|
+
backbone_out["backbone_fpn"][1] = self.model.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
|
|
64
|
+
|
|
65
|
+
# Prepare and flatten visual features.
|
|
66
|
+
feature_maps = backbone_out["backbone_fpn"][-self.model.num_feature_levels :]
|
|
67
|
+
vision_pos_embeds = backbone_out["vision_pos_enc"][-self.model.num_feature_levels :]
|
|
68
|
+
feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
|
|
69
|
+
|
|
70
|
+
# flatten NxCxHxW to HWxNxC
|
|
71
|
+
# TODO: we should avoid this transpose since it will be transposed back to NCHW later.
|
|
72
|
+
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
|
|
73
|
+
|
|
74
|
+
vision_feats[-1] = vision_feats[-1] + self.no_mem_embed
|
|
75
|
+
|
|
76
|
+
feats = [
|
|
77
|
+
feat.permute(1, 2, 0).reshape(1, -1, *feat_size)
|
|
78
|
+
for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1])
|
|
79
|
+
][::-1]
|
|
80
|
+
|
|
81
|
+
if nvtx_helper is not None:
|
|
82
|
+
nvtx_helper.stop_profile("post_process")
|
|
83
|
+
nvtx_helper.print_latency()
|
|
84
|
+
|
|
85
|
+
return feats[0], feats[1], feats[2]
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def export_image_encoder_onnx(
|
|
89
|
+
sam2_model: SAM2Base,
|
|
90
|
+
onnx_model_path: str,
|
|
91
|
+
dynamic_batch_axes: bool = False,
|
|
92
|
+
verbose: bool = False,
|
|
93
|
+
):
|
|
94
|
+
image = random_sam2_input_image()
|
|
95
|
+
|
|
96
|
+
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
|
|
97
|
+
image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
|
|
98
|
+
logger.info("image.shape: %s", image.shape)
|
|
99
|
+
logger.info("image_features_0.shape: %s", image_features_0.shape)
|
|
100
|
+
logger.info("image_features_1.shape: %s", image_features_1.shape)
|
|
101
|
+
logger.info("image_embeddings.shape: %s", image_embeddings.shape)
|
|
102
|
+
|
|
103
|
+
dynamic_axes = None
|
|
104
|
+
if dynamic_batch_axes:
|
|
105
|
+
dynamic_axes = {
|
|
106
|
+
"image": {0: "batch_size"},
|
|
107
|
+
"image_features_0": {0: "batch_size"},
|
|
108
|
+
"image_features_1": {0: "batch_size"},
|
|
109
|
+
"image_embeddings": {0: "batch_size"},
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
with warnings.catch_warnings():
|
|
113
|
+
if not verbose:
|
|
114
|
+
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
|
|
115
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
116
|
+
torch.onnx.export(
|
|
117
|
+
sam2_encoder,
|
|
118
|
+
image,
|
|
119
|
+
onnx_model_path,
|
|
120
|
+
export_params=True,
|
|
121
|
+
opset_version=17,
|
|
122
|
+
do_constant_folding=True,
|
|
123
|
+
input_names=["image"],
|
|
124
|
+
output_names=["image_features_0", "image_features_1", "image_embeddings"],
|
|
125
|
+
dynamic_axes=dynamic_axes,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
print("encoder onnx model saved to", onnx_model_path)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def test_image_encoder_onnx(
|
|
132
|
+
sam2_model: SAM2Base,
|
|
133
|
+
onnx_model_path: str,
|
|
134
|
+
dynamic_batch_axes=False,
|
|
135
|
+
):
|
|
136
|
+
ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers())
|
|
137
|
+
|
|
138
|
+
model_inputs = ort_session.get_inputs()
|
|
139
|
+
input_names = [model_inputs[i].name for i in range(len(model_inputs))]
|
|
140
|
+
logger.info("input_names: %s", input_names)
|
|
141
|
+
|
|
142
|
+
model_outputs = ort_session.get_outputs()
|
|
143
|
+
output_names = [model_outputs[i].name for i in range(len(model_outputs))]
|
|
144
|
+
logger.info("output_names: %s", output_names)
|
|
145
|
+
|
|
146
|
+
batch_sizes = [1, 2] if dynamic_batch_axes else [1]
|
|
147
|
+
for batch_size in batch_sizes:
|
|
148
|
+
image = random_sam2_input_image(batch_size)
|
|
149
|
+
|
|
150
|
+
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
|
|
151
|
+
image_features_0, image_features_1, image_embeddings = sam2_encoder(image.clone())
|
|
152
|
+
|
|
153
|
+
logger.info("image.shape: %s", image.shape)
|
|
154
|
+
logger.info("image_features_0.shape: %s", image_features_0.shape)
|
|
155
|
+
logger.info("image_features_1.shape: %s", image_features_1.shape)
|
|
156
|
+
logger.info("image_embeddings.shape: %s", image_embeddings.shape)
|
|
157
|
+
|
|
158
|
+
outputs = ort_session.run(output_names, {"image": image.numpy()})
|
|
159
|
+
for i, output_name in enumerate(output_names):
|
|
160
|
+
logger.info("output %s shape %s", output_name, outputs[i].shape)
|
|
161
|
+
ort_image_features_0, ort_image_features_1, ort_image_embeddings = outputs
|
|
162
|
+
|
|
163
|
+
# ONNXRuntime and PyTorch has about 0.75% mismatched elements, but seems not impacting segmentation results.
|
|
164
|
+
if (
|
|
165
|
+
compare_tensors_with_tolerance(
|
|
166
|
+
"image_features_0",
|
|
167
|
+
image_features_0,
|
|
168
|
+
torch.tensor(ort_image_features_0),
|
|
169
|
+
mismatch_percentage_tolerance=1,
|
|
170
|
+
)
|
|
171
|
+
and compare_tensors_with_tolerance(
|
|
172
|
+
"image_features_1",
|
|
173
|
+
image_features_1,
|
|
174
|
+
torch.tensor(ort_image_features_1),
|
|
175
|
+
mismatch_percentage_tolerance=1,
|
|
176
|
+
)
|
|
177
|
+
and compare_tensors_with_tolerance(
|
|
178
|
+
"image_embeddings",
|
|
179
|
+
image_embeddings,
|
|
180
|
+
torch.tensor(ort_image_embeddings),
|
|
181
|
+
mismatch_percentage_tolerance=1,
|
|
182
|
+
)
|
|
183
|
+
):
|
|
184
|
+
print(f"onnx model has been verified for batch_size={batch_size}: {onnx_model_path}")
|
|
185
|
+
else:
|
|
186
|
+
print(f"onnx model verification failed for batch_size={batch_size}: {onnx_model_path}")
|
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (R) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import logging
|
|
6
|
+
import warnings
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from image_encoder import SAM2ImageEncoder, random_sam2_input_image
|
|
10
|
+
from prompt_encoder import SAM2PromptEncoder
|
|
11
|
+
from sam2.modeling.sam2_base import SAM2Base
|
|
12
|
+
from torch import nn
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SAM2MaskDecoder(nn.Module):
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
sam_model: SAM2Base,
|
|
21
|
+
multimask_output: bool,
|
|
22
|
+
dynamic_multimask_via_stability: bool = True,
|
|
23
|
+
) -> None:
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.mask_decoder = sam_model.sam_mask_decoder
|
|
26
|
+
self.prompt_encoder = sam_model.sam_prompt_encoder
|
|
27
|
+
self.model = sam_model
|
|
28
|
+
self.multimask_output = multimask_output
|
|
29
|
+
self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
|
|
30
|
+
|
|
31
|
+
@torch.no_grad()
|
|
32
|
+
def forward(
|
|
33
|
+
self,
|
|
34
|
+
image_features_0: torch.Tensor,
|
|
35
|
+
image_features_1: torch.Tensor,
|
|
36
|
+
image_embeddings: torch.Tensor,
|
|
37
|
+
image_pe: torch.Tensor,
|
|
38
|
+
sparse_embeddings: torch.Tensor,
|
|
39
|
+
dense_embeddings: torch.Tensor,
|
|
40
|
+
):
|
|
41
|
+
"""
|
|
42
|
+
Decode masks from image and prompt embeddings. Only support H=W=1024.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
image_features_0 (torch.Tensor): [1, 32, H/4, W/4]. high resolution features of level 0 from image encoder.
|
|
46
|
+
image_features_1 (torch.Tensor): [1, 64, H/8, W/8]. high resolution features of level 1 from image encoder.
|
|
47
|
+
image_embeddings (torch.Tensor): [1, 256, H/16, W/16]. image embedding from image encoder.
|
|
48
|
+
image_pe (torch.Tensor): [1, 256, H/16, W/16]. image positional encoding.
|
|
49
|
+
sparse_embeddings (torch.Tensor): [L, P+1, 256], embedding for points and boxes.
|
|
50
|
+
dense_embeddings (torch.Tensor): [L, 256, H/16, W/16]. embedding for input masks.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
low_res_masks (torch.Tensor, optional): [1, M, H/4, W/4]. low resolution masks.
|
|
54
|
+
iou_predictions (torch.Tensor): [1, M]. scores for M masks.
|
|
55
|
+
"""
|
|
56
|
+
low_res_masks, iou_predictions, _, _ = self.mask_decoder.predict_masks(
|
|
57
|
+
image_embeddings=image_embeddings,
|
|
58
|
+
image_pe=image_pe,
|
|
59
|
+
sparse_prompt_embeddings=sparse_embeddings,
|
|
60
|
+
dense_prompt_embeddings=dense_embeddings,
|
|
61
|
+
repeat_image=sparse_embeddings.shape[0] > 1, # batch mode
|
|
62
|
+
high_res_features=[image_features_0, image_features_1],
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
if self.multimask_output:
|
|
66
|
+
low_res_masks = low_res_masks[:, 1:, :, :]
|
|
67
|
+
iou_predictions = iou_predictions[:, 1:]
|
|
68
|
+
elif self.dynamic_multimask_via_stability:
|
|
69
|
+
# When outputting a single mask, if the stability score from the current single-mask
|
|
70
|
+
# output (based on output token 0) falls below a threshold, we instead select from
|
|
71
|
+
# multi-mask outputs (based on output token 1~3) the mask with the highest predicted IoU score.
|
|
72
|
+
low_res_masks, iou_predictions = self.mask_decoder._dynamic_multimask_via_stability(
|
|
73
|
+
low_res_masks, iou_predictions
|
|
74
|
+
)
|
|
75
|
+
else:
|
|
76
|
+
low_res_masks = low_res_masks[:, 0:1, :, :]
|
|
77
|
+
iou_predictions = iou_predictions[:, 0:1]
|
|
78
|
+
|
|
79
|
+
return low_res_masks, iou_predictions
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def export_mask_decoder_onnx(
|
|
83
|
+
sam2_model: SAM2Base,
|
|
84
|
+
onnx_model_path: str,
|
|
85
|
+
multimask_output: bool,
|
|
86
|
+
dynamic_multimask_via_stability: bool = True,
|
|
87
|
+
verbose=False,
|
|
88
|
+
):
|
|
89
|
+
sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
|
|
90
|
+
|
|
91
|
+
image = random_sam2_input_image()
|
|
92
|
+
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
|
|
93
|
+
image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
|
|
94
|
+
logger.info("image_features_0.shape: %s", image_features_0.shape)
|
|
95
|
+
logger.info("image_features_1.shape: %s", image_features_1.shape)
|
|
96
|
+
logger.info("image_embeddings.shape: %s", image_embeddings.shape)
|
|
97
|
+
|
|
98
|
+
# encode an random prompt
|
|
99
|
+
num_labels = 2
|
|
100
|
+
num_points = 3
|
|
101
|
+
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
|
|
102
|
+
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.float)
|
|
103
|
+
input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
|
|
104
|
+
has_input_masks = torch.ones(1, dtype=torch.float)
|
|
105
|
+
|
|
106
|
+
sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
|
|
107
|
+
point_coords, point_labels, input_masks, has_input_masks
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
logger.info("sparse_embeddings.shape: %s", sparse_embeddings.shape)
|
|
111
|
+
logger.info("dense_embeddings.shape: %s", dense_embeddings.shape)
|
|
112
|
+
logger.info("image_pe.shape: %s", image_pe.shape)
|
|
113
|
+
|
|
114
|
+
sam2_mask_decoder = SAM2MaskDecoder(sam2_model, multimask_output, dynamic_multimask_via_stability)
|
|
115
|
+
inputs = (image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings)
|
|
116
|
+
low_res_masks, iou_predictions = sam2_mask_decoder(*inputs)
|
|
117
|
+
logger.info("low_res_masks.shape: %s", low_res_masks.shape)
|
|
118
|
+
logger.info("iou_predictions.shape: %s", iou_predictions.shape)
|
|
119
|
+
|
|
120
|
+
with warnings.catch_warnings():
|
|
121
|
+
if not verbose:
|
|
122
|
+
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
|
|
123
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
124
|
+
torch.onnx.export(
|
|
125
|
+
sam2_mask_decoder,
|
|
126
|
+
inputs,
|
|
127
|
+
onnx_model_path,
|
|
128
|
+
export_params=True,
|
|
129
|
+
opset_version=18,
|
|
130
|
+
do_constant_folding=True,
|
|
131
|
+
input_names=[
|
|
132
|
+
"image_features_0",
|
|
133
|
+
"image_features_1",
|
|
134
|
+
"image_embeddings",
|
|
135
|
+
"image_pe",
|
|
136
|
+
"sparse_embeddings",
|
|
137
|
+
"dense_embeddings",
|
|
138
|
+
],
|
|
139
|
+
output_names=["low_res_masks", "iou_predictions"],
|
|
140
|
+
dynamic_axes={
|
|
141
|
+
"sparse_embeddings": {0: "num_labels", 1: "num_points+1"},
|
|
142
|
+
"dense_embeddings": {0: "num_labels"},
|
|
143
|
+
"low_res_masks": {0: "num_labels"},
|
|
144
|
+
"iou_predictions": {0: "num_labels"},
|
|
145
|
+
},
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
print("mask decoder onnx model saved to", onnx_model_path)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def test_mask_decoder_onnx(
|
|
152
|
+
sam2_model: SAM2Base,
|
|
153
|
+
onnx_model_path: str,
|
|
154
|
+
multimask_output: bool,
|
|
155
|
+
dynamic_multimask_via_stability: bool,
|
|
156
|
+
):
|
|
157
|
+
sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
|
|
158
|
+
|
|
159
|
+
image = random_sam2_input_image()
|
|
160
|
+
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
|
|
161
|
+
image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
|
|
162
|
+
|
|
163
|
+
num_labels = 1
|
|
164
|
+
num_points = 5
|
|
165
|
+
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
|
|
166
|
+
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.float)
|
|
167
|
+
input_masks = torch.rand(num_labels, 1, 256, 256, dtype=torch.float)
|
|
168
|
+
has_input_masks = torch.ones(1, dtype=torch.float)
|
|
169
|
+
|
|
170
|
+
sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
|
|
171
|
+
point_coords, point_labels, input_masks, has_input_masks
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
sam2_mask_decoder = SAM2MaskDecoder(sam2_model, multimask_output, dynamic_multimask_via_stability)
|
|
175
|
+
inputs = (image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings)
|
|
176
|
+
low_res_masks, iou_predictions = sam2_mask_decoder(*inputs)
|
|
177
|
+
|
|
178
|
+
import onnxruntime
|
|
179
|
+
|
|
180
|
+
ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers())
|
|
181
|
+
|
|
182
|
+
model_inputs = ort_session.get_inputs()
|
|
183
|
+
input_names = [model_inputs[i].name for i in range(len(model_inputs))]
|
|
184
|
+
logger.info("input_names: %s", input_names)
|
|
185
|
+
|
|
186
|
+
model_outputs = ort_session.get_outputs()
|
|
187
|
+
output_names = [model_outputs[i].name for i in range(len(model_outputs))]
|
|
188
|
+
logger.info("output_names: %s", output_names)
|
|
189
|
+
|
|
190
|
+
outputs = ort_session.run(
|
|
191
|
+
output_names,
|
|
192
|
+
{
|
|
193
|
+
"image_features_0": image_features_0.numpy(),
|
|
194
|
+
"image_features_1": image_features_1.numpy(),
|
|
195
|
+
"image_embeddings": image_embeddings.numpy(),
|
|
196
|
+
"image_pe": image_pe.numpy(),
|
|
197
|
+
"sparse_embeddings": sparse_embeddings.numpy(),
|
|
198
|
+
"dense_embeddings": dense_embeddings.numpy(),
|
|
199
|
+
},
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
for i, output_name in enumerate(output_names):
|
|
203
|
+
logger.info("output %s shape: %s", output_name, outputs[i].shape)
|
|
204
|
+
|
|
205
|
+
ort_low_res_masks, ort_iou_predictions = outputs
|
|
206
|
+
torch.testing.assert_close(low_res_masks, torch.tensor(ort_low_res_masks), atol=5e-3, rtol=1e-4)
|
|
207
|
+
torch.testing.assert_close(iou_predictions, torch.tensor(ort_iou_predictions), atol=5e-3, rtol=1e-4)
|
|
208
|
+
print(f"onnx model has been verified: {onnx_model_path}")
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (R) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import nvtx
|
|
6
|
+
from cuda import cudart
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class NvtxHelper:
|
|
10
|
+
def __init__(self, stages):
|
|
11
|
+
self.stages = stages
|
|
12
|
+
self.events = {}
|
|
13
|
+
for stage in stages:
|
|
14
|
+
for marker in ["start", "stop"]:
|
|
15
|
+
self.events[stage + "-" + marker] = cudart.cudaEventCreate()[1]
|
|
16
|
+
self.markers = {}
|
|
17
|
+
|
|
18
|
+
def start_profile(self, stage, color="blue"):
|
|
19
|
+
self.markers[stage] = nvtx.start_range(message=stage, color=color)
|
|
20
|
+
event_name = stage + "-start"
|
|
21
|
+
if event_name in self.events:
|
|
22
|
+
cudart.cudaEventRecord(self.events[event_name], 0)
|
|
23
|
+
|
|
24
|
+
def stop_profile(self, stage):
|
|
25
|
+
event_name = stage + "-stop"
|
|
26
|
+
if event_name in self.events:
|
|
27
|
+
cudart.cudaEventRecord(self.events[event_name], 0)
|
|
28
|
+
nvtx.end_range(self.markers[stage])
|
|
29
|
+
|
|
30
|
+
def print_latency(self):
|
|
31
|
+
for stage in self.stages:
|
|
32
|
+
latency = cudart.cudaEventElapsedTime(self.events[f"{stage}-start"], self.events[f"{stage}-stop"])[1]
|
|
33
|
+
print(f"{stage}: {latency:.2f} ms")
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (R) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from sam2.modeling.sam2_base import SAM2Base
|
|
9
|
+
from sam2_utils import compare_tensors_with_tolerance
|
|
10
|
+
from torch import nn
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SAM2PromptEncoder(nn.Module):
|
|
16
|
+
def __init__(self, sam_model: SAM2Base):
|
|
17
|
+
super().__init__()
|
|
18
|
+
self.prompt_encoder = sam_model.sam_prompt_encoder
|
|
19
|
+
self.model = sam_model
|
|
20
|
+
|
|
21
|
+
@torch.no_grad()
|
|
22
|
+
def forward(
|
|
23
|
+
self,
|
|
24
|
+
point_coords: torch.Tensor,
|
|
25
|
+
point_labels: torch.Tensor,
|
|
26
|
+
input_masks: torch.Tensor,
|
|
27
|
+
has_input_masks: torch.Tensor,
|
|
28
|
+
):
|
|
29
|
+
"""Encode prompts.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
point_coords (torch.Tensor): [L, P, 2] shape and float32 dtype and contains the absolute pixel
|
|
33
|
+
coordinate in (x, y) format of the P input points in image of size 1024x1024.
|
|
34
|
+
point_labels (torch.Tensor): shape [L, P] and int32 dtype, where 1 means
|
|
35
|
+
positive (foreground), 0 means negative (background), -1 means padding,
|
|
36
|
+
2 (box left upper corner), 3 (box right bottom corner).
|
|
37
|
+
input_masks (torch.Tensor): [L, 1, H/4, W/4]. Low resolution mask input to the model.
|
|
38
|
+
Typically coming from a previous iteration.
|
|
39
|
+
has_input_masks (torch.Tensor): [L]. 1.0 if input_masks is used, 0.0 otherwise.
|
|
40
|
+
Returns:
|
|
41
|
+
sparse_embeddings (torch.Tensor): [L, P+1, 256], embedding for points and boxes.
|
|
42
|
+
dense_embeddings (torch.Tensor): [L, 256, 64, 64]. embedding for input masks.
|
|
43
|
+
image_pe (torch.Tensor, optional): [1, 256, 64, 64]. image positional encoding.
|
|
44
|
+
"""
|
|
45
|
+
sparse_embeddings = self._embed_points(point_coords, point_labels)
|
|
46
|
+
dense_embeddings = self._embed_masks(input_masks, has_input_masks)
|
|
47
|
+
image_pe = self.prompt_encoder.get_dense_pe()
|
|
48
|
+
|
|
49
|
+
return sparse_embeddings, dense_embeddings, image_pe
|
|
50
|
+
|
|
51
|
+
def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
|
|
52
|
+
point_coords = point_coords + 0.5
|
|
53
|
+
|
|
54
|
+
padding_point = torch.zeros((point_coords.shape[0], 1, 2), device=point_coords.device)
|
|
55
|
+
padding_label = -torch.ones((point_labels.shape[0], 1), device=point_labels.device)
|
|
56
|
+
point_coords = torch.cat([point_coords, padding_point], dim=1)
|
|
57
|
+
point_labels = torch.cat([point_labels, padding_label], dim=1)
|
|
58
|
+
|
|
59
|
+
# Note that the input coordinates are based on image size 1024x1024. Here we normalize it to [0.0, 1.0).
|
|
60
|
+
point_coords[:, :, 0] = point_coords[:, :, 0] / self.model.image_size
|
|
61
|
+
point_coords[:, :, 1] = point_coords[:, :, 1] / self.model.image_size
|
|
62
|
+
|
|
63
|
+
point_embedding = self.prompt_encoder.pe_layer._pe_encoding(point_coords)
|
|
64
|
+
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
|
|
65
|
+
|
|
66
|
+
point_embedding = point_embedding * (point_labels != -1)
|
|
67
|
+
point_embedding = point_embedding + self.prompt_encoder.not_a_point_embed.weight * (point_labels == -1)
|
|
68
|
+
|
|
69
|
+
for i in range(self.prompt_encoder.num_point_embeddings):
|
|
70
|
+
point_embedding = point_embedding + self.prompt_encoder.point_embeddings[i].weight * (point_labels == i)
|
|
71
|
+
|
|
72
|
+
return point_embedding
|
|
73
|
+
|
|
74
|
+
def _embed_masks(self, input_masks: torch.Tensor, has_input_masks: torch.Tensor) -> torch.Tensor:
|
|
75
|
+
mask_embedding = self.prompt_encoder.mask_downscaling(input_masks)
|
|
76
|
+
no_mask_embedding = self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
|
|
77
|
+
logger.info("no_mask_embedding.shape: %s", no_mask_embedding.shape)
|
|
78
|
+
mask_embedding = has_input_masks * mask_embedding + (1.0 - has_input_masks) * no_mask_embedding
|
|
79
|
+
logger.info("mask_embedding.shape: %s", mask_embedding.shape)
|
|
80
|
+
return mask_embedding
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def export_prompt_encoder_onnx(
|
|
84
|
+
sam2_model: SAM2Base,
|
|
85
|
+
onnx_model_path: str,
|
|
86
|
+
):
|
|
87
|
+
sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
|
|
88
|
+
|
|
89
|
+
num_labels = 2
|
|
90
|
+
num_points = 3
|
|
91
|
+
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
|
|
92
|
+
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
|
|
93
|
+
input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
|
|
94
|
+
has_input_masks = torch.ones(1, dtype=torch.float)
|
|
95
|
+
|
|
96
|
+
sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
|
|
97
|
+
point_coords, point_labels, input_masks, has_input_masks
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
logger.info("point_coords.shape: %s", point_coords.shape)
|
|
101
|
+
logger.info("point_labels.shape: %s", point_labels.shape)
|
|
102
|
+
logger.info("input_masks.shape: %s", input_masks.shape)
|
|
103
|
+
logger.info("has_input_masks.shape: %s", has_input_masks.shape)
|
|
104
|
+
|
|
105
|
+
logger.info("sparse_embeddings.shape: %s", sparse_embeddings.shape)
|
|
106
|
+
logger.info("dense_embeddings.shape: %s", dense_embeddings.shape)
|
|
107
|
+
logger.info("image_pe.shape: %s", image_pe.shape)
|
|
108
|
+
|
|
109
|
+
torch.onnx.export(
|
|
110
|
+
sam2_prompt_encoder,
|
|
111
|
+
(point_coords, point_labels, input_masks, has_input_masks),
|
|
112
|
+
onnx_model_path,
|
|
113
|
+
export_params=True,
|
|
114
|
+
opset_version=18,
|
|
115
|
+
do_constant_folding=True,
|
|
116
|
+
input_names=["point_coords", "point_labels", "input_masks", "has_input_masks"],
|
|
117
|
+
output_names=["sparse_embeddings", "dense_embeddings", "image_pe"],
|
|
118
|
+
dynamic_axes={
|
|
119
|
+
"point_coords": {0: "num_labels", 1: "num_points"},
|
|
120
|
+
"point_labels": {0: "num_labels", 1: "num_points"},
|
|
121
|
+
"input_masks": {0: "num_labels"},
|
|
122
|
+
"sparse_embeddings": {0: "num_labels", 1: "num_points+1"},
|
|
123
|
+
"dense_embeddings": {0: "num_labels"},
|
|
124
|
+
},
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
print("prompt encoder onnx model saved to ", onnx_model_path)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def test_prompt_encoder_onnx(
|
|
131
|
+
sam2_model: SAM2Base,
|
|
132
|
+
onnx_model_path: str,
|
|
133
|
+
):
|
|
134
|
+
sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
|
|
135
|
+
|
|
136
|
+
num_labels = 1
|
|
137
|
+
num_points = 5
|
|
138
|
+
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
|
|
139
|
+
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
|
|
140
|
+
input_masks = torch.rand(num_labels, 1, 256, 256, dtype=torch.float)
|
|
141
|
+
has_input_masks = torch.ones(1, dtype=torch.float)
|
|
142
|
+
|
|
143
|
+
sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
|
|
144
|
+
point_coords, point_labels, input_masks, has_input_masks
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
import onnxruntime
|
|
148
|
+
|
|
149
|
+
ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers())
|
|
150
|
+
|
|
151
|
+
model_inputs = ort_session.get_inputs()
|
|
152
|
+
input_names = [model_inputs[i].name for i in range(len(model_inputs))]
|
|
153
|
+
logger.info("input_names: %s", input_names)
|
|
154
|
+
|
|
155
|
+
model_outputs = ort_session.get_outputs()
|
|
156
|
+
output_names = [model_outputs[i].name for i in range(len(model_outputs))]
|
|
157
|
+
logger.info("output_names: %s", output_names)
|
|
158
|
+
|
|
159
|
+
outputs = ort_session.run(
|
|
160
|
+
output_names,
|
|
161
|
+
{
|
|
162
|
+
"point_coords": point_coords.numpy(),
|
|
163
|
+
"point_labels": point_labels.numpy(),
|
|
164
|
+
"input_masks": input_masks.numpy(),
|
|
165
|
+
"has_input_masks": has_input_masks.numpy(),
|
|
166
|
+
},
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
for i, output_name in enumerate(output_names):
|
|
170
|
+
logger.info("output %s shape: %s", output_name, outputs[i].shape)
|
|
171
|
+
|
|
172
|
+
ort_sparse_embeddings, ort_dense_embeddings, ort_image_pe = outputs
|
|
173
|
+
if (
|
|
174
|
+
compare_tensors_with_tolerance(
|
|
175
|
+
"sparse_embeddings",
|
|
176
|
+
sparse_embeddings,
|
|
177
|
+
torch.tensor(ort_sparse_embeddings),
|
|
178
|
+
mismatch_percentage_tolerance=0.2,
|
|
179
|
+
)
|
|
180
|
+
and compare_tensors_with_tolerance(
|
|
181
|
+
"dense_embeddings", dense_embeddings, torch.tensor(ort_dense_embeddings), mismatch_percentage_tolerance=0.2
|
|
182
|
+
)
|
|
183
|
+
and compare_tensors_with_tolerance(
|
|
184
|
+
"image_pe", image_pe, torch.tensor(ort_image_pe), mismatch_percentage_tolerance=0.2
|
|
185
|
+
)
|
|
186
|
+
):
|
|
187
|
+
print(f"onnx model has been verified: {onnx_model_path}")
|
|
188
|
+
else:
|
|
189
|
+
print(f"onnx model verification failed: {onnx_model_path}")
|