monai-weekly 1.5.dev2505__py3-none-any.whl → 1.5.dev2507__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/auto3dseg/transforms.py +1 -4
- monai/data/meta_tensor.py +5 -0
- monai/data/utils.py +6 -13
- monai/inferers/utils.py +1 -2
- monai/losses/dice.py +2 -14
- monai/losses/ds_loss.py +1 -3
- monai/networks/layers/simplelayers.py +2 -14
- monai/networks/utils.py +4 -16
- monai/transforms/compose.py +28 -11
- monai/transforms/croppad/array.py +1 -6
- monai/transforms/io/array.py +0 -1
- monai/transforms/transform.py +15 -6
- monai/transforms/utils.py +1 -2
- monai/utils/jupyter_utils.py +1 -1
- monai/utils/tf32.py +0 -10
- monai/visualize/class_activation_maps.py +5 -8
- monai/visualize/img2tensorboard.py +2 -2
- {monai_weekly-1.5.dev2505.dist-info → monai_weekly-1.5.dev2507.dist-info}/METADATA +2 -2
- monai_weekly-1.5.dev2507.dist-info/RECORD +1181 -0
- {monai_weekly-1.5.dev2505.dist-info → monai_weekly-1.5.dev2507.dist-info}/top_level.txt +1 -0
- tests/apps/__init__.py +10 -0
- tests/apps/deepedit/__init__.py +10 -0
- tests/apps/deepedit/test_deepedit_transforms.py +314 -0
- tests/apps/deepgrow/__init__.py +10 -0
- tests/apps/deepgrow/test_deepgrow_dataset.py +109 -0
- tests/apps/deepgrow/transforms/__init__.py +10 -0
- tests/apps/deepgrow/transforms/test_deepgrow_interaction.py +97 -0
- tests/apps/deepgrow/transforms/test_deepgrow_transforms.py +556 -0
- tests/apps/detection/__init__.py +10 -0
- tests/apps/detection/metrics/__init__.py +10 -0
- tests/apps/detection/metrics/test_detection_coco_metrics.py +69 -0
- tests/apps/detection/networks/__init__.py +10 -0
- tests/apps/detection/networks/test_retinanet.py +210 -0
- tests/apps/detection/networks/test_retinanet_detector.py +203 -0
- tests/apps/detection/test_box_transform.py +370 -0
- tests/apps/detection/utils/__init__.py +10 -0
- tests/apps/detection/utils/test_anchor_box.py +88 -0
- tests/apps/detection/utils/test_atss_box_matcher.py +46 -0
- tests/apps/detection/utils/test_box_coder.py +43 -0
- tests/apps/detection/utils/test_detector_boxselector.py +67 -0
- tests/apps/detection/utils/test_detector_utils.py +96 -0
- tests/apps/detection/utils/test_hardnegsampler.py +54 -0
- tests/apps/nuclick/__init__.py +10 -0
- tests/apps/nuclick/test_nuclick_transforms.py +259 -0
- tests/apps/pathology/__init__.py +10 -0
- tests/apps/pathology/handlers/__init__.py +10 -0
- tests/apps/pathology/handlers/test_from_engine_hovernet.py +38 -0
- tests/apps/pathology/test_lesion_froc.py +333 -0
- tests/apps/pathology/test_pathology_prob_nms.py +55 -0
- tests/apps/pathology/test_prepare_batch_hovernet.py +70 -0
- tests/apps/pathology/test_sliding_window_hovernet_inference.py +303 -0
- tests/apps/pathology/transforms/__init__.py +10 -0
- tests/apps/pathology/transforms/post/__init__.py +10 -0
- tests/apps/pathology/transforms/post/test_generate_distance_map.py +51 -0
- tests/apps/pathology/transforms/post/test_generate_distance_mapd.py +70 -0
- tests/apps/pathology/transforms/post/test_generate_instance_border.py +49 -0
- tests/apps/pathology/transforms/post/test_generate_instance_borderd.py +59 -0
- tests/apps/pathology/transforms/post/test_generate_instance_centroid.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_instance_centroidd.py +56 -0
- tests/apps/pathology/transforms/post/test_generate_instance_contour.py +58 -0
- tests/apps/pathology/transforms/post/test_generate_instance_contourd.py +61 -0
- tests/apps/pathology/transforms/post/test_generate_instance_type.py +51 -0
- tests/apps/pathology/transforms/post/test_generate_instance_typed.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_succinct_contour.py +55 -0
- tests/apps/pathology/transforms/post/test_generate_succinct_contourd.py +57 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_markers.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_markersd.py +83 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_mask.py +77 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_maskd.py +77 -0
- tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processing.py +61 -0
- tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processingd.py +66 -0
- tests/apps/pathology/transforms/post/test_hovernet_nuclear_type_post_processing.py +66 -0
- tests/apps/pathology/transforms/post/test_watershed.py +60 -0
- tests/apps/pathology/transforms/post/test_watershedd.py +70 -0
- tests/apps/pathology/transforms/test_pathology_he_stain.py +230 -0
- tests/apps/pathology/transforms/test_pathology_he_stain_dict.py +225 -0
- tests/apps/reconstruction/__init__.py +10 -0
- tests/apps/reconstruction/nets/__init__.py +10 -0
- tests/apps/reconstruction/nets/test_recon_net_utils.py +82 -0
- tests/apps/reconstruction/test_complex_utils.py +77 -0
- tests/apps/reconstruction/test_fastmri_reader.py +82 -0
- tests/apps/reconstruction/test_mri_utils.py +37 -0
- tests/apps/reconstruction/transforms/__init__.py +10 -0
- tests/apps/reconstruction/transforms/test_kspace_mask.py +50 -0
- tests/apps/reconstruction/transforms/test_reference_based_normalize_intensity.py +77 -0
- tests/apps/reconstruction/transforms/test_reference_based_spatial_cropd.py +57 -0
- tests/apps/test_auto3dseg_bundlegen.py +156 -0
- tests/apps/test_check_hash.py +53 -0
- tests/apps/test_cross_validation.py +74 -0
- tests/apps/test_decathlondataset.py +93 -0
- tests/apps/test_download_and_extract.py +70 -0
- tests/apps/test_download_url_yandex.py +45 -0
- tests/apps/test_mednistdataset.py +72 -0
- tests/apps/test_mmar_download.py +154 -0
- tests/apps/test_tciadataset.py +123 -0
- tests/apps/vista3d/__init__.py +10 -0
- tests/apps/vista3d/test_point_based_window_inferer.py +77 -0
- tests/apps/vista3d/test_vista3d_sampler.py +100 -0
- tests/apps/vista3d/test_vista3d_transforms.py +94 -0
- tests/bundle/__init__.py +10 -0
- tests/bundle/test_bundle_ckpt_export.py +107 -0
- tests/bundle/test_bundle_download.py +435 -0
- tests/bundle/test_bundle_get_data.py +94 -0
- tests/bundle/test_bundle_push_to_hf_hub.py +41 -0
- tests/bundle/test_bundle_trt_export.py +147 -0
- tests/bundle/test_bundle_utils.py +149 -0
- tests/bundle/test_bundle_verify_metadata.py +66 -0
- tests/bundle/test_bundle_verify_net.py +76 -0
- tests/bundle/test_bundle_workflow.py +272 -0
- tests/bundle/test_component_locator.py +38 -0
- tests/bundle/test_config_item.py +138 -0
- tests/bundle/test_config_parser.py +392 -0
- tests/bundle/test_reference_resolver.py +114 -0
- tests/config/__init__.py +10 -0
- tests/config/test_cv2_dist.py +53 -0
- tests/engines/__init__.py +10 -0
- tests/engines/test_ensemble_evaluator.py +94 -0
- tests/engines/test_prepare_batch_default.py +76 -0
- tests/engines/test_prepare_batch_default_dist.py +76 -0
- tests/engines/test_prepare_batch_diffusion.py +104 -0
- tests/engines/test_prepare_batch_extra_input.py +80 -0
- tests/fl/__init__.py +10 -0
- tests/fl/monai_algo/__init__.py +10 -0
- tests/fl/monai_algo/test_fl_monai_algo.py +251 -0
- tests/fl/monai_algo/test_fl_monai_algo_dist.py +117 -0
- tests/fl/test_fl_monai_algo_stats.py +81 -0
- tests/fl/utils/__init__.py +10 -0
- tests/fl/utils/test_fl_exchange_object.py +63 -0
- tests/handlers/__init__.py +10 -0
- tests/handlers/test_handler_checkpoint_loader.py +182 -0
- tests/handlers/test_handler_checkpoint_saver.py +233 -0
- tests/handlers/test_handler_classification_saver.py +64 -0
- tests/handlers/test_handler_classification_saver_dist.py +77 -0
- tests/handlers/test_handler_clearml_image.py +65 -0
- tests/handlers/test_handler_clearml_stats.py +65 -0
- tests/handlers/test_handler_confusion_matrix.py +104 -0
- tests/handlers/test_handler_confusion_matrix_dist.py +70 -0
- tests/handlers/test_handler_decollate_batch.py +66 -0
- tests/handlers/test_handler_early_stop.py +68 -0
- tests/handlers/test_handler_garbage_collector.py +73 -0
- tests/handlers/test_handler_hausdorff_distance.py +111 -0
- tests/handlers/test_handler_ignite_metric.py +191 -0
- tests/handlers/test_handler_lr_scheduler.py +94 -0
- tests/handlers/test_handler_mean_dice.py +98 -0
- tests/handlers/test_handler_mean_iou.py +76 -0
- tests/handlers/test_handler_metrics_reloaded.py +149 -0
- tests/handlers/test_handler_metrics_saver.py +89 -0
- tests/handlers/test_handler_metrics_saver_dist.py +120 -0
- tests/handlers/test_handler_mlflow.py +296 -0
- tests/handlers/test_handler_nvtx.py +93 -0
- tests/handlers/test_handler_panoptic_quality.py +89 -0
- tests/handlers/test_handler_parameter_scheduler.py +136 -0
- tests/handlers/test_handler_post_processing.py +74 -0
- tests/handlers/test_handler_prob_map_producer.py +111 -0
- tests/handlers/test_handler_regression_metrics.py +160 -0
- tests/handlers/test_handler_regression_metrics_dist.py +245 -0
- tests/handlers/test_handler_rocauc.py +48 -0
- tests/handlers/test_handler_rocauc_dist.py +54 -0
- tests/handlers/test_handler_stats.py +281 -0
- tests/handlers/test_handler_surface_distance.py +113 -0
- tests/handlers/test_handler_tb_image.py +61 -0
- tests/handlers/test_handler_tb_stats.py +166 -0
- tests/handlers/test_handler_validation.py +59 -0
- tests/handlers/test_trt_compile.py +145 -0
- tests/handlers/test_write_metrics_reports.py +68 -0
- tests/inferers/__init__.py +10 -0
- tests/inferers/test_avg_merger.py +179 -0
- tests/inferers/test_controlnet_inferers.py +1310 -0
- tests/inferers/test_diffusion_inferer.py +236 -0
- tests/inferers/test_latent_diffusion_inferer.py +824 -0
- tests/inferers/test_patch_inferer.py +309 -0
- tests/inferers/test_saliency_inferer.py +55 -0
- tests/inferers/test_slice_inferer.py +57 -0
- tests/inferers/test_sliding_window_inference.py +377 -0
- tests/inferers/test_sliding_window_splitter.py +284 -0
- tests/inferers/test_wsi_sliding_window_splitter.py +249 -0
- tests/inferers/test_zarr_avg_merger.py +326 -0
- tests/integration/__init__.py +10 -0
- tests/integration/test_auto3dseg_ensemble.py +211 -0
- tests/integration/test_auto3dseg_hpo.py +189 -0
- tests/integration/test_deepedit_interaction.py +122 -0
- tests/integration/test_downsample_block.py +50 -0
- tests/integration/test_hovernet_nuclear_type_post_processingd.py +71 -0
- tests/integration/test_integration_autorunner.py +201 -0
- tests/integration/test_integration_bundle_run.py +240 -0
- tests/integration/test_integration_classification_2d.py +282 -0
- tests/integration/test_integration_determinism.py +95 -0
- tests/integration/test_integration_fast_train.py +231 -0
- tests/integration/test_integration_gpu_customization.py +159 -0
- tests/integration/test_integration_lazy_samples.py +219 -0
- tests/integration/test_integration_nnunetv2_runner.py +96 -0
- tests/integration/test_integration_segmentation_3d.py +304 -0
- tests/integration/test_integration_sliding_window.py +100 -0
- tests/integration/test_integration_stn.py +133 -0
- tests/integration/test_integration_unet_2d.py +67 -0
- tests/integration/test_integration_workers.py +61 -0
- tests/integration/test_integration_workflows.py +365 -0
- tests/integration/test_integration_workflows_adversarial.py +173 -0
- tests/integration/test_integration_workflows_gan.py +158 -0
- tests/integration/test_loader_semaphore.py +48 -0
- tests/integration/test_mapping_filed.py +122 -0
- tests/integration/test_meta_affine.py +183 -0
- tests/integration/test_metatensor_integration.py +114 -0
- tests/integration/test_module_list.py +76 -0
- tests/integration/test_one_of.py +283 -0
- tests/integration/test_pad_collation.py +124 -0
- tests/integration/test_reg_loss_integration.py +107 -0
- tests/integration/test_retinanet_predict_utils.py +154 -0
- tests/integration/test_seg_loss_integration.py +159 -0
- tests/integration/test_spatial_combine_transforms.py +185 -0
- tests/integration/test_testtimeaugmentation.py +186 -0
- tests/integration/test_vis_gradbased.py +69 -0
- tests/integration/test_vista3d_utils.py +159 -0
- tests/losses/__init__.py +10 -0
- tests/losses/deform/__init__.py +10 -0
- tests/losses/deform/test_bending_energy.py +88 -0
- tests/losses/deform/test_diffusion_loss.py +117 -0
- tests/losses/image_dissimilarity/__init__.py +10 -0
- tests/losses/image_dissimilarity/test_global_mutual_information_loss.py +150 -0
- tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py +162 -0
- tests/losses/test_adversarial_loss.py +94 -0
- tests/losses/test_barlow_twins_loss.py +109 -0
- tests/losses/test_cldice_loss.py +51 -0
- tests/losses/test_contrastive_loss.py +86 -0
- tests/losses/test_dice_ce_loss.py +123 -0
- tests/losses/test_dice_focal_loss.py +124 -0
- tests/losses/test_dice_loss.py +227 -0
- tests/losses/test_ds_loss.py +189 -0
- tests/losses/test_focal_loss.py +379 -0
- tests/losses/test_generalized_dice_focal_loss.py +85 -0
- tests/losses/test_generalized_dice_loss.py +221 -0
- tests/losses/test_generalized_wasserstein_dice_loss.py +234 -0
- tests/losses/test_giou_loss.py +62 -0
- tests/losses/test_hausdorff_loss.py +264 -0
- tests/losses/test_masked_dice_loss.py +152 -0
- tests/losses/test_masked_loss.py +87 -0
- tests/losses/test_multi_scale.py +86 -0
- tests/losses/test_nacl_loss.py +167 -0
- tests/losses/test_perceptual_loss.py +122 -0
- tests/losses/test_spectral_loss.py +86 -0
- tests/losses/test_ssim_loss.py +59 -0
- tests/losses/test_sure_loss.py +72 -0
- tests/losses/test_tversky_loss.py +198 -0
- tests/losses/test_unified_focal_loss.py +66 -0
- tests/metrics/__init__.py +10 -0
- tests/metrics/test_compute_confusion_matrix.py +294 -0
- tests/metrics/test_compute_f_beta.py +80 -0
- tests/metrics/test_compute_fid_metric.py +40 -0
- tests/metrics/test_compute_froc.py +143 -0
- tests/metrics/test_compute_generalized_dice.py +240 -0
- tests/metrics/test_compute_meandice.py +306 -0
- tests/metrics/test_compute_meaniou.py +223 -0
- tests/metrics/test_compute_mmd_metric.py +56 -0
- tests/metrics/test_compute_multiscalessim_metric.py +83 -0
- tests/metrics/test_compute_panoptic_quality.py +113 -0
- tests/metrics/test_compute_regression_metrics.py +196 -0
- tests/metrics/test_compute_roc_auc.py +155 -0
- tests/metrics/test_compute_variance.py +147 -0
- tests/metrics/test_cumulative.py +63 -0
- tests/metrics/test_cumulative_average.py +74 -0
- tests/metrics/test_cumulative_average_dist.py +48 -0
- tests/metrics/test_hausdorff_distance.py +209 -0
- tests/metrics/test_label_quality_score.py +134 -0
- tests/metrics/test_loss_metric.py +57 -0
- tests/metrics/test_metrics_reloaded.py +96 -0
- tests/metrics/test_ssim_metric.py +78 -0
- tests/metrics/test_surface_dice.py +416 -0
- tests/metrics/test_surface_distance.py +186 -0
- tests/networks/__init__.py +10 -0
- tests/networks/blocks/__init__.py +10 -0
- tests/networks/blocks/dints_block/__init__.py +10 -0
- tests/networks/blocks/dints_block/test_acn_block.py +41 -0
- tests/networks/blocks/dints_block/test_factorized_increase.py +37 -0
- tests/networks/blocks/dints_block/test_factorized_reduce.py +37 -0
- tests/networks/blocks/dints_block/test_p3d_block.py +78 -0
- tests/networks/blocks/test_adn.py +86 -0
- tests/networks/blocks/test_convolutions.py +156 -0
- tests/networks/blocks/test_crf_cpu.py +513 -0
- tests/networks/blocks/test_crf_cuda.py +528 -0
- tests/networks/blocks/test_crossattention.py +185 -0
- tests/networks/blocks/test_denseblock.py +105 -0
- tests/networks/blocks/test_dynunet_block.py +116 -0
- tests/networks/blocks/test_fpn_block.py +88 -0
- tests/networks/blocks/test_localnet_block.py +121 -0
- tests/networks/blocks/test_mlp.py +78 -0
- tests/networks/blocks/test_patchembedding.py +212 -0
- tests/networks/blocks/test_regunet_block.py +103 -0
- tests/networks/blocks/test_se_block.py +85 -0
- tests/networks/blocks/test_se_blocks.py +78 -0
- tests/networks/blocks/test_segresnet_block.py +57 -0
- tests/networks/blocks/test_selfattention.py +232 -0
- tests/networks/blocks/test_simple_aspp.py +87 -0
- tests/networks/blocks/test_spatialattention.py +55 -0
- tests/networks/blocks/test_subpixel_upsample.py +87 -0
- tests/networks/blocks/test_text_encoding.py +49 -0
- tests/networks/blocks/test_transformerblock.py +90 -0
- tests/networks/blocks/test_unetr_block.py +158 -0
- tests/networks/blocks/test_upsample_block.py +134 -0
- tests/networks/blocks/warp/__init__.py +10 -0
- tests/networks/blocks/warp/test_dvf2ddf.py +72 -0
- tests/networks/blocks/warp/test_warp.py +250 -0
- tests/networks/layers/__init__.py +10 -0
- tests/networks/layers/filtering/__init__.py +10 -0
- tests/networks/layers/filtering/test_bilateral_approx_cpu.py +399 -0
- tests/networks/layers/filtering/test_bilateral_approx_cuda.py +404 -0
- tests/networks/layers/filtering/test_bilateral_precise.py +437 -0
- tests/networks/layers/filtering/test_phl_cpu.py +259 -0
- tests/networks/layers/filtering/test_phl_cuda.py +167 -0
- tests/networks/layers/filtering/test_trainable_bilateral.py +474 -0
- tests/networks/layers/filtering/test_trainable_joint_bilateral.py +609 -0
- tests/networks/layers/test_affine_transform.py +385 -0
- tests/networks/layers/test_apply_filter.py +89 -0
- tests/networks/layers/test_channel_pad.py +51 -0
- tests/networks/layers/test_conjugate_gradient.py +56 -0
- tests/networks/layers/test_drop_path.py +46 -0
- tests/networks/layers/test_gaussian.py +317 -0
- tests/networks/layers/test_gaussian_filter.py +206 -0
- tests/networks/layers/test_get_layers.py +65 -0
- tests/networks/layers/test_gmm.py +314 -0
- tests/networks/layers/test_grid_pull.py +93 -0
- tests/networks/layers/test_hilbert_transform.py +131 -0
- tests/networks/layers/test_lltm.py +62 -0
- tests/networks/layers/test_median_filter.py +52 -0
- tests/networks/layers/test_polyval.py +55 -0
- tests/networks/layers/test_preset_filters.py +136 -0
- tests/networks/layers/test_savitzky_golay_filter.py +141 -0
- tests/networks/layers/test_separable_filter.py +87 -0
- tests/networks/layers/test_skip_connection.py +48 -0
- tests/networks/layers/test_vector_quantizer.py +89 -0
- tests/networks/layers/test_weight_init.py +50 -0
- tests/networks/nets/__init__.py +10 -0
- tests/networks/nets/dints/__init__.py +10 -0
- tests/networks/nets/dints/test_dints_cell.py +110 -0
- tests/networks/nets/dints/test_dints_mixop.py +84 -0
- tests/networks/nets/regunet/__init__.py +10 -0
- tests/networks/nets/regunet/test_localnet.py +86 -0
- tests/networks/nets/regunet/test_regunet.py +88 -0
- tests/networks/nets/test_ahnet.py +224 -0
- tests/networks/nets/test_attentionunet.py +88 -0
- tests/networks/nets/test_autoencoder.py +95 -0
- tests/networks/nets/test_autoencoderkl.py +337 -0
- tests/networks/nets/test_basic_unet.py +102 -0
- tests/networks/nets/test_basic_unetplusplus.py +109 -0
- tests/networks/nets/test_bundle_init_bundle.py +55 -0
- tests/networks/nets/test_cell_sam_wrapper.py +58 -0
- tests/networks/nets/test_controlnet.py +215 -0
- tests/networks/nets/test_daf3d.py +62 -0
- tests/networks/nets/test_densenet.py +121 -0
- tests/networks/nets/test_diffusion_model_unet.py +585 -0
- tests/networks/nets/test_dints_network.py +168 -0
- tests/networks/nets/test_discriminator.py +59 -0
- tests/networks/nets/test_dynunet.py +181 -0
- tests/networks/nets/test_efficientnet.py +400 -0
- tests/networks/nets/test_flexible_unet.py +341 -0
- tests/networks/nets/test_fullyconnectednet.py +69 -0
- tests/networks/nets/test_generator.py +59 -0
- tests/networks/nets/test_globalnet.py +103 -0
- tests/networks/nets/test_highresnet.py +67 -0
- tests/networks/nets/test_hovernet.py +218 -0
- tests/networks/nets/test_mednext.py +122 -0
- tests/networks/nets/test_milmodel.py +92 -0
- tests/networks/nets/test_net_adapter.py +68 -0
- tests/networks/nets/test_network_consistency.py +86 -0
- tests/networks/nets/test_patch_gan_dicriminator.py +179 -0
- tests/networks/nets/test_quicknat.py +57 -0
- tests/networks/nets/test_resnet.py +340 -0
- tests/networks/nets/test_segresnet.py +120 -0
- tests/networks/nets/test_segresnet_ds.py +156 -0
- tests/networks/nets/test_senet.py +151 -0
- tests/networks/nets/test_spade_autoencoderkl.py +295 -0
- tests/networks/nets/test_spade_diffusion_model_unet.py +574 -0
- tests/networks/nets/test_spade_vaegan.py +140 -0
- tests/networks/nets/test_swin_unetr.py +139 -0
- tests/networks/nets/test_torchvision_fc_model.py +201 -0
- tests/networks/nets/test_transchex.py +84 -0
- tests/networks/nets/test_transformer.py +108 -0
- tests/networks/nets/test_unet.py +208 -0
- tests/networks/nets/test_unetr.py +137 -0
- tests/networks/nets/test_varautoencoder.py +127 -0
- tests/networks/nets/test_vista3d.py +84 -0
- tests/networks/nets/test_vit.py +139 -0
- tests/networks/nets/test_vitautoenc.py +112 -0
- tests/networks/nets/test_vnet.py +81 -0
- tests/networks/nets/test_voxelmorph.py +280 -0
- tests/networks/nets/test_vqvae.py +274 -0
- tests/networks/nets/test_vqvaetransformer_inferer.py +295 -0
- tests/networks/schedulers/__init__.py +10 -0
- tests/networks/schedulers/test_scheduler_ddim.py +83 -0
- tests/networks/schedulers/test_scheduler_ddpm.py +104 -0
- tests/networks/schedulers/test_scheduler_pndm.py +108 -0
- tests/networks/test_bundle_onnx_export.py +71 -0
- tests/networks/test_convert_to_onnx.py +106 -0
- tests/networks/test_convert_to_torchscript.py +46 -0
- tests/networks/test_convert_to_trt.py +79 -0
- tests/networks/test_save_state.py +73 -0
- tests/networks/test_to_onehot.py +63 -0
- tests/networks/test_varnet.py +63 -0
- tests/networks/utils/__init__.py +10 -0
- tests/networks/utils/test_copy_model_state.py +187 -0
- tests/networks/utils/test_eval_mode.py +34 -0
- tests/networks/utils/test_freeze_layers.py +61 -0
- tests/networks/utils/test_replace_module.py +98 -0
- tests/networks/utils/test_train_mode.py +34 -0
- tests/optimizers/__init__.py +10 -0
- tests/optimizers/test_generate_param_groups.py +105 -0
- tests/optimizers/test_lr_finder.py +108 -0
- tests/optimizers/test_lr_scheduler.py +71 -0
- tests/optimizers/test_optim_novograd.py +100 -0
- tests/profile_subclass/__init__.py +10 -0
- tests/profile_subclass/cprofile_profiling.py +29 -0
- tests/profile_subclass/min_classes.py +30 -0
- tests/profile_subclass/profiling.py +73 -0
- tests/profile_subclass/pyspy_profiling.py +41 -0
- tests/transforms/__init__.py +10 -0
- tests/transforms/compose/__init__.py +10 -0
- tests/transforms/compose/test_compose.py +758 -0
- tests/transforms/compose/test_some_of.py +258 -0
- tests/transforms/croppad/__init__.py +10 -0
- tests/transforms/croppad/test_rand_weighted_crop.py +224 -0
- tests/transforms/croppad/test_rand_weighted_cropd.py +182 -0
- tests/transforms/functional/__init__.py +10 -0
- tests/transforms/functional/test_apply.py +75 -0
- tests/transforms/functional/test_resample.py +50 -0
- tests/transforms/intensity/__init__.py +10 -0
- tests/transforms/intensity/test_compute_ho_ver_maps.py +75 -0
- tests/transforms/intensity/test_compute_ho_ver_maps_d.py +79 -0
- tests/transforms/intensity/test_foreground_mask.py +98 -0
- tests/transforms/intensity/test_foreground_maskd.py +106 -0
- tests/transforms/intensity/test_rand_histogram_shiftd.py +76 -0
- tests/transforms/intensity/test_scale_intensity_range_percentiles.py +96 -0
- tests/transforms/intensity/test_scale_intensity_range_percentilesd.py +100 -0
- tests/transforms/inverse/__init__.py +10 -0
- tests/transforms/inverse/test_inverse_array.py +76 -0
- tests/transforms/inverse/test_traceable_transform.py +59 -0
- tests/transforms/post/__init__.py +10 -0
- tests/transforms/post/test_label_filterd.py +78 -0
- tests/transforms/post/test_probnms.py +72 -0
- tests/transforms/post/test_probnmsd.py +79 -0
- tests/transforms/post/test_remove_small_objects.py +102 -0
- tests/transforms/spatial/__init__.py +10 -0
- tests/transforms/spatial/test_convert_box_points.py +119 -0
- tests/transforms/spatial/test_grid_patch.py +134 -0
- tests/transforms/spatial/test_grid_patchd.py +102 -0
- tests/transforms/spatial/test_rand_grid_patch.py +150 -0
- tests/transforms/spatial/test_rand_grid_patchd.py +117 -0
- tests/transforms/spatial/test_spatial_resampled.py +124 -0
- tests/transforms/test_activations.py +120 -0
- tests/transforms/test_activationsd.py +64 -0
- tests/transforms/test_adaptors.py +160 -0
- tests/transforms/test_add_coordinate_channels.py +53 -0
- tests/transforms/test_add_coordinate_channelsd.py +67 -0
- tests/transforms/test_add_extreme_points_channel.py +80 -0
- tests/transforms/test_add_extreme_points_channeld.py +77 -0
- tests/transforms/test_adjust_contrast.py +70 -0
- tests/transforms/test_adjust_contrastd.py +64 -0
- tests/transforms/test_affine.py +245 -0
- tests/transforms/test_affine_grid.py +152 -0
- tests/transforms/test_affined.py +190 -0
- tests/transforms/test_as_channel_last.py +38 -0
- tests/transforms/test_as_channel_lastd.py +44 -0
- tests/transforms/test_as_discrete.py +81 -0
- tests/transforms/test_as_discreted.py +82 -0
- tests/transforms/test_border_pad.py +49 -0
- tests/transforms/test_border_padd.py +45 -0
- tests/transforms/test_bounding_rect.py +54 -0
- tests/transforms/test_bounding_rectd.py +53 -0
- tests/transforms/test_cast_to_type.py +63 -0
- tests/transforms/test_cast_to_typed.py +74 -0
- tests/transforms/test_center_scale_crop.py +55 -0
- tests/transforms/test_center_scale_cropd.py +56 -0
- tests/transforms/test_center_spatial_crop.py +56 -0
- tests/transforms/test_center_spatial_cropd.py +63 -0
- tests/transforms/test_classes_to_indices.py +93 -0
- tests/transforms/test_classes_to_indicesd.py +110 -0
- tests/transforms/test_clip_intensity_percentiles.py +196 -0
- tests/transforms/test_clip_intensity_percentilesd.py +193 -0
- tests/transforms/test_compose_get_number_conversions.py +127 -0
- tests/transforms/test_concat_itemsd.py +82 -0
- tests/transforms/test_convert_to_multi_channel.py +59 -0
- tests/transforms/test_convert_to_multi_channeld.py +37 -0
- tests/transforms/test_copy_itemsd.py +86 -0
- tests/transforms/test_create_grid_and_affine.py +274 -0
- tests/transforms/test_crop_foreground.py +164 -0
- tests/transforms/test_crop_foregroundd.py +205 -0
- tests/transforms/test_cucim_dict_transform.py +142 -0
- tests/transforms/test_cucim_transform.py +141 -0
- tests/transforms/test_data_stats.py +221 -0
- tests/transforms/test_data_statsd.py +249 -0
- tests/transforms/test_delete_itemsd.py +58 -0
- tests/transforms/test_detect_envelope.py +159 -0
- tests/transforms/test_distance_transform_edt.py +202 -0
- tests/transforms/test_divisible_pad.py +49 -0
- tests/transforms/test_divisible_padd.py +42 -0
- tests/transforms/test_ensure_channel_first.py +113 -0
- tests/transforms/test_ensure_channel_firstd.py +85 -0
- tests/transforms/test_ensure_type.py +94 -0
- tests/transforms/test_ensure_typed.py +110 -0
- tests/transforms/test_fg_bg_to_indices.py +83 -0
- tests/transforms/test_fg_bg_to_indicesd.py +78 -0
- tests/transforms/test_fill_holes.py +207 -0
- tests/transforms/test_fill_holesd.py +209 -0
- tests/transforms/test_flatten_sub_keysd.py +64 -0
- tests/transforms/test_flip.py +83 -0
- tests/transforms/test_flipd.py +90 -0
- tests/transforms/test_fourier.py +70 -0
- tests/transforms/test_gaussian_sharpen.py +92 -0
- tests/transforms/test_gaussian_sharpend.py +92 -0
- tests/transforms/test_gaussian_smooth.py +96 -0
- tests/transforms/test_gaussian_smoothd.py +96 -0
- tests/transforms/test_generate_label_classes_crop_centers.py +71 -0
- tests/transforms/test_generate_pos_neg_label_crop_centers.py +76 -0
- tests/transforms/test_generate_spatial_bounding_box.py +114 -0
- tests/transforms/test_get_extreme_points.py +57 -0
- tests/transforms/test_gibbs_noise.py +75 -0
- tests/transforms/test_gibbs_noised.py +88 -0
- tests/transforms/test_grid_distortion.py +113 -0
- tests/transforms/test_grid_distortiond.py +87 -0
- tests/transforms/test_grid_split.py +88 -0
- tests/transforms/test_grid_splitd.py +96 -0
- tests/transforms/test_histogram_normalize.py +59 -0
- tests/transforms/test_histogram_normalized.py +59 -0
- tests/transforms/test_image_filter.py +259 -0
- tests/transforms/test_intensity_stats.py +73 -0
- tests/transforms/test_intensity_statsd.py +90 -0
- tests/transforms/test_inverse.py +521 -0
- tests/transforms/test_inverse_collation.py +147 -0
- tests/transforms/test_invert.py +105 -0
- tests/transforms/test_invertd.py +142 -0
- tests/transforms/test_k_space_spike_noise.py +81 -0
- tests/transforms/test_k_space_spike_noised.py +98 -0
- tests/transforms/test_keep_largest_connected_component.py +419 -0
- tests/transforms/test_keep_largest_connected_componentd.py +348 -0
- tests/transforms/test_label_filter.py +78 -0
- tests/transforms/test_label_to_contour.py +179 -0
- tests/transforms/test_label_to_contourd.py +182 -0
- tests/transforms/test_label_to_mask.py +69 -0
- tests/transforms/test_label_to_maskd.py +70 -0
- tests/transforms/test_load_image.py +502 -0
- tests/transforms/test_load_imaged.py +198 -0
- tests/transforms/test_load_spacing_orientation.py +149 -0
- tests/transforms/test_map_and_generate_sampling_centers.py +86 -0
- tests/transforms/test_map_binary_to_indices.py +75 -0
- tests/transforms/test_map_classes_to_indices.py +135 -0
- tests/transforms/test_map_label_value.py +89 -0
- tests/transforms/test_map_label_valued.py +85 -0
- tests/transforms/test_map_transform.py +45 -0
- tests/transforms/test_mask_intensity.py +74 -0
- tests/transforms/test_mask_intensityd.py +68 -0
- tests/transforms/test_mean_ensemble.py +77 -0
- tests/transforms/test_mean_ensembled.py +91 -0
- tests/transforms/test_median_smooth.py +41 -0
- tests/transforms/test_median_smoothd.py +65 -0
- tests/transforms/test_morphological_ops.py +101 -0
- tests/transforms/test_nifti_endianness.py +107 -0
- tests/transforms/test_normalize_intensity.py +143 -0
- tests/transforms/test_normalize_intensityd.py +81 -0
- tests/transforms/test_nvtx_decorator.py +289 -0
- tests/transforms/test_nvtx_transform.py +143 -0
- tests/transforms/test_orientation.py +247 -0
- tests/transforms/test_orientationd.py +112 -0
- tests/transforms/test_rand_adjust_contrast.py +45 -0
- tests/transforms/test_rand_adjust_contrastd.py +44 -0
- tests/transforms/test_rand_affine.py +201 -0
- tests/transforms/test_rand_affine_grid.py +212 -0
- tests/transforms/test_rand_affined.py +281 -0
- tests/transforms/test_rand_axis_flip.py +50 -0
- tests/transforms/test_rand_axis_flipd.py +50 -0
- tests/transforms/test_rand_bias_field.py +69 -0
- tests/transforms/test_rand_bias_fieldd.py +65 -0
- tests/transforms/test_rand_coarse_dropout.py +110 -0
- tests/transforms/test_rand_coarse_dropoutd.py +107 -0
- tests/transforms/test_rand_coarse_shuffle.py +65 -0
- tests/transforms/test_rand_coarse_shuffled.py +59 -0
- tests/transforms/test_rand_crop_by_label_classes.py +170 -0
- tests/transforms/test_rand_crop_by_label_classesd.py +159 -0
- tests/transforms/test_rand_crop_by_pos_neg_label.py +152 -0
- tests/transforms/test_rand_crop_by_pos_neg_labeld.py +172 -0
- tests/transforms/test_rand_cucim_dict_transform.py +162 -0
- tests/transforms/test_rand_cucim_transform.py +162 -0
- tests/transforms/test_rand_deform_grid.py +138 -0
- tests/transforms/test_rand_elastic_2d.py +127 -0
- tests/transforms/test_rand_elastic_3d.py +104 -0
- tests/transforms/test_rand_elasticd_2d.py +177 -0
- tests/transforms/test_rand_elasticd_3d.py +156 -0
- tests/transforms/test_rand_flip.py +60 -0
- tests/transforms/test_rand_flipd.py +55 -0
- tests/transforms/test_rand_gaussian_noise.py +48 -0
- tests/transforms/test_rand_gaussian_noised.py +54 -0
- tests/transforms/test_rand_gaussian_sharpen.py +140 -0
- tests/transforms/test_rand_gaussian_sharpend.py +143 -0
- tests/transforms/test_rand_gaussian_smooth.py +98 -0
- tests/transforms/test_rand_gaussian_smoothd.py +98 -0
- tests/transforms/test_rand_gibbs_noise.py +103 -0
- tests/transforms/test_rand_gibbs_noised.py +117 -0
- tests/transforms/test_rand_grid_distortion.py +99 -0
- tests/transforms/test_rand_grid_distortiond.py +90 -0
- tests/transforms/test_rand_histogram_shift.py +92 -0
- tests/transforms/test_rand_k_space_spike_noise.py +92 -0
- tests/transforms/test_rand_k_space_spike_noised.py +76 -0
- tests/transforms/test_rand_rician_noise.py +52 -0
- tests/transforms/test_rand_rician_noised.py +52 -0
- tests/transforms/test_rand_rotate.py +166 -0
- tests/transforms/test_rand_rotate90.py +100 -0
- tests/transforms/test_rand_rotate90d.py +112 -0
- tests/transforms/test_rand_rotated.py +187 -0
- tests/transforms/test_rand_scale_crop.py +78 -0
- tests/transforms/test_rand_scale_cropd.py +98 -0
- tests/transforms/test_rand_scale_intensity.py +54 -0
- tests/transforms/test_rand_scale_intensity_fixed_mean.py +41 -0
- tests/transforms/test_rand_scale_intensity_fixed_meand.py +41 -0
- tests/transforms/test_rand_scale_intensityd.py +53 -0
- tests/transforms/test_rand_shift_intensity.py +52 -0
- tests/transforms/test_rand_shift_intensityd.py +67 -0
- tests/transforms/test_rand_simulate_low_resolution.py +83 -0
- tests/transforms/test_rand_simulate_low_resolutiond.py +73 -0
- tests/transforms/test_rand_spatial_crop.py +107 -0
- tests/transforms/test_rand_spatial_crop_samples.py +128 -0
- tests/transforms/test_rand_spatial_crop_samplesd.py +147 -0
- tests/transforms/test_rand_spatial_cropd.py +112 -0
- tests/transforms/test_rand_std_shift_intensity.py +43 -0
- tests/transforms/test_rand_std_shift_intensityd.py +38 -0
- tests/transforms/test_rand_zoom.py +105 -0
- tests/transforms/test_rand_zoomd.py +108 -0
- tests/transforms/test_randidentity.py +49 -0
- tests/transforms/test_random_order.py +144 -0
- tests/transforms/test_randtorchvisiond.py +65 -0
- tests/transforms/test_regularization.py +139 -0
- tests/transforms/test_remove_repeated_channel.py +34 -0
- tests/transforms/test_remove_repeated_channeld.py +44 -0
- tests/transforms/test_repeat_channel.py +34 -0
- tests/transforms/test_repeat_channeld.py +41 -0
- tests/transforms/test_resample_backends.py +65 -0
- tests/transforms/test_resample_to_match.py +110 -0
- tests/transforms/test_resample_to_matchd.py +93 -0
- tests/transforms/test_resampler.py +165 -0
- tests/transforms/test_resize.py +140 -0
- tests/transforms/test_resize_with_pad_or_crop.py +91 -0
- tests/transforms/test_resize_with_pad_or_cropd.py +86 -0
- tests/transforms/test_resized.py +163 -0
- tests/transforms/test_rotate.py +160 -0
- tests/transforms/test_rotate90.py +212 -0
- tests/transforms/test_rotate90d.py +106 -0
- tests/transforms/test_rotated.py +179 -0
- tests/transforms/test_save_classificationd.py +109 -0
- tests/transforms/test_save_image.py +80 -0
- tests/transforms/test_save_imaged.py +130 -0
- tests/transforms/test_savitzky_golay_smooth.py +73 -0
- tests/transforms/test_savitzky_golay_smoothd.py +73 -0
- tests/transforms/test_scale_intensity.py +76 -0
- tests/transforms/test_scale_intensity_fixed_mean.py +94 -0
- tests/transforms/test_scale_intensity_range.py +41 -0
- tests/transforms/test_scale_intensity_ranged.py +40 -0
- tests/transforms/test_scale_intensityd.py +57 -0
- tests/transforms/test_select_itemsd.py +41 -0
- tests/transforms/test_shift_intensity.py +31 -0
- tests/transforms/test_shift_intensityd.py +44 -0
- tests/transforms/test_signal_continuouswavelet.py +44 -0
- tests/transforms/test_signal_fillempty.py +52 -0
- tests/transforms/test_signal_fillemptyd.py +60 -0
- tests/transforms/test_signal_rand_add_gaussiannoise.py +50 -0
- tests/transforms/test_signal_rand_add_sine.py +52 -0
- tests/transforms/test_signal_rand_add_sine_partial.py +50 -0
- tests/transforms/test_signal_rand_add_squarepulse.py +58 -0
- tests/transforms/test_signal_rand_add_squarepulse_partial.py +62 -0
- tests/transforms/test_signal_rand_drop.py +50 -0
- tests/transforms/test_signal_rand_scale.py +52 -0
- tests/transforms/test_signal_rand_shift.py +55 -0
- tests/transforms/test_signal_remove_frequency.py +71 -0
- tests/transforms/test_smooth_field.py +177 -0
- tests/transforms/test_sobel_gradient.py +189 -0
- tests/transforms/test_sobel_gradientd.py +212 -0
- tests/transforms/test_spacing.py +381 -0
- tests/transforms/test_spacingd.py +178 -0
- tests/transforms/test_spatial_crop.py +82 -0
- tests/transforms/test_spatial_cropd.py +74 -0
- tests/transforms/test_spatial_pad.py +57 -0
- tests/transforms/test_spatial_padd.py +43 -0
- tests/transforms/test_spatial_resample.py +235 -0
- tests/transforms/test_squeezedim.py +62 -0
- tests/transforms/test_squeezedimd.py +98 -0
- tests/transforms/test_std_shift_intensity.py +76 -0
- tests/transforms/test_std_shift_intensityd.py +74 -0
- tests/transforms/test_threshold_intensity.py +38 -0
- tests/transforms/test_threshold_intensityd.py +58 -0
- tests/transforms/test_to_contiguous.py +47 -0
- tests/transforms/test_to_cupy.py +112 -0
- tests/transforms/test_to_cupyd.py +76 -0
- tests/transforms/test_to_device.py +42 -0
- tests/transforms/test_to_deviced.py +37 -0
- tests/transforms/test_to_numpy.py +85 -0
- tests/transforms/test_to_numpyd.py +68 -0
- tests/transforms/test_to_pil.py +52 -0
- tests/transforms/test_to_pild.py +50 -0
- tests/transforms/test_to_tensor.py +60 -0
- tests/transforms/test_to_tensord.py +71 -0
- tests/transforms/test_torchvision.py +66 -0
- tests/transforms/test_torchvisiond.py +63 -0
- tests/transforms/test_transform.py +62 -0
- tests/transforms/test_transpose.py +41 -0
- tests/transforms/test_transposed.py +52 -0
- tests/transforms/test_ultrasound_confidence_map_transform.py +711 -0
- tests/transforms/test_utils_pytorch_numpy_unification.py +90 -0
- tests/transforms/test_vote_ensemble.py +84 -0
- tests/transforms/test_vote_ensembled.py +107 -0
- tests/transforms/test_with_allow_missing_keys.py +76 -0
- tests/transforms/test_zoom.py +120 -0
- tests/transforms/test_zoomd.py +94 -0
- tests/transforms/transform/__init__.py +10 -0
- tests/transforms/transform/test_randomizable.py +52 -0
- tests/transforms/transform/test_randomizable_transform_type.py +37 -0
- tests/transforms/utility/__init__.py +10 -0
- tests/transforms/utility/test_apply_transform_to_points.py +81 -0
- tests/transforms/utility/test_apply_transform_to_pointsd.py +185 -0
- tests/transforms/utility/test_identity.py +29 -0
- tests/transforms/utility/test_identityd.py +30 -0
- tests/transforms/utility/test_lambda.py +71 -0
- tests/transforms/utility/test_lambdad.py +83 -0
- tests/transforms/utility/test_rand_lambda.py +87 -0
- tests/transforms/utility/test_rand_lambdad.py +77 -0
- tests/transforms/utility/test_simulatedelay.py +36 -0
- tests/transforms/utility/test_simulatedelayd.py +36 -0
- tests/transforms/utility/test_splitdim.py +52 -0
- tests/transforms/utility/test_splitdimd.py +96 -0
- tests/transforms/utils/__init__.py +10 -0
- tests/transforms/utils/test_correct_crop_centers.py +36 -0
- tests/transforms/utils/test_get_unique_labels.py +45 -0
- tests/transforms/utils/test_print_transform_backends.py +29 -0
- tests/transforms/utils/test_soft_clip.py +125 -0
- tests/utils/__init__.py +10 -0
- tests/utils/enums/__init__.py +10 -0
- tests/utils/enums/test_hovernet_loss.py +190 -0
- tests/utils/enums/test_ordering.py +289 -0
- tests/utils/enums/test_wsireader.py +663 -0
- tests/utils/misc/__init__.py +10 -0
- tests/utils/misc/test_ensure_tuple.py +53 -0
- tests/utils/misc/test_monai_env_vars.py +44 -0
- tests/utils/misc/test_monai_utils_misc.py +103 -0
- tests/utils/misc/test_str2bool.py +34 -0
- tests/utils/misc/test_str2list.py +33 -0
- tests/utils/test_alias.py +44 -0
- tests/utils/test_component_store.py +73 -0
- tests/utils/test_deprecated.py +455 -0
- tests/utils/test_enum_bound_interp.py +75 -0
- tests/utils/test_evenly_divisible_all_gather_dist.py +50 -0
- tests/utils/test_get_package_version.py +34 -0
- tests/utils/test_handler_logfile.py +84 -0
- tests/utils/test_handler_metric_logger.py +62 -0
- tests/utils/test_list_to_dict.py +43 -0
- tests/utils/test_look_up_option.py +87 -0
- tests/utils/test_optional_import.py +80 -0
- tests/utils/test_pad_mode.py +39 -0
- tests/utils/test_profiling.py +208 -0
- tests/utils/test_rankfilter_dist.py +77 -0
- tests/utils/test_require_pkg.py +83 -0
- tests/utils/test_sample_slices.py +43 -0
- tests/utils/test_set_determinism.py +74 -0
- tests/utils/test_squeeze_unsqueeze.py +71 -0
- tests/utils/test_state_cacher.py +67 -0
- tests/utils/test_torchscript_utils.py +113 -0
- tests/utils/test_version.py +91 -0
- tests/utils/test_version_after.py +65 -0
- tests/utils/type_conversion/__init__.py +10 -0
- tests/utils/type_conversion/test_convert_data_type.py +152 -0
- tests/utils/type_conversion/test_get_equivalent_dtype.py +65 -0
- tests/utils/type_conversion/test_safe_dtype_range.py +99 -0
- tests/visualize/__init__.py +10 -0
- tests/visualize/test_img2tensorboard.py +46 -0
- tests/visualize/test_occlusion_sensitivity.py +128 -0
- tests/visualize/test_plot_2d_or_3d_image.py +74 -0
- tests/visualize/test_vis_cam.py +98 -0
- tests/visualize/test_vis_gradcam.py +211 -0
- tests/visualize/utils/__init__.py +10 -0
- tests/visualize/utils/test_blend_images.py +63 -0
- tests/visualize/utils/test_matshow3d.py +133 -0
- monai_weekly-1.5.dev2505.dist-info/RECORD +0 -427
- {monai_weekly-1.5.dev2505.dist-info → monai_weekly-1.5.dev2507.dist-info}/LICENSE +0 -0
- {monai_weekly-1.5.dev2505.dist-info → monai_weekly-1.5.dev2507.dist-info}/WHEEL +0 -0
@@ -0,0 +1,139 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
import torch
|
18
|
+
|
19
|
+
from monai.transforms import CutMix, CutMixd, CutOut, CutOutd, MixUp, MixUpd
|
20
|
+
from tests.test_utils import assert_allclose
|
21
|
+
|
22
|
+
|
23
|
+
class TestMixup(unittest.TestCase):
|
24
|
+
def test_mixup(self):
|
25
|
+
for dims in [2, 3]:
|
26
|
+
shape = (6, 3) + (32,) * dims
|
27
|
+
sample = torch.rand(*shape, dtype=torch.float32)
|
28
|
+
mixup = MixUp(6, 1.0)
|
29
|
+
mixup.set_random_state(seed=0)
|
30
|
+
output = mixup(sample)
|
31
|
+
np.random.seed(0)
|
32
|
+
# simulate the randomize() of transform
|
33
|
+
np.random.random()
|
34
|
+
weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32)
|
35
|
+
perm = np.random.permutation(6)
|
36
|
+
self.assertEqual(output.shape, sample.shape)
|
37
|
+
mixweight = weight[(Ellipsis,) + (None,) * (dims + 1)]
|
38
|
+
expected = mixweight * sample + (1 - mixweight) * sample[perm, ...]
|
39
|
+
assert_allclose(output, expected, type_test=False, atol=1e-7)
|
40
|
+
|
41
|
+
with self.assertRaises(ValueError):
|
42
|
+
MixUp(6, -0.5)
|
43
|
+
|
44
|
+
mixup = MixUp(6, 0.5)
|
45
|
+
for dims in [2, 3]:
|
46
|
+
with self.assertRaises(ValueError):
|
47
|
+
shape = (5, 3) + (32,) * dims
|
48
|
+
sample = torch.rand(*shape, dtype=torch.float32)
|
49
|
+
mixup(sample)
|
50
|
+
|
51
|
+
def test_mixupd(self):
|
52
|
+
for dims in [2, 3]:
|
53
|
+
shape = (6, 3) + (32,) * dims
|
54
|
+
t = torch.rand(*shape, dtype=torch.float32)
|
55
|
+
sample = {"a": t, "b": t}
|
56
|
+
mixup = MixUpd(["a", "b"], 6)
|
57
|
+
mixup.set_random_state(seed=0)
|
58
|
+
output = mixup(sample)
|
59
|
+
np.random.seed(0)
|
60
|
+
# simulate the randomize() of transform
|
61
|
+
np.random.random()
|
62
|
+
weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32)
|
63
|
+
perm = np.random.permutation(6)
|
64
|
+
self.assertEqual(output["a"].shape, sample["a"].shape)
|
65
|
+
mixweight = weight[(Ellipsis,) + (None,) * (dims + 1)]
|
66
|
+
expected = mixweight * sample["a"] + (1 - mixweight) * sample["a"][perm, ...]
|
67
|
+
assert_allclose(output["a"], expected, type_test=False, atol=1e-7)
|
68
|
+
assert_allclose(output["a"], output["b"], type_test=False, atol=1e-7)
|
69
|
+
# self.assertTrue(torch.allclose(output["a"], output["b"]))
|
70
|
+
|
71
|
+
with self.assertRaises(ValueError):
|
72
|
+
MixUpd(["k1", "k2"], 6, -0.5)
|
73
|
+
|
74
|
+
|
75
|
+
class TestCutMix(unittest.TestCase):
|
76
|
+
def test_cutmix(self):
|
77
|
+
for dims in [2, 3]:
|
78
|
+
shape = (6, 3) + (32,) * dims
|
79
|
+
sample = torch.rand(*shape, dtype=torch.float32)
|
80
|
+
cutmix = CutMix(6, 1.0)
|
81
|
+
cutmix.set_random_state(seed=0)
|
82
|
+
output = cutmix(sample)
|
83
|
+
self.assertEqual(output.shape, sample.shape)
|
84
|
+
self.assertTrue(any(not torch.allclose(sample, cutmix(sample)) for _ in range(10)))
|
85
|
+
|
86
|
+
def test_cutmixd(self):
|
87
|
+
for dims in [2, 3]:
|
88
|
+
shape = (6, 3) + (32,) * dims
|
89
|
+
t = torch.rand(*shape, dtype=torch.float32)
|
90
|
+
label = torch.randint(0, 1, shape)
|
91
|
+
sample = {"a": t, "b": t, "lbl1": label, "lbl2": label}
|
92
|
+
cutmix = CutMixd(["a", "b"], 6, label_keys=("lbl1", "lbl2"))
|
93
|
+
cutmix.set_random_state(seed=123)
|
94
|
+
output = cutmix(sample)
|
95
|
+
# but mixing of labels is not affected by it
|
96
|
+
self.assertTrue(torch.allclose(output["lbl1"], output["lbl2"]))
|
97
|
+
|
98
|
+
|
99
|
+
class TestCutOut(unittest.TestCase):
|
100
|
+
def test_cutout(self):
|
101
|
+
for dims in [2, 3]:
|
102
|
+
shape = (6, 3) + (32,) * dims
|
103
|
+
sample = torch.rand(*shape, dtype=torch.float32)
|
104
|
+
cutout = CutOut(6, 1.0)
|
105
|
+
cutout.set_random_state(seed=123)
|
106
|
+
output = cutout(sample)
|
107
|
+
np.random.seed(123)
|
108
|
+
# simulate the randomize() of transform
|
109
|
+
np.random.random()
|
110
|
+
weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32)
|
111
|
+
perm = np.random.permutation(6)
|
112
|
+
coords = [torch.from_numpy(np.random.randint(0, d, size=(1,))) for d in sample.shape[2:]]
|
113
|
+
assert_allclose(weight, cutout._params[0])
|
114
|
+
assert_allclose(perm, cutout._params[1])
|
115
|
+
self.assertSequenceEqual(coords, cutout._params[2])
|
116
|
+
self.assertEqual(output.shape, sample.shape)
|
117
|
+
|
118
|
+
def test_cutoutd(self):
|
119
|
+
for dims in [2, 3]:
|
120
|
+
shape = (6, 3) + (32,) * dims
|
121
|
+
t = torch.rand(*shape, dtype=torch.float32)
|
122
|
+
sample = {"a": t, "b": t}
|
123
|
+
cutout = CutOutd(["a", "b"], 6, 1.0)
|
124
|
+
cutout.set_random_state(seed=123)
|
125
|
+
output = cutout(sample)
|
126
|
+
np.random.seed(123)
|
127
|
+
# simulate the randomize() of transform
|
128
|
+
np.random.random()
|
129
|
+
weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32)
|
130
|
+
perm = np.random.permutation(6)
|
131
|
+
coords = [torch.from_numpy(np.random.randint(0, d, size=(1,))) for d in t.shape[2:]]
|
132
|
+
assert_allclose(weight, cutout.cutout._params[0])
|
133
|
+
assert_allclose(perm, cutout.cutout._params[1])
|
134
|
+
self.assertSequenceEqual(coords, cutout.cutout._params[2])
|
135
|
+
self.assertEqual(output["a"].shape, sample["a"].shape)
|
136
|
+
|
137
|
+
|
138
|
+
if __name__ == "__main__":
|
139
|
+
unittest.main()
|
@@ -0,0 +1,34 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
from parameterized import parameterized
|
17
|
+
|
18
|
+
from monai.transforms import RemoveRepeatedChannel
|
19
|
+
from tests.test_utils import TEST_NDARRAYS
|
20
|
+
|
21
|
+
TEST_CASES = []
|
22
|
+
for q in TEST_NDARRAYS:
|
23
|
+
TEST_CASES.append([{"repeats": 2}, q([[1, 2], [1, 2], [3, 4], [3, 4]]), (2, 2)])
|
24
|
+
|
25
|
+
|
26
|
+
class TestRemoveRepeatedChannel(unittest.TestCase):
|
27
|
+
@parameterized.expand(TEST_CASES)
|
28
|
+
def test_shape(self, input_param, input_data, expected_shape):
|
29
|
+
result = RemoveRepeatedChannel(**input_param)(input_data)
|
30
|
+
self.assertEqual(result.shape, expected_shape)
|
31
|
+
|
32
|
+
|
33
|
+
if __name__ == "__main__":
|
34
|
+
unittest.main()
|
@@ -0,0 +1,44 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
from parameterized import parameterized
|
18
|
+
|
19
|
+
from monai.transforms import RemoveRepeatedChanneld
|
20
|
+
from tests.test_utils import TEST_NDARRAYS
|
21
|
+
|
22
|
+
TESTS = []
|
23
|
+
for p in TEST_NDARRAYS:
|
24
|
+
TESTS.append(
|
25
|
+
[
|
26
|
+
{"keys": ["img"], "repeats": 2},
|
27
|
+
{
|
28
|
+
"img": p(np.array([[1, 2], [1, 2], [3, 4], [3, 4]])),
|
29
|
+
"seg": p(np.array([[1, 2], [1, 2], [3, 4], [3, 4]])),
|
30
|
+
},
|
31
|
+
(2, 2),
|
32
|
+
]
|
33
|
+
)
|
34
|
+
|
35
|
+
|
36
|
+
class TestRemoveRepeatedChanneld(unittest.TestCase):
|
37
|
+
@parameterized.expand(TESTS)
|
38
|
+
def test_shape(self, input_param, input_data, expected_shape):
|
39
|
+
result = RemoveRepeatedChanneld(**input_param)(input_data)
|
40
|
+
self.assertEqual(result["img"].shape, expected_shape)
|
41
|
+
|
42
|
+
|
43
|
+
if __name__ == "__main__":
|
44
|
+
unittest.main()
|
@@ -0,0 +1,34 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
from parameterized import parameterized
|
17
|
+
|
18
|
+
from monai.transforms import RepeatChannel
|
19
|
+
from tests.test_utils import TEST_NDARRAYS
|
20
|
+
|
21
|
+
TESTS = []
|
22
|
+
for p in TEST_NDARRAYS:
|
23
|
+
TESTS.append([{"repeats": 3}, p([[[0, 1], [1, 2]]]), (3, 2, 2)])
|
24
|
+
|
25
|
+
|
26
|
+
class TestRepeatChannel(unittest.TestCase):
|
27
|
+
@parameterized.expand(TESTS)
|
28
|
+
def test_shape(self, input_param, input_data, expected_shape):
|
29
|
+
result = RepeatChannel(**input_param)(input_data)
|
30
|
+
self.assertEqual(result.shape, expected_shape)
|
31
|
+
|
32
|
+
|
33
|
+
if __name__ == "__main__":
|
34
|
+
unittest.main()
|
@@ -0,0 +1,41 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
from parameterized import parameterized
|
18
|
+
|
19
|
+
from monai.transforms import RepeatChanneld
|
20
|
+
from tests.test_utils import TEST_NDARRAYS
|
21
|
+
|
22
|
+
TESTS = []
|
23
|
+
for p in TEST_NDARRAYS:
|
24
|
+
TESTS.append(
|
25
|
+
[
|
26
|
+
{"keys": ["img"], "repeats": 3},
|
27
|
+
{"img": p(np.array([[[0, 1], [1, 2]]])), "seg": p(np.array([[[0, 1], [1, 2]]]))},
|
28
|
+
(3, 2, 2),
|
29
|
+
]
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
class TestRepeatChanneld(unittest.TestCase):
|
34
|
+
@parameterized.expand(TESTS)
|
35
|
+
def test_shape(self, input_param, input_data, expected_shape):
|
36
|
+
result = RepeatChanneld(**input_param)(input_data)
|
37
|
+
self.assertEqual(result["img"].shape, expected_shape)
|
38
|
+
|
39
|
+
|
40
|
+
if __name__ == "__main__":
|
41
|
+
unittest.main()
|
@@ -0,0 +1,65 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
import torch
|
18
|
+
from parameterized import parameterized
|
19
|
+
|
20
|
+
from monai.config import USE_COMPILED
|
21
|
+
from monai.data import MetaTensor
|
22
|
+
from monai.transforms import Resample
|
23
|
+
from monai.transforms.utils import create_grid
|
24
|
+
from monai.utils import GridSampleMode, GridSamplePadMode, NdimageMode, SplineMode, convert_to_numpy
|
25
|
+
from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose, is_tf32_env
|
26
|
+
|
27
|
+
_rtol = 1e-3 if is_tf32_env() else 1e-4
|
28
|
+
|
29
|
+
TEST_IDENTITY = []
|
30
|
+
for interp in GridSampleMode if not USE_COMPILED else ("nearest", "bilinear"): # type: ignore
|
31
|
+
for pad in GridSamplePadMode:
|
32
|
+
for p in (np.float32, np.float64):
|
33
|
+
for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]:
|
34
|
+
TEST_IDENTITY.append([dict(device=device), p, interp, pad, (1, 3, 4)])
|
35
|
+
if interp != "bicubic":
|
36
|
+
TEST_IDENTITY.append([dict(device=device), p, interp, pad, (1, 3, 5, 8)])
|
37
|
+
for interp_s in SplineMode if not USE_COMPILED else []: # type: ignore
|
38
|
+
for pad_s in NdimageMode:
|
39
|
+
for p_s in (int, float, np.float32, np.float64):
|
40
|
+
for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]:
|
41
|
+
TEST_IDENTITY.append([dict(device=device), p_s, interp_s, pad_s, (1, 20, 21)])
|
42
|
+
TEST_IDENTITY.append([dict(device=device), p_s, interp_s, pad_s, (1, 21, 23, 24)])
|
43
|
+
|
44
|
+
|
45
|
+
@SkipIfBeforePyTorchVersion((1, 9, 1))
|
46
|
+
class TestResampleBackends(unittest.TestCase):
|
47
|
+
@parameterized.expand(TEST_IDENTITY)
|
48
|
+
def test_resample_identity(self, input_param, im_type, interp, padding, input_shape):
|
49
|
+
"""test resampling of an identity grid with padding 2, im_type, interp, padding, input_shape"""
|
50
|
+
xform = Resample(dtype=im_type, **input_param)
|
51
|
+
n_elem = np.prod(input_shape)
|
52
|
+
img = convert_to_numpy(np.arange(n_elem).reshape(input_shape), dtype=im_type)
|
53
|
+
grid = create_grid(input_shape[1:], homogeneous=True, backend="numpy")
|
54
|
+
grid_p = np.stack([np.pad(g, 2, "constant") for g in grid]) # testing pad
|
55
|
+
output = xform(img=img, grid=grid_p, mode=interp, padding_mode=padding)
|
56
|
+
self.assertTrue(not torch.any(torch.isinf(output) | torch.isnan(output)))
|
57
|
+
self.assertIsInstance(output, MetaTensor)
|
58
|
+
slices = [slice(None)]
|
59
|
+
slices.extend([slice(2, -2) for _ in img.shape[1:]])
|
60
|
+
output_c = output[slices]
|
61
|
+
assert_allclose(output_c, img, rtol=_rtol, atol=1e-3, type_test="tensor")
|
62
|
+
|
63
|
+
|
64
|
+
if __name__ == "__main__":
|
65
|
+
unittest.main()
|
@@ -0,0 +1,110 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import itertools
|
15
|
+
import os
|
16
|
+
import random
|
17
|
+
import shutil
|
18
|
+
import string
|
19
|
+
import tempfile
|
20
|
+
import unittest
|
21
|
+
|
22
|
+
import nibabel as nib
|
23
|
+
import numpy as np
|
24
|
+
import torch
|
25
|
+
from parameterized import parameterized
|
26
|
+
|
27
|
+
from monai.data import MetaTensor
|
28
|
+
from monai.data.image_reader import ITKReader, NibabelReader
|
29
|
+
from monai.data.image_writer import ITKWriter
|
30
|
+
from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, ResampleToMatch, SaveImage, SaveImaged
|
31
|
+
from monai.utils import optional_import
|
32
|
+
from tests.lazy_transforms_utils import test_resampler_lazy
|
33
|
+
from tests.test_utils import assert_allclose, download_url_or_skip_test, testing_data_config
|
34
|
+
|
35
|
+
_, has_itk = optional_import("itk", allow_namespace_pkg=True)
|
36
|
+
|
37
|
+
TEST_CASES = ["itkreader", "nibabelreader"]
|
38
|
+
|
39
|
+
|
40
|
+
def get_rand_fname(len=10, suffix=".nii.gz"):
|
41
|
+
letters = string.ascii_letters
|
42
|
+
out = "".join(random.choice(letters) for _ in range(len))
|
43
|
+
out += suffix
|
44
|
+
return out
|
45
|
+
|
46
|
+
|
47
|
+
@unittest.skipUnless(has_itk, "itk not installed")
|
48
|
+
class TestResampleToMatch(unittest.TestCase):
|
49
|
+
@classmethod
|
50
|
+
def setUpClass(cls):
|
51
|
+
super(__class__, cls).setUpClass()
|
52
|
+
cls.fnames = []
|
53
|
+
cls.tmpdir = tempfile.mkdtemp()
|
54
|
+
for key in ("0000_t2_tse_tra_4", "0000_ep2d_diff_tra_7"):
|
55
|
+
fname = os.path.join(cls.tmpdir, f"test_{key}.nii.gz")
|
56
|
+
url = testing_data_config("images", key, "url")
|
57
|
+
hash_type = testing_data_config("images", key, "hash_type")
|
58
|
+
hash_val = testing_data_config("images", key, "hash_val")
|
59
|
+
download_url_or_skip_test(url=url, filepath=fname, hash_type=hash_type, hash_val=hash_val)
|
60
|
+
cls.fnames.append(fname)
|
61
|
+
|
62
|
+
@classmethod
|
63
|
+
def tearDownClass(cls):
|
64
|
+
shutil.rmtree(cls.tmpdir)
|
65
|
+
super(__class__, cls).tearDownClass()
|
66
|
+
|
67
|
+
@parameterized.expand(itertools.product([NibabelReader, ITKReader], ["monai.data.NibabelWriter", ITKWriter]))
|
68
|
+
def test_correct(self, reader, writer):
|
69
|
+
loader = Compose([LoadImaged(("im1", "im2"), reader=reader), EnsureChannelFirstd(("im1", "im2"))])
|
70
|
+
data = loader({"im1": self.fnames[0], "im2": self.fnames[1]})
|
71
|
+
tr = ResampleToMatch()
|
72
|
+
im_mod = tr(data["im2"], data["im1"])
|
73
|
+
|
74
|
+
# check lazy resample
|
75
|
+
call_param = {"img": data["im2"], "img_dst": data["im1"]}
|
76
|
+
test_resampler_lazy(tr, im_mod, init_param={}, call_param=call_param)
|
77
|
+
|
78
|
+
saver = SaveImaged(
|
79
|
+
"im3", output_dir=self.tmpdir, output_postfix="", separate_folder=False, writer=writer, resample=False
|
80
|
+
)
|
81
|
+
im_mod.meta["filename_or_obj"] = get_rand_fname()
|
82
|
+
saver({"im3": im_mod})
|
83
|
+
|
84
|
+
saved = nib.load(os.path.join(self.tmpdir, im_mod.meta["filename_or_obj"]))
|
85
|
+
assert_allclose(data["im1"].shape[1:], saved.shape)
|
86
|
+
assert_allclose(saved.header["dim"][:4], np.array([3, 384, 384, 19]))
|
87
|
+
|
88
|
+
def test_inverse(self):
|
89
|
+
loader = Compose([LoadImaged(("im1", "im2"), image_only=True), EnsureChannelFirstd(("im1", "im2"))])
|
90
|
+
data = loader({"im1": self.fnames[0], "im2": self.fnames[1]})
|
91
|
+
tr = ResampleToMatch()
|
92
|
+
im_mod = tr(data["im2"], data["im1"])
|
93
|
+
self.assertNotEqual(im_mod.shape, data["im2"].shape)
|
94
|
+
self.assertGreater(((im_mod.affine - data["im2"].affine) ** 2).sum() ** 0.5, 1e-2)
|
95
|
+
# inverse
|
96
|
+
im_mod2 = tr.inverse(im_mod)
|
97
|
+
self.assertEqual(im_mod2.shape, data["im2"].shape)
|
98
|
+
self.assertLess(((im_mod2.affine - data["im2"].affine) ** 2).sum() ** 0.5, 1e-2)
|
99
|
+
self.assertEqual(im_mod2.applied_operations, [])
|
100
|
+
|
101
|
+
def test_no_name(self):
|
102
|
+
img_1 = MetaTensor(torch.zeros(1, 2, 2, 2))
|
103
|
+
img_2 = MetaTensor(torch.zeros(1, 3, 3, 3))
|
104
|
+
im_mod = ResampleToMatch()(img_1, img_2)
|
105
|
+
self.assertEqual(im_mod.meta["filename_or_obj"], "resample_to_match_source")
|
106
|
+
SaveImage(output_dir=self.tmpdir, output_postfix="", separate_folder=False, resample=False)(im_mod)
|
107
|
+
|
108
|
+
|
109
|
+
if __name__ == "__main__":
|
110
|
+
unittest.main()
|
@@ -0,0 +1,93 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import os
|
15
|
+
import shutil
|
16
|
+
import tempfile
|
17
|
+
import unittest
|
18
|
+
|
19
|
+
from monai.transforms import (
|
20
|
+
Compose,
|
21
|
+
CopyItemsd,
|
22
|
+
EnsureChannelFirstd,
|
23
|
+
Invertd,
|
24
|
+
Lambda,
|
25
|
+
LoadImaged,
|
26
|
+
ResampleToMatchd,
|
27
|
+
SaveImaged,
|
28
|
+
)
|
29
|
+
from tests.lazy_transforms_utils import test_resampler_lazy
|
30
|
+
from tests.test_utils import assert_allclose, download_url_or_skip_test, testing_data_config
|
31
|
+
|
32
|
+
|
33
|
+
def update_fname(d):
|
34
|
+
d["im3"].meta["filename_or_obj"] = "file3.nii.gz"
|
35
|
+
return d
|
36
|
+
|
37
|
+
|
38
|
+
class TestResampleToMatchd(unittest.TestCase):
|
39
|
+
@classmethod
|
40
|
+
def setUpClass(cls):
|
41
|
+
super(__class__, cls).setUpClass()
|
42
|
+
cls.fnames = []
|
43
|
+
cls.tmpdir = tempfile.mkdtemp()
|
44
|
+
for key in ("0000_t2_tse_tra_4", "0000_ep2d_diff_tra_7"):
|
45
|
+
fname = os.path.join(cls.tmpdir, f"test_{key}.nii.gz")
|
46
|
+
url = testing_data_config("images", key, "url")
|
47
|
+
hash_type = testing_data_config("images", key, "hash_type")
|
48
|
+
hash_val = testing_data_config("images", key, "hash_val")
|
49
|
+
download_url_or_skip_test(url=url, filepath=fname, hash_type=hash_type, hash_val=hash_val)
|
50
|
+
cls.fnames.append(fname)
|
51
|
+
|
52
|
+
@classmethod
|
53
|
+
def tearDownClass(cls):
|
54
|
+
shutil.rmtree(cls.tmpdir)
|
55
|
+
super(__class__, cls).tearDownClass()
|
56
|
+
|
57
|
+
def test_correct(self):
|
58
|
+
transforms = Compose(
|
59
|
+
[
|
60
|
+
LoadImaged(("im1", "im2")),
|
61
|
+
EnsureChannelFirstd(("im1", "im2")),
|
62
|
+
CopyItemsd(("im2"), names=("im3")),
|
63
|
+
ResampleToMatchd("im3", "im1"),
|
64
|
+
Lambda(update_fname),
|
65
|
+
SaveImaged("im3", output_dir=self.tmpdir, output_postfix="", separate_folder=False, resample=False),
|
66
|
+
]
|
67
|
+
)
|
68
|
+
data = transforms({"im1": self.fnames[0], "im2": self.fnames[1]})
|
69
|
+
# check that output sizes match
|
70
|
+
assert_allclose(data["im1"].shape, data["im3"].shape)
|
71
|
+
# and that the meta data has been updated accordingly
|
72
|
+
assert_allclose(data["im3"].affine, data["im1"].affine)
|
73
|
+
# check we're different from the original
|
74
|
+
self.assertTrue(any(i != j for i, j in zip(data["im3"].shape, data["im2"].shape)))
|
75
|
+
self.assertTrue(any(i != j for i, j in zip(data["im3"].affine.flatten(), data["im2"].affine.flatten())))
|
76
|
+
# test the inverse
|
77
|
+
data = Invertd("im3", transforms)(data)
|
78
|
+
assert_allclose(data["im2"].shape, data["im3"].shape)
|
79
|
+
|
80
|
+
def test_lazy(self):
|
81
|
+
pre_transforms = Compose(
|
82
|
+
[LoadImaged(("im1", "im2")), EnsureChannelFirstd(("im1", "im2")), CopyItemsd(("im2"), names=("im3"))]
|
83
|
+
)
|
84
|
+
data = pre_transforms({"im1": self.fnames[0], "im2": self.fnames[1]})
|
85
|
+
init_param = {"keys": "im3", "key_dst": "im1"}
|
86
|
+
resampler = ResampleToMatchd(**init_param)
|
87
|
+
call_param = {"data": data}
|
88
|
+
non_lazy_out = resampler(**call_param)
|
89
|
+
test_resampler_lazy(resampler, non_lazy_out, init_param, call_param, output_key="im3")
|
90
|
+
|
91
|
+
|
92
|
+
if __name__ == "__main__":
|
93
|
+
unittest.main()
|