monai-weekly 1.5.dev2506__py3-none-any.whl → 1.5.dev2508__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/auto3dseg/transforms.py +1 -4
- monai/data/utils.py +6 -13
- monai/handlers/__init__.py +1 -0
- monai/handlers/average_precision.py +53 -0
- monai/inferers/inferer.py +10 -7
- monai/inferers/utils.py +1 -2
- monai/losses/dice.py +2 -14
- monai/losses/ds_loss.py +1 -3
- monai/metrics/__init__.py +1 -0
- monai/metrics/average_precision.py +187 -0
- monai/networks/layers/simplelayers.py +2 -14
- monai/networks/utils.py +4 -16
- monai/transforms/compose.py +28 -11
- monai/transforms/croppad/array.py +1 -6
- monai/transforms/io/array.py +0 -1
- monai/transforms/transform.py +15 -6
- monai/transforms/utility/array.py +2 -12
- monai/transforms/utils.py +1 -2
- monai/transforms/utils_pytorch_numpy_unification.py +2 -4
- monai/utils/enums.py +3 -2
- monai/utils/module.py +6 -6
- monai/utils/tf32.py +0 -10
- monai/visualize/class_activation_maps.py +5 -8
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/METADATA +21 -17
- monai_weekly-1.5.dev2508.dist-info/RECORD +1185 -0
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/top_level.txt +1 -0
- tests/apps/__init__.py +10 -0
- tests/apps/deepedit/__init__.py +10 -0
- tests/apps/deepedit/test_deepedit_transforms.py +314 -0
- tests/apps/deepgrow/__init__.py +10 -0
- tests/apps/deepgrow/test_deepgrow_dataset.py +109 -0
- tests/apps/deepgrow/transforms/__init__.py +10 -0
- tests/apps/deepgrow/transforms/test_deepgrow_interaction.py +97 -0
- tests/apps/deepgrow/transforms/test_deepgrow_transforms.py +556 -0
- tests/apps/detection/__init__.py +10 -0
- tests/apps/detection/metrics/__init__.py +10 -0
- tests/apps/detection/metrics/test_detection_coco_metrics.py +69 -0
- tests/apps/detection/networks/__init__.py +10 -0
- tests/apps/detection/networks/test_retinanet.py +210 -0
- tests/apps/detection/networks/test_retinanet_detector.py +203 -0
- tests/apps/detection/test_box_transform.py +370 -0
- tests/apps/detection/utils/__init__.py +10 -0
- tests/apps/detection/utils/test_anchor_box.py +88 -0
- tests/apps/detection/utils/test_atss_box_matcher.py +46 -0
- tests/apps/detection/utils/test_box_coder.py +43 -0
- tests/apps/detection/utils/test_detector_boxselector.py +67 -0
- tests/apps/detection/utils/test_detector_utils.py +96 -0
- tests/apps/detection/utils/test_hardnegsampler.py +54 -0
- tests/apps/nuclick/__init__.py +10 -0
- tests/apps/nuclick/test_nuclick_transforms.py +259 -0
- tests/apps/pathology/__init__.py +10 -0
- tests/apps/pathology/handlers/__init__.py +10 -0
- tests/apps/pathology/handlers/test_from_engine_hovernet.py +38 -0
- tests/apps/pathology/test_lesion_froc.py +333 -0
- tests/apps/pathology/test_pathology_prob_nms.py +55 -0
- tests/apps/pathology/test_prepare_batch_hovernet.py +70 -0
- tests/apps/pathology/test_sliding_window_hovernet_inference.py +303 -0
- tests/apps/pathology/transforms/__init__.py +10 -0
- tests/apps/pathology/transforms/post/__init__.py +10 -0
- tests/apps/pathology/transforms/post/test_generate_distance_map.py +51 -0
- tests/apps/pathology/transforms/post/test_generate_distance_mapd.py +70 -0
- tests/apps/pathology/transforms/post/test_generate_instance_border.py +49 -0
- tests/apps/pathology/transforms/post/test_generate_instance_borderd.py +59 -0
- tests/apps/pathology/transforms/post/test_generate_instance_centroid.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_instance_centroidd.py +56 -0
- tests/apps/pathology/transforms/post/test_generate_instance_contour.py +58 -0
- tests/apps/pathology/transforms/post/test_generate_instance_contourd.py +61 -0
- tests/apps/pathology/transforms/post/test_generate_instance_type.py +51 -0
- tests/apps/pathology/transforms/post/test_generate_instance_typed.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_succinct_contour.py +55 -0
- tests/apps/pathology/transforms/post/test_generate_succinct_contourd.py +57 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_markers.py +53 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_markersd.py +83 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_mask.py +77 -0
- tests/apps/pathology/transforms/post/test_generate_watershed_maskd.py +77 -0
- tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processing.py +61 -0
- tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processingd.py +66 -0
- tests/apps/pathology/transforms/post/test_hovernet_nuclear_type_post_processing.py +66 -0
- tests/apps/pathology/transforms/post/test_watershed.py +60 -0
- tests/apps/pathology/transforms/post/test_watershedd.py +70 -0
- tests/apps/pathology/transforms/test_pathology_he_stain.py +230 -0
- tests/apps/pathology/transforms/test_pathology_he_stain_dict.py +225 -0
- tests/apps/reconstruction/__init__.py +10 -0
- tests/apps/reconstruction/nets/__init__.py +10 -0
- tests/apps/reconstruction/nets/test_recon_net_utils.py +82 -0
- tests/apps/reconstruction/test_complex_utils.py +77 -0
- tests/apps/reconstruction/test_fastmri_reader.py +82 -0
- tests/apps/reconstruction/test_mri_utils.py +37 -0
- tests/apps/reconstruction/transforms/__init__.py +10 -0
- tests/apps/reconstruction/transforms/test_kspace_mask.py +50 -0
- tests/apps/reconstruction/transforms/test_reference_based_normalize_intensity.py +77 -0
- tests/apps/reconstruction/transforms/test_reference_based_spatial_cropd.py +57 -0
- tests/apps/test_auto3dseg_bundlegen.py +156 -0
- tests/apps/test_check_hash.py +53 -0
- tests/apps/test_cross_validation.py +74 -0
- tests/apps/test_decathlondataset.py +93 -0
- tests/apps/test_download_and_extract.py +70 -0
- tests/apps/test_download_url_yandex.py +45 -0
- tests/apps/test_mednistdataset.py +72 -0
- tests/apps/test_mmar_download.py +154 -0
- tests/apps/test_tciadataset.py +123 -0
- tests/apps/vista3d/__init__.py +10 -0
- tests/apps/vista3d/test_point_based_window_inferer.py +77 -0
- tests/apps/vista3d/test_vista3d_sampler.py +100 -0
- tests/apps/vista3d/test_vista3d_transforms.py +94 -0
- tests/bundle/__init__.py +10 -0
- tests/bundle/test_bundle_ckpt_export.py +107 -0
- tests/bundle/test_bundle_download.py +435 -0
- tests/bundle/test_bundle_get_data.py +94 -0
- tests/bundle/test_bundle_push_to_hf_hub.py +41 -0
- tests/bundle/test_bundle_trt_export.py +147 -0
- tests/bundle/test_bundle_utils.py +149 -0
- tests/bundle/test_bundle_verify_metadata.py +66 -0
- tests/bundle/test_bundle_verify_net.py +76 -0
- tests/bundle/test_bundle_workflow.py +272 -0
- tests/bundle/test_component_locator.py +38 -0
- tests/bundle/test_config_item.py +138 -0
- tests/bundle/test_config_parser.py +392 -0
- tests/bundle/test_reference_resolver.py +114 -0
- tests/config/__init__.py +10 -0
- tests/config/test_cv2_dist.py +53 -0
- tests/engines/__init__.py +10 -0
- tests/engines/test_ensemble_evaluator.py +94 -0
- tests/engines/test_prepare_batch_default.py +76 -0
- tests/engines/test_prepare_batch_default_dist.py +76 -0
- tests/engines/test_prepare_batch_diffusion.py +104 -0
- tests/engines/test_prepare_batch_extra_input.py +80 -0
- tests/fl/__init__.py +10 -0
- tests/fl/monai_algo/__init__.py +10 -0
- tests/fl/monai_algo/test_fl_monai_algo.py +251 -0
- tests/fl/monai_algo/test_fl_monai_algo_dist.py +117 -0
- tests/fl/test_fl_monai_algo_stats.py +81 -0
- tests/fl/utils/__init__.py +10 -0
- tests/fl/utils/test_fl_exchange_object.py +63 -0
- tests/handlers/__init__.py +10 -0
- tests/handlers/test_handler_average_precision.py +79 -0
- tests/handlers/test_handler_checkpoint_loader.py +182 -0
- tests/handlers/test_handler_checkpoint_saver.py +233 -0
- tests/handlers/test_handler_classification_saver.py +64 -0
- tests/handlers/test_handler_classification_saver_dist.py +77 -0
- tests/handlers/test_handler_clearml_image.py +65 -0
- tests/handlers/test_handler_clearml_stats.py +65 -0
- tests/handlers/test_handler_confusion_matrix.py +104 -0
- tests/handlers/test_handler_confusion_matrix_dist.py +70 -0
- tests/handlers/test_handler_decollate_batch.py +66 -0
- tests/handlers/test_handler_early_stop.py +68 -0
- tests/handlers/test_handler_garbage_collector.py +73 -0
- tests/handlers/test_handler_hausdorff_distance.py +111 -0
- tests/handlers/test_handler_ignite_metric.py +191 -0
- tests/handlers/test_handler_lr_scheduler.py +94 -0
- tests/handlers/test_handler_mean_dice.py +98 -0
- tests/handlers/test_handler_mean_iou.py +76 -0
- tests/handlers/test_handler_metrics_reloaded.py +149 -0
- tests/handlers/test_handler_metrics_saver.py +89 -0
- tests/handlers/test_handler_metrics_saver_dist.py +120 -0
- tests/handlers/test_handler_mlflow.py +296 -0
- tests/handlers/test_handler_nvtx.py +93 -0
- tests/handlers/test_handler_panoptic_quality.py +89 -0
- tests/handlers/test_handler_parameter_scheduler.py +136 -0
- tests/handlers/test_handler_post_processing.py +74 -0
- tests/handlers/test_handler_prob_map_producer.py +111 -0
- tests/handlers/test_handler_regression_metrics.py +160 -0
- tests/handlers/test_handler_regression_metrics_dist.py +245 -0
- tests/handlers/test_handler_rocauc.py +48 -0
- tests/handlers/test_handler_rocauc_dist.py +54 -0
- tests/handlers/test_handler_stats.py +281 -0
- tests/handlers/test_handler_surface_distance.py +113 -0
- tests/handlers/test_handler_tb_image.py +61 -0
- tests/handlers/test_handler_tb_stats.py +166 -0
- tests/handlers/test_handler_validation.py +59 -0
- tests/handlers/test_trt_compile.py +145 -0
- tests/handlers/test_write_metrics_reports.py +68 -0
- tests/inferers/__init__.py +10 -0
- tests/inferers/test_avg_merger.py +179 -0
- tests/inferers/test_controlnet_inferers.py +1388 -0
- tests/inferers/test_diffusion_inferer.py +236 -0
- tests/inferers/test_latent_diffusion_inferer.py +884 -0
- tests/inferers/test_patch_inferer.py +309 -0
- tests/inferers/test_saliency_inferer.py +55 -0
- tests/inferers/test_slice_inferer.py +57 -0
- tests/inferers/test_sliding_window_inference.py +377 -0
- tests/inferers/test_sliding_window_splitter.py +284 -0
- tests/inferers/test_wsi_sliding_window_splitter.py +249 -0
- tests/inferers/test_zarr_avg_merger.py +326 -0
- tests/integration/__init__.py +10 -0
- tests/integration/test_auto3dseg_ensemble.py +211 -0
- tests/integration/test_auto3dseg_hpo.py +189 -0
- tests/integration/test_deepedit_interaction.py +122 -0
- tests/integration/test_downsample_block.py +50 -0
- tests/integration/test_hovernet_nuclear_type_post_processingd.py +71 -0
- tests/integration/test_integration_autorunner.py +201 -0
- tests/integration/test_integration_bundle_run.py +240 -0
- tests/integration/test_integration_classification_2d.py +282 -0
- tests/integration/test_integration_determinism.py +95 -0
- tests/integration/test_integration_fast_train.py +231 -0
- tests/integration/test_integration_gpu_customization.py +159 -0
- tests/integration/test_integration_lazy_samples.py +219 -0
- tests/integration/test_integration_nnunetv2_runner.py +96 -0
- tests/integration/test_integration_segmentation_3d.py +304 -0
- tests/integration/test_integration_sliding_window.py +100 -0
- tests/integration/test_integration_stn.py +133 -0
- tests/integration/test_integration_unet_2d.py +67 -0
- tests/integration/test_integration_workers.py +61 -0
- tests/integration/test_integration_workflows.py +365 -0
- tests/integration/test_integration_workflows_adversarial.py +173 -0
- tests/integration/test_integration_workflows_gan.py +158 -0
- tests/integration/test_loader_semaphore.py +48 -0
- tests/integration/test_mapping_filed.py +122 -0
- tests/integration/test_meta_affine.py +183 -0
- tests/integration/test_metatensor_integration.py +114 -0
- tests/integration/test_module_list.py +76 -0
- tests/integration/test_one_of.py +283 -0
- tests/integration/test_pad_collation.py +124 -0
- tests/integration/test_reg_loss_integration.py +107 -0
- tests/integration/test_retinanet_predict_utils.py +154 -0
- tests/integration/test_seg_loss_integration.py +159 -0
- tests/integration/test_spatial_combine_transforms.py +185 -0
- tests/integration/test_testtimeaugmentation.py +186 -0
- tests/integration/test_vis_gradbased.py +69 -0
- tests/integration/test_vista3d_utils.py +159 -0
- tests/losses/__init__.py +10 -0
- tests/losses/deform/__init__.py +10 -0
- tests/losses/deform/test_bending_energy.py +88 -0
- tests/losses/deform/test_diffusion_loss.py +117 -0
- tests/losses/image_dissimilarity/__init__.py +10 -0
- tests/losses/image_dissimilarity/test_global_mutual_information_loss.py +150 -0
- tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py +162 -0
- tests/losses/test_adversarial_loss.py +94 -0
- tests/losses/test_barlow_twins_loss.py +109 -0
- tests/losses/test_cldice_loss.py +51 -0
- tests/losses/test_contrastive_loss.py +86 -0
- tests/losses/test_dice_ce_loss.py +123 -0
- tests/losses/test_dice_focal_loss.py +124 -0
- tests/losses/test_dice_loss.py +227 -0
- tests/losses/test_ds_loss.py +189 -0
- tests/losses/test_focal_loss.py +379 -0
- tests/losses/test_generalized_dice_focal_loss.py +85 -0
- tests/losses/test_generalized_dice_loss.py +221 -0
- tests/losses/test_generalized_wasserstein_dice_loss.py +234 -0
- tests/losses/test_giou_loss.py +62 -0
- tests/losses/test_hausdorff_loss.py +264 -0
- tests/losses/test_masked_dice_loss.py +152 -0
- tests/losses/test_masked_loss.py +87 -0
- tests/losses/test_multi_scale.py +86 -0
- tests/losses/test_nacl_loss.py +167 -0
- tests/losses/test_perceptual_loss.py +122 -0
- tests/losses/test_spectral_loss.py +86 -0
- tests/losses/test_ssim_loss.py +59 -0
- tests/losses/test_sure_loss.py +72 -0
- tests/losses/test_tversky_loss.py +198 -0
- tests/losses/test_unified_focal_loss.py +66 -0
- tests/metrics/__init__.py +10 -0
- tests/metrics/test_compute_average_precision.py +162 -0
- tests/metrics/test_compute_confusion_matrix.py +294 -0
- tests/metrics/test_compute_f_beta.py +80 -0
- tests/metrics/test_compute_fid_metric.py +40 -0
- tests/metrics/test_compute_froc.py +143 -0
- tests/metrics/test_compute_generalized_dice.py +240 -0
- tests/metrics/test_compute_meandice.py +306 -0
- tests/metrics/test_compute_meaniou.py +223 -0
- tests/metrics/test_compute_mmd_metric.py +56 -0
- tests/metrics/test_compute_multiscalessim_metric.py +83 -0
- tests/metrics/test_compute_panoptic_quality.py +113 -0
- tests/metrics/test_compute_regression_metrics.py +196 -0
- tests/metrics/test_compute_roc_auc.py +155 -0
- tests/metrics/test_compute_variance.py +147 -0
- tests/metrics/test_cumulative.py +63 -0
- tests/metrics/test_cumulative_average.py +74 -0
- tests/metrics/test_cumulative_average_dist.py +48 -0
- tests/metrics/test_hausdorff_distance.py +209 -0
- tests/metrics/test_label_quality_score.py +134 -0
- tests/metrics/test_loss_metric.py +57 -0
- tests/metrics/test_metrics_reloaded.py +96 -0
- tests/metrics/test_ssim_metric.py +78 -0
- tests/metrics/test_surface_dice.py +416 -0
- tests/metrics/test_surface_distance.py +186 -0
- tests/networks/__init__.py +10 -0
- tests/networks/blocks/__init__.py +10 -0
- tests/networks/blocks/dints_block/__init__.py +10 -0
- tests/networks/blocks/dints_block/test_acn_block.py +41 -0
- tests/networks/blocks/dints_block/test_factorized_increase.py +37 -0
- tests/networks/blocks/dints_block/test_factorized_reduce.py +37 -0
- tests/networks/blocks/dints_block/test_p3d_block.py +78 -0
- tests/networks/blocks/test_adn.py +86 -0
- tests/networks/blocks/test_convolutions.py +156 -0
- tests/networks/blocks/test_crf_cpu.py +513 -0
- tests/networks/blocks/test_crf_cuda.py +528 -0
- tests/networks/blocks/test_crossattention.py +185 -0
- tests/networks/blocks/test_denseblock.py +105 -0
- tests/networks/blocks/test_dynunet_block.py +116 -0
- tests/networks/blocks/test_fpn_block.py +88 -0
- tests/networks/blocks/test_localnet_block.py +121 -0
- tests/networks/blocks/test_mlp.py +78 -0
- tests/networks/blocks/test_patchembedding.py +212 -0
- tests/networks/blocks/test_regunet_block.py +103 -0
- tests/networks/blocks/test_se_block.py +85 -0
- tests/networks/blocks/test_se_blocks.py +78 -0
- tests/networks/blocks/test_segresnet_block.py +57 -0
- tests/networks/blocks/test_selfattention.py +232 -0
- tests/networks/blocks/test_simple_aspp.py +87 -0
- tests/networks/blocks/test_spatialattention.py +55 -0
- tests/networks/blocks/test_subpixel_upsample.py +87 -0
- tests/networks/blocks/test_text_encoding.py +49 -0
- tests/networks/blocks/test_transformerblock.py +90 -0
- tests/networks/blocks/test_unetr_block.py +158 -0
- tests/networks/blocks/test_upsample_block.py +134 -0
- tests/networks/blocks/warp/__init__.py +10 -0
- tests/networks/blocks/warp/test_dvf2ddf.py +72 -0
- tests/networks/blocks/warp/test_warp.py +250 -0
- tests/networks/layers/__init__.py +10 -0
- tests/networks/layers/filtering/__init__.py +10 -0
- tests/networks/layers/filtering/test_bilateral_approx_cpu.py +399 -0
- tests/networks/layers/filtering/test_bilateral_approx_cuda.py +404 -0
- tests/networks/layers/filtering/test_bilateral_precise.py +437 -0
- tests/networks/layers/filtering/test_phl_cpu.py +259 -0
- tests/networks/layers/filtering/test_phl_cuda.py +167 -0
- tests/networks/layers/filtering/test_trainable_bilateral.py +474 -0
- tests/networks/layers/filtering/test_trainable_joint_bilateral.py +609 -0
- tests/networks/layers/test_affine_transform.py +385 -0
- tests/networks/layers/test_apply_filter.py +89 -0
- tests/networks/layers/test_channel_pad.py +51 -0
- tests/networks/layers/test_conjugate_gradient.py +56 -0
- tests/networks/layers/test_drop_path.py +46 -0
- tests/networks/layers/test_gaussian.py +317 -0
- tests/networks/layers/test_gaussian_filter.py +206 -0
- tests/networks/layers/test_get_layers.py +65 -0
- tests/networks/layers/test_gmm.py +314 -0
- tests/networks/layers/test_grid_pull.py +93 -0
- tests/networks/layers/test_hilbert_transform.py +131 -0
- tests/networks/layers/test_lltm.py +62 -0
- tests/networks/layers/test_median_filter.py +52 -0
- tests/networks/layers/test_polyval.py +55 -0
- tests/networks/layers/test_preset_filters.py +136 -0
- tests/networks/layers/test_savitzky_golay_filter.py +141 -0
- tests/networks/layers/test_separable_filter.py +87 -0
- tests/networks/layers/test_skip_connection.py +48 -0
- tests/networks/layers/test_vector_quantizer.py +89 -0
- tests/networks/layers/test_weight_init.py +50 -0
- tests/networks/nets/__init__.py +10 -0
- tests/networks/nets/dints/__init__.py +10 -0
- tests/networks/nets/dints/test_dints_cell.py +110 -0
- tests/networks/nets/dints/test_dints_mixop.py +84 -0
- tests/networks/nets/regunet/__init__.py +10 -0
- tests/networks/nets/regunet/test_localnet.py +86 -0
- tests/networks/nets/regunet/test_regunet.py +88 -0
- tests/networks/nets/test_ahnet.py +224 -0
- tests/networks/nets/test_attentionunet.py +88 -0
- tests/networks/nets/test_autoencoder.py +95 -0
- tests/networks/nets/test_autoencoderkl.py +337 -0
- tests/networks/nets/test_basic_unet.py +102 -0
- tests/networks/nets/test_basic_unetplusplus.py +109 -0
- tests/networks/nets/test_bundle_init_bundle.py +55 -0
- tests/networks/nets/test_cell_sam_wrapper.py +58 -0
- tests/networks/nets/test_controlnet.py +215 -0
- tests/networks/nets/test_daf3d.py +62 -0
- tests/networks/nets/test_densenet.py +121 -0
- tests/networks/nets/test_diffusion_model_unet.py +585 -0
- tests/networks/nets/test_dints_network.py +168 -0
- tests/networks/nets/test_discriminator.py +59 -0
- tests/networks/nets/test_dynunet.py +181 -0
- tests/networks/nets/test_efficientnet.py +400 -0
- tests/networks/nets/test_flexible_unet.py +341 -0
- tests/networks/nets/test_fullyconnectednet.py +69 -0
- tests/networks/nets/test_generator.py +59 -0
- tests/networks/nets/test_globalnet.py +103 -0
- tests/networks/nets/test_highresnet.py +67 -0
- tests/networks/nets/test_hovernet.py +218 -0
- tests/networks/nets/test_mednext.py +122 -0
- tests/networks/nets/test_milmodel.py +92 -0
- tests/networks/nets/test_net_adapter.py +68 -0
- tests/networks/nets/test_network_consistency.py +86 -0
- tests/networks/nets/test_patch_gan_dicriminator.py +179 -0
- tests/networks/nets/test_quicknat.py +57 -0
- tests/networks/nets/test_resnet.py +340 -0
- tests/networks/nets/test_segresnet.py +120 -0
- tests/networks/nets/test_segresnet_ds.py +156 -0
- tests/networks/nets/test_senet.py +151 -0
- tests/networks/nets/test_spade_autoencoderkl.py +295 -0
- tests/networks/nets/test_spade_diffusion_model_unet.py +574 -0
- tests/networks/nets/test_spade_vaegan.py +140 -0
- tests/networks/nets/test_swin_unetr.py +139 -0
- tests/networks/nets/test_torchvision_fc_model.py +201 -0
- tests/networks/nets/test_transchex.py +84 -0
- tests/networks/nets/test_transformer.py +108 -0
- tests/networks/nets/test_unet.py +208 -0
- tests/networks/nets/test_unetr.py +137 -0
- tests/networks/nets/test_varautoencoder.py +127 -0
- tests/networks/nets/test_vista3d.py +84 -0
- tests/networks/nets/test_vit.py +139 -0
- tests/networks/nets/test_vitautoenc.py +112 -0
- tests/networks/nets/test_vnet.py +81 -0
- tests/networks/nets/test_voxelmorph.py +280 -0
- tests/networks/nets/test_vqvae.py +274 -0
- tests/networks/nets/test_vqvaetransformer_inferer.py +295 -0
- tests/networks/schedulers/__init__.py +10 -0
- tests/networks/schedulers/test_scheduler_ddim.py +83 -0
- tests/networks/schedulers/test_scheduler_ddpm.py +104 -0
- tests/networks/schedulers/test_scheduler_pndm.py +108 -0
- tests/networks/test_bundle_onnx_export.py +71 -0
- tests/networks/test_convert_to_onnx.py +106 -0
- tests/networks/test_convert_to_torchscript.py +46 -0
- tests/networks/test_convert_to_trt.py +79 -0
- tests/networks/test_save_state.py +73 -0
- tests/networks/test_to_onehot.py +63 -0
- tests/networks/test_varnet.py +63 -0
- tests/networks/utils/__init__.py +10 -0
- tests/networks/utils/test_copy_model_state.py +187 -0
- tests/networks/utils/test_eval_mode.py +34 -0
- tests/networks/utils/test_freeze_layers.py +61 -0
- tests/networks/utils/test_replace_module.py +98 -0
- tests/networks/utils/test_train_mode.py +34 -0
- tests/optimizers/__init__.py +10 -0
- tests/optimizers/test_generate_param_groups.py +105 -0
- tests/optimizers/test_lr_finder.py +108 -0
- tests/optimizers/test_lr_scheduler.py +71 -0
- tests/optimizers/test_optim_novograd.py +100 -0
- tests/profile_subclass/__init__.py +10 -0
- tests/profile_subclass/cprofile_profiling.py +29 -0
- tests/profile_subclass/min_classes.py +30 -0
- tests/profile_subclass/profiling.py +73 -0
- tests/profile_subclass/pyspy_profiling.py +41 -0
- tests/transforms/__init__.py +10 -0
- tests/transforms/compose/__init__.py +10 -0
- tests/transforms/compose/test_compose.py +758 -0
- tests/transforms/compose/test_some_of.py +258 -0
- tests/transforms/croppad/__init__.py +10 -0
- tests/transforms/croppad/test_rand_weighted_crop.py +224 -0
- tests/transforms/croppad/test_rand_weighted_cropd.py +182 -0
- tests/transforms/functional/__init__.py +10 -0
- tests/transforms/functional/test_apply.py +75 -0
- tests/transforms/functional/test_resample.py +50 -0
- tests/transforms/intensity/__init__.py +10 -0
- tests/transforms/intensity/test_compute_ho_ver_maps.py +75 -0
- tests/transforms/intensity/test_compute_ho_ver_maps_d.py +79 -0
- tests/transforms/intensity/test_foreground_mask.py +98 -0
- tests/transforms/intensity/test_foreground_maskd.py +106 -0
- tests/transforms/intensity/test_rand_histogram_shiftd.py +76 -0
- tests/transforms/intensity/test_scale_intensity_range_percentiles.py +96 -0
- tests/transforms/intensity/test_scale_intensity_range_percentilesd.py +100 -0
- tests/transforms/inverse/__init__.py +10 -0
- tests/transforms/inverse/test_inverse_array.py +76 -0
- tests/transforms/inverse/test_traceable_transform.py +59 -0
- tests/transforms/post/__init__.py +10 -0
- tests/transforms/post/test_label_filterd.py +78 -0
- tests/transforms/post/test_probnms.py +72 -0
- tests/transforms/post/test_probnmsd.py +79 -0
- tests/transforms/post/test_remove_small_objects.py +102 -0
- tests/transforms/spatial/__init__.py +10 -0
- tests/transforms/spatial/test_convert_box_points.py +119 -0
- tests/transforms/spatial/test_grid_patch.py +134 -0
- tests/transforms/spatial/test_grid_patchd.py +102 -0
- tests/transforms/spatial/test_rand_grid_patch.py +150 -0
- tests/transforms/spatial/test_rand_grid_patchd.py +117 -0
- tests/transforms/spatial/test_spatial_resampled.py +124 -0
- tests/transforms/test_activations.py +120 -0
- tests/transforms/test_activationsd.py +64 -0
- tests/transforms/test_adaptors.py +160 -0
- tests/transforms/test_add_coordinate_channels.py +53 -0
- tests/transforms/test_add_coordinate_channelsd.py +67 -0
- tests/transforms/test_add_extreme_points_channel.py +80 -0
- tests/transforms/test_add_extreme_points_channeld.py +77 -0
- tests/transforms/test_adjust_contrast.py +70 -0
- tests/transforms/test_adjust_contrastd.py +64 -0
- tests/transforms/test_affine.py +245 -0
- tests/transforms/test_affine_grid.py +152 -0
- tests/transforms/test_affined.py +190 -0
- tests/transforms/test_as_channel_last.py +38 -0
- tests/transforms/test_as_channel_lastd.py +44 -0
- tests/transforms/test_as_discrete.py +81 -0
- tests/transforms/test_as_discreted.py +82 -0
- tests/transforms/test_border_pad.py +49 -0
- tests/transforms/test_border_padd.py +45 -0
- tests/transforms/test_bounding_rect.py +54 -0
- tests/transforms/test_bounding_rectd.py +53 -0
- tests/transforms/test_cast_to_type.py +63 -0
- tests/transforms/test_cast_to_typed.py +74 -0
- tests/transforms/test_center_scale_crop.py +55 -0
- tests/transforms/test_center_scale_cropd.py +56 -0
- tests/transforms/test_center_spatial_crop.py +56 -0
- tests/transforms/test_center_spatial_cropd.py +63 -0
- tests/transforms/test_classes_to_indices.py +93 -0
- tests/transforms/test_classes_to_indicesd.py +110 -0
- tests/transforms/test_clip_intensity_percentiles.py +196 -0
- tests/transforms/test_clip_intensity_percentilesd.py +193 -0
- tests/transforms/test_compose_get_number_conversions.py +127 -0
- tests/transforms/test_concat_itemsd.py +82 -0
- tests/transforms/test_convert_to_multi_channel.py +59 -0
- tests/transforms/test_convert_to_multi_channeld.py +37 -0
- tests/transforms/test_copy_itemsd.py +86 -0
- tests/transforms/test_create_grid_and_affine.py +274 -0
- tests/transforms/test_crop_foreground.py +164 -0
- tests/transforms/test_crop_foregroundd.py +205 -0
- tests/transforms/test_cucim_dict_transform.py +142 -0
- tests/transforms/test_cucim_transform.py +141 -0
- tests/transforms/test_data_stats.py +221 -0
- tests/transforms/test_data_statsd.py +249 -0
- tests/transforms/test_delete_itemsd.py +58 -0
- tests/transforms/test_detect_envelope.py +159 -0
- tests/transforms/test_distance_transform_edt.py +202 -0
- tests/transforms/test_divisible_pad.py +49 -0
- tests/transforms/test_divisible_padd.py +42 -0
- tests/transforms/test_ensure_channel_first.py +113 -0
- tests/transforms/test_ensure_channel_firstd.py +85 -0
- tests/transforms/test_ensure_type.py +94 -0
- tests/transforms/test_ensure_typed.py +110 -0
- tests/transforms/test_fg_bg_to_indices.py +83 -0
- tests/transforms/test_fg_bg_to_indicesd.py +78 -0
- tests/transforms/test_fill_holes.py +207 -0
- tests/transforms/test_fill_holesd.py +209 -0
- tests/transforms/test_flatten_sub_keysd.py +64 -0
- tests/transforms/test_flip.py +83 -0
- tests/transforms/test_flipd.py +90 -0
- tests/transforms/test_fourier.py +70 -0
- tests/transforms/test_gaussian_sharpen.py +92 -0
- tests/transforms/test_gaussian_sharpend.py +92 -0
- tests/transforms/test_gaussian_smooth.py +96 -0
- tests/transforms/test_gaussian_smoothd.py +96 -0
- tests/transforms/test_generate_label_classes_crop_centers.py +71 -0
- tests/transforms/test_generate_pos_neg_label_crop_centers.py +76 -0
- tests/transforms/test_generate_spatial_bounding_box.py +114 -0
- tests/transforms/test_get_extreme_points.py +57 -0
- tests/transforms/test_gibbs_noise.py +73 -0
- tests/transforms/test_gibbs_noised.py +88 -0
- tests/transforms/test_grid_distortion.py +113 -0
- tests/transforms/test_grid_distortiond.py +87 -0
- tests/transforms/test_grid_split.py +88 -0
- tests/transforms/test_grid_splitd.py +96 -0
- tests/transforms/test_histogram_normalize.py +59 -0
- tests/transforms/test_histogram_normalized.py +59 -0
- tests/transforms/test_image_filter.py +259 -0
- tests/transforms/test_intensity_stats.py +73 -0
- tests/transforms/test_intensity_statsd.py +90 -0
- tests/transforms/test_inverse.py +521 -0
- tests/transforms/test_inverse_collation.py +147 -0
- tests/transforms/test_invert.py +105 -0
- tests/transforms/test_invertd.py +142 -0
- tests/transforms/test_k_space_spike_noise.py +81 -0
- tests/transforms/test_k_space_spike_noised.py +98 -0
- tests/transforms/test_keep_largest_connected_component.py +419 -0
- tests/transforms/test_keep_largest_connected_componentd.py +348 -0
- tests/transforms/test_label_filter.py +78 -0
- tests/transforms/test_label_to_contour.py +179 -0
- tests/transforms/test_label_to_contourd.py +182 -0
- tests/transforms/test_label_to_mask.py +69 -0
- tests/transforms/test_label_to_maskd.py +70 -0
- tests/transforms/test_load_image.py +502 -0
- tests/transforms/test_load_imaged.py +198 -0
- tests/transforms/test_load_spacing_orientation.py +149 -0
- tests/transforms/test_map_and_generate_sampling_centers.py +86 -0
- tests/transforms/test_map_binary_to_indices.py +75 -0
- tests/transforms/test_map_classes_to_indices.py +135 -0
- tests/transforms/test_map_label_value.py +89 -0
- tests/transforms/test_map_label_valued.py +85 -0
- tests/transforms/test_map_transform.py +45 -0
- tests/transforms/test_mask_intensity.py +74 -0
- tests/transforms/test_mask_intensityd.py +68 -0
- tests/transforms/test_mean_ensemble.py +77 -0
- tests/transforms/test_mean_ensembled.py +91 -0
- tests/transforms/test_median_smooth.py +41 -0
- tests/transforms/test_median_smoothd.py +65 -0
- tests/transforms/test_morphological_ops.py +101 -0
- tests/transforms/test_nifti_endianness.py +107 -0
- tests/transforms/test_normalize_intensity.py +143 -0
- tests/transforms/test_normalize_intensityd.py +81 -0
- tests/transforms/test_nvtx_decorator.py +289 -0
- tests/transforms/test_nvtx_transform.py +143 -0
- tests/transforms/test_orientation.py +247 -0
- tests/transforms/test_orientationd.py +112 -0
- tests/transforms/test_rand_adjust_contrast.py +45 -0
- tests/transforms/test_rand_adjust_contrastd.py +44 -0
- tests/transforms/test_rand_affine.py +201 -0
- tests/transforms/test_rand_affine_grid.py +212 -0
- tests/transforms/test_rand_affined.py +281 -0
- tests/transforms/test_rand_axis_flip.py +50 -0
- tests/transforms/test_rand_axis_flipd.py +50 -0
- tests/transforms/test_rand_bias_field.py +69 -0
- tests/transforms/test_rand_bias_fieldd.py +65 -0
- tests/transforms/test_rand_coarse_dropout.py +110 -0
- tests/transforms/test_rand_coarse_dropoutd.py +107 -0
- tests/transforms/test_rand_coarse_shuffle.py +65 -0
- tests/transforms/test_rand_coarse_shuffled.py +59 -0
- tests/transforms/test_rand_crop_by_label_classes.py +170 -0
- tests/transforms/test_rand_crop_by_label_classesd.py +159 -0
- tests/transforms/test_rand_crop_by_pos_neg_label.py +152 -0
- tests/transforms/test_rand_crop_by_pos_neg_labeld.py +172 -0
- tests/transforms/test_rand_cucim_dict_transform.py +162 -0
- tests/transforms/test_rand_cucim_transform.py +162 -0
- tests/transforms/test_rand_deform_grid.py +138 -0
- tests/transforms/test_rand_elastic_2d.py +127 -0
- tests/transforms/test_rand_elastic_3d.py +104 -0
- tests/transforms/test_rand_elasticd_2d.py +177 -0
- tests/transforms/test_rand_elasticd_3d.py +156 -0
- tests/transforms/test_rand_flip.py +60 -0
- tests/transforms/test_rand_flipd.py +55 -0
- tests/transforms/test_rand_gaussian_noise.py +48 -0
- tests/transforms/test_rand_gaussian_noised.py +54 -0
- tests/transforms/test_rand_gaussian_sharpen.py +140 -0
- tests/transforms/test_rand_gaussian_sharpend.py +143 -0
- tests/transforms/test_rand_gaussian_smooth.py +98 -0
- tests/transforms/test_rand_gaussian_smoothd.py +98 -0
- tests/transforms/test_rand_gibbs_noise.py +103 -0
- tests/transforms/test_rand_gibbs_noised.py +117 -0
- tests/transforms/test_rand_grid_distortion.py +99 -0
- tests/transforms/test_rand_grid_distortiond.py +90 -0
- tests/transforms/test_rand_histogram_shift.py +92 -0
- tests/transforms/test_rand_k_space_spike_noise.py +92 -0
- tests/transforms/test_rand_k_space_spike_noised.py +76 -0
- tests/transforms/test_rand_rician_noise.py +52 -0
- tests/transforms/test_rand_rician_noised.py +52 -0
- tests/transforms/test_rand_rotate.py +166 -0
- tests/transforms/test_rand_rotate90.py +100 -0
- tests/transforms/test_rand_rotate90d.py +112 -0
- tests/transforms/test_rand_rotated.py +187 -0
- tests/transforms/test_rand_scale_crop.py +78 -0
- tests/transforms/test_rand_scale_cropd.py +98 -0
- tests/transforms/test_rand_scale_intensity.py +54 -0
- tests/transforms/test_rand_scale_intensity_fixed_mean.py +41 -0
- tests/transforms/test_rand_scale_intensity_fixed_meand.py +41 -0
- tests/transforms/test_rand_scale_intensityd.py +53 -0
- tests/transforms/test_rand_shift_intensity.py +52 -0
- tests/transforms/test_rand_shift_intensityd.py +67 -0
- tests/transforms/test_rand_simulate_low_resolution.py +83 -0
- tests/transforms/test_rand_simulate_low_resolutiond.py +73 -0
- tests/transforms/test_rand_spatial_crop.py +107 -0
- tests/transforms/test_rand_spatial_crop_samples.py +128 -0
- tests/transforms/test_rand_spatial_crop_samplesd.py +147 -0
- tests/transforms/test_rand_spatial_cropd.py +112 -0
- tests/transforms/test_rand_std_shift_intensity.py +43 -0
- tests/transforms/test_rand_std_shift_intensityd.py +38 -0
- tests/transforms/test_rand_zoom.py +105 -0
- tests/transforms/test_rand_zoomd.py +108 -0
- tests/transforms/test_randidentity.py +49 -0
- tests/transforms/test_random_order.py +144 -0
- tests/transforms/test_randtorchvisiond.py +65 -0
- tests/transforms/test_regularization.py +139 -0
- tests/transforms/test_remove_repeated_channel.py +34 -0
- tests/transforms/test_remove_repeated_channeld.py +44 -0
- tests/transforms/test_repeat_channel.py +34 -0
- tests/transforms/test_repeat_channeld.py +41 -0
- tests/transforms/test_resample_backends.py +65 -0
- tests/transforms/test_resample_to_match.py +110 -0
- tests/transforms/test_resample_to_matchd.py +93 -0
- tests/transforms/test_resampler.py +165 -0
- tests/transforms/test_resize.py +140 -0
- tests/transforms/test_resize_with_pad_or_crop.py +91 -0
- tests/transforms/test_resize_with_pad_or_cropd.py +86 -0
- tests/transforms/test_resized.py +163 -0
- tests/transforms/test_rotate.py +160 -0
- tests/transforms/test_rotate90.py +212 -0
- tests/transforms/test_rotate90d.py +106 -0
- tests/transforms/test_rotated.py +179 -0
- tests/transforms/test_save_classificationd.py +109 -0
- tests/transforms/test_save_image.py +80 -0
- tests/transforms/test_save_imaged.py +130 -0
- tests/transforms/test_savitzky_golay_smooth.py +73 -0
- tests/transforms/test_savitzky_golay_smoothd.py +73 -0
- tests/transforms/test_scale_intensity.py +76 -0
- tests/transforms/test_scale_intensity_fixed_mean.py +94 -0
- tests/transforms/test_scale_intensity_range.py +41 -0
- tests/transforms/test_scale_intensity_ranged.py +40 -0
- tests/transforms/test_scale_intensityd.py +57 -0
- tests/transforms/test_select_itemsd.py +41 -0
- tests/transforms/test_shift_intensity.py +31 -0
- tests/transforms/test_shift_intensityd.py +44 -0
- tests/transforms/test_signal_continuouswavelet.py +44 -0
- tests/transforms/test_signal_fillempty.py +52 -0
- tests/transforms/test_signal_fillemptyd.py +60 -0
- tests/transforms/test_signal_rand_add_gaussiannoise.py +50 -0
- tests/transforms/test_signal_rand_add_sine.py +52 -0
- tests/transforms/test_signal_rand_add_sine_partial.py +50 -0
- tests/transforms/test_signal_rand_add_squarepulse.py +58 -0
- tests/transforms/test_signal_rand_add_squarepulse_partial.py +62 -0
- tests/transforms/test_signal_rand_drop.py +50 -0
- tests/transforms/test_signal_rand_scale.py +52 -0
- tests/transforms/test_signal_rand_shift.py +55 -0
- tests/transforms/test_signal_remove_frequency.py +71 -0
- tests/transforms/test_smooth_field.py +177 -0
- tests/transforms/test_sobel_gradient.py +189 -0
- tests/transforms/test_sobel_gradientd.py +212 -0
- tests/transforms/test_spacing.py +381 -0
- tests/transforms/test_spacingd.py +178 -0
- tests/transforms/test_spatial_crop.py +82 -0
- tests/transforms/test_spatial_cropd.py +74 -0
- tests/transforms/test_spatial_pad.py +57 -0
- tests/transforms/test_spatial_padd.py +43 -0
- tests/transforms/test_spatial_resample.py +235 -0
- tests/transforms/test_squeezedim.py +62 -0
- tests/transforms/test_squeezedimd.py +98 -0
- tests/transforms/test_std_shift_intensity.py +76 -0
- tests/transforms/test_std_shift_intensityd.py +74 -0
- tests/transforms/test_threshold_intensity.py +38 -0
- tests/transforms/test_threshold_intensityd.py +58 -0
- tests/transforms/test_to_contiguous.py +47 -0
- tests/transforms/test_to_cupy.py +112 -0
- tests/transforms/test_to_cupyd.py +76 -0
- tests/transforms/test_to_device.py +42 -0
- tests/transforms/test_to_deviced.py +37 -0
- tests/transforms/test_to_numpy.py +85 -0
- tests/transforms/test_to_numpyd.py +68 -0
- tests/transforms/test_to_pil.py +52 -0
- tests/transforms/test_to_pild.py +50 -0
- tests/transforms/test_to_tensor.py +60 -0
- tests/transforms/test_to_tensord.py +71 -0
- tests/transforms/test_torchvision.py +66 -0
- tests/transforms/test_torchvisiond.py +63 -0
- tests/transforms/test_transform.py +62 -0
- tests/transforms/test_transpose.py +41 -0
- tests/transforms/test_transposed.py +52 -0
- tests/transforms/test_ultrasound_confidence_map_transform.py +711 -0
- tests/transforms/test_utils_pytorch_numpy_unification.py +90 -0
- tests/transforms/test_vote_ensemble.py +84 -0
- tests/transforms/test_vote_ensembled.py +107 -0
- tests/transforms/test_with_allow_missing_keys.py +76 -0
- tests/transforms/test_zoom.py +120 -0
- tests/transforms/test_zoomd.py +94 -0
- tests/transforms/transform/__init__.py +10 -0
- tests/transforms/transform/test_randomizable.py +52 -0
- tests/transforms/transform/test_randomizable_transform_type.py +37 -0
- tests/transforms/utility/__init__.py +10 -0
- tests/transforms/utility/test_apply_transform_to_points.py +81 -0
- tests/transforms/utility/test_apply_transform_to_pointsd.py +185 -0
- tests/transforms/utility/test_identity.py +29 -0
- tests/transforms/utility/test_identityd.py +30 -0
- tests/transforms/utility/test_lambda.py +71 -0
- tests/transforms/utility/test_lambdad.py +83 -0
- tests/transforms/utility/test_rand_lambda.py +87 -0
- tests/transforms/utility/test_rand_lambdad.py +77 -0
- tests/transforms/utility/test_simulatedelay.py +36 -0
- tests/transforms/utility/test_simulatedelayd.py +36 -0
- tests/transforms/utility/test_splitdim.py +52 -0
- tests/transforms/utility/test_splitdimd.py +96 -0
- tests/transforms/utils/__init__.py +10 -0
- tests/transforms/utils/test_correct_crop_centers.py +36 -0
- tests/transforms/utils/test_get_unique_labels.py +45 -0
- tests/transforms/utils/test_print_transform_backends.py +29 -0
- tests/transforms/utils/test_soft_clip.py +125 -0
- tests/utils/__init__.py +10 -0
- tests/utils/enums/__init__.py +10 -0
- tests/utils/enums/test_hovernet_loss.py +190 -0
- tests/utils/enums/test_ordering.py +289 -0
- tests/utils/enums/test_wsireader.py +663 -0
- tests/utils/misc/__init__.py +10 -0
- tests/utils/misc/test_ensure_tuple.py +53 -0
- tests/utils/misc/test_monai_env_vars.py +44 -0
- tests/utils/misc/test_monai_utils_misc.py +103 -0
- tests/utils/misc/test_str2bool.py +34 -0
- tests/utils/misc/test_str2list.py +33 -0
- tests/utils/test_alias.py +44 -0
- tests/utils/test_component_store.py +73 -0
- tests/utils/test_deprecated.py +455 -0
- tests/utils/test_enum_bound_interp.py +75 -0
- tests/utils/test_evenly_divisible_all_gather_dist.py +50 -0
- tests/utils/test_get_package_version.py +34 -0
- tests/utils/test_handler_logfile.py +84 -0
- tests/utils/test_handler_metric_logger.py +62 -0
- tests/utils/test_list_to_dict.py +43 -0
- tests/utils/test_look_up_option.py +87 -0
- tests/utils/test_optional_import.py +80 -0
- tests/utils/test_pad_mode.py +39 -0
- tests/utils/test_profiling.py +208 -0
- tests/utils/test_rankfilter_dist.py +77 -0
- tests/utils/test_require_pkg.py +83 -0
- tests/utils/test_sample_slices.py +43 -0
- tests/utils/test_set_determinism.py +74 -0
- tests/utils/test_squeeze_unsqueeze.py +71 -0
- tests/utils/test_state_cacher.py +67 -0
- tests/utils/test_torchscript_utils.py +113 -0
- tests/utils/test_version.py +91 -0
- tests/utils/test_version_after.py +65 -0
- tests/utils/type_conversion/__init__.py +10 -0
- tests/utils/type_conversion/test_convert_data_type.py +152 -0
- tests/utils/type_conversion/test_get_equivalent_dtype.py +65 -0
- tests/utils/type_conversion/test_safe_dtype_range.py +99 -0
- tests/visualize/__init__.py +10 -0
- tests/visualize/test_img2tensorboard.py +46 -0
- tests/visualize/test_occlusion_sensitivity.py +128 -0
- tests/visualize/test_plot_2d_or_3d_image.py +74 -0
- tests/visualize/test_vis_cam.py +98 -0
- tests/visualize/test_vis_gradcam.py +211 -0
- tests/visualize/utils/__init__.py +10 -0
- tests/visualize/utils/test_blend_images.py +63 -0
- tests/visualize/utils/test_matshow3d.py +133 -0
- monai_weekly-1.5.dev2506.dist-info/RECORD +0 -427
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/LICENSE +0 -0
- {monai_weekly-1.5.dev2506.dist-info → monai_weekly-1.5.dev2508.dist-info}/WHEEL +0 -0
@@ -0,0 +1,232 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
from unittest import skipUnless
|
16
|
+
|
17
|
+
import numpy as np
|
18
|
+
import torch
|
19
|
+
from parameterized import parameterized
|
20
|
+
|
21
|
+
from monai.networks import eval_mode
|
22
|
+
from monai.networks.blocks.selfattention import SABlock
|
23
|
+
from monai.networks.layers.factories import RelPosEmbedding
|
24
|
+
from monai.utils import optional_import
|
25
|
+
from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose, test_script_save
|
26
|
+
|
27
|
+
einops, has_einops = optional_import("einops")
|
28
|
+
|
29
|
+
TEST_CASE_SABLOCK = []
|
30
|
+
for dropout_rate in np.linspace(0, 1, 4):
|
31
|
+
for hidden_size in [360, 480, 600, 768]:
|
32
|
+
for num_heads in [4, 6, 8, 12]:
|
33
|
+
for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]:
|
34
|
+
for input_size in [(16, 32), (8, 8, 8)]:
|
35
|
+
for include_fc in [True, False]:
|
36
|
+
for use_combined_linear in [True, False]:
|
37
|
+
test_case = [
|
38
|
+
{
|
39
|
+
"hidden_size": hidden_size,
|
40
|
+
"num_heads": num_heads,
|
41
|
+
"dropout_rate": dropout_rate,
|
42
|
+
"rel_pos_embedding": rel_pos_embedding,
|
43
|
+
"input_size": input_size,
|
44
|
+
"include_fc": include_fc,
|
45
|
+
"use_combined_linear": use_combined_linear,
|
46
|
+
"use_flash_attention": True if rel_pos_embedding is None else False,
|
47
|
+
},
|
48
|
+
(2, 512, hidden_size),
|
49
|
+
(2, 512, hidden_size),
|
50
|
+
]
|
51
|
+
TEST_CASE_SABLOCK.append(test_case)
|
52
|
+
|
53
|
+
|
54
|
+
class TestResBlock(unittest.TestCase):
|
55
|
+
@parameterized.expand(TEST_CASE_SABLOCK)
|
56
|
+
@skipUnless(has_einops, "Requires einops")
|
57
|
+
@SkipIfBeforePyTorchVersion((2, 0))
|
58
|
+
def test_shape(self, input_param, input_shape, expected_shape):
|
59
|
+
net = SABlock(**input_param)
|
60
|
+
with eval_mode(net):
|
61
|
+
result = net(torch.randn(input_shape))
|
62
|
+
self.assertEqual(result.shape, expected_shape)
|
63
|
+
|
64
|
+
def test_ill_arg(self):
|
65
|
+
with self.assertRaises(ValueError):
|
66
|
+
SABlock(hidden_size=128, num_heads=12, dropout_rate=6.0)
|
67
|
+
|
68
|
+
with self.assertRaises(ValueError):
|
69
|
+
SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4)
|
70
|
+
|
71
|
+
@SkipIfBeforePyTorchVersion((2, 0))
|
72
|
+
def test_rel_pos_embedding_with_flash_attention(self):
|
73
|
+
with self.assertRaises(ValueError):
|
74
|
+
SABlock(
|
75
|
+
hidden_size=128,
|
76
|
+
num_heads=3,
|
77
|
+
dropout_rate=0.1,
|
78
|
+
use_flash_attention=True,
|
79
|
+
save_attn=False,
|
80
|
+
rel_pos_embedding=RelPosEmbedding.DECOMPOSED,
|
81
|
+
)
|
82
|
+
|
83
|
+
@SkipIfBeforePyTorchVersion((1, 13))
|
84
|
+
def test_save_attn_with_flash_attention(self):
|
85
|
+
with self.assertRaises(ValueError):
|
86
|
+
SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True)
|
87
|
+
|
88
|
+
def test_attention_dim_not_multiple_of_heads(self):
|
89
|
+
with self.assertRaises(ValueError):
|
90
|
+
SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1)
|
91
|
+
|
92
|
+
@skipUnless(has_einops, "Requires einops")
|
93
|
+
def test_inner_dim_different(self):
|
94
|
+
SABlock(hidden_size=128, num_heads=4, dropout_rate=0.1, dim_head=30)
|
95
|
+
|
96
|
+
def test_causal_no_sequence_length(self):
|
97
|
+
with self.assertRaises(ValueError):
|
98
|
+
SABlock(hidden_size=128, num_heads=4, dropout_rate=0.1, causal=True)
|
99
|
+
|
100
|
+
@skipUnless(has_einops, "Requires einops")
|
101
|
+
@SkipIfBeforePyTorchVersion((2, 0))
|
102
|
+
def test_causal_flash_attention(self):
|
103
|
+
block = SABlock(
|
104
|
+
hidden_size=128,
|
105
|
+
num_heads=1,
|
106
|
+
dropout_rate=0.1,
|
107
|
+
causal=True,
|
108
|
+
sequence_length=16,
|
109
|
+
save_attn=False,
|
110
|
+
use_flash_attention=True,
|
111
|
+
)
|
112
|
+
input_shape = (1, 16, 128)
|
113
|
+
# Check it runs correctly
|
114
|
+
block(torch.randn(input_shape))
|
115
|
+
|
116
|
+
@skipUnless(has_einops, "Requires einops")
|
117
|
+
def test_causal(self):
|
118
|
+
block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, save_attn=True)
|
119
|
+
input_shape = (1, 16, 128)
|
120
|
+
block(torch.randn(input_shape))
|
121
|
+
# check upper triangular part of the attention matrix is zero
|
122
|
+
assert torch.triu(block.att_mat, diagonal=1).sum() == 0
|
123
|
+
|
124
|
+
def test_masked_selfattention(self):
|
125
|
+
n = 64
|
126
|
+
block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, sequence_length=16, save_attn=True)
|
127
|
+
input_shape = (1, n, 128)
|
128
|
+
# generate a mask randomly with zeros and ones of shape (1, n)
|
129
|
+
mask = torch.randint(0, 2, (1, n)).bool()
|
130
|
+
block(torch.randn(input_shape), attn_mask=mask)
|
131
|
+
att_mat = block.att_mat.squeeze()
|
132
|
+
# ensure all masked columns are zeros
|
133
|
+
assert torch.allclose(att_mat[:, ~mask.squeeze(0)], torch.zeros_like(att_mat[:, ~mask.squeeze(0)]))
|
134
|
+
|
135
|
+
def test_causal_and_mask(self):
|
136
|
+
with self.assertRaises(ValueError):
|
137
|
+
block = SABlock(hidden_size=128, num_heads=1, causal=True, sequence_length=64)
|
138
|
+
inputs = torch.randn(2, 64, 128)
|
139
|
+
mask = torch.randint(0, 2, (2, 64)).bool()
|
140
|
+
block(inputs, attn_mask=mask)
|
141
|
+
|
142
|
+
@skipUnless(has_einops, "Requires einops")
|
143
|
+
def test_access_attn_matrix(self):
|
144
|
+
# input format
|
145
|
+
hidden_size = 128
|
146
|
+
num_heads = 2
|
147
|
+
dropout_rate = 0
|
148
|
+
input_shape = (2, 256, hidden_size)
|
149
|
+
|
150
|
+
# be not able to access the matrix
|
151
|
+
no_matrix_acess_blk = SABlock(hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate)
|
152
|
+
no_matrix_acess_blk(torch.randn(input_shape))
|
153
|
+
assert isinstance(no_matrix_acess_blk.att_mat, torch.Tensor)
|
154
|
+
# no of elements is zero
|
155
|
+
assert no_matrix_acess_blk.att_mat.nelement() == 0
|
156
|
+
|
157
|
+
# be able to acess the attention matrix
|
158
|
+
matrix_acess_blk = SABlock(
|
159
|
+
hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True
|
160
|
+
)
|
161
|
+
matrix_acess_blk(torch.randn(input_shape))
|
162
|
+
assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1])
|
163
|
+
|
164
|
+
def test_number_of_parameters(self):
|
165
|
+
def count_sablock_params(*args, **kwargs):
|
166
|
+
"""Count the number of parameters in a SABlock."""
|
167
|
+
sablock = SABlock(*args, **kwargs)
|
168
|
+
return sum([x.numel() for x in sablock.parameters() if x.requires_grad])
|
169
|
+
|
170
|
+
hidden_size = 128
|
171
|
+
num_heads = 8
|
172
|
+
default_dim_head = hidden_size // num_heads
|
173
|
+
|
174
|
+
# Default dim_head is hidden_size // num_heads
|
175
|
+
nparams_default = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads)
|
176
|
+
nparams_like_default = count_sablock_params(
|
177
|
+
hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head
|
178
|
+
)
|
179
|
+
self.assertEqual(nparams_default, nparams_like_default)
|
180
|
+
|
181
|
+
# Increasing dim_head should increase the number of parameters
|
182
|
+
nparams_custom_large = count_sablock_params(
|
183
|
+
hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head * 2
|
184
|
+
)
|
185
|
+
self.assertGreater(nparams_custom_large, nparams_default)
|
186
|
+
|
187
|
+
# Decreasing dim_head should decrease the number of parameters
|
188
|
+
nparams_custom_small = count_sablock_params(
|
189
|
+
hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head // 2
|
190
|
+
)
|
191
|
+
self.assertGreater(nparams_default, nparams_custom_small)
|
192
|
+
|
193
|
+
# Increasing the number of heads with the default behaviour should not change the number of params.
|
194
|
+
nparams_default_more_heads = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads * 2)
|
195
|
+
self.assertEqual(nparams_default, nparams_default_more_heads)
|
196
|
+
|
197
|
+
@parameterized.expand([[True, False], [True, True], [False, True], [False, False]])
|
198
|
+
@skipUnless(has_einops, "Requires einops")
|
199
|
+
@SkipIfBeforePyTorchVersion((2, 0))
|
200
|
+
def test_script(self, include_fc, use_combined_linear):
|
201
|
+
input_param = {
|
202
|
+
"hidden_size": 360,
|
203
|
+
"num_heads": 4,
|
204
|
+
"dropout_rate": 0.0,
|
205
|
+
"rel_pos_embedding": None,
|
206
|
+
"input_size": (16, 32),
|
207
|
+
"include_fc": include_fc,
|
208
|
+
"use_combined_linear": use_combined_linear,
|
209
|
+
}
|
210
|
+
net = SABlock(**input_param)
|
211
|
+
input_shape = (2, 512, 360)
|
212
|
+
test_data = torch.randn(input_shape)
|
213
|
+
test_script_save(net, test_data)
|
214
|
+
|
215
|
+
@skipUnless(has_einops, "Requires einops")
|
216
|
+
@SkipIfBeforePyTorchVersion((2, 0))
|
217
|
+
def test_flash_attention(self):
|
218
|
+
for causal in [True, False]:
|
219
|
+
input_param = {"hidden_size": 360, "num_heads": 4, "input_size": (16, 32), "causal": causal}
|
220
|
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
221
|
+
block_w_flash_attention = SABlock(**input_param, use_flash_attention=True).to(device)
|
222
|
+
block_wo_flash_attention = SABlock(**input_param, use_flash_attention=False).to(device)
|
223
|
+
block_wo_flash_attention.load_state_dict(block_w_flash_attention.state_dict())
|
224
|
+
test_data = torch.randn(2, 512, 360).to(device)
|
225
|
+
|
226
|
+
out_1 = block_w_flash_attention(test_data)
|
227
|
+
out_2 = block_wo_flash_attention(test_data)
|
228
|
+
assert_allclose(out_1, out_2, atol=1e-4)
|
229
|
+
|
230
|
+
|
231
|
+
if __name__ == "__main__":
|
232
|
+
unittest.main()
|
@@ -0,0 +1,87 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import torch
|
17
|
+
from parameterized import parameterized
|
18
|
+
|
19
|
+
from monai.networks import eval_mode
|
20
|
+
from monai.networks.blocks import SimpleASPP
|
21
|
+
|
22
|
+
TEST_CASES = [
|
23
|
+
[ # 32-channel 2D, batch 7
|
24
|
+
{"spatial_dims": 2, "in_channels": 32, "conv_out_channels": 3, "norm_type": ("batch", {"affine": False})},
|
25
|
+
(7, 32, 18, 20),
|
26
|
+
(7, 12, 18, 20),
|
27
|
+
],
|
28
|
+
[ # 4-channel 1D, batch 16
|
29
|
+
{"spatial_dims": 1, "in_channels": 4, "conv_out_channels": 8, "acti_type": ("PRELU", {"num_parameters": 32})},
|
30
|
+
(16, 4, 17),
|
31
|
+
(16, 32, 17),
|
32
|
+
],
|
33
|
+
[ # 3-channel 3D, batch 16
|
34
|
+
{"spatial_dims": 3, "in_channels": 3, "conv_out_channels": 2},
|
35
|
+
(16, 3, 17, 18, 19),
|
36
|
+
(16, 8, 17, 18, 19),
|
37
|
+
],
|
38
|
+
[ # 3-channel 3D, batch 16
|
39
|
+
{
|
40
|
+
"spatial_dims": 3,
|
41
|
+
"in_channels": 3,
|
42
|
+
"conv_out_channels": 2,
|
43
|
+
"kernel_sizes": (1, 3, 3),
|
44
|
+
"dilations": (1, 2, 4),
|
45
|
+
},
|
46
|
+
(16, 3, 17, 18, 19),
|
47
|
+
(16, 6, 17, 18, 19),
|
48
|
+
],
|
49
|
+
]
|
50
|
+
|
51
|
+
TEST_ILL_CASES = [
|
52
|
+
[ # 3-channel 3D, batch 16, wrong k and d sizes.
|
53
|
+
{"spatial_dims": 3, "in_channels": 3, "conv_out_channels": 2, "kernel_sizes": (1, 3, 3), "dilations": (1, 2)},
|
54
|
+
(16, 3, 17, 18, 19),
|
55
|
+
ValueError,
|
56
|
+
],
|
57
|
+
[ # 3-channel 3D, batch 16, wrong k and d sizes.
|
58
|
+
{
|
59
|
+
"spatial_dims": 3,
|
60
|
+
"in_channels": 3,
|
61
|
+
"conv_out_channels": 2,
|
62
|
+
"kernel_sizes": (1, 3, 4),
|
63
|
+
"dilations": (1, 2, 3),
|
64
|
+
},
|
65
|
+
(16, 3, 17, 18, 19),
|
66
|
+
NotImplementedError, # unknown padding k=4, d=3
|
67
|
+
],
|
68
|
+
]
|
69
|
+
|
70
|
+
|
71
|
+
class TestChannelSELayer(unittest.TestCase):
|
72
|
+
|
73
|
+
@parameterized.expand(TEST_CASES)
|
74
|
+
def test_shape(self, input_param, input_shape, expected_shape):
|
75
|
+
net = SimpleASPP(**input_param)
|
76
|
+
with eval_mode(net):
|
77
|
+
result = net(torch.randn(input_shape))
|
78
|
+
self.assertEqual(result.shape, expected_shape)
|
79
|
+
|
80
|
+
@parameterized.expand(TEST_ILL_CASES)
|
81
|
+
def test_ill_args(self, input_param, input_shape, error_type):
|
82
|
+
with self.assertRaises(error_type):
|
83
|
+
SimpleASPP(**input_param)
|
84
|
+
|
85
|
+
|
86
|
+
if __name__ == "__main__":
|
87
|
+
unittest.main()
|
@@ -0,0 +1,55 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
from unittest import skipUnless
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from parameterized import parameterized
|
19
|
+
|
20
|
+
from monai.networks import eval_mode
|
21
|
+
from monai.networks.blocks.spatialattention import SpatialAttentionBlock
|
22
|
+
from monai.utils import optional_import
|
23
|
+
|
24
|
+
einops, has_einops = optional_import("einops")
|
25
|
+
|
26
|
+
TEST_CASES = [
|
27
|
+
[
|
28
|
+
{"spatial_dims": 2, "num_channels": 128, "num_head_channels": 32, "norm_num_groups": 32, "norm_eps": 1e-6},
|
29
|
+
(1, 128, 32, 32),
|
30
|
+
(1, 128, 32, 32),
|
31
|
+
],
|
32
|
+
[
|
33
|
+
{"spatial_dims": 3, "num_channels": 16, "num_head_channels": 8, "norm_num_groups": 8, "norm_eps": 1e-6},
|
34
|
+
(1, 16, 8, 8, 8),
|
35
|
+
(1, 16, 8, 8, 8),
|
36
|
+
],
|
37
|
+
]
|
38
|
+
|
39
|
+
|
40
|
+
class TestBlock(unittest.TestCase):
|
41
|
+
@parameterized.expand(TEST_CASES)
|
42
|
+
@skipUnless(has_einops, "Requires einops")
|
43
|
+
def test_shape(self, input_param, input_shape, expected_shape):
|
44
|
+
net = SpatialAttentionBlock(**input_param)
|
45
|
+
with eval_mode(net):
|
46
|
+
result = net(torch.randn(input_shape))
|
47
|
+
self.assertEqual(result.shape, expected_shape)
|
48
|
+
|
49
|
+
def test_attention_dim_not_multiple_of_heads(self):
|
50
|
+
with self.assertRaises(ValueError):
|
51
|
+
SpatialAttentionBlock(spatial_dims=2, num_channels=128, num_head_channels=33)
|
52
|
+
|
53
|
+
|
54
|
+
if __name__ == "__main__":
|
55
|
+
unittest.main()
|
@@ -0,0 +1,87 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import torch
|
17
|
+
import torch.nn as nn
|
18
|
+
from parameterized import parameterized
|
19
|
+
|
20
|
+
from monai.networks import eval_mode
|
21
|
+
from monai.networks.blocks import SubpixelUpsample
|
22
|
+
from monai.networks.layers.factories import Conv
|
23
|
+
from tests.test_utils import SkipIfBeforePyTorchVersion, test_script_save
|
24
|
+
|
25
|
+
TEST_CASE_SUBPIXEL = []
|
26
|
+
for inch in range(1, 5):
|
27
|
+
for dim in range(1, 4):
|
28
|
+
for factor in range(1, 3):
|
29
|
+
test_case = [
|
30
|
+
{"spatial_dims": dim, "in_channels": inch, "scale_factor": factor},
|
31
|
+
(2, inch, *([8] * dim)),
|
32
|
+
(2, inch, *([8 * factor] * dim)),
|
33
|
+
]
|
34
|
+
TEST_CASE_SUBPIXEL.append(test_case)
|
35
|
+
|
36
|
+
TEST_CASE_SUBPIXEL_2D_EXTRA = [
|
37
|
+
{"spatial_dims": 2, "in_channels": 2, "scale_factor": 3},
|
38
|
+
(2, 2, 8, 4), # different size for H and W
|
39
|
+
(2, 2, 24, 12),
|
40
|
+
]
|
41
|
+
|
42
|
+
TEST_CASE_SUBPIXEL_3D_EXTRA = [
|
43
|
+
{"spatial_dims": 3, "in_channels": 1, "scale_factor": 2},
|
44
|
+
(2, 1, 16, 8, 4), # different size for H, W and D
|
45
|
+
(2, 1, 32, 16, 8),
|
46
|
+
]
|
47
|
+
|
48
|
+
conv_block = nn.Sequential(
|
49
|
+
Conv[Conv.CONV, 3](1, 4, kernel_size=1), Conv[Conv.CONV, 3](4, 8, kernel_size=3, stride=1, padding=1)
|
50
|
+
)
|
51
|
+
|
52
|
+
TEST_CASE_SUBPIXEL_CONV_BLOCK_EXTRA = [
|
53
|
+
{"spatial_dims": 3, "in_channels": 1, "scale_factor": 2, "conv_block": conv_block},
|
54
|
+
(2, 1, 16, 8, 4), # different size for H, W and D
|
55
|
+
(2, 1, 32, 16, 8),
|
56
|
+
]
|
57
|
+
|
58
|
+
TEST_CASE_SUBPIXEL.append(TEST_CASE_SUBPIXEL_2D_EXTRA) # type: ignore
|
59
|
+
TEST_CASE_SUBPIXEL.append(TEST_CASE_SUBPIXEL_3D_EXTRA) # type: ignore
|
60
|
+
TEST_CASE_SUBPIXEL.append(TEST_CASE_SUBPIXEL_CONV_BLOCK_EXTRA) # type: ignore
|
61
|
+
|
62
|
+
# add every test back with the pad/pool sequential component omitted
|
63
|
+
for tests in list(TEST_CASE_SUBPIXEL):
|
64
|
+
args: dict = tests[0] # type: ignore
|
65
|
+
args = dict(args)
|
66
|
+
args["apply_pad_pool"] = False
|
67
|
+
TEST_CASE_SUBPIXEL.append([args, tests[1], tests[2]])
|
68
|
+
|
69
|
+
|
70
|
+
class TestSUBPIXEL(unittest.TestCase):
|
71
|
+
@parameterized.expand(TEST_CASE_SUBPIXEL)
|
72
|
+
def test_subpixel_shape(self, input_param, input_shape, expected_shape):
|
73
|
+
net = SubpixelUpsample(**input_param)
|
74
|
+
with eval_mode(net):
|
75
|
+
result = net.forward(torch.randn(input_shape))
|
76
|
+
self.assertEqual(result.shape, expected_shape)
|
77
|
+
|
78
|
+
@SkipIfBeforePyTorchVersion((1, 8, 1))
|
79
|
+
def test_script(self):
|
80
|
+
input_param, input_shape, _ = TEST_CASE_SUBPIXEL[0]
|
81
|
+
net = SubpixelUpsample(**input_param)
|
82
|
+
test_data = torch.randn(input_shape)
|
83
|
+
test_script_save(net, test_data)
|
84
|
+
|
85
|
+
|
86
|
+
if __name__ == "__main__":
|
87
|
+
unittest.main()
|
@@ -0,0 +1,49 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
from monai.networks.blocks.text_embedding import TextEncoder
|
17
|
+
from tests.test_utils import skip_if_downloading_fails
|
18
|
+
|
19
|
+
|
20
|
+
class TestTextEncoder(unittest.TestCase):
|
21
|
+
def test_test_encoding_shape(self):
|
22
|
+
with skip_if_downloading_fails():
|
23
|
+
# test 2D encoder
|
24
|
+
text_encoder = TextEncoder(
|
25
|
+
spatial_dims=2, out_channels=32, encoding="clip_encoding_universal_model_32", pretrained=True
|
26
|
+
)
|
27
|
+
text_encoding = text_encoder()
|
28
|
+
self.assertEqual(text_encoding.shape, (32, 256, 1, 1))
|
29
|
+
|
30
|
+
# test 3D encoder
|
31
|
+
text_encoder = TextEncoder(
|
32
|
+
spatial_dims=3, out_channels=32, encoding="clip_encoding_universal_model_32", pretrained=True
|
33
|
+
)
|
34
|
+
text_encoding = text_encoder()
|
35
|
+
self.assertEqual(text_encoding.shape, (32, 256, 1, 1, 1))
|
36
|
+
|
37
|
+
# test random enbedding 3D
|
38
|
+
text_encoder = TextEncoder(spatial_dims=3, out_channels=32, encoding="rand_embedding", pretrained=True)
|
39
|
+
text_encoding = text_encoder()
|
40
|
+
self.assertEqual(text_encoding.shape, (32, 256, 1, 1, 1))
|
41
|
+
|
42
|
+
# test random enbedding 2D
|
43
|
+
text_encoder = TextEncoder(spatial_dims=2, out_channels=32, encoding="rand_embedding", pretrained=True)
|
44
|
+
text_encoding = text_encoder()
|
45
|
+
self.assertEqual(text_encoding.shape, (32, 256, 1, 1))
|
46
|
+
|
47
|
+
|
48
|
+
if __name__ == "__main__":
|
49
|
+
unittest.main()
|
@@ -0,0 +1,90 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
from unittest import skipUnless
|
16
|
+
|
17
|
+
import numpy as np
|
18
|
+
import torch
|
19
|
+
from parameterized import parameterized
|
20
|
+
|
21
|
+
from monai.networks import eval_mode
|
22
|
+
from monai.networks.blocks.transformerblock import TransformerBlock
|
23
|
+
from monai.utils import optional_import
|
24
|
+
|
25
|
+
einops, has_einops = optional_import("einops")
|
26
|
+
TEST_CASE_TRANSFORMERBLOCK = []
|
27
|
+
for dropout_rate in np.linspace(0, 1, 4):
|
28
|
+
for hidden_size in [360, 480, 600, 768]:
|
29
|
+
for num_heads in [4, 8, 12]:
|
30
|
+
for mlp_dim in [1024, 3072]:
|
31
|
+
for cross_attention in [False, True]:
|
32
|
+
test_case = [
|
33
|
+
{
|
34
|
+
"hidden_size": hidden_size,
|
35
|
+
"num_heads": num_heads,
|
36
|
+
"mlp_dim": mlp_dim,
|
37
|
+
"dropout_rate": dropout_rate,
|
38
|
+
"with_cross_attention": cross_attention,
|
39
|
+
},
|
40
|
+
(2, 512, hidden_size),
|
41
|
+
(2, 512, hidden_size),
|
42
|
+
]
|
43
|
+
TEST_CASE_TRANSFORMERBLOCK.append(test_case)
|
44
|
+
|
45
|
+
|
46
|
+
class TestTransformerBlock(unittest.TestCase):
|
47
|
+
|
48
|
+
@parameterized.expand(TEST_CASE_TRANSFORMERBLOCK)
|
49
|
+
@skipUnless(has_einops, "Requires einops")
|
50
|
+
def test_shape(self, input_param, input_shape, expected_shape):
|
51
|
+
net = TransformerBlock(**input_param)
|
52
|
+
with eval_mode(net):
|
53
|
+
result = net(torch.randn(input_shape))
|
54
|
+
self.assertEqual(result.shape, expected_shape)
|
55
|
+
|
56
|
+
def test_ill_arg(self):
|
57
|
+
with self.assertRaises(ValueError):
|
58
|
+
TransformerBlock(hidden_size=128, num_heads=12, mlp_dim=2048, dropout_rate=4.0)
|
59
|
+
|
60
|
+
with self.assertRaises(ValueError):
|
61
|
+
TransformerBlock(hidden_size=622, num_heads=8, mlp_dim=3072, dropout_rate=0.4)
|
62
|
+
|
63
|
+
@skipUnless(has_einops, "Requires einops")
|
64
|
+
def test_access_attn_matrix(self):
|
65
|
+
# input format
|
66
|
+
hidden_size = 128
|
67
|
+
mlp_dim = 12
|
68
|
+
num_heads = 2
|
69
|
+
dropout_rate = 0
|
70
|
+
input_shape = (2, 256, hidden_size)
|
71
|
+
|
72
|
+
# returns an empty attention matrix
|
73
|
+
no_matrix_acess_blk = TransformerBlock(
|
74
|
+
hidden_size=hidden_size, mlp_dim=mlp_dim, num_heads=num_heads, dropout_rate=dropout_rate
|
75
|
+
)
|
76
|
+
no_matrix_acess_blk(torch.randn(input_shape))
|
77
|
+
assert isinstance(no_matrix_acess_blk.attn.att_mat, torch.Tensor)
|
78
|
+
# no of elements is zero
|
79
|
+
assert no_matrix_acess_blk.attn.att_mat.nelement() == 0
|
80
|
+
|
81
|
+
# be able to acess the attention matrix
|
82
|
+
matrix_acess_blk = TransformerBlock(
|
83
|
+
hidden_size=hidden_size, mlp_dim=mlp_dim, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True
|
84
|
+
)
|
85
|
+
matrix_acess_blk(torch.randn(input_shape))
|
86
|
+
assert matrix_acess_blk.attn.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1])
|
87
|
+
|
88
|
+
|
89
|
+
if __name__ == "__main__":
|
90
|
+
unittest.main()
|