monai-weekly 1.5.dev2506__py3-none-any.whl → 1.5.dev2508__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/auto3dseg/transforms.py +1 -4
- monai/data/utils.py +6 -13
- monai/handlers/__init__.py +1 -0
- monai/handlers/average_precision.py +53 -0
- monai/inferers/inferer.py +10 -7
- monai/inferers/utils.py +1 -2
- monai/losses/dice.py +2 -14
- monai/losses/ds_loss.py +1 -3
- monai/metrics/__init__.py +1 -0
- monai/metrics/average_precision.py +187 -0
- monai/networks/layers/simplelayers.py +2 -14
- monai/networks/utils.py +4 -16
- monai/transforms/compose.py +28 -11
- monai/transforms/croppad/array.py +1 -6
- monai/transforms/io/array.py +0 -1
- monai/transforms/transform.py +15 -6
- monai/transforms/utility/array.py +2 -12
- monai/transforms/utils.py +1 -2
- monai/transforms/utils_pytorch_numpy_unification.py +2 -4
- monai/utils/enums.py +3 -2
- monai/utils/module.py +6 -6
- monai/utils/tf32.py +0 -10
- monai/visualize/class_activation_maps.py +5 -8
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/METADATA +21 -17
- monai_weekly-1.5.dev2508.dist-info/RECORD +1185 -0
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/top_level.txt +1 -0
- tests/apps/__init__.py +10 -0
- tests/apps/deepedit/__init__.py +10 -0
- tests/apps/deepedit/test_deepedit_transforms.py +314 -0
- tests/apps/deepgrow/__init__.py +10 -0
- tests/apps/deepgrow/test_deepgrow_dataset.py +109 -0
- tests/apps/deepgrow/transforms/__init__.py +10 -0
- tests/apps/deepgrow/transforms/test_deepgrow_interaction.py +97 -0
- tests/apps/deepgrow/transforms/test_deepgrow_transforms.py +556 -0
- tests/apps/detection/__init__.py +10 -0
- tests/apps/detection/metrics/__init__.py +10 -0
- tests/apps/detection/metrics/test_detection_coco_metrics.py +69 -0
- tests/apps/detection/networks/__init__.py +10 -0
- tests/apps/detection/networks/test_retinanet.py +210 -0
- tests/apps/detection/networks/test_retinanet_detector.py +203 -0
- tests/apps/detection/test_box_transform.py +370 -0
- tests/apps/detection/utils/__init__.py +10 -0
- tests/apps/detection/utils/test_anchor_box.py +88 -0
- tests/apps/detection/utils/test_atss_box_matcher.py +46 -0
- tests/apps/detection/utils/test_box_coder.py +43 -0
- tests/apps/detection/utils/test_detector_boxselector.py +67 -0
- tests/apps/detection/utils/test_detector_utils.py +96 -0
- tests/apps/detection/utils/test_hardnegsampler.py +54 -0
- tests/apps/nuclick/__init__.py +10 -0
- tests/apps/nuclick/test_nuclick_transforms.py +259 -0
- tests/apps/pathology/__init__.py +10 -0
- tests/apps/pathology/handlers/__init__.py +10 -0
- tests/apps/pathology/handlers/test_from_engine_hovernet.py +38 -0
- tests/apps/pathology/test_lesion_froc.py +333 -0
- tests/apps/pathology/test_pathology_prob_nms.py +55 -0
- tests/apps/pathology/test_prepare_batch_hovernet.py +70 -0
- tests/apps/pathology/test_sliding_window_hovernet_inference.py +303 -0
- tests/apps/pathology/transforms/__init__.py +10 -0
- tests/apps/pathology/transforms/post/__init__.py +10 -0
- tests/apps/pathology/transforms/post/test_generate_distance_map.py +51 -0
- tests/apps/pathology/transforms/post/test_generate_distance_mapd.py +70 -0
- tests/apps/pathology/transforms/post/test_generate_instance_border.py +49 -0
- tests/apps/pathology/transforms/post/test_generate_instance_borderd.py +59 -0
- tests/apps/pathology/transforms/post/test_generate_instance_centroid.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_instance_centroidd.py +56 -0
- tests/apps/pathology/transforms/post/test_generate_instance_contour.py +58 -0
- tests/apps/pathology/transforms/post/test_generate_instance_contourd.py +61 -0
- tests/apps/pathology/transforms/post/test_generate_instance_type.py +51 -0
- tests/apps/pathology/transforms/post/test_generate_instance_typed.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_succinct_contour.py +55 -0
- tests/apps/pathology/transforms/post/test_generate_succinct_contourd.py +57 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_markers.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_markersd.py +83 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_mask.py +77 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_maskd.py +77 -0
- tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processing.py +61 -0
- tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processingd.py +66 -0
- tests/apps/pathology/transforms/post/test_hovernet_nuclear_type_post_processing.py +66 -0
- tests/apps/pathology/transforms/post/test_watershed.py +60 -0
- tests/apps/pathology/transforms/post/test_watershedd.py +70 -0
- tests/apps/pathology/transforms/test_pathology_he_stain.py +230 -0
- tests/apps/pathology/transforms/test_pathology_he_stain_dict.py +225 -0
- tests/apps/reconstruction/__init__.py +10 -0
- tests/apps/reconstruction/nets/__init__.py +10 -0
- tests/apps/reconstruction/nets/test_recon_net_utils.py +82 -0
- tests/apps/reconstruction/test_complex_utils.py +77 -0
- tests/apps/reconstruction/test_fastmri_reader.py +82 -0
- tests/apps/reconstruction/test_mri_utils.py +37 -0
- tests/apps/reconstruction/transforms/__init__.py +10 -0
- tests/apps/reconstruction/transforms/test_kspace_mask.py +50 -0
- tests/apps/reconstruction/transforms/test_reference_based_normalize_intensity.py +77 -0
- tests/apps/reconstruction/transforms/test_reference_based_spatial_cropd.py +57 -0
- tests/apps/test_auto3dseg_bundlegen.py +156 -0
- tests/apps/test_check_hash.py +53 -0
- tests/apps/test_cross_validation.py +74 -0
- tests/apps/test_decathlondataset.py +93 -0
- tests/apps/test_download_and_extract.py +70 -0
- tests/apps/test_download_url_yandex.py +45 -0
- tests/apps/test_mednistdataset.py +72 -0
- tests/apps/test_mmar_download.py +154 -0
- tests/apps/test_tciadataset.py +123 -0
- tests/apps/vista3d/__init__.py +10 -0
- tests/apps/vista3d/test_point_based_window_inferer.py +77 -0
- tests/apps/vista3d/test_vista3d_sampler.py +100 -0
- tests/apps/vista3d/test_vista3d_transforms.py +94 -0
- tests/bundle/__init__.py +10 -0
- tests/bundle/test_bundle_ckpt_export.py +107 -0
- tests/bundle/test_bundle_download.py +435 -0
- tests/bundle/test_bundle_get_data.py +94 -0
- tests/bundle/test_bundle_push_to_hf_hub.py +41 -0
- tests/bundle/test_bundle_trt_export.py +147 -0
- tests/bundle/test_bundle_utils.py +149 -0
- tests/bundle/test_bundle_verify_metadata.py +66 -0
- tests/bundle/test_bundle_verify_net.py +76 -0
- tests/bundle/test_bundle_workflow.py +272 -0
- tests/bundle/test_component_locator.py +38 -0
- tests/bundle/test_config_item.py +138 -0
- tests/bundle/test_config_parser.py +392 -0
- tests/bundle/test_reference_resolver.py +114 -0
- tests/config/__init__.py +10 -0
- tests/config/test_cv2_dist.py +53 -0
- tests/engines/__init__.py +10 -0
- tests/engines/test_ensemble_evaluator.py +94 -0
- tests/engines/test_prepare_batch_default.py +76 -0
- tests/engines/test_prepare_batch_default_dist.py +76 -0
- tests/engines/test_prepare_batch_diffusion.py +104 -0
- tests/engines/test_prepare_batch_extra_input.py +80 -0
- tests/fl/__init__.py +10 -0
- tests/fl/monai_algo/__init__.py +10 -0
- tests/fl/monai_algo/test_fl_monai_algo.py +251 -0
- tests/fl/monai_algo/test_fl_monai_algo_dist.py +117 -0
- tests/fl/test_fl_monai_algo_stats.py +81 -0
- tests/fl/utils/__init__.py +10 -0
- tests/fl/utils/test_fl_exchange_object.py +63 -0
- tests/handlers/__init__.py +10 -0
- tests/handlers/test_handler_average_precision.py +79 -0
- tests/handlers/test_handler_checkpoint_loader.py +182 -0
- tests/handlers/test_handler_checkpoint_saver.py +233 -0
- tests/handlers/test_handler_classification_saver.py +64 -0
- tests/handlers/test_handler_classification_saver_dist.py +77 -0
- tests/handlers/test_handler_clearml_image.py +65 -0
- tests/handlers/test_handler_clearml_stats.py +65 -0
- tests/handlers/test_handler_confusion_matrix.py +104 -0
- tests/handlers/test_handler_confusion_matrix_dist.py +70 -0
- tests/handlers/test_handler_decollate_batch.py +66 -0
- tests/handlers/test_handler_early_stop.py +68 -0
- tests/handlers/test_handler_garbage_collector.py +73 -0
- tests/handlers/test_handler_hausdorff_distance.py +111 -0
- tests/handlers/test_handler_ignite_metric.py +191 -0
- tests/handlers/test_handler_lr_scheduler.py +94 -0
- tests/handlers/test_handler_mean_dice.py +98 -0
- tests/handlers/test_handler_mean_iou.py +76 -0
- tests/handlers/test_handler_metrics_reloaded.py +149 -0
- tests/handlers/test_handler_metrics_saver.py +89 -0
- tests/handlers/test_handler_metrics_saver_dist.py +120 -0
- tests/handlers/test_handler_mlflow.py +296 -0
- tests/handlers/test_handler_nvtx.py +93 -0
- tests/handlers/test_handler_panoptic_quality.py +89 -0
- tests/handlers/test_handler_parameter_scheduler.py +136 -0
- tests/handlers/test_handler_post_processing.py +74 -0
- tests/handlers/test_handler_prob_map_producer.py +111 -0
- tests/handlers/test_handler_regression_metrics.py +160 -0
- tests/handlers/test_handler_regression_metrics_dist.py +245 -0
- tests/handlers/test_handler_rocauc.py +48 -0
- tests/handlers/test_handler_rocauc_dist.py +54 -0
- tests/handlers/test_handler_stats.py +281 -0
- tests/handlers/test_handler_surface_distance.py +113 -0
- tests/handlers/test_handler_tb_image.py +61 -0
- tests/handlers/test_handler_tb_stats.py +166 -0
- tests/handlers/test_handler_validation.py +59 -0
- tests/handlers/test_trt_compile.py +145 -0
- tests/handlers/test_write_metrics_reports.py +68 -0
- tests/inferers/__init__.py +10 -0
- tests/inferers/test_avg_merger.py +179 -0
- tests/inferers/test_controlnet_inferers.py +1388 -0
- tests/inferers/test_diffusion_inferer.py +236 -0
- tests/inferers/test_latent_diffusion_inferer.py +884 -0
- tests/inferers/test_patch_inferer.py +309 -0
- tests/inferers/test_saliency_inferer.py +55 -0
- tests/inferers/test_slice_inferer.py +57 -0
- tests/inferers/test_sliding_window_inference.py +377 -0
- tests/inferers/test_sliding_window_splitter.py +284 -0
- tests/inferers/test_wsi_sliding_window_splitter.py +249 -0
- tests/inferers/test_zarr_avg_merger.py +326 -0
- tests/integration/__init__.py +10 -0
- tests/integration/test_auto3dseg_ensemble.py +211 -0
- tests/integration/test_auto3dseg_hpo.py +189 -0
- tests/integration/test_deepedit_interaction.py +122 -0
- tests/integration/test_downsample_block.py +50 -0
- tests/integration/test_hovernet_nuclear_type_post_processingd.py +71 -0
- tests/integration/test_integration_autorunner.py +201 -0
- tests/integration/test_integration_bundle_run.py +240 -0
- tests/integration/test_integration_classification_2d.py +282 -0
- tests/integration/test_integration_determinism.py +95 -0
- tests/integration/test_integration_fast_train.py +231 -0
- tests/integration/test_integration_gpu_customization.py +159 -0
- tests/integration/test_integration_lazy_samples.py +219 -0
- tests/integration/test_integration_nnunetv2_runner.py +96 -0
- tests/integration/test_integration_segmentation_3d.py +304 -0
- tests/integration/test_integration_sliding_window.py +100 -0
- tests/integration/test_integration_stn.py +133 -0
- tests/integration/test_integration_unet_2d.py +67 -0
- tests/integration/test_integration_workers.py +61 -0
- tests/integration/test_integration_workflows.py +365 -0
- tests/integration/test_integration_workflows_adversarial.py +173 -0
- tests/integration/test_integration_workflows_gan.py +158 -0
- tests/integration/test_loader_semaphore.py +48 -0
- tests/integration/test_mapping_filed.py +122 -0
- tests/integration/test_meta_affine.py +183 -0
- tests/integration/test_metatensor_integration.py +114 -0
- tests/integration/test_module_list.py +76 -0
- tests/integration/test_one_of.py +283 -0
- tests/integration/test_pad_collation.py +124 -0
- tests/integration/test_reg_loss_integration.py +107 -0
- tests/integration/test_retinanet_predict_utils.py +154 -0
- tests/integration/test_seg_loss_integration.py +159 -0
- tests/integration/test_spatial_combine_transforms.py +185 -0
- tests/integration/test_testtimeaugmentation.py +186 -0
- tests/integration/test_vis_gradbased.py +69 -0
- tests/integration/test_vista3d_utils.py +159 -0
- tests/losses/__init__.py +10 -0
- tests/losses/deform/__init__.py +10 -0
- tests/losses/deform/test_bending_energy.py +88 -0
- tests/losses/deform/test_diffusion_loss.py +117 -0
- tests/losses/image_dissimilarity/__init__.py +10 -0
- tests/losses/image_dissimilarity/test_global_mutual_information_loss.py +150 -0
- tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py +162 -0
- tests/losses/test_adversarial_loss.py +94 -0
- tests/losses/test_barlow_twins_loss.py +109 -0
- tests/losses/test_cldice_loss.py +51 -0
- tests/losses/test_contrastive_loss.py +86 -0
- tests/losses/test_dice_ce_loss.py +123 -0
- tests/losses/test_dice_focal_loss.py +124 -0
- tests/losses/test_dice_loss.py +227 -0
- tests/losses/test_ds_loss.py +189 -0
- tests/losses/test_focal_loss.py +379 -0
- tests/losses/test_generalized_dice_focal_loss.py +85 -0
- tests/losses/test_generalized_dice_loss.py +221 -0
- tests/losses/test_generalized_wasserstein_dice_loss.py +234 -0
- tests/losses/test_giou_loss.py +62 -0
- tests/losses/test_hausdorff_loss.py +264 -0
- tests/losses/test_masked_dice_loss.py +152 -0
- tests/losses/test_masked_loss.py +87 -0
- tests/losses/test_multi_scale.py +86 -0
- tests/losses/test_nacl_loss.py +167 -0
- tests/losses/test_perceptual_loss.py +122 -0
- tests/losses/test_spectral_loss.py +86 -0
- tests/losses/test_ssim_loss.py +59 -0
- tests/losses/test_sure_loss.py +72 -0
- tests/losses/test_tversky_loss.py +198 -0
- tests/losses/test_unified_focal_loss.py +66 -0
- tests/metrics/__init__.py +10 -0
- tests/metrics/test_compute_average_precision.py +162 -0
- tests/metrics/test_compute_confusion_matrix.py +294 -0
- tests/metrics/test_compute_f_beta.py +80 -0
- tests/metrics/test_compute_fid_metric.py +40 -0
- tests/metrics/test_compute_froc.py +143 -0
- tests/metrics/test_compute_generalized_dice.py +240 -0
- tests/metrics/test_compute_meandice.py +306 -0
- tests/metrics/test_compute_meaniou.py +223 -0
- tests/metrics/test_compute_mmd_metric.py +56 -0
- tests/metrics/test_compute_multiscalessim_metric.py +83 -0
- tests/metrics/test_compute_panoptic_quality.py +113 -0
- tests/metrics/test_compute_regression_metrics.py +196 -0
- tests/metrics/test_compute_roc_auc.py +155 -0
- tests/metrics/test_compute_variance.py +147 -0
- tests/metrics/test_cumulative.py +63 -0
- tests/metrics/test_cumulative_average.py +74 -0
- tests/metrics/test_cumulative_average_dist.py +48 -0
- tests/metrics/test_hausdorff_distance.py +209 -0
- tests/metrics/test_label_quality_score.py +134 -0
- tests/metrics/test_loss_metric.py +57 -0
- tests/metrics/test_metrics_reloaded.py +96 -0
- tests/metrics/test_ssim_metric.py +78 -0
- tests/metrics/test_surface_dice.py +416 -0
- tests/metrics/test_surface_distance.py +186 -0
- tests/networks/__init__.py +10 -0
- tests/networks/blocks/__init__.py +10 -0
- tests/networks/blocks/dints_block/__init__.py +10 -0
- tests/networks/blocks/dints_block/test_acn_block.py +41 -0
- tests/networks/blocks/dints_block/test_factorized_increase.py +37 -0
- tests/networks/blocks/dints_block/test_factorized_reduce.py +37 -0
- tests/networks/blocks/dints_block/test_p3d_block.py +78 -0
- tests/networks/blocks/test_adn.py +86 -0
- tests/networks/blocks/test_convolutions.py +156 -0
- tests/networks/blocks/test_crf_cpu.py +513 -0
- tests/networks/blocks/test_crf_cuda.py +528 -0
- tests/networks/blocks/test_crossattention.py +185 -0
- tests/networks/blocks/test_denseblock.py +105 -0
- tests/networks/blocks/test_dynunet_block.py +116 -0
- tests/networks/blocks/test_fpn_block.py +88 -0
- tests/networks/blocks/test_localnet_block.py +121 -0
- tests/networks/blocks/test_mlp.py +78 -0
- tests/networks/blocks/test_patchembedding.py +212 -0
- tests/networks/blocks/test_regunet_block.py +103 -0
- tests/networks/blocks/test_se_block.py +85 -0
- tests/networks/blocks/test_se_blocks.py +78 -0
- tests/networks/blocks/test_segresnet_block.py +57 -0
- tests/networks/blocks/test_selfattention.py +232 -0
- tests/networks/blocks/test_simple_aspp.py +87 -0
- tests/networks/blocks/test_spatialattention.py +55 -0
- tests/networks/blocks/test_subpixel_upsample.py +87 -0
- tests/networks/blocks/test_text_encoding.py +49 -0
- tests/networks/blocks/test_transformerblock.py +90 -0
- tests/networks/blocks/test_unetr_block.py +158 -0
- tests/networks/blocks/test_upsample_block.py +134 -0
- tests/networks/blocks/warp/__init__.py +10 -0
- tests/networks/blocks/warp/test_dvf2ddf.py +72 -0
- tests/networks/blocks/warp/test_warp.py +250 -0
- tests/networks/layers/__init__.py +10 -0
- tests/networks/layers/filtering/__init__.py +10 -0
- tests/networks/layers/filtering/test_bilateral_approx_cpu.py +399 -0
- tests/networks/layers/filtering/test_bilateral_approx_cuda.py +404 -0
- tests/networks/layers/filtering/test_bilateral_precise.py +437 -0
- tests/networks/layers/filtering/test_phl_cpu.py +259 -0
- tests/networks/layers/filtering/test_phl_cuda.py +167 -0
- tests/networks/layers/filtering/test_trainable_bilateral.py +474 -0
- tests/networks/layers/filtering/test_trainable_joint_bilateral.py +609 -0
- tests/networks/layers/test_affine_transform.py +385 -0
- tests/networks/layers/test_apply_filter.py +89 -0
- tests/networks/layers/test_channel_pad.py +51 -0
- tests/networks/layers/test_conjugate_gradient.py +56 -0
- tests/networks/layers/test_drop_path.py +46 -0
- tests/networks/layers/test_gaussian.py +317 -0
- tests/networks/layers/test_gaussian_filter.py +206 -0
- tests/networks/layers/test_get_layers.py +65 -0
- tests/networks/layers/test_gmm.py +314 -0
- tests/networks/layers/test_grid_pull.py +93 -0
- tests/networks/layers/test_hilbert_transform.py +131 -0
- tests/networks/layers/test_lltm.py +62 -0
- tests/networks/layers/test_median_filter.py +52 -0
- tests/networks/layers/test_polyval.py +55 -0
- tests/networks/layers/test_preset_filters.py +136 -0
- tests/networks/layers/test_savitzky_golay_filter.py +141 -0
- tests/networks/layers/test_separable_filter.py +87 -0
- tests/networks/layers/test_skip_connection.py +48 -0
- tests/networks/layers/test_vector_quantizer.py +89 -0
- tests/networks/layers/test_weight_init.py +50 -0
- tests/networks/nets/__init__.py +10 -0
- tests/networks/nets/dints/__init__.py +10 -0
- tests/networks/nets/dints/test_dints_cell.py +110 -0
- tests/networks/nets/dints/test_dints_mixop.py +84 -0
- tests/networks/nets/regunet/__init__.py +10 -0
- tests/networks/nets/regunet/test_localnet.py +86 -0
- tests/networks/nets/regunet/test_regunet.py +88 -0
- tests/networks/nets/test_ahnet.py +224 -0
- tests/networks/nets/test_attentionunet.py +88 -0
- tests/networks/nets/test_autoencoder.py +95 -0
- tests/networks/nets/test_autoencoderkl.py +337 -0
- tests/networks/nets/test_basic_unet.py +102 -0
- tests/networks/nets/test_basic_unetplusplus.py +109 -0
- tests/networks/nets/test_bundle_init_bundle.py +55 -0
- tests/networks/nets/test_cell_sam_wrapper.py +58 -0
- tests/networks/nets/test_controlnet.py +215 -0
- tests/networks/nets/test_daf3d.py +62 -0
- tests/networks/nets/test_densenet.py +121 -0
- tests/networks/nets/test_diffusion_model_unet.py +585 -0
- tests/networks/nets/test_dints_network.py +168 -0
- tests/networks/nets/test_discriminator.py +59 -0
- tests/networks/nets/test_dynunet.py +181 -0
- tests/networks/nets/test_efficientnet.py +400 -0
- tests/networks/nets/test_flexible_unet.py +341 -0
- tests/networks/nets/test_fullyconnectednet.py +69 -0
- tests/networks/nets/test_generator.py +59 -0
- tests/networks/nets/test_globalnet.py +103 -0
- tests/networks/nets/test_highresnet.py +67 -0
- tests/networks/nets/test_hovernet.py +218 -0
- tests/networks/nets/test_mednext.py +122 -0
- tests/networks/nets/test_milmodel.py +92 -0
- tests/networks/nets/test_net_adapter.py +68 -0
- tests/networks/nets/test_network_consistency.py +86 -0
- tests/networks/nets/test_patch_gan_dicriminator.py +179 -0
- tests/networks/nets/test_quicknat.py +57 -0
- tests/networks/nets/test_resnet.py +340 -0
- tests/networks/nets/test_segresnet.py +120 -0
- tests/networks/nets/test_segresnet_ds.py +156 -0
- tests/networks/nets/test_senet.py +151 -0
- tests/networks/nets/test_spade_autoencoderkl.py +295 -0
- tests/networks/nets/test_spade_diffusion_model_unet.py +574 -0
- tests/networks/nets/test_spade_vaegan.py +140 -0
- tests/networks/nets/test_swin_unetr.py +139 -0
- tests/networks/nets/test_torchvision_fc_model.py +201 -0
- tests/networks/nets/test_transchex.py +84 -0
- tests/networks/nets/test_transformer.py +108 -0
- tests/networks/nets/test_unet.py +208 -0
- tests/networks/nets/test_unetr.py +137 -0
- tests/networks/nets/test_varautoencoder.py +127 -0
- tests/networks/nets/test_vista3d.py +84 -0
- tests/networks/nets/test_vit.py +139 -0
- tests/networks/nets/test_vitautoenc.py +112 -0
- tests/networks/nets/test_vnet.py +81 -0
- tests/networks/nets/test_voxelmorph.py +280 -0
- tests/networks/nets/test_vqvae.py +274 -0
- tests/networks/nets/test_vqvaetransformer_inferer.py +295 -0
- tests/networks/schedulers/__init__.py +10 -0
- tests/networks/schedulers/test_scheduler_ddim.py +83 -0
- tests/networks/schedulers/test_scheduler_ddpm.py +104 -0
- tests/networks/schedulers/test_scheduler_pndm.py +108 -0
- tests/networks/test_bundle_onnx_export.py +71 -0
- tests/networks/test_convert_to_onnx.py +106 -0
- tests/networks/test_convert_to_torchscript.py +46 -0
- tests/networks/test_convert_to_trt.py +79 -0
- tests/networks/test_save_state.py +73 -0
- tests/networks/test_to_onehot.py +63 -0
- tests/networks/test_varnet.py +63 -0
- tests/networks/utils/__init__.py +10 -0
- tests/networks/utils/test_copy_model_state.py +187 -0
- tests/networks/utils/test_eval_mode.py +34 -0
- tests/networks/utils/test_freeze_layers.py +61 -0
- tests/networks/utils/test_replace_module.py +98 -0
- tests/networks/utils/test_train_mode.py +34 -0
- tests/optimizers/__init__.py +10 -0
- tests/optimizers/test_generate_param_groups.py +105 -0
- tests/optimizers/test_lr_finder.py +108 -0
- tests/optimizers/test_lr_scheduler.py +71 -0
- tests/optimizers/test_optim_novograd.py +100 -0
- tests/profile_subclass/__init__.py +10 -0
- tests/profile_subclass/cprofile_profiling.py +29 -0
- tests/profile_subclass/min_classes.py +30 -0
- tests/profile_subclass/profiling.py +73 -0
- tests/profile_subclass/pyspy_profiling.py +41 -0
- tests/transforms/__init__.py +10 -0
- tests/transforms/compose/__init__.py +10 -0
- tests/transforms/compose/test_compose.py +758 -0
- tests/transforms/compose/test_some_of.py +258 -0
- tests/transforms/croppad/__init__.py +10 -0
- tests/transforms/croppad/test_rand_weighted_crop.py +224 -0
- tests/transforms/croppad/test_rand_weighted_cropd.py +182 -0
- tests/transforms/functional/__init__.py +10 -0
- tests/transforms/functional/test_apply.py +75 -0
- tests/transforms/functional/test_resample.py +50 -0
- tests/transforms/intensity/__init__.py +10 -0
- tests/transforms/intensity/test_compute_ho_ver_maps.py +75 -0
- tests/transforms/intensity/test_compute_ho_ver_maps_d.py +79 -0
- tests/transforms/intensity/test_foreground_mask.py +98 -0
- tests/transforms/intensity/test_foreground_maskd.py +106 -0
- tests/transforms/intensity/test_rand_histogram_shiftd.py +76 -0
- tests/transforms/intensity/test_scale_intensity_range_percentiles.py +96 -0
- tests/transforms/intensity/test_scale_intensity_range_percentilesd.py +100 -0
- tests/transforms/inverse/__init__.py +10 -0
- tests/transforms/inverse/test_inverse_array.py +76 -0
- tests/transforms/inverse/test_traceable_transform.py +59 -0
- tests/transforms/post/__init__.py +10 -0
- tests/transforms/post/test_label_filterd.py +78 -0
- tests/transforms/post/test_probnms.py +72 -0
- tests/transforms/post/test_probnmsd.py +79 -0
- tests/transforms/post/test_remove_small_objects.py +102 -0
- tests/transforms/spatial/__init__.py +10 -0
- tests/transforms/spatial/test_convert_box_points.py +119 -0
- tests/transforms/spatial/test_grid_patch.py +134 -0
- tests/transforms/spatial/test_grid_patchd.py +102 -0
- tests/transforms/spatial/test_rand_grid_patch.py +150 -0
- tests/transforms/spatial/test_rand_grid_patchd.py +117 -0
- tests/transforms/spatial/test_spatial_resampled.py +124 -0
- tests/transforms/test_activations.py +120 -0
- tests/transforms/test_activationsd.py +64 -0
- tests/transforms/test_adaptors.py +160 -0
- tests/transforms/test_add_coordinate_channels.py +53 -0
- tests/transforms/test_add_coordinate_channelsd.py +67 -0
- tests/transforms/test_add_extreme_points_channel.py +80 -0
- tests/transforms/test_add_extreme_points_channeld.py +77 -0
- tests/transforms/test_adjust_contrast.py +70 -0
- tests/transforms/test_adjust_contrastd.py +64 -0
- tests/transforms/test_affine.py +245 -0
- tests/transforms/test_affine_grid.py +152 -0
- tests/transforms/test_affined.py +190 -0
- tests/transforms/test_as_channel_last.py +38 -0
- tests/transforms/test_as_channel_lastd.py +44 -0
- tests/transforms/test_as_discrete.py +81 -0
- tests/transforms/test_as_discreted.py +82 -0
- tests/transforms/test_border_pad.py +49 -0
- tests/transforms/test_border_padd.py +45 -0
- tests/transforms/test_bounding_rect.py +54 -0
- tests/transforms/test_bounding_rectd.py +53 -0
- tests/transforms/test_cast_to_type.py +63 -0
- tests/transforms/test_cast_to_typed.py +74 -0
- tests/transforms/test_center_scale_crop.py +55 -0
- tests/transforms/test_center_scale_cropd.py +56 -0
- tests/transforms/test_center_spatial_crop.py +56 -0
- tests/transforms/test_center_spatial_cropd.py +63 -0
- tests/transforms/test_classes_to_indices.py +93 -0
- tests/transforms/test_classes_to_indicesd.py +110 -0
- tests/transforms/test_clip_intensity_percentiles.py +196 -0
- tests/transforms/test_clip_intensity_percentilesd.py +193 -0
- tests/transforms/test_compose_get_number_conversions.py +127 -0
- tests/transforms/test_concat_itemsd.py +82 -0
- tests/transforms/test_convert_to_multi_channel.py +59 -0
- tests/transforms/test_convert_to_multi_channeld.py +37 -0
- tests/transforms/test_copy_itemsd.py +86 -0
- tests/transforms/test_create_grid_and_affine.py +274 -0
- tests/transforms/test_crop_foreground.py +164 -0
- tests/transforms/test_crop_foregroundd.py +205 -0
- tests/transforms/test_cucim_dict_transform.py +142 -0
- tests/transforms/test_cucim_transform.py +141 -0
- tests/transforms/test_data_stats.py +221 -0
- tests/transforms/test_data_statsd.py +249 -0
- tests/transforms/test_delete_itemsd.py +58 -0
- tests/transforms/test_detect_envelope.py +159 -0
- tests/transforms/test_distance_transform_edt.py +202 -0
- tests/transforms/test_divisible_pad.py +49 -0
- tests/transforms/test_divisible_padd.py +42 -0
- tests/transforms/test_ensure_channel_first.py +113 -0
- tests/transforms/test_ensure_channel_firstd.py +85 -0
- tests/transforms/test_ensure_type.py +94 -0
- tests/transforms/test_ensure_typed.py +110 -0
- tests/transforms/test_fg_bg_to_indices.py +83 -0
- tests/transforms/test_fg_bg_to_indicesd.py +78 -0
- tests/transforms/test_fill_holes.py +207 -0
- tests/transforms/test_fill_holesd.py +209 -0
- tests/transforms/test_flatten_sub_keysd.py +64 -0
- tests/transforms/test_flip.py +83 -0
- tests/transforms/test_flipd.py +90 -0
- tests/transforms/test_fourier.py +70 -0
- tests/transforms/test_gaussian_sharpen.py +92 -0
- tests/transforms/test_gaussian_sharpend.py +92 -0
- tests/transforms/test_gaussian_smooth.py +96 -0
- tests/transforms/test_gaussian_smoothd.py +96 -0
- tests/transforms/test_generate_label_classes_crop_centers.py +71 -0
- tests/transforms/test_generate_pos_neg_label_crop_centers.py +76 -0
- tests/transforms/test_generate_spatial_bounding_box.py +114 -0
- tests/transforms/test_get_extreme_points.py +57 -0
- tests/transforms/test_gibbs_noise.py +73 -0
- tests/transforms/test_gibbs_noised.py +88 -0
- tests/transforms/test_grid_distortion.py +113 -0
- tests/transforms/test_grid_distortiond.py +87 -0
- tests/transforms/test_grid_split.py +88 -0
- tests/transforms/test_grid_splitd.py +96 -0
- tests/transforms/test_histogram_normalize.py +59 -0
- tests/transforms/test_histogram_normalized.py +59 -0
- tests/transforms/test_image_filter.py +259 -0
- tests/transforms/test_intensity_stats.py +73 -0
- tests/transforms/test_intensity_statsd.py +90 -0
- tests/transforms/test_inverse.py +521 -0
- tests/transforms/test_inverse_collation.py +147 -0
- tests/transforms/test_invert.py +105 -0
- tests/transforms/test_invertd.py +142 -0
- tests/transforms/test_k_space_spike_noise.py +81 -0
- tests/transforms/test_k_space_spike_noised.py +98 -0
- tests/transforms/test_keep_largest_connected_component.py +419 -0
- tests/transforms/test_keep_largest_connected_componentd.py +348 -0
- tests/transforms/test_label_filter.py +78 -0
- tests/transforms/test_label_to_contour.py +179 -0
- tests/transforms/test_label_to_contourd.py +182 -0
- tests/transforms/test_label_to_mask.py +69 -0
- tests/transforms/test_label_to_maskd.py +70 -0
- tests/transforms/test_load_image.py +502 -0
- tests/transforms/test_load_imaged.py +198 -0
- tests/transforms/test_load_spacing_orientation.py +149 -0
- tests/transforms/test_map_and_generate_sampling_centers.py +86 -0
- tests/transforms/test_map_binary_to_indices.py +75 -0
- tests/transforms/test_map_classes_to_indices.py +135 -0
- tests/transforms/test_map_label_value.py +89 -0
- tests/transforms/test_map_label_valued.py +85 -0
- tests/transforms/test_map_transform.py +45 -0
- tests/transforms/test_mask_intensity.py +74 -0
- tests/transforms/test_mask_intensityd.py +68 -0
- tests/transforms/test_mean_ensemble.py +77 -0
- tests/transforms/test_mean_ensembled.py +91 -0
- tests/transforms/test_median_smooth.py +41 -0
- tests/transforms/test_median_smoothd.py +65 -0
- tests/transforms/test_morphological_ops.py +101 -0
- tests/transforms/test_nifti_endianness.py +107 -0
- tests/transforms/test_normalize_intensity.py +143 -0
- tests/transforms/test_normalize_intensityd.py +81 -0
- tests/transforms/test_nvtx_decorator.py +289 -0
- tests/transforms/test_nvtx_transform.py +143 -0
- tests/transforms/test_orientation.py +247 -0
- tests/transforms/test_orientationd.py +112 -0
- tests/transforms/test_rand_adjust_contrast.py +45 -0
- tests/transforms/test_rand_adjust_contrastd.py +44 -0
- tests/transforms/test_rand_affine.py +201 -0
- tests/transforms/test_rand_affine_grid.py +212 -0
- tests/transforms/test_rand_affined.py +281 -0
- tests/transforms/test_rand_axis_flip.py +50 -0
- tests/transforms/test_rand_axis_flipd.py +50 -0
- tests/transforms/test_rand_bias_field.py +69 -0
- tests/transforms/test_rand_bias_fieldd.py +65 -0
- tests/transforms/test_rand_coarse_dropout.py +110 -0
- tests/transforms/test_rand_coarse_dropoutd.py +107 -0
- tests/transforms/test_rand_coarse_shuffle.py +65 -0
- tests/transforms/test_rand_coarse_shuffled.py +59 -0
- tests/transforms/test_rand_crop_by_label_classes.py +170 -0
- tests/transforms/test_rand_crop_by_label_classesd.py +159 -0
- tests/transforms/test_rand_crop_by_pos_neg_label.py +152 -0
- tests/transforms/test_rand_crop_by_pos_neg_labeld.py +172 -0
- tests/transforms/test_rand_cucim_dict_transform.py +162 -0
- tests/transforms/test_rand_cucim_transform.py +162 -0
- tests/transforms/test_rand_deform_grid.py +138 -0
- tests/transforms/test_rand_elastic_2d.py +127 -0
- tests/transforms/test_rand_elastic_3d.py +104 -0
- tests/transforms/test_rand_elasticd_2d.py +177 -0
- tests/transforms/test_rand_elasticd_3d.py +156 -0
- tests/transforms/test_rand_flip.py +60 -0
- tests/transforms/test_rand_flipd.py +55 -0
- tests/transforms/test_rand_gaussian_noise.py +48 -0
- tests/transforms/test_rand_gaussian_noised.py +54 -0
- tests/transforms/test_rand_gaussian_sharpen.py +140 -0
- tests/transforms/test_rand_gaussian_sharpend.py +143 -0
- tests/transforms/test_rand_gaussian_smooth.py +98 -0
- tests/transforms/test_rand_gaussian_smoothd.py +98 -0
- tests/transforms/test_rand_gibbs_noise.py +103 -0
- tests/transforms/test_rand_gibbs_noised.py +117 -0
- tests/transforms/test_rand_grid_distortion.py +99 -0
- tests/transforms/test_rand_grid_distortiond.py +90 -0
- tests/transforms/test_rand_histogram_shift.py +92 -0
- tests/transforms/test_rand_k_space_spike_noise.py +92 -0
- tests/transforms/test_rand_k_space_spike_noised.py +76 -0
- tests/transforms/test_rand_rician_noise.py +52 -0
- tests/transforms/test_rand_rician_noised.py +52 -0
- tests/transforms/test_rand_rotate.py +166 -0
- tests/transforms/test_rand_rotate90.py +100 -0
- tests/transforms/test_rand_rotate90d.py +112 -0
- tests/transforms/test_rand_rotated.py +187 -0
- tests/transforms/test_rand_scale_crop.py +78 -0
- tests/transforms/test_rand_scale_cropd.py +98 -0
- tests/transforms/test_rand_scale_intensity.py +54 -0
- tests/transforms/test_rand_scale_intensity_fixed_mean.py +41 -0
- tests/transforms/test_rand_scale_intensity_fixed_meand.py +41 -0
- tests/transforms/test_rand_scale_intensityd.py +53 -0
- tests/transforms/test_rand_shift_intensity.py +52 -0
- tests/transforms/test_rand_shift_intensityd.py +67 -0
- tests/transforms/test_rand_simulate_low_resolution.py +83 -0
- tests/transforms/test_rand_simulate_low_resolutiond.py +73 -0
- tests/transforms/test_rand_spatial_crop.py +107 -0
- tests/transforms/test_rand_spatial_crop_samples.py +128 -0
- tests/transforms/test_rand_spatial_crop_samplesd.py +147 -0
- tests/transforms/test_rand_spatial_cropd.py +112 -0
- tests/transforms/test_rand_std_shift_intensity.py +43 -0
- tests/transforms/test_rand_std_shift_intensityd.py +38 -0
- tests/transforms/test_rand_zoom.py +105 -0
- tests/transforms/test_rand_zoomd.py +108 -0
- tests/transforms/test_randidentity.py +49 -0
- tests/transforms/test_random_order.py +144 -0
- tests/transforms/test_randtorchvisiond.py +65 -0
- tests/transforms/test_regularization.py +139 -0
- tests/transforms/test_remove_repeated_channel.py +34 -0
- tests/transforms/test_remove_repeated_channeld.py +44 -0
- tests/transforms/test_repeat_channel.py +34 -0
- tests/transforms/test_repeat_channeld.py +41 -0
- tests/transforms/test_resample_backends.py +65 -0
- tests/transforms/test_resample_to_match.py +110 -0
- tests/transforms/test_resample_to_matchd.py +93 -0
- tests/transforms/test_resampler.py +165 -0
- tests/transforms/test_resize.py +140 -0
- tests/transforms/test_resize_with_pad_or_crop.py +91 -0
- tests/transforms/test_resize_with_pad_or_cropd.py +86 -0
- tests/transforms/test_resized.py +163 -0
- tests/transforms/test_rotate.py +160 -0
- tests/transforms/test_rotate90.py +212 -0
- tests/transforms/test_rotate90d.py +106 -0
- tests/transforms/test_rotated.py +179 -0
- tests/transforms/test_save_classificationd.py +109 -0
- tests/transforms/test_save_image.py +80 -0
- tests/transforms/test_save_imaged.py +130 -0
- tests/transforms/test_savitzky_golay_smooth.py +73 -0
- tests/transforms/test_savitzky_golay_smoothd.py +73 -0
- tests/transforms/test_scale_intensity.py +76 -0
- tests/transforms/test_scale_intensity_fixed_mean.py +94 -0
- tests/transforms/test_scale_intensity_range.py +41 -0
- tests/transforms/test_scale_intensity_ranged.py +40 -0
- tests/transforms/test_scale_intensityd.py +57 -0
- tests/transforms/test_select_itemsd.py +41 -0
- tests/transforms/test_shift_intensity.py +31 -0
- tests/transforms/test_shift_intensityd.py +44 -0
- tests/transforms/test_signal_continuouswavelet.py +44 -0
- tests/transforms/test_signal_fillempty.py +52 -0
- tests/transforms/test_signal_fillemptyd.py +60 -0
- tests/transforms/test_signal_rand_add_gaussiannoise.py +50 -0
- tests/transforms/test_signal_rand_add_sine.py +52 -0
- tests/transforms/test_signal_rand_add_sine_partial.py +50 -0
- tests/transforms/test_signal_rand_add_squarepulse.py +58 -0
- tests/transforms/test_signal_rand_add_squarepulse_partial.py +62 -0
- tests/transforms/test_signal_rand_drop.py +50 -0
- tests/transforms/test_signal_rand_scale.py +52 -0
- tests/transforms/test_signal_rand_shift.py +55 -0
- tests/transforms/test_signal_remove_frequency.py +71 -0
- tests/transforms/test_smooth_field.py +177 -0
- tests/transforms/test_sobel_gradient.py +189 -0
- tests/transforms/test_sobel_gradientd.py +212 -0
- tests/transforms/test_spacing.py +381 -0
- tests/transforms/test_spacingd.py +178 -0
- tests/transforms/test_spatial_crop.py +82 -0
- tests/transforms/test_spatial_cropd.py +74 -0
- tests/transforms/test_spatial_pad.py +57 -0
- tests/transforms/test_spatial_padd.py +43 -0
- tests/transforms/test_spatial_resample.py +235 -0
- tests/transforms/test_squeezedim.py +62 -0
- tests/transforms/test_squeezedimd.py +98 -0
- tests/transforms/test_std_shift_intensity.py +76 -0
- tests/transforms/test_std_shift_intensityd.py +74 -0
- tests/transforms/test_threshold_intensity.py +38 -0
- tests/transforms/test_threshold_intensityd.py +58 -0
- tests/transforms/test_to_contiguous.py +47 -0
- tests/transforms/test_to_cupy.py +112 -0
- tests/transforms/test_to_cupyd.py +76 -0
- tests/transforms/test_to_device.py +42 -0
- tests/transforms/test_to_deviced.py +37 -0
- tests/transforms/test_to_numpy.py +85 -0
- tests/transforms/test_to_numpyd.py +68 -0
- tests/transforms/test_to_pil.py +52 -0
- tests/transforms/test_to_pild.py +50 -0
- tests/transforms/test_to_tensor.py +60 -0
- tests/transforms/test_to_tensord.py +71 -0
- tests/transforms/test_torchvision.py +66 -0
- tests/transforms/test_torchvisiond.py +63 -0
- tests/transforms/test_transform.py +62 -0
- tests/transforms/test_transpose.py +41 -0
- tests/transforms/test_transposed.py +52 -0
- tests/transforms/test_ultrasound_confidence_map_transform.py +711 -0
- tests/transforms/test_utils_pytorch_numpy_unification.py +90 -0
- tests/transforms/test_vote_ensemble.py +84 -0
- tests/transforms/test_vote_ensembled.py +107 -0
- tests/transforms/test_with_allow_missing_keys.py +76 -0
- tests/transforms/test_zoom.py +120 -0
- tests/transforms/test_zoomd.py +94 -0
- tests/transforms/transform/__init__.py +10 -0
- tests/transforms/transform/test_randomizable.py +52 -0
- tests/transforms/transform/test_randomizable_transform_type.py +37 -0
- tests/transforms/utility/__init__.py +10 -0
- tests/transforms/utility/test_apply_transform_to_points.py +81 -0
- tests/transforms/utility/test_apply_transform_to_pointsd.py +185 -0
- tests/transforms/utility/test_identity.py +29 -0
- tests/transforms/utility/test_identityd.py +30 -0
- tests/transforms/utility/test_lambda.py +71 -0
- tests/transforms/utility/test_lambdad.py +83 -0
- tests/transforms/utility/test_rand_lambda.py +87 -0
- tests/transforms/utility/test_rand_lambdad.py +77 -0
- tests/transforms/utility/test_simulatedelay.py +36 -0
- tests/transforms/utility/test_simulatedelayd.py +36 -0
- tests/transforms/utility/test_splitdim.py +52 -0
- tests/transforms/utility/test_splitdimd.py +96 -0
- tests/transforms/utils/__init__.py +10 -0
- tests/transforms/utils/test_correct_crop_centers.py +36 -0
- tests/transforms/utils/test_get_unique_labels.py +45 -0
- tests/transforms/utils/test_print_transform_backends.py +29 -0
- tests/transforms/utils/test_soft_clip.py +125 -0
- tests/utils/__init__.py +10 -0
- tests/utils/enums/__init__.py +10 -0
- tests/utils/enums/test_hovernet_loss.py +190 -0
- tests/utils/enums/test_ordering.py +289 -0
- tests/utils/enums/test_wsireader.py +663 -0
- tests/utils/misc/__init__.py +10 -0
- tests/utils/misc/test_ensure_tuple.py +53 -0
- tests/utils/misc/test_monai_env_vars.py +44 -0
- tests/utils/misc/test_monai_utils_misc.py +103 -0
- tests/utils/misc/test_str2bool.py +34 -0
- tests/utils/misc/test_str2list.py +33 -0
- tests/utils/test_alias.py +44 -0
- tests/utils/test_component_store.py +73 -0
- tests/utils/test_deprecated.py +455 -0
- tests/utils/test_enum_bound_interp.py +75 -0
- tests/utils/test_evenly_divisible_all_gather_dist.py +50 -0
- tests/utils/test_get_package_version.py +34 -0
- tests/utils/test_handler_logfile.py +84 -0
- tests/utils/test_handler_metric_logger.py +62 -0
- tests/utils/test_list_to_dict.py +43 -0
- tests/utils/test_look_up_option.py +87 -0
- tests/utils/test_optional_import.py +80 -0
- tests/utils/test_pad_mode.py +39 -0
- tests/utils/test_profiling.py +208 -0
- tests/utils/test_rankfilter_dist.py +77 -0
- tests/utils/test_require_pkg.py +83 -0
- tests/utils/test_sample_slices.py +43 -0
- tests/utils/test_set_determinism.py +74 -0
- tests/utils/test_squeeze_unsqueeze.py +71 -0
- tests/utils/test_state_cacher.py +67 -0
- tests/utils/test_torchscript_utils.py +113 -0
- tests/utils/test_version.py +91 -0
- tests/utils/test_version_after.py +65 -0
- tests/utils/type_conversion/__init__.py +10 -0
- tests/utils/type_conversion/test_convert_data_type.py +152 -0
- tests/utils/type_conversion/test_get_equivalent_dtype.py +65 -0
- tests/utils/type_conversion/test_safe_dtype_range.py +99 -0
- tests/visualize/__init__.py +10 -0
- tests/visualize/test_img2tensorboard.py +46 -0
- tests/visualize/test_occlusion_sensitivity.py +128 -0
- tests/visualize/test_plot_2d_or_3d_image.py +74 -0
- tests/visualize/test_vis_cam.py +98 -0
- tests/visualize/test_vis_gradcam.py +211 -0
- tests/visualize/utils/__init__.py +10 -0
- tests/visualize/utils/test_blend_images.py +63 -0
- tests/visualize/utils/test_matshow3d.py +133 -0
- monai_weekly-1.5.dev2506.dist-info/RECORD +0 -427
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/LICENSE +0 -0
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/WHEEL +0 -0
monai/__init__.py
CHANGED
monai/_version.py
CHANGED
@@ -8,11 +8,11 @@ import json
|
|
8
8
|
|
9
9
|
version_json = '''
|
10
10
|
{
|
11
|
-
"date": "2025-02-
|
11
|
+
"date": "2025-02-23T02:28:09+0000",
|
12
12
|
"dirty": false,
|
13
13
|
"error": null,
|
14
|
-
"full-revisionid": "
|
15
|
-
"version": "1.5.
|
14
|
+
"full-revisionid": "e55b5cbfbbba1800a968a9c06b2deaaa5c9bec54",
|
15
|
+
"version": "1.5.dev2508"
|
16
16
|
}
|
17
17
|
''' # END VERSION_JSON
|
18
18
|
|
@@ -18,7 +18,6 @@ import numpy as np
|
|
18
18
|
import torch
|
19
19
|
|
20
20
|
from monai.config import KeysCollection
|
21
|
-
from monai.networks.utils import pytorch_after
|
22
21
|
from monai.transforms import MapTransform
|
23
22
|
from monai.utils.misc import ImageMetaKey
|
24
23
|
|
@@ -74,9 +73,7 @@ class EnsureSameShaped(MapTransform):
|
|
74
73
|
f", the metadata was not updated {filename}."
|
75
74
|
)
|
76
75
|
d[key] = torch.nn.functional.interpolate(
|
77
|
-
input=d[key].unsqueeze(0),
|
78
|
-
size=image_shape,
|
79
|
-
mode="nearest-exact" if pytorch_after(1, 11) else "nearest",
|
76
|
+
input=d[key].unsqueeze(0), size=image_shape, mode="nearest-exact"
|
80
77
|
).squeeze(0)
|
81
78
|
else:
|
82
79
|
raise ValueError(
|
monai/data/utils.py
CHANGED
@@ -50,7 +50,6 @@ from monai.utils import (
|
|
50
50
|
issequenceiterable,
|
51
51
|
look_up_option,
|
52
52
|
optional_import,
|
53
|
-
pytorch_after,
|
54
53
|
)
|
55
54
|
|
56
55
|
pd, _ = optional_import("pandas")
|
@@ -450,12 +449,9 @@ def collate_meta_tensor_fn(batch, *, collate_fn_map=None):
|
|
450
449
|
Collate a sequence of meta tensor into a single batched metatensor. This is called by `collage_meta_tensor`
|
451
450
|
and so should not be used as a collate function directly in dataloaders.
|
452
451
|
"""
|
453
|
-
|
454
|
-
from torch.utils.data._utils.collate import collate_tensor_fn # imported here for pylint/mypy issues
|
452
|
+
from torch.utils.data._utils.collate import collate_tensor_fn # imported here for pylint/mypy issues
|
455
453
|
|
456
|
-
|
457
|
-
else:
|
458
|
-
collated = default_collate(batch)
|
454
|
+
collated = collate_tensor_fn(batch)
|
459
455
|
|
460
456
|
meta_dicts = [i.meta or TraceKeys.NONE for i in batch]
|
461
457
|
common_ = set.intersection(*[set(d.keys()) for d in meta_dicts if isinstance(d, dict)])
|
@@ -494,18 +490,15 @@ def list_data_collate(batch: Sequence):
|
|
494
490
|
Need to use this collate if apply some transforms that can generate batch data.
|
495
491
|
|
496
492
|
"""
|
493
|
+
from torch.utils.data._utils.collate import default_collate_fn_map
|
497
494
|
|
498
|
-
|
499
|
-
# needs to go here to avoid circular import
|
500
|
-
from torch.utils.data._utils.collate import default_collate_fn_map
|
501
|
-
|
502
|
-
from monai.data.meta_tensor import MetaTensor
|
495
|
+
from monai.data.meta_tensor import MetaTensor
|
503
496
|
|
504
|
-
|
497
|
+
default_collate_fn_map.update({MetaTensor: collate_meta_tensor_fn})
|
505
498
|
elem = batch[0]
|
506
499
|
data = [i for k in batch for i in k] if isinstance(elem, list) else batch
|
507
500
|
key = None
|
508
|
-
collate_fn = default_collate
|
501
|
+
collate_fn = default_collate
|
509
502
|
try:
|
510
503
|
if config.USE_META_DICT:
|
511
504
|
data = pickle_operations(data) # bc 0.9.0
|
monai/handlers/__init__.py
CHANGED
@@ -0,0 +1,53 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
from collections.abc import Callable
|
15
|
+
|
16
|
+
from monai.handlers.ignite_metric import IgniteMetricHandler
|
17
|
+
from monai.metrics import AveragePrecisionMetric
|
18
|
+
from monai.utils import Average
|
19
|
+
|
20
|
+
|
21
|
+
class AveragePrecision(IgniteMetricHandler):
|
22
|
+
"""
|
23
|
+
Computes Average Precision (AP).
|
24
|
+
accumulating predictions and the ground-truth during an epoch and applying `compute_average_precision`.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
|
28
|
+
Type of averaging performed if not binary classification. Defaults to ``"macro"``.
|
29
|
+
|
30
|
+
- ``"macro"``: calculate metrics for each label, and find their unweighted mean.
|
31
|
+
This does not take label imbalance into account.
|
32
|
+
- ``"weighted"``: calculate metrics for each label, and find their average,
|
33
|
+
weighted by support (the number of true instances for each label).
|
34
|
+
- ``"micro"``: calculate metrics globally by considering each element of the label
|
35
|
+
indicator matrix as a label.
|
36
|
+
- ``"none"``: the scores for each class are returned.
|
37
|
+
|
38
|
+
output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then
|
39
|
+
construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or
|
40
|
+
lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.
|
41
|
+
`engine.state` and `output_transform` inherit from the ignite concept:
|
42
|
+
https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
|
43
|
+
https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
|
44
|
+
|
45
|
+
Note:
|
46
|
+
Average Precision expects y to be comprised of 0's and 1's.
|
47
|
+
y_pred must either be probability estimates or confidence values.
|
48
|
+
|
49
|
+
"""
|
50
|
+
|
51
|
+
def __init__(self, average: Average | str = Average.MACRO, output_transform: Callable = lambda x: x) -> None:
|
52
|
+
metric_fn = AveragePrecisionMetric(average=Average(average))
|
53
|
+
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=False)
|
monai/inferers/inferer.py
CHANGED
@@ -1202,15 +1202,16 @@ class LatentDiffusionInferer(DiffusionInferer):
|
|
1202
1202
|
|
1203
1203
|
if self.autoencoder_latent_shape is not None:
|
1204
1204
|
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
|
1205
|
-
|
1206
|
-
|
1207
|
-
|
1205
|
+
if save_intermediates:
|
1206
|
+
latent_intermediates = [
|
1207
|
+
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
|
1208
|
+
for l in latent_intermediates
|
1209
|
+
]
|
1208
1210
|
|
1209
1211
|
decode = autoencoder_model.decode_stage_2_outputs
|
1210
1212
|
if isinstance(autoencoder_model, SPADEAutoencoderKL):
|
1211
1213
|
decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
|
1212
1214
|
image = decode(latent / self.scale_factor)
|
1213
|
-
|
1214
1215
|
if save_intermediates:
|
1215
1216
|
intermediates = []
|
1216
1217
|
for latent_intermediate in latent_intermediates:
|
@@ -1727,9 +1728,11 @@ class ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer):
|
|
1727
1728
|
|
1728
1729
|
if self.autoencoder_latent_shape is not None:
|
1729
1730
|
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
|
1730
|
-
|
1731
|
-
|
1732
|
-
|
1731
|
+
if save_intermediates:
|
1732
|
+
latent_intermediates = [
|
1733
|
+
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
|
1734
|
+
for l in latent_intermediates
|
1735
|
+
]
|
1733
1736
|
|
1734
1737
|
decode = autoencoder_model.decode_stage_2_outputs
|
1735
1738
|
if isinstance(autoencoder_model, SPADEAutoencoderKL):
|
monai/inferers/utils.py
CHANGED
@@ -31,11 +31,10 @@ from monai.utils import (
|
|
31
31
|
fall_back_tuple,
|
32
32
|
look_up_option,
|
33
33
|
optional_import,
|
34
|
-
pytorch_after,
|
35
34
|
)
|
36
35
|
|
37
36
|
tqdm, _ = optional_import("tqdm", name="tqdm")
|
38
|
-
_nearest_mode = "nearest-exact"
|
37
|
+
_nearest_mode = "nearest-exact"
|
39
38
|
|
40
39
|
__all__ = ["sliding_window_inference"]
|
41
40
|
|
monai/losses/dice.py
CHANGED
@@ -25,7 +25,7 @@ from monai.losses.focal_loss import FocalLoss
|
|
25
25
|
from monai.losses.spatial_mask import MaskedLoss
|
26
26
|
from monai.losses.utils import compute_tp_fp_fn
|
27
27
|
from monai.networks import one_hot
|
28
|
-
from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option
|
28
|
+
from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option
|
29
29
|
|
30
30
|
|
31
31
|
class DiceLoss(_Loss):
|
@@ -738,12 +738,7 @@ class DiceCELoss(_Loss):
|
|
738
738
|
batch=batch,
|
739
739
|
weight=dice_weight,
|
740
740
|
)
|
741
|
-
|
742
|
-
self.cross_entropy = nn.CrossEntropyLoss(
|
743
|
-
weight=weight, reduction=reduction, label_smoothing=label_smoothing
|
744
|
-
)
|
745
|
-
else:
|
746
|
-
self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction)
|
741
|
+
self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction, label_smoothing=label_smoothing)
|
747
742
|
self.binary_cross_entropy = nn.BCEWithLogitsLoss(pos_weight=weight, reduction=reduction)
|
748
743
|
if lambda_dice < 0.0:
|
749
744
|
raise ValueError("lambda_dice should be no less than 0.0.")
|
@@ -751,7 +746,6 @@ class DiceCELoss(_Loss):
|
|
751
746
|
raise ValueError("lambda_ce should be no less than 0.0.")
|
752
747
|
self.lambda_dice = lambda_dice
|
753
748
|
self.lambda_ce = lambda_ce
|
754
|
-
self.old_pt_ver = not pytorch_after(1, 10)
|
755
749
|
|
756
750
|
def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
757
751
|
"""
|
@@ -764,12 +758,6 @@ class DiceCELoss(_Loss):
|
|
764
758
|
if n_pred_ch != n_target_ch and n_target_ch == 1:
|
765
759
|
target = torch.squeeze(target, dim=1)
|
766
760
|
target = target.long()
|
767
|
-
elif self.old_pt_ver:
|
768
|
-
warnings.warn(
|
769
|
-
f"Multichannel targets are not supported in this older Pytorch version {torch.__version__}. "
|
770
|
-
"Using argmax (as a workaround) to convert target to a single channel."
|
771
|
-
)
|
772
|
-
target = torch.argmax(target, dim=1)
|
773
761
|
elif not torch.is_floating_point(target):
|
774
762
|
target = target.to(dtype=input.dtype)
|
775
763
|
|
monai/losses/ds_loss.py
CHANGED
@@ -17,8 +17,6 @@ import torch
|
|
17
17
|
import torch.nn.functional as F
|
18
18
|
from torch.nn.modules.loss import _Loss
|
19
19
|
|
20
|
-
from monai.utils import pytorch_after
|
21
|
-
|
22
20
|
|
23
21
|
class DeepSupervisionLoss(_Loss):
|
24
22
|
"""
|
@@ -42,7 +40,7 @@ class DeepSupervisionLoss(_Loss):
|
|
42
40
|
self.loss = loss
|
43
41
|
self.weight_mode = weight_mode
|
44
42
|
self.weights = weights
|
45
|
-
self.interp_mode = "nearest-exact"
|
43
|
+
self.interp_mode = "nearest-exact"
|
46
44
|
|
47
45
|
def get_weights(self, levels: int = 1) -> list[float]:
|
48
46
|
"""
|
monai/metrics/__init__.py
CHANGED
@@ -12,6 +12,7 @@
|
|
12
12
|
from __future__ import annotations
|
13
13
|
|
14
14
|
from .active_learning_metrics import LabelQualityScore, VarianceMetric, compute_variance, label_quality_score
|
15
|
+
from .average_precision import AveragePrecisionMetric, compute_average_precision
|
15
16
|
from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix
|
16
17
|
from .cumulative_average import CumulativeAverage
|
17
18
|
from .f_beta_score import FBetaScore
|
@@ -0,0 +1,187 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import warnings
|
15
|
+
from typing import TYPE_CHECKING, cast
|
16
|
+
|
17
|
+
import numpy as np
|
18
|
+
|
19
|
+
if TYPE_CHECKING:
|
20
|
+
import numpy.typing as npt
|
21
|
+
|
22
|
+
import torch
|
23
|
+
|
24
|
+
from monai.utils import Average, look_up_option
|
25
|
+
|
26
|
+
from .metric import CumulativeIterationMetric
|
27
|
+
|
28
|
+
|
29
|
+
class AveragePrecisionMetric(CumulativeIterationMetric):
|
30
|
+
"""
|
31
|
+
Computes Average Precision (AP). AP is a useful metric to evaluate a classifier when the classes are
|
32
|
+
imbalanced. It can take values between 0.0 and 1.0, 1.0 being the best possible score.
|
33
|
+
It summarizes a Precision-Recall curve as the weighted mean of precisions achieved at each
|
34
|
+
threshold, with the increase in recall from the previous threshold used as the weight:
|
35
|
+
|
36
|
+
.. math::
|
37
|
+
\\text{AP} = \\sum_n (R_n - R_{n-1}) P_n
|
38
|
+
:label: ap
|
39
|
+
|
40
|
+
where :math:`P_n` and :math:`R_n` are the precision and recall at the :math:`n^{th}` threshold.
|
41
|
+
|
42
|
+
Referring to: `sklearn.metrics.average_precision_score
|
43
|
+
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score>`_.
|
44
|
+
|
45
|
+
The input `y_pred` and `y` can be a list of `channel-first` Tensor or a `batch-first` Tensor.
|
46
|
+
|
47
|
+
Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
|
51
|
+
Type of averaging performed if not binary classification.
|
52
|
+
Defaults to ``"macro"``.
|
53
|
+
|
54
|
+
- ``"macro"``: calculate metrics for each label, and find their unweighted mean.
|
55
|
+
This does not take label imbalance into account.
|
56
|
+
- ``"weighted"``: calculate metrics for each label, and find their average,
|
57
|
+
weighted by support (the number of true instances for each label).
|
58
|
+
- ``"micro"``: calculate metrics globally by considering each element of the label
|
59
|
+
indicator matrix as a label.
|
60
|
+
- ``"none"``: the scores for each class are returned.
|
61
|
+
|
62
|
+
"""
|
63
|
+
|
64
|
+
def __init__(self, average: Average | str = Average.MACRO) -> None:
|
65
|
+
super().__init__()
|
66
|
+
self.average = average
|
67
|
+
|
68
|
+
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # type: ignore[override]
|
69
|
+
return y_pred, y
|
70
|
+
|
71
|
+
def aggregate(self, average: Average | str | None = None) -> np.ndarray | float | npt.ArrayLike:
|
72
|
+
"""
|
73
|
+
Typically `y_pred` and `y` are stored in the cumulative buffers at each iteration,
|
74
|
+
This function reads the buffers and computes the Average Precision.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
|
78
|
+
Type of averaging performed if not binary classification. Defaults to `self.average`.
|
79
|
+
|
80
|
+
"""
|
81
|
+
y_pred, y = self.get_buffer()
|
82
|
+
# compute final value and do metric reduction
|
83
|
+
if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor):
|
84
|
+
raise ValueError("y_pred and y must be PyTorch Tensor.")
|
85
|
+
|
86
|
+
return compute_average_precision(y_pred=y_pred, y=y, average=average or self.average)
|
87
|
+
|
88
|
+
|
89
|
+
def _calculate(y_pred: torch.Tensor, y: torch.Tensor) -> float:
|
90
|
+
if not (y.ndimension() == y_pred.ndimension() == 1 and len(y) == len(y_pred)):
|
91
|
+
raise AssertionError("y and y_pred must be 1 dimension data with same length.")
|
92
|
+
y_unique = y.unique()
|
93
|
+
if len(y_unique) == 1:
|
94
|
+
warnings.warn(f"y values can not be all {y_unique.item()}, skip AP computation and return `Nan`.")
|
95
|
+
return float("nan")
|
96
|
+
if not y_unique.equal(torch.tensor([0, 1], dtype=y.dtype, device=y.device)):
|
97
|
+
warnings.warn(f"y values must be 0 or 1, but in {y_unique.tolist()}, skip AP computation and return `Nan`.")
|
98
|
+
return float("nan")
|
99
|
+
|
100
|
+
n = len(y)
|
101
|
+
indices = y_pred.argsort(descending=True)
|
102
|
+
y = y[indices].cpu().numpy() # type: ignore[assignment]
|
103
|
+
y_pred = y_pred[indices].cpu().numpy() # type: ignore[assignment]
|
104
|
+
npos = ap = tmp_pos = 0.0
|
105
|
+
|
106
|
+
for i in range(n):
|
107
|
+
y_i = cast(float, y[i])
|
108
|
+
if i + 1 < n and y_pred[i] == y_pred[i + 1]:
|
109
|
+
tmp_pos += y_i
|
110
|
+
else:
|
111
|
+
tmp_pos += y_i
|
112
|
+
npos += tmp_pos
|
113
|
+
ap += tmp_pos * npos / (i + 1)
|
114
|
+
tmp_pos = 0
|
115
|
+
|
116
|
+
return ap / npos
|
117
|
+
|
118
|
+
|
119
|
+
def compute_average_precision(
|
120
|
+
y_pred: torch.Tensor, y: torch.Tensor, average: Average | str = Average.MACRO
|
121
|
+
) -> np.ndarray | float | npt.ArrayLike:
|
122
|
+
"""Computes Average Precision (AP). AP is a useful metric to evaluate a classifier when the classes are
|
123
|
+
imbalanced. It summarizes a Precision-Recall according to equation :eq:`ap`.
|
124
|
+
Referring to: `sklearn.metrics.average_precision_score
|
125
|
+
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score>`_.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
y_pred: input data to compute, typical classification model output.
|
129
|
+
the first dim must be batch, if multi-classes, it must be in One-Hot format.
|
130
|
+
for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data.
|
131
|
+
y: ground truth to compute AP metric, the first dim must be batch.
|
132
|
+
if multi-classes, it must be in One-Hot format.
|
133
|
+
for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data.
|
134
|
+
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
|
135
|
+
Type of averaging performed if not binary classification.
|
136
|
+
Defaults to ``"macro"``.
|
137
|
+
|
138
|
+
- ``"macro"``: calculate metrics for each label, and find their unweighted mean.
|
139
|
+
This does not take label imbalance into account.
|
140
|
+
- ``"weighted"``: calculate metrics for each label, and find their average,
|
141
|
+
weighted by support (the number of true instances for each label).
|
142
|
+
- ``"micro"``: calculate metrics globally by considering each element of the label
|
143
|
+
indicator matrix as a label.
|
144
|
+
- ``"none"``: the scores for each class are returned.
|
145
|
+
|
146
|
+
Raises:
|
147
|
+
ValueError: When ``y_pred`` dimension is not one of [1, 2].
|
148
|
+
ValueError: When ``y`` dimension is not one of [1, 2].
|
149
|
+
ValueError: When ``average`` is not one of ["macro", "weighted", "micro", "none"].
|
150
|
+
|
151
|
+
Note:
|
152
|
+
Average Precision expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values.
|
153
|
+
|
154
|
+
"""
|
155
|
+
y_pred_ndim = y_pred.ndimension()
|
156
|
+
y_ndim = y.ndimension()
|
157
|
+
if y_pred_ndim not in (1, 2):
|
158
|
+
raise ValueError(
|
159
|
+
f"Predictions should be of shape (batch_size, num_classes) or (batch_size, ), got {y_pred.shape}."
|
160
|
+
)
|
161
|
+
if y_ndim not in (1, 2):
|
162
|
+
raise ValueError(f"Targets should be of shape (batch_size, num_classes) or (batch_size, ), got {y.shape}.")
|
163
|
+
if y_pred_ndim == 2 and y_pred.shape[1] == 1:
|
164
|
+
y_pred = y_pred.squeeze(dim=-1)
|
165
|
+
y_pred_ndim = 1
|
166
|
+
if y_ndim == 2 and y.shape[1] == 1:
|
167
|
+
y = y.squeeze(dim=-1)
|
168
|
+
|
169
|
+
if y_pred_ndim == 1:
|
170
|
+
return _calculate(y_pred, y)
|
171
|
+
|
172
|
+
if y.shape != y_pred.shape:
|
173
|
+
raise ValueError(f"data shapes of y_pred and y do not match, got {y_pred.shape} and {y.shape}.")
|
174
|
+
|
175
|
+
average = look_up_option(average, Average)
|
176
|
+
if average == Average.MICRO:
|
177
|
+
return _calculate(y_pred.flatten(), y.flatten())
|
178
|
+
y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1)
|
179
|
+
ap_values = [_calculate(y_pred_, y_) for y_pred_, y_ in zip(y_pred, y)]
|
180
|
+
if average == Average.NONE:
|
181
|
+
return ap_values
|
182
|
+
if average == Average.MACRO:
|
183
|
+
return np.mean(ap_values)
|
184
|
+
if average == Average.WEIGHTED:
|
185
|
+
weights = [sum(y_) for y_ in y]
|
186
|
+
return np.average(ap_values, weights=weights) # type: ignore[no-any-return]
|
187
|
+
raise ValueError(f'Unsupported average: {average}, available options are ["macro", "weighted", "micro", "none"].')
|
@@ -31,7 +31,6 @@ from monai.utils import (
|
|
31
31
|
issequenceiterable,
|
32
32
|
look_up_option,
|
33
33
|
optional_import,
|
34
|
-
pytorch_after,
|
35
34
|
)
|
36
35
|
|
37
36
|
_C, _ = optional_import("monai._C")
|
@@ -293,14 +292,7 @@ def apply_filter(x: torch.Tensor, kernel: torch.Tensor, **kwargs) -> torch.Tenso
|
|
293
292
|
x = x.view(1, kernel.shape[0], *spatials)
|
294
293
|
conv = [F.conv1d, F.conv2d, F.conv3d][n_spatial - 1]
|
295
294
|
if "padding" not in kwargs:
|
296
|
-
|
297
|
-
kwargs["padding"] = "same"
|
298
|
-
else:
|
299
|
-
# even-sized kernels are not supported
|
300
|
-
kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]]
|
301
|
-
elif kwargs["padding"] == "same" and not pytorch_after(1, 10):
|
302
|
-
# even-sized kernels are not supported
|
303
|
-
kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]]
|
295
|
+
kwargs["padding"] = "same"
|
304
296
|
|
305
297
|
if "stride" not in kwargs:
|
306
298
|
kwargs["stride"] = 1
|
@@ -372,11 +364,7 @@ class SavitzkyGolayFilter(nn.Module):
|
|
372
364
|
a = idx ** torch.arange(order + 1, dtype=torch.float, device="cpu").reshape(-1, 1)
|
373
365
|
y = torch.zeros(order + 1, dtype=torch.float, device="cpu")
|
374
366
|
y[0] = 1.0
|
375
|
-
return (
|
376
|
-
torch.lstsq(y, a).solution.squeeze() # type: ignore
|
377
|
-
if not pytorch_after(1, 11)
|
378
|
-
else torch.linalg.lstsq(a, y).solution.squeeze()
|
379
|
-
)
|
367
|
+
return torch.linalg.lstsq(a, y).solution.squeeze()
|
380
368
|
|
381
369
|
|
382
370
|
class HilbertTransform(nn.Module):
|
monai/networks/utils.py
CHANGED
@@ -31,7 +31,7 @@ import torch.nn as nn
|
|
31
31
|
from monai.apps.utils import get_logger
|
32
32
|
from monai.config import PathLike
|
33
33
|
from monai.utils.misc import ensure_tuple, save_obj, set_determinism
|
34
|
-
from monai.utils.module import look_up_option, optional_import
|
34
|
+
from monai.utils.module import look_up_option, optional_import
|
35
35
|
from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor
|
36
36
|
|
37
37
|
onnx, _ = optional_import("onnx")
|
@@ -676,15 +676,6 @@ def convert_to_onnx(
|
|
676
676
|
torch_versioned_kwargs["verify"] = verify
|
677
677
|
verify = False
|
678
678
|
else:
|
679
|
-
if not pytorch_after(1, 10):
|
680
|
-
if "example_outputs" not in kwargs:
|
681
|
-
# https://github.com/pytorch/pytorch/blob/release/1.9/torch/onnx/__init__.py#L182
|
682
|
-
raise TypeError(
|
683
|
-
"example_outputs is required in scripting mode before PyTorch 1.10."
|
684
|
-
"Please provide example outputs or use trace mode to export onnx model."
|
685
|
-
)
|
686
|
-
torch_versioned_kwargs["example_outputs"] = kwargs["example_outputs"]
|
687
|
-
del kwargs["example_outputs"]
|
688
679
|
mode_to_export = torch.jit.script(model, **kwargs)
|
689
680
|
|
690
681
|
if torch.is_tensor(inputs) or isinstance(inputs, dict):
|
@@ -746,8 +737,7 @@ def convert_to_onnx(
|
|
746
737
|
# compare onnx/ort and PyTorch results
|
747
738
|
for r1, r2 in zip(torch_out, onnx_out):
|
748
739
|
if isinstance(r1, torch.Tensor):
|
749
|
-
|
750
|
-
assert_fn(r1.cpu(), convert_to_tensor(r2, dtype=r1.dtype), rtol=rtol, atol=atol) # type: ignore
|
740
|
+
torch.testing.assert_close(r1.cpu(), convert_to_tensor(r2, dtype=r1.dtype), rtol=rtol, atol=atol) # type: ignore
|
751
741
|
|
752
742
|
return onnx_model
|
753
743
|
|
@@ -817,8 +807,7 @@ def convert_to_torchscript(
|
|
817
807
|
# compare TorchScript and PyTorch results
|
818
808
|
for r1, r2 in zip(torch_out, torchscript_out):
|
819
809
|
if isinstance(r1, torch.Tensor) or isinstance(r2, torch.Tensor):
|
820
|
-
|
821
|
-
assert_fn(r1, r2, rtol=rtol, atol=atol) # type: ignore
|
810
|
+
torch.testing.assert_close(r1, r2, rtol=rtol, atol=atol) # type: ignore
|
822
811
|
|
823
812
|
return script_module
|
824
813
|
|
@@ -1031,8 +1020,7 @@ def convert_to_trt(
|
|
1031
1020
|
# compare TorchScript and PyTorch results
|
1032
1021
|
for r1, r2 in zip(torch_out, trt_out):
|
1033
1022
|
if isinstance(r1, torch.Tensor) or isinstance(r2, torch.Tensor):
|
1034
|
-
|
1035
|
-
assert_fn(r1, r2, rtol=rtol, atol=atol) # type: ignore
|
1023
|
+
torch.testing.assert_close(r1, r2, rtol=rtol, atol=atol) # type: ignore
|
1036
1024
|
|
1037
1025
|
return trt_model
|
1038
1026
|
|
monai/transforms/compose.py
CHANGED
@@ -47,7 +47,7 @@ __all__ = ["Compose", "OneOf", "RandomOrder", "SomeOf", "execute_compose"]
|
|
47
47
|
def execute_compose(
|
48
48
|
data: NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor],
|
49
49
|
transforms: Sequence[Any],
|
50
|
-
map_items: bool = True,
|
50
|
+
map_items: bool | int = True,
|
51
51
|
unpack_items: bool = False,
|
52
52
|
start: int = 0,
|
53
53
|
end: int | None = None,
|
@@ -65,8 +65,13 @@ def execute_compose(
|
|
65
65
|
Args:
|
66
66
|
data: a tensor-like object to be transformed
|
67
67
|
transforms: a sequence of transforms to be carried out
|
68
|
-
map_items: whether to apply
|
69
|
-
|
68
|
+
map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple,
|
69
|
+
it can behave as follows:
|
70
|
+
- Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied
|
71
|
+
to the first level of items in `data`.
|
72
|
+
- If an integer is provided, it specifies the maximum level of nesting to which the transformation
|
73
|
+
should be recursively applied. This allows treating multi-sample transforms applied after another
|
74
|
+
multi-sample transform while controlling how deep the mapping goes.
|
70
75
|
unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform.
|
71
76
|
defaults to `False`.
|
72
77
|
start: the index of the first transform to be executed. If not set, this defaults to 0
|
@@ -205,8 +210,14 @@ class Compose(Randomizable, InvertibleTransform, LazyTransform):
|
|
205
210
|
|
206
211
|
Args:
|
207
212
|
transforms: sequence of callables.
|
208
|
-
map_items: whether to apply
|
209
|
-
|
213
|
+
map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple,
|
214
|
+
it can behave as follows:
|
215
|
+
|
216
|
+
- Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied
|
217
|
+
to the first level of items in `data`.
|
218
|
+
- If an integer is provided, it specifies the maximum level of nesting to which the transformation
|
219
|
+
should be recursively applied. This allows treating multi-sample transforms applied after another
|
220
|
+
multi-sample transform while controlling how deep the mapping goes.
|
210
221
|
unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform.
|
211
222
|
defaults to `False`.
|
212
223
|
log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution.
|
@@ -227,7 +238,7 @@ class Compose(Randomizable, InvertibleTransform, LazyTransform):
|
|
227
238
|
def __init__(
|
228
239
|
self,
|
229
240
|
transforms: Sequence[Callable] | Callable | None = None,
|
230
|
-
map_items: bool = True,
|
241
|
+
map_items: bool | int = True,
|
231
242
|
unpack_items: bool = False,
|
232
243
|
log_stats: bool | str = False,
|
233
244
|
lazy: bool | None = False,
|
@@ -238,9 +249,9 @@ class Compose(Randomizable, InvertibleTransform, LazyTransform):
|
|
238
249
|
if transforms is None:
|
239
250
|
transforms = []
|
240
251
|
|
241
|
-
if not isinstance(map_items, bool):
|
252
|
+
if not isinstance(map_items, (bool, int)):
|
242
253
|
raise ValueError(
|
243
|
-
f"Argument 'map_items' should be boolean. Got {type(map_items)}."
|
254
|
+
f"Argument 'map_items' should be boolean or int. Got {type(map_items)}."
|
244
255
|
"Check brackets when passing a sequence of callables."
|
245
256
|
)
|
246
257
|
|
@@ -391,8 +402,14 @@ class OneOf(Compose):
|
|
391
402
|
transforms: sequence of callables.
|
392
403
|
weights: probabilities corresponding to each callable in transforms.
|
393
404
|
Probabilities are normalized to sum to one.
|
394
|
-
map_items: whether to apply
|
395
|
-
|
405
|
+
map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple,
|
406
|
+
it can behave as follows:
|
407
|
+
|
408
|
+
- Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied
|
409
|
+
to the first level of items in `data`.
|
410
|
+
- If an integer is provided, it specifies the maximum level of nesting to which the transformation
|
411
|
+
should be recursively applied. This allows treating multi-sample transforms applied after another
|
412
|
+
multi-sample transform while controlling how deep the mapping goes.
|
396
413
|
unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform.
|
397
414
|
defaults to `False`.
|
398
415
|
log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution.
|
@@ -414,7 +431,7 @@ class OneOf(Compose):
|
|
414
431
|
self,
|
415
432
|
transforms: Sequence[Callable] | Callable | None = None,
|
416
433
|
weights: Sequence[float] | float | None = None,
|
417
|
-
map_items: bool = True,
|
434
|
+
map_items: bool | int = True,
|
418
435
|
unpack_items: bool = False,
|
419
436
|
log_stats: bool | str = False,
|
420
437
|
lazy: bool | None = False,
|
@@ -56,7 +56,6 @@ from monai.utils import (
|
|
56
56
|
ensure_tuple_rep,
|
57
57
|
fall_back_tuple,
|
58
58
|
look_up_option,
|
59
|
-
pytorch_after,
|
60
59
|
)
|
61
60
|
|
62
61
|
__all__ = [
|
@@ -392,11 +391,7 @@ class Crop(InvertibleTransform, LazyTransform):
|
|
392
391
|
roi_center_t = convert_to_tensor(data=roi_center, dtype=torch.int16, wrap_sequence=True, device="cpu")
|
393
392
|
roi_size_t = convert_to_tensor(data=roi_size, dtype=torch.int16, wrap_sequence=True, device="cpu")
|
394
393
|
_zeros = torch.zeros_like(roi_center_t)
|
395
|
-
half = (
|
396
|
-
torch.divide(roi_size_t, 2, rounding_mode="floor")
|
397
|
-
if pytorch_after(1, 8)
|
398
|
-
else torch.floor_divide(roi_size_t, 2)
|
399
|
-
)
|
394
|
+
half = torch.divide(roi_size_t, 2, rounding_mode="floor")
|
400
395
|
roi_start_t = torch.maximum(roi_center_t - half, _zeros)
|
401
396
|
roi_end_t = torch.maximum(roi_start_t + roi_size_t, roi_start_t)
|
402
397
|
else:
|