monai-weekly 1.5.dev2506__py3-none-any.whl → 1.5.dev2507__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/auto3dseg/transforms.py +1 -4
- monai/data/utils.py +6 -13
- monai/inferers/utils.py +1 -2
- monai/losses/dice.py +2 -14
- monai/losses/ds_loss.py +1 -3
- monai/networks/layers/simplelayers.py +2 -14
- monai/networks/utils.py +4 -16
- monai/transforms/compose.py +28 -11
- monai/transforms/croppad/array.py +1 -6
- monai/transforms/io/array.py +0 -1
- monai/transforms/transform.py +15 -6
- monai/transforms/utils.py +1 -2
- monai/utils/tf32.py +0 -10
- monai/visualize/class_activation_maps.py +5 -8
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2507.dist-info}/METADATA +2 -2
- monai_weekly-1.5.dev2507.dist-info/RECORD +1181 -0
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2507.dist-info}/top_level.txt +1 -0
- tests/apps/__init__.py +10 -0
- tests/apps/deepedit/__init__.py +10 -0
- tests/apps/deepedit/test_deepedit_transforms.py +314 -0
- tests/apps/deepgrow/__init__.py +10 -0
- tests/apps/deepgrow/test_deepgrow_dataset.py +109 -0
- tests/apps/deepgrow/transforms/__init__.py +10 -0
- tests/apps/deepgrow/transforms/test_deepgrow_interaction.py +97 -0
- tests/apps/deepgrow/transforms/test_deepgrow_transforms.py +556 -0
- tests/apps/detection/__init__.py +10 -0
- tests/apps/detection/metrics/__init__.py +10 -0
- tests/apps/detection/metrics/test_detection_coco_metrics.py +69 -0
- tests/apps/detection/networks/__init__.py +10 -0
- tests/apps/detection/networks/test_retinanet.py +210 -0
- tests/apps/detection/networks/test_retinanet_detector.py +203 -0
- tests/apps/detection/test_box_transform.py +370 -0
- tests/apps/detection/utils/__init__.py +10 -0
- tests/apps/detection/utils/test_anchor_box.py +88 -0
- tests/apps/detection/utils/test_atss_box_matcher.py +46 -0
- tests/apps/detection/utils/test_box_coder.py +43 -0
- tests/apps/detection/utils/test_detector_boxselector.py +67 -0
- tests/apps/detection/utils/test_detector_utils.py +96 -0
- tests/apps/detection/utils/test_hardnegsampler.py +54 -0
- tests/apps/nuclick/__init__.py +10 -0
- tests/apps/nuclick/test_nuclick_transforms.py +259 -0
- tests/apps/pathology/__init__.py +10 -0
- tests/apps/pathology/handlers/__init__.py +10 -0
- tests/apps/pathology/handlers/test_from_engine_hovernet.py +38 -0
- tests/apps/pathology/test_lesion_froc.py +333 -0
- tests/apps/pathology/test_pathology_prob_nms.py +55 -0
- tests/apps/pathology/test_prepare_batch_hovernet.py +70 -0
- tests/apps/pathology/test_sliding_window_hovernet_inference.py +303 -0
- tests/apps/pathology/transforms/__init__.py +10 -0
- tests/apps/pathology/transforms/post/__init__.py +10 -0
- tests/apps/pathology/transforms/post/test_generate_distance_map.py +51 -0
- tests/apps/pathology/transforms/post/test_generate_distance_mapd.py +70 -0
- tests/apps/pathology/transforms/post/test_generate_instance_border.py +49 -0
- tests/apps/pathology/transforms/post/test_generate_instance_borderd.py +59 -0
- tests/apps/pathology/transforms/post/test_generate_instance_centroid.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_instance_centroidd.py +56 -0
- tests/apps/pathology/transforms/post/test_generate_instance_contour.py +58 -0
- tests/apps/pathology/transforms/post/test_generate_instance_contourd.py +61 -0
- tests/apps/pathology/transforms/post/test_generate_instance_type.py +51 -0
- tests/apps/pathology/transforms/post/test_generate_instance_typed.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_succinct_contour.py +55 -0
- tests/apps/pathology/transforms/post/test_generate_succinct_contourd.py +57 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_markers.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_markersd.py +83 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_mask.py +77 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_maskd.py +77 -0
- tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processing.py +61 -0
- tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processingd.py +66 -0
- tests/apps/pathology/transforms/post/test_hovernet_nuclear_type_post_processing.py +66 -0
- tests/apps/pathology/transforms/post/test_watershed.py +60 -0
- tests/apps/pathology/transforms/post/test_watershedd.py +70 -0
- tests/apps/pathology/transforms/test_pathology_he_stain.py +230 -0
- tests/apps/pathology/transforms/test_pathology_he_stain_dict.py +225 -0
- tests/apps/reconstruction/__init__.py +10 -0
- tests/apps/reconstruction/nets/__init__.py +10 -0
- tests/apps/reconstruction/nets/test_recon_net_utils.py +82 -0
- tests/apps/reconstruction/test_complex_utils.py +77 -0
- tests/apps/reconstruction/test_fastmri_reader.py +82 -0
- tests/apps/reconstruction/test_mri_utils.py +37 -0
- tests/apps/reconstruction/transforms/__init__.py +10 -0
- tests/apps/reconstruction/transforms/test_kspace_mask.py +50 -0
- tests/apps/reconstruction/transforms/test_reference_based_normalize_intensity.py +77 -0
- tests/apps/reconstruction/transforms/test_reference_based_spatial_cropd.py +57 -0
- tests/apps/test_auto3dseg_bundlegen.py +156 -0
- tests/apps/test_check_hash.py +53 -0
- tests/apps/test_cross_validation.py +74 -0
- tests/apps/test_decathlondataset.py +93 -0
- tests/apps/test_download_and_extract.py +70 -0
- tests/apps/test_download_url_yandex.py +45 -0
- tests/apps/test_mednistdataset.py +72 -0
- tests/apps/test_mmar_download.py +154 -0
- tests/apps/test_tciadataset.py +123 -0
- tests/apps/vista3d/__init__.py +10 -0
- tests/apps/vista3d/test_point_based_window_inferer.py +77 -0
- tests/apps/vista3d/test_vista3d_sampler.py +100 -0
- tests/apps/vista3d/test_vista3d_transforms.py +94 -0
- tests/bundle/__init__.py +10 -0
- tests/bundle/test_bundle_ckpt_export.py +107 -0
- tests/bundle/test_bundle_download.py +435 -0
- tests/bundle/test_bundle_get_data.py +94 -0
- tests/bundle/test_bundle_push_to_hf_hub.py +41 -0
- tests/bundle/test_bundle_trt_export.py +147 -0
- tests/bundle/test_bundle_utils.py +149 -0
- tests/bundle/test_bundle_verify_metadata.py +66 -0
- tests/bundle/test_bundle_verify_net.py +76 -0
- tests/bundle/test_bundle_workflow.py +272 -0
- tests/bundle/test_component_locator.py +38 -0
- tests/bundle/test_config_item.py +138 -0
- tests/bundle/test_config_parser.py +392 -0
- tests/bundle/test_reference_resolver.py +114 -0
- tests/config/__init__.py +10 -0
- tests/config/test_cv2_dist.py +53 -0
- tests/engines/__init__.py +10 -0
- tests/engines/test_ensemble_evaluator.py +94 -0
- tests/engines/test_prepare_batch_default.py +76 -0
- tests/engines/test_prepare_batch_default_dist.py +76 -0
- tests/engines/test_prepare_batch_diffusion.py +104 -0
- tests/engines/test_prepare_batch_extra_input.py +80 -0
- tests/fl/__init__.py +10 -0
- tests/fl/monai_algo/__init__.py +10 -0
- tests/fl/monai_algo/test_fl_monai_algo.py +251 -0
- tests/fl/monai_algo/test_fl_monai_algo_dist.py +117 -0
- tests/fl/test_fl_monai_algo_stats.py +81 -0
- tests/fl/utils/__init__.py +10 -0
- tests/fl/utils/test_fl_exchange_object.py +63 -0
- tests/handlers/__init__.py +10 -0
- tests/handlers/test_handler_checkpoint_loader.py +182 -0
- tests/handlers/test_handler_checkpoint_saver.py +233 -0
- tests/handlers/test_handler_classification_saver.py +64 -0
- tests/handlers/test_handler_classification_saver_dist.py +77 -0
- tests/handlers/test_handler_clearml_image.py +65 -0
- tests/handlers/test_handler_clearml_stats.py +65 -0
- tests/handlers/test_handler_confusion_matrix.py +104 -0
- tests/handlers/test_handler_confusion_matrix_dist.py +70 -0
- tests/handlers/test_handler_decollate_batch.py +66 -0
- tests/handlers/test_handler_early_stop.py +68 -0
- tests/handlers/test_handler_garbage_collector.py +73 -0
- tests/handlers/test_handler_hausdorff_distance.py +111 -0
- tests/handlers/test_handler_ignite_metric.py +191 -0
- tests/handlers/test_handler_lr_scheduler.py +94 -0
- tests/handlers/test_handler_mean_dice.py +98 -0
- tests/handlers/test_handler_mean_iou.py +76 -0
- tests/handlers/test_handler_metrics_reloaded.py +149 -0
- tests/handlers/test_handler_metrics_saver.py +89 -0
- tests/handlers/test_handler_metrics_saver_dist.py +120 -0
- tests/handlers/test_handler_mlflow.py +296 -0
- tests/handlers/test_handler_nvtx.py +93 -0
- tests/handlers/test_handler_panoptic_quality.py +89 -0
- tests/handlers/test_handler_parameter_scheduler.py +136 -0
- tests/handlers/test_handler_post_processing.py +74 -0
- tests/handlers/test_handler_prob_map_producer.py +111 -0
- tests/handlers/test_handler_regression_metrics.py +160 -0
- tests/handlers/test_handler_regression_metrics_dist.py +245 -0
- tests/handlers/test_handler_rocauc.py +48 -0
- tests/handlers/test_handler_rocauc_dist.py +54 -0
- tests/handlers/test_handler_stats.py +281 -0
- tests/handlers/test_handler_surface_distance.py +113 -0
- tests/handlers/test_handler_tb_image.py +61 -0
- tests/handlers/test_handler_tb_stats.py +166 -0
- tests/handlers/test_handler_validation.py +59 -0
- tests/handlers/test_trt_compile.py +145 -0
- tests/handlers/test_write_metrics_reports.py +68 -0
- tests/inferers/__init__.py +10 -0
- tests/inferers/test_avg_merger.py +179 -0
- tests/inferers/test_controlnet_inferers.py +1310 -0
- tests/inferers/test_diffusion_inferer.py +236 -0
- tests/inferers/test_latent_diffusion_inferer.py +824 -0
- tests/inferers/test_patch_inferer.py +309 -0
- tests/inferers/test_saliency_inferer.py +55 -0
- tests/inferers/test_slice_inferer.py +57 -0
- tests/inferers/test_sliding_window_inference.py +377 -0
- tests/inferers/test_sliding_window_splitter.py +284 -0
- tests/inferers/test_wsi_sliding_window_splitter.py +249 -0
- tests/inferers/test_zarr_avg_merger.py +326 -0
- tests/integration/__init__.py +10 -0
- tests/integration/test_auto3dseg_ensemble.py +211 -0
- tests/integration/test_auto3dseg_hpo.py +189 -0
- tests/integration/test_deepedit_interaction.py +122 -0
- tests/integration/test_downsample_block.py +50 -0
- tests/integration/test_hovernet_nuclear_type_post_processingd.py +71 -0
- tests/integration/test_integration_autorunner.py +201 -0
- tests/integration/test_integration_bundle_run.py +240 -0
- tests/integration/test_integration_classification_2d.py +282 -0
- tests/integration/test_integration_determinism.py +95 -0
- tests/integration/test_integration_fast_train.py +231 -0
- tests/integration/test_integration_gpu_customization.py +159 -0
- tests/integration/test_integration_lazy_samples.py +219 -0
- tests/integration/test_integration_nnunetv2_runner.py +96 -0
- tests/integration/test_integration_segmentation_3d.py +304 -0
- tests/integration/test_integration_sliding_window.py +100 -0
- tests/integration/test_integration_stn.py +133 -0
- tests/integration/test_integration_unet_2d.py +67 -0
- tests/integration/test_integration_workers.py +61 -0
- tests/integration/test_integration_workflows.py +365 -0
- tests/integration/test_integration_workflows_adversarial.py +173 -0
- tests/integration/test_integration_workflows_gan.py +158 -0
- tests/integration/test_loader_semaphore.py +48 -0
- tests/integration/test_mapping_filed.py +122 -0
- tests/integration/test_meta_affine.py +183 -0
- tests/integration/test_metatensor_integration.py +114 -0
- tests/integration/test_module_list.py +76 -0
- tests/integration/test_one_of.py +283 -0
- tests/integration/test_pad_collation.py +124 -0
- tests/integration/test_reg_loss_integration.py +107 -0
- tests/integration/test_retinanet_predict_utils.py +154 -0
- tests/integration/test_seg_loss_integration.py +159 -0
- tests/integration/test_spatial_combine_transforms.py +185 -0
- tests/integration/test_testtimeaugmentation.py +186 -0
- tests/integration/test_vis_gradbased.py +69 -0
- tests/integration/test_vista3d_utils.py +159 -0
- tests/losses/__init__.py +10 -0
- tests/losses/deform/__init__.py +10 -0
- tests/losses/deform/test_bending_energy.py +88 -0
- tests/losses/deform/test_diffusion_loss.py +117 -0
- tests/losses/image_dissimilarity/__init__.py +10 -0
- tests/losses/image_dissimilarity/test_global_mutual_information_loss.py +150 -0
- tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py +162 -0
- tests/losses/test_adversarial_loss.py +94 -0
- tests/losses/test_barlow_twins_loss.py +109 -0
- tests/losses/test_cldice_loss.py +51 -0
- tests/losses/test_contrastive_loss.py +86 -0
- tests/losses/test_dice_ce_loss.py +123 -0
- tests/losses/test_dice_focal_loss.py +124 -0
- tests/losses/test_dice_loss.py +227 -0
- tests/losses/test_ds_loss.py +189 -0
- tests/losses/test_focal_loss.py +379 -0
- tests/losses/test_generalized_dice_focal_loss.py +85 -0
- tests/losses/test_generalized_dice_loss.py +221 -0
- tests/losses/test_generalized_wasserstein_dice_loss.py +234 -0
- tests/losses/test_giou_loss.py +62 -0
- tests/losses/test_hausdorff_loss.py +264 -0
- tests/losses/test_masked_dice_loss.py +152 -0
- tests/losses/test_masked_loss.py +87 -0
- tests/losses/test_multi_scale.py +86 -0
- tests/losses/test_nacl_loss.py +167 -0
- tests/losses/test_perceptual_loss.py +122 -0
- tests/losses/test_spectral_loss.py +86 -0
- tests/losses/test_ssim_loss.py +59 -0
- tests/losses/test_sure_loss.py +72 -0
- tests/losses/test_tversky_loss.py +198 -0
- tests/losses/test_unified_focal_loss.py +66 -0
- tests/metrics/__init__.py +10 -0
- tests/metrics/test_compute_confusion_matrix.py +294 -0
- tests/metrics/test_compute_f_beta.py +80 -0
- tests/metrics/test_compute_fid_metric.py +40 -0
- tests/metrics/test_compute_froc.py +143 -0
- tests/metrics/test_compute_generalized_dice.py +240 -0
- tests/metrics/test_compute_meandice.py +306 -0
- tests/metrics/test_compute_meaniou.py +223 -0
- tests/metrics/test_compute_mmd_metric.py +56 -0
- tests/metrics/test_compute_multiscalessim_metric.py +83 -0
- tests/metrics/test_compute_panoptic_quality.py +113 -0
- tests/metrics/test_compute_regression_metrics.py +196 -0
- tests/metrics/test_compute_roc_auc.py +155 -0
- tests/metrics/test_compute_variance.py +147 -0
- tests/metrics/test_cumulative.py +63 -0
- tests/metrics/test_cumulative_average.py +74 -0
- tests/metrics/test_cumulative_average_dist.py +48 -0
- tests/metrics/test_hausdorff_distance.py +209 -0
- tests/metrics/test_label_quality_score.py +134 -0
- tests/metrics/test_loss_metric.py +57 -0
- tests/metrics/test_metrics_reloaded.py +96 -0
- tests/metrics/test_ssim_metric.py +78 -0
- tests/metrics/test_surface_dice.py +416 -0
- tests/metrics/test_surface_distance.py +186 -0
- tests/networks/__init__.py +10 -0
- tests/networks/blocks/__init__.py +10 -0
- tests/networks/blocks/dints_block/__init__.py +10 -0
- tests/networks/blocks/dints_block/test_acn_block.py +41 -0
- tests/networks/blocks/dints_block/test_factorized_increase.py +37 -0
- tests/networks/blocks/dints_block/test_factorized_reduce.py +37 -0
- tests/networks/blocks/dints_block/test_p3d_block.py +78 -0
- tests/networks/blocks/test_adn.py +86 -0
- tests/networks/blocks/test_convolutions.py +156 -0
- tests/networks/blocks/test_crf_cpu.py +513 -0
- tests/networks/blocks/test_crf_cuda.py +528 -0
- tests/networks/blocks/test_crossattention.py +185 -0
- tests/networks/blocks/test_denseblock.py +105 -0
- tests/networks/blocks/test_dynunet_block.py +116 -0
- tests/networks/blocks/test_fpn_block.py +88 -0
- tests/networks/blocks/test_localnet_block.py +121 -0
- tests/networks/blocks/test_mlp.py +78 -0
- tests/networks/blocks/test_patchembedding.py +212 -0
- tests/networks/blocks/test_regunet_block.py +103 -0
- tests/networks/blocks/test_se_block.py +85 -0
- tests/networks/blocks/test_se_blocks.py +78 -0
- tests/networks/blocks/test_segresnet_block.py +57 -0
- tests/networks/blocks/test_selfattention.py +232 -0
- tests/networks/blocks/test_simple_aspp.py +87 -0
- tests/networks/blocks/test_spatialattention.py +55 -0
- tests/networks/blocks/test_subpixel_upsample.py +87 -0
- tests/networks/blocks/test_text_encoding.py +49 -0
- tests/networks/blocks/test_transformerblock.py +90 -0
- tests/networks/blocks/test_unetr_block.py +158 -0
- tests/networks/blocks/test_upsample_block.py +134 -0
- tests/networks/blocks/warp/__init__.py +10 -0
- tests/networks/blocks/warp/test_dvf2ddf.py +72 -0
- tests/networks/blocks/warp/test_warp.py +250 -0
- tests/networks/layers/__init__.py +10 -0
- tests/networks/layers/filtering/__init__.py +10 -0
- tests/networks/layers/filtering/test_bilateral_approx_cpu.py +399 -0
- tests/networks/layers/filtering/test_bilateral_approx_cuda.py +404 -0
- tests/networks/layers/filtering/test_bilateral_precise.py +437 -0
- tests/networks/layers/filtering/test_phl_cpu.py +259 -0
- tests/networks/layers/filtering/test_phl_cuda.py +167 -0
- tests/networks/layers/filtering/test_trainable_bilateral.py +474 -0
- tests/networks/layers/filtering/test_trainable_joint_bilateral.py +609 -0
- tests/networks/layers/test_affine_transform.py +385 -0
- tests/networks/layers/test_apply_filter.py +89 -0
- tests/networks/layers/test_channel_pad.py +51 -0
- tests/networks/layers/test_conjugate_gradient.py +56 -0
- tests/networks/layers/test_drop_path.py +46 -0
- tests/networks/layers/test_gaussian.py +317 -0
- tests/networks/layers/test_gaussian_filter.py +206 -0
- tests/networks/layers/test_get_layers.py +65 -0
- tests/networks/layers/test_gmm.py +314 -0
- tests/networks/layers/test_grid_pull.py +93 -0
- tests/networks/layers/test_hilbert_transform.py +131 -0
- tests/networks/layers/test_lltm.py +62 -0
- tests/networks/layers/test_median_filter.py +52 -0
- tests/networks/layers/test_polyval.py +55 -0
- tests/networks/layers/test_preset_filters.py +136 -0
- tests/networks/layers/test_savitzky_golay_filter.py +141 -0
- tests/networks/layers/test_separable_filter.py +87 -0
- tests/networks/layers/test_skip_connection.py +48 -0
- tests/networks/layers/test_vector_quantizer.py +89 -0
- tests/networks/layers/test_weight_init.py +50 -0
- tests/networks/nets/__init__.py +10 -0
- tests/networks/nets/dints/__init__.py +10 -0
- tests/networks/nets/dints/test_dints_cell.py +110 -0
- tests/networks/nets/dints/test_dints_mixop.py +84 -0
- tests/networks/nets/regunet/__init__.py +10 -0
- tests/networks/nets/regunet/test_localnet.py +86 -0
- tests/networks/nets/regunet/test_regunet.py +88 -0
- tests/networks/nets/test_ahnet.py +224 -0
- tests/networks/nets/test_attentionunet.py +88 -0
- tests/networks/nets/test_autoencoder.py +95 -0
- tests/networks/nets/test_autoencoderkl.py +337 -0
- tests/networks/nets/test_basic_unet.py +102 -0
- tests/networks/nets/test_basic_unetplusplus.py +109 -0
- tests/networks/nets/test_bundle_init_bundle.py +55 -0
- tests/networks/nets/test_cell_sam_wrapper.py +58 -0
- tests/networks/nets/test_controlnet.py +215 -0
- tests/networks/nets/test_daf3d.py +62 -0
- tests/networks/nets/test_densenet.py +121 -0
- tests/networks/nets/test_diffusion_model_unet.py +585 -0
- tests/networks/nets/test_dints_network.py +168 -0
- tests/networks/nets/test_discriminator.py +59 -0
- tests/networks/nets/test_dynunet.py +181 -0
- tests/networks/nets/test_efficientnet.py +400 -0
- tests/networks/nets/test_flexible_unet.py +341 -0
- tests/networks/nets/test_fullyconnectednet.py +69 -0
- tests/networks/nets/test_generator.py +59 -0
- tests/networks/nets/test_globalnet.py +103 -0
- tests/networks/nets/test_highresnet.py +67 -0
- tests/networks/nets/test_hovernet.py +218 -0
- tests/networks/nets/test_mednext.py +122 -0
- tests/networks/nets/test_milmodel.py +92 -0
- tests/networks/nets/test_net_adapter.py +68 -0
- tests/networks/nets/test_network_consistency.py +86 -0
- tests/networks/nets/test_patch_gan_dicriminator.py +179 -0
- tests/networks/nets/test_quicknat.py +57 -0
- tests/networks/nets/test_resnet.py +340 -0
- tests/networks/nets/test_segresnet.py +120 -0
- tests/networks/nets/test_segresnet_ds.py +156 -0
- tests/networks/nets/test_senet.py +151 -0
- tests/networks/nets/test_spade_autoencoderkl.py +295 -0
- tests/networks/nets/test_spade_diffusion_model_unet.py +574 -0
- tests/networks/nets/test_spade_vaegan.py +140 -0
- tests/networks/nets/test_swin_unetr.py +139 -0
- tests/networks/nets/test_torchvision_fc_model.py +201 -0
- tests/networks/nets/test_transchex.py +84 -0
- tests/networks/nets/test_transformer.py +108 -0
- tests/networks/nets/test_unet.py +208 -0
- tests/networks/nets/test_unetr.py +137 -0
- tests/networks/nets/test_varautoencoder.py +127 -0
- tests/networks/nets/test_vista3d.py +84 -0
- tests/networks/nets/test_vit.py +139 -0
- tests/networks/nets/test_vitautoenc.py +112 -0
- tests/networks/nets/test_vnet.py +81 -0
- tests/networks/nets/test_voxelmorph.py +280 -0
- tests/networks/nets/test_vqvae.py +274 -0
- tests/networks/nets/test_vqvaetransformer_inferer.py +295 -0
- tests/networks/schedulers/__init__.py +10 -0
- tests/networks/schedulers/test_scheduler_ddim.py +83 -0
- tests/networks/schedulers/test_scheduler_ddpm.py +104 -0
- tests/networks/schedulers/test_scheduler_pndm.py +108 -0
- tests/networks/test_bundle_onnx_export.py +71 -0
- tests/networks/test_convert_to_onnx.py +106 -0
- tests/networks/test_convert_to_torchscript.py +46 -0
- tests/networks/test_convert_to_trt.py +79 -0
- tests/networks/test_save_state.py +73 -0
- tests/networks/test_to_onehot.py +63 -0
- tests/networks/test_varnet.py +63 -0
- tests/networks/utils/__init__.py +10 -0
- tests/networks/utils/test_copy_model_state.py +187 -0
- tests/networks/utils/test_eval_mode.py +34 -0
- tests/networks/utils/test_freeze_layers.py +61 -0
- tests/networks/utils/test_replace_module.py +98 -0
- tests/networks/utils/test_train_mode.py +34 -0
- tests/optimizers/__init__.py +10 -0
- tests/optimizers/test_generate_param_groups.py +105 -0
- tests/optimizers/test_lr_finder.py +108 -0
- tests/optimizers/test_lr_scheduler.py +71 -0
- tests/optimizers/test_optim_novograd.py +100 -0
- tests/profile_subclass/__init__.py +10 -0
- tests/profile_subclass/cprofile_profiling.py +29 -0
- tests/profile_subclass/min_classes.py +30 -0
- tests/profile_subclass/profiling.py +73 -0
- tests/profile_subclass/pyspy_profiling.py +41 -0
- tests/transforms/__init__.py +10 -0
- tests/transforms/compose/__init__.py +10 -0
- tests/transforms/compose/test_compose.py +758 -0
- tests/transforms/compose/test_some_of.py +258 -0
- tests/transforms/croppad/__init__.py +10 -0
- tests/transforms/croppad/test_rand_weighted_crop.py +224 -0
- tests/transforms/croppad/test_rand_weighted_cropd.py +182 -0
- tests/transforms/functional/__init__.py +10 -0
- tests/transforms/functional/test_apply.py +75 -0
- tests/transforms/functional/test_resample.py +50 -0
- tests/transforms/intensity/__init__.py +10 -0
- tests/transforms/intensity/test_compute_ho_ver_maps.py +75 -0
- tests/transforms/intensity/test_compute_ho_ver_maps_d.py +79 -0
- tests/transforms/intensity/test_foreground_mask.py +98 -0
- tests/transforms/intensity/test_foreground_maskd.py +106 -0
- tests/transforms/intensity/test_rand_histogram_shiftd.py +76 -0
- tests/transforms/intensity/test_scale_intensity_range_percentiles.py +96 -0
- tests/transforms/intensity/test_scale_intensity_range_percentilesd.py +100 -0
- tests/transforms/inverse/__init__.py +10 -0
- tests/transforms/inverse/test_inverse_array.py +76 -0
- tests/transforms/inverse/test_traceable_transform.py +59 -0
- tests/transforms/post/__init__.py +10 -0
- tests/transforms/post/test_label_filterd.py +78 -0
- tests/transforms/post/test_probnms.py +72 -0
- tests/transforms/post/test_probnmsd.py +79 -0
- tests/transforms/post/test_remove_small_objects.py +102 -0
- tests/transforms/spatial/__init__.py +10 -0
- tests/transforms/spatial/test_convert_box_points.py +119 -0
- tests/transforms/spatial/test_grid_patch.py +134 -0
- tests/transforms/spatial/test_grid_patchd.py +102 -0
- tests/transforms/spatial/test_rand_grid_patch.py +150 -0
- tests/transforms/spatial/test_rand_grid_patchd.py +117 -0
- tests/transforms/spatial/test_spatial_resampled.py +124 -0
- tests/transforms/test_activations.py +120 -0
- tests/transforms/test_activationsd.py +64 -0
- tests/transforms/test_adaptors.py +160 -0
- tests/transforms/test_add_coordinate_channels.py +53 -0
- tests/transforms/test_add_coordinate_channelsd.py +67 -0
- tests/transforms/test_add_extreme_points_channel.py +80 -0
- tests/transforms/test_add_extreme_points_channeld.py +77 -0
- tests/transforms/test_adjust_contrast.py +70 -0
- tests/transforms/test_adjust_contrastd.py +64 -0
- tests/transforms/test_affine.py +245 -0
- tests/transforms/test_affine_grid.py +152 -0
- tests/transforms/test_affined.py +190 -0
- tests/transforms/test_as_channel_last.py +38 -0
- tests/transforms/test_as_channel_lastd.py +44 -0
- tests/transforms/test_as_discrete.py +81 -0
- tests/transforms/test_as_discreted.py +82 -0
- tests/transforms/test_border_pad.py +49 -0
- tests/transforms/test_border_padd.py +45 -0
- tests/transforms/test_bounding_rect.py +54 -0
- tests/transforms/test_bounding_rectd.py +53 -0
- tests/transforms/test_cast_to_type.py +63 -0
- tests/transforms/test_cast_to_typed.py +74 -0
- tests/transforms/test_center_scale_crop.py +55 -0
- tests/transforms/test_center_scale_cropd.py +56 -0
- tests/transforms/test_center_spatial_crop.py +56 -0
- tests/transforms/test_center_spatial_cropd.py +63 -0
- tests/transforms/test_classes_to_indices.py +93 -0
- tests/transforms/test_classes_to_indicesd.py +110 -0
- tests/transforms/test_clip_intensity_percentiles.py +196 -0
- tests/transforms/test_clip_intensity_percentilesd.py +193 -0
- tests/transforms/test_compose_get_number_conversions.py +127 -0
- tests/transforms/test_concat_itemsd.py +82 -0
- tests/transforms/test_convert_to_multi_channel.py +59 -0
- tests/transforms/test_convert_to_multi_channeld.py +37 -0
- tests/transforms/test_copy_itemsd.py +86 -0
- tests/transforms/test_create_grid_and_affine.py +274 -0
- tests/transforms/test_crop_foreground.py +164 -0
- tests/transforms/test_crop_foregroundd.py +205 -0
- tests/transforms/test_cucim_dict_transform.py +142 -0
- tests/transforms/test_cucim_transform.py +141 -0
- tests/transforms/test_data_stats.py +221 -0
- tests/transforms/test_data_statsd.py +249 -0
- tests/transforms/test_delete_itemsd.py +58 -0
- tests/transforms/test_detect_envelope.py +159 -0
- tests/transforms/test_distance_transform_edt.py +202 -0
- tests/transforms/test_divisible_pad.py +49 -0
- tests/transforms/test_divisible_padd.py +42 -0
- tests/transforms/test_ensure_channel_first.py +113 -0
- tests/transforms/test_ensure_channel_firstd.py +85 -0
- tests/transforms/test_ensure_type.py +94 -0
- tests/transforms/test_ensure_typed.py +110 -0
- tests/transforms/test_fg_bg_to_indices.py +83 -0
- tests/transforms/test_fg_bg_to_indicesd.py +78 -0
- tests/transforms/test_fill_holes.py +207 -0
- tests/transforms/test_fill_holesd.py +209 -0
- tests/transforms/test_flatten_sub_keysd.py +64 -0
- tests/transforms/test_flip.py +83 -0
- tests/transforms/test_flipd.py +90 -0
- tests/transforms/test_fourier.py +70 -0
- tests/transforms/test_gaussian_sharpen.py +92 -0
- tests/transforms/test_gaussian_sharpend.py +92 -0
- tests/transforms/test_gaussian_smooth.py +96 -0
- tests/transforms/test_gaussian_smoothd.py +96 -0
- tests/transforms/test_generate_label_classes_crop_centers.py +71 -0
- tests/transforms/test_generate_pos_neg_label_crop_centers.py +76 -0
- tests/transforms/test_generate_spatial_bounding_box.py +114 -0
- tests/transforms/test_get_extreme_points.py +57 -0
- tests/transforms/test_gibbs_noise.py +75 -0
- tests/transforms/test_gibbs_noised.py +88 -0
- tests/transforms/test_grid_distortion.py +113 -0
- tests/transforms/test_grid_distortiond.py +87 -0
- tests/transforms/test_grid_split.py +88 -0
- tests/transforms/test_grid_splitd.py +96 -0
- tests/transforms/test_histogram_normalize.py +59 -0
- tests/transforms/test_histogram_normalized.py +59 -0
- tests/transforms/test_image_filter.py +259 -0
- tests/transforms/test_intensity_stats.py +73 -0
- tests/transforms/test_intensity_statsd.py +90 -0
- tests/transforms/test_inverse.py +521 -0
- tests/transforms/test_inverse_collation.py +147 -0
- tests/transforms/test_invert.py +105 -0
- tests/transforms/test_invertd.py +142 -0
- tests/transforms/test_k_space_spike_noise.py +81 -0
- tests/transforms/test_k_space_spike_noised.py +98 -0
- tests/transforms/test_keep_largest_connected_component.py +419 -0
- tests/transforms/test_keep_largest_connected_componentd.py +348 -0
- tests/transforms/test_label_filter.py +78 -0
- tests/transforms/test_label_to_contour.py +179 -0
- tests/transforms/test_label_to_contourd.py +182 -0
- tests/transforms/test_label_to_mask.py +69 -0
- tests/transforms/test_label_to_maskd.py +70 -0
- tests/transforms/test_load_image.py +502 -0
- tests/transforms/test_load_imaged.py +198 -0
- tests/transforms/test_load_spacing_orientation.py +149 -0
- tests/transforms/test_map_and_generate_sampling_centers.py +86 -0
- tests/transforms/test_map_binary_to_indices.py +75 -0
- tests/transforms/test_map_classes_to_indices.py +135 -0
- tests/transforms/test_map_label_value.py +89 -0
- tests/transforms/test_map_label_valued.py +85 -0
- tests/transforms/test_map_transform.py +45 -0
- tests/transforms/test_mask_intensity.py +74 -0
- tests/transforms/test_mask_intensityd.py +68 -0
- tests/transforms/test_mean_ensemble.py +77 -0
- tests/transforms/test_mean_ensembled.py +91 -0
- tests/transforms/test_median_smooth.py +41 -0
- tests/transforms/test_median_smoothd.py +65 -0
- tests/transforms/test_morphological_ops.py +101 -0
- tests/transforms/test_nifti_endianness.py +107 -0
- tests/transforms/test_normalize_intensity.py +143 -0
- tests/transforms/test_normalize_intensityd.py +81 -0
- tests/transforms/test_nvtx_decorator.py +289 -0
- tests/transforms/test_nvtx_transform.py +143 -0
- tests/transforms/test_orientation.py +247 -0
- tests/transforms/test_orientationd.py +112 -0
- tests/transforms/test_rand_adjust_contrast.py +45 -0
- tests/transforms/test_rand_adjust_contrastd.py +44 -0
- tests/transforms/test_rand_affine.py +201 -0
- tests/transforms/test_rand_affine_grid.py +212 -0
- tests/transforms/test_rand_affined.py +281 -0
- tests/transforms/test_rand_axis_flip.py +50 -0
- tests/transforms/test_rand_axis_flipd.py +50 -0
- tests/transforms/test_rand_bias_field.py +69 -0
- tests/transforms/test_rand_bias_fieldd.py +65 -0
- tests/transforms/test_rand_coarse_dropout.py +110 -0
- tests/transforms/test_rand_coarse_dropoutd.py +107 -0
- tests/transforms/test_rand_coarse_shuffle.py +65 -0
- tests/transforms/test_rand_coarse_shuffled.py +59 -0
- tests/transforms/test_rand_crop_by_label_classes.py +170 -0
- tests/transforms/test_rand_crop_by_label_classesd.py +159 -0
- tests/transforms/test_rand_crop_by_pos_neg_label.py +152 -0
- tests/transforms/test_rand_crop_by_pos_neg_labeld.py +172 -0
- tests/transforms/test_rand_cucim_dict_transform.py +162 -0
- tests/transforms/test_rand_cucim_transform.py +162 -0
- tests/transforms/test_rand_deform_grid.py +138 -0
- tests/transforms/test_rand_elastic_2d.py +127 -0
- tests/transforms/test_rand_elastic_3d.py +104 -0
- tests/transforms/test_rand_elasticd_2d.py +177 -0
- tests/transforms/test_rand_elasticd_3d.py +156 -0
- tests/transforms/test_rand_flip.py +60 -0
- tests/transforms/test_rand_flipd.py +55 -0
- tests/transforms/test_rand_gaussian_noise.py +48 -0
- tests/transforms/test_rand_gaussian_noised.py +54 -0
- tests/transforms/test_rand_gaussian_sharpen.py +140 -0
- tests/transforms/test_rand_gaussian_sharpend.py +143 -0
- tests/transforms/test_rand_gaussian_smooth.py +98 -0
- tests/transforms/test_rand_gaussian_smoothd.py +98 -0
- tests/transforms/test_rand_gibbs_noise.py +103 -0
- tests/transforms/test_rand_gibbs_noised.py +117 -0
- tests/transforms/test_rand_grid_distortion.py +99 -0
- tests/transforms/test_rand_grid_distortiond.py +90 -0
- tests/transforms/test_rand_histogram_shift.py +92 -0
- tests/transforms/test_rand_k_space_spike_noise.py +92 -0
- tests/transforms/test_rand_k_space_spike_noised.py +76 -0
- tests/transforms/test_rand_rician_noise.py +52 -0
- tests/transforms/test_rand_rician_noised.py +52 -0
- tests/transforms/test_rand_rotate.py +166 -0
- tests/transforms/test_rand_rotate90.py +100 -0
- tests/transforms/test_rand_rotate90d.py +112 -0
- tests/transforms/test_rand_rotated.py +187 -0
- tests/transforms/test_rand_scale_crop.py +78 -0
- tests/transforms/test_rand_scale_cropd.py +98 -0
- tests/transforms/test_rand_scale_intensity.py +54 -0
- tests/transforms/test_rand_scale_intensity_fixed_mean.py +41 -0
- tests/transforms/test_rand_scale_intensity_fixed_meand.py +41 -0
- tests/transforms/test_rand_scale_intensityd.py +53 -0
- tests/transforms/test_rand_shift_intensity.py +52 -0
- tests/transforms/test_rand_shift_intensityd.py +67 -0
- tests/transforms/test_rand_simulate_low_resolution.py +83 -0
- tests/transforms/test_rand_simulate_low_resolutiond.py +73 -0
- tests/transforms/test_rand_spatial_crop.py +107 -0
- tests/transforms/test_rand_spatial_crop_samples.py +128 -0
- tests/transforms/test_rand_spatial_crop_samplesd.py +147 -0
- tests/transforms/test_rand_spatial_cropd.py +112 -0
- tests/transforms/test_rand_std_shift_intensity.py +43 -0
- tests/transforms/test_rand_std_shift_intensityd.py +38 -0
- tests/transforms/test_rand_zoom.py +105 -0
- tests/transforms/test_rand_zoomd.py +108 -0
- tests/transforms/test_randidentity.py +49 -0
- tests/transforms/test_random_order.py +144 -0
- tests/transforms/test_randtorchvisiond.py +65 -0
- tests/transforms/test_regularization.py +139 -0
- tests/transforms/test_remove_repeated_channel.py +34 -0
- tests/transforms/test_remove_repeated_channeld.py +44 -0
- tests/transforms/test_repeat_channel.py +34 -0
- tests/transforms/test_repeat_channeld.py +41 -0
- tests/transforms/test_resample_backends.py +65 -0
- tests/transforms/test_resample_to_match.py +110 -0
- tests/transforms/test_resample_to_matchd.py +93 -0
- tests/transforms/test_resampler.py +165 -0
- tests/transforms/test_resize.py +140 -0
- tests/transforms/test_resize_with_pad_or_crop.py +91 -0
- tests/transforms/test_resize_with_pad_or_cropd.py +86 -0
- tests/transforms/test_resized.py +163 -0
- tests/transforms/test_rotate.py +160 -0
- tests/transforms/test_rotate90.py +212 -0
- tests/transforms/test_rotate90d.py +106 -0
- tests/transforms/test_rotated.py +179 -0
- tests/transforms/test_save_classificationd.py +109 -0
- tests/transforms/test_save_image.py +80 -0
- tests/transforms/test_save_imaged.py +130 -0
- tests/transforms/test_savitzky_golay_smooth.py +73 -0
- tests/transforms/test_savitzky_golay_smoothd.py +73 -0
- tests/transforms/test_scale_intensity.py +76 -0
- tests/transforms/test_scale_intensity_fixed_mean.py +94 -0
- tests/transforms/test_scale_intensity_range.py +41 -0
- tests/transforms/test_scale_intensity_ranged.py +40 -0
- tests/transforms/test_scale_intensityd.py +57 -0
- tests/transforms/test_select_itemsd.py +41 -0
- tests/transforms/test_shift_intensity.py +31 -0
- tests/transforms/test_shift_intensityd.py +44 -0
- tests/transforms/test_signal_continuouswavelet.py +44 -0
- tests/transforms/test_signal_fillempty.py +52 -0
- tests/transforms/test_signal_fillemptyd.py +60 -0
- tests/transforms/test_signal_rand_add_gaussiannoise.py +50 -0
- tests/transforms/test_signal_rand_add_sine.py +52 -0
- tests/transforms/test_signal_rand_add_sine_partial.py +50 -0
- tests/transforms/test_signal_rand_add_squarepulse.py +58 -0
- tests/transforms/test_signal_rand_add_squarepulse_partial.py +62 -0
- tests/transforms/test_signal_rand_drop.py +50 -0
- tests/transforms/test_signal_rand_scale.py +52 -0
- tests/transforms/test_signal_rand_shift.py +55 -0
- tests/transforms/test_signal_remove_frequency.py +71 -0
- tests/transforms/test_smooth_field.py +177 -0
- tests/transforms/test_sobel_gradient.py +189 -0
- tests/transforms/test_sobel_gradientd.py +212 -0
- tests/transforms/test_spacing.py +381 -0
- tests/transforms/test_spacingd.py +178 -0
- tests/transforms/test_spatial_crop.py +82 -0
- tests/transforms/test_spatial_cropd.py +74 -0
- tests/transforms/test_spatial_pad.py +57 -0
- tests/transforms/test_spatial_padd.py +43 -0
- tests/transforms/test_spatial_resample.py +235 -0
- tests/transforms/test_squeezedim.py +62 -0
- tests/transforms/test_squeezedimd.py +98 -0
- tests/transforms/test_std_shift_intensity.py +76 -0
- tests/transforms/test_std_shift_intensityd.py +74 -0
- tests/transforms/test_threshold_intensity.py +38 -0
- tests/transforms/test_threshold_intensityd.py +58 -0
- tests/transforms/test_to_contiguous.py +47 -0
- tests/transforms/test_to_cupy.py +112 -0
- tests/transforms/test_to_cupyd.py +76 -0
- tests/transforms/test_to_device.py +42 -0
- tests/transforms/test_to_deviced.py +37 -0
- tests/transforms/test_to_numpy.py +85 -0
- tests/transforms/test_to_numpyd.py +68 -0
- tests/transforms/test_to_pil.py +52 -0
- tests/transforms/test_to_pild.py +50 -0
- tests/transforms/test_to_tensor.py +60 -0
- tests/transforms/test_to_tensord.py +71 -0
- tests/transforms/test_torchvision.py +66 -0
- tests/transforms/test_torchvisiond.py +63 -0
- tests/transforms/test_transform.py +62 -0
- tests/transforms/test_transpose.py +41 -0
- tests/transforms/test_transposed.py +52 -0
- tests/transforms/test_ultrasound_confidence_map_transform.py +711 -0
- tests/transforms/test_utils_pytorch_numpy_unification.py +90 -0
- tests/transforms/test_vote_ensemble.py +84 -0
- tests/transforms/test_vote_ensembled.py +107 -0
- tests/transforms/test_with_allow_missing_keys.py +76 -0
- tests/transforms/test_zoom.py +120 -0
- tests/transforms/test_zoomd.py +94 -0
- tests/transforms/transform/__init__.py +10 -0
- tests/transforms/transform/test_randomizable.py +52 -0
- tests/transforms/transform/test_randomizable_transform_type.py +37 -0
- tests/transforms/utility/__init__.py +10 -0
- tests/transforms/utility/test_apply_transform_to_points.py +81 -0
- tests/transforms/utility/test_apply_transform_to_pointsd.py +185 -0
- tests/transforms/utility/test_identity.py +29 -0
- tests/transforms/utility/test_identityd.py +30 -0
- tests/transforms/utility/test_lambda.py +71 -0
- tests/transforms/utility/test_lambdad.py +83 -0
- tests/transforms/utility/test_rand_lambda.py +87 -0
- tests/transforms/utility/test_rand_lambdad.py +77 -0
- tests/transforms/utility/test_simulatedelay.py +36 -0
- tests/transforms/utility/test_simulatedelayd.py +36 -0
- tests/transforms/utility/test_splitdim.py +52 -0
- tests/transforms/utility/test_splitdimd.py +96 -0
- tests/transforms/utils/__init__.py +10 -0
- tests/transforms/utils/test_correct_crop_centers.py +36 -0
- tests/transforms/utils/test_get_unique_labels.py +45 -0
- tests/transforms/utils/test_print_transform_backends.py +29 -0
- tests/transforms/utils/test_soft_clip.py +125 -0
- tests/utils/__init__.py +10 -0
- tests/utils/enums/__init__.py +10 -0
- tests/utils/enums/test_hovernet_loss.py +190 -0
- tests/utils/enums/test_ordering.py +289 -0
- tests/utils/enums/test_wsireader.py +663 -0
- tests/utils/misc/__init__.py +10 -0
- tests/utils/misc/test_ensure_tuple.py +53 -0
- tests/utils/misc/test_monai_env_vars.py +44 -0
- tests/utils/misc/test_monai_utils_misc.py +103 -0
- tests/utils/misc/test_str2bool.py +34 -0
- tests/utils/misc/test_str2list.py +33 -0
- tests/utils/test_alias.py +44 -0
- tests/utils/test_component_store.py +73 -0
- tests/utils/test_deprecated.py +455 -0
- tests/utils/test_enum_bound_interp.py +75 -0
- tests/utils/test_evenly_divisible_all_gather_dist.py +50 -0
- tests/utils/test_get_package_version.py +34 -0
- tests/utils/test_handler_logfile.py +84 -0
- tests/utils/test_handler_metric_logger.py +62 -0
- tests/utils/test_list_to_dict.py +43 -0
- tests/utils/test_look_up_option.py +87 -0
- tests/utils/test_optional_import.py +80 -0
- tests/utils/test_pad_mode.py +39 -0
- tests/utils/test_profiling.py +208 -0
- tests/utils/test_rankfilter_dist.py +77 -0
- tests/utils/test_require_pkg.py +83 -0
- tests/utils/test_sample_slices.py +43 -0
- tests/utils/test_set_determinism.py +74 -0
- tests/utils/test_squeeze_unsqueeze.py +71 -0
- tests/utils/test_state_cacher.py +67 -0
- tests/utils/test_torchscript_utils.py +113 -0
- tests/utils/test_version.py +91 -0
- tests/utils/test_version_after.py +65 -0
- tests/utils/type_conversion/__init__.py +10 -0
- tests/utils/type_conversion/test_convert_data_type.py +152 -0
- tests/utils/type_conversion/test_get_equivalent_dtype.py +65 -0
- tests/utils/type_conversion/test_safe_dtype_range.py +99 -0
- tests/visualize/__init__.py +10 -0
- tests/visualize/test_img2tensorboard.py +46 -0
- tests/visualize/test_occlusion_sensitivity.py +128 -0
- tests/visualize/test_plot_2d_or_3d_image.py +74 -0
- tests/visualize/test_vis_cam.py +98 -0
- tests/visualize/test_vis_gradcam.py +211 -0
- tests/visualize/utils/__init__.py +10 -0
- tests/visualize/utils/test_blend_images.py +63 -0
- tests/visualize/utils/test_matshow3d.py +133 -0
- monai_weekly-1.5.dev2506.dist-info/RECORD +0 -427
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2507.dist-info}/LICENSE +0 -0
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2507.dist-info}/WHEEL +0 -0
@@ -0,0 +1,55 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
import torch
|
18
|
+
from parameterized import parameterized
|
19
|
+
|
20
|
+
from monai.networks.layers import polyval
|
21
|
+
|
22
|
+
TEST_CASES = [
|
23
|
+
[[1.0, 2.5, -4.2], 5.0, 33.3],
|
24
|
+
[[2, 1, 0], 3.0, 21],
|
25
|
+
[[2, 1, 0], [3.0, 3.0], [21, 21]],
|
26
|
+
[torch.as_tensor([2, 1, 0]), [3.0, 3.0], [21, 21]],
|
27
|
+
[torch.as_tensor([2, 1, 0]), torch.as_tensor([3.0, 3.0]), [21, 21]],
|
28
|
+
[torch.as_tensor([2, 1, 0]), np.array([3.0, 3.0]), [21, 21]],
|
29
|
+
[[], np.array([3.0, 3.0]), [0, 0]],
|
30
|
+
]
|
31
|
+
|
32
|
+
|
33
|
+
class TestPolyval(unittest.TestCase):
|
34
|
+
|
35
|
+
@parameterized.expand(TEST_CASES)
|
36
|
+
def test_floats(self, coef, x, expected):
|
37
|
+
result = polyval(coef, x)
|
38
|
+
np.testing.assert_allclose(result.cpu().numpy(), expected)
|
39
|
+
|
40
|
+
@parameterized.expand(TEST_CASES)
|
41
|
+
def test_gpu(self, coef, x, expected):
|
42
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
43
|
+
x = torch.as_tensor(x, dtype=torch.float, device=device)
|
44
|
+
x.requires_grad = True
|
45
|
+
coef = torch.as_tensor(coef, dtype=torch.float, device=device)
|
46
|
+
coef.requires_grad = True
|
47
|
+
result = polyval(coef, x)
|
48
|
+
if coef.shape[0] > 0: # empty coef doesn't have grad
|
49
|
+
result.mean().backward()
|
50
|
+
np.testing.assert_allclose(coef.grad.shape, coef.shape)
|
51
|
+
np.testing.assert_allclose(result.cpu().detach().numpy(), expected)
|
52
|
+
|
53
|
+
|
54
|
+
if __name__ == "__main__":
|
55
|
+
unittest.main()
|
@@ -0,0 +1,136 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import torch
|
17
|
+
from parameterized import parameterized
|
18
|
+
|
19
|
+
from monai.networks.layers import ApplyFilter, EllipticalFilter, LaplaceFilter, MeanFilter, SharpenFilter
|
20
|
+
|
21
|
+
TEST_CASES_MEAN = [(3, 3, torch.ones(3, 3, 3)), (2, 5, torch.ones(5, 5))]
|
22
|
+
|
23
|
+
TEST_CASES_LAPLACE = [
|
24
|
+
(
|
25
|
+
3,
|
26
|
+
3,
|
27
|
+
torch.Tensor(
|
28
|
+
[
|
29
|
+
[[-1, -1, -1], [-1, -1, -1], [-1, -1, -1]],
|
30
|
+
[[-1, -1, -1], [-1, 26, -1], [-1, -1, -1]],
|
31
|
+
[[-1, -1, -1], [-1, -1, -1], [-1, -1, -1]],
|
32
|
+
]
|
33
|
+
),
|
34
|
+
),
|
35
|
+
(2, 3, torch.Tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]])),
|
36
|
+
]
|
37
|
+
|
38
|
+
TEST_CASES_ELLIPTICAL = [
|
39
|
+
(
|
40
|
+
3,
|
41
|
+
3,
|
42
|
+
torch.Tensor(
|
43
|
+
[[[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 1, 0], [1, 1, 1], [0, 1, 0]], [[0, 0, 0], [0, 1, 0], [0, 0, 0]]]
|
44
|
+
),
|
45
|
+
),
|
46
|
+
(2, 3, torch.Tensor([[0, 1, 0], [1, 1, 1], [0, 1, 0]])),
|
47
|
+
]
|
48
|
+
|
49
|
+
TEST_CASES_SHARPEN = [
|
50
|
+
(
|
51
|
+
3,
|
52
|
+
3,
|
53
|
+
torch.Tensor(
|
54
|
+
[
|
55
|
+
[[0, 0, 0], [0, -1, 0], [0, 0, 0]],
|
56
|
+
[[0, -1, 0], [-1, 7, -1], [0, -1, 0]],
|
57
|
+
[[0, 0, 0], [0, -1, 0], [0, 0, 0]],
|
58
|
+
]
|
59
|
+
),
|
60
|
+
),
|
61
|
+
(2, 3, torch.Tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])),
|
62
|
+
]
|
63
|
+
|
64
|
+
|
65
|
+
class _TestFilter:
|
66
|
+
|
67
|
+
def test_init(self, spatial_dims, size, expected):
|
68
|
+
test_filter = self.filter_class(spatial_dims=spatial_dims, size=size)
|
69
|
+
torch.testing.assert_allclose(expected, test_filter.filter)
|
70
|
+
self.assertIsInstance(test_filter, torch.nn.Module)
|
71
|
+
|
72
|
+
def test_forward(self):
|
73
|
+
test_filter = self.filter_class(spatial_dims=2, size=3)
|
74
|
+
input = torch.ones(1, 1, 5, 5)
|
75
|
+
_ = test_filter(input)
|
76
|
+
|
77
|
+
|
78
|
+
class TestApplyFilter(unittest.TestCase):
|
79
|
+
|
80
|
+
def test_init_and_forward_2d(self):
|
81
|
+
filter_2d = torch.ones(3, 3)
|
82
|
+
image_2d = torch.ones(1, 3, 3)
|
83
|
+
apply_filter_2d = ApplyFilter(filter_2d)
|
84
|
+
out = apply_filter_2d(image_2d)
|
85
|
+
self.assertEqual(out.shape, image_2d.shape)
|
86
|
+
|
87
|
+
def test_init_and_forward_3d(self):
|
88
|
+
filter_2d = torch.ones(3, 3, 3)
|
89
|
+
image_2d = torch.ones(1, 3, 3, 3)
|
90
|
+
apply_filter_2d = ApplyFilter(filter_2d)
|
91
|
+
out = apply_filter_2d(image_2d)
|
92
|
+
self.assertEqual(out.shape, image_2d.shape)
|
93
|
+
|
94
|
+
|
95
|
+
class MeanFilterTestCase(_TestFilter, unittest.TestCase):
|
96
|
+
|
97
|
+
def setUp(self) -> None:
|
98
|
+
self.filter_class = MeanFilter
|
99
|
+
|
100
|
+
@parameterized.expand(TEST_CASES_MEAN)
|
101
|
+
def test_init(self, spatial_dims, size, expected):
|
102
|
+
super().test_init(spatial_dims, size, expected)
|
103
|
+
|
104
|
+
|
105
|
+
class LaplaceFilterTestCase(_TestFilter, unittest.TestCase):
|
106
|
+
|
107
|
+
def setUp(self) -> None:
|
108
|
+
self.filter_class = LaplaceFilter
|
109
|
+
|
110
|
+
@parameterized.expand(TEST_CASES_LAPLACE)
|
111
|
+
def test_init(self, spatial_dims, size, expected):
|
112
|
+
super().test_init(spatial_dims, size, expected)
|
113
|
+
|
114
|
+
|
115
|
+
class EllipticalTestCase(_TestFilter, unittest.TestCase):
|
116
|
+
|
117
|
+
def setUp(self) -> None:
|
118
|
+
self.filter_class = EllipticalFilter
|
119
|
+
|
120
|
+
@parameterized.expand(TEST_CASES_ELLIPTICAL)
|
121
|
+
def test_init(self, spatial_dims, size, expected):
|
122
|
+
super().test_init(spatial_dims, size, expected)
|
123
|
+
|
124
|
+
|
125
|
+
class SharpenTestCase(_TestFilter, unittest.TestCase):
|
126
|
+
|
127
|
+
def setUp(self) -> None:
|
128
|
+
self.filter_class = SharpenFilter
|
129
|
+
|
130
|
+
@parameterized.expand(TEST_CASES_SHARPEN)
|
131
|
+
def test_init(self, spatial_dims, size, expected):
|
132
|
+
super().test_init(spatial_dims, size, expected)
|
133
|
+
|
134
|
+
|
135
|
+
if __name__ == "__main__":
|
136
|
+
unittest.main()
|
@@ -0,0 +1,141 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
import torch
|
18
|
+
from parameterized import parameterized
|
19
|
+
|
20
|
+
from monai.networks.layers import SavitzkyGolayFilter
|
21
|
+
from tests.test_utils import skip_if_no_cuda
|
22
|
+
|
23
|
+
# Zero-padding trivial tests
|
24
|
+
|
25
|
+
TEST_CASE_SINGLE_VALUE = [
|
26
|
+
{"window_length": 3, "order": 1},
|
27
|
+
torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value
|
28
|
+
torch.Tensor([1 / 3]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1
|
29
|
+
# output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed)
|
30
|
+
1e-6, # absolute tolerance
|
31
|
+
]
|
32
|
+
|
33
|
+
TEST_CASE_1D = [
|
34
|
+
{"window_length": 3, "order": 1},
|
35
|
+
torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data
|
36
|
+
torch.Tensor([2 / 3, 1.0, 2 / 3])
|
37
|
+
.unsqueeze(0)
|
38
|
+
.unsqueeze(0), # Expected output: zero padded, so linear interpolation
|
39
|
+
# over length-3 windows will result in output of [2/3, 1, 2/3].
|
40
|
+
1e-6, # absolute tolerance
|
41
|
+
]
|
42
|
+
|
43
|
+
TEST_CASE_2D_AXIS_2 = [
|
44
|
+
{"window_length": 3, "order": 1}, # along default axis (2, first spatial dim)
|
45
|
+
torch.ones((3, 2)).unsqueeze(0).unsqueeze(0),
|
46
|
+
torch.Tensor([[2 / 3, 2 / 3], [1.0, 1.0], [2 / 3, 2 / 3]]).unsqueeze(0).unsqueeze(0),
|
47
|
+
1e-6, # absolute tolerance
|
48
|
+
]
|
49
|
+
|
50
|
+
TEST_CASE_2D_AXIS_3 = [
|
51
|
+
{"window_length": 3, "order": 1, "axis": 3}, # along axis 3 (second spatial dim)
|
52
|
+
torch.ones((2, 3)).unsqueeze(0).unsqueeze(0),
|
53
|
+
torch.Tensor([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]).unsqueeze(0).unsqueeze(0),
|
54
|
+
1e-6, # absolute tolerance
|
55
|
+
]
|
56
|
+
|
57
|
+
# Replicated-padding trivial tests
|
58
|
+
|
59
|
+
TEST_CASE_SINGLE_VALUE_REP = [
|
60
|
+
{"window_length": 3, "order": 1, "mode": "replicate"},
|
61
|
+
torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value
|
62
|
+
torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1
|
63
|
+
# output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed)
|
64
|
+
1e-6, # absolute tolerance
|
65
|
+
]
|
66
|
+
|
67
|
+
TEST_CASE_1D_REP = [
|
68
|
+
{"window_length": 3, "order": 1, "mode": "replicate"},
|
69
|
+
torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data
|
70
|
+
torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Expected output: zero padded, so linear interpolation
|
71
|
+
# over length-3 windows will result in output of [2/3, 1, 2/3].
|
72
|
+
1e-6, # absolute tolerance
|
73
|
+
]
|
74
|
+
|
75
|
+
TEST_CASE_2D_AXIS_2_REP = [
|
76
|
+
{"window_length": 3, "order": 1, "mode": "replicate"}, # along default axis (2, first spatial dim)
|
77
|
+
torch.ones((3, 2)).unsqueeze(0).unsqueeze(0),
|
78
|
+
torch.Tensor([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]).unsqueeze(0).unsqueeze(0),
|
79
|
+
1e-6, # absolute tolerance
|
80
|
+
]
|
81
|
+
|
82
|
+
TEST_CASE_2D_AXIS_3_REP = [
|
83
|
+
{"window_length": 3, "order": 1, "axis": 3, "mode": "replicate"}, # along axis 3 (second spatial dim)
|
84
|
+
torch.ones((2, 3)).unsqueeze(0).unsqueeze(0),
|
85
|
+
torch.Tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]).unsqueeze(0).unsqueeze(0),
|
86
|
+
1e-6, # absolute tolerance
|
87
|
+
]
|
88
|
+
|
89
|
+
# Sine smoothing
|
90
|
+
|
91
|
+
TEST_CASE_SINE_SMOOTH = [
|
92
|
+
{"window_length": 3, "order": 1},
|
93
|
+
# Sine wave with period equal to savgol window length (windowed to reduce edge effects).
|
94
|
+
torch.as_tensor(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100)).unsqueeze(0).unsqueeze(0),
|
95
|
+
# Should be smoothed out to zeros
|
96
|
+
torch.zeros(100).unsqueeze(0).unsqueeze(0),
|
97
|
+
# tolerance chosen by examining output of SciPy.signal.savgol_filter when provided the above input
|
98
|
+
2e-2, # absolute tolerance
|
99
|
+
]
|
100
|
+
|
101
|
+
|
102
|
+
class TestSavitzkyGolayCPU(unittest.TestCase):
|
103
|
+
@parameterized.expand(
|
104
|
+
[TEST_CASE_SINGLE_VALUE, TEST_CASE_1D, TEST_CASE_2D_AXIS_2, TEST_CASE_2D_AXIS_3, TEST_CASE_SINE_SMOOTH]
|
105
|
+
)
|
106
|
+
def test_value(self, arguments, image, expected_data, atol, rtol=1e-5):
|
107
|
+
result = SavitzkyGolayFilter(**arguments)(image)
|
108
|
+
np.testing.assert_allclose(result, expected_data, atol=atol, rtol=rtol)
|
109
|
+
|
110
|
+
|
111
|
+
class TestSavitzkyGolayCPUREP(unittest.TestCase):
|
112
|
+
@parameterized.expand(
|
113
|
+
[TEST_CASE_SINGLE_VALUE_REP, TEST_CASE_1D_REP, TEST_CASE_2D_AXIS_2_REP, TEST_CASE_2D_AXIS_3_REP]
|
114
|
+
)
|
115
|
+
def test_value(self, arguments, image, expected_data, atol, rtol=1e-5):
|
116
|
+
result = SavitzkyGolayFilter(**arguments)(image)
|
117
|
+
np.testing.assert_allclose(result, expected_data, atol=atol, rtol=rtol)
|
118
|
+
|
119
|
+
|
120
|
+
@skip_if_no_cuda
|
121
|
+
class TestSavitzkyGolayGPU(unittest.TestCase):
|
122
|
+
@parameterized.expand(
|
123
|
+
[TEST_CASE_SINGLE_VALUE, TEST_CASE_1D, TEST_CASE_2D_AXIS_2, TEST_CASE_2D_AXIS_3, TEST_CASE_SINE_SMOOTH]
|
124
|
+
)
|
125
|
+
def test_value(self, arguments, image, expected_data, atol, rtol=1e-5):
|
126
|
+
result = SavitzkyGolayFilter(**arguments)(image.to(device="cuda"))
|
127
|
+
np.testing.assert_allclose(result.cpu(), expected_data, atol=atol, rtol=rtol)
|
128
|
+
|
129
|
+
|
130
|
+
@skip_if_no_cuda
|
131
|
+
class TestSavitzkyGolayGPUREP(unittest.TestCase):
|
132
|
+
@parameterized.expand(
|
133
|
+
[TEST_CASE_SINGLE_VALUE_REP, TEST_CASE_1D_REP, TEST_CASE_2D_AXIS_2_REP, TEST_CASE_2D_AXIS_3_REP]
|
134
|
+
)
|
135
|
+
def test_value(self, arguments, image, expected_data, atol, rtol=1e-5):
|
136
|
+
result = SavitzkyGolayFilter(**arguments)(image.to(device="cuda"))
|
137
|
+
np.testing.assert_allclose(result.cpu(), expected_data, atol=atol, rtol=rtol)
|
138
|
+
|
139
|
+
|
140
|
+
if __name__ == "__main__":
|
141
|
+
unittest.main()
|
@@ -0,0 +1,87 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
import torch
|
18
|
+
|
19
|
+
from monai.networks.layers import separable_filtering
|
20
|
+
|
21
|
+
|
22
|
+
class SeparableFilterTestCase(unittest.TestCase):
|
23
|
+
|
24
|
+
def test_1d(self):
|
25
|
+
a = torch.tensor([[list(range(10))]], dtype=torch.float)
|
26
|
+
out = separable_filtering(a, torch.tensor([-1, 0, 1]))
|
27
|
+
expected = np.array([[[1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, -8.0]]])
|
28
|
+
np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4)
|
29
|
+
if torch.cuda.is_available():
|
30
|
+
out = separable_filtering(a.cuda(), torch.tensor([-1, 0, 1]).cuda())
|
31
|
+
np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4)
|
32
|
+
|
33
|
+
def test_2d(self):
|
34
|
+
a = torch.tensor([[[list(range(7)), list(range(7, 0, -1)), list(range(7))]]], dtype=torch.float)
|
35
|
+
expected = np.array(
|
36
|
+
[
|
37
|
+
[28.0, 28.0, 28.0, 28.0, 28.0, 28.0],
|
38
|
+
[30.0, 34.0, 38.0, 42.0, 46.0, 50.0],
|
39
|
+
[28.0, 28.0, 28.0, 28.0, 28.0, 28.0],
|
40
|
+
]
|
41
|
+
)
|
42
|
+
expected = expected[None][None]
|
43
|
+
out = separable_filtering(a, [torch.tensor([1, 1, 1]), torch.tensor([2, 2])])
|
44
|
+
np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4)
|
45
|
+
if torch.cuda.is_available():
|
46
|
+
out = separable_filtering(a.cuda(), [torch.tensor([1, 1, 1]).cuda(), torch.tensor([2, 2]).cuda()])
|
47
|
+
np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4)
|
48
|
+
|
49
|
+
def test_3d(self):
|
50
|
+
a = torch.tensor(
|
51
|
+
[[list(range(7)), list(range(7)), list(range(7))], [list(range(7)), list(range(7)), list(range(7))]],
|
52
|
+
dtype=torch.float,
|
53
|
+
)
|
54
|
+
a = a[None][None]
|
55
|
+
a = a.expand(2, 3, -1, -1, -1)
|
56
|
+
expected = np.array(
|
57
|
+
[
|
58
|
+
[
|
59
|
+
[4.0, 12.0, 24.0, 36.0, 48.0, 60.0, 44.0],
|
60
|
+
[6.0, 18.0, 36.0, 54.0, 72.0, 90.0, 66.0],
|
61
|
+
[4.0, 12.0, 24.0, 36.0, 48.0, 60.0, 44.0],
|
62
|
+
],
|
63
|
+
[
|
64
|
+
[4.0, 12.0, 24.0, 36.0, 48.0, 60.0, 44.0],
|
65
|
+
[6.0, 18.0, 36.0, 54.0, 72.0, 90.0, 66.0],
|
66
|
+
[4.0, 12.0, 24.0, 36.0, 48.0, 60.0, 44.0],
|
67
|
+
],
|
68
|
+
]
|
69
|
+
)
|
70
|
+
# testing shapes
|
71
|
+
k = torch.tensor([1, 1, 1])
|
72
|
+
for kernel in (k, [k] * 3):
|
73
|
+
out = separable_filtering(a, kernel)
|
74
|
+
np.testing.assert_allclose(out.cpu().numpy()[1][2], expected, rtol=1e-4)
|
75
|
+
if torch.cuda.is_available():
|
76
|
+
out = separable_filtering(
|
77
|
+
a.cuda(), kernel.cuda() if isinstance(kernel, torch.Tensor) else [k.cuda() for k in kernel]
|
78
|
+
)
|
79
|
+
np.testing.assert_allclose(out.cpu().numpy()[0][1], expected, rtol=1e-4)
|
80
|
+
|
81
|
+
def test_wrong_args(self):
|
82
|
+
with self.assertRaisesRegex(TypeError, ""):
|
83
|
+
separable_filtering(((1, 1, 1, 2, 3, 2)), torch.ones((2,)))
|
84
|
+
|
85
|
+
|
86
|
+
if __name__ == "__main__":
|
87
|
+
unittest.main()
|
@@ -0,0 +1,48 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import torch
|
17
|
+
from parameterized import parameterized
|
18
|
+
|
19
|
+
from monai.networks import eval_mode
|
20
|
+
from monai.networks.layers import SkipConnection
|
21
|
+
|
22
|
+
TEST_CASES_3D = []
|
23
|
+
for type_1 in ("cat", "add", "mul"):
|
24
|
+
input_shape = (16, 10, 32, 24, 48)
|
25
|
+
if type_1 == "cat":
|
26
|
+
result_shape = (input_shape[0] * 2, *input_shape[1:])
|
27
|
+
else:
|
28
|
+
result_shape = input_shape
|
29
|
+
test_case = [{"dim": 0, "mode": type_1}, input_shape, result_shape]
|
30
|
+
TEST_CASES_3D.append(test_case)
|
31
|
+
|
32
|
+
|
33
|
+
class TestSkipConnection(unittest.TestCase):
|
34
|
+
|
35
|
+
@parameterized.expand(TEST_CASES_3D)
|
36
|
+
def test_shape(self, input_param, input_shape, expected_shape):
|
37
|
+
net = SkipConnection(submodule=torch.nn.Softmax(dim=1), **input_param)
|
38
|
+
with eval_mode(net):
|
39
|
+
result = net(torch.randn(input_shape))
|
40
|
+
self.assertEqual(result.shape, expected_shape)
|
41
|
+
|
42
|
+
def test_wrong_mode(self):
|
43
|
+
with self.assertRaises(ValueError):
|
44
|
+
SkipConnection(torch.nn.Softmax(dim=1), mode="test")(torch.randn(10, 10))
|
45
|
+
|
46
|
+
|
47
|
+
if __name__ == "__main__":
|
48
|
+
unittest.main()
|
@@ -0,0 +1,89 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
from math import prod
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from parameterized import parameterized
|
19
|
+
|
20
|
+
from monai.networks.layers import EMAQuantizer, VectorQuantizer
|
21
|
+
|
22
|
+
TEST_CASES = [
|
23
|
+
[{"spatial_dims": 2, "num_embeddings": 16, "embedding_dim": 8}, (1, 8, 4, 4), (1, 4, 4)],
|
24
|
+
[{"spatial_dims": 3, "num_embeddings": 16, "embedding_dim": 8}, (1, 8, 4, 4, 4), (1, 4, 4, 4)],
|
25
|
+
]
|
26
|
+
|
27
|
+
|
28
|
+
class TestEMA(unittest.TestCase):
|
29
|
+
@parameterized.expand(TEST_CASES)
|
30
|
+
def test_ema_shape(self, input_param, input_shape, output_shape):
|
31
|
+
layer = EMAQuantizer(**input_param)
|
32
|
+
x = torch.randn(input_shape)
|
33
|
+
layer = layer.train()
|
34
|
+
outputs = layer(x)
|
35
|
+
self.assertEqual(outputs[0].shape, input_shape)
|
36
|
+
self.assertEqual(outputs[2].shape, output_shape)
|
37
|
+
|
38
|
+
layer = layer.eval()
|
39
|
+
outputs = layer(x)
|
40
|
+
self.assertEqual(outputs[0].shape, input_shape)
|
41
|
+
self.assertEqual(outputs[2].shape, output_shape)
|
42
|
+
|
43
|
+
@parameterized.expand(TEST_CASES)
|
44
|
+
def test_ema_quantize(self, input_param, input_shape, output_shape):
|
45
|
+
layer = EMAQuantizer(**input_param)
|
46
|
+
x = torch.randn(input_shape)
|
47
|
+
outputs = layer.quantize(x)
|
48
|
+
self.assertEqual(outputs[0].shape, (prod(input_shape[2:]), input_shape[1])) # (HxW[xD], C)
|
49
|
+
self.assertEqual(outputs[1].shape, (prod(input_shape[2:]), input_param["num_embeddings"])) # (HxW[xD], E)
|
50
|
+
self.assertEqual(outputs[2].shape, (input_shape[0],) + input_shape[2:]) # (1, H, W, [D])
|
51
|
+
|
52
|
+
def test_ema(self):
|
53
|
+
layer = EMAQuantizer(spatial_dims=2, num_embeddings=2, embedding_dim=2, epsilon=0, decay=0)
|
54
|
+
original_weight_0 = layer.embedding.weight[0].clone()
|
55
|
+
original_weight_1 = layer.embedding.weight[1].clone()
|
56
|
+
x_0 = original_weight_0
|
57
|
+
x_0 = x_0.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
58
|
+
x_0 = x_0.repeat(1, 1, 1, 2) + 0.001
|
59
|
+
|
60
|
+
x_1 = original_weight_1
|
61
|
+
x_1 = x_1.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
62
|
+
x_1 = x_1.repeat(1, 1, 1, 2)
|
63
|
+
|
64
|
+
x = torch.cat([x_0, x_1], dim=0)
|
65
|
+
layer = layer.train()
|
66
|
+
_ = layer(x)
|
67
|
+
|
68
|
+
self.assertTrue(all(layer.embedding.weight[0] != original_weight_0))
|
69
|
+
self.assertTrue(all(layer.embedding.weight[1] == original_weight_1))
|
70
|
+
|
71
|
+
|
72
|
+
class TestVectorQuantizer(unittest.TestCase):
|
73
|
+
@parameterized.expand(TEST_CASES)
|
74
|
+
def test_vector_quantizer_shape(self, input_param, input_shape, output_shape):
|
75
|
+
layer = VectorQuantizer(EMAQuantizer(**input_param))
|
76
|
+
x = torch.randn(input_shape)
|
77
|
+
outputs = layer(x)
|
78
|
+
self.assertEqual(outputs[1].shape, input_shape)
|
79
|
+
|
80
|
+
@parameterized.expand(TEST_CASES)
|
81
|
+
def test_vector_quantizer_quantize(self, input_param, input_shape, output_shape):
|
82
|
+
layer = VectorQuantizer(EMAQuantizer(**input_param))
|
83
|
+
x = torch.randn(input_shape)
|
84
|
+
outputs = layer.quantize(x)
|
85
|
+
self.assertEqual(outputs.shape, output_shape)
|
86
|
+
|
87
|
+
|
88
|
+
if __name__ == "__main__":
|
89
|
+
unittest.main()
|
@@ -0,0 +1,50 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import torch
|
17
|
+
from parameterized import parameterized
|
18
|
+
|
19
|
+
from monai.networks.layers import trunc_normal_
|
20
|
+
|
21
|
+
TEST_CASES = [
|
22
|
+
[{"mean": 0.0, "std": 1.0, "a": 2, "b": 4}, (6, 12, 3, 1, 7)],
|
23
|
+
[{"mean": 0.3, "std": 0.6, "a": -1.0, "b": 1.3}, (1, 4, 4, 4)],
|
24
|
+
[{"mean": 0.1, "std": 0.4, "a": 1.3, "b": 1.8}, (5, 7, 7, 8, 9)],
|
25
|
+
]
|
26
|
+
|
27
|
+
TEST_ERRORS = [
|
28
|
+
[{"mean": 0.0, "std": 1.0, "a": 5, "b": 1.1}, (1, 1, 8, 8, 8)],
|
29
|
+
[{"mean": 0.3, "std": -0.1, "a": 1.0, "b": 2.0}, (8, 5, 2, 6, 9)],
|
30
|
+
[{"mean": 0.7, "std": 0.0, "a": 0.1, "b": 2.0}, (4, 12, 23, 17)],
|
31
|
+
]
|
32
|
+
|
33
|
+
|
34
|
+
class TestWeightInit(unittest.TestCase):
|
35
|
+
|
36
|
+
@parameterized.expand(TEST_CASES)
|
37
|
+
def test_shape(self, input_param, input_shape):
|
38
|
+
im = torch.rand(input_shape)
|
39
|
+
trunc_normal_(im, **input_param)
|
40
|
+
self.assertEqual(im.shape, input_shape)
|
41
|
+
|
42
|
+
@parameterized.expand(TEST_ERRORS)
|
43
|
+
def test_ill_arg(self, input_param, input_shape):
|
44
|
+
with self.assertRaises(ValueError):
|
45
|
+
im = torch.rand(input_shape)
|
46
|
+
trunc_normal_(im, **input_param)
|
47
|
+
|
48
|
+
|
49
|
+
if __name__ == "__main__":
|
50
|
+
unittest.main()
|
@@ -0,0 +1,10 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
@@ -0,0 +1,10 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|