monai-weekly 1.5.dev2506__py3-none-any.whl → 1.5.dev2507__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/auto3dseg/transforms.py +1 -4
- monai/data/utils.py +6 -13
- monai/inferers/utils.py +1 -2
- monai/losses/dice.py +2 -14
- monai/losses/ds_loss.py +1 -3
- monai/networks/layers/simplelayers.py +2 -14
- monai/networks/utils.py +4 -16
- monai/transforms/compose.py +28 -11
- monai/transforms/croppad/array.py +1 -6
- monai/transforms/io/array.py +0 -1
- monai/transforms/transform.py +15 -6
- monai/transforms/utils.py +1 -2
- monai/utils/tf32.py +0 -10
- monai/visualize/class_activation_maps.py +5 -8
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2507.dist-info}/METADATA +2 -2
- monai_weekly-1.5.dev2507.dist-info/RECORD +1181 -0
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2507.dist-info}/top_level.txt +1 -0
- tests/apps/__init__.py +10 -0
- tests/apps/deepedit/__init__.py +10 -0
- tests/apps/deepedit/test_deepedit_transforms.py +314 -0
- tests/apps/deepgrow/__init__.py +10 -0
- tests/apps/deepgrow/test_deepgrow_dataset.py +109 -0
- tests/apps/deepgrow/transforms/__init__.py +10 -0
- tests/apps/deepgrow/transforms/test_deepgrow_interaction.py +97 -0
- tests/apps/deepgrow/transforms/test_deepgrow_transforms.py +556 -0
- tests/apps/detection/__init__.py +10 -0
- tests/apps/detection/metrics/__init__.py +10 -0
- tests/apps/detection/metrics/test_detection_coco_metrics.py +69 -0
- tests/apps/detection/networks/__init__.py +10 -0
- tests/apps/detection/networks/test_retinanet.py +210 -0
- tests/apps/detection/networks/test_retinanet_detector.py +203 -0
- tests/apps/detection/test_box_transform.py +370 -0
- tests/apps/detection/utils/__init__.py +10 -0
- tests/apps/detection/utils/test_anchor_box.py +88 -0
- tests/apps/detection/utils/test_atss_box_matcher.py +46 -0
- tests/apps/detection/utils/test_box_coder.py +43 -0
- tests/apps/detection/utils/test_detector_boxselector.py +67 -0
- tests/apps/detection/utils/test_detector_utils.py +96 -0
- tests/apps/detection/utils/test_hardnegsampler.py +54 -0
- tests/apps/nuclick/__init__.py +10 -0
- tests/apps/nuclick/test_nuclick_transforms.py +259 -0
- tests/apps/pathology/__init__.py +10 -0
- tests/apps/pathology/handlers/__init__.py +10 -0
- tests/apps/pathology/handlers/test_from_engine_hovernet.py +38 -0
- tests/apps/pathology/test_lesion_froc.py +333 -0
- tests/apps/pathology/test_pathology_prob_nms.py +55 -0
- tests/apps/pathology/test_prepare_batch_hovernet.py +70 -0
- tests/apps/pathology/test_sliding_window_hovernet_inference.py +303 -0
- tests/apps/pathology/transforms/__init__.py +10 -0
- tests/apps/pathology/transforms/post/__init__.py +10 -0
- tests/apps/pathology/transforms/post/test_generate_distance_map.py +51 -0
- tests/apps/pathology/transforms/post/test_generate_distance_mapd.py +70 -0
- tests/apps/pathology/transforms/post/test_generate_instance_border.py +49 -0
- tests/apps/pathology/transforms/post/test_generate_instance_borderd.py +59 -0
- tests/apps/pathology/transforms/post/test_generate_instance_centroid.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_instance_centroidd.py +56 -0
- tests/apps/pathology/transforms/post/test_generate_instance_contour.py +58 -0
- tests/apps/pathology/transforms/post/test_generate_instance_contourd.py +61 -0
- tests/apps/pathology/transforms/post/test_generate_instance_type.py +51 -0
- tests/apps/pathology/transforms/post/test_generate_instance_typed.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_succinct_contour.py +55 -0
- tests/apps/pathology/transforms/post/test_generate_succinct_contourd.py +57 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_markers.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_markersd.py +83 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_mask.py +77 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_maskd.py +77 -0
- tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processing.py +61 -0
- tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processingd.py +66 -0
- tests/apps/pathology/transforms/post/test_hovernet_nuclear_type_post_processing.py +66 -0
- tests/apps/pathology/transforms/post/test_watershed.py +60 -0
- tests/apps/pathology/transforms/post/test_watershedd.py +70 -0
- tests/apps/pathology/transforms/test_pathology_he_stain.py +230 -0
- tests/apps/pathology/transforms/test_pathology_he_stain_dict.py +225 -0
- tests/apps/reconstruction/__init__.py +10 -0
- tests/apps/reconstruction/nets/__init__.py +10 -0
- tests/apps/reconstruction/nets/test_recon_net_utils.py +82 -0
- tests/apps/reconstruction/test_complex_utils.py +77 -0
- tests/apps/reconstruction/test_fastmri_reader.py +82 -0
- tests/apps/reconstruction/test_mri_utils.py +37 -0
- tests/apps/reconstruction/transforms/__init__.py +10 -0
- tests/apps/reconstruction/transforms/test_kspace_mask.py +50 -0
- tests/apps/reconstruction/transforms/test_reference_based_normalize_intensity.py +77 -0
- tests/apps/reconstruction/transforms/test_reference_based_spatial_cropd.py +57 -0
- tests/apps/test_auto3dseg_bundlegen.py +156 -0
- tests/apps/test_check_hash.py +53 -0
- tests/apps/test_cross_validation.py +74 -0
- tests/apps/test_decathlondataset.py +93 -0
- tests/apps/test_download_and_extract.py +70 -0
- tests/apps/test_download_url_yandex.py +45 -0
- tests/apps/test_mednistdataset.py +72 -0
- tests/apps/test_mmar_download.py +154 -0
- tests/apps/test_tciadataset.py +123 -0
- tests/apps/vista3d/__init__.py +10 -0
- tests/apps/vista3d/test_point_based_window_inferer.py +77 -0
- tests/apps/vista3d/test_vista3d_sampler.py +100 -0
- tests/apps/vista3d/test_vista3d_transforms.py +94 -0
- tests/bundle/__init__.py +10 -0
- tests/bundle/test_bundle_ckpt_export.py +107 -0
- tests/bundle/test_bundle_download.py +435 -0
- tests/bundle/test_bundle_get_data.py +94 -0
- tests/bundle/test_bundle_push_to_hf_hub.py +41 -0
- tests/bundle/test_bundle_trt_export.py +147 -0
- tests/bundle/test_bundle_utils.py +149 -0
- tests/bundle/test_bundle_verify_metadata.py +66 -0
- tests/bundle/test_bundle_verify_net.py +76 -0
- tests/bundle/test_bundle_workflow.py +272 -0
- tests/bundle/test_component_locator.py +38 -0
- tests/bundle/test_config_item.py +138 -0
- tests/bundle/test_config_parser.py +392 -0
- tests/bundle/test_reference_resolver.py +114 -0
- tests/config/__init__.py +10 -0
- tests/config/test_cv2_dist.py +53 -0
- tests/engines/__init__.py +10 -0
- tests/engines/test_ensemble_evaluator.py +94 -0
- tests/engines/test_prepare_batch_default.py +76 -0
- tests/engines/test_prepare_batch_default_dist.py +76 -0
- tests/engines/test_prepare_batch_diffusion.py +104 -0
- tests/engines/test_prepare_batch_extra_input.py +80 -0
- tests/fl/__init__.py +10 -0
- tests/fl/monai_algo/__init__.py +10 -0
- tests/fl/monai_algo/test_fl_monai_algo.py +251 -0
- tests/fl/monai_algo/test_fl_monai_algo_dist.py +117 -0
- tests/fl/test_fl_monai_algo_stats.py +81 -0
- tests/fl/utils/__init__.py +10 -0
- tests/fl/utils/test_fl_exchange_object.py +63 -0
- tests/handlers/__init__.py +10 -0
- tests/handlers/test_handler_checkpoint_loader.py +182 -0
- tests/handlers/test_handler_checkpoint_saver.py +233 -0
- tests/handlers/test_handler_classification_saver.py +64 -0
- tests/handlers/test_handler_classification_saver_dist.py +77 -0
- tests/handlers/test_handler_clearml_image.py +65 -0
- tests/handlers/test_handler_clearml_stats.py +65 -0
- tests/handlers/test_handler_confusion_matrix.py +104 -0
- tests/handlers/test_handler_confusion_matrix_dist.py +70 -0
- tests/handlers/test_handler_decollate_batch.py +66 -0
- tests/handlers/test_handler_early_stop.py +68 -0
- tests/handlers/test_handler_garbage_collector.py +73 -0
- tests/handlers/test_handler_hausdorff_distance.py +111 -0
- tests/handlers/test_handler_ignite_metric.py +191 -0
- tests/handlers/test_handler_lr_scheduler.py +94 -0
- tests/handlers/test_handler_mean_dice.py +98 -0
- tests/handlers/test_handler_mean_iou.py +76 -0
- tests/handlers/test_handler_metrics_reloaded.py +149 -0
- tests/handlers/test_handler_metrics_saver.py +89 -0
- tests/handlers/test_handler_metrics_saver_dist.py +120 -0
- tests/handlers/test_handler_mlflow.py +296 -0
- tests/handlers/test_handler_nvtx.py +93 -0
- tests/handlers/test_handler_panoptic_quality.py +89 -0
- tests/handlers/test_handler_parameter_scheduler.py +136 -0
- tests/handlers/test_handler_post_processing.py +74 -0
- tests/handlers/test_handler_prob_map_producer.py +111 -0
- tests/handlers/test_handler_regression_metrics.py +160 -0
- tests/handlers/test_handler_regression_metrics_dist.py +245 -0
- tests/handlers/test_handler_rocauc.py +48 -0
- tests/handlers/test_handler_rocauc_dist.py +54 -0
- tests/handlers/test_handler_stats.py +281 -0
- tests/handlers/test_handler_surface_distance.py +113 -0
- tests/handlers/test_handler_tb_image.py +61 -0
- tests/handlers/test_handler_tb_stats.py +166 -0
- tests/handlers/test_handler_validation.py +59 -0
- tests/handlers/test_trt_compile.py +145 -0
- tests/handlers/test_write_metrics_reports.py +68 -0
- tests/inferers/__init__.py +10 -0
- tests/inferers/test_avg_merger.py +179 -0
- tests/inferers/test_controlnet_inferers.py +1310 -0
- tests/inferers/test_diffusion_inferer.py +236 -0
- tests/inferers/test_latent_diffusion_inferer.py +824 -0
- tests/inferers/test_patch_inferer.py +309 -0
- tests/inferers/test_saliency_inferer.py +55 -0
- tests/inferers/test_slice_inferer.py +57 -0
- tests/inferers/test_sliding_window_inference.py +377 -0
- tests/inferers/test_sliding_window_splitter.py +284 -0
- tests/inferers/test_wsi_sliding_window_splitter.py +249 -0
- tests/inferers/test_zarr_avg_merger.py +326 -0
- tests/integration/__init__.py +10 -0
- tests/integration/test_auto3dseg_ensemble.py +211 -0
- tests/integration/test_auto3dseg_hpo.py +189 -0
- tests/integration/test_deepedit_interaction.py +122 -0
- tests/integration/test_downsample_block.py +50 -0
- tests/integration/test_hovernet_nuclear_type_post_processingd.py +71 -0
- tests/integration/test_integration_autorunner.py +201 -0
- tests/integration/test_integration_bundle_run.py +240 -0
- tests/integration/test_integration_classification_2d.py +282 -0
- tests/integration/test_integration_determinism.py +95 -0
- tests/integration/test_integration_fast_train.py +231 -0
- tests/integration/test_integration_gpu_customization.py +159 -0
- tests/integration/test_integration_lazy_samples.py +219 -0
- tests/integration/test_integration_nnunetv2_runner.py +96 -0
- tests/integration/test_integration_segmentation_3d.py +304 -0
- tests/integration/test_integration_sliding_window.py +100 -0
- tests/integration/test_integration_stn.py +133 -0
- tests/integration/test_integration_unet_2d.py +67 -0
- tests/integration/test_integration_workers.py +61 -0
- tests/integration/test_integration_workflows.py +365 -0
- tests/integration/test_integration_workflows_adversarial.py +173 -0
- tests/integration/test_integration_workflows_gan.py +158 -0
- tests/integration/test_loader_semaphore.py +48 -0
- tests/integration/test_mapping_filed.py +122 -0
- tests/integration/test_meta_affine.py +183 -0
- tests/integration/test_metatensor_integration.py +114 -0
- tests/integration/test_module_list.py +76 -0
- tests/integration/test_one_of.py +283 -0
- tests/integration/test_pad_collation.py +124 -0
- tests/integration/test_reg_loss_integration.py +107 -0
- tests/integration/test_retinanet_predict_utils.py +154 -0
- tests/integration/test_seg_loss_integration.py +159 -0
- tests/integration/test_spatial_combine_transforms.py +185 -0
- tests/integration/test_testtimeaugmentation.py +186 -0
- tests/integration/test_vis_gradbased.py +69 -0
- tests/integration/test_vista3d_utils.py +159 -0
- tests/losses/__init__.py +10 -0
- tests/losses/deform/__init__.py +10 -0
- tests/losses/deform/test_bending_energy.py +88 -0
- tests/losses/deform/test_diffusion_loss.py +117 -0
- tests/losses/image_dissimilarity/__init__.py +10 -0
- tests/losses/image_dissimilarity/test_global_mutual_information_loss.py +150 -0
- tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py +162 -0
- tests/losses/test_adversarial_loss.py +94 -0
- tests/losses/test_barlow_twins_loss.py +109 -0
- tests/losses/test_cldice_loss.py +51 -0
- tests/losses/test_contrastive_loss.py +86 -0
- tests/losses/test_dice_ce_loss.py +123 -0
- tests/losses/test_dice_focal_loss.py +124 -0
- tests/losses/test_dice_loss.py +227 -0
- tests/losses/test_ds_loss.py +189 -0
- tests/losses/test_focal_loss.py +379 -0
- tests/losses/test_generalized_dice_focal_loss.py +85 -0
- tests/losses/test_generalized_dice_loss.py +221 -0
- tests/losses/test_generalized_wasserstein_dice_loss.py +234 -0
- tests/losses/test_giou_loss.py +62 -0
- tests/losses/test_hausdorff_loss.py +264 -0
- tests/losses/test_masked_dice_loss.py +152 -0
- tests/losses/test_masked_loss.py +87 -0
- tests/losses/test_multi_scale.py +86 -0
- tests/losses/test_nacl_loss.py +167 -0
- tests/losses/test_perceptual_loss.py +122 -0
- tests/losses/test_spectral_loss.py +86 -0
- tests/losses/test_ssim_loss.py +59 -0
- tests/losses/test_sure_loss.py +72 -0
- tests/losses/test_tversky_loss.py +198 -0
- tests/losses/test_unified_focal_loss.py +66 -0
- tests/metrics/__init__.py +10 -0
- tests/metrics/test_compute_confusion_matrix.py +294 -0
- tests/metrics/test_compute_f_beta.py +80 -0
- tests/metrics/test_compute_fid_metric.py +40 -0
- tests/metrics/test_compute_froc.py +143 -0
- tests/metrics/test_compute_generalized_dice.py +240 -0
- tests/metrics/test_compute_meandice.py +306 -0
- tests/metrics/test_compute_meaniou.py +223 -0
- tests/metrics/test_compute_mmd_metric.py +56 -0
- tests/metrics/test_compute_multiscalessim_metric.py +83 -0
- tests/metrics/test_compute_panoptic_quality.py +113 -0
- tests/metrics/test_compute_regression_metrics.py +196 -0
- tests/metrics/test_compute_roc_auc.py +155 -0
- tests/metrics/test_compute_variance.py +147 -0
- tests/metrics/test_cumulative.py +63 -0
- tests/metrics/test_cumulative_average.py +74 -0
- tests/metrics/test_cumulative_average_dist.py +48 -0
- tests/metrics/test_hausdorff_distance.py +209 -0
- tests/metrics/test_label_quality_score.py +134 -0
- tests/metrics/test_loss_metric.py +57 -0
- tests/metrics/test_metrics_reloaded.py +96 -0
- tests/metrics/test_ssim_metric.py +78 -0
- tests/metrics/test_surface_dice.py +416 -0
- tests/metrics/test_surface_distance.py +186 -0
- tests/networks/__init__.py +10 -0
- tests/networks/blocks/__init__.py +10 -0
- tests/networks/blocks/dints_block/__init__.py +10 -0
- tests/networks/blocks/dints_block/test_acn_block.py +41 -0
- tests/networks/blocks/dints_block/test_factorized_increase.py +37 -0
- tests/networks/blocks/dints_block/test_factorized_reduce.py +37 -0
- tests/networks/blocks/dints_block/test_p3d_block.py +78 -0
- tests/networks/blocks/test_adn.py +86 -0
- tests/networks/blocks/test_convolutions.py +156 -0
- tests/networks/blocks/test_crf_cpu.py +513 -0
- tests/networks/blocks/test_crf_cuda.py +528 -0
- tests/networks/blocks/test_crossattention.py +185 -0
- tests/networks/blocks/test_denseblock.py +105 -0
- tests/networks/blocks/test_dynunet_block.py +116 -0
- tests/networks/blocks/test_fpn_block.py +88 -0
- tests/networks/blocks/test_localnet_block.py +121 -0
- tests/networks/blocks/test_mlp.py +78 -0
- tests/networks/blocks/test_patchembedding.py +212 -0
- tests/networks/blocks/test_regunet_block.py +103 -0
- tests/networks/blocks/test_se_block.py +85 -0
- tests/networks/blocks/test_se_blocks.py +78 -0
- tests/networks/blocks/test_segresnet_block.py +57 -0
- tests/networks/blocks/test_selfattention.py +232 -0
- tests/networks/blocks/test_simple_aspp.py +87 -0
- tests/networks/blocks/test_spatialattention.py +55 -0
- tests/networks/blocks/test_subpixel_upsample.py +87 -0
- tests/networks/blocks/test_text_encoding.py +49 -0
- tests/networks/blocks/test_transformerblock.py +90 -0
- tests/networks/blocks/test_unetr_block.py +158 -0
- tests/networks/blocks/test_upsample_block.py +134 -0
- tests/networks/blocks/warp/__init__.py +10 -0
- tests/networks/blocks/warp/test_dvf2ddf.py +72 -0
- tests/networks/blocks/warp/test_warp.py +250 -0
- tests/networks/layers/__init__.py +10 -0
- tests/networks/layers/filtering/__init__.py +10 -0
- tests/networks/layers/filtering/test_bilateral_approx_cpu.py +399 -0
- tests/networks/layers/filtering/test_bilateral_approx_cuda.py +404 -0
- tests/networks/layers/filtering/test_bilateral_precise.py +437 -0
- tests/networks/layers/filtering/test_phl_cpu.py +259 -0
- tests/networks/layers/filtering/test_phl_cuda.py +167 -0
- tests/networks/layers/filtering/test_trainable_bilateral.py +474 -0
- tests/networks/layers/filtering/test_trainable_joint_bilateral.py +609 -0
- tests/networks/layers/test_affine_transform.py +385 -0
- tests/networks/layers/test_apply_filter.py +89 -0
- tests/networks/layers/test_channel_pad.py +51 -0
- tests/networks/layers/test_conjugate_gradient.py +56 -0
- tests/networks/layers/test_drop_path.py +46 -0
- tests/networks/layers/test_gaussian.py +317 -0
- tests/networks/layers/test_gaussian_filter.py +206 -0
- tests/networks/layers/test_get_layers.py +65 -0
- tests/networks/layers/test_gmm.py +314 -0
- tests/networks/layers/test_grid_pull.py +93 -0
- tests/networks/layers/test_hilbert_transform.py +131 -0
- tests/networks/layers/test_lltm.py +62 -0
- tests/networks/layers/test_median_filter.py +52 -0
- tests/networks/layers/test_polyval.py +55 -0
- tests/networks/layers/test_preset_filters.py +136 -0
- tests/networks/layers/test_savitzky_golay_filter.py +141 -0
- tests/networks/layers/test_separable_filter.py +87 -0
- tests/networks/layers/test_skip_connection.py +48 -0
- tests/networks/layers/test_vector_quantizer.py +89 -0
- tests/networks/layers/test_weight_init.py +50 -0
- tests/networks/nets/__init__.py +10 -0
- tests/networks/nets/dints/__init__.py +10 -0
- tests/networks/nets/dints/test_dints_cell.py +110 -0
- tests/networks/nets/dints/test_dints_mixop.py +84 -0
- tests/networks/nets/regunet/__init__.py +10 -0
- tests/networks/nets/regunet/test_localnet.py +86 -0
- tests/networks/nets/regunet/test_regunet.py +88 -0
- tests/networks/nets/test_ahnet.py +224 -0
- tests/networks/nets/test_attentionunet.py +88 -0
- tests/networks/nets/test_autoencoder.py +95 -0
- tests/networks/nets/test_autoencoderkl.py +337 -0
- tests/networks/nets/test_basic_unet.py +102 -0
- tests/networks/nets/test_basic_unetplusplus.py +109 -0
- tests/networks/nets/test_bundle_init_bundle.py +55 -0
- tests/networks/nets/test_cell_sam_wrapper.py +58 -0
- tests/networks/nets/test_controlnet.py +215 -0
- tests/networks/nets/test_daf3d.py +62 -0
- tests/networks/nets/test_densenet.py +121 -0
- tests/networks/nets/test_diffusion_model_unet.py +585 -0
- tests/networks/nets/test_dints_network.py +168 -0
- tests/networks/nets/test_discriminator.py +59 -0
- tests/networks/nets/test_dynunet.py +181 -0
- tests/networks/nets/test_efficientnet.py +400 -0
- tests/networks/nets/test_flexible_unet.py +341 -0
- tests/networks/nets/test_fullyconnectednet.py +69 -0
- tests/networks/nets/test_generator.py +59 -0
- tests/networks/nets/test_globalnet.py +103 -0
- tests/networks/nets/test_highresnet.py +67 -0
- tests/networks/nets/test_hovernet.py +218 -0
- tests/networks/nets/test_mednext.py +122 -0
- tests/networks/nets/test_milmodel.py +92 -0
- tests/networks/nets/test_net_adapter.py +68 -0
- tests/networks/nets/test_network_consistency.py +86 -0
- tests/networks/nets/test_patch_gan_dicriminator.py +179 -0
- tests/networks/nets/test_quicknat.py +57 -0
- tests/networks/nets/test_resnet.py +340 -0
- tests/networks/nets/test_segresnet.py +120 -0
- tests/networks/nets/test_segresnet_ds.py +156 -0
- tests/networks/nets/test_senet.py +151 -0
- tests/networks/nets/test_spade_autoencoderkl.py +295 -0
- tests/networks/nets/test_spade_diffusion_model_unet.py +574 -0
- tests/networks/nets/test_spade_vaegan.py +140 -0
- tests/networks/nets/test_swin_unetr.py +139 -0
- tests/networks/nets/test_torchvision_fc_model.py +201 -0
- tests/networks/nets/test_transchex.py +84 -0
- tests/networks/nets/test_transformer.py +108 -0
- tests/networks/nets/test_unet.py +208 -0
- tests/networks/nets/test_unetr.py +137 -0
- tests/networks/nets/test_varautoencoder.py +127 -0
- tests/networks/nets/test_vista3d.py +84 -0
- tests/networks/nets/test_vit.py +139 -0
- tests/networks/nets/test_vitautoenc.py +112 -0
- tests/networks/nets/test_vnet.py +81 -0
- tests/networks/nets/test_voxelmorph.py +280 -0
- tests/networks/nets/test_vqvae.py +274 -0
- tests/networks/nets/test_vqvaetransformer_inferer.py +295 -0
- tests/networks/schedulers/__init__.py +10 -0
- tests/networks/schedulers/test_scheduler_ddim.py +83 -0
- tests/networks/schedulers/test_scheduler_ddpm.py +104 -0
- tests/networks/schedulers/test_scheduler_pndm.py +108 -0
- tests/networks/test_bundle_onnx_export.py +71 -0
- tests/networks/test_convert_to_onnx.py +106 -0
- tests/networks/test_convert_to_torchscript.py +46 -0
- tests/networks/test_convert_to_trt.py +79 -0
- tests/networks/test_save_state.py +73 -0
- tests/networks/test_to_onehot.py +63 -0
- tests/networks/test_varnet.py +63 -0
- tests/networks/utils/__init__.py +10 -0
- tests/networks/utils/test_copy_model_state.py +187 -0
- tests/networks/utils/test_eval_mode.py +34 -0
- tests/networks/utils/test_freeze_layers.py +61 -0
- tests/networks/utils/test_replace_module.py +98 -0
- tests/networks/utils/test_train_mode.py +34 -0
- tests/optimizers/__init__.py +10 -0
- tests/optimizers/test_generate_param_groups.py +105 -0
- tests/optimizers/test_lr_finder.py +108 -0
- tests/optimizers/test_lr_scheduler.py +71 -0
- tests/optimizers/test_optim_novograd.py +100 -0
- tests/profile_subclass/__init__.py +10 -0
- tests/profile_subclass/cprofile_profiling.py +29 -0
- tests/profile_subclass/min_classes.py +30 -0
- tests/profile_subclass/profiling.py +73 -0
- tests/profile_subclass/pyspy_profiling.py +41 -0
- tests/transforms/__init__.py +10 -0
- tests/transforms/compose/__init__.py +10 -0
- tests/transforms/compose/test_compose.py +758 -0
- tests/transforms/compose/test_some_of.py +258 -0
- tests/transforms/croppad/__init__.py +10 -0
- tests/transforms/croppad/test_rand_weighted_crop.py +224 -0
- tests/transforms/croppad/test_rand_weighted_cropd.py +182 -0
- tests/transforms/functional/__init__.py +10 -0
- tests/transforms/functional/test_apply.py +75 -0
- tests/transforms/functional/test_resample.py +50 -0
- tests/transforms/intensity/__init__.py +10 -0
- tests/transforms/intensity/test_compute_ho_ver_maps.py +75 -0
- tests/transforms/intensity/test_compute_ho_ver_maps_d.py +79 -0
- tests/transforms/intensity/test_foreground_mask.py +98 -0
- tests/transforms/intensity/test_foreground_maskd.py +106 -0
- tests/transforms/intensity/test_rand_histogram_shiftd.py +76 -0
- tests/transforms/intensity/test_scale_intensity_range_percentiles.py +96 -0
- tests/transforms/intensity/test_scale_intensity_range_percentilesd.py +100 -0
- tests/transforms/inverse/__init__.py +10 -0
- tests/transforms/inverse/test_inverse_array.py +76 -0
- tests/transforms/inverse/test_traceable_transform.py +59 -0
- tests/transforms/post/__init__.py +10 -0
- tests/transforms/post/test_label_filterd.py +78 -0
- tests/transforms/post/test_probnms.py +72 -0
- tests/transforms/post/test_probnmsd.py +79 -0
- tests/transforms/post/test_remove_small_objects.py +102 -0
- tests/transforms/spatial/__init__.py +10 -0
- tests/transforms/spatial/test_convert_box_points.py +119 -0
- tests/transforms/spatial/test_grid_patch.py +134 -0
- tests/transforms/spatial/test_grid_patchd.py +102 -0
- tests/transforms/spatial/test_rand_grid_patch.py +150 -0
- tests/transforms/spatial/test_rand_grid_patchd.py +117 -0
- tests/transforms/spatial/test_spatial_resampled.py +124 -0
- tests/transforms/test_activations.py +120 -0
- tests/transforms/test_activationsd.py +64 -0
- tests/transforms/test_adaptors.py +160 -0
- tests/transforms/test_add_coordinate_channels.py +53 -0
- tests/transforms/test_add_coordinate_channelsd.py +67 -0
- tests/transforms/test_add_extreme_points_channel.py +80 -0
- tests/transforms/test_add_extreme_points_channeld.py +77 -0
- tests/transforms/test_adjust_contrast.py +70 -0
- tests/transforms/test_adjust_contrastd.py +64 -0
- tests/transforms/test_affine.py +245 -0
- tests/transforms/test_affine_grid.py +152 -0
- tests/transforms/test_affined.py +190 -0
- tests/transforms/test_as_channel_last.py +38 -0
- tests/transforms/test_as_channel_lastd.py +44 -0
- tests/transforms/test_as_discrete.py +81 -0
- tests/transforms/test_as_discreted.py +82 -0
- tests/transforms/test_border_pad.py +49 -0
- tests/transforms/test_border_padd.py +45 -0
- tests/transforms/test_bounding_rect.py +54 -0
- tests/transforms/test_bounding_rectd.py +53 -0
- tests/transforms/test_cast_to_type.py +63 -0
- tests/transforms/test_cast_to_typed.py +74 -0
- tests/transforms/test_center_scale_crop.py +55 -0
- tests/transforms/test_center_scale_cropd.py +56 -0
- tests/transforms/test_center_spatial_crop.py +56 -0
- tests/transforms/test_center_spatial_cropd.py +63 -0
- tests/transforms/test_classes_to_indices.py +93 -0
- tests/transforms/test_classes_to_indicesd.py +110 -0
- tests/transforms/test_clip_intensity_percentiles.py +196 -0
- tests/transforms/test_clip_intensity_percentilesd.py +193 -0
- tests/transforms/test_compose_get_number_conversions.py +127 -0
- tests/transforms/test_concat_itemsd.py +82 -0
- tests/transforms/test_convert_to_multi_channel.py +59 -0
- tests/transforms/test_convert_to_multi_channeld.py +37 -0
- tests/transforms/test_copy_itemsd.py +86 -0
- tests/transforms/test_create_grid_and_affine.py +274 -0
- tests/transforms/test_crop_foreground.py +164 -0
- tests/transforms/test_crop_foregroundd.py +205 -0
- tests/transforms/test_cucim_dict_transform.py +142 -0
- tests/transforms/test_cucim_transform.py +141 -0
- tests/transforms/test_data_stats.py +221 -0
- tests/transforms/test_data_statsd.py +249 -0
- tests/transforms/test_delete_itemsd.py +58 -0
- tests/transforms/test_detect_envelope.py +159 -0
- tests/transforms/test_distance_transform_edt.py +202 -0
- tests/transforms/test_divisible_pad.py +49 -0
- tests/transforms/test_divisible_padd.py +42 -0
- tests/transforms/test_ensure_channel_first.py +113 -0
- tests/transforms/test_ensure_channel_firstd.py +85 -0
- tests/transforms/test_ensure_type.py +94 -0
- tests/transforms/test_ensure_typed.py +110 -0
- tests/transforms/test_fg_bg_to_indices.py +83 -0
- tests/transforms/test_fg_bg_to_indicesd.py +78 -0
- tests/transforms/test_fill_holes.py +207 -0
- tests/transforms/test_fill_holesd.py +209 -0
- tests/transforms/test_flatten_sub_keysd.py +64 -0
- tests/transforms/test_flip.py +83 -0
- tests/transforms/test_flipd.py +90 -0
- tests/transforms/test_fourier.py +70 -0
- tests/transforms/test_gaussian_sharpen.py +92 -0
- tests/transforms/test_gaussian_sharpend.py +92 -0
- tests/transforms/test_gaussian_smooth.py +96 -0
- tests/transforms/test_gaussian_smoothd.py +96 -0
- tests/transforms/test_generate_label_classes_crop_centers.py +71 -0
- tests/transforms/test_generate_pos_neg_label_crop_centers.py +76 -0
- tests/transforms/test_generate_spatial_bounding_box.py +114 -0
- tests/transforms/test_get_extreme_points.py +57 -0
- tests/transforms/test_gibbs_noise.py +75 -0
- tests/transforms/test_gibbs_noised.py +88 -0
- tests/transforms/test_grid_distortion.py +113 -0
- tests/transforms/test_grid_distortiond.py +87 -0
- tests/transforms/test_grid_split.py +88 -0
- tests/transforms/test_grid_splitd.py +96 -0
- tests/transforms/test_histogram_normalize.py +59 -0
- tests/transforms/test_histogram_normalized.py +59 -0
- tests/transforms/test_image_filter.py +259 -0
- tests/transforms/test_intensity_stats.py +73 -0
- tests/transforms/test_intensity_statsd.py +90 -0
- tests/transforms/test_inverse.py +521 -0
- tests/transforms/test_inverse_collation.py +147 -0
- tests/transforms/test_invert.py +105 -0
- tests/transforms/test_invertd.py +142 -0
- tests/transforms/test_k_space_spike_noise.py +81 -0
- tests/transforms/test_k_space_spike_noised.py +98 -0
- tests/transforms/test_keep_largest_connected_component.py +419 -0
- tests/transforms/test_keep_largest_connected_componentd.py +348 -0
- tests/transforms/test_label_filter.py +78 -0
- tests/transforms/test_label_to_contour.py +179 -0
- tests/transforms/test_label_to_contourd.py +182 -0
- tests/transforms/test_label_to_mask.py +69 -0
- tests/transforms/test_label_to_maskd.py +70 -0
- tests/transforms/test_load_image.py +502 -0
- tests/transforms/test_load_imaged.py +198 -0
- tests/transforms/test_load_spacing_orientation.py +149 -0
- tests/transforms/test_map_and_generate_sampling_centers.py +86 -0
- tests/transforms/test_map_binary_to_indices.py +75 -0
- tests/transforms/test_map_classes_to_indices.py +135 -0
- tests/transforms/test_map_label_value.py +89 -0
- tests/transforms/test_map_label_valued.py +85 -0
- tests/transforms/test_map_transform.py +45 -0
- tests/transforms/test_mask_intensity.py +74 -0
- tests/transforms/test_mask_intensityd.py +68 -0
- tests/transforms/test_mean_ensemble.py +77 -0
- tests/transforms/test_mean_ensembled.py +91 -0
- tests/transforms/test_median_smooth.py +41 -0
- tests/transforms/test_median_smoothd.py +65 -0
- tests/transforms/test_morphological_ops.py +101 -0
- tests/transforms/test_nifti_endianness.py +107 -0
- tests/transforms/test_normalize_intensity.py +143 -0
- tests/transforms/test_normalize_intensityd.py +81 -0
- tests/transforms/test_nvtx_decorator.py +289 -0
- tests/transforms/test_nvtx_transform.py +143 -0
- tests/transforms/test_orientation.py +247 -0
- tests/transforms/test_orientationd.py +112 -0
- tests/transforms/test_rand_adjust_contrast.py +45 -0
- tests/transforms/test_rand_adjust_contrastd.py +44 -0
- tests/transforms/test_rand_affine.py +201 -0
- tests/transforms/test_rand_affine_grid.py +212 -0
- tests/transforms/test_rand_affined.py +281 -0
- tests/transforms/test_rand_axis_flip.py +50 -0
- tests/transforms/test_rand_axis_flipd.py +50 -0
- tests/transforms/test_rand_bias_field.py +69 -0
- tests/transforms/test_rand_bias_fieldd.py +65 -0
- tests/transforms/test_rand_coarse_dropout.py +110 -0
- tests/transforms/test_rand_coarse_dropoutd.py +107 -0
- tests/transforms/test_rand_coarse_shuffle.py +65 -0
- tests/transforms/test_rand_coarse_shuffled.py +59 -0
- tests/transforms/test_rand_crop_by_label_classes.py +170 -0
- tests/transforms/test_rand_crop_by_label_classesd.py +159 -0
- tests/transforms/test_rand_crop_by_pos_neg_label.py +152 -0
- tests/transforms/test_rand_crop_by_pos_neg_labeld.py +172 -0
- tests/transforms/test_rand_cucim_dict_transform.py +162 -0
- tests/transforms/test_rand_cucim_transform.py +162 -0
- tests/transforms/test_rand_deform_grid.py +138 -0
- tests/transforms/test_rand_elastic_2d.py +127 -0
- tests/transforms/test_rand_elastic_3d.py +104 -0
- tests/transforms/test_rand_elasticd_2d.py +177 -0
- tests/transforms/test_rand_elasticd_3d.py +156 -0
- tests/transforms/test_rand_flip.py +60 -0
- tests/transforms/test_rand_flipd.py +55 -0
- tests/transforms/test_rand_gaussian_noise.py +48 -0
- tests/transforms/test_rand_gaussian_noised.py +54 -0
- tests/transforms/test_rand_gaussian_sharpen.py +140 -0
- tests/transforms/test_rand_gaussian_sharpend.py +143 -0
- tests/transforms/test_rand_gaussian_smooth.py +98 -0
- tests/transforms/test_rand_gaussian_smoothd.py +98 -0
- tests/transforms/test_rand_gibbs_noise.py +103 -0
- tests/transforms/test_rand_gibbs_noised.py +117 -0
- tests/transforms/test_rand_grid_distortion.py +99 -0
- tests/transforms/test_rand_grid_distortiond.py +90 -0
- tests/transforms/test_rand_histogram_shift.py +92 -0
- tests/transforms/test_rand_k_space_spike_noise.py +92 -0
- tests/transforms/test_rand_k_space_spike_noised.py +76 -0
- tests/transforms/test_rand_rician_noise.py +52 -0
- tests/transforms/test_rand_rician_noised.py +52 -0
- tests/transforms/test_rand_rotate.py +166 -0
- tests/transforms/test_rand_rotate90.py +100 -0
- tests/transforms/test_rand_rotate90d.py +112 -0
- tests/transforms/test_rand_rotated.py +187 -0
- tests/transforms/test_rand_scale_crop.py +78 -0
- tests/transforms/test_rand_scale_cropd.py +98 -0
- tests/transforms/test_rand_scale_intensity.py +54 -0
- tests/transforms/test_rand_scale_intensity_fixed_mean.py +41 -0
- tests/transforms/test_rand_scale_intensity_fixed_meand.py +41 -0
- tests/transforms/test_rand_scale_intensityd.py +53 -0
- tests/transforms/test_rand_shift_intensity.py +52 -0
- tests/transforms/test_rand_shift_intensityd.py +67 -0
- tests/transforms/test_rand_simulate_low_resolution.py +83 -0
- tests/transforms/test_rand_simulate_low_resolutiond.py +73 -0
- tests/transforms/test_rand_spatial_crop.py +107 -0
- tests/transforms/test_rand_spatial_crop_samples.py +128 -0
- tests/transforms/test_rand_spatial_crop_samplesd.py +147 -0
- tests/transforms/test_rand_spatial_cropd.py +112 -0
- tests/transforms/test_rand_std_shift_intensity.py +43 -0
- tests/transforms/test_rand_std_shift_intensityd.py +38 -0
- tests/transforms/test_rand_zoom.py +105 -0
- tests/transforms/test_rand_zoomd.py +108 -0
- tests/transforms/test_randidentity.py +49 -0
- tests/transforms/test_random_order.py +144 -0
- tests/transforms/test_randtorchvisiond.py +65 -0
- tests/transforms/test_regularization.py +139 -0
- tests/transforms/test_remove_repeated_channel.py +34 -0
- tests/transforms/test_remove_repeated_channeld.py +44 -0
- tests/transforms/test_repeat_channel.py +34 -0
- tests/transforms/test_repeat_channeld.py +41 -0
- tests/transforms/test_resample_backends.py +65 -0
- tests/transforms/test_resample_to_match.py +110 -0
- tests/transforms/test_resample_to_matchd.py +93 -0
- tests/transforms/test_resampler.py +165 -0
- tests/transforms/test_resize.py +140 -0
- tests/transforms/test_resize_with_pad_or_crop.py +91 -0
- tests/transforms/test_resize_with_pad_or_cropd.py +86 -0
- tests/transforms/test_resized.py +163 -0
- tests/transforms/test_rotate.py +160 -0
- tests/transforms/test_rotate90.py +212 -0
- tests/transforms/test_rotate90d.py +106 -0
- tests/transforms/test_rotated.py +179 -0
- tests/transforms/test_save_classificationd.py +109 -0
- tests/transforms/test_save_image.py +80 -0
- tests/transforms/test_save_imaged.py +130 -0
- tests/transforms/test_savitzky_golay_smooth.py +73 -0
- tests/transforms/test_savitzky_golay_smoothd.py +73 -0
- tests/transforms/test_scale_intensity.py +76 -0
- tests/transforms/test_scale_intensity_fixed_mean.py +94 -0
- tests/transforms/test_scale_intensity_range.py +41 -0
- tests/transforms/test_scale_intensity_ranged.py +40 -0
- tests/transforms/test_scale_intensityd.py +57 -0
- tests/transforms/test_select_itemsd.py +41 -0
- tests/transforms/test_shift_intensity.py +31 -0
- tests/transforms/test_shift_intensityd.py +44 -0
- tests/transforms/test_signal_continuouswavelet.py +44 -0
- tests/transforms/test_signal_fillempty.py +52 -0
- tests/transforms/test_signal_fillemptyd.py +60 -0
- tests/transforms/test_signal_rand_add_gaussiannoise.py +50 -0
- tests/transforms/test_signal_rand_add_sine.py +52 -0
- tests/transforms/test_signal_rand_add_sine_partial.py +50 -0
- tests/transforms/test_signal_rand_add_squarepulse.py +58 -0
- tests/transforms/test_signal_rand_add_squarepulse_partial.py +62 -0
- tests/transforms/test_signal_rand_drop.py +50 -0
- tests/transforms/test_signal_rand_scale.py +52 -0
- tests/transforms/test_signal_rand_shift.py +55 -0
- tests/transforms/test_signal_remove_frequency.py +71 -0
- tests/transforms/test_smooth_field.py +177 -0
- tests/transforms/test_sobel_gradient.py +189 -0
- tests/transforms/test_sobel_gradientd.py +212 -0
- tests/transforms/test_spacing.py +381 -0
- tests/transforms/test_spacingd.py +178 -0
- tests/transforms/test_spatial_crop.py +82 -0
- tests/transforms/test_spatial_cropd.py +74 -0
- tests/transforms/test_spatial_pad.py +57 -0
- tests/transforms/test_spatial_padd.py +43 -0
- tests/transforms/test_spatial_resample.py +235 -0
- tests/transforms/test_squeezedim.py +62 -0
- tests/transforms/test_squeezedimd.py +98 -0
- tests/transforms/test_std_shift_intensity.py +76 -0
- tests/transforms/test_std_shift_intensityd.py +74 -0
- tests/transforms/test_threshold_intensity.py +38 -0
- tests/transforms/test_threshold_intensityd.py +58 -0
- tests/transforms/test_to_contiguous.py +47 -0
- tests/transforms/test_to_cupy.py +112 -0
- tests/transforms/test_to_cupyd.py +76 -0
- tests/transforms/test_to_device.py +42 -0
- tests/transforms/test_to_deviced.py +37 -0
- tests/transforms/test_to_numpy.py +85 -0
- tests/transforms/test_to_numpyd.py +68 -0
- tests/transforms/test_to_pil.py +52 -0
- tests/transforms/test_to_pild.py +50 -0
- tests/transforms/test_to_tensor.py +60 -0
- tests/transforms/test_to_tensord.py +71 -0
- tests/transforms/test_torchvision.py +66 -0
- tests/transforms/test_torchvisiond.py +63 -0
- tests/transforms/test_transform.py +62 -0
- tests/transforms/test_transpose.py +41 -0
- tests/transforms/test_transposed.py +52 -0
- tests/transforms/test_ultrasound_confidence_map_transform.py +711 -0
- tests/transforms/test_utils_pytorch_numpy_unification.py +90 -0
- tests/transforms/test_vote_ensemble.py +84 -0
- tests/transforms/test_vote_ensembled.py +107 -0
- tests/transforms/test_with_allow_missing_keys.py +76 -0
- tests/transforms/test_zoom.py +120 -0
- tests/transforms/test_zoomd.py +94 -0
- tests/transforms/transform/__init__.py +10 -0
- tests/transforms/transform/test_randomizable.py +52 -0
- tests/transforms/transform/test_randomizable_transform_type.py +37 -0
- tests/transforms/utility/__init__.py +10 -0
- tests/transforms/utility/test_apply_transform_to_points.py +81 -0
- tests/transforms/utility/test_apply_transform_to_pointsd.py +185 -0
- tests/transforms/utility/test_identity.py +29 -0
- tests/transforms/utility/test_identityd.py +30 -0
- tests/transforms/utility/test_lambda.py +71 -0
- tests/transforms/utility/test_lambdad.py +83 -0
- tests/transforms/utility/test_rand_lambda.py +87 -0
- tests/transforms/utility/test_rand_lambdad.py +77 -0
- tests/transforms/utility/test_simulatedelay.py +36 -0
- tests/transforms/utility/test_simulatedelayd.py +36 -0
- tests/transforms/utility/test_splitdim.py +52 -0
- tests/transforms/utility/test_splitdimd.py +96 -0
- tests/transforms/utils/__init__.py +10 -0
- tests/transforms/utils/test_correct_crop_centers.py +36 -0
- tests/transforms/utils/test_get_unique_labels.py +45 -0
- tests/transforms/utils/test_print_transform_backends.py +29 -0
- tests/transforms/utils/test_soft_clip.py +125 -0
- tests/utils/__init__.py +10 -0
- tests/utils/enums/__init__.py +10 -0
- tests/utils/enums/test_hovernet_loss.py +190 -0
- tests/utils/enums/test_ordering.py +289 -0
- tests/utils/enums/test_wsireader.py +663 -0
- tests/utils/misc/__init__.py +10 -0
- tests/utils/misc/test_ensure_tuple.py +53 -0
- tests/utils/misc/test_monai_env_vars.py +44 -0
- tests/utils/misc/test_monai_utils_misc.py +103 -0
- tests/utils/misc/test_str2bool.py +34 -0
- tests/utils/misc/test_str2list.py +33 -0
- tests/utils/test_alias.py +44 -0
- tests/utils/test_component_store.py +73 -0
- tests/utils/test_deprecated.py +455 -0
- tests/utils/test_enum_bound_interp.py +75 -0
- tests/utils/test_evenly_divisible_all_gather_dist.py +50 -0
- tests/utils/test_get_package_version.py +34 -0
- tests/utils/test_handler_logfile.py +84 -0
- tests/utils/test_handler_metric_logger.py +62 -0
- tests/utils/test_list_to_dict.py +43 -0
- tests/utils/test_look_up_option.py +87 -0
- tests/utils/test_optional_import.py +80 -0
- tests/utils/test_pad_mode.py +39 -0
- tests/utils/test_profiling.py +208 -0
- tests/utils/test_rankfilter_dist.py +77 -0
- tests/utils/test_require_pkg.py +83 -0
- tests/utils/test_sample_slices.py +43 -0
- tests/utils/test_set_determinism.py +74 -0
- tests/utils/test_squeeze_unsqueeze.py +71 -0
- tests/utils/test_state_cacher.py +67 -0
- tests/utils/test_torchscript_utils.py +113 -0
- tests/utils/test_version.py +91 -0
- tests/utils/test_version_after.py +65 -0
- tests/utils/type_conversion/__init__.py +10 -0
- tests/utils/type_conversion/test_convert_data_type.py +152 -0
- tests/utils/type_conversion/test_get_equivalent_dtype.py +65 -0
- tests/utils/type_conversion/test_safe_dtype_range.py +99 -0
- tests/visualize/__init__.py +10 -0
- tests/visualize/test_img2tensorboard.py +46 -0
- tests/visualize/test_occlusion_sensitivity.py +128 -0
- tests/visualize/test_plot_2d_or_3d_image.py +74 -0
- tests/visualize/test_vis_cam.py +98 -0
- tests/visualize/test_vis_gradcam.py +211 -0
- tests/visualize/utils/__init__.py +10 -0
- tests/visualize/utils/test_blend_images.py +63 -0
- tests/visualize/utils/test_matshow3d.py +133 -0
- monai_weekly-1.5.dev2506.dist-info/RECORD +0 -427
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2507.dist-info}/LICENSE +0 -0
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2507.dist-info}/WHEEL +0 -0
@@ -0,0 +1,112 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
from parameterized import parameterized
|
18
|
+
|
19
|
+
from monai.data.meta_tensor import MetaTensor
|
20
|
+
from monai.transforms import RandScaleCropd, RandSpatialCropd
|
21
|
+
from monai.transforms.lazy.functional import apply_pending
|
22
|
+
from tests.croppers import CropTest
|
23
|
+
from tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose
|
24
|
+
|
25
|
+
TEST_SHAPES = [
|
26
|
+
[{"keys": "img", "roi_size": [3, 3, -1], "random_center": True}, (3, 3, 3, 5), (3, 3, 3, 5)],
|
27
|
+
[{"keys": "img", "roi_size": [3, 3, 3], "random_center": True}, (3, 3, 3, 3), (3, 3, 3, 3)],
|
28
|
+
[{"keys": "img", "roi_size": [3, 3, 3], "random_center": False}, (3, 3, 3, 3), (3, 3, 3, 3)],
|
29
|
+
[{"keys": "img", "roi_size": [3, 2, 3], "random_center": False, "random_size": False}, (3, 3, 3, 3), (3, 3, 2, 3)],
|
30
|
+
]
|
31
|
+
|
32
|
+
TEST_VALUES = [
|
33
|
+
[
|
34
|
+
{"keys": "img", "roi_size": [3, 3], "random_center": False},
|
35
|
+
np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]),
|
36
|
+
]
|
37
|
+
]
|
38
|
+
|
39
|
+
TEST_RANDOM_SHAPES = [
|
40
|
+
[
|
41
|
+
{"keys": "img", "roi_size": [3, 3, 3], "max_roi_size": [5, -1, 4], "random_center": True, "random_size": True},
|
42
|
+
(1, 4, 5, 6),
|
43
|
+
(1, 4, 4, 3),
|
44
|
+
],
|
45
|
+
[
|
46
|
+
{"keys": "img", "roi_size": 3, "max_roi_size": 4, "random_center": True, "random_size": True},
|
47
|
+
(1, 4, 5, 6),
|
48
|
+
(1, 3, 4, 3),
|
49
|
+
],
|
50
|
+
]
|
51
|
+
|
52
|
+
func1 = {RandSpatialCropd: {"keys": "img", "roi_size": [8, 7, -1], "random_center": True, "random_size": False}}
|
53
|
+
func2 = {RandScaleCropd: {"keys": "img", "roi_scale": [0.5, 0.6, -1.0], "random_center": True, "random_size": True}}
|
54
|
+
func3 = {RandScaleCropd: {"keys": "img", "roi_scale": [1.0, 0.5, -1.0], "random_center": False, "random_size": False}}
|
55
|
+
|
56
|
+
TESTS_COMBINE = []
|
57
|
+
TESTS_COMBINE.append([[func1, func2, func3], (3, 10, 10, 8)])
|
58
|
+
TESTS_COMBINE.append([[func1, func2], (3, 8, 8, 4)])
|
59
|
+
TESTS_COMBINE.append([[func2, func2], (3, 8, 8, 4)])
|
60
|
+
|
61
|
+
|
62
|
+
class TestRandSpatialCropd(CropTest):
|
63
|
+
Cropper = RandSpatialCropd
|
64
|
+
|
65
|
+
@parameterized.expand(TEST_SHAPES)
|
66
|
+
def test_shape(self, input_param, input_shape, expected_shape):
|
67
|
+
self.crop_test(input_param, input_shape, expected_shape)
|
68
|
+
|
69
|
+
@parameterized.expand(TEST_VALUES)
|
70
|
+
def test_value(self, input_param, input_im):
|
71
|
+
for im_type in TEST_NDARRAYS_ALL:
|
72
|
+
with self.subTest(im_type=im_type):
|
73
|
+
cropper = self.Cropper(**input_param)
|
74
|
+
input_data = {"img": im_type(input_im)}
|
75
|
+
result = cropper(input_data)["img"]
|
76
|
+
roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper.cropper._size]
|
77
|
+
assert_allclose(result, input_im[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test="tensor")
|
78
|
+
|
79
|
+
@parameterized.expand(TEST_RANDOM_SHAPES)
|
80
|
+
def test_random_shape(self, input_param, input_shape, expected_shape):
|
81
|
+
for im_type in TEST_NDARRAYS_ALL:
|
82
|
+
with self.subTest(im_type=im_type):
|
83
|
+
cropper = self.Cropper(**input_param)
|
84
|
+
cropper.set_random_state(seed=123)
|
85
|
+
input_data = {"img": im_type(np.random.randint(0, 2, input_shape))}
|
86
|
+
expected = cropper(input_data)["img"]
|
87
|
+
self.assertTupleEqual(expected.shape, expected_shape)
|
88
|
+
|
89
|
+
# lazy
|
90
|
+
# reset random seed to ensure the same results
|
91
|
+
cropper.set_random_state(seed=123)
|
92
|
+
cropper.lazy = True
|
93
|
+
pending_result = cropper(input_data)["img"]
|
94
|
+
self.assertIsInstance(pending_result, MetaTensor)
|
95
|
+
assert_allclose(pending_result.peek_pending_affine(), expected.affine)
|
96
|
+
assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:])
|
97
|
+
# only support nearest
|
98
|
+
result = apply_pending(pending_result, overrides={"mode": "nearest", "align_corners": False})[0]
|
99
|
+
# compare
|
100
|
+
assert_allclose(result, expected, rtol=1e-5)
|
101
|
+
|
102
|
+
@parameterized.expand(TEST_SHAPES)
|
103
|
+
def test_pending_ops(self, input_param, input_shape, _):
|
104
|
+
self.crop_test_pending_ops(input_param, input_shape)
|
105
|
+
|
106
|
+
@parameterized.expand(TESTS_COMBINE)
|
107
|
+
def test_combine_ops(self, funcs, input_shape):
|
108
|
+
self.crop_test_combine_ops(funcs, input_shape)
|
109
|
+
|
110
|
+
|
111
|
+
if __name__ == "__main__":
|
112
|
+
unittest.main()
|
@@ -0,0 +1,43 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
import torch
|
18
|
+
from parameterized import parameterized
|
19
|
+
|
20
|
+
from monai.transforms import RandStdShiftIntensity
|
21
|
+
from tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose
|
22
|
+
|
23
|
+
|
24
|
+
class TestRandStdShiftIntensity(NumpyImageTestCase2D):
|
25
|
+
@parameterized.expand([[p] for p in TEST_NDARRAYS])
|
26
|
+
def test_value(self, p):
|
27
|
+
np.random.seed(0)
|
28
|
+
# simulate the randomize() of transform
|
29
|
+
np.random.random()
|
30
|
+
factor = np.random.uniform(low=-1.0, high=1.0)
|
31
|
+
offset = factor * np.std(self.imt)
|
32
|
+
expected = p(self.imt + offset)
|
33
|
+
shifter = RandStdShiftIntensity(factors=1.0, prob=1.0)
|
34
|
+
shifter.set_random_state(seed=0)
|
35
|
+
_imt = p(self.imt)
|
36
|
+
result = shifter(_imt)
|
37
|
+
if isinstance(_imt, torch.Tensor):
|
38
|
+
self.assertEqual(result.dtype, _imt.dtype)
|
39
|
+
assert_allclose(result, expected, atol=0, rtol=1e-5, type_test="tensor")
|
40
|
+
|
41
|
+
|
42
|
+
if __name__ == "__main__":
|
43
|
+
unittest.main()
|
@@ -0,0 +1,38 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
|
18
|
+
from monai.transforms import RandStdShiftIntensityd
|
19
|
+
from tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose
|
20
|
+
|
21
|
+
|
22
|
+
class TestRandStdShiftIntensityd(NumpyImageTestCase2D):
|
23
|
+
def test_value(self):
|
24
|
+
for p in TEST_NDARRAYS:
|
25
|
+
key = "img"
|
26
|
+
np.random.seed(0)
|
27
|
+
# simulate the randomize() of transform
|
28
|
+
np.random.random()
|
29
|
+
factor = np.random.uniform(low=-1.0, high=1.0)
|
30
|
+
expected = self.imt + factor * np.std(self.imt)
|
31
|
+
shifter = RandStdShiftIntensityd(keys=[key], factors=1.0, prob=1.0)
|
32
|
+
shifter.set_random_state(seed=0)
|
33
|
+
result = shifter({key: p(self.imt)})[key]
|
34
|
+
assert_allclose(result, expected, rtol=1e-5, type_test="tensor")
|
35
|
+
|
36
|
+
|
37
|
+
if __name__ == "__main__":
|
38
|
+
unittest.main()
|
@@ -0,0 +1,105 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
import torch
|
18
|
+
from parameterized import parameterized
|
19
|
+
from scipy.ndimage import zoom as zoom_scipy
|
20
|
+
|
21
|
+
from monai.config import USE_COMPILED
|
22
|
+
from monai.transforms import RandZoom
|
23
|
+
from monai.utils import InterpolateMode
|
24
|
+
from tests.lazy_transforms_utils import test_resampler_lazy
|
25
|
+
from tests.test_utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion
|
26
|
+
|
27
|
+
VALID_CASES = [
|
28
|
+
(0.8, 1.2, "nearest", False),
|
29
|
+
(0.8, 1.2, InterpolateMode.NEAREST, False),
|
30
|
+
(0.8, 1.2, InterpolateMode.BILINEAR, False, True),
|
31
|
+
(0.8, 1.2, InterpolateMode.BILINEAR, False, False),
|
32
|
+
]
|
33
|
+
|
34
|
+
|
35
|
+
class TestRandZoom(NumpyImageTestCase2D):
|
36
|
+
@parameterized.expand(VALID_CASES)
|
37
|
+
def test_correct_results(self, min_zoom, max_zoom, mode, keep_size, align_corners=None):
|
38
|
+
for p in TEST_NDARRAYS_ALL:
|
39
|
+
init_param = {
|
40
|
+
"prob": 1.0,
|
41
|
+
"min_zoom": min_zoom,
|
42
|
+
"max_zoom": max_zoom,
|
43
|
+
"mode": mode,
|
44
|
+
"keep_size": keep_size,
|
45
|
+
"dtype": torch.float64,
|
46
|
+
"align_corners": align_corners,
|
47
|
+
}
|
48
|
+
random_zoom = RandZoom(**init_param)
|
49
|
+
random_zoom.set_random_state(1234)
|
50
|
+
im = p(self.imt[0])
|
51
|
+
call_param = {"img": im}
|
52
|
+
zoomed = random_zoom(**call_param)
|
53
|
+
|
54
|
+
# test lazy
|
55
|
+
# TODO: temporarily skip "nearest" test
|
56
|
+
if mode == InterpolateMode.BILINEAR:
|
57
|
+
test_resampler_lazy(
|
58
|
+
random_zoom, zoomed, init_param, call_param, seed=1234, atol=1e-4 if USE_COMPILED else 1e-6
|
59
|
+
)
|
60
|
+
|
61
|
+
test_local_inversion(random_zoom, zoomed, im)
|
62
|
+
expected = [
|
63
|
+
zoom_scipy(channel, zoom=random_zoom._zoom, mode="nearest", order=0, prefilter=False)
|
64
|
+
for channel in self.imt[0]
|
65
|
+
]
|
66
|
+
|
67
|
+
expected = np.stack(expected).astype(np.float32)
|
68
|
+
assert_allclose(zoomed, p(expected), atol=1.0, type_test=False)
|
69
|
+
|
70
|
+
def test_keep_size(self):
|
71
|
+
for p in TEST_NDARRAYS_ALL:
|
72
|
+
im = p(self.imt[0])
|
73
|
+
random_zoom = RandZoom(prob=1.0, min_zoom=0.6, max_zoom=0.7, keep_size=True)
|
74
|
+
random_zoom.set_random_state(12)
|
75
|
+
zoomed = random_zoom(im)
|
76
|
+
test_local_inversion(random_zoom, zoomed, im)
|
77
|
+
self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:]))
|
78
|
+
zoomed = random_zoom(im)
|
79
|
+
self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:]))
|
80
|
+
zoomed = random_zoom(im)
|
81
|
+
self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:]))
|
82
|
+
random_zoom.prob = 0.0
|
83
|
+
self.assertEqual(random_zoom(im).dtype, torch.float32)
|
84
|
+
|
85
|
+
@parameterized.expand(
|
86
|
+
[("no_min_zoom", None, 1.1, "bilinear", TypeError), ("invalid_mode", 0.9, 1.1, "s", ValueError)]
|
87
|
+
)
|
88
|
+
def test_invalid_inputs(self, _, min_zoom, max_zoom, mode, raises):
|
89
|
+
for p in TEST_NDARRAYS_ALL:
|
90
|
+
with self.assertRaises(raises):
|
91
|
+
random_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode)
|
92
|
+
random_zoom(p(self.imt[0]))
|
93
|
+
|
94
|
+
def test_auto_expand_3d(self):
|
95
|
+
for p in TEST_NDARRAYS_ALL:
|
96
|
+
random_zoom = RandZoom(prob=1.0, min_zoom=[0.8, 0.7], max_zoom=[1.2, 1.3], mode="nearest", keep_size=False)
|
97
|
+
random_zoom.set_random_state(1234)
|
98
|
+
test_data = p(np.random.randint(0, 2, size=[2, 2, 3, 4]))
|
99
|
+
zoomed = random_zoom(test_data)
|
100
|
+
assert_allclose(random_zoom._zoom, (1.048844, 1.048844, 0.962637), atol=1e-2, type_test=False)
|
101
|
+
assert_allclose(zoomed.shape, (2, 2, 3, 3), type_test=False)
|
102
|
+
|
103
|
+
|
104
|
+
if __name__ == "__main__":
|
105
|
+
unittest.main()
|
@@ -0,0 +1,108 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
import torch
|
18
|
+
from parameterized import parameterized
|
19
|
+
from scipy.ndimage import zoom as zoom_scipy
|
20
|
+
|
21
|
+
from monai.config import USE_COMPILED
|
22
|
+
from monai.transforms import RandZoomd
|
23
|
+
from tests.lazy_transforms_utils import test_resampler_lazy
|
24
|
+
from tests.test_utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion
|
25
|
+
|
26
|
+
VALID_CASES = [
|
27
|
+
(0.8, 1.2, "nearest", None, False),
|
28
|
+
(0.8, 1.2, "bilinear", None, False),
|
29
|
+
(0.8, 1.2, "bilinear", False, False),
|
30
|
+
]
|
31
|
+
|
32
|
+
|
33
|
+
class TestRandZoomd(NumpyImageTestCase2D):
|
34
|
+
@parameterized.expand(VALID_CASES)
|
35
|
+
def test_correct_results(self, min_zoom, max_zoom, mode, align_corners, keep_size):
|
36
|
+
key = "img"
|
37
|
+
init_param = {
|
38
|
+
"keys": key,
|
39
|
+
"prob": 1.0,
|
40
|
+
"min_zoom": min_zoom,
|
41
|
+
"max_zoom": max_zoom,
|
42
|
+
"mode": mode,
|
43
|
+
"align_corners": align_corners,
|
44
|
+
"keep_size": keep_size,
|
45
|
+
"dtype": torch.float64,
|
46
|
+
}
|
47
|
+
random_zoom = RandZoomd(**init_param)
|
48
|
+
for p in TEST_NDARRAYS_ALL:
|
49
|
+
random_zoom.set_random_state(1234)
|
50
|
+
|
51
|
+
im = p(self.imt[0])
|
52
|
+
call_param = {"data": {key: im}}
|
53
|
+
zoomed = random_zoom(**call_param)
|
54
|
+
|
55
|
+
# test lazy
|
56
|
+
# TODO: temporarily skip "nearest" test
|
57
|
+
if mode == "bilinear":
|
58
|
+
test_resampler_lazy(
|
59
|
+
random_zoom, zoomed, init_param, call_param, key, seed=1234, atol=1e-4 if USE_COMPILED else 1e-6
|
60
|
+
)
|
61
|
+
random_zoom.lazy = False
|
62
|
+
|
63
|
+
test_local_inversion(random_zoom, zoomed, {key: im}, key)
|
64
|
+
expected = [
|
65
|
+
zoom_scipy(channel, zoom=random_zoom.rand_zoom._zoom, mode="nearest", order=0, prefilter=False)
|
66
|
+
for channel in self.imt[0]
|
67
|
+
]
|
68
|
+
|
69
|
+
expected = np.stack(expected).astype(np.float32)
|
70
|
+
assert_allclose(zoomed[key], p(expected), atol=1.0, type_test=False)
|
71
|
+
|
72
|
+
def test_keep_size(self):
|
73
|
+
key = "img"
|
74
|
+
random_zoom = RandZoomd(
|
75
|
+
keys=key, prob=1.0, min_zoom=0.6, max_zoom=0.7, keep_size=True, padding_mode="constant", constant_values=2
|
76
|
+
)
|
77
|
+
for p in TEST_NDARRAYS_ALL:
|
78
|
+
im = p(self.imt[0])
|
79
|
+
zoomed = random_zoom({key: im})
|
80
|
+
test_local_inversion(random_zoom, zoomed, {key: im}, key)
|
81
|
+
np.testing.assert_array_equal(zoomed[key].shape, self.imt.shape[1:])
|
82
|
+
random_zoom.prob = 0.0
|
83
|
+
self.assertEqual(random_zoom({key: p(self.imt[0])})[key].dtype, torch.float32)
|
84
|
+
|
85
|
+
@parameterized.expand(
|
86
|
+
[("no_min_zoom", None, 1.1, "bilinear", TypeError), ("invalid_order", 0.9, 1.1, "s", ValueError)]
|
87
|
+
)
|
88
|
+
def test_invalid_inputs(self, _, min_zoom, max_zoom, mode, raises):
|
89
|
+
key = "img"
|
90
|
+
for p in TEST_NDARRAYS_ALL:
|
91
|
+
with self.assertRaises(raises):
|
92
|
+
random_zoom = RandZoomd(key, prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode)
|
93
|
+
random_zoom({key: p(self.imt[0])})
|
94
|
+
|
95
|
+
def test_auto_expand_3d(self):
|
96
|
+
random_zoom = RandZoomd(
|
97
|
+
keys="img", prob=1.0, min_zoom=[0.8, 0.7], max_zoom=[1.2, 1.3], mode="nearest", keep_size=False
|
98
|
+
)
|
99
|
+
for p in TEST_NDARRAYS_ALL:
|
100
|
+
random_zoom.set_random_state(1234)
|
101
|
+
test_data = {"img": p(np.random.randint(0, 2, size=[2, 2, 3, 4]))}
|
102
|
+
zoomed = random_zoom(test_data)
|
103
|
+
assert_allclose(random_zoom.rand_zoom._zoom, (1.048844, 1.048844, 0.962637), atol=1e-2)
|
104
|
+
assert_allclose(zoomed["img"].shape, (2, 2, 3, 3))
|
105
|
+
|
106
|
+
|
107
|
+
if __name__ == "__main__":
|
108
|
+
unittest.main()
|
@@ -0,0 +1,49 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import monai.transforms as mt
|
17
|
+
from monai.data import CacheDataset
|
18
|
+
from tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose
|
19
|
+
|
20
|
+
|
21
|
+
class T(mt.Transform):
|
22
|
+
def __call__(self, x):
|
23
|
+
return x * 2
|
24
|
+
|
25
|
+
|
26
|
+
class TestIdentity(NumpyImageTestCase2D):
|
27
|
+
def test_identity(self):
|
28
|
+
for p in TEST_NDARRAYS:
|
29
|
+
img = p(self.imt)
|
30
|
+
identity = mt.RandIdentity()
|
31
|
+
assert_allclose(img, identity(img))
|
32
|
+
|
33
|
+
def test_caching(self, init=1, expect=4, expect_pre_cache=2):
|
34
|
+
# check that we get the correct result (two lots of T so should get 4)
|
35
|
+
x = init
|
36
|
+
transforms = mt.Compose([T(), mt.RandIdentity(), T()])
|
37
|
+
self.assertEqual(transforms(x), expect)
|
38
|
+
|
39
|
+
# check we get correct result with CacheDataset
|
40
|
+
x = [init]
|
41
|
+
ds = CacheDataset(x, transforms)
|
42
|
+
self.assertEqual(ds[0], expect)
|
43
|
+
|
44
|
+
# check that the cached value is correct
|
45
|
+
self.assertEqual(ds._cache[0], expect_pre_cache)
|
46
|
+
|
47
|
+
|
48
|
+
if __name__ == "__main__":
|
49
|
+
unittest.main()
|
@@ -0,0 +1,144 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
from copy import deepcopy
|
16
|
+
|
17
|
+
import numpy as np
|
18
|
+
import torch
|
19
|
+
from parameterized import parameterized
|
20
|
+
|
21
|
+
import monai.transforms.intensity.array as ia
|
22
|
+
import monai.transforms.spatial.array as sa
|
23
|
+
import monai.transforms.spatial.dictionary as sd
|
24
|
+
from monai.data import MetaTensor
|
25
|
+
from monai.transforms import RandomOrder
|
26
|
+
from monai.transforms.compose import Compose
|
27
|
+
from monai.utils import set_determinism
|
28
|
+
from monai.utils.enums import TraceKeys
|
29
|
+
from tests.integration.test_one_of import A, B, C, Inv, NonInv, X, Y
|
30
|
+
|
31
|
+
|
32
|
+
class InvC(Inv):
|
33
|
+
def __init__(self, keys):
|
34
|
+
super().__init__(keys)
|
35
|
+
self.fwd_fn = lambda x: x + 1
|
36
|
+
self.inv_fn = lambda x: x - 1
|
37
|
+
|
38
|
+
|
39
|
+
class InvD(Inv):
|
40
|
+
def __init__(self, keys):
|
41
|
+
super().__init__(keys)
|
42
|
+
self.fwd_fn = lambda x: x * 100
|
43
|
+
self.inv_fn = lambda x: x / 100
|
44
|
+
|
45
|
+
|
46
|
+
set_determinism(seed=123)
|
47
|
+
KEYS = ["x", "y"]
|
48
|
+
TEST_INVERSES = [
|
49
|
+
(RandomOrder((InvC(KEYS), InvD(KEYS))), True, True),
|
50
|
+
(Compose((RandomOrder((InvC(KEYS), InvD(KEYS))), RandomOrder((InvD(KEYS), InvC(KEYS))))), True, False),
|
51
|
+
(RandomOrder((RandomOrder((InvC(KEYS), InvD(KEYS))), RandomOrder((InvD(KEYS), InvC(KEYS))))), True, False),
|
52
|
+
(RandomOrder((Compose((InvC(KEYS), InvD(KEYS))), Compose((InvD(KEYS), InvC(KEYS))))), True, False),
|
53
|
+
(RandomOrder((NonInv(KEYS), NonInv(KEYS))), False, False),
|
54
|
+
]
|
55
|
+
|
56
|
+
|
57
|
+
class TestRandomOrder(unittest.TestCase):
|
58
|
+
def test_empty_compose(self):
|
59
|
+
c = RandomOrder()
|
60
|
+
i = 1
|
61
|
+
self.assertEqual(c(i), 1)
|
62
|
+
|
63
|
+
def test_compose_flatten_does_not_affect_random_order(self):
|
64
|
+
p = Compose([A(), B(), RandomOrder([C(), Inv(KEYS), Compose([X(), Y()])])])
|
65
|
+
f = p.flatten()
|
66
|
+
|
67
|
+
# in this case the flattened transform should be the same.
|
68
|
+
def _match(a, b):
|
69
|
+
self.assertEqual(type(a), type(b))
|
70
|
+
for a_, b_ in zip(a.transforms, b.transforms):
|
71
|
+
self.assertEqual(type(a_), type(b_))
|
72
|
+
if isinstance(a_, (Compose, RandomOrder)):
|
73
|
+
_match(a_, b_)
|
74
|
+
|
75
|
+
_match(p, f)
|
76
|
+
|
77
|
+
@parameterized.expand(TEST_INVERSES)
|
78
|
+
def test_inverse(self, transform, invertible, use_metatensor):
|
79
|
+
data = {k: MetaTensor((i + 1) * 10.0) for i, k in enumerate(KEYS)}
|
80
|
+
fwd_data1 = transform(data)
|
81
|
+
# test call twice won't affect inverse
|
82
|
+
fwd_data2 = transform(data)
|
83
|
+
|
84
|
+
if invertible:
|
85
|
+
for k in KEYS:
|
86
|
+
t = fwd_data1[k].applied_operations[-1]
|
87
|
+
# make sure the RandomOrder applied_order was stored
|
88
|
+
self.assertEqual(t[TraceKeys.CLASS_NAME], RandomOrder.__name__)
|
89
|
+
|
90
|
+
# call the inverse
|
91
|
+
fwd_inv_data1 = transform.inverse(fwd_data1)
|
92
|
+
fwd_inv_data2 = transform.inverse(fwd_data2)
|
93
|
+
|
94
|
+
fwd_data = [fwd_data1, fwd_data2]
|
95
|
+
fwd_inv_data = [fwd_inv_data1, fwd_inv_data2]
|
96
|
+
for i, _fwd_inv_data in enumerate(fwd_inv_data):
|
97
|
+
if invertible:
|
98
|
+
for k in KEYS:
|
99
|
+
# check data is same as original (and different from forward)
|
100
|
+
self.assertEqual(_fwd_inv_data[k], data[k])
|
101
|
+
self.assertNotEqual(_fwd_inv_data[k], fwd_data[i][k])
|
102
|
+
else:
|
103
|
+
# if not invertible, should not change the data
|
104
|
+
self.assertDictEqual(fwd_data[i], _fwd_inv_data)
|
105
|
+
|
106
|
+
|
107
|
+
TEST_RANDOM_ORDER_EXTENDED_TEST_CASES = [
|
108
|
+
[None, tuple()],
|
109
|
+
[None, (sa.Rotate(np.pi / 8),)],
|
110
|
+
[None, (sa.Flip(0), sa.Flip(1), sa.Rotate90(1), sa.Zoom(0.8), ia.NormalizeIntensity())],
|
111
|
+
[("a",), (sd.Rotated(("a",), np.pi / 8),)],
|
112
|
+
]
|
113
|
+
|
114
|
+
|
115
|
+
class TestRandomOrderAPITests(unittest.TestCase):
|
116
|
+
@staticmethod
|
117
|
+
def data_from_keys(keys):
|
118
|
+
if keys is None:
|
119
|
+
data = torch.unsqueeze(torch.tensor(np.arange(12 * 16).reshape(12, 16)), dim=0)
|
120
|
+
else:
|
121
|
+
data = {}
|
122
|
+
for i_k, k in enumerate(keys):
|
123
|
+
data[k] = torch.unsqueeze(torch.tensor(np.arange(12 * 16)).reshape(12, 16) + i_k * 192, dim=0)
|
124
|
+
return data
|
125
|
+
|
126
|
+
@parameterized.expand(TEST_RANDOM_ORDER_EXTENDED_TEST_CASES)
|
127
|
+
def test_execute_change_start_end(self, keys, pipeline):
|
128
|
+
data = self.data_from_keys(keys)
|
129
|
+
|
130
|
+
c = RandomOrder(deepcopy(pipeline))
|
131
|
+
with self.assertRaises(ValueError):
|
132
|
+
c(data, start=1)
|
133
|
+
with self.assertRaises(ValueError):
|
134
|
+
c(data, start=1)
|
135
|
+
|
136
|
+
c = RandomOrder(deepcopy(pipeline))
|
137
|
+
with self.assertRaises(ValueError):
|
138
|
+
c(data, end=1)
|
139
|
+
with self.assertRaises(ValueError):
|
140
|
+
c(data, end=1)
|
141
|
+
|
142
|
+
|
143
|
+
if __name__ == "__main__":
|
144
|
+
unittest.main()
|
@@ -0,0 +1,65 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import torch
|
17
|
+
from parameterized import parameterized
|
18
|
+
|
19
|
+
from monai.transforms import RandomizableTrait, RandTorchVisiond
|
20
|
+
from monai.utils import set_determinism
|
21
|
+
from tests.test_utils import assert_allclose
|
22
|
+
|
23
|
+
TEST_CASE_1 = [
|
24
|
+
{"keys": "img", "name": "ColorJitter"},
|
25
|
+
{"img": torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]])},
|
26
|
+
torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]),
|
27
|
+
]
|
28
|
+
|
29
|
+
TEST_CASE_2 = [
|
30
|
+
{"keys": "img", "name": "ColorJitter", "brightness": 0.5, "contrast": 0.5, "saturation": [0.1, 0.8], "hue": 0.5},
|
31
|
+
{"img": torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]])},
|
32
|
+
torch.tensor(
|
33
|
+
[
|
34
|
+
[[0.1090, 0.6193], [0.6193, 0.9164]],
|
35
|
+
[[0.1090, 0.6193], [0.6193, 0.9164]],
|
36
|
+
[[0.1090, 0.6193], [0.6193, 0.9164]],
|
37
|
+
]
|
38
|
+
),
|
39
|
+
]
|
40
|
+
|
41
|
+
TEST_CASE_3 = [
|
42
|
+
{"keys": "img", "name": "Pad", "padding": [1, 1, 1, 1]},
|
43
|
+
{"img": torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]])},
|
44
|
+
torch.tensor(
|
45
|
+
[
|
46
|
+
[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]],
|
47
|
+
[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]],
|
48
|
+
[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]],
|
49
|
+
]
|
50
|
+
),
|
51
|
+
]
|
52
|
+
|
53
|
+
|
54
|
+
class TestRandTorchVisiond(unittest.TestCase):
|
55
|
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
|
56
|
+
def test_value(self, input_param, input_data, expected_value):
|
57
|
+
set_determinism(seed=0)
|
58
|
+
transform = RandTorchVisiond(**input_param)
|
59
|
+
result = transform(input_data)
|
60
|
+
self.assertTrue(isinstance(transform, RandomizableTrait))
|
61
|
+
assert_allclose(result["img"], expected_value, atol=1e-4, rtol=1e-4)
|
62
|
+
|
63
|
+
|
64
|
+
if __name__ == "__main__":
|
65
|
+
unittest.main()
|