monai-weekly 1.5.dev2506__py3-none-any.whl → 1.5.dev2508__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/auto3dseg/transforms.py +1 -4
- monai/data/utils.py +6 -13
- monai/handlers/__init__.py +1 -0
- monai/handlers/average_precision.py +53 -0
- monai/inferers/inferer.py +10 -7
- monai/inferers/utils.py +1 -2
- monai/losses/dice.py +2 -14
- monai/losses/ds_loss.py +1 -3
- monai/metrics/__init__.py +1 -0
- monai/metrics/average_precision.py +187 -0
- monai/networks/layers/simplelayers.py +2 -14
- monai/networks/utils.py +4 -16
- monai/transforms/compose.py +28 -11
- monai/transforms/croppad/array.py +1 -6
- monai/transforms/io/array.py +0 -1
- monai/transforms/transform.py +15 -6
- monai/transforms/utility/array.py +2 -12
- monai/transforms/utils.py +1 -2
- monai/transforms/utils_pytorch_numpy_unification.py +2 -4
- monai/utils/enums.py +3 -2
- monai/utils/module.py +6 -6
- monai/utils/tf32.py +0 -10
- monai/visualize/class_activation_maps.py +5 -8
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/METADATA +21 -17
- monai_weekly-1.5.dev2508.dist-info/RECORD +1185 -0
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/top_level.txt +1 -0
- tests/apps/__init__.py +10 -0
- tests/apps/deepedit/__init__.py +10 -0
- tests/apps/deepedit/test_deepedit_transforms.py +314 -0
- tests/apps/deepgrow/__init__.py +10 -0
- tests/apps/deepgrow/test_deepgrow_dataset.py +109 -0
- tests/apps/deepgrow/transforms/__init__.py +10 -0
- tests/apps/deepgrow/transforms/test_deepgrow_interaction.py +97 -0
- tests/apps/deepgrow/transforms/test_deepgrow_transforms.py +556 -0
- tests/apps/detection/__init__.py +10 -0
- tests/apps/detection/metrics/__init__.py +10 -0
- tests/apps/detection/metrics/test_detection_coco_metrics.py +69 -0
- tests/apps/detection/networks/__init__.py +10 -0
- tests/apps/detection/networks/test_retinanet.py +210 -0
- tests/apps/detection/networks/test_retinanet_detector.py +203 -0
- tests/apps/detection/test_box_transform.py +370 -0
- tests/apps/detection/utils/__init__.py +10 -0
- tests/apps/detection/utils/test_anchor_box.py +88 -0
- tests/apps/detection/utils/test_atss_box_matcher.py +46 -0
- tests/apps/detection/utils/test_box_coder.py +43 -0
- tests/apps/detection/utils/test_detector_boxselector.py +67 -0
- tests/apps/detection/utils/test_detector_utils.py +96 -0
- tests/apps/detection/utils/test_hardnegsampler.py +54 -0
- tests/apps/nuclick/__init__.py +10 -0
- tests/apps/nuclick/test_nuclick_transforms.py +259 -0
- tests/apps/pathology/__init__.py +10 -0
- tests/apps/pathology/handlers/__init__.py +10 -0
- tests/apps/pathology/handlers/test_from_engine_hovernet.py +38 -0
- tests/apps/pathology/test_lesion_froc.py +333 -0
- tests/apps/pathology/test_pathology_prob_nms.py +55 -0
- tests/apps/pathology/test_prepare_batch_hovernet.py +70 -0
- tests/apps/pathology/test_sliding_window_hovernet_inference.py +303 -0
- tests/apps/pathology/transforms/__init__.py +10 -0
- tests/apps/pathology/transforms/post/__init__.py +10 -0
- tests/apps/pathology/transforms/post/test_generate_distance_map.py +51 -0
- tests/apps/pathology/transforms/post/test_generate_distance_mapd.py +70 -0
- tests/apps/pathology/transforms/post/test_generate_instance_border.py +49 -0
- tests/apps/pathology/transforms/post/test_generate_instance_borderd.py +59 -0
- tests/apps/pathology/transforms/post/test_generate_instance_centroid.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_instance_centroidd.py +56 -0
- tests/apps/pathology/transforms/post/test_generate_instance_contour.py +58 -0
- tests/apps/pathology/transforms/post/test_generate_instance_contourd.py +61 -0
- tests/apps/pathology/transforms/post/test_generate_instance_type.py +51 -0
- tests/apps/pathology/transforms/post/test_generate_instance_typed.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_succinct_contour.py +55 -0
- tests/apps/pathology/transforms/post/test_generate_succinct_contourd.py +57 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_markers.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_markersd.py +83 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_mask.py +77 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_maskd.py +77 -0
- tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processing.py +61 -0
- tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processingd.py +66 -0
- tests/apps/pathology/transforms/post/test_hovernet_nuclear_type_post_processing.py +66 -0
- tests/apps/pathology/transforms/post/test_watershed.py +60 -0
- tests/apps/pathology/transforms/post/test_watershedd.py +70 -0
- tests/apps/pathology/transforms/test_pathology_he_stain.py +230 -0
- tests/apps/pathology/transforms/test_pathology_he_stain_dict.py +225 -0
- tests/apps/reconstruction/__init__.py +10 -0
- tests/apps/reconstruction/nets/__init__.py +10 -0
- tests/apps/reconstruction/nets/test_recon_net_utils.py +82 -0
- tests/apps/reconstruction/test_complex_utils.py +77 -0
- tests/apps/reconstruction/test_fastmri_reader.py +82 -0
- tests/apps/reconstruction/test_mri_utils.py +37 -0
- tests/apps/reconstruction/transforms/__init__.py +10 -0
- tests/apps/reconstruction/transforms/test_kspace_mask.py +50 -0
- tests/apps/reconstruction/transforms/test_reference_based_normalize_intensity.py +77 -0
- tests/apps/reconstruction/transforms/test_reference_based_spatial_cropd.py +57 -0
- tests/apps/test_auto3dseg_bundlegen.py +156 -0
- tests/apps/test_check_hash.py +53 -0
- tests/apps/test_cross_validation.py +74 -0
- tests/apps/test_decathlondataset.py +93 -0
- tests/apps/test_download_and_extract.py +70 -0
- tests/apps/test_download_url_yandex.py +45 -0
- tests/apps/test_mednistdataset.py +72 -0
- tests/apps/test_mmar_download.py +154 -0
- tests/apps/test_tciadataset.py +123 -0
- tests/apps/vista3d/__init__.py +10 -0
- tests/apps/vista3d/test_point_based_window_inferer.py +77 -0
- tests/apps/vista3d/test_vista3d_sampler.py +100 -0
- tests/apps/vista3d/test_vista3d_transforms.py +94 -0
- tests/bundle/__init__.py +10 -0
- tests/bundle/test_bundle_ckpt_export.py +107 -0
- tests/bundle/test_bundle_download.py +435 -0
- tests/bundle/test_bundle_get_data.py +94 -0
- tests/bundle/test_bundle_push_to_hf_hub.py +41 -0
- tests/bundle/test_bundle_trt_export.py +147 -0
- tests/bundle/test_bundle_utils.py +149 -0
- tests/bundle/test_bundle_verify_metadata.py +66 -0
- tests/bundle/test_bundle_verify_net.py +76 -0
- tests/bundle/test_bundle_workflow.py +272 -0
- tests/bundle/test_component_locator.py +38 -0
- tests/bundle/test_config_item.py +138 -0
- tests/bundle/test_config_parser.py +392 -0
- tests/bundle/test_reference_resolver.py +114 -0
- tests/config/__init__.py +10 -0
- tests/config/test_cv2_dist.py +53 -0
- tests/engines/__init__.py +10 -0
- tests/engines/test_ensemble_evaluator.py +94 -0
- tests/engines/test_prepare_batch_default.py +76 -0
- tests/engines/test_prepare_batch_default_dist.py +76 -0
- tests/engines/test_prepare_batch_diffusion.py +104 -0
- tests/engines/test_prepare_batch_extra_input.py +80 -0
- tests/fl/__init__.py +10 -0
- tests/fl/monai_algo/__init__.py +10 -0
- tests/fl/monai_algo/test_fl_monai_algo.py +251 -0
- tests/fl/monai_algo/test_fl_monai_algo_dist.py +117 -0
- tests/fl/test_fl_monai_algo_stats.py +81 -0
- tests/fl/utils/__init__.py +10 -0
- tests/fl/utils/test_fl_exchange_object.py +63 -0
- tests/handlers/__init__.py +10 -0
- tests/handlers/test_handler_average_precision.py +79 -0
- tests/handlers/test_handler_checkpoint_loader.py +182 -0
- tests/handlers/test_handler_checkpoint_saver.py +233 -0
- tests/handlers/test_handler_classification_saver.py +64 -0
- tests/handlers/test_handler_classification_saver_dist.py +77 -0
- tests/handlers/test_handler_clearml_image.py +65 -0
- tests/handlers/test_handler_clearml_stats.py +65 -0
- tests/handlers/test_handler_confusion_matrix.py +104 -0
- tests/handlers/test_handler_confusion_matrix_dist.py +70 -0
- tests/handlers/test_handler_decollate_batch.py +66 -0
- tests/handlers/test_handler_early_stop.py +68 -0
- tests/handlers/test_handler_garbage_collector.py +73 -0
- tests/handlers/test_handler_hausdorff_distance.py +111 -0
- tests/handlers/test_handler_ignite_metric.py +191 -0
- tests/handlers/test_handler_lr_scheduler.py +94 -0
- tests/handlers/test_handler_mean_dice.py +98 -0
- tests/handlers/test_handler_mean_iou.py +76 -0
- tests/handlers/test_handler_metrics_reloaded.py +149 -0
- tests/handlers/test_handler_metrics_saver.py +89 -0
- tests/handlers/test_handler_metrics_saver_dist.py +120 -0
- tests/handlers/test_handler_mlflow.py +296 -0
- tests/handlers/test_handler_nvtx.py +93 -0
- tests/handlers/test_handler_panoptic_quality.py +89 -0
- tests/handlers/test_handler_parameter_scheduler.py +136 -0
- tests/handlers/test_handler_post_processing.py +74 -0
- tests/handlers/test_handler_prob_map_producer.py +111 -0
- tests/handlers/test_handler_regression_metrics.py +160 -0
- tests/handlers/test_handler_regression_metrics_dist.py +245 -0
- tests/handlers/test_handler_rocauc.py +48 -0
- tests/handlers/test_handler_rocauc_dist.py +54 -0
- tests/handlers/test_handler_stats.py +281 -0
- tests/handlers/test_handler_surface_distance.py +113 -0
- tests/handlers/test_handler_tb_image.py +61 -0
- tests/handlers/test_handler_tb_stats.py +166 -0
- tests/handlers/test_handler_validation.py +59 -0
- tests/handlers/test_trt_compile.py +145 -0
- tests/handlers/test_write_metrics_reports.py +68 -0
- tests/inferers/__init__.py +10 -0
- tests/inferers/test_avg_merger.py +179 -0
- tests/inferers/test_controlnet_inferers.py +1388 -0
- tests/inferers/test_diffusion_inferer.py +236 -0
- tests/inferers/test_latent_diffusion_inferer.py +884 -0
- tests/inferers/test_patch_inferer.py +309 -0
- tests/inferers/test_saliency_inferer.py +55 -0
- tests/inferers/test_slice_inferer.py +57 -0
- tests/inferers/test_sliding_window_inference.py +377 -0
- tests/inferers/test_sliding_window_splitter.py +284 -0
- tests/inferers/test_wsi_sliding_window_splitter.py +249 -0
- tests/inferers/test_zarr_avg_merger.py +326 -0
- tests/integration/__init__.py +10 -0
- tests/integration/test_auto3dseg_ensemble.py +211 -0
- tests/integration/test_auto3dseg_hpo.py +189 -0
- tests/integration/test_deepedit_interaction.py +122 -0
- tests/integration/test_downsample_block.py +50 -0
- tests/integration/test_hovernet_nuclear_type_post_processingd.py +71 -0
- tests/integration/test_integration_autorunner.py +201 -0
- tests/integration/test_integration_bundle_run.py +240 -0
- tests/integration/test_integration_classification_2d.py +282 -0
- tests/integration/test_integration_determinism.py +95 -0
- tests/integration/test_integration_fast_train.py +231 -0
- tests/integration/test_integration_gpu_customization.py +159 -0
- tests/integration/test_integration_lazy_samples.py +219 -0
- tests/integration/test_integration_nnunetv2_runner.py +96 -0
- tests/integration/test_integration_segmentation_3d.py +304 -0
- tests/integration/test_integration_sliding_window.py +100 -0
- tests/integration/test_integration_stn.py +133 -0
- tests/integration/test_integration_unet_2d.py +67 -0
- tests/integration/test_integration_workers.py +61 -0
- tests/integration/test_integration_workflows.py +365 -0
- tests/integration/test_integration_workflows_adversarial.py +173 -0
- tests/integration/test_integration_workflows_gan.py +158 -0
- tests/integration/test_loader_semaphore.py +48 -0
- tests/integration/test_mapping_filed.py +122 -0
- tests/integration/test_meta_affine.py +183 -0
- tests/integration/test_metatensor_integration.py +114 -0
- tests/integration/test_module_list.py +76 -0
- tests/integration/test_one_of.py +283 -0
- tests/integration/test_pad_collation.py +124 -0
- tests/integration/test_reg_loss_integration.py +107 -0
- tests/integration/test_retinanet_predict_utils.py +154 -0
- tests/integration/test_seg_loss_integration.py +159 -0
- tests/integration/test_spatial_combine_transforms.py +185 -0
- tests/integration/test_testtimeaugmentation.py +186 -0
- tests/integration/test_vis_gradbased.py +69 -0
- tests/integration/test_vista3d_utils.py +159 -0
- tests/losses/__init__.py +10 -0
- tests/losses/deform/__init__.py +10 -0
- tests/losses/deform/test_bending_energy.py +88 -0
- tests/losses/deform/test_diffusion_loss.py +117 -0
- tests/losses/image_dissimilarity/__init__.py +10 -0
- tests/losses/image_dissimilarity/test_global_mutual_information_loss.py +150 -0
- tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py +162 -0
- tests/losses/test_adversarial_loss.py +94 -0
- tests/losses/test_barlow_twins_loss.py +109 -0
- tests/losses/test_cldice_loss.py +51 -0
- tests/losses/test_contrastive_loss.py +86 -0
- tests/losses/test_dice_ce_loss.py +123 -0
- tests/losses/test_dice_focal_loss.py +124 -0
- tests/losses/test_dice_loss.py +227 -0
- tests/losses/test_ds_loss.py +189 -0
- tests/losses/test_focal_loss.py +379 -0
- tests/losses/test_generalized_dice_focal_loss.py +85 -0
- tests/losses/test_generalized_dice_loss.py +221 -0
- tests/losses/test_generalized_wasserstein_dice_loss.py +234 -0
- tests/losses/test_giou_loss.py +62 -0
- tests/losses/test_hausdorff_loss.py +264 -0
- tests/losses/test_masked_dice_loss.py +152 -0
- tests/losses/test_masked_loss.py +87 -0
- tests/losses/test_multi_scale.py +86 -0
- tests/losses/test_nacl_loss.py +167 -0
- tests/losses/test_perceptual_loss.py +122 -0
- tests/losses/test_spectral_loss.py +86 -0
- tests/losses/test_ssim_loss.py +59 -0
- tests/losses/test_sure_loss.py +72 -0
- tests/losses/test_tversky_loss.py +198 -0
- tests/losses/test_unified_focal_loss.py +66 -0
- tests/metrics/__init__.py +10 -0
- tests/metrics/test_compute_average_precision.py +162 -0
- tests/metrics/test_compute_confusion_matrix.py +294 -0
- tests/metrics/test_compute_f_beta.py +80 -0
- tests/metrics/test_compute_fid_metric.py +40 -0
- tests/metrics/test_compute_froc.py +143 -0
- tests/metrics/test_compute_generalized_dice.py +240 -0
- tests/metrics/test_compute_meandice.py +306 -0
- tests/metrics/test_compute_meaniou.py +223 -0
- tests/metrics/test_compute_mmd_metric.py +56 -0
- tests/metrics/test_compute_multiscalessim_metric.py +83 -0
- tests/metrics/test_compute_panoptic_quality.py +113 -0
- tests/metrics/test_compute_regression_metrics.py +196 -0
- tests/metrics/test_compute_roc_auc.py +155 -0
- tests/metrics/test_compute_variance.py +147 -0
- tests/metrics/test_cumulative.py +63 -0
- tests/metrics/test_cumulative_average.py +74 -0
- tests/metrics/test_cumulative_average_dist.py +48 -0
- tests/metrics/test_hausdorff_distance.py +209 -0
- tests/metrics/test_label_quality_score.py +134 -0
- tests/metrics/test_loss_metric.py +57 -0
- tests/metrics/test_metrics_reloaded.py +96 -0
- tests/metrics/test_ssim_metric.py +78 -0
- tests/metrics/test_surface_dice.py +416 -0
- tests/metrics/test_surface_distance.py +186 -0
- tests/networks/__init__.py +10 -0
- tests/networks/blocks/__init__.py +10 -0
- tests/networks/blocks/dints_block/__init__.py +10 -0
- tests/networks/blocks/dints_block/test_acn_block.py +41 -0
- tests/networks/blocks/dints_block/test_factorized_increase.py +37 -0
- tests/networks/blocks/dints_block/test_factorized_reduce.py +37 -0
- tests/networks/blocks/dints_block/test_p3d_block.py +78 -0
- tests/networks/blocks/test_adn.py +86 -0
- tests/networks/blocks/test_convolutions.py +156 -0
- tests/networks/blocks/test_crf_cpu.py +513 -0
- tests/networks/blocks/test_crf_cuda.py +528 -0
- tests/networks/blocks/test_crossattention.py +185 -0
- tests/networks/blocks/test_denseblock.py +105 -0
- tests/networks/blocks/test_dynunet_block.py +116 -0
- tests/networks/blocks/test_fpn_block.py +88 -0
- tests/networks/blocks/test_localnet_block.py +121 -0
- tests/networks/blocks/test_mlp.py +78 -0
- tests/networks/blocks/test_patchembedding.py +212 -0
- tests/networks/blocks/test_regunet_block.py +103 -0
- tests/networks/blocks/test_se_block.py +85 -0
- tests/networks/blocks/test_se_blocks.py +78 -0
- tests/networks/blocks/test_segresnet_block.py +57 -0
- tests/networks/blocks/test_selfattention.py +232 -0
- tests/networks/blocks/test_simple_aspp.py +87 -0
- tests/networks/blocks/test_spatialattention.py +55 -0
- tests/networks/blocks/test_subpixel_upsample.py +87 -0
- tests/networks/blocks/test_text_encoding.py +49 -0
- tests/networks/blocks/test_transformerblock.py +90 -0
- tests/networks/blocks/test_unetr_block.py +158 -0
- tests/networks/blocks/test_upsample_block.py +134 -0
- tests/networks/blocks/warp/__init__.py +10 -0
- tests/networks/blocks/warp/test_dvf2ddf.py +72 -0
- tests/networks/blocks/warp/test_warp.py +250 -0
- tests/networks/layers/__init__.py +10 -0
- tests/networks/layers/filtering/__init__.py +10 -0
- tests/networks/layers/filtering/test_bilateral_approx_cpu.py +399 -0
- tests/networks/layers/filtering/test_bilateral_approx_cuda.py +404 -0
- tests/networks/layers/filtering/test_bilateral_precise.py +437 -0
- tests/networks/layers/filtering/test_phl_cpu.py +259 -0
- tests/networks/layers/filtering/test_phl_cuda.py +167 -0
- tests/networks/layers/filtering/test_trainable_bilateral.py +474 -0
- tests/networks/layers/filtering/test_trainable_joint_bilateral.py +609 -0
- tests/networks/layers/test_affine_transform.py +385 -0
- tests/networks/layers/test_apply_filter.py +89 -0
- tests/networks/layers/test_channel_pad.py +51 -0
- tests/networks/layers/test_conjugate_gradient.py +56 -0
- tests/networks/layers/test_drop_path.py +46 -0
- tests/networks/layers/test_gaussian.py +317 -0
- tests/networks/layers/test_gaussian_filter.py +206 -0
- tests/networks/layers/test_get_layers.py +65 -0
- tests/networks/layers/test_gmm.py +314 -0
- tests/networks/layers/test_grid_pull.py +93 -0
- tests/networks/layers/test_hilbert_transform.py +131 -0
- tests/networks/layers/test_lltm.py +62 -0
- tests/networks/layers/test_median_filter.py +52 -0
- tests/networks/layers/test_polyval.py +55 -0
- tests/networks/layers/test_preset_filters.py +136 -0
- tests/networks/layers/test_savitzky_golay_filter.py +141 -0
- tests/networks/layers/test_separable_filter.py +87 -0
- tests/networks/layers/test_skip_connection.py +48 -0
- tests/networks/layers/test_vector_quantizer.py +89 -0
- tests/networks/layers/test_weight_init.py +50 -0
- tests/networks/nets/__init__.py +10 -0
- tests/networks/nets/dints/__init__.py +10 -0
- tests/networks/nets/dints/test_dints_cell.py +110 -0
- tests/networks/nets/dints/test_dints_mixop.py +84 -0
- tests/networks/nets/regunet/__init__.py +10 -0
- tests/networks/nets/regunet/test_localnet.py +86 -0
- tests/networks/nets/regunet/test_regunet.py +88 -0
- tests/networks/nets/test_ahnet.py +224 -0
- tests/networks/nets/test_attentionunet.py +88 -0
- tests/networks/nets/test_autoencoder.py +95 -0
- tests/networks/nets/test_autoencoderkl.py +337 -0
- tests/networks/nets/test_basic_unet.py +102 -0
- tests/networks/nets/test_basic_unetplusplus.py +109 -0
- tests/networks/nets/test_bundle_init_bundle.py +55 -0
- tests/networks/nets/test_cell_sam_wrapper.py +58 -0
- tests/networks/nets/test_controlnet.py +215 -0
- tests/networks/nets/test_daf3d.py +62 -0
- tests/networks/nets/test_densenet.py +121 -0
- tests/networks/nets/test_diffusion_model_unet.py +585 -0
- tests/networks/nets/test_dints_network.py +168 -0
- tests/networks/nets/test_discriminator.py +59 -0
- tests/networks/nets/test_dynunet.py +181 -0
- tests/networks/nets/test_efficientnet.py +400 -0
- tests/networks/nets/test_flexible_unet.py +341 -0
- tests/networks/nets/test_fullyconnectednet.py +69 -0
- tests/networks/nets/test_generator.py +59 -0
- tests/networks/nets/test_globalnet.py +103 -0
- tests/networks/nets/test_highresnet.py +67 -0
- tests/networks/nets/test_hovernet.py +218 -0
- tests/networks/nets/test_mednext.py +122 -0
- tests/networks/nets/test_milmodel.py +92 -0
- tests/networks/nets/test_net_adapter.py +68 -0
- tests/networks/nets/test_network_consistency.py +86 -0
- tests/networks/nets/test_patch_gan_dicriminator.py +179 -0
- tests/networks/nets/test_quicknat.py +57 -0
- tests/networks/nets/test_resnet.py +340 -0
- tests/networks/nets/test_segresnet.py +120 -0
- tests/networks/nets/test_segresnet_ds.py +156 -0
- tests/networks/nets/test_senet.py +151 -0
- tests/networks/nets/test_spade_autoencoderkl.py +295 -0
- tests/networks/nets/test_spade_diffusion_model_unet.py +574 -0
- tests/networks/nets/test_spade_vaegan.py +140 -0
- tests/networks/nets/test_swin_unetr.py +139 -0
- tests/networks/nets/test_torchvision_fc_model.py +201 -0
- tests/networks/nets/test_transchex.py +84 -0
- tests/networks/nets/test_transformer.py +108 -0
- tests/networks/nets/test_unet.py +208 -0
- tests/networks/nets/test_unetr.py +137 -0
- tests/networks/nets/test_varautoencoder.py +127 -0
- tests/networks/nets/test_vista3d.py +84 -0
- tests/networks/nets/test_vit.py +139 -0
- tests/networks/nets/test_vitautoenc.py +112 -0
- tests/networks/nets/test_vnet.py +81 -0
- tests/networks/nets/test_voxelmorph.py +280 -0
- tests/networks/nets/test_vqvae.py +274 -0
- tests/networks/nets/test_vqvaetransformer_inferer.py +295 -0
- tests/networks/schedulers/__init__.py +10 -0
- tests/networks/schedulers/test_scheduler_ddim.py +83 -0
- tests/networks/schedulers/test_scheduler_ddpm.py +104 -0
- tests/networks/schedulers/test_scheduler_pndm.py +108 -0
- tests/networks/test_bundle_onnx_export.py +71 -0
- tests/networks/test_convert_to_onnx.py +106 -0
- tests/networks/test_convert_to_torchscript.py +46 -0
- tests/networks/test_convert_to_trt.py +79 -0
- tests/networks/test_save_state.py +73 -0
- tests/networks/test_to_onehot.py +63 -0
- tests/networks/test_varnet.py +63 -0
- tests/networks/utils/__init__.py +10 -0
- tests/networks/utils/test_copy_model_state.py +187 -0
- tests/networks/utils/test_eval_mode.py +34 -0
- tests/networks/utils/test_freeze_layers.py +61 -0
- tests/networks/utils/test_replace_module.py +98 -0
- tests/networks/utils/test_train_mode.py +34 -0
- tests/optimizers/__init__.py +10 -0
- tests/optimizers/test_generate_param_groups.py +105 -0
- tests/optimizers/test_lr_finder.py +108 -0
- tests/optimizers/test_lr_scheduler.py +71 -0
- tests/optimizers/test_optim_novograd.py +100 -0
- tests/profile_subclass/__init__.py +10 -0
- tests/profile_subclass/cprofile_profiling.py +29 -0
- tests/profile_subclass/min_classes.py +30 -0
- tests/profile_subclass/profiling.py +73 -0
- tests/profile_subclass/pyspy_profiling.py +41 -0
- tests/transforms/__init__.py +10 -0
- tests/transforms/compose/__init__.py +10 -0
- tests/transforms/compose/test_compose.py +758 -0
- tests/transforms/compose/test_some_of.py +258 -0
- tests/transforms/croppad/__init__.py +10 -0
- tests/transforms/croppad/test_rand_weighted_crop.py +224 -0
- tests/transforms/croppad/test_rand_weighted_cropd.py +182 -0
- tests/transforms/functional/__init__.py +10 -0
- tests/transforms/functional/test_apply.py +75 -0
- tests/transforms/functional/test_resample.py +50 -0
- tests/transforms/intensity/__init__.py +10 -0
- tests/transforms/intensity/test_compute_ho_ver_maps.py +75 -0
- tests/transforms/intensity/test_compute_ho_ver_maps_d.py +79 -0
- tests/transforms/intensity/test_foreground_mask.py +98 -0
- tests/transforms/intensity/test_foreground_maskd.py +106 -0
- tests/transforms/intensity/test_rand_histogram_shiftd.py +76 -0
- tests/transforms/intensity/test_scale_intensity_range_percentiles.py +96 -0
- tests/transforms/intensity/test_scale_intensity_range_percentilesd.py +100 -0
- tests/transforms/inverse/__init__.py +10 -0
- tests/transforms/inverse/test_inverse_array.py +76 -0
- tests/transforms/inverse/test_traceable_transform.py +59 -0
- tests/transforms/post/__init__.py +10 -0
- tests/transforms/post/test_label_filterd.py +78 -0
- tests/transforms/post/test_probnms.py +72 -0
- tests/transforms/post/test_probnmsd.py +79 -0
- tests/transforms/post/test_remove_small_objects.py +102 -0
- tests/transforms/spatial/__init__.py +10 -0
- tests/transforms/spatial/test_convert_box_points.py +119 -0
- tests/transforms/spatial/test_grid_patch.py +134 -0
- tests/transforms/spatial/test_grid_patchd.py +102 -0
- tests/transforms/spatial/test_rand_grid_patch.py +150 -0
- tests/transforms/spatial/test_rand_grid_patchd.py +117 -0
- tests/transforms/spatial/test_spatial_resampled.py +124 -0
- tests/transforms/test_activations.py +120 -0
- tests/transforms/test_activationsd.py +64 -0
- tests/transforms/test_adaptors.py +160 -0
- tests/transforms/test_add_coordinate_channels.py +53 -0
- tests/transforms/test_add_coordinate_channelsd.py +67 -0
- tests/transforms/test_add_extreme_points_channel.py +80 -0
- tests/transforms/test_add_extreme_points_channeld.py +77 -0
- tests/transforms/test_adjust_contrast.py +70 -0
- tests/transforms/test_adjust_contrastd.py +64 -0
- tests/transforms/test_affine.py +245 -0
- tests/transforms/test_affine_grid.py +152 -0
- tests/transforms/test_affined.py +190 -0
- tests/transforms/test_as_channel_last.py +38 -0
- tests/transforms/test_as_channel_lastd.py +44 -0
- tests/transforms/test_as_discrete.py +81 -0
- tests/transforms/test_as_discreted.py +82 -0
- tests/transforms/test_border_pad.py +49 -0
- tests/transforms/test_border_padd.py +45 -0
- tests/transforms/test_bounding_rect.py +54 -0
- tests/transforms/test_bounding_rectd.py +53 -0
- tests/transforms/test_cast_to_type.py +63 -0
- tests/transforms/test_cast_to_typed.py +74 -0
- tests/transforms/test_center_scale_crop.py +55 -0
- tests/transforms/test_center_scale_cropd.py +56 -0
- tests/transforms/test_center_spatial_crop.py +56 -0
- tests/transforms/test_center_spatial_cropd.py +63 -0
- tests/transforms/test_classes_to_indices.py +93 -0
- tests/transforms/test_classes_to_indicesd.py +110 -0
- tests/transforms/test_clip_intensity_percentiles.py +196 -0
- tests/transforms/test_clip_intensity_percentilesd.py +193 -0
- tests/transforms/test_compose_get_number_conversions.py +127 -0
- tests/transforms/test_concat_itemsd.py +82 -0
- tests/transforms/test_convert_to_multi_channel.py +59 -0
- tests/transforms/test_convert_to_multi_channeld.py +37 -0
- tests/transforms/test_copy_itemsd.py +86 -0
- tests/transforms/test_create_grid_and_affine.py +274 -0
- tests/transforms/test_crop_foreground.py +164 -0
- tests/transforms/test_crop_foregroundd.py +205 -0
- tests/transforms/test_cucim_dict_transform.py +142 -0
- tests/transforms/test_cucim_transform.py +141 -0
- tests/transforms/test_data_stats.py +221 -0
- tests/transforms/test_data_statsd.py +249 -0
- tests/transforms/test_delete_itemsd.py +58 -0
- tests/transforms/test_detect_envelope.py +159 -0
- tests/transforms/test_distance_transform_edt.py +202 -0
- tests/transforms/test_divisible_pad.py +49 -0
- tests/transforms/test_divisible_padd.py +42 -0
- tests/transforms/test_ensure_channel_first.py +113 -0
- tests/transforms/test_ensure_channel_firstd.py +85 -0
- tests/transforms/test_ensure_type.py +94 -0
- tests/transforms/test_ensure_typed.py +110 -0
- tests/transforms/test_fg_bg_to_indices.py +83 -0
- tests/transforms/test_fg_bg_to_indicesd.py +78 -0
- tests/transforms/test_fill_holes.py +207 -0
- tests/transforms/test_fill_holesd.py +209 -0
- tests/transforms/test_flatten_sub_keysd.py +64 -0
- tests/transforms/test_flip.py +83 -0
- tests/transforms/test_flipd.py +90 -0
- tests/transforms/test_fourier.py +70 -0
- tests/transforms/test_gaussian_sharpen.py +92 -0
- tests/transforms/test_gaussian_sharpend.py +92 -0
- tests/transforms/test_gaussian_smooth.py +96 -0
- tests/transforms/test_gaussian_smoothd.py +96 -0
- tests/transforms/test_generate_label_classes_crop_centers.py +71 -0
- tests/transforms/test_generate_pos_neg_label_crop_centers.py +76 -0
- tests/transforms/test_generate_spatial_bounding_box.py +114 -0
- tests/transforms/test_get_extreme_points.py +57 -0
- tests/transforms/test_gibbs_noise.py +73 -0
- tests/transforms/test_gibbs_noised.py +88 -0
- tests/transforms/test_grid_distortion.py +113 -0
- tests/transforms/test_grid_distortiond.py +87 -0
- tests/transforms/test_grid_split.py +88 -0
- tests/transforms/test_grid_splitd.py +96 -0
- tests/transforms/test_histogram_normalize.py +59 -0
- tests/transforms/test_histogram_normalized.py +59 -0
- tests/transforms/test_image_filter.py +259 -0
- tests/transforms/test_intensity_stats.py +73 -0
- tests/transforms/test_intensity_statsd.py +90 -0
- tests/transforms/test_inverse.py +521 -0
- tests/transforms/test_inverse_collation.py +147 -0
- tests/transforms/test_invert.py +105 -0
- tests/transforms/test_invertd.py +142 -0
- tests/transforms/test_k_space_spike_noise.py +81 -0
- tests/transforms/test_k_space_spike_noised.py +98 -0
- tests/transforms/test_keep_largest_connected_component.py +419 -0
- tests/transforms/test_keep_largest_connected_componentd.py +348 -0
- tests/transforms/test_label_filter.py +78 -0
- tests/transforms/test_label_to_contour.py +179 -0
- tests/transforms/test_label_to_contourd.py +182 -0
- tests/transforms/test_label_to_mask.py +69 -0
- tests/transforms/test_label_to_maskd.py +70 -0
- tests/transforms/test_load_image.py +502 -0
- tests/transforms/test_load_imaged.py +198 -0
- tests/transforms/test_load_spacing_orientation.py +149 -0
- tests/transforms/test_map_and_generate_sampling_centers.py +86 -0
- tests/transforms/test_map_binary_to_indices.py +75 -0
- tests/transforms/test_map_classes_to_indices.py +135 -0
- tests/transforms/test_map_label_value.py +89 -0
- tests/transforms/test_map_label_valued.py +85 -0
- tests/transforms/test_map_transform.py +45 -0
- tests/transforms/test_mask_intensity.py +74 -0
- tests/transforms/test_mask_intensityd.py +68 -0
- tests/transforms/test_mean_ensemble.py +77 -0
- tests/transforms/test_mean_ensembled.py +91 -0
- tests/transforms/test_median_smooth.py +41 -0
- tests/transforms/test_median_smoothd.py +65 -0
- tests/transforms/test_morphological_ops.py +101 -0
- tests/transforms/test_nifti_endianness.py +107 -0
- tests/transforms/test_normalize_intensity.py +143 -0
- tests/transforms/test_normalize_intensityd.py +81 -0
- tests/transforms/test_nvtx_decorator.py +289 -0
- tests/transforms/test_nvtx_transform.py +143 -0
- tests/transforms/test_orientation.py +247 -0
- tests/transforms/test_orientationd.py +112 -0
- tests/transforms/test_rand_adjust_contrast.py +45 -0
- tests/transforms/test_rand_adjust_contrastd.py +44 -0
- tests/transforms/test_rand_affine.py +201 -0
- tests/transforms/test_rand_affine_grid.py +212 -0
- tests/transforms/test_rand_affined.py +281 -0
- tests/transforms/test_rand_axis_flip.py +50 -0
- tests/transforms/test_rand_axis_flipd.py +50 -0
- tests/transforms/test_rand_bias_field.py +69 -0
- tests/transforms/test_rand_bias_fieldd.py +65 -0
- tests/transforms/test_rand_coarse_dropout.py +110 -0
- tests/transforms/test_rand_coarse_dropoutd.py +107 -0
- tests/transforms/test_rand_coarse_shuffle.py +65 -0
- tests/transforms/test_rand_coarse_shuffled.py +59 -0
- tests/transforms/test_rand_crop_by_label_classes.py +170 -0
- tests/transforms/test_rand_crop_by_label_classesd.py +159 -0
- tests/transforms/test_rand_crop_by_pos_neg_label.py +152 -0
- tests/transforms/test_rand_crop_by_pos_neg_labeld.py +172 -0
- tests/transforms/test_rand_cucim_dict_transform.py +162 -0
- tests/transforms/test_rand_cucim_transform.py +162 -0
- tests/transforms/test_rand_deform_grid.py +138 -0
- tests/transforms/test_rand_elastic_2d.py +127 -0
- tests/transforms/test_rand_elastic_3d.py +104 -0
- tests/transforms/test_rand_elasticd_2d.py +177 -0
- tests/transforms/test_rand_elasticd_3d.py +156 -0
- tests/transforms/test_rand_flip.py +60 -0
- tests/transforms/test_rand_flipd.py +55 -0
- tests/transforms/test_rand_gaussian_noise.py +48 -0
- tests/transforms/test_rand_gaussian_noised.py +54 -0
- tests/transforms/test_rand_gaussian_sharpen.py +140 -0
- tests/transforms/test_rand_gaussian_sharpend.py +143 -0
- tests/transforms/test_rand_gaussian_smooth.py +98 -0
- tests/transforms/test_rand_gaussian_smoothd.py +98 -0
- tests/transforms/test_rand_gibbs_noise.py +103 -0
- tests/transforms/test_rand_gibbs_noised.py +117 -0
- tests/transforms/test_rand_grid_distortion.py +99 -0
- tests/transforms/test_rand_grid_distortiond.py +90 -0
- tests/transforms/test_rand_histogram_shift.py +92 -0
- tests/transforms/test_rand_k_space_spike_noise.py +92 -0
- tests/transforms/test_rand_k_space_spike_noised.py +76 -0
- tests/transforms/test_rand_rician_noise.py +52 -0
- tests/transforms/test_rand_rician_noised.py +52 -0
- tests/transforms/test_rand_rotate.py +166 -0
- tests/transforms/test_rand_rotate90.py +100 -0
- tests/transforms/test_rand_rotate90d.py +112 -0
- tests/transforms/test_rand_rotated.py +187 -0
- tests/transforms/test_rand_scale_crop.py +78 -0
- tests/transforms/test_rand_scale_cropd.py +98 -0
- tests/transforms/test_rand_scale_intensity.py +54 -0
- tests/transforms/test_rand_scale_intensity_fixed_mean.py +41 -0
- tests/transforms/test_rand_scale_intensity_fixed_meand.py +41 -0
- tests/transforms/test_rand_scale_intensityd.py +53 -0
- tests/transforms/test_rand_shift_intensity.py +52 -0
- tests/transforms/test_rand_shift_intensityd.py +67 -0
- tests/transforms/test_rand_simulate_low_resolution.py +83 -0
- tests/transforms/test_rand_simulate_low_resolutiond.py +73 -0
- tests/transforms/test_rand_spatial_crop.py +107 -0
- tests/transforms/test_rand_spatial_crop_samples.py +128 -0
- tests/transforms/test_rand_spatial_crop_samplesd.py +147 -0
- tests/transforms/test_rand_spatial_cropd.py +112 -0
- tests/transforms/test_rand_std_shift_intensity.py +43 -0
- tests/transforms/test_rand_std_shift_intensityd.py +38 -0
- tests/transforms/test_rand_zoom.py +105 -0
- tests/transforms/test_rand_zoomd.py +108 -0
- tests/transforms/test_randidentity.py +49 -0
- tests/transforms/test_random_order.py +144 -0
- tests/transforms/test_randtorchvisiond.py +65 -0
- tests/transforms/test_regularization.py +139 -0
- tests/transforms/test_remove_repeated_channel.py +34 -0
- tests/transforms/test_remove_repeated_channeld.py +44 -0
- tests/transforms/test_repeat_channel.py +34 -0
- tests/transforms/test_repeat_channeld.py +41 -0
- tests/transforms/test_resample_backends.py +65 -0
- tests/transforms/test_resample_to_match.py +110 -0
- tests/transforms/test_resample_to_matchd.py +93 -0
- tests/transforms/test_resampler.py +165 -0
- tests/transforms/test_resize.py +140 -0
- tests/transforms/test_resize_with_pad_or_crop.py +91 -0
- tests/transforms/test_resize_with_pad_or_cropd.py +86 -0
- tests/transforms/test_resized.py +163 -0
- tests/transforms/test_rotate.py +160 -0
- tests/transforms/test_rotate90.py +212 -0
- tests/transforms/test_rotate90d.py +106 -0
- tests/transforms/test_rotated.py +179 -0
- tests/transforms/test_save_classificationd.py +109 -0
- tests/transforms/test_save_image.py +80 -0
- tests/transforms/test_save_imaged.py +130 -0
- tests/transforms/test_savitzky_golay_smooth.py +73 -0
- tests/transforms/test_savitzky_golay_smoothd.py +73 -0
- tests/transforms/test_scale_intensity.py +76 -0
- tests/transforms/test_scale_intensity_fixed_mean.py +94 -0
- tests/transforms/test_scale_intensity_range.py +41 -0
- tests/transforms/test_scale_intensity_ranged.py +40 -0
- tests/transforms/test_scale_intensityd.py +57 -0
- tests/transforms/test_select_itemsd.py +41 -0
- tests/transforms/test_shift_intensity.py +31 -0
- tests/transforms/test_shift_intensityd.py +44 -0
- tests/transforms/test_signal_continuouswavelet.py +44 -0
- tests/transforms/test_signal_fillempty.py +52 -0
- tests/transforms/test_signal_fillemptyd.py +60 -0
- tests/transforms/test_signal_rand_add_gaussiannoise.py +50 -0
- tests/transforms/test_signal_rand_add_sine.py +52 -0
- tests/transforms/test_signal_rand_add_sine_partial.py +50 -0
- tests/transforms/test_signal_rand_add_squarepulse.py +58 -0
- tests/transforms/test_signal_rand_add_squarepulse_partial.py +62 -0
- tests/transforms/test_signal_rand_drop.py +50 -0
- tests/transforms/test_signal_rand_scale.py +52 -0
- tests/transforms/test_signal_rand_shift.py +55 -0
- tests/transforms/test_signal_remove_frequency.py +71 -0
- tests/transforms/test_smooth_field.py +177 -0
- tests/transforms/test_sobel_gradient.py +189 -0
- tests/transforms/test_sobel_gradientd.py +212 -0
- tests/transforms/test_spacing.py +381 -0
- tests/transforms/test_spacingd.py +178 -0
- tests/transforms/test_spatial_crop.py +82 -0
- tests/transforms/test_spatial_cropd.py +74 -0
- tests/transforms/test_spatial_pad.py +57 -0
- tests/transforms/test_spatial_padd.py +43 -0
- tests/transforms/test_spatial_resample.py +235 -0
- tests/transforms/test_squeezedim.py +62 -0
- tests/transforms/test_squeezedimd.py +98 -0
- tests/transforms/test_std_shift_intensity.py +76 -0
- tests/transforms/test_std_shift_intensityd.py +74 -0
- tests/transforms/test_threshold_intensity.py +38 -0
- tests/transforms/test_threshold_intensityd.py +58 -0
- tests/transforms/test_to_contiguous.py +47 -0
- tests/transforms/test_to_cupy.py +112 -0
- tests/transforms/test_to_cupyd.py +76 -0
- tests/transforms/test_to_device.py +42 -0
- tests/transforms/test_to_deviced.py +37 -0
- tests/transforms/test_to_numpy.py +85 -0
- tests/transforms/test_to_numpyd.py +68 -0
- tests/transforms/test_to_pil.py +52 -0
- tests/transforms/test_to_pild.py +50 -0
- tests/transforms/test_to_tensor.py +60 -0
- tests/transforms/test_to_tensord.py +71 -0
- tests/transforms/test_torchvision.py +66 -0
- tests/transforms/test_torchvisiond.py +63 -0
- tests/transforms/test_transform.py +62 -0
- tests/transforms/test_transpose.py +41 -0
- tests/transforms/test_transposed.py +52 -0
- tests/transforms/test_ultrasound_confidence_map_transform.py +711 -0
- tests/transforms/test_utils_pytorch_numpy_unification.py +90 -0
- tests/transforms/test_vote_ensemble.py +84 -0
- tests/transforms/test_vote_ensembled.py +107 -0
- tests/transforms/test_with_allow_missing_keys.py +76 -0
- tests/transforms/test_zoom.py +120 -0
- tests/transforms/test_zoomd.py +94 -0
- tests/transforms/transform/__init__.py +10 -0
- tests/transforms/transform/test_randomizable.py +52 -0
- tests/transforms/transform/test_randomizable_transform_type.py +37 -0
- tests/transforms/utility/__init__.py +10 -0
- tests/transforms/utility/test_apply_transform_to_points.py +81 -0
- tests/transforms/utility/test_apply_transform_to_pointsd.py +185 -0
- tests/transforms/utility/test_identity.py +29 -0
- tests/transforms/utility/test_identityd.py +30 -0
- tests/transforms/utility/test_lambda.py +71 -0
- tests/transforms/utility/test_lambdad.py +83 -0
- tests/transforms/utility/test_rand_lambda.py +87 -0
- tests/transforms/utility/test_rand_lambdad.py +77 -0
- tests/transforms/utility/test_simulatedelay.py +36 -0
- tests/transforms/utility/test_simulatedelayd.py +36 -0
- tests/transforms/utility/test_splitdim.py +52 -0
- tests/transforms/utility/test_splitdimd.py +96 -0
- tests/transforms/utils/__init__.py +10 -0
- tests/transforms/utils/test_correct_crop_centers.py +36 -0
- tests/transforms/utils/test_get_unique_labels.py +45 -0
- tests/transforms/utils/test_print_transform_backends.py +29 -0
- tests/transforms/utils/test_soft_clip.py +125 -0
- tests/utils/__init__.py +10 -0
- tests/utils/enums/__init__.py +10 -0
- tests/utils/enums/test_hovernet_loss.py +190 -0
- tests/utils/enums/test_ordering.py +289 -0
- tests/utils/enums/test_wsireader.py +663 -0
- tests/utils/misc/__init__.py +10 -0
- tests/utils/misc/test_ensure_tuple.py +53 -0
- tests/utils/misc/test_monai_env_vars.py +44 -0
- tests/utils/misc/test_monai_utils_misc.py +103 -0
- tests/utils/misc/test_str2bool.py +34 -0
- tests/utils/misc/test_str2list.py +33 -0
- tests/utils/test_alias.py +44 -0
- tests/utils/test_component_store.py +73 -0
- tests/utils/test_deprecated.py +455 -0
- tests/utils/test_enum_bound_interp.py +75 -0
- tests/utils/test_evenly_divisible_all_gather_dist.py +50 -0
- tests/utils/test_get_package_version.py +34 -0
- tests/utils/test_handler_logfile.py +84 -0
- tests/utils/test_handler_metric_logger.py +62 -0
- tests/utils/test_list_to_dict.py +43 -0
- tests/utils/test_look_up_option.py +87 -0
- tests/utils/test_optional_import.py +80 -0
- tests/utils/test_pad_mode.py +39 -0
- tests/utils/test_profiling.py +208 -0
- tests/utils/test_rankfilter_dist.py +77 -0
- tests/utils/test_require_pkg.py +83 -0
- tests/utils/test_sample_slices.py +43 -0
- tests/utils/test_set_determinism.py +74 -0
- tests/utils/test_squeeze_unsqueeze.py +71 -0
- tests/utils/test_state_cacher.py +67 -0
- tests/utils/test_torchscript_utils.py +113 -0
- tests/utils/test_version.py +91 -0
- tests/utils/test_version_after.py +65 -0
- tests/utils/type_conversion/__init__.py +10 -0
- tests/utils/type_conversion/test_convert_data_type.py +152 -0
- tests/utils/type_conversion/test_get_equivalent_dtype.py +65 -0
- tests/utils/type_conversion/test_safe_dtype_range.py +99 -0
- tests/visualize/__init__.py +10 -0
- tests/visualize/test_img2tensorboard.py +46 -0
- tests/visualize/test_occlusion_sensitivity.py +128 -0
- tests/visualize/test_plot_2d_or_3d_image.py +74 -0
- tests/visualize/test_vis_cam.py +98 -0
- tests/visualize/test_vis_gradcam.py +211 -0
- tests/visualize/utils/__init__.py +10 -0
- tests/visualize/utils/test_blend_images.py +63 -0
- tests/visualize/utils/test_matshow3d.py +133 -0
- monai_weekly-1.5.dev2506.dist-info/RECORD +0 -427
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/LICENSE +0 -0
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/WHEEL +0 -0
@@ -0,0 +1,152 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
import torch
|
18
|
+
from parameterized import parameterized
|
19
|
+
|
20
|
+
from monai.data import MetaTensor
|
21
|
+
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, get_equivalent_dtype
|
22
|
+
from tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose
|
23
|
+
|
24
|
+
TESTS: list[tuple] = []
|
25
|
+
for in_type in TEST_NDARRAYS_ALL + (int, float):
|
26
|
+
for out_type in TEST_NDARRAYS_ALL:
|
27
|
+
TESTS.append((in_type(np.array(1.0)), out_type(np.array(1.0)), None, False)) # type: ignore
|
28
|
+
if in_type is not float:
|
29
|
+
TESTS.append((in_type(np.array(256)), out_type(np.array(255)), np.uint8, True)) # type: ignore
|
30
|
+
|
31
|
+
TESTS_LIST: list[tuple] = []
|
32
|
+
for in_type in TEST_NDARRAYS_ALL + (int, float):
|
33
|
+
for out_type in TEST_NDARRAYS_ALL:
|
34
|
+
TESTS_LIST.append(
|
35
|
+
(
|
36
|
+
[in_type(np.array(1.0)), in_type(np.array(1.0))], # type: ignore
|
37
|
+
out_type(np.array([1.0, 1.0])),
|
38
|
+
True,
|
39
|
+
None,
|
40
|
+
False,
|
41
|
+
)
|
42
|
+
)
|
43
|
+
TESTS_LIST.append(
|
44
|
+
(
|
45
|
+
[in_type(np.array(1.0)), in_type(np.array(1.0))], # type: ignore
|
46
|
+
[out_type(np.array(1.0)), out_type(np.array(1.0))],
|
47
|
+
False,
|
48
|
+
None,
|
49
|
+
False,
|
50
|
+
)
|
51
|
+
)
|
52
|
+
if in_type is not float:
|
53
|
+
TESTS_LIST.append(
|
54
|
+
(
|
55
|
+
[in_type(np.array(257)), in_type(np.array(1))], # type: ignore
|
56
|
+
out_type(np.array([255, 1])),
|
57
|
+
True,
|
58
|
+
np.uint8,
|
59
|
+
True,
|
60
|
+
)
|
61
|
+
)
|
62
|
+
TESTS_LIST.append(
|
63
|
+
(
|
64
|
+
[in_type(np.array(257)), in_type(np.array(-12))], # type: ignore
|
65
|
+
[out_type(np.array(255)), out_type(np.array(0))],
|
66
|
+
False,
|
67
|
+
np.uint8,
|
68
|
+
True,
|
69
|
+
)
|
70
|
+
)
|
71
|
+
|
72
|
+
UNSUPPORTED_TYPES = {np.dtype("uint16"): torch.int32, np.dtype("uint32"): torch.int64, np.dtype("uint64"): torch.int64}
|
73
|
+
|
74
|
+
|
75
|
+
class TestTensor(torch.Tensor):
|
76
|
+
__test__ = False # indicate to pytest that this class is not intended for collection
|
77
|
+
pass
|
78
|
+
|
79
|
+
|
80
|
+
class TestConvertDataType(unittest.TestCase):
|
81
|
+
@parameterized.expand(TESTS)
|
82
|
+
def test_convert_data_type(self, in_image, im_out, out_dtype, safe):
|
83
|
+
converted_im, orig_type, orig_device = convert_data_type(in_image, type(im_out), dtype=out_dtype, safe=safe)
|
84
|
+
# check input is unchanged
|
85
|
+
self.assertEqual(type(in_image), orig_type)
|
86
|
+
if isinstance(in_image, torch.Tensor):
|
87
|
+
self.assertEqual(in_image.device, orig_device)
|
88
|
+
# check output is desired type
|
89
|
+
self.assertEqual(type(converted_im), type(im_out))
|
90
|
+
# check data has been clipped
|
91
|
+
assert_allclose(converted_im, im_out)
|
92
|
+
# check dtype is unchanged
|
93
|
+
if out_dtype is None:
|
94
|
+
if isinstance(in_image, (np.ndarray, torch.Tensor)):
|
95
|
+
self.assertEqual(converted_im.dtype, im_out.dtype)
|
96
|
+
|
97
|
+
def test_neg_stride(self):
|
98
|
+
_ = convert_data_type(np.array((1, 2))[::-1], torch.Tensor)
|
99
|
+
|
100
|
+
@parameterized.expand(list(UNSUPPORTED_TYPES.items()))
|
101
|
+
def test_unsupported_np_types(self, np_type, pt_type):
|
102
|
+
in_image = np.ones(13, dtype=np_type) # choose a prime size so as to be indivisible by the size of any dtype
|
103
|
+
converted_im, orig_type, orig_device = convert_data_type(in_image, torch.Tensor)
|
104
|
+
|
105
|
+
self.assertEqual(converted_im.dtype, pt_type)
|
106
|
+
|
107
|
+
@parameterized.expand(TESTS_LIST)
|
108
|
+
def test_convert_list(self, in_image, im_out, wrap, out_dtype, safe):
|
109
|
+
output_type = type(im_out) if wrap else type(im_out[0])
|
110
|
+
converted_im, *_ = convert_data_type(in_image, output_type, wrap_sequence=wrap, dtype=out_dtype, safe=safe)
|
111
|
+
# check output is desired type
|
112
|
+
if not wrap:
|
113
|
+
converted_im = converted_im[0]
|
114
|
+
im_out = im_out[0]
|
115
|
+
self.assertEqual(type(converted_im), type(im_out))
|
116
|
+
assert_allclose(converted_im, im_out)
|
117
|
+
# check dtype is unchanged
|
118
|
+
if isinstance(in_image[0], (np.ndarray, torch.Tensor)):
|
119
|
+
if out_dtype is None:
|
120
|
+
self.assertEqual(converted_im.dtype, im_out.dtype)
|
121
|
+
else:
|
122
|
+
_out_dtype = get_equivalent_dtype(out_dtype, output_type)
|
123
|
+
self.assertEqual(converted_im.dtype, _out_dtype)
|
124
|
+
|
125
|
+
|
126
|
+
class TestConvertDataSame(unittest.TestCase):
|
127
|
+
# add test for subclass of Tensor
|
128
|
+
@parameterized.expand(TESTS + [(np.array(256), TestTensor(np.array([255])), torch.uint8, True)])
|
129
|
+
def test_convert_data_type(self, in_image, im_out, out_dtype, safe):
|
130
|
+
converted_im, orig_type, orig_device = convert_to_dst_type(in_image, im_out, dtype=out_dtype, safe=safe)
|
131
|
+
# check input is unchanged
|
132
|
+
self.assertEqual(type(in_image), orig_type)
|
133
|
+
assert_allclose(converted_im, im_out)
|
134
|
+
if isinstance(in_image, torch.Tensor):
|
135
|
+
self.assertEqual(in_image.device, orig_device)
|
136
|
+
|
137
|
+
# check output is desired type
|
138
|
+
if isinstance(im_out, MetaTensor):
|
139
|
+
output_type = MetaTensor
|
140
|
+
elif isinstance(im_out, torch.Tensor):
|
141
|
+
output_type = torch.Tensor
|
142
|
+
else:
|
143
|
+
output_type = np.ndarray
|
144
|
+
self.assertEqual(type(converted_im), output_type)
|
145
|
+
# check dtype is unchanged
|
146
|
+
if out_dtype is None:
|
147
|
+
if isinstance(in_image, (np.ndarray, torch.Tensor, MetaTensor)):
|
148
|
+
self.assertEqual(converted_im.dtype, im_out.dtype)
|
149
|
+
|
150
|
+
|
151
|
+
if __name__ == "__main__":
|
152
|
+
unittest.main()
|
@@ -0,0 +1,65 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
import torch
|
18
|
+
from parameterized import parameterized
|
19
|
+
|
20
|
+
from monai.utils.type_conversion import get_equivalent_dtype, get_numpy_dtype_from_string, get_torch_dtype_from_string
|
21
|
+
from tests.test_utils import TEST_NDARRAYS
|
22
|
+
|
23
|
+
DTYPES = [torch.float32, np.float32, np.dtype(np.float32)]
|
24
|
+
|
25
|
+
TESTS = []
|
26
|
+
for p in TEST_NDARRAYS:
|
27
|
+
for im_dtype in DTYPES:
|
28
|
+
TESTS.append((p(np.array(1.0, dtype=np.float32)), im_dtype))
|
29
|
+
|
30
|
+
|
31
|
+
class TestGetEquivalentDtype(unittest.TestCase):
|
32
|
+
@parameterized.expand(TESTS)
|
33
|
+
def test_get_equivalent_dtype(self, im, input_dtype):
|
34
|
+
out_dtype = get_equivalent_dtype(input_dtype, type(im))
|
35
|
+
self.assertEqual(out_dtype, im.dtype)
|
36
|
+
|
37
|
+
def test_native_type(self):
|
38
|
+
"""the get_equivalent_dtype currently doesn't change the build-in type"""
|
39
|
+
n_type = [float, int, bool]
|
40
|
+
for n in n_type:
|
41
|
+
for im_dtype in DTYPES:
|
42
|
+
out_dtype = get_equivalent_dtype(n, type(im_dtype))
|
43
|
+
self.assertEqual(out_dtype, n)
|
44
|
+
|
45
|
+
@parameterized.expand(
|
46
|
+
[
|
47
|
+
["float", np.float64],
|
48
|
+
["float32", np.float32],
|
49
|
+
["np.float32", np.float32],
|
50
|
+
["float64", np.float64],
|
51
|
+
["torch.float64", np.float64],
|
52
|
+
]
|
53
|
+
)
|
54
|
+
def test_from_string(self, dtype_str, expected_np):
|
55
|
+
expected_pt = get_equivalent_dtype(expected_np, torch.Tensor)
|
56
|
+
# numpy
|
57
|
+
dtype = get_numpy_dtype_from_string(dtype_str)
|
58
|
+
self.assertEqual(dtype, expected_np)
|
59
|
+
# torch
|
60
|
+
dtype = get_torch_dtype_from_string(dtype_str)
|
61
|
+
self.assertEqual(dtype, expected_pt)
|
62
|
+
|
63
|
+
|
64
|
+
if __name__ == "__main__":
|
65
|
+
unittest.main()
|
@@ -0,0 +1,99 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
import torch
|
18
|
+
from parameterized import parameterized
|
19
|
+
|
20
|
+
from monai.utils import optional_import
|
21
|
+
from monai.utils.type_conversion import get_equivalent_dtype, safe_dtype_range
|
22
|
+
from tests.test_utils import HAS_CUPY, TEST_NDARRAYS_ALL, assert_allclose
|
23
|
+
|
24
|
+
cp, _ = optional_import("cupy")
|
25
|
+
|
26
|
+
TESTS: list[tuple] = []
|
27
|
+
for in_type in TEST_NDARRAYS_ALL + (int, float):
|
28
|
+
TESTS.append((in_type(np.array(1.0)), in_type(np.array(1.0)), None)) # type: ignore
|
29
|
+
if in_type is not float:
|
30
|
+
TESTS.append((in_type(np.array(256)), in_type(np.array(255)), np.uint8)) # type: ignore
|
31
|
+
TESTS.append((in_type(np.array(-12)), in_type(np.array(0)), np.uint8)) # type: ignore
|
32
|
+
for in_type in TEST_NDARRAYS_ALL:
|
33
|
+
TESTS.append((in_type(np.array([[256, 255], [-12, 0]])), in_type(np.array([[255, 255], [0, 0]])), np.uint8))
|
34
|
+
|
35
|
+
TESTS_LIST: list[tuple] = []
|
36
|
+
for in_type in TEST_NDARRAYS_ALL + (int, float):
|
37
|
+
TESTS_LIST.append(
|
38
|
+
(
|
39
|
+
[in_type(np.array(1.0)), in_type(np.array(1.0))], # type: ignore
|
40
|
+
[in_type(np.array(1.0)), in_type(np.array(1.0))], # type: ignore
|
41
|
+
None,
|
42
|
+
)
|
43
|
+
)
|
44
|
+
if in_type is not float:
|
45
|
+
TESTS_LIST.append(
|
46
|
+
(
|
47
|
+
[in_type(np.array(257)), in_type(np.array(-12))], # type: ignore
|
48
|
+
[in_type(np.array(255)), in_type(np.array(0))], # type: ignore
|
49
|
+
np.uint8,
|
50
|
+
)
|
51
|
+
)
|
52
|
+
|
53
|
+
TESTS_CUPY = [[np.array(1.0), np.array(1.0), None], [np.array([-12]), np.array([0]), np.uint8]]
|
54
|
+
|
55
|
+
|
56
|
+
class TesSafeDtypeRange(unittest.TestCase):
|
57
|
+
@parameterized.expand(TESTS)
|
58
|
+
def test_safe_dtype_range(self, in_image, im_out, out_dtype):
|
59
|
+
result = safe_dtype_range(in_image, out_dtype)
|
60
|
+
# check type is unchanged
|
61
|
+
self.assertEqual(type(in_image), type(result))
|
62
|
+
# check dtype is unchanged
|
63
|
+
if isinstance(in_image, (np.ndarray, torch.Tensor)):
|
64
|
+
self.assertEqual(in_image.dtype, result.dtype)
|
65
|
+
# check output
|
66
|
+
assert_allclose(result, im_out)
|
67
|
+
|
68
|
+
@parameterized.expand(TESTS_LIST)
|
69
|
+
def test_safe_dtype_range_list(self, in_image, im_out, out_dtype):
|
70
|
+
output_type = type(im_out[0])
|
71
|
+
result = safe_dtype_range(in_image, dtype=out_dtype)
|
72
|
+
# check type is unchanged
|
73
|
+
self.assertEqual(type(result), type(im_out))
|
74
|
+
# check output
|
75
|
+
for i, _result in enumerate(result):
|
76
|
+
assert_allclose(_result, im_out[i])
|
77
|
+
# check dtype is unchanged
|
78
|
+
if isinstance(in_image, (np.ndarray, torch.Tensor)):
|
79
|
+
if out_dtype is None:
|
80
|
+
self.assertEqual(result[0].dtype, im_out[0].dtype)
|
81
|
+
else:
|
82
|
+
_out_dtype = get_equivalent_dtype(out_dtype, output_type)
|
83
|
+
self.assertEqual(result[0].dtype, _out_dtype)
|
84
|
+
|
85
|
+
@parameterized.expand(TESTS_CUPY)
|
86
|
+
@unittest.skipUnless(HAS_CUPY, "Requires CuPy")
|
87
|
+
def test_type_cupy(self, in_image, im_out, out_dtype):
|
88
|
+
in_image = cp.asarray(in_image)
|
89
|
+
result = safe_dtype_range(in_image, dtype=out_dtype)
|
90
|
+
# check type is unchanged
|
91
|
+
self.assertEqual(type(in_image), type(result))
|
92
|
+
# check dtype is unchanged
|
93
|
+
self.assertEqual(result.dtype, in_image.dtype)
|
94
|
+
# check output
|
95
|
+
self.assertEqual(result, cp.asarray(im_out))
|
96
|
+
|
97
|
+
|
98
|
+
if __name__ == "__main__":
|
99
|
+
unittest.main()
|
@@ -0,0 +1,10 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
@@ -0,0 +1,46 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
import tensorboard
|
18
|
+
import torch
|
19
|
+
|
20
|
+
from monai.visualize import make_animated_gif_summary
|
21
|
+
|
22
|
+
|
23
|
+
class TestImg2Tensorboard(unittest.TestCase):
|
24
|
+
|
25
|
+
def test_write_gray(self):
|
26
|
+
nparr = np.ones(shape=(1, 32, 32, 32), dtype=np.float32)
|
27
|
+
summary_object_np = make_animated_gif_summary(
|
28
|
+
tag="test_summary_nparr.png", image=nparr, max_out=1, scale_factor=253.0
|
29
|
+
)
|
30
|
+
for s in summary_object_np:
|
31
|
+
assert isinstance(
|
32
|
+
s, tensorboard.compat.proto.summary_pb2.Summary
|
33
|
+
), "make_animated_gif_summary must return a tensorboard.summary object from numpy array"
|
34
|
+
|
35
|
+
tensorarr = torch.tensor(nparr)
|
36
|
+
summary_object_tensor = make_animated_gif_summary(
|
37
|
+
tag="test_summary_tensorarr.png", image=tensorarr, max_out=1, frame_dim=-1, scale_factor=253.0
|
38
|
+
)
|
39
|
+
for s in summary_object_tensor:
|
40
|
+
assert isinstance(
|
41
|
+
s, tensorboard.compat.proto.summary_pb2.Summary
|
42
|
+
), "make_animated_gif_summary must return a tensorboard.summary object from tensor input"
|
43
|
+
|
44
|
+
|
45
|
+
if __name__ == "__main__":
|
46
|
+
unittest.main()
|
@@ -0,0 +1,128 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
from typing import Any
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from parameterized import parameterized
|
19
|
+
|
20
|
+
from monai.networks.nets import DenseNet, DenseNet121
|
21
|
+
from monai.visualize import OcclusionSensitivity
|
22
|
+
|
23
|
+
|
24
|
+
class DenseNetAdjoint(DenseNet121):
|
25
|
+
|
26
|
+
def __call__(self, x, adjoint_info):
|
27
|
+
if adjoint_info != 42:
|
28
|
+
raise ValueError
|
29
|
+
return super().__call__(x)
|
30
|
+
|
31
|
+
|
32
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
33
|
+
out_channels_2d = 4
|
34
|
+
out_channels_3d = 3
|
35
|
+
model_2d = DenseNet121(spatial_dims=2, in_channels=1, out_channels=out_channels_2d).to(device)
|
36
|
+
model_2d_2c = DenseNet121(spatial_dims=2, in_channels=2, out_channels=out_channels_2d).to(device)
|
37
|
+
model_3d = DenseNet(
|
38
|
+
spatial_dims=3, in_channels=1, out_channels=out_channels_3d, init_features=2, growth_rate=2, block_config=(6,)
|
39
|
+
).to(device)
|
40
|
+
model_2d_adjoint = DenseNetAdjoint(spatial_dims=2, in_channels=1, out_channels=out_channels_2d).to(device)
|
41
|
+
model_2d.eval()
|
42
|
+
model_2d_2c.eval()
|
43
|
+
model_3d.eval()
|
44
|
+
model_2d_adjoint.eval()
|
45
|
+
|
46
|
+
TESTS: list[Any] = []
|
47
|
+
TESTS_FAIL: list[Any] = []
|
48
|
+
|
49
|
+
# 2D w/ bounding box with all modes
|
50
|
+
for mode in ("gaussian", "mean_patch", "mean_img"):
|
51
|
+
TESTS.append(
|
52
|
+
[
|
53
|
+
{"nn_module": model_2d, "mode": mode},
|
54
|
+
{"x": torch.rand(1, 1, 48, 64).to(device), "b_box": [2, 40, 1, 62]},
|
55
|
+
(1, out_channels_2d, 38, 61),
|
56
|
+
(1, 1, 38, 61),
|
57
|
+
]
|
58
|
+
)
|
59
|
+
# 3D w/ bounding box
|
60
|
+
TESTS.append(
|
61
|
+
[
|
62
|
+
{"nn_module": model_3d, "n_batch": 10, "mask_size": (16, 15, 14)},
|
63
|
+
{"x": torch.rand(1, 1, 64, 32, 16).to(device), "b_box": [2, 43, -1, -1, -1, -1]},
|
64
|
+
(1, out_channels_3d, 41, 32, 16),
|
65
|
+
(1, 1, 41, 32, 16),
|
66
|
+
]
|
67
|
+
)
|
68
|
+
TESTS.append(
|
69
|
+
[
|
70
|
+
{"nn_module": model_3d, "n_batch": 10},
|
71
|
+
{"x": torch.rand(1, 1, 6, 7, 8).to(device), "b_box": [1, 3, -1, -1, -1, -1]},
|
72
|
+
(1, out_channels_3d, 2, 7, 8),
|
73
|
+
(1, 1, 2, 7, 8),
|
74
|
+
]
|
75
|
+
)
|
76
|
+
TESTS.append(
|
77
|
+
[
|
78
|
+
{"nn_module": model_2d_2c},
|
79
|
+
{"x": torch.rand(1, 2, 48, 64).to(device)},
|
80
|
+
(1, out_channels_2d, 48, 64),
|
81
|
+
(1, 1, 48, 64),
|
82
|
+
]
|
83
|
+
)
|
84
|
+
# 2D w/ bounding box and adjoint
|
85
|
+
TESTS.append(
|
86
|
+
[
|
87
|
+
{"nn_module": model_2d_adjoint},
|
88
|
+
{"x": torch.rand(1, 1, 48, 64).to(device), "b_box": [2, 40, 1, 62], "adjoint_info": 42},
|
89
|
+
(1, out_channels_2d, 38, 61),
|
90
|
+
(1, 1, 38, 61),
|
91
|
+
]
|
92
|
+
)
|
93
|
+
# 2D should fail: bbox makes image too small
|
94
|
+
TESTS_FAIL.append(
|
95
|
+
[{"nn_module": model_2d, "n_batch": 10, "mask_size": 200}, {"x": torch.rand(1, 1, 48, 64).to(device)}, ValueError]
|
96
|
+
)
|
97
|
+
# 2D should fail: batch > 1
|
98
|
+
TESTS_FAIL.append(
|
99
|
+
[{"nn_module": model_2d, "n_batch": 10, "mask_size": 100}, {"x": torch.rand(2, 1, 48, 64).to(device)}, ValueError]
|
100
|
+
)
|
101
|
+
# 2D should fail: unknown mode
|
102
|
+
TESTS_FAIL.append(
|
103
|
+
[{"nn_module": model_2d, "mode": "test"}, {"x": torch.rand(1, 1, 48, 64).to(device)}, NotImplementedError]
|
104
|
+
)
|
105
|
+
|
106
|
+
|
107
|
+
class TestComputeOcclusionSensitivity(unittest.TestCase):
|
108
|
+
|
109
|
+
@parameterized.expand(TESTS)
|
110
|
+
def test_shape(self, init_data, call_data, map_expected_shape, most_prob_expected_shape):
|
111
|
+
occ_sens = OcclusionSensitivity(**init_data)
|
112
|
+
m, most_prob = occ_sens(**call_data)
|
113
|
+
self.assertTupleEqual(m.shape, map_expected_shape)
|
114
|
+
self.assertTupleEqual(most_prob.shape, most_prob_expected_shape)
|
115
|
+
# most probable class should be of type int, and should have min>=0, max<num_classes
|
116
|
+
self.assertEqual(most_prob.dtype, torch.int64)
|
117
|
+
self.assertGreaterEqual(most_prob.min(), 0)
|
118
|
+
self.assertLess(most_prob.max(), m.shape[-1])
|
119
|
+
|
120
|
+
@parameterized.expand(TESTS_FAIL)
|
121
|
+
def test_fail(self, init_data, call_data, error_type):
|
122
|
+
with self.assertRaises(error_type):
|
123
|
+
occ_sens = OcclusionSensitivity(**init_data)
|
124
|
+
occ_sens(**call_data)
|
125
|
+
|
126
|
+
|
127
|
+
if __name__ == "__main__":
|
128
|
+
unittest.main()
|
@@ -0,0 +1,74 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import glob
|
15
|
+
import tempfile
|
16
|
+
import unittest
|
17
|
+
|
18
|
+
import torch
|
19
|
+
from parameterized import parameterized
|
20
|
+
|
21
|
+
from monai.utils import optional_import
|
22
|
+
from monai.visualize import plot_2d_or_3d_image
|
23
|
+
from tests.test_utils import SkipIfBeforePyTorchVersion, SkipIfNoModule
|
24
|
+
|
25
|
+
SummaryWriter, has_tb = optional_import("torch.utils.tensorboard", name="SummaryWriter")
|
26
|
+
|
27
|
+
SummaryWriterX, _ = optional_import("tensorboardX", name="SummaryWriter")
|
28
|
+
|
29
|
+
TEST_CASE_1 = [(1, 1, 10, 10)]
|
30
|
+
|
31
|
+
TEST_CASE_2 = [(1, 3, 10, 10)]
|
32
|
+
|
33
|
+
TEST_CASE_3 = [(1, 4, 10, 10)]
|
34
|
+
|
35
|
+
TEST_CASE_4 = [(1, 1, 10, 10, 10)]
|
36
|
+
|
37
|
+
TEST_CASE_5 = [(1, 3, 10, 10, 10)]
|
38
|
+
|
39
|
+
|
40
|
+
@unittest.skipUnless(has_tb, "Requires SummaryWriter installation")
|
41
|
+
@SkipIfBeforePyTorchVersion((1, 13)) # issue 6683
|
42
|
+
class TestPlot2dOr3dImage(unittest.TestCase):
|
43
|
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
|
44
|
+
def test_tb_image(self, shape):
|
45
|
+
with tempfile.TemporaryDirectory() as tempdir:
|
46
|
+
writer = SummaryWriter(log_dir=tempdir)
|
47
|
+
plot_2d_or_3d_image(torch.zeros(shape), 0, writer, max_channels=3, frame_dim=-1)
|
48
|
+
writer.flush()
|
49
|
+
writer.close()
|
50
|
+
self.assertTrue(len(glob.glob(tempdir)) > 0)
|
51
|
+
|
52
|
+
@SkipIfNoModule("tensorboardX")
|
53
|
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
|
54
|
+
def test_tbx_image(self, shape):
|
55
|
+
with tempfile.TemporaryDirectory() as tempdir:
|
56
|
+
writer = SummaryWriterX(log_dir=tempdir)
|
57
|
+
plot_2d_or_3d_image(torch.zeros(shape), 0, writer, max_channels=2)
|
58
|
+
writer.flush()
|
59
|
+
writer.close()
|
60
|
+
self.assertTrue(len(glob.glob(tempdir)) > 0)
|
61
|
+
|
62
|
+
@SkipIfNoModule("tensorboardX")
|
63
|
+
@parameterized.expand([TEST_CASE_5])
|
64
|
+
def test_tbx_video(self, shape):
|
65
|
+
with tempfile.TemporaryDirectory() as tempdir:
|
66
|
+
writer = SummaryWriterX(log_dir=tempdir)
|
67
|
+
plot_2d_or_3d_image(torch.rand(shape), 0, writer, max_channels=3)
|
68
|
+
writer.flush()
|
69
|
+
writer.close()
|
70
|
+
self.assertTrue(len(glob.glob(tempdir)) > 0)
|
71
|
+
|
72
|
+
|
73
|
+
if __name__ == "__main__":
|
74
|
+
unittest.main()
|
@@ -0,0 +1,98 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import torch
|
17
|
+
from parameterized import parameterized
|
18
|
+
|
19
|
+
from monai.networks.nets import DenseNet, DenseNet121, SEResNet50
|
20
|
+
from monai.visualize import CAM
|
21
|
+
|
22
|
+
# 2D
|
23
|
+
TEST_CASE_0 = [
|
24
|
+
{
|
25
|
+
"model": "densenet2d",
|
26
|
+
"shape": (2, 1, 48, 64),
|
27
|
+
"feature_shape": (2, 1, 1, 2),
|
28
|
+
"target_layers": "class_layers.relu",
|
29
|
+
"fc_layers": "class_layers.out",
|
30
|
+
},
|
31
|
+
(2, 1, 48, 64),
|
32
|
+
]
|
33
|
+
# 3D
|
34
|
+
TEST_CASE_1 = [
|
35
|
+
{
|
36
|
+
"model": "densenet3d",
|
37
|
+
"shape": (2, 1, 6, 6, 6),
|
38
|
+
"feature_shape": (2, 1, 2, 2, 2),
|
39
|
+
"target_layers": "class_layers.relu",
|
40
|
+
"fc_layers": "class_layers.out",
|
41
|
+
},
|
42
|
+
(2, 1, 6, 6, 6),
|
43
|
+
]
|
44
|
+
# 2D
|
45
|
+
TEST_CASE_2 = [
|
46
|
+
{
|
47
|
+
"model": "senet2d",
|
48
|
+
"shape": (2, 3, 64, 64),
|
49
|
+
"feature_shape": (2, 1, 2, 2),
|
50
|
+
"target_layers": "layer4",
|
51
|
+
"fc_layers": "last_linear",
|
52
|
+
},
|
53
|
+
(2, 1, 64, 64),
|
54
|
+
]
|
55
|
+
|
56
|
+
# 3D
|
57
|
+
TEST_CASE_3 = [
|
58
|
+
{
|
59
|
+
"model": "senet3d",
|
60
|
+
"shape": (2, 3, 8, 8, 48),
|
61
|
+
"feature_shape": (2, 1, 1, 1, 2),
|
62
|
+
"target_layers": "layer4",
|
63
|
+
"fc_layers": "last_linear",
|
64
|
+
},
|
65
|
+
(2, 1, 8, 8, 48),
|
66
|
+
]
|
67
|
+
|
68
|
+
|
69
|
+
class TestClassActivationMap(unittest.TestCase):
|
70
|
+
|
71
|
+
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
|
72
|
+
def test_shape(self, input_data, expected_shape):
|
73
|
+
model = None
|
74
|
+
|
75
|
+
if input_data["model"] == "densenet2d":
|
76
|
+
model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3)
|
77
|
+
if input_data["model"] == "densenet3d":
|
78
|
+
model = DenseNet(
|
79
|
+
spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,)
|
80
|
+
)
|
81
|
+
if input_data["model"] == "senet2d":
|
82
|
+
model = SEResNet50(spatial_dims=2, in_channels=3, num_classes=4)
|
83
|
+
if input_data["model"] == "senet3d":
|
84
|
+
model = SEResNet50(spatial_dims=3, in_channels=3, num_classes=4)
|
85
|
+
|
86
|
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
87
|
+
model.to(device)
|
88
|
+
model.eval()
|
89
|
+
cam = CAM(nn_module=model, target_layers=input_data["target_layers"], fc_layers=input_data["fc_layers"])
|
90
|
+
image = torch.rand(input_data["shape"], device=device)
|
91
|
+
result = cam(x=image, layer_idx=-1)
|
92
|
+
fea_shape = cam.feature_map_size(input_data["shape"], device=device)
|
93
|
+
self.assertTupleEqual(fea_shape, input_data["feature_shape"])
|
94
|
+
self.assertTupleEqual(result.shape, expected_shape)
|
95
|
+
|
96
|
+
|
97
|
+
if __name__ == "__main__":
|
98
|
+
unittest.main()
|