monai-weekly 1.5.dev2506__py3-none-any.whl → 1.5.dev2508__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/auto3dseg/transforms.py +1 -4
- monai/data/utils.py +6 -13
- monai/handlers/__init__.py +1 -0
- monai/handlers/average_precision.py +53 -0
- monai/inferers/inferer.py +10 -7
- monai/inferers/utils.py +1 -2
- monai/losses/dice.py +2 -14
- monai/losses/ds_loss.py +1 -3
- monai/metrics/__init__.py +1 -0
- monai/metrics/average_precision.py +187 -0
- monai/networks/layers/simplelayers.py +2 -14
- monai/networks/utils.py +4 -16
- monai/transforms/compose.py +28 -11
- monai/transforms/croppad/array.py +1 -6
- monai/transforms/io/array.py +0 -1
- monai/transforms/transform.py +15 -6
- monai/transforms/utility/array.py +2 -12
- monai/transforms/utils.py +1 -2
- monai/transforms/utils_pytorch_numpy_unification.py +2 -4
- monai/utils/enums.py +3 -2
- monai/utils/module.py +6 -6
- monai/utils/tf32.py +0 -10
- monai/visualize/class_activation_maps.py +5 -8
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/METADATA +21 -17
- monai_weekly-1.5.dev2508.dist-info/RECORD +1185 -0
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/top_level.txt +1 -0
- tests/apps/__init__.py +10 -0
- tests/apps/deepedit/__init__.py +10 -0
- tests/apps/deepedit/test_deepedit_transforms.py +314 -0
- tests/apps/deepgrow/__init__.py +10 -0
- tests/apps/deepgrow/test_deepgrow_dataset.py +109 -0
- tests/apps/deepgrow/transforms/__init__.py +10 -0
- tests/apps/deepgrow/transforms/test_deepgrow_interaction.py +97 -0
- tests/apps/deepgrow/transforms/test_deepgrow_transforms.py +556 -0
- tests/apps/detection/__init__.py +10 -0
- tests/apps/detection/metrics/__init__.py +10 -0
- tests/apps/detection/metrics/test_detection_coco_metrics.py +69 -0
- tests/apps/detection/networks/__init__.py +10 -0
- tests/apps/detection/networks/test_retinanet.py +210 -0
- tests/apps/detection/networks/test_retinanet_detector.py +203 -0
- tests/apps/detection/test_box_transform.py +370 -0
- tests/apps/detection/utils/__init__.py +10 -0
- tests/apps/detection/utils/test_anchor_box.py +88 -0
- tests/apps/detection/utils/test_atss_box_matcher.py +46 -0
- tests/apps/detection/utils/test_box_coder.py +43 -0
- tests/apps/detection/utils/test_detector_boxselector.py +67 -0
- tests/apps/detection/utils/test_detector_utils.py +96 -0
- tests/apps/detection/utils/test_hardnegsampler.py +54 -0
- tests/apps/nuclick/__init__.py +10 -0
- tests/apps/nuclick/test_nuclick_transforms.py +259 -0
- tests/apps/pathology/__init__.py +10 -0
- tests/apps/pathology/handlers/__init__.py +10 -0
- tests/apps/pathology/handlers/test_from_engine_hovernet.py +38 -0
- tests/apps/pathology/test_lesion_froc.py +333 -0
- tests/apps/pathology/test_pathology_prob_nms.py +55 -0
- tests/apps/pathology/test_prepare_batch_hovernet.py +70 -0
- tests/apps/pathology/test_sliding_window_hovernet_inference.py +303 -0
- tests/apps/pathology/transforms/__init__.py +10 -0
- tests/apps/pathology/transforms/post/__init__.py +10 -0
- tests/apps/pathology/transforms/post/test_generate_distance_map.py +51 -0
- tests/apps/pathology/transforms/post/test_generate_distance_mapd.py +70 -0
- tests/apps/pathology/transforms/post/test_generate_instance_border.py +49 -0
- tests/apps/pathology/transforms/post/test_generate_instance_borderd.py +59 -0
- tests/apps/pathology/transforms/post/test_generate_instance_centroid.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_instance_centroidd.py +56 -0
- tests/apps/pathology/transforms/post/test_generate_instance_contour.py +58 -0
- tests/apps/pathology/transforms/post/test_generate_instance_contourd.py +61 -0
- tests/apps/pathology/transforms/post/test_generate_instance_type.py +51 -0
- tests/apps/pathology/transforms/post/test_generate_instance_typed.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_succinct_contour.py +55 -0
- tests/apps/pathology/transforms/post/test_generate_succinct_contourd.py +57 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_markers.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_markersd.py +83 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_mask.py +77 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_maskd.py +77 -0
- tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processing.py +61 -0
- tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processingd.py +66 -0
- tests/apps/pathology/transforms/post/test_hovernet_nuclear_type_post_processing.py +66 -0
- tests/apps/pathology/transforms/post/test_watershed.py +60 -0
- tests/apps/pathology/transforms/post/test_watershedd.py +70 -0
- tests/apps/pathology/transforms/test_pathology_he_stain.py +230 -0
- tests/apps/pathology/transforms/test_pathology_he_stain_dict.py +225 -0
- tests/apps/reconstruction/__init__.py +10 -0
- tests/apps/reconstruction/nets/__init__.py +10 -0
- tests/apps/reconstruction/nets/test_recon_net_utils.py +82 -0
- tests/apps/reconstruction/test_complex_utils.py +77 -0
- tests/apps/reconstruction/test_fastmri_reader.py +82 -0
- tests/apps/reconstruction/test_mri_utils.py +37 -0
- tests/apps/reconstruction/transforms/__init__.py +10 -0
- tests/apps/reconstruction/transforms/test_kspace_mask.py +50 -0
- tests/apps/reconstruction/transforms/test_reference_based_normalize_intensity.py +77 -0
- tests/apps/reconstruction/transforms/test_reference_based_spatial_cropd.py +57 -0
- tests/apps/test_auto3dseg_bundlegen.py +156 -0
- tests/apps/test_check_hash.py +53 -0
- tests/apps/test_cross_validation.py +74 -0
- tests/apps/test_decathlondataset.py +93 -0
- tests/apps/test_download_and_extract.py +70 -0
- tests/apps/test_download_url_yandex.py +45 -0
- tests/apps/test_mednistdataset.py +72 -0
- tests/apps/test_mmar_download.py +154 -0
- tests/apps/test_tciadataset.py +123 -0
- tests/apps/vista3d/__init__.py +10 -0
- tests/apps/vista3d/test_point_based_window_inferer.py +77 -0
- tests/apps/vista3d/test_vista3d_sampler.py +100 -0
- tests/apps/vista3d/test_vista3d_transforms.py +94 -0
- tests/bundle/__init__.py +10 -0
- tests/bundle/test_bundle_ckpt_export.py +107 -0
- tests/bundle/test_bundle_download.py +435 -0
- tests/bundle/test_bundle_get_data.py +94 -0
- tests/bundle/test_bundle_push_to_hf_hub.py +41 -0
- tests/bundle/test_bundle_trt_export.py +147 -0
- tests/bundle/test_bundle_utils.py +149 -0
- tests/bundle/test_bundle_verify_metadata.py +66 -0
- tests/bundle/test_bundle_verify_net.py +76 -0
- tests/bundle/test_bundle_workflow.py +272 -0
- tests/bundle/test_component_locator.py +38 -0
- tests/bundle/test_config_item.py +138 -0
- tests/bundle/test_config_parser.py +392 -0
- tests/bundle/test_reference_resolver.py +114 -0
- tests/config/__init__.py +10 -0
- tests/config/test_cv2_dist.py +53 -0
- tests/engines/__init__.py +10 -0
- tests/engines/test_ensemble_evaluator.py +94 -0
- tests/engines/test_prepare_batch_default.py +76 -0
- tests/engines/test_prepare_batch_default_dist.py +76 -0
- tests/engines/test_prepare_batch_diffusion.py +104 -0
- tests/engines/test_prepare_batch_extra_input.py +80 -0
- tests/fl/__init__.py +10 -0
- tests/fl/monai_algo/__init__.py +10 -0
- tests/fl/monai_algo/test_fl_monai_algo.py +251 -0
- tests/fl/monai_algo/test_fl_monai_algo_dist.py +117 -0
- tests/fl/test_fl_monai_algo_stats.py +81 -0
- tests/fl/utils/__init__.py +10 -0
- tests/fl/utils/test_fl_exchange_object.py +63 -0
- tests/handlers/__init__.py +10 -0
- tests/handlers/test_handler_average_precision.py +79 -0
- tests/handlers/test_handler_checkpoint_loader.py +182 -0
- tests/handlers/test_handler_checkpoint_saver.py +233 -0
- tests/handlers/test_handler_classification_saver.py +64 -0
- tests/handlers/test_handler_classification_saver_dist.py +77 -0
- tests/handlers/test_handler_clearml_image.py +65 -0
- tests/handlers/test_handler_clearml_stats.py +65 -0
- tests/handlers/test_handler_confusion_matrix.py +104 -0
- tests/handlers/test_handler_confusion_matrix_dist.py +70 -0
- tests/handlers/test_handler_decollate_batch.py +66 -0
- tests/handlers/test_handler_early_stop.py +68 -0
- tests/handlers/test_handler_garbage_collector.py +73 -0
- tests/handlers/test_handler_hausdorff_distance.py +111 -0
- tests/handlers/test_handler_ignite_metric.py +191 -0
- tests/handlers/test_handler_lr_scheduler.py +94 -0
- tests/handlers/test_handler_mean_dice.py +98 -0
- tests/handlers/test_handler_mean_iou.py +76 -0
- tests/handlers/test_handler_metrics_reloaded.py +149 -0
- tests/handlers/test_handler_metrics_saver.py +89 -0
- tests/handlers/test_handler_metrics_saver_dist.py +120 -0
- tests/handlers/test_handler_mlflow.py +296 -0
- tests/handlers/test_handler_nvtx.py +93 -0
- tests/handlers/test_handler_panoptic_quality.py +89 -0
- tests/handlers/test_handler_parameter_scheduler.py +136 -0
- tests/handlers/test_handler_post_processing.py +74 -0
- tests/handlers/test_handler_prob_map_producer.py +111 -0
- tests/handlers/test_handler_regression_metrics.py +160 -0
- tests/handlers/test_handler_regression_metrics_dist.py +245 -0
- tests/handlers/test_handler_rocauc.py +48 -0
- tests/handlers/test_handler_rocauc_dist.py +54 -0
- tests/handlers/test_handler_stats.py +281 -0
- tests/handlers/test_handler_surface_distance.py +113 -0
- tests/handlers/test_handler_tb_image.py +61 -0
- tests/handlers/test_handler_tb_stats.py +166 -0
- tests/handlers/test_handler_validation.py +59 -0
- tests/handlers/test_trt_compile.py +145 -0
- tests/handlers/test_write_metrics_reports.py +68 -0
- tests/inferers/__init__.py +10 -0
- tests/inferers/test_avg_merger.py +179 -0
- tests/inferers/test_controlnet_inferers.py +1388 -0
- tests/inferers/test_diffusion_inferer.py +236 -0
- tests/inferers/test_latent_diffusion_inferer.py +884 -0
- tests/inferers/test_patch_inferer.py +309 -0
- tests/inferers/test_saliency_inferer.py +55 -0
- tests/inferers/test_slice_inferer.py +57 -0
- tests/inferers/test_sliding_window_inference.py +377 -0
- tests/inferers/test_sliding_window_splitter.py +284 -0
- tests/inferers/test_wsi_sliding_window_splitter.py +249 -0
- tests/inferers/test_zarr_avg_merger.py +326 -0
- tests/integration/__init__.py +10 -0
- tests/integration/test_auto3dseg_ensemble.py +211 -0
- tests/integration/test_auto3dseg_hpo.py +189 -0
- tests/integration/test_deepedit_interaction.py +122 -0
- tests/integration/test_downsample_block.py +50 -0
- tests/integration/test_hovernet_nuclear_type_post_processingd.py +71 -0
- tests/integration/test_integration_autorunner.py +201 -0
- tests/integration/test_integration_bundle_run.py +240 -0
- tests/integration/test_integration_classification_2d.py +282 -0
- tests/integration/test_integration_determinism.py +95 -0
- tests/integration/test_integration_fast_train.py +231 -0
- tests/integration/test_integration_gpu_customization.py +159 -0
- tests/integration/test_integration_lazy_samples.py +219 -0
- tests/integration/test_integration_nnunetv2_runner.py +96 -0
- tests/integration/test_integration_segmentation_3d.py +304 -0
- tests/integration/test_integration_sliding_window.py +100 -0
- tests/integration/test_integration_stn.py +133 -0
- tests/integration/test_integration_unet_2d.py +67 -0
- tests/integration/test_integration_workers.py +61 -0
- tests/integration/test_integration_workflows.py +365 -0
- tests/integration/test_integration_workflows_adversarial.py +173 -0
- tests/integration/test_integration_workflows_gan.py +158 -0
- tests/integration/test_loader_semaphore.py +48 -0
- tests/integration/test_mapping_filed.py +122 -0
- tests/integration/test_meta_affine.py +183 -0
- tests/integration/test_metatensor_integration.py +114 -0
- tests/integration/test_module_list.py +76 -0
- tests/integration/test_one_of.py +283 -0
- tests/integration/test_pad_collation.py +124 -0
- tests/integration/test_reg_loss_integration.py +107 -0
- tests/integration/test_retinanet_predict_utils.py +154 -0
- tests/integration/test_seg_loss_integration.py +159 -0
- tests/integration/test_spatial_combine_transforms.py +185 -0
- tests/integration/test_testtimeaugmentation.py +186 -0
- tests/integration/test_vis_gradbased.py +69 -0
- tests/integration/test_vista3d_utils.py +159 -0
- tests/losses/__init__.py +10 -0
- tests/losses/deform/__init__.py +10 -0
- tests/losses/deform/test_bending_energy.py +88 -0
- tests/losses/deform/test_diffusion_loss.py +117 -0
- tests/losses/image_dissimilarity/__init__.py +10 -0
- tests/losses/image_dissimilarity/test_global_mutual_information_loss.py +150 -0
- tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py +162 -0
- tests/losses/test_adversarial_loss.py +94 -0
- tests/losses/test_barlow_twins_loss.py +109 -0
- tests/losses/test_cldice_loss.py +51 -0
- tests/losses/test_contrastive_loss.py +86 -0
- tests/losses/test_dice_ce_loss.py +123 -0
- tests/losses/test_dice_focal_loss.py +124 -0
- tests/losses/test_dice_loss.py +227 -0
- tests/losses/test_ds_loss.py +189 -0
- tests/losses/test_focal_loss.py +379 -0
- tests/losses/test_generalized_dice_focal_loss.py +85 -0
- tests/losses/test_generalized_dice_loss.py +221 -0
- tests/losses/test_generalized_wasserstein_dice_loss.py +234 -0
- tests/losses/test_giou_loss.py +62 -0
- tests/losses/test_hausdorff_loss.py +264 -0
- tests/losses/test_masked_dice_loss.py +152 -0
- tests/losses/test_masked_loss.py +87 -0
- tests/losses/test_multi_scale.py +86 -0
- tests/losses/test_nacl_loss.py +167 -0
- tests/losses/test_perceptual_loss.py +122 -0
- tests/losses/test_spectral_loss.py +86 -0
- tests/losses/test_ssim_loss.py +59 -0
- tests/losses/test_sure_loss.py +72 -0
- tests/losses/test_tversky_loss.py +198 -0
- tests/losses/test_unified_focal_loss.py +66 -0
- tests/metrics/__init__.py +10 -0
- tests/metrics/test_compute_average_precision.py +162 -0
- tests/metrics/test_compute_confusion_matrix.py +294 -0
- tests/metrics/test_compute_f_beta.py +80 -0
- tests/metrics/test_compute_fid_metric.py +40 -0
- tests/metrics/test_compute_froc.py +143 -0
- tests/metrics/test_compute_generalized_dice.py +240 -0
- tests/metrics/test_compute_meandice.py +306 -0
- tests/metrics/test_compute_meaniou.py +223 -0
- tests/metrics/test_compute_mmd_metric.py +56 -0
- tests/metrics/test_compute_multiscalessim_metric.py +83 -0
- tests/metrics/test_compute_panoptic_quality.py +113 -0
- tests/metrics/test_compute_regression_metrics.py +196 -0
- tests/metrics/test_compute_roc_auc.py +155 -0
- tests/metrics/test_compute_variance.py +147 -0
- tests/metrics/test_cumulative.py +63 -0
- tests/metrics/test_cumulative_average.py +74 -0
- tests/metrics/test_cumulative_average_dist.py +48 -0
- tests/metrics/test_hausdorff_distance.py +209 -0
- tests/metrics/test_label_quality_score.py +134 -0
- tests/metrics/test_loss_metric.py +57 -0
- tests/metrics/test_metrics_reloaded.py +96 -0
- tests/metrics/test_ssim_metric.py +78 -0
- tests/metrics/test_surface_dice.py +416 -0
- tests/metrics/test_surface_distance.py +186 -0
- tests/networks/__init__.py +10 -0
- tests/networks/blocks/__init__.py +10 -0
- tests/networks/blocks/dints_block/__init__.py +10 -0
- tests/networks/blocks/dints_block/test_acn_block.py +41 -0
- tests/networks/blocks/dints_block/test_factorized_increase.py +37 -0
- tests/networks/blocks/dints_block/test_factorized_reduce.py +37 -0
- tests/networks/blocks/dints_block/test_p3d_block.py +78 -0
- tests/networks/blocks/test_adn.py +86 -0
- tests/networks/blocks/test_convolutions.py +156 -0
- tests/networks/blocks/test_crf_cpu.py +513 -0
- tests/networks/blocks/test_crf_cuda.py +528 -0
- tests/networks/blocks/test_crossattention.py +185 -0
- tests/networks/blocks/test_denseblock.py +105 -0
- tests/networks/blocks/test_dynunet_block.py +116 -0
- tests/networks/blocks/test_fpn_block.py +88 -0
- tests/networks/blocks/test_localnet_block.py +121 -0
- tests/networks/blocks/test_mlp.py +78 -0
- tests/networks/blocks/test_patchembedding.py +212 -0
- tests/networks/blocks/test_regunet_block.py +103 -0
- tests/networks/blocks/test_se_block.py +85 -0
- tests/networks/blocks/test_se_blocks.py +78 -0
- tests/networks/blocks/test_segresnet_block.py +57 -0
- tests/networks/blocks/test_selfattention.py +232 -0
- tests/networks/blocks/test_simple_aspp.py +87 -0
- tests/networks/blocks/test_spatialattention.py +55 -0
- tests/networks/blocks/test_subpixel_upsample.py +87 -0
- tests/networks/blocks/test_text_encoding.py +49 -0
- tests/networks/blocks/test_transformerblock.py +90 -0
- tests/networks/blocks/test_unetr_block.py +158 -0
- tests/networks/blocks/test_upsample_block.py +134 -0
- tests/networks/blocks/warp/__init__.py +10 -0
- tests/networks/blocks/warp/test_dvf2ddf.py +72 -0
- tests/networks/blocks/warp/test_warp.py +250 -0
- tests/networks/layers/__init__.py +10 -0
- tests/networks/layers/filtering/__init__.py +10 -0
- tests/networks/layers/filtering/test_bilateral_approx_cpu.py +399 -0
- tests/networks/layers/filtering/test_bilateral_approx_cuda.py +404 -0
- tests/networks/layers/filtering/test_bilateral_precise.py +437 -0
- tests/networks/layers/filtering/test_phl_cpu.py +259 -0
- tests/networks/layers/filtering/test_phl_cuda.py +167 -0
- tests/networks/layers/filtering/test_trainable_bilateral.py +474 -0
- tests/networks/layers/filtering/test_trainable_joint_bilateral.py +609 -0
- tests/networks/layers/test_affine_transform.py +385 -0
- tests/networks/layers/test_apply_filter.py +89 -0
- tests/networks/layers/test_channel_pad.py +51 -0
- tests/networks/layers/test_conjugate_gradient.py +56 -0
- tests/networks/layers/test_drop_path.py +46 -0
- tests/networks/layers/test_gaussian.py +317 -0
- tests/networks/layers/test_gaussian_filter.py +206 -0
- tests/networks/layers/test_get_layers.py +65 -0
- tests/networks/layers/test_gmm.py +314 -0
- tests/networks/layers/test_grid_pull.py +93 -0
- tests/networks/layers/test_hilbert_transform.py +131 -0
- tests/networks/layers/test_lltm.py +62 -0
- tests/networks/layers/test_median_filter.py +52 -0
- tests/networks/layers/test_polyval.py +55 -0
- tests/networks/layers/test_preset_filters.py +136 -0
- tests/networks/layers/test_savitzky_golay_filter.py +141 -0
- tests/networks/layers/test_separable_filter.py +87 -0
- tests/networks/layers/test_skip_connection.py +48 -0
- tests/networks/layers/test_vector_quantizer.py +89 -0
- tests/networks/layers/test_weight_init.py +50 -0
- tests/networks/nets/__init__.py +10 -0
- tests/networks/nets/dints/__init__.py +10 -0
- tests/networks/nets/dints/test_dints_cell.py +110 -0
- tests/networks/nets/dints/test_dints_mixop.py +84 -0
- tests/networks/nets/regunet/__init__.py +10 -0
- tests/networks/nets/regunet/test_localnet.py +86 -0
- tests/networks/nets/regunet/test_regunet.py +88 -0
- tests/networks/nets/test_ahnet.py +224 -0
- tests/networks/nets/test_attentionunet.py +88 -0
- tests/networks/nets/test_autoencoder.py +95 -0
- tests/networks/nets/test_autoencoderkl.py +337 -0
- tests/networks/nets/test_basic_unet.py +102 -0
- tests/networks/nets/test_basic_unetplusplus.py +109 -0
- tests/networks/nets/test_bundle_init_bundle.py +55 -0
- tests/networks/nets/test_cell_sam_wrapper.py +58 -0
- tests/networks/nets/test_controlnet.py +215 -0
- tests/networks/nets/test_daf3d.py +62 -0
- tests/networks/nets/test_densenet.py +121 -0
- tests/networks/nets/test_diffusion_model_unet.py +585 -0
- tests/networks/nets/test_dints_network.py +168 -0
- tests/networks/nets/test_discriminator.py +59 -0
- tests/networks/nets/test_dynunet.py +181 -0
- tests/networks/nets/test_efficientnet.py +400 -0
- tests/networks/nets/test_flexible_unet.py +341 -0
- tests/networks/nets/test_fullyconnectednet.py +69 -0
- tests/networks/nets/test_generator.py +59 -0
- tests/networks/nets/test_globalnet.py +103 -0
- tests/networks/nets/test_highresnet.py +67 -0
- tests/networks/nets/test_hovernet.py +218 -0
- tests/networks/nets/test_mednext.py +122 -0
- tests/networks/nets/test_milmodel.py +92 -0
- tests/networks/nets/test_net_adapter.py +68 -0
- tests/networks/nets/test_network_consistency.py +86 -0
- tests/networks/nets/test_patch_gan_dicriminator.py +179 -0
- tests/networks/nets/test_quicknat.py +57 -0
- tests/networks/nets/test_resnet.py +340 -0
- tests/networks/nets/test_segresnet.py +120 -0
- tests/networks/nets/test_segresnet_ds.py +156 -0
- tests/networks/nets/test_senet.py +151 -0
- tests/networks/nets/test_spade_autoencoderkl.py +295 -0
- tests/networks/nets/test_spade_diffusion_model_unet.py +574 -0
- tests/networks/nets/test_spade_vaegan.py +140 -0
- tests/networks/nets/test_swin_unetr.py +139 -0
- tests/networks/nets/test_torchvision_fc_model.py +201 -0
- tests/networks/nets/test_transchex.py +84 -0
- tests/networks/nets/test_transformer.py +108 -0
- tests/networks/nets/test_unet.py +208 -0
- tests/networks/nets/test_unetr.py +137 -0
- tests/networks/nets/test_varautoencoder.py +127 -0
- tests/networks/nets/test_vista3d.py +84 -0
- tests/networks/nets/test_vit.py +139 -0
- tests/networks/nets/test_vitautoenc.py +112 -0
- tests/networks/nets/test_vnet.py +81 -0
- tests/networks/nets/test_voxelmorph.py +280 -0
- tests/networks/nets/test_vqvae.py +274 -0
- tests/networks/nets/test_vqvaetransformer_inferer.py +295 -0
- tests/networks/schedulers/__init__.py +10 -0
- tests/networks/schedulers/test_scheduler_ddim.py +83 -0
- tests/networks/schedulers/test_scheduler_ddpm.py +104 -0
- tests/networks/schedulers/test_scheduler_pndm.py +108 -0
- tests/networks/test_bundle_onnx_export.py +71 -0
- tests/networks/test_convert_to_onnx.py +106 -0
- tests/networks/test_convert_to_torchscript.py +46 -0
- tests/networks/test_convert_to_trt.py +79 -0
- tests/networks/test_save_state.py +73 -0
- tests/networks/test_to_onehot.py +63 -0
- tests/networks/test_varnet.py +63 -0
- tests/networks/utils/__init__.py +10 -0
- tests/networks/utils/test_copy_model_state.py +187 -0
- tests/networks/utils/test_eval_mode.py +34 -0
- tests/networks/utils/test_freeze_layers.py +61 -0
- tests/networks/utils/test_replace_module.py +98 -0
- tests/networks/utils/test_train_mode.py +34 -0
- tests/optimizers/__init__.py +10 -0
- tests/optimizers/test_generate_param_groups.py +105 -0
- tests/optimizers/test_lr_finder.py +108 -0
- tests/optimizers/test_lr_scheduler.py +71 -0
- tests/optimizers/test_optim_novograd.py +100 -0
- tests/profile_subclass/__init__.py +10 -0
- tests/profile_subclass/cprofile_profiling.py +29 -0
- tests/profile_subclass/min_classes.py +30 -0
- tests/profile_subclass/profiling.py +73 -0
- tests/profile_subclass/pyspy_profiling.py +41 -0
- tests/transforms/__init__.py +10 -0
- tests/transforms/compose/__init__.py +10 -0
- tests/transforms/compose/test_compose.py +758 -0
- tests/transforms/compose/test_some_of.py +258 -0
- tests/transforms/croppad/__init__.py +10 -0
- tests/transforms/croppad/test_rand_weighted_crop.py +224 -0
- tests/transforms/croppad/test_rand_weighted_cropd.py +182 -0
- tests/transforms/functional/__init__.py +10 -0
- tests/transforms/functional/test_apply.py +75 -0
- tests/transforms/functional/test_resample.py +50 -0
- tests/transforms/intensity/__init__.py +10 -0
- tests/transforms/intensity/test_compute_ho_ver_maps.py +75 -0
- tests/transforms/intensity/test_compute_ho_ver_maps_d.py +79 -0
- tests/transforms/intensity/test_foreground_mask.py +98 -0
- tests/transforms/intensity/test_foreground_maskd.py +106 -0
- tests/transforms/intensity/test_rand_histogram_shiftd.py +76 -0
- tests/transforms/intensity/test_scale_intensity_range_percentiles.py +96 -0
- tests/transforms/intensity/test_scale_intensity_range_percentilesd.py +100 -0
- tests/transforms/inverse/__init__.py +10 -0
- tests/transforms/inverse/test_inverse_array.py +76 -0
- tests/transforms/inverse/test_traceable_transform.py +59 -0
- tests/transforms/post/__init__.py +10 -0
- tests/transforms/post/test_label_filterd.py +78 -0
- tests/transforms/post/test_probnms.py +72 -0
- tests/transforms/post/test_probnmsd.py +79 -0
- tests/transforms/post/test_remove_small_objects.py +102 -0
- tests/transforms/spatial/__init__.py +10 -0
- tests/transforms/spatial/test_convert_box_points.py +119 -0
- tests/transforms/spatial/test_grid_patch.py +134 -0
- tests/transforms/spatial/test_grid_patchd.py +102 -0
- tests/transforms/spatial/test_rand_grid_patch.py +150 -0
- tests/transforms/spatial/test_rand_grid_patchd.py +117 -0
- tests/transforms/spatial/test_spatial_resampled.py +124 -0
- tests/transforms/test_activations.py +120 -0
- tests/transforms/test_activationsd.py +64 -0
- tests/transforms/test_adaptors.py +160 -0
- tests/transforms/test_add_coordinate_channels.py +53 -0
- tests/transforms/test_add_coordinate_channelsd.py +67 -0
- tests/transforms/test_add_extreme_points_channel.py +80 -0
- tests/transforms/test_add_extreme_points_channeld.py +77 -0
- tests/transforms/test_adjust_contrast.py +70 -0
- tests/transforms/test_adjust_contrastd.py +64 -0
- tests/transforms/test_affine.py +245 -0
- tests/transforms/test_affine_grid.py +152 -0
- tests/transforms/test_affined.py +190 -0
- tests/transforms/test_as_channel_last.py +38 -0
- tests/transforms/test_as_channel_lastd.py +44 -0
- tests/transforms/test_as_discrete.py +81 -0
- tests/transforms/test_as_discreted.py +82 -0
- tests/transforms/test_border_pad.py +49 -0
- tests/transforms/test_border_padd.py +45 -0
- tests/transforms/test_bounding_rect.py +54 -0
- tests/transforms/test_bounding_rectd.py +53 -0
- tests/transforms/test_cast_to_type.py +63 -0
- tests/transforms/test_cast_to_typed.py +74 -0
- tests/transforms/test_center_scale_crop.py +55 -0
- tests/transforms/test_center_scale_cropd.py +56 -0
- tests/transforms/test_center_spatial_crop.py +56 -0
- tests/transforms/test_center_spatial_cropd.py +63 -0
- tests/transforms/test_classes_to_indices.py +93 -0
- tests/transforms/test_classes_to_indicesd.py +110 -0
- tests/transforms/test_clip_intensity_percentiles.py +196 -0
- tests/transforms/test_clip_intensity_percentilesd.py +193 -0
- tests/transforms/test_compose_get_number_conversions.py +127 -0
- tests/transforms/test_concat_itemsd.py +82 -0
- tests/transforms/test_convert_to_multi_channel.py +59 -0
- tests/transforms/test_convert_to_multi_channeld.py +37 -0
- tests/transforms/test_copy_itemsd.py +86 -0
- tests/transforms/test_create_grid_and_affine.py +274 -0
- tests/transforms/test_crop_foreground.py +164 -0
- tests/transforms/test_crop_foregroundd.py +205 -0
- tests/transforms/test_cucim_dict_transform.py +142 -0
- tests/transforms/test_cucim_transform.py +141 -0
- tests/transforms/test_data_stats.py +221 -0
- tests/transforms/test_data_statsd.py +249 -0
- tests/transforms/test_delete_itemsd.py +58 -0
- tests/transforms/test_detect_envelope.py +159 -0
- tests/transforms/test_distance_transform_edt.py +202 -0
- tests/transforms/test_divisible_pad.py +49 -0
- tests/transforms/test_divisible_padd.py +42 -0
- tests/transforms/test_ensure_channel_first.py +113 -0
- tests/transforms/test_ensure_channel_firstd.py +85 -0
- tests/transforms/test_ensure_type.py +94 -0
- tests/transforms/test_ensure_typed.py +110 -0
- tests/transforms/test_fg_bg_to_indices.py +83 -0
- tests/transforms/test_fg_bg_to_indicesd.py +78 -0
- tests/transforms/test_fill_holes.py +207 -0
- tests/transforms/test_fill_holesd.py +209 -0
- tests/transforms/test_flatten_sub_keysd.py +64 -0
- tests/transforms/test_flip.py +83 -0
- tests/transforms/test_flipd.py +90 -0
- tests/transforms/test_fourier.py +70 -0
- tests/transforms/test_gaussian_sharpen.py +92 -0
- tests/transforms/test_gaussian_sharpend.py +92 -0
- tests/transforms/test_gaussian_smooth.py +96 -0
- tests/transforms/test_gaussian_smoothd.py +96 -0
- tests/transforms/test_generate_label_classes_crop_centers.py +71 -0
- tests/transforms/test_generate_pos_neg_label_crop_centers.py +76 -0
- tests/transforms/test_generate_spatial_bounding_box.py +114 -0
- tests/transforms/test_get_extreme_points.py +57 -0
- tests/transforms/test_gibbs_noise.py +73 -0
- tests/transforms/test_gibbs_noised.py +88 -0
- tests/transforms/test_grid_distortion.py +113 -0
- tests/transforms/test_grid_distortiond.py +87 -0
- tests/transforms/test_grid_split.py +88 -0
- tests/transforms/test_grid_splitd.py +96 -0
- tests/transforms/test_histogram_normalize.py +59 -0
- tests/transforms/test_histogram_normalized.py +59 -0
- tests/transforms/test_image_filter.py +259 -0
- tests/transforms/test_intensity_stats.py +73 -0
- tests/transforms/test_intensity_statsd.py +90 -0
- tests/transforms/test_inverse.py +521 -0
- tests/transforms/test_inverse_collation.py +147 -0
- tests/transforms/test_invert.py +105 -0
- tests/transforms/test_invertd.py +142 -0
- tests/transforms/test_k_space_spike_noise.py +81 -0
- tests/transforms/test_k_space_spike_noised.py +98 -0
- tests/transforms/test_keep_largest_connected_component.py +419 -0
- tests/transforms/test_keep_largest_connected_componentd.py +348 -0
- tests/transforms/test_label_filter.py +78 -0
- tests/transforms/test_label_to_contour.py +179 -0
- tests/transforms/test_label_to_contourd.py +182 -0
- tests/transforms/test_label_to_mask.py +69 -0
- tests/transforms/test_label_to_maskd.py +70 -0
- tests/transforms/test_load_image.py +502 -0
- tests/transforms/test_load_imaged.py +198 -0
- tests/transforms/test_load_spacing_orientation.py +149 -0
- tests/transforms/test_map_and_generate_sampling_centers.py +86 -0
- tests/transforms/test_map_binary_to_indices.py +75 -0
- tests/transforms/test_map_classes_to_indices.py +135 -0
- tests/transforms/test_map_label_value.py +89 -0
- tests/transforms/test_map_label_valued.py +85 -0
- tests/transforms/test_map_transform.py +45 -0
- tests/transforms/test_mask_intensity.py +74 -0
- tests/transforms/test_mask_intensityd.py +68 -0
- tests/transforms/test_mean_ensemble.py +77 -0
- tests/transforms/test_mean_ensembled.py +91 -0
- tests/transforms/test_median_smooth.py +41 -0
- tests/transforms/test_median_smoothd.py +65 -0
- tests/transforms/test_morphological_ops.py +101 -0
- tests/transforms/test_nifti_endianness.py +107 -0
- tests/transforms/test_normalize_intensity.py +143 -0
- tests/transforms/test_normalize_intensityd.py +81 -0
- tests/transforms/test_nvtx_decorator.py +289 -0
- tests/transforms/test_nvtx_transform.py +143 -0
- tests/transforms/test_orientation.py +247 -0
- tests/transforms/test_orientationd.py +112 -0
- tests/transforms/test_rand_adjust_contrast.py +45 -0
- tests/transforms/test_rand_adjust_contrastd.py +44 -0
- tests/transforms/test_rand_affine.py +201 -0
- tests/transforms/test_rand_affine_grid.py +212 -0
- tests/transforms/test_rand_affined.py +281 -0
- tests/transforms/test_rand_axis_flip.py +50 -0
- tests/transforms/test_rand_axis_flipd.py +50 -0
- tests/transforms/test_rand_bias_field.py +69 -0
- tests/transforms/test_rand_bias_fieldd.py +65 -0
- tests/transforms/test_rand_coarse_dropout.py +110 -0
- tests/transforms/test_rand_coarse_dropoutd.py +107 -0
- tests/transforms/test_rand_coarse_shuffle.py +65 -0
- tests/transforms/test_rand_coarse_shuffled.py +59 -0
- tests/transforms/test_rand_crop_by_label_classes.py +170 -0
- tests/transforms/test_rand_crop_by_label_classesd.py +159 -0
- tests/transforms/test_rand_crop_by_pos_neg_label.py +152 -0
- tests/transforms/test_rand_crop_by_pos_neg_labeld.py +172 -0
- tests/transforms/test_rand_cucim_dict_transform.py +162 -0
- tests/transforms/test_rand_cucim_transform.py +162 -0
- tests/transforms/test_rand_deform_grid.py +138 -0
- tests/transforms/test_rand_elastic_2d.py +127 -0
- tests/transforms/test_rand_elastic_3d.py +104 -0
- tests/transforms/test_rand_elasticd_2d.py +177 -0
- tests/transforms/test_rand_elasticd_3d.py +156 -0
- tests/transforms/test_rand_flip.py +60 -0
- tests/transforms/test_rand_flipd.py +55 -0
- tests/transforms/test_rand_gaussian_noise.py +48 -0
- tests/transforms/test_rand_gaussian_noised.py +54 -0
- tests/transforms/test_rand_gaussian_sharpen.py +140 -0
- tests/transforms/test_rand_gaussian_sharpend.py +143 -0
- tests/transforms/test_rand_gaussian_smooth.py +98 -0
- tests/transforms/test_rand_gaussian_smoothd.py +98 -0
- tests/transforms/test_rand_gibbs_noise.py +103 -0
- tests/transforms/test_rand_gibbs_noised.py +117 -0
- tests/transforms/test_rand_grid_distortion.py +99 -0
- tests/transforms/test_rand_grid_distortiond.py +90 -0
- tests/transforms/test_rand_histogram_shift.py +92 -0
- tests/transforms/test_rand_k_space_spike_noise.py +92 -0
- tests/transforms/test_rand_k_space_spike_noised.py +76 -0
- tests/transforms/test_rand_rician_noise.py +52 -0
- tests/transforms/test_rand_rician_noised.py +52 -0
- tests/transforms/test_rand_rotate.py +166 -0
- tests/transforms/test_rand_rotate90.py +100 -0
- tests/transforms/test_rand_rotate90d.py +112 -0
- tests/transforms/test_rand_rotated.py +187 -0
- tests/transforms/test_rand_scale_crop.py +78 -0
- tests/transforms/test_rand_scale_cropd.py +98 -0
- tests/transforms/test_rand_scale_intensity.py +54 -0
- tests/transforms/test_rand_scale_intensity_fixed_mean.py +41 -0
- tests/transforms/test_rand_scale_intensity_fixed_meand.py +41 -0
- tests/transforms/test_rand_scale_intensityd.py +53 -0
- tests/transforms/test_rand_shift_intensity.py +52 -0
- tests/transforms/test_rand_shift_intensityd.py +67 -0
- tests/transforms/test_rand_simulate_low_resolution.py +83 -0
- tests/transforms/test_rand_simulate_low_resolutiond.py +73 -0
- tests/transforms/test_rand_spatial_crop.py +107 -0
- tests/transforms/test_rand_spatial_crop_samples.py +128 -0
- tests/transforms/test_rand_spatial_crop_samplesd.py +147 -0
- tests/transforms/test_rand_spatial_cropd.py +112 -0
- tests/transforms/test_rand_std_shift_intensity.py +43 -0
- tests/transforms/test_rand_std_shift_intensityd.py +38 -0
- tests/transforms/test_rand_zoom.py +105 -0
- tests/transforms/test_rand_zoomd.py +108 -0
- tests/transforms/test_randidentity.py +49 -0
- tests/transforms/test_random_order.py +144 -0
- tests/transforms/test_randtorchvisiond.py +65 -0
- tests/transforms/test_regularization.py +139 -0
- tests/transforms/test_remove_repeated_channel.py +34 -0
- tests/transforms/test_remove_repeated_channeld.py +44 -0
- tests/transforms/test_repeat_channel.py +34 -0
- tests/transforms/test_repeat_channeld.py +41 -0
- tests/transforms/test_resample_backends.py +65 -0
- tests/transforms/test_resample_to_match.py +110 -0
- tests/transforms/test_resample_to_matchd.py +93 -0
- tests/transforms/test_resampler.py +165 -0
- tests/transforms/test_resize.py +140 -0
- tests/transforms/test_resize_with_pad_or_crop.py +91 -0
- tests/transforms/test_resize_with_pad_or_cropd.py +86 -0
- tests/transforms/test_resized.py +163 -0
- tests/transforms/test_rotate.py +160 -0
- tests/transforms/test_rotate90.py +212 -0
- tests/transforms/test_rotate90d.py +106 -0
- tests/transforms/test_rotated.py +179 -0
- tests/transforms/test_save_classificationd.py +109 -0
- tests/transforms/test_save_image.py +80 -0
- tests/transforms/test_save_imaged.py +130 -0
- tests/transforms/test_savitzky_golay_smooth.py +73 -0
- tests/transforms/test_savitzky_golay_smoothd.py +73 -0
- tests/transforms/test_scale_intensity.py +76 -0
- tests/transforms/test_scale_intensity_fixed_mean.py +94 -0
- tests/transforms/test_scale_intensity_range.py +41 -0
- tests/transforms/test_scale_intensity_ranged.py +40 -0
- tests/transforms/test_scale_intensityd.py +57 -0
- tests/transforms/test_select_itemsd.py +41 -0
- tests/transforms/test_shift_intensity.py +31 -0
- tests/transforms/test_shift_intensityd.py +44 -0
- tests/transforms/test_signal_continuouswavelet.py +44 -0
- tests/transforms/test_signal_fillempty.py +52 -0
- tests/transforms/test_signal_fillemptyd.py +60 -0
- tests/transforms/test_signal_rand_add_gaussiannoise.py +50 -0
- tests/transforms/test_signal_rand_add_sine.py +52 -0
- tests/transforms/test_signal_rand_add_sine_partial.py +50 -0
- tests/transforms/test_signal_rand_add_squarepulse.py +58 -0
- tests/transforms/test_signal_rand_add_squarepulse_partial.py +62 -0
- tests/transforms/test_signal_rand_drop.py +50 -0
- tests/transforms/test_signal_rand_scale.py +52 -0
- tests/transforms/test_signal_rand_shift.py +55 -0
- tests/transforms/test_signal_remove_frequency.py +71 -0
- tests/transforms/test_smooth_field.py +177 -0
- tests/transforms/test_sobel_gradient.py +189 -0
- tests/transforms/test_sobel_gradientd.py +212 -0
- tests/transforms/test_spacing.py +381 -0
- tests/transforms/test_spacingd.py +178 -0
- tests/transforms/test_spatial_crop.py +82 -0
- tests/transforms/test_spatial_cropd.py +74 -0
- tests/transforms/test_spatial_pad.py +57 -0
- tests/transforms/test_spatial_padd.py +43 -0
- tests/transforms/test_spatial_resample.py +235 -0
- tests/transforms/test_squeezedim.py +62 -0
- tests/transforms/test_squeezedimd.py +98 -0
- tests/transforms/test_std_shift_intensity.py +76 -0
- tests/transforms/test_std_shift_intensityd.py +74 -0
- tests/transforms/test_threshold_intensity.py +38 -0
- tests/transforms/test_threshold_intensityd.py +58 -0
- tests/transforms/test_to_contiguous.py +47 -0
- tests/transforms/test_to_cupy.py +112 -0
- tests/transforms/test_to_cupyd.py +76 -0
- tests/transforms/test_to_device.py +42 -0
- tests/transforms/test_to_deviced.py +37 -0
- tests/transforms/test_to_numpy.py +85 -0
- tests/transforms/test_to_numpyd.py +68 -0
- tests/transforms/test_to_pil.py +52 -0
- tests/transforms/test_to_pild.py +50 -0
- tests/transforms/test_to_tensor.py +60 -0
- tests/transforms/test_to_tensord.py +71 -0
- tests/transforms/test_torchvision.py +66 -0
- tests/transforms/test_torchvisiond.py +63 -0
- tests/transforms/test_transform.py +62 -0
- tests/transforms/test_transpose.py +41 -0
- tests/transforms/test_transposed.py +52 -0
- tests/transforms/test_ultrasound_confidence_map_transform.py +711 -0
- tests/transforms/test_utils_pytorch_numpy_unification.py +90 -0
- tests/transforms/test_vote_ensemble.py +84 -0
- tests/transforms/test_vote_ensembled.py +107 -0
- tests/transforms/test_with_allow_missing_keys.py +76 -0
- tests/transforms/test_zoom.py +120 -0
- tests/transforms/test_zoomd.py +94 -0
- tests/transforms/transform/__init__.py +10 -0
- tests/transforms/transform/test_randomizable.py +52 -0
- tests/transforms/transform/test_randomizable_transform_type.py +37 -0
- tests/transforms/utility/__init__.py +10 -0
- tests/transforms/utility/test_apply_transform_to_points.py +81 -0
- tests/transforms/utility/test_apply_transform_to_pointsd.py +185 -0
- tests/transforms/utility/test_identity.py +29 -0
- tests/transforms/utility/test_identityd.py +30 -0
- tests/transforms/utility/test_lambda.py +71 -0
- tests/transforms/utility/test_lambdad.py +83 -0
- tests/transforms/utility/test_rand_lambda.py +87 -0
- tests/transforms/utility/test_rand_lambdad.py +77 -0
- tests/transforms/utility/test_simulatedelay.py +36 -0
- tests/transforms/utility/test_simulatedelayd.py +36 -0
- tests/transforms/utility/test_splitdim.py +52 -0
- tests/transforms/utility/test_splitdimd.py +96 -0
- tests/transforms/utils/__init__.py +10 -0
- tests/transforms/utils/test_correct_crop_centers.py +36 -0
- tests/transforms/utils/test_get_unique_labels.py +45 -0
- tests/transforms/utils/test_print_transform_backends.py +29 -0
- tests/transforms/utils/test_soft_clip.py +125 -0
- tests/utils/__init__.py +10 -0
- tests/utils/enums/__init__.py +10 -0
- tests/utils/enums/test_hovernet_loss.py +190 -0
- tests/utils/enums/test_ordering.py +289 -0
- tests/utils/enums/test_wsireader.py +663 -0
- tests/utils/misc/__init__.py +10 -0
- tests/utils/misc/test_ensure_tuple.py +53 -0
- tests/utils/misc/test_monai_env_vars.py +44 -0
- tests/utils/misc/test_monai_utils_misc.py +103 -0
- tests/utils/misc/test_str2bool.py +34 -0
- tests/utils/misc/test_str2list.py +33 -0
- tests/utils/test_alias.py +44 -0
- tests/utils/test_component_store.py +73 -0
- tests/utils/test_deprecated.py +455 -0
- tests/utils/test_enum_bound_interp.py +75 -0
- tests/utils/test_evenly_divisible_all_gather_dist.py +50 -0
- tests/utils/test_get_package_version.py +34 -0
- tests/utils/test_handler_logfile.py +84 -0
- tests/utils/test_handler_metric_logger.py +62 -0
- tests/utils/test_list_to_dict.py +43 -0
- tests/utils/test_look_up_option.py +87 -0
- tests/utils/test_optional_import.py +80 -0
- tests/utils/test_pad_mode.py +39 -0
- tests/utils/test_profiling.py +208 -0
- tests/utils/test_rankfilter_dist.py +77 -0
- tests/utils/test_require_pkg.py +83 -0
- tests/utils/test_sample_slices.py +43 -0
- tests/utils/test_set_determinism.py +74 -0
- tests/utils/test_squeeze_unsqueeze.py +71 -0
- tests/utils/test_state_cacher.py +67 -0
- tests/utils/test_torchscript_utils.py +113 -0
- tests/utils/test_version.py +91 -0
- tests/utils/test_version_after.py +65 -0
- tests/utils/type_conversion/__init__.py +10 -0
- tests/utils/type_conversion/test_convert_data_type.py +152 -0
- tests/utils/type_conversion/test_get_equivalent_dtype.py +65 -0
- tests/utils/type_conversion/test_safe_dtype_range.py +99 -0
- tests/visualize/__init__.py +10 -0
- tests/visualize/test_img2tensorboard.py +46 -0
- tests/visualize/test_occlusion_sensitivity.py +128 -0
- tests/visualize/test_plot_2d_or_3d_image.py +74 -0
- tests/visualize/test_vis_cam.py +98 -0
- tests/visualize/test_vis_gradcam.py +211 -0
- tests/visualize/utils/__init__.py +10 -0
- tests/visualize/utils/test_blend_images.py +63 -0
- tests/visualize/utils/test_matshow3d.py +133 -0
- monai_weekly-1.5.dev2506.dist-info/RECORD +0 -427
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/LICENSE +0 -0
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/WHEEL +0 -0
@@ -0,0 +1,304 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import os
|
15
|
+
import shutil
|
16
|
+
import tempfile
|
17
|
+
import unittest
|
18
|
+
from glob import glob
|
19
|
+
|
20
|
+
import nibabel as nib
|
21
|
+
import numpy as np
|
22
|
+
import torch
|
23
|
+
|
24
|
+
import monai
|
25
|
+
from monai.data import create_test_image_3d, decollate_batch
|
26
|
+
from monai.inferers import sliding_window_inference
|
27
|
+
from monai.metrics import DiceMetric
|
28
|
+
from monai.networks import eval_mode
|
29
|
+
from monai.networks.nets import UNet
|
30
|
+
from monai.transforms import (
|
31
|
+
Activations,
|
32
|
+
AsDiscrete,
|
33
|
+
Compose,
|
34
|
+
EnsureChannelFirstd,
|
35
|
+
LoadImaged,
|
36
|
+
RandCropByPosNegLabeld,
|
37
|
+
RandRotate90d,
|
38
|
+
SaveImage,
|
39
|
+
ScaleIntensityd,
|
40
|
+
Spacingd,
|
41
|
+
)
|
42
|
+
from monai.utils import optional_import, set_determinism
|
43
|
+
from monai.visualize import plot_2d_or_3d_image
|
44
|
+
from tests.test_utils import DistTestCase, TimedCall, skip_if_quick
|
45
|
+
from tests.testing_data.integration_answers import test_integration_value
|
46
|
+
|
47
|
+
SummaryWriter, _ = optional_import("torch.utils.tensorboard", name="SummaryWriter")
|
48
|
+
|
49
|
+
TASK = "integration_segmentation_3d"
|
50
|
+
|
51
|
+
|
52
|
+
def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, None)):
|
53
|
+
monai.config.print_config()
|
54
|
+
images = sorted(glob(os.path.join(root_dir, "img*.nii.gz")))
|
55
|
+
segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
|
56
|
+
train_files = [{"img": img, "seg": seg} for img, seg in zip(images[:20], segs[:20])]
|
57
|
+
val_files = [{"img": img, "seg": seg} for img, seg in zip(images[-20:], segs[-20:])]
|
58
|
+
|
59
|
+
# define transforms for image and segmentation
|
60
|
+
train_transforms = Compose(
|
61
|
+
[
|
62
|
+
LoadImaged(keys=["img", "seg"], reader=readers[0]),
|
63
|
+
EnsureChannelFirstd(keys=["img", "seg"]),
|
64
|
+
# resampling with align_corners=True or dtype=float64 will generate
|
65
|
+
# slight different results between PyTorch 1.5 an 1.6
|
66
|
+
Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32),
|
67
|
+
ScaleIntensityd(keys="img"),
|
68
|
+
RandCropByPosNegLabeld(
|
69
|
+
keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4
|
70
|
+
),
|
71
|
+
RandRotate90d(keys=["img", "seg"], prob=0.8, spatial_axes=[0, 2]),
|
72
|
+
]
|
73
|
+
)
|
74
|
+
train_transforms.set_random_state(1234)
|
75
|
+
val_transforms = Compose(
|
76
|
+
[
|
77
|
+
LoadImaged(keys=["img", "seg"], reader=readers[1]),
|
78
|
+
EnsureChannelFirstd(keys=["img", "seg"]),
|
79
|
+
# resampling with align_corners=True or dtype=float64 will generate
|
80
|
+
# slight different results between PyTorch 1.5 an 1.6
|
81
|
+
Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32),
|
82
|
+
ScaleIntensityd(keys="img"),
|
83
|
+
]
|
84
|
+
)
|
85
|
+
|
86
|
+
# create a training data loader
|
87
|
+
if cachedataset == 2:
|
88
|
+
train_ds = monai.data.CacheDataset(
|
89
|
+
data=train_files, transform=train_transforms, cache_rate=0.8, runtime_cache="process"
|
90
|
+
)
|
91
|
+
elif cachedataset == 3:
|
92
|
+
train_ds = monai.data.LMDBDataset(data=train_files, transform=train_transforms, cache_dir=root_dir)
|
93
|
+
else:
|
94
|
+
train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
|
95
|
+
# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
|
96
|
+
train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
|
97
|
+
# create a validation data loader
|
98
|
+
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
|
99
|
+
val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
|
100
|
+
val_post_tran = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
|
101
|
+
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
|
102
|
+
|
103
|
+
# create UNet, DiceLoss and Adam optimizer
|
104
|
+
model = monai.networks.nets.UNet(
|
105
|
+
spatial_dims=3,
|
106
|
+
in_channels=1,
|
107
|
+
out_channels=1,
|
108
|
+
channels=(16, 32, 64, 128, 256),
|
109
|
+
strides=(2, 2, 2, 2),
|
110
|
+
num_res_units=2,
|
111
|
+
).to(device)
|
112
|
+
loss_function = monai.losses.DiceLoss(sigmoid=True)
|
113
|
+
optimizer = torch.optim.Adam(model.parameters(), 5e-4)
|
114
|
+
|
115
|
+
# start a typical PyTorch training
|
116
|
+
val_interval = 2
|
117
|
+
best_metric, best_metric_epoch = -1, -1
|
118
|
+
epoch_loss_values = []
|
119
|
+
metric_values = []
|
120
|
+
writer = SummaryWriter(log_dir=os.path.join(root_dir, "runs"))
|
121
|
+
model_filename = os.path.join(root_dir, "best_metric_model.pth")
|
122
|
+
for epoch in range(6):
|
123
|
+
print("-" * 10)
|
124
|
+
print(f"Epoch {epoch + 1}/{6}")
|
125
|
+
model.train()
|
126
|
+
epoch_loss = 0
|
127
|
+
step = 0
|
128
|
+
for batch_data in train_loader:
|
129
|
+
step += 1
|
130
|
+
inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device)
|
131
|
+
optimizer.zero_grad()
|
132
|
+
outputs = model(inputs)
|
133
|
+
loss = loss_function(outputs, labels)
|
134
|
+
loss.backward()
|
135
|
+
optimizer.step()
|
136
|
+
epoch_loss += loss.item()
|
137
|
+
epoch_len = len(train_ds) // train_loader.batch_size
|
138
|
+
print(f"{step}/{epoch_len}, train_loss:{loss.item():0.4f}")
|
139
|
+
writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
|
140
|
+
epoch_loss /= step
|
141
|
+
epoch_loss_values.append(epoch_loss)
|
142
|
+
print(f"epoch {epoch + 1} average loss:{epoch_loss:0.4f}")
|
143
|
+
|
144
|
+
if (epoch + 1) % val_interval == 0:
|
145
|
+
with eval_mode(model):
|
146
|
+
val_images = None
|
147
|
+
val_labels = None
|
148
|
+
val_outputs = None
|
149
|
+
for val_data in val_loader:
|
150
|
+
val_images, val_labels = (val_data["img"].to(device), val_data["seg"].to(device))
|
151
|
+
sw_batch_size, roi_size = 4, (96, 96, 96)
|
152
|
+
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
|
153
|
+
# decollate prediction into a list and execute post processing for every item
|
154
|
+
val_outputs = [val_post_tran(i) for i in decollate_batch(val_outputs)]
|
155
|
+
# compute metrics
|
156
|
+
dice_metric(y_pred=val_outputs, y=val_labels)
|
157
|
+
|
158
|
+
metric = dice_metric.aggregate().item()
|
159
|
+
dice_metric.reset()
|
160
|
+
metric_values.append(metric)
|
161
|
+
if metric > best_metric:
|
162
|
+
best_metric = metric
|
163
|
+
best_metric_epoch = epoch + 1
|
164
|
+
torch.save(model.state_dict(), model_filename)
|
165
|
+
print("saved new best metric model")
|
166
|
+
print(
|
167
|
+
f"current epoch {epoch + 1} current mean dice: {metric:0.4f} "
|
168
|
+
f"best mean dice: {best_metric:0.4f} at epoch {best_metric_epoch}"
|
169
|
+
)
|
170
|
+
writer.add_scalar("val_mean_dice", metric, epoch + 1)
|
171
|
+
# plot the last model output as GIF image in TensorBoard with the corresponding image and label
|
172
|
+
plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
|
173
|
+
plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
|
174
|
+
plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")
|
175
|
+
print(f"train completed, best_metric: {best_metric:0.4f} at epoch: {best_metric_epoch}")
|
176
|
+
writer.close()
|
177
|
+
return epoch_loss_values, best_metric
|
178
|
+
|
179
|
+
|
180
|
+
def run_inference_test(root_dir, device="cuda:0"):
|
181
|
+
images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
|
182
|
+
segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
|
183
|
+
val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]
|
184
|
+
|
185
|
+
saver = SaveImage(
|
186
|
+
output_dir=os.path.join(root_dir, "output"),
|
187
|
+
dtype=np.float32,
|
188
|
+
output_ext=".nii.gz",
|
189
|
+
output_postfix="seg",
|
190
|
+
mode="bilinear",
|
191
|
+
)
|
192
|
+
# define transforms for image and segmentation
|
193
|
+
val_transforms = Compose(
|
194
|
+
[
|
195
|
+
LoadImaged(keys=["img", "seg"]),
|
196
|
+
EnsureChannelFirstd(keys=["img", "seg"]),
|
197
|
+
# resampling with align_corners=True or dtype=float64 will generate
|
198
|
+
# slight different results between PyTorch 1.5 an 1.6
|
199
|
+
Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32),
|
200
|
+
ScaleIntensityd(keys="img"),
|
201
|
+
]
|
202
|
+
)
|
203
|
+
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
|
204
|
+
# sliding window inference need to input 1 image in every iteration
|
205
|
+
val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
|
206
|
+
val_post_tran = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5), saver])
|
207
|
+
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
|
208
|
+
|
209
|
+
model = UNet(
|
210
|
+
spatial_dims=3,
|
211
|
+
in_channels=1,
|
212
|
+
out_channels=1,
|
213
|
+
channels=(16, 32, 64, 128, 256),
|
214
|
+
strides=(2, 2, 2, 2),
|
215
|
+
num_res_units=2,
|
216
|
+
).to(device)
|
217
|
+
|
218
|
+
model_filename = os.path.join(root_dir, "best_metric_model.pth")
|
219
|
+
model.load_state_dict(torch.load(model_filename))
|
220
|
+
with eval_mode(model):
|
221
|
+
# resampling with align_corners=True or dtype=float64 will generate
|
222
|
+
# slight different results between PyTorch 1.5 an 1.6
|
223
|
+
for val_data in val_loader:
|
224
|
+
val_images, val_labels = (val_data["img"].to(device), val_data["seg"].to(device))
|
225
|
+
# define sliding window size and batch size for windows inference
|
226
|
+
sw_batch_size, roi_size = 4, (96, 96, 96)
|
227
|
+
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
|
228
|
+
# decollate prediction into a list
|
229
|
+
val_outputs = [val_post_tran(i) for i in decollate_batch(val_outputs)]
|
230
|
+
# compute metrics
|
231
|
+
dice_metric(y_pred=val_outputs, y=val_labels)
|
232
|
+
|
233
|
+
return dice_metric.aggregate().item()
|
234
|
+
|
235
|
+
|
236
|
+
@skip_if_quick
|
237
|
+
class IntegrationSegmentation3D(DistTestCase):
|
238
|
+
def setUp(self):
|
239
|
+
set_determinism(seed=0)
|
240
|
+
|
241
|
+
self.data_dir = tempfile.mkdtemp()
|
242
|
+
for i in range(40):
|
243
|
+
im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
|
244
|
+
n = nib.Nifti1Image(im, np.eye(4))
|
245
|
+
nib.save(n, os.path.join(self.data_dir, f"img{i:d}.nii.gz"))
|
246
|
+
n = nib.Nifti1Image(seg, np.eye(4))
|
247
|
+
nib.save(n, os.path.join(self.data_dir, f"seg{i:d}.nii.gz"))
|
248
|
+
|
249
|
+
self.device = "cuda:0" if torch.cuda.is_available() else "cpu:0"
|
250
|
+
|
251
|
+
def tearDown(self):
|
252
|
+
set_determinism(seed=None)
|
253
|
+
shutil.rmtree(self.data_dir)
|
254
|
+
|
255
|
+
def train_and_infer(self, idx=0):
|
256
|
+
results = []
|
257
|
+
set_determinism(0)
|
258
|
+
_readers = (None, None)
|
259
|
+
if idx == 1:
|
260
|
+
_readers = ("itkreader", "itkreader")
|
261
|
+
elif idx == 2:
|
262
|
+
_readers = ("itkreader", "nibabelreader")
|
263
|
+
losses, best_metric = run_training_test(self.data_dir, device=self.device, cachedataset=idx, readers=_readers)
|
264
|
+
infer_metric = run_inference_test(self.data_dir, device=self.device)
|
265
|
+
|
266
|
+
# check training properties
|
267
|
+
print("losses", losses)
|
268
|
+
print("best metric", best_metric)
|
269
|
+
print("infer metric", infer_metric)
|
270
|
+
self.assertTrue(len(glob(os.path.join(self.data_dir, "runs"))) > 0)
|
271
|
+
model_file = os.path.join(self.data_dir, "best_metric_model.pth")
|
272
|
+
self.assertTrue(os.path.exists(model_file))
|
273
|
+
|
274
|
+
# check inference properties
|
275
|
+
output_files = sorted(glob(os.path.join(self.data_dir, "output", "img*", "*.nii.gz")))
|
276
|
+
print([np.mean(nib.load(output).get_fdata()) for output in output_files])
|
277
|
+
results.extend(losses)
|
278
|
+
results.append(best_metric)
|
279
|
+
results.append(infer_metric)
|
280
|
+
for output in output_files:
|
281
|
+
ave = np.mean(nib.load(output).get_fdata())
|
282
|
+
results.append(ave)
|
283
|
+
self.assertTrue(test_integration_value(TASK, key="losses", data=results[:6], rtol=1e-3))
|
284
|
+
self.assertTrue(test_integration_value(TASK, key="best_metric", data=results[6], rtol=1e-2))
|
285
|
+
self.assertTrue(test_integration_value(TASK, key="infer_metric", data=results[7], rtol=1e-2))
|
286
|
+
self.assertTrue(test_integration_value(TASK, key="output_sums", data=results[8:], rtol=5e-2))
|
287
|
+
return results
|
288
|
+
|
289
|
+
def test_training(self):
|
290
|
+
repeated = []
|
291
|
+
for i in range(4):
|
292
|
+
results = self.train_and_infer(i)
|
293
|
+
repeated.append(results)
|
294
|
+
np.testing.assert_allclose(repeated[0], repeated[1])
|
295
|
+
np.testing.assert_allclose(repeated[0], repeated[2])
|
296
|
+
np.testing.assert_allclose(repeated[0], repeated[3])
|
297
|
+
|
298
|
+
@TimedCall(seconds=360, daemon=False)
|
299
|
+
def test_timing(self):
|
300
|
+
self.train_and_infer(idx=3)
|
301
|
+
|
302
|
+
|
303
|
+
if __name__ == "__main__":
|
304
|
+
unittest.main()
|
@@ -0,0 +1,100 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import os
|
15
|
+
import tempfile
|
16
|
+
import unittest
|
17
|
+
|
18
|
+
import nibabel as nib
|
19
|
+
import numpy as np
|
20
|
+
import torch
|
21
|
+
from ignite.engine import Engine, Events
|
22
|
+
from torch.utils.data import DataLoader
|
23
|
+
|
24
|
+
from monai.data import ImageDataset, create_test_image_3d
|
25
|
+
from monai.inferers import sliding_window_inference
|
26
|
+
from monai.networks import eval_mode, predict_segmentation
|
27
|
+
from monai.networks.nets import UNet
|
28
|
+
from monai.transforms import EnsureChannelFirst, SaveImage
|
29
|
+
from monai.utils import set_determinism
|
30
|
+
from tests.test_utils import DistTestCase, TimedCall, make_nifti_image, skip_if_quick
|
31
|
+
|
32
|
+
|
33
|
+
def run_test(batch_size, img_name, seg_name, output_dir, device="cuda:0"):
|
34
|
+
ds = ImageDataset(
|
35
|
+
[img_name],
|
36
|
+
[seg_name],
|
37
|
+
transform=EnsureChannelFirst(channel_dim="no_channel"),
|
38
|
+
seg_transform=EnsureChannelFirst(channel_dim="no_channel"),
|
39
|
+
image_only=True,
|
40
|
+
)
|
41
|
+
loader = DataLoader(ds, batch_size=1, pin_memory=torch.cuda.is_available())
|
42
|
+
|
43
|
+
net = UNet(
|
44
|
+
spatial_dims=3, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2
|
45
|
+
).to(device)
|
46
|
+
roi_size = (16, 32, 48)
|
47
|
+
sw_batch_size = batch_size
|
48
|
+
|
49
|
+
saver = SaveImage(output_dir=output_dir, output_ext=".nii.gz", output_postfix="seg")
|
50
|
+
|
51
|
+
def _sliding_window_processor(_engine, batch):
|
52
|
+
img = batch[0] # first item from ImageDataset is the input image
|
53
|
+
with eval_mode(net):
|
54
|
+
seg_probs = sliding_window_inference(img.to(device), roi_size, sw_batch_size, net, device=device)
|
55
|
+
return predict_segmentation(seg_probs)
|
56
|
+
|
57
|
+
def save_func(engine):
|
58
|
+
for m in engine.state.output:
|
59
|
+
saver(m)
|
60
|
+
|
61
|
+
infer_engine = Engine(_sliding_window_processor)
|
62
|
+
infer_engine.add_event_handler(Events.ITERATION_COMPLETED, save_func)
|
63
|
+
infer_engine.run(loader)
|
64
|
+
|
65
|
+
basename = os.path.basename(img_name)[: -len(".nii.gz")]
|
66
|
+
saved_name = os.path.join(output_dir, basename, f"{basename}_seg.nii.gz")
|
67
|
+
return saved_name
|
68
|
+
|
69
|
+
|
70
|
+
@skip_if_quick
|
71
|
+
class TestIntegrationSlidingWindow(DistTestCase):
|
72
|
+
def setUp(self):
|
73
|
+
set_determinism(seed=0)
|
74
|
+
|
75
|
+
im, seg = create_test_image_3d(28, 25, 63, rad_max=10, noise_max=1, num_objs=4, num_seg_classes=1)
|
76
|
+
self.img_name = make_nifti_image(im)
|
77
|
+
self.seg_name = make_nifti_image(seg)
|
78
|
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0")
|
79
|
+
|
80
|
+
def tearDown(self):
|
81
|
+
set_determinism(seed=None)
|
82
|
+
if os.path.exists(self.img_name):
|
83
|
+
os.remove(self.img_name)
|
84
|
+
if os.path.exists(self.seg_name):
|
85
|
+
os.remove(self.seg_name)
|
86
|
+
|
87
|
+
@TimedCall(seconds=20)
|
88
|
+
def test_training(self):
|
89
|
+
set_determinism(seed=0)
|
90
|
+
with tempfile.TemporaryDirectory() as tempdir:
|
91
|
+
output_file = run_test(
|
92
|
+
batch_size=2, img_name=self.img_name, seg_name=self.seg_name, output_dir=tempdir, device=self.device
|
93
|
+
)
|
94
|
+
output_image = nib.load(output_file).get_fdata()
|
95
|
+
np.testing.assert_allclose(np.sum(output_image), 33621)
|
96
|
+
np.testing.assert_allclose(output_image.shape, (28, 25, 63))
|
97
|
+
|
98
|
+
|
99
|
+
if __name__ == "__main__":
|
100
|
+
unittest.main()
|
@@ -0,0 +1,133 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
import torch
|
18
|
+
import torch.nn as nn
|
19
|
+
import torch.nn.functional as F
|
20
|
+
import torch.optim as optim
|
21
|
+
|
22
|
+
from monai.data import create_test_image_2d
|
23
|
+
from monai.networks.layers import AffineTransform
|
24
|
+
from monai.utils import set_determinism
|
25
|
+
from tests.test_utils import DistTestCase, TimedCall
|
26
|
+
|
27
|
+
|
28
|
+
class STNBenchmark(nn.Module):
|
29
|
+
"""
|
30
|
+
adapted from https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html
|
31
|
+
"""
|
32
|
+
|
33
|
+
def __init__(self, is_ref=True, reverse_indexing=False):
|
34
|
+
super().__init__()
|
35
|
+
self.is_ref = is_ref
|
36
|
+
self.localization = nn.Sequential(
|
37
|
+
nn.Conv2d(1, 8, kernel_size=7),
|
38
|
+
nn.MaxPool2d(2, stride=2),
|
39
|
+
nn.ReLU(True),
|
40
|
+
nn.Conv2d(8, 10, kernel_size=5),
|
41
|
+
nn.MaxPool2d(2, stride=2),
|
42
|
+
nn.ReLU(True),
|
43
|
+
)
|
44
|
+
# Regressor for the 3 * 2 affine matrix
|
45
|
+
self.fc_loc = nn.Sequential(nn.Linear(10 * 3 * 3, 32), nn.ReLU(True), nn.Linear(32, 3 * 2))
|
46
|
+
# Initialize the weights/bias with identity transformation
|
47
|
+
self.fc_loc[2].weight.data.zero_()
|
48
|
+
self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
|
49
|
+
if not self.is_ref:
|
50
|
+
self.xform = AffineTransform(align_corners=False, normalized=True, reverse_indexing=reverse_indexing)
|
51
|
+
|
52
|
+
# Spatial transformer network forward function
|
53
|
+
def stn_ref(self, x):
|
54
|
+
xs = self.localization(x)
|
55
|
+
xs = xs.view(-1, 10 * 3 * 3)
|
56
|
+
theta = self.fc_loc(xs)
|
57
|
+
theta = theta.view(-1, 2, 3)
|
58
|
+
|
59
|
+
grid = F.affine_grid(theta, x.size(), align_corners=False)
|
60
|
+
x = F.grid_sample(x, grid, align_corners=False)
|
61
|
+
return x
|
62
|
+
|
63
|
+
def stn(self, x):
|
64
|
+
xs = self.localization(x)
|
65
|
+
xs = xs.view(-1, 10 * 3 * 3)
|
66
|
+
theta = self.fc_loc(xs)
|
67
|
+
theta = theta.view(-1, 2, 3)
|
68
|
+
x = self.xform(x, theta, spatial_size=x.size()[2:])
|
69
|
+
return x
|
70
|
+
|
71
|
+
def forward(self, x):
|
72
|
+
if self.is_ref:
|
73
|
+
return self.stn_ref(x)
|
74
|
+
return self.stn(x)
|
75
|
+
|
76
|
+
|
77
|
+
def compare_2d(is_ref=True, device=None, reverse_indexing=False):
|
78
|
+
batch_size = 32
|
79
|
+
img_a = [create_test_image_2d(28, 28, 5, rad_max=6, noise_max=1)[0][None] for _ in range(batch_size)]
|
80
|
+
img_b = [create_test_image_2d(28, 28, 5, rad_max=6, noise_max=1)[0][None] for _ in range(batch_size)]
|
81
|
+
img_a = np.stack(img_a, axis=0)
|
82
|
+
img_b = np.stack(img_b, axis=0)
|
83
|
+
img_a = torch.as_tensor(img_a, device=device)
|
84
|
+
img_b = torch.as_tensor(img_b, device=device)
|
85
|
+
model = STNBenchmark(is_ref=is_ref, reverse_indexing=reverse_indexing).to(device)
|
86
|
+
optimizer = optim.SGD(model.parameters(), lr=0.001)
|
87
|
+
model.train()
|
88
|
+
init_loss = None
|
89
|
+
for _ in range(20):
|
90
|
+
optimizer.zero_grad()
|
91
|
+
output_a = model(img_a)
|
92
|
+
loss = torch.mean((output_a - img_b) ** 2)
|
93
|
+
if init_loss is None:
|
94
|
+
init_loss = loss.item()
|
95
|
+
loss.backward()
|
96
|
+
optimizer.step()
|
97
|
+
return model(img_a).detach().cpu().numpy(), loss.item(), init_loss
|
98
|
+
|
99
|
+
|
100
|
+
class TestSpatialTransformerCore(DistTestCase):
|
101
|
+
def setUp(self):
|
102
|
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0")
|
103
|
+
|
104
|
+
def tearDown(self):
|
105
|
+
set_determinism(seed=None)
|
106
|
+
|
107
|
+
@TimedCall(seconds=100, skip_timing=not torch.cuda.is_available())
|
108
|
+
def test_training(self):
|
109
|
+
"""
|
110
|
+
check that the quality AffineTransform backpropagation
|
111
|
+
"""
|
112
|
+
atol = 1e-5
|
113
|
+
set_determinism(seed=0)
|
114
|
+
out_ref, loss_ref, init_loss_ref = compare_2d(True, self.device)
|
115
|
+
print(out_ref.shape, loss_ref, init_loss_ref)
|
116
|
+
|
117
|
+
set_determinism(seed=0)
|
118
|
+
out, loss, init_loss = compare_2d(False, self.device)
|
119
|
+
print(out.shape, loss, init_loss)
|
120
|
+
np.testing.assert_allclose(out_ref, out, atol=atol)
|
121
|
+
np.testing.assert_allclose(init_loss_ref, init_loss, atol=atol)
|
122
|
+
np.testing.assert_allclose(loss_ref, loss, atol=atol)
|
123
|
+
|
124
|
+
set_determinism(seed=0)
|
125
|
+
out, loss, init_loss = compare_2d(False, self.device, True)
|
126
|
+
print(out.shape, loss, init_loss)
|
127
|
+
np.testing.assert_allclose(out_ref, out, atol=atol)
|
128
|
+
np.testing.assert_allclose(init_loss_ref, init_loss, atol=atol)
|
129
|
+
np.testing.assert_allclose(loss_ref, loss, atol=atol)
|
130
|
+
|
131
|
+
|
132
|
+
if __name__ == "__main__":
|
133
|
+
unittest.main()
|
@@ -0,0 +1,67 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
import torch
|
18
|
+
from ignite.engine import create_supervised_trainer
|
19
|
+
from torch.utils.data import DataLoader, Dataset
|
20
|
+
|
21
|
+
from monai.data import create_test_image_2d
|
22
|
+
from monai.losses import DiceLoss
|
23
|
+
from monai.networks.nets import BasicUNet, UNet
|
24
|
+
from tests.test_utils import DistTestCase, TimedCall, skip_if_quick
|
25
|
+
|
26
|
+
|
27
|
+
def run_test(net_name="basicunet", batch_size=64, train_steps=100, device="cuda:0"):
|
28
|
+
class _TestBatch(Dataset):
|
29
|
+
def __getitem__(self, _unused_id):
|
30
|
+
im, seg = create_test_image_2d(128, 128, noise_max=1, num_objs=4, num_seg_classes=1)
|
31
|
+
return im[None], seg[None].astype(np.float32)
|
32
|
+
|
33
|
+
def __len__(self):
|
34
|
+
return train_steps
|
35
|
+
|
36
|
+
net = None
|
37
|
+
if net_name == "basicunet":
|
38
|
+
net = BasicUNet(spatial_dims=2, in_channels=1, out_channels=1, features=(4, 8, 8, 16, 16, 32))
|
39
|
+
elif net_name == "unet":
|
40
|
+
net = UNet(
|
41
|
+
spatial_dims=2, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2
|
42
|
+
)
|
43
|
+
net.to(device)
|
44
|
+
|
45
|
+
loss = DiceLoss(sigmoid=True)
|
46
|
+
opt = torch.optim.Adam(net.parameters(), 1e-4)
|
47
|
+
src = DataLoader(_TestBatch(), batch_size=batch_size)
|
48
|
+
|
49
|
+
trainer = create_supervised_trainer(net, opt, loss, device, False)
|
50
|
+
|
51
|
+
trainer.run(src, 1)
|
52
|
+
loss = trainer.state.output
|
53
|
+
return loss
|
54
|
+
|
55
|
+
|
56
|
+
@skip_if_quick
|
57
|
+
class TestIntegrationUnet2D(DistTestCase):
|
58
|
+
@TimedCall(seconds=20, daemon=False)
|
59
|
+
def test_unet_training(self):
|
60
|
+
for n in ["basicunet", "unet"]:
|
61
|
+
loss = run_test(net_name=n, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0"))
|
62
|
+
print(loss)
|
63
|
+
self.assertGreaterEqual(0.85, loss)
|
64
|
+
|
65
|
+
|
66
|
+
if __name__ == "__main__":
|
67
|
+
unittest.main()
|
@@ -0,0 +1,61 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
import torch
|
18
|
+
|
19
|
+
from monai.data import DataLoader
|
20
|
+
from monai.utils import set_determinism
|
21
|
+
from tests.test_utils import DistTestCase, SkipIfBeforePyTorchVersion, TimedCall, skip_if_no_cuda, skip_if_quick
|
22
|
+
|
23
|
+
|
24
|
+
def run_loading_test(num_workers=50, device=None, pw=False):
|
25
|
+
"""multi workers stress tests"""
|
26
|
+
set_determinism(seed=0)
|
27
|
+
if device is None:
|
28
|
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
29
|
+
train_ds = list(range(10000))
|
30
|
+
train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=num_workers, persistent_workers=pw)
|
31
|
+
answer = []
|
32
|
+
for _ in range(2):
|
33
|
+
np.testing.assert_equal(torch.cuda.memory_allocated(), 0)
|
34
|
+
for batch_data in train_loader:
|
35
|
+
x = batch_data.to(device)
|
36
|
+
mem = torch.cuda.memory_allocated()
|
37
|
+
np.testing.assert_equal(mem > 0 and mem < 5000, True)
|
38
|
+
answer.append(x[-1].item())
|
39
|
+
del x
|
40
|
+
return answer
|
41
|
+
|
42
|
+
|
43
|
+
@skip_if_quick
|
44
|
+
@skip_if_no_cuda
|
45
|
+
@SkipIfBeforePyTorchVersion((1, 9))
|
46
|
+
class IntegrationLoading(DistTestCase):
|
47
|
+
def tearDown(self):
|
48
|
+
set_determinism(seed=None)
|
49
|
+
|
50
|
+
@TimedCall(seconds=5000, skip_timing=not torch.cuda.is_available(), daemon=False)
|
51
|
+
def test_timing(self):
|
52
|
+
expected = None
|
53
|
+
for pw in (False, True):
|
54
|
+
result = run_loading_test(pw=pw)
|
55
|
+
if expected is None:
|
56
|
+
expected = result[0]
|
57
|
+
np.testing.assert_allclose(result[0], expected) # test for deterministic first epoch in two settings
|
58
|
+
|
59
|
+
|
60
|
+
if __name__ == "__main__":
|
61
|
+
unittest.main()
|