monai-weekly 1.5.dev2506__py3-none-any.whl → 1.5.dev2507__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/auto3dseg/transforms.py +1 -4
- monai/data/utils.py +6 -13
- monai/inferers/utils.py +1 -2
- monai/losses/dice.py +2 -14
- monai/losses/ds_loss.py +1 -3
- monai/networks/layers/simplelayers.py +2 -14
- monai/networks/utils.py +4 -16
- monai/transforms/compose.py +28 -11
- monai/transforms/croppad/array.py +1 -6
- monai/transforms/io/array.py +0 -1
- monai/transforms/transform.py +15 -6
- monai/transforms/utils.py +1 -2
- monai/utils/tf32.py +0 -10
- monai/visualize/class_activation_maps.py +5 -8
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2507.dist-info}/METADATA +2 -2
- monai_weekly-1.5.dev2507.dist-info/RECORD +1181 -0
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2507.dist-info}/top_level.txt +1 -0
- tests/apps/__init__.py +10 -0
- tests/apps/deepedit/__init__.py +10 -0
- tests/apps/deepedit/test_deepedit_transforms.py +314 -0
- tests/apps/deepgrow/__init__.py +10 -0
- tests/apps/deepgrow/test_deepgrow_dataset.py +109 -0
- tests/apps/deepgrow/transforms/__init__.py +10 -0
- tests/apps/deepgrow/transforms/test_deepgrow_interaction.py +97 -0
- tests/apps/deepgrow/transforms/test_deepgrow_transforms.py +556 -0
- tests/apps/detection/__init__.py +10 -0
- tests/apps/detection/metrics/__init__.py +10 -0
- tests/apps/detection/metrics/test_detection_coco_metrics.py +69 -0
- tests/apps/detection/networks/__init__.py +10 -0
- tests/apps/detection/networks/test_retinanet.py +210 -0
- tests/apps/detection/networks/test_retinanet_detector.py +203 -0
- tests/apps/detection/test_box_transform.py +370 -0
- tests/apps/detection/utils/__init__.py +10 -0
- tests/apps/detection/utils/test_anchor_box.py +88 -0
- tests/apps/detection/utils/test_atss_box_matcher.py +46 -0
- tests/apps/detection/utils/test_box_coder.py +43 -0
- tests/apps/detection/utils/test_detector_boxselector.py +67 -0
- tests/apps/detection/utils/test_detector_utils.py +96 -0
- tests/apps/detection/utils/test_hardnegsampler.py +54 -0
- tests/apps/nuclick/__init__.py +10 -0
- tests/apps/nuclick/test_nuclick_transforms.py +259 -0
- tests/apps/pathology/__init__.py +10 -0
- tests/apps/pathology/handlers/__init__.py +10 -0
- tests/apps/pathology/handlers/test_from_engine_hovernet.py +38 -0
- tests/apps/pathology/test_lesion_froc.py +333 -0
- tests/apps/pathology/test_pathology_prob_nms.py +55 -0
- tests/apps/pathology/test_prepare_batch_hovernet.py +70 -0
- tests/apps/pathology/test_sliding_window_hovernet_inference.py +303 -0
- tests/apps/pathology/transforms/__init__.py +10 -0
- tests/apps/pathology/transforms/post/__init__.py +10 -0
- tests/apps/pathology/transforms/post/test_generate_distance_map.py +51 -0
- tests/apps/pathology/transforms/post/test_generate_distance_mapd.py +70 -0
- tests/apps/pathology/transforms/post/test_generate_instance_border.py +49 -0
- tests/apps/pathology/transforms/post/test_generate_instance_borderd.py +59 -0
- tests/apps/pathology/transforms/post/test_generate_instance_centroid.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_instance_centroidd.py +56 -0
- tests/apps/pathology/transforms/post/test_generate_instance_contour.py +58 -0
- tests/apps/pathology/transforms/post/test_generate_instance_contourd.py +61 -0
- tests/apps/pathology/transforms/post/test_generate_instance_type.py +51 -0
- tests/apps/pathology/transforms/post/test_generate_instance_typed.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_succinct_contour.py +55 -0
- tests/apps/pathology/transforms/post/test_generate_succinct_contourd.py +57 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_markers.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_markersd.py +83 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_mask.py +77 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_maskd.py +77 -0
- tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processing.py +61 -0
- tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processingd.py +66 -0
- tests/apps/pathology/transforms/post/test_hovernet_nuclear_type_post_processing.py +66 -0
- tests/apps/pathology/transforms/post/test_watershed.py +60 -0
- tests/apps/pathology/transforms/post/test_watershedd.py +70 -0
- tests/apps/pathology/transforms/test_pathology_he_stain.py +230 -0
- tests/apps/pathology/transforms/test_pathology_he_stain_dict.py +225 -0
- tests/apps/reconstruction/__init__.py +10 -0
- tests/apps/reconstruction/nets/__init__.py +10 -0
- tests/apps/reconstruction/nets/test_recon_net_utils.py +82 -0
- tests/apps/reconstruction/test_complex_utils.py +77 -0
- tests/apps/reconstruction/test_fastmri_reader.py +82 -0
- tests/apps/reconstruction/test_mri_utils.py +37 -0
- tests/apps/reconstruction/transforms/__init__.py +10 -0
- tests/apps/reconstruction/transforms/test_kspace_mask.py +50 -0
- tests/apps/reconstruction/transforms/test_reference_based_normalize_intensity.py +77 -0
- tests/apps/reconstruction/transforms/test_reference_based_spatial_cropd.py +57 -0
- tests/apps/test_auto3dseg_bundlegen.py +156 -0
- tests/apps/test_check_hash.py +53 -0
- tests/apps/test_cross_validation.py +74 -0
- tests/apps/test_decathlondataset.py +93 -0
- tests/apps/test_download_and_extract.py +70 -0
- tests/apps/test_download_url_yandex.py +45 -0
- tests/apps/test_mednistdataset.py +72 -0
- tests/apps/test_mmar_download.py +154 -0
- tests/apps/test_tciadataset.py +123 -0
- tests/apps/vista3d/__init__.py +10 -0
- tests/apps/vista3d/test_point_based_window_inferer.py +77 -0
- tests/apps/vista3d/test_vista3d_sampler.py +100 -0
- tests/apps/vista3d/test_vista3d_transforms.py +94 -0
- tests/bundle/__init__.py +10 -0
- tests/bundle/test_bundle_ckpt_export.py +107 -0
- tests/bundle/test_bundle_download.py +435 -0
- tests/bundle/test_bundle_get_data.py +94 -0
- tests/bundle/test_bundle_push_to_hf_hub.py +41 -0
- tests/bundle/test_bundle_trt_export.py +147 -0
- tests/bundle/test_bundle_utils.py +149 -0
- tests/bundle/test_bundle_verify_metadata.py +66 -0
- tests/bundle/test_bundle_verify_net.py +76 -0
- tests/bundle/test_bundle_workflow.py +272 -0
- tests/bundle/test_component_locator.py +38 -0
- tests/bundle/test_config_item.py +138 -0
- tests/bundle/test_config_parser.py +392 -0
- tests/bundle/test_reference_resolver.py +114 -0
- tests/config/__init__.py +10 -0
- tests/config/test_cv2_dist.py +53 -0
- tests/engines/__init__.py +10 -0
- tests/engines/test_ensemble_evaluator.py +94 -0
- tests/engines/test_prepare_batch_default.py +76 -0
- tests/engines/test_prepare_batch_default_dist.py +76 -0
- tests/engines/test_prepare_batch_diffusion.py +104 -0
- tests/engines/test_prepare_batch_extra_input.py +80 -0
- tests/fl/__init__.py +10 -0
- tests/fl/monai_algo/__init__.py +10 -0
- tests/fl/monai_algo/test_fl_monai_algo.py +251 -0
- tests/fl/monai_algo/test_fl_monai_algo_dist.py +117 -0
- tests/fl/test_fl_monai_algo_stats.py +81 -0
- tests/fl/utils/__init__.py +10 -0
- tests/fl/utils/test_fl_exchange_object.py +63 -0
- tests/handlers/__init__.py +10 -0
- tests/handlers/test_handler_checkpoint_loader.py +182 -0
- tests/handlers/test_handler_checkpoint_saver.py +233 -0
- tests/handlers/test_handler_classification_saver.py +64 -0
- tests/handlers/test_handler_classification_saver_dist.py +77 -0
- tests/handlers/test_handler_clearml_image.py +65 -0
- tests/handlers/test_handler_clearml_stats.py +65 -0
- tests/handlers/test_handler_confusion_matrix.py +104 -0
- tests/handlers/test_handler_confusion_matrix_dist.py +70 -0
- tests/handlers/test_handler_decollate_batch.py +66 -0
- tests/handlers/test_handler_early_stop.py +68 -0
- tests/handlers/test_handler_garbage_collector.py +73 -0
- tests/handlers/test_handler_hausdorff_distance.py +111 -0
- tests/handlers/test_handler_ignite_metric.py +191 -0
- tests/handlers/test_handler_lr_scheduler.py +94 -0
- tests/handlers/test_handler_mean_dice.py +98 -0
- tests/handlers/test_handler_mean_iou.py +76 -0
- tests/handlers/test_handler_metrics_reloaded.py +149 -0
- tests/handlers/test_handler_metrics_saver.py +89 -0
- tests/handlers/test_handler_metrics_saver_dist.py +120 -0
- tests/handlers/test_handler_mlflow.py +296 -0
- tests/handlers/test_handler_nvtx.py +93 -0
- tests/handlers/test_handler_panoptic_quality.py +89 -0
- tests/handlers/test_handler_parameter_scheduler.py +136 -0
- tests/handlers/test_handler_post_processing.py +74 -0
- tests/handlers/test_handler_prob_map_producer.py +111 -0
- tests/handlers/test_handler_regression_metrics.py +160 -0
- tests/handlers/test_handler_regression_metrics_dist.py +245 -0
- tests/handlers/test_handler_rocauc.py +48 -0
- tests/handlers/test_handler_rocauc_dist.py +54 -0
- tests/handlers/test_handler_stats.py +281 -0
- tests/handlers/test_handler_surface_distance.py +113 -0
- tests/handlers/test_handler_tb_image.py +61 -0
- tests/handlers/test_handler_tb_stats.py +166 -0
- tests/handlers/test_handler_validation.py +59 -0
- tests/handlers/test_trt_compile.py +145 -0
- tests/handlers/test_write_metrics_reports.py +68 -0
- tests/inferers/__init__.py +10 -0
- tests/inferers/test_avg_merger.py +179 -0
- tests/inferers/test_controlnet_inferers.py +1310 -0
- tests/inferers/test_diffusion_inferer.py +236 -0
- tests/inferers/test_latent_diffusion_inferer.py +824 -0
- tests/inferers/test_patch_inferer.py +309 -0
- tests/inferers/test_saliency_inferer.py +55 -0
- tests/inferers/test_slice_inferer.py +57 -0
- tests/inferers/test_sliding_window_inference.py +377 -0
- tests/inferers/test_sliding_window_splitter.py +284 -0
- tests/inferers/test_wsi_sliding_window_splitter.py +249 -0
- tests/inferers/test_zarr_avg_merger.py +326 -0
- tests/integration/__init__.py +10 -0
- tests/integration/test_auto3dseg_ensemble.py +211 -0
- tests/integration/test_auto3dseg_hpo.py +189 -0
- tests/integration/test_deepedit_interaction.py +122 -0
- tests/integration/test_downsample_block.py +50 -0
- tests/integration/test_hovernet_nuclear_type_post_processingd.py +71 -0
- tests/integration/test_integration_autorunner.py +201 -0
- tests/integration/test_integration_bundle_run.py +240 -0
- tests/integration/test_integration_classification_2d.py +282 -0
- tests/integration/test_integration_determinism.py +95 -0
- tests/integration/test_integration_fast_train.py +231 -0
- tests/integration/test_integration_gpu_customization.py +159 -0
- tests/integration/test_integration_lazy_samples.py +219 -0
- tests/integration/test_integration_nnunetv2_runner.py +96 -0
- tests/integration/test_integration_segmentation_3d.py +304 -0
- tests/integration/test_integration_sliding_window.py +100 -0
- tests/integration/test_integration_stn.py +133 -0
- tests/integration/test_integration_unet_2d.py +67 -0
- tests/integration/test_integration_workers.py +61 -0
- tests/integration/test_integration_workflows.py +365 -0
- tests/integration/test_integration_workflows_adversarial.py +173 -0
- tests/integration/test_integration_workflows_gan.py +158 -0
- tests/integration/test_loader_semaphore.py +48 -0
- tests/integration/test_mapping_filed.py +122 -0
- tests/integration/test_meta_affine.py +183 -0
- tests/integration/test_metatensor_integration.py +114 -0
- tests/integration/test_module_list.py +76 -0
- tests/integration/test_one_of.py +283 -0
- tests/integration/test_pad_collation.py +124 -0
- tests/integration/test_reg_loss_integration.py +107 -0
- tests/integration/test_retinanet_predict_utils.py +154 -0
- tests/integration/test_seg_loss_integration.py +159 -0
- tests/integration/test_spatial_combine_transforms.py +185 -0
- tests/integration/test_testtimeaugmentation.py +186 -0
- tests/integration/test_vis_gradbased.py +69 -0
- tests/integration/test_vista3d_utils.py +159 -0
- tests/losses/__init__.py +10 -0
- tests/losses/deform/__init__.py +10 -0
- tests/losses/deform/test_bending_energy.py +88 -0
- tests/losses/deform/test_diffusion_loss.py +117 -0
- tests/losses/image_dissimilarity/__init__.py +10 -0
- tests/losses/image_dissimilarity/test_global_mutual_information_loss.py +150 -0
- tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py +162 -0
- tests/losses/test_adversarial_loss.py +94 -0
- tests/losses/test_barlow_twins_loss.py +109 -0
- tests/losses/test_cldice_loss.py +51 -0
- tests/losses/test_contrastive_loss.py +86 -0
- tests/losses/test_dice_ce_loss.py +123 -0
- tests/losses/test_dice_focal_loss.py +124 -0
- tests/losses/test_dice_loss.py +227 -0
- tests/losses/test_ds_loss.py +189 -0
- tests/losses/test_focal_loss.py +379 -0
- tests/losses/test_generalized_dice_focal_loss.py +85 -0
- tests/losses/test_generalized_dice_loss.py +221 -0
- tests/losses/test_generalized_wasserstein_dice_loss.py +234 -0
- tests/losses/test_giou_loss.py +62 -0
- tests/losses/test_hausdorff_loss.py +264 -0
- tests/losses/test_masked_dice_loss.py +152 -0
- tests/losses/test_masked_loss.py +87 -0
- tests/losses/test_multi_scale.py +86 -0
- tests/losses/test_nacl_loss.py +167 -0
- tests/losses/test_perceptual_loss.py +122 -0
- tests/losses/test_spectral_loss.py +86 -0
- tests/losses/test_ssim_loss.py +59 -0
- tests/losses/test_sure_loss.py +72 -0
- tests/losses/test_tversky_loss.py +198 -0
- tests/losses/test_unified_focal_loss.py +66 -0
- tests/metrics/__init__.py +10 -0
- tests/metrics/test_compute_confusion_matrix.py +294 -0
- tests/metrics/test_compute_f_beta.py +80 -0
- tests/metrics/test_compute_fid_metric.py +40 -0
- tests/metrics/test_compute_froc.py +143 -0
- tests/metrics/test_compute_generalized_dice.py +240 -0
- tests/metrics/test_compute_meandice.py +306 -0
- tests/metrics/test_compute_meaniou.py +223 -0
- tests/metrics/test_compute_mmd_metric.py +56 -0
- tests/metrics/test_compute_multiscalessim_metric.py +83 -0
- tests/metrics/test_compute_panoptic_quality.py +113 -0
- tests/metrics/test_compute_regression_metrics.py +196 -0
- tests/metrics/test_compute_roc_auc.py +155 -0
- tests/metrics/test_compute_variance.py +147 -0
- tests/metrics/test_cumulative.py +63 -0
- tests/metrics/test_cumulative_average.py +74 -0
- tests/metrics/test_cumulative_average_dist.py +48 -0
- tests/metrics/test_hausdorff_distance.py +209 -0
- tests/metrics/test_label_quality_score.py +134 -0
- tests/metrics/test_loss_metric.py +57 -0
- tests/metrics/test_metrics_reloaded.py +96 -0
- tests/metrics/test_ssim_metric.py +78 -0
- tests/metrics/test_surface_dice.py +416 -0
- tests/metrics/test_surface_distance.py +186 -0
- tests/networks/__init__.py +10 -0
- tests/networks/blocks/__init__.py +10 -0
- tests/networks/blocks/dints_block/__init__.py +10 -0
- tests/networks/blocks/dints_block/test_acn_block.py +41 -0
- tests/networks/blocks/dints_block/test_factorized_increase.py +37 -0
- tests/networks/blocks/dints_block/test_factorized_reduce.py +37 -0
- tests/networks/blocks/dints_block/test_p3d_block.py +78 -0
- tests/networks/blocks/test_adn.py +86 -0
- tests/networks/blocks/test_convolutions.py +156 -0
- tests/networks/blocks/test_crf_cpu.py +513 -0
- tests/networks/blocks/test_crf_cuda.py +528 -0
- tests/networks/blocks/test_crossattention.py +185 -0
- tests/networks/blocks/test_denseblock.py +105 -0
- tests/networks/blocks/test_dynunet_block.py +116 -0
- tests/networks/blocks/test_fpn_block.py +88 -0
- tests/networks/blocks/test_localnet_block.py +121 -0
- tests/networks/blocks/test_mlp.py +78 -0
- tests/networks/blocks/test_patchembedding.py +212 -0
- tests/networks/blocks/test_regunet_block.py +103 -0
- tests/networks/blocks/test_se_block.py +85 -0
- tests/networks/blocks/test_se_blocks.py +78 -0
- tests/networks/blocks/test_segresnet_block.py +57 -0
- tests/networks/blocks/test_selfattention.py +232 -0
- tests/networks/blocks/test_simple_aspp.py +87 -0
- tests/networks/blocks/test_spatialattention.py +55 -0
- tests/networks/blocks/test_subpixel_upsample.py +87 -0
- tests/networks/blocks/test_text_encoding.py +49 -0
- tests/networks/blocks/test_transformerblock.py +90 -0
- tests/networks/blocks/test_unetr_block.py +158 -0
- tests/networks/blocks/test_upsample_block.py +134 -0
- tests/networks/blocks/warp/__init__.py +10 -0
- tests/networks/blocks/warp/test_dvf2ddf.py +72 -0
- tests/networks/blocks/warp/test_warp.py +250 -0
- tests/networks/layers/__init__.py +10 -0
- tests/networks/layers/filtering/__init__.py +10 -0
- tests/networks/layers/filtering/test_bilateral_approx_cpu.py +399 -0
- tests/networks/layers/filtering/test_bilateral_approx_cuda.py +404 -0
- tests/networks/layers/filtering/test_bilateral_precise.py +437 -0
- tests/networks/layers/filtering/test_phl_cpu.py +259 -0
- tests/networks/layers/filtering/test_phl_cuda.py +167 -0
- tests/networks/layers/filtering/test_trainable_bilateral.py +474 -0
- tests/networks/layers/filtering/test_trainable_joint_bilateral.py +609 -0
- tests/networks/layers/test_affine_transform.py +385 -0
- tests/networks/layers/test_apply_filter.py +89 -0
- tests/networks/layers/test_channel_pad.py +51 -0
- tests/networks/layers/test_conjugate_gradient.py +56 -0
- tests/networks/layers/test_drop_path.py +46 -0
- tests/networks/layers/test_gaussian.py +317 -0
- tests/networks/layers/test_gaussian_filter.py +206 -0
- tests/networks/layers/test_get_layers.py +65 -0
- tests/networks/layers/test_gmm.py +314 -0
- tests/networks/layers/test_grid_pull.py +93 -0
- tests/networks/layers/test_hilbert_transform.py +131 -0
- tests/networks/layers/test_lltm.py +62 -0
- tests/networks/layers/test_median_filter.py +52 -0
- tests/networks/layers/test_polyval.py +55 -0
- tests/networks/layers/test_preset_filters.py +136 -0
- tests/networks/layers/test_savitzky_golay_filter.py +141 -0
- tests/networks/layers/test_separable_filter.py +87 -0
- tests/networks/layers/test_skip_connection.py +48 -0
- tests/networks/layers/test_vector_quantizer.py +89 -0
- tests/networks/layers/test_weight_init.py +50 -0
- tests/networks/nets/__init__.py +10 -0
- tests/networks/nets/dints/__init__.py +10 -0
- tests/networks/nets/dints/test_dints_cell.py +110 -0
- tests/networks/nets/dints/test_dints_mixop.py +84 -0
- tests/networks/nets/regunet/__init__.py +10 -0
- tests/networks/nets/regunet/test_localnet.py +86 -0
- tests/networks/nets/regunet/test_regunet.py +88 -0
- tests/networks/nets/test_ahnet.py +224 -0
- tests/networks/nets/test_attentionunet.py +88 -0
- tests/networks/nets/test_autoencoder.py +95 -0
- tests/networks/nets/test_autoencoderkl.py +337 -0
- tests/networks/nets/test_basic_unet.py +102 -0
- tests/networks/nets/test_basic_unetplusplus.py +109 -0
- tests/networks/nets/test_bundle_init_bundle.py +55 -0
- tests/networks/nets/test_cell_sam_wrapper.py +58 -0
- tests/networks/nets/test_controlnet.py +215 -0
- tests/networks/nets/test_daf3d.py +62 -0
- tests/networks/nets/test_densenet.py +121 -0
- tests/networks/nets/test_diffusion_model_unet.py +585 -0
- tests/networks/nets/test_dints_network.py +168 -0
- tests/networks/nets/test_discriminator.py +59 -0
- tests/networks/nets/test_dynunet.py +181 -0
- tests/networks/nets/test_efficientnet.py +400 -0
- tests/networks/nets/test_flexible_unet.py +341 -0
- tests/networks/nets/test_fullyconnectednet.py +69 -0
- tests/networks/nets/test_generator.py +59 -0
- tests/networks/nets/test_globalnet.py +103 -0
- tests/networks/nets/test_highresnet.py +67 -0
- tests/networks/nets/test_hovernet.py +218 -0
- tests/networks/nets/test_mednext.py +122 -0
- tests/networks/nets/test_milmodel.py +92 -0
- tests/networks/nets/test_net_adapter.py +68 -0
- tests/networks/nets/test_network_consistency.py +86 -0
- tests/networks/nets/test_patch_gan_dicriminator.py +179 -0
- tests/networks/nets/test_quicknat.py +57 -0
- tests/networks/nets/test_resnet.py +340 -0
- tests/networks/nets/test_segresnet.py +120 -0
- tests/networks/nets/test_segresnet_ds.py +156 -0
- tests/networks/nets/test_senet.py +151 -0
- tests/networks/nets/test_spade_autoencoderkl.py +295 -0
- tests/networks/nets/test_spade_diffusion_model_unet.py +574 -0
- tests/networks/nets/test_spade_vaegan.py +140 -0
- tests/networks/nets/test_swin_unetr.py +139 -0
- tests/networks/nets/test_torchvision_fc_model.py +201 -0
- tests/networks/nets/test_transchex.py +84 -0
- tests/networks/nets/test_transformer.py +108 -0
- tests/networks/nets/test_unet.py +208 -0
- tests/networks/nets/test_unetr.py +137 -0
- tests/networks/nets/test_varautoencoder.py +127 -0
- tests/networks/nets/test_vista3d.py +84 -0
- tests/networks/nets/test_vit.py +139 -0
- tests/networks/nets/test_vitautoenc.py +112 -0
- tests/networks/nets/test_vnet.py +81 -0
- tests/networks/nets/test_voxelmorph.py +280 -0
- tests/networks/nets/test_vqvae.py +274 -0
- tests/networks/nets/test_vqvaetransformer_inferer.py +295 -0
- tests/networks/schedulers/__init__.py +10 -0
- tests/networks/schedulers/test_scheduler_ddim.py +83 -0
- tests/networks/schedulers/test_scheduler_ddpm.py +104 -0
- tests/networks/schedulers/test_scheduler_pndm.py +108 -0
- tests/networks/test_bundle_onnx_export.py +71 -0
- tests/networks/test_convert_to_onnx.py +106 -0
- tests/networks/test_convert_to_torchscript.py +46 -0
- tests/networks/test_convert_to_trt.py +79 -0
- tests/networks/test_save_state.py +73 -0
- tests/networks/test_to_onehot.py +63 -0
- tests/networks/test_varnet.py +63 -0
- tests/networks/utils/__init__.py +10 -0
- tests/networks/utils/test_copy_model_state.py +187 -0
- tests/networks/utils/test_eval_mode.py +34 -0
- tests/networks/utils/test_freeze_layers.py +61 -0
- tests/networks/utils/test_replace_module.py +98 -0
- tests/networks/utils/test_train_mode.py +34 -0
- tests/optimizers/__init__.py +10 -0
- tests/optimizers/test_generate_param_groups.py +105 -0
- tests/optimizers/test_lr_finder.py +108 -0
- tests/optimizers/test_lr_scheduler.py +71 -0
- tests/optimizers/test_optim_novograd.py +100 -0
- tests/profile_subclass/__init__.py +10 -0
- tests/profile_subclass/cprofile_profiling.py +29 -0
- tests/profile_subclass/min_classes.py +30 -0
- tests/profile_subclass/profiling.py +73 -0
- tests/profile_subclass/pyspy_profiling.py +41 -0
- tests/transforms/__init__.py +10 -0
- tests/transforms/compose/__init__.py +10 -0
- tests/transforms/compose/test_compose.py +758 -0
- tests/transforms/compose/test_some_of.py +258 -0
- tests/transforms/croppad/__init__.py +10 -0
- tests/transforms/croppad/test_rand_weighted_crop.py +224 -0
- tests/transforms/croppad/test_rand_weighted_cropd.py +182 -0
- tests/transforms/functional/__init__.py +10 -0
- tests/transforms/functional/test_apply.py +75 -0
- tests/transforms/functional/test_resample.py +50 -0
- tests/transforms/intensity/__init__.py +10 -0
- tests/transforms/intensity/test_compute_ho_ver_maps.py +75 -0
- tests/transforms/intensity/test_compute_ho_ver_maps_d.py +79 -0
- tests/transforms/intensity/test_foreground_mask.py +98 -0
- tests/transforms/intensity/test_foreground_maskd.py +106 -0
- tests/transforms/intensity/test_rand_histogram_shiftd.py +76 -0
- tests/transforms/intensity/test_scale_intensity_range_percentiles.py +96 -0
- tests/transforms/intensity/test_scale_intensity_range_percentilesd.py +100 -0
- tests/transforms/inverse/__init__.py +10 -0
- tests/transforms/inverse/test_inverse_array.py +76 -0
- tests/transforms/inverse/test_traceable_transform.py +59 -0
- tests/transforms/post/__init__.py +10 -0
- tests/transforms/post/test_label_filterd.py +78 -0
- tests/transforms/post/test_probnms.py +72 -0
- tests/transforms/post/test_probnmsd.py +79 -0
- tests/transforms/post/test_remove_small_objects.py +102 -0
- tests/transforms/spatial/__init__.py +10 -0
- tests/transforms/spatial/test_convert_box_points.py +119 -0
- tests/transforms/spatial/test_grid_patch.py +134 -0
- tests/transforms/spatial/test_grid_patchd.py +102 -0
- tests/transforms/spatial/test_rand_grid_patch.py +150 -0
- tests/transforms/spatial/test_rand_grid_patchd.py +117 -0
- tests/transforms/spatial/test_spatial_resampled.py +124 -0
- tests/transforms/test_activations.py +120 -0
- tests/transforms/test_activationsd.py +64 -0
- tests/transforms/test_adaptors.py +160 -0
- tests/transforms/test_add_coordinate_channels.py +53 -0
- tests/transforms/test_add_coordinate_channelsd.py +67 -0
- tests/transforms/test_add_extreme_points_channel.py +80 -0
- tests/transforms/test_add_extreme_points_channeld.py +77 -0
- tests/transforms/test_adjust_contrast.py +70 -0
- tests/transforms/test_adjust_contrastd.py +64 -0
- tests/transforms/test_affine.py +245 -0
- tests/transforms/test_affine_grid.py +152 -0
- tests/transforms/test_affined.py +190 -0
- tests/transforms/test_as_channel_last.py +38 -0
- tests/transforms/test_as_channel_lastd.py +44 -0
- tests/transforms/test_as_discrete.py +81 -0
- tests/transforms/test_as_discreted.py +82 -0
- tests/transforms/test_border_pad.py +49 -0
- tests/transforms/test_border_padd.py +45 -0
- tests/transforms/test_bounding_rect.py +54 -0
- tests/transforms/test_bounding_rectd.py +53 -0
- tests/transforms/test_cast_to_type.py +63 -0
- tests/transforms/test_cast_to_typed.py +74 -0
- tests/transforms/test_center_scale_crop.py +55 -0
- tests/transforms/test_center_scale_cropd.py +56 -0
- tests/transforms/test_center_spatial_crop.py +56 -0
- tests/transforms/test_center_spatial_cropd.py +63 -0
- tests/transforms/test_classes_to_indices.py +93 -0
- tests/transforms/test_classes_to_indicesd.py +110 -0
- tests/transforms/test_clip_intensity_percentiles.py +196 -0
- tests/transforms/test_clip_intensity_percentilesd.py +193 -0
- tests/transforms/test_compose_get_number_conversions.py +127 -0
- tests/transforms/test_concat_itemsd.py +82 -0
- tests/transforms/test_convert_to_multi_channel.py +59 -0
- tests/transforms/test_convert_to_multi_channeld.py +37 -0
- tests/transforms/test_copy_itemsd.py +86 -0
- tests/transforms/test_create_grid_and_affine.py +274 -0
- tests/transforms/test_crop_foreground.py +164 -0
- tests/transforms/test_crop_foregroundd.py +205 -0
- tests/transforms/test_cucim_dict_transform.py +142 -0
- tests/transforms/test_cucim_transform.py +141 -0
- tests/transforms/test_data_stats.py +221 -0
- tests/transforms/test_data_statsd.py +249 -0
- tests/transforms/test_delete_itemsd.py +58 -0
- tests/transforms/test_detect_envelope.py +159 -0
- tests/transforms/test_distance_transform_edt.py +202 -0
- tests/transforms/test_divisible_pad.py +49 -0
- tests/transforms/test_divisible_padd.py +42 -0
- tests/transforms/test_ensure_channel_first.py +113 -0
- tests/transforms/test_ensure_channel_firstd.py +85 -0
- tests/transforms/test_ensure_type.py +94 -0
- tests/transforms/test_ensure_typed.py +110 -0
- tests/transforms/test_fg_bg_to_indices.py +83 -0
- tests/transforms/test_fg_bg_to_indicesd.py +78 -0
- tests/transforms/test_fill_holes.py +207 -0
- tests/transforms/test_fill_holesd.py +209 -0
- tests/transforms/test_flatten_sub_keysd.py +64 -0
- tests/transforms/test_flip.py +83 -0
- tests/transforms/test_flipd.py +90 -0
- tests/transforms/test_fourier.py +70 -0
- tests/transforms/test_gaussian_sharpen.py +92 -0
- tests/transforms/test_gaussian_sharpend.py +92 -0
- tests/transforms/test_gaussian_smooth.py +96 -0
- tests/transforms/test_gaussian_smoothd.py +96 -0
- tests/transforms/test_generate_label_classes_crop_centers.py +71 -0
- tests/transforms/test_generate_pos_neg_label_crop_centers.py +76 -0
- tests/transforms/test_generate_spatial_bounding_box.py +114 -0
- tests/transforms/test_get_extreme_points.py +57 -0
- tests/transforms/test_gibbs_noise.py +75 -0
- tests/transforms/test_gibbs_noised.py +88 -0
- tests/transforms/test_grid_distortion.py +113 -0
- tests/transforms/test_grid_distortiond.py +87 -0
- tests/transforms/test_grid_split.py +88 -0
- tests/transforms/test_grid_splitd.py +96 -0
- tests/transforms/test_histogram_normalize.py +59 -0
- tests/transforms/test_histogram_normalized.py +59 -0
- tests/transforms/test_image_filter.py +259 -0
- tests/transforms/test_intensity_stats.py +73 -0
- tests/transforms/test_intensity_statsd.py +90 -0
- tests/transforms/test_inverse.py +521 -0
- tests/transforms/test_inverse_collation.py +147 -0
- tests/transforms/test_invert.py +105 -0
- tests/transforms/test_invertd.py +142 -0
- tests/transforms/test_k_space_spike_noise.py +81 -0
- tests/transforms/test_k_space_spike_noised.py +98 -0
- tests/transforms/test_keep_largest_connected_component.py +419 -0
- tests/transforms/test_keep_largest_connected_componentd.py +348 -0
- tests/transforms/test_label_filter.py +78 -0
- tests/transforms/test_label_to_contour.py +179 -0
- tests/transforms/test_label_to_contourd.py +182 -0
- tests/transforms/test_label_to_mask.py +69 -0
- tests/transforms/test_label_to_maskd.py +70 -0
- tests/transforms/test_load_image.py +502 -0
- tests/transforms/test_load_imaged.py +198 -0
- tests/transforms/test_load_spacing_orientation.py +149 -0
- tests/transforms/test_map_and_generate_sampling_centers.py +86 -0
- tests/transforms/test_map_binary_to_indices.py +75 -0
- tests/transforms/test_map_classes_to_indices.py +135 -0
- tests/transforms/test_map_label_value.py +89 -0
- tests/transforms/test_map_label_valued.py +85 -0
- tests/transforms/test_map_transform.py +45 -0
- tests/transforms/test_mask_intensity.py +74 -0
- tests/transforms/test_mask_intensityd.py +68 -0
- tests/transforms/test_mean_ensemble.py +77 -0
- tests/transforms/test_mean_ensembled.py +91 -0
- tests/transforms/test_median_smooth.py +41 -0
- tests/transforms/test_median_smoothd.py +65 -0
- tests/transforms/test_morphological_ops.py +101 -0
- tests/transforms/test_nifti_endianness.py +107 -0
- tests/transforms/test_normalize_intensity.py +143 -0
- tests/transforms/test_normalize_intensityd.py +81 -0
- tests/transforms/test_nvtx_decorator.py +289 -0
- tests/transforms/test_nvtx_transform.py +143 -0
- tests/transforms/test_orientation.py +247 -0
- tests/transforms/test_orientationd.py +112 -0
- tests/transforms/test_rand_adjust_contrast.py +45 -0
- tests/transforms/test_rand_adjust_contrastd.py +44 -0
- tests/transforms/test_rand_affine.py +201 -0
- tests/transforms/test_rand_affine_grid.py +212 -0
- tests/transforms/test_rand_affined.py +281 -0
- tests/transforms/test_rand_axis_flip.py +50 -0
- tests/transforms/test_rand_axis_flipd.py +50 -0
- tests/transforms/test_rand_bias_field.py +69 -0
- tests/transforms/test_rand_bias_fieldd.py +65 -0
- tests/transforms/test_rand_coarse_dropout.py +110 -0
- tests/transforms/test_rand_coarse_dropoutd.py +107 -0
- tests/transforms/test_rand_coarse_shuffle.py +65 -0
- tests/transforms/test_rand_coarse_shuffled.py +59 -0
- tests/transforms/test_rand_crop_by_label_classes.py +170 -0
- tests/transforms/test_rand_crop_by_label_classesd.py +159 -0
- tests/transforms/test_rand_crop_by_pos_neg_label.py +152 -0
- tests/transforms/test_rand_crop_by_pos_neg_labeld.py +172 -0
- tests/transforms/test_rand_cucim_dict_transform.py +162 -0
- tests/transforms/test_rand_cucim_transform.py +162 -0
- tests/transforms/test_rand_deform_grid.py +138 -0
- tests/transforms/test_rand_elastic_2d.py +127 -0
- tests/transforms/test_rand_elastic_3d.py +104 -0
- tests/transforms/test_rand_elasticd_2d.py +177 -0
- tests/transforms/test_rand_elasticd_3d.py +156 -0
- tests/transforms/test_rand_flip.py +60 -0
- tests/transforms/test_rand_flipd.py +55 -0
- tests/transforms/test_rand_gaussian_noise.py +48 -0
- tests/transforms/test_rand_gaussian_noised.py +54 -0
- tests/transforms/test_rand_gaussian_sharpen.py +140 -0
- tests/transforms/test_rand_gaussian_sharpend.py +143 -0
- tests/transforms/test_rand_gaussian_smooth.py +98 -0
- tests/transforms/test_rand_gaussian_smoothd.py +98 -0
- tests/transforms/test_rand_gibbs_noise.py +103 -0
- tests/transforms/test_rand_gibbs_noised.py +117 -0
- tests/transforms/test_rand_grid_distortion.py +99 -0
- tests/transforms/test_rand_grid_distortiond.py +90 -0
- tests/transforms/test_rand_histogram_shift.py +92 -0
- tests/transforms/test_rand_k_space_spike_noise.py +92 -0
- tests/transforms/test_rand_k_space_spike_noised.py +76 -0
- tests/transforms/test_rand_rician_noise.py +52 -0
- tests/transforms/test_rand_rician_noised.py +52 -0
- tests/transforms/test_rand_rotate.py +166 -0
- tests/transforms/test_rand_rotate90.py +100 -0
- tests/transforms/test_rand_rotate90d.py +112 -0
- tests/transforms/test_rand_rotated.py +187 -0
- tests/transforms/test_rand_scale_crop.py +78 -0
- tests/transforms/test_rand_scale_cropd.py +98 -0
- tests/transforms/test_rand_scale_intensity.py +54 -0
- tests/transforms/test_rand_scale_intensity_fixed_mean.py +41 -0
- tests/transforms/test_rand_scale_intensity_fixed_meand.py +41 -0
- tests/transforms/test_rand_scale_intensityd.py +53 -0
- tests/transforms/test_rand_shift_intensity.py +52 -0
- tests/transforms/test_rand_shift_intensityd.py +67 -0
- tests/transforms/test_rand_simulate_low_resolution.py +83 -0
- tests/transforms/test_rand_simulate_low_resolutiond.py +73 -0
- tests/transforms/test_rand_spatial_crop.py +107 -0
- tests/transforms/test_rand_spatial_crop_samples.py +128 -0
- tests/transforms/test_rand_spatial_crop_samplesd.py +147 -0
- tests/transforms/test_rand_spatial_cropd.py +112 -0
- tests/transforms/test_rand_std_shift_intensity.py +43 -0
- tests/transforms/test_rand_std_shift_intensityd.py +38 -0
- tests/transforms/test_rand_zoom.py +105 -0
- tests/transforms/test_rand_zoomd.py +108 -0
- tests/transforms/test_randidentity.py +49 -0
- tests/transforms/test_random_order.py +144 -0
- tests/transforms/test_randtorchvisiond.py +65 -0
- tests/transforms/test_regularization.py +139 -0
- tests/transforms/test_remove_repeated_channel.py +34 -0
- tests/transforms/test_remove_repeated_channeld.py +44 -0
- tests/transforms/test_repeat_channel.py +34 -0
- tests/transforms/test_repeat_channeld.py +41 -0
- tests/transforms/test_resample_backends.py +65 -0
- tests/transforms/test_resample_to_match.py +110 -0
- tests/transforms/test_resample_to_matchd.py +93 -0
- tests/transforms/test_resampler.py +165 -0
- tests/transforms/test_resize.py +140 -0
- tests/transforms/test_resize_with_pad_or_crop.py +91 -0
- tests/transforms/test_resize_with_pad_or_cropd.py +86 -0
- tests/transforms/test_resized.py +163 -0
- tests/transforms/test_rotate.py +160 -0
- tests/transforms/test_rotate90.py +212 -0
- tests/transforms/test_rotate90d.py +106 -0
- tests/transforms/test_rotated.py +179 -0
- tests/transforms/test_save_classificationd.py +109 -0
- tests/transforms/test_save_image.py +80 -0
- tests/transforms/test_save_imaged.py +130 -0
- tests/transforms/test_savitzky_golay_smooth.py +73 -0
- tests/transforms/test_savitzky_golay_smoothd.py +73 -0
- tests/transforms/test_scale_intensity.py +76 -0
- tests/transforms/test_scale_intensity_fixed_mean.py +94 -0
- tests/transforms/test_scale_intensity_range.py +41 -0
- tests/transforms/test_scale_intensity_ranged.py +40 -0
- tests/transforms/test_scale_intensityd.py +57 -0
- tests/transforms/test_select_itemsd.py +41 -0
- tests/transforms/test_shift_intensity.py +31 -0
- tests/transforms/test_shift_intensityd.py +44 -0
- tests/transforms/test_signal_continuouswavelet.py +44 -0
- tests/transforms/test_signal_fillempty.py +52 -0
- tests/transforms/test_signal_fillemptyd.py +60 -0
- tests/transforms/test_signal_rand_add_gaussiannoise.py +50 -0
- tests/transforms/test_signal_rand_add_sine.py +52 -0
- tests/transforms/test_signal_rand_add_sine_partial.py +50 -0
- tests/transforms/test_signal_rand_add_squarepulse.py +58 -0
- tests/transforms/test_signal_rand_add_squarepulse_partial.py +62 -0
- tests/transforms/test_signal_rand_drop.py +50 -0
- tests/transforms/test_signal_rand_scale.py +52 -0
- tests/transforms/test_signal_rand_shift.py +55 -0
- tests/transforms/test_signal_remove_frequency.py +71 -0
- tests/transforms/test_smooth_field.py +177 -0
- tests/transforms/test_sobel_gradient.py +189 -0
- tests/transforms/test_sobel_gradientd.py +212 -0
- tests/transforms/test_spacing.py +381 -0
- tests/transforms/test_spacingd.py +178 -0
- tests/transforms/test_spatial_crop.py +82 -0
- tests/transforms/test_spatial_cropd.py +74 -0
- tests/transforms/test_spatial_pad.py +57 -0
- tests/transforms/test_spatial_padd.py +43 -0
- tests/transforms/test_spatial_resample.py +235 -0
- tests/transforms/test_squeezedim.py +62 -0
- tests/transforms/test_squeezedimd.py +98 -0
- tests/transforms/test_std_shift_intensity.py +76 -0
- tests/transforms/test_std_shift_intensityd.py +74 -0
- tests/transforms/test_threshold_intensity.py +38 -0
- tests/transforms/test_threshold_intensityd.py +58 -0
- tests/transforms/test_to_contiguous.py +47 -0
- tests/transforms/test_to_cupy.py +112 -0
- tests/transforms/test_to_cupyd.py +76 -0
- tests/transforms/test_to_device.py +42 -0
- tests/transforms/test_to_deviced.py +37 -0
- tests/transforms/test_to_numpy.py +85 -0
- tests/transforms/test_to_numpyd.py +68 -0
- tests/transforms/test_to_pil.py +52 -0
- tests/transforms/test_to_pild.py +50 -0
- tests/transforms/test_to_tensor.py +60 -0
- tests/transforms/test_to_tensord.py +71 -0
- tests/transforms/test_torchvision.py +66 -0
- tests/transforms/test_torchvisiond.py +63 -0
- tests/transforms/test_transform.py +62 -0
- tests/transforms/test_transpose.py +41 -0
- tests/transforms/test_transposed.py +52 -0
- tests/transforms/test_ultrasound_confidence_map_transform.py +711 -0
- tests/transforms/test_utils_pytorch_numpy_unification.py +90 -0
- tests/transforms/test_vote_ensemble.py +84 -0
- tests/transforms/test_vote_ensembled.py +107 -0
- tests/transforms/test_with_allow_missing_keys.py +76 -0
- tests/transforms/test_zoom.py +120 -0
- tests/transforms/test_zoomd.py +94 -0
- tests/transforms/transform/__init__.py +10 -0
- tests/transforms/transform/test_randomizable.py +52 -0
- tests/transforms/transform/test_randomizable_transform_type.py +37 -0
- tests/transforms/utility/__init__.py +10 -0
- tests/transforms/utility/test_apply_transform_to_points.py +81 -0
- tests/transforms/utility/test_apply_transform_to_pointsd.py +185 -0
- tests/transforms/utility/test_identity.py +29 -0
- tests/transforms/utility/test_identityd.py +30 -0
- tests/transforms/utility/test_lambda.py +71 -0
- tests/transforms/utility/test_lambdad.py +83 -0
- tests/transforms/utility/test_rand_lambda.py +87 -0
- tests/transforms/utility/test_rand_lambdad.py +77 -0
- tests/transforms/utility/test_simulatedelay.py +36 -0
- tests/transforms/utility/test_simulatedelayd.py +36 -0
- tests/transforms/utility/test_splitdim.py +52 -0
- tests/transforms/utility/test_splitdimd.py +96 -0
- tests/transforms/utils/__init__.py +10 -0
- tests/transforms/utils/test_correct_crop_centers.py +36 -0
- tests/transforms/utils/test_get_unique_labels.py +45 -0
- tests/transforms/utils/test_print_transform_backends.py +29 -0
- tests/transforms/utils/test_soft_clip.py +125 -0
- tests/utils/__init__.py +10 -0
- tests/utils/enums/__init__.py +10 -0
- tests/utils/enums/test_hovernet_loss.py +190 -0
- tests/utils/enums/test_ordering.py +289 -0
- tests/utils/enums/test_wsireader.py +663 -0
- tests/utils/misc/__init__.py +10 -0
- tests/utils/misc/test_ensure_tuple.py +53 -0
- tests/utils/misc/test_monai_env_vars.py +44 -0
- tests/utils/misc/test_monai_utils_misc.py +103 -0
- tests/utils/misc/test_str2bool.py +34 -0
- tests/utils/misc/test_str2list.py +33 -0
- tests/utils/test_alias.py +44 -0
- tests/utils/test_component_store.py +73 -0
- tests/utils/test_deprecated.py +455 -0
- tests/utils/test_enum_bound_interp.py +75 -0
- tests/utils/test_evenly_divisible_all_gather_dist.py +50 -0
- tests/utils/test_get_package_version.py +34 -0
- tests/utils/test_handler_logfile.py +84 -0
- tests/utils/test_handler_metric_logger.py +62 -0
- tests/utils/test_list_to_dict.py +43 -0
- tests/utils/test_look_up_option.py +87 -0
- tests/utils/test_optional_import.py +80 -0
- tests/utils/test_pad_mode.py +39 -0
- tests/utils/test_profiling.py +208 -0
- tests/utils/test_rankfilter_dist.py +77 -0
- tests/utils/test_require_pkg.py +83 -0
- tests/utils/test_sample_slices.py +43 -0
- tests/utils/test_set_determinism.py +74 -0
- tests/utils/test_squeeze_unsqueeze.py +71 -0
- tests/utils/test_state_cacher.py +67 -0
- tests/utils/test_torchscript_utils.py +113 -0
- tests/utils/test_version.py +91 -0
- tests/utils/test_version_after.py +65 -0
- tests/utils/type_conversion/__init__.py +10 -0
- tests/utils/type_conversion/test_convert_data_type.py +152 -0
- tests/utils/type_conversion/test_get_equivalent_dtype.py +65 -0
- tests/utils/type_conversion/test_safe_dtype_range.py +99 -0
- tests/visualize/__init__.py +10 -0
- tests/visualize/test_img2tensorboard.py +46 -0
- tests/visualize/test_occlusion_sensitivity.py +128 -0
- tests/visualize/test_plot_2d_or_3d_image.py +74 -0
- tests/visualize/test_vis_cam.py +98 -0
- tests/visualize/test_vis_gradcam.py +211 -0
- tests/visualize/utils/__init__.py +10 -0
- tests/visualize/utils/test_blend_images.py +63 -0
- tests/visualize/utils/test_matshow3d.py +133 -0
- monai_weekly-1.5.dev2506.dist-info/RECORD +0 -427
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2507.dist-info}/LICENSE +0 -0
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2507.dist-info}/WHEEL +0 -0
@@ -0,0 +1,159 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
from unittest.case import skipUnless
|
16
|
+
|
17
|
+
import numpy as np
|
18
|
+
import torch
|
19
|
+
from parameterized import parameterized
|
20
|
+
|
21
|
+
from monai.transforms.utils import convert_points_to_disc, keep_merge_components_with_points, sample_points_from_label
|
22
|
+
from monai.utils import min_version
|
23
|
+
from monai.utils.module import optional_import
|
24
|
+
from tests.test_utils import skip_if_no_cuda, skip_if_quick
|
25
|
+
|
26
|
+
cp, has_cp = optional_import("cupy")
|
27
|
+
cucim_skimage, has_cucim = optional_import("cucim.skimage")
|
28
|
+
measure, has_measure = optional_import("skimage.measure", "0.14.2", min_version)
|
29
|
+
|
30
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
31
|
+
|
32
|
+
|
33
|
+
TESTS_SAMPLE_POINTS_FROM_LABEL = []
|
34
|
+
for use_center in [True, False]:
|
35
|
+
labels = torch.zeros(1, 1, 32, 32, 32)
|
36
|
+
labels[0, 0, 5:10, 5:10, 5:10] = 1
|
37
|
+
labels[0, 0, 10:15, 10:15, 10:15] = 3
|
38
|
+
labels[0, 0, 20:25, 20:25, 20:25] = 5
|
39
|
+
TESTS_SAMPLE_POINTS_FROM_LABEL.append(
|
40
|
+
[{"labels": labels, "label_set": (1, 3, 5), "use_center": use_center}, (3, 1, 3), (3, 1)]
|
41
|
+
)
|
42
|
+
|
43
|
+
TEST_CONVERT_POINTS_TO_DISC = []
|
44
|
+
for radius in [1, 2]:
|
45
|
+
for disc in [True, False]:
|
46
|
+
image_size = (32, 32, 32)
|
47
|
+
point = torch.randn(3, 1, 3)
|
48
|
+
point_label = torch.randint(0, 4, (3, 1))
|
49
|
+
expected_shape = (point.shape[0], 2, *image_size)
|
50
|
+
TEST_CONVERT_POINTS_TO_DISC.append(
|
51
|
+
[
|
52
|
+
{"image_size": image_size, "point": point, "point_label": point_label, "radius": radius, "disc": disc},
|
53
|
+
expected_shape,
|
54
|
+
]
|
55
|
+
)
|
56
|
+
image_size = (16, 32, 64)
|
57
|
+
point = torch.tensor([[[8, 16, 42], [2, 8, 21]]])
|
58
|
+
point_label = torch.tensor([[1, 0]])
|
59
|
+
expected_shape = (point.shape[0], 2, *image_size)
|
60
|
+
TEST_CONVERT_POINTS_TO_DISC.append(
|
61
|
+
[
|
62
|
+
{"image_size": image_size, "point": point, "point_label": point_label, "radius": radius, "disc": disc},
|
63
|
+
expected_shape,
|
64
|
+
]
|
65
|
+
)
|
66
|
+
|
67
|
+
TEST_CONVERT_POINTS_TO_DISC_VALUE = []
|
68
|
+
image_size = (16, 32, 64)
|
69
|
+
point = torch.tensor([[[8, 16, 42], [2, 8, 21]]])
|
70
|
+
point_label = torch.tensor([[1, 0]])
|
71
|
+
expected_shape = (point.shape[0], 2, *image_size)
|
72
|
+
for radius in [5, 10]:
|
73
|
+
for disc in [True, False]:
|
74
|
+
TEST_CONVERT_POINTS_TO_DISC_VALUE.append(
|
75
|
+
[
|
76
|
+
{"image_size": image_size, "point": point, "point_label": point_label, "radius": radius, "disc": disc},
|
77
|
+
[point, point_label],
|
78
|
+
]
|
79
|
+
)
|
80
|
+
|
81
|
+
|
82
|
+
TEST_LCC_MASK_POINT_TORCH = []
|
83
|
+
for bs in [1, 2]:
|
84
|
+
for num_points in [1, 3]:
|
85
|
+
shape = (bs, 1, 128, 32, 32)
|
86
|
+
TEST_LCC_MASK_POINT_TORCH.append(
|
87
|
+
[
|
88
|
+
{
|
89
|
+
"img_pos": torch.randint(0, 2, shape, dtype=torch.bool),
|
90
|
+
"img_neg": torch.randint(0, 2, shape, dtype=torch.bool),
|
91
|
+
"point_coords": torch.randint(0, 10, (bs, num_points, 3)),
|
92
|
+
"point_labels": torch.randint(0, 4, (bs, num_points)),
|
93
|
+
},
|
94
|
+
shape,
|
95
|
+
]
|
96
|
+
)
|
97
|
+
|
98
|
+
TEST_LCC_MASK_POINT_NP = []
|
99
|
+
for bs in [1, 2]:
|
100
|
+
for num_points in [1, 3]:
|
101
|
+
shape = (bs, 1, 32, 32, 64)
|
102
|
+
TEST_LCC_MASK_POINT_NP.append(
|
103
|
+
[
|
104
|
+
{
|
105
|
+
"img_pos": np.random.randint(0, 2, shape, dtype=bool),
|
106
|
+
"img_neg": np.random.randint(0, 2, shape, dtype=bool),
|
107
|
+
"point_coords": np.random.randint(0, 5, (bs, num_points, 3)),
|
108
|
+
"point_labels": np.random.randint(0, 4, (bs, num_points)),
|
109
|
+
},
|
110
|
+
shape,
|
111
|
+
]
|
112
|
+
)
|
113
|
+
|
114
|
+
|
115
|
+
@skipUnless(has_measure or cucim_skimage, "skimage or cucim.skimage required")
|
116
|
+
class TestSamplePointsFromLabel(unittest.TestCase):
|
117
|
+
@parameterized.expand(TESTS_SAMPLE_POINTS_FROM_LABEL)
|
118
|
+
def test_shape(self, input_data, expected_point_shape, expected_point_label_shape):
|
119
|
+
point, point_label = sample_points_from_label(**input_data)
|
120
|
+
self.assertEqual(point.shape, expected_point_shape)
|
121
|
+
self.assertEqual(point_label.shape, expected_point_label_shape)
|
122
|
+
|
123
|
+
|
124
|
+
class TestConvertPointsToDisc(unittest.TestCase):
|
125
|
+
@parameterized.expand(TEST_CONVERT_POINTS_TO_DISC)
|
126
|
+
def test_shape(self, input_data, expected_shape):
|
127
|
+
result = convert_points_to_disc(**input_data)
|
128
|
+
self.assertEqual(result.shape, expected_shape)
|
129
|
+
|
130
|
+
@parameterized.expand(TEST_CONVERT_POINTS_TO_DISC_VALUE)
|
131
|
+
def test_value(self, input_data, points):
|
132
|
+
result = convert_points_to_disc(**input_data)
|
133
|
+
point, point_label = points
|
134
|
+
for i in range(point.shape[0]):
|
135
|
+
for j in range(point.shape[1]):
|
136
|
+
self.assertEqual(result[i, point_label[i, j], point[i, j][0], point[i, j][1], point[i, j][2]], True)
|
137
|
+
|
138
|
+
|
139
|
+
@skipUnless(has_measure or cucim_skimage, "skimage or cucim.skimage required")
|
140
|
+
class TestKeepMergeComponentsWithPoints(unittest.TestCase):
|
141
|
+
@skip_if_quick
|
142
|
+
@skip_if_no_cuda
|
143
|
+
@skipUnless(has_cp and cucim_skimage, "cupy and cucim.skimage required")
|
144
|
+
@parameterized.expand(TEST_LCC_MASK_POINT_TORCH)
|
145
|
+
def test_cp_shape(self, input_data, shape):
|
146
|
+
for key in input_data:
|
147
|
+
input_data[key] = input_data[key].to(device)
|
148
|
+
mask = keep_merge_components_with_points(**input_data)
|
149
|
+
self.assertEqual(mask.shape, shape)
|
150
|
+
|
151
|
+
@skipUnless(has_measure, "skimage required")
|
152
|
+
@parameterized.expand(TEST_LCC_MASK_POINT_NP)
|
153
|
+
def test_np_shape(self, input_data, shape):
|
154
|
+
mask = keep_merge_components_with_points(**input_data)
|
155
|
+
self.assertEqual(mask.shape, shape)
|
156
|
+
|
157
|
+
|
158
|
+
if __name__ == "__main__":
|
159
|
+
unittest.main()
|
tests/losses/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
@@ -0,0 +1,10 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
@@ -0,0 +1,88 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
import torch
|
18
|
+
from parameterized import parameterized
|
19
|
+
|
20
|
+
from monai.losses.deform import BendingEnergyLoss
|
21
|
+
|
22
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
23
|
+
|
24
|
+
TEST_CASES = [
|
25
|
+
[{}, {"pred": torch.ones((1, 3, 5, 5, 5), device=device)}, 0.0],
|
26
|
+
[{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 0.0],
|
27
|
+
[
|
28
|
+
{"normalize": False},
|
29
|
+
{"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2},
|
30
|
+
4.0,
|
31
|
+
],
|
32
|
+
[
|
33
|
+
{"normalize": False},
|
34
|
+
{"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2},
|
35
|
+
4.0,
|
36
|
+
],
|
37
|
+
[{"normalize": False}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 4.0],
|
38
|
+
[
|
39
|
+
{"normalize": True},
|
40
|
+
{"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2},
|
41
|
+
100.0,
|
42
|
+
],
|
43
|
+
[
|
44
|
+
{"normalize": True},
|
45
|
+
{"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2},
|
46
|
+
100.0,
|
47
|
+
],
|
48
|
+
[{"normalize": True}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 100.0],
|
49
|
+
]
|
50
|
+
|
51
|
+
|
52
|
+
class TestBendingEnergy(unittest.TestCase):
|
53
|
+
|
54
|
+
@parameterized.expand(TEST_CASES)
|
55
|
+
def test_shape(self, input_param, input_data, expected_val):
|
56
|
+
result = BendingEnergyLoss(**input_param).forward(**input_data)
|
57
|
+
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)
|
58
|
+
|
59
|
+
def test_ill_shape(self):
|
60
|
+
loss = BendingEnergyLoss()
|
61
|
+
# not in 3-d, 4-d, 5-d
|
62
|
+
with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"):
|
63
|
+
loss.forward(torch.ones((1, 3), device=device))
|
64
|
+
with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"):
|
65
|
+
loss.forward(torch.ones((1, 4, 5, 5, 5, 5), device=device))
|
66
|
+
with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
|
67
|
+
loss.forward(torch.ones((1, 3, 4, 5, 5), device=device))
|
68
|
+
with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
|
69
|
+
loss.forward(torch.ones((1, 3, 5, 4, 5)))
|
70
|
+
with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
|
71
|
+
loss.forward(torch.ones((1, 3, 5, 5, 4)))
|
72
|
+
|
73
|
+
# number of vector components unequal to number of spatial dims
|
74
|
+
with self.assertRaisesRegex(ValueError, "Number of vector components"):
|
75
|
+
loss.forward(torch.ones((1, 2, 5, 5, 5)))
|
76
|
+
with self.assertRaisesRegex(ValueError, "Number of vector components"):
|
77
|
+
loss.forward(torch.ones((1, 2, 5, 5, 5)))
|
78
|
+
|
79
|
+
def test_ill_opts(self):
|
80
|
+
pred = torch.rand(1, 3, 5, 5, 5).to(device=device)
|
81
|
+
with self.assertRaisesRegex(ValueError, ""):
|
82
|
+
BendingEnergyLoss(reduction="unknown")(pred)
|
83
|
+
with self.assertRaisesRegex(ValueError, ""):
|
84
|
+
BendingEnergyLoss(reduction=None)(pred)
|
85
|
+
|
86
|
+
|
87
|
+
if __name__ == "__main__":
|
88
|
+
unittest.main()
|
@@ -0,0 +1,117 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
import torch
|
18
|
+
from parameterized import parameterized
|
19
|
+
|
20
|
+
from monai.losses.deform import DiffusionLoss
|
21
|
+
|
22
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
23
|
+
|
24
|
+
TEST_CASES = [
|
25
|
+
# all first partials are zero, so the diffusion loss is also zero
|
26
|
+
[{}, {"pred": torch.ones((1, 3, 5, 5, 5), device=device)}, 0.0],
|
27
|
+
# all first partials are one, so the diffusion loss is also one
|
28
|
+
[{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 1.0],
|
29
|
+
# before expansion, the first partials are 2, 4, 6, so the diffusion loss is (2^2 + 4^2 + 6^2) / 3 = 18.67
|
30
|
+
[
|
31
|
+
{"normalize": False},
|
32
|
+
{"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2},
|
33
|
+
56.0 / 3.0,
|
34
|
+
],
|
35
|
+
# same as the previous case
|
36
|
+
[
|
37
|
+
{"normalize": False},
|
38
|
+
{"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2},
|
39
|
+
56.0 / 3.0,
|
40
|
+
],
|
41
|
+
# same as the previous case
|
42
|
+
[{"normalize": False}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0],
|
43
|
+
# we have shown in the demo notebook that
|
44
|
+
# diffusion loss is scale-invariant when the all axes have the same resolution
|
45
|
+
[
|
46
|
+
{"normalize": True},
|
47
|
+
{"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2},
|
48
|
+
56.0 / 3.0,
|
49
|
+
],
|
50
|
+
[
|
51
|
+
{"normalize": True},
|
52
|
+
{"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2},
|
53
|
+
56.0 / 3.0,
|
54
|
+
],
|
55
|
+
[{"normalize": True}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0],
|
56
|
+
# for the following case, consider the following 2D matrix:
|
57
|
+
# tensor([[[[0, 1, 2],
|
58
|
+
# [1, 2, 3],
|
59
|
+
# [2, 3, 4],
|
60
|
+
# [3, 4, 5],
|
61
|
+
# [4, 5, 6]],
|
62
|
+
# [[0, 1, 2],
|
63
|
+
# [1, 2, 3],
|
64
|
+
# [2, 3, 4],
|
65
|
+
# [3, 4, 5],
|
66
|
+
# [4, 5, 6]]]])
|
67
|
+
# the first partials wrt x are all ones, and so are the first partials wrt y
|
68
|
+
# the diffusion loss, when normalization is not applied, is 1^2 + 1^2 = 2
|
69
|
+
[{"normalize": False}, {"pred": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)}, 2.0],
|
70
|
+
# consider the same matrix, this time with normalization applied, using the same notation as in the demo notebook,
|
71
|
+
# the coefficients to be divided out are (1, 5/3) for partials wrt x and (3/5, 1) for partials wrt y
|
72
|
+
# the diffusion loss is then (1/1)^2 + (1/(5/3))^2 + (1/(3/5))^2 + (1/1)^2 = (1 + 9/25 + 25/9 + 1) / 2 = 2.5689
|
73
|
+
[
|
74
|
+
{"normalize": True},
|
75
|
+
{"pred": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)},
|
76
|
+
(1.0 + 9.0 / 25.0 + 25.0 / 9.0 + 1.0) / 2.0,
|
77
|
+
],
|
78
|
+
]
|
79
|
+
|
80
|
+
|
81
|
+
class TestDiffusionLoss(unittest.TestCase):
|
82
|
+
|
83
|
+
@parameterized.expand(TEST_CASES)
|
84
|
+
def test_shape(self, input_param, input_data, expected_val):
|
85
|
+
result = DiffusionLoss(**input_param).forward(**input_data)
|
86
|
+
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)
|
87
|
+
|
88
|
+
def test_ill_shape(self):
|
89
|
+
loss = DiffusionLoss()
|
90
|
+
# not in 3-d, 4-d, 5-d
|
91
|
+
with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"):
|
92
|
+
loss.forward(torch.ones((1, 3), device=device))
|
93
|
+
with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"):
|
94
|
+
loss.forward(torch.ones((1, 4, 5, 5, 5, 5), device=device))
|
95
|
+
with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
|
96
|
+
loss.forward(torch.ones((1, 3, 2, 5, 5), device=device))
|
97
|
+
with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
|
98
|
+
loss.forward(torch.ones((1, 3, 5, 2, 5)))
|
99
|
+
with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
|
100
|
+
loss.forward(torch.ones((1, 3, 5, 5, 2)))
|
101
|
+
|
102
|
+
# number of vector components unequal to number of spatial dims
|
103
|
+
with self.assertRaisesRegex(ValueError, "Number of vector components"):
|
104
|
+
loss.forward(torch.ones((1, 2, 5, 5, 5)))
|
105
|
+
with self.assertRaisesRegex(ValueError, "Number of vector components"):
|
106
|
+
loss.forward(torch.ones((1, 2, 5, 5, 5)))
|
107
|
+
|
108
|
+
def test_ill_opts(self):
|
109
|
+
pred = torch.rand(1, 3, 5, 5, 5).to(device=device)
|
110
|
+
with self.assertRaisesRegex(ValueError, ""):
|
111
|
+
DiffusionLoss(reduction="unknown")(pred)
|
112
|
+
with self.assertRaisesRegex(ValueError, ""):
|
113
|
+
DiffusionLoss(reduction=None)(pred)
|
114
|
+
|
115
|
+
|
116
|
+
if __name__ == "__main__":
|
117
|
+
unittest.main()
|
@@ -0,0 +1,10 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
@@ -0,0 +1,150 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
from __future__ import annotations
|
12
|
+
|
13
|
+
import os
|
14
|
+
import unittest
|
15
|
+
from pathlib import Path
|
16
|
+
|
17
|
+
import numpy as np
|
18
|
+
import torch
|
19
|
+
from parameterized import parameterized
|
20
|
+
|
21
|
+
from monai import transforms
|
22
|
+
from monai.losses.image_dissimilarity import GlobalMutualInformationLoss
|
23
|
+
from tests.test_utils import SkipIfBeforePyTorchVersion, download_url_or_skip_test, skip_if_quick, testing_data_config
|
24
|
+
|
25
|
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
26
|
+
|
27
|
+
TESTS_PATH = Path(__file__).parents[2]
|
28
|
+
FILE_PATH = os.path.join(TESTS_PATH, "testing_data", "temp_" + "mri.nii")
|
29
|
+
|
30
|
+
EXPECTED_VALUE = {
|
31
|
+
"xyz_translation": [
|
32
|
+
-1.5860257,
|
33
|
+
-0.62433463,
|
34
|
+
-0.38217825,
|
35
|
+
-0.2905613,
|
36
|
+
-0.23233329,
|
37
|
+
-0.1961407,
|
38
|
+
-0.16905619,
|
39
|
+
-0.15100679,
|
40
|
+
-0.13666219,
|
41
|
+
-0.12635908,
|
42
|
+
],
|
43
|
+
"xyz_rotation": [
|
44
|
+
-1.5860257,
|
45
|
+
-0.30265224,
|
46
|
+
-0.18666176,
|
47
|
+
-0.15887907,
|
48
|
+
-0.1625064,
|
49
|
+
-0.16603896,
|
50
|
+
-0.19222091,
|
51
|
+
-0.18158069,
|
52
|
+
-0.167644,
|
53
|
+
-0.16698098,
|
54
|
+
],
|
55
|
+
}
|
56
|
+
|
57
|
+
|
58
|
+
@skip_if_quick
|
59
|
+
class TestGlobalMutualInformationLoss(unittest.TestCase):
|
60
|
+
def setUp(self):
|
61
|
+
config = testing_data_config("images", "Prostate_T2W_AX_1")
|
62
|
+
download_url_or_skip_test(
|
63
|
+
url=config["url"],
|
64
|
+
filepath=FILE_PATH,
|
65
|
+
hash_val=config.get("hash_val"),
|
66
|
+
hash_type=config.get("hash_type", "sha256"),
|
67
|
+
)
|
68
|
+
|
69
|
+
@SkipIfBeforePyTorchVersion((1, 9))
|
70
|
+
def test_bspline(self):
|
71
|
+
loss_fn = GlobalMutualInformationLoss(kernel_type="b-spline", num_bins=32, sigma_ratio=0.015)
|
72
|
+
|
73
|
+
transform_params_dict = {
|
74
|
+
"xyz_translation": [(i, i, i) for i in range(10)],
|
75
|
+
"xyz_rotation": [(np.pi / 100 * i, np.pi / 100 * i, np.pi / 100 * i) for i in range(10)],
|
76
|
+
}
|
77
|
+
|
78
|
+
def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0.0)):
|
79
|
+
"""
|
80
|
+
Read and transform Prostate_T2W_AX_1.nii
|
81
|
+
Args:
|
82
|
+
translate_params: a tuple of 3 floats, translation is in pixel/voxel relative to the center of the input
|
83
|
+
image. Defaults to no translation.
|
84
|
+
rotate_params: a rotation angle in radians, a tuple of 3 floats for 3D.
|
85
|
+
Defaults to no rotation.
|
86
|
+
Returns:
|
87
|
+
numpy array of shape HWD
|
88
|
+
"""
|
89
|
+
transform_list = [
|
90
|
+
transforms.LoadImaged(keys="img", image_only=True),
|
91
|
+
transforms.Affined(
|
92
|
+
keys="img",
|
93
|
+
translate_params=translate_params,
|
94
|
+
rotate_params=rotate_params,
|
95
|
+
device=None,
|
96
|
+
padding_mode="border",
|
97
|
+
),
|
98
|
+
transforms.NormalizeIntensityd(keys=["img"]),
|
99
|
+
]
|
100
|
+
transformation = transforms.Compose(transform_list)
|
101
|
+
return transformation({"img": FILE_PATH})["img"]
|
102
|
+
|
103
|
+
a1 = transformation()
|
104
|
+
a1 = a1.clone().unsqueeze(0).unsqueeze(0).to(device)
|
105
|
+
|
106
|
+
for mode in transform_params_dict:
|
107
|
+
transform_params_list = transform_params_dict[mode]
|
108
|
+
expected_value_list = EXPECTED_VALUE[mode]
|
109
|
+
for transform_params, expected_value in zip(transform_params_list, expected_value_list):
|
110
|
+
a2 = transformation(
|
111
|
+
translate_params=transform_params if "translation" in mode else (0.0, 0.0, 0.0),
|
112
|
+
rotate_params=transform_params if "rotation" in mode else (0.0, 0.0, 0.0),
|
113
|
+
)
|
114
|
+
a2 = a2.clone().unsqueeze(0).unsqueeze(0).to(device)
|
115
|
+
result = loss_fn(a2, a1).detach().cpu().numpy()
|
116
|
+
np.testing.assert_allclose(result, expected_value, rtol=0.08, atol=0.05)
|
117
|
+
|
118
|
+
|
119
|
+
class TestGlobalMutualInformationLossIll(unittest.TestCase):
|
120
|
+
@parameterized.expand(
|
121
|
+
[
|
122
|
+
(torch.ones((1, 2), dtype=torch.float), torch.ones((1, 3), dtype=torch.float)), # mismatched_simple_dims
|
123
|
+
(
|
124
|
+
torch.ones((1, 3, 3), dtype=torch.float),
|
125
|
+
torch.ones((1, 3), dtype=torch.float),
|
126
|
+
), # mismatched_advanced_dims
|
127
|
+
]
|
128
|
+
)
|
129
|
+
def test_ill_shape(self, input1, input2):
|
130
|
+
loss = GlobalMutualInformationLoss()
|
131
|
+
with self.assertRaises(ValueError):
|
132
|
+
loss.forward(input1, input2)
|
133
|
+
|
134
|
+
@parameterized.expand(
|
135
|
+
[
|
136
|
+
(0, "mean", ValueError, ""), # num_bins_zero
|
137
|
+
(-1, "mean", ValueError, ""), # num_bins_negative
|
138
|
+
(64, "unknown", ValueError, ""), # reduction_unknown
|
139
|
+
(64, None, ValueError, ""), # reduction_none
|
140
|
+
]
|
141
|
+
)
|
142
|
+
def test_ill_opts(self, num_bins, reduction, expected_exception, expected_message):
|
143
|
+
pred = torch.ones((1, 3, 3, 3, 3), dtype=torch.float, device=device)
|
144
|
+
target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float, device=device)
|
145
|
+
with self.assertRaisesRegex(expected_exception, expected_message):
|
146
|
+
GlobalMutualInformationLoss(num_bins=num_bins, reduction=reduction)(pred, target)
|
147
|
+
|
148
|
+
|
149
|
+
if __name__ == "__main__":
|
150
|
+
unittest.main()
|