monai-weekly 1.5.dev2506__py3-none-any.whl → 1.5.dev2508__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/auto3dseg/transforms.py +1 -4
- monai/data/utils.py +6 -13
- monai/handlers/__init__.py +1 -0
- monai/handlers/average_precision.py +53 -0
- monai/inferers/inferer.py +10 -7
- monai/inferers/utils.py +1 -2
- monai/losses/dice.py +2 -14
- monai/losses/ds_loss.py +1 -3
- monai/metrics/__init__.py +1 -0
- monai/metrics/average_precision.py +187 -0
- monai/networks/layers/simplelayers.py +2 -14
- monai/networks/utils.py +4 -16
- monai/transforms/compose.py +28 -11
- monai/transforms/croppad/array.py +1 -6
- monai/transforms/io/array.py +0 -1
- monai/transforms/transform.py +15 -6
- monai/transforms/utility/array.py +2 -12
- monai/transforms/utils.py +1 -2
- monai/transforms/utils_pytorch_numpy_unification.py +2 -4
- monai/utils/enums.py +3 -2
- monai/utils/module.py +6 -6
- monai/utils/tf32.py +0 -10
- monai/visualize/class_activation_maps.py +5 -8
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/METADATA +21 -17
- monai_weekly-1.5.dev2508.dist-info/RECORD +1185 -0
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/top_level.txt +1 -0
- tests/apps/__init__.py +10 -0
- tests/apps/deepedit/__init__.py +10 -0
- tests/apps/deepedit/test_deepedit_transforms.py +314 -0
- tests/apps/deepgrow/__init__.py +10 -0
- tests/apps/deepgrow/test_deepgrow_dataset.py +109 -0
- tests/apps/deepgrow/transforms/__init__.py +10 -0
- tests/apps/deepgrow/transforms/test_deepgrow_interaction.py +97 -0
- tests/apps/deepgrow/transforms/test_deepgrow_transforms.py +556 -0
- tests/apps/detection/__init__.py +10 -0
- tests/apps/detection/metrics/__init__.py +10 -0
- tests/apps/detection/metrics/test_detection_coco_metrics.py +69 -0
- tests/apps/detection/networks/__init__.py +10 -0
- tests/apps/detection/networks/test_retinanet.py +210 -0
- tests/apps/detection/networks/test_retinanet_detector.py +203 -0
- tests/apps/detection/test_box_transform.py +370 -0
- tests/apps/detection/utils/__init__.py +10 -0
- tests/apps/detection/utils/test_anchor_box.py +88 -0
- tests/apps/detection/utils/test_atss_box_matcher.py +46 -0
- tests/apps/detection/utils/test_box_coder.py +43 -0
- tests/apps/detection/utils/test_detector_boxselector.py +67 -0
- tests/apps/detection/utils/test_detector_utils.py +96 -0
- tests/apps/detection/utils/test_hardnegsampler.py +54 -0
- tests/apps/nuclick/__init__.py +10 -0
- tests/apps/nuclick/test_nuclick_transforms.py +259 -0
- tests/apps/pathology/__init__.py +10 -0
- tests/apps/pathology/handlers/__init__.py +10 -0
- tests/apps/pathology/handlers/test_from_engine_hovernet.py +38 -0
- tests/apps/pathology/test_lesion_froc.py +333 -0
- tests/apps/pathology/test_pathology_prob_nms.py +55 -0
- tests/apps/pathology/test_prepare_batch_hovernet.py +70 -0
- tests/apps/pathology/test_sliding_window_hovernet_inference.py +303 -0
- tests/apps/pathology/transforms/__init__.py +10 -0
- tests/apps/pathology/transforms/post/__init__.py +10 -0
- tests/apps/pathology/transforms/post/test_generate_distance_map.py +51 -0
- tests/apps/pathology/transforms/post/test_generate_distance_mapd.py +70 -0
- tests/apps/pathology/transforms/post/test_generate_instance_border.py +49 -0
- tests/apps/pathology/transforms/post/test_generate_instance_borderd.py +59 -0
- tests/apps/pathology/transforms/post/test_generate_instance_centroid.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_instance_centroidd.py +56 -0
- tests/apps/pathology/transforms/post/test_generate_instance_contour.py +58 -0
- tests/apps/pathology/transforms/post/test_generate_instance_contourd.py +61 -0
- tests/apps/pathology/transforms/post/test_generate_instance_type.py +51 -0
- tests/apps/pathology/transforms/post/test_generate_instance_typed.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_succinct_contour.py +55 -0
- tests/apps/pathology/transforms/post/test_generate_succinct_contourd.py +57 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_markers.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_markersd.py +83 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_mask.py +77 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_maskd.py +77 -0
- tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processing.py +61 -0
- tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processingd.py +66 -0
- tests/apps/pathology/transforms/post/test_hovernet_nuclear_type_post_processing.py +66 -0
- tests/apps/pathology/transforms/post/test_watershed.py +60 -0
- tests/apps/pathology/transforms/post/test_watershedd.py +70 -0
- tests/apps/pathology/transforms/test_pathology_he_stain.py +230 -0
- tests/apps/pathology/transforms/test_pathology_he_stain_dict.py +225 -0
- tests/apps/reconstruction/__init__.py +10 -0
- tests/apps/reconstruction/nets/__init__.py +10 -0
- tests/apps/reconstruction/nets/test_recon_net_utils.py +82 -0
- tests/apps/reconstruction/test_complex_utils.py +77 -0
- tests/apps/reconstruction/test_fastmri_reader.py +82 -0
- tests/apps/reconstruction/test_mri_utils.py +37 -0
- tests/apps/reconstruction/transforms/__init__.py +10 -0
- tests/apps/reconstruction/transforms/test_kspace_mask.py +50 -0
- tests/apps/reconstruction/transforms/test_reference_based_normalize_intensity.py +77 -0
- tests/apps/reconstruction/transforms/test_reference_based_spatial_cropd.py +57 -0
- tests/apps/test_auto3dseg_bundlegen.py +156 -0
- tests/apps/test_check_hash.py +53 -0
- tests/apps/test_cross_validation.py +74 -0
- tests/apps/test_decathlondataset.py +93 -0
- tests/apps/test_download_and_extract.py +70 -0
- tests/apps/test_download_url_yandex.py +45 -0
- tests/apps/test_mednistdataset.py +72 -0
- tests/apps/test_mmar_download.py +154 -0
- tests/apps/test_tciadataset.py +123 -0
- tests/apps/vista3d/__init__.py +10 -0
- tests/apps/vista3d/test_point_based_window_inferer.py +77 -0
- tests/apps/vista3d/test_vista3d_sampler.py +100 -0
- tests/apps/vista3d/test_vista3d_transforms.py +94 -0
- tests/bundle/__init__.py +10 -0
- tests/bundle/test_bundle_ckpt_export.py +107 -0
- tests/bundle/test_bundle_download.py +435 -0
- tests/bundle/test_bundle_get_data.py +94 -0
- tests/bundle/test_bundle_push_to_hf_hub.py +41 -0
- tests/bundle/test_bundle_trt_export.py +147 -0
- tests/bundle/test_bundle_utils.py +149 -0
- tests/bundle/test_bundle_verify_metadata.py +66 -0
- tests/bundle/test_bundle_verify_net.py +76 -0
- tests/bundle/test_bundle_workflow.py +272 -0
- tests/bundle/test_component_locator.py +38 -0
- tests/bundle/test_config_item.py +138 -0
- tests/bundle/test_config_parser.py +392 -0
- tests/bundle/test_reference_resolver.py +114 -0
- tests/config/__init__.py +10 -0
- tests/config/test_cv2_dist.py +53 -0
- tests/engines/__init__.py +10 -0
- tests/engines/test_ensemble_evaluator.py +94 -0
- tests/engines/test_prepare_batch_default.py +76 -0
- tests/engines/test_prepare_batch_default_dist.py +76 -0
- tests/engines/test_prepare_batch_diffusion.py +104 -0
- tests/engines/test_prepare_batch_extra_input.py +80 -0
- tests/fl/__init__.py +10 -0
- tests/fl/monai_algo/__init__.py +10 -0
- tests/fl/monai_algo/test_fl_monai_algo.py +251 -0
- tests/fl/monai_algo/test_fl_monai_algo_dist.py +117 -0
- tests/fl/test_fl_monai_algo_stats.py +81 -0
- tests/fl/utils/__init__.py +10 -0
- tests/fl/utils/test_fl_exchange_object.py +63 -0
- tests/handlers/__init__.py +10 -0
- tests/handlers/test_handler_average_precision.py +79 -0
- tests/handlers/test_handler_checkpoint_loader.py +182 -0
- tests/handlers/test_handler_checkpoint_saver.py +233 -0
- tests/handlers/test_handler_classification_saver.py +64 -0
- tests/handlers/test_handler_classification_saver_dist.py +77 -0
- tests/handlers/test_handler_clearml_image.py +65 -0
- tests/handlers/test_handler_clearml_stats.py +65 -0
- tests/handlers/test_handler_confusion_matrix.py +104 -0
- tests/handlers/test_handler_confusion_matrix_dist.py +70 -0
- tests/handlers/test_handler_decollate_batch.py +66 -0
- tests/handlers/test_handler_early_stop.py +68 -0
- tests/handlers/test_handler_garbage_collector.py +73 -0
- tests/handlers/test_handler_hausdorff_distance.py +111 -0
- tests/handlers/test_handler_ignite_metric.py +191 -0
- tests/handlers/test_handler_lr_scheduler.py +94 -0
- tests/handlers/test_handler_mean_dice.py +98 -0
- tests/handlers/test_handler_mean_iou.py +76 -0
- tests/handlers/test_handler_metrics_reloaded.py +149 -0
- tests/handlers/test_handler_metrics_saver.py +89 -0
- tests/handlers/test_handler_metrics_saver_dist.py +120 -0
- tests/handlers/test_handler_mlflow.py +296 -0
- tests/handlers/test_handler_nvtx.py +93 -0
- tests/handlers/test_handler_panoptic_quality.py +89 -0
- tests/handlers/test_handler_parameter_scheduler.py +136 -0
- tests/handlers/test_handler_post_processing.py +74 -0
- tests/handlers/test_handler_prob_map_producer.py +111 -0
- tests/handlers/test_handler_regression_metrics.py +160 -0
- tests/handlers/test_handler_regression_metrics_dist.py +245 -0
- tests/handlers/test_handler_rocauc.py +48 -0
- tests/handlers/test_handler_rocauc_dist.py +54 -0
- tests/handlers/test_handler_stats.py +281 -0
- tests/handlers/test_handler_surface_distance.py +113 -0
- tests/handlers/test_handler_tb_image.py +61 -0
- tests/handlers/test_handler_tb_stats.py +166 -0
- tests/handlers/test_handler_validation.py +59 -0
- tests/handlers/test_trt_compile.py +145 -0
- tests/handlers/test_write_metrics_reports.py +68 -0
- tests/inferers/__init__.py +10 -0
- tests/inferers/test_avg_merger.py +179 -0
- tests/inferers/test_controlnet_inferers.py +1388 -0
- tests/inferers/test_diffusion_inferer.py +236 -0
- tests/inferers/test_latent_diffusion_inferer.py +884 -0
- tests/inferers/test_patch_inferer.py +309 -0
- tests/inferers/test_saliency_inferer.py +55 -0
- tests/inferers/test_slice_inferer.py +57 -0
- tests/inferers/test_sliding_window_inference.py +377 -0
- tests/inferers/test_sliding_window_splitter.py +284 -0
- tests/inferers/test_wsi_sliding_window_splitter.py +249 -0
- tests/inferers/test_zarr_avg_merger.py +326 -0
- tests/integration/__init__.py +10 -0
- tests/integration/test_auto3dseg_ensemble.py +211 -0
- tests/integration/test_auto3dseg_hpo.py +189 -0
- tests/integration/test_deepedit_interaction.py +122 -0
- tests/integration/test_downsample_block.py +50 -0
- tests/integration/test_hovernet_nuclear_type_post_processingd.py +71 -0
- tests/integration/test_integration_autorunner.py +201 -0
- tests/integration/test_integration_bundle_run.py +240 -0
- tests/integration/test_integration_classification_2d.py +282 -0
- tests/integration/test_integration_determinism.py +95 -0
- tests/integration/test_integration_fast_train.py +231 -0
- tests/integration/test_integration_gpu_customization.py +159 -0
- tests/integration/test_integration_lazy_samples.py +219 -0
- tests/integration/test_integration_nnunetv2_runner.py +96 -0
- tests/integration/test_integration_segmentation_3d.py +304 -0
- tests/integration/test_integration_sliding_window.py +100 -0
- tests/integration/test_integration_stn.py +133 -0
- tests/integration/test_integration_unet_2d.py +67 -0
- tests/integration/test_integration_workers.py +61 -0
- tests/integration/test_integration_workflows.py +365 -0
- tests/integration/test_integration_workflows_adversarial.py +173 -0
- tests/integration/test_integration_workflows_gan.py +158 -0
- tests/integration/test_loader_semaphore.py +48 -0
- tests/integration/test_mapping_filed.py +122 -0
- tests/integration/test_meta_affine.py +183 -0
- tests/integration/test_metatensor_integration.py +114 -0
- tests/integration/test_module_list.py +76 -0
- tests/integration/test_one_of.py +283 -0
- tests/integration/test_pad_collation.py +124 -0
- tests/integration/test_reg_loss_integration.py +107 -0
- tests/integration/test_retinanet_predict_utils.py +154 -0
- tests/integration/test_seg_loss_integration.py +159 -0
- tests/integration/test_spatial_combine_transforms.py +185 -0
- tests/integration/test_testtimeaugmentation.py +186 -0
- tests/integration/test_vis_gradbased.py +69 -0
- tests/integration/test_vista3d_utils.py +159 -0
- tests/losses/__init__.py +10 -0
- tests/losses/deform/__init__.py +10 -0
- tests/losses/deform/test_bending_energy.py +88 -0
- tests/losses/deform/test_diffusion_loss.py +117 -0
- tests/losses/image_dissimilarity/__init__.py +10 -0
- tests/losses/image_dissimilarity/test_global_mutual_information_loss.py +150 -0
- tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py +162 -0
- tests/losses/test_adversarial_loss.py +94 -0
- tests/losses/test_barlow_twins_loss.py +109 -0
- tests/losses/test_cldice_loss.py +51 -0
- tests/losses/test_contrastive_loss.py +86 -0
- tests/losses/test_dice_ce_loss.py +123 -0
- tests/losses/test_dice_focal_loss.py +124 -0
- tests/losses/test_dice_loss.py +227 -0
- tests/losses/test_ds_loss.py +189 -0
- tests/losses/test_focal_loss.py +379 -0
- tests/losses/test_generalized_dice_focal_loss.py +85 -0
- tests/losses/test_generalized_dice_loss.py +221 -0
- tests/losses/test_generalized_wasserstein_dice_loss.py +234 -0
- tests/losses/test_giou_loss.py +62 -0
- tests/losses/test_hausdorff_loss.py +264 -0
- tests/losses/test_masked_dice_loss.py +152 -0
- tests/losses/test_masked_loss.py +87 -0
- tests/losses/test_multi_scale.py +86 -0
- tests/losses/test_nacl_loss.py +167 -0
- tests/losses/test_perceptual_loss.py +122 -0
- tests/losses/test_spectral_loss.py +86 -0
- tests/losses/test_ssim_loss.py +59 -0
- tests/losses/test_sure_loss.py +72 -0
- tests/losses/test_tversky_loss.py +198 -0
- tests/losses/test_unified_focal_loss.py +66 -0
- tests/metrics/__init__.py +10 -0
- tests/metrics/test_compute_average_precision.py +162 -0
- tests/metrics/test_compute_confusion_matrix.py +294 -0
- tests/metrics/test_compute_f_beta.py +80 -0
- tests/metrics/test_compute_fid_metric.py +40 -0
- tests/metrics/test_compute_froc.py +143 -0
- tests/metrics/test_compute_generalized_dice.py +240 -0
- tests/metrics/test_compute_meandice.py +306 -0
- tests/metrics/test_compute_meaniou.py +223 -0
- tests/metrics/test_compute_mmd_metric.py +56 -0
- tests/metrics/test_compute_multiscalessim_metric.py +83 -0
- tests/metrics/test_compute_panoptic_quality.py +113 -0
- tests/metrics/test_compute_regression_metrics.py +196 -0
- tests/metrics/test_compute_roc_auc.py +155 -0
- tests/metrics/test_compute_variance.py +147 -0
- tests/metrics/test_cumulative.py +63 -0
- tests/metrics/test_cumulative_average.py +74 -0
- tests/metrics/test_cumulative_average_dist.py +48 -0
- tests/metrics/test_hausdorff_distance.py +209 -0
- tests/metrics/test_label_quality_score.py +134 -0
- tests/metrics/test_loss_metric.py +57 -0
- tests/metrics/test_metrics_reloaded.py +96 -0
- tests/metrics/test_ssim_metric.py +78 -0
- tests/metrics/test_surface_dice.py +416 -0
- tests/metrics/test_surface_distance.py +186 -0
- tests/networks/__init__.py +10 -0
- tests/networks/blocks/__init__.py +10 -0
- tests/networks/blocks/dints_block/__init__.py +10 -0
- tests/networks/blocks/dints_block/test_acn_block.py +41 -0
- tests/networks/blocks/dints_block/test_factorized_increase.py +37 -0
- tests/networks/blocks/dints_block/test_factorized_reduce.py +37 -0
- tests/networks/blocks/dints_block/test_p3d_block.py +78 -0
- tests/networks/blocks/test_adn.py +86 -0
- tests/networks/blocks/test_convolutions.py +156 -0
- tests/networks/blocks/test_crf_cpu.py +513 -0
- tests/networks/blocks/test_crf_cuda.py +528 -0
- tests/networks/blocks/test_crossattention.py +185 -0
- tests/networks/blocks/test_denseblock.py +105 -0
- tests/networks/blocks/test_dynunet_block.py +116 -0
- tests/networks/blocks/test_fpn_block.py +88 -0
- tests/networks/blocks/test_localnet_block.py +121 -0
- tests/networks/blocks/test_mlp.py +78 -0
- tests/networks/blocks/test_patchembedding.py +212 -0
- tests/networks/blocks/test_regunet_block.py +103 -0
- tests/networks/blocks/test_se_block.py +85 -0
- tests/networks/blocks/test_se_blocks.py +78 -0
- tests/networks/blocks/test_segresnet_block.py +57 -0
- tests/networks/blocks/test_selfattention.py +232 -0
- tests/networks/blocks/test_simple_aspp.py +87 -0
- tests/networks/blocks/test_spatialattention.py +55 -0
- tests/networks/blocks/test_subpixel_upsample.py +87 -0
- tests/networks/blocks/test_text_encoding.py +49 -0
- tests/networks/blocks/test_transformerblock.py +90 -0
- tests/networks/blocks/test_unetr_block.py +158 -0
- tests/networks/blocks/test_upsample_block.py +134 -0
- tests/networks/blocks/warp/__init__.py +10 -0
- tests/networks/blocks/warp/test_dvf2ddf.py +72 -0
- tests/networks/blocks/warp/test_warp.py +250 -0
- tests/networks/layers/__init__.py +10 -0
- tests/networks/layers/filtering/__init__.py +10 -0
- tests/networks/layers/filtering/test_bilateral_approx_cpu.py +399 -0
- tests/networks/layers/filtering/test_bilateral_approx_cuda.py +404 -0
- tests/networks/layers/filtering/test_bilateral_precise.py +437 -0
- tests/networks/layers/filtering/test_phl_cpu.py +259 -0
- tests/networks/layers/filtering/test_phl_cuda.py +167 -0
- tests/networks/layers/filtering/test_trainable_bilateral.py +474 -0
- tests/networks/layers/filtering/test_trainable_joint_bilateral.py +609 -0
- tests/networks/layers/test_affine_transform.py +385 -0
- tests/networks/layers/test_apply_filter.py +89 -0
- tests/networks/layers/test_channel_pad.py +51 -0
- tests/networks/layers/test_conjugate_gradient.py +56 -0
- tests/networks/layers/test_drop_path.py +46 -0
- tests/networks/layers/test_gaussian.py +317 -0
- tests/networks/layers/test_gaussian_filter.py +206 -0
- tests/networks/layers/test_get_layers.py +65 -0
- tests/networks/layers/test_gmm.py +314 -0
- tests/networks/layers/test_grid_pull.py +93 -0
- tests/networks/layers/test_hilbert_transform.py +131 -0
- tests/networks/layers/test_lltm.py +62 -0
- tests/networks/layers/test_median_filter.py +52 -0
- tests/networks/layers/test_polyval.py +55 -0
- tests/networks/layers/test_preset_filters.py +136 -0
- tests/networks/layers/test_savitzky_golay_filter.py +141 -0
- tests/networks/layers/test_separable_filter.py +87 -0
- tests/networks/layers/test_skip_connection.py +48 -0
- tests/networks/layers/test_vector_quantizer.py +89 -0
- tests/networks/layers/test_weight_init.py +50 -0
- tests/networks/nets/__init__.py +10 -0
- tests/networks/nets/dints/__init__.py +10 -0
- tests/networks/nets/dints/test_dints_cell.py +110 -0
- tests/networks/nets/dints/test_dints_mixop.py +84 -0
- tests/networks/nets/regunet/__init__.py +10 -0
- tests/networks/nets/regunet/test_localnet.py +86 -0
- tests/networks/nets/regunet/test_regunet.py +88 -0
- tests/networks/nets/test_ahnet.py +224 -0
- tests/networks/nets/test_attentionunet.py +88 -0
- tests/networks/nets/test_autoencoder.py +95 -0
- tests/networks/nets/test_autoencoderkl.py +337 -0
- tests/networks/nets/test_basic_unet.py +102 -0
- tests/networks/nets/test_basic_unetplusplus.py +109 -0
- tests/networks/nets/test_bundle_init_bundle.py +55 -0
- tests/networks/nets/test_cell_sam_wrapper.py +58 -0
- tests/networks/nets/test_controlnet.py +215 -0
- tests/networks/nets/test_daf3d.py +62 -0
- tests/networks/nets/test_densenet.py +121 -0
- tests/networks/nets/test_diffusion_model_unet.py +585 -0
- tests/networks/nets/test_dints_network.py +168 -0
- tests/networks/nets/test_discriminator.py +59 -0
- tests/networks/nets/test_dynunet.py +181 -0
- tests/networks/nets/test_efficientnet.py +400 -0
- tests/networks/nets/test_flexible_unet.py +341 -0
- tests/networks/nets/test_fullyconnectednet.py +69 -0
- tests/networks/nets/test_generator.py +59 -0
- tests/networks/nets/test_globalnet.py +103 -0
- tests/networks/nets/test_highresnet.py +67 -0
- tests/networks/nets/test_hovernet.py +218 -0
- tests/networks/nets/test_mednext.py +122 -0
- tests/networks/nets/test_milmodel.py +92 -0
- tests/networks/nets/test_net_adapter.py +68 -0
- tests/networks/nets/test_network_consistency.py +86 -0
- tests/networks/nets/test_patch_gan_dicriminator.py +179 -0
- tests/networks/nets/test_quicknat.py +57 -0
- tests/networks/nets/test_resnet.py +340 -0
- tests/networks/nets/test_segresnet.py +120 -0
- tests/networks/nets/test_segresnet_ds.py +156 -0
- tests/networks/nets/test_senet.py +151 -0
- tests/networks/nets/test_spade_autoencoderkl.py +295 -0
- tests/networks/nets/test_spade_diffusion_model_unet.py +574 -0
- tests/networks/nets/test_spade_vaegan.py +140 -0
- tests/networks/nets/test_swin_unetr.py +139 -0
- tests/networks/nets/test_torchvision_fc_model.py +201 -0
- tests/networks/nets/test_transchex.py +84 -0
- tests/networks/nets/test_transformer.py +108 -0
- tests/networks/nets/test_unet.py +208 -0
- tests/networks/nets/test_unetr.py +137 -0
- tests/networks/nets/test_varautoencoder.py +127 -0
- tests/networks/nets/test_vista3d.py +84 -0
- tests/networks/nets/test_vit.py +139 -0
- tests/networks/nets/test_vitautoenc.py +112 -0
- tests/networks/nets/test_vnet.py +81 -0
- tests/networks/nets/test_voxelmorph.py +280 -0
- tests/networks/nets/test_vqvae.py +274 -0
- tests/networks/nets/test_vqvaetransformer_inferer.py +295 -0
- tests/networks/schedulers/__init__.py +10 -0
- tests/networks/schedulers/test_scheduler_ddim.py +83 -0
- tests/networks/schedulers/test_scheduler_ddpm.py +104 -0
- tests/networks/schedulers/test_scheduler_pndm.py +108 -0
- tests/networks/test_bundle_onnx_export.py +71 -0
- tests/networks/test_convert_to_onnx.py +106 -0
- tests/networks/test_convert_to_torchscript.py +46 -0
- tests/networks/test_convert_to_trt.py +79 -0
- tests/networks/test_save_state.py +73 -0
- tests/networks/test_to_onehot.py +63 -0
- tests/networks/test_varnet.py +63 -0
- tests/networks/utils/__init__.py +10 -0
- tests/networks/utils/test_copy_model_state.py +187 -0
- tests/networks/utils/test_eval_mode.py +34 -0
- tests/networks/utils/test_freeze_layers.py +61 -0
- tests/networks/utils/test_replace_module.py +98 -0
- tests/networks/utils/test_train_mode.py +34 -0
- tests/optimizers/__init__.py +10 -0
- tests/optimizers/test_generate_param_groups.py +105 -0
- tests/optimizers/test_lr_finder.py +108 -0
- tests/optimizers/test_lr_scheduler.py +71 -0
- tests/optimizers/test_optim_novograd.py +100 -0
- tests/profile_subclass/__init__.py +10 -0
- tests/profile_subclass/cprofile_profiling.py +29 -0
- tests/profile_subclass/min_classes.py +30 -0
- tests/profile_subclass/profiling.py +73 -0
- tests/profile_subclass/pyspy_profiling.py +41 -0
- tests/transforms/__init__.py +10 -0
- tests/transforms/compose/__init__.py +10 -0
- tests/transforms/compose/test_compose.py +758 -0
- tests/transforms/compose/test_some_of.py +258 -0
- tests/transforms/croppad/__init__.py +10 -0
- tests/transforms/croppad/test_rand_weighted_crop.py +224 -0
- tests/transforms/croppad/test_rand_weighted_cropd.py +182 -0
- tests/transforms/functional/__init__.py +10 -0
- tests/transforms/functional/test_apply.py +75 -0
- tests/transforms/functional/test_resample.py +50 -0
- tests/transforms/intensity/__init__.py +10 -0
- tests/transforms/intensity/test_compute_ho_ver_maps.py +75 -0
- tests/transforms/intensity/test_compute_ho_ver_maps_d.py +79 -0
- tests/transforms/intensity/test_foreground_mask.py +98 -0
- tests/transforms/intensity/test_foreground_maskd.py +106 -0
- tests/transforms/intensity/test_rand_histogram_shiftd.py +76 -0
- tests/transforms/intensity/test_scale_intensity_range_percentiles.py +96 -0
- tests/transforms/intensity/test_scale_intensity_range_percentilesd.py +100 -0
- tests/transforms/inverse/__init__.py +10 -0
- tests/transforms/inverse/test_inverse_array.py +76 -0
- tests/transforms/inverse/test_traceable_transform.py +59 -0
- tests/transforms/post/__init__.py +10 -0
- tests/transforms/post/test_label_filterd.py +78 -0
- tests/transforms/post/test_probnms.py +72 -0
- tests/transforms/post/test_probnmsd.py +79 -0
- tests/transforms/post/test_remove_small_objects.py +102 -0
- tests/transforms/spatial/__init__.py +10 -0
- tests/transforms/spatial/test_convert_box_points.py +119 -0
- tests/transforms/spatial/test_grid_patch.py +134 -0
- tests/transforms/spatial/test_grid_patchd.py +102 -0
- tests/transforms/spatial/test_rand_grid_patch.py +150 -0
- tests/transforms/spatial/test_rand_grid_patchd.py +117 -0
- tests/transforms/spatial/test_spatial_resampled.py +124 -0
- tests/transforms/test_activations.py +120 -0
- tests/transforms/test_activationsd.py +64 -0
- tests/transforms/test_adaptors.py +160 -0
- tests/transforms/test_add_coordinate_channels.py +53 -0
- tests/transforms/test_add_coordinate_channelsd.py +67 -0
- tests/transforms/test_add_extreme_points_channel.py +80 -0
- tests/transforms/test_add_extreme_points_channeld.py +77 -0
- tests/transforms/test_adjust_contrast.py +70 -0
- tests/transforms/test_adjust_contrastd.py +64 -0
- tests/transforms/test_affine.py +245 -0
- tests/transforms/test_affine_grid.py +152 -0
- tests/transforms/test_affined.py +190 -0
- tests/transforms/test_as_channel_last.py +38 -0
- tests/transforms/test_as_channel_lastd.py +44 -0
- tests/transforms/test_as_discrete.py +81 -0
- tests/transforms/test_as_discreted.py +82 -0
- tests/transforms/test_border_pad.py +49 -0
- tests/transforms/test_border_padd.py +45 -0
- tests/transforms/test_bounding_rect.py +54 -0
- tests/transforms/test_bounding_rectd.py +53 -0
- tests/transforms/test_cast_to_type.py +63 -0
- tests/transforms/test_cast_to_typed.py +74 -0
- tests/transforms/test_center_scale_crop.py +55 -0
- tests/transforms/test_center_scale_cropd.py +56 -0
- tests/transforms/test_center_spatial_crop.py +56 -0
- tests/transforms/test_center_spatial_cropd.py +63 -0
- tests/transforms/test_classes_to_indices.py +93 -0
- tests/transforms/test_classes_to_indicesd.py +110 -0
- tests/transforms/test_clip_intensity_percentiles.py +196 -0
- tests/transforms/test_clip_intensity_percentilesd.py +193 -0
- tests/transforms/test_compose_get_number_conversions.py +127 -0
- tests/transforms/test_concat_itemsd.py +82 -0
- tests/transforms/test_convert_to_multi_channel.py +59 -0
- tests/transforms/test_convert_to_multi_channeld.py +37 -0
- tests/transforms/test_copy_itemsd.py +86 -0
- tests/transforms/test_create_grid_and_affine.py +274 -0
- tests/transforms/test_crop_foreground.py +164 -0
- tests/transforms/test_crop_foregroundd.py +205 -0
- tests/transforms/test_cucim_dict_transform.py +142 -0
- tests/transforms/test_cucim_transform.py +141 -0
- tests/transforms/test_data_stats.py +221 -0
- tests/transforms/test_data_statsd.py +249 -0
- tests/transforms/test_delete_itemsd.py +58 -0
- tests/transforms/test_detect_envelope.py +159 -0
- tests/transforms/test_distance_transform_edt.py +202 -0
- tests/transforms/test_divisible_pad.py +49 -0
- tests/transforms/test_divisible_padd.py +42 -0
- tests/transforms/test_ensure_channel_first.py +113 -0
- tests/transforms/test_ensure_channel_firstd.py +85 -0
- tests/transforms/test_ensure_type.py +94 -0
- tests/transforms/test_ensure_typed.py +110 -0
- tests/transforms/test_fg_bg_to_indices.py +83 -0
- tests/transforms/test_fg_bg_to_indicesd.py +78 -0
- tests/transforms/test_fill_holes.py +207 -0
- tests/transforms/test_fill_holesd.py +209 -0
- tests/transforms/test_flatten_sub_keysd.py +64 -0
- tests/transforms/test_flip.py +83 -0
- tests/transforms/test_flipd.py +90 -0
- tests/transforms/test_fourier.py +70 -0
- tests/transforms/test_gaussian_sharpen.py +92 -0
- tests/transforms/test_gaussian_sharpend.py +92 -0
- tests/transforms/test_gaussian_smooth.py +96 -0
- tests/transforms/test_gaussian_smoothd.py +96 -0
- tests/transforms/test_generate_label_classes_crop_centers.py +71 -0
- tests/transforms/test_generate_pos_neg_label_crop_centers.py +76 -0
- tests/transforms/test_generate_spatial_bounding_box.py +114 -0
- tests/transforms/test_get_extreme_points.py +57 -0
- tests/transforms/test_gibbs_noise.py +73 -0
- tests/transforms/test_gibbs_noised.py +88 -0
- tests/transforms/test_grid_distortion.py +113 -0
- tests/transforms/test_grid_distortiond.py +87 -0
- tests/transforms/test_grid_split.py +88 -0
- tests/transforms/test_grid_splitd.py +96 -0
- tests/transforms/test_histogram_normalize.py +59 -0
- tests/transforms/test_histogram_normalized.py +59 -0
- tests/transforms/test_image_filter.py +259 -0
- tests/transforms/test_intensity_stats.py +73 -0
- tests/transforms/test_intensity_statsd.py +90 -0
- tests/transforms/test_inverse.py +521 -0
- tests/transforms/test_inverse_collation.py +147 -0
- tests/transforms/test_invert.py +105 -0
- tests/transforms/test_invertd.py +142 -0
- tests/transforms/test_k_space_spike_noise.py +81 -0
- tests/transforms/test_k_space_spike_noised.py +98 -0
- tests/transforms/test_keep_largest_connected_component.py +419 -0
- tests/transforms/test_keep_largest_connected_componentd.py +348 -0
- tests/transforms/test_label_filter.py +78 -0
- tests/transforms/test_label_to_contour.py +179 -0
- tests/transforms/test_label_to_contourd.py +182 -0
- tests/transforms/test_label_to_mask.py +69 -0
- tests/transforms/test_label_to_maskd.py +70 -0
- tests/transforms/test_load_image.py +502 -0
- tests/transforms/test_load_imaged.py +198 -0
- tests/transforms/test_load_spacing_orientation.py +149 -0
- tests/transforms/test_map_and_generate_sampling_centers.py +86 -0
- tests/transforms/test_map_binary_to_indices.py +75 -0
- tests/transforms/test_map_classes_to_indices.py +135 -0
- tests/transforms/test_map_label_value.py +89 -0
- tests/transforms/test_map_label_valued.py +85 -0
- tests/transforms/test_map_transform.py +45 -0
- tests/transforms/test_mask_intensity.py +74 -0
- tests/transforms/test_mask_intensityd.py +68 -0
- tests/transforms/test_mean_ensemble.py +77 -0
- tests/transforms/test_mean_ensembled.py +91 -0
- tests/transforms/test_median_smooth.py +41 -0
- tests/transforms/test_median_smoothd.py +65 -0
- tests/transforms/test_morphological_ops.py +101 -0
- tests/transforms/test_nifti_endianness.py +107 -0
- tests/transforms/test_normalize_intensity.py +143 -0
- tests/transforms/test_normalize_intensityd.py +81 -0
- tests/transforms/test_nvtx_decorator.py +289 -0
- tests/transforms/test_nvtx_transform.py +143 -0
- tests/transforms/test_orientation.py +247 -0
- tests/transforms/test_orientationd.py +112 -0
- tests/transforms/test_rand_adjust_contrast.py +45 -0
- tests/transforms/test_rand_adjust_contrastd.py +44 -0
- tests/transforms/test_rand_affine.py +201 -0
- tests/transforms/test_rand_affine_grid.py +212 -0
- tests/transforms/test_rand_affined.py +281 -0
- tests/transforms/test_rand_axis_flip.py +50 -0
- tests/transforms/test_rand_axis_flipd.py +50 -0
- tests/transforms/test_rand_bias_field.py +69 -0
- tests/transforms/test_rand_bias_fieldd.py +65 -0
- tests/transforms/test_rand_coarse_dropout.py +110 -0
- tests/transforms/test_rand_coarse_dropoutd.py +107 -0
- tests/transforms/test_rand_coarse_shuffle.py +65 -0
- tests/transforms/test_rand_coarse_shuffled.py +59 -0
- tests/transforms/test_rand_crop_by_label_classes.py +170 -0
- tests/transforms/test_rand_crop_by_label_classesd.py +159 -0
- tests/transforms/test_rand_crop_by_pos_neg_label.py +152 -0
- tests/transforms/test_rand_crop_by_pos_neg_labeld.py +172 -0
- tests/transforms/test_rand_cucim_dict_transform.py +162 -0
- tests/transforms/test_rand_cucim_transform.py +162 -0
- tests/transforms/test_rand_deform_grid.py +138 -0
- tests/transforms/test_rand_elastic_2d.py +127 -0
- tests/transforms/test_rand_elastic_3d.py +104 -0
- tests/transforms/test_rand_elasticd_2d.py +177 -0
- tests/transforms/test_rand_elasticd_3d.py +156 -0
- tests/transforms/test_rand_flip.py +60 -0
- tests/transforms/test_rand_flipd.py +55 -0
- tests/transforms/test_rand_gaussian_noise.py +48 -0
- tests/transforms/test_rand_gaussian_noised.py +54 -0
- tests/transforms/test_rand_gaussian_sharpen.py +140 -0
- tests/transforms/test_rand_gaussian_sharpend.py +143 -0
- tests/transforms/test_rand_gaussian_smooth.py +98 -0
- tests/transforms/test_rand_gaussian_smoothd.py +98 -0
- tests/transforms/test_rand_gibbs_noise.py +103 -0
- tests/transforms/test_rand_gibbs_noised.py +117 -0
- tests/transforms/test_rand_grid_distortion.py +99 -0
- tests/transforms/test_rand_grid_distortiond.py +90 -0
- tests/transforms/test_rand_histogram_shift.py +92 -0
- tests/transforms/test_rand_k_space_spike_noise.py +92 -0
- tests/transforms/test_rand_k_space_spike_noised.py +76 -0
- tests/transforms/test_rand_rician_noise.py +52 -0
- tests/transforms/test_rand_rician_noised.py +52 -0
- tests/transforms/test_rand_rotate.py +166 -0
- tests/transforms/test_rand_rotate90.py +100 -0
- tests/transforms/test_rand_rotate90d.py +112 -0
- tests/transforms/test_rand_rotated.py +187 -0
- tests/transforms/test_rand_scale_crop.py +78 -0
- tests/transforms/test_rand_scale_cropd.py +98 -0
- tests/transforms/test_rand_scale_intensity.py +54 -0
- tests/transforms/test_rand_scale_intensity_fixed_mean.py +41 -0
- tests/transforms/test_rand_scale_intensity_fixed_meand.py +41 -0
- tests/transforms/test_rand_scale_intensityd.py +53 -0
- tests/transforms/test_rand_shift_intensity.py +52 -0
- tests/transforms/test_rand_shift_intensityd.py +67 -0
- tests/transforms/test_rand_simulate_low_resolution.py +83 -0
- tests/transforms/test_rand_simulate_low_resolutiond.py +73 -0
- tests/transforms/test_rand_spatial_crop.py +107 -0
- tests/transforms/test_rand_spatial_crop_samples.py +128 -0
- tests/transforms/test_rand_spatial_crop_samplesd.py +147 -0
- tests/transforms/test_rand_spatial_cropd.py +112 -0
- tests/transforms/test_rand_std_shift_intensity.py +43 -0
- tests/transforms/test_rand_std_shift_intensityd.py +38 -0
- tests/transforms/test_rand_zoom.py +105 -0
- tests/transforms/test_rand_zoomd.py +108 -0
- tests/transforms/test_randidentity.py +49 -0
- tests/transforms/test_random_order.py +144 -0
- tests/transforms/test_randtorchvisiond.py +65 -0
- tests/transforms/test_regularization.py +139 -0
- tests/transforms/test_remove_repeated_channel.py +34 -0
- tests/transforms/test_remove_repeated_channeld.py +44 -0
- tests/transforms/test_repeat_channel.py +34 -0
- tests/transforms/test_repeat_channeld.py +41 -0
- tests/transforms/test_resample_backends.py +65 -0
- tests/transforms/test_resample_to_match.py +110 -0
- tests/transforms/test_resample_to_matchd.py +93 -0
- tests/transforms/test_resampler.py +165 -0
- tests/transforms/test_resize.py +140 -0
- tests/transforms/test_resize_with_pad_or_crop.py +91 -0
- tests/transforms/test_resize_with_pad_or_cropd.py +86 -0
- tests/transforms/test_resized.py +163 -0
- tests/transforms/test_rotate.py +160 -0
- tests/transforms/test_rotate90.py +212 -0
- tests/transforms/test_rotate90d.py +106 -0
- tests/transforms/test_rotated.py +179 -0
- tests/transforms/test_save_classificationd.py +109 -0
- tests/transforms/test_save_image.py +80 -0
- tests/transforms/test_save_imaged.py +130 -0
- tests/transforms/test_savitzky_golay_smooth.py +73 -0
- tests/transforms/test_savitzky_golay_smoothd.py +73 -0
- tests/transforms/test_scale_intensity.py +76 -0
- tests/transforms/test_scale_intensity_fixed_mean.py +94 -0
- tests/transforms/test_scale_intensity_range.py +41 -0
- tests/transforms/test_scale_intensity_ranged.py +40 -0
- tests/transforms/test_scale_intensityd.py +57 -0
- tests/transforms/test_select_itemsd.py +41 -0
- tests/transforms/test_shift_intensity.py +31 -0
- tests/transforms/test_shift_intensityd.py +44 -0
- tests/transforms/test_signal_continuouswavelet.py +44 -0
- tests/transforms/test_signal_fillempty.py +52 -0
- tests/transforms/test_signal_fillemptyd.py +60 -0
- tests/transforms/test_signal_rand_add_gaussiannoise.py +50 -0
- tests/transforms/test_signal_rand_add_sine.py +52 -0
- tests/transforms/test_signal_rand_add_sine_partial.py +50 -0
- tests/transforms/test_signal_rand_add_squarepulse.py +58 -0
- tests/transforms/test_signal_rand_add_squarepulse_partial.py +62 -0
- tests/transforms/test_signal_rand_drop.py +50 -0
- tests/transforms/test_signal_rand_scale.py +52 -0
- tests/transforms/test_signal_rand_shift.py +55 -0
- tests/transforms/test_signal_remove_frequency.py +71 -0
- tests/transforms/test_smooth_field.py +177 -0
- tests/transforms/test_sobel_gradient.py +189 -0
- tests/transforms/test_sobel_gradientd.py +212 -0
- tests/transforms/test_spacing.py +381 -0
- tests/transforms/test_spacingd.py +178 -0
- tests/transforms/test_spatial_crop.py +82 -0
- tests/transforms/test_spatial_cropd.py +74 -0
- tests/transforms/test_spatial_pad.py +57 -0
- tests/transforms/test_spatial_padd.py +43 -0
- tests/transforms/test_spatial_resample.py +235 -0
- tests/transforms/test_squeezedim.py +62 -0
- tests/transforms/test_squeezedimd.py +98 -0
- tests/transforms/test_std_shift_intensity.py +76 -0
- tests/transforms/test_std_shift_intensityd.py +74 -0
- tests/transforms/test_threshold_intensity.py +38 -0
- tests/transforms/test_threshold_intensityd.py +58 -0
- tests/transforms/test_to_contiguous.py +47 -0
- tests/transforms/test_to_cupy.py +112 -0
- tests/transforms/test_to_cupyd.py +76 -0
- tests/transforms/test_to_device.py +42 -0
- tests/transforms/test_to_deviced.py +37 -0
- tests/transforms/test_to_numpy.py +85 -0
- tests/transforms/test_to_numpyd.py +68 -0
- tests/transforms/test_to_pil.py +52 -0
- tests/transforms/test_to_pild.py +50 -0
- tests/transforms/test_to_tensor.py +60 -0
- tests/transforms/test_to_tensord.py +71 -0
- tests/transforms/test_torchvision.py +66 -0
- tests/transforms/test_torchvisiond.py +63 -0
- tests/transforms/test_transform.py +62 -0
- tests/transforms/test_transpose.py +41 -0
- tests/transforms/test_transposed.py +52 -0
- tests/transforms/test_ultrasound_confidence_map_transform.py +711 -0
- tests/transforms/test_utils_pytorch_numpy_unification.py +90 -0
- tests/transforms/test_vote_ensemble.py +84 -0
- tests/transforms/test_vote_ensembled.py +107 -0
- tests/transforms/test_with_allow_missing_keys.py +76 -0
- tests/transforms/test_zoom.py +120 -0
- tests/transforms/test_zoomd.py +94 -0
- tests/transforms/transform/__init__.py +10 -0
- tests/transforms/transform/test_randomizable.py +52 -0
- tests/transforms/transform/test_randomizable_transform_type.py +37 -0
- tests/transforms/utility/__init__.py +10 -0
- tests/transforms/utility/test_apply_transform_to_points.py +81 -0
- tests/transforms/utility/test_apply_transform_to_pointsd.py +185 -0
- tests/transforms/utility/test_identity.py +29 -0
- tests/transforms/utility/test_identityd.py +30 -0
- tests/transforms/utility/test_lambda.py +71 -0
- tests/transforms/utility/test_lambdad.py +83 -0
- tests/transforms/utility/test_rand_lambda.py +87 -0
- tests/transforms/utility/test_rand_lambdad.py +77 -0
- tests/transforms/utility/test_simulatedelay.py +36 -0
- tests/transforms/utility/test_simulatedelayd.py +36 -0
- tests/transforms/utility/test_splitdim.py +52 -0
- tests/transforms/utility/test_splitdimd.py +96 -0
- tests/transforms/utils/__init__.py +10 -0
- tests/transforms/utils/test_correct_crop_centers.py +36 -0
- tests/transforms/utils/test_get_unique_labels.py +45 -0
- tests/transforms/utils/test_print_transform_backends.py +29 -0
- tests/transforms/utils/test_soft_clip.py +125 -0
- tests/utils/__init__.py +10 -0
- tests/utils/enums/__init__.py +10 -0
- tests/utils/enums/test_hovernet_loss.py +190 -0
- tests/utils/enums/test_ordering.py +289 -0
- tests/utils/enums/test_wsireader.py +663 -0
- tests/utils/misc/__init__.py +10 -0
- tests/utils/misc/test_ensure_tuple.py +53 -0
- tests/utils/misc/test_monai_env_vars.py +44 -0
- tests/utils/misc/test_monai_utils_misc.py +103 -0
- tests/utils/misc/test_str2bool.py +34 -0
- tests/utils/misc/test_str2list.py +33 -0
- tests/utils/test_alias.py +44 -0
- tests/utils/test_component_store.py +73 -0
- tests/utils/test_deprecated.py +455 -0
- tests/utils/test_enum_bound_interp.py +75 -0
- tests/utils/test_evenly_divisible_all_gather_dist.py +50 -0
- tests/utils/test_get_package_version.py +34 -0
- tests/utils/test_handler_logfile.py +84 -0
- tests/utils/test_handler_metric_logger.py +62 -0
- tests/utils/test_list_to_dict.py +43 -0
- tests/utils/test_look_up_option.py +87 -0
- tests/utils/test_optional_import.py +80 -0
- tests/utils/test_pad_mode.py +39 -0
- tests/utils/test_profiling.py +208 -0
- tests/utils/test_rankfilter_dist.py +77 -0
- tests/utils/test_require_pkg.py +83 -0
- tests/utils/test_sample_slices.py +43 -0
- tests/utils/test_set_determinism.py +74 -0
- tests/utils/test_squeeze_unsqueeze.py +71 -0
- tests/utils/test_state_cacher.py +67 -0
- tests/utils/test_torchscript_utils.py +113 -0
- tests/utils/test_version.py +91 -0
- tests/utils/test_version_after.py +65 -0
- tests/utils/type_conversion/__init__.py +10 -0
- tests/utils/type_conversion/test_convert_data_type.py +152 -0
- tests/utils/type_conversion/test_get_equivalent_dtype.py +65 -0
- tests/utils/type_conversion/test_safe_dtype_range.py +99 -0
- tests/visualize/__init__.py +10 -0
- tests/visualize/test_img2tensorboard.py +46 -0
- tests/visualize/test_occlusion_sensitivity.py +128 -0
- tests/visualize/test_plot_2d_or_3d_image.py +74 -0
- tests/visualize/test_vis_cam.py +98 -0
- tests/visualize/test_vis_gradcam.py +211 -0
- tests/visualize/utils/__init__.py +10 -0
- tests/visualize/utils/test_blend_images.py +63 -0
- tests/visualize/utils/test_matshow3d.py +133 -0
- monai_weekly-1.5.dev2506.dist-info/RECORD +0 -427
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/LICENSE +0 -0
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/WHEEL +0 -0
@@ -0,0 +1,63 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import torch
|
17
|
+
from parameterized import parameterized
|
18
|
+
|
19
|
+
from monai.apps.reconstruction.networks.nets.coil_sensitivity_model import CoilSensitivityModel
|
20
|
+
from monai.apps.reconstruction.networks.nets.complex_unet import ComplexUnet
|
21
|
+
from monai.apps.reconstruction.networks.nets.varnet import VariationalNetworkModel
|
22
|
+
from monai.networks import eval_mode
|
23
|
+
from tests.test_utils import SkipIfBeforePyTorchVersion, test_script_save
|
24
|
+
|
25
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
26
|
+
coil_sens_model = CoilSensitivityModel(spatial_dims=2, features=[8, 16, 32, 64, 128, 8])
|
27
|
+
refinement_model = ComplexUnet(spatial_dims=2, features=[8, 16, 32, 64, 128, 8])
|
28
|
+
num_cascades = 2
|
29
|
+
TESTS = []
|
30
|
+
TESTS.append([coil_sens_model, refinement_model, num_cascades, (1, 3, 50, 50, 2), (1, 50, 50)]) # batch=1
|
31
|
+
TESTS.append([coil_sens_model, refinement_model, num_cascades, (2, 3, 50, 50, 2), (2, 50, 50)]) # batch=2
|
32
|
+
|
33
|
+
|
34
|
+
class TestVarNet(unittest.TestCase):
|
35
|
+
@parameterized.expand(TESTS)
|
36
|
+
def test_shape(self, coil_sens_model, refinement_model, num_cascades, input_shape, expected_shape):
|
37
|
+
net = VariationalNetworkModel(coil_sens_model, refinement_model, num_cascades).to(device)
|
38
|
+
mask_shape = [1 for _ in input_shape]
|
39
|
+
mask_shape[-2] = input_shape[-2]
|
40
|
+
mask = torch.zeros(mask_shape)
|
41
|
+
mask[..., mask_shape[-2] // 2 - 5 : mask_shape[-2] // 2 + 5, :] = 1
|
42
|
+
|
43
|
+
with eval_mode(net):
|
44
|
+
result = net(torch.randn(input_shape).to(device), mask.bool().to(device))
|
45
|
+
self.assertEqual(result.shape, expected_shape)
|
46
|
+
|
47
|
+
@parameterized.expand(TESTS)
|
48
|
+
@SkipIfBeforePyTorchVersion((1, 9, 1))
|
49
|
+
def test_script(self, coil_sens_model, refinement_model, num_cascades, input_shape, expected_shape):
|
50
|
+
net = VariationalNetworkModel(coil_sens_model, refinement_model, num_cascades)
|
51
|
+
|
52
|
+
mask_shape = [1 for _ in input_shape]
|
53
|
+
mask_shape[-2] = input_shape[-2]
|
54
|
+
mask = torch.zeros(mask_shape)
|
55
|
+
mask[..., mask_shape[-2] // 2 - 5 : mask_shape[-2] // 2 + 5, :] = 1
|
56
|
+
|
57
|
+
test_data = torch.randn(input_shape)
|
58
|
+
|
59
|
+
test_script_save(net, test_data, mask.bool())
|
60
|
+
|
61
|
+
|
62
|
+
if __name__ == "__main__":
|
63
|
+
unittest.main()
|
@@ -0,0 +1,10 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
@@ -0,0 +1,187 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
import torch
|
18
|
+
from parameterized import parameterized
|
19
|
+
|
20
|
+
from monai.networks.utils import copy_model_state
|
21
|
+
from monai.utils import set_determinism
|
22
|
+
|
23
|
+
|
24
|
+
class _TestModelOne(torch.nn.Module):
|
25
|
+
|
26
|
+
def __init__(self, n_n, n_m, n_class):
|
27
|
+
super().__init__()
|
28
|
+
self.layer = torch.nn.Linear(n_n, n_m)
|
29
|
+
self.class_layer = torch.nn.Linear(n_m, n_class)
|
30
|
+
|
31
|
+
def forward(self, x):
|
32
|
+
x = self.layer(x)
|
33
|
+
x = self.class_layer(x)
|
34
|
+
return x
|
35
|
+
|
36
|
+
|
37
|
+
class _TestModelTwo(torch.nn.Module):
|
38
|
+
|
39
|
+
def __init__(self, n_n, n_m, n_d, n_class):
|
40
|
+
super().__init__()
|
41
|
+
self.layer = torch.nn.Linear(n_n, n_m)
|
42
|
+
self.layer_1 = torch.nn.Linear(n_m, n_d)
|
43
|
+
self.class_layer = torch.nn.Linear(n_d, n_class)
|
44
|
+
|
45
|
+
def forward(self, x):
|
46
|
+
x = self.layer(x)
|
47
|
+
x = self.layer_1(x)
|
48
|
+
x = self.class_layer(x)
|
49
|
+
return x
|
50
|
+
|
51
|
+
|
52
|
+
TEST_CASES = []
|
53
|
+
__devices = ("cpu", "cuda") if torch.cuda.is_available() else ("cpu",)
|
54
|
+
for _x in __devices:
|
55
|
+
for _y in __devices:
|
56
|
+
TEST_CASES.append((_x, _y))
|
57
|
+
|
58
|
+
|
59
|
+
class TestModuleState(unittest.TestCase):
|
60
|
+
|
61
|
+
def tearDown(self):
|
62
|
+
set_determinism(None)
|
63
|
+
|
64
|
+
@parameterized.expand(TEST_CASES)
|
65
|
+
def test_set_state(self, device_0, device_1):
|
66
|
+
set_determinism(0)
|
67
|
+
model_one = _TestModelOne(10, 20, 3)
|
68
|
+
model_two = _TestModelTwo(10, 20, 10, 4)
|
69
|
+
model_one.to(device_0)
|
70
|
+
model_two.to(device_1)
|
71
|
+
model_dict, ch, unch = copy_model_state(model_one, model_two)
|
72
|
+
x = np.random.randn(4, 10)
|
73
|
+
x = torch.tensor(x, device=device_0, dtype=torch.float32)
|
74
|
+
output = model_one(x).detach().cpu().numpy()
|
75
|
+
expected = np.array(
|
76
|
+
[
|
77
|
+
[-0.36076584, -0.03177825, -0.7702266],
|
78
|
+
[-0.0526831, -0.15855855, -0.01149344],
|
79
|
+
[-0.3760508, -0.22485238, -0.0634037],
|
80
|
+
[0.5977675, -0.67991066, 0.1919502],
|
81
|
+
]
|
82
|
+
)
|
83
|
+
np.testing.assert_allclose(output, expected, atol=1e-3)
|
84
|
+
self.assertEqual(len(ch), 2)
|
85
|
+
self.assertEqual(len(unch), 2)
|
86
|
+
|
87
|
+
@parameterized.expand(TEST_CASES)
|
88
|
+
def test_set_full_state(self, device_0, device_1):
|
89
|
+
set_determinism(0)
|
90
|
+
model_one = _TestModelOne(10, 20, 3)
|
91
|
+
model_two = _TestModelOne(10, 20, 3)
|
92
|
+
model_one.to(device_0)
|
93
|
+
model_two.to(device_1)
|
94
|
+
# test module input
|
95
|
+
model_dict, ch, unch = copy_model_state(model_one, model_two)
|
96
|
+
# test dict input
|
97
|
+
model_dict, ch, unch = copy_model_state(model_dict, model_two)
|
98
|
+
x = np.random.randn(4, 10)
|
99
|
+
x = torch.tensor(x, device=device_0, dtype=torch.float32)
|
100
|
+
output = model_one(x).detach().cpu().numpy()
|
101
|
+
model_two.to(device_0)
|
102
|
+
output_1 = model_two(x).detach().cpu().numpy()
|
103
|
+
np.testing.assert_allclose(output, output_1, atol=1e-3)
|
104
|
+
self.assertEqual(len(ch), 4)
|
105
|
+
self.assertEqual(len(unch), 0)
|
106
|
+
|
107
|
+
@parameterized.expand(TEST_CASES)
|
108
|
+
def test_set_exclude_vars(self, device_0, device_1):
|
109
|
+
set_determinism(0)
|
110
|
+
model_one = _TestModelOne(10, 20, 3)
|
111
|
+
model_two = _TestModelTwo(10, 20, 10, 4)
|
112
|
+
model_one.to(device_0)
|
113
|
+
model_two.to(device_1)
|
114
|
+
# test skip layer.bias
|
115
|
+
model_dict, ch, unch = copy_model_state(model_one, model_two, exclude_vars="layer.bias")
|
116
|
+
x = np.random.randn(4, 10)
|
117
|
+
x = torch.tensor(x, device=device_0, dtype=torch.float32)
|
118
|
+
output = model_one(x).detach().cpu().numpy()
|
119
|
+
expected = np.array(
|
120
|
+
[
|
121
|
+
[-0.34172416, 0.0375042, -0.98340976],
|
122
|
+
[-0.03364138, -0.08927619, -0.2246768],
|
123
|
+
[-0.35700908, -0.15556987, -0.27658707],
|
124
|
+
[0.61680925, -0.6106281, -0.02123314],
|
125
|
+
]
|
126
|
+
)
|
127
|
+
np.testing.assert_allclose(output, expected, atol=1e-3)
|
128
|
+
self.assertEqual(len(ch), 1)
|
129
|
+
self.assertEqual(len(unch), 3)
|
130
|
+
|
131
|
+
@parameterized.expand(TEST_CASES)
|
132
|
+
def test_set_map_across(self, device_0, device_1):
|
133
|
+
set_determinism(0)
|
134
|
+
model_one = _TestModelOne(10, 10, 3)
|
135
|
+
model_two = _TestModelTwo(10, 10, 10, 4)
|
136
|
+
model_one.to(device_0)
|
137
|
+
model_two.to(device_1)
|
138
|
+
# test weight map
|
139
|
+
model_dict, ch, unch = copy_model_state(
|
140
|
+
model_one, model_two, mapping={"layer_1.weight": "layer.weight", "layer_1.bias": "layer_1.weight"}
|
141
|
+
)
|
142
|
+
model_one.load_state_dict(model_dict)
|
143
|
+
x = np.random.randn(4, 10)
|
144
|
+
x = torch.tensor(x, device=device_0, dtype=torch.float32)
|
145
|
+
output = model_one(x).detach().cpu().numpy()
|
146
|
+
expected = np.array(
|
147
|
+
[
|
148
|
+
[0.8244487, -0.19650555, 0.65723234],
|
149
|
+
[0.71239626, 0.25617486, 0.5247122],
|
150
|
+
[0.24168758, 1.0301148, 0.39089814],
|
151
|
+
[0.25791705, 0.8653245, 0.14833644],
|
152
|
+
]
|
153
|
+
)
|
154
|
+
np.testing.assert_allclose(output, expected, atol=1e-3)
|
155
|
+
self.assertEqual(len(ch), 2)
|
156
|
+
self.assertEqual(len(unch), 2)
|
157
|
+
|
158
|
+
@parameterized.expand(TEST_CASES)
|
159
|
+
def test_set_prefix(self, device_0, device_1):
|
160
|
+
set_determinism(0)
|
161
|
+
model_one = torch.nn.Sequential(_TestModelOne(10, 20, 3))
|
162
|
+
model_two = _TestModelTwo(10, 20, 10, 4)
|
163
|
+
model_one.to(device_0)
|
164
|
+
model_two.to(device_1)
|
165
|
+
# test skip layer.bias
|
166
|
+
model_dict, ch, unch = copy_model_state(
|
167
|
+
model_one, model_two, dst_prefix="0.", exclude_vars="layer.bias", inplace=False
|
168
|
+
)
|
169
|
+
model_one.load_state_dict(model_dict)
|
170
|
+
x = np.random.randn(4, 10)
|
171
|
+
x = torch.tensor(x, device=device_0, dtype=torch.float32)
|
172
|
+
output = model_one(x).detach().cpu().numpy()
|
173
|
+
expected = np.array(
|
174
|
+
[
|
175
|
+
[-0.360766, -0.031778, -0.770227],
|
176
|
+
[-0.052683, -0.158559, -0.011493],
|
177
|
+
[-0.376051, -0.224852, -0.063404],
|
178
|
+
[0.597767, -0.679911, 0.19195],
|
179
|
+
]
|
180
|
+
)
|
181
|
+
np.testing.assert_allclose(output, expected, atol=1e-3)
|
182
|
+
self.assertEqual(len(ch), 2)
|
183
|
+
self.assertEqual(len(unch), 2)
|
184
|
+
|
185
|
+
|
186
|
+
if __name__ == "__main__":
|
187
|
+
unittest.main()
|
@@ -0,0 +1,34 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import torch
|
17
|
+
|
18
|
+
from monai.networks.utils import eval_mode
|
19
|
+
|
20
|
+
|
21
|
+
class TestEvalMode(unittest.TestCase):
|
22
|
+
|
23
|
+
def test_eval_mode(self):
|
24
|
+
t = torch.rand(1, 1, 4, 4)
|
25
|
+
p = torch.nn.Conv2d(1, 1, 3)
|
26
|
+
self.assertTrue(p.training) # True
|
27
|
+
with eval_mode(p):
|
28
|
+
self.assertFalse(p.training) # False
|
29
|
+
with self.assertRaises(RuntimeError):
|
30
|
+
p(t).sum().backward()
|
31
|
+
|
32
|
+
|
33
|
+
if __name__ == "__main__":
|
34
|
+
unittest.main()
|
@@ -0,0 +1,61 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import torch
|
17
|
+
from parameterized import parameterized
|
18
|
+
|
19
|
+
from monai.networks.utils import freeze_layers
|
20
|
+
from monai.utils import set_determinism
|
21
|
+
from tests.networks.utils.test_copy_model_state import _TestModelOne, _TestModelTwo
|
22
|
+
|
23
|
+
TEST_CASES = []
|
24
|
+
__devices = ("cpu", "cuda") if torch.cuda.is_available() else ("cpu",)
|
25
|
+
for _x in __devices:
|
26
|
+
TEST_CASES.append(_x)
|
27
|
+
|
28
|
+
|
29
|
+
class TestModuleState(unittest.TestCase):
|
30
|
+
def tearDown(self):
|
31
|
+
set_determinism(None)
|
32
|
+
|
33
|
+
@parameterized.expand(TEST_CASES)
|
34
|
+
def test_freeze_vars(self, device):
|
35
|
+
set_determinism(0)
|
36
|
+
model = _TestModelOne(10, 20, 3)
|
37
|
+
model.to(device)
|
38
|
+
freeze_layers(model, "class")
|
39
|
+
|
40
|
+
for name, param in model.named_parameters():
|
41
|
+
if "class_layer" in name:
|
42
|
+
self.assertFalse(param.requires_grad)
|
43
|
+
else:
|
44
|
+
self.assertTrue(param.requires_grad)
|
45
|
+
|
46
|
+
@parameterized.expand(TEST_CASES)
|
47
|
+
def test_exclude_vars(self, device):
|
48
|
+
set_determinism(0)
|
49
|
+
model = _TestModelTwo(10, 20, 10, 4)
|
50
|
+
model.to(device)
|
51
|
+
freeze_layers(model, exclude_vars="class")
|
52
|
+
|
53
|
+
for name, param in model.named_parameters():
|
54
|
+
if "class_layer" in name:
|
55
|
+
self.assertTrue(param.requires_grad)
|
56
|
+
else:
|
57
|
+
self.assertFalse(param.requires_grad)
|
58
|
+
|
59
|
+
|
60
|
+
if __name__ == "__main__":
|
61
|
+
unittest.main()
|
@@ -0,0 +1,98 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import torch
|
17
|
+
from parameterized import parameterized
|
18
|
+
|
19
|
+
from monai.networks.nets import DenseNet121
|
20
|
+
from monai.networks.utils import replace_modules, replace_modules_temp
|
21
|
+
from tests.test_utils import TEST_DEVICES
|
22
|
+
|
23
|
+
TESTS = []
|
24
|
+
for device in TEST_DEVICES:
|
25
|
+
for match_device in (True, False):
|
26
|
+
# replace 1
|
27
|
+
TESTS.append(("features.denseblock1.denselayer1.layers.relu1", True, match_device, *device))
|
28
|
+
# replace 1 (but not strict)
|
29
|
+
TESTS.append(("features.denseblock1.denselayer1.layers.relu1", False, match_device, *device))
|
30
|
+
# replace multiple
|
31
|
+
TESTS.append(("relu", False, match_device, *device))
|
32
|
+
|
33
|
+
|
34
|
+
class TestReplaceModule(unittest.TestCase):
|
35
|
+
def setUp(self):
|
36
|
+
self.net = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3)
|
37
|
+
self.num_relus = self.get_num_modules(torch.nn.ReLU)
|
38
|
+
self.total = self.get_num_modules()
|
39
|
+
self.assertGreater(self.num_relus, 0)
|
40
|
+
|
41
|
+
def get_num_modules(self, mod: type[torch.nn.Module] | None = None) -> int:
|
42
|
+
m = [m for _, m in self.net.named_modules()]
|
43
|
+
if mod is not None:
|
44
|
+
m = [_m for _m in m if isinstance(_m, mod)]
|
45
|
+
return len(m)
|
46
|
+
|
47
|
+
def check_replaced_modules(self, name, match_device):
|
48
|
+
# total num modules should remain the same
|
49
|
+
self.assertEqual(self.total, self.get_num_modules())
|
50
|
+
num_relus_mod = self.get_num_modules(torch.nn.ReLU)
|
51
|
+
num_softmax = self.get_num_modules(torch.nn.Softmax)
|
52
|
+
# list of returned modules should be as long as number of softmax
|
53
|
+
self.assertEqual(self.num_relus, num_relus_mod + num_softmax)
|
54
|
+
if name == "relu":
|
55
|
+
# at least 2 softmaxes
|
56
|
+
self.assertGreaterEqual(num_softmax, 2)
|
57
|
+
else:
|
58
|
+
# one softmax
|
59
|
+
self.assertEqual(num_softmax, 1)
|
60
|
+
if match_device:
|
61
|
+
self.assertEqual(len(list({i.device for i in self.net.parameters()})), 1)
|
62
|
+
|
63
|
+
@parameterized.expand(TESTS)
|
64
|
+
def test_replace(self, name, strict_match, match_device, device):
|
65
|
+
self.net.to(device)
|
66
|
+
# replace module(s)
|
67
|
+
replaced = replace_modules(self.net, name, torch.nn.Softmax(), strict_match, match_device)
|
68
|
+
self.check_replaced_modules(name, match_device)
|
69
|
+
# number of returned modules should equal number of softmax modules
|
70
|
+
self.assertEqual(len(replaced), self.get_num_modules(torch.nn.Softmax))
|
71
|
+
# all replaced modules should be ReLU
|
72
|
+
for r in replaced:
|
73
|
+
self.assertIsInstance(r[1], torch.nn.ReLU)
|
74
|
+
# if a specfic module was named, check that the name matches exactly
|
75
|
+
if name == "features.denseblock1.denselayer1.layers.relu1":
|
76
|
+
self.assertEqual(replaced[0][0], name)
|
77
|
+
|
78
|
+
@parameterized.expand(TESTS)
|
79
|
+
def test_replace_context_manager(self, name, strict_match, match_device, device):
|
80
|
+
self.net.to(device)
|
81
|
+
with replace_modules_temp(self.net, name, torch.nn.Softmax(), strict_match, match_device):
|
82
|
+
self.check_replaced_modules(name, match_device)
|
83
|
+
# Check that model was correctly reverted
|
84
|
+
self.assertEqual(self.get_num_modules(), self.total)
|
85
|
+
self.assertEqual(self.get_num_modules(torch.nn.ReLU), self.num_relus)
|
86
|
+
self.assertEqual(self.get_num_modules(torch.nn.Softmax), 0)
|
87
|
+
|
88
|
+
def test_raises(self):
|
89
|
+
# name doesn't exist in module
|
90
|
+
with self.assertRaises(AttributeError):
|
91
|
+
replace_modules(self.net, "non_existent_module", torch.nn.Softmax(), strict_match=True)
|
92
|
+
with self.assertRaises(AttributeError):
|
93
|
+
with replace_modules_temp(self.net, "non_existent_module", torch.nn.Softmax(), strict_match=True):
|
94
|
+
pass
|
95
|
+
|
96
|
+
|
97
|
+
if __name__ == "__main__":
|
98
|
+
unittest.main()
|
@@ -0,0 +1,34 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import torch
|
17
|
+
|
18
|
+
from monai.networks.utils import train_mode
|
19
|
+
|
20
|
+
|
21
|
+
class TestEvalMode(unittest.TestCase):
|
22
|
+
|
23
|
+
def test_eval_mode(self):
|
24
|
+
t = torch.rand(1, 1, 4, 4)
|
25
|
+
p = torch.nn.Conv2d(1, 1, 3)
|
26
|
+
p.eval()
|
27
|
+
self.assertFalse(p.training) # False
|
28
|
+
with train_mode(p):
|
29
|
+
self.assertTrue(p.training) # True
|
30
|
+
p(t).sum().backward()
|
31
|
+
|
32
|
+
|
33
|
+
if __name__ == "__main__":
|
34
|
+
unittest.main()
|
@@ -0,0 +1,10 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
@@ -0,0 +1,105 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import torch
|
17
|
+
from parameterized import parameterized
|
18
|
+
|
19
|
+
from monai.networks.nets import Unet
|
20
|
+
from monai.optimizers import generate_param_groups
|
21
|
+
from monai.utils import ensure_tuple
|
22
|
+
from tests.test_utils import assert_allclose
|
23
|
+
|
24
|
+
TEST_CASE_1 = [{"layer_matches": [lambda x: x.model[-1]], "match_types": "select", "lr_values": [1]}, (1, 100), [5, 21]]
|
25
|
+
|
26
|
+
TEST_CASE_2 = [
|
27
|
+
{
|
28
|
+
"layer_matches": [lambda x: x.model[-1], lambda x: x.model[-2], lambda x: x.model[-3]],
|
29
|
+
"match_types": "select",
|
30
|
+
"lr_values": [1, 2, 3],
|
31
|
+
},
|
32
|
+
(1, 2, 3, 100),
|
33
|
+
[5, 16, 5, 0],
|
34
|
+
]
|
35
|
+
|
36
|
+
TEST_CASE_3 = [
|
37
|
+
{"layer_matches": [lambda x: x.model[2][1].conv[0].conv], "match_types": ["select"], "lr_values": [1]},
|
38
|
+
(1, 100),
|
39
|
+
[2, 24],
|
40
|
+
]
|
41
|
+
|
42
|
+
TEST_CASE_4 = [
|
43
|
+
{
|
44
|
+
"layer_matches": [lambda x: x.model[0], lambda x: "2.0.conv" in x[0]],
|
45
|
+
"match_types": ["select", "filter"],
|
46
|
+
"lr_values": [1, 2],
|
47
|
+
},
|
48
|
+
(1, 2, 100),
|
49
|
+
[5, 4, 17],
|
50
|
+
]
|
51
|
+
|
52
|
+
TEST_CASE_5 = [
|
53
|
+
{"layer_matches": [lambda x: x.model[-1]], "match_types": ["select"], "lr_values": [1], "include_others": False},
|
54
|
+
(1),
|
55
|
+
[5],
|
56
|
+
]
|
57
|
+
|
58
|
+
TEST_CASE_6 = [
|
59
|
+
{
|
60
|
+
"layer_matches": [lambda x: "weight" in x[0]],
|
61
|
+
"match_types": ["filter"],
|
62
|
+
"lr_values": [1],
|
63
|
+
"include_others": True,
|
64
|
+
},
|
65
|
+
(1),
|
66
|
+
[16, 10],
|
67
|
+
]
|
68
|
+
|
69
|
+
|
70
|
+
class TestGenerateParamGroups(unittest.TestCase):
|
71
|
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6])
|
72
|
+
def test_lr_values(self, input_param, expected_values, expected_groups):
|
73
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
74
|
+
net = Unet(
|
75
|
+
spatial_dims=3, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=1
|
76
|
+
).to(device)
|
77
|
+
|
78
|
+
params = generate_param_groups(network=net, **input_param)
|
79
|
+
optimizer = torch.optim.Adam(params, 100)
|
80
|
+
|
81
|
+
for param_group, value in zip(optimizer.param_groups, ensure_tuple(expected_values)):
|
82
|
+
assert_allclose(param_group["lr"], value)
|
83
|
+
|
84
|
+
n = [len(p["params"]) for p in params]
|
85
|
+
self.assertListEqual(n, expected_groups)
|
86
|
+
|
87
|
+
def test_wrong(self):
|
88
|
+
"""overlapped"""
|
89
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
90
|
+
net = Unet(
|
91
|
+
spatial_dims=3, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=1
|
92
|
+
).to(device)
|
93
|
+
|
94
|
+
params = generate_param_groups(
|
95
|
+
network=net,
|
96
|
+
layer_matches=[lambda x: x.model[-1], lambda x: x.model[-1]],
|
97
|
+
match_types="select",
|
98
|
+
lr_values=0.1,
|
99
|
+
)
|
100
|
+
with self.assertRaises(ValueError):
|
101
|
+
torch.optim.Adam(params, 100)
|
102
|
+
|
103
|
+
|
104
|
+
if __name__ == "__main__":
|
105
|
+
unittest.main()
|